├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── feature-request.md │ └── general-question.md └── workflows │ └── go.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── api └── api.go ├── avltree ├── avltree.go ├── avltree_bench_test.go └── avltree_test.go ├── btree ├── btree.go ├── btree_bench_test.go └── btree_test.go ├── cmap ├── cmap.go ├── cmap_bench_test.go ├── cmap_stdmap.go └── cmap_test.go ├── cmp └── cmp.go ├── go.mod ├── go.sum ├── ifop ├── ifop.go └── ifop_test.go ├── linkedlist ├── linkedlist.go ├── linkedlist_benchmark_test.go └── linkedlist_test.go ├── mapex ├── mapex.go └── mapex_test.go ├── must └── must.go ├── radix └── radix.go ├── rbtree ├── rbtree.go ├── rbtree_bench_test.go └── rbtree_test.go ├── rhashmap ├── opt.go ├── rhashmap.go ├── rhashmap_bench_test.go └── rhashmap_test.go ├── rwmap ├── rwmap.go ├── rwmap_bench_test.go └── rwmap_test.go ├── set ├── set.go └── set_test.go ├── skiplist ├── skiplist.go ├── skiplist_bench_test.go └── skiplist_test.go ├── trie ├── trie_map.go └── trie_map_test.go ├── vec ├── example_search_test.go ├── vec.go └── vec_test.go └── vecdeque ├── vecdeque.go └── vecdeque_test.go /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F91D Bug Report" 3 | about: As a User, I want to report a Bug. 4 | labels: type/bug 5 | --- 6 | 7 | ## Bug Report 8 | 9 | Please answer these questions before submitting your issue. Thanks! 10 | 11 | ### 1. Minimal reproduce step (Required) 12 | 13 | 14 | 15 | ### 2. What did you expect to see? (Required) 16 | 17 | ### 3. What did you see instead (Required) 18 | 19 | ### 4. What is your gstl version? (Required) 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F44F Feature Request" 3 | about: As a user, I want to request a New Feature on the product. 4 | labels: type/feature-request 5 | --- 6 | 7 | ## Feature Request 8 | 9 | **Is your feature request related to a problem? Please describe:** 10 | 11 | 12 | **Describe the feature you'd like:** 13 | 14 | 15 | **Describe alternatives you've considered:** 16 | 17 | 18 | **Teachability, Documentation, Adoption, Migration Strategy:** 19 | 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/general-question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F600 Ask a Question" 3 | about: I want to ask a question. 4 | labels: type/question 5 | --- 6 | 7 | ## General Question 8 | 9 | 20 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | pull_request: 6 | 7 | jobs: 8 | 9 | build: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | go: ['1.22', '1.23'] 14 | name: Go ${{ matrix.go }} sample 15 | 16 | steps: 17 | 18 | - name: Set up Go 1.19 19 | uses: actions/setup-go@v1 20 | with: 21 | go-version: ${{ matrix.go }} 22 | id: go 23 | 24 | - name: Check out code into the Go module directory 25 | uses: actions/checkout@v1 26 | 27 | - name: Get dependencies 28 | run: | 29 | go get -v -t -d ./... 30 | if [ -f Gopkg.toml ]; then 31 | curl https://raw.githubusercontent.com/golang/dep/master/install.sh | sh 32 | dep ensure 33 | fi 34 | 35 | - name: Test 36 | run: go test -v -coverprofile='coverage.out' -covermode=count ./... 37 | 38 | - name: Upload Coverage report 39 | uses: codecov/codecov-action@v1 40 | with: 41 | token: ${{secrets.CODECOV_TOKEN}} 42 | file: ./coverage.out 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | *.log 14 | 15 | # Dependency directories (remove the comment below to include it) 16 | # vendor/ 17 | cover.cov 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: fmt test 2 | 3 | # lvim没有很好适配泛型语, 不能自动格式化. 这里先手动执行下 4 | bug: 5 | #go test -test.run=Test_Btree_RangePrev ./... 6 | 7 | fmt: 8 | go fmt ./... 9 | test: 10 | go test ./... 11 | 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gstl 2 | 支持泛型的数据结构库 3 | [![Go](https://github.com/antlabs/gstl/workflows/Go/badge.svg)](https://github.com/antlabs/gstl/actions) 4 | [![codecov](https://codecov.io/gh/antlabs/gstl/branch/master/graph/badge.svg)](https://codecov.io/gh/antlabs/gstl) 5 | 6 | ## 一、`vec` 7 | ```go 8 | 9 | ``` 10 | ## 二、`Listked` 11 | 12 | `Listked` 是一个支持泛型的双向链表容器,提供了加锁和不加锁的实现。 13 | 14 | #### 不加锁的使用方式 15 | 16 | ```go 17 | package main 18 | 19 | import ( 20 | "fmt" 21 | "github.com/antlabs/gstl/linkedlist" 22 | ) 23 | 24 | func main() { 25 | // 创建一个不加锁的链表 26 | list := linkedlist.New[int]() 27 | 28 | // 插入元素 29 | list.PushBack(1) 30 | list.PushFront(0) 31 | 32 | // 遍历链表 33 | list.Range(func(value int) { 34 | fmt.Println(value) 35 | }) 36 | 37 | // 删除元素 38 | list.Remove(0) 39 | } 40 | ``` 41 | 42 | #### 加锁的使用方式 43 | 44 | ```go 45 | package main 46 | 47 | import ( 48 | "fmt" 49 | "sync" 50 | "github.com/antlabs/gstl/linkedlist" 51 | ) 52 | 53 | func main() { 54 | // 创建一个加锁的链表 55 | list := linkedlist.NewConcurrent[int]() 56 | 57 | var wg sync.WaitGroup 58 | wg.Add(2) 59 | 60 | // 并发插入元素 61 | go func() { 62 | defer wg.Done() 63 | list.PushBack(1) 64 | list.PushFront(0) 65 | }() 66 | 67 | // 并发遍历链表 68 | go func() { 69 | defer wg.Done() 70 | list.Range(func(value int) { 71 | fmt.Println(value) 72 | }) 73 | }() 74 | 75 | wg.Wait() 76 | 77 | // 删除元素 78 | list.Remove(0) 79 | } 80 | ``` 81 | 82 | ### 区别 83 | 84 | - **不加锁的链表**:适用于单线程环境,性能更高。 85 | - **加锁的链表**:适用于多线程环境,保证线程安全。 86 | 87 | ## 三、`rhashmap` 88 | 和标准库不同的地方是有序hash 89 | ```go 90 | ``` 91 | 92 | ## 四、`btree` 93 | ```go 94 | ``` 95 | ## 五、`SkipList` 96 | 97 | `SkipList` 是一种高效的有序数据结构,支持快速的插入、删除和查找操作。 98 | 99 | #### 基本使用 100 | 101 | ```go 102 | package main 103 | 104 | import ( 105 | "fmt" 106 | "github.com/antlabs/gstl/skiplist" 107 | ) 108 | 109 | func main() { 110 | // 创建一个新的 SkipList 111 | sl := skiplist.New[int, string]() 112 | 113 | // 插入元素 114 | sl.Insert(1, "one") 115 | sl.Insert(2, "two") 116 | 117 | // 获取元素 118 | if value, ok := sl.Get(1); ok { 119 | fmt.Println("Key 1:", value) 120 | } 121 | 122 | // 删除元素 123 | sl.Delete(1) 124 | } 125 | ``` 126 | 127 | #### 并发安全的使用 128 | 129 | ```go 130 | package main 131 | 132 | import ( 133 | "fmt" 134 | "sync" 135 | "github.com/antlabs/gstl/skiplist" 136 | ) 137 | 138 | func main() { 139 | // 创建一个并发安全的 SkipList 140 | csl := skiplist.NewConcurrent[int, string]() 141 | var wg sync.WaitGroup 142 | 143 | // 并发插入元素 144 | wg.Add(2) 145 | go func() { 146 | defer wg.Done() 147 | csl.Insert(1, "one") 148 | }() 149 | go func() { 150 | defer wg.Done() 151 | csl.Insert(2, "two") 152 | }() 153 | wg.Wait() 154 | 155 | // 并发获取元素 156 | wg.Add(2) 157 | go func() { 158 | defer wg.Done() 159 | if value, ok := csl.Get(1); ok { 160 | fmt.Println("Key 1:", value) 161 | } 162 | }() 163 | go func() { 164 | defer wg.Done() 165 | if value, ok := csl.Get(2); ok { 166 | fmt.Println("Key 2:", value) 167 | } 168 | }() 169 | wg.Wait() 170 | } 171 | ``` 172 | 173 | ## 六、`rbtree` 174 | ```go 175 | ``` 176 | 177 | ## 七、`avltree` 178 | ```go 179 | ``` 180 | 181 | ## 八、`trie` 182 | ```go 183 | // 声明一个bool类型的trie tree 184 | t := trie.New[bool]() 185 | 186 | // 新增一个key 187 | t.Set("hello", true) 188 | 189 | // 获取值 190 | v := t.Get("hello") 191 | 192 | // 检查trie中是有hello前缀的数据 193 | ok := t.HasPrefix("hello") 194 | 195 | // 删除键 196 | t.Delete(k string) 197 | 198 | // 返回trie中保存的元素个数 199 | t.Len() 200 | ``` 201 | 202 | ## 九、`set` 203 | ```go 204 | // 声明一个string类型的set 205 | s := set.New[string]() 206 | 207 | // 新加成员 208 | s.Set("1") 209 | s.Set("2") 210 | s.Set("3") 211 | 212 | // 查看某个变量是否存在set中 213 | s.IsMember(1) 214 | 215 | // 长度 216 | s.Len() 217 | 218 | // set转slice 219 | s.ToSlice() 220 | 221 | // 深度复制一份 222 | newSet := s.Close() 223 | 224 | // 集合取差集 s - s2 225 | s := From("hello", "world", "1234", "4567") 226 | s2 := From("1234", "4567") 227 | 228 | newSet := s.Diff(s2) 229 | assert.Equal(t, newSet.ToSlice(), []string{"hello", "world"}) 230 | 231 | // 集合取交集 232 | s := From("1234", "5678", "9abc") 233 | s2 := From("abcde", "5678", "9abc") 234 | 235 | v := s.Intersection(s2).ToSlice() 236 | assert.Equal(t, v, []string{"5678", "9abc"}) 237 | 238 | // 集合取并集 239 | s := From("1111") 240 | s1 := From("2222") 241 | s2 := From("3333") 242 | 243 | newSet := s.Union(s1, s2) 244 | assert.Equal(t, newSet.ToSlice(), []string{"1111", "2222", "3333"}) 245 | 246 | // 测试集合s每个元素是否在s1里面, s <= s1 247 | s := From("5678", "9abc") 248 | s2 := From("abcde", "5678", "9abc") 249 | 250 | assert.True(t, s.IsSubset(s2)) 251 | 252 | // 测试集合s1每个元素是否在s里面 s1 <= s 253 | s2 := From("5678", "9abc") 254 | s := From("abcde", "5678", "9abc") 255 | 256 | assert.True(t, s.IsSuperset(s2)) 257 | 258 | // 遍历某个集合 259 | a := []string{"1111", "2222", "3333"} 260 | s := From(a...) 261 | for _, v := range a { 262 | s.Set(v) 263 | } 264 | 265 | s.Range(func(k string) bool { 266 | fmt.Println(k) 267 | return true 268 | }) 269 | 270 | // 测试两个集合是否相等 Equal 271 | s := New[int]() 272 | max := 1000 273 | for i := 0; i < max; i++ { 274 | s.Set(i) 275 | } 276 | 277 | s2 := s.Clone() 278 | 279 | assert.True(t, s.Equal(s2)) 280 | ``` 281 | 282 | ## 十、`ifop` 283 | ifop是弥补下golang没有三目运算符,使用函数模拟 284 | ### 10.1 if else部分类型相同 285 | ```go 286 | // 如果该值不为0, 返回原来的值,否则默认值 287 | val = IfElse(len(val) != 0, val, "default") 288 | ``` 289 | ### 10.2 if else部分类型不同 290 | ```go 291 | o := map[string]any{"hello": "hello"} 292 | a := []any{"hello", "world"} 293 | fmt.Printf("%#v", IfElseAny(o != nil, o, a)) 294 | ``` 295 | ## 十一、`mapex` 296 | 薄薄一层包装,增加标准库map的接口 297 | * mapex.Keys() 298 | ```go 299 | m := make(map[string]string) 300 | m["a"] = "1" 301 | m["b"] = "2" 302 | m["c"] = "3" 303 | get := mapex.Keys(m)// 返回map的所有key 304 | 305 | ``` 306 | * mapex.Values() 307 | ```go 308 | m := make(map[string]string) 309 | m["a"] = "1" 310 | m["b"] = "2" 311 | m["c"] = "3" 312 | get := mapex.Values(m) 313 | ``` 314 | ## 十二、`rwmap` 315 | rwmap与sync.Map类似支持并发访问,只解决sync.Map 2个问题. 316 | 1. 没有Len成员函数 317 | 2. 以及没有使用泛型语法,有运行才发现类型使用错误的烦恼 318 | ```go 319 | var m rwmap.RWMap[string, string] // 声明一个string, string的map 320 | m.Store("hello", "1") // 保存 321 | v1, ok1 := m.Load("hello") // 获取值 322 | v1, ok1 = m.LoadAndDelete("hello") //返回hello对应值,然后删除hello 323 | Delete("hello") // 删除 324 | v1, ok1 = m.LoadOrStore("hello", "world") 325 | 326 | // 遍历,使用回调函数 327 | m.Range(func(key, val string) bool { 328 | fmt.Printf("k:%s, val:%s\n"i, key, val) 329 | return true 330 | }) 331 | 332 | // 遍历,迭代器 333 | for pair := range m.Iter() { 334 | fmt.Printf("k:%s, val:%s\n", pair.Key, pair.Val) 335 | } 336 | 337 | m.Len()// 获取长度 338 | allKeys := m.Keys() //返回所有的key 339 | allValues := m.Values()// 返回所有的value 340 | ``` 341 | ## 十三、`cmap` 342 | cmap是用锁分区的方式实现的,(TODO优化,目前只有几个指标比sync.Map快) 343 | ```go 344 | var m cmap.CMap[string, string] // 声明一个string, string的map 345 | m.Store("hello", "1") // 保存 346 | v1, ok1 := m.Load("hello") // 获取值 347 | v1, ok1 = m.LoadAndDelete("hello") //返回hello对应值,然后删除hello 348 | Delete("hello") // 删除 349 | v1, ok1 = m.LoadOrStore("hello", "world") 350 | 351 | // 遍历,使用回调函数 352 | m.Range(func(key, val string) bool { 353 | fmt.Printf("k:%s, val:%s\n"i, key, val) 354 | return true 355 | }) 356 | 357 | // 遍历,迭代器 358 | for pair := range m.Iter() { 359 | fmt.Printf("k:%s, val:%s\n", pair.Key, pair.Val) 360 | } 361 | 362 | m.Len()// 获取长度 363 | allKeys := m.Keys() //返回所有的key 364 | allValues := m.Values()// 返回所有的value 365 | -------------------------------------------------------------------------------- /api/api.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import "golang.org/x/exp/constraints" 4 | 5 | type Map[K constraints.Ordered, V any] interface { 6 | // 获取 7 | Get(k K) (elem V) 8 | // 获取 9 | TryGet(k K) (elem V, ok bool) 10 | // 删除 11 | Delete(k K) 12 | // 设置 13 | Set(k K, v V) 14 | // 设置值 15 | Swap(k K, v V) (prev V, replaced bool) 16 | // int 17 | Len() int 18 | // 遍历 19 | Range(callback func(k K, v V) bool) 20 | } 21 | 22 | type SortedMap[K constraints.Ordered, V any] interface { 23 | Map[K, V] 24 | TopMin(limit int, callback func(k K, v V) bool) 25 | TopMax(limit int, callback func(k K, v V) bool) 26 | } 27 | 28 | // TODO 29 | type Set[K constraints.Ordered] interface { 30 | Set(k K) 31 | } 32 | 33 | type Trie[V any] interface { 34 | Get(k string) (v V) 35 | Swap(k string, v V) (prev V, replaced bool) 36 | HasPrefix(k string) bool 37 | TryGet(k string) (v V, found bool) 38 | Delete(k string) 39 | Len() int 40 | } 41 | 42 | type CMaper[K comparable, V any] interface { 43 | Delete(key K) 44 | Load(key K) (value V, ok bool) 45 | LoadAndDelete(key K) (value V, loaded bool) 46 | LoadOrStore(key K, value V) (actual V, loaded bool) 47 | Range(f func(key K, value V) bool) 48 | Store(key K, value V) 49 | } 50 | -------------------------------------------------------------------------------- /avltree/avltree.go: -------------------------------------------------------------------------------- 1 | package avltree 2 | 3 | // apache 2.0 antlabs 4 | 5 | // 参考资料 6 | // https://github.com/skywind3000/avlmini 7 | import ( 8 | "fmt" 9 | 10 | "github.com/antlabs/gstl/api" 11 | "github.com/antlabs/gstl/cmp" 12 | "github.com/antlabs/gstl/vec" 13 | "golang.org/x/exp/constraints" 14 | ) 15 | 16 | var _ api.SortedMap[int, int] = (*AvlTree[int, int])(nil) 17 | 18 | // 元素 19 | type pair[K constraints.Ordered, V any] struct { 20 | val V 21 | key K 22 | } 23 | 24 | type node[K constraints.Ordered, V any] struct { 25 | left *node[K, V] 26 | right *node[K, V] 27 | parent *node[K, V] 28 | pair[K, V] 29 | height int 30 | } 31 | 32 | // 返回左子树高度 33 | func (n *node[K, V]) leftHeight() int { 34 | if n.left != nil { 35 | return n.left.height 36 | } 37 | 38 | return 0 39 | } 40 | 41 | // 返回右子树高度 42 | func (n *node[K, V]) rightHeight() int { 43 | if n.right != nil { 44 | return n.right.height 45 | } 46 | return 0 47 | } 48 | 49 | func (n *node[K, V]) heightUpdate() { 50 | lh := n.leftHeight() 51 | rh := n.rightHeight() 52 | n.height = cmp.Max(lh, rh) + 1 53 | } 54 | 55 | func (n *node[K, V]) link(parent *node[K, V], link **node[K, V]) { 56 | n.parent = parent 57 | *link = n 58 | } 59 | 60 | type root[K constraints.Ordered, V any] struct { 61 | node *node[K, V] 62 | } 63 | 64 | func (r *root[K, V]) fixLeft(node *node[K, V]) *node[K, V] { 65 | right := node.right 66 | // 右节点, 左子树高度 67 | rlh := right.leftHeight() 68 | // 右节点, 右子树高度 69 | rrh := right.rightHeight() 70 | 71 | if rlh > rrh { 72 | right = r.rotateRight(right) 73 | right.right.heightUpdate() 74 | right.heightUpdate() 75 | } 76 | node = r.rotateLeft(node) 77 | node.left.heightUpdate() 78 | node.heightUpdate() 79 | 80 | return node 81 | } 82 | 83 | func (r *root[K, V]) fixRight(node *node[K, V]) *node[K, V] { 84 | left := node.left 85 | // 右节点, 左子树高度 86 | llh := left.leftHeight() 87 | // 右节点, 右子树高度 88 | lrh := left.rightHeight() 89 | 90 | if llh < lrh { 91 | left = r.rotateLeft(left) 92 | left.left.heightUpdate() 93 | left.heightUpdate() 94 | } 95 | 96 | node = r.rotateRight(node) 97 | if node.right != nil { 98 | node.right.heightUpdate() 99 | } 100 | 101 | node.heightUpdate() 102 | 103 | return node 104 | } 105 | 106 | func (r *root[K, V]) postInsert(node *node[K, V]) { 107 | node.height = 1 108 | 109 | for node = node.parent; node != nil; node = node.parent { 110 | lh := node.leftHeight() 111 | rh := node.rightHeight() 112 | height := cmp.Max(lh, rh) + 1 113 | 114 | diff := lh - rh 115 | if node.height == height { 116 | break 117 | } 118 | node.height = height 119 | 120 | if diff <= -2 { 121 | node = r.fixLeft(node) 122 | } else if diff >= 2 { 123 | node = r.fixRight(node) 124 | } 125 | } 126 | } 127 | 128 | func (r *root[K, V]) childReplace(oldNode, newNode, parent *node[K, V]) { 129 | if parent != nil { 130 | if parent.left == oldNode { 131 | parent.left = newNode 132 | } else { 133 | parent.right = newNode 134 | } 135 | } else { 136 | r.node = newNode 137 | } 138 | 139 | } 140 | 141 | // 左旋就是拽住node往左下拉, node.right升为父节点 142 | func (r *root[K, V]) rotateLeft(node *node[K, V]) *node[K, V] { 143 | right := node.right 144 | parent := node.parent 145 | 146 | // node会滑成right的左节点 147 | // 这里安排下node.right的位置, 这里不再指向right, 再向right的左孩子 148 | // right.left 大于node, 小于right, 所以新的位置就是node.right 149 | node.right = right.left 150 | if right.left != nil { 151 | right.left.parent = node 152 | } 153 | 154 | // 把node从父的位置降下来 155 | right.left = node 156 | right.parent = parent 157 | r.childReplace(node, right, parent) 158 | node.parent = right 159 | return right 160 | } 161 | 162 | // 右旋就是拽往node往右下拉, node.left升为父节点 163 | func (r *root[K, V]) rotateRight(node *node[K, V]) *node[K, V] { 164 | left := node.left 165 | parent := node.parent 166 | node.left = left.right 167 | if left.right != nil { 168 | left.right.parent = node 169 | } 170 | 171 | left.right = node 172 | left.parent = parent 173 | r.childReplace(node, left, parent) 174 | node.parent = left 175 | 176 | return node 177 | } 178 | 179 | // avl tree的结构 180 | type AvlTree[K constraints.Ordered, V any] struct { 181 | length int 182 | root root[K, V] 183 | } 184 | 185 | // 如果有值,则这个回调函数会被调用 186 | type InsertOrUpdateCb[V any] func(prev V, new V) V 187 | 188 | // 构造函数 189 | func New[K constraints.Ordered, V any]() *AvlTree[K, V] { 190 | return &AvlTree[K, V]{} 191 | } 192 | 193 | // 第一个节点 194 | func (a *AvlTree[K, V]) First() (v V, ok bool) { 195 | n := a.root.node 196 | if n == nil { 197 | ok = false 198 | return 199 | } 200 | 201 | for n.left != nil { 202 | n = n.left 203 | } 204 | 205 | return n.val, true 206 | } 207 | 208 | // 最后一个节点 209 | func (a *AvlTree[K, V]) Last() (v V, ok bool) { 210 | n := a.root.node 211 | if n == nil { 212 | ok = false 213 | return 214 | } 215 | 216 | for n.right != nil { 217 | n = n.right 218 | } 219 | 220 | return n.val, true 221 | } 222 | 223 | // Get 224 | func (a *AvlTree[K, V]) Get(k K) (v V) { 225 | v, _ = a.TryGet(k) 226 | return 227 | } 228 | 229 | // 从avl tree找到需要的值 230 | func (a *AvlTree[K, V]) TryGet(k K) (v V, ok bool) { 231 | n := a.root.node 232 | for n != nil { 233 | if n.key == k { 234 | return n.val, true 235 | } 236 | 237 | if k > n.key { 238 | n = n.right 239 | } else { 240 | n = n.left 241 | } 242 | } 243 | 244 | return 245 | } 246 | 247 | func (a *AvlTree[K, V]) Set(k K, v V) { 248 | _, _ = a.Swap(k, v) 249 | } 250 | 251 | // 设置接口, 如果有值, 把prev值带返回, 并且被替换, 没有就新加 252 | func (a *AvlTree[K, V]) Swap(k K, v V) (prev V, replaced bool) { 253 | link := &a.root.node 254 | var parent *node[K, V] 255 | node := &node[K, V]{pair: pair[K, V]{key: k, val: v}} 256 | 257 | for *link != nil { 258 | parent = *link 259 | if parent.key == k { 260 | prev = parent.val 261 | parent.val = v 262 | return prev, true 263 | } 264 | 265 | if parent.key < k { 266 | link = &parent.right 267 | } else { 268 | link = &parent.left 269 | } 270 | } 271 | 272 | node.link(parent, link) 273 | a.root.postInsert(node) 274 | a.length++ 275 | return 276 | } 277 | 278 | func (a *AvlTree[K, V]) InsertOrUpdate(k K, v V, cb InsertOrUpdateCb[V]) { 279 | if prev, ok := a.TryGet(k); ok { 280 | v = cb(prev, v) 281 | } 282 | a.Set(k, v) 283 | } 284 | 285 | func (r *root[K, V]) rebalance(node *node[K, V]) { 286 | 287 | for ; node != nil; node = node.parent { 288 | lh := node.leftHeight() 289 | lr := node.rightHeight() 290 | height := cmp.Max(lh, lr) + 1 291 | 292 | diff := lh - lr 293 | if node.height != height { 294 | node.height = height 295 | } else if diff >= -1 && diff <= 1 { 296 | break 297 | } 298 | 299 | if diff <= -2 { 300 | node = r.fixLeft(node) 301 | } else if diff >= 2 { 302 | node = r.fixRight(node) 303 | } 304 | } 305 | } 306 | 307 | func (a *AvlTree[K, V]) Delete(k K) { 308 | a.Remove(k) 309 | } 310 | 311 | func (a *AvlTree[K, V]) Remove(k K) *AvlTree[K, V] { 312 | n := a.root.node 313 | for n != nil { 314 | if n.key == k { 315 | goto found 316 | } 317 | 318 | if k > n.key { 319 | n = n.right 320 | } else { 321 | n = n.left 322 | } 323 | } 324 | 325 | return a 326 | 327 | found: 328 | var child, parent *node[K, V] 329 | if n.left != nil && n.right != nil { 330 | old := n 331 | n = n.right 332 | for left := n; left != nil; left = left.left { 333 | n = left 334 | } 335 | // 待会儿old被删除时, 使用n贴到old原来的位置 336 | 337 | child = n.left 338 | parent = n.parent 339 | if child != nil { 340 | // child 这条线不再n 节点 341 | child.parent = parent 342 | } 343 | // TODO 写注释 344 | a.root.childReplace(n, child, parent) 345 | 346 | if n.parent == old { 347 | parent = n 348 | } 349 | 350 | // 把n节点贴到原来old的位置 351 | n.left = old.left 352 | n.right = old.right 353 | n.parent = old.parent 354 | n.height = old.height 355 | 356 | a.root.childReplace(old, n, old.parent) 357 | old.left.parent = n 358 | 359 | if old.right != nil { 360 | old.right.parent = n 361 | } 362 | } else { 363 | if n.left == nil { 364 | child = n.right 365 | } else { 366 | child = n.left 367 | } 368 | parent = n.parent 369 | a.root.childReplace(n, child, parent) 370 | if child != nil { 371 | child.parent = parent 372 | } 373 | } 374 | 375 | if parent != nil { 376 | a.root.rebalance(parent) 377 | } 378 | return a 379 | } 380 | 381 | func (n *node[K, V]) rangeInner(callback func(k K, v V) bool) bool { 382 | 383 | if n == nil { 384 | return true 385 | } 386 | 387 | if n.left != nil { 388 | if !n.left.rangeInner(callback) { 389 | return false 390 | } 391 | } 392 | 393 | if !callback(n.key, n.val) { 394 | return false 395 | } 396 | 397 | if n.right != nil { 398 | if !n.right.rangeInner(callback) { 399 | return false 400 | } 401 | } 402 | return true 403 | } 404 | 405 | func (n *node[K, V]) rangePrevInner(callback func(k K, v V) bool) bool { 406 | 407 | if n == nil { 408 | return true 409 | } 410 | 411 | if n.right != nil { 412 | if !n.right.rangePrevInner(callback) { 413 | return false 414 | } 415 | } 416 | 417 | if !callback(n.key, n.val) { 418 | return false 419 | } 420 | 421 | if n.left != nil { 422 | if !n.left.rangePrevInner(callback) { 423 | return false 424 | } 425 | } 426 | 427 | return true 428 | } 429 | 430 | // 遍历avl tree 431 | func (a *AvlTree[K, V]) Range(callback func(k K, v V) bool) { 432 | // 遍历 433 | if a.root.node == nil { 434 | return 435 | } 436 | 437 | a.root.node.rangeInner(callback) 438 | return 439 | } 440 | 441 | // 遍历avl tree 442 | func (a *AvlTree[K, V]) RangePrev(callback func(k K, v V) bool) { 443 | // 遍历 444 | if a.root.node == nil { 445 | return 446 | } 447 | 448 | a.root.node.rangePrevInner(callback) 449 | return 450 | } 451 | 452 | func (a *AvlTree[K, V]) TopMax(limit int, callback func(k K, v V) bool) { 453 | a.RangePrev(func(k K, v V) bool { 454 | 455 | if limit <= 0 { 456 | return false 457 | } 458 | 459 | if !callback(k, v) { 460 | return false 461 | } 462 | 463 | limit-- 464 | return true 465 | }) 466 | } 467 | 468 | func (a *AvlTree[K, V]) TopMin(limit int, callback func(k K, v V) bool) { 469 | 470 | a.Range(func(k K, v V) bool { 471 | 472 | if limit <= 0 { 473 | return false 474 | } 475 | 476 | if !callback(k, v) { 477 | return false 478 | } 479 | 480 | limit-- 481 | return true 482 | }) 483 | } 484 | 485 | func (a *AvlTree[K, V]) Len() int { 486 | return a.length 487 | } 488 | 489 | func (a *AvlTree[K, V]) Draw() { 490 | if a.root.node == nil { 491 | return 492 | } 493 | 494 | a.root.node.draw(a.root.node) 495 | } 496 | 497 | // 画出avl tree 498 | // 使用层序遍历的姿势 499 | func (n *node[K, V]) draw(root *node[K, V]) { 500 | if root == nil { 501 | return 502 | } 503 | 504 | q := vec.New(root) 505 | for height := 0; q.Len() > 0; height++ { 506 | tmp := q.ToSlice() 507 | q = vec.New[*node[K, V]]() 508 | 509 | fmt.Printf("height:%d ", height) 510 | for _, node := range tmp { 511 | fmt.Printf("%v ", node.pair) 512 | 513 | if node.left != nil { 514 | 515 | q.Push(node.left) 516 | } 517 | 518 | if node.right != nil { 519 | q.Push(node.right) 520 | } 521 | 522 | } 523 | fmt.Printf("\n") 524 | 525 | } 526 | } 527 | -------------------------------------------------------------------------------- /avltree/avltree_bench_test.go: -------------------------------------------------------------------------------- 1 | package avltree 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | // b.N = 3kw 9 | // pkg: github.com/antlabs/gstl/avltree 10 | // BenchmarkGetAsc-8 33178270 41.07 ns/op 11 | // BenchmarkGetDesc-8 33488839 39.91 ns/op 12 | // BenchmarkGetStd-8 29553132 49.34 ns/op 13 | 14 | func BenchmarkGetAsc(b *testing.B) { 15 | //max := 1000000.0 * 5 16 | max := float64(b.N) 17 | set := New[float64, float64]() 18 | for i := 0.0; i < max; i++ { 19 | set.Set(i, i) 20 | } 21 | 22 | b.ResetTimer() 23 | 24 | for i := 0.0; i < max; i++ { 25 | v := set.Get(i) 26 | if v != i { 27 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 28 | } 29 | } 30 | } 31 | 32 | func BenchmarkGetDesc(b *testing.B) { 33 | //max := 1000000.0 * 5 34 | max := float64(b.N) 35 | set := New[float64, float64]() 36 | for i := max; i >= 0; i-- { 37 | set.Set(i, i) 38 | } 39 | 40 | b.ResetTimer() 41 | 42 | for i := 0.0; i < max; i++ { 43 | v := set.Get(i) 44 | if v != i { 45 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 46 | } 47 | } 48 | } 49 | 50 | func BenchmarkGetStd(b *testing.B) { 51 | 52 | //max := 1000000.0 * 5 53 | max := float64(b.N) 54 | set := make(map[float64]float64, int(max)) 55 | for i := 0.0; i < max; i++ { 56 | set[i] = i 57 | } 58 | 59 | b.ResetTimer() 60 | 61 | for i := 0.0; i < max; i++ { 62 | v := set[i] 63 | if v != i { 64 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /avltree/avltree_test.go: -------------------------------------------------------------------------------- 1 | package avltree 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/antlabs/gstl/cmp" 7 | "github.com/antlabs/gstl/vec" 8 | ) 9 | 10 | // 从小到大, 插入 11 | func Test_SetAndGet(t *testing.T) { 12 | b := New[int, int]() 13 | max := 1000 14 | for i := 0; i < max; i++ { 15 | b.Swap(i, i) 16 | } 17 | 18 | for i := 0; i < max; i++ { 19 | v, ok := b.TryGet(i) 20 | if !ok { 21 | t.Errorf("expected true, got false for index %d", i) 22 | } 23 | if v != i { 24 | t.Errorf("expected %d, got %d for index %d", i, v, i) 25 | } 26 | } 27 | } 28 | 29 | // 从大到小, 插入 30 | func Test_SetAndGet2(t *testing.T) { 31 | b := New[int, int]() 32 | max := 1000 33 | for i := max; i >= 0; i-- { 34 | b.Swap(i, i) 35 | } 36 | 37 | for i := max; i >= 0; i-- { 38 | v, ok := b.TryGet(i) 39 | if !ok { 40 | t.Errorf("expected true, got false for index %d", i) 41 | } 42 | if v != i { 43 | t.Errorf("expected %d, got %d for index %d", i, v, i) 44 | } 45 | } 46 | } 47 | 48 | // 测试avltree删除的情况, 少量数量 49 | func Test_AVLTree_Delete1(t *testing.T) { 50 | for max := 3; max < 1000; max++ { 51 | 52 | b := New[int, int]() 53 | 54 | // 设置0-max 55 | for i := 0; i < max; i++ { 56 | b.Set(i, i) 57 | } 58 | 59 | // 删除0-max/2 60 | for i := 0; i < max/2; i++ { 61 | b.Delete(i) 62 | } 63 | 64 | // max/2-max应该能找到 65 | for i := max / 2; i < max; i++ { 66 | v, ok := b.TryGet(i) 67 | if !ok { 68 | t.Errorf("expected true, got false for index %d", i) 69 | } 70 | if v != i { 71 | t.Errorf("expected %d, got %d for index %d", i, v, i) 72 | } 73 | } 74 | 75 | // 0-max/2应该找不到 76 | for i := 0; i < max/2; i++ { 77 | v, ok := b.TryGet(i) 78 | if ok { 79 | t.Errorf("expected false, got true for index %d", i) 80 | } 81 | if v != 0 { 82 | t.Errorf("expected 0, got %d for index %d", v, i) 83 | } 84 | } 85 | } 86 | } 87 | 88 | // 测试TopMax, 返回最大的几个数据降序返回 89 | func Test_AvlTree_TopMax(t *testing.T) { 90 | 91 | need := [3][]int{} 92 | count10 := 10 93 | count100 := 100 94 | count1000 := 1000 95 | count := []int{count10, count100, count1000} 96 | 97 | for i := 0; i < len(count); i++ { 98 | for j, k := count[i]-1, count100-1; j >= 0 && k >= 0; j-- { 99 | need[i] = append(need[i], j) 100 | k-- 101 | } 102 | } 103 | 104 | for i, b := range []*AvlTree[int, int]{ 105 | // btree里面元素 少于 TopMax 需要返回的值 106 | func() *AvlTree[int, int] { 107 | b := New[int, int]() 108 | for i := 0; i < count10; i++ { 109 | b.Set(i, i) 110 | } 111 | 112 | b.Draw() 113 | 114 | if b.Len() != count10 { 115 | t.Errorf("expected length %d, got %d", count10, b.Len()) 116 | } 117 | return b 118 | }(), 119 | // btree里面元素 等于 TopMax 需要返回的值 120 | func() *AvlTree[int, int] { 121 | 122 | b := New[int, int]() 123 | for i := 0; i < count100; i++ { 124 | b.Set(int(i), i) 125 | } 126 | if b.Len() != count100 { 127 | t.Errorf("expected length %d, got %d", count100, b.Len()) 128 | } 129 | return b 130 | }(), 131 | // btree里面元素 大于 TopMax 需要返回的值 132 | func() *AvlTree[int, int] { 133 | 134 | b := New[int, int]() 135 | for i := 0; i < count1000; i++ { 136 | b.Set(int(i), i) 137 | } 138 | if b.Len() != count1000 { 139 | t.Errorf("expected length %d, got %d", count1000, b.Len()) 140 | } 141 | return b 142 | }(), 143 | } { 144 | var key, val []int 145 | b.TopMax(count100, func(k int, v int) bool { 146 | key = append(key, int(k)) 147 | val = append(val, v) 148 | return true 149 | }) 150 | length := cmp.Min(count[i], len(need[i])) 151 | if !equalSlices(key, need[i][:length]) { 152 | t.Errorf("expected keys %v, got %v", need[i][:length], key) 153 | } 154 | if !equalSlices(val, need[i][:length]) { 155 | t.Errorf("expected values %v, got %v", need[i][:length], val) 156 | } 157 | } 158 | } 159 | 160 | // 测试TopMin, 它返回最小的几个值 161 | func Test_AvlTree_TopMin(t *testing.T) { 162 | 163 | need := []int{} 164 | count10 := 10 165 | count100 := 100 166 | count1000 := 1000 167 | 168 | for i := 0; i < count1000; i++ { 169 | need = append(need, i) 170 | } 171 | 172 | needCount := []int{count10, count100, count100} 173 | for i, b := range []*AvlTree[int, int]{ 174 | // btree里面元素 少于 TopMin 需要返回的值 175 | func() *AvlTree[int, int] { 176 | b := New[int, int]() 177 | for i := 0; i < count10; i++ { 178 | b.Set(i, i) 179 | } 180 | 181 | if b.Len() != count10 { 182 | t.Errorf("expected length %d, got %d", count10, b.Len()) 183 | } 184 | return b 185 | }(), 186 | // btree里面元素 等于 TopMin 需要返回的值 187 | func() *AvlTree[int, int] { 188 | 189 | b := New[int, int]() 190 | for i := 0; i < count100; i++ { 191 | b.Set(i, i) 192 | } 193 | if b.Len() != count100 { 194 | t.Errorf("expected length %d, got %d", count100, b.Len()) 195 | } 196 | return b 197 | }(), 198 | // btree里面元素 大于 TopMin 需要返回的值 199 | func() *AvlTree[int, int] { 200 | 201 | b := New[int, int]() 202 | for i := 0; i < count1000; i++ { 203 | b.Set(i, i) 204 | } 205 | if b.Len() != count1000 { 206 | t.Errorf("expected length %d, got %d", count1000, b.Len()) 207 | } 208 | return b 209 | }(), 210 | } { 211 | var key, val []int 212 | b.TopMin(count100, func(k, v int) bool { 213 | key = append(key, k) 214 | val = append(val, v) 215 | return true 216 | }) 217 | if !equalSlices(key, need[:needCount[i]]) { 218 | t.Errorf("expected keys %v, got %v", need[:needCount[i]], key) 219 | } 220 | if !equalSlices(val, need[:needCount[i]]) { 221 | t.Errorf("expected values %v, got %v", need[:needCount[i]], val) 222 | } 223 | } 224 | } 225 | 226 | func Test_RanePrev(t *testing.T) { 227 | a := New[int, int]() 228 | data := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} 229 | 230 | dataRev := vec.New(data...).Clone().Rev().ToSlice() 231 | for i := len(data) / 2; i >= 0; i-- { 232 | a.Set(i, i) 233 | } 234 | 235 | for i := len(data)/2 + 1; i < len(data); i++ { 236 | a.Set(i, i) 237 | } 238 | 239 | //a.Draw() 240 | 241 | var gotKey []int 242 | var gotVal []int 243 | a.RangePrev(func(k, v int) bool { 244 | gotKey = append(gotKey, k) 245 | gotVal = append(gotVal, v) 246 | 247 | return true 248 | }) 249 | 250 | if !equalSlices(gotKey, dataRev) { 251 | t.Errorf("expected keys %v, got %v", dataRev, gotKey) 252 | } 253 | if !equalSlices(gotVal, dataRev) { 254 | t.Errorf("expected values %v, got %v", dataRev, gotVal) 255 | } 256 | } 257 | 258 | func Test_AvlTree_InsertOrUpdate(t *testing.T) { 259 | b := New[int, int]() 260 | max := 100 261 | 262 | // Insert elements 263 | for i := 0; i < max; i++ { 264 | b.InsertOrUpdate(i, i, func(prev, new int) int { 265 | return prev + new 266 | }) 267 | } 268 | 269 | // Update elements 270 | for i := 0; i < max; i++ { 271 | b.InsertOrUpdate(i, i, func(prev, new int) int { 272 | return prev + new 273 | }) 274 | } 275 | 276 | // Verify elements 277 | for i := 0; i < max; i++ { 278 | v, ok := b.TryGet(i) 279 | if !ok || v != i*2 { 280 | t.Errorf("expected %d, got %v", i*2, v) 281 | } 282 | } 283 | } 284 | 285 | func Test_AvlTree_InsertOrUpdate2(t *testing.T) { 286 | b := New[int, int]() 287 | max := 100 288 | 289 | // Insert elements 290 | for i := 0; i < max; i++ { 291 | b.InsertOrUpdate(i, i, func(prev, new int) int { 292 | return prev + new 293 | }) 294 | } 295 | 296 | // Update elements 297 | for i := 0; i < max; i++ { 298 | b.InsertOrUpdate(i, i*2, func(prev, new int) int { 299 | return prev + new 300 | }) 301 | } 302 | 303 | // Verify elements 304 | for i := 0; i < max; i++ { 305 | v, ok := b.TryGet(i) 306 | if !ok || v != i*3 { 307 | t.Errorf("expected %d, got %v", i*3, v) 308 | } 309 | } 310 | } 311 | 312 | // 辅助函数,用于比较两个切片是否相等 313 | func equalSlices(a, b []int) bool { 314 | if len(a) != len(b) { 315 | return false 316 | } 317 | for i := range a { 318 | if a[i] != b[i] { 319 | return false 320 | } 321 | } 322 | return true 323 | } 324 | -------------------------------------------------------------------------------- /btree/btree.go: -------------------------------------------------------------------------------- 1 | package btree 2 | 3 | // apache 2.0 antlabs 4 | 5 | // 参考资料 6 | // https://github.com/tidwall/btree 7 | import ( 8 | "fmt" 9 | 10 | "github.com/antlabs/gstl/api" 11 | "github.com/antlabs/gstl/must" 12 | "github.com/antlabs/gstl/vec" 13 | "golang.org/x/exp/constraints" 14 | ) 15 | 16 | var _ api.SortedMap[int, int] = (*Btree[int, int])(nil) 17 | var notFound = "not found element" 18 | 19 | // btree头结点 20 | type Btree[K constraints.Ordered, V any] struct { 21 | count int //当前元素个数 22 | root *node[K, V] // root结点指针 23 | maxItems int 24 | minItems int 25 | } 26 | 27 | // 元素 28 | type pair[K constraints.Ordered, V any] struct { 29 | val V 30 | key K 31 | } 32 | 33 | // btree树的结点的组成 34 | type node[K constraints.Ordered, V any] struct { 35 | items *vec.Vec[pair[K, V]] //存放元素的节点 36 | children *vec.Vec[*node[K, V]] //孩子节点 37 | } 38 | 39 | func (n *node[K, V]) leaf() bool { 40 | return n.children == nil || n.children.Len() == 0 41 | } 42 | 43 | func New[K constraints.Ordered, V any](degree int) *Btree[K, V] { 44 | 45 | if degree == 0 { 46 | degree = 128 //拍脑袋给的, 需要压测下 47 | } 48 | 49 | maxItems := degree*2 - 1 // max items per node. max children is +1 50 | return &Btree[K, V]{ 51 | maxItems: maxItems, 52 | minItems: maxItems / 2, 53 | } 54 | } 55 | 56 | // 返回btree中元素的个数 57 | func (b *Btree[K, V]) Len() int { 58 | return b.count 59 | } 60 | 61 | // 设置接口, 如果有这个值, 有值就替换, 没有就新加 62 | func (b *Btree[K, V]) Set(k K, v V) { 63 | 64 | _, _ = b.Swap(k, v) 65 | } 66 | 67 | // 新建一个节点 68 | func (b *Btree[K, V]) newNode(leaf bool) (n *node[K, V]) { 69 | n = &node[K, V]{} 70 | if !leaf { 71 | n.children = vec.New[*node[K, V]]() 72 | } 73 | return 74 | } 75 | 76 | // 新建叶子节点 77 | func (b *Btree[K, V]) newLeaf() *node[K, V] { 78 | return b.newNode(true) 79 | } 80 | 81 | func (b *Btree[K, V]) find(n *node[K, V], key K) (index int, found bool) { 82 | 83 | index = n.items.SearchFunc(func(elem pair[K, V]) bool { return key < elem.key }) 84 | if index > 0 && n.items.Get(index-1).key >= key { 85 | return index - 1, true 86 | } 87 | 88 | return index, false 89 | } 90 | 91 | // 分裂结点 92 | func (b *Btree[K, V]) nodeSplit(n *node[K, V]) (right *node[K, V], median pair[K, V]) { 93 | i := b.maxItems / 2 94 | //fmt.Printf("nodeSplit:%#v i(%d):len(%d)\n", n.items, i, n.items.Len()) 95 | median = n.items.Get(i) 96 | 97 | // 新的左孩子就是n节点 98 | // n.items包含 左孩子和median节点 99 | rightItems := n.items.SplitOff(i + 1) 100 | // 删除median节点 101 | n.items.SetLen(n.items.Len() - 1) 102 | 103 | // 当前节点还有下层节点, 也要左右分家 104 | right = b.newNode(n.leaf()) 105 | right.items = rightItems 106 | //fmt.Printf("nodeSplit: %p, left:%v, median:%v %p, right:%v\n", n.items, n.items, median, right, rightItems) 107 | if !n.leaf() { 108 | right.children = n.children.SplitOff(i + 1) 109 | } 110 | 111 | return 112 | } 113 | 114 | // 把k/v的值放到结点里面 115 | func (b *Btree[K, V]) nodeSet(n *node[K, V], item pair[K, V]) (prev V, replaced bool, needSplit bool) { 116 | i, found := b.find(n, item.key) 117 | // 找到位置直接替换 118 | if found { 119 | //fmt.Printf("1.## i = %v, item:%v\n", item.key, n.items) 120 | prevPtr := n.items.GetPtr(i) 121 | prev = prevPtr.val 122 | prevPtr.val = item.val 123 | return prev, true, false 124 | } 125 | 126 | // 如果是叶子节点 127 | if n.leaf() { 128 | // 没有位置插入新元素, 上层节点需要分裂 129 | if n.items.Len() == b.maxItems { 130 | needSplit = true 131 | return 132 | } 133 | n.items.Insert(i, item) 134 | return 135 | } 136 | 137 | prev, replaced, needSplit = b.nodeSet(n.children.Get(i), item) 138 | if needSplit { 139 | // 没有位置插入新元素, 上层节点需要分裂 140 | if n.items.Len() == b.maxItems { 141 | 142 | needSplit = true 143 | return 144 | } 145 | 146 | right, median := b.nodeSplit(n.children.Get(i)) 147 | 148 | n.children.Insert(i+1, right) 149 | n.items.Insert(i, median) 150 | 151 | return b.nodeSet(n, item) 152 | } 153 | 154 | return 155 | } 156 | 157 | // 设置接口, 如果有值, 把prev值带返回, 并且被替换, 没有就新加 158 | func (b *Btree[K, V]) Swap(k K, v V) (prev V, replaced bool) { 159 | item := pair[K, V]{key: k, val: v} 160 | // 如果是每一个节点, 直接加入到root节点 161 | if b.root == nil { 162 | b.root = b.newLeaf() 163 | if b.root.items == nil { 164 | b.root.items = vec.New[pair[K, V]]() 165 | } 166 | b.root.items.Push(item) 167 | b.count = 1 168 | return 169 | } 170 | 171 | prev, replaced, needSplit := b.nodeSet(b.root, item) 172 | if needSplit { 173 | left := b.root 174 | right, median := b.nodeSplit(left) 175 | b.root = b.newNode(false) 176 | if b.root.children == nil { 177 | b.root.children = vec.WithCapacity[*node[K, V]](b.maxItems + 1) 178 | } 179 | 180 | b.root.children.Push(left, right) 181 | if b.root.items == nil { 182 | b.root.items = vec.New(median) 183 | } else { 184 | b.root.items.Push(median) 185 | } 186 | 187 | // 再调用下Swap, 结点分裂好了, 就有空间放数据 188 | return b.Swap(item.key, item.val) 189 | } 190 | 191 | if replaced { 192 | return prev, true 193 | } 194 | b.count++ 195 | return 196 | } 197 | 198 | // 获取值, 忽略找不到的情况 199 | func (b *Btree[K, V]) Get(k K) (v V) { 200 | v, _ = b.TryGet(k) 201 | return 202 | } 203 | 204 | // 找到ok为true 205 | // 找不到ok为false 206 | func (b *Btree[K, V]) TryGet(k K) (v V, ok bool) { 207 | if b.root == nil { 208 | return 209 | } 210 | 211 | n := b.root 212 | for { 213 | i, found := b.find(n, k) 214 | if found { 215 | return n.items.Get(i).val, true 216 | } 217 | 218 | if n.leaf() { 219 | return 220 | } 221 | 222 | n = (*n.children)[i] 223 | } 224 | } 225 | 226 | // 删除接口 227 | func (b *Btree[K, V]) Delete(k K) { 228 | b.DeleteWithPrev(k) 229 | } 230 | 231 | // 删除接口, 返回旧值 232 | func (b *Btree[K, V]) DeleteWithPrev(k K) (prev V, deleted bool) { 233 | if b.root == nil { 234 | return 235 | } 236 | 237 | prevPair, deleted := b.delete(b.root, false, k) 238 | if !deleted { 239 | return 240 | } 241 | 242 | if b.root.items.Len() == 0 && !b.root.leaf() { 243 | var ok bool 244 | b.root, ok = b.root.children.First() 245 | if !ok { 246 | panic("not found first element") 247 | } 248 | } 249 | 250 | b.count-- 251 | if b.count == 0 { 252 | b.root = nil 253 | } 254 | return prevPair.val, true 255 | } 256 | 257 | func (b *Btree[K, V]) delete(n *node[K, V], max bool, k K) (prev pair[K, V], deleted bool) { 258 | 259 | var i int 260 | var found bool 261 | 262 | var emptykv pair[K, V] 263 | if max { 264 | i, found = n.items.Len()-1, true 265 | } else { 266 | i, found = b.find(n, k) 267 | } 268 | 269 | // 如果是叶子, 并且没有找到 270 | if n.leaf() && !found { 271 | return emptykv, false 272 | } 273 | 274 | if found { 275 | if n.leaf() { 276 | // 叶子结点直接删除走人 277 | prev = n.items.Get(i) 278 | n.items.Remove(i) 279 | return prev, true 280 | } 281 | 282 | if max { 283 | i++ 284 | prev, deleted = b.delete(n.children.Get(i), true, emptykv.key) 285 | } else { 286 | prev = n.items.Get(i) 287 | maxItems, _ := b.delete(n.children.Get(i), true, emptykv.key) 288 | deleted = true 289 | n.items.Set(i, maxItems) 290 | } 291 | } else { 292 | prev, deleted = b.delete(n.children.Get(i), max, k) 293 | } 294 | 295 | if !deleted { 296 | return emptykv, false 297 | } 298 | 299 | // 准备合并 300 | if n.children.Get(i).items.Len() < b.minItems { 301 | b.rebalance(n, i) 302 | } 303 | 304 | return prev, true 305 | } 306 | 307 | func (b *Btree[K, V]) rebalance(n *node[K, V], i int) { 308 | if i == n.items.Len() { 309 | i-- 310 | } 311 | 312 | left, right := n.children.Get(i), n.children.Get(i+1) 313 | 314 | // 左右元素相加 < maxItems 315 | if left.items.Len()+right.items.Len() < b.maxItems { 316 | // 向左合并, 左=左+父+右 317 | // 合并父节点 318 | left.items.Push(n.items.Get(i)) 319 | // 合并右叶子 320 | left.items.Append(right.items) 321 | 322 | if !left.leaf() { 323 | // 合并右children 324 | left.children.Append(right.children) 325 | } 326 | 327 | // 删除父节点 328 | n.items.Remove(i) 329 | // 删除右叶子 330 | n.children.Remove(i + 1) 331 | } else if left.items.Len() > right.items.Len() { 332 | // 向右移动 333 | // 父到右 334 | right.items.Insert(0, n.items.Get(i)) 335 | // 左边最后一个当父 336 | last, ok := left.items.Pop() 337 | if !ok { 338 | panic(notFound) 339 | } 340 | 341 | // last是从左叶子最后一个元素借过来的 342 | n.items.Set(i, last) 343 | 344 | if !left.leaf() { 345 | 346 | last, ok := left.children.Pop() 347 | if !ok { 348 | panic(notFound) 349 | } 350 | right.children.Insert(0, last) 351 | } 352 | } else { 353 | 354 | // 向左合并 355 | // 左叶先合并父 356 | left.items.Push(n.items.Get(i)) 357 | // 向右边借最左边的元素当父 358 | first, ok := right.items.PopFront() 359 | if !ok { 360 | panic(notFound) 361 | } 362 | 363 | // first是和右叶子借过来的 364 | n.items.Set(i, first) 365 | 366 | if !left.leaf() { 367 | first, ok := right.children.PopFront() 368 | if !ok { 369 | panic(notFound) 370 | } 371 | 372 | left.children.Push(first) 373 | } 374 | } 375 | } 376 | 377 | // 遍历b tree 378 | func (b *Btree[K, V]) Range(callback func(k K, v V) bool) { 379 | // 遍历 380 | if b.root == nil { 381 | return 382 | } 383 | 384 | b.root.rangeInner(callback) 385 | return 386 | } 387 | 388 | // 返回最小的n个值, 升序返回, 比如0,1,2,3 389 | func (b *Btree[K, V]) TopMin(limit int, callback func(k K, v V) bool) { 390 | b.Range(func(k K, v V) bool { 391 | if limit <= 0 { 392 | return false 393 | } 394 | callback(k, v) 395 | limit-- 396 | return true 397 | }) 398 | } 399 | 400 | func (b *Btree[K, V]) Draw() { 401 | if b.root == nil { 402 | return 403 | } 404 | 405 | b.root.draw(b.root) 406 | //b.root.draw(0, b.root.items.Len() == b.root.children.Len()) 407 | } 408 | 409 | // 画出b tree 410 | // 使用层序遍历的姿势 411 | func (n *node[K, V]) draw(root *node[K, V]) { 412 | if root == nil { 413 | return 414 | } 415 | 416 | q := vec.New(root) 417 | for height := 0; q.Len() > 0; height++ { 418 | tmp := q.ToSlice() 419 | q = vec.New[*node[K, V]]() 420 | 421 | fmt.Printf("height:%d ", height) 422 | for _, node := range tmp { 423 | fmt.Printf("%v ", node.items) 424 | 425 | if node.children != nil { 426 | children := node.children.ToSlice() 427 | 428 | for _, nodeChild := range children { 429 | 430 | q.Push(nodeChild) 431 | } 432 | } 433 | 434 | } 435 | fmt.Printf("\n") 436 | 437 | } 438 | } 439 | 440 | // 遍历b tree 441 | func (n *node[K, V]) rangeInner(callback func(k K, v V) bool) bool { 442 | 443 | // 如果是叶子节点 444 | if n.leaf() { 445 | // 直接遍历n.items里面的元素 446 | for i, l := 0, n.items.Len(); i < l; i++ { 447 | item := n.items.Get(i) 448 | if !callback(item.key, item.val) { 449 | return false 450 | } 451 | } 452 | 453 | return true 454 | } 455 | 456 | for i, l := 0, n.items.Len(); i < l; i++ { 457 | if !n.children.Get(i).rangeInner(callback) { 458 | return false 459 | } 460 | 461 | item := n.items.Get(i) 462 | if !callback(item.key, item.val) { 463 | return false 464 | } 465 | } 466 | 467 | // n.children比n.items多一个元素. 这里不能漏掉 468 | return must.TakeOneDiscardBool(n.children.Last()).rangeInner(callback) 469 | } 470 | 471 | // 从后向前倒序遍历b tree 472 | func (b *Btree[K, V]) RangePrev(callback func(k K, v V) bool) *Btree[K, V] { 473 | // 遍历 474 | if b.root == nil { 475 | return b 476 | } 477 | 478 | b.root.rangePrevInner(callback) 479 | return b 480 | } 481 | 482 | // 返回最大的n个值, 降序返回, 10, 9, 8, 7 483 | func (b *Btree[K, V]) TopMax(limit int, callback func(k K, v V) bool) { 484 | b.RangePrev(func(k K, v V) bool { 485 | if limit <= 0 { 486 | return false 487 | } 488 | callback(k, v) 489 | limit-- 490 | return true 491 | }) 492 | } 493 | 494 | // TODO benchmark下 if提出来之后性能提升, 就是把for循环拆成两个写 495 | func (n *node[K, V]) rangePrevInner(callback func(k K, v V) bool) bool { 496 | 497 | // 先右 498 | if n.children != nil { 499 | if !must.TakeOneDiscardBool(n.children.Last()).rangePrevInner(callback) { 500 | return false 501 | } 502 | } 503 | 504 | for i := n.items.Len() - 1; i >= 0; i-- { 505 | 506 | // 后根 507 | item := n.items.Get(i) 508 | if !callback(item.key, item.val) { 509 | return false 510 | } 511 | // 最后左 512 | if n.children != nil { 513 | if !n.children.Get(i).rangePrevInner(callback) { 514 | return false 515 | } 516 | } 517 | } 518 | 519 | return true 520 | } 521 | -------------------------------------------------------------------------------- /btree/btree_bench_test.go: -------------------------------------------------------------------------------- 1 | package btree 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | // goos: darwin 10 | // goarch: amd64 11 | // pkg: github.com/antlabs/gstl/btree 12 | // cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz 13 | // BenchmarkGet-8 1000000000 0.5326 ns/op 14 | // PASS 15 | // ok github.com/antlabs/gstl/btree 25.315s 16 | // 五百万数据的Get操作时间 17 | 18 | // goos: darwin 19 | // goarch: arm64 20 | // pkg: github.com/antlabs/gstl/btree 21 | // BenchmarkGetAsc-8 17242494 79.54 ns/op 22 | // BenchmarkGetDesc-8 17556082 78.17 ns/op 23 | // BenchmarkGetStd-8 29304117 50.49 ns/op 24 | // PASS 25 | // ok github.com/antlabs/gstl/btree 10.503s 26 | func BenchmarkGetAsc(b *testing.B) { 27 | //max := 1000000.0 * 5 28 | set := New[float64, float64](0) 29 | max := float64(b.N) 30 | for i := 0.0; i < max; i++ { 31 | set.Set(i, i) 32 | } 33 | 34 | b.ResetTimer() 35 | 36 | for i := 0.0; i < max; i++ { 37 | v := set.Get(i) 38 | if v != i { 39 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 40 | } 41 | } 42 | } 43 | 44 | func BenchmarkGetDesc(b *testing.B) { 45 | max := float64(b.N) 46 | //max := 1000000.0 * 5 47 | set := New[float64, float64](0) 48 | for i := max; i >= 0; i-- { 49 | set.Set(i, i) 50 | } 51 | 52 | b.ResetTimer() 53 | 54 | for i := 0.0; i < max; i++ { 55 | v := set.Get(i) 56 | if v != i { 57 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 58 | } 59 | } 60 | } 61 | 62 | func BenchmarkGetStd(b *testing.B) { 63 | 64 | max := float64(b.N) 65 | //max := 1000000.0 * 5 66 | set := make(map[float64]float64) 67 | //set := make(map[float64]float64, int(max)) 68 | for i := 0.0; i < max; i++ { 69 | set[i] = i 70 | } 71 | 72 | b.ResetTimer() 73 | 74 | for i := 0.0; i < max; i++ { 75 | v := set[i] 76 | if v != i { 77 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /btree/btree_test.go: -------------------------------------------------------------------------------- 1 | package btree 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/antlabs/gstl/cmp" 7 | ) 8 | 9 | // 测试get set 10 | // 不分裂逻辑 11 | func Test_Btree_SetAndGet(t *testing.T) { 12 | b := New[int, int](0) 13 | 14 | max := 10 15 | for i := 0; i < max; i++ { 16 | b.Set(i, i) 17 | } 18 | 19 | for i := 0; i < max; i++ { 20 | v, ok := b.TryGet(i) 21 | if !ok { 22 | t.Errorf("Expected true, got false for key %d", i) 23 | } 24 | if v != i { 25 | t.Errorf("Expected %d, got %d for key %d", i, v, i) 26 | } 27 | } 28 | } 29 | 30 | // 测试get set 31 | // 分裂逻辑 32 | func Test_Btree_SetAndGet_Split(t *testing.T) { 33 | b := New[int, int](2) 34 | 35 | max := 10 36 | for i := 0; i < max; i++ { 37 | b.Set(i, i) 38 | } 39 | 40 | for i := 0; i < max; i++ { 41 | v, ok := b.TryGet(i) 42 | if !ok { 43 | t.Errorf("Expected true, got false for key %d", i) 44 | } 45 | if v != i { 46 | t.Errorf("Expected %d, got %d for key %d", i, v, i) 47 | } 48 | } 49 | } 50 | 51 | // 测试get set, 大数据量下面的测试 52 | // 分裂逻辑 53 | func Test_Btree_SetAndGet_Split_Big(t *testing.T) { 54 | max := 10000 55 | b := New[int, int](max) 56 | 57 | for i := 0; i < max; i++ { 58 | b.Set(i, i) 59 | } 60 | 61 | for i := 0; i < max; i++ { 62 | v, ok := b.TryGet(i) 63 | if !ok { 64 | t.Errorf("Expected true, got false for key %d", i) 65 | } 66 | if v != i { 67 | t.Errorf("Expected %d, got %d for key %d", i, v, i) 68 | } 69 | } 70 | } 71 | 72 | // 测试get set, 小数据量下面的替换 73 | func Test_Btree_SetAndGet_Replace(t *testing.T) { 74 | max := 10 75 | b := New[int, int](max) 76 | 77 | for i := 0; i < max; i++ { 78 | b.Set(i, i) 79 | } 80 | 81 | for i := 0; i < max; i++ { 82 | prev, replace := b.Swap(i, i+1) 83 | if !replace { 84 | t.Errorf("Expected true, got false for key %d", i) 85 | } 86 | if prev != i { 87 | t.Errorf("Expected %d, got %d for key %d", i, prev, i) 88 | } 89 | } 90 | 91 | for i := 0; i < max; i++ { 92 | v, ok := b.TryGet(i) 93 | if !ok { 94 | t.Errorf("Expected true, got false for key %d", i) 95 | } 96 | if v != i+1 { 97 | t.Errorf("Expected %d, got %d for key %d", i+1, v, i) 98 | } 99 | } 100 | } 101 | 102 | // 测试Range, 小数据量测试 103 | func Test_Btree_Range(t *testing.T) { 104 | b := New[int, int](2) 105 | max := 100 106 | key := make([]int, 0, max) 107 | val := make([]int, 0, max) 108 | need := make([]int, 0, max) 109 | for i := max - 1; i >= 0; i-- { 110 | b.Set(i, i) 111 | } 112 | 113 | for i := 0; i < max; i++ { 114 | need = append(need, i) 115 | } 116 | 117 | b.Range(func(k, v int) bool { 118 | key = append(key, k) 119 | val = append(val, k) 120 | return true 121 | }) 122 | 123 | if !slicesEqual(key, need) { 124 | t.Errorf("Expected %v, got %v", need, key) 125 | } 126 | if !slicesEqual(val, need) { 127 | t.Errorf("Expected %v, got %v", need, val) 128 | } 129 | } 130 | 131 | // 测试TopMin, 它返回最小的几个值 132 | func Test_Btree_TopMin(t *testing.T) { 133 | need := []int{} 134 | count10 := 10 135 | count100 := 100 136 | count1000 := 1000 137 | 138 | for i := 0; i < count1000; i++ { 139 | need = append(need, i) 140 | } 141 | 142 | needCount := []int{count10, count100, count100} 143 | for i, b := range []*Btree[int, int]{ 144 | // btree里面元素 少于 TopMin 需要返回的值 145 | func() *Btree[int, int] { 146 | b := New[int, int](2) 147 | for i := 0; i < count10; i++ { 148 | b.Set(i, i) 149 | } 150 | return b 151 | }(), 152 | // btree里面元素 等于 TopMin 需要返回的值 153 | func() *Btree[int, int] { 154 | b := New[int, int](2) 155 | for i := 0; i < count100; i++ { 156 | b.Set(i, i) 157 | } 158 | return b 159 | }(), 160 | // btree里面元素 大于 TopMin 需要返回的值 161 | func() *Btree[int, int] { 162 | b := New[int, int](2) 163 | for i := 0; i < count1000; i++ { 164 | b.Set(i, i) 165 | } 166 | return b 167 | }(), 168 | } { 169 | var key, val []int 170 | b.TopMin(count100, func(k, v int) bool { 171 | key = append(key, k) 172 | val = append(val, v) 173 | return true 174 | }) 175 | if !slicesEqual(key, need[:needCount[i]]) { 176 | t.Errorf("Expected %v, got %v", need[:needCount[i]], key) 177 | } 178 | if !slicesEqual(val, need[:needCount[i]]) { 179 | t.Errorf("Expected %v, got %v", need[:needCount[i]], val) 180 | } 181 | } 182 | } 183 | 184 | // 测试倒序输出 185 | func Test_Btree_RangePrev(t *testing.T) { 186 | b := New[int, int](2) 187 | max := 1000 188 | key := make([]int, 0, max) 189 | val := make([]int, 0, max) 190 | need := make([]int, 0, max) 191 | for i := 0; i < max; i++ { 192 | b.Set(i, i) 193 | } 194 | 195 | for i := max - 1; i >= 0; i-- { 196 | need = append(need, i) 197 | } 198 | 199 | b.RangePrev(func(k, v int) bool { 200 | key = append(key, k) 201 | val = append(val, k) 202 | return true 203 | }) 204 | 205 | if !slicesEqual(key, need) { 206 | t.Errorf("Expected %v, got %v", need, key) 207 | } 208 | } 209 | 210 | func Test_Btree_RangePrev2(t *testing.T) { 211 | b := New[int, int](2) 212 | max := 1000 213 | key := make([]int, 0, max) 214 | val := make([]int, 0, max) 215 | need := make([]int, 0, max) 216 | for i := max - 1; i >= 0; i-- { 217 | b.Set(i, i) 218 | } 219 | 220 | for i := max - 1; i >= 0; i-- { 221 | need = append(need, i) 222 | } 223 | 224 | b.RangePrev(func(k, v int) bool { 225 | key = append(key, k) 226 | val = append(val, k) 227 | return true 228 | }) 229 | 230 | if !slicesEqual(key, need) { 231 | t.Errorf("Expected %v, got %v", need, key) 232 | } 233 | } 234 | 235 | // 测试Find接口 236 | func Test_Btree_Find(t *testing.T) { 237 | b := New[int, int](2) 238 | b.Set(0, 0) 239 | b.Set(1, 1) 240 | b.Set(2, 2) 241 | 242 | index, _ := b.find(b.root, 2) 243 | if index != 2 { 244 | t.Errorf("Expected 2, got %d", index) 245 | } 246 | 247 | index, _ = b.find(b.root, 4) 248 | if index != 3 { 249 | t.Errorf("Expected 3, got %d", index) 250 | } 251 | } 252 | 253 | // 测试TopMax, 返回最大的几个数据降序返回 254 | func Test_Btree_TopMax(t *testing.T) { 255 | need := [3][]int{} 256 | count10 := 10 257 | count100 := 100 258 | count1000 := 1000 259 | count := []int{count10, count100, count1000} 260 | 261 | for i := 0; i < len(count); i++ { 262 | for j, k := count[i]-1, count100-1; j >= 0 && k >= 0; j-- { 263 | need[i] = append(need[i], j) 264 | k-- 265 | } 266 | } 267 | 268 | for i, b := range []*Btree[int, int]{ 269 | // btree里面元素 少于 TopMin 需要返回的值 270 | func() *Btree[int, int] { 271 | b := New[int, int](2) 272 | for i := 0; i < count10; i++ { 273 | b.Set(i, i) 274 | } 275 | return b 276 | }(), 277 | // btree里面元素 等于 TopMin 需要返回的值 278 | func() *Btree[int, int] { 279 | b := New[int, int](2) 280 | for i := 0; i < count100; i++ { 281 | b.Set(i, i) 282 | } 283 | return b 284 | }(), 285 | // btree里面元素 大于 TopMin 需要返回的值 286 | func() *Btree[int, int] { 287 | b := New[int, int](2) 288 | for i := 0; i < count1000; i++ { 289 | b.Set(i, i) 290 | } 291 | return b 292 | }(), 293 | } { 294 | var key, val []int 295 | b.TopMax(count100, func(k, v int) bool { 296 | key = append(key, k) 297 | val = append(val, v) 298 | return true 299 | }) 300 | length := cmp.Min(count[i], len(need[i])) 301 | if !slicesEqual(key, need[i][:length]) { 302 | t.Errorf("Expected %v, got %v", need[i][:length], key) 303 | } 304 | if !slicesEqual(val, need[i][:length]) { 305 | t.Errorf("Expected %v, got %v", need[i][:length], val) 306 | } 307 | } 308 | } 309 | 310 | // 测试btree删除的情况, 少量数量 311 | func Test_Btree_Delete1(t *testing.T) { 312 | for max := 3; max < 1000; max++ { 313 | b := New[int, int](64) 314 | 315 | // 设置0-max 316 | for i := 0; i < max; i++ { 317 | b.Set(i, i) 318 | } 319 | 320 | // 删除0-max/2 321 | for i := 0; i < max/2; i++ { 322 | b.Delete(i) 323 | } 324 | 325 | // max/2-max应该能找到 326 | for i := max / 2; i < max; i++ { 327 | v, ok := b.TryGet(i) 328 | if !ok { 329 | t.Errorf("Expected true, got false for key %d", i) 330 | } 331 | if v != i { 332 | t.Errorf("Expected %d, got %d for key %d", i, v, i) 333 | } 334 | } 335 | 336 | // 0-max/2应该找不到 337 | for i := 0; i < max/2; i++ { 338 | v, ok := b.TryGet(i) 339 | if ok { 340 | t.Errorf("Expected false, got true for key %d", i) 341 | } 342 | if v != 0 { 343 | t.Errorf("Expected 0, got %d for key %d", v, i) 344 | } 345 | } 346 | } 347 | } 348 | 349 | // 测试draw 350 | func Test_Btree_Draw(t *testing.T) { 351 | b := New[int, int](2) 352 | for i := 0; i < 10; i++ { 353 | b.Set(i, i) 354 | } 355 | 356 | b.Draw() 357 | } 358 | 359 | func Test_Btree_Delete2(t *testing.T) { 360 | b := New[int, int](2) 361 | 362 | for max := 0; max <= 500; max++ { 363 | for i := 0; i < max; i++ { 364 | b.Set(i, i) 365 | } 366 | 367 | start := max / 2 368 | // 删除后半段 369 | for i := start; i < max; i++ { 370 | prev, ok := b.DeleteWithPrev(i) 371 | if !ok { 372 | t.Errorf("Expected true, got false for key %d", i) 373 | } 374 | if prev != i { 375 | t.Errorf("Expected %d, got %d for key %d", i, prev, i) 376 | } 377 | } 378 | 379 | // 查找后半段, 应该找不到 380 | for i := start; i < max; i++ { 381 | v, ok := b.TryGet(i) 382 | if ok { 383 | t.Errorf("Expected false, got true for key %d", i) 384 | } 385 | if v != 0 { 386 | t.Errorf("Expected 0, got %d for key %d", v, i) 387 | } 388 | } 389 | 390 | // 查找前半段 391 | for i := 0; i < start; i++ { 392 | v, ok := b.TryGet(i) 393 | if !ok { 394 | t.Errorf("Expected true, got false for key %d", i) 395 | } 396 | if v != i { 397 | t.Errorf("Expected %d, got %d for key %d", i, v, i) 398 | } 399 | } 400 | } 401 | } 402 | 403 | // Helper function to compare slices 404 | func slicesEqual[T comparable](a, b []T) bool { 405 | if len(a) != len(b) { 406 | return false 407 | } 408 | for i := range a { 409 | if a[i] != b[i] { 410 | return false 411 | } 412 | } 413 | return true 414 | } 415 | -------------------------------------------------------------------------------- /cmap/cmap.go: -------------------------------------------------------------------------------- 1 | package cmap 2 | 3 | import ( 4 | "reflect" 5 | "runtime" 6 | "sync" 7 | "unsafe" 8 | 9 | "github.com/antlabs/gstl/api" 10 | xxhash "github.com/cespare/xxhash/v2" 11 | "golang.org/x/exp/constraints" 12 | ) 13 | 14 | var _ api.CMaper[int, int] = (*CMap[int, int])(nil) 15 | 16 | type Pair[K constraints.Ordered, V any] struct { 17 | Key K 18 | Val V 19 | } 20 | 21 | type CMap[K constraints.Ordered, V any] struct { 22 | bucket []Item[K, V] 23 | keySize int 24 | isKeyStr bool 25 | } 26 | 27 | type Item[K constraints.Ordered, V any] struct { 28 | rw sync.RWMutex 29 | m api.Map[K, V] 30 | } 31 | 32 | func New[K constraints.Ordered, V any]() (c *CMap[K, V]) { 33 | c = &CMap[K, V]{} 34 | c.init(0) 35 | return c 36 | } 37 | 38 | func (c *CMap[K, V]) init(n int) { 39 | np := runtime.GOMAXPROCS(0) 40 | if np <= 0 { 41 | np = 8 42 | } 43 | 44 | if n > 0 { 45 | np = n 46 | } 47 | 48 | c.bucket = make([]Item[K, V], np) 49 | 50 | for i := range c.bucket { 51 | c.bucket[i].m = newStdMap[K, V]() 52 | } 53 | 54 | } 55 | 56 | // 计算hash值 57 | func (c *CMap[K, V]) calHash(k K) uint64 { 58 | var key string 59 | 60 | if c.isKeyStr { 61 | // 直接赋值会报错, 使用unsafe绕过编译器检查 62 | key = *(*string)(unsafe.Pointer(&k)) 63 | } else { 64 | // 因为xxhash.Sum64String 接收string, 所以要把非string类型变量当成string类型来处理 65 | key = *(*string)(unsafe.Pointer(&reflect.StringHeader{ 66 | Data: uintptr(unsafe.Pointer(&k)), 67 | Len: c.keySize, 68 | })) 69 | } 70 | 71 | return xxhash.Sum64String(key) 72 | } 73 | 74 | // 保存key的类型和key的长度 75 | func (h *CMap[K, V]) keyTypeAndKeySize() { 76 | var k K 77 | switch (interface{})(k).(type) { 78 | case string: 79 | h.isKeyStr = true 80 | default: 81 | h.keySize = int(unsafe.Sizeof(k)) 82 | } 83 | } 84 | 85 | // 找到索引 86 | func (c *CMap[K, V]) findIndex(key K) *Item[K, V] { 87 | index := c.calHash(key) % uint64(len(c.bucket)) 88 | return &c.bucket[index] 89 | } 90 | 91 | // 删除 92 | func (c *CMap[K, V]) Delete(key K) { 93 | item := c.findIndex(key) 94 | item.rw.Lock() 95 | item.m.Delete(key) 96 | item.rw.Unlock() 97 | } 98 | 99 | type UpdataOrInsertCb[K constraints.Ordered, V any] func(exist bool, old V) (newVal V) 100 | 101 | // 删除或者更新 102 | func (c *CMap[K, V]) UpdateOrInsert(k K, cb UpdataOrInsertCb[K, V]) { 103 | item := c.findIndex(k) 104 | item.rw.Lock() 105 | old, ok := item.m.TryGet(k) 106 | newVal := cb(ok, old) 107 | item.m.Set(k, newVal) 108 | item.rw.Unlock() 109 | 110 | } 111 | 112 | func (c *CMap[K, V]) Load(key K) (value V, ok bool) { 113 | item := c.findIndex(key) 114 | item.rw.RLock() 115 | value, ok = item.m.TryGet(key) 116 | item.rw.RUnlock() 117 | return 118 | } 119 | 120 | func (c *CMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) { 121 | item := c.findIndex(key) 122 | item.rw.Lock() 123 | value, loaded = item.m.TryGet(key) 124 | if !loaded { 125 | item.rw.Unlock() 126 | return 127 | } 128 | item.m.Delete(key) 129 | item.rw.Unlock() 130 | return 131 | } 132 | 133 | func (c *CMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { 134 | item := c.findIndex(key) 135 | item.rw.Lock() 136 | actual, loaded = item.m.TryGet(key) 137 | if !loaded { 138 | actual = value 139 | item.m.Set(key, actual) 140 | item.rw.Unlock() 141 | return 142 | } 143 | 144 | actual, loaded = item.m.TryGet(key) 145 | item.rw.Unlock() 146 | return 147 | } 148 | 149 | func (c *CMap[K, V]) Range(f func(key K, value V) bool) { 150 | for i := 0; i < len(c.bucket); i++ { 151 | item := &c.bucket[i] 152 | item.rw.RLock() 153 | item.m.Range(f) 154 | item.rw.RUnlock() 155 | } 156 | } 157 | 158 | func (c *CMap[K, V]) Iter() (rv chan Pair[K, V]) { 159 | 160 | rv = make(chan Pair[K, V]) 161 | var wg sync.WaitGroup 162 | 163 | wg.Add(len(c.bucket)) 164 | 165 | go func() { 166 | wg.Wait() 167 | close(rv) 168 | }() 169 | 170 | for i := 0; i < len(c.bucket); i++ { 171 | item := &c.bucket[i] 172 | go func(item *Item[K, V]) { 173 | 174 | defer wg.Done() 175 | item.rw.RLock() 176 | item.m.Range(func(key K, value V) bool { 177 | rv <- Pair[K, V]{Key: key, Val: value} 178 | return true 179 | }) 180 | item.rw.RUnlock() 181 | 182 | }(item) 183 | } 184 | return rv 185 | 186 | } 187 | 188 | func (c *CMap[K, V]) Store(key K, value V) { 189 | item := c.findIndex(key) 190 | item.rw.Lock() 191 | item.m.Set(key, value) 192 | item.rw.Unlock() 193 | return 194 | } 195 | 196 | // TODO 优化 197 | func (c *CMap[K, V]) Keys() []K { 198 | l := c.Len() 199 | all := make([]K, 0, l) 200 | if l == 0 { 201 | return nil 202 | } 203 | 204 | for i := 0; i < len(c.bucket); i++ { 205 | 206 | item := &c.bucket[i] 207 | item.rw.RLock() 208 | item.m.Range(func(key K, value V) bool { 209 | all = append(all, key) 210 | return true 211 | }) 212 | item.rw.RUnlock() 213 | } 214 | return all 215 | } 216 | 217 | func (c *CMap[K, V]) Values() []V { 218 | l := c.Len() 219 | all := make([]V, 0, l) 220 | if l == 0 { 221 | return nil 222 | } 223 | 224 | for i := 0; i < len(c.bucket); i++ { 225 | 226 | item := &c.bucket[i] 227 | item.rw.RLock() 228 | item.m.Range(func(key K, value V) bool { 229 | all = append(all, value) 230 | return true 231 | }) 232 | item.rw.RUnlock() 233 | } 234 | return all 235 | } 236 | 237 | func (c *CMap[K, V]) Len() int { 238 | l := 0 239 | for i := 0; i < len(c.bucket); i++ { 240 | item := &c.bucket[i] 241 | item.rw.RLock() 242 | l += item.m.Len() 243 | item.rw.RUnlock() 244 | } 245 | return l 246 | } 247 | -------------------------------------------------------------------------------- /cmap/cmap_bench_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // guonaihong: 修改如下 5 | // 1. interface的地方换成泛型语法 6 | 7 | package cmap 8 | 9 | import ( 10 | "fmt" 11 | "reflect" 12 | "sync" 13 | "sync/atomic" 14 | "testing" 15 | "unsafe" 16 | 17 | "github.com/antlabs/gstl/api" 18 | xxhash "github.com/cespare/xxhash/v2" 19 | ) 20 | 21 | type syncmap[K comparable, V any] struct { 22 | m sync.Map 23 | } 24 | 25 | func (c *syncmap[K, V]) Delete(key K) { 26 | c.m.Delete(key) 27 | } 28 | 29 | func (c *syncmap[K, V]) Load(key K) (value V, ok bool) { 30 | v, ok := c.m.Load(key) 31 | if !ok { 32 | return 33 | } 34 | 35 | return v.(V), ok 36 | } 37 | 38 | func (c *syncmap[K, V]) LoadAndDelete(key K) (value V, loaded bool) { 39 | v, ok := c.m.LoadAndDelete(key) 40 | if !ok { 41 | return 42 | } 43 | 44 | return v.(V), ok 45 | } 46 | 47 | func (c *syncmap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { 48 | v, ok := c.m.LoadOrStore(key, value) 49 | if !ok { 50 | return 51 | } 52 | return v.(V), ok 53 | } 54 | 55 | func (c *syncmap[K, V]) Range(f func(key K, value V) bool) { 56 | c.m.Range(func(key any, value any) bool { 57 | return f(key.(K), value.(V)) 58 | }) 59 | } 60 | 61 | func (c *syncmap[K, V]) Store(key K, value V) { 62 | c.m.Store(key, value) 63 | } 64 | 65 | type bench[K comparable, V any] struct { 66 | setup func(*testing.B, api.CMaper[K, V]) 67 | perG func(b *testing.B, pb *testing.PB, i int, m api.CMaper[K, V]) 68 | } 69 | 70 | func benchMap(b *testing.B, bench bench[int, int]) { 71 | for _, m := range [...]api.CMaper[int, int]{New[int, int](), &syncmap[int, int]{}} { 72 | b.Run(fmt.Sprintf("%T", m), func(b *testing.B) { 73 | m = reflect.New(reflect.TypeOf(m).Elem()).Interface().(api.CMaper[int, int]) 74 | if m2, ok := m.(*CMap[int, int]); ok { 75 | m2.init(64) 76 | } 77 | 78 | if bench.setup != nil { 79 | bench.setup(b, m) 80 | } 81 | 82 | b.ResetTimer() 83 | 84 | var i int64 85 | b.RunParallel(func(pb *testing.PB) { 86 | id := int(atomic.AddInt64(&i, 1) - 1) 87 | bench.perG(b, pb, id*b.N, m) 88 | }) 89 | }) 90 | } 91 | } 92 | 93 | func BenchmarkLoadMostlyHits(b *testing.B) { 94 | const hits, misses = 1023, 1 95 | 96 | benchMap(b, bench[int, int]{ 97 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 98 | for i := 0; i < hits; i++ { 99 | m.LoadOrStore(i, i) 100 | } 101 | // Prime the map to get it into a steady state. 102 | for i := 0; i < hits*2; i++ { 103 | m.Load(i % hits) 104 | } 105 | }, 106 | 107 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 108 | for ; pb.Next(); i++ { 109 | m.Load(i % (hits + misses)) 110 | } 111 | }, 112 | }) 113 | } 114 | 115 | func BenchmarkLoadMostlyMisses(b *testing.B) { 116 | const hits, misses = 1, 1023 117 | 118 | benchMap(b, bench[int, int]{ 119 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 120 | for i := 0; i < hits; i++ { 121 | m.LoadOrStore(i, i) 122 | } 123 | // Prime the map to get it into a steady state. 124 | for i := 0; i < hits*2; i++ { 125 | m.Load(i % hits) 126 | } 127 | }, 128 | 129 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 130 | for ; pb.Next(); i++ { 131 | m.Load(i % (hits + misses)) 132 | } 133 | }, 134 | }) 135 | } 136 | 137 | func BenchmarkLoadOrStoreBalanced(b *testing.B) { 138 | const hits, misses = 128, 128 139 | 140 | benchMap(b, bench[int, int]{ 141 | setup: func(b *testing.B, m api.CMaper[int, int]) { 142 | for i := 0; i < hits; i++ { 143 | m.LoadOrStore(i, i) 144 | } 145 | // Prime the map to get it into a steady state. 146 | for i := 0; i < hits*2; i++ { 147 | m.Load(i % hits) 148 | } 149 | }, 150 | 151 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 152 | for ; pb.Next(); i++ { 153 | j := i % (hits + misses) 154 | if j < hits { 155 | if _, ok := m.LoadOrStore(j, i); !ok { 156 | b.Fatalf("unexpected miss for %v", j) 157 | } 158 | } else { 159 | if v, loaded := m.LoadOrStore(i, i); loaded { 160 | b.Fatalf("failed to store %v: existing value %v", i, v) 161 | } 162 | } 163 | } 164 | }, 165 | }) 166 | } 167 | 168 | func BenchmarkLoadOrStoreUnique(b *testing.B) { 169 | benchMap(b, bench[int, int]{ 170 | setup: func(b *testing.B, m api.CMaper[int, int]) { 171 | }, 172 | 173 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 174 | for ; pb.Next(); i++ { 175 | m.LoadOrStore(i, i) 176 | } 177 | }, 178 | }) 179 | } 180 | 181 | func BenchmarkDelete(b *testing.B) { 182 | benchMap(b, bench[int, int]{ 183 | setup: func(b *testing.B, m api.CMaper[int, int]) { 184 | for i := 0; i < 1000000; i++ { 185 | m.Store(i, i) 186 | } 187 | }, 188 | 189 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 190 | for ; pb.Next(); i++ { 191 | m.Delete(i) 192 | } 193 | }, 194 | }) 195 | } 196 | 197 | func BenchmarkStore(b *testing.B) { 198 | benchMap(b, bench[int, int]{ 199 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 200 | //m.LoadOrStore(0, 0) 201 | }, 202 | 203 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 204 | for ; pb.Next(); i++ { 205 | m.Store(i, i) 206 | } 207 | }, 208 | }) 209 | } 210 | 211 | func BenchmarkLoadOrStoreCollision(b *testing.B) { 212 | benchMap(b, bench[int, int]{ 213 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 214 | m.LoadOrStore(0, 0) 215 | }, 216 | 217 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 218 | for ; pb.Next(); i++ { 219 | m.LoadOrStore(0, 0) 220 | } 221 | }, 222 | }) 223 | } 224 | 225 | func BenchmarkLoadAndDeleteBalanced(b *testing.B) { 226 | const hits, misses = 128, 128 227 | 228 | benchMap(b, bench[int, int]{ 229 | setup: func(b *testing.B, m api.CMaper[int, int]) { 230 | for i := 0; i < hits; i++ { 231 | m.LoadOrStore(i, i) 232 | } 233 | // Prime the map to get it into a steady state. 234 | for i := 0; i < hits*2; i++ { 235 | m.Load(i % hits) 236 | } 237 | }, 238 | 239 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 240 | for ; pb.Next(); i++ { 241 | j := i % (hits + misses) 242 | if j < hits { 243 | m.LoadAndDelete(j) 244 | } else { 245 | m.LoadAndDelete(i) 246 | } 247 | } 248 | }, 249 | }) 250 | } 251 | 252 | func BenchmarkLoadAndDeleteUnique(b *testing.B) { 253 | benchMap(b, bench[int, int]{ 254 | setup: func(b *testing.B, m api.CMaper[int, int]) { 255 | }, 256 | 257 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 258 | for ; pb.Next(); i++ { 259 | m.LoadAndDelete(i) 260 | } 261 | }, 262 | }) 263 | } 264 | 265 | func BenchmarkLoadAndDeleteCollision(b *testing.B) { 266 | benchMap(b, bench[int, int]{ 267 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 268 | m.LoadOrStore(0, 0) 269 | }, 270 | 271 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 272 | for ; pb.Next(); i++ { 273 | m.LoadAndDelete(0) 274 | } 275 | }, 276 | }) 277 | } 278 | 279 | func BenchmarkRange(b *testing.B) { 280 | const mapSize = 1 << 10 281 | 282 | benchMap(b, bench[int, int]{ 283 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 284 | for i := 0; i < mapSize; i++ { 285 | m.Store(i, i) 286 | } 287 | }, 288 | 289 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 290 | for ; pb.Next(); i++ { 291 | m.Range(func(_, _ int) bool { return true }) 292 | } 293 | }, 294 | }) 295 | } 296 | 297 | // BenchmarkAdversarialAlloc tests performance when we store a new value 298 | // immediately whenever the map is promoted to clean and otherwise load a 299 | // unique, missing key. 300 | // 301 | // This forces the Load calls to always acquire the map's mutex. 302 | func BenchmarkAdversarialAlloc(b *testing.B) { 303 | benchMap(b, bench[int, int]{ 304 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 305 | var stores, loadsSinceStore int 306 | for ; pb.Next(); i++ { 307 | m.Load(i) 308 | if loadsSinceStore++; loadsSinceStore > stores { 309 | m.LoadOrStore(i, stores) 310 | loadsSinceStore = 0 311 | stores++ 312 | } 313 | } 314 | }, 315 | }) 316 | } 317 | 318 | // BenchmarkAdversarialDelete tests performance when we periodically delete 319 | // one key and add a different one in a large map. 320 | // 321 | // This forces the Load calls to always acquire the map's mutex and periodically 322 | // makes a full copy of the map despite changing only one entry. 323 | 324 | // 这个case不测试, 锁分区的方式这么使用会死锁 325 | /* 326 | func BenchmarkAdversarialDelete(b *testing.B) { 327 | const mapSize = 1 << 10 328 | 329 | benchMap(b, bench[int, int]{ 330 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 331 | for i := 0; i < mapSize; i++ { 332 | m.Store(i, i) 333 | } 334 | }, 335 | 336 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 337 | for ; pb.Next(); i++ { 338 | m.Load(i) 339 | 340 | if i%mapSize == 0 { 341 | m.Range(func(k, _ int) bool { 342 | m.Delete(k) 343 | return false 344 | }) 345 | m.Store(i, i) 346 | } 347 | } 348 | }, 349 | }) 350 | } 351 | */ 352 | 353 | func BenchmarkDeleteCollision(b *testing.B) { 354 | benchMap(b, bench[int, int]{ 355 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 356 | m.LoadOrStore(0, 0) 357 | }, 358 | 359 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 360 | for ; pb.Next(); i++ { 361 | m.Delete(0) 362 | } 363 | }, 364 | }) 365 | } 366 | 367 | func BenchmarkXXHash(b *testing.B) { 368 | 369 | for i := 0; i < b.N; i++ { 370 | key := *(*string)(unsafe.Pointer(&reflect.StringHeader{ 371 | Data: uintptr(unsafe.Pointer(&i)), 372 | Len: 8})) 373 | 374 | xxhash.Sum64String(key) 375 | } 376 | 377 | } 378 | -------------------------------------------------------------------------------- /cmap/cmap_stdmap.go: -------------------------------------------------------------------------------- 1 | package cmap 2 | 3 | import ( 4 | "github.com/antlabs/gstl/api" 5 | "golang.org/x/exp/constraints" 6 | ) 7 | 8 | var _ api.Map[int, int] = (*stdmap[int, int])(nil) 9 | 10 | type stdmap[K constraints.Ordered, V any] struct { 11 | m map[K]V 12 | } 13 | 14 | func newStdMap[K constraints.Ordered, V any]() *stdmap[K, V] { 15 | return &stdmap[K, V]{m: make(map[K]V)} 16 | } 17 | 18 | func (s *stdmap[K, V]) Get(key K) (elem V) { 19 | elem, _ = s.m[key] 20 | return 21 | } 22 | 23 | // 获取 24 | func (s *stdmap[K, V]) TryGet(key K) (elem V, ok bool) { 25 | elem, ok = s.m[key] 26 | return 27 | } 28 | 29 | // 删除 30 | func (s *stdmap[K, V]) Delete(key K) { 31 | delete(s.m, key) 32 | } 33 | 34 | // 设置 35 | func (s *stdmap[K, V]) Set(key K, value V) { 36 | s.m[key] = value 37 | } 38 | 39 | // 设置值 40 | func (s *stdmap[K, V]) Swap(key K, value V) (prev V, replaced bool) { 41 | prev, replaced = s.m[key] 42 | s.m[key] = value 43 | return 44 | } 45 | 46 | // int 47 | func (s *stdmap[K, V]) Len() int { 48 | return len(s.m) 49 | } 50 | 51 | // 遍历 52 | func (s *stdmap[K, V]) Range(callback func(k K, v V) bool) { 53 | for k, v := range s.m { 54 | if !callback(k, v) { 55 | return 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /cmap/cmap_test.go: -------------------------------------------------------------------------------- 1 | package cmap 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "sync" 7 | "testing" 8 | ) 9 | 10 | // Store And Load 11 | func Test_StoreAndLoad(t *testing.T) { 12 | m := New[string, string]() 13 | m.Store("hello", "1") 14 | m.Store("world", "2") 15 | v1, ok1 := m.Load("hello") 16 | if v1 != "1" { 17 | t.Errorf("expected '1', got '%s'", v1) 18 | } 19 | if !ok1 { 20 | t.Errorf("expected true, got false") 21 | } 22 | 23 | v1, ok1 = m.Load("world") 24 | if v1 != "2" { 25 | t.Errorf("expected '2', got '%s'", v1) 26 | } 27 | if !ok1 { 28 | t.Errorf("expected true, got false") 29 | } 30 | } 31 | 32 | // Store And Load 33 | func Test_StoreDeleteLoad(t *testing.T) { 34 | m := New[string, string]() 35 | m.Store("hello", "1") 36 | m.Store("world", "2") 37 | 38 | m.Delete("hello") 39 | m.Delete("world") 40 | 41 | v1, ok1 := m.Load("hello") 42 | if v1 != "" { 43 | t.Errorf("expected '', got '%s'", v1) 44 | } 45 | if ok1 { 46 | t.Errorf("expected false, got true") 47 | } 48 | 49 | v1, ok1 = m.Load("world") 50 | if v1 != "" { 51 | t.Errorf("expected '', got '%s'", v1) 52 | } 53 | if ok1 { 54 | t.Errorf("expected false, got true") 55 | } 56 | } 57 | 58 | func Test_LoadAndDelete(t *testing.T) { 59 | m := New[string, string]() 60 | v1, ok1 := m.LoadAndDelete("hello") 61 | 62 | if v1 != "" { 63 | t.Errorf("expected '', got '%s'", v1) 64 | } 65 | if ok1 { 66 | t.Errorf("expected false, got true") 67 | } 68 | 69 | m.Store("hello", "world") 70 | v1, ok1 = m.Load("hello") 71 | 72 | if v1 != "world" { 73 | t.Errorf("expected 'world', got '%s'", v1) 74 | } 75 | 76 | v1, ok1 = m.LoadAndDelete("hello") 77 | if v1 != "world" { 78 | t.Errorf("expected 'world', got '%s'", v1) 79 | } 80 | if !ok1 { 81 | t.Errorf("expected true, got false") 82 | } 83 | } 84 | 85 | func Test_loadOrStore(t *testing.T) { 86 | m := New[string, string]() 87 | var m2 sync.Map 88 | v1, ok1 := m.LoadOrStore("hello", "world") 89 | v2, ok2 := m2.LoadOrStore("hello", "world") 90 | 91 | if ok1 != ok2 { 92 | t.Errorf("expected %v, got %v", ok2, ok1) 93 | } 94 | if v1 != v2.(string) { 95 | t.Errorf("expected '%s', got '%s'", v2.(string), v1) 96 | } 97 | } 98 | 99 | func Test_RangeBreak(t *testing.T) { 100 | m := New[string, string]() 101 | m.Store("1", "1") 102 | m.Store("2", "2") 103 | 104 | count := 0 105 | m.Range(func(key, val string) bool { 106 | count++ 107 | return false 108 | }) 109 | 110 | if count != 1 { 111 | t.Errorf("expected 1, got %d", count) 112 | } 113 | } 114 | 115 | func Test_Range(t *testing.T) { 116 | m := New[string, string]() 117 | max := 5 118 | keyAll := []string{} 119 | valAll := []string{} 120 | 121 | for i := 1; i < max; i++ { 122 | key := fmt.Sprintf("%dk", i) 123 | val := fmt.Sprintf("%dv", i) 124 | keyAll = append(keyAll, key) 125 | valAll = append(valAll, val) 126 | m.Store(key, val) 127 | } 128 | 129 | gotKey := []string{} 130 | gotVal := []string{} 131 | m.Range(func(key, val string) bool { 132 | gotKey = append(gotKey, key) 133 | gotVal = append(gotVal, val) 134 | return true 135 | }) 136 | 137 | sort.Strings(gotKey) 138 | sort.Strings(gotVal) 139 | 140 | if !equalSlices(keyAll, gotKey) { 141 | t.Errorf("expected keys %v, got %v", keyAll, gotKey) 142 | } 143 | if !equalSlices(valAll, gotVal) { 144 | t.Errorf("expected values %v, got %v", valAll, gotVal) 145 | } 146 | } 147 | 148 | func Test_Iter(t *testing.T) { 149 | m := New[string, string]() 150 | max := 5 151 | keyAll := []string{} 152 | valAll := []string{} 153 | 154 | for i := 1; i < max; i++ { 155 | key := fmt.Sprintf("%dk", i) 156 | val := fmt.Sprintf("%dv", i) 157 | keyAll = append(keyAll, key) 158 | valAll = append(valAll, val) 159 | m.Store(key, val) 160 | } 161 | 162 | gotKey := []string{} 163 | gotVal := []string{} 164 | for pair := range m.Iter() { 165 | gotKey = append(gotKey, pair.Key) 166 | gotVal = append(gotVal, pair.Val) 167 | } 168 | 169 | sort.Strings(gotKey) 170 | sort.Strings(gotVal) 171 | 172 | if !equalSlices(keyAll, gotKey) { 173 | t.Errorf("expected keys %v, got %v", keyAll, gotKey) 174 | } 175 | if !equalSlices(valAll, gotVal) { 176 | t.Errorf("expected values %v, got %v", valAll, gotVal) 177 | } 178 | } 179 | 180 | func Test_Len(t *testing.T) { 181 | m := New[string, string]() 182 | m.Store("1", "1") 183 | m.Store("2", "2") 184 | m.Store("3", "3") 185 | if m.Len() != 3 { 186 | t.Errorf("expected 3, got %d", m.Len()) 187 | } 188 | } 189 | 190 | func Test_New(t *testing.T) { 191 | m := New[string, string]() 192 | m.Store("1", "1") 193 | m.Store("2", "2") 194 | m.Store("3", "3") 195 | if m.Len() != 3 { 196 | t.Errorf("expected 3, got %d", m.Len()) 197 | } 198 | } 199 | 200 | func Test_Keys(t *testing.T) { 201 | m := New[string, string]() 202 | m.Store("a", "1") 203 | m.Store("b", "2") 204 | m.Store("c", "3") 205 | get := m.Keys() 206 | sort.Strings(get) 207 | if !equalSlices(get, []string{"a", "b", "c"}) { 208 | t.Errorf("expected keys %v, got %v", []string{"a", "b", "c"}, get) 209 | } 210 | 211 | m2 := New[string, string]() 212 | if len(m2.Values()) != 0 { 213 | t.Errorf("expected 0, got %d", len(m2.Values())) 214 | } 215 | } 216 | 217 | func Test_Values(t *testing.T) { 218 | m := New[string, string]() 219 | m.Store("a", "1") 220 | m.Store("b", "2") 221 | m.Store("c", "3") 222 | get := m.Values() 223 | sort.Strings(get) 224 | if !equalSlices(get, []string{"1", "2", "3"}) { 225 | t.Errorf("expected values %v, got %v", []string{"1", "2", "3"}, get) 226 | } 227 | 228 | m2 := New[string, string]() 229 | if len(m2.Keys()) != 0 { 230 | t.Errorf("expected 0, got %d", len(m2.Keys())) 231 | } 232 | } 233 | 234 | func Test_UpdateOrInsert(t *testing.T) { 235 | t.Run("Update", func(t *testing.T) { 236 | m := New[string, string]() 237 | m.Store("a", "1") 238 | m.Store("b", "2") 239 | m.Store("c", "3") 240 | m.UpdateOrInsert("a", func(exist bool, old string) string { 241 | if !exist { 242 | t.Error("should exist") 243 | } 244 | if exist { 245 | return "4" 246 | } 247 | return old 248 | }) 249 | get, _ := m.Load("a") 250 | if get != "4" { 251 | t.Error("should be 4") 252 | } 253 | }) 254 | 255 | t.Run("Insert", func(t *testing.T) { 256 | m := New[string, string]() 257 | m.Store("a", "1") 258 | m.UpdateOrInsert("b", func(exist bool, old string) string { 259 | if !exist { 260 | return "2" 261 | } 262 | return "" 263 | }) 264 | 265 | get, _ := m.Load("b") 266 | if get != "2" { 267 | t.Error("should be 2") 268 | } 269 | }) 270 | } 271 | 272 | // 辅助函数,用于比较两个切片是否相等 273 | func equalSlices(a, b []string) bool { 274 | if len(a) != len(b) { 275 | return false 276 | } 277 | for i := range a { 278 | if a[i] != b[i] { 279 | return false 280 | } 281 | } 282 | return true 283 | } 284 | -------------------------------------------------------------------------------- /cmp/cmp.go: -------------------------------------------------------------------------------- 1 | package cmp 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "golang.org/x/exp/constraints" 6 | ) 7 | 8 | func Max[T constraints.Ordered](a, b T) T { 9 | if a > b { 10 | return a 11 | } 12 | return b 13 | } 14 | 15 | func Min[T constraints.Ordered](a, b T) T { 16 | if a < b { 17 | return a 18 | } 19 | return b 20 | } 21 | 22 | func MaxSlice[T constraints.Ordered](s []T) int { 23 | if len(s) == 0 { 24 | return -1 25 | } 26 | 27 | maxIndex := 0 28 | for i, v := range s[1:] { 29 | if s[maxIndex] < v { 30 | maxIndex = i 31 | } 32 | } 33 | return maxIndex 34 | } 35 | 36 | func MinSlice[T constraints.Ordered](s []T) int { 37 | if len(s) == 0 { 38 | return -1 39 | } 40 | 41 | minIndex := 0 42 | for i, v := range s[1:] { 43 | if s[minIndex] > v { 44 | minIndex = i 45 | } 46 | } 47 | return minIndex 48 | } 49 | 50 | func Compare[T constraints.Ordered](a, b T) int { 51 | switch { 52 | case a < b: 53 | return -1 54 | case a > b: 55 | return 1 56 | } 57 | 58 | return 0 59 | 60 | } 61 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/antlabs/gstl 2 | 3 | go 1.22.0 4 | 5 | require ( 6 | github.com/cespare/xxhash/v2 v2.1.2 7 | golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= 2 | github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 3 | golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67 h1:1UoZQm6f0P/ZO0w1Ri+f+ifG/gXhegadRdwBIXEFWDo= 4 | golang.org/x/exp v0.0.0-20241217172543-b2144cdd0a67/go.mod h1:qj5a5QZpwLU2NLQudwIN5koi3beDhSAlJwa67PuM98c= 5 | -------------------------------------------------------------------------------- /ifop/ifop.go: -------------------------------------------------------------------------------- 1 | package ifop 2 | 3 | // apache 2.0 antlabs 4 | func If[T any](cond bool, t T) (zero T) { 5 | if cond { 6 | return t 7 | } 8 | return 9 | } 10 | 11 | func IfElse[T any](cond bool, ifVal T, elseVal T) T { 12 | if cond { 13 | return ifVal 14 | } 15 | return elseVal 16 | } 17 | 18 | func IfElseAny(cond bool, ifVal any, elseVal any) any { 19 | if cond { 20 | return ifVal 21 | } 22 | return elseVal 23 | } 24 | -------------------------------------------------------------------------------- /ifop/ifop_test.go: -------------------------------------------------------------------------------- 1 | package ifop 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestIf(t *testing.T) { 10 | a := "" 11 | if result := If(len(a) == 0, "default"); result != "default" { 12 | t.Errorf("expected 'default', got '%s'", result) 13 | } 14 | } 15 | 16 | func TestIfElse(t *testing.T) { 17 | a := "" 18 | if result := IfElse(len(a) != 0, a, "default"); result != "default" { 19 | t.Errorf("expected 'default', got '%s'", result) 20 | } 21 | a = "hello" 22 | if result := IfElse(len(a) != 0, a, "default"); result != "hello" { 23 | t.Errorf("expected 'hello', got '%s'", result) 24 | } 25 | } 26 | 27 | func TestIfElse2(t *testing.T) { 28 | o := map[string]any{"hello": "hello"} 29 | a := []any{"hello", "world"} 30 | 31 | if result := IfElseAny(o != nil, o, a); !reflect.DeepEqual(result, o) { 32 | t.Errorf("expected %v, got %v", o, result) 33 | } 34 | o = nil 35 | if result := IfElseAny(o != nil, o, a); !reflect.DeepEqual(result, a) { 36 | t.Errorf("expected %v, got %v", a, result) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /linkedlist/linkedlist_benchmark_test.go: -------------------------------------------------------------------------------- 1 | package linkedlist 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "container/list" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | // goos: darwin 11 | // goarch: amd64 12 | // pkg: github.com/antlabs/gstl/linkedlist 13 | // cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz 14 | // Benchmark_ListAdd_Stdlib-8 5918479 190.0 ns/op 15 | // Benchmark_ListAdd_gstl-8 15942064 83.15 ns/op 16 | // PASS 17 | // ok github.com/antlabs/gstl/linkedlist 3.157s 18 | type timeNodeStdlib struct { 19 | expire uint64 20 | userExpire time.Duration 21 | callback func() 22 | isSchedule bool 23 | close uint32 24 | lock uint32 25 | } 26 | 27 | // 标准库 28 | func Benchmark_ListAdd_Stdlib(b *testing.B) { 29 | head := list.New() 30 | for i := 0; i < b.N; i++ { 31 | node := timeNodeStdlib{} 32 | head.PushBack(node) 33 | } 34 | } 35 | 36 | func Benchmark_ListAdd_gstl(b *testing.B) { 37 | head := New[timeNodeStdlib]() 38 | for i := 0; i < b.N; i++ { 39 | node := timeNodeStdlib{} 40 | head.PushBack(node) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /mapex/mapex.go: -------------------------------------------------------------------------------- 1 | package mapex 2 | 3 | import ( 4 | "sort" 5 | 6 | "golang.org/x/exp/constraints" 7 | ) 8 | 9 | type Map[K comparable, V any] map[K]V 10 | 11 | func Keys[K comparable, V any](m map[K]V) (keys []K) { 12 | return Map[K, V](m).Keys() 13 | } 14 | 15 | func SortKeys[K constraints.Ordered, V any](m map[K]V) (keys []K) { 16 | keys = Keys(m) 17 | sort.Slice(keys, func(i, j int) bool { 18 | return keys[i] < keys[j] 19 | }) 20 | return keys 21 | } 22 | 23 | func Values[K comparable, V any](m map[K]V) (values []V) { 24 | return Map[K, V](m).Values() 25 | } 26 | 27 | func SortValues[K comparable, V constraints.Ordered](m map[K]V) (values []V) { 28 | values = Values(m) 29 | sort.Slice(values, func(i, j int) bool { 30 | return values[i] < values[j] 31 | }) 32 | return values 33 | } 34 | 35 | func (m Map[K, V]) Keys() (keys []K) { 36 | keys = make([]K, 0, len(m)) 37 | for k := range m { 38 | keys = append(keys, k) 39 | } 40 | return 41 | } 42 | 43 | func (m Map[K, V]) Values() (values []V) { 44 | values = make([]V, 0, len(m)) 45 | for _, v := range m { 46 | values = append(values, v) 47 | } 48 | return 49 | } 50 | -------------------------------------------------------------------------------- /mapex/mapex_test.go: -------------------------------------------------------------------------------- 1 | package mapex 2 | 3 | import ( 4 | "sort" 5 | "testing" 6 | ) 7 | 8 | func Test_Keys(t *testing.T) { 9 | m := make(map[string]string) 10 | m["a"] = "1" 11 | m["b"] = "2" 12 | m["c"] = "3" 13 | get := Keys(m) 14 | sort.Strings(get) 15 | expected := []string{"a", "b", "c"} 16 | if !equalSlices(get, expected) { 17 | t.Errorf("expected %v, got %v", expected, get) 18 | } 19 | get = Map[string, string](m).Keys() 20 | sort.Strings(get) 21 | if !equalSlices(get, expected) { 22 | t.Errorf("expected %v, got %v", expected, get) 23 | } 24 | } 25 | 26 | func Test_Values(t *testing.T) { 27 | m := make(map[string]string) 28 | m["a"] = "1" 29 | m["b"] = "2" 30 | m["c"] = "3" 31 | get := Values(m) 32 | sort.Strings(get) 33 | expected := []string{"1", "2", "3"} 34 | if !equalSlices(get, expected) { 35 | t.Errorf("expected %v, got %v", expected, get) 36 | } 37 | 38 | get = Map[string, string](m).Values() 39 | sort.Strings(get) 40 | if !equalSlices(get, expected) { 41 | t.Errorf("expected %v, got %v", expected, get) 42 | } 43 | } 44 | 45 | // 辅助函数,用于比较两个切片是否相等 46 | func equalSlices(a, b []string) bool { 47 | if len(a) != len(b) { 48 | return false 49 | } 50 | for i := range a { 51 | if a[i] != b[i] { 52 | return false 53 | } 54 | } 55 | return true 56 | } 57 | -------------------------------------------------------------------------------- /must/must.go: -------------------------------------------------------------------------------- 1 | package must 2 | 3 | // apache 2.0 antlabs 4 | func TakeOneDiscardBool[T any](v T, ok bool) T { 5 | if !ok { 6 | panic("ok is false") 7 | } 8 | return v 9 | } 10 | 11 | func TakeOne[T any](v T, err error) T { 12 | if err != nil { 13 | panic(err.Error()) 14 | } 15 | return v 16 | } 17 | 18 | func TakeTwo[T, U any](a T, b U, err error) (T, U) { 19 | if err != nil { 20 | panic(err.Error()) 21 | } 22 | return a, b 23 | } 24 | 25 | func TakeThree[T, U, V any](a T, b U, c V, err error) (T, U, V) { 26 | if err != nil { 27 | panic(err.Error()) 28 | } 29 | 30 | return a, b, c 31 | } 32 | 33 | func TakeOneErr[T any](v T, err error) error { 34 | return err 35 | } 36 | 37 | func TakeOneBool[T any](v T, ok bool) bool { 38 | return ok 39 | } 40 | 41 | func TakeTwoErr[T, U any](a T, b U, err error) error { 42 | return err 43 | } 44 | 45 | func TakeThreeErr[T, U, V any](a T, b U, c V, err error) error { 46 | return err 47 | } 48 | -------------------------------------------------------------------------------- /radix/radix.go: -------------------------------------------------------------------------------- 1 | package radix 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "strings" 6 | "unicode/utf8" 7 | 8 | "github.com/antlabs/gstl/api" 9 | "github.com/antlabs/gstl/cmp" 10 | "github.com/antlabs/gstl/vec" 11 | ) 12 | 13 | var _ api.Trie[int] = (*Radix[int])(nil) 14 | 15 | // 健值对 16 | type pair[V any] struct { 17 | val V 18 | key string 19 | isSet bool 20 | } 21 | 22 | // 边 23 | type edge[V any] struct { 24 | label rune 25 | node *node[V] 26 | } 27 | 28 | // 节点 29 | type node[V any] struct { 30 | pair[V] 31 | prefix string 32 | edges vec.Vec[edge[V]] 33 | } 34 | 35 | // 头节点 36 | type Radix[V any] struct { 37 | root *node[V] 38 | length int 39 | } 40 | 41 | // 获取 42 | func (r *Radix[V]) Get(k string) (v V) { 43 | v, _ = r.TryGet(k) 44 | return 45 | } 46 | 47 | // 返回共同的前缀 48 | func commonPrefix(k1, k2 string) (i int) { 49 | min := cmp.Min(len(k1), len(k2)) 50 | 51 | for i = 0; i < min; i++ { 52 | if k1[i] != k2[i] { 53 | return i 54 | } 55 | } 56 | return 57 | } 58 | 59 | func (r *Radix[V]) newEdge(label rune, p pair[V], prefix string) edge[V] { 60 | return edge[V]{ 61 | label: label, 62 | node: &node[V]{ 63 | pair: p, 64 | prefix: prefix, 65 | }, 66 | } 67 | } 68 | 69 | // 设置 70 | func (r *Radix[V]) Swap(k string, v V) (prev V, replaced bool) { 71 | 72 | var parent *node[V] 73 | var found bool 74 | n := r.root 75 | remaining := k 76 | 77 | for { 78 | 79 | if len(k) == 0 { 80 | if n.isSet { 81 | prev = n.val 82 | n.val = v 83 | return prev, true 84 | } 85 | 86 | n.key, n.val = k, v 87 | n.isSet = true 88 | r.length++ 89 | replaced = true 90 | return 91 | } 92 | 93 | rune, _ := utf8.DecodeLastRuneInString(remaining) 94 | parent = n 95 | n, found = n.children(rune) 96 | if !found { 97 | parent.edges.Push(r.newEdge(rune, pair[V]{ 98 | key: k, 99 | val: v, 100 | isSet: true, 101 | }, remaining)) 102 | r.length++ 103 | return 104 | } 105 | 106 | // 待插入节点 貌似和当前节点有共同的路径,先continue看看情况,等会插入 107 | commonPrefixLen := commonPrefix(remaining, n.prefix) 108 | if commonPrefixLen == len(n.prefix) { 109 | remaining = remaining[commonPrefixLen:] 110 | continue 111 | } 112 | 113 | subRune, _ := utf8.DecodeLastRuneInString(remaining[commonPrefixLen:]) 114 | // 这里遇到分叉 115 | // 这里的节点只加上,比如原来节点是/helloaxx, 现在要插入/hellobxx 116 | // /hello 会变成child 117 | r.length++ 118 | child := &node[V]{ 119 | // 共同路径成为两个分裂节点的父节点 120 | prefix: remaining[:commonPrefixLen], 121 | } 122 | 123 | // axx 变成child的儿子1 124 | child.edges.Push(edge[V]{ 125 | label: subRune, 126 | node: n, 127 | }) 128 | // 把/helloaxx 变成axx 129 | n.prefix = n.prefix[commonPrefixLen:] 130 | 131 | // 如果以前 parent指向/helloaxx,那么现在parent就指向/hello(即child节点) 132 | // axx 和bxx都将成为/hello的两个子节点 133 | parent.setChildren(rune, child) 134 | 135 | remaining = remaining[commonPrefixLen:] 136 | pairKV := pair[V]{ 137 | key: k, 138 | val: v, 139 | isSet: true, 140 | } 141 | 142 | // 如果新插入路径只是原路径的子集,就走起 143 | // 比如原路径是/helloaxx, 本次插入/hello 144 | if len(remaining) == 0 { 145 | child.pair = pairKV 146 | return 147 | } 148 | 149 | subRune, _ = utf8.DecodeLastRuneInString(remaining) 150 | 151 | // 把bxx的路径是在newEdge函数里面设置的 152 | child.insertChildren(subRune, r.newEdge(subRune, pairKV, remaining)) 153 | } 154 | 155 | } 156 | 157 | // 是否有这个前缀串 158 | func (r *Radix[V]) HasPrefix(k string) (ok bool) { 159 | return 160 | } 161 | 162 | func (n *node[V]) insertChildren(r rune, new edge[V]) { 163 | 164 | index, found := n.find(r) 165 | if found { 166 | panic("这个节点已经存在过???") 167 | } 168 | 169 | n.edges.Insert(index, new) 170 | } 171 | 172 | func (n *node[V]) setChildren(r rune, new *node[V]) { 173 | index, found := n.find(r) 174 | if !found { 175 | panic("没找到这个节点") 176 | } 177 | 178 | n.edges.GetPtr(index).node = new 179 | } 180 | 181 | func (n *node[V]) children(r rune) (v *node[V], found bool) { 182 | index, found := n.find(r) 183 | if !found { 184 | return 185 | } 186 | return n.edges.Get(index).node, true 187 | } 188 | 189 | func (n *node[V]) find(r rune) (index int, found bool) { 190 | 191 | index = n.edges.SearchFunc(func(elem edge[V]) bool { return r < elem.label }) 192 | if index > 0 && n.edges.Get(index-1).label >= r { 193 | return index - 1, true 194 | } 195 | 196 | return index, false 197 | } 198 | 199 | // 获取返回bool 200 | func (r *Radix[V]) TryGet(k string) (v V, found bool) { 201 | n := r.root 202 | 203 | for { 204 | 205 | // k 消费完,找到,或者找不到 206 | if len(k) == 0 { 207 | if n.isSet { 208 | return n.val, true 209 | } 210 | return 211 | } 212 | 213 | rune, _ := utf8.DecodeLastRuneInString(k) 214 | n, found = n.children(rune) 215 | if !found { 216 | return 217 | } 218 | 219 | if strings.HasPrefix(k, n.prefix) { 220 | k = k[len(n.prefix):] 221 | continue 222 | } 223 | return 224 | } 225 | 226 | } 227 | 228 | // 删除 229 | func (r *Radix[V]) Delete(k string) { 230 | 231 | } 232 | 233 | // 返回长度 234 | func (r *Radix[V]) Len() int { 235 | return r.length 236 | } 237 | -------------------------------------------------------------------------------- /rbtree/rbtree.go: -------------------------------------------------------------------------------- 1 | package rbtree 2 | 3 | // apache 2.0 antlabs 4 | // 参考资料 5 | // https://github.com/torvalds/linux/blob/master/lib/rbtree.c 6 | import ( 7 | "errors" 8 | 9 | "github.com/antlabs/gstl/api" 10 | "golang.org/x/exp/constraints" 11 | ) 12 | 13 | // 红黑树5条重要性质 14 | // 1. 节点为红色或者黑色 15 | // 2. 根节点是黑色(黑根) 16 | // 3. 所有叶子(空节点)均为黑色 17 | // 4. 每个红色节点的两个子节点均为黑色(红父黑子) 18 | // 5. 从根到叶的每个路径包含相同数量的黑色节点(黑高相同) 19 | 20 | var _ api.SortedMap[int, int] = (*RBTree[int, int])(nil) 21 | 22 | var ErrNotFound = errors.New("rbtree: not found value") 23 | 24 | type color int8 25 | 26 | const ( 27 | RED color = 1 28 | BLACK color = 2 29 | ) 30 | 31 | // 元素 32 | type pair[K constraints.Ordered, V any] struct { 33 | val V 34 | key K 35 | } 36 | 37 | type parentColor[K constraints.Ordered, V any] struct { 38 | parent *node[K, V] 39 | color color 40 | } 41 | 42 | type node[K constraints.Ordered, V any] struct { 43 | left *node[K, V] 44 | right *node[K, V] 45 | pair[K, V] 46 | parentColor[K, V] 47 | } 48 | 49 | func (n *node[K, V]) setParent(parent *node[K, V]) { 50 | n.parent = parent 51 | } 52 | 53 | func (n *node[K, V]) link(parent *node[K, V], link **node[K, V]) { 54 | n.parent = parent 55 | n.color = RED 56 | *link = n 57 | } 58 | 59 | type root[K constraints.Ordered, V any] struct { 60 | node *node[K, V] 61 | } 62 | 63 | func (r *root[K, V]) rotateLeft(n *node[K, V]) { 64 | 65 | right := n.right 66 | 67 | n.right = right.left 68 | if right.left != nil { 69 | right.left.parent = n 70 | } 71 | right.left = n 72 | 73 | right.parent = n.parent 74 | 75 | if n.parent != nil { 76 | 77 | if n == n.parent.left { 78 | n.parent.left = right 79 | } else { 80 | n.parent.right = right 81 | } 82 | } else { 83 | r.node = right 84 | } 85 | n.parent = right 86 | } 87 | 88 | func (r *root[K, V]) rotateRight(n *node[K, V]) { 89 | left := n.left 90 | n.left = left.right 91 | if left.right != nil { 92 | left.right.parent = n 93 | } 94 | left.right = n 95 | 96 | left.parent = n.parent 97 | if n.parent != nil { 98 | if n == n.parent.right { 99 | n.parent.right = left 100 | } else { 101 | n.parent.left = left 102 | } 103 | } else { 104 | r.node = left 105 | } 106 | n.parent = left 107 | } 108 | 109 | func (r *root[K, V]) changeChild(old, new, parent *node[K, V]) { 110 | if parent != nil { 111 | if parent.left == old { 112 | parent.left = new 113 | } else { 114 | parent.right = new 115 | } 116 | } else { 117 | r.node = new 118 | } 119 | 120 | } 121 | 122 | func (r *root[K, V]) insert(n *node[K, V]) { 123 | 124 | var parent, gparent *node[K, V] 125 | 126 | for parent = n.parent; parent != nil && parent.color == RED; parent = n.parent { 127 | 128 | gparent = parent.parent 129 | if parent == gparent.left { 130 | 131 | uncle := gparent.right 132 | if uncle != nil && uncle.color == RED { 133 | uncle.color = BLACK 134 | parent.color = BLACK 135 | gparent.color = RED 136 | n = gparent 137 | continue 138 | } 139 | 140 | if parent.right == n { 141 | r.rotateLeft(parent) 142 | parent, n = n, parent 143 | } 144 | 145 | parent.color = BLACK 146 | gparent.color = RED 147 | r.rotateRight(gparent) 148 | } else { 149 | uncle := gparent.left 150 | if uncle != nil && uncle.color == RED { 151 | uncle.color = BLACK 152 | parent.color = BLACK 153 | gparent.color = RED 154 | n = gparent 155 | continue 156 | } 157 | 158 | if parent.left == n { 159 | r.rotateRight(parent) 160 | parent, n = n, parent 161 | } 162 | parent.color = BLACK 163 | gparent.color = RED 164 | r.rotateLeft(gparent) 165 | } 166 | } 167 | r.node.color = BLACK //黑根 168 | } 169 | 170 | // 红黑树 171 | type RBTree[K constraints.Ordered, V any] struct { 172 | length int 173 | root root[K, V] 174 | } 175 | 176 | type InsertOrUpdateCb[V any] func(prev V, new V) V 177 | 178 | // InsertOrUpdate inserts or updates an element in the RBTree 179 | func (r *RBTree[K, V]) InsertOrUpdate(k K, v V, cb InsertOrUpdateCb[V]) { 180 | if prev, ok := r.TryGet(k); ok { 181 | v = cb(prev, v) 182 | } 183 | r.Set(k, v) 184 | } 185 | 186 | // 初始化函数 187 | func New[K constraints.Ordered, V any]() *RBTree[K, V] { 188 | return &RBTree[K, V]{} 189 | } 190 | 191 | // 第一个节点 192 | func (r *RBTree[K, V]) First() (v V, ok bool) { 193 | n := r.root.node 194 | if n == nil { 195 | return 196 | } 197 | 198 | for n.left != nil { 199 | n = n.left 200 | } 201 | 202 | return n.val, true 203 | } 204 | 205 | // 最后一个节点 206 | func (r *RBTree[K, V]) Last() (v V, ok bool) { 207 | n := r.root.node 208 | if n == nil { 209 | return 210 | } 211 | 212 | for n.right != nil { 213 | n = n.right 214 | } 215 | 216 | return n.val, true 217 | } 218 | 219 | func (r *RBTree[K, V]) Set(k K, v V) { 220 | _, _ = r.Swap(k, v) 221 | } 222 | 223 | // 设置 224 | func (r *RBTree[K, V]) Swap(k K, v V) (prev V, replaced bool) { 225 | link := &r.root.node 226 | var parent *node[K, V] 227 | 228 | node := &node[K, V]{pair: pair[K, V]{key: k, val: v}} 229 | 230 | for *link != nil { 231 | parent = *link 232 | if parent.key == k { 233 | prev = parent.val 234 | parent.val = v 235 | return prev, true 236 | } 237 | 238 | if parent.key < k { 239 | link = &parent.right 240 | } else { 241 | link = &parent.left 242 | } 243 | } 244 | 245 | node.link(parent, link) 246 | r.root.insert(node) 247 | r.length++ 248 | return 249 | } 250 | 251 | // Get 252 | func (r *RBTree[K, V]) Get(k K) (v V) { 253 | v, _ = r.TryGet(k) 254 | return 255 | } 256 | 257 | // 从rbtree 找到需要的值 258 | func (r *RBTree[K, V]) TryGet(k K) (v V, ok bool) { 259 | n := r.root.node 260 | for n != nil { 261 | if n.key == k { 262 | return n.val, true 263 | } 264 | 265 | if k > n.key { 266 | n = n.right 267 | } else { 268 | n = n.left 269 | } 270 | } 271 | 272 | return 273 | } 274 | 275 | // 删除 276 | func (r *root[K, V]) erase(n *node[K, V]) { 277 | 278 | var child, parent *node[K, V] 279 | var color color 280 | if n.left == nil { 281 | child = n.right 282 | } else if n.right == nil { 283 | child = n.left 284 | } else { 285 | old := n 286 | n = n.right 287 | left := n.left 288 | for ; left != nil; left = n.left { 289 | } 290 | child = n.right 291 | parent = n.parent 292 | color = n.color 293 | 294 | if child != nil { 295 | child.parent = parent 296 | } 297 | 298 | if parent != nil { 299 | if parent.left == n { 300 | parent.left = child 301 | } else { 302 | parent.right = child 303 | } 304 | } else { 305 | r.node = child 306 | } 307 | 308 | if n.parent == old { 309 | parent = n 310 | } 311 | 312 | n.parent = old.parent 313 | n.color = old.color 314 | n.right = old.right 315 | n.left = old.left 316 | 317 | if old.parent != nil { 318 | if old.parent.left == old { 319 | old.parent.left = n 320 | } else { 321 | old.parent.right = n 322 | } 323 | } else { 324 | r.node = n 325 | } 326 | old.left.parent = n 327 | if old.right != nil { 328 | old.right.parent = n 329 | } 330 | goto color 331 | } 332 | parent = n.parent 333 | color = n.color 334 | 335 | if child != nil { 336 | child.parent = parent 337 | } 338 | if parent != nil { 339 | if parent.left == n { 340 | parent.left = child 341 | } else { 342 | parent.right = child 343 | } 344 | } else { 345 | r.node = child 346 | } 347 | 348 | color: 349 | if color == BLACK { 350 | r.eraseColor(child, parent) 351 | } 352 | } 353 | 354 | func (r *root[K, V]) eraseColor(n *node[K, V], parent *node[K, V]) { 355 | 356 | var other *node[K, V] 357 | for (n == nil || n.color == BLACK) && n != r.node { 358 | if parent.left == n { 359 | other = parent.right 360 | if other.color == RED { 361 | other.color = BLACK 362 | parent.color = RED 363 | r.rotateLeft(parent) 364 | other = parent.right 365 | } 366 | if (other.left == nil || other.left.color == BLACK) && (other.right == nil || other.right.color == BLACK) { 367 | 368 | other.color = RED 369 | n = parent 370 | parent = n.parent 371 | } else { 372 | 373 | if other.right == nil || other.right.color == BLACK { 374 | 375 | oleft := other.left 376 | if oleft != nil { 377 | oleft.color = BLACK 378 | } 379 | other.color = RED 380 | r.rotateRight(other) 381 | other = parent.right 382 | } 383 | other.color = parent.color 384 | parent.color = BLACK 385 | if other.right != nil { 386 | other.right.color = BLACK 387 | } 388 | r.rotateLeft(parent) 389 | n = r.node 390 | break 391 | } 392 | } else { 393 | 394 | other = parent.left 395 | if other.color == RED { 396 | 397 | other.color = BLACK 398 | parent.color = RED 399 | r.rotateRight(parent) 400 | other = parent.left 401 | } 402 | 403 | if (other.left == nil || other.left.color == BLACK) && (other.right == nil || other.right.color == BLACK) { 404 | other.color = RED 405 | n = parent 406 | parent = n.parent 407 | } else { 408 | if other.left == nil || other.left.color == BLACK { 409 | 410 | oright := other.right 411 | if oright != nil { 412 | oright.color = BLACK 413 | } 414 | other.color = RED 415 | r.rotateLeft(other) 416 | other = parent.left 417 | } 418 | other.color = parent.color 419 | parent.color = BLACK 420 | if other.left != nil { 421 | other.left.color = BLACK 422 | } 423 | r.rotateRight(parent) 424 | 425 | n = r.node 426 | break 427 | } 428 | 429 | } 430 | } 431 | if n != nil { 432 | n.color = BLACK 433 | } 434 | 435 | } 436 | 437 | func (r *RBTree[K, V]) Delete(k K) { 438 | n := r.root.node 439 | for n != nil { 440 | if n.key == k { 441 | goto found 442 | } 443 | 444 | if k > n.key { 445 | n = n.right 446 | } else { 447 | n = n.left 448 | } 449 | } 450 | return 451 | 452 | found: 453 | r.root.erase(n) 454 | return 455 | } 456 | 457 | func (r *RBTree[K, V]) Len() int { 458 | return r.length 459 | } 460 | 461 | func (r *RBTree[K, V]) TopMin(limit int, callback func(k K, v V) bool) { 462 | 463 | r.Range(func(k K, v V) bool { 464 | 465 | if limit <= 0 { 466 | return false 467 | } 468 | 469 | if !callback(k, v) { 470 | return false 471 | } 472 | 473 | limit-- 474 | return true 475 | }) 476 | } 477 | 478 | // 遍历rbtree 479 | func (r *RBTree[K, V]) RangePrev(callback func(k K, v V) bool) { 480 | // 遍历 481 | if r.root.node == nil { 482 | return 483 | } 484 | 485 | r.root.node.rangePrevInner(callback) 486 | return 487 | } 488 | 489 | func (r *RBTree[K, V]) TopMax(limit int, callback func(k K, v V) bool) { 490 | 491 | r.RangePrev(func(k K, v V) bool { 492 | 493 | if limit <= 0 { 494 | return false 495 | } 496 | 497 | if !callback(k, v) { 498 | return false 499 | } 500 | 501 | limit-- 502 | return true 503 | }) 504 | } 505 | 506 | func (n *node[K, V]) rangePrevInner(callback func(k K, v V) bool) bool { 507 | 508 | if n == nil { 509 | return true 510 | } 511 | 512 | if n.right != nil { 513 | if !n.right.rangePrevInner(callback) { 514 | return false 515 | } 516 | } 517 | 518 | if !callback(n.key, n.val) { 519 | return false 520 | } 521 | 522 | if n.left != nil { 523 | if !n.left.rangePrevInner(callback) { 524 | return false 525 | } 526 | } 527 | 528 | return true 529 | } 530 | 531 | func (n *node[K, V]) rangeInner(callback func(k K, v V) bool) bool { 532 | 533 | if n.left != nil { 534 | if !n.left.rangeInner(callback) { 535 | return false 536 | } 537 | } 538 | 539 | if !callback(n.key, n.val) { 540 | return false 541 | } 542 | 543 | if n.right != nil { 544 | if !n.right.rangeInner(callback) { 545 | return false 546 | } 547 | } 548 | return true 549 | } 550 | 551 | // 遍历rbtree 552 | func (a *RBTree[K, V]) Range(callback func(k K, v V) bool) { 553 | // 遍历 554 | if a.root.node == nil { 555 | return 556 | } 557 | 558 | a.root.node.rangeInner(callback) 559 | return 560 | } 561 | -------------------------------------------------------------------------------- /rbtree/rbtree_bench_test.go: -------------------------------------------------------------------------------- 1 | package rbtree 2 | 3 | // apache 2.0 antlabs 4 | 5 | // b.N = 500w 6 | // goos: darwin 7 | // goarch: amd64 8 | // pkg: github.com/antlabs/gstl/rbtree 9 | // cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz 10 | // BenchmarkGetAsc-8 1000000000 0.3336 ns/op 11 | // BenchmarkGetDesc-8 1000000000 0.3702 ns/op 12 | // BenchmarkGetStd-8 13 | // 1000000000 0.8940 ns/op 14 | // PASS 15 | // ok github.com/antlabs/gstl/rbtree 139.415s 16 | 17 | // b.N = 3kw 18 | // goos: darwin 19 | // goarch: arm64 20 | // pkg: github.com/antlabs/gstl/rbtree 21 | // BenchmarkGetAsc-8 32662837 40.22 ns/op 22 | // BenchmarkGetDesc-8 33250437 39.52 ns/op 23 | // BenchmarkGetStd-8 29353758 49.73 ns/op 24 | // PASS 25 | // ok github.com/antlabs/gstl/rbtree 13.030s 26 | 27 | import ( 28 | "fmt" 29 | "testing" 30 | ) 31 | 32 | func BenchmarkSetAsc(b *testing.B) { 33 | //max := 1000000.0 * 5 34 | set := New[float64, float64]() 35 | max := float64(b.N) 36 | for i := 0.0; i < max; i++ { 37 | set.Set(i, i) 38 | } 39 | 40 | } 41 | 42 | func BenchmarkGetAsc(b *testing.B) { 43 | //max := 1000000.0 * 5 44 | set := New[float64, float64]() 45 | max := float64(b.N) 46 | for i := 0.0; i < max; i++ { 47 | set.Set(i, i) 48 | } 49 | 50 | b.ResetTimer() 51 | 52 | for i := 0.0; i < max; i++ { 53 | v := set.Get(i) 54 | if v != i { 55 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 56 | } 57 | } 58 | } 59 | 60 | func BenchmarkGetDesc(b *testing.B) { 61 | max := float64(b.N) 62 | //max := 1000000.0 * 5 63 | set := New[float64, float64]() 64 | for i := max; i >= 0; i-- { 65 | set.Set(i, i) 66 | } 67 | 68 | b.ResetTimer() 69 | 70 | for i := 0.0; i < max; i++ { 71 | v := set.Get(i) 72 | if v != i { 73 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 74 | } 75 | } 76 | } 77 | 78 | func BenchmarkGetStd(b *testing.B) { 79 | 80 | max := float64(b.N) 81 | //max := 1000000.0 * 5 82 | set := make(map[float64]float64, int(max)) 83 | for i := 0.0; i < max; i++ { 84 | set[i] = i 85 | } 86 | 87 | b.ResetTimer() 88 | 89 | for i := 0.0; i < max; i++ { 90 | v := set[i] 91 | if v != i { 92 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 93 | } 94 | } 95 | } 96 | 97 | func BenchmarkSet(b *testing.B) { 98 | max := float64(b.N) 99 | //max := 1000000.0 * 5 100 | set := New[float64, float64]() 101 | for i := max; i >= 0; i-- { 102 | set.Set(i, i) 103 | } 104 | 105 | } 106 | 107 | func BenchmarkSetStd(b *testing.B) { 108 | max := float64(b.N) 109 | //max := 1000000.0 * 5 110 | set := make(map[float64]float64) 111 | for i := max; i >= 0.0; i-- { 112 | set[i] = i 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /rbtree/rbtree_test.go: -------------------------------------------------------------------------------- 1 | package rbtree 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/antlabs/gstl/cmp" 7 | "github.com/antlabs/gstl/vec" 8 | ) 9 | 10 | // 从小到大, 插入 11 | func Test_SetAndGet(t *testing.T) { 12 | b := New[int, int]() 13 | max := 1000 14 | for i := 0; i < max; i++ { 15 | b.Swap(i, i) 16 | } 17 | 18 | for i := 0; i < max; i++ { 19 | v, ok := b.TryGet(i) 20 | if !ok { 21 | t.Errorf("Expected true, got false for key %d", i) 22 | } 23 | if v != i { 24 | t.Errorf("Expected %d, got %d for key %d", i, v, i) 25 | } 26 | } 27 | } 28 | 29 | // 从大到小, 插入 30 | func Test_SetAndGet2(t *testing.T) { 31 | b := New[int, int]() 32 | max := 1000 33 | for i := max; i >= 0; i-- { 34 | b.Swap(i, i) 35 | } 36 | 37 | for i := max; i >= 0; i-- { 38 | v, ok := b.TryGet(i) 39 | if !ok { 40 | t.Errorf("Expected true, got false for key %d", i) 41 | } 42 | if v != i { 43 | t.Errorf("Expected %d, got %d for key %d", i, v, i) 44 | } 45 | } 46 | } 47 | 48 | // 测试avltree删除的情况, 少量数量 49 | func Test_RBTree_Delete1(t *testing.T) { 50 | for max := 3; max < 1000; max++ { 51 | b := New[int, int]() 52 | 53 | // 设置0-max 54 | for i := 0; i < max; i++ { 55 | b.Set(i, i) 56 | } 57 | 58 | // 删除0-max/2 59 | for i := 0; i < max/2; i++ { 60 | b.Delete(i) 61 | } 62 | 63 | // max/2-max应该能找到 64 | for i := max / 2; i < max; i++ { 65 | v, ok := b.TryGet(i) 66 | if !ok { 67 | t.Errorf("Expected true, got false for key %d", i) 68 | } 69 | if v != i { 70 | t.Errorf("Expected %d, got %d for key %d", i, v, i) 71 | } 72 | } 73 | 74 | // 0-max/2应该找不到 75 | for i := 0; i < max/2; i++ { 76 | v, ok := b.TryGet(i) 77 | if ok { 78 | t.Errorf("Expected false, got true for key %d", i) 79 | } 80 | if v != 0 { 81 | t.Errorf("Expected 0, got %d for key %d", v, i) 82 | } 83 | } 84 | } 85 | } 86 | 87 | // 测试TopMax, 返回最大的几个数据降序返回 88 | func Test_RBTree_TopMax(t *testing.T) { 89 | need := [3][]int{} 90 | count10 := 10 91 | count100 := 100 92 | count1000 := 1000 93 | count := []int{count10, count100, count1000} 94 | 95 | for i := 0; i < len(count); i++ { 96 | for j, k := count[i]-1, count100-1; j >= 0 && k >= 0; j-- { 97 | need[i] = append(need[i], j) 98 | k-- 99 | } 100 | } 101 | 102 | for i, b := range []*RBTree[int, int]{ 103 | // btree里面元素 少于 TopMax 需要返回的值 104 | func() *RBTree[int, int] { 105 | b := New[int, int]() 106 | for i := 0; i < count10; i++ { 107 | b.Set(i, i) 108 | } 109 | return b 110 | }(), 111 | // btree里面元素 等于 TopMax 需要返回的值 112 | func() *RBTree[int, int] { 113 | b := New[int, int]() 114 | for i := 0; i < count100; i++ { 115 | b.Set(int(i), i) 116 | } 117 | return b 118 | }(), 119 | // btree里面元素 大于 TopMax 需要返回的值 120 | func() *RBTree[int, int] { 121 | b := New[int, int]() 122 | for i := 0; i < count1000; i++ { 123 | b.Set(int(i), i) 124 | } 125 | return b 126 | }(), 127 | } { 128 | var key, val []int 129 | b.TopMax(count100, func(k int, v int) bool { 130 | key = append(key, int(k)) 131 | val = append(val, v) 132 | return true 133 | }) 134 | length := cmp.Min(count[i], len(need[i])) 135 | if !slicesEqual(key, need[i][:length]) { 136 | t.Errorf("Expected %v, got %v", need[i][:length], key) 137 | } 138 | if !slicesEqual(val, need[i][:length]) { 139 | t.Errorf("Expected %v, got %v", need[i][:length], val) 140 | } 141 | } 142 | } 143 | 144 | // 测试TopMin, 它返回最小的几个值 145 | func Test_RBTree_TopMin(t *testing.T) { 146 | need := []int{} 147 | count10 := 10 148 | count100 := 100 149 | count1000 := 1000 150 | 151 | for i := 0; i < count1000; i++ { 152 | need = append(need, i) 153 | } 154 | 155 | needCount := []int{count10, count100, count100} 156 | for i, b := range []*RBTree[int, int]{ 157 | // btree里面元素 少于 TopMin 需要返回的值 158 | func() *RBTree[int, int] { 159 | b := New[int, int]() 160 | for i := 0; i < count10; i++ { 161 | b.Set(i, i) 162 | } 163 | return b 164 | }(), 165 | // btree里面元素 等于 TopMin 需要返回的值 166 | func() *RBTree[int, int] { 167 | b := New[int, int]() 168 | for i := 0; i < count100; i++ { 169 | b.Set(i, i) 170 | } 171 | return b 172 | }(), 173 | // btree里面元素 大于 TopMin 需要返回的值 174 | func() *RBTree[int, int] { 175 | b := New[int, int]() 176 | for i := 0; i < count1000; i++ { 177 | b.Set(i, i) 178 | } 179 | return b 180 | }(), 181 | } { 182 | var key, val []int 183 | b.TopMin(count100, func(k, v int) bool { 184 | key = append(key, k) 185 | val = append(val, v) 186 | return true 187 | }) 188 | if !slicesEqual(key, need[:needCount[i]]) { 189 | t.Errorf("Expected %v, got %v", need[:needCount[i]], key) 190 | } 191 | if !slicesEqual(val, need[:needCount[i]]) { 192 | t.Errorf("Expected %v, got %v", need[:needCount[i]], val) 193 | } 194 | } 195 | } 196 | 197 | func Test_RanePrev(t *testing.T) { 198 | a := New[int, int]() 199 | data := []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10} 200 | 201 | dataRev := vec.New(data...).Clone().Rev().ToSlice() 202 | for i := len(data) / 2; i >= 0; i-- { 203 | a.Set(i, i) 204 | } 205 | 206 | for i := len(data)/2 + 1; i < len(data); i++ { 207 | a.Set(i, i) 208 | } 209 | 210 | var gotKey []int 211 | var gotVal []int 212 | a.RangePrev(func(k, v int) bool { 213 | gotKey = append(gotKey, k) 214 | gotVal = append(gotVal, v) 215 | return true 216 | }) 217 | 218 | if !slicesEqual(gotKey, dataRev) { 219 | t.Errorf("Expected %v, got %v", dataRev, gotKey) 220 | } 221 | if !slicesEqual(gotVal, dataRev) { 222 | t.Errorf("Expected %v, got %v", dataRev, gotVal) 223 | } 224 | } 225 | 226 | func Test_RBTree_InsertOrUpdate(t *testing.T) { 227 | b := New[int, int]() 228 | max := 100 229 | 230 | // Insert elements 231 | for i := 0; i < max; i++ { 232 | b.InsertOrUpdate(i, i, func(prev, new int) int { 233 | return prev + new 234 | }) 235 | } 236 | 237 | // Update elements 238 | for i := 0; i < max; i++ { 239 | b.InsertOrUpdate(i, i, func(prev, new int) int { 240 | return prev + new 241 | }) 242 | } 243 | 244 | // Verify elements 245 | for i := 0; i < max; i++ { 246 | v, ok := b.TryGet(i) 247 | if !ok || v != i*2 { 248 | t.Errorf("expected %d, got %v", i*2, v) 249 | } 250 | } 251 | } 252 | 253 | // Helper function to compare slices 254 | func slicesEqual[T comparable](a, b []T) bool { 255 | if len(a) != len(b) { 256 | return false 257 | } 258 | for i := range a { 259 | if a[i] != b[i] { 260 | return false 261 | } 262 | } 263 | return true 264 | } 265 | -------------------------------------------------------------------------------- /rhashmap/opt.go: -------------------------------------------------------------------------------- 1 | package rhashmap 2 | 3 | // apache 2.0 antlabs 4 | type Option interface { 5 | apply(*config) 6 | } 7 | 8 | type hashFunc func(str string) uint64 9 | 10 | func (h hashFunc) apply(c *config) { 11 | c.hashFunc = h 12 | } 13 | 14 | func WithHashFunc(hfunc func(str string) uint64) Option { 15 | return hashFunc(hfunc) 16 | } 17 | 18 | type withCap int 19 | 20 | func (wc withCap) apply(c *config) { 21 | c.cap = int(wc) 22 | } 23 | 24 | func WithCap(cap int) Option { 25 | return withCap(cap) 26 | } 27 | -------------------------------------------------------------------------------- /rhashmap/rhashmap.go: -------------------------------------------------------------------------------- 1 | package rhashmap 2 | 3 | // apache 2.0 antlabs 4 | // 参考资料 5 | // https://github.com/redis/redis/blob/unstable/src/dict.c 6 | import ( 7 | "errors" 8 | "math" 9 | "reflect" 10 | "unsafe" 11 | 12 | "github.com/antlabs/gstl/api" 13 | xxhash "github.com/cespare/xxhash/v2" 14 | ) 15 | 16 | var _ api.Map[int, int] = (*HashMap[int, int])(nil) 17 | 18 | const ( 19 | HT_INITIAL_EXP = 2 20 | HT_INITIAL_SIZE = (1 << (HT_INITIAL_EXP)) 21 | ) 22 | 23 | var forceResizeRatio = 5 24 | 25 | var ( 26 | ErrHashing = errors.New("rehashing...") 27 | ErrSize = errors.New("wrong size") 28 | ErrNotFound = errors.New("not found") 29 | ) 30 | 31 | // 元素 32 | type entry[K comparable, V any] struct { 33 | key K 34 | val V 35 | next *entry[K, V] 36 | } 37 | 38 | type config struct { 39 | hashFunc func(str string) uint64 40 | cap int 41 | } 42 | 43 | // hash 表头 44 | type HashMap[K comparable, V any] struct { 45 | // 大多数情况, table[0]里就存在hash表元素的数据 46 | // 大小一尘不变hash随着数据的增强效率会降低, rhashmap的实现是超过某阈值时 47 | // table[1] 会先放新申请的hash表元素, 当table[0]都移动到table[1]时, table[1]赋值给table[0], 完成一次hash扩容 48 | // 移动的操作都分摊到Get, Set, Delete操作中, 每次移动一个槽位, 或者跳运100个空桶(TODO修改代码, 需要修改这边的注释) 49 | table [2][]*entry[K, V] //hash table 50 | used [2]uint64 // 记录每个table里面存在的元素个数 51 | sizeExp [2]int8 //记录exp 52 | 53 | rehashidx int // rehashid目前的槽位 54 | keySize int //key的长度 55 | config 56 | isKeyStr bool //是string类型的key, 或者不是 57 | init bool 58 | } 59 | 60 | // 初始化一个hashtable 61 | func New[K comparable, V any]() *HashMap[K, V] { 62 | h := &HashMap[K, V]{} 63 | h.Init() 64 | return h 65 | } 66 | 67 | func (h *HashMap[K, V]) Init() { 68 | 69 | h.rehashidx = -1 70 | h.hashFunc = xxhash.Sum64String 71 | h.init = true 72 | 73 | h.reset(0) 74 | h.reset(1) 75 | h.keyTypeAndKeySize() 76 | } 77 | 78 | func (h *HashMap[K, V]) lazyinit() { 79 | if !h.init { 80 | h.Init() 81 | } 82 | } 83 | 84 | // 初始化一个hashtable并且可以设置值 85 | func NewWithOpt[K comparable, V any](opts ...Option) *HashMap[K, V] { 86 | h := New[K, V]() 87 | for _, o := range opts { 88 | o.apply(&h.config) 89 | } 90 | 91 | if h.cap > 0 { 92 | h.Resize(uint64(h.cap)) 93 | } 94 | return h 95 | } 96 | 97 | // 保存key的类型和key的长度 98 | func (h *HashMap[K, V]) keyTypeAndKeySize() { 99 | var k K 100 | switch (interface{})(k).(type) { 101 | case string: 102 | h.isKeyStr = true 103 | default: 104 | h.keySize = int(unsafe.Sizeof(k)) 105 | } 106 | } 107 | 108 | // 计算hash值 109 | func (h *HashMap[K, V]) calHash(k K) uint64 { 110 | var key string 111 | 112 | if h.isKeyStr { 113 | // 直接赋值会报错, 使用unsafe绕过编译器检查 114 | key = *(*string)(unsafe.Pointer(&k)) 115 | } else { 116 | // 因为xxhash.Sum64String 接收string, 所以要把非string类型变量当成string类型来处理 117 | key = *(*string)(unsafe.Pointer(&reflect.StringHeader{ 118 | Data: uintptr(unsafe.Pointer(&k)), 119 | Len: h.keySize, 120 | })) 121 | } 122 | 123 | return xxhash.Sum64String(key) 124 | } 125 | 126 | func (h *HashMap[K, V]) isRehashing() bool { 127 | return h.rehashidx != -1 128 | } 129 | 130 | // TODO 这个函数可以优化下 131 | func nextExp(size uint64) int8 { 132 | if size >= math.MaxUint64 { 133 | return 63 134 | } 135 | 136 | e := int8(HT_INITIAL_EXP) 137 | for { 138 | if 1<= size { 139 | return e 140 | } 141 | e++ 142 | } 143 | 144 | return e 145 | } 146 | 147 | func (h *HashMap[K, V]) expand() error { 148 | if h.isRehashing() { 149 | return nil 150 | } 151 | 152 | if hashSize(h.sizeExp[0]) == 0 { 153 | return h.Resize(HT_INITIAL_SIZE) 154 | } 155 | 156 | if h.used[0] >= hashSize(h.sizeExp[0]) || h.used[0]/hashSize(h.sizeExp[0]) > uint64(forceResizeRatio) { 157 | return h.Resize(h.used[0] + 1) 158 | } 159 | 160 | return nil 161 | } 162 | 163 | // 手动修改hashtable的大小 164 | func (h *HashMap[K, V]) Resize(size uint64) error { 165 | h.lazyinit() 166 | // 如果正在扩容中, 或者需要扩容的数据小于已存在的元素, 直接返回 167 | if h.isRehashing() || h.used[0] > uint64(size) { 168 | return ErrHashing 169 | } 170 | 171 | newSizeExp := nextExp(uint64(size)) 172 | // 新大小比需要的大小还小 173 | newSize := uint64(1 << newSizeExp) 174 | if newSize < size { 175 | return ErrSize 176 | } 177 | 178 | // 新扩容大小和以前的一样 179 | if uint64(newSizeExp) == uint64(h.sizeExp[0]) { 180 | return nil 181 | } 182 | 183 | newTable := make([]*entry[K, V], newSize) 184 | 185 | // 第一次初始化 186 | if h.table[0] == nil { 187 | h.sizeExp[0] = newSizeExp 188 | h.table[0] = newTable 189 | return nil 190 | } 191 | 192 | // 把新hash表放到table[1]里面 193 | h.sizeExp[1] = newSizeExp 194 | h.used[1] = 0 195 | h.table[1] = newTable 196 | h.rehashidx = 0 197 | return nil 198 | } 199 | 200 | // 收缩hash table 201 | func (h *HashMap[K, V]) ShrinkToFit() error { 202 | h.lazyinit() 203 | if h.isRehashing() { 204 | return ErrHashing 205 | } 206 | 207 | minimal := h.used[0] 208 | if minimal < HT_INITIAL_SIZE { 209 | minimal = HT_INITIAL_SIZE 210 | } 211 | 212 | return h.Resize(minimal) 213 | } 214 | 215 | // 返回索引值和entry 216 | func (h *HashMap[K, V]) findIndexAndEntry(key K) (i uint64, e *entry[K, V], err error) { 217 | if err := h.expand(); err != nil { 218 | return 0, nil, err 219 | } 220 | 221 | hashCode := h.calHash(key) 222 | idx := uint64(0) 223 | for table := 0; table < 2; table++ { 224 | idx = hashCode & sizeMask(h.sizeExp[table]) 225 | head := h.table[table][idx] 226 | for head != nil { 227 | if key == head.key { 228 | return idx, head, nil 229 | } 230 | 231 | head = head.next 232 | } 233 | 234 | if !h.isRehashing() { 235 | break 236 | } 237 | } 238 | 239 | return idx, nil, nil 240 | } 241 | 242 | func (h *HashMap[K, V]) rehash(n int) error { 243 | // 控制访问空槽位的个数 244 | emptyVisits := n * 10 245 | 246 | // 没有rehashing 就退出 247 | if !h.isRehashing() { 248 | return ErrHashing 249 | } 250 | 251 | // n是控制桶数 252 | for ; n > 0 && h.used[0] != 0; n-- { 253 | 254 | for h.table[0][h.rehashidx] == nil { 255 | h.rehashidx++ 256 | emptyVisits-- 257 | if emptyVisits == 0 { 258 | return nil 259 | } 260 | } 261 | 262 | // 取出hash槽中第一个元素 263 | head := h.table[0][h.rehashidx] 264 | for head != nil { 265 | next := head.next 266 | newIdx := h.calHash(head.key) & sizeMask(h.sizeExp[1]) 267 | head.next = h.table[1][newIdx] 268 | h.table[1][newIdx] = head 269 | h.used[0]-- 270 | h.used[1]++ 271 | head = next 272 | } 273 | 274 | h.table[0][h.rehashidx] = nil 275 | h.rehashidx++ 276 | } 277 | 278 | if h.used[0] == 0 { 279 | h.table[0] = h.table[1] 280 | h.used[0] = h.used[1] 281 | h.sizeExp[0] = h.sizeExp[1] 282 | h.reset(1) 283 | // 这里重装置为-1 284 | h.rehashidx = -1 285 | } 286 | return nil 287 | } 288 | 289 | func (h *HashMap[K, V]) reset(idx int) { 290 | h.table[idx] = nil 291 | h.sizeExp[idx] = -1 292 | h.used[idx] = 0 293 | } 294 | 295 | func hashSize(exp int8) uint64 { 296 | if exp == -1 { 297 | return 0 298 | } 299 | return 1 << exp 300 | } 301 | 302 | func sizeMask(exp int8) uint64 { 303 | if exp == -1 { 304 | return 0 305 | } 306 | 307 | return (1 << exp) - 1 308 | } 309 | 310 | // 获取 311 | func (h *HashMap[K, V]) TryGet(key K) (v V, ok bool) { 312 | if h.Len() == 0 { 313 | return 314 | } 315 | 316 | if h.isRehashing() { 317 | h.rehash(1) 318 | } 319 | 320 | hashCode := h.calHash(key) 321 | idx := uint64(0) 322 | for table := 0; table < 2; table++ { 323 | idx = hashCode & sizeMask(h.sizeExp[table]) 324 | head := h.table[table][idx] 325 | for head != nil { 326 | if key == head.key { 327 | return head.val, true 328 | } 329 | 330 | head = head.next 331 | } 332 | 333 | if !h.isRehashing() { 334 | break 335 | } 336 | } 337 | return 338 | } 339 | 340 | // 获取 341 | func (h *HashMap[K, V]) Get(key K) (v V) { 342 | v, _ = h.TryGet(key) 343 | return 344 | } 345 | 346 | // 遍历 347 | func (h *HashMap[K, V]) Range(pr func(key K, val V) bool) { 348 | if h.Len() == 0 { 349 | //err = ErrNotFound 350 | return 351 | } 352 | 353 | if h.isRehashing() { 354 | h.rehash(1) 355 | } 356 | 357 | length := h.Len() 358 | for table := 0; table < 2 && length > 0; table++ { 359 | 360 | for idx := 0; idx < len(h.table[table]); idx++ { 361 | head := h.table[table][idx] 362 | for head != nil { 363 | if !pr(head.key, head.val) { 364 | return 365 | } 366 | head = head.next 367 | } 368 | 369 | length-- 370 | } 371 | if !h.isRehashing() { 372 | break 373 | } 374 | } 375 | } 376 | 377 | func (h *HashMap[K, V]) Set(k K, v V) { 378 | h.Swap(k, v) 379 | } 380 | 381 | // 设置 382 | func (h *HashMap[K, V]) Swap(k K, v V) (prev V, replaced bool) { 383 | h.lazyinit() 384 | if h.isRehashing() { 385 | h.rehash(1) 386 | } 387 | 388 | index, e, err := h.findIndexAndEntry(k) 389 | if err != nil { 390 | return 391 | } 392 | 393 | idx := 0 394 | if h.isRehashing() { 395 | //如果在rehasing过程中, 如果这个key是第一次存入到hash table, 优先写入到新hash table中 396 | idx = 1 397 | } 398 | 399 | // element存在, 这里是替换 400 | if e != nil { 401 | //e.key = k 402 | prev = e.val 403 | e.val = v 404 | return prev, true 405 | } 406 | 407 | e = &entry[K, V]{key: k, val: v} 408 | e.next = h.table[idx][index] 409 | h.table[idx][index] = e 410 | h.used[idx]++ 411 | return 412 | } 413 | 414 | type InsertOrUpdateCb[V any] func(prev V, new V) V 415 | 416 | // InsertOrUpdate inserts or updates an element in the HashMap 417 | func (h *HashMap[K, V]) InsertOrUpdate(k K, v V, cb InsertOrUpdateCb[V]) { 418 | if prev, ok := h.TryGet(k); ok { 419 | v = cb(prev, v) 420 | } 421 | h.Set(k, v) 422 | } 423 | 424 | // Remove是delete别名 425 | func (h *HashMap[K, V]) Delete(key K) { 426 | h.Remove(key) 427 | } 428 | 429 | // 删除 430 | func (h *HashMap[K, V]) Remove(key K) (err error) { 431 | if h.Len() == 0 { 432 | err = ErrNotFound 433 | return 434 | } 435 | 436 | if h.isRehashing() { 437 | h.rehash(1) 438 | } 439 | 440 | hashCode := h.calHash(key) 441 | idx := uint64(0) 442 | for table := 0; table < 2; table++ { 443 | idx = hashCode & sizeMask(h.sizeExp[table]) 444 | var prev *entry[K, V] 445 | head := h.table[table][idx] 446 | for head != nil { 447 | if key == head.key { 448 | if prev != nil { 449 | // 使用双指针删除中间的元素 450 | prev.next = head.next 451 | } else { 452 | // 表头元素, 直接跳过就可以删除 453 | h.table[table][idx] = head.next 454 | } 455 | h.used[table]-- 456 | return nil 457 | } 458 | 459 | prev = head 460 | head = head.next 461 | } 462 | 463 | if !h.isRehashing() { 464 | break 465 | } 466 | } 467 | return nil 468 | } 469 | 470 | // 测试长度 471 | func (h *HashMap[K, V]) Len() int { 472 | return int(h.used[0] + h.used[1]) 473 | } 474 | -------------------------------------------------------------------------------- /rhashmap/rhashmap_bench_test.go: -------------------------------------------------------------------------------- 1 | package rhashmap 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | // goos: darwin 10 | // goarch: amd64 11 | // pkg: github.com/antlabs/gstl/rhashmap 12 | // cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz 13 | // BenchmarkGet-8 1000000000 0.4066 ns/op 14 | // BenchmarkGetStd-8 1000000000 0.8333 ns/op 15 | // PASS 16 | // ok github.com/antlabs/gstl/rhashmap 130.007s. 17 | // 比标准库快一倍. 18 | 19 | // goos: darwin 20 | // goarch: amd64 21 | // pkg: github.com/antlabs/gstl/rhashmap 22 | // cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz 23 | // BenchmarkSet-8 1000000000 0.1690 ns/op 24 | // BenchmarkSetStd-8 1000000000 0.1470 ns/op 25 | // PASS 26 | // ok github.com/antlabs/gstl/rhashmap 3.970s 27 | // 五百万数据的Get操作时间 28 | 29 | // TODO 再优化下性能 30 | // go1.19.1 31 | // 3kw 32 | // goos: darwin 33 | // goarch: arm64 34 | // pkg: github.com/antlabs/gstl/rhashmap 35 | // BenchmarkGet-8 34664005 62.20 ns/op 36 | // BenchmarkGetStd-8 30007470 49.40 ns/op 37 | // BenchmarkSet-8 14623854 178.9 ns/op 38 | // BenchmarkSetStd-8 22709601 74.71 ns/op 39 | // PASS 40 | // ok github.com/antlabs/gstl/rhashmap 16.521s 41 | func BenchmarkGet(b *testing.B) { 42 | //max := 1000000.0 * 5 43 | max := float64(b.N) 44 | set := NewWithOpt[float64, float64](WithCap(int(max))) 45 | for i := 0.0; i < max; i++ { 46 | set.Set(i, i) 47 | } 48 | 49 | b.ResetTimer() 50 | 51 | for i := 0.0; i < max; i++ { 52 | v := set.Get(i) 53 | if v != i { 54 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 55 | } 56 | } 57 | } 58 | 59 | func BenchmarkGetStd(b *testing.B) { 60 | 61 | max := float64(b.N) 62 | set := make(map[float64]float64, int(max)) 63 | for i := 0.0; i < max; i++ { 64 | set[i] = i 65 | } 66 | 67 | b.ResetTimer() 68 | 69 | for i := 0.0; i < max; i++ { 70 | v := set[i] 71 | if v != i { 72 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 73 | } 74 | } 75 | } 76 | 77 | // gstl set 78 | func BenchmarkSet(b *testing.B) { 79 | max := float64(b.N) 80 | set := NewWithOpt[float64, float64](WithCap(int(max))) 81 | for i := 0.0; i < max; i++ { 82 | set.Set(i, i) 83 | } 84 | 85 | } 86 | 87 | // 标准库set 88 | func BenchmarkSetStd(b *testing.B) { 89 | 90 | max := float64(b.N) 91 | set := make(map[float64]float64, int(max)) 92 | for i := 0.0; i < max; i++ { 93 | set[i] = i 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /rhashmap/rhashmap_test.go: -------------------------------------------------------------------------------- 1 | package rhashmap 2 | 3 | import ( 4 | "sort" 5 | "testing" 6 | ) 7 | 8 | // 1. set get测试 9 | // key string value bool 10 | func Test_SetGet_StringBool(t *testing.T) { 11 | hm := New[string, bool]() 12 | hm.Set("hello", true) 13 | hm.Set("world", true) 14 | hm.Set("ni", true) 15 | hm.Set("hao", true) 16 | 17 | if !hm.Get("hello") { 18 | t.Errorf("Expected true, got false for key 'hello'") 19 | } 20 | if !hm.Get("world") { 21 | t.Errorf("Expected true, got false for key 'world'") 22 | } 23 | if !hm.Get("ni") { 24 | t.Errorf("Expected true, got false for key 'ni'") 25 | } 26 | if !hm.Get("hao") { 27 | t.Errorf("Expected true, got false for key 'hao'") 28 | } 29 | } 30 | 31 | // 1. set get测试 32 | // key string, value string 33 | func Test_SetGet_StringString(t *testing.T) { 34 | hm := New[string, string]() 35 | hm.Set("hello", "hello") 36 | hm.Set("world", "world") 37 | hm.Set("ni", "ni") 38 | hm.Set("hao", "hao") 39 | 40 | if hm.Get("hello") != "hello" { 41 | t.Errorf("Expected 'hello', got %v for key 'hello'", hm.Get("hello")) 42 | } 43 | if hm.Get("world") != "world" { 44 | t.Errorf("Expected 'world', got %v for key 'world'", hm.Get("world")) 45 | } 46 | if hm.Get("ni") != "ni" { 47 | t.Errorf("Expected 'ni', got %v for key 'ni'", hm.Get("ni")) 48 | } 49 | if hm.Get("hao") != "hao" { 50 | t.Errorf("Expected 'hao', got %v for key 'hao'", hm.Get("hao")) 51 | } 52 | } 53 | 54 | // 1. set get测试 55 | // key string value string 56 | func Test_SetGet_IntString(t *testing.T) { 57 | hm := New[int, string]() 58 | hm.Set(1, "hello") 59 | hm.Set(2, "world") 60 | hm.Set(3, "ni") 61 | hm.Set(4, "hao") 62 | 63 | if hm.Get(1) != "hello" { 64 | t.Errorf("Expected 'hello', got %v for key 1", hm.Get(1)) 65 | } 66 | if hm.Get(2) != "world" { 67 | t.Errorf("Expected 'world', got %v for key 2", hm.Get(2)) 68 | } 69 | if hm.Get(3) != "ni" { 70 | t.Errorf("Expected 'ni', got %v for key 3", hm.Get(3)) 71 | } 72 | if hm.Get(4) != "hao" { 73 | t.Errorf("Expected 'hao', got %v for key 4", hm.Get(4)) 74 | } 75 | } 76 | 77 | // 1. set get测试 78 | func Test_SetGet_IntString_Lazyinit(t *testing.T) { 79 | var hm HashMap[int, string] 80 | hm.Set(1, "hello") 81 | hm.Set(2, "world") 82 | hm.Set(3, "ni") 83 | hm.Set(4, "hao") 84 | 85 | if hm.Get(1) != "hello" { 86 | t.Errorf("Expected 'hello', got %v for key 1", hm.Get(1)) 87 | } 88 | if hm.Get(2) != "world" { 89 | t.Errorf("Expected 'world', got %v for key 2", hm.Get(2)) 90 | } 91 | if hm.Get(3) != "ni" { 92 | t.Errorf("Expected 'ni', got %v for key 3", hm.Get(3)) 93 | } 94 | if hm.Get(4) != "hao" { 95 | t.Errorf("Expected 'hao', got %v for key 4", hm.Get(4)) 96 | } 97 | } 98 | 99 | // 1. set get测试 100 | // 设计重复key 101 | func Test_SetGet_Replace_IntString(t *testing.T) { 102 | hm := New[int, string]() 103 | hm.Set(1, "hello") 104 | hm.Set(1, "world") 105 | 106 | if hm.Get(1) != "world" { 107 | t.Errorf("Expected 'world', got %v for key 1", hm.Get(1)) 108 | } 109 | } 110 | 111 | // 1. set get测试 112 | // 获取空值数据 113 | func Test_SetGet_Zero(t *testing.T) { 114 | hm := New[int, int]() 115 | for i := 0; i < 10; i++ { 116 | if hm.Get(i) != 0 { 117 | t.Errorf("Expected 0, got %v for key %d", hm.Get(i), i) 118 | } 119 | } 120 | 121 | for i := 0; i < 10; i++ { 122 | v, err := hm.TryGet(i) 123 | if err { 124 | t.Errorf("Expected false, got true for key %d", i) 125 | } 126 | if v != 0 { 127 | t.Errorf("Expected 0, got %v for key %d", v, i) 128 | } 129 | } 130 | } 131 | 132 | // 1. set get测试 133 | // 测试重复key 134 | func Test_SetGet_NotFound(t *testing.T) { 135 | hm := New[int, string]() 136 | hm.Set(1, "hello") 137 | hm.Set(1, "world") 138 | 139 | _, err := hm.TryGet(3) 140 | 141 | if err { 142 | t.Errorf("Expected false, got true for key 3") 143 | } 144 | if hm.Get(1) != "world" { 145 | t.Errorf("Expected 'world', got %v for key 1", hm.Get(1)) 146 | } 147 | } 148 | 149 | // 1. set get测试 150 | // 测试 151 | func Test_SetGet_Rehashing(t *testing.T) { 152 | hm := New[int, string]() 153 | hm.Set(1, "hello") 154 | hm.Set(2, "world") 155 | hm.Set(3, "hello") 156 | hm.Set(4, "world") 157 | hm.Set(5, "world") 158 | 159 | _, err := hm.TryGet(7) 160 | 161 | if err { 162 | t.Errorf("Expected false, got true for key 7") 163 | } 164 | if hm.Get(1) != "hello" { 165 | t.Errorf("Expected 'hello', got %v for key 1", hm.Get(1)) 166 | } 167 | } 168 | 169 | // 测试Len接口 170 | func Test_Len(t *testing.T) { 171 | hm := New[int, int]() 172 | max := 3333 173 | for i := 0; i < max; i++ { 174 | hm.Set(i, i) 175 | } 176 | if hm.Len() != max { 177 | t.Errorf("Expected %d, got %v", max, hm.Len()) 178 | } 179 | } 180 | 181 | // 2.测试删除功能 182 | func Test_Delete(t *testing.T) { 183 | hm := New[int, int]() 184 | 185 | max := 3333 186 | for i := 0; i < max; i++ { 187 | hm.Set(i, i) 188 | } 189 | 190 | for i := 0; i < max; i++ { 191 | hm.Delete(i) 192 | } 193 | if hm.Len() != 0 { 194 | t.Errorf("Expected 0, got %v", hm.Len()) 195 | } 196 | } 197 | 198 | // 2. 测试删除功能 199 | func Test_Delete_NotFound(t *testing.T) { 200 | hm := New[int, int]() 201 | 202 | max := 4 //不要修改4 203 | for i := 0; i < max; i++ { 204 | hm.Set(i, i) 205 | } 206 | 207 | hm.Delete(max + 1) 208 | for i := 0; i < max; i++ { 209 | hm.Delete(i) 210 | } 211 | 212 | if hm.Len() != 0 { 213 | t.Errorf("Expected 0, got %v", hm.Len()) 214 | } 215 | } 216 | 217 | // 2. 测试删除功能 218 | func Test_Delete_Empty(t *testing.T) { 219 | hm := New[int, int]() 220 | 221 | err := hm.Remove(0) 222 | if err == nil { 223 | t.Errorf("Expected error, got nil") 224 | } 225 | if hm.Len() != 0 { 226 | t.Errorf("Expected 0, got %v", hm.Len()) 227 | } 228 | } 229 | 230 | // 3. 测试Range 231 | func Test_Range(t *testing.T) { 232 | max := 100 233 | hm := NewWithOpt[int, int](WithCap(max)) 234 | need := []int{} 235 | for i := 0; i < max; i++ { 236 | need = append(need, i, i) 237 | hm.Set(i, i) 238 | } 239 | 240 | if hm.Len() != max { 241 | t.Errorf("Expected %d, got %v", max, hm.Len()) 242 | } 243 | got := make([]int, 0, max) 244 | hm.Range(func(key int, val int) bool { 245 | got = append(got, key, val) 246 | return true 247 | }) 248 | 249 | sort.Ints(got) 250 | if !slicesEqual(need, got) { 251 | t.Errorf("Expected %v, got %v", need, got) 252 | } 253 | } 254 | 255 | // 3. 测试Range 256 | func Test_Range_Zero(t *testing.T) { 257 | max := 0 258 | hm := New[int, int]() 259 | need := []int{} 260 | 261 | if hm.Len() != max { 262 | t.Errorf("Expected %d, got %v", max, hm.Len()) 263 | } 264 | got := make([]int, 0, max) 265 | hm.Range(func(key int, val int) bool { 266 | got = append(got, key, val) 267 | return true 268 | }) 269 | 270 | sort.Ints(got) 271 | if !slicesEqual(need, got) { 272 | t.Errorf("Expected %v, got %v", need, got) 273 | } 274 | } 275 | 276 | func Test_Range_Rehasing(t *testing.T) { 277 | max := 5 278 | hm := New[int, int]() 279 | need := []int{} 280 | for i := 0; i < max; i++ { 281 | hm.Set(i, i) 282 | need = append(need, i, i) 283 | } 284 | 285 | if hm.Len() != max { 286 | t.Errorf("Expected %d, got %v", max, hm.Len()) 287 | } 288 | got := make([]int, 0, max) 289 | hm.Range(func(key int, val int) bool { 290 | got = append(got, key, val) 291 | return true 292 | }) 293 | 294 | sort.Ints(got) 295 | if !slicesEqual(need, got) { 296 | t.Errorf("Expected %v, got %v", need, got) 297 | } 298 | } 299 | 300 | // 测试shrink 301 | func Test_Range_ShrinkToFit(t *testing.T) { 302 | hm := New[int, int]() 303 | 304 | max := 3333 305 | for i := 0; i < max; i++ { 306 | hm.Set(i, i) 307 | } 308 | 309 | for i := 0; i < max; i++ { 310 | hm.Delete(i) 311 | } 312 | 313 | err := hm.ShrinkToFit() 314 | if err != nil { 315 | t.Errorf("Expected no error, got %v", err) 316 | } 317 | if hm.Len() != 0 { 318 | t.Errorf("Expected 0, got %v", hm.Len()) 319 | } 320 | } 321 | 322 | func Test_HashMap_InsertOrUpdate(t *testing.T) { 323 | hm := New[int, int]() 324 | max := 100 325 | 326 | // Insert elements 327 | for i := 0; i < max; i++ { 328 | hm.InsertOrUpdate(i, i, func(prev, new int) int { 329 | return prev + new 330 | }) 331 | } 332 | 333 | // Update elements 334 | for i := 0; i < max; i++ { 335 | hm.InsertOrUpdate(i, i, func(prev, new int) int { 336 | return prev + new 337 | }) 338 | } 339 | 340 | // Verify elements 341 | for i := 0; i < max; i++ { 342 | v, ok := hm.TryGet(i) 343 | if !ok || v != i*2 { 344 | t.Errorf("expected %d, got %v", i*2, v) 345 | } 346 | } 347 | } 348 | 349 | // Helper function to compare slices 350 | func slicesEqual[T comparable](a, b []T) bool { 351 | if len(a) != len(b) { 352 | return false 353 | } 354 | for i := range a { 355 | if a[i] != b[i] { 356 | return false 357 | } 358 | } 359 | return true 360 | } 361 | -------------------------------------------------------------------------------- /rwmap/rwmap.go: -------------------------------------------------------------------------------- 1 | // apache 2.0 antlabs 2 | 3 | package rwmap 4 | 5 | import ( 6 | "sync" 7 | 8 | "github.com/antlabs/gstl/api" 9 | "github.com/antlabs/gstl/mapex" 10 | ) 11 | 12 | // type Pair[K comparable, V any] = mapex.Pair[K comparable, V any] 13 | type Pair[K comparable, V any] struct { 14 | Key K 15 | Val V 16 | } 17 | 18 | var _ api.CMaper[int, int] = (*RWMap[int, int])(nil) 19 | 20 | type RWMap[K comparable, V any] struct { 21 | rw sync.RWMutex 22 | m map[K]V 23 | } 24 | 25 | // 通过new函数分配可以指定map的长度 26 | func New[K comparable, V any](l ...int) *RWMap[K, V] { 27 | if len(l) == 0 { 28 | return &RWMap[K, V]{ 29 | m: make(map[K]V), 30 | } 31 | } 32 | return &RWMap[K, V]{ 33 | m: make(map[K]V, l[0]), 34 | } 35 | } 36 | 37 | func (r *RWMap[K, V]) ToMap() map[K]V { 38 | return r.m 39 | } 40 | 41 | // 删除 42 | func (r *RWMap[K, V]) Delete(key K) { 43 | r.rw.Lock() 44 | delete(r.m, key) 45 | r.rw.Unlock() 46 | } 47 | 48 | // 加载 49 | func (r *RWMap[K, V]) Load(key K) (value V, ok bool) { 50 | r.rw.RLock() 51 | value, ok = r.m[key] 52 | r.rw.RUnlock() 53 | return 54 | } 55 | 56 | // 获取值,然后并删除 57 | func (r *RWMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) { 58 | r.rw.Lock() 59 | if r.m == nil { 60 | r.rw.Unlock() 61 | return 62 | } 63 | value, loaded = r.m[key] 64 | delete(r.m, key) 65 | r.rw.Unlock() 66 | return 67 | } 68 | 69 | // 存在返回现有的值,loaded 为true 70 | // 不存在就保存现在的值,loaded为false 71 | func (r *RWMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { 72 | r.rw.Lock() 73 | if r.m == nil { 74 | r.m = make(map[K]V) 75 | } 76 | actual, loaded = r.m[key] 77 | if !loaded { 78 | actual = value 79 | r.m[key] = actual 80 | } 81 | r.rw.Unlock() 82 | return 83 | } 84 | 85 | func (r *RWMap[K, V]) Range(f func(key K, value V) bool) { 86 | r.rw.RLock() 87 | for k, v := range r.m { 88 | if !f(k, v) { 89 | break 90 | } 91 | } 92 | r.rw.RUnlock() 93 | } 94 | 95 | func (r *RWMap[K, V]) Iter() <-chan Pair[K, V] { 96 | p := make(chan Pair[K, V]) 97 | go func() { 98 | r.rw.RLock() 99 | for k, v := range r.m { 100 | p <- Pair[K, V]{Key: k, Val: v} 101 | } 102 | close(p) 103 | r.rw.RUnlock() 104 | }() 105 | return p 106 | } 107 | 108 | // 保存值 109 | func (r *RWMap[K, V]) Store(key K, value V) { 110 | r.rw.Lock() 111 | if r.m == nil { 112 | r.m = make(map[K]V) 113 | } 114 | r.m[key] = value 115 | r.rw.Unlock() 116 | } 117 | 118 | // keys 119 | func (r *RWMap[K, V]) Keys() (keys []K) { 120 | r.rw.RLock() 121 | if r.m == nil { 122 | r.rw.RUnlock() 123 | return 124 | } 125 | keys = mapex.Keys(r.m) 126 | r.rw.RUnlock() 127 | return keys 128 | } 129 | 130 | // vals 131 | func (r *RWMap[K, V]) Values() (values []V) { 132 | r.rw.RLock() 133 | if r.m == nil { 134 | r.rw.RUnlock() 135 | return 136 | } 137 | values = mapex.Values(r.m) 138 | r.rw.RUnlock() 139 | return values 140 | } 141 | 142 | // 返回长度 143 | func (r *RWMap[K, V]) Len() (l int) { 144 | r.rw.RLock() 145 | l = len(r.m) 146 | r.rw.RUnlock() 147 | return 148 | } 149 | -------------------------------------------------------------------------------- /rwmap/rwmap_bench_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | // guonaihong: 修改如下 5 | // 1. interface的地方换成泛型语法 6 | 7 | package rwmap 8 | 9 | import ( 10 | "fmt" 11 | "reflect" 12 | "sync" 13 | "sync/atomic" 14 | "testing" 15 | 16 | "github.com/antlabs/gstl/api" 17 | ) 18 | 19 | type syncmap[K comparable, V any] struct { 20 | m sync.Map 21 | } 22 | 23 | func (c *syncmap[K, V]) Delete(key K) { 24 | c.m.Delete(key) 25 | } 26 | 27 | func (c *syncmap[K, V]) Load(key K) (value V, ok bool) { 28 | v, ok := c.m.Load(key) 29 | if !ok { 30 | return 31 | } 32 | 33 | return v.(V), ok 34 | } 35 | 36 | func (c *syncmap[K, V]) LoadAndDelete(key K) (value V, loaded bool) { 37 | v, ok := c.m.LoadAndDelete(key) 38 | if !ok { 39 | return 40 | } 41 | 42 | return v.(V), ok 43 | } 44 | 45 | func (c *syncmap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { 46 | v, ok := c.m.LoadOrStore(key, value) 47 | if !ok { 48 | return 49 | } 50 | return v.(V), ok 51 | } 52 | 53 | func (c *syncmap[K, V]) Range(f func(key K, value V) bool) { 54 | c.m.Range(func(key any, value any) bool { 55 | return f(key.(K), value.(V)) 56 | }) 57 | } 58 | 59 | func (c *syncmap[K, V]) Store(key K, value V) { 60 | c.m.Store(key, value) 61 | } 62 | 63 | type bench[K comparable, V any] struct { 64 | setup func(*testing.B, api.CMaper[K, V]) 65 | perG func(b *testing.B, pb *testing.PB, i int, m api.CMaper[K, V]) 66 | } 67 | 68 | func benchMap(b *testing.B, bench bench[int, int]) { 69 | for _, m := range [...]api.CMaper[int, int]{New[int, int](), &syncmap[int, int]{}} { 70 | b.Run(fmt.Sprintf("%T", m), func(b *testing.B) { 71 | m = reflect.New(reflect.TypeOf(m).Elem()).Interface().(api.CMaper[int, int]) 72 | // if m2, ok := m.(*RWMap[int, int]); ok { 73 | // m2.init(64) 74 | // } 75 | 76 | if bench.setup != nil { 77 | bench.setup(b, m) 78 | } 79 | 80 | b.ResetTimer() 81 | 82 | var i int64 83 | b.RunParallel(func(pb *testing.PB) { 84 | id := int(atomic.AddInt64(&i, 1) - 1) 85 | bench.perG(b, pb, id*b.N, m) 86 | }) 87 | }) 88 | } 89 | } 90 | 91 | func BenchmarkLoadMostlyHits(b *testing.B) { 92 | const hits, misses = 1023, 1 93 | 94 | benchMap(b, bench[int, int]{ 95 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 96 | for i := 0; i < hits; i++ { 97 | m.LoadOrStore(i, i) 98 | } 99 | // Prime the map to get it into a steady state. 100 | for i := 0; i < hits*2; i++ { 101 | m.Load(i % hits) 102 | } 103 | }, 104 | 105 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 106 | for ; pb.Next(); i++ { 107 | m.Load(i % (hits + misses)) 108 | } 109 | }, 110 | }) 111 | } 112 | 113 | func BenchmarkLoadMostlyMisses(b *testing.B) { 114 | const hits, misses = 1, 1023 115 | 116 | benchMap(b, bench[int, int]{ 117 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 118 | for i := 0; i < hits; i++ { 119 | m.LoadOrStore(i, i) 120 | } 121 | // Prime the map to get it into a steady state. 122 | for i := 0; i < hits*2; i++ { 123 | m.Load(i % hits) 124 | } 125 | }, 126 | 127 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 128 | for ; pb.Next(); i++ { 129 | m.Load(i % (hits + misses)) 130 | } 131 | }, 132 | }) 133 | } 134 | 135 | func BenchmarkLoadOrStoreBalanced(b *testing.B) { 136 | const hits, misses = 128, 128 137 | 138 | benchMap(b, bench[int, int]{ 139 | setup: func(b *testing.B, m api.CMaper[int, int]) { 140 | for i := 0; i < hits; i++ { 141 | m.LoadOrStore(i, i) 142 | } 143 | // Prime the map to get it into a steady state. 144 | for i := 0; i < hits*2; i++ { 145 | m.Load(i % hits) 146 | } 147 | }, 148 | 149 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 150 | for ; pb.Next(); i++ { 151 | j := i % (hits + misses) 152 | if j < hits { 153 | if _, ok := m.LoadOrStore(j, i); !ok { 154 | b.Fatalf("unexpected miss for %v", j) 155 | } 156 | } else { 157 | if v, loaded := m.LoadOrStore(i, i); loaded { 158 | b.Fatalf("failed to store %v: existing value %v", i, v) 159 | } 160 | } 161 | } 162 | }, 163 | }) 164 | } 165 | 166 | func BenchmarkLoadOrStoreUnique(b *testing.B) { 167 | benchMap(b, bench[int, int]{ 168 | setup: func(b *testing.B, m api.CMaper[int, int]) { 169 | }, 170 | 171 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 172 | for ; pb.Next(); i++ { 173 | m.LoadOrStore(i, i) 174 | } 175 | }, 176 | }) 177 | } 178 | 179 | func BenchmarkDelete(b *testing.B) { 180 | benchMap(b, bench[int, int]{ 181 | setup: func(b *testing.B, m api.CMaper[int, int]) { 182 | for i := 0; i < 1000000; i++ { 183 | m.Store(i, i) 184 | } 185 | }, 186 | 187 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 188 | for ; pb.Next(); i++ { 189 | m.Delete(i) 190 | } 191 | }, 192 | }) 193 | } 194 | 195 | func BenchmarkStore(b *testing.B) { 196 | benchMap(b, bench[int, int]{ 197 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 198 | //m.LoadOrStore(0, 0) 199 | }, 200 | 201 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 202 | for ; pb.Next(); i++ { 203 | m.Store(i, i) 204 | } 205 | }, 206 | }) 207 | } 208 | 209 | func BenchmarkLoadOrStoreCollision(b *testing.B) { 210 | benchMap(b, bench[int, int]{ 211 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 212 | m.LoadOrStore(0, 0) 213 | }, 214 | 215 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 216 | for ; pb.Next(); i++ { 217 | m.LoadOrStore(0, 0) 218 | } 219 | }, 220 | }) 221 | } 222 | 223 | func BenchmarkLoadAndDeleteBalanced(b *testing.B) { 224 | const hits, misses = 128, 128 225 | 226 | benchMap(b, bench[int, int]{ 227 | setup: func(b *testing.B, m api.CMaper[int, int]) { 228 | for i := 0; i < hits; i++ { 229 | m.LoadOrStore(i, i) 230 | } 231 | // Prime the map to get it into a steady state. 232 | for i := 0; i < hits*2; i++ { 233 | m.Load(i % hits) 234 | } 235 | }, 236 | 237 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 238 | for ; pb.Next(); i++ { 239 | j := i % (hits + misses) 240 | if j < hits { 241 | m.LoadAndDelete(j) 242 | } else { 243 | m.LoadAndDelete(i) 244 | } 245 | } 246 | }, 247 | }) 248 | } 249 | 250 | func BenchmarkLoadAndDeleteUnique(b *testing.B) { 251 | benchMap(b, bench[int, int]{ 252 | setup: func(b *testing.B, m api.CMaper[int, int]) { 253 | }, 254 | 255 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 256 | for ; pb.Next(); i++ { 257 | m.LoadAndDelete(i) 258 | } 259 | }, 260 | }) 261 | } 262 | 263 | func BenchmarkLoadAndDeleteCollision(b *testing.B) { 264 | benchMap(b, bench[int, int]{ 265 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 266 | m.LoadOrStore(0, 0) 267 | }, 268 | 269 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 270 | for ; pb.Next(); i++ { 271 | m.LoadAndDelete(0) 272 | } 273 | }, 274 | }) 275 | } 276 | 277 | func BenchmarkRange(b *testing.B) { 278 | const mapSize = 1 << 10 279 | 280 | benchMap(b, bench[int, int]{ 281 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 282 | for i := 0; i < mapSize; i++ { 283 | m.Store(i, i) 284 | } 285 | }, 286 | 287 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 288 | for ; pb.Next(); i++ { 289 | m.Range(func(_, _ int) bool { return true }) 290 | } 291 | }, 292 | }) 293 | } 294 | 295 | // BenchmarkAdversarialAlloc tests performance when we store a new value 296 | // immediately whenever the map is promoted to clean and otherwise load a 297 | // unique, missing key. 298 | // 299 | // This forces the Load calls to always acquire the map's mutex. 300 | func BenchmarkAdversarialAlloc(b *testing.B) { 301 | benchMap(b, bench[int, int]{ 302 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 303 | var stores, loadsSinceStore int 304 | for ; pb.Next(); i++ { 305 | m.Load(i) 306 | if loadsSinceStore++; loadsSinceStore > stores { 307 | m.LoadOrStore(i, stores) 308 | loadsSinceStore = 0 309 | stores++ 310 | } 311 | } 312 | }, 313 | }) 314 | } 315 | 316 | // BenchmarkAdversarialDelete tests performance when we periodically delete 317 | // one key and add a different one in a large map. 318 | // 319 | // This forces the Load calls to always acquire the map's mutex and periodically 320 | // makes a full copy of the map despite changing only one entry. 321 | 322 | // 这个case不测试, 锁分区的方式这么使用会死锁 323 | /* 324 | func BenchmarkAdversarialDelete(b *testing.B) { 325 | const mapSize = 1 << 10 326 | 327 | benchMap(b, bench[int, int]{ 328 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 329 | for i := 0; i < mapSize; i++ { 330 | m.Store(i, i) 331 | } 332 | }, 333 | 334 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 335 | for ; pb.Next(); i++ { 336 | m.Load(i) 337 | 338 | if i%mapSize == 0 { 339 | m.Range(func(k, _ int) bool { 340 | m.Delete(k) 341 | return false 342 | }) 343 | m.Store(i, i) 344 | } 345 | } 346 | }, 347 | }) 348 | } 349 | */ 350 | 351 | func BenchmarkDeleteCollision(b *testing.B) { 352 | benchMap(b, bench[int, int]{ 353 | setup: func(_ *testing.B, m api.CMaper[int, int]) { 354 | m.LoadOrStore(0, 0) 355 | }, 356 | 357 | perG: func(b *testing.B, pb *testing.PB, i int, m api.CMaper[int, int]) { 358 | for ; pb.Next(); i++ { 359 | m.Delete(0) 360 | } 361 | }, 362 | }) 363 | } 364 | -------------------------------------------------------------------------------- /rwmap/rwmap_test.go: -------------------------------------------------------------------------------- 1 | package rwmap 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | "sync" 7 | "testing" 8 | ) 9 | 10 | // Store And Load 11 | func Test_StoreAndLoad(t *testing.T) { 12 | var m RWMap[string, string] 13 | m.Store("hello", "1") 14 | m.Store("world", "2") 15 | v1, ok1 := m.Load("hello") 16 | if v1 != "1" { 17 | t.Errorf("expected '1', got '%s'", v1) 18 | } 19 | if !ok1 { 20 | t.Errorf("expected true, got false") 21 | } 22 | 23 | v1, ok1 = m.Load("world") 24 | if v1 != "2" { 25 | t.Errorf("expected '2', got '%s'", v1) 26 | } 27 | if !ok1 { 28 | t.Errorf("expected true, got false") 29 | } 30 | } 31 | 32 | // Store And Load 33 | func Test_StoreDeleteLoad(t *testing.T) { 34 | var m RWMap[string, string] 35 | m.Store("hello", "1") 36 | m.Store("world", "2") 37 | 38 | m.Delete("hello") 39 | m.Delete("world") 40 | 41 | v1, ok1 := m.Load("hello") 42 | if v1 != "" { 43 | t.Errorf("expected '', got '%s'", v1) 44 | } 45 | if ok1 { 46 | t.Errorf("expected false, got true") 47 | } 48 | 49 | v1, ok1 = m.Load("world") 50 | if v1 != "" { 51 | t.Errorf("expected '', got '%s'", v1) 52 | } 53 | if ok1 { 54 | t.Errorf("expected false, got true") 55 | } 56 | } 57 | 58 | func Test_LoadAndDelete(t *testing.T) { 59 | var m RWMap[string, string] 60 | v1, ok1 := m.LoadAndDelete("hello") 61 | 62 | if v1 != "" { 63 | t.Errorf("expected '', got '%s'", v1) 64 | } 65 | if ok1 { 66 | t.Errorf("expected false, got true") 67 | } 68 | 69 | m.Store("hello", "world") 70 | v1, ok1 = m.Load("hello") 71 | 72 | if v1 != "world" { 73 | t.Errorf("expected 'world', got '%s'", v1) 74 | } 75 | 76 | v1, ok1 = m.LoadAndDelete("hello") 77 | if v1 != "world" { 78 | t.Errorf("expected 'world', got '%s'", v1) 79 | } 80 | if !ok1 { 81 | t.Errorf("expected true, got false") 82 | } 83 | } 84 | 85 | func Test_loadOrStore(t *testing.T) { 86 | var m RWMap[string, string] 87 | var m2 sync.Map 88 | v1, ok1 := m.LoadOrStore("hello", "world") 89 | v2, ok2 := m2.LoadOrStore("hello", "world") 90 | 91 | if ok1 != ok2 { 92 | t.Errorf("expected %v, got %v", ok2, ok1) 93 | } 94 | if v1 != v2.(string) { 95 | t.Errorf("expected '%s', got '%s'", v2.(string), v1) 96 | } 97 | } 98 | 99 | func Test_RangeBreak(t *testing.T) { 100 | var m RWMap[string, string] 101 | m.Store("1", "1") 102 | m.Store("2", "2") 103 | 104 | count := 0 105 | m.Range(func(key, val string) bool { 106 | count++ 107 | return false 108 | }) 109 | 110 | if count != 1 { 111 | t.Errorf("expected 1, got %d", count) 112 | } 113 | } 114 | 115 | func Test_Range(t *testing.T) { 116 | var m RWMap[string, string] 117 | max := 5 118 | keyAll := []string{} 119 | valAll := []string{} 120 | 121 | for i := 1; i < max; i++ { 122 | key := fmt.Sprintf("%dk", i) 123 | val := fmt.Sprintf("%dv", i) 124 | keyAll = append(keyAll, key) 125 | valAll = append(valAll, val) 126 | m.Store(key, val) 127 | } 128 | 129 | gotKey := []string{} 130 | gotVal := []string{} 131 | m.Range(func(key, val string) bool { 132 | gotKey = append(gotKey, key) 133 | gotVal = append(gotVal, val) 134 | return true 135 | }) 136 | 137 | sort.Strings(gotKey) 138 | sort.Strings(gotVal) 139 | 140 | if !equalSlices(keyAll, gotKey) { 141 | t.Errorf("expected keys %v, got %v", keyAll, gotKey) 142 | } 143 | if !equalSlices(valAll, gotVal) { 144 | t.Errorf("expected values %v, got %v", valAll, gotVal) 145 | } 146 | } 147 | 148 | func Test_Iter(t *testing.T) { 149 | var m RWMap[string, string] 150 | max := 5 151 | keyAll := []string{} 152 | valAll := []string{} 153 | 154 | for i := 1; i < max; i++ { 155 | key := fmt.Sprintf("%dk", i) 156 | val := fmt.Sprintf("%dv", i) 157 | keyAll = append(keyAll, key) 158 | valAll = append(valAll, val) 159 | m.Store(key, val) 160 | } 161 | 162 | gotKey := []string{} 163 | gotVal := []string{} 164 | for pair := range m.Iter() { 165 | gotKey = append(gotKey, pair.Key) 166 | gotVal = append(gotVal, pair.Val) 167 | } 168 | 169 | sort.Strings(gotKey) 170 | sort.Strings(gotVal) 171 | 172 | if !equalSlices(keyAll, gotKey) { 173 | t.Errorf("expected keys %v, got %v", keyAll, gotKey) 174 | } 175 | if !equalSlices(valAll, gotVal) { 176 | t.Errorf("expected values %v, got %v", valAll, gotVal) 177 | } 178 | } 179 | 180 | func Test_Len(t *testing.T) { 181 | var m RWMap[string, string] 182 | m.Store("1", "1") 183 | m.Store("2", "2") 184 | m.Store("3", "3") 185 | if m.Len() != 3 { 186 | t.Errorf("expected 3, got %d", m.Len()) 187 | } 188 | } 189 | 190 | func Test_New(t *testing.T) { 191 | m := New[string, string](3) 192 | m.Store("1", "1") 193 | m.Store("2", "2") 194 | m.Store("3", "3") 195 | if m.Len() != 3 { 196 | t.Errorf("expected 3, got %d", m.Len()) 197 | } 198 | } 199 | 200 | func Test_Keys(t *testing.T) { 201 | m := New[string, string](3) 202 | m.Store("a", "1") 203 | m.Store("b", "2") 204 | m.Store("c", "3") 205 | get := m.Keys() 206 | sort.Strings(get) 207 | if !equalSlices(get, []string{"a", "b", "c"}) { 208 | t.Errorf("expected keys %v, got %v", []string{"a", "b", "c"}, get) 209 | } 210 | 211 | var m2 RWMap[string, string] 212 | if len(m2.Values()) != 0 { 213 | t.Errorf("expected 0, got %d", len(m2.Values())) 214 | } 215 | } 216 | 217 | func Test_Values(t *testing.T) { 218 | m := New[string, string](3) 219 | m.Store("a", "1") 220 | m.Store("b", "2") 221 | m.Store("c", "3") 222 | get := m.Values() 223 | sort.Strings(get) 224 | if !equalSlices(get, []string{"1", "2", "3"}) { 225 | t.Errorf("expected values %v, got %v", []string{"1", "2", "3"}, get) 226 | } 227 | 228 | var m2 RWMap[string, string] 229 | if len(m2.Keys()) != 0 { 230 | t.Errorf("expected 0, got %d", len(m2.Keys())) 231 | } 232 | } 233 | 234 | // 辅助函数,用于比较两个切片是否相等 235 | func equalSlices(a, b []string) bool { 236 | if len(a) != len(b) { 237 | return false 238 | } 239 | for i := range a { 240 | if a[i] != b[i] { 241 | return false 242 | } 243 | } 244 | return true 245 | } 246 | -------------------------------------------------------------------------------- /set/set.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "github.com/antlabs/gstl/api" 6 | "github.com/antlabs/gstl/rbtree" 7 | "golang.org/x/exp/constraints" 8 | ) 9 | 10 | type Set[K constraints.Ordered] struct { 11 | api.SortedMap[K, struct{}] 12 | } 13 | 14 | // 创建一个空的slice 15 | func New[K constraints.Ordered]() *Set[K] { 16 | // 随手使用rbtree,后面压测再决定使用 17 | return &Set[K]{SortedMap: rbtree.New[K, struct{}]()} 18 | } 19 | 20 | // 从slice创建set 21 | func From[K constraints.Ordered](s ...K) *Set[K] { 22 | var b rbtree.RBTree[K, struct{}] 23 | for _, v := range s { 24 | b.Set(v, struct{}{}) 25 | } 26 | 27 | return &Set[K]{SortedMap: &b} 28 | } 29 | 30 | // 给集合添加元素 31 | func (s *Set[K]) Set(k K) { 32 | s.SortedMap.Set(k, struct{}{}) 33 | } 34 | 35 | // 返回集合中元素的个数 36 | func (s *Set[K]) Len() int { 37 | return s.SortedMap.Len() 38 | } 39 | 40 | func (s *Set[K]) ToSlice() (new []K) { 41 | new = make([]K, 0, s.Len()) 42 | s.Range(func(k K) bool { 43 | new = append(new, k) 44 | return true 45 | }) 46 | return 47 | } 48 | 49 | // 深度复制一个集合 50 | func (s *Set[K]) Clone() (new *Set[K]) { 51 | new = New[K]() 52 | s.Range(func(k K) bool { 53 | new.Set(k) 54 | return true 55 | }) 56 | return 57 | } 58 | 59 | // 测试k是否在集合中 60 | func (s *Set[K]) IsMember(k K) (b bool) { 61 | _, b = s.TryGet(k) 62 | return 63 | } 64 | 65 | // 返回的是s1没有的元素, s - s1 66 | func (s *Set[K]) Diff(s1 *Set[K]) (new *Set[K]) { 67 | 68 | new = New[K]() 69 | s.Range(func(k K) bool { 70 | if !s1.IsMember(k) { 71 | new.Set(k) 72 | } 73 | return true 74 | }) 75 | return 76 | } 77 | 78 | // 返回两个集合的所有元素 79 | func (s *Set[K]) Union(sets ...*Set[K]) (new *Set[K]) { 80 | 81 | new = New[K]() 82 | s.Range(func(k K) bool { 83 | new.Set(k) 84 | return true 85 | }) 86 | 87 | for _, s1 := range sets { 88 | s1.Range(func(k K) bool { 89 | new.Set(k) 90 | return true 91 | }) 92 | } 93 | 94 | return 95 | } 96 | 97 | // 返回两个集合的公共集合 98 | func (s *Set[K]) Intersection(s1 *Set[K]) (new *Set[K]) { 99 | if s.Len() >= s1.Len() { 100 | s, s1 = s1, s 101 | } 102 | 103 | new = New[K]() 104 | s.Range(func(k K) bool { 105 | if s1.IsMember(k) { 106 | new.Set(k) 107 | } 108 | return true 109 | }) 110 | return 111 | } 112 | 113 | // 测试集合s每个元素是否在s1里面, s <= s1 114 | func (s *Set[K]) IsSubset(s1 *Set[K]) (b bool) { 115 | if s.Len() > s1.Len() { 116 | return false 117 | } 118 | 119 | b = true 120 | s.Range(func(k K) bool { 121 | if !s1.IsMember(k) { 122 | b = false 123 | return false 124 | } 125 | return true 126 | }) 127 | return 128 | } 129 | 130 | // 测试集合s1每个元素是否在s里面 s1 <= s 131 | func (s *Set[K]) IsSuperset(s1 *Set[K]) (b bool) { 132 | return s1.IsSubset(s) 133 | } 134 | 135 | // 遍历 136 | func (s *Set[K]) Range(cb func(k K) bool) { 137 | s.SortedMap.Range(func(k K, _ struct{}) bool { 138 | return cb(k) 139 | }) 140 | } 141 | 142 | // 两个集合是否相等 143 | func (s *Set[K]) Equal(s1 *Set[K]) (b bool) { 144 | if s.Len() != s1.Len() { 145 | return false 146 | } 147 | 148 | b = true 149 | s.Range(func(k K) bool { 150 | _, b = s1.TryGet(k) 151 | return b 152 | }) 153 | 154 | return 155 | } 156 | -------------------------------------------------------------------------------- /set/set_test.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "testing" 6 | ) 7 | 8 | func Test_Range_New(t *testing.T) { 9 | s := New[string]() 10 | a := []string{"1111", "2222", "3333"} 11 | for _, v := range a { 12 | s.Set(v) 13 | } 14 | var got []string 15 | s.Range(func(k string) bool { 16 | got = append(got, k) 17 | return true 18 | }) 19 | 20 | if !equalSlices(got, a) { 21 | t.Errorf("expected %v, got %v", a, got) 22 | } 23 | } 24 | 25 | func Test_Range_From(t *testing.T) { 26 | a := []string{"1111", "2222", "3333"} 27 | s := From(a...) 28 | for _, v := range a { 29 | s.Set(v) 30 | } 31 | var got []string 32 | s.Range(func(k string) bool { 33 | got = append(got, k) 34 | return true 35 | }) 36 | 37 | if !equalSlices(got, a) { 38 | t.Errorf("expected %v, got %v", a, got) 39 | } 40 | } 41 | 42 | func Test_Len(t *testing.T) { 43 | s := New[int]() 44 | max := 1000 45 | for i := 0; i < max; i++ { 46 | s.Set(i) 47 | } 48 | if s.Len() != max { 49 | t.Errorf("expected %d, got %d", max, s.Len()) 50 | } 51 | } 52 | 53 | func Test_Equal(t *testing.T) { 54 | s := New[int]() 55 | max := 1000 56 | for i := 0; i < max; i++ { 57 | s.Set(i) 58 | } 59 | if s.Len() != max { 60 | t.Errorf("expected %d, got %d", max, s.Len()) 61 | } 62 | 63 | s2 := s.Clone() 64 | 65 | if s.Len() != s2.Len() { 66 | t.Errorf("expected %d, got %d", s.Len(), s2.Len()) 67 | } 68 | if !s.Equal(s2) { 69 | t.Errorf("expected sets to be equal") 70 | } 71 | } 72 | 73 | func Test_Not_Equal(t *testing.T) { 74 | s := New[int]() 75 | s2 := New[int]() 76 | max := 1000 77 | 78 | for i := 0; i < max; i++ { 79 | s.Set(i) 80 | } 81 | 82 | if s.Len() != max { 83 | t.Errorf("expected %d, got %d", max, s.Len()) 84 | } 85 | 86 | for i := 0; i < max; i++ { 87 | s2.Set(i - 1) 88 | } 89 | 90 | if s.Len() != s2.Len() { 91 | t.Errorf("expected %d, got %d", s.Len(), s2.Len()) 92 | } 93 | if s.Equal(s2) { 94 | t.Errorf("expected sets to be not equal") 95 | } 96 | } 97 | 98 | func Test_Not_Equal2(t *testing.T) { 99 | s := New[int]() 100 | s2 := New[int]() 101 | max := 1000 102 | 103 | for i := 0; i < max; i++ { 104 | s.Set(i) 105 | } 106 | 107 | if s.Len() != max { 108 | t.Errorf("expected %d, got %d", max, s.Len()) 109 | } 110 | 111 | for i := 0; i < max/2; i++ { 112 | s2.Set(i - 1) 113 | } 114 | 115 | if s.Equal(s2) { 116 | t.Errorf("expected sets to be not equal") 117 | } 118 | } 119 | 120 | func Test_IsMember(t *testing.T) { 121 | s := New[int]() 122 | for i := 0; i < 10; i++ { 123 | s.Set(i) 124 | } 125 | if !s.IsMember(1) { 126 | t.Errorf("expected true, got false") 127 | } 128 | } 129 | 130 | func Test_Union(t *testing.T) { 131 | s := From("1111") 132 | s1 := From("2222") 133 | s2 := From("3333") 134 | 135 | newSet := s.Union(s1, s2) 136 | expected := []string{"1111", "2222", "3333"} 137 | if !equalSlices(newSet.ToSlice(), expected) { 138 | t.Errorf("expected %v, got %v", expected, newSet.ToSlice()) 139 | } 140 | } 141 | 142 | func Test_Diff(t *testing.T) { 143 | s := From("hello", "world", "1234", "4567") 144 | s2 := From("1234", "4567") 145 | 146 | newSet := s.Diff(s2) 147 | expected := []string{"hello", "world"} 148 | if !equalSlices(newSet.ToSlice(), expected) { 149 | t.Errorf("expected %v, got %v", expected, newSet.ToSlice()) 150 | } 151 | } 152 | 153 | func Test_Intersection(t *testing.T) { 154 | s := From("1234", "5678", "9abc") 155 | s2 := From("abcde", "5678", "9abc") 156 | 157 | v := s.Intersection(s2).ToSlice() 158 | expected := []string{"5678", "9abc"} 159 | if !equalSlices(v, expected) { 160 | t.Errorf("expected %v, got %v", expected, v) 161 | } 162 | } 163 | 164 | func Test_IsSubset(t *testing.T) { 165 | s := From("5678", "9abc") 166 | s2 := From("abcde", "5678", "9abc") 167 | 168 | if !s.IsSubset(s2) { 169 | t.Errorf("expected true, got false") 170 | } 171 | } 172 | 173 | func Test_IsSubset_Not(t *testing.T) { 174 | s := From("aa", "5678", "9abc") 175 | s2 := From("abcde", "5678", "9abc") 176 | 177 | if s.IsSubset(s2) { 178 | t.Errorf("expected false, got true") 179 | } 180 | } 181 | 182 | func Test_IsSubset_Not2(t *testing.T) { 183 | s := From("aa", "5678", "9abc", "33333") 184 | s2 := From("abcde", "5678", "9abc") 185 | 186 | if s.IsSubset(s2) { 187 | t.Errorf("expected false, got true") 188 | } 189 | } 190 | 191 | func Test_IsSuperset(t *testing.T) { 192 | s2 := From("5678", "9abc") 193 | s := From("abcde", "5678", "9abc") 194 | 195 | if !s.IsSuperset(s2) { 196 | t.Errorf("expected true, got false") 197 | } 198 | } 199 | 200 | func Test_IsSuperset_Not(t *testing.T) { 201 | s2 := From("aa", "5678", "9abc") 202 | s := From("abcde", "5678", "9abc") 203 | 204 | if s.IsSuperset(s2) { 205 | t.Errorf("expected false, got true") 206 | } 207 | } 208 | 209 | func Test_IsSuperset_Not2(t *testing.T) { 210 | s2 := From("aa", "5678", "9abc", "33333") 211 | s := From("abcde", "5678", "9abc") 212 | 213 | if s.IsSuperset(s2) { 214 | t.Errorf("expected false, got true") 215 | } 216 | } 217 | 218 | // 辅助函数,用于比较两个切片是否相等 219 | func equalSlices(a, b []string) bool { 220 | if len(a) != len(b) { 221 | return false 222 | } 223 | for i := range a { 224 | if a[i] != b[i] { 225 | return false 226 | } 227 | } 228 | return true 229 | } 230 | -------------------------------------------------------------------------------- /skiplist/skiplist_bench_test.go: -------------------------------------------------------------------------------- 1 | package skiplist 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | // goos: darwin 10 | // goarch: amd64 11 | // pkg: github.com/antlabs/gstl/skiplist 12 | // cpu: Intel(R) Core(TM) i7-1068NG7 CPU @ 2.30GHz 13 | // BenchmarkGet-8 1000000000 0.7746 ns/op 14 | // BenchmarkGetStd-8 1000000000 0.7847 ns/op 15 | // PASS 16 | // ok github.com/antlabs/gstl/skiplist 178.377s 17 | // 五百万数据的Get操作时间 18 | func BenchmarkGet(b *testing.B) { 19 | //max := 1000000.0 * 5 20 | max := float64(b.N) 21 | set := New[float64, float64]() 22 | for i := 0.0; i < max; i++ { 23 | set.Set(i, i) 24 | } 25 | 26 | b.ResetTimer() 27 | 28 | for i := 0.0; i < max; i++ { 29 | v := set.Get(i) 30 | if v != i { 31 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 32 | } 33 | } 34 | } 35 | 36 | func BenchmarkGetStd(b *testing.B) { 37 | 38 | //max := 1000000.0 * 5 39 | max := float64(b.N) 40 | set := make(map[float64]float64, int(max)) 41 | for i := 0.0; i < max; i++ { 42 | set[i] = i 43 | } 44 | 45 | b.ResetTimer() 46 | 47 | for i := 0.0; i < max; i++ { 48 | v := set[i] 49 | if v != i { 50 | panic(fmt.Sprintf("need:%f, got:%f", i, v)) 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /skiplist/skiplist_test.go: -------------------------------------------------------------------------------- 1 | package skiplist 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "fmt" 6 | "sync" 7 | "testing" 8 | 9 | "github.com/antlabs/gstl/cmp" 10 | ) 11 | 12 | func Test_New(t *testing.T) { 13 | n := New[int, int]() 14 | if n == nil { 15 | t.Errorf("expected non-nil, got nil") 16 | } 17 | } 18 | 19 | func Test_SetGet(t *testing.T) { 20 | zset := New[float64, string]() 21 | max := 100.0 22 | for i := 0.0; i < max; i++ { 23 | zset.Set(i, fmt.Sprintf("%d", int(i))) 24 | } 25 | 26 | for i := 0.0; i < max; i++ { 27 | v := zset.Get(i) 28 | if v != fmt.Sprintf("%d", int(i)) { 29 | t.Errorf("expected %s, got %s", fmt.Sprintf("%d", int(i)), v) 30 | } 31 | } 32 | } 33 | 34 | // 测试插入重复 35 | func Test_InsertRepeatingElement(t *testing.T) { 36 | sl := New[float64, string]() 37 | max := 100 38 | for i := 0; i < max; i++ { 39 | sl.Set(float64(i), fmt.Sprint(i)) 40 | } 41 | 42 | for i := 0; i < max; i++ { 43 | sl.Set(float64(i), fmt.Sprint(i+1)) 44 | } 45 | 46 | for i := 0; i < max; i++ { 47 | if sl.Get(float64(i)) != fmt.Sprint(i+1) { 48 | t.Errorf("expected %s, got %s", fmt.Sprint(i+1), sl.Get(float64(i))) 49 | } 50 | } 51 | } 52 | 53 | func Test_SetGetRemove(t *testing.T) { 54 | zset := New[float64, float64]() 55 | 56 | max := 100.0 57 | for i := 0.0; i < max; i++ { 58 | zset.Set(i, i) 59 | } 60 | 61 | for i := 0.0; i < max; i++ { 62 | zset.Remove(i) 63 | if float64(zset.Len()) != max-1 { 64 | t.Errorf("expected %f, got %f", max-1, float64(zset.Len())) 65 | } 66 | for j := 0.0; j < max; j++ { 67 | if j == i { 68 | continue 69 | } 70 | v, ok := zset.TryGet(j) 71 | if !ok { 72 | t.Errorf("expected true for score:%f, i:%f, j:%f", j, i, j) 73 | return 74 | } 75 | if v != j { 76 | t.Errorf("expected %f, got %f", j, v) 77 | } 78 | } 79 | zset.Set(i, i) 80 | } 81 | } 82 | 83 | // 测试TopMin, 它返回最小的几个值 84 | func Test_Skiplist_TopMin(t *testing.T) { 85 | 86 | need := []int{} 87 | count10 := 10 88 | count100 := 100 89 | count1000 := 1000 90 | 91 | for i := 0; i < count1000; i++ { 92 | need = append(need, i) 93 | } 94 | 95 | needCount := []int{count10, count100, count100} 96 | for i, b := range []*SkipList[float64, int]{ 97 | // btree里面元素 少于 TopMin 需要返回的值 98 | func() *SkipList[float64, int] { 99 | b := New[float64, int]() 100 | for i := 0; i < count10; i++ { 101 | b.Set(float64(i), i) 102 | } 103 | 104 | if b.Len() != count10 { 105 | t.Errorf("expected %d, got %d", count10, b.Len()) 106 | } 107 | return b 108 | }(), 109 | // btree里面元素 等于 TopMin 需要返回的值 110 | func() *SkipList[float64, int] { 111 | 112 | b := New[float64, int]() 113 | for i := 0; i < count100; i++ { 114 | b.Set(float64(i), i) 115 | } 116 | if b.Len() != count100 { 117 | t.Errorf("expected %d, got %d", count100, b.Len()) 118 | } 119 | return b 120 | }(), 121 | // btree里面元素 大于 TopMin 需要返回的值 122 | func() *SkipList[float64, int] { 123 | 124 | b := New[float64, int]() 125 | for i := 0; i < count1000; i++ { 126 | b.Set(float64(i), i) 127 | } 128 | if b.Len() != count1000 { 129 | t.Errorf("expected %d, got %d", count1000, b.Len()) 130 | } 131 | return b 132 | }(), 133 | } { 134 | var key, val []int 135 | b.TopMin(count100, func(k float64, v int) bool { 136 | key = append(key, int(k)) 137 | val = append(val, v) 138 | return true 139 | }) 140 | if !equalSlices(key, need[:needCount[i]]) { 141 | t.Errorf("expected %v, got %v", need[:needCount[i]], key) 142 | } 143 | if !equalSlices(val, need[:needCount[i]]) { 144 | t.Errorf("expected %v, got %v", need[:needCount[i]], val) 145 | } 146 | } 147 | } 148 | 149 | // 测试下负数 150 | func Test_Skiplist_TopMin2(t *testing.T) { 151 | start := -10 152 | max := 100 153 | limit := 10 154 | sl := New[float64, int]() 155 | 156 | need := make([]int, 0, limit) 157 | for i, l := start, limit; i < max && l > 0; i++ { 158 | sl.Set(float64(i), i) 159 | need = append(need, i) 160 | l-- 161 | } 162 | 163 | got := make([]int, 0, limit) 164 | sl.TopMin(10, func(k float64, v int) bool { 165 | got = append(got, int(k)) 166 | return true 167 | }) 168 | 169 | if !equalSlices(need, got) { 170 | t.Errorf("expected %v, got %v", need, got) 171 | } 172 | } 173 | 174 | // debug, 指定层 175 | func Test_SkipList_SetAndGet_Level(t *testing.T) { 176 | 177 | sl := New[float64, int]() 178 | 179 | keys := []int{5, 8, 10} 180 | level := []int{2, 3, 5} 181 | for i, key := range keys { 182 | sl.InsertInner(float64(key), key, level[i]) 183 | } 184 | 185 | sl.Draw() 186 | for _, i := range keys { 187 | v, count, _ := sl.GetWithMeta(float64(i)) 188 | fmt.Printf("get %v count = %v, nodes:%v, level:%v maxlevel:%v\n", 189 | float64(i), 190 | count.Total, 191 | count.Keys, 192 | count.Level, 193 | count.MaxLevel) 194 | if v != i { 195 | t.Errorf("expected %d, got %d", i, v) 196 | } 197 | } 198 | } 199 | 200 | // debug, 用的入口函数 201 | func Test_SkipList_SetAndGet2(t *testing.T) { 202 | 203 | sl := New[float64, int]() 204 | 205 | max := 1000 206 | start := -1 207 | for i := max; i >= start; i-- { 208 | sl.Set(float64(i), i) 209 | } 210 | 211 | sl.Draw() 212 | for i := start; i < max; i++ { 213 | v, count, _ := sl.GetWithMeta(float64(i)) 214 | fmt.Printf("get %v count = %v, nodes:%v, level:%v maxlevel:%v\n", 215 | float64(i), 216 | count.Total, 217 | count.Keys, 218 | count.Level, 219 | count.MaxLevel) 220 | if v != i { 221 | t.Errorf("expected %d, got %d", i, v) 222 | } 223 | } 224 | } 225 | 226 | // 测试TopMax, 返回最大的几个数据降序返回 227 | func Test_Skiplist_TopMax(t *testing.T) { 228 | 229 | need := [3][]int{} 230 | count10 := 10 231 | count100 := 100 232 | count1000 := 1000 233 | count := []int{count10, count100, count1000} 234 | 235 | for i := 0; i < len(count); i++ { 236 | for j, k := count[i]-1, count100-1; j >= 0 && k >= 0; j-- { 237 | need[i] = append(need[i], j) 238 | k-- 239 | } 240 | } 241 | 242 | for i, b := range []*SkipList[float64, int]{ 243 | // btree里面元素 少于 TopMax 需要返回的值 244 | func() *SkipList[float64, int] { 245 | b := New[float64, int]() 246 | for i := 0; i < count10; i++ { 247 | b.Set(float64(i), i) 248 | } 249 | 250 | if b.Len() != count10 { 251 | t.Errorf("expected %d, got %d", count10, b.Len()) 252 | } 253 | return b 254 | }(), 255 | // btree里面元素 等于 TopMax 需要返回的值 256 | func() *SkipList[float64, int] { 257 | 258 | b := New[float64, int]() 259 | for i := 0; i < count100; i++ { 260 | b.Set(float64(i), i) 261 | } 262 | if b.Len() != count100 { 263 | t.Errorf("expected %d, got %d", count100, b.Len()) 264 | } 265 | return b 266 | }(), 267 | // btree里面元素 大于 TopMax 需要返回的值 268 | func() *SkipList[float64, int] { 269 | 270 | b := New[float64, int]() 271 | for i := 0; i < count1000; i++ { 272 | b.Set(float64(i), i) 273 | } 274 | if b.Len() != count1000 { 275 | t.Errorf("expected %d, got %d", count1000, b.Len()) 276 | } 277 | return b 278 | }(), 279 | } { 280 | var key, val []int 281 | b.TopMax(count100, func(k float64, v int) bool { 282 | key = append(key, int(k)) 283 | val = append(val, v) 284 | return true 285 | }) 286 | length := cmp.Min(count[i], len(need[i])) 287 | if !equalSlices(key, need[i][:length]) { 288 | t.Errorf("expected %v, got %v", need[i][:length], key) 289 | } 290 | if !equalSlices(val, need[i][:length]) { 291 | t.Errorf("expected %v, got %v", need[i][:length], val) 292 | } 293 | } 294 | } 295 | 296 | func Test_ConcurrentSkipList_InsertGet(t *testing.T) { 297 | csl := NewConcurrent[int, string]() 298 | var wg sync.WaitGroup 299 | count := 1000 300 | 301 | // Concurrent inserts 302 | wg.Add(count) 303 | for i := 0; i < count; i++ { 304 | go func(i int) { 305 | defer wg.Done() 306 | csl.Insert(i, fmt.Sprintf("value%d", i)) 307 | }(i) 308 | } 309 | 310 | wg.Wait() 311 | 312 | // Concurrent gets 313 | wg.Add(count) 314 | for i := 0; i < count; i++ { 315 | go func(i int) { 316 | defer wg.Done() 317 | if val, ok := csl.Get(i); !ok || val != fmt.Sprintf("value%d", i) { 318 | t.Errorf("expected value%d, got %v", i, val) 319 | } 320 | }(i) 321 | } 322 | 323 | wg.Wait() 324 | } 325 | 326 | func Test_ConcurrentSkipList_Delete(t *testing.T) { 327 | csl := NewConcurrent[int, string]() 328 | var wg sync.WaitGroup 329 | count := 1000 330 | 331 | // Insert elements 332 | for i := 0; i < count; i++ { 333 | csl.Insert(i, fmt.Sprintf("value%d", i)) 334 | } 335 | 336 | // Concurrent deletes 337 | wg.Add(count) 338 | for i := 0; i < count; i++ { 339 | go func(i int) { 340 | defer wg.Done() 341 | csl.Delete(i) 342 | }(i) 343 | } 344 | 345 | wg.Wait() 346 | 347 | // Verify all elements are deleted 348 | for i := 0; i < count; i++ { 349 | if _, ok := csl.Get(i); ok { 350 | t.Errorf("expected element %d to be deleted", i) 351 | } 352 | } 353 | } 354 | 355 | func Test_ConcurrentSkipList_Get(t *testing.T) { 356 | csl := NewConcurrent[int, string]() 357 | var wg sync.WaitGroup 358 | count := 1000 359 | 360 | // Insert elements 361 | for i := 0; i < count; i++ { 362 | csl.Insert(i, fmt.Sprintf("value%d", i)) 363 | } 364 | 365 | // Concurrent gets 366 | wg.Add(count) 367 | for i := 0; i < count; i++ { 368 | go func(i int) { 369 | defer wg.Done() 370 | if val, ok := csl.Get(i); !ok || val != fmt.Sprintf("value%d", i) { 371 | t.Errorf("expected value%d, got %v", i, val) 372 | } 373 | }(i) 374 | } 375 | 376 | wg.Wait() 377 | } 378 | 379 | func Test_ConcurrentSkipList_Range(t *testing.T) { 380 | csl := NewConcurrent[int, string]() 381 | count := 1000 382 | 383 | // Insert elements 384 | for i := 0; i < count; i++ { 385 | csl.Insert(i, fmt.Sprintf("value%d", i)) 386 | } 387 | 388 | // Range over elements 389 | elements := make(map[int]string) 390 | csl.Range(func(score int, value string) bool { 391 | elements[score] = value 392 | return true 393 | }) 394 | 395 | // Verify all elements are ranged 396 | for i := 0; i < count; i++ { 397 | if val, exists := elements[i]; !exists || val != fmt.Sprintf("value%d", i) { 398 | t.Errorf("expected value%d, got %v", i, val) 399 | } 400 | } 401 | } 402 | 403 | func Test_ConcurrentSkipList_Remove(t *testing.T) { 404 | csl := NewConcurrent[int, string]() 405 | var wg sync.WaitGroup 406 | count := 1000 407 | 408 | // Insert elements 409 | for i := 0; i < count; i++ { 410 | csl.Insert(i, fmt.Sprintf("value%d", i)) 411 | } 412 | 413 | // Concurrent removes 414 | wg.Add(count) 415 | for i := 0; i < count; i++ { 416 | go func(i int) { 417 | defer wg.Done() 418 | csl.Remove(i) 419 | }(i) 420 | } 421 | 422 | wg.Wait() 423 | 424 | // Verify all elements are removed 425 | for i := 0; i < count; i++ { 426 | if _, ok := csl.Get(i); ok { 427 | t.Errorf("expected element %d to be removed", i) 428 | } 429 | } 430 | } 431 | 432 | // 辅助函数,用于比较两个切片是否相等 433 | func equalSlices(a, b []int) bool { 434 | if len(a) != len(b) { 435 | return false 436 | } 437 | for i := range a { 438 | if a[i] != b[i] { 439 | return false 440 | } 441 | } 442 | return true 443 | } 444 | -------------------------------------------------------------------------------- /trie/trie_map.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "unicode/utf8" 5 | 6 | "github.com/antlabs/gstl/api" 7 | ) 8 | 9 | // apache 2.0 antlabs 10 | 11 | var _ api.Trie[int] = (*Trie[int])(nil) 12 | 13 | type Trie[V any] struct { 14 | v V 15 | // 这里也可以换成别的数据结构, btree, avltree, skiplist, slice(搜索就二分搜索,插入也是,并且维护有序) 16 | // 压测下性能 TODO 17 | children map[rune]*Trie[V] 18 | isSet bool 19 | length int 20 | } 21 | 22 | func New[V any]() *Trie[V] { 23 | return &Trie[V]{} 24 | } 25 | 26 | func (t *Trie[V]) Set(k string, v V) { 27 | _, _ = t.Swap(k, v) 28 | } 29 | 30 | func (t *Trie[V]) Swap(k string, v V) (prev V, replaced bool) { 31 | n := t 32 | for _, r := range k { 33 | c := n.children[r] 34 | if c == nil { 35 | if n.children == nil { 36 | n.children = map[rune]*Trie[V]{} 37 | } 38 | c = &Trie[V]{} 39 | n.children[r] = c 40 | } 41 | 42 | n = c 43 | } 44 | 45 | prev = n.v 46 | n.v = v 47 | 48 | replaced = n.isSet 49 | if !replaced { 50 | t.length++ 51 | } 52 | n.isSet = true 53 | return 54 | } 55 | 56 | func (t *Trie[V]) HasPrefix(k string) bool { 57 | 58 | n := t 59 | for _, r := range k { 60 | n = n.children[r] 61 | if n == nil { 62 | return false 63 | } 64 | } 65 | 66 | return true 67 | } 68 | 69 | func (t *Trie[V]) TryGet(k string) (v V, found bool) { 70 | 71 | n := t 72 | for _, r := range k { 73 | n = n.children[r] 74 | if n == nil { 75 | return 76 | } 77 | } 78 | return n.v, true && n.isSet 79 | } 80 | 81 | func (t *Trie[V]) Get(k string) (v V) { 82 | v, _ = t.TryGet(k) 83 | return 84 | } 85 | 86 | func (t *Trie[V]) isLeaf() bool { 87 | return len(t.children) == 0 88 | } 89 | 90 | // 记录删除的过程 91 | type recogNode[V any] struct { 92 | r rune 93 | n *Trie[V] 94 | } 95 | 96 | // 删除有两种方法, 这里先选择第1种,后面有时间再压测下第二种效率如何 97 | // 1.记录rune和节点,删除这个节点。如果是子节点,再回溯删除 98 | // 2.声明一个parent指针,不记录过程节点,直接p = n.parent; p != nil; p=p.parent 回溯删除 99 | func (t *Trie[V]) Delete(k string) { 100 | recog := make([]recogNode[V], 0, utf8.RuneCountInString(k)) 101 | 102 | var v V 103 | n := t 104 | 105 | for _, r := range k { 106 | recog = append(recog, recogNode[V]{r, n}) 107 | n = n.children[r] 108 | if n == nil { 109 | return 110 | } 111 | } 112 | 113 | n.v = v 114 | n.isSet = false 115 | 116 | n.length-- 117 | if !n.isLeaf() { 118 | return 119 | } 120 | 121 | for last := len(recog) - 1; last >= 0; last-- { 122 | p := recog[last].n 123 | delete(p.children, recog[last].r) 124 | 125 | if !p.isLeaf() { 126 | return 127 | } 128 | 129 | if p.isSet { 130 | return 131 | } 132 | } 133 | } 134 | 135 | func (t *Trie[V]) Len() int { 136 | return t.length 137 | } 138 | -------------------------------------------------------------------------------- /trie/trie_map_test.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | // set get 预期是设置进去, 也能读出来 9 | func Test_TrimeMap_SetGet(t *testing.T) { 10 | 11 | tm := New[string]() 12 | max := 1000 13 | 14 | for i := 1; i < max; i++ { 15 | 16 | key := fmt.Sprint(i) 17 | 18 | tm.Set(key, key) 19 | val := tm.Get(key) 20 | if key != val { 21 | t.Errorf("expected %s, got %s", key, val) 22 | } 23 | } 24 | } 25 | 26 | // HasPrefix 找到 27 | func Test_TrieMap_HasPrefix(t *testing.T) { 28 | tm := New[string]() 29 | key := "/hello/world" 30 | tm.Set("/hello", "1") 31 | tm.Set("/hello/world", "1") 32 | for i := 1; i < len(key); i++ { 33 | 34 | if !tm.HasPrefix(key[:i]) { 35 | t.Errorf("expected true for prefix %s", key[:i]) 36 | } 37 | } 38 | } 39 | 40 | // HasPrefix 找不到 41 | func Test_TrieMap_HasPrefix_notFound(t *testing.T) { 42 | tm := New[string]() 43 | key := "/hello/world" 44 | tm.Set("/hello", "1") 45 | tm.Set("/hello/world", "1") 46 | for i := 1; i < len(key); i++ { 47 | 48 | if !tm.HasPrefix(key[:i]) { 49 | t.Errorf("expected true for prefix %s", key[:i]) 50 | } 51 | } 52 | 53 | if tm.HasPrefix("/ha") { 54 | t.Errorf("expected false for prefix /ha") 55 | } 56 | } 57 | 58 | func Test_TrieMap_TryGet_notFound(t *testing.T) { 59 | tm := New[string]() 60 | key := "/hello/world" 61 | tm.Set("/hello", "1") 62 | tm.Set("/hello/world", "1") 63 | for i := 1; i < len(key); i++ { 64 | 65 | if !tm.HasPrefix(key[:i]) { 66 | t.Errorf("expected true for prefix %s", key[:i]) 67 | } 68 | } 69 | _, ok := tm.TryGet("/ha") 70 | if ok { 71 | t.Errorf("expected false for /ha") 72 | } 73 | _, ok = tm.TryGet("/he") 74 | if ok { 75 | t.Errorf("expected false for /he") 76 | } 77 | } 78 | 79 | func Test_TrieMap_Delete(t *testing.T) { 80 | 81 | tm := New[string]() 82 | max := 1000 83 | 84 | for i := 1; i < max; i++ { 85 | 86 | key := fmt.Sprint(i) 87 | 88 | tm.Set(key, key) 89 | val := tm.Get(key) 90 | if key != val { 91 | t.Errorf("expected %s, got %s", key, val) 92 | } 93 | tm.Delete(key) 94 | val, ok := tm.TryGet(key) 95 | if ok { 96 | t.Errorf("expected false for key %s", key) 97 | } 98 | if val != "" { 99 | t.Errorf("expected empty string, got %s", val) 100 | } 101 | } 102 | 103 | key := fmt.Sprint(max + 1) 104 | tm.Delete(key) 105 | val, ok := tm.TryGet(key) 106 | if ok { 107 | t.Errorf("expected false for key %s", key) 108 | } 109 | if val != "" { 110 | t.Errorf("expected empty string, got %s", val) 111 | } 112 | } 113 | 114 | // 删除长的 115 | func Test_TrieMap_Delete2(t *testing.T) { 116 | 117 | tm := New[string]() 118 | 119 | tm.Set("/1", "/1") 120 | tm.Set("/12", "/12") 121 | tm.Delete("/12") 122 | if tm.Get("/12") != "" { 123 | t.Errorf("expected empty string for /12, got %s", tm.Get("/12")) 124 | } 125 | if tm.Get("/1") != "/1" { 126 | t.Errorf("expected /1, got %s", tm.Get("/1")) 127 | } 128 | } 129 | 130 | // 删除短的 131 | func Test_TrieMap_Delete3(t *testing.T) { 132 | 133 | tm := New[string]() 134 | 135 | tm.Set("/1", "/1") 136 | tm.Set("/12", "/12") 137 | tm.Delete("/1") 138 | if tm.Get("/12") != "/12" { 139 | t.Errorf("expected /12, got %s", tm.Get("/12")) 140 | } 141 | if tm.Get("/1") != "" { 142 | t.Errorf("expected empty string for /1, got %s", tm.Get("/1")) 143 | } 144 | } 145 | 146 | // 删除带中文 147 | func Test_TrieMap_Delete4(t *testing.T) { 148 | 149 | tm := New[string]() 150 | 151 | tm.Set("中", "中") 152 | tm.Set("中国", "中国") 153 | tm.Delete("中") 154 | if tm.Get("中国") != "中国" { 155 | t.Errorf("expected 中国, got %s", tm.Get("中国")) 156 | } 157 | if tm.Get("中") != "" { 158 | t.Errorf("expected empty string for 中, got %s", tm.Get("中")) 159 | } 160 | } 161 | 162 | func Test_TrieMap_Delete5(t *testing.T) { 163 | 164 | tm := New[string]() 165 | 166 | tm.Set("中", "中") 167 | tm.Set("中国", "中国") 168 | tm.Delete("中") 169 | tm.Delete("中国") 170 | if tm.Get("中国") != "" { 171 | t.Errorf("expected empty string for 中国, got %s", tm.Get("中国")) 172 | } 173 | if tm.Get("中") != "" { 174 | t.Errorf("expected empty string for 中, got %s", tm.Get("中")) 175 | } 176 | } 177 | 178 | func Test_TrieMap_Delete6(t *testing.T) { 179 | 180 | tm := New[string]() 181 | 182 | tm.Set("/1", "/1") 183 | tm.Set("/12", "/12") 184 | tm.Set("/13", "/13") 185 | tm.Delete("/12") 186 | if tm.Get("/12") != "" { 187 | t.Errorf("expected empty string for /12, got %s", tm.Get("/12")) 188 | } 189 | if tm.Get("/1") != "/1" { 190 | t.Errorf("expected /1, got %s", tm.Get("/1")) 191 | } 192 | if tm.Get("/13") != "/13" { 193 | t.Errorf("expected /13, got %s", tm.Get("/13")) 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /vec/example_search_test.go: -------------------------------------------------------------------------------- 1 | package vec 2 | 3 | // apache 2.0 antlabs 4 | import "fmt" 5 | 6 | func Example_search() { 7 | vec := New(1, 2, 3, 4, 5, 6, 7) 8 | index := vec.SearchFunc(func(e int) bool { 9 | return 7 <= e 10 | }) 11 | 12 | fmt.Println(index) 13 | } 14 | -------------------------------------------------------------------------------- /vec/vec.go: -------------------------------------------------------------------------------- 1 | package vec 2 | 3 | // apache 2.0 antlabs 4 | // 参考文档如下 5 | // https://doc.rust-lang.org/src/alloc/vec/mod.rs.html 6 | // https://doc.rust-lang.org/std/vec/struct.Vec.html 7 | 8 | import ( 9 | "errors" 10 | "fmt" 11 | 12 | "github.com/antlabs/gstl/cmp" 13 | ) 14 | 15 | var ( 16 | ErrLenGreaterCap = errors.New("len is too long > length of cap") 17 | ) 18 | 19 | const coefficient = 1.5 20 | 21 | // vec类型 22 | type Vec[T any] []T 23 | 24 | // 初始化一个vec 25 | func New[T any](a ...T) *Vec[T] { 26 | return (*Vec[T])(&a) 27 | } 28 | 29 | // 初始化函数, 可以把slice指针转成Vec类型 30 | func FromSlicePtr[T any](ptr *[]T) *Vec[T] { 31 | return (*Vec[T])(ptr) 32 | } 33 | 34 | // 初始化一个vec, 并指定底层的slice 容量 35 | func WithCapacity[T any](capacity int) *Vec[T] { 36 | p := make([]T, 0, capacity) 37 | return (*Vec[T])(&p) 38 | } 39 | 40 | // 清空vec里面的所有值 41 | // TODO 需要看下效率. 如果效率不行,使用reflect.SliceHeader, 强转, 然后挨个置空 42 | func (v *Vec[T]) Clear() { 43 | *v = []T{} 44 | } 45 | 46 | // 删除连续重复值 47 | // TODO 优化. 寻找更优做法 48 | func (v *Vec[T]) DedupFunc(cmp func(a, b T) bool) *Vec[T] { 49 | if v.Len() <= 1 { 50 | return v 51 | } 52 | 53 | slice := v.ToSlice() 54 | i := 0 55 | for i < len(slice) { 56 | j := i + 1 57 | for j < len(slice) && cmp(slice[i], slice[j]) { 58 | j++ 59 | } 60 | 61 | if j != i+1 { 62 | copy(slice[i+1:], slice[j:]) 63 | slice = slice[:len(slice)-(j-i-1)] 64 | 65 | //fmt.Printf("i = %d:%v\n", i, slice) 66 | } 67 | i++ 68 | } 69 | 70 | *v = *New(slice...) 71 | return v 72 | } 73 | 74 | // 从尾巴插入 75 | // 支持插入一个值或者多个值 76 | func (v *Vec[T]) Push(e ...T) *Vec[T] { 77 | *v = append(*v, e...) 78 | return v 79 | } 80 | 81 | // 设置新长度 82 | func (v *Vec[T]) SetLen(newLen int) { 83 | slice := []T(*v) 84 | if newLen > v.Cap() { 85 | panic(ErrLenGreaterCap) 86 | } 87 | 88 | slice = slice[:newLen] 89 | *v = Vec[T](slice) 90 | } 91 | 92 | // 添加other类型的vec到v里面 93 | func (v *Vec[T]) Append(other *Vec[T]) *Vec[T] { 94 | *v = append(*v, other.ToSlice()...) 95 | return v 96 | } 97 | 98 | // 删除vec第一个元素, 并返回它, 和TakeFirst是同义词的关系 99 | func (v *Vec[T]) PopFront() (e T, ok bool) { 100 | return v.TakeFirst() 101 | } 102 | 103 | // 从尾巴弹出 104 | func (v *Vec[T]) Pop() (e T, ok bool) { 105 | l := v.Len() 106 | if l == 0 { 107 | return 108 | } 109 | 110 | slice := v.ToSlice() 111 | e = slice[l-1] 112 | v = New(slice[:l-1]...) 113 | 114 | // 缩容 115 | if v.Len()*2 < v.Cap() { 116 | newSlice := make([]T, v.Len()) 117 | copy(newSlice, slice) 118 | v = New(newSlice...) 119 | } 120 | 121 | return e, true 122 | } 123 | 124 | // 返回slice底层的slice 125 | func (v *Vec[T]) ToSlice() []T { 126 | return []T(*v) 127 | } 128 | 129 | // 往指定位置插入元素, 后面的元素往右移动 130 | // i是位置, es可以是单个值和多个值 131 | func (v *Vec[T]) Insert(i int, es ...T) *Vec[T] { 132 | l := v.Len() 133 | if i == l { 134 | // slice=1 2 3 4 insert(4, 5), result=1 2 3 4 5 135 | v.Push(es...) 136 | return v 137 | } 138 | 139 | if i > l { 140 | panic(fmt.Sprintf("insertion index (is %d) should be <= len (is %d)", i, l)) 141 | } 142 | 143 | need := l + len(es) 144 | if need > v.Cap() { 145 | v.Reserve(len(es)) 146 | } 147 | 148 | slice := v.ToSlice() 149 | 150 | // 插入之前: hello world 151 | // 插入之后: hello es world 152 | newSlice := slice[:need] 153 | copy(newSlice[i+len(es):], slice[i:]) //先往后挪 154 | copy(newSlice[i:], es) //拷贝到i指定的位置 155 | 156 | // TODO 需要压测下, 这种写法是否慢 157 | *v = *New(newSlice...) 158 | return v 159 | } 160 | 161 | // 删除指定范围内的元素 162 | func (v *Vec[T]) Delete(i, j int) *Vec[T] { 163 | slice := v.ToSlice() 164 | copy(slice[i:], slice[j:]) 165 | *v = *New(slice[:v.Len()-(j-i)]...) 166 | return v 167 | } 168 | 169 | // 获取指定索引的值 170 | func (v *Vec[T]) Get(index int) (e T) { 171 | slice := v.ToSlice() 172 | return slice[index] 173 | } 174 | 175 | // 获取指定索引的值, 如果索引不合法会返回错误 176 | func (v *Vec[T]) TryGet(index int) (e T, ok bool) { 177 | if index < 0 || index >= v.Len() { 178 | return 179 | } 180 | return v.Get(index), true 181 | } 182 | 183 | // 获取指定索引的指针 184 | func (v *Vec[T]) GetPtr(index int) (e *T) { 185 | slice := v.ToSlice() 186 | return &slice[index] 187 | } 188 | 189 | // 设置指定索引的值 190 | func (v *Vec[T]) Set(index int, value T) *Vec[T] { 191 | v.ToSlice()[index] = value 192 | return v 193 | } 194 | 195 | // 删除指定索引的元素, 空缺的位置, 使用最后一个元素替换上去 196 | func (v *Vec[T]) SwapRemove(index int) (rv T) { 197 | l := v.Len() 198 | if index >= l { 199 | panic(fmt.Sprintf("SwapRemove index (is %d) should be < len (is %d)", index, l)) 200 | } 201 | 202 | rv = v.Get(index) 203 | v.Set(index, v.Get(l-1)) 204 | v.SetLen(l - 1) 205 | return 206 | } 207 | 208 | // 在给定索引处将vec拆分为两个 209 | // 返回一个新的vec, 范围是[at, len), 这里需要注意 210 | // 原始的vec的范围是[0, at), 不改变原先的容量 211 | func (v *Vec[T]) SplitOff(at int) (new *Vec[T]) { 212 | l := v.Len() 213 | 214 | if at > l { 215 | panic(fmt.Sprintf("`at` split index (is %d) should be <= len (is %d)", at, l)) 216 | } 217 | 218 | if at == 0 { 219 | v2 := *v 220 | v.Clear() 221 | return &v2 222 | } 223 | 224 | newSlice := make([]T, l-at) 225 | copy(newSlice, v.ToSlice()[at:]) 226 | 227 | *v = *New(v.ToSlice()[:at]...) 228 | return New(newSlice...) 229 | } 230 | 231 | // 删除指定索引的元素 232 | func (v *Vec[T]) Remove(index int) *Vec[T] { 233 | l := v.Len() 234 | if index >= l { 235 | panic(fmt.Sprintf("removal index (is %d) should be < len (is %d)", index, l)) 236 | } 237 | 238 | copy(v.ToSlice()[index:], v.ToSlice()[index+1:]) 239 | v.SetLen(l - 1) 240 | 241 | return v 242 | } 243 | 244 | // 提前在现有基础上再额外申请 additional 长度空间 245 | // 可以避免频繁的重新分配 246 | // 如果容量已经满足, 则什么事也不做 247 | func (v *Vec[T]) Reserve(additional int) *Vec[T] { 248 | return v.reserve(additional, coefficient) 249 | } 250 | 251 | // 如果容量已经满足, 则什么事也不做 252 | // 保留最小容量, 提前在现有基础上再额外申请 additional 长度空间 253 | func (v *Vec[T]) ReserveExact(additional int) *Vec[T] { 254 | return v.reserve(additional, 1) 255 | } 256 | 257 | func (v *Vec[T]) reserve(additional int, factor float64) *Vec[T] { 258 | l := v.Len() 259 | if l+additional <= v.Cap() { 260 | return v 261 | } 262 | 263 | newSlice := make([]T, l, int(float64(l+additional)*factor)) 264 | copy(newSlice, v.ToSlice()) 265 | *v = Vec[T](newSlice) 266 | return v 267 | } 268 | 269 | // 向下收缩vec的容器 270 | func (v *Vec[T]) ShrinkToFit() *Vec[T] { 271 | l := v.Len() 272 | if v.Cap() > getCap(l) { 273 | v.ShrinkTo(l) 274 | } 275 | return v 276 | } 277 | 278 | // 向下收缩vec的容器, 会重新分配底层的slice 279 | func (v *Vec[T]) ShrinkTo(minCapacity int) *Vec[T] { 280 | cap := v.Cap() 281 | minCapacity = getCap(minCapacity) 282 | if cap > minCapacity { 283 | min := cmp.Min(cap, minCapacity) 284 | if min == 0 { 285 | min = int(0.66 * float64(cap)) 286 | } 287 | 288 | newSlice := append([]T{}, v.ToSlice()[:min]...) 289 | *v = Vec[T](newSlice) 290 | } 291 | return v 292 | } 293 | 294 | // 修改vec可访问的容量, 但是不会修改底层的slice, 只是修改slice的len 295 | func (v *Vec[T]) Truncate(newLen int) { 296 | *v = Vec[T](v.ToSlice()[:newLen]) 297 | } 298 | 299 | // 在vec后面追加newLen 长度的value 300 | func (v *Vec[T]) ExtendWith(newLen int, value T) *Vec[T] { 301 | 302 | oldLen := v.Len() 303 | v.Reserve(newLen) 304 | slice := v.ToSlice() 305 | 306 | l := oldLen + newLen 307 | slice = slice[:l] 308 | 309 | for i := oldLen; i < l; i++ { 310 | slice[i] = value 311 | } 312 | *v = Vec[T](slice) 313 | 314 | return v 315 | } 316 | 317 | // 调整vec的大小, 使用len等于newLen 318 | // 如果newLen > len, 差值部分会填充value 319 | // 如果newLen < len, 多余的部分会被截断 320 | func (v *Vec[T]) Resize(newLen int, value T) *Vec[T] { 321 | l := v.Len() 322 | if newLen > l { 323 | v.ExtendWith(newLen-l, value) 324 | return v 325 | } 326 | 327 | v.Truncate(newLen) 328 | return v 329 | } 330 | 331 | // 深度拷贝一份 332 | func (v *Vec[T]) Clone() *Vec[T] { 333 | newSlice := make([]T, v.Len()) 334 | copy(newSlice, v.ToSlice()) 335 | return (*Vec[T])(&newSlice) 336 | } 337 | 338 | // 如果为空 339 | func (v *Vec[T]) IsEmpty() bool { 340 | return len(v.ToSlice()) == 0 341 | } 342 | 343 | // len 344 | func (v *Vec[T]) Len() int { 345 | return len(v.ToSlice()) 346 | } 347 | 348 | // cap 349 | func (v *Vec[T]) Cap() int { 350 | return cap(v.ToSlice()) 351 | } 352 | 353 | // 返回第1个元素 354 | func (v *Vec[T]) First() (n T, ok bool) { 355 | if v.Len() == 0 { 356 | return 357 | } 358 | 359 | return v.Get(0), true 360 | } 361 | 362 | // 删除vec第一个元素, 并返回它 363 | func (v *Vec[T]) TakeFirst() (n T, ok bool) { 364 | if v.Len() == 0 { 365 | return 366 | } 367 | 368 | n = v.Get(0) 369 | v.Remove(0) 370 | return n, true 371 | } 372 | 373 | // 返回最后一个元素 374 | func (v *Vec[T]) Last() (n T, ok bool) { 375 | if v.Len() == 0 { 376 | return 377 | } 378 | 379 | return v.Get(v.Len() - 1), true 380 | } 381 | 382 | // 原地操作, 回调函数会返回的元素值 383 | func (v *Vec[T]) Map(m func(e T) T) *Vec[T] { 384 | 385 | l := v.Len() 386 | 387 | slice := v.ToSlice() 388 | for i := 0; i < l; i++ { 389 | slice[i] = m(slice[i]) 390 | } 391 | 392 | return v 393 | } 394 | 395 | // Retain 是Filter函数的同义词 396 | func (v *Vec[T]) Retain(filter func(e T) bool) *Vec[T] { 397 | return v.Filter(filter) 398 | } 399 | 400 | // 原地操作, 回调函数返回true的元素保留 401 | func (v *Vec[T]) Filter(filter func(e T) bool) *Vec[T] { 402 | 403 | l := v.Len() 404 | left := 0 405 | 406 | slice := v.ToSlice() 407 | for i := 0; i < l; i++ { 408 | if filter(slice[i]) { 409 | if left != i { 410 | slice[left] = slice[i] 411 | } 412 | left++ 413 | } 414 | } 415 | v.SetLen(left) 416 | return v 417 | } 418 | 419 | // 原地旋转vec, 向左边旋转 420 | func (v *Vec[T]) RotateLeft(n int) *Vec[T] { 421 | l := v.Len() 422 | n %= l 423 | 424 | if n == 0 { 425 | return v 426 | } 427 | 428 | slice := v.ToSlice() 429 | left := make([]T, n) 430 | // 先备份左边 431 | copy(left, slice[:n]) 432 | // 备下的往左拷贝 433 | copy(slice, slice[n:]) 434 | // 右边需要被替换的空间 435 | copy(slice[l-n:], left) 436 | 437 | return v 438 | } 439 | 440 | // 反转 441 | func (v *Vec[T]) Rev() *Vec[T] { 442 | slice := v.ToSlice() 443 | for i, l := 0, v.Len()-1; i < l; i, l = i+1, l-1 { 444 | slice[i], slice[l] = slice[l], slice[i] 445 | } 446 | 447 | *v = *New(slice...) 448 | return v 449 | } 450 | 451 | // 原地旋转vec, 向右边旋转 452 | func (v *Vec[T]) RotateRight(n int) *Vec[T] { 453 | l := v.Len() 454 | n %= l 455 | if n == 0 { 456 | return v 457 | } 458 | 459 | at := l - n 460 | slice := v.ToSlice() 461 | rightVec := make([]T, n) 462 | copy(rightVec, slice[at:]) 463 | 464 | for right, left := l-1, at-1; right >= 0 && left >= 0; { 465 | slice[right] = slice[left] 466 | right-- 467 | left-- 468 | } 469 | 470 | copy(slice[:n], rightVec) 471 | return v 472 | } 473 | 474 | // 用于写入重复的值, 返回新的内存块, 来创建新的vec 475 | func (v *Vec[T]) Repeat(count int) *Vec[T] { 476 | need := v.Len() * count 477 | rv := WithCapacity[T](need) 478 | 479 | for i := 0; i < count; i++ { 480 | rv.Append(v) 481 | } 482 | 483 | return rv 484 | } 485 | 486 | // 二分搜索 487 | func (v *Vec[T]) SearchFunc(f func(T) bool) int { 488 | 489 | // Define f(-1) == false and f(n) == true. 490 | // Invariant: f(i-1) == false, f(j) == true. 491 | i, j := 0, v.Len() 492 | for i < j { 493 | h := int(uint(i+j) >> 1) // avoid overflow when computing h 494 | // i ≤ h < j 495 | if !f(v.Get(h)) { 496 | i = h + 1 // preserves f(i-1) == false 497 | } else { 498 | j = h // preserves f(j) == true 499 | } 500 | } 501 | 502 | // i == j, f(i-1) == false, and f(j) (= f(i)) == true => answer is i. 503 | return i 504 | } 505 | 506 | // 遍历, callback 返回false就停止遍历, 返回true继续遍历 507 | func (v *Vec[T]) Range(callback func(index int, v T) bool) *Vec[T] { 508 | slice := v.ToSlice() 509 | for i, val := range slice { 510 | if !callback(i, val) { 511 | return v 512 | } 513 | } 514 | return v 515 | } 516 | 517 | func getCap(l int) int { 518 | return int(float64(l) * coefficient) 519 | } 520 | -------------------------------------------------------------------------------- /vecdeque/vecdeque.go: -------------------------------------------------------------------------------- 1 | package vecdeque 2 | 3 | // apache 2.0 antlabs 4 | import ( 5 | "errors" 6 | "fmt" 7 | "math" 8 | 9 | "github.com/antlabs/gstl/cmp" 10 | ) 11 | 12 | // 参考文档如下 13 | // https://doc.rust-lang.org/std/collections/struct.VecDeque.html 14 | // https://doc.rust-lang.org/src/alloc/collections/vec_deque/mod.rs.html 15 | // https://doc.rust-lang.org/beta/src/alloc/collections/vec_deque/ring_slices.rs.html 16 | // 翻译好的中文文档 17 | // https://rustwiki.org/zh-CN/src/alloc/collections/vec_deque/mod.rs.html 18 | 19 | const ( 20 | INITIAL_CAPACITY uint = 7 // 2^3 - 1 21 | MINIMUM_CAPACITY uint = 1 // 2 - 1 22 | ) 23 | 24 | var ( 25 | ErrNoData = errors.New("no data") 26 | ) 27 | 28 | type VecDeque[T any] struct { 29 | // tail 总是指向可以读取的第一个元素 30 | // head 只是指向应该写入数据的位置 31 | // 如果tail == head, 则缓存区为空. 环形缓冲区的长度定义为两者之间的距离 32 | tail uint 33 | head uint 34 | buf []T 35 | } 36 | 37 | // 初始化 38 | func New[T any]() *VecDeque[T] { 39 | return WithCapacity[T](int(INITIAL_CAPACITY)) 40 | } 41 | 42 | // 初始VecDeque, 并设置实际需要的容量 43 | func WithCapacity[T any](capacity int) *VecDeque[T] { 44 | cap := nextPowOfTwo(cmp.Max(uint(capacity)+1, MINIMUM_CAPACITY+1)) 45 | return &VecDeque[T]{buf: make([]T, cap, cap)} 46 | } 47 | 48 | // 如果缓冲区满了. 就返回true 49 | func (v *VecDeque[T]) IsFull() bool { 50 | return v.Cap()-v.Len() == 1 51 | } 52 | 53 | // 返回当前使用的容量 54 | func (v *VecDeque[T]) Len() int { 55 | return int(count(v.tail, v.head, uint(v.cap()))) 56 | } 57 | 58 | // 统计数据 59 | func count(tail, head, size uint) uint { 60 | // 结果和 math.Abs(head - tail) & (size -1) 一样 61 | return (head - tail) & (size - 1) 62 | } 63 | 64 | // 扩容 65 | func (v *VecDeque[T]) grow() *VecDeque[T] { 66 | if v.IsFull() { 67 | oldCap := v.cap() 68 | newBuf := make([]T, oldCap*2) 69 | copy(newBuf, v.buf) 70 | v.buf = newBuf 71 | v.handleCapIncrease(uint(oldCap)) 72 | } 73 | return v 74 | } 75 | 76 | // 扩容 77 | func (v *VecDeque[T]) handleCapIncrease(oldCap uint) { 78 | // Move the shortest contiguous section of the ring buffer 79 | // T H 80 | // [o o o o o o o . ] 81 | // T H 82 | // A [o o o o o o o . . . . . . . . . ] 83 | // H T 84 | // [o o . o o o o o ] 85 | // T H 86 | // B [. . . o o o o o o o . . . . . . ] 87 | // H T 88 | // [o o o o o . o o ] 89 | // H T 90 | // C [o o o o o . . . . . . . . . o o ] 91 | if v.tail <= v.head { 92 | // 不需要做啥 93 | return 94 | } 95 | 96 | // 把前面的数据移到后面, 合并起来, 中间没有空隙 97 | if v.head < oldCap-v.tail { 98 | copy(v.buf[oldCap:], v.buf[:v.head]) 99 | v.head += oldCap 100 | return 101 | } 102 | 103 | // 把老的cap右边的数据放到新的cap的最右端 104 | newTail := oldCap + v.tail 105 | copy(v.buf[newTail:], v.buf[v.tail:oldCap]) 106 | v.tail = newTail 107 | } 108 | 109 | // 判断VecDeque 110 | func (v *VecDeque[T]) IsEmpty() bool { 111 | return v.tail == v.head 112 | } 113 | 114 | // 删除最后一个元素, 并且返回它. 如果为空, 返回ErrNoData 115 | func (v *VecDeque[T]) PopBack() (value T, err error) { 116 | if v.IsEmpty() { 117 | err = ErrNoData 118 | return 119 | } 120 | 121 | v.head = v.wrapSub(v.head, 1) 122 | value = v.buf[v.head] 123 | return 124 | } 125 | 126 | // 删除第一个元素, 并且返回它, 如果为空, 返回ErrNoData 127 | func (v *VecDeque[T]) PopFront() (value T, err error) { 128 | if v.IsEmpty() { 129 | err = ErrNoData 130 | return 131 | } 132 | 133 | value = v.buf[v.tail] 134 | v.tail = v.wrapAdd(v.tail, 1) 135 | return 136 | } 137 | 138 | // 将一个元素添加到VecDeque 后面 139 | func (v *VecDeque[T]) PushBack(value T) { 140 | 141 | // 先检查是否满了 142 | if v.IsFull() { 143 | // 满了就扩容 144 | v.grow() 145 | } 146 | 147 | head := v.head 148 | 149 | // 修改head的值 150 | v.head = v.wrapAdd(v.head, uint(1)) 151 | 152 | v.buf[head] = value 153 | // 修改head值 154 | } 155 | 156 | // 将一个元素添加到VecDeque的前面 157 | func (v *VecDeque[T]) PushFront(value T) { 158 | if v.IsFull() { 159 | v.grow() 160 | } 161 | 162 | v.tail = v.wrapSub(v.tail, 1) 163 | v.buf[v.tail] = value 164 | } 165 | 166 | // 根据索引获取指定的值 167 | func (v *VecDeque[T]) Get(i uint) T { 168 | idx := v.wrapAdd(v.tail, uint(i)) 169 | return v.buf[idx] 170 | } 171 | 172 | // 内存里面的物理容量 173 | func (v *VecDeque[T]) cap() int { 174 | return len(v.buf) 175 | } 176 | 177 | // 业务意义上的容量, 有一个格式是空的 178 | func (v *VecDeque[T]) Cap() int { 179 | return v.cap() - 1 180 | } 181 | 182 | // 对index 减去一些值 183 | func (v *VecDeque[T]) wrapSub(index uint, subtrahend uint) uint { 184 | return v.wrapIndex(index - subtrahend) 185 | } 186 | 187 | // 对index 增加一些值 188 | func (v *VecDeque[T]) wrapAdd(index uint, addend uint) uint { 189 | return v.wrapIndex(index + addend) 190 | } 191 | 192 | // 操作index的包装函数 193 | func (v *VecDeque[T]) wrapIndex(index uint) uint { 194 | return wrapIndex(index, uint(v.cap())) 195 | } 196 | 197 | // 操作index的核心函数 198 | func wrapIndex(index uint, size uint) uint { 199 | // 判断size是否是2的n次方 200 | if n := (size & (size - 1)); n != 0 { 201 | panic(fmt.Sprintf("size is always a power of 2, the current size is %d", size)) 202 | } 203 | 204 | return index & (size - 1) 205 | } 206 | 207 | // TODO 优化下 208 | // 使用更好的算法计算 209 | func nextPowOfTwo(n uint) uint { 210 | 211 | for i := 1; i < 32; i++ { 212 | 213 | if nextPowOfTwoNum := math.Pow(2, float64(i)); nextPowOfTwoNum > float64(n) { 214 | return uint(nextPowOfTwoNum) 215 | } 216 | } 217 | 218 | return 0 219 | } 220 | 221 | // 交换索引为i和j的元素 222 | func (v *VecDeque[T]) Swap(i, j uint) { 223 | ri := v.wrapAdd(v.tail, i) 224 | rj := v.wrapAdd(v.tail, j) 225 | v.buf[ri], v.buf[rj] = v.buf[rj], v.buf[ri] 226 | } 227 | 228 | // 向左旋转 229 | func (v *VecDeque[T]) RotateLeftInner(k uint) { 230 | v.head = v.wrapAdd(v.head, k) 231 | v.tail = v.wrapAdd(v.tail, k) 232 | } 233 | 234 | // 向左旋转 235 | func (v *VecDeque[T]) RotateLeft(k uint) { 236 | other := uint(v.Len()) - k 237 | 238 | if k <= other { 239 | v.RotateLeftInner(k) 240 | return 241 | } 242 | 243 | v.RotateRightInner(other) 244 | } 245 | 246 | // 向右旋转 247 | func (v *VecDeque[T]) RotateRightInner(k uint) { 248 | //v.wrapCopy() 249 | v.head = v.wrapSub(v.head, k) 250 | v.tail = v.wrapSub(v.tail, k) 251 | } 252 | 253 | // 向右旋转 254 | func (v *VecDeque[T]) RotateRight(k uint) { 255 | other := uint(v.Len()) - k 256 | if k <= other { 257 | // k = k 258 | // other = o 259 | // kkkkkkkkkkooo 260 | v.RotateRightInner(k) 261 | return 262 | } 263 | 264 | v.RotateLeftInner(other) 265 | 266 | } 267 | 268 | // 尽可能缩小VecDeque的容量 269 | // 它将尽可能接近Len的位置 270 | func (v *VecDeque[T]) ShrinkToFit() { 271 | v.ShrinkTo(0) 272 | } 273 | 274 | // 缩容 275 | func (v *VecDeque[T]) ShrinkTo(minCapacity uint) { 276 | minCapacity = cmp.Min(minCapacity, uint(v.Cap())) 277 | minCapacity = cmp.Max(minCapacity, uint(v.Len())) 278 | targetCap := nextPowOfTwo(cmp.Max(minCapacity+1, MINIMUM_CAPACITY+1)) 279 | 280 | if targetCap < uint(v.cap()) { 281 | 282 | //有三种情况值得关注: 283 | 284 | //所有元素都超出了预期范围 285 | 286 | //元素是连续的,head超出了所需的边界 287 | 288 | //元素是不连续的,尾部超出了期望的界限 289 | 290 | // 291 | 292 | //在所有其他时间,元素位置不受影响。 293 | 294 | // 295 | 296 | //指示应移动头部的元素。 297 | headOutside := v.head == 0 || v.head >= targetCap 298 | if v.tail >= targetCap && headOutside { 299 | // T H 300 | // [. . . . . . . . o o o o o o o . ] 301 | // T H 302 | // [o o o o o o o . ] 303 | copy(v.buf, v.buf[v.tail:v.head]) 304 | v.tail = 0 305 | v.head = uint(v.Len()) 306 | } else if v.tail != 0 && v.tail < targetCap && headOutside { 307 | 308 | // T H 309 | // [. . . o o o o o o o . . . . . . ] 310 | // H T 311 | // [o o . o o o o o ] 312 | length := v.wrapSub(v.head, targetCap) 313 | copy(v.buf, v.buf[targetCap:v.head]) 314 | v.head = length 315 | } else if v.tail >= targetCap { 316 | 317 | // H T 318 | // [o o o o o . . . . . . . . . o o ] 319 | // H T 320 | // [o o o o o . o o ] 321 | length := uint(len(v.buf)) - v.tail 322 | newTail := targetCap - length 323 | copy(v.buf[newTail:], v.buf[v.tail:]) 324 | v.tail = newTail 325 | } 326 | 327 | newBuf := make([]T, targetCap) 328 | copy(newBuf, v.buf) 329 | v.buf = newBuf 330 | } 331 | } 332 | 333 | func (v *VecDeque[T]) Truncate() { 334 | 335 | } 336 | 337 | func (v *VecDeque[T]) ToSlices() (first []T, second []T) { 338 | return 339 | } 340 | 341 | func (v *VecDeque[T]) wrapCopy(dst, src, length uint) { 342 | if src == dst || length == 0 { 343 | return 344 | } 345 | 346 | } 347 | 348 | func (v *VecDeque[T]) ReserveExact() { 349 | 350 | } 351 | 352 | func (v *VecDeque[T]) Reserve() { 353 | 354 | } 355 | 356 | func (v *VecDeque[T]) Contains(x T) bool { 357 | return false 358 | } 359 | 360 | // 获取第1个元素, 第二个参数返回错误 361 | func (v *VecDeque[T]) Front() (e T, err error) { 362 | if v.Len() == 0 { 363 | err = ErrNoData 364 | return 365 | } 366 | 367 | return v.Get(0), nil 368 | } 369 | 370 | // 获取最后一个元素, 第二个参数返回错误 371 | func (v *VecDeque[T]) Back() (e T, err error) { 372 | if v.Len() == 0 { 373 | err = ErrNoData 374 | return 375 | } 376 | 377 | newIndex := v.wrapSub(uint(v.Len()), uint(1)) 378 | return v.Get(newIndex), nil 379 | } 380 | 381 | // 从 `VecDeque` 的任何位置删除一个元素并返回,并用第一个元素替换它。 382 | func (v *VecDeque[T]) SwapRemoveFront(index uint) (e T, err error) { 383 | length := uint(v.Len()) 384 | 385 | if index >= length { 386 | err = ErrNoData 387 | return 388 | } 389 | 390 | if length > 0 && index < length && index != 0 { 391 | v.Swap(index, 0) 392 | } 393 | 394 | return v.PopFront() 395 | 396 | } 397 | 398 | func (v *VecDeque[T]) SwapRemoveBack() { 399 | 400 | } 401 | 402 | // 在VecDeque内的index处插入一个元素, 所有索引大于或者等于'index'的元素向后移动 403 | // TODO 404 | func (v *VecDeque[T]) Insert(index uint, value T) { 405 | if v.IsFull() { 406 | v.grow() 407 | } 408 | 409 | // 移动环形缓冲区中最少的元素并插入 410 | // 给定对象 411 | // 412 | // 最多会移动len/2-1元素。O(min(n, n-i)) 413 | // 414 | // 主要有三种情况: 415 | // 元素是连续的 416 | // -尾部为0时的特殊情况 417 | // 元素不连续,插入部分位于尾部 418 | // 元素不连续,插入部分位于头部 419 | // 420 | // 对于每一种情况,还有两种情况: 421 | // 插入物更靠近尾部 422 | // 插入物更靠近头部 423 | // 424 | // key:H - v.head 425 | // T - v.tail 426 | // o - 有效元素 427 | // I - 插入元素 428 | // A - 应位于插入点之后的元素 429 | // M - 表示元素已移动 430 | 431 | //idx := v.wrapAdd(v.tail, index) 432 | distanceToTail := index 433 | distanceToHead := uint(v.Len()) - index 434 | contiguous := v.isContiguous() 435 | 436 | if contiguous && distanceToTail < distanceToHead { 437 | 438 | if index == 0 { 439 | // push_front 440 | // 441 | // T I H 442 | // [A o o o o o o . . . . . . . 443 | // . 444 | // .] 445 | // 446 | // H T 447 | // [A o o o o o o o . . . . . I] 448 | 449 | v.tail = v.wrapSub(v.tail, 1) 450 | } else { 451 | 452 | } 453 | } 454 | } 455 | 456 | func (v *VecDeque[T]) Remove(index int) { 457 | 458 | } 459 | 460 | func (v *VecDeque[T]) SplitOff() { 461 | 462 | } 463 | 464 | func (v *VecDeque[T]) Append(other *VecDeque[T]) { 465 | 466 | } 467 | 468 | func (v *VecDeque[T]) Retain() { 469 | 470 | } 471 | 472 | func (v *VecDeque[T]) ResizeWith() { 473 | 474 | } 475 | 476 | func (v *VecDeque[T]) isContiguous() bool { 477 | return v.tail <= v.head 478 | } 479 | 480 | func (v *VecDeque[T]) MakeContiguous() []T { 481 | if v.isContiguous() { 482 | return v.buf[v.tail:v.head] 483 | } 484 | 485 | cap := uint(v.cap()) //取出物理容量 486 | length := uint(v.Len()) // 取出已存元素个数 487 | free := v.tail - v.head // 空间的空间个数 488 | tailLen := cap - v.tail // tail到右顶边的个数 489 | 490 | if free >= tailLen { 491 | // 有足够的可用空间来一次性复制尾部,这意味着我们先将头向后移动,然后再将尾部复制到正确的位置。 492 | // 493 | // 494 | // 从: DEFGH....ABC 到: ABCDEFGH.... 495 | 496 | // ...DEFGH.ABC 497 | copy(v.buf[tailLen:], v.buf[:v.head]) 498 | // ABCDEFGH.... 499 | copy(v.buf, v.buf[v.tail:]) 500 | v.tail = 0 501 | v.head = length 502 | return v.buf[:v.head] 503 | } 504 | 505 | if free > v.head { 506 | // 有足够的自由空间可以一次性复制头部,这意味着我们先将尾部向前移动,然后再将头部复制到正确的位置。 507 | // 508 | // 509 | // 从: FGH....ABCDE 到: ...ABCDEFGH。 510 | // 511 | // 512 | 513 | // FGHABCDE.... 514 | copy(v.buf[v.head:], v.buf[v.tail:]) 515 | 516 | // ...ABCDEFGH. 517 | copy(v.buf[v.head+v.tail:], v.buf[:v.head]) 518 | v.tail = v.head 519 | v.head = v.wrapAdd(v.tail, length) 520 | } 521 | 522 | // free 小于头和尾,这意味着我们必须缓慢地 "swap" 尾和头。 523 | // 524 | // 从: EFGHI...ABCD 或 HIJK.ABCDEFG 525 | // 到: ABCDEFGHI... 或 ABCDEFGHIJK. 526 | leftEdge := uint(0) 527 | rightEdge := v.tail 528 | 529 | // The general problem looks like this 530 | // GHIJKLM...ABCDEF - before any swaps 531 | // ABCDEFM...GHIJKL - after 1 pass of swaps 532 | // ABCDEFGHIJM...KL - swap until the left edge reaches the temp store 533 | // - then restart the algorithm with a new (smaller) store 534 | // Sometimes the temp store is reached when the right edge is at the end 535 | // of the buffer - this means we've hit the right order with fewer swaps! 536 | // E.g 537 | // EF..ABCD 538 | // ABCDEF.. - after four only swaps we've finished 539 | 540 | // TODO 再仔细捋一捋逻辑 541 | for leftEdge < length && rightEdge != cap { 542 | rightOffset := uint(0) 543 | for i := leftEdge; i < rightEdge; i++ { 544 | rightOffset = (i - leftEdge) % (cap - rightEdge) 545 | src := (rightEdge + rightOffset) 546 | v.buf[i], v.buf[src] = v.buf[src], v.buf[i] 547 | } 548 | 549 | nOps := rightEdge - leftEdge 550 | leftEdge += nOps 551 | rightEdge += rightOffset + 1 552 | } 553 | 554 | v.tail = 0 555 | v.head = length 556 | 557 | return v.buf[v.tail:v.head] 558 | } 559 | 560 | func (v *VecDeque[T]) BinarySearch() { 561 | 562 | } 563 | -------------------------------------------------------------------------------- /vecdeque/vecdeque_test.go: -------------------------------------------------------------------------------- 1 | package vecdeque 2 | 3 | // apache 2.0 antlabs 4 | import "testing" 5 | 6 | func Test_PushBack(t *testing.T) { 7 | v := New[int]() 8 | 9 | max := 100 10 | need := make([]int, 0, max) 11 | got := make([]int, 0, max) 12 | 13 | for i := 0; i < max; i++ { 14 | need = append(need, i) 15 | 16 | v.PushBack(i) 17 | v2, err := v.PopFront() 18 | if err != nil { 19 | break 20 | } 21 | got = append(got, v2) 22 | } 23 | } 24 | --------------------------------------------------------------------------------