AVL Tree in Go

tech · Nov 15, 2020 · ~13 min
Photo by @loicleray on Unsplash
Photo by @loicleray on Unsplash

After I tried to implement BST in Go, it seems like I want to modify the BST to AVL because BST is not a fairly optimal tree data structure.

When I said this:

To find a specific node you don’t have to go around the whole tree, you need to know that BST can route to a specific node by checking the node value

It’s half true because there’s a case that BST makes a linear tree like this:

Linear Tree
Linear Tree

And if you want to find a node with value 6, in the end, you will travel the whole tree. That’s why we need AVL to improve the time complexity. AVL will try to rebalance the tree whenever it becomes imbalance after insertion/deletion.

The whole concept of AVL is much the same with BST besides the rebalancing algorithm. In AVL we need to rebalance the tree by rotating every imbalance sub-tree in every insertion/deletion. So we’re gonna use all the code from here and modified it a bit.

To see the tree is balanced or not, we need to define the height on each node. We can calculate the height by counting the maximum height of the left and the right node recursively. If the node has no child, it means its height is 1 otherwise we compare the maximum height of the children.

Update the node struct by adding the height attribute, add the Getter function, and set the value to 1 inside the constructor.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
type node struct {
  height, value int
  left, right   *node
}

func (n *node) Height() int {
  if n == nil {
    return 0
  }

  return n.height
}

func newNode(val int) *node {
  return &node{
    height: 1,
    value:  val,
    left:   nil,
    right:  nil,
  }
}

And to keep track of the height and the balance of the tree after insertion/deletion, we need to have a updateHeight and balanceFactor function.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
func (n *node) balanceFactor() int {
  if n == nil {
    return 0
  }

  return n.left.Height() - n.right.Height()
}

func max(a, b int) int {
  if a > b {
    return a
  }
  return b
}

func (n *node) updateHeight() {
  // compare the maximum height of the children + its own height
  n.height = max(n.left.Height(), n.right.Height()) + 1
}

balanceFactor function determines whether the tree is heavier on the left or the right side. If it returns an integer below 0, it means it’s heavier on the right side and we need to rotate to the left side of the tree. The thresholds for imbalanced tree are -1 and 1, so if the balanceFactor function returns less then -1 or greater than 1, we need to rotate the tree.

Now let’s create the rotate function. There are 2 types of rotate, rotateLeft and rotateRight. But there are 4 conditions to rotate the tree on insertion and deletion. You can read it and see the picture from here.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func rightRotate(x *node) *node {
  y := x.left
  t := y.right

  y.right = x
  x.left = t

  x.updateHeight()
  y.updateHeight()

  return y
}

func leftRotate(x *node) *node {
  y := x.right
  t := y.left

  y.left = x
  x.right = t

  x.updateHeight()
  y.updateHeight()

  return y
}

The conditions to rotate the tree on insertion are:

  1. When the tree linearly to the right, you need to use leftRotate on the current node
  2. When the tree linearly to the left, you need to use rightRotate on the current node
  3. When the tree creates Less Than Symbol, you need to leftRotate on the left child, and rightRotate on the current node
  4. When the tree creates Greater Than Symbol, you need to rightRotate on the right child, and leftRotate on the current node

Insertion

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
func rotateInsert(node *node, val int) *node {
  // update the height on every insertion
  node.updateHeight()

  // bFactor will tell you which side the weight is on
  bFactor := node.balanceFactor()

  // linearly to the left
  if bFactor > 1 && val < node.left.value {
    return rightRotate(node)
  }

  // linearly to the right
  if bFactor < -1 && val > node.right.value {
    return leftRotate(node)
  }

  // less than symbol
  if bFactor > 1 && val > node.left.value {
    node.left = leftRotate(node.left)
    return rightRotate(node)
  }

  // greater than symbol
  if bFactor < -1 && val < node.right.value {
    node.right = rightRotate(node.right)
    return leftRotate(node)
  }

  return node
}

Lastly, you need to update the return statement of the insertNode function.

1
2
3
4
func insertNode(node *node, val int) (*node, error) {
  ...
  return rotateInsert(node, val), nil
}

Traverse Operation and Validation

So that the results are easy to visualize, you need to change the traverse function becomes pre-order and open BST Visualization page & AVL Visualiztion Page.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
func traverse(node *node) {
  // exit condition
  if node == nil {
    return
  }

  fmt.Println(node.value)
  traverse(node.left)
  traverse(node.right)
}

func main() {
  tree := avl.New()

  // to check if your implementation is correct
  // First insert this sequentially
  // to the AVL Visualiztion Page
  tree.Insert(0)
  tree.Insert(1)
  tree.Insert(2)
  tree.Insert(3)
  tree.Insert(4)
  tree.Insert(5)
  tree.Insert(6)
  tree.Insert(7)

  // Second insert Traverse function results sequentially
  // to the BST Visualization page
  tree.Traverse() // 3 1 0 2 5 4 6 7
}

If you find the tree visualizations are the same and balanced, then it’s correct.

Deletion

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
func rotateDelete(node *node) *node {
  node.updateHeight()
  bFactor := node.balanceFactor()

  // linearly to the left
  if bFactor > 1 && node.left.balanceFactor() >= 0 {
    return rightRotate(node)
  }
  
  // less than symbol
  if bFactor > 1 && node.left.balanceFactor() < 0 {
    node.left = leftRotate(node.left)
    return rightRotate(node)
  }

  // linearly to the right
  if bFactor < -1 && node.right.balanceFactor() <= 0 {
    return leftRotate(node)
  }

  // greater than symbol
  if bFactor < -1 && node.right.balanceFactor() > 0 {
    node.right = rightRotate(node.right)
    return leftRotate(node)
  }

  return node
}

Deletion is not like insertion in that we can compare the entered values, because the node we are looking for is already deleted. That’s why we need to compare the current node’s balance factor with the balance factor of the child. Now, you need to modify the removeNode function. Remember when removing a node with 2 children, we need to find the successor and there are 2 ways to find the successor.

  1. Find the least valueable node from the right child of the node
  2. Find the greatest valueable node from the left child of the node

We used the first way while the BST & AVL Visualization Page using the second way. You can also change the code so it’s easy to visualize.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
func greatest(node *node) *node {
  if node == nil {
    return nil
  }

  if node.right == nil {
    return node
  }

  return greatest(node.right)
}

func removeNode(node *node, val int) (*node, error) {
  if node == nil {
    return nil, ErrNodeNotFound
  }

  if val > node.value {
    right, err := removeNode(node.right, val)
    if err != nil {
      return nil, err
    }

    node.right = right
  } else if val < node.value {
    left, err := removeNode(node.left, val)
    if err != nil {
      return nil, err
    }

    node.left = left
  } else {
    if node.left != nil && node.right != nil {
      // has 2 children

      // find the successor
      successor := greatest(node.left)
      value := successor.value

      // remove the successor
      left, err := removeNode(node.left, value)
      if err != nil {
        return nil, err
      }
      node.left = left

      // copy the successor value to the current node
      node.value = value
    } else if node.left != nil || node.right != nil {
      // has 1 child
      // move the child position to the current node
      if node.left != nil {
        node = node.left
      } else {
        node = node.right
      }
    } else if node.left == nil && node.right == nil {
      // has no child
      // simply remove the node
      node = nil
    }
  }

  if node == nil {
    return nil, nil
  }

  return rotateDelete(node), nil
}

You can validate and recheck your AVL implementation with the BST & AVL visualization page.

Here is the modifed node.go file.

node.go

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
package avl

import (
  "errors"
  "fmt"
)

var (
  ErrDuplicatedNode error = errors.New("bst: found duplicated value on tree")
  ErrNodeNotFound   error = errors.New("bst: node not found")
)

type node struct {
  height, value int
  left, right   *node
}

func (n *node) balanceFactor() int {
  if n == nil {
    return 0
  }

  return n.left.Height() - n.right.Height()
}

func (n *node) updateHeight() {
  max := func (a, b int) int {
    if a > b {
      return a
    }

    return b
  }
  n.height = max(n.left.Height(), n.right.Height()) + 1
}

func (n *node) Height() int {
  if n == nil {
    return 0
  }

  return n.height
}

func (n *node) Value() int {
  return n.value
}

func (n *node) Left() *node {
  return n.left
}

func (n *node) Right() *node {
  return n.right
}

func newNode(val int) *node {
  return &node{
    height: 1,
    value:  val,
    left:   nil,
    right:  nil,
  }
}

func insertNode(node *node, val int) (*node, error) {
  // if there's no node, create one
  if node == nil {
    return newNode(val), nil
  }

  // if there's duplicated node returns error
  if node.value == val {
    return nil, ErrDuplicatedNode
  }

  // if value is greater than current node's value, insert to the right
  if val > node.value {
    right, err := insertNode(node.right, val)

    if err != nil {
      return nil, err
    }

    node.right = right
  }

  // if value is less than current node's value, insert to the left
  if val < node.value {
    left, err := insertNode(node.left, val)

    if err != nil {
      return nil, err
    }

    node.left = left
  }

  return rotateInsert(node, val), nil
}

func removeNode(node *node, val int) (*node, error) {
  if node == nil {
    return nil, ErrNodeNotFound
  }

  if val > node.value {
    right, err := removeNode(node.right, val)
    if err != nil {
      return nil, err
    }

    node.right = right
  } else if val < node.value {
    left, err := removeNode(node.left, val)
    if err != nil {
      return nil, err
    }

    node.left = left
  } else {
    if node.left != nil && node.right != nil {
      // has 2 children

      // find the successor
      successor := greatest(node.left)
      value := successor.value

      // remove the successor
      left, err := removeNode(node.left, value)
      if err != nil {
        return nil, err
      }
      node.left = left

      // copy the successor value to the current node
      node.value = value
    } else if node.left != nil || node.right != nil {
      // has 1 child
      // move the child position to the current node
      if node.left != nil {
        node = node.left
      } else {
        node = node.right
      }
    } else if node.left == nil && node.right == nil {
      // has no child
      // simply remove the node
      node = nil
    }
  }

  if node == nil {
    return nil, nil
  }

  return rotateDelete(node), nil
}

func findNode(node *node, val int) *node {
  if node == nil {
    return nil
  }

  // if the node is found, return the node
  if node.value == val {
    return node
  }

  // if value is greater than current node's value, search recursively to the right
  if val > node.value {
    return findNode(node.right, val)
  }

  // if value is less than current node's value, search recursively to the left
  if val < node.value {
    return findNode(node.left, val)
  }

  return nil
}

func rotateInsert(node *node, val int) *node {
  // update the height on every insertion
  node.updateHeight()

  // bFactor will tell you which side the weight is on
  bFactor := node.balanceFactor()

  // linearly to the left
  if bFactor > 1 && val < node.left.value {
    return rightRotate(node)
  }

  // linearly to the right
  if bFactor < -1 && val > node.right.value {
    return leftRotate(node)
  }

  // less than symbol
  if bFactor > 1 && val > node.left.value {
    node.left = leftRotate(node.left)
    return rightRotate(node)
  }

  // greater than symbol
  if bFactor < -1 && val < node.right.value {
    node.right = rightRotate(node.right)
    return leftRotate(node)
  }

  return node
}

func rotateDelete(node *node) *node {
  node.updateHeight()
  bFactor := node.balanceFactor()

  // linearly to the left
  if bFactor > 1 && node.left.balanceFactor() >= 0 {
    return rightRotate(node)
  }

  // less than symbol
  if bFactor > 1 && node.left.balanceFactor() < 0 {
    node.left = leftRotate(node.left)
    return rightRotate(node)
  }

  // linearly to the right
  if bFactor < -1 && node.right.balanceFactor() <= 0 {
    return leftRotate(node)
  }

  // greater than symbol
  if bFactor < -1 && node.right.balanceFactor() > 0 {
    node.right = rightRotate(node.right)
    return leftRotate(node)
  }

  return node
}

func rightRotate(x *node) *node {
  y := x.left
  t := y.right

  y.right = x
  x.left = t

  x.updateHeight()
  y.updateHeight()

  return y
}

func leftRotate(x *node) *node {
  y := x.right
  t := y.left

  y.left = x
  x.right = t

  x.updateHeight()
  y.updateHeight()

  return y
}

func greatest(node *node) *node {
  if node == nil {
    return nil
  }

  if node.right == nil {
    return node
  }

  return greatest(node.right)
}

func traverse(node *node) {
  // exit condition
  if node == nil {
    return
  }

  fmt.Println(node.value)
  traverse(node.left)
  traverse(node.right)
}

func max(a, b int) int {
  if a > b {
    return a
  }
  return b
}

Thank you for reading!

· · ·

Love This Content?

Any kind of supports is greatly appreciated! Kindly support me via Bitcoin, Ko-fi, Trakteer, or just continue to read another content. You can write a response via Webmention and let me know the URL via Telegraph.