├── .github └── workflows │ └── go.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── benchmark ├── benchmark.go ├── benchmark_test.go ├── go.mod └── go.sum ├── go.mod ├── go.sum ├── go.work ├── groups ├── group.go ├── group_test.go └── options.go ├── internal ├── utils.go └── utils_test.go ├── logs └── log.go └── queues ├── multiple_queue.go ├── options.go ├── queue.go ├── queue_test.go └── single_queue.go /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go Test 5 | 6 | on: 7 | push: 8 | branches: [ "master" ] 9 | pull_request: 10 | branches: [ "master" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v3 21 | with: 22 | go-version: 1.20.3 23 | - name: Test 24 | run: go test -v ./... 25 | - name: Bench 26 | run: make bench 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | 14 | # Dependency directories (remove the comment below to include it) 15 | vendor/ 16 | examples/ 17 | .idea/ 18 | .vacode/ 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 道友请留步 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test: 2 | go test -count=1 ./... 3 | 4 | bench: 5 | go test -benchmem -run=^$$ -bench . github.com/lxzan/concurrency/benchmark 6 | 7 | cover: 8 | go test -coverprofile=./bin/cover.out --cover ./... -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Concurrency 2 | 3 | [![Build Status](https://github.com/lxzan/concurrency/workflows/Go%20Test/badge.svg?branch=master)](https://github.com/lxzan/concurrency/actions?query=branch%3Amaster) [![Coverage Statusd][1]][2] 4 | 5 | [1]: https://codecov.io/gh/lxzan/concurrency/branch/master/graph/badge.svg 6 | [2]: https://codecov.io/gh/lxzan/concurrency 7 | 8 | ### Install 9 | 10 | ```bash 11 | go get -v github.com/lxzan/concurrency@latest 12 | ``` 13 | 14 | #### Usage 15 | 16 | ##### 任务组 17 | > 添加一组任务, 等待它们全部执行完成 18 | 19 | ```go 20 | package main 21 | 22 | import ( 23 | "fmt" 24 | "github.com/lxzan/concurrency/groups" 25 | "sync/atomic" 26 | ) 27 | 28 | func main() { 29 | sum := int64(0) 30 | w := groups.New[int64]() 31 | for i := int64(1); i <= 10; i++ { 32 | w.Push(i) 33 | } 34 | w.OnMessage = func(args int64) error { 35 | fmt.Printf("%v ", args) 36 | atomic.AddInt64(&sum, args) 37 | return nil 38 | } 39 | w.Start() 40 | fmt.Printf("sum=%d\n", sum) 41 | } 42 | ``` 43 | 44 | ``` 45 | 4 5 6 7 8 9 10 1 3 2 sum=55 46 | ``` 47 | 48 | ##### 任务队列 49 | > 把任务加入队列, 异步执行 50 | 51 | ```go 52 | package main 53 | 54 | import ( 55 | "context" 56 | "fmt" 57 | "github.com/lxzan/concurrency/queues" 58 | "sync/atomic" 59 | ) 60 | 61 | func main() { 62 | sum := int64(0) 63 | w := queues.New() 64 | for i := int64(1); i <= 10; i++ { 65 | var x = i 66 | w.Push(func() { 67 | fmt.Printf("%v ", x) 68 | atomic.AddInt64(&sum, x) 69 | }) 70 | } 71 | w.Stop(context.Background()) 72 | fmt.Printf("sum=%d\n", sum) 73 | } 74 | ``` 75 | 76 | ``` 77 | 3 9 10 4 1 6 8 5 2 7 sum=55 78 | ``` 79 | 80 | ### Benchmark 81 | 82 | ``` 83 | go test -benchmem -run=^$ -bench . github.com/lxzan/concurrency/benchmark 84 | goos: linux 85 | goarch: amd64 86 | pkg: github.com/lxzan/concurrency/benchmark 87 | cpu: AMD Ryzen 5 PRO 4650G with Radeon Graphics 88 | Benchmark_Fib-12 1000000 1146 ns/op 0 B/op 0 allocs/op 89 | Benchmark_StdGo-12 3661 317905 ns/op 16064 B/op 1001 allocs/op 90 | Benchmark_QueuesSingle-12 2178 532224 ns/op 67941 B/op 1098 allocs/op 91 | Benchmark_QueuesMultiple-12 3691 317757 ns/op 61648 B/op 1256 allocs/op 92 | Benchmark_Ants-12 1569 751802 ns/op 22596 B/op 1097 allocs/op 93 | Benchmark_GoPool-12 2910 406935 ns/op 19042 B/op 1093 allocs/op 94 | PASS 95 | ok github.com/lxzan/concurrency/benchmark 7.271s 96 | ``` -------------------------------------------------------------------------------- /benchmark/benchmark.go: -------------------------------------------------------------------------------- 1 | package benchmark 2 | 3 | func fib(n int) int { 4 | switch n { 5 | case 0, 1: 6 | return n 7 | default: 8 | return fib(n-1) + fib(n-2) 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /benchmark/benchmark_test.go: -------------------------------------------------------------------------------- 1 | package benchmark 2 | 3 | import ( 4 | "github.com/bytedance/gopkg/util/gopool" 5 | "github.com/lxzan/concurrency/queues" 6 | "github.com/panjf2000/ants/v2" 7 | "sync" 8 | "testing" 9 | ) 10 | 11 | const ( 12 | Concurrency = 16 13 | M = 1000 14 | N = 13 15 | ) 16 | 17 | func newJob(wg *sync.WaitGroup) func() { 18 | return func() { 19 | fib(N) 20 | wg.Done() 21 | } 22 | } 23 | 24 | func Benchmark_Fib(b *testing.B) { 25 | for i := 0; i < b.N; i++ { 26 | fib(N) 27 | } 28 | } 29 | 30 | func Benchmark_StdGo(b *testing.B) { 31 | for i := 0; i < b.N; i++ { 32 | wg := &sync.WaitGroup{} 33 | wg.Add(M) 34 | job := newJob(wg) 35 | for j := 0; j < M; j++ { 36 | go job() 37 | } 38 | wg.Wait() 39 | } 40 | } 41 | 42 | func Benchmark_QueuesSingle(b *testing.B) { 43 | q := queues.New( 44 | queues.WithConcurrency(Concurrency), 45 | queues.WithSharding(1), 46 | ) 47 | 48 | for i := 0; i < b.N; i++ { 49 | wg := &sync.WaitGroup{} 50 | wg.Add(M) 51 | job := newJob(wg) 52 | for j := 0; j < M; j++ { 53 | q.Push(job) 54 | } 55 | wg.Wait() 56 | } 57 | } 58 | 59 | func Benchmark_QueuesMultiple(b *testing.B) { 60 | q := queues.New( 61 | queues.WithConcurrency(1), 62 | queues.WithSharding(Concurrency), 63 | ) 64 | 65 | for i := 0; i < b.N; i++ { 66 | wg := &sync.WaitGroup{} 67 | wg.Add(M) 68 | job := newJob(wg) 69 | for j := 0; j < M; j++ { 70 | q.Push(job) 71 | } 72 | wg.Wait() 73 | } 74 | } 75 | 76 | func Benchmark_Ants(b *testing.B) { 77 | q, _ := ants.NewPool(Concurrency) 78 | defer q.Release() 79 | 80 | for i := 0; i < b.N; i++ { 81 | wg := &sync.WaitGroup{} 82 | wg.Add(M) 83 | job := newJob(wg) 84 | for j := 0; j < M; j++ { 85 | q.Submit(job) 86 | } 87 | wg.Wait() 88 | } 89 | } 90 | 91 | func Benchmark_GoPool(b *testing.B) { 92 | q := gopool.NewPool("", Concurrency, gopool.NewConfig()) 93 | 94 | for i := 0; i < b.N; i++ { 95 | wg := &sync.WaitGroup{} 96 | wg.Add(M) 97 | job := newJob(wg) 98 | for j := 0; j < M; j++ { 99 | q.Go(job) 100 | } 101 | wg.Wait() 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /benchmark/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lxzan/concurrency/benchmark 2 | 3 | go 1.18 4 | 5 | replace github.com/lxzan/concurrency v0.0.0 => ../ 6 | 7 | require ( 8 | github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b 9 | github.com/lxzan/concurrency v0.0.0 10 | github.com/panjf2000/ants/v2 v2.8.1 11 | ) 12 | -------------------------------------------------------------------------------- /benchmark/go.sum: -------------------------------------------------------------------------------- 1 | github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b h1:R6PWoQtxEMpWJPHnpci+9LgFxCS7iJCfOGBvCgZeTKI= 2 | github.com/bytedance/gopkg v0.0.0-20230728082804-614d0af6619b/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= 3 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 4 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 5 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/panjf2000/ants/v2 v2.8.1 h1:C+n/f++aiW8kHCExKlpX6X+okmxKXP7DWLutxuAPuwQ= 7 | github.com/panjf2000/ants/v2 v2.8.1/go.mod h1:KIBmYG9QQX5U2qzFP/yQJaq/nSb6rahS9iEHkrCMgM8= 8 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 9 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 10 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 11 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 12 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 13 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 14 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 15 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 16 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 17 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 18 | golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= 19 | golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 20 | golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= 21 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 22 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 23 | golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 24 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 25 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 26 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 27 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 28 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 29 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 30 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 31 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/lxzan/concurrency 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/lxzan/dao v1.1.7 7 | github.com/pkg/errors v0.9.1 8 | github.com/stretchr/testify v1.8.4 9 | ) 10 | 11 | require ( 12 | github.com/davecgh/go-spew v1.1.1 // indirect 13 | github.com/pmezard/go-difflib v1.0.0 // indirect 14 | gopkg.in/yaml.v3 v3.0.1 // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 2 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/lxzan/dao v1.1.7 h1:I049e67buJIpr4QJ/vJbHSjKMLN4ZJlSMeK3Rq+CJl8= 4 | github.com/lxzan/dao v1.1.7/go.mod h1:5ChTIo7RSZ4upqRo16eicJ3XdJWhGwgMIsyuGLMUofM= 5 | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= 6 | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= 7 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 8 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 9 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 10 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 12 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 13 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 14 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 15 | -------------------------------------------------------------------------------- /go.work: -------------------------------------------------------------------------------- 1 | go 1.20 2 | 3 | use ( 4 | . 5 | ./benchmark 6 | ) 7 | -------------------------------------------------------------------------------- /groups/group.go: -------------------------------------------------------------------------------- 1 | package groups 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "github.com/lxzan/concurrency/internal" 7 | "sync" 8 | "sync/atomic" 9 | "time" 10 | ) 11 | 12 | const ( 13 | defaultConcurrency = 8 // 默认并发度 14 | defaultWaitTimeout = 60 * time.Second // 默认线程同步等待超时 15 | ) 16 | 17 | var defaultCaller Caller = func(args any, f func(any) error) error { return f(args) } 18 | 19 | type ( 20 | Caller func(args any, f func(any) error) error 21 | 22 | Group[T any] struct { 23 | options *options // 配置 24 | mu sync.Mutex // 锁 25 | ctx context.Context // 上下文 26 | cancelFunc context.CancelFunc // 取消函数 27 | canceled atomic.Uint32 // 是否已取消 28 | errs []error // 错误 29 | done chan bool // 完成信号 30 | q []T // 任务队列 31 | taskDone int64 // 已完成任务数量 32 | taskTotal int64 // 总任务数量 33 | OnMessage func(args T) error // 任务处理 34 | OnError func(args T, err error) // 错误处理 35 | } 36 | ) 37 | 38 | // New 新建一个任务集 39 | func New[T any](opts ...Option) *Group[T] { 40 | o := new(options) 41 | opts = append(opts, withInitialize()) 42 | for _, f := range opts { 43 | f(o) 44 | } 45 | 46 | c := &Group[T]{ 47 | options: o, 48 | q: make([]T, 0), 49 | taskDone: 0, 50 | done: make(chan bool), 51 | } 52 | c.ctx, c.cancelFunc = context.WithTimeout(context.Background(), o.timeout) 53 | c.OnMessage = func(args T) error { 54 | return nil 55 | } 56 | c.OnError = func(args T, err error) {} 57 | 58 | return c 59 | } 60 | 61 | func (c *Group[T]) clearJob() { 62 | c.mu.Lock() 63 | c.q = c.q[:0] 64 | c.mu.Unlock() 65 | } 66 | 67 | func (c *Group[T]) getJob() (v T, ok bool) { 68 | c.mu.Lock() 69 | defer c.mu.Unlock() 70 | 71 | if n := len(c.q); n == 0 { 72 | return 73 | } 74 | var result = c.q[0] 75 | c.q = c.q[1:] 76 | return result, true 77 | } 78 | 79 | // incrAndIsDone 80 | // 已完成任务+1, 并检查任务是否全部完成 81 | func (c *Group[T]) incrAndIsDone() bool { 82 | c.mu.Lock() 83 | c.taskDone++ 84 | ok := c.taskDone == c.taskTotal 85 | c.mu.Unlock() 86 | return ok 87 | } 88 | 89 | func (c *Group[T]) getError() error { 90 | c.mu.Lock() 91 | defer c.mu.Unlock() 92 | return errors.Join(c.errs...) 93 | } 94 | 95 | func (c *Group[T]) jobFunc(v any) error { 96 | if c.canceled.Load() == 1 { 97 | return nil 98 | } 99 | return c.OnMessage(v.(T)) 100 | } 101 | 102 | func (c *Group[T]) do(args T) { 103 | if err := c.options.caller(args, c.jobFunc); err != nil { 104 | c.mu.Lock() 105 | c.errs = append(c.errs, err) 106 | c.mu.Unlock() 107 | c.OnError(args, err) 108 | } 109 | 110 | if c.incrAndIsDone() { 111 | c.done <- true 112 | return 113 | } 114 | 115 | if nextJob, ok := c.getJob(); ok { 116 | c.do(nextJob) 117 | } 118 | } 119 | 120 | // Len 获取队列中剩余任务数量 121 | func (c *Group[T]) Len() int { 122 | c.mu.Lock() 123 | x := len(c.q) 124 | c.mu.Unlock() 125 | return x 126 | } 127 | 128 | // Cancel 取消队列中剩余任务的执行 129 | func (c *Group[T]) Cancel() { 130 | if c.canceled.CompareAndSwap(0, 1) { 131 | c.cancelFunc() 132 | } 133 | } 134 | 135 | // Push 往任务队列中追加任务 136 | func (c *Group[T]) Push(eles ...T) { 137 | c.mu.Lock() 138 | c.taskTotal += int64(len(eles)) 139 | c.q = append(c.q, eles...) 140 | c.mu.Unlock() 141 | } 142 | 143 | // Update 线程安全操作 144 | func (c *Group[T]) Update(f func()) { 145 | c.mu.Lock() 146 | f() 147 | c.mu.Unlock() 148 | } 149 | 150 | // Start 启动并等待所有任务执行完成 151 | func (c *Group[T]) Start() error { 152 | var taskTotal = int64(c.Len()) 153 | if taskTotal == 0 { 154 | return nil 155 | } 156 | 157 | var co = internal.Min(c.options.concurrency, taskTotal) 158 | for i := int64(0); i < co; i++ { 159 | if item, ok := c.getJob(); ok { 160 | go c.do(item) 161 | } 162 | } 163 | 164 | defer c.cancelFunc() 165 | 166 | select { 167 | case <-c.done: 168 | return c.getError() 169 | case <-c.ctx.Done(): 170 | c.clearJob() 171 | return c.ctx.Err() 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /groups/group_test.go: -------------------------------------------------------------------------------- 1 | package groups 2 | 3 | import ( 4 | "github.com/pkg/errors" 5 | "github.com/stretchr/testify/assert" 6 | "sync" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | func TestNewTaskGroup(t *testing.T) { 13 | as := assert.New(t) 14 | 15 | t.Run("0 task", func(t *testing.T) { 16 | cc := New[int]() 17 | err := cc.Start() 18 | as.NoError(err) 19 | }) 20 | 21 | t.Run("1 task", func(t *testing.T) { 22 | cc := New[int]() 23 | cc.Push(0) 24 | err := cc.Start() 25 | as.NoError(err) 26 | }) 27 | 28 | t.Run("100 task", func(t *testing.T) { 29 | sum := int64(0) 30 | w := New[int64]() 31 | w.OnMessage = func(args int64) error { 32 | atomic.AddInt64(&sum, args) 33 | w.Update(func() {}) 34 | return nil 35 | } 36 | for i := int64(1); i <= 100; i++ { 37 | w.Push(i) 38 | } 39 | _ = w.Start() 40 | as.Equal(sum, int64(5050)) 41 | }) 42 | 43 | t.Run("error", func(t *testing.T) { 44 | cc := New[int]() 45 | cc.Push(1) 46 | cc.Push(2) 47 | cc.OnMessage = func(args int) error { 48 | return errors.New("test1") 49 | } 50 | err := cc.Start() 51 | as.Error(err) 52 | }) 53 | 54 | t.Run("timeout", func(t *testing.T) { 55 | var mu = &sync.Mutex{} 56 | var list = make([]int, 0) 57 | ctl := New[int](WithConcurrency(2), WithTimeout(time.Second)) 58 | ctl.Push(1, 3, 5, 7, 9) 59 | ctl.OnMessage = func(args int) error { 60 | mu.Lock() 61 | list = append(list, args) 62 | mu.Unlock() 63 | time.Sleep(2 * time.Second) 64 | return nil 65 | } 66 | err := ctl.Start() 67 | as.Error(err) 68 | as.ElementsMatch(list, []int{1, 3}) 69 | }) 70 | 71 | t.Run("recovery", func(t *testing.T) { 72 | ctl := New[int](WithRecovery()) 73 | ctl.Push(1) 74 | ctl.Push(2) 75 | ctl.OnMessage = func(args int) error { 76 | var err error 77 | println(err.Error()) 78 | return err 79 | } 80 | err := ctl.Start() 81 | as.Error(err) 82 | }) 83 | 84 | t.Run("cancel", func(t *testing.T) { 85 | ctl := New[int](WithConcurrency(1)) 86 | ctl.Push(1, 3, 5) 87 | arr := make([]int, 0) 88 | ctl.OnMessage = func(args int) error { 89 | ctl.Update(func() { 90 | arr = append(arr, args) 91 | }) 92 | switch args { 93 | case 3: 94 | return errors.New("3") 95 | default: 96 | return nil 97 | } 98 | } 99 | ctl.OnError = func(args int, err error) { 100 | ctl.Cancel() 101 | } 102 | err := ctl.Start() 103 | as.Error(err) 104 | as.ElementsMatch(arr, []int{1, 3}) 105 | }) 106 | } 107 | -------------------------------------------------------------------------------- /groups/options.go: -------------------------------------------------------------------------------- 1 | package groups 2 | 3 | import ( 4 | "github.com/lxzan/concurrency/internal" 5 | "github.com/pkg/errors" 6 | "runtime" 7 | "time" 8 | "unsafe" 9 | ) 10 | 11 | type options struct { 12 | timeout time.Duration 13 | concurrency int64 14 | caller Caller 15 | } 16 | 17 | type Option func(o *options) 18 | 19 | // WithTimeout 设置任务超时时间 20 | func WithTimeout(t time.Duration) Option { 21 | return func(o *options) { 22 | o.timeout = t 23 | } 24 | } 25 | 26 | // WithConcurrency 设置最大并发 27 | func WithConcurrency(n uint32) Option { 28 | return func(o *options) { 29 | o.concurrency = int64(n) 30 | } 31 | } 32 | 33 | // WithRecovery 设置恢复程序 34 | func WithRecovery() Option { 35 | return func(o *options) { 36 | o.caller = func(args any, f func(any) error) (err error) { 37 | defer func() { 38 | if e := recover(); e != nil { 39 | const size = 64 << 10 40 | buf := make([]byte, size) 41 | buf = buf[:runtime.Stack(buf, false)] 42 | msg := *(*string)(unsafe.Pointer(&buf)) 43 | err = errors.New(msg) 44 | } 45 | }() 46 | 47 | return f(args) 48 | } 49 | } 50 | } 51 | 52 | func withInitialize() Option { 53 | return func(o *options) { 54 | o.timeout = internal.SelectValue(o.timeout <= 0, defaultWaitTimeout, o.timeout) 55 | o.concurrency = internal.SelectValue(o.concurrency <= 0, defaultConcurrency, o.concurrency) 56 | o.caller = internal.SelectValue(o.caller == nil, defaultCaller, o.caller) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /internal/utils.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | type Integer interface { 4 | int | int64 | int32 | uint | uint64 | uint32 5 | } 6 | 7 | func Min[T Integer](a, b T) T { 8 | if a < b { 9 | return a 10 | } 11 | return b 12 | } 13 | 14 | func ToBinaryNumber[T Integer](n T) T { 15 | var x T = 1 16 | for x < n { 17 | x *= 2 18 | } 19 | return x 20 | } 21 | 22 | func IsSameSlice[T comparable](a, b []T) bool { 23 | if len(a) != len(b) { 24 | return false 25 | } 26 | for i, v := range a { 27 | if v != b[i] { 28 | return false 29 | } 30 | } 31 | return true 32 | } 33 | 34 | func SelectValue[T any](ok bool, a, b T) T { 35 | if ok { 36 | return a 37 | } 38 | return b 39 | } 40 | -------------------------------------------------------------------------------- /internal/utils_test.go: -------------------------------------------------------------------------------- 1 | package internal 2 | 3 | import ( 4 | "github.com/stretchr/testify/assert" 5 | "testing" 6 | ) 7 | 8 | func TestToBinaryNumber(t *testing.T) { 9 | assert.Equal(t, 8, ToBinaryNumber(7)) 10 | assert.Equal(t, 1, ToBinaryNumber(0)) 11 | assert.Equal(t, 128, ToBinaryNumber(120)) 12 | assert.Equal(t, 1024, ToBinaryNumber(1024)) 13 | } 14 | 15 | func TestMin(t *testing.T) { 16 | assert.Equal(t, 1, Min(1, 2)) 17 | assert.Equal(t, 3, Min(4, 3)) 18 | } 19 | 20 | func TestIsSameSlice(t *testing.T) { 21 | assert.True(t, IsSameSlice( 22 | []int{1, 2, 3}, 23 | []int{1, 2, 3}, 24 | )) 25 | 26 | assert.False(t, IsSameSlice( 27 | []int{1, 2, 3}, 28 | []int{1, 2}, 29 | )) 30 | 31 | assert.False(t, IsSameSlice( 32 | []int{1, 2, 3}, 33 | []int{1, 2, 4}, 34 | )) 35 | } 36 | 37 | func TestSelectValue(t *testing.T) { 38 | assert.Equal(t, SelectValue(true, 1, 2), 1) 39 | assert.Equal(t, SelectValue(false, 1, 2), 2) 40 | } 41 | -------------------------------------------------------------------------------- /logs/log.go: -------------------------------------------------------------------------------- 1 | package logs 2 | 3 | import "log" 4 | 5 | var DefaultLogger = new(logger) 6 | 7 | type Logger interface { 8 | Errorf(format string, args ...any) 9 | } 10 | 11 | type logger struct{} 12 | 13 | func (c *logger) Errorf(format string, args ...any) { 14 | log.Printf(format, args...) 15 | } 16 | -------------------------------------------------------------------------------- /queues/multiple_queue.go: -------------------------------------------------------------------------------- 1 | package queues 2 | 3 | import ( 4 | "context" 5 | "sync" 6 | "sync/atomic" 7 | ) 8 | 9 | type ( 10 | multipleQueue struct { 11 | conf *options // 参数 12 | serial atomic.Int64 // 序列号 13 | stopped bool // 是否关闭 14 | qs []*singleQueue // 子队列 15 | } 16 | 17 | errWrapper struct{ err error } 18 | ) 19 | 20 | // 创建多重队列 21 | func newMultipleQueue(o *options) *multipleQueue { 22 | qs := make([]*singleQueue, o.sharding) 23 | for i := int64(0); i < o.sharding; i++ { 24 | qs[i] = newSingleQueue(o) 25 | } 26 | return &multipleQueue{conf: o, qs: qs} 27 | } 28 | 29 | func (c *multipleQueue) Len() int { 30 | var sum = 0 31 | for _, q := range c.qs { 32 | sum += q.Len() 33 | } 34 | return sum 35 | } 36 | 37 | // Push 追加任务 38 | func (c *multipleQueue) Push(job Job) { 39 | i := c.serial.Add(1) & (c.conf.sharding - 1) 40 | c.qs[i].Push(job) 41 | } 42 | 43 | // Stop 停止 44 | // 可能需要等待一段时间, 直到所有任务执行完成或者超时 45 | func (c *multipleQueue) Stop(ctx context.Context) error { 46 | var err = atomic.Pointer[errWrapper]{} 47 | var wg = sync.WaitGroup{} 48 | wg.Add(int(c.conf.sharding)) 49 | for i, _ := range c.qs { 50 | go func(q *singleQueue) { 51 | err.CompareAndSwap(nil, &errWrapper{q.Stop(ctx)}) 52 | wg.Done() 53 | }(c.qs[i]) 54 | } 55 | wg.Wait() 56 | return err.Load().err 57 | } 58 | -------------------------------------------------------------------------------- /queues/options.go: -------------------------------------------------------------------------------- 1 | package queues 2 | 3 | import ( 4 | "github.com/lxzan/concurrency/internal" 5 | "github.com/lxzan/concurrency/logs" 6 | "runtime" 7 | "time" 8 | "unsafe" 9 | ) 10 | 11 | type options struct { 12 | sharding int64 // 分片数 13 | concurrency uint32 // 并行度 14 | timeout time.Duration // 退出等待超时时间 15 | caller Caller // 调用器 16 | logger logs.Logger // 日志组件 17 | } 18 | 19 | type Option func(o *options) 20 | 21 | // WithSharding 设置分片数量, 有利于降低锁竞争开销, 默认为1 22 | func WithSharding(num uint32) Option { 23 | return func(o *options) { 24 | o.sharding = int64(num) 25 | } 26 | } 27 | 28 | // WithConcurrency 设置最大并行度, 默认为8 29 | func WithConcurrency(n uint32) Option { 30 | return func(o *options) { 31 | o.concurrency = n 32 | } 33 | } 34 | 35 | // WithTimeout 设置退出等待超时时间, 默认30s 36 | func WithTimeout(t time.Duration) Option { 37 | return func(o *options) { 38 | o.timeout = t 39 | } 40 | } 41 | 42 | // WithLogger 设置日志组件 43 | func WithLogger(logger logs.Logger) Option { 44 | return func(o *options) { 45 | o.logger = logger 46 | } 47 | } 48 | 49 | func withInitialize() Option { 50 | return func(o *options) { 51 | o.sharding = internal.SelectValue(o.sharding <= 0, defaultSharding, o.sharding) 52 | o.sharding = internal.ToBinaryNumber(o.sharding) 53 | o.concurrency = internal.SelectValue(o.concurrency <= 0, defaultConcurrency, o.concurrency) 54 | o.timeout = internal.SelectValue(o.timeout <= 0, defaultTimeout, o.timeout) 55 | o.logger = internal.SelectValue[logs.Logger](o.logger == nil, logs.DefaultLogger, o.logger) 56 | o.caller = internal.SelectValue(o.caller == nil, defaultCaller, o.caller) 57 | } 58 | } 59 | 60 | // WithRecovery 设置恢复程序 61 | func WithRecovery() Option { 62 | return func(o *options) { 63 | o.caller = func(logger logs.Logger, f func()) { 64 | defer func() { 65 | if e := recover(); e != nil { 66 | const size = 64 << 10 67 | buf := make([]byte, size) 68 | buf = buf[:runtime.Stack(buf, false)] 69 | msg := *(*string)(unsafe.Pointer(&buf)) 70 | logger.Errorf("fatal error: %v\n%v\n", e, msg) 71 | } 72 | }() 73 | 74 | f() 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /queues/queue.go: -------------------------------------------------------------------------------- 1 | package queues 2 | 3 | import ( 4 | "context" 5 | "github.com/lxzan/concurrency/logs" 6 | "time" 7 | ) 8 | 9 | var defaultCaller Caller = func(logger logs.Logger, f func()) { f() } 10 | 11 | const ( 12 | defaultSharding = 1 13 | defaultConcurrency = 8 14 | defaultTimeout = 30 * time.Second 15 | ) 16 | 17 | type ( 18 | Caller func(logger logs.Logger, f func()) 19 | 20 | Job func() 21 | 22 | Queue interface { 23 | // Len 获取队列中剩余任务数量 24 | Len() int 25 | 26 | // Push 追加任务 27 | Push(job Job) 28 | 29 | // Stop 停止 30 | // 停止后不能追加新的任务, 队列中剩余的任务会继续执行, 到收到上下文信号为止. 31 | Stop(ctx context.Context) error 32 | } 33 | ) 34 | 35 | func New(opts ...Option) Queue { 36 | opts = append(opts, withInitialize()) 37 | o := new(options) 38 | for _, f := range opts { 39 | f(o) 40 | } 41 | 42 | if o.sharding == 1 { 43 | return newSingleQueue(o) 44 | } 45 | return newMultipleQueue(o) 46 | } 47 | -------------------------------------------------------------------------------- /queues/queue_test.go: -------------------------------------------------------------------------------- 1 | package queues 2 | 3 | import ( 4 | "context" 5 | "github.com/lxzan/concurrency/logs" 6 | "github.com/stretchr/testify/assert" 7 | "sync" 8 | "sync/atomic" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | func TestSingleQueue(t *testing.T) { 14 | as := assert.New(t) 15 | 16 | t.Run("sum", func(t *testing.T) { 17 | var val = int64(0) 18 | var wg = sync.WaitGroup{} 19 | wg.Add(1000) 20 | w := New(WithConcurrency(16)) 21 | for i := 1; i <= 1000; i++ { 22 | args := int64(i) 23 | w.Push(func() { 24 | atomic.AddInt64(&val, args) 25 | wg.Done() 26 | }) 27 | } 28 | wg.Wait() 29 | as.Equal(int64(500500), val) 30 | }) 31 | 32 | t.Run("recover", func(t *testing.T) { 33 | w := New(WithRecovery(), WithLogger(logs.DefaultLogger)) 34 | w.Push(func() { 35 | panic("test") 36 | }) 37 | }) 38 | 39 | t.Run("stop timeout", func(t *testing.T) { 40 | cc := New( 41 | WithConcurrency(1), 42 | WithTimeout(100*time.Millisecond), 43 | ) 44 | 45 | cc.Push(func() { 46 | time.Sleep(500 * time.Millisecond) 47 | }) 48 | cc.Push(func() { 49 | time.Sleep(500 * time.Millisecond) 50 | }) 51 | 52 | err := cc.Stop(context.Background()) 53 | as.Error(err) 54 | }) 55 | 56 | t.Run("stop graceful", func(t *testing.T) { 57 | cc := New(WithTimeout(time.Second)) 58 | sum := int64(0) 59 | cc.Push(func() { 60 | time.Sleep(time.Millisecond) 61 | atomic.AddInt64(&sum, 1) 62 | }) 63 | cc.Stop(context.Background()) 64 | assert.Equal(t, int64(1), atomic.LoadInt64(&sum)) 65 | }) 66 | 67 | t.Run("", func(t *testing.T) { 68 | q := New(WithConcurrency(1)) 69 | q.Push(func() { time.Sleep(100 * time.Millisecond) }) 70 | q.Push(func() {}) 71 | q.Stop(context.Background()) 72 | q.Push(func() {}) 73 | assert.Equal(t, 0, q.(*singleQueue).Len()) 74 | }) 75 | 76 | t.Run("stop", func(t *testing.T) { 77 | var q = New() 78 | var ctx = context.Background() 79 | q.Stop(ctx) 80 | q.Stop(ctx) 81 | }) 82 | } 83 | 84 | func TestMultiQueue(t *testing.T) { 85 | as := assert.New(t) 86 | 87 | t.Run("sum", func(t *testing.T) { 88 | var val = int64(0) 89 | var wg = sync.WaitGroup{} 90 | wg.Add(1000) 91 | w := New(WithConcurrency(16), WithSharding(8)) 92 | for i := 1; i <= 1000; i++ { 93 | args := int64(i) 94 | w.Push(func() { 95 | atomic.AddInt64(&val, args) 96 | wg.Done() 97 | }) 98 | } 99 | wg.Wait() 100 | as.Equal(int64(500500), val) 101 | }) 102 | 103 | t.Run("recover", func(t *testing.T) { 104 | w := New(WithRecovery(), WithLogger(logs.DefaultLogger), WithSharding(8)) 105 | w.Push(func() { 106 | panic("test") 107 | }) 108 | }) 109 | 110 | t.Run("stop", func(t *testing.T) { 111 | cc := New( 112 | WithSharding(2), 113 | WithConcurrency(1), 114 | ) 115 | as.Nil(cc.Stop(context.Background())) 116 | }) 117 | 118 | t.Run("stop finished", func(t *testing.T) { 119 | cc := New(WithSharding(8)) 120 | 121 | job := func() { 122 | time.Sleep(120 * time.Millisecond) 123 | } 124 | cc.Push(job) 125 | cc.Push(job) 126 | cc.Push(job) 127 | 128 | ctx, cancel := context.WithCancel(context.Background()) 129 | go func() { 130 | time.Sleep(150 * time.Millisecond) 131 | cancel() 132 | }() 133 | t0 := time.Now() 134 | err := cc.Stop(ctx) 135 | println(time.Since(t0).String()) 136 | as.Nil(err) 137 | }) 138 | 139 | t.Run("stop timeout", func(t *testing.T) { 140 | cc := New( 141 | WithTimeout(100*time.Millisecond), 142 | WithSharding(2), 143 | WithConcurrency(1), 144 | ) 145 | job := func() { 146 | time.Sleep(500 * time.Millisecond) 147 | } 148 | cc.Push(job) 149 | cc.Push(job) 150 | cc.Push(job) 151 | err := cc.Stop(context.Background()) 152 | as.Error(err) 153 | as.Equal(cc.Len(), 1) 154 | }) 155 | 156 | t.Run("stop graceful", func(t *testing.T) { 157 | cc := New(WithTimeout(time.Second), WithSharding(8)) 158 | sum := int64(0) 159 | cc.Push(func() { 160 | time.Sleep(time.Millisecond) 161 | atomic.AddInt64(&sum, 1) 162 | }) 163 | cc.Stop(context.Background()) 164 | assert.Equal(t, int64(1), atomic.LoadInt64(&sum)) 165 | }) 166 | 167 | t.Run("", func(t *testing.T) { 168 | q := New(WithConcurrency(1)) 169 | q.Push(func() { time.Sleep(100 * time.Millisecond) }) 170 | q.Push(func() {}) 171 | q.Stop(context.Background()) 172 | q.Push(func() {}) 173 | }) 174 | } 175 | -------------------------------------------------------------------------------- /queues/single_queue.go: -------------------------------------------------------------------------------- 1 | package queues 2 | 3 | import ( 4 | "context" 5 | "github.com/lxzan/dao/deque" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | // 创建一条任务队列 11 | func newSingleQueue(o *options) *singleQueue { 12 | return &singleQueue{ 13 | conf: o, 14 | maxConcurrency: int32(o.concurrency), 15 | q: deque.New[Job](8), 16 | } 17 | } 18 | 19 | type singleQueue struct { 20 | mu sync.Mutex // 锁 21 | conf *options 22 | q *deque.Deque[Job] // 任务队列 23 | maxConcurrency int32 // 最大并发 24 | curConcurrency int32 // 当前并发 25 | stopped bool // 是否关闭 26 | } 27 | 28 | func (c *singleQueue) Stop(ctx context.Context) error { 29 | if !c.cas(false, true) { 30 | return nil 31 | } 32 | 33 | ctx1, cancel := context.WithTimeout(ctx, c.conf.timeout) 34 | ticker := time.NewTicker(100 * time.Millisecond) 35 | defer func() { 36 | cancel() 37 | ticker.Stop() 38 | }() 39 | 40 | for { 41 | select { 42 | case <-ticker.C: 43 | if c.finish() { 44 | return nil 45 | } 46 | case <-ctx1.Done(): 47 | if c.finish() { 48 | return nil 49 | } 50 | return ctx1.Err() 51 | } 52 | } 53 | } 54 | 55 | // 获取一个任务 56 | func (c *singleQueue) getJob(newJob Job, delta int32) Job { 57 | c.mu.Lock() 58 | defer c.mu.Unlock() 59 | 60 | if !c.stopped && newJob != nil { 61 | c.q.PushBack(newJob) 62 | } 63 | c.curConcurrency += delta 64 | if c.curConcurrency >= c.maxConcurrency { 65 | return nil 66 | } 67 | if job := c.q.PopFront(); job != nil { 68 | c.curConcurrency++ 69 | return job 70 | } 71 | return nil 72 | } 73 | 74 | // 循环执行任务 75 | func (c *singleQueue) do(job Job) { 76 | for job != nil { 77 | c.conf.caller(c.conf.logger, job) 78 | job = c.getJob(nil, -1) 79 | } 80 | } 81 | 82 | // Push 追加任务, 有资源空闲的话会立即执行 83 | func (c *singleQueue) Push(job Job) { 84 | if nextJob := c.getJob(job, 0); nextJob != nil { 85 | go c.do(nextJob) 86 | } 87 | } 88 | 89 | func (c *singleQueue) Len() int { 90 | c.mu.Lock() 91 | defer c.mu.Unlock() 92 | return c.q.Len() 93 | } 94 | 95 | func (c *singleQueue) finish() bool { 96 | c.mu.Lock() 97 | defer c.mu.Unlock() 98 | return c.q.Len()+int(c.curConcurrency) == 0 99 | } 100 | 101 | func (c *singleQueue) cas(old, new bool) bool { 102 | c.mu.Lock() 103 | defer c.mu.Unlock() 104 | if c.stopped == old { 105 | c.stopped = new 106 | return true 107 | } 108 | return false 109 | } 110 | --------------------------------------------------------------------------------