├── go.mod ├── go.sum ├── .github └── workflows │ └── go.yml ├── LICENSE ├── example_test.go ├── README.md ├── semgroup.go └── semgroup_test.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/fatih/semgroup 2 | 3 | go 1.17 4 | 5 | require golang.org/x/sync v0.8.0 6 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= 2 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 3 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | ci: 7 | name: "Go build" 8 | runs-on: ubuntu-latest 9 | steps: 10 | 11 | - name: Set up Go 1.17 12 | uses: actions/setup-go@v2 13 | with: 14 | go-version: 1.17 15 | id: go 16 | 17 | - name: Check out code into the Go module directory 18 | uses: actions/checkout@v2 19 | 20 | - name: Test 21 | run: | 22 | go mod tidy -v 23 | go test -race ./... 24 | 25 | - run: "go vet ./..." 26 | 27 | - name: Staticcheck 28 | uses: dominikh/staticcheck-action@v1.1.0 29 | with: 30 | version: "2021.1.1" 31 | install-go: false 32 | 33 | - name: Build 34 | run: go build ./... 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022, Fatih Arslan 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of semgroup nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package semgroup_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "sync" 8 | 9 | "github.com/fatih/semgroup" 10 | ) 11 | 12 | // This example increases a counter for each visit concurrently, using a 13 | // SemGroup to block until all the visitors have finished. It only runs 2 tasks 14 | // at any time. 15 | func ExampleGroup_parallel() { 16 | const maxWorkers = 2 17 | s := semgroup.NewGroup(context.Background(), maxWorkers) 18 | 19 | var ( 20 | counter int 21 | mu sync.Mutex // protects visits 22 | ) 23 | 24 | visitors := []int{5, 2, 10, 8, 9, 3, 1} 25 | 26 | for _, v := range visitors { 27 | v := v 28 | 29 | s.Go(func() error { 30 | mu.Lock() 31 | counter += v 32 | mu.Unlock() 33 | return nil 34 | }) 35 | } 36 | 37 | // Wait for all visits to complete. Any errors are accumulated. 38 | if err := s.Wait(); err != nil { 39 | fmt.Println(err) 40 | } 41 | 42 | fmt.Printf("Counter: %d", counter) 43 | 44 | // Output: 45 | // Counter: 38 46 | } 47 | 48 | func ExampleGroup_withErrors() { 49 | const maxWorkers = 2 50 | s := semgroup.NewGroup(context.Background(), maxWorkers) 51 | 52 | visitors := []int{1, 1, 1, 1, 2, 2, 1, 1, 2} 53 | 54 | for _, v := range visitors { 55 | v := v 56 | 57 | s.Go(func() error { 58 | if v != 1 { 59 | return errors.New("only one visitor is allowed") 60 | } 61 | return nil 62 | }) 63 | } 64 | 65 | // Wait for all visits to complete. Any errors are accumulated. 66 | if err := s.Wait(); err != nil { 67 | fmt.Println(err) 68 | } 69 | 70 | // Output: 71 | // 3 error(s) occurred: 72 | // * only one visitor is allowed 73 | // * only one visitor is allowed 74 | // * only one visitor is allowed 75 | } 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # semgroup [![](https://github.com/fatih/semgroup/workflows/build/badge.svg)](https://github.com/fatih/semgroup/actions) [![PkgGoDev](https://pkg.go.dev/badge/github.com/fatih/semgroup)](https://pkg.go.dev/github.com/fatih/semgroup) 2 | 3 | semgroup provides synchronization and error propagation, for groups of goroutines working on subtasks of a common task. It uses a weighted semaphore implementation to make sure that only a number of maximum tasks can be run at any time. 4 | 5 | Unlike [golang.org/x/sync/errgroup](https://pkg.go.dev/golang.org/x/sync/errgroup), it doesn't return the first non-nil error, rather it accumulates all errors and returns a set of errors, allowing each task to fullfil their task. 6 | 7 | 8 | # Install 9 | 10 | ```bash 11 | go get github.com/fatih/semgroup 12 | ``` 13 | 14 | # Example 15 | 16 | With no errors: 17 | 18 | ```go 19 | package main 20 | 21 | import ( 22 | "context" 23 | "fmt" 24 | 25 | "github.com/fatih/semgroup" 26 | ) 27 | 28 | func main() { 29 | const maxWorkers = 2 30 | s := semgroup.NewGroup(context.Background(), maxWorkers) 31 | 32 | visitors := []int{5, 2, 10, 8, 9, 3, 1} 33 | 34 | for _, v := range visitors { 35 | v := v 36 | 37 | s.Go(func() error { 38 | fmt.Println("Visits: ", v) 39 | return nil 40 | }) 41 | } 42 | 43 | // Wait for all visits to complete. Any errors are accumulated. 44 | if err := s.Wait(); err != nil { 45 | fmt.Println(err) 46 | } 47 | 48 | // Output: 49 | // Visits: 2 50 | // Visits: 10 51 | // Visits: 8 52 | // Visits: 9 53 | // Visits: 3 54 | // Visits: 1 55 | // Visits: 5 56 | } 57 | ``` 58 | 59 | With errors: 60 | 61 | 62 | ```go 63 | package main 64 | 65 | import ( 66 | "context" 67 | "errors" 68 | "fmt" 69 | 70 | "github.com/fatih/semgroup" 71 | ) 72 | 73 | func main() { 74 | const maxWorkers = 2 75 | s := semgroup.NewGroup(context.Background(), maxWorkers) 76 | 77 | visitors := []int{1, 1, 1, 1, 2, 2, 1, 1, 2} 78 | 79 | for _, v := range visitors { 80 | v := v 81 | 82 | s.Go(func() error { 83 | if v != 1 { 84 | return errors.New("only one visitor is allowed") 85 | } 86 | return nil 87 | }) 88 | } 89 | 90 | // Wait for all visits to complete. Any errors are accumulated. 91 | if err := s.Wait(); err != nil { 92 | fmt.Println(err) 93 | } 94 | 95 | // Output: 96 | // 3 error(s) occurred: 97 | // * only one visitor is allowed 98 | // * only one visitor is allowed 99 | // * only one visitor is allowed 100 | } 101 | ``` 102 | 103 | -------------------------------------------------------------------------------- /semgroup.go: -------------------------------------------------------------------------------- 1 | // Package semgroup provides synchronization and error propagation, for groups 2 | // of goroutines working on subtasks of a common task. It uses a weighted 3 | // semaphore implementation to make sure that only a number of maximum tasks 4 | // can be run at any time. 5 | // 6 | // Unlike golang.org/x/sync/errgroup, it doesn't return the first non-nil 7 | // error, rather it accumulates all errors and returns a set of errors, 8 | // allowing each task to fullfil their task. 9 | package semgroup 10 | 11 | import ( 12 | "context" 13 | "errors" 14 | "fmt" 15 | "strings" 16 | "sync" 17 | 18 | "golang.org/x/sync/semaphore" 19 | ) 20 | 21 | // A Group is a collection of goroutines working on subtasks that are part of 22 | // the same overall task. 23 | type Group struct { 24 | sem *semaphore.Weighted 25 | wg sync.WaitGroup 26 | ctx context.Context 27 | 28 | errs MultiError 29 | mu sync.Mutex // protects errs 30 | } 31 | 32 | // NewGroup returns a new Group with the given maximum combined weight for 33 | // concurrent access. 34 | func NewGroup(ctx context.Context, maxWorkers int64) *Group { 35 | return &Group{ 36 | ctx: ctx, 37 | sem: semaphore.NewWeighted(maxWorkers), 38 | } 39 | } 40 | 41 | // Go calls the given function in a new goroutine. It also acquires the 42 | // semaphore with a weight of 1, blocking until resources are available or ctx 43 | // is done. 44 | 45 | // On success, returns nil. On failure, returns ctx.Err() and leaves the 46 | // semaphore unchanged. Any function call to return a non-nil error is 47 | // accumulated; the accumulated errors will be returned by Wait. 48 | func (g *Group) Go(f func() error) { 49 | g.wg.Add(1) 50 | 51 | err := g.sem.Acquire(g.ctx, 1) 52 | if err != nil { 53 | g.wg.Done() 54 | g.mu.Lock() 55 | g.errs = append(g.errs, fmt.Errorf("couldn't acquire semaphore: %s", err)) 56 | g.mu.Unlock() 57 | return 58 | } 59 | 60 | go func() { 61 | defer g.sem.Release(1) 62 | defer g.wg.Done() 63 | 64 | if err := f(); err != nil { 65 | g.mu.Lock() 66 | g.errs = append(g.errs, err) 67 | g.mu.Unlock() 68 | } 69 | }() 70 | } 71 | 72 | // Wait blocks until all function calls from the Go method have returned, then 73 | // returns all accumulated non-nil errors (if any) from them. 74 | // 75 | // If a non-nil error is returned, it will be of type [MultiError]. 76 | func (g *Group) Wait() error { 77 | g.wg.Wait() 78 | return g.errs.ErrorOrNil() 79 | } 80 | 81 | type MultiError []error 82 | 83 | func (e MultiError) Error() string { 84 | var b strings.Builder 85 | fmt.Fprintf(&b, "%d error(s) occurred:\n", len(e)) 86 | 87 | for i, err := range e { 88 | fmt.Fprintf(&b, "* %s", err.Error()) 89 | if i != len(e)-1 { 90 | fmt.Fprintln(&b, "") 91 | } 92 | } 93 | 94 | return b.String() 95 | } 96 | 97 | // ErrorOrNil returns nil if there are no errors, otherwise returns itself. 98 | func (e MultiError) ErrorOrNil() error { 99 | if len(e) == 0 { 100 | return nil 101 | } 102 | 103 | return e 104 | } 105 | 106 | func (e MultiError) Is(target error) bool { 107 | for _, err := range e { 108 | if errors.Is(err, target) { 109 | return true 110 | } 111 | } 112 | return false 113 | } 114 | 115 | func (e MultiError) As(target interface{}) bool { 116 | for _, err := range e { 117 | if errors.As(err, target) { 118 | return true 119 | } 120 | } 121 | return false 122 | } 123 | -------------------------------------------------------------------------------- /semgroup_test.go: -------------------------------------------------------------------------------- 1 | package semgroup 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "os" 7 | "strings" 8 | "sync" 9 | "testing" 10 | ) 11 | 12 | func TestGroup_single_task(t *testing.T) { 13 | ctx := context.Background() 14 | g := NewGroup(ctx, 1) 15 | 16 | g.Go(func() error { return nil }) 17 | 18 | err := g.Wait() 19 | if err != nil { 20 | t.Errorf("g.Wait() should not return an error") 21 | } 22 | } 23 | 24 | func TestGroup_multiple_tasks(t *testing.T) { 25 | ctx := context.Background() 26 | g := NewGroup(ctx, 1) 27 | 28 | count := 0 29 | var mu sync.Mutex 30 | 31 | inc := func() error { 32 | mu.Lock() 33 | count++ 34 | mu.Unlock() 35 | return nil 36 | } 37 | 38 | g.Go(func() error { return inc() }) 39 | g.Go(func() error { return inc() }) 40 | g.Go(func() error { return inc() }) 41 | g.Go(func() error { return inc() }) 42 | 43 | err := g.Wait() 44 | if err != nil { 45 | t.Errorf("g.Wait() should not return an error") 46 | } 47 | 48 | if count != 4 { 49 | t.Errorf("count should be %d, got: %d", 4, count) 50 | } 51 | } 52 | 53 | func TestGroup_multiple_tasks_errors(t *testing.T) { 54 | ctx := context.Background() 55 | g := NewGroup(ctx, 1) 56 | 57 | g.Go(func() error { return errors.New("foo") }) 58 | g.Go(func() error { return nil }) 59 | g.Go(func() error { return errors.New("bar") }) 60 | g.Go(func() error { return nil }) 61 | 62 | err := g.Wait() 63 | if err == nil { 64 | t.Fatalf("g.Wait() should return an error") 65 | } 66 | if !errors.As(err, &MultiError{}) { 67 | t.Fatalf("the error should be of type MultiError") 68 | } 69 | 70 | wantErr := `2 error(s) occurred: 71 | * foo 72 | * bar` 73 | 74 | if wantErr != err.Error() { 75 | t.Errorf("error should be:\n%s\ngot:\n%s\n", wantErr, err.Error()) 76 | } 77 | } 78 | 79 | func TestGroup_deadlock(t *testing.T) { 80 | canceledCtx, cancel := context.WithCancel(context.Background()) 81 | cancel() 82 | g := NewGroup(canceledCtx, 1) 83 | 84 | g.Go(func() error { return nil }) 85 | g.Go(func() error { return nil }) 86 | 87 | err := g.Wait() 88 | if err == nil { 89 | t.Fatalf("g.Wait() should return an error") 90 | } 91 | 92 | wantErr := `couldn't acquire semaphore: context canceled` 93 | 94 | if !strings.Contains(err.Error(), wantErr) { 95 | t.Errorf("error should contain:\n%s\ngot:\n%s\n", wantErr, err.Error()) 96 | } 97 | } 98 | 99 | func TestGroup_multiple_tasks_errors_Is(t *testing.T) { 100 | ctx := context.Background() 101 | g := NewGroup(ctx, 1) 102 | 103 | var ( 104 | fooErr = errors.New("foo") 105 | barErr = errors.New("bar") 106 | bazErr = errors.New("baz") 107 | ) 108 | 109 | g.Go(func() error { return fooErr }) 110 | g.Go(func() error { return nil }) 111 | g.Go(func() error { return barErr }) 112 | g.Go(func() error { return nil }) 113 | 114 | err := g.Wait() 115 | if err == nil { 116 | t.Fatalf("g.Wait() should return an error") 117 | } 118 | 119 | if !errors.Is(err, fooErr) { 120 | t.Errorf("error should be contained %v\n", fooErr) 121 | } 122 | 123 | if !errors.Is(err, barErr) { 124 | t.Errorf("error should be contained %v\n", barErr) 125 | } 126 | 127 | if errors.Is(err, bazErr) { 128 | t.Errorf("error should not be contained %v\n", bazErr) 129 | } 130 | 131 | var gotMultiErr MultiError 132 | if !errors.As(err, &gotMultiErr) { 133 | t.Fatalf("error should be matched MultiError") 134 | } 135 | expectedErr := (MultiError{fooErr, barErr}).Error() 136 | if gotMultiErr.Error() != expectedErr { 137 | t.Errorf("error should be %q, got %q", expectedErr, gotMultiErr.Error()) 138 | } 139 | } 140 | 141 | type foobarErr struct{ str string } 142 | 143 | func (e foobarErr) Error() string { 144 | return "foobar" 145 | } 146 | 147 | func TestGroup_multiple_tasks_errors_As(t *testing.T) { 148 | ctx := context.Background() 149 | g := NewGroup(ctx, 1) 150 | 151 | g.Go(func() error { return foobarErr{"baz"} }) 152 | g.Go(func() error { return nil }) 153 | 154 | err := g.Wait() 155 | if err == nil { 156 | t.Fatalf("g.Wait() should return an error") 157 | } 158 | 159 | var ( 160 | fbe foobarErr 161 | pe *os.PathError 162 | ) 163 | 164 | if !errors.As(err, &fbe) { 165 | t.Error("error should be matched foobarErr") 166 | } 167 | 168 | if errors.As(err, &pe) { 169 | t.Error("error should not be matched os.PathError") 170 | } 171 | } 172 | --------------------------------------------------------------------------------