'''
Based on Wikipedia code
'''
from collections import namedtuple
from pprint import pformat
class Node:
def __init__(self, location, left, right):
self.location = location
self.left = left
self.right = right
def __repr__(self):
return pformat((self.location, self.left, self.right))
def init_kdtree(point_list, depth=0):
try:
k = len(point_list[0]) # assumes all points have the same dimension
except IndexError: # if empty :
return None
# Select axis based on depth so that axis cycles through all valid values
axis = depth % k
# Sort point list and choose median as pivot element
# itemgetter is faster than lambda tup: tup[axis]
point_list.sort(key=itemgetter(axis))
median = len(point_list) // 2 # choose median
# Create node and construct subtrees
return Node(
location=point_list[median],
left=init_kdtree(point_list[:median], depth + 1),
right=init_kdtree(point_list[median + 1:], depth + 1)
)
class kdtree:
def __init__(self, point_list, depth=0):
self.k = len(point_list[0]) # assumes all points have the same dimension
self.tree = init_kdtree(point_list, depth)
def __repr__(self):
return pformat(self.tree)
def insert(self, point, subtree=None, depth=0):
if subtree is None:
subtree = self.tree
axis = depth % self.k
if point[axis] < subtree.location[axis]:
if subtree.left is not None:
self.insert(point, subtree.left, depth+1)
else:
subtree.left = Node(
location=point,
left=None,
right=None
)
elif point[axis] > subtree.location[axis]:
if subtree.right is not None:
self.insert(point, subtree.right, depth+1)
else:
subtree.right = Node(
location=point,
left=None,
right=None
)
elif point[axis] == subtree.location[axis]:
if point != subtree.location:
self.insert(point, subtree.left, depth+1)
def main():
point_list = [(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)]
tree = kdtree(point_list)
print(tree)
tree.insert((3,2))
print(tree)