├── go.mod ├── README.md ├── LICENSE ├── omap.go ├── bench_test.go ├── treap.go ├── omap_test.go ├── avl.go └── llrb.go /go.mod: -------------------------------------------------------------------------------- 1 | module rsc.io/omap 2 | 3 | go 1.23.0 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Package omap implements an ordered map[K]V. 2 | See the [API reference](https://pkg.go.dev/rsc.io/omap). 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2009 The Go Authors. All rights reserved. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are 5 | met: 6 | 7 | * Redistributions of source code must retain the above copyright 8 | notice, this list of conditions and the following disclaimer. 9 | * Redistributions in binary form must reproduce the above 10 | copyright notice, this list of conditions and the following disclaimer 11 | in the documentation and/or other materials provided with the 12 | distribution. 13 | * Neither the name of Google Inc. nor the names of its 14 | contributors may be used to endorse or promote products derived from 15 | this software without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /omap.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Package omap implements in-memory ordered maps. 6 | // [Map][K, V] is suitable for ordered types K, 7 | // while [MapFunc][K, V] supports arbitrary keys and comparison functions. 8 | package omap 9 | 10 | // The implementation is a treap. See: 11 | // https://en.wikipedia.org/wiki/Treap 12 | // https://faculty.washington.edu/aragon/pubs/rst89.pdf 13 | 14 | import "cmp" 15 | 16 | // A Map is a map[K]V ordered according to K's standard Go ordering. 17 | // The zero value of a Map is an empty Map ready to use. 18 | type Map[K cmp.Ordered, V any] struct { 19 | treapMap[K, V] 20 | } 21 | 22 | func (m *Map[K, V]) Join(more *Map[K, V]) { 23 | m.join(nil, more.treap) 24 | more.root = nil 25 | } 26 | 27 | func (m *Map[K, V]) Split(key K) (val V, ok bool, more *Map[K, V]) { 28 | x, after := m.split(key) 29 | if x != nil { 30 | val, ok = x.val, true 31 | } 32 | more = &Map[K, V]{treapMap[K, V]{after}} 33 | return val, ok, more 34 | } 35 | 36 | type MapFunc[K, V any] struct { 37 | treapMapFunc[K, V] 38 | } 39 | 40 | func NewMapFunc[K, V any](cmp func(K, K) int) *MapFunc[K, V] { 41 | m := new(MapFunc[K, V]) 42 | m.cmp = cmp 43 | return m 44 | } 45 | 46 | func (m *MapFunc[K, V]) Join(more *MapFunc[K, V]) { 47 | m.join(nil, more.treap) 48 | more.root = nil 49 | } 50 | 51 | func (m *MapFunc[K, V]) Split(key K) (val V, ok bool, more *MapFunc[K, V]) { 52 | x, after := m.split(key) 53 | if x != nil { 54 | val, ok = x.val, true 55 | } 56 | more = &MapFunc[K, V]{treapMapFunc[K, V]{after, m.cmp}} 57 | return val, ok, more 58 | } 59 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package omap 6 | 7 | import ( 8 | "math/rand/v2" 9 | "testing" 10 | ) 11 | 12 | var getMap, getMapSeq *Map[int, int] 13 | 14 | func benchMaps(b *testing.B, bench func(b *testing.B, newMap func() Mapper[int, int])) { 15 | for _, m := range maps { 16 | b.Run(m.name, func(b *testing.B) { bench(b, m.new) }) 17 | } 18 | } 19 | 20 | func BenchmarkGetRandRand(b *testing.B) { 21 | benchMaps(b, func(b *testing.B, newMap func() Mapper[int, int]) { 22 | const N = 100000 23 | m := newMap() 24 | rand := rand.New(rand.NewPCG(1, 1)) 25 | perm := rand.Perm(N) 26 | for _, v := range rand.Perm(N) { 27 | m.Set(v, v) 28 | } 29 | //b.Logf("depth=%v", m.Depth()) 30 | perm = rand.Perm(N) 31 | b.ResetTimer() 32 | n := 0 33 | for range b.N { 34 | m.Get(perm[n]) 35 | n++ 36 | if n == N { 37 | n = 0 38 | } 39 | } 40 | }) 41 | } 42 | 43 | func BenchmarkGetSeqRand(b *testing.B) { 44 | benchMaps(b, func(b *testing.B, newMap func() Mapper[int, int]) { 45 | const N = 100000 46 | rand := rand.New(rand.NewPCG(1, 1)) 47 | m := newMap() 48 | for v := range N { 49 | m.Set(v, v) 50 | } 51 | //b.Logf("depth=%v", m.Depth()) 52 | perm := rand.Perm(N) 53 | b.ResetTimer() 54 | n := 0 55 | for range b.N { 56 | m.Get(perm[n]) 57 | n++ 58 | if n == N { 59 | n = 0 60 | } 61 | } 62 | }) 63 | } 64 | 65 | func BenchmarkSetDelete(b *testing.B) { 66 | benchMaps(b, func(b *testing.B, newMap func() Mapper[int, int]) { 67 | const N = 100000 68 | perm := rand.Perm(N) 69 | perm2 := rand.Perm(N) 70 | m := newMap() 71 | b.ResetTimer() 72 | n := 0 73 | for range b.N { 74 | if n < N { 75 | m.Set(perm[n], perm[n]) 76 | } else { 77 | m.Delete(perm2[n-N]) 78 | } 79 | n++ 80 | if n == 2*N { 81 | n = 0 82 | } 83 | } 84 | }) 85 | } 86 | -------------------------------------------------------------------------------- /treap.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package omap 6 | 7 | import ( 8 | "bytes" 9 | "cmp" 10 | "fmt" 11 | "iter" 12 | "math/rand/v2" 13 | ) 14 | 15 | type treapMap[K cmp.Ordered, V any] struct { 16 | treap[K, V] 17 | } 18 | 19 | type treapMapFunc[K, V any] struct { 20 | treap[K, V] 21 | cmp func(K, K) int 22 | } 23 | 24 | func (t *treapMapFunc[K, V]) init(cmp func(K, K) int) { 25 | t.cmp = cmp 26 | } 27 | 28 | type treap[K, V any] struct { 29 | root *treapNode[K, V] 30 | // Rotates int 31 | } 32 | 33 | type treapNode[K, V any] struct { 34 | parent *treapNode[K, V] 35 | left *treapNode[K, V] 36 | right *treapNode[K, V] 37 | key K 38 | val V 39 | pri uint64 40 | } 41 | 42 | func (t *treap[K, V]) setRoot(x *treapNode[K, V]) { 43 | t.root = x 44 | if x != nil { 45 | x.parent = nil 46 | } 47 | } 48 | 49 | func (x *treapNode[K, V]) setLeft(y *treapNode[K, V]) { 50 | x.left = y 51 | if y != nil { 52 | y.parent = x 53 | } 54 | } 55 | 56 | func (x *treapNode[K, V]) setRight(y *treapNode[K, V]) { 57 | x.right = y 58 | if y != nil { 59 | y.parent = x 60 | } 61 | } 62 | 63 | func (m *treapMap[K, V]) Get(key K) (val V, ok bool) { 64 | x := m.get(key) 65 | if x == nil { 66 | return 67 | } 68 | return x.val, true 69 | } 70 | 71 | func (m *treapMap[K, V]) get(key K) *treapNode[K, V] { 72 | if m == nil { 73 | return nil 74 | } 75 | x := m.root 76 | for x != nil { 77 | if key == x.key { 78 | return x 79 | } 80 | if key < x.key { 81 | x = x.left 82 | } else { 83 | x = x.right 84 | } 85 | } 86 | return nil 87 | } 88 | 89 | func (m *treapMapFunc[K, V]) Get(key K) (val V, ok bool) { 90 | x := m.get(key) 91 | if x == nil { 92 | return 93 | } 94 | return x.val, true 95 | } 96 | 97 | func (m *treapMapFunc[K, V]) get(key K) *treapNode[K, V] { 98 | if m == nil { 99 | return nil 100 | } 101 | x := m.root 102 | for x != nil { 103 | c := m.cmp(key, x.key) 104 | if c == 0 { 105 | return x 106 | } 107 | if c < 0 { 108 | x = x.left 109 | } else { 110 | x = x.right 111 | } 112 | } 113 | return nil 114 | } 115 | 116 | // Delete deletes m[key]. 117 | func (m *treapMap[K, V]) Delete(key K) { 118 | m.delete(m.get(key)) 119 | } 120 | 121 | // Delete deletes m[key]. 122 | func (m *treapMapFunc[K, V]) Delete(key K) { 123 | m.delete(m.get(key)) 124 | } 125 | 126 | func (t *treap[K, V]) delete(x *treapNode[K, V]) { 127 | if t == nil { 128 | panic("Delete of nil Map") 129 | } 130 | if x == nil { 131 | return 132 | } 133 | 134 | // Rotate x down to be leaf of tree for removal, respecting priorities. 135 | for x.right != nil || x.left != nil { 136 | if x.right == nil || x.left != nil && x.left.pri < x.right.pri { 137 | t.rotateRight(x) 138 | } else { 139 | t.rotateLeft(x) 140 | } 141 | } 142 | 143 | // Remove x, now a leaf. 144 | switch p := x.parent; { 145 | case p == nil: 146 | t.root = nil 147 | case p.left == x: 148 | p.left = nil 149 | default: 150 | p.right = nil 151 | } 152 | x.pri = 0 // mark deleted 153 | } 154 | 155 | // rotateUp rotates x upward in the tree to correct any priority inversions. 156 | func (t *treap[K, V]) rotateUp(x *treapNode[K, V]) { 157 | // Rotate up into tree according to priority. 158 | for x.parent != nil && x.parent.pri > x.pri { 159 | if x.parent.left == x { 160 | t.rotateRight(x.parent) 161 | } else { 162 | t.rotateLeft(x.parent) 163 | } 164 | } 165 | } 166 | 167 | // rotateLeft rotates the subtree rooted at node x. 168 | // turning (x a (y b c)) into (y (x a b) c). 169 | func (t *treap[K, V]) rotateLeft(x *treapNode[K, V]) { 170 | // t.Rotates++ 171 | // p -> (x a (y b c)) 172 | p := x.parent 173 | y := x.right 174 | b := y.left 175 | 176 | y.left = x 177 | x.parent = y 178 | x.right = b 179 | if b != nil { 180 | b.parent = x 181 | } 182 | 183 | y.parent = p 184 | switch { 185 | case p == nil: 186 | t.root = y 187 | case p.left == x: 188 | p.left = y 189 | case p.right == x: 190 | p.right = y 191 | default: 192 | // unreachable 193 | panic("corrupt treap") 194 | } 195 | } 196 | 197 | // rotateRight rotates the subtree rooted at node y. 198 | // turning (y (x a b) c) into (x a (y b c)). 199 | func (t *treap[K, V]) rotateRight(y *treapNode[K, V]) { 200 | // t.Rotates++ 201 | // p -> (y (x a b) c) 202 | p := y.parent 203 | x := y.left 204 | b := x.right 205 | 206 | x.right = y 207 | y.parent = x 208 | y.left = b 209 | if b != nil { 210 | b.parent = y 211 | } 212 | 213 | x.parent = p 214 | switch { 215 | case p == nil: 216 | t.root = x 217 | case p.left == y: 218 | p.left = x 219 | case p.right == y: 220 | p.right = x 221 | default: 222 | // unreachable 223 | panic("corrupt treap") 224 | } 225 | } 226 | 227 | func (m *treapMap[K, V]) Set(key K, val V) { 228 | pos, parent := m.locate(key) 229 | m.set(key, val, pos, parent) 230 | } 231 | 232 | func (m *treapMapFunc[K, V]) Set(key K, val V) { 233 | pos, parent := m.locate(key) 234 | m.set(key, val, pos, parent) 235 | } 236 | 237 | func (t *treap[K, V]) set(key K, val V, pos **treapNode[K, V], parent *treapNode[K, V]) { 238 | if x := *pos; x != nil { 239 | x.val = val 240 | return 241 | } 242 | x := &treapNode[K, V]{key: key, val: val, parent: parent, pri: rand.Uint64() | 1} 243 | *pos = x 244 | t.rotateUp(x) 245 | } 246 | 247 | func (m *treapMap[K, V]) locate(key K) (pos **treapNode[K, V], parent *treapNode[K, V]) { 248 | pos, x := &m.root, m.root 249 | for x != nil && key != x.key { 250 | parent = x 251 | if key < x.key { 252 | pos, x = &x.left, x.left 253 | } else { 254 | pos, x = &x.right, x.right 255 | } 256 | } 257 | return pos, parent 258 | } 259 | 260 | func (m *treapMapFunc[K, V]) locate(key K) (pos **treapNode[K, V], parent *treapNode[K, V]) { 261 | pos, x := &m.root, m.root 262 | for x != nil { 263 | c := m.cmp(key, x.key) 264 | if c == 0 { 265 | break 266 | } 267 | parent = x 268 | if c < 0 { 269 | pos, x = &x.left, x.left 270 | } else { 271 | pos, x = &x.right, x.right 272 | } 273 | } 274 | return pos, parent 275 | } 276 | 277 | func (t *treapMap[K, V]) split(key K) (x *treapNode[K, V], after treap[K, V]) { 278 | return t.treap.split(t.locate(key)) 279 | } 280 | 281 | func (t *treapMapFunc[K, V]) split(key K) (x *treapNode[K, V], after treap[K, V]) { 282 | return t.treap.split(t.locate(key)) 283 | } 284 | 285 | func (t *treap[K, V]) split(pos **treapNode[K, V], parent *treapNode[K, V]) (x *treapNode[K, V], after treap[K, V]) { 286 | clear := false 287 | x = *pos 288 | if x == nil { 289 | x = &treapNode[K, V]{parent: parent} 290 | *pos = x 291 | clear = true 292 | } 293 | x.pri = 0 294 | t.rotateUp(x) 295 | t.setRoot(x.left) 296 | after.setRoot(x.right) 297 | x.left, x.right = nil, nil 298 | if clear { 299 | x = nil 300 | } 301 | return x, after 302 | } 303 | 304 | func (t *treap[K, V]) min() *treapNode[K, V] { 305 | x := t.root 306 | for x != nil && x.left != nil { 307 | x = x.left 308 | } 309 | return x 310 | } 311 | 312 | func (t *treap[K, V]) max() *treapNode[K, V] { 313 | x := t.root 314 | for x != nil && x.right != nil { 315 | x = x.right 316 | } 317 | return x 318 | } 319 | 320 | func (t *treap[K, V]) join(x *treapNode[K, V], after treap[K, V]) { 321 | if x == nil { 322 | if x = after.min(); x == nil { 323 | return 324 | } 325 | after.delete(x) 326 | } 327 | if x.left != nil || x.right != nil || x.pri != 0 { 328 | panic("treap join misuse") 329 | } 330 | x.setRight(after.root) 331 | max := t.max() 332 | if max == nil { 333 | t.setRoot(x) 334 | } else { 335 | max.setRight(x) 336 | } 337 | if x.right != nil { 338 | t.rotateUp(x.right) 339 | } else { 340 | t.rotateUp(x) 341 | } 342 | } 343 | 344 | func (t *treapMap[K, V]) DeleteRange(lo, hi K) { 345 | if t == nil { 346 | panic("nil DeleteRange") 347 | } 348 | if lo > hi { 349 | return 350 | } 351 | t.deleteRange(lo, hi, t.split) 352 | } 353 | 354 | func (t *treapMapFunc[K, V]) DeleteRange(lo, hi K) { 355 | if t == nil { 356 | panic("nil DeleteRange") 357 | } 358 | if t.cmp(lo, hi) > 0 { 359 | return 360 | } 361 | t.deleteRange(lo, hi, t.split) 362 | } 363 | 364 | func (t *treap[K, V]) deleteRange(lo, hi K, split func(K) (*treapNode[K, V], treap[K, V])) { 365 | _, after := split(hi) 366 | _, middle := split(lo) 367 | t.join(nil, after) 368 | middle.root.markDeleted() 369 | } 370 | 371 | func (x *treapNode[K, V]) markDeleted() { 372 | if x == nil { 373 | return 374 | } 375 | x.pri = 0 376 | x.left.markDeleted() 377 | x.right.markDeleted() 378 | } 379 | 380 | func (t *treap[K, V]) Depth() int { 381 | return t.root.depth() 382 | } 383 | 384 | func (x *treapNode[K, V]) depth() int { 385 | if x == nil { 386 | return 0 387 | } 388 | return 1 + max(x.left.depth(), x.right.depth()) 389 | } 390 | 391 | // All returns an iterator over the map m. 392 | // If m is modified during the iteration, some keys may not be visited. 393 | // No keys will be visited multiple times. 394 | func (m *treapMap[K, V]) All() iter.Seq2[K, V] { 395 | return m.all(m.locate) 396 | } 397 | 398 | // All returns an iterator over the map m. 399 | // If m is modified during the iteration, some keys may not be visited. 400 | // No keys will be visited multiple times. 401 | func (m *treapMapFunc[K, V]) All() iter.Seq2[K, V] { 402 | return m.all(m.locate) 403 | } 404 | 405 | func (t *treap[K, V]) all(locate func(K) (**treapNode[K, V], *treapNode[K, V])) iter.Seq2[K, V] { 406 | return func(yield func(K, V) bool) { 407 | if t == nil { 408 | return 409 | } 410 | x := t.root 411 | if x != nil { 412 | for x.left != nil { 413 | x = x.left 414 | } 415 | } 416 | for x != nil && yield(x.key, x.val) { 417 | if x.pri != 0 { 418 | // still in tree 419 | x = x.next() 420 | } else { 421 | // deleted 422 | x = t.nextAfter(locate(x.key)) 423 | } 424 | } 425 | } 426 | } 427 | 428 | func (x *treapNode[K, V]) next() *treapNode[K, V] { 429 | if x.right == nil { 430 | for x.parent != nil && x.parent.right == x { 431 | x = x.parent 432 | } 433 | return x.parent 434 | } 435 | x = x.right 436 | for x.left != nil { 437 | x = x.left 438 | } 439 | return x 440 | } 441 | 442 | func (t *treap[K, V]) nextAfter(pos **treapNode[K, V], parent *treapNode[K, V]) *treapNode[K, V] { 443 | switch { 444 | case *pos != nil: 445 | return (*pos).next() 446 | case parent == nil: 447 | return nil 448 | case pos == &parent.left: 449 | return parent 450 | default: 451 | return parent.next() 452 | } 453 | } 454 | 455 | // Scan returns an iterator over the map m 456 | // limited to keys k satisfying lo ≤ k ≤ hi. 457 | // 458 | // If m is modified during the iteration, some keys may not be visited. 459 | // No keys will be visited multiple times. 460 | func (m *treapMap[K, V]) Scan(lo, hi K) iter.Seq2[K, V] { 461 | return m.scan(lo, hi, cmp.Compare[K], m.locate) 462 | } 463 | 464 | // Scan returns an iterator over the map m 465 | // limited to keys k satisfying lo ≤ k ≤ hi. 466 | // 467 | // If m is modified during the iteration, some keys may not be visited. 468 | // No keys will be visited multiple times. 469 | func (m *treapMapFunc[K, V]) Scan(lo, hi K) iter.Seq2[K, V] { 470 | return m.scan(lo, hi, m.cmp, m.locate) 471 | } 472 | 473 | func (t *treap[K, V]) scan(lo, hi K, cmp func(K, K) int, locate func(K) (**treapNode[K, V], *treapNode[K, V])) iter.Seq2[K, V] { 474 | return func(yield func(K, V) bool) { 475 | if t == nil { 476 | return 477 | } 478 | pos, parent := locate(lo) 479 | x := *pos 480 | if x == nil { 481 | x = t.nextAfter(pos, parent) 482 | } 483 | for x != nil && cmp(x.key, hi) <= 0 && yield(x.key, x.val) { 484 | if x.pri != 0 { 485 | x = x.next() 486 | } else { 487 | x = t.nextAfter(locate(x.key)) 488 | } 489 | } 490 | } 491 | } 492 | 493 | func (t *treap[K, V]) Dump() string { 494 | var buf bytes.Buffer 495 | var walk func(*treapNode[K, V]) 496 | walk = func(x *treapNode[K, V]) { 497 | if x == nil { 498 | fmt.Fprintf(&buf, "nil") 499 | return 500 | } 501 | fmt.Fprintf(&buf, "(%v:%v ", x.key, x.val) 502 | walk(x.left) 503 | fmt.Fprintf(&buf, " ") 504 | walk(x.right) 505 | fmt.Fprintf(&buf, ")") 506 | } 507 | walk(t.root) 508 | return buf.String() 509 | 510 | } 511 | -------------------------------------------------------------------------------- /omap_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //go:build go1.23 6 | 7 | package omap 8 | 9 | import ( 10 | "cmp" 11 | "fmt" 12 | "iter" 13 | "math/rand/v2" 14 | "slices" 15 | "sort" 16 | "testing" 17 | ) 18 | 19 | type Mapper[K, V any] interface { 20 | Get(key K) (V, bool) 21 | All() iter.Seq2[K, V] 22 | Delete(key K) 23 | DeleteRange(lo, hi K) 24 | Scan(lo, hi K) iter.Seq2[K, V] 25 | Set(key K, val V) 26 | Dump() string 27 | Depth() int 28 | } 29 | 30 | func permute(m Mapper[int, int], n int) (perm, slice []int) { 31 | perm = rand.Perm(n) 32 | slice = make([]int, 2*n+1) 33 | //println("P") 34 | for i, x := range perm { 35 | //println("X", 2*x+1, m.Dump()) 36 | m.Set(2*x+1, i+1) 37 | slice[2*x+1] = i + 1 38 | } 39 | // Overwrite-Set half the entries. 40 | for i, x := range perm[:len(perm)/2] { 41 | m.Set(2*x+1, i+100) 42 | slice[2*x+1] = i + 100 43 | } 44 | return perm, slice 45 | } 46 | 47 | var maps = []struct { 48 | name string 49 | new func() Mapper[int, int] 50 | }{ 51 | {"Map", func() Mapper[int, int] { return new(Map[int, int]) }}, 52 | {"treapMap", func() Mapper[int, int] { return new(treapMap[int, int]) }}, 53 | {"treapMapFunc", func() Mapper[int, int] { 54 | m := new(treapMapFunc[int, int]) 55 | m.init(cmp.Compare[int]) 56 | return m 57 | }}, 58 | {"avlMap", func() Mapper[int, int] { return new(avlMap[int, int]) }}, 59 | {"avlMapFunc", func() Mapper[int, int] { 60 | m := new(avlMapFunc[int, int]) 61 | m.init(cmp.Compare[int]) 62 | return m 63 | }}, 64 | {"llrbMap", func() Mapper[int, int] { return new(llrbMap[int, int]) }}, 65 | {"llrbMapFunc", func() Mapper[int, int] { 66 | m := new(llrbMapFunc[int, int]) 67 | m.init(cmp.Compare[int]) 68 | return m 69 | }}, 70 | } 71 | 72 | func test(t *testing.T, f func(*testing.T, func() Mapper[int, int])) { 73 | for _, m := range maps { 74 | t.Run(m.name, func(t *testing.T) { f(t, m.new) }) 75 | } 76 | } 77 | 78 | func TestGet(t *testing.T) { 79 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 80 | for N := range 100 { 81 | m := newMap() 82 | _, slice := permute(m, N) 83 | for k, want := range slice { 84 | v, ok := m.Get(k) 85 | if v != want || ok != (want > 0) { 86 | t.Fatalf("Get(%d) = %d, %v, want %d, %v\nM: %v", k, v, ok, want, want > 0, m.Dump()) 87 | } 88 | } 89 | } 90 | }) 91 | } 92 | 93 | func TestAll(t *testing.T) { 94 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 95 | for N := range 11 { 96 | m := newMap() 97 | _, slice := permute(m, N) 98 | var have []int 99 | for k, v := range m.All() { 100 | if v != slice[k] { 101 | t.Errorf("All() returned %d, %d want %d, %d", k, v, k, slice[k]) 102 | } 103 | have = append(have, k) 104 | if len(have) > N+5 { // too many; looping? 105 | break 106 | } 107 | } 108 | var want []int 109 | for k, v := range slice { 110 | if v != 0 { 111 | want = append(want, k) 112 | } 113 | } 114 | if !slices.Equal(have, want) { 115 | t.Errorf("All() = %v, want %v", have, want) 116 | } 117 | } 118 | }) 119 | } 120 | 121 | func TestScan(t *testing.T) { 122 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 123 | for N := range 11 { 124 | m := newMap() 125 | _, slice := permute(m, N) 126 | for hi := range slice { 127 | for lo := range hi + 1 { 128 | var have []int 129 | for k, v := range m.Scan(lo, hi) { 130 | if v != slice[k] { 131 | t.Errorf("All() returned %d, %d want %d, %d", k, v, k, slice[k]) 132 | } 133 | have = append(have, k) 134 | if len(have) > N+5 { // too many; looping? 135 | break 136 | } 137 | } 138 | var want []int 139 | for k, v := range slice { 140 | if v != 0 && lo <= k && k <= hi { 141 | want = append(want, k) 142 | } 143 | } 144 | if !slices.Equal(have, want) { 145 | t.Errorf("All() = %v, want %v", have, want) 146 | } 147 | } 148 | } 149 | } 150 | }) 151 | } 152 | 153 | func TestDelete(t *testing.T) { 154 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 155 | for N := range 50 { 156 | for range 100 { 157 | m := newMap() 158 | _, slice := permute(m, N) 159 | for _, x := range rand.Perm(len(slice)) { 160 | m.Delete(x) 161 | slice[x] = 0 162 | var have []int 163 | for k, _ := range m.All() { 164 | have = append(have, k) 165 | } 166 | var want []int 167 | for x, v := range slice { 168 | if v != 0 { 169 | want = append(want, x) 170 | } 171 | } 172 | slices.Sort(want) 173 | if !slices.Equal(have, want) { 174 | t.Errorf("after Delete(%v), All() = %v, want %v", x, have, want) 175 | } 176 | } 177 | } 178 | } 179 | }) 180 | } 181 | 182 | func TestDelete2(t *testing.T) { 183 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 184 | for N := range 5 { 185 | for perm := range allPerm(N) { 186 | for dperm := range allPerm(N) { 187 | m := newMap() 188 | func() { 189 | defer func() { 190 | if e := recover(); e != nil { 191 | fmt.Println("SET", perm, m.Dump(), "DEL", dperm) 192 | panic(e) 193 | } 194 | }() 195 | for _, v := range perm { 196 | m.Set(v, v) 197 | } 198 | for _, v := range dperm { 199 | m.Delete(v) 200 | } 201 | }() 202 | } 203 | } 204 | } 205 | }) 206 | } 207 | 208 | func TestDelete3(t *testing.T) { 209 | set := []int{17, 9, 23, 7, 11, 19, 27} 210 | del := []int{17} 211 | m := new(llrbMap[int, int]) 212 | for _, v := range set { 213 | m.Set(v, v) 214 | } 215 | for _, v := range del { 216 | m.Delete(v) 217 | } 218 | } 219 | 220 | func TestDeleteRange(t *testing.T) { 221 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 222 | for N := range 11 { 223 | for hi := range 2*N + 1 { 224 | for lo := range hi + 1 { 225 | m := newMap() 226 | _, slice := permute(m, N) 227 | if lo < hi { 228 | m.DeleteRange(hi, lo) // want no-op 229 | } 230 | m.DeleteRange(lo, hi) 231 | var have []int 232 | for k, _ := range m.All() { 233 | have = append(have, k) 234 | } 235 | var want []int 236 | for k, v := range slice { 237 | if v != 0 && (k < lo || hi < k) { 238 | want = append(want, k) 239 | } 240 | } 241 | if !slices.Equal(have, want) { 242 | t.Fatalf("after DeleteRange(%d, %d), All() = %v, want %v", lo, hi, have, want) 243 | } 244 | } 245 | } 246 | } 247 | }) 248 | } 249 | 250 | func TestScanDelete(t *testing.T) { 251 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 252 | for _, mode := range []string{"prev", "current", "next"} { 253 | for N := range 8 { 254 | for target := 1; target <= 2*N-1; target += 2 { 255 | m := newMap() 256 | _, slice := permute(m, N) 257 | var have []int 258 | var deleted int 259 | for k, _ := range m.All() { 260 | if k == target { 261 | switch mode { 262 | case "prev": 263 | deleted = k - 2 264 | case "current": 265 | deleted = k 266 | case "next": 267 | deleted = k + 2 268 | if k+2 < len(slice) { 269 | slice[k+2] = 0 270 | } 271 | } 272 | m.Delete(deleted) 273 | } 274 | have = append(have, k) 275 | } 276 | var want []int 277 | for k, v := range slice { 278 | if v != 0 { 279 | want = append(want, k) 280 | } 281 | } 282 | if !slices.Equal(have, want) { 283 | var have2 []int 284 | for k := range m.All() { 285 | have2 = append(have2, k) 286 | } 287 | t.Errorf("All() with Delete(%d) at %d = %v, want %v (after=%v)", deleted, target, have, want, have2) 288 | } 289 | } 290 | } 291 | } 292 | }) 293 | } 294 | 295 | func TestScanDeleteRange(t *testing.T) { 296 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 297 | for _, mode := range []string{"prev", "current", "next"} { 298 | for N := range 8 { 299 | for target := 1; target <= 2*N-1; target += 2 { 300 | m := newMap() 301 | _, slice := permute(m, N) 302 | var have []int 303 | var deleteLo, deleteHi int 304 | for k, _ := range m.All() { 305 | if k == target { 306 | switch mode { 307 | case "prev": 308 | deleteLo, deleteHi = k-5, k-1 309 | case "current": 310 | deleteLo, deleteHi = k-2, k+2 311 | if k+2 < len(slice) { 312 | slice[k+2] = 0 313 | } 314 | case "next": 315 | deleteLo, deleteHi = k+1, k+5 316 | if k+2 < len(slice) { 317 | slice[k+2] = 0 318 | } 319 | if k+4 < len(slice) { 320 | slice[k+4] = 0 321 | } 322 | } 323 | m.DeleteRange(deleteLo, deleteHi) 324 | } 325 | have = append(have, k) 326 | } 327 | var want []int 328 | for k, v := range slice { 329 | if v != 0 { 330 | want = append(want, k) 331 | } 332 | } 333 | if !slices.Equal(have, want) { 334 | t.Errorf("All() with DeleteRange(%d, %d) at %d = %v, want %v", deleteLo, deleteHi, target, have, want) 335 | } 336 | } 337 | } 338 | } 339 | }) 340 | } 341 | 342 | func TestSplit(t *testing.T) { 343 | split := func(t *testing.T, m Mapper[int, int], target int) (val int, ok bool, more Mapper[int, int]) { 344 | switch m := m.(type) { 345 | default: 346 | t.Fatalf("bad split %T", m) 347 | panic("unreachable") 348 | case *Map[int, int]: 349 | return m.Split(target) 350 | case *treapMap[int, int]: 351 | x, after := m.split(target) 352 | if x != nil { 353 | val, ok = x.val, true 354 | } 355 | more = &treapMap[int, int]{after} 356 | return val, ok, more 357 | case *treapMapFunc[int, int]: 358 | x, after := m.split(target) 359 | if x != nil { 360 | val, ok = x.val, true 361 | } 362 | more = &treapMapFunc[int, int]{after, m.cmp} 363 | return val, ok, more 364 | case *avlMap[int, int]: 365 | x, after := m.split(target) 366 | if x != nil { 367 | val, ok = x.val, true 368 | } 369 | more = &avlMap[int, int]{after} 370 | return val, ok, more 371 | case *avlMapFunc[int, int]: 372 | x, after := m.split(target) 373 | if x != nil { 374 | val, ok = x.val, true 375 | } 376 | more = &avlMapFunc[int, int]{after, m.cmp} 377 | return val, ok, more 378 | } 379 | } 380 | 381 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 382 | for N := range 16 { 383 | for range 100 { 384 | for target := 0; target <= 2*N; target++ { 385 | m := newMap() 386 | _, slice := permute(m, N) 387 | orig := getAll(m) 388 | var want1, want2 []int 389 | var wantOK bool 390 | var wantVal int 391 | for k, v := range slice { 392 | if v != 0 { 393 | if k < target { 394 | want1 = append(want1, k) 395 | } else if k > target { 396 | want2 = append(want2, k) 397 | } else { 398 | wantOK = true 399 | wantVal = v 400 | } 401 | } 402 | } 403 | val, ok, rest := split(t, m, target) 404 | have1 := getAll(m) 405 | have2 := getAll(rest) 406 | if val != wantVal || ok != wantOK || !slices.Equal(have1, want1) || !slices.Equal(have2, want2) { 407 | t.Fatalf("%v.Split(%v):\nhave m=%v val=%v ok=%v rest=%v\nwant m=%v val=%v ok=%v rest=%v", 408 | orig, target, have1, val, ok, have2, want1, wantVal, wantOK, want2) 409 | } 410 | } 411 | } 412 | } 413 | }) 414 | } 415 | 416 | func TestJoin(t *testing.T) { 417 | join := func(t *testing.T, m, more Mapper[int, int]) { 418 | switch m := m.(type) { 419 | default: 420 | t.Fatalf("bad join %T", m) 421 | case *Map[int, int]: 422 | m.Join(more.(*Map[int, int])) 423 | case *treapMap[int, int]: 424 | more := more.(*treapMap[int, int]) 425 | m.join(nil, more.treap) 426 | more.root = nil 427 | case *treapMapFunc[int, int]: 428 | more := more.(*treapMapFunc[int, int]) 429 | m.join(nil, more.treap) 430 | more.root = nil 431 | case *avlMap[int, int]: 432 | more := more.(*avlMap[int, int]) 433 | m.join(nil, more.avl) 434 | more.root = nil 435 | case *avlMapFunc[int, int]: 436 | more := more.(*avlMapFunc[int, int]) 437 | m.join(nil, more.avl) 438 | more.root = nil 439 | } 440 | } 441 | 442 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 443 | for N := range 16 { 444 | for range 100 { 445 | perm := rand.Perm(N) 446 | target := rand.N(1 + N) 447 | 448 | m1 := newMap() 449 | m2 := newMap() 450 | var want1, want2 []int 451 | for _, v := range perm { 452 | if v < target { 453 | m1.Set(v, v) 454 | want1 = append(want1, v) 455 | } else { 456 | m2.Set(v, v) 457 | want2 = append(want2, v) 458 | } 459 | } 460 | sort.Ints(want1) 461 | sort.Ints(want2) 462 | have1 := getAll(m1) 463 | have2 := getAll(m2) 464 | if !slices.Equal(have1, want1) || !slices.Equal(have2, want2) { 465 | t.Fatalf("before join\nperm: %v target: %v\nhave: %v %v\nwant: %v %v", perm, target, have1, have2, want1, want2) 466 | } 467 | join(t, m1, m2) 468 | have1 = getAll(m1) 469 | have2 = getAll(m2) 470 | want := slices.Concat(want1, want2) 471 | if !slices.Equal(have1, want) || !slices.Equal(have2, nil) { 472 | t.Fatalf("after join\ninputs: %v %v\nhave: %v %v\nwant: %v %v", want1, want2, have1, have2, want, []int{}) 473 | } 474 | } 475 | } 476 | }) 477 | } 478 | 479 | func getAll(m Mapper[int, int]) []int { 480 | var x []int 481 | for k := range m.All() { 482 | x = append(x, k) 483 | } 484 | return x 485 | } 486 | 487 | func TestDepth(t *testing.T) { 488 | t.Skip("depth") 489 | test(t, func(t *testing.T, newMap func() Mapper[int, int]) { 490 | for _, mode := range []string{"seq", "rand"} { 491 | t.Run(mode, func(t *testing.T) { 492 | for range 3 { 493 | const N = 1000000 494 | m := newMap() 495 | if mode == "seq" { 496 | for i := range N { 497 | m.Set(i, i) 498 | } 499 | } else { 500 | for _, i := range rand.Perm(N) { 501 | m.Set(i, i) 502 | } 503 | } 504 | t.Logf("n=%d depth=%d", N, m.Depth()) 505 | } 506 | }) 507 | } 508 | }) 509 | } 510 | 511 | func allPerm(n int) iter.Seq[[]int] { 512 | return func(yield func([]int) bool) { 513 | x := make([]int, n) 514 | for i := range x { 515 | x[i] = i 516 | } 517 | genAllPerm(n, x, yield) 518 | } 519 | } 520 | 521 | func genAllPerm(k int, x []int, yield func([]int) bool) bool { 522 | if k <= 1 { 523 | return yield(x) 524 | } 525 | if !genAllPerm(k-1, x, yield) { 526 | return false 527 | } 528 | for i := range k - 1 { 529 | if k%2 == 0 { 530 | x[i], x[k-1] = x[k-1], x[i] 531 | } else { 532 | x[0], x[k-1] = x[k-1], x[0] 533 | } 534 | if !genAllPerm(k-1, x, yield) { 535 | return false 536 | } 537 | } 538 | return true 539 | } 540 | -------------------------------------------------------------------------------- /avl.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package omap 6 | 7 | import ( 8 | "bytes" 9 | "cmp" 10 | "fmt" 11 | "iter" 12 | ) 13 | 14 | // In-memory database stored as self-balancing AVL tree. 15 | // See Lewis & Denenberg, Data Structures and Their Algorithms. 16 | 17 | // A Map is a map[K]V ordered according to K's standard Go ordering. 18 | // The zero value of a Map is an empty Map ready to use. 19 | type avlMap[K cmp.Ordered, V any] struct { 20 | avl[K, V] 21 | } 22 | 23 | type avl[K, V any] struct { 24 | root *anode[K, V] 25 | } 26 | 27 | type avlMapFunc[K, V any] struct { 28 | avl[K, V] 29 | cmp func(K, K) int 30 | } 31 | 32 | func (t *avlMapFunc[K, V]) init(cmp func(K, K) int) { 33 | t.cmp = cmp 34 | } 35 | 36 | // An anode is a node in the AVL tree. 37 | type anode[K, V any] struct { 38 | parent *anode[K, V] 39 | left *anode[K, V] 40 | right *anode[K, V] 41 | bal int 42 | height int 43 | key K 44 | val V 45 | } 46 | 47 | func (t *avl[K, V]) Depth() int { 48 | return t.root.safeHeight() 49 | } 50 | 51 | func (t *avl[K, V]) setRoot(x *anode[K, V]) { 52 | t.root = x 53 | if x != nil { 54 | x.parent = nil 55 | } 56 | } 57 | 58 | func (x *anode[K, V]) setLeft(y *anode[K, V]) { 59 | x.left = y 60 | if y != nil { 61 | y.parent = x 62 | } 63 | } 64 | 65 | func (x *anode[K, V]) setRight(y *anode[K, V]) { 66 | x.right = y 67 | if y != nil { 68 | y.parent = x 69 | } 70 | } 71 | 72 | func (m *avlMap[K, V]) Get(key K) (val V, ok bool) { 73 | x := m.get(key) 74 | if x == nil { 75 | return 76 | } 77 | return x.val, true 78 | } 79 | 80 | func (m *avlMap[K, V]) get(key K) *anode[K, V] { 81 | if m == nil { 82 | return nil 83 | } 84 | x := m.root 85 | for x != nil { 86 | if key == x.key { 87 | return x 88 | } 89 | if key < x.key { 90 | x = x.left 91 | } else { 92 | x = x.right 93 | } 94 | } 95 | return nil 96 | } 97 | 98 | func (m *avlMapFunc[K, V]) Get(key K) (val V, ok bool) { 99 | x := m.get(key) 100 | if x == nil { 101 | return 102 | } 103 | return x.val, true 104 | } 105 | 106 | func (m *avlMapFunc[K, V]) get(key K) *anode[K, V] { 107 | if m == nil { 108 | return nil 109 | } 110 | x := m.root 111 | for x != nil { 112 | c := m.cmp(key, x.key) 113 | if c == 0 { 114 | return x 115 | } 116 | if c < 0 { 117 | x = x.left 118 | } else { 119 | x = x.right 120 | } 121 | } 122 | return nil 123 | } 124 | 125 | func (n *anode[K, V]) safeHeight() int { 126 | if n == nil { 127 | return -1 128 | } 129 | return n.height 130 | } 131 | 132 | func (n *anode[K, V]) checkbal() { 133 | b := n.right.safeHeight() - n.left.safeHeight() 134 | if b != n.bal { 135 | println("bad balance", n.left.safeHeight(), n.right.safeHeight(), b, n.bal, n.dump()) 136 | panic("bad balance") 137 | } 138 | } 139 | 140 | func (n *anode[K, V]) setHeight() { 141 | n.height = 1 + max(n.left.safeHeight(), n.right.safeHeight()) 142 | } 143 | 144 | func (n *anode[K, V]) setbal() { 145 | n.bal = n.right.safeHeight() - n.left.safeHeight() 146 | } 147 | 148 | func (t *avl[K, V]) replaceChild(p, old, x *anode[K, V]) { 149 | switch { 150 | case p == nil: 151 | if t.root != old { 152 | panic("corrupt avl") 153 | } 154 | t.setRoot(x) 155 | case p.left == old: 156 | p.setLeft(x) 157 | case p.right == old: 158 | p.setRight(x) 159 | default: 160 | panic("corrupt avl") 161 | } 162 | } 163 | 164 | func (t *avl[K, V]) rebalanceUp(x *anode[K, V]) { 165 | for x != nil { 166 | h := x.height 167 | x.setHeight() 168 | x.setbal() 169 | switch x.bal { 170 | case -2: 171 | if x.left.bal == 1 { 172 | t.rotateLeft(x.left) 173 | } 174 | x = t.rotateRight(x) 175 | 176 | case +2: 177 | if x.right.bal == -1 { 178 | t.rotateRight(x.right) 179 | } 180 | x = t.rotateLeft(x) 181 | } 182 | if x.height == h { 183 | return 184 | } 185 | x = x.parent 186 | } 187 | } 188 | 189 | // rotateRight rotates the subtree rooted at node y. 190 | // turning (y (x a b) c) into (x a (y b c)). 191 | func (t *avl[K, V]) rotateRight(y *anode[K, V]) *anode[K, V] { 192 | //m.Rotates++ 193 | // p -> (y (x a b) c) 194 | p := y.parent 195 | x := y.left 196 | b := x.right 197 | 198 | x.checkbal() 199 | y.checkbal() 200 | 201 | x.setRight(y) 202 | y.setLeft(b) 203 | t.replaceChild(p, y, x) 204 | 205 | y.setHeight() 206 | y.setbal() 207 | x.setHeight() 208 | x.setbal() 209 | return x 210 | } 211 | 212 | // rotateLeft rotates the subtree rooted at node x. 213 | // turning (x a (y b c)) into (y (x a b) c). 214 | func (t *avl[K, V]) rotateLeft(x *anode[K, V]) *anode[K, V] { 215 | //m.Rotates++ 216 | // p -> (x a (y b c)) 217 | p := x.parent 218 | y := x.right 219 | b := y.left 220 | 221 | x.checkbal() 222 | y.checkbal() 223 | 224 | y.setLeft(x) 225 | x.setRight(b) 226 | t.replaceChild(p, x, y) 227 | 228 | x.setHeight() 229 | x.setbal() 230 | y.setHeight() 231 | y.setbal() 232 | return y 233 | } 234 | 235 | func (m *avlMap[K, V]) Set(key K, val V) { 236 | pos, parent := m.locate(key) 237 | m.set(key, val, pos, parent) 238 | } 239 | 240 | func (m *avlMapFunc[K, V]) Set(key K, val V) { 241 | pos, parent := m.locate(key) 242 | m.set(key, val, pos, parent) 243 | } 244 | 245 | func (t *avl[K, V]) set(key K, val V, pos **anode[K, V], parent *anode[K, V]) { 246 | if x := *pos; x != nil { 247 | x.val = val 248 | return 249 | } 250 | x := &anode[K, V]{key: key, val: val, parent: parent, height: -1} 251 | *pos = x 252 | t.rebalanceUp(x) 253 | } 254 | 255 | // Delete deletes m[key]. 256 | func (m *avlMap[K, V]) Delete(key K) { 257 | pos, _ := m.locate(key) 258 | m.delete(pos) 259 | } 260 | 261 | // Delete deletes m[key]. 262 | func (m *avlMapFunc[K, V]) Delete(key K) { 263 | pos, _ := m.locate(key) 264 | m.delete(pos) 265 | } 266 | 267 | func (t *avl[K, V]) delete(pos **anode[K, V]) { 268 | t.root.checkParents(nil) 269 | 270 | x := *pos 271 | switch { 272 | case x == nil: 273 | return 274 | 275 | case x.left == nil: 276 | if *pos = x.right; *pos != nil { 277 | (*pos).parent = x.parent 278 | } 279 | t.rebalanceUp(x.parent) 280 | 281 | case x.right == nil: 282 | *pos = x.left 283 | x.left.parent = x.parent 284 | t.rebalanceUp(x.parent) 285 | 286 | default: 287 | t.deleteSwap(pos) 288 | } 289 | 290 | x.bal = -100 291 | x.parent = nil 292 | x.left = nil 293 | x.right = nil 294 | x.height = -1 295 | t.root.checkParents(nil) 296 | } 297 | 298 | func (m *avlMap[K, V]) locate(key K) (pos **anode[K, V], parent *anode[K, V]) { 299 | pos, x := &m.root, m.root 300 | for x != nil && key != x.key { 301 | parent = x 302 | if key < x.key { 303 | pos, x = &x.left, x.left 304 | } else { 305 | pos, x = &x.right, x.right 306 | } 307 | } 308 | return pos, parent 309 | } 310 | 311 | func (m *avlMapFunc[K, V]) locate(key K) (pos **anode[K, V], parent *anode[K, V]) { 312 | pos, x := &m.root, m.root 313 | for x != nil { 314 | c := m.cmp(key, x.key) 315 | if c == 0 { 316 | break 317 | } 318 | parent = x 319 | if c < 0 { 320 | pos, x = &x.left, x.left 321 | } else { 322 | pos, x = &x.right, x.right 323 | } 324 | } 325 | return pos, parent 326 | } 327 | 328 | func (t *avlMap[K, V]) split(key K) (x *anode[K, V], after avl[K, V]) { 329 | return t.avl.split(key, cmp.Compare[K]) 330 | } 331 | 332 | func (t *avlMapFunc[K, V]) split(key K) (x *anode[K, V], after avl[K, V]) { 333 | return t.avl.split(key, t.cmp) 334 | } 335 | 336 | func (t *avl[K, V]) split(key K, cmp func(K, K) int) (x *anode[K, V], after avl[K, V]) { 337 | /* 338 | split(T,k) = 339 | if T = Leaf then (Leaf, false, Leaf) 340 | else 341 | (L,m,R) = expose(T) 342 | if k = m then (L, true, R) 343 | else if k < m then 344 | (LL, b, LR) = split(L, k) 345 | (LL, b, join(LR, m, R)) 346 | else 347 | (RL, b, RR) = split(R, k) 348 | (join(L, m, RL), b, RR) 349 | 350 | (Figure 1) 351 | */ 352 | right := avl[K, V]{} 353 | if t.root == nil { 354 | return nil, right 355 | } 356 | 357 | mid := t.root 358 | t.setRoot(mid.left) 359 | right.setRoot(mid.right) 360 | mid.left, mid.right, mid.height = nil, nil, -1 361 | 362 | c := cmp(key, mid.key) 363 | if c == 0 { 364 | return mid, right 365 | } 366 | if c < 0 { 367 | leftMid, leftRight := t.split(key, cmp) 368 | leftRight.join(mid, right) 369 | return leftMid, leftRight 370 | } 371 | if c > 0 { 372 | rightMid, rightRight := right.split(key, cmp) 373 | t.join(mid, right) 374 | return rightMid, rightRight 375 | } 376 | panic("unreachable") 377 | } 378 | 379 | /* 380 | join(TL, k, TR) = 381 | if h(TL) > h(TR)+1 then joinRight(TL, k, TR) 382 | else if h(TR) > h(TL)+1 then joinLeft(TL,k,TR) 383 | else Node(TL, k, TR) 384 | 385 | joinRight(TL, k, TR) = 386 | (l, k', c) = expose(TL) 387 | if h(c) <= h(TR)+1 then 388 | T' = Node(c,k,TR) 389 | if h(T') <= h(l)+1 then Node(l,k',T') 390 | else rotateLeft(Node(l,k',rotateRight(T')) 391 | else 392 | T' = joinRight(c,k,TR) 393 | T'' = Node(l,k',T') 394 | if h(T') <= h(l)+1 then T'' 395 | else rotateLeft(T'') 396 | 397 | (Figure 2) 398 | 399 | split(T,k) = 400 | if T = Leaf then (Leaf, false, Leaf) 401 | else 402 | (L,m,R) = expose(T) 403 | if k = m then (L, true, R) 404 | else if k < m then 405 | (LL, b, LR) = split(L, k) 406 | (LL, b, join(LR, m, R)) 407 | else 408 | (RL, b, RR) = split(R, k) 409 | (join(L, m, RL), b, RR) 410 | 411 | (Figure 1) 412 | 413 | https://arxiv.org/pdf/1602.02120 414 | */ 415 | 416 | func (t *avl[K, V]) min() **anode[K, V] { 417 | pos, x := &t.root, t.root 418 | for x != nil && x.left != nil { 419 | pos, x = &x.left, x.left 420 | } 421 | return pos 422 | } 423 | 424 | func (t *avl[K, V]) max() **anode[K, V] { 425 | pos, x := &t.root, t.root 426 | for x != nil && x.right != nil { 427 | pos, x = &x.right, x.right 428 | } 429 | return pos 430 | } 431 | 432 | func (t *avl[K, V]) join(y *anode[K, V], after avl[K, V]) { 433 | if y == nil { 434 | pos := after.min() 435 | y = *pos 436 | if y == nil { 437 | return 438 | } 439 | after.delete(pos) 440 | } 441 | 442 | if y.left != nil || y.right != nil || y.height >= 0 { 443 | panic("avl join misuse") 444 | } 445 | 446 | x := t.root 447 | z := after.root 448 | xh := x.safeHeight() 449 | zh := z.safeHeight() 450 | 451 | switch { 452 | case xh > zh+1: 453 | for x.right != nil && x.right.height > zh { 454 | x = x.right 455 | } 456 | // now x.height > zh but x.right.height <= zh 457 | // replacing x.right with y=node{x.right, z} will grow x.right.height at most 1 458 | // println("JOIN X", x.safeHeight(), x.left.safeHeight(), x.right.safeHeight(), z.safeHeight()) 459 | y.setLeft(x.right) 460 | y.setRight(z) 461 | x.setRight(y) 462 | y.height = -1 463 | t.rebalanceUp(y) 464 | t.root.checkAll() 465 | 466 | case zh > xh+1: 467 | for z.left != nil && z.left.height > xh { 468 | z = z.left 469 | } 470 | // println("JOIN Z", z.safeHeight(), z.left.safeHeight(), z.right.safeHeight(), x.safeHeight()) 471 | y.setLeft(x) 472 | y.setRight(z.left) 473 | z.setLeft(y) 474 | y.height = -1 475 | t.root = after.root 476 | t.rebalanceUp(y) 477 | t.root.checkAll() 478 | 479 | default: 480 | y.setLeft(x) 481 | y.setRight(z) 482 | t.setRoot(y) 483 | t.rebalanceUp(y) 484 | t.root.checkAll() 485 | } 486 | 487 | after.root = nil 488 | t.rebalanceUp(y) 489 | } 490 | 491 | func (m *avlMap[K, V]) Split(key K) (val V, ok bool, more avl[K, V]) { 492 | mid, more := m.split(key) 493 | if mid != nil { 494 | val, ok = mid.val, true 495 | } 496 | return val, ok, more 497 | } 498 | 499 | func (t *avl[K, V]) deleteMin(zpos **anode[K, V]) (z, zparent *anode[K, V]) { 500 | for (*zpos).left != nil { 501 | zpos = &(*zpos).left 502 | } 503 | z = *zpos 504 | zparent = z.parent 505 | *zpos = z.right 506 | if *zpos != nil { 507 | (*zpos).parent = zparent 508 | } 509 | return z, zparent 510 | } 511 | 512 | func (t *avl[K, V]) deleteSwap(pos **anode[K, V]) { 513 | x := *pos 514 | z, zparent := t.deleteMin(&x.right) 515 | 516 | *pos = z 517 | if zparent == x { 518 | zparent = z 519 | } 520 | z.parent = x.parent 521 | z.height = x.height 522 | z.bal = x.bal 523 | z.setLeft(x.left) 524 | z.setRight(x.right) 525 | 526 | t.rebalanceUp(zparent) 527 | } 528 | 529 | func (n *anode[K, V]) checkAll() { 530 | return 531 | if n == nil { 532 | return 533 | } 534 | if n.height != 1+max(n.left.safeHeight(), n.right.safeHeight()) { 535 | panic("bad height") 536 | } 537 | n.checkbal() 538 | n.left.checkAll() 539 | n.right.checkAll() 540 | } 541 | 542 | func (n *anode[K, V]) checkParents(p *anode[K, V]) { 543 | return 544 | if n == nil { 545 | return 546 | } 547 | if n.parent != p { 548 | panic("bad parent") 549 | } 550 | n.left.checkParents(n) 551 | n.right.checkParents(n) 552 | n.checkbal() 553 | } 554 | 555 | func (t *avlMap[K, V]) Dump() string { 556 | return t.root.dump() 557 | } 558 | 559 | func (root *anode[K, V]) dump() string { 560 | var buf bytes.Buffer 561 | var walk func(*anode[K, V]) 562 | walk = func(x *anode[K, V]) { 563 | if x == nil { 564 | fmt.Fprintf(&buf, "nil") 565 | return 566 | } 567 | fmt.Fprintf(&buf, "(%d/%d %v:%v ", x.bal, x.height, x.key, x.val) 568 | walk(x.left) 569 | fmt.Fprintf(&buf, " ") 570 | walk(x.right) 571 | fmt.Fprintf(&buf, ")") 572 | } 573 | walk(root) 574 | return buf.String() 575 | } 576 | 577 | func (t *avlMap[K, V]) DeleteRange(lo, hi K) { 578 | if t == nil { 579 | panic("nil DeleteRange") 580 | } 581 | if lo > hi { 582 | return 583 | } 584 | t.deleteRange(lo, hi, t.split) 585 | } 586 | 587 | func (t *avlMapFunc[K, V]) DeleteRange(lo, hi K) { 588 | if t == nil { 589 | panic("nil DeleteRange") 590 | } 591 | if t.cmp(lo, hi) > 0 { 592 | return 593 | } 594 | t.deleteRange(lo, hi, t.split) 595 | } 596 | 597 | func (t *avl[K, V]) deleteRange(lo, hi K, split func(K) (*anode[K, V], avl[K, V])) { 598 | _, after := split(hi) 599 | _, middle := split(lo) 600 | t.join(nil, after) 601 | middle.root.markDeleted() 602 | } 603 | 604 | func (x *anode[K, V]) markDeleted() { 605 | if x == nil { 606 | return 607 | } 608 | x.height = -1 609 | x.left.markDeleted() 610 | x.right.markDeleted() 611 | } 612 | 613 | // All returns an iterator over the map m. 614 | // If m is modified during the iteration, some keys may not be visited. 615 | // No keys will be visited multiple times. 616 | func (m *avlMap[K, V]) All() iter.Seq2[K, V] { 617 | return m.all(m.locate) 618 | } 619 | 620 | // All returns an iterator over the map m. 621 | // If m is modified during the iteration, some keys may not be visited. 622 | // No keys will be visited multiple times. 623 | func (m *avlMapFunc[K, V]) All() iter.Seq2[K, V] { 624 | return m.all(m.locate) 625 | } 626 | 627 | func (t *avl[K, V]) all(locate func(K) (**anode[K, V], *anode[K, V])) iter.Seq2[K, V] { 628 | return func(yield func(K, V) bool) { 629 | if t == nil { 630 | return 631 | } 632 | x := t.root 633 | if x != nil { 634 | for x.left != nil { 635 | x = x.left 636 | } 637 | } 638 | for x != nil && yield(x.key, x.val) { 639 | if x.height >= 0 { 640 | // still in tree 641 | x = x.next() 642 | } else { 643 | // deleted 644 | x = t.nextAfter(locate(x.key)) 645 | } 646 | } 647 | } 648 | } 649 | 650 | func (x *anode[K, V]) next() *anode[K, V] { 651 | if x.right == nil { 652 | for x.parent != nil && x.parent.right == x { 653 | x = x.parent 654 | } 655 | return x.parent 656 | } 657 | x = x.right 658 | for x.left != nil { 659 | x = x.left 660 | } 661 | return x 662 | } 663 | 664 | func (t *avl[K, V]) nextAfter(pos **anode[K, V], parent *anode[K, V]) *anode[K, V] { 665 | switch { 666 | case *pos != nil: 667 | return (*pos).next() 668 | case parent == nil: 669 | return nil 670 | case pos == &parent.left: 671 | return parent 672 | default: 673 | return parent.next() 674 | } 675 | } 676 | 677 | // Scan returns an iterator over the map m 678 | // limited to keys k satisfying lo ≤ k ≤ hi. 679 | // 680 | // If m is modified during the iteration, some keys may not be visited. 681 | // No keys will be visited multiple times. 682 | func (m *avlMap[K, V]) Scan(lo, hi K) iter.Seq2[K, V] { 683 | return m.scan(lo, hi, cmp.Compare[K], m.locate) 684 | } 685 | 686 | // Scan returns an iterator over the map m 687 | // limited to keys k satisfying lo ≤ k ≤ hi. 688 | // 689 | // If m is modified during the iteration, some keys may not be visited. 690 | // No keys will be visited multiple times. 691 | func (m *avlMapFunc[K, V]) Scan(lo, hi K) iter.Seq2[K, V] { 692 | return m.scan(lo, hi, m.cmp, m.locate) 693 | } 694 | 695 | func (t *avl[K, V]) scan(lo, hi K, cmp func(K, K) int, locate func(K) (**anode[K, V], *anode[K, V])) iter.Seq2[K, V] { 696 | return func(yield func(K, V) bool) { 697 | if t == nil { 698 | return 699 | } 700 | pos, parent := locate(lo) 701 | x := *pos 702 | if x == nil { 703 | x = t.nextAfter(pos, parent) 704 | } 705 | for x != nil && cmp(x.key, hi) <= 0 && yield(x.key, x.val) { 706 | if x.height >= 0 { 707 | x = x.next() 708 | } else { 709 | x = t.nextAfter(locate(x.key)) 710 | } 711 | } 712 | } 713 | } 714 | 715 | func (t *avl[K, V]) Dump() string { 716 | var buf bytes.Buffer 717 | var walk func(*anode[K, V]) 718 | walk = func(x *anode[K, V]) { 719 | if x == nil { 720 | fmt.Fprintf(&buf, "nil") 721 | return 722 | } 723 | fmt.Fprintf(&buf, "(h%d/b%+d %v:%v ", x.height, x.bal, x.key, x.val) 724 | walk(x.left) 725 | fmt.Fprintf(&buf, " ") 726 | walk(x.right) 727 | fmt.Fprintf(&buf, ")") 728 | } 729 | walk(t.root) 730 | return buf.String() 731 | 732 | } 733 | -------------------------------------------------------------------------------- /llrb.go: -------------------------------------------------------------------------------- 1 | // Copyright 2024 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | package omap 6 | 7 | import ( 8 | "bytes" 9 | "cmp" 10 | "fmt" 11 | "iter" 12 | ) 13 | 14 | // A Map is a map[K]V ordered according to K's standard Go ordering. 15 | // The zero value of a Map is an empty Map ready to use. 16 | type llrbMap[K cmp.Ordered, V any] struct { 17 | llrb[K, V] 18 | } 19 | 20 | type llrb[K, V any] struct { 21 | root *rbnode[K, V] 22 | } 23 | 24 | type llrbMapFunc[K, V any] struct { 25 | llrb[K, V] 26 | cmp func(K, K) int 27 | } 28 | 29 | func (t *llrbMapFunc[K, V]) init(cmp func(K, K) int) { 30 | t.cmp = cmp 31 | } 32 | 33 | // An rbnode is a node in the LLRB tree. 34 | type rbnode[K, V any] struct { 35 | parent *rbnode[K, V] 36 | left *rbnode[K, V] 37 | right *rbnode[K, V] 38 | red bool 39 | del bool 40 | key K 41 | val V 42 | } 43 | 44 | func (x *rbnode[K, V]) deleted() bool { 45 | return x.del 46 | } 47 | 48 | func (t *llrb[K, V]) Depth() int { 49 | return t.root.depth() 50 | } 51 | 52 | func (x *rbnode[K, V]) depth() int { 53 | if x == nil { 54 | return -1 55 | } 56 | return 1 + max(x.left.depth(), x.right.depth()) 57 | } 58 | 59 | func (t *llrb[K, V]) setRoot(x *rbnode[K, V]) { 60 | t.root = x 61 | if x != nil { 62 | x.parent = nil 63 | } 64 | } 65 | 66 | func (x *rbnode[K, V]) setLeft(y *rbnode[K, V]) { 67 | x.left = y 68 | if y != nil { 69 | y.parent = x 70 | } 71 | } 72 | 73 | func (x *rbnode[K, V]) setRight(y *rbnode[K, V]) { 74 | x.right = y 75 | if y != nil { 76 | y.parent = x 77 | } 78 | } 79 | 80 | func (m *llrbMap[K, V]) Get(key K) (val V, ok bool) { 81 | x := m.get(key) 82 | if x == nil { 83 | return 84 | } 85 | return x.val, true 86 | } 87 | 88 | func (m *llrbMap[K, V]) get(key K) *rbnode[K, V] { 89 | if m == nil { 90 | return nil 91 | } 92 | x := m.root 93 | for x != nil { 94 | if key == x.key { 95 | return x 96 | } 97 | if key < x.key { 98 | x = x.left 99 | } else { 100 | x = x.right 101 | } 102 | } 103 | return nil 104 | } 105 | 106 | func (m *llrbMapFunc[K, V]) Get(key K) (val V, ok bool) { 107 | x := m.get(key) 108 | if x == nil { 109 | return 110 | } 111 | return x.val, true 112 | } 113 | 114 | func (m *llrbMapFunc[K, V]) get(key K) *rbnode[K, V] { 115 | if m == nil { 116 | return nil 117 | } 118 | x := m.root 119 | for x != nil { 120 | c := m.cmp(key, x.key) 121 | if c == 0 { 122 | return x 123 | } 124 | if c < 0 { 125 | x = x.left 126 | } else { 127 | x = x.right 128 | } 129 | } 130 | return nil 131 | } 132 | 133 | func (x *rbnode[K, V]) isRed() bool { 134 | if x == nil { 135 | return false 136 | } 137 | return x.red 138 | } 139 | 140 | func (t *llrb[K, V]) replaceChild(p, old, x *rbnode[K, V]) { 141 | switch { 142 | case p == nil: 143 | if t.root != old { 144 | panic("corrupt llrb") 145 | } 146 | t.setRoot(x) 147 | case p.left == old: 148 | p.setLeft(x) 149 | case p.right == old: 150 | p.setRight(x) 151 | default: 152 | panic("corrupt llrb") 153 | } 154 | } 155 | 156 | // rotateRight rotates the subtree rooted at node y. 157 | // turning (y (x a b) c) into (x a (y b c)). 158 | func (t *llrb[K, V]) rotateRight(y *rbnode[K, V]) *rbnode[K, V] { 159 | //m.Rotates++ 160 | // p -> (y (x a b) c) 161 | p := y.parent 162 | x := y.left 163 | b := x.right 164 | 165 | x.setRight(y) 166 | y.setLeft(b) 167 | t.replaceChild(p, y, x) 168 | 169 | x.red = y.red 170 | y.red = true 171 | 172 | return x 173 | } 174 | 175 | // rotateLeft rotates the subtree rooted at node x. 176 | // turning (x a (y b c)) into (y (x a b) c). 177 | func (t *llrb[K, V]) rotateLeft(x *rbnode[K, V]) *rbnode[K, V] { 178 | //m.Rotates++ 179 | // p -> (x a (y b c)) 180 | p := x.parent 181 | y := x.right 182 | b := y.left 183 | 184 | y.setLeft(x) 185 | x.setRight(b) 186 | t.replaceChild(p, x, y) 187 | 188 | y.red = x.red 189 | x.red = true 190 | 191 | return y 192 | } 193 | 194 | func (t *llrb[K, V]) flipColors(x *rbnode[K, V]) { 195 | x.red = !x.red 196 | x.left.red = !x.left.red 197 | x.right.red = !x.right.red 198 | } 199 | 200 | func (m *llrbMap[K, V]) checkAll() { 201 | defer func() { 202 | if e := recover(); e != nil { 203 | println(m.Dump()) 204 | panic(e) 205 | } 206 | }() 207 | m.root.checkAll(cmp.Compare[K]) 208 | } 209 | 210 | func (m *llrbMapFunc[K, V]) checkAll() { 211 | defer func() { 212 | if e := recover(); e != nil { 213 | println(m.Dump()) 214 | panic(e) 215 | } 216 | }() 217 | m.root.checkAll(m.cmp) 218 | } 219 | 220 | func (x *rbnode[K, V]) checkAll(cmp func(K, K) int) int { 221 | if x == nil { 222 | return -1 223 | } 224 | if x.left != nil && cmp(x.left.key, x.key) >= 0 { 225 | panic("bad left order") 226 | } 227 | if x.left != nil && x.left.parent != x { 228 | println("P", x.key, x.left.key, x.left.parent.key) 229 | panic("bad left parent") 230 | } 231 | if x.right != nil && cmp(x.key, x.right.key) >= 0 { 232 | panic("bad right order") 233 | } 234 | if x.right != nil && x.right.parent != x { 235 | panic("bad right parent") 236 | } 237 | if x.red && x.left != nil && x.left.red { 238 | panic("bad llrb double red left") 239 | } 240 | if x.red && x.right != nil && x.right.red { 241 | panic("bad llrb double red right") 242 | } 243 | if x.right.isRed() { 244 | panic("bad llrb right red") 245 | } 246 | h1 := x.left.checkAll(cmp) 247 | h2 := x.right.checkAll(cmp) 248 | if h1 != h2 { 249 | panic("bad llrb height") 250 | } 251 | if !x.red { 252 | h1++ 253 | } 254 | return h1 255 | } 256 | 257 | func (t *llrb[K, V]) recolorUp(x *rbnode[K, V]) { 258 | for x != nil { 259 | if x.right.isRed() && !x.left.isRed() { 260 | x = t.rotateLeft(x) 261 | } 262 | if x.left.isRed() && x.left.left.isRed() { 263 | x = t.rotateRight(x) 264 | } 265 | if x.left.isRed() && x.right.isRed() { 266 | t.flipColors(x) 267 | } 268 | x = x.parent 269 | } 270 | if t.root != nil { 271 | t.root.red = false 272 | } 273 | } 274 | 275 | func (m *llrbMap[K, V]) Set(key K, val V) { 276 | pos, parent := m.locate(key) 277 | m.set(key, val, pos, parent) 278 | m.checkAll() 279 | } 280 | 281 | func (m *llrbMapFunc[K, V]) Set(key K, val V) { 282 | pos, parent := m.locate(key) 283 | m.set(key, val, pos, parent) 284 | } 285 | 286 | func (t *llrb[K, V]) set(key K, val V, pos **rbnode[K, V], parent *rbnode[K, V]) { 287 | if x := *pos; x != nil { 288 | x.val = val 289 | return 290 | } 291 | x := &rbnode[K, V]{key: key, val: val, parent: parent, red: true} 292 | *pos = x 293 | t.recolorUp(x) 294 | } 295 | 296 | // Delete deletes m[key]. 297 | func (m *llrbMap[K, V]) Delete(key K) { 298 | m.delete(key, cmp.Compare[K]) 299 | m.checkAll() 300 | } 301 | 302 | // Delete deletes m[key]. 303 | func (m *llrbMapFunc[K, V]) Delete(key K) { 304 | m.delete(key, m.cmp) 305 | m.checkAll() 306 | } 307 | 308 | func (t *llrb[K, V]) delete(key K, cmp func(K, K) int) { 309 | pos, parent, x := &t.root, t.root, t.root 310 | if x == nil { 311 | return 312 | } 313 | for { 314 | if x == nil { 315 | t.recolorUp(parent) 316 | return 317 | } 318 | if cmp(key, x.key) < 0 { 319 | if !x.left.isRed() && x.left != nil && !x.left.left.isRed() { // TODO x.left != nil? 320 | x = t.moveRedLeft(x) 321 | } 322 | parent, pos, x = x, &x.left, x.left 323 | } else { 324 | if x.left.isRed() { 325 | x = t.rotateRight(x) 326 | } 327 | if cmp(key, x.key) == 0 && x.right == nil { 328 | *pos = nil 329 | t.recolorUp(parent) 330 | break 331 | } 332 | if !x.right.isRed() && x.right != nil && !x.right.left.isRed() { 333 | x = t.moveRedRight(x) 334 | } 335 | if cmp(key, x.key) == 0 { 336 | z, zparent := t.deleteMin(&x.right) 337 | if zparent == x { 338 | zparent = z 339 | } 340 | z.setLeft(x.left) 341 | z.setRight(x.right) 342 | z.red = x.red 343 | t.replaceChild(x.parent, x, z) 344 | t.recolorUp(zparent) 345 | break 346 | } 347 | parent, pos, x = x, &x.right, x.right 348 | } 349 | } 350 | 351 | x.parent = nil 352 | x.left = nil 353 | x.right = nil 354 | x.del = true 355 | } 356 | 357 | func (m *llrbMap[K, V]) locate(key K) (pos **rbnode[K, V], parent *rbnode[K, V]) { 358 | pos, x := &m.root, m.root 359 | for x != nil && key != x.key { 360 | parent = x 361 | if key < x.key { 362 | pos, x = &x.left, x.left 363 | } else { 364 | pos, x = &x.right, x.right 365 | } 366 | } 367 | return pos, parent 368 | } 369 | 370 | func (t *llrb[K, V]) moveRedLeft(x *rbnode[K, V]) *rbnode[K, V] { 371 | t.flipColors(x) 372 | if x.right != nil && x.right.left.isRed() { 373 | t.rotateRight(x.right) 374 | x = t.rotateLeft(x) 375 | t.flipColors(x) 376 | } 377 | return x 378 | } 379 | 380 | func (t *llrb[K, V]) moveRedRight(x *rbnode[K, V]) *rbnode[K, V] { 381 | t.flipColors(x) 382 | if x.left != nil && x.left.left != nil && x.left.left.red { 383 | x = t.rotateRight(x) 384 | t.flipColors(x) 385 | } 386 | return x 387 | } 388 | 389 | func (m *llrbMapFunc[K, V]) locate(key K) (pos **rbnode[K, V], parent *rbnode[K, V]) { 390 | pos, x := &m.root, m.root 391 | for x != nil { 392 | c := m.cmp(key, x.key) 393 | if c == 0 { 394 | break 395 | } 396 | parent = x 397 | if c < 0 { 398 | pos, x = &x.left, x.left 399 | } else { 400 | pos, x = &x.right, x.right 401 | } 402 | } 403 | return pos, parent 404 | } 405 | 406 | func (t *llrbMap[K, V]) split(key K) (x *rbnode[K, V], after llrb[K, V]) { 407 | return t.llrb.split(key, cmp.Compare[K]) 408 | } 409 | 410 | func (t *llrbMapFunc[K, V]) split(key K) (x *rbnode[K, V], after llrb[K, V]) { 411 | return t.llrb.split(key, t.cmp) 412 | } 413 | 414 | func (t *llrb[K, V]) split(key K, cmp func(K, K) int) (x *rbnode[K, V], after llrb[K, V]) { 415 | panic("split") 416 | /* 417 | // split(T,k) = 418 | // if T = Leaf then (Leaf, false, Leaf) 419 | // else 420 | // (L,m,R) = expose(T) 421 | // if k = m then (L, true, R) 422 | // else if k < m then 423 | // (LL, b, LR) = split(L, k) 424 | // (LL, b, join(LR, m, R)) 425 | // else 426 | // (RL, b, RR) = split(R, k) 427 | // (join(L, m, RL), b, RR) 428 | // 429 | // (Figure 1) 430 | right := llrb[K, V]{} 431 | if t.root == nil { 432 | return nil, right 433 | } 434 | 435 | mid := t.root 436 | t.setRoot(mid.left) 437 | right.setRoot(mid.right) 438 | mid.left, mid.right = nil, nil 439 | 440 | c := cmp(key, mid.key) 441 | if c == 0 { 442 | return mid, right 443 | } 444 | if c < 0 { 445 | leftMid, leftRight := t.split(key, cmp) 446 | leftRight.join(mid, right) 447 | return leftMid, leftRight 448 | } 449 | if c > 0 { 450 | rightMid, rightRight := right.split(key, cmp) 451 | t.join(mid, right) 452 | return rightMid, rightRight 453 | } 454 | panic("unreachable") 455 | */ 456 | } 457 | 458 | /* 459 | join(TL, k, TR) = 460 | if h(TL) > h(TR)+1 then joinRight(TL, k, TR) 461 | else if h(TR) > h(TL)+1 then joinLeft(TL,k,TR) 462 | else Node(TL, k, TR) 463 | 464 | joinRight(TL, k, TR) = 465 | (l, k', c) = expose(TL) 466 | if h(c) <= h(TR)+1 then 467 | T' = Node(c,k,TR) 468 | if h(T') <= h(l)+1 then Node(l,k',T') 469 | else rotateLeft(Node(l,k',rotateRight(T')) 470 | else 471 | T' = joinRight(c,k,TR) 472 | T'' = Node(l,k',T') 473 | if h(T') <= h(l)+1 then T'' 474 | else rotateLeft(T'') 475 | 476 | (Figure 2) 477 | 478 | split(T,k) = 479 | if T = Leaf then (Leaf, false, Leaf) 480 | else 481 | (L,m,R) = expose(T) 482 | if k = m then (L, true, R) 483 | else if k < m then 484 | (LL, b, LR) = split(L, k) 485 | (LL, b, join(LR, m, R)) 486 | else 487 | (RL, b, RR) = split(R, k) 488 | (join(L, m, RL), b, RR) 489 | 490 | (Figure 1) 491 | 492 | https://arxiv.org/pdf/1602.02120 493 | */ 494 | 495 | func (t *llrb[K, V]) min() **rbnode[K, V] { 496 | pos, x := &t.root, t.root 497 | for x != nil && x.left != nil { 498 | pos, x = &x.left, x.left 499 | } 500 | return pos 501 | } 502 | 503 | func (t *llrb[K, V]) max() **rbnode[K, V] { 504 | pos, x := &t.root, t.root 505 | for x != nil && x.right != nil { 506 | pos, x = &x.right, x.right 507 | } 508 | return pos 509 | } 510 | 511 | /* 512 | func (t *llrb[K, V]) join(y *rbnode[K, V], after llrb[K, V]) { 513 | if y == nil { 514 | pos := after.min() 515 | y = *pos 516 | if y == nil { 517 | return 518 | } 519 | after.delete(pos) 520 | } 521 | 522 | if y.left != nil || y.right != nil || y.height >= 0 { 523 | panic("llrb join misuse") 524 | } 525 | 526 | x := t.root 527 | z := after.root 528 | xh := x.safeHeight() 529 | zh := z.safeHeight() 530 | 531 | switch { 532 | case xh > zh+1: 533 | for x.right != nil && x.right.height > zh { 534 | x = x.right 535 | } 536 | // now x.height > zh but x.right.height <= zh 537 | // replacing x.right with y=node{x.right, z} will grow x.right.height at most 1 538 | // println("JOIN X", x.safeHeight(), x.left.safeHeight(), x.right.safeHeight(), z.safeHeight()) 539 | y.setLeft(x.right) 540 | y.setRight(z) 541 | x.setRight(y) 542 | y.height = -1 543 | t.recolorUp(y) 544 | 545 | case zh > xh+1: 546 | for z.left != nil && z.left.height > xh { 547 | z = z.left 548 | } 549 | // println("JOIN Z", z.safeHeight(), z.left.safeHeight(), z.right.safeHeight(), x.safeHeight()) 550 | y.setLeft(x) 551 | y.setRight(z.left) 552 | z.setLeft(y) 553 | y.height = -1 554 | t.root = after.root 555 | t.recolorUp(y) 556 | 557 | default: 558 | y.setLeft(x) 559 | y.setRight(z) 560 | t.setRoot(y) 561 | t.recolorUp(y) 562 | } 563 | 564 | after.root = nil 565 | t.recolorUp(y) 566 | } 567 | 568 | func (m *llrbMap[K, V]) Split(key K) (val V, ok bool, more llrb[K, V]) { 569 | mid, more := m.split(key) 570 | if mid != nil { 571 | val, ok = mid.val, true 572 | } 573 | return val, ok, more 574 | } 575 | */ 576 | 577 | func (t *llrb[K, V]) deleteMin(zpos **rbnode[K, V]) (z, zparent *rbnode[K, V]) { 578 | //fmt.Println("before deleteMin:", t.Dump()) 579 | z = *zpos 580 | for { 581 | if z.left == nil { 582 | zparent = z.parent 583 | if z.right != nil { 584 | panic("bad z.right") 585 | } 586 | *zpos = nil 587 | //fmt.Println("after deleteMin:", t.Dump()) 588 | return z, zparent 589 | } 590 | if !z.left.isRed() && !z.left.left.isRed() { 591 | z = t.moveRedLeft(z) 592 | } 593 | zpos, z = &z.left, z.left 594 | } 595 | } 596 | 597 | func (root *rbnode[K, V]) dump() string { 598 | var buf bytes.Buffer 599 | var walk func(*rbnode[K, V]) 600 | walk = func(x *rbnode[K, V]) { 601 | if x == nil { 602 | fmt.Fprintf(&buf, "nil") 603 | return 604 | } 605 | fmt.Fprintf(&buf, "(") 606 | if x.red { 607 | fmt.Fprintf(&buf, "RED ") 608 | } 609 | fmt.Fprintf(&buf, "%v:%v", x.key, x.val) 610 | if x.left != nil || x.right != nil { 611 | fmt.Fprintf(&buf, " ") 612 | walk(x.left) 613 | fmt.Fprintf(&buf, " ") 614 | walk(x.right) 615 | } 616 | fmt.Fprintf(&buf, ")") 617 | } 618 | walk(root) 619 | return buf.String() 620 | } 621 | 622 | func (t *llrbMap[K, V]) DeleteRange(lo, hi K) { 623 | if t == nil { 624 | panic("nil DeleteRange") 625 | } 626 | if lo > hi { 627 | return 628 | } 629 | t.deleteRange(lo, hi, t.split) 630 | } 631 | 632 | func (t *llrbMapFunc[K, V]) DeleteRange(lo, hi K) { 633 | if t == nil { 634 | panic("nil DeleteRange") 635 | } 636 | if t.cmp(lo, hi) > 0 { 637 | return 638 | } 639 | t.deleteRange(lo, hi, t.split) 640 | } 641 | 642 | func (t *llrb[K, V]) deleteRange(lo, hi K, split func(K) (*rbnode[K, V], llrb[K, V])) { 643 | panic("deleteRange") 644 | /* 645 | _, after := split(hi) 646 | _, middle := split(lo) 647 | t.join(nil, after) 648 | middle.root.markDeleted() 649 | */ 650 | } 651 | 652 | /* 653 | func (x *rbnode[K, V]) markDeleted() { 654 | if x == nil { 655 | return 656 | } 657 | x.height = -1 658 | x.left.markDeleted() 659 | x.right.markDeleted() 660 | } 661 | */ 662 | 663 | // All returns an iterator over the map m. 664 | // If m is modified during the iteration, some keys may not be visited. 665 | // No keys will be visited multiple times. 666 | func (m *llrbMap[K, V]) All() iter.Seq2[K, V] { 667 | return m.all(m.locate) 668 | } 669 | 670 | // All returns an iterator over the map m. 671 | // If m is modified during the iteration, some keys may not be visited. 672 | // No keys will be visited multiple times. 673 | func (m *llrbMapFunc[K, V]) All() iter.Seq2[K, V] { 674 | return m.all(m.locate) 675 | } 676 | 677 | func (t *llrb[K, V]) all(locate func(K) (**rbnode[K, V], *rbnode[K, V])) iter.Seq2[K, V] { 678 | return func(yield func(K, V) bool) { 679 | if t == nil { 680 | return 681 | } 682 | x := t.root 683 | if x != nil { 684 | for x.left != nil { 685 | x = x.left 686 | } 687 | } 688 | for x != nil && yield(x.key, x.val) { 689 | if x.deleted() { 690 | x = t.nextAfter(locate(x.key)) 691 | } else { 692 | x = x.next() 693 | } 694 | } 695 | } 696 | } 697 | 698 | func (x *rbnode[K, V]) next() *rbnode[K, V] { 699 | if x.right == nil { 700 | for x.parent != nil && x.parent.right == x { 701 | x = x.parent 702 | } 703 | return x.parent 704 | } 705 | x = x.right 706 | for x.left != nil { 707 | x = x.left 708 | } 709 | return x 710 | } 711 | 712 | func (t *llrb[K, V]) nextAfter(pos **rbnode[K, V], parent *rbnode[K, V]) *rbnode[K, V] { 713 | switch { 714 | case *pos != nil: 715 | return (*pos).next() 716 | case parent == nil: 717 | return nil 718 | case pos == &parent.left: 719 | return parent 720 | default: 721 | return parent.next() 722 | } 723 | } 724 | 725 | // Scan returns an iterator over the map m 726 | // limited to keys k satisfying lo ≤ k ≤ hi. 727 | // 728 | // If m is modified during the iteration, some keys may not be visited. 729 | // No keys will be visited multiple times. 730 | func (m *llrbMap[K, V]) Scan(lo, hi K) iter.Seq2[K, V] { 731 | return m.scan(lo, hi, cmp.Compare[K], m.locate) 732 | } 733 | 734 | // Scan returns an iterator over the map m 735 | // limited to keys k satisfying lo ≤ k ≤ hi. 736 | // 737 | // If m is modified during the iteration, some keys may not be visited. 738 | // No keys will be visited multiple times. 739 | func (m *llrbMapFunc[K, V]) Scan(lo, hi K) iter.Seq2[K, V] { 740 | return m.scan(lo, hi, m.cmp, m.locate) 741 | } 742 | 743 | func (t *llrb[K, V]) scan(lo, hi K, cmp func(K, K) int, locate func(K) (**rbnode[K, V], *rbnode[K, V])) iter.Seq2[K, V] { 744 | return func(yield func(K, V) bool) { 745 | if t == nil { 746 | return 747 | } 748 | pos, parent := locate(lo) 749 | x := *pos 750 | if x == nil { 751 | x = t.nextAfter(pos, parent) 752 | } 753 | for x != nil && cmp(x.key, hi) <= 0 && yield(x.key, x.val) { 754 | if x.deleted() { 755 | x = t.nextAfter(locate(x.key)) 756 | } else { 757 | x = x.next() 758 | } 759 | } 760 | } 761 | } 762 | 763 | func (t *llrb[K, V]) Dump() string { 764 | return t.root.dump() 765 | } 766 | --------------------------------------------------------------------------------