├── .gitignore ├── go.mod ├── go.sum ├── slices ├── filter_test.go ├── group_test.go ├── nonnil_test.go ├── accum_test.go ├── map_test.go ├── append_test.go ├── reverse_test.go ├── each_test.go ├── insert_test.go ├── get_put_test.go ├── remove_test.go ├── replace_test.go ├── example_test.go ├── slice_test.go ├── dropin.go ├── slices.go └── dropin_test.go ├── parallel ├── protect_test.go ├── protect.go ├── parallel_test.go ├── example_test.go └── parallel.go ├── LICENSE ├── .github └── workflows │ └── go.yml ├── set ├── example_test.go ├── set_test.go └── set.go └── Readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | /cover.out 2 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/bobg/go-generics/v4 2 | 3 | go 1.23 4 | 5 | require golang.org/x/sync v0.8.0 6 | 7 | retract v4.0.0 8 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= 2 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 3 | -------------------------------------------------------------------------------- /slices/filter_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestFilter(t *testing.T) { 9 | inp := []int{2, 3, 4, 5, 6, 7, 8, 9, 10} 10 | want := []int{2, 3, 5, 7} 11 | got := Filter(inp, func(n int) bool { 12 | for i := 2; i < n; i++ { 13 | if n%i == 0 { 14 | return false 15 | } 16 | } 17 | return true 18 | }) 19 | if !reflect.DeepEqual(got, want) { 20 | t.Errorf("got %v, want %v", got, want) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /slices/group_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestGroup(t *testing.T) { 9 | inp := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} 10 | want := map[string][]int{ 11 | "evens": {2, 4, 6, 8, 10}, 12 | "odds": {1, 3, 5, 7, 9}, 13 | } 14 | got := Group(inp, func(n int) string { 15 | if n%2 == 0 { 16 | return "evens" 17 | } 18 | return "odds" 19 | }) 20 | if !reflect.DeepEqual(got, want) { 21 | t.Errorf("got %v, want %v", got, want) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /slices/nonnil_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestNonNil(t *testing.T) { 10 | cases := []struct { 11 | inp []int 12 | want []int 13 | }{{ 14 | inp: nil, 15 | want: []int{}, 16 | }, { 17 | inp: []int{}, 18 | want: []int{}, 19 | }, { 20 | inp: []int{1, 2, 3}, 21 | want: []int{1, 2, 3}, 22 | }} 23 | 24 | for i, c := range cases { 25 | t.Run(fmt.Sprintf("case_%d", i+1), func(t *testing.T) { 26 | got := NonNil(c.inp) 27 | if !reflect.DeepEqual(got, c.want) { 28 | t.Errorf("got %v, want %v", got, c.want) 29 | } 30 | }) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /slices/accum_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestAccum(t *testing.T) { 9 | cases := []struct { 10 | inp []int 11 | want int 12 | }{{ 13 | inp: nil, want: 0, 14 | }, { 15 | inp: []int{1}, want: 1, 16 | }, { 17 | inp: []int{1, 2}, want: 3, 18 | }, { 19 | inp: []int{1, 2, 3, 4}, want: 10, 20 | }} 21 | 22 | for i, tc := range cases { 23 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 24 | got := Accum(tc.inp, func(a, b int) int { return a + b }) 25 | if got != tc.want { 26 | t.Errorf("got %d, want %d", got, tc.want) 27 | } 28 | }) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /slices/map_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strconv" 7 | "testing" 8 | ) 9 | 10 | func TestMap(t *testing.T) { 11 | cases := []struct { 12 | inp []int 13 | want []string 14 | }{{ 15 | inp: []int{2, 3, 5}, 16 | want: []string{"2", "3", "5"}, 17 | }, { 18 | inp: []int{}, 19 | want: nil, 20 | }} 21 | 22 | for i, tc := range cases { 23 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 24 | got := Map(tc.inp, func(val int) string { return strconv.Itoa(val) }) 25 | if !reflect.DeepEqual(got, tc.want) { 26 | t.Errorf("got %v, want %v", got, tc.want) 27 | } 28 | }) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /parallel/protect_test.go: -------------------------------------------------------------------------------- 1 | package parallel 2 | 3 | import ( 4 | "context" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestProtect(t *testing.T) { 10 | reader, writer, closer := Protect(4) 11 | defer closer() 12 | 13 | ctx := context.Background() 14 | 15 | vals, err := Values(ctx, 3, func(_ context.Context, _ int) (int, error) { return reader(), nil }) 16 | if err != nil { 17 | t.Fatal(err) 18 | } 19 | 20 | writer(reader() + 1) 21 | 22 | vals2, err := Values(ctx, 3, func(_ context.Context, _ int) (int, error) { return reader(), nil }) 23 | if err != nil { 24 | t.Fatal(err) 25 | } 26 | 27 | vals = append(vals, vals2...) 28 | 29 | want := []int{4, 4, 4, 5, 5, 5} 30 | if !reflect.DeepEqual(vals, want) { 31 | t.Errorf("got %v, want %v", vals, want) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /slices/append_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestAppend(t *testing.T) { 10 | cases := []struct { 11 | inp, vals, want []int 12 | }{{ 13 | inp: nil, 14 | vals: []int{1}, 15 | want: []int{1}, 16 | }, { 17 | inp: nil, 18 | vals: []int{1, 2}, 19 | want: []int{1, 2}, 20 | }, { 21 | inp: []int{1, 2}, 22 | vals: []int{3}, 23 | want: []int{1, 2, 3}, 24 | }, { 25 | inp: []int{1, 2}, 26 | vals: []int{3, 4}, 27 | want: []int{1, 2, 3, 4}, 28 | }} 29 | 30 | for i, tc := range cases { 31 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 32 | got := Append(tc.inp, tc.vals...) 33 | if !reflect.DeepEqual(got, tc.want) { 34 | t.Errorf("got %v, want %v", got, tc.want) 35 | } 36 | }) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /slices/reverse_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestReverse(t *testing.T) { 10 | cases := []struct { 11 | in, want []int 12 | }{{ 13 | in: []int{1, 2, 3, 4, 5}, 14 | want: []int{5, 4, 3, 2, 1}, 15 | }, { 16 | in: []int{1, 2, 3, 4}, 17 | want: []int{4, 3, 2, 1}, 18 | }, { 19 | in: []int{1, 2, 3}, 20 | want: []int{3, 2, 1}, 21 | }, { 22 | in: []int{1, 2}, 23 | want: []int{2, 1}, 24 | }, { 25 | in: []int{1}, 26 | want: []int{1}, 27 | }, { 28 | in: nil, 29 | want: nil, 30 | }} 31 | 32 | for i, tc := range cases { 33 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 34 | Reverse(tc.in) 35 | if !reflect.DeepEqual(tc.in, tc.want) { 36 | t.Errorf("got %v, want %v", tc.in, tc.want) 37 | } 38 | }) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /slices/each_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestEachx(t *testing.T) { 10 | inp := []int{2, 3, 5} 11 | 12 | type wanttype struct { 13 | idx, val int 14 | } 15 | want := []wanttype{{ 16 | idx: 0, val: 2, 17 | }, { 18 | idx: 1, val: 3, 19 | }, { 20 | idx: 2, val: 5, 21 | }} 22 | 23 | var got []wanttype 24 | err := Eachx(inp, func(idx, val int) error { 25 | got = append(got, wanttype{idx: idx, val: val}) 26 | return nil 27 | }) 28 | if err != nil { 29 | t.Fatal(err) 30 | } 31 | if !reflect.DeepEqual(got, want) { 32 | t.Errorf("got %v, want %v", got, want) 33 | } 34 | 35 | e := errors.New("error") 36 | 37 | err = Eachx(inp, func(_, _ int) error { 38 | return e 39 | }) 40 | if !errors.Is(err, e) { 41 | t.Errorf("got %v, want error %v", err, e) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /slices/insert_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestInsert(t *testing.T) { 10 | cases := []struct { 11 | base, ins, want []int 12 | idx int 13 | }{{ 14 | base: nil, 15 | idx: 0, 16 | ins: []int{1, 2}, 17 | want: []int{1, 2}, 18 | }, { 19 | base: []int{1, 2, 3}, 20 | idx: 0, 21 | ins: []int{4}, 22 | want: []int{4, 1, 2, 3}, 23 | }, { 24 | base: []int{1, 2, 3}, 25 | idx: 0, 26 | ins: []int{4, 5}, 27 | want: []int{4, 5, 1, 2, 3}, 28 | }, { 29 | base: []int{1, 2, 3}, 30 | idx: 1, 31 | ins: []int{4, 5}, 32 | want: []int{1, 4, 5, 2, 3}, 33 | }, { 34 | base: []int{1, 2, 3}, 35 | idx: -1, 36 | ins: []int{4, 5}, 37 | want: []int{1, 2, 4, 5, 3}, 38 | }} 39 | 40 | for i, tc := range cases { 41 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 42 | got := Insert(tc.base, tc.idx, tc.ins...) 43 | if !reflect.DeepEqual(got, tc.want) { 44 | t.Errorf("got %v, want %v", got, tc.want) 45 | } 46 | }) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Bob Glickstein 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /slices/get_put_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestGet(t *testing.T) { 10 | cases := []struct { 11 | inp []int 12 | idx int 13 | want int 14 | }{{ 15 | inp: []int{4, 5, 6}, 16 | idx: 0, 17 | want: 4, 18 | }, { 19 | inp: []int{4, 5, 6}, 20 | idx: -1, 21 | want: 6, 22 | }} 23 | 24 | for i, tc := range cases { 25 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 26 | got := Get(tc.inp, tc.idx) 27 | if got != tc.want { 28 | t.Errorf("got %d, want %d", got, tc.want) 29 | } 30 | }) 31 | } 32 | } 33 | 34 | func TestPut(t *testing.T) { 35 | cases := []struct { 36 | inp []int 37 | idx, val int 38 | want []int 39 | }{{ 40 | inp: []int{4, 5, 6}, 41 | idx: 0, 42 | val: 7, 43 | want: []int{7, 5, 6}, 44 | }, { 45 | inp: []int{4, 5, 6}, 46 | idx: -1, 47 | val: 7, 48 | want: []int{4, 5, 7}, 49 | }} 50 | 51 | for i, tc := range cases { 52 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 53 | Put(tc.inp, tc.idx, tc.val) 54 | if !reflect.DeepEqual(tc.inp, tc.want) { 55 | t.Errorf("got %v, want %v", tc.inp, tc.want) 56 | } 57 | }) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | 15 | - name: Set up Go 16 | uses: actions/setup-go@v3 17 | with: 18 | go-version: '1.23' 19 | 20 | - name: golangci-lint 21 | uses: golangci/golangci-lint-action@v3 22 | with: 23 | # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version 24 | version: latest 25 | 26 | # Optional: golangci-lint command line arguments. 27 | # args: --issues-exit-code=0 28 | 29 | # Optional: show only new issues if it's a pull request. The default value is `false`. 30 | # only-new-issues: true 31 | 32 | - name: Unit tests 33 | run: go test -v -coverprofile=cover.out ./... 34 | 35 | - name: Send coverage 36 | uses: shogo82148/actions-goveralls@v1 37 | with: 38 | path-to-profile: cover.out 39 | 40 | - name: Modver 41 | if: ${{ github.event_name == 'pull_request' }} 42 | uses: bobg/modver@v2.11.0 43 | with: 44 | github_token: ${{ secrets.GITHUB_TOKEN }} 45 | pull_request_url: https://github.com/${{ github.repository }}/pull/${{ github.event.number }} 46 | -------------------------------------------------------------------------------- /set/example_test.go: -------------------------------------------------------------------------------- 1 | package set_test 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/bobg/go-generics/v4/set" 7 | ) 8 | 9 | func ExampleDiff() { 10 | var ( 11 | s1 = set.New(1, 2, 3, 4, 5) 12 | s2 = set.New(4, 5, 6, 7, 8) 13 | diff = set.Diff(s1, s2) 14 | ) 15 | diff.Each(func(val int) { fmt.Println(val) }) 16 | // Unordered output: 17 | // 1 18 | // 2 19 | // 3 20 | } 21 | 22 | func ExampleIntersect() { 23 | var ( 24 | s1 = set.New(1, 2, 3, 4, 5) 25 | s2 = set.New(4, 5, 6, 7, 8) 26 | inter = set.Intersect(s1, s2) 27 | ) 28 | inter.Each(func(val int) { fmt.Println(val) }) 29 | // Unordered output: 30 | // 4 31 | // 5 32 | } 33 | 34 | func ExampleUnion() { 35 | var ( 36 | s1 = set.New(1, 2, 3, 4, 5) 37 | s2 = set.New(4, 5, 6, 7, 8) 38 | union = set.Union(s1, s2) 39 | ) 40 | union.Each(func(val int) { fmt.Println(val) }) 41 | // Unordered output: 42 | // 1 43 | // 2 44 | // 3 45 | // 4 46 | // 5 47 | // 6 48 | // 7 49 | // 8 50 | } 51 | 52 | func ExampleOf() { 53 | s := set.New(1, 2, 3, 4, 5) 54 | fmt.Println("1 is in the set?", s.Has(1)) 55 | fmt.Println("100 is in the set?", s.Has(100)) 56 | s.Add(100) 57 | fmt.Println("100 is in the set?", s.Has(100)) 58 | fmt.Println("set size is", s.Len()) 59 | s.Del(100) 60 | fmt.Println("100 is in the set?", s.Has(100)) 61 | fmt.Println("set size is", s.Len()) 62 | // Output: 63 | // 1 is in the set? true 64 | // 100 is in the set? false 65 | // 100 is in the set? true 66 | // set size is 6 67 | // 100 is in the set? false 68 | // set size is 5 69 | } 70 | -------------------------------------------------------------------------------- /slices/remove_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestRemoveN(t *testing.T) { 10 | cases := []struct { 11 | inp []int 12 | idx, n int 13 | want []int 14 | }{{ 15 | inp: nil, 16 | idx: 0, 17 | n: 0, 18 | want: nil, 19 | }, { 20 | inp: []int{4, 5, 6, 7}, 21 | idx: 1, 22 | n: 1, 23 | want: []int{4, 6, 7}, 24 | }, { 25 | inp: []int{4, 5, 6, 7}, 26 | idx: 1, 27 | n: 2, 28 | want: []int{4, 7}, 29 | }, { 30 | inp: []int{4, 5, 6, 7}, 31 | idx: -2, 32 | n: 1, 33 | want: []int{4, 5, 7}, 34 | }, { 35 | inp: []int{4, 5, 6, 7}, 36 | idx: -2, 37 | n: 2, 38 | want: []int{4, 5}, 39 | }} 40 | 41 | for i, tc := range cases { 42 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 43 | got := RemoveN(tc.inp, tc.idx, tc.n) 44 | if !reflect.DeepEqual(got, tc.want) { 45 | t.Errorf("got %v, want %v", got, tc.want) 46 | } 47 | }) 48 | } 49 | } 50 | 51 | func TestRemoveTo(t *testing.T) { 52 | cases := []struct { 53 | inp []int 54 | from, to int 55 | want []int 56 | }{{ 57 | inp: nil, 58 | from: 0, 59 | to: 0, 60 | want: nil, 61 | }, { 62 | inp: []int{4, 5, 6, 7}, 63 | from: 1, 64 | to: 2, 65 | want: []int{4, 6, 7}, 66 | }, { 67 | inp: []int{4, 5, 6, 7}, 68 | from: 1, 69 | to: 3, 70 | want: []int{4, 7}, 71 | }, { 72 | inp: []int{4, 5, 6, 7}, 73 | from: -2, 74 | to: -1, 75 | want: []int{4, 5, 7}, 76 | }, { 77 | inp: []int{4, 5, 6, 7}, 78 | from: -2, 79 | to: 0, 80 | want: []int{4, 5}, 81 | }} 82 | 83 | for i, tc := range cases { 84 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 85 | got := RemoveTo(tc.inp, tc.from, tc.to) 86 | if !reflect.DeepEqual(got, tc.want) { 87 | t.Errorf("got %v, want %v", got, tc.want) 88 | } 89 | }) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /slices/replace_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestReplaceN(t *testing.T) { 10 | cases := []struct { 11 | inp []int 12 | idx, n int 13 | vals []int 14 | want []int 15 | }{{ 16 | inp: nil, 17 | idx: 0, 18 | n: 0, 19 | vals: []int{4}, 20 | want: []int{4}, 21 | }, { 22 | inp: []int{1, 2, 3}, 23 | idx: 0, 24 | n: 0, 25 | vals: []int{4}, 26 | want: []int{4, 1, 2, 3}, 27 | }, { 28 | inp: []int{1, 2, 3}, 29 | idx: 1, 30 | n: 0, 31 | vals: []int{4}, 32 | want: []int{1, 4, 2, 3}, 33 | }, { 34 | inp: []int{1, 2, 3}, 35 | idx: 1, 36 | n: 1, 37 | vals: []int{4}, 38 | want: []int{1, 4, 3}, 39 | }, { 40 | inp: []int{1, 2, 3}, 41 | idx: -2, 42 | n: 1, 43 | vals: []int{4}, 44 | want: []int{1, 4, 3}, 45 | }, { 46 | inp: []int{1, 2, 3}, 47 | idx: 1, 48 | n: 2, 49 | vals: []int{4}, 50 | want: []int{1, 4}, 51 | }, { 52 | inp: []int{1, 2, 3}, 53 | idx: 0, 54 | n: 3, 55 | vals: []int{4, 5}, 56 | want: []int{4, 5}, 57 | }, { 58 | inp: []int{1, 2, 3}, 59 | idx: -2, 60 | n: 2, 61 | vals: []int{4}, 62 | want: []int{1, 4}, 63 | }} 64 | 65 | for i, tc := range cases { 66 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 67 | got := ReplaceN(tc.inp, tc.idx, tc.n, tc.vals...) 68 | if !reflect.DeepEqual(got, tc.want) { 69 | t.Errorf("got %v, want %v", got, tc.want) 70 | } 71 | }) 72 | } 73 | } 74 | 75 | func TestReplaceTo(t *testing.T) { 76 | cases := []struct { 77 | inp []int 78 | from, to int 79 | vals []int 80 | want []int 81 | }{{ 82 | inp: []int{1, 2, 3}, 83 | from: 1, 84 | to: 1, 85 | vals: []int{4}, 86 | want: []int{1, 4, 2, 3}, 87 | }, { 88 | inp: []int{1, 2, 3}, 89 | from: 1, 90 | to: 2, 91 | vals: []int{4}, 92 | want: []int{1, 4, 3}, 93 | }, { 94 | inp: []int{1, 2, 3}, 95 | from: -2, 96 | to: -1, 97 | vals: []int{4}, 98 | want: []int{1, 4, 3}, 99 | }, { 100 | inp: []int{1, 2, 3}, 101 | from: 1, 102 | to: 0, 103 | vals: []int{4}, 104 | want: []int{1, 4}, 105 | }, { 106 | inp: []int{1, 2, 3}, 107 | from: -2, 108 | to: 3, 109 | vals: []int{4}, 110 | want: []int{1, 4}, 111 | }} 112 | 113 | for i, tc := range cases { 114 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 115 | got := ReplaceTo(tc.inp, tc.from, tc.to, tc.vals...) 116 | if !reflect.DeepEqual(got, tc.want) { 117 | t.Errorf("got %v, want %v", got, tc.want) 118 | } 119 | }) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /set/set_test.go: -------------------------------------------------------------------------------- 1 | package set 2 | 3 | import ( 4 | "reflect" 5 | "slices" 6 | "testing" 7 | ) 8 | 9 | func TestSet(t *testing.T) { 10 | s := New[int](1, 2, 3) 11 | s.Add(4, 5, 6) 12 | if s.Has(0) { 13 | t.Error("set should not contain 0") 14 | } 15 | if !s.Has(1) { 16 | t.Error("set should contain 1") 17 | } 18 | if s.Len() != 6 { 19 | t.Errorf("got len %d, want 6", s.Len()) 20 | } 21 | 22 | got := make(map[int]struct{}) 23 | s.Each(func(val int) { got[val] = struct{}{} }) 24 | if !reflect.DeepEqual(got, map[int]struct{}{1: {}, 2: {}, 3: {}, 4: {}, 5: {}, 6: {}}) { 25 | t.Errorf("got %v, want [1 2 3 4 5 6]", got) 26 | } 27 | 28 | s2 := New[int](5, 6, 7, 8) 29 | i := Intersect(s, s2) 30 | if !reflect.DeepEqual(i, Of[int](map[int]struct{}{5: {}, 6: {}})) { 31 | t.Errorf("got %v, want [5 6]", i) 32 | } 33 | i = Intersect(s2, nil) 34 | if i.Len() != 0 { 35 | t.Errorf("got %v, want []", i) 36 | } 37 | 38 | u := Union(s, s2, nil) 39 | if !reflect.DeepEqual(u, Of[int](map[int]struct{}{1: {}, 2: {}, 3: {}, 4: {}, 5: {}, 6: {}, 7: {}, 8: {}})) { 40 | t.Errorf("got %v, want [1 2 3 4 5 6 7 8]", u) 41 | } 42 | 43 | d := Diff(s, s2) 44 | if !reflect.DeepEqual(d, Of[int](map[int]struct{}{1: {}, 2: {}, 3: {}, 4: {}})) { 45 | t.Errorf("got %v, want [1 2 3 4]", d) 46 | } 47 | 48 | var ( 49 | it = s.All() 50 | m = make(map[int]struct{}) 51 | ) 52 | for val := range it { 53 | m[val] = struct{}{} 54 | } 55 | if !reflect.DeepEqual(m, map[int]struct{}{1: {}, 2: {}, 3: {}, 4: {}, 5: {}, 6: {}}) { 56 | t.Errorf("got %v, want [1 2 3 4 5 6]", m) 57 | } 58 | 59 | m = make(map[int]struct{}) 60 | for _, val := range s.Slice() { 61 | m[val] = struct{}{} 62 | } 63 | if !reflect.DeepEqual(m, map[int]struct{}{1: {}, 2: {}, 3: {}, 4: {}, 5: {}, 6: {}}) { 64 | t.Errorf("got %v, want [1 2 3 4 5 6]", m) 65 | } 66 | } 67 | 68 | func TestEqual(t *testing.T) { 69 | var ( 70 | a = New[int](1, 2, 3) 71 | b = New[int](1, 2, 3) 72 | c = New[int](1, 3, 5) 73 | d = New[int](1, 5, 9) 74 | ) 75 | if !a.Equal(b) { 76 | t.Error("got a != b") 77 | } 78 | if !b.Equal(a) { 79 | t.Error("got b != a") 80 | } 81 | if a.Equal(c) { 82 | t.Error("got a == c") 83 | } 84 | if c.Equal(a) { 85 | t.Error("got c == a") 86 | } 87 | if a.Equal(d) { 88 | t.Error("got a == d") 89 | } 90 | if d.Equal(a) { 91 | t.Error("got d == a") 92 | } 93 | } 94 | 95 | func TestCollect(t *testing.T) { 96 | var ( 97 | nums = slices.Values([]int{1, 2, 3, 4, 5, 6}) 98 | got = Collect(nums) 99 | want = New[int](1, 2, 3, 4, 5, 6) 100 | ) 101 | if !got.Equal(want) { 102 | t.Errorf("got %v, want %v", got, want) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /parallel/protect.go: -------------------------------------------------------------------------------- 1 | package parallel 2 | 3 | import ( 4 | "sync" 5 | ) 6 | 7 | // Protect offers safe concurrent access to a protected value. 8 | // It is a "share memory by communicating" alternative 9 | // to protecting the value with [sync.RWMutex]. 10 | // 11 | // The caller gets back three functions: 12 | // a reader for getting the protected value, 13 | // a writer for updating it, 14 | // and a closer for releasing resources 15 | // when no further reads or writes are needed. 16 | // 17 | // Any number of calls to the reader may run concurrently. 18 | // If T is a "reference type" (see below) 19 | // then the caller should not make any changes 20 | // to the value it receives from the reader. 21 | // 22 | // A call to the writer prevents other reader and writer calls from running until it is done. 23 | // It waits for pending calls to finish before it executes. 24 | // After a call to the writer, 25 | // future reader calls will receive the updated value. 26 | // 27 | // The closer should be called to release resources 28 | // when no more reader or writer calls are needed. 29 | // Calling any of the functions (reader, writer, or closer) 30 | // after a call to the closer may cause a panic. 31 | // 32 | // The term "reference type" here means a type 33 | // (such as pointer, slice, map, channel, function, and interface) 34 | // that allows a caller C 35 | // to make changes that will be visible to other callers 36 | // outside of C's scope. 37 | // In other words, 38 | // if the type is int and caller A does this: 39 | // 40 | // val := reader() 41 | // val++ 42 | // 43 | // it will not affect the value that caller B sees when it does its own call to reader(). 44 | // But if the type is *int and caller A does this: 45 | // 46 | // val := reader() 47 | // *val++ 48 | // 49 | // then the change in the pointed-to value _will_ be seen by caller B. 50 | // 51 | // For more on the fuzzy concept of "reference types" in Go, 52 | // see https://github.com/go101/go101/wiki/About-the-terminology-%22reference-type%22-in-Go 53 | func Protect[T any](val T) (reader func() T, writer func(T), closer func()) { 54 | ch := make(chan rwRequest[T]) 55 | 56 | go func() { 57 | var wg sync.WaitGroup 58 | 59 | for req := range ch { 60 | req := req // Go loop var pitfall 61 | if req.r != nil { 62 | wg.Add(1) 63 | go func() { 64 | req.r <- val 65 | close(req.r) 66 | wg.Done() 67 | }() 68 | continue 69 | } 70 | 71 | wg.Wait() 72 | val = <-req.w 73 | } 74 | }() 75 | 76 | reader = func() T { 77 | valch := make(chan T, 1) 78 | ch <- rwRequest[T]{r: valch} 79 | return <-valch 80 | } 81 | writer = func(val T) { 82 | valch := make(chan T, 1) 83 | valch <- val 84 | ch <- rwRequest[T]{w: valch} 85 | close(valch) 86 | } 87 | closer = func() { 88 | close(ch) 89 | } 90 | 91 | return reader, writer, closer 92 | } 93 | 94 | type rwRequest[T any] struct { 95 | r chan<- T 96 | w <-chan T 97 | } 98 | -------------------------------------------------------------------------------- /slices/example_test.go: -------------------------------------------------------------------------------- 1 | package slices_test 2 | 3 | import ( 4 | "fmt" 5 | "sort" 6 | 7 | "github.com/bobg/go-generics/v4/slices" 8 | ) 9 | 10 | func ExampleGet() { 11 | var ( 12 | s = []int{1, 2, 3, 4} 13 | last = slices.Get(s, -1) 14 | ) 15 | fmt.Println(last) 16 | // Output: 4 17 | } 18 | 19 | func ExampleInsert() { 20 | var ( 21 | s1 = []int{10, 15, 16} 22 | s2 = slices.Insert(s1, 1, 11, 12, 13, 14) 23 | ) 24 | fmt.Println(s2) 25 | // Output: [10 11 12 13 14 15 16] 26 | } 27 | 28 | func ExampleReplaceN() { 29 | var ( 30 | s1 = []int{99, 0, 0, 0, 97} 31 | s2 = slices.ReplaceN(s1, 1, 3, 98) 32 | ) 33 | fmt.Println(s2) 34 | // Output: [99 98 97] 35 | } 36 | 37 | func ExampleReplaceTo() { 38 | var ( 39 | s1 = []int{99, 0, 0, 0, 97} 40 | s2 = slices.ReplaceTo(s1, 1, -1, 98) 41 | ) 42 | fmt.Println(s2) 43 | // Output: [99 98 97] 44 | } 45 | 46 | func ExampleRemoveN() { 47 | var ( 48 | s1 = []int{1, 2, 3, 4, 5} 49 | s2 = slices.RemoveN(s1, -2, 2) 50 | ) 51 | fmt.Println(s2) 52 | // Output: [1 2 3] 53 | } 54 | 55 | func ExampleRemoveTo() { 56 | var ( 57 | s1 = []int{1, 2, 3, 4, 5} 58 | s2 = slices.RemoveTo(s1, -2, 0) 59 | ) 60 | fmt.Println(s2) 61 | // Output: [1 2 3] 62 | } 63 | 64 | func ExampleEachx() { 65 | s := []int{100, 200, 300} 66 | _ = slices.Eachx(s, func(idx, val int) error { 67 | fmt.Println(idx, val) 68 | return nil 69 | }) 70 | // Output: 71 | // 0 100 72 | // 1 200 73 | // 2 300 74 | } 75 | 76 | func ExampleMap() { 77 | var ( 78 | s1 = []int{1, 2, 3, 4, 5} 79 | s2 = slices.Map(s1, func(val int) string { return string([]byte{byte('a' + val - 1)}) }) 80 | ) 81 | fmt.Println(s2) 82 | // Output: [a b c d e] 83 | } 84 | 85 | func ExampleAccum() { 86 | var ( 87 | s = []int{1, 2, 3, 4, 5} 88 | sum = slices.Accum(s, func(a, b int) int { return a + b }) 89 | ) 90 | fmt.Println(sum) 91 | // Output: 15 92 | } 93 | 94 | func ExampleFilter() { 95 | var ( 96 | s = []int{1, 2, 3, 4, 5, 6, 7} 97 | evens = slices.Filter(s, func(val int) bool { return val%2 == 0 }) 98 | ) 99 | fmt.Println(evens) 100 | // Output: [2 4 6] 101 | } 102 | 103 | func ExampleGroup() { 104 | s := []int{1, 2, 3, 4, 5, 6, 7} 105 | groups := slices.Group(s, func(val int) string { 106 | if val%2 == 0 { 107 | return "even" 108 | } 109 | return "odd" 110 | }) 111 | 112 | for key, slice := range groups { 113 | fmt.Println(key, slice) 114 | } 115 | // Unordered output: 116 | // even [2 4 6] 117 | // odd [1 3 5 7] 118 | } 119 | 120 | func ExampleRotate() { 121 | s := []int{3, 4, 5, 1, 2} 122 | slices.Rotate(s, 2) 123 | fmt.Println(s) 124 | // Output: [1 2 3 4 5] 125 | } 126 | 127 | func ExampleKeyedSort() { 128 | var ( 129 | nums = []int{1, 2, 3, 4, 5} 130 | names = []string{"one", "two", "three", "four", "five"} 131 | ) 132 | 133 | // Sort the numbers in `nums` according to their names in `names`. 134 | slices.KeyedSort(nums, sort.StringSlice(names)) 135 | 136 | fmt.Println(nums) 137 | // Output: [5 4 1 3 2] 138 | } 139 | -------------------------------------------------------------------------------- /parallel/parallel_test.go: -------------------------------------------------------------------------------- 1 | package parallel 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | "testing" 8 | 9 | "github.com/bobg/go-generics/v4/set" 10 | ) 11 | 12 | func TestValues(t *testing.T) { 13 | got, err := Values(context.Background(), 100, func(_ context.Context, n int) (int, error) { return n, nil }) 14 | if err != nil { 15 | t.Fatal(err) 16 | } 17 | if len(got) != 100 { 18 | t.Errorf("got len %d, want 100", len(got)) 19 | } 20 | for i := 0; i < 100; i++ { 21 | if got[i] != i { 22 | t.Errorf("got[%d] is %d, want %d", i, got[i], i) 23 | } 24 | } 25 | } 26 | 27 | func TestProducers(t *testing.T) { 28 | it, errptr := Producers(context.Background(), 10, func(_ context.Context, n int, send func(int) error) error { 29 | for i := 0; i < 10; i++ { 30 | err := send(10*n + i) 31 | if err != nil { 32 | return err 33 | } 34 | } 35 | return nil 36 | }) 37 | got := set.New[int]() 38 | for val := range it { 39 | got.Add(val) 40 | } 41 | if *errptr != nil { 42 | t.Fatal(*errptr) 43 | } 44 | if got.Len() != 100 { 45 | t.Errorf("got %d values, want 100", got.Len()) 46 | } 47 | for i := 0; i < 100; i++ { 48 | if !got.Has(i) { 49 | t.Errorf("%d missing from result", i) 50 | } 51 | } 52 | } 53 | 54 | func TestConsumers(t *testing.T) { 55 | var ( 56 | mu sync.Mutex 57 | got = set.New[int]() 58 | ) 59 | 60 | sendfn, closefn := Consumers(context.Background(), 10, func(_ context.Context, _, val int) error { 61 | mu.Lock() 62 | got.Add(val) 63 | mu.Unlock() 64 | return nil 65 | }) 66 | for i := 0; i < 100; i++ { 67 | err := sendfn(i) 68 | if err != nil { 69 | t.Fatal(err) 70 | } 71 | } 72 | err := closefn() 73 | if err != nil { 74 | t.Fatal(err) 75 | } 76 | if got.Len() != 100 { 77 | t.Errorf("got %d values, want 100", got.Len()) 78 | } 79 | for i := 0; i < 100; i++ { 80 | if !got.Has(i) { 81 | t.Errorf("%d missing from result", i) 82 | } 83 | } 84 | } 85 | 86 | func TestPool(t *testing.T) { 87 | var ( 88 | running, max int 89 | mu sync.Mutex 90 | cond = sync.NewCond(&mu) 91 | unblocked = make(map[int]bool) 92 | ) 93 | 94 | call := Pool(10, func(n int) (int, error) { 95 | mu.Lock() 96 | running++ 97 | if running > max { 98 | max = running 99 | } 100 | 101 | for !unblocked[n] { 102 | cond.Wait() 103 | } 104 | 105 | running-- 106 | mu.Unlock() 107 | 108 | return n, nil 109 | }) 110 | 111 | var ( 112 | errch = make(chan error, 100) 113 | wg sync.WaitGroup 114 | ) 115 | for i := 0; i < 100; i++ { 116 | wg.Add(1) 117 | go func() { 118 | got, err := call(i) 119 | if err != nil { 120 | errch <- err 121 | } 122 | if got != i { 123 | errch <- fmt.Errorf("got %d, want %d", got, i) 124 | } 125 | wg.Done() 126 | }() 127 | } 128 | 129 | for i := 0; i < 100; i++ { 130 | mu.Lock() 131 | unblocked[i] = true 132 | cond.Broadcast() 133 | mu.Unlock() 134 | } 135 | 136 | go func() { 137 | wg.Wait() 138 | close(errch) 139 | }() 140 | 141 | for err := range errch { 142 | t.Error(err) 143 | } 144 | 145 | if max > 10 { 146 | t.Errorf("max is %d, want <=10", max) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /parallel/example_test.go: -------------------------------------------------------------------------------- 1 | package parallel_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "sync" 7 | 8 | "github.com/bobg/go-generics/v4/parallel" 9 | ) 10 | 11 | func ExampleConsumers() { 12 | ctx := context.Background() 13 | 14 | // One of three goroutines prints incoming values. 15 | send, closefn := parallel.Consumers(ctx, 3, func(_ context.Context, _, val int) error { 16 | fmt.Println(val) 17 | return nil 18 | }) 19 | 20 | // Caller produces values. 21 | for i := 1; i <= 5; i++ { 22 | err := send(i) 23 | if err != nil { 24 | panic(err) 25 | } 26 | } 27 | if err := closefn(); err != nil { 28 | panic(err) 29 | } 30 | // Unordered output: 31 | // 1 32 | // 2 33 | // 3 34 | // 4 35 | // 5 36 | } 37 | 38 | func ExampleProducers() { 39 | ctx := context.Background() 40 | 41 | // Five goroutines each produce their worker number and then exit. 42 | it, errptr := parallel.Producers(ctx, 5, func(_ context.Context, n int, send func(int) error) error { 43 | return send(n) 44 | }) 45 | 46 | // Caller consumes the produced values. 47 | for val := range it { 48 | fmt.Println(val) 49 | } 50 | if *errptr != nil { 51 | panic(*errptr) 52 | } 53 | // Unordered output: 54 | // 0 55 | // 1 56 | // 2 57 | // 3 58 | // 4 59 | } 60 | 61 | func ExampleValues() { 62 | ctx := context.Background() 63 | 64 | // Five goroutines, each placing its worker number in the corresponding slot of the result slice. 65 | values, err := parallel.Values(ctx, 5, func(_ context.Context, n int) (int, error) { 66 | return n, nil 67 | }) 68 | if err != nil { 69 | panic(err) 70 | } 71 | fmt.Println(values) 72 | // Output: 73 | // [0 1 2 3 4] 74 | } 75 | 76 | func ExamplePool() { 77 | // Three workers available, each negating its input. 78 | pool := parallel.Pool(3, func(n int) (int, error) { 79 | return -n, nil 80 | }) 81 | 82 | var wg sync.WaitGroup 83 | 84 | // Ten goroutines requesting work from those three workers. 85 | for i := 1; i <= 10; i++ { 86 | wg.Add(1) 87 | go func() { 88 | neg, err := pool(i) 89 | if err != nil { 90 | panic(err) 91 | } 92 | fmt.Println(neg) 93 | wg.Done() 94 | }() 95 | } 96 | 97 | wg.Wait() 98 | 99 | // Unordered output: 100 | // -1 101 | // -2 102 | // -3 103 | // -4 104 | // -5 105 | // -6 106 | // -7 107 | // -8 108 | // -9 109 | // -10 110 | } 111 | 112 | func ExampleProtect() { 113 | // A caller is supplied with a reader and a writer 114 | // for purposes of accessing and updating the protected value safely 115 | // (in this case an int, initially 4). 116 | reader, writer, closer := parallel.Protect(4) 117 | defer closer() 118 | 119 | // Call the reader in three concurrent goroutines, each printing the protected value. 120 | var wg sync.WaitGroup 121 | for i := 0; i < 3; i++ { 122 | wg.Add(1) 123 | go func() { 124 | fmt.Println(reader()) 125 | wg.Done() 126 | }() 127 | } 128 | wg.Wait() 129 | 130 | // Increment the protected value. 131 | writer(reader() + 1) 132 | 133 | // Call the reader in three concurrent goroutines, each printing the protected value. 134 | for i := 0; i < 3; i++ { 135 | wg.Add(1) 136 | go func() { 137 | fmt.Println(reader()) 138 | wg.Done() 139 | }() 140 | } 141 | wg.Wait() 142 | 143 | // Output: 144 | // 4 145 | // 4 146 | // 4 147 | // 5 148 | // 5 149 | // 5 150 | } 151 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Go-generics - Generic slice, map, set, iterator, and goroutine utilities for Go 2 | 3 | [![Go Reference](https://pkg.go.dev/badge/github.com/bobg/go-generics/v4.svg)](https://pkg.go.dev/github.com/bobg/go-generics/v4) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/bobg/go-generics/v4)](https://goreportcard.com/report/github.com/bobg/go-generics/v4) 5 | [![Tests](https://github.com/bobg/go-generics/actions/workflows/go.yml/badge.svg)](https://github.com/bobg/go-generics/actions/workflows/go.yml) 6 | [![Coverage Status](https://coveralls.io/repos/github/bobg/go-generics/badge.svg?branch=master)](https://coveralls.io/github/bobg/go-generics?branch=master) 7 | [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) 8 | 9 | This is go-generics, 10 | a collection of typesafe generic utilities 11 | for slices, sets, and goroutine patterns in Go. 12 | 13 | # Compatibility note 14 | 15 | This is version 4 of this library, 16 | for the release of Go 1.23. 17 | 18 | Earlier versions of this library included a package, 19 | `iter`, 20 | that defined an iterator type over several types of containers, 21 | and functions for operating with iterators. 22 | However, Go 1.23 defines its own, better iterator mechanism 23 | via the new “range over function” language feature, 24 | plus [a new standard-library package](https://pkg.go.dev/iter) also called `iter`. 25 | This version of the go-generics library therefore does away with its `iter` package. 26 | The handy functions that `iter` contained for working with iterators 27 | (`Filter`, `Map`, `FirstN`, and many more) 28 | can now be found in the [github.com/bobg/seqs](https://pkg.go.dev/github.com/bobg/seqs) library, 29 | adapted for Go 1.23 iterators. 30 | 31 | (This version of go-generics might have kept `iter` as a drop-in replacement for the standard-library package, 32 | but was unable because the standard library defines two types, 33 | `iter.Seq[K]` and `iter.Seq2[K, V]`, 34 | that go-generics would have had to reference using type aliases; 35 | but Go type aliases [do not yet permit type parameters](https://github.com/golang/go/issues/46477#issuecomment-2101270785).) 36 | 37 | Earlier versions of this library included combinatorial operations in the `slices` package. 38 | Those have now been moved to their own library, 39 | [github.com/bobg/combo](https://pkg.go.dev/github.com/bobg/combo). 40 | 41 | Earlier versions of this library included a `maps` package, 42 | which was a drop-in replacement for the stdlib `maps` 43 | (added in Go 1.21) 44 | plus a few convenience functions. 45 | With the advent of Go 1.23 iterators, 46 | those few convenience functions are mostly redundant 47 | (and a couple of them − `Keys` and `Values` − conflict with new functions in the standard library), 48 | so `maps` has been removed. 49 | 50 | Earlier versions of this library included a `Find` method on the `set.Of[T]` type, 51 | for finding some element in the set that satisfies a given predicate. 52 | This method has been removed in favor of composing operations with functions from [github.com/bobg/seqs](https://pkg.go.dev/github.com/bobg/seqs). 53 | For example, `s.Find(pred)` can now be written as `seqs.First(seqs.Filter(s.All(), pred))`. 54 | 55 | # Slices 56 | 57 | The `slices` package is useful in three ways: 58 | 59 | - It encapsulates hard-to-remember Go idioms for inserting and removing elements to and from the middle of a slice; 60 | - It adds the ability to index from the right end of a slice using negative integers 61 | (for example, Get(s, -1) is the same as s[len(s)-1]); and 62 | - It includes `Map`, `Filter`, and a few other such functions 63 | for processing slice elements with callbacks. 64 | 65 | The `slices` package is a drop-in replacement 66 | for the `slices` package 67 | added to the Go stdlib 68 | in [Go 1.21](https://go.dev/doc/go1.21#slices). 69 | There is one difference: 70 | this version of `slices` 71 | allows the index value passed to `Insert`, `Delete`, and `Replace` 72 | to be negative for counting backward from the end of the slice. 73 | 74 | # Set 75 | 76 | The `set` package implements the usual collection of functions for sets: 77 | `Intersect`, `Union`, `Diff`, etc., 78 | as well as member functions for adding and removing items, 79 | checking for the presence of items, 80 | and iterating over items. 81 | 82 | # Parallel 83 | 84 | The `parallel` package contains functions for coordinating parallel workers: 85 | 86 | - `Consumers` manages a set of N workers consuming a stream of values produced by the caller. 87 | - `Producers` manages a set of N workers producing a stream of values consumed by the caller. 88 | - `Values` concurrently produces a set of N values. 89 | - `Pool` manages access to a pool of concurrent workers. 90 | - `Protect` manages concurrent access to a protected data value. 91 | -------------------------------------------------------------------------------- /set/set.go: -------------------------------------------------------------------------------- 1 | // Package set contains generic typesafe set operations. 2 | package set 3 | 4 | import ( 5 | "iter" 6 | "maps" 7 | "slices" 8 | ) 9 | 10 | // Of is a set of elements of type T. 11 | // It is called "Of" so that when qualified with this package name 12 | // and instantiated with a member type, 13 | // it reads naturally: e.g., set.Of[int]. 14 | // 15 | // The zero value of Of is not safe for use. 16 | // Create one with New instead. 17 | type Of[T comparable] map[T]struct{} 18 | 19 | // New produces a new set containing the given values. 20 | func New[T comparable](vals ...T) Of[T] { 21 | s := Of[T](make(map[T]struct{})) 22 | for _, val := range vals { 23 | s.Add(val) 24 | } 25 | return Of[T](s) 26 | } 27 | 28 | // Collect collects the members of the given sequence into a new set. 29 | func Collect[T comparable](inp iter.Seq[T]) Of[T] { 30 | s := New[T]() 31 | s.AddSeq(inp) 32 | return s 33 | } 34 | 35 | // Add adds the given values to the set. 36 | // Items already present in the set are silently ignored. 37 | func (s Of[T]) Add(vals ...T) { 38 | for _, val := range vals { 39 | s[val] = struct{}{} 40 | } 41 | } 42 | 43 | // AddSeq adds the members of the given sequence to the set. 44 | func (s Of[T]) AddSeq(inp iter.Seq[T]) { 45 | for val := range inp { 46 | s.Add(val) 47 | } 48 | } 49 | 50 | // Has tells whether the given value is in the set. 51 | // The set may be nil. 52 | func (s Of[T]) Has(val T) bool { 53 | _, ok := s[val] 54 | return ok 55 | } 56 | 57 | // Del removes the given items from the set. 58 | // Items already absent from the set are silently ignored. 59 | func (s Of[T]) Del(vals ...T) { 60 | for _, val := range vals { 61 | delete(s, val) 62 | } 63 | } 64 | 65 | // Len tells the number of distinct values in the set. 66 | // The set may be nil. 67 | func (s Of[T]) Len() int { 68 | return len(s) 69 | } 70 | 71 | // Equal tests whether the set has the same membership as another. 72 | // Either set may be nil. 73 | func (s Of[T]) Equal(other Of[T]) bool { 74 | if len(s) != len(other) { 75 | return false 76 | } 77 | for val := range s { 78 | if !other.Has(val) { 79 | return false 80 | } 81 | } 82 | return true 83 | } 84 | 85 | // Each calls a function on each element of the set in an indeterminate order. 86 | // It is safe to add and remove items during a call to Each, 87 | // but that can affect the sequence of values seen later during the same Each call. 88 | // The set may be nil. 89 | func (s Of[T]) Each(f func(T)) { 90 | _ = s.Eachx(func(val T) error { 91 | f(val) 92 | return nil 93 | }) 94 | } 95 | 96 | // Eachx calls a function on each element of the set in an indeterminate order. 97 | // It is safe to add and remove items during a call to Each, 98 | // but that can affect the sequence of values seen later during the same Eachx call. 99 | // The set may be nil. 100 | // If the function returns an error, 101 | // Eachx stops and returns that error. 102 | func (s Of[T]) Eachx(f func(T) error) error { 103 | for val := range s { 104 | err := f(val) 105 | if err != nil { 106 | return err 107 | } 108 | } 109 | return nil 110 | } 111 | 112 | // All produces an iterator over the members of the set, 113 | // in an indeterminate order. 114 | // The set may be nil. 115 | func (s Of[T]) All() iter.Seq[T] { 116 | return maps.Keys(s) 117 | } 118 | 119 | // Slice produces a new slice of the elements in the set. 120 | // The slice is in an indeterminate order. 121 | func (s Of[T]) Slice() []T { 122 | if s.Len() == 0 { 123 | return nil 124 | } 125 | return slices.Collect(s.All()) 126 | } 127 | 128 | // Intersect produces a new set containing only items that appear in all the given sets. 129 | // The input may include nils, 130 | // representing empty sets 131 | // and therefore producing an empty (but non-nil) intersection. 132 | func Intersect[T comparable](sets ...Of[T]) Of[T] { 133 | result := New[T]() 134 | if len(sets) == 0 { 135 | return result 136 | } 137 | for _, s := range sets { 138 | if s == nil { 139 | return result 140 | } 141 | } 142 | sets[0].Each(func(val T) { 143 | for _, s := range sets[1:] { 144 | if !s.Has(val) { 145 | return 146 | } 147 | } 148 | result.Add(val) 149 | }) 150 | return result 151 | } 152 | 153 | // Union produces a new set containing all the items in all the given sets. 154 | // The input may include nils, 155 | // representing empty sets. 156 | // The result is never nil (but may be empty). 157 | func Union[T comparable](sets ...Of[T]) Of[T] { 158 | result := New[T]() 159 | for _, s := range sets { 160 | if s == nil { 161 | continue 162 | } 163 | s.Each(func(val T) { result.Add(val) }) 164 | } 165 | return result 166 | } 167 | 168 | // Diff produces a new set containing the items in s1 that are not also in s2. 169 | // Either set may be nil. 170 | // The result is never nil (but may be empty). 171 | func Diff[T comparable](s1, s2 Of[T]) Of[T] { 172 | s := New[T]() 173 | s1.Each(func(val T) { 174 | if !s2.Has(val) { 175 | s.Add(val) 176 | } 177 | }) 178 | return s 179 | } 180 | -------------------------------------------------------------------------------- /parallel/parallel.go: -------------------------------------------------------------------------------- 1 | // Package parallel contains generic typesafe functions to manage concurrent logic of various kinds. 2 | package parallel 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "iter" 8 | "sync" 9 | 10 | "golang.org/x/sync/errgroup" 11 | ) 12 | 13 | // Error is an error type for wrapping errors returned from worker goroutines. 14 | // It contains the worker number of the goroutine that produced the error. 15 | type Error struct { 16 | N int 17 | Err error 18 | } 19 | 20 | func (e Error) Error() string { 21 | return fmt.Sprintf("in goroutine %d: %s", e.N, e.Err) 22 | } 23 | 24 | func (e Error) Unwrap() error { 25 | return e.Err 26 | } 27 | 28 | // Values produces a slice of n values using n parallel workers each running the function f. 29 | // 30 | // Each worker receives its worker number (in the range 0 through n-1). 31 | // 32 | // An error from any worker cancels them all. 33 | // The first error is returned to the caller. 34 | // 35 | // The resulting slice has length n. 36 | // The value at position i comes from worker i. 37 | func Values[T any](ctx context.Context, n int, f func(context.Context, int) (T, error)) ([]T, error) { 38 | g, ctx := errgroup.WithContext(ctx) 39 | result := make([]T, n) 40 | 41 | for i := 0; i < n; i++ { 42 | g.Go(func() error { 43 | val, err := f(ctx, i) 44 | result[i] = val 45 | if err != nil { 46 | return Error{N: i, Err: err} 47 | } 48 | return nil 49 | }) 50 | } 51 | 52 | err := g.Wait() 53 | return result, err 54 | } 55 | 56 | // Producers launches n parallel workers each running the function f. 57 | // 58 | // Each worker receives its worker number 59 | // (in the range 0 through n-1) 60 | // and a callback to use for producing a value. 61 | // If the callback returns an error, 62 | // the worker should exit with that error. 63 | // 64 | // The callback that the worker uses to produce a value may block 65 | // until the caller is able to consume the value. 66 | // 67 | // An error from any worker cancels them all. 68 | // 69 | // The caller gets an iterator over the values produced 70 | // and a non-nil pointer to an error. 71 | // The caller may dereference the error pointer to see if any worker failed, 72 | // but not before the iterator has been fully consumed. 73 | // The error (if there is one) is of type [Error], 74 | // whose N field indicates which worker failed. 75 | func Producers[T any](ctx context.Context, n int, f func(context.Context, int, func(T) error) error) (iter.Seq[T], *error) { 76 | ch := make(chan T) 77 | g, innerCtx := errgroup.WithContext(ctx) 78 | 79 | for i := 0; i < n; i++ { 80 | i := i 81 | g.Go(func() error { 82 | err := f(innerCtx, i, func(val T) error { 83 | select { 84 | case <-innerCtx.Done(): 85 | return innerCtx.Err() 86 | case ch <- val: 87 | return nil 88 | } 89 | }) 90 | if err != nil { 91 | err = Error{N: i, Err: err} 92 | } 93 | return err 94 | }) 95 | } 96 | 97 | var err error 98 | 99 | go func() { 100 | err = g.Wait() 101 | close(ch) 102 | }() 103 | 104 | // This could be FromSeq(ch), 105 | // but that would introduce a circular dependency on github.com/bobg/seqs. 106 | fromSeq := func(yield func(T) bool) { 107 | for x := range ch { 108 | if !yield(x) { 109 | return 110 | } 111 | } 112 | } 113 | 114 | return fromSeq, &err 115 | } 116 | 117 | // Consumers launches n parallel workers each consuming values supplied by the caller. 118 | // 119 | // When a value is available, 120 | // an available worker calls the function f to consume it. 121 | // This callback receives the worker's number 122 | // (in the range 0 through n-1) 123 | // and the value. 124 | // 125 | // The caller receives two callbacks: 126 | // one for sending a value to the workers via an internal channel, 127 | // and one for closing that channel, 128 | // signaling the end of input and causing the workers to exit normally. 129 | // 130 | // The value-sending callback may block until a worker is available to consume the value. 131 | // 132 | // An error from any worker cancels them all. 133 | // This error is returned from the close-channel callback. 134 | // After any error, the value-sending callback will return an error. 135 | // (Not the original error, however. 136 | // For that, the caller should still invoke the close callback.) 137 | func Consumers[T any](ctx context.Context, n int, f func(context.Context, int, T) error) (func(T) error, func() error) { 138 | ch := make(chan T, n) 139 | 140 | g, ctx := errgroup.WithContext(ctx) 141 | 142 | for i := 0; i < n; i++ { 143 | i := i 144 | g.Go(func() error { 145 | for { 146 | select { 147 | case <-ctx.Done(): 148 | return Error{N: i, Err: ctx.Err()} 149 | case val, ok := <-ch: 150 | if !ok { 151 | return nil 152 | } 153 | if err := f(ctx, i, val); err != nil { 154 | return Error{N: i, Err: err} 155 | } 156 | } 157 | } 158 | }) 159 | } 160 | 161 | sendfn := func(val T) error { 162 | select { 163 | case <-ctx.Done(): 164 | return ctx.Err() 165 | case ch <- val: 166 | return nil 167 | } 168 | } 169 | 170 | closefn := func() error { 171 | close(ch) 172 | return g.Wait() 173 | } 174 | 175 | return sendfn, closefn 176 | } 177 | 178 | // Pool permits up to n concurrent calls to a function f. 179 | // The caller receives a callback for requesting a worker from this pool. 180 | // When no worker is available, 181 | // the callback blocks until one becomes available. 182 | // Then it invokes f and returns the result. 183 | // 184 | // Each call of the callback is synchronous. 185 | // Any desired concurrency is the responsibility of the caller. 186 | func Pool[T, U any](n int, f func(T) (U, error)) func(T) (U, error) { 187 | var ( 188 | running int 189 | mu sync.Mutex 190 | cond = sync.NewCond(&mu) 191 | ) 192 | return func(val T) (U, error) { 193 | mu.Lock() 194 | for running >= n { 195 | cond.Wait() 196 | } 197 | running++ 198 | mu.Unlock() 199 | 200 | result, err := f(val) 201 | 202 | mu.Lock() 203 | running-- 204 | cond.Signal() 205 | mu.Unlock() 206 | 207 | return result, err 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /slices/slice_test.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "slices" 7 | "testing" 8 | ) 9 | 10 | func TestPrefix(t *testing.T) { 11 | cases := []struct { 12 | inp []int 13 | idx int 14 | want []int 15 | }{{ 16 | inp: []int{4, 5, 6}, 17 | idx: 0, 18 | want: []int{}, 19 | }, { 20 | inp: []int{4, 5, 6}, 21 | idx: 1, 22 | want: []int{4}, 23 | }, { 24 | inp: []int{4, 5, 6}, 25 | idx: 2, 26 | want: []int{4, 5}, 27 | }, { 28 | inp: []int{4, 5, 6}, 29 | idx: -1, 30 | want: []int{4, 5}, 31 | }, { 32 | inp: []int{4, 5, 6}, 33 | idx: -2, 34 | want: []int{4}, 35 | }} 36 | 37 | for i, tc := range cases { 38 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 39 | got := Prefix(tc.inp, tc.idx) 40 | if !reflect.DeepEqual(got, tc.want) { 41 | t.Errorf("got %v, want %v", got, tc.want) 42 | } 43 | }) 44 | } 45 | } 46 | 47 | func TestSuffix(t *testing.T) { 48 | cases := []struct { 49 | inp []int 50 | idx int 51 | want []int 52 | }{{ 53 | inp: []int{4, 5, 6}, 54 | idx: 0, 55 | want: []int{4, 5, 6}, 56 | }, { 57 | inp: []int{4, 5, 6}, 58 | idx: 1, 59 | want: []int{5, 6}, 60 | }, { 61 | inp: []int{4, 5, 6}, 62 | idx: 2, 63 | want: []int{6}, 64 | }, { 65 | inp: []int{4, 5, 6}, 66 | idx: -1, 67 | want: []int{6}, 68 | }, { 69 | inp: []int{4, 5, 6}, 70 | idx: -2, 71 | want: []int{5, 6}, 72 | }} 73 | 74 | for i, tc := range cases { 75 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 76 | got := Suffix(tc.inp, tc.idx) 77 | if !reflect.DeepEqual(got, tc.want) { 78 | t.Errorf("got %v, want %v", got, tc.want) 79 | } 80 | }) 81 | } 82 | } 83 | 84 | func isEven(n int) bool { return n%2 == 0 } 85 | 86 | func TestPrefixFunc(t *testing.T) { 87 | cases := []struct { 88 | inp, want []int 89 | }{{ 90 | inp: nil, 91 | want: nil, 92 | }, { 93 | inp: []int{4}, 94 | want: []int{4}, 95 | }, { 96 | inp: []int{5}, 97 | want: nil, 98 | }, { 99 | inp: []int{4, 5, 6, 7}, 100 | want: []int{4}, 101 | }, { 102 | inp: []int{4, 6}, 103 | want: []int{4, 6}, 104 | }, { 105 | inp: []int{4, 6, 5, 7}, 106 | want: []int{4, 6}, 107 | }, { 108 | inp: []int{7, 5, 6, 4}, 109 | want: nil, 110 | }} 111 | 112 | for i, tc := range cases { 113 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 114 | got := PrefixFunc(tc.inp, isEven) 115 | if !slices.Equal(got, tc.want) { 116 | t.Errorf("got %v, want %v", got, tc.want) 117 | } 118 | }) 119 | } 120 | } 121 | 122 | func TestSuffixFunc(t *testing.T) { 123 | cases := []struct { 124 | inp, want []int 125 | }{{ 126 | inp: nil, 127 | want: nil, 128 | }, { 129 | inp: []int{4}, 130 | want: []int{4}, 131 | }, { 132 | inp: []int{5}, 133 | want: nil, 134 | }, { 135 | inp: []int{4, 5, 6, 7}, 136 | want: nil, 137 | }, { 138 | inp: []int{4, 6}, 139 | want: []int{4, 6}, 140 | }, { 141 | inp: []int{7, 6, 5, 4}, 142 | want: []int{4}, 143 | }} 144 | 145 | for i, tc := range cases { 146 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 147 | got := SuffixFunc(tc.inp, isEven) 148 | if !slices.Equal(got, tc.want) { 149 | t.Errorf("got %v, want %v", got, tc.want) 150 | } 151 | }) 152 | } 153 | } 154 | 155 | func TestRindex(t *testing.T) { 156 | cases := []struct { 157 | inp []int 158 | want int 159 | }{{ 160 | inp: nil, 161 | want: -1, 162 | }, { 163 | inp: []int{4, 5, 6, 7}, 164 | want: 1, 165 | }, { 166 | inp: []int{4, 5, 6, 7, 5}, 167 | want: 4, 168 | }, { 169 | inp: []int{6, 7, 8}, 170 | want: -1, 171 | }, { 172 | inp: []int{6, 5, 4, 3}, 173 | want: 1, 174 | }, { 175 | inp: []int{5, 5, 5, 4}, 176 | want: 2, 177 | }} 178 | 179 | for i, tc := range cases { 180 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 181 | got := Rindex(tc.inp, 5) 182 | if got != tc.want { 183 | t.Errorf("got %d, want %d", got, tc.want) 184 | } 185 | }) 186 | } 187 | } 188 | 189 | func TestSliceN(t *testing.T) { 190 | cases := []struct { 191 | inp []int 192 | idx, n int 193 | want []int 194 | }{{ 195 | inp: []int{4, 5, 6, 7}, 196 | idx: 0, 197 | n: 0, 198 | want: []int{}, 199 | }, { 200 | inp: []int{4, 5, 6, 7}, 201 | idx: 0, 202 | n: 1, 203 | want: []int{4}, 204 | }, { 205 | inp: []int{4, 5, 6, 7}, 206 | idx: 1, 207 | n: 1, 208 | want: []int{5}, 209 | }, { 210 | inp: []int{4, 5, 6, 7}, 211 | idx: 0, 212 | n: 2, 213 | want: []int{4, 5}, 214 | }, { 215 | inp: []int{4, 5, 6, 7}, 216 | idx: 1, 217 | n: 2, 218 | want: []int{5, 6}, 219 | }, { 220 | inp: []int{4, 5, 6, 7}, 221 | idx: -1, 222 | n: 1, 223 | want: []int{7}, 224 | }, { 225 | inp: []int{4, 5, 6, 7}, 226 | idx: -2, 227 | n: 1, 228 | want: []int{6}, 229 | }, { 230 | inp: []int{4, 5, 6, 7}, 231 | idx: -2, 232 | n: 2, 233 | want: []int{6, 7}, 234 | }} 235 | 236 | for i, tc := range cases { 237 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 238 | got := SliceN(tc.inp, tc.idx, tc.n) 239 | if !reflect.DeepEqual(got, tc.want) { 240 | t.Errorf("got %v, want %v", got, tc.want) 241 | } 242 | }) 243 | } 244 | } 245 | 246 | func TestSliceTo(t *testing.T) { 247 | cases := []struct { 248 | inp []int 249 | from, to int 250 | want []int 251 | }{{ 252 | inp: []int{4, 5, 6, 7}, 253 | from: 0, 254 | to: 0, 255 | want: []int{4, 5, 6, 7}, 256 | }, { 257 | inp: []int{4, 5, 6, 7}, 258 | from: 0, 259 | to: 1, 260 | want: []int{4}, 261 | }, { 262 | inp: []int{4, 5, 6, 7}, 263 | from: 1, 264 | to: 2, 265 | want: []int{5}, 266 | }, { 267 | inp: []int{4, 5, 6, 7}, 268 | from: 0, 269 | to: 2, 270 | want: []int{4, 5}, 271 | }, { 272 | inp: []int{4, 5, 6, 7}, 273 | from: 1, 274 | to: 3, 275 | want: []int{5, 6}, 276 | }, { 277 | inp: []int{4, 5, 6, 7}, 278 | from: -1, 279 | to: 0, 280 | want: []int{7}, 281 | }, { 282 | inp: []int{4, 5, 6, 7}, 283 | from: -2, 284 | to: -1, 285 | want: []int{6}, 286 | }, { 287 | inp: []int{4, 5, 6, 7}, 288 | from: -2, 289 | to: 0, 290 | want: []int{6, 7}, 291 | }} 292 | 293 | for i, tc := range cases { 294 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 295 | got := SliceTo(tc.inp, tc.from, tc.to) 296 | if !reflect.DeepEqual(got, tc.want) { 297 | t.Errorf("got %v, want %v", got, tc.want) 298 | } 299 | }) 300 | } 301 | } 302 | 303 | func TestRotate(t *testing.T) { 304 | cases := []struct { 305 | inp []int 306 | n int 307 | want []int 308 | }{{ 309 | inp: []int{4, 5, 1, 2, 3}, 310 | n: -2, 311 | want: []int{1, 2, 3, 4, 5}, 312 | }, { 313 | inp: []int{4, 5, 1, 2, 3}, 314 | n: -7, 315 | want: []int{1, 2, 3, 4, 5}, 316 | }, { 317 | inp: []int{4, 5, 1, 2, 3}, 318 | n: 3, 319 | want: []int{1, 2, 3, 4, 5}, 320 | }, { 321 | inp: []int{4, 5, 1, 2, 3}, 322 | n: 8, 323 | want: []int{1, 2, 3, 4, 5}, 324 | }} 325 | 326 | for i, tc := range cases { 327 | t.Run(fmt.Sprintf("case_%02d", i+1), func(t *testing.T) { 328 | got := make([]int, len(tc.inp)) 329 | copy(got, tc.inp) 330 | Rotate(got, tc.n) 331 | if !reflect.DeepEqual(got, tc.want) { 332 | t.Errorf("got %v, want %v", got, tc.want) 333 | } 334 | }) 335 | } 336 | } 337 | -------------------------------------------------------------------------------- /slices/dropin.go: -------------------------------------------------------------------------------- 1 | package slices 2 | 3 | import ( 4 | "cmp" 5 | "iter" 6 | "slices" 7 | ) 8 | 9 | // This file contains entrypoints for each of the functions in the standard Go slices package 10 | // (except those, like Insert, extended by other functions in this package). 11 | 12 | // Equal reports whether two slices are equal: the same length and all 13 | // elements equal. If the lengths are different, Equal returns false. 14 | // Otherwise, the elements are compared in increasing index order, and the 15 | // comparison stops at the first unequal pair. 16 | // Floating point NaNs are not considered equal. 17 | func Equal[S ~[]E, E comparable](s1, s2 S) bool { 18 | return slices.Equal(s1, s2) 19 | } 20 | 21 | // EqualFunc reports whether two slices are equal using an equality 22 | // function on each pair of elements. If the lengths are different, 23 | // EqualFunc returns false. Otherwise, the elements are compared in 24 | // increasing index order, and the comparison stops at the first index 25 | // for which eq returns false. 26 | func EqualFunc[S1 ~[]E1, S2 ~[]E2, E1, E2 any](s1 S1, s2 S2, eq func(E1, E2) bool) bool { 27 | return slices.EqualFunc(s1, s2, eq) 28 | } 29 | 30 | // Compare compares the elements of s1 and s2, using [cmp.Compare] on each pair 31 | // of elements. The elements are compared sequentially, starting at index 0, 32 | // until one element is not equal to the other. 33 | // The result of comparing the first non-matching elements is returned. 34 | // If both slices are equal until one of them ends, the shorter slice is 35 | // considered less than the longer one. 36 | // The result is 0 if s1 == s2, -1 if s1 < s2, and +1 if s1 > s2. 37 | func Compare[S ~[]E, E cmp.Ordered](s1, s2 S) int { 38 | return slices.Compare(s1, s2) 39 | } 40 | 41 | // CompareFunc is like [Compare] but uses a custom comparison function on each 42 | // pair of elements. 43 | // The result is the first non-zero result of cmp; if cmp always 44 | // returns 0 the result is 0 if len(s1) == len(s2), -1 if len(s1) < len(s2), 45 | // and +1 if len(s1) > len(s2). 46 | func CompareFunc[S1 ~[]E1, S2 ~[]E2, E1, E2 any](s1 S1, s2 S2, cmp func(E1, E2) int) int { 47 | return slices.CompareFunc(s1, s2, cmp) 48 | } 49 | 50 | // Index returns the index of the first occurrence of v in s, 51 | // or -1 if not present. 52 | func Index[S ~[]E, E comparable](s S, v E) int { 53 | return slices.Index(s, v) 54 | } 55 | 56 | // IndexFunc returns the first index i satisfying f(s[i]), 57 | // or -1 if none do. 58 | func IndexFunc[S ~[]E, E any](s S, f func(E) bool) int { 59 | return slices.IndexFunc(s, f) 60 | } 61 | 62 | // Contains reports whether v is present in s. 63 | func Contains[S ~[]E, E comparable](s S, v E) bool { 64 | return slices.Contains(s, v) 65 | } 66 | 67 | // ContainsFunc reports whether at least one 68 | // element e of s satisfies f(e). 69 | func ContainsFunc[S ~[]E, E any](s S, f func(E) bool) bool { 70 | return slices.ContainsFunc(s, f) 71 | } 72 | 73 | // DeleteFunc removes any elements from s for which del returns true, 74 | // returning the modified slice. 75 | // When DeleteFunc removes m elements, it might not modify the elements 76 | // s[len(s)-m:len(s)]. If those elements contain pointers you might consider 77 | // zeroing those elements so that objects they reference can be garbage 78 | // collected. 79 | func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { 80 | return slices.DeleteFunc(s, del) 81 | } 82 | 83 | // Clone returns a copy of the slice. 84 | // The elements are copied using assignment, so this is a shallow clone. 85 | func Clone[S ~[]E, E any](s S) S { 86 | return slices.Clone(s) 87 | } 88 | 89 | // Compact replaces consecutive runs of equal elements with a single copy. 90 | // This is like the uniq command found on Unix. 91 | // Compact modifies the contents of the slice s and returns the modified slice, 92 | // which may have a smaller length. 93 | // When Compact discards m elements in total, it might not modify the elements 94 | // s[len(s)-m:len(s)]. If those elements contain pointers you might consider 95 | // zeroing those elements so that objects they reference can be garbage collected. 96 | func Compact[S ~[]E, E comparable](s S) S { 97 | return slices.Compact(s) 98 | } 99 | 100 | // CompactFunc is like [Compact] but uses an equality function to compare elements. 101 | // For runs of elements that compare equal, CompactFunc keeps the first one. 102 | func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S { 103 | return slices.CompactFunc(s, eq) 104 | } 105 | 106 | // Grow increases the slice's capacity, if necessary, to guarantee space for 107 | // another n elements. After Grow(n), at least n elements can be appended 108 | // to the slice without another allocation. If n is negative or too large to 109 | // allocate the memory, Grow panics. 110 | func Grow[S ~[]E, E any](s S, n int) S { 111 | return slices.Grow(s, n) 112 | } 113 | 114 | // Clip removes unused capacity from the slice, returning s[:len(s):len(s)]. 115 | func Clip[S ~[]E, E any](s S) S { 116 | return slices.Clip(s) 117 | } 118 | 119 | // Reverse reverses the elements of the slice in place. 120 | func Reverse[S ~[]E, E any](s S) { 121 | slices.Reverse(s) 122 | } 123 | 124 | // Sort sorts a slice of any ordered type in ascending order. 125 | // When sorting floating-point numbers, NaNs are ordered before other values. 126 | func Sort[S ~[]E, E cmp.Ordered](x S) { 127 | slices.Sort(x) 128 | } 129 | 130 | // SortFunc sorts the slice x in ascending order as determined by the cmp 131 | // function. This sort is not guaranteed to be stable. 132 | // cmp(a, b) should return a negative number when a < b, a positive number when 133 | // a > b and zero when a == b. 134 | // 135 | // SortFunc requires that cmp is a strict weak ordering. 136 | // See https://en.wikipedia.org/wiki/Weak_ordering#Strict_weak_orderings. 137 | func SortFunc[S ~[]E, E any](x S, cmp func(a, b E) int) { 138 | slices.SortFunc(x, cmp) 139 | } 140 | 141 | // SortStableFunc sorts the slice x while keeping the original order of equal 142 | // elements, using cmp to compare elements in the same way as [SortFunc]. 143 | func SortStableFunc[S ~[]E, E any](x S, cmp func(a, b E) int) { 144 | slices.SortStableFunc(x, cmp) 145 | } 146 | 147 | // IsSorted reports whether x is sorted in ascending order. 148 | func IsSorted[S ~[]E, E cmp.Ordered](x S) bool { 149 | return slices.IsSorted(x) 150 | } 151 | 152 | // IsSortedFunc reports whether x is sorted in ascending order, with cmp as the 153 | // comparison function as defined by [SortFunc]. 154 | func IsSortedFunc[S ~[]E, E any](x S, cmp func(a, b E) int) bool { 155 | return slices.IsSortedFunc(x, cmp) 156 | } 157 | 158 | // Min returns the minimal value in x. It panics if x is empty. 159 | // For floating-point numbers, Min propagates NaNs (any NaN value in x 160 | // forces the output to be NaN). 161 | func Min[S ~[]E, E cmp.Ordered](x S) E { 162 | return slices.Min(x) 163 | } 164 | 165 | // MinFunc returns the minimal value in x, using cmp to compare elements. 166 | // It panics if x is empty. If there is more than one minimal element 167 | // according to the cmp function, MinFunc returns the first one. 168 | func MinFunc[S ~[]E, E any](x S, cmp func(a, b E) int) E { 169 | return slices.MinFunc(x, cmp) 170 | } 171 | 172 | // Max returns the maximal value in x. It panics if x is empty. 173 | // For floating-point E, Max propagates NaNs (any NaN value in x 174 | // forces the output to be NaN). 175 | func Max[S ~[]E, E cmp.Ordered](x S) E { 176 | return slices.Max(x) 177 | } 178 | 179 | // MaxFunc returns the maximal value in x, using cmp to compare elements. 180 | // It panics if x is empty. If there is more than one maximal element 181 | // according to the cmp function, MaxFunc returns the first one. 182 | func MaxFunc[S ~[]E, E any](x S, cmp func(a, b E) int) E { 183 | return slices.MaxFunc(x, cmp) 184 | } 185 | 186 | // BinarySearch searches for target in a sorted slice and returns the position 187 | // where target is found, or the position where target would appear in the 188 | // sort order; it also returns a bool saying whether the target is really found 189 | // in the slice. The slice must be sorted in increasing order. 190 | func BinarySearch[S ~[]E, E cmp.Ordered](x S, target E) (int, bool) { 191 | return slices.BinarySearch(x, target) 192 | } 193 | 194 | // BinarySearchFunc works like [BinarySearch], but uses a custom comparison 195 | // function. The slice must be sorted in increasing order, where "increasing" 196 | // is defined by cmp. cmp should return 0 if the slice element matches 197 | // the target, a negative number if the slice element precedes the target, 198 | // or a positive number if the slice element follows the target. 199 | // cmp must implement the same ordering as the slice, such that if 200 | // cmp(a, t) < 0 and cmp(b, t) >= 0, then a must precede b in the slice. 201 | func BinarySearchFunc[S ~[]E, E, T any](x S, target T, cmp func(E, T) int) (int, bool) { 202 | return slices.BinarySearchFunc(x, target, cmp) 203 | } 204 | 205 | // Concat returns a new slice concatenating the passed in slices. 206 | func Concat[S ~[]E, E any](s ...S) S { 207 | return slices.Concat(s...) 208 | } 209 | 210 | // The following are new for Go 1.23. 211 | 212 | // All returns an iterator over index-value pairs in the slice in the usual order. 213 | func All[Slice ~[]E, E any](s Slice) iter.Seq2[int, E] { 214 | return slices.All(s) 215 | } 216 | 217 | // Values returns an iterator that yields the slice elements in order. 218 | func Values[Slice ~[]E, E any](s Slice) iter.Seq[E] { 219 | return slices.Values(s) 220 | } 221 | 222 | // Backward returns an iterator over index-value pairs in the slice, traversing it backward with descending indices. 223 | func Backward[Slice ~[]E, E any](s Slice) iter.Seq2[int, E] { 224 | return slices.Backward(s) 225 | } 226 | 227 | // Collect collects values from seq into a new slice and returns it. 228 | func Collect[E any](seq iter.Seq[E]) []E { 229 | return slices.Collect(seq) 230 | } 231 | 232 | // AppendSeq appends the values from seq to the slice and returns the extended slice. 233 | func AppendSeq[Slice ~[]E, E any](s Slice, seq iter.Seq[E]) Slice { 234 | return slices.AppendSeq(s, seq) 235 | } 236 | 237 | // Sorted collects values from seq into a new slice, sorts the slice, and returns it. 238 | func Sorted[E cmp.Ordered](seq iter.Seq[E]) []E { 239 | return slices.Sorted(seq) 240 | } 241 | 242 | // SortedFunc collects values from seq into a new slice, sorts the slice using the comparison function, and returns it. 243 | func SortedFunc[E any](seq iter.Seq[E], cmp func(E, E) int) []E { 244 | return slices.SortedFunc(seq, cmp) 245 | } 246 | 247 | // SortedStableFunc collects values from seq into a new slice. It then sorts the slice while keeping the original order of equal elements, using the comparison function to compare elements. It returns the new slice. 248 | func SortedStableFunc[E any](seq iter.Seq[E], cmp func(E, E) int) []E { 249 | return slices.SortedStableFunc(seq, cmp) 250 | } 251 | 252 | // Chunk returns an iterator over consecutive sub-slices of up to n elements of s. All but the last sub-slice will have size n. All sub-slices are clipped to have no capacity beyond the length. If s is empty, the sequence is empty: there is no empty slice in the sequence. Chunk panics if n is less than 1. 253 | func Chunk[Slice ~[]E, E any](s Slice, n int) iter.Seq[Slice] { 254 | return slices.Chunk(s, n) 255 | } 256 | -------------------------------------------------------------------------------- /slices/slices.go: -------------------------------------------------------------------------------- 1 | // Package slices contains utility functions for working with slices. 2 | // It encapsulates hard-to-remember idioms for inserting and removing elements; 3 | // it adds the ability to index from the right end of a slice using negative integers 4 | // (for example, Get(s, -1) is the same as s[len(s)-1]), 5 | // and it includes Map, Filter, 6 | // and a few other such functions 7 | // for processing slice elements with callbacks. 8 | // 9 | // This package is a drop-in replacement 10 | // for the slices package 11 | // added to the Go stdlib 12 | // in Go 1.21 (https://go.dev/doc/go1.21#slices). 13 | // There is one difference: 14 | // this version of slices 15 | // allows the index value passed to `Insert`, `Delete`, and `Replace` 16 | // to be negative for counting backward from the end of the slice. 17 | package slices 18 | 19 | import ( 20 | "slices" 21 | "sort" 22 | ) 23 | 24 | // Get gets the idx'th element of s. 25 | // 26 | // If idx < 0 it counts from the end of s. 27 | func Get[T any, S ~[]T](s S, idx int) T { 28 | if idx < 0 { 29 | idx += len(s) 30 | } 31 | return s[idx] 32 | } 33 | 34 | // Put puts a given value into the idx'th location in s. 35 | // 36 | // If idx < 0 it counts from the end of s. 37 | // 38 | // The input slice is modified. 39 | func Put[T any, S ~[]T](s S, idx int, val T) { 40 | if idx < 0 { 41 | idx += len(s) 42 | } 43 | s[idx] = val 44 | } 45 | 46 | // Append is the same as Go's builtin append and is included for completeness. 47 | func Append[T any, S ~[]T](s S, vals ...T) S { 48 | return append(s, vals...) 49 | } 50 | 51 | // Insert inserts the given values at the idx'th location in s and returns the result. 52 | // After the insert, the first new value has position idx. 53 | // 54 | // If idx < 0, it counts from the end of s. 55 | // (This is a change from the behavior of Go's standard slices.Insert.) 56 | // 57 | // The input slice is modified. 58 | // 59 | // Example: Insert([x, y, z], 1, a, b, c) -> [x, a, b, c, y, z] 60 | func Insert[S ~[]E, E any](s S, idx int, vals ...E) S { 61 | if idx < 0 { 62 | idx += len(s) 63 | } 64 | return slices.Insert(s, idx, vals...) 65 | } 66 | 67 | // Delete removes the elements s[i:j] from s, returning the modified slice. 68 | // Delete panics if s[i:j] is not a valid slice of s. 69 | // Delete is O(len(s)-j), so if many items must be deleted, it is better to 70 | // make a single call deleting them all together than to delete one at a time. 71 | // Delete might not modify the elements s[len(s)-(j-i):len(s)]. If those 72 | // elements contain pointers you might consider zeroing those elements so that 73 | // objects they reference can be garbage collected. 74 | // 75 | // If i < 0 it counts from the end of s. 76 | // If j <= 0 it counts from the end of s. 77 | // (This is a change from the behavior of Go's standard slices.Delete.) 78 | func Delete[S ~[]E, E any](s S, i, j int) S { 79 | return RemoveTo(s, i, j) 80 | } 81 | 82 | // Replace replaces the elements s[i:j] by the given v, and returns the 83 | // modified slice. Replace panics if s[i:j] is not a valid slice of s. 84 | // 85 | // If i < 0 it counts from the end of s. 86 | // If j <= 0 it counts from the end of s. 87 | // (This is a change from the behavior of Go's standard slices.Replace.) 88 | func Replace[S ~[]E, E any](s S, i, j int, v ...E) S { 89 | return ReplaceTo(s, i, j, v...) 90 | } 91 | 92 | // ReplaceN replaces the n values of s beginning at position idx with the given values. 93 | // After the replace, the first new value has position idx. 94 | // 95 | // If idx < 0, it counts from the end of s. 96 | // 97 | // The input slice is modified. 98 | func ReplaceN[T any, S ~[]T](s S, idx, n int, vals ...T) S { 99 | if idx < 0 { 100 | idx += len(s) 101 | } 102 | return slices.Replace(s, idx, idx+n, vals...) 103 | } 104 | 105 | // ReplaceTo replaces the values of s beginning at from and ending before to with the given values. 106 | // After the replace, the first new value has position from. 107 | // 108 | // If from < 0 it counts from the end of s. 109 | // If to <= 0 it counts from the end of s. 110 | // 111 | // The input slice is modified. 112 | func ReplaceTo[T any, S ~[]T](s S, from, to int, vals ...T) S { 113 | if from < 0 { 114 | from += len(s) 115 | } 116 | if to < 0 { 117 | to += len(s) 118 | } else if to == 0 { 119 | to = len(s) 120 | } 121 | return slices.Replace(s, from, to, vals...) 122 | } 123 | 124 | // RemoveN removes n items from s beginning at position idx and returns the result. 125 | // 126 | // If idx < 0 it counts from the end of s. 127 | // 128 | // The input slice is modified. 129 | // 130 | // Example: RemoveN([a, b, c, d], 1, 2) -> [a, d] 131 | func RemoveN[T any, S ~[]T](s S, idx, n int) S { 132 | if idx < 0 { 133 | idx += len(s) 134 | } 135 | return slices.Delete(s, idx, idx+n) 136 | } 137 | 138 | // RemoveTo removes items from s beginning at position from and ending before position to. 139 | // It returns the result. 140 | // 141 | // If from < 0 it counts from the end of s. 142 | // If to <= 0 it counts from the end of s. 143 | // 144 | // The input slice is modified. 145 | // 146 | // Example: RemoveTo([a, b, c, d], 1, 3) -> [a, d] 147 | func RemoveTo[T any, S ~[]T](s S, from, to int) S { 148 | if from < 0 { 149 | from += len(s) 150 | } 151 | if to < 0 { 152 | to += len(s) 153 | } else if to == 0 { 154 | to = len(s) 155 | } 156 | return slices.Delete(s, from, to) 157 | } 158 | 159 | // Prefix returns s up to but not including position idx. 160 | // 161 | // If idx < 0 it counts from the end of s. 162 | func Prefix[T any, S ~[]T](s S, idx int) S { 163 | if idx < 0 { 164 | idx += len(s) 165 | } 166 | return s[:idx] 167 | } 168 | 169 | // PrefixFunc returns the longest prefix of s whose elements all satisfy the given predicate. 170 | func PrefixFunc[T any, S ~[]T](s S, f func(T) bool) S { 171 | idx := IndexFunc(s, invert(f)) 172 | if idx < 0 { 173 | return s 174 | } 175 | return s[:idx] 176 | } 177 | 178 | // Suffix returns s excluding elements before position idx. 179 | // 180 | // If idx < 0 it counts from the end of s. 181 | func Suffix[T any, S ~[]T](s S, idx int) S { 182 | if idx < 0 { 183 | idx += len(s) 184 | } 185 | return s[idx:] 186 | } 187 | 188 | // Rindex returns the index of the last occurrence of v in s, or -1 if not present. 189 | func Rindex[T comparable, S ~[]T](s S, v T) int { 190 | for i := len(s) - 1; i >= 0; i-- { 191 | if s[i] == v { 192 | return i 193 | } 194 | } 195 | return -1 196 | } 197 | 198 | // RindexFunc returns the index of the last element in s that satisfies the given predicate, 199 | // or -1 if no such element exists. 200 | func RindexFunc[T any, S ~[]T](s S, f func(T) bool) int { 201 | for i := len(s) - 1; i >= 0; i-- { 202 | if f(s[i]) { 203 | return i 204 | } 205 | } 206 | return -1 207 | } 208 | 209 | // SuffixFunc returns the longest suffix of s whose elements all satisfy the given predicate. 210 | func SuffixFunc[T any, S ~[]T](s S, f func(T) bool) S { 211 | idx := RindexFunc(s, invert(f)) 212 | if idx < 0 { 213 | return s 214 | } 215 | return s[idx+1:] 216 | } 217 | 218 | func invert[T any](pred func(T) bool) func(T) bool { 219 | return func(val T) bool { 220 | return !pred(val) 221 | } 222 | } 223 | 224 | // SliceN returns n elements of s beginning at position idx. 225 | // 226 | // If idx < 0 it counts from the end of s. 227 | func SliceN[T any, S ~[]T](s S, idx, n int) S { 228 | if idx < 0 { 229 | idx += len(s) 230 | } 231 | return s[idx : idx+n] 232 | } 233 | 234 | // SliceTo returns the elements of s beginning at position from and ending before position to. 235 | // 236 | // If from < 0 it counts from the end of s. 237 | // If to <= 0 it counts from the end of s. 238 | func SliceTo[T any, S ~[]T](s S, from, to int) S { 239 | if from < 0 { 240 | from += len(s) 241 | } 242 | if to < 0 { 243 | to += len(s) 244 | } else if to == 0 { 245 | to = len(s) 246 | } 247 | return s[from:to] 248 | } 249 | 250 | // Each runs a simple function on each item of a slice. 251 | func Each[T any, S ~[]T](s S, f func(T)) { 252 | _ = Eachx(s, func(_ int, val T) error { 253 | f(val) 254 | return nil 255 | }) 256 | } 257 | 258 | // Eachx is the extended form of [Each]. 259 | // It runs a function on each item of a slice, 260 | // passing the index and the item to the function. 261 | // If any call to the function returns an error, 262 | // Eachx stops looping and exits with the error. 263 | func Eachx[T any, S ~[]T](s S, f func(int, T) error) error { 264 | for i, val := range s { 265 | if err := f(i, val); err != nil { 266 | return err 267 | } 268 | } 269 | return nil 270 | } 271 | 272 | // Map runs a simple function on each item of a slice, 273 | // accumulating results in a new slice. 274 | func Map[T, U any, S ~[]T](s S, f func(T) U) []U { 275 | result, _ := Mapx(s, func(_ int, val T) (U, error) { 276 | return f(val), nil 277 | }) 278 | return result 279 | } 280 | 281 | // Mapx is the extended form of [Map]. 282 | // It runs a function on each item of a slice, 283 | // accumulating results in a new slice. 284 | // If any call to the function returns an error, 285 | // Mapx stops looping and exits with the error. 286 | func Mapx[T, U any, S ~[]T](s S, f func(int, T) (U, error)) ([]U, error) { 287 | if len(s) == 0 { 288 | return nil, nil 289 | } 290 | result := make([]U, 0, len(s)) 291 | for i, val := range s { 292 | u, err := f(i, val) 293 | if err != nil { 294 | return nil, err 295 | } 296 | result = append(result, u) 297 | } 298 | return result, nil 299 | } 300 | 301 | // Accum accumulates the result of repeatedly applying a simple function to the elements of a slice. 302 | // 303 | // If the slice has length 0, the result is the zero value of type T. 304 | // If the slice has length 1, the result is s[0]. 305 | // Otherwise, the result is R[len(s)-1], 306 | // where R[0] is s[0] 307 | // and R[n+1] = f(R[n], s[n+1]). 308 | func Accum[T any, S ~[]T](s S, f func(T, T) T) T { 309 | result, _ := Accumx(s, func(a, b T) (T, error) { 310 | return f(a, b), nil 311 | }) 312 | return result 313 | } 314 | 315 | // Accumx is the extended form of [Accum]. 316 | // It accumulates the result of repeatedly applying a function to the elements of a slice. 317 | // 318 | // If the slice has length 0, the result is the zero value of type T. 319 | // If the slice has length 1, the result is s[0]. 320 | // Otherwise, the result is R[len(s)-1], 321 | // where R[0] is s[0] 322 | // and R[n+1] = f(R[n], s[n+1]). 323 | func Accumx[T any, S ~[]T](s S, f func(T, T) (T, error)) (T, error) { 324 | if len(s) == 0 { 325 | var zero T 326 | return zero, nil 327 | } 328 | result := s[0] 329 | for i := 1; i < len(s); i++ { 330 | var err error 331 | result, err = f(result, s[i]) 332 | if err != nil { 333 | return result, err 334 | } 335 | } 336 | return result, nil 337 | } 338 | 339 | // Filter calls a simple predicate for each element of a slice, 340 | // returning a slice of those elements for which the predicate returned true. 341 | func Filter[T any, S ~[]T](s S, f func(T) bool) S { 342 | result, _ := Filterx(s, func(val T) (bool, error) { 343 | return f(val), nil 344 | }) 345 | return result 346 | } 347 | 348 | // Filterx is the extended form of [Filter]. 349 | // It calls a predicate for each element of a slice, 350 | // returning a slice of those elements for which the predicate returned true. 351 | func Filterx[T any, S ~[]T](s S, f func(T) (bool, error)) (S, error) { 352 | var result S 353 | for _, val := range s { 354 | ok, err := f(val) 355 | if err != nil { 356 | return nil, err 357 | } 358 | if !ok { 359 | continue 360 | } 361 | result = append(result, val) 362 | } 363 | return result, nil 364 | } 365 | 366 | // Group partitions the elements of a slice into groups. 367 | // It does this by calling a simple grouping function on each element, 368 | // which produces a grouping key. 369 | // The result is a map of group keys to slices of elements having that key. 370 | func Group[T any, K comparable, S ~[]T](s S, f func(T) K) map[K]S { 371 | result, _ := Groupx(s, func(val T) (K, error) { 372 | return f(val), nil 373 | }) 374 | return result 375 | } 376 | 377 | // Groupx is the extended form of [Group]. 378 | // It partitions the elements of a slice into groups. 379 | // It does this by calling a grouping function on each element, 380 | // which produces a grouping key. 381 | // The result is a map of group keys to slices of elements having that key. 382 | func Groupx[T any, K comparable, S ~[]T](s S, f func(T) (K, error)) (map[K]S, error) { 383 | result := make(map[K]S) 384 | for _, val := range s { 385 | key, err := f(val) 386 | if err != nil { 387 | return nil, err 388 | } 389 | result[key] = append(result[key], val) 390 | } 391 | return result, nil 392 | } 393 | 394 | // Rotate rotates a slice in place by n places to the right. 395 | // (With negative n, it's to the left.) 396 | // Example: Rotate([D, E, A, B, C], 3) -> [A, B, C, D, E] 397 | func Rotate[T any, S ~[]T](s S, n int) { 398 | if n < 0 { 399 | // Convert left-rotation to right-rotation. 400 | n = -n 401 | n %= len(s) 402 | n = len(s) - n 403 | } else { 404 | n %= len(s) 405 | } 406 | if n == 0 { 407 | return 408 | } 409 | tmp := make([]T, n) 410 | copy(tmp, s[len(s)-n:]) 411 | copy(s[n:], s) 412 | copy(s, tmp) 413 | } 414 | 415 | // KeyedSort sorts the given slice according to the ordering of the given keys, 416 | // whose items must map 1:1 with the slice. 417 | // It is an unchecked error if keys.Len() != len(slice). 418 | // 419 | // Both arguments end up sorted in place: 420 | // keys according to its Less method, 421 | // and slice by mirroring the reordering that happens in keys. 422 | func KeyedSort[T any, S ~[]T](slice S, keys sort.Interface) { 423 | ks := keyedSorter[T, S]{ 424 | keys: keys, 425 | slice: slice, 426 | } 427 | sort.Sort(ks) 428 | } 429 | 430 | // KeyedSorter allows sorting a slice according to the order of a set of sort keys. 431 | // It works by sorting a [sort.Interface] containing sort keys 432 | // that must map 1:1 with the items of the slice you wish to sort. 433 | // (It is an error for Keys.Len() to differ from len(Slice).) 434 | // Any reordering applied to Keys is also applied to Slice. 435 | type keyedSorter[T any, S ~[]T] struct { 436 | keys sort.Interface 437 | slice S 438 | } 439 | 440 | func (k keyedSorter[T, S]) Len() int { return len(k.slice) } 441 | func (k keyedSorter[T, S]) Less(i, j int) bool { return k.keys.Less(i, j) } 442 | func (k keyedSorter[T, S]) Swap(i, j int) { 443 | k.keys.Swap(i, j) 444 | k.slice[i], k.slice[j] = k.slice[j], k.slice[i] 445 | } 446 | 447 | // NonNil converts a nil slice to a non-nil empty slice. 448 | // It returns other slices unchanged. 449 | // 450 | // A nil slice is usually preferable, 451 | // since it is equivalent to an empty slice in almost every way 452 | // and does not have the overhead of an allocation. 453 | // (See https://dave.cheney.net/2018/07/12/slices-from-the-ground-up.) 454 | // However, there are some corner cases where the difference matters, 455 | // notably when marshaling to JSON, 456 | // where an empty slice marshals as the array [] 457 | // but a nil slice marshals as the non-array `null`. 458 | func NonNil[T any, S ~[]T](s S) S { 459 | if s == nil { 460 | return S{} 461 | } 462 | return s 463 | } 464 | -------------------------------------------------------------------------------- /slices/dropin_test.go: -------------------------------------------------------------------------------- 1 | // Adapted from golang.org/x/exp/slices and the Go standard library. 2 | 3 | // Copyright 2021-2024 The Go Authors. All rights reserved. 4 | // Use of this source code is governed by a BSD-style 5 | // license that can be found in the LICENSE file. 6 | 7 | package slices 8 | 9 | import ( 10 | "cmp" 11 | "fmt" 12 | "iter" 13 | "math" 14 | "math/rand/v2" 15 | "sort" 16 | "strconv" 17 | "strings" 18 | "testing" 19 | ) 20 | 21 | var raceEnabled bool 22 | 23 | var equalIntTests = []struct { 24 | s1, s2 []int 25 | want bool 26 | }{ 27 | { 28 | []int{1}, 29 | nil, 30 | false, 31 | }, 32 | { 33 | []int{}, 34 | nil, 35 | true, 36 | }, 37 | { 38 | []int{1, 2, 3}, 39 | []int{1, 2, 3}, 40 | true, 41 | }, 42 | { 43 | []int{1, 2, 3}, 44 | []int{1, 2, 3, 4}, 45 | false, 46 | }, 47 | } 48 | 49 | var equalFloatTests = []struct { 50 | s1, s2 []float64 51 | wantEqual bool 52 | wantEqualNaN bool 53 | }{ 54 | { 55 | []float64{1, 2}, 56 | []float64{1, 2}, 57 | true, 58 | true, 59 | }, 60 | { 61 | []float64{1, 2, math.NaN()}, 62 | []float64{1, 2, math.NaN()}, 63 | false, 64 | true, 65 | }, 66 | } 67 | 68 | func TestEqual(t *testing.T) { 69 | for _, test := range equalIntTests { 70 | if got := Equal(test.s1, test.s2); got != test.want { 71 | t.Errorf("Equal(%v, %v) = %t, want %t", test.s1, test.s2, got, test.want) 72 | } 73 | } 74 | for _, test := range equalFloatTests { 75 | if got := Equal(test.s1, test.s2); got != test.wantEqual { 76 | t.Errorf("Equal(%v, %v) = %t, want %t", test.s1, test.s2, got, test.wantEqual) 77 | } 78 | } 79 | } 80 | 81 | // equal is simply ==. 82 | func equal[T comparable](v1, v2 T) bool { 83 | return v1 == v2 84 | } 85 | 86 | // equalNaN is like == except that all NaNs are equal. 87 | func equalNaN[T comparable](v1, v2 T) bool { 88 | isNaN := func(f T) bool { return f != f } 89 | return v1 == v2 || (isNaN(v1) && isNaN(v2)) 90 | } 91 | 92 | // offByOne returns true if integers v1 and v2 differ by 1. 93 | func offByOne(v1, v2 int) bool { 94 | return v1 == v2+1 || v1 == v2-1 95 | } 96 | 97 | func TestEqualFunc(t *testing.T) { 98 | for _, test := range equalIntTests { 99 | if got := EqualFunc(test.s1, test.s2, equal[int]); got != test.want { 100 | t.Errorf("EqualFunc(%v, %v, equal[int]) = %t, want %t", test.s1, test.s2, got, test.want) 101 | } 102 | } 103 | for _, test := range equalFloatTests { 104 | if got := EqualFunc(test.s1, test.s2, equal[float64]); got != test.wantEqual { 105 | t.Errorf("Equal(%v, %v, equal[float64]) = %t, want %t", test.s1, test.s2, got, test.wantEqual) 106 | } 107 | if got := EqualFunc(test.s1, test.s2, equalNaN[float64]); got != test.wantEqualNaN { 108 | t.Errorf("Equal(%v, %v, equalNaN[float64]) = %t, want %t", test.s1, test.s2, got, test.wantEqualNaN) 109 | } 110 | } 111 | 112 | s1 := []int{1, 2, 3} 113 | s2 := []int{2, 3, 4} 114 | if EqualFunc(s1, s1, offByOne) { 115 | t.Errorf("EqualFunc(%v, %v, offByOne) = true, want false", s1, s1) 116 | } 117 | if !EqualFunc(s1, s2, offByOne) { 118 | t.Errorf("EqualFunc(%v, %v, offByOne) = false, want true", s1, s2) 119 | } 120 | 121 | s3 := []string{"a", "b", "c"} 122 | s4 := []string{"A", "B", "C"} 123 | if !EqualFunc(s3, s4, strings.EqualFold) { 124 | t.Errorf("EqualFunc(%v, %v, strings.EqualFold) = false, want true", s3, s4) 125 | } 126 | 127 | cmpIntString := func(v1 int, v2 string) bool { 128 | return string(rune(v1)-1+'a') == v2 129 | } 130 | if !EqualFunc(s1, s3, cmpIntString) { 131 | t.Errorf("EqualFunc(%v, %v, cmpIntString) = false, want true", s1, s3) 132 | } 133 | } 134 | 135 | var compareIntTests = []struct { 136 | s1, s2 []int 137 | want int 138 | }{ 139 | { 140 | []int{1, 2, 3}, 141 | []int{1, 2, 3, 4}, 142 | -1, 143 | }, 144 | { 145 | []int{1, 2, 3, 4}, 146 | []int{1, 2, 3}, 147 | +1, 148 | }, 149 | { 150 | []int{1, 2, 3}, 151 | []int{1, 4, 3}, 152 | -1, 153 | }, 154 | { 155 | []int{1, 4, 3}, 156 | []int{1, 2, 3}, 157 | +1, 158 | }, 159 | } 160 | 161 | var compareFloatTests = []struct { 162 | s1, s2 []float64 163 | want int 164 | }{ 165 | { 166 | []float64{1, 2, math.NaN()}, 167 | []float64{1, 2, math.NaN()}, 168 | 0, 169 | }, 170 | { 171 | []float64{1, math.NaN(), 3}, 172 | []float64{1, math.NaN(), 4}, 173 | -1, 174 | }, 175 | { 176 | []float64{1, math.NaN(), 3}, 177 | []float64{1, 2, 4}, 178 | -1, 179 | }, 180 | { 181 | []float64{1, math.NaN(), 3}, 182 | []float64{1, 2, math.NaN()}, 183 | -1, 184 | }, 185 | { 186 | []float64{1, math.NaN(), 3, 4}, 187 | []float64{1, 2, math.NaN()}, 188 | -1, 189 | }, 190 | } 191 | 192 | func TestCompare(t *testing.T) { 193 | intWant := func(want bool) string { 194 | if want { 195 | return "0" 196 | } 197 | return "!= 0" 198 | } 199 | for _, test := range equalIntTests { 200 | if got := Compare(test.s1, test.s2); (got == 0) != test.want { 201 | t.Errorf("Compare(%v, %v) = %d, want %s", test.s1, test.s2, got, intWant(test.want)) 202 | } 203 | } 204 | for _, test := range equalFloatTests { 205 | if got := Compare(test.s1, test.s2); (got == 0) != test.wantEqualNaN { 206 | t.Errorf("Compare(%v, %v) = %d, want %s", test.s1, test.s2, got, intWant(test.wantEqualNaN)) 207 | } 208 | } 209 | 210 | for _, test := range compareIntTests { 211 | if got := Compare(test.s1, test.s2); got != test.want { 212 | t.Errorf("Compare(%v, %v) = %d, want %d", test.s1, test.s2, got, test.want) 213 | } 214 | } 215 | for _, test := range compareFloatTests { 216 | if got := Compare(test.s1, test.s2); got != test.want { 217 | t.Errorf("Compare(%v, %v) = %d, want %d", test.s1, test.s2, got, test.want) 218 | } 219 | } 220 | } 221 | 222 | func equalToCmp[T comparable](eq func(T, T) bool) func(T, T) int { 223 | return func(v1, v2 T) int { 224 | if eq(v1, v2) { 225 | return 0 226 | } 227 | return 1 228 | } 229 | } 230 | 231 | func TestCompareFunc(t *testing.T) { 232 | intWant := func(want bool) string { 233 | if want { 234 | return "0" 235 | } 236 | return "!= 0" 237 | } 238 | for _, test := range equalIntTests { 239 | if got := CompareFunc(test.s1, test.s2, equalToCmp(equal[int])); (got == 0) != test.want { 240 | t.Errorf("CompareFunc(%v, %v, equalToCmp(equal[int])) = %d, want %s", test.s1, test.s2, got, intWant(test.want)) 241 | } 242 | } 243 | for _, test := range equalFloatTests { 244 | if got := CompareFunc(test.s1, test.s2, equalToCmp(equal[float64])); (got == 0) != test.wantEqual { 245 | t.Errorf("CompareFunc(%v, %v, equalToCmp(equal[float64])) = %d, want %s", test.s1, test.s2, got, intWant(test.wantEqual)) 246 | } 247 | } 248 | 249 | for _, test := range compareIntTests { 250 | if got := CompareFunc(test.s1, test.s2, cmp.Compare); got != test.want { 251 | t.Errorf("CompareFunc(%v, %v, cmp.Compare) = %d, want %d", test.s1, test.s2, got, test.want) 252 | } 253 | } 254 | for _, test := range compareFloatTests { 255 | if got := CompareFunc(test.s1, test.s2, cmp.Compare); got != test.want { 256 | t.Errorf("CompareFunc(%v, %v, cmp.Compare) = %d, want %d", test.s1, test.s2, got, test.want) 257 | } 258 | } 259 | 260 | s1 := []int{1, 2, 3} 261 | s2 := []int{2, 3, 4} 262 | if got := CompareFunc(s1, s2, equalToCmp(offByOne)); got != 0 { 263 | t.Errorf("CompareFunc(%v, %v, offByOne) = %d, want 0", s1, s2, got) 264 | } 265 | 266 | s3 := []string{"a", "b", "c"} 267 | s4 := []string{"A", "B", "C"} 268 | if got := CompareFunc(s3, s4, strings.Compare); got != 1 { 269 | t.Errorf("CompareFunc(%v, %v, strings.Compare) = %d, want 1", s3, s4, got) 270 | } 271 | 272 | compareLower := func(v1, v2 string) int { 273 | return strings.Compare(strings.ToLower(v1), strings.ToLower(v2)) 274 | } 275 | if got := CompareFunc(s3, s4, compareLower); got != 0 { 276 | t.Errorf("CompareFunc(%v, %v, compareLower) = %d, want 0", s3, s4, got) 277 | } 278 | 279 | cmpIntString := func(v1 int, v2 string) int { 280 | return strings.Compare(string(rune(v1)-1+'a'), v2) 281 | } 282 | if got := CompareFunc(s1, s3, cmpIntString); got != 0 { 283 | t.Errorf("CompareFunc(%v, %v, cmpIntString) = %d, want 0", s1, s3, got) 284 | } 285 | } 286 | 287 | var indexTests = []struct { 288 | s []int 289 | v int 290 | want int 291 | }{ 292 | { 293 | nil, 294 | 0, 295 | -1, 296 | }, 297 | { 298 | []int{}, 299 | 0, 300 | -1, 301 | }, 302 | { 303 | []int{1, 2, 3}, 304 | 2, 305 | 1, 306 | }, 307 | { 308 | []int{1, 2, 2, 3}, 309 | 2, 310 | 1, 311 | }, 312 | { 313 | []int{1, 2, 3, 2}, 314 | 2, 315 | 1, 316 | }, 317 | } 318 | 319 | func TestIndex(t *testing.T) { 320 | for _, test := range indexTests { 321 | if got := Index(test.s, test.v); got != test.want { 322 | t.Errorf("Index(%v, %v) = %d, want %d", test.s, test.v, got, test.want) 323 | } 324 | } 325 | } 326 | 327 | func equalToIndex[T any](f func(T, T) bool, v1 T) func(T) bool { 328 | return func(v2 T) bool { 329 | return f(v1, v2) 330 | } 331 | } 332 | 333 | func TestIndexFunc(t *testing.T) { 334 | for _, test := range indexTests { 335 | if got := IndexFunc(test.s, equalToIndex(equal[int], test.v)); got != test.want { 336 | t.Errorf("IndexFunc(%v, equalToIndex(equal[int], %v)) = %d, want %d", test.s, test.v, got, test.want) 337 | } 338 | } 339 | 340 | s1 := []string{"hi", "HI"} 341 | if got := IndexFunc(s1, equalToIndex(equal[string], "HI")); got != 1 { 342 | t.Errorf("IndexFunc(%v, equalToIndex(equal[string], %q)) = %d, want %d", s1, "HI", got, 1) 343 | } 344 | if got := IndexFunc(s1, equalToIndex(strings.EqualFold, "HI")); got != 0 { 345 | t.Errorf("IndexFunc(%v, equalToIndex(strings.EqualFold, %q)) = %d, want %d", s1, "HI", got, 0) 346 | } 347 | } 348 | 349 | func TestContains(t *testing.T) { 350 | for _, test := range indexTests { 351 | if got := Contains(test.s, test.v); got != (test.want != -1) { 352 | t.Errorf("Contains(%v, %v) = %t, want %t", test.s, test.v, got, test.want != -1) 353 | } 354 | } 355 | } 356 | 357 | func TestContainsFunc(t *testing.T) { 358 | for _, test := range indexTests { 359 | if got := ContainsFunc(test.s, equalToIndex(equal[int], test.v)); got != (test.want != -1) { 360 | t.Errorf("ContainsFunc(%v, equalToIndex(equal[int], %v)) = %t, want %t", test.s, test.v, got, test.want != -1) 361 | } 362 | } 363 | 364 | s1 := []string{"hi", "HI"} 365 | if got := ContainsFunc(s1, equalToIndex(equal[string], "HI")); got != true { 366 | t.Errorf("ContainsFunc(%v, equalToContains(equal[string], %q)) = %t, want %t", s1, "HI", got, true) 367 | } 368 | if got := ContainsFunc(s1, equalToIndex(equal[string], "hI")); got != false { 369 | t.Errorf("ContainsFunc(%v, equalToContains(strings.EqualFold, %q)) = %t, want %t", s1, "hI", got, false) 370 | } 371 | if got := ContainsFunc(s1, equalToIndex(strings.EqualFold, "hI")); got != true { 372 | t.Errorf("ContainsFunc(%v, equalToContains(strings.EqualFold, %q)) = %t, want %t", s1, "hI", got, true) 373 | } 374 | } 375 | 376 | // var insertTests = []struct { 377 | // s []int 378 | // i int 379 | // add []int 380 | // want []int 381 | // }{ 382 | // { 383 | // []int{1, 2, 3}, 384 | // 0, 385 | // []int{4}, 386 | // []int{4, 1, 2, 3}, 387 | // }, 388 | // { 389 | // []int{1, 2, 3}, 390 | // 1, 391 | // []int{4}, 392 | // []int{1, 4, 2, 3}, 393 | // }, 394 | // { 395 | // []int{1, 2, 3}, 396 | // 3, 397 | // []int{4}, 398 | // []int{1, 2, 3, 4}, 399 | // }, 400 | // { 401 | // []int{1, 2, 3}, 402 | // 2, 403 | // []int{4, 5}, 404 | // []int{1, 2, 4, 5, 3}, 405 | // }, 406 | // } 407 | 408 | // func TestInsert(t *testing.T) { 409 | // s := []int{1, 2, 3} 410 | // if got := Insert(s, 0); !Equal(got, s) { 411 | // t.Errorf("Insert(%v, 0) = %v, want %v", s, got, s) 412 | // } 413 | // for _, test := range insertTests { 414 | // copy := Clone(test.s) 415 | // if got := Insert(copy, test.i, test.add...); !Equal(got, test.want) { 416 | // t.Errorf("Insert(%v, %d, %v...) = %v, want %v", test.s, test.i, test.add, got, test.want) 417 | // } 418 | // } 419 | // } 420 | 421 | var deleteTests = []struct { 422 | s []int 423 | i, j int 424 | want []int 425 | }{ 426 | { 427 | []int{1, 2, 3}, 428 | 0, 429 | 0, 430 | []int{}, 431 | }, 432 | { 433 | []int{1, 2, 3}, 434 | 0, 435 | 1, 436 | []int{2, 3}, 437 | }, 438 | { 439 | []int{1, 2, 3}, 440 | 3, 441 | 3, 442 | []int{1, 2, 3}, 443 | }, 444 | { 445 | []int{1, 2, 3}, 446 | 0, 447 | 2, 448 | []int{3}, 449 | }, 450 | { 451 | []int{1, 2, 3}, 452 | 0, 453 | 3, 454 | []int{}, 455 | }, 456 | } 457 | 458 | func TestDelete(t *testing.T) { 459 | for _, test := range deleteTests { 460 | copy := Clone(test.s) 461 | if got := Delete(copy, test.i, test.j); !Equal(got, test.want) { 462 | t.Errorf("Delete(%v, %d, %d) = %v, want %v", test.s, test.i, test.j, got, test.want) 463 | } 464 | } 465 | } 466 | 467 | func panics(f func()) (b bool) { 468 | defer func() { 469 | if x := recover(); x != nil { 470 | b = true 471 | } 472 | }() 473 | f() 474 | return false 475 | } 476 | 477 | func TestDeletePanics(t *testing.T) { 478 | for _, test := range []struct { 479 | name string 480 | s []int 481 | i, j int 482 | }{ 483 | {"with negative first index", []int{42}, -2, 1}, 484 | {"with negative second index", []int{42}, 1, -1}, 485 | {"with out-of-bounds first index", []int{42}, 2, 3}, 486 | {"with out-of-bounds second index", []int{42}, 0, 2}, 487 | } { 488 | if !panics(func() { Delete(test.s, test.i, test.j) }) { 489 | t.Errorf("Delete %s: got no panic, want panic", test.name) 490 | } 491 | } 492 | } 493 | 494 | func TestClone(t *testing.T) { 495 | s1 := []int{1, 2, 3} 496 | s2 := Clone(s1) 497 | if !Equal(s1, s2) { 498 | t.Errorf("Clone(%v) = %v, want %v", s1, s2, s1) 499 | } 500 | s1[0] = 4 501 | want := []int{1, 2, 3} 502 | if !Equal(s2, want) { 503 | t.Errorf("Clone(%v) changed unexpectedly to %v", want, s2) 504 | } 505 | if got := Clone([]int(nil)); got != nil { 506 | t.Errorf("Clone(nil) = %#v, want nil", got) 507 | } 508 | if got := Clone(s1[:0]); got == nil || len(got) != 0 { 509 | t.Errorf("Clone(%v) = %#v, want %#v", s1[:0], got, s1[:0]) 510 | } 511 | } 512 | 513 | var compactTests = []struct { 514 | name string 515 | s []int 516 | want []int 517 | }{ 518 | { 519 | "nil", 520 | nil, 521 | nil, 522 | }, 523 | { 524 | "one", 525 | []int{1}, 526 | []int{1}, 527 | }, 528 | { 529 | "sorted", 530 | []int{1, 2, 3}, 531 | []int{1, 2, 3}, 532 | }, 533 | { 534 | "1 item", 535 | []int{1, 1, 2}, 536 | []int{1, 2}, 537 | }, 538 | { 539 | "unsorted", 540 | []int{1, 2, 1}, 541 | []int{1, 2, 1}, 542 | }, 543 | { 544 | "many", 545 | []int{1, 2, 2, 3, 3, 4}, 546 | []int{1, 2, 3, 4}, 547 | }, 548 | } 549 | 550 | func TestCompact(t *testing.T) { 551 | for _, test := range compactTests { 552 | copy := Clone(test.s) 553 | if got := Compact(copy); !Equal(got, test.want) { 554 | t.Errorf("Compact(%v) = %v, want %v", test.s, got, test.want) 555 | } 556 | } 557 | } 558 | 559 | func TestCompactFunc(t *testing.T) { 560 | for _, test := range compactTests { 561 | copy := Clone(test.s) 562 | if got := CompactFunc(copy, equal[int]); !Equal(got, test.want) { 563 | t.Errorf("CompactFunc(%v, equal[int]) = %v, want %v", test.s, got, test.want) 564 | } 565 | } 566 | 567 | s1 := []string{"a", "a", "A", "B", "b"} 568 | copy := Clone(s1) 569 | want := []string{"a", "B"} 570 | if got := CompactFunc(copy, strings.EqualFold); !Equal(got, want) { 571 | t.Errorf("CompactFunc(%v, strings.EqualFold) = %v, want %v", s1, got, want) 572 | } 573 | } 574 | 575 | func TestGrow(t *testing.T) { 576 | s1 := []int{1, 2, 3} 577 | 578 | copy := Clone(s1) 579 | s2 := Grow(copy, 1000) 580 | if !Equal(s1, s2) { 581 | t.Errorf("Grow(%v) = %v, want %v", s1, s2, s1) 582 | } 583 | if cap(s2) < 1000+len(s1) { 584 | t.Errorf("after Grow(%v) cap = %d, want >= %d", s1, cap(s2), 1000+len(s1)) 585 | } 586 | 587 | // Test mutation of elements between length and capacity. 588 | copy = Clone(s1) 589 | s3 := Grow(copy[:1], 2)[:3] 590 | if !Equal(s1, s3) { 591 | t.Errorf("Grow should not mutate elements between length and capacity") 592 | } 593 | s3 = Grow(copy[:1], 1000)[:3] 594 | if !Equal(s1, s3) { 595 | t.Errorf("Grow should not mutate elements between length and capacity") 596 | } 597 | 598 | // Test number of allocations. 599 | if n := testing.AllocsPerRun(100, func() { Grow(s2, cap(s2)-len(s2)) }); n != 0 { 600 | t.Errorf("Grow should not allocate when given sufficient capacity; allocated %v times", n) 601 | } 602 | if n := testing.AllocsPerRun(100, func() { Grow(s2, cap(s2)-len(s2)+1) }); n != 1 { 603 | errorf := t.Errorf 604 | if raceEnabled { 605 | errorf = t.Logf // this allocates multiple times in race detector mode 606 | } 607 | errorf("Grow should allocate once when given insufficient capacity; allocated %v times", n) 608 | } 609 | 610 | // Test for negative growth sizes. 611 | var gotPanic bool 612 | func() { 613 | defer func() { gotPanic = recover() != nil }() 614 | Grow(s1, -1) 615 | }() 616 | if !gotPanic { 617 | t.Errorf("Grow(-1) did not panic; expected a panic") 618 | } 619 | } 620 | 621 | func TestClip(t *testing.T) { 622 | s1 := []int{1, 2, 3, 4, 5, 6}[:3] 623 | orig := Clone(s1) 624 | if len(s1) != 3 { 625 | t.Errorf("len(%v) = %d, want 3", s1, len(s1)) 626 | } 627 | if cap(s1) < 6 { 628 | t.Errorf("cap(%v[:3]) = %d, want >= 6", orig, cap(s1)) 629 | } 630 | s2 := Clip(s1) 631 | if !Equal(s1, s2) { 632 | t.Errorf("Clip(%v) = %v, want %v", s1, s2, s1) 633 | } 634 | if cap(s2) != 3 { 635 | t.Errorf("cap(Clip(%v)) = %d, want 3", orig, cap(s2)) 636 | } 637 | } 638 | 639 | // naiveReplace is a baseline implementation to the Replace function. 640 | func naiveReplace[S ~[]E, E any](s S, i, j int, v ...E) S { 641 | s = Delete(s, i, j) 642 | s = Insert(s, i, v...) 643 | return s 644 | } 645 | 646 | func TestReplace(t *testing.T) { 647 | for _, test := range []struct { 648 | s, v []int 649 | i, j int 650 | }{ 651 | {}, // all zero value 652 | { 653 | s: []int{1, 2, 3, 4}, 654 | v: []int{5}, 655 | i: 1, 656 | j: 2, 657 | }, 658 | { 659 | s: []int{1, 2, 3, 4}, 660 | v: []int{5, 6, 7, 8}, 661 | i: 1, 662 | j: 2, 663 | }, 664 | { 665 | s: func() []int { 666 | s := make([]int, 3, 20) 667 | s[0] = 0 668 | s[1] = 1 669 | s[2] = 2 670 | return s 671 | }(), 672 | v: []int{3, 4, 5, 6, 7}, 673 | i: 0, 674 | j: 1, 675 | }, 676 | } { 677 | ss, vv := Clone(test.s), Clone(test.v) 678 | want := naiveReplace(ss, test.i, test.j, vv...) 679 | got := Replace(test.s, test.i, test.j, test.v...) 680 | if !Equal(got, want) { 681 | t.Errorf("Replace(%v, %v, %v, %v) = %v, want %v", test.s, test.i, test.j, test.v, got, want) 682 | } 683 | } 684 | } 685 | 686 | func TestReplacePanics(t *testing.T) { 687 | for _, test := range []struct { 688 | name string 689 | s, v []int 690 | i, j int 691 | }{ 692 | {"indexes out of order", []int{1, 2}, []int{3}, 2, 1}, 693 | {"large index", []int{1, 2}, []int{3}, 1, 10}, 694 | } { 695 | ss, vv := Clone(test.s), Clone(test.v) 696 | if !panics(func() { Replace(ss, test.i, test.j, vv...) }) { 697 | t.Errorf("Replace %s: should have panicked", test.name) 698 | } 699 | } 700 | } 701 | 702 | var ints = [...]int{74, 59, 238, -784, 9845, 959, 905, 0, 0, 42, 7586, -5467984, 7586} 703 | var float64s = [...]float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.Inf(-1), 9845.768, -959.7485, 905, 7.8, 7.8, 74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3} 704 | var float64sWithNaNs = [...]float64{74.3, 59.0, math.Inf(1), 238.2, -784.0, 2.3, math.NaN(), math.NaN(), math.Inf(-1), 9845.768, -959.7485, 905, 7.8, 7.8} 705 | var strs = [...]string{"", "Hello", "foo", "bar", "foo", "f00", "%*&^*&^&", "***"} 706 | 707 | func TestSortIntSlice(t *testing.T) { 708 | data := Clone(ints[:]) 709 | Sort(data) 710 | if !IsSorted(data) { 711 | t.Errorf("sorted %v", ints) 712 | t.Errorf(" got %v", data) 713 | } 714 | } 715 | 716 | func TestSortFuncIntSlice(t *testing.T) { 717 | data := Clone(ints[:]) 718 | SortFunc(data, func(a, b int) int { return a - b }) 719 | if !IsSorted(data) { 720 | t.Errorf("sorted %v", ints) 721 | t.Errorf(" got %v", data) 722 | } 723 | } 724 | 725 | func TestSortFloat64Slice(t *testing.T) { 726 | data := Clone(float64s[:]) 727 | Sort(data) 728 | if !IsSorted(data) { 729 | t.Errorf("sorted %v", float64s) 730 | t.Errorf(" got %v", data) 731 | } 732 | } 733 | 734 | func TestSortFloat64SliceWithNaNs(t *testing.T) { 735 | data := float64sWithNaNs[:] 736 | input := Clone(data) 737 | 738 | // Make sure Sort doesn't panic when the slice contains NaNs. 739 | Sort(data) 740 | // Check whether the result is a permutation of the input. 741 | sort.Float64s(data) 742 | sort.Float64s(input) 743 | for i, v := range input { 744 | if data[i] != v && !(math.IsNaN(data[i]) && math.IsNaN(v)) { 745 | t.Fatalf("the result is not a permutation of the input\ngot %v\nwant %v", data, input) 746 | } 747 | } 748 | } 749 | 750 | func TestSortStringSlice(t *testing.T) { 751 | data := Clone(strs[:]) 752 | Sort(data) 753 | if !IsSorted(data) { 754 | t.Errorf("sorted %v", strs) 755 | t.Errorf(" got %v", data) 756 | } 757 | } 758 | 759 | func TestSortLarge_Random(t *testing.T) { 760 | n := 1000000 761 | if testing.Short() { 762 | n /= 100 763 | } 764 | data := make([]int, n) 765 | for i := 0; i < len(data); i++ { 766 | data[i] = rand.IntN(100) 767 | } 768 | if IsSorted(data) { 769 | t.Fatalf("terrible rand.rand") 770 | } 771 | Sort(data) 772 | if !IsSorted(data) { 773 | t.Errorf("sort didn't sort - 1M ints") 774 | } 775 | } 776 | 777 | type intPair struct { 778 | a, b int 779 | } 780 | 781 | type intPairs []intPair 782 | 783 | // Pairs compare on a only. 784 | func intPairCmp(x, y intPair) int { 785 | return x.a - y.a 786 | } 787 | 788 | // Record initial order in B. 789 | func (d intPairs) initB() { 790 | for i := range d { 791 | d[i].b = i 792 | } 793 | } 794 | 795 | // InOrder checks if a-equal elements were not reordered. 796 | // If reversed is true, expect reverse ordering. 797 | func (d intPairs) inOrder(reversed bool) bool { 798 | lastA, lastB := -1, 0 799 | for i := 0; i < len(d); i++ { 800 | if lastA != d[i].a { 801 | lastA = d[i].a 802 | lastB = d[i].b 803 | continue 804 | } 805 | if !reversed { 806 | if d[i].b <= lastB { 807 | return false 808 | } 809 | } else { 810 | if d[i].b >= lastB { 811 | return false 812 | } 813 | } 814 | lastB = d[i].b 815 | } 816 | return true 817 | } 818 | 819 | func TestStability(t *testing.T) { 820 | n, m := 100000, 1000 821 | if testing.Short() { 822 | n, m = 1000, 100 823 | } 824 | data := make(intPairs, n) 825 | 826 | // random distribution 827 | for i := 0; i < len(data); i++ { 828 | data[i].a = rand.IntN(m) 829 | } 830 | if IsSortedFunc(data, intPairCmp) { 831 | t.Fatalf("terrible rand.rand") 832 | } 833 | data.initB() 834 | SortStableFunc(data, intPairCmp) 835 | if !IsSortedFunc(data, intPairCmp) { 836 | t.Errorf("Stable didn't sort %d ints", n) 837 | } 838 | if !data.inOrder(false) { 839 | t.Errorf("Stable wasn't stable on %d ints", n) 840 | } 841 | 842 | // already sorted 843 | data.initB() 844 | SortStableFunc(data, intPairCmp) 845 | if !IsSortedFunc(data, intPairCmp) { 846 | t.Errorf("Stable shuffled sorted %d ints (order)", n) 847 | } 848 | if !data.inOrder(false) { 849 | t.Errorf("Stable shuffled sorted %d ints (stability)", n) 850 | } 851 | 852 | // sorted reversed 853 | for i := 0; i < len(data); i++ { 854 | data[i].a = len(data) - i 855 | } 856 | data.initB() 857 | SortStableFunc(data, intPairCmp) 858 | if !IsSortedFunc(data, intPairCmp) { 859 | t.Errorf("Stable didn't sort %d ints", n) 860 | } 861 | if !data.inOrder(false) { 862 | t.Errorf("Stable wasn't stable on %d ints", n) 863 | } 864 | } 865 | 866 | func TestBinarySearch(t *testing.T) { 867 | str1 := []string{"foo"} 868 | str2 := []string{"ab", "ca"} 869 | str3 := []string{"mo", "qo", "vo"} 870 | str4 := []string{"ab", "ad", "ca", "xy"} 871 | 872 | // slice with repeating elements 873 | strRepeats := []string{"ba", "ca", "da", "da", "da", "ka", "ma", "ma", "ta"} 874 | 875 | // slice with all element equal 876 | strSame := []string{"xx", "xx", "xx"} 877 | 878 | tests := []struct { 879 | data []string 880 | target string 881 | wantPos int 882 | wantFound bool 883 | }{ 884 | {[]string{}, "foo", 0, false}, 885 | {[]string{}, "", 0, false}, 886 | 887 | {str1, "foo", 0, true}, 888 | {str1, "bar", 0, false}, 889 | {str1, "zx", 1, false}, 890 | 891 | {str2, "aa", 0, false}, 892 | {str2, "ab", 0, true}, 893 | {str2, "ad", 1, false}, 894 | {str2, "ca", 1, true}, 895 | {str2, "ra", 2, false}, 896 | 897 | {str3, "bb", 0, false}, 898 | {str3, "mo", 0, true}, 899 | {str3, "nb", 1, false}, 900 | {str3, "qo", 1, true}, 901 | {str3, "tr", 2, false}, 902 | {str3, "vo", 2, true}, 903 | {str3, "xr", 3, false}, 904 | 905 | {str4, "aa", 0, false}, 906 | {str4, "ab", 0, true}, 907 | {str4, "ac", 1, false}, 908 | {str4, "ad", 1, true}, 909 | {str4, "ax", 2, false}, 910 | {str4, "ca", 2, true}, 911 | {str4, "cc", 3, false}, 912 | {str4, "dd", 3, false}, 913 | {str4, "xy", 3, true}, 914 | {str4, "zz", 4, false}, 915 | 916 | {strRepeats, "da", 2, true}, 917 | {strRepeats, "db", 5, false}, 918 | {strRepeats, "ma", 6, true}, 919 | {strRepeats, "mb", 8, false}, 920 | 921 | {strSame, "xx", 0, true}, 922 | {strSame, "ab", 0, false}, 923 | {strSame, "zz", 3, false}, 924 | } 925 | for _, tt := range tests { 926 | t.Run(tt.target, func(t *testing.T) { 927 | { 928 | pos, found := BinarySearch(tt.data, tt.target) 929 | if pos != tt.wantPos || found != tt.wantFound { 930 | t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) 931 | } 932 | } 933 | 934 | { 935 | pos, found := BinarySearchFunc(tt.data, tt.target, strings.Compare) 936 | if pos != tt.wantPos || found != tt.wantFound { 937 | t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) 938 | } 939 | } 940 | }) 941 | } 942 | } 943 | 944 | func TestBinarySearchInts(t *testing.T) { 945 | data := []int{20, 30, 40, 50, 60, 70, 80, 90} 946 | tests := []struct { 947 | target int 948 | wantPos int 949 | wantFound bool 950 | }{ 951 | {20, 0, true}, 952 | {23, 1, false}, 953 | {43, 3, false}, 954 | {80, 6, true}, 955 | } 956 | for _, tt := range tests { 957 | t.Run(strconv.Itoa(tt.target), func(t *testing.T) { 958 | { 959 | pos, found := BinarySearch(data, tt.target) 960 | if pos != tt.wantPos || found != tt.wantFound { 961 | t.Errorf("BinarySearch got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) 962 | } 963 | } 964 | 965 | { 966 | cmp := func(a, b int) int { 967 | return a - b 968 | } 969 | pos, found := BinarySearchFunc(data, tt.target, cmp) 970 | if pos != tt.wantPos || found != tt.wantFound { 971 | t.Errorf("BinarySearchFunc got (%v, %v), want (%v, %v)", pos, found, tt.wantPos, tt.wantFound) 972 | } 973 | } 974 | }) 975 | } 976 | } 977 | 978 | func TestBinarySearchFunc(t *testing.T) { 979 | data := []int{1, 10, 11, 2} // sorted lexicographically 980 | cmp := func(a int, b string) int { 981 | return strings.Compare(strconv.Itoa(a), b) 982 | } 983 | pos, found := BinarySearchFunc(data, "2", cmp) 984 | if pos != 3 || !found { 985 | t.Errorf("BinarySearchFunc(%v, %q, cmp) = %v, %v, want %v, %v", data, "2", pos, found, 3, true) 986 | } 987 | } 988 | 989 | func TestConcat(t *testing.T) { 990 | var ( 991 | a = []int{1, 2, 3} 992 | b = []int{4, 5, 6} 993 | want = []int{1, 2, 3, 4, 5, 6} 994 | ) 995 | if got := Concat(a, b); !Equal(got, want) { 996 | t.Errorf("Concat(%v, %v) = %v, want %v", a, b, got, want) 997 | } 998 | } 999 | 1000 | func TestAll(t *testing.T) { 1001 | for size := 0; size < 10; size++ { 1002 | var s []int 1003 | for i := range size { 1004 | s = append(s, i) 1005 | } 1006 | ei, ev := 0, 0 1007 | cnt := 0 1008 | for i, v := range All(s) { 1009 | if i != ei || v != ev { 1010 | t.Errorf("at iteration %d got %d, %d want %d, %d", cnt, i, v, ei, ev) 1011 | } 1012 | ei++ 1013 | ev++ 1014 | cnt++ 1015 | } 1016 | if cnt != size { 1017 | t.Errorf("read %d values expected %d", cnt, size) 1018 | } 1019 | } 1020 | } 1021 | 1022 | func TestBackward(t *testing.T) { 1023 | for size := 0; size < 10; size++ { 1024 | var s []int 1025 | for i := range size { 1026 | s = append(s, i) 1027 | } 1028 | ei, ev := size-1, size-1 1029 | cnt := 0 1030 | for i, v := range Backward(s) { 1031 | if i != ei || v != ev { 1032 | t.Errorf("at iteration %d got %d, %d want %d, %d", cnt, i, v, ei, ev) 1033 | } 1034 | ei-- 1035 | ev-- 1036 | cnt++ 1037 | } 1038 | if cnt != size { 1039 | t.Errorf("read %d values expected %d", cnt, size) 1040 | } 1041 | } 1042 | } 1043 | 1044 | func TestValues(t *testing.T) { 1045 | for size := 0; size < 10; size++ { 1046 | var s []int 1047 | for i := range size { 1048 | s = append(s, i) 1049 | } 1050 | ev := 0 1051 | cnt := 0 1052 | for v := range Values(s) { 1053 | if v != ev { 1054 | t.Errorf("at iteration %d got %d want %d", cnt, v, ev) 1055 | } 1056 | ev++ 1057 | cnt++ 1058 | } 1059 | if cnt != size { 1060 | t.Errorf("read %d values expected %d", cnt, size) 1061 | } 1062 | } 1063 | } 1064 | 1065 | func testSeq(yield func(int) bool) { 1066 | for i := 0; i < 10; i += 2 { 1067 | if !yield(i) { 1068 | return 1069 | } 1070 | } 1071 | } 1072 | 1073 | var testSeqResult = []int{0, 2, 4, 6, 8} 1074 | 1075 | func TestAppendSeq(t *testing.T) { 1076 | s := AppendSeq([]int{1, 2}, testSeq) 1077 | want := append([]int{1, 2}, testSeqResult...) 1078 | if !Equal(s, want) { 1079 | t.Errorf("got %v, want %v", s, want) 1080 | } 1081 | } 1082 | 1083 | func TestCollect(t *testing.T) { 1084 | s := Collect(testSeq) 1085 | want := testSeqResult 1086 | if !Equal(s, want) { 1087 | t.Errorf("got %v, want %v", s, want) 1088 | } 1089 | } 1090 | 1091 | var iterTests = [][]string{ 1092 | nil, 1093 | {"a"}, 1094 | {"a", "b"}, 1095 | {"b", "a"}, 1096 | strs[:], 1097 | } 1098 | 1099 | func TestValuesAppendSeq(t *testing.T) { 1100 | for _, prefix := range iterTests { 1101 | for _, s := range iterTests { 1102 | got := AppendSeq(prefix, Values(s)) 1103 | want := append(prefix, s...) 1104 | if !Equal(got, want) { 1105 | t.Errorf("AppendSeq(%v, Values(%v)) == %v, want %v", prefix, s, got, want) 1106 | } 1107 | } 1108 | } 1109 | } 1110 | 1111 | func TestValuesCollect(t *testing.T) { 1112 | for _, s := range iterTests { 1113 | got := Collect(Values(s)) 1114 | if !Equal(got, s) { 1115 | t.Errorf("Collect(Values(%v)) == %v, want %v", s, got, s) 1116 | } 1117 | } 1118 | } 1119 | 1120 | func TestSorted(t *testing.T) { 1121 | s := Sorted(Values(ints[:])) 1122 | if !IsSorted(s) { 1123 | t.Errorf("sorted %v", ints) 1124 | t.Errorf(" got %v", s) 1125 | } 1126 | } 1127 | 1128 | func TestSortedFunc(t *testing.T) { 1129 | s := SortedFunc(Values(ints[:]), func(a, b int) int { return a - b }) 1130 | if !IsSorted(s) { 1131 | t.Errorf("sorted %v", ints) 1132 | t.Errorf(" got %v", s) 1133 | } 1134 | } 1135 | 1136 | func TestSortedStableFunc(t *testing.T) { 1137 | n, m := 1000, 100 1138 | data := make(intPairs, n) 1139 | for i := range data { 1140 | data[i].a = rand.IntN(m) 1141 | } 1142 | data.initB() 1143 | 1144 | s := intPairs(SortedStableFunc(Values(data), intPairCmp)) 1145 | if !IsSortedFunc(s, intPairCmp) { 1146 | t.Errorf("SortedStableFunc didn't sort %d ints", n) 1147 | } 1148 | if !s.inOrder(false) { 1149 | t.Errorf("SortedStableFunc wasn't stable on %d ints", n) 1150 | } 1151 | 1152 | // iterVal converts a Seq2 to a Seq. 1153 | iterVal := func(seq iter.Seq2[int, intPair]) iter.Seq[intPair] { 1154 | return func(yield func(intPair) bool) { 1155 | for _, v := range seq { 1156 | if !yield(v) { 1157 | return 1158 | } 1159 | } 1160 | } 1161 | } 1162 | 1163 | s = intPairs(SortedStableFunc(iterVal(Backward(data)), intPairCmp)) 1164 | if !IsSortedFunc(s, intPairCmp) { 1165 | t.Errorf("SortedStableFunc didn't sort %d reverse ints", n) 1166 | } 1167 | if !s.inOrder(true) { 1168 | t.Errorf("SortedStableFunc wasn't stable on %d reverse ints", n) 1169 | } 1170 | } 1171 | 1172 | func TestChunk(t *testing.T) { 1173 | cases := []struct { 1174 | name string 1175 | s []int 1176 | n int 1177 | chunks [][]int 1178 | }{ 1179 | { 1180 | name: "nil", 1181 | s: nil, 1182 | n: 1, 1183 | chunks: nil, 1184 | }, 1185 | { 1186 | name: "empty", 1187 | s: []int{}, 1188 | n: 1, 1189 | chunks: nil, 1190 | }, 1191 | { 1192 | name: "short", 1193 | s: []int{1, 2}, 1194 | n: 3, 1195 | chunks: [][]int{{1, 2}}, 1196 | }, 1197 | { 1198 | name: "one", 1199 | s: []int{1, 2}, 1200 | n: 2, 1201 | chunks: [][]int{{1, 2}}, 1202 | }, 1203 | { 1204 | name: "even", 1205 | s: []int{1, 2, 3, 4}, 1206 | n: 2, 1207 | chunks: [][]int{{1, 2}, {3, 4}}, 1208 | }, 1209 | { 1210 | name: "odd", 1211 | s: []int{1, 2, 3, 4, 5}, 1212 | n: 2, 1213 | chunks: [][]int{{1, 2}, {3, 4}, {5}}, 1214 | }, 1215 | } 1216 | 1217 | for _, tc := range cases { 1218 | t.Run(tc.name, func(t *testing.T) { 1219 | var chunks [][]int 1220 | for c := range Chunk(tc.s, tc.n) { 1221 | chunks = append(chunks, c) 1222 | } 1223 | 1224 | if !chunkEqual(chunks, tc.chunks) { 1225 | t.Errorf("Chunk(%v, %d) = %v, want %v", tc.s, tc.n, chunks, tc.chunks) 1226 | } 1227 | 1228 | if len(chunks) == 0 { 1229 | return 1230 | } 1231 | 1232 | // Verify that appending to the end of the first chunk does not 1233 | // clobber the beginning of the next chunk. 1234 | s := Clone(tc.s) 1235 | chunks[0] = append(chunks[0], -1) 1236 | if !Equal(s, tc.s) { 1237 | t.Errorf("slice was clobbered: %v, want %v", s, tc.s) 1238 | } 1239 | }) 1240 | } 1241 | } 1242 | 1243 | func TestChunkPanics(t *testing.T) { 1244 | for _, test := range []struct { 1245 | name string 1246 | x []struct{} 1247 | n int 1248 | }{ 1249 | { 1250 | name: "cannot be less than 1", 1251 | x: make([]struct{}, 0), 1252 | n: 0, 1253 | }, 1254 | } { 1255 | if !panics(func() { _ = Chunk(test.x, test.n) }) { 1256 | t.Errorf("Chunk %s: got no panic, want panic", test.name) 1257 | } 1258 | } 1259 | } 1260 | 1261 | func TestChunkRange(t *testing.T) { 1262 | // Verify Chunk iteration can be stopped. 1263 | var got [][]int 1264 | for c := range Chunk([]int{1, 2, 3, 4, -100}, 2) { 1265 | if len(got) == 2 { 1266 | // Found enough values, break early. 1267 | break 1268 | } 1269 | 1270 | got = append(got, c) 1271 | } 1272 | 1273 | if want := [][]int{{1, 2}, {3, 4}}; !chunkEqual(got, want) { 1274 | t.Errorf("Chunk iteration did not stop, got %v, want %v", got, want) 1275 | } 1276 | } 1277 | 1278 | func chunkEqual[Slice ~[]E, E comparable](s1, s2 []Slice) bool { 1279 | return EqualFunc(s1, s2, Equal[Slice]) 1280 | } 1281 | 1282 | type S struct { 1283 | a int 1284 | b string 1285 | } 1286 | 1287 | func cmpS(s1, s2 S) int { 1288 | return cmp.Compare(s1.a, s2.a) 1289 | } 1290 | 1291 | func TestMinMax(t *testing.T) { 1292 | intCmp := func(a, b int) int { return a - b } 1293 | 1294 | tests := []struct { 1295 | data []int 1296 | wantMin int 1297 | wantMax int 1298 | }{ 1299 | {[]int{7}, 7, 7}, 1300 | {[]int{1, 2}, 1, 2}, 1301 | {[]int{2, 1}, 1, 2}, 1302 | {[]int{1, 2, 3}, 1, 3}, 1303 | {[]int{3, 2, 1}, 1, 3}, 1304 | {[]int{2, 1, 3}, 1, 3}, 1305 | {[]int{2, 2, 3}, 2, 3}, 1306 | {[]int{3, 2, 3}, 2, 3}, 1307 | {[]int{0, 2, -9}, -9, 2}, 1308 | } 1309 | for _, tt := range tests { 1310 | t.Run(fmt.Sprintf("%v", tt.data), func(t *testing.T) { 1311 | gotMin := Min(tt.data) 1312 | if gotMin != tt.wantMin { 1313 | t.Errorf("Min got %v, want %v", gotMin, tt.wantMin) 1314 | } 1315 | 1316 | gotMinFunc := MinFunc(tt.data, intCmp) 1317 | if gotMinFunc != tt.wantMin { 1318 | t.Errorf("MinFunc got %v, want %v", gotMinFunc, tt.wantMin) 1319 | } 1320 | 1321 | gotMax := Max(tt.data) 1322 | if gotMax != tt.wantMax { 1323 | t.Errorf("Max got %v, want %v", gotMax, tt.wantMax) 1324 | } 1325 | 1326 | gotMaxFunc := MaxFunc(tt.data, intCmp) 1327 | if gotMaxFunc != tt.wantMax { 1328 | t.Errorf("MaxFunc got %v, want %v", gotMaxFunc, tt.wantMax) 1329 | } 1330 | }) 1331 | } 1332 | 1333 | svals := []S{ 1334 | {1, "a"}, 1335 | {2, "a"}, 1336 | {1, "b"}, 1337 | {2, "b"}, 1338 | } 1339 | 1340 | gotMin := MinFunc(svals, cmpS) 1341 | wantMin := S{1, "a"} 1342 | if gotMin != wantMin { 1343 | t.Errorf("MinFunc(%v) = %v, want %v", svals, gotMin, wantMin) 1344 | } 1345 | 1346 | gotMax := MaxFunc(svals, cmpS) 1347 | wantMax := S{2, "a"} 1348 | if gotMax != wantMax { 1349 | t.Errorf("MaxFunc(%v) = %v, want %v", svals, gotMax, wantMax) 1350 | } 1351 | } 1352 | 1353 | func TestMinMaxNaNs(t *testing.T) { 1354 | fs := []float64{1.0, 999.9, 3.14, -400.4, -5.14} 1355 | if Min(fs) != -400.4 { 1356 | t.Errorf("got min %v, want -400.4", Min(fs)) 1357 | } 1358 | if Max(fs) != 999.9 { 1359 | t.Errorf("got max %v, want 999.9", Max(fs)) 1360 | } 1361 | 1362 | // No matter which element of fs is replaced with a NaN, both Min and Max 1363 | // should propagate the NaN to their output. 1364 | for i := 0; i < len(fs); i++ { 1365 | testfs := Clone(fs) 1366 | testfs[i] = math.NaN() 1367 | 1368 | fmin := Min(testfs) 1369 | if !math.IsNaN(fmin) { 1370 | t.Errorf("got min %v, want NaN", fmin) 1371 | } 1372 | 1373 | fmax := Max(testfs) 1374 | if !math.IsNaN(fmax) { 1375 | t.Errorf("got max %v, want NaN", fmax) 1376 | } 1377 | } 1378 | } 1379 | 1380 | func TestMinMaxPanics(t *testing.T) { 1381 | intCmp := func(a, b int) int { return a - b } 1382 | emptySlice := []int{} 1383 | 1384 | if !panics(func() { Min(emptySlice) }) { 1385 | t.Errorf("Min([]): got no panic, want panic") 1386 | } 1387 | 1388 | if !panics(func() { Max(emptySlice) }) { 1389 | t.Errorf("Max([]): got no panic, want panic") 1390 | } 1391 | 1392 | if !panics(func() { MinFunc(emptySlice, intCmp) }) { 1393 | t.Errorf("MinFunc([]): got no panic, want panic") 1394 | } 1395 | 1396 | if !panics(func() { MaxFunc(emptySlice, intCmp) }) { 1397 | t.Errorf("MaxFunc([]): got no panic, want panic") 1398 | } 1399 | } 1400 | --------------------------------------------------------------------------------