├── .gitignore ├── LICENSE ├── README.md ├── batch.go ├── batch_test.go ├── consume.go ├── consume_test.go ├── doc.go ├── fork_join.go ├── fork_join_test.go ├── go.mod ├── go.sum ├── invoke_all.go ├── invoke_all_test.go ├── mocks └── partitioner.go ├── partition.go ├── partition_test.go ├── repeat.go ├── repeat_test.go ├── spread.go ├── spread_test.go ├── task.go └── task_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | vendor 2 | .idea/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2019 Grabtaxi Holdings PTE LTE (GRAB) 4 | Copyright (c) 2021 Roman Atachiants 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | this software and associated documentation files (the "Software"), to deal in 8 | the Software without restriction, including without limitation the rights to 9 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 10 | the Software, and to permit persons to whom the Software is furnished to do so, 11 | subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 18 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 19 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 20 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 21 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Async 2 | 3 | ## What is package async 4 | 5 | Package async simplifies the implementation of orchestration patterns for concurrent systems. It is similar to Java Future or JS Promise, which makes life much easier when dealing with asynchronous operation and concurrent processing. Golang is excellent in term of parallel programming. However, dealing with goroutine and channels could be a big headache when business logic gets complicated. Wrapping them into higher-level functions brings code much better readability and developers a ease of thinking. 6 | 7 | Currently, this packageg includes: 8 | 9 | - Asynchronous tasks with cancellations, context propagation and state. 10 | - Task chaining by using continuations. 11 | - Fork/join pattern - running a bunch of work and waiting for everything to finish. 12 | - Throttling pattern - throttling task execution on a specified rate. 13 | - Spread pattern - spreading tasks across time. 14 | - Partition pattern - partitioning data concurrently. 15 | - Repeat pattern - repeating a certain task at a specified interval. 16 | - Batch pattern - batching many tasks into a single one with individual continuations. 17 | 18 | ## Concept 19 | 20 | **Task** is a basic concept like Future in Java. You can create a Task with an executable function which takes in context and returns result and error. 21 | 22 | ``` 23 | task := NewTask(func(context.Context) (any, error) { 24 | // run the job 25 | return res, err 26 | }) 27 | ``` 28 | 29 | #### Get the result 30 | 31 | The function will be evaluated asynchronously. You can query whether it's completed by calling task.State(), which would be a non-blocking function. Alternative, you can wait for the response with task.Outcome(), which will block the execution until the job is done. These 2 functions are quite similar to Future.isDone() or Future.get() 32 | 33 | #### Cancelling 34 | 35 | There could be case that we don't care about the result anymore some time after execution. In this case, the task can be aborted by invoking task.Cancel(). 36 | 37 | #### Chaining 38 | 39 | To have a follow-up action after the task, we can simply call ContinueWith(). This could be very useful to create a chain of processing, or like have a teardown process after the job. 40 | 41 | ## Examples 42 | 43 | For example, if want to upload numerous files efficiently. There are multiple strategies you can take 44 | Given file uploading function like: 45 | 46 | ``` 47 | func upload(context.Context) (any, error){ 48 | // do file uploading 49 | return res, err 50 | } 51 | 52 | ``` 53 | 54 | #### Fork join 55 | 56 | The main characteristic for Fork join task is to spawn new subtasks running concurrently. They could be different parts of the main task which can be running independently. The following code example illustrates how you can send files to S3 concurrently with few lines of code. 57 | 58 | ``` 59 | func uploadFilesConcurrently(files []string) { 60 | tasks := []Tasks{} 61 | for _, file := files { 62 | tasks = append(tasks, NewTask(upload(file))) 63 | } 64 | ForkJoin(context.Background(), tasks) 65 | } 66 | ``` 67 | 68 | #### Invoke All 69 | 70 | The Fork Join may not apply to every cases imagining the number of tasks go crazy. In that case, the number of concurrently running tasks, goroutines and CPU utilisation would overwhelm the node. One solution is to constraint the maximum concurrency. InvokeAll is introduced for this purpose, it's like maintaining a fixed size of goroutine pool which attempt serve the given tasks with shortest time. 71 | 72 | ``` 73 | InvokeAll(context.Background(), concurrency, tasks) 74 | ``` 75 | 76 | #### Spread 77 | 78 | Sometimes we don't really care about the concurrency but just want to make sure the tasks could be finished with certain time period. Spread function would be useful in this case by spreading the tasks evenly in given period. 79 | 80 | ``` 81 | Spread(context.Background(), period, tasks) 82 | ``` 83 | 84 | For example, if we want to send 50 files within 10 seconds, the Spread function would start to run the task every 0.2 second. An assumption made here is that every task takes similar period of time. To have more sophisticated model, we may need to have adaptive learning model to derive the task duration from characteristics or parameters of distinct tasks. 85 | -------------------------------------------------------------------------------- /batch.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "fmt" 10 | "sync" 11 | ) 12 | 13 | type batchEntry struct { 14 | id uint64 15 | payload any // Will be used as input when the batch is processed 16 | task Task // The callback will be called when this entry is processed 17 | } 18 | 19 | type batch struct { 20 | sync.RWMutex 21 | ctx context.Context 22 | lastID uint64 // The last id for result matching 23 | pending []batchEntry // The pending entries to the batch 24 | batchTask Task // The current batch task 25 | batch chan []batchEntry // The current batch channel to execute 26 | process func([]any) []any // The function which will be executed to process the items of the NewBatch 27 | } 28 | 29 | // Batch represents a batch where one can append to the batch and process it as a whole. 30 | type Batch interface { 31 | Append(payload any) Task 32 | Size() int 33 | Reduce() 34 | } 35 | 36 | // NewBatch creates a new batch 37 | func NewBatch(ctx context.Context, process func([]any) []any) Batch { 38 | return &batch{ 39 | ctx: ctx, 40 | pending: []batchEntry{}, 41 | batch: make(chan []batchEntry), 42 | process: process, 43 | } 44 | } 45 | 46 | // Append adds a new payload to the batch and returns the task for that particular 47 | // payload. You should listen for the outcome, as the task will be executed by the reducer. 48 | func (b *batch) Append(payload any) Task { 49 | b.Lock() 50 | defer b.Unlock() 51 | 52 | b.lastID = b.lastID + 1 53 | id := b.lastID 54 | 55 | // Make sure we have a batch task 56 | if b.batchTask == nil { 57 | b.batchTask = b.createBatchTask() 58 | } 59 | 60 | // Batch task will need to continue with this one 61 | t := b.batchTask.ContinueWith(b.ctx, func(batchResult any, _ error) (any, error) { 62 | if res, ok := batchResult.(map[uint64]any); ok { 63 | return res[id], nil 64 | } 65 | 66 | actualType := fmt.Sprintf("%T", batchResult) 67 | return nil, errors.New("Invalid batch type, got: " + actualType) 68 | }) 69 | 70 | // Add to the task queue 71 | b.pending = append(b.pending, batchEntry{ 72 | id: id, 73 | payload: payload, 74 | task: t, 75 | }) 76 | 77 | // Return the task we created 78 | return t 79 | } 80 | 81 | // Reduce will send a batch 82 | func (b *batch) Reduce() { 83 | b.Lock() 84 | defer b.Unlock() 85 | 86 | // Skip if the queue is empty 87 | if len(b.pending) == 0 { 88 | return 89 | } 90 | 91 | // Prepare the batch 92 | batch := append([]batchEntry{}, b.pending...) 93 | 94 | // Run the current batch 95 | b.batch <- batch 96 | 97 | // Swap the batch 98 | b.batchTask = b.createBatchTask() 99 | } 100 | 101 | // Size returns the length of the pending queue 102 | func (b *batch) Size() int { 103 | b.RLock() 104 | defer b.RUnlock() 105 | return len(b.pending) 106 | } 107 | 108 | // createBatchTask creates a task for the batch. Triggering this task will trigger the whole batch. 109 | func (b *batch) createBatchTask() Task { 110 | return Invoke(b.ctx, func(context.Context) (any, error) { 111 | // block here until a batch is ordered to be processed 112 | batch := <-b.batch 113 | m := map[uint64]any{} 114 | 115 | // prepare the input for the batch reduce call 116 | input := make([]any, len(batch)) 117 | for i, b := range batch { 118 | input[i] = b.payload 119 | } 120 | 121 | // process the input 122 | result := b.process(input) 123 | for i, res := range result { 124 | id := batch[i].id 125 | m[id] = res 126 | } 127 | 128 | // return the map of associations 129 | return m, nil 130 | }) 131 | } 132 | -------------------------------------------------------------------------------- /batch_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "sync" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestBatch(t *testing.T) { 16 | const taskCount = 10 17 | var wg sync.WaitGroup 18 | wg.Add(taskCount) 19 | 20 | // reducer that multiplies items by 10 at once 21 | r := NewBatch(context.Background(), func(input []any) []any { 22 | result := make([]any, len(input)) 23 | for i, number := range input { 24 | result[i] = number.(int) * 10 25 | } 26 | return result 27 | }) 28 | 29 | for i := 0; i < taskCount; i++ { 30 | number := i 31 | r.Append(number).ContinueWith(context.TODO(), func(result any, err error) (any, error) { 32 | assert.Equal(t, result.(int), number*10) 33 | assert.NoError(t, err) 34 | wg.Done() 35 | return nil, nil 36 | }) 37 | } 38 | 39 | assert.Equal(t, 10, r.Size()) 40 | 41 | r.Reduce() 42 | wg.Wait() 43 | } 44 | 45 | func ExampleBatch() { 46 | var wg sync.WaitGroup 47 | wg.Add(2) 48 | 49 | r := NewBatch(context.Background(), func(input []any) []any { 50 | fmt.Println(input) 51 | return input 52 | }) 53 | 54 | r.Append(1).ContinueWith(context.TODO(), func(result any, err error) (any, error) { 55 | wg.Done() 56 | return nil, nil 57 | }) 58 | r.Append(2).ContinueWith(context.TODO(), func(result any, err error) (any, error) { 59 | wg.Done() 60 | return nil, nil 61 | }) 62 | r.Reduce() 63 | wg.Wait() 64 | 65 | // Output: 66 | // [1 2] 67 | } 68 | -------------------------------------------------------------------------------- /consume.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "runtime" 9 | ) 10 | 11 | // Consume runs the tasks with a specific max concurrency 12 | func Consume(ctx context.Context, concurrency int, tasks chan Task) Task { 13 | if concurrency <= 0 { 14 | concurrency = runtime.NumCPU() 15 | } 16 | 17 | return Invoke(ctx, func(taskCtx context.Context) (any, error) { 18 | workers := make(chan int, concurrency) 19 | concurrentTasks := make([]Task, concurrency) 20 | // generate worker IDs 21 | for id := 0; id < concurrency; id++ { 22 | workers <- id 23 | } 24 | 25 | for { 26 | select { 27 | // context cancelled 28 | case <-taskCtx.Done(): 29 | WaitAll(concurrentTasks) 30 | return nil, taskCtx.Err() 31 | 32 | // worker available 33 | case workerID := <-workers: 34 | select { 35 | // worker is waiting for job when context is cancelled 36 | case <-taskCtx.Done(): 37 | WaitAll(concurrentTasks) 38 | return nil, taskCtx.Err() 39 | 40 | case t, ok := <-tasks: 41 | // if task channel is closed 42 | if !ok { 43 | WaitAll(concurrentTasks) 44 | return nil, nil 45 | } 46 | concurrentTasks[workerID] = t 47 | t.Run(taskCtx).ContinueWith(taskCtx, 48 | func(any, error) (any, error) { 49 | workers <- workerID 50 | return nil, nil 51 | }) 52 | } 53 | } 54 | } 55 | }) 56 | } 57 | -------------------------------------------------------------------------------- /consume_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestProcessTaskPool_HappyPath(t *testing.T) { 15 | tests := []struct { 16 | desc string 17 | taskCount int 18 | concurrency int 19 | }{ 20 | { 21 | desc: "10 tasks in channel to be run with default concurrency", 22 | taskCount: 10, 23 | concurrency: 0, 24 | }, 25 | { 26 | desc: "10 tasks in channel to be run with 2 workers", 27 | taskCount: 10, 28 | concurrency: 2, 29 | }, 30 | { 31 | desc: "10 tasks in channel to be run with 10 workers", 32 | taskCount: 10, 33 | concurrency: 10, 34 | }, 35 | { 36 | desc: "10 tasks in channel to be run with 20 workers", 37 | taskCount: 10, 38 | concurrency: 20, 39 | }, 40 | } 41 | 42 | for _, test := range tests { 43 | m := test 44 | resChan := make(chan struct{}, m.taskCount) 45 | taskChan := make(chan Task) 46 | 47 | go func() { 48 | for i := 0; i < m.taskCount; i++ { 49 | taskChan <- NewTask(func(context.Context) (any, error) { 50 | resChan <- struct{}{} 51 | time.Sleep(time.Millisecond * 10) 52 | return nil, nil 53 | }) 54 | } 55 | close(taskChan) 56 | }() 57 | p := Consume(context.Background(), m.concurrency, taskChan) 58 | _, err := p.Outcome() 59 | close(resChan) 60 | assert.Nil(t, err, m.desc) 61 | 62 | var res []struct{} 63 | for r := range resChan { 64 | res = append(res, r) 65 | } 66 | assert.Len(t, res, m.taskCount, m.desc) 67 | } 68 | } 69 | 70 | // test context cancellation 71 | func TestProcessTaskPool_SadPath(t *testing.T) { 72 | tests := []struct { 73 | desc string 74 | taskCount int 75 | concurrency int 76 | timeOut time.Duration // in millisecond 77 | }{ 78 | { 79 | desc: "2 workers cannot finish 10 tasks in 20 ms where 1 task takes 10 ms. Context cancelled while waiting for available worker", 80 | taskCount: 10, 81 | concurrency: 2, 82 | timeOut: 20, 83 | }, 84 | { 85 | desc: "once 10 tasks are completed, workers will wait for more task. Then context will timeout in 20ms", 86 | taskCount: 10, 87 | concurrency: 20, 88 | timeOut: 20, 89 | }, 90 | } 91 | 92 | for _, test := range tests { 93 | m := test 94 | taskChan := make(chan Task) 95 | ctx, cancel := context.WithTimeout(context.Background(), m.timeOut*time.Millisecond) 96 | defer cancel() 97 | 98 | go func() { 99 | for i := 0; i < m.taskCount; i++ { 100 | taskChan <- NewTask(func(context.Context) (any, error) { 101 | time.Sleep(time.Millisecond * 10) 102 | return nil, nil 103 | }) 104 | } 105 | }() 106 | p := Consume(ctx, m.concurrency, taskChan) 107 | _, err := p.Outcome() 108 | assert.NotNil(t, err, m.desc) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | // Package async simplifies the implementation of orchestration patterns for concurrent systems. Currently, it includes: 5 | // 6 | // Asynchronous tasks with cancellations, context propagation and state. 7 | // 8 | // Task chaining by using continuations. 9 | // 10 | // Fork/join pattern - running a bunch of work and waiting for everything to finish. 11 | // 12 | // Throttling pattern - throttling task execution on a specified rate. 13 | // 14 | // Spread pattern - spreading tasks across time. 15 | // 16 | // Partition pattern - partitioning data concurrently. 17 | // 18 | // Repeat pattern - repeating a certain task at a specified interval. 19 | // 20 | // Batch pattern - batching many tasks into a single one with individual continuations. 21 | 22 | package async 23 | -------------------------------------------------------------------------------- /fork_join.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import "context" 7 | 8 | // ForkJoin executes input task in parallel and waits for ALL outcomes before returning. 9 | func ForkJoin(ctx context.Context, tasks []Task) Task { 10 | return Invoke(ctx, func(context.Context) (any, error) { 11 | for _, task := range tasks { 12 | _ = task.Run(ctx) 13 | } 14 | WaitAll(tasks) 15 | return nil, nil 16 | }) 17 | } 18 | 19 | // WaitAll waits for all tasks to finish. 20 | func WaitAll(tasks []Task) { 21 | for _, task := range tasks { 22 | if task != nil { 23 | _, _ = task.Outcome() 24 | } 25 | } 26 | } 27 | 28 | // CancelAll cancels all specified tasks. 29 | func CancelAll(tasks []Task) { 30 | for _, task := range tasks { 31 | task.Cancel() 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /fork_join_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "fmt" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestForkJoin(t *testing.T) { 16 | first := NewTask(func(context.Context) (any, error) { 17 | return 1, nil 18 | }) 19 | second := NewTask(func(context.Context) (any, error) { 20 | return nil, errors.New("some error") 21 | }) 22 | third := NewTask(func(context.Context) (any, error) { 23 | return 3, nil 24 | }) 25 | 26 | ForkJoin(context.Background(), []Task{first, second, third}) 27 | 28 | outcome1, error1 := first.Outcome() 29 | assert.Equal(t, 1, outcome1) 30 | assert.Nil(t, error1) 31 | 32 | outcome2, error2 := second.Outcome() 33 | assert.Nil(t, outcome2) 34 | assert.NotNil(t, error2) 35 | 36 | outcome3, error3 := third.Outcome() 37 | assert.Equal(t, 3, outcome3) 38 | assert.Nil(t, error3) 39 | } 40 | 41 | func ExampleForkJoin() { 42 | first := NewTask(func(context.Context) (any, error) { 43 | return 1, nil 44 | }) 45 | 46 | second := NewTask(func(context.Context) (any, error) { 47 | return nil, errors.New("some error") 48 | }) 49 | 50 | ForkJoin(context.Background(), []Task{first, second}) 51 | 52 | fmt.Println(first.Outcome()) 53 | fmt.Println(second.Outcome()) 54 | 55 | // Output: 56 | // 1 57 | // some error 58 | } 59 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/kelindar/async 2 | 3 | go 1.18 4 | 5 | require ( 6 | github.com/stretchr/testify v1.7.0 7 | golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba 8 | ) 9 | 10 | require ( 11 | github.com/davecgh/go-spew v1.1.0 // indirect 12 | github.com/pmezard/go-difflib v1.0.0 // indirect 13 | github.com/stretchr/objx v0.1.0 // indirect 14 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect 15 | ) 16 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= 6 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 7 | github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= 8 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 9 | golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE= 10 | golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= 11 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 12 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 13 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= 14 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 15 | -------------------------------------------------------------------------------- /invoke_all.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import "context" 7 | 8 | // InvokeAll runs the tasks with a specific max concurrency 9 | func InvokeAll(ctx context.Context, concurrency int, tasks []Task) Task { 10 | if concurrency == 0 { 11 | return ForkJoin(ctx, tasks) 12 | } 13 | 14 | return Invoke(ctx, func(context.Context) (any, error) { 15 | sem := make(chan struct{}, concurrency) 16 | for _, task := range tasks { 17 | sem <- struct{}{} 18 | task.Run(ctx).ContinueWith(ctx, 19 | func(any, error) (any, error) { 20 | <-sem 21 | return nil, nil 22 | }) 23 | } 24 | WaitAll(tasks) 25 | return nil, nil 26 | }) 27 | } 28 | -------------------------------------------------------------------------------- /invoke_all_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestInvokeAll(t *testing.T) { 16 | resChan := make(chan int, 6) 17 | works := make([]Work, 6, 6) 18 | for i := range works { 19 | j := i 20 | works[j] = func(context.Context) (any, error) { 21 | resChan <- j / 2 22 | time.Sleep(time.Millisecond * 10) 23 | return nil, nil 24 | } 25 | } 26 | tasks := NewTasks(works...) 27 | InvokeAll(context.Background(), 2, tasks) 28 | WaitAll(tasks) 29 | close(resChan) 30 | res := []int{} 31 | for r := range resChan { 32 | res = append(res, r) 33 | } 34 | assert.Equal(t, []int{0, 0, 1, 1, 2, 2}, res) 35 | } 36 | 37 | func TestInvokeAllWithZeroConcurrency(t *testing.T) { 38 | resChan := make(chan int, 6) 39 | works := make([]Work, 6, 6) 40 | for i := range works { 41 | j := i 42 | works[j] = func(context.Context) (any, error) { 43 | resChan <- 1 44 | time.Sleep(time.Millisecond * 10) 45 | return nil, nil 46 | } 47 | } 48 | tasks := NewTasks(works...) 49 | InvokeAll(context.Background(), 0, tasks) 50 | WaitAll(tasks) 51 | close(resChan) 52 | res := []int{} 53 | for r := range resChan { 54 | res = append(res, r) 55 | } 56 | assert.Equal(t, []int{1, 1, 1, 1, 1, 1}, res) 57 | } 58 | 59 | func ExampleInvokeAll() { 60 | resChan := make(chan int, 6) 61 | works := make([]Work, 6, 6) 62 | for i := range works { 63 | j := i 64 | works[j] = func(context.Context) (any, error) { 65 | fmt.Println(j / 2) 66 | time.Sleep(time.Millisecond * 10) 67 | return nil, nil 68 | } 69 | } 70 | tasks := NewTasks(works...) 71 | InvokeAll(context.Background(), 2, tasks) 72 | WaitAll(tasks) 73 | close(resChan) 74 | res := []int{} 75 | for r := range resChan { 76 | res = append(res, r) 77 | } 78 | 79 | // Output: 80 | // 0 81 | // 0 82 | // 1 83 | // 1 84 | // 2 85 | // 2 86 | } 87 | -------------------------------------------------------------------------------- /mocks/partitioner.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | // Code generated by mockery v1.0.0. DO NOT EDIT. 5 | package mocks 6 | 7 | import ( 8 | "github.com/kelindar/async" 9 | "github.com/stretchr/testify/mock" 10 | ) 11 | 12 | // Partitioner is an autogenerated mock type for the Partitioner type 13 | type Partitioner struct { 14 | mock.Mock 15 | } 16 | 17 | // Append provides a mock function with given fields: items 18 | func (_m *Partitioner) Append(items any) async.Task { 19 | ret := _m.Called(items) 20 | 21 | var r0 async.Task 22 | if rf, ok := ret.Get(0).(func(any) async.Task); ok { 23 | r0 = rf(items) 24 | } else { 25 | if ret.Get(0) != nil { 26 | r0 = ret.Get(0).(async.Task) 27 | } 28 | } 29 | 30 | return r0 31 | } 32 | 33 | // Partition provides a mock function with given fields: 34 | func (_m *Partitioner) Partition() map[string][]any { 35 | ret := _m.Called() 36 | 37 | var r0 map[string][]any 38 | if rf, ok := ret.Get(0).(func() map[string][]any); ok { 39 | r0 = rf() 40 | } else { 41 | if ret.Get(0) != nil { 42 | r0 = ret.Get(0).(map[string][]any) 43 | } 44 | } 45 | 46 | return r0 47 | } 48 | -------------------------------------------------------------------------------- /partition.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "reflect" 9 | "sync" 10 | ) 11 | 12 | type partitioner struct { 13 | sync.RWMutex 14 | ctx context.Context 15 | queue *queue 16 | partition PartitionFunc // The function which will be executed to process the items of the NewBatch 17 | } 18 | 19 | const defaultCapacity = 1 << 14 20 | 21 | type partitionedItems map[string][]any 22 | 23 | // Partitioner partitions events 24 | type Partitioner interface { 25 | // Append items to the queue which is pending partition 26 | Append(items any) Task 27 | 28 | // Partition items and output the result 29 | Partition() map[string][]any 30 | } 31 | 32 | // PartitionFunc takes in data and outputs key 33 | // if ok is false, the data doesn't fall into and partition 34 | type PartitionFunc func(data any) (key string, ok bool) 35 | 36 | // NewPartitioner creates a new partitioner 37 | func NewPartitioner(ctx context.Context, partition PartitionFunc) Partitioner { 38 | return &partitioner{ 39 | ctx: ctx, 40 | queue: newQueue(), 41 | partition: partition, 42 | } 43 | } 44 | 45 | // Append adds a batch of events to the buffer 46 | func (p *partitioner) Append(items any) Task { 47 | return Invoke(p.ctx, func(context.Context) (any, error) { 48 | p.queue.Append(p.transform(items)) 49 | return nil, nil 50 | }) 51 | } 52 | 53 | // transform creates a map of scope to event 54 | func (p *partitioner) transform(items any) partitionedItems { 55 | t := reflect.TypeOf(items) 56 | if t.Kind() != reflect.Slice { 57 | panic("transform requires for slice") 58 | } 59 | 60 | rv := reflect.ValueOf(items) 61 | mapped := partitionedItems{} 62 | for i := 0; i < rv.Len(); i++ { 63 | e := rv.Index(i).Interface() 64 | if key, ok := p.partition(e); ok { 65 | mapped[key] = append(mapped[key], e) 66 | } 67 | } 68 | return mapped 69 | } 70 | 71 | // Partition flushes the list of events and clears up the buffer 72 | func (p *partitioner) Partition() map[string][]any { 73 | out := partitionedItems{} 74 | for _, pMap := range p.queue.Flush() { 75 | for k, v := range pMap { 76 | out[k] = append(out[k], v...) 77 | } 78 | } 79 | return out 80 | } 81 | 82 | // ------------------------------------------------------ 83 | 84 | // Queue represents a batch queue for faster inserts 85 | type queue struct { 86 | sync.Mutex 87 | queue []partitionedItems 88 | } 89 | 90 | // newQueue creates a new event queue 91 | func newQueue() *queue { 92 | return &queue{ 93 | queue: make([]partitionedItems, 0, defaultCapacity), 94 | } 95 | } 96 | 97 | // Append appends to the concurrent queue 98 | func (q *queue) Append(events partitionedItems) { 99 | q.Lock() 100 | q.queue = append(q.queue, events) 101 | q.Unlock() 102 | } 103 | 104 | // Flush flushes the event queue 105 | func (q *queue) Flush() []partitionedItems { 106 | q.Lock() 107 | defer q.Unlock() 108 | 109 | flushed := q.queue 110 | q.queue = make([]partitionedItems, 0, defaultCapacity) 111 | 112 | return flushed 113 | } 114 | -------------------------------------------------------------------------------- /partition_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "reflect" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestPartitioner(t *testing.T) { 16 | partitionFunc := func(data any) (string, bool) { 17 | xevent, ok := data.(map[string]string) 18 | if !ok { 19 | return "", false 20 | } 21 | key, ok := xevent["pre"] 22 | return key, ok 23 | } 24 | p := NewPartitioner(context.Background(), partitionFunc) 25 | 26 | input1 := []any{ 27 | map[string]string{"pre": "a", "val": "val1"}, 28 | map[string]string{"pre": "b", "val": "val2"}, 29 | map[string]string{"pre": "a", "val": "val4"}, 30 | map[string]string{"pre": "c", "val": "val5"}, 31 | } 32 | 33 | input2 := []any{ 34 | map[string]string{"pre": "a", "val": "val3"}, 35 | map[string]string{"pre": "c", "val": "val4"}, 36 | } 37 | 38 | expectedRes1 := map[string][]any{ 39 | "a": { 40 | map[string]string{"pre": "a", "val": "val1"}, 41 | map[string]string{"pre": "a", "val": "val4"}, 42 | map[string]string{"pre": "a", "val": "val3"}, 43 | }, 44 | "b": { 45 | map[string]string{"pre": "b", "val": "val2"}, 46 | }, 47 | "c": { 48 | map[string]string{"pre": "c", "val": "val5"}, 49 | map[string]string{"pre": "c", "val": "val4"}, 50 | }, 51 | } 52 | 53 | expectedRes2 := map[string][]any{ 54 | "a": { 55 | map[string]string{"pre": "a", "val": "val3"}, 56 | map[string]string{"pre": "a", "val": "val1"}, 57 | map[string]string{"pre": "a", "val": "val4"}, 58 | }, 59 | "b": { 60 | map[string]string{"pre": "b", "val": "val2"}, 61 | }, 62 | "c": { 63 | map[string]string{"pre": "c", "val": "val4"}, 64 | map[string]string{"pre": "c", "val": "val5"}, 65 | }, 66 | } 67 | 68 | t1 := p.Append(input1) 69 | t2 := p.Append(input2) 70 | _, _ = t1.Outcome() 71 | _, _ = t2.Outcome() 72 | 73 | res := p.Partition() 74 | assert.True(t, reflect.DeepEqual(expectedRes1, res) || reflect.DeepEqual(expectedRes2, res)) 75 | } 76 | 77 | func ExamplePartitioner() { 78 | partitionFunc := func(data any) (string, bool) { 79 | xevent, ok := data.(map[string]string) 80 | if !ok { 81 | return "", false 82 | } 83 | key, ok := xevent["pre"] 84 | return key, ok 85 | } 86 | p := NewPartitioner(context.Background(), partitionFunc) 87 | 88 | input := []any{ 89 | map[string]string{"pre": "a", "val": "val1"}, 90 | map[string]string{"pre": "b", "val": "val2"}, 91 | map[string]string{"pre": "a", "val": "val4"}, 92 | map[string]string{"pre": "c", "val": "val5"}, 93 | } 94 | t := p.Append(input) 95 | _, _ = t.Outcome() 96 | 97 | res := p.Partition() 98 | first := res["a"][0].(map[string]string) 99 | fmt.Println(first["pre"]) 100 | fmt.Println(first["val"]) 101 | 102 | // Output: 103 | // a 104 | // val1 105 | } 106 | 107 | func TestQueue(t *testing.T) { 108 | q := newQueue() 109 | input1 := partitionedItems{ 110 | "a": []any{"val1"}, 111 | "b": []any{"val2"}, 112 | } 113 | 114 | input2 := partitionedItems{ 115 | "a": []any{"val4"}, 116 | "c": []any{"val5"}, 117 | } 118 | 119 | expectedRes := []partitionedItems{ 120 | { 121 | "a": []any{"val1"}, 122 | "b": []any{"val2"}, 123 | }, { 124 | "a": []any{"val4"}, 125 | "c": []any{"val5"}, 126 | }, 127 | } 128 | 129 | q.Append(input1) 130 | q.Append(input2) 131 | assert.Equal(t, expectedRes, q.Flush()) 132 | } 133 | 134 | func TestQuery_flush(t *testing.T) { 135 | q := newQueue() 136 | 137 | // fill greater than default capacity 138 | items := defaultCapacity + 10 139 | for x := 0; x < items; x++ { 140 | q.Append(partitionedItems{}) 141 | } 142 | assert.True(t, defaultCapacity < cap(q.queue)) 143 | assert.True(t, defaultCapacity < len(q.queue)) 144 | 145 | // flush 146 | flushedItems := q.Flush() 147 | 148 | // validate 149 | assert.Equal(t, items, len(flushedItems)) 150 | assert.Equal(t, 0, len(q.queue)) 151 | assert.Equal(t, defaultCapacity, cap(q.queue)) 152 | } 153 | -------------------------------------------------------------------------------- /repeat.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "time" 9 | ) 10 | 11 | // Repeat performs an action asynchronously on a predetermined interval. 12 | func Repeat(ctx context.Context, interval time.Duration, action Work) Task { 13 | 14 | // Invoke the task timer 15 | return Invoke(ctx, func(taskCtx context.Context) (any, error) { 16 | timer := time.NewTicker(interval) 17 | for { 18 | select { 19 | case <-taskCtx.Done(): 20 | timer.Stop() 21 | return nil, nil 22 | 23 | case <-timer.C: 24 | _, _ = action(taskCtx) 25 | } 26 | } 27 | }) 28 | } 29 | -------------------------------------------------------------------------------- /repeat_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestRepeat(t *testing.T) { 16 | assert.NotPanics(t, func() { 17 | out := make(chan bool, 1) 18 | task := Repeat(context.TODO(), time.Nanosecond*10, func(context.Context) (any, error) { 19 | out <- true 20 | return nil, nil 21 | }) 22 | 23 | <-out 24 | v := <-out 25 | assert.True(t, v) 26 | task.Cancel() 27 | }) 28 | } 29 | 30 | func ExampleRepeat() { 31 | out := make(chan bool, 1) 32 | task := Repeat(context.TODO(), time.Nanosecond*10, func(context.Context) (any, error) { 33 | out <- true 34 | return nil, nil 35 | }) 36 | 37 | <-out 38 | v := <-out 39 | fmt.Println(v) 40 | task.Cancel() 41 | 42 | // Output: 43 | // true 44 | } 45 | 46 | /* 47 | func TestRepeatFirstActionPanic(t *testing.T) { 48 | assert.NotPanics(t, func() { 49 | task := Repeat(context.TODO(), time.Nanosecond*10, func(context.Context) (any, error) { 50 | panic("test") 51 | }) 52 | 53 | task.Cancel() 54 | }) 55 | } 56 | 57 | func TestRepeatPanic(t *testing.T) { 58 | assert.NotPanics(t, func() { 59 | var counter int32 60 | task := Repeat(context.TODO(), time.Nanosecond*10, func(context.Context) (any, error) { 61 | atomic.AddInt32(&counter, 1) 62 | panic("test") 63 | }) 64 | 65 | for atomic.LoadInt32(&counter) <= 10 { 66 | } 67 | 68 | task.Cancel() 69 | }) 70 | } 71 | */ 72 | -------------------------------------------------------------------------------- /spread.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "time" 9 | 10 | "golang.org/x/time/rate" 11 | ) 12 | 13 | // Throttle runs the tasks with a specified rate limiter. 14 | func Throttle(ctx context.Context, tasks []Task, rateLimit int, every time.Duration) Task { 15 | return Invoke(ctx, func(context.Context) (any, error) { 16 | limiter := rate.NewLimiter(rate.Every(every/time.Duration(rateLimit)), 1) 17 | for i, task := range tasks { 18 | select { 19 | case <-ctx.Done(): 20 | CancelAll(tasks[i:]) 21 | return nil, errCancelled 22 | default: 23 | if err := limiter.Wait(ctx); err == nil { 24 | task.Run(ctx) 25 | } 26 | } 27 | } 28 | 29 | WaitAll(tasks) 30 | return nil, nil 31 | }) 32 | } 33 | 34 | // Spread evenly spreads the work within the specified duration. 35 | func Spread(ctx context.Context, within time.Duration, tasks []Task) Task { 36 | return Invoke(ctx, func(context.Context) (any, error) { 37 | sleep := within / time.Duration(len(tasks)) 38 | for _, task := range tasks { 39 | select { 40 | case <-ctx.Done(): 41 | return nil, errCancelled 42 | default: 43 | task.Run(ctx) 44 | time.Sleep(sleep) 45 | } 46 | } 47 | 48 | WaitAll(tasks) 49 | return nil, nil 50 | }) 51 | } 52 | -------------------------------------------------------------------------------- /spread_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func newTasks() []Task { 16 | work := func(context.Context) (any, error) { 17 | return 1, nil 18 | } 19 | 20 | return NewTasks(work, work, work, work, work) 21 | } 22 | 23 | func TestThrottle(t *testing.T) { 24 | tasks := newTasks() 25 | 26 | // Throttle and calculate the duration 27 | t0 := time.Now() 28 | task := Throttle(context.Background(), tasks, 3, 50*time.Millisecond) 29 | _, _ = task.Outcome() // Wait 30 | 31 | // Make sure we completed within duration 32 | dt := int(time.Since(t0).Seconds() * 1000) 33 | assert.True(t, dt > 50 && dt < 100, fmt.Sprintf("%v ms.", dt)) 34 | } 35 | 36 | func TestThrottle_Cancel(t *testing.T) { 37 | tasks := newTasks() 38 | 39 | ctx, cancel := context.WithCancel(context.Background()) 40 | cancel() 41 | 42 | // Throttle and calculate the duration 43 | Throttle(ctx, tasks, 3, 50*time.Millisecond) 44 | WaitAll(tasks) 45 | cancelled := 0 46 | for _, task := range tasks { 47 | if task.State() == IsCancelled { 48 | cancelled++ 49 | } 50 | } 51 | 52 | assert.Equal(t, 5, cancelled) 53 | } 54 | 55 | func TestSpread(t *testing.T) { 56 | tasks := newTasks() 57 | within := 200 * time.Millisecond 58 | 59 | // Spread and calculate the duration 60 | t0 := time.Now() 61 | task := Spread(context.Background(), within, tasks) 62 | _, _ = task.Outcome() // Wait 63 | 64 | // Make sure we completed within duration 65 | dt := int(time.Since(t0).Seconds() * 1000) 66 | assert.True(t, dt > 150 && dt < 250, fmt.Sprintf("%v ms.", dt)) 67 | 68 | // Make sure all tasks are done 69 | for _, task := range tasks { 70 | v, _ := task.Outcome() 71 | assert.Equal(t, 1, v.(int)) 72 | } 73 | } 74 | 75 | func ExampleSpread() { 76 | tasks := newTasks() 77 | within := 200 * time.Millisecond 78 | 79 | // Spread 80 | task := Spread(context.Background(), within, tasks) 81 | _, _ = task.Outcome() // Wait 82 | 83 | // Make sure all tasks are done 84 | for _, task := range tasks { 85 | v, _ := task.Outcome() 86 | fmt.Println(v) 87 | } 88 | 89 | // Output: 90 | // 1 91 | // 1 92 | // 1 93 | // 1 94 | // 1 95 | } 96 | -------------------------------------------------------------------------------- /task.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Copyright 2021 Roman Atachiants 3 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 4 | 5 | package async 6 | 7 | import ( 8 | "context" 9 | "errors" 10 | "fmt" 11 | "runtime/debug" 12 | "sync/atomic" 13 | "time" 14 | ) 15 | 16 | var ( 17 | errCancelled = errors.New("context canceled") 18 | ErrPanic = errors.New("panic in async task") 19 | ) 20 | 21 | var now = time.Now 22 | 23 | // Work represents a handler to execute 24 | type Work func(context.Context) (any, error) 25 | 26 | // State represents the state enumeration for a task. 27 | type State byte 28 | 29 | // Various task states 30 | const ( 31 | IsCreated State = iota // IsCreated represents a newly created task 32 | IsRunning // IsRunning represents a task which is currently running 33 | IsCompleted // IsCompleted represents a task which was completed successfully or errored out 34 | IsCancelled // IsCancelled represents a task which was cancelled or has timed out 35 | ) 36 | 37 | type signal chan struct{} 38 | 39 | // Outcome of the task contains a result and an error 40 | type outcome struct { 41 | result any // The result of the work 42 | err error // The error 43 | } 44 | 45 | // Task represents a unit of work to be done 46 | type task struct { 47 | state int32 // This indicates whether the task is started or not 48 | cancel signal // The cancellation channel 49 | done signal // The outcome channel 50 | action Work // The work to do 51 | outcome outcome // This is used to store the result 52 | duration time.Duration // The duration of the task, in nanoseconds 53 | } 54 | 55 | // Task represents a unit of work to be done 56 | type Task interface { 57 | Run(ctx context.Context) Task 58 | Cancel() 59 | State() State 60 | Outcome() (any, error) 61 | ContinueWith(ctx context.Context, nextAction func(any, error) (any, error)) Task 62 | Duration() time.Duration 63 | } 64 | 65 | // NewTask creates a new task. 66 | func NewTask(action Work) Task { 67 | return &task{ 68 | action: action, 69 | done: make(signal, 1), 70 | cancel: make(signal, 1), 71 | } 72 | } 73 | 74 | // NewTasks creates a set of new tasks. 75 | func NewTasks(actions ...Work) []Task { 76 | tasks := make([]Task, 0, len(actions)) 77 | for _, action := range actions { 78 | tasks = append(tasks, NewTask(action)) 79 | } 80 | return tasks 81 | } 82 | 83 | // Invoke creates a new tasks and runs it asynchronously. 84 | func Invoke(ctx context.Context, action Work) Task { 85 | return NewTask(action).Run(ctx) 86 | } 87 | 88 | // Outcome waits until the task is done and returns the final result and error. 89 | func (t *task) Outcome() (any, error) { 90 | <-t.done 91 | return t.outcome.result, t.outcome.err 92 | } 93 | 94 | // State returns the current state of the task. This operation is non-blocking. 95 | func (t *task) State() State { 96 | v := atomic.LoadInt32(&t.state) 97 | return State(v) 98 | } 99 | 100 | // Duration returns the duration of the task. 101 | func (t *task) Duration() time.Duration { 102 | return t.duration 103 | } 104 | 105 | // Run starts the task asynchronously. 106 | func (t *task) Run(ctx context.Context) Task { 107 | go t.run(ctx) 108 | return t 109 | } 110 | 111 | // Cancel cancels a running task. 112 | func (t *task) Cancel() { 113 | 114 | // If the task was created but never started, transition directly to cancelled state 115 | // and close the done channel and set the error. 116 | if t.changeState(IsCreated, IsCancelled) { 117 | t.outcome = outcome{err: errCancelled} 118 | close(t.done) 119 | return 120 | } 121 | 122 | // Attempt to cancel the task if it's in the running state 123 | if t.cancel != nil { 124 | select { 125 | case <-t.cancel: 126 | return 127 | default: 128 | close(t.cancel) 129 | } 130 | } 131 | } 132 | 133 | // run starts the task synchronously. 134 | func (t *task) run(ctx context.Context) { 135 | if !t.changeState(IsCreated, IsRunning) { 136 | return // Prevent from running the same task twice 137 | } 138 | 139 | // Notify everyone of the completion/error state 140 | defer close(t.done) 141 | 142 | // Execute the task 143 | startedAt := now().UnixNano() 144 | outcomeCh := make(chan outcome, 1) 145 | go func() { 146 | defer func() { 147 | if out := recover(); out != nil { 148 | outcomeCh <- outcome{err: fmt.Errorf("%w: %s\n%s", 149 | ErrPanic, out, debug.Stack())} 150 | return 151 | } 152 | }() 153 | 154 | r, e := t.action(ctx) 155 | outcomeCh <- outcome{result: r, err: e} 156 | }() 157 | 158 | select { 159 | 160 | // In case of a manual task cancellation, set the outcome and transition 161 | // to the cancelled state. 162 | case <-t.cancel: 163 | t.duration = time.Nanosecond * time.Duration(now().UnixNano()-startedAt) 164 | t.outcome = outcome{err: errCancelled} 165 | t.changeState(IsRunning, IsCancelled) 166 | return 167 | 168 | // In case of the context timeout or other error, change the state of the 169 | // task to cancelled and return right away. 170 | case <-ctx.Done(): 171 | t.duration = time.Nanosecond * time.Duration(now().UnixNano()-startedAt) 172 | t.outcome = outcome{err: ctx.Err()} 173 | t.changeState(IsRunning, IsCancelled) 174 | return 175 | 176 | // In case where we got an outcome (happy path) 177 | case o := <-outcomeCh: 178 | t.duration = time.Nanosecond * time.Duration(now().UnixNano()-startedAt) 179 | t.outcome = o 180 | t.changeState(IsRunning, IsCompleted) 181 | return 182 | } 183 | } 184 | 185 | // ContinueWith proceeds with the next task once the current one is finished. 186 | func (t *task) ContinueWith(ctx context.Context, nextAction func(any, error) (any, error)) Task { 187 | return Invoke(ctx, func(context.Context) (any, error) { 188 | result, err := t.Outcome() 189 | return nextAction(result, err) 190 | }) 191 | } 192 | 193 | // Cancel cancels a running task. 194 | func (t *task) changeState(from, to State) bool { 195 | return atomic.CompareAndSwapInt32(&t.state, int32(from), int32(to)) 196 | } 197 | 198 | // -------------------------------- No-Op Task -------------------------------- 199 | 200 | // Completed creates a completed task. 201 | func Completed() Task { 202 | t := &task{ 203 | state: int32(IsCompleted), 204 | done: make(signal, 1), 205 | cancel: make(signal, 1), 206 | outcome: outcome{}, 207 | } 208 | close(t.done) 209 | return t 210 | } 211 | -------------------------------------------------------------------------------- /task_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "sync" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | func TestNewTasks(t *testing.T) { 17 | work := func(context.Context) (any, error) { 18 | return 1, nil 19 | } 20 | 21 | tasks := NewTasks(work, work, work) 22 | assert.Equal(t, 3, len(tasks)) 23 | } 24 | 25 | func TestOutcome(t *testing.T) { 26 | task := Invoke(context.Background(), func(context.Context) (any, error) { 27 | return 1, nil 28 | }) 29 | 30 | var wg sync.WaitGroup 31 | wg.Add(100) 32 | for i := 0; i < 100; i++ { 33 | go func() { 34 | o, _ := task.Outcome() 35 | wg.Done() 36 | assert.Equal(t, o.(int), 1) 37 | }() 38 | } 39 | wg.Wait() 40 | } 41 | 42 | func TestOutcomeTimeout(t *testing.T) { 43 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 44 | defer cancel() 45 | 46 | task := Invoke(ctx, func(context.Context) (any, error) { 47 | time.Sleep(500 * time.Millisecond) 48 | return 1, nil 49 | }) 50 | 51 | _, err := task.Outcome() 52 | assert.Equal(t, "context deadline exceeded", err.Error()) 53 | } 54 | 55 | func TestContinueWithChain(t *testing.T) { 56 | task1 := Invoke(context.Background(), func(context.Context) (any, error) { 57 | return 1, nil 58 | }) 59 | 60 | ctx := context.TODO() 61 | task2 := task1.ContinueWith(ctx, func(result any, _ error) (any, error) { 62 | return result.(int) + 1, nil 63 | }) 64 | 65 | task3 := task2.ContinueWith(ctx, func(result any, _ error) (any, error) { 66 | return result.(int) + 1, nil 67 | }) 68 | 69 | result, err := task3.Outcome() 70 | assert.Equal(t, 3, result) 71 | assert.Nil(t, err) 72 | } 73 | 74 | func TestContinueTimeout(t *testing.T) { 75 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 76 | defer cancel() 77 | 78 | first := Invoke(ctx, func(context.Context) (any, error) { 79 | return 5, nil 80 | }) 81 | 82 | second := first.ContinueWith(ctx, func(result any, err error) (any, error) { 83 | time.Sleep(500 * time.Millisecond) 84 | return result, err 85 | }) 86 | 87 | r1, err1 := first.Outcome() 88 | assert.Equal(t, 5, r1) 89 | assert.Nil(t, err1) 90 | 91 | _, err2 := second.Outcome() 92 | assert.Equal(t, "context deadline exceeded", err2.Error()) 93 | } 94 | 95 | func TestTaskCancelStarted(t *testing.T) { 96 | task := Invoke(context.Background(), func(context.Context) (any, error) { 97 | time.Sleep(500 * time.Millisecond) 98 | return 1, nil 99 | }) 100 | 101 | task.Cancel() 102 | 103 | _, err := task.Outcome() 104 | assert.Equal(t, errCancelled, err) 105 | } 106 | 107 | func TestTaskCancelRunning(t *testing.T) { 108 | task := Invoke(context.Background(), func(context.Context) (any, error) { 109 | time.Sleep(500 * time.Millisecond) 110 | return 1, nil 111 | }) 112 | 113 | time.Sleep(10 * time.Millisecond) 114 | 115 | task.Cancel() 116 | 117 | _, err := task.Outcome() 118 | assert.Equal(t, errCancelled, err) 119 | } 120 | 121 | func TestTaskCancelTwice(t *testing.T) { 122 | task := Invoke(context.Background(), func(context.Context) (any, error) { 123 | time.Sleep(500 * time.Millisecond) 124 | return 1, nil 125 | }) 126 | 127 | assert.NotPanics(t, func() { 128 | for i := 0; i < 100; i++ { 129 | task.Cancel() 130 | } 131 | }) 132 | 133 | _, err := task.Outcome() 134 | assert.Equal(t, errCancelled, err) 135 | } 136 | 137 | func TestCompleted(t *testing.T) { 138 | task := Completed() 139 | assert.Equal(t, IsCompleted, task.State()) 140 | v, err := task.Outcome() 141 | assert.Nil(t, err) 142 | assert.Nil(t, v) 143 | } 144 | 145 | func TestPanic(t *testing.T) { 146 | assert.NotPanics(t, func() { 147 | _, err := Invoke(context.Background(), func(context.Context) (any, error) { 148 | panic("test") 149 | }).Outcome() 150 | 151 | assert.Error(t, err) 152 | assert.True(t, errors.Is(err, ErrPanic)) 153 | }) 154 | } 155 | --------------------------------------------------------------------------------