├── .gitignore ├── nlp ├── doc.go ├── wordnet │ ├── string.go │ ├── types.go │ └── parser_test.go ├── tfidf.go ├── tokenize.go └── lda-tool │ └── main.go ├── rhash ├── rabin32_test.go ├── rabin64_test.go ├── doc.go ├── buz_test.go ├── fuzz_test.go ├── bench_test.go ├── rabin32.go ├── rabin64.go └── common_test.go ├── go.mod ├── bnry ├── doc.go ├── bnry_test.go └── write.go ├── README.md ├── .github └── workflows │ └── go.yml ├── hashx ├── hashx_test.go └── hashx.go ├── go.sum ├── aio ├── wrap.go └── open.go ├── ppln ├── common.go ├── doc.go ├── common_test.go ├── nserial.go ├── nserial_test.go ├── serial_test.go └── serial.go ├── clustering ├── kmeans_test.go ├── agglo_test.go ├── upgma_test.go ├── rand.go └── upgma.go ├── morris ├── error │ └── error.go ├── morris_test.go └── morris.go ├── csvx ├── csvx_test.go ├── examples_test.go └── csvx.go ├── repeat ├── repeat_test.go └── repeat.go ├── reservoir ├── reservoir.go └── reservoir_test.go ├── LICENSE ├── gnum ├── vecs_test.go ├── vecs.go ├── gnum.go └── gnum_test.go ├── snm ├── queue_test.go ├── queue.go └── snm_test.go ├── sets ├── sorted_test.go ├── sets_test.go ├── sorted.go └── sets.go ├── xmlnode ├── read_test.go └── read.go ├── graphs ├── bfsdfs.go ├── graph_test.go ├── bfsdfs_test.go └── graph.go ├── ezpprof └── ezpprof.go ├── jio └── jio.go ├── iterx ├── slice.go ├── slice_test.go ├── lines.go ├── unreader.go ├── unreader2.go ├── unreader_test.go └── unreader2_test.go ├── bits ├── bits.go └── bits_test.go ├── hll ├── hll_test.go └── hll.go ├── heaps ├── heaps_test.go └── heaps.go ├── ptimer ├── ptimer_test.go └── ptimer.go ├── flagx ├── flagx_test.go └── flagx.go ├── bloom ├── bloom_test.go └── bloom.go ├── prefixtree └── tree.go └── minhash ├── minhash_test.go └── minhash.go /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | testdata 3 | -------------------------------------------------------------------------------- /nlp/doc.go: -------------------------------------------------------------------------------- 1 | // Package nlp provides basic NLP utilities. 2 | package nlp 3 | -------------------------------------------------------------------------------- /rhash/rabin32_test.go: -------------------------------------------------------------------------------- 1 | package rhash 2 | 3 | import ( 4 | "hash" 5 | "testing" 6 | ) 7 | 8 | func TestRabin32(t *testing.T) { 9 | test32(t, func(n int) hash.Hash32 { return NewRabinFingerprint32(n) }) 10 | } 11 | -------------------------------------------------------------------------------- /rhash/rabin64_test.go: -------------------------------------------------------------------------------- 1 | package rhash 2 | 3 | import ( 4 | "hash" 5 | "testing" 6 | ) 7 | 8 | func TestRabin64(t *testing.T) { 9 | test64(t, func(n int) hash.Hash64 { return NewRabinFingerprint64(n) }) 10 | } 11 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fluhus/gostuff 2 | 3 | go 1.24 4 | 5 | require ( 6 | github.com/agonopol/go-stem v0.0.0-20150630113328-985885018250 7 | github.com/klauspost/compress v1.17.9 8 | github.com/spaolacci/murmur3 v1.1.0 9 | golang.org/x/exp v0.0.0-20220325121720-054d8573a5d8 10 | ) 11 | -------------------------------------------------------------------------------- /rhash/doc.go: -------------------------------------------------------------------------------- 1 | // Package rhash provides implementations of rolling-hash functions. 2 | // 3 | // A rolling-hash is a hash function that "remembers" only the last n bytes it 4 | // received, where n is a parameter. Meaning, the hash of a byte sequence 5 | // always equals the hash of its last n bytes. 6 | package rhash 7 | -------------------------------------------------------------------------------- /bnry/doc.go: -------------------------------------------------------------------------------- 1 | // Package bnry provides simple functions for encoding and decoding values as 2 | // binary. 3 | // 4 | // # Supported data types 5 | // 6 | // The types that can be encoded and decoded are 7 | // int*, uint* (excluding int and uint), float*, bool, string and 8 | // slices of these types. 9 | // [Read] and [UnmarshalBinary] expect pointers to these types, 10 | // while [Write] and [MarshalBinary] expect non-pointers. 11 | package bnry 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | gostuff 2 | ======= 3 | 4 | [![Go Reference](https://pkg.go.dev/badge/github.com/fluhus/gostuff.svg)](https://pkg.go.dev/github.com/fluhus/gostuff) 5 | [![Go Report Card](https://goreportcard.com/badge/github.com/fluhus/gostuff)](https://goreportcard.com/report/github.com/fluhus/gostuff) 6 | 7 | I believe that simple ideas deserve simple code. 8 | 9 | This is a collection of packages I wrote for my personal work as a data 10 | scientist, and may benefit others as well. 11 | The emphasis is on simplicity and minimalism, to keep code readable and clean. 12 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "master" ] 9 | pull_request: 10 | branches: [ "master" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v4 21 | with: 22 | go-version: '1.24' 23 | 24 | - name: Build 25 | run: go build -v ./... 26 | 27 | - name: Test 28 | run: go test -v ./... 29 | -------------------------------------------------------------------------------- /hashx/hashx_test.go: -------------------------------------------------------------------------------- 1 | package hashx 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/fluhus/gostuff/sets" 7 | ) 8 | 9 | func TestInt(t *testing.T) { 10 | const n = 100000 11 | 12 | hashes := sets.Set[uint64]{} 13 | hx := New() 14 | for i := range n { 15 | hashes.Add(IntHashx(hx, i)) 16 | } 17 | if len(hashes) != n { 18 | t.Errorf("len(hashes)=%v, want %v", len(hashes), n) 19 | } 20 | 21 | hashes = sets.Set[uint64]{} 22 | hx = New() 23 | for i := range n { 24 | hashes.Add(IntHashx(hx, i)) 25 | hashes.Add(IntHashx(hx, -i)) 26 | } 27 | want := n*2 - 1 28 | if len(hashes) != want { 29 | t.Errorf("len(hashes)=%v, want %v", len(hashes), want) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/agonopol/go-stem v0.0.0-20150630113328-985885018250 h1:znLjWXbvRGRvFlxYoiRu4Yf38isHmYYG07scbZKKg9I= 2 | github.com/agonopol/go-stem v0.0.0-20150630113328-985885018250/go.mod h1:JpR7ykfRJUCcS6aOUCB6dPImrYufY0NoBCDg/wqeIIo= 3 | github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= 4 | github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= 5 | github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= 6 | github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= 7 | golang.org/x/exp v0.0.0-20220325121720-054d8573a5d8 h1:Xt4/LzbTwfocTk9ZLEu4onjeFucl88iW+v4j4PWbQuE= 8 | golang.org/x/exp v0.0.0-20220325121720-054d8573a5d8/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE= 9 | -------------------------------------------------------------------------------- /aio/wrap.go: -------------------------------------------------------------------------------- 1 | package aio 2 | 3 | import ( 4 | "bufio" 5 | "io" 6 | ) 7 | 8 | // Wraps a writer and its underlying closer. 9 | type writerWrapper struct { 10 | top, bottom io.WriteCloser 11 | } 12 | 13 | func (w *writerWrapper) Write(p []byte) (int, error) { 14 | return w.top.Write(p) 15 | } 16 | 17 | func (w *writerWrapper) Close() error { 18 | if err := w.top.Close(); err != nil { 19 | return err 20 | } 21 | return w.bottom.Close() 22 | } 23 | 24 | type Reader struct { 25 | bufio.Reader 26 | r io.ReadCloser 27 | } 28 | 29 | func (r *Reader) Close() error { 30 | return r.r.Close() 31 | } 32 | 33 | type Writer struct { 34 | bufio.Writer 35 | w io.WriteCloser 36 | } 37 | 38 | func (w *Writer) Close() error { 39 | if err := w.Flush(); err != nil { 40 | return err 41 | } 42 | return w.w.Close() 43 | } 44 | -------------------------------------------------------------------------------- /ppln/common.go: -------------------------------------------------------------------------------- 1 | package ppln 2 | 3 | import ( 4 | "iter" 5 | ) 6 | 7 | // SliceInput returns a function that iterates over a slice, 8 | // to be used as the input function in [Serial] and [NonSerial]. 9 | func SliceInput[T any](s []T) iter.Seq2[T, error] { 10 | return func(yield func(T, error) bool) { 11 | for _, t := range s { 12 | if !yield(t, nil) { 13 | return 14 | } 15 | } 16 | } 17 | } 18 | 19 | // RangeInput returns a function that iterates over a range of integers, 20 | // starting at start and ending at (and excluding) stop, 21 | // to be used as the input function in [Serial] and [NonSerial]. 22 | func RangeInput(start, stop int) iter.Seq2[int, error] { 23 | return func(yield func(int, error) bool) { 24 | for i := start; i < stop; i++ { 25 | if !yield(i, nil) { 26 | return 27 | } 28 | } 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /clustering/kmeans_test.go: -------------------------------------------------------------------------------- 1 | package clustering 2 | 3 | import ( 4 | "math/rand" 5 | "reflect" 6 | "testing" 7 | "time" 8 | ) 9 | 10 | func TestKmeans(t *testing.T) { 11 | rand.Seed(time.Now().UnixNano()) 12 | 13 | m := [][]float64{ 14 | {0.1, 0.0}, 15 | {0.9, 1.0}, 16 | {-0.1, 0.0}, 17 | {0.0, -0.1}, 18 | {1.1, 1.0}, 19 | {1.0, 1.1}, 20 | {1.0, 0.9}, 21 | {0.0, 0.1}, 22 | } 23 | 24 | means, tags := Kmeans(m, 2) 25 | 26 | if tags[0] == 0 { 27 | assertEqual(tags, []int{0, 1, 0, 0, 1, 1, 1, 0}, t) 28 | assertEqual(means, [][]float64{{0, 0}, {1, 1}}, t) 29 | } else { 30 | assertEqual(tags, []int{1, 0, 1, 1, 0, 0, 0, 1}, t) 31 | assertEqual(means, [][]float64{{1, 1}, {0, 0}}, t) 32 | } 33 | } 34 | 35 | func assertEqual(act, exp interface{}, t *testing.T) { 36 | if !reflect.DeepEqual(act, exp) { 37 | t.Fatalf("Wrong value: %v, expected %v", act, exp) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /rhash/buz_test.go: -------------------------------------------------------------------------------- 1 | package rhash 2 | 3 | import ( 4 | "hash" 5 | "testing" 6 | ) 7 | 8 | func TestBuz64(t *testing.T) { 9 | test64(t, func(i int) hash.Hash64 { return NewBuz(i) }) 10 | } 11 | 12 | func TestBuz32(t *testing.T) { 13 | test32(t, func(i int) hash.Hash32 { return NewBuz(i) }) 14 | } 15 | 16 | func TestBuz64_seed(t *testing.T) { 17 | seed := BuzRandomSeed() 18 | test64(t, func(i int) hash.Hash64 { 19 | return NewBuzWithSeed(i, seed) 20 | }) 21 | } 22 | 23 | func TestBuz64_modifySeed(t *testing.T) { 24 | seed := &BuzSeed{} 25 | for i := range seed { 26 | seed[i] = uint64(i) 27 | } 28 | input := "dsfhjkdfhdjsfjksdjadi" 29 | h := NewBuzWithSeed(len(input), seed) 30 | h.Write([]byte(input)) 31 | want := h.Sum64() 32 | for i := range seed { 33 | seed[i] = 0 34 | } 35 | h.Write([]byte(input)) 36 | got := h.Sum64() 37 | if got != want { 38 | t.Fatalf("Buz.Sum64(%q)=%v, want %v", input, got, want) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /ppln/doc.go: -------------------------------------------------------------------------------- 1 | // Package ppln provides generic parallel processing pipelines. 2 | // 3 | // # General usage 4 | // 5 | // This package provides two modes of operation: serial and non-serial. 6 | // In [Serial] the outputs are ordered in the same order of the inputs. 7 | // In [NonSerial] the order of outputs is arbitrary, 8 | // but correlated with the order of inputs. 9 | // 10 | // Each of the functions blocks the calling function until either the processing 11 | // is done (output was called on the last value) or until an error is returned. 12 | // 13 | // # Stopping 14 | // 15 | // Each user-function (input, transform, output) may return an error. 16 | // Returning a non-nil error stops the pipeline prematurely, and that 17 | // error is returned to the caller. 18 | // 19 | // # Experimental 20 | // 21 | // This package relies on the experimental [iter] package. 22 | // In order to use it, go 1.22 is required with GOEXPERIMENT=rangefunc. 23 | package ppln 24 | -------------------------------------------------------------------------------- /morris/error/error.go: -------------------------------------------------------------------------------- 1 | // Prints error rates of Morris counter using different m's. 2 | package main 3 | 4 | import ( 5 | "fmt" 6 | 7 | "github.com/fluhus/gostuff/gnum" 8 | "github.com/fluhus/gostuff/morris" 9 | "github.com/fluhus/gostuff/ptimer" 10 | ) 11 | 12 | const ( 13 | upto = 10000000 14 | reps = 100 15 | ) 16 | 17 | func main() { 18 | ms := []uint{1, 3, 10, 30, 100, 300, 1000, 3000, 10000} 19 | var errs []float64 20 | for _, m := range ms { 21 | pt := ptimer.NewMessage(fmt.Sprint("{} (", m, ")")) 22 | err := 0.0 23 | for rep := 1; rep <= reps; rep++ { 24 | c := uint(0) 25 | for i := 1; i <= upto; i++ { 26 | c = morris.Raise(c, m) 27 | r := morris.Restore(c, m) 28 | err += float64(gnum.Abs(int(i)-int(r))) / float64(i) 29 | pt.Inc() 30 | } 31 | } 32 | pt.Done() 33 | errs = append(errs, err/upto/reps) 34 | } 35 | for i := range ms { 36 | fmt.Printf("// % 10d: %.1f%%\n", ms[i], errs[i]*100) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /nlp/wordnet/string.go: -------------------------------------------------------------------------------- 1 | package wordnet 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | ) 7 | 8 | // String functions for types. 9 | 10 | // TODO(amit): Consider removing String() functions to simplify the API. 11 | 12 | // Returns a compact string representation of the WordNet data collection, for 13 | // debugging. 14 | func (wn *WordNet) String() string { 15 | return fmt.Sprintf("WordNet[%d lemmas, %d synsets, %d exceptions,"+ 16 | " %d examples]", 17 | len(wn.Lemma), len(wn.Synset), len(wn.Exception), len(wn.Example)) 18 | } 19 | 20 | // Returns a string representation of the synset, for debugging. 21 | func (s *Synset) String() string { 22 | result := bytes.NewBuffer(make([]byte, 0, 100)) 23 | fmt.Fprintf(result, "Synset[%s.", s.Pos) 24 | for i, word := range s.Word { 25 | if i > 0 { 26 | fmt.Fprintf(result, ",") 27 | } 28 | fmt.Fprintf(result, " %v", word) 29 | } 30 | fmt.Fprintf(result, ": %s]", s.Gloss) 31 | return result.String() 32 | } 33 | -------------------------------------------------------------------------------- /csvx/csvx_test.go: -------------------------------------------------------------------------------- 1 | package csvx 2 | 3 | import ( 4 | "slices" 5 | "strings" 6 | "testing" 7 | 8 | "github.com/fluhus/gostuff/iterx" 9 | ) 10 | 11 | func TestReader(t *testing.T) { 12 | input := "a,bb,ccc\nddd,ee,f" 13 | want := [][]string{{"a", "bb", "ccc"}, {"ddd", "ee", "f"}} 14 | got, err := iterx.CollectErr(Reader(strings.NewReader(input))) 15 | if err != nil { 16 | t.Fatalf("Reader(%q) failed: %v", input, err) 17 | } 18 | if !slices.EqualFunc(got, want, slices.Equal) { 19 | t.Fatalf("Reader(%q)=%q, want %q", input, got, want) 20 | } 21 | } 22 | 23 | func TestReaderTSV(t *testing.T) { 24 | input := "a\tbb\tccc\nddd\tee\tf" 25 | want := [][]string{{"a", "bb", "ccc"}, {"ddd", "ee", "f"}} 26 | got, err := iterx.CollectErr(Reader(strings.NewReader(input), TSV)) 27 | if err != nil { 28 | t.Fatalf("Reader(%q) failed: %v", input, err) 29 | } 30 | if !slices.EqualFunc(got, want, slices.Equal) { 31 | t.Fatalf("Reader(%q)=%q, want %q", input, got, want) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /csvx/examples_test.go: -------------------------------------------------------------------------------- 1 | package csvx 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | func ExampleDecodeReader() { 9 | type person struct { 10 | Name string 11 | Age int 12 | } 13 | input := strings.NewReader("alice,30\nbob,25") 14 | 15 | for p, err := range DecodeReader[person](input) { 16 | if err != nil { 17 | panic(err) 18 | } 19 | fmt.Println(p.Name, "is", p.Age, "years old") 20 | } 21 | 22 | //Output: 23 | //alice is 30 years old 24 | //bob is 25 years old 25 | } 26 | 27 | func ExampleDecodeReaderHeader() { 28 | type person struct { 29 | Name string 30 | Age int 31 | } 32 | input := strings.NewReader( 33 | "user_id,age,city,name\n" + 34 | "111,30,paris,alice\n" + 35 | "222,25,london,bob") 36 | 37 | for p, err := range DecodeReaderHeader[person](input) { 38 | if err != nil { 39 | panic(err) 40 | } 41 | fmt.Println(p.Name, "is", p.Age, "years old") 42 | } 43 | 44 | //Output: 45 | //alice is 30 years old 46 | //bob is 25 years old 47 | } 48 | -------------------------------------------------------------------------------- /repeat/repeat_test.go: -------------------------------------------------------------------------------- 1 | package repeat 2 | 3 | import ( 4 | "io" 5 | "testing" 6 | ) 7 | 8 | func TestReader(t *testing.T) { 9 | r := NewReader([]byte("amit"), 2) 10 | buf := make([]byte, 3) 11 | want := []string{"ami", "tam", "it"} 12 | 13 | for _, w := range want { 14 | n, err := r.Read(buf) 15 | if err != nil { 16 | t.Fatalf("Read() failed: %v", err) 17 | } 18 | if got := string(buf[:n]); got != w { 19 | t.Fatalf("Read()=%q, want %q", got, "ami") 20 | } 21 | } 22 | if _, err := r.Read(buf); err != io.EOF { 23 | t.Fatalf("Read() err=%v, want EOF", err) 24 | } 25 | } 26 | 27 | func TestReader_infinite(t *testing.T) { 28 | r := NewReader([]byte("amit"), -1) 29 | buf := make([]byte, 3) 30 | want := []string{"ami", "tam", "ita", "mit", "ami", "tam", "ita", "mit"} 31 | 32 | for i, w := range want { 33 | n, err := r.Read(buf) 34 | if err != nil { 35 | t.Fatalf("#%v: Read() failed: %v", i, err) 36 | } 37 | if got := string(buf[:n]); got != w { 38 | t.Fatalf("#%v: Read()=%q, want %q", i, got, "ami") 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /reservoir/reservoir.go: -------------------------------------------------------------------------------- 1 | // Package reservoir implements reservoir sampling. 2 | // 3 | // Reservoir sampling allows sampling m uniformly random elements 4 | // from a stream, 5 | // using O(m) memory regardless of the stream length. 6 | package reservoir 7 | 8 | import "math/rand/v2" 9 | 10 | // Sampler samples a fixed number of elements with uniform distribution 11 | // from a stream. 12 | type Sampler[T any] struct { 13 | Elements []T // Elements selected so far. 14 | r *rand.Rand 15 | n int 16 | } 17 | 18 | // New returns a new sampler that samples n elements. 19 | func New[T any](n int) *Sampler[T] { 20 | return &Sampler[T]{ 21 | Elements: make([]T, 0, n), 22 | r: rand.New(rand.NewPCG(rand.Uint64(), rand.Uint64())), 23 | } 24 | } 25 | 26 | // Add maybe adds t to the selected sample. 27 | func (r *Sampler[T]) Add(t T) { 28 | r.n++ 29 | if len(r.Elements) < cap(r.Elements) { 30 | r.Elements = append(r.Elements, t) 31 | return 32 | } 33 | i := r.r.IntN(r.n) 34 | if i >= len(r.Elements) { 35 | return 36 | } 37 | r.Elements[i] = t 38 | } 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Amit Lavon 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so, 8 | subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /nlp/tfidf.go: -------------------------------------------------------------------------------- 1 | package nlp 2 | 3 | // TF-IDF functionality. 4 | 5 | import ( 6 | "math" 7 | ) 8 | 9 | // TfIdf returns the TF-IDF scores of the given corpus. For each documet, 10 | // returns a map from token to TF-IDF score. 11 | // 12 | // TF = count(token in document) / count(all tokens in document) 13 | // 14 | // IDF = log(count(documents) / count(documents with token)) 15 | func TfIdf(docTokens [][]string) []map[string]float64 { 16 | tf := make([]map[string]float64, len(docTokens)) 17 | idf := map[string]float64{} 18 | 19 | // Collect TF and DF. 20 | for i := range docTokens { 21 | tf[i] = map[string]float64{} 22 | for j := range docTokens[i] { 23 | tf[i][docTokens[i][j]]++ 24 | } 25 | for token := range tf[i] { 26 | tf[i][token] /= float64(len(docTokens[i])) 27 | idf[token]++ 28 | } 29 | } 30 | 31 | // Turn DF to IDF. 32 | for token, df := range idf { 33 | idf[token] = math.Log(float64(len(docTokens)) / df) 34 | } 35 | 36 | // Turn TF to TF-IDF. 37 | for i := range tf { 38 | for token := range tf[i] { 39 | tf[i][token] *= idf[token] 40 | } 41 | } 42 | 43 | return tf 44 | } 45 | -------------------------------------------------------------------------------- /ppln/common_test.go: -------------------------------------------------------------------------------- 1 | package ppln 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | ) 7 | 8 | func TestRangeInput(t *testing.T) { 9 | tests := []struct { 10 | from, to int 11 | want []int 12 | }{ 13 | {0, 5, []int{0, 1, 2, 3, 4}}, 14 | {-2, 3, []int{-2, -1, 0, 1, 2}}, 15 | {0, 0, nil}, 16 | {1, 1, nil}, 17 | } 18 | for _, test := range tests { 19 | var got []int 20 | for i, err := range RangeInput(test.from, test.to) { 21 | if err != nil { 22 | t.Fatalf("RangeInput(%d,%d) returned error: %v", 23 | test.from, test.to, err) 24 | } 25 | got = append(got, i) 26 | } 27 | if !slices.Equal(got, test.want) { 28 | t.Fatalf("RangeInput(%d,%d)=%v, want %v", 29 | test.from, test.to, got, test.want) 30 | } 31 | } 32 | } 33 | 34 | func TestSliceInput(t *testing.T) { 35 | want := []int{3, 5, 1, 7} 36 | var got []int 37 | for i, err := range SliceInput(slices.Clone(want)) { 38 | if err != nil { 39 | t.Fatalf("SliceInput(%v) returned error: %v", 40 | want, err) 41 | } 42 | got = append(got, i) 43 | } 44 | if !slices.Equal(got, want) { 45 | t.Fatalf("SliceInput(%v)=%v, want %v", 46 | want, got, want) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /gnum/vecs_test.go: -------------------------------------------------------------------------------- 1 | package gnum 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | ) 7 | 8 | func TestOnes(t *testing.T) { 9 | want := []int8{1, 1, 1, 1} 10 | got := Ones[[]int8](4) 11 | if !slices.Equal(got, want) { 12 | t.Errorf("Ones(4) = %v, want %v", got, want) 13 | } 14 | } 15 | 16 | func TestAdd(t *testing.T) { 17 | a := []uint{4, 6} 18 | b := []uint{2, 3} 19 | got := Add(nil, a, b) 20 | want := []uint{6, 9} 21 | if !slices.Equal(got, want) { 22 | t.Errorf("Add(nil, %v, %v)=%v, want %v", a, b, got, want) 23 | } 24 | if want := []uint{4, 6}; !slices.Equal(a, want) { 25 | t.Errorf("a=%v, want %v", a, want) 26 | } 27 | if want := []uint{2, 3}; !slices.Equal(b, want) { 28 | t.Errorf("b=%v, want %v", b, want) 29 | } 30 | } 31 | 32 | func TestAdd_inplace(t *testing.T) { 33 | a := []uint{4, 6} 34 | b := []uint{2, 3} 35 | got := Add(a, b) 36 | want := []uint{6, 9} 37 | if !slices.Equal(got, want) { 38 | t.Errorf("Add(nil, %v, %v)=%v, want %v", a, b, got, want) 39 | } 40 | if !slices.Equal(a, want) { 41 | t.Errorf("a=%v, want %v", a, want) 42 | } 43 | if want := []uint{2, 3}; !slices.Equal(b, want) { 44 | t.Errorf("b=%v, want %v", b, want) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /clustering/agglo_test.go: -------------------------------------------------------------------------------- 1 | package clustering 2 | 3 | import ( 4 | "math" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | // TODO(amit): Add more test cases. 10 | 11 | func TestClink(t *testing.T) { 12 | points := []float64{0, 1, 5} 13 | steps := []AggloStep{ 14 | {0, 1, 1}, 15 | {1, 2, 5}, 16 | } 17 | agg := clink(len(points), func(i, j int) float64 { 18 | return math.Abs(points[i] - points[j]) 19 | }) 20 | if agg.Len() != len(points)-1 { 21 | t.Fatalf("Len()=%v, want %v", agg.Len(), len(points)-1) 22 | } 23 | for i := range steps { 24 | if step := agg.Step(i); !reflect.DeepEqual(steps[i], step) { 25 | t.Errorf("Step(%v)=%v, want %v", i, step, steps[i]) 26 | } 27 | } 28 | } 29 | 30 | func TestSlink(t *testing.T) { 31 | points := []float64{0, 1, 5} 32 | steps := []AggloStep{ 33 | {0, 1, 1}, 34 | {1, 2, 4}, 35 | } 36 | agg := slink(len(points), func(i, j int) float64 { 37 | return math.Abs(points[i] - points[j]) 38 | }) 39 | if agg.Len() != len(points)-1 { 40 | t.Fatalf("Len()=%v, want %v", agg.Len(), len(points)-1) 41 | } 42 | for i := range steps { 43 | if step := agg.Step(i); !reflect.DeepEqual(steps[i], step) { 44 | t.Errorf("Step(%v)=%v, want %v", i, step, steps[i]) 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /snm/queue_test.go: -------------------------------------------------------------------------------- 1 | package snm 2 | 3 | import "testing" 4 | 5 | func TestQueue(t *testing.T) { 6 | q := &Queue[int]{} 7 | 8 | qEnqueue(q, 11) 9 | qExpect(q, t, 11) 10 | 11 | qEnqueue(q, 22, 33) 12 | qExpect(q, t, 22, 33) 13 | 14 | qEnqueue(q, 44, 55, 66) 15 | qExpect(q, t, 44) 16 | qEnqueue(q, 77) 17 | qExpect(q, t, 55) 18 | qEnqueue(q, 88) 19 | qExpect(q, t, 66, 77, 88) 20 | } 21 | 22 | func qEnqueue(q *Queue[int], x ...int) { 23 | for _, xx := range x { 24 | q.Enqueue(xx) 25 | } 26 | } 27 | 28 | func qExpect(q *Queue[int], t *testing.T, x ...int) { 29 | for _, xx := range x { 30 | if got := q.Peek(); got != xx { 31 | t.Fatalf("q.pull()=%v, want %v", got, xx) 32 | } 33 | if got := q.Dequeue(); got != xx { 34 | t.Fatalf("q.pull()=%v, want %v", got, xx) 35 | } 36 | } 37 | } 38 | 39 | func FuzzQueue(f *testing.F) { 40 | f.Add(1, 1, 1, 1, 1, 1, 1, 1, 1, 1) 41 | f.Fuzz(func(t *testing.T, a, b, c, d, e, f, g, h, i, j int) { 42 | q := &Queue[int]{} 43 | qEnqueue(q, a) 44 | qExpect(q, t, a) 45 | qEnqueue(q, b, c) 46 | qExpect(q, t, b, c) 47 | qEnqueue(q, d, e) 48 | qExpect(q, t, d) 49 | qEnqueue(q, f) 50 | qExpect(q, t, e) 51 | qEnqueue(q, g, h, i) 52 | qExpect(q, t, f, g) 53 | qEnqueue(q, j) 54 | qExpect(q, t, h, i, j) 55 | }) 56 | } 57 | -------------------------------------------------------------------------------- /nlp/tokenize.go: -------------------------------------------------------------------------------- 1 | package nlp 2 | 3 | import ( 4 | "github.com/agonopol/go-stem" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | // Tokenizer splits text into tokens. This regexp represents a single word. 10 | // Changing this regexp will affect the Tokenize function. 11 | var Tokenizer = regexp.MustCompile("\\w([\\w']*\\w)?") 12 | 13 | // Tokenize splits a given text to a slice of stemmed, lowercase words. If 14 | // keepStopWords is false, will drop stop words. 15 | func Tokenize(s string, keepStopWords bool) []string { 16 | s = correctUtf8Punctuation(s) 17 | s = strings.ToLower(s) 18 | words := Tokenizer.FindAllString(s, -1) 19 | var result []string 20 | for _, word := range words { 21 | if !keepStopWords && StopWords[word] { 22 | continue 23 | } 24 | result = append(result, Stem(word)) 25 | } 26 | 27 | return result 28 | } 29 | 30 | // Stem porter-stems the given word. 31 | func Stem(s string) string { 32 | if strings.HasSuffix(s, "'s") { 33 | s = s[:len(s)-2] 34 | } 35 | return string(stemmer.Stem([]byte(s))) 36 | } 37 | 38 | // correctUtf8Punctuation translates or removes non-ASCII punctuation characters. 39 | func correctUtf8Punctuation(s string) string { 40 | return strings.Replace(s, "’", "'", -1) 41 | // TODO(amit): Improve this function with more characters. 42 | } 43 | -------------------------------------------------------------------------------- /sets/sorted_test.go: -------------------------------------------------------------------------------- 1 | package sets 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | ) 7 | 8 | func TestSorted(t *testing.T) { 9 | tests := []struct { 10 | a, b, u, i []int 11 | }{ 12 | {nil, nil, nil, nil}, 13 | {[]int{1}, nil, []int{1}, nil}, 14 | {nil, []int{2}, []int{2}, nil}, 15 | {[]int{1}, []int{2}, []int{1, 2}, nil}, 16 | {[]int{2}, []int{1}, []int{1, 2}, nil}, 17 | {[]int{1, 3, 5}, []int{3, 4, 5, 6}, []int{1, 3, 4, 5, 6}, []int{3, 5}}, 18 | 19 | {[]int{1, 1}, []int{1, 2}, []int{1, 1, 2}, []int{1}}, 20 | {[]int{1, 2, 2, 3}, []int{2, 3, 3, 4}, []int{1, 2, 2, 3, 3, 4}, []int{2, 3}}, 21 | } 22 | for _, test := range tests { 23 | i := SortedIntersection(test.a, test.b) 24 | u := SortedUnion(test.a, test.b) 25 | il := SortedIntersectionLen(test.a, test.b) 26 | ul := SortedUnionLen(test.a, test.b) 27 | if !slices.Equal(i, test.i) { 28 | t.Fatalf("SortedIntersection(%v,%v)=%v, want %v", 29 | test.a, test.b, i, test.i) 30 | } 31 | if !slices.Equal(u, test.u) { 32 | t.Fatalf("SortedUnion(%v,%v)=%v, want %v", 33 | test.a, test.b, u, test.u) 34 | } 35 | if il != len(i) { 36 | t.Fatalf("SortedIntersectionLen(%v,%v)=%v, want %v", 37 | test.a, test.b, il, len(i)) 38 | } 39 | if ul != len(u) { 40 | t.Fatalf("SortedUnionLen(%v,%v)=%v, want %v", 41 | test.a, test.b, ul, len(u)) 42 | } 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /reservoir/reservoir_test.go: -------------------------------------------------------------------------------- 1 | package reservoir 2 | 3 | import ( 4 | "fmt" 5 | "slices" 6 | "testing" 7 | ) 8 | 9 | func TestSmall(t *testing.T) { 10 | var want []int 11 | r := New[int](10) 12 | if len(r.Elements) != 0 { 13 | t.Fatalf("len(E)=%v, want 0", len(r.Elements)) 14 | } 15 | for i := range 10 { 16 | r.Add(i) 17 | want = append(want, i) 18 | if !slices.Equal(r.Elements, want) { 19 | t.Fatalf("E=%v, want %v", r.Elements, want) 20 | } 21 | } 22 | r.Add(10) 23 | want = append(want, 10) 24 | if slices.Equal(r.Elements, want) { 25 | t.Fatalf("E=%v, want smaller", r.Elements) 26 | } 27 | if len(r.Elements) != 10 { 28 | t.Fatalf("len(E)=%v, want 10", len(r.Elements)) 29 | } 30 | } 31 | 32 | func TestBig(t *testing.T) { 33 | counts := make([]int, 10) 34 | for range 1000 { 35 | r := New[int](5) 36 | for i := range 10 { 37 | r.Add(i) 38 | } 39 | for _, i := range r.Elements { 40 | counts[i]++ 41 | } 42 | } 43 | wantMin, wantMax := 450, 550 44 | for i, c := range counts { 45 | if c < wantMin || c > wantMax { 46 | t.Errorf("count[%v]=%v, want %v-%v", i, c, wantMin, wantMax) 47 | } 48 | } 49 | } 50 | 51 | func Example() { 52 | // Select 10 random elements uniformly out of a stream. 53 | sampler := New[int](10) 54 | for i := range 1000 { 55 | sampler.Add(i) 56 | } 57 | fmt.Println("Selected subsample:", sampler.Elements) 58 | } 59 | -------------------------------------------------------------------------------- /xmlnode/read_test.go: -------------------------------------------------------------------------------- 1 | package xmlnode 2 | 3 | import ( 4 | "strings" 5 | "testing" 6 | ) 7 | 8 | func TestReadAll(t *testing.T) { 9 | text := "worldworld" 10 | node, err := ReadAll(strings.NewReader(text)) 11 | 12 | if err != nil { 13 | t.Fatal(err) 14 | } 15 | 16 | result := nodeToString(node) 17 | if result != text { 18 | t.Fatal("expected:", text, "actual:", result) 19 | } 20 | } 21 | 22 | // Inefficient stringifier for testing. 23 | func nodeToString(n Node) string { 24 | switch n := n.(type) { 25 | case *root: 26 | result := "" 27 | for _, child := range n.children { 28 | result += nodeToString(child) 29 | } 30 | return result 31 | 32 | case *tag: 33 | result := "<" + n.tagName 34 | for _, attr := range n.attr { 35 | result += " " + attr.Name.Local + "=\"" + attr.Value + "\"" 36 | } 37 | result += ">" 38 | for _, child := range n.children { 39 | result += nodeToString(child) 40 | } 41 | result += "" 42 | return result 43 | 44 | case *text: 45 | return n.text 46 | 47 | case *comment: 48 | return "" 49 | 50 | case *procInst: 51 | return "" 52 | 53 | case *directive: 54 | return "" 55 | 56 | default: 57 | panic("Unknown node type.") 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /csvx/csvx.go: -------------------------------------------------------------------------------- 1 | package csvx 2 | 3 | import ( 4 | "encoding/csv" 5 | "io" 6 | "iter" 7 | 8 | "github.com/fluhus/gostuff/aio" 9 | ) 10 | 11 | // Reader iterates over CSV entries from a reader. 12 | // Applies the given modifiers before iteration. 13 | func Reader(r io.Reader, mods ...ReaderModifier) iter.Seq2[[]string, error] { 14 | return func(yield func([]string, error) bool) { 15 | c := csv.NewReader(r) 16 | for _, mod := range mods { 17 | mod(c) 18 | } 19 | for { 20 | e, err := c.Read() 21 | if err == io.EOF { 22 | return 23 | } 24 | if !yield(e, nil) { 25 | return 26 | } 27 | } 28 | } 29 | } 30 | 31 | // File iterates over CSV entries from a file. 32 | // Applies the given modifiers before iteration. 33 | func File(file string, mods ...ReaderModifier) iter.Seq2[[]string, error] { 34 | return func(yield func([]string, error) bool) { 35 | f, err := aio.Open(file) 36 | if err != nil { 37 | yield(nil, err) 38 | return 39 | } 40 | defer f.Close() 41 | c := csv.NewReader(f) 42 | for _, mod := range mods { 43 | mod(c) 44 | } 45 | for { 46 | e, err := c.Read() 47 | if err == io.EOF { 48 | return 49 | } 50 | if !yield(e, nil) { 51 | return 52 | } 53 | } 54 | } 55 | } 56 | 57 | // ReaderModifier modifies the settings of a CSV reader 58 | // before iteration starts. 59 | type ReaderModifier = func(*csv.Reader) 60 | 61 | // TSV makes the reader use tab as the delimiter. 62 | func TSV(r *csv.Reader) { 63 | r.Comma = '\t' 64 | } 65 | -------------------------------------------------------------------------------- /graphs/bfsdfs.go: -------------------------------------------------------------------------------- 1 | package graphs 2 | 3 | import ( 4 | "iter" 5 | 6 | "github.com/fluhus/gostuff/sets" 7 | "github.com/fluhus/gostuff/snm" 8 | ) 9 | 10 | // BFS iterates over this graph's nodes in a breadth-first ordering, 11 | // including start. 12 | func (g *Graph[T]) BFS(start T) iter.Seq[T] { 13 | return func(yield func(T) bool) { 14 | if _, ok := g.v[start]; !ok { 15 | return 16 | } 17 | 18 | elems := g.v.Elements() 19 | edges := g.edgeSlices() 20 | istart := g.v.IndexOf(start) 21 | done := sets.Set[int]{}.Add(istart) 22 | q := &snm.Queue[int]{} 23 | q.Enqueue(istart) 24 | 25 | for v := range q.Seq() { 26 | if !yield(elems[v]) { 27 | return 28 | } 29 | for _, e := range edges[v] { 30 | if done.Has(e) { 31 | continue 32 | } 33 | done.Add(e) 34 | q.Enqueue(e) 35 | } 36 | } 37 | } 38 | } 39 | 40 | // DFS iterates over this graph's nodes in a depth-first ordering, 41 | // including start. 42 | func (g *Graph[T]) DFS(start T) iter.Seq[T] { 43 | return func(yield func(T) bool) { 44 | if _, ok := g.v[start]; !ok { 45 | return 46 | } 47 | 48 | elems := g.v.Elements() 49 | edges := g.edgeSlices() 50 | istart := g.v.IndexOf(start) 51 | done := sets.Set[int]{}.Add(istart) 52 | q := []int{istart} 53 | 54 | for len(q) > 0 { 55 | v := q[len(q)-1] 56 | q = q[:len(q)-1] 57 | if !yield(elems[v]) { 58 | return 59 | } 60 | for _, e := range edges[v] { 61 | if done.Has(e) { 62 | continue 63 | } 64 | done.Add(e) 65 | q = append(q, e) 66 | } 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /rhash/fuzz_test.go: -------------------------------------------------------------------------------- 1 | package rhash 2 | 3 | import ( 4 | "hash" 5 | "testing" 6 | 7 | "github.com/fluhus/gostuff/snm" 8 | ) 9 | 10 | func FuzzBuz64(f *testing.F) { 11 | fuzz64(f, func(n int) hash.Hash64 { return NewBuz(n) }) 12 | } 13 | 14 | func FuzzRabin64(f *testing.F) { 15 | fuzz64(f, func(n int) hash.Hash64 { return NewRabinFingerprint64(n) }) 16 | } 17 | 18 | func FuzzBuz32(f *testing.F) { 19 | fuzz32(f, func(n int) hash.Hash32 { return NewBuz(n) }) 20 | } 21 | 22 | func FuzzRabin32(f *testing.F) { 23 | fuzz32(f, func(n int) hash.Hash32 { return NewRabinFingerprint32(n) }) 24 | } 25 | 26 | func fuzz64(f *testing.F, fn func(n int) hash.Hash64) { 27 | const m = 10 28 | prefix := snm.Slice(30, func(i int) byte { return byte(i) }) 29 | 30 | f.Add([]byte{}) 31 | f.Fuzz(func(t *testing.T, a []byte) { 32 | a = append(prefix, a...) 33 | h := fn(m) 34 | h.Write(a[len(a)-m:]) 35 | want := h.Sum64() 36 | for i := range a[m:] { 37 | h.Write(a[i:]) 38 | if got := h.Sum64(); got != want { 39 | t.Fatalf("Sum64(%v)=%v, want %v", a[i:], got, want) 40 | } 41 | } 42 | }) 43 | } 44 | 45 | func fuzz32(f *testing.F, fn func(n int) hash.Hash32) { 46 | const m = 10 47 | prefix := snm.Slice(30, func(i int) byte { return byte(i) }) 48 | 49 | f.Add([]byte{}) 50 | f.Fuzz(func(t *testing.T, a []byte) { 51 | a = append(prefix, a...) 52 | h := fn(m) 53 | h.Write(a[len(a)-m:]) 54 | want := h.Sum32() 55 | for i := range a[m:] { 56 | h.Write(a[i:]) 57 | if got := h.Sum32(); got != want { 58 | t.Fatalf("Sum32(%v)=%v, want %v", a[i:], got, want) 59 | } 60 | } 61 | }) 62 | } 63 | -------------------------------------------------------------------------------- /ezpprof/ezpprof.go: -------------------------------------------------------------------------------- 1 | // Package ezpprof is a convenience wrapper over the runtime/pprof package. 2 | // 3 | // This package helps to quickly introduce profiling to a piece of code without 4 | // the mess of opening files and checking errors. 5 | // 6 | // A typical use of this package looks like: 7 | // 8 | // ezpprof.Start("myfile.pprof") 9 | // {... some complicated code ...} 10 | // ezpprof.Stop() 11 | // 12 | // Or alternatively: 13 | // 14 | // const profile = true 15 | // 16 | // if profile { 17 | // ezpprof.Start("myfile.pprof") 18 | // defer ezpprof.Stop() 19 | // } 20 | package ezpprof 21 | 22 | import ( 23 | "io" 24 | "runtime/pprof" 25 | 26 | "github.com/fluhus/gostuff/aio" 27 | ) 28 | 29 | var fout io.WriteCloser 30 | 31 | // Start starts CPU profiling and writes to the given file. 32 | // Panics if an error occurs. 33 | func Start(file string) { 34 | if fout != nil { 35 | panic("already profiling") 36 | } 37 | f, err := aio.CreateRaw(file) 38 | if err != nil { 39 | panic(err) 40 | } 41 | fout = f 42 | pprof.StartCPUProfile(fout) 43 | } 44 | 45 | // Stop stops CPU profiling and closes the output file. 46 | // Panics if called without calling Start. 47 | func Stop() { 48 | if fout == nil { 49 | panic("Stop called without calling Start") 50 | } 51 | pprof.StopCPUProfile() 52 | if err := fout.Close(); err != nil { 53 | panic(err) 54 | } 55 | fout = nil 56 | } 57 | 58 | // Heap writes heap profile to the given file. Panics if an error occurs. 59 | func Heap(file string) { 60 | f, err := aio.Create(file) 61 | if err != nil { 62 | panic(err) 63 | } 64 | defer f.Close() 65 | 66 | err = pprof.WriteHeapProfile(f) 67 | if err != nil { 68 | panic(err) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /snm/queue.go: -------------------------------------------------------------------------------- 1 | package snm 2 | 3 | import "iter" 4 | 5 | // Queue is a memory-efficient FIFO container. 6 | type Queue[T any] struct { 7 | q []T 8 | i, n int 9 | } 10 | 11 | // Enqueue inserts an element to the queue. 12 | func (q *Queue[T]) Enqueue(x T) { 13 | if q.n == len(q.q) { 14 | nq := make([]T, len(q.q)*2+1) 15 | _ = append(append(append(nq[:0], q.q[q.i:]...), q.q[:q.i]...), x) 16 | q.q = nq 17 | q.n++ 18 | q.i = 0 19 | return 20 | } 21 | i := (q.i + q.n) % len(q.q) 22 | q.q[i] = x 23 | q.n++ 24 | } 25 | 26 | // Dequeue removes and returns the next element in the queue. 27 | // Panics if the queue is empty. 28 | func (q *Queue[T]) Dequeue() T { 29 | if q.n == 0 { 30 | panic("pull with 0 elements") 31 | } 32 | x := q.q[q.i] 33 | var zero T 34 | q.q[q.i] = zero // Remove element to allow GC. 35 | q.n-- 36 | q.i = (q.i + 1) % len(q.q) 37 | return x 38 | } 39 | 40 | // Peek returns the next element in the queue, 41 | // without modifying its contents. 42 | // Panics if the queue is empty. 43 | func (q *Queue[T]) Peek() T { 44 | if q.n == 0 { 45 | panic("pull with 0 elements") 46 | } 47 | return q.q[q.i] 48 | } 49 | 50 | // Len return the current number of elements in the queue. 51 | func (q *Queue[T]) Len() int { 52 | return q.n 53 | } 54 | 55 | // Seq returns an iterator over the queue's elements, 56 | // dequeueing each one. 57 | // 58 | // It is okay to enqueue elements while iterating, 59 | // from within the same goroutine. 60 | // The new elements will be included in the same loop. 61 | func (q *Queue[T]) Seq() iter.Seq[T] { 62 | return func(yield func(T) bool) { 63 | for q.Len() > 0 { 64 | if !yield(q.Dequeue()) { 65 | break 66 | } 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /jio/jio.go: -------------------------------------------------------------------------------- 1 | // Package jio provides convenience functions for saving and loading 2 | // JSON-encoded values. 3 | // 4 | // Uses the [aio] package for I/O. 5 | package jio 6 | 7 | import ( 8 | "encoding/json" 9 | "io" 10 | 11 | "github.com/fluhus/gostuff/aio" 12 | ) 13 | 14 | // Write saves v to the given file, encoded as JSON. 15 | func Write(file string, v any) error { 16 | f, err := aio.Create(file) 17 | if err != nil { 18 | return err 19 | } 20 | e := json.NewEncoder(f) 21 | e.SetIndent("", " ") 22 | if err := e.Encode(v); err != nil { 23 | f.Close() 24 | return err 25 | } 26 | return f.Close() 27 | } 28 | 29 | // Read loads a JSON encoded value from the given file and populates v with it. 30 | func Read(file string, v any) error { 31 | f, err := aio.Open(file) 32 | if err != nil { 33 | return err 34 | } 35 | defer f.Close() 36 | return json.NewDecoder(f).Decode(v) 37 | } 38 | 39 | // ReadAs reads a JSON encoded value of type T and returns it. 40 | func ReadAs[T any](file string) (T, error) { 41 | var t T 42 | err := Read(file, &t) 43 | return t, err 44 | } 45 | 46 | // Iter returns an iterator over sequential JSON values in a file. 47 | // 48 | // Note: a file with several independent JSON values is not a valid JSON file. 49 | func Iter[T any](file string) func(yield func(T, error) bool) { 50 | return func(yield func(T, error) bool) { 51 | f, err := aio.Open(file) 52 | if err != nil { 53 | var t T 54 | yield(t, err) 55 | return 56 | } 57 | defer f.Close() 58 | j := json.NewDecoder(f) 59 | for { 60 | var t T 61 | err := j.Decode(&t) 62 | if err == io.EOF { 63 | return 64 | } 65 | if err != nil { 66 | yield(t, err) 67 | return 68 | } 69 | if !yield(t, nil) { 70 | return 71 | } 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /morris/morris_test.go: -------------------------------------------------------------------------------- 1 | package morris 2 | 3 | import "testing" 4 | 5 | func TestRestore(t *testing.T) { 6 | tests := []struct { 7 | i byte 8 | m uint 9 | want uint 10 | }{ 11 | {0, 1, 0}, 12 | {1, 1, 1}, 13 | {2, 1, 4}, 14 | {3, 1, 9}, 15 | {4, 1, 19}, 16 | {0, 10, 0}, 17 | {1, 10, 1}, 18 | {2, 10, 2}, 19 | {10, 10, 10}, 20 | {15, 10, 20}, 21 | {20, 10, 31}, 22 | {25, 10, 51}, 23 | {30, 10, 72}, 24 | } 25 | for _, test := range tests { 26 | if got := Restore(test.i, test.m); got != test.want { 27 | t.Errorf("Restore(%d,%d)=%d, want %d", 28 | test.i, test.m, got, test.want) 29 | } 30 | } 31 | } 32 | 33 | func TestRaise_overflow(t *testing.T) { 34 | if !checkOverFlow { 35 | t.Skip() 36 | } 37 | defer func() { recover() }() 38 | got := Raise(byte(255), 10) 39 | t.Fatalf("Raise(byte(255)=%d, want fail", got) 40 | } 41 | 42 | func TestRaise(t *testing.T) { 43 | const reps = 1000 44 | want := map[byte]int{ 45 | 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10, 46 | 15: 20, 20: 30, 25: 50, 30: 70, 35: 110, 40: 150, 47 | } 48 | margins := map[byte]int{ 49 | 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 50 | 15: 2, 20: 2, 25: 2, 30: 2, 35: 2, 40: 2, 51 | } 52 | got := map[byte]int{} 53 | 54 | for i := 0; i < reps; i++ { 55 | a, n := byte(0), 0 56 | for a < 40 { 57 | n++ 58 | aa := Raise(a, 10) 59 | if aa == a { 60 | continue 61 | } 62 | a = aa 63 | if _, ok := want[a]; !ok { 64 | continue 65 | } 66 | got[a] += n 67 | } 68 | } 69 | 70 | for k, n := range got { 71 | got := n / reps 72 | if got < want[k]-margins[k] || got > want[k]+margins[k] { 73 | t.Errorf("Raise() to %d took %d, want %d-%d", k, got, 74 | want[k]-margins[k], want[k]+margins[k]) 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /hashx/hashx.go: -------------------------------------------------------------------------------- 1 | // Package hashx provides simple hashing functions for various input types. 2 | package hashx 3 | 4 | import ( 5 | "hash" 6 | "io" 7 | 8 | "github.com/spaolacci/murmur3" 9 | "golang.org/x/exp/constraints" 10 | ) 11 | 12 | // Hashx calculates hash values for various input types. 13 | type Hashx struct { 14 | h hash.Hash64 15 | buf []byte 16 | } 17 | 18 | // NewSeed returns a new Hashx with the given seed. 19 | func NewSeed(seed uint32) *Hashx { 20 | return &Hashx{murmur3.New64WithSeed(seed), make([]byte, 8)} 21 | } 22 | 23 | // New returns a new Hashx. 24 | func New() *Hashx { 25 | return &Hashx{murmur3.New64(), make([]byte, 8)} 26 | } 27 | 28 | // Bytes returns the hash value of the given byte sequence. 29 | func (h *Hashx) Bytes(b []byte) uint64 { 30 | h.h.Reset() 31 | h.h.Write(b) 32 | return h.h.Sum64() 33 | } 34 | 35 | // String returns the hash value of the given string. 36 | func (h *Hashx) String(s string) uint64 { 37 | // TODO(amit): Optimize this? 38 | h.h.Reset() 39 | io.WriteString(h.h, s) 40 | return h.h.Sum64() 41 | } 42 | 43 | // IntHashx returns the hash value of the given integer, 44 | // using the given Hashx instance. 45 | func IntHashx[I constraints.Integer](h *Hashx, i I) uint64 { 46 | h.h.Reset() 47 | u := uint64(i) 48 | for j := range 8 { 49 | h.buf[j] = byte(u >> (j * 8)) 50 | } 51 | h.h.Write(h.buf) 52 | return h.h.Sum64() 53 | } 54 | 55 | // The default hashx. 56 | var dflt = New() 57 | 58 | // Bytes returns the hash value of the given byte sequence. 59 | func Bytes(b []byte) uint64 { return dflt.Bytes(b) } 60 | 61 | // String returns the hash value of the given string. 62 | func String(s string) uint64 { return dflt.String(s) } 63 | 64 | // Int returns the hash value of the given integer. 65 | func Int[I constraints.Integer](i I) uint64 { return IntHashx(dflt, i) } 66 | -------------------------------------------------------------------------------- /rhash/bench_test.go: -------------------------------------------------------------------------------- 1 | package rhash 2 | 3 | import ( 4 | "crypto/rand" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/fluhus/gostuff/hashx" 9 | ) 10 | 11 | func BenchmarkWrite1K(b *testing.B) { 12 | const n = 20 13 | buf := make([]byte, 1024) 14 | rand.Read(buf) 15 | b.Run("buz64", func(b *testing.B) { 16 | h := NewBuz(n) 17 | for b.Loop() { 18 | h.Write(buf) 19 | } 20 | }) 21 | b.Run("rabin32", func(b *testing.B) { 22 | h := NewRabinFingerprint32(n) 23 | for b.Loop() { 24 | h.Write(buf) 25 | } 26 | }) 27 | b.Run("rabin64", func(b *testing.B) { 28 | h := NewRabinFingerprint64(n) 29 | for b.Loop() { 30 | h.Write(buf) 31 | } 32 | }) 33 | } 34 | 35 | func BenchmarkRolling(b *testing.B) { 36 | text := make([]byte, 10000) 37 | rand.Read(text) 38 | for _, ln := range []int{10, 30, 100} { 39 | b.Run(fmt.Sprint("buz", ln), func(b *testing.B) { 40 | h := NewBuz(ln) 41 | for b.Loop() { 42 | h.Write(text[:ln-1]) 43 | var s uint64 44 | for _, b := range text[ln:] { 45 | h.WriteByte(b) 46 | s += h.Sum64() 47 | } 48 | if s == 0 { 49 | b.Error("placeholder to not optimize s out") 50 | } 51 | } 52 | }) 53 | b.Run(fmt.Sprint("rabin", ln), func(b *testing.B) { 54 | h := NewRabinFingerprint64(ln) 55 | for b.Loop() { 56 | h.Write(text[:ln-1]) 57 | var s uint64 58 | for _, b := range text[ln:] { 59 | h.WriteByte(b) 60 | s += h.Sum64() 61 | } 62 | if s == 0 { 63 | b.Error("placeholder to not optimize s out") 64 | } 65 | } 66 | }) 67 | b.Run(fmt.Sprint("murmur3", ln), func(b *testing.B) { 68 | for b.Loop() { 69 | var s uint64 70 | for i := range text[ln:] { 71 | s += hashx.Bytes(text[i : i+ln]) 72 | } 73 | if s == 0 { 74 | b.Error("placeholder to not optimize s out") 75 | } 76 | } 77 | }) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /repeat/repeat.go: -------------------------------------------------------------------------------- 1 | // Package repeat implements the repeat-reader. A repeat-reader outputs a given 2 | // constant byte sequence repeatedly. 3 | // 4 | // This package was originally written for testing and profiling parsers. 5 | // A typical use may look something like this: 6 | // 7 | // input := "some line to be parsed\n" 8 | // r := NewReader([]byte(input), 1000) 9 | // parser := myparse.NewParser(r) 10 | // 11 | // (start profiling) 12 | // for range parser.Items() {} // Exhaust parser. 13 | // (stop profiling) 14 | package repeat 15 | 16 | import "io" 17 | 18 | // Reader outputs a given constant byte sequence repeatedly. 19 | type Reader struct { 20 | data []byte 21 | i int 22 | n int 23 | } 24 | 25 | // NewReader returns a reader that outputs data n times. If n is negative, repeats 26 | // infinitely. Copies the contents of data. 27 | func NewReader(data []byte, n int) *Reader { 28 | cp := append(make([]byte, 0, len(data)), data...) 29 | return &Reader{data: cp, n: n} 30 | } 31 | 32 | // Read fills p with repetitions of the reader's data. Writes until p is full or 33 | // until the last repetition was written. Subsequent calls to Read resume from where 34 | // the last repetition stopped. When no more bytes are available, returns 0, EOF. 35 | // Otherwise the error is nil. 36 | func (r *Reader) Read(p []byte) (int, error) { 37 | if len(p) == 0 { 38 | return 0, nil 39 | } 40 | if r.i == 0 && r.n == 0 { 41 | return 0, io.EOF 42 | } 43 | m := 0 44 | for { 45 | n := copy(p, r.data[r.i:]) 46 | r.i += n 47 | if r.i == len(r.data) { 48 | r.i = 0 49 | if r.n > 0 { 50 | r.n-- 51 | } 52 | } 53 | p = p[n:] 54 | m += n 55 | if len(p) == 0 || (r.i == 0 && r.n == 0) { 56 | break 57 | } 58 | } 59 | return m, nil 60 | } 61 | 62 | // Close is a no-op. Implements [io.ReadCloser]. 63 | func (r *Reader) Close() error { 64 | return nil 65 | } 66 | -------------------------------------------------------------------------------- /sets/sets_test.go: -------------------------------------------------------------------------------- 1 | package sets 2 | 3 | import ( 4 | "encoding/json" 5 | "reflect" 6 | "testing" 7 | 8 | "golang.org/x/exp/maps" 9 | ) 10 | 11 | func TestJSON(t *testing.T) { 12 | input := Set[int]{}.Add(1, 3, 6) 13 | want := Set[int]{}.Add(1, 3, 6) 14 | j, err := input.MarshalJSON() 15 | if err != nil { 16 | t.Fatalf("%v.MarshalJSON() failed: %v", input, err) 17 | } 18 | 19 | got := Set[int]{} 20 | err = got.UnmarshalJSON(j) 21 | if err != nil { 22 | t.Fatalf("%v.UnmarshalJSON(%q) failed: %v", input, j, err) 23 | } 24 | 25 | if !maps.Equal(got, want) { 26 | t.Fatalf("UnmarshalJSON(%q)=%v, want %v", j, got, want) 27 | } 28 | } 29 | 30 | func TestJSON_slice(t *testing.T) { 31 | input := []Set[int]{ 32 | Set[int]{}.Add(1, 3, 6), 33 | Set[int]{}.Add(7, 9, 6), 34 | } 35 | want := []Set[int]{ 36 | Set[int]{}.Add(1, 3, 6), 37 | Set[int]{}.Add(7, 9, 6), 38 | } 39 | j, err := json.Marshal(input) 40 | if err != nil { 41 | t.Fatalf("Marshal(%v) failed: %v", input, err) 42 | } 43 | 44 | got := []Set[int]{} 45 | err = json.Unmarshal(j, &got) 46 | if err != nil { 47 | t.Fatalf("Unmarshal(%q) failed: %v", j, err) 48 | } 49 | 50 | if !reflect.DeepEqual(input, got) { 51 | t.Fatalf("Unmarshal(%q)=%v, want %v", j, got, want) 52 | } 53 | } 54 | 55 | func TestJSON_map(t *testing.T) { 56 | input := map[string]Set[int]{ 57 | "a": Set[int]{}.Add(1, 3, 6), 58 | "x": Set[int]{}.Add(7, 9, 6), 59 | } 60 | want := map[string]Set[int]{ 61 | "a": Set[int]{}.Add(1, 3, 6), 62 | "x": Set[int]{}.Add(7, 9, 6), 63 | } 64 | j, err := json.Marshal(input) 65 | if err != nil { 66 | t.Fatalf("Marshal(%v) failed: %v", input, err) 67 | } 68 | 69 | got := map[string]Set[int]{} 70 | err = json.Unmarshal(j, &got) 71 | if err != nil { 72 | t.Fatalf("Unmarshal(%q) failed: %v", j, err) 73 | } 74 | 75 | if !reflect.DeepEqual(input, got) { 76 | t.Fatalf("Unmarshal(%q)=%v, want %v", j, got, want) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /iterx/slice.go: -------------------------------------------------------------------------------- 1 | package iterx 2 | 3 | import ( 4 | "iter" 5 | "slices" 6 | ) 7 | 8 | // Slice returns an iterator over the slice values. 9 | // 10 | // Deprecated: use [slices.Values] instead. 11 | func Slice[T any](s []T) iter.Seq[T] { 12 | return slices.Values(s) 13 | } 14 | 15 | // ISlice returns an iterator over the slice values and their indices, 16 | // like in a range expression. 17 | // 18 | // Deprecated: use [slices.All] instead. 19 | func ISlice[T any](s []T) iter.Seq2[int, T] { 20 | return slices.All(s) 21 | } 22 | 23 | // Limit returns an iterator that stops after n elements, 24 | // if the underlying iterator does not stop before. 25 | func Limit[T any](it iter.Seq[T], n int) iter.Seq[T] { 26 | return func(yield func(T) bool) { 27 | i := 0 28 | for x := range it { 29 | i++ 30 | if i > n { 31 | return 32 | } 33 | if !yield(x) { 34 | return 35 | } 36 | } 37 | } 38 | } 39 | 40 | // Limit2 returns an iterator that stops after n elements, 41 | // if the underlying iterator does not stop before. 42 | func Limit2[T any, S any](it iter.Seq2[T, S], n int) iter.Seq2[T, S] { 43 | return func(yield func(T, S) bool) { 44 | i := 0 45 | for x, y := range it { 46 | i++ 47 | if i > n { 48 | return 49 | } 50 | if !yield(x, y) { 51 | return 52 | } 53 | } 54 | } 55 | } 56 | 57 | // Skip returns an iterator without the first n elements. 58 | func Skip[T any](it iter.Seq[T], n int) iter.Seq[T] { 59 | return func(yield func(T) bool) { 60 | i := 0 61 | for x := range it { 62 | i++ 63 | if i <= n { 64 | continue 65 | } 66 | if !yield(x) { 67 | return 68 | } 69 | } 70 | } 71 | } 72 | 73 | // Skip2 returns an iterator without the first n elements. 74 | func Skip2[T any, S any](it iter.Seq2[T, S], n int) iter.Seq2[T, S] { 75 | return func(yield func(T, S) bool) { 76 | i := 0 77 | for x, y := range it { 78 | i++ 79 | if i <= n { 80 | continue 81 | } 82 | if !yield(x, y) { 83 | return 84 | } 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /sets/sorted.go: -------------------------------------------------------------------------------- 1 | package sets 2 | 3 | import ( 4 | "cmp" 5 | ) 6 | 7 | // SortedIntersection returns the intersection of 8 | // two sorted slices a and b. 9 | func SortedIntersection[T cmp.Ordered](a, b []T) []T { 10 | var result []T 11 | i, j := 0, 0 12 | for i < len(a) && j < len(b) { 13 | switch cmp.Compare(a[i], b[j]) { 14 | case 0: 15 | result = append(result, a[i]) 16 | i++ 17 | j++ 18 | case 1: 19 | j++ 20 | case -1: 21 | i++ 22 | } 23 | } 24 | return result 25 | } 26 | 27 | // SortedUnion returns the union of 28 | // two sorted slices a and b. 29 | func SortedUnion[T cmp.Ordered](a, b []T) []T { 30 | var result []T 31 | i, j := 0, 0 32 | for i < len(a) && j < len(b) { 33 | switch cmp.Compare(a[i], b[j]) { 34 | case 0: 35 | result = append(result, a[i]) 36 | i++ 37 | j++ 38 | case 1: 39 | result = append(result, b[j]) 40 | j++ 41 | case -1: 42 | result = append(result, a[i]) 43 | i++ 44 | } 45 | } 46 | // Add remaining elements. 47 | result = append(result, a[i:]...) 48 | result = append(result, b[j:]...) 49 | return result 50 | } 51 | 52 | // SortedIntersectionLen returns the length of the intersection of 53 | // two sorted slices a and b. 54 | func SortedIntersectionLen[T cmp.Ordered](a, b []T) int { 55 | result := 0 56 | i, j := 0, 0 57 | for i < len(a) && j < len(b) { 58 | switch cmp.Compare(a[i], b[j]) { 59 | case 0: 60 | i++ 61 | j++ 62 | result++ 63 | case 1: 64 | j++ 65 | case -1: 66 | i++ 67 | } 68 | } 69 | return result 70 | } 71 | 72 | // SortedUnionLen returns the length of the union of 73 | // two sorted slices a and b. 74 | func SortedUnionLen[T cmp.Ordered](a, b []T) int { 75 | result := 0 76 | i, j := 0, 0 77 | for i < len(a) && j < len(b) { 78 | result++ 79 | switch cmp.Compare(a[i], b[j]) { 80 | case 0: 81 | i++ 82 | j++ 83 | case 1: 84 | j++ 85 | case -1: 86 | i++ 87 | } 88 | } 89 | // Add remaining elements. 90 | result += len(a) - i 91 | result += len(b) - j 92 | return result 93 | } 94 | -------------------------------------------------------------------------------- /morris/morris.go: -------------------------------------------------------------------------------- 1 | // Package morris provides an implementation of Morris's algorithm 2 | // for approximate counting with few bits. 3 | // 4 | // The original formula raises a counter i with probability 2^(-i). 5 | // The restored value is 2^i - 1. 6 | // 7 | // This package introduces a parameter m, so that the first m increments are 8 | // made with probability 1, then m increments with probability 1/2, then m with 9 | // probability 1/4... Using m=1 is equivalent to the original formula. 10 | // A large m increases accuracy but costs more bits. 11 | // A single counter should use the same m for all calls to Raise and Restore. 12 | // 13 | // This package is experimental. 14 | // 15 | // # Error rates 16 | // 17 | // Average error rates for different values of m: 18 | // 19 | // 1: 54.2% 20 | // 3: 31.1% 21 | // 10: 15.5% 22 | // 30: 8.8% 23 | // 100: 4.6% 24 | // 300: 2.5% 25 | // 1000: 1.6% 26 | // 3000: 0.9% 27 | // 10000: 0.5% 28 | package morris 29 | 30 | import ( 31 | "fmt" 32 | "math/rand" 33 | 34 | "golang.org/x/exp/constraints" 35 | ) 36 | 37 | // If true, panics when a counter is about to be raised beyond its maximal 38 | // value. 39 | const checkOverFlow = true 40 | 41 | // Raise returns the new value of i after one increment. 42 | // m controls the restoration accuracy. 43 | // The approximate number of calls to Raise can be restored using Restore. 44 | func Raise[T constraints.Unsigned](i T, m uint) T { 45 | if checkOverFlow { 46 | max := T(0) - 1 47 | if i == max { 48 | panic(fmt.Sprintf("counter reached maximal value: %d", i)) 49 | } 50 | } 51 | r := 1 << (uint(i) / m) 52 | if rand.Intn(r) == 0 { 53 | return i + 1 54 | } 55 | return i 56 | } 57 | 58 | // Restore returns an approximation of the number of calls to Raise on i. 59 | // m should have the same value that was used with Raise. 60 | func Restore[T constraints.Unsigned](i T, m uint) uint { 61 | ui := uint(i) 62 | if ui <= m { 63 | return ui 64 | } 65 | return m*(1<<(ui/m)-1) + (ui%m)*(1<<(ui/m)) + (1 << (uint(i)/m - 2)) 66 | } 67 | -------------------------------------------------------------------------------- /iterx/slice_test.go: -------------------------------------------------------------------------------- 1 | package iterx 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | ) 7 | 8 | func TestSlice(t *testing.T) { 9 | input := []string{"hello", "world", "hi"} 10 | want := slices.Clone(input) 11 | var got []string 12 | for x := range Slice(input) { 13 | got = append(got, x) 14 | } 15 | if !slices.Equal(input, want) { 16 | t.Fatalf("Slice(%v) changed input to %v", want, input) 17 | } 18 | if !slices.Equal(got, want) { 19 | t.Fatalf("Slice(%v)=%v, want %v", input, got, want) 20 | } 21 | } 22 | 23 | func TestISlice(t *testing.T) { 24 | input := []string{"hello", "world", "hi"} 25 | want := slices.Clone(input) 26 | var got []string 27 | for i, x := range ISlice(input) { 28 | if i != len(got) { 29 | t.Fatalf("ISlice(%v) i=%v, want %v", input, i, len(got)) 30 | } 31 | got = append(got, x) 32 | } 33 | if !slices.Equal(input, want) { 34 | t.Fatalf("ISlice(%v) changed input to %v", want, input) 35 | } 36 | if !slices.Equal(got, want) { 37 | t.Fatalf("ISlice(%v)=%v, want %v", input, got, want) 38 | } 39 | } 40 | 41 | func TestLimit(t *testing.T) { 42 | input := []string{"bla", "blu", "bli", "ble"} 43 | tests := []struct { 44 | n int 45 | want []string 46 | }{ 47 | {-1, nil}, {0, nil}, {1, input[:1]}, {2, input[:2]}, 48 | {3, input[:3]}, {4, input}, {5, input}, {6, input}, 49 | } 50 | for _, test := range tests { 51 | var got []string 52 | for x := range Limit(Slice(input), test.n) { 53 | got = append(got, x) 54 | } 55 | if !slices.Equal(got, test.want) { 56 | t.Errorf("Limit(%v,%v)=%v, want %v", input, test.n, got, test.want) 57 | } 58 | } 59 | } 60 | 61 | func TestSkip(t *testing.T) { 62 | input := []string{"bla", "blu", "bli", "ble"} 63 | tests := []struct { 64 | n int 65 | want []string 66 | }{ 67 | {-1, input}, {0, input}, {1, input[1:]}, {2, input[2:]}, 68 | {3, input[3:]}, {4, nil}, {5, nil}, {6, nil}, 69 | } 70 | for _, test := range tests { 71 | var got []string 72 | for x := range Skip(Slice(input), test.n) { 73 | got = append(got, x) 74 | } 75 | if !slices.Equal(got, test.want) { 76 | t.Errorf("Skip(%v,%v)=%v, want %v", input, test.n, got, test.want) 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /clustering/upgma_test.go: -------------------------------------------------------------------------------- 1 | package clustering 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "reflect" 7 | "testing" 8 | ) 9 | 10 | func TestUPGMA(t *testing.T) { 11 | points := []float64{1, 4, 6, 10} 12 | steps := []AggloStep{ 13 | {1, 2, 2}, 14 | {0, 2, 4}, 15 | {2, 3, 19.0 / 3.0}, 16 | } 17 | agg := upgma(len(points), func(i, j int) float64 { 18 | return math.Abs(points[i] - points[j]) 19 | }) 20 | if agg.Len() != len(points)-1 { 21 | t.Fatalf("Len()=%v, want %v", agg.Len(), len(points)-1) 22 | } 23 | for i := range steps { 24 | if step := agg.Step(i); !reflect.DeepEqual(steps[i], step) { 25 | t.Errorf("Step(%v)=%v, want %v", i, step, steps[i]) 26 | } 27 | } 28 | } 29 | 30 | func TestUPGMA_more(t *testing.T) { 31 | points := []float64{1, 3, 8, 12, 20, 28} 32 | steps := []AggloStep{ 33 | {0, 1, 2}, 34 | {2, 3, 4}, 35 | {1, 3, 8}, 36 | {4, 5, 8}, 37 | {3, 5, 18}, 38 | } 39 | agg := upgma(len(points), func(i, j int) float64 { 40 | return math.Abs(points[i] - points[j]) 41 | }) 42 | if agg.Len() != len(points)-1 { 43 | t.Fatalf("Len()=%v, want %v", agg.Len(), len(points)-1) 44 | } 45 | for i := range steps { 46 | if step := agg.Step(i); !reflect.DeepEqual(steps[i], step) { 47 | t.Errorf("Step(%v)=%v, want %v", i, step, steps[i]) 48 | } 49 | } 50 | } 51 | 52 | func BenchmarkUPGMA(b *testing.B) { 53 | for _, n := range []int{10, 30, 100} { 54 | b.Run(fmt.Sprint(n), func(b *testing.B) { 55 | nums := make([]float64, n) 56 | for i := range nums { 57 | nums[i] = 1.0 / float64(i+1) 58 | if i%2 == 0 { 59 | nums[i] += 10 60 | } 61 | } 62 | b.ResetTimer() 63 | for i := 0; i < b.N; i++ { 64 | upgma(n, func(i1, i2 int) float64 { 65 | return math.Abs(nums[i1] - nums[i2]) 66 | }) 67 | } 68 | }) 69 | } 70 | } 71 | 72 | func FuzzUPGMA(f *testing.F) { 73 | f.Add(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) 74 | f.Fuzz(func(t *testing.T, a float64, b float64, c float64, 75 | d float64, e float64, f float64, g float64, h float64, i float64) { 76 | nums := []float64{a, b, c, d, e, f, g, h, i} 77 | upgma(len(nums), func(i1, i2 int) float64 { 78 | return math.Abs(nums[i1] - nums[i2]) 79 | }) 80 | }) 81 | } 82 | -------------------------------------------------------------------------------- /graphs/graph_test.go: -------------------------------------------------------------------------------- 1 | package graphs 2 | 3 | import ( 4 | "reflect" 5 | "slices" 6 | "testing" 7 | 8 | "github.com/fluhus/gostuff/snm" 9 | ) 10 | 11 | func TestComponents(t *testing.T) { 12 | edges := [][2]int{ 13 | {0, 1}, {1, 2}, {5, 7}, {6, 9}, {9, 10}, {8, 10}, {7, 8}, 14 | } 15 | want := [][]int{ 16 | {0, 1, 2}, {3}, {4}, {5, 6, 7, 8, 9, 10}, {11}, 17 | } 18 | g := New[int]() 19 | for i := range 12 { 20 | g.AddVertices(i) 21 | } 22 | for _, e := range edges { 23 | g.AddEdge(e[0], e[1]) 24 | } 25 | got := g.ConnectedComponents() 26 | if !reflect.DeepEqual(got, want) { 27 | t.Fatalf("components(...)=%v, want %v", got, want) 28 | } 29 | } 30 | 31 | func TestComponents_string(t *testing.T) { 32 | edges := [][2]string{ 33 | {"a", "bb"}, {"eeeee", "dddd"}, {"bb", "ccc"}, {"dddd", "eeeee"}, 34 | } 35 | want := [][]string{ 36 | {"ffffff"}, {"a", "bb", "ccc"}, {"eeeee", "dddd"}, 37 | } 38 | g := New[string]() 39 | g.AddVertices("ffffff") 40 | for _, e := range edges { 41 | g.AddEdge(e[0], e[1]) 42 | } 43 | got := g.ConnectedComponents() 44 | if !reflect.DeepEqual(got, want) { 45 | t.Fatalf("components(...)=%v, want %v", got, want) 46 | } 47 | } 48 | 49 | func TestVerticesEdges(t *testing.T) { 50 | vertices := []string{"ffffff", "bb"} 51 | edges := [][2]string{ 52 | {"a", "bb"}, {"eeeee", "dddd"}, {"bb", "ccc"}, {"dddd", "eeeee"}, 53 | } 54 | 55 | wantVertices := []string{"a", "bb", "ccc", "dddd", "eeeee", "ffffff"} 56 | wantEdges := [][2]string{ 57 | {"bb", "a"}, {"bb", "ccc"}, {"eeeee", "dddd"}, 58 | } 59 | 60 | g := New[string]() 61 | g.AddVertices(vertices...) 62 | for _, e := range edges { 63 | g.AddEdge(e[0], e[1]) 64 | } 65 | 66 | gotVertices := snm.Sorted(slices.Collect(g.Vertices())) 67 | if !slices.Equal(gotVertices, wantVertices) { 68 | t.Errorf("Vertices()=%q, want %q", gotVertices, wantVertices) 69 | } 70 | 71 | var gotEdges [][2]string 72 | for a, b := range g.Edges() { 73 | gotEdges = append(gotEdges, [2]string{a, b}) 74 | } 75 | slices.SortFunc(gotEdges, func(a, b [2]string) int { 76 | return slices.Compare(a[:], b[:]) 77 | }) 78 | if !slices.Equal(gotEdges, wantEdges) { 79 | t.Errorf("Edges()=%q, want %q", gotEdges, wantEdges) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /bits/bits.go: -------------------------------------------------------------------------------- 1 | // Package bits provides operations on bit arrays. 2 | package bits 3 | 4 | import ( 5 | "iter" 6 | mbits "math/bits" 7 | 8 | "golang.org/x/exp/constraints" 9 | ) 10 | 11 | // Set sets the n'th bit to 1 or 0 for values of true or false 12 | // respectively. 13 | func Set[I constraints.Integer](data []byte, n I, value bool) { 14 | if value { 15 | Set1(data, n) 16 | } else { 17 | Set0(data, n) 18 | } 19 | } 20 | 21 | // Set1 sets the n'th bit in data to 1. 22 | func Set1[I constraints.Integer](data []byte, n I) { 23 | data[n/8] |= 1 << (n % 8) 24 | } 25 | 26 | // Set0 sets the n'th bit in data to 0. 27 | func Set0[I constraints.Integer](data []byte, n I) { 28 | data[n/8] &= ^(1 << (n % 8)) 29 | } 30 | 31 | // Get returns the value of the n'th bit (0 or 1). 32 | func Get[I constraints.Integer](data []byte, n I) int { 33 | return int((data[n/8] >> (n % 8)) & 1) 34 | } 35 | 36 | // Sum returns the number of bits that have a value of 1. 37 | func Sum(data []byte) int { 38 | a := 0 39 | for _, b := range data { 40 | a += mbits.OnesCount8(b) 41 | } 42 | return a 43 | } 44 | 45 | // Ones iterates over the indexes of bits whose values are 1. 46 | func Ones(data []byte) iter.Seq[int] { 47 | return func(yield func(int) bool) { 48 | for i, x := range data { 49 | for _, b := range byteOnes[x] { 50 | if !yield(i*8 + b) { 51 | return 52 | } 53 | } 54 | } 55 | } 56 | } 57 | 58 | // Zeros iterates over the indexes of bits whose values are 0. 59 | func Zeros(data []byte) iter.Seq[int] { 60 | return func(yield func(int) bool) { 61 | for i, x := range data { 62 | for _, b := range byteZeros[x] { 63 | if !yield(i*8 + b) { 64 | return 65 | } 66 | } 67 | } 68 | } 69 | } 70 | 71 | // Indexes of ones for each byte value. 72 | var byteOnes [][]int 73 | 74 | // Indexes of zeros for each byte value. 75 | var byteZeros [][]int 76 | 77 | // Calculates indexes of ones and zeros for each byte value. 78 | func init() { 79 | byteOnes = make([][]int, 256) 80 | byteZeros = make([][]int, 256) 81 | for i := range 256 { 82 | var ones, zeros []int 83 | for b := range 8 { 84 | if (i>>b)&1 == 1 { 85 | ones = append(ones, b) 86 | } else { 87 | zeros = append(zeros, b) 88 | } 89 | } 90 | byteOnes[i] = ones 91 | byteZeros[i] = zeros 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /rhash/rabin32.go: -------------------------------------------------------------------------------- 1 | package rhash 2 | 3 | import ( 4 | "fmt" 5 | "hash" 6 | ) 7 | 8 | var _ hash.Hash32 = &RabinFingerprint32{} 9 | 10 | // RabinFingerprint32 implements a Rabin fingerprint rolling-hash. 11 | // Implements [hash.Hash32]. 12 | type RabinFingerprint32 struct { 13 | h, pow uint32 14 | i int 15 | hist []byte 16 | } 17 | 18 | // NewRabinFingerprint32 returns a new rolling hash with a window size of n. 19 | func NewRabinFingerprint32(n int) *RabinFingerprint32 { 20 | if n < 1 { 21 | panic(fmt.Sprintf("bad n: %d", n)) 22 | } 23 | return &RabinFingerprint32{0, 1, 0, make([]byte, n)} 24 | } 25 | 26 | // Write updates the hash with the given bytes. Always returns len(data), nil. 27 | func (h *RabinFingerprint32) Write(data []byte) (int, error) { 28 | for _, b := range data { 29 | h.WriteByte(b) 30 | } 31 | return len(data), nil 32 | } 33 | 34 | // WriteByte updates the hash with the given byte. Always returns nil. 35 | func (h *RabinFingerprint32) WriteByte(b byte) error { 36 | h.h = h.h*rabinPrime + uint32(b) 37 | i := h.i % len(h.hist) 38 | h.h -= h.pow * uint32(h.hist[i]) 39 | h.hist[i] = b 40 | if h.i < len(h.hist) { 41 | h.pow *= rabinPrime 42 | } 43 | h.i++ 44 | return nil 45 | } 46 | 47 | // Sum appends the current hash to b and returns the resulting slice. 48 | func (h *RabinFingerprint32) Sum(b []byte) []byte { 49 | s := h.Sum32() 50 | n := h.Size() 51 | for i := 0; i < n; i++ { 52 | b = append(b, byte(s)) 53 | s >>= 8 54 | } 55 | return b 56 | } 57 | 58 | // Sum32 returns the current hash. 59 | func (h *RabinFingerprint32) Sum32() uint32 { 60 | return h.h 61 | } 62 | 63 | // Size returns the number of bytes Sum will return, which is four. 64 | func (h *RabinFingerprint32) Size() int { 65 | return 4 66 | } 67 | 68 | // BlockSize returns the hash's block size, which is one. 69 | func (h *RabinFingerprint32) BlockSize() int { 70 | return 1 71 | } 72 | 73 | // Reset resets the hash to its initial state. 74 | func (h *RabinFingerprint32) Reset() { 75 | h.h = 0 76 | h.i = 0 77 | h.pow = 1 78 | for i := range h.hist { 79 | h.hist[i] = 0 80 | } 81 | } 82 | 83 | // RabinFingerprintSum32 returns the Rabin fingerprint of data. 84 | func RabinFingerprintSum32(data []byte) uint32 { 85 | if len(data) == 0 { 86 | return 0 87 | } 88 | h := NewRabinFingerprint32(len(data)) 89 | h.Write(data) 90 | return h.Sum32() 91 | } 92 | -------------------------------------------------------------------------------- /clustering/rand.go: -------------------------------------------------------------------------------- 1 | package clustering 2 | 3 | import ( 4 | "fmt" 5 | "slices" 6 | 7 | "github.com/fluhus/gostuff/sets" 8 | ) 9 | 10 | // AdjustedRandIndex compares 2 taggings of the data for similarity. A score of 11 | // 1 means identical, a score of 0 means as good as random, and a negative 12 | // score means worse than random. 13 | func AdjustedRandIndex(tags1, tags2 []int) float64 { 14 | // Check input. 15 | if len(tags1) != len(tags2) { 16 | panic(fmt.Sprintf("Mismatching lengths: %d, %d", 17 | len(tags1), len(tags2))) 18 | } 19 | 20 | sets1 := tagsToSets(tags1) 21 | sets2 := tagsToSets(tags2) 22 | 23 | r := randIndex(sets1, sets2) 24 | e := expectedRandIndex(sets1, sets2) 25 | m := maxRandIndex(sets1, sets2) 26 | return (r - e) / (m - e) 27 | } 28 | 29 | // randIndex returns the RI part of the adjusted index. 30 | func randIndex(tags1, tags2 [][]int) float64 { 31 | r := 0 32 | for _, t1 := range tags1 { 33 | for _, t2 := range tags2 { 34 | r += choose2(sets.SortedIntersectionLen(t1, t2)) 35 | } 36 | } 37 | return float64(r) 38 | } 39 | 40 | // expectedRandIndex returns the expected index according to hypergeometrical 41 | // distribution. 42 | func expectedRandIndex(tags1, tags2 [][]int) float64 { 43 | p1 := 0 44 | n := 0 45 | for _, tags := range tags1 { 46 | n += len(tags) 47 | p1 += choose2(len(tags)) 48 | } 49 | p2 := 0 50 | for _, tags := range tags2 { 51 | p2 += choose2(len(tags)) 52 | } 53 | p := float64(choose2(n)) 54 | return float64(p1) * float64(p2) / p 55 | } 56 | 57 | // maxRandIndex returns the maximal possible index. 58 | func maxRandIndex(tags1, tags2 [][]int) float64 { 59 | p := 0 60 | for _, tags := range tags1 { 61 | p += choose2(len(tags)) 62 | } 63 | for _, tags := range tags2 { 64 | p += choose2(len(tags)) 65 | } 66 | return float64(p) / 2 67 | } 68 | 69 | func choose2(n int) int { 70 | return n * (n - 1) / 2 71 | } 72 | 73 | // tagsToSets converts a list of tags to a list of sets of indexes, one list 74 | // for each tag. 75 | func tagsToSets(tags []int) [][]int { 76 | // Make map from tag to its set. 77 | sets := map[int][]int{} 78 | for i, tag := range tags { 79 | sets[tag] = append(sets[tag], i) 80 | } 81 | 82 | // Convert map to slice. 83 | result := make([][]int, 0, len(sets)) 84 | for _, set := range sets { 85 | slices.Sort(set) 86 | result = append(result, set) 87 | } 88 | 89 | return result 90 | } 91 | -------------------------------------------------------------------------------- /rhash/rabin64.go: -------------------------------------------------------------------------------- 1 | package rhash 2 | 3 | import ( 4 | "fmt" 5 | "hash" 6 | ) 7 | 8 | const rabinPrime = 16777619 9 | 10 | var _ hash.Hash64 = &RabinFingerprint64{} 11 | 12 | // RabinFingerprint64 implements a Rabin fingerprint rolling-hash. 13 | // Implements [hash.Hash64]. 14 | type RabinFingerprint64 struct { 15 | h, pow uint64 16 | i int 17 | hist []byte 18 | } 19 | 20 | // NewRabinFingerprint64 returns a new rolling hash with a window size of n. 21 | func NewRabinFingerprint64(n int) *RabinFingerprint64 { 22 | if n < 1 { 23 | panic(fmt.Sprintf("bad n: %d", n)) 24 | } 25 | return &RabinFingerprint64{0, 1, 0, make([]byte, n)} 26 | } 27 | 28 | // Write updates the hash with the given bytes. Always returns len(data), nil. 29 | func (h *RabinFingerprint64) Write(data []byte) (int, error) { 30 | for _, b := range data { 31 | h.WriteByte(b) 32 | } 33 | return len(data), nil 34 | } 35 | 36 | // WriteByte updates the hash with the given byte. Always returns nil. 37 | func (h *RabinFingerprint64) WriteByte(b byte) error { 38 | h.h = h.h*rabinPrime + uint64(b) 39 | i := h.i % len(h.hist) 40 | h.h -= h.pow * uint64(h.hist[i]) 41 | h.hist[i] = b 42 | if h.i < len(h.hist) { 43 | h.pow *= rabinPrime 44 | } 45 | h.i++ 46 | return nil 47 | } 48 | 49 | // Sum appends the current hash to b and returns the resulting slice. 50 | func (h *RabinFingerprint64) Sum(b []byte) []byte { 51 | s := h.Sum64() 52 | n := h.Size() 53 | for i := 0; i < n; i++ { 54 | b = append(b, byte(s)) 55 | s >>= 8 56 | } 57 | return b 58 | } 59 | 60 | // Sum64 returns the current hash. 61 | func (h *RabinFingerprint64) Sum64() uint64 { 62 | return h.h 63 | } 64 | 65 | // Size returns the number of bytes Sum will return, which is eight. 66 | func (h *RabinFingerprint64) Size() int { 67 | return 8 68 | } 69 | 70 | // BlockSize returns the hash's block size, which is one. 71 | func (h *RabinFingerprint64) BlockSize() int { 72 | return 1 73 | } 74 | 75 | // Reset resets the hash to its initial state. 76 | func (h *RabinFingerprint64) Reset() { 77 | h.h = 0 78 | h.i = 0 79 | h.pow = 1 80 | for i := range h.hist { 81 | h.hist[i] = 0 82 | } 83 | } 84 | 85 | // RabinFingerprintSum64 returns the Rabin fingerprint of data. 86 | func RabinFingerprintSum64(data []byte) uint64 { 87 | if len(data) == 0 { 88 | return 0 89 | } 90 | h := NewRabinFingerprint64(len(data)) 91 | h.Write(data) 92 | return h.Sum64() 93 | } 94 | -------------------------------------------------------------------------------- /hll/hll_test.go: -------------------------------------------------------------------------------- 1 | package hll 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/fluhus/gostuff/bnry" 8 | "github.com/spaolacci/murmur3" 9 | ) 10 | 11 | func TestCount_short(t *testing.T) { 12 | upto := 10000000 13 | if testing.Short() { 14 | upto = 1000 15 | } 16 | 17 | hll := newIntHLL() 18 | next := 1 19 | ratioSum := 0.0 20 | ratioCount := 0.0 21 | for i := 1; i <= upto; i++ { 22 | hll.Add(i) 23 | if i != next { // Check only a sample. 24 | continue 25 | } 26 | next = (next + 1) * 21 / 20 27 | ratio := float64(i) / float64(hll.ApproxCount()) 28 | ratioSum += math.Abs(math.Log(ratio)) 29 | ratioCount++ 30 | } 31 | avg := math.Exp(ratioSum / ratioCount) 32 | want := 1.003 33 | if avg > want { 34 | t.Errorf("average error=%f, want at most %f", 35 | avg, want) 36 | } 37 | } 38 | 39 | func TestCount_zero(t *testing.T) { 40 | hll := newIntHLL() 41 | if count := hll.ApproxCount(); count != 0 { 42 | t.Fatalf("ApproxCount()=%v, want 0", count) 43 | } 44 | } 45 | 46 | func TestAddHLL(t *testing.T) { 47 | hll1 := newIntHLL() 48 | for i := 1; i <= 5; i++ { 49 | hll1.Add(i) 50 | } 51 | if count := hll1.ApproxCount(); count != 5 { 52 | t.Fatalf("ApproxCount()=%v, want 5", count) 53 | } 54 | 55 | hll2 := newIntHLL() 56 | for i := 4; i <= 9; i++ { 57 | hll2.Add(i) 58 | } 59 | if count := hll2.ApproxCount(); count != 6 { 60 | t.Fatalf("ApproxCount()=%v, want 6", count) 61 | } 62 | 63 | hll1.AddHLL(hll2) 64 | if count := hll1.ApproxCount(); count != 9 { 65 | t.Fatalf("ApproxCount()=%v, want 9", count) 66 | } 67 | } 68 | 69 | func BenchmarkAdd(b *testing.B) { 70 | hll := New(16, func(i int) uint64 { return uint64(i) }) 71 | for i := 0; i < b.N; i++ { 72 | hll.Add(i) 73 | } 74 | } 75 | 76 | func BenchmarkAdd_intHLL(b *testing.B) { 77 | hll := newIntHLL() 78 | for i := 0; i < b.N; i++ { 79 | hll.Add(i) 80 | } 81 | } 82 | 83 | func BenchmarkCount(b *testing.B) { 84 | const nelements = 1000000 85 | hll := newIntHLL() 86 | for i := 0; i < nelements; i++ { 87 | hll.Add(i) 88 | } 89 | b.Run("", func(b *testing.B) { 90 | for i := 0; i < b.N; i++ { 91 | hll.ApproxCount() 92 | } 93 | }) 94 | } 95 | 96 | func newIntHLL() *HLL[int] { 97 | h := murmur3.New64() 98 | w := bnry.NewWriter(h) 99 | return New(16, func(i int) uint64 { 100 | h.Reset() 101 | w.Write(i) 102 | return h.Sum64() 103 | }) 104 | } 105 | -------------------------------------------------------------------------------- /ppln/nserial.go: -------------------------------------------------------------------------------- 1 | package ppln 2 | 3 | import ( 4 | "fmt" 5 | "iter" 6 | "sync" 7 | "sync/atomic" 8 | ) 9 | 10 | // NonSerial starts a multi-goroutine transformation pipeline. 11 | // 12 | // Input is an iterator over the input values to be transformed. 13 | // It will be called in a thread-safe manner. 14 | // Transform receives an input (a) and a 0-based goroutine number (g), 15 | // and returns the result of processing a. 16 | // Output acts on a single result, and will be called in a thread-safe manner. 17 | // The order of outputs is arbitrary, but correlated with the order of 18 | // inputs. 19 | // 20 | // If one of the functions returns a non-nil error, the process stops and the 21 | // error is returned. Otherwise returns nil. 22 | func NonSerial[T1 any, T2 any]( 23 | ngoroutines int, 24 | input iter.Seq2[T1, error], 25 | transform func(a T1, g int) (T2, error), 26 | output func(a T2) error) error { 27 | if ngoroutines < 1 { 28 | panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines)) 29 | } 30 | pull, pstop := iter.Pull2(input) 31 | defer pstop() 32 | 33 | // An optimization for a single thread. 34 | if ngoroutines == 1 { 35 | for { 36 | t1, err, ok := pull() 37 | 38 | if !ok { 39 | return nil 40 | } 41 | if err != nil { 42 | return err 43 | } 44 | 45 | t2, err := transform(t1, 0) 46 | if err != nil { 47 | return err 48 | } 49 | if err := output(t2); err != nil { 50 | return err 51 | } 52 | } 53 | } 54 | 55 | ilock := &sync.Mutex{} 56 | olock := &sync.Mutex{} 57 | errs := make(chan error, ngoroutines) 58 | stop := &atomic.Bool{} 59 | 60 | for g := 0; g < ngoroutines; g++ { 61 | go func(g int) { 62 | for { 63 | if stop.Load() { 64 | errs <- nil 65 | return 66 | } 67 | 68 | ilock.Lock() 69 | t1, err, ok := pull() 70 | ilock.Unlock() 71 | 72 | if !ok { 73 | errs <- nil 74 | return 75 | } 76 | if err != nil { 77 | stop.Store(true) 78 | errs <- err 79 | return 80 | } 81 | 82 | t2, err := transform(t1, g) 83 | if err != nil { 84 | stop.Store(true) 85 | errs <- err 86 | return 87 | } 88 | 89 | olock.Lock() 90 | err = output(t2) 91 | olock.Unlock() 92 | if err != nil { 93 | stop.Store(true) 94 | errs <- err 95 | return 96 | } 97 | } 98 | }(g) 99 | } 100 | 101 | for g := 0; g < ngoroutines; g++ { 102 | if err := <-errs; err != nil { 103 | return err 104 | } 105 | } 106 | return nil 107 | } 108 | -------------------------------------------------------------------------------- /xmlnode/read.go: -------------------------------------------------------------------------------- 1 | // Package xmlnode provides a hierarchical node representation of XML documents. 2 | // This package wraps encoding/xml and can be used instead of it. 3 | // 4 | // Each node has an underlying concrete type, but calling all functions is 5 | // legal. For example, here is how you can traverse the node tree: 6 | // 7 | // func traverse(n Node) { 8 | // // Text() returns an empty string for non-text nodes. 9 | // doSomeTextSearch(n.Text()) 10 | // 11 | // // Children() returns nil for non-parent nodes. 12 | // for _, child := range n.Children() { 13 | // traverse(child) 14 | // } 15 | // } 16 | package xmlnode 17 | 18 | import ( 19 | "encoding/xml" 20 | "io" 21 | ) 22 | 23 | // ReadAll reads all XML data from the given reader and stores it in a root node. 24 | func ReadAll(r io.Reader) (Node, error) { 25 | // Create root node. 26 | // Starting with Tag instead of Root, to eliminate type checks when referring 27 | // to parent nodes during reading. Will be replaced with a Root node at the 28 | // end. 29 | result := &tag{ 30 | nil, 31 | "", 32 | nil, 33 | nil, 34 | } 35 | dec := xml.NewDecoder(r) 36 | 37 | var t xml.Token 38 | var err error 39 | current := result 40 | 41 | // Parse tokens. 42 | for t, err = dec.Token(); err == nil; t, err = dec.Token() { 43 | switch t := t.(type) { 44 | case xml.StartElement: 45 | // Copy attributes. 46 | attrs := make([]*xml.Attr, len(t.Attr)) 47 | for i, attr := range t.Attr { 48 | attrs[i] = &xml.Attr{Name: attr.Name, Value: attr.Value} 49 | } 50 | 51 | // Create child node. 52 | child := &tag{ 53 | current, 54 | t.Name.Local, 55 | attrs, 56 | nil, 57 | } 58 | 59 | current.children = append(current.children, child) 60 | current = child 61 | 62 | case xml.EndElement: 63 | current = current.Parent().(*tag) 64 | 65 | case xml.CharData: 66 | child := &text{ 67 | current, 68 | string(t), 69 | } 70 | 71 | current.children = append(current.children, child) 72 | 73 | case xml.Comment: 74 | child := &comment{ 75 | current, 76 | string(t), 77 | } 78 | 79 | current.children = append(current.children, child) 80 | 81 | case xml.ProcInst: 82 | child := &procInst{ 83 | current, 84 | string(t.Target), 85 | string(t.Inst), 86 | } 87 | 88 | current.children = append(current.children, child) 89 | 90 | case xml.Directive: 91 | child := &directive{ 92 | current, 93 | string(t), 94 | } 95 | 96 | current.children = append(current.children, child) 97 | } 98 | } 99 | 100 | // EOF is ok. 101 | if err != io.EOF { 102 | return nil, err 103 | } 104 | 105 | return &root{result.children}, nil 106 | } 107 | -------------------------------------------------------------------------------- /iterx/lines.go: -------------------------------------------------------------------------------- 1 | // Package iterx provides convenience functions for iterators. 2 | package iterx 3 | 4 | import ( 5 | "bufio" 6 | "encoding/csv" 7 | "io" 8 | "iter" 9 | 10 | "github.com/fluhus/gostuff/aio" 11 | ) 12 | 13 | // LinesReader iterates over text lines from a reader. 14 | func LinesReader(r io.Reader) iter.Seq2[string, error] { 15 | return func(yield func(string, error) bool) { 16 | sc := bufio.NewScanner(r) 17 | for sc.Scan() { 18 | if !yield(sc.Text(), nil) { 19 | return 20 | } 21 | } 22 | if err := sc.Err(); err != nil { 23 | yield("", err) 24 | } 25 | } 26 | } 27 | 28 | // LinesFile iterates over text lines from a reader. 29 | func LinesFile(file string) iter.Seq2[string, error] { 30 | return func(yield func(string, error) bool) { 31 | f, err := aio.Open(file) 32 | if err != nil { 33 | yield("", err) 34 | return 35 | } 36 | defer f.Close() 37 | sc := bufio.NewScanner(f) 38 | for sc.Scan() { 39 | if !yield(sc.Text(), nil) { 40 | return 41 | } 42 | } 43 | if err := sc.Err(); err != nil { 44 | yield("", err) 45 | } 46 | } 47 | } 48 | 49 | // CSVReader iterates over CSV entries from a reader. 50 | // fn is an optional function for modifying the CSV parser, 51 | // for example for changing the delimiter. 52 | // 53 | // Deprecated: use package csvx. 54 | func CSVReader(r io.Reader, fn func(*csv.Reader)) iter.Seq2[[]string, error] { 55 | return func(yield func([]string, error) bool) { 56 | c := csv.NewReader(r) 57 | if fn != nil { 58 | fn(c) 59 | } 60 | for { 61 | e, err := c.Read() 62 | if err == io.EOF { 63 | return 64 | } 65 | if !yield(e, nil) { 66 | return 67 | } 68 | } 69 | } 70 | } 71 | 72 | // CSVFile iterates over CSV entries from a file. 73 | // fn is an optional function for modifying the CSV parser, 74 | // for example for changing the delimiter. 75 | // 76 | // Deprecated: use package csvx. 77 | func CSVFile(file string, fn func(*csv.Reader)) iter.Seq2[[]string, error] { 78 | return func(yield func([]string, error) bool) { 79 | f, err := aio.Open(file) 80 | if err != nil { 81 | yield(nil, err) 82 | return 83 | } 84 | defer f.Close() 85 | c := csv.NewReader(f) 86 | if fn != nil { 87 | fn(c) 88 | } 89 | for { 90 | e, err := c.Read() 91 | if err == io.EOF { 92 | return 93 | } 94 | if !yield(e, nil) { 95 | return 96 | } 97 | } 98 | } 99 | } 100 | 101 | // CollectErr collects the given T's in a slice until the error is non-nil. 102 | func CollectErr[T any](it iter.Seq2[T, error]) ([]T, error) { 103 | var a []T 104 | for t, err := range it { 105 | if err != nil { 106 | return a, err 107 | } 108 | a = append(a, t) 109 | } 110 | return a, nil 111 | } 112 | -------------------------------------------------------------------------------- /ppln/nserial_test.go: -------------------------------------------------------------------------------- 1 | package ppln 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | "testing" 7 | ) 8 | 9 | func TestNonSerial(t *testing.T) { 10 | want := 21082009.0 11 | for _, nt := range []int{1, 2, 4, 8} { 12 | t.Run(fmt.Sprint(nt), func(t *testing.T) { 13 | got := 0.0 14 | NonSerial[int, float64]( 15 | nt, 16 | RangeInput(1, 100001), 17 | func(a, g int) (float64, error) { 18 | return math.Sqrt(float64(a)), nil 19 | }, 20 | func(a float64) error { 21 | got += a 22 | return nil 23 | }, 24 | ) 25 | if math.Round(got) != want { 26 | t.Fatalf("NonSerial: got %f, want %f", got, want) 27 | } 28 | }) 29 | } 30 | } 31 | 32 | func TestNonSerial_inputError(t *testing.T) { 33 | for _, nt := range []int{1, 2, 4, 8} { 34 | t.Run(fmt.Sprint(nt), func(t *testing.T) { 35 | got := 0.0 36 | err := NonSerial[int, float64]( 37 | nt, 38 | func(yield func(int, error) bool) { 39 | for i, err := range RangeInput(1, 100001) { 40 | if i == 1000 { 41 | yield(0, fmt.Errorf("oh no")) 42 | return 43 | } 44 | if !yield(i, err) { 45 | return 46 | } 47 | } 48 | }, 49 | func(a, g int) (float64, error) { 50 | return math.Sqrt(float64(a)), nil 51 | }, 52 | func(a float64) error { 53 | got += a 54 | return nil 55 | }, 56 | ) 57 | if err == nil { 58 | t.Fatalf("NonSerial succeeded, want error") 59 | } 60 | }) 61 | } 62 | } 63 | 64 | func TestNonSerial_transformError(t *testing.T) { 65 | for _, nt := range []int{1, 2, 4, 8} { 66 | t.Run(fmt.Sprint(nt), func(t *testing.T) { 67 | got := 0.0 68 | err := NonSerial[int, float64]( 69 | nt, 70 | RangeInput(1, 100001), 71 | func(a, g int) (float64, error) { 72 | if a == 1000 { 73 | return 0, fmt.Errorf("oh no") 74 | } 75 | return math.Sqrt(float64(a)), nil 76 | }, 77 | func(a float64) error { 78 | got += a 79 | return nil 80 | }, 81 | ) 82 | if err == nil { 83 | t.Fatalf("NonSerial succeeded, want error") 84 | } 85 | }) 86 | } 87 | } 88 | 89 | func TestNonSerial_outputError(t *testing.T) { 90 | for _, nt := range []int{1, 2, 4, 8} { 91 | t.Run(fmt.Sprint(nt), func(t *testing.T) { 92 | got := 0.0 93 | err := NonSerial[int, float64]( 94 | nt, 95 | RangeInput(1, 100001), 96 | func(a, g int) (float64, error) { 97 | return math.Sqrt(float64(a)), nil 98 | }, 99 | func(a float64) error { 100 | if a == 32 { 101 | return fmt.Errorf("oh no") 102 | } 103 | got += a 104 | return nil 105 | }, 106 | ) 107 | if err == nil { 108 | t.Fatalf("NonSerial succeeded, want error") 109 | } 110 | }) 111 | } 112 | } 113 | 114 | // TODO(amit): Error tests. 115 | -------------------------------------------------------------------------------- /iterx/unreader.go: -------------------------------------------------------------------------------- 1 | package iterx 2 | 3 | import ( 4 | "fmt" 5 | "iter" 6 | ) 7 | 8 | // An Unreader wraps an iterator and adds an unread function. 9 | type Unreader[T any] struct { 10 | read func() (T, bool) // The underlying reader 11 | stop func() 12 | head T // The last read element 13 | hasHead bool // Unread was called 14 | } 15 | 16 | // Read returns the result of calling the underlying reader, 17 | // or the last unread element. 18 | func (r *Unreader[T]) Read() (T, bool) { 19 | if r.hasHead { 20 | // Remove headX values to allow GC. 21 | var t T 22 | t, r.head = r.head, t 23 | r.hasHead = false 24 | return t, true 25 | } 26 | t, ok := r.read() 27 | r.head = t 28 | return t, ok 29 | } 30 | 31 | // Unread makes the next call to Read return the last element and a nil error. 32 | // Can be called up to once per call to Read. 33 | func (r *Unreader[T]) Unread() { 34 | if r.hasHead { 35 | panic(fmt.Sprintf("called Unread twice: first with %v", r.head)) 36 | } 37 | r.hasHead = true 38 | } 39 | 40 | // Until calls Read until stop returns true. 41 | func (r *Unreader[T]) Until(stop func(T) bool) iter.Seq[T] { 42 | return func(yield func(T) bool) { 43 | for { 44 | t, ok := r.Read() 45 | if !ok { 46 | return 47 | } 48 | if stop(t) { 49 | r.Unread() 50 | return 51 | } 52 | if !yield(t) { 53 | r.stop() 54 | return 55 | } 56 | } 57 | } 58 | } 59 | 60 | // GroupBy returns an iterator over groups, where each group is an iterator that 61 | // yields elements as long as the elements have sameGroup==true with the first 62 | // element in the group. 63 | func (r *Unreader[T]) GroupBy(sameGroup func(old T, nu T) bool, 64 | ) iter.Seq[iter.Seq[T]] { 65 | return func(yield func(iter.Seq[T]) bool) { 66 | for { 67 | group, ok := r.nextGroup(sameGroup) 68 | if !ok { 69 | return 70 | } 71 | if !yield(group) { 72 | r.stop() 73 | return 74 | } 75 | } 76 | } 77 | } 78 | 79 | // Returns an iterator that yields elements while sameGroup==true. 80 | func (r *Unreader[T]) nextGroup(sameGroup func(T, T) bool) (iter.Seq[T], bool) { 81 | firstT, ok := r.Read() 82 | if !ok { 83 | return nil, false 84 | } 85 | return func(yield func(T) bool) { 86 | if !yield(firstT) { 87 | return 88 | } 89 | for { 90 | t, ok := r.Read() 91 | if !ok { 92 | return 93 | } 94 | if !sameGroup(firstT, t) { 95 | r.Unread() 96 | return 97 | } 98 | if !yield(t) { 99 | r.stop() 100 | return 101 | } 102 | } 103 | }, true 104 | } 105 | 106 | // NewUnreader returns an Unreader with seq as its underlying iterator. 107 | func NewUnreader[T any](seq iter.Seq[T]) *Unreader[T] { 108 | read, stop := iter.Pull(seq) 109 | return &Unreader[T]{read: read, stop: stop, hasHead: false} 110 | } 111 | -------------------------------------------------------------------------------- /nlp/wordnet/types.go: -------------------------------------------------------------------------------- 1 | package wordnet 2 | 3 | // WordNet is an entire wordnet database. 4 | type WordNet struct { 5 | // Maps from synset ID to synset. 6 | Synset map[string]*Synset `json:"synset"` 7 | 8 | // Maps from pos.lemma to synset IDs that contain it. 9 | Lemma map[string][]string `json:"lemma"` 10 | 11 | // Like Lemma, but synsets are ordered from the most frequently used to the 12 | // least. Only a subset of the synsets are ranked, so LemmaRanked has less 13 | // synsets. 14 | LemmaRanked map[string][]string `json:"lemmaRanked"` 15 | 16 | // Maps from exceptional word to its forms. 17 | Exception map[string][]string `json:"exception"` 18 | 19 | // Maps from example ID to sentence template. Using string keys for JSON 20 | // compatibility. 21 | Example map[string]string `json:"example"` 22 | } 23 | 24 | // Synset is a set of synonymous words. 25 | type Synset struct { 26 | // Synset offset, also used as an identifier. 27 | Offset string `json:"offset"` 28 | 29 | // Part of speech, including 's' for adjective satellite. 30 | Pos string `json:"pos"` 31 | 32 | // Words in this synset. 33 | Word []string `json:"word"` 34 | 35 | // Pointers to other synsets. 36 | Pointer []*Pointer `json:"pointer"` 37 | 38 | // Sentence frames for verbs. 39 | Frame []*Frame `json:"frame"` 40 | 41 | // Lexical definition. 42 | Gloss string `json:"gloss"` 43 | 44 | // Usage examples for words in this synset. Verbs only. 45 | Example []*Example `json:"example"` 46 | } 47 | 48 | // A Frame links a synset word to a generic phrase that illustrates how to use 49 | // it. Applies to verbs only. 50 | // 51 | // See the list of frames here: 52 | // https://wordnet.princeton.edu/man/wninput.5WN.html#sect4 53 | type Frame struct { 54 | // Index of word in the containing synset, -1 for entire synset. 55 | WordNumber int `json:"wordNumber"` 56 | 57 | // Frame number on the WordNet site. 58 | FrameNumber int `json:"frameNumber"` 59 | } 60 | 61 | // A Pointer denotes a semantic relation between one synset/word to another. 62 | // 63 | // See list of pointer symbols here: 64 | // https://wordnet.princeton.edu/man/wninput.5WN.html#sect3 65 | type Pointer struct { 66 | // Relation between the 2 words. Target is to source. See 67 | // package constants for meaning of symbols. 68 | Symbol string `json:"symbol"` 69 | 70 | // Target synset ID. 71 | Synset string `json:"synset"` 72 | 73 | // Index of word in source synset, -1 for entire synset. 74 | Source int `json:"source"` 75 | 76 | // Index of word in target synset, -1 for entire synset. 77 | Target int `json:"target"` 78 | } 79 | 80 | // An Example links a synset word to an example sentence. Applies to verbs only. 81 | type Example struct { 82 | // Index of word in the containing synset. 83 | WordNumber int `json:"wordNumber"` 84 | 85 | // Number of template in the WordNet.Example field. 86 | TemplateNumber int `json:"templateNumber"` 87 | } 88 | -------------------------------------------------------------------------------- /sets/sets.go: -------------------------------------------------------------------------------- 1 | // Package sets provides generic sets. 2 | package sets 3 | 4 | import ( 5 | "encoding/json" 6 | 7 | "golang.org/x/exp/maps" 8 | ) 9 | 10 | // Set is a convenience wrapper around map[T]struct{}. 11 | type Set[T comparable] map[T]struct{} 12 | 13 | // Add inserts the given elements to s and returns s. 14 | func (s Set[T]) Add(t ...T) Set[T] { 15 | for _, v := range t { 16 | s[v] = struct{}{} 17 | } 18 | return s 19 | } 20 | 21 | // AddSet inserts the elements of t to s and returns s. 22 | func (s Set[T]) AddSet(t Set[T]) Set[T] { 23 | for v := range t { 24 | s[v] = struct{}{} 25 | } 26 | return s 27 | } 28 | 29 | // Remove deletes the given elements from s and returns s. 30 | func (s Set[T]) Remove(t ...T) Set[T] { 31 | for _, v := range t { 32 | delete(s, v) 33 | } 34 | return s 35 | } 36 | 37 | // RemoveSet deletes the elements of t from s and returns s. 38 | func (s Set[T]) RemoveSet(t Set[T]) Set[T] { 39 | for v := range t { 40 | delete(s, v) 41 | } 42 | return s 43 | } 44 | 45 | // Has returns whether t is a member of s. 46 | func (s Set[T]) Has(t T) bool { 47 | _, ok := s[t] 48 | return ok 49 | } 50 | 51 | // Intersect returns a new set holding the elements that are common 52 | // to s and t. 53 | func (s Set[T]) Intersect(t Set[T]) Set[T] { 54 | if len(s) > len(t) { // Iterate over the smaller one. 55 | s, t = t, s 56 | } 57 | result := Set[T]{} 58 | for v := range s { 59 | if t.Has(v) { 60 | result.Add(v) 61 | } 62 | } 63 | return result 64 | } 65 | 66 | // MarshalJSON implements the json.Marshaler interface. 67 | func (s Set[T]) MarshalJSON() ([]byte, error) { 68 | return json.Marshal(maps.Keys(s)) 69 | } 70 | 71 | // UnmarshalJSON implements the json.Unmarshaler interface. 72 | func (s *Set[T]) UnmarshalJSON(b []byte) error { 73 | var slice []T 74 | if err := json.Unmarshal(b, &slice); err != nil { 75 | return err 76 | } 77 | if *s == nil { 78 | *s = Set[T]{} 79 | } 80 | s.Add(slice...) 81 | return nil 82 | } 83 | 84 | // AddKeys adds the keys of a map to a set. 85 | func AddKeys[K comparable, V any](s Set[K], m map[K]V) Set[K] { 86 | for k := range m { 87 | s.Add(k) 88 | } 89 | return s 90 | } 91 | 92 | // AddValues adds the values of a map to a set. 93 | func AddValues[K comparable, V comparable](s Set[V], m map[K]V) Set[V] { 94 | for _, v := range m { 95 | s.Add(v) 96 | } 97 | return s 98 | } 99 | 100 | // Of returns a new set containing the given elements. 101 | func Of[T comparable](t ...T) Set[T] { 102 | return make(Set[T], len(t)).Add(t...) 103 | } 104 | 105 | // FromKeys returns a new set containing the keys of the given map. 106 | func FromKeys[K comparable, V any](m map[K]V) Set[K] { 107 | return AddKeys(make(Set[K], len(m)), m) 108 | } 109 | 110 | // FromValues returns a new set containing the values of the given map. 111 | func FromValues[K comparable, V comparable](m map[K]V) Set[V] { 112 | return AddValues(Set[V]{}, m) 113 | } 114 | -------------------------------------------------------------------------------- /heaps/heaps_test.go: -------------------------------------------------------------------------------- 1 | package heaps 2 | 3 | import ( 4 | "fmt" 5 | "math/rand/v2" 6 | "slices" 7 | "testing" 8 | 9 | "github.com/fluhus/gostuff/snm" 10 | ) 11 | 12 | func TestHeap(t *testing.T) { 13 | input := []string{"bb", "a", "ffff", "ddddd"} 14 | want := []string{"a", "bb", "ddddd", "ffff"} 15 | h := Min[string]() 16 | for _, v := range input { 17 | h.Push(v) 18 | } 19 | if ln := h.Len(); ln != len(input) { 20 | t.Fatalf("Len()=%d, want %d", ln, len(input)) 21 | } 22 | var got []string 23 | for h.Len() > 0 { 24 | got = append(got, h.Pop()) 25 | } 26 | if !slices.Equal(got, want) { 27 | t.Fatalf("Pop=%v, want %v", got, want) 28 | } 29 | } 30 | 31 | func TestHeap_big(t *testing.T) { 32 | input := []int{ 33 | 5, 8, 25, 21, 22, 15, 13, 20, 1, 14, 34 | 24, 12, 7, 18, 27, 3, 30, 28, 23, 29, 35 | 19, 2, 6, 4, 26, 9, 17, 10, 11, 16, 36 | } 37 | want := []int{ 38 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 39 | 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 40 | 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 41 | } 42 | h := New(func(i1, i2 int) bool { 43 | return i1 < i2 44 | }) 45 | for _, v := range input { 46 | h.Push(v) 47 | } 48 | if ln := h.Len(); ln != len(input) { 49 | t.Fatalf("Len()=%d, want %d", ln, len(input)) 50 | } 51 | var got []int 52 | for h.Len() > 0 { 53 | got = append(got, h.Pop()) 54 | } 55 | if !slices.Equal(got, want) { 56 | t.Fatalf("Pop=%v, want %v", got, want) 57 | } 58 | } 59 | 60 | func TestHeap_pushSlice(t *testing.T) { 61 | input := []int{ 62 | 5, 8, 25, 21, 22, 15, 13, 20, 1, 14, 63 | 24, 12, 7, 18, 27, 3, 30, 28, 23, 29, 64 | 19, 2, 6, 4, 26, 9, 17, 10, 11, 16, 65 | } 66 | h := Min[int]() 67 | h.PushSlice(input) 68 | for i := range h.a { 69 | if i == 0 { 70 | continue 71 | } 72 | ia := (i - 1) / 2 73 | if h.a[i] < h.a[ia] { 74 | t.Errorf("h[%d] < h[%d]: %d < %d", i, ia, h.a[i], h.a[ia]) 75 | } 76 | } 77 | } 78 | 79 | func Benchmark(b *testing.B) { 80 | for _, n := range []int{1000, 10000, 100000, 1000000} { 81 | nums := snm.Slice(n, func(i int) int { 82 | return rand.Int() 83 | }) 84 | b.Run(fmt.Sprint("Heap.Push.", n), func(b *testing.B) { 85 | for range b.N { 86 | h := Min[int]() 87 | for _, i := range nums { 88 | h.Push(i) 89 | } 90 | } 91 | }) 92 | b.Run(fmt.Sprint("Heap.PushSlice.", n), func(b *testing.B) { 93 | for range b.N { 94 | h := Min[int]() 95 | h.PushSlice(nums) 96 | } 97 | }) 98 | } 99 | } 100 | 101 | func FuzzHeap(f *testing.F) { 102 | f.Add(0, 0, 0, 0, 0, 0, 0) 103 | f.Fuzz(func(t *testing.T, a, b, c, d, e, f, g int) { 104 | h := Min[int]() 105 | h.Push(a) 106 | h.Push(b) 107 | h.Push(c) 108 | h.Push(d) 109 | h.Push(e) 110 | h.Push(f) 111 | h.Push(g) 112 | got := make([]int, 0, 7) 113 | for h.Len() > 0 { 114 | got = append(got, h.Pop()) 115 | } 116 | if !slices.IsSorted(got) { 117 | t.Fatalf("Min().Pop()=%v, want sorted", got) 118 | } 119 | }) 120 | } 121 | -------------------------------------------------------------------------------- /iterx/unreader2.go: -------------------------------------------------------------------------------- 1 | package iterx 2 | 3 | import ( 4 | "fmt" 5 | "iter" 6 | ) 7 | 8 | // An Unreader2 wraps an iterator and adds an unread function. 9 | type Unreader2[T, S any] struct { 10 | read func() (T, S, bool) // The underlying reader 11 | stop func() 12 | headT T // The last read element 13 | headS S // The last read element 14 | hasHead bool // Unread was called 15 | } 16 | 17 | // Read returns the result of calling the underlying reader, 18 | // or the last unread element. 19 | func (r *Unreader2[T, S]) Read() (T, S, bool) { 20 | if r.hasHead { 21 | // Remove headX values to allow GC. 22 | var t T 23 | var s S 24 | t, r.headT = r.headT, t 25 | s, r.headS = r.headS, s 26 | r.hasHead = false 27 | return t, s, true 28 | } 29 | t, s, ok := r.read() 30 | r.headT = t 31 | r.headS = s 32 | return t, s, ok 33 | } 34 | 35 | // Unread makes the next call to Read return the last element and a nil error. 36 | // Can be called up to once per call to Read. 37 | func (r *Unreader2[T, S]) Unread() { 38 | if r.hasHead { 39 | panic(fmt.Sprintf("called Unread twice: first with (%v,%v)", 40 | r.headT, r.headS)) 41 | } 42 | r.hasHead = true 43 | } 44 | 45 | // Until calls Read until stop returns true. 46 | func (r *Unreader2[T, S]) Until(stop func(T, S) bool) iter.Seq2[T, S] { 47 | return func(yield func(T, S) bool) { 48 | for { 49 | t, s, ok := r.Read() 50 | if !ok { 51 | return 52 | } 53 | if stop(t, s) { 54 | r.Unread() 55 | return 56 | } 57 | if !yield(t, s) { 58 | r.stop() 59 | return 60 | } 61 | } 62 | } 63 | } 64 | 65 | // GroupBy returns an iterator over groups, where each group is an iterator that 66 | // yields elements as long as the elements have sameGroup==true with the first 67 | // element in the group. 68 | func (r *Unreader2[T, S]) GroupBy( 69 | sameGroup func(oldT T, oldS S, newT T, newS S) bool, 70 | ) iter.Seq[iter.Seq2[T, S]] { 71 | return func(yield func(iter.Seq2[T, S]) bool) { 72 | for { 73 | group, ok := r.nextGroup(sameGroup) 74 | if !ok { 75 | return 76 | } 77 | if !yield(group) { 78 | r.stop() 79 | return 80 | } 81 | } 82 | } 83 | } 84 | 85 | // Returns an iterator that yields elements while sameGroup==true. 86 | func (r *Unreader2[T, S]) nextGroup(sameGroup func(T, S, T, S) bool) (iter.Seq2[T, S], bool) { 87 | firstT, firstS, ok := r.Read() 88 | if !ok { 89 | return nil, false 90 | } 91 | return func(yield func(T, S) bool) { 92 | if !yield(firstT, firstS) { 93 | return 94 | } 95 | for { 96 | t, s, ok := r.Read() 97 | if !ok { 98 | return 99 | } 100 | if !sameGroup(firstT, firstS, t, s) { 101 | r.Unread() 102 | return 103 | } 104 | if !yield(t, s) { 105 | r.stop() 106 | return 107 | } 108 | } 109 | }, true 110 | } 111 | 112 | // NewUnreader2 returns an Unreader2 with seq as its underlying iterator. 113 | func NewUnreader2[T, S any](seq iter.Seq2[T, S]) *Unreader2[T, S] { 114 | read, stop := iter.Pull2(seq) 115 | return &Unreader2[T, S]{read: read, stop: stop, hasHead: false} 116 | } 117 | -------------------------------------------------------------------------------- /ptimer/ptimer_test.go: -------------------------------------------------------------------------------- 1 | package ptimer 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | "regexp" 8 | "slices" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | const timePattern = "\\d\\d:\\d\\d:\\d\\d\\.\\d\\d\\d\\d\\d\\d" 14 | 15 | func TestNew(t *testing.T) { 16 | want := "^" 17 | for _, i := range []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 35} { 18 | want += fmt.Sprintf("\r%s \\(%s\\) %d", timePattern, timePattern, i) 19 | } 20 | want += "\n$" 21 | 22 | got := bytes.NewBuffer(nil) 23 | pt := New() 24 | pt.W = got 25 | for i := 0; i < 35; i++ { 26 | pt.Inc() 27 | } 28 | pt.Done() 29 | 30 | if match, _ := regexp.MatchString(want, got.String()); !match { 31 | t.Fatalf("Inc()+Done()=%q, want %q", got.String(), want) 32 | } 33 | } 34 | 35 | func TestNewMessage(t *testing.T) { 36 | want := "^" 37 | for _, i := range []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 35} { 38 | want += fmt.Sprintf("\r%s \\(%s\\) hey %d ho", 39 | timePattern, timePattern, i) 40 | } 41 | want += "\n$" 42 | 43 | got := bytes.NewBuffer(nil) 44 | pt := NewMessage("hey {} ho") 45 | pt.W = got 46 | for i := 0; i < 35; i++ { 47 | pt.Inc() 48 | } 49 | pt.Done() 50 | 51 | if match, _ := regexp.MatchString(want, got.String()); !match { 52 | t.Fatalf("Inc()+Done()=%q, want %q", got.String(), want) 53 | } 54 | } 55 | 56 | func TestNewFunc(t *testing.T) { 57 | want := "^" 58 | for _, i := range []float64{1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 59 | 10.5, 20.5, 30.5, 35.5} { 60 | want += fmt.Sprintf("\r%s \\(%s\\) ho ho %f", 61 | timePattern, timePattern, i) 62 | } 63 | want += "\n$" 64 | 65 | got := bytes.NewBuffer(nil) 66 | pt := NewFunc(func(i int) string { 67 | return fmt.Sprintf("ho ho %f", float64(i)+0.5) 68 | }) 69 | pt.W = got 70 | for i := 0; i < 35; i++ { 71 | pt.Inc() 72 | } 73 | pt.Done() 74 | 75 | if match, _ := regexp.MatchString(want, got.String()); !match { 76 | t.Fatalf("Inc()+Done()=%q, want %q", got.String(), want) 77 | } 78 | } 79 | 80 | func TestDone(t *testing.T) { 81 | want := "^" + timePattern + " hello\n$" 82 | 83 | got := bytes.NewBuffer(nil) 84 | pt := NewMessage("hello") 85 | pt.W = got 86 | pt.Done() 87 | 88 | if match, _ := regexp.MatchString(want, got.String()); !match { 89 | t.Fatalf("Done()=%q, want %q", got.String(), want) 90 | } 91 | } 92 | 93 | func TestNextCheckpoint(t *testing.T) { 94 | want := []int{ 95 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 96 | 10, 20, 30, 40, 50, 60, 70, 80, 90, 97 | 100, 200, 300, 400, 500} 98 | var got []int 99 | i := 0 100 | for range want { 101 | i = nextCheckpoint(i) 102 | got = append(got, i) 103 | } 104 | if !slices.Equal(got, want) { 105 | t.Fatalf("nextCheckpoint(...)=%v, want %v", got, want) 106 | } 107 | } 108 | 109 | func BenchmarkTimer_inc(b *testing.B) { 110 | pt := New() 111 | pt.W = io.Discard 112 | b.ResetTimer() 113 | for i := 0; i < b.N; i++ { 114 | pt.Inc() 115 | } 116 | } 117 | 118 | func Example() { 119 | pt := New() 120 | for i := 0; i < 45; i++ { 121 | time.Sleep(100 * time.Millisecond) 122 | pt.Inc() 123 | } 124 | pt.Done() 125 | } 126 | -------------------------------------------------------------------------------- /ppln/serial_test.go: -------------------------------------------------------------------------------- 1 | package ppln 2 | 3 | import ( 4 | "fmt" 5 | "math/rand" 6 | "testing" 7 | "time" 8 | 9 | "github.com/fluhus/gostuff/gnum" 10 | ) 11 | 12 | func ExampleSerial() { 13 | ngoroutines := 4 14 | var results []float64 15 | 16 | Serial[int, float64](ngoroutines, 17 | // Read/generate input data. 18 | RangeInput(1, 101), 19 | // Some processing. 20 | func(a, i, g int) (float64, error) { 21 | return float64(a*a) + 0.5, nil 22 | }, 23 | // Accumulate/forward outputs. 24 | func(a float64) error { 25 | results = append(results, a) 26 | return nil 27 | }) 28 | 29 | fmt.Println(results[:3], results[len(results)-3:]) 30 | 31 | // Output: 32 | // [1.5 4.5 9.5] [9604.5 9801.5 10000.5] 33 | } 34 | 35 | func ExampleSerial_parallelAggregation() { 36 | ngoroutines := 4 37 | results := make([]int, ngoroutines) // Goroutine-specific data and objects. 38 | 39 | Serial( 40 | ngoroutines, 41 | // Read/generate input data. 42 | RangeInput(1, 101), 43 | // Accumulate in goroutine-specific memory. 44 | func(a int, i, g int) (int, error) { 45 | results[g] += a 46 | return 0, nil // Unused. 47 | }, 48 | // No outputs. 49 | func(a int) error { return nil }) 50 | 51 | // Collect the results of all goroutines. 52 | fmt.Println("Sum of 1-100:", gnum.Sum(results)) 53 | 54 | // Output: 55 | // Sum of 1-100: 5050 56 | } 57 | 58 | func TestSerial(t *testing.T) { 59 | for _, nt := range []int{1, 2, 4, 8} { 60 | t.Run(fmt.Sprint(nt), func(t *testing.T) { 61 | n := nt * 100 62 | var result []int 63 | err := Serial( 64 | nt, 65 | RangeInput(0, n), 66 | func(a int, i int, g int) (int, error) { 67 | time.Sleep(time.Millisecond * time.Duration(rand.Intn(3))) 68 | return a * a, nil 69 | }, 70 | func(i int) error { 71 | result = append(result, i) 72 | return nil 73 | }) 74 | if err != nil { 75 | t.Fatalf("Serial(...) failed: %d", err) 76 | } 77 | for i := range result { 78 | if result[i] != i*i { 79 | t.Errorf("result[%d]=%d, want %d", i, result[i], i*i) 80 | } 81 | } 82 | }) 83 | } 84 | } 85 | 86 | func TestSerial_error(t *testing.T) { 87 | for _, nt := range []int{1, 2, 4, 8} { 88 | t.Run(fmt.Sprint(nt), func(t *testing.T) { 89 | n := nt * 100 90 | var result []int 91 | err := Serial( 92 | nt, 93 | RangeInput(0, n), 94 | func(a int, i int, g int) (int, error) { 95 | time.Sleep(time.Millisecond * time.Duration(rand.Intn(3))) 96 | if a > 300 { 97 | return 0, fmt.Errorf("a too big: %d", a) 98 | } 99 | return a * a, nil 100 | }, 101 | func(i int) error { 102 | result = append(result, i) 103 | return nil 104 | }) 105 | if nt <= 3 { 106 | if err != nil { 107 | t.Fatalf("Serial(...) failed: %d", err) 108 | } 109 | for i := range result { 110 | if result[i] != i*i { 111 | t.Errorf("result[%d]=%d, want %d", i, result[i], i*i) 112 | } 113 | } 114 | } else { // n > 3 115 | if err == nil { 116 | t.Fatalf("Serial(...) succeeded, want error") 117 | } 118 | } 119 | }) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /iterx/unreader_test.go: -------------------------------------------------------------------------------- 1 | package iterx 2 | 3 | import ( 4 | "iter" 5 | "reflect" 6 | "slices" 7 | "testing" 8 | ) 9 | 10 | func TestUnreader_until(t *testing.T) { 11 | input := []int{1, 4, 2, 6, 8, 4, 5, 7} 12 | tests := []struct { 13 | until int 14 | want []int 15 | }{ 16 | {6, []int{1, 4, 2}}, {4, []int{6, 8}}, {1, []int{4, 5, 7}}, 17 | } 18 | r := NewUnreader(Slice(input)) 19 | for _, test := range tests { 20 | var got []int 21 | for i := range r.Until(func(j int) bool { return j == test.until }) { 22 | got = append(got, i) 23 | } 24 | if !slices.Equal(got, test.want) { 25 | t.Fatalf("New(%v).Until(%d)=%v, want %v", input, test.until, got, test.want) 26 | } 27 | } 28 | } 29 | 30 | func TestUnreader_groupBy(t *testing.T) { 31 | input := []int{1, 4, 2, 6, 9, 4, 5, 7} 32 | want := [][]int{{1}, {4, 2, 6}, {9}, {4}, {5, 7}} 33 | var got [][]int 34 | r := NewUnreader(Slice(input)) 35 | for group := range r.GroupBy(func(i int, j int) bool { 36 | return i%2 == j%2 37 | }) { 38 | var got1 []int 39 | for i := range group { 40 | got1 = append(got1, i) 41 | } 42 | got = append(got, got1) 43 | } 44 | if !reflect.DeepEqual(got, want) { 45 | t.Fatalf("New(%v).GroupBy(...)=%v, want %v", input, got, want) 46 | } 47 | } 48 | 49 | func TestUnreader_stop(t *testing.T) { 50 | input := []int{1, 2, 3, 4, 5, 6, 7, 8, 9} 51 | 52 | stopped := false 53 | broke := false 54 | r := NewUnreader(stopIter(input, &stopped)) 55 | for x := range r.Until(func(i int) bool { return i == 5 }) { 56 | if x == 3 { 57 | broke = true 58 | break 59 | } 60 | } 61 | if !broke { 62 | t.Fatalf("Unreader(%v) never broke", input) 63 | } 64 | if !stopped { 65 | t.Fatalf("Unreader(%v) never called stop", input) 66 | } 67 | 68 | stopped, broke = false, false 69 | continued := true 70 | r = NewUnreader(stopIter(input, &stopped)) 71 | for g := range r.GroupBy(func(i int, j int) bool { 72 | return i/3 == j/3 73 | }) { 74 | continued = true 75 | for x := range g { 76 | if x == 5 { 77 | broke = true 78 | continued = false 79 | break 80 | } 81 | } 82 | } 83 | if !broke { 84 | t.Fatalf("Unreader(%v) never broke", input) 85 | } 86 | if continued { 87 | t.Fatalf("Unreader(%v).GroupBy outer loop continued "+ 88 | "after inner loop broke", input) 89 | } 90 | if !stopped { 91 | t.Fatalf("Unreader(%v) never called stop", input) 92 | } 93 | 94 | stopped = false 95 | toBreak := false 96 | r = NewUnreader(stopIter(input, &stopped)) 97 | for g := range r.GroupBy(func(i int, j int) bool { 98 | return i/3 == j/3 99 | }) { 100 | for x := range g { 101 | if x == 5 { 102 | toBreak = true 103 | } 104 | } 105 | if toBreak { 106 | break 107 | } 108 | } 109 | if !toBreak { 110 | t.Fatalf("Unreader(%v) never broke", input) 111 | } 112 | if !stopped { 113 | t.Fatalf("Unreader(%v) never called stop", input) 114 | } 115 | } 116 | 117 | func stopIter[T any](s []T, stopped *bool) iter.Seq[T] { 118 | return func(yield func(T) bool) { 119 | defer func() { 120 | *stopped = true 121 | }() 122 | for _, x := range s { 123 | if !yield(x) { 124 | return 125 | } 126 | } 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /nlp/lda-tool/main.go: -------------------------------------------------------------------------------- 1 | // Command lda-tool performs LDA on the input documents. 2 | package main 3 | 4 | import ( 5 | "bufio" 6 | "encoding/json" 7 | "flag" 8 | "fmt" 9 | "io" 10 | "os" 11 | "regexp" 12 | "runtime" 13 | "strings" 14 | 15 | "github.com/fluhus/gostuff/nlp" 16 | "golang.org/x/exp/maps" 17 | ) 18 | 19 | var ( 20 | k = flag.Int("k", 0, "Number of topics") 21 | numThreads = flag.Int("t", 1, "Number of therads to use") 22 | js = flag.Bool("j", false, "Output as JSON instead of default format") 23 | ) 24 | 25 | func main() { 26 | parseArgs() 27 | 28 | // Read input and perform LDA. 29 | fmt.Fprintln(os.Stdout, "Run with no arguments for usage help.") 30 | fmt.Fprintln(os.Stdout, "Reading documents from stdin...") 31 | docs, err := readDocs(os.Stdin) 32 | if err != nil { 33 | die("Error: failed to read input:", err) 34 | } 35 | fmt.Fprintln(os.Stdout, "Found", len(docs), "documents.") 36 | 37 | fmt.Fprintln(os.Stdout, "Performing LDA...") 38 | lda, _ := nlp.LdaThreads(docs, *k, *numThreads) 39 | 40 | // Print output. 41 | if *js { 42 | j, _ := json.MarshalIndent(lda, "", "\t") 43 | fmt.Println(string(j)) 44 | } else { 45 | for _, w := range maps.Keys(lda) { 46 | fmt.Print(w) 47 | for _, x := range lda[w] { 48 | fmt.Printf(" %v", x) 49 | } 50 | fmt.Println() 51 | } 52 | } 53 | } 54 | 55 | // readDocs reads documents, one per line, from the input reader. 56 | // It splits and lowercases the documents, and returns them as a 2d slice. 57 | func readDocs(r io.Reader) ([][]string, error) { 58 | wordsRe := regexp.MustCompile(`\w+`) 59 | scanner := bufio.NewScanner(r) 60 | var result [][]string 61 | for scanner.Scan() { 62 | w := wordsRe.FindAllString(strings.ToLower(scanner.Text()), -1) 63 | 64 | // Copy line to a lower capacity slice, to reduce memory usage. 65 | result = append(result, make([]string, len(w))) 66 | copy(result[len(result)-1], w) 67 | } 68 | if scanner.Err() != nil { 69 | return nil, scanner.Err() 70 | } 71 | return result, nil 72 | } 73 | 74 | // die reports an error message and exits with error code 2. 75 | // Arguments are treated like Println. 76 | func die(a ...interface{}) { 77 | fmt.Fprintln(os.Stderr, a...) 78 | os.Exit(2) 79 | } 80 | 81 | // parseArgs parses the program's arguments and validates them. 82 | // Exits with an error message upon validation error. 83 | func parseArgs() { 84 | flag.Parse() 85 | if len(os.Args) == 1 { 86 | fmt.Fprintln(os.Stderr, help) 87 | flag.PrintDefaults() 88 | os.Exit(1) 89 | } 90 | if *k < 1 { 91 | die("Error: invalid k:", *k) 92 | } 93 | if *numThreads < 0 { 94 | die("Error: invalid number of threads:", *numThreads) 95 | } 96 | if *numThreads == 0 { 97 | *numThreads = runtime.NumCPU() 98 | } 99 | } 100 | 101 | var help = `Performs LDA on the given documents. 102 | 103 | Input is read from the standard input. Format is one document per line. 104 | Documents will be lowercased and normalized (spaces and punctuation omitted). 105 | 106 | Output is printed to the standard output. Format is one word per line. 107 | Each word is followed by K numbers, the i'th number represents the likelihood 108 | of the i'th topic to emit that word. 109 | ` 110 | -------------------------------------------------------------------------------- /flagx/flagx_test.go: -------------------------------------------------------------------------------- 1 | package flagx 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "testing" 7 | ) 8 | 9 | func TestRegexp(t *testing.T) { 10 | fs := flag.NewFlagSet("", flag.PanicOnError) 11 | 12 | re := RegexpFlagSet(fs, "a", nil, "") 13 | fs.Parse([]string{"-a", "a..b"}) 14 | if (*re).String() != "a..b" { 15 | t.Errorf("RegexpFlagSet(...)=%q, want %q", 16 | (*re).String(), "a..b") 17 | } 18 | 19 | fs = flag.NewFlagSet("", flag.PanicOnError) 20 | re = RegexpFlagSet(fs, "a", nil, "") 21 | fs.Parse(nil) 22 | if (*re) != nil { 23 | t.Errorf("RegexpFlagSet(...)=%q, want nil", (*re).String()) 24 | } 25 | } 26 | 27 | func TestIntBetween(t *testing.T) { 28 | fs := flag.NewFlagSet("", flag.PanicOnError) 29 | 30 | ii := IntBetweenFlagSet(fs, "i", 3, "", 1, 5) 31 | if *ii != 3 { 32 | t.Errorf("IntBetweenFlagSet(...)=%v, want %v", ii, 3) 33 | } 34 | 35 | // Valid values. 36 | for i := 1; i <= 5; i++ { 37 | args := []string{"-i", fmt.Sprint(i)} 38 | fs.Parse(args) 39 | if *ii != i { 40 | t.Errorf("Parse(%v)=%v, want %v", args, ii, i) 41 | } 42 | } 43 | 44 | // Invalid values. 45 | for _, i := range []int{-1, 0, 6, 7, 10} { 46 | func() { 47 | args := []string{"-i", fmt.Sprint(i)} 48 | defer func() { 49 | recover() 50 | }() 51 | fs.Parse(args) 52 | t.Errorf("Parse(%v)=%v, want error", args, ii) 53 | }() 54 | } 55 | } 56 | 57 | func TestOneOf_string(t *testing.T) { 58 | fs := flag.NewFlagSet("", flag.PanicOnError) 59 | 60 | vals := []string{"blue", "yellow", "red"} 61 | ss := OneOfFlagSet(fs, "s", vals[0], "", vals...) 62 | if *ss != vals[0] { 63 | t.Errorf("StringFromFlagSet(...)=%v, want %v", ss, vals[0]) 64 | } 65 | 66 | // Valid values. 67 | for _, s := range vals { 68 | args := []string{"-s", s} 69 | fs.Parse(args) 70 | if *ss != s { 71 | t.Errorf("Parse(%v)=%v, want %v", args, ss, s) 72 | } 73 | } 74 | 75 | // Invalid values. 76 | for _, s := range vals { 77 | func() { 78 | args := []string{"-s", s + "."} 79 | defer func() { 80 | recover() 81 | }() 82 | fs.Parse(args) 83 | t.Errorf("Parse(%v)=%v, want error", args, ss) 84 | }() 85 | } 86 | } 87 | 88 | func TestOneOf_int(t *testing.T) { 89 | fs := flag.NewFlagSet("", flag.PanicOnError) 90 | 91 | vals := []int{3, 55, 888} 92 | oct := []string{"0o3", "0o67", "0o1570"} 93 | ss := OneOfFlagSet(fs, "s", vals[0], "", vals...) 94 | if *ss != vals[0] { 95 | t.Errorf("StringFromFlagSet(...)=%v, want %v", ss, vals[0]) 96 | } 97 | 98 | // Valid values. 99 | for _, i := range vals { 100 | args := []string{"-s", fmt.Sprint(i)} 101 | fs.Parse(args) 102 | if *ss != i { 103 | t.Errorf("Parse(%v)=%v, want %v", args, ss, i) 104 | } 105 | } 106 | 107 | // Octal representation. 108 | for i, s := range oct { 109 | args := []string{"-s", s} 110 | fs.Parse(args) 111 | want := vals[i] 112 | if *ss != want { 113 | t.Errorf("Parse(%v)=%v, want %v", args, ss, want) 114 | } 115 | } 116 | 117 | // Invalid values. 118 | for _, i := range vals { 119 | func() { 120 | args := []string{"-s", fmt.Sprint(i) + "aaa"} 121 | defer func() { 122 | recover() 123 | }() 124 | fs.Parse(args) 125 | t.Errorf("Parse(%v)=%v, want error", args, *ss) 126 | }() 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /ptimer/ptimer.go: -------------------------------------------------------------------------------- 1 | // Package ptimer provides a progress timer for iterative processes. 2 | // 3 | // A timer prints how much time passed since its creation at exponentially 4 | // growing time-points. 5 | // Precisely, prints are triggered after i calls to Inc, if i has only one non-zero 6 | // digit. That is: 1, 2, 3 .. 9, 10, 20, 30 .. 90, 100, 200, 300... 7 | // 8 | // # Output Format 9 | // 10 | // For a regular use: 11 | // 12 | // 00:00:00.000000 (00:00:00.000000) message 13 | // | | | 14 | // Total time since creation | | 15 | // | | 16 | // Average time per call to Inc | 17 | // | 18 | // User-defined message ----------------| 19 | // (default message is number of calls to Inc) 20 | // 21 | // When calling Done without calling Inc: 22 | // 23 | // 00:00:00.000000 message 24 | package ptimer 25 | 26 | import ( 27 | "fmt" 28 | "io" 29 | "os" 30 | "strings" 31 | "time" 32 | ) 33 | 34 | // A Timer measures time during iterative processes and prints the progress on 35 | // exponential checkpoints. 36 | type Timer struct { 37 | N int // Current count, incremented with each call to Inc 38 | W io.Writer // Timer's output, defaults to stderr 39 | t time.Time // Start time 40 | f func(int) string // Message function 41 | c int // Next checkpoint 42 | } 43 | 44 | // Prints the progress. 45 | func (t *Timer) print() { 46 | since := time.Since(t.t) 47 | if t.N == 0 { // Happens when calling Done without Inc. 48 | fmt.Fprintf(t.W, "%s %s", fmtDuration(since), t.f(t.N)) 49 | return 50 | } 51 | fmt.Fprintf(t.W, "\r%s (%s) %s", fmtDuration(since), 52 | fmtDuration(since/time.Duration(t.N)), t.f(t.N)) 53 | } 54 | 55 | // Formats a duration in constant-width format. 56 | func fmtDuration(d time.Duration) string { 57 | return fmt.Sprintf("%02d:%02d:%02d.%06d", 58 | d/time.Hour, 59 | d%time.Hour/time.Minute, 60 | d%time.Minute/time.Second, 61 | d%time.Second/time.Microsecond, 62 | ) 63 | } 64 | 65 | // NewFunc returns a new timer that calls f with the current count on checkpoints, 66 | // and prints its output. 67 | func NewFunc(f func(i int) string) *Timer { 68 | return &Timer{0, os.Stderr, time.Now(), f, 0} 69 | } 70 | 71 | // NewMessage returns a new timer that prints msg on checkpoints. 72 | // A "{}" in msg will be replaced with the current count. 73 | func NewMessage(msg string) *Timer { 74 | return NewFunc(func(i int) string { 75 | return strings.ReplaceAll(msg, "{}", fmt.Sprint(i)) 76 | }) 77 | } 78 | 79 | // New returns a new timer that prints the current count on checkpoints. 80 | func New() *Timer { 81 | return NewMessage("{}") 82 | } 83 | 84 | // Inc increments t's counter and prints progress if reached a checkpoint. 85 | func (t *Timer) Inc() { 86 | t.N++ 87 | if t.N >= t.c { 88 | t.print() 89 | for t.c <= t.N { 90 | t.c = nextCheckpoint(t.c) 91 | } 92 | } 93 | } 94 | 95 | // Done prints progress as if a checkpoint was reached. 96 | func (t *Timer) Done() { 97 | t.print() 98 | fmt.Fprintln(t.W) 99 | } 100 | 101 | // Returns the next i in which the timer should print, given the current i. 102 | func nextCheckpoint(i int) int { 103 | m := 1 104 | for m*10 <= i { 105 | m *= 10 106 | } 107 | if i%m != 0 { 108 | panic(fmt.Sprintf( 109 | "bad checkpoint: %d, should be a multiple of a power of 10", i)) 110 | } 111 | return i + m 112 | } 113 | -------------------------------------------------------------------------------- /iterx/unreader2_test.go: -------------------------------------------------------------------------------- 1 | package iterx 2 | 3 | import ( 4 | "iter" 5 | "reflect" 6 | "slices" 7 | "testing" 8 | 9 | "github.com/fluhus/gostuff/ppln" 10 | ) 11 | 12 | func TestUnreader2_until(t *testing.T) { 13 | input := []int{1, 4, 2, 6, 8, 4, 5, 7} 14 | tests := []struct { 15 | until int 16 | want []int 17 | }{ 18 | {6, []int{1, 4, 2}}, {4, []int{6, 8}}, {1, []int{4, 5, 7}}, 19 | } 20 | r := NewUnreader2(ppln.SliceInput(input)) 21 | for _, test := range tests { 22 | var got []int 23 | for i, err := range r.Until(func(j int, err error) bool { return j == test.until }) { 24 | if err != nil { 25 | t.Fatalf("New(%v).Until(%d) failed: %v", input, test.until, err) 26 | } 27 | got = append(got, i) 28 | } 29 | if !slices.Equal(got, test.want) { 30 | t.Fatalf("New(%v).Until(%d)=%v, want %v", input, test.until, got, test.want) 31 | } 32 | } 33 | } 34 | 35 | func TestUnreader2_groupBy(t *testing.T) { 36 | input := []int{1, 4, 2, 6, 9, 4, 5, 7} 37 | want := [][]int{{1}, {4, 2, 6}, {9}, {4}, {5, 7}} 38 | var got [][]int 39 | r := NewUnreader2(ppln.SliceInput(input)) 40 | for group := range r.GroupBy(func(i int, ie error, j int, je error) bool { 41 | return i%2 == j%2 42 | }) { 43 | var got1 []int 44 | for i := range group { 45 | got1 = append(got1, i) 46 | } 47 | got = append(got, got1) 48 | } 49 | if !reflect.DeepEqual(got, want) { 50 | t.Fatalf("New(%v).GroupBy(...)=%v, want %v", input, got, want) 51 | } 52 | } 53 | 54 | func TestUnreader2_stop(t *testing.T) { 55 | input := []int{1, 2, 3, 4, 5, 6, 7, 8, 9} 56 | 57 | stopped := false 58 | broke := false 59 | r := NewUnreader2(stopIter2(input, &stopped)) 60 | for x := range r.Until(func(i int, err error) bool { return i == 5 }) { 61 | if x == 3 { 62 | broke = true 63 | break 64 | } 65 | } 66 | if !broke { 67 | t.Fatalf("Unreader2(%v) never broke", input) 68 | } 69 | if !stopped { 70 | t.Fatalf("Unreader2(%v) never called stop", input) 71 | } 72 | 73 | stopped, broke = false, false 74 | continued := true 75 | r = NewUnreader2(stopIter2(input, &stopped)) 76 | for g := range r.GroupBy(func(i int, ie error, j int, je error) bool { 77 | return i/3 == j/3 78 | }) { 79 | continued = true 80 | for x := range g { 81 | if x == 5 { 82 | broke = true 83 | continued = false 84 | break 85 | } 86 | } 87 | } 88 | if !broke { 89 | t.Fatalf("Unreader2(%v) never broke", input) 90 | } 91 | if continued { 92 | t.Fatalf("Unreader2(%v).GroupBy outer loop continued "+ 93 | "after inner loop broke", input) 94 | } 95 | if !stopped { 96 | t.Fatalf("Unreader2(%v) never called stop", input) 97 | } 98 | 99 | stopped = false 100 | toBreak := false 101 | r = NewUnreader2(stopIter2(input, &stopped)) 102 | for g := range r.GroupBy(func(i int, ie error, j int, je error) bool { 103 | return i/3 == j/3 104 | }) { 105 | for x := range g { 106 | if x == 5 { 107 | toBreak = true 108 | } 109 | } 110 | if toBreak { 111 | break 112 | } 113 | } 114 | if !toBreak { 115 | t.Fatalf("Unreader2(%v) never broke", input) 116 | } 117 | if !stopped { 118 | t.Fatalf("Unreader2(%v) never called stop", input) 119 | } 120 | } 121 | 122 | func stopIter2[T any](s []T, stopped *bool) iter.Seq2[T, error] { 123 | return func(yield func(T, error) bool) { 124 | defer func() { 125 | *stopped = true 126 | }() 127 | for _, x := range s { 128 | if !yield(x, nil) { 129 | return 130 | } 131 | } 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /nlp/wordnet/parser_test.go: -------------------------------------------------------------------------------- 1 | package wordnet 2 | 3 | import ( 4 | "encoding/json" 5 | "reflect" 6 | "strings" 7 | "testing" 8 | ) 9 | 10 | func TestDataParser(t *testing.T) { 11 | expected := map[string]*Synset{ 12 | "v111": { 13 | "111", 14 | "v", 15 | []string{ 16 | "foo", 17 | "bar", 18 | "baz", 19 | }, 20 | []*Pointer{ 21 | {"!", "n123", -1, -1}, 22 | {"@", "a321", 0, 1}, 23 | }, 24 | []*Frame{ 25 | {4, 4}, 26 | {6, 6}, 27 | }, 28 | "hello world", 29 | nil, 30 | }, 31 | } 32 | 33 | actual := map[string]*Synset{} 34 | err := parseDataFile(strings.NewReader(testData), "v", map[string][]int{}, 35 | actual) 36 | if err != nil { 37 | t.Fatal("Parsing error:", err) 38 | } 39 | if !reflect.DeepEqual(expected, actual) { 40 | t.Error("Non-equal values:") 41 | t.Error(stringify(expected)) 42 | t.Error(stringify(actual)) 43 | } 44 | for key, ss := range actual { 45 | if ss.Id() != key { 46 | t.Errorf("ss.Id()=%v, want %v", ss.Id(), key) 47 | } 48 | } 49 | } 50 | 51 | func TestExceptionParser(t *testing.T) { 52 | expected := map[string][]string{ 53 | "n.foo": {"n.bar"}, 54 | "n.baz": {"n.bla", "n.blu"}, 55 | } 56 | 57 | actual := map[string][]string{} 58 | err := parseExceptionFile(strings.NewReader(testException), "n", actual) 59 | if err != nil { 60 | t.Fatal("Parsing error:", err) 61 | } 62 | if !reflect.DeepEqual(expected, actual) { 63 | t.Error("Non-equal values:") 64 | t.Error(stringify(expected)) 65 | t.Error(stringify(actual)) 66 | } 67 | } 68 | 69 | func TestExampleIndexParser(t *testing.T) { 70 | expected := map[string][]int{ 71 | "abash.37.0": {126, 127}, 72 | "abhor.37.0": {138, 139, 15}, 73 | } 74 | 75 | actual, err := parseExampleIndex(strings.NewReader(testExampleIndex)) 76 | if err != nil { 77 | t.Fatal("Parsing error:", err) 78 | } 79 | if !reflect.DeepEqual(expected, actual) { 80 | t.Error("Non-equal values:") 81 | t.Error(expected) 82 | t.Error(actual) 83 | } 84 | } 85 | 86 | func TestExampleParser(t *testing.T) { 87 | expected := map[string]string{ 88 | "111": "hello world", 89 | "222": "goodbye universe", 90 | } 91 | 92 | actual, err := parseExamples(strings.NewReader(testExamples)) 93 | if err != nil { 94 | t.Fatal("Parsing error:", err) 95 | } 96 | if !reflect.DeepEqual(expected, actual) { 97 | t.Error("Non-equal values:") 98 | t.Error(expected) 99 | t.Error(actual) 100 | } 101 | } 102 | 103 | func TestIndexParser(t *testing.T) { 104 | expected := map[string][]string{ 105 | "n.thing": {"na", "nb"}, 106 | "v.thing2": {"vc", "vd"}, 107 | } 108 | 109 | actual, err := parseIndex(strings.NewReader(testIndex)) 110 | if err != nil { 111 | t.Fatal("Parsing error:", err) 112 | } 113 | if !reflect.DeepEqual(expected, actual) { 114 | t.Error("Non-equal values:") 115 | t.Error(expected) 116 | t.Error(actual) 117 | } 118 | } 119 | 120 | func stringify(a interface{}) string { 121 | j, _ := json.Marshal(a) 122 | return string(j) 123 | } 124 | 125 | var testData = ` copyright line 126 | 111 1 v 3 foo 1 bar 3 baz 5 2 ! 123 n 0000 @ 321 a 0102 2 + 4 5 + 6 7 | hello world` 127 | 128 | var testException = `foo bar 129 | baz bla blu` 130 | 131 | var testExampleIndex = `abash%2:37:00:: 126,127 132 | abhor%2:37:00:: 138,139,15` 133 | 134 | var testExamples = `111 hello world 135 | 222 goodbye universe` 136 | 137 | var testIndex = ` copyright line 138 | thing n 2 3 x y z 2 2 a b 139 | thing2 v 4 1 x 4 2 c d e f` 140 | -------------------------------------------------------------------------------- /bits/bits_test.go: -------------------------------------------------------------------------------- 1 | package bits 2 | 3 | import ( 4 | "slices" 5 | "testing" 6 | ) 7 | 8 | func TestSet_true(t *testing.T) { 9 | tests := []struct { 10 | n int 11 | want []byte 12 | }{ 13 | {0, []byte{1, 0}}, 14 | {1, []byte{2, 0}}, 15 | {2, []byte{4, 0}}, 16 | {3, []byte{8, 0}}, 17 | {4, []byte{16, 0}}, 18 | {5, []byte{32, 0}}, 19 | {6, []byte{64, 0}}, 20 | {7, []byte{128, 0}}, 21 | {8, []byte{0, 1}}, 22 | {9, []byte{0, 2}}, 23 | {10, []byte{0, 4}}, 24 | {11, []byte{0, 8}}, 25 | {12, []byte{0, 16}}, 26 | {13, []byte{0, 32}}, 27 | {14, []byte{0, 64}}, 28 | {15, []byte{0, 128}}, 29 | } 30 | for _, test := range tests { 31 | b := []byte{0, 0} 32 | Set(b, test.n, true) 33 | if !slices.Equal(b, test.want) { 34 | t.Errorf("Set(%v,%v,%v)=%v, want %v", 35 | []byte{0, 0}, test.n, true, b, test.want) 36 | } 37 | } 38 | } 39 | 40 | func TestSet_false(t *testing.T) { 41 | tests := []struct { 42 | n int 43 | want []byte 44 | }{ 45 | {0, []byte{255 - 1, 255}}, 46 | {1, []byte{255 - 2, 255}}, 47 | {2, []byte{255 - 4, 255}}, 48 | {3, []byte{255 - 8, 255}}, 49 | {4, []byte{255 - 16, 255}}, 50 | {5, []byte{255 - 32, 255}}, 51 | {6, []byte{255 - 64, 255}}, 52 | {7, []byte{255 - 128, 255}}, 53 | {8, []byte{255, 255 - 1}}, 54 | {9, []byte{255, 255 - 2}}, 55 | {10, []byte{255, 255 - 4}}, 56 | {11, []byte{255, 255 - 8}}, 57 | {12, []byte{255, 255 - 16}}, 58 | {13, []byte{255, 255 - 32}}, 59 | {14, []byte{255, 255 - 64}}, 60 | {15, []byte{255, 255 - 128}}, 61 | } 62 | for _, test := range tests { 63 | b := []byte{255, 255} 64 | Set(b, test.n, false) 65 | if !slices.Equal(b, test.want) { 66 | t.Errorf("Set(%v,%v,%v)=%v, want %v", 67 | []byte{0, 0}, test.n, false, b, test.want) 68 | } 69 | } 70 | } 71 | 72 | func TestGet(t *testing.T) { 73 | input := []byte{0b10011001, 0b01100111} 74 | want := []int{1, 0, 0, 1, 1, 0, 0, 1, 75 | 1, 1, 1, 0, 0, 1, 1, 0} 76 | for i := range want { 77 | if got := Get(input, i); got != want[i] { 78 | t.Errorf("Get(%v,%v)=%v, want %v", input, i, got, want[i]) 79 | } 80 | } 81 | } 82 | 83 | func TestSum(t *testing.T) { 84 | tests := []struct { 85 | input []byte 86 | want int 87 | }{ 88 | {nil, 0}, 89 | {[]byte{0}, 0}, 90 | {[]byte{0b01100111}, 5}, 91 | {[]byte{0b10011001, 0b01100111}, 9}, 92 | } 93 | for _, test := range tests { 94 | if got := Sum(test.input); got != test.want { 95 | t.Errorf("Sum(%v)=%v, want %v", test.input, got, test.want) 96 | } 97 | } 98 | } 99 | 100 | func TestByteOnes(t *testing.T) { 101 | tests := []struct { 102 | i int 103 | want []int 104 | }{ 105 | {0, nil}, 106 | {1, []int{0}}, 107 | {2, []int{1}}, 108 | {3, []int{0, 1}}, 109 | {4, []int{2}}, 110 | {5, []int{0, 2}}, 111 | {255, []int{0, 1, 2, 3, 4, 5, 6, 7}}, 112 | } 113 | for _, test := range tests { 114 | if got := byteOnes[test.i]; !slices.Equal(got, test.want) { 115 | t.Errorf("byteOnes(%v)=%v, want %v", test.i, got, test.want) 116 | } 117 | } 118 | } 119 | 120 | func TestOnesZeros(t *testing.T) { 121 | input := []byte{0b01011100, 0b11101010} 122 | wantOnes := []int{2, 3, 4, 6, 9, 11, 13, 14, 15} 123 | wantZeros := []int{0, 1, 5, 7, 8, 10, 12} 124 | 125 | gotOnes := slices.Collect(Ones(input)) 126 | if !slices.Equal(gotOnes, wantOnes) { 127 | t.Errorf("Ones(%v)=%v, want %v", input, gotOnes, wantOnes) 128 | } 129 | 130 | gotZeros := slices.Collect(Zeros(input)) 131 | if !slices.Equal(gotZeros, wantZeros) { 132 | t.Errorf("Zeros(%v)=%v, want %v", input, gotZeros, wantZeros) 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /clustering/upgma.go: -------------------------------------------------------------------------------- 1 | package clustering 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | 7 | "github.com/fluhus/gostuff/gnum" 8 | "github.com/fluhus/gostuff/heaps" 9 | ) 10 | 11 | // distPyramid is a distance half-matrix. 12 | type distPyramid [][]float64 13 | 14 | // dist returns the distance between a and b. 15 | func (d distPyramid) dist(a, b int) float64 { 16 | if a > b { 17 | return d[a][b] 18 | } 19 | return d[b][a] 20 | } 21 | 22 | // makePyramid creates a distance half-matrix. 23 | func makePyramid(n int, f func(int, int) float64) distPyramid { 24 | nn := n * (n - 1) / 2 25 | d := make([]float64, 0, nn) 26 | for i := 1; i < n; i++ { 27 | for j := 0; j < i; j++ { 28 | d = append(d, f(j, i)) 29 | } 30 | } 31 | result := make([][]float64, n) 32 | j := 0 33 | for i := range result { 34 | result[i] = d[j : j+i] 35 | j += i 36 | } 37 | return result 38 | } 39 | 40 | // upgma is an implementation of UPGMA clustering. The distance between clusters 41 | // is the average distance between pairs of their individual elements. 42 | func upgma(n int, f func(int, int) float64) *AggloResult { 43 | pi := make([]int, n) // Index of first merge target of each element. 44 | lambda := make([]float64, n) // Distance of first merge target of each element. 45 | 46 | // Last cluster does not get matched with anyone -> max distance. 47 | lambda[len(lambda)-1] = math.MaxFloat64 48 | 49 | // Calculate raw distances. 50 | d := makePyramid(n, f) 51 | heapss := make([]*heaps.Heap[upgmaCluster], n) 52 | for i := range heapss { 53 | heapss[i] = heaps.New(compareUpgmaClusters) 54 | } 55 | for i := 1; i < n; i++ { 56 | for j := 0; j < i; j++ { 57 | heapss[i].Push(upgmaCluster{j, d[i][j]}) 58 | heapss[j].Push(upgmaCluster{i, d[i][j]}) 59 | } 60 | } 61 | 62 | // Clustering. 63 | sizes := gnum.Ones[[]float64](n) // Cluster sizes 64 | // The identifier of each cluster = highest index of an element 65 | names := make([]int, n) 66 | for i := range names { 67 | names[i] = i 68 | } 69 | for i := 0; i < n-1; i++ { 70 | // Find lowest distance. 71 | fmin := math.MaxFloat64 72 | a, b := -1, -1 73 | for hi, h := range heapss { 74 | if h == nil { 75 | continue 76 | } 77 | // Clean up removed clusters. 78 | if h.Len() == 0 { 79 | panic(fmt.Sprintf("heap %d with length 0", hi)) 80 | } 81 | for heapss[h.Head().i] == nil { 82 | h.Pop() 83 | } 84 | if h.Head().d < fmin { 85 | a = hi 86 | fmin = h.Head().d 87 | b = h.Head().i 88 | } 89 | } 90 | 91 | // Create agglo step. 92 | nmin := min(names[a], names[b]) 93 | nmax := max(names[a], names[b]) 94 | pi[nmin] = nmax 95 | lambda[nmin] = fmin 96 | 97 | // Merge clusters. 98 | names = append(names, nmax) 99 | sizes = append(sizes, sizes[a]+sizes[b]) 100 | heapss[a] = nil 101 | heapss[b] = nil 102 | var cdist []float64 103 | cheap := heaps.New(compareUpgmaClusters) 104 | for hi, h := range heapss { 105 | if h == nil { 106 | cdist = append(cdist, 0) 107 | continue 108 | } 109 | da := d.dist(a, hi) * sizes[a] 110 | db := d.dist(b, hi) * sizes[b] 111 | dd := (da + db) / (sizes[a] + sizes[b]) 112 | cdist = append(cdist, dd) 113 | h.Push(upgmaCluster{len(sizes) - 1, dd}) 114 | cheap.Push(upgmaCluster{hi, dd}) 115 | } 116 | d = append(d, cdist) 117 | heapss = append(heapss, cheap) 118 | } 119 | 120 | return newAggloResult(pi, lambda) 121 | } 122 | 123 | // Cluster info in UPGMA. 124 | type upgmaCluster struct { 125 | i int // Cluster index 126 | d float64 // Distance from cluster i 127 | } 128 | 129 | func compareUpgmaClusters(a, b upgmaCluster) bool { 130 | return a.d < b.d 131 | } 132 | -------------------------------------------------------------------------------- /graphs/bfsdfs_test.go: -------------------------------------------------------------------------------- 1 | package graphs 2 | 3 | import ( 4 | "cmp" 5 | "slices" 6 | "testing" 7 | 8 | "github.com/fluhus/gostuff/snm" 9 | ) 10 | 11 | func TestBFS(t *testing.T) { 12 | edges := [][2]string{ 13 | {"a", "b"}, 14 | {"a", "c"}, 15 | {"b", "d"}, 16 | {"b", "e"}, 17 | {"c", "f"}, 18 | {"e", "g"}, 19 | {"e", "h"}, 20 | } 21 | want := [][]string{ 22 | {"a"}, 23 | {"b", "c"}, 24 | {"d", "e", "f"}, 25 | {"g", "h"}, 26 | } 27 | 28 | snm.Shuffle(edges) 29 | g := New[string]() 30 | for _, e := range edges { 31 | g.AddEdge(e[0], e[1]) 32 | } 33 | 34 | got := slices.Collect(g.BFS("a")) 35 | if wantLen := sumLens(want); len(got) != wantLen { 36 | t.Fatalf("BFS(...) len=%v, want %v", len(got), wantLen) 37 | } 38 | ggot := groupLike(got, want) 39 | for _, g := range ggot { 40 | slices.Sort(g) 41 | } 42 | if !slices.EqualFunc(ggot, want, slices.Equal) { 43 | t.Fatalf("BFS(...)=%v, want %v", got, want) 44 | } 45 | } 46 | 47 | func TestBFS_loop(t *testing.T) { 48 | edges := [][2]string{ 49 | {"a", "b"}, 50 | {"b", "c"}, 51 | {"b", "d"}, 52 | {"c", "e"}, 53 | {"d", "f"}, 54 | {"e", "g"}, 55 | {"f", "g"}, 56 | } 57 | want := [][]string{ 58 | {"a"}, 59 | {"b"}, 60 | {"c", "d"}, 61 | {"e", "f"}, 62 | {"g"}, 63 | } 64 | 65 | snm.Shuffle(edges) 66 | g := New[string]() 67 | for _, e := range edges { 68 | g.AddEdge(e[0], e[1]) 69 | } 70 | 71 | got := slices.Collect(g.BFS("a")) 72 | if wantLen := sumLens(want); len(got) != wantLen { 73 | t.Fatalf("BFS(...) len=%v, want %v", len(got), wantLen) 74 | } 75 | ggot := groupLike(got, want) 76 | for _, g := range ggot { 77 | slices.Sort(g) 78 | } 79 | if !slices.EqualFunc(ggot, want, slices.Equal) { 80 | t.Fatalf("BFS(...)=%v, want %v", got, want) 81 | } 82 | } 83 | 84 | func TestDFS(t *testing.T) { 85 | edges := [][2]string{ 86 | {"a", "b"}, 87 | {"a", "c"}, 88 | {"b", "d"}, 89 | {"b", "e"}, 90 | {"c", "f"}, 91 | {"c", "g"}, 92 | } 93 | want := [][]string{ 94 | {"a"}, 95 | {"b", "d", "e"}, 96 | {"c", "f", "g"}, 97 | } 98 | 99 | snm.Shuffle(edges) 100 | g := New[string]() 101 | for _, e := range edges { 102 | g.AddEdge(e[0], e[1]) 103 | } 104 | 105 | got := slices.Collect(g.DFS("a")) 106 | if wantLen := sumLens(want); len(got) != wantLen { 107 | t.Fatalf("DFS(...) len=%v, want %v", len(got), wantLen) 108 | } 109 | ggot := groupLike(got, want) 110 | for _, g := range ggot { 111 | slices.Sort(g[1:]) 112 | } 113 | slices.SortFunc(ggot, func(a, b []string) int { 114 | return cmp.Compare(a[0], b[0]) 115 | }) 116 | if !slices.EqualFunc(ggot, want, slices.Equal) { 117 | t.Fatalf("DFS(...)=%v, want %v", got, want) 118 | } 119 | } 120 | 121 | func TestDFS_loop(t *testing.T) { 122 | edges := [][2]string{ 123 | {"a", "b"}, 124 | {"b", "c"}, 125 | {"c", "d"}, 126 | {"d", "a"}, 127 | } 128 | want := [][]string{ 129 | {"b", "c", "d", "a"}, 130 | {"b", "a", "d", "c"}, 131 | } 132 | 133 | snm.Shuffle(edges) 134 | g := New[string]() 135 | for _, e := range edges { 136 | g.AddEdge(e[0], e[1]) 137 | } 138 | 139 | got := slices.Collect(g.DFS("b")) 140 | found := slices.IndexFunc(want, func(a []string) bool { 141 | return slices.Equal(a, got) 142 | }) 143 | if found == -1 { 144 | t.Fatalf("DFS(...)=%v, want one of %v", got, want) 145 | } 146 | } 147 | 148 | // Returns the sum of lengths of slices. 149 | func sumLens(a [][]string) int { 150 | i := 0 151 | for _, x := range a { 152 | i += len(x) 153 | } 154 | return i 155 | } 156 | 157 | // Returns a sliced like like. 158 | func groupLike(a []string, like [][]string) [][]string { 159 | var result [][]string 160 | for _, x := range like { 161 | result = append(result, a[:len(x)]) 162 | a = a[len(x):] 163 | } 164 | return result 165 | } 166 | -------------------------------------------------------------------------------- /hll/hll.go: -------------------------------------------------------------------------------- 1 | // Package hll provides an implementation of the HyperLogLog algorithm. 2 | // 3 | // A HyperLogLog counter can approximate the cardinality of a set with high 4 | // accuracy and little memory. 5 | // 6 | // # Accuracy 7 | // 8 | // Average error for 1,000,000,000 elements for different values of logSize: 9 | // 10 | // logSize average error % 11 | // 4 21 12 | // 5 12 13 | // 6 10 14 | // 7 8.1 15 | // 8 4.8 16 | // 9 3.6 17 | // 10 1.9 18 | // 11 1.2 19 | // 12 1.0 20 | // 13 0.7 21 | // 14 0.5 22 | // 15 0.33 23 | // 16 0.25 24 | // 25 | // # Citation 26 | // 27 | // Flajolet, Philippe; Fusy, Éric; Gandouet, Olivier; Meunier, Frédéric (2007). 28 | // "Hyperloglog: The analysis of a near-optimal cardinality estimation 29 | // algorithm". Discrete Mathematics and Theoretical Computer Science 30 | // Proceedings. 31 | package hll 32 | 33 | import ( 34 | "fmt" 35 | "math" 36 | ) 37 | 38 | // An HLL is a HyperLogLog counter for arbitrary values. 39 | type HLL[T any] struct { 40 | counters []byte 41 | h func(T) uint64 42 | nbits int 43 | m int 44 | mask uint64 45 | } 46 | 47 | // New creates a new HyperLogLog counter. 48 | // The counter will use 2^logSize bytes. 49 | // h is the hash function to use for added values. 50 | func New[T any](logSize int, h func(T) uint64) *HLL[T] { 51 | if logSize < 4 { 52 | panic(fmt.Sprintf("logSize=%v, should be at least 4", logSize)) 53 | } 54 | m := 1 << logSize 55 | return &HLL[T]{ 56 | counters: make([]byte, m), 57 | h: h, 58 | nbits: logSize, 59 | m: m, 60 | mask: uint64(m - 1), 61 | } 62 | } 63 | 64 | // Add adds v to the counter. Calls hash once. 65 | func (h *HLL[T]) Add(t T) { 66 | hash := h.h(t) 67 | idx := hash & h.mask 68 | fp := hash >> h.nbits 69 | z := byte(h.nzeros(fp)) + 1 70 | if z > h.counters[idx] { 71 | h.counters[idx] = z 72 | } 73 | } 74 | 75 | // ApproxCount returns the current approximate count. 76 | // Does not alter the state of the counter. 77 | func (h *HLL[T]) ApproxCount() int { 78 | z := 0.0 79 | for _, v := range h.counters { 80 | z += math.Pow(2, -float64(v)) 81 | } 82 | z = 1.0 / z 83 | fm := float64(h.m) 84 | result := int(h.alpha() * fm * fm * z) 85 | 86 | if result < h.m*5/2 { 87 | zeros := 0 88 | for _, v := range h.counters { 89 | if v == 0 { 90 | zeros++ 91 | } 92 | } 93 | // If some registers are zero, use linear counting. 94 | if zeros > 0 { 95 | result = int(fm * math.Log(fm/float64(zeros))) 96 | } 97 | } 98 | 99 | return result 100 | } 101 | 102 | // Returns the alpha value to use depending on m. 103 | func (h *HLL[T]) alpha() float64 { 104 | switch h.m { 105 | case 16: 106 | return 0.673 107 | case 32: 108 | return 0.697 109 | case 64: 110 | return 0.709 111 | } 112 | return 0.7213 / (1 + 1.079/float64(h.m)) 113 | } 114 | 115 | // nzeros counts the number of zeros on the right side of a binary number. 116 | func (h *HLL[T]) nzeros(a uint64) int { 117 | if a == 0 { 118 | return 64 - h.nbits // Number of bits after using the first nbits. 119 | } 120 | n := 0 121 | for a&1 == 0 { 122 | n++ 123 | a /= 2 124 | } 125 | return n 126 | } 127 | 128 | // AddHLL adds the state of another counter to h, 129 | // assuming they use the same hash function. 130 | // The result is equivalent to adding all the values of other to h. 131 | func (h *HLL[T]) AddHLL(other *HLL[T]) { 132 | if len(h.counters) != len(other.counters) { 133 | panic("merging HLLs with different sizes") 134 | } 135 | for i, b := range other.counters { 136 | if h.counters[i] < b { 137 | h.counters[i] = b 138 | } 139 | } 140 | } 141 | 142 | // LogSize returns the logSize parameter that was used to create this counter. 143 | func (h *HLL[T]) LogSize() int { 144 | return h.nbits 145 | } 146 | -------------------------------------------------------------------------------- /ppln/serial.go: -------------------------------------------------------------------------------- 1 | package ppln 2 | 3 | import ( 4 | "fmt" 5 | "iter" 6 | "sync" 7 | "sync/atomic" 8 | 9 | "github.com/fluhus/gostuff/heaps" 10 | ) 11 | 12 | // Serial starts a multi-goroutine transformation pipeline that maintains the 13 | // order of the inputs. 14 | // 15 | // Input is an iterator over the input values to be transformed. 16 | // It will be called in a thread-safe manner. 17 | // Transform receives an input (a), 0-based input serial number (i), 0-based 18 | // goroutine number (g), and returns the result of processing a. 19 | // Output acts on a single result, and will be called by the same 20 | // order of the input, in a thread-safe manner. 21 | // 22 | // If one of the functions returns a non-nil error, the process stops and the 23 | // error is returned. Otherwise returns nil. 24 | func Serial[T1 any, T2 any]( 25 | ngoroutines int, 26 | input iter.Seq2[T1, error], 27 | transform func(a T1, i int, g int) (T2, error), 28 | output func(a T2) error) error { 29 | if ngoroutines < 1 { 30 | panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines)) 31 | } 32 | pull, pstop := iter.Pull2(input) 33 | defer pstop() 34 | 35 | // An optimization for a single thread. 36 | if ngoroutines == 1 { 37 | i := 0 38 | for { 39 | t1, err, ok := pull() 40 | ii := i 41 | i++ 42 | 43 | if !ok { 44 | return nil 45 | } 46 | if err != nil { 47 | return err 48 | } 49 | 50 | t2, err := transform(t1, ii, 0) 51 | if err != nil { 52 | return err 53 | } 54 | if err := output(t2); err != nil { 55 | return err 56 | } 57 | } 58 | } 59 | 60 | ilock := &sync.Mutex{} 61 | olock := &sync.Mutex{} 62 | errs := make(chan error, ngoroutines) 63 | stop := &atomic.Bool{} 64 | items := &serialHeap[T2]{ 65 | data: heaps.New(func(a, b serialItem[T2]) bool { 66 | return a.i < b.i 67 | }), 68 | } 69 | 70 | i := 0 71 | for g := 0; g < ngoroutines; g++ { 72 | go func(g int) { 73 | for { 74 | if stop.Load() { 75 | errs <- nil 76 | return 77 | } 78 | 79 | ilock.Lock() 80 | t1, err, ok := pull() 81 | ii := i 82 | i++ 83 | ilock.Unlock() 84 | 85 | if !ok { 86 | errs <- nil 87 | return 88 | } 89 | if err != nil { 90 | stop.Store(true) 91 | errs <- err 92 | return 93 | } 94 | 95 | t2, err := transform(t1, ii, g) 96 | if err != nil { 97 | stop.Store(true) 98 | errs <- err 99 | return 100 | } 101 | 102 | olock.Lock() 103 | items.put(serialItem[T2]{ii, t2}) 104 | for items.ok() { 105 | err = output(items.pop()) 106 | if err != nil { 107 | olock.Unlock() 108 | stop.Store(true) 109 | errs <- err 110 | return 111 | } 112 | } 113 | olock.Unlock() 114 | } 115 | }(g) 116 | } 117 | 118 | for g := 0; g < ngoroutines; g++ { 119 | if err := <-errs; err != nil { 120 | return err 121 | } 122 | } 123 | return nil 124 | } 125 | 126 | // General data with a serial number. 127 | type serialItem[T any] struct { 128 | i int 129 | data T 130 | } 131 | 132 | // A heap of serial items. Sorts by serial number. 133 | type serialHeap[T any] struct { 134 | next int 135 | data *heaps.Heap[serialItem[T]] 136 | } 137 | 138 | // Checks whether the minimal element in the heap is the next in the series. 139 | func (s *serialHeap[T]) ok() bool { 140 | return s.data.Len() > 0 && s.data.Head().i == s.next 141 | } 142 | 143 | // Removes and returns the minimal element in the heap. Panics if the element 144 | // is not the next in the series. 145 | func (s *serialHeap[T]) pop() T { 146 | if !s.ok() { 147 | panic("get when not ok") 148 | } 149 | s.next++ 150 | a := s.data.Pop() 151 | return a.data 152 | } 153 | 154 | // Adds an item to the heap. 155 | func (s *serialHeap[T]) put(item serialItem[T]) { 156 | if item.i < s.next { 157 | panic(fmt.Sprintf("put(%d) when next is %d", item.i, s.next)) 158 | } 159 | s.data.Push(item) 160 | } 161 | -------------------------------------------------------------------------------- /heaps/heaps.go: -------------------------------------------------------------------------------- 1 | // Package heaps provides generic heaps. 2 | // 3 | // This package provides better run speeds than the standard [heap] package. 4 | package heaps 5 | 6 | import ( 7 | "golang.org/x/exp/constraints" 8 | ) 9 | 10 | // Heap is a generic heap. 11 | type Heap[T any] struct { 12 | a []T 13 | less func(T, T) bool 14 | } 15 | 16 | // New returns a new heap that uses the given comparator function. 17 | func New[T any](less func(T, T) bool) *Heap[T] { 18 | return &Heap[T]{nil, less} 19 | } 20 | 21 | // Min returns a new min-heap of an ordered type by its natural order. 22 | func Min[T constraints.Ordered]() *Heap[T] { 23 | return New(func(t1, t2 T) bool { 24 | return t1 < t2 25 | }) 26 | } 27 | 28 | // Max returns a new max-heap of an ordered type by its natural order. 29 | func Max[T constraints.Ordered]() *Heap[T] { 30 | return New(func(t1, t2 T) bool { 31 | return t1 > t2 32 | }) 33 | } 34 | 35 | // Push adds x to h while maintaining its heap invariants. 36 | func (h *Heap[T]) Push(x T) { 37 | h.a = append(h.a, x) 38 | i := len(h.a) - 1 39 | for i != -1 { 40 | i = h.bubbleUp(i) 41 | } 42 | } 43 | 44 | // PushSlice adds the elements of s to h while maintaining its heap invariants. 45 | // The complexity is O(new n), so it should be typically used to initialize a 46 | // new heap. 47 | func (h *Heap[T]) PushSlice(s []T) { 48 | h.a = append(h.a, s...) 49 | for i := len(h.a) - 1; i >= 0; i-- { 50 | j := i 51 | for j != -1 { 52 | j = h.bubbleDown(j) 53 | } 54 | } 55 | } 56 | 57 | // Pop removes and returns the minimal element in h. 58 | func (h *Heap[T]) Pop() T { 59 | if len(h.a) == 0 { 60 | panic("called Pop() on an empty heap") 61 | } 62 | x := h.a[0] 63 | h.a[0] = h.a[len(h.a)-1] 64 | h.a = h.a[:len(h.a)-1] 65 | i := 0 66 | for i != -1 { 67 | i = h.bubbleDown(i) 68 | } 69 | // Shrink if needed. 70 | if cap(h.a) >= 16 && len(h.a) <= cap(h.a)/4 { 71 | h.a = append(make([]T, 0, cap(h.a)/2), h.a...) 72 | } 73 | return x 74 | } 75 | 76 | // Len returns the number of elements in h. 77 | func (h *Heap[T]) Len() int { 78 | return len(h.a) 79 | } 80 | 81 | // Moves the i'th element down and returns its new index. 82 | // Returns -1 when no more bubble-downs are needed. 83 | func (h *Heap[T]) bubbleDown(i int) int { 84 | ia, ib := i*2+1, i*2+2 85 | if len(h.a) < ib { // No children 86 | return -1 87 | } 88 | if len(h.a) == ib { // Only one child 89 | if h.less(h.a[ia], h.a[i]) { 90 | h.a[i], h.a[ia] = h.a[ia], h.a[i] 91 | } 92 | return -1 93 | } 94 | if h.less(h.a[ib], h.a[ia]) { 95 | ia = ib 96 | } 97 | if h.less(h.a[ia], h.a[i]) { 98 | h.a[i], h.a[ia] = h.a[ia], h.a[i] 99 | } 100 | return ia 101 | } 102 | 103 | // Moves the i'th element up and returns its new index. 104 | // Returns -1 when no more bubble-ups are needed. 105 | func (h *Heap[T]) bubbleUp(i int) int { 106 | if i == 0 { 107 | return -1 108 | } 109 | ia := (i - 1) / 2 110 | if !h.less(h.a[i], h.a[ia]) { 111 | return -1 112 | } 113 | h.a[i], h.a[ia] = h.a[ia], h.a[i] 114 | return ia 115 | } 116 | 117 | // View returns the underlying slice of h, containing all of its elements. 118 | // Modifying the slice may invalidate the heap. 119 | func (h *Heap[T]) View() []T { 120 | return h.a 121 | } 122 | 123 | // Head returns the minimal element in h. 124 | func (h *Heap[T]) Head() T { 125 | return h.a[0] 126 | } 127 | 128 | // Fix fixes the heap after a single value had been modified. 129 | // i is the index of the modified value. 130 | func (h *Heap[T]) Fix(i int) { 131 | wentUp := false 132 | for { 133 | j := h.bubbleUp(i) 134 | if j == -1 { 135 | break 136 | } 137 | i = j 138 | wentUp = true 139 | } 140 | if !wentUp { 141 | for { 142 | j := h.bubbleDown(i) 143 | if j == -1 { 144 | break 145 | } 146 | i = j 147 | } 148 | } 149 | } 150 | 151 | // Clip removes unused capacity from the heap. 152 | func (h *Heap[T]) Clip() { 153 | h.a = append(make([]T, 0, len(h.a)), h.a...) 154 | } 155 | -------------------------------------------------------------------------------- /bnry/bnry_test.go: -------------------------------------------------------------------------------- 1 | package bnry 2 | 3 | import ( 4 | "reflect" 5 | "slices" 6 | "testing" 7 | ) 8 | 9 | func TestMarshal(t *testing.T) { 10 | a := byte(113) 11 | b := uint64(2391278932173219) 12 | c := "amit" 13 | d := int16(10000) 14 | e := []int32{1, 11, 100, 433223} 15 | f := true 16 | g := false 17 | buf := MarshalBinary(a, b, c, d, e, f, g) 18 | var aa byte 19 | var bb uint64 20 | var cc string 21 | var dd int16 22 | var ee []int32 23 | var ff bool 24 | var gg bool 25 | if err := UnmarshalBinary( 26 | buf, &aa, &bb, &cc, &dd, &ee, &ff, &gg); err != nil { 27 | t.Fatalf("UnmarshalBinary(%v) failed: %v", buf, err) 28 | } 29 | if aa != a { 30 | t.Errorf("UnmarshalBinary(...)=%v, want %v", aa, a) 31 | } 32 | if bb != b { 33 | t.Errorf("UnmarshalBinary(...)=%v, want %v", bb, b) 34 | } 35 | if cc != c { 36 | t.Errorf("UnmarshalBinary(...)=%v, want %v", cc, c) 37 | } 38 | if dd != d { 39 | t.Errorf("UnmarshalBinary(...)=%v, want %v", dd, d) 40 | } 41 | if !slices.Equal(ee, e) { 42 | t.Errorf("UnmarshalBinary(...)=%v, want %v", ee, e) 43 | } 44 | if ff != f { 45 | t.Errorf("UnmarshalBinary(...)=%v, want %v", dd, d) 46 | } 47 | if gg != g { 48 | t.Errorf("UnmarshalBinary(...)=%v, want %v", dd, d) 49 | } 50 | } 51 | 52 | func FuzzMarshal(f *testing.F) { 53 | f.Add(uint8(1), int16(1), uint32(1), int64(1), "", true, 1.0, float32(1)) 54 | f.Fuzz(func(t *testing.T, a uint8, b int16, c uint32, d int64, e string, 55 | g bool, h float64, i float32) { 56 | buf := MarshalBinary(a, b, c, d, e, g, h, i) 57 | var ( 58 | aa uint8 59 | bb int16 60 | cc uint32 61 | dd int64 62 | ee string 63 | gg bool 64 | hh float64 65 | ii float32 66 | ) 67 | err := UnmarshalBinary(buf, &aa, &bb, &cc, &dd, &ee, &gg, &hh, &ii) 68 | if err != nil { 69 | t.Fatal(err) 70 | } 71 | if aa != a { 72 | t.Fatalf("got %v, want %v", aa, a) 73 | } 74 | if bb != b { 75 | t.Fatalf("got %v, want %v", bb, b) 76 | } 77 | if cc != c { 78 | t.Fatalf("got %v, want %v", cc, c) 79 | } 80 | if dd != d { 81 | t.Fatalf("got %v, want %v", dd, d) 82 | } 83 | if ee != e { 84 | t.Fatalf("got %v, want %v", ee, e) 85 | } 86 | if gg != g { 87 | t.Fatalf("got %v, want %v", gg, g) 88 | } 89 | if hh != h { 90 | t.Fatalf("got %v, want %v", hh, h) 91 | } 92 | if ii != i { 93 | t.Fatalf("got %v, want %v", ii, i) 94 | } 95 | }) 96 | } 97 | 98 | func TestMarshal_slices(t *testing.T) { 99 | a := []uint32{321321, 213, 4944} 100 | b := []string{"this", "is", "", "a", "slice"} 101 | c := []int8{100, 9, 0, -21} 102 | d := []bool{true, false, false, true, true} 103 | buf := MarshalBinary(slices.Clone(a), slices.Clone(b), 104 | slices.Clone(c), slices.Clone(d)) 105 | var ( 106 | aa []uint32 107 | bb []string 108 | cc []int8 109 | dd []bool 110 | ) 111 | err := UnmarshalBinary(buf, &aa, &bb, &cc, &dd) 112 | if err != nil { 113 | t.Fatal("UnmarshalBinary(...) failed:", err) 114 | } 115 | inputs := []any{a, b, c, d} 116 | outputs := []any{aa, bb, cc, dd} 117 | for i := range inputs { 118 | if !reflect.DeepEqual(inputs[i], outputs[i]) { 119 | t.Fatalf("UnmarshalBinary(...)=%v, want %v", outputs[i], inputs[i]) 120 | } 121 | } 122 | } 123 | 124 | func TestMarshal_single(t *testing.T) { 125 | testMarshalSingle(t, int8(123)) 126 | testMarshalSingle(t, uint8(123)) 127 | testMarshalSingle(t, int32(12345)) 128 | testMarshalSingle(t, uint32(12345)) 129 | testMarshalSingle(t, int(12345)) 130 | testMarshalSingle(t, uint(12345)) 131 | testMarshalSingle(t, float64(33.33)) 132 | testMarshalSingle(t, float32(33.33)) 133 | testMarshalSingle(t, "amit") 134 | } 135 | 136 | func testMarshalSingle[T comparable](t *testing.T, val T) { 137 | buf := MarshalBinary(val) 138 | var got T 139 | if err := UnmarshalBinary(buf, &got); err != nil { 140 | t.Errorf("UnmarshalBinary(%#v) failed: %s", val, err) 141 | return 142 | } 143 | if got != val { 144 | t.Errorf("UnmarshalBinary(%#v)=%#v, want %#v", val, got, val) 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /bloom/bloom_test.go: -------------------------------------------------------------------------------- 1 | package bloom 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/fluhus/gostuff/bnry" 9 | "github.com/fluhus/gostuff/gnum" 10 | ) 11 | 12 | func TestLen(t *testing.T) { 13 | tests := []struct { 14 | bits int 15 | want int 16 | }{ 17 | {1, 8}, 18 | {2, 8}, 19 | {3, 8}, 20 | {4, 8}, 21 | {5, 8}, 22 | {6, 8}, 23 | {7, 8}, 24 | {8, 8}, 25 | {9, 16}, 26 | {16, 16}, 27 | {17, 24}, 28 | } 29 | 30 | for _, test := range tests { 31 | f := New(test.bits, 1) 32 | if l := f.NBits(); l != test.want { 33 | t.Errorf("New(%v,1).Len()=%v, want %v", 34 | test.bits, l, test.want) 35 | } 36 | if f.NHash() != 1 { 37 | t.Errorf("New(%v,1).K()=%v, want 1", 38 | test.bits, f.NHash()) 39 | } 40 | } 41 | } 42 | 43 | func TestFilter(t *testing.T) { 44 | f := New(80, 4) 45 | data := []byte{1, 2, 3, 4} 46 | if f.Has(data) { 47 | t.Fatalf("Has(%v)=true, want false", data) 48 | } 49 | if f.Add(data) { 50 | t.Fatalf("Add(%v)=true, want false", data) 51 | } 52 | if !f.Has(data) { 53 | t.Fatalf("Has(%v)=false, want true", data) 54 | } 55 | if !f.Add(data) { 56 | t.Fatalf("Add(%v)=false, want true", data) 57 | } 58 | 59 | data2 := []byte{4, 3, 2, 1} 60 | if f.Has(data2) { 61 | t.Fatalf("Has(%v)=true, want false", data2) 62 | } 63 | } 64 | 65 | func TestNewOptimal(t *testing.T) { 66 | n := 1000000 67 | p := 0.01 68 | f := NewOptimal(n, p) 69 | t.Logf("bits=%v, k=%v", f.NBits(), f.NHash()) 70 | fp := 0 71 | for i := 0; i < n; i++ { 72 | buf := bnry.MarshalBinary(uint64(i)) 73 | if f.Add(buf) { 74 | fp++ 75 | } 76 | } 77 | if fpr := float64(fp) / float64(n); fpr > p { 78 | t.Fatalf("fp=%v, want <%v", fpr, p) 79 | } 80 | } 81 | 82 | func TestEncode(t *testing.T) { 83 | data1 := []byte{1, 2, 3, 4} 84 | data2 := []byte{4, 3, 2, 1} 85 | f1 := New(80, 4) 86 | f1.SetSeed(5678) 87 | f1.Add(data1) 88 | 89 | if !f1.Has(data1) { 90 | t.Fatalf("Has(%v)=false, want true", data1) 91 | } 92 | if f1.Has(data2) { 93 | t.Fatalf("Has(%v)=true, want false", data2) 94 | } 95 | 96 | buf := bytes.NewBuffer(nil) 97 | if err := f1.Encode(buf); err != nil { 98 | t.Fatalf("Encode(...) failed: %v", err) 99 | } 100 | f2 := &Filter{} 101 | if err := f2.Decode(buf); err != nil { 102 | t.Fatalf("Decode(...) failed: %v", err) 103 | } 104 | 105 | if !bytes.Equal(f1.b, f2.b) { 106 | t.Fatalf("Decode(...) bytes=%v, want %v", f2.b, f1.b) 107 | } 108 | if f1.seed != f2.seed { 109 | t.Fatalf("Decode(...) seed=%v, want %v", f2.seed, f1.seed) 110 | } 111 | 112 | if !f2.Has(data1) { 113 | t.Fatalf("Decode(...).Has(%v)=false, want true", data1) 114 | } 115 | if f2.Has(data2) { 116 | t.Fatalf("Decode(...).Has(%v)=true, want false", data2) 117 | } 118 | } 119 | 120 | func TestAddFilter(t *testing.T) { 121 | data := [][]byte{ 122 | {1, 2, 3, 4}, 123 | {5, 6, 7, 8}, 124 | {9, 10}, 125 | } 126 | f1 := New(100, 3) 127 | f2 := New(100, 3) 128 | f2.SetSeed(f1.Seed()) 129 | 130 | f1.Add(data[0]) 131 | f2.Add(data[1]) 132 | f1.AddFilter(f2) 133 | 134 | if !f1.Has(data[0]) { 135 | t.Fatalf("Has(%v)=false, want true", data[0]) 136 | } 137 | if !f1.Has(data[1]) { 138 | t.Fatalf("Has(%v)=false, want true", data[1]) 139 | } 140 | if f1.Has(data[2]) { 141 | t.Fatalf("Has(%v)=true, want false", data[2]) 142 | } 143 | } 144 | 145 | func TestNElements(t *testing.T) { 146 | bf := New(2000, 4) 147 | bf.SetSeed(0) 148 | n := 100 149 | errs := 0 150 | for i := 1; i <= n; i++ { 151 | bf.Add([]byte{byte(i)}) 152 | if gnum.Diff(bf.NElements(), i) > 1 { 153 | errs++ 154 | } 155 | } 156 | if errs > n/10 { 157 | t.Fatalf("Too many errors: %v, want %v", errs, n/10) 158 | } 159 | } 160 | 161 | func BenchmarkHas(b *testing.B) { 162 | for _, n := range []int{10, 30, 100} { 163 | for k := 1; k <= 3; k++ { 164 | b.Run(fmt.Sprintf("n=%v,k=%v", n, k), func(b *testing.B) { 165 | f := New(1000000, k) 166 | buf := make([]byte, n) 167 | f.Add(buf) 168 | b.ResetTimer() 169 | for i := 0; i < b.N; i++ { 170 | f.Has(buf) 171 | } 172 | }) 173 | } 174 | } 175 | } 176 | 177 | func BenchmarkAdd(b *testing.B) { 178 | for _, n := range []int{10, 30, 100} { 179 | for k := 1; k <= 3; k++ { 180 | b.Run(fmt.Sprintf("n=%v,k=%v", n, k), func(b *testing.B) { 181 | f := New(1000000, k) 182 | buf := make([]byte, n) 183 | b.ResetTimer() 184 | for i := 0; i < b.N; i++ { 185 | f.Add(buf) 186 | } 187 | }) 188 | } 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /graphs/graph.go: -------------------------------------------------------------------------------- 1 | // Package graphs implements a simple graph. 2 | package graphs 3 | 4 | import ( 5 | "cmp" 6 | "iter" 7 | "slices" 8 | 9 | "github.com/fluhus/gostuff/sets" 10 | "github.com/fluhus/gostuff/snm" 11 | ) 12 | 13 | // Graph is a simple graph. 14 | // It contains a set of vertices and a set of edges, 15 | // which are pairs of vertices. 16 | type Graph[T comparable] struct { 17 | v snm.Enumerator[T] // Value to ID. 18 | e sets.Set[[2]int] // Pairs of IDs. 19 | } 20 | 21 | // New returns an empty graph. 22 | func New[T comparable]() *Graph[T] { 23 | return &Graph[T]{snm.Enumerator[T]{}, sets.Set[[2]int]{}} 24 | } 25 | 26 | // AddVertices adds the given values as vertices. 27 | // Values that already exist are ignored. 28 | func (g *Graph[T]) AddVertices(t ...T) { 29 | for _, v := range t { 30 | g.v.IndexOf(v) 31 | } 32 | } 33 | 34 | // NumVertices returns the current number of vertices. 35 | func (g *Graph[T]) NumVertices() int { 36 | return len(g.v) 37 | } 38 | 39 | // NumEdges returns the current number of edges. 40 | func (g *Graph[T]) NumEdges() int { 41 | return len(g.e) 42 | } 43 | 44 | // Edges iterates over current set of edges. 45 | func (g *Graph[T]) Edges() iter.Seq2[T, T] { 46 | return func(yield func(T, T) bool) { 47 | flat := g.v.Elements() 48 | for e := range g.e { 49 | if !yield(flat[e[0]], flat[e[1]]) { 50 | return 51 | } 52 | } 53 | } 54 | } 55 | 56 | // Vertices iterates over current set of vertices, 57 | // by order of addition to the graph. 58 | func (g *Graph[T]) Vertices() iter.Seq[T] { 59 | return func(yield func(T) bool) { 60 | for _, x := range g.v.Elements() { 61 | if !yield(x) { 62 | return 63 | } 64 | } 65 | } 66 | } 67 | 68 | // HasEdge returns whether there is an edge between a and b. 69 | func (g *Graph[T]) HasEdge(a, b T) bool { 70 | return g.e.Has(g.toEdge(a, b)) 71 | } 72 | 73 | // AddEdge adds a and b to the vertex set and adds an edge between them. 74 | // The edge is undirected, meaning that AddEdge(a,b) is equivalent 75 | // to AddEdge(b,a). 76 | // If the edge already exists, this is a no-op. 77 | func (g *Graph[T]) AddEdge(a, b T) { 78 | g.e.Add(g.toEdge(a, b)) 79 | } 80 | 81 | // DeleteEdge removes the edge between a and b, 82 | // while keeping them in the vertex set. 83 | func (g *Graph[T]) DeleteEdge(a, b T) { 84 | delete(g.e, g.toEdge(a, b)) 85 | } 86 | 87 | // ConnectedComponents returns a slice of connected components. 88 | // In each component, the elements are ordered by order of addition to the 89 | // graph. 90 | // The components are ordered by the order of addition of their 91 | // first elements. 92 | func (g *Graph[T]) ConnectedComponents() [][]T { 93 | edges := g.edgeSlices() 94 | m := snm.Slice(g.NumVertices(), func(i int) int { return -1 }) 95 | queue := &snm.Queue[int]{} 96 | 97 | for i := range g.NumVertices() { 98 | if m[i] != -1 { 99 | continue 100 | } 101 | m[i] = i 102 | queue.Enqueue(i) 103 | for queue.Len() > 0 { 104 | e := queue.Dequeue() 105 | for _, j := range edges[e] { 106 | if m[j] == -1 { 107 | m[j] = i 108 | queue.Enqueue(j) 109 | } 110 | } 111 | edges[e] = nil 112 | } 113 | } 114 | 115 | comps := map[int][]int{} 116 | for k, v := range m { 117 | comps[v] = append(comps[v], k) 118 | } 119 | poncs := make([][]int, 0, len(comps)) 120 | for _, v := range comps { 121 | poncs = append(poncs, snm.Sorted(v)) 122 | } 123 | slices.SortFunc(poncs, func(a, b []int) int { 124 | return cmp.Compare(a[0], b[0]) 125 | }) 126 | 127 | // Convert indices to vertex values. 128 | i2v := g.v.Elements() 129 | return snm.Slice(len(poncs), func(i int) []T { 130 | return snm.Slice(len(poncs[i]), func(j int) T { 131 | return i2v[poncs[i][j]] 132 | }) 133 | }) 134 | } 135 | 136 | // Returns (without adding) an edge between a and b. 137 | func (g *Graph[T]) toEdge(a, b T) [2]int { 138 | ia, ib := g.v.IndexOf(a), g.v.IndexOf(b) 139 | if ia > ib { 140 | return [2]int{ib, ia} 141 | } 142 | return [2]int{ia, ib} 143 | } 144 | 145 | // Returns a slice representation of this graph's edges. 146 | func (g *Graph[T]) edgeSlices() [][]int { 147 | // Pre-allocate slices. 148 | counts := make([]int, g.NumVertices()) 149 | for e := range g.e { 150 | counts[e[0]]++ 151 | counts[e[1]]++ 152 | } 153 | edges := snm.Slice(len(counts), func(i int) []int { 154 | return make([]int, 0, counts[i]) 155 | }) 156 | 157 | // Populate with values. 158 | for e := range g.e { 159 | edges[e[0]] = append(edges[e[0]], e[1]) 160 | edges[e[1]] = append(edges[e[1]], e[0]) 161 | } 162 | return edges 163 | } 164 | -------------------------------------------------------------------------------- /gnum/vecs.go: -------------------------------------------------------------------------------- 1 | package gnum 2 | 3 | import ( 4 | "fmt" 5 | "math" 6 | 7 | "golang.org/x/exp/constraints" 8 | ) 9 | 10 | // L1 returns the L1 (Manhattan) distance between a and b. 11 | // Equivalent to Lp(1) but returns the same type. 12 | func L1[S ~[]N, N Number](a, b S) N { 13 | assertMatchingLengths(a, b) 14 | var sum N 15 | for i := range a { 16 | sum += Diff(a[i], b[i]) 17 | } 18 | return sum 19 | } 20 | 21 | // L2 returns the L2 (Euclidean) distance between a and b. 22 | // Equivalent to Lp(2). 23 | func L2[S ~[]N, N Number](a, b S) float64 { 24 | assertMatchingLengths(a, b) 25 | var sum N 26 | for i := range a { 27 | d := (a[i] - b[i]) 28 | sum += d * d 29 | } 30 | return math.Sqrt(float64(sum)) 31 | } 32 | 33 | // Lp returns an Lp distance function. Lp is calculated as follows: 34 | // 35 | // Lp(v) = (sum_i(v[i]^p))^(1/p) 36 | func Lp[S ~[]N, N Number](p int) func(S, S) float64 { 37 | if p < 1 { 38 | panic(fmt.Sprintf("invalid p: %d", p)) 39 | } 40 | 41 | if p == 1 { 42 | return func(a, b S) float64 { 43 | return float64(L1(a, b)) 44 | } 45 | } 46 | if p == 2 { 47 | return L2[S, N] 48 | } 49 | 50 | return func(a, b S) float64 { 51 | assertMatchingLengths(a, b) 52 | fp := float64(p) 53 | var sum float64 54 | for i := range a { 55 | sum += math.Pow(float64(Diff(a[i], b[i])), fp) 56 | } 57 | return math.Pow(sum, 1/fp) 58 | } 59 | } 60 | 61 | // Add adds b to a and returns a. b is unchanged. If a is nil, creates a new 62 | // vector. 63 | func Add[S ~[]N, N Number](a S, b ...S) S { 64 | if a == nil { 65 | if len(b) == 0 { 66 | return nil 67 | } 68 | a = make(S, len(b[0])) 69 | } 70 | for i := range b { 71 | assertMatchingLengths(a, b[i]) 72 | for j := range a { 73 | a[j] += b[i][j] 74 | } 75 | } 76 | return a 77 | } 78 | 79 | // Sub subtracts b from a and returns a. b is unchanged. If a is nil, creates a 80 | // new vector. 81 | func Sub[S ~[]N, N Number](a S, b ...S) S { 82 | if a == nil { 83 | if len(b) == 0 { 84 | return nil 85 | } 86 | a = make(S, len(b[0])) 87 | } 88 | for i := range b { 89 | assertMatchingLengths(a, b[i]) 90 | for j := range a { 91 | a[j] -= b[i][j] 92 | } 93 | } 94 | return a 95 | } 96 | 97 | // Mul multiplies a by b and returns a. b is unchanged. If a is nil, creates a 98 | // new vector. 99 | func Mul[S ~[]N, N Number](a S, b ...S) S { 100 | if a == nil { 101 | if len(b) == 0 { 102 | return nil 103 | } 104 | a = Ones[S](len(b[0])) 105 | } 106 | for i := range b { 107 | assertMatchingLengths(a, b[i]) 108 | for j := range a { 109 | a[j] -= b[i][j] 110 | } 111 | } 112 | return a 113 | } 114 | 115 | // Add1 adds m to a and returns a. 116 | func Add1[S ~[]N, N Number](a S, m N) S { 117 | for i := range a { 118 | a[i] += m 119 | } 120 | return a 121 | } 122 | 123 | // Sub1 subtracts m from a and returns a. 124 | func Sub1[S ~[]N, N Number](a S, m N) S { 125 | for i := range a { 126 | a[i] -= m 127 | } 128 | return a 129 | } 130 | 131 | // Mul1 multiplies the values of a by m and returns a. 132 | func Mul1[S ~[]N, N Number](a S, m N) S { 133 | for i := range a { 134 | a[i] *= m 135 | } 136 | return a 137 | } 138 | 139 | // Dot returns the dot product of the input vectors. 140 | func Dot[S ~[]N, N Number](a, b S) N { 141 | assertMatchingLengths(a, b) 142 | var sum N 143 | for i := range a { 144 | sum += a[i] * b[i] 145 | } 146 | return sum 147 | } 148 | 149 | // Norm returns the L2 norm of the vector. 150 | func Norm[S ~[]N, N constraints.Float](a S) float64 { 151 | var norm N 152 | for _, v := range a { 153 | norm += v * v 154 | } 155 | return math.Sqrt(float64(norm)) 156 | } 157 | 158 | // Ones returns a slice of n ones. Panics if n is negative. 159 | func Ones[S ~[]N, N Number](n int) S { 160 | if n < 0 { 161 | panic(fmt.Sprintf("bad vector length: %d", n)) 162 | } 163 | a := make(S, n) 164 | for i := range a { 165 | a[i] = 1 166 | } 167 | return a 168 | } 169 | 170 | // Copy returns a copy of the given slice. 171 | func Copy[S ~[]N, N any](a S) S { 172 | result := make(S, len(a)) 173 | copy(result, a) 174 | return result 175 | } 176 | 177 | // Cast casts the values of a and places them in a new slice. 178 | func Cast[S ~[]N, T ~[]M, N Number, M Number](a S) T { 179 | t := make(T, len(a)) 180 | for i, s := range a { 181 | t[i] = M(s) 182 | } 183 | return t 184 | } 185 | 186 | // Panics if the input vectors are of different lengths. 187 | func assertMatchingLengths[S ~[]N, N any](a, b S) { 188 | if len(a) != len(b) { 189 | panic(fmt.Sprintf("mismatching lengths: %d, %d", len(a), len(b))) 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /aio/open.go: -------------------------------------------------------------------------------- 1 | // Package aio provides buffered file I/O. 2 | package aio 3 | 4 | import ( 5 | "bufio" 6 | "compress/bzip2" 7 | "compress/gzip" 8 | "io" 9 | "os" 10 | "path/filepath" 11 | 12 | "github.com/klauspost/compress/zstd" 13 | ) 14 | 15 | const ( 16 | gzipSupport = true // If true, .gz files are automatically compressed/decompressed. 17 | zstdSupport = true // If true, .zst files are automatically compressed/decompressed. 18 | bzipSupport = true // If true, .bz2 files are automatically decompressed. 19 | ) 20 | 21 | // OpenRaw opens a file for reading, with a buffer. 22 | func OpenRaw(file string) (*Reader, error) { 23 | f, err := os.Open(file) 24 | if err != nil { 25 | return nil, err 26 | } 27 | return &Reader{*bufio.NewReader(f), f}, nil 28 | } 29 | 30 | // CreateRaw opens a file for writing, with a buffer. 31 | // Erases any previously existing content. 32 | func CreateRaw(file string) (*Writer, error) { 33 | f, err := os.Create(file) 34 | if err != nil { 35 | return nil, err 36 | } 37 | return &Writer{*bufio.NewWriter(f), f}, nil 38 | } 39 | 40 | // AppendRaw opens a file for writing, with a buffer. 41 | // Appends to previously existing content if any. 42 | func AppendRaw(file string) (*Writer, error) { 43 | f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644) 44 | if err != nil { 45 | return nil, err 46 | } 47 | return &Writer{*bufio.NewWriter(f), f}, nil 48 | } 49 | 50 | var ( 51 | rsuffixes = map[string]func(io.Reader) (io.Reader, error){} 52 | wsuffixes = map[string]func(io.WriteCloser) (io.WriteCloser, error){} 53 | ) 54 | 55 | // Open opens a file for reading, with a buffer. 56 | // Decompresses the data according to the file's suffix. 57 | func Open(file string) (*Reader, error) { 58 | f, err := OpenRaw(file) 59 | if err != nil { 60 | return nil, err 61 | } 62 | fn := rsuffixes[filepath.Ext(file)] 63 | if fn == nil { 64 | return f, nil 65 | } 66 | ff, err := fn(f) 67 | if err != nil { 68 | return nil, err 69 | } 70 | return &Reader{*bufio.NewReader(ff), f}, nil 71 | } 72 | 73 | // Create opens a file for writing, with a buffer. 74 | // Erases any previously existing content. 75 | // Compresses the data according to the file's suffix. 76 | func Create(file string) (*Writer, error) { 77 | f, err := CreateRaw(file) 78 | if err != nil { 79 | return nil, err 80 | } 81 | fn := wsuffixes[filepath.Ext(file)] 82 | if fn == nil { 83 | return f, nil 84 | } 85 | ff, err := fn(f) 86 | if err != nil { 87 | return nil, err 88 | } 89 | wrapper := &writerWrapper{ff, f} 90 | return &Writer{*bufio.NewWriter(ff), wrapper}, nil 91 | } 92 | 93 | // Append opens a file for writing, with a buffer. 94 | // Appends to previously existing content if any. 95 | // Compresses the data according to the file's suffix. 96 | func Append(file string) (*Writer, error) { 97 | f, err := AppendRaw(file) 98 | if err != nil { 99 | return nil, err 100 | } 101 | fn := wsuffixes[filepath.Ext(file)] 102 | if fn == nil { 103 | return f, nil 104 | } 105 | ff, err := fn(f) 106 | if err != nil { 107 | return nil, err 108 | } 109 | wrapper := &writerWrapper{ff, f} 110 | return &Writer{*bufio.NewWriter(ff), wrapper}, nil 111 | } 112 | 113 | // AddReadSuffix adds a supported suffix for automatic decompression. 114 | // suffix should include the dot. f should take a raw reader and return a reader 115 | // that decompresses the data. 116 | func AddReadSuffix(suffix string, f func(io.Reader) (io.Reader, error)) { 117 | rsuffixes[suffix] = f 118 | } 119 | 120 | // AddWriteSuffix adds a supported suffix for automatic compression. 121 | // suffix should include the dot. f should take a raw writer and return a writer 122 | // that compresses the data. 123 | func AddWriteSuffix(suffix string, f func(io.WriteCloser) ( 124 | io.WriteCloser, error)) { 125 | wsuffixes[suffix] = f 126 | } 127 | 128 | func init() { 129 | if gzipSupport { 130 | AddReadSuffix(".gz", func(r io.Reader) (io.Reader, error) { 131 | return gzip.NewReader(r) 132 | }) 133 | AddWriteSuffix(".gz", func(w io.WriteCloser) (io.WriteCloser, error) { 134 | return gzip.NewWriterLevel(w, 1) 135 | }) 136 | } 137 | if bzipSupport { 138 | AddReadSuffix(".bz2", func(r io.Reader) (io.Reader, error) { 139 | return bzip2.NewReader(r), nil 140 | }) 141 | } 142 | if zstdSupport { 143 | AddReadSuffix(".zst", func(r io.Reader) (io.Reader, error) { 144 | return zstd.NewReader(r, zstd.WithDecoderConcurrency(1)) 145 | }) 146 | AddWriteSuffix(".zst", func(w io.WriteCloser) (io.WriteCloser, error) { 147 | return zstd.NewWriter(w, zstd.WithEncoderConcurrency(1)) 148 | }) 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /rhash/common_test.go: -------------------------------------------------------------------------------- 1 | // A generic test suite for rolling hashes. 2 | 3 | package rhash 4 | 5 | import ( 6 | "crypto/rand" 7 | "hash" 8 | "testing" 9 | 10 | "github.com/fluhus/gostuff/sets" 11 | ) 12 | 13 | // Runs the test suite for a hash64. 14 | func test64(t *testing.T, f func(n int) hash.Hash64) { 15 | t.Run("basic", func(t *testing.T) { test64basic(t, f) }) 16 | t.Run("cyclic", func(t *testing.T) { test64cyclic(t, f) }) 17 | t.Run("big-n", func(t *testing.T) { test64bigN(t, f) }) 18 | } 19 | 20 | func test64basic(t *testing.T, f func(n int) hash.Hash64) { 21 | data := []byte("amitamit") 22 | tests := []struct { 23 | n int 24 | wantSize []int 25 | wantEq []int 26 | }{ 27 | {2, []int{1, 2, 3, 4, 5, 5, 5, 5}, []int{-1, -1, -1, -1, -1, 1, 2, 3}}, 28 | {3, []int{1, 2, 3, 4, 5, 6, 6, 6}, []int{-1, -1, -1, -1, -1, -1, 2, 3}}, 29 | } 30 | for _, test := range tests { 31 | slice := []uint64{} 32 | set := sets.Set[uint64]{} 33 | h := f(test.n) 34 | for i := range data { 35 | h.Write(data[i : i+1]) 36 | slice = append(slice, h.Sum64()) 37 | set.Add(h.Sum64()) 38 | if len(set) != test.wantSize[i] { 39 | t.Fatalf("n=%d #%d: set size=%v, want %v", 40 | test.n, i, len(set), test.wantSize[i]) 41 | } 42 | if test.wantEq[i] != -1 && h.Sum64() != slice[test.wantEq[i]] { 43 | t.Fatalf("n=%d #%d: Sum64()=%d, want %d", 44 | test.n, i, h.Sum64(), slice[test.wantEq[i]]) 45 | } 46 | } 47 | } 48 | } 49 | 50 | func test64cyclic(t *testing.T, f func(n int) hash.Hash64) { 51 | inputs := []string{ 52 | "asdjadasdk", 53 | "uioewrmnoc", 54 | "wiewuwikxa", 55 | "mfhddl/lcc", 56 | "28n9789dkd", 57 | } 58 | h := f(10) 59 | for _, input := range inputs { 60 | h.Write([]byte(input)) 61 | h2 := f(10) 62 | h2.Write([]byte(input)) 63 | got, want := h.Sum64(), h2.Sum64() 64 | if got != want { 65 | t.Fatalf("Sum64(%q)=%v, want %v", input, got, want) 66 | } 67 | } 68 | } 69 | 70 | func test64bigN(t *testing.T, f func(n int) hash.Hash64) { 71 | const n = 100 72 | 73 | // Create random input. 74 | buf := make([]byte, n) 75 | _, err := rand.Read(buf) 76 | if err != nil { 77 | t.Fatalf("rand.Read() failed: %v", err) 78 | } 79 | 80 | // Repeat 3 times. 81 | buf = append(buf, buf...) 82 | buf = append(buf, buf...) 83 | 84 | h := f(n) 85 | hashes := sets.Set[uint64]{} 86 | for i := range buf { 87 | h.Write(buf[i : i+1]) 88 | hashes.Add(h.Sum64()) 89 | want := min(i+1, n*2-1) 90 | if len(hashes) != want { 91 | t.Fatalf("got %d unique hashes, want %d", len(hashes), want) 92 | } 93 | } 94 | } 95 | 96 | // Runs the test suite for a hash32. 97 | func test32(t *testing.T, f func(n int) hash.Hash32) { 98 | t.Run("basic", func(t *testing.T) { test32basic(t, f) }) 99 | t.Run("cyclic", func(t *testing.T) { test32cyclic(t, f) }) 100 | t.Run("big-n", func(t *testing.T) { test32bigN(t, f) }) 101 | } 102 | 103 | func test32basic(t *testing.T, f func(n int) hash.Hash32) { 104 | data := []byte("amitamit") 105 | tests := []struct { 106 | n int 107 | wantSize []int 108 | wantEq []int 109 | }{ 110 | {2, []int{1, 2, 3, 4, 5, 5, 5, 5}, []int{-1, -1, -1, -1, -1, 1, 2, 3}}, 111 | {3, []int{1, 2, 3, 4, 5, 6, 6, 6}, []int{-1, -1, -1, -1, -1, -1, 2, 3}}, 112 | } 113 | for _, test := range tests { 114 | slice := []uint32{} 115 | set := sets.Set[uint32]{} 116 | h := f(test.n) 117 | for i := range data { 118 | h.Write(data[i : i+1]) 119 | slice = append(slice, h.Sum32()) 120 | set.Add(h.Sum32()) 121 | if len(set) != test.wantSize[i] { 122 | t.Fatalf("n=%d #%d: set size=%v, want %v", 123 | test.n, i, len(set), test.wantSize[i]) 124 | } 125 | if test.wantEq[i] != -1 && h.Sum32() != slice[test.wantEq[i]] { 126 | t.Fatalf("n=%d #%d: Sum32()=%d, want %d", 127 | test.n, i, h.Sum32(), slice[test.wantEq[i]]) 128 | } 129 | } 130 | } 131 | } 132 | 133 | func test32cyclic(t *testing.T, f func(n int) hash.Hash32) { 134 | inputs := []string{ 135 | "asdjadasdk", 136 | "uioewrmnoc", 137 | "wiewuwikxa", 138 | "mfhddl/lcc", 139 | "28n9789dkd", 140 | } 141 | h := f(10) 142 | for _, input := range inputs { 143 | h.Write([]byte(input)) 144 | h2 := f(10) 145 | h2.Write([]byte(input)) 146 | got, want := h.Sum32(), h2.Sum32() 147 | if got != want { 148 | t.Fatalf("Sum32(%q)=%v, want %v", input, got, want) 149 | } 150 | } 151 | } 152 | 153 | func test32bigN(t *testing.T, f func(n int) hash.Hash32) { 154 | const n = 100 155 | 156 | // Create random input. 157 | buf := make([]byte, n) 158 | _, err := rand.Read(buf) 159 | if err != nil { 160 | t.Fatalf("rand.Read() failed: %v", err) 161 | } 162 | 163 | // Repeat 3 times. 164 | buf = append(buf, buf...) 165 | buf = append(buf, buf...) 166 | 167 | h := f(n) 168 | hashes := sets.Set[uint32]{} 169 | for i := range buf { 170 | h.Write(buf[i : i+1]) 171 | hashes.Add(h.Sum32()) 172 | want := min(i+1, n*2-1) 173 | if len(hashes) != want { 174 | t.Fatalf("got %d unique hashes, want %d", len(hashes), want) 175 | } 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /snm/snm_test.go: -------------------------------------------------------------------------------- 1 | package snm 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "slices" 7 | "testing" 8 | 9 | "golang.org/x/exp/maps" 10 | ) 11 | 12 | func TestSlice(t *testing.T) { 13 | want := []int{1, 4, 9, 16} 14 | got := Slice(4, func(i int) int { return (i + 1) * (i + 1) }) 15 | if !slices.Equal(got, want) { 16 | t.Fatalf("Slice((i+1)*(i+1))=%v, want %v", got, want) 17 | } 18 | } 19 | 20 | func TestSliceToSlice(t *testing.T) { 21 | input := []int{1, 4, 9, 16} 22 | want := []float64{1.5, 4.5, 9.5, 16.5} 23 | got := SliceToSlice(input, func(i int) float64 { 24 | return float64(i) + 0.5 25 | }) 26 | if !slices.Equal(got, want) { 27 | t.Fatalf("SliceToSlice(%v)=%v, want %v", input, got, want) 28 | } 29 | } 30 | 31 | func TestMapToMap(t *testing.T) { 32 | input := map[string]string{"a": "bbb", "cccc": "ddddddd"} 33 | want := map[int]int{1: 3, 4: 7} 34 | got := MapToMap(input, func(k, v string) (int, int) { 35 | return len(k), len(v) 36 | }) 37 | if !maps.Equal(got, want) { 38 | t.Fatalf("MapToMap(%v)=%v, want %v", input, got, want) 39 | } 40 | } 41 | 42 | func TestMapToMap_equalKeys(t *testing.T) { 43 | input := map[string]string{"a": "bbb", "cccc": "ddddddd", "e": "ff"} 44 | want1 := map[int]int{1: 3, 4: 7} 45 | want2 := map[int]int{1: 2, 4: 7} 46 | got := MapToMap(input, func(k, v string) (int, int) { 47 | return len(k), len(v) 48 | }) 49 | if !maps.Equal(got, want1) && !maps.Equal(got, want2) { 50 | t.Fatalf("MapToMap(%v)=%v, want %v or %v", input, got, want1, want2) 51 | } 52 | } 53 | 54 | func TestDefaultMap(t *testing.T) { 55 | m := NewDefaultMap[int, string](func(i int) string { 56 | return fmt.Sprint(i + 1) 57 | }) 58 | if got, want := m.Get(2), "3"; got != want { 59 | t.Fatalf("Get(%d)=%s, want %s", 2, got, want) 60 | } 61 | if got, want := m.Get(6), "7"; got != want { 62 | t.Fatalf("Get(%d)=%s, want %s", 6, got, want) 63 | } 64 | m.Set(2, "a") 65 | if got, want := m.Get(2), "a"; got != want { 66 | t.Fatalf("Get(%d)=%s, want %s", 2, got, want) 67 | } 68 | if got, want := m.Get(6), "7"; got != want { 69 | t.Fatalf("Get(%d)=%s, want %s", 6, got, want) 70 | } 71 | if got, want := len(m.M), 2; got != want { 72 | t.Fatalf("Len=%d, want %d", got, want) 73 | } 74 | } 75 | 76 | func TestCompareReverse(t *testing.T) { 77 | input := []int{3, 4, 2, 1, 5} 78 | want := []int{5, 4, 3, 2, 1} 79 | 80 | cp := slices.Clone(input) 81 | slices.SortFunc(cp, CompareReverse) 82 | if !slices.Equal(cp, want) { 83 | t.Errorf("SortFunc(%v, Compare)=%v, want %v", 84 | input, cp, want) 85 | } 86 | } 87 | 88 | func ExampleSortedKeys() { 89 | ages := map[string]int{ 90 | "Alice": 30, 91 | "Bob": 20, 92 | "Charlie": 25, 93 | } 94 | for _, name := range SortedKeys(ages) { 95 | fmt.Printf("%s: %d\n", name, ages[name]) 96 | } 97 | // Output: 98 | // Bob: 20 99 | // Charlie: 25 100 | // Alice: 30 101 | } 102 | 103 | func ExampleSortedKeysFunc_reverse() { 104 | ages := map[string]int{ 105 | "Alice": 30, 106 | "Bob": 20, 107 | "Charlie": 25, 108 | } 109 | // Sort by reverse natural order. 110 | for _, name := range SortedKeysFunc(ages, CompareReverse) { 111 | fmt.Printf("%s: %d\n", name, ages[name]) 112 | } 113 | // Output: 114 | // Alice: 30 115 | // Charlie: 25 116 | // Bob: 20 117 | } 118 | 119 | func TestEnumerator(t *testing.T) { 120 | tests := []struct { 121 | i, want int 122 | }{ 123 | {6, 0}, {3, 1}, {6, 0}, {2, 2}, {3, 1}, {10, 3}, {10, 3}, {2, 2}, 124 | {6, 0}, {3, 1}, 125 | } 126 | e := Enumerator[int]{} 127 | for _, test := range tests { 128 | if got := e.IndexOf(test.i); got != test.want { 129 | t.Fatalf("%v.IndexOf(%v)=%v, want %v", e, test.i, got, test.want) 130 | } 131 | } 132 | 133 | wantElem := []int{6, 3, 2, 10} 134 | if got := e.Elements(); !slices.Equal(got, wantElem) { 135 | t.Fatalf("%v.Elements()=%v, want %v", e, got, wantElem) 136 | } 137 | } 138 | 139 | func ExampleCapMap() { 140 | data := [][]string{ 141 | {"a", "b", "c", "a", "b", "b"}, 142 | // ... 143 | } 144 | counter := NewCapMap[string, int]() 145 | for _, x := range data { 146 | m := counter.Map() 147 | countValues(x, m) 148 | 149 | // Do something with m. 150 | j, _ := json.Marshal(m) 151 | fmt.Println(string(j)) 152 | counter.Clear() 153 | } 154 | //Output: 155 | //{"a":2,"b":3,"c":1} 156 | } 157 | 158 | func countValues(vals []string, out map[string]int) { 159 | for _, v := range vals { 160 | out[v]++ 161 | } 162 | } 163 | 164 | func TestShuffle(t *testing.T) { 165 | nums := Slice(10, func(i int) int { return i }) 166 | found := make([]bool, len(nums)) 167 | counts := Slice(len(nums), func(i int) []int { return make([]int, len(nums)) }) 168 | for range 1000 { 169 | Shuffle(nums) 170 | clear(found) 171 | for i, x := range nums { 172 | found[x] = true 173 | counts[i][x]++ 174 | } 175 | for i, f := range found { 176 | if !f { 177 | t.Fatalf("did not find %v: %v", i, nums) 178 | } 179 | } 180 | } 181 | for i, c := range counts { 182 | for j, x := range c { 183 | if x < 70 { 184 | t.Errorf("count of %v at position %v: %v, want >%v", 185 | j, i, x, 70) 186 | } 187 | } 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /prefixtree/tree.go: -------------------------------------------------------------------------------- 1 | // Package prefixtree provides a basic prefix tree implementation. 2 | // 3 | // Add, Has, HasPrefix, Delete and DeletePrefix are linear in query length, 4 | // regardless of the size of the tree. 5 | // All operations are non-recursive. 6 | package prefixtree 7 | 8 | import ( 9 | "iter" 10 | "slices" 11 | 12 | "golang.org/x/exp/maps" 13 | ) 14 | 15 | // A Tree is a prefix tree. 16 | // 17 | // A zero value tree is invalid; use New to create a new instance. 18 | type Tree struct { 19 | isElem bool // Is this node an element in the tree. 20 | m map[byte]*Tree 21 | } 22 | 23 | // New returns an empty tree. 24 | func New() *Tree { 25 | return &Tree{m: map[byte]*Tree{}} 26 | } 27 | 28 | // Add inserts x to the tree. 29 | // If a was already added, the tree is unchanged. 30 | func (t *Tree) Add(x []byte) { 31 | cur := t 32 | for _, b := range x { 33 | next := cur.m[b] 34 | if next == nil { 35 | next = New() 36 | cur.m[b] = next 37 | } 38 | cur = next 39 | } 40 | cur.isElem = true 41 | } 42 | 43 | // Has returns whether x was added to the tree. 44 | func (t *Tree) Has(x []byte) bool { 45 | cur := t 46 | for _, b := range x { 47 | next := cur.m[b] 48 | if next == nil { 49 | return false 50 | } 51 | cur = next 52 | } 53 | return cur.isElem 54 | } 55 | 56 | // HasPrefix returns whether x is a prefix of an element in the tree. 57 | func (t *Tree) HasPrefix(x []byte) bool { 58 | cur := t 59 | for _, b := range x { 60 | next := cur.m[b] 61 | if next == nil { 62 | return false 63 | } 64 | cur = next 65 | } 66 | return true 67 | } 68 | 69 | // FindPrefixes returns all the elements in the tree 70 | // that are prefixes of x, ordered by length. 71 | func (t *Tree) FindPrefixes(x []byte) [][]byte { 72 | cur := t 73 | var result [][]byte 74 | if cur.isElem { 75 | result = append(result, x[:0]) 76 | } 77 | for i, b := range x { 78 | cur = cur.m[b] 79 | if cur == nil { 80 | break 81 | } 82 | if cur.isElem { 83 | result = append(result, x[:i+1]) 84 | } 85 | } 86 | return result 87 | } 88 | 89 | // Delete removes x from the tree, if possible. 90 | // Returns the result of Has(a) before deletion. 91 | func (t *Tree) Delete(x []byte) bool { 92 | // Delve in and create a stack. 93 | stack := make([]*Tree, len(x)) 94 | cur := t 95 | for i := range x { 96 | stack[i] = cur 97 | cur = cur.m[x[i]] 98 | if cur == nil { 99 | return false 100 | } 101 | } 102 | if !cur.isElem { 103 | return false 104 | } 105 | cur.isElem = false 106 | if len(cur.m) > 0 { 107 | return true 108 | } 109 | 110 | // Go back and delete nodes. 111 | for i := len(stack) - 1; i >= 0; i-- { 112 | delete(stack[i].m, x[i]) 113 | if len(stack[i].m) > 0 { 114 | // Stop deleting if node has other children. 115 | break 116 | } 117 | } 118 | return true 119 | } 120 | 121 | // DeletePrefix removes prefix x from the tree. 122 | // All sequences that have x as their prefix are removed. 123 | // Other sequences are unchanged. 124 | func (t *Tree) DeletePrefix(x []byte) { 125 | // Length 0 mean all strings. 126 | if len(x) == 0 { 127 | clear(t.m) 128 | return 129 | } 130 | 131 | // Delve in and create a stack. 132 | stack := make([]*Tree, len(x)) 133 | cur := t 134 | for i := range x { 135 | stack[i] = cur 136 | cur = cur.m[x[i]] 137 | if cur == nil { 138 | return 139 | } 140 | } 141 | 142 | // Go back and delete nodes. 143 | for i := len(stack) - 1; i >= 0; i-- { 144 | delete(stack[i].m, x[i]) 145 | if len(stack[i].m) > 0 { 146 | // Stop deleting if node has other children. 147 | break 148 | } 149 | } 150 | } 151 | 152 | // Iter iterates over the elements of t, 153 | // in no particular order. 154 | // The tree should not be modified during iteration. 155 | func (t *Tree) Iter() iter.Seq[[]byte] { 156 | return func(yield func([]byte) bool) { 157 | stack := []*iterStep{{t, maps.Keys(t.m)}} 158 | var cur []byte 159 | if t.isElem && !yield(cur) { 160 | return 161 | } 162 | for { 163 | step := stack[len(stack)-1] 164 | if len(step.k) == 0 { // Finished with this branch. 165 | stack = stack[:len(stack)-1] 166 | if len(stack) == 0 { // Done. 167 | break 168 | } 169 | cur = cur[:len(cur)-1] 170 | continue 171 | } 172 | // Handle next child. 173 | key := step.k[0] 174 | child := step.t.m[key] 175 | stack = append(stack, &iterStep{child, maps.Keys(child.m)}) 176 | step.k = step.k[1:] 177 | cur = append(cur, key) 178 | if child.isElem && !yield(slices.Clone(cur)) { 179 | return 180 | } 181 | } 182 | } 183 | } 184 | 185 | // IterPrefix iterates the elements in the tree that have x as their prefix. 186 | // Calling IterPrefix with a zero-length slice is equivalent to Iter(). 187 | func (t *Tree) IterPrefix(x []byte) iter.Seq[[]byte] { 188 | return func(yield func([]byte) bool) { 189 | cur := t 190 | for i := range x { 191 | cur = cur.m[x[i]] 192 | if cur == nil { 193 | return 194 | } 195 | } 196 | for y := range cur.Iter() { 197 | if !yield(slices.Concat(x, y)) { 198 | return 199 | } 200 | } 201 | } 202 | } 203 | 204 | // A step in the iteration stack. 205 | type iterStep struct { 206 | t *Tree // Tree to process 207 | k []byte // Tree's remaining children 208 | } 209 | -------------------------------------------------------------------------------- /minhash/minhash_test.go: -------------------------------------------------------------------------------- 1 | package minhash 2 | 3 | import ( 4 | "fmt" 5 | "hash/crc64" 6 | "math" 7 | "math/rand" 8 | "reflect" 9 | "slices" 10 | "sort" 11 | "testing" 12 | ) 13 | 14 | func TestCollection(t *testing.T) { 15 | tests := []struct { 16 | n int 17 | input []uint64 18 | want []uint64 19 | }{ 20 | { 21 | 3, 22 | []uint64{1, 2, 2, 2, 2, 1, 1, 3, 3, 3, 1, 2, 3, 1, 3, 3, 2}, 23 | []uint64{1, 2, 3}, 24 | }, 25 | { 26 | 3, 27 | []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9}, 28 | []uint64{1, 2, 3}, 29 | }, 30 | { 31 | 3, 32 | []uint64{9, 8, 7, 6, 5, 4, 3, 2, 1}, 33 | []uint64{1, 2, 3}, 34 | }, 35 | { 36 | 5, 37 | []uint64{40, 19, 55, 10, 32, 1, 100, 5, 99, 16, 16}, 38 | []uint64{1, 5, 10, 16, 19}, 39 | }, 40 | } 41 | for _, test := range tests { 42 | mh := New[uint64](test.n) 43 | for _, k := range test.input { 44 | mh.Push(k) 45 | } 46 | got := mh.View() 47 | sort.Slice(got, func(i, j int) bool { 48 | return got[i] < got[j] 49 | }) 50 | if !reflect.DeepEqual(got, test.want) { 51 | t.Errorf("New(%d).Push(%v)=%v, want %v", 52 | test.n, test.input, got, test.want) 53 | } 54 | } 55 | } 56 | 57 | func TestJSON(t *testing.T) { 58 | input := New[int](5) 59 | input.Push(1) 60 | input.Push(4) 61 | input.Push(9) 62 | input.Push(16) 63 | input.Push(25) 64 | input.Push(36) 65 | jsn, err := input.MarshalJSON() 66 | if err != nil { 67 | t.Fatalf("MinHash(1,4,9,16,25,36).MarshalJSON() failed: %v", err) 68 | } 69 | got := New[int](2) 70 | err = got.UnmarshalJSON(jsn) 71 | if err != nil { 72 | t.Fatalf("UnmarshalJSON(%q) failed: %v", jsn, err) 73 | } 74 | if !slices.Equal(got.View(), input.View()) { 75 | t.Fatalf("UnmarshalJSON(%q)=%v, want %v", jsn, got, input) 76 | } 77 | } 78 | 79 | func TestJaccard(t *testing.T) { 80 | tests := []struct { 81 | a, b []uint64 82 | k int 83 | want float64 84 | }{ 85 | {[]uint64{1, 2, 3}, []uint64{1, 2, 3}, 3, 1}, 86 | {[]uint64{1, 2, 3}, []uint64{2, 3, 4}, 3, 2.0 / 3.0}, 87 | {[]uint64{2, 3, 4}, []uint64{1, 2, 3}, 3, 2.0 / 3.0}, 88 | {[]uint64{1, 2, 3, 4, 5}, []uint64{1, 3, 5}, 5, 0.6}, 89 | } 90 | for _, test := range tests { 91 | a, b := New[uint64](test.k), New[uint64](test.k) 92 | for _, i := range test.a { 93 | a.Push(i) 94 | } 95 | for _, i := range test.b { 96 | b.Push(i) 97 | } 98 | a.Sort() 99 | b.Sort() 100 | if got := a.Jaccard(b); math.Abs(got-test.want) > 0.00001 { 101 | t.Errorf("Jaccard(%v,%v)=%f, want %f", 102 | test.a, test.b, got, test.want) 103 | } 104 | } 105 | } 106 | 107 | func TestCollection_largeInput(t *testing.T) { 108 | const k = 10000 109 | tests := []struct { 110 | from1, to1, from2, to2 int 111 | }{ 112 | {1, 75000, 25000, 100000}, 113 | {1, 60000, 40000, 60000}, 114 | {1, 60000, 20000, 60000}, 115 | {1, 40000, 40001, 60000}, 116 | } 117 | for _, test := range tests { 118 | a, b := New[uint64](k), New[uint64](k) 119 | h := crc64.New(crc64.MakeTable(crc64.ECMA)) 120 | for i := test.from1; i <= test.to1; i++ { 121 | h.Reset() 122 | fmt.Fprint(h, i) 123 | a.Push(h.Sum64()) 124 | } 125 | for i := test.from2; i <= test.to2; i++ { 126 | h.Reset() 127 | fmt.Fprint(h, i) 128 | b.Push(h.Sum64()) 129 | } 130 | a.Sort() 131 | b.Sort() 132 | want := float64(test.to1-test.from2+1) / float64( 133 | test.to2-test.from1+1) 134 | if got := a.Jaccard(b); math.Abs(got-want) > want/100 { 135 | t.Errorf("Jaccard(...)=%f, want %f", got, want) 136 | } 137 | } 138 | } 139 | 140 | func FuzzCollection(f *testing.F) { 141 | f.Add(1, 2, 3, 4, 5, 6) 142 | f.Fuzz(func(t *testing.T, a int, b int, c int, d int, e int, f int) { 143 | col := New[int](2) 144 | col.Push(a) 145 | col.Push(b) 146 | col.Push(c) 147 | col.Push(d) 148 | col.Push(e) 149 | col.Push(f) 150 | v := col.View() 151 | if len(v) != 2 { 152 | t.Errorf("len()=%d, want %d", len(v), 2) 153 | } 154 | if v[0] < v[1] { 155 | t.Errorf("v[0]=", v[0], v[1]) 156 | } 157 | }) 158 | } 159 | 160 | func TestFrozen(t *testing.T) { 161 | mh := New[int](3) 162 | mh.Push(27872) 163 | mh.Push(16978) 164 | mh.Push(28696) 165 | mh.Sort() 166 | 167 | fr := mh.Frozen() 168 | if !slices.Equal(mh.View(), fr.View()) { 169 | t.Fatalf("View()=%v, want %v", fr.View(), mh.View()) 170 | } 171 | 172 | mh2 := New[int](3) 173 | mh.Push(27872) 174 | mh.Push(16978) 175 | mh.Push(28697) 176 | mh2.Sort() 177 | 178 | want := mh.Jaccard(mh2) 179 | got := fr.Jaccard(mh2.Frozen()) 180 | if got != want { 181 | t.Fatalf("Jaccard=%v, want %v", got, want) 182 | } 183 | } 184 | 185 | func TestFrozen_modifySort(t *testing.T) { 186 | mh := New[int](1) 187 | mh.Push(27872) 188 | mh = mh.Frozen() 189 | defer func() { 190 | recover() 191 | }() 192 | mh.Sort() 193 | t.Fatalf(".Frozen().Sort() succeeded, want panic") 194 | } 195 | 196 | func TestFrozen_modifyPush(t *testing.T) { 197 | mh := New[int](1) 198 | mh.Push(27872) 199 | mh = mh.Frozen() 200 | defer func() { 201 | recover() 202 | }() 203 | mh.Push(123) 204 | t.Fatalf(".Frozen().Sort() succeeded, want panic") 205 | } 206 | 207 | func BenchmarkPush(b *testing.B) { 208 | nums := rand.Perm(b.N) 209 | mh := New[int](b.N) 210 | b.ResetTimer() 211 | for i := 0; i < b.N; i++ { 212 | mh.Push(nums[i]) 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /minhash/minhash.go: -------------------------------------------------------------------------------- 1 | // Package minhash provides a min-hash collection for approximating Jaccard 2 | // similarity. 3 | package minhash 4 | 5 | import ( 6 | "encoding/json" 7 | "fmt" 8 | "slices" 9 | 10 | "github.com/fluhus/gostuff/heaps" 11 | "github.com/fluhus/gostuff/sets" 12 | "github.com/fluhus/gostuff/snm" 13 | "golang.org/x/exp/constraints" 14 | ) 15 | 16 | // A MinHash is a min-hash collection. Retains the k lowest unique values out of all 17 | // the values that were added to it. 18 | type MinHash[T constraints.Integer] struct { 19 | h *heaps.Heap[T] // Min-hash heap 20 | s sets.Set[T] // Keeps elements unique 21 | k int // Max size of the collection 22 | n int // Number of calls to Push 23 | } 24 | 25 | // New returns an empty collection that stores k values. 26 | func New[T constraints.Integer](k int) *MinHash[T] { 27 | if k < 1 { 28 | panic(fmt.Sprintf("invalid n: %d, should be positive", k)) 29 | } 30 | return &MinHash[T]{ 31 | heaps.Max[T](), 32 | make(sets.Set[T], k), 33 | k, 0, 34 | } 35 | } 36 | 37 | // Push tries to add a hash to the collection. x is added only if it does not 38 | // already exist, and there are less than k elements lesser than x. 39 | // Returns true if x was added and false if not. 40 | func (mh *MinHash[T]) Push(x T) bool { 41 | if mh.frozen() { 42 | panic("called Push on a frozen MinHash") 43 | } 44 | mh.n++ 45 | if mh.h.Len() == mh.k && x >= mh.h.Head() { 46 | // x is too large. 47 | return false 48 | } 49 | if mh.s.Has(x) { 50 | return false 51 | } 52 | if mh.h.Len() == mh.k { 53 | mh.s.Remove(mh.h.Pop()) 54 | } 55 | mh.h.Push(x) 56 | mh.s.Add(x) 57 | return true 58 | } 59 | 60 | // K returns the maximal number of elements in mh. 61 | func (mh *MinHash[T]) K() int { 62 | return mh.k 63 | } 64 | 65 | // N returns the number of calls that were made to Push. 66 | // Represents the size of the original set. 67 | func (mh *MinHash[T]) N() int { 68 | return mh.n 69 | } 70 | 71 | // View returns the underlying slice of values. 72 | func (mh *MinHash[T]) View() []T { 73 | if mh.frozen() { 74 | return slices.Clone(mh.h.View()) 75 | } 76 | return mh.h.View() 77 | } 78 | 79 | // MarshalJSON implements the json.Marshaler interface. 80 | func (mh *MinHash[T]) MarshalJSON() ([]byte, error) { 81 | return json.Marshal(struct { 82 | K int `json:"k"` 83 | N int `json:"n"` 84 | H []T `json:"h"` 85 | }{mh.k, mh.n, mh.View()}) 86 | } 87 | 88 | // UnmarshalJSON implements the json.Unmarshaler interface. 89 | func (mh *MinHash[T]) UnmarshalJSON(b []byte) error { 90 | var raw struct { 91 | K int 92 | N int 93 | H []T 94 | } 95 | if err := json.Unmarshal(b, &raw); err != nil { 96 | return err 97 | } 98 | ss := New[T](raw.K) 99 | ss.h.PushSlice(raw.H) 100 | ss.s.Add(raw.H...) 101 | ss.n = raw.N 102 | *mh = *ss 103 | return nil 104 | } 105 | 106 | // Returns the intersection and union sizes of mh and other, 107 | // in min-hash terms. 108 | func (mh *MinHash[T]) intersect(other *MinHash[T]) (int, int) { 109 | a, b := mh.View(), other.View() 110 | if !mh.frozen() && !slices.IsSortedFunc(a, snm.CompareReverse) { 111 | panic("receiver is not sorted") 112 | } 113 | if !other.frozen() && !slices.IsSortedFunc(b, snm.CompareReverse) { 114 | panic("other is not sorted") 115 | } 116 | intersection := 0 117 | i, j, m := len(a)-1, len(b)-1, 0 118 | for ; i >= 0 && j >= 0 && m < mh.k; m++ { 119 | if a[i] > b[j] { 120 | j-- 121 | } else if a[i] < b[j] { 122 | i-- 123 | } else { // a[i] == b[j] 124 | intersection++ 125 | i-- 126 | j-- 127 | } 128 | } 129 | union := min(mh.k, m+len(a)-i+len(b)-j) 130 | return intersection, union 131 | } 132 | 133 | // Jaccard returns the approximated Jaccard similarity between mh and other. 134 | // 135 | // Sort needs to be called before calling this function. 136 | func (mh *MinHash[T]) Jaccard(other *MinHash[T]) float64 { 137 | i, u := mh.intersect(other) 138 | return float64(i) / float64(u) 139 | } 140 | 141 | // SoftJaccard returns the Jaccard similarity between mh and other, 142 | // adding one agreed upon element and one disagreed upon element to 143 | // the calculation. 144 | // 145 | // Sort needs to be called before calling this function. 146 | func (mh *MinHash[T]) SoftJaccard(other *MinHash[T]) float64 { 147 | r := mh.Jaccard(other) 148 | sum := float64(mh.N() + other.N()) 149 | ri, ru := r*sum/(r+1), sum/(r+1) 150 | return (ri + 1) / (ru + 2) 151 | } 152 | 153 | // Sort sorts the collection, making it ready for Jaccard calculation. 154 | // The collection is still valid after calling Sort. 155 | func (mh *MinHash[T]) Sort() { 156 | if mh.frozen() { 157 | panic("called Sort on a frozen MinHash " + 158 | "(frozen instances are already sorted)") 159 | } 160 | slices.SortFunc(mh.h.View(), snm.CompareReverse) 161 | } 162 | 163 | // Frozen returns an immutable version of this instance. 164 | // The original instance is unchanged. 165 | // 166 | // Frozen instances are sorted, take up less memory 167 | // and calculate Jaccard faster. 168 | // Calls to View are slower because the data is cloned. 169 | func (mh *MinHash[T]) Frozen() *MinHash[T] { 170 | h := heaps.Max[T]() 171 | h.PushSlice(mh.View()) 172 | h.Clip() 173 | slices.SortFunc(h.View(), snm.CompareReverse) 174 | result := &MinHash[T]{h, nil, mh.k, mh.n} 175 | return result 176 | } 177 | 178 | // Returns whether this minhash is frozen. 179 | func (mh *MinHash[T]) frozen() bool { 180 | return mh.s == nil 181 | } 182 | -------------------------------------------------------------------------------- /gnum/gnum.go: -------------------------------------------------------------------------------- 1 | // Package gnum provides generic numerical functions. 2 | package gnum 3 | 4 | import ( 5 | "fmt" 6 | "math" 7 | 8 | "golang.org/x/exp/constraints" 9 | ) 10 | 11 | // Number is a constraint that contains comparable numbers. 12 | type Number interface { 13 | constraints.Float | constraints.Integer 14 | } 15 | 16 | // Max returns the maximal value in the slice or zero if the slice is empty. 17 | func Max[S ~[]N, N constraints.Ordered](s S) N { 18 | if len(s) == 0 { 19 | var zero N 20 | return zero 21 | } 22 | e := s[0] 23 | for _, v := range s[1:] { 24 | e = max(e, v) 25 | } 26 | return e 27 | } 28 | 29 | // Min returns the maximal value in the slice or zero if the slice is empty. 30 | func Min[S ~[]N, N constraints.Ordered](s S) N { 31 | if len(s) == 0 { 32 | var zero N 33 | return zero 34 | } 35 | e := s[0] 36 | for _, v := range s[1:] { 37 | e = min(e, v) 38 | } 39 | return e 40 | } 41 | 42 | // ArgMax returns the index of the maximal value in the slice or -1 if the slice is empty. 43 | func ArgMax[S ~[]E, E constraints.Ordered](s S) int { 44 | if len(s) == 0 { 45 | return -1 46 | } 47 | imax, max := 0, s[0] 48 | for i, v := range s { 49 | if v > max { 50 | imax, max = i, v 51 | } 52 | } 53 | return imax 54 | } 55 | 56 | // ArgMin returns the index of the minimal value in the slice or -1 if the slice is empty. 57 | func ArgMin[S ~[]E, E constraints.Ordered](s S) int { 58 | if len(s) == 0 { 59 | return -1 60 | } 61 | imin, min := 0, s[0] 62 | for i, v := range s { 63 | if v < min { 64 | imin, min = i, v 65 | } 66 | } 67 | return imin 68 | } 69 | 70 | // Abs returns the absolute value of n. 71 | // 72 | // For floats use [math.Abs]. 73 | func Abs[N constraints.Signed](n N) N { 74 | if n < 0 { 75 | return -n 76 | } 77 | return n 78 | } 79 | 80 | // Diff returns the non-negative difference between a and b. 81 | func Diff[N Number](a, b N) N { 82 | if a > b { 83 | return a - b 84 | } 85 | return b - a 86 | } 87 | 88 | // Sum returns the sum of the slice. 89 | func Sum[S ~[]N, N Number](a S) N { 90 | var sum N 91 | for _, v := range a { 92 | sum += v 93 | } 94 | return sum 95 | } 96 | 97 | // Mean returns the average of the slice. 98 | func Mean[S ~[]N, N Number](a S) float64 { 99 | return float64(Sum(a)) / float64(len(a)) 100 | } 101 | 102 | // ExpMean returns the exponential average of the slice. 103 | // Non-positive values result in NaN. 104 | func ExpMean[S ~[]N, N Number](a S) float64 { 105 | if len(a) == 0 { 106 | return math.NaN() 107 | } 108 | sum := 0.0 109 | for _, v := range a { 110 | sum += math.Log(float64(v)) 111 | } 112 | return math.Exp(sum / float64(len(a))) 113 | } 114 | 115 | // Cov returns the covariance of a and b. 116 | func Cov[S ~[]N, N Number](a, b S) float64 { 117 | assertMatchingLengths(a, b) 118 | ma := Mean(a) 119 | mb := Mean(b) 120 | cov := 0.0 121 | for i := range a { 122 | cov += (float64(a[i]) - ma) * (float64(b[i]) - mb) 123 | } 124 | cov /= float64(len(a)) 125 | return cov 126 | } 127 | 128 | // Var returns the variance of a. 129 | func Var[S ~[]N, N Number](a S) float64 { 130 | return Cov(a, a) 131 | } 132 | 133 | // Std returns the standard deviation of a. 134 | func Std[S ~[]N, N Number](a S) float64 { 135 | return math.Sqrt(Var(a)) 136 | } 137 | 138 | // Corr returns the Pearson correlation between the a and b. 139 | func Corr[S ~[]N, N Number](a, b S) float64 { 140 | return Cov(a, b) / Std(a) / Std(b) 141 | } 142 | 143 | // Entropy returns the Shannon-entropy of a. 144 | // The elements in a don't have to sum up to 1. 145 | func Entropy[S ~[]N, N Number](a S) float64 { 146 | sum := float64(Sum(a)) 147 | result := 0.0 148 | for i, v := range a { 149 | if v < 0 { 150 | panic(fmt.Sprintf("negative value at position %d: %v", 151 | i, v)) 152 | } 153 | if v == 0 { 154 | continue 155 | } 156 | p := float64(v) / sum 157 | result -= p * math.Log2(p) 158 | } 159 | return result 160 | } 161 | 162 | // Idiv divides a by b, rounded to the nearest integer. 163 | func Idiv[T constraints.Integer](a, b T) T { 164 | return T(math.Round(float64(a) / float64(b))) 165 | } 166 | 167 | // Quantiles returns the elements that divide the given slice 168 | // at the given ratios. 169 | // 170 | // For example, 0.5 returns the middle element, 171 | // 0.25 returns the element at a quarter of the length, etc. 172 | // 0 and 1 return the first and last element, respectively. 173 | func Quantiles[T any](a []T, qq ...float64) []T { 174 | if len(qq) == 0 { 175 | return nil 176 | } 177 | if len(a) == 0 { 178 | panic("input slice cannot be empty") 179 | } 180 | result := make([]T, 0, len(qq)) 181 | n := float64(len(a) - 1) 182 | for _, q := range qq { 183 | i := int(math.Round(q * n)) 184 | result = append(result, a[i]) 185 | } 186 | return result 187 | } 188 | 189 | // NQuantiles returns the elements that divide 190 | // the given slice into n equal parts (up to rounding), 191 | // including the first and last elements. 192 | // 193 | // For example, for n=2 it returns the first element, 194 | // the middle, and the last element. 195 | func NQuantiles[T any](a []T, n int) []T { 196 | q := make([]float64, n+1) 197 | for i := range q { 198 | q[i] = float64(i) / float64(n) 199 | } 200 | return Quantiles(a, q...) 201 | } 202 | 203 | // LogFactorial returns an approximation of log(n!), 204 | // calculated in constant time. 205 | func LogFactorial(n int) float64 { 206 | if n < 0 { 207 | panic(fmt.Sprintf("n cannot be negative: %v", n)) 208 | } 209 | if n == 0 || n == 1 { 210 | return 0 211 | } 212 | // Stirling's approximation. 213 | const halfLog2pi = 0x1.d67f1c864beb4p-01 // 0.5*math.Log(2*math.Pi) 214 | nf := float64(n) 215 | logn := math.Log(nf) 216 | return halfLog2pi + 0.5*logn + nf*(logn-1) 217 | } 218 | -------------------------------------------------------------------------------- /bloom/bloom.go: -------------------------------------------------------------------------------- 1 | // Package bloom provides a simple bloom filter implementation. 2 | package bloom 3 | 4 | import ( 5 | "fmt" 6 | "hash" 7 | "io" 8 | "math" 9 | _ "unsafe" 10 | 11 | "github.com/fluhus/gostuff/bnry" 12 | "github.com/spaolacci/murmur3" 13 | ) 14 | 15 | //go:linkname fastrand runtime.fastrand 16 | func fastrand() uint32 17 | 18 | // Filter is a single bloom filter. 19 | type Filter struct { 20 | b []byte // Filter data. 21 | h []hash.Hash64 // Hash functions. 22 | seed uint32 23 | } 24 | 25 | // NHash returns the number of hash functions this filter uses. 26 | func (f *Filter) NHash() int { 27 | return len(f.h) 28 | } 29 | 30 | // NBits returns the number of bits this filter uses. 31 | func (f *Filter) NBits() int { 32 | return 8 * len(f.b) 33 | } 34 | 35 | // NElements returns an approximation of the number of elements added to the 36 | // filter. 37 | func (f *Filter) NElements() int { 38 | m := float64(f.NBits()) 39 | k := float64(f.NHash()) 40 | x := 0.0 // Number of bits that are 1. 41 | for _, bt := range f.b { 42 | for bt > 0 { 43 | if bt&1 > 0 { 44 | x++ 45 | } 46 | bt >>= 1 47 | } 48 | } 49 | return int(math.Round(-m / k * math.Log(1-x/m))) 50 | } 51 | 52 | // Has checks if all k hash values of v were encountered. 53 | // Makes at most k hash calculations. 54 | func (f *Filter) Has(v []byte) bool { 55 | for i := range f.h { 56 | f.h[i].Reset() 57 | f.h[i].Write(v) 58 | hash := int(f.h[i].Sum64() % uint64(len(f.b)*8)) 59 | if getBit(f.b, hash) == 0 { 60 | return false 61 | } 62 | } 63 | return true 64 | } 65 | 66 | // Add adds v to the filter, and returns the value of Has(v) before adding. 67 | // After calling Add, Has(v) will always be true. Makes k calls to hash. 68 | func (f *Filter) Add(v []byte) bool { 69 | has := true 70 | for i := range f.h { 71 | f.h[i].Reset() 72 | f.h[i].Write(v) 73 | hash := int(f.h[i].Sum64() % uint64(len(f.b)*8)) 74 | if getBit(f.b, hash) == 0 { 75 | has = false 76 | setBit(f.b, hash, 1) 77 | } 78 | } 79 | return has 80 | } 81 | 82 | // AddFilter merges other into f. After merging, f is equivalent to have been added 83 | // all the elements of other. 84 | func (f *Filter) AddFilter(other *Filter) { 85 | // Make sure the two filters are compatible. 86 | if f.NBits() != other.NBits() { 87 | panic(fmt.Sprintf("mismatching number of bits: this has %v, other has %v", 88 | f.NBits(), other.NBits())) 89 | } 90 | if f.NHash() != other.NHash() { 91 | panic(fmt.Sprintf("mismatching number of hashes: this has %v, other has %v", 92 | f.NHash(), other.NHash())) 93 | } 94 | if f.Seed() != other.Seed() { 95 | panic(fmt.Sprintf("mismatching seeds: this has %v, other has %v", 96 | f.Seed(), other.Seed())) 97 | } 98 | 99 | // Merge. 100 | for i := range f.b { 101 | f.b[i] |= other.b[i] 102 | } 103 | } 104 | 105 | // Seed returns the hash seed of this filter. 106 | // A new filter starts with a random seed. 107 | func (f *Filter) Seed() uint32 { 108 | return f.seed 109 | } 110 | 111 | // SetSeed sets the hash seed of this filter. 112 | // The filter must be empty. 113 | func (f *Filter) SetSeed(seed uint32) { 114 | for _, b := range f.b { 115 | if b != 0 { 116 | panic("cannot change seed after elements were added") 117 | } 118 | } 119 | f.seed = seed 120 | h := murmur3.New32WithSeed(seed) 121 | for i := range f.h { 122 | h.Write([]byte{1}) 123 | f.h[i] = murmur3.New64WithSeed(h.Sum32()) 124 | } 125 | } 126 | 127 | // Encode writes this filter to the stream. Can be reproduced later with Decode. 128 | func (f *Filter) Encode(w io.Writer) error { 129 | // Order is k, seed, bytes. 130 | return bnry.Write(w, uint64(len(f.h)), f.seed, f.b) 131 | } 132 | 133 | // Decode reads an encoded filter from the stream and sets this filter's state 134 | // to match it. Destroys the previously existing state of this filter. 135 | func (f *Filter) Decode(r io.ByteReader) error { 136 | var k uint64 137 | var seed uint32 138 | var b []byte 139 | if err := bnry.Read(r, &k, &seed, &b); err != nil { 140 | return err 141 | } 142 | f.h = make([]hash.Hash64, k) 143 | f.SetSeed(uint32(seed)) 144 | f.b = b 145 | 146 | return nil 147 | } 148 | 149 | // New creates a new bloom filter with the given parameters. Number of 150 | // bits is rounded up to the nearest multiple of 8. 151 | // 152 | // See NewOptimal for an alternative way to decide on the parameters. 153 | func New(bits int, k int) *Filter { 154 | if bits < 1 { 155 | panic(fmt.Sprintf("number of bits should be at least 1, got %v", bits)) 156 | } 157 | if k < 1 { 158 | panic(fmt.Sprintf("k should be at least 1, got %v", k)) 159 | } 160 | 161 | result := &Filter{ 162 | b: make([]byte, ((bits-1)/8)+1), 163 | h: make([]hash.Hash64, k), 164 | } 165 | result.SetSeed(fastrand()) 166 | return result 167 | } 168 | 169 | // NewOptimal creates a new bloom filter, with parameters optimal for the 170 | // expected number of elements (n) and the required false-positive rate (p). 171 | // 172 | // The calculation is taken from: 173 | // https://en.wikipedia.org/wiki/Bloom_filter#Optimal_number_of_hash_functions 174 | func NewOptimal(n int, p float64) *Filter { 175 | m := math.Round(-float64(n) * math.Log(p) / math.Ln2 / math.Ln2) 176 | k := math.Round(-math.Log2(p)) 177 | return New(int(m), int(k)) 178 | } 179 | 180 | // Returns the value of the n'th bit in a byte slice. 181 | func getBit(b []byte, n int) int { 182 | return int(b[n/8] >> (n % 8) & 1) 183 | } 184 | 185 | // Sets the value of the n'th bit in a byte slice. 186 | func setBit(b []byte, n, v int) { 187 | if v == 0 { 188 | b[n/8] &= ^(byte(1) << (n % 8)) 189 | } else if v == 1 { 190 | b[n/8] |= byte(1) << (n % 8) 191 | } else { 192 | panic(fmt.Sprintf("bad value: %v, expected 0 or 1", v)) 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /gnum/gnum_test.go: -------------------------------------------------------------------------------- 1 | package gnum 2 | 3 | import ( 4 | "math" 5 | "slices" 6 | "testing" 7 | ) 8 | 9 | func TestEntropy(t *testing.T) { 10 | input1 := []int{1, 2, 3, 4} 11 | input2 := []uint{1, 2, 3, 4} 12 | input3 := []float32{1, 2, 3, 4} 13 | input4 := []float64{1, 2, 3, 4} 14 | want := 1.8464393 15 | if got := Entropy(input1); Diff(got, want) > 0.00000005 { 16 | t.Errorf("Entropy(%v)=%v, want %v", input1, got, want) 17 | } 18 | if got := Entropy(input2); Diff(got, want) > 0.00000005 { 19 | t.Errorf("Entropy(%v)=%v, want %v", input2, got, want) 20 | } 21 | if got := Entropy(input3); Diff(got, want) > 0.00000005 { 22 | t.Errorf("Entropy(%v)=%v, want %v", input3, got, want) 23 | } 24 | if got := Entropy(input4); Diff(got, want) > 0.00000005 { 25 | t.Errorf("Entropy(%v)=%v, want %v", input4, got, want) 26 | } 27 | } 28 | 29 | func TestIdiv(t *testing.T) { 30 | tests := []struct { 31 | a, b, want int 32 | }{ 33 | {8, 1, 8}, 34 | {8, 2, 4}, 35 | {8, 3, 3}, 36 | {8, 4, 2}, 37 | {8, 5, 2}, 38 | {8, 6, 1}, 39 | {8, 7, 1}, 40 | {8, 8, 1}, 41 | {8, 9, 1}, 42 | {8, 10, 1}, 43 | {8, 11, 1}, 44 | {8, 12, 1}, 45 | {8, 13, 1}, 46 | {8, 14, 1}, 47 | {8, 15, 1}, 48 | {8, 17, 0}, 49 | {8, 18, 0}, 50 | {8, 19, 0}, 51 | {8, 20, 0}, 52 | } 53 | for _, test := range tests { 54 | if got := Idiv(test.a, test.b); got != test.want { 55 | t.Errorf("Idiv(%v,%v)=%v, want %v", test.a, test.b, got, test.want) 56 | } 57 | } 58 | } 59 | 60 | func TestMinMax(t *testing.T) { 61 | tests := []struct { 62 | input []int 63 | mn, mx, amn, amx int 64 | }{ 65 | {nil, 0, 0, -1, -1}, 66 | {[]int{42}, 42, 42, 0, 0}, 67 | {[]int{42, 42}, 42, 42, 0, 0}, 68 | {[]int{42, 42, 42}, 42, 42, 0, 0}, 69 | {[]int{1, 2, 3}, 1, 3, 0, 2}, 70 | {[]int{1, 3, 2}, 1, 3, 0, 1}, 71 | {[]int{2, 1, 3}, 1, 3, 1, 2}, 72 | {[]int{2, 3, 1}, 1, 3, 2, 1}, 73 | {[]int{3, 1, 2}, 1, 3, 1, 0}, 74 | {[]int{3, 2, 1}, 1, 3, 2, 0}, 75 | } 76 | for _, test := range tests { 77 | mn, mx, amn, amx := Min(test.input), Max(test.input), ArgMin(test.input), ArgMax(test.input) 78 | if mn != test.mn { 79 | t.Errorf("Min(%v)=%v, want %v", test.input, mn, test.mn) 80 | } 81 | if mx != test.mx { 82 | t.Errorf("Max(%v)=%v, want %v", test.input, mx, test.mx) 83 | } 84 | if amn != test.amn { 85 | t.Errorf("ArgMin(%v)=%v, want %v", test.input, amn, test.amn) 86 | } 87 | if amx != test.amx { 88 | t.Errorf("ArgMax(%v)=%v, want %v", test.input, amx, test.amx) 89 | } 90 | } 91 | } 92 | 93 | func TestSum(t *testing.T) { 94 | tests := []struct { 95 | input []int 96 | want int 97 | }{ 98 | {nil, 0}, 99 | {[]int{1}, 1}, 100 | {[]int{1, 1}, 2}, 101 | {[]int{1, 1, 1, 1}, 4}, 102 | {[]int{6, 4, 1}, 11}, 103 | } 104 | for _, test := range tests { 105 | if got := Sum(test.input); got != test.want { 106 | t.Errorf("Sum(%v)=%v, want %v", test.input, got, test.want) 107 | } 108 | } 109 | } 110 | 111 | func TestMean(t *testing.T) { 112 | tests := []struct { 113 | input []int 114 | want float64 115 | }{ 116 | {[]int{1}, 1}, 117 | {[]int{1, 1}, 1}, 118 | {[]int{1, 1, 1, 1}, 1}, 119 | {[]int{6, 4, -1}, 3}, 120 | } 121 | for _, test := range tests { 122 | if got := Mean(test.input); got != test.want { 123 | t.Errorf("Mean(%v)=%v, want %v", test.input, got, test.want) 124 | } 125 | } 126 | } 127 | 128 | func TestExpMean(t *testing.T) { 129 | tests := []struct { 130 | input []int 131 | want float64 132 | }{ 133 | {[]int{1}, 1}, 134 | {[]int{1, 1}, 1}, 135 | {[]int{3, 3, 3, 3}, 3}, 136 | {[]int{10, 1000}, 100}, 137 | {[]int{10, 100}, math.Sqrt(1000)}, 138 | {[]int{10, 100, 1000}, 100}, 139 | } 140 | const tolerance = 0.0000001 141 | for _, test := range tests { 142 | if got := ExpMean(test.input); Diff(got, test.want) > tolerance { 143 | t.Errorf("ExpMean(%v)=%v, want %v", test.input, got, test.want) 144 | } 145 | } 146 | } 147 | 148 | func FuzzSumMean(f *testing.F) { 149 | f.Add(0.0, 0.0, 0.0, 0.0) 150 | f.Fuzz(func(t *testing.T, a float64, b float64, c float64, d float64) { 151 | slice := []float64{a, b, c, d} 152 | want := a + b + c + d 153 | if got := Sum(slice); got != want { 154 | t.Fatalf("Sum([%v,%v,%v,%v])=%v, want %v", a, b, c, d, got, want) 155 | } 156 | want /= 4 157 | if got := Mean(slice); got != want { 158 | t.Fatalf("Mean([%v,%v,%v,%v])=%v, want %v", a, b, c, d, got, want) 159 | } 160 | if a > 0 && b > 0 && c > 0 && d > 0 { 161 | const tol = 0.0000001 162 | want = math.Pow(a*b*c*d, 0.25) 163 | if got := ExpMean(slice); Diff(got, want) > tol { 164 | t.Fatalf("ExpMean([%v,%v,%v,%v])=%v, want %v", a, b, c, d, got, want) 165 | } 166 | } 167 | }) 168 | } 169 | 170 | func TestQuantiles(t *testing.T) { 171 | input := []int{1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144} 172 | q := []float64{0.3, 0.5, 0, 1, 0.9} 173 | want1 := []int{3, 8, 1, 144, 89} 174 | want2 := []int{3, 13, 1, 144, 89} 175 | got := Quantiles(input, q...) 176 | if !slices.Equal(got, want1) && !slices.Equal(got, want2) { 177 | t.Fatalf("Quantiles(%v,%v)=%v, want %v or %v", 178 | input, q, got, want1, want2) 179 | } 180 | } 181 | 182 | func TestNQuantiles(t *testing.T) { 183 | input := []int{1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144} 184 | want := []int{1, 5, 21, 144} 185 | got := NQuantiles(input, 3) 186 | if !slices.Equal(got, want) { 187 | t.Fatalf("Quantiles(%v,3)=%v, want %v", input, got, want) 188 | } 189 | } 190 | 191 | func TestLogFactorial(t *testing.T) { 192 | tests := [][]int{ 193 | {0, 1}, {1, 1}, {2, 2}, {3, 6}, {4, 24}, {5, 120}, {6, 720}, 194 | } 195 | for _, test := range tests { 196 | got := math.Exp(LogFactorial(test[0])) 197 | want := float64(test[1]) 198 | if Diff(got, want) > want*0.05 { 199 | t.Errorf("lf(%v)=%v, want %v", test[0], got, test[1]) 200 | } 201 | } 202 | } 203 | 204 | func BenchmarkLogFactorial(b *testing.B) { 205 | for i := range b.N { 206 | LogFactorial(i) 207 | } 208 | } 209 | -------------------------------------------------------------------------------- /flagx/flagx.go: -------------------------------------------------------------------------------- 1 | // Package flagx provides additional [flag] functions. 2 | package flagx 3 | 4 | import ( 5 | "flag" 6 | "fmt" 7 | "os" 8 | "regexp" 9 | "slices" 10 | "strconv" 11 | "strings" 12 | 13 | "github.com/fluhus/gostuff/snm" 14 | ) 15 | 16 | // RegexpFlagSet defines a regular expression flag with specified name, 17 | // default value, and usage string. 18 | // The return value is the address of a regular expression variable that 19 | // stores the value of the flag. 20 | func RegexpFlagSet(fs *flag.FlagSet, name string, 21 | value *regexp.Regexp, usage string) **regexp.Regexp { 22 | p := &value 23 | fs.Func(name, usage, func(s string) error { 24 | r, err := regexp.Compile(s) 25 | if err != nil { 26 | return err 27 | } 28 | *p = r 29 | return nil 30 | }) 31 | return p 32 | } 33 | 34 | // Regexp defines a regular expression flag with specified name, 35 | // default value, and usage string. 36 | // The return value is the address of a regular expression variable that 37 | // stores the value of the flag. 38 | func Regexp(name string, value *regexp.Regexp, usage string) **regexp.Regexp { 39 | return RegexpFlagSet(flag.CommandLine, name, value, usage) 40 | } 41 | 42 | // IntBetweenFlagSet defines an int flag with specified name, 43 | // default value, usage string and bounds. 44 | // The return value is the address of an int variable that 45 | // stores the value of the flag. 46 | func IntBetweenFlagSet(fs *flag.FlagSet, name string, 47 | value int, usage string, minVal, maxVal int) *int { 48 | p := &value 49 | fs.Func(name, usage, func(s string) error { 50 | i, err := strconv.Atoi(s) 51 | if err != nil { 52 | return err 53 | } 54 | if i < minVal || i > maxVal { 55 | return fmt.Errorf("got %d, want %d-%d", i, minVal, maxVal) 56 | } 57 | *p = i 58 | return nil 59 | }) 60 | return p 61 | } 62 | 63 | // IntBetween defines an int flag with specified name, 64 | // default value, usage string and bounds. 65 | // The return value is the address of an int variable that 66 | // stores the value of the flag. 67 | func IntBetween(name string, value int, usage string, minVal, maxVal int) *int { 68 | return IntBetweenFlagSet( 69 | flag.CommandLine, name, value, usage, minVal, maxVal) 70 | } 71 | 72 | // FloatBetweenFlagSet defines a float flag with specified name, 73 | // default value, usage string and bounds. 74 | // incMin and incMax define whether min and max are included in the 75 | // allowed values. 76 | // The return value is the address of a float variable that 77 | // stores the value of the flag. 78 | func FloatBetweenFlagSet(fs *flag.FlagSet, name string, value float64, 79 | usage string, minVal, maxVal float64, incMin, incMax bool) *float64 { 80 | p := &value 81 | fs.Func(name, usage, func(s string) error { 82 | f, err := strconv.ParseFloat(s, 64) 83 | if err != nil { 84 | return err 85 | } 86 | if (incMin && f < minVal) || (!incMin && f <= minVal) || 87 | (incMax && f > maxVal) || (!incMax && f >= maxVal) { 88 | smin, smax := "(", ")" 89 | if incMin { 90 | smin = "[" 91 | } 92 | if incMax { 93 | smax = "]" 94 | } 95 | return fmt.Errorf("got %f, want %s%f,%f%s", 96 | f, smin, minVal, maxVal, smax) 97 | } 98 | *p = f 99 | return nil 100 | }) 101 | return p 102 | } 103 | 104 | // FloatBetween defines a float flag with specified name, 105 | // default value, usage string and bounds. 106 | // incMin and incMax define whether min and max are included in the 107 | // allowed values. 108 | // The return value is the address of a float variable that 109 | // stores the value of the flag. 110 | func FloatBetween(name string, value float64, usage string, 111 | minVal, maxVal float64, incMin, incMax bool) *float64 { 112 | return FloatBetweenFlagSet( 113 | flag.CommandLine, name, value, usage, 114 | minVal, maxVal, incMin, incMax) 115 | } 116 | 117 | // FileExistsFlagSet defines a string flag that represents 118 | // an existing file. Returns an error if the file does not exist. 119 | func FileExistsFlagSet(fs *flag.FlagSet, name string, value string, 120 | usage string) *string { 121 | v := &value 122 | fs.Func(name, usage, func(s string) error { 123 | f, err := os.Stat(s) 124 | if err != nil { 125 | return err 126 | } 127 | if f.IsDir() { 128 | return fmt.Errorf("path is a directory") 129 | } 130 | *v = s 131 | return nil 132 | }) 133 | return v 134 | } 135 | 136 | // FileExists defines a string flag that represents 137 | // an existing file. Returns an error if the file does not exist. 138 | func FileExists(name string, value string, usage string) *string { 139 | return FileExistsFlagSet(flag.CommandLine, name, value, usage) 140 | } 141 | 142 | // OneOfFlagSet defines a flag that must have one of the given values. 143 | // The type must be one that can be read by [fmt.Scan]. 144 | func OneOfFlagSet[T comparable](fs *flag.FlagSet, name string, 145 | value T, usage string, of ...T) *T { 146 | if len(of) == 0 { 147 | panic("called with 0 possible values") 148 | } 149 | 150 | // Create usage string. 151 | options := strings.Join(snm.SliceToSlice(of, func(t T) string { 152 | return fmt.Sprint(t) 153 | }), ", ") 154 | dflt := "" 155 | var zero T 156 | if value != zero { 157 | dflt = fmt.Sprintf("; default %s", fmt.Sprint(value)) 158 | } 159 | usage = fmt.Sprintf("%s (one of [%s]%s)", 160 | usage, options, dflt) 161 | 162 | v := value 163 | fs.Func(name, usage, func(s string) error { 164 | _, err := fmt.Sscanln(s, &v) 165 | if err != nil { 166 | return err 167 | } 168 | if slices.Index(of, v) == -1 { 169 | return fmt.Errorf("want one of %v", of) 170 | } 171 | return nil 172 | }) 173 | return &v 174 | } 175 | 176 | // OneOf defines a flag that must have one of the given values. 177 | // The type must be one that can be read by [fmt.Scan]. 178 | func OneOf[T comparable](name string, value T, usage string, of ...T) *T { 179 | return OneOfFlagSet(flag.CommandLine, name, value, usage, of...) 180 | } 181 | -------------------------------------------------------------------------------- /bnry/write.go: -------------------------------------------------------------------------------- 1 | package bnry 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "fmt" 7 | "io" 8 | "math" 9 | "reflect" 10 | "strings" 11 | 12 | "golang.org/x/exp/constraints" 13 | ) 14 | 15 | // Write writes the given values to the given writer. 16 | // Values should be of any of the supported types. 17 | // Panics if a value is of an unsupported type. 18 | func Write(w io.Writer, vals ...any) error { 19 | return NewWriter(w).Write(vals...) 20 | } 21 | 22 | // MarshalBinary writes the given values to a byte slice. 23 | // Values should be of any of the supported types. 24 | // Panics if a value is of an unsupported type. 25 | func MarshalBinary(vals ...any) []byte { 26 | buf := bytes.NewBuffer(nil) 27 | NewWriter(buf).Write(vals...) 28 | return buf.Bytes() 29 | } 30 | 31 | type Writer struct { 32 | w io.Writer 33 | buf [binary.MaxVarintLen64]byte 34 | } 35 | 36 | func NewWriter(w io.Writer) *Writer { 37 | return &Writer{w: w} 38 | } 39 | 40 | // Write writes the given values. 41 | // Values should be of any of the supported types. 42 | // Panics if a value is of an unsupported type. 43 | func (w *Writer) Write(vals ...any) error { 44 | for _, v := range vals { 45 | if err := w.writeSingle(v); err != nil { 46 | return err 47 | } 48 | } 49 | return nil 50 | } 51 | 52 | // Writes a single value as binary. 53 | func (w *Writer) writeSingle(val any) error { 54 | switch val := val.(type) { 55 | case uint8: 56 | return w.writeByte(val) 57 | case uint16: 58 | return writeUint(w, val) 59 | case uint32: 60 | return writeUint(w, val) 61 | case uint64: 62 | return writeUint(w, val) 63 | case uint: 64 | return writeUint(w, val) 65 | case int8: 66 | return w.writeByte(byte(val)) 67 | case int16: 68 | return writeInt(w, val) 69 | case int32: 70 | return writeInt(w, val) 71 | case int64: 72 | return writeInt(w, val) 73 | case int: 74 | return writeInt(w, val) 75 | case float32: 76 | return writeUint(w, math.Float32bits(val)) 77 | case float64: 78 | return writeUint(w, math.Float64bits(val)) 79 | case bool: 80 | return w.writeByte(boolToByte(val)) 81 | case string: 82 | return w.writeString(val) 83 | case []uint8: 84 | return w.writeUint8Slice(val) 85 | case []uint16: 86 | return writeUintSlice(w, val) 87 | case []uint32: 88 | return writeUintSlice(w, val) 89 | case []uint64: 90 | return writeUintSlice(w, val) 91 | case []uint: 92 | return writeUintSlice(w, val) 93 | case []int8: 94 | return w.writeInt8Slice(val) 95 | case []int16: 96 | return writeIntSlice(w, val) 97 | case []int32: 98 | return writeIntSlice(w, val) 99 | case []int64: 100 | return writeIntSlice(w, val) 101 | case []int: 102 | return writeIntSlice(w, val) 103 | case []float32: 104 | return w.writeFloat32Slice(val) 105 | case []float64: 106 | return w.writeFloat64Slice(val) 107 | case []bool: 108 | return w.writeBoolSlice(val) 109 | case []string: 110 | return w.writeStringSlice(val) 111 | default: 112 | panic(fmt.Sprintf("unsupported type: %v", 113 | reflect.TypeOf(val))) 114 | } 115 | } 116 | 117 | func (w *Writer) writeByte(b byte) error { 118 | w.buf[0] = b 119 | _, err := w.w.Write(w.buf[:1]) 120 | return err 121 | } 122 | 123 | func writeUint[T constraints.Unsigned](w *Writer, i T) error { 124 | _, err := w.w.Write(binary.AppendUvarint(w.buf[:0], uint64(i))) 125 | return err 126 | } 127 | 128 | func writeInt[T constraints.Signed](w *Writer, i T) error { 129 | _, err := w.w.Write(binary.AppendVarint(w.buf[:0], int64(i))) 130 | return err 131 | } 132 | 133 | func (w *Writer) writeUint8Slice(s []uint8) error { 134 | if err := writeUint(w, uint(len(s))); err != nil { 135 | return err 136 | } 137 | _, err := w.w.Write(s) 138 | return err 139 | } 140 | 141 | func (w *Writer) writeString(s string) error { 142 | if err := writeUint(w, uint(len(s))); err != nil { 143 | return err 144 | } 145 | _, err := strings.NewReader(s).WriteTo(w.w) 146 | return err 147 | } 148 | 149 | func (w *Writer) writeInt8Slice(s []int8) error { 150 | if err := writeUint(w, uint(len(s))); err != nil { 151 | return err 152 | } 153 | for _, x := range s { 154 | if err := w.writeByte(byte(x)); err != nil { 155 | return err 156 | } 157 | } 158 | return nil 159 | } 160 | 161 | func writeUintSlice[T constraints.Unsigned](w *Writer, s []T) error { 162 | if err := writeUint(w, uint(len(s))); err != nil { 163 | return err 164 | } 165 | for _, x := range s { 166 | if err := writeUint(w, x); err != nil { 167 | return err 168 | } 169 | } 170 | return nil 171 | } 172 | 173 | func writeIntSlice[T constraints.Signed](w *Writer, s []T) error { 174 | if err := writeUint(w, uint(len(s))); err != nil { 175 | return err 176 | } 177 | for _, x := range s { 178 | if err := writeInt(w, x); err != nil { 179 | return err 180 | } 181 | } 182 | return nil 183 | } 184 | 185 | func (w *Writer) writeFloat32Slice(s []float32) error { 186 | if err := writeUint(w, uint(len(s))); err != nil { 187 | return err 188 | } 189 | for _, x := range s { 190 | if err := writeUint(w, math.Float32bits(x)); err != nil { 191 | return err 192 | } 193 | } 194 | return nil 195 | } 196 | 197 | func (w *Writer) writeFloat64Slice(s []float64) error { 198 | if err := writeUint(w, uint(len(s))); err != nil { 199 | return err 200 | } 201 | for _, x := range s { 202 | if err := writeUint(w, math.Float64bits(x)); err != nil { 203 | return err 204 | } 205 | } 206 | return nil 207 | } 208 | 209 | func (w *Writer) writeBoolSlice(s []bool) error { 210 | if err := writeUint(w, uint(len(s))); err != nil { 211 | return err 212 | } 213 | for _, x := range s { 214 | if err := w.writeByte(boolToByte(x)); err != nil { 215 | return err 216 | } 217 | } 218 | return nil 219 | } 220 | 221 | func (w *Writer) writeStringSlice(s []string) error { 222 | if err := writeUint(w, uint(len(s))); err != nil { 223 | return err 224 | } 225 | for _, x := range s { 226 | if err := w.writeString(x); err != nil { 227 | return err 228 | } 229 | } 230 | return nil 231 | } 232 | 233 | func boolToByte(b bool) byte { 234 | if b { 235 | return 1 236 | } else { 237 | return 0 238 | } 239 | } 240 | --------------------------------------------------------------------------------