├── .gitignore
├── nlp
├── doc.go
├── wordnet
│ ├── string.go
│ ├── types.go
│ └── parser_test.go
├── tfidf.go
├── tokenize.go
└── lda-tool
│ └── main.go
├── rhash
├── rabin32_test.go
├── rabin64_test.go
├── doc.go
├── buz_test.go
├── fuzz_test.go
├── bench_test.go
├── rabin32.go
├── rabin64.go
└── common_test.go
├── go.mod
├── bnry
├── doc.go
├── bnry_test.go
└── write.go
├── README.md
├── .github
└── workflows
│ └── go.yml
├── hashx
├── hashx_test.go
└── hashx.go
├── go.sum
├── aio
├── wrap.go
└── open.go
├── ppln
├── common.go
├── doc.go
├── common_test.go
├── nserial.go
├── nserial_test.go
├── serial_test.go
└── serial.go
├── clustering
├── kmeans_test.go
├── agglo_test.go
├── upgma_test.go
├── rand.go
└── upgma.go
├── morris
├── error
│ └── error.go
├── morris_test.go
└── morris.go
├── csvx
├── csvx_test.go
├── examples_test.go
└── csvx.go
├── repeat
├── repeat_test.go
└── repeat.go
├── reservoir
├── reservoir.go
└── reservoir_test.go
├── LICENSE
├── gnum
├── vecs_test.go
├── vecs.go
├── gnum.go
└── gnum_test.go
├── snm
├── queue_test.go
├── queue.go
└── snm_test.go
├── sets
├── sorted_test.go
├── sets_test.go
├── sorted.go
└── sets.go
├── xmlnode
├── read_test.go
└── read.go
├── graphs
├── bfsdfs.go
├── graph_test.go
├── bfsdfs_test.go
└── graph.go
├── ezpprof
└── ezpprof.go
├── jio
└── jio.go
├── iterx
├── slice.go
├── slice_test.go
├── lines.go
├── unreader.go
├── unreader2.go
├── unreader_test.go
└── unreader2_test.go
├── bits
├── bits.go
└── bits_test.go
├── hll
├── hll_test.go
└── hll.go
├── heaps
├── heaps_test.go
└── heaps.go
├── ptimer
├── ptimer_test.go
└── ptimer.go
├── flagx
├── flagx_test.go
└── flagx.go
├── bloom
├── bloom_test.go
└── bloom.go
├── prefixtree
└── tree.go
└── minhash
├── minhash_test.go
└── minhash.go
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | testdata
3 |
--------------------------------------------------------------------------------
/nlp/doc.go:
--------------------------------------------------------------------------------
1 | // Package nlp provides basic NLP utilities.
2 | package nlp
3 |
--------------------------------------------------------------------------------
/rhash/rabin32_test.go:
--------------------------------------------------------------------------------
1 | package rhash
2 |
3 | import (
4 | "hash"
5 | "testing"
6 | )
7 |
8 | func TestRabin32(t *testing.T) {
9 | test32(t, func(n int) hash.Hash32 { return NewRabinFingerprint32(n) })
10 | }
11 |
--------------------------------------------------------------------------------
/rhash/rabin64_test.go:
--------------------------------------------------------------------------------
1 | package rhash
2 |
3 | import (
4 | "hash"
5 | "testing"
6 | )
7 |
8 | func TestRabin64(t *testing.T) {
9 | test64(t, func(n int) hash.Hash64 { return NewRabinFingerprint64(n) })
10 | }
11 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/fluhus/gostuff
2 |
3 | go 1.24
4 |
5 | require (
6 | github.com/agonopol/go-stem v0.0.0-20150630113328-985885018250
7 | github.com/klauspost/compress v1.17.9
8 | github.com/spaolacci/murmur3 v1.1.0
9 | golang.org/x/exp v0.0.0-20220325121720-054d8573a5d8
10 | )
11 |
--------------------------------------------------------------------------------
/rhash/doc.go:
--------------------------------------------------------------------------------
1 | // Package rhash provides implementations of rolling-hash functions.
2 | //
3 | // A rolling-hash is a hash function that "remembers" only the last n bytes it
4 | // received, where n is a parameter. Meaning, the hash of a byte sequence
5 | // always equals the hash of its last n bytes.
6 | package rhash
7 |
--------------------------------------------------------------------------------
/bnry/doc.go:
--------------------------------------------------------------------------------
1 | // Package bnry provides simple functions for encoding and decoding values as
2 | // binary.
3 | //
4 | // # Supported data types
5 | //
6 | // The types that can be encoded and decoded are
7 | // int*, uint* (excluding int and uint), float*, bool, string and
8 | // slices of these types.
9 | // [Read] and [UnmarshalBinary] expect pointers to these types,
10 | // while [Write] and [MarshalBinary] expect non-pointers.
11 | package bnry
12 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | gostuff
2 | =======
3 |
4 | [](https://pkg.go.dev/github.com/fluhus/gostuff)
5 | [](https://goreportcard.com/report/github.com/fluhus/gostuff)
6 |
7 | I believe that simple ideas deserve simple code.
8 |
9 | This is a collection of packages I wrote for my personal work as a data
10 | scientist, and may benefit others as well.
11 | The emphasis is on simplicity and minimalism, to keep code readable and clean.
12 |
--------------------------------------------------------------------------------
/.github/workflows/go.yml:
--------------------------------------------------------------------------------
1 | # This workflow will build a golang project
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
3 |
4 | name: Go
5 |
6 | on:
7 | push:
8 | branches: [ "master" ]
9 | pull_request:
10 | branches: [ "master" ]
11 |
12 | jobs:
13 |
14 | build:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v3
18 |
19 | - name: Set up Go
20 | uses: actions/setup-go@v4
21 | with:
22 | go-version: '1.24'
23 |
24 | - name: Build
25 | run: go build -v ./...
26 |
27 | - name: Test
28 | run: go test -v ./...
29 |
--------------------------------------------------------------------------------
/hashx/hashx_test.go:
--------------------------------------------------------------------------------
1 | package hashx
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/fluhus/gostuff/sets"
7 | )
8 |
9 | func TestInt(t *testing.T) {
10 | const n = 100000
11 |
12 | hashes := sets.Set[uint64]{}
13 | hx := New()
14 | for i := range n {
15 | hashes.Add(IntHashx(hx, i))
16 | }
17 | if len(hashes) != n {
18 | t.Errorf("len(hashes)=%v, want %v", len(hashes), n)
19 | }
20 |
21 | hashes = sets.Set[uint64]{}
22 | hx = New()
23 | for i := range n {
24 | hashes.Add(IntHashx(hx, i))
25 | hashes.Add(IntHashx(hx, -i))
26 | }
27 | want := n*2 - 1
28 | if len(hashes) != want {
29 | t.Errorf("len(hashes)=%v, want %v", len(hashes), want)
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/agonopol/go-stem v0.0.0-20150630113328-985885018250 h1:znLjWXbvRGRvFlxYoiRu4Yf38isHmYYG07scbZKKg9I=
2 | github.com/agonopol/go-stem v0.0.0-20150630113328-985885018250/go.mod h1:JpR7ykfRJUCcS6aOUCB6dPImrYufY0NoBCDg/wqeIIo=
3 | github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
4 | github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
5 | github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
6 | github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
7 | golang.org/x/exp v0.0.0-20220325121720-054d8573a5d8 h1:Xt4/LzbTwfocTk9ZLEu4onjeFucl88iW+v4j4PWbQuE=
8 | golang.org/x/exp v0.0.0-20220325121720-054d8573a5d8/go.mod h1:lgLbSvA5ygNOMpwM/9anMpWVlVJ7Z+cHWq/eFuinpGE=
9 |
--------------------------------------------------------------------------------
/aio/wrap.go:
--------------------------------------------------------------------------------
1 | package aio
2 |
3 | import (
4 | "bufio"
5 | "io"
6 | )
7 |
8 | // Wraps a writer and its underlying closer.
9 | type writerWrapper struct {
10 | top, bottom io.WriteCloser
11 | }
12 |
13 | func (w *writerWrapper) Write(p []byte) (int, error) {
14 | return w.top.Write(p)
15 | }
16 |
17 | func (w *writerWrapper) Close() error {
18 | if err := w.top.Close(); err != nil {
19 | return err
20 | }
21 | return w.bottom.Close()
22 | }
23 |
24 | type Reader struct {
25 | bufio.Reader
26 | r io.ReadCloser
27 | }
28 |
29 | func (r *Reader) Close() error {
30 | return r.r.Close()
31 | }
32 |
33 | type Writer struct {
34 | bufio.Writer
35 | w io.WriteCloser
36 | }
37 |
38 | func (w *Writer) Close() error {
39 | if err := w.Flush(); err != nil {
40 | return err
41 | }
42 | return w.w.Close()
43 | }
44 |
--------------------------------------------------------------------------------
/ppln/common.go:
--------------------------------------------------------------------------------
1 | package ppln
2 |
3 | import (
4 | "iter"
5 | )
6 |
7 | // SliceInput returns a function that iterates over a slice,
8 | // to be used as the input function in [Serial] and [NonSerial].
9 | func SliceInput[T any](s []T) iter.Seq2[T, error] {
10 | return func(yield func(T, error) bool) {
11 | for _, t := range s {
12 | if !yield(t, nil) {
13 | return
14 | }
15 | }
16 | }
17 | }
18 |
19 | // RangeInput returns a function that iterates over a range of integers,
20 | // starting at start and ending at (and excluding) stop,
21 | // to be used as the input function in [Serial] and [NonSerial].
22 | func RangeInput(start, stop int) iter.Seq2[int, error] {
23 | return func(yield func(int, error) bool) {
24 | for i := start; i < stop; i++ {
25 | if !yield(i, nil) {
26 | return
27 | }
28 | }
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/clustering/kmeans_test.go:
--------------------------------------------------------------------------------
1 | package clustering
2 |
3 | import (
4 | "math/rand"
5 | "reflect"
6 | "testing"
7 | "time"
8 | )
9 |
10 | func TestKmeans(t *testing.T) {
11 | rand.Seed(time.Now().UnixNano())
12 |
13 | m := [][]float64{
14 | {0.1, 0.0},
15 | {0.9, 1.0},
16 | {-0.1, 0.0},
17 | {0.0, -0.1},
18 | {1.1, 1.0},
19 | {1.0, 1.1},
20 | {1.0, 0.9},
21 | {0.0, 0.1},
22 | }
23 |
24 | means, tags := Kmeans(m, 2)
25 |
26 | if tags[0] == 0 {
27 | assertEqual(tags, []int{0, 1, 0, 0, 1, 1, 1, 0}, t)
28 | assertEqual(means, [][]float64{{0, 0}, {1, 1}}, t)
29 | } else {
30 | assertEqual(tags, []int{1, 0, 1, 1, 0, 0, 0, 1}, t)
31 | assertEqual(means, [][]float64{{1, 1}, {0, 0}}, t)
32 | }
33 | }
34 |
35 | func assertEqual(act, exp interface{}, t *testing.T) {
36 | if !reflect.DeepEqual(act, exp) {
37 | t.Fatalf("Wrong value: %v, expected %v", act, exp)
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/rhash/buz_test.go:
--------------------------------------------------------------------------------
1 | package rhash
2 |
3 | import (
4 | "hash"
5 | "testing"
6 | )
7 |
8 | func TestBuz64(t *testing.T) {
9 | test64(t, func(i int) hash.Hash64 { return NewBuz(i) })
10 | }
11 |
12 | func TestBuz32(t *testing.T) {
13 | test32(t, func(i int) hash.Hash32 { return NewBuz(i) })
14 | }
15 |
16 | func TestBuz64_seed(t *testing.T) {
17 | seed := BuzRandomSeed()
18 | test64(t, func(i int) hash.Hash64 {
19 | return NewBuzWithSeed(i, seed)
20 | })
21 | }
22 |
23 | func TestBuz64_modifySeed(t *testing.T) {
24 | seed := &BuzSeed{}
25 | for i := range seed {
26 | seed[i] = uint64(i)
27 | }
28 | input := "dsfhjkdfhdjsfjksdjadi"
29 | h := NewBuzWithSeed(len(input), seed)
30 | h.Write([]byte(input))
31 | want := h.Sum64()
32 | for i := range seed {
33 | seed[i] = 0
34 | }
35 | h.Write([]byte(input))
36 | got := h.Sum64()
37 | if got != want {
38 | t.Fatalf("Buz.Sum64(%q)=%v, want %v", input, got, want)
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/ppln/doc.go:
--------------------------------------------------------------------------------
1 | // Package ppln provides generic parallel processing pipelines.
2 | //
3 | // # General usage
4 | //
5 | // This package provides two modes of operation: serial and non-serial.
6 | // In [Serial] the outputs are ordered in the same order of the inputs.
7 | // In [NonSerial] the order of outputs is arbitrary,
8 | // but correlated with the order of inputs.
9 | //
10 | // Each of the functions blocks the calling function until either the processing
11 | // is done (output was called on the last value) or until an error is returned.
12 | //
13 | // # Stopping
14 | //
15 | // Each user-function (input, transform, output) may return an error.
16 | // Returning a non-nil error stops the pipeline prematurely, and that
17 | // error is returned to the caller.
18 | //
19 | // # Experimental
20 | //
21 | // This package relies on the experimental [iter] package.
22 | // In order to use it, go 1.22 is required with GOEXPERIMENT=rangefunc.
23 | package ppln
24 |
--------------------------------------------------------------------------------
/morris/error/error.go:
--------------------------------------------------------------------------------
1 | // Prints error rates of Morris counter using different m's.
2 | package main
3 |
4 | import (
5 | "fmt"
6 |
7 | "github.com/fluhus/gostuff/gnum"
8 | "github.com/fluhus/gostuff/morris"
9 | "github.com/fluhus/gostuff/ptimer"
10 | )
11 |
12 | const (
13 | upto = 10000000
14 | reps = 100
15 | )
16 |
17 | func main() {
18 | ms := []uint{1, 3, 10, 30, 100, 300, 1000, 3000, 10000}
19 | var errs []float64
20 | for _, m := range ms {
21 | pt := ptimer.NewMessage(fmt.Sprint("{} (", m, ")"))
22 | err := 0.0
23 | for rep := 1; rep <= reps; rep++ {
24 | c := uint(0)
25 | for i := 1; i <= upto; i++ {
26 | c = morris.Raise(c, m)
27 | r := morris.Restore(c, m)
28 | err += float64(gnum.Abs(int(i)-int(r))) / float64(i)
29 | pt.Inc()
30 | }
31 | }
32 | pt.Done()
33 | errs = append(errs, err/upto/reps)
34 | }
35 | for i := range ms {
36 | fmt.Printf("// % 10d: %.1f%%\n", ms[i], errs[i]*100)
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/nlp/wordnet/string.go:
--------------------------------------------------------------------------------
1 | package wordnet
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | )
7 |
8 | // String functions for types.
9 |
10 | // TODO(amit): Consider removing String() functions to simplify the API.
11 |
12 | // Returns a compact string representation of the WordNet data collection, for
13 | // debugging.
14 | func (wn *WordNet) String() string {
15 | return fmt.Sprintf("WordNet[%d lemmas, %d synsets, %d exceptions,"+
16 | " %d examples]",
17 | len(wn.Lemma), len(wn.Synset), len(wn.Exception), len(wn.Example))
18 | }
19 |
20 | // Returns a string representation of the synset, for debugging.
21 | func (s *Synset) String() string {
22 | result := bytes.NewBuffer(make([]byte, 0, 100))
23 | fmt.Fprintf(result, "Synset[%s.", s.Pos)
24 | for i, word := range s.Word {
25 | if i > 0 {
26 | fmt.Fprintf(result, ",")
27 | }
28 | fmt.Fprintf(result, " %v", word)
29 | }
30 | fmt.Fprintf(result, ": %s]", s.Gloss)
31 | return result.String()
32 | }
33 |
--------------------------------------------------------------------------------
/csvx/csvx_test.go:
--------------------------------------------------------------------------------
1 | package csvx
2 |
3 | import (
4 | "slices"
5 | "strings"
6 | "testing"
7 |
8 | "github.com/fluhus/gostuff/iterx"
9 | )
10 |
11 | func TestReader(t *testing.T) {
12 | input := "a,bb,ccc\nddd,ee,f"
13 | want := [][]string{{"a", "bb", "ccc"}, {"ddd", "ee", "f"}}
14 | got, err := iterx.CollectErr(Reader(strings.NewReader(input)))
15 | if err != nil {
16 | t.Fatalf("Reader(%q) failed: %v", input, err)
17 | }
18 | if !slices.EqualFunc(got, want, slices.Equal) {
19 | t.Fatalf("Reader(%q)=%q, want %q", input, got, want)
20 | }
21 | }
22 |
23 | func TestReaderTSV(t *testing.T) {
24 | input := "a\tbb\tccc\nddd\tee\tf"
25 | want := [][]string{{"a", "bb", "ccc"}, {"ddd", "ee", "f"}}
26 | got, err := iterx.CollectErr(Reader(strings.NewReader(input), TSV))
27 | if err != nil {
28 | t.Fatalf("Reader(%q) failed: %v", input, err)
29 | }
30 | if !slices.EqualFunc(got, want, slices.Equal) {
31 | t.Fatalf("Reader(%q)=%q, want %q", input, got, want)
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/csvx/examples_test.go:
--------------------------------------------------------------------------------
1 | package csvx
2 |
3 | import (
4 | "fmt"
5 | "strings"
6 | )
7 |
8 | func ExampleDecodeReader() {
9 | type person struct {
10 | Name string
11 | Age int
12 | }
13 | input := strings.NewReader("alice,30\nbob,25")
14 |
15 | for p, err := range DecodeReader[person](input) {
16 | if err != nil {
17 | panic(err)
18 | }
19 | fmt.Println(p.Name, "is", p.Age, "years old")
20 | }
21 |
22 | //Output:
23 | //alice is 30 years old
24 | //bob is 25 years old
25 | }
26 |
27 | func ExampleDecodeReaderHeader() {
28 | type person struct {
29 | Name string
30 | Age int
31 | }
32 | input := strings.NewReader(
33 | "user_id,age,city,name\n" +
34 | "111,30,paris,alice\n" +
35 | "222,25,london,bob")
36 |
37 | for p, err := range DecodeReaderHeader[person](input) {
38 | if err != nil {
39 | panic(err)
40 | }
41 | fmt.Println(p.Name, "is", p.Age, "years old")
42 | }
43 |
44 | //Output:
45 | //alice is 30 years old
46 | //bob is 25 years old
47 | }
48 |
--------------------------------------------------------------------------------
/repeat/repeat_test.go:
--------------------------------------------------------------------------------
1 | package repeat
2 |
3 | import (
4 | "io"
5 | "testing"
6 | )
7 |
8 | func TestReader(t *testing.T) {
9 | r := NewReader([]byte("amit"), 2)
10 | buf := make([]byte, 3)
11 | want := []string{"ami", "tam", "it"}
12 |
13 | for _, w := range want {
14 | n, err := r.Read(buf)
15 | if err != nil {
16 | t.Fatalf("Read() failed: %v", err)
17 | }
18 | if got := string(buf[:n]); got != w {
19 | t.Fatalf("Read()=%q, want %q", got, "ami")
20 | }
21 | }
22 | if _, err := r.Read(buf); err != io.EOF {
23 | t.Fatalf("Read() err=%v, want EOF", err)
24 | }
25 | }
26 |
27 | func TestReader_infinite(t *testing.T) {
28 | r := NewReader([]byte("amit"), -1)
29 | buf := make([]byte, 3)
30 | want := []string{"ami", "tam", "ita", "mit", "ami", "tam", "ita", "mit"}
31 |
32 | for i, w := range want {
33 | n, err := r.Read(buf)
34 | if err != nil {
35 | t.Fatalf("#%v: Read() failed: %v", i, err)
36 | }
37 | if got := string(buf[:n]); got != w {
38 | t.Fatalf("#%v: Read()=%q, want %q", i, got, "ami")
39 | }
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/reservoir/reservoir.go:
--------------------------------------------------------------------------------
1 | // Package reservoir implements reservoir sampling.
2 | //
3 | // Reservoir sampling allows sampling m uniformly random elements
4 | // from a stream,
5 | // using O(m) memory regardless of the stream length.
6 | package reservoir
7 |
8 | import "math/rand/v2"
9 |
10 | // Sampler samples a fixed number of elements with uniform distribution
11 | // from a stream.
12 | type Sampler[T any] struct {
13 | Elements []T // Elements selected so far.
14 | r *rand.Rand
15 | n int
16 | }
17 |
18 | // New returns a new sampler that samples n elements.
19 | func New[T any](n int) *Sampler[T] {
20 | return &Sampler[T]{
21 | Elements: make([]T, 0, n),
22 | r: rand.New(rand.NewPCG(rand.Uint64(), rand.Uint64())),
23 | }
24 | }
25 |
26 | // Add maybe adds t to the selected sample.
27 | func (r *Sampler[T]) Add(t T) {
28 | r.n++
29 | if len(r.Elements) < cap(r.Elements) {
30 | r.Elements = append(r.Elements, t)
31 | return
32 | }
33 | i := r.r.IntN(r.n)
34 | if i >= len(r.Elements) {
35 | return
36 | }
37 | r.Elements[i] = t
38 | }
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2015 Amit Lavon
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | this software and associated documentation files (the "Software"), to deal in
5 | the Software without restriction, including without limitation the rights to
6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7 | the Software, and to permit persons to whom the Software is furnished to do so,
8 | subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19 |
--------------------------------------------------------------------------------
/nlp/tfidf.go:
--------------------------------------------------------------------------------
1 | package nlp
2 |
3 | // TF-IDF functionality.
4 |
5 | import (
6 | "math"
7 | )
8 |
9 | // TfIdf returns the TF-IDF scores of the given corpus. For each documet,
10 | // returns a map from token to TF-IDF score.
11 | //
12 | // TF = count(token in document) / count(all tokens in document)
13 | //
14 | // IDF = log(count(documents) / count(documents with token))
15 | func TfIdf(docTokens [][]string) []map[string]float64 {
16 | tf := make([]map[string]float64, len(docTokens))
17 | idf := map[string]float64{}
18 |
19 | // Collect TF and DF.
20 | for i := range docTokens {
21 | tf[i] = map[string]float64{}
22 | for j := range docTokens[i] {
23 | tf[i][docTokens[i][j]]++
24 | }
25 | for token := range tf[i] {
26 | tf[i][token] /= float64(len(docTokens[i]))
27 | idf[token]++
28 | }
29 | }
30 |
31 | // Turn DF to IDF.
32 | for token, df := range idf {
33 | idf[token] = math.Log(float64(len(docTokens)) / df)
34 | }
35 |
36 | // Turn TF to TF-IDF.
37 | for i := range tf {
38 | for token := range tf[i] {
39 | tf[i][token] *= idf[token]
40 | }
41 | }
42 |
43 | return tf
44 | }
45 |
--------------------------------------------------------------------------------
/ppln/common_test.go:
--------------------------------------------------------------------------------
1 | package ppln
2 |
3 | import (
4 | "slices"
5 | "testing"
6 | )
7 |
8 | func TestRangeInput(t *testing.T) {
9 | tests := []struct {
10 | from, to int
11 | want []int
12 | }{
13 | {0, 5, []int{0, 1, 2, 3, 4}},
14 | {-2, 3, []int{-2, -1, 0, 1, 2}},
15 | {0, 0, nil},
16 | {1, 1, nil},
17 | }
18 | for _, test := range tests {
19 | var got []int
20 | for i, err := range RangeInput(test.from, test.to) {
21 | if err != nil {
22 | t.Fatalf("RangeInput(%d,%d) returned error: %v",
23 | test.from, test.to, err)
24 | }
25 | got = append(got, i)
26 | }
27 | if !slices.Equal(got, test.want) {
28 | t.Fatalf("RangeInput(%d,%d)=%v, want %v",
29 | test.from, test.to, got, test.want)
30 | }
31 | }
32 | }
33 |
34 | func TestSliceInput(t *testing.T) {
35 | want := []int{3, 5, 1, 7}
36 | var got []int
37 | for i, err := range SliceInput(slices.Clone(want)) {
38 | if err != nil {
39 | t.Fatalf("SliceInput(%v) returned error: %v",
40 | want, err)
41 | }
42 | got = append(got, i)
43 | }
44 | if !slices.Equal(got, want) {
45 | t.Fatalf("SliceInput(%v)=%v, want %v",
46 | want, got, want)
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/gnum/vecs_test.go:
--------------------------------------------------------------------------------
1 | package gnum
2 |
3 | import (
4 | "slices"
5 | "testing"
6 | )
7 |
8 | func TestOnes(t *testing.T) {
9 | want := []int8{1, 1, 1, 1}
10 | got := Ones[[]int8](4)
11 | if !slices.Equal(got, want) {
12 | t.Errorf("Ones(4) = %v, want %v", got, want)
13 | }
14 | }
15 |
16 | func TestAdd(t *testing.T) {
17 | a := []uint{4, 6}
18 | b := []uint{2, 3}
19 | got := Add(nil, a, b)
20 | want := []uint{6, 9}
21 | if !slices.Equal(got, want) {
22 | t.Errorf("Add(nil, %v, %v)=%v, want %v", a, b, got, want)
23 | }
24 | if want := []uint{4, 6}; !slices.Equal(a, want) {
25 | t.Errorf("a=%v, want %v", a, want)
26 | }
27 | if want := []uint{2, 3}; !slices.Equal(b, want) {
28 | t.Errorf("b=%v, want %v", b, want)
29 | }
30 | }
31 |
32 | func TestAdd_inplace(t *testing.T) {
33 | a := []uint{4, 6}
34 | b := []uint{2, 3}
35 | got := Add(a, b)
36 | want := []uint{6, 9}
37 | if !slices.Equal(got, want) {
38 | t.Errorf("Add(nil, %v, %v)=%v, want %v", a, b, got, want)
39 | }
40 | if !slices.Equal(a, want) {
41 | t.Errorf("a=%v, want %v", a, want)
42 | }
43 | if want := []uint{2, 3}; !slices.Equal(b, want) {
44 | t.Errorf("b=%v, want %v", b, want)
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/clustering/agglo_test.go:
--------------------------------------------------------------------------------
1 | package clustering
2 |
3 | import (
4 | "math"
5 | "reflect"
6 | "testing"
7 | )
8 |
9 | // TODO(amit): Add more test cases.
10 |
11 | func TestClink(t *testing.T) {
12 | points := []float64{0, 1, 5}
13 | steps := []AggloStep{
14 | {0, 1, 1},
15 | {1, 2, 5},
16 | }
17 | agg := clink(len(points), func(i, j int) float64 {
18 | return math.Abs(points[i] - points[j])
19 | })
20 | if agg.Len() != len(points)-1 {
21 | t.Fatalf("Len()=%v, want %v", agg.Len(), len(points)-1)
22 | }
23 | for i := range steps {
24 | if step := agg.Step(i); !reflect.DeepEqual(steps[i], step) {
25 | t.Errorf("Step(%v)=%v, want %v", i, step, steps[i])
26 | }
27 | }
28 | }
29 |
30 | func TestSlink(t *testing.T) {
31 | points := []float64{0, 1, 5}
32 | steps := []AggloStep{
33 | {0, 1, 1},
34 | {1, 2, 4},
35 | }
36 | agg := slink(len(points), func(i, j int) float64 {
37 | return math.Abs(points[i] - points[j])
38 | })
39 | if agg.Len() != len(points)-1 {
40 | t.Fatalf("Len()=%v, want %v", agg.Len(), len(points)-1)
41 | }
42 | for i := range steps {
43 | if step := agg.Step(i); !reflect.DeepEqual(steps[i], step) {
44 | t.Errorf("Step(%v)=%v, want %v", i, step, steps[i])
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/snm/queue_test.go:
--------------------------------------------------------------------------------
1 | package snm
2 |
3 | import "testing"
4 |
5 | func TestQueue(t *testing.T) {
6 | q := &Queue[int]{}
7 |
8 | qEnqueue(q, 11)
9 | qExpect(q, t, 11)
10 |
11 | qEnqueue(q, 22, 33)
12 | qExpect(q, t, 22, 33)
13 |
14 | qEnqueue(q, 44, 55, 66)
15 | qExpect(q, t, 44)
16 | qEnqueue(q, 77)
17 | qExpect(q, t, 55)
18 | qEnqueue(q, 88)
19 | qExpect(q, t, 66, 77, 88)
20 | }
21 |
22 | func qEnqueue(q *Queue[int], x ...int) {
23 | for _, xx := range x {
24 | q.Enqueue(xx)
25 | }
26 | }
27 |
28 | func qExpect(q *Queue[int], t *testing.T, x ...int) {
29 | for _, xx := range x {
30 | if got := q.Peek(); got != xx {
31 | t.Fatalf("q.pull()=%v, want %v", got, xx)
32 | }
33 | if got := q.Dequeue(); got != xx {
34 | t.Fatalf("q.pull()=%v, want %v", got, xx)
35 | }
36 | }
37 | }
38 |
39 | func FuzzQueue(f *testing.F) {
40 | f.Add(1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
41 | f.Fuzz(func(t *testing.T, a, b, c, d, e, f, g, h, i, j int) {
42 | q := &Queue[int]{}
43 | qEnqueue(q, a)
44 | qExpect(q, t, a)
45 | qEnqueue(q, b, c)
46 | qExpect(q, t, b, c)
47 | qEnqueue(q, d, e)
48 | qExpect(q, t, d)
49 | qEnqueue(q, f)
50 | qExpect(q, t, e)
51 | qEnqueue(q, g, h, i)
52 | qExpect(q, t, f, g)
53 | qEnqueue(q, j)
54 | qExpect(q, t, h, i, j)
55 | })
56 | }
57 |
--------------------------------------------------------------------------------
/nlp/tokenize.go:
--------------------------------------------------------------------------------
1 | package nlp
2 |
3 | import (
4 | "github.com/agonopol/go-stem"
5 | "regexp"
6 | "strings"
7 | )
8 |
9 | // Tokenizer splits text into tokens. This regexp represents a single word.
10 | // Changing this regexp will affect the Tokenize function.
11 | var Tokenizer = regexp.MustCompile("\\w([\\w']*\\w)?")
12 |
13 | // Tokenize splits a given text to a slice of stemmed, lowercase words. If
14 | // keepStopWords is false, will drop stop words.
15 | func Tokenize(s string, keepStopWords bool) []string {
16 | s = correctUtf8Punctuation(s)
17 | s = strings.ToLower(s)
18 | words := Tokenizer.FindAllString(s, -1)
19 | var result []string
20 | for _, word := range words {
21 | if !keepStopWords && StopWords[word] {
22 | continue
23 | }
24 | result = append(result, Stem(word))
25 | }
26 |
27 | return result
28 | }
29 |
30 | // Stem porter-stems the given word.
31 | func Stem(s string) string {
32 | if strings.HasSuffix(s, "'s") {
33 | s = s[:len(s)-2]
34 | }
35 | return string(stemmer.Stem([]byte(s)))
36 | }
37 |
38 | // correctUtf8Punctuation translates or removes non-ASCII punctuation characters.
39 | func correctUtf8Punctuation(s string) string {
40 | return strings.Replace(s, "’", "'", -1)
41 | // TODO(amit): Improve this function with more characters.
42 | }
43 |
--------------------------------------------------------------------------------
/sets/sorted_test.go:
--------------------------------------------------------------------------------
1 | package sets
2 |
3 | import (
4 | "slices"
5 | "testing"
6 | )
7 |
8 | func TestSorted(t *testing.T) {
9 | tests := []struct {
10 | a, b, u, i []int
11 | }{
12 | {nil, nil, nil, nil},
13 | {[]int{1}, nil, []int{1}, nil},
14 | {nil, []int{2}, []int{2}, nil},
15 | {[]int{1}, []int{2}, []int{1, 2}, nil},
16 | {[]int{2}, []int{1}, []int{1, 2}, nil},
17 | {[]int{1, 3, 5}, []int{3, 4, 5, 6}, []int{1, 3, 4, 5, 6}, []int{3, 5}},
18 |
19 | {[]int{1, 1}, []int{1, 2}, []int{1, 1, 2}, []int{1}},
20 | {[]int{1, 2, 2, 3}, []int{2, 3, 3, 4}, []int{1, 2, 2, 3, 3, 4}, []int{2, 3}},
21 | }
22 | for _, test := range tests {
23 | i := SortedIntersection(test.a, test.b)
24 | u := SortedUnion(test.a, test.b)
25 | il := SortedIntersectionLen(test.a, test.b)
26 | ul := SortedUnionLen(test.a, test.b)
27 | if !slices.Equal(i, test.i) {
28 | t.Fatalf("SortedIntersection(%v,%v)=%v, want %v",
29 | test.a, test.b, i, test.i)
30 | }
31 | if !slices.Equal(u, test.u) {
32 | t.Fatalf("SortedUnion(%v,%v)=%v, want %v",
33 | test.a, test.b, u, test.u)
34 | }
35 | if il != len(i) {
36 | t.Fatalf("SortedIntersectionLen(%v,%v)=%v, want %v",
37 | test.a, test.b, il, len(i))
38 | }
39 | if ul != len(u) {
40 | t.Fatalf("SortedUnionLen(%v,%v)=%v, want %v",
41 | test.a, test.b, ul, len(u))
42 | }
43 | }
44 | }
45 |
--------------------------------------------------------------------------------
/reservoir/reservoir_test.go:
--------------------------------------------------------------------------------
1 | package reservoir
2 |
3 | import (
4 | "fmt"
5 | "slices"
6 | "testing"
7 | )
8 |
9 | func TestSmall(t *testing.T) {
10 | var want []int
11 | r := New[int](10)
12 | if len(r.Elements) != 0 {
13 | t.Fatalf("len(E)=%v, want 0", len(r.Elements))
14 | }
15 | for i := range 10 {
16 | r.Add(i)
17 | want = append(want, i)
18 | if !slices.Equal(r.Elements, want) {
19 | t.Fatalf("E=%v, want %v", r.Elements, want)
20 | }
21 | }
22 | r.Add(10)
23 | want = append(want, 10)
24 | if slices.Equal(r.Elements, want) {
25 | t.Fatalf("E=%v, want smaller", r.Elements)
26 | }
27 | if len(r.Elements) != 10 {
28 | t.Fatalf("len(E)=%v, want 10", len(r.Elements))
29 | }
30 | }
31 |
32 | func TestBig(t *testing.T) {
33 | counts := make([]int, 10)
34 | for range 1000 {
35 | r := New[int](5)
36 | for i := range 10 {
37 | r.Add(i)
38 | }
39 | for _, i := range r.Elements {
40 | counts[i]++
41 | }
42 | }
43 | wantMin, wantMax := 450, 550
44 | for i, c := range counts {
45 | if c < wantMin || c > wantMax {
46 | t.Errorf("count[%v]=%v, want %v-%v", i, c, wantMin, wantMax)
47 | }
48 | }
49 | }
50 |
51 | func Example() {
52 | // Select 10 random elements uniformly out of a stream.
53 | sampler := New[int](10)
54 | for i := range 1000 {
55 | sampler.Add(i)
56 | }
57 | fmt.Println("Selected subsample:", sampler.Elements)
58 | }
59 |
--------------------------------------------------------------------------------
/xmlnode/read_test.go:
--------------------------------------------------------------------------------
1 | package xmlnode
2 |
3 | import (
4 | "strings"
5 | "testing"
6 | )
7 |
8 | func TestReadAll(t *testing.T) {
9 | text := "worldworld"
10 | node, err := ReadAll(strings.NewReader(text))
11 |
12 | if err != nil {
13 | t.Fatal(err)
14 | }
15 |
16 | result := nodeToString(node)
17 | if result != text {
18 | t.Fatal("expected:", text, "actual:", result)
19 | }
20 | }
21 |
22 | // Inefficient stringifier for testing.
23 | func nodeToString(n Node) string {
24 | switch n := n.(type) {
25 | case *root:
26 | result := ""
27 | for _, child := range n.children {
28 | result += nodeToString(child)
29 | }
30 | return result
31 |
32 | case *tag:
33 | result := "<" + n.tagName
34 | for _, attr := range n.attr {
35 | result += " " + attr.Name.Local + "=\"" + attr.Value + "\""
36 | }
37 | result += ">"
38 | for _, child := range n.children {
39 | result += nodeToString(child)
40 | }
41 | result += "" + n.tagName + ">"
42 | return result
43 |
44 | case *text:
45 | return n.text
46 |
47 | case *comment:
48 | return ""
49 |
50 | case *procInst:
51 | return "" + n.target + " " + n.inst + "?>"
52 |
53 | case *directive:
54 | return ""
55 |
56 | default:
57 | panic("Unknown node type.")
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/csvx/csvx.go:
--------------------------------------------------------------------------------
1 | package csvx
2 |
3 | import (
4 | "encoding/csv"
5 | "io"
6 | "iter"
7 |
8 | "github.com/fluhus/gostuff/aio"
9 | )
10 |
11 | // Reader iterates over CSV entries from a reader.
12 | // Applies the given modifiers before iteration.
13 | func Reader(r io.Reader, mods ...ReaderModifier) iter.Seq2[[]string, error] {
14 | return func(yield func([]string, error) bool) {
15 | c := csv.NewReader(r)
16 | for _, mod := range mods {
17 | mod(c)
18 | }
19 | for {
20 | e, err := c.Read()
21 | if err == io.EOF {
22 | return
23 | }
24 | if !yield(e, nil) {
25 | return
26 | }
27 | }
28 | }
29 | }
30 |
31 | // File iterates over CSV entries from a file.
32 | // Applies the given modifiers before iteration.
33 | func File(file string, mods ...ReaderModifier) iter.Seq2[[]string, error] {
34 | return func(yield func([]string, error) bool) {
35 | f, err := aio.Open(file)
36 | if err != nil {
37 | yield(nil, err)
38 | return
39 | }
40 | defer f.Close()
41 | c := csv.NewReader(f)
42 | for _, mod := range mods {
43 | mod(c)
44 | }
45 | for {
46 | e, err := c.Read()
47 | if err == io.EOF {
48 | return
49 | }
50 | if !yield(e, nil) {
51 | return
52 | }
53 | }
54 | }
55 | }
56 |
57 | // ReaderModifier modifies the settings of a CSV reader
58 | // before iteration starts.
59 | type ReaderModifier = func(*csv.Reader)
60 |
61 | // TSV makes the reader use tab as the delimiter.
62 | func TSV(r *csv.Reader) {
63 | r.Comma = '\t'
64 | }
65 |
--------------------------------------------------------------------------------
/graphs/bfsdfs.go:
--------------------------------------------------------------------------------
1 | package graphs
2 |
3 | import (
4 | "iter"
5 |
6 | "github.com/fluhus/gostuff/sets"
7 | "github.com/fluhus/gostuff/snm"
8 | )
9 |
10 | // BFS iterates over this graph's nodes in a breadth-first ordering,
11 | // including start.
12 | func (g *Graph[T]) BFS(start T) iter.Seq[T] {
13 | return func(yield func(T) bool) {
14 | if _, ok := g.v[start]; !ok {
15 | return
16 | }
17 |
18 | elems := g.v.Elements()
19 | edges := g.edgeSlices()
20 | istart := g.v.IndexOf(start)
21 | done := sets.Set[int]{}.Add(istart)
22 | q := &snm.Queue[int]{}
23 | q.Enqueue(istart)
24 |
25 | for v := range q.Seq() {
26 | if !yield(elems[v]) {
27 | return
28 | }
29 | for _, e := range edges[v] {
30 | if done.Has(e) {
31 | continue
32 | }
33 | done.Add(e)
34 | q.Enqueue(e)
35 | }
36 | }
37 | }
38 | }
39 |
40 | // DFS iterates over this graph's nodes in a depth-first ordering,
41 | // including start.
42 | func (g *Graph[T]) DFS(start T) iter.Seq[T] {
43 | return func(yield func(T) bool) {
44 | if _, ok := g.v[start]; !ok {
45 | return
46 | }
47 |
48 | elems := g.v.Elements()
49 | edges := g.edgeSlices()
50 | istart := g.v.IndexOf(start)
51 | done := sets.Set[int]{}.Add(istart)
52 | q := []int{istart}
53 |
54 | for len(q) > 0 {
55 | v := q[len(q)-1]
56 | q = q[:len(q)-1]
57 | if !yield(elems[v]) {
58 | return
59 | }
60 | for _, e := range edges[v] {
61 | if done.Has(e) {
62 | continue
63 | }
64 | done.Add(e)
65 | q = append(q, e)
66 | }
67 | }
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/rhash/fuzz_test.go:
--------------------------------------------------------------------------------
1 | package rhash
2 |
3 | import (
4 | "hash"
5 | "testing"
6 |
7 | "github.com/fluhus/gostuff/snm"
8 | )
9 |
10 | func FuzzBuz64(f *testing.F) {
11 | fuzz64(f, func(n int) hash.Hash64 { return NewBuz(n) })
12 | }
13 |
14 | func FuzzRabin64(f *testing.F) {
15 | fuzz64(f, func(n int) hash.Hash64 { return NewRabinFingerprint64(n) })
16 | }
17 |
18 | func FuzzBuz32(f *testing.F) {
19 | fuzz32(f, func(n int) hash.Hash32 { return NewBuz(n) })
20 | }
21 |
22 | func FuzzRabin32(f *testing.F) {
23 | fuzz32(f, func(n int) hash.Hash32 { return NewRabinFingerprint32(n) })
24 | }
25 |
26 | func fuzz64(f *testing.F, fn func(n int) hash.Hash64) {
27 | const m = 10
28 | prefix := snm.Slice(30, func(i int) byte { return byte(i) })
29 |
30 | f.Add([]byte{})
31 | f.Fuzz(func(t *testing.T, a []byte) {
32 | a = append(prefix, a...)
33 | h := fn(m)
34 | h.Write(a[len(a)-m:])
35 | want := h.Sum64()
36 | for i := range a[m:] {
37 | h.Write(a[i:])
38 | if got := h.Sum64(); got != want {
39 | t.Fatalf("Sum64(%v)=%v, want %v", a[i:], got, want)
40 | }
41 | }
42 | })
43 | }
44 |
45 | func fuzz32(f *testing.F, fn func(n int) hash.Hash32) {
46 | const m = 10
47 | prefix := snm.Slice(30, func(i int) byte { return byte(i) })
48 |
49 | f.Add([]byte{})
50 | f.Fuzz(func(t *testing.T, a []byte) {
51 | a = append(prefix, a...)
52 | h := fn(m)
53 | h.Write(a[len(a)-m:])
54 | want := h.Sum32()
55 | for i := range a[m:] {
56 | h.Write(a[i:])
57 | if got := h.Sum32(); got != want {
58 | t.Fatalf("Sum32(%v)=%v, want %v", a[i:], got, want)
59 | }
60 | }
61 | })
62 | }
63 |
--------------------------------------------------------------------------------
/ezpprof/ezpprof.go:
--------------------------------------------------------------------------------
1 | // Package ezpprof is a convenience wrapper over the runtime/pprof package.
2 | //
3 | // This package helps to quickly introduce profiling to a piece of code without
4 | // the mess of opening files and checking errors.
5 | //
6 | // A typical use of this package looks like:
7 | //
8 | // ezpprof.Start("myfile.pprof")
9 | // {... some complicated code ...}
10 | // ezpprof.Stop()
11 | //
12 | // Or alternatively:
13 | //
14 | // const profile = true
15 | //
16 | // if profile {
17 | // ezpprof.Start("myfile.pprof")
18 | // defer ezpprof.Stop()
19 | // }
20 | package ezpprof
21 |
22 | import (
23 | "io"
24 | "runtime/pprof"
25 |
26 | "github.com/fluhus/gostuff/aio"
27 | )
28 |
29 | var fout io.WriteCloser
30 |
31 | // Start starts CPU profiling and writes to the given file.
32 | // Panics if an error occurs.
33 | func Start(file string) {
34 | if fout != nil {
35 | panic("already profiling")
36 | }
37 | f, err := aio.CreateRaw(file)
38 | if err != nil {
39 | panic(err)
40 | }
41 | fout = f
42 | pprof.StartCPUProfile(fout)
43 | }
44 |
45 | // Stop stops CPU profiling and closes the output file.
46 | // Panics if called without calling Start.
47 | func Stop() {
48 | if fout == nil {
49 | panic("Stop called without calling Start")
50 | }
51 | pprof.StopCPUProfile()
52 | if err := fout.Close(); err != nil {
53 | panic(err)
54 | }
55 | fout = nil
56 | }
57 |
58 | // Heap writes heap profile to the given file. Panics if an error occurs.
59 | func Heap(file string) {
60 | f, err := aio.Create(file)
61 | if err != nil {
62 | panic(err)
63 | }
64 | defer f.Close()
65 |
66 | err = pprof.WriteHeapProfile(f)
67 | if err != nil {
68 | panic(err)
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/snm/queue.go:
--------------------------------------------------------------------------------
1 | package snm
2 |
3 | import "iter"
4 |
5 | // Queue is a memory-efficient FIFO container.
6 | type Queue[T any] struct {
7 | q []T
8 | i, n int
9 | }
10 |
11 | // Enqueue inserts an element to the queue.
12 | func (q *Queue[T]) Enqueue(x T) {
13 | if q.n == len(q.q) {
14 | nq := make([]T, len(q.q)*2+1)
15 | _ = append(append(append(nq[:0], q.q[q.i:]...), q.q[:q.i]...), x)
16 | q.q = nq
17 | q.n++
18 | q.i = 0
19 | return
20 | }
21 | i := (q.i + q.n) % len(q.q)
22 | q.q[i] = x
23 | q.n++
24 | }
25 |
26 | // Dequeue removes and returns the next element in the queue.
27 | // Panics if the queue is empty.
28 | func (q *Queue[T]) Dequeue() T {
29 | if q.n == 0 {
30 | panic("pull with 0 elements")
31 | }
32 | x := q.q[q.i]
33 | var zero T
34 | q.q[q.i] = zero // Remove element to allow GC.
35 | q.n--
36 | q.i = (q.i + 1) % len(q.q)
37 | return x
38 | }
39 |
40 | // Peek returns the next element in the queue,
41 | // without modifying its contents.
42 | // Panics if the queue is empty.
43 | func (q *Queue[T]) Peek() T {
44 | if q.n == 0 {
45 | panic("pull with 0 elements")
46 | }
47 | return q.q[q.i]
48 | }
49 |
50 | // Len return the current number of elements in the queue.
51 | func (q *Queue[T]) Len() int {
52 | return q.n
53 | }
54 |
55 | // Seq returns an iterator over the queue's elements,
56 | // dequeueing each one.
57 | //
58 | // It is okay to enqueue elements while iterating,
59 | // from within the same goroutine.
60 | // The new elements will be included in the same loop.
61 | func (q *Queue[T]) Seq() iter.Seq[T] {
62 | return func(yield func(T) bool) {
63 | for q.Len() > 0 {
64 | if !yield(q.Dequeue()) {
65 | break
66 | }
67 | }
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/jio/jio.go:
--------------------------------------------------------------------------------
1 | // Package jio provides convenience functions for saving and loading
2 | // JSON-encoded values.
3 | //
4 | // Uses the [aio] package for I/O.
5 | package jio
6 |
7 | import (
8 | "encoding/json"
9 | "io"
10 |
11 | "github.com/fluhus/gostuff/aio"
12 | )
13 |
14 | // Write saves v to the given file, encoded as JSON.
15 | func Write(file string, v any) error {
16 | f, err := aio.Create(file)
17 | if err != nil {
18 | return err
19 | }
20 | e := json.NewEncoder(f)
21 | e.SetIndent("", " ")
22 | if err := e.Encode(v); err != nil {
23 | f.Close()
24 | return err
25 | }
26 | return f.Close()
27 | }
28 |
29 | // Read loads a JSON encoded value from the given file and populates v with it.
30 | func Read(file string, v any) error {
31 | f, err := aio.Open(file)
32 | if err != nil {
33 | return err
34 | }
35 | defer f.Close()
36 | return json.NewDecoder(f).Decode(v)
37 | }
38 |
39 | // ReadAs reads a JSON encoded value of type T and returns it.
40 | func ReadAs[T any](file string) (T, error) {
41 | var t T
42 | err := Read(file, &t)
43 | return t, err
44 | }
45 |
46 | // Iter returns an iterator over sequential JSON values in a file.
47 | //
48 | // Note: a file with several independent JSON values is not a valid JSON file.
49 | func Iter[T any](file string) func(yield func(T, error) bool) {
50 | return func(yield func(T, error) bool) {
51 | f, err := aio.Open(file)
52 | if err != nil {
53 | var t T
54 | yield(t, err)
55 | return
56 | }
57 | defer f.Close()
58 | j := json.NewDecoder(f)
59 | for {
60 | var t T
61 | err := j.Decode(&t)
62 | if err == io.EOF {
63 | return
64 | }
65 | if err != nil {
66 | yield(t, err)
67 | return
68 | }
69 | if !yield(t, nil) {
70 | return
71 | }
72 | }
73 | }
74 | }
75 |
--------------------------------------------------------------------------------
/morris/morris_test.go:
--------------------------------------------------------------------------------
1 | package morris
2 |
3 | import "testing"
4 |
5 | func TestRestore(t *testing.T) {
6 | tests := []struct {
7 | i byte
8 | m uint
9 | want uint
10 | }{
11 | {0, 1, 0},
12 | {1, 1, 1},
13 | {2, 1, 4},
14 | {3, 1, 9},
15 | {4, 1, 19},
16 | {0, 10, 0},
17 | {1, 10, 1},
18 | {2, 10, 2},
19 | {10, 10, 10},
20 | {15, 10, 20},
21 | {20, 10, 31},
22 | {25, 10, 51},
23 | {30, 10, 72},
24 | }
25 | for _, test := range tests {
26 | if got := Restore(test.i, test.m); got != test.want {
27 | t.Errorf("Restore(%d,%d)=%d, want %d",
28 | test.i, test.m, got, test.want)
29 | }
30 | }
31 | }
32 |
33 | func TestRaise_overflow(t *testing.T) {
34 | if !checkOverFlow {
35 | t.Skip()
36 | }
37 | defer func() { recover() }()
38 | got := Raise(byte(255), 10)
39 | t.Fatalf("Raise(byte(255)=%d, want fail", got)
40 | }
41 |
42 | func TestRaise(t *testing.T) {
43 | const reps = 1000
44 | want := map[byte]int{
45 | 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6, 7: 7, 8: 8, 9: 9, 10: 10,
46 | 15: 20, 20: 30, 25: 50, 30: 70, 35: 110, 40: 150,
47 | }
48 | margins := map[byte]int{
49 | 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0,
50 | 15: 2, 20: 2, 25: 2, 30: 2, 35: 2, 40: 2,
51 | }
52 | got := map[byte]int{}
53 |
54 | for i := 0; i < reps; i++ {
55 | a, n := byte(0), 0
56 | for a < 40 {
57 | n++
58 | aa := Raise(a, 10)
59 | if aa == a {
60 | continue
61 | }
62 | a = aa
63 | if _, ok := want[a]; !ok {
64 | continue
65 | }
66 | got[a] += n
67 | }
68 | }
69 |
70 | for k, n := range got {
71 | got := n / reps
72 | if got < want[k]-margins[k] || got > want[k]+margins[k] {
73 | t.Errorf("Raise() to %d took %d, want %d-%d", k, got,
74 | want[k]-margins[k], want[k]+margins[k])
75 | }
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/hashx/hashx.go:
--------------------------------------------------------------------------------
1 | // Package hashx provides simple hashing functions for various input types.
2 | package hashx
3 |
4 | import (
5 | "hash"
6 | "io"
7 |
8 | "github.com/spaolacci/murmur3"
9 | "golang.org/x/exp/constraints"
10 | )
11 |
12 | // Hashx calculates hash values for various input types.
13 | type Hashx struct {
14 | h hash.Hash64
15 | buf []byte
16 | }
17 |
18 | // NewSeed returns a new Hashx with the given seed.
19 | func NewSeed(seed uint32) *Hashx {
20 | return &Hashx{murmur3.New64WithSeed(seed), make([]byte, 8)}
21 | }
22 |
23 | // New returns a new Hashx.
24 | func New() *Hashx {
25 | return &Hashx{murmur3.New64(), make([]byte, 8)}
26 | }
27 |
28 | // Bytes returns the hash value of the given byte sequence.
29 | func (h *Hashx) Bytes(b []byte) uint64 {
30 | h.h.Reset()
31 | h.h.Write(b)
32 | return h.h.Sum64()
33 | }
34 |
35 | // String returns the hash value of the given string.
36 | func (h *Hashx) String(s string) uint64 {
37 | // TODO(amit): Optimize this?
38 | h.h.Reset()
39 | io.WriteString(h.h, s)
40 | return h.h.Sum64()
41 | }
42 |
43 | // IntHashx returns the hash value of the given integer,
44 | // using the given Hashx instance.
45 | func IntHashx[I constraints.Integer](h *Hashx, i I) uint64 {
46 | h.h.Reset()
47 | u := uint64(i)
48 | for j := range 8 {
49 | h.buf[j] = byte(u >> (j * 8))
50 | }
51 | h.h.Write(h.buf)
52 | return h.h.Sum64()
53 | }
54 |
55 | // The default hashx.
56 | var dflt = New()
57 |
58 | // Bytes returns the hash value of the given byte sequence.
59 | func Bytes(b []byte) uint64 { return dflt.Bytes(b) }
60 |
61 | // String returns the hash value of the given string.
62 | func String(s string) uint64 { return dflt.String(s) }
63 |
64 | // Int returns the hash value of the given integer.
65 | func Int[I constraints.Integer](i I) uint64 { return IntHashx(dflt, i) }
66 |
--------------------------------------------------------------------------------
/rhash/bench_test.go:
--------------------------------------------------------------------------------
1 | package rhash
2 |
3 | import (
4 | "crypto/rand"
5 | "fmt"
6 | "testing"
7 |
8 | "github.com/fluhus/gostuff/hashx"
9 | )
10 |
11 | func BenchmarkWrite1K(b *testing.B) {
12 | const n = 20
13 | buf := make([]byte, 1024)
14 | rand.Read(buf)
15 | b.Run("buz64", func(b *testing.B) {
16 | h := NewBuz(n)
17 | for b.Loop() {
18 | h.Write(buf)
19 | }
20 | })
21 | b.Run("rabin32", func(b *testing.B) {
22 | h := NewRabinFingerprint32(n)
23 | for b.Loop() {
24 | h.Write(buf)
25 | }
26 | })
27 | b.Run("rabin64", func(b *testing.B) {
28 | h := NewRabinFingerprint64(n)
29 | for b.Loop() {
30 | h.Write(buf)
31 | }
32 | })
33 | }
34 |
35 | func BenchmarkRolling(b *testing.B) {
36 | text := make([]byte, 10000)
37 | rand.Read(text)
38 | for _, ln := range []int{10, 30, 100} {
39 | b.Run(fmt.Sprint("buz", ln), func(b *testing.B) {
40 | h := NewBuz(ln)
41 | for b.Loop() {
42 | h.Write(text[:ln-1])
43 | var s uint64
44 | for _, b := range text[ln:] {
45 | h.WriteByte(b)
46 | s += h.Sum64()
47 | }
48 | if s == 0 {
49 | b.Error("placeholder to not optimize s out")
50 | }
51 | }
52 | })
53 | b.Run(fmt.Sprint("rabin", ln), func(b *testing.B) {
54 | h := NewRabinFingerprint64(ln)
55 | for b.Loop() {
56 | h.Write(text[:ln-1])
57 | var s uint64
58 | for _, b := range text[ln:] {
59 | h.WriteByte(b)
60 | s += h.Sum64()
61 | }
62 | if s == 0 {
63 | b.Error("placeholder to not optimize s out")
64 | }
65 | }
66 | })
67 | b.Run(fmt.Sprint("murmur3", ln), func(b *testing.B) {
68 | for b.Loop() {
69 | var s uint64
70 | for i := range text[ln:] {
71 | s += hashx.Bytes(text[i : i+ln])
72 | }
73 | if s == 0 {
74 | b.Error("placeholder to not optimize s out")
75 | }
76 | }
77 | })
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/repeat/repeat.go:
--------------------------------------------------------------------------------
1 | // Package repeat implements the repeat-reader. A repeat-reader outputs a given
2 | // constant byte sequence repeatedly.
3 | //
4 | // This package was originally written for testing and profiling parsers.
5 | // A typical use may look something like this:
6 | //
7 | // input := "some line to be parsed\n"
8 | // r := NewReader([]byte(input), 1000)
9 | // parser := myparse.NewParser(r)
10 | //
11 | // (start profiling)
12 | // for range parser.Items() {} // Exhaust parser.
13 | // (stop profiling)
14 | package repeat
15 |
16 | import "io"
17 |
18 | // Reader outputs a given constant byte sequence repeatedly.
19 | type Reader struct {
20 | data []byte
21 | i int
22 | n int
23 | }
24 |
25 | // NewReader returns a reader that outputs data n times. If n is negative, repeats
26 | // infinitely. Copies the contents of data.
27 | func NewReader(data []byte, n int) *Reader {
28 | cp := append(make([]byte, 0, len(data)), data...)
29 | return &Reader{data: cp, n: n}
30 | }
31 |
32 | // Read fills p with repetitions of the reader's data. Writes until p is full or
33 | // until the last repetition was written. Subsequent calls to Read resume from where
34 | // the last repetition stopped. When no more bytes are available, returns 0, EOF.
35 | // Otherwise the error is nil.
36 | func (r *Reader) Read(p []byte) (int, error) {
37 | if len(p) == 0 {
38 | return 0, nil
39 | }
40 | if r.i == 0 && r.n == 0 {
41 | return 0, io.EOF
42 | }
43 | m := 0
44 | for {
45 | n := copy(p, r.data[r.i:])
46 | r.i += n
47 | if r.i == len(r.data) {
48 | r.i = 0
49 | if r.n > 0 {
50 | r.n--
51 | }
52 | }
53 | p = p[n:]
54 | m += n
55 | if len(p) == 0 || (r.i == 0 && r.n == 0) {
56 | break
57 | }
58 | }
59 | return m, nil
60 | }
61 |
62 | // Close is a no-op. Implements [io.ReadCloser].
63 | func (r *Reader) Close() error {
64 | return nil
65 | }
66 |
--------------------------------------------------------------------------------
/sets/sets_test.go:
--------------------------------------------------------------------------------
1 | package sets
2 |
3 | import (
4 | "encoding/json"
5 | "reflect"
6 | "testing"
7 |
8 | "golang.org/x/exp/maps"
9 | )
10 |
11 | func TestJSON(t *testing.T) {
12 | input := Set[int]{}.Add(1, 3, 6)
13 | want := Set[int]{}.Add(1, 3, 6)
14 | j, err := input.MarshalJSON()
15 | if err != nil {
16 | t.Fatalf("%v.MarshalJSON() failed: %v", input, err)
17 | }
18 |
19 | got := Set[int]{}
20 | err = got.UnmarshalJSON(j)
21 | if err != nil {
22 | t.Fatalf("%v.UnmarshalJSON(%q) failed: %v", input, j, err)
23 | }
24 |
25 | if !maps.Equal(got, want) {
26 | t.Fatalf("UnmarshalJSON(%q)=%v, want %v", j, got, want)
27 | }
28 | }
29 |
30 | func TestJSON_slice(t *testing.T) {
31 | input := []Set[int]{
32 | Set[int]{}.Add(1, 3, 6),
33 | Set[int]{}.Add(7, 9, 6),
34 | }
35 | want := []Set[int]{
36 | Set[int]{}.Add(1, 3, 6),
37 | Set[int]{}.Add(7, 9, 6),
38 | }
39 | j, err := json.Marshal(input)
40 | if err != nil {
41 | t.Fatalf("Marshal(%v) failed: %v", input, err)
42 | }
43 |
44 | got := []Set[int]{}
45 | err = json.Unmarshal(j, &got)
46 | if err != nil {
47 | t.Fatalf("Unmarshal(%q) failed: %v", j, err)
48 | }
49 |
50 | if !reflect.DeepEqual(input, got) {
51 | t.Fatalf("Unmarshal(%q)=%v, want %v", j, got, want)
52 | }
53 | }
54 |
55 | func TestJSON_map(t *testing.T) {
56 | input := map[string]Set[int]{
57 | "a": Set[int]{}.Add(1, 3, 6),
58 | "x": Set[int]{}.Add(7, 9, 6),
59 | }
60 | want := map[string]Set[int]{
61 | "a": Set[int]{}.Add(1, 3, 6),
62 | "x": Set[int]{}.Add(7, 9, 6),
63 | }
64 | j, err := json.Marshal(input)
65 | if err != nil {
66 | t.Fatalf("Marshal(%v) failed: %v", input, err)
67 | }
68 |
69 | got := map[string]Set[int]{}
70 | err = json.Unmarshal(j, &got)
71 | if err != nil {
72 | t.Fatalf("Unmarshal(%q) failed: %v", j, err)
73 | }
74 |
75 | if !reflect.DeepEqual(input, got) {
76 | t.Fatalf("Unmarshal(%q)=%v, want %v", j, got, want)
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/iterx/slice.go:
--------------------------------------------------------------------------------
1 | package iterx
2 |
3 | import (
4 | "iter"
5 | "slices"
6 | )
7 |
8 | // Slice returns an iterator over the slice values.
9 | //
10 | // Deprecated: use [slices.Values] instead.
11 | func Slice[T any](s []T) iter.Seq[T] {
12 | return slices.Values(s)
13 | }
14 |
15 | // ISlice returns an iterator over the slice values and their indices,
16 | // like in a range expression.
17 | //
18 | // Deprecated: use [slices.All] instead.
19 | func ISlice[T any](s []T) iter.Seq2[int, T] {
20 | return slices.All(s)
21 | }
22 |
23 | // Limit returns an iterator that stops after n elements,
24 | // if the underlying iterator does not stop before.
25 | func Limit[T any](it iter.Seq[T], n int) iter.Seq[T] {
26 | return func(yield func(T) bool) {
27 | i := 0
28 | for x := range it {
29 | i++
30 | if i > n {
31 | return
32 | }
33 | if !yield(x) {
34 | return
35 | }
36 | }
37 | }
38 | }
39 |
40 | // Limit2 returns an iterator that stops after n elements,
41 | // if the underlying iterator does not stop before.
42 | func Limit2[T any, S any](it iter.Seq2[T, S], n int) iter.Seq2[T, S] {
43 | return func(yield func(T, S) bool) {
44 | i := 0
45 | for x, y := range it {
46 | i++
47 | if i > n {
48 | return
49 | }
50 | if !yield(x, y) {
51 | return
52 | }
53 | }
54 | }
55 | }
56 |
57 | // Skip returns an iterator without the first n elements.
58 | func Skip[T any](it iter.Seq[T], n int) iter.Seq[T] {
59 | return func(yield func(T) bool) {
60 | i := 0
61 | for x := range it {
62 | i++
63 | if i <= n {
64 | continue
65 | }
66 | if !yield(x) {
67 | return
68 | }
69 | }
70 | }
71 | }
72 |
73 | // Skip2 returns an iterator without the first n elements.
74 | func Skip2[T any, S any](it iter.Seq2[T, S], n int) iter.Seq2[T, S] {
75 | return func(yield func(T, S) bool) {
76 | i := 0
77 | for x, y := range it {
78 | i++
79 | if i <= n {
80 | continue
81 | }
82 | if !yield(x, y) {
83 | return
84 | }
85 | }
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/sets/sorted.go:
--------------------------------------------------------------------------------
1 | package sets
2 |
3 | import (
4 | "cmp"
5 | )
6 |
7 | // SortedIntersection returns the intersection of
8 | // two sorted slices a and b.
9 | func SortedIntersection[T cmp.Ordered](a, b []T) []T {
10 | var result []T
11 | i, j := 0, 0
12 | for i < len(a) && j < len(b) {
13 | switch cmp.Compare(a[i], b[j]) {
14 | case 0:
15 | result = append(result, a[i])
16 | i++
17 | j++
18 | case 1:
19 | j++
20 | case -1:
21 | i++
22 | }
23 | }
24 | return result
25 | }
26 |
27 | // SortedUnion returns the union of
28 | // two sorted slices a and b.
29 | func SortedUnion[T cmp.Ordered](a, b []T) []T {
30 | var result []T
31 | i, j := 0, 0
32 | for i < len(a) && j < len(b) {
33 | switch cmp.Compare(a[i], b[j]) {
34 | case 0:
35 | result = append(result, a[i])
36 | i++
37 | j++
38 | case 1:
39 | result = append(result, b[j])
40 | j++
41 | case -1:
42 | result = append(result, a[i])
43 | i++
44 | }
45 | }
46 | // Add remaining elements.
47 | result = append(result, a[i:]...)
48 | result = append(result, b[j:]...)
49 | return result
50 | }
51 |
52 | // SortedIntersectionLen returns the length of the intersection of
53 | // two sorted slices a and b.
54 | func SortedIntersectionLen[T cmp.Ordered](a, b []T) int {
55 | result := 0
56 | i, j := 0, 0
57 | for i < len(a) && j < len(b) {
58 | switch cmp.Compare(a[i], b[j]) {
59 | case 0:
60 | i++
61 | j++
62 | result++
63 | case 1:
64 | j++
65 | case -1:
66 | i++
67 | }
68 | }
69 | return result
70 | }
71 |
72 | // SortedUnionLen returns the length of the union of
73 | // two sorted slices a and b.
74 | func SortedUnionLen[T cmp.Ordered](a, b []T) int {
75 | result := 0
76 | i, j := 0, 0
77 | for i < len(a) && j < len(b) {
78 | result++
79 | switch cmp.Compare(a[i], b[j]) {
80 | case 0:
81 | i++
82 | j++
83 | case 1:
84 | j++
85 | case -1:
86 | i++
87 | }
88 | }
89 | // Add remaining elements.
90 | result += len(a) - i
91 | result += len(b) - j
92 | return result
93 | }
94 |
--------------------------------------------------------------------------------
/morris/morris.go:
--------------------------------------------------------------------------------
1 | // Package morris provides an implementation of Morris's algorithm
2 | // for approximate counting with few bits.
3 | //
4 | // The original formula raises a counter i with probability 2^(-i).
5 | // The restored value is 2^i - 1.
6 | //
7 | // This package introduces a parameter m, so that the first m increments are
8 | // made with probability 1, then m increments with probability 1/2, then m with
9 | // probability 1/4... Using m=1 is equivalent to the original formula.
10 | // A large m increases accuracy but costs more bits.
11 | // A single counter should use the same m for all calls to Raise and Restore.
12 | //
13 | // This package is experimental.
14 | //
15 | // # Error rates
16 | //
17 | // Average error rates for different values of m:
18 | //
19 | // 1: 54.2%
20 | // 3: 31.1%
21 | // 10: 15.5%
22 | // 30: 8.8%
23 | // 100: 4.6%
24 | // 300: 2.5%
25 | // 1000: 1.6%
26 | // 3000: 0.9%
27 | // 10000: 0.5%
28 | package morris
29 |
30 | import (
31 | "fmt"
32 | "math/rand"
33 |
34 | "golang.org/x/exp/constraints"
35 | )
36 |
37 | // If true, panics when a counter is about to be raised beyond its maximal
38 | // value.
39 | const checkOverFlow = true
40 |
41 | // Raise returns the new value of i after one increment.
42 | // m controls the restoration accuracy.
43 | // The approximate number of calls to Raise can be restored using Restore.
44 | func Raise[T constraints.Unsigned](i T, m uint) T {
45 | if checkOverFlow {
46 | max := T(0) - 1
47 | if i == max {
48 | panic(fmt.Sprintf("counter reached maximal value: %d", i))
49 | }
50 | }
51 | r := 1 << (uint(i) / m)
52 | if rand.Intn(r) == 0 {
53 | return i + 1
54 | }
55 | return i
56 | }
57 |
58 | // Restore returns an approximation of the number of calls to Raise on i.
59 | // m should have the same value that was used with Raise.
60 | func Restore[T constraints.Unsigned](i T, m uint) uint {
61 | ui := uint(i)
62 | if ui <= m {
63 | return ui
64 | }
65 | return m*(1<<(ui/m)-1) + (ui%m)*(1<<(ui/m)) + (1 << (uint(i)/m - 2))
66 | }
67 |
--------------------------------------------------------------------------------
/iterx/slice_test.go:
--------------------------------------------------------------------------------
1 | package iterx
2 |
3 | import (
4 | "slices"
5 | "testing"
6 | )
7 |
8 | func TestSlice(t *testing.T) {
9 | input := []string{"hello", "world", "hi"}
10 | want := slices.Clone(input)
11 | var got []string
12 | for x := range Slice(input) {
13 | got = append(got, x)
14 | }
15 | if !slices.Equal(input, want) {
16 | t.Fatalf("Slice(%v) changed input to %v", want, input)
17 | }
18 | if !slices.Equal(got, want) {
19 | t.Fatalf("Slice(%v)=%v, want %v", input, got, want)
20 | }
21 | }
22 |
23 | func TestISlice(t *testing.T) {
24 | input := []string{"hello", "world", "hi"}
25 | want := slices.Clone(input)
26 | var got []string
27 | for i, x := range ISlice(input) {
28 | if i != len(got) {
29 | t.Fatalf("ISlice(%v) i=%v, want %v", input, i, len(got))
30 | }
31 | got = append(got, x)
32 | }
33 | if !slices.Equal(input, want) {
34 | t.Fatalf("ISlice(%v) changed input to %v", want, input)
35 | }
36 | if !slices.Equal(got, want) {
37 | t.Fatalf("ISlice(%v)=%v, want %v", input, got, want)
38 | }
39 | }
40 |
41 | func TestLimit(t *testing.T) {
42 | input := []string{"bla", "blu", "bli", "ble"}
43 | tests := []struct {
44 | n int
45 | want []string
46 | }{
47 | {-1, nil}, {0, nil}, {1, input[:1]}, {2, input[:2]},
48 | {3, input[:3]}, {4, input}, {5, input}, {6, input},
49 | }
50 | for _, test := range tests {
51 | var got []string
52 | for x := range Limit(Slice(input), test.n) {
53 | got = append(got, x)
54 | }
55 | if !slices.Equal(got, test.want) {
56 | t.Errorf("Limit(%v,%v)=%v, want %v", input, test.n, got, test.want)
57 | }
58 | }
59 | }
60 |
61 | func TestSkip(t *testing.T) {
62 | input := []string{"bla", "blu", "bli", "ble"}
63 | tests := []struct {
64 | n int
65 | want []string
66 | }{
67 | {-1, input}, {0, input}, {1, input[1:]}, {2, input[2:]},
68 | {3, input[3:]}, {4, nil}, {5, nil}, {6, nil},
69 | }
70 | for _, test := range tests {
71 | var got []string
72 | for x := range Skip(Slice(input), test.n) {
73 | got = append(got, x)
74 | }
75 | if !slices.Equal(got, test.want) {
76 | t.Errorf("Skip(%v,%v)=%v, want %v", input, test.n, got, test.want)
77 | }
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/clustering/upgma_test.go:
--------------------------------------------------------------------------------
1 | package clustering
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "reflect"
7 | "testing"
8 | )
9 |
10 | func TestUPGMA(t *testing.T) {
11 | points := []float64{1, 4, 6, 10}
12 | steps := []AggloStep{
13 | {1, 2, 2},
14 | {0, 2, 4},
15 | {2, 3, 19.0 / 3.0},
16 | }
17 | agg := upgma(len(points), func(i, j int) float64 {
18 | return math.Abs(points[i] - points[j])
19 | })
20 | if agg.Len() != len(points)-1 {
21 | t.Fatalf("Len()=%v, want %v", agg.Len(), len(points)-1)
22 | }
23 | for i := range steps {
24 | if step := agg.Step(i); !reflect.DeepEqual(steps[i], step) {
25 | t.Errorf("Step(%v)=%v, want %v", i, step, steps[i])
26 | }
27 | }
28 | }
29 |
30 | func TestUPGMA_more(t *testing.T) {
31 | points := []float64{1, 3, 8, 12, 20, 28}
32 | steps := []AggloStep{
33 | {0, 1, 2},
34 | {2, 3, 4},
35 | {1, 3, 8},
36 | {4, 5, 8},
37 | {3, 5, 18},
38 | }
39 | agg := upgma(len(points), func(i, j int) float64 {
40 | return math.Abs(points[i] - points[j])
41 | })
42 | if agg.Len() != len(points)-1 {
43 | t.Fatalf("Len()=%v, want %v", agg.Len(), len(points)-1)
44 | }
45 | for i := range steps {
46 | if step := agg.Step(i); !reflect.DeepEqual(steps[i], step) {
47 | t.Errorf("Step(%v)=%v, want %v", i, step, steps[i])
48 | }
49 | }
50 | }
51 |
52 | func BenchmarkUPGMA(b *testing.B) {
53 | for _, n := range []int{10, 30, 100} {
54 | b.Run(fmt.Sprint(n), func(b *testing.B) {
55 | nums := make([]float64, n)
56 | for i := range nums {
57 | nums[i] = 1.0 / float64(i+1)
58 | if i%2 == 0 {
59 | nums[i] += 10
60 | }
61 | }
62 | b.ResetTimer()
63 | for i := 0; i < b.N; i++ {
64 | upgma(n, func(i1, i2 int) float64 {
65 | return math.Abs(nums[i1] - nums[i2])
66 | })
67 | }
68 | })
69 | }
70 | }
71 |
72 | func FuzzUPGMA(f *testing.F) {
73 | f.Add(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)
74 | f.Fuzz(func(t *testing.T, a float64, b float64, c float64,
75 | d float64, e float64, f float64, g float64, h float64, i float64) {
76 | nums := []float64{a, b, c, d, e, f, g, h, i}
77 | upgma(len(nums), func(i1, i2 int) float64 {
78 | return math.Abs(nums[i1] - nums[i2])
79 | })
80 | })
81 | }
82 |
--------------------------------------------------------------------------------
/graphs/graph_test.go:
--------------------------------------------------------------------------------
1 | package graphs
2 |
3 | import (
4 | "reflect"
5 | "slices"
6 | "testing"
7 |
8 | "github.com/fluhus/gostuff/snm"
9 | )
10 |
11 | func TestComponents(t *testing.T) {
12 | edges := [][2]int{
13 | {0, 1}, {1, 2}, {5, 7}, {6, 9}, {9, 10}, {8, 10}, {7, 8},
14 | }
15 | want := [][]int{
16 | {0, 1, 2}, {3}, {4}, {5, 6, 7, 8, 9, 10}, {11},
17 | }
18 | g := New[int]()
19 | for i := range 12 {
20 | g.AddVertices(i)
21 | }
22 | for _, e := range edges {
23 | g.AddEdge(e[0], e[1])
24 | }
25 | got := g.ConnectedComponents()
26 | if !reflect.DeepEqual(got, want) {
27 | t.Fatalf("components(...)=%v, want %v", got, want)
28 | }
29 | }
30 |
31 | func TestComponents_string(t *testing.T) {
32 | edges := [][2]string{
33 | {"a", "bb"}, {"eeeee", "dddd"}, {"bb", "ccc"}, {"dddd", "eeeee"},
34 | }
35 | want := [][]string{
36 | {"ffffff"}, {"a", "bb", "ccc"}, {"eeeee", "dddd"},
37 | }
38 | g := New[string]()
39 | g.AddVertices("ffffff")
40 | for _, e := range edges {
41 | g.AddEdge(e[0], e[1])
42 | }
43 | got := g.ConnectedComponents()
44 | if !reflect.DeepEqual(got, want) {
45 | t.Fatalf("components(...)=%v, want %v", got, want)
46 | }
47 | }
48 |
49 | func TestVerticesEdges(t *testing.T) {
50 | vertices := []string{"ffffff", "bb"}
51 | edges := [][2]string{
52 | {"a", "bb"}, {"eeeee", "dddd"}, {"bb", "ccc"}, {"dddd", "eeeee"},
53 | }
54 |
55 | wantVertices := []string{"a", "bb", "ccc", "dddd", "eeeee", "ffffff"}
56 | wantEdges := [][2]string{
57 | {"bb", "a"}, {"bb", "ccc"}, {"eeeee", "dddd"},
58 | }
59 |
60 | g := New[string]()
61 | g.AddVertices(vertices...)
62 | for _, e := range edges {
63 | g.AddEdge(e[0], e[1])
64 | }
65 |
66 | gotVertices := snm.Sorted(slices.Collect(g.Vertices()))
67 | if !slices.Equal(gotVertices, wantVertices) {
68 | t.Errorf("Vertices()=%q, want %q", gotVertices, wantVertices)
69 | }
70 |
71 | var gotEdges [][2]string
72 | for a, b := range g.Edges() {
73 | gotEdges = append(gotEdges, [2]string{a, b})
74 | }
75 | slices.SortFunc(gotEdges, func(a, b [2]string) int {
76 | return slices.Compare(a[:], b[:])
77 | })
78 | if !slices.Equal(gotEdges, wantEdges) {
79 | t.Errorf("Edges()=%q, want %q", gotEdges, wantEdges)
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/bits/bits.go:
--------------------------------------------------------------------------------
1 | // Package bits provides operations on bit arrays.
2 | package bits
3 |
4 | import (
5 | "iter"
6 | mbits "math/bits"
7 |
8 | "golang.org/x/exp/constraints"
9 | )
10 |
11 | // Set sets the n'th bit to 1 or 0 for values of true or false
12 | // respectively.
13 | func Set[I constraints.Integer](data []byte, n I, value bool) {
14 | if value {
15 | Set1(data, n)
16 | } else {
17 | Set0(data, n)
18 | }
19 | }
20 |
21 | // Set1 sets the n'th bit in data to 1.
22 | func Set1[I constraints.Integer](data []byte, n I) {
23 | data[n/8] |= 1 << (n % 8)
24 | }
25 |
26 | // Set0 sets the n'th bit in data to 0.
27 | func Set0[I constraints.Integer](data []byte, n I) {
28 | data[n/8] &= ^(1 << (n % 8))
29 | }
30 |
31 | // Get returns the value of the n'th bit (0 or 1).
32 | func Get[I constraints.Integer](data []byte, n I) int {
33 | return int((data[n/8] >> (n % 8)) & 1)
34 | }
35 |
36 | // Sum returns the number of bits that have a value of 1.
37 | func Sum(data []byte) int {
38 | a := 0
39 | for _, b := range data {
40 | a += mbits.OnesCount8(b)
41 | }
42 | return a
43 | }
44 |
45 | // Ones iterates over the indexes of bits whose values are 1.
46 | func Ones(data []byte) iter.Seq[int] {
47 | return func(yield func(int) bool) {
48 | for i, x := range data {
49 | for _, b := range byteOnes[x] {
50 | if !yield(i*8 + b) {
51 | return
52 | }
53 | }
54 | }
55 | }
56 | }
57 |
58 | // Zeros iterates over the indexes of bits whose values are 0.
59 | func Zeros(data []byte) iter.Seq[int] {
60 | return func(yield func(int) bool) {
61 | for i, x := range data {
62 | for _, b := range byteZeros[x] {
63 | if !yield(i*8 + b) {
64 | return
65 | }
66 | }
67 | }
68 | }
69 | }
70 |
71 | // Indexes of ones for each byte value.
72 | var byteOnes [][]int
73 |
74 | // Indexes of zeros for each byte value.
75 | var byteZeros [][]int
76 |
77 | // Calculates indexes of ones and zeros for each byte value.
78 | func init() {
79 | byteOnes = make([][]int, 256)
80 | byteZeros = make([][]int, 256)
81 | for i := range 256 {
82 | var ones, zeros []int
83 | for b := range 8 {
84 | if (i>>b)&1 == 1 {
85 | ones = append(ones, b)
86 | } else {
87 | zeros = append(zeros, b)
88 | }
89 | }
90 | byteOnes[i] = ones
91 | byteZeros[i] = zeros
92 | }
93 | }
94 |
--------------------------------------------------------------------------------
/rhash/rabin32.go:
--------------------------------------------------------------------------------
1 | package rhash
2 |
3 | import (
4 | "fmt"
5 | "hash"
6 | )
7 |
8 | var _ hash.Hash32 = &RabinFingerprint32{}
9 |
10 | // RabinFingerprint32 implements a Rabin fingerprint rolling-hash.
11 | // Implements [hash.Hash32].
12 | type RabinFingerprint32 struct {
13 | h, pow uint32
14 | i int
15 | hist []byte
16 | }
17 |
18 | // NewRabinFingerprint32 returns a new rolling hash with a window size of n.
19 | func NewRabinFingerprint32(n int) *RabinFingerprint32 {
20 | if n < 1 {
21 | panic(fmt.Sprintf("bad n: %d", n))
22 | }
23 | return &RabinFingerprint32{0, 1, 0, make([]byte, n)}
24 | }
25 |
26 | // Write updates the hash with the given bytes. Always returns len(data), nil.
27 | func (h *RabinFingerprint32) Write(data []byte) (int, error) {
28 | for _, b := range data {
29 | h.WriteByte(b)
30 | }
31 | return len(data), nil
32 | }
33 |
34 | // WriteByte updates the hash with the given byte. Always returns nil.
35 | func (h *RabinFingerprint32) WriteByte(b byte) error {
36 | h.h = h.h*rabinPrime + uint32(b)
37 | i := h.i % len(h.hist)
38 | h.h -= h.pow * uint32(h.hist[i])
39 | h.hist[i] = b
40 | if h.i < len(h.hist) {
41 | h.pow *= rabinPrime
42 | }
43 | h.i++
44 | return nil
45 | }
46 |
47 | // Sum appends the current hash to b and returns the resulting slice.
48 | func (h *RabinFingerprint32) Sum(b []byte) []byte {
49 | s := h.Sum32()
50 | n := h.Size()
51 | for i := 0; i < n; i++ {
52 | b = append(b, byte(s))
53 | s >>= 8
54 | }
55 | return b
56 | }
57 |
58 | // Sum32 returns the current hash.
59 | func (h *RabinFingerprint32) Sum32() uint32 {
60 | return h.h
61 | }
62 |
63 | // Size returns the number of bytes Sum will return, which is four.
64 | func (h *RabinFingerprint32) Size() int {
65 | return 4
66 | }
67 |
68 | // BlockSize returns the hash's block size, which is one.
69 | func (h *RabinFingerprint32) BlockSize() int {
70 | return 1
71 | }
72 |
73 | // Reset resets the hash to its initial state.
74 | func (h *RabinFingerprint32) Reset() {
75 | h.h = 0
76 | h.i = 0
77 | h.pow = 1
78 | for i := range h.hist {
79 | h.hist[i] = 0
80 | }
81 | }
82 |
83 | // RabinFingerprintSum32 returns the Rabin fingerprint of data.
84 | func RabinFingerprintSum32(data []byte) uint32 {
85 | if len(data) == 0 {
86 | return 0
87 | }
88 | h := NewRabinFingerprint32(len(data))
89 | h.Write(data)
90 | return h.Sum32()
91 | }
92 |
--------------------------------------------------------------------------------
/clustering/rand.go:
--------------------------------------------------------------------------------
1 | package clustering
2 |
3 | import (
4 | "fmt"
5 | "slices"
6 |
7 | "github.com/fluhus/gostuff/sets"
8 | )
9 |
10 | // AdjustedRandIndex compares 2 taggings of the data for similarity. A score of
11 | // 1 means identical, a score of 0 means as good as random, and a negative
12 | // score means worse than random.
13 | func AdjustedRandIndex(tags1, tags2 []int) float64 {
14 | // Check input.
15 | if len(tags1) != len(tags2) {
16 | panic(fmt.Sprintf("Mismatching lengths: %d, %d",
17 | len(tags1), len(tags2)))
18 | }
19 |
20 | sets1 := tagsToSets(tags1)
21 | sets2 := tagsToSets(tags2)
22 |
23 | r := randIndex(sets1, sets2)
24 | e := expectedRandIndex(sets1, sets2)
25 | m := maxRandIndex(sets1, sets2)
26 | return (r - e) / (m - e)
27 | }
28 |
29 | // randIndex returns the RI part of the adjusted index.
30 | func randIndex(tags1, tags2 [][]int) float64 {
31 | r := 0
32 | for _, t1 := range tags1 {
33 | for _, t2 := range tags2 {
34 | r += choose2(sets.SortedIntersectionLen(t1, t2))
35 | }
36 | }
37 | return float64(r)
38 | }
39 |
40 | // expectedRandIndex returns the expected index according to hypergeometrical
41 | // distribution.
42 | func expectedRandIndex(tags1, tags2 [][]int) float64 {
43 | p1 := 0
44 | n := 0
45 | for _, tags := range tags1 {
46 | n += len(tags)
47 | p1 += choose2(len(tags))
48 | }
49 | p2 := 0
50 | for _, tags := range tags2 {
51 | p2 += choose2(len(tags))
52 | }
53 | p := float64(choose2(n))
54 | return float64(p1) * float64(p2) / p
55 | }
56 |
57 | // maxRandIndex returns the maximal possible index.
58 | func maxRandIndex(tags1, tags2 [][]int) float64 {
59 | p := 0
60 | for _, tags := range tags1 {
61 | p += choose2(len(tags))
62 | }
63 | for _, tags := range tags2 {
64 | p += choose2(len(tags))
65 | }
66 | return float64(p) / 2
67 | }
68 |
69 | func choose2(n int) int {
70 | return n * (n - 1) / 2
71 | }
72 |
73 | // tagsToSets converts a list of tags to a list of sets of indexes, one list
74 | // for each tag.
75 | func tagsToSets(tags []int) [][]int {
76 | // Make map from tag to its set.
77 | sets := map[int][]int{}
78 | for i, tag := range tags {
79 | sets[tag] = append(sets[tag], i)
80 | }
81 |
82 | // Convert map to slice.
83 | result := make([][]int, 0, len(sets))
84 | for _, set := range sets {
85 | slices.Sort(set)
86 | result = append(result, set)
87 | }
88 |
89 | return result
90 | }
91 |
--------------------------------------------------------------------------------
/rhash/rabin64.go:
--------------------------------------------------------------------------------
1 | package rhash
2 |
3 | import (
4 | "fmt"
5 | "hash"
6 | )
7 |
8 | const rabinPrime = 16777619
9 |
10 | var _ hash.Hash64 = &RabinFingerprint64{}
11 |
12 | // RabinFingerprint64 implements a Rabin fingerprint rolling-hash.
13 | // Implements [hash.Hash64].
14 | type RabinFingerprint64 struct {
15 | h, pow uint64
16 | i int
17 | hist []byte
18 | }
19 |
20 | // NewRabinFingerprint64 returns a new rolling hash with a window size of n.
21 | func NewRabinFingerprint64(n int) *RabinFingerprint64 {
22 | if n < 1 {
23 | panic(fmt.Sprintf("bad n: %d", n))
24 | }
25 | return &RabinFingerprint64{0, 1, 0, make([]byte, n)}
26 | }
27 |
28 | // Write updates the hash with the given bytes. Always returns len(data), nil.
29 | func (h *RabinFingerprint64) Write(data []byte) (int, error) {
30 | for _, b := range data {
31 | h.WriteByte(b)
32 | }
33 | return len(data), nil
34 | }
35 |
36 | // WriteByte updates the hash with the given byte. Always returns nil.
37 | func (h *RabinFingerprint64) WriteByte(b byte) error {
38 | h.h = h.h*rabinPrime + uint64(b)
39 | i := h.i % len(h.hist)
40 | h.h -= h.pow * uint64(h.hist[i])
41 | h.hist[i] = b
42 | if h.i < len(h.hist) {
43 | h.pow *= rabinPrime
44 | }
45 | h.i++
46 | return nil
47 | }
48 |
49 | // Sum appends the current hash to b and returns the resulting slice.
50 | func (h *RabinFingerprint64) Sum(b []byte) []byte {
51 | s := h.Sum64()
52 | n := h.Size()
53 | for i := 0; i < n; i++ {
54 | b = append(b, byte(s))
55 | s >>= 8
56 | }
57 | return b
58 | }
59 |
60 | // Sum64 returns the current hash.
61 | func (h *RabinFingerprint64) Sum64() uint64 {
62 | return h.h
63 | }
64 |
65 | // Size returns the number of bytes Sum will return, which is eight.
66 | func (h *RabinFingerprint64) Size() int {
67 | return 8
68 | }
69 |
70 | // BlockSize returns the hash's block size, which is one.
71 | func (h *RabinFingerprint64) BlockSize() int {
72 | return 1
73 | }
74 |
75 | // Reset resets the hash to its initial state.
76 | func (h *RabinFingerprint64) Reset() {
77 | h.h = 0
78 | h.i = 0
79 | h.pow = 1
80 | for i := range h.hist {
81 | h.hist[i] = 0
82 | }
83 | }
84 |
85 | // RabinFingerprintSum64 returns the Rabin fingerprint of data.
86 | func RabinFingerprintSum64(data []byte) uint64 {
87 | if len(data) == 0 {
88 | return 0
89 | }
90 | h := NewRabinFingerprint64(len(data))
91 | h.Write(data)
92 | return h.Sum64()
93 | }
94 |
--------------------------------------------------------------------------------
/hll/hll_test.go:
--------------------------------------------------------------------------------
1 | package hll
2 |
3 | import (
4 | "math"
5 | "testing"
6 |
7 | "github.com/fluhus/gostuff/bnry"
8 | "github.com/spaolacci/murmur3"
9 | )
10 |
11 | func TestCount_short(t *testing.T) {
12 | upto := 10000000
13 | if testing.Short() {
14 | upto = 1000
15 | }
16 |
17 | hll := newIntHLL()
18 | next := 1
19 | ratioSum := 0.0
20 | ratioCount := 0.0
21 | for i := 1; i <= upto; i++ {
22 | hll.Add(i)
23 | if i != next { // Check only a sample.
24 | continue
25 | }
26 | next = (next + 1) * 21 / 20
27 | ratio := float64(i) / float64(hll.ApproxCount())
28 | ratioSum += math.Abs(math.Log(ratio))
29 | ratioCount++
30 | }
31 | avg := math.Exp(ratioSum / ratioCount)
32 | want := 1.003
33 | if avg > want {
34 | t.Errorf("average error=%f, want at most %f",
35 | avg, want)
36 | }
37 | }
38 |
39 | func TestCount_zero(t *testing.T) {
40 | hll := newIntHLL()
41 | if count := hll.ApproxCount(); count != 0 {
42 | t.Fatalf("ApproxCount()=%v, want 0", count)
43 | }
44 | }
45 |
46 | func TestAddHLL(t *testing.T) {
47 | hll1 := newIntHLL()
48 | for i := 1; i <= 5; i++ {
49 | hll1.Add(i)
50 | }
51 | if count := hll1.ApproxCount(); count != 5 {
52 | t.Fatalf("ApproxCount()=%v, want 5", count)
53 | }
54 |
55 | hll2 := newIntHLL()
56 | for i := 4; i <= 9; i++ {
57 | hll2.Add(i)
58 | }
59 | if count := hll2.ApproxCount(); count != 6 {
60 | t.Fatalf("ApproxCount()=%v, want 6", count)
61 | }
62 |
63 | hll1.AddHLL(hll2)
64 | if count := hll1.ApproxCount(); count != 9 {
65 | t.Fatalf("ApproxCount()=%v, want 9", count)
66 | }
67 | }
68 |
69 | func BenchmarkAdd(b *testing.B) {
70 | hll := New(16, func(i int) uint64 { return uint64(i) })
71 | for i := 0; i < b.N; i++ {
72 | hll.Add(i)
73 | }
74 | }
75 |
76 | func BenchmarkAdd_intHLL(b *testing.B) {
77 | hll := newIntHLL()
78 | for i := 0; i < b.N; i++ {
79 | hll.Add(i)
80 | }
81 | }
82 |
83 | func BenchmarkCount(b *testing.B) {
84 | const nelements = 1000000
85 | hll := newIntHLL()
86 | for i := 0; i < nelements; i++ {
87 | hll.Add(i)
88 | }
89 | b.Run("", func(b *testing.B) {
90 | for i := 0; i < b.N; i++ {
91 | hll.ApproxCount()
92 | }
93 | })
94 | }
95 |
96 | func newIntHLL() *HLL[int] {
97 | h := murmur3.New64()
98 | w := bnry.NewWriter(h)
99 | return New(16, func(i int) uint64 {
100 | h.Reset()
101 | w.Write(i)
102 | return h.Sum64()
103 | })
104 | }
105 |
--------------------------------------------------------------------------------
/ppln/nserial.go:
--------------------------------------------------------------------------------
1 | package ppln
2 |
3 | import (
4 | "fmt"
5 | "iter"
6 | "sync"
7 | "sync/atomic"
8 | )
9 |
10 | // NonSerial starts a multi-goroutine transformation pipeline.
11 | //
12 | // Input is an iterator over the input values to be transformed.
13 | // It will be called in a thread-safe manner.
14 | // Transform receives an input (a) and a 0-based goroutine number (g),
15 | // and returns the result of processing a.
16 | // Output acts on a single result, and will be called in a thread-safe manner.
17 | // The order of outputs is arbitrary, but correlated with the order of
18 | // inputs.
19 | //
20 | // If one of the functions returns a non-nil error, the process stops and the
21 | // error is returned. Otherwise returns nil.
22 | func NonSerial[T1 any, T2 any](
23 | ngoroutines int,
24 | input iter.Seq2[T1, error],
25 | transform func(a T1, g int) (T2, error),
26 | output func(a T2) error) error {
27 | if ngoroutines < 1 {
28 | panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines))
29 | }
30 | pull, pstop := iter.Pull2(input)
31 | defer pstop()
32 |
33 | // An optimization for a single thread.
34 | if ngoroutines == 1 {
35 | for {
36 | t1, err, ok := pull()
37 |
38 | if !ok {
39 | return nil
40 | }
41 | if err != nil {
42 | return err
43 | }
44 |
45 | t2, err := transform(t1, 0)
46 | if err != nil {
47 | return err
48 | }
49 | if err := output(t2); err != nil {
50 | return err
51 | }
52 | }
53 | }
54 |
55 | ilock := &sync.Mutex{}
56 | olock := &sync.Mutex{}
57 | errs := make(chan error, ngoroutines)
58 | stop := &atomic.Bool{}
59 |
60 | for g := 0; g < ngoroutines; g++ {
61 | go func(g int) {
62 | for {
63 | if stop.Load() {
64 | errs <- nil
65 | return
66 | }
67 |
68 | ilock.Lock()
69 | t1, err, ok := pull()
70 | ilock.Unlock()
71 |
72 | if !ok {
73 | errs <- nil
74 | return
75 | }
76 | if err != nil {
77 | stop.Store(true)
78 | errs <- err
79 | return
80 | }
81 |
82 | t2, err := transform(t1, g)
83 | if err != nil {
84 | stop.Store(true)
85 | errs <- err
86 | return
87 | }
88 |
89 | olock.Lock()
90 | err = output(t2)
91 | olock.Unlock()
92 | if err != nil {
93 | stop.Store(true)
94 | errs <- err
95 | return
96 | }
97 | }
98 | }(g)
99 | }
100 |
101 | for g := 0; g < ngoroutines; g++ {
102 | if err := <-errs; err != nil {
103 | return err
104 | }
105 | }
106 | return nil
107 | }
108 |
--------------------------------------------------------------------------------
/xmlnode/read.go:
--------------------------------------------------------------------------------
1 | // Package xmlnode provides a hierarchical node representation of XML documents.
2 | // This package wraps encoding/xml and can be used instead of it.
3 | //
4 | // Each node has an underlying concrete type, but calling all functions is
5 | // legal. For example, here is how you can traverse the node tree:
6 | //
7 | // func traverse(n Node) {
8 | // // Text() returns an empty string for non-text nodes.
9 | // doSomeTextSearch(n.Text())
10 | //
11 | // // Children() returns nil for non-parent nodes.
12 | // for _, child := range n.Children() {
13 | // traverse(child)
14 | // }
15 | // }
16 | package xmlnode
17 |
18 | import (
19 | "encoding/xml"
20 | "io"
21 | )
22 |
23 | // ReadAll reads all XML data from the given reader and stores it in a root node.
24 | func ReadAll(r io.Reader) (Node, error) {
25 | // Create root node.
26 | // Starting with Tag instead of Root, to eliminate type checks when referring
27 | // to parent nodes during reading. Will be replaced with a Root node at the
28 | // end.
29 | result := &tag{
30 | nil,
31 | "",
32 | nil,
33 | nil,
34 | }
35 | dec := xml.NewDecoder(r)
36 |
37 | var t xml.Token
38 | var err error
39 | current := result
40 |
41 | // Parse tokens.
42 | for t, err = dec.Token(); err == nil; t, err = dec.Token() {
43 | switch t := t.(type) {
44 | case xml.StartElement:
45 | // Copy attributes.
46 | attrs := make([]*xml.Attr, len(t.Attr))
47 | for i, attr := range t.Attr {
48 | attrs[i] = &xml.Attr{Name: attr.Name, Value: attr.Value}
49 | }
50 |
51 | // Create child node.
52 | child := &tag{
53 | current,
54 | t.Name.Local,
55 | attrs,
56 | nil,
57 | }
58 |
59 | current.children = append(current.children, child)
60 | current = child
61 |
62 | case xml.EndElement:
63 | current = current.Parent().(*tag)
64 |
65 | case xml.CharData:
66 | child := &text{
67 | current,
68 | string(t),
69 | }
70 |
71 | current.children = append(current.children, child)
72 |
73 | case xml.Comment:
74 | child := &comment{
75 | current,
76 | string(t),
77 | }
78 |
79 | current.children = append(current.children, child)
80 |
81 | case xml.ProcInst:
82 | child := &procInst{
83 | current,
84 | string(t.Target),
85 | string(t.Inst),
86 | }
87 |
88 | current.children = append(current.children, child)
89 |
90 | case xml.Directive:
91 | child := &directive{
92 | current,
93 | string(t),
94 | }
95 |
96 | current.children = append(current.children, child)
97 | }
98 | }
99 |
100 | // EOF is ok.
101 | if err != io.EOF {
102 | return nil, err
103 | }
104 |
105 | return &root{result.children}, nil
106 | }
107 |
--------------------------------------------------------------------------------
/iterx/lines.go:
--------------------------------------------------------------------------------
1 | // Package iterx provides convenience functions for iterators.
2 | package iterx
3 |
4 | import (
5 | "bufio"
6 | "encoding/csv"
7 | "io"
8 | "iter"
9 |
10 | "github.com/fluhus/gostuff/aio"
11 | )
12 |
13 | // LinesReader iterates over text lines from a reader.
14 | func LinesReader(r io.Reader) iter.Seq2[string, error] {
15 | return func(yield func(string, error) bool) {
16 | sc := bufio.NewScanner(r)
17 | for sc.Scan() {
18 | if !yield(sc.Text(), nil) {
19 | return
20 | }
21 | }
22 | if err := sc.Err(); err != nil {
23 | yield("", err)
24 | }
25 | }
26 | }
27 |
28 | // LinesFile iterates over text lines from a reader.
29 | func LinesFile(file string) iter.Seq2[string, error] {
30 | return func(yield func(string, error) bool) {
31 | f, err := aio.Open(file)
32 | if err != nil {
33 | yield("", err)
34 | return
35 | }
36 | defer f.Close()
37 | sc := bufio.NewScanner(f)
38 | for sc.Scan() {
39 | if !yield(sc.Text(), nil) {
40 | return
41 | }
42 | }
43 | if err := sc.Err(); err != nil {
44 | yield("", err)
45 | }
46 | }
47 | }
48 |
49 | // CSVReader iterates over CSV entries from a reader.
50 | // fn is an optional function for modifying the CSV parser,
51 | // for example for changing the delimiter.
52 | //
53 | // Deprecated: use package csvx.
54 | func CSVReader(r io.Reader, fn func(*csv.Reader)) iter.Seq2[[]string, error] {
55 | return func(yield func([]string, error) bool) {
56 | c := csv.NewReader(r)
57 | if fn != nil {
58 | fn(c)
59 | }
60 | for {
61 | e, err := c.Read()
62 | if err == io.EOF {
63 | return
64 | }
65 | if !yield(e, nil) {
66 | return
67 | }
68 | }
69 | }
70 | }
71 |
72 | // CSVFile iterates over CSV entries from a file.
73 | // fn is an optional function for modifying the CSV parser,
74 | // for example for changing the delimiter.
75 | //
76 | // Deprecated: use package csvx.
77 | func CSVFile(file string, fn func(*csv.Reader)) iter.Seq2[[]string, error] {
78 | return func(yield func([]string, error) bool) {
79 | f, err := aio.Open(file)
80 | if err != nil {
81 | yield(nil, err)
82 | return
83 | }
84 | defer f.Close()
85 | c := csv.NewReader(f)
86 | if fn != nil {
87 | fn(c)
88 | }
89 | for {
90 | e, err := c.Read()
91 | if err == io.EOF {
92 | return
93 | }
94 | if !yield(e, nil) {
95 | return
96 | }
97 | }
98 | }
99 | }
100 |
101 | // CollectErr collects the given T's in a slice until the error is non-nil.
102 | func CollectErr[T any](it iter.Seq2[T, error]) ([]T, error) {
103 | var a []T
104 | for t, err := range it {
105 | if err != nil {
106 | return a, err
107 | }
108 | a = append(a, t)
109 | }
110 | return a, nil
111 | }
112 |
--------------------------------------------------------------------------------
/ppln/nserial_test.go:
--------------------------------------------------------------------------------
1 | package ppln
2 |
3 | import (
4 | "fmt"
5 | "math"
6 | "testing"
7 | )
8 |
9 | func TestNonSerial(t *testing.T) {
10 | want := 21082009.0
11 | for _, nt := range []int{1, 2, 4, 8} {
12 | t.Run(fmt.Sprint(nt), func(t *testing.T) {
13 | got := 0.0
14 | NonSerial[int, float64](
15 | nt,
16 | RangeInput(1, 100001),
17 | func(a, g int) (float64, error) {
18 | return math.Sqrt(float64(a)), nil
19 | },
20 | func(a float64) error {
21 | got += a
22 | return nil
23 | },
24 | )
25 | if math.Round(got) != want {
26 | t.Fatalf("NonSerial: got %f, want %f", got, want)
27 | }
28 | })
29 | }
30 | }
31 |
32 | func TestNonSerial_inputError(t *testing.T) {
33 | for _, nt := range []int{1, 2, 4, 8} {
34 | t.Run(fmt.Sprint(nt), func(t *testing.T) {
35 | got := 0.0
36 | err := NonSerial[int, float64](
37 | nt,
38 | func(yield func(int, error) bool) {
39 | for i, err := range RangeInput(1, 100001) {
40 | if i == 1000 {
41 | yield(0, fmt.Errorf("oh no"))
42 | return
43 | }
44 | if !yield(i, err) {
45 | return
46 | }
47 | }
48 | },
49 | func(a, g int) (float64, error) {
50 | return math.Sqrt(float64(a)), nil
51 | },
52 | func(a float64) error {
53 | got += a
54 | return nil
55 | },
56 | )
57 | if err == nil {
58 | t.Fatalf("NonSerial succeeded, want error")
59 | }
60 | })
61 | }
62 | }
63 |
64 | func TestNonSerial_transformError(t *testing.T) {
65 | for _, nt := range []int{1, 2, 4, 8} {
66 | t.Run(fmt.Sprint(nt), func(t *testing.T) {
67 | got := 0.0
68 | err := NonSerial[int, float64](
69 | nt,
70 | RangeInput(1, 100001),
71 | func(a, g int) (float64, error) {
72 | if a == 1000 {
73 | return 0, fmt.Errorf("oh no")
74 | }
75 | return math.Sqrt(float64(a)), nil
76 | },
77 | func(a float64) error {
78 | got += a
79 | return nil
80 | },
81 | )
82 | if err == nil {
83 | t.Fatalf("NonSerial succeeded, want error")
84 | }
85 | })
86 | }
87 | }
88 |
89 | func TestNonSerial_outputError(t *testing.T) {
90 | for _, nt := range []int{1, 2, 4, 8} {
91 | t.Run(fmt.Sprint(nt), func(t *testing.T) {
92 | got := 0.0
93 | err := NonSerial[int, float64](
94 | nt,
95 | RangeInput(1, 100001),
96 | func(a, g int) (float64, error) {
97 | return math.Sqrt(float64(a)), nil
98 | },
99 | func(a float64) error {
100 | if a == 32 {
101 | return fmt.Errorf("oh no")
102 | }
103 | got += a
104 | return nil
105 | },
106 | )
107 | if err == nil {
108 | t.Fatalf("NonSerial succeeded, want error")
109 | }
110 | })
111 | }
112 | }
113 |
114 | // TODO(amit): Error tests.
115 |
--------------------------------------------------------------------------------
/iterx/unreader.go:
--------------------------------------------------------------------------------
1 | package iterx
2 |
3 | import (
4 | "fmt"
5 | "iter"
6 | )
7 |
8 | // An Unreader wraps an iterator and adds an unread function.
9 | type Unreader[T any] struct {
10 | read func() (T, bool) // The underlying reader
11 | stop func()
12 | head T // The last read element
13 | hasHead bool // Unread was called
14 | }
15 |
16 | // Read returns the result of calling the underlying reader,
17 | // or the last unread element.
18 | func (r *Unreader[T]) Read() (T, bool) {
19 | if r.hasHead {
20 | // Remove headX values to allow GC.
21 | var t T
22 | t, r.head = r.head, t
23 | r.hasHead = false
24 | return t, true
25 | }
26 | t, ok := r.read()
27 | r.head = t
28 | return t, ok
29 | }
30 |
31 | // Unread makes the next call to Read return the last element and a nil error.
32 | // Can be called up to once per call to Read.
33 | func (r *Unreader[T]) Unread() {
34 | if r.hasHead {
35 | panic(fmt.Sprintf("called Unread twice: first with %v", r.head))
36 | }
37 | r.hasHead = true
38 | }
39 |
40 | // Until calls Read until stop returns true.
41 | func (r *Unreader[T]) Until(stop func(T) bool) iter.Seq[T] {
42 | return func(yield func(T) bool) {
43 | for {
44 | t, ok := r.Read()
45 | if !ok {
46 | return
47 | }
48 | if stop(t) {
49 | r.Unread()
50 | return
51 | }
52 | if !yield(t) {
53 | r.stop()
54 | return
55 | }
56 | }
57 | }
58 | }
59 |
60 | // GroupBy returns an iterator over groups, where each group is an iterator that
61 | // yields elements as long as the elements have sameGroup==true with the first
62 | // element in the group.
63 | func (r *Unreader[T]) GroupBy(sameGroup func(old T, nu T) bool,
64 | ) iter.Seq[iter.Seq[T]] {
65 | return func(yield func(iter.Seq[T]) bool) {
66 | for {
67 | group, ok := r.nextGroup(sameGroup)
68 | if !ok {
69 | return
70 | }
71 | if !yield(group) {
72 | r.stop()
73 | return
74 | }
75 | }
76 | }
77 | }
78 |
79 | // Returns an iterator that yields elements while sameGroup==true.
80 | func (r *Unreader[T]) nextGroup(sameGroup func(T, T) bool) (iter.Seq[T], bool) {
81 | firstT, ok := r.Read()
82 | if !ok {
83 | return nil, false
84 | }
85 | return func(yield func(T) bool) {
86 | if !yield(firstT) {
87 | return
88 | }
89 | for {
90 | t, ok := r.Read()
91 | if !ok {
92 | return
93 | }
94 | if !sameGroup(firstT, t) {
95 | r.Unread()
96 | return
97 | }
98 | if !yield(t) {
99 | r.stop()
100 | return
101 | }
102 | }
103 | }, true
104 | }
105 |
106 | // NewUnreader returns an Unreader with seq as its underlying iterator.
107 | func NewUnreader[T any](seq iter.Seq[T]) *Unreader[T] {
108 | read, stop := iter.Pull(seq)
109 | return &Unreader[T]{read: read, stop: stop, hasHead: false}
110 | }
111 |
--------------------------------------------------------------------------------
/nlp/wordnet/types.go:
--------------------------------------------------------------------------------
1 | package wordnet
2 |
3 | // WordNet is an entire wordnet database.
4 | type WordNet struct {
5 | // Maps from synset ID to synset.
6 | Synset map[string]*Synset `json:"synset"`
7 |
8 | // Maps from pos.lemma to synset IDs that contain it.
9 | Lemma map[string][]string `json:"lemma"`
10 |
11 | // Like Lemma, but synsets are ordered from the most frequently used to the
12 | // least. Only a subset of the synsets are ranked, so LemmaRanked has less
13 | // synsets.
14 | LemmaRanked map[string][]string `json:"lemmaRanked"`
15 |
16 | // Maps from exceptional word to its forms.
17 | Exception map[string][]string `json:"exception"`
18 |
19 | // Maps from example ID to sentence template. Using string keys for JSON
20 | // compatibility.
21 | Example map[string]string `json:"example"`
22 | }
23 |
24 | // Synset is a set of synonymous words.
25 | type Synset struct {
26 | // Synset offset, also used as an identifier.
27 | Offset string `json:"offset"`
28 |
29 | // Part of speech, including 's' for adjective satellite.
30 | Pos string `json:"pos"`
31 |
32 | // Words in this synset.
33 | Word []string `json:"word"`
34 |
35 | // Pointers to other synsets.
36 | Pointer []*Pointer `json:"pointer"`
37 |
38 | // Sentence frames for verbs.
39 | Frame []*Frame `json:"frame"`
40 |
41 | // Lexical definition.
42 | Gloss string `json:"gloss"`
43 |
44 | // Usage examples for words in this synset. Verbs only.
45 | Example []*Example `json:"example"`
46 | }
47 |
48 | // A Frame links a synset word to a generic phrase that illustrates how to use
49 | // it. Applies to verbs only.
50 | //
51 | // See the list of frames here:
52 | // https://wordnet.princeton.edu/man/wninput.5WN.html#sect4
53 | type Frame struct {
54 | // Index of word in the containing synset, -1 for entire synset.
55 | WordNumber int `json:"wordNumber"`
56 |
57 | // Frame number on the WordNet site.
58 | FrameNumber int `json:"frameNumber"`
59 | }
60 |
61 | // A Pointer denotes a semantic relation between one synset/word to another.
62 | //
63 | // See list of pointer symbols here:
64 | // https://wordnet.princeton.edu/man/wninput.5WN.html#sect3
65 | type Pointer struct {
66 | // Relation between the 2 words. Target is to source. See
67 | // package constants for meaning of symbols.
68 | Symbol string `json:"symbol"`
69 |
70 | // Target synset ID.
71 | Synset string `json:"synset"`
72 |
73 | // Index of word in source synset, -1 for entire synset.
74 | Source int `json:"source"`
75 |
76 | // Index of word in target synset, -1 for entire synset.
77 | Target int `json:"target"`
78 | }
79 |
80 | // An Example links a synset word to an example sentence. Applies to verbs only.
81 | type Example struct {
82 | // Index of word in the containing synset.
83 | WordNumber int `json:"wordNumber"`
84 |
85 | // Number of template in the WordNet.Example field.
86 | TemplateNumber int `json:"templateNumber"`
87 | }
88 |
--------------------------------------------------------------------------------
/sets/sets.go:
--------------------------------------------------------------------------------
1 | // Package sets provides generic sets.
2 | package sets
3 |
4 | import (
5 | "encoding/json"
6 |
7 | "golang.org/x/exp/maps"
8 | )
9 |
10 | // Set is a convenience wrapper around map[T]struct{}.
11 | type Set[T comparable] map[T]struct{}
12 |
13 | // Add inserts the given elements to s and returns s.
14 | func (s Set[T]) Add(t ...T) Set[T] {
15 | for _, v := range t {
16 | s[v] = struct{}{}
17 | }
18 | return s
19 | }
20 |
21 | // AddSet inserts the elements of t to s and returns s.
22 | func (s Set[T]) AddSet(t Set[T]) Set[T] {
23 | for v := range t {
24 | s[v] = struct{}{}
25 | }
26 | return s
27 | }
28 |
29 | // Remove deletes the given elements from s and returns s.
30 | func (s Set[T]) Remove(t ...T) Set[T] {
31 | for _, v := range t {
32 | delete(s, v)
33 | }
34 | return s
35 | }
36 |
37 | // RemoveSet deletes the elements of t from s and returns s.
38 | func (s Set[T]) RemoveSet(t Set[T]) Set[T] {
39 | for v := range t {
40 | delete(s, v)
41 | }
42 | return s
43 | }
44 |
45 | // Has returns whether t is a member of s.
46 | func (s Set[T]) Has(t T) bool {
47 | _, ok := s[t]
48 | return ok
49 | }
50 |
51 | // Intersect returns a new set holding the elements that are common
52 | // to s and t.
53 | func (s Set[T]) Intersect(t Set[T]) Set[T] {
54 | if len(s) > len(t) { // Iterate over the smaller one.
55 | s, t = t, s
56 | }
57 | result := Set[T]{}
58 | for v := range s {
59 | if t.Has(v) {
60 | result.Add(v)
61 | }
62 | }
63 | return result
64 | }
65 |
66 | // MarshalJSON implements the json.Marshaler interface.
67 | func (s Set[T]) MarshalJSON() ([]byte, error) {
68 | return json.Marshal(maps.Keys(s))
69 | }
70 |
71 | // UnmarshalJSON implements the json.Unmarshaler interface.
72 | func (s *Set[T]) UnmarshalJSON(b []byte) error {
73 | var slice []T
74 | if err := json.Unmarshal(b, &slice); err != nil {
75 | return err
76 | }
77 | if *s == nil {
78 | *s = Set[T]{}
79 | }
80 | s.Add(slice...)
81 | return nil
82 | }
83 |
84 | // AddKeys adds the keys of a map to a set.
85 | func AddKeys[K comparable, V any](s Set[K], m map[K]V) Set[K] {
86 | for k := range m {
87 | s.Add(k)
88 | }
89 | return s
90 | }
91 |
92 | // AddValues adds the values of a map to a set.
93 | func AddValues[K comparable, V comparable](s Set[V], m map[K]V) Set[V] {
94 | for _, v := range m {
95 | s.Add(v)
96 | }
97 | return s
98 | }
99 |
100 | // Of returns a new set containing the given elements.
101 | func Of[T comparable](t ...T) Set[T] {
102 | return make(Set[T], len(t)).Add(t...)
103 | }
104 |
105 | // FromKeys returns a new set containing the keys of the given map.
106 | func FromKeys[K comparable, V any](m map[K]V) Set[K] {
107 | return AddKeys(make(Set[K], len(m)), m)
108 | }
109 |
110 | // FromValues returns a new set containing the values of the given map.
111 | func FromValues[K comparable, V comparable](m map[K]V) Set[V] {
112 | return AddValues(Set[V]{}, m)
113 | }
114 |
--------------------------------------------------------------------------------
/heaps/heaps_test.go:
--------------------------------------------------------------------------------
1 | package heaps
2 |
3 | import (
4 | "fmt"
5 | "math/rand/v2"
6 | "slices"
7 | "testing"
8 |
9 | "github.com/fluhus/gostuff/snm"
10 | )
11 |
12 | func TestHeap(t *testing.T) {
13 | input := []string{"bb", "a", "ffff", "ddddd"}
14 | want := []string{"a", "bb", "ddddd", "ffff"}
15 | h := Min[string]()
16 | for _, v := range input {
17 | h.Push(v)
18 | }
19 | if ln := h.Len(); ln != len(input) {
20 | t.Fatalf("Len()=%d, want %d", ln, len(input))
21 | }
22 | var got []string
23 | for h.Len() > 0 {
24 | got = append(got, h.Pop())
25 | }
26 | if !slices.Equal(got, want) {
27 | t.Fatalf("Pop=%v, want %v", got, want)
28 | }
29 | }
30 |
31 | func TestHeap_big(t *testing.T) {
32 | input := []int{
33 | 5, 8, 25, 21, 22, 15, 13, 20, 1, 14,
34 | 24, 12, 7, 18, 27, 3, 30, 28, 23, 29,
35 | 19, 2, 6, 4, 26, 9, 17, 10, 11, 16,
36 | }
37 | want := []int{
38 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10,
39 | 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
40 | 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
41 | }
42 | h := New(func(i1, i2 int) bool {
43 | return i1 < i2
44 | })
45 | for _, v := range input {
46 | h.Push(v)
47 | }
48 | if ln := h.Len(); ln != len(input) {
49 | t.Fatalf("Len()=%d, want %d", ln, len(input))
50 | }
51 | var got []int
52 | for h.Len() > 0 {
53 | got = append(got, h.Pop())
54 | }
55 | if !slices.Equal(got, want) {
56 | t.Fatalf("Pop=%v, want %v", got, want)
57 | }
58 | }
59 |
60 | func TestHeap_pushSlice(t *testing.T) {
61 | input := []int{
62 | 5, 8, 25, 21, 22, 15, 13, 20, 1, 14,
63 | 24, 12, 7, 18, 27, 3, 30, 28, 23, 29,
64 | 19, 2, 6, 4, 26, 9, 17, 10, 11, 16,
65 | }
66 | h := Min[int]()
67 | h.PushSlice(input)
68 | for i := range h.a {
69 | if i == 0 {
70 | continue
71 | }
72 | ia := (i - 1) / 2
73 | if h.a[i] < h.a[ia] {
74 | t.Errorf("h[%d] < h[%d]: %d < %d", i, ia, h.a[i], h.a[ia])
75 | }
76 | }
77 | }
78 |
79 | func Benchmark(b *testing.B) {
80 | for _, n := range []int{1000, 10000, 100000, 1000000} {
81 | nums := snm.Slice(n, func(i int) int {
82 | return rand.Int()
83 | })
84 | b.Run(fmt.Sprint("Heap.Push.", n), func(b *testing.B) {
85 | for range b.N {
86 | h := Min[int]()
87 | for _, i := range nums {
88 | h.Push(i)
89 | }
90 | }
91 | })
92 | b.Run(fmt.Sprint("Heap.PushSlice.", n), func(b *testing.B) {
93 | for range b.N {
94 | h := Min[int]()
95 | h.PushSlice(nums)
96 | }
97 | })
98 | }
99 | }
100 |
101 | func FuzzHeap(f *testing.F) {
102 | f.Add(0, 0, 0, 0, 0, 0, 0)
103 | f.Fuzz(func(t *testing.T, a, b, c, d, e, f, g int) {
104 | h := Min[int]()
105 | h.Push(a)
106 | h.Push(b)
107 | h.Push(c)
108 | h.Push(d)
109 | h.Push(e)
110 | h.Push(f)
111 | h.Push(g)
112 | got := make([]int, 0, 7)
113 | for h.Len() > 0 {
114 | got = append(got, h.Pop())
115 | }
116 | if !slices.IsSorted(got) {
117 | t.Fatalf("Min().Pop()=%v, want sorted", got)
118 | }
119 | })
120 | }
121 |
--------------------------------------------------------------------------------
/iterx/unreader2.go:
--------------------------------------------------------------------------------
1 | package iterx
2 |
3 | import (
4 | "fmt"
5 | "iter"
6 | )
7 |
8 | // An Unreader2 wraps an iterator and adds an unread function.
9 | type Unreader2[T, S any] struct {
10 | read func() (T, S, bool) // The underlying reader
11 | stop func()
12 | headT T // The last read element
13 | headS S // The last read element
14 | hasHead bool // Unread was called
15 | }
16 |
17 | // Read returns the result of calling the underlying reader,
18 | // or the last unread element.
19 | func (r *Unreader2[T, S]) Read() (T, S, bool) {
20 | if r.hasHead {
21 | // Remove headX values to allow GC.
22 | var t T
23 | var s S
24 | t, r.headT = r.headT, t
25 | s, r.headS = r.headS, s
26 | r.hasHead = false
27 | return t, s, true
28 | }
29 | t, s, ok := r.read()
30 | r.headT = t
31 | r.headS = s
32 | return t, s, ok
33 | }
34 |
35 | // Unread makes the next call to Read return the last element and a nil error.
36 | // Can be called up to once per call to Read.
37 | func (r *Unreader2[T, S]) Unread() {
38 | if r.hasHead {
39 | panic(fmt.Sprintf("called Unread twice: first with (%v,%v)",
40 | r.headT, r.headS))
41 | }
42 | r.hasHead = true
43 | }
44 |
45 | // Until calls Read until stop returns true.
46 | func (r *Unreader2[T, S]) Until(stop func(T, S) bool) iter.Seq2[T, S] {
47 | return func(yield func(T, S) bool) {
48 | for {
49 | t, s, ok := r.Read()
50 | if !ok {
51 | return
52 | }
53 | if stop(t, s) {
54 | r.Unread()
55 | return
56 | }
57 | if !yield(t, s) {
58 | r.stop()
59 | return
60 | }
61 | }
62 | }
63 | }
64 |
65 | // GroupBy returns an iterator over groups, where each group is an iterator that
66 | // yields elements as long as the elements have sameGroup==true with the first
67 | // element in the group.
68 | func (r *Unreader2[T, S]) GroupBy(
69 | sameGroup func(oldT T, oldS S, newT T, newS S) bool,
70 | ) iter.Seq[iter.Seq2[T, S]] {
71 | return func(yield func(iter.Seq2[T, S]) bool) {
72 | for {
73 | group, ok := r.nextGroup(sameGroup)
74 | if !ok {
75 | return
76 | }
77 | if !yield(group) {
78 | r.stop()
79 | return
80 | }
81 | }
82 | }
83 | }
84 |
85 | // Returns an iterator that yields elements while sameGroup==true.
86 | func (r *Unreader2[T, S]) nextGroup(sameGroup func(T, S, T, S) bool) (iter.Seq2[T, S], bool) {
87 | firstT, firstS, ok := r.Read()
88 | if !ok {
89 | return nil, false
90 | }
91 | return func(yield func(T, S) bool) {
92 | if !yield(firstT, firstS) {
93 | return
94 | }
95 | for {
96 | t, s, ok := r.Read()
97 | if !ok {
98 | return
99 | }
100 | if !sameGroup(firstT, firstS, t, s) {
101 | r.Unread()
102 | return
103 | }
104 | if !yield(t, s) {
105 | r.stop()
106 | return
107 | }
108 | }
109 | }, true
110 | }
111 |
112 | // NewUnreader2 returns an Unreader2 with seq as its underlying iterator.
113 | func NewUnreader2[T, S any](seq iter.Seq2[T, S]) *Unreader2[T, S] {
114 | read, stop := iter.Pull2(seq)
115 | return &Unreader2[T, S]{read: read, stop: stop, hasHead: false}
116 | }
117 |
--------------------------------------------------------------------------------
/ptimer/ptimer_test.go:
--------------------------------------------------------------------------------
1 | package ptimer
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "io"
7 | "regexp"
8 | "slices"
9 | "testing"
10 | "time"
11 | )
12 |
13 | const timePattern = "\\d\\d:\\d\\d:\\d\\d\\.\\d\\d\\d\\d\\d\\d"
14 |
15 | func TestNew(t *testing.T) {
16 | want := "^"
17 | for _, i := range []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 35} {
18 | want += fmt.Sprintf("\r%s \\(%s\\) %d", timePattern, timePattern, i)
19 | }
20 | want += "\n$"
21 |
22 | got := bytes.NewBuffer(nil)
23 | pt := New()
24 | pt.W = got
25 | for i := 0; i < 35; i++ {
26 | pt.Inc()
27 | }
28 | pt.Done()
29 |
30 | if match, _ := regexp.MatchString(want, got.String()); !match {
31 | t.Fatalf("Inc()+Done()=%q, want %q", got.String(), want)
32 | }
33 | }
34 |
35 | func TestNewMessage(t *testing.T) {
36 | want := "^"
37 | for _, i := range []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 35} {
38 | want += fmt.Sprintf("\r%s \\(%s\\) hey %d ho",
39 | timePattern, timePattern, i)
40 | }
41 | want += "\n$"
42 |
43 | got := bytes.NewBuffer(nil)
44 | pt := NewMessage("hey {} ho")
45 | pt.W = got
46 | for i := 0; i < 35; i++ {
47 | pt.Inc()
48 | }
49 | pt.Done()
50 |
51 | if match, _ := regexp.MatchString(want, got.String()); !match {
52 | t.Fatalf("Inc()+Done()=%q, want %q", got.String(), want)
53 | }
54 | }
55 |
56 | func TestNewFunc(t *testing.T) {
57 | want := "^"
58 | for _, i := range []float64{1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5,
59 | 10.5, 20.5, 30.5, 35.5} {
60 | want += fmt.Sprintf("\r%s \\(%s\\) ho ho %f",
61 | timePattern, timePattern, i)
62 | }
63 | want += "\n$"
64 |
65 | got := bytes.NewBuffer(nil)
66 | pt := NewFunc(func(i int) string {
67 | return fmt.Sprintf("ho ho %f", float64(i)+0.5)
68 | })
69 | pt.W = got
70 | for i := 0; i < 35; i++ {
71 | pt.Inc()
72 | }
73 | pt.Done()
74 |
75 | if match, _ := regexp.MatchString(want, got.String()); !match {
76 | t.Fatalf("Inc()+Done()=%q, want %q", got.String(), want)
77 | }
78 | }
79 |
80 | func TestDone(t *testing.T) {
81 | want := "^" + timePattern + " hello\n$"
82 |
83 | got := bytes.NewBuffer(nil)
84 | pt := NewMessage("hello")
85 | pt.W = got
86 | pt.Done()
87 |
88 | if match, _ := regexp.MatchString(want, got.String()); !match {
89 | t.Fatalf("Done()=%q, want %q", got.String(), want)
90 | }
91 | }
92 |
93 | func TestNextCheckpoint(t *testing.T) {
94 | want := []int{
95 | 1, 2, 3, 4, 5, 6, 7, 8, 9,
96 | 10, 20, 30, 40, 50, 60, 70, 80, 90,
97 | 100, 200, 300, 400, 500}
98 | var got []int
99 | i := 0
100 | for range want {
101 | i = nextCheckpoint(i)
102 | got = append(got, i)
103 | }
104 | if !slices.Equal(got, want) {
105 | t.Fatalf("nextCheckpoint(...)=%v, want %v", got, want)
106 | }
107 | }
108 |
109 | func BenchmarkTimer_inc(b *testing.B) {
110 | pt := New()
111 | pt.W = io.Discard
112 | b.ResetTimer()
113 | for i := 0; i < b.N; i++ {
114 | pt.Inc()
115 | }
116 | }
117 |
118 | func Example() {
119 | pt := New()
120 | for i := 0; i < 45; i++ {
121 | time.Sleep(100 * time.Millisecond)
122 | pt.Inc()
123 | }
124 | pt.Done()
125 | }
126 |
--------------------------------------------------------------------------------
/ppln/serial_test.go:
--------------------------------------------------------------------------------
1 | package ppln
2 |
3 | import (
4 | "fmt"
5 | "math/rand"
6 | "testing"
7 | "time"
8 |
9 | "github.com/fluhus/gostuff/gnum"
10 | )
11 |
12 | func ExampleSerial() {
13 | ngoroutines := 4
14 | var results []float64
15 |
16 | Serial[int, float64](ngoroutines,
17 | // Read/generate input data.
18 | RangeInput(1, 101),
19 | // Some processing.
20 | func(a, i, g int) (float64, error) {
21 | return float64(a*a) + 0.5, nil
22 | },
23 | // Accumulate/forward outputs.
24 | func(a float64) error {
25 | results = append(results, a)
26 | return nil
27 | })
28 |
29 | fmt.Println(results[:3], results[len(results)-3:])
30 |
31 | // Output:
32 | // [1.5 4.5 9.5] [9604.5 9801.5 10000.5]
33 | }
34 |
35 | func ExampleSerial_parallelAggregation() {
36 | ngoroutines := 4
37 | results := make([]int, ngoroutines) // Goroutine-specific data and objects.
38 |
39 | Serial(
40 | ngoroutines,
41 | // Read/generate input data.
42 | RangeInput(1, 101),
43 | // Accumulate in goroutine-specific memory.
44 | func(a int, i, g int) (int, error) {
45 | results[g] += a
46 | return 0, nil // Unused.
47 | },
48 | // No outputs.
49 | func(a int) error { return nil })
50 |
51 | // Collect the results of all goroutines.
52 | fmt.Println("Sum of 1-100:", gnum.Sum(results))
53 |
54 | // Output:
55 | // Sum of 1-100: 5050
56 | }
57 |
58 | func TestSerial(t *testing.T) {
59 | for _, nt := range []int{1, 2, 4, 8} {
60 | t.Run(fmt.Sprint(nt), func(t *testing.T) {
61 | n := nt * 100
62 | var result []int
63 | err := Serial(
64 | nt,
65 | RangeInput(0, n),
66 | func(a int, i int, g int) (int, error) {
67 | time.Sleep(time.Millisecond * time.Duration(rand.Intn(3)))
68 | return a * a, nil
69 | },
70 | func(i int) error {
71 | result = append(result, i)
72 | return nil
73 | })
74 | if err != nil {
75 | t.Fatalf("Serial(...) failed: %d", err)
76 | }
77 | for i := range result {
78 | if result[i] != i*i {
79 | t.Errorf("result[%d]=%d, want %d", i, result[i], i*i)
80 | }
81 | }
82 | })
83 | }
84 | }
85 |
86 | func TestSerial_error(t *testing.T) {
87 | for _, nt := range []int{1, 2, 4, 8} {
88 | t.Run(fmt.Sprint(nt), func(t *testing.T) {
89 | n := nt * 100
90 | var result []int
91 | err := Serial(
92 | nt,
93 | RangeInput(0, n),
94 | func(a int, i int, g int) (int, error) {
95 | time.Sleep(time.Millisecond * time.Duration(rand.Intn(3)))
96 | if a > 300 {
97 | return 0, fmt.Errorf("a too big: %d", a)
98 | }
99 | return a * a, nil
100 | },
101 | func(i int) error {
102 | result = append(result, i)
103 | return nil
104 | })
105 | if nt <= 3 {
106 | if err != nil {
107 | t.Fatalf("Serial(...) failed: %d", err)
108 | }
109 | for i := range result {
110 | if result[i] != i*i {
111 | t.Errorf("result[%d]=%d, want %d", i, result[i], i*i)
112 | }
113 | }
114 | } else { // n > 3
115 | if err == nil {
116 | t.Fatalf("Serial(...) succeeded, want error")
117 | }
118 | }
119 | })
120 | }
121 | }
122 |
--------------------------------------------------------------------------------
/iterx/unreader_test.go:
--------------------------------------------------------------------------------
1 | package iterx
2 |
3 | import (
4 | "iter"
5 | "reflect"
6 | "slices"
7 | "testing"
8 | )
9 |
10 | func TestUnreader_until(t *testing.T) {
11 | input := []int{1, 4, 2, 6, 8, 4, 5, 7}
12 | tests := []struct {
13 | until int
14 | want []int
15 | }{
16 | {6, []int{1, 4, 2}}, {4, []int{6, 8}}, {1, []int{4, 5, 7}},
17 | }
18 | r := NewUnreader(Slice(input))
19 | for _, test := range tests {
20 | var got []int
21 | for i := range r.Until(func(j int) bool { return j == test.until }) {
22 | got = append(got, i)
23 | }
24 | if !slices.Equal(got, test.want) {
25 | t.Fatalf("New(%v).Until(%d)=%v, want %v", input, test.until, got, test.want)
26 | }
27 | }
28 | }
29 |
30 | func TestUnreader_groupBy(t *testing.T) {
31 | input := []int{1, 4, 2, 6, 9, 4, 5, 7}
32 | want := [][]int{{1}, {4, 2, 6}, {9}, {4}, {5, 7}}
33 | var got [][]int
34 | r := NewUnreader(Slice(input))
35 | for group := range r.GroupBy(func(i int, j int) bool {
36 | return i%2 == j%2
37 | }) {
38 | var got1 []int
39 | for i := range group {
40 | got1 = append(got1, i)
41 | }
42 | got = append(got, got1)
43 | }
44 | if !reflect.DeepEqual(got, want) {
45 | t.Fatalf("New(%v).GroupBy(...)=%v, want %v", input, got, want)
46 | }
47 | }
48 |
49 | func TestUnreader_stop(t *testing.T) {
50 | input := []int{1, 2, 3, 4, 5, 6, 7, 8, 9}
51 |
52 | stopped := false
53 | broke := false
54 | r := NewUnreader(stopIter(input, &stopped))
55 | for x := range r.Until(func(i int) bool { return i == 5 }) {
56 | if x == 3 {
57 | broke = true
58 | break
59 | }
60 | }
61 | if !broke {
62 | t.Fatalf("Unreader(%v) never broke", input)
63 | }
64 | if !stopped {
65 | t.Fatalf("Unreader(%v) never called stop", input)
66 | }
67 |
68 | stopped, broke = false, false
69 | continued := true
70 | r = NewUnreader(stopIter(input, &stopped))
71 | for g := range r.GroupBy(func(i int, j int) bool {
72 | return i/3 == j/3
73 | }) {
74 | continued = true
75 | for x := range g {
76 | if x == 5 {
77 | broke = true
78 | continued = false
79 | break
80 | }
81 | }
82 | }
83 | if !broke {
84 | t.Fatalf("Unreader(%v) never broke", input)
85 | }
86 | if continued {
87 | t.Fatalf("Unreader(%v).GroupBy outer loop continued "+
88 | "after inner loop broke", input)
89 | }
90 | if !stopped {
91 | t.Fatalf("Unreader(%v) never called stop", input)
92 | }
93 |
94 | stopped = false
95 | toBreak := false
96 | r = NewUnreader(stopIter(input, &stopped))
97 | for g := range r.GroupBy(func(i int, j int) bool {
98 | return i/3 == j/3
99 | }) {
100 | for x := range g {
101 | if x == 5 {
102 | toBreak = true
103 | }
104 | }
105 | if toBreak {
106 | break
107 | }
108 | }
109 | if !toBreak {
110 | t.Fatalf("Unreader(%v) never broke", input)
111 | }
112 | if !stopped {
113 | t.Fatalf("Unreader(%v) never called stop", input)
114 | }
115 | }
116 |
117 | func stopIter[T any](s []T, stopped *bool) iter.Seq[T] {
118 | return func(yield func(T) bool) {
119 | defer func() {
120 | *stopped = true
121 | }()
122 | for _, x := range s {
123 | if !yield(x) {
124 | return
125 | }
126 | }
127 | }
128 | }
129 |
--------------------------------------------------------------------------------
/nlp/lda-tool/main.go:
--------------------------------------------------------------------------------
1 | // Command lda-tool performs LDA on the input documents.
2 | package main
3 |
4 | import (
5 | "bufio"
6 | "encoding/json"
7 | "flag"
8 | "fmt"
9 | "io"
10 | "os"
11 | "regexp"
12 | "runtime"
13 | "strings"
14 |
15 | "github.com/fluhus/gostuff/nlp"
16 | "golang.org/x/exp/maps"
17 | )
18 |
19 | var (
20 | k = flag.Int("k", 0, "Number of topics")
21 | numThreads = flag.Int("t", 1, "Number of therads to use")
22 | js = flag.Bool("j", false, "Output as JSON instead of default format")
23 | )
24 |
25 | func main() {
26 | parseArgs()
27 |
28 | // Read input and perform LDA.
29 | fmt.Fprintln(os.Stdout, "Run with no arguments for usage help.")
30 | fmt.Fprintln(os.Stdout, "Reading documents from stdin...")
31 | docs, err := readDocs(os.Stdin)
32 | if err != nil {
33 | die("Error: failed to read input:", err)
34 | }
35 | fmt.Fprintln(os.Stdout, "Found", len(docs), "documents.")
36 |
37 | fmt.Fprintln(os.Stdout, "Performing LDA...")
38 | lda, _ := nlp.LdaThreads(docs, *k, *numThreads)
39 |
40 | // Print output.
41 | if *js {
42 | j, _ := json.MarshalIndent(lda, "", "\t")
43 | fmt.Println(string(j))
44 | } else {
45 | for _, w := range maps.Keys(lda) {
46 | fmt.Print(w)
47 | for _, x := range lda[w] {
48 | fmt.Printf(" %v", x)
49 | }
50 | fmt.Println()
51 | }
52 | }
53 | }
54 |
55 | // readDocs reads documents, one per line, from the input reader.
56 | // It splits and lowercases the documents, and returns them as a 2d slice.
57 | func readDocs(r io.Reader) ([][]string, error) {
58 | wordsRe := regexp.MustCompile(`\w+`)
59 | scanner := bufio.NewScanner(r)
60 | var result [][]string
61 | for scanner.Scan() {
62 | w := wordsRe.FindAllString(strings.ToLower(scanner.Text()), -1)
63 |
64 | // Copy line to a lower capacity slice, to reduce memory usage.
65 | result = append(result, make([]string, len(w)))
66 | copy(result[len(result)-1], w)
67 | }
68 | if scanner.Err() != nil {
69 | return nil, scanner.Err()
70 | }
71 | return result, nil
72 | }
73 |
74 | // die reports an error message and exits with error code 2.
75 | // Arguments are treated like Println.
76 | func die(a ...interface{}) {
77 | fmt.Fprintln(os.Stderr, a...)
78 | os.Exit(2)
79 | }
80 |
81 | // parseArgs parses the program's arguments and validates them.
82 | // Exits with an error message upon validation error.
83 | func parseArgs() {
84 | flag.Parse()
85 | if len(os.Args) == 1 {
86 | fmt.Fprintln(os.Stderr, help)
87 | flag.PrintDefaults()
88 | os.Exit(1)
89 | }
90 | if *k < 1 {
91 | die("Error: invalid k:", *k)
92 | }
93 | if *numThreads < 0 {
94 | die("Error: invalid number of threads:", *numThreads)
95 | }
96 | if *numThreads == 0 {
97 | *numThreads = runtime.NumCPU()
98 | }
99 | }
100 |
101 | var help = `Performs LDA on the given documents.
102 |
103 | Input is read from the standard input. Format is one document per line.
104 | Documents will be lowercased and normalized (spaces and punctuation omitted).
105 |
106 | Output is printed to the standard output. Format is one word per line.
107 | Each word is followed by K numbers, the i'th number represents the likelihood
108 | of the i'th topic to emit that word.
109 | `
110 |
--------------------------------------------------------------------------------
/flagx/flagx_test.go:
--------------------------------------------------------------------------------
1 | package flagx
2 |
3 | import (
4 | "flag"
5 | "fmt"
6 | "testing"
7 | )
8 |
9 | func TestRegexp(t *testing.T) {
10 | fs := flag.NewFlagSet("", flag.PanicOnError)
11 |
12 | re := RegexpFlagSet(fs, "a", nil, "")
13 | fs.Parse([]string{"-a", "a..b"})
14 | if (*re).String() != "a..b" {
15 | t.Errorf("RegexpFlagSet(...)=%q, want %q",
16 | (*re).String(), "a..b")
17 | }
18 |
19 | fs = flag.NewFlagSet("", flag.PanicOnError)
20 | re = RegexpFlagSet(fs, "a", nil, "")
21 | fs.Parse(nil)
22 | if (*re) != nil {
23 | t.Errorf("RegexpFlagSet(...)=%q, want nil", (*re).String())
24 | }
25 | }
26 |
27 | func TestIntBetween(t *testing.T) {
28 | fs := flag.NewFlagSet("", flag.PanicOnError)
29 |
30 | ii := IntBetweenFlagSet(fs, "i", 3, "", 1, 5)
31 | if *ii != 3 {
32 | t.Errorf("IntBetweenFlagSet(...)=%v, want %v", ii, 3)
33 | }
34 |
35 | // Valid values.
36 | for i := 1; i <= 5; i++ {
37 | args := []string{"-i", fmt.Sprint(i)}
38 | fs.Parse(args)
39 | if *ii != i {
40 | t.Errorf("Parse(%v)=%v, want %v", args, ii, i)
41 | }
42 | }
43 |
44 | // Invalid values.
45 | for _, i := range []int{-1, 0, 6, 7, 10} {
46 | func() {
47 | args := []string{"-i", fmt.Sprint(i)}
48 | defer func() {
49 | recover()
50 | }()
51 | fs.Parse(args)
52 | t.Errorf("Parse(%v)=%v, want error", args, ii)
53 | }()
54 | }
55 | }
56 |
57 | func TestOneOf_string(t *testing.T) {
58 | fs := flag.NewFlagSet("", flag.PanicOnError)
59 |
60 | vals := []string{"blue", "yellow", "red"}
61 | ss := OneOfFlagSet(fs, "s", vals[0], "", vals...)
62 | if *ss != vals[0] {
63 | t.Errorf("StringFromFlagSet(...)=%v, want %v", ss, vals[0])
64 | }
65 |
66 | // Valid values.
67 | for _, s := range vals {
68 | args := []string{"-s", s}
69 | fs.Parse(args)
70 | if *ss != s {
71 | t.Errorf("Parse(%v)=%v, want %v", args, ss, s)
72 | }
73 | }
74 |
75 | // Invalid values.
76 | for _, s := range vals {
77 | func() {
78 | args := []string{"-s", s + "."}
79 | defer func() {
80 | recover()
81 | }()
82 | fs.Parse(args)
83 | t.Errorf("Parse(%v)=%v, want error", args, ss)
84 | }()
85 | }
86 | }
87 |
88 | func TestOneOf_int(t *testing.T) {
89 | fs := flag.NewFlagSet("", flag.PanicOnError)
90 |
91 | vals := []int{3, 55, 888}
92 | oct := []string{"0o3", "0o67", "0o1570"}
93 | ss := OneOfFlagSet(fs, "s", vals[0], "", vals...)
94 | if *ss != vals[0] {
95 | t.Errorf("StringFromFlagSet(...)=%v, want %v", ss, vals[0])
96 | }
97 |
98 | // Valid values.
99 | for _, i := range vals {
100 | args := []string{"-s", fmt.Sprint(i)}
101 | fs.Parse(args)
102 | if *ss != i {
103 | t.Errorf("Parse(%v)=%v, want %v", args, ss, i)
104 | }
105 | }
106 |
107 | // Octal representation.
108 | for i, s := range oct {
109 | args := []string{"-s", s}
110 | fs.Parse(args)
111 | want := vals[i]
112 | if *ss != want {
113 | t.Errorf("Parse(%v)=%v, want %v", args, ss, want)
114 | }
115 | }
116 |
117 | // Invalid values.
118 | for _, i := range vals {
119 | func() {
120 | args := []string{"-s", fmt.Sprint(i) + "aaa"}
121 | defer func() {
122 | recover()
123 | }()
124 | fs.Parse(args)
125 | t.Errorf("Parse(%v)=%v, want error", args, *ss)
126 | }()
127 | }
128 | }
129 |
--------------------------------------------------------------------------------
/ptimer/ptimer.go:
--------------------------------------------------------------------------------
1 | // Package ptimer provides a progress timer for iterative processes.
2 | //
3 | // A timer prints how much time passed since its creation at exponentially
4 | // growing time-points.
5 | // Precisely, prints are triggered after i calls to Inc, if i has only one non-zero
6 | // digit. That is: 1, 2, 3 .. 9, 10, 20, 30 .. 90, 100, 200, 300...
7 | //
8 | // # Output Format
9 | //
10 | // For a regular use:
11 | //
12 | // 00:00:00.000000 (00:00:00.000000) message
13 | // | | |
14 | // Total time since creation | |
15 | // | |
16 | // Average time per call to Inc |
17 | // |
18 | // User-defined message ----------------|
19 | // (default message is number of calls to Inc)
20 | //
21 | // When calling Done without calling Inc:
22 | //
23 | // 00:00:00.000000 message
24 | package ptimer
25 |
26 | import (
27 | "fmt"
28 | "io"
29 | "os"
30 | "strings"
31 | "time"
32 | )
33 |
34 | // A Timer measures time during iterative processes and prints the progress on
35 | // exponential checkpoints.
36 | type Timer struct {
37 | N int // Current count, incremented with each call to Inc
38 | W io.Writer // Timer's output, defaults to stderr
39 | t time.Time // Start time
40 | f func(int) string // Message function
41 | c int // Next checkpoint
42 | }
43 |
44 | // Prints the progress.
45 | func (t *Timer) print() {
46 | since := time.Since(t.t)
47 | if t.N == 0 { // Happens when calling Done without Inc.
48 | fmt.Fprintf(t.W, "%s %s", fmtDuration(since), t.f(t.N))
49 | return
50 | }
51 | fmt.Fprintf(t.W, "\r%s (%s) %s", fmtDuration(since),
52 | fmtDuration(since/time.Duration(t.N)), t.f(t.N))
53 | }
54 |
55 | // Formats a duration in constant-width format.
56 | func fmtDuration(d time.Duration) string {
57 | return fmt.Sprintf("%02d:%02d:%02d.%06d",
58 | d/time.Hour,
59 | d%time.Hour/time.Minute,
60 | d%time.Minute/time.Second,
61 | d%time.Second/time.Microsecond,
62 | )
63 | }
64 |
65 | // NewFunc returns a new timer that calls f with the current count on checkpoints,
66 | // and prints its output.
67 | func NewFunc(f func(i int) string) *Timer {
68 | return &Timer{0, os.Stderr, time.Now(), f, 0}
69 | }
70 |
71 | // NewMessage returns a new timer that prints msg on checkpoints.
72 | // A "{}" in msg will be replaced with the current count.
73 | func NewMessage(msg string) *Timer {
74 | return NewFunc(func(i int) string {
75 | return strings.ReplaceAll(msg, "{}", fmt.Sprint(i))
76 | })
77 | }
78 |
79 | // New returns a new timer that prints the current count on checkpoints.
80 | func New() *Timer {
81 | return NewMessage("{}")
82 | }
83 |
84 | // Inc increments t's counter and prints progress if reached a checkpoint.
85 | func (t *Timer) Inc() {
86 | t.N++
87 | if t.N >= t.c {
88 | t.print()
89 | for t.c <= t.N {
90 | t.c = nextCheckpoint(t.c)
91 | }
92 | }
93 | }
94 |
95 | // Done prints progress as if a checkpoint was reached.
96 | func (t *Timer) Done() {
97 | t.print()
98 | fmt.Fprintln(t.W)
99 | }
100 |
101 | // Returns the next i in which the timer should print, given the current i.
102 | func nextCheckpoint(i int) int {
103 | m := 1
104 | for m*10 <= i {
105 | m *= 10
106 | }
107 | if i%m != 0 {
108 | panic(fmt.Sprintf(
109 | "bad checkpoint: %d, should be a multiple of a power of 10", i))
110 | }
111 | return i + m
112 | }
113 |
--------------------------------------------------------------------------------
/iterx/unreader2_test.go:
--------------------------------------------------------------------------------
1 | package iterx
2 |
3 | import (
4 | "iter"
5 | "reflect"
6 | "slices"
7 | "testing"
8 |
9 | "github.com/fluhus/gostuff/ppln"
10 | )
11 |
12 | func TestUnreader2_until(t *testing.T) {
13 | input := []int{1, 4, 2, 6, 8, 4, 5, 7}
14 | tests := []struct {
15 | until int
16 | want []int
17 | }{
18 | {6, []int{1, 4, 2}}, {4, []int{6, 8}}, {1, []int{4, 5, 7}},
19 | }
20 | r := NewUnreader2(ppln.SliceInput(input))
21 | for _, test := range tests {
22 | var got []int
23 | for i, err := range r.Until(func(j int, err error) bool { return j == test.until }) {
24 | if err != nil {
25 | t.Fatalf("New(%v).Until(%d) failed: %v", input, test.until, err)
26 | }
27 | got = append(got, i)
28 | }
29 | if !slices.Equal(got, test.want) {
30 | t.Fatalf("New(%v).Until(%d)=%v, want %v", input, test.until, got, test.want)
31 | }
32 | }
33 | }
34 |
35 | func TestUnreader2_groupBy(t *testing.T) {
36 | input := []int{1, 4, 2, 6, 9, 4, 5, 7}
37 | want := [][]int{{1}, {4, 2, 6}, {9}, {4}, {5, 7}}
38 | var got [][]int
39 | r := NewUnreader2(ppln.SliceInput(input))
40 | for group := range r.GroupBy(func(i int, ie error, j int, je error) bool {
41 | return i%2 == j%2
42 | }) {
43 | var got1 []int
44 | for i := range group {
45 | got1 = append(got1, i)
46 | }
47 | got = append(got, got1)
48 | }
49 | if !reflect.DeepEqual(got, want) {
50 | t.Fatalf("New(%v).GroupBy(...)=%v, want %v", input, got, want)
51 | }
52 | }
53 |
54 | func TestUnreader2_stop(t *testing.T) {
55 | input := []int{1, 2, 3, 4, 5, 6, 7, 8, 9}
56 |
57 | stopped := false
58 | broke := false
59 | r := NewUnreader2(stopIter2(input, &stopped))
60 | for x := range r.Until(func(i int, err error) bool { return i == 5 }) {
61 | if x == 3 {
62 | broke = true
63 | break
64 | }
65 | }
66 | if !broke {
67 | t.Fatalf("Unreader2(%v) never broke", input)
68 | }
69 | if !stopped {
70 | t.Fatalf("Unreader2(%v) never called stop", input)
71 | }
72 |
73 | stopped, broke = false, false
74 | continued := true
75 | r = NewUnreader2(stopIter2(input, &stopped))
76 | for g := range r.GroupBy(func(i int, ie error, j int, je error) bool {
77 | return i/3 == j/3
78 | }) {
79 | continued = true
80 | for x := range g {
81 | if x == 5 {
82 | broke = true
83 | continued = false
84 | break
85 | }
86 | }
87 | }
88 | if !broke {
89 | t.Fatalf("Unreader2(%v) never broke", input)
90 | }
91 | if continued {
92 | t.Fatalf("Unreader2(%v).GroupBy outer loop continued "+
93 | "after inner loop broke", input)
94 | }
95 | if !stopped {
96 | t.Fatalf("Unreader2(%v) never called stop", input)
97 | }
98 |
99 | stopped = false
100 | toBreak := false
101 | r = NewUnreader2(stopIter2(input, &stopped))
102 | for g := range r.GroupBy(func(i int, ie error, j int, je error) bool {
103 | return i/3 == j/3
104 | }) {
105 | for x := range g {
106 | if x == 5 {
107 | toBreak = true
108 | }
109 | }
110 | if toBreak {
111 | break
112 | }
113 | }
114 | if !toBreak {
115 | t.Fatalf("Unreader2(%v) never broke", input)
116 | }
117 | if !stopped {
118 | t.Fatalf("Unreader2(%v) never called stop", input)
119 | }
120 | }
121 |
122 | func stopIter2[T any](s []T, stopped *bool) iter.Seq2[T, error] {
123 | return func(yield func(T, error) bool) {
124 | defer func() {
125 | *stopped = true
126 | }()
127 | for _, x := range s {
128 | if !yield(x, nil) {
129 | return
130 | }
131 | }
132 | }
133 | }
134 |
--------------------------------------------------------------------------------
/nlp/wordnet/parser_test.go:
--------------------------------------------------------------------------------
1 | package wordnet
2 |
3 | import (
4 | "encoding/json"
5 | "reflect"
6 | "strings"
7 | "testing"
8 | )
9 |
10 | func TestDataParser(t *testing.T) {
11 | expected := map[string]*Synset{
12 | "v111": {
13 | "111",
14 | "v",
15 | []string{
16 | "foo",
17 | "bar",
18 | "baz",
19 | },
20 | []*Pointer{
21 | {"!", "n123", -1, -1},
22 | {"@", "a321", 0, 1},
23 | },
24 | []*Frame{
25 | {4, 4},
26 | {6, 6},
27 | },
28 | "hello world",
29 | nil,
30 | },
31 | }
32 |
33 | actual := map[string]*Synset{}
34 | err := parseDataFile(strings.NewReader(testData), "v", map[string][]int{},
35 | actual)
36 | if err != nil {
37 | t.Fatal("Parsing error:", err)
38 | }
39 | if !reflect.DeepEqual(expected, actual) {
40 | t.Error("Non-equal values:")
41 | t.Error(stringify(expected))
42 | t.Error(stringify(actual))
43 | }
44 | for key, ss := range actual {
45 | if ss.Id() != key {
46 | t.Errorf("ss.Id()=%v, want %v", ss.Id(), key)
47 | }
48 | }
49 | }
50 |
51 | func TestExceptionParser(t *testing.T) {
52 | expected := map[string][]string{
53 | "n.foo": {"n.bar"},
54 | "n.baz": {"n.bla", "n.blu"},
55 | }
56 |
57 | actual := map[string][]string{}
58 | err := parseExceptionFile(strings.NewReader(testException), "n", actual)
59 | if err != nil {
60 | t.Fatal("Parsing error:", err)
61 | }
62 | if !reflect.DeepEqual(expected, actual) {
63 | t.Error("Non-equal values:")
64 | t.Error(stringify(expected))
65 | t.Error(stringify(actual))
66 | }
67 | }
68 |
69 | func TestExampleIndexParser(t *testing.T) {
70 | expected := map[string][]int{
71 | "abash.37.0": {126, 127},
72 | "abhor.37.0": {138, 139, 15},
73 | }
74 |
75 | actual, err := parseExampleIndex(strings.NewReader(testExampleIndex))
76 | if err != nil {
77 | t.Fatal("Parsing error:", err)
78 | }
79 | if !reflect.DeepEqual(expected, actual) {
80 | t.Error("Non-equal values:")
81 | t.Error(expected)
82 | t.Error(actual)
83 | }
84 | }
85 |
86 | func TestExampleParser(t *testing.T) {
87 | expected := map[string]string{
88 | "111": "hello world",
89 | "222": "goodbye universe",
90 | }
91 |
92 | actual, err := parseExamples(strings.NewReader(testExamples))
93 | if err != nil {
94 | t.Fatal("Parsing error:", err)
95 | }
96 | if !reflect.DeepEqual(expected, actual) {
97 | t.Error("Non-equal values:")
98 | t.Error(expected)
99 | t.Error(actual)
100 | }
101 | }
102 |
103 | func TestIndexParser(t *testing.T) {
104 | expected := map[string][]string{
105 | "n.thing": {"na", "nb"},
106 | "v.thing2": {"vc", "vd"},
107 | }
108 |
109 | actual, err := parseIndex(strings.NewReader(testIndex))
110 | if err != nil {
111 | t.Fatal("Parsing error:", err)
112 | }
113 | if !reflect.DeepEqual(expected, actual) {
114 | t.Error("Non-equal values:")
115 | t.Error(expected)
116 | t.Error(actual)
117 | }
118 | }
119 |
120 | func stringify(a interface{}) string {
121 | j, _ := json.Marshal(a)
122 | return string(j)
123 | }
124 |
125 | var testData = ` copyright line
126 | 111 1 v 3 foo 1 bar 3 baz 5 2 ! 123 n 0000 @ 321 a 0102 2 + 4 5 + 6 7 | hello world`
127 |
128 | var testException = `foo bar
129 | baz bla blu`
130 |
131 | var testExampleIndex = `abash%2:37:00:: 126,127
132 | abhor%2:37:00:: 138,139,15`
133 |
134 | var testExamples = `111 hello world
135 | 222 goodbye universe`
136 |
137 | var testIndex = ` copyright line
138 | thing n 2 3 x y z 2 2 a b
139 | thing2 v 4 1 x 4 2 c d e f`
140 |
--------------------------------------------------------------------------------
/bits/bits_test.go:
--------------------------------------------------------------------------------
1 | package bits
2 |
3 | import (
4 | "slices"
5 | "testing"
6 | )
7 |
8 | func TestSet_true(t *testing.T) {
9 | tests := []struct {
10 | n int
11 | want []byte
12 | }{
13 | {0, []byte{1, 0}},
14 | {1, []byte{2, 0}},
15 | {2, []byte{4, 0}},
16 | {3, []byte{8, 0}},
17 | {4, []byte{16, 0}},
18 | {5, []byte{32, 0}},
19 | {6, []byte{64, 0}},
20 | {7, []byte{128, 0}},
21 | {8, []byte{0, 1}},
22 | {9, []byte{0, 2}},
23 | {10, []byte{0, 4}},
24 | {11, []byte{0, 8}},
25 | {12, []byte{0, 16}},
26 | {13, []byte{0, 32}},
27 | {14, []byte{0, 64}},
28 | {15, []byte{0, 128}},
29 | }
30 | for _, test := range tests {
31 | b := []byte{0, 0}
32 | Set(b, test.n, true)
33 | if !slices.Equal(b, test.want) {
34 | t.Errorf("Set(%v,%v,%v)=%v, want %v",
35 | []byte{0, 0}, test.n, true, b, test.want)
36 | }
37 | }
38 | }
39 |
40 | func TestSet_false(t *testing.T) {
41 | tests := []struct {
42 | n int
43 | want []byte
44 | }{
45 | {0, []byte{255 - 1, 255}},
46 | {1, []byte{255 - 2, 255}},
47 | {2, []byte{255 - 4, 255}},
48 | {3, []byte{255 - 8, 255}},
49 | {4, []byte{255 - 16, 255}},
50 | {5, []byte{255 - 32, 255}},
51 | {6, []byte{255 - 64, 255}},
52 | {7, []byte{255 - 128, 255}},
53 | {8, []byte{255, 255 - 1}},
54 | {9, []byte{255, 255 - 2}},
55 | {10, []byte{255, 255 - 4}},
56 | {11, []byte{255, 255 - 8}},
57 | {12, []byte{255, 255 - 16}},
58 | {13, []byte{255, 255 - 32}},
59 | {14, []byte{255, 255 - 64}},
60 | {15, []byte{255, 255 - 128}},
61 | }
62 | for _, test := range tests {
63 | b := []byte{255, 255}
64 | Set(b, test.n, false)
65 | if !slices.Equal(b, test.want) {
66 | t.Errorf("Set(%v,%v,%v)=%v, want %v",
67 | []byte{0, 0}, test.n, false, b, test.want)
68 | }
69 | }
70 | }
71 |
72 | func TestGet(t *testing.T) {
73 | input := []byte{0b10011001, 0b01100111}
74 | want := []int{1, 0, 0, 1, 1, 0, 0, 1,
75 | 1, 1, 1, 0, 0, 1, 1, 0}
76 | for i := range want {
77 | if got := Get(input, i); got != want[i] {
78 | t.Errorf("Get(%v,%v)=%v, want %v", input, i, got, want[i])
79 | }
80 | }
81 | }
82 |
83 | func TestSum(t *testing.T) {
84 | tests := []struct {
85 | input []byte
86 | want int
87 | }{
88 | {nil, 0},
89 | {[]byte{0}, 0},
90 | {[]byte{0b01100111}, 5},
91 | {[]byte{0b10011001, 0b01100111}, 9},
92 | }
93 | for _, test := range tests {
94 | if got := Sum(test.input); got != test.want {
95 | t.Errorf("Sum(%v)=%v, want %v", test.input, got, test.want)
96 | }
97 | }
98 | }
99 |
100 | func TestByteOnes(t *testing.T) {
101 | tests := []struct {
102 | i int
103 | want []int
104 | }{
105 | {0, nil},
106 | {1, []int{0}},
107 | {2, []int{1}},
108 | {3, []int{0, 1}},
109 | {4, []int{2}},
110 | {5, []int{0, 2}},
111 | {255, []int{0, 1, 2, 3, 4, 5, 6, 7}},
112 | }
113 | for _, test := range tests {
114 | if got := byteOnes[test.i]; !slices.Equal(got, test.want) {
115 | t.Errorf("byteOnes(%v)=%v, want %v", test.i, got, test.want)
116 | }
117 | }
118 | }
119 |
120 | func TestOnesZeros(t *testing.T) {
121 | input := []byte{0b01011100, 0b11101010}
122 | wantOnes := []int{2, 3, 4, 6, 9, 11, 13, 14, 15}
123 | wantZeros := []int{0, 1, 5, 7, 8, 10, 12}
124 |
125 | gotOnes := slices.Collect(Ones(input))
126 | if !slices.Equal(gotOnes, wantOnes) {
127 | t.Errorf("Ones(%v)=%v, want %v", input, gotOnes, wantOnes)
128 | }
129 |
130 | gotZeros := slices.Collect(Zeros(input))
131 | if !slices.Equal(gotZeros, wantZeros) {
132 | t.Errorf("Zeros(%v)=%v, want %v", input, gotZeros, wantZeros)
133 | }
134 | }
135 |
--------------------------------------------------------------------------------
/clustering/upgma.go:
--------------------------------------------------------------------------------
1 | package clustering
2 |
3 | import (
4 | "fmt"
5 | "math"
6 |
7 | "github.com/fluhus/gostuff/gnum"
8 | "github.com/fluhus/gostuff/heaps"
9 | )
10 |
11 | // distPyramid is a distance half-matrix.
12 | type distPyramid [][]float64
13 |
14 | // dist returns the distance between a and b.
15 | func (d distPyramid) dist(a, b int) float64 {
16 | if a > b {
17 | return d[a][b]
18 | }
19 | return d[b][a]
20 | }
21 |
22 | // makePyramid creates a distance half-matrix.
23 | func makePyramid(n int, f func(int, int) float64) distPyramid {
24 | nn := n * (n - 1) / 2
25 | d := make([]float64, 0, nn)
26 | for i := 1; i < n; i++ {
27 | for j := 0; j < i; j++ {
28 | d = append(d, f(j, i))
29 | }
30 | }
31 | result := make([][]float64, n)
32 | j := 0
33 | for i := range result {
34 | result[i] = d[j : j+i]
35 | j += i
36 | }
37 | return result
38 | }
39 |
40 | // upgma is an implementation of UPGMA clustering. The distance between clusters
41 | // is the average distance between pairs of their individual elements.
42 | func upgma(n int, f func(int, int) float64) *AggloResult {
43 | pi := make([]int, n) // Index of first merge target of each element.
44 | lambda := make([]float64, n) // Distance of first merge target of each element.
45 |
46 | // Last cluster does not get matched with anyone -> max distance.
47 | lambda[len(lambda)-1] = math.MaxFloat64
48 |
49 | // Calculate raw distances.
50 | d := makePyramid(n, f)
51 | heapss := make([]*heaps.Heap[upgmaCluster], n)
52 | for i := range heapss {
53 | heapss[i] = heaps.New(compareUpgmaClusters)
54 | }
55 | for i := 1; i < n; i++ {
56 | for j := 0; j < i; j++ {
57 | heapss[i].Push(upgmaCluster{j, d[i][j]})
58 | heapss[j].Push(upgmaCluster{i, d[i][j]})
59 | }
60 | }
61 |
62 | // Clustering.
63 | sizes := gnum.Ones[[]float64](n) // Cluster sizes
64 | // The identifier of each cluster = highest index of an element
65 | names := make([]int, n)
66 | for i := range names {
67 | names[i] = i
68 | }
69 | for i := 0; i < n-1; i++ {
70 | // Find lowest distance.
71 | fmin := math.MaxFloat64
72 | a, b := -1, -1
73 | for hi, h := range heapss {
74 | if h == nil {
75 | continue
76 | }
77 | // Clean up removed clusters.
78 | if h.Len() == 0 {
79 | panic(fmt.Sprintf("heap %d with length 0", hi))
80 | }
81 | for heapss[h.Head().i] == nil {
82 | h.Pop()
83 | }
84 | if h.Head().d < fmin {
85 | a = hi
86 | fmin = h.Head().d
87 | b = h.Head().i
88 | }
89 | }
90 |
91 | // Create agglo step.
92 | nmin := min(names[a], names[b])
93 | nmax := max(names[a], names[b])
94 | pi[nmin] = nmax
95 | lambda[nmin] = fmin
96 |
97 | // Merge clusters.
98 | names = append(names, nmax)
99 | sizes = append(sizes, sizes[a]+sizes[b])
100 | heapss[a] = nil
101 | heapss[b] = nil
102 | var cdist []float64
103 | cheap := heaps.New(compareUpgmaClusters)
104 | for hi, h := range heapss {
105 | if h == nil {
106 | cdist = append(cdist, 0)
107 | continue
108 | }
109 | da := d.dist(a, hi) * sizes[a]
110 | db := d.dist(b, hi) * sizes[b]
111 | dd := (da + db) / (sizes[a] + sizes[b])
112 | cdist = append(cdist, dd)
113 | h.Push(upgmaCluster{len(sizes) - 1, dd})
114 | cheap.Push(upgmaCluster{hi, dd})
115 | }
116 | d = append(d, cdist)
117 | heapss = append(heapss, cheap)
118 | }
119 |
120 | return newAggloResult(pi, lambda)
121 | }
122 |
123 | // Cluster info in UPGMA.
124 | type upgmaCluster struct {
125 | i int // Cluster index
126 | d float64 // Distance from cluster i
127 | }
128 |
129 | func compareUpgmaClusters(a, b upgmaCluster) bool {
130 | return a.d < b.d
131 | }
132 |
--------------------------------------------------------------------------------
/graphs/bfsdfs_test.go:
--------------------------------------------------------------------------------
1 | package graphs
2 |
3 | import (
4 | "cmp"
5 | "slices"
6 | "testing"
7 |
8 | "github.com/fluhus/gostuff/snm"
9 | )
10 |
11 | func TestBFS(t *testing.T) {
12 | edges := [][2]string{
13 | {"a", "b"},
14 | {"a", "c"},
15 | {"b", "d"},
16 | {"b", "e"},
17 | {"c", "f"},
18 | {"e", "g"},
19 | {"e", "h"},
20 | }
21 | want := [][]string{
22 | {"a"},
23 | {"b", "c"},
24 | {"d", "e", "f"},
25 | {"g", "h"},
26 | }
27 |
28 | snm.Shuffle(edges)
29 | g := New[string]()
30 | for _, e := range edges {
31 | g.AddEdge(e[0], e[1])
32 | }
33 |
34 | got := slices.Collect(g.BFS("a"))
35 | if wantLen := sumLens(want); len(got) != wantLen {
36 | t.Fatalf("BFS(...) len=%v, want %v", len(got), wantLen)
37 | }
38 | ggot := groupLike(got, want)
39 | for _, g := range ggot {
40 | slices.Sort(g)
41 | }
42 | if !slices.EqualFunc(ggot, want, slices.Equal) {
43 | t.Fatalf("BFS(...)=%v, want %v", got, want)
44 | }
45 | }
46 |
47 | func TestBFS_loop(t *testing.T) {
48 | edges := [][2]string{
49 | {"a", "b"},
50 | {"b", "c"},
51 | {"b", "d"},
52 | {"c", "e"},
53 | {"d", "f"},
54 | {"e", "g"},
55 | {"f", "g"},
56 | }
57 | want := [][]string{
58 | {"a"},
59 | {"b"},
60 | {"c", "d"},
61 | {"e", "f"},
62 | {"g"},
63 | }
64 |
65 | snm.Shuffle(edges)
66 | g := New[string]()
67 | for _, e := range edges {
68 | g.AddEdge(e[0], e[1])
69 | }
70 |
71 | got := slices.Collect(g.BFS("a"))
72 | if wantLen := sumLens(want); len(got) != wantLen {
73 | t.Fatalf("BFS(...) len=%v, want %v", len(got), wantLen)
74 | }
75 | ggot := groupLike(got, want)
76 | for _, g := range ggot {
77 | slices.Sort(g)
78 | }
79 | if !slices.EqualFunc(ggot, want, slices.Equal) {
80 | t.Fatalf("BFS(...)=%v, want %v", got, want)
81 | }
82 | }
83 |
84 | func TestDFS(t *testing.T) {
85 | edges := [][2]string{
86 | {"a", "b"},
87 | {"a", "c"},
88 | {"b", "d"},
89 | {"b", "e"},
90 | {"c", "f"},
91 | {"c", "g"},
92 | }
93 | want := [][]string{
94 | {"a"},
95 | {"b", "d", "e"},
96 | {"c", "f", "g"},
97 | }
98 |
99 | snm.Shuffle(edges)
100 | g := New[string]()
101 | for _, e := range edges {
102 | g.AddEdge(e[0], e[1])
103 | }
104 |
105 | got := slices.Collect(g.DFS("a"))
106 | if wantLen := sumLens(want); len(got) != wantLen {
107 | t.Fatalf("DFS(...) len=%v, want %v", len(got), wantLen)
108 | }
109 | ggot := groupLike(got, want)
110 | for _, g := range ggot {
111 | slices.Sort(g[1:])
112 | }
113 | slices.SortFunc(ggot, func(a, b []string) int {
114 | return cmp.Compare(a[0], b[0])
115 | })
116 | if !slices.EqualFunc(ggot, want, slices.Equal) {
117 | t.Fatalf("DFS(...)=%v, want %v", got, want)
118 | }
119 | }
120 |
121 | func TestDFS_loop(t *testing.T) {
122 | edges := [][2]string{
123 | {"a", "b"},
124 | {"b", "c"},
125 | {"c", "d"},
126 | {"d", "a"},
127 | }
128 | want := [][]string{
129 | {"b", "c", "d", "a"},
130 | {"b", "a", "d", "c"},
131 | }
132 |
133 | snm.Shuffle(edges)
134 | g := New[string]()
135 | for _, e := range edges {
136 | g.AddEdge(e[0], e[1])
137 | }
138 |
139 | got := slices.Collect(g.DFS("b"))
140 | found := slices.IndexFunc(want, func(a []string) bool {
141 | return slices.Equal(a, got)
142 | })
143 | if found == -1 {
144 | t.Fatalf("DFS(...)=%v, want one of %v", got, want)
145 | }
146 | }
147 |
148 | // Returns the sum of lengths of slices.
149 | func sumLens(a [][]string) int {
150 | i := 0
151 | for _, x := range a {
152 | i += len(x)
153 | }
154 | return i
155 | }
156 |
157 | // Returns a sliced like like.
158 | func groupLike(a []string, like [][]string) [][]string {
159 | var result [][]string
160 | for _, x := range like {
161 | result = append(result, a[:len(x)])
162 | a = a[len(x):]
163 | }
164 | return result
165 | }
166 |
--------------------------------------------------------------------------------
/hll/hll.go:
--------------------------------------------------------------------------------
1 | // Package hll provides an implementation of the HyperLogLog algorithm.
2 | //
3 | // A HyperLogLog counter can approximate the cardinality of a set with high
4 | // accuracy and little memory.
5 | //
6 | // # Accuracy
7 | //
8 | // Average error for 1,000,000,000 elements for different values of logSize:
9 | //
10 | // logSize average error %
11 | // 4 21
12 | // 5 12
13 | // 6 10
14 | // 7 8.1
15 | // 8 4.8
16 | // 9 3.6
17 | // 10 1.9
18 | // 11 1.2
19 | // 12 1.0
20 | // 13 0.7
21 | // 14 0.5
22 | // 15 0.33
23 | // 16 0.25
24 | //
25 | // # Citation
26 | //
27 | // Flajolet, Philippe; Fusy, Éric; Gandouet, Olivier; Meunier, Frédéric (2007).
28 | // "Hyperloglog: The analysis of a near-optimal cardinality estimation
29 | // algorithm". Discrete Mathematics and Theoretical Computer Science
30 | // Proceedings.
31 | package hll
32 |
33 | import (
34 | "fmt"
35 | "math"
36 | )
37 |
38 | // An HLL is a HyperLogLog counter for arbitrary values.
39 | type HLL[T any] struct {
40 | counters []byte
41 | h func(T) uint64
42 | nbits int
43 | m int
44 | mask uint64
45 | }
46 |
47 | // New creates a new HyperLogLog counter.
48 | // The counter will use 2^logSize bytes.
49 | // h is the hash function to use for added values.
50 | func New[T any](logSize int, h func(T) uint64) *HLL[T] {
51 | if logSize < 4 {
52 | panic(fmt.Sprintf("logSize=%v, should be at least 4", logSize))
53 | }
54 | m := 1 << logSize
55 | return &HLL[T]{
56 | counters: make([]byte, m),
57 | h: h,
58 | nbits: logSize,
59 | m: m,
60 | mask: uint64(m - 1),
61 | }
62 | }
63 |
64 | // Add adds v to the counter. Calls hash once.
65 | func (h *HLL[T]) Add(t T) {
66 | hash := h.h(t)
67 | idx := hash & h.mask
68 | fp := hash >> h.nbits
69 | z := byte(h.nzeros(fp)) + 1
70 | if z > h.counters[idx] {
71 | h.counters[idx] = z
72 | }
73 | }
74 |
75 | // ApproxCount returns the current approximate count.
76 | // Does not alter the state of the counter.
77 | func (h *HLL[T]) ApproxCount() int {
78 | z := 0.0
79 | for _, v := range h.counters {
80 | z += math.Pow(2, -float64(v))
81 | }
82 | z = 1.0 / z
83 | fm := float64(h.m)
84 | result := int(h.alpha() * fm * fm * z)
85 |
86 | if result < h.m*5/2 {
87 | zeros := 0
88 | for _, v := range h.counters {
89 | if v == 0 {
90 | zeros++
91 | }
92 | }
93 | // If some registers are zero, use linear counting.
94 | if zeros > 0 {
95 | result = int(fm * math.Log(fm/float64(zeros)))
96 | }
97 | }
98 |
99 | return result
100 | }
101 |
102 | // Returns the alpha value to use depending on m.
103 | func (h *HLL[T]) alpha() float64 {
104 | switch h.m {
105 | case 16:
106 | return 0.673
107 | case 32:
108 | return 0.697
109 | case 64:
110 | return 0.709
111 | }
112 | return 0.7213 / (1 + 1.079/float64(h.m))
113 | }
114 |
115 | // nzeros counts the number of zeros on the right side of a binary number.
116 | func (h *HLL[T]) nzeros(a uint64) int {
117 | if a == 0 {
118 | return 64 - h.nbits // Number of bits after using the first nbits.
119 | }
120 | n := 0
121 | for a&1 == 0 {
122 | n++
123 | a /= 2
124 | }
125 | return n
126 | }
127 |
128 | // AddHLL adds the state of another counter to h,
129 | // assuming they use the same hash function.
130 | // The result is equivalent to adding all the values of other to h.
131 | func (h *HLL[T]) AddHLL(other *HLL[T]) {
132 | if len(h.counters) != len(other.counters) {
133 | panic("merging HLLs with different sizes")
134 | }
135 | for i, b := range other.counters {
136 | if h.counters[i] < b {
137 | h.counters[i] = b
138 | }
139 | }
140 | }
141 |
142 | // LogSize returns the logSize parameter that was used to create this counter.
143 | func (h *HLL[T]) LogSize() int {
144 | return h.nbits
145 | }
146 |
--------------------------------------------------------------------------------
/ppln/serial.go:
--------------------------------------------------------------------------------
1 | package ppln
2 |
3 | import (
4 | "fmt"
5 | "iter"
6 | "sync"
7 | "sync/atomic"
8 |
9 | "github.com/fluhus/gostuff/heaps"
10 | )
11 |
12 | // Serial starts a multi-goroutine transformation pipeline that maintains the
13 | // order of the inputs.
14 | //
15 | // Input is an iterator over the input values to be transformed.
16 | // It will be called in a thread-safe manner.
17 | // Transform receives an input (a), 0-based input serial number (i), 0-based
18 | // goroutine number (g), and returns the result of processing a.
19 | // Output acts on a single result, and will be called by the same
20 | // order of the input, in a thread-safe manner.
21 | //
22 | // If one of the functions returns a non-nil error, the process stops and the
23 | // error is returned. Otherwise returns nil.
24 | func Serial[T1 any, T2 any](
25 | ngoroutines int,
26 | input iter.Seq2[T1, error],
27 | transform func(a T1, i int, g int) (T2, error),
28 | output func(a T2) error) error {
29 | if ngoroutines < 1 {
30 | panic(fmt.Sprintf("bad number of goroutines: %d", ngoroutines))
31 | }
32 | pull, pstop := iter.Pull2(input)
33 | defer pstop()
34 |
35 | // An optimization for a single thread.
36 | if ngoroutines == 1 {
37 | i := 0
38 | for {
39 | t1, err, ok := pull()
40 | ii := i
41 | i++
42 |
43 | if !ok {
44 | return nil
45 | }
46 | if err != nil {
47 | return err
48 | }
49 |
50 | t2, err := transform(t1, ii, 0)
51 | if err != nil {
52 | return err
53 | }
54 | if err := output(t2); err != nil {
55 | return err
56 | }
57 | }
58 | }
59 |
60 | ilock := &sync.Mutex{}
61 | olock := &sync.Mutex{}
62 | errs := make(chan error, ngoroutines)
63 | stop := &atomic.Bool{}
64 | items := &serialHeap[T2]{
65 | data: heaps.New(func(a, b serialItem[T2]) bool {
66 | return a.i < b.i
67 | }),
68 | }
69 |
70 | i := 0
71 | for g := 0; g < ngoroutines; g++ {
72 | go func(g int) {
73 | for {
74 | if stop.Load() {
75 | errs <- nil
76 | return
77 | }
78 |
79 | ilock.Lock()
80 | t1, err, ok := pull()
81 | ii := i
82 | i++
83 | ilock.Unlock()
84 |
85 | if !ok {
86 | errs <- nil
87 | return
88 | }
89 | if err != nil {
90 | stop.Store(true)
91 | errs <- err
92 | return
93 | }
94 |
95 | t2, err := transform(t1, ii, g)
96 | if err != nil {
97 | stop.Store(true)
98 | errs <- err
99 | return
100 | }
101 |
102 | olock.Lock()
103 | items.put(serialItem[T2]{ii, t2})
104 | for items.ok() {
105 | err = output(items.pop())
106 | if err != nil {
107 | olock.Unlock()
108 | stop.Store(true)
109 | errs <- err
110 | return
111 | }
112 | }
113 | olock.Unlock()
114 | }
115 | }(g)
116 | }
117 |
118 | for g := 0; g < ngoroutines; g++ {
119 | if err := <-errs; err != nil {
120 | return err
121 | }
122 | }
123 | return nil
124 | }
125 |
126 | // General data with a serial number.
127 | type serialItem[T any] struct {
128 | i int
129 | data T
130 | }
131 |
132 | // A heap of serial items. Sorts by serial number.
133 | type serialHeap[T any] struct {
134 | next int
135 | data *heaps.Heap[serialItem[T]]
136 | }
137 |
138 | // Checks whether the minimal element in the heap is the next in the series.
139 | func (s *serialHeap[T]) ok() bool {
140 | return s.data.Len() > 0 && s.data.Head().i == s.next
141 | }
142 |
143 | // Removes and returns the minimal element in the heap. Panics if the element
144 | // is not the next in the series.
145 | func (s *serialHeap[T]) pop() T {
146 | if !s.ok() {
147 | panic("get when not ok")
148 | }
149 | s.next++
150 | a := s.data.Pop()
151 | return a.data
152 | }
153 |
154 | // Adds an item to the heap.
155 | func (s *serialHeap[T]) put(item serialItem[T]) {
156 | if item.i < s.next {
157 | panic(fmt.Sprintf("put(%d) when next is %d", item.i, s.next))
158 | }
159 | s.data.Push(item)
160 | }
161 |
--------------------------------------------------------------------------------
/heaps/heaps.go:
--------------------------------------------------------------------------------
1 | // Package heaps provides generic heaps.
2 | //
3 | // This package provides better run speeds than the standard [heap] package.
4 | package heaps
5 |
6 | import (
7 | "golang.org/x/exp/constraints"
8 | )
9 |
10 | // Heap is a generic heap.
11 | type Heap[T any] struct {
12 | a []T
13 | less func(T, T) bool
14 | }
15 |
16 | // New returns a new heap that uses the given comparator function.
17 | func New[T any](less func(T, T) bool) *Heap[T] {
18 | return &Heap[T]{nil, less}
19 | }
20 |
21 | // Min returns a new min-heap of an ordered type by its natural order.
22 | func Min[T constraints.Ordered]() *Heap[T] {
23 | return New(func(t1, t2 T) bool {
24 | return t1 < t2
25 | })
26 | }
27 |
28 | // Max returns a new max-heap of an ordered type by its natural order.
29 | func Max[T constraints.Ordered]() *Heap[T] {
30 | return New(func(t1, t2 T) bool {
31 | return t1 > t2
32 | })
33 | }
34 |
35 | // Push adds x to h while maintaining its heap invariants.
36 | func (h *Heap[T]) Push(x T) {
37 | h.a = append(h.a, x)
38 | i := len(h.a) - 1
39 | for i != -1 {
40 | i = h.bubbleUp(i)
41 | }
42 | }
43 |
44 | // PushSlice adds the elements of s to h while maintaining its heap invariants.
45 | // The complexity is O(new n), so it should be typically used to initialize a
46 | // new heap.
47 | func (h *Heap[T]) PushSlice(s []T) {
48 | h.a = append(h.a, s...)
49 | for i := len(h.a) - 1; i >= 0; i-- {
50 | j := i
51 | for j != -1 {
52 | j = h.bubbleDown(j)
53 | }
54 | }
55 | }
56 |
57 | // Pop removes and returns the minimal element in h.
58 | func (h *Heap[T]) Pop() T {
59 | if len(h.a) == 0 {
60 | panic("called Pop() on an empty heap")
61 | }
62 | x := h.a[0]
63 | h.a[0] = h.a[len(h.a)-1]
64 | h.a = h.a[:len(h.a)-1]
65 | i := 0
66 | for i != -1 {
67 | i = h.bubbleDown(i)
68 | }
69 | // Shrink if needed.
70 | if cap(h.a) >= 16 && len(h.a) <= cap(h.a)/4 {
71 | h.a = append(make([]T, 0, cap(h.a)/2), h.a...)
72 | }
73 | return x
74 | }
75 |
76 | // Len returns the number of elements in h.
77 | func (h *Heap[T]) Len() int {
78 | return len(h.a)
79 | }
80 |
81 | // Moves the i'th element down and returns its new index.
82 | // Returns -1 when no more bubble-downs are needed.
83 | func (h *Heap[T]) bubbleDown(i int) int {
84 | ia, ib := i*2+1, i*2+2
85 | if len(h.a) < ib { // No children
86 | return -1
87 | }
88 | if len(h.a) == ib { // Only one child
89 | if h.less(h.a[ia], h.a[i]) {
90 | h.a[i], h.a[ia] = h.a[ia], h.a[i]
91 | }
92 | return -1
93 | }
94 | if h.less(h.a[ib], h.a[ia]) {
95 | ia = ib
96 | }
97 | if h.less(h.a[ia], h.a[i]) {
98 | h.a[i], h.a[ia] = h.a[ia], h.a[i]
99 | }
100 | return ia
101 | }
102 |
103 | // Moves the i'th element up and returns its new index.
104 | // Returns -1 when no more bubble-ups are needed.
105 | func (h *Heap[T]) bubbleUp(i int) int {
106 | if i == 0 {
107 | return -1
108 | }
109 | ia := (i - 1) / 2
110 | if !h.less(h.a[i], h.a[ia]) {
111 | return -1
112 | }
113 | h.a[i], h.a[ia] = h.a[ia], h.a[i]
114 | return ia
115 | }
116 |
117 | // View returns the underlying slice of h, containing all of its elements.
118 | // Modifying the slice may invalidate the heap.
119 | func (h *Heap[T]) View() []T {
120 | return h.a
121 | }
122 |
123 | // Head returns the minimal element in h.
124 | func (h *Heap[T]) Head() T {
125 | return h.a[0]
126 | }
127 |
128 | // Fix fixes the heap after a single value had been modified.
129 | // i is the index of the modified value.
130 | func (h *Heap[T]) Fix(i int) {
131 | wentUp := false
132 | for {
133 | j := h.bubbleUp(i)
134 | if j == -1 {
135 | break
136 | }
137 | i = j
138 | wentUp = true
139 | }
140 | if !wentUp {
141 | for {
142 | j := h.bubbleDown(i)
143 | if j == -1 {
144 | break
145 | }
146 | i = j
147 | }
148 | }
149 | }
150 |
151 | // Clip removes unused capacity from the heap.
152 | func (h *Heap[T]) Clip() {
153 | h.a = append(make([]T, 0, len(h.a)), h.a...)
154 | }
155 |
--------------------------------------------------------------------------------
/bnry/bnry_test.go:
--------------------------------------------------------------------------------
1 | package bnry
2 |
3 | import (
4 | "reflect"
5 | "slices"
6 | "testing"
7 | )
8 |
9 | func TestMarshal(t *testing.T) {
10 | a := byte(113)
11 | b := uint64(2391278932173219)
12 | c := "amit"
13 | d := int16(10000)
14 | e := []int32{1, 11, 100, 433223}
15 | f := true
16 | g := false
17 | buf := MarshalBinary(a, b, c, d, e, f, g)
18 | var aa byte
19 | var bb uint64
20 | var cc string
21 | var dd int16
22 | var ee []int32
23 | var ff bool
24 | var gg bool
25 | if err := UnmarshalBinary(
26 | buf, &aa, &bb, &cc, &dd, &ee, &ff, &gg); err != nil {
27 | t.Fatalf("UnmarshalBinary(%v) failed: %v", buf, err)
28 | }
29 | if aa != a {
30 | t.Errorf("UnmarshalBinary(...)=%v, want %v", aa, a)
31 | }
32 | if bb != b {
33 | t.Errorf("UnmarshalBinary(...)=%v, want %v", bb, b)
34 | }
35 | if cc != c {
36 | t.Errorf("UnmarshalBinary(...)=%v, want %v", cc, c)
37 | }
38 | if dd != d {
39 | t.Errorf("UnmarshalBinary(...)=%v, want %v", dd, d)
40 | }
41 | if !slices.Equal(ee, e) {
42 | t.Errorf("UnmarshalBinary(...)=%v, want %v", ee, e)
43 | }
44 | if ff != f {
45 | t.Errorf("UnmarshalBinary(...)=%v, want %v", dd, d)
46 | }
47 | if gg != g {
48 | t.Errorf("UnmarshalBinary(...)=%v, want %v", dd, d)
49 | }
50 | }
51 |
52 | func FuzzMarshal(f *testing.F) {
53 | f.Add(uint8(1), int16(1), uint32(1), int64(1), "", true, 1.0, float32(1))
54 | f.Fuzz(func(t *testing.T, a uint8, b int16, c uint32, d int64, e string,
55 | g bool, h float64, i float32) {
56 | buf := MarshalBinary(a, b, c, d, e, g, h, i)
57 | var (
58 | aa uint8
59 | bb int16
60 | cc uint32
61 | dd int64
62 | ee string
63 | gg bool
64 | hh float64
65 | ii float32
66 | )
67 | err := UnmarshalBinary(buf, &aa, &bb, &cc, &dd, &ee, &gg, &hh, &ii)
68 | if err != nil {
69 | t.Fatal(err)
70 | }
71 | if aa != a {
72 | t.Fatalf("got %v, want %v", aa, a)
73 | }
74 | if bb != b {
75 | t.Fatalf("got %v, want %v", bb, b)
76 | }
77 | if cc != c {
78 | t.Fatalf("got %v, want %v", cc, c)
79 | }
80 | if dd != d {
81 | t.Fatalf("got %v, want %v", dd, d)
82 | }
83 | if ee != e {
84 | t.Fatalf("got %v, want %v", ee, e)
85 | }
86 | if gg != g {
87 | t.Fatalf("got %v, want %v", gg, g)
88 | }
89 | if hh != h {
90 | t.Fatalf("got %v, want %v", hh, h)
91 | }
92 | if ii != i {
93 | t.Fatalf("got %v, want %v", ii, i)
94 | }
95 | })
96 | }
97 |
98 | func TestMarshal_slices(t *testing.T) {
99 | a := []uint32{321321, 213, 4944}
100 | b := []string{"this", "is", "", "a", "slice"}
101 | c := []int8{100, 9, 0, -21}
102 | d := []bool{true, false, false, true, true}
103 | buf := MarshalBinary(slices.Clone(a), slices.Clone(b),
104 | slices.Clone(c), slices.Clone(d))
105 | var (
106 | aa []uint32
107 | bb []string
108 | cc []int8
109 | dd []bool
110 | )
111 | err := UnmarshalBinary(buf, &aa, &bb, &cc, &dd)
112 | if err != nil {
113 | t.Fatal("UnmarshalBinary(...) failed:", err)
114 | }
115 | inputs := []any{a, b, c, d}
116 | outputs := []any{aa, bb, cc, dd}
117 | for i := range inputs {
118 | if !reflect.DeepEqual(inputs[i], outputs[i]) {
119 | t.Fatalf("UnmarshalBinary(...)=%v, want %v", outputs[i], inputs[i])
120 | }
121 | }
122 | }
123 |
124 | func TestMarshal_single(t *testing.T) {
125 | testMarshalSingle(t, int8(123))
126 | testMarshalSingle(t, uint8(123))
127 | testMarshalSingle(t, int32(12345))
128 | testMarshalSingle(t, uint32(12345))
129 | testMarshalSingle(t, int(12345))
130 | testMarshalSingle(t, uint(12345))
131 | testMarshalSingle(t, float64(33.33))
132 | testMarshalSingle(t, float32(33.33))
133 | testMarshalSingle(t, "amit")
134 | }
135 |
136 | func testMarshalSingle[T comparable](t *testing.T, val T) {
137 | buf := MarshalBinary(val)
138 | var got T
139 | if err := UnmarshalBinary(buf, &got); err != nil {
140 | t.Errorf("UnmarshalBinary(%#v) failed: %s", val, err)
141 | return
142 | }
143 | if got != val {
144 | t.Errorf("UnmarshalBinary(%#v)=%#v, want %#v", val, got, val)
145 | }
146 | }
147 |
--------------------------------------------------------------------------------
/bloom/bloom_test.go:
--------------------------------------------------------------------------------
1 | package bloom
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "testing"
7 |
8 | "github.com/fluhus/gostuff/bnry"
9 | "github.com/fluhus/gostuff/gnum"
10 | )
11 |
12 | func TestLen(t *testing.T) {
13 | tests := []struct {
14 | bits int
15 | want int
16 | }{
17 | {1, 8},
18 | {2, 8},
19 | {3, 8},
20 | {4, 8},
21 | {5, 8},
22 | {6, 8},
23 | {7, 8},
24 | {8, 8},
25 | {9, 16},
26 | {16, 16},
27 | {17, 24},
28 | }
29 |
30 | for _, test := range tests {
31 | f := New(test.bits, 1)
32 | if l := f.NBits(); l != test.want {
33 | t.Errorf("New(%v,1).Len()=%v, want %v",
34 | test.bits, l, test.want)
35 | }
36 | if f.NHash() != 1 {
37 | t.Errorf("New(%v,1).K()=%v, want 1",
38 | test.bits, f.NHash())
39 | }
40 | }
41 | }
42 |
43 | func TestFilter(t *testing.T) {
44 | f := New(80, 4)
45 | data := []byte{1, 2, 3, 4}
46 | if f.Has(data) {
47 | t.Fatalf("Has(%v)=true, want false", data)
48 | }
49 | if f.Add(data) {
50 | t.Fatalf("Add(%v)=true, want false", data)
51 | }
52 | if !f.Has(data) {
53 | t.Fatalf("Has(%v)=false, want true", data)
54 | }
55 | if !f.Add(data) {
56 | t.Fatalf("Add(%v)=false, want true", data)
57 | }
58 |
59 | data2 := []byte{4, 3, 2, 1}
60 | if f.Has(data2) {
61 | t.Fatalf("Has(%v)=true, want false", data2)
62 | }
63 | }
64 |
65 | func TestNewOptimal(t *testing.T) {
66 | n := 1000000
67 | p := 0.01
68 | f := NewOptimal(n, p)
69 | t.Logf("bits=%v, k=%v", f.NBits(), f.NHash())
70 | fp := 0
71 | for i := 0; i < n; i++ {
72 | buf := bnry.MarshalBinary(uint64(i))
73 | if f.Add(buf) {
74 | fp++
75 | }
76 | }
77 | if fpr := float64(fp) / float64(n); fpr > p {
78 | t.Fatalf("fp=%v, want <%v", fpr, p)
79 | }
80 | }
81 |
82 | func TestEncode(t *testing.T) {
83 | data1 := []byte{1, 2, 3, 4}
84 | data2 := []byte{4, 3, 2, 1}
85 | f1 := New(80, 4)
86 | f1.SetSeed(5678)
87 | f1.Add(data1)
88 |
89 | if !f1.Has(data1) {
90 | t.Fatalf("Has(%v)=false, want true", data1)
91 | }
92 | if f1.Has(data2) {
93 | t.Fatalf("Has(%v)=true, want false", data2)
94 | }
95 |
96 | buf := bytes.NewBuffer(nil)
97 | if err := f1.Encode(buf); err != nil {
98 | t.Fatalf("Encode(...) failed: %v", err)
99 | }
100 | f2 := &Filter{}
101 | if err := f2.Decode(buf); err != nil {
102 | t.Fatalf("Decode(...) failed: %v", err)
103 | }
104 |
105 | if !bytes.Equal(f1.b, f2.b) {
106 | t.Fatalf("Decode(...) bytes=%v, want %v", f2.b, f1.b)
107 | }
108 | if f1.seed != f2.seed {
109 | t.Fatalf("Decode(...) seed=%v, want %v", f2.seed, f1.seed)
110 | }
111 |
112 | if !f2.Has(data1) {
113 | t.Fatalf("Decode(...).Has(%v)=false, want true", data1)
114 | }
115 | if f2.Has(data2) {
116 | t.Fatalf("Decode(...).Has(%v)=true, want false", data2)
117 | }
118 | }
119 |
120 | func TestAddFilter(t *testing.T) {
121 | data := [][]byte{
122 | {1, 2, 3, 4},
123 | {5, 6, 7, 8},
124 | {9, 10},
125 | }
126 | f1 := New(100, 3)
127 | f2 := New(100, 3)
128 | f2.SetSeed(f1.Seed())
129 |
130 | f1.Add(data[0])
131 | f2.Add(data[1])
132 | f1.AddFilter(f2)
133 |
134 | if !f1.Has(data[0]) {
135 | t.Fatalf("Has(%v)=false, want true", data[0])
136 | }
137 | if !f1.Has(data[1]) {
138 | t.Fatalf("Has(%v)=false, want true", data[1])
139 | }
140 | if f1.Has(data[2]) {
141 | t.Fatalf("Has(%v)=true, want false", data[2])
142 | }
143 | }
144 |
145 | func TestNElements(t *testing.T) {
146 | bf := New(2000, 4)
147 | bf.SetSeed(0)
148 | n := 100
149 | errs := 0
150 | for i := 1; i <= n; i++ {
151 | bf.Add([]byte{byte(i)})
152 | if gnum.Diff(bf.NElements(), i) > 1 {
153 | errs++
154 | }
155 | }
156 | if errs > n/10 {
157 | t.Fatalf("Too many errors: %v, want %v", errs, n/10)
158 | }
159 | }
160 |
161 | func BenchmarkHas(b *testing.B) {
162 | for _, n := range []int{10, 30, 100} {
163 | for k := 1; k <= 3; k++ {
164 | b.Run(fmt.Sprintf("n=%v,k=%v", n, k), func(b *testing.B) {
165 | f := New(1000000, k)
166 | buf := make([]byte, n)
167 | f.Add(buf)
168 | b.ResetTimer()
169 | for i := 0; i < b.N; i++ {
170 | f.Has(buf)
171 | }
172 | })
173 | }
174 | }
175 | }
176 |
177 | func BenchmarkAdd(b *testing.B) {
178 | for _, n := range []int{10, 30, 100} {
179 | for k := 1; k <= 3; k++ {
180 | b.Run(fmt.Sprintf("n=%v,k=%v", n, k), func(b *testing.B) {
181 | f := New(1000000, k)
182 | buf := make([]byte, n)
183 | b.ResetTimer()
184 | for i := 0; i < b.N; i++ {
185 | f.Add(buf)
186 | }
187 | })
188 | }
189 | }
190 | }
191 |
--------------------------------------------------------------------------------
/graphs/graph.go:
--------------------------------------------------------------------------------
1 | // Package graphs implements a simple graph.
2 | package graphs
3 |
4 | import (
5 | "cmp"
6 | "iter"
7 | "slices"
8 |
9 | "github.com/fluhus/gostuff/sets"
10 | "github.com/fluhus/gostuff/snm"
11 | )
12 |
13 | // Graph is a simple graph.
14 | // It contains a set of vertices and a set of edges,
15 | // which are pairs of vertices.
16 | type Graph[T comparable] struct {
17 | v snm.Enumerator[T] // Value to ID.
18 | e sets.Set[[2]int] // Pairs of IDs.
19 | }
20 |
21 | // New returns an empty graph.
22 | func New[T comparable]() *Graph[T] {
23 | return &Graph[T]{snm.Enumerator[T]{}, sets.Set[[2]int]{}}
24 | }
25 |
26 | // AddVertices adds the given values as vertices.
27 | // Values that already exist are ignored.
28 | func (g *Graph[T]) AddVertices(t ...T) {
29 | for _, v := range t {
30 | g.v.IndexOf(v)
31 | }
32 | }
33 |
34 | // NumVertices returns the current number of vertices.
35 | func (g *Graph[T]) NumVertices() int {
36 | return len(g.v)
37 | }
38 |
39 | // NumEdges returns the current number of edges.
40 | func (g *Graph[T]) NumEdges() int {
41 | return len(g.e)
42 | }
43 |
44 | // Edges iterates over current set of edges.
45 | func (g *Graph[T]) Edges() iter.Seq2[T, T] {
46 | return func(yield func(T, T) bool) {
47 | flat := g.v.Elements()
48 | for e := range g.e {
49 | if !yield(flat[e[0]], flat[e[1]]) {
50 | return
51 | }
52 | }
53 | }
54 | }
55 |
56 | // Vertices iterates over current set of vertices,
57 | // by order of addition to the graph.
58 | func (g *Graph[T]) Vertices() iter.Seq[T] {
59 | return func(yield func(T) bool) {
60 | for _, x := range g.v.Elements() {
61 | if !yield(x) {
62 | return
63 | }
64 | }
65 | }
66 | }
67 |
68 | // HasEdge returns whether there is an edge between a and b.
69 | func (g *Graph[T]) HasEdge(a, b T) bool {
70 | return g.e.Has(g.toEdge(a, b))
71 | }
72 |
73 | // AddEdge adds a and b to the vertex set and adds an edge between them.
74 | // The edge is undirected, meaning that AddEdge(a,b) is equivalent
75 | // to AddEdge(b,a).
76 | // If the edge already exists, this is a no-op.
77 | func (g *Graph[T]) AddEdge(a, b T) {
78 | g.e.Add(g.toEdge(a, b))
79 | }
80 |
81 | // DeleteEdge removes the edge between a and b,
82 | // while keeping them in the vertex set.
83 | func (g *Graph[T]) DeleteEdge(a, b T) {
84 | delete(g.e, g.toEdge(a, b))
85 | }
86 |
87 | // ConnectedComponents returns a slice of connected components.
88 | // In each component, the elements are ordered by order of addition to the
89 | // graph.
90 | // The components are ordered by the order of addition of their
91 | // first elements.
92 | func (g *Graph[T]) ConnectedComponents() [][]T {
93 | edges := g.edgeSlices()
94 | m := snm.Slice(g.NumVertices(), func(i int) int { return -1 })
95 | queue := &snm.Queue[int]{}
96 |
97 | for i := range g.NumVertices() {
98 | if m[i] != -1 {
99 | continue
100 | }
101 | m[i] = i
102 | queue.Enqueue(i)
103 | for queue.Len() > 0 {
104 | e := queue.Dequeue()
105 | for _, j := range edges[e] {
106 | if m[j] == -1 {
107 | m[j] = i
108 | queue.Enqueue(j)
109 | }
110 | }
111 | edges[e] = nil
112 | }
113 | }
114 |
115 | comps := map[int][]int{}
116 | for k, v := range m {
117 | comps[v] = append(comps[v], k)
118 | }
119 | poncs := make([][]int, 0, len(comps))
120 | for _, v := range comps {
121 | poncs = append(poncs, snm.Sorted(v))
122 | }
123 | slices.SortFunc(poncs, func(a, b []int) int {
124 | return cmp.Compare(a[0], b[0])
125 | })
126 |
127 | // Convert indices to vertex values.
128 | i2v := g.v.Elements()
129 | return snm.Slice(len(poncs), func(i int) []T {
130 | return snm.Slice(len(poncs[i]), func(j int) T {
131 | return i2v[poncs[i][j]]
132 | })
133 | })
134 | }
135 |
136 | // Returns (without adding) an edge between a and b.
137 | func (g *Graph[T]) toEdge(a, b T) [2]int {
138 | ia, ib := g.v.IndexOf(a), g.v.IndexOf(b)
139 | if ia > ib {
140 | return [2]int{ib, ia}
141 | }
142 | return [2]int{ia, ib}
143 | }
144 |
145 | // Returns a slice representation of this graph's edges.
146 | func (g *Graph[T]) edgeSlices() [][]int {
147 | // Pre-allocate slices.
148 | counts := make([]int, g.NumVertices())
149 | for e := range g.e {
150 | counts[e[0]]++
151 | counts[e[1]]++
152 | }
153 | edges := snm.Slice(len(counts), func(i int) []int {
154 | return make([]int, 0, counts[i])
155 | })
156 |
157 | // Populate with values.
158 | for e := range g.e {
159 | edges[e[0]] = append(edges[e[0]], e[1])
160 | edges[e[1]] = append(edges[e[1]], e[0])
161 | }
162 | return edges
163 | }
164 |
--------------------------------------------------------------------------------
/gnum/vecs.go:
--------------------------------------------------------------------------------
1 | package gnum
2 |
3 | import (
4 | "fmt"
5 | "math"
6 |
7 | "golang.org/x/exp/constraints"
8 | )
9 |
10 | // L1 returns the L1 (Manhattan) distance between a and b.
11 | // Equivalent to Lp(1) but returns the same type.
12 | func L1[S ~[]N, N Number](a, b S) N {
13 | assertMatchingLengths(a, b)
14 | var sum N
15 | for i := range a {
16 | sum += Diff(a[i], b[i])
17 | }
18 | return sum
19 | }
20 |
21 | // L2 returns the L2 (Euclidean) distance between a and b.
22 | // Equivalent to Lp(2).
23 | func L2[S ~[]N, N Number](a, b S) float64 {
24 | assertMatchingLengths(a, b)
25 | var sum N
26 | for i := range a {
27 | d := (a[i] - b[i])
28 | sum += d * d
29 | }
30 | return math.Sqrt(float64(sum))
31 | }
32 |
33 | // Lp returns an Lp distance function. Lp is calculated as follows:
34 | //
35 | // Lp(v) = (sum_i(v[i]^p))^(1/p)
36 | func Lp[S ~[]N, N Number](p int) func(S, S) float64 {
37 | if p < 1 {
38 | panic(fmt.Sprintf("invalid p: %d", p))
39 | }
40 |
41 | if p == 1 {
42 | return func(a, b S) float64 {
43 | return float64(L1(a, b))
44 | }
45 | }
46 | if p == 2 {
47 | return L2[S, N]
48 | }
49 |
50 | return func(a, b S) float64 {
51 | assertMatchingLengths(a, b)
52 | fp := float64(p)
53 | var sum float64
54 | for i := range a {
55 | sum += math.Pow(float64(Diff(a[i], b[i])), fp)
56 | }
57 | return math.Pow(sum, 1/fp)
58 | }
59 | }
60 |
61 | // Add adds b to a and returns a. b is unchanged. If a is nil, creates a new
62 | // vector.
63 | func Add[S ~[]N, N Number](a S, b ...S) S {
64 | if a == nil {
65 | if len(b) == 0 {
66 | return nil
67 | }
68 | a = make(S, len(b[0]))
69 | }
70 | for i := range b {
71 | assertMatchingLengths(a, b[i])
72 | for j := range a {
73 | a[j] += b[i][j]
74 | }
75 | }
76 | return a
77 | }
78 |
79 | // Sub subtracts b from a and returns a. b is unchanged. If a is nil, creates a
80 | // new vector.
81 | func Sub[S ~[]N, N Number](a S, b ...S) S {
82 | if a == nil {
83 | if len(b) == 0 {
84 | return nil
85 | }
86 | a = make(S, len(b[0]))
87 | }
88 | for i := range b {
89 | assertMatchingLengths(a, b[i])
90 | for j := range a {
91 | a[j] -= b[i][j]
92 | }
93 | }
94 | return a
95 | }
96 |
97 | // Mul multiplies a by b and returns a. b is unchanged. If a is nil, creates a
98 | // new vector.
99 | func Mul[S ~[]N, N Number](a S, b ...S) S {
100 | if a == nil {
101 | if len(b) == 0 {
102 | return nil
103 | }
104 | a = Ones[S](len(b[0]))
105 | }
106 | for i := range b {
107 | assertMatchingLengths(a, b[i])
108 | for j := range a {
109 | a[j] -= b[i][j]
110 | }
111 | }
112 | return a
113 | }
114 |
115 | // Add1 adds m to a and returns a.
116 | func Add1[S ~[]N, N Number](a S, m N) S {
117 | for i := range a {
118 | a[i] += m
119 | }
120 | return a
121 | }
122 |
123 | // Sub1 subtracts m from a and returns a.
124 | func Sub1[S ~[]N, N Number](a S, m N) S {
125 | for i := range a {
126 | a[i] -= m
127 | }
128 | return a
129 | }
130 |
131 | // Mul1 multiplies the values of a by m and returns a.
132 | func Mul1[S ~[]N, N Number](a S, m N) S {
133 | for i := range a {
134 | a[i] *= m
135 | }
136 | return a
137 | }
138 |
139 | // Dot returns the dot product of the input vectors.
140 | func Dot[S ~[]N, N Number](a, b S) N {
141 | assertMatchingLengths(a, b)
142 | var sum N
143 | for i := range a {
144 | sum += a[i] * b[i]
145 | }
146 | return sum
147 | }
148 |
149 | // Norm returns the L2 norm of the vector.
150 | func Norm[S ~[]N, N constraints.Float](a S) float64 {
151 | var norm N
152 | for _, v := range a {
153 | norm += v * v
154 | }
155 | return math.Sqrt(float64(norm))
156 | }
157 |
158 | // Ones returns a slice of n ones. Panics if n is negative.
159 | func Ones[S ~[]N, N Number](n int) S {
160 | if n < 0 {
161 | panic(fmt.Sprintf("bad vector length: %d", n))
162 | }
163 | a := make(S, n)
164 | for i := range a {
165 | a[i] = 1
166 | }
167 | return a
168 | }
169 |
170 | // Copy returns a copy of the given slice.
171 | func Copy[S ~[]N, N any](a S) S {
172 | result := make(S, len(a))
173 | copy(result, a)
174 | return result
175 | }
176 |
177 | // Cast casts the values of a and places them in a new slice.
178 | func Cast[S ~[]N, T ~[]M, N Number, M Number](a S) T {
179 | t := make(T, len(a))
180 | for i, s := range a {
181 | t[i] = M(s)
182 | }
183 | return t
184 | }
185 |
186 | // Panics if the input vectors are of different lengths.
187 | func assertMatchingLengths[S ~[]N, N any](a, b S) {
188 | if len(a) != len(b) {
189 | panic(fmt.Sprintf("mismatching lengths: %d, %d", len(a), len(b)))
190 | }
191 | }
192 |
--------------------------------------------------------------------------------
/aio/open.go:
--------------------------------------------------------------------------------
1 | // Package aio provides buffered file I/O.
2 | package aio
3 |
4 | import (
5 | "bufio"
6 | "compress/bzip2"
7 | "compress/gzip"
8 | "io"
9 | "os"
10 | "path/filepath"
11 |
12 | "github.com/klauspost/compress/zstd"
13 | )
14 |
15 | const (
16 | gzipSupport = true // If true, .gz files are automatically compressed/decompressed.
17 | zstdSupport = true // If true, .zst files are automatically compressed/decompressed.
18 | bzipSupport = true // If true, .bz2 files are automatically decompressed.
19 | )
20 |
21 | // OpenRaw opens a file for reading, with a buffer.
22 | func OpenRaw(file string) (*Reader, error) {
23 | f, err := os.Open(file)
24 | if err != nil {
25 | return nil, err
26 | }
27 | return &Reader{*bufio.NewReader(f), f}, nil
28 | }
29 |
30 | // CreateRaw opens a file for writing, with a buffer.
31 | // Erases any previously existing content.
32 | func CreateRaw(file string) (*Writer, error) {
33 | f, err := os.Create(file)
34 | if err != nil {
35 | return nil, err
36 | }
37 | return &Writer{*bufio.NewWriter(f), f}, nil
38 | }
39 |
40 | // AppendRaw opens a file for writing, with a buffer.
41 | // Appends to previously existing content if any.
42 | func AppendRaw(file string) (*Writer, error) {
43 | f, err := os.OpenFile(file, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644)
44 | if err != nil {
45 | return nil, err
46 | }
47 | return &Writer{*bufio.NewWriter(f), f}, nil
48 | }
49 |
50 | var (
51 | rsuffixes = map[string]func(io.Reader) (io.Reader, error){}
52 | wsuffixes = map[string]func(io.WriteCloser) (io.WriteCloser, error){}
53 | )
54 |
55 | // Open opens a file for reading, with a buffer.
56 | // Decompresses the data according to the file's suffix.
57 | func Open(file string) (*Reader, error) {
58 | f, err := OpenRaw(file)
59 | if err != nil {
60 | return nil, err
61 | }
62 | fn := rsuffixes[filepath.Ext(file)]
63 | if fn == nil {
64 | return f, nil
65 | }
66 | ff, err := fn(f)
67 | if err != nil {
68 | return nil, err
69 | }
70 | return &Reader{*bufio.NewReader(ff), f}, nil
71 | }
72 |
73 | // Create opens a file for writing, with a buffer.
74 | // Erases any previously existing content.
75 | // Compresses the data according to the file's suffix.
76 | func Create(file string) (*Writer, error) {
77 | f, err := CreateRaw(file)
78 | if err != nil {
79 | return nil, err
80 | }
81 | fn := wsuffixes[filepath.Ext(file)]
82 | if fn == nil {
83 | return f, nil
84 | }
85 | ff, err := fn(f)
86 | if err != nil {
87 | return nil, err
88 | }
89 | wrapper := &writerWrapper{ff, f}
90 | return &Writer{*bufio.NewWriter(ff), wrapper}, nil
91 | }
92 |
93 | // Append opens a file for writing, with a buffer.
94 | // Appends to previously existing content if any.
95 | // Compresses the data according to the file's suffix.
96 | func Append(file string) (*Writer, error) {
97 | f, err := AppendRaw(file)
98 | if err != nil {
99 | return nil, err
100 | }
101 | fn := wsuffixes[filepath.Ext(file)]
102 | if fn == nil {
103 | return f, nil
104 | }
105 | ff, err := fn(f)
106 | if err != nil {
107 | return nil, err
108 | }
109 | wrapper := &writerWrapper{ff, f}
110 | return &Writer{*bufio.NewWriter(ff), wrapper}, nil
111 | }
112 |
113 | // AddReadSuffix adds a supported suffix for automatic decompression.
114 | // suffix should include the dot. f should take a raw reader and return a reader
115 | // that decompresses the data.
116 | func AddReadSuffix(suffix string, f func(io.Reader) (io.Reader, error)) {
117 | rsuffixes[suffix] = f
118 | }
119 |
120 | // AddWriteSuffix adds a supported suffix for automatic compression.
121 | // suffix should include the dot. f should take a raw writer and return a writer
122 | // that compresses the data.
123 | func AddWriteSuffix(suffix string, f func(io.WriteCloser) (
124 | io.WriteCloser, error)) {
125 | wsuffixes[suffix] = f
126 | }
127 |
128 | func init() {
129 | if gzipSupport {
130 | AddReadSuffix(".gz", func(r io.Reader) (io.Reader, error) {
131 | return gzip.NewReader(r)
132 | })
133 | AddWriteSuffix(".gz", func(w io.WriteCloser) (io.WriteCloser, error) {
134 | return gzip.NewWriterLevel(w, 1)
135 | })
136 | }
137 | if bzipSupport {
138 | AddReadSuffix(".bz2", func(r io.Reader) (io.Reader, error) {
139 | return bzip2.NewReader(r), nil
140 | })
141 | }
142 | if zstdSupport {
143 | AddReadSuffix(".zst", func(r io.Reader) (io.Reader, error) {
144 | return zstd.NewReader(r, zstd.WithDecoderConcurrency(1))
145 | })
146 | AddWriteSuffix(".zst", func(w io.WriteCloser) (io.WriteCloser, error) {
147 | return zstd.NewWriter(w, zstd.WithEncoderConcurrency(1))
148 | })
149 | }
150 | }
151 |
--------------------------------------------------------------------------------
/rhash/common_test.go:
--------------------------------------------------------------------------------
1 | // A generic test suite for rolling hashes.
2 |
3 | package rhash
4 |
5 | import (
6 | "crypto/rand"
7 | "hash"
8 | "testing"
9 |
10 | "github.com/fluhus/gostuff/sets"
11 | )
12 |
13 | // Runs the test suite for a hash64.
14 | func test64(t *testing.T, f func(n int) hash.Hash64) {
15 | t.Run("basic", func(t *testing.T) { test64basic(t, f) })
16 | t.Run("cyclic", func(t *testing.T) { test64cyclic(t, f) })
17 | t.Run("big-n", func(t *testing.T) { test64bigN(t, f) })
18 | }
19 |
20 | func test64basic(t *testing.T, f func(n int) hash.Hash64) {
21 | data := []byte("amitamit")
22 | tests := []struct {
23 | n int
24 | wantSize []int
25 | wantEq []int
26 | }{
27 | {2, []int{1, 2, 3, 4, 5, 5, 5, 5}, []int{-1, -1, -1, -1, -1, 1, 2, 3}},
28 | {3, []int{1, 2, 3, 4, 5, 6, 6, 6}, []int{-1, -1, -1, -1, -1, -1, 2, 3}},
29 | }
30 | for _, test := range tests {
31 | slice := []uint64{}
32 | set := sets.Set[uint64]{}
33 | h := f(test.n)
34 | for i := range data {
35 | h.Write(data[i : i+1])
36 | slice = append(slice, h.Sum64())
37 | set.Add(h.Sum64())
38 | if len(set) != test.wantSize[i] {
39 | t.Fatalf("n=%d #%d: set size=%v, want %v",
40 | test.n, i, len(set), test.wantSize[i])
41 | }
42 | if test.wantEq[i] != -1 && h.Sum64() != slice[test.wantEq[i]] {
43 | t.Fatalf("n=%d #%d: Sum64()=%d, want %d",
44 | test.n, i, h.Sum64(), slice[test.wantEq[i]])
45 | }
46 | }
47 | }
48 | }
49 |
50 | func test64cyclic(t *testing.T, f func(n int) hash.Hash64) {
51 | inputs := []string{
52 | "asdjadasdk",
53 | "uioewrmnoc",
54 | "wiewuwikxa",
55 | "mfhddl/lcc",
56 | "28n9789dkd",
57 | }
58 | h := f(10)
59 | for _, input := range inputs {
60 | h.Write([]byte(input))
61 | h2 := f(10)
62 | h2.Write([]byte(input))
63 | got, want := h.Sum64(), h2.Sum64()
64 | if got != want {
65 | t.Fatalf("Sum64(%q)=%v, want %v", input, got, want)
66 | }
67 | }
68 | }
69 |
70 | func test64bigN(t *testing.T, f func(n int) hash.Hash64) {
71 | const n = 100
72 |
73 | // Create random input.
74 | buf := make([]byte, n)
75 | _, err := rand.Read(buf)
76 | if err != nil {
77 | t.Fatalf("rand.Read() failed: %v", err)
78 | }
79 |
80 | // Repeat 3 times.
81 | buf = append(buf, buf...)
82 | buf = append(buf, buf...)
83 |
84 | h := f(n)
85 | hashes := sets.Set[uint64]{}
86 | for i := range buf {
87 | h.Write(buf[i : i+1])
88 | hashes.Add(h.Sum64())
89 | want := min(i+1, n*2-1)
90 | if len(hashes) != want {
91 | t.Fatalf("got %d unique hashes, want %d", len(hashes), want)
92 | }
93 | }
94 | }
95 |
96 | // Runs the test suite for a hash32.
97 | func test32(t *testing.T, f func(n int) hash.Hash32) {
98 | t.Run("basic", func(t *testing.T) { test32basic(t, f) })
99 | t.Run("cyclic", func(t *testing.T) { test32cyclic(t, f) })
100 | t.Run("big-n", func(t *testing.T) { test32bigN(t, f) })
101 | }
102 |
103 | func test32basic(t *testing.T, f func(n int) hash.Hash32) {
104 | data := []byte("amitamit")
105 | tests := []struct {
106 | n int
107 | wantSize []int
108 | wantEq []int
109 | }{
110 | {2, []int{1, 2, 3, 4, 5, 5, 5, 5}, []int{-1, -1, -1, -1, -1, 1, 2, 3}},
111 | {3, []int{1, 2, 3, 4, 5, 6, 6, 6}, []int{-1, -1, -1, -1, -1, -1, 2, 3}},
112 | }
113 | for _, test := range tests {
114 | slice := []uint32{}
115 | set := sets.Set[uint32]{}
116 | h := f(test.n)
117 | for i := range data {
118 | h.Write(data[i : i+1])
119 | slice = append(slice, h.Sum32())
120 | set.Add(h.Sum32())
121 | if len(set) != test.wantSize[i] {
122 | t.Fatalf("n=%d #%d: set size=%v, want %v",
123 | test.n, i, len(set), test.wantSize[i])
124 | }
125 | if test.wantEq[i] != -1 && h.Sum32() != slice[test.wantEq[i]] {
126 | t.Fatalf("n=%d #%d: Sum32()=%d, want %d",
127 | test.n, i, h.Sum32(), slice[test.wantEq[i]])
128 | }
129 | }
130 | }
131 | }
132 |
133 | func test32cyclic(t *testing.T, f func(n int) hash.Hash32) {
134 | inputs := []string{
135 | "asdjadasdk",
136 | "uioewrmnoc",
137 | "wiewuwikxa",
138 | "mfhddl/lcc",
139 | "28n9789dkd",
140 | }
141 | h := f(10)
142 | for _, input := range inputs {
143 | h.Write([]byte(input))
144 | h2 := f(10)
145 | h2.Write([]byte(input))
146 | got, want := h.Sum32(), h2.Sum32()
147 | if got != want {
148 | t.Fatalf("Sum32(%q)=%v, want %v", input, got, want)
149 | }
150 | }
151 | }
152 |
153 | func test32bigN(t *testing.T, f func(n int) hash.Hash32) {
154 | const n = 100
155 |
156 | // Create random input.
157 | buf := make([]byte, n)
158 | _, err := rand.Read(buf)
159 | if err != nil {
160 | t.Fatalf("rand.Read() failed: %v", err)
161 | }
162 |
163 | // Repeat 3 times.
164 | buf = append(buf, buf...)
165 | buf = append(buf, buf...)
166 |
167 | h := f(n)
168 | hashes := sets.Set[uint32]{}
169 | for i := range buf {
170 | h.Write(buf[i : i+1])
171 | hashes.Add(h.Sum32())
172 | want := min(i+1, n*2-1)
173 | if len(hashes) != want {
174 | t.Fatalf("got %d unique hashes, want %d", len(hashes), want)
175 | }
176 | }
177 | }
178 |
--------------------------------------------------------------------------------
/snm/snm_test.go:
--------------------------------------------------------------------------------
1 | package snm
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "slices"
7 | "testing"
8 |
9 | "golang.org/x/exp/maps"
10 | )
11 |
12 | func TestSlice(t *testing.T) {
13 | want := []int{1, 4, 9, 16}
14 | got := Slice(4, func(i int) int { return (i + 1) * (i + 1) })
15 | if !slices.Equal(got, want) {
16 | t.Fatalf("Slice((i+1)*(i+1))=%v, want %v", got, want)
17 | }
18 | }
19 |
20 | func TestSliceToSlice(t *testing.T) {
21 | input := []int{1, 4, 9, 16}
22 | want := []float64{1.5, 4.5, 9.5, 16.5}
23 | got := SliceToSlice(input, func(i int) float64 {
24 | return float64(i) + 0.5
25 | })
26 | if !slices.Equal(got, want) {
27 | t.Fatalf("SliceToSlice(%v)=%v, want %v", input, got, want)
28 | }
29 | }
30 |
31 | func TestMapToMap(t *testing.T) {
32 | input := map[string]string{"a": "bbb", "cccc": "ddddddd"}
33 | want := map[int]int{1: 3, 4: 7}
34 | got := MapToMap(input, func(k, v string) (int, int) {
35 | return len(k), len(v)
36 | })
37 | if !maps.Equal(got, want) {
38 | t.Fatalf("MapToMap(%v)=%v, want %v", input, got, want)
39 | }
40 | }
41 |
42 | func TestMapToMap_equalKeys(t *testing.T) {
43 | input := map[string]string{"a": "bbb", "cccc": "ddddddd", "e": "ff"}
44 | want1 := map[int]int{1: 3, 4: 7}
45 | want2 := map[int]int{1: 2, 4: 7}
46 | got := MapToMap(input, func(k, v string) (int, int) {
47 | return len(k), len(v)
48 | })
49 | if !maps.Equal(got, want1) && !maps.Equal(got, want2) {
50 | t.Fatalf("MapToMap(%v)=%v, want %v or %v", input, got, want1, want2)
51 | }
52 | }
53 |
54 | func TestDefaultMap(t *testing.T) {
55 | m := NewDefaultMap[int, string](func(i int) string {
56 | return fmt.Sprint(i + 1)
57 | })
58 | if got, want := m.Get(2), "3"; got != want {
59 | t.Fatalf("Get(%d)=%s, want %s", 2, got, want)
60 | }
61 | if got, want := m.Get(6), "7"; got != want {
62 | t.Fatalf("Get(%d)=%s, want %s", 6, got, want)
63 | }
64 | m.Set(2, "a")
65 | if got, want := m.Get(2), "a"; got != want {
66 | t.Fatalf("Get(%d)=%s, want %s", 2, got, want)
67 | }
68 | if got, want := m.Get(6), "7"; got != want {
69 | t.Fatalf("Get(%d)=%s, want %s", 6, got, want)
70 | }
71 | if got, want := len(m.M), 2; got != want {
72 | t.Fatalf("Len=%d, want %d", got, want)
73 | }
74 | }
75 |
76 | func TestCompareReverse(t *testing.T) {
77 | input := []int{3, 4, 2, 1, 5}
78 | want := []int{5, 4, 3, 2, 1}
79 |
80 | cp := slices.Clone(input)
81 | slices.SortFunc(cp, CompareReverse)
82 | if !slices.Equal(cp, want) {
83 | t.Errorf("SortFunc(%v, Compare)=%v, want %v",
84 | input, cp, want)
85 | }
86 | }
87 |
88 | func ExampleSortedKeys() {
89 | ages := map[string]int{
90 | "Alice": 30,
91 | "Bob": 20,
92 | "Charlie": 25,
93 | }
94 | for _, name := range SortedKeys(ages) {
95 | fmt.Printf("%s: %d\n", name, ages[name])
96 | }
97 | // Output:
98 | // Bob: 20
99 | // Charlie: 25
100 | // Alice: 30
101 | }
102 |
103 | func ExampleSortedKeysFunc_reverse() {
104 | ages := map[string]int{
105 | "Alice": 30,
106 | "Bob": 20,
107 | "Charlie": 25,
108 | }
109 | // Sort by reverse natural order.
110 | for _, name := range SortedKeysFunc(ages, CompareReverse) {
111 | fmt.Printf("%s: %d\n", name, ages[name])
112 | }
113 | // Output:
114 | // Alice: 30
115 | // Charlie: 25
116 | // Bob: 20
117 | }
118 |
119 | func TestEnumerator(t *testing.T) {
120 | tests := []struct {
121 | i, want int
122 | }{
123 | {6, 0}, {3, 1}, {6, 0}, {2, 2}, {3, 1}, {10, 3}, {10, 3}, {2, 2},
124 | {6, 0}, {3, 1},
125 | }
126 | e := Enumerator[int]{}
127 | for _, test := range tests {
128 | if got := e.IndexOf(test.i); got != test.want {
129 | t.Fatalf("%v.IndexOf(%v)=%v, want %v", e, test.i, got, test.want)
130 | }
131 | }
132 |
133 | wantElem := []int{6, 3, 2, 10}
134 | if got := e.Elements(); !slices.Equal(got, wantElem) {
135 | t.Fatalf("%v.Elements()=%v, want %v", e, got, wantElem)
136 | }
137 | }
138 |
139 | func ExampleCapMap() {
140 | data := [][]string{
141 | {"a", "b", "c", "a", "b", "b"},
142 | // ...
143 | }
144 | counter := NewCapMap[string, int]()
145 | for _, x := range data {
146 | m := counter.Map()
147 | countValues(x, m)
148 |
149 | // Do something with m.
150 | j, _ := json.Marshal(m)
151 | fmt.Println(string(j))
152 | counter.Clear()
153 | }
154 | //Output:
155 | //{"a":2,"b":3,"c":1}
156 | }
157 |
158 | func countValues(vals []string, out map[string]int) {
159 | for _, v := range vals {
160 | out[v]++
161 | }
162 | }
163 |
164 | func TestShuffle(t *testing.T) {
165 | nums := Slice(10, func(i int) int { return i })
166 | found := make([]bool, len(nums))
167 | counts := Slice(len(nums), func(i int) []int { return make([]int, len(nums)) })
168 | for range 1000 {
169 | Shuffle(nums)
170 | clear(found)
171 | for i, x := range nums {
172 | found[x] = true
173 | counts[i][x]++
174 | }
175 | for i, f := range found {
176 | if !f {
177 | t.Fatalf("did not find %v: %v", i, nums)
178 | }
179 | }
180 | }
181 | for i, c := range counts {
182 | for j, x := range c {
183 | if x < 70 {
184 | t.Errorf("count of %v at position %v: %v, want >%v",
185 | j, i, x, 70)
186 | }
187 | }
188 | }
189 | }
190 |
--------------------------------------------------------------------------------
/prefixtree/tree.go:
--------------------------------------------------------------------------------
1 | // Package prefixtree provides a basic prefix tree implementation.
2 | //
3 | // Add, Has, HasPrefix, Delete and DeletePrefix are linear in query length,
4 | // regardless of the size of the tree.
5 | // All operations are non-recursive.
6 | package prefixtree
7 |
8 | import (
9 | "iter"
10 | "slices"
11 |
12 | "golang.org/x/exp/maps"
13 | )
14 |
15 | // A Tree is a prefix tree.
16 | //
17 | // A zero value tree is invalid; use New to create a new instance.
18 | type Tree struct {
19 | isElem bool // Is this node an element in the tree.
20 | m map[byte]*Tree
21 | }
22 |
23 | // New returns an empty tree.
24 | func New() *Tree {
25 | return &Tree{m: map[byte]*Tree{}}
26 | }
27 |
28 | // Add inserts x to the tree.
29 | // If a was already added, the tree is unchanged.
30 | func (t *Tree) Add(x []byte) {
31 | cur := t
32 | for _, b := range x {
33 | next := cur.m[b]
34 | if next == nil {
35 | next = New()
36 | cur.m[b] = next
37 | }
38 | cur = next
39 | }
40 | cur.isElem = true
41 | }
42 |
43 | // Has returns whether x was added to the tree.
44 | func (t *Tree) Has(x []byte) bool {
45 | cur := t
46 | for _, b := range x {
47 | next := cur.m[b]
48 | if next == nil {
49 | return false
50 | }
51 | cur = next
52 | }
53 | return cur.isElem
54 | }
55 |
56 | // HasPrefix returns whether x is a prefix of an element in the tree.
57 | func (t *Tree) HasPrefix(x []byte) bool {
58 | cur := t
59 | for _, b := range x {
60 | next := cur.m[b]
61 | if next == nil {
62 | return false
63 | }
64 | cur = next
65 | }
66 | return true
67 | }
68 |
69 | // FindPrefixes returns all the elements in the tree
70 | // that are prefixes of x, ordered by length.
71 | func (t *Tree) FindPrefixes(x []byte) [][]byte {
72 | cur := t
73 | var result [][]byte
74 | if cur.isElem {
75 | result = append(result, x[:0])
76 | }
77 | for i, b := range x {
78 | cur = cur.m[b]
79 | if cur == nil {
80 | break
81 | }
82 | if cur.isElem {
83 | result = append(result, x[:i+1])
84 | }
85 | }
86 | return result
87 | }
88 |
89 | // Delete removes x from the tree, if possible.
90 | // Returns the result of Has(a) before deletion.
91 | func (t *Tree) Delete(x []byte) bool {
92 | // Delve in and create a stack.
93 | stack := make([]*Tree, len(x))
94 | cur := t
95 | for i := range x {
96 | stack[i] = cur
97 | cur = cur.m[x[i]]
98 | if cur == nil {
99 | return false
100 | }
101 | }
102 | if !cur.isElem {
103 | return false
104 | }
105 | cur.isElem = false
106 | if len(cur.m) > 0 {
107 | return true
108 | }
109 |
110 | // Go back and delete nodes.
111 | for i := len(stack) - 1; i >= 0; i-- {
112 | delete(stack[i].m, x[i])
113 | if len(stack[i].m) > 0 {
114 | // Stop deleting if node has other children.
115 | break
116 | }
117 | }
118 | return true
119 | }
120 |
121 | // DeletePrefix removes prefix x from the tree.
122 | // All sequences that have x as their prefix are removed.
123 | // Other sequences are unchanged.
124 | func (t *Tree) DeletePrefix(x []byte) {
125 | // Length 0 mean all strings.
126 | if len(x) == 0 {
127 | clear(t.m)
128 | return
129 | }
130 |
131 | // Delve in and create a stack.
132 | stack := make([]*Tree, len(x))
133 | cur := t
134 | for i := range x {
135 | stack[i] = cur
136 | cur = cur.m[x[i]]
137 | if cur == nil {
138 | return
139 | }
140 | }
141 |
142 | // Go back and delete nodes.
143 | for i := len(stack) - 1; i >= 0; i-- {
144 | delete(stack[i].m, x[i])
145 | if len(stack[i].m) > 0 {
146 | // Stop deleting if node has other children.
147 | break
148 | }
149 | }
150 | }
151 |
152 | // Iter iterates over the elements of t,
153 | // in no particular order.
154 | // The tree should not be modified during iteration.
155 | func (t *Tree) Iter() iter.Seq[[]byte] {
156 | return func(yield func([]byte) bool) {
157 | stack := []*iterStep{{t, maps.Keys(t.m)}}
158 | var cur []byte
159 | if t.isElem && !yield(cur) {
160 | return
161 | }
162 | for {
163 | step := stack[len(stack)-1]
164 | if len(step.k) == 0 { // Finished with this branch.
165 | stack = stack[:len(stack)-1]
166 | if len(stack) == 0 { // Done.
167 | break
168 | }
169 | cur = cur[:len(cur)-1]
170 | continue
171 | }
172 | // Handle next child.
173 | key := step.k[0]
174 | child := step.t.m[key]
175 | stack = append(stack, &iterStep{child, maps.Keys(child.m)})
176 | step.k = step.k[1:]
177 | cur = append(cur, key)
178 | if child.isElem && !yield(slices.Clone(cur)) {
179 | return
180 | }
181 | }
182 | }
183 | }
184 |
185 | // IterPrefix iterates the elements in the tree that have x as their prefix.
186 | // Calling IterPrefix with a zero-length slice is equivalent to Iter().
187 | func (t *Tree) IterPrefix(x []byte) iter.Seq[[]byte] {
188 | return func(yield func([]byte) bool) {
189 | cur := t
190 | for i := range x {
191 | cur = cur.m[x[i]]
192 | if cur == nil {
193 | return
194 | }
195 | }
196 | for y := range cur.Iter() {
197 | if !yield(slices.Concat(x, y)) {
198 | return
199 | }
200 | }
201 | }
202 | }
203 |
204 | // A step in the iteration stack.
205 | type iterStep struct {
206 | t *Tree // Tree to process
207 | k []byte // Tree's remaining children
208 | }
209 |
--------------------------------------------------------------------------------
/minhash/minhash_test.go:
--------------------------------------------------------------------------------
1 | package minhash
2 |
3 | import (
4 | "fmt"
5 | "hash/crc64"
6 | "math"
7 | "math/rand"
8 | "reflect"
9 | "slices"
10 | "sort"
11 | "testing"
12 | )
13 |
14 | func TestCollection(t *testing.T) {
15 | tests := []struct {
16 | n int
17 | input []uint64
18 | want []uint64
19 | }{
20 | {
21 | 3,
22 | []uint64{1, 2, 2, 2, 2, 1, 1, 3, 3, 3, 1, 2, 3, 1, 3, 3, 2},
23 | []uint64{1, 2, 3},
24 | },
25 | {
26 | 3,
27 | []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9},
28 | []uint64{1, 2, 3},
29 | },
30 | {
31 | 3,
32 | []uint64{9, 8, 7, 6, 5, 4, 3, 2, 1},
33 | []uint64{1, 2, 3},
34 | },
35 | {
36 | 5,
37 | []uint64{40, 19, 55, 10, 32, 1, 100, 5, 99, 16, 16},
38 | []uint64{1, 5, 10, 16, 19},
39 | },
40 | }
41 | for _, test := range tests {
42 | mh := New[uint64](test.n)
43 | for _, k := range test.input {
44 | mh.Push(k)
45 | }
46 | got := mh.View()
47 | sort.Slice(got, func(i, j int) bool {
48 | return got[i] < got[j]
49 | })
50 | if !reflect.DeepEqual(got, test.want) {
51 | t.Errorf("New(%d).Push(%v)=%v, want %v",
52 | test.n, test.input, got, test.want)
53 | }
54 | }
55 | }
56 |
57 | func TestJSON(t *testing.T) {
58 | input := New[int](5)
59 | input.Push(1)
60 | input.Push(4)
61 | input.Push(9)
62 | input.Push(16)
63 | input.Push(25)
64 | input.Push(36)
65 | jsn, err := input.MarshalJSON()
66 | if err != nil {
67 | t.Fatalf("MinHash(1,4,9,16,25,36).MarshalJSON() failed: %v", err)
68 | }
69 | got := New[int](2)
70 | err = got.UnmarshalJSON(jsn)
71 | if err != nil {
72 | t.Fatalf("UnmarshalJSON(%q) failed: %v", jsn, err)
73 | }
74 | if !slices.Equal(got.View(), input.View()) {
75 | t.Fatalf("UnmarshalJSON(%q)=%v, want %v", jsn, got, input)
76 | }
77 | }
78 |
79 | func TestJaccard(t *testing.T) {
80 | tests := []struct {
81 | a, b []uint64
82 | k int
83 | want float64
84 | }{
85 | {[]uint64{1, 2, 3}, []uint64{1, 2, 3}, 3, 1},
86 | {[]uint64{1, 2, 3}, []uint64{2, 3, 4}, 3, 2.0 / 3.0},
87 | {[]uint64{2, 3, 4}, []uint64{1, 2, 3}, 3, 2.0 / 3.0},
88 | {[]uint64{1, 2, 3, 4, 5}, []uint64{1, 3, 5}, 5, 0.6},
89 | }
90 | for _, test := range tests {
91 | a, b := New[uint64](test.k), New[uint64](test.k)
92 | for _, i := range test.a {
93 | a.Push(i)
94 | }
95 | for _, i := range test.b {
96 | b.Push(i)
97 | }
98 | a.Sort()
99 | b.Sort()
100 | if got := a.Jaccard(b); math.Abs(got-test.want) > 0.00001 {
101 | t.Errorf("Jaccard(%v,%v)=%f, want %f",
102 | test.a, test.b, got, test.want)
103 | }
104 | }
105 | }
106 |
107 | func TestCollection_largeInput(t *testing.T) {
108 | const k = 10000
109 | tests := []struct {
110 | from1, to1, from2, to2 int
111 | }{
112 | {1, 75000, 25000, 100000},
113 | {1, 60000, 40000, 60000},
114 | {1, 60000, 20000, 60000},
115 | {1, 40000, 40001, 60000},
116 | }
117 | for _, test := range tests {
118 | a, b := New[uint64](k), New[uint64](k)
119 | h := crc64.New(crc64.MakeTable(crc64.ECMA))
120 | for i := test.from1; i <= test.to1; i++ {
121 | h.Reset()
122 | fmt.Fprint(h, i)
123 | a.Push(h.Sum64())
124 | }
125 | for i := test.from2; i <= test.to2; i++ {
126 | h.Reset()
127 | fmt.Fprint(h, i)
128 | b.Push(h.Sum64())
129 | }
130 | a.Sort()
131 | b.Sort()
132 | want := float64(test.to1-test.from2+1) / float64(
133 | test.to2-test.from1+1)
134 | if got := a.Jaccard(b); math.Abs(got-want) > want/100 {
135 | t.Errorf("Jaccard(...)=%f, want %f", got, want)
136 | }
137 | }
138 | }
139 |
140 | func FuzzCollection(f *testing.F) {
141 | f.Add(1, 2, 3, 4, 5, 6)
142 | f.Fuzz(func(t *testing.T, a int, b int, c int, d int, e int, f int) {
143 | col := New[int](2)
144 | col.Push(a)
145 | col.Push(b)
146 | col.Push(c)
147 | col.Push(d)
148 | col.Push(e)
149 | col.Push(f)
150 | v := col.View()
151 | if len(v) != 2 {
152 | t.Errorf("len()=%d, want %d", len(v), 2)
153 | }
154 | if v[0] < v[1] {
155 | t.Errorf("v[0]=", v[0], v[1])
156 | }
157 | })
158 | }
159 |
160 | func TestFrozen(t *testing.T) {
161 | mh := New[int](3)
162 | mh.Push(27872)
163 | mh.Push(16978)
164 | mh.Push(28696)
165 | mh.Sort()
166 |
167 | fr := mh.Frozen()
168 | if !slices.Equal(mh.View(), fr.View()) {
169 | t.Fatalf("View()=%v, want %v", fr.View(), mh.View())
170 | }
171 |
172 | mh2 := New[int](3)
173 | mh.Push(27872)
174 | mh.Push(16978)
175 | mh.Push(28697)
176 | mh2.Sort()
177 |
178 | want := mh.Jaccard(mh2)
179 | got := fr.Jaccard(mh2.Frozen())
180 | if got != want {
181 | t.Fatalf("Jaccard=%v, want %v", got, want)
182 | }
183 | }
184 |
185 | func TestFrozen_modifySort(t *testing.T) {
186 | mh := New[int](1)
187 | mh.Push(27872)
188 | mh = mh.Frozen()
189 | defer func() {
190 | recover()
191 | }()
192 | mh.Sort()
193 | t.Fatalf(".Frozen().Sort() succeeded, want panic")
194 | }
195 |
196 | func TestFrozen_modifyPush(t *testing.T) {
197 | mh := New[int](1)
198 | mh.Push(27872)
199 | mh = mh.Frozen()
200 | defer func() {
201 | recover()
202 | }()
203 | mh.Push(123)
204 | t.Fatalf(".Frozen().Sort() succeeded, want panic")
205 | }
206 |
207 | func BenchmarkPush(b *testing.B) {
208 | nums := rand.Perm(b.N)
209 | mh := New[int](b.N)
210 | b.ResetTimer()
211 | for i := 0; i < b.N; i++ {
212 | mh.Push(nums[i])
213 | }
214 | }
215 |
--------------------------------------------------------------------------------
/minhash/minhash.go:
--------------------------------------------------------------------------------
1 | // Package minhash provides a min-hash collection for approximating Jaccard
2 | // similarity.
3 | package minhash
4 |
5 | import (
6 | "encoding/json"
7 | "fmt"
8 | "slices"
9 |
10 | "github.com/fluhus/gostuff/heaps"
11 | "github.com/fluhus/gostuff/sets"
12 | "github.com/fluhus/gostuff/snm"
13 | "golang.org/x/exp/constraints"
14 | )
15 |
16 | // A MinHash is a min-hash collection. Retains the k lowest unique values out of all
17 | // the values that were added to it.
18 | type MinHash[T constraints.Integer] struct {
19 | h *heaps.Heap[T] // Min-hash heap
20 | s sets.Set[T] // Keeps elements unique
21 | k int // Max size of the collection
22 | n int // Number of calls to Push
23 | }
24 |
25 | // New returns an empty collection that stores k values.
26 | func New[T constraints.Integer](k int) *MinHash[T] {
27 | if k < 1 {
28 | panic(fmt.Sprintf("invalid n: %d, should be positive", k))
29 | }
30 | return &MinHash[T]{
31 | heaps.Max[T](),
32 | make(sets.Set[T], k),
33 | k, 0,
34 | }
35 | }
36 |
37 | // Push tries to add a hash to the collection. x is added only if it does not
38 | // already exist, and there are less than k elements lesser than x.
39 | // Returns true if x was added and false if not.
40 | func (mh *MinHash[T]) Push(x T) bool {
41 | if mh.frozen() {
42 | panic("called Push on a frozen MinHash")
43 | }
44 | mh.n++
45 | if mh.h.Len() == mh.k && x >= mh.h.Head() {
46 | // x is too large.
47 | return false
48 | }
49 | if mh.s.Has(x) {
50 | return false
51 | }
52 | if mh.h.Len() == mh.k {
53 | mh.s.Remove(mh.h.Pop())
54 | }
55 | mh.h.Push(x)
56 | mh.s.Add(x)
57 | return true
58 | }
59 |
60 | // K returns the maximal number of elements in mh.
61 | func (mh *MinHash[T]) K() int {
62 | return mh.k
63 | }
64 |
65 | // N returns the number of calls that were made to Push.
66 | // Represents the size of the original set.
67 | func (mh *MinHash[T]) N() int {
68 | return mh.n
69 | }
70 |
71 | // View returns the underlying slice of values.
72 | func (mh *MinHash[T]) View() []T {
73 | if mh.frozen() {
74 | return slices.Clone(mh.h.View())
75 | }
76 | return mh.h.View()
77 | }
78 |
79 | // MarshalJSON implements the json.Marshaler interface.
80 | func (mh *MinHash[T]) MarshalJSON() ([]byte, error) {
81 | return json.Marshal(struct {
82 | K int `json:"k"`
83 | N int `json:"n"`
84 | H []T `json:"h"`
85 | }{mh.k, mh.n, mh.View()})
86 | }
87 |
88 | // UnmarshalJSON implements the json.Unmarshaler interface.
89 | func (mh *MinHash[T]) UnmarshalJSON(b []byte) error {
90 | var raw struct {
91 | K int
92 | N int
93 | H []T
94 | }
95 | if err := json.Unmarshal(b, &raw); err != nil {
96 | return err
97 | }
98 | ss := New[T](raw.K)
99 | ss.h.PushSlice(raw.H)
100 | ss.s.Add(raw.H...)
101 | ss.n = raw.N
102 | *mh = *ss
103 | return nil
104 | }
105 |
106 | // Returns the intersection and union sizes of mh and other,
107 | // in min-hash terms.
108 | func (mh *MinHash[T]) intersect(other *MinHash[T]) (int, int) {
109 | a, b := mh.View(), other.View()
110 | if !mh.frozen() && !slices.IsSortedFunc(a, snm.CompareReverse) {
111 | panic("receiver is not sorted")
112 | }
113 | if !other.frozen() && !slices.IsSortedFunc(b, snm.CompareReverse) {
114 | panic("other is not sorted")
115 | }
116 | intersection := 0
117 | i, j, m := len(a)-1, len(b)-1, 0
118 | for ; i >= 0 && j >= 0 && m < mh.k; m++ {
119 | if a[i] > b[j] {
120 | j--
121 | } else if a[i] < b[j] {
122 | i--
123 | } else { // a[i] == b[j]
124 | intersection++
125 | i--
126 | j--
127 | }
128 | }
129 | union := min(mh.k, m+len(a)-i+len(b)-j)
130 | return intersection, union
131 | }
132 |
133 | // Jaccard returns the approximated Jaccard similarity between mh and other.
134 | //
135 | // Sort needs to be called before calling this function.
136 | func (mh *MinHash[T]) Jaccard(other *MinHash[T]) float64 {
137 | i, u := mh.intersect(other)
138 | return float64(i) / float64(u)
139 | }
140 |
141 | // SoftJaccard returns the Jaccard similarity between mh and other,
142 | // adding one agreed upon element and one disagreed upon element to
143 | // the calculation.
144 | //
145 | // Sort needs to be called before calling this function.
146 | func (mh *MinHash[T]) SoftJaccard(other *MinHash[T]) float64 {
147 | r := mh.Jaccard(other)
148 | sum := float64(mh.N() + other.N())
149 | ri, ru := r*sum/(r+1), sum/(r+1)
150 | return (ri + 1) / (ru + 2)
151 | }
152 |
153 | // Sort sorts the collection, making it ready for Jaccard calculation.
154 | // The collection is still valid after calling Sort.
155 | func (mh *MinHash[T]) Sort() {
156 | if mh.frozen() {
157 | panic("called Sort on a frozen MinHash " +
158 | "(frozen instances are already sorted)")
159 | }
160 | slices.SortFunc(mh.h.View(), snm.CompareReverse)
161 | }
162 |
163 | // Frozen returns an immutable version of this instance.
164 | // The original instance is unchanged.
165 | //
166 | // Frozen instances are sorted, take up less memory
167 | // and calculate Jaccard faster.
168 | // Calls to View are slower because the data is cloned.
169 | func (mh *MinHash[T]) Frozen() *MinHash[T] {
170 | h := heaps.Max[T]()
171 | h.PushSlice(mh.View())
172 | h.Clip()
173 | slices.SortFunc(h.View(), snm.CompareReverse)
174 | result := &MinHash[T]{h, nil, mh.k, mh.n}
175 | return result
176 | }
177 |
178 | // Returns whether this minhash is frozen.
179 | func (mh *MinHash[T]) frozen() bool {
180 | return mh.s == nil
181 | }
182 |
--------------------------------------------------------------------------------
/gnum/gnum.go:
--------------------------------------------------------------------------------
1 | // Package gnum provides generic numerical functions.
2 | package gnum
3 |
4 | import (
5 | "fmt"
6 | "math"
7 |
8 | "golang.org/x/exp/constraints"
9 | )
10 |
11 | // Number is a constraint that contains comparable numbers.
12 | type Number interface {
13 | constraints.Float | constraints.Integer
14 | }
15 |
16 | // Max returns the maximal value in the slice or zero if the slice is empty.
17 | func Max[S ~[]N, N constraints.Ordered](s S) N {
18 | if len(s) == 0 {
19 | var zero N
20 | return zero
21 | }
22 | e := s[0]
23 | for _, v := range s[1:] {
24 | e = max(e, v)
25 | }
26 | return e
27 | }
28 |
29 | // Min returns the maximal value in the slice or zero if the slice is empty.
30 | func Min[S ~[]N, N constraints.Ordered](s S) N {
31 | if len(s) == 0 {
32 | var zero N
33 | return zero
34 | }
35 | e := s[0]
36 | for _, v := range s[1:] {
37 | e = min(e, v)
38 | }
39 | return e
40 | }
41 |
42 | // ArgMax returns the index of the maximal value in the slice or -1 if the slice is empty.
43 | func ArgMax[S ~[]E, E constraints.Ordered](s S) int {
44 | if len(s) == 0 {
45 | return -1
46 | }
47 | imax, max := 0, s[0]
48 | for i, v := range s {
49 | if v > max {
50 | imax, max = i, v
51 | }
52 | }
53 | return imax
54 | }
55 |
56 | // ArgMin returns the index of the minimal value in the slice or -1 if the slice is empty.
57 | func ArgMin[S ~[]E, E constraints.Ordered](s S) int {
58 | if len(s) == 0 {
59 | return -1
60 | }
61 | imin, min := 0, s[0]
62 | for i, v := range s {
63 | if v < min {
64 | imin, min = i, v
65 | }
66 | }
67 | return imin
68 | }
69 |
70 | // Abs returns the absolute value of n.
71 | //
72 | // For floats use [math.Abs].
73 | func Abs[N constraints.Signed](n N) N {
74 | if n < 0 {
75 | return -n
76 | }
77 | return n
78 | }
79 |
80 | // Diff returns the non-negative difference between a and b.
81 | func Diff[N Number](a, b N) N {
82 | if a > b {
83 | return a - b
84 | }
85 | return b - a
86 | }
87 |
88 | // Sum returns the sum of the slice.
89 | func Sum[S ~[]N, N Number](a S) N {
90 | var sum N
91 | for _, v := range a {
92 | sum += v
93 | }
94 | return sum
95 | }
96 |
97 | // Mean returns the average of the slice.
98 | func Mean[S ~[]N, N Number](a S) float64 {
99 | return float64(Sum(a)) / float64(len(a))
100 | }
101 |
102 | // ExpMean returns the exponential average of the slice.
103 | // Non-positive values result in NaN.
104 | func ExpMean[S ~[]N, N Number](a S) float64 {
105 | if len(a) == 0 {
106 | return math.NaN()
107 | }
108 | sum := 0.0
109 | for _, v := range a {
110 | sum += math.Log(float64(v))
111 | }
112 | return math.Exp(sum / float64(len(a)))
113 | }
114 |
115 | // Cov returns the covariance of a and b.
116 | func Cov[S ~[]N, N Number](a, b S) float64 {
117 | assertMatchingLengths(a, b)
118 | ma := Mean(a)
119 | mb := Mean(b)
120 | cov := 0.0
121 | for i := range a {
122 | cov += (float64(a[i]) - ma) * (float64(b[i]) - mb)
123 | }
124 | cov /= float64(len(a))
125 | return cov
126 | }
127 |
128 | // Var returns the variance of a.
129 | func Var[S ~[]N, N Number](a S) float64 {
130 | return Cov(a, a)
131 | }
132 |
133 | // Std returns the standard deviation of a.
134 | func Std[S ~[]N, N Number](a S) float64 {
135 | return math.Sqrt(Var(a))
136 | }
137 |
138 | // Corr returns the Pearson correlation between the a and b.
139 | func Corr[S ~[]N, N Number](a, b S) float64 {
140 | return Cov(a, b) / Std(a) / Std(b)
141 | }
142 |
143 | // Entropy returns the Shannon-entropy of a.
144 | // The elements in a don't have to sum up to 1.
145 | func Entropy[S ~[]N, N Number](a S) float64 {
146 | sum := float64(Sum(a))
147 | result := 0.0
148 | for i, v := range a {
149 | if v < 0 {
150 | panic(fmt.Sprintf("negative value at position %d: %v",
151 | i, v))
152 | }
153 | if v == 0 {
154 | continue
155 | }
156 | p := float64(v) / sum
157 | result -= p * math.Log2(p)
158 | }
159 | return result
160 | }
161 |
162 | // Idiv divides a by b, rounded to the nearest integer.
163 | func Idiv[T constraints.Integer](a, b T) T {
164 | return T(math.Round(float64(a) / float64(b)))
165 | }
166 |
167 | // Quantiles returns the elements that divide the given slice
168 | // at the given ratios.
169 | //
170 | // For example, 0.5 returns the middle element,
171 | // 0.25 returns the element at a quarter of the length, etc.
172 | // 0 and 1 return the first and last element, respectively.
173 | func Quantiles[T any](a []T, qq ...float64) []T {
174 | if len(qq) == 0 {
175 | return nil
176 | }
177 | if len(a) == 0 {
178 | panic("input slice cannot be empty")
179 | }
180 | result := make([]T, 0, len(qq))
181 | n := float64(len(a) - 1)
182 | for _, q := range qq {
183 | i := int(math.Round(q * n))
184 | result = append(result, a[i])
185 | }
186 | return result
187 | }
188 |
189 | // NQuantiles returns the elements that divide
190 | // the given slice into n equal parts (up to rounding),
191 | // including the first and last elements.
192 | //
193 | // For example, for n=2 it returns the first element,
194 | // the middle, and the last element.
195 | func NQuantiles[T any](a []T, n int) []T {
196 | q := make([]float64, n+1)
197 | for i := range q {
198 | q[i] = float64(i) / float64(n)
199 | }
200 | return Quantiles(a, q...)
201 | }
202 |
203 | // LogFactorial returns an approximation of log(n!),
204 | // calculated in constant time.
205 | func LogFactorial(n int) float64 {
206 | if n < 0 {
207 | panic(fmt.Sprintf("n cannot be negative: %v", n))
208 | }
209 | if n == 0 || n == 1 {
210 | return 0
211 | }
212 | // Stirling's approximation.
213 | const halfLog2pi = 0x1.d67f1c864beb4p-01 // 0.5*math.Log(2*math.Pi)
214 | nf := float64(n)
215 | logn := math.Log(nf)
216 | return halfLog2pi + 0.5*logn + nf*(logn-1)
217 | }
218 |
--------------------------------------------------------------------------------
/bloom/bloom.go:
--------------------------------------------------------------------------------
1 | // Package bloom provides a simple bloom filter implementation.
2 | package bloom
3 |
4 | import (
5 | "fmt"
6 | "hash"
7 | "io"
8 | "math"
9 | _ "unsafe"
10 |
11 | "github.com/fluhus/gostuff/bnry"
12 | "github.com/spaolacci/murmur3"
13 | )
14 |
15 | //go:linkname fastrand runtime.fastrand
16 | func fastrand() uint32
17 |
18 | // Filter is a single bloom filter.
19 | type Filter struct {
20 | b []byte // Filter data.
21 | h []hash.Hash64 // Hash functions.
22 | seed uint32
23 | }
24 |
25 | // NHash returns the number of hash functions this filter uses.
26 | func (f *Filter) NHash() int {
27 | return len(f.h)
28 | }
29 |
30 | // NBits returns the number of bits this filter uses.
31 | func (f *Filter) NBits() int {
32 | return 8 * len(f.b)
33 | }
34 |
35 | // NElements returns an approximation of the number of elements added to the
36 | // filter.
37 | func (f *Filter) NElements() int {
38 | m := float64(f.NBits())
39 | k := float64(f.NHash())
40 | x := 0.0 // Number of bits that are 1.
41 | for _, bt := range f.b {
42 | for bt > 0 {
43 | if bt&1 > 0 {
44 | x++
45 | }
46 | bt >>= 1
47 | }
48 | }
49 | return int(math.Round(-m / k * math.Log(1-x/m)))
50 | }
51 |
52 | // Has checks if all k hash values of v were encountered.
53 | // Makes at most k hash calculations.
54 | func (f *Filter) Has(v []byte) bool {
55 | for i := range f.h {
56 | f.h[i].Reset()
57 | f.h[i].Write(v)
58 | hash := int(f.h[i].Sum64() % uint64(len(f.b)*8))
59 | if getBit(f.b, hash) == 0 {
60 | return false
61 | }
62 | }
63 | return true
64 | }
65 |
66 | // Add adds v to the filter, and returns the value of Has(v) before adding.
67 | // After calling Add, Has(v) will always be true. Makes k calls to hash.
68 | func (f *Filter) Add(v []byte) bool {
69 | has := true
70 | for i := range f.h {
71 | f.h[i].Reset()
72 | f.h[i].Write(v)
73 | hash := int(f.h[i].Sum64() % uint64(len(f.b)*8))
74 | if getBit(f.b, hash) == 0 {
75 | has = false
76 | setBit(f.b, hash, 1)
77 | }
78 | }
79 | return has
80 | }
81 |
82 | // AddFilter merges other into f. After merging, f is equivalent to have been added
83 | // all the elements of other.
84 | func (f *Filter) AddFilter(other *Filter) {
85 | // Make sure the two filters are compatible.
86 | if f.NBits() != other.NBits() {
87 | panic(fmt.Sprintf("mismatching number of bits: this has %v, other has %v",
88 | f.NBits(), other.NBits()))
89 | }
90 | if f.NHash() != other.NHash() {
91 | panic(fmt.Sprintf("mismatching number of hashes: this has %v, other has %v",
92 | f.NHash(), other.NHash()))
93 | }
94 | if f.Seed() != other.Seed() {
95 | panic(fmt.Sprintf("mismatching seeds: this has %v, other has %v",
96 | f.Seed(), other.Seed()))
97 | }
98 |
99 | // Merge.
100 | for i := range f.b {
101 | f.b[i] |= other.b[i]
102 | }
103 | }
104 |
105 | // Seed returns the hash seed of this filter.
106 | // A new filter starts with a random seed.
107 | func (f *Filter) Seed() uint32 {
108 | return f.seed
109 | }
110 |
111 | // SetSeed sets the hash seed of this filter.
112 | // The filter must be empty.
113 | func (f *Filter) SetSeed(seed uint32) {
114 | for _, b := range f.b {
115 | if b != 0 {
116 | panic("cannot change seed after elements were added")
117 | }
118 | }
119 | f.seed = seed
120 | h := murmur3.New32WithSeed(seed)
121 | for i := range f.h {
122 | h.Write([]byte{1})
123 | f.h[i] = murmur3.New64WithSeed(h.Sum32())
124 | }
125 | }
126 |
127 | // Encode writes this filter to the stream. Can be reproduced later with Decode.
128 | func (f *Filter) Encode(w io.Writer) error {
129 | // Order is k, seed, bytes.
130 | return bnry.Write(w, uint64(len(f.h)), f.seed, f.b)
131 | }
132 |
133 | // Decode reads an encoded filter from the stream and sets this filter's state
134 | // to match it. Destroys the previously existing state of this filter.
135 | func (f *Filter) Decode(r io.ByteReader) error {
136 | var k uint64
137 | var seed uint32
138 | var b []byte
139 | if err := bnry.Read(r, &k, &seed, &b); err != nil {
140 | return err
141 | }
142 | f.h = make([]hash.Hash64, k)
143 | f.SetSeed(uint32(seed))
144 | f.b = b
145 |
146 | return nil
147 | }
148 |
149 | // New creates a new bloom filter with the given parameters. Number of
150 | // bits is rounded up to the nearest multiple of 8.
151 | //
152 | // See NewOptimal for an alternative way to decide on the parameters.
153 | func New(bits int, k int) *Filter {
154 | if bits < 1 {
155 | panic(fmt.Sprintf("number of bits should be at least 1, got %v", bits))
156 | }
157 | if k < 1 {
158 | panic(fmt.Sprintf("k should be at least 1, got %v", k))
159 | }
160 |
161 | result := &Filter{
162 | b: make([]byte, ((bits-1)/8)+1),
163 | h: make([]hash.Hash64, k),
164 | }
165 | result.SetSeed(fastrand())
166 | return result
167 | }
168 |
169 | // NewOptimal creates a new bloom filter, with parameters optimal for the
170 | // expected number of elements (n) and the required false-positive rate (p).
171 | //
172 | // The calculation is taken from:
173 | // https://en.wikipedia.org/wiki/Bloom_filter#Optimal_number_of_hash_functions
174 | func NewOptimal(n int, p float64) *Filter {
175 | m := math.Round(-float64(n) * math.Log(p) / math.Ln2 / math.Ln2)
176 | k := math.Round(-math.Log2(p))
177 | return New(int(m), int(k))
178 | }
179 |
180 | // Returns the value of the n'th bit in a byte slice.
181 | func getBit(b []byte, n int) int {
182 | return int(b[n/8] >> (n % 8) & 1)
183 | }
184 |
185 | // Sets the value of the n'th bit in a byte slice.
186 | func setBit(b []byte, n, v int) {
187 | if v == 0 {
188 | b[n/8] &= ^(byte(1) << (n % 8))
189 | } else if v == 1 {
190 | b[n/8] |= byte(1) << (n % 8)
191 | } else {
192 | panic(fmt.Sprintf("bad value: %v, expected 0 or 1", v))
193 | }
194 | }
195 |
--------------------------------------------------------------------------------
/gnum/gnum_test.go:
--------------------------------------------------------------------------------
1 | package gnum
2 |
3 | import (
4 | "math"
5 | "slices"
6 | "testing"
7 | )
8 |
9 | func TestEntropy(t *testing.T) {
10 | input1 := []int{1, 2, 3, 4}
11 | input2 := []uint{1, 2, 3, 4}
12 | input3 := []float32{1, 2, 3, 4}
13 | input4 := []float64{1, 2, 3, 4}
14 | want := 1.8464393
15 | if got := Entropy(input1); Diff(got, want) > 0.00000005 {
16 | t.Errorf("Entropy(%v)=%v, want %v", input1, got, want)
17 | }
18 | if got := Entropy(input2); Diff(got, want) > 0.00000005 {
19 | t.Errorf("Entropy(%v)=%v, want %v", input2, got, want)
20 | }
21 | if got := Entropy(input3); Diff(got, want) > 0.00000005 {
22 | t.Errorf("Entropy(%v)=%v, want %v", input3, got, want)
23 | }
24 | if got := Entropy(input4); Diff(got, want) > 0.00000005 {
25 | t.Errorf("Entropy(%v)=%v, want %v", input4, got, want)
26 | }
27 | }
28 |
29 | func TestIdiv(t *testing.T) {
30 | tests := []struct {
31 | a, b, want int
32 | }{
33 | {8, 1, 8},
34 | {8, 2, 4},
35 | {8, 3, 3},
36 | {8, 4, 2},
37 | {8, 5, 2},
38 | {8, 6, 1},
39 | {8, 7, 1},
40 | {8, 8, 1},
41 | {8, 9, 1},
42 | {8, 10, 1},
43 | {8, 11, 1},
44 | {8, 12, 1},
45 | {8, 13, 1},
46 | {8, 14, 1},
47 | {8, 15, 1},
48 | {8, 17, 0},
49 | {8, 18, 0},
50 | {8, 19, 0},
51 | {8, 20, 0},
52 | }
53 | for _, test := range tests {
54 | if got := Idiv(test.a, test.b); got != test.want {
55 | t.Errorf("Idiv(%v,%v)=%v, want %v", test.a, test.b, got, test.want)
56 | }
57 | }
58 | }
59 |
60 | func TestMinMax(t *testing.T) {
61 | tests := []struct {
62 | input []int
63 | mn, mx, amn, amx int
64 | }{
65 | {nil, 0, 0, -1, -1},
66 | {[]int{42}, 42, 42, 0, 0},
67 | {[]int{42, 42}, 42, 42, 0, 0},
68 | {[]int{42, 42, 42}, 42, 42, 0, 0},
69 | {[]int{1, 2, 3}, 1, 3, 0, 2},
70 | {[]int{1, 3, 2}, 1, 3, 0, 1},
71 | {[]int{2, 1, 3}, 1, 3, 1, 2},
72 | {[]int{2, 3, 1}, 1, 3, 2, 1},
73 | {[]int{3, 1, 2}, 1, 3, 1, 0},
74 | {[]int{3, 2, 1}, 1, 3, 2, 0},
75 | }
76 | for _, test := range tests {
77 | mn, mx, amn, amx := Min(test.input), Max(test.input), ArgMin(test.input), ArgMax(test.input)
78 | if mn != test.mn {
79 | t.Errorf("Min(%v)=%v, want %v", test.input, mn, test.mn)
80 | }
81 | if mx != test.mx {
82 | t.Errorf("Max(%v)=%v, want %v", test.input, mx, test.mx)
83 | }
84 | if amn != test.amn {
85 | t.Errorf("ArgMin(%v)=%v, want %v", test.input, amn, test.amn)
86 | }
87 | if amx != test.amx {
88 | t.Errorf("ArgMax(%v)=%v, want %v", test.input, amx, test.amx)
89 | }
90 | }
91 | }
92 |
93 | func TestSum(t *testing.T) {
94 | tests := []struct {
95 | input []int
96 | want int
97 | }{
98 | {nil, 0},
99 | {[]int{1}, 1},
100 | {[]int{1, 1}, 2},
101 | {[]int{1, 1, 1, 1}, 4},
102 | {[]int{6, 4, 1}, 11},
103 | }
104 | for _, test := range tests {
105 | if got := Sum(test.input); got != test.want {
106 | t.Errorf("Sum(%v)=%v, want %v", test.input, got, test.want)
107 | }
108 | }
109 | }
110 |
111 | func TestMean(t *testing.T) {
112 | tests := []struct {
113 | input []int
114 | want float64
115 | }{
116 | {[]int{1}, 1},
117 | {[]int{1, 1}, 1},
118 | {[]int{1, 1, 1, 1}, 1},
119 | {[]int{6, 4, -1}, 3},
120 | }
121 | for _, test := range tests {
122 | if got := Mean(test.input); got != test.want {
123 | t.Errorf("Mean(%v)=%v, want %v", test.input, got, test.want)
124 | }
125 | }
126 | }
127 |
128 | func TestExpMean(t *testing.T) {
129 | tests := []struct {
130 | input []int
131 | want float64
132 | }{
133 | {[]int{1}, 1},
134 | {[]int{1, 1}, 1},
135 | {[]int{3, 3, 3, 3}, 3},
136 | {[]int{10, 1000}, 100},
137 | {[]int{10, 100}, math.Sqrt(1000)},
138 | {[]int{10, 100, 1000}, 100},
139 | }
140 | const tolerance = 0.0000001
141 | for _, test := range tests {
142 | if got := ExpMean(test.input); Diff(got, test.want) > tolerance {
143 | t.Errorf("ExpMean(%v)=%v, want %v", test.input, got, test.want)
144 | }
145 | }
146 | }
147 |
148 | func FuzzSumMean(f *testing.F) {
149 | f.Add(0.0, 0.0, 0.0, 0.0)
150 | f.Fuzz(func(t *testing.T, a float64, b float64, c float64, d float64) {
151 | slice := []float64{a, b, c, d}
152 | want := a + b + c + d
153 | if got := Sum(slice); got != want {
154 | t.Fatalf("Sum([%v,%v,%v,%v])=%v, want %v", a, b, c, d, got, want)
155 | }
156 | want /= 4
157 | if got := Mean(slice); got != want {
158 | t.Fatalf("Mean([%v,%v,%v,%v])=%v, want %v", a, b, c, d, got, want)
159 | }
160 | if a > 0 && b > 0 && c > 0 && d > 0 {
161 | const tol = 0.0000001
162 | want = math.Pow(a*b*c*d, 0.25)
163 | if got := ExpMean(slice); Diff(got, want) > tol {
164 | t.Fatalf("ExpMean([%v,%v,%v,%v])=%v, want %v", a, b, c, d, got, want)
165 | }
166 | }
167 | })
168 | }
169 |
170 | func TestQuantiles(t *testing.T) {
171 | input := []int{1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144}
172 | q := []float64{0.3, 0.5, 0, 1, 0.9}
173 | want1 := []int{3, 8, 1, 144, 89}
174 | want2 := []int{3, 13, 1, 144, 89}
175 | got := Quantiles(input, q...)
176 | if !slices.Equal(got, want1) && !slices.Equal(got, want2) {
177 | t.Fatalf("Quantiles(%v,%v)=%v, want %v or %v",
178 | input, q, got, want1, want2)
179 | }
180 | }
181 |
182 | func TestNQuantiles(t *testing.T) {
183 | input := []int{1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144}
184 | want := []int{1, 5, 21, 144}
185 | got := NQuantiles(input, 3)
186 | if !slices.Equal(got, want) {
187 | t.Fatalf("Quantiles(%v,3)=%v, want %v", input, got, want)
188 | }
189 | }
190 |
191 | func TestLogFactorial(t *testing.T) {
192 | tests := [][]int{
193 | {0, 1}, {1, 1}, {2, 2}, {3, 6}, {4, 24}, {5, 120}, {6, 720},
194 | }
195 | for _, test := range tests {
196 | got := math.Exp(LogFactorial(test[0]))
197 | want := float64(test[1])
198 | if Diff(got, want) > want*0.05 {
199 | t.Errorf("lf(%v)=%v, want %v", test[0], got, test[1])
200 | }
201 | }
202 | }
203 |
204 | func BenchmarkLogFactorial(b *testing.B) {
205 | for i := range b.N {
206 | LogFactorial(i)
207 | }
208 | }
209 |
--------------------------------------------------------------------------------
/flagx/flagx.go:
--------------------------------------------------------------------------------
1 | // Package flagx provides additional [flag] functions.
2 | package flagx
3 |
4 | import (
5 | "flag"
6 | "fmt"
7 | "os"
8 | "regexp"
9 | "slices"
10 | "strconv"
11 | "strings"
12 |
13 | "github.com/fluhus/gostuff/snm"
14 | )
15 |
16 | // RegexpFlagSet defines a regular expression flag with specified name,
17 | // default value, and usage string.
18 | // The return value is the address of a regular expression variable that
19 | // stores the value of the flag.
20 | func RegexpFlagSet(fs *flag.FlagSet, name string,
21 | value *regexp.Regexp, usage string) **regexp.Regexp {
22 | p := &value
23 | fs.Func(name, usage, func(s string) error {
24 | r, err := regexp.Compile(s)
25 | if err != nil {
26 | return err
27 | }
28 | *p = r
29 | return nil
30 | })
31 | return p
32 | }
33 |
34 | // Regexp defines a regular expression flag with specified name,
35 | // default value, and usage string.
36 | // The return value is the address of a regular expression variable that
37 | // stores the value of the flag.
38 | func Regexp(name string, value *regexp.Regexp, usage string) **regexp.Regexp {
39 | return RegexpFlagSet(flag.CommandLine, name, value, usage)
40 | }
41 |
42 | // IntBetweenFlagSet defines an int flag with specified name,
43 | // default value, usage string and bounds.
44 | // The return value is the address of an int variable that
45 | // stores the value of the flag.
46 | func IntBetweenFlagSet(fs *flag.FlagSet, name string,
47 | value int, usage string, minVal, maxVal int) *int {
48 | p := &value
49 | fs.Func(name, usage, func(s string) error {
50 | i, err := strconv.Atoi(s)
51 | if err != nil {
52 | return err
53 | }
54 | if i < minVal || i > maxVal {
55 | return fmt.Errorf("got %d, want %d-%d", i, minVal, maxVal)
56 | }
57 | *p = i
58 | return nil
59 | })
60 | return p
61 | }
62 |
63 | // IntBetween defines an int flag with specified name,
64 | // default value, usage string and bounds.
65 | // The return value is the address of an int variable that
66 | // stores the value of the flag.
67 | func IntBetween(name string, value int, usage string, minVal, maxVal int) *int {
68 | return IntBetweenFlagSet(
69 | flag.CommandLine, name, value, usage, minVal, maxVal)
70 | }
71 |
72 | // FloatBetweenFlagSet defines a float flag with specified name,
73 | // default value, usage string and bounds.
74 | // incMin and incMax define whether min and max are included in the
75 | // allowed values.
76 | // The return value is the address of a float variable that
77 | // stores the value of the flag.
78 | func FloatBetweenFlagSet(fs *flag.FlagSet, name string, value float64,
79 | usage string, minVal, maxVal float64, incMin, incMax bool) *float64 {
80 | p := &value
81 | fs.Func(name, usage, func(s string) error {
82 | f, err := strconv.ParseFloat(s, 64)
83 | if err != nil {
84 | return err
85 | }
86 | if (incMin && f < minVal) || (!incMin && f <= minVal) ||
87 | (incMax && f > maxVal) || (!incMax && f >= maxVal) {
88 | smin, smax := "(", ")"
89 | if incMin {
90 | smin = "["
91 | }
92 | if incMax {
93 | smax = "]"
94 | }
95 | return fmt.Errorf("got %f, want %s%f,%f%s",
96 | f, smin, minVal, maxVal, smax)
97 | }
98 | *p = f
99 | return nil
100 | })
101 | return p
102 | }
103 |
104 | // FloatBetween defines a float flag with specified name,
105 | // default value, usage string and bounds.
106 | // incMin and incMax define whether min and max are included in the
107 | // allowed values.
108 | // The return value is the address of a float variable that
109 | // stores the value of the flag.
110 | func FloatBetween(name string, value float64, usage string,
111 | minVal, maxVal float64, incMin, incMax bool) *float64 {
112 | return FloatBetweenFlagSet(
113 | flag.CommandLine, name, value, usage,
114 | minVal, maxVal, incMin, incMax)
115 | }
116 |
117 | // FileExistsFlagSet defines a string flag that represents
118 | // an existing file. Returns an error if the file does not exist.
119 | func FileExistsFlagSet(fs *flag.FlagSet, name string, value string,
120 | usage string) *string {
121 | v := &value
122 | fs.Func(name, usage, func(s string) error {
123 | f, err := os.Stat(s)
124 | if err != nil {
125 | return err
126 | }
127 | if f.IsDir() {
128 | return fmt.Errorf("path is a directory")
129 | }
130 | *v = s
131 | return nil
132 | })
133 | return v
134 | }
135 |
136 | // FileExists defines a string flag that represents
137 | // an existing file. Returns an error if the file does not exist.
138 | func FileExists(name string, value string, usage string) *string {
139 | return FileExistsFlagSet(flag.CommandLine, name, value, usage)
140 | }
141 |
142 | // OneOfFlagSet defines a flag that must have one of the given values.
143 | // The type must be one that can be read by [fmt.Scan].
144 | func OneOfFlagSet[T comparable](fs *flag.FlagSet, name string,
145 | value T, usage string, of ...T) *T {
146 | if len(of) == 0 {
147 | panic("called with 0 possible values")
148 | }
149 |
150 | // Create usage string.
151 | options := strings.Join(snm.SliceToSlice(of, func(t T) string {
152 | return fmt.Sprint(t)
153 | }), ", ")
154 | dflt := ""
155 | var zero T
156 | if value != zero {
157 | dflt = fmt.Sprintf("; default %s", fmt.Sprint(value))
158 | }
159 | usage = fmt.Sprintf("%s (one of [%s]%s)",
160 | usage, options, dflt)
161 |
162 | v := value
163 | fs.Func(name, usage, func(s string) error {
164 | _, err := fmt.Sscanln(s, &v)
165 | if err != nil {
166 | return err
167 | }
168 | if slices.Index(of, v) == -1 {
169 | return fmt.Errorf("want one of %v", of)
170 | }
171 | return nil
172 | })
173 | return &v
174 | }
175 |
176 | // OneOf defines a flag that must have one of the given values.
177 | // The type must be one that can be read by [fmt.Scan].
178 | func OneOf[T comparable](name string, value T, usage string, of ...T) *T {
179 | return OneOfFlagSet(flag.CommandLine, name, value, usage, of...)
180 | }
181 |
--------------------------------------------------------------------------------
/bnry/write.go:
--------------------------------------------------------------------------------
1 | package bnry
2 |
3 | import (
4 | "bytes"
5 | "encoding/binary"
6 | "fmt"
7 | "io"
8 | "math"
9 | "reflect"
10 | "strings"
11 |
12 | "golang.org/x/exp/constraints"
13 | )
14 |
15 | // Write writes the given values to the given writer.
16 | // Values should be of any of the supported types.
17 | // Panics if a value is of an unsupported type.
18 | func Write(w io.Writer, vals ...any) error {
19 | return NewWriter(w).Write(vals...)
20 | }
21 |
22 | // MarshalBinary writes the given values to a byte slice.
23 | // Values should be of any of the supported types.
24 | // Panics if a value is of an unsupported type.
25 | func MarshalBinary(vals ...any) []byte {
26 | buf := bytes.NewBuffer(nil)
27 | NewWriter(buf).Write(vals...)
28 | return buf.Bytes()
29 | }
30 |
31 | type Writer struct {
32 | w io.Writer
33 | buf [binary.MaxVarintLen64]byte
34 | }
35 |
36 | func NewWriter(w io.Writer) *Writer {
37 | return &Writer{w: w}
38 | }
39 |
40 | // Write writes the given values.
41 | // Values should be of any of the supported types.
42 | // Panics if a value is of an unsupported type.
43 | func (w *Writer) Write(vals ...any) error {
44 | for _, v := range vals {
45 | if err := w.writeSingle(v); err != nil {
46 | return err
47 | }
48 | }
49 | return nil
50 | }
51 |
52 | // Writes a single value as binary.
53 | func (w *Writer) writeSingle(val any) error {
54 | switch val := val.(type) {
55 | case uint8:
56 | return w.writeByte(val)
57 | case uint16:
58 | return writeUint(w, val)
59 | case uint32:
60 | return writeUint(w, val)
61 | case uint64:
62 | return writeUint(w, val)
63 | case uint:
64 | return writeUint(w, val)
65 | case int8:
66 | return w.writeByte(byte(val))
67 | case int16:
68 | return writeInt(w, val)
69 | case int32:
70 | return writeInt(w, val)
71 | case int64:
72 | return writeInt(w, val)
73 | case int:
74 | return writeInt(w, val)
75 | case float32:
76 | return writeUint(w, math.Float32bits(val))
77 | case float64:
78 | return writeUint(w, math.Float64bits(val))
79 | case bool:
80 | return w.writeByte(boolToByte(val))
81 | case string:
82 | return w.writeString(val)
83 | case []uint8:
84 | return w.writeUint8Slice(val)
85 | case []uint16:
86 | return writeUintSlice(w, val)
87 | case []uint32:
88 | return writeUintSlice(w, val)
89 | case []uint64:
90 | return writeUintSlice(w, val)
91 | case []uint:
92 | return writeUintSlice(w, val)
93 | case []int8:
94 | return w.writeInt8Slice(val)
95 | case []int16:
96 | return writeIntSlice(w, val)
97 | case []int32:
98 | return writeIntSlice(w, val)
99 | case []int64:
100 | return writeIntSlice(w, val)
101 | case []int:
102 | return writeIntSlice(w, val)
103 | case []float32:
104 | return w.writeFloat32Slice(val)
105 | case []float64:
106 | return w.writeFloat64Slice(val)
107 | case []bool:
108 | return w.writeBoolSlice(val)
109 | case []string:
110 | return w.writeStringSlice(val)
111 | default:
112 | panic(fmt.Sprintf("unsupported type: %v",
113 | reflect.TypeOf(val)))
114 | }
115 | }
116 |
117 | func (w *Writer) writeByte(b byte) error {
118 | w.buf[0] = b
119 | _, err := w.w.Write(w.buf[:1])
120 | return err
121 | }
122 |
123 | func writeUint[T constraints.Unsigned](w *Writer, i T) error {
124 | _, err := w.w.Write(binary.AppendUvarint(w.buf[:0], uint64(i)))
125 | return err
126 | }
127 |
128 | func writeInt[T constraints.Signed](w *Writer, i T) error {
129 | _, err := w.w.Write(binary.AppendVarint(w.buf[:0], int64(i)))
130 | return err
131 | }
132 |
133 | func (w *Writer) writeUint8Slice(s []uint8) error {
134 | if err := writeUint(w, uint(len(s))); err != nil {
135 | return err
136 | }
137 | _, err := w.w.Write(s)
138 | return err
139 | }
140 |
141 | func (w *Writer) writeString(s string) error {
142 | if err := writeUint(w, uint(len(s))); err != nil {
143 | return err
144 | }
145 | _, err := strings.NewReader(s).WriteTo(w.w)
146 | return err
147 | }
148 |
149 | func (w *Writer) writeInt8Slice(s []int8) error {
150 | if err := writeUint(w, uint(len(s))); err != nil {
151 | return err
152 | }
153 | for _, x := range s {
154 | if err := w.writeByte(byte(x)); err != nil {
155 | return err
156 | }
157 | }
158 | return nil
159 | }
160 |
161 | func writeUintSlice[T constraints.Unsigned](w *Writer, s []T) error {
162 | if err := writeUint(w, uint(len(s))); err != nil {
163 | return err
164 | }
165 | for _, x := range s {
166 | if err := writeUint(w, x); err != nil {
167 | return err
168 | }
169 | }
170 | return nil
171 | }
172 |
173 | func writeIntSlice[T constraints.Signed](w *Writer, s []T) error {
174 | if err := writeUint(w, uint(len(s))); err != nil {
175 | return err
176 | }
177 | for _, x := range s {
178 | if err := writeInt(w, x); err != nil {
179 | return err
180 | }
181 | }
182 | return nil
183 | }
184 |
185 | func (w *Writer) writeFloat32Slice(s []float32) error {
186 | if err := writeUint(w, uint(len(s))); err != nil {
187 | return err
188 | }
189 | for _, x := range s {
190 | if err := writeUint(w, math.Float32bits(x)); err != nil {
191 | return err
192 | }
193 | }
194 | return nil
195 | }
196 |
197 | func (w *Writer) writeFloat64Slice(s []float64) error {
198 | if err := writeUint(w, uint(len(s))); err != nil {
199 | return err
200 | }
201 | for _, x := range s {
202 | if err := writeUint(w, math.Float64bits(x)); err != nil {
203 | return err
204 | }
205 | }
206 | return nil
207 | }
208 |
209 | func (w *Writer) writeBoolSlice(s []bool) error {
210 | if err := writeUint(w, uint(len(s))); err != nil {
211 | return err
212 | }
213 | for _, x := range s {
214 | if err := w.writeByte(boolToByte(x)); err != nil {
215 | return err
216 | }
217 | }
218 | return nil
219 | }
220 |
221 | func (w *Writer) writeStringSlice(s []string) error {
222 | if err := writeUint(w, uint(len(s))); err != nil {
223 | return err
224 | }
225 | for _, x := range s {
226 | if err := w.writeString(x); err != nil {
227 | return err
228 | }
229 | }
230 | return nil
231 | }
232 |
233 | func boolToByte(b bool) byte {
234 | if b {
235 | return 1
236 | } else {
237 | return 0
238 | }
239 | }
240 |
--------------------------------------------------------------------------------