eightHundreds
1/25/2017 - 7:03 AM

AVLTree.py

class AVLTreeNode:
    def __init__(self,key=None,left=None,right=None,height:int=0):
        """
        构造函数
        :param height:树的高度
        :param key: 值
        :param left: 左子节点
        :param right: 右子节点
        """
        self.left = left #type:AVLTreeNode
        self.right = right #type:AVLTreeNode
        self.key = key
        self.height = height

    def compareTo(self,other_key):
        return self.key-other_key

class AVLTree():
    def __init__(self,root=None):
        self.root=root
        pass

    def rr_rotation(self,root:AVLTreeNode):
        """
        右右旋转
        :param root:旋转前的根节点
        :return:旋转后的根节点
        """
        new_root=root.right#type: AVLTreeNode
        root.right=new_root.left
        new_root.left=root

        #更新节点高度
        root.height=max(self._height(root.left),self._height(root.right))+1
        new_root.height=max(self._height(new_root.right),self._height(root))+1

        return new_root
    def ll_rotation(self,root:AVLTreeNode):
        """
        左左旋转
        :param root:
        :return:
        """
        new_root=root.left
        root.left = new_root.right
        new_root.right=root

        root.height=max(self._height(root.left),self._height(root.right))+1
        new_root.height=max(self._height(root),self._height(new_root.left))+1

        return new_root
    def rl_rotation(self,node:AVLTreeNode):
        """
        右左旋转
        :param node:
        :return:旋转后的根节点
        """
        node.right=self.ll_rotation(node.right)
        return self.rr_rotation(node)
    def lr_rotation(self,node:AVLTreeNode):
        """
        左右旋转
        :param node:
        :return:旋转后的根节点
        """
        node.left=self.rr_rotation(node.left)
        return self.ll_rotation(node)
    def remove(self,key):
        if self._search(self.root,key):
            self.root=self._remove(self.root,key)
    def _remove(self,tree:AVLTreeNode,key):
        if not tree or not key:
            return None
        cmp=tree.compareTo(key)
        if cmp>0:
            tree.left=self._remove(tree.left,key)

            #如果失去平衡
            if (self._height(tree.right)-self._height(tree.left))==2:
                r=tree.right
                if self._height(r.left)>self._height(r.right):
                    tree=self.rl_rotation(tree)
                else:
                    tree=self.rr_rotation(tree)
        elif cmp<0:
            tree.right=self._remove(tree.right,key)

            #如果失去平衡
            if(self._height(tree.left)-self._height(tree.right))==2:
                l=tree.left
                if self._height(l.right)>self._height(l.left):
                    tree=self.lr_rotation(tree)
                else:
                    tree=self.ll_rotation(tree)
        else:#删除节点
            if tree.left and tree.right:#如果被删除节点有左右子树
                if self._height(tree.left)>self._height(tree.right):
                    max=self.maximum(tree.left)
                    tree.key=max.key
                    tree.left=self._remove(tree.left,max())
                else:
                    min=self.minimun(tree.right)
                    tree.key=min.key
                    tree.right=self._remove(tree.right,min)
            else:
                tree=tree.left if tree.left else tree.right
        return tree
    def maximum(self,tree):
        if not tree:
            return None
        while tree.right:
            tree=tree.right
        return tree
    def minimun(self,tree):
        if not tree:
            return None
        while tree.left:
            tree=tree.left
        return tree
    def insert(self,key):
        self.root=self._insert(self.root,key)
    def _insert(self,tree,insert_key):
        """
        插入
        :param tree:
        :param insert_key:
        :return:返回根节点
        """
        if not tree:
            tree=AVLTreeNode(insert_key)
        else:
            cmp=tree.compareTo(insert_key)

            if cmp>0:#插入左子节点
                tree.left=self._insert(tree.left,insert_key)
                if (self._height(tree.left)-self._height(tree.right))==2:#是否要旋转
                    if tree.left.compareTo(insert_key)>0:#左子节点的左子树
                        tree=self.ll_rotation(tree)
                    else:
                        tree = self.lr_rotation(tree)
            elif cmp<0:
                tree.right=self._insert(tree.right,insert_key)
                if (self._height(tree.right)-self._height(tree.left))==2:
                    if tree.right.compareTo(insert_key)<0:
                        tree=self.rr_rotation(tree)
                    else:
                        tree=self.rl_rotation(tree)
            else:
                raise ValueError('不可插入已经存在的节点')

        tree.height = max(self._height(tree.left),self._height(tree.right)) + 1
        return tree

    def search(self, key):
        if self.root != None:
            return self._search(self.root, key)
    def _search(self, node: AVLTreeNode, key):
        """
        查找节点对象(递归)
        :param value:
        :return:
        """
        if node == None:
            return node
        cmp=node.compareTo(key)
        if cmp>0:
            return self._search(node.left,key)
        elif cmp<0:
            return self._search(node.right,key)
        else:
            return node

    def _height(self,tree):
        """
        树的高度,这里定义空树为0
        :return:
        """
        if tree:
            return tree.height
        return 0
    def height(self):
        return self._height(self.root)

    def print(self):
        """
        打印树
        :return:
        """
        self._print(self.root,self.root.key,0)
    def _print(self, node: AVLTreeNode,key, direction):
        """
        打印"二叉查找树"
        :param node:
        :param direction:
        direction  0,表示该节点是根节点;
                   -1,表示该节点是它的父结点的左孩子;
                   1,表示该节点是它的父结点的右孩子。
        :return:
        """
        if node != None:
            if direction == 0:
                print('%s is root' % node.key)
            else:
                print('%s is %s %s child' % (node.key, key, 'right' if direction == 1 else 'left'))
            # 递归
            self._print(node.left,node.key, -1)
            self._print(node.right,node.key, 1)