├── .go-update ├── go.mod ├── .github └── workflows │ └── go-presubmit.yml ├── go.sum ├── bench_test.go ├── throttle.go ├── LICENSE ├── gatherer.go ├── single.go ├── examples └── copytree │ └── copytree.go ├── example_test.go ├── taskgroup.go ├── taskgroup_test.go └── README.md /.go-update: -------------------------------------------------------------------------------- 1 | cleanup() { 2 | ( cd "$GOBIN" && rm -vf -- copytree ) 3 | } 4 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/creachadair/taskgroup 2 | 3 | go 1.24 4 | 5 | require ( 6 | golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678 // indirect 7 | golang.org/x/mod v0.23.0 // indirect 8 | golang.org/x/sync v0.11.0 // indirect 9 | golang.org/x/tools v0.30.0 // indirect 10 | honnef.co/go/tools v0.6.1 // indirect 11 | ) 12 | 13 | tool honnef.co/go/tools/staticcheck 14 | -------------------------------------------------------------------------------- /.github/workflows/go-presubmit.yml: -------------------------------------------------------------------------------- 1 | name: Go presubmit 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: [opened, reopened, synchronize] 9 | workflow_dispatch: 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | build: 16 | name: Go presubmit 17 | runs-on: ${{ matrix.os }} 18 | strategy: 19 | matrix: 20 | go-version: ['stable'] 21 | os: ['ubuntu-24.04'] 22 | steps: 23 | - uses: actions/checkout@v6 24 | - name: Install Go ${{ matrix.go-version }} 25 | uses: actions/setup-go@v6 26 | with: 27 | go-version: ${{ matrix.go-version }} 28 | - uses: creachadair/go-presubmit-action@v2 29 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c h1:pxW6RcqyfI9/kWtOwnv/G+AzdKuy2ZrqINhenH4HyNs= 2 | github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= 3 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 4 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 5 | golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678 h1:1P7xPZEwZMoBoz0Yze5Nx2/4pxj6nw9ZqHWXqP0iRgQ= 6 | golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= 7 | golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= 8 | golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= 9 | golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= 10 | golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 11 | golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= 12 | golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= 13 | honnef.co/go/tools v0.6.1 h1:R094WgE8K4JirYjBaOpz/AvTyUu/3wbmAoskKN/pxTI= 14 | honnef.co/go/tools v0.6.1/go.mod h1:3puzxxljPCe8RGJX7BIy1plGbxEOZni5mR2aXe3/uk4= 15 | -------------------------------------------------------------------------------- /bench_test.go: -------------------------------------------------------------------------------- 1 | package taskgroup_test 2 | 3 | import ( 4 | "math/rand/v2" 5 | "sync" 6 | "testing" 7 | ) 8 | 9 | // A very rough benchmark comparing the performance of accumulating values with 10 | // a separate goroutine via a channel, vs. accumulating them directly under a 11 | // lock. The workload here is intentionally minimal, so the benchmark is 12 | // measuring more or less just the overhead. 13 | 14 | func BenchmarkChan(b *testing.B) { 15 | ch := make(chan int) 16 | done := make(chan struct{}) 17 | var total int 18 | go func() { 19 | defer close(done) 20 | for v := range ch { 21 | total += v 22 | } 23 | }() 24 | b.ResetTimer() // discount the setup time. 25 | 26 | var wg sync.WaitGroup 27 | wg.Add(b.N) 28 | for i := 0; i < b.N; i++ { 29 | go func() { 30 | defer wg.Done() 31 | ch <- rand.IntN(1000) 32 | }() 33 | } 34 | wg.Wait() 35 | close(ch) 36 | <-done 37 | } 38 | 39 | func BenchmarkLock(b *testing.B) { 40 | var μ sync.Mutex 41 | var total int 42 | report := func(v int) { 43 | μ.Lock() 44 | defer μ.Unlock() 45 | total += v 46 | } 47 | b.ResetTimer() // discount the setup time. 48 | 49 | var wg sync.WaitGroup 50 | wg.Add(b.N) 51 | for i := 0; i < b.N; i++ { 52 | go func() { 53 | defer wg.Done() 54 | report(rand.IntN(1000)) 55 | }() 56 | } 57 | wg.Wait() 58 | } 59 | -------------------------------------------------------------------------------- /throttle.go: -------------------------------------------------------------------------------- 1 | package taskgroup 2 | 3 | // A Throttle rate-limits the number of concurrent goroutines that can execute 4 | // in parallel to some fixed number. A zero Throttle is ready for use, but 5 | // imposes no limit on parallel execution. 6 | type Throttle struct { 7 | adm chan struct{} 8 | } 9 | 10 | // NewThrottle constructs a [Throttle] with a capacity of n goroutines. 11 | // If n ≤ 0, the resulting Throttle imposes no limit. 12 | func NewThrottle(n int) Throttle { 13 | if n <= 0 { 14 | return Throttle{} 15 | } 16 | return Throttle{adm: make(chan struct{}, n)} 17 | } 18 | 19 | // Limit returns a function that starts each [Task] passed to it in g, 20 | // respecting the rate limit imposed by t. Each call to Limit yields a fresh 21 | // start function, and all the functions returned share the capacity of t. 22 | func (t Throttle) Limit(g *Group) StartFunc { 23 | if t.adm == nil { 24 | return g.Go 25 | } 26 | return func(task Task) { 27 | t.adm <- struct{}{} // wait for a semaphore slot 28 | g.Go(func() error { 29 | defer func() { <-t.adm }() // yield a semaphore slot 30 | return task() 31 | }) 32 | } 33 | } 34 | 35 | // A StartFunc executes each [Task] passed to it in a [Group]. 36 | type StartFunc func(Task) 37 | 38 | // Go is a legibility shorthand for calling s with task. 39 | func (s StartFunc) Go(task Task) { s(task) } 40 | 41 | // Run is a legibility shorthand for calling s with a task that runs f and 42 | // reports a nil error. 43 | func (s StartFunc) Run(f func()) { s(noError(f)) } 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Michael J. Fromberger 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /gatherer.go: -------------------------------------------------------------------------------- 1 | package taskgroup 2 | 3 | import "sync" 4 | 5 | // A Gatherer manages a group of [Task] functions that report values, and 6 | // gathers the values they return. 7 | type Gatherer[T any] struct { 8 | run func(Task) // start the task in a goroutine 9 | 10 | μ sync.Mutex 11 | gather func(T) // handle values reported by tasks 12 | } 13 | 14 | func (g *Gatherer[T]) report(v T) { 15 | g.μ.Lock() 16 | defer g.μ.Unlock() 17 | g.gather(v) 18 | } 19 | 20 | // Gather creates a new empty gatherer that uses run to execute tasks returning 21 | // values of type T. 22 | // 23 | // If gather != nil, values reported by successful tasks are passed to the 24 | // function, otherwise such values are discarded. Calls to gather are 25 | // synchronized to a single goroutine. 26 | // 27 | // If run == nil, Gather will panic. 28 | func Gather[T any](run func(Task), gather func(T)) *Gatherer[T] { 29 | if run == nil { 30 | panic("run function is nil") 31 | } 32 | if gather == nil { 33 | gather = func(T) {} 34 | } 35 | return &Gatherer[T]{run: run, gather: gather} 36 | } 37 | 38 | // Call runs f in g. If f reports an error, the error is propagated to the 39 | // runner; otherwise the non-error value reported by f is gathered. 40 | func (g *Gatherer[T]) Call(f func() (T, error)) { 41 | g.run(func() error { 42 | v, err := f() 43 | if err == nil { 44 | g.report(v) 45 | } 46 | return err 47 | }) 48 | } 49 | 50 | // Run runs f in g, and gathers the value it reports. 51 | func (g *Gatherer[T]) Run(f func() T) { 52 | g.run(func() error { g.report(f()); return nil }) 53 | } 54 | 55 | // Report runs f in g. Any values passed to report are gathered. If f reports 56 | // an error, that error is propagated to the runner. Any values sent before f 57 | // returns are still gathered, even if f reports an error. 58 | func (g *Gatherer[T]) Report(f func(report func(T)) error) { 59 | g.run(func() error { return f(g.report) }) 60 | } 61 | -------------------------------------------------------------------------------- /single.go: -------------------------------------------------------------------------------- 1 | package taskgroup 2 | 3 | // A Single manages a single background goroutine. The task is started when the 4 | // value is first created, and the caller can use the Wait method to block 5 | // until it has exited. 6 | type Single[T any] struct { 7 | valc chan T 8 | val T 9 | } 10 | 11 | // Wait blocks until the task monitored by s has completed and returns the 12 | // value it reported. 13 | func (s *Single[T]) Wait() T { 14 | if v, ok := <-s.valc; ok { 15 | // This is the first call to receive a value: 16 | // Update val and close the channel (in that order). 17 | s.val = v 18 | close(s.valc) 19 | } 20 | return s.val 21 | } 22 | 23 | // Go runs task in a new goroutine. The caller must call Wait to wait for the 24 | // task to return and collect its value. 25 | func Go[T any](task func() T) *Single[T] { 26 | // N.B. This is closed by Wait. 27 | valc := make(chan T, 1) 28 | go func() { valc <- task() }() 29 | 30 | return &Single[T]{valc: valc} 31 | } 32 | 33 | // Run runs task in a new goroutine. The caller must call Wait to wait for the 34 | // task to return. The error reported by Wait is always nil. 35 | func Run(task func()) *Single[error] { return Go(noError(task)) } 36 | 37 | // Call starts task in a new goroutine. The caller must call Wait to wait for 38 | // the task to return and collect its result. 39 | func Call[T any](task func() (T, error)) *Single[Result[T]] { 40 | return Go(func() Result[T] { 41 | v, err := task() 42 | return Result[T]{Value: v, Err: err} 43 | }) 44 | } 45 | 46 | // A Result is a pair of an arbitrary value and an error. 47 | type Result[T any] struct { 48 | Value T 49 | Err error 50 | } 51 | 52 | // Get returns the fields of r as results. It is a convenience method for 53 | // unpacking the results of a Call. 54 | // 55 | // Typical usage: 56 | // 57 | // s := taskgroup.Call(func() (int, error) { ... }) 58 | // v, err := s.Wait().Get() 59 | func (r Result[T]) Get() (T, error) { return r.Value, r.Err } 60 | -------------------------------------------------------------------------------- /examples/copytree/copytree.go: -------------------------------------------------------------------------------- 1 | // Binary copytree is an example program to demonstrate the use of the group 2 | // and throttle packages to manage concurrency. It recursively copies a tree 3 | // of files from one directory to another. 4 | // 5 | // Usage: 6 | // 7 | // copytree -from /path/to/source -to /path/to/target 8 | package main 9 | 10 | import ( 11 | "context" 12 | "flag" 13 | "io" 14 | "log" 15 | "os" 16 | "path/filepath" 17 | "strings" 18 | 19 | "github.com/creachadair/taskgroup" 20 | ) 21 | 22 | var ( 23 | srcPath = flag.String("from", "", "Source path (required)") 24 | dstPath = flag.String("to", "", "Destination path (required)") 25 | maxWorkers = flag.Int("workers", 1, "Maximum number of concurrent tasks") 26 | ) 27 | 28 | func main() { 29 | flag.Parse() 30 | 31 | if *srcPath == "" || *dstPath == "" { 32 | log.Fatal("You must provide both --from and --to paths") 33 | } 34 | var destExists bool 35 | if _, err := os.Stat(*dstPath); err == nil { 36 | destExists = true 37 | } 38 | 39 | ctx, cancel := context.WithCancel(context.Background()) 40 | g, start := taskgroup.New(cancel).Limit(*maxWorkers) 41 | 42 | err := filepath.Walk(*srcPath, func(path string, fi os.FileInfo, err error) error { 43 | if err != nil { 44 | return err 45 | } else if err := ctx.Err(); err != nil { 46 | return err 47 | } 48 | 49 | target := adjustPath(path) 50 | if fi.IsDir() { 51 | return os.MkdirAll(target, fi.Mode()) 52 | } else if fi.Mode()&os.ModeType == os.ModeSymlink { 53 | start(func() error { 54 | log.Printf("Relinking %q", path) 55 | return copyLink(ctx, path, target) 56 | }) 57 | } else { 58 | start(func() error { 59 | log.Printf("Copying %q", path) 60 | return copyFile(ctx, path, target) 61 | }) 62 | } 63 | return nil 64 | }) 65 | if err != nil { 66 | log.Printf("Error traversing directory: %v", err) 67 | cancel() 68 | } 69 | if err := g.Wait(); err != nil { 70 | log.Printf("Error copying: %v", err) 71 | if !destExists { 72 | log.Printf("Cleaning up %q...", *dstPath) 73 | os.RemoveAll(*dstPath) 74 | } 75 | os.Exit(1) 76 | } 77 | } 78 | 79 | // adjustPath modifies path to be relative to the destination by stripping off 80 | // the source prefix and conjoining it with the destination path. 81 | func adjustPath(path string) string { 82 | return filepath.Join(*dstPath, strings.TrimPrefix(path, *srcPath)) 83 | } 84 | 85 | // copyFile copies a plain file from source to target. 86 | func copyFile(ctx context.Context, source, target string) error { 87 | if err := ctx.Err(); err != nil { 88 | return err 89 | } 90 | in, err := os.Open(source) 91 | if err != nil { 92 | return err 93 | } 94 | defer in.Close() 95 | out, err := os.Create(target) 96 | if err != nil { 97 | return err 98 | } 99 | if _, err := io.Copy(out, in); err != nil { 100 | out.Close() 101 | return err 102 | } 103 | return out.Close() 104 | } 105 | 106 | // copyLink transfers a symlink from source to target. It is an error if the 107 | // content of source cannot be made relative to source. 108 | func copyLink(ctx context.Context, source, target string) error { 109 | link, err := os.Readlink(source) 110 | if err != nil { 111 | return err 112 | } 113 | if !filepath.IsAbs(link) { 114 | link = filepath.Join(source, link) 115 | } 116 | rel, err := filepath.Rel(source, link) 117 | if err != nil { 118 | return err 119 | } 120 | return os.Symlink(rel, target) 121 | } 122 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package taskgroup_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "log" 9 | "math/rand/v2" 10 | "strings" 11 | "time" 12 | 13 | "github.com/creachadair/taskgroup" 14 | ) 15 | 16 | func ExampleGroup() { 17 | msg := make(chan string) 18 | g := taskgroup.New(nil) 19 | g.Run(func() { 20 | msg <- "ping" 21 | fmt.Println(<-msg) 22 | }) 23 | g.Run(func() { 24 | fmt.Println(<-msg) 25 | msg <- "pong" 26 | }) 27 | g.Wait() 28 | fmt.Println("") 29 | 30 | // Output: 31 | // ping 32 | // pong 33 | // 34 | } 35 | 36 | func ExampleNew_cancel() { 37 | ctx, cancel := context.WithCancel(context.Background()) 38 | defer cancel() 39 | 40 | const badTask = 5 41 | 42 | // Construct a group in which any task error cancels the context. 43 | g := taskgroup.New(cancel) 44 | 45 | for i := range 10 { 46 | g.Go(func() error { 47 | if i == badTask { 48 | return fmt.Errorf("task %d failed", i) 49 | } 50 | select { 51 | case <-ctx.Done(): 52 | return ctx.Err() 53 | case <-time.After(10 * time.Second): 54 | return nil 55 | } 56 | }) 57 | } 58 | 59 | if err := g.Wait(); err == nil { 60 | log.Fatal("I expected an error here") 61 | } else { 62 | fmt.Println(err.Error()) 63 | } 64 | // Output: 65 | // task 5 failed 66 | } 67 | 68 | func ExampleNew_listen() { 69 | // The taskgroup itself will only report the first non-nil task error, but 70 | // you can use an error listener used to accumulate all of them. 71 | // Calls to the listener are synchronized, so we don't need a lock. 72 | var all []error 73 | g := taskgroup.New(func(e error) { 74 | all = append(all, e) 75 | }) 76 | g.Go(func() error { return errors.New("badness 1") }) 77 | g.Go(func() error { return errors.New("badness 2") }) 78 | g.Go(func() error { return errors.New("badness 3") }) 79 | 80 | if err := g.Wait(); err == nil || !strings.Contains(err.Error(), "badness") { 81 | log.Fatalf("Unexpected error: %v", err) 82 | } 83 | fmt.Println(errors.Join(all...)) 84 | // Unordered output: 85 | // badness 1 86 | // badness 2 87 | // badness 3 88 | } 89 | 90 | func ExampleGroup_Limit() { 91 | var p peakValue 92 | 93 | g, start := taskgroup.New(nil).Limit(4) 94 | for range 100 { 95 | start.Run(func() { 96 | p.inc() 97 | defer p.dec() 98 | time.Sleep(1 * time.Microsecond) 99 | }) 100 | } 101 | g.Wait() 102 | fmt.Printf("Max active ≤ 4: %v\n", p.max <= 4) 103 | // Output: 104 | // Max active ≤ 4: true 105 | } 106 | 107 | type slowReader struct { 108 | n int 109 | d time.Duration 110 | } 111 | 112 | func (s *slowReader) Read(data []byte) (int, error) { 113 | if s.n == 0 { 114 | return 0, io.EOF 115 | } 116 | time.Sleep(s.d) 117 | nr := min(len(data), s.n) 118 | s.n -= nr 119 | for i := range nr { 120 | data[i] = 'x' 121 | } 122 | return nr, nil 123 | } 124 | 125 | func ExampleSingle() { 126 | // A fake reader to simulate a slow file read. 127 | // 2500 bytes and each read takes 50ms. 128 | sr := &slowReader{2500, 50 * time.Millisecond} 129 | 130 | // Start a task to read te "file" in the background. 131 | fmt.Println("start") 132 | s := taskgroup.Call(func() ([]byte, error) { 133 | return io.ReadAll(sr) 134 | }) 135 | 136 | fmt.Println("work, work") 137 | data, err := s.Wait().Get() 138 | if err != nil { 139 | log.Fatalf("Read failed: %v", err) 140 | } 141 | fmt.Println("done") 142 | fmt.Println(len(data), "bytes") 143 | 144 | // Output: 145 | // start 146 | // work, work 147 | // done 148 | // 2500 bytes 149 | } 150 | 151 | func ExampleGatherer() { 152 | const numTasks = 25 153 | input := rand.Perm(500) 154 | 155 | // Start a bunch of tasks to find elements in the input. 156 | g, start := taskgroup.New(nil).Limit(10) 157 | 158 | // We can pass g.Go directly, or as in this example we can use a throttled 159 | // start function. 160 | var total int 161 | c := taskgroup.Gather(start, func(v int) { 162 | total += v 163 | }) 164 | 165 | for i := range numTasks { 166 | target := i + 1 167 | c.Call(func() (int, error) { 168 | for _, v := range input { 169 | if v == target { 170 | return v, nil 171 | } 172 | } 173 | return 0, errors.New("not found") 174 | }) 175 | } 176 | 177 | // Wait for the searchers to finish. Once they do, it is safe to access the 178 | // state managed by the gatherer. 179 | g.Wait() 180 | 181 | // Now get the final result. 182 | fmt.Println(total) 183 | // Output: 184 | // 325 185 | } 186 | 187 | func ExampleGatherer_Report() { 188 | type val struct { 189 | who string 190 | v int 191 | } 192 | 193 | g := taskgroup.New(nil) 194 | c := taskgroup.Gather(g.Go, func(z val) { 195 | fmt.Println(z.who, z.v) 196 | }) 197 | 198 | // The Report method passes its argument a function to report multiple 199 | // values to the collector. 200 | c.Report(func(report func(v val)) error { 201 | for i := range 3 { 202 | report(val{"even", 2 * i}) 203 | } 204 | return nil 205 | }) 206 | // Multiple reporters are fine. 207 | c.Report(func(report func(v val)) error { 208 | for i := range 3 { 209 | report(val{"odd", 2*i + 1}) 210 | } 211 | // An error from a reporter is propagated like any other task error. 212 | return errors.New("no bueno") 213 | }) 214 | err := g.Wait() 215 | if err == nil || err.Error() != "no bueno" { 216 | log.Fatalf("Unexpected error: %v", err) 217 | } 218 | // Unordered output: 219 | // even 0 220 | // odd 1 221 | // even 2 222 | // odd 3 223 | // even 4 224 | // odd 5 225 | } 226 | 227 | func ExampleThrottle() { 228 | var p peakValue 229 | 230 | work := func() { 231 | p.inc() 232 | defer p.dec() 233 | time.Sleep(time.Microsecond) 234 | } 235 | 236 | // Create a throttle with shared capacity among multiple groups. 237 | t := taskgroup.NewThrottle(10) 238 | 239 | var g1 taskgroup.Group 240 | var g2 taskgroup.Group 241 | 242 | // Start functions for all the calls to Limit share the capacity of t. 243 | start1 := t.Limit(&g1) 244 | start2 := t.Limit(&g2) 245 | 246 | for range 100 { 247 | start1.Run(work) 248 | start2.Run(work) 249 | } 250 | 251 | g1.Wait() 252 | g2.Wait() 253 | fmt.Printf("Max active ≤ 10: %v\n", p.max <= 10) 254 | // Output: 255 | // Max active ≤ 10: true 256 | } 257 | -------------------------------------------------------------------------------- /taskgroup.go: -------------------------------------------------------------------------------- 1 | // Package taskgroup manages collections of cooperating goroutines. 2 | // It defines a [Group] that handles waiting for goroutine termination and the 3 | // propagation of error values. The caller may provide a callback to filter and 4 | // respond to task errors. 5 | package taskgroup 6 | 7 | import ( 8 | "fmt" 9 | "reflect" 10 | "sync" 11 | "sync/atomic" 12 | ) 13 | 14 | // A Task function is the basic unit of work in a [Group]. Errors reported by 15 | // tasks are collected and reported by the group. 16 | type Task func() error 17 | 18 | // A Group manages a collection of cooperating goroutines. Add new tasks to 19 | // the group with [Group.Go] and [Group.Run]. Call [Group.Wait] to wait for 20 | // the tasks to complete. A zero value is ready for use, but must not be copied 21 | // after its first use. 22 | // 23 | // The group collects any errors returned by the tasks in the group. The first 24 | // non-nil error reported by any task (and not otherwise filtered) is returned 25 | // from the Wait method. 26 | type Group struct { 27 | wg sync.WaitGroup // counter for active goroutines 28 | 29 | // active is nonzero when the group is "active", meaning there has been at 30 | // least one call to Go since the group was created or the last Wait. 31 | // 32 | // Together active and μ work as a kind of resettable sync.Once; the fast 33 | // path reads active and only acquires μ if it discovers setup is needed. 34 | active atomic.Bool 35 | 36 | μ sync.Mutex // guards the fields below 37 | err error // error returned from Wait 38 | onError errorFunc // called each time a task returns non-nil 39 | } 40 | 41 | // activate resets the state of the group and marks it as active. This is 42 | // triggered by adding a goroutine to an empty group. 43 | func (g *Group) activate() { 44 | g.μ.Lock() 45 | defer g.μ.Unlock() 46 | if !g.active.Load() { // still inactive 47 | g.err = nil 48 | g.active.Store(true) 49 | } 50 | } 51 | 52 | // New constructs a new empty group with the specified error filter. 53 | // See [Group.OnError] for a description of how errors are filtered. 54 | // If ef == nil, no filtering is performed. 55 | func New(ef any) *Group { return new(Group).OnError(ef) } 56 | 57 | // OnError sets the error filter for g. If ef == nil, the error filter is 58 | // removed and errors are no longer filtered. Otherwise, each non-nil error 59 | // reported by a task running in g is passed to ef. 60 | // 61 | // The concrete type of ef must be a function with one of the following 62 | // signature schemes, or OnError will panic. 63 | // 64 | // If ef is: 65 | // 66 | // func() 67 | // 68 | // then ef is called once per reported error, and the error is not modified. 69 | // 70 | // If ef is: 71 | // 72 | // func(error) 73 | // 74 | // then ef is called with each reported error, and the error is not modified. 75 | // 76 | // If ef is: 77 | // 78 | // func(error) error 79 | // 80 | // then ef is called with each reported error, and its result replaces the 81 | // reported value. This permits ef to suppress or replace the error value 82 | // selectively. 83 | // 84 | // Calls to ef are synchronized so that it is safe for ef to manipulate local 85 | // data structures without additional locking. It is safe to call OnError while 86 | // tasks are active in g. 87 | func (g *Group) OnError(ef any) *Group { 88 | filter := adaptErrorFunc(ef) 89 | g.μ.Lock() 90 | defer g.μ.Unlock() 91 | g.onError = filter 92 | return g 93 | } 94 | 95 | // Go runs task in a new goroutine in g. 96 | func (g *Group) Go(task Task) { 97 | g.wg.Add(1) 98 | if !g.active.Load() { 99 | g.activate() 100 | } 101 | go func() { 102 | defer g.wg.Done() 103 | if err := task(); err != nil { 104 | g.handleError(err) 105 | } 106 | }() 107 | } 108 | 109 | // Run runs task in a new goroutine in g. 110 | // The resulting task reports a nil error. 111 | func (g *Group) Run(task func()) { g.Go(noError(task)) } 112 | 113 | func (g *Group) handleError(err error) { 114 | g.μ.Lock() 115 | defer g.μ.Unlock() 116 | e := g.onError.filter(err) 117 | if e != nil && g.err == nil { 118 | g.err = e // capture the first unfiltered error always 119 | } 120 | } 121 | 122 | // Wait blocks until all the goroutines currently active in the group have 123 | // returned, and all reported errors have been delivered to the callback. It 124 | // returns the first non-nil error reported by any of the goroutines in the 125 | // group and not filtered by an OnError callback. 126 | // 127 | // As with [sync.WaitGroup], new tasks can be added to g during a call to Wait 128 | // only if the group contains at least one active task when Wait is called and 129 | // continuously thereafter until the last concurrent call to g.Go returns. 130 | // 131 | // Wait may be called from at most one goroutine at a time. After Wait has 132 | // returned, the group is ready for reuse. 133 | func (g *Group) Wait() error { 134 | g.wg.Wait() 135 | g.μ.Lock() 136 | defer g.μ.Unlock() 137 | 138 | // If the group is still active, deactivate it now. 139 | if g.active.Load() { 140 | g.active.Store(false) 141 | } 142 | return g.err 143 | } 144 | 145 | // An errorFunc is called by a group each time a task reports an error. Its 146 | // return value replaces the reported error, so the errorFunc can filter or 147 | // suppress errors by modifying or discarding the input error. 148 | type errorFunc func(error) error 149 | 150 | func (ef errorFunc) filter(err error) error { 151 | if ef == nil { 152 | return err 153 | } 154 | return ef(err) 155 | } 156 | 157 | var ( 158 | triggerType = reflect.TypeOf(func() {}) 159 | listenType = reflect.TypeOf(func(error) {}) 160 | filterType = reflect.TypeOf(func(error) error { return nil }) 161 | ) 162 | 163 | func adaptErrorFunc(ef any) errorFunc { 164 | v := reflect.ValueOf(ef) 165 | if !v.IsValid() { 166 | // OK, ef == nil, nothing to do 167 | return nil 168 | } else if t := v.Type(); t.ConvertibleTo(triggerType) { 169 | f := v.Convert(triggerType).Interface().(func()) 170 | return func(err error) error { f(); return err } 171 | } else if t.ConvertibleTo(listenType) { 172 | f := v.Convert(listenType).Interface().(func(error)) 173 | return func(err error) error { f(err); return err } 174 | } else if t.ConvertibleTo(filterType) { 175 | return errorFunc(v.Convert(filterType).Interface().(func(error) error)) 176 | } else { 177 | panic(fmt.Sprintf("unsupported filter type %T", ef)) 178 | } 179 | } 180 | 181 | func noError(f func()) Task { return func() error { f(); return nil } } 182 | 183 | // Limit returns g and a [StartFunc] that starts each task passed to it in g, 184 | // allowing no more than n tasks to be active concurrently. If n ≤ 0, no limit 185 | // is enforced. 186 | // 187 | // The limiting mechanism is optional, and the underlying group is not 188 | // restricted. A call to the start function will block until a slot is 189 | // available, but calling g.Go directly will add a task unconditionally and 190 | // will not take up a limiter slot. 191 | // 192 | // This is a shorthand for constructing a [Throttle] with capacity n and 193 | // calling its Limit method. If n ≤ 0, the start function is equivalent to 194 | // g.Go, which enforces no limit. To share a throttle among multiple groups, 195 | // construct the throttle separately. 196 | func (g *Group) Limit(n int) (*Group, StartFunc) { t := NewThrottle(n); return g, t.Limit(g) } 197 | -------------------------------------------------------------------------------- /taskgroup_test.go: -------------------------------------------------------------------------------- 1 | package taskgroup_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "math" 7 | "math/rand/v2" 8 | "reflect" 9 | "sync" 10 | "sync/atomic" 11 | "testing" 12 | "testing/synctest" 13 | "time" 14 | 15 | "github.com/creachadair/taskgroup" 16 | ) 17 | 18 | const numTasks = 64 19 | 20 | // randms returns a random duration of up to n milliseconds. 21 | func randms(n int) time.Duration { return time.Duration(rand.IntN(n)) * time.Millisecond } 22 | 23 | // busyWork returns a Task that does nothing for n ms and returns err. 24 | func busyWork(n int, err error) taskgroup.Task { 25 | return func() error { time.Sleep(randms(n)); return err } 26 | } 27 | 28 | func TestBasic(t *testing.T) { 29 | synctest.Test(t, func(t *testing.T) { 30 | t.Logf("Group value is %d bytes", reflect.TypeOf((*taskgroup.Group)(nil)).Elem().Size()) 31 | 32 | // Verify that the group works at all. 33 | var g taskgroup.Group 34 | g.Go(busyWork(25, nil)) 35 | if err := g.Wait(); err != nil { 36 | t.Errorf("Unexpected task error: %v", err) 37 | } 38 | 39 | // Verify that the group can be reused. 40 | g.Go(busyWork(50, nil)) 41 | g.Go(busyWork(75, nil)) 42 | if err := g.Wait(); err != nil { 43 | t.Errorf("Unexpected task error: %v", err) 44 | } 45 | }) 46 | 47 | t.Run("Zero", func(t *testing.T) { 48 | synctest.Test(t, func(t *testing.T) { 49 | g := taskgroup.New(nil) 50 | g.Go(busyWork(30, nil)) 51 | if err := g.Wait(); err != nil { 52 | t.Errorf("Unexpected task error: %v", err) 53 | } 54 | 55 | _, run := g.Limit(1) 56 | run(busyWork(60, nil)) 57 | if err := g.Wait(); err != nil { 58 | t.Errorf("Unexpected task error: %v", err) 59 | } 60 | }) 61 | }) 62 | } 63 | 64 | func TestErrorPropagation(t *testing.T) { 65 | synctest.Test(t, func(t *testing.T) { 66 | errBogus := errors.New("bogus") 67 | 68 | var g taskgroup.Group 69 | g.Go(func() error { return errBogus }) 70 | if err := g.Wait(); err != errBogus { 71 | t.Errorf("Wait: got error %v, wanted %v", err, errBogus) 72 | } 73 | 74 | g.OnError(func(error) error { return nil }) // discard 75 | g.Go(func() error { return errBogus }) 76 | if err := g.Wait(); err != nil { 77 | t.Errorf("Wait: got error %v, wanted nil", err) 78 | } 79 | }) 80 | } 81 | 82 | func TestCancellation(t *testing.T) { 83 | synctest.Test(t, func(t *testing.T) { 84 | 85 | var errs []error 86 | g := taskgroup.New(func(err error) { 87 | errs = append(errs, err) 88 | }) 89 | 90 | errOther := errors.New("something is wrong") 91 | ctx, cancel := context.WithCancel(context.Background()) 92 | var numOK int32 93 | for range numTasks { 94 | g.Go(func() error { 95 | select { 96 | case <-ctx.Done(): 97 | return ctx.Err() 98 | case <-time.After(randms(1)): 99 | return errOther 100 | case <-time.After(randms(1)): 101 | atomic.AddInt32(&numOK, 1) 102 | return nil 103 | } 104 | }) 105 | } 106 | cancel() 107 | g.Wait() 108 | var numCanceled, numOther int 109 | for _, err := range errs { 110 | switch err { 111 | case context.Canceled: 112 | numCanceled++ 113 | case errOther: 114 | numOther++ 115 | default: 116 | t.Errorf("Unexpected error: %v", err) 117 | } 118 | } 119 | t.Logf("Got %d successful tasks, %d cancelled tasks, and %d other errors", 120 | numOK, numCanceled, numOther) 121 | if total := int(numOK) + numCanceled + numOther; total != numTasks { 122 | t.Errorf("Task count mismatch: got %d results, wanted %d", total, numTasks) 123 | } 124 | }) 125 | } 126 | 127 | func TestCapacity(t *testing.T) { 128 | synctest.Test(t, func(t *testing.T) { 129 | const maxCapacity = 25 130 | const numTasks = 1492 131 | 132 | // Verify that multiple groups sharing a throttle respect the combined 133 | // capacity limit. 134 | throttle := taskgroup.NewThrottle(maxCapacity) 135 | var g1, g2 taskgroup.Group 136 | start1 := throttle.Limit(&g1) 137 | start2 := throttle.Limit(&g2) 138 | 139 | var p peakValue 140 | var n int32 141 | for i := range numTasks { 142 | start := start1 143 | if i%2 == 1 { 144 | start = start2 145 | } 146 | start.Run(func() { 147 | p.inc() 148 | defer p.dec() 149 | time.Sleep(2 * time.Millisecond) 150 | atomic.AddInt32(&n, 1) 151 | }) 152 | } 153 | g1.Wait() 154 | g2.Wait() 155 | t.Logf("Total tasks completed: %d", n) 156 | if p.max > maxCapacity { 157 | t.Errorf("Exceeded maximum capacity: got %d, want %d", p.max, maxCapacity) 158 | } else { 159 | t.Logf("Maximum concurrent tasks: %d", p.max) 160 | } 161 | }) 162 | } 163 | 164 | func TestRegression(t *testing.T) { 165 | t.Run("WaitRace", func(t *testing.T) { 166 | synctest.Test(t, func(t *testing.T) { 167 | ready := make(chan struct{}) 168 | var g taskgroup.Group 169 | g.Go(func() error { 170 | <-ready 171 | return nil 172 | }) 173 | 174 | var wg sync.WaitGroup 175 | wg.Add(2) 176 | go func() { defer wg.Done(); g.Wait() }() 177 | go func() { defer wg.Done(); g.Wait() }() 178 | 179 | close(ready) 180 | wg.Wait() 181 | }) 182 | }) 183 | t.Run("WaitUnstarted", func(t *testing.T) { 184 | synctest.Test(t, func(t *testing.T) { 185 | defer func() { 186 | if x := recover(); x != nil { 187 | t.Errorf("Unexpected panic: %v", x) 188 | } 189 | }() 190 | var g taskgroup.Group 191 | g.Wait() 192 | }) 193 | }) 194 | } 195 | 196 | func TestSingleTask(t *testing.T) { 197 | sentinel := errors.New("expected value") 198 | 199 | t.Run("Early", func(t *testing.T) { 200 | synctest.Test(t, func(t *testing.T) { 201 | s := taskgroup.Go(func() error { 202 | return sentinel 203 | }) 204 | 205 | time.Sleep(time.Second) 206 | if err := s.Wait(); err != sentinel { 207 | t.Errorf("Wait: got %v, want %v", err, sentinel) 208 | } 209 | }) 210 | }) 211 | 212 | t.Run("Late", func(t *testing.T) { 213 | synctest.Test(t, func(t *testing.T) { 214 | s := taskgroup.Go(func() error { 215 | time.Sleep(5 * time.Second) 216 | return sentinel 217 | }) 218 | 219 | var g taskgroup.Group 220 | g.Run(func() { 221 | if err := s.Wait(); err != sentinel { 222 | t.Errorf("Background Wait: got %v, want %v", err, sentinel) 223 | } 224 | }) 225 | 226 | if err := s.Wait(); err != sentinel { 227 | t.Errorf("Foreground Wait: got %v, want %v", err, sentinel) 228 | } 229 | }) 230 | }) 231 | 232 | t.Run("MultipleWaiters", func(t *testing.T) { 233 | synctest.Test(t, func(t *testing.T) { 234 | // Here we want to verify that multiple concurrent waiters do not produce 235 | // a data race. 236 | s := taskgroup.Go(func() error { 237 | time.Sleep(5 * time.Second) 238 | return sentinel 239 | }) 240 | 241 | // Start multiple waiters, and wait for them to be running. 242 | const numWaiters = 4 243 | for i := range numWaiters { 244 | go func() { 245 | if err := s.Wait(); err != sentinel { 246 | t.Errorf("Wait %d: got %v, want %v", i+1, err, sentinel) 247 | } 248 | }() 249 | } 250 | synctest.Wait() // waiters are running 251 | 252 | if err := s.Wait(); err != sentinel { 253 | t.Errorf("Wait: got %v, want %v", err, sentinel) 254 | } 255 | }) 256 | }) 257 | } 258 | 259 | func TestWaitMoreTasks(t *testing.T) { 260 | synctest.Test(t, func(t *testing.T) { 261 | var g taskgroup.Group 262 | var results int 263 | coll := taskgroup.Gather(g.Go, func(int) { 264 | results++ 265 | }) 266 | 267 | // Test that if a task spawns more tasks on its own recognizance, waiting 268 | // correctly waits for all of them provided we do not let the group go empty 269 | // before all the tasks are spawned. 270 | var countdown func(int) int 271 | countdown = func(n int) int { 272 | if n > 1 { 273 | // The subordinate task, if there is one, is started before this one 274 | // exits, ensuring the group is kept "afloat". 275 | coll.Run(func() int { 276 | return countdown(n - 1) 277 | }) 278 | } 279 | return n 280 | } 281 | 282 | coll.Run(func() int { return countdown(15) }) 283 | g.Wait() 284 | 285 | if results != 15 { 286 | t.Errorf("Got %d results, want 15", results) 287 | } 288 | }) 289 | } 290 | 291 | func TestSingleResult(t *testing.T) { 292 | synctest.Test(t, func(t *testing.T) { 293 | s := taskgroup.Call(func() (int, error) { 294 | time.Sleep(5 * time.Second) 295 | return 25, nil 296 | }) 297 | 298 | res, err := s.Wait().Get() 299 | if err != nil { 300 | t.Errorf("Unexpected error: %v", err) 301 | } 302 | if res != 25 { 303 | t.Errorf("Result: got %v, want 25", res) 304 | } 305 | }) 306 | } 307 | 308 | func TestGatherer(t *testing.T) { 309 | g, run := taskgroup.New(nil).Limit(4) 310 | checkWait := func(t *testing.T) { 311 | t.Helper() 312 | if err := g.Wait(); err != nil { 313 | t.Errorf("Unexpected error from Wait: %v", err) 314 | } 315 | } 316 | 317 | t.Run("Call", func(t *testing.T) { 318 | synctest.Test(t, func(t *testing.T) { 319 | var sum int 320 | r := taskgroup.Gather(run, func(v int) { 321 | sum += v 322 | }) 323 | 324 | for _, v := range rand.Perm(15) { 325 | r.Call(func() (int, error) { 326 | if v > 10 { 327 | return -100, errors.New("don't add this") 328 | } 329 | return v, nil 330 | }) 331 | } 332 | 333 | g.Wait() 334 | if want := (10 * 11) / 2; sum != want { 335 | t.Errorf("Final result: got %d, want %d", sum, want) 336 | } 337 | }) 338 | }) 339 | 340 | t.Run("Run", func(t *testing.T) { 341 | synctest.Test(t, func(t *testing.T) { 342 | var sum int 343 | r := taskgroup.Gather(run, func(v int) { 344 | sum += v 345 | }) 346 | for _, v := range rand.Perm(15) { 347 | r.Run(func() int { return v + 1 }) 348 | } 349 | 350 | checkWait(t) 351 | if want := (15 * 16) / 2; sum != want { 352 | t.Errorf("Final result: got %d, want %d", sum, want) 353 | } 354 | }) 355 | }) 356 | 357 | t.Run("Report", func(t *testing.T) { 358 | synctest.Test(t, func(t *testing.T) { 359 | var sum uint32 360 | r := taskgroup.Gather(g.Go, func(v uint32) { 361 | sum |= v 362 | }) 363 | 364 | for _, i := range rand.Perm(32) { 365 | r.Report(func(report func(v uint32)) error { 366 | for _, v := range rand.Perm(i + 1) { 367 | report(uint32(1 << v)) 368 | } 369 | return nil 370 | }) 371 | } 372 | 373 | checkWait(t) 374 | if sum != math.MaxUint32 { 375 | t.Errorf("Final result: got %d, want %d", sum, math.MaxUint32) 376 | } 377 | }) 378 | }) 379 | } 380 | 381 | type peakValue struct { 382 | μ sync.Mutex 383 | cur, max int 384 | } 385 | 386 | func (p *peakValue) inc() { 387 | p.μ.Lock() 388 | p.cur++ 389 | if p.cur > p.max { 390 | p.max = p.cur 391 | } 392 | p.μ.Unlock() 393 | } 394 | 395 | func (p *peakValue) dec() { 396 | p.μ.Lock() 397 | p.cur-- 398 | p.μ.Unlock() 399 | } 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # taskgroup 2 | 3 | [![GoDoc](https://img.shields.io/static/v1?label=godoc&message=reference&color=khaki)](https://pkg.go.dev/github.com/creachadair/taskgroup) 4 | [![CI](https://github.com/creachadair/taskgroup/actions/workflows/go-presubmit.yml/badge.svg?event=push&branch=main)](https://github.com/creachadair/taskgroup/actions/workflows/go-presubmit.yml) 5 | 6 | A `*taskgroup.Group` represents a group of goroutines working on related tasks. 7 | New tasks can be added to the group at will, and the caller can wait until all 8 | tasks are complete. Errors are automatically collected and delivered 9 | synchronously to a user-provided callback. This does not replace the full 10 | generality of Go's built-in features, but it simplifies some of the plumbing 11 | for common concurrent tasks. 12 | 13 | Here is a [working example in the Go Playground](https://go.dev/play/p/miyrtp4PyOc). 14 | 15 | ## Contents 16 | 17 | - [Rationale](#rationale) 18 | - [Overview](#overview) 19 | - [Filtering Errors](#filtering-errors) 20 | - [Controlling Concurrency](#controlling-concurrency) 21 | - [Solo Tasks](#solo-tasks) 22 | - [Gathering Results](#gathering-results) 23 | 24 | ## Rationale 25 | 26 | Go provides powerful concurrency primitives, including 27 | [goroutines](http://golang.org/ref/spec#Go_statements), 28 | [channels](http://golang.org/ref/spec#Channel_types), 29 | [select](http://golang.org/ref/spec#Select_statements), and the standard 30 | library's [sync](http://godoc.org/sync) package. In some common situations, 31 | however, managing goroutine lifetimes can become unwieldy using only what is 32 | built in. 33 | 34 | For example, consider the case of copying a large directory tree: Walk through 35 | a source directory recursively, creating a parallel target directory structure 36 | and starting a goroutine to copy each of the files concurrently. In outline: 37 | 38 | ```go 39 | func copyTree(source, target string) error { 40 | err := filepath.Walk(source, func(path string, fi os.FileInfo, err error) error { 41 | adjusted := adjustPath(path) 42 | if fi.IsDir() { 43 | return os.MkdirAll(adjusted, 0755) 44 | } 45 | go copyFile(adjusted, target) 46 | return nil 47 | }) 48 | if err != nil { 49 | // ... clean up the output directory ... 50 | } 51 | return err 52 | } 53 | ``` 54 | 55 | This solution is deficient, however, as it does not provide any way to detect 56 | when all the file copies are finished. To do that we will typically use a 57 | `sync.WaitGroup`: 58 | 59 | ```go 60 | var wg sync.WaitGroup 61 | ... 62 | wg.Add(1) 63 | go func() { 64 | defer wg.Done() 65 | copyFile(adjusted, target) 66 | }() 67 | ... 68 | wg.Wait() // block until all the tasks signal done 69 | ``` 70 | 71 | In addition, we need to handle errors. Copies might fail (the disk may fill, or 72 | there might be a permissions error). For some applications it might suffice to 73 | log the error and continue, but usually in case of error we should back out and 74 | clean up the partial state. 75 | 76 | To do that, we need to capture the return value from the function inside the 77 | goroutine―and that will require us either to add a lock or plumb in another 78 | channel: 79 | 80 | ```go 81 | errs := make(chan error) 82 | ... 83 | go copyFile(adjusted, target, errs) 84 | ``` 85 | 86 | Since multiple operations can be running in parallel, we will also need another 87 | goroutine to drain the errors channel and accumulate the results somewhere: 88 | 89 | ```go 90 | var failures []error 91 | go func() { 92 | for e := range errs { 93 | failures = append(failures, e) 94 | } 95 | }() 96 | ... 97 | wg.Wait() 98 | close(errs) 99 | ``` 100 | 101 | Once the work is finished, we must also detect when the error collector is 102 | done, so we can examine the `failures` without a data race. We'll need another 103 | channel or wait group to signal for this: 104 | 105 | ```go 106 | var failures []error 107 | edone := make(chan struct{}) 108 | go func() { 109 | defer close(edone) 110 | for e := range errs { 111 | failures = append(failures, e) 112 | } 113 | }() 114 | ... 115 | wg.Wait() // all the workers are done 116 | close(errs) // signal the error collector to stop 117 | <-edone // wait for the error collector to be done 118 | ``` 119 | 120 | Another issue is, if one of the file copies fails, we don't necessarily want to 121 | wait around for all the copies to finish before reporting the error―we want to 122 | stop everything and clean up the whole operation. Typically we would do this 123 | using a `context.Context`: 124 | 125 | ```go 126 | ctx, cancel := context.WithCancel(context.Background()) 127 | defer cancel() 128 | ... 129 | copyFile(ctx, adjusted, target, errs) 130 | ``` 131 | 132 | Now `copyFile` will have to check for `ctx` to be finished: 133 | 134 | ```go 135 | func copyFile(ctx context.Context, source, target string, errs chan<- error) { 136 | if ctx.Err() != nil { 137 | return 138 | } 139 | // ... do the copy as normal, or propagate an error 140 | } 141 | ``` 142 | 143 | Finally, we want the ability to to limit the number of concurrent copies. Even 144 | if the host has plenty of memory and CPU, unbounded concurrency is likely to 145 | run us out of file descriptors. To handle this we might use a 146 | [semaphore](https://godoc.org/golang.org/x/sync/semaphore) or a throttling 147 | channel: 148 | 149 | ```go 150 | throttle := make(chan struct{}, 64) // allow up to 64 concurrent copies 151 | go func() { 152 | throttle <- struct{}{} // block until the throttle has a free slot 153 | defer func() { wg.Done(); <-throttle }() 154 | copyFile(ctx, adjusted, target, errs) 155 | }() 156 | ``` 157 | 158 | So far, we're up to four channels (errs, edone, context, and throttle) plus a 159 | wait group. The point to note is that while these tools are quite able to 160 | express what we want, it can be tedious to wire them all together and keep 161 | track of the current state of the system. 162 | 163 | The `taskgroup` package exists to handle the plumbing for the common case of a 164 | group of tasks that are all working on a related outcome (_e.g.,_ copying a 165 | directory structure), and where an error on the part of any _single_ task may 166 | be grounds for terminating the work as a whole. 167 | 168 | The package provides a `taskgroup.Group` type that has built-in support for 169 | some of these concerns: 170 | 171 | - Limiting the number of active goroutines. 172 | - Collecting and filtering errors. 173 | - Waiting for completion and delivering status. 174 | 175 | A `taskgroup.Group` collects error values from each task and can deliver them 176 | to a user-provided callback. The callback can filter them or take other actions 177 | (such as cancellation). Invocations of the callback are all done from a single 178 | goroutine so it is safe to manipulate local resources without a lock. 179 | 180 | A group does not directly support cancellation, but integrates cleanly with the 181 | standard [context](https://godoc.org/context) package. A `context.CancelFunc` 182 | can be used as a trigger to signal the whole group when an error occurs. 183 | 184 | ## Overview 185 | 186 | A task is expressed as a `func() error`, and is added to a group using the `Go` 187 | method: 188 | 189 | ```go 190 | var g taskgroup.Group 191 | g.Go(myTask) 192 | ``` 193 | 194 | Any number of tasks may be added, and it is safe to do so from multiple 195 | goroutines concurrently. To wait for the tasks to finish, use: 196 | 197 | ```go 198 | err := g.Wait() 199 | ``` 200 | 201 | `Wait` blocks until all the tasks in the group have returned, and then reports 202 | the first non-nil error returned by any of the worker tasks. 203 | 204 | An implementation of this example can be found in `examples/copytree/copytree.go`. 205 | 206 | ## Filtering Errors 207 | 208 | The `taskgroup.New` function takes an optional callback to be invoked for each 209 | non-nil error reported by a task in the group. The callback may choose to 210 | propagate, replace, or discard the error. For example, suppose we want to 211 | ignore "file not found" errors from a copy operation: 212 | 213 | ```go 214 | g := taskgroup.New(func(err error) error { 215 | if os.IsNotExist(err) { 216 | return nil // ignore files that do not exist 217 | } 218 | return err 219 | }) 220 | ``` 221 | 222 | This mechanism can also be used to trigger a context cancellation if a task 223 | fails, for example: 224 | 225 | ```go 226 | ctx, cancel := context.WithCancel(context.Background()) 227 | defer cancel() 228 | 229 | g := taskgroup.New(cancel) 230 | ``` 231 | 232 | Now, if a task in `g` reports an error, it will cancel the context, allowing 233 | any other running tasks to observe a context cancellation and bail out. 234 | 235 | ## Controlling Concurrency 236 | 237 | The `Limit` method supports limiting the number of concurrently _active_ 238 | goroutines in the group. It returns a `StartFunc` that adds goroutines to the 239 | group, but will will block when the limit of goroutines is reached until some 240 | of the goroutines already running have finished. 241 | 242 | For example: 243 | 244 | ```go 245 | // Allow at most 3 concurrently-active goroutines in the group. 246 | g, start := taskgroup.New(nil).Limit(3) 247 | 248 | // Start tasks by calling the function returned by taskgroup.Limit: 249 | start(task1) 250 | start(task2) 251 | start(task3) 252 | start(task4) // blocks until one of the previous tasks is finished 253 | // ... 254 | ``` 255 | 256 | ## Solo Tasks 257 | 258 | In some cases it is useful to start a single background task to handle an 259 | isolated concern (elsewhere sometimes described as a "promise" or a "future"). 260 | 261 | For example, suppose we want to run some expensive background cleanup task 262 | while we take care of other work. Rather than create a whole group for a single 263 | goroutine we can create a solo task using the `Go` or `Run` functions: 264 | 265 | ```go 266 | s := taskgroup.Go(func() error { 267 | for _, v := range itemsToClean { 268 | if err := cleanup(v); err != nil { 269 | return err 270 | } 271 | } 272 | return nil 273 | }) 274 | ``` 275 | 276 | Once we're ready, we can `Wait` for this task to collect its result: 277 | 278 | ```go 279 | if err := s.Wait(); err != nil { 280 | log.Printf("WARNING: Cleanup failed: %v", err) 281 | } 282 | ``` 283 | 284 | Solo tasks are also helpful for functions that return a value. For example, 285 | suppose we want to read a file while we handle other matters. The `Call` 286 | function creates a solo task from such a function: 287 | 288 | ```go 289 | s := taskgroup.Call(func() ([]byte, error) { 290 | return os.ReadFile(filePath) 291 | }) 292 | ``` 293 | 294 | As before, we can `Wait` for the result when we're ready: 295 | 296 | ```go 297 | // N.B.: Wait returns a taskgroup.Result, whose Get method unpacks 298 | // it into a value and an error like a normal function call. 299 | data, err := s.Wait().Get() 300 | if err != nil { 301 | log.Fatalf("Read configuration: %v", err) 302 | } 303 | doThingsWith(data) 304 | ``` 305 | 306 | ## Gathering Results 307 | 308 | One common use for a background task is accumulating the results from a batch 309 | of concurrent workers. This could be handled by a solo task, as described 310 | above, but it is a common enough pattern that the library provides a `Gatherer` 311 | type to handle it specifically. 312 | 313 | To use it, pass a function to `Gather` to receive the values: 314 | 315 | ```go 316 | var g taskgroup.Group 317 | 318 | var sum int 319 | c := taskgroup.Gather(g.Go, func(v int) { sum += v }) 320 | ``` 321 | 322 | The `Call`, `Run`, and `Report` methods of `c` can now be used to start tasks 323 | in `g` that yield values, and deliver those values to the accumulator: 324 | 325 | - `c.Call` takes a `func() (T, error)`, returning a value and an error. 326 | If the task reports an error, that error is returned as usual. Otherwise, 327 | its non-error value is gathered by the callback. 328 | 329 | - `c.Run` takes a `func() T`, returning only a value, which is gathered by the 330 | callback. 331 | 332 | - `c.Report` takes a `func(func(T)) error`, which allows a task to report 333 | _multiple_ values to the gatherer via a "report" callback. The task itself 334 | returns only an `error`, but it may call its argument any number of times to 335 | gather values. 336 | 337 | Calls to the callback are serialized so that it is safe to access state without 338 | additional locking: 339 | 340 | ```go 341 | // Report an error, no value is gathered. 342 | c.Call(func() (int, error) { 343 | return -1, errors.New("bad") 344 | }) 345 | 346 | // No error, send gather the value 25. 347 | c.Call(func() (int, error) { 348 | return 25, nil 349 | }) 350 | 351 | // Gather a random integer. 352 | c.Run(func() int { return rand.Intn(1000) }) 353 | 354 | // Gather the values 10, 20, and 30. 355 | // 356 | // Note that even if the function reports an error, any values it sent 357 | // before returning are still gathered. 358 | c.Report(func(report func(int)) error { 359 | report(10) 360 | report(20) 361 | report(30) 362 | return nil 363 | }) 364 | ``` 365 | 366 | Once all the tasks passed to the gatherer are complete, it is safe to access 367 | the values accumulated by the callback: 368 | 369 | ```go 370 | g.Wait() // wait for tasks to finish 371 | 372 | // Now you can access the values accumulated by c. 373 | fmt.Println(sum) 374 | ``` 375 | --------------------------------------------------------------------------------