├── go.mod ├── go.sum ├── .gitignore ├── .github └── workflows │ └── go.yml ├── LICENSE ├── errgroup_example_md5all_test.go ├── errgroup_test.go ├── errgroupn_test.go ├── errgroup.go ├── README.md └── benchmark_test.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/neilotoole/errgroup 2 | 3 | go 1.14 4 | 5 | require golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 h1:qwRHBd0NqMbJxfbotnDhm2ByMI1Shq4Y6oRJo21SGJA= 2 | golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | *.iml 15 | .idea 16 | 17 | # Dependency directories (remove the comment below to include it) 18 | # vendor/ 19 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ master, dev ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | 11 | build: 12 | name: Build 13 | runs-on: ubuntu-latest 14 | steps: 15 | 16 | - name: Set up Go 1.x 17 | uses: actions/setup-go@v2 18 | with: 19 | go-version: ^1.14 20 | id: go 21 | 22 | - name: Check out code into the Go module directory 23 | uses: actions/checkout@v2 24 | 25 | - name: Get dependencies 26 | run: | 27 | go get -v -t -d ./... 28 | 29 | - name: Build 30 | run: go build -v . 31 | 32 | - name: Test 33 | run: go test -v . 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Neil O'Toole 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /errgroup_example_md5all_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Note: This file is copied directly from sync/errgroup 6 | // with the one-line change that pkg neilotoole/errgroup is imported 7 | // as errgroup. The purpose is to test if neilotoole/errgroup can be 8 | // characterized as a "drop-in" replacement for sync/errgroup, by 9 | // seamlessly passing all of sync/errgroup's tests. 10 | 11 | package errgroup_test 12 | 13 | import ( 14 | "context" 15 | "crypto/md5" 16 | "fmt" 17 | "io/ioutil" 18 | "log" 19 | "os" 20 | "path/filepath" 21 | 22 | "github.com/neilotoole/errgroup" 23 | ) 24 | 25 | // Pipeline demonstrates the use of a Group to implement a multi-stage 26 | // pipeline: a version of the MD5All function with bounded parallelism from 27 | // https://blog.golang.org/pipelines. 28 | func ExampleGroup_pipeline() { 29 | m, err := MD5All(context.Background(), ".") 30 | if err != nil { 31 | log.Fatal(err) 32 | } 33 | 34 | for k, sum := range m { 35 | fmt.Printf("%s:\t%x\n", k, sum) 36 | } 37 | } 38 | 39 | type result struct { 40 | path string 41 | sum [md5.Size]byte 42 | } 43 | 44 | // MD5All reads all the files in the file tree rooted at root and returns a map 45 | // from file path to the MD5 sum of the file's contents. If the directory walk 46 | // fails or any read operation fails, MD5All returns an error. 47 | func MD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) { 48 | // ctx is canceled when g.Wait() returns. When this version of MD5All returns 49 | // - even in case of error! - we know that all of the goroutines have finished 50 | // and the memory they were using can be garbage-collected. 51 | g, ctx := errgroup.WithContext(ctx) 52 | paths := make(chan string) 53 | 54 | g.Go(func() error { 55 | defer close(paths) 56 | return filepath.Walk(root, func(path string, info os.FileInfo, err error) error { 57 | if err != nil { 58 | return err 59 | } 60 | if !info.Mode().IsRegular() { 61 | return nil 62 | } 63 | select { 64 | case paths <- path: 65 | case <-ctx.Done(): 66 | return ctx.Err() 67 | } 68 | return nil 69 | }) 70 | }) 71 | 72 | // Start a fixed number of goroutines to read and digest files. 73 | c := make(chan result) 74 | const numDigesters = 20 75 | for i := 0; i < numDigesters; i++ { 76 | g.Go(func() error { 77 | for path := range paths { 78 | data, err := ioutil.ReadFile(path) 79 | if err != nil { 80 | return err 81 | } 82 | select { 83 | case c <- result{path, md5.Sum(data)}: 84 | case <-ctx.Done(): 85 | return ctx.Err() 86 | } 87 | } 88 | return nil 89 | }) 90 | } 91 | go func() { 92 | g.Wait() 93 | close(c) 94 | }() 95 | 96 | m := make(map[string][md5.Size]byte) 97 | for r := range c { 98 | m[r.path] = r.sum 99 | } 100 | // Check whether any of the goroutines failed. Since g is accumulating the 101 | // errors, we don't need to send them (or check for them) in the individual 102 | // results sent on the channel. 103 | if err := g.Wait(); err != nil { 104 | return nil, err 105 | } 106 | return m, nil 107 | } 108 | -------------------------------------------------------------------------------- /errgroup_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2016 The Go Authors. All rights reserved. 2 | // Use of this source code is governed by a BSD-style 3 | // license that can be found in the LICENSE file. 4 | 5 | // Note: This file is copied directly from sync/errgroup 6 | // with the one-line change that pkg neilotoole/errgroup is imported 7 | // instead of sync/errgroup. The purpose is to test if neilotoole/errgroup 8 | // can be characterized as a "drop-in" replacement for sync/errgroup, by 9 | // seamlessly passing all of sync/errgroup's tests. 10 | 11 | package errgroup_test 12 | 13 | import ( 14 | "context" 15 | "errors" 16 | "fmt" 17 | "net/http" 18 | "os" 19 | "testing" 20 | 21 | "github.com/neilotoole/errgroup" 22 | ) 23 | 24 | var ( 25 | Web = fakeSearch("web") 26 | Image = fakeSearch("image") 27 | Video = fakeSearch("video") 28 | ) 29 | 30 | type Result string 31 | type Search func(ctx context.Context, query string) (Result, error) 32 | 33 | func fakeSearch(kind string) Search { 34 | return func(_ context.Context, query string) (Result, error) { 35 | return Result(fmt.Sprintf("%s result for %q", kind, query)), nil 36 | } 37 | } 38 | 39 | // JustErrors illustrates the use of a Group in place of a sync.WaitGroup to 40 | // simplify goroutine counting and error handling. This example is derived from 41 | // the sync.WaitGroup example at https://golang.org/pkg/sync/#example_WaitGroup. 42 | func ExampleGroup_justErrors() { 43 | var g errgroup.Group 44 | var urls = []string{ 45 | "http://www.golang.org/", 46 | "http://www.google.com/", 47 | "http://www.somestupidname.com/", 48 | } 49 | for _, url := range urls { 50 | // Launch a goroutine to fetch the URL. 51 | url := url // https://golang.org/doc/faq#closures_and_goroutines 52 | g.Go(func() error { 53 | // Fetch the URL. 54 | resp, err := http.Get(url) 55 | if err == nil { 56 | resp.Body.Close() 57 | } 58 | return err 59 | }) 60 | } 61 | // Wait for all HTTP fetches to complete. 62 | if err := g.Wait(); err == nil { 63 | fmt.Println("Successfully fetched all URLs.") 64 | } 65 | } 66 | 67 | // Parallel illustrates the use of a Group for synchronizing a simple parallel 68 | // task: the "Google Search 2.0" function from 69 | // https://talks.golang.org/2012/concurrency.slide#46, augmented with a Context 70 | // and error-handling. 71 | func ExampleGroup_parallel() { 72 | Google := func(ctx context.Context, query string) ([]Result, error) { 73 | g, ctx := errgroup.WithContext(ctx) 74 | 75 | searches := []Search{Web, Image, Video} 76 | results := make([]Result, len(searches)) 77 | for i, search := range searches { 78 | i, search := i, search // https://golang.org/doc/faq#closures_and_goroutines 79 | g.Go(func() error { 80 | result, err := search(ctx, query) 81 | if err == nil { 82 | results[i] = result 83 | } 84 | return err 85 | }) 86 | } 87 | if err := g.Wait(); err != nil { 88 | return nil, err 89 | } 90 | return results, nil 91 | } 92 | 93 | results, err := Google(context.Background(), "golang") 94 | if err != nil { 95 | fmt.Fprintln(os.Stderr, err) 96 | return 97 | } 98 | for _, result := range results { 99 | fmt.Println(result) 100 | } 101 | 102 | // Output: 103 | // web result for "golang" 104 | // image result for "golang" 105 | // video result for "golang" 106 | } 107 | 108 | func TestZeroGroup(t *testing.T) { 109 | err1 := errors.New("errgroup_test: 1") 110 | err2 := errors.New("errgroup_test: 2") 111 | 112 | cases := []struct { 113 | errs []error 114 | }{ 115 | {errs: []error{}}, 116 | {errs: []error{nil}}, 117 | {errs: []error{err1}}, 118 | {errs: []error{err1, nil}}, 119 | {errs: []error{err1, nil, err2}}, 120 | } 121 | 122 | for _, tc := range cases { 123 | var g errgroup.Group 124 | 125 | var firstErr error 126 | for i, err := range tc.errs { 127 | err := err 128 | g.Go(func() error { return err }) 129 | 130 | if firstErr == nil && err != nil { 131 | firstErr = err 132 | } 133 | 134 | if gErr := g.Wait(); gErr != firstErr { 135 | t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ 136 | "g.Wait() = %v; want %v", 137 | g, tc.errs[:i+1], err, firstErr) 138 | } 139 | } 140 | } 141 | } 142 | 143 | func TestWithContext(t *testing.T) { 144 | errDoom := errors.New("group_test: doomed") 145 | 146 | cases := []struct { 147 | errs []error 148 | want error 149 | }{ 150 | {want: nil}, 151 | {errs: []error{nil}, want: nil}, 152 | {errs: []error{errDoom}, want: errDoom}, 153 | {errs: []error{errDoom, nil}, want: errDoom}, 154 | } 155 | 156 | for _, tc := range cases { 157 | g, ctx := errgroup.WithContext(context.Background()) 158 | 159 | for _, err := range tc.errs { 160 | err := err 161 | g.Go(func() error { return err }) 162 | } 163 | 164 | if err := g.Wait(); err != tc.want { 165 | t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ 166 | "g.Wait() = %v; want %v", 167 | g, tc.errs, err, tc.want) 168 | } 169 | 170 | canceled := false 171 | select { 172 | case <-ctx.Done(): 173 | canceled = true 174 | default: 175 | } 176 | if !canceled { 177 | t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+ 178 | "ctx.Done() was not closed", 179 | g, tc.errs) 180 | } 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /errgroupn_test.go: -------------------------------------------------------------------------------- 1 | package errgroup_test 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | "golang.org/x/sync/errgroup" 10 | 11 | errgroupn "github.com/neilotoole/errgroup" 12 | ) 13 | 14 | // grouper is an abstraction of errgroup.Group's exported methodset. 15 | type grouper interface { 16 | Go(f func() error) 17 | Wait() error 18 | } 19 | 20 | var ( 21 | _ grouper = &errgroup.Group{} 22 | _ grouper = &errgroupn.Group{} 23 | ) 24 | 25 | func newErrgroupZero() (grouper, context.Context) { 26 | return &errgroup.Group{}, context.Background() 27 | } 28 | 29 | func newErrgroupnZero() (grouper, context.Context) { 30 | return &errgroupn.Group{}, context.Background() 31 | } 32 | 33 | func newErrgroupWithContext() (grouper, context.Context) { 34 | return errgroup.WithContext(context.Background()) 35 | } 36 | func newErrgroupnWithContext() (grouper, context.Context) { 37 | return errgroupn.WithContext(context.Background()) 38 | } 39 | 40 | func newErrgroupnWithContextN(numG, qSize int) func() (grouper, context.Context) { 41 | return func() (grouper, context.Context) { 42 | return errgroupn.WithContextN(context.Background(), numG, qSize) 43 | } 44 | } 45 | 46 | func TestGroup(t *testing.T) { 47 | testCases := []struct { 48 | name string 49 | newG func() (grouper, context.Context) 50 | }{ 51 | {name: "errgroup_zero", newG: newErrgroupZero}, 52 | {name: "errgroup_wctx", newG: newErrgroupWithContext}, 53 | {name: "errgroupn_zero", newG: newErrgroupnZero}, 54 | {name: "errgroupn_wctx", newG: newErrgroupnWithContext}, 55 | {name: "errgroupn_wctx_0_0", newG: newErrgroupnWithContextN(0, 0)}, 56 | {name: "errgroupn_wctx_1_0", newG: newErrgroupnWithContextN(1, 0)}, 57 | {name: "errgroupn_wctx_1_1", newG: newErrgroupnWithContextN(1, 1)}, 58 | {name: "errgroupn_wctx_4_16", newG: newErrgroupnWithContextN(4, 16)}, 59 | {name: "errgroupn_wctx_16_4", newG: newErrgroupnWithContextN(16, 4)}, 60 | } 61 | 62 | for _, tc := range testCases { 63 | tc := tc 64 | t.Run(tc.name, func(t *testing.T) { 65 | g, _ := tc.newG() 66 | 67 | vals := make([]int, fibMax) 68 | for i := 0; i < fibMax; i++ { 69 | i := i 70 | g.Go(func() error { 71 | vals[i] = fib(i) 72 | return nil 73 | }) 74 | } 75 | 76 | err := g.Wait() 77 | if err != nil { 78 | t.Error(err) 79 | } 80 | if !equalInts(vals, fibVals[fibMax-1]) { 81 | t.Errorf("vals (%d) incorrect: %v | %v", fibMax, vals, fibVals[fibMax]) 82 | } 83 | 84 | // Let's do this a second time to verify that g.Go continues 85 | // to work after the first call to g.Wait 86 | vals = make([]int, fibMax) 87 | for i := 0; i < fibMax; i++ { 88 | i := i 89 | g.Go(func() error { 90 | vals[i] = fib(i) 91 | return nil 92 | }) 93 | } 94 | 95 | err = g.Wait() 96 | if err != nil { 97 | t.Error(err) 98 | } 99 | if !equalInts(vals, fibVals[fibMax-1]) { 100 | t.Errorf("vals (%d) incorrect: %v | %v", fibMax, vals, fibVals[fibMax]) 101 | } 102 | }) 103 | 104 | } 105 | 106 | } 107 | 108 | func TestEquivalence_GoWaitThenGoAgain(t *testing.T) { 109 | testCases := []struct { 110 | name string 111 | newG func() (grouper, context.Context) 112 | }{ 113 | {name: "errgroup_zero", newG: newErrgroupZero}, 114 | {name: "errgroup_wctx", newG: newErrgroupWithContext}, 115 | {name: "errgroupn_zero", newG: newErrgroupnZero}, 116 | {name: "errgroupn_wctx", newG: newErrgroupnWithContext}, 117 | {name: "errgroupn_wctx_16_4", newG: newErrgroupnWithContextN(16, 4)}, 118 | } 119 | 120 | for _, tc := range testCases { 121 | tc := tc 122 | 123 | t.Run(tc.name, func(t *testing.T) { 124 | g, gctx := tc.newG() 125 | 126 | actionCh := make(chan struct{}, 1) 127 | actionMu := &sync.Mutex{} 128 | actionFn := func() error { 129 | actionMu.Lock() 130 | defer actionMu.Unlock() 131 | 132 | _ = doWork(gctx, 10) 133 | 134 | actionCh <- struct{}{} 135 | return nil 136 | } 137 | 138 | g.Go(actionFn) 139 | 140 | err := g.Wait() 141 | if err != nil { 142 | t.Error(err) 143 | } 144 | if len(actionCh) != 1 { 145 | t.Errorf("actionCh should have one item") 146 | } 147 | 148 | // drain actionCh 149 | <-actionCh 150 | 151 | g.Go(actionFn) 152 | 153 | err = g.Wait() 154 | if err != nil { 155 | t.Error(err) 156 | } 157 | if len(actionCh) != 1 { 158 | t.Errorf("actionCh should have one item") 159 | } 160 | }) 161 | } 162 | } 163 | 164 | func TestEquivalence_WaitThenGo(t *testing.T) { 165 | testCases := []struct { 166 | name string 167 | newG func() (grouper, context.Context) 168 | }{ 169 | {name: "errgroup_zero", newG: newErrgroupZero}, 170 | {name: "errgroup_wctx", newG: newErrgroupWithContext}, 171 | {name: "errgroupn_zero", newG: newErrgroupnZero}, 172 | {name: "errgroupn_wctx", newG: newErrgroupnWithContext}, 173 | {name: "errgroupn_wctx_16_4", newG: newErrgroupnWithContextN(16, 4)}, 174 | } 175 | 176 | for _, tc := range testCases { 177 | tc := tc 178 | 179 | t.Run(tc.name, func(t *testing.T) { 180 | g, gctx := tc.newG() 181 | 182 | actionCh := make(chan struct{}, 1) 183 | actionMu := &sync.Mutex{} 184 | actionFn := func() error { 185 | actionMu.Lock() 186 | defer actionMu.Unlock() 187 | 188 | _ = doWork(gctx, 10) 189 | 190 | actionCh <- struct{}{} 191 | return nil 192 | } 193 | 194 | time.Sleep(time.Second) 195 | err := g.Wait() 196 | if err != nil { 197 | t.Error(err) 198 | } 199 | if len(actionCh) != 0 { 200 | t.Errorf("actionCh should have zero items") 201 | } 202 | 203 | g.Go(actionFn) 204 | 205 | time.Sleep(time.Second) 206 | err = g.Wait() 207 | if err != nil { 208 | t.Error(err) 209 | } 210 | if len(actionCh) != 1 { 211 | t.Errorf("actionCh should have one item") 212 | } 213 | }) 214 | } 215 | } 216 | 217 | // fibVals holds computed values of the fibonacci sequence. 218 | // Each row holds the fib sequence for that row's index. That is, 219 | // the first few rows look like: 220 | // 221 | // [0] 222 | // [0 1] 223 | // [0 1 1] 224 | // [0 1 1 2] 225 | // [0 1 1 2 3] 226 | var fibVals [][]int 227 | 228 | const fibMax = 50 229 | 230 | func init() { 231 | fibVals = make([][]int, fibMax) 232 | for i := 0; i < fibMax; i++ { 233 | fibVals[i] = make([]int, i+1) 234 | if i == 0 { 235 | fibVals[0][0] = 0 236 | continue 237 | } 238 | copy(fibVals[i], fibVals[i-1]) 239 | fibVals[i][i] = fib(i) 240 | } 241 | 242 | } 243 | 244 | // fib returns the fibonacci sequence of n. 245 | func fib(n int) int { 246 | a, b, temp := 0, 1, 0 247 | for i := 0; i < n; i++ { 248 | temp = a 249 | a = b 250 | b = temp + a 251 | } 252 | return a 253 | } 254 | 255 | func equalInts(a, b []int) bool { 256 | if len(a) != len(b) { 257 | return false 258 | } 259 | 260 | for i := range a { 261 | if a[i] != b[i] { 262 | return false 263 | } 264 | } 265 | 266 | return true 267 | } 268 | -------------------------------------------------------------------------------- /errgroup.go: -------------------------------------------------------------------------------- 1 | // Package neilotoole/errgroup is an extension of the sync/errgroup 2 | // concept, and much of the code herein is descended from 3 | // or directly copied from that sync/errgroup code which 4 | // has this header comment: 5 | // 6 | // Copyright 2016 The Go Authors. All rights reserved. 7 | // Use of this source code is governed by a BSD-style 8 | // license that can be found in the LICENSE file. 9 | 10 | // Package errgroup is a drop-in alternative to sync/errgroup but 11 | // limited to N goroutines. In effect, neilotoole/errgroup is 12 | // sync/errgroup but with a worker pool of N goroutines. 13 | package errgroup 14 | 15 | import ( 16 | "context" 17 | "runtime" 18 | "sync" 19 | 20 | "sync/atomic" 21 | ) 22 | 23 | // A Group is a collection of goroutines working on subtasks that are part of 24 | // the same overall task. 25 | // 26 | // A zero Group is valid and does not cancel on error. 27 | // 28 | // This Group implementation differs from sync/errgroup in that instead 29 | // of each call to Go spawning a new Go routine, the f passed to Go 30 | // is sent to a queue channel (qCh), and is picked up by one of N 31 | // worker goroutines. The number of goroutines (numG) and the queue 32 | // channel size (qSize) are args to WithContextN. The zero Group and 33 | // the Group returned by WithContext both use default values (the value 34 | // of runtime.NumCPU) for the numG and qSize args. A side-effect of this 35 | // implementation is that the Go method will block while qCh is full: in 36 | // contrast, errgroup.Group's Go method never blocks (it always spawns 37 | // a new goroutine). 38 | type Group struct { 39 | cancel func() 40 | 41 | wg sync.WaitGroup 42 | 43 | errOnce sync.Once 44 | err error 45 | 46 | // numG is the maximum number of goroutines that can be started. 47 | numG int 48 | 49 | // qSize is the capacity of qCh, used for buffering funcs 50 | // passed to method Go. 51 | qSize int 52 | 53 | // qCh is the buffer used to hold funcs passed to method Go 54 | // before they are picked up by worker goroutines. 55 | qCh chan func() error 56 | 57 | // qMu protects qCh. 58 | qMu sync.Mutex 59 | 60 | // gCount tracks the number of worker goroutines. 61 | gCount int64 62 | } 63 | 64 | // WithContext returns a new Group and an associated Context derived from ctx. 65 | // It is equivalent to WithContextN(ctx, 0, 0). 66 | func WithContext(ctx context.Context) (*Group, context.Context) { 67 | return WithContextN(ctx, 0, 0) // zero indicates default values 68 | } 69 | 70 | // WithContextN returns a new Group and an associated Context derived from ctx. 71 | // 72 | // The derived Context is canceled the first time a function passed to Go 73 | // returns a non-nil error or the first time Wait returns, whichever occurs 74 | // first. 75 | // 76 | // Param numG controls the number of worker goroutines. Param qSize 77 | // controls the size of the queue channel that holds functions passed 78 | // to method Go: while the queue channel is full, Go blocks. 79 | // If numG <= 0, the value of runtime.NumCPU is used; if qSize is 80 | // also <= 0, a qSize of runtime.NumCPU is used. 81 | func WithContextN(ctx context.Context, numG, qSize int) (*Group, context.Context) { 82 | ctx, cancel := context.WithCancel(ctx) 83 | return &Group{cancel: cancel, numG: numG, qSize: qSize}, ctx 84 | } 85 | 86 | // Wait blocks until all function calls from the Go method have returned, then 87 | // returns the first non-nil error (if any) from them. 88 | func (g *Group) Wait() error { 89 | g.qMu.Lock() 90 | 91 | if g.qCh != nil { 92 | // qCh is typically initialized by the first call to method Go. 93 | // qCh can be nil if Wait is invoked before the first 94 | // call to Go, hence this check before we close qCh. 95 | close(g.qCh) 96 | } 97 | 98 | // Wait for the worker goroutines to finish. 99 | g.wg.Wait() 100 | 101 | // All of the worker goroutines have finished, 102 | // so it's safe to set qCh to nil. 103 | g.qCh = nil 104 | 105 | g.qMu.Unlock() 106 | 107 | if g.cancel != nil { 108 | g.cancel() 109 | } 110 | 111 | return g.err 112 | } 113 | 114 | // Go adds the given function to a queue of functions that are called 115 | // by one of g's worker goroutines. 116 | // 117 | // The first call to return a non-nil error cancels the group; its error will be 118 | // returned by Wait. 119 | // 120 | // Go may block while g's qCh is full. 121 | func (g *Group) Go(f func() error) { 122 | g.qMu.Lock() 123 | if g.qCh == nil { 124 | // We need to initialize g. 125 | 126 | // The zero value of numG would mean no worker goroutine 127 | // would be created, which would be daft. 128 | // We want the "effective" zero value to be runtime.NumCPU. 129 | if g.numG == 0 { 130 | // Benchmarking has shown that the optimal numG and 131 | // qSize values depend on the particular workload. In 132 | // the absence of any other deciding factor, we somewhat 133 | // arbitrarily default to NumCPU, which seems to perform 134 | // reasonably in benchmarks. Users that care about performance 135 | // tuning will use the WithContextN func to specify the numG 136 | // and qSize args. 137 | g.numG = runtime.NumCPU() 138 | if g.qSize == 0 { 139 | g.qSize = g.numG 140 | } 141 | } 142 | 143 | g.qCh = make(chan func() error, g.qSize) 144 | 145 | // Being that g.Go has been invoked, we'll need at 146 | // least one goroutine. 147 | atomic.StoreInt64(&g.gCount, 1) 148 | g.startG() 149 | 150 | g.qMu.Unlock() 151 | 152 | g.qCh <- f 153 | 154 | return 155 | } 156 | 157 | g.qCh <- f 158 | 159 | // Check if we can or should start a new goroutine? 160 | g.maybeStartG() 161 | 162 | g.qMu.Unlock() 163 | 164 | } 165 | 166 | // maybeStartG might start a new worker goroutine, if 167 | // needed and allowed. 168 | func (g *Group) maybeStartG() { 169 | if len(g.qCh) == 0 { 170 | // No point starting a new goroutine if there's 171 | // nothing in qCh 172 | return 173 | } 174 | 175 | // We have at least one item in qCh. Maybe it's time to start 176 | // a new worker goroutine? 177 | if atomic.AddInt64(&g.gCount, 1) > int64(g.numG) { 178 | // Nope: not allowed. Starting a new goroutine would put us 179 | // over the numG limit, so we back out. 180 | atomic.AddInt64(&g.gCount, -1) 181 | return 182 | } 183 | 184 | // It's safe to start a new worker goroutine. 185 | g.startG() 186 | } 187 | 188 | // startG starts a new worker goroutine. 189 | func (g *Group) startG() { 190 | g.wg.Add(1) 191 | go func() { 192 | defer g.wg.Done() 193 | defer atomic.AddInt64(&g.gCount, -1) 194 | 195 | var f func() error 196 | 197 | for { 198 | // Block until f is received from qCh or 199 | // the channel is closed. 200 | f = <-g.qCh 201 | if f == nil { 202 | // qCh was closed, time for this goroutine 203 | // to die. 204 | return 205 | } 206 | 207 | if err := f(); err != nil { 208 | g.errOnce.Do(func() { 209 | g.err = err 210 | if g.cancel != nil { 211 | g.cancel() 212 | } 213 | }) 214 | 215 | return 216 | } 217 | } 218 | }() 219 | } 220 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![Actions Status](https://github.com/neilotoole/errgroup/workflows/Go/badge.svg)](https://github.com/neilotoole/errgroup/actions?query=workflow%3AGo) 3 | [![Go Report Card](https://goreportcard.com/badge/neilotoole/errgroup)](https://goreportcard.com/report/neilotoole/errgroup) 4 | [![release](https://img.shields.io/badge/release-v0.1.5-green.svg)](https://github.com/neilotoole/errgroup/releases/tag/v0.1.5) 5 | [![Coverage](https://gocover.io/_badge/github.com/neilotoole/errgroup)](https://gocover.io/github.com/neilotoole/errgroup) 6 | [![GoDoc](https://godoc.org/github.com/golang/gddo?status.svg)](https://pkg.go.dev/github.com/neilotoole/errgroup) 7 | [![license](https://img.shields.io/github/license/neilotoole/errgroup)](./LICENSE) 8 | 9 | # neilotoole/errgroup 10 | `neilotoole/errgroup` is a drop-in alternative to Go's wonderful 11 | [`sync/errgroup`](https://pkg.go.dev/golang.org/x/sync/errgroup) but 12 | limited to `N` goroutines. This is useful for interaction with rate-limited 13 | APIs, databases, and the like. 14 | 15 | 16 | > **Note** 17 | > The `sync/errgroup` package [now](https://github.com/neilotoole/errgroup/issues/14) has a [Group.SetLimit](https://pkg.go.dev/golang.org/x/sync/errgroup#Group.SetLimit) method, 18 | > which eliminates the need for `neilotoole/errgroup`. This package will no longer be maintained. Use `sync/errgroup` instead. 19 | 20 | 21 | ## Overview 22 | In effect, `neilotoole/errgroup` is `sync/errgroup` but with a worker pool 23 | of `N` goroutines. The exported API is identical but for an additional 24 | function `WithContextN`, which allows the caller 25 | to specify the maximum number of goroutines (`numG`) and the capacity 26 | of the queue channel (`qSize`) used to hold work before it is picked 27 | up by a worker goroutine. The zero `Group` and the `Group` returned 28 | by `WithContext` have `numG` and `qSize` equal to `runtime.NumCPU`. 29 | 30 | 31 | ## Usage 32 | The exported API of this package mirrors the `sync/errgroup` package. 33 | The only change needed is the import path of the package, from: 34 | 35 | ```go 36 | import ( 37 | "golang.org/x/sync/errgroup" 38 | ) 39 | ``` 40 | 41 | to 42 | 43 | ```go 44 | import ( 45 | "github.com/neilotoole/errgroup" 46 | ) 47 | ``` 48 | 49 | Then use in the normal manner. See the [godoc](https://pkg.go.dev/github.com/neilotoole/errgroup) for more. 50 | 51 | ```go 52 | g, ctx := errgroup.WithContext(ctx) 53 | g.Go(func() error { 54 | // do something 55 | return nil 56 | }) 57 | 58 | err := g.Wait() 59 | ``` 60 | 61 | Many users will have no need to tweak the `numG` and `qCh` params. However, benchmarking 62 | may suggest particular values for your workload. For that you'll need `WithContextN`: 63 | 64 | ```go 65 | numG, qSize := 8, 4 66 | g, ctx := errgroup.WithContextN(ctx, numG, qSize) 67 | 68 | ``` 69 | 70 | ## Performance 71 | The motivation for creating `neilotoole/errgroup` was to provide rate-limiting while 72 | maintaining the lovely `sync/errgroup` semantics. Sacrificing some 73 | performance vs `sync/errgroup` was assumed. However, benchmarking 74 | suggests that this implementation can be more effective than `sync/errgroup` 75 | when tuned for a specific workload. 76 | 77 | Below is a selection of benchmark results. How to read this: a workload is _X_ tasks 78 | of _Y_ complexity. The workload is executed for: 79 | 80 | - `sync/errgroup`, listed as `sync_errgroup` 81 | - a non-parallel implementation (`sequential`) 82 | - various `{numG, qSize}` configurations of `neilotoole/errgroup`, listed as `errgroupn_{numG}_{qSize}` 83 | 84 | ``` 85 | BenchmarkGroup_Short/complexity_5/tasks_50/errgroupn_default_16_16-16 25574 46867 ns/op 688 B/op 12 allocs/op 86 | BenchmarkGroup_Short/complexity_5/tasks_50/errgroupn_4_4-16 24908 48926 ns/op 592 B/op 12 allocs/op 87 | BenchmarkGroup_Short/complexity_5/tasks_50/errgroupn_16_4-16 24895 48313 ns/op 592 B/op 12 allocs/op 88 | BenchmarkGroup_Short/complexity_5/tasks_50/errgroupn_32_4-16 24853 48284 ns/op 592 B/op 12 allocs/op 89 | BenchmarkGroup_Short/complexity_5/tasks_50/sync_errgroup-16 18784 65826 ns/op 1858 B/op 55 allocs/op 90 | BenchmarkGroup_Short/complexity_5/tasks_50/sequential-16 10000 111483 ns/op 0 B/op 0 allocs/op 91 | 92 | BenchmarkGroup_Short/complexity_20/tasks_50/errgroupn_default_16_16-16 3745 325993 ns/op 1168 B/op 27 allocs/op 93 | BenchmarkGroup_Short/complexity_20/tasks_50/errgroupn_4_4-16 5186 227034 ns/op 1072 B/op 27 allocs/op 94 | BenchmarkGroup_Short/complexity_20/tasks_50/errgroupn_16_4-16 3970 312816 ns/op 1076 B/op 27 allocs/op 95 | BenchmarkGroup_Short/complexity_20/tasks_50/errgroupn_32_4-16 3715 320757 ns/op 1073 B/op 27 allocs/op 96 | BenchmarkGroup_Short/complexity_20/tasks_50/sync_errgroup-16 2739 432093 ns/op 1862 B/op 55 allocs/op 97 | BenchmarkGroup_Short/complexity_20/tasks_50/sequential-16 2306 520947 ns/op 0 B/op 0 allocs/op 98 | 99 | BenchmarkGroup_Short/complexity_40/tasks_250/errgroupn_default_16_16-16 354 3602666 ns/op 1822 B/op 47 allocs/op 100 | BenchmarkGroup_Short/complexity_40/tasks_250/errgroupn_4_4-16 420 2468605 ns/op 1712 B/op 47 allocs/op 101 | BenchmarkGroup_Short/complexity_40/tasks_250/errgroupn_16_4-16 334 3581349 ns/op 1716 B/op 47 allocs/op 102 | BenchmarkGroup_Short/complexity_40/tasks_250/errgroupn_32_4-16 310 3890316 ns/op 1712 B/op 47 allocs/op 103 | BenchmarkGroup_Short/complexity_40/tasks_250/sync_errgroup-16 253 4740462 ns/op 8303 B/op 255 allocs/op 104 | BenchmarkGroup_Short/complexity_40/tasks_250/sequential-16 200 5924693 ns/op 0 B/op 0 allocs/op 105 | ``` 106 | 107 | The overall impression is that `neilotoole/errgroup` can provide higher 108 | throughput than `sync/errgroup` for these (CPU-intensive) workloads, 109 | sometimes significantly so. As always, these benchmark results should 110 | not be taken as gospel: your results may vary. 111 | 112 | 113 | ## Design Note 114 | Why require an explicit `qSize` limit? 115 | 116 | If the number of calls to `Group.Go` results in `qCh` becoming 117 | full, the `Go` method will block until worker goroutines relieve `qCh`. 118 | This behavior is in contrast to `sync/errgroup`'s `Go` method, which doesn't block. 119 | While `neilotoole/errgroup` aims to be as much of a behaviorally similar 120 | "drop-in" alternative to `sync/errgroup` as possible, this blocking behavior 121 | is a conscious deviation. 122 | 123 | Noting that the capacity of `qCh` is controlled by `qSize`, it's probable an 124 | alternative implementation could be built that uses a (growable) slice 125 | acting - if `qCh` is full - as a buffer for functions passed to `Go`. 126 | Consideration of this potential design led to this [issue](https://github.com/golang/go/issues/20352) 127 | regarding _unlimited capacity channels_, or perhaps better characterized 128 | in this particular case as "_growable capacity channels_". If such a 129 | feature existed in the language, it's possible that this implementation might 130 | have taken advantage of it, at least in the first-pass release (benchmarking notwithstanding). 131 | However benchmarking seems to suggest that a relatively 132 | small `qSize` has performance benefits for some workloads, so it's possible 133 | that the explicit `qSize` requirement is a better design choice regardless. 134 | -------------------------------------------------------------------------------- /benchmark_test.go: -------------------------------------------------------------------------------- 1 | package errgroup_test 2 | 3 | import ( 4 | "context" 5 | "crypto/sha512" 6 | "fmt" 7 | "runtime" 8 | "testing" 9 | 10 | "golang.org/x/sync/errgroup" 11 | 12 | errgroupn "github.com/neilotoole/errgroup" 13 | ) 14 | 15 | // BenchmarkGroup_Short is a (shorter) benchmark of errgroup. 16 | // 17 | // go test -run=XXX -bench=BenchmarkGroup_Short -benchtime=1s 18 | func BenchmarkGroup_Short(b *testing.B) { 19 | cpus := runtime.NumCPU() 20 | 21 | testImpls := []struct { 22 | name string 23 | fn func(tasks, complexity int) error 24 | }{ 25 | {name: fmt.Sprintf("errgroupn_default_%d_%d", cpus, cpus), fn: doErrgroupnFunc(0, 0)}, 26 | {name: "errgroupn_4_4", fn: doErrgroupnFunc(4, 4)}, 27 | {name: "errgroupn_16_4", fn: doErrgroupnFunc(16, 4)}, 28 | {name: "errgroupn_32_4", fn: doErrgroupnFunc(32, 4)}, 29 | {name: "sync_errgroup", fn: doErrgroup}, // this is the sync/errgroup impl 30 | {name: "sequential", fn: doSequential}, // for reference, the non-parallel way 31 | } 32 | 33 | for _, complexity := range []int{5, 20, 40} { 34 | complexity := complexity 35 | 36 | b.Run(fmt.Sprintf("complexity_%d", complexity), func(b *testing.B) { 37 | for _, tasks := range []int{10, 50, 250} { 38 | tasks := tasks 39 | 40 | b.Run(fmt.Sprintf("tasks_%d", tasks), func(b *testing.B) { 41 | for _, impl := range testImpls { 42 | impl := impl 43 | 44 | b.Run(impl.name, func(b *testing.B) { 45 | b.ReportAllocs() 46 | for i := 0; i < b.N; i++ { 47 | err := impl.fn(complexity, tasks) 48 | if err != nil { 49 | b.Error(err) 50 | } 51 | } 52 | }) 53 | } 54 | }) 55 | } 56 | }) 57 | } 58 | } 59 | 60 | // BenchmarkGroup_Long benchmarks errgroupn vs errgroup at 61 | // various configurations and workloads. 62 | // 63 | // The benchmark setup is convoluted, but results in output like so: 64 | // 65 | // BenchmarkGroup_Long/complexity_2/tasks_10/errgroupn_default-16 79617 14827 ns/op 66 | // BenchmarkGroup_Long/complexity_2/tasks_10/errgroupn_1_0-16 63069 19095 ns/op 67 | // BenchmarkGroup_Long/complexity_2/tasks_10/errgroupn_1_1-16 60290 19913 ns/op 68 | // BenchmarkGroup_Long/complexity_2/tasks_10/errgroupn_2_1-16 60476 19621 ns/op 69 | // BenchmarkGroup_Long/complexity_2/tasks_10/errgroupn_2_2-16 85765 13931 ns/op 70 | // BenchmarkGroup_Long/complexity_2/tasks_10/errgroupn_2_4-16 83857 14335 ns/op 71 | // BenchmarkGroup_Long/complexity_2/tasks_10/sync_errgroup-16 59840 19987 ns/op 72 | // BenchmarkGroup_Long/complexity_2/tasks_10/sequential-16 79074 15193 ns/op 73 | // BenchmarkGroup_Long/complexity_2/tasks_10/errgroupn_4_2-16 85455 13924 ns/op 74 | // BenchmarkGroup_Long/complexity_2/tasks_10/errgroupn_4_8-16 83496 14323 ns/op 75 | // [...] 76 | // BenchmarkGroup_Long/complexity_10/tasks_10/errgroupn_default-16 29162 41861 ns/op 77 | // BenchmarkGroup_Long/complexity_10/tasks_10/errgroupn_1_0-16 13405 94965 ns/op 78 | // BenchmarkGroup_Long/complexity_10/tasks_10/errgroupn_1_1-16 14505 82001 ns/op 79 | // BenchmarkGroup_Long/complexity_10/tasks_10/errgroupn_2_1-16 19932 58991 ns/op 80 | // BenchmarkGroup_Long/complexity_10/tasks_10/errgroupn_2_2-16 22478 54035 ns/op 81 | // BenchmarkGroup_Long/complexity_10/tasks_10/errgroupn_2_4-16 24512 49981 ns/op 82 | // BenchmarkGroup_Long/complexity_10/tasks_10/sync_errgroup-16 26193 43099 ns/op 83 | // BenchmarkGroup_Long/complexity_10/tasks_10/sequential-16 20118 60036 ns/op 84 | // BenchmarkGroup_Long/complexity_10/tasks_10/errgroupn_4_2-16 24685 48527 ns/op 85 | // BenchmarkGroup_Long/complexity_10/tasks_10/errgroupn_4_8-16 28038 42742 ns/op 86 | // [...] 87 | // 88 | // The goal of this benchmark is to generate a bunch of data on 89 | // how errgroupn performs in different configurations {numG, qSize} 90 | // at different workloads {task complexity, num tasks}, also 91 | // including sync/errgroup ("sync_errgroup") and non-parallel 92 | // ("sequential") in the benchmark. The benchmark uses sub-benchmarks 93 | // for: {task complexity, number of tasks, the impl being tested}. 94 | // 95 | // Note that this benchmark takes a long time to run. Typically you'll 96 | // need to set the -timeout flag to 20-60m depending upon your system. 97 | // 98 | // go test -run=XXX -bench=BenchmarkGroup_Short -benchtime=1s -timeout=60m 99 | func BenchmarkGroup_Long(b *testing.B) { 100 | if testing.Short() { 101 | b.Skipf("This benchmark takes a long time to run") 102 | } 103 | 104 | b.Log("Go grab lunch, this benchmark takes a long time to run") 105 | 106 | cpus := runtime.NumCPU() 107 | 108 | testImpls := []struct { 109 | name string 110 | fn func(tasks, complexity int) error 111 | }{ 112 | // These are impls we want for reference 113 | {name: fmt.Sprintf("errgroupn_default_%d_%d", cpus, cpus), fn: doErrgroupnFunc(0, 0)}, 114 | {name: "errgroupn_1_0", fn: doErrgroupnFunc(1, 0)}, 115 | {name: "errgroupn_1_1", fn: doErrgroupnFunc(1, 1)}, 116 | {name: "errgroupn_2_1", fn: doErrgroupnFunc(2, 1)}, 117 | {name: "errgroupn_2_2", fn: doErrgroupnFunc(2, 2)}, 118 | {name: "errgroupn_2_4", fn: doErrgroupnFunc(2, 4)}, 119 | {name: "sync_errgroup", fn: doErrgroup}, // this is the sync/errgroup impl 120 | {name: "sequential", fn: doSequential}, // for reference, the non-parallel way 121 | } 122 | 123 | for _, numG := range []int{4, 16, 64, 512} { 124 | for _, qSize := range []int{2, 8, 64} { 125 | testImpls = append(testImpls, struct { 126 | name string 127 | fn func(tasks int, complexity int) error 128 | }{name: fmt.Sprintf("errgroupn_%d_%d", numG, qSize), fn: doErrgroupnFunc(numG, qSize)}) 129 | } 130 | } 131 | 132 | for _, complexity := range []int{2, 10, 20, 40} { 133 | complexity := complexity 134 | 135 | b.Run(fmt.Sprintf("complexity_%d", complexity), func(b *testing.B) { 136 | for _, tasks := range []int{10, 100, 500 /*, 1000, 5000*/} { 137 | tasks := tasks 138 | 139 | b.Run(fmt.Sprintf("tasks_%d", tasks), func(b *testing.B) { 140 | for _, impl := range testImpls { 141 | impl := impl 142 | 143 | b.Run(impl.name, func(b *testing.B) { 144 | b.ReportAllocs() 145 | for i := 0; i < b.N; i++ { 146 | err := impl.fn(complexity, tasks) 147 | if err != nil { 148 | b.Error(err) 149 | } 150 | } 151 | }) 152 | } 153 | }) 154 | } 155 | }) 156 | } 157 | } 158 | 159 | // doWork spends some time doing something that the 160 | // compiler won't zap away. The complexity param controls 161 | // how long the work takes. 162 | func doWork(ctx context.Context, complexity int) error { 163 | const text = `In Xanadu did Kubla Khan 164 | A stately pleasure-dome decree: 165 | Where Alph, the sacred river, ran 166 | Through caverns measureless to man 167 | Down to a sunless sea. 168 | So twice five miles of fertile ground 169 | With walls and towers were girdled round; 170 | And there were gardens bright with sinuous rills, 171 | Where blossomed many an incense-bearing tree; 172 | And here were forests ancient as the hills, 173 | Enfolding sunny spots of greenery. 174 | ` 175 | 176 | b := []byte(text) 177 | var res [sha512.Size256]byte 178 | 179 | for i := 0; i < complexity; i++ { 180 | select { 181 | case <-ctx.Done(): 182 | return ctx.Err() 183 | default: 184 | } 185 | res = sha512.Sum512_256(b) 186 | b = res[0:] 187 | } 188 | runtime.KeepAlive(b) 189 | return nil 190 | } 191 | 192 | func doSequential(complexity, tasks int) error { 193 | ctx := context.Background() 194 | var err error 195 | 196 | for i := 0; i <= tasks; i++ { 197 | err = doWork(ctx, complexity) 198 | if err != nil { 199 | break 200 | } 201 | } 202 | return err 203 | } 204 | 205 | func doErrgroup(complexity, tasks int) error { 206 | g, ctx := errgroup.WithContext(context.Background()) 207 | for i := 0; i <= tasks; i++ { 208 | g.Go(func() error { 209 | return doWork(ctx, complexity) 210 | }) 211 | } 212 | 213 | return g.Wait() 214 | } 215 | 216 | func doErrgroupnFunc(numG, qSize int) func(int, int) error { 217 | return func(tasks, complexity int) error { 218 | g, ctx := errgroupn.WithContextN(context.Background(), numG, qSize) 219 | for i := 0; i <= tasks; i++ { 220 | g.Go(func() error { 221 | return doWork(ctx, complexity) 222 | }) 223 | } 224 | 225 | return g.Wait() 226 | } 227 | } 228 | --------------------------------------------------------------------------------