CodyKochmann
11/26/2016 - 10:22 PM

Binary search tree in python

Binary search tree in python

# -*- coding: utf-8 -*-
# @Author: cody
# @Date:   2016-11-26 14:18:19
# @Last Modified 2016-11-26
# @Last Modified time: 2016-11-26 17:19:44

from random import shuffle

class Node:
    def __init__(self, val, depth=1):
        self.l = None
        self.r = None
        self.v = val
        self.depth = depth

    def __call__(self):
        """ shorthand way of getting the val """
        return self.v

    def add(self,val,depth=2):
        """ adds an instance to this node """
        if val < self.v: # go to left branch
            if self.l is None: # add if none
                #print("setting {} as left child at depth {}".format(val,depth))
                self.l = Node(val,depth)
            else:
                #print("giving {} to left child({}) at depth {}".format(val,self.l(),depth))
                self.l.add(val,depth+1)
        elif val > self.v: # go to right branch
            if self.r is None:
                #print("setting {} as right child at depth {}".format(val,depth))
                self.r = Node(val,depth)
            else:
                #print("giving {} to right child({}) at depth {}".format(val,self.r(),depth))
                self.r.add(val,depth+1)

    def find(self,target):
        if self() is target:
            return True
        elif self() > target:
            if self.l is None: return False
            else: return self.l.find(target)
        else:
            if self.r is None: return False
            else: return self.r.find(target)



    def map(self):
        """ generator for all nodes """
        yield self
        if self.l is not None:
            for _ in self.l.map():
                yield _
        if self.r is not None:
            for _ in self.r.map():
                yield _

    def __iter__(self):
        return self.map()

    def map_values(self):
        """ generator for all values """
        for _ in self:
            yield _()

    def map_depth(self):
        for _ in self:
            yield _.depth

class BinaryTree:
    def __init__(self):
        self.root = None

    @property
    def has_root(self):
        return self.root is not None

    def add(self, val):
        if self.has_root:
            self.root.add(val)
        else:
            self.root = Node(val)

    @property
    def shuffled_copy(self):
        """ returns a shuffled copy of itself """
        bt = BinaryTree()
        v=list(self.map_values())
        shuffle(v)
        for _ in v:
            bt.add(_)
        return bt

    @property
    def deepest_branch(self):
        assert self.has_root, "no root found"
        return max(self.root.map_depth())

    def attempt_optimization(self):
        assert self.has_root, "no root found"
        cp = self.shuffled_copy
        if self.deepest_branch > cp.deepest_branch:
            #print("current length: {}".format(self.deepest_branch))
            #print("optimized length: {}".format(cp.deepest_branch))
            self.root = cp.root

    def optimize(self,rounds=100):
        for _ in range(rounds):
            if self.is_optimal:
                break
            self.attempt_optimization()

    @property
    def is_optimal(self):
        """ returns if the tree has been optimized to its fullest. """
        return 2**(self.deepest_branch-1)<self.count

    @property
    def count(self):
        """ returns how many nodes are in the tree """
        out = 0
        for _ in self:
            out += 1
        return out

    def __iter__(self):
        return self.map_values()

    def map_values(self):
        assert self.has_root, "no root found"
        for _ in self.root.map_values():
            yield _

    def find(self,target):
        """ returns True if the instance is there """
        if self.has_root:
            return self.root.find(target)
        else:
            return False


if __name__ == "__main__":
    bt = BinaryTree()
    from random import randint
    r=lambda:randint(1, 100)
    for i in range(50):
        bt.add(r())
    bt.optimize()
    for i in bt:
        print i
    for i in range(50):
        _ = r()
        print _, bt.find(_)