├── README.md ├── algebra └── algebra.go2 ├── build ├── concur └── concur.go2 ├── future └── cache.go2 ├── go.mod ├── hacks ├── empty.s └── hacks.go ├── mapreduce └── mapreduce.go2 ├── maps └── sortedmap.go2 ├── metric └── metric.go2 ├── number └── number.go2 ├── oddities └── odd.go2 ├── pq └── pq.go2 ├── slices └── slices.go2 ├── stream └── stream.go2 └── striped └── striped.go2 /README.md: -------------------------------------------------------------------------------- 1 | # generics 2 | Quick experiments with Go generics 3 | 4 | - `algebra`, a generic square root function for float, complex and and rational. 5 | - `concur`, various concurrency utilities. 6 | - `future`, a concurrent cache ("future cache"). 7 | - `mapreduce`, parallel Map, Reduce, and ForEach utilities 8 | - `maps`, a map with sorted keys based on a binary tree. 9 | - `metric`, a streamz-style multidimensional variable for production monitoring. 10 | - `number`, generic functions related to numbers (min, max, abs) and a user-defined complex type. 11 | - `oddities`, bugs and quirks. 12 | - `pq`, a priority queue 13 | - `slices`, generic slice utilities, and a user-defined Slice type. 14 | - `stream`, a streams library. 15 | - `striped`, a concurrency-safe map using lock striping, and also a custom hash/eq relation. 16 | 17 | First impression: 18 | 19 | This is really nice. It addresses the main things I miss about generics, namely 20 | 21 | - being able to change the equivalence relation of map; 22 | - better APIs for data structures like trees, graphs, and priority queues; and 23 | - the ability to generate efficient specialized code for a range of data types 24 | (though I understand that's not guaranteed); 25 | 26 | and does so without overwhelming complexity. 27 | 28 | By comparison, C++'s templates are extremely powerful, but syntactically and 29 | semantically quite complex, and historically the error messages have 30 | been both confusing and too late. (Templates also lead to considerable bloat 31 | in the text segment, but that may be a risk of the Go approach too.) 32 | Java's generics are limited to reference types, and thus are no use for 33 | efficient algorithms on (say) arrays of integers. 34 | Somehow the Go approach seems to do most of what I need while still 35 | feeling simple and easy to use. 36 | 37 | Observations: 38 | 39 | - Slices can now be implemented in the language, but not without `unsafe` pointer arithmetic. 40 | The generic slice algorithms work nicely. So does sorting with a custom order. 41 | 42 | - `unsafe.Sizeof` is disallowed on type parameters, yet it can be simulated using pointer arithmetic 43 | (though not as a constant expression). Why? 44 | 45 | - I often need a hash table with an alternative hash function, for 46 | - comparing non-canonical pointers (e.g. *big.Int, go/types.Type) by their referent; 47 | - comparing non-comparable values (such as slices) under the obvious relation; 48 | - using an alternative comparator (e.g. case insensitive, absolute value) for simple types. 49 | However, the custom hash function often wants to be at least partly defined in terms of 50 | the standard hash function, so the latter needs to be exposed somehow; see hacks.RuntimeHash. 51 | I imagine that could be problematic. 52 | 53 | - min, max, abs work nicely. 54 | 55 | - Go's built-in complex numbers could be satisfactorily replaced by a library. 56 | 57 | - The abstract algebraic ring generates pretty good code. 58 | One can imagine writing some numerical analysis routines this way 59 | when the algorithm is sufficiently complex that it is best not duplicated. 60 | 61 | - I couldn't find a way to achieve ad-hoc polymorphism, that is, defining a generic 62 | function by cases specialized to each possible type. For example, I don't know 63 | how to write a generic version of all the math/bits.OnesCount functions that uses 64 | the most efficient implementation. 65 | (Typeswitch doesn't handle named variants of built-in types; and using reflect is cheating.) 66 | 67 | - In practice I suspect concurrent loop abstractions will nearly always want additional 68 | parameters such as: a limit on parallelism; a context; cancellation; control over 69 | errors (e.g. ignore, choose first, choose arbitrarily, or combine all). 70 | 71 | Users will no doubt build generic libraries of collections, of stream processing functions, 72 | and of numeric analysis routines. The design space for each is large, and I imagine arriving 73 | at simple, efficient, and coherent APIs worthy of the standard library will be an arduous task. 74 | But there is no need to hurry. 75 | 76 | My experiments have tended to parameterize over built-in types, or "any". 77 | I should probably spend more time investigating interface-constrained types. 78 | -------------------------------------------------------------------------------- /algebra/algebra.go2: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "math/big" 6 | ) 7 | 8 | // Algebra is a set of traits that define an algebraic ring. 9 | // It may be defined for various kinds of number, for example. 10 | type Algebra[N any] interface { 11 | Add(N, N) N 12 | Sub(N, N) N 13 | Mul(N, N) N 14 | Div(N, N) N 15 | } 16 | 17 | // sqrt is a generic square root algorithm using Newton's method. 18 | // The generated code is efficient and specialized. 19 | func sqrt[N any, A Algebra[N]](a N) N { 20 | var alg A // trivial instantiation: Algebra is traits, not state 21 | x := a 22 | // TODO: stop when f_ is smaller than some algebra-defined epsilon. 23 | for i := 0; i < 10; i++ { 24 | f := alg.Sub(alg.Mul(x, x), a) 25 | f_ := alg.Add(x, x) 26 | x = alg.Sub(x, alg.Div(f, f_)) 27 | } 28 | return x 29 | } 30 | 31 | type float struct{} 32 | 33 | var _ Algebra[float64] = float{} 34 | 35 | func (float) Add(x, y float64) float64 { return x + y } 36 | func (float) Sub(x, y float64) float64 { return x - y } 37 | func (float) Mul(x, y float64) float64 { return x * y } 38 | func (float) Div(x, y float64) float64 { return x / y } 39 | 40 | type rational struct{} 41 | 42 | var _ Algebra[*big.Rat] = rational{} 43 | 44 | func (rational) Add(x, y *big.Rat) *big.Rat { return new(big.Rat).Add(x, y) } 45 | func (rational) Sub(x, y *big.Rat) *big.Rat { return new(big.Rat).Sub(x, y) } 46 | func (rational) Mul(x, y *big.Rat) *big.Rat { return new(big.Rat).Mul(x, y) } 47 | func (rational) Div(x, y *big.Rat) *big.Rat { return new(big.Rat).Mul(x, new(big.Rat).Inv(y)) } 48 | 49 | // This is currently the same as float, 50 | // but would diverge as we add operators. 51 | type complex struct{} 52 | 53 | var _ Algebra[complex128] = complex{} 54 | 55 | func (complex) Add(x, y complex128) complex128 { return x + y } 56 | func (complex) Sub(x, y complex128) complex128 { return x - y } 57 | func (complex) Mul(x, y complex128) complex128 { return x * y } 58 | func (complex) Div(x, y complex128) complex128 { return x / y } 59 | 60 | func main() { 61 | // Observation: unless we define Sqrt{Float,Rat,Complex} wrappers, 62 | // the need to specify both the numeric type and its algebra is clumsy. 63 | 64 | // Prints: 1.1111111060555556 65 | fmt.Println(sqrt[float64, float](1.23456789)) 66 | 67 | // Prints: 1.11111110605555554405416661433535 68 | fmt.Println(sqrt[*big.Rat, rational](big.NewRat(123456789, 1e8)).FloatString(32)) 69 | 70 | // Prints: (1.1111111060555556+0i) 71 | fmt.Println(sqrt[complex128, complex](1.23456789)) 72 | } 73 | -------------------------------------------------------------------------------- /build: -------------------------------------------------------------------------------- 1 | # Emacs: use 'M-x compile ../build' from within any .go2 file. 2 | 3 | set -eux 4 | export GOROOT=$HOME/w/goroot 5 | export PATH=$GOROOT/bin:$PATH 6 | go tool go2go translate *.go2 7 | # go tool compile -S *.go 8 | go run . 9 | 10 | -------------------------------------------------------------------------------- /concur/concur.go2: -------------------------------------------------------------------------------- 1 | // Package concur provides various concurrency utilities. 2 | package main 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "time" 8 | ) 9 | 10 | // The APIs below all potentially encourange goroutine leaks. 11 | // The inability to parameterize over a function's result type (e.g. T vs. (X, Y) vs (T, error)) 12 | // suggests that these utilities might be needed in many variants. 13 | // If the called function needs a context, presumably f would have to close over it. 14 | 15 | // CallWithCancel returns the result of calling f(), or an error if the context is cancelled. 16 | // Even upon cancellation, the call to f always runs to completion. 17 | func CallWithCancel[T any](ctx context.Context, f func() (T, error)) (T, error) { 18 | return CallUntilDone(ctx.Done(), f) 19 | } 20 | 21 | // CallWithTimeout is a variant of CallUntilDone that takes a duration. 22 | func CallWithTimeout[T any](timeout time.Duration, f func() (T, error)) (T, error) { 23 | return CallUntilDone(time.After(timeout), f) 24 | } 25 | 26 | // CallUntilDone is a variant of CallWithCancel that takes a done channel. 27 | // Returns context.DeadlineExceeded if done channel is closed before call returns.x 28 | func CallUntilDone[T, U any](done <- chan U, f func() (T, error)) (T, error) { 29 | type result struct { T; error } 30 | ch := make(chan result, 1) 31 | go func() { 32 | var res result 33 | res.T, res.error = f() 34 | ch <- res 35 | }() 36 | select { 37 | case res := <- ch: 38 | return res.T, res.error 39 | case <-done: 40 | return *new(T), context.DeadlineExceeded 41 | } 42 | } 43 | 44 | // -- test -- 45 | 46 | func main() { 47 | ctx2, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) 48 | x, err := CallWithCancel[int] (ctx2, func() (int, error) { return slowAdd(1, 2), nil }) 49 | fmt.Println(x, err) 50 | 51 | ctx2, _ = context.WithTimeout(context.Background(), 100*time.Millisecond) 52 | x, err = CallUntilDone[int, struct{}] (ctx2.Done(), func() (int, error) { return slowAdd(1, 2), nil }) 53 | fmt.Println(x, err) 54 | 55 | x, err = CallWithTimeout[int] (100*time.Millisecond, func() (int, error) { return slowAdd(1, 2), nil }) 56 | fmt.Println(x, err) 57 | } 58 | 59 | func slowAdd(x, y int) int { 60 | time.Sleep(1 * time.Second) 61 | return x + y 62 | } 63 | -------------------------------------------------------------------------------- /future/cache.go2: -------------------------------------------------------------------------------- 1 | // A generic concurrency-safe cache that memoizes a function. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | // TODO: 11 | // - make error result of f optional somehow. 12 | // - add option to suppress caching errors. 13 | // - support cancellation 14 | 15 | // A future.Cache is a concurrency safe memoization of a function f such 16 | // that the value f(k) for each distinct key is computed at most once. 17 | // 18 | // TODO: if we used the generic hash map we could drop the 'comparable' constraint 19 | // and let the user specify it. 20 | type Cache[K comparable, V any] struct { 21 | mu sync.Mutex 22 | m map[K]*future[V] 23 | f func(K) (V, error) 24 | } 25 | 26 | // New returns a new Cache that memoizes calls to f. 27 | // f must be concurrency safe. 28 | func New[K comparable, V any](f func(K) (V, error)) *Cache[K, V] { 29 | return &Cache[K, V]{ 30 | m: make(map[K]*future[V]), 31 | f: f, 32 | } 33 | } 34 | 35 | type future[V any] struct { 36 | done chan struct{} 37 | value V 38 | err error 39 | } 40 | 41 | // Get returns the value of f(k). 42 | func (c *Cache[K, V]) Get(k K) (V, error) { 43 | c.mu.Lock() 44 | f, ok := c.m[k] 45 | if !ok { 46 | // first request: compute it 47 | f = &future[V]{done: make(chan struct{})} 48 | c.m[k] = f 49 | c.mu.Unlock() 50 | f.value, f.err = c.f(k) 51 | close(f.done) 52 | } else { 53 | // subsequent request: wait 54 | c.mu.Unlock() 55 | <-f.done 56 | } 57 | return f.value, f.err 58 | } 59 | 60 | func main() { 61 | t0 := time.Now() 62 | done := make(chan struct{}) 63 | cache := New[string, int](slowStrlen) 64 | go func() { 65 | fmt.Println(cache.Get("hello")) 66 | fmt.Println(cache.Get("world")) 67 | close(done) 68 | }() 69 | fmt.Println(cache.Get("hello")) 70 | fmt.Println(cache.Get("world")) 71 | <-done 72 | fmt.Println(time.Since(t0)) // about 2s (not 4) 73 | } 74 | 75 | func slowStrlen(s string) (int, error) { 76 | time.Sleep(time.Second) 77 | return len(s), nil 78 | } 79 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/adonovan/generics 2 | 3 | go 1.15 4 | -------------------------------------------------------------------------------- /hacks/empty.s: -------------------------------------------------------------------------------- 1 | // The presence of this file allows the package to use the 2 | // "go:linkname" hack to call non-exported functions in the 3 | // Go runtime, such as string hashing. 4 | -------------------------------------------------------------------------------- /hacks/hacks.go: -------------------------------------------------------------------------------- 1 | // Hacks needed for generics. 2 | // (Must be a separate package because the linkname hack doesn't work in go2go.) 3 | package hacks // import "generic/hacks" 4 | 5 | import "unsafe" 6 | 7 | // RuntimeHash returns the hash of x used by the Go runtime's maps, 8 | // or panics if key is unhashable. 9 | func RuntimeHash(key interface{}, seed uintptr) uintptr { 10 | type eface struct { t, v unsafe.Pointer } 11 | e := (*eface)(unsafe.Pointer(&key)) 12 | if e.v == nil { 13 | return 0 14 | } 15 | return typehash(e.t, e.v, seed) 16 | } 17 | 18 | //go:linkname typehash reflect.typehash 19 | func typehash(t, p unsafe.Pointer, h uintptr) uintptr 20 | -------------------------------------------------------------------------------- /mapreduce/mapreduce.go2: -------------------------------------------------------------------------------- 1 | // Generic utilities for parallel map/reduce and other functional-programming staples. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "sync" 7 | ) 8 | 9 | // ForEach calls f(x) for x in elems, in parallel. 10 | func ForEach[T any](elems []T, f func(T)) { 11 | var wg sync.WaitGroup 12 | for _, elem := range elems { 13 | wg.Add(1) 14 | go func(elem T) { 15 | f(elem) 16 | wg.Done() 17 | }(elem) 18 | } 19 | wg.Wait() 20 | } 21 | 22 | // Map computes [f(x) for x in elems] in parallel. 23 | func Map[A, B any](elems []A, f func(A) B) []B { 24 | res := make([]B, len(elems)) 25 | var wg sync.WaitGroup 26 | wg.Add(len(elems)) 27 | for i, x := range elems { 28 | go func(i int, x A) { 29 | res[i] = f(x) 30 | wg.Done() 31 | }(i, x) 32 | } 33 | wg.Wait() 34 | return res 35 | } 36 | 37 | // Reduce computes, in parallel, the reduction of a non-empty slice 38 | // using the binary operator f, which must be a monoid (e.g. +). 39 | // f may be called concurrently. 40 | func Reduce[T any](elems []T, f func(T, T) T) T { 41 | switch n := len(elems); n { 42 | case 0: 43 | panic("empty") 44 | case 1: 45 | return elems[0] 46 | default: 47 | x := make(chan T) 48 | go func() { x <- Reduce(elems[:n/2], f) }() 49 | y := Reduce(elems[n/2:], f) 50 | return f(<-x, y) 51 | } 52 | } 53 | 54 | // --test-- 55 | 56 | func main() { 57 | ForEach([]int{1, 2, 3}, func(x int) { fmt.Println(x) }) // prints 1, 2, 3 concurrently 58 | 59 | square := func(x int) int { return x * x } 60 | fmt.Println(Map([]int{1, 2, 3}, square)) // [1 4 9] 61 | 62 | add := func(x, y int) int { return x + y } 63 | fmt.Println(Reduce([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, add)) // 55 64 | mul := func(x, y int) int { return x * y } 65 | fmt.Println(Reduce([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, mul)) // 3628800 66 | } 67 | -------------------------------------------------------------------------------- /maps/sortedmap.go2: -------------------------------------------------------------------------------- 1 | // A map data structure with sorted keys. See also striped.Map. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "strings" 7 | ) 8 | 9 | // TODO: pass in the order relation 10 | type ordered interface { 11 | type int, int8, int16, int32, int64, 12 | uint, uint8, uint16, uint32, uint64, uintptr, 13 | float32, float64, 14 | string 15 | } 16 | 17 | type Map[K ordered, V any] struct {root *node[K,V] } 18 | 19 | func New[K ordered, V any]() *Map[K, V] { return new(Map[K, V]) } 20 | 21 | func (m *Map[K, V]) Put(k K, v V) { put(&m.root, k, v) } 22 | 23 | func (m *Map[K, V]) Get(k K) V { return m.root.get(k) } 24 | 25 | func (m *Map[K, V]) Keys() []K { return m.root.appendKeys(nil) } 26 | 27 | func (m *Map[K, V]) String() string { 28 | var buf strings.Builder 29 | buf.WriteString("{") 30 | for i, k := range m.Keys() { 31 | if i > 0 { 32 | buf.WriteString(", ") 33 | } 34 | fmt.Fprintf(&buf, "%v: %v", k, m.Get(k)) 35 | } 36 | buf.WriteString("}") 37 | return buf.String() 38 | } 39 | 40 | 41 | // -- impl -- 42 | 43 | // Trivial unbalanced binary tree. 44 | type node[K ordered, V any] struct { left, right *node[K, V]; k K; v V } 45 | 46 | func put[K ordered, V any](naddr **node[K, V], k K, v V) { 47 | if *naddr == nil { 48 | *naddr = &node[K, V]{k: k, v: v} 49 | return 50 | } 51 | n := *naddr 52 | if k < n.k { 53 | put(&n.left, k, v) 54 | } else if k > n.k { 55 | put(&n.right, k, v) 56 | } else { 57 | n.v = v 58 | } 59 | } 60 | 61 | func (n *node[K, V]) get(k K) (_ V) { 62 | if n == nil { 63 | return // nope 64 | } 65 | if k < n.k { 66 | return n.left.get(k) 67 | } else if k > n.k { 68 | return n.right.get(k) 69 | } else { 70 | return n.v 71 | } 72 | } 73 | 74 | func (n *node[K, V]) appendKeys(out []K) []K { 75 | if n != nil { 76 | out = n.left.appendKeys(out) 77 | out = append(out, n.k) 78 | out = n.right.appendKeys(out) 79 | } 80 | return out 81 | } 82 | 83 | func main() { 84 | var m Map[string, int] 85 | m.Put("two", -2) 86 | m.Put("three", 3) 87 | m.Put("two", 2) 88 | m.Put("one", 1) 89 | fmt.Println(&m) // {one: 1, three: 3, two: 2} 90 | } 91 | 92 | -------------------------------------------------------------------------------- /metric/metric.go2: -------------------------------------------------------------------------------- 1 | // Metrics, in the manner of Google's monarch/streamz, which inspired prometheus.io. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | ) 7 | 8 | // A Metric is essentially a function. Parameterization over variable numbers of arguments 9 | // is not easily done in any generic type system. 10 | 11 | // I expect this design with cause a lot of code expansion, 12 | // as does the template-heavy C++ implementation of streamz used at Google, 13 | // which accounts for a significant chunk of the code size of large C++ binaries. 14 | // (I have always wondered whether that design was a mistake. Sorry about that.) 15 | 16 | // A Metric is a time-varying function of type V, whose argument is a key of type K, 17 | // which is a struct whose elements are each constrained to be a Field (int64, string, or bool). 18 | // (Algebraically, Field^n). 19 | // Metrics may be represented as a map of variables that is updated by the process, 20 | // or lazily, in terms of a callback function for each map cell. 21 | // 22 | // What about constraints on V? Usually it must be one of 23 | // - a counter (int64) 24 | // - a continuous random variable (float64) 25 | // - a protocol message (for more complex scenarios) 26 | // 27 | // TODO: locking 28 | // 29 | // This interface allows the monitoring service (not shown) to read all the process's 30 | // registered metrics (registration also not shown) during a monitoring request. 31 | // 32 | // (A better API would expose all the concrete types to the user using custom 33 | // signatures such as Get() for a scalar, or Get(K1, K2) for a 2-field map, and use the 34 | // generic abstraction over types only in the internal communication with the registration 35 | // and monitoring code.) 36 | type Metric[K comparable, V any] interface { 37 | Get(K) V 38 | Set(K, V) 39 | enumerate() []Pair[K, V] // called by monitoring system to enumerate k/v pairs 40 | } 41 | 42 | // Specialization for when V=int64, revealing atomic Increment method. 43 | // This interface is intended for users, not the monitoring system. 44 | type CounterMetric[K comparable] interface { 45 | // Metric[K, int64] // FIXME for some reason this is a parse error. 46 | // Write out base interface in full: 47 | enumerate() []Pair[K, int64] 48 | Get(K) int64 49 | Set(K, int64) 50 | 51 | Increment(K) 52 | } 53 | 54 | type Pair[X, Y any] struct {X X; Y Y} 55 | 56 | // A Field is one component of a key. (Prometheus calls them "metric labels"). 57 | type Field interface { 58 | type int64, bool, string; 59 | } 60 | 61 | type K0 struct{} 62 | type K1[T1 Field] struct{F1 T1} 63 | type K2[T1, T2 Field] struct{F1 T1; F2 T2} 64 | type K3[T1, T2, T3 Field] struct{F1 T1; F2 T2; F3 T3} 65 | type K4[T1, T2, T3, T4 Field] struct{F1 T1; F2 T2; F3 T3; F4 T4} // enough already? 66 | 67 | // -- mapMetric -- 68 | 69 | // A metric defined in terms of a map. 70 | type mapMetric[K comparable, V any] struct { 71 | m map[K]V 72 | } 73 | 74 | func (m *mapMetric[K, V]) enumerate() []Pair[K, V] { 75 | // Assertions of interface satisfaction must now live 76 | // inside the scope of a type parameter, not at top level. 77 | var _ Metric[K, V] = m 78 | 79 | res := make([]Pair[K, V], len(m.m)) 80 | for k, v := range m.m { 81 | res = append(res, Pair[K, V]{k, v}) 82 | } 83 | return res 84 | } 85 | 86 | func (m *mapMetric[K, V]) Get(k K) V { 87 | if m.m == nil { 88 | m.m = make(map[K]V) 89 | } 90 | return m.m[k] 91 | } 92 | 93 | func (m *mapMetric[K, V]) Set(k K, v V) { m.m[k] = v } 94 | 95 | // counterMapMetric is a specialization for V=int64, and has an atomic Increment method. 96 | type counterMapMetric[K comparable] struct { 97 | mapMetric[K, int64] // FIXME embedding doesn't work 98 | } 99 | 100 | func (m *counterMapMetric[K]) Increment(k K) { m.m[k]++ } 101 | 102 | // --scalarMetric-- 103 | 104 | // A metric with a single variable, and no key. 105 | type scalarMetric[V any] struct { 106 | v V 107 | } 108 | 109 | func (m *scalarMetric[V]) enumerate() []Pair[K0, V] { 110 | // Assertions of interface satisfaction must now live 111 | // inside the scope of a type parameter, not at top level. 112 | var _ Metric[K0, V] = m 113 | 114 | return []Pair[K0, V]{{Y: m.v}} 115 | } 116 | 117 | func (m *scalarMetric[V]) Get(_ K0) V { return m.v } 118 | 119 | func (m *scalarMetric[V]) Set(_ K0, v V) { m.v = v } 120 | 121 | // counterScalarMetric is a specialization for V=int64, and has an atomic Increment method. 122 | type counterScalarMetric struct { 123 | scalarMetric[int64] // FIXME embedding doesn't work 124 | } 125 | 126 | func (m *counterScalarMetric) Increment(_ K0) { m.v++ } 127 | 128 | // TODO: implement a callback-defined metric. 129 | 130 | // --test-- 131 | 132 | type hostport K2[string, int64] 133 | 134 | func main() { 135 | fmt.Println() 136 | // An example metric that counts connections. 137 | // type hostport K2[string, int64] // go2go doesn't allow locally defined types 138 | var connections counterMapMetric[hostport] 139 | 140 | // FIXME all of the .mapMetric operations below should not be required, 141 | // but appear to work around a compiler bug. 142 | // Nonetheless, we still get confusing compiler errors of the form: 143 | // "m.scalarMetric undefined (type *counterScalarMetric has no field or method scalarMetric)" 144 | // reported in the wrong place. 145 | // See https://github.com/golang/go/issues/44688 146 | 147 | k := hostport{"localhost", 80} 148 | connections.mapMetric.Set(k, connections.mapMetric.Get(k) + 1) // compile errors: see 'oddities' package 149 | connections.Increment(k) 150 | 151 | k = hostport{"github.com", 443} 152 | connections.mapMetric.Set(k, connections.mapMetric.Get(k) + 1) 153 | connections.Increment(k) 154 | 155 | for _, e := range connections.mapMetric.enumerate() { 156 | fmt.Println(e) 157 | } 158 | 159 | var counter counterScalarMetric 160 | counter.Increment(K0{}) 161 | counter.Increment(K0{}) 162 | counter.Increment(K0{}) 163 | for _, e := range counter.scalarMetric.enumerate() { 164 | fmt.Println(e) 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /number/number.go2: -------------------------------------------------------------------------------- 1 | // Generic functions related to numbers. 2 | package main 3 | 4 | // TODO: 5 | // - routines from Hacker's Delight (e.g. clp2, nlz, popcount) 6 | // Specialized routines already exist as math/bits.OnesCount[N], 7 | // but I don't see an easy way to write a generic one that 8 | // dispatches to the most efficient implementation for the width. 9 | // But see `algebra` package for a generic implementation of Newton's 10 | // method that works for float64, complex128, and *big.Rat. 11 | // - Galois fields of (2^N). 12 | 13 | import ( 14 | "math/big" 15 | "fmt" 16 | "unsafe" 17 | ) 18 | 19 | type integer interface { 20 | type int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr; 21 | } 22 | 23 | type float interface { 24 | type float32, float64; 25 | } 26 | 27 | // NB: real, not complex. 28 | type number interface { 29 | // Q. is there a shorter syntax 'for integer + float'? 30 | type int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, 31 | float32, float64; 32 | } 33 | 34 | // Bits returns the number of bits in an integer or float. 35 | func Bits[T number]() int { 36 | // WRONG: Sizeof(T) is disallowed for some reason. 37 | // (Presumably because it would be non-constant?) 38 | // return unsafe.Sizeof(*new(*T)) * 8 39 | 40 | // WRONG: doesn't work for named (aka defined) types. 41 | // switch (interface{}(*new(T))).(type) { 42 | // case int8, uint8: 43 | // return 8 44 | // case int16, uint16: 45 | // return 16 46 | // ... 47 | 48 | // The compiler doesn't recognize this as a constant, 49 | // and generates poor code. 50 | var array [2]T 51 | return 8 * int(uintptr(unsafe.Pointer(&array[1])) - uintptr(unsafe.Pointer(&array[0]))) 52 | } 53 | 54 | // Signed reports whether T is a signed integer type. 55 | func Signed[T integer]() bool { 56 | // This causes a type checker crash (https://github.com/golang/go/issues/44762) 57 | // return T(0) - 1 < 0 58 | 59 | var zero T 60 | return zero - 1 < 0 61 | } 62 | 63 | // Max returns the maximum value of integer type T. 64 | func Max[T integer]() T { 65 | bits := Bits[T]() 66 | if Signed[T]() { 67 | return 1<<(bits-1) - 1 68 | } else { 69 | return (1< 0. 59 | func (pq *PriorityQueue[T]) Min() (T, int) { 60 | it := pq.items[0] 61 | return it.value, it.priority 62 | } 63 | 64 | // TakeMin removes and returns the least element in the queue, 65 | // along with its priority. 66 | // Precondition: Len() > 0. 67 | func (pq *PriorityQueue[T]) TakeMin() (T, int) { 68 | it := heap.Pop(&pq.items).(*item[T]) 69 | return it.value, it.priority 70 | } 71 | 72 | // -- impl -- 73 | 74 | // repr implements heap.Interface. 75 | // Q. Is the poijnter indirection necessary? (avoids garbage alloc in pop) 76 | type repr[T any] []*item[T] 77 | 78 | type item[T any] struct { 79 | value T 80 | priority int 81 | } 82 | 83 | func (pq repr[T]) Len() int { return len(pq) } 84 | 85 | func (pq repr[T]) Less(i, j int) bool { 86 | return pq[i].priority < pq[j].priority 87 | } 88 | 89 | func (pq repr[T]) Swap(i, j int) { 90 | pq[i], pq[j] = pq[j], pq[i] 91 | } 92 | 93 | func (pq *repr[T]) Push(x interface{}) { 94 | *pq = append(*pq, x.(*item[T])) 95 | } 96 | 97 | func (pq *repr[T]) Pop() interface{} { 98 | old := *pq 99 | last := len(old)-1 100 | it := old[last] 101 | old[last] = nil 102 | *pq = old[:last] 103 | return it 104 | } 105 | 106 | // -- test -- 107 | 108 | func main() { 109 | // FromMap: initial elements and priorities supplied by map. 110 | pq := FromMap(map[string]int{"banana": 3, "apple": 2, "pear": 4}) 111 | pq.Add("orange", 1) 112 | for pq.Len() > 0 { 113 | v, pri := pq.TakeMin() 114 | fmt.Println(pri, v) 115 | } 116 | fmt.Println() 117 | // Output: 118 | // 1 orange 119 | // 2 apple 120 | // 3 banana 121 | // 4 pear 122 | 123 | // FromSlice: initial elements supplied from slice, priorities from a function. 124 | pq = FromSlice([]string{"apple", "banana", "orange"}, strlen) 125 | pq.Add("pear", 4) 126 | for pq.Len() > 0 { 127 | v, pri := pq.TakeMin() 128 | fmt.Println(pri, v) 129 | } 130 | fmt.Println() 131 | // Output: 132 | // 4 pear 133 | // 5 apple 134 | // 6 orange 135 | // 6 banana 136 | } 137 | 138 | func strlen(x string) int { return len(x) } 139 | -------------------------------------------------------------------------------- /slices/slices.go2: -------------------------------------------------------------------------------- 1 | // Generic algorithms over slices. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "math/big" 7 | "sort" 8 | "unsafe" 9 | ) 10 | 11 | // A user-defined slice type, to demonstrate that the runtime slice type can now 12 | // be implemented within the language. 13 | // We tried *[maxint]T for the data field, but it exceeds the size of the address space. 14 | // (We can't use maxint/unsafe.Sizeof(T) in case T is a zero-length type.) 15 | // data could equally be an unsafe.Pointer: the same number of casts are needed, 16 | // but in different places; in particular, offset() calls don't need offset[T]). 17 | // We prefer *T for clarity. 18 | type Slice[T any] struct { 19 | data *T // pointer to an element of an array of type [n]T 20 | len, cap int 21 | } 22 | 23 | func (s Slice[T]) Len() int { return s.len } 24 | func (s Slice[T]) Cap() int { return s.cap } 25 | func (s Slice[T]) IsNil() bool { return s.data == nil } 26 | func (s Slice[T]) Addr(i int) *T { 27 | if 0 <= i && i < s.len { 28 | return offset(s.data, i) 29 | } 30 | panic("index out of range") 31 | } 32 | func (s Slice[T]) Elem(i int) T { return *s.Addr(i) } 33 | func (s Slice[T]) Append(elems...T) Slice[T] { 34 | newlen := s.len + len(elems) 35 | newcap := s.cap 36 | if newlen > s.cap { 37 | // Expand capacity by doubling to ensure geometric growth. 38 | newcap = max(newlen, s.len * 2) 39 | new := make([]T, newcap) 40 | copy(new, s.ToSlice()) 41 | s.data = fromRuntime(new).data 42 | } 43 | for i, elem := range elems { 44 | *offset(s.data, s.len + i) = elem 45 | } 46 | s.len = newlen 47 | s.cap = newcap 48 | return s 49 | } 50 | func (s Slice[T]) AppendSlice(t Slice[T]) Slice[T] { return s.Append(s.ToSlice()...) } 51 | func (s Slice[T]) ToSlice() []T { return toRuntime(s) } 52 | func (s Slice[T]) String() string { return fmt.Sprint(s.ToSlice()) } 53 | 54 | // s[i:j] 55 | func (s Slice[T]) Slice(i, j int) Slice[T] { 56 | if 0 <= i && i <= j && j <= s.len { 57 | s.len = j - i 58 | s.cap -= i 59 | s.data = offset(s.data, i) 60 | return s; 61 | } 62 | panic("indices out of range") 63 | } 64 | 65 | // s[::cap] 66 | func (s Slice[T]) WithCap(cap int) Slice[T] { 67 | if cap < s.len { 68 | panic("invalid cap index") 69 | } 70 | return Slice[T]{s.data, s.len, cap} 71 | } 72 | 73 | // -- unsafe hacks --- 74 | 75 | type uP = unsafe.Pointer 76 | 77 | func offset[T any](data *T, index int) *T { 78 | return (*T)(uP(uintptr(uP(data)) + sizeof[T]() * uintptr(index))) 79 | } 80 | 81 | // sizeof returns the address difference between adjacent []T array elements. 82 | // (Why does unsafe.Sizeof(T) not work, when we can implement it in the language?) 83 | func sizeof[T any]() uintptr { 84 | // The compiler generates very poor code for this function. 85 | var array [2]T 86 | return uintptr(uP(&array[1])) - uintptr(uP(&array[0])) 87 | } 88 | 89 | func fromRuntime[T any](slice []T) Slice[T] { 90 | return *(*Slice[T])(uP(&slice)) 91 | } 92 | 93 | func toRuntime[T any](slice Slice[T]) []T { 94 | return *(*[]T)(uP(&slice)) 95 | } 96 | 97 | // -- utils -- 98 | 99 | func max(x, y int) int { 100 | if x > y { 101 | return x 102 | } else { 103 | return y 104 | } 105 | } 106 | 107 | // ------------------------------------------------------------------------ 108 | 109 | // Generic algorithms over runtime slices. 110 | 111 | type integer interface { 112 | type int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr; 113 | } 114 | 115 | // SortInts sorts the slice of integers using its natural order. 116 | // (No need for strings, as sort.Strings exists.) 117 | // (No need for floats, as they are not a strict weak order.) 118 | func SortInts[T integer](x []T) { 119 | sort.Slice(x, func(i, j int) bool { return x[i] < x[j] }) 120 | } 121 | 122 | // Sort sorts a slice using the given strict weak order. 123 | func Sort[T any](slice []T, less func(x, y T) bool) { 124 | sort.Slice(slice, func(i, j int) bool { return less(slice[i], slice[j]) }) 125 | } 126 | 127 | // Uniq combines adjacent elements that are equal, in place. 128 | // Don't forget to use the result! 129 | func Uniq[T comparable](in []T) []T { 130 | out := in[:0] 131 | for _, x := range in { 132 | if len(out) > 0 && x == out[len(out)-1] { 133 | continue // duplicate 134 | } 135 | out = append(out, x) 136 | } 137 | return out 138 | } 139 | 140 | // Filter discards elements for which !keep(x), in place. 141 | // Don't forget to use the result! 142 | func Filter[T any](in []T, keep func(x T) bool) []T { 143 | out := in[:0] 144 | for _, x := range in { 145 | if keep(x) { 146 | out = append(out, x) 147 | } 148 | } 149 | return out 150 | } 151 | 152 | type Pair[X, Y any] struct {X X; Y Y} 153 | 154 | // Zip produces the x-major cross product of two slices. 155 | func Zip[X, Y any](xx []X, yy []Y) (res []Pair[X, Y]) { 156 | for _, x := range xx { 157 | for _, y := range yy { 158 | res = append(res, Pair[X, Y]{x, y}) 159 | } 160 | } 161 | return 162 | } 163 | 164 | // --test-- 165 | 166 | func main() { 167 | var s Slice[string] 168 | s = s.Append("hello") 169 | s = s.Append("world") 170 | s = s.AppendSlice(s) 171 | fmt.Println(s) // [hello world hello world] 172 | s = s.Slice(0, 3) 173 | fmt.Println(s) // [hello world hello] 174 | *s.Addr(2) = "goodbye" 175 | fmt.Println(s) // [hello world goodbye] 176 | 177 | // discard odd elements, in place (destroys s) 178 | out := s.Slice(0, 0) // zero-length prefix 179 | for i := 0; i < s.Len(); i++ { 180 | if i & 1 == 0 { 181 | out = out.Append(s.Elem(i)) 182 | } 183 | } 184 | fmt.Println(out) // [hello goodbye] 185 | 186 | // ------------ 187 | 188 | // Uniq, Filter 189 | a := []string{"one", "two", "three", "two"} 190 | sort.Strings(a) 191 | a = Uniq(a) 192 | a = Filter(a, func(x string) bool { return x[0] == 't' }) 193 | fmt.Println(a) // ["three two"] 194 | 195 | // SortInts 196 | b := []uint16{9, 3, 7, 0, 7} 197 | SortInts(b) 198 | fmt.Println(b) // [0 3 7 7 9] 199 | 200 | // Sort slice of pointers using custom order. 201 | c := []*big.Int{bigInt(9), bigInt(3), bigInt(7), bigInt(0), bigInt(7)} 202 | Sort(c, bigIntLess) 203 | fmt.Println(c) // [0 3 7 7 9] 204 | 205 | // Zip 206 | d := Zip(a, b) 207 | fmt.Println(d) // {three 0} {three 3} {three 7} {three 7} {three 9} {two 0} {two 3} {two 7} {two 7} {two 9}] 208 | } 209 | 210 | // bigInt returns a bigint (a value whose standard hash/eq/< relations are not the logical ones). 211 | func bigInt(x int64) *big.Int { return new(big.Int).SetInt64(x) } 212 | 213 | func bigIntLess(x, y *big.Int) bool { return x.Cmp(y) < 0; } 214 | 215 | 216 | -------------------------------------------------------------------------------- /stream/stream.go2: -------------------------------------------------------------------------------- 1 | // Generic utilities for streams processing. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | // A stream is a function that retrieves and returns the next element in a sequence, 11 | // and whether that was successful. 12 | type Stream[T any] func() (T, bool) 13 | 14 | // MakeStream is a stream source that yields the specified elements. 15 | func MakeStream[T any](x ...T) Stream[T] { 16 | s := sliceStream[T](x) 17 | return s.next 18 | } 19 | 20 | // FibonacciStream is another stream source, this one infinite, that yields the Fibonacci sequence. 21 | func FibonacciStream() Stream[int] { 22 | x, y := 0, 1 23 | return func() (int, bool) { 24 | x, y = y, x+y 25 | return x, true 26 | } 27 | } 28 | 29 | // PrintStream is a stream sink that prints each element. 30 | func PrintStream[T any](input Stream[T]) { 31 | for { 32 | x, ok := input() 33 | if !ok { 34 | break 35 | } 36 | fmt.Println(x) 37 | } 38 | } 39 | 40 | // SumStream is another sink, that adds numbers (or concatenates strings). 41 | func SumStream[T addable](input Stream[T]) (sum T) { 42 | for { 43 | x, ok := input() 44 | if !ok { 45 | return 46 | } 47 | sum += x 48 | } 49 | } 50 | 51 | type addable interface { 52 | type int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, float32, float64, complex64, complex128, string 53 | } 54 | 55 | // Limit is a stream transformer that truncates the input at limit elements. 56 | func Limit[T any](input Stream[T], limit int) Stream[T] { 57 | i := 0 58 | return func() (_ T, ok bool) { 59 | if i < limit { 60 | i++ 61 | return input() 62 | } 63 | return 64 | } 65 | } 66 | 67 | // Transform transforms a stream of X to an stream of Y. 68 | func Transform[X, Y any](input Stream[X], f func(X) Y) Stream[Y] { 69 | return func() (y Y, ok bool) { 70 | x, ok := input() 71 | if ok { 72 | y = f(x) 73 | } 74 | return 75 | } 76 | } 77 | 78 | // ParallelTransform transforms a stream of X to an (unordered) stream of Y, in parallel, 79 | // as fast is it can read items from the input stream. 80 | // The function f is called concurrently. 81 | // The client must drain the result stream (see chanStream). 82 | // 83 | // TODO: 84 | // - cancellation 85 | // - concurrency limiting 86 | // - avoid goroutine leak from failure to drain result. 87 | // - buffer results so that they are reported in the original order? 88 | // The design space is huge (as is the potential for explosion of the generated code, 89 | // and for impenetrably abstract control flow). 90 | func ParallelTransform[X, Y any](input Stream[X], f func(X) Y) Stream[Y] { 91 | ch := make(chan Y, 1) 92 | var wg sync.WaitGroup 93 | for { 94 | x, ok := input() 95 | if !ok { 96 | break 97 | } 98 | wg.Add(1) 99 | go func() { 100 | ch <- f(x) 101 | wg.Done() 102 | }() 103 | } 104 | go func() { 105 | wg.Wait() 106 | close(ch) 107 | }() 108 | return ChanStream(ch) 109 | } 110 | 111 | // Closure for a stream over slice elements. 112 | type sliceStream[T any] []T 113 | 114 | func (it *sliceStream[T]) next() (elem T, ok bool) { 115 | if len(*it) > 0 { 116 | elem, *it, ok = (*it)[0], (*it)[1:], true 117 | } 118 | return 119 | } 120 | 121 | // Closure for an iterator over channels. 122 | // This is a bad design, because failure to drain the iterator 123 | // leads to a goroutine leak. But the types work. 124 | type chanStream[T any] <-chan T 125 | 126 | func (ch chanStream[T]) next() (x T, ok bool) { x, ok = <-ch; return } 127 | 128 | func ChanStream[T any](ch <-chan T) Stream[T] { 129 | return chanStream[T](ch).next 130 | } 131 | 132 | // --test-- 133 | 134 | func main() { 135 | square := func(x int) int { return x * x } 136 | PrintStream(ParallelTransform(MakeStream(1, 2, 3), square)) // prints 1, 4, 9 in some order 137 | fmt.Println(SumStream(ParallelTransform(MakeStream(1, 2, 3), square))) // 14 138 | 139 | // Prints 10ms, 20ms, 30ms in that order. 140 | const ms = time.Millisecond 141 | PrintStream( 142 | ParallelTransform( 143 | MakeStream(30*ms, 10*ms, 20*ms), 144 | func(x time.Duration) time.Duration { time.Sleep(x); return x })) 145 | 146 | // Prints "1 1 2 3 5" 147 | fmt.Println( 148 | SumStream( 149 | Transform( 150 | Limit(FibonacciStream(), 5), 151 | func(x int) string { return fmt.Sprintf("%d ", x) }))) 152 | 153 | } 154 | -------------------------------------------------------------------------------- /striped/striped.go2: -------------------------------------------------------------------------------- 1 | // A concurrency-safe map using striped locks. 2 | // 3 | // There are no Len() nor Iterate() methods as these are poorly 4 | // defined or inefficient wrt concurrency, but they could easily be added. 5 | package main 6 | 7 | import ( 8 | "fmt" 9 | "hash/fnv" 10 | "io" 11 | "runtime" 12 | "strings" 13 | "sync" 14 | "math/bits" 15 | "github.com/adonovan/generics/hacks" 16 | ) 17 | 18 | // A Map is concurrency-safe hash table using one striped locks (one per bucket). 19 | // 20 | // By default, it uses the same key hash function and equivalence relation 21 | // (KeyComp) as a standard map; consequently the key type must be comparable, 22 | // and panic results if hashing or equivalence fails dynamically. 23 | // Clients may override this behavior by specifying an alternative KeyComp. 24 | type Map[K, V any] struct { 25 | buckets []bucket[K,V] // nonempty; len is power of 2 26 | kcomp KeyComp[K] 27 | } 28 | 29 | type KeyComp[K any] interface { 30 | Hash(K) uintptr // TODO: add seed param 31 | Equal(K, K) bool 32 | } 33 | 34 | type bucket[K, V any] struct { 35 | mu sync.Mutex 36 | entries []entry[K, V] 37 | } 38 | 39 | type entry[K, V any] struct { 40 | hash uintptr 41 | key K 42 | value V 43 | } 44 | 45 | // New returns a new, empty map using the standard key comparator. 46 | func New[K comparable, V any]() *Map[K, V] { 47 | var std stdKeyComp[K] 48 | return NewWithKeyComp[K, V](std) 49 | } 50 | 51 | // New returns a new, empty map using the specified key comparator. 52 | func NewWithKeyComp[K, V any](kcomp KeyComp[K]) *Map[K, V] { 53 | // TODO: better estimate 54 | nbuckets := int(clp2(uint64(runtime.GOMAXPROCS(-1)))) 55 | return &Map[K,V]{ 56 | kcomp: kcomp, 57 | buckets: make([]bucket[K,V], nbuckets), 58 | } 59 | } 60 | 61 | // FromMap returns a new Map containing the given elements, 62 | // and the standard key comparator. 63 | func FromMap[K comparable, V any](elems map[K]V) *Map[K, V] { 64 | m := New[K,V]() // TODO: use initial size 65 | for k, v := range elems { 66 | m.Insert(k, v) 67 | } 68 | return m 69 | } 70 | 71 | // Insert adds an entry to the map that associates the given key and value. 72 | // It returns the previously associated value, or the zero value if there was none. 73 | // Concurrency-safe. 74 | func (m *Map[K, V]) Insert(k K, v V) V { 75 | hash, b := m.getBucket(k) 76 | return b.insert(hash, k, v, m.kcomp) 77 | } 78 | 79 | // Remove removes the map entry for a given key, and returns its associated value, 80 | // or the zero value if there was none. Concurrency-safe. 81 | func (m *Map[K, V]) Remove(k K) V { 82 | hash, b := m.getBucket(k) 83 | return b.remove(hash, k, m.kcomp) 84 | } 85 | 86 | // Clear removes all entries from the map in an unspecified order. Concurrency-safe. 87 | // Concurrency-safe. 88 | func (m *Map[K, V]) Clear() { 89 | for _, b := range m.buckets { 90 | b.clear() 91 | } 92 | } 93 | 94 | // Get returns the value associated with key k, or the zero value if not found. 95 | // Concurrency-safe. 96 | // TODO: provide a "v, ok = m[k]" flavor too. 97 | func (m *Map[K, V]) Get(k K) V { 98 | hash, b := m.getBucket(k) 99 | return b.get(hash, k, m.kcomp) 100 | } 101 | 102 | func (m *Map[K, V]) String() string { 103 | var out strings.Builder 104 | out.WriteString("{") 105 | for _, b := range m.buckets { 106 | b.foreach(func(k K, v V) { 107 | if out.Len() > 1 { 108 | out.WriteString(", ") 109 | } 110 | fmt.Fprintf(&out, "%v: %v", k, v) 111 | }) 112 | } 113 | out.WriteString("}") 114 | return out.String() 115 | } 116 | 117 | // -- impl -- 118 | 119 | func (b *bucket[K, V]) insert(hash uintptr, k K, v V, kcomp KeyComp[K]) (prev V) { 120 | b.mu.Lock() 121 | defer b.mu.Unlock() 122 | for i := range b.entries { 123 | e := &b.entries[i] 124 | if e.hash == hash && kcomp.Equal(e.key, k) { 125 | prev = e.value 126 | e.key = k 127 | e.value = v 128 | return 129 | } 130 | } 131 | b.entries = append(b.entries, entry[K, V]{hash, k, v}) 132 | return 133 | } 134 | 135 | func (b *bucket[K, V]) remove(hash uintptr, k K, kcomp KeyComp[K]) (prev V) { 136 | b.mu.Lock() 137 | defer b.mu.Unlock() 138 | for i := range b.entries { 139 | e := &b.entries[i] 140 | if e.hash == hash && kcomp.Equal(e.key, k) { 141 | prev = e.value 142 | last := len(b.entries)-1 143 | b.entries[i] = b.entries[last] 144 | b.entries = b.entries[:last] 145 | break 146 | } 147 | } 148 | return 149 | } 150 | 151 | func (b *bucket[K, V]) get(hash uintptr, k K, kcomp KeyComp[K]) (v V) { 152 | b.mu.Lock() 153 | defer b.mu.Unlock() 154 | for i := range b.entries { 155 | e := &b.entries[i] 156 | if e.hash == hash && kcomp.Equal(e.key, k) { 157 | v = e.value 158 | break 159 | } 160 | } 161 | return 162 | } 163 | 164 | func (b *bucket[K, V]) clear() { 165 | b.mu.Lock() 166 | for i := range b.entries { 167 | b.entries[i] = entry[K, V]{} // aid GC 168 | } 169 | b.entries = b.entries[:0] 170 | b.mu.Unlock() 171 | } 172 | 173 | func (b *bucket[K, V]) foreach(f func(k K, v V)) { 174 | b.mu.Lock() 175 | for _, e := range b.entries { 176 | f(e.key, e.value) 177 | } 178 | b.mu.Unlock() 179 | } 180 | 181 | func (m *Map[K, V]) getBucket(k K) (hash uintptr, b *bucket[K, V]) { 182 | // TODO: Seed properly. Flood protection. Don't discard top bits. 183 | hash = m.kcomp.Hash(k) 184 | return hash, &m.buckets[hash % uintptr(len(m.buckets))] 185 | } 186 | 187 | // A key comparator that uses the same relation as the standard map. May panic. 188 | type stdKeyComp[K comparable] struct{} 189 | 190 | func (stdKeyComp[K]) Hash(k K) uintptr { 191 | return hacks.RuntimeHash(k, /*seed=*/0) // may panic 192 | } 193 | 194 | func (stdKeyComp[K]) Equal(x, y K) bool { 195 | return x == y // may panic 196 | } 197 | 198 | // clp2 returns x rounded up to a power of 2 ("ceiling power 2"). See HD 3-2. 199 | func clp2(x uint64) uint64 { 200 | return (uint64(1)<<63) >> uint64(bits.LeadingZeros64(x-1)-1) 201 | } 202 | 203 | // -- test -- 204 | 205 | func main() { 206 | // int keys 207 | m1 := New[int, int]() 208 | fmt.Println(m1.Insert(1, 2)) // =0 209 | fmt.Println(m1.Insert(2, 4)) // =0 210 | fmt.Println(m1.Get(1)) // =2 211 | fmt.Println(m1.Remove(1)) // =2 212 | fmt.Println(m1.Remove(2)) // =4 213 | fmt.Println(m1.Remove(2)) // =0 214 | 215 | // string keys (wider than 1 word) 216 | m2 := New[string, int]() 217 | fmt.Println(m2.Insert("one", 1)) // =0 218 | fmt.Println(m2.Insert("one", 2)) // =1 219 | fmt.Println(m2.Get("one")) // =2 220 | fmt.Println(m2.Remove("one")) // =2 221 | fmt.Println(m2.Remove("one")) // =0 222 | 223 | // string keys, case insensitive 224 | m2a := NewWithKeyComp[string, int](stringCompNoCase{}) 225 | fmt.Println(m2a.Insert("one", 1)) // =0 226 | fmt.Println(m2a.Insert("One", 2)) // =1 227 | fmt.Println(m2a.Get("ONE")) // =2 228 | fmt.Println(m2a.Remove("One")) // =2 229 | fmt.Println(m2a.Remove("one")) // =0 230 | 231 | // pointer keys 232 | // type S struct{int} // a bug: this S must be defined outside 'main' 233 | var a, b S 234 | m3 := New[*S, string]() 235 | fmt.Println(m3.Insert(&a, "a")) // ="" 236 | fmt.Println(m3.Insert(&b, "b")) // ="" 237 | fmt.Println(m3.Get(&a)) // ="a" 238 | fmt.Println(m3.Get(&b)) // ="b" 239 | fmt.Println(m3.Remove(&a)) // ="a" 240 | fmt.Println(m3.Remove(&a)) // ="" 241 | 242 | // interface keys 243 | m4 := New[interface{}, string]() 244 | fmt.Println(m4.Insert(1, "1")) // ="" 245 | fmt.Println(m4.Insert("two", "2")) // ="" 246 | fmt.Println(m4.Insert(S{3}, "3")) // ="" 247 | fmt.Println(m4.Get(1)) // ="1" 248 | fmt.Println(m4.Get("two")) // ="2" 249 | fmt.Println(m4.Remove(1)) // ="1" 250 | fmt.Println(m4.Remove("two")) // ="2" 251 | fmt.Println(m4.Remove(S{3})) // ="3" 252 | 253 | // slice keys (but don't mutate them) 254 | m5 := NewWithKeyComp[[]string, int](sliceComp[string]{}) 255 | m5.Insert(strings.Fields("hello, world"), 1) 256 | m5.Insert(strings.Fields("a b c"), 2) 257 | m5.Insert([]string{}, 3) 258 | m5.Insert(nil, 4) 259 | fmt.Println(m5) // {[]: 4, [a b c]: 2, [hello, world]: 1} 260 | 261 | // dynamically unhashable 262 | fmt.Println(m4.Remove(main)) // panic: runtime error: hash of unhashable type func() 263 | } 264 | 265 | type S struct{int} 266 | 267 | // A case insensitive comparator for string keys. 268 | type stringCompNoCase struct{} 269 | 270 | func (stringCompNoCase) Hash(x string) uintptr { 271 | // Would be nice if the runtime's string hash were easily and efficiently 272 | // accessible. (Maphash isn't there yet; see 273 | // https://github.com/golang/go/issues/42710#issuecomment-763950234). 274 | h := fnv.New64a() 275 | io.WriteString(h, strings.ToLower(x)) 276 | return uintptr(h.Sum64()) 277 | } 278 | 279 | func (stringCompNoCase) Equal(x, y string) bool { 280 | return strings.ToLower(x)==strings.ToLower(y) 281 | } 282 | 283 | // A comparator for slices of comparable element type T. 284 | // May panic on elements that are not dynamically hashable/comparable. 285 | type sliceComp[T comparable] struct{} 286 | 287 | func (sliceComp[T]) Hash(slice []T) uintptr { 288 | var hash uintptr 289 | for _, elem := range slice { 290 | hash = hash * 7 + stdKeyComp[T]{}.Hash(elem) 291 | } 292 | return hash 293 | } 294 | 295 | func (sliceComp[T]) Equal(x, y []T) bool { 296 | if len(x) != len(y) { 297 | return false 298 | } 299 | for i := range x { 300 | if x[i] != y[i] { 301 | return false 302 | } 303 | } 304 | return true 305 | } 306 | --------------------------------------------------------------------------------