# AVL Tree in Go

tech · Nov 15, 2020 · ~13 min

After I tried to implement BST in Go, it seems like I want to modify the BST to AVL because BST is not a fairly optimal tree data structure.

When I said this:

To find a specific node you don’t have to go around the whole tree, you need to know that BST can route to a specific node by checking the node value

It’s half true because there’s a case that BST makes a linear tree like this:

And if you want to find a node with value 6, in the end, you will travel the whole tree. That’s why we need AVL to improve the time complexity. AVL will try to rebalance the tree whenever it becomes imbalance after insertion/deletion.

The whole concept of AVL is much the same with BST besides the rebalancing algorithm. In AVL we need to rebalance the tree by rotating every imbalance sub-tree in every insertion/deletion. So we’re gonna use all the code from here and modified it a bit.

To see the tree is balanced or not, we need to define the height on each node. We can calculate the height by counting the maximum height of the left and the right node recursively. If the node has no child, it means its height is 1 otherwise we compare the maximum height of the children.

Update the `node struct` by adding the `height` attribute, add the `Getter` function, and set the value to 1 inside the constructor.

 `````` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 `````` ``````type node struct { height, value int left, right *node } func (n *node) Height() int { if n == nil { return 0 } return n.height } func newNode(val int) *node { return &node{ height: 1, value: val, left: nil, right: nil, } } ``````

And to keep track of the height and the balance of the tree after insertion/deletion, we need to have a `updateHeight` and `balanceFactor` function.

 `````` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 `````` ``````func (n *node) balanceFactor() int { if n == nil { return 0 } return n.left.Height() - n.right.Height() } func max(a, b int) int { if a > b { return a } return b } func (n *node) updateHeight() { // compare the maximum height of the children + its own height n.height = max(n.left.Height(), n.right.Height()) + 1 } ``````

`balanceFactor` function determines whether the tree is heavier on the left or the right side. If it returns an integer below 0, it means it’s heavier on the right side and we need to rotate to the left side of the tree. The thresholds for imbalanced tree are -1 and 1, so if the `balanceFactor` function returns less then -1 or greater than 1, we need to rotate the tree.

Now let’s create the rotate function. There are 2 types of rotate, `rotateLeft` and `rotateRight`. But there are 4 conditions to rotate the tree on insertion and deletion. You can read it and see the picture from here.

 `````` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 `````` ``````func rightRotate(x *node) *node { y := x.left t := y.right y.right = x x.left = t x.updateHeight() y.updateHeight() return y } func leftRotate(x *node) *node { y := x.right t := y.left y.left = x x.right = t x.updateHeight() y.updateHeight() return y } ``````

The conditions to rotate the tree on insertion are:

1. When the tree `linearly to the right`, you need to use `leftRotate` on the `current node`
2. When the tree `linearly to the left`, you need to use `rightRotate` on the `current node`
3. When the tree creates `Less Than` Symbol, you need to `leftRotate` on the `left child`, and `rightRotate` on the `current node`
4. When the tree creates `Greater Than` Symbol, you need to `rightRotate` on the `right child`, and `leftRotate` on the `current node`

## Insertion

 `````` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 `````` ``````func rotateInsert(node *node, val int) *node { // update the height on every insertion node.updateHeight() // bFactor will tell you which side the weight is on bFactor := node.balanceFactor() // linearly to the left if bFactor > 1 && val < node.left.value { return rightRotate(node) } // linearly to the right if bFactor < -1 && val > node.right.value { return leftRotate(node) } // less than symbol if bFactor > 1 && val > node.left.value { node.left = leftRotate(node.left) return rightRotate(node) } // greater than symbol if bFactor < -1 && val < node.right.value { node.right = rightRotate(node.right) return leftRotate(node) } return node } ``````

Lastly, you need to update the return statement of the `insertNode` function.

 ``````1 2 3 4 `````` ``````func insertNode(node *node, val int) (*node, error) { ... return rotateInsert(node, val), nil } ``````

## Traverse Operation and Validation

So that the results are easy to visualize, you need to change the traverse function becomes `pre-order` and open BST Visualization page & AVL Visualiztion Page.

 `````` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 `````` ``````func traverse(node *node) { // exit condition if node == nil { return } fmt.Println(node.value) traverse(node.left) traverse(node.right) } func main() { tree := avl.New() // to check if your implementation is correct // First insert this sequentially // to the AVL Visualiztion Page tree.Insert(0) tree.Insert(1) tree.Insert(2) tree.Insert(3) tree.Insert(4) tree.Insert(5) tree.Insert(6) tree.Insert(7) // Second insert Traverse function results sequentially // to the BST Visualization page tree.Traverse() // 3 1 0 2 5 4 6 7 } ``````

If you find the tree visualizations are the same and balanced, then it’s correct.

## Deletion

 `````` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 `````` ``````func rotateDelete(node *node) *node { node.updateHeight() bFactor := node.balanceFactor() // linearly to the left if bFactor > 1 && node.left.balanceFactor() >= 0 { return rightRotate(node) } // less than symbol if bFactor > 1 && node.left.balanceFactor() < 0 { node.left = leftRotate(node.left) return rightRotate(node) } // linearly to the right if bFactor < -1 && node.right.balanceFactor() <= 0 { return leftRotate(node) } // greater than symbol if bFactor < -1 && node.right.balanceFactor() > 0 { node.right = rightRotate(node.right) return leftRotate(node) } return node } ``````

`Deletion` is not like `insertion` in that we can compare the entered values, because the node we are looking for is already deleted. That’s why we need to compare the current node’s balance factor with the balance factor of the child. Now, you need to modify the `removeNode` function. Remember when removing a node with 2 children, we need to find the `successor` and there are 2 ways to find the `successor`.

1. Find the least valueable node from the right child of the node
2. Find the greatest valueable node from the left child of the node

We used the first way while the BST & AVL Visualization Page using the second way. You can also change the code so it’s easy to visualize.

 `````` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 `````` ``````func greatest(node *node) *node { if node == nil { return nil } if node.right == nil { return node } return greatest(node.right) } func removeNode(node *node, val int) (*node, error) { if node == nil { return nil, ErrNodeNotFound } if val > node.value { right, err := removeNode(node.right, val) if err != nil { return nil, err } node.right = right } else if val < node.value { left, err := removeNode(node.left, val) if err != nil { return nil, err } node.left = left } else { if node.left != nil && node.right != nil { // has 2 children // find the successor successor := greatest(node.left) value := successor.value // remove the successor left, err := removeNode(node.left, value) if err != nil { return nil, err } node.left = left // copy the successor value to the current node node.value = value } else if node.left != nil || node.right != nil { // has 1 child // move the child position to the current node if node.left != nil { node = node.left } else { node = node.right } } else if node.left == nil && node.right == nil { // has no child // simply remove the node node = nil } } if node == nil { return nil, nil } return rotateDelete(node), nil } ``````

You can validate and recheck your AVL implementation with the BST & AVL visualization page.

Here is the modifed `node.go` file.

## node.go

 `````` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 `````` ``````package avl import ( "errors" "fmt" ) var ( ErrDuplicatedNode error = errors.New("bst: found duplicated value on tree") ErrNodeNotFound error = errors.New("bst: node not found") ) type node struct { height, value int left, right *node } func (n *node) balanceFactor() int { if n == nil { return 0 } return n.left.Height() - n.right.Height() } func (n *node) updateHeight() { max := func (a, b int) int { if a > b { return a } return b } n.height = max(n.left.Height(), n.right.Height()) + 1 } func (n *node) Height() int { if n == nil { return 0 } return n.height } func (n *node) Value() int { return n.value } func (n *node) Left() *node { return n.left } func (n *node) Right() *node { return n.right } func newNode(val int) *node { return &node{ height: 1, value: val, left: nil, right: nil, } } func insertNode(node *node, val int) (*node, error) { // if there's no node, create one if node == nil { return newNode(val), nil } // if there's duplicated node returns error if node.value == val { return nil, ErrDuplicatedNode } // if value is greater than current node's value, insert to the right if val > node.value { right, err := insertNode(node.right, val) if err != nil { return nil, err } node.right = right } // if value is less than current node's value, insert to the left if val < node.value { left, err := insertNode(node.left, val) if err != nil { return nil, err } node.left = left } return rotateInsert(node, val), nil } func removeNode(node *node, val int) (*node, error) { if node == nil { return nil, ErrNodeNotFound } if val > node.value { right, err := removeNode(node.right, val) if err != nil { return nil, err } node.right = right } else if val < node.value { left, err := removeNode(node.left, val) if err != nil { return nil, err } node.left = left } else { if node.left != nil && node.right != nil { // has 2 children // find the successor successor := greatest(node.left) value := successor.value // remove the successor left, err := removeNode(node.left, value) if err != nil { return nil, err } node.left = left // copy the successor value to the current node node.value = value } else if node.left != nil || node.right != nil { // has 1 child // move the child position to the current node if node.left != nil { node = node.left } else { node = node.right } } else if node.left == nil && node.right == nil { // has no child // simply remove the node node = nil } } if node == nil { return nil, nil } return rotateDelete(node), nil } func findNode(node *node, val int) *node { if node == nil { return nil } // if the node is found, return the node if node.value == val { return node } // if value is greater than current node's value, search recursively to the right if val > node.value { return findNode(node.right, val) } // if value is less than current node's value, search recursively to the left if val < node.value { return findNode(node.left, val) } return nil } func rotateInsert(node *node, val int) *node { // update the height on every insertion node.updateHeight() // bFactor will tell you which side the weight is on bFactor := node.balanceFactor() // linearly to the left if bFactor > 1 && val < node.left.value { return rightRotate(node) } // linearly to the right if bFactor < -1 && val > node.right.value { return leftRotate(node) } // less than symbol if bFactor > 1 && val > node.left.value { node.left = leftRotate(node.left) return rightRotate(node) } // greater than symbol if bFactor < -1 && val < node.right.value { node.right = rightRotate(node.right) return leftRotate(node) } return node } func rotateDelete(node *node) *node { node.updateHeight() bFactor := node.balanceFactor() // linearly to the left if bFactor > 1 && node.left.balanceFactor() >= 0 { return rightRotate(node) } // less than symbol if bFactor > 1 && node.left.balanceFactor() < 0 { node.left = leftRotate(node.left) return rightRotate(node) } // linearly to the right if bFactor < -1 && node.right.balanceFactor() <= 0 { return leftRotate(node) } // greater than symbol if bFactor < -1 && node.right.balanceFactor() > 0 { node.right = rightRotate(node.right) return leftRotate(node) } return node } func rightRotate(x *node) *node { y := x.left t := y.right y.right = x x.left = t x.updateHeight() y.updateHeight() return y } func leftRotate(x *node) *node { y := x.right t := y.left y.left = x x.right = t x.updateHeight() y.updateHeight() return y } func greatest(node *node) *node { if node == nil { return nil } if node.right == nil { return node } return greatest(node.right) } func traverse(node *node) { // exit condition if node == nil { return } fmt.Println(node.value) traverse(node.left) traverse(node.right) } func max(a, b int) int { if a > b { return a } return b } ``````