class AVLTreeNode:
def __init__(self,key=None,left=None,right=None,height:int=0):
"""
构造函数
:param height:树的高度
:param key: 值
:param left: 左子节点
:param right: 右子节点
"""
self.left = left #type:AVLTreeNode
self.right = right #type:AVLTreeNode
self.key = key
self.height = height
def compareTo(self,other_key):
return self.key-other_key
class AVLTree():
def __init__(self,root=None):
self.root=root
pass
def rr_rotation(self,root:AVLTreeNode):
"""
右右旋转
:param root:旋转前的根节点
:return:旋转后的根节点
"""
new_root=root.right#type: AVLTreeNode
root.right=new_root.left
new_root.left=root
#更新节点高度
root.height=max(self._height(root.left),self._height(root.right))+1
new_root.height=max(self._height(new_root.right),self._height(root))+1
return new_root
def ll_rotation(self,root:AVLTreeNode):
"""
左左旋转
:param root:
:return:
"""
new_root=root.left
root.left = new_root.right
new_root.right=root
root.height=max(self._height(root.left),self._height(root.right))+1
new_root.height=max(self._height(root),self._height(new_root.left))+1
return new_root
def rl_rotation(self,node:AVLTreeNode):
"""
右左旋转
:param node:
:return:旋转后的根节点
"""
node.right=self.ll_rotation(node.right)
return self.rr_rotation(node)
def lr_rotation(self,node:AVLTreeNode):
"""
左右旋转
:param node:
:return:旋转后的根节点
"""
node.left=self.rr_rotation(node.left)
return self.ll_rotation(node)
def remove(self,key):
if self._search(self.root,key):
self.root=self._remove(self.root,key)
def _remove(self,tree:AVLTreeNode,key):
if not tree or not key:
return None
cmp=tree.compareTo(key)
if cmp>0:
tree.left=self._remove(tree.left,key)
#如果失去平衡
if (self._height(tree.right)-self._height(tree.left))==2:
r=tree.right
if self._height(r.left)>self._height(r.right):
tree=self.rl_rotation(tree)
else:
tree=self.rr_rotation(tree)
elif cmp<0:
tree.right=self._remove(tree.right,key)
#如果失去平衡
if(self._height(tree.left)-self._height(tree.right))==2:
l=tree.left
if self._height(l.right)>self._height(l.left):
tree=self.lr_rotation(tree)
else:
tree=self.ll_rotation(tree)
else:#删除节点
if tree.left and tree.right:#如果被删除节点有左右子树
if self._height(tree.left)>self._height(tree.right):
max=self.maximum(tree.left)
tree.key=max.key
tree.left=self._remove(tree.left,max())
else:
min=self.minimun(tree.right)
tree.key=min.key
tree.right=self._remove(tree.right,min)
else:
tree=tree.left if tree.left else tree.right
return tree
def maximum(self,tree):
if not tree:
return None
while tree.right:
tree=tree.right
return tree
def minimun(self,tree):
if not tree:
return None
while tree.left:
tree=tree.left
return tree
def insert(self,key):
self.root=self._insert(self.root,key)
def _insert(self,tree,insert_key):
"""
插入
:param tree:
:param insert_key:
:return:返回根节点
"""
if not tree:
tree=AVLTreeNode(insert_key)
else:
cmp=tree.compareTo(insert_key)
if cmp>0:#插入左子节点
tree.left=self._insert(tree.left,insert_key)
if (self._height(tree.left)-self._height(tree.right))==2:#是否要旋转
if tree.left.compareTo(insert_key)>0:#左子节点的左子树
tree=self.ll_rotation(tree)
else:
tree = self.lr_rotation(tree)
elif cmp<0:
tree.right=self._insert(tree.right,insert_key)
if (self._height(tree.right)-self._height(tree.left))==2:
if tree.right.compareTo(insert_key)<0:
tree=self.rr_rotation(tree)
else:
tree=self.rl_rotation(tree)
else:
raise ValueError('不可插入已经存在的节点')
tree.height = max(self._height(tree.left),self._height(tree.right)) + 1
return tree
def search(self, key):
if self.root != None:
return self._search(self.root, key)
def _search(self, node: AVLTreeNode, key):
"""
查找节点对象(递归)
:param value:
:return:
"""
if node == None:
return node
cmp=node.compareTo(key)
if cmp>0:
return self._search(node.left,key)
elif cmp<0:
return self._search(node.right,key)
else:
return node
def _height(self,tree):
"""
树的高度,这里定义空树为0
:return:
"""
if tree:
return tree.height
return 0
def height(self):
return self._height(self.root)
def print(self):
"""
打印树
:return:
"""
self._print(self.root,self.root.key,0)
def _print(self, node: AVLTreeNode,key, direction):
"""
打印"二叉查找树"
:param node:
:param direction:
direction 0,表示该节点是根节点;
-1,表示该节点是它的父结点的左孩子;
1,表示该节点是它的父结点的右孩子。
:return:
"""
if node != None:
if direction == 0:
print('%s is root' % node.key)
else:
print('%s is %s %s child' % (node.key, key, 'right' if direction == 1 else 'left'))
# 递归
self._print(node.left,node.key, -1)
self._print(node.right,node.key, 1)