├── .github ├── CODEOWNERS ├── splash_image.png └── workflows │ └── test.yml ├── .gitignore ├── .golangci.yml ├── .mockery.yml ├── LICENSE ├── Makefile ├── README.md ├── _examples ├── mysql │ ├── docker-compose.yml │ └── main.go └── postgres │ ├── docker-compose.yml │ └── main.go ├── cleaner.go ├── cleaner_test.go ├── client.go ├── client_test.go ├── consumer.go ├── consumer_test.go ├── export_test.go ├── go.mod ├── go.sum ├── inspector.go ├── inspector_test.go ├── mocks └── IRepository.go ├── producer.go ├── producer_test.go ├── repository.go ├── repository ├── mysql │ ├── export_test.go │ ├── mysql.go │ ├── mysqlTask.go │ ├── mysqlTask_test.go │ └── mysql_test.go └── postgres │ ├── export_test.go │ ├── postgres.go │ ├── postgresTask.go │ ├── postgresTask_test.go │ └── postgres_test.go ├── task.go └── task_test.go /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @greencoda -------------------------------------------------------------------------------- /.github/splash_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/greencoda/tasq/cfff6aa01b5cda4faac73a706617161140d584c0/.github/splash_image.png -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v3 15 | 16 | - name: Setup Go 17 | uses: actions/setup-go@v3 18 | with: 19 | go-version: 1.19.x 20 | check-latest: true 21 | 22 | - name: Lint 23 | uses: golangci/golangci-lint-action@v3 24 | with: 25 | version: v1.50.1 26 | test: 27 | runs-on: ubuntu-latest 28 | steps: 29 | - name: Checkout 30 | uses: actions/checkout@v3 31 | 32 | - name: Setup Go 33 | uses: actions/setup-go@v3 34 | with: 35 | go-version: 1.19.x 36 | check-latest: true 37 | 38 | - name: Setup Dependencies 39 | run: make deps && git diff-index --quiet HEAD || { >&2 echo "Stale go.{mod,sum} detected. This can be fixed with 'make deps'."; exit 1; } 40 | 41 | - name: Run Test 42 | run: make test 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # .DS_Store 2 | .DS_Store 3 | 4 | # Binaries for programs and plugins 5 | *.exe 6 | *.exe~ 7 | *.dll 8 | *.so 9 | *.dylib 10 | 11 | # Test binary, build with `go test -c` 12 | *.test 13 | 14 | # Output of the go coverage tool, specifically when used with LiteIDE 15 | *.out 16 | 17 | # Ignore dependencies in ./vendor 18 | vendor 19 | 20 | # ignore build products 21 | *.zip 22 | bin/ 23 | 24 | # ignore docker db volume data 25 | .dbdata 26 | 27 | # IDE-related files 28 | .idea 29 | .vscode 30 | *.code-workspace 31 | 32 | # product from go test -cover / ginkgo -cover 33 | *.coverprofile 34 | 35 | # certificates 36 | *.pem -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | # Refer to golangci-lint's example config file for more options and information: 2 | # https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml 3 | 4 | run: 5 | timeout: 5m 6 | modules-download-mode: readonly 7 | 8 | linters: 9 | enable-all: true 10 | disable: 11 | - dupl 12 | - funlen 13 | - gochecknoglobals 14 | - lll 15 | - revive 16 | - varnamelen 17 | 18 | linters-settings: 19 | paralleltest: 20 | ignore-missing: true 21 | interfacebloat: 22 | max: 20 23 | gosec: 24 | excludes: 25 | - G404 26 | 27 | issues: 28 | exclude-use-default: false 29 | max-issues-per-linter: 0 30 | max-same-issues: 0 -------------------------------------------------------------------------------- /.mockery.yml: -------------------------------------------------------------------------------- 1 | name: "IRepository" 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 greencoda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: deps lint mocks test test-cover 2 | .SILENT: test test-cover 3 | 4 | TESTABLE_PACKAGES = $(shell go list ./... | grep -v ./mocks) 5 | 6 | deps: 7 | go mod tidy 8 | go mod vendor 9 | 10 | lint: deps 11 | golangci-lint run -v 12 | 13 | mocks: 14 | rm -rf mocks/* 15 | mockery 16 | 17 | test: 18 | go test ${TESTABLE_PACKAGES} -count=1 -cover 19 | 20 | test-cover: 21 | go test ${TESTABLE_PACKAGES} -count=1 -cover -coverprofile=coverage.out 22 | go tool cover -html=coverage.out 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![godoc for greencoda/tasq][godoc-badge]][godoc-url] 2 | [![Build Status][actions-badge]][actions-url] 3 | [![Go 1.19][goversion-badge]][goversion-url] 4 | [![Go Report card][goreportcard-badge]][goreportcard-url] 5 | 6 |

7 | 8 | # tasq 9 | 10 | Tasq is Golang task queue using SQL database for persistence. 11 | Currently supports: 12 | - PostgreSQL 13 | - MySQL 14 | 15 | ## Install 16 | 17 | ```shell 18 | go get -u github.com/greencoda/tasq 19 | ``` 20 | 21 | ## Usage Example 22 | 23 | To try tasq locally, you'll need a local DB running on your machine. You may use the supplied docker-compose.yml file to start a local instance 24 | ```shell 25 | docker-compose -f _examples//docker-compose.yml up -d 26 | ``` 27 | 28 | Afterwards simply run the example.go file 29 | ```shell 30 | go run _examples//main.go 31 | ``` 32 | 33 | [godoc-badge]: https://pkg.go.dev/badge/github.com/greencoda/tasq 34 | [godoc-url]: https://pkg.go.dev/github.com/greencoda/tasq 35 | [actions-badge]: https://github.com/greencoda/tasq/actions/workflows/test.yml/badge.svg 36 | [actions-url]: https://github.com/greencoda/tasq/actions/workflows/test.yml 37 | [goversion-badge]: https://img.shields.io/badge/Go-1.19-%2300ADD8?logo=go 38 | [goversion-url]: https://golang.org/doc/go1.19 39 | [goreportcard-badge]: https://goreportcard.com/badge/github.com/greencoda/tasq 40 | [goreportcard-url]: https://goreportcard.com/report/github.com/greencoda/tasq 41 | -------------------------------------------------------------------------------- /_examples/mysql/docker-compose.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: '3.8' 3 | 4 | services: 5 | mysql: 6 | image: mysql:8.0 7 | command: ["mysqld", "--skip-name-resolve"] 8 | ports: 9 | - "3306:3306" 10 | volumes: 11 | - ./.dbdata:/var/lib/mysql 12 | environment: 13 | MYSQL_DATABASE: test 14 | MYSQL_ROOT_PASSWORD: root 15 | healthcheck: 16 | test: ["CMD", "mysqladmin" ,"ping", "-h", "localhost"] 17 | interval: 10s 18 | timeout: 5s 19 | retries: 10 20 | networks: 21 | tasq: 22 | name: tasq_network 23 | external: false 24 | -------------------------------------------------------------------------------- /_examples/mysql/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | "math/rand" 9 | "sync" 10 | "time" 11 | 12 | "github.com/greencoda/tasq" 13 | tasqMySQL "github.com/greencoda/tasq/repository/mysql" 14 | ) 15 | 16 | const ( 17 | channelSize = 10 18 | taskType = "sampleTask" 19 | taskQueue = "sampleQueue" 20 | taskPriority = 20 21 | taskScanLimit = 20 22 | taskMaxReceives = 5 23 | deletionCutoff = 0.5 24 | migrationTimeout = 10 * time.Second 25 | pollInterval = 10 * time.Second 26 | consumerShutdownTimeout = 30 * time.Second 27 | ) 28 | 29 | // SampleTaskArgs is a struct that represents the arguments for the sample task. 30 | type SampleTaskArgs struct { 31 | ID int 32 | Value float64 33 | } 34 | 35 | func processSampleTask(task *tasq.Task) error { 36 | var args SampleTaskArgs 37 | 38 | err := task.UnmarshalArgs(&args) 39 | if err != nil { 40 | return fmt.Errorf("failed to unmarshal value: %w", err) 41 | } 42 | 43 | // do something here with the task arguments as input 44 | // for purposes of the sample, we'll just log its details here 45 | log.Printf("executed task '%s' with args '%+v'", task.ID, args) 46 | 47 | return nil 48 | } 49 | 50 | func consumeTasks(consumer *tasq.Consumer, wg *sync.WaitGroup) { 51 | defer wg.Done() 52 | 53 | for { 54 | job := <-consumer.Channel() 55 | if job == nil { 56 | return 57 | } 58 | 59 | // execute the job right away or feed it into a workerpool 60 | // such as workerpool.Add(*job) 61 | (*job)() 62 | } 63 | } 64 | 65 | func produceTasks(ctx context.Context, producer *tasq.Producer) { 66 | taskTicker := time.NewTicker(1 * time.Second) 67 | 68 | for taskIndex := 0; true; taskIndex++ { 69 | <-taskTicker.C 70 | 71 | seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) 72 | 73 | taskArgs := SampleTaskArgs{ 74 | ID: taskIndex, 75 | Value: seededRand.Float64(), 76 | } 77 | 78 | t, err := producer.Submit(ctx, taskType, taskArgs, taskQueue, taskPriority, taskMaxReceives) 79 | if err != nil { 80 | log.Panicf("error while submitting task to tasq: %s", err) 81 | } else { 82 | log.Printf("successfully submitted task '%s'", t.ID) 83 | } 84 | } 85 | } 86 | 87 | func inspectTasks(ctx context.Context, inspector *tasq.Inspector) { 88 | taskTicker := time.NewTicker(1 * time.Second) 89 | 90 | for taskIndex := 0; true; taskIndex++ { 91 | <-taskTicker.C 92 | 93 | taskCount, err := inspector.Count(ctx, nil, []string{taskType}, []string{taskQueue}) 94 | if err != nil { 95 | log.Panicf("error while counting tasks: %s", err) 96 | } 97 | 98 | log.Printf("successfully counted %d tasks total", taskCount) 99 | 100 | tasks, err := inspector.Scan(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{taskType}, []string{taskQueue}, tasq.OrderingCreatedAtFirst, taskScanLimit) 101 | if err != nil { 102 | log.Panicf("error while scanning tasks: %s", err) 103 | } 104 | 105 | log.Printf("successfully scanned %d new tasks", len(tasks)) 106 | 107 | for _, task := range tasks { 108 | var args SampleTaskArgs 109 | 110 | err := task.UnmarshalArgs(&args) 111 | if err != nil { 112 | log.Printf("failed to unmarshal value for task '%s'", task.ID) 113 | 114 | continue 115 | } 116 | 117 | if args.Value < deletionCutoff { 118 | continue 119 | } 120 | 121 | err = inspector.Delete(ctx, true, task) 122 | if err != nil { 123 | log.Printf("failed to remove task '%s'", task.ID) 124 | 125 | continue 126 | } 127 | 128 | log.Printf("successfully removed task '%s'", task.ID) 129 | } 130 | } 131 | } 132 | 133 | func main() { 134 | ctx, cancelCtx := context.WithCancel(context.Background()) 135 | defer cancelCtx() 136 | 137 | db, err := sql.Open("mysql", "root:root@/test") 138 | if err != nil { 139 | log.Panicf("failed to open DB connection: %v", err) 140 | } 141 | 142 | // instantiate tasq repository to manage the database connection 143 | // you can also have it set up the sql DB for you if you provide the dsn string 144 | // instead of the *sql.DB instance 145 | tasqRepository, err := tasqMySQL.NewRepository(db, "tasq") 146 | if err != nil { 147 | log.Panicf("failed to create tasq repository: %s", err) 148 | } 149 | 150 | migrationCtx, migrationCancelCtx := context.WithTimeout(context.Background(), migrationTimeout) 151 | defer migrationCancelCtx() 152 | 153 | err = tasqRepository.Migrate(migrationCtx) 154 | if err != nil { 155 | log.Panicf("failed to migrate tasq repository: %s", err) 156 | } 157 | 158 | log.Print("database migrated successfully") 159 | 160 | // instantiate tasq client 161 | tasqClient := tasq.NewClient(tasqRepository) 162 | 163 | // set up tasq cleaner 164 | cleaner := tasqClient.NewCleaner(). 165 | WithTaskAge(time.Second) 166 | 167 | cleanedTaskCount, err := cleaner.Clean(ctx) 168 | if err != nil { 169 | log.Panicf("failed to clean old tasks from queue: %s", err) 170 | } 171 | 172 | log.Printf("cleaned %d finished tasks from the queue on startup", cleanedTaskCount) 173 | 174 | // set up tasq consumer 175 | consumer := tasqClient.NewConsumer(). 176 | WithQueues(taskQueue). 177 | WithChannelSize(channelSize). 178 | WithPollInterval(pollInterval). 179 | WithPollStrategy(tasq.PollStrategyByPriority). 180 | WithAutoDeleteOnSuccess(false). 181 | WithLogger(log.Default()) 182 | 183 | // teach the consumer to handle tasks with the type "sampleTask" with the function "processSampleTask" 184 | err = consumer.Learn(taskType, processSampleTask, false) 185 | if err != nil { 186 | log.Panicf("failed to teach tasq consumer task handler: %s", err) 187 | } 188 | 189 | // start the consumer 190 | err = consumer.Start(ctx) 191 | if err != nil { 192 | log.Panicf("failed to start tasq consumer: %s", err) 193 | } 194 | 195 | var consumerWg sync.WaitGroup 196 | 197 | // start the goroutine which handles the tasq jobs received from the consumer 198 | consumerWg.Add(1) 199 | 200 | go consumeTasks(consumer, &consumerWg) 201 | 202 | // set up tasq inspector 203 | inspector := tasqClient.NewInspector() 204 | 205 | go inspectTasks(ctx, inspector) 206 | 207 | // set up tasq producer 208 | producer := tasqClient.NewProducer() 209 | 210 | // start the goroutine which produces the tasks and submits them to the tasq queue 211 | go produceTasks(ctx, producer) 212 | 213 | // block the execution 214 | <-time.After(consumerShutdownTimeout) 215 | 216 | err = consumer.Stop() 217 | if err != nil { 218 | log.Panicf("failed to stop tasq consumer: %s", err) 219 | } 220 | 221 | // wait until consumer go routine exits 222 | consumerWg.Wait() 223 | 224 | purgedTaskCount, err := inspector.Purge(ctx, true, tasq.GetTaskStatuses(tasq.OpenTasks), []string{taskType}, []string{taskQueue}) 225 | if err != nil { 226 | log.Panicf("failed to purge tasq queue: %s", err) 227 | } 228 | 229 | log.Printf("purged %d open tasks from the queue", purgedTaskCount) 230 | } 231 | -------------------------------------------------------------------------------- /_examples/postgres/docker-compose.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: "3.2" 3 | 4 | services: 5 | tasq: 6 | container_name: postgres 7 | image: postgres:14.2-alpine 8 | volumes: 9 | - ./.dbdata:/var/lib/postgresql 10 | ports: 11 | - "5432:5432" 12 | environment: 13 | LC_ALL: C.UTF-8 14 | POSTGRES_USER: test 15 | POSTGRES_PASSWORD: test 16 | POSTGRES_DB: test 17 | tmpfs: 18 | - /var/lib/postgresql/data 19 | healthcheck: 20 | test: [ "CMD", "pg_isready" ] 21 | interval: 10s 22 | timeout: 5s 23 | retries: 5 24 | networks: 25 | - tasq 26 | networks: 27 | tasq: 28 | name: tasq_network 29 | external: false 30 | -------------------------------------------------------------------------------- /_examples/postgres/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "fmt" 7 | "log" 8 | "math/rand" 9 | "sync" 10 | "time" 11 | 12 | "github.com/greencoda/tasq" 13 | tasqPostgres "github.com/greencoda/tasq/repository/postgres" 14 | ) 15 | 16 | const ( 17 | channelSize = 10 18 | taskType = "sampleTask" 19 | taskQueue = "sampleQueue" 20 | taskPriority = 20 21 | taskScanLimit = 20 22 | taskMaxReceives = 5 23 | deletionCutoff = 0.5 24 | migrationTimeout = 10 * time.Second 25 | pollInterval = 10 * time.Second 26 | consumerShutdownTimeout = 30 * time.Second 27 | ) 28 | 29 | // SampleTaskArgs is a struct that represents the arguments for the sample task. 30 | type SampleTaskArgs struct { 31 | ID int 32 | Value float64 33 | } 34 | 35 | func processSampleTask(task *tasq.Task) error { 36 | var args SampleTaskArgs 37 | 38 | err := task.UnmarshalArgs(&args) 39 | if err != nil { 40 | return fmt.Errorf("failed to unmarshal value: %w", err) 41 | } 42 | 43 | // do something here with the task arguments as input 44 | // for purposes of the sample, we'll just log its details here 45 | log.Printf("executed task '%s' with args '%+v'", task.ID, args) 46 | 47 | return nil 48 | } 49 | 50 | func consumeTasks(consumer *tasq.Consumer, wg *sync.WaitGroup) { 51 | defer wg.Done() 52 | 53 | for { 54 | job := <-consumer.Channel() 55 | if job == nil { 56 | return 57 | } 58 | 59 | // execute the job right away or feed it into a workerpool 60 | // such as workerpool.Add(*job) 61 | (*job)() 62 | } 63 | } 64 | 65 | func produceTasks(ctx context.Context, producer *tasq.Producer) { 66 | taskTicker := time.NewTicker(1 * time.Second) 67 | 68 | for taskIndex := 0; true; taskIndex++ { 69 | <-taskTicker.C 70 | 71 | seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) 72 | 73 | taskArgs := SampleTaskArgs{ 74 | ID: taskIndex, 75 | Value: seededRand.Float64(), 76 | } 77 | 78 | t, err := producer.Submit(ctx, taskType, taskArgs, taskQueue, taskPriority, taskMaxReceives) 79 | if err != nil { 80 | log.Panicf("error while submitting task to tasq: %s", err) 81 | } else { 82 | log.Printf("successfully submitted task '%s'", t.ID) 83 | } 84 | } 85 | } 86 | 87 | func inspectTasks(ctx context.Context, inspector *tasq.Inspector) { 88 | taskTicker := time.NewTicker(1 * time.Second) 89 | 90 | for taskIndex := 0; true; taskIndex++ { 91 | <-taskTicker.C 92 | 93 | taskCount, err := inspector.Count(ctx, nil, []string{taskType}, []string{taskQueue}) 94 | if err != nil { 95 | log.Panicf("error while counting tasks: %s", err) 96 | } 97 | 98 | log.Printf("successfully counted %d tasks total", taskCount) 99 | 100 | tasks, err := inspector.Scan(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{taskType}, []string{taskQueue}, tasq.OrderingCreatedAtFirst, taskScanLimit) 101 | if err != nil { 102 | log.Panicf("error while scanning tasks: %s", err) 103 | } 104 | 105 | log.Printf("successfully scanned %d new tasks", len(tasks)) 106 | 107 | for _, task := range tasks { 108 | var args SampleTaskArgs 109 | 110 | err := task.UnmarshalArgs(&args) 111 | if err != nil { 112 | log.Printf("failed to unmarshal value for task '%s'", task.ID) 113 | 114 | continue 115 | } 116 | 117 | if args.Value < deletionCutoff { 118 | continue 119 | } 120 | 121 | err = inspector.Delete(ctx, true, task) 122 | if err != nil { 123 | log.Printf("failed to remove task '%s'", task.ID) 124 | 125 | continue 126 | } 127 | 128 | log.Printf("successfully removed task '%s'", task.ID) 129 | } 130 | } 131 | } 132 | 133 | func main() { 134 | ctx, cancelCtx := context.WithCancel(context.Background()) 135 | defer cancelCtx() 136 | 137 | db, err := sql.Open("postgres", "host=127.0.0.1 user=test password=test dbname=test port=5432 sslmode=disable") 138 | if err != nil { 139 | log.Panicf("failed to open DB connection: %v", err) 140 | } 141 | 142 | // instantiate tasq repository to manage the database connection 143 | // you can also have it set up the sql DB for you if you provide the dsn string 144 | // instead of the *sql.DB instance 145 | tasqRepository, err := tasqPostgres.NewRepository(db, "tasq") 146 | if err != nil { 147 | log.Panicf("failed to create tasq repository: %s", err) 148 | } 149 | 150 | migrationCtx, migrationCancelCtx := context.WithTimeout(context.Background(), migrationTimeout) 151 | defer migrationCancelCtx() 152 | 153 | err = tasqRepository.Migrate(migrationCtx) 154 | if err != nil { 155 | log.Panicf("failed to migrate tasq repository: %s", err) 156 | } 157 | 158 | log.Print("database migrated successfully") 159 | 160 | // instantiate tasq client 161 | tasqClient := tasq.NewClient(tasqRepository) 162 | 163 | // set up tasq cleaner 164 | cleaner := tasqClient.NewCleaner(). 165 | WithTaskAge(time.Second) 166 | 167 | cleanedTaskCount, err := cleaner.Clean(ctx) 168 | if err != nil { 169 | log.Panicf("failed to clean old tasks from queue: %s", err) 170 | } 171 | 172 | log.Printf("cleaned %d finished tasks from the queue on startup", cleanedTaskCount) 173 | 174 | // set up tasq consumer 175 | consumer := tasqClient.NewConsumer(). 176 | WithQueues(taskQueue). 177 | WithChannelSize(channelSize). 178 | WithPollInterval(pollInterval). 179 | WithPollStrategy(tasq.PollStrategyByPriority). 180 | WithAutoDeleteOnSuccess(false). 181 | WithLogger(log.Default()) 182 | 183 | // teach the consumer to handle tasks with the type "sampleTask" with the function "processSampleTask" 184 | err = consumer.Learn(taskType, processSampleTask, false) 185 | if err != nil { 186 | log.Panicf("failed to teach tasq consumer task handler: %s", err) 187 | } 188 | 189 | // start the consumer 190 | err = consumer.Start(ctx) 191 | if err != nil { 192 | log.Panicf("failed to start tasq consumer: %s", err) 193 | } 194 | 195 | var consumerWg sync.WaitGroup 196 | 197 | // start the goroutine which handles the tasq jobs received from the consumer 198 | consumerWg.Add(1) 199 | 200 | go consumeTasks(consumer, &consumerWg) 201 | 202 | // set up tasq inspector 203 | inspector := tasqClient.NewInspector() 204 | 205 | go inspectTasks(ctx, inspector) 206 | 207 | // set up tasq producer 208 | producer := tasqClient.NewProducer() 209 | 210 | // start the goroutine which produces the tasks and submits them to the tasq queue 211 | go produceTasks(ctx, producer) 212 | 213 | // block the execution 214 | <-time.After(consumerShutdownTimeout) 215 | 216 | err = consumer.Stop() 217 | if err != nil { 218 | log.Panicf("failed to stop tasq consumer: %s", err) 219 | } 220 | 221 | // wait until consumer go routine exits 222 | consumerWg.Wait() 223 | 224 | purgedTaskCount, err := inspector.Purge(ctx, true, tasq.GetTaskStatuses(tasq.OpenTasks), []string{taskType}, []string{taskQueue}) 225 | if err != nil { 226 | log.Panicf("failed to purge tasq queue: %s", err) 227 | } 228 | 229 | log.Printf("purged %d open tasks from the queue", purgedTaskCount) 230 | } 231 | -------------------------------------------------------------------------------- /cleaner.go: -------------------------------------------------------------------------------- 1 | package tasq 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | const defaultTaskAgeLimit = 15 * time.Minute 10 | 11 | // Cleaner is a service instance created by a Client with reference to that client 12 | // and the task age limit parameter. 13 | type Cleaner struct { 14 | client *Client 15 | 16 | taskAgeLimit time.Duration 17 | } 18 | 19 | // NewCleaner creates a new cleaner with a reference to the original tasq client. 20 | func (c *Client) NewCleaner() *Cleaner { 21 | return &Cleaner{ 22 | client: c, 23 | 24 | taskAgeLimit: defaultTaskAgeLimit, 25 | } 26 | } 27 | 28 | // WithTaskAge defines the minimum time duration that must have passed since the creation of a finished task 29 | // in order for it to be eligible for cleanup when the Cleaner's Clean() method is called. 30 | // 31 | // Default value: 15 minutes. 32 | func (c *Cleaner) WithTaskAge(taskAge time.Duration) *Cleaner { 33 | c.taskAgeLimit = taskAge 34 | 35 | return c 36 | } 37 | 38 | // Clean will initiate the removal of finished (either succeeded or failed) tasks from the tasks table 39 | // if they have been created long enough ago for them to be eligible. 40 | func (c *Cleaner) Clean(ctx context.Context) (int64, error) { 41 | cleanedTaskCount, err := c.client.repository.CleanTasks(ctx, c.taskAgeLimit) 42 | if err != nil { 43 | return 0, fmt.Errorf("failed to clean tasks: %w", err) 44 | } 45 | 46 | return cleanedTaskCount, nil 47 | } 48 | -------------------------------------------------------------------------------- /cleaner_test.go: -------------------------------------------------------------------------------- 1 | package tasq_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/greencoda/tasq" 9 | "github.com/greencoda/tasq/mocks" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/require" 12 | "github.com/stretchr/testify/suite" 13 | ) 14 | 15 | type CleanerTestSuite struct { 16 | suite.Suite 17 | mockRepository *mocks.IRepository 18 | tasqClient *tasq.Client 19 | tasqCleaner *tasq.Cleaner 20 | } 21 | 22 | func TestCleanerTestSuite(t *testing.T) { 23 | t.Parallel() 24 | 25 | suite.Run(t, new(CleanerTestSuite)) 26 | } 27 | 28 | func (s *CleanerTestSuite) SetupTest() { 29 | s.mockRepository = mocks.NewIRepository(s.T()) 30 | 31 | s.tasqClient = tasq.NewClient(s.mockRepository) 32 | require.NotNil(s.T(), s.tasqClient) 33 | 34 | s.tasqCleaner = s.tasqClient.NewCleaner().WithTaskAge(time.Hour) 35 | } 36 | 37 | func (s *CleanerTestSuite) TestNewCleaner() { 38 | assert.NotNil(s.T(), s.tasqCleaner) 39 | } 40 | 41 | func (s *CleanerTestSuite) TestClean() { 42 | ctx := context.Background() 43 | 44 | s.mockRepository.On("CleanTasks", ctx, time.Hour).Return(int64(1), nil).Once() 45 | 46 | rowsAffected, err := s.tasqCleaner.Clean(ctx) 47 | 48 | assert.Equal(s.T(), int64(1), rowsAffected) 49 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "CleanTasks", ctx, time.Hour)) 50 | assert.Nil(s.T(), err) 51 | 52 | s.mockRepository.On("CleanTasks", ctx, time.Hour).Return(int64(0), errRepository).Once() 53 | rowsAffected, err = s.tasqCleaner.Clean(ctx) 54 | 55 | assert.Equal(s.T(), int64(0), rowsAffected) 56 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "CleanTasks", ctx, time.Hour)) 57 | assert.NotNil(s.T(), err) 58 | } 59 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | // Package tasq provides a task queue implementation compapible with multiple repositories 2 | package tasq 3 | 4 | // Client wraps the tasq repository interface which is used 5 | // by the different services to access the database. 6 | type Client struct { 7 | repository IRepository 8 | } 9 | 10 | // NewClient creates a new tasq client instance with the provided tasq. 11 | func NewClient(repository IRepository) *Client { 12 | return &Client{ 13 | repository: repository, 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package tasq_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/greencoda/tasq" 7 | "github.com/greencoda/tasq/mocks" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestNewClient(t *testing.T) { 12 | t.Parallel() 13 | 14 | repository := mocks.NewIRepository(t) 15 | 16 | tasqClient := tasq.NewClient(repository) 17 | assert.NotNil(t, tasqClient) 18 | } 19 | -------------------------------------------------------------------------------- /consumer.go: -------------------------------------------------------------------------------- 1 | package tasq 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "log" 9 | "sync" 10 | "time" 11 | 12 | "github.com/benbjohnson/clock" 13 | "github.com/google/uuid" 14 | ) 15 | 16 | // Collection of consumer errors. 17 | var ( 18 | ErrConsumerAlreadyRunning = errors.New("consumer has already been started") 19 | ErrConsumerAlreadyStopped = errors.New("consumer has already been stopped") 20 | ErrCouldNotActivateTasks = errors.New("a number of tasks could not be activated") 21 | ErrCouldNotPollTasks = errors.New("could not poll tasks") 22 | ErrCouldNotPingTasks = errors.New("could not ping tasks") 23 | ErrTaskTypeAlreadyLearned = errors.New("task with this type already learned") 24 | ErrTaskTypeNotFound = errors.New("task with this type not found") 25 | ErrTaskTypeNotKnown = errors.New("task with this type is not known by this consumer") 26 | ErrUnknownPollStrategy = errors.New("unknown poll strategy") 27 | ErrVisibilityTimeoutTooShort = errors.New("visibility timeout must be longer than poll interval") 28 | ) 29 | 30 | // Logger is the interface used for event logging during task consumption. 31 | type Logger interface { 32 | Print(v ...any) 33 | Printf(format string, v ...any) 34 | } 35 | 36 | // HandlerFunc is the function signature for the handler functions that are used to process tasks. 37 | type HandlerFunc func(task *Task) error 38 | 39 | // PollStrategy is the label assigned to the ordering by which tasks are polled for consumption. 40 | type PollStrategy string 41 | 42 | // Collection of pollStrategies. 43 | const ( 44 | PollStrategyByCreatedAt PollStrategy = "pollByCreatedAt" // Poll by oldest tasks first 45 | PollStrategyByPriority PollStrategy = "pollByPriority" // Poll by highest priority task first 46 | ) 47 | 48 | const ( 49 | defaultQueue = "" 50 | defaultChannelSize = 10 51 | defaultPollInterval = 5 * time.Second 52 | defaultPollStrategy = PollStrategyByCreatedAt 53 | defaultPollLimit = 10 54 | defaultAutoDeleteOnSuccess = false 55 | defaultMaxActiveTasks = 10 56 | defaultVisibilityTimeout = 15 * time.Second 57 | ) 58 | 59 | // NoopLogger discards the log messages written to it. 60 | func NoopLogger() *log.Logger { 61 | return log.New(io.Discard, "", 0) 62 | } 63 | 64 | // Consumer is a service instance created by a Client with reference to that client 65 | // and the various parameters that define the task consumption behaviour. 66 | type Consumer struct { 67 | running bool 68 | autoDeleteOnSuccess bool 69 | channelSize int 70 | pollLimit int 71 | maxActiveTasks int 72 | pollInterval time.Duration 73 | pollStrategy PollStrategy 74 | 75 | wg sync.WaitGroup 76 | 77 | channel chan *func() 78 | client *Client 79 | clock clock.Clock 80 | logger Logger 81 | 82 | handlerFuncMap map[string]HandlerFunc 83 | 84 | activeMutex sync.RWMutex 85 | activeTasks map[uuid.UUID]struct{} 86 | 87 | visibilityTimeout time.Duration 88 | queues []string 89 | 90 | stop chan struct{} 91 | } 92 | 93 | // NewConsumer creates a new consumer with a reference to the original tasq client 94 | // and default consumer parameters. 95 | func (c *Client) NewConsumer() *Consumer { 96 | return &Consumer{ 97 | running: false, 98 | autoDeleteOnSuccess: defaultAutoDeleteOnSuccess, 99 | channelSize: defaultChannelSize, 100 | pollLimit: defaultPollLimit, 101 | maxActiveTasks: defaultMaxActiveTasks, 102 | pollInterval: defaultPollInterval, 103 | pollStrategy: defaultPollStrategy, 104 | 105 | wg: sync.WaitGroup{}, 106 | 107 | channel: nil, 108 | client: c, 109 | clock: clock.New(), 110 | logger: NoopLogger(), 111 | 112 | handlerFuncMap: make(map[string]HandlerFunc), 113 | 114 | activeMutex: sync.RWMutex{}, 115 | activeTasks: make(map[uuid.UUID]struct{}), 116 | 117 | visibilityTimeout: defaultVisibilityTimeout, 118 | queues: []string{defaultQueue}, 119 | 120 | stop: make(chan struct{}, 1), 121 | } 122 | } 123 | 124 | // WithChannelSize sets the size of the buffered channel used for outputting the polled messages to. 125 | // 126 | // Default value: 10. 127 | func (c *Consumer) WithChannelSize(channelSize int) *Consumer { 128 | c.channelSize = channelSize 129 | 130 | return c 131 | } 132 | 133 | // WithLogger sets the Logger interface that is used for event logging during task consumption. 134 | // 135 | // Default value: NoopLogger. 136 | func (c *Consumer) WithLogger(logger Logger) *Consumer { 137 | c.logger = logger 138 | 139 | return c 140 | } 141 | 142 | // WithPollInterval sets the interval at which the consumer will try and poll for new tasks to be executed 143 | // must not be greater than or equal to visibility timeout. 144 | // 145 | // Default value: 5 seconds. 146 | func (c *Consumer) WithPollInterval(pollInterval time.Duration) *Consumer { 147 | c.pollInterval = pollInterval 148 | 149 | return c 150 | } 151 | 152 | // WithPollLimit sets the maximum number of messages polled from the task queue. 153 | // 154 | // Default value: 10. 155 | func (c *Consumer) WithPollLimit(pollLimit int) *Consumer { 156 | c.pollLimit = pollLimit 157 | 158 | return c 159 | } 160 | 161 | // WithPollStrategy sets the ordering to be used when polling for tasks from the task queue. 162 | // 163 | // Default value: PollStrategyByCreatedAt. 164 | func (c *Consumer) WithPollStrategy(pollStrategy PollStrategy) *Consumer { 165 | c.pollStrategy = pollStrategy 166 | 167 | return c 168 | } 169 | 170 | // WithAutoDeleteOnSuccess sets whether successful tasks should be automatically deleted from the task queue 171 | // by the consumer. 172 | // 173 | // Default value: false. 174 | func (c *Consumer) WithAutoDeleteOnSuccess(autoDeleteOnSuccess bool) *Consumer { 175 | c.autoDeleteOnSuccess = autoDeleteOnSuccess 176 | 177 | return c 178 | } 179 | 180 | // WithMaxActiveTasks sets the maximum number of tasks a consumer can have enqueued at the same time 181 | // before polling for additional ones. 182 | // 183 | // Default value: 10. 184 | func (c *Consumer) WithMaxActiveTasks(maxActiveTasks int) *Consumer { 185 | c.maxActiveTasks = maxActiveTasks 186 | 187 | return c 188 | } 189 | 190 | // WithVisibilityTimeout sets the duration by which each ping will extend a task's visibility timeout; 191 | // Once this timeout is up, a consumer instance may receive the task again. 192 | // 193 | // Default value: 15 seconds. 194 | func (c *Consumer) WithVisibilityTimeout(visibilityTimeout time.Duration) *Consumer { 195 | c.visibilityTimeout = visibilityTimeout 196 | 197 | return c 198 | } 199 | 200 | // WithQueues sets the queues from which the consumer may poll for tasks. 201 | // 202 | // Default value: empty slice of strings. 203 | func (c *Consumer) WithQueues(queues ...string) *Consumer { 204 | c.queues = queues 205 | 206 | return c 207 | } 208 | 209 | // Learn sets a handler function for the specified taskType. 210 | // If override is false and a handler function is already set for the specified 211 | // taskType, it'll return an error. 212 | func (c *Consumer) Learn(taskType string, f HandlerFunc, override bool) error { 213 | if _, exists := c.handlerFuncMap[taskType]; exists && !override { 214 | return fmt.Errorf("%w: %s", ErrTaskTypeAlreadyLearned, taskType) 215 | } 216 | 217 | c.handlerFuncMap[taskType] = f 218 | 219 | return nil 220 | } 221 | 222 | // Forget removes a handler function for the specified taskType from the map of 223 | // learned handler functions. 224 | // If the specified taskType does not exist, it'll return an error. 225 | func (c *Consumer) Forget(taskType string) error { 226 | if _, exists := c.handlerFuncMap[taskType]; !exists { 227 | return fmt.Errorf("%w: %s", ErrTaskTypeNotFound, taskType) 228 | } 229 | 230 | delete(c.handlerFuncMap, taskType) 231 | 232 | return nil 233 | } 234 | 235 | // Start launches the go routine which manages the pinging and polling of tasks 236 | // for the consumer, or returns an error if the consumer is not properly configured. 237 | func (c *Consumer) Start(ctx context.Context) error { 238 | if c.isRunning() { 239 | return ErrConsumerAlreadyRunning 240 | } 241 | 242 | if c.visibilityTimeout <= c.pollInterval { 243 | return ErrVisibilityTimeoutTooShort 244 | } 245 | 246 | c.setRunning(true) 247 | 248 | c.channel = make(chan *func(), c.channelSize) 249 | 250 | ticker := c.clock.Ticker(c.pollInterval) 251 | 252 | go c.processLoop(ctx, ticker) 253 | 254 | return nil 255 | } 256 | 257 | // Stop sends the termination signal to the consumer so it'll no longer poll for news tasks. 258 | func (c *Consumer) Stop() error { 259 | if !c.isRunning() { 260 | return ErrConsumerAlreadyStopped 261 | } 262 | 263 | c.stop <- struct{}{} 264 | 265 | return nil 266 | } 267 | 268 | // Channel returns a read-only channel where the polled jobs can be read from. 269 | func (c *Consumer) Channel() <-chan *func() { 270 | return c.channel 271 | } 272 | 273 | func (c *Consumer) isRunning() bool { 274 | return c.running 275 | } 276 | 277 | func (c *Consumer) setRunning(isRunning bool) { 278 | c.running = isRunning 279 | } 280 | 281 | func (c *Consumer) registerTaskStart(ctx context.Context, task *Task) { 282 | _, err := c.client.repository.RegisterStart(ctx, task) 283 | if err != nil { 284 | panic(err) 285 | } 286 | } 287 | 288 | func (c *Consumer) registerTaskError(ctx context.Context, task *Task, taskError error) { 289 | _, err := c.client.repository.RegisterError(ctx, task, taskError) 290 | if err != nil { 291 | panic(err) 292 | } 293 | 294 | if task.MaxReceives > 0 && (task.ReceiveCount) >= task.MaxReceives { 295 | c.registerTaskFail(ctx, task) 296 | } else { 297 | c.requeueTask(ctx, task) 298 | } 299 | } 300 | 301 | func (c *Consumer) registerTaskSuccess(ctx context.Context, task *Task) { 302 | if c.autoDeleteOnSuccess { 303 | err := c.client.repository.DeleteTask(ctx, task, false) 304 | if err != nil { 305 | panic(err) 306 | } 307 | } else { 308 | _, err := c.client.repository.RegisterFinish(ctx, task, StatusSuccessful) 309 | if err != nil { 310 | panic(err) 311 | } 312 | } 313 | 314 | c.removeFromActiveTasks(task) 315 | } 316 | 317 | func (c *Consumer) registerTaskFail(ctx context.Context, task *Task) { 318 | _, err := c.client.repository.RegisterFinish(ctx, task, StatusFailed) 319 | if err != nil { 320 | panic(err) 321 | } 322 | 323 | c.removeFromActiveTasks(task) 324 | } 325 | 326 | func (c *Consumer) requeueTask(ctx context.Context, task *Task) { 327 | _, err := c.client.repository.RequeueTask(ctx, task) 328 | if err != nil { 329 | panic(err) 330 | } 331 | 332 | c.removeFromActiveTasks(task) 333 | } 334 | 335 | func (c *Consumer) getActiveTaskCount() int { 336 | return len(c.activeTasks) 337 | } 338 | 339 | func (c *Consumer) removeFromActiveTasks(task *Task) { 340 | c.activeMutex.Lock() 341 | delete(c.activeTasks, task.ID) 342 | c.activeMutex.Unlock() 343 | } 344 | 345 | func (c *Consumer) getActiveTaskIDs() []uuid.UUID { 346 | activeTaskIDs := make([]uuid.UUID, 0, len(c.activeTasks)) 347 | 348 | for taskID := range c.activeTasks { 349 | activeTaskIDs = append(activeTaskIDs, taskID) 350 | } 351 | 352 | return activeTaskIDs 353 | } 354 | 355 | func (c *Consumer) getKnownTaskTypes() []string { 356 | taskTypes := make([]string, 0, len(c.handlerFuncMap)) 357 | 358 | for taskType := range c.handlerFuncMap { 359 | taskTypes = append(taskTypes, taskType) 360 | } 361 | 362 | return taskTypes 363 | } 364 | 365 | func (c *Consumer) getPollOrdering() (Ordering, error) { 366 | switch c.pollStrategy { 367 | case PollStrategyByCreatedAt: 368 | return OrderingCreatedAtFirst, nil 369 | case PollStrategyByPriority: 370 | return OrderingPriorityFirst, nil 371 | default: 372 | return -1, fmt.Errorf("%w: %s", ErrUnknownPollStrategy, c.pollStrategy) 373 | } 374 | } 375 | 376 | func (c *Consumer) getPollQuantity() int { 377 | taskCapacity := c.maxActiveTasks - len(c.activeTasks) 378 | 379 | if c.pollLimit < taskCapacity { 380 | return c.pollLimit 381 | } 382 | 383 | return taskCapacity 384 | } 385 | 386 | func (c *Consumer) processLoop(ctx context.Context, ticker *clock.Ticker) { 387 | c.wg.Add(1) 388 | defer c.wg.Done() 389 | defer c.logger.Print("processing stopped") 390 | defer ticker.Stop() 391 | 392 | var ( 393 | tasks []*Task 394 | err error 395 | ) 396 | 397 | for { 398 | err = c.pingActiveTasks(ctx) 399 | if err != nil { 400 | c.logger.Printf("error pinging active tasks: %s", err) 401 | } 402 | 403 | if c.isRunning() { 404 | tasks, err = c.pollForTasks(ctx) 405 | if err != nil { 406 | c.logger.Printf("error polling for tasks: %s", err) 407 | } 408 | 409 | err = c.activateTasks(ctx, tasks) 410 | if err != nil { 411 | c.logger.Printf("error activating tasks: %s", err) 412 | } 413 | } else if c.getActiveTaskCount() == 0 { 414 | return 415 | } 416 | 417 | select { 418 | case <-c.stop: 419 | c.setRunning(false) 420 | close(c.channel) 421 | case <-ticker.C: 422 | continue 423 | } 424 | } 425 | } 426 | 427 | func (c *Consumer) pollForTasks(ctx context.Context) ([]*Task, error) { 428 | pollOrdering, err := c.getPollOrdering() 429 | if err != nil { 430 | return nil, err 431 | } 432 | 433 | tasks, err := c.client.repository.PollTasks(ctx, c.getKnownTaskTypes(), c.queues, c.visibilityTimeout, pollOrdering, c.getPollQuantity()) 434 | if err != nil { 435 | return nil, fmt.Errorf("%w: %s", ErrCouldNotPollTasks, err) 436 | } 437 | 438 | return tasks, nil 439 | } 440 | 441 | func (c *Consumer) pingActiveTasks(ctx context.Context) error { 442 | _, err := c.client.repository.PingTasks(ctx, c.getActiveTaskIDs(), c.visibilityTimeout) 443 | if err != nil { 444 | return fmt.Errorf("%w: %s", ErrCouldNotPingTasks, err) 445 | } 446 | 447 | return nil 448 | } 449 | 450 | func (c *Consumer) activateTasks(ctx context.Context, tasks []*Task) error { 451 | var errors []error 452 | 453 | for _, task := range tasks { 454 | err := c.activateTask(ctx, task) 455 | if err != nil { 456 | errors = append(errors, err) 457 | 458 | c.registerTaskFail(ctx, task) 459 | } 460 | } 461 | 462 | if len(errors) > 0 { 463 | return fmt.Errorf("%w: %v", ErrCouldNotActivateTasks, len(errors)) 464 | } 465 | 466 | return nil 467 | } 468 | 469 | func (c *Consumer) activateTask(ctx context.Context, task *Task) error { 470 | job, err := c.createJobFromTask(ctx, task) 471 | if err != nil { 472 | return err 473 | } 474 | 475 | c.activeMutex.Lock() 476 | c.activeTasks[task.ID] = struct{}{} 477 | c.activeMutex.Unlock() 478 | 479 | c.channel <- job 480 | 481 | return nil 482 | } 483 | 484 | func (c *Consumer) createJobFromTask(ctx context.Context, task *Task) (*func(), error) { 485 | if handlerFunc, ok := c.handlerFuncMap[task.Type]; ok { 486 | return c.newJob(ctx, c, handlerFunc, task), nil 487 | } 488 | 489 | return nil, fmt.Errorf("%w: %s", ErrTaskTypeNotKnown, task.Type) 490 | } 491 | 492 | func (c *Consumer) newJob(ctx context.Context, consumer *Consumer, f HandlerFunc, task *Task) *func() { 493 | job := func() { 494 | consumer.registerTaskStart(ctx, task) 495 | 496 | if err := f(task); err == nil { 497 | consumer.registerTaskSuccess(ctx, task) 498 | } else { 499 | consumer.registerTaskError(ctx, task, err) 500 | } 501 | } 502 | 503 | return &job 504 | } 505 | -------------------------------------------------------------------------------- /consumer_test.go: -------------------------------------------------------------------------------- 1 | package tasq_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "log" 8 | "testing" 9 | "time" 10 | 11 | "github.com/benbjohnson/clock" 12 | "github.com/google/uuid" 13 | "github.com/greencoda/tasq" 14 | "github.com/greencoda/tasq/mocks" 15 | "github.com/stretchr/testify/assert" 16 | "github.com/stretchr/testify/mock" 17 | "github.com/stretchr/testify/require" 18 | "github.com/stretchr/testify/suite" 19 | ) 20 | 21 | var ( 22 | errTaskFail = errors.New("task failed") 23 | errRepository = errors.New("repository error") 24 | ) 25 | 26 | type ConsumerTestSuite struct { 27 | suite.Suite 28 | mockRepository *mocks.IRepository 29 | mockClock *clock.Mock 30 | tasqClient *tasq.Client 31 | tasqConsumer *tasq.Consumer 32 | logBuffer bytes.Buffer 33 | } 34 | 35 | func TestConsumerTestSuite(t *testing.T) { 36 | t.Parallel() 37 | 38 | suite.Run(t, new(ConsumerTestSuite)) 39 | } 40 | 41 | func (s *ConsumerTestSuite) SetupTest() { 42 | s.mockRepository = mocks.NewIRepository(s.T()) 43 | s.mockClock = clock.NewMock() 44 | s.mockClock.Set(time.Now()) 45 | 46 | s.tasqClient = tasq.NewClient(s.mockRepository) 47 | require.NotNil(s.T(), s.tasqClient) 48 | 49 | s.tasqConsumer = s.tasqClient.NewConsumer().WithLogger(log.New(&s.logBuffer, "", 0)).SetClock(s.mockClock) 50 | require.NotNil(s.T(), s.tasqConsumer) 51 | 52 | s.logBuffer.Reset() 53 | } 54 | 55 | func (s *ConsumerTestSuite) TestNewConsumer() { 56 | assert.NotNil(s.T(), s.tasqConsumer. 57 | WithAutoDeleteOnSuccess(true). 58 | WithChannelSize(10). 59 | WithMaxActiveTasks(10). 60 | WithPollInterval(10*time.Second). 61 | WithPollLimit(10). 62 | WithPollStrategy(tasq.PollStrategyByCreatedAt). 63 | WithQueues("testQueue"). 64 | WithVisibilityTimeout(30*time.Second)) 65 | 66 | assert.Empty(s.T(), s.logBuffer.String()) 67 | } 68 | 69 | func (s *ConsumerTestSuite) TestLearnAndForget() { 70 | // Learning a new task execution method is successful 71 | err := s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error { 72 | return nil 73 | }, false) 74 | assert.Nil(s.T(), err) 75 | 76 | // Learning an already learned task execution method returns an error 77 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error { 78 | return nil 79 | }, false) 80 | assert.NotNil(s.T(), err) 81 | 82 | // Learning an already learned task execution method with override being true is successful 83 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error { 84 | return nil 85 | }, true) 86 | assert.Nil(s.T(), err) 87 | 88 | // Forgetting an already learned task execution method is successful 89 | err = s.tasqConsumer.Forget("testTask") 90 | assert.Nil(s.T(), err) 91 | 92 | // Forgetting an unknown execution method is successful 93 | err = s.tasqConsumer.Forget("anotherTestTask") 94 | assert.NotNil(s.T(), err) 95 | 96 | assert.Empty(s.T(), s.logBuffer.String()) 97 | } 98 | 99 | func (s *ConsumerTestSuite) TestStartWithInvalidVisibilityTimeoutParam() { 100 | ctx := context.Background() 101 | 102 | s.tasqConsumer. 103 | WithVisibilityTimeout(time.Second). 104 | WithPollInterval(5 * time.Second) 105 | 106 | // Start up the consumer 107 | err := s.tasqConsumer.Start(ctx) 108 | assert.NotNil(s.T(), err) 109 | 110 | assert.Empty(s.T(), s.logBuffer.String()) 111 | } 112 | 113 | func (s *ConsumerTestSuite) TestStartStopTwice() { 114 | ctx := context.Background() 115 | 116 | s.tasqConsumer. 117 | WithQueues("testQueue") 118 | 119 | err := s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error { 120 | return nil 121 | }, false) 122 | assert.Nil(s.T(), err) 123 | 124 | // Getting tasks 125 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Return([]*tasq.Task{}, nil) 126 | 127 | // Polling fails on the first time 128 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 10).Return([]*tasq.Task{}, errRepository).Once() 129 | 130 | // Polling succeeds on subsequent attempts 131 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 10).Return([]*tasq.Task{}, nil) 132 | 133 | // Start up the consumer 134 | err = s.tasqConsumer.Start(ctx) 135 | assert.Nil(s.T(), err) 136 | 137 | // Start up the consumer again 138 | err = s.tasqConsumer.Start(ctx) 139 | assert.NotNil(s.T(), err) 140 | 141 | // Stop the consumer 142 | err = s.tasqConsumer.Stop() 143 | assert.Nil(s.T(), err) 144 | 145 | s.mockClock.Add(5 * time.Second) 146 | 147 | // Stop the consumer again 148 | err = s.tasqConsumer.Stop() 149 | assert.NotNil(s.T(), err) 150 | 151 | // Wait for goroutine to actually return and output log message 152 | s.tasqConsumer.GetWaitGroup().Wait() 153 | 154 | assert.Equal(s.T(), "error polling for tasks: could not poll tasks: repository error\nprocessing stopped\n", s.logBuffer.String()) 155 | } 156 | 157 | func (s *ConsumerTestSuite) TestConsumption() { 158 | ctx := context.Background() 159 | 160 | s.tasqConsumer. 161 | WithQueues("testQueue") 162 | 163 | var ( 164 | successTestArgs = "success" 165 | failTestArgs = "fail" 166 | ) 167 | 168 | successTestTask, err := tasq.NewTask("testTask", successTestArgs, "testQueue", 100, 5) 169 | require.NotNil(s.T(), successTestTask) 170 | require.Nil(s.T(), err) 171 | 172 | failTestTask, err := tasq.NewTask("testTask", failTestArgs, "testQueue", 100, 5) 173 | require.NotNil(s.T(), failTestTask) 174 | require.Nil(s.T(), err) 175 | 176 | failNoRequeueTestTask, err := tasq.NewTask("testTask", failTestArgs, "testQueue", 100, 1) 177 | require.NotNil(s.T(), failNoRequeueTestTask) 178 | require.Nil(s.T(), err) 179 | 180 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error { 181 | var args string 182 | 183 | err := task.UnmarshalArgs(&args) 184 | require.Nil(s.T(), err) 185 | 186 | if args == successTestArgs { 187 | return nil 188 | } 189 | 190 | return errTaskFail 191 | }, false) 192 | require.Nil(s.T(), err) 193 | 194 | // Increment receive counts 195 | for _, task := range []*tasq.Task{ 196 | successTestTask, 197 | failTestTask, 198 | failNoRequeueTestTask, 199 | } { 200 | task.ReceiveCount++ 201 | } 202 | 203 | // Getting tasks 204 | 205 | // First try - pinging fails 206 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Once().Return([]*tasq.Task{}, errRepository) 207 | 208 | // Second try - pinging succeeds 209 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Once().Return([]*tasq.Task{}, nil) 210 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 10).Return([]*tasq.Task{ 211 | successTestTask, 212 | failTestTask, 213 | failNoRequeueTestTask, 214 | }, nil) 215 | 216 | // Start up the consumer 217 | err = s.tasqConsumer.Start(ctx) 218 | assert.Nil(s.T(), err) 219 | 220 | // Read the successful job from the consumer channel 221 | successJob := <-s.tasqConsumer.Channel() 222 | assert.NotNil(s.T(), successJob) 223 | 224 | // First try - registering task start fails 225 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(nil, errRepository) 226 | assert.Panics(s.T(), func() { 227 | (*successJob)() 228 | }) 229 | 230 | // Second try - registering task success fails 231 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(successTestTask, nil) 232 | s.mockRepository.On("RegisterFinish", ctx, successTestTask, tasq.StatusSuccessful).Once().Return(nil, errRepository) 233 | assert.Panics(s.T(), func() { 234 | (*successJob)() 235 | }) 236 | 237 | // Third try - repository succeeds 238 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(successTestTask, nil) 239 | s.mockRepository.On("RegisterFinish", ctx, successTestTask, tasq.StatusSuccessful).Once().Return(successTestTask, nil) 240 | assert.NotPanics(s.T(), func() { 241 | (*successJob)() 242 | }) 243 | 244 | // Read the failing job from the consumer channel 245 | failJob := <-s.tasqConsumer.Channel() 246 | assert.NotNil(s.T(), failJob) 247 | 248 | // First try - registering task error fails 249 | s.mockRepository.On("RegisterStart", ctx, failTestTask).Once().Return(failTestTask, nil) 250 | s.mockRepository.On("RegisterError", ctx, failTestTask, errTaskFail).Once().Return(nil, errRepository) 251 | assert.Panics(s.T(), func() { 252 | (*failJob)() 253 | }) 254 | 255 | // Second try - requeuing task fails 256 | s.mockRepository.On("RegisterStart", ctx, failTestTask).Once().Return(failTestTask, nil) 257 | s.mockRepository.On("RegisterError", ctx, failTestTask, errTaskFail).Once().Return(failTestTask, nil) 258 | s.mockRepository.On("RequeueTask", ctx, failTestTask).Once().Return(nil, errRepository) 259 | assert.Panics(s.T(), func() { 260 | (*failJob)() 261 | }) 262 | 263 | // Third try - repository succeeds 264 | s.mockRepository.On("RegisterStart", ctx, failTestTask).Once().Return(failTestTask, nil) 265 | s.mockRepository.On("RegisterError", ctx, failTestTask, errTaskFail).Once().Return(failTestTask, nil) 266 | s.mockRepository.On("RequeueTask", ctx, failTestTask).Once().Return(failTestTask, nil) 267 | assert.NotPanics(s.T(), func() { 268 | (*failJob)() 269 | }) 270 | 271 | // Read the failing job that shouldn't be requeued from the consumer channel 272 | failNoRequeueJob := <-s.tasqConsumer.Channel() 273 | assert.NotNil(s.T(), failNoRequeueJob) 274 | 275 | // First try - registering task failure fails 276 | s.mockRepository.On("RegisterStart", ctx, failNoRequeueTestTask).Once().Return(failNoRequeueTestTask, nil) 277 | s.mockRepository.On("RegisterError", ctx, failNoRequeueTestTask, errTaskFail).Once().Return(nil, errRepository) 278 | assert.Panics(s.T(), func() { 279 | (*failNoRequeueJob)() 280 | }) 281 | 282 | // Second try - registering task failure fails 283 | s.mockRepository.On("RegisterStart", ctx, failNoRequeueTestTask).Once().Return(failNoRequeueTestTask, nil) 284 | s.mockRepository.On("RegisterError", ctx, failNoRequeueTestTask, errTaskFail).Once().Return(failNoRequeueTestTask, nil) 285 | s.mockRepository.On("RegisterFinish", ctx, failNoRequeueTestTask, tasq.StatusFailed).Once().Return(nil, errRepository) 286 | assert.Panics(s.T(), func() { 287 | (*failNoRequeueJob)() 288 | }) 289 | 290 | // Third try - repository succeeds 291 | s.mockRepository.On("RegisterStart", ctx, failNoRequeueTestTask).Once().Return(failNoRequeueTestTask, nil) 292 | s.mockRepository.On("RegisterError", ctx, failNoRequeueTestTask, errTaskFail).Once().Return(failNoRequeueTestTask, nil) 293 | s.mockRepository.On("RegisterFinish", ctx, failNoRequeueTestTask, tasq.StatusFailed).Once().Return(failNoRequeueTestTask, nil) 294 | assert.NotPanics(s.T(), func() { 295 | (*failNoRequeueJob)() 296 | }) 297 | 298 | // Stop the consumer 299 | err = s.tasqConsumer.Stop() 300 | assert.Nil(s.T(), err) 301 | 302 | // Wait until channel is closed 303 | <-s.tasqConsumer.Channel() 304 | 305 | s.mockClock.Add(10 * time.Second) 306 | 307 | // Wait for goroutine to actually return and output log message 308 | s.tasqConsumer.GetWaitGroup().Wait() 309 | 310 | assert.Equal(s.T(), "error pinging active tasks: could not ping tasks: repository error\nprocessing stopped\n", s.logBuffer.String()) 311 | } 312 | 313 | func (s *ConsumerTestSuite) TestConsumptionWithAutoDeleteOnSuccess() { 314 | ctx := context.Background() 315 | 316 | s.tasqConsumer. 317 | WithQueues("testQueue"). 318 | WithAutoDeleteOnSuccess(true) 319 | 320 | successTestTask, err := tasq.NewTask("testTask", true, "testQueue", 100, 5) 321 | require.NotNil(s.T(), successTestTask) 322 | require.Nil(s.T(), err) 323 | 324 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error { 325 | return nil 326 | }, false) 327 | require.Nil(s.T(), err) 328 | 329 | successTestTask.ReceiveCount++ 330 | 331 | // Getting tasks 332 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Twice().Return([]*tasq.Task{}, nil) 333 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 10).Return([]*tasq.Task{ 334 | successTestTask, 335 | }, nil) 336 | 337 | // Start up the consumer 338 | err = s.tasqConsumer.Start(ctx) 339 | assert.Nil(s.T(), err) 340 | 341 | // Read the successful job from the consumer channel 342 | successJob := <-s.tasqConsumer.Channel() 343 | assert.NotNil(s.T(), successJob) 344 | 345 | // First try - deleting task fails 346 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(successTestTask, nil) 347 | s.mockRepository.On("DeleteTask", ctx, successTestTask, false).Once().Return(errRepository) 348 | assert.Panics(s.T(), func() { 349 | (*successJob)() 350 | }) 351 | 352 | // Second try - deleting task succeeds 353 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(successTestTask, nil) 354 | s.mockRepository.On("DeleteTask", ctx, successTestTask, false).Once().Return(nil) 355 | assert.NotPanics(s.T(), func() { 356 | (*successJob)() 357 | }) 358 | 359 | // Stop the consumer 360 | err = s.tasqConsumer.Stop() 361 | assert.Nil(s.T(), err) 362 | 363 | // Wait until channel is closed 364 | <-s.tasqConsumer.Channel() 365 | 366 | s.mockClock.Add(5 * time.Second) 367 | 368 | // Wait for goroutine to actually return and output log message 369 | s.tasqConsumer.GetWaitGroup().Wait() 370 | 371 | assert.Equal(s.T(), "processing stopped\n", s.logBuffer.String()) 372 | } 373 | 374 | func (s *ConsumerTestSuite) TestConsumptionWithPollStrategyByPriority() { 375 | ctx := context.Background() 376 | 377 | s.tasqConsumer. 378 | WithQueues("testQueue"). 379 | WithPollStrategy(tasq.PollStrategyByPriority) 380 | 381 | successTestTask, err := tasq.NewTask("testTask", true, "testQueue", 100, 5) 382 | require.NotNil(s.T(), successTestTask) 383 | require.Nil(s.T(), err) 384 | 385 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error { 386 | return nil 387 | }, false) 388 | require.Nil(s.T(), err) 389 | 390 | successTestTask.ReceiveCount++ 391 | 392 | // Getting tasks 393 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Twice().Return([]*tasq.Task{}, nil) 394 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingPriorityFirst, 10).Return([]*tasq.Task{ 395 | successTestTask, 396 | }, nil) 397 | 398 | // Start up the consumer 399 | err = s.tasqConsumer.Start(ctx) 400 | assert.Nil(s.T(), err) 401 | 402 | // Read the successful job from the consumer channel 403 | successJob := <-s.tasqConsumer.Channel() 404 | assert.NotNil(s.T(), successJob) 405 | 406 | // First try - repository succeeds 407 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(successTestTask, nil) 408 | s.mockRepository.On("RegisterFinish", ctx, successTestTask, tasq.StatusSuccessful).Once().Return(successTestTask, nil) 409 | assert.NotPanics(s.T(), func() { 410 | (*successJob)() 411 | }) 412 | 413 | // Stop the consumer 414 | err = s.tasqConsumer.Stop() 415 | assert.Nil(s.T(), err) 416 | 417 | // Wait until channel is closed 418 | <-s.tasqConsumer.Channel() 419 | 420 | s.mockClock.Add(5 * time.Second) 421 | 422 | // Wait for goroutine to actually return and output log message 423 | s.tasqConsumer.GetWaitGroup().Wait() 424 | 425 | assert.Equal(s.T(), "processing stopped\n", s.logBuffer.String()) 426 | } 427 | 428 | func (s *ConsumerTestSuite) TestConsumptionWithUnknownPollStrategy() { 429 | ctx := context.Background() 430 | 431 | s.tasqConsumer. 432 | WithQueues("testQueue"). 433 | WithPollStrategy(tasq.PollStrategy("pollByMagic")) 434 | 435 | // Getting tasks 436 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Twice().Return([]*tasq.Task{}, nil) 437 | 438 | // Start up the consumer 439 | err := s.tasqConsumer.Start(ctx) 440 | assert.Nil(s.T(), err) 441 | 442 | // Stop the consumer 443 | err = s.tasqConsumer.Stop() 444 | assert.Nil(s.T(), err) 445 | 446 | // Wait until channel is closed 447 | <-s.tasqConsumer.Channel() 448 | 449 | s.mockClock.Add(5 * time.Second) 450 | 451 | // Wait for goroutine to actually return and output log message 452 | s.tasqConsumer.GetWaitGroup().Wait() 453 | 454 | assert.Equal(s.T(), "error polling for tasks: unknown poll strategy: pollByMagic\nprocessing stopped\n", s.logBuffer.String()) 455 | } 456 | 457 | func (s *ConsumerTestSuite) TestConsumptionOfUnknownTaskType() { 458 | ctx := context.Background() 459 | 460 | s.tasqConsumer. 461 | WithQueues("testQueue") 462 | 463 | anotherTestTask, err := tasq.NewTask("anotherTestTask", true, "testQueue", 100, 5) 464 | require.NotNil(s.T(), anotherTestTask) 465 | require.Nil(s.T(), err) 466 | 467 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error { 468 | return nil 469 | }, false) 470 | require.Nil(s.T(), err) 471 | 472 | // Getting tasks 473 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Twice().Return([]*tasq.Task{}, nil) 474 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 10).Return([]*tasq.Task{ 475 | anotherTestTask, 476 | }, nil) 477 | 478 | // Start up the consumer 479 | err = s.tasqConsumer.Start(ctx) 480 | assert.Nil(s.T(), err) 481 | 482 | s.mockRepository.On("RegisterFinish", ctx, anotherTestTask, tasq.StatusFailed).Return(anotherTestTask, nil) 483 | 484 | // Stop the consumer 485 | err = s.tasqConsumer.Stop() 486 | assert.Nil(s.T(), err) 487 | 488 | // Wait until channel is closed 489 | <-s.tasqConsumer.Channel() 490 | 491 | s.mockClock.Add(5 * time.Second) 492 | 493 | // Wait for goroutine to actually return and output log message 494 | s.tasqConsumer.GetWaitGroup().Wait() 495 | 496 | assert.Equal(s.T(), "error activating tasks: a number of tasks could not be activated: 1\nprocessing stopped\n", s.logBuffer.String()) 497 | } 498 | 499 | func (s *ConsumerTestSuite) TestLoopingConsumption() { 500 | s.tasqConsumer. 501 | WithQueues("testQueue"). 502 | WithPollInterval(5 * time.Second). 503 | WithPollLimit(1). 504 | WithMaxActiveTasks(2) 505 | 506 | var ( 507 | ctx = context.Background() 508 | 509 | testTaskID1 = uuid.MustParse("1ada263f-61d5-44ac-b99d-2d5ad4f249de") 510 | testTaskID2 = uuid.MustParse("28032675-bc13-4dcd-8ec6-6aa430fc466a") 511 | 512 | testTasks = map[uuid.UUID]*tasq.Task{ 513 | testTaskID1: { 514 | ID: testTaskID1, 515 | Type: "testTask", 516 | Args: []uint8{0x3, 0x2, 0x0, 0x1}, 517 | Queue: "testQueue", 518 | Priority: 100, 519 | Status: tasq.StatusNew, 520 | ReceiveCount: 0, 521 | MaxReceives: 5, 522 | CreatedAt: s.mockClock.Now(), 523 | VisibleAt: s.mockClock.Now(), 524 | }, 525 | testTaskID2: { 526 | ID: testTaskID2, 527 | Type: "testTask", 528 | Args: []uint8{0x3, 0x2, 0x0, 0x1}, 529 | Queue: "testQueue", 530 | Priority: 100, 531 | Status: tasq.StatusNew, 532 | ReceiveCount: 0, 533 | MaxReceives: 5, 534 | CreatedAt: s.mockClock.Now(), 535 | VisibleAt: s.mockClock.Now(), 536 | }, 537 | } 538 | ) 539 | 540 | err := s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error { 541 | return nil 542 | }, false) 543 | require.Nil(s.T(), err) 544 | 545 | // Respond to pings 546 | pingCall := s.mockRepository.On("PingTasks", ctx, mock.AnythingOfType("[]uuid.UUID"), 15*time.Second) 547 | pingCall.Run(func(args mock.Arguments) { 548 | inputTestIDs, ok := args[1].([]uuid.UUID) 549 | require.True(s.T(), ok) 550 | 551 | taskIDs, ok := args[1].([]uuid.UUID) 552 | require.True(s.T(), ok) 553 | 554 | returnTasks := make([]*tasq.Task, 0, len(inputTestIDs)) 555 | 556 | for _, taskID := range taskIDs { 557 | returnTask, ok := testTasks[taskID] 558 | require.True(s.T(), ok) 559 | 560 | returnTasks = append(returnTasks, returnTask) 561 | } 562 | 563 | pingCall.ReturnArguments = mock.Arguments{returnTasks, nil} 564 | }) 565 | 566 | // Respond to polls 567 | // First call 568 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1).Once(). 569 | Return([]*tasq.Task{ 570 | testTasks[testTaskID1], 571 | }, nil) 572 | 573 | // Second call 574 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1).Once(). 575 | Return([]*tasq.Task{ 576 | testTasks[testTaskID2], 577 | }, nil) 578 | 579 | // Start up the consumer 580 | err = s.tasqConsumer.Start(ctx) 581 | assert.Nil(s.T(), err) 582 | 583 | s.mockClock.Add(5 * time.Second) 584 | 585 | // Stop the consumer 586 | err = s.tasqConsumer.Stop() 587 | assert.Nil(s.T(), err) 588 | 589 | // Mock job handling 590 | s.mockRepository.On("RegisterStart", ctx, testTasks[testTaskID1]).Once(). 591 | Return(testTasks[testTaskID1], nil) 592 | s.mockRepository.On("RegisterFinish", ctx, testTasks[testTaskID1], tasq.StatusSuccessful).Once(). 593 | Return(testTasks[testTaskID1], nil) 594 | 595 | s.mockRepository.On("RegisterStart", ctx, testTasks[testTaskID2]).Once(). 596 | Return(testTasks[testTaskID2], nil) 597 | s.mockRepository.On("RegisterFinish", ctx, testTasks[testTaskID2], tasq.StatusSuccessful).Once(). 598 | Return(testTasks[testTaskID2], nil) 599 | 600 | // Drain channel of jobs 601 | for job := range s.tasqConsumer.Channel() { 602 | currentJob := job 603 | 604 | assert.NotPanics(s.T(), func() { 605 | (*currentJob)() 606 | }) 607 | } 608 | 609 | // Let 5 seconds pass so that the goroutine has a chance to finish its last loop 610 | s.mockClock.Add(5 * time.Second) 611 | 612 | // Wait for goroutine to actually return and output log message 613 | s.tasqConsumer.GetWaitGroup().Wait() 614 | 615 | assert.Equal(s.T(), "processing stopped\n", s.logBuffer.String()) 616 | } 617 | -------------------------------------------------------------------------------- /export_test.go: -------------------------------------------------------------------------------- 1 | package tasq 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/benbjohnson/clock" 7 | ) 8 | 9 | // SetClock needs to be exported so we can adjust the clock during tests 10 | // to avoid failures due to different results of time.Now(). 11 | func (c *Consumer) SetClock(clock clock.Clock) *Consumer { 12 | c.clock = clock 13 | 14 | return c 15 | } 16 | 17 | // GetWaitGroup needs to be exported so we can get the waitgroup of the consumer 18 | // in order to allow the tests for the consumption to finish. 19 | func (c *Consumer) GetWaitGroup() *sync.WaitGroup { 20 | return &c.wg 21 | } 22 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/greencoda/tasq 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/DATA-DOG/go-sqlmock v1.5.0 7 | github.com/benbjohnson/clock v1.3.0 8 | github.com/go-sql-driver/mysql v1.7.0 9 | github.com/google/uuid v1.3.0 10 | github.com/jmoiron/sqlx v1.3.5 11 | github.com/lib/pq v1.10.8 12 | github.com/stretchr/testify v1.8.1 13 | ) 14 | 15 | require ( 16 | github.com/davecgh/go-spew v1.1.1 // indirect 17 | github.com/pmezard/go-difflib v1.0.0 // indirect 18 | github.com/stretchr/objx v0.5.0 // indirect 19 | gopkg.in/yaml.v3 v3.0.1 // indirect 20 | ) 21 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= 2 | github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= 3 | github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= 4 | github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= 5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 8 | github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= 9 | github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= 10 | github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= 11 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= 12 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 13 | github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= 14 | github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= 15 | github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= 16 | github.com/lib/pq v1.10.8 h1:3fdt97i/cwSU83+E0hZTC/Xpc9mTZxc6UWSCRcSbxiE= 17 | github.com/lib/pq v1.10.8/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 18 | github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= 19 | github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= 20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 22 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 23 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 24 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= 25 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 26 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 27 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 28 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= 29 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 30 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 31 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 32 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 33 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 34 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 35 | -------------------------------------------------------------------------------- /inspector.go: -------------------------------------------------------------------------------- 1 | package tasq 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | // Inspector is a service instance created by a Client with reference to that client 9 | // with the purpose of enabling the observation of tasks. 10 | type Inspector struct { 11 | client *Client 12 | } 13 | 14 | // NewInspector creates a new inspector with a reference to the original tasq client. 15 | func (c *Client) NewInspector() *Inspector { 16 | return &Inspector{ 17 | client: c, 18 | } 19 | } 20 | 21 | // Count returns a the total number of tasks based on the supplied filter arguments. 22 | func (o *Inspector) Count(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string) (int64, error) { 23 | count, err := o.client.repository.CountTasks(ctx, taskStatuses, taskTypes, queues) 24 | if err != nil { 25 | return 0, fmt.Errorf("error counting tasks: %w", err) 26 | } 27 | 28 | return count, nil 29 | } 30 | 31 | // Scan returns a list of tasks based on the supplied filter arguments. 32 | func (o *Inspector) Scan(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string, ordering Ordering, limit int) ([]*Task, error) { 33 | tasks, err := o.client.repository.ScanTasks(ctx, taskStatuses, taskTypes, queues, ordering, limit) 34 | if err != nil { 35 | return nil, fmt.Errorf("error scanning tasks: %w", err) 36 | } 37 | 38 | return tasks, nil 39 | } 40 | 41 | // Purge will remove all tasks based on the supplied filter arguments. 42 | func (o *Inspector) Purge(ctx context.Context, safeDelete bool, taskStatuses []TaskStatus, taskTypes, queues []string) (int64, error) { 43 | count, err := o.client.repository.PurgeTasks(ctx, taskStatuses, taskTypes, queues, safeDelete) 44 | if err != nil { 45 | return 0, fmt.Errorf("error purging tasks: %w", err) 46 | } 47 | 48 | return count, nil 49 | } 50 | 51 | // Delete will remove the supplied tasks. 52 | func (o *Inspector) Delete(ctx context.Context, safeDelete bool, tasks ...*Task) error { 53 | for _, task := range tasks { 54 | if err := o.client.repository.DeleteTask(ctx, task, safeDelete); err != nil { 55 | return fmt.Errorf("error removing task: %w", err) 56 | } 57 | } 58 | 59 | return nil 60 | } 61 | -------------------------------------------------------------------------------- /inspector_test.go: -------------------------------------------------------------------------------- 1 | package tasq_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/greencoda/tasq" 8 | "github.com/greencoda/tasq/mocks" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/require" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type InspectorTestSuite struct { 15 | suite.Suite 16 | mockRepository *mocks.IRepository 17 | tasqClient *tasq.Client 18 | tasqInspector *tasq.Inspector 19 | } 20 | 21 | func TestInspectorTestSuite(t *testing.T) { 22 | t.Parallel() 23 | 24 | suite.Run(t, new(InspectorTestSuite)) 25 | } 26 | 27 | func (s *InspectorTestSuite) SetupTest() { 28 | s.mockRepository = mocks.NewIRepository(s.T()) 29 | 30 | s.tasqClient = tasq.NewClient(s.mockRepository) 31 | require.NotNil(s.T(), s.tasqClient) 32 | 33 | s.tasqInspector = s.tasqClient.NewInspector() 34 | } 35 | 36 | func (s *InspectorTestSuite) TestNewCleaner() { 37 | assert.NotNil(s.T(), s.tasqInspector) 38 | } 39 | 40 | func (s *InspectorTestSuite) TestCount() { 41 | ctx := context.Background() 42 | 43 | s.mockRepository.On("CountTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}).Return(int64(1), nil).Once() 44 | 45 | taskCount, err := s.tasqInspector.Count(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}) 46 | assert.Equal(s.T(), int64(1), taskCount) 47 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "CountTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"})) 48 | assert.Nil(s.T(), err) 49 | 50 | s.mockRepository.On("CountTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}).Return(int64(0), errRepository).Once() 51 | 52 | taskCount, err = s.tasqInspector.Count(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}) 53 | assert.Equal(s.T(), int64(0), taskCount) 54 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "CountTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"})) 55 | assert.NotNil(s.T(), err) 56 | } 57 | 58 | func (s *InspectorTestSuite) TestScan() { 59 | ctx := context.Background() 60 | 61 | testTask, err := tasq.NewTask("testTask", true, "testQueue", 100, 5) 62 | require.NotNil(s.T(), testTask) 63 | require.Nil(s.T(), err) 64 | 65 | s.mockRepository.On("ScanTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100).Return([]*tasq.Task{testTask}, nil).Once() 66 | 67 | tasks, err := s.tasqInspector.Scan(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100) 68 | assert.Equal(s.T(), []*tasq.Task{testTask}, tasks) 69 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "ScanTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100)) 70 | assert.Nil(s.T(), err) 71 | 72 | s.mockRepository.On("ScanTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100).Return([]*tasq.Task{}, errRepository).Once() 73 | 74 | tasks, err = s.tasqInspector.Scan(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100) 75 | assert.Len(s.T(), tasks, 0) 76 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "ScanTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100)) 77 | assert.NotNil(s.T(), err) 78 | } 79 | 80 | func (s *InspectorTestSuite) TestPurge() { 81 | ctx := context.Background() 82 | 83 | s.mockRepository.On("PurgeTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, true).Return(int64(10), nil).Once() 84 | 85 | count, err := s.tasqInspector.Purge(ctx, true, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}) 86 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "PurgeTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, true)) 87 | assert.Equal(s.T(), int64(10), count) 88 | assert.Nil(s.T(), err) 89 | 90 | s.mockRepository.On("PurgeTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, true).Return(int64(0), errRepository).Once() 91 | 92 | count, err = s.tasqInspector.Purge(ctx, true, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}) 93 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "PurgeTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, true)) 94 | assert.Equal(s.T(), int64(0), count) 95 | assert.NotNil(s.T(), err) 96 | } 97 | 98 | func (s *InspectorTestSuite) TestDelete() { 99 | ctx := context.Background() 100 | 101 | testTask, err := tasq.NewTask("testTask", true, "testQueue", 100, 5) 102 | require.NotNil(s.T(), testTask) 103 | require.Nil(s.T(), err) 104 | 105 | s.mockRepository.On("DeleteTask", ctx, testTask, true).Once().Return(errRepository) 106 | 107 | err = s.tasqInspector.Delete(ctx, true, testTask) 108 | assert.NotNil(s.T(), err) 109 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "DeleteTask", ctx, testTask, true)) 110 | 111 | s.mockRepository.On("DeleteTask", ctx, testTask, true).Once().Return(nil) 112 | 113 | err = s.tasqInspector.Delete(ctx, true, testTask) 114 | assert.Nil(s.T(), err) 115 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "DeleteTask", ctx, testTask, true)) 116 | } 117 | -------------------------------------------------------------------------------- /mocks/IRepository.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v2.25.1. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | context "context" 7 | 8 | tasq "github.com/greencoda/tasq" 9 | mock "github.com/stretchr/testify/mock" 10 | 11 | time "time" 12 | 13 | uuid "github.com/google/uuid" 14 | ) 15 | 16 | // IRepository is an autogenerated mock type for the IRepository type 17 | type IRepository struct { 18 | mock.Mock 19 | } 20 | 21 | // CleanTasks provides a mock function with given fields: ctx, minimumAge 22 | func (_m *IRepository) CleanTasks(ctx context.Context, minimumAge time.Duration) (int64, error) { 23 | ret := _m.Called(ctx, minimumAge) 24 | 25 | var r0 int64 26 | var r1 error 27 | if rf, ok := ret.Get(0).(func(context.Context, time.Duration) (int64, error)); ok { 28 | return rf(ctx, minimumAge) 29 | } 30 | if rf, ok := ret.Get(0).(func(context.Context, time.Duration) int64); ok { 31 | r0 = rf(ctx, minimumAge) 32 | } else { 33 | r0 = ret.Get(0).(int64) 34 | } 35 | 36 | if rf, ok := ret.Get(1).(func(context.Context, time.Duration) error); ok { 37 | r1 = rf(ctx, minimumAge) 38 | } else { 39 | r1 = ret.Error(1) 40 | } 41 | 42 | return r0, r1 43 | } 44 | 45 | // CountTasks provides a mock function with given fields: ctx, taskStatuses, taskTypes, queues 46 | func (_m *IRepository) CountTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes []string, queues []string) (int64, error) { 47 | ret := _m.Called(ctx, taskStatuses, taskTypes, queues) 48 | 49 | var r0 int64 50 | var r1 error 51 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string) (int64, error)); ok { 52 | return rf(ctx, taskStatuses, taskTypes, queues) 53 | } 54 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string) int64); ok { 55 | r0 = rf(ctx, taskStatuses, taskTypes, queues) 56 | } else { 57 | r0 = ret.Get(0).(int64) 58 | } 59 | 60 | if rf, ok := ret.Get(1).(func(context.Context, []tasq.TaskStatus, []string, []string) error); ok { 61 | r1 = rf(ctx, taskStatuses, taskTypes, queues) 62 | } else { 63 | r1 = ret.Error(1) 64 | } 65 | 66 | return r0, r1 67 | } 68 | 69 | // DeleteTask provides a mock function with given fields: ctx, task, safeDelete 70 | func (_m *IRepository) DeleteTask(ctx context.Context, task *tasq.Task, safeDelete bool) error { 71 | ret := _m.Called(ctx, task, safeDelete) 72 | 73 | var r0 error 74 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task, bool) error); ok { 75 | r0 = rf(ctx, task, safeDelete) 76 | } else { 77 | r0 = ret.Error(0) 78 | } 79 | 80 | return r0 81 | } 82 | 83 | // Migrate provides a mock function with given fields: ctx 84 | func (_m *IRepository) Migrate(ctx context.Context) error { 85 | ret := _m.Called(ctx) 86 | 87 | var r0 error 88 | if rf, ok := ret.Get(0).(func(context.Context) error); ok { 89 | r0 = rf(ctx) 90 | } else { 91 | r0 = ret.Error(0) 92 | } 93 | 94 | return r0 95 | } 96 | 97 | // PingTasks provides a mock function with given fields: ctx, taskIDs, visibilityTimeout 98 | func (_m *IRepository) PingTasks(ctx context.Context, taskIDs []uuid.UUID, visibilityTimeout time.Duration) ([]*tasq.Task, error) { 99 | ret := _m.Called(ctx, taskIDs, visibilityTimeout) 100 | 101 | var r0 []*tasq.Task 102 | var r1 error 103 | if rf, ok := ret.Get(0).(func(context.Context, []uuid.UUID, time.Duration) ([]*tasq.Task, error)); ok { 104 | return rf(ctx, taskIDs, visibilityTimeout) 105 | } 106 | if rf, ok := ret.Get(0).(func(context.Context, []uuid.UUID, time.Duration) []*tasq.Task); ok { 107 | r0 = rf(ctx, taskIDs, visibilityTimeout) 108 | } else { 109 | if ret.Get(0) != nil { 110 | r0 = ret.Get(0).([]*tasq.Task) 111 | } 112 | } 113 | 114 | if rf, ok := ret.Get(1).(func(context.Context, []uuid.UUID, time.Duration) error); ok { 115 | r1 = rf(ctx, taskIDs, visibilityTimeout) 116 | } else { 117 | r1 = ret.Error(1) 118 | } 119 | 120 | return r0, r1 121 | } 122 | 123 | // PollTasks provides a mock function with given fields: ctx, types, queues, visibilityTimeout, ordering, limit 124 | func (_m *IRepository) PollTasks(ctx context.Context, types []string, queues []string, visibilityTimeout time.Duration, ordering tasq.Ordering, limit int) ([]*tasq.Task, error) { 125 | ret := _m.Called(ctx, types, queues, visibilityTimeout, ordering, limit) 126 | 127 | var r0 []*tasq.Task 128 | var r1 error 129 | if rf, ok := ret.Get(0).(func(context.Context, []string, []string, time.Duration, tasq.Ordering, int) ([]*tasq.Task, error)); ok { 130 | return rf(ctx, types, queues, visibilityTimeout, ordering, limit) 131 | } 132 | if rf, ok := ret.Get(0).(func(context.Context, []string, []string, time.Duration, tasq.Ordering, int) []*tasq.Task); ok { 133 | r0 = rf(ctx, types, queues, visibilityTimeout, ordering, limit) 134 | } else { 135 | if ret.Get(0) != nil { 136 | r0 = ret.Get(0).([]*tasq.Task) 137 | } 138 | } 139 | 140 | if rf, ok := ret.Get(1).(func(context.Context, []string, []string, time.Duration, tasq.Ordering, int) error); ok { 141 | r1 = rf(ctx, types, queues, visibilityTimeout, ordering, limit) 142 | } else { 143 | r1 = ret.Error(1) 144 | } 145 | 146 | return r0, r1 147 | } 148 | 149 | // PurgeTasks provides a mock function with given fields: ctx, taskStatuses, taskTypes, queues, safeDelete 150 | func (_m *IRepository) PurgeTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes []string, queues []string, safeDelete bool) (int64, error) { 151 | ret := _m.Called(ctx, taskStatuses, taskTypes, queues, safeDelete) 152 | 153 | var r0 int64 154 | var r1 error 155 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string, bool) (int64, error)); ok { 156 | return rf(ctx, taskStatuses, taskTypes, queues, safeDelete) 157 | } 158 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string, bool) int64); ok { 159 | r0 = rf(ctx, taskStatuses, taskTypes, queues, safeDelete) 160 | } else { 161 | r0 = ret.Get(0).(int64) 162 | } 163 | 164 | if rf, ok := ret.Get(1).(func(context.Context, []tasq.TaskStatus, []string, []string, bool) error); ok { 165 | r1 = rf(ctx, taskStatuses, taskTypes, queues, safeDelete) 166 | } else { 167 | r1 = ret.Error(1) 168 | } 169 | 170 | return r0, r1 171 | } 172 | 173 | // RegisterError provides a mock function with given fields: ctx, task, errTask 174 | func (_m *IRepository) RegisterError(ctx context.Context, task *tasq.Task, errTask error) (*tasq.Task, error) { 175 | ret := _m.Called(ctx, task, errTask) 176 | 177 | var r0 *tasq.Task 178 | var r1 error 179 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task, error) (*tasq.Task, error)); ok { 180 | return rf(ctx, task, errTask) 181 | } 182 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task, error) *tasq.Task); ok { 183 | r0 = rf(ctx, task, errTask) 184 | } else { 185 | if ret.Get(0) != nil { 186 | r0 = ret.Get(0).(*tasq.Task) 187 | } 188 | } 189 | 190 | if rf, ok := ret.Get(1).(func(context.Context, *tasq.Task, error) error); ok { 191 | r1 = rf(ctx, task, errTask) 192 | } else { 193 | r1 = ret.Error(1) 194 | } 195 | 196 | return r0, r1 197 | } 198 | 199 | // RegisterFinish provides a mock function with given fields: ctx, task, finishStatus 200 | func (_m *IRepository) RegisterFinish(ctx context.Context, task *tasq.Task, finishStatus tasq.TaskStatus) (*tasq.Task, error) { 201 | ret := _m.Called(ctx, task, finishStatus) 202 | 203 | var r0 *tasq.Task 204 | var r1 error 205 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task, tasq.TaskStatus) (*tasq.Task, error)); ok { 206 | return rf(ctx, task, finishStatus) 207 | } 208 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task, tasq.TaskStatus) *tasq.Task); ok { 209 | r0 = rf(ctx, task, finishStatus) 210 | } else { 211 | if ret.Get(0) != nil { 212 | r0 = ret.Get(0).(*tasq.Task) 213 | } 214 | } 215 | 216 | if rf, ok := ret.Get(1).(func(context.Context, *tasq.Task, tasq.TaskStatus) error); ok { 217 | r1 = rf(ctx, task, finishStatus) 218 | } else { 219 | r1 = ret.Error(1) 220 | } 221 | 222 | return r0, r1 223 | } 224 | 225 | // RegisterStart provides a mock function with given fields: ctx, task 226 | func (_m *IRepository) RegisterStart(ctx context.Context, task *tasq.Task) (*tasq.Task, error) { 227 | ret := _m.Called(ctx, task) 228 | 229 | var r0 *tasq.Task 230 | var r1 error 231 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) (*tasq.Task, error)); ok { 232 | return rf(ctx, task) 233 | } 234 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) *tasq.Task); ok { 235 | r0 = rf(ctx, task) 236 | } else { 237 | if ret.Get(0) != nil { 238 | r0 = ret.Get(0).(*tasq.Task) 239 | } 240 | } 241 | 242 | if rf, ok := ret.Get(1).(func(context.Context, *tasq.Task) error); ok { 243 | r1 = rf(ctx, task) 244 | } else { 245 | r1 = ret.Error(1) 246 | } 247 | 248 | return r0, r1 249 | } 250 | 251 | // RequeueTask provides a mock function with given fields: ctx, task 252 | func (_m *IRepository) RequeueTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) { 253 | ret := _m.Called(ctx, task) 254 | 255 | var r0 *tasq.Task 256 | var r1 error 257 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) (*tasq.Task, error)); ok { 258 | return rf(ctx, task) 259 | } 260 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) *tasq.Task); ok { 261 | r0 = rf(ctx, task) 262 | } else { 263 | if ret.Get(0) != nil { 264 | r0 = ret.Get(0).(*tasq.Task) 265 | } 266 | } 267 | 268 | if rf, ok := ret.Get(1).(func(context.Context, *tasq.Task) error); ok { 269 | r1 = rf(ctx, task) 270 | } else { 271 | r1 = ret.Error(1) 272 | } 273 | 274 | return r0, r1 275 | } 276 | 277 | // ScanTasks provides a mock function with given fields: ctx, taskStatuses, taskTypes, queues, ordering, limit 278 | func (_m *IRepository) ScanTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes []string, queues []string, ordering tasq.Ordering, limit int) ([]*tasq.Task, error) { 279 | ret := _m.Called(ctx, taskStatuses, taskTypes, queues, ordering, limit) 280 | 281 | var r0 []*tasq.Task 282 | var r1 error 283 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string, tasq.Ordering, int) ([]*tasq.Task, error)); ok { 284 | return rf(ctx, taskStatuses, taskTypes, queues, ordering, limit) 285 | } 286 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string, tasq.Ordering, int) []*tasq.Task); ok { 287 | r0 = rf(ctx, taskStatuses, taskTypes, queues, ordering, limit) 288 | } else { 289 | if ret.Get(0) != nil { 290 | r0 = ret.Get(0).([]*tasq.Task) 291 | } 292 | } 293 | 294 | if rf, ok := ret.Get(1).(func(context.Context, []tasq.TaskStatus, []string, []string, tasq.Ordering, int) error); ok { 295 | r1 = rf(ctx, taskStatuses, taskTypes, queues, ordering, limit) 296 | } else { 297 | r1 = ret.Error(1) 298 | } 299 | 300 | return r0, r1 301 | } 302 | 303 | // SubmitTask provides a mock function with given fields: ctx, task 304 | func (_m *IRepository) SubmitTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) { 305 | ret := _m.Called(ctx, task) 306 | 307 | var r0 *tasq.Task 308 | var r1 error 309 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) (*tasq.Task, error)); ok { 310 | return rf(ctx, task) 311 | } 312 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) *tasq.Task); ok { 313 | r0 = rf(ctx, task) 314 | } else { 315 | if ret.Get(0) != nil { 316 | r0 = ret.Get(0).(*tasq.Task) 317 | } 318 | } 319 | 320 | if rf, ok := ret.Get(1).(func(context.Context, *tasq.Task) error); ok { 321 | r1 = rf(ctx, task) 322 | } else { 323 | r1 = ret.Error(1) 324 | } 325 | 326 | return r0, r1 327 | } 328 | 329 | type mockConstructorTestingTNewIRepository interface { 330 | mock.TestingT 331 | Cleanup(func()) 332 | } 333 | 334 | // NewIRepository creates a new instance of IRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. 335 | func NewIRepository(t mockConstructorTestingTNewIRepository) *IRepository { 336 | mock := &IRepository{} 337 | mock.Mock.Test(t) 338 | 339 | t.Cleanup(func() { mock.AssertExpectations(t) }) 340 | 341 | return mock 342 | } 343 | -------------------------------------------------------------------------------- /producer.go: -------------------------------------------------------------------------------- 1 | package tasq 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | // Producer is a service instance created by a Client with reference to that client 9 | // with the purpose of enabling the submission of new tasks. 10 | type Producer struct { 11 | client *Client 12 | } 13 | 14 | // NewProducer creates a new consumer with a reference to the original tasq client. 15 | func (c *Client) NewProducer() *Producer { 16 | return &Producer{ 17 | client: c, 18 | } 19 | } 20 | 21 | // Submit constructs and submits a new task to the queue based on the supplied arguments. 22 | func (p *Producer) Submit(ctx context.Context, taskType string, taskArgs any, queue string, priority int16, maxReceives int32) (*Task, error) { 23 | newTask, err := NewTask(taskType, taskArgs, queue, priority, maxReceives) 24 | if err != nil { 25 | return nil, fmt.Errorf("error creating task: %w", err) 26 | } 27 | 28 | return p.SubmitTask(ctx, newTask) 29 | } 30 | 31 | // SubmitTask submits an existing task struct to the queue based on the supplied arguments. 32 | func (p *Producer) SubmitTask(ctx context.Context, task *Task) (*Task, error) { 33 | submittedTask, err := p.client.repository.SubmitTask(ctx, task) 34 | if err != nil { 35 | return nil, fmt.Errorf("error submitting task: %w", err) 36 | } 37 | 38 | return submittedTask, nil 39 | } 40 | -------------------------------------------------------------------------------- /producer_test.go: -------------------------------------------------------------------------------- 1 | package tasq_test 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/greencoda/tasq" 8 | "github.com/greencoda/tasq/mocks" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/mock" 11 | "github.com/stretchr/testify/require" 12 | "github.com/stretchr/testify/suite" 13 | ) 14 | 15 | type ProducterTestSuite struct { 16 | suite.Suite 17 | mockRepository *mocks.IRepository 18 | tasqClient *tasq.Client 19 | tasqProducer *tasq.Producer 20 | testTask *tasq.Task 21 | } 22 | 23 | func TestProducterTestSuite(t *testing.T) { 24 | t.Parallel() 25 | 26 | suite.Run(t, new(ProducterTestSuite)) 27 | } 28 | 29 | func (s *ProducterTestSuite) SetupTest() { 30 | s.mockRepository = mocks.NewIRepository(s.T()) 31 | 32 | s.tasqClient = tasq.NewClient(s.mockRepository) 33 | require.NotNil(s.T(), s.tasqClient) 34 | 35 | s.tasqProducer = s.tasqClient.NewProducer() 36 | require.NotNil(s.T(), s.tasqProducer) 37 | 38 | testArgs := "testData" 39 | testTask, err := tasq.NewTask("testTask", testArgs, "testQueue", 100, 5) 40 | require.Nil(s.T(), err) 41 | 42 | s.testTask = testTask 43 | } 44 | 45 | func (s *ProducterTestSuite) TestNewProducer() { 46 | assert.NotNil(s.T(), s.tasqProducer) 47 | } 48 | 49 | func (s *ProducterTestSuite) TestSubmitSuccessful() { 50 | ctx := context.Background() 51 | 52 | s.mockRepository.On("SubmitTask", ctx, mock.AnythingOfType("*tasq.Task")).Return(s.testTask, nil) 53 | 54 | task, err := s.tasqProducer.Submit(ctx, s.testTask.Type, s.testTask.Args, s.testTask.Queue, s.testTask.Priority, s.testTask.MaxReceives) 55 | 56 | assert.NotNil(s.T(), task) 57 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "SubmitTask", ctx, mock.AnythingOfType("*tasq.Task"))) 58 | assert.Nil(s.T(), err) 59 | } 60 | 61 | func (s *ProducterTestSuite) TestSubmitUnsuccessful() { 62 | var ( 63 | ctx = context.Background() 64 | testArgs = "testData" 65 | testTask, _ = tasq.NewTask("testTask", testArgs, "testQueue", 100, 5) 66 | ) 67 | 68 | s.mockRepository.On("SubmitTask", ctx, mock.AnythingOfType("*tasq.Task")).Return(nil, errRepository) 69 | 70 | task, err := s.tasqProducer.Submit(ctx, testTask.Type, testArgs, testTask.Queue, testTask.Priority, testTask.MaxReceives) 71 | 72 | assert.Nil(s.T(), task) 73 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "SubmitTask", ctx, mock.AnythingOfType("*tasq.Task"))) 74 | assert.NotNil(s.T(), err) 75 | } 76 | 77 | func (s *ProducterTestSuite) TestSubmitInvalidpriority() { 78 | ctx := context.Background() 79 | 80 | task, err := s.tasqProducer.Submit(ctx, "testData", nil, "testQueue", 100, 5) 81 | 82 | assert.Nil(s.T(), task) 83 | assert.NotNil(s.T(), err) 84 | } 85 | 86 | func (s *ProducterTestSuite) TestSubmitTask() { 87 | ctx := context.Background() 88 | 89 | s.mockRepository.On("SubmitTask", ctx, mock.AnythingOfType("*tasq.Task")).Return(s.testTask, nil) 90 | 91 | task, err := tasq.NewTask(s.testTask.Type, s.testTask.Args, s.testTask.Queue, s.testTask.Priority, s.testTask.MaxReceives) 92 | require.NotNil(s.T(), task) 93 | require.Nil(s.T(), err) 94 | 95 | submittedTask, err := s.tasqProducer.SubmitTask(ctx, task) 96 | 97 | assert.NotNil(s.T(), submittedTask) 98 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "SubmitTask", ctx, mock.AnythingOfType("*tasq.Task"))) 99 | assert.Nil(s.T(), err) 100 | } 101 | -------------------------------------------------------------------------------- /repository.go: -------------------------------------------------------------------------------- 1 | package tasq 2 | 3 | import ( 4 | "context" 5 | "time" 6 | 7 | "github.com/google/uuid" 8 | ) 9 | 10 | // Ordering is an enum type describing the polling strategy utitlized during 11 | // the polling process. 12 | type Ordering int 13 | 14 | // The collection of orderings. 15 | const ( 16 | OrderingCreatedAtFirst Ordering = iota 17 | OrderingPriorityFirst 18 | ) 19 | 20 | // IRepository describes the mandatory methods a repository must implement 21 | // in order for tasq to be able to use it. 22 | type IRepository interface { 23 | Migrate(ctx context.Context) error 24 | 25 | PingTasks(ctx context.Context, taskIDs []uuid.UUID, visibilityTimeout time.Duration) ([]*Task, error) 26 | PollTasks(ctx context.Context, types, queues []string, visibilityTimeout time.Duration, ordering Ordering, limit int) ([]*Task, error) 27 | CleanTasks(ctx context.Context, minimumAge time.Duration) (int64, error) 28 | 29 | RegisterStart(ctx context.Context, task *Task) (*Task, error) 30 | RegisterError(ctx context.Context, task *Task, errTask error) (*Task, error) 31 | RegisterFinish(ctx context.Context, task *Task, finishStatus TaskStatus) (*Task, error) 32 | 33 | SubmitTask(ctx context.Context, task *Task) (*Task, error) 34 | DeleteTask(ctx context.Context, task *Task, safeDelete bool) error 35 | RequeueTask(ctx context.Context, task *Task) (*Task, error) 36 | 37 | ScanTasks(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string, ordering Ordering, limit int) ([]*Task, error) 38 | CountTasks(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string) (int64, error) 39 | PurgeTasks(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string, safeDelete bool) (int64, error) 40 | } 41 | -------------------------------------------------------------------------------- /repository/mysql/export_test.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | 7 | "github.com/greencoda/tasq" 8 | ) 9 | 10 | func GetTestTaskValues(task *tasq.Task) []driver.Value { 11 | testMySQLTask := newFromTask(task) 12 | 13 | return []driver.Value{ 14 | testMySQLTask.ID, 15 | testMySQLTask.Type, 16 | testMySQLTask.Args, 17 | testMySQLTask.Queue, 18 | testMySQLTask.Priority, 19 | testMySQLTask.Status, 20 | testMySQLTask.ReceiveCount, 21 | testMySQLTask.MaxReceives, 22 | testMySQLTask.LastError, 23 | testMySQLTask.CreatedAt, 24 | testMySQLTask.StartedAt, 25 | testMySQLTask.FinishedAt, 26 | testMySQLTask.VisibleAt, 27 | } 28 | } 29 | 30 | func InterpolateSQL(sql string, params map[string]any) string { 31 | return interpolateSQL(sql, params) 32 | } 33 | 34 | func (d *Repository) GetQueryWithTableName(sqlTemplate string, args ...any) (string, []any) { 35 | return d.getQueryWithTableName(sqlTemplate, args...) 36 | } 37 | 38 | func StringToSQLNullString(input *string) sql.NullString { 39 | return stringToSQLNullString(input) 40 | } 41 | 42 | func ParseNullableString(input sql.NullString) *string { 43 | return parseNullableString(input) 44 | } 45 | -------------------------------------------------------------------------------- /repository/mysql/mysql.go: -------------------------------------------------------------------------------- 1 | // Package mysql provides the implementation of a tasq repository in MySQL 2 | package mysql 3 | 4 | import ( 5 | "bytes" 6 | "context" 7 | "database/sql" 8 | "errors" 9 | "fmt" 10 | "strings" 11 | "text/template" 12 | "time" 13 | 14 | _ "github.com/go-sql-driver/mysql" // import mysql driver 15 | "github.com/google/uuid" 16 | "github.com/greencoda/tasq" 17 | "github.com/jmoiron/sqlx" 18 | "github.com/lib/pq" 19 | ) 20 | 21 | const driverName = "mysql" 22 | 23 | var ( 24 | errUnexpectedDataSourceType = errors.New("unexpected dataSource type") 25 | errFailedToBeginTx = errors.New("failed to begin transaction") 26 | errFailedToCommitTx = errors.New("failed to commit transaction") 27 | errFailedToExecuteSelect = errors.New("failed to execute select query") 28 | errFailedToExecuteUpdate = errors.New("failed to execute update query") 29 | errFailedToExecuteDelete = errors.New("failed to execute delete query") 30 | errFailedToExecuteInsert = errors.New("failed to execute insert query") 31 | errFailedToExecuteCreateTable = errors.New("failed to execute create table query") 32 | errFailedGetRowsAffected = errors.New("failed to get rows affected by query") 33 | ) 34 | 35 | // Repository implements the menthods necessary for tasq to work in MySQL. 36 | type Repository struct { 37 | db *sqlx.DB 38 | tableName string 39 | } 40 | 41 | // NewRepository creates a new MySQL Repository instance. 42 | func NewRepository(dataSource any, prefix string) (*Repository, error) { 43 | switch d := dataSource.(type) { 44 | case string: 45 | return newRepositoryFromDSN(d, prefix) 46 | case *sql.DB: 47 | return newRepositoryFromDB(d, prefix) 48 | } 49 | 50 | return nil, fmt.Errorf("%w: %T", errUnexpectedDataSourceType, dataSource) 51 | } 52 | 53 | func newRepositoryFromDSN(dsn string, prefix string) (*Repository, error) { 54 | dbx, err := sqlx.Open(driverName, dsn) 55 | if err != nil { 56 | return nil, fmt.Errorf("failed to open DB from dsn: %w", err) 57 | } 58 | 59 | return &Repository{ 60 | db: dbx, 61 | tableName: tableName(prefix), 62 | }, nil 63 | } 64 | 65 | func newRepositoryFromDB(db *sql.DB, prefix string) (*Repository, error) { 66 | dbx := sqlx.NewDb(db, driverName) 67 | 68 | return &Repository{ 69 | db: dbx, 70 | tableName: tableName(prefix), 71 | }, nil 72 | } 73 | 74 | // Migrate prepares the database by adding the tasks table. 75 | func (d *Repository) Migrate(ctx context.Context) error { 76 | if err := d.migrateTable(ctx); err != nil { 77 | return err 78 | } 79 | 80 | return nil 81 | } 82 | 83 | // PingTasks pings a list of tasks by their ID 84 | // and extends their invisibility timestamp with the supplied timeout parameter. 85 | func (d *Repository) PingTasks(ctx context.Context, taskIDs []uuid.UUID, visibilityTimeout time.Duration) ([]*tasq.Task, error) { 86 | if len(taskIDs) == 0 { 87 | return []*tasq.Task{}, nil 88 | } 89 | 90 | const ( 91 | updatePingedTasksSQLTemplate = `UPDATE 92 | {{.tableName}} 93 | SET 94 | visible_at = :visibleAt 95 | WHERE 96 | id IN (:pingedTaskIDs);` 97 | selectPingedTasksSQLTemplate = `SELECT 98 | * 99 | FROM 100 | {{.tableName}} 101 | WHERE 102 | id IN (:pingedTaskIDs);` 103 | ) 104 | 105 | tx, err := d.db.Beginx() 106 | if err != nil { 107 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToBeginTx, err) 108 | } 109 | 110 | defer rollback(tx) 111 | 112 | var ( 113 | pingTime = time.Now() 114 | updatePingedTasksQuery, updatePingedTasksArgs = d.getQueryWithTableName(updatePingedTasksSQLTemplate, map[string]any{ 115 | "visibleAt": timeToString(pingTime.Add(visibilityTimeout)), 116 | "pingedTaskIDs": taskIDs, 117 | }) 118 | ) 119 | 120 | _, err = tx.ExecContext(ctx, updatePingedTasksQuery, updatePingedTasksArgs...) 121 | if err != nil { 122 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err) 123 | } 124 | 125 | var ( 126 | pingedMySQLTasks []*mySQLTask 127 | selectPingedTasksQuery, selectPingedTasksArgs = d.getQueryWithTableName(selectPingedTasksSQLTemplate, map[string]any{ 128 | "pingedTaskIDs": taskIDs, 129 | }) 130 | ) 131 | 132 | err = tx.SelectContext(ctx, &pingedMySQLTasks, selectPingedTasksQuery, selectPingedTasksArgs...) 133 | if err != nil { 134 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err) 135 | } 136 | 137 | err = tx.Commit() 138 | if err != nil { 139 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToCommitTx, err) 140 | } 141 | 142 | return mySQLTasksToTasks(pingedMySQLTasks), nil 143 | } 144 | 145 | // PollTasks polls for available tasks matching supplied the parameters 146 | // and sets their invisibility the supplied timeout parameter to the future. 147 | func (d *Repository) PollTasks(ctx context.Context, types, queues []string, visibilityTimeout time.Duration, ordering tasq.Ordering, pollLimit int) ([]*tasq.Task, error) { 148 | if pollLimit == 0 { 149 | return []*tasq.Task{}, nil 150 | } 151 | 152 | const ( 153 | selectPolledTasksSQLTemplate = `SELECT 154 | id 155 | FROM 156 | {{.tableName}} 157 | WHERE 158 | type IN (:pollTypes) AND 159 | queue IN (:pollQueues) AND 160 | status IN (:pollStatuses) AND 161 | visible_at <= :pollTime 162 | ORDER BY 163 | :pollOrdering 164 | LIMIT :pollLimit 165 | FOR UPDATE SKIP LOCKED;` 166 | updatePolledTasksSQLTemplate = `UPDATE 167 | {{.tableName}} 168 | SET 169 | status = :status, 170 | receive_count = receive_count + 1, 171 | visible_at = :visibleAt 172 | WHERE 173 | id IN (:polledTaskIDs);` 174 | selectUpdatedPolledTasksSQLTemplate = `SELECT 175 | * 176 | FROM 177 | {{.tableName}} 178 | WHERE 179 | id IN (:polledTaskIDs);` 180 | ) 181 | 182 | tx, err := d.db.Beginx() 183 | if err != nil { 184 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToBeginTx, err) 185 | } 186 | 187 | defer rollback(tx) 188 | 189 | var ( 190 | polledTaskIDs []TaskID 191 | pollTime = time.Now() 192 | selectPolledTasksQuery, selectPolledTasksArgs = d.getQueryWithTableName(selectPolledTasksSQLTemplate, map[string]any{ 193 | "pollTypes": types, 194 | "pollQueues": queues, 195 | "pollStatuses": tasq.GetTaskStatuses(tasq.OpenTasks), 196 | "pollTime": timeToString(pollTime), 197 | "pollOrdering": getOrderingDirectives(ordering), 198 | "pollLimit": pollLimit, 199 | }) 200 | ) 201 | 202 | err = tx.SelectContext(ctx, &polledTaskIDs, selectPolledTasksQuery, selectPolledTasksArgs...) 203 | if err != nil { 204 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err) 205 | } 206 | 207 | if len(polledTaskIDs) == 0 { 208 | return []*tasq.Task{}, nil 209 | } 210 | 211 | updatePolledTasksQuery, updatePolledTasksArgs := d.getQueryWithTableName(updatePolledTasksSQLTemplate, map[string]any{ 212 | "status": tasq.StatusEnqueued, 213 | "visibleAt": timeToString(pollTime.Add(visibilityTimeout)), 214 | "polledTaskIDs": polledTaskIDs, 215 | }) 216 | 217 | _, err = tx.ExecContext(ctx, updatePolledTasksQuery, updatePolledTasksArgs...) 218 | if err != nil { 219 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err) 220 | } 221 | 222 | var ( 223 | polledMySQLTasks []*mySQLTask 224 | selectUpdatedTasksQuery, selectUpdatedTasksArgs = d.getQueryWithTableName(selectUpdatedPolledTasksSQLTemplate, map[string]any{ 225 | "polledTaskIDs": polledTaskIDs, 226 | }) 227 | ) 228 | 229 | err = tx.SelectContext(ctx, &polledMySQLTasks, selectUpdatedTasksQuery, selectUpdatedTasksArgs...) 230 | if err != nil { 231 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err) 232 | } 233 | 234 | err = tx.Commit() 235 | if err != nil { 236 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToCommitTx, err) 237 | } 238 | 239 | return mySQLTasksToTasks(polledMySQLTasks), nil 240 | } 241 | 242 | // CleanTasks removes finished tasks from the queue 243 | // if their creation date is past the supplied duration. 244 | func (d *Repository) CleanTasks(ctx context.Context, cleanAge time.Duration) (int64, error) { 245 | const cleanTasksSQLTemplate = `DELETE FROM 246 | {{.tableName}} 247 | WHERE 248 | status IN (:statuses) AND 249 | created_at <= :cleanAt;` 250 | 251 | var ( 252 | cleanTime = time.Now() 253 | cleanTasksQuery, cleanTasksArgs = d.getQueryWithTableName(cleanTasksSQLTemplate, map[string]any{ 254 | "statuses": tasq.GetTaskStatuses(tasq.FinishedTasks), 255 | "cleanAt": timeToString(cleanTime.Add(-cleanAge)), 256 | }) 257 | ) 258 | 259 | result, err := d.db.ExecContext(ctx, cleanTasksQuery, cleanTasksArgs...) 260 | if err != nil { 261 | return 0, fmt.Errorf("%s: %w", errFailedToExecuteDelete, err) 262 | } 263 | 264 | rowsAffected, err := result.RowsAffected() 265 | if err != nil { 266 | return 0, fmt.Errorf("%s: %w", errFailedGetRowsAffected, err) 267 | } 268 | 269 | return rowsAffected, nil 270 | } 271 | 272 | // RegisterStart marks a task as started with the 'in progress' status 273 | // and records the time of start. 274 | func (d *Repository) RegisterStart(ctx context.Context, task *tasq.Task) (*tasq.Task, error) { 275 | const ( 276 | updateTaskSQLTemplate = `UPDATE 277 | {{.tableName}} 278 | SET 279 | status = :status, 280 | started_at = :startTime 281 | WHERE 282 | id = :taskID;` 283 | selectUpdatedTaskSQLTemplate = `SELECT * 284 | FROM 285 | {{.tableName}} 286 | WHERE 287 | id = :taskID;` 288 | ) 289 | 290 | tx, err := d.db.Beginx() 291 | if err != nil { 292 | return nil, fmt.Errorf("%s: %w", errFailedToBeginTx, err) 293 | } 294 | 295 | defer rollback(tx) 296 | 297 | var ( 298 | mySQLTask = newFromTask(task) 299 | startTime = time.Now() 300 | updateTaskQuery, updateTaskArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{ 301 | "status": tasq.StatusInProgress, 302 | "startTime": timeToString(startTime), 303 | "taskID": mySQLTask.ID, 304 | }) 305 | ) 306 | 307 | _, err = tx.ExecContext(ctx, updateTaskQuery, updateTaskArgs...) 308 | if err != nil { 309 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err) 310 | } 311 | 312 | selectUpdatedTaskQuery, selectUpdatedTaskArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{ 313 | "taskID": mySQLTask.ID, 314 | }) 315 | 316 | err = tx.QueryRowxContext(ctx, selectUpdatedTaskQuery, selectUpdatedTaskArgs...). 317 | StructScan(mySQLTask) 318 | if err != nil { 319 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err) 320 | } 321 | 322 | err = tx.Commit() 323 | if err != nil { 324 | return nil, fmt.Errorf("%s: %w", errFailedToCommitTx, err) 325 | } 326 | 327 | return mySQLTask.toTask(), nil 328 | } 329 | 330 | // RegisterError records an error message on the task as last error. 331 | func (d *Repository) RegisterError(ctx context.Context, task *tasq.Task, errTask error) (*tasq.Task, error) { 332 | const ( 333 | updateTaskSQLTemplate = `UPDATE 334 | {{.tableName}} 335 | SET 336 | last_error = :errorMessage 337 | WHERE 338 | id = :taskID;` 339 | selectUpdatedTaskSQLTemplate = `SELECT * 340 | FROM 341 | {{.tableName}} 342 | WHERE 343 | id = :taskID;` 344 | ) 345 | 346 | tx, err := d.db.Beginx() 347 | if err != nil { 348 | return nil, fmt.Errorf("%s: %w", errFailedToBeginTx, err) 349 | } 350 | 351 | defer rollback(tx) 352 | 353 | var ( 354 | mySQLTask = newFromTask(task) 355 | updateTaskQuery, updateTaskArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{ 356 | "errorMessage": errTask.Error(), 357 | "taskID": mySQLTask.ID, 358 | }) 359 | ) 360 | 361 | _, err = tx.ExecContext(ctx, updateTaskQuery, updateTaskArgs...) 362 | if err != nil { 363 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err) 364 | } 365 | 366 | selectUpdatedTaskQuery, selectUpdatedTaskArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{ 367 | "taskID": mySQLTask.ID, 368 | }) 369 | 370 | err = tx.QueryRowxContext(ctx, selectUpdatedTaskQuery, selectUpdatedTaskArgs...). 371 | StructScan(mySQLTask) 372 | if err != nil { 373 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err) 374 | } 375 | 376 | err = tx.Commit() 377 | if err != nil { 378 | return nil, fmt.Errorf("%s: %w", errFailedToCommitTx, err) 379 | } 380 | 381 | return mySQLTask.toTask(), nil 382 | } 383 | 384 | // RegisterFinish marks a task as finished with the supplied status 385 | // and records the time of finish. 386 | func (d *Repository) RegisterFinish(ctx context.Context, task *tasq.Task, finishStatus tasq.TaskStatus) (*tasq.Task, error) { 387 | const ( 388 | updateTaskSQLTemplate = `UPDATE 389 | {{.tableName}} 390 | SET 391 | status = :status, 392 | finished_at = :finishTime 393 | WHERE 394 | id = :taskID;` 395 | selectUpdatedTaskSQLTemplate = `SELECT * 396 | FROM 397 | {{.tableName}} 398 | WHERE 399 | id = :taskID;` 400 | ) 401 | 402 | tx, err := d.db.Beginx() 403 | if err != nil { 404 | return nil, fmt.Errorf("%s: %w", errFailedToBeginTx, err) 405 | } 406 | 407 | defer rollback(tx) 408 | 409 | var ( 410 | mySQLTask = newFromTask(task) 411 | finishTime = time.Now() 412 | updateTasksQuery, updateTasksArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{ 413 | "status": finishStatus, 414 | "finishTime": timeToString(finishTime), 415 | "taskID": mySQLTask.ID, 416 | }) 417 | ) 418 | 419 | _, err = tx.ExecContext(ctx, updateTasksQuery, updateTasksArgs...) 420 | if err != nil { 421 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err) 422 | } 423 | 424 | selectUpdatedTasksQuery, selectUpdatedTasksArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{ 425 | "taskID": mySQLTask.ID, 426 | }) 427 | 428 | err = tx.QueryRowxContext(ctx, selectUpdatedTasksQuery, selectUpdatedTasksArgs...). 429 | StructScan(mySQLTask) 430 | if err != nil { 431 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err) 432 | } 433 | 434 | err = tx.Commit() 435 | if err != nil { 436 | return nil, fmt.Errorf("%s: %w", errFailedToCommitTx, err) 437 | } 438 | 439 | return mySQLTask.toTask(), nil 440 | } 441 | 442 | // SubmitTask adds the supplied task to the queue. 443 | func (d *Repository) SubmitTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) { 444 | const ( 445 | insertTaskSQLTemplate = `INSERT INTO 446 | {{.tableName}} 447 | (id, type, args, queue, priority, status, max_receives, created_at, visible_at) 448 | VALUES 449 | (:id, :type, :args, :queue, :priority, :status, :maxReceives, :createdAt, :visibleAt);` 450 | selectInsertedTaskSQLTemplate = `SELECT * 451 | FROM 452 | {{.tableName}} 453 | WHERE 454 | id = :taskID;` 455 | ) 456 | 457 | tx, err := d.db.Beginx() 458 | if err != nil { 459 | return nil, fmt.Errorf("%s: %w", errFailedToBeginTx, err) 460 | } 461 | 462 | defer rollback(tx) 463 | 464 | var ( 465 | mySQLTask = newFromTask(task) 466 | insertTaskQuery, insertTaskArgs = d.getQueryWithTableName(insertTaskSQLTemplate, map[string]any{ 467 | "id": mySQLTask.ID, 468 | "type": mySQLTask.Type, 469 | "args": mySQLTask.Args, 470 | "queue": mySQLTask.Queue, 471 | "priority": mySQLTask.Priority, 472 | "status": mySQLTask.Status, 473 | "maxReceives": mySQLTask.MaxReceives, 474 | "createdAt": mySQLTask.CreatedAt, 475 | "visibleAt": mySQLTask.VisibleAt, 476 | }) 477 | ) 478 | 479 | _, err = tx.ExecContext(ctx, insertTaskQuery, insertTaskArgs...) 480 | if err != nil { 481 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteInsert, err) 482 | } 483 | 484 | selectInsertedTaskQuery, selectInsertedTaskArgs := d.getQueryWithTableName(selectInsertedTaskSQLTemplate, map[string]any{ 485 | "taskID": mySQLTask.ID, 486 | }) 487 | 488 | err = tx.QueryRowxContext(ctx, selectInsertedTaskQuery, selectInsertedTaskArgs...). 489 | StructScan(mySQLTask) 490 | if err != nil { 491 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err) 492 | } 493 | 494 | err = tx.Commit() 495 | if err != nil { 496 | return nil, fmt.Errorf("%s: %w", errFailedToCommitTx, err) 497 | } 498 | 499 | return mySQLTask.toTask(), nil 500 | } 501 | 502 | // DeleteTask removes the supplied task from the queue. 503 | func (d *Repository) DeleteTask(ctx context.Context, task *tasq.Task, safeDelete bool) error { 504 | var ( 505 | mySQLTask = newFromTask(task) 506 | conditions = []string{ 507 | `id = :taskID`, 508 | } 509 | parameters = map[string]any{ 510 | "taskID": mySQLTask.ID, 511 | } 512 | ) 513 | 514 | if safeDelete { 515 | d.applySafeDeleteConditions(&conditions, ¶meters) 516 | } 517 | 518 | deleteTaskSQLTemplate := `DELETE FROM {{.tableName}} WHERE ` + strings.Join(conditions, ` AND `) + `;` 519 | 520 | deleteTaskQuery, deleteTaskArgs := d.getQueryWithTableName(deleteTaskSQLTemplate, parameters) 521 | 522 | _, err := d.db.ExecContext(ctx, deleteTaskQuery, deleteTaskArgs...) 523 | if err != nil { 524 | return fmt.Errorf("%s: %w", errFailedToExecuteDelete, err) 525 | } 526 | 527 | return nil 528 | } 529 | 530 | // RequeueTask marks a task as new, so it can be picked up again. 531 | func (d *Repository) RequeueTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) { 532 | const ( 533 | updateTaskSQLTemplate = `UPDATE 534 | {{.tableName}} 535 | SET 536 | status = :status 537 | WHERE 538 | id = :taskID;` 539 | selectUpdatedTaskSQLTemplate = `SELECT * 540 | FROM 541 | {{.tableName}} 542 | WHERE 543 | id = :taskID;` 544 | ) 545 | 546 | tx, err := d.db.Beginx() 547 | if err != nil { 548 | return nil, fmt.Errorf("%s: %w", errFailedToBeginTx, err) 549 | } 550 | 551 | defer rollback(tx) 552 | 553 | var ( 554 | mySQLTask = newFromTask(task) 555 | updateTaskQuery, updateTaskArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{ 556 | "status": tasq.StatusNew, 557 | "taskID": mySQLTask.ID, 558 | }) 559 | ) 560 | 561 | _, err = tx.ExecContext(ctx, updateTaskQuery, updateTaskArgs...) 562 | if err != nil { 563 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err) 564 | } 565 | 566 | selectUpdatedTaskQuery, selectUpdatedTaskArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{ 567 | "taskID": mySQLTask.ID, 568 | }) 569 | 570 | err = tx.QueryRowxContext(ctx, selectUpdatedTaskQuery, selectUpdatedTaskArgs...). 571 | StructScan(mySQLTask) 572 | if err != nil { 573 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err) 574 | } 575 | 576 | err = tx.Commit() 577 | if err != nil { 578 | return nil, fmt.Errorf("%s: %w", errFailedToCommitTx, err) 579 | } 580 | 581 | return mySQLTask.toTask(), err 582 | } 583 | 584 | // CountTasks returns the number of tasks in the queue based on the supplied filters. 585 | func (d *Repository) CountTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (int64, error) { 586 | var ( 587 | count int64 588 | selectTaskCountQuery, selectTaskCountArgs = d.getQueryWithTableName( 589 | d.buildCountSQLTemplate(taskStatuses, taskTypes, queues), 590 | ) 591 | ) 592 | 593 | err := d.db.GetContext(ctx, &count, selectTaskCountQuery, selectTaskCountArgs...) 594 | if err != nil { 595 | return 0, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err) 596 | } 597 | 598 | return count, nil 599 | } 600 | 601 | // ScanTasks returns a list of tasks in the queue based on the supplied filters. 602 | func (d *Repository) ScanTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) ([]*tasq.Task, error) { 603 | var ( 604 | scannedTasks []*mySQLTask 605 | selectScannedTasksQuery, selectScannedTasksArgs = d.getQueryWithTableName( 606 | d.buildScanSQLTemplate(taskStatuses, taskTypes, queues, ordering, scanLimit), 607 | ) 608 | ) 609 | 610 | err := d.db.SelectContext(ctx, &scannedTasks, selectScannedTasksQuery, selectScannedTasksArgs...) 611 | if err != nil { 612 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err) 613 | } 614 | 615 | return mySQLTasksToTasks(scannedTasks), nil 616 | } 617 | 618 | // PurgeTasks removes all tasks from the queue based on the supplied filters. 619 | func (d *Repository) PurgeTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (int64, error) { 620 | selectPurgedTasksQuery, selectPurgedTasksArgs := d.getQueryWithTableName( 621 | d.buildPurgeSQLTemplate(taskStatuses, taskTypes, queues, safeDelete), 622 | ) 623 | 624 | result, err := d.db.ExecContext(ctx, selectPurgedTasksQuery, selectPurgedTasksArgs...) 625 | if err != nil { 626 | return 0, fmt.Errorf("%s: %w", errFailedToExecuteDelete, err) 627 | } 628 | 629 | rowsAffected, err := result.RowsAffected() 630 | if err != nil { 631 | return 0, fmt.Errorf("%s: %w", errFailedGetRowsAffected, err) 632 | } 633 | 634 | return rowsAffected, nil 635 | } 636 | 637 | func (d *Repository) buildCountSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (string, map[string]any) { 638 | var ( 639 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues) 640 | sqlTemplate = `SELECT COUNT(*) FROM {{.tableName}}` 641 | ) 642 | 643 | if len(conditions) > 0 { 644 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ") 645 | } 646 | 647 | return sqlTemplate + `;`, parameters 648 | } 649 | 650 | func (d *Repository) buildScanSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) (string, map[string]any) { 651 | var ( 652 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues) 653 | sqlTemplate = `SELECT * FROM {{.tableName}}` 654 | ) 655 | 656 | if len(conditions) > 0 { 657 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ") 658 | } 659 | 660 | sqlTemplate += ` ORDER BY :scanOrdering LIMIT :limit;` 661 | 662 | parameters["scanOrdering"] = pq.Array(getOrderingDirectives(ordering)) 663 | parameters["limit"] = scanLimit 664 | 665 | return sqlTemplate + `;`, parameters 666 | } 667 | 668 | func (d *Repository) buildPurgeSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (string, map[string]any) { 669 | var ( 670 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues) 671 | sqlTemplate = `DELETE FROM {{.tableName}}` 672 | ) 673 | 674 | if safeDelete { 675 | d.applySafeDeleteConditions(&conditions, ¶meters) 676 | } 677 | 678 | if len(conditions) > 0 { 679 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ") 680 | } 681 | 682 | return sqlTemplate + `;`, parameters 683 | } 684 | 685 | func (d *Repository) applySafeDeleteConditions(conditions *[]string, parameters *map[string]any) { 686 | *conditions = append(*conditions, `( 687 | ( 688 | visible_at <= :visibleAt 689 | ) OR ( 690 | status IN (:statuses) AND 691 | visible_at > :visibleAt 692 | ) 693 | )`) 694 | (*parameters)["statuses"] = []tasq.TaskStatus{tasq.StatusNew} 695 | (*parameters)["visibleAt"] = time.Now() 696 | } 697 | 698 | func (d *Repository) buildFilterConditions(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) ([]string, map[string]any) { 699 | var ( 700 | conditions []string 701 | parameters = make(map[string]any) 702 | ) 703 | 704 | if len(taskStatuses) > 0 { 705 | conditions = append(conditions, `status IN (:filterStatuses)`) 706 | parameters["filterStatuses"] = taskStatuses 707 | } 708 | 709 | if len(taskTypes) > 0 { 710 | conditions = append(conditions, `type IN (:filterTypes)`) 711 | parameters["filterTypes"] = taskTypes 712 | } 713 | 714 | if len(queues) > 0 { 715 | conditions = append(conditions, `queue IN (:filterQueues)`) 716 | parameters["filterQueues"] = queues 717 | } 718 | 719 | return conditions, parameters 720 | } 721 | 722 | func (d *Repository) getQueryWithTableName(sqlTemplate string, args ...any) (string, []any) { 723 | query := interpolateSQL(sqlTemplate, map[string]any{ 724 | "tableName": d.tableName, 725 | }) 726 | 727 | query, args, err := sqlx.Named(query, args) 728 | if err != nil { 729 | panic(err) 730 | } 731 | 732 | query, args, err = sqlx.In(query, args...) 733 | if err != nil { 734 | panic(err) 735 | } 736 | 737 | return d.db.Rebind(query), args 738 | } 739 | 740 | func (d *Repository) migrateTable(ctx context.Context) error { 741 | const sqlTemplate = `CREATE TABLE IF NOT EXISTS {{.tableName}} ( 742 | id binary(16) NOT NULL, 743 | type text NOT NULL, 744 | args longblob NOT NULL, 745 | queue text NOT NULL, 746 | priority smallint NOT NULL, 747 | status enum({{.enumValues}}) NOT NULL, 748 | receive_count int NOT NULL DEFAULT '0', 749 | max_receives int NOT NULL DEFAULT '0', 750 | last_error text, 751 | created_at datetime(6) NOT NULL DEFAULT '0001-01-01 00:00:00.000000', 752 | started_at datetime(6), 753 | finished_at datetime(6), 754 | visible_at datetime(6) NOT NULL DEFAULT '0001-01-01 00:00:00.000000', 755 | PRIMARY KEY (id) 756 | );` 757 | 758 | query := interpolateSQL(sqlTemplate, map[string]any{ 759 | "tableName": d.tableName, 760 | "enumValues": sliceToMySQLValueList(tasq.GetTaskStatuses(tasq.AllTasks)), 761 | }) 762 | 763 | _, err := d.db.ExecContext(ctx, query) 764 | if err != nil { 765 | return fmt.Errorf("%s: %w", errFailedToExecuteCreateTable, err) 766 | } 767 | 768 | return nil 769 | } 770 | 771 | func getOrderingDirectives(ordering tasq.Ordering) []string { 772 | var ( 773 | OrderingCreatedAtFirst = []string{"created_at ASC", "priority DESC"} 774 | OrderingPriorityFirst = []string{"priority DESC", "created_at ASC"} 775 | ) 776 | 777 | if orderingDirectives, ok := map[tasq.Ordering][]string{ 778 | tasq.OrderingCreatedAtFirst: OrderingCreatedAtFirst, 779 | tasq.OrderingPriorityFirst: OrderingPriorityFirst, 780 | }[ordering]; ok { 781 | return orderingDirectives 782 | } 783 | 784 | return OrderingCreatedAtFirst 785 | } 786 | 787 | func rollback(tx *sqlx.Tx) { 788 | if err := tx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) { 789 | panic(err) 790 | } 791 | } 792 | 793 | func sliceToMySQLValueList[T any](slice []T) string { 794 | stringSlice := make([]string, 0, len(slice)) 795 | 796 | for _, s := range slice { 797 | stringSlice = append(stringSlice, fmt.Sprint(s)) 798 | } 799 | 800 | return fmt.Sprintf(`"%s"`, strings.Join(stringSlice, `", "`)) 801 | } 802 | 803 | func tableName(prefix string) string { 804 | const tableName = "tasks" 805 | 806 | if len(prefix) > 0 { 807 | return prefix + "_" + tableName 808 | } 809 | 810 | return tableName 811 | } 812 | 813 | func interpolateSQL(sql string, params map[string]any) string { 814 | template, err := template.New("sql").Parse(sql) 815 | if err != nil { 816 | panic(err) 817 | } 818 | 819 | var outputBuffer bytes.Buffer 820 | 821 | err = template.Execute(&outputBuffer, params) 822 | if err != nil { 823 | panic(err) 824 | } 825 | 826 | return outputBuffer.String() 827 | } 828 | -------------------------------------------------------------------------------- /repository/mysql/mysqlTask.go: -------------------------------------------------------------------------------- 1 | package mysql 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | "errors" 7 | "fmt" 8 | "time" 9 | 10 | "github.com/google/uuid" 11 | "github.com/greencoda/tasq" 12 | ) 13 | 14 | const ( 15 | idLength = 16 16 | timeFormat = "2006-01-02 15:04:05.999999" 17 | ) 18 | 19 | var ( 20 | errIncorrectLength = errors.New("Scan: MySQLTaskID is of incorrect length") 21 | errUnableToScan = errors.New("Scan: unable to scan type into MySQLTaskID") 22 | ) 23 | 24 | // TaskID represents the types used to manage conversion of UUID 25 | // to MySQL's binary(16) format. 26 | type TaskID [idLength]byte 27 | 28 | // Scan implements sql.Scanner so TaskIDs can be read from MySQL transparently. 29 | func (i *TaskID) Scan(src any) error { 30 | switch src := src.(type) { 31 | case nil: 32 | return nil 33 | case []byte: 34 | if len(src) == 0 { 35 | return nil 36 | } 37 | 38 | if len(src) != idLength { 39 | return fmt.Errorf("%w: %v", errIncorrectLength, len(src)) 40 | } 41 | 42 | copy((*i)[:], src) 43 | default: 44 | return fmt.Errorf("%w: %T", errUnableToScan, src) 45 | } 46 | 47 | return nil 48 | } 49 | 50 | // Value implements sql.Valuer so that TaskIDs can be written to MySQL 51 | // transparently. 52 | func (i TaskID) Value() (driver.Value, error) { 53 | return i[:], nil 54 | } 55 | 56 | type mySQLTask struct { 57 | ID TaskID `db:"id"` 58 | Type string `db:"type"` 59 | Args []byte `db:"args"` 60 | Queue string `db:"queue"` 61 | Priority int16 `db:"priority"` 62 | Status tasq.TaskStatus `db:"status"` 63 | ReceiveCount int32 `db:"receive_count"` 64 | MaxReceives int32 `db:"max_receives"` 65 | LastError sql.NullString `db:"last_error"` 66 | CreatedAt string `db:"created_at"` 67 | StartedAt sql.NullString `db:"started_at"` 68 | FinishedAt sql.NullString `db:"finished_at"` 69 | VisibleAt string `db:"visible_at"` 70 | } 71 | 72 | func newFromTask(task *tasq.Task) *mySQLTask { 73 | return &mySQLTask{ 74 | ID: TaskID(task.ID), 75 | Type: task.Type, 76 | Args: task.Args, 77 | Queue: task.Queue, 78 | Priority: task.Priority, 79 | Status: task.Status, 80 | ReceiveCount: task.ReceiveCount, 81 | MaxReceives: task.MaxReceives, 82 | LastError: stringToSQLNullString(task.LastError), 83 | CreatedAt: timeToString(task.CreatedAt), 84 | StartedAt: timeToSQLNullString(task.StartedAt), 85 | FinishedAt: timeToSQLNullString(task.FinishedAt), 86 | VisibleAt: timeToString(task.VisibleAt), 87 | } 88 | } 89 | 90 | func (t *mySQLTask) toTask() *tasq.Task { 91 | return &tasq.Task{ 92 | ID: uuid.UUID(t.ID), 93 | Type: t.Type, 94 | Args: t.Args, 95 | Queue: t.Queue, 96 | Priority: t.Priority, 97 | Status: t.Status, 98 | ReceiveCount: t.ReceiveCount, 99 | MaxReceives: t.MaxReceives, 100 | LastError: parseNullableString(t.LastError), 101 | CreatedAt: parseTime(t.CreatedAt), 102 | StartedAt: parseNullableTime(t.StartedAt), 103 | FinishedAt: parseNullableTime(t.FinishedAt), 104 | VisibleAt: parseTime(t.VisibleAt), 105 | } 106 | } 107 | 108 | func mySQLTasksToTasks(mySQLTasks []*mySQLTask) []*tasq.Task { 109 | tasks := make([]*tasq.Task, len(mySQLTasks)) 110 | 111 | for i, mySQLTask := range mySQLTasks { 112 | tasks[i] = mySQLTask.toTask() 113 | } 114 | 115 | return tasks 116 | } 117 | 118 | func stringToSQLNullString(input *string) sql.NullString { 119 | if input == nil { 120 | return sql.NullString{ 121 | String: "", 122 | Valid: false, 123 | } 124 | } 125 | 126 | return sql.NullString{ 127 | String: *input, 128 | Valid: true, 129 | } 130 | } 131 | 132 | func timeToString(input time.Time) string { 133 | return input.Format(timeFormat) 134 | } 135 | 136 | func timeToSQLNullString(input *time.Time) sql.NullString { 137 | if input == nil { 138 | return sql.NullString{ 139 | String: "", 140 | Valid: false, 141 | } 142 | } 143 | 144 | return sql.NullString{ 145 | String: input.Format(timeFormat), 146 | Valid: true, 147 | } 148 | } 149 | 150 | func parseNullableString(input sql.NullString) *string { 151 | if !input.Valid { 152 | return nil 153 | } 154 | 155 | return &input.String 156 | } 157 | 158 | func parseTime(input string) time.Time { 159 | parsedTime, err := time.Parse(timeFormat, input) 160 | if err != nil { 161 | return time.Time{} 162 | } 163 | 164 | return parsedTime 165 | } 166 | 167 | func parseNullableTime(input sql.NullString) *time.Time { 168 | if !input.Valid { 169 | return nil 170 | } 171 | 172 | parsedTime := parseTime(input.String) 173 | 174 | return &parsedTime 175 | } 176 | -------------------------------------------------------------------------------- /repository/mysql/mysqlTask_test.go: -------------------------------------------------------------------------------- 1 | package mysql_test 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/greencoda/tasq/repository/mysql" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestMySQLTaskIDScan(t *testing.T) { 12 | t.Parallel() 13 | 14 | var mySQLID mysql.TaskID 15 | 16 | err := mySQLID.Scan(nil) 17 | assert.Nil(t, err) 18 | 19 | err = mySQLID.Scan([]byte{}) 20 | assert.Nil(t, err) 21 | 22 | err = mySQLID.Scan([]byte{12}) 23 | assert.NotNil(t, err) 24 | 25 | err = mySQLID.Scan("test") 26 | assert.NotNil(t, err) 27 | } 28 | 29 | func TestStringToSQLNullString(t *testing.T) { 30 | t.Parallel() 31 | 32 | var ( 33 | emptyInput = "" 34 | nonEmptyInput = "test" 35 | nilInput *string 36 | ) 37 | 38 | sqlNullString := mysql.StringToSQLNullString(&emptyInput) 39 | assert.Equal(t, sql.NullString{ 40 | String: emptyInput, 41 | Valid: true, 42 | }, sqlNullString) 43 | 44 | sqlNullString = mysql.StringToSQLNullString(&nonEmptyInput) 45 | assert.Equal(t, sql.NullString{ 46 | String: nonEmptyInput, 47 | Valid: true, 48 | }, sqlNullString) 49 | 50 | sqlNullString = mysql.StringToSQLNullString(nilInput) 51 | assert.Equal(t, sql.NullString{ 52 | String: "", 53 | Valid: false, 54 | }, sqlNullString) 55 | } 56 | 57 | func TestParseNullableString(t *testing.T) { 58 | t.Parallel() 59 | 60 | var ( 61 | emptyInput = sql.NullString{ 62 | String: "", 63 | Valid: true, 64 | } 65 | nonEmptyInput = sql.NullString{ 66 | String: "test", 67 | Valid: true, 68 | } 69 | nilInput = sql.NullString{ 70 | String: "", 71 | Valid: false, 72 | } 73 | ) 74 | 75 | output := mysql.ParseNullableString(emptyInput) 76 | assert.NotNil(t, output) 77 | assert.Equal(t, *output, emptyInput.String) 78 | 79 | output = mysql.ParseNullableString(nonEmptyInput) 80 | assert.NotNil(t, output) 81 | assert.Equal(t, *output, nonEmptyInput.String) 82 | 83 | output = mysql.ParseNullableString(nilInput) 84 | assert.Nil(t, output) 85 | } 86 | -------------------------------------------------------------------------------- /repository/mysql/mysql_test.go: -------------------------------------------------------------------------------- 1 | package mysql_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "errors" 8 | "regexp" 9 | "testing" 10 | "time" 11 | 12 | "github.com/DATA-DOG/go-sqlmock" 13 | "github.com/google/uuid" 14 | "github.com/greencoda/tasq" 15 | "github.com/greencoda/tasq/repository/mysql" 16 | "github.com/stretchr/testify/assert" 17 | "github.com/stretchr/testify/require" 18 | "github.com/stretchr/testify/suite" 19 | ) 20 | 21 | var ( 22 | ctx = context.Background() 23 | testTask = getStartedTestTask() 24 | taskColumns = []string{ 25 | "id", 26 | "type", 27 | "args", 28 | "queue", 29 | "priority", 30 | "status", 31 | "receive_count", 32 | "max_receives", 33 | "last_error", 34 | "created_at", 35 | "started_at", 36 | "finished_at", 37 | "visible_at", 38 | } 39 | testTaskType = "testTask" 40 | testTaskQueue = "testQueue" 41 | taskValues = mysql.GetTestTaskValues(testTask) 42 | errSQL = errors.New("sql error") 43 | errTask = errors.New("task error") 44 | ) 45 | 46 | func getStartedTestTask() *tasq.Task { 47 | var ( 48 | testTask, _ = tasq.NewTask(testTaskType, true, testTaskQueue, 100, 5) 49 | startTime = testTask.CreatedAt.Add(time.Second) 50 | ) 51 | 52 | testTask.StartedAt = &startTime 53 | 54 | return testTask 55 | } 56 | 57 | type MySQLTestSuite struct { 58 | suite.Suite 59 | db *sql.DB 60 | sqlMock sqlmock.Sqlmock 61 | mockedRepository tasq.IRepository 62 | } 63 | 64 | func TestTaskTestSuite(t *testing.T) { 65 | t.Parallel() 66 | 67 | suite.Run(t, new(MySQLTestSuite)) 68 | } 69 | 70 | func (s *MySQLTestSuite) SetupTest() { 71 | var err error 72 | 73 | s.db, s.sqlMock, err = sqlmock.New() 74 | require.Nil(s.T(), err) 75 | 76 | s.mockedRepository, err = mysql.NewRepository(s.db, "test") 77 | require.NotNil(s.T(), s.mockedRepository) 78 | require.Nil(s.T(), err) 79 | } 80 | 81 | func (s *MySQLTestSuite) TestNewRepository() { 82 | // providing the datasource as *sql.DB 83 | repository, err := mysql.NewRepository(s.db, "test") 84 | assert.NotNil(s.T(), repository) 85 | assert.Nil(s.T(), err) 86 | 87 | // providing the datasource as *sql.DB with no prefix 88 | repository, err = mysql.NewRepository(s.db, "") 89 | assert.NotNil(s.T(), repository) 90 | assert.Nil(s.T(), err) 91 | 92 | // providing the datasource as dsn string 93 | repository, err = mysql.NewRepository("root:root@/test", "test") 94 | assert.NotNil(s.T(), repository) 95 | assert.Nil(s.T(), err) 96 | 97 | // providing an invalid drivdsner as dsn string 98 | repository, err = mysql.NewRepository("invalidDSN", "test") 99 | assert.Nil(s.T(), repository) 100 | assert.NotNil(s.T(), err) 101 | 102 | // providing the datasource as unknown datasource type 103 | repository, err = mysql.NewRepository(false, "test") 104 | assert.Nil(s.T(), repository) 105 | assert.NotNil(s.T(), err) 106 | } 107 | 108 | func (s *MySQLTestSuite) TestMigrate() { 109 | // First try - creating the tasks table fails 110 | s.sqlMock.ExpectExec(`CREATE TABLE IF NOT EXISTS test_tasks`).WillReturnError(errSQL) 111 | 112 | err := s.mockedRepository.Migrate(ctx) 113 | assert.NotNil(s.T(), err) 114 | 115 | // Second try - migration succeeds 116 | s.sqlMock.ExpectExec(`CREATE TABLE IF NOT EXISTS test_tasks`).WillReturnResult(sqlmock.NewResult(1, 1)) 117 | 118 | err = s.mockedRepository.Migrate(ctx) 119 | assert.Nil(s.T(), err) 120 | } 121 | 122 | func (s *MySQLTestSuite) TestPingTasks() { 123 | var ( 124 | taskUUID = uuid.New() 125 | taskUUIDBytes, _ = taskUUID.MarshalBinary() 126 | updateMockRegexp = regexp.QuoteMeta(`UPDATE test_tasks SET visible_at = ? WHERE id IN (?);`) 127 | selectMockRegexp = regexp.QuoteMeta(`SELECT * FROM test_tasks WHERE id IN (?);`) 128 | ) 129 | 130 | // pinging empty tasklist 131 | tasks, err := s.mockedRepository.PingTasks(ctx, []uuid.UUID{}, 15*time.Second) 132 | assert.Len(s.T(), tasks, 0) 133 | assert.Nil(s.T(), err) 134 | 135 | // beginning the transaction fails 136 | s.sqlMock.ExpectBegin().WillReturnError(errSQL) 137 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second) 138 | assert.Len(s.T(), tasks, 0) 139 | assert.NotNil(s.T(), err) 140 | 141 | // pinging when DB returns no rows 142 | s.sqlMock.ExpectBegin() 143 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL) 144 | s.sqlMock.ExpectRollback() 145 | 146 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second) 147 | assert.Len(s.T(), tasks, 0) 148 | assert.NotNil(s.T(), err) 149 | 150 | // pinging when DB returns no rows, rollback fails 151 | assert.PanicsWithError(s.T(), errSQL.Error(), func() { 152 | s.sqlMock.ExpectBegin() 153 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL) 154 | s.sqlMock.ExpectRollback().WillReturnError(errSQL) 155 | 156 | _, _ = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second) 157 | }) 158 | 159 | // pinging existing task fails 160 | s.sqlMock.ExpectBegin() 161 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 162 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL) 163 | s.sqlMock.ExpectRollback() 164 | 165 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second) 166 | assert.Len(s.T(), tasks, 0) 167 | assert.NotNil(s.T(), err) 168 | 169 | // pinging existing task succeeds, commit fails 170 | s.sqlMock.ExpectBegin() 171 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 172 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes)) 173 | s.sqlMock.ExpectCommit().WillReturnError(errSQL) 174 | 175 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second) 176 | assert.Len(s.T(), tasks, 0) 177 | assert.NotNil(s.T(), err) 178 | 179 | // pinging existing task succeeds, commit succeeds 180 | s.sqlMock.ExpectBegin() 181 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 182 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes)) 183 | s.sqlMock.ExpectCommit() 184 | 185 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second) 186 | assert.Len(s.T(), tasks, 1) 187 | assert.Nil(s.T(), err) 188 | } 189 | 190 | func (s *MySQLTestSuite) TestPollTasks() { 191 | var ( 192 | taskUUID = uuid.New() 193 | taskUUIDBytes, _ = taskUUID.MarshalBinary() 194 | selectMockRegexp = regexp.QuoteMeta(`SELECT 195 | id 196 | FROM 197 | test_tasks 198 | WHERE 199 | type IN (?) AND 200 | queue IN (?) AND 201 | status IN (?, ?, ?) AND 202 | visible_at <= ? 203 | ORDER BY 204 | ?, ? 205 | LIMIT ? 206 | FOR UPDATE SKIP LOCKED;`) 207 | updateMockRegexp = regexp.QuoteMeta(`UPDATE 208 | test_tasks 209 | SET 210 | status = ?, 211 | receive_count = receive_count + 1, 212 | visible_at = ? 213 | WHERE 214 | id IN (?);`) 215 | selectUpdatedMockRegexp = regexp.QuoteMeta(`SELECT 216 | * 217 | FROM 218 | test_tasks 219 | WHERE 220 | id IN (?);`) 221 | ) 222 | 223 | // polling with 0 limit 224 | tasks, err := s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 0) 225 | assert.Len(s.T(), tasks, 0) 226 | assert.Nil(s.T(), err) 227 | 228 | // beginning the transaction fails 229 | s.sqlMock.ExpectBegin().WillReturnError(errSQL) 230 | 231 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1) 232 | assert.Len(s.T(), tasks, 0) 233 | assert.NotNil(s.T(), err) 234 | 235 | // polling when DB returns an error 236 | s.sqlMock.ExpectBegin() 237 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL) 238 | s.sqlMock.ExpectRollback() 239 | 240 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1) 241 | assert.Len(s.T(), tasks, 0) 242 | assert.NotNil(s.T(), err) 243 | 244 | // polling when DB returns no task IDs 245 | s.sqlMock.ExpectBegin() 246 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"})) 247 | s.sqlMock.ExpectRollback() 248 | 249 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1) 250 | assert.Len(s.T(), tasks, 0) 251 | assert.Nil(s.T(), err) 252 | 253 | // polling when DB fails to update task 254 | s.sqlMock.ExpectBegin() 255 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes)) 256 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL) 257 | s.sqlMock.ExpectRollback() 258 | 259 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1) 260 | assert.Len(s.T(), tasks, 0) 261 | assert.NotNil(s.T(), err) 262 | 263 | // polling when DB succeeds to update task but fails to select the updated tasks 264 | s.sqlMock.ExpectBegin() 265 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes)) 266 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 267 | s.sqlMock.ExpectQuery(selectUpdatedMockRegexp).WillReturnError(errSQL) 268 | s.sqlMock.ExpectRollback() 269 | 270 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1) 271 | assert.Len(s.T(), tasks, 0) 272 | assert.NotNil(s.T(), err) 273 | 274 | // polling when DB succeeds to update task but fails to select the updated tasks 275 | s.sqlMock.ExpectBegin() 276 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes)) 277 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 278 | s.sqlMock.ExpectQuery(selectUpdatedMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 279 | s.sqlMock.ExpectCommit().WillReturnError(errSQL) 280 | 281 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1) 282 | assert.Len(s.T(), tasks, 0) 283 | assert.NotNil(s.T(), err) 284 | 285 | // polling successfully 286 | s.sqlMock.ExpectBegin() 287 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes)) 288 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 289 | s.sqlMock.ExpectQuery(selectUpdatedMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 290 | s.sqlMock.ExpectCommit() 291 | 292 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1) 293 | assert.Len(s.T(), tasks, 1) 294 | assert.Nil(s.T(), err) 295 | 296 | // polling successfully with unknown ordering 297 | s.sqlMock.ExpectBegin() 298 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes)) 299 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 300 | s.sqlMock.ExpectQuery(selectUpdatedMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 301 | s.sqlMock.ExpectCommit() 302 | 303 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, -1, 1) 304 | assert.Len(s.T(), tasks, 1) 305 | assert.Nil(s.T(), err) 306 | } 307 | 308 | func (s *MySQLTestSuite) TestCleanTasks() { 309 | deleteMockRegexp := regexp.QuoteMeta(`DELETE 310 | FROM 311 | test_tasks 312 | WHERE 313 | status IN (?, ?) AND 314 | created_at <= ?;`) 315 | 316 | // cleaning when DB returns error 317 | s.sqlMock.ExpectExec(deleteMockRegexp).WillReturnError(errSQL) 318 | 319 | rowsAffected, err := s.mockedRepository.CleanTasks(ctx, time.Hour) 320 | assert.Zero(s.T(), rowsAffected) 321 | assert.NotNil(s.T(), err) 322 | 323 | // cleaning when no rows are found 324 | s.sqlMock.ExpectExec(deleteMockRegexp).WillReturnResult(driver.ResultNoRows) 325 | 326 | rowsAffected, err = s.mockedRepository.CleanTasks(ctx, time.Hour) 327 | assert.Equal(s.T(), int64(0), rowsAffected) 328 | assert.NotNil(s.T(), err) 329 | 330 | // cleaning successful 331 | s.sqlMock.ExpectExec(deleteMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 332 | 333 | rowsAffected, err = s.mockedRepository.CleanTasks(ctx, time.Hour) 334 | assert.Equal(s.T(), int64(1), rowsAffected) 335 | assert.Nil(s.T(), err) 336 | } 337 | 338 | func (s *MySQLTestSuite) TestRegisterStart() { 339 | var ( 340 | updateMockRegexp = regexp.QuoteMeta(`UPDATE 341 | test_tasks 342 | SET 343 | status = ?, 344 | started_at = ? 345 | WHERE 346 | id = ?;`) 347 | selectMockRegexp = regexp.QuoteMeta(`SELECT * 348 | FROM 349 | test_tasks 350 | WHERE 351 | id = ?;`) 352 | ) 353 | 354 | // beginning the transaction fails 355 | s.sqlMock.ExpectBegin().WillReturnError(errSQL) 356 | 357 | task, err := s.mockedRepository.RegisterStart(ctx, testTask) 358 | assert.Empty(s.T(), task) 359 | assert.NotNil(s.T(), err) 360 | 361 | // registering start when update fails 362 | s.sqlMock.ExpectBegin() 363 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL) 364 | s.sqlMock.ExpectRollback() 365 | 366 | task, err = s.mockedRepository.RegisterStart(ctx, testTask) 367 | assert.Empty(s.T(), task) 368 | assert.NotNil(s.T(), err) 369 | 370 | // registering start when update is successful but select fails 371 | s.sqlMock.ExpectBegin() 372 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 373 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL) 374 | s.sqlMock.ExpectRollback() 375 | 376 | task, err = s.mockedRepository.RegisterStart(ctx, testTask) 377 | assert.Empty(s.T(), task) 378 | assert.NotNil(s.T(), err) 379 | 380 | // registering start when commit fails 381 | s.sqlMock.ExpectBegin() 382 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 383 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 384 | s.sqlMock.ExpectCommit().WillReturnError(errSQL) 385 | 386 | task, err = s.mockedRepository.RegisterStart(ctx, testTask) 387 | assert.Empty(s.T(), task) 388 | assert.NotNil(s.T(), err) 389 | 390 | // registering error when update is successful but select returns no rows 391 | s.sqlMock.ExpectBegin() 392 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 393 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns)) 394 | s.sqlMock.ExpectRollback() 395 | 396 | task, err = s.mockedRepository.RegisterStart(ctx, testTask) 397 | assert.Empty(s.T(), task) 398 | assert.NotNil(s.T(), err) 399 | 400 | // registering start successful 401 | s.sqlMock.ExpectBegin() 402 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 403 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 404 | s.sqlMock.ExpectCommit() 405 | 406 | task, err = s.mockedRepository.RegisterStart(ctx, testTask) 407 | assert.NotEmpty(s.T(), task) 408 | assert.Nil(s.T(), err) 409 | } 410 | 411 | func (s *MySQLTestSuite) TestRegisterError() { 412 | var ( 413 | updateMockRegexp = regexp.QuoteMeta(`UPDATE 414 | test_tasks 415 | SET 416 | last_error = ? 417 | WHERE 418 | id = ?;`) 419 | selectMockRegexp = regexp.QuoteMeta(`SELECT * 420 | FROM 421 | test_tasks 422 | WHERE 423 | id = ?;`) 424 | ) 425 | 426 | // beginning the transaction fails 427 | s.sqlMock.ExpectBegin().WillReturnError(errSQL) 428 | 429 | task, err := s.mockedRepository.RegisterError(ctx, testTask, errTask) 430 | assert.Empty(s.T(), task) 431 | assert.NotNil(s.T(), err) 432 | 433 | // registering error when update fails 434 | s.sqlMock.ExpectBegin() 435 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL) 436 | s.sqlMock.ExpectRollback() 437 | 438 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask) 439 | assert.Empty(s.T(), task) 440 | assert.NotNil(s.T(), err) 441 | 442 | // registering error when update is successful but select fails 443 | s.sqlMock.ExpectBegin() 444 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 445 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL) 446 | s.sqlMock.ExpectRollback() 447 | 448 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask) 449 | assert.Empty(s.T(), task) 450 | assert.NotNil(s.T(), err) 451 | 452 | // registering error when update is successful but select returns no rows 453 | s.sqlMock.ExpectBegin() 454 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 455 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns)) 456 | s.sqlMock.ExpectRollback() 457 | 458 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask) 459 | assert.Empty(s.T(), task) 460 | assert.NotNil(s.T(), err) 461 | 462 | // registering error when commit fails 463 | s.sqlMock.ExpectBegin() 464 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 465 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 466 | s.sqlMock.ExpectCommit().WillReturnError(errSQL) 467 | 468 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask) 469 | assert.Empty(s.T(), task) 470 | assert.NotNil(s.T(), err) 471 | 472 | // registering error successful 473 | s.sqlMock.ExpectBegin() 474 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 475 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 476 | s.sqlMock.ExpectCommit() 477 | 478 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask) 479 | assert.NotEmpty(s.T(), task) 480 | assert.Nil(s.T(), err) 481 | } 482 | 483 | func (s *MySQLTestSuite) TestRegisterFinish() { 484 | var ( 485 | updateMockRegexp = regexp.QuoteMeta(`UPDATE 486 | test_tasks 487 | SET 488 | status = ?, 489 | finished_at = ? 490 | WHERE 491 | id = ?;`) 492 | selectMockRegexp = regexp.QuoteMeta(`SELECT * 493 | FROM 494 | test_tasks 495 | WHERE 496 | id = ?;`) 497 | ) 498 | 499 | // beginning the transaction fails 500 | s.sqlMock.ExpectBegin().WillReturnError(errSQL) 501 | 502 | task, err := s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful) 503 | assert.Empty(s.T(), task) 504 | assert.NotNil(s.T(), err) 505 | 506 | // registering success when update fails 507 | s.sqlMock.ExpectBegin() 508 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL) 509 | s.sqlMock.ExpectRollback() 510 | 511 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful) 512 | assert.Empty(s.T(), task) 513 | assert.NotNil(s.T(), err) 514 | 515 | // registering success when update is successful but select fails 516 | s.sqlMock.ExpectBegin() 517 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 518 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL) 519 | s.sqlMock.ExpectRollback() 520 | 521 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful) 522 | assert.Empty(s.T(), task) 523 | assert.NotNil(s.T(), err) 524 | 525 | // registering success when update is successful but select returns no rows 526 | s.sqlMock.ExpectBegin() 527 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 528 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns)) 529 | s.sqlMock.ExpectRollback() 530 | 531 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful) 532 | assert.Empty(s.T(), task) 533 | assert.NotNil(s.T(), err) 534 | 535 | // registering success when commit fails 536 | s.sqlMock.ExpectBegin() 537 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 538 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 539 | s.sqlMock.ExpectCommit().WillReturnError(errSQL) 540 | 541 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful) 542 | assert.Empty(s.T(), task) 543 | assert.NotNil(s.T(), err) 544 | 545 | // registering success successful 546 | s.sqlMock.ExpectBegin() 547 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 548 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 549 | s.sqlMock.ExpectCommit() 550 | 551 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful) 552 | assert.NotEmpty(s.T(), task) 553 | assert.Nil(s.T(), err) 554 | } 555 | 556 | func (s *MySQLTestSuite) TestSubmitTask() { 557 | var ( 558 | insertMockRegexp = regexp.QuoteMeta(`INSERT INTO 559 | test_tasks 560 | (id, type, args, queue, priority, status, max_receives, created_at, visible_at) 561 | VALUES 562 | (?, ?, ?, ?, ?, ?, ?, ?, ?);`) 563 | selectMockRegexp = regexp.QuoteMeta(`SELECT * 564 | FROM 565 | test_tasks 566 | WHERE 567 | id = ?;`) 568 | ) 569 | 570 | // beginning the transaction fails 571 | s.sqlMock.ExpectBegin().WillReturnError(errSQL) 572 | 573 | task, err := s.mockedRepository.SubmitTask(ctx, testTask) 574 | assert.Empty(s.T(), task) 575 | assert.NotNil(s.T(), err) 576 | 577 | // registering failure when update fails 578 | s.sqlMock.ExpectBegin() 579 | s.sqlMock.ExpectExec(insertMockRegexp).WillReturnError(errSQL) 580 | s.sqlMock.ExpectRollback() 581 | 582 | task, err = s.mockedRepository.SubmitTask(ctx, testTask) 583 | assert.Empty(s.T(), task) 584 | assert.NotNil(s.T(), err) 585 | 586 | // registering failure when update is successful but select fails 587 | s.sqlMock.ExpectBegin() 588 | s.sqlMock.ExpectExec(insertMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 589 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL) 590 | s.sqlMock.ExpectRollback() 591 | 592 | task, err = s.mockedRepository.SubmitTask(ctx, testTask) 593 | assert.Empty(s.T(), task) 594 | assert.NotNil(s.T(), err) 595 | 596 | // registering failure when update is successful but select returns no rows 597 | s.sqlMock.ExpectBegin() 598 | s.sqlMock.ExpectExec(insertMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 599 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns)) 600 | s.sqlMock.ExpectRollback() 601 | 602 | task, err = s.mockedRepository.SubmitTask(ctx, testTask) 603 | assert.Empty(s.T(), task) 604 | assert.NotNil(s.T(), err) 605 | 606 | // registering failure when commit fails 607 | s.sqlMock.ExpectBegin() 608 | s.sqlMock.ExpectExec(insertMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 609 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 610 | s.sqlMock.ExpectCommit().WillReturnError(errSQL) 611 | 612 | task, err = s.mockedRepository.SubmitTask(ctx, testTask) 613 | assert.Empty(s.T(), task) 614 | assert.NotNil(s.T(), err) 615 | 616 | // registering failure successful 617 | s.sqlMock.ExpectBegin() 618 | s.sqlMock.ExpectExec(insertMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 619 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 620 | s.sqlMock.ExpectCommit() 621 | 622 | task, err = s.mockedRepository.SubmitTask(ctx, testTask) 623 | assert.NotEmpty(s.T(), task) 624 | assert.Nil(s.T(), err) 625 | } 626 | 627 | func (s *MySQLTestSuite) TestDeleteTask() { 628 | var ( 629 | deleteMockRegexp = regexp.QuoteMeta(`DELETE 630 | FROM 631 | test_tasks 632 | WHERE 633 | id = ?;`) 634 | deleteSafeDeleteMockRegexp = regexp.QuoteMeta(`DELETE 635 | FROM 636 | test_tasks 637 | WHERE 638 | id = ? AND 639 | ( 640 | ( 641 | visible_at <= ? 642 | ) OR 643 | ( 644 | status IN (?) AND 645 | visible_at > ? 646 | ) 647 | );`) 648 | ) 649 | 650 | // deleting task when DB returns error 651 | s.sqlMock.ExpectExec(deleteMockRegexp).WillReturnError(errSQL) 652 | 653 | err := s.mockedRepository.DeleteTask(ctx, testTask, false) 654 | assert.NotNil(s.T(), err) 655 | 656 | // deleting task successful 657 | s.sqlMock.ExpectExec(deleteMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 658 | 659 | err = s.mockedRepository.DeleteTask(ctx, testTask, false) 660 | assert.Nil(s.T(), err) 661 | 662 | // deleting invisible task successful 663 | s.sqlMock.ExpectExec(deleteSafeDeleteMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 664 | 665 | err = s.mockedRepository.DeleteTask(ctx, testTask, true) 666 | assert.Nil(s.T(), err) 667 | } 668 | 669 | func (s *MySQLTestSuite) TestCountTasks() { 670 | selectMockRegexp := regexp.QuoteMeta(`SELECT 671 | COUNT(*) 672 | FROM 673 | test_tasks 674 | WHERE 675 | status IN (?) AND 676 | type IN (?) AND 677 | queue IN (?);`) 678 | 679 | // counting when DB returns error 680 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL) 681 | 682 | count, err := s.mockedRepository.CountTasks(ctx, []tasq.TaskStatus{testTask.Status}, []string{testTask.Type}, []string{testTask.Queue}) 683 | assert.Equal(s.T(), int64(0), count) 684 | assert.NotNil(s.T(), err) 685 | 686 | // counting successful 687 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(10)) 688 | 689 | count, err = s.mockedRepository.CountTasks(ctx, []tasq.TaskStatus{testTask.Status}, []string{testTask.Type}, []string{testTask.Queue}) 690 | assert.Equal(s.T(), int64(10), count) 691 | assert.Nil(s.T(), err) 692 | } 693 | 694 | func (s *MySQLTestSuite) TestScanTasks() { 695 | selectMockRegexp := regexp.QuoteMeta(`SELECT 696 | * 697 | FROM 698 | test_tasks 699 | WHERE 700 | status IN (?) AND 701 | type IN (?) AND 702 | queue IN (?) 703 | ORDER BY ? 704 | LIMIT ?;`) 705 | 706 | // scanning when DB returns error 707 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL) 708 | 709 | tasks, err := s.mockedRepository.ScanTasks(ctx, []tasq.TaskStatus{testTask.Status}, []string{testTask.Type}, []string{testTask.Queue}, tasq.OrderingCreatedAtFirst, 10) 710 | assert.Empty(s.T(), tasks) 711 | assert.NotNil(s.T(), err) 712 | 713 | // scanning successful 714 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 715 | 716 | tasks, err = s.mockedRepository.ScanTasks(ctx, []tasq.TaskStatus{testTask.Status}, []string{testTask.Type}, []string{testTask.Queue}, tasq.OrderingCreatedAtFirst, 10) 717 | assert.NotEmpty(s.T(), tasks) 718 | assert.Nil(s.T(), err) 719 | } 720 | 721 | func (s *MySQLTestSuite) TestPurgeTasks() { 722 | var ( 723 | purgeMockRegexp = regexp.QuoteMeta(`DELETE 724 | FROM 725 | test_tasks 726 | WHERE 727 | status IN (?) AND 728 | queue IN (?);`) 729 | purgeSafeDeleteMockRegexp = regexp.QuoteMeta(`DELETE 730 | FROM 731 | test_tasks 732 | WHERE 733 | status IN (?) AND 734 | queue IN (?) AND 735 | ( 736 | ( visible_at <= ? ) OR 737 | ( 738 | status IN (?) AND 739 | visible_at > ? 740 | ) 741 | );`) 742 | ) 743 | 744 | // purging tasks when DB returns error 745 | s.sqlMock.ExpectExec(purgeMockRegexp).WillReturnError(errSQL) 746 | 747 | count, err := s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false) 748 | assert.Equal(s.T(), int64(0), count) 749 | assert.NotNil(s.T(), err) 750 | 751 | // purging when no rows are found 752 | s.sqlMock.ExpectExec(purgeMockRegexp).WillReturnResult(driver.ResultNoRows) 753 | 754 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false) 755 | assert.Equal(s.T(), int64(0), count) 756 | assert.NotNil(s.T(), err) 757 | 758 | // purging tasks successful 759 | s.sqlMock.ExpectExec(purgeMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 760 | 761 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false) 762 | assert.Equal(s.T(), int64(1), count) 763 | assert.Nil(s.T(), err) 764 | 765 | // purging tasks with safeDelete successful 766 | s.sqlMock.ExpectExec(purgeSafeDeleteMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 767 | 768 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, true) 769 | assert.Equal(s.T(), int64(1), count) 770 | assert.Nil(s.T(), err) 771 | } 772 | 773 | func (s *MySQLTestSuite) TestRequeueTask() { 774 | var ( 775 | updateMockRegexp = regexp.QuoteMeta(`UPDATE 776 | test_tasks 777 | SET 778 | status = ? 779 | WHERE 780 | id = ?;`) 781 | selectMockRegexp = regexp.QuoteMeta(`SELECT * 782 | FROM 783 | test_tasks 784 | WHERE 785 | id = ?;`) 786 | ) 787 | 788 | // beginning the transaction fails 789 | s.sqlMock.ExpectBegin().WillReturnError(errSQL) 790 | 791 | task, err := s.mockedRepository.RequeueTask(ctx, testTask) 792 | assert.Empty(s.T(), task) 793 | assert.NotNil(s.T(), err) 794 | 795 | // requeuing when update fails 796 | s.sqlMock.ExpectBegin() 797 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL) 798 | s.sqlMock.ExpectRollback() 799 | 800 | task, err = s.mockedRepository.RequeueTask(ctx, testTask) 801 | assert.Empty(s.T(), task) 802 | assert.NotNil(s.T(), err) 803 | 804 | // requeuing when update is successful but select fails 805 | s.sqlMock.ExpectBegin() 806 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 807 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL) 808 | s.sqlMock.ExpectRollback() 809 | 810 | task, err = s.mockedRepository.RequeueTask(ctx, testTask) 811 | assert.Empty(s.T(), task) 812 | assert.NotNil(s.T(), err) 813 | 814 | // requeuing when update is successful but select returns no rows 815 | s.sqlMock.ExpectBegin() 816 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 817 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns)) 818 | s.sqlMock.ExpectRollback() 819 | 820 | task, err = s.mockedRepository.RequeueTask(ctx, testTask) 821 | assert.Empty(s.T(), task) 822 | assert.NotNil(s.T(), err) 823 | 824 | // requeuing when commit fails 825 | s.sqlMock.ExpectBegin() 826 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 827 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 828 | s.sqlMock.ExpectCommit().WillReturnError(errSQL) 829 | 830 | task, err = s.mockedRepository.RequeueTask(ctx, testTask) 831 | assert.Empty(s.T(), task) 832 | assert.NotNil(s.T(), err) 833 | 834 | // requeuing successful 835 | s.sqlMock.ExpectBegin() 836 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1)) 837 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 838 | s.sqlMock.ExpectCommit() 839 | 840 | task, err = s.mockedRepository.RequeueTask(ctx, testTask) 841 | assert.NotEmpty(s.T(), task) 842 | assert.Nil(s.T(), err) 843 | } 844 | 845 | func (s *MySQLTestSuite) TestGetQueryWithTableName() { 846 | var ( 847 | taskUUID = uuid.New() 848 | taskUUIDBytes, _ = taskUUID.MarshalBinary() 849 | ) 850 | 851 | mysqlRepository, ok := s.mockedRepository.(*mysql.Repository) 852 | require.True(s.T(), ok) 853 | 854 | assert.Panics(s.T(), func() { 855 | _, _ = mysqlRepository.GetQueryWithTableName("SELECT * FROM {{.tableName}} WHERE id = :taskID:;", map[string]any{ 856 | "taskID": taskUUIDBytes, 857 | }) 858 | }) 859 | 860 | assert.Panics(s.T(), func() { 861 | _, _ = mysqlRepository.GetQueryWithTableName("SELECT * FROM {{.tableName}} WHERE id IN (:taskIDs);", map[string]any{ 862 | "taskIDs": [][]byte{}, 863 | }) 864 | }) 865 | 866 | assert.NotPanics(s.T(), func() { 867 | query, args := mysqlRepository.GetQueryWithTableName("SELECT * FROM {{.tableName}} WHERE id = :taskID", map[string]any{ 868 | "taskID": taskUUIDBytes, 869 | }) 870 | assert.Equal(s.T(), "SELECT * FROM test_tasks WHERE id = ?", query) 871 | assert.Contains(s.T(), args, taskUUIDBytes) 872 | }) 873 | } 874 | 875 | func (s *MySQLTestSuite) TestInterpolateSQL() { 876 | params := map[string]any{"tableName": "test_table"} 877 | 878 | // Interpolate SQL successfully 879 | interpolatedSQL := mysql.InterpolateSQL("SELECT * FROM {{.tableName}}", params) 880 | assert.Equal(s.T(), "SELECT * FROM test_table", interpolatedSQL) 881 | 882 | // Fail interpolaing unparseable SQL template 883 | assert.Panics(s.T(), func() { 884 | unparseableTemplateSQL := mysql.InterpolateSQL("SELECT * FROM {{.tableName", params) 885 | assert.Empty(s.T(), unparseableTemplateSQL) 886 | }) 887 | 888 | // Fail interpolaing unexecutable SQL template 889 | assert.Panics(s.T(), func() { 890 | unexecutableTemplateSQL := mysql.InterpolateSQL(`SELECT * FROM {{if .tableName eq 1}} {{end}} {{.tableName}}`, params) 891 | assert.Empty(s.T(), unexecutableTemplateSQL) 892 | }) 893 | } 894 | -------------------------------------------------------------------------------- /repository/postgres/export_test.go: -------------------------------------------------------------------------------- 1 | package postgres 2 | 3 | import ( 4 | "database/sql" 5 | "database/sql/driver" 6 | 7 | "github.com/greencoda/tasq" 8 | "github.com/jmoiron/sqlx" 9 | ) 10 | 11 | func GetTestTaskValues(task *tasq.Task) []driver.Value { 12 | testMySQLTask := newFromTask(task) 13 | 14 | return []driver.Value{ 15 | testMySQLTask.ID, 16 | testMySQLTask.Type, 17 | testMySQLTask.Args, 18 | testMySQLTask.Queue, 19 | testMySQLTask.Priority, 20 | testMySQLTask.Status, 21 | testMySQLTask.ReceiveCount, 22 | testMySQLTask.MaxReceives, 23 | testMySQLTask.LastError, 24 | testMySQLTask.CreatedAt, 25 | testMySQLTask.StartedAt, 26 | testMySQLTask.FinishedAt, 27 | testMySQLTask.VisibleAt, 28 | } 29 | } 30 | 31 | func (d *Repository) PrepareWithTableName(sqlTemplate string) *sqlx.NamedStmt { 32 | return d.prepareWithTableName(sqlTemplate) 33 | } 34 | 35 | func (d *Repository) CloseNamedStmt(stmt closeableStmt) { 36 | d.closeStmt(stmt) 37 | } 38 | 39 | func InterpolateSQL(sql string, params map[string]any) string { 40 | return interpolateSQL(sql, params) 41 | } 42 | 43 | func StringToSQLNullString(input *string) sql.NullString { 44 | return stringToSQLNullString(input) 45 | } 46 | 47 | func ParseNullableString(input sql.NullString) *string { 48 | return parseNullableString(input) 49 | } 50 | -------------------------------------------------------------------------------- /repository/postgres/postgres.go: -------------------------------------------------------------------------------- 1 | // Package postgres provides the implementation of a tasq repository in PostgreSQL 2 | package postgres 3 | 4 | import ( 5 | "bytes" 6 | "context" 7 | "database/sql" 8 | "errors" 9 | "fmt" 10 | "strings" 11 | "text/template" 12 | "time" 13 | 14 | "github.com/google/uuid" 15 | "github.com/greencoda/tasq" 16 | "github.com/jmoiron/sqlx" 17 | "github.com/lib/pq" 18 | ) 19 | 20 | const driverName = "postgres" 21 | 22 | var ( 23 | errUnexpectedDataSourceType = errors.New("unexpected dataSource type") 24 | errFailedToExecuteUpdate = errors.New("failed to execute update query") 25 | errFailedToExecuteDelete = errors.New("failed to execute delete query") 26 | errFailedToExecuteInsert = errors.New("failed to execute insert query") 27 | errFailedToExecuteCreateTable = errors.New("failed to execute create table query") 28 | errFailedToExecuteCreateType = errors.New("failed to execute create type query") 29 | ) 30 | 31 | // Repository implements the menthods necessary for tasq to work in PostgreSQL. 32 | type Repository struct { 33 | db *sqlx.DB 34 | statusTypeName string 35 | tableName string 36 | } 37 | 38 | // NewRepository creates a new PostgreSQL Repository instance. 39 | func NewRepository(dataSource any, prefix string) (*Repository, error) { 40 | switch d := dataSource.(type) { 41 | case string: 42 | return newRepositoryFromDSN(d, prefix) 43 | case *sql.DB: 44 | return newRepositoryFromDB(d, prefix) 45 | } 46 | 47 | return nil, fmt.Errorf("%w: %T", errUnexpectedDataSourceType, dataSource) 48 | } 49 | 50 | func newRepositoryFromDSN(dsn string, prefix string) (*Repository, error) { 51 | dbx, _ := sqlx.Open(driverName, dsn) 52 | 53 | return &Repository{ 54 | db: dbx, 55 | statusTypeName: statusTypeName(prefix), 56 | tableName: tableName(prefix), 57 | }, nil 58 | } 59 | 60 | func newRepositoryFromDB(db *sql.DB, prefix string) (*Repository, error) { 61 | dbx := sqlx.NewDb(db, driverName) 62 | 63 | return &Repository{ 64 | db: dbx, 65 | statusTypeName: statusTypeName(prefix), 66 | tableName: tableName(prefix), 67 | }, nil 68 | } 69 | 70 | // Migrate prepares the database with the task status type 71 | // and by adding the tasks table. 72 | func (d *Repository) Migrate(ctx context.Context) error { 73 | err := d.migrateStatus(ctx) 74 | if err != nil { 75 | return err 76 | } 77 | 78 | err = d.migrateTable(ctx) 79 | if err != nil { 80 | return err 81 | } 82 | 83 | return nil 84 | } 85 | 86 | // PingTasks pings a list of tasks by their ID 87 | // and extends their invisibility timestamp with the supplied timeout parameter. 88 | func (d *Repository) PingTasks(ctx context.Context, taskIDs []uuid.UUID, visibilityTimeout time.Duration) ([]*tasq.Task, error) { 89 | if len(taskIDs) == 0 { 90 | return []*tasq.Task{}, nil 91 | } 92 | 93 | var ( 94 | pingedTasks []*postgresTask 95 | pingTime = time.Now() 96 | sqlTemplate = `UPDATE 97 | {{.tableName}} 98 | SET 99 | "visible_at" = :visibleAt 100 | WHERE 101 | "id" = ANY(:pingedTaskIDs) 102 | RETURNING id;` 103 | stmt = d.prepareWithTableName(sqlTemplate) 104 | ) 105 | 106 | defer d.closeStmt(stmt) 107 | 108 | err := stmt.SelectContext(ctx, &pingedTasks, map[string]any{ 109 | "visibleAt": pingTime.Add(visibilityTimeout), 110 | "pingedTaskIDs": pq.Array(taskIDs), 111 | }) 112 | if err != nil && !errors.Is(err, sql.ErrNoRows) { 113 | return []*tasq.Task{}, fmt.Errorf("failed to update tasks: %w", err) 114 | } 115 | 116 | return postgresTasksToTasks(pingedTasks), nil 117 | } 118 | 119 | // PollTasks polls for available tasks matching supplied the parameters 120 | // and sets their invisibility the supplied timeout parameter to the future. 121 | func (d *Repository) PollTasks(ctx context.Context, types, queues []string, visibilityTimeout time.Duration, ordering tasq.Ordering, pollLimit int) ([]*tasq.Task, error) { 122 | if pollLimit == 0 { 123 | return []*tasq.Task{}, nil 124 | } 125 | 126 | var ( 127 | polledTasks []*postgresTask 128 | pollTime = time.Now() 129 | sqlTemplate = `UPDATE {{.tableName}} SET 130 | "status" = :status, 131 | "receive_count" = "receive_count" + 1, 132 | "visible_at" = :visibleAt 133 | WHERE 134 | "id" IN ( 135 | SELECT 136 | "id" FROM {{.tableName}} 137 | WHERE 138 | "type" = ANY(:pollTypes) AND 139 | "queue" = ANY(:pollQueues) AND 140 | "status" = ANY(:pollStatuses) AND 141 | "visible_at" <= :pollTime 142 | ORDER BY 143 | :pollOrdering 144 | LIMIT :pollLimit 145 | FOR UPDATE ) 146 | RETURNING *;` 147 | stmt = d.prepareWithTableName(sqlTemplate) 148 | ) 149 | 150 | defer d.closeStmt(stmt) 151 | 152 | err := stmt.SelectContext(ctx, &polledTasks, map[string]any{ 153 | "status": tasq.StatusEnqueued, 154 | "visibleAt": pollTime.Add(visibilityTimeout), 155 | "pollTypes": pq.Array(types), 156 | "pollQueues": pq.Array(queues), 157 | "pollStatuses": pq.Array(tasq.GetTaskStatuses(tasq.OpenTasks)), 158 | "pollTime": pollTime, 159 | "pollOrdering": pq.Array(getOrderingDirectives(ordering)), 160 | "pollLimit": pollLimit, 161 | }) 162 | if err != nil && !errors.Is(err, sql.ErrNoRows) { 163 | return []*tasq.Task{}, fmt.Errorf("failed to update tasks: %w", err) 164 | } 165 | 166 | return postgresTasksToTasks(polledTasks), nil 167 | } 168 | 169 | // CleanTasks removes finished tasks from the queue 170 | // if their creation date is past the supplied duration. 171 | func (d *Repository) CleanTasks(ctx context.Context, cleanAge time.Duration) (int64, error) { 172 | var ( 173 | cleanTime = time.Now() 174 | sqlTemplate = `DELETE FROM {{.tableName}} 175 | WHERE 176 | "status" = ANY(:statuses) AND 177 | "created_at" <= :cleanAt;` 178 | stmt = d.prepareWithTableName(sqlTemplate) 179 | ) 180 | 181 | defer d.closeStmt(stmt) 182 | 183 | result, err := stmt.ExecContext(ctx, map[string]any{ 184 | "statuses": pq.Array(tasq.GetTaskStatuses(tasq.FinishedTasks)), 185 | "cleanAt": cleanTime.Add(-cleanAge), 186 | }) 187 | if err != nil { 188 | return 0, fmt.Errorf("failed to delete tasks: %w", err) 189 | } 190 | 191 | rowsAffected, err := result.RowsAffected() 192 | if err != nil { 193 | return 0, fmt.Errorf("failed to get number of affected rows: %w", err) 194 | } 195 | 196 | return rowsAffected, nil 197 | } 198 | 199 | // RegisterStart marks a task as started with the 'in progress' status 200 | // and records the time of start. 201 | func (d *Repository) RegisterStart(ctx context.Context, task *tasq.Task) (*tasq.Task, error) { 202 | var ( 203 | updatedTask = new(postgresTask) 204 | startTime = time.Now() 205 | sqlTemplate = `UPDATE {{.tableName}} SET 206 | "status" = :status, 207 | "started_at" = :startTime 208 | WHERE 209 | "id" = :taskID 210 | RETURNING *;` 211 | stmt = d.prepareWithTableName(sqlTemplate) 212 | ) 213 | 214 | defer d.closeStmt(stmt) 215 | 216 | err := stmt. 217 | QueryRowContext(ctx, map[string]any{ 218 | "status": tasq.StatusInProgress, 219 | "startTime": startTime, 220 | "taskID": task.ID, 221 | }). 222 | StructScan(updatedTask) 223 | if err != nil { 224 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err) 225 | } 226 | 227 | return updatedTask.toTask(), nil 228 | } 229 | 230 | // RegisterError records an error message on the task as last error. 231 | func (d *Repository) RegisterError(ctx context.Context, task *tasq.Task, errTask error) (*tasq.Task, error) { 232 | var ( 233 | updatedTask = new(postgresTask) 234 | sqlTemplate = `UPDATE {{.tableName}} SET 235 | "last_error" = :errorMessage 236 | WHERE 237 | "id" = :taskID 238 | RETURNING *;` 239 | stmt = d.prepareWithTableName(sqlTemplate) 240 | ) 241 | 242 | defer d.closeStmt(stmt) 243 | 244 | err := stmt. 245 | QueryRowContext(ctx, map[string]any{ 246 | "errorMessage": errTask.Error(), 247 | "taskID": task.ID, 248 | }). 249 | StructScan(updatedTask) 250 | if err != nil { 251 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err) 252 | } 253 | 254 | return updatedTask.toTask(), nil 255 | } 256 | 257 | // RegisterFinish marks a task as finished with the supplied status 258 | // and records the time of finish. 259 | func (d *Repository) RegisterFinish(ctx context.Context, task *tasq.Task, finishStatus tasq.TaskStatus) (*tasq.Task, error) { 260 | var ( 261 | updatedTask = new(postgresTask) 262 | finishTime = time.Now() 263 | sqlTemplate = `UPDATE {{.tableName}} SET 264 | "status" = :status, 265 | "finished_at" = :finishTime 266 | WHERE 267 | "id" = :taskID 268 | RETURNING *;` 269 | stmt = d.prepareWithTableName(sqlTemplate) 270 | ) 271 | 272 | defer d.closeStmt(stmt) 273 | 274 | err := stmt. 275 | QueryRowContext(ctx, map[string]any{ 276 | "status": finishStatus, 277 | "finishTime": finishTime, 278 | "taskID": task.ID, 279 | }). 280 | StructScan(updatedTask) 281 | if err != nil { 282 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err) 283 | } 284 | 285 | return updatedTask.toTask(), nil 286 | } 287 | 288 | // SubmitTask adds the supplied task to the queue. 289 | func (d *Repository) SubmitTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) { 290 | var ( 291 | postgresTask = newFromTask(task) 292 | sqlTemplate = `INSERT INTO {{.tableName}} 293 | (id, type, args, queue, priority, status, max_receives, created_at, visible_at) 294 | VALUES 295 | (:id, :type, :args, :queue, :priority, :status, :maxReceives, :createdAt, :visibleAt) 296 | RETURNING *;` 297 | stmt = d.prepareWithTableName(sqlTemplate) 298 | ) 299 | 300 | defer d.closeStmt(stmt) 301 | 302 | err := stmt. 303 | QueryRowContext(ctx, map[string]any{ 304 | "id": postgresTask.ID, 305 | "type": postgresTask.Type, 306 | "args": postgresTask.Args, 307 | "queue": postgresTask.Queue, 308 | "priority": postgresTask.Priority, 309 | "status": postgresTask.Status, 310 | "maxReceives": postgresTask.MaxReceives, 311 | "createdAt": postgresTask.CreatedAt, 312 | "visibleAt": postgresTask.VisibleAt, 313 | }). 314 | StructScan(postgresTask) 315 | if err != nil { 316 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteInsert, err) 317 | } 318 | 319 | return postgresTask.toTask(), nil 320 | } 321 | 322 | // DeleteTask removes the supplied task from the queue. 323 | func (d *Repository) DeleteTask(ctx context.Context, task *tasq.Task, safeDelete bool) error { 324 | var ( 325 | conditions = []string{ 326 | `"id" = :taskID`, 327 | } 328 | parameters = map[string]any{ 329 | "taskID": task.ID, 330 | } 331 | ) 332 | 333 | if safeDelete { 334 | d.applySafeDeleteConditions(&conditions, ¶meters) 335 | } 336 | 337 | sqlTemplate := `DELETE FROM {{.tableName}} WHERE ` + strings.Join(conditions, ` AND `) + `;` 338 | 339 | _, err := d.prepareWithTableName(sqlTemplate).ExecContext(ctx, parameters) 340 | if err != nil { 341 | return fmt.Errorf("%s: %w", errFailedToExecuteDelete, err) 342 | } 343 | 344 | return nil 345 | } 346 | 347 | // RequeueTask marks a task as new, so it can be picked up again. 348 | func (d *Repository) RequeueTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) { 349 | var ( 350 | updatedTask = new(postgresTask) 351 | sqlTemplate = `UPDATE {{.tableName}} SET 352 | "status" = :status 353 | WHERE 354 | "id" = :taskID 355 | RETURNING *;` 356 | stmt = d.prepareWithTableName(sqlTemplate) 357 | ) 358 | 359 | err := stmt. 360 | QueryRowContext(ctx, map[string]any{ 361 | "status": tasq.StatusNew, 362 | "taskID": task.ID, 363 | }). 364 | StructScan(updatedTask) 365 | if err != nil { 366 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err) 367 | } 368 | 369 | return updatedTask.toTask(), err 370 | } 371 | 372 | // CountTasks returns the number of tasks in the queue based on the supplied filters. 373 | func (d *Repository) CountTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (int64, error) { 374 | var ( 375 | count int64 376 | sqlTemplate, parameters = d.buildCountSQLTemplate(taskStatuses, taskTypes, queues) 377 | stmt = d.prepareWithTableName(sqlTemplate) 378 | ) 379 | 380 | err := stmt.GetContext(ctx, &count, parameters) 381 | if err != nil && !errors.Is(err, sql.ErrNoRows) { 382 | return 0, fmt.Errorf("failed to count tasks: %w", err) 383 | } 384 | 385 | return count, nil 386 | } 387 | 388 | // ScanTasks returns a list of tasks in the queue based on the supplied filters. 389 | func (d *Repository) ScanTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) ([]*tasq.Task, error) { 390 | var ( 391 | scannedTasks []*postgresTask 392 | sqlTemplate, parameters = d.buildScanSQLTemplate(taskStatuses, taskTypes, queues, ordering, scanLimit) 393 | stmt = d.prepareWithTableName(sqlTemplate) 394 | ) 395 | 396 | err := stmt.SelectContext(ctx, &scannedTasks, parameters) 397 | if err != nil && !errors.Is(err, sql.ErrNoRows) { 398 | return []*tasq.Task{}, fmt.Errorf("failed to scan tasks: %w", err) 399 | } 400 | 401 | return postgresTasksToTasks(scannedTasks), nil 402 | } 403 | 404 | // PurgeTasks removes all tasks from the queue based on the supplied filters. 405 | func (d *Repository) PurgeTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (int64, error) { 406 | var ( 407 | sqlTemplate, parameters = d.buildPurgeSQLTemplate(taskStatuses, taskTypes, queues, safeDelete) 408 | stmt = d.prepareWithTableName(sqlTemplate) 409 | ) 410 | 411 | result, err := stmt.ExecContext(ctx, parameters) 412 | if err != nil && !errors.Is(err, sql.ErrNoRows) { 413 | return 0, fmt.Errorf("failed to purge tasks: %w", err) 414 | } 415 | 416 | rowsAffected, err := result.RowsAffected() 417 | if err != nil { 418 | return 0, fmt.Errorf("failed to get number of affected rows: %w", err) 419 | } 420 | 421 | return rowsAffected, nil 422 | } 423 | 424 | func (d *Repository) buildCountSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (string, map[string]any) { 425 | var ( 426 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues) 427 | sqlTemplate = `SELECT COUNT(*) FROM {{.tableName}}` 428 | ) 429 | 430 | if len(conditions) > 0 { 431 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ") 432 | } 433 | 434 | return sqlTemplate, parameters 435 | } 436 | 437 | func (d *Repository) buildScanSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) (string, map[string]any) { 438 | var ( 439 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues) 440 | sqlTemplate = `SELECT * FROM {{.tableName}}` 441 | ) 442 | 443 | if len(conditions) > 0 { 444 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ") 445 | } 446 | 447 | sqlTemplate += ` ORDER BY :scanOrdering LIMIT :limit;` 448 | 449 | parameters["scanOrdering"] = pq.Array(getOrderingDirectives(ordering)) 450 | parameters["limit"] = scanLimit 451 | 452 | return sqlTemplate, parameters 453 | } 454 | 455 | func (d *Repository) buildPurgeSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (string, map[string]any) { 456 | var ( 457 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues) 458 | sqlTemplate = `DELETE FROM {{.tableName}}` 459 | ) 460 | 461 | if safeDelete { 462 | d.applySafeDeleteConditions(&conditions, ¶meters) 463 | } 464 | 465 | if len(conditions) > 0 { 466 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ") 467 | } 468 | 469 | return sqlTemplate + `;`, parameters 470 | } 471 | 472 | func (d *Repository) applySafeDeleteConditions(conditions *[]string, parameters *map[string]any) { 473 | *conditions = append(*conditions, `( 474 | ( 475 | "visible_at" <= :visibleAt 476 | ) OR ( 477 | "status" = ANY(:statuses) AND 478 | "visible_at" > :visibleAt 479 | ) 480 | )`) 481 | (*parameters)["statuses"] = pq.Array([]tasq.TaskStatus{tasq.StatusNew}) 482 | (*parameters)["visibleAt"] = time.Now() 483 | } 484 | 485 | func (d *Repository) buildFilterConditions(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) ([]string, map[string]any) { 486 | var ( 487 | conditions []string 488 | parameters = make(map[string]any) 489 | ) 490 | 491 | if len(taskStatuses) > 0 { 492 | conditions = append(conditions, `"status" = ANY(:filterStatuses)`) 493 | parameters["filterStatuses"] = pq.Array(taskStatuses) 494 | } 495 | 496 | if len(taskTypes) > 0 { 497 | conditions = append(conditions, `"type" = ANY(:filterTypes)`) 498 | parameters["filterTypes"] = pq.Array(taskTypes) 499 | } 500 | 501 | if len(queues) > 0 { 502 | conditions = append(conditions, `"queue" = ANY(:filterQueues)`) 503 | parameters["filterQueues"] = pq.Array(queues) 504 | } 505 | 506 | return conditions, parameters 507 | } 508 | 509 | func (d *Repository) migrateStatus(ctx context.Context) error { 510 | var ( 511 | sqlTemplate = `DO $$ 512 | BEGIN 513 | IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = '{{.statusTypeName}}') THEN 514 | CREATE TYPE {{.statusTypeName}} AS ENUM ({{.enumValues}}); 515 | END IF; 516 | END$$;` 517 | query = interpolateSQL(sqlTemplate, map[string]any{ 518 | "statusTypeName": d.statusTypeName, 519 | "enumValues": sliceToPostgreSQLValueList(tasq.GetTaskStatuses(tasq.AllTasks)), 520 | }) 521 | ) 522 | 523 | _, err := d.db.ExecContext(ctx, query) 524 | if err != nil { 525 | return fmt.Errorf("%s: %w", errFailedToExecuteCreateType, err) 526 | } 527 | 528 | return nil 529 | } 530 | 531 | func (d *Repository) migrateTable(ctx context.Context) error { 532 | const sqlTemplate = `CREATE TABLE IF NOT EXISTS {{.tableName}} ( 533 | "id" UUID NOT NULL PRIMARY KEY, 534 | "type" TEXT NOT NULL, 535 | "args" BYTEA NOT NULL, 536 | "queue" TEXT NOT NULL, 537 | "priority" SMALLINT NOT NULL, 538 | "status" {{.statusTypeName}} NOT NULL, 539 | "receive_count" INTEGER NOT NULL DEFAULT 0, 540 | "max_receives" INTEGER NOT NULL DEFAULT 0, 541 | "last_error" TEXT, 542 | "created_at" TIMESTAMPTZ NOT NULL DEFAULT '0001-01-01 00:00:00.000000', 543 | "started_at" TIMESTAMPTZ, 544 | "finished_at" TIMESTAMPTZ, 545 | "visible_at" TIMESTAMPTZ NOT NULL DEFAULT '0001-01-01 00:00:00.000000' 546 | );` 547 | 548 | query := interpolateSQL(sqlTemplate, map[string]any{ 549 | "tableName": d.tableName, 550 | "statusTypeName": d.statusTypeName, 551 | }) 552 | 553 | _, err := d.db.ExecContext(ctx, query) 554 | if err != nil { 555 | return fmt.Errorf("%s: %w", errFailedToExecuteCreateTable, err) 556 | } 557 | 558 | return nil 559 | } 560 | 561 | func (d *Repository) prepareWithTableName(sqlTemplate string) *sqlx.NamedStmt { 562 | query := interpolateSQL(sqlTemplate, map[string]any{ 563 | "tableName": d.tableName, 564 | }) 565 | 566 | namedStmt, err := d.db.PrepareNamed(query) 567 | if err != nil { 568 | panic(err) 569 | } 570 | 571 | return namedStmt 572 | } 573 | 574 | type closeableStmt interface { 575 | Close() error 576 | } 577 | 578 | func (d *Repository) closeStmt(stmt closeableStmt) { 579 | if err := stmt.Close(); err != nil { 580 | panic(err) 581 | } 582 | } 583 | 584 | func getOrderingDirectives(ordering tasq.Ordering) []string { 585 | var ( 586 | OrderingCreatedAtFirst = []string{"created_at ASC", "priority DESC"} 587 | OrderingPriorityFirst = []string{"priority DESC", "created_at ASC"} 588 | ) 589 | 590 | if orderingDirectives, ok := map[tasq.Ordering][]string{ 591 | tasq.OrderingCreatedAtFirst: OrderingCreatedAtFirst, 592 | tasq.OrderingPriorityFirst: OrderingPriorityFirst, 593 | }[ordering]; ok { 594 | return orderingDirectives 595 | } 596 | 597 | return OrderingCreatedAtFirst 598 | } 599 | 600 | func sliceToPostgreSQLValueList[T any](slice []T) string { 601 | stringSlice := make([]string, 0, len(slice)) 602 | 603 | for _, s := range slice { 604 | stringSlice = append(stringSlice, fmt.Sprint(s)) 605 | } 606 | 607 | return fmt.Sprintf("'%s'", strings.Join(stringSlice, "','")) 608 | } 609 | 610 | func statusTypeName(prefix string) string { 611 | const statusTypeName = "task_status" 612 | 613 | if len(prefix) > 0 { 614 | return prefix + "_" + statusTypeName 615 | } 616 | 617 | return statusTypeName 618 | } 619 | 620 | func tableName(prefix string) string { 621 | const tableName = "tasks" 622 | 623 | if len(prefix) > 0 { 624 | return prefix + "_" + tableName 625 | } 626 | 627 | return tableName 628 | } 629 | 630 | func interpolateSQL(sql string, params map[string]any) string { 631 | template, err := template.New("sql").Parse(sql) 632 | if err != nil { 633 | panic(err) 634 | } 635 | 636 | var outputBuffer bytes.Buffer 637 | 638 | err = template.Execute(&outputBuffer, params) 639 | if err != nil { 640 | panic(err) 641 | } 642 | 643 | return outputBuffer.String() 644 | } 645 | -------------------------------------------------------------------------------- /repository/postgres/postgresTask.go: -------------------------------------------------------------------------------- 1 | package postgres 2 | 3 | import ( 4 | "database/sql" 5 | "time" 6 | 7 | "github.com/google/uuid" 8 | "github.com/greencoda/tasq" 9 | ) 10 | 11 | type postgresTask struct { 12 | ID uuid.UUID `db:"id"` 13 | Type string `db:"type"` 14 | Args []byte `db:"args"` 15 | Queue string `db:"queue"` 16 | Priority int16 `db:"priority"` 17 | Status tasq.TaskStatus `db:"status"` 18 | ReceiveCount int32 `db:"receive_count"` 19 | MaxReceives int32 `db:"max_receives"` 20 | LastError sql.NullString `db:"last_error"` 21 | CreatedAt time.Time `db:"created_at"` 22 | StartedAt *time.Time `db:"started_at"` 23 | FinishedAt *time.Time `db:"finished_at"` 24 | VisibleAt time.Time `db:"visible_at"` 25 | } 26 | 27 | func newFromTask(task *tasq.Task) *postgresTask { 28 | return &postgresTask{ 29 | ID: task.ID, 30 | Type: task.Type, 31 | Args: task.Args, 32 | Queue: task.Queue, 33 | Priority: task.Priority, 34 | Status: task.Status, 35 | ReceiveCount: task.ReceiveCount, 36 | MaxReceives: task.MaxReceives, 37 | LastError: stringToSQLNullString(task.LastError), 38 | CreatedAt: task.CreatedAt, 39 | StartedAt: task.StartedAt, 40 | FinishedAt: task.FinishedAt, 41 | VisibleAt: task.VisibleAt, 42 | } 43 | } 44 | 45 | func (t *postgresTask) toTask() *tasq.Task { 46 | return &tasq.Task{ 47 | ID: t.ID, 48 | Type: t.Type, 49 | Args: t.Args, 50 | Queue: t.Queue, 51 | Priority: t.Priority, 52 | Status: t.Status, 53 | ReceiveCount: t.ReceiveCount, 54 | MaxReceives: t.MaxReceives, 55 | LastError: parseNullableString(t.LastError), 56 | CreatedAt: t.CreatedAt, 57 | StartedAt: t.StartedAt, 58 | FinishedAt: t.FinishedAt, 59 | VisibleAt: t.VisibleAt, 60 | } 61 | } 62 | 63 | func postgresTasksToTasks(postgresTasks []*postgresTask) []*tasq.Task { 64 | tasks := make([]*tasq.Task, len(postgresTasks)) 65 | 66 | for i, postgresTask := range postgresTasks { 67 | tasks[i] = postgresTask.toTask() 68 | } 69 | 70 | return tasks 71 | } 72 | 73 | func stringToSQLNullString(input *string) sql.NullString { 74 | if input == nil { 75 | return sql.NullString{ 76 | String: "", 77 | Valid: false, 78 | } 79 | } 80 | 81 | return sql.NullString{ 82 | String: *input, 83 | Valid: true, 84 | } 85 | } 86 | 87 | func parseNullableString(input sql.NullString) *string { 88 | if !input.Valid { 89 | return nil 90 | } 91 | 92 | return &input.String 93 | } 94 | -------------------------------------------------------------------------------- /repository/postgres/postgresTask_test.go: -------------------------------------------------------------------------------- 1 | package postgres_test 2 | 3 | import ( 4 | "database/sql" 5 | "testing" 6 | 7 | "github.com/greencoda/tasq/repository/postgres" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestStringToSQLNullString(t *testing.T) { 12 | t.Parallel() 13 | 14 | var ( 15 | emptyInput = "" 16 | nonEmptyInput = "test" 17 | nilInput *string 18 | ) 19 | 20 | sqlNullString := postgres.StringToSQLNullString(&emptyInput) 21 | assert.Equal(t, sql.NullString{ 22 | String: emptyInput, 23 | Valid: true, 24 | }, sqlNullString) 25 | 26 | sqlNullString = postgres.StringToSQLNullString(&nonEmptyInput) 27 | assert.Equal(t, sql.NullString{ 28 | String: nonEmptyInput, 29 | Valid: true, 30 | }, sqlNullString) 31 | 32 | sqlNullString = postgres.StringToSQLNullString(nilInput) 33 | assert.Equal(t, sql.NullString{ 34 | String: "", 35 | Valid: false, 36 | }, sqlNullString) 37 | } 38 | 39 | func TestParseNullableString(t *testing.T) { 40 | t.Parallel() 41 | 42 | var ( 43 | emptyInput = sql.NullString{ 44 | String: "", 45 | Valid: true, 46 | } 47 | nonEmptyInput = sql.NullString{ 48 | String: "test", 49 | Valid: true, 50 | } 51 | nilInput = sql.NullString{ 52 | String: "", 53 | Valid: false, 54 | } 55 | ) 56 | 57 | output := postgres.ParseNullableString(emptyInput) 58 | assert.NotNil(t, output) 59 | assert.Equal(t, *output, emptyInput.String) 60 | 61 | output = postgres.ParseNullableString(nonEmptyInput) 62 | assert.NotNil(t, output) 63 | assert.Equal(t, *output, nonEmptyInput.String) 64 | 65 | output = postgres.ParseNullableString(nilInput) 66 | assert.Nil(t, output) 67 | } 68 | -------------------------------------------------------------------------------- /repository/postgres/postgres_test.go: -------------------------------------------------------------------------------- 1 | package postgres_test 2 | 3 | import ( 4 | "context" 5 | "database/sql" 6 | "database/sql/driver" 7 | "errors" 8 | "regexp" 9 | "testing" 10 | "time" 11 | 12 | "github.com/DATA-DOG/go-sqlmock" 13 | "github.com/google/uuid" 14 | "github.com/greencoda/tasq" 15 | "github.com/greencoda/tasq/repository/postgres" 16 | "github.com/stretchr/testify/assert" 17 | "github.com/stretchr/testify/require" 18 | "github.com/stretchr/testify/suite" 19 | ) 20 | 21 | var ( 22 | ctx = context.Background() 23 | testTask = getStartedTestTask() 24 | taskColumns = []string{ 25 | "id", 26 | "type", 27 | "args", 28 | "queue", 29 | "priority", 30 | "status", 31 | "receive_count", 32 | "max_receives", 33 | "last_error", 34 | "created_at", 35 | "started_at", 36 | "finished_at", 37 | "visible_at", 38 | } 39 | testTaskType = "testTask" 40 | testTaskQueue = "testQueue" 41 | taskValues = postgres.GetTestTaskValues(testTask) 42 | errSQL = errors.New("sql error") 43 | errTask = errors.New("task error") 44 | ) 45 | 46 | func getStartedTestTask() *tasq.Task { 47 | var ( 48 | testTask, _ = tasq.NewTask(testTaskType, true, testTaskQueue, 100, 5) 49 | startTime = testTask.CreatedAt.Add(time.Second) 50 | ) 51 | 52 | testTask.StartedAt = &startTime 53 | 54 | return testTask 55 | } 56 | 57 | type PostgresTestSuite struct { 58 | suite.Suite 59 | db *sql.DB 60 | sqlMock sqlmock.Sqlmock 61 | mockedRepository tasq.IRepository 62 | } 63 | 64 | func TestTaskTestSuite(t *testing.T) { 65 | t.Parallel() 66 | 67 | suite.Run(t, new(PostgresTestSuite)) 68 | } 69 | 70 | func (s *PostgresTestSuite) SetupTest() { 71 | var err error 72 | 73 | s.db, s.sqlMock, err = sqlmock.New() 74 | require.Nil(s.T(), err) 75 | 76 | s.mockedRepository, err = postgres.NewRepository(s.db, "test") 77 | require.NotNil(s.T(), s.mockedRepository) 78 | require.Nil(s.T(), err) 79 | } 80 | 81 | func (s *PostgresTestSuite) TestNewRepository() { 82 | // providing the datasource as *sql.DB 83 | repository, err := postgres.NewRepository(s.db, "test") 84 | assert.NotNil(s.T(), repository) 85 | assert.Nil(s.T(), err) 86 | 87 | // providing the datasource as *sql.DB with no prefix 88 | repository, err = postgres.NewRepository(s.db, "") 89 | assert.NotNil(s.T(), repository) 90 | assert.Nil(s.T(), err) 91 | 92 | // providing the datasource as dsn string 93 | repository, err = postgres.NewRepository("testDSN", "test") 94 | assert.NotNil(s.T(), repository) 95 | assert.Nil(s.T(), err) 96 | 97 | // providing the datasource as unknown datasource type 98 | repository, err = postgres.NewRepository(false, "test") 99 | assert.Nil(s.T(), repository) 100 | assert.NotNil(s.T(), err) 101 | } 102 | 103 | func (s *PostgresTestSuite) TestMigrate() { 104 | // First try - creating the task_status type fails 105 | s.sqlMock.ExpectExec(`CREATE TYPE test_task_status AS ENUM`).WillReturnError(errSQL) 106 | 107 | err := s.mockedRepository.Migrate(ctx) 108 | assert.NotNil(s.T(), err) 109 | 110 | // Second try - creating the tasks table fails 111 | s.sqlMock.ExpectExec(`CREATE TYPE test_task_status AS ENUM`).WillReturnResult(sqlmock.NewResult(1, 1)) 112 | s.sqlMock.ExpectExec(`CREATE TABLE IF NOT EXISTS test_tasks`).WillReturnError(errSQL) 113 | 114 | err = s.mockedRepository.Migrate(ctx) 115 | assert.NotNil(s.T(), err) 116 | 117 | // Third try - migration succeeds 118 | s.sqlMock.ExpectExec(`CREATE TYPE test_task_status AS ENUM`).WillReturnResult(sqlmock.NewResult(1, 1)) 119 | s.sqlMock.ExpectExec(`CREATE TABLE IF NOT EXISTS test_tasks`).WillReturnResult(sqlmock.NewResult(1, 1)) 120 | 121 | err = s.mockedRepository.Migrate(ctx) 122 | assert.Nil(s.T(), err) 123 | } 124 | 125 | func (s *PostgresTestSuite) TestPingTasks() { 126 | var ( 127 | taskUUID = uuid.New() 128 | stmtMockRegexp = regexp.QuoteMeta(`UPDATE test_tasks SET "visible_at" = $1 WHERE "id" = ANY($2) RETURNING id;`) 129 | ) 130 | // pinging empty tasklist 131 | tasks, err := s.mockedRepository.PingTasks(ctx, []uuid.UUID{}, 15*time.Second) 132 | assert.Len(s.T(), tasks, 0) 133 | assert.Nil(s.T(), err) 134 | 135 | // pinging when stmt preparation returns an error 136 | s.sqlMock.ExpectPrepare(stmtMockRegexp).WillReturnError(errSQL) 137 | 138 | assert.PanicsWithError(s.T(), errSQL.Error(), func() { 139 | _, _ = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second) 140 | }) 141 | 142 | // pinging when DB returns no rows 143 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL) 144 | 145 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second) 146 | assert.Len(s.T(), tasks, 0) 147 | assert.NotNil(s.T(), err) 148 | 149 | // pinging existing task 150 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUID)) 151 | 152 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second) 153 | assert.Len(s.T(), tasks, 1) 154 | assert.Nil(s.T(), err) 155 | } 156 | 157 | func (s *PostgresTestSuite) TestPollTasks() { 158 | stmtMockRegexp := regexp.QuoteMeta(`UPDATE test_tasks SET 159 | "status" = $1, 160 | "receive_count" = "receive_count" + 1, 161 | "visible_at" = $2 162 | WHERE "id" IN ( 163 | SELECT 164 | "id" 165 | FROM test_tasks 166 | WHERE "type" = ANY($3) 167 | AND "queue" = ANY($4) 168 | AND "status" = ANY($5) 169 | AND "visible_at" <= $6 170 | ORDER BY $7 171 | LIMIT $8 172 | FOR UPDATE 173 | ) RETURNING *;`) 174 | 175 | // polling with 0 limit 176 | tasks, err := s.mockedRepository.PollTasks(ctx, []string{testTaskType}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 0) 177 | assert.Len(s.T(), tasks, 0) 178 | assert.Nil(s.T(), err) 179 | 180 | // polling when DB returns no rows 181 | s.sqlMock.ExpectPrepare(stmtMockRegexp). 182 | ExpectQuery(). 183 | WillReturnError(sql.ErrNoRows) 184 | 185 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{testTaskType}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1) 186 | assert.Len(s.T(), tasks, 0) 187 | assert.Nil(s.T(), err) 188 | 189 | // polling when DB returns error 190 | s.sqlMock.ExpectPrepare(stmtMockRegexp). 191 | ExpectQuery(). 192 | WillReturnError(errSQL) 193 | 194 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{testTaskType}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1) 195 | assert.Len(s.T(), tasks, 0) 196 | assert.NotNil(s.T(), err) 197 | 198 | // polling for existing tasks 199 | s.sqlMock.ExpectPrepare(stmtMockRegexp). 200 | ExpectQuery(). 201 | WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 202 | 203 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{testTaskType}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1) 204 | assert.Len(s.T(), tasks, 1) 205 | assert.Nil(s.T(), err) 206 | 207 | // polling for existing tasks with unknown ordering 208 | s.sqlMock.ExpectPrepare(stmtMockRegexp). 209 | ExpectQuery(). 210 | WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 211 | 212 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{testTaskType}, []string{"testQueue"}, 15*time.Second, -1, 1) 213 | assert.Len(s.T(), tasks, 1) 214 | assert.Nil(s.T(), err) 215 | } 216 | 217 | func (s *PostgresTestSuite) TestCleanTasks() { 218 | stmtMockRegexp := regexp.QuoteMeta(`DELETE FROM test_tasks WHERE "status" = ANY($1) AND "created_at" <= $2;`) 219 | 220 | // cleaning when DB returns error 221 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnError(errSQL) 222 | 223 | rowsAffected, err := s.mockedRepository.CleanTasks(ctx, time.Hour) 224 | assert.Zero(s.T(), rowsAffected) 225 | assert.NotNil(s.T(), err) 226 | 227 | // cleaning when no rows are found 228 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnResult(driver.ResultNoRows) 229 | 230 | rowsAffected, err = s.mockedRepository.CleanTasks(ctx, time.Hour) 231 | assert.Equal(s.T(), int64(0), rowsAffected) 232 | assert.NotNil(s.T(), err) 233 | 234 | // cleaning successful 235 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1)) 236 | 237 | rowsAffected, err = s.mockedRepository.CleanTasks(ctx, time.Hour) 238 | assert.Equal(s.T(), int64(1), rowsAffected) 239 | assert.Nil(s.T(), err) 240 | } 241 | 242 | func (s *PostgresTestSuite) TestRegisterStart() { 243 | stmtMockRegexp := regexp.QuoteMeta(`UPDATE test_tasks SET "status" = $1, "started_at" = $2 WHERE "id" = $3 RETURNING *;`) 244 | 245 | // registering start when DB returns error 246 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL) 247 | 248 | task, err := s.mockedRepository.RegisterStart(ctx, testTask) 249 | assert.Empty(s.T(), task) 250 | assert.NotNil(s.T(), err) 251 | 252 | // registering start successful 253 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 254 | 255 | task, err = s.mockedRepository.RegisterStart(ctx, testTask) 256 | assert.NotEmpty(s.T(), task) 257 | assert.Nil(s.T(), err) 258 | } 259 | 260 | func (s *PostgresTestSuite) TestRegisterError() { 261 | stmtMockRegexp := regexp.QuoteMeta(`UPDATE test_tasks SET "last_error" = $1 WHERE "id" = $2 RETURNING *;`) 262 | 263 | // registering error when DB returns error 264 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL) 265 | 266 | task, err := s.mockedRepository.RegisterError(ctx, testTask, errTask) 267 | assert.Empty(s.T(), task) 268 | assert.NotNil(s.T(), err) 269 | 270 | // registering error successful 271 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 272 | 273 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask) 274 | assert.NotEmpty(s.T(), task) 275 | assert.Nil(s.T(), err) 276 | } 277 | 278 | func (s *PostgresTestSuite) TestRegisterFinish() { 279 | stmtMockRegexp := regexp.QuoteMeta(`UPDATE test_tasks SET "status" = $1, "finished_at" = $2 WHERE "id" = $3 RETURNING *;`) 280 | 281 | // registering failure when DB returns error 282 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL) 283 | 284 | task, err := s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful) 285 | assert.Empty(s.T(), task) 286 | assert.NotNil(s.T(), err) 287 | 288 | // registering failure successful 289 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 290 | 291 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful) 292 | assert.NotEmpty(s.T(), task) 293 | assert.Nil(s.T(), err) 294 | } 295 | 296 | func (s *PostgresTestSuite) TestSubmitTask() { 297 | stmtMockRegexp := regexp.QuoteMeta(`INSERT INTO test_tasks (id, type, args, queue, priority, status, max_receives, created_at, visible_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *;`) 298 | 299 | // submitting task when DB returns error 300 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL) 301 | 302 | task, err := s.mockedRepository.SubmitTask(ctx, testTask) 303 | assert.Empty(s.T(), task) 304 | assert.NotNil(s.T(), err) 305 | 306 | // submitting task successful 307 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 308 | 309 | task, err = s.mockedRepository.SubmitTask(ctx, testTask) 310 | assert.NotEmpty(s.T(), task) 311 | assert.Nil(s.T(), err) 312 | } 313 | 314 | func (s *PostgresTestSuite) TestDeleteTask() { 315 | var ( 316 | stmtMockRegexp = regexp.QuoteMeta(`DELETE 317 | FROM 318 | test_tasks 319 | WHERE 320 | "id" = $1;`) 321 | stmtInvisibleMockRegexp = regexp.QuoteMeta(`DELETE 322 | FROM 323 | test_tasks 324 | WHERE 325 | "id" = $1 AND 326 | ( 327 | ( 328 | "visible_at" <= $2 329 | ) OR 330 | ( 331 | "status" = ANY($3) AND 332 | "visible_at" > $4 333 | ) 334 | );`) 335 | ) 336 | 337 | // deleting task when DB returns error 338 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnError(errSQL) 339 | 340 | err := s.mockedRepository.DeleteTask(ctx, testTask, false) 341 | assert.NotNil(s.T(), err) 342 | 343 | // deleting task successful 344 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1)) 345 | 346 | err = s.mockedRepository.DeleteTask(ctx, testTask, false) 347 | assert.Nil(s.T(), err) 348 | 349 | // deleting invisible task successful 350 | s.sqlMock.ExpectPrepare(stmtInvisibleMockRegexp).ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1)) 351 | 352 | err = s.mockedRepository.DeleteTask(ctx, testTask, true) 353 | assert.Nil(s.T(), err) 354 | } 355 | 356 | func (s *PostgresTestSuite) TestRequeueTask() { 357 | stmtMockRegexp := regexp.QuoteMeta(`UPDATE test_tasks SET "status" = $1 WHERE "id" = $2 RETURNING *;`) 358 | 359 | // requeuing task when DB returns error 360 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL) 361 | 362 | task, err := s.mockedRepository.RequeueTask(ctx, testTask) 363 | assert.Empty(s.T(), task) 364 | assert.NotNil(s.T(), err) 365 | 366 | // requeuing task successful 367 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 368 | 369 | task, err = s.mockedRepository.RequeueTask(ctx, testTask) 370 | assert.NotEmpty(s.T(), task) 371 | assert.Nil(s.T(), err) 372 | } 373 | 374 | func (s *PostgresTestSuite) TestCountTasks() { 375 | stmtMockRegexp := regexp.QuoteMeta(`SELECT COUNT(*) FROM test_tasks WHERE "status" = ANY($1) AND "type" = ANY($2) AND "queue" = ANY($3)`) 376 | 377 | // counting tasks when DB returns error 378 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL) 379 | 380 | count, err := s.mockedRepository.CountTasks(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"test"}, []string{"test"}) 381 | assert.Equal(s.T(), int64(0), count) 382 | assert.NotNil(s.T(), err) 383 | 384 | // counting tasks successful 385 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(10)) 386 | 387 | count, err = s.mockedRepository.CountTasks(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"test"}, []string{"test"}) 388 | assert.Equal(s.T(), int64(10), count) 389 | assert.Nil(s.T(), err) 390 | } 391 | 392 | func (s *PostgresTestSuite) TestScanTasks() { 393 | stmtMockRegexp := regexp.QuoteMeta(`SELECT * FROM test_tasks WHERE "status" = ANY($1) AND "type" = ANY($2) AND "queue" = ANY($3) ORDER BY $4 LIMIT $5;`) 394 | 395 | // scanning tasks when DB returns error 396 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL) 397 | 398 | tasks, err := s.mockedRepository.ScanTasks(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"test"}, []string{"test"}, 0, 10) 399 | assert.Empty(s.T(), tasks) 400 | assert.NotNil(s.T(), err) 401 | 402 | // scanning tasks successful 403 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...)) 404 | 405 | tasks, err = s.mockedRepository.ScanTasks(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"test"}, []string{"test"}, 0, 10) 406 | assert.NotEmpty(s.T(), tasks) 407 | assert.Nil(s.T(), err) 408 | } 409 | 410 | func (s *PostgresTestSuite) TestPurgeTasks() { 411 | var ( 412 | stmtMockRegexp = regexp.QuoteMeta(`DELETE 413 | FROM 414 | test_tasks 415 | WHERE 416 | "status" = ANY($1) AND 417 | "queue" = ANY($2);`) 418 | stmtSafeDeleteMockRegexp = regexp.QuoteMeta(`DELETE 419 | FROM 420 | test_tasks 421 | WHERE 422 | "status" = ANY($1) AND 423 | "queue" = ANY($2) AND 424 | ( 425 | ( "visible_at" <= $3 ) OR 426 | ( "status" = ANY($4) AND "visible_at" > $5 ) 427 | );`) 428 | ) 429 | 430 | // purging tasks when DB returns error 431 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnError(errSQL) 432 | 433 | count, err := s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false) 434 | assert.Equal(s.T(), int64(0), count) 435 | assert.NotNil(s.T(), err) 436 | 437 | // purging when no rows are found 438 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnResult(driver.ResultNoRows) 439 | 440 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false) 441 | assert.Equal(s.T(), int64(0), count) 442 | assert.NotNil(s.T(), err) 443 | 444 | // purging tasks successful 445 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1)) 446 | 447 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false) 448 | assert.Equal(s.T(), int64(1), count) 449 | assert.Nil(s.T(), err) 450 | 451 | // purging tasks with safeDelete successful 452 | s.sqlMock.ExpectPrepare(stmtSafeDeleteMockRegexp).ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1)) 453 | 454 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, true) 455 | assert.Equal(s.T(), int64(1), count) 456 | assert.Nil(s.T(), err) 457 | } 458 | 459 | func (s *PostgresTestSuite) TestPrepareWithTableName() { 460 | stmtMockRegexp := regexp.QuoteMeta(`SELECT * FROM test_tasks`) 461 | 462 | postgresRepository, ok := s.mockedRepository.(*postgres.Repository) 463 | require.True(s.T(), ok) 464 | 465 | // preparing stmt with table name when DB returns error 466 | s.sqlMock.ExpectPrepare(stmtMockRegexp).WillReturnError(errSQL) 467 | 468 | assert.PanicsWithError(s.T(), "sql error", func() { 469 | _ = postgresRepository.PrepareWithTableName("SELECT * FROM {{.tableName}}") 470 | }) 471 | } 472 | 473 | func (s *PostgresTestSuite) TestCloseNamedStmt() { 474 | stmtMockRegexp := regexp.QuoteMeta(`SELECT * FROM test_tasks`) 475 | 476 | postgresRepository, ok := s.mockedRepository.(*postgres.Repository) 477 | require.True(s.T(), ok) 478 | 479 | s.sqlMock.ExpectPrepare(stmtMockRegexp) 480 | 481 | stmt, err := s.db.Prepare("SELECT * FROM test_tasks") 482 | require.Nil(s.T(), err) 483 | 484 | // an alternative DB to test the panic 485 | altDB, altSQLMock, err := sqlmock.New() 486 | require.Nil(s.T(), err) 487 | 488 | altSQLMock.ExpectBegin() 489 | 490 | tx, err := altDB.Begin() 491 | require.Nil(s.T(), err) 492 | 493 | assert.PanicsWithError(s.T(), "sql: Tx.Stmt: statement from different database used", func() { 494 | postgresRepository.CloseNamedStmt(tx.Stmt(stmt)) 495 | }) 496 | } 497 | 498 | func (s *PostgresTestSuite) TestInterpolateSQL() { 499 | params := map[string]any{"tableName": "test_table"} 500 | 501 | // Interpolate SQL successfully 502 | interpolatedSQL := postgres.InterpolateSQL("SELECT * FROM {{.tableName}}", params) 503 | assert.Equal(s.T(), "SELECT * FROM test_table", interpolatedSQL) 504 | 505 | // Fail interpolaing unparseable SQL template 506 | assert.Panics(s.T(), func() { 507 | unparseableTemplateSQL := postgres.InterpolateSQL("SELECT * FROM {{.tableName", params) 508 | assert.Empty(s.T(), unparseableTemplateSQL) 509 | }) 510 | 511 | // Fail interpolaing unexecutable SQL template 512 | assert.Panics(s.T(), func() { 513 | unexecutableTemplateSQL := postgres.InterpolateSQL(`SELECT * FROM {{if .tableName eq 1}} {{end}} {{.tableName}}`, params) 514 | assert.Empty(s.T(), unexecutableTemplateSQL) 515 | }) 516 | } 517 | -------------------------------------------------------------------------------- /task.go: -------------------------------------------------------------------------------- 1 | package tasq 2 | 3 | import ( 4 | "bytes" 5 | "encoding/gob" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/google/uuid" 10 | ) 11 | 12 | // TaskStatus is an enum type describing the status a task is currently in. 13 | type TaskStatus string 14 | 15 | // The collection of possible task statuses. 16 | const ( 17 | StatusNew TaskStatus = "NEW" 18 | StatusEnqueued TaskStatus = "ENQUEUED" 19 | StatusInProgress TaskStatus = "IN_PROGRESS" 20 | StatusSuccessful TaskStatus = "SUCCESSFUL" 21 | StatusFailed TaskStatus = "FAILED" 22 | ) 23 | 24 | // TaskStatusGroup is an enum type describing the key used in the 25 | // map of TaskStatuses which groups them for different purposes. 26 | type TaskStatusGroup int 27 | 28 | // The collection of possible task status groupings. 29 | const ( 30 | AllTasks TaskStatusGroup = iota 31 | OpenTasks 32 | FinishedTasks 33 | ) 34 | 35 | // GetTaskStatuses returns a slice of TaskStatuses based on the TaskStatusGroup 36 | // passed as an argument. 37 | func GetTaskStatuses(taskStatusGroup TaskStatusGroup) []TaskStatus { 38 | if selected, ok := map[TaskStatusGroup][]TaskStatus{ 39 | AllTasks: { 40 | StatusNew, 41 | StatusEnqueued, 42 | StatusInProgress, 43 | StatusSuccessful, 44 | StatusFailed, 45 | }, 46 | OpenTasks: { 47 | StatusNew, 48 | StatusEnqueued, 49 | StatusInProgress, 50 | }, 51 | FinishedTasks: { 52 | StatusSuccessful, 53 | StatusFailed, 54 | }, 55 | }[taskStatusGroup]; ok { 56 | return selected 57 | } 58 | 59 | return nil 60 | } 61 | 62 | // Task is the struct used to represent an atomic task managed by tasq. 63 | type Task struct { 64 | ID uuid.UUID 65 | Type string 66 | Args []byte 67 | Queue string 68 | Priority int16 69 | Status TaskStatus 70 | ReceiveCount int32 71 | MaxReceives int32 72 | LastError *string 73 | CreatedAt time.Time 74 | StartedAt *time.Time 75 | FinishedAt *time.Time 76 | VisibleAt time.Time 77 | } 78 | 79 | // NewTask creates a new Task struct based on the supplied arguments required to define it. 80 | func NewTask(taskType string, taskArgs any, queue string, priority int16, maxReceives int32) (*Task, error) { 81 | taskID, err := uuid.NewRandom() 82 | if err != nil { 83 | return nil, fmt.Errorf("failed to generate new task ID: %w", err) 84 | } 85 | 86 | encodedArgs, err := encodeTaskArgs(taskArgs) 87 | if err != nil { 88 | return nil, err 89 | } 90 | 91 | return &Task{ 92 | ID: taskID, 93 | Type: taskType, 94 | Args: encodedArgs, 95 | Queue: queue, 96 | Priority: priority, 97 | Status: StatusNew, 98 | ReceiveCount: 0, 99 | MaxReceives: maxReceives, 100 | LastError: nil, 101 | CreatedAt: time.Now(), 102 | StartedAt: nil, 103 | FinishedAt: nil, 104 | VisibleAt: time.Time{}, 105 | }, nil 106 | } 107 | 108 | // IsLastReceive returns true if the task has reached its maximum number of receives. 109 | func (t *Task) IsLastReceive() bool { 110 | return t.ReceiveCount >= t.MaxReceives 111 | } 112 | 113 | // SetVisibility sets the time at which the task will become visible again. 114 | func (t *Task) SetVisibility(visibleAt time.Time) { 115 | t.VisibleAt = visibleAt 116 | } 117 | 118 | // UnmarshalArgs decodes the task arguments into the passed target interface. 119 | func (t *Task) UnmarshalArgs(target any) error { 120 | var ( 121 | buffer = bytes.NewBuffer(t.Args) 122 | decoder = gob.NewDecoder(buffer) 123 | ) 124 | 125 | if err := decoder.Decode(target); err != nil { 126 | return fmt.Errorf("failed to decode task arguments: %w", err) 127 | } 128 | 129 | return nil 130 | } 131 | 132 | func encodeTaskArgs(taskArgs any) ([]byte, error) { 133 | var ( 134 | buffer bytes.Buffer 135 | encoder = gob.NewEncoder(&buffer) 136 | ) 137 | 138 | err := encoder.Encode(taskArgs) 139 | if err != nil { 140 | return []byte{}, fmt.Errorf("failed to encode task arguments: %w", err) 141 | } 142 | 143 | return buffer.Bytes(), nil 144 | } 145 | -------------------------------------------------------------------------------- /task_test.go: -------------------------------------------------------------------------------- 1 | package tasq_test 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | "github.com/greencoda/tasq" 10 | "github.com/stretchr/testify/assert" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | var errTest = errors.New("test error") 15 | 16 | type errorReader int 17 | 18 | func (errorReader) Read(p []byte) (int, error) { 19 | return 0, errTest 20 | } 21 | 22 | type TaskTestSuite struct { 23 | suite.Suite 24 | } 25 | 26 | func TestTaskTestSuite(t *testing.T) { 27 | suite.Run(t, new(TaskTestSuite)) 28 | } 29 | 30 | func (s *TaskTestSuite) SetupTest() { 31 | uuid.SetRand(nil) 32 | } 33 | 34 | func (s *TaskTestSuite) TestGetTaskStatuses() { 35 | allTasks := tasq.GetTaskStatuses(tasq.AllTasks) 36 | assert.ElementsMatch(s.T(), allTasks, []tasq.TaskStatus{ 37 | tasq.StatusNew, 38 | tasq.StatusEnqueued, 39 | tasq.StatusInProgress, 40 | tasq.StatusSuccessful, 41 | tasq.StatusFailed, 42 | }) 43 | 44 | openTasks := tasq.GetTaskStatuses(tasq.OpenTasks) 45 | assert.ElementsMatch(s.T(), openTasks, []tasq.TaskStatus{ 46 | tasq.StatusNew, 47 | tasq.StatusEnqueued, 48 | tasq.StatusInProgress, 49 | }) 50 | 51 | finishedTasks := tasq.GetTaskStatuses(tasq.FinishedTasks) 52 | assert.ElementsMatch(s.T(), finishedTasks, []tasq.TaskStatus{ 53 | tasq.StatusSuccessful, 54 | tasq.StatusFailed, 55 | }) 56 | 57 | unknownTasks := tasq.GetTaskStatuses(-1) 58 | assert.Empty(s.T(), unknownTasks) 59 | } 60 | 61 | func (s *TaskTestSuite) TestNewTask() { 62 | // Create task successfully 63 | task, _ := tasq.NewTask("testTask", true, "testQueue", 0, 5) 64 | assert.NotNil(s.T(), task) 65 | 66 | // Fail by creating task with nil args 67 | nilTask, err := tasq.NewTask("testTask", nil, "testQueue", 0, 5) 68 | assert.Nil(s.T(), nilTask) 69 | assert.NotNil(s.T(), err) 70 | 71 | // Fail by causing uuid generation to return error 72 | uuid.SetRand(new(errorReader)) 73 | 74 | invalidUUIDTask, err := tasq.NewTask("testTask", false, "testQueue", 0, 5) 75 | assert.Nil(s.T(), invalidUUIDTask) 76 | assert.NotNil(s.T(), err) 77 | } 78 | 79 | func (s *TaskTestSuite) TestTaskUnmarshalArgs() { 80 | // Create task successfully 81 | task, _ := tasq.NewTask("testTask", true, "testQueue", 0, 5) 82 | assert.NotNil(s.T(), task) 83 | 84 | // Unmarshal task args successfully 85 | var args bool 86 | err := task.UnmarshalArgs(&args) 87 | assert.Nil(s.T(), err) 88 | assert.True(s.T(), args) 89 | 90 | // Fail by unmarshaling args to incorrect type 91 | var incorrectTypeArgs string 92 | err = task.UnmarshalArgs(&incorrectTypeArgs) 93 | assert.NotNil(s.T(), err) 94 | assert.Empty(s.T(), incorrectTypeArgs) 95 | } 96 | 97 | func (s *TaskTestSuite) TestTaskIsLastReceive() { 98 | // Create singleReceiveTask successfully 99 | singleReceiveTask, _ := tasq.NewTask("testTask", true, "testQueue", 0, 1) 100 | singleReceiveTask.ReceiveCount = 1 101 | assert.NotNil(s.T(), singleReceiveTask) 102 | 103 | // Check if task is in its last receive before reaching the maximum amount of receives 104 | assert.True(s.T(), singleReceiveTask.IsLastReceive()) 105 | 106 | // Create multiReceiveTask successfully 107 | multiReceiveTask, _ := tasq.NewTask("testTask", true, "testQueue", 0, 5) 108 | multiReceiveTask.ReceiveCount = 1 109 | assert.NotNil(s.T(), multiReceiveTask) 110 | 111 | // Check if task is in its last receive before reaching the maximum amount of receives 112 | assert.False(s.T(), multiReceiveTask.IsLastReceive()) 113 | } 114 | 115 | func (s *TaskTestSuite) TestTaskSetVisibility() { 116 | // Create task successfully 117 | task, _ := tasq.NewTask("testTask", true, "testQueue", 0, 5) 118 | assert.NotNil(s.T(), task) 119 | 120 | // Set Visibility successfully 121 | visibilityTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) 122 | task.SetVisibility(visibilityTime) 123 | assert.Equal(s.T(), task.VisibleAt, visibilityTime) 124 | } 125 | --------------------------------------------------------------------------------