Skip to content

Commit 8160bae

Browse files
authoredJan 17, 2023
基于红黑树的 TreeMap 实现 (ecodeclub#142)
1 parent 276e47b commit 8160bae

File tree

5 files changed

+2380
-0
lines changed

5 files changed

+2380
-0
lines changed
 

‎.CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
- [ekit: 修改代码风格,增加bool类型支持](https://github.com/gotomicro/ekit/pull/135)
66
- [mapx: hashmap添加刪除功能](https://github.com/gotomicro/ekit/pull/138)
77
- [mapx: HashMap 增加 Keys 和 Values 方法](https://github.com/gotomicro/ekit/pull/141)
8+
- [mapx: TreeMap](https://github.com/gotomicro/ekit/pull/142)
89

910
# v0.0.5
1011
- [atomicx: 泛型封装 atomic.Value](https://github.com/gotomicro/ekit/pull/101)

‎internal/tree/red_black_tree.go

+511
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,511 @@
1+
// Copyright 2021 gotomicro
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package tree
16+
17+
import (
18+
"errors"
19+
20+
"github.com/gotomicro/ekit"
21+
)
22+
23+
type color bool
24+
25+
const (
26+
Red color = false
27+
Black color = true
28+
)
29+
30+
var (
31+
ErrRBTreeSameRBNode = errors.New("ekit: RBTree不能添加重复节点Key")
32+
ErrRBTreeNotRBNode = errors.New("ekit: RBTree不存在节点Key")
33+
// errRBTreeCantRepaceNil = errors.New("ekit: RBTree不能将节点替换为nil")
34+
)
35+
36+
type RBTree[K any, V any] struct {
37+
root *rbNode[K, V]
38+
compare ekit.Comparator[K]
39+
size int
40+
}
41+
42+
func (rb *RBTree[K, V]) Size() int {
43+
if rb == nil {
44+
return 0
45+
}
46+
return rb.size
47+
}
48+
49+
type rbNode[K any, V any] struct {
50+
color color
51+
key K
52+
value V
53+
left, right, parent *rbNode[K, V]
54+
}
55+
56+
func (node *rbNode[K, V]) setNode(v V) {
57+
if node == nil {
58+
return
59+
}
60+
node.value = v
61+
}
62+
63+
// NewRBTree 构建红黑树
64+
func NewRBTree[K any, V any](compare ekit.Comparator[K]) *RBTree[K, V] {
65+
return &RBTree[K, V]{
66+
compare: compare,
67+
root: nil,
68+
}
69+
}
70+
71+
func newRBNode[K any, V any](key K, value V) *rbNode[K, V] {
72+
return &rbNode[K, V]{
73+
key: key,
74+
value: value,
75+
color: Red,
76+
left: nil,
77+
right: nil,
78+
parent: nil,
79+
}
80+
}
81+
82+
// Add 增加节点
83+
func (rb *RBTree[K, V]) Add(key K, value V) error {
84+
return rb.addNode(newRBNode(key, value))
85+
}
86+
87+
// Delete 删除节点
88+
func (rb *RBTree[K, V]) Delete(key K) {
89+
if node := rb.findNode(key); node != nil {
90+
rb.deleteNode(node)
91+
}
92+
}
93+
94+
// Find 查找节点
95+
func (rb *RBTree[K, V]) Find(key K) (V, error) {
96+
var v V
97+
if node := rb.findNode(key); node != nil {
98+
return node.value, nil
99+
}
100+
return v, ErrRBTreeNotRBNode
101+
}
102+
func (rb *RBTree[K, V]) Set(key K, value V) error {
103+
if node := rb.findNode(key); node != nil {
104+
node.setNode(value)
105+
return nil
106+
}
107+
return ErrRBTreeNotRBNode
108+
}
109+
110+
// addNode 插入新节点
111+
func (rb *RBTree[K, V]) addNode(node *rbNode[K, V]) error {
112+
var fixNode *rbNode[K, V]
113+
if rb.root == nil {
114+
rb.root = newRBNode[K, V](node.key, node.value)
115+
fixNode = rb.root
116+
} else {
117+
t := rb.root
118+
cmp := 0
119+
parent := &rbNode[K, V]{}
120+
for t != nil {
121+
parent = t
122+
cmp = rb.compare(node.key, t.key)
123+
if cmp < 0 {
124+
t = t.left
125+
} else if cmp > 0 {
126+
t = t.right
127+
} else if cmp == 0 {
128+
return ErrRBTreeSameRBNode
129+
}
130+
}
131+
fixNode = &rbNode[K, V]{
132+
key: node.key,
133+
parent: parent,
134+
value: node.value,
135+
color: Red,
136+
}
137+
if cmp < 0 {
138+
parent.left = fixNode
139+
} else {
140+
parent.right = fixNode
141+
}
142+
}
143+
rb.size++
144+
rb.fixAfterAdd(fixNode)
145+
return nil
146+
}
147+
148+
// deleteNode 红黑树删除方法
149+
// 删除分两步,第一步取出后继节点,第二部着色旋转
150+
// 取后继节点
151+
// case1:node左右非空子节点,通过getSuccessor获取后继节点
152+
// case2:node左右只有一个非空子节点
153+
// case3:node左右均为空节点
154+
// 着色旋转
155+
// case1:当删除节点非空且为黑色时,会违反红黑树任何路径黑节点个数相同的约束,所以需要重新平衡
156+
// case2:当删除红色节点时,不会破坏任何约束,所以不需要平衡
157+
func (rb *RBTree[K, V]) deleteNode(node *rbNode[K, V]) {
158+
// node左右非空,取后继节点
159+
if node.left != nil && node.right != nil {
160+
s := rb.findSuccessor(node)
161+
node.key = s.key
162+
node.value = s.value
163+
node = s
164+
}
165+
var replacement *rbNode[K, V]
166+
// node节点只有一个非空子节点
167+
if node.left != nil {
168+
replacement = node.left
169+
} else {
170+
replacement = node.right
171+
}
172+
if replacement != nil {
173+
replacement.parent = node.parent
174+
if node.parent == nil {
175+
rb.root = replacement
176+
} else if node == node.parent.left {
177+
node.parent.left = replacement
178+
} else {
179+
node.parent.right = replacement
180+
}
181+
node.left = nil
182+
node.right = nil
183+
node.parent = nil
184+
if node.getColor() {
185+
rb.fixAfterDelete(replacement)
186+
}
187+
} else if node.parent == nil {
188+
// 如果node节点无父节点,说明node为root节点
189+
rb.root = nil
190+
} else {
191+
// node子节点均为空
192+
if node.getColor() {
193+
rb.fixAfterDelete(node)
194+
}
195+
if node.parent != nil {
196+
if node == node.parent.left {
197+
node.parent.left = nil
198+
} else if node == node.parent.right {
199+
node.parent.right = nil
200+
}
201+
node.parent = nil
202+
}
203+
}
204+
rb.size--
205+
}
206+
207+
// findSuccessor 寻找后继节点
208+
// case1: node节点存在右子节点,则右子树的最小节点是node的后继节点
209+
// case2: node节点不存在右子节点,则其第一个为左节点的祖先的父节点为node的后继节点
210+
func (rb *RBTree[K, V]) findSuccessor(node *rbNode[K, V]) *rbNode[K, V] {
211+
if node == nil {
212+
return nil
213+
} else if node.right != nil {
214+
p := node.right
215+
for p.left != nil {
216+
p = p.left
217+
}
218+
return p
219+
} else {
220+
p := node.parent
221+
ch := node
222+
for p != nil && ch == p.right {
223+
ch = p
224+
p = p.parent
225+
}
226+
return p
227+
}
228+
229+
}
230+
231+
func (rb *RBTree[K, V]) findNode(key K) *rbNode[K, V] {
232+
node := rb.root
233+
for node != nil {
234+
cmp := rb.compare(key, node.key)
235+
if cmp < 0 {
236+
node = node.left
237+
} else if cmp > 0 {
238+
node = node.right
239+
} else {
240+
return node
241+
}
242+
}
243+
return nil
244+
}
245+
246+
// fixAfterAdd 插入时着色旋转
247+
// 如果是空节点、root节点、父节点是黑无需构建
248+
// 可分为3种情况
249+
// fixUncleRed 叔叔节点是红色右节点
250+
// fixAddLeftBlack 叔叔节点是黑色右节点
251+
// fixAddRightBlack 叔叔节点是黑色左节点
252+
func (rb *RBTree[K, V]) fixAfterAdd(x *rbNode[K, V]) {
253+
x.color = Red
254+
for x != nil && x != rb.root && x.getParent().getColor() == Red {
255+
uncle := x.getUncle()
256+
if uncle.getColor() == Red {
257+
x = rb.fixUncleRed(x, uncle)
258+
continue
259+
}
260+
if x.getParent() == x.getGrandParent().getLeft() {
261+
x = rb.fixAddLeftBlack(x)
262+
continue
263+
}
264+
x = rb.fixAddRightBlack(x)
265+
}
266+
rb.root.setColor(Black)
267+
}
268+
269+
// fixAddLeftRed 叔叔节点是红色右节点,由于不能存在连续红色节点,此时祖父节点x.getParent().getParent()必为黑。另x为红所以叔父节点需要变黑,祖父变红,此时红黑树完成
270+
//
271+
// b(b) b(r)
272+
// / \ / \
273+
// a(r) y(r) -> a(b) y(b)
274+
// / \ / \ / \ / \
275+
// x(r) nil nil nil x (r) nil nil nil
276+
// / \ / \
277+
// nil nil nil nil
278+
func (rb *RBTree[K, V]) fixUncleRed(x *rbNode[K, V], y *rbNode[K, V]) *rbNode[K, V] {
279+
x.getParent().setColor(Black)
280+
y.setColor(Black)
281+
x.getGrandParent().setColor(Red)
282+
x = x.getGrandParent()
283+
return x
284+
}
285+
286+
// fixAddLeftBlack 叔叔节点是黑色右节点.x节点是父节点左节点,执行左旋,此时x节点变为原x节点的父节点a,也就是左子节点。的接着将x的父节点和爷爷节点的颜色对换。然后对爷爷节点进行右旋转,此时红黑树完成
287+
// 如果x为左节点则跳过左旋操作
288+
//
289+
// b(b) b(b) b(r)
290+
// / \ / \ / \
291+
// a(r) y(b) -> a(r) y(b) -> a(b) y(b)
292+
// / \ / \ / \ / \ / \ / \
293+
// nil x (r) nil nil x(r) nil nil nil x(r) nil nil nil
294+
// / \ / \ / \
295+
// nil nil nil nil nil nil
296+
func (rb *RBTree[K, V]) fixAddLeftBlack(x *rbNode[K, V]) *rbNode[K, V] {
297+
if x == x.getParent().getRight() {
298+
x = x.getParent()
299+
rb.rotateLeft(x)
300+
}
301+
x.getParent().setColor(Black)
302+
x.getGrandParent().setColor(Red)
303+
rb.rotateRight(x.getGrandParent())
304+
return x
305+
}
306+
307+
// fixAddRightBlack 叔叔节点是黑色左节点.x节点是父节点右节点,执行右旋,此时x节点变为原x节点的父节点a,也就是右子节点。接着将x的父节点和爷爷节点的颜色对换。然后对爷爷节点进行右旋转,此时红黑树完成
308+
// 如果x为右节点则跳过右旋操作
309+
//
310+
// b(b) b(b) b(r)
311+
// / \ / \ / \
312+
// y(b) a(r) -> y(b) a(r) -> y(b) a(b)
313+
// / \ / \ / \ / \ / \ / \
314+
// nil nil x(r) nil nil nil nil x(r) nil nil nil x(r)
315+
// / \ / \ / \
316+
// nil nil nil nil nil nil
317+
func (rb *RBTree[K, V]) fixAddRightBlack(x *rbNode[K, V]) *rbNode[K, V] {
318+
if x == x.getParent().getLeft() {
319+
x = x.getParent()
320+
rb.rotateRight(x)
321+
}
322+
x.getParent().setColor(Black)
323+
x.getGrandParent().setColor(Red)
324+
rb.rotateLeft(x.getGrandParent())
325+
return x
326+
}
327+
328+
// fixAfterDelete 删除时着色旋转
329+
// 根据x是节点位置分为fixAfterDeleteLeft,fixAfterDeleteRight两种情况
330+
func (rb *RBTree[K, V]) fixAfterDelete(x *rbNode[K, V]) {
331+
for x != rb.root && x.getColor() == Black {
332+
if x == x.parent.getLeft() {
333+
x = rb.fixAfterDeleteLeft(x)
334+
} else {
335+
x = rb.fixAfterDeleteRight(x)
336+
}
337+
}
338+
x.setColor(Black)
339+
}
340+
341+
// fixAfterDeleteLeft 处理x为左子节点时的平衡处理
342+
func (rb *RBTree[K, V]) fixAfterDeleteLeft(x *rbNode[K, V]) *rbNode[K, V] {
343+
sib := x.getParent().getRight()
344+
if sib.getColor() == Red {
345+
sib.setColor(Black)
346+
sib.getParent().setColor(Red)
347+
rb.rotateLeft(x.getParent())
348+
sib = x.getParent().getRight()
349+
}
350+
if sib.getLeft().getColor() == Black && sib.getRight().getColor() == Black {
351+
sib.setColor(Red)
352+
x = x.getParent()
353+
} else {
354+
if sib.getRight().getColor() == Black {
355+
sib.getLeft().setColor(Black)
356+
sib.setColor(Red)
357+
rb.rotateRight(sib)
358+
sib = x.getParent().getRight()
359+
}
360+
sib.setColor(x.getParent().getColor())
361+
x.getParent().setColor(Black)
362+
sib.getRight().setColor(Black)
363+
rb.rotateLeft(x.getParent())
364+
x = rb.root
365+
}
366+
return x
367+
}
368+
369+
// fixAfterDeleteRight 处理x为右子节点时的平衡处理
370+
func (rb *RBTree[K, V]) fixAfterDeleteRight(x *rbNode[K, V]) *rbNode[K, V] {
371+
sib := x.getParent().getLeft()
372+
if sib.getColor() == Red {
373+
sib.setColor(Black)
374+
x.getParent().setColor(Red)
375+
rb.rotateRight(x.getParent())
376+
sib = x.getBrother()
377+
}
378+
if sib.getRight().getColor() == Black && sib.getLeft().getColor() == Black {
379+
sib.setColor(Red)
380+
x = x.getParent()
381+
} else {
382+
if sib.getLeft().getColor() == Black {
383+
sib.getRight().setColor(Black)
384+
sib.setColor(Red)
385+
rb.rotateLeft(sib)
386+
sib = x.getParent().getLeft()
387+
}
388+
sib.setColor(x.getParent().getColor())
389+
x.getParent().setColor(Black)
390+
sib.getLeft().setColor(Black)
391+
rb.rotateRight(x.getParent())
392+
x = rb.root
393+
}
394+
return x
395+
}
396+
397+
// rotateLeft 左旋转
398+
//
399+
// b a
400+
// / \ / \
401+
// c a -> b y
402+
// / \ / \
403+
// x y c x
404+
405+
func (rb *RBTree[K, V]) rotateLeft(node *rbNode[K, V]) {
406+
if node == nil || node.getRight() == nil {
407+
return
408+
}
409+
r := node.right
410+
node.right = r.left
411+
if r.left != nil {
412+
r.left.parent = node
413+
}
414+
r.parent = node.parent
415+
if node.parent == nil {
416+
rb.root = r
417+
} else if node.parent.left == node {
418+
node.parent.left = r
419+
} else {
420+
node.parent.right = r
421+
}
422+
r.left = node
423+
node.parent = r
424+
425+
}
426+
427+
// rotateRight 右旋转
428+
//
429+
// b c
430+
// / \ / \
431+
// c a -> x b
432+
// / \ / \
433+
// x y y a
434+
func (rb *RBTree[K, V]) rotateRight(node *rbNode[K, V]) {
435+
if node == nil || node.getLeft() == nil {
436+
return
437+
}
438+
l := node.left
439+
node.left = l.right
440+
if l.right != nil {
441+
l.right.parent = node
442+
}
443+
l.parent = node.parent
444+
if node.parent == nil {
445+
rb.root = l
446+
} else if node.parent.right == node {
447+
node.parent.right = l
448+
} else {
449+
node.parent.left = l
450+
}
451+
l.right = node
452+
node.parent = l
453+
454+
}
455+
456+
func (node *rbNode[K, V]) getColor() color {
457+
if node == nil {
458+
return Black
459+
}
460+
return node.color
461+
}
462+
463+
func (node *rbNode[K, V]) setColor(color color) {
464+
if node == nil {
465+
return
466+
}
467+
node.color = color
468+
}
469+
470+
func (node *rbNode[K, V]) getParent() *rbNode[K, V] {
471+
if node == nil {
472+
return nil
473+
}
474+
return node.parent
475+
}
476+
477+
func (node *rbNode[K, V]) getLeft() *rbNode[K, V] {
478+
if node == nil {
479+
return nil
480+
}
481+
return node.left
482+
}
483+
484+
func (node *rbNode[K, V]) getRight() *rbNode[K, V] {
485+
if node == nil {
486+
return nil
487+
}
488+
return node.right
489+
}
490+
491+
func (node *rbNode[K, V]) getUncle() *rbNode[K, V] {
492+
if node == nil {
493+
return nil
494+
}
495+
return node.getParent().getBrother()
496+
}
497+
func (node *rbNode[K, V]) getGrandParent() *rbNode[K, V] {
498+
if node == nil {
499+
return nil
500+
}
501+
return node.getParent().getParent()
502+
}
503+
func (node *rbNode[K, V]) getBrother() *rbNode[K, V] {
504+
if node == nil {
505+
return nil
506+
}
507+
if node == node.getParent().getLeft() {
508+
return node.getParent().getRight()
509+
}
510+
return node.getParent().getLeft()
511+
}

‎internal/tree/red_black_tree_test.go

+1,400
Large diffs are not rendered by default.

‎mapx/treemap.go

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright 2021 gotomicro
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package mapx
16+
17+
import (
18+
"errors"
19+
20+
"github.com/gotomicro/ekit"
21+
"github.com/gotomicro/ekit/internal/tree"
22+
)
23+
24+
var (
25+
errTreeMapComparatorIsNull = errors.New("ekit: Comparator不能为nil")
26+
)
27+
28+
// TreeMap 是基于红黑树实现的Map
29+
type TreeMap[K any, V any] struct {
30+
*tree.RBTree[K, V]
31+
}
32+
33+
// NewTreeMapWithMap TreeMap构造方法
34+
// 支持通过传入的map构造生成TreeMap
35+
func NewTreeMapWithMap[K comparable, V any](compare ekit.Comparator[K], m map[K]V) (*TreeMap[K, V], error) {
36+
treeMap, err := NewTreeMap[K, V](compare)
37+
if err != nil {
38+
return treeMap, err
39+
}
40+
putAll(treeMap, m)
41+
return treeMap, nil
42+
}
43+
44+
// NewTreeMap TreeMap构造方法,创建一个的TreeMap
45+
// 需注意比较器compare不能为nil
46+
func NewTreeMap[K any, V any](compare ekit.Comparator[K]) (*TreeMap[K, V], error) {
47+
if compare == nil {
48+
return nil, errTreeMapComparatorIsNull
49+
}
50+
return &TreeMap[K, V]{
51+
RBTree: tree.NewRBTree[K, V](compare),
52+
}, nil
53+
}
54+
55+
// putAll 将map传入TreeMap
56+
// 需注意如果map中的key已存在,value将被替换
57+
func putAll[K comparable, V any](treeMap *TreeMap[K, V], m map[K]V) {
58+
for k, v := range m {
59+
_ = treeMap.Put(k, v)
60+
}
61+
}
62+
63+
// Put 在TreeMap插入指定值
64+
// 需注意如果TreeMap已存在该Key那么原值会被替换
65+
func (treeMap *TreeMap[K, V]) Put(key K, value V) error {
66+
err := treeMap.Add(key, value)
67+
if err == tree.ErrRBTreeSameRBNode {
68+
return treeMap.Set(key, value)
69+
}
70+
return nil
71+
}
72+
73+
// Get 在TreeMap找到指定Key的节点,返回Val
74+
// TreeMap未找到指定节点将会返回false
75+
func (treeMap *TreeMap[K, V]) Get(key K) (V, bool) {
76+
v, err := treeMap.Find(key)
77+
return v, err == nil
78+
}
79+
80+
// Remove TreeMap中删除指定key的节点
81+
func (treeMap *TreeMap[T, V]) Remove(k T) {
82+
treeMap.Delete(k)
83+
}
84+
85+
var _ mapi[any, any] = (*TreeMap[any, any])(nil)

‎mapx/treemap_test.go

+383
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
// Copyright 2021 gotomicro
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package mapx
16+
17+
import (
18+
"errors"
19+
"testing"
20+
21+
"github.com/gotomicro/ekit"
22+
"github.com/stretchr/testify/assert"
23+
)
24+
25+
func TestNewTreeMapWithMap(t *testing.T) {
26+
tests := []struct {
27+
name string
28+
m map[int]int
29+
comparable ekit.Comparator[int]
30+
wantKey []int
31+
wantVal []int
32+
wantErr error
33+
}{
34+
{
35+
name: "nil",
36+
m: nil,
37+
comparable: nil,
38+
wantKey: nil,
39+
wantVal: nil,
40+
wantErr: errors.New("ekit: Comparator不能为nil"),
41+
},
42+
{
43+
name: "empty",
44+
m: map[int]int{},
45+
comparable: compare(),
46+
wantKey: nil,
47+
wantVal: nil,
48+
wantErr: nil,
49+
},
50+
{
51+
name: "single",
52+
m: map[int]int{
53+
0: 0,
54+
},
55+
comparable: compare(),
56+
wantKey: []int{0},
57+
wantVal: []int{0},
58+
wantErr: nil,
59+
},
60+
{
61+
name: "multiple",
62+
m: map[int]int{
63+
0: 0,
64+
1: 1,
65+
2: 2,
66+
},
67+
comparable: compare(),
68+
wantKey: []int{0, 1, 2},
69+
wantVal: []int{0, 1, 2},
70+
wantErr: nil,
71+
},
72+
{
73+
name: "disorder",
74+
m: map[int]int{
75+
1: 1,
76+
2: 2,
77+
0: 0,
78+
3: 3,
79+
5: 5,
80+
4: 4,
81+
},
82+
comparable: compare(),
83+
wantKey: []int{0, 1, 2, 3, 5, 4},
84+
wantVal: []int{0, 1, 2, 3, 5, 4},
85+
wantErr: nil,
86+
},
87+
}
88+
for _, tt := range tests {
89+
t.Run(tt.name, func(t *testing.T) {
90+
treeMap, err := NewTreeMapWithMap[int, int](tt.comparable, tt.m)
91+
if err != nil {
92+
assert.Equal(t, tt.wantErr, err)
93+
return
94+
}
95+
for k, v := range tt.m {
96+
value, _ := treeMap.Get(k)
97+
assert.Equal(t, true, v == value)
98+
}
99+
100+
})
101+
102+
}
103+
}
104+
105+
func TestTreeMap_Get(t *testing.T) {
106+
var tests = []struct {
107+
name string
108+
m map[int]int
109+
findKey int
110+
wantVal int
111+
wantBool bool
112+
}{
113+
{
114+
name: "empty-TreeMap",
115+
m: map[int]int{},
116+
findKey: 0,
117+
wantVal: 0,
118+
wantBool: false,
119+
},
120+
{
121+
name: "find",
122+
m: map[int]int{
123+
1: 1,
124+
2: 2,
125+
0: 0,
126+
3: 3,
127+
5: 5,
128+
4: 4,
129+
},
130+
findKey: 2,
131+
wantVal: 2,
132+
wantBool: true,
133+
},
134+
{
135+
name: "not-find",
136+
m: map[int]int{
137+
1: 1,
138+
2: 2,
139+
0: 0,
140+
3: 3,
141+
5: 5,
142+
4: 4,
143+
},
144+
findKey: 6,
145+
wantVal: 0,
146+
wantBool: false,
147+
},
148+
}
149+
for _, tt := range tests {
150+
t.Run(tt.name, func(t *testing.T) {
151+
treeMap, _ := NewTreeMap[int, int](compare())
152+
putAll(treeMap, tt.m)
153+
val, b := treeMap.Get(tt.findKey)
154+
assert.Equal(t, tt.wantBool, b)
155+
assert.Equal(t, tt.wantVal, val)
156+
})
157+
}
158+
}
159+
160+
func TestTreeMap_Put(t *testing.T) {
161+
162+
tests := []struct {
163+
name string
164+
k []int
165+
v []string
166+
wantKey []int
167+
wantVal []string
168+
wantErr error
169+
}{
170+
{
171+
name: "single",
172+
k: []int{0},
173+
v: []string{"0"},
174+
wantKey: []int{0},
175+
wantVal: []string{"0"},
176+
wantErr: nil,
177+
},
178+
{
179+
name: "multiple",
180+
k: []int{0, 1, 2},
181+
v: []string{"0", "1", "2"},
182+
wantKey: []int{0, 1, 2},
183+
wantVal: []string{"0", "1", "2"},
184+
wantErr: nil,
185+
},
186+
{
187+
name: "same",
188+
k: []int{0, 0, 1, 2, 2, 3},
189+
v: []string{"0", "999", "1", "998", "2", "3"},
190+
wantKey: []int{0, 1, 2, 3},
191+
wantVal: []string{"999", "1", "2", "3"},
192+
wantErr: nil,
193+
},
194+
{
195+
name: "same",
196+
k: []int{0, 0},
197+
v: []string{"0", "999"},
198+
wantKey: []int{0},
199+
wantVal: []string{"999"},
200+
wantErr: nil,
201+
},
202+
{
203+
name: "disorder",
204+
k: []int{1, 2, 0, 3, 5, 4},
205+
v: []string{"1", "2", "0", "3", "5", "4"},
206+
wantKey: []int{0, 1, 2, 3, 4, 5},
207+
wantVal: []string{"0", "1", "2", "3", "4", "5"},
208+
wantErr: nil,
209+
},
210+
{
211+
name: "disorder-same",
212+
k: []int{1, 3, 2, 0, 2, 3},
213+
v: []string{"1", "2", "998", "0", "3", "997"},
214+
wantKey: []int{0, 1, 2, 3},
215+
wantVal: []string{"0", "1", "3", "997"},
216+
wantErr: nil,
217+
},
218+
}
219+
for _, tt := range tests {
220+
t.Run(tt.name, func(t *testing.T) {
221+
treeMap, _ := NewTreeMap[int, string](compare())
222+
for i := 0; i < len(tt.k); i++ {
223+
err := treeMap.Put(tt.k[i], tt.v[i])
224+
if err != nil {
225+
assert.Equal(t, tt.wantErr, err)
226+
return
227+
}
228+
}
229+
for i := 0; i < len(tt.wantKey); i++ {
230+
v, b := treeMap.Get(tt.wantKey[i])
231+
assert.Equal(t, true, b)
232+
assert.Equal(t, tt.wantVal[i], v)
233+
}
234+
235+
})
236+
}
237+
subTests := []struct {
238+
name string
239+
k []int
240+
v []string
241+
wantKey []int
242+
wantVal []string
243+
wantErr error
244+
}{
245+
{
246+
name: "nil",
247+
k: []int{0},
248+
v: nil,
249+
wantKey: []int{0},
250+
wantVal: []string(nil),
251+
},
252+
{
253+
name: "nil",
254+
k: []int{0},
255+
v: []string{"0"},
256+
wantKey: []int{0},
257+
wantVal: []string{"0"},
258+
},
259+
}
260+
for _, tt := range subTests {
261+
t.Run(tt.name, func(t *testing.T) {
262+
treeMap, _ := NewTreeMap[int, []string](compare())
263+
for i := 0; i < len(tt.k); i++ {
264+
err := treeMap.Put(tt.k[i], tt.v)
265+
if err != nil {
266+
assert.Equal(t, tt.wantErr, err)
267+
return
268+
}
269+
}
270+
for i := 0; i < len(tt.wantKey); i++ {
271+
v, b := treeMap.Get(tt.wantKey[i])
272+
assert.Equal(t, true, b)
273+
assert.Equal(t, tt.wantVal, v)
274+
}
275+
276+
})
277+
}
278+
}
279+
280+
func TestTreeMap_Remove(t *testing.T) {
281+
var tests = []struct {
282+
name string
283+
m map[int]int
284+
delKey int
285+
wantVal int
286+
wantBool bool
287+
}{
288+
{
289+
name: "empty-TreeMap",
290+
m: map[int]int{},
291+
delKey: 0,
292+
wantVal: 0,
293+
},
294+
{
295+
name: "find",
296+
m: map[int]int{
297+
1: 1,
298+
2: 2,
299+
0: 0,
300+
3: 3,
301+
5: 5,
302+
4: 4,
303+
},
304+
delKey: 2,
305+
wantVal: 0,
306+
},
307+
{
308+
name: "not-find",
309+
m: map[int]int{
310+
1: 1,
311+
2: 2,
312+
0: 0,
313+
3: 3,
314+
5: 5,
315+
4: 4,
316+
},
317+
delKey: 6,
318+
wantVal: 0,
319+
},
320+
}
321+
for _, tt := range tests {
322+
t.Run(tt.name, func(t *testing.T) {
323+
treeMap, _ := NewTreeMap[int, int](compare())
324+
treeMap.Remove(tt.delKey)
325+
val, err := treeMap.Get(tt.delKey)
326+
assert.Equal(t, tt.wantBool, err)
327+
assert.Equal(t, tt.wantVal, val)
328+
})
329+
}
330+
}
331+
332+
func compare() ekit.Comparator[int] {
333+
return ekit.ComparatorRealNumber[int]
334+
}
335+
336+
// goos: windows
337+
// goarch: amd64
338+
// pkg: github.com/gotomicro/ekit/mapx
339+
// cpu: Intel(R) Core(TM) i5-7500 CPU @ 3.40GHz
340+
// BenchmarkTreeMap/treeMap_put-4 10000 250.6 ns/op 95 B/op 1 allocs/op
341+
// BenchmarkTreeMap/map_put-4 10000 103.0 ns/op 68 B/op 0 allocs/op
342+
// BenchmarkTreeMap/hashMap_put-4 10000 250.6 ns/op 107 B/op 1 allocs/op
343+
// BenchmarkTreeMap/treeMap_get-4 10000 52.16 ns/op 0 B/op 0 allocs/op
344+
// BenchmarkTreeMap/map_get-4 10000 0 B/op 0 allocs/op
345+
// BenchmarkTreeMap/hashMap_get-4 10000 52.89 ns/op 7 B/op 0 allocs/op
346+
// PASS
347+
// ok github.com/gotomicro/ekit/mapx 0.797s
348+
func BenchmarkTreeMap(b *testing.B) {
349+
hashmap := NewHashMap[hashInt, int](10)
350+
treeMap, _ := NewTreeMap[uint64, int](ekit.ComparatorRealNumber[uint64])
351+
m := make(map[uint64]int, 10)
352+
b.Run("treeMap_put", func(b *testing.B) {
353+
for i := 0; i < b.N; i++ {
354+
_ = treeMap.Put(uint64(i), i)
355+
}
356+
})
357+
b.Run("map_put", func(b *testing.B) {
358+
for i := 0; i < b.N; i++ {
359+
m[uint64(i)] = i
360+
}
361+
})
362+
b.Run("hashMap_put", func(b *testing.B) {
363+
for i := 0; i < b.N; i++ {
364+
_ = hashmap.Put(hashInt(uint64(i)), i)
365+
}
366+
})
367+
b.Run("treeMap_get", func(b *testing.B) {
368+
for i := 0; i < b.N; i++ {
369+
_, _ = treeMap.Get(uint64(i))
370+
}
371+
})
372+
b.Run("map_get", func(b *testing.B) {
373+
for i := 0; i < b.N; i++ {
374+
_ = m[uint64(i)]
375+
}
376+
})
377+
b.Run("hashMap_get", func(b *testing.B) {
378+
for i := 0; i < b.N; i++ {
379+
_, _ = hashmap.Get(hashInt(uint64(i)))
380+
}
381+
})
382+
383+
}

0 commit comments

Comments
 (0)
Please sign in to comment.