jongmmm
3/6/2017 - 2:21 AM

KD_Tree.py

'''
Based on Wikipedia code
'''
from collections import namedtuple
from pprint import pformat

class Node:
    def __init__(self, location, left, right):
        self.location = location
        self.left = left
        self.right = right
    def __repr__(self):
        return pformat((self.location, self.left, self.right))

def init_kdtree(point_list, depth=0):
    try:
        k = len(point_list[0]) # assumes all points have the same dimension
    except IndexError: # if empty :
        return None
    
    # Select axis based on depth so that axis cycles through all valid values
    axis = depth % k
 
    # Sort point list and choose median as pivot element
    # itemgetter is faster than lambda tup: tup[axis]
    point_list.sort(key=itemgetter(axis))
    median = len(point_list) // 2 # choose median
 
    # Create node and construct subtrees
    return Node(
        location=point_list[median],
        left=init_kdtree(point_list[:median], depth + 1),
        right=init_kdtree(point_list[median + 1:], depth + 1)
    )

class kdtree:
    def __init__(self, point_list, depth=0):
        self.k = len(point_list[0]) # assumes all points have the same dimension
        self.tree = init_kdtree(point_list, depth)
        
    def __repr__(self):
        return pformat(self.tree)
        
    def insert(self, point, subtree=None, depth=0):
        if subtree is None:
            subtree = self.tree
        axis = depth % self.k
        if point[axis] < subtree.location[axis]:
            if subtree.left is not None:
                self.insert(point, subtree.left, depth+1)
            else:
                subtree.left = Node(
                    location=point,
                    left=None,
                    right=None
                )
        elif point[axis] > subtree.location[axis]:
            if subtree.right is not None:
                self.insert(point, subtree.right, depth+1)
            else:
                subtree.right = Node(
                    location=point,
                    left=None,
                    right=None
                )
        elif point[axis] == subtree.location[axis]:
            if point != subtree.location:
                self.insert(point, subtree.left, depth+1)
   
def main():
    point_list = [(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)]
    tree = kdtree(point_list)
    print(tree)
    tree.insert((3,2))
    print(tree)