├── .gitignore ├── LICENSE ├── README.md ├── batch.go ├── batch_test.go ├── consume.go ├── consume_test.go ├── doc.go ├── fork_join.go ├── fork_join_test.go ├── go.mod ├── go.sum ├── invoke_all.go ├── invoke_all_test.go ├── mocks └── partitioner.go ├── partition.go ├── partition_test.go ├── repeat.go ├── repeat_test.go ├── spread.go ├── spread_test.go ├── task.go └── task_test.go /.gitignore: -------------------------------------------------------------------------------- 1 | vendor 2 | .idea/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2019 Grabtaxi Holdings PTE LTE (GRAB) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Async 2 | 3 | ## What is package async 4 | Package async simplifies the implementation of orchestration patterns for concurrent systems. It is similar to Java Future or JS Promise, which makes life much easier when dealing with asynchronous operation and concurrent processing. Golang is excellent in term of parallel programming. However, dealing with goroutine and channels could be a big headache when business logic gets complicated. Wrapping them into higher-level functions brings code much better readability and developers a ease of thinking. 5 | 6 | Currently, this packageg includes: 7 | 8 | * Asynchronous tasks with cancellations, context propagation and state. 9 | * Task chaining by using continuations. 10 | * Fork/join pattern - running a bunch of work and waiting for everything to finish. 11 | * Throttling pattern - throttling task execution on a specified rate. 12 | * Spread pattern - spreading tasks across time. 13 | * Partition pattern - partitioning data concurrently. 14 | * Repeat pattern - repeating a certain task at a specified interval. 15 | * Batch pattern - batching many tasks into a single one with individual continuations. 16 | 17 | ## Concept 18 | **Task** is a basic concept like Future in Java. You can create a Task with an executable function which takes in context and returns result and error. 19 | ``` 20 | task := NewTask(func(context.Context) (interface{}, error) { 21 | // run the job 22 | return res, err 23 | }) 24 | ``` 25 | #### Get the result 26 | The function will be evaluated asynchronously. You can query whether it's completed by calling task.State(), which would be a non-blocking function. Alternative, you can wait for the response with task.Outcome(), which will block the execution until the job is done. These 2 functions are quite similar to Future.isDone() or Future.get() 27 | 28 | #### Cancelling 29 | There could be case that we don't care about the result anymore some time after execution. In this case, the task can be aborted by invoking task.Cancel(). 30 | 31 | #### Chaining 32 | To have a follow-up action after the task, we can simply call ContinueWith(). This could be very useful to create a chain of processing, or like have a teardown process after the job. 33 | 34 | ## Examples 35 | For example, if want to upload numerous files efficiently. There are multiple strategies you can take 36 | Given file uploading function like: 37 | ``` 38 | func upload(context.Context) (interface{}, error){ 39 | // do file uploading 40 | return res, err 41 | } 42 | 43 | ``` 44 | 45 | #### Fork join 46 | The main characteristic for Fork join task is to spawn new subtasks running concurrently. They could be different parts of the main task which can be running independently. The following code example illustrates how you can send files to S3 concurrently with few lines of code. 47 | 48 | 49 | ``` 50 | func uploadFilesConcurrently(files []string) { 51 | tasks := []Tasks{} 52 | for _, file := files { 53 | tasks = append(tasks, NewTask(upload(file))) 54 | } 55 | ForkJoin(context.Background(), tasks) 56 | } 57 | ``` 58 | 59 | #### Invoke All 60 | The Fork Join may not apply to every cases imagining the number of tasks go crazy. In that case, the number of concurrently running tasks, goroutines and CPU utilisation would overwhelm the node. One solution is to constraint the maximum concurrency. InvokeAll is introduced for this purpose, it's like maintaining a fixed size of goroutine pool which attempt serve the given tasks with shortest time. 61 | ``` 62 | InvokeAll(context.Background(), concurrency, tasks) 63 | ``` 64 | 65 | #### Spread 66 | Sometimes we don't really care about the concurrency but just want to make sure the tasks could be finished with certain time period. Spread function would be useful in this case by spreading the tasks evenly in given period. 67 | ``` 68 | Spread(context.Background(), period, tasks) 69 | ``` 70 | For example, if we want to send 50 files within 10 seconds, the Spread function would start to run the task every 0.2 second. An assumption made here is that every task takes similar period of time. To have more sophisticated model, we may need to have adaptive learning model to derive the task duration from characteristics or parameters of distinct tasks. 71 | 72 | -------------------------------------------------------------------------------- /batch.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "fmt" 10 | "sync" 11 | ) 12 | 13 | type batchEntry struct { 14 | id uint64 15 | payload interface{} // Will be used as input when the batch is processed 16 | task Task // The callback will be called when this entry is processed 17 | } 18 | 19 | type batch struct { 20 | sync.RWMutex 21 | ctx context.Context 22 | lastID uint64 // The last id for result matching 23 | pending []batchEntry // The pending entries to the batch 24 | batchTask Task // The current batch task 25 | batch chan []batchEntry // The current batch channel to execute 26 | process func([]interface{}) []interface{} // The function which will be executed to process the items of the NewBatch 27 | } 28 | 29 | // Batch represents a batch where one can append to the batch and process it as a whole. 30 | type Batch interface { 31 | Append(payload interface{}) Task 32 | Size() int 33 | Reduce() 34 | } 35 | 36 | // NewBatch creates a new batch 37 | func NewBatch(ctx context.Context, process func([]interface{}) []interface{}) Batch { 38 | return &batch{ 39 | ctx: ctx, 40 | pending: []batchEntry{}, 41 | batch: make(chan []batchEntry), 42 | process: process, 43 | } 44 | } 45 | 46 | // Append adds a new payload to the batch and returns the task for that particular 47 | // payload. You should listen for the outcome, as the task will be executed by the reducer. 48 | func (b *batch) Append(payload interface{}) Task { 49 | b.Lock() 50 | defer b.Unlock() 51 | 52 | b.lastID = b.lastID + 1 53 | id := b.lastID 54 | 55 | // Make sure we have a batch task 56 | if b.batchTask == nil { 57 | b.batchTask = b.createBatchTask() 58 | } 59 | 60 | // Batch task will need to continue with this one 61 | t := b.batchTask.ContinueWith(b.ctx, func(batchResult interface{}, _ error) (interface{}, error) { 62 | if res, ok := batchResult.(map[uint64]interface{}); ok { 63 | return res[id], nil 64 | } 65 | 66 | actualType := fmt.Sprintf("%T", batchResult) 67 | return nil, errors.New("Invalid batch type, got: " + actualType) 68 | }) 69 | 70 | // Add to the task queue 71 | b.pending = append(b.pending, batchEntry{ 72 | id: id, 73 | payload: payload, 74 | task: t, 75 | }) 76 | 77 | // Return the task we created 78 | return t 79 | } 80 | 81 | // Reduce will send a batch 82 | func (b *batch) Reduce() { 83 | b.Lock() 84 | defer b.Unlock() 85 | 86 | // Skip if the queue is empty 87 | if len(b.pending) == 0 { 88 | return 89 | } 90 | 91 | // Prepare the batch 92 | batch := append([]batchEntry{}, b.pending...) 93 | 94 | // Run the current batch 95 | b.batch <- batch 96 | 97 | // Swap the batch 98 | b.batchTask = b.createBatchTask() 99 | } 100 | 101 | // Size returns the length of the pending queue 102 | func (b *batch) Size() int { 103 | b.RLock() 104 | defer b.RUnlock() 105 | return len(b.pending) 106 | } 107 | 108 | // createBatchTask creates a task for the batch. Triggering this task will trigger the whole batch. 109 | func (b *batch) createBatchTask() Task { 110 | return Invoke(b.ctx, func(context.Context) (interface{}, error) { 111 | // block here until a batch is ordered to be processed 112 | batch := <-b.batch 113 | m := map[uint64]interface{}{} 114 | 115 | // prepare the input for the batch reduce call 116 | input := make([]interface{}, len(batch)) 117 | for i, b := range batch { 118 | input[i] = b.payload 119 | } 120 | 121 | // process the input 122 | result := b.process(input) 123 | for i, res := range result { 124 | id := batch[i].id 125 | m[id] = res 126 | } 127 | 128 | // return the map of associations 129 | return m, nil 130 | }) 131 | } 132 | -------------------------------------------------------------------------------- /batch_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "sync" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestBatch(t *testing.T) { 16 | const taskCount = 10 17 | var wg sync.WaitGroup 18 | wg.Add(taskCount) 19 | 20 | // reducer that multiplies items by 10 at once 21 | r := NewBatch(context.Background(), func(input []interface{}) []interface{} { 22 | result := make([]interface{}, len(input)) 23 | for i, number := range input { 24 | result[i] = number.(int) * 10 25 | } 26 | return result 27 | }) 28 | 29 | for i := 0; i < taskCount; i++ { 30 | number := i 31 | r.Append(number).ContinueWith(context.TODO(), func(result interface{}, err error) (interface{}, error) { 32 | assert.Equal(t, result.(int), number*10) 33 | assert.NoError(t, err) 34 | wg.Done() 35 | return nil, nil 36 | }) 37 | } 38 | 39 | assert.Equal(t, 10, r.Size()) 40 | 41 | r.Reduce() 42 | wg.Wait() 43 | } 44 | 45 | func ExampleBatch() { 46 | var wg sync.WaitGroup 47 | wg.Add(2) 48 | 49 | r := NewBatch(context.Background(), func(input []interface{}) []interface{} { 50 | fmt.Println(input) 51 | return input 52 | }) 53 | 54 | r.Append(1).ContinueWith(context.TODO(), func(result interface{}, err error) (interface{}, error) { 55 | wg.Done() 56 | return nil, nil 57 | }) 58 | r.Append(2).ContinueWith(context.TODO(), func(result interface{}, err error) (interface{}, error) { 59 | wg.Done() 60 | return nil, nil 61 | }) 62 | r.Reduce() 63 | wg.Wait() 64 | 65 | // Output: 66 | // [1 2] 67 | } 68 | -------------------------------------------------------------------------------- /consume.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "runtime" 9 | ) 10 | 11 | // Consume runs the tasks with a specific max concurrency 12 | func Consume(ctx context.Context, concurrency int, tasks chan Task) Task { 13 | if concurrency <= 0 { 14 | concurrency = runtime.NumCPU() 15 | } 16 | 17 | return Invoke(ctx, func(taskCtx context.Context) (interface{}, error) { 18 | workers := make(chan int, concurrency) 19 | concurrentTasks := make([]Task, concurrency) 20 | // generate worker IDs 21 | for id := 0; id < concurrency; id++ { 22 | workers <- id 23 | } 24 | 25 | for { 26 | select { 27 | // context cancelled 28 | case <-taskCtx.Done(): 29 | WaitAll(concurrentTasks) 30 | return nil, taskCtx.Err() 31 | 32 | // worker available 33 | case workerID := <-workers: 34 | select { 35 | // worker is waiting for job when context is cancelled 36 | case <-taskCtx.Done(): 37 | WaitAll(concurrentTasks) 38 | return nil, taskCtx.Err() 39 | 40 | case t, ok := <-tasks: 41 | // if task channel is closed 42 | if !ok { 43 | WaitAll(concurrentTasks) 44 | return nil, nil 45 | } 46 | concurrentTasks[workerID] = t 47 | t.Run(taskCtx).ContinueWith(taskCtx, 48 | func(interface{}, error) (interface{}, error) { 49 | workers <- workerID 50 | return nil, nil 51 | }) 52 | } 53 | } 54 | } 55 | }) 56 | } 57 | -------------------------------------------------------------------------------- /consume_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "testing" 9 | "time" 10 | 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestProcessTaskPool_HappyPath(t *testing.T) { 15 | tests := []struct { 16 | desc string 17 | taskCount int 18 | concurrency int 19 | }{ 20 | { 21 | desc: "10 tasks in channel to be run with default concurrency", 22 | taskCount: 10, 23 | concurrency: 0, 24 | }, 25 | { 26 | desc: "10 tasks in channel to be run with 2 workers", 27 | taskCount: 10, 28 | concurrency: 2, 29 | }, 30 | { 31 | desc: "10 tasks in channel to be run with 10 workers", 32 | taskCount: 10, 33 | concurrency: 10, 34 | }, 35 | { 36 | desc: "10 tasks in channel to be run with 20 workers", 37 | taskCount: 10, 38 | concurrency: 20, 39 | }, 40 | } 41 | 42 | for _, test := range tests { 43 | m := test 44 | resChan := make(chan struct{}, m.taskCount) 45 | taskChan := make(chan Task) 46 | 47 | go func() { 48 | for i := 0; i < m.taskCount; i++ { 49 | taskChan <- NewTask(func(context.Context) (interface{}, error) { 50 | resChan <- struct{}{} 51 | time.Sleep(time.Millisecond * 10) 52 | return nil, nil 53 | }) 54 | } 55 | close(taskChan) 56 | }() 57 | p := Consume(context.Background(), m.concurrency, taskChan) 58 | _, err := p.Outcome() 59 | close(resChan) 60 | assert.Nil(t, err, m.desc) 61 | 62 | var res []struct{} 63 | for r := range resChan { 64 | res = append(res, r) 65 | } 66 | assert.Len(t, res, m.taskCount, m.desc) 67 | } 68 | } 69 | 70 | // test context cancellation 71 | func TestProcessTaskPool_SadPath(t *testing.T) { 72 | tests := []struct { 73 | desc string 74 | taskCount int 75 | concurrency int 76 | timeOut time.Duration //in millisecond 77 | }{ 78 | { 79 | desc: "2 workers cannot finish 10 tasks in 20 ms where 1 task takes 10 ms. Context cancelled while waiting for available worker", 80 | taskCount: 10, 81 | concurrency: 2, 82 | timeOut: 20, 83 | }, 84 | { 85 | desc: "once 10 tasks are completed, workers will wait for more task. Then context will timeout in 20ms", 86 | taskCount: 10, 87 | concurrency: 20, 88 | timeOut: 20, 89 | }, 90 | } 91 | 92 | for _, test := range tests { 93 | m := test 94 | taskChan := make(chan Task) 95 | ctx, _ := context.WithTimeout(context.Background(), m.timeOut*time.Millisecond) 96 | go func() { 97 | for i := 0; i < m.taskCount; i++ { 98 | taskChan <- NewTask(func(context.Context) (interface{}, error) { 99 | time.Sleep(time.Millisecond * 10) 100 | return nil, nil 101 | }) 102 | } 103 | }() 104 | p := Consume(ctx, m.concurrency, taskChan) 105 | _, err := p.Outcome() 106 | assert.NotNil(t, err, m.desc) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /doc.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | // Package async simplifies the implementation of orchestration patterns for concurrent systems. Currently, it includes: 5 | // 6 | // Asynchronous tasks with cancellations, context propagation and state. 7 | // 8 | // Task chaining by using continuations. 9 | // 10 | // Fork/join pattern - running a bunch of work and waiting for everything to finish. 11 | // 12 | // Throttling pattern - throttling task execution on a specified rate. 13 | // 14 | // Spread pattern - spreading tasks across time. 15 | // 16 | // Partition pattern - partitioning data concurrently. 17 | // 18 | // Repeat pattern - repeating a certain task at a specified interval. 19 | // 20 | // Batch pattern - batching many tasks into a single one with individual continuations. 21 | 22 | package async 23 | -------------------------------------------------------------------------------- /fork_join.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import "context" 7 | 8 | // ForkJoin executes input task in parallel and waits for ALL outcomes before returning. 9 | func ForkJoin(ctx context.Context, tasks []Task) Task { 10 | return Invoke(ctx, func(context.Context) (interface{}, error) { 11 | for _, task := range tasks { 12 | _ = task.Run(ctx) 13 | } 14 | WaitAll(tasks) 15 | return nil, nil 16 | }) 17 | } 18 | 19 | // WaitAll waits for all tasks to finish. 20 | func WaitAll(tasks []Task) { 21 | for _, task := range tasks { 22 | if task != nil { 23 | _, _ = task.Outcome() 24 | } 25 | } 26 | } 27 | 28 | // CancelAll cancels all specified tasks. 29 | func CancelAll(tasks []Task) { 30 | for _, task := range tasks { 31 | task.Cancel() 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /fork_join_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "fmt" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestForkJoin(t *testing.T) { 16 | first := NewTask(func(context.Context) (interface{}, error) { 17 | return 1, nil 18 | }) 19 | second := NewTask(func(context.Context) (interface{}, error) { 20 | return nil, errors.New("some error") 21 | }) 22 | third := NewTask(func(context.Context) (interface{}, error) { 23 | return 3, nil 24 | }) 25 | 26 | ForkJoin(context.Background(), []Task{first, second, third}) 27 | 28 | outcome1, error1 := first.Outcome() 29 | assert.Equal(t, 1, outcome1) 30 | assert.Nil(t, error1) 31 | 32 | outcome2, error2 := second.Outcome() 33 | assert.Nil(t, outcome2) 34 | assert.NotNil(t, error2) 35 | 36 | outcome3, error3 := third.Outcome() 37 | assert.Equal(t, 3, outcome3) 38 | assert.Nil(t, error3) 39 | } 40 | 41 | func ExampleForkJoin() { 42 | first := NewTask(func(context.Context) (interface{}, error) { 43 | return 1, nil 44 | }) 45 | 46 | second := NewTask(func(context.Context) (interface{}, error) { 47 | return nil, errors.New("some error") 48 | }) 49 | 50 | ForkJoin(context.Background(), []Task{first, second}) 51 | 52 | fmt.Println(first.Outcome()) 53 | fmt.Println(second.Outcome()) 54 | 55 | // Output: 56 | // 1 57 | // some error 58 | } 59 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/grab/async 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/stretchr/testify v1.3.0 7 | golang.org/x/time v0.0.0-20191024005414-555d28b269f0 8 | ) 9 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 2 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 3 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 4 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 5 | github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= 6 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 7 | github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= 8 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 9 | golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= 10 | golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= 11 | golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= 12 | golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= 13 | -------------------------------------------------------------------------------- /invoke_all.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import "context" 7 | 8 | // InvokeAll runs the tasks with a specific max concurrency 9 | func InvokeAll(ctx context.Context, concurrency int, tasks []Task) Task { 10 | if concurrency == 0 { 11 | return ForkJoin(ctx, tasks) 12 | } 13 | 14 | return Invoke(ctx, func(context.Context) (interface{}, error) { 15 | sem := make(chan struct{}, concurrency) 16 | for _, task := range tasks { 17 | sem <- struct{}{} 18 | task.Run(ctx).ContinueWith(ctx, 19 | func(interface{}, error) (interface{}, error) { 20 | <-sem 21 | return nil, nil 22 | }) 23 | } 24 | WaitAll(tasks) 25 | return nil, nil 26 | }) 27 | } 28 | -------------------------------------------------------------------------------- /invoke_all_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestInvokeAll(t *testing.T) { 16 | resChan := make(chan int, 6) 17 | works := make([]Work, 6, 6) 18 | for i := range works { 19 | j := i 20 | works[j] = func(context.Context) (interface{}, error) { 21 | resChan <- j / 2 22 | time.Sleep(time.Millisecond * 10) 23 | return nil, nil 24 | } 25 | } 26 | tasks := NewTasks(works...) 27 | InvokeAll(context.Background(), 2, tasks) 28 | WaitAll(tasks) 29 | close(resChan) 30 | res := []int{} 31 | for r := range resChan { 32 | res = append(res, r) 33 | } 34 | assert.Equal(t, []int{0, 0, 1, 1, 2, 2}, res) 35 | } 36 | 37 | func TestInvokeAllWithZeroConcurrency(t *testing.T) { 38 | resChan := make(chan int, 6) 39 | works := make([]Work, 6, 6) 40 | for i := range works { 41 | j := i 42 | works[j] = func(context.Context) (interface{}, error) { 43 | resChan <- 1 44 | time.Sleep(time.Millisecond * 10) 45 | return nil, nil 46 | } 47 | } 48 | tasks := NewTasks(works...) 49 | InvokeAll(context.Background(), 0, tasks) 50 | WaitAll(tasks) 51 | close(resChan) 52 | res := []int{} 53 | for r := range resChan { 54 | res = append(res, r) 55 | } 56 | assert.Equal(t, []int{1, 1, 1, 1, 1, 1}, res) 57 | } 58 | 59 | func ExampleInvokeAll() { 60 | resChan := make(chan int, 6) 61 | works := make([]Work, 6, 6) 62 | for i := range works { 63 | j := i 64 | works[j] = func(context.Context) (interface{}, error) { 65 | fmt.Println(j / 2) 66 | time.Sleep(time.Millisecond * 10) 67 | return nil, nil 68 | } 69 | } 70 | tasks := NewTasks(works...) 71 | InvokeAll(context.Background(), 2, tasks) 72 | WaitAll(tasks) 73 | close(resChan) 74 | res := []int{} 75 | for r := range resChan { 76 | res = append(res, r) 77 | } 78 | 79 | // Output: 80 | // 0 81 | // 0 82 | // 1 83 | // 1 84 | // 2 85 | // 2 86 | } 87 | -------------------------------------------------------------------------------- /mocks/partitioner.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | // Code generated by mockery v1.0.0. DO NOT EDIT. 5 | package mocks 6 | 7 | import ( 8 | "github.com/stretchr/testify/mock" 9 | "gitlab.myteksi.net/grab-x/async" 10 | ) 11 | 12 | // Partitioner is an autogenerated mock type for the Partitioner type 13 | type Partitioner struct { 14 | mock.Mock 15 | } 16 | 17 | // Append provides a mock function with given fields: items 18 | func (_m *Partitioner) Append(items interface{}) async.Task { 19 | ret := _m.Called(items) 20 | 21 | var r0 async.Task 22 | if rf, ok := ret.Get(0).(func(interface{}) async.Task); ok { 23 | r0 = rf(items) 24 | } else { 25 | if ret.Get(0) != nil { 26 | r0 = ret.Get(0).(async.Task) 27 | } 28 | } 29 | 30 | return r0 31 | } 32 | 33 | // Partition provides a mock function with given fields: 34 | func (_m *Partitioner) Partition() map[string][]interface{} { 35 | ret := _m.Called() 36 | 37 | var r0 map[string][]interface{} 38 | if rf, ok := ret.Get(0).(func() map[string][]interface{}); ok { 39 | r0 = rf() 40 | } else { 41 | if ret.Get(0) != nil { 42 | r0 = ret.Get(0).(map[string][]interface{}) 43 | } 44 | } 45 | 46 | return r0 47 | } 48 | -------------------------------------------------------------------------------- /partition.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "reflect" 9 | "sync" 10 | ) 11 | 12 | type partitioner struct { 13 | sync.RWMutex 14 | ctx context.Context 15 | queue *queue 16 | partition PartitionFunc // The function which will be executed to process the items of the NewBatch 17 | } 18 | 19 | const defaultCapacity = 1 << 14 20 | 21 | type partitionedItems map[string][]interface{} 22 | 23 | // Partitioner partitions events 24 | type Partitioner interface { 25 | // Append items to the queue which is pending partition 26 | Append(items interface{}) Task 27 | 28 | // Partition items and output the result 29 | Partition() map[string][]interface{} 30 | } 31 | 32 | // PartitionFunc takes in data and outputs key 33 | // if ok is false, the data doesn't fall into and partition 34 | type PartitionFunc func(data interface{}) (key string, ok bool) 35 | 36 | // NewPartitioner creates a new partitioner 37 | func NewPartitioner(ctx context.Context, partition PartitionFunc) Partitioner { 38 | return &partitioner{ 39 | ctx: ctx, 40 | queue: newQueue(), 41 | partition: partition, 42 | } 43 | } 44 | 45 | // Append adds a batch of events to the buffer 46 | func (p *partitioner) Append(items interface{}) Task { 47 | return Invoke(p.ctx, func(context.Context) (interface{}, error) { 48 | p.queue.Append(p.transform(items)) 49 | return nil, nil 50 | }) 51 | } 52 | 53 | // transform creates a map of scope to event 54 | func (p *partitioner) transform(items interface{}) partitionedItems { 55 | t := reflect.TypeOf(items) 56 | if t.Kind() != reflect.Slice { 57 | panic("transform requires for slice") 58 | } 59 | 60 | rv := reflect.ValueOf(items) 61 | mapped := partitionedItems{} 62 | for i := 0; i < rv.Len(); i++ { 63 | e := rv.Index(i).Interface() 64 | if key, ok := p.partition(e); ok { 65 | mapped[key] = append(mapped[key], e) 66 | } 67 | } 68 | return mapped 69 | } 70 | 71 | // Partition flushes the list of events and clears up the buffer 72 | func (p *partitioner) Partition() map[string][]interface{} { 73 | out := partitionedItems{} 74 | for _, pMap := range p.queue.Flush() { 75 | for k, v := range pMap { 76 | out[k] = append(out[k], v...) 77 | } 78 | } 79 | return out 80 | } 81 | 82 | // ------------------------------------------------------ 83 | 84 | // Queue represents a batch queue for faster inserts 85 | type queue struct { 86 | sync.Mutex 87 | queue []partitionedItems 88 | } 89 | 90 | // newQueue creates a new event queue 91 | func newQueue() *queue { 92 | return &queue{ 93 | queue: make([]partitionedItems, 0, defaultCapacity), 94 | } 95 | } 96 | 97 | // Append appends to the concurrent queue 98 | func (q *queue) Append(events partitionedItems) { 99 | q.Lock() 100 | q.queue = append(q.queue, events) 101 | q.Unlock() 102 | } 103 | 104 | // Flush flushes the event queue 105 | func (q *queue) Flush() []partitionedItems { 106 | q.Lock() 107 | defer q.Unlock() 108 | 109 | flushed := q.queue 110 | q.queue = make([]partitionedItems, 0, defaultCapacity) 111 | 112 | return flushed 113 | } 114 | -------------------------------------------------------------------------------- /partition_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "reflect" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestPartitioner(t *testing.T) { 16 | partitionFunc := func(data interface{}) (string, bool) { 17 | xevent, ok := data.(map[string]string) 18 | if !ok { 19 | return "", false 20 | } 21 | key, ok := xevent["pre"] 22 | return key, ok 23 | } 24 | p := NewPartitioner(context.Background(), partitionFunc) 25 | 26 | input1 := []interface{}{ 27 | map[string]string{"pre": "a", "val": "val1"}, 28 | map[string]string{"pre": "b", "val": "val2"}, 29 | map[string]string{"pre": "a", "val": "val4"}, 30 | map[string]string{"pre": "c", "val": "val5"}, 31 | } 32 | 33 | input2 := []interface{}{ 34 | map[string]string{"pre": "a", "val": "val3"}, 35 | map[string]string{"pre": "c", "val": "val4"}, 36 | } 37 | 38 | expectedRes1 := map[string][]interface{}{ 39 | "a": { 40 | map[string]string{"pre": "a", "val": "val1"}, 41 | map[string]string{"pre": "a", "val": "val4"}, 42 | map[string]string{"pre": "a", "val": "val3"}, 43 | }, 44 | "b": { 45 | map[string]string{"pre": "b", "val": "val2"}, 46 | }, 47 | "c": { 48 | map[string]string{"pre": "c", "val": "val5"}, 49 | map[string]string{"pre": "c", "val": "val4"}, 50 | }, 51 | } 52 | 53 | expectedRes2 := map[string][]interface{}{ 54 | "a": { 55 | map[string]string{"pre": "a", "val": "val3"}, 56 | map[string]string{"pre": "a", "val": "val1"}, 57 | map[string]string{"pre": "a", "val": "val4"}, 58 | }, 59 | "b": { 60 | map[string]string{"pre": "b", "val": "val2"}, 61 | }, 62 | "c": { 63 | map[string]string{"pre": "c", "val": "val4"}, 64 | map[string]string{"pre": "c", "val": "val5"}, 65 | }, 66 | } 67 | 68 | t1 := p.Append(input1) 69 | t2 := p.Append(input2) 70 | _, _ = t1.Outcome() 71 | _, _ = t2.Outcome() 72 | 73 | res := p.Partition() 74 | assert.True(t, reflect.DeepEqual(expectedRes1, res) || reflect.DeepEqual(expectedRes2, res)) 75 | } 76 | 77 | func ExamplePartitioner() { 78 | partitionFunc := func(data interface{}) (string, bool) { 79 | xevent, ok := data.(map[string]string) 80 | if !ok { 81 | return "", false 82 | } 83 | key, ok := xevent["pre"] 84 | return key, ok 85 | } 86 | p := NewPartitioner(context.Background(), partitionFunc) 87 | 88 | input := []interface{}{ 89 | map[string]string{"pre": "a", "val": "val1"}, 90 | map[string]string{"pre": "b", "val": "val2"}, 91 | map[string]string{"pre": "a", "val": "val4"}, 92 | map[string]string{"pre": "c", "val": "val5"}, 93 | } 94 | t := p.Append(input) 95 | _, _ = t.Outcome() 96 | 97 | res := p.Partition() 98 | first := res["a"][0].(map[string]string) 99 | fmt.Println(first["pre"]) 100 | fmt.Println(first["val"]) 101 | 102 | // Output: 103 | // a 104 | // val1 105 | } 106 | 107 | func TestQueue(t *testing.T) { 108 | q := newQueue() 109 | input1 := partitionedItems{ 110 | "a": []interface{}{"val1"}, 111 | "b": []interface{}{"val2"}, 112 | } 113 | 114 | input2 := partitionedItems{ 115 | "a": []interface{}{"val4"}, 116 | "c": []interface{}{"val5"}, 117 | } 118 | 119 | expectedRes := []partitionedItems{ 120 | { 121 | "a": []interface{}{"val1"}, 122 | "b": []interface{}{"val2"}, 123 | }, { 124 | "a": []interface{}{"val4"}, 125 | "c": []interface{}{"val5"}, 126 | }, 127 | } 128 | 129 | q.Append(input1) 130 | q.Append(input2) 131 | assert.Equal(t, expectedRes, q.Flush()) 132 | } 133 | 134 | func TestQuery_flush(t *testing.T) { 135 | q := newQueue() 136 | 137 | // fill greater than default capacity 138 | items := defaultCapacity + 10 139 | for x := 0; x < items; x++ { 140 | q.Append(partitionedItems{}) 141 | } 142 | assert.True(t, defaultCapacity < cap(q.queue)) 143 | assert.True(t, defaultCapacity < len(q.queue)) 144 | 145 | // flush 146 | flushedItems := q.Flush() 147 | 148 | // validate 149 | assert.Equal(t, items, len(flushedItems)) 150 | assert.Equal(t, 0, len(q.queue)) 151 | assert.Equal(t, defaultCapacity, cap(q.queue)) 152 | } 153 | -------------------------------------------------------------------------------- /repeat.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "log" 9 | "runtime/debug" 10 | "time" 11 | ) 12 | 13 | // Repeat performs an action asynchronously on a predetermined interval. 14 | func Repeat(ctx context.Context, interval time.Duration, action Work) Task { 15 | safeAction := func(ctx context.Context) (interface{}, error) { 16 | defer handlePanic() 17 | return action(ctx) 18 | } 19 | 20 | // Invoke the task timer 21 | return Invoke(ctx, func(taskCtx context.Context) (interface{}, error) { 22 | timer := time.NewTicker(interval) 23 | for { 24 | select { 25 | case <-taskCtx.Done(): 26 | timer.Stop() 27 | return nil, nil 28 | 29 | case <-timer.C: 30 | _, _ = safeAction(taskCtx) 31 | } 32 | } 33 | }) 34 | } 35 | 36 | // handlePanic handles the panic and logs it out. 37 | func handlePanic() { 38 | if r := recover(); r != nil { 39 | log.Printf("panic recovered: %ss \n %s", r, debug.Stack()) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /repeat_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestRepeat(t *testing.T) { 16 | assert.NotPanics(t, func() { 17 | out := make(chan bool, 1) 18 | task := Repeat(context.TODO(), time.Nanosecond*10, func(context.Context) (interface{}, error) { 19 | out <- true 20 | return nil, nil 21 | }) 22 | 23 | <-out 24 | v := <-out 25 | assert.True(t, v) 26 | task.Cancel() 27 | }) 28 | } 29 | 30 | func ExampleRepeat() { 31 | out := make(chan bool, 1) 32 | task := Repeat(context.TODO(), time.Nanosecond*10, func(context.Context) (interface{}, error) { 33 | out <- true 34 | return nil, nil 35 | }) 36 | 37 | <-out 38 | v := <-out 39 | fmt.Println(v) 40 | task.Cancel() 41 | 42 | // Output: 43 | // true 44 | } 45 | 46 | /* 47 | func TestRepeatFirstActionPanic(t *testing.T) { 48 | assert.NotPanics(t, func() { 49 | task := Repeat(context.TODO(), time.Nanosecond*10, func(context.Context) (interface{}, error) { 50 | panic("test") 51 | }) 52 | 53 | task.Cancel() 54 | }) 55 | } 56 | 57 | func TestRepeatPanic(t *testing.T) { 58 | assert.NotPanics(t, func() { 59 | var counter int32 60 | task := Repeat(context.TODO(), time.Nanosecond*10, func(context.Context) (interface{}, error) { 61 | atomic.AddInt32(&counter, 1) 62 | panic("test") 63 | }) 64 | 65 | for atomic.LoadInt32(&counter) <= 10 { 66 | } 67 | 68 | task.Cancel() 69 | }) 70 | } 71 | */ 72 | -------------------------------------------------------------------------------- /spread.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "time" 9 | 10 | "golang.org/x/time/rate" 11 | ) 12 | 13 | // Throttle runs the tasks with a specified rate limiter. 14 | func Throttle(ctx context.Context, tasks []Task, rateLimit int, every time.Duration) Task { 15 | return Invoke(ctx, func(context.Context) (interface{}, error) { 16 | limiter := rate.NewLimiter(rate.Every(every/time.Duration(rateLimit)), 1) 17 | for i, task := range tasks { 18 | select { 19 | case <-ctx.Done(): 20 | CancelAll(tasks[i:]) 21 | return nil, errCancelled 22 | default: 23 | if err := limiter.Wait(ctx); err == nil { 24 | task.Run(ctx) 25 | } 26 | } 27 | } 28 | 29 | WaitAll(tasks) 30 | return nil, nil 31 | }) 32 | } 33 | 34 | // Spread evenly spreads the work within the specified duration. 35 | func Spread(ctx context.Context, within time.Duration, tasks []Task) Task { 36 | return Invoke(ctx, func(context.Context) (interface{}, error) { 37 | sleep := within / time.Duration(len(tasks)) 38 | for _, task := range tasks { 39 | select { 40 | case <-ctx.Done(): 41 | return nil, errCancelled 42 | default: 43 | task.Run(ctx) 44 | time.Sleep(sleep) 45 | } 46 | } 47 | 48 | WaitAll(tasks) 49 | return nil, nil 50 | }) 51 | } 52 | -------------------------------------------------------------------------------- /spread_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func newTasks() []Task { 16 | work := func(context.Context) (interface{}, error) { 17 | return 1, nil 18 | } 19 | 20 | return NewTasks(work, work, work, work, work) 21 | } 22 | 23 | func TestThrottle(t *testing.T) { 24 | tasks := newTasks() 25 | 26 | // Throttle and calculate the duration 27 | t0 := time.Now() 28 | task := Throttle(context.Background(), tasks, 3, 50*time.Millisecond) 29 | _, _ = task.Outcome() // Wait 30 | 31 | // Make sure we completed within duration 32 | dt := int(time.Now().Sub(t0).Seconds() * 1000) 33 | assert.True(t, dt > 50 && dt < 100, fmt.Sprintf("%v ms.", dt)) 34 | } 35 | 36 | func TestThrottle_Cancel(t *testing.T) { 37 | tasks := newTasks() 38 | 39 | ctx, cancel := context.WithCancel(context.Background()) 40 | cancel() 41 | 42 | // Throttle and calculate the duration 43 | Throttle(ctx, tasks, 3, 50*time.Millisecond) 44 | WaitAll(tasks) 45 | cancelled := 0 46 | for _, task := range tasks { 47 | if task.State() == IsCancelled { 48 | cancelled++ 49 | } 50 | } 51 | 52 | assert.Equal(t, 5, cancelled) 53 | } 54 | 55 | func TestSpread(t *testing.T) { 56 | tasks := newTasks() 57 | within := 200 * time.Millisecond 58 | 59 | // Spread and calculate the duration 60 | t0 := time.Now() 61 | task := Spread(context.Background(), within, tasks) 62 | _, _ = task.Outcome() // Wait 63 | 64 | // Make sure we completed within duration 65 | dt := int(time.Now().Sub(t0).Seconds() * 1000) 66 | assert.True(t, dt > 150 && dt < 250, fmt.Sprintf("%v ms.", dt)) 67 | 68 | // Make sure all tasks are done 69 | for _, task := range tasks { 70 | v, _ := task.Outcome() 71 | assert.Equal(t, 1, v.(int)) 72 | } 73 | } 74 | 75 | func ExampleSpread() { 76 | tasks := newTasks() 77 | within := 200 * time.Millisecond 78 | 79 | // Spread 80 | task := Spread(context.Background(), within, tasks) 81 | _, _ = task.Outcome() // Wait 82 | 83 | // Make sure all tasks are done 84 | for _, task := range tasks { 85 | v, _ := task.Outcome() 86 | fmt.Println(v) 87 | } 88 | 89 | // Output: 90 | // 1 91 | // 1 92 | // 1 93 | // 1 94 | // 1 95 | 96 | } 97 | -------------------------------------------------------------------------------- /task.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "errors" 9 | "sync/atomic" 10 | "time" 11 | ) 12 | 13 | var errCancelled = errors.New("context canceled") 14 | 15 | var now = time.Now 16 | 17 | // Work represents a handler to execute 18 | type Work func(context.Context) (interface{}, error) 19 | 20 | // State represents the state enumeration for a task. 21 | type State byte 22 | 23 | // Various task states 24 | const ( 25 | IsCreated State = iota // IsCreated represents a newly created task 26 | IsRunning // IsRunning represents a task which is currently running 27 | IsCompleted // IsCompleted represents a task which was completed successfully or errored out 28 | IsCancelled // IsCancelled represents a task which was cancelled or has timed out 29 | ) 30 | 31 | type signal chan struct{} 32 | 33 | // Outcome of the task contains a result and an error 34 | type outcome struct { 35 | result interface{} // The result of the work 36 | err error // The error 37 | } 38 | 39 | // Task represents a unit of work to be done 40 | type task struct { 41 | state int32 // This indicates whether the task is started or not 42 | cancel signal // The cancellation channel 43 | done signal // The outcome channel 44 | action Work // The work to do 45 | outcome outcome // This is used to store the result 46 | duration time.Duration // The duration of the task, in nanoseconds 47 | } 48 | 49 | // Task represents a unit of work to be done 50 | type Task interface { 51 | Run(ctx context.Context) Task 52 | Cancel() 53 | State() State 54 | Outcome() (interface{}, error) 55 | ContinueWith(ctx context.Context, nextAction func(interface{}, error) (interface{}, error)) Task 56 | } 57 | 58 | // NewTask creates a new task. 59 | func NewTask(action Work) Task { 60 | return &task{ 61 | action: action, 62 | done: make(signal, 1), 63 | cancel: make(signal, 1), 64 | } 65 | } 66 | 67 | // NewTasks creates a set of new tasks. 68 | func NewTasks(actions ...Work) []Task { 69 | tasks := make([]Task, 0, len(actions)) 70 | for _, action := range actions { 71 | tasks = append(tasks, NewTask(action)) 72 | } 73 | return tasks 74 | } 75 | 76 | // Invoke creates a new tasks and runs it asynchronously. 77 | func Invoke(ctx context.Context, action Work) Task { 78 | return NewTask(action).Run(ctx) 79 | } 80 | 81 | // Outcome waits until the task is done and returns the final result and error. 82 | func (t *task) Outcome() (interface{}, error) { 83 | <-t.done 84 | return t.outcome.result, t.outcome.err 85 | } 86 | 87 | // State returns the current state of the task. This operation is non-blocking. 88 | func (t *task) State() State { 89 | v := atomic.LoadInt32(&t.state) 90 | return State(v) 91 | } 92 | 93 | // Duration returns the duration of the task. 94 | func (t *task) Duration() time.Duration { 95 | return t.duration 96 | } 97 | 98 | // Run starts the task asynchronously. 99 | func (t *task) Run(ctx context.Context) Task { 100 | go t.run(ctx) 101 | return t 102 | } 103 | 104 | // Cancel cancels a running task. 105 | func (t *task) Cancel() { 106 | 107 | // If the task was created but never started, transition directly to cancelled state 108 | // and close the done channel and set the error. 109 | if t.changeState(IsCreated, IsCancelled) { 110 | t.outcome = outcome{err: errCancelled} 111 | close(t.done) 112 | return 113 | } 114 | 115 | // Attempt to cancel the task if it's in the running state 116 | if t.cancel != nil { 117 | select { 118 | case <-t.cancel: 119 | return 120 | default: 121 | close(t.cancel) 122 | } 123 | } 124 | } 125 | 126 | // run starts the task synchronously. 127 | func (t *task) run(ctx context.Context) { 128 | if !t.changeState(IsCreated, IsRunning) { 129 | return // Prevent from running the same task twice 130 | } 131 | 132 | // Notify everyone of the completion/error state 133 | defer close(t.done) 134 | 135 | // Execute the task 136 | startedAt := now().UnixNano() 137 | outcomeCh := make(chan outcome, 1) 138 | go func() { 139 | r, e := t.action(ctx) 140 | outcomeCh <- outcome{result: r, err: e} 141 | }() 142 | 143 | select { 144 | 145 | // In case of a manual task cancellation, set the outcome and transition 146 | // to the cancelled state. 147 | case <-t.cancel: 148 | t.duration = time.Nanosecond * time.Duration(now().UnixNano()-startedAt) 149 | t.outcome = outcome{err: errCancelled} 150 | t.changeState(IsRunning, IsCancelled) 151 | return 152 | 153 | // In case of the context timeout or other error, change the state of the 154 | // task to cancelled and return right away. 155 | case <-ctx.Done(): 156 | t.duration = time.Nanosecond * time.Duration(now().UnixNano()-startedAt) 157 | t.outcome = outcome{err: ctx.Err()} 158 | t.changeState(IsRunning, IsCancelled) 159 | return 160 | 161 | // In case where we got an outcome (happy path) 162 | case o := <-outcomeCh: 163 | t.duration = time.Nanosecond * time.Duration(now().UnixNano()-startedAt) 164 | t.outcome = o 165 | t.changeState(IsRunning, IsCompleted) 166 | return 167 | } 168 | } 169 | 170 | // ContinueWith proceeds with the next task once the current one is finished. 171 | func (t *task) ContinueWith(ctx context.Context, nextAction func(interface{}, error) (interface{}, error)) Task { 172 | return Invoke(ctx, func(context.Context) (interface{}, error) { 173 | result, err := t.Outcome() 174 | return nextAction(result, err) 175 | }) 176 | } 177 | 178 | // Cancel cancels a running task. 179 | func (t *task) changeState(from, to State) bool { 180 | return atomic.CompareAndSwapInt32(&t.state, int32(from), int32(to)) 181 | } 182 | -------------------------------------------------------------------------------- /task_test.go: -------------------------------------------------------------------------------- 1 | // Copyright 2019 Grabtaxi Holdings PTE LTE (GRAB), All rights reserved. 2 | // Use of this source code is governed by an MIT-style license that can be found in the LICENSE file 3 | 4 | package async 5 | 6 | import ( 7 | "context" 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func TestNewTasks(t *testing.T) { 16 | work := func(context.Context) (interface{}, error) { 17 | return 1, nil 18 | } 19 | 20 | tasks := NewTasks(work, work, work) 21 | assert.Equal(t, 3, len(tasks)) 22 | } 23 | 24 | func TestOutcome(t *testing.T) { 25 | task := Invoke(context.Background(), func(context.Context) (interface{}, error) { 26 | return 1, nil 27 | }) 28 | 29 | var wg sync.WaitGroup 30 | wg.Add(100) 31 | for i := 0; i < 100; i++ { 32 | go func() { 33 | o, _ := task.Outcome() 34 | wg.Done() 35 | assert.Equal(t, o.(int), 1) 36 | }() 37 | } 38 | wg.Wait() 39 | } 40 | 41 | func TestOutcomeTimeout(t *testing.T) { 42 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 43 | defer cancel() 44 | 45 | task := Invoke(ctx, func(context.Context) (interface{}, error) { 46 | time.Sleep(500 * time.Millisecond) 47 | return 1, nil 48 | }) 49 | 50 | _, err := task.Outcome() 51 | assert.Equal(t, "context deadline exceeded", err.Error()) 52 | } 53 | 54 | func TestContinueWithChain(t *testing.T) { 55 | task1 := Invoke(context.Background(), func(context.Context) (interface{}, error) { 56 | return 1, nil 57 | }) 58 | 59 | ctx := context.TODO() 60 | task2 := task1.ContinueWith(ctx, func(result interface{}, _ error) (interface{}, error) { 61 | return result.(int) + 1, nil 62 | }) 63 | 64 | task3 := task2.ContinueWith(ctx, func(result interface{}, _ error) (interface{}, error) { 65 | return result.(int) + 1, nil 66 | }) 67 | 68 | result, err := task3.Outcome() 69 | assert.Equal(t, 3, result) 70 | assert.Nil(t, err) 71 | } 72 | 73 | func TestContinueTimeout(t *testing.T) { 74 | ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 75 | defer cancel() 76 | 77 | first := Invoke(ctx, func(context.Context) (interface{}, error) { 78 | return 5, nil 79 | }) 80 | 81 | second := first.ContinueWith(ctx, func(result interface{}, err error) (interface{}, error) { 82 | time.Sleep(500 * time.Millisecond) 83 | return result, err 84 | }) 85 | 86 | r1, err1 := first.Outcome() 87 | assert.Equal(t, 5, r1) 88 | assert.Nil(t, err1) 89 | 90 | _, err2 := second.Outcome() 91 | assert.Equal(t, "context deadline exceeded", err2.Error()) 92 | } 93 | 94 | func TestTaskCancelStarted(t *testing.T) { 95 | task := Invoke(context.Background(), func(context.Context) (interface{}, error) { 96 | time.Sleep(500 * time.Millisecond) 97 | return 1, nil 98 | }) 99 | 100 | task.Cancel() 101 | 102 | _, err := task.Outcome() 103 | assert.Equal(t, errCancelled, err) 104 | } 105 | 106 | func TestTaskCancelRunning(t *testing.T) { 107 | task := Invoke(context.Background(), func(context.Context) (interface{}, error) { 108 | time.Sleep(500 * time.Millisecond) 109 | return 1, nil 110 | }) 111 | 112 | time.Sleep(10 * time.Millisecond) 113 | 114 | task.Cancel() 115 | 116 | _, err := task.Outcome() 117 | assert.Equal(t, errCancelled, err) 118 | } 119 | 120 | func TestTaskCancelTwice(t *testing.T) { 121 | task := Invoke(context.Background(), func(context.Context) (interface{}, error) { 122 | time.Sleep(500 * time.Millisecond) 123 | return 1, nil 124 | }) 125 | 126 | assert.NotPanics(t, func() { 127 | for i := 0; i < 100; i++ { 128 | task.Cancel() 129 | } 130 | }) 131 | 132 | _, err := task.Outcome() 133 | assert.Equal(t, errCancelled, err) 134 | } 135 | --------------------------------------------------------------------------------