├── .github
├── CODEOWNERS
├── splash_image.png
└── workflows
│ └── test.yml
├── .gitignore
├── .golangci.yml
├── .mockery.yml
├── LICENSE
├── Makefile
├── README.md
├── _examples
├── mysql
│ ├── docker-compose.yml
│ └── main.go
└── postgres
│ ├── docker-compose.yml
│ └── main.go
├── cleaner.go
├── cleaner_test.go
├── client.go
├── client_test.go
├── consumer.go
├── consumer_test.go
├── export_test.go
├── go.mod
├── go.sum
├── inspector.go
├── inspector_test.go
├── mocks
└── IRepository.go
├── producer.go
├── producer_test.go
├── repository.go
├── repository
├── mysql
│ ├── export_test.go
│ ├── mysql.go
│ ├── mysqlTask.go
│ ├── mysqlTask_test.go
│ └── mysql_test.go
└── postgres
│ ├── export_test.go
│ ├── postgres.go
│ ├── postgresTask.go
│ ├── postgresTask_test.go
│ └── postgres_test.go
├── task.go
└── task_test.go
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | * @greencoda
--------------------------------------------------------------------------------
/.github/splash_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/greencoda/tasq/cfff6aa01b5cda4faac73a706617161140d584c0/.github/splash_image.png
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on:
4 | push:
5 | branches: [ "master" ]
6 | pull_request:
7 | branches: [ "master" ]
8 |
9 | jobs:
10 | lint:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout
14 | uses: actions/checkout@v3
15 |
16 | - name: Setup Go
17 | uses: actions/setup-go@v3
18 | with:
19 | go-version: 1.19.x
20 | check-latest: true
21 |
22 | - name: Lint
23 | uses: golangci/golangci-lint-action@v3
24 | with:
25 | version: v1.50.1
26 | test:
27 | runs-on: ubuntu-latest
28 | steps:
29 | - name: Checkout
30 | uses: actions/checkout@v3
31 |
32 | - name: Setup Go
33 | uses: actions/setup-go@v3
34 | with:
35 | go-version: 1.19.x
36 | check-latest: true
37 |
38 | - name: Setup Dependencies
39 | run: make deps && git diff-index --quiet HEAD || { >&2 echo "Stale go.{mod,sum} detected. This can be fixed with 'make deps'."; exit 1; }
40 |
41 | - name: Run Test
42 | run: make test
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # .DS_Store
2 | .DS_Store
3 |
4 | # Binaries for programs and plugins
5 | *.exe
6 | *.exe~
7 | *.dll
8 | *.so
9 | *.dylib
10 |
11 | # Test binary, build with `go test -c`
12 | *.test
13 |
14 | # Output of the go coverage tool, specifically when used with LiteIDE
15 | *.out
16 |
17 | # Ignore dependencies in ./vendor
18 | vendor
19 |
20 | # ignore build products
21 | *.zip
22 | bin/
23 |
24 | # ignore docker db volume data
25 | .dbdata
26 |
27 | # IDE-related files
28 | .idea
29 | .vscode
30 | *.code-workspace
31 |
32 | # product from go test -cover / ginkgo -cover
33 | *.coverprofile
34 |
35 | # certificates
36 | *.pem
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | # Refer to golangci-lint's example config file for more options and information:
2 | # https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml
3 |
4 | run:
5 | timeout: 5m
6 | modules-download-mode: readonly
7 |
8 | linters:
9 | enable-all: true
10 | disable:
11 | - dupl
12 | - funlen
13 | - gochecknoglobals
14 | - lll
15 | - revive
16 | - varnamelen
17 |
18 | linters-settings:
19 | paralleltest:
20 | ignore-missing: true
21 | interfacebloat:
22 | max: 20
23 | gosec:
24 | excludes:
25 | - G404
26 |
27 | issues:
28 | exclude-use-default: false
29 | max-issues-per-linter: 0
30 | max-same-issues: 0
--------------------------------------------------------------------------------
/.mockery.yml:
--------------------------------------------------------------------------------
1 | name: "IRepository"
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 greencoda
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: deps lint mocks test test-cover
2 | .SILENT: test test-cover
3 |
4 | TESTABLE_PACKAGES = $(shell go list ./... | grep -v ./mocks)
5 |
6 | deps:
7 | go mod tidy
8 | go mod vendor
9 |
10 | lint: deps
11 | golangci-lint run -v
12 |
13 | mocks:
14 | rm -rf mocks/*
15 | mockery
16 |
17 | test:
18 | go test ${TESTABLE_PACKAGES} -count=1 -cover
19 |
20 | test-cover:
21 | go test ${TESTABLE_PACKAGES} -count=1 -cover -coverprofile=coverage.out
22 | go tool cover -html=coverage.out
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [![godoc for greencoda/tasq][godoc-badge]][godoc-url]
2 | [![Build Status][actions-badge]][actions-url]
3 | [![Go 1.19][goversion-badge]][goversion-url]
4 | [![Go Report card][goreportcard-badge]][goreportcard-url]
5 |
6 |

7 |
8 | # tasq
9 |
10 | Tasq is Golang task queue using SQL database for persistence.
11 | Currently supports:
12 | - PostgreSQL
13 | - MySQL
14 |
15 | ## Install
16 |
17 | ```shell
18 | go get -u github.com/greencoda/tasq
19 | ```
20 |
21 | ## Usage Example
22 |
23 | To try tasq locally, you'll need a local DB running on your machine. You may use the supplied docker-compose.yml file to start a local instance
24 | ```shell
25 | docker-compose -f _examples//docker-compose.yml up -d
26 | ```
27 |
28 | Afterwards simply run the example.go file
29 | ```shell
30 | go run _examples//main.go
31 | ```
32 |
33 | [godoc-badge]: https://pkg.go.dev/badge/github.com/greencoda/tasq
34 | [godoc-url]: https://pkg.go.dev/github.com/greencoda/tasq
35 | [actions-badge]: https://github.com/greencoda/tasq/actions/workflows/test.yml/badge.svg
36 | [actions-url]: https://github.com/greencoda/tasq/actions/workflows/test.yml
37 | [goversion-badge]: https://img.shields.io/badge/Go-1.19-%2300ADD8?logo=go
38 | [goversion-url]: https://golang.org/doc/go1.19
39 | [goreportcard-badge]: https://goreportcard.com/badge/github.com/greencoda/tasq
40 | [goreportcard-url]: https://goreportcard.com/report/github.com/greencoda/tasq
41 |
--------------------------------------------------------------------------------
/_examples/mysql/docker-compose.yml:
--------------------------------------------------------------------------------
1 | ---
2 | version: '3.8'
3 |
4 | services:
5 | mysql:
6 | image: mysql:8.0
7 | command: ["mysqld", "--skip-name-resolve"]
8 | ports:
9 | - "3306:3306"
10 | volumes:
11 | - ./.dbdata:/var/lib/mysql
12 | environment:
13 | MYSQL_DATABASE: test
14 | MYSQL_ROOT_PASSWORD: root
15 | healthcheck:
16 | test: ["CMD", "mysqladmin" ,"ping", "-h", "localhost"]
17 | interval: 10s
18 | timeout: 5s
19 | retries: 10
20 | networks:
21 | tasq:
22 | name: tasq_network
23 | external: false
24 |
--------------------------------------------------------------------------------
/_examples/mysql/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "fmt"
7 | "log"
8 | "math/rand"
9 | "sync"
10 | "time"
11 |
12 | "github.com/greencoda/tasq"
13 | tasqMySQL "github.com/greencoda/tasq/repository/mysql"
14 | )
15 |
16 | const (
17 | channelSize = 10
18 | taskType = "sampleTask"
19 | taskQueue = "sampleQueue"
20 | taskPriority = 20
21 | taskScanLimit = 20
22 | taskMaxReceives = 5
23 | deletionCutoff = 0.5
24 | migrationTimeout = 10 * time.Second
25 | pollInterval = 10 * time.Second
26 | consumerShutdownTimeout = 30 * time.Second
27 | )
28 |
29 | // SampleTaskArgs is a struct that represents the arguments for the sample task.
30 | type SampleTaskArgs struct {
31 | ID int
32 | Value float64
33 | }
34 |
35 | func processSampleTask(task *tasq.Task) error {
36 | var args SampleTaskArgs
37 |
38 | err := task.UnmarshalArgs(&args)
39 | if err != nil {
40 | return fmt.Errorf("failed to unmarshal value: %w", err)
41 | }
42 |
43 | // do something here with the task arguments as input
44 | // for purposes of the sample, we'll just log its details here
45 | log.Printf("executed task '%s' with args '%+v'", task.ID, args)
46 |
47 | return nil
48 | }
49 |
50 | func consumeTasks(consumer *tasq.Consumer, wg *sync.WaitGroup) {
51 | defer wg.Done()
52 |
53 | for {
54 | job := <-consumer.Channel()
55 | if job == nil {
56 | return
57 | }
58 |
59 | // execute the job right away or feed it into a workerpool
60 | // such as workerpool.Add(*job)
61 | (*job)()
62 | }
63 | }
64 |
65 | func produceTasks(ctx context.Context, producer *tasq.Producer) {
66 | taskTicker := time.NewTicker(1 * time.Second)
67 |
68 | for taskIndex := 0; true; taskIndex++ {
69 | <-taskTicker.C
70 |
71 | seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
72 |
73 | taskArgs := SampleTaskArgs{
74 | ID: taskIndex,
75 | Value: seededRand.Float64(),
76 | }
77 |
78 | t, err := producer.Submit(ctx, taskType, taskArgs, taskQueue, taskPriority, taskMaxReceives)
79 | if err != nil {
80 | log.Panicf("error while submitting task to tasq: %s", err)
81 | } else {
82 | log.Printf("successfully submitted task '%s'", t.ID)
83 | }
84 | }
85 | }
86 |
87 | func inspectTasks(ctx context.Context, inspector *tasq.Inspector) {
88 | taskTicker := time.NewTicker(1 * time.Second)
89 |
90 | for taskIndex := 0; true; taskIndex++ {
91 | <-taskTicker.C
92 |
93 | taskCount, err := inspector.Count(ctx, nil, []string{taskType}, []string{taskQueue})
94 | if err != nil {
95 | log.Panicf("error while counting tasks: %s", err)
96 | }
97 |
98 | log.Printf("successfully counted %d tasks total", taskCount)
99 |
100 | tasks, err := inspector.Scan(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{taskType}, []string{taskQueue}, tasq.OrderingCreatedAtFirst, taskScanLimit)
101 | if err != nil {
102 | log.Panicf("error while scanning tasks: %s", err)
103 | }
104 |
105 | log.Printf("successfully scanned %d new tasks", len(tasks))
106 |
107 | for _, task := range tasks {
108 | var args SampleTaskArgs
109 |
110 | err := task.UnmarshalArgs(&args)
111 | if err != nil {
112 | log.Printf("failed to unmarshal value for task '%s'", task.ID)
113 |
114 | continue
115 | }
116 |
117 | if args.Value < deletionCutoff {
118 | continue
119 | }
120 |
121 | err = inspector.Delete(ctx, true, task)
122 | if err != nil {
123 | log.Printf("failed to remove task '%s'", task.ID)
124 |
125 | continue
126 | }
127 |
128 | log.Printf("successfully removed task '%s'", task.ID)
129 | }
130 | }
131 | }
132 |
133 | func main() {
134 | ctx, cancelCtx := context.WithCancel(context.Background())
135 | defer cancelCtx()
136 |
137 | db, err := sql.Open("mysql", "root:root@/test")
138 | if err != nil {
139 | log.Panicf("failed to open DB connection: %v", err)
140 | }
141 |
142 | // instantiate tasq repository to manage the database connection
143 | // you can also have it set up the sql DB for you if you provide the dsn string
144 | // instead of the *sql.DB instance
145 | tasqRepository, err := tasqMySQL.NewRepository(db, "tasq")
146 | if err != nil {
147 | log.Panicf("failed to create tasq repository: %s", err)
148 | }
149 |
150 | migrationCtx, migrationCancelCtx := context.WithTimeout(context.Background(), migrationTimeout)
151 | defer migrationCancelCtx()
152 |
153 | err = tasqRepository.Migrate(migrationCtx)
154 | if err != nil {
155 | log.Panicf("failed to migrate tasq repository: %s", err)
156 | }
157 |
158 | log.Print("database migrated successfully")
159 |
160 | // instantiate tasq client
161 | tasqClient := tasq.NewClient(tasqRepository)
162 |
163 | // set up tasq cleaner
164 | cleaner := tasqClient.NewCleaner().
165 | WithTaskAge(time.Second)
166 |
167 | cleanedTaskCount, err := cleaner.Clean(ctx)
168 | if err != nil {
169 | log.Panicf("failed to clean old tasks from queue: %s", err)
170 | }
171 |
172 | log.Printf("cleaned %d finished tasks from the queue on startup", cleanedTaskCount)
173 |
174 | // set up tasq consumer
175 | consumer := tasqClient.NewConsumer().
176 | WithQueues(taskQueue).
177 | WithChannelSize(channelSize).
178 | WithPollInterval(pollInterval).
179 | WithPollStrategy(tasq.PollStrategyByPriority).
180 | WithAutoDeleteOnSuccess(false).
181 | WithLogger(log.Default())
182 |
183 | // teach the consumer to handle tasks with the type "sampleTask" with the function "processSampleTask"
184 | err = consumer.Learn(taskType, processSampleTask, false)
185 | if err != nil {
186 | log.Panicf("failed to teach tasq consumer task handler: %s", err)
187 | }
188 |
189 | // start the consumer
190 | err = consumer.Start(ctx)
191 | if err != nil {
192 | log.Panicf("failed to start tasq consumer: %s", err)
193 | }
194 |
195 | var consumerWg sync.WaitGroup
196 |
197 | // start the goroutine which handles the tasq jobs received from the consumer
198 | consumerWg.Add(1)
199 |
200 | go consumeTasks(consumer, &consumerWg)
201 |
202 | // set up tasq inspector
203 | inspector := tasqClient.NewInspector()
204 |
205 | go inspectTasks(ctx, inspector)
206 |
207 | // set up tasq producer
208 | producer := tasqClient.NewProducer()
209 |
210 | // start the goroutine which produces the tasks and submits them to the tasq queue
211 | go produceTasks(ctx, producer)
212 |
213 | // block the execution
214 | <-time.After(consumerShutdownTimeout)
215 |
216 | err = consumer.Stop()
217 | if err != nil {
218 | log.Panicf("failed to stop tasq consumer: %s", err)
219 | }
220 |
221 | // wait until consumer go routine exits
222 | consumerWg.Wait()
223 |
224 | purgedTaskCount, err := inspector.Purge(ctx, true, tasq.GetTaskStatuses(tasq.OpenTasks), []string{taskType}, []string{taskQueue})
225 | if err != nil {
226 | log.Panicf("failed to purge tasq queue: %s", err)
227 | }
228 |
229 | log.Printf("purged %d open tasks from the queue", purgedTaskCount)
230 | }
231 |
--------------------------------------------------------------------------------
/_examples/postgres/docker-compose.yml:
--------------------------------------------------------------------------------
1 | ---
2 | version: "3.2"
3 |
4 | services:
5 | tasq:
6 | container_name: postgres
7 | image: postgres:14.2-alpine
8 | volumes:
9 | - ./.dbdata:/var/lib/postgresql
10 | ports:
11 | - "5432:5432"
12 | environment:
13 | LC_ALL: C.UTF-8
14 | POSTGRES_USER: test
15 | POSTGRES_PASSWORD: test
16 | POSTGRES_DB: test
17 | tmpfs:
18 | - /var/lib/postgresql/data
19 | healthcheck:
20 | test: [ "CMD", "pg_isready" ]
21 | interval: 10s
22 | timeout: 5s
23 | retries: 5
24 | networks:
25 | - tasq
26 | networks:
27 | tasq:
28 | name: tasq_network
29 | external: false
30 |
--------------------------------------------------------------------------------
/_examples/postgres/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "fmt"
7 | "log"
8 | "math/rand"
9 | "sync"
10 | "time"
11 |
12 | "github.com/greencoda/tasq"
13 | tasqPostgres "github.com/greencoda/tasq/repository/postgres"
14 | )
15 |
16 | const (
17 | channelSize = 10
18 | taskType = "sampleTask"
19 | taskQueue = "sampleQueue"
20 | taskPriority = 20
21 | taskScanLimit = 20
22 | taskMaxReceives = 5
23 | deletionCutoff = 0.5
24 | migrationTimeout = 10 * time.Second
25 | pollInterval = 10 * time.Second
26 | consumerShutdownTimeout = 30 * time.Second
27 | )
28 |
29 | // SampleTaskArgs is a struct that represents the arguments for the sample task.
30 | type SampleTaskArgs struct {
31 | ID int
32 | Value float64
33 | }
34 |
35 | func processSampleTask(task *tasq.Task) error {
36 | var args SampleTaskArgs
37 |
38 | err := task.UnmarshalArgs(&args)
39 | if err != nil {
40 | return fmt.Errorf("failed to unmarshal value: %w", err)
41 | }
42 |
43 | // do something here with the task arguments as input
44 | // for purposes of the sample, we'll just log its details here
45 | log.Printf("executed task '%s' with args '%+v'", task.ID, args)
46 |
47 | return nil
48 | }
49 |
50 | func consumeTasks(consumer *tasq.Consumer, wg *sync.WaitGroup) {
51 | defer wg.Done()
52 |
53 | for {
54 | job := <-consumer.Channel()
55 | if job == nil {
56 | return
57 | }
58 |
59 | // execute the job right away or feed it into a workerpool
60 | // such as workerpool.Add(*job)
61 | (*job)()
62 | }
63 | }
64 |
65 | func produceTasks(ctx context.Context, producer *tasq.Producer) {
66 | taskTicker := time.NewTicker(1 * time.Second)
67 |
68 | for taskIndex := 0; true; taskIndex++ {
69 | <-taskTicker.C
70 |
71 | seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
72 |
73 | taskArgs := SampleTaskArgs{
74 | ID: taskIndex,
75 | Value: seededRand.Float64(),
76 | }
77 |
78 | t, err := producer.Submit(ctx, taskType, taskArgs, taskQueue, taskPriority, taskMaxReceives)
79 | if err != nil {
80 | log.Panicf("error while submitting task to tasq: %s", err)
81 | } else {
82 | log.Printf("successfully submitted task '%s'", t.ID)
83 | }
84 | }
85 | }
86 |
87 | func inspectTasks(ctx context.Context, inspector *tasq.Inspector) {
88 | taskTicker := time.NewTicker(1 * time.Second)
89 |
90 | for taskIndex := 0; true; taskIndex++ {
91 | <-taskTicker.C
92 |
93 | taskCount, err := inspector.Count(ctx, nil, []string{taskType}, []string{taskQueue})
94 | if err != nil {
95 | log.Panicf("error while counting tasks: %s", err)
96 | }
97 |
98 | log.Printf("successfully counted %d tasks total", taskCount)
99 |
100 | tasks, err := inspector.Scan(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{taskType}, []string{taskQueue}, tasq.OrderingCreatedAtFirst, taskScanLimit)
101 | if err != nil {
102 | log.Panicf("error while scanning tasks: %s", err)
103 | }
104 |
105 | log.Printf("successfully scanned %d new tasks", len(tasks))
106 |
107 | for _, task := range tasks {
108 | var args SampleTaskArgs
109 |
110 | err := task.UnmarshalArgs(&args)
111 | if err != nil {
112 | log.Printf("failed to unmarshal value for task '%s'", task.ID)
113 |
114 | continue
115 | }
116 |
117 | if args.Value < deletionCutoff {
118 | continue
119 | }
120 |
121 | err = inspector.Delete(ctx, true, task)
122 | if err != nil {
123 | log.Printf("failed to remove task '%s'", task.ID)
124 |
125 | continue
126 | }
127 |
128 | log.Printf("successfully removed task '%s'", task.ID)
129 | }
130 | }
131 | }
132 |
133 | func main() {
134 | ctx, cancelCtx := context.WithCancel(context.Background())
135 | defer cancelCtx()
136 |
137 | db, err := sql.Open("postgres", "host=127.0.0.1 user=test password=test dbname=test port=5432 sslmode=disable")
138 | if err != nil {
139 | log.Panicf("failed to open DB connection: %v", err)
140 | }
141 |
142 | // instantiate tasq repository to manage the database connection
143 | // you can also have it set up the sql DB for you if you provide the dsn string
144 | // instead of the *sql.DB instance
145 | tasqRepository, err := tasqPostgres.NewRepository(db, "tasq")
146 | if err != nil {
147 | log.Panicf("failed to create tasq repository: %s", err)
148 | }
149 |
150 | migrationCtx, migrationCancelCtx := context.WithTimeout(context.Background(), migrationTimeout)
151 | defer migrationCancelCtx()
152 |
153 | err = tasqRepository.Migrate(migrationCtx)
154 | if err != nil {
155 | log.Panicf("failed to migrate tasq repository: %s", err)
156 | }
157 |
158 | log.Print("database migrated successfully")
159 |
160 | // instantiate tasq client
161 | tasqClient := tasq.NewClient(tasqRepository)
162 |
163 | // set up tasq cleaner
164 | cleaner := tasqClient.NewCleaner().
165 | WithTaskAge(time.Second)
166 |
167 | cleanedTaskCount, err := cleaner.Clean(ctx)
168 | if err != nil {
169 | log.Panicf("failed to clean old tasks from queue: %s", err)
170 | }
171 |
172 | log.Printf("cleaned %d finished tasks from the queue on startup", cleanedTaskCount)
173 |
174 | // set up tasq consumer
175 | consumer := tasqClient.NewConsumer().
176 | WithQueues(taskQueue).
177 | WithChannelSize(channelSize).
178 | WithPollInterval(pollInterval).
179 | WithPollStrategy(tasq.PollStrategyByPriority).
180 | WithAutoDeleteOnSuccess(false).
181 | WithLogger(log.Default())
182 |
183 | // teach the consumer to handle tasks with the type "sampleTask" with the function "processSampleTask"
184 | err = consumer.Learn(taskType, processSampleTask, false)
185 | if err != nil {
186 | log.Panicf("failed to teach tasq consumer task handler: %s", err)
187 | }
188 |
189 | // start the consumer
190 | err = consumer.Start(ctx)
191 | if err != nil {
192 | log.Panicf("failed to start tasq consumer: %s", err)
193 | }
194 |
195 | var consumerWg sync.WaitGroup
196 |
197 | // start the goroutine which handles the tasq jobs received from the consumer
198 | consumerWg.Add(1)
199 |
200 | go consumeTasks(consumer, &consumerWg)
201 |
202 | // set up tasq inspector
203 | inspector := tasqClient.NewInspector()
204 |
205 | go inspectTasks(ctx, inspector)
206 |
207 | // set up tasq producer
208 | producer := tasqClient.NewProducer()
209 |
210 | // start the goroutine which produces the tasks and submits them to the tasq queue
211 | go produceTasks(ctx, producer)
212 |
213 | // block the execution
214 | <-time.After(consumerShutdownTimeout)
215 |
216 | err = consumer.Stop()
217 | if err != nil {
218 | log.Panicf("failed to stop tasq consumer: %s", err)
219 | }
220 |
221 | // wait until consumer go routine exits
222 | consumerWg.Wait()
223 |
224 | purgedTaskCount, err := inspector.Purge(ctx, true, tasq.GetTaskStatuses(tasq.OpenTasks), []string{taskType}, []string{taskQueue})
225 | if err != nil {
226 | log.Panicf("failed to purge tasq queue: %s", err)
227 | }
228 |
229 | log.Printf("purged %d open tasks from the queue", purgedTaskCount)
230 | }
231 |
--------------------------------------------------------------------------------
/cleaner.go:
--------------------------------------------------------------------------------
1 | package tasq
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "time"
7 | )
8 |
9 | const defaultTaskAgeLimit = 15 * time.Minute
10 |
11 | // Cleaner is a service instance created by a Client with reference to that client
12 | // and the task age limit parameter.
13 | type Cleaner struct {
14 | client *Client
15 |
16 | taskAgeLimit time.Duration
17 | }
18 |
19 | // NewCleaner creates a new cleaner with a reference to the original tasq client.
20 | func (c *Client) NewCleaner() *Cleaner {
21 | return &Cleaner{
22 | client: c,
23 |
24 | taskAgeLimit: defaultTaskAgeLimit,
25 | }
26 | }
27 |
28 | // WithTaskAge defines the minimum time duration that must have passed since the creation of a finished task
29 | // in order for it to be eligible for cleanup when the Cleaner's Clean() method is called.
30 | //
31 | // Default value: 15 minutes.
32 | func (c *Cleaner) WithTaskAge(taskAge time.Duration) *Cleaner {
33 | c.taskAgeLimit = taskAge
34 |
35 | return c
36 | }
37 |
38 | // Clean will initiate the removal of finished (either succeeded or failed) tasks from the tasks table
39 | // if they have been created long enough ago for them to be eligible.
40 | func (c *Cleaner) Clean(ctx context.Context) (int64, error) {
41 | cleanedTaskCount, err := c.client.repository.CleanTasks(ctx, c.taskAgeLimit)
42 | if err != nil {
43 | return 0, fmt.Errorf("failed to clean tasks: %w", err)
44 | }
45 |
46 | return cleanedTaskCount, nil
47 | }
48 |
--------------------------------------------------------------------------------
/cleaner_test.go:
--------------------------------------------------------------------------------
1 | package tasq_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 | "time"
7 |
8 | "github.com/greencoda/tasq"
9 | "github.com/greencoda/tasq/mocks"
10 | "github.com/stretchr/testify/assert"
11 | "github.com/stretchr/testify/require"
12 | "github.com/stretchr/testify/suite"
13 | )
14 |
15 | type CleanerTestSuite struct {
16 | suite.Suite
17 | mockRepository *mocks.IRepository
18 | tasqClient *tasq.Client
19 | tasqCleaner *tasq.Cleaner
20 | }
21 |
22 | func TestCleanerTestSuite(t *testing.T) {
23 | t.Parallel()
24 |
25 | suite.Run(t, new(CleanerTestSuite))
26 | }
27 |
28 | func (s *CleanerTestSuite) SetupTest() {
29 | s.mockRepository = mocks.NewIRepository(s.T())
30 |
31 | s.tasqClient = tasq.NewClient(s.mockRepository)
32 | require.NotNil(s.T(), s.tasqClient)
33 |
34 | s.tasqCleaner = s.tasqClient.NewCleaner().WithTaskAge(time.Hour)
35 | }
36 |
37 | func (s *CleanerTestSuite) TestNewCleaner() {
38 | assert.NotNil(s.T(), s.tasqCleaner)
39 | }
40 |
41 | func (s *CleanerTestSuite) TestClean() {
42 | ctx := context.Background()
43 |
44 | s.mockRepository.On("CleanTasks", ctx, time.Hour).Return(int64(1), nil).Once()
45 |
46 | rowsAffected, err := s.tasqCleaner.Clean(ctx)
47 |
48 | assert.Equal(s.T(), int64(1), rowsAffected)
49 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "CleanTasks", ctx, time.Hour))
50 | assert.Nil(s.T(), err)
51 |
52 | s.mockRepository.On("CleanTasks", ctx, time.Hour).Return(int64(0), errRepository).Once()
53 | rowsAffected, err = s.tasqCleaner.Clean(ctx)
54 |
55 | assert.Equal(s.T(), int64(0), rowsAffected)
56 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "CleanTasks", ctx, time.Hour))
57 | assert.NotNil(s.T(), err)
58 | }
59 |
--------------------------------------------------------------------------------
/client.go:
--------------------------------------------------------------------------------
1 | // Package tasq provides a task queue implementation compapible with multiple repositories
2 | package tasq
3 |
4 | // Client wraps the tasq repository interface which is used
5 | // by the different services to access the database.
6 | type Client struct {
7 | repository IRepository
8 | }
9 |
10 | // NewClient creates a new tasq client instance with the provided tasq.
11 | func NewClient(repository IRepository) *Client {
12 | return &Client{
13 | repository: repository,
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/client_test.go:
--------------------------------------------------------------------------------
1 | package tasq_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/greencoda/tasq"
7 | "github.com/greencoda/tasq/mocks"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestNewClient(t *testing.T) {
12 | t.Parallel()
13 |
14 | repository := mocks.NewIRepository(t)
15 |
16 | tasqClient := tasq.NewClient(repository)
17 | assert.NotNil(t, tasqClient)
18 | }
19 |
--------------------------------------------------------------------------------
/consumer.go:
--------------------------------------------------------------------------------
1 | package tasq
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "fmt"
7 | "io"
8 | "log"
9 | "sync"
10 | "time"
11 |
12 | "github.com/benbjohnson/clock"
13 | "github.com/google/uuid"
14 | )
15 |
16 | // Collection of consumer errors.
17 | var (
18 | ErrConsumerAlreadyRunning = errors.New("consumer has already been started")
19 | ErrConsumerAlreadyStopped = errors.New("consumer has already been stopped")
20 | ErrCouldNotActivateTasks = errors.New("a number of tasks could not be activated")
21 | ErrCouldNotPollTasks = errors.New("could not poll tasks")
22 | ErrCouldNotPingTasks = errors.New("could not ping tasks")
23 | ErrTaskTypeAlreadyLearned = errors.New("task with this type already learned")
24 | ErrTaskTypeNotFound = errors.New("task with this type not found")
25 | ErrTaskTypeNotKnown = errors.New("task with this type is not known by this consumer")
26 | ErrUnknownPollStrategy = errors.New("unknown poll strategy")
27 | ErrVisibilityTimeoutTooShort = errors.New("visibility timeout must be longer than poll interval")
28 | )
29 |
30 | // Logger is the interface used for event logging during task consumption.
31 | type Logger interface {
32 | Print(v ...any)
33 | Printf(format string, v ...any)
34 | }
35 |
36 | // HandlerFunc is the function signature for the handler functions that are used to process tasks.
37 | type HandlerFunc func(task *Task) error
38 |
39 | // PollStrategy is the label assigned to the ordering by which tasks are polled for consumption.
40 | type PollStrategy string
41 |
42 | // Collection of pollStrategies.
43 | const (
44 | PollStrategyByCreatedAt PollStrategy = "pollByCreatedAt" // Poll by oldest tasks first
45 | PollStrategyByPriority PollStrategy = "pollByPriority" // Poll by highest priority task first
46 | )
47 |
48 | const (
49 | defaultQueue = ""
50 | defaultChannelSize = 10
51 | defaultPollInterval = 5 * time.Second
52 | defaultPollStrategy = PollStrategyByCreatedAt
53 | defaultPollLimit = 10
54 | defaultAutoDeleteOnSuccess = false
55 | defaultMaxActiveTasks = 10
56 | defaultVisibilityTimeout = 15 * time.Second
57 | )
58 |
59 | // NoopLogger discards the log messages written to it.
60 | func NoopLogger() *log.Logger {
61 | return log.New(io.Discard, "", 0)
62 | }
63 |
64 | // Consumer is a service instance created by a Client with reference to that client
65 | // and the various parameters that define the task consumption behaviour.
66 | type Consumer struct {
67 | running bool
68 | autoDeleteOnSuccess bool
69 | channelSize int
70 | pollLimit int
71 | maxActiveTasks int
72 | pollInterval time.Duration
73 | pollStrategy PollStrategy
74 |
75 | wg sync.WaitGroup
76 |
77 | channel chan *func()
78 | client *Client
79 | clock clock.Clock
80 | logger Logger
81 |
82 | handlerFuncMap map[string]HandlerFunc
83 |
84 | activeMutex sync.RWMutex
85 | activeTasks map[uuid.UUID]struct{}
86 |
87 | visibilityTimeout time.Duration
88 | queues []string
89 |
90 | stop chan struct{}
91 | }
92 |
93 | // NewConsumer creates a new consumer with a reference to the original tasq client
94 | // and default consumer parameters.
95 | func (c *Client) NewConsumer() *Consumer {
96 | return &Consumer{
97 | running: false,
98 | autoDeleteOnSuccess: defaultAutoDeleteOnSuccess,
99 | channelSize: defaultChannelSize,
100 | pollLimit: defaultPollLimit,
101 | maxActiveTasks: defaultMaxActiveTasks,
102 | pollInterval: defaultPollInterval,
103 | pollStrategy: defaultPollStrategy,
104 |
105 | wg: sync.WaitGroup{},
106 |
107 | channel: nil,
108 | client: c,
109 | clock: clock.New(),
110 | logger: NoopLogger(),
111 |
112 | handlerFuncMap: make(map[string]HandlerFunc),
113 |
114 | activeMutex: sync.RWMutex{},
115 | activeTasks: make(map[uuid.UUID]struct{}),
116 |
117 | visibilityTimeout: defaultVisibilityTimeout,
118 | queues: []string{defaultQueue},
119 |
120 | stop: make(chan struct{}, 1),
121 | }
122 | }
123 |
124 | // WithChannelSize sets the size of the buffered channel used for outputting the polled messages to.
125 | //
126 | // Default value: 10.
127 | func (c *Consumer) WithChannelSize(channelSize int) *Consumer {
128 | c.channelSize = channelSize
129 |
130 | return c
131 | }
132 |
133 | // WithLogger sets the Logger interface that is used for event logging during task consumption.
134 | //
135 | // Default value: NoopLogger.
136 | func (c *Consumer) WithLogger(logger Logger) *Consumer {
137 | c.logger = logger
138 |
139 | return c
140 | }
141 |
142 | // WithPollInterval sets the interval at which the consumer will try and poll for new tasks to be executed
143 | // must not be greater than or equal to visibility timeout.
144 | //
145 | // Default value: 5 seconds.
146 | func (c *Consumer) WithPollInterval(pollInterval time.Duration) *Consumer {
147 | c.pollInterval = pollInterval
148 |
149 | return c
150 | }
151 |
152 | // WithPollLimit sets the maximum number of messages polled from the task queue.
153 | //
154 | // Default value: 10.
155 | func (c *Consumer) WithPollLimit(pollLimit int) *Consumer {
156 | c.pollLimit = pollLimit
157 |
158 | return c
159 | }
160 |
161 | // WithPollStrategy sets the ordering to be used when polling for tasks from the task queue.
162 | //
163 | // Default value: PollStrategyByCreatedAt.
164 | func (c *Consumer) WithPollStrategy(pollStrategy PollStrategy) *Consumer {
165 | c.pollStrategy = pollStrategy
166 |
167 | return c
168 | }
169 |
170 | // WithAutoDeleteOnSuccess sets whether successful tasks should be automatically deleted from the task queue
171 | // by the consumer.
172 | //
173 | // Default value: false.
174 | func (c *Consumer) WithAutoDeleteOnSuccess(autoDeleteOnSuccess bool) *Consumer {
175 | c.autoDeleteOnSuccess = autoDeleteOnSuccess
176 |
177 | return c
178 | }
179 |
180 | // WithMaxActiveTasks sets the maximum number of tasks a consumer can have enqueued at the same time
181 | // before polling for additional ones.
182 | //
183 | // Default value: 10.
184 | func (c *Consumer) WithMaxActiveTasks(maxActiveTasks int) *Consumer {
185 | c.maxActiveTasks = maxActiveTasks
186 |
187 | return c
188 | }
189 |
190 | // WithVisibilityTimeout sets the duration by which each ping will extend a task's visibility timeout;
191 | // Once this timeout is up, a consumer instance may receive the task again.
192 | //
193 | // Default value: 15 seconds.
194 | func (c *Consumer) WithVisibilityTimeout(visibilityTimeout time.Duration) *Consumer {
195 | c.visibilityTimeout = visibilityTimeout
196 |
197 | return c
198 | }
199 |
200 | // WithQueues sets the queues from which the consumer may poll for tasks.
201 | //
202 | // Default value: empty slice of strings.
203 | func (c *Consumer) WithQueues(queues ...string) *Consumer {
204 | c.queues = queues
205 |
206 | return c
207 | }
208 |
209 | // Learn sets a handler function for the specified taskType.
210 | // If override is false and a handler function is already set for the specified
211 | // taskType, it'll return an error.
212 | func (c *Consumer) Learn(taskType string, f HandlerFunc, override bool) error {
213 | if _, exists := c.handlerFuncMap[taskType]; exists && !override {
214 | return fmt.Errorf("%w: %s", ErrTaskTypeAlreadyLearned, taskType)
215 | }
216 |
217 | c.handlerFuncMap[taskType] = f
218 |
219 | return nil
220 | }
221 |
222 | // Forget removes a handler function for the specified taskType from the map of
223 | // learned handler functions.
224 | // If the specified taskType does not exist, it'll return an error.
225 | func (c *Consumer) Forget(taskType string) error {
226 | if _, exists := c.handlerFuncMap[taskType]; !exists {
227 | return fmt.Errorf("%w: %s", ErrTaskTypeNotFound, taskType)
228 | }
229 |
230 | delete(c.handlerFuncMap, taskType)
231 |
232 | return nil
233 | }
234 |
235 | // Start launches the go routine which manages the pinging and polling of tasks
236 | // for the consumer, or returns an error if the consumer is not properly configured.
237 | func (c *Consumer) Start(ctx context.Context) error {
238 | if c.isRunning() {
239 | return ErrConsumerAlreadyRunning
240 | }
241 |
242 | if c.visibilityTimeout <= c.pollInterval {
243 | return ErrVisibilityTimeoutTooShort
244 | }
245 |
246 | c.setRunning(true)
247 |
248 | c.channel = make(chan *func(), c.channelSize)
249 |
250 | ticker := c.clock.Ticker(c.pollInterval)
251 |
252 | go c.processLoop(ctx, ticker)
253 |
254 | return nil
255 | }
256 |
257 | // Stop sends the termination signal to the consumer so it'll no longer poll for news tasks.
258 | func (c *Consumer) Stop() error {
259 | if !c.isRunning() {
260 | return ErrConsumerAlreadyStopped
261 | }
262 |
263 | c.stop <- struct{}{}
264 |
265 | return nil
266 | }
267 |
268 | // Channel returns a read-only channel where the polled jobs can be read from.
269 | func (c *Consumer) Channel() <-chan *func() {
270 | return c.channel
271 | }
272 |
273 | func (c *Consumer) isRunning() bool {
274 | return c.running
275 | }
276 |
277 | func (c *Consumer) setRunning(isRunning bool) {
278 | c.running = isRunning
279 | }
280 |
281 | func (c *Consumer) registerTaskStart(ctx context.Context, task *Task) {
282 | _, err := c.client.repository.RegisterStart(ctx, task)
283 | if err != nil {
284 | panic(err)
285 | }
286 | }
287 |
288 | func (c *Consumer) registerTaskError(ctx context.Context, task *Task, taskError error) {
289 | _, err := c.client.repository.RegisterError(ctx, task, taskError)
290 | if err != nil {
291 | panic(err)
292 | }
293 |
294 | if task.MaxReceives > 0 && (task.ReceiveCount) >= task.MaxReceives {
295 | c.registerTaskFail(ctx, task)
296 | } else {
297 | c.requeueTask(ctx, task)
298 | }
299 | }
300 |
301 | func (c *Consumer) registerTaskSuccess(ctx context.Context, task *Task) {
302 | if c.autoDeleteOnSuccess {
303 | err := c.client.repository.DeleteTask(ctx, task, false)
304 | if err != nil {
305 | panic(err)
306 | }
307 | } else {
308 | _, err := c.client.repository.RegisterFinish(ctx, task, StatusSuccessful)
309 | if err != nil {
310 | panic(err)
311 | }
312 | }
313 |
314 | c.removeFromActiveTasks(task)
315 | }
316 |
317 | func (c *Consumer) registerTaskFail(ctx context.Context, task *Task) {
318 | _, err := c.client.repository.RegisterFinish(ctx, task, StatusFailed)
319 | if err != nil {
320 | panic(err)
321 | }
322 |
323 | c.removeFromActiveTasks(task)
324 | }
325 |
326 | func (c *Consumer) requeueTask(ctx context.Context, task *Task) {
327 | _, err := c.client.repository.RequeueTask(ctx, task)
328 | if err != nil {
329 | panic(err)
330 | }
331 |
332 | c.removeFromActiveTasks(task)
333 | }
334 |
335 | func (c *Consumer) getActiveTaskCount() int {
336 | return len(c.activeTasks)
337 | }
338 |
339 | func (c *Consumer) removeFromActiveTasks(task *Task) {
340 | c.activeMutex.Lock()
341 | delete(c.activeTasks, task.ID)
342 | c.activeMutex.Unlock()
343 | }
344 |
345 | func (c *Consumer) getActiveTaskIDs() []uuid.UUID {
346 | activeTaskIDs := make([]uuid.UUID, 0, len(c.activeTasks))
347 |
348 | for taskID := range c.activeTasks {
349 | activeTaskIDs = append(activeTaskIDs, taskID)
350 | }
351 |
352 | return activeTaskIDs
353 | }
354 |
355 | func (c *Consumer) getKnownTaskTypes() []string {
356 | taskTypes := make([]string, 0, len(c.handlerFuncMap))
357 |
358 | for taskType := range c.handlerFuncMap {
359 | taskTypes = append(taskTypes, taskType)
360 | }
361 |
362 | return taskTypes
363 | }
364 |
365 | func (c *Consumer) getPollOrdering() (Ordering, error) {
366 | switch c.pollStrategy {
367 | case PollStrategyByCreatedAt:
368 | return OrderingCreatedAtFirst, nil
369 | case PollStrategyByPriority:
370 | return OrderingPriorityFirst, nil
371 | default:
372 | return -1, fmt.Errorf("%w: %s", ErrUnknownPollStrategy, c.pollStrategy)
373 | }
374 | }
375 |
376 | func (c *Consumer) getPollQuantity() int {
377 | taskCapacity := c.maxActiveTasks - len(c.activeTasks)
378 |
379 | if c.pollLimit < taskCapacity {
380 | return c.pollLimit
381 | }
382 |
383 | return taskCapacity
384 | }
385 |
386 | func (c *Consumer) processLoop(ctx context.Context, ticker *clock.Ticker) {
387 | c.wg.Add(1)
388 | defer c.wg.Done()
389 | defer c.logger.Print("processing stopped")
390 | defer ticker.Stop()
391 |
392 | var (
393 | tasks []*Task
394 | err error
395 | )
396 |
397 | for {
398 | err = c.pingActiveTasks(ctx)
399 | if err != nil {
400 | c.logger.Printf("error pinging active tasks: %s", err)
401 | }
402 |
403 | if c.isRunning() {
404 | tasks, err = c.pollForTasks(ctx)
405 | if err != nil {
406 | c.logger.Printf("error polling for tasks: %s", err)
407 | }
408 |
409 | err = c.activateTasks(ctx, tasks)
410 | if err != nil {
411 | c.logger.Printf("error activating tasks: %s", err)
412 | }
413 | } else if c.getActiveTaskCount() == 0 {
414 | return
415 | }
416 |
417 | select {
418 | case <-c.stop:
419 | c.setRunning(false)
420 | close(c.channel)
421 | case <-ticker.C:
422 | continue
423 | }
424 | }
425 | }
426 |
427 | func (c *Consumer) pollForTasks(ctx context.Context) ([]*Task, error) {
428 | pollOrdering, err := c.getPollOrdering()
429 | if err != nil {
430 | return nil, err
431 | }
432 |
433 | tasks, err := c.client.repository.PollTasks(ctx, c.getKnownTaskTypes(), c.queues, c.visibilityTimeout, pollOrdering, c.getPollQuantity())
434 | if err != nil {
435 | return nil, fmt.Errorf("%w: %s", ErrCouldNotPollTasks, err)
436 | }
437 |
438 | return tasks, nil
439 | }
440 |
441 | func (c *Consumer) pingActiveTasks(ctx context.Context) error {
442 | _, err := c.client.repository.PingTasks(ctx, c.getActiveTaskIDs(), c.visibilityTimeout)
443 | if err != nil {
444 | return fmt.Errorf("%w: %s", ErrCouldNotPingTasks, err)
445 | }
446 |
447 | return nil
448 | }
449 |
450 | func (c *Consumer) activateTasks(ctx context.Context, tasks []*Task) error {
451 | var errors []error
452 |
453 | for _, task := range tasks {
454 | err := c.activateTask(ctx, task)
455 | if err != nil {
456 | errors = append(errors, err)
457 |
458 | c.registerTaskFail(ctx, task)
459 | }
460 | }
461 |
462 | if len(errors) > 0 {
463 | return fmt.Errorf("%w: %v", ErrCouldNotActivateTasks, len(errors))
464 | }
465 |
466 | return nil
467 | }
468 |
469 | func (c *Consumer) activateTask(ctx context.Context, task *Task) error {
470 | job, err := c.createJobFromTask(ctx, task)
471 | if err != nil {
472 | return err
473 | }
474 |
475 | c.activeMutex.Lock()
476 | c.activeTasks[task.ID] = struct{}{}
477 | c.activeMutex.Unlock()
478 |
479 | c.channel <- job
480 |
481 | return nil
482 | }
483 |
484 | func (c *Consumer) createJobFromTask(ctx context.Context, task *Task) (*func(), error) {
485 | if handlerFunc, ok := c.handlerFuncMap[task.Type]; ok {
486 | return c.newJob(ctx, c, handlerFunc, task), nil
487 | }
488 |
489 | return nil, fmt.Errorf("%w: %s", ErrTaskTypeNotKnown, task.Type)
490 | }
491 |
492 | func (c *Consumer) newJob(ctx context.Context, consumer *Consumer, f HandlerFunc, task *Task) *func() {
493 | job := func() {
494 | consumer.registerTaskStart(ctx, task)
495 |
496 | if err := f(task); err == nil {
497 | consumer.registerTaskSuccess(ctx, task)
498 | } else {
499 | consumer.registerTaskError(ctx, task, err)
500 | }
501 | }
502 |
503 | return &job
504 | }
505 |
--------------------------------------------------------------------------------
/consumer_test.go:
--------------------------------------------------------------------------------
1 | package tasq_test
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "errors"
7 | "log"
8 | "testing"
9 | "time"
10 |
11 | "github.com/benbjohnson/clock"
12 | "github.com/google/uuid"
13 | "github.com/greencoda/tasq"
14 | "github.com/greencoda/tasq/mocks"
15 | "github.com/stretchr/testify/assert"
16 | "github.com/stretchr/testify/mock"
17 | "github.com/stretchr/testify/require"
18 | "github.com/stretchr/testify/suite"
19 | )
20 |
21 | var (
22 | errTaskFail = errors.New("task failed")
23 | errRepository = errors.New("repository error")
24 | )
25 |
26 | type ConsumerTestSuite struct {
27 | suite.Suite
28 | mockRepository *mocks.IRepository
29 | mockClock *clock.Mock
30 | tasqClient *tasq.Client
31 | tasqConsumer *tasq.Consumer
32 | logBuffer bytes.Buffer
33 | }
34 |
35 | func TestConsumerTestSuite(t *testing.T) {
36 | t.Parallel()
37 |
38 | suite.Run(t, new(ConsumerTestSuite))
39 | }
40 |
41 | func (s *ConsumerTestSuite) SetupTest() {
42 | s.mockRepository = mocks.NewIRepository(s.T())
43 | s.mockClock = clock.NewMock()
44 | s.mockClock.Set(time.Now())
45 |
46 | s.tasqClient = tasq.NewClient(s.mockRepository)
47 | require.NotNil(s.T(), s.tasqClient)
48 |
49 | s.tasqConsumer = s.tasqClient.NewConsumer().WithLogger(log.New(&s.logBuffer, "", 0)).SetClock(s.mockClock)
50 | require.NotNil(s.T(), s.tasqConsumer)
51 |
52 | s.logBuffer.Reset()
53 | }
54 |
55 | func (s *ConsumerTestSuite) TestNewConsumer() {
56 | assert.NotNil(s.T(), s.tasqConsumer.
57 | WithAutoDeleteOnSuccess(true).
58 | WithChannelSize(10).
59 | WithMaxActiveTasks(10).
60 | WithPollInterval(10*time.Second).
61 | WithPollLimit(10).
62 | WithPollStrategy(tasq.PollStrategyByCreatedAt).
63 | WithQueues("testQueue").
64 | WithVisibilityTimeout(30*time.Second))
65 |
66 | assert.Empty(s.T(), s.logBuffer.String())
67 | }
68 |
69 | func (s *ConsumerTestSuite) TestLearnAndForget() {
70 | // Learning a new task execution method is successful
71 | err := s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error {
72 | return nil
73 | }, false)
74 | assert.Nil(s.T(), err)
75 |
76 | // Learning an already learned task execution method returns an error
77 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error {
78 | return nil
79 | }, false)
80 | assert.NotNil(s.T(), err)
81 |
82 | // Learning an already learned task execution method with override being true is successful
83 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error {
84 | return nil
85 | }, true)
86 | assert.Nil(s.T(), err)
87 |
88 | // Forgetting an already learned task execution method is successful
89 | err = s.tasqConsumer.Forget("testTask")
90 | assert.Nil(s.T(), err)
91 |
92 | // Forgetting an unknown execution method is successful
93 | err = s.tasqConsumer.Forget("anotherTestTask")
94 | assert.NotNil(s.T(), err)
95 |
96 | assert.Empty(s.T(), s.logBuffer.String())
97 | }
98 |
99 | func (s *ConsumerTestSuite) TestStartWithInvalidVisibilityTimeoutParam() {
100 | ctx := context.Background()
101 |
102 | s.tasqConsumer.
103 | WithVisibilityTimeout(time.Second).
104 | WithPollInterval(5 * time.Second)
105 |
106 | // Start up the consumer
107 | err := s.tasqConsumer.Start(ctx)
108 | assert.NotNil(s.T(), err)
109 |
110 | assert.Empty(s.T(), s.logBuffer.String())
111 | }
112 |
113 | func (s *ConsumerTestSuite) TestStartStopTwice() {
114 | ctx := context.Background()
115 |
116 | s.tasqConsumer.
117 | WithQueues("testQueue")
118 |
119 | err := s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error {
120 | return nil
121 | }, false)
122 | assert.Nil(s.T(), err)
123 |
124 | // Getting tasks
125 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Return([]*tasq.Task{}, nil)
126 |
127 | // Polling fails on the first time
128 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 10).Return([]*tasq.Task{}, errRepository).Once()
129 |
130 | // Polling succeeds on subsequent attempts
131 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 10).Return([]*tasq.Task{}, nil)
132 |
133 | // Start up the consumer
134 | err = s.tasqConsumer.Start(ctx)
135 | assert.Nil(s.T(), err)
136 |
137 | // Start up the consumer again
138 | err = s.tasqConsumer.Start(ctx)
139 | assert.NotNil(s.T(), err)
140 |
141 | // Stop the consumer
142 | err = s.tasqConsumer.Stop()
143 | assert.Nil(s.T(), err)
144 |
145 | s.mockClock.Add(5 * time.Second)
146 |
147 | // Stop the consumer again
148 | err = s.tasqConsumer.Stop()
149 | assert.NotNil(s.T(), err)
150 |
151 | // Wait for goroutine to actually return and output log message
152 | s.tasqConsumer.GetWaitGroup().Wait()
153 |
154 | assert.Equal(s.T(), "error polling for tasks: could not poll tasks: repository error\nprocessing stopped\n", s.logBuffer.String())
155 | }
156 |
157 | func (s *ConsumerTestSuite) TestConsumption() {
158 | ctx := context.Background()
159 |
160 | s.tasqConsumer.
161 | WithQueues("testQueue")
162 |
163 | var (
164 | successTestArgs = "success"
165 | failTestArgs = "fail"
166 | )
167 |
168 | successTestTask, err := tasq.NewTask("testTask", successTestArgs, "testQueue", 100, 5)
169 | require.NotNil(s.T(), successTestTask)
170 | require.Nil(s.T(), err)
171 |
172 | failTestTask, err := tasq.NewTask("testTask", failTestArgs, "testQueue", 100, 5)
173 | require.NotNil(s.T(), failTestTask)
174 | require.Nil(s.T(), err)
175 |
176 | failNoRequeueTestTask, err := tasq.NewTask("testTask", failTestArgs, "testQueue", 100, 1)
177 | require.NotNil(s.T(), failNoRequeueTestTask)
178 | require.Nil(s.T(), err)
179 |
180 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error {
181 | var args string
182 |
183 | err := task.UnmarshalArgs(&args)
184 | require.Nil(s.T(), err)
185 |
186 | if args == successTestArgs {
187 | return nil
188 | }
189 |
190 | return errTaskFail
191 | }, false)
192 | require.Nil(s.T(), err)
193 |
194 | // Increment receive counts
195 | for _, task := range []*tasq.Task{
196 | successTestTask,
197 | failTestTask,
198 | failNoRequeueTestTask,
199 | } {
200 | task.ReceiveCount++
201 | }
202 |
203 | // Getting tasks
204 |
205 | // First try - pinging fails
206 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Once().Return([]*tasq.Task{}, errRepository)
207 |
208 | // Second try - pinging succeeds
209 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Once().Return([]*tasq.Task{}, nil)
210 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 10).Return([]*tasq.Task{
211 | successTestTask,
212 | failTestTask,
213 | failNoRequeueTestTask,
214 | }, nil)
215 |
216 | // Start up the consumer
217 | err = s.tasqConsumer.Start(ctx)
218 | assert.Nil(s.T(), err)
219 |
220 | // Read the successful job from the consumer channel
221 | successJob := <-s.tasqConsumer.Channel()
222 | assert.NotNil(s.T(), successJob)
223 |
224 | // First try - registering task start fails
225 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(nil, errRepository)
226 | assert.Panics(s.T(), func() {
227 | (*successJob)()
228 | })
229 |
230 | // Second try - registering task success fails
231 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(successTestTask, nil)
232 | s.mockRepository.On("RegisterFinish", ctx, successTestTask, tasq.StatusSuccessful).Once().Return(nil, errRepository)
233 | assert.Panics(s.T(), func() {
234 | (*successJob)()
235 | })
236 |
237 | // Third try - repository succeeds
238 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(successTestTask, nil)
239 | s.mockRepository.On("RegisterFinish", ctx, successTestTask, tasq.StatusSuccessful).Once().Return(successTestTask, nil)
240 | assert.NotPanics(s.T(), func() {
241 | (*successJob)()
242 | })
243 |
244 | // Read the failing job from the consumer channel
245 | failJob := <-s.tasqConsumer.Channel()
246 | assert.NotNil(s.T(), failJob)
247 |
248 | // First try - registering task error fails
249 | s.mockRepository.On("RegisterStart", ctx, failTestTask).Once().Return(failTestTask, nil)
250 | s.mockRepository.On("RegisterError", ctx, failTestTask, errTaskFail).Once().Return(nil, errRepository)
251 | assert.Panics(s.T(), func() {
252 | (*failJob)()
253 | })
254 |
255 | // Second try - requeuing task fails
256 | s.mockRepository.On("RegisterStart", ctx, failTestTask).Once().Return(failTestTask, nil)
257 | s.mockRepository.On("RegisterError", ctx, failTestTask, errTaskFail).Once().Return(failTestTask, nil)
258 | s.mockRepository.On("RequeueTask", ctx, failTestTask).Once().Return(nil, errRepository)
259 | assert.Panics(s.T(), func() {
260 | (*failJob)()
261 | })
262 |
263 | // Third try - repository succeeds
264 | s.mockRepository.On("RegisterStart", ctx, failTestTask).Once().Return(failTestTask, nil)
265 | s.mockRepository.On("RegisterError", ctx, failTestTask, errTaskFail).Once().Return(failTestTask, nil)
266 | s.mockRepository.On("RequeueTask", ctx, failTestTask).Once().Return(failTestTask, nil)
267 | assert.NotPanics(s.T(), func() {
268 | (*failJob)()
269 | })
270 |
271 | // Read the failing job that shouldn't be requeued from the consumer channel
272 | failNoRequeueJob := <-s.tasqConsumer.Channel()
273 | assert.NotNil(s.T(), failNoRequeueJob)
274 |
275 | // First try - registering task failure fails
276 | s.mockRepository.On("RegisterStart", ctx, failNoRequeueTestTask).Once().Return(failNoRequeueTestTask, nil)
277 | s.mockRepository.On("RegisterError", ctx, failNoRequeueTestTask, errTaskFail).Once().Return(nil, errRepository)
278 | assert.Panics(s.T(), func() {
279 | (*failNoRequeueJob)()
280 | })
281 |
282 | // Second try - registering task failure fails
283 | s.mockRepository.On("RegisterStart", ctx, failNoRequeueTestTask).Once().Return(failNoRequeueTestTask, nil)
284 | s.mockRepository.On("RegisterError", ctx, failNoRequeueTestTask, errTaskFail).Once().Return(failNoRequeueTestTask, nil)
285 | s.mockRepository.On("RegisterFinish", ctx, failNoRequeueTestTask, tasq.StatusFailed).Once().Return(nil, errRepository)
286 | assert.Panics(s.T(), func() {
287 | (*failNoRequeueJob)()
288 | })
289 |
290 | // Third try - repository succeeds
291 | s.mockRepository.On("RegisterStart", ctx, failNoRequeueTestTask).Once().Return(failNoRequeueTestTask, nil)
292 | s.mockRepository.On("RegisterError", ctx, failNoRequeueTestTask, errTaskFail).Once().Return(failNoRequeueTestTask, nil)
293 | s.mockRepository.On("RegisterFinish", ctx, failNoRequeueTestTask, tasq.StatusFailed).Once().Return(failNoRequeueTestTask, nil)
294 | assert.NotPanics(s.T(), func() {
295 | (*failNoRequeueJob)()
296 | })
297 |
298 | // Stop the consumer
299 | err = s.tasqConsumer.Stop()
300 | assert.Nil(s.T(), err)
301 |
302 | // Wait until channel is closed
303 | <-s.tasqConsumer.Channel()
304 |
305 | s.mockClock.Add(10 * time.Second)
306 |
307 | // Wait for goroutine to actually return and output log message
308 | s.tasqConsumer.GetWaitGroup().Wait()
309 |
310 | assert.Equal(s.T(), "error pinging active tasks: could not ping tasks: repository error\nprocessing stopped\n", s.logBuffer.String())
311 | }
312 |
313 | func (s *ConsumerTestSuite) TestConsumptionWithAutoDeleteOnSuccess() {
314 | ctx := context.Background()
315 |
316 | s.tasqConsumer.
317 | WithQueues("testQueue").
318 | WithAutoDeleteOnSuccess(true)
319 |
320 | successTestTask, err := tasq.NewTask("testTask", true, "testQueue", 100, 5)
321 | require.NotNil(s.T(), successTestTask)
322 | require.Nil(s.T(), err)
323 |
324 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error {
325 | return nil
326 | }, false)
327 | require.Nil(s.T(), err)
328 |
329 | successTestTask.ReceiveCount++
330 |
331 | // Getting tasks
332 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Twice().Return([]*tasq.Task{}, nil)
333 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 10).Return([]*tasq.Task{
334 | successTestTask,
335 | }, nil)
336 |
337 | // Start up the consumer
338 | err = s.tasqConsumer.Start(ctx)
339 | assert.Nil(s.T(), err)
340 |
341 | // Read the successful job from the consumer channel
342 | successJob := <-s.tasqConsumer.Channel()
343 | assert.NotNil(s.T(), successJob)
344 |
345 | // First try - deleting task fails
346 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(successTestTask, nil)
347 | s.mockRepository.On("DeleteTask", ctx, successTestTask, false).Once().Return(errRepository)
348 | assert.Panics(s.T(), func() {
349 | (*successJob)()
350 | })
351 |
352 | // Second try - deleting task succeeds
353 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(successTestTask, nil)
354 | s.mockRepository.On("DeleteTask", ctx, successTestTask, false).Once().Return(nil)
355 | assert.NotPanics(s.T(), func() {
356 | (*successJob)()
357 | })
358 |
359 | // Stop the consumer
360 | err = s.tasqConsumer.Stop()
361 | assert.Nil(s.T(), err)
362 |
363 | // Wait until channel is closed
364 | <-s.tasqConsumer.Channel()
365 |
366 | s.mockClock.Add(5 * time.Second)
367 |
368 | // Wait for goroutine to actually return and output log message
369 | s.tasqConsumer.GetWaitGroup().Wait()
370 |
371 | assert.Equal(s.T(), "processing stopped\n", s.logBuffer.String())
372 | }
373 |
374 | func (s *ConsumerTestSuite) TestConsumptionWithPollStrategyByPriority() {
375 | ctx := context.Background()
376 |
377 | s.tasqConsumer.
378 | WithQueues("testQueue").
379 | WithPollStrategy(tasq.PollStrategyByPriority)
380 |
381 | successTestTask, err := tasq.NewTask("testTask", true, "testQueue", 100, 5)
382 | require.NotNil(s.T(), successTestTask)
383 | require.Nil(s.T(), err)
384 |
385 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error {
386 | return nil
387 | }, false)
388 | require.Nil(s.T(), err)
389 |
390 | successTestTask.ReceiveCount++
391 |
392 | // Getting tasks
393 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Twice().Return([]*tasq.Task{}, nil)
394 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingPriorityFirst, 10).Return([]*tasq.Task{
395 | successTestTask,
396 | }, nil)
397 |
398 | // Start up the consumer
399 | err = s.tasqConsumer.Start(ctx)
400 | assert.Nil(s.T(), err)
401 |
402 | // Read the successful job from the consumer channel
403 | successJob := <-s.tasqConsumer.Channel()
404 | assert.NotNil(s.T(), successJob)
405 |
406 | // First try - repository succeeds
407 | s.mockRepository.On("RegisterStart", ctx, successTestTask).Once().Return(successTestTask, nil)
408 | s.mockRepository.On("RegisterFinish", ctx, successTestTask, tasq.StatusSuccessful).Once().Return(successTestTask, nil)
409 | assert.NotPanics(s.T(), func() {
410 | (*successJob)()
411 | })
412 |
413 | // Stop the consumer
414 | err = s.tasqConsumer.Stop()
415 | assert.Nil(s.T(), err)
416 |
417 | // Wait until channel is closed
418 | <-s.tasqConsumer.Channel()
419 |
420 | s.mockClock.Add(5 * time.Second)
421 |
422 | // Wait for goroutine to actually return and output log message
423 | s.tasqConsumer.GetWaitGroup().Wait()
424 |
425 | assert.Equal(s.T(), "processing stopped\n", s.logBuffer.String())
426 | }
427 |
428 | func (s *ConsumerTestSuite) TestConsumptionWithUnknownPollStrategy() {
429 | ctx := context.Background()
430 |
431 | s.tasqConsumer.
432 | WithQueues("testQueue").
433 | WithPollStrategy(tasq.PollStrategy("pollByMagic"))
434 |
435 | // Getting tasks
436 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Twice().Return([]*tasq.Task{}, nil)
437 |
438 | // Start up the consumer
439 | err := s.tasqConsumer.Start(ctx)
440 | assert.Nil(s.T(), err)
441 |
442 | // Stop the consumer
443 | err = s.tasqConsumer.Stop()
444 | assert.Nil(s.T(), err)
445 |
446 | // Wait until channel is closed
447 | <-s.tasqConsumer.Channel()
448 |
449 | s.mockClock.Add(5 * time.Second)
450 |
451 | // Wait for goroutine to actually return and output log message
452 | s.tasqConsumer.GetWaitGroup().Wait()
453 |
454 | assert.Equal(s.T(), "error polling for tasks: unknown poll strategy: pollByMagic\nprocessing stopped\n", s.logBuffer.String())
455 | }
456 |
457 | func (s *ConsumerTestSuite) TestConsumptionOfUnknownTaskType() {
458 | ctx := context.Background()
459 |
460 | s.tasqConsumer.
461 | WithQueues("testQueue")
462 |
463 | anotherTestTask, err := tasq.NewTask("anotherTestTask", true, "testQueue", 100, 5)
464 | require.NotNil(s.T(), anotherTestTask)
465 | require.Nil(s.T(), err)
466 |
467 | err = s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error {
468 | return nil
469 | }, false)
470 | require.Nil(s.T(), err)
471 |
472 | // Getting tasks
473 | s.mockRepository.On("PingTasks", ctx, []uuid.UUID{}, 15*time.Second).Twice().Return([]*tasq.Task{}, nil)
474 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 10).Return([]*tasq.Task{
475 | anotherTestTask,
476 | }, nil)
477 |
478 | // Start up the consumer
479 | err = s.tasqConsumer.Start(ctx)
480 | assert.Nil(s.T(), err)
481 |
482 | s.mockRepository.On("RegisterFinish", ctx, anotherTestTask, tasq.StatusFailed).Return(anotherTestTask, nil)
483 |
484 | // Stop the consumer
485 | err = s.tasqConsumer.Stop()
486 | assert.Nil(s.T(), err)
487 |
488 | // Wait until channel is closed
489 | <-s.tasqConsumer.Channel()
490 |
491 | s.mockClock.Add(5 * time.Second)
492 |
493 | // Wait for goroutine to actually return and output log message
494 | s.tasqConsumer.GetWaitGroup().Wait()
495 |
496 | assert.Equal(s.T(), "error activating tasks: a number of tasks could not be activated: 1\nprocessing stopped\n", s.logBuffer.String())
497 | }
498 |
499 | func (s *ConsumerTestSuite) TestLoopingConsumption() {
500 | s.tasqConsumer.
501 | WithQueues("testQueue").
502 | WithPollInterval(5 * time.Second).
503 | WithPollLimit(1).
504 | WithMaxActiveTasks(2)
505 |
506 | var (
507 | ctx = context.Background()
508 |
509 | testTaskID1 = uuid.MustParse("1ada263f-61d5-44ac-b99d-2d5ad4f249de")
510 | testTaskID2 = uuid.MustParse("28032675-bc13-4dcd-8ec6-6aa430fc466a")
511 |
512 | testTasks = map[uuid.UUID]*tasq.Task{
513 | testTaskID1: {
514 | ID: testTaskID1,
515 | Type: "testTask",
516 | Args: []uint8{0x3, 0x2, 0x0, 0x1},
517 | Queue: "testQueue",
518 | Priority: 100,
519 | Status: tasq.StatusNew,
520 | ReceiveCount: 0,
521 | MaxReceives: 5,
522 | CreatedAt: s.mockClock.Now(),
523 | VisibleAt: s.mockClock.Now(),
524 | },
525 | testTaskID2: {
526 | ID: testTaskID2,
527 | Type: "testTask",
528 | Args: []uint8{0x3, 0x2, 0x0, 0x1},
529 | Queue: "testQueue",
530 | Priority: 100,
531 | Status: tasq.StatusNew,
532 | ReceiveCount: 0,
533 | MaxReceives: 5,
534 | CreatedAt: s.mockClock.Now(),
535 | VisibleAt: s.mockClock.Now(),
536 | },
537 | }
538 | )
539 |
540 | err := s.tasqConsumer.Learn("testTask", func(task *tasq.Task) error {
541 | return nil
542 | }, false)
543 | require.Nil(s.T(), err)
544 |
545 | // Respond to pings
546 | pingCall := s.mockRepository.On("PingTasks", ctx, mock.AnythingOfType("[]uuid.UUID"), 15*time.Second)
547 | pingCall.Run(func(args mock.Arguments) {
548 | inputTestIDs, ok := args[1].([]uuid.UUID)
549 | require.True(s.T(), ok)
550 |
551 | taskIDs, ok := args[1].([]uuid.UUID)
552 | require.True(s.T(), ok)
553 |
554 | returnTasks := make([]*tasq.Task, 0, len(inputTestIDs))
555 |
556 | for _, taskID := range taskIDs {
557 | returnTask, ok := testTasks[taskID]
558 | require.True(s.T(), ok)
559 |
560 | returnTasks = append(returnTasks, returnTask)
561 | }
562 |
563 | pingCall.ReturnArguments = mock.Arguments{returnTasks, nil}
564 | })
565 |
566 | // Respond to polls
567 | // First call
568 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1).Once().
569 | Return([]*tasq.Task{
570 | testTasks[testTaskID1],
571 | }, nil)
572 |
573 | // Second call
574 | s.mockRepository.On("PollTasks", ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1).Once().
575 | Return([]*tasq.Task{
576 | testTasks[testTaskID2],
577 | }, nil)
578 |
579 | // Start up the consumer
580 | err = s.tasqConsumer.Start(ctx)
581 | assert.Nil(s.T(), err)
582 |
583 | s.mockClock.Add(5 * time.Second)
584 |
585 | // Stop the consumer
586 | err = s.tasqConsumer.Stop()
587 | assert.Nil(s.T(), err)
588 |
589 | // Mock job handling
590 | s.mockRepository.On("RegisterStart", ctx, testTasks[testTaskID1]).Once().
591 | Return(testTasks[testTaskID1], nil)
592 | s.mockRepository.On("RegisterFinish", ctx, testTasks[testTaskID1], tasq.StatusSuccessful).Once().
593 | Return(testTasks[testTaskID1], nil)
594 |
595 | s.mockRepository.On("RegisterStart", ctx, testTasks[testTaskID2]).Once().
596 | Return(testTasks[testTaskID2], nil)
597 | s.mockRepository.On("RegisterFinish", ctx, testTasks[testTaskID2], tasq.StatusSuccessful).Once().
598 | Return(testTasks[testTaskID2], nil)
599 |
600 | // Drain channel of jobs
601 | for job := range s.tasqConsumer.Channel() {
602 | currentJob := job
603 |
604 | assert.NotPanics(s.T(), func() {
605 | (*currentJob)()
606 | })
607 | }
608 |
609 | // Let 5 seconds pass so that the goroutine has a chance to finish its last loop
610 | s.mockClock.Add(5 * time.Second)
611 |
612 | // Wait for goroutine to actually return and output log message
613 | s.tasqConsumer.GetWaitGroup().Wait()
614 |
615 | assert.Equal(s.T(), "processing stopped\n", s.logBuffer.String())
616 | }
617 |
--------------------------------------------------------------------------------
/export_test.go:
--------------------------------------------------------------------------------
1 | package tasq
2 |
3 | import (
4 | "sync"
5 |
6 | "github.com/benbjohnson/clock"
7 | )
8 |
9 | // SetClock needs to be exported so we can adjust the clock during tests
10 | // to avoid failures due to different results of time.Now().
11 | func (c *Consumer) SetClock(clock clock.Clock) *Consumer {
12 | c.clock = clock
13 |
14 | return c
15 | }
16 |
17 | // GetWaitGroup needs to be exported so we can get the waitgroup of the consumer
18 | // in order to allow the tests for the consumption to finish.
19 | func (c *Consumer) GetWaitGroup() *sync.WaitGroup {
20 | return &c.wg
21 | }
22 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/greencoda/tasq
2 |
3 | go 1.19
4 |
5 | require (
6 | github.com/DATA-DOG/go-sqlmock v1.5.0
7 | github.com/benbjohnson/clock v1.3.0
8 | github.com/go-sql-driver/mysql v1.7.0
9 | github.com/google/uuid v1.3.0
10 | github.com/jmoiron/sqlx v1.3.5
11 | github.com/lib/pq v1.10.8
12 | github.com/stretchr/testify v1.8.1
13 | )
14 |
15 | require (
16 | github.com/davecgh/go-spew v1.1.1 // indirect
17 | github.com/pmezard/go-difflib v1.0.0 // indirect
18 | github.com/stretchr/objx v0.5.0 // indirect
19 | gopkg.in/yaml.v3 v3.0.1 // indirect
20 | )
21 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
2 | github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
3 | github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
4 | github.com/benbjohnson/clock v1.3.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
5 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
6 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
7 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
8 | github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
9 | github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc=
10 | github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
11 | github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
12 | github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
13 | github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
14 | github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
15 | github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
16 | github.com/lib/pq v1.10.8 h1:3fdt97i/cwSU83+E0hZTC/Xpc9mTZxc6UWSCRcSbxiE=
17 | github.com/lib/pq v1.10.8/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
18 | github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
19 | github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
20 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
21 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
22 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
23 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
24 | github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c=
25 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
26 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
27 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
28 | github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
29 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
30 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
31 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
32 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
33 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
34 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
35 |
--------------------------------------------------------------------------------
/inspector.go:
--------------------------------------------------------------------------------
1 | package tasq
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | )
7 |
8 | // Inspector is a service instance created by a Client with reference to that client
9 | // with the purpose of enabling the observation of tasks.
10 | type Inspector struct {
11 | client *Client
12 | }
13 |
14 | // NewInspector creates a new inspector with a reference to the original tasq client.
15 | func (c *Client) NewInspector() *Inspector {
16 | return &Inspector{
17 | client: c,
18 | }
19 | }
20 |
21 | // Count returns a the total number of tasks based on the supplied filter arguments.
22 | func (o *Inspector) Count(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string) (int64, error) {
23 | count, err := o.client.repository.CountTasks(ctx, taskStatuses, taskTypes, queues)
24 | if err != nil {
25 | return 0, fmt.Errorf("error counting tasks: %w", err)
26 | }
27 |
28 | return count, nil
29 | }
30 |
31 | // Scan returns a list of tasks based on the supplied filter arguments.
32 | func (o *Inspector) Scan(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string, ordering Ordering, limit int) ([]*Task, error) {
33 | tasks, err := o.client.repository.ScanTasks(ctx, taskStatuses, taskTypes, queues, ordering, limit)
34 | if err != nil {
35 | return nil, fmt.Errorf("error scanning tasks: %w", err)
36 | }
37 |
38 | return tasks, nil
39 | }
40 |
41 | // Purge will remove all tasks based on the supplied filter arguments.
42 | func (o *Inspector) Purge(ctx context.Context, safeDelete bool, taskStatuses []TaskStatus, taskTypes, queues []string) (int64, error) {
43 | count, err := o.client.repository.PurgeTasks(ctx, taskStatuses, taskTypes, queues, safeDelete)
44 | if err != nil {
45 | return 0, fmt.Errorf("error purging tasks: %w", err)
46 | }
47 |
48 | return count, nil
49 | }
50 |
51 | // Delete will remove the supplied tasks.
52 | func (o *Inspector) Delete(ctx context.Context, safeDelete bool, tasks ...*Task) error {
53 | for _, task := range tasks {
54 | if err := o.client.repository.DeleteTask(ctx, task, safeDelete); err != nil {
55 | return fmt.Errorf("error removing task: %w", err)
56 | }
57 | }
58 |
59 | return nil
60 | }
61 |
--------------------------------------------------------------------------------
/inspector_test.go:
--------------------------------------------------------------------------------
1 | package tasq_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/greencoda/tasq"
8 | "github.com/greencoda/tasq/mocks"
9 | "github.com/stretchr/testify/assert"
10 | "github.com/stretchr/testify/require"
11 | "github.com/stretchr/testify/suite"
12 | )
13 |
14 | type InspectorTestSuite struct {
15 | suite.Suite
16 | mockRepository *mocks.IRepository
17 | tasqClient *tasq.Client
18 | tasqInspector *tasq.Inspector
19 | }
20 |
21 | func TestInspectorTestSuite(t *testing.T) {
22 | t.Parallel()
23 |
24 | suite.Run(t, new(InspectorTestSuite))
25 | }
26 |
27 | func (s *InspectorTestSuite) SetupTest() {
28 | s.mockRepository = mocks.NewIRepository(s.T())
29 |
30 | s.tasqClient = tasq.NewClient(s.mockRepository)
31 | require.NotNil(s.T(), s.tasqClient)
32 |
33 | s.tasqInspector = s.tasqClient.NewInspector()
34 | }
35 |
36 | func (s *InspectorTestSuite) TestNewCleaner() {
37 | assert.NotNil(s.T(), s.tasqInspector)
38 | }
39 |
40 | func (s *InspectorTestSuite) TestCount() {
41 | ctx := context.Background()
42 |
43 | s.mockRepository.On("CountTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}).Return(int64(1), nil).Once()
44 |
45 | taskCount, err := s.tasqInspector.Count(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"})
46 | assert.Equal(s.T(), int64(1), taskCount)
47 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "CountTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}))
48 | assert.Nil(s.T(), err)
49 |
50 | s.mockRepository.On("CountTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}).Return(int64(0), errRepository).Once()
51 |
52 | taskCount, err = s.tasqInspector.Count(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"})
53 | assert.Equal(s.T(), int64(0), taskCount)
54 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "CountTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}))
55 | assert.NotNil(s.T(), err)
56 | }
57 |
58 | func (s *InspectorTestSuite) TestScan() {
59 | ctx := context.Background()
60 |
61 | testTask, err := tasq.NewTask("testTask", true, "testQueue", 100, 5)
62 | require.NotNil(s.T(), testTask)
63 | require.Nil(s.T(), err)
64 |
65 | s.mockRepository.On("ScanTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100).Return([]*tasq.Task{testTask}, nil).Once()
66 |
67 | tasks, err := s.tasqInspector.Scan(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100)
68 | assert.Equal(s.T(), []*tasq.Task{testTask}, tasks)
69 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "ScanTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100))
70 | assert.Nil(s.T(), err)
71 |
72 | s.mockRepository.On("ScanTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100).Return([]*tasq.Task{}, errRepository).Once()
73 |
74 | tasks, err = s.tasqInspector.Scan(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100)
75 | assert.Len(s.T(), tasks, 0)
76 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "ScanTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, tasq.OrderingCreatedAtFirst, 100))
77 | assert.NotNil(s.T(), err)
78 | }
79 |
80 | func (s *InspectorTestSuite) TestPurge() {
81 | ctx := context.Background()
82 |
83 | s.mockRepository.On("PurgeTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, true).Return(int64(10), nil).Once()
84 |
85 | count, err := s.tasqInspector.Purge(ctx, true, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"})
86 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "PurgeTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, true))
87 | assert.Equal(s.T(), int64(10), count)
88 | assert.Nil(s.T(), err)
89 |
90 | s.mockRepository.On("PurgeTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, true).Return(int64(0), errRepository).Once()
91 |
92 | count, err = s.tasqInspector.Purge(ctx, true, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"})
93 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "PurgeTasks", ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"testType"}, []string{"testQueue"}, true))
94 | assert.Equal(s.T(), int64(0), count)
95 | assert.NotNil(s.T(), err)
96 | }
97 |
98 | func (s *InspectorTestSuite) TestDelete() {
99 | ctx := context.Background()
100 |
101 | testTask, err := tasq.NewTask("testTask", true, "testQueue", 100, 5)
102 | require.NotNil(s.T(), testTask)
103 | require.Nil(s.T(), err)
104 |
105 | s.mockRepository.On("DeleteTask", ctx, testTask, true).Once().Return(errRepository)
106 |
107 | err = s.tasqInspector.Delete(ctx, true, testTask)
108 | assert.NotNil(s.T(), err)
109 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "DeleteTask", ctx, testTask, true))
110 |
111 | s.mockRepository.On("DeleteTask", ctx, testTask, true).Once().Return(nil)
112 |
113 | err = s.tasqInspector.Delete(ctx, true, testTask)
114 | assert.Nil(s.T(), err)
115 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "DeleteTask", ctx, testTask, true))
116 | }
117 |
--------------------------------------------------------------------------------
/mocks/IRepository.go:
--------------------------------------------------------------------------------
1 | // Code generated by mockery v2.25.1. DO NOT EDIT.
2 |
3 | package mocks
4 |
5 | import (
6 | context "context"
7 |
8 | tasq "github.com/greencoda/tasq"
9 | mock "github.com/stretchr/testify/mock"
10 |
11 | time "time"
12 |
13 | uuid "github.com/google/uuid"
14 | )
15 |
16 | // IRepository is an autogenerated mock type for the IRepository type
17 | type IRepository struct {
18 | mock.Mock
19 | }
20 |
21 | // CleanTasks provides a mock function with given fields: ctx, minimumAge
22 | func (_m *IRepository) CleanTasks(ctx context.Context, minimumAge time.Duration) (int64, error) {
23 | ret := _m.Called(ctx, minimumAge)
24 |
25 | var r0 int64
26 | var r1 error
27 | if rf, ok := ret.Get(0).(func(context.Context, time.Duration) (int64, error)); ok {
28 | return rf(ctx, minimumAge)
29 | }
30 | if rf, ok := ret.Get(0).(func(context.Context, time.Duration) int64); ok {
31 | r0 = rf(ctx, minimumAge)
32 | } else {
33 | r0 = ret.Get(0).(int64)
34 | }
35 |
36 | if rf, ok := ret.Get(1).(func(context.Context, time.Duration) error); ok {
37 | r1 = rf(ctx, minimumAge)
38 | } else {
39 | r1 = ret.Error(1)
40 | }
41 |
42 | return r0, r1
43 | }
44 |
45 | // CountTasks provides a mock function with given fields: ctx, taskStatuses, taskTypes, queues
46 | func (_m *IRepository) CountTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes []string, queues []string) (int64, error) {
47 | ret := _m.Called(ctx, taskStatuses, taskTypes, queues)
48 |
49 | var r0 int64
50 | var r1 error
51 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string) (int64, error)); ok {
52 | return rf(ctx, taskStatuses, taskTypes, queues)
53 | }
54 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string) int64); ok {
55 | r0 = rf(ctx, taskStatuses, taskTypes, queues)
56 | } else {
57 | r0 = ret.Get(0).(int64)
58 | }
59 |
60 | if rf, ok := ret.Get(1).(func(context.Context, []tasq.TaskStatus, []string, []string) error); ok {
61 | r1 = rf(ctx, taskStatuses, taskTypes, queues)
62 | } else {
63 | r1 = ret.Error(1)
64 | }
65 |
66 | return r0, r1
67 | }
68 |
69 | // DeleteTask provides a mock function with given fields: ctx, task, safeDelete
70 | func (_m *IRepository) DeleteTask(ctx context.Context, task *tasq.Task, safeDelete bool) error {
71 | ret := _m.Called(ctx, task, safeDelete)
72 |
73 | var r0 error
74 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task, bool) error); ok {
75 | r0 = rf(ctx, task, safeDelete)
76 | } else {
77 | r0 = ret.Error(0)
78 | }
79 |
80 | return r0
81 | }
82 |
83 | // Migrate provides a mock function with given fields: ctx
84 | func (_m *IRepository) Migrate(ctx context.Context) error {
85 | ret := _m.Called(ctx)
86 |
87 | var r0 error
88 | if rf, ok := ret.Get(0).(func(context.Context) error); ok {
89 | r0 = rf(ctx)
90 | } else {
91 | r0 = ret.Error(0)
92 | }
93 |
94 | return r0
95 | }
96 |
97 | // PingTasks provides a mock function with given fields: ctx, taskIDs, visibilityTimeout
98 | func (_m *IRepository) PingTasks(ctx context.Context, taskIDs []uuid.UUID, visibilityTimeout time.Duration) ([]*tasq.Task, error) {
99 | ret := _m.Called(ctx, taskIDs, visibilityTimeout)
100 |
101 | var r0 []*tasq.Task
102 | var r1 error
103 | if rf, ok := ret.Get(0).(func(context.Context, []uuid.UUID, time.Duration) ([]*tasq.Task, error)); ok {
104 | return rf(ctx, taskIDs, visibilityTimeout)
105 | }
106 | if rf, ok := ret.Get(0).(func(context.Context, []uuid.UUID, time.Duration) []*tasq.Task); ok {
107 | r0 = rf(ctx, taskIDs, visibilityTimeout)
108 | } else {
109 | if ret.Get(0) != nil {
110 | r0 = ret.Get(0).([]*tasq.Task)
111 | }
112 | }
113 |
114 | if rf, ok := ret.Get(1).(func(context.Context, []uuid.UUID, time.Duration) error); ok {
115 | r1 = rf(ctx, taskIDs, visibilityTimeout)
116 | } else {
117 | r1 = ret.Error(1)
118 | }
119 |
120 | return r0, r1
121 | }
122 |
123 | // PollTasks provides a mock function with given fields: ctx, types, queues, visibilityTimeout, ordering, limit
124 | func (_m *IRepository) PollTasks(ctx context.Context, types []string, queues []string, visibilityTimeout time.Duration, ordering tasq.Ordering, limit int) ([]*tasq.Task, error) {
125 | ret := _m.Called(ctx, types, queues, visibilityTimeout, ordering, limit)
126 |
127 | var r0 []*tasq.Task
128 | var r1 error
129 | if rf, ok := ret.Get(0).(func(context.Context, []string, []string, time.Duration, tasq.Ordering, int) ([]*tasq.Task, error)); ok {
130 | return rf(ctx, types, queues, visibilityTimeout, ordering, limit)
131 | }
132 | if rf, ok := ret.Get(0).(func(context.Context, []string, []string, time.Duration, tasq.Ordering, int) []*tasq.Task); ok {
133 | r0 = rf(ctx, types, queues, visibilityTimeout, ordering, limit)
134 | } else {
135 | if ret.Get(0) != nil {
136 | r0 = ret.Get(0).([]*tasq.Task)
137 | }
138 | }
139 |
140 | if rf, ok := ret.Get(1).(func(context.Context, []string, []string, time.Duration, tasq.Ordering, int) error); ok {
141 | r1 = rf(ctx, types, queues, visibilityTimeout, ordering, limit)
142 | } else {
143 | r1 = ret.Error(1)
144 | }
145 |
146 | return r0, r1
147 | }
148 |
149 | // PurgeTasks provides a mock function with given fields: ctx, taskStatuses, taskTypes, queues, safeDelete
150 | func (_m *IRepository) PurgeTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes []string, queues []string, safeDelete bool) (int64, error) {
151 | ret := _m.Called(ctx, taskStatuses, taskTypes, queues, safeDelete)
152 |
153 | var r0 int64
154 | var r1 error
155 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string, bool) (int64, error)); ok {
156 | return rf(ctx, taskStatuses, taskTypes, queues, safeDelete)
157 | }
158 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string, bool) int64); ok {
159 | r0 = rf(ctx, taskStatuses, taskTypes, queues, safeDelete)
160 | } else {
161 | r0 = ret.Get(0).(int64)
162 | }
163 |
164 | if rf, ok := ret.Get(1).(func(context.Context, []tasq.TaskStatus, []string, []string, bool) error); ok {
165 | r1 = rf(ctx, taskStatuses, taskTypes, queues, safeDelete)
166 | } else {
167 | r1 = ret.Error(1)
168 | }
169 |
170 | return r0, r1
171 | }
172 |
173 | // RegisterError provides a mock function with given fields: ctx, task, errTask
174 | func (_m *IRepository) RegisterError(ctx context.Context, task *tasq.Task, errTask error) (*tasq.Task, error) {
175 | ret := _m.Called(ctx, task, errTask)
176 |
177 | var r0 *tasq.Task
178 | var r1 error
179 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task, error) (*tasq.Task, error)); ok {
180 | return rf(ctx, task, errTask)
181 | }
182 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task, error) *tasq.Task); ok {
183 | r0 = rf(ctx, task, errTask)
184 | } else {
185 | if ret.Get(0) != nil {
186 | r0 = ret.Get(0).(*tasq.Task)
187 | }
188 | }
189 |
190 | if rf, ok := ret.Get(1).(func(context.Context, *tasq.Task, error) error); ok {
191 | r1 = rf(ctx, task, errTask)
192 | } else {
193 | r1 = ret.Error(1)
194 | }
195 |
196 | return r0, r1
197 | }
198 |
199 | // RegisterFinish provides a mock function with given fields: ctx, task, finishStatus
200 | func (_m *IRepository) RegisterFinish(ctx context.Context, task *tasq.Task, finishStatus tasq.TaskStatus) (*tasq.Task, error) {
201 | ret := _m.Called(ctx, task, finishStatus)
202 |
203 | var r0 *tasq.Task
204 | var r1 error
205 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task, tasq.TaskStatus) (*tasq.Task, error)); ok {
206 | return rf(ctx, task, finishStatus)
207 | }
208 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task, tasq.TaskStatus) *tasq.Task); ok {
209 | r0 = rf(ctx, task, finishStatus)
210 | } else {
211 | if ret.Get(0) != nil {
212 | r0 = ret.Get(0).(*tasq.Task)
213 | }
214 | }
215 |
216 | if rf, ok := ret.Get(1).(func(context.Context, *tasq.Task, tasq.TaskStatus) error); ok {
217 | r1 = rf(ctx, task, finishStatus)
218 | } else {
219 | r1 = ret.Error(1)
220 | }
221 |
222 | return r0, r1
223 | }
224 |
225 | // RegisterStart provides a mock function with given fields: ctx, task
226 | func (_m *IRepository) RegisterStart(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
227 | ret := _m.Called(ctx, task)
228 |
229 | var r0 *tasq.Task
230 | var r1 error
231 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) (*tasq.Task, error)); ok {
232 | return rf(ctx, task)
233 | }
234 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) *tasq.Task); ok {
235 | r0 = rf(ctx, task)
236 | } else {
237 | if ret.Get(0) != nil {
238 | r0 = ret.Get(0).(*tasq.Task)
239 | }
240 | }
241 |
242 | if rf, ok := ret.Get(1).(func(context.Context, *tasq.Task) error); ok {
243 | r1 = rf(ctx, task)
244 | } else {
245 | r1 = ret.Error(1)
246 | }
247 |
248 | return r0, r1
249 | }
250 |
251 | // RequeueTask provides a mock function with given fields: ctx, task
252 | func (_m *IRepository) RequeueTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
253 | ret := _m.Called(ctx, task)
254 |
255 | var r0 *tasq.Task
256 | var r1 error
257 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) (*tasq.Task, error)); ok {
258 | return rf(ctx, task)
259 | }
260 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) *tasq.Task); ok {
261 | r0 = rf(ctx, task)
262 | } else {
263 | if ret.Get(0) != nil {
264 | r0 = ret.Get(0).(*tasq.Task)
265 | }
266 | }
267 |
268 | if rf, ok := ret.Get(1).(func(context.Context, *tasq.Task) error); ok {
269 | r1 = rf(ctx, task)
270 | } else {
271 | r1 = ret.Error(1)
272 | }
273 |
274 | return r0, r1
275 | }
276 |
277 | // ScanTasks provides a mock function with given fields: ctx, taskStatuses, taskTypes, queues, ordering, limit
278 | func (_m *IRepository) ScanTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes []string, queues []string, ordering tasq.Ordering, limit int) ([]*tasq.Task, error) {
279 | ret := _m.Called(ctx, taskStatuses, taskTypes, queues, ordering, limit)
280 |
281 | var r0 []*tasq.Task
282 | var r1 error
283 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string, tasq.Ordering, int) ([]*tasq.Task, error)); ok {
284 | return rf(ctx, taskStatuses, taskTypes, queues, ordering, limit)
285 | }
286 | if rf, ok := ret.Get(0).(func(context.Context, []tasq.TaskStatus, []string, []string, tasq.Ordering, int) []*tasq.Task); ok {
287 | r0 = rf(ctx, taskStatuses, taskTypes, queues, ordering, limit)
288 | } else {
289 | if ret.Get(0) != nil {
290 | r0 = ret.Get(0).([]*tasq.Task)
291 | }
292 | }
293 |
294 | if rf, ok := ret.Get(1).(func(context.Context, []tasq.TaskStatus, []string, []string, tasq.Ordering, int) error); ok {
295 | r1 = rf(ctx, taskStatuses, taskTypes, queues, ordering, limit)
296 | } else {
297 | r1 = ret.Error(1)
298 | }
299 |
300 | return r0, r1
301 | }
302 |
303 | // SubmitTask provides a mock function with given fields: ctx, task
304 | func (_m *IRepository) SubmitTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
305 | ret := _m.Called(ctx, task)
306 |
307 | var r0 *tasq.Task
308 | var r1 error
309 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) (*tasq.Task, error)); ok {
310 | return rf(ctx, task)
311 | }
312 | if rf, ok := ret.Get(0).(func(context.Context, *tasq.Task) *tasq.Task); ok {
313 | r0 = rf(ctx, task)
314 | } else {
315 | if ret.Get(0) != nil {
316 | r0 = ret.Get(0).(*tasq.Task)
317 | }
318 | }
319 |
320 | if rf, ok := ret.Get(1).(func(context.Context, *tasq.Task) error); ok {
321 | r1 = rf(ctx, task)
322 | } else {
323 | r1 = ret.Error(1)
324 | }
325 |
326 | return r0, r1
327 | }
328 |
329 | type mockConstructorTestingTNewIRepository interface {
330 | mock.TestingT
331 | Cleanup(func())
332 | }
333 |
334 | // NewIRepository creates a new instance of IRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
335 | func NewIRepository(t mockConstructorTestingTNewIRepository) *IRepository {
336 | mock := &IRepository{}
337 | mock.Mock.Test(t)
338 |
339 | t.Cleanup(func() { mock.AssertExpectations(t) })
340 |
341 | return mock
342 | }
343 |
--------------------------------------------------------------------------------
/producer.go:
--------------------------------------------------------------------------------
1 | package tasq
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | )
7 |
8 | // Producer is a service instance created by a Client with reference to that client
9 | // with the purpose of enabling the submission of new tasks.
10 | type Producer struct {
11 | client *Client
12 | }
13 |
14 | // NewProducer creates a new consumer with a reference to the original tasq client.
15 | func (c *Client) NewProducer() *Producer {
16 | return &Producer{
17 | client: c,
18 | }
19 | }
20 |
21 | // Submit constructs and submits a new task to the queue based on the supplied arguments.
22 | func (p *Producer) Submit(ctx context.Context, taskType string, taskArgs any, queue string, priority int16, maxReceives int32) (*Task, error) {
23 | newTask, err := NewTask(taskType, taskArgs, queue, priority, maxReceives)
24 | if err != nil {
25 | return nil, fmt.Errorf("error creating task: %w", err)
26 | }
27 |
28 | return p.SubmitTask(ctx, newTask)
29 | }
30 |
31 | // SubmitTask submits an existing task struct to the queue based on the supplied arguments.
32 | func (p *Producer) SubmitTask(ctx context.Context, task *Task) (*Task, error) {
33 | submittedTask, err := p.client.repository.SubmitTask(ctx, task)
34 | if err != nil {
35 | return nil, fmt.Errorf("error submitting task: %w", err)
36 | }
37 |
38 | return submittedTask, nil
39 | }
40 |
--------------------------------------------------------------------------------
/producer_test.go:
--------------------------------------------------------------------------------
1 | package tasq_test
2 |
3 | import (
4 | "context"
5 | "testing"
6 |
7 | "github.com/greencoda/tasq"
8 | "github.com/greencoda/tasq/mocks"
9 | "github.com/stretchr/testify/assert"
10 | "github.com/stretchr/testify/mock"
11 | "github.com/stretchr/testify/require"
12 | "github.com/stretchr/testify/suite"
13 | )
14 |
15 | type ProducterTestSuite struct {
16 | suite.Suite
17 | mockRepository *mocks.IRepository
18 | tasqClient *tasq.Client
19 | tasqProducer *tasq.Producer
20 | testTask *tasq.Task
21 | }
22 |
23 | func TestProducterTestSuite(t *testing.T) {
24 | t.Parallel()
25 |
26 | suite.Run(t, new(ProducterTestSuite))
27 | }
28 |
29 | func (s *ProducterTestSuite) SetupTest() {
30 | s.mockRepository = mocks.NewIRepository(s.T())
31 |
32 | s.tasqClient = tasq.NewClient(s.mockRepository)
33 | require.NotNil(s.T(), s.tasqClient)
34 |
35 | s.tasqProducer = s.tasqClient.NewProducer()
36 | require.NotNil(s.T(), s.tasqProducer)
37 |
38 | testArgs := "testData"
39 | testTask, err := tasq.NewTask("testTask", testArgs, "testQueue", 100, 5)
40 | require.Nil(s.T(), err)
41 |
42 | s.testTask = testTask
43 | }
44 |
45 | func (s *ProducterTestSuite) TestNewProducer() {
46 | assert.NotNil(s.T(), s.tasqProducer)
47 | }
48 |
49 | func (s *ProducterTestSuite) TestSubmitSuccessful() {
50 | ctx := context.Background()
51 |
52 | s.mockRepository.On("SubmitTask", ctx, mock.AnythingOfType("*tasq.Task")).Return(s.testTask, nil)
53 |
54 | task, err := s.tasqProducer.Submit(ctx, s.testTask.Type, s.testTask.Args, s.testTask.Queue, s.testTask.Priority, s.testTask.MaxReceives)
55 |
56 | assert.NotNil(s.T(), task)
57 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "SubmitTask", ctx, mock.AnythingOfType("*tasq.Task")))
58 | assert.Nil(s.T(), err)
59 | }
60 |
61 | func (s *ProducterTestSuite) TestSubmitUnsuccessful() {
62 | var (
63 | ctx = context.Background()
64 | testArgs = "testData"
65 | testTask, _ = tasq.NewTask("testTask", testArgs, "testQueue", 100, 5)
66 | )
67 |
68 | s.mockRepository.On("SubmitTask", ctx, mock.AnythingOfType("*tasq.Task")).Return(nil, errRepository)
69 |
70 | task, err := s.tasqProducer.Submit(ctx, testTask.Type, testArgs, testTask.Queue, testTask.Priority, testTask.MaxReceives)
71 |
72 | assert.Nil(s.T(), task)
73 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "SubmitTask", ctx, mock.AnythingOfType("*tasq.Task")))
74 | assert.NotNil(s.T(), err)
75 | }
76 |
77 | func (s *ProducterTestSuite) TestSubmitInvalidpriority() {
78 | ctx := context.Background()
79 |
80 | task, err := s.tasqProducer.Submit(ctx, "testData", nil, "testQueue", 100, 5)
81 |
82 | assert.Nil(s.T(), task)
83 | assert.NotNil(s.T(), err)
84 | }
85 |
86 | func (s *ProducterTestSuite) TestSubmitTask() {
87 | ctx := context.Background()
88 |
89 | s.mockRepository.On("SubmitTask", ctx, mock.AnythingOfType("*tasq.Task")).Return(s.testTask, nil)
90 |
91 | task, err := tasq.NewTask(s.testTask.Type, s.testTask.Args, s.testTask.Queue, s.testTask.Priority, s.testTask.MaxReceives)
92 | require.NotNil(s.T(), task)
93 | require.Nil(s.T(), err)
94 |
95 | submittedTask, err := s.tasqProducer.SubmitTask(ctx, task)
96 |
97 | assert.NotNil(s.T(), submittedTask)
98 | assert.True(s.T(), s.mockRepository.AssertCalled(s.T(), "SubmitTask", ctx, mock.AnythingOfType("*tasq.Task")))
99 | assert.Nil(s.T(), err)
100 | }
101 |
--------------------------------------------------------------------------------
/repository.go:
--------------------------------------------------------------------------------
1 | package tasq
2 |
3 | import (
4 | "context"
5 | "time"
6 |
7 | "github.com/google/uuid"
8 | )
9 |
10 | // Ordering is an enum type describing the polling strategy utitlized during
11 | // the polling process.
12 | type Ordering int
13 |
14 | // The collection of orderings.
15 | const (
16 | OrderingCreatedAtFirst Ordering = iota
17 | OrderingPriorityFirst
18 | )
19 |
20 | // IRepository describes the mandatory methods a repository must implement
21 | // in order for tasq to be able to use it.
22 | type IRepository interface {
23 | Migrate(ctx context.Context) error
24 |
25 | PingTasks(ctx context.Context, taskIDs []uuid.UUID, visibilityTimeout time.Duration) ([]*Task, error)
26 | PollTasks(ctx context.Context, types, queues []string, visibilityTimeout time.Duration, ordering Ordering, limit int) ([]*Task, error)
27 | CleanTasks(ctx context.Context, minimumAge time.Duration) (int64, error)
28 |
29 | RegisterStart(ctx context.Context, task *Task) (*Task, error)
30 | RegisterError(ctx context.Context, task *Task, errTask error) (*Task, error)
31 | RegisterFinish(ctx context.Context, task *Task, finishStatus TaskStatus) (*Task, error)
32 |
33 | SubmitTask(ctx context.Context, task *Task) (*Task, error)
34 | DeleteTask(ctx context.Context, task *Task, safeDelete bool) error
35 | RequeueTask(ctx context.Context, task *Task) (*Task, error)
36 |
37 | ScanTasks(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string, ordering Ordering, limit int) ([]*Task, error)
38 | CountTasks(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string) (int64, error)
39 | PurgeTasks(ctx context.Context, taskStatuses []TaskStatus, taskTypes, queues []string, safeDelete bool) (int64, error)
40 | }
41 |
--------------------------------------------------------------------------------
/repository/mysql/export_test.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | import (
4 | "database/sql"
5 | "database/sql/driver"
6 |
7 | "github.com/greencoda/tasq"
8 | )
9 |
10 | func GetTestTaskValues(task *tasq.Task) []driver.Value {
11 | testMySQLTask := newFromTask(task)
12 |
13 | return []driver.Value{
14 | testMySQLTask.ID,
15 | testMySQLTask.Type,
16 | testMySQLTask.Args,
17 | testMySQLTask.Queue,
18 | testMySQLTask.Priority,
19 | testMySQLTask.Status,
20 | testMySQLTask.ReceiveCount,
21 | testMySQLTask.MaxReceives,
22 | testMySQLTask.LastError,
23 | testMySQLTask.CreatedAt,
24 | testMySQLTask.StartedAt,
25 | testMySQLTask.FinishedAt,
26 | testMySQLTask.VisibleAt,
27 | }
28 | }
29 |
30 | func InterpolateSQL(sql string, params map[string]any) string {
31 | return interpolateSQL(sql, params)
32 | }
33 |
34 | func (d *Repository) GetQueryWithTableName(sqlTemplate string, args ...any) (string, []any) {
35 | return d.getQueryWithTableName(sqlTemplate, args...)
36 | }
37 |
38 | func StringToSQLNullString(input *string) sql.NullString {
39 | return stringToSQLNullString(input)
40 | }
41 |
42 | func ParseNullableString(input sql.NullString) *string {
43 | return parseNullableString(input)
44 | }
45 |
--------------------------------------------------------------------------------
/repository/mysql/mysql.go:
--------------------------------------------------------------------------------
1 | // Package mysql provides the implementation of a tasq repository in MySQL
2 | package mysql
3 |
4 | import (
5 | "bytes"
6 | "context"
7 | "database/sql"
8 | "errors"
9 | "fmt"
10 | "strings"
11 | "text/template"
12 | "time"
13 |
14 | _ "github.com/go-sql-driver/mysql" // import mysql driver
15 | "github.com/google/uuid"
16 | "github.com/greencoda/tasq"
17 | "github.com/jmoiron/sqlx"
18 | "github.com/lib/pq"
19 | )
20 |
21 | const driverName = "mysql"
22 |
23 | var (
24 | errUnexpectedDataSourceType = errors.New("unexpected dataSource type")
25 | errFailedToBeginTx = errors.New("failed to begin transaction")
26 | errFailedToCommitTx = errors.New("failed to commit transaction")
27 | errFailedToExecuteSelect = errors.New("failed to execute select query")
28 | errFailedToExecuteUpdate = errors.New("failed to execute update query")
29 | errFailedToExecuteDelete = errors.New("failed to execute delete query")
30 | errFailedToExecuteInsert = errors.New("failed to execute insert query")
31 | errFailedToExecuteCreateTable = errors.New("failed to execute create table query")
32 | errFailedGetRowsAffected = errors.New("failed to get rows affected by query")
33 | )
34 |
35 | // Repository implements the menthods necessary for tasq to work in MySQL.
36 | type Repository struct {
37 | db *sqlx.DB
38 | tableName string
39 | }
40 |
41 | // NewRepository creates a new MySQL Repository instance.
42 | func NewRepository(dataSource any, prefix string) (*Repository, error) {
43 | switch d := dataSource.(type) {
44 | case string:
45 | return newRepositoryFromDSN(d, prefix)
46 | case *sql.DB:
47 | return newRepositoryFromDB(d, prefix)
48 | }
49 |
50 | return nil, fmt.Errorf("%w: %T", errUnexpectedDataSourceType, dataSource)
51 | }
52 |
53 | func newRepositoryFromDSN(dsn string, prefix string) (*Repository, error) {
54 | dbx, err := sqlx.Open(driverName, dsn)
55 | if err != nil {
56 | return nil, fmt.Errorf("failed to open DB from dsn: %w", err)
57 | }
58 |
59 | return &Repository{
60 | db: dbx,
61 | tableName: tableName(prefix),
62 | }, nil
63 | }
64 |
65 | func newRepositoryFromDB(db *sql.DB, prefix string) (*Repository, error) {
66 | dbx := sqlx.NewDb(db, driverName)
67 |
68 | return &Repository{
69 | db: dbx,
70 | tableName: tableName(prefix),
71 | }, nil
72 | }
73 |
74 | // Migrate prepares the database by adding the tasks table.
75 | func (d *Repository) Migrate(ctx context.Context) error {
76 | if err := d.migrateTable(ctx); err != nil {
77 | return err
78 | }
79 |
80 | return nil
81 | }
82 |
83 | // PingTasks pings a list of tasks by their ID
84 | // and extends their invisibility timestamp with the supplied timeout parameter.
85 | func (d *Repository) PingTasks(ctx context.Context, taskIDs []uuid.UUID, visibilityTimeout time.Duration) ([]*tasq.Task, error) {
86 | if len(taskIDs) == 0 {
87 | return []*tasq.Task{}, nil
88 | }
89 |
90 | const (
91 | updatePingedTasksSQLTemplate = `UPDATE
92 | {{.tableName}}
93 | SET
94 | visible_at = :visibleAt
95 | WHERE
96 | id IN (:pingedTaskIDs);`
97 | selectPingedTasksSQLTemplate = `SELECT
98 | *
99 | FROM
100 | {{.tableName}}
101 | WHERE
102 | id IN (:pingedTaskIDs);`
103 | )
104 |
105 | tx, err := d.db.Beginx()
106 | if err != nil {
107 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToBeginTx, err)
108 | }
109 |
110 | defer rollback(tx)
111 |
112 | var (
113 | pingTime = time.Now()
114 | updatePingedTasksQuery, updatePingedTasksArgs = d.getQueryWithTableName(updatePingedTasksSQLTemplate, map[string]any{
115 | "visibleAt": timeToString(pingTime.Add(visibilityTimeout)),
116 | "pingedTaskIDs": taskIDs,
117 | })
118 | )
119 |
120 | _, err = tx.ExecContext(ctx, updatePingedTasksQuery, updatePingedTasksArgs...)
121 | if err != nil {
122 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err)
123 | }
124 |
125 | var (
126 | pingedMySQLTasks []*mySQLTask
127 | selectPingedTasksQuery, selectPingedTasksArgs = d.getQueryWithTableName(selectPingedTasksSQLTemplate, map[string]any{
128 | "pingedTaskIDs": taskIDs,
129 | })
130 | )
131 |
132 | err = tx.SelectContext(ctx, &pingedMySQLTasks, selectPingedTasksQuery, selectPingedTasksArgs...)
133 | if err != nil {
134 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err)
135 | }
136 |
137 | err = tx.Commit()
138 | if err != nil {
139 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToCommitTx, err)
140 | }
141 |
142 | return mySQLTasksToTasks(pingedMySQLTasks), nil
143 | }
144 |
145 | // PollTasks polls for available tasks matching supplied the parameters
146 | // and sets their invisibility the supplied timeout parameter to the future.
147 | func (d *Repository) PollTasks(ctx context.Context, types, queues []string, visibilityTimeout time.Duration, ordering tasq.Ordering, pollLimit int) ([]*tasq.Task, error) {
148 | if pollLimit == 0 {
149 | return []*tasq.Task{}, nil
150 | }
151 |
152 | const (
153 | selectPolledTasksSQLTemplate = `SELECT
154 | id
155 | FROM
156 | {{.tableName}}
157 | WHERE
158 | type IN (:pollTypes) AND
159 | queue IN (:pollQueues) AND
160 | status IN (:pollStatuses) AND
161 | visible_at <= :pollTime
162 | ORDER BY
163 | :pollOrdering
164 | LIMIT :pollLimit
165 | FOR UPDATE SKIP LOCKED;`
166 | updatePolledTasksSQLTemplate = `UPDATE
167 | {{.tableName}}
168 | SET
169 | status = :status,
170 | receive_count = receive_count + 1,
171 | visible_at = :visibleAt
172 | WHERE
173 | id IN (:polledTaskIDs);`
174 | selectUpdatedPolledTasksSQLTemplate = `SELECT
175 | *
176 | FROM
177 | {{.tableName}}
178 | WHERE
179 | id IN (:polledTaskIDs);`
180 | )
181 |
182 | tx, err := d.db.Beginx()
183 | if err != nil {
184 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToBeginTx, err)
185 | }
186 |
187 | defer rollback(tx)
188 |
189 | var (
190 | polledTaskIDs []TaskID
191 | pollTime = time.Now()
192 | selectPolledTasksQuery, selectPolledTasksArgs = d.getQueryWithTableName(selectPolledTasksSQLTemplate, map[string]any{
193 | "pollTypes": types,
194 | "pollQueues": queues,
195 | "pollStatuses": tasq.GetTaskStatuses(tasq.OpenTasks),
196 | "pollTime": timeToString(pollTime),
197 | "pollOrdering": getOrderingDirectives(ordering),
198 | "pollLimit": pollLimit,
199 | })
200 | )
201 |
202 | err = tx.SelectContext(ctx, &polledTaskIDs, selectPolledTasksQuery, selectPolledTasksArgs...)
203 | if err != nil {
204 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err)
205 | }
206 |
207 | if len(polledTaskIDs) == 0 {
208 | return []*tasq.Task{}, nil
209 | }
210 |
211 | updatePolledTasksQuery, updatePolledTasksArgs := d.getQueryWithTableName(updatePolledTasksSQLTemplate, map[string]any{
212 | "status": tasq.StatusEnqueued,
213 | "visibleAt": timeToString(pollTime.Add(visibilityTimeout)),
214 | "polledTaskIDs": polledTaskIDs,
215 | })
216 |
217 | _, err = tx.ExecContext(ctx, updatePolledTasksQuery, updatePolledTasksArgs...)
218 | if err != nil {
219 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err)
220 | }
221 |
222 | var (
223 | polledMySQLTasks []*mySQLTask
224 | selectUpdatedTasksQuery, selectUpdatedTasksArgs = d.getQueryWithTableName(selectUpdatedPolledTasksSQLTemplate, map[string]any{
225 | "polledTaskIDs": polledTaskIDs,
226 | })
227 | )
228 |
229 | err = tx.SelectContext(ctx, &polledMySQLTasks, selectUpdatedTasksQuery, selectUpdatedTasksArgs...)
230 | if err != nil {
231 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err)
232 | }
233 |
234 | err = tx.Commit()
235 | if err != nil {
236 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToCommitTx, err)
237 | }
238 |
239 | return mySQLTasksToTasks(polledMySQLTasks), nil
240 | }
241 |
242 | // CleanTasks removes finished tasks from the queue
243 | // if their creation date is past the supplied duration.
244 | func (d *Repository) CleanTasks(ctx context.Context, cleanAge time.Duration) (int64, error) {
245 | const cleanTasksSQLTemplate = `DELETE FROM
246 | {{.tableName}}
247 | WHERE
248 | status IN (:statuses) AND
249 | created_at <= :cleanAt;`
250 |
251 | var (
252 | cleanTime = time.Now()
253 | cleanTasksQuery, cleanTasksArgs = d.getQueryWithTableName(cleanTasksSQLTemplate, map[string]any{
254 | "statuses": tasq.GetTaskStatuses(tasq.FinishedTasks),
255 | "cleanAt": timeToString(cleanTime.Add(-cleanAge)),
256 | })
257 | )
258 |
259 | result, err := d.db.ExecContext(ctx, cleanTasksQuery, cleanTasksArgs...)
260 | if err != nil {
261 | return 0, fmt.Errorf("%s: %w", errFailedToExecuteDelete, err)
262 | }
263 |
264 | rowsAffected, err := result.RowsAffected()
265 | if err != nil {
266 | return 0, fmt.Errorf("%s: %w", errFailedGetRowsAffected, err)
267 | }
268 |
269 | return rowsAffected, nil
270 | }
271 |
272 | // RegisterStart marks a task as started with the 'in progress' status
273 | // and records the time of start.
274 | func (d *Repository) RegisterStart(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
275 | const (
276 | updateTaskSQLTemplate = `UPDATE
277 | {{.tableName}}
278 | SET
279 | status = :status,
280 | started_at = :startTime
281 | WHERE
282 | id = :taskID;`
283 | selectUpdatedTaskSQLTemplate = `SELECT *
284 | FROM
285 | {{.tableName}}
286 | WHERE
287 | id = :taskID;`
288 | )
289 |
290 | tx, err := d.db.Beginx()
291 | if err != nil {
292 | return nil, fmt.Errorf("%s: %w", errFailedToBeginTx, err)
293 | }
294 |
295 | defer rollback(tx)
296 |
297 | var (
298 | mySQLTask = newFromTask(task)
299 | startTime = time.Now()
300 | updateTaskQuery, updateTaskArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{
301 | "status": tasq.StatusInProgress,
302 | "startTime": timeToString(startTime),
303 | "taskID": mySQLTask.ID,
304 | })
305 | )
306 |
307 | _, err = tx.ExecContext(ctx, updateTaskQuery, updateTaskArgs...)
308 | if err != nil {
309 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err)
310 | }
311 |
312 | selectUpdatedTaskQuery, selectUpdatedTaskArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{
313 | "taskID": mySQLTask.ID,
314 | })
315 |
316 | err = tx.QueryRowxContext(ctx, selectUpdatedTaskQuery, selectUpdatedTaskArgs...).
317 | StructScan(mySQLTask)
318 | if err != nil {
319 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err)
320 | }
321 |
322 | err = tx.Commit()
323 | if err != nil {
324 | return nil, fmt.Errorf("%s: %w", errFailedToCommitTx, err)
325 | }
326 |
327 | return mySQLTask.toTask(), nil
328 | }
329 |
330 | // RegisterError records an error message on the task as last error.
331 | func (d *Repository) RegisterError(ctx context.Context, task *tasq.Task, errTask error) (*tasq.Task, error) {
332 | const (
333 | updateTaskSQLTemplate = `UPDATE
334 | {{.tableName}}
335 | SET
336 | last_error = :errorMessage
337 | WHERE
338 | id = :taskID;`
339 | selectUpdatedTaskSQLTemplate = `SELECT *
340 | FROM
341 | {{.tableName}}
342 | WHERE
343 | id = :taskID;`
344 | )
345 |
346 | tx, err := d.db.Beginx()
347 | if err != nil {
348 | return nil, fmt.Errorf("%s: %w", errFailedToBeginTx, err)
349 | }
350 |
351 | defer rollback(tx)
352 |
353 | var (
354 | mySQLTask = newFromTask(task)
355 | updateTaskQuery, updateTaskArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{
356 | "errorMessage": errTask.Error(),
357 | "taskID": mySQLTask.ID,
358 | })
359 | )
360 |
361 | _, err = tx.ExecContext(ctx, updateTaskQuery, updateTaskArgs...)
362 | if err != nil {
363 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err)
364 | }
365 |
366 | selectUpdatedTaskQuery, selectUpdatedTaskArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{
367 | "taskID": mySQLTask.ID,
368 | })
369 |
370 | err = tx.QueryRowxContext(ctx, selectUpdatedTaskQuery, selectUpdatedTaskArgs...).
371 | StructScan(mySQLTask)
372 | if err != nil {
373 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err)
374 | }
375 |
376 | err = tx.Commit()
377 | if err != nil {
378 | return nil, fmt.Errorf("%s: %w", errFailedToCommitTx, err)
379 | }
380 |
381 | return mySQLTask.toTask(), nil
382 | }
383 |
384 | // RegisterFinish marks a task as finished with the supplied status
385 | // and records the time of finish.
386 | func (d *Repository) RegisterFinish(ctx context.Context, task *tasq.Task, finishStatus tasq.TaskStatus) (*tasq.Task, error) {
387 | const (
388 | updateTaskSQLTemplate = `UPDATE
389 | {{.tableName}}
390 | SET
391 | status = :status,
392 | finished_at = :finishTime
393 | WHERE
394 | id = :taskID;`
395 | selectUpdatedTaskSQLTemplate = `SELECT *
396 | FROM
397 | {{.tableName}}
398 | WHERE
399 | id = :taskID;`
400 | )
401 |
402 | tx, err := d.db.Beginx()
403 | if err != nil {
404 | return nil, fmt.Errorf("%s: %w", errFailedToBeginTx, err)
405 | }
406 |
407 | defer rollback(tx)
408 |
409 | var (
410 | mySQLTask = newFromTask(task)
411 | finishTime = time.Now()
412 | updateTasksQuery, updateTasksArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{
413 | "status": finishStatus,
414 | "finishTime": timeToString(finishTime),
415 | "taskID": mySQLTask.ID,
416 | })
417 | )
418 |
419 | _, err = tx.ExecContext(ctx, updateTasksQuery, updateTasksArgs...)
420 | if err != nil {
421 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err)
422 | }
423 |
424 | selectUpdatedTasksQuery, selectUpdatedTasksArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{
425 | "taskID": mySQLTask.ID,
426 | })
427 |
428 | err = tx.QueryRowxContext(ctx, selectUpdatedTasksQuery, selectUpdatedTasksArgs...).
429 | StructScan(mySQLTask)
430 | if err != nil {
431 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err)
432 | }
433 |
434 | err = tx.Commit()
435 | if err != nil {
436 | return nil, fmt.Errorf("%s: %w", errFailedToCommitTx, err)
437 | }
438 |
439 | return mySQLTask.toTask(), nil
440 | }
441 |
442 | // SubmitTask adds the supplied task to the queue.
443 | func (d *Repository) SubmitTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
444 | const (
445 | insertTaskSQLTemplate = `INSERT INTO
446 | {{.tableName}}
447 | (id, type, args, queue, priority, status, max_receives, created_at, visible_at)
448 | VALUES
449 | (:id, :type, :args, :queue, :priority, :status, :maxReceives, :createdAt, :visibleAt);`
450 | selectInsertedTaskSQLTemplate = `SELECT *
451 | FROM
452 | {{.tableName}}
453 | WHERE
454 | id = :taskID;`
455 | )
456 |
457 | tx, err := d.db.Beginx()
458 | if err != nil {
459 | return nil, fmt.Errorf("%s: %w", errFailedToBeginTx, err)
460 | }
461 |
462 | defer rollback(tx)
463 |
464 | var (
465 | mySQLTask = newFromTask(task)
466 | insertTaskQuery, insertTaskArgs = d.getQueryWithTableName(insertTaskSQLTemplate, map[string]any{
467 | "id": mySQLTask.ID,
468 | "type": mySQLTask.Type,
469 | "args": mySQLTask.Args,
470 | "queue": mySQLTask.Queue,
471 | "priority": mySQLTask.Priority,
472 | "status": mySQLTask.Status,
473 | "maxReceives": mySQLTask.MaxReceives,
474 | "createdAt": mySQLTask.CreatedAt,
475 | "visibleAt": mySQLTask.VisibleAt,
476 | })
477 | )
478 |
479 | _, err = tx.ExecContext(ctx, insertTaskQuery, insertTaskArgs...)
480 | if err != nil {
481 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteInsert, err)
482 | }
483 |
484 | selectInsertedTaskQuery, selectInsertedTaskArgs := d.getQueryWithTableName(selectInsertedTaskSQLTemplate, map[string]any{
485 | "taskID": mySQLTask.ID,
486 | })
487 |
488 | err = tx.QueryRowxContext(ctx, selectInsertedTaskQuery, selectInsertedTaskArgs...).
489 | StructScan(mySQLTask)
490 | if err != nil {
491 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err)
492 | }
493 |
494 | err = tx.Commit()
495 | if err != nil {
496 | return nil, fmt.Errorf("%s: %w", errFailedToCommitTx, err)
497 | }
498 |
499 | return mySQLTask.toTask(), nil
500 | }
501 |
502 | // DeleteTask removes the supplied task from the queue.
503 | func (d *Repository) DeleteTask(ctx context.Context, task *tasq.Task, safeDelete bool) error {
504 | var (
505 | mySQLTask = newFromTask(task)
506 | conditions = []string{
507 | `id = :taskID`,
508 | }
509 | parameters = map[string]any{
510 | "taskID": mySQLTask.ID,
511 | }
512 | )
513 |
514 | if safeDelete {
515 | d.applySafeDeleteConditions(&conditions, ¶meters)
516 | }
517 |
518 | deleteTaskSQLTemplate := `DELETE FROM {{.tableName}} WHERE ` + strings.Join(conditions, ` AND `) + `;`
519 |
520 | deleteTaskQuery, deleteTaskArgs := d.getQueryWithTableName(deleteTaskSQLTemplate, parameters)
521 |
522 | _, err := d.db.ExecContext(ctx, deleteTaskQuery, deleteTaskArgs...)
523 | if err != nil {
524 | return fmt.Errorf("%s: %w", errFailedToExecuteDelete, err)
525 | }
526 |
527 | return nil
528 | }
529 |
530 | // RequeueTask marks a task as new, so it can be picked up again.
531 | func (d *Repository) RequeueTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
532 | const (
533 | updateTaskSQLTemplate = `UPDATE
534 | {{.tableName}}
535 | SET
536 | status = :status
537 | WHERE
538 | id = :taskID;`
539 | selectUpdatedTaskSQLTemplate = `SELECT *
540 | FROM
541 | {{.tableName}}
542 | WHERE
543 | id = :taskID;`
544 | )
545 |
546 | tx, err := d.db.Beginx()
547 | if err != nil {
548 | return nil, fmt.Errorf("%s: %w", errFailedToBeginTx, err)
549 | }
550 |
551 | defer rollback(tx)
552 |
553 | var (
554 | mySQLTask = newFromTask(task)
555 | updateTaskQuery, updateTaskArgs = d.getQueryWithTableName(updateTaskSQLTemplate, map[string]any{
556 | "status": tasq.StatusNew,
557 | "taskID": mySQLTask.ID,
558 | })
559 | )
560 |
561 | _, err = tx.ExecContext(ctx, updateTaskQuery, updateTaskArgs...)
562 | if err != nil {
563 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err)
564 | }
565 |
566 | selectUpdatedTaskQuery, selectUpdatedTaskArgs := d.getQueryWithTableName(selectUpdatedTaskSQLTemplate, map[string]any{
567 | "taskID": mySQLTask.ID,
568 | })
569 |
570 | err = tx.QueryRowxContext(ctx, selectUpdatedTaskQuery, selectUpdatedTaskArgs...).
571 | StructScan(mySQLTask)
572 | if err != nil {
573 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err)
574 | }
575 |
576 | err = tx.Commit()
577 | if err != nil {
578 | return nil, fmt.Errorf("%s: %w", errFailedToCommitTx, err)
579 | }
580 |
581 | return mySQLTask.toTask(), err
582 | }
583 |
584 | // CountTasks returns the number of tasks in the queue based on the supplied filters.
585 | func (d *Repository) CountTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (int64, error) {
586 | var (
587 | count int64
588 | selectTaskCountQuery, selectTaskCountArgs = d.getQueryWithTableName(
589 | d.buildCountSQLTemplate(taskStatuses, taskTypes, queues),
590 | )
591 | )
592 |
593 | err := d.db.GetContext(ctx, &count, selectTaskCountQuery, selectTaskCountArgs...)
594 | if err != nil {
595 | return 0, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err)
596 | }
597 |
598 | return count, nil
599 | }
600 |
601 | // ScanTasks returns a list of tasks in the queue based on the supplied filters.
602 | func (d *Repository) ScanTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) ([]*tasq.Task, error) {
603 | var (
604 | scannedTasks []*mySQLTask
605 | selectScannedTasksQuery, selectScannedTasksArgs = d.getQueryWithTableName(
606 | d.buildScanSQLTemplate(taskStatuses, taskTypes, queues, ordering, scanLimit),
607 | )
608 | )
609 |
610 | err := d.db.SelectContext(ctx, &scannedTasks, selectScannedTasksQuery, selectScannedTasksArgs...)
611 | if err != nil {
612 | return []*tasq.Task{}, fmt.Errorf("%s: %w", errFailedToExecuteSelect, err)
613 | }
614 |
615 | return mySQLTasksToTasks(scannedTasks), nil
616 | }
617 |
618 | // PurgeTasks removes all tasks from the queue based on the supplied filters.
619 | func (d *Repository) PurgeTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (int64, error) {
620 | selectPurgedTasksQuery, selectPurgedTasksArgs := d.getQueryWithTableName(
621 | d.buildPurgeSQLTemplate(taskStatuses, taskTypes, queues, safeDelete),
622 | )
623 |
624 | result, err := d.db.ExecContext(ctx, selectPurgedTasksQuery, selectPurgedTasksArgs...)
625 | if err != nil {
626 | return 0, fmt.Errorf("%s: %w", errFailedToExecuteDelete, err)
627 | }
628 |
629 | rowsAffected, err := result.RowsAffected()
630 | if err != nil {
631 | return 0, fmt.Errorf("%s: %w", errFailedGetRowsAffected, err)
632 | }
633 |
634 | return rowsAffected, nil
635 | }
636 |
637 | func (d *Repository) buildCountSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (string, map[string]any) {
638 | var (
639 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
640 | sqlTemplate = `SELECT COUNT(*) FROM {{.tableName}}`
641 | )
642 |
643 | if len(conditions) > 0 {
644 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
645 | }
646 |
647 | return sqlTemplate + `;`, parameters
648 | }
649 |
650 | func (d *Repository) buildScanSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) (string, map[string]any) {
651 | var (
652 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
653 | sqlTemplate = `SELECT * FROM {{.tableName}}`
654 | )
655 |
656 | if len(conditions) > 0 {
657 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
658 | }
659 |
660 | sqlTemplate += ` ORDER BY :scanOrdering LIMIT :limit;`
661 |
662 | parameters["scanOrdering"] = pq.Array(getOrderingDirectives(ordering))
663 | parameters["limit"] = scanLimit
664 |
665 | return sqlTemplate + `;`, parameters
666 | }
667 |
668 | func (d *Repository) buildPurgeSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (string, map[string]any) {
669 | var (
670 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
671 | sqlTemplate = `DELETE FROM {{.tableName}}`
672 | )
673 |
674 | if safeDelete {
675 | d.applySafeDeleteConditions(&conditions, ¶meters)
676 | }
677 |
678 | if len(conditions) > 0 {
679 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
680 | }
681 |
682 | return sqlTemplate + `;`, parameters
683 | }
684 |
685 | func (d *Repository) applySafeDeleteConditions(conditions *[]string, parameters *map[string]any) {
686 | *conditions = append(*conditions, `(
687 | (
688 | visible_at <= :visibleAt
689 | ) OR (
690 | status IN (:statuses) AND
691 | visible_at > :visibleAt
692 | )
693 | )`)
694 | (*parameters)["statuses"] = []tasq.TaskStatus{tasq.StatusNew}
695 | (*parameters)["visibleAt"] = time.Now()
696 | }
697 |
698 | func (d *Repository) buildFilterConditions(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) ([]string, map[string]any) {
699 | var (
700 | conditions []string
701 | parameters = make(map[string]any)
702 | )
703 |
704 | if len(taskStatuses) > 0 {
705 | conditions = append(conditions, `status IN (:filterStatuses)`)
706 | parameters["filterStatuses"] = taskStatuses
707 | }
708 |
709 | if len(taskTypes) > 0 {
710 | conditions = append(conditions, `type IN (:filterTypes)`)
711 | parameters["filterTypes"] = taskTypes
712 | }
713 |
714 | if len(queues) > 0 {
715 | conditions = append(conditions, `queue IN (:filterQueues)`)
716 | parameters["filterQueues"] = queues
717 | }
718 |
719 | return conditions, parameters
720 | }
721 |
722 | func (d *Repository) getQueryWithTableName(sqlTemplate string, args ...any) (string, []any) {
723 | query := interpolateSQL(sqlTemplate, map[string]any{
724 | "tableName": d.tableName,
725 | })
726 |
727 | query, args, err := sqlx.Named(query, args)
728 | if err != nil {
729 | panic(err)
730 | }
731 |
732 | query, args, err = sqlx.In(query, args...)
733 | if err != nil {
734 | panic(err)
735 | }
736 |
737 | return d.db.Rebind(query), args
738 | }
739 |
740 | func (d *Repository) migrateTable(ctx context.Context) error {
741 | const sqlTemplate = `CREATE TABLE IF NOT EXISTS {{.tableName}} (
742 | id binary(16) NOT NULL,
743 | type text NOT NULL,
744 | args longblob NOT NULL,
745 | queue text NOT NULL,
746 | priority smallint NOT NULL,
747 | status enum({{.enumValues}}) NOT NULL,
748 | receive_count int NOT NULL DEFAULT '0',
749 | max_receives int NOT NULL DEFAULT '0',
750 | last_error text,
751 | created_at datetime(6) NOT NULL DEFAULT '0001-01-01 00:00:00.000000',
752 | started_at datetime(6),
753 | finished_at datetime(6),
754 | visible_at datetime(6) NOT NULL DEFAULT '0001-01-01 00:00:00.000000',
755 | PRIMARY KEY (id)
756 | );`
757 |
758 | query := interpolateSQL(sqlTemplate, map[string]any{
759 | "tableName": d.tableName,
760 | "enumValues": sliceToMySQLValueList(tasq.GetTaskStatuses(tasq.AllTasks)),
761 | })
762 |
763 | _, err := d.db.ExecContext(ctx, query)
764 | if err != nil {
765 | return fmt.Errorf("%s: %w", errFailedToExecuteCreateTable, err)
766 | }
767 |
768 | return nil
769 | }
770 |
771 | func getOrderingDirectives(ordering tasq.Ordering) []string {
772 | var (
773 | OrderingCreatedAtFirst = []string{"created_at ASC", "priority DESC"}
774 | OrderingPriorityFirst = []string{"priority DESC", "created_at ASC"}
775 | )
776 |
777 | if orderingDirectives, ok := map[tasq.Ordering][]string{
778 | tasq.OrderingCreatedAtFirst: OrderingCreatedAtFirst,
779 | tasq.OrderingPriorityFirst: OrderingPriorityFirst,
780 | }[ordering]; ok {
781 | return orderingDirectives
782 | }
783 |
784 | return OrderingCreatedAtFirst
785 | }
786 |
787 | func rollback(tx *sqlx.Tx) {
788 | if err := tx.Rollback(); err != nil && !errors.Is(err, sql.ErrTxDone) {
789 | panic(err)
790 | }
791 | }
792 |
793 | func sliceToMySQLValueList[T any](slice []T) string {
794 | stringSlice := make([]string, 0, len(slice))
795 |
796 | for _, s := range slice {
797 | stringSlice = append(stringSlice, fmt.Sprint(s))
798 | }
799 |
800 | return fmt.Sprintf(`"%s"`, strings.Join(stringSlice, `", "`))
801 | }
802 |
803 | func tableName(prefix string) string {
804 | const tableName = "tasks"
805 |
806 | if len(prefix) > 0 {
807 | return prefix + "_" + tableName
808 | }
809 |
810 | return tableName
811 | }
812 |
813 | func interpolateSQL(sql string, params map[string]any) string {
814 | template, err := template.New("sql").Parse(sql)
815 | if err != nil {
816 | panic(err)
817 | }
818 |
819 | var outputBuffer bytes.Buffer
820 |
821 | err = template.Execute(&outputBuffer, params)
822 | if err != nil {
823 | panic(err)
824 | }
825 |
826 | return outputBuffer.String()
827 | }
828 |
--------------------------------------------------------------------------------
/repository/mysql/mysqlTask.go:
--------------------------------------------------------------------------------
1 | package mysql
2 |
3 | import (
4 | "database/sql"
5 | "database/sql/driver"
6 | "errors"
7 | "fmt"
8 | "time"
9 |
10 | "github.com/google/uuid"
11 | "github.com/greencoda/tasq"
12 | )
13 |
14 | const (
15 | idLength = 16
16 | timeFormat = "2006-01-02 15:04:05.999999"
17 | )
18 |
19 | var (
20 | errIncorrectLength = errors.New("Scan: MySQLTaskID is of incorrect length")
21 | errUnableToScan = errors.New("Scan: unable to scan type into MySQLTaskID")
22 | )
23 |
24 | // TaskID represents the types used to manage conversion of UUID
25 | // to MySQL's binary(16) format.
26 | type TaskID [idLength]byte
27 |
28 | // Scan implements sql.Scanner so TaskIDs can be read from MySQL transparently.
29 | func (i *TaskID) Scan(src any) error {
30 | switch src := src.(type) {
31 | case nil:
32 | return nil
33 | case []byte:
34 | if len(src) == 0 {
35 | return nil
36 | }
37 |
38 | if len(src) != idLength {
39 | return fmt.Errorf("%w: %v", errIncorrectLength, len(src))
40 | }
41 |
42 | copy((*i)[:], src)
43 | default:
44 | return fmt.Errorf("%w: %T", errUnableToScan, src)
45 | }
46 |
47 | return nil
48 | }
49 |
50 | // Value implements sql.Valuer so that TaskIDs can be written to MySQL
51 | // transparently.
52 | func (i TaskID) Value() (driver.Value, error) {
53 | return i[:], nil
54 | }
55 |
56 | type mySQLTask struct {
57 | ID TaskID `db:"id"`
58 | Type string `db:"type"`
59 | Args []byte `db:"args"`
60 | Queue string `db:"queue"`
61 | Priority int16 `db:"priority"`
62 | Status tasq.TaskStatus `db:"status"`
63 | ReceiveCount int32 `db:"receive_count"`
64 | MaxReceives int32 `db:"max_receives"`
65 | LastError sql.NullString `db:"last_error"`
66 | CreatedAt string `db:"created_at"`
67 | StartedAt sql.NullString `db:"started_at"`
68 | FinishedAt sql.NullString `db:"finished_at"`
69 | VisibleAt string `db:"visible_at"`
70 | }
71 |
72 | func newFromTask(task *tasq.Task) *mySQLTask {
73 | return &mySQLTask{
74 | ID: TaskID(task.ID),
75 | Type: task.Type,
76 | Args: task.Args,
77 | Queue: task.Queue,
78 | Priority: task.Priority,
79 | Status: task.Status,
80 | ReceiveCount: task.ReceiveCount,
81 | MaxReceives: task.MaxReceives,
82 | LastError: stringToSQLNullString(task.LastError),
83 | CreatedAt: timeToString(task.CreatedAt),
84 | StartedAt: timeToSQLNullString(task.StartedAt),
85 | FinishedAt: timeToSQLNullString(task.FinishedAt),
86 | VisibleAt: timeToString(task.VisibleAt),
87 | }
88 | }
89 |
90 | func (t *mySQLTask) toTask() *tasq.Task {
91 | return &tasq.Task{
92 | ID: uuid.UUID(t.ID),
93 | Type: t.Type,
94 | Args: t.Args,
95 | Queue: t.Queue,
96 | Priority: t.Priority,
97 | Status: t.Status,
98 | ReceiveCount: t.ReceiveCount,
99 | MaxReceives: t.MaxReceives,
100 | LastError: parseNullableString(t.LastError),
101 | CreatedAt: parseTime(t.CreatedAt),
102 | StartedAt: parseNullableTime(t.StartedAt),
103 | FinishedAt: parseNullableTime(t.FinishedAt),
104 | VisibleAt: parseTime(t.VisibleAt),
105 | }
106 | }
107 |
108 | func mySQLTasksToTasks(mySQLTasks []*mySQLTask) []*tasq.Task {
109 | tasks := make([]*tasq.Task, len(mySQLTasks))
110 |
111 | for i, mySQLTask := range mySQLTasks {
112 | tasks[i] = mySQLTask.toTask()
113 | }
114 |
115 | return tasks
116 | }
117 |
118 | func stringToSQLNullString(input *string) sql.NullString {
119 | if input == nil {
120 | return sql.NullString{
121 | String: "",
122 | Valid: false,
123 | }
124 | }
125 |
126 | return sql.NullString{
127 | String: *input,
128 | Valid: true,
129 | }
130 | }
131 |
132 | func timeToString(input time.Time) string {
133 | return input.Format(timeFormat)
134 | }
135 |
136 | func timeToSQLNullString(input *time.Time) sql.NullString {
137 | if input == nil {
138 | return sql.NullString{
139 | String: "",
140 | Valid: false,
141 | }
142 | }
143 |
144 | return sql.NullString{
145 | String: input.Format(timeFormat),
146 | Valid: true,
147 | }
148 | }
149 |
150 | func parseNullableString(input sql.NullString) *string {
151 | if !input.Valid {
152 | return nil
153 | }
154 |
155 | return &input.String
156 | }
157 |
158 | func parseTime(input string) time.Time {
159 | parsedTime, err := time.Parse(timeFormat, input)
160 | if err != nil {
161 | return time.Time{}
162 | }
163 |
164 | return parsedTime
165 | }
166 |
167 | func parseNullableTime(input sql.NullString) *time.Time {
168 | if !input.Valid {
169 | return nil
170 | }
171 |
172 | parsedTime := parseTime(input.String)
173 |
174 | return &parsedTime
175 | }
176 |
--------------------------------------------------------------------------------
/repository/mysql/mysqlTask_test.go:
--------------------------------------------------------------------------------
1 | package mysql_test
2 |
3 | import (
4 | "database/sql"
5 | "testing"
6 |
7 | "github.com/greencoda/tasq/repository/mysql"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestMySQLTaskIDScan(t *testing.T) {
12 | t.Parallel()
13 |
14 | var mySQLID mysql.TaskID
15 |
16 | err := mySQLID.Scan(nil)
17 | assert.Nil(t, err)
18 |
19 | err = mySQLID.Scan([]byte{})
20 | assert.Nil(t, err)
21 |
22 | err = mySQLID.Scan([]byte{12})
23 | assert.NotNil(t, err)
24 |
25 | err = mySQLID.Scan("test")
26 | assert.NotNil(t, err)
27 | }
28 |
29 | func TestStringToSQLNullString(t *testing.T) {
30 | t.Parallel()
31 |
32 | var (
33 | emptyInput = ""
34 | nonEmptyInput = "test"
35 | nilInput *string
36 | )
37 |
38 | sqlNullString := mysql.StringToSQLNullString(&emptyInput)
39 | assert.Equal(t, sql.NullString{
40 | String: emptyInput,
41 | Valid: true,
42 | }, sqlNullString)
43 |
44 | sqlNullString = mysql.StringToSQLNullString(&nonEmptyInput)
45 | assert.Equal(t, sql.NullString{
46 | String: nonEmptyInput,
47 | Valid: true,
48 | }, sqlNullString)
49 |
50 | sqlNullString = mysql.StringToSQLNullString(nilInput)
51 | assert.Equal(t, sql.NullString{
52 | String: "",
53 | Valid: false,
54 | }, sqlNullString)
55 | }
56 |
57 | func TestParseNullableString(t *testing.T) {
58 | t.Parallel()
59 |
60 | var (
61 | emptyInput = sql.NullString{
62 | String: "",
63 | Valid: true,
64 | }
65 | nonEmptyInput = sql.NullString{
66 | String: "test",
67 | Valid: true,
68 | }
69 | nilInput = sql.NullString{
70 | String: "",
71 | Valid: false,
72 | }
73 | )
74 |
75 | output := mysql.ParseNullableString(emptyInput)
76 | assert.NotNil(t, output)
77 | assert.Equal(t, *output, emptyInput.String)
78 |
79 | output = mysql.ParseNullableString(nonEmptyInput)
80 | assert.NotNil(t, output)
81 | assert.Equal(t, *output, nonEmptyInput.String)
82 |
83 | output = mysql.ParseNullableString(nilInput)
84 | assert.Nil(t, output)
85 | }
86 |
--------------------------------------------------------------------------------
/repository/mysql/mysql_test.go:
--------------------------------------------------------------------------------
1 | package mysql_test
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "database/sql/driver"
7 | "errors"
8 | "regexp"
9 | "testing"
10 | "time"
11 |
12 | "github.com/DATA-DOG/go-sqlmock"
13 | "github.com/google/uuid"
14 | "github.com/greencoda/tasq"
15 | "github.com/greencoda/tasq/repository/mysql"
16 | "github.com/stretchr/testify/assert"
17 | "github.com/stretchr/testify/require"
18 | "github.com/stretchr/testify/suite"
19 | )
20 |
21 | var (
22 | ctx = context.Background()
23 | testTask = getStartedTestTask()
24 | taskColumns = []string{
25 | "id",
26 | "type",
27 | "args",
28 | "queue",
29 | "priority",
30 | "status",
31 | "receive_count",
32 | "max_receives",
33 | "last_error",
34 | "created_at",
35 | "started_at",
36 | "finished_at",
37 | "visible_at",
38 | }
39 | testTaskType = "testTask"
40 | testTaskQueue = "testQueue"
41 | taskValues = mysql.GetTestTaskValues(testTask)
42 | errSQL = errors.New("sql error")
43 | errTask = errors.New("task error")
44 | )
45 |
46 | func getStartedTestTask() *tasq.Task {
47 | var (
48 | testTask, _ = tasq.NewTask(testTaskType, true, testTaskQueue, 100, 5)
49 | startTime = testTask.CreatedAt.Add(time.Second)
50 | )
51 |
52 | testTask.StartedAt = &startTime
53 |
54 | return testTask
55 | }
56 |
57 | type MySQLTestSuite struct {
58 | suite.Suite
59 | db *sql.DB
60 | sqlMock sqlmock.Sqlmock
61 | mockedRepository tasq.IRepository
62 | }
63 |
64 | func TestTaskTestSuite(t *testing.T) {
65 | t.Parallel()
66 |
67 | suite.Run(t, new(MySQLTestSuite))
68 | }
69 |
70 | func (s *MySQLTestSuite) SetupTest() {
71 | var err error
72 |
73 | s.db, s.sqlMock, err = sqlmock.New()
74 | require.Nil(s.T(), err)
75 |
76 | s.mockedRepository, err = mysql.NewRepository(s.db, "test")
77 | require.NotNil(s.T(), s.mockedRepository)
78 | require.Nil(s.T(), err)
79 | }
80 |
81 | func (s *MySQLTestSuite) TestNewRepository() {
82 | // providing the datasource as *sql.DB
83 | repository, err := mysql.NewRepository(s.db, "test")
84 | assert.NotNil(s.T(), repository)
85 | assert.Nil(s.T(), err)
86 |
87 | // providing the datasource as *sql.DB with no prefix
88 | repository, err = mysql.NewRepository(s.db, "")
89 | assert.NotNil(s.T(), repository)
90 | assert.Nil(s.T(), err)
91 |
92 | // providing the datasource as dsn string
93 | repository, err = mysql.NewRepository("root:root@/test", "test")
94 | assert.NotNil(s.T(), repository)
95 | assert.Nil(s.T(), err)
96 |
97 | // providing an invalid drivdsner as dsn string
98 | repository, err = mysql.NewRepository("invalidDSN", "test")
99 | assert.Nil(s.T(), repository)
100 | assert.NotNil(s.T(), err)
101 |
102 | // providing the datasource as unknown datasource type
103 | repository, err = mysql.NewRepository(false, "test")
104 | assert.Nil(s.T(), repository)
105 | assert.NotNil(s.T(), err)
106 | }
107 |
108 | func (s *MySQLTestSuite) TestMigrate() {
109 | // First try - creating the tasks table fails
110 | s.sqlMock.ExpectExec(`CREATE TABLE IF NOT EXISTS test_tasks`).WillReturnError(errSQL)
111 |
112 | err := s.mockedRepository.Migrate(ctx)
113 | assert.NotNil(s.T(), err)
114 |
115 | // Second try - migration succeeds
116 | s.sqlMock.ExpectExec(`CREATE TABLE IF NOT EXISTS test_tasks`).WillReturnResult(sqlmock.NewResult(1, 1))
117 |
118 | err = s.mockedRepository.Migrate(ctx)
119 | assert.Nil(s.T(), err)
120 | }
121 |
122 | func (s *MySQLTestSuite) TestPingTasks() {
123 | var (
124 | taskUUID = uuid.New()
125 | taskUUIDBytes, _ = taskUUID.MarshalBinary()
126 | updateMockRegexp = regexp.QuoteMeta(`UPDATE test_tasks SET visible_at = ? WHERE id IN (?);`)
127 | selectMockRegexp = regexp.QuoteMeta(`SELECT * FROM test_tasks WHERE id IN (?);`)
128 | )
129 |
130 | // pinging empty tasklist
131 | tasks, err := s.mockedRepository.PingTasks(ctx, []uuid.UUID{}, 15*time.Second)
132 | assert.Len(s.T(), tasks, 0)
133 | assert.Nil(s.T(), err)
134 |
135 | // beginning the transaction fails
136 | s.sqlMock.ExpectBegin().WillReturnError(errSQL)
137 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second)
138 | assert.Len(s.T(), tasks, 0)
139 | assert.NotNil(s.T(), err)
140 |
141 | // pinging when DB returns no rows
142 | s.sqlMock.ExpectBegin()
143 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL)
144 | s.sqlMock.ExpectRollback()
145 |
146 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second)
147 | assert.Len(s.T(), tasks, 0)
148 | assert.NotNil(s.T(), err)
149 |
150 | // pinging when DB returns no rows, rollback fails
151 | assert.PanicsWithError(s.T(), errSQL.Error(), func() {
152 | s.sqlMock.ExpectBegin()
153 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL)
154 | s.sqlMock.ExpectRollback().WillReturnError(errSQL)
155 |
156 | _, _ = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second)
157 | })
158 |
159 | // pinging existing task fails
160 | s.sqlMock.ExpectBegin()
161 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
162 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL)
163 | s.sqlMock.ExpectRollback()
164 |
165 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second)
166 | assert.Len(s.T(), tasks, 0)
167 | assert.NotNil(s.T(), err)
168 |
169 | // pinging existing task succeeds, commit fails
170 | s.sqlMock.ExpectBegin()
171 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
172 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes))
173 | s.sqlMock.ExpectCommit().WillReturnError(errSQL)
174 |
175 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second)
176 | assert.Len(s.T(), tasks, 0)
177 | assert.NotNil(s.T(), err)
178 |
179 | // pinging existing task succeeds, commit succeeds
180 | s.sqlMock.ExpectBegin()
181 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
182 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes))
183 | s.sqlMock.ExpectCommit()
184 |
185 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second)
186 | assert.Len(s.T(), tasks, 1)
187 | assert.Nil(s.T(), err)
188 | }
189 |
190 | func (s *MySQLTestSuite) TestPollTasks() {
191 | var (
192 | taskUUID = uuid.New()
193 | taskUUIDBytes, _ = taskUUID.MarshalBinary()
194 | selectMockRegexp = regexp.QuoteMeta(`SELECT
195 | id
196 | FROM
197 | test_tasks
198 | WHERE
199 | type IN (?) AND
200 | queue IN (?) AND
201 | status IN (?, ?, ?) AND
202 | visible_at <= ?
203 | ORDER BY
204 | ?, ?
205 | LIMIT ?
206 | FOR UPDATE SKIP LOCKED;`)
207 | updateMockRegexp = regexp.QuoteMeta(`UPDATE
208 | test_tasks
209 | SET
210 | status = ?,
211 | receive_count = receive_count + 1,
212 | visible_at = ?
213 | WHERE
214 | id IN (?);`)
215 | selectUpdatedMockRegexp = regexp.QuoteMeta(`SELECT
216 | *
217 | FROM
218 | test_tasks
219 | WHERE
220 | id IN (?);`)
221 | )
222 |
223 | // polling with 0 limit
224 | tasks, err := s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 0)
225 | assert.Len(s.T(), tasks, 0)
226 | assert.Nil(s.T(), err)
227 |
228 | // beginning the transaction fails
229 | s.sqlMock.ExpectBegin().WillReturnError(errSQL)
230 |
231 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1)
232 | assert.Len(s.T(), tasks, 0)
233 | assert.NotNil(s.T(), err)
234 |
235 | // polling when DB returns an error
236 | s.sqlMock.ExpectBegin()
237 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL)
238 | s.sqlMock.ExpectRollback()
239 |
240 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1)
241 | assert.Len(s.T(), tasks, 0)
242 | assert.NotNil(s.T(), err)
243 |
244 | // polling when DB returns no task IDs
245 | s.sqlMock.ExpectBegin()
246 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}))
247 | s.sqlMock.ExpectRollback()
248 |
249 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1)
250 | assert.Len(s.T(), tasks, 0)
251 | assert.Nil(s.T(), err)
252 |
253 | // polling when DB fails to update task
254 | s.sqlMock.ExpectBegin()
255 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes))
256 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL)
257 | s.sqlMock.ExpectRollback()
258 |
259 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1)
260 | assert.Len(s.T(), tasks, 0)
261 | assert.NotNil(s.T(), err)
262 |
263 | // polling when DB succeeds to update task but fails to select the updated tasks
264 | s.sqlMock.ExpectBegin()
265 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes))
266 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
267 | s.sqlMock.ExpectQuery(selectUpdatedMockRegexp).WillReturnError(errSQL)
268 | s.sqlMock.ExpectRollback()
269 |
270 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1)
271 | assert.Len(s.T(), tasks, 0)
272 | assert.NotNil(s.T(), err)
273 |
274 | // polling when DB succeeds to update task but fails to select the updated tasks
275 | s.sqlMock.ExpectBegin()
276 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes))
277 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
278 | s.sqlMock.ExpectQuery(selectUpdatedMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
279 | s.sqlMock.ExpectCommit().WillReturnError(errSQL)
280 |
281 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1)
282 | assert.Len(s.T(), tasks, 0)
283 | assert.NotNil(s.T(), err)
284 |
285 | // polling successfully
286 | s.sqlMock.ExpectBegin()
287 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes))
288 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
289 | s.sqlMock.ExpectQuery(selectUpdatedMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
290 | s.sqlMock.ExpectCommit()
291 |
292 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1)
293 | assert.Len(s.T(), tasks, 1)
294 | assert.Nil(s.T(), err)
295 |
296 | // polling successfully with unknown ordering
297 | s.sqlMock.ExpectBegin()
298 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUIDBytes))
299 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
300 | s.sqlMock.ExpectQuery(selectUpdatedMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
301 | s.sqlMock.ExpectCommit()
302 |
303 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{"testTask"}, []string{"testQueue"}, 15*time.Second, -1, 1)
304 | assert.Len(s.T(), tasks, 1)
305 | assert.Nil(s.T(), err)
306 | }
307 |
308 | func (s *MySQLTestSuite) TestCleanTasks() {
309 | deleteMockRegexp := regexp.QuoteMeta(`DELETE
310 | FROM
311 | test_tasks
312 | WHERE
313 | status IN (?, ?) AND
314 | created_at <= ?;`)
315 |
316 | // cleaning when DB returns error
317 | s.sqlMock.ExpectExec(deleteMockRegexp).WillReturnError(errSQL)
318 |
319 | rowsAffected, err := s.mockedRepository.CleanTasks(ctx, time.Hour)
320 | assert.Zero(s.T(), rowsAffected)
321 | assert.NotNil(s.T(), err)
322 |
323 | // cleaning when no rows are found
324 | s.sqlMock.ExpectExec(deleteMockRegexp).WillReturnResult(driver.ResultNoRows)
325 |
326 | rowsAffected, err = s.mockedRepository.CleanTasks(ctx, time.Hour)
327 | assert.Equal(s.T(), int64(0), rowsAffected)
328 | assert.NotNil(s.T(), err)
329 |
330 | // cleaning successful
331 | s.sqlMock.ExpectExec(deleteMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
332 |
333 | rowsAffected, err = s.mockedRepository.CleanTasks(ctx, time.Hour)
334 | assert.Equal(s.T(), int64(1), rowsAffected)
335 | assert.Nil(s.T(), err)
336 | }
337 |
338 | func (s *MySQLTestSuite) TestRegisterStart() {
339 | var (
340 | updateMockRegexp = regexp.QuoteMeta(`UPDATE
341 | test_tasks
342 | SET
343 | status = ?,
344 | started_at = ?
345 | WHERE
346 | id = ?;`)
347 | selectMockRegexp = regexp.QuoteMeta(`SELECT *
348 | FROM
349 | test_tasks
350 | WHERE
351 | id = ?;`)
352 | )
353 |
354 | // beginning the transaction fails
355 | s.sqlMock.ExpectBegin().WillReturnError(errSQL)
356 |
357 | task, err := s.mockedRepository.RegisterStart(ctx, testTask)
358 | assert.Empty(s.T(), task)
359 | assert.NotNil(s.T(), err)
360 |
361 | // registering start when update fails
362 | s.sqlMock.ExpectBegin()
363 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL)
364 | s.sqlMock.ExpectRollback()
365 |
366 | task, err = s.mockedRepository.RegisterStart(ctx, testTask)
367 | assert.Empty(s.T(), task)
368 | assert.NotNil(s.T(), err)
369 |
370 | // registering start when update is successful but select fails
371 | s.sqlMock.ExpectBegin()
372 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
373 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL)
374 | s.sqlMock.ExpectRollback()
375 |
376 | task, err = s.mockedRepository.RegisterStart(ctx, testTask)
377 | assert.Empty(s.T(), task)
378 | assert.NotNil(s.T(), err)
379 |
380 | // registering start when commit fails
381 | s.sqlMock.ExpectBegin()
382 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
383 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
384 | s.sqlMock.ExpectCommit().WillReturnError(errSQL)
385 |
386 | task, err = s.mockedRepository.RegisterStart(ctx, testTask)
387 | assert.Empty(s.T(), task)
388 | assert.NotNil(s.T(), err)
389 |
390 | // registering error when update is successful but select returns no rows
391 | s.sqlMock.ExpectBegin()
392 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
393 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns))
394 | s.sqlMock.ExpectRollback()
395 |
396 | task, err = s.mockedRepository.RegisterStart(ctx, testTask)
397 | assert.Empty(s.T(), task)
398 | assert.NotNil(s.T(), err)
399 |
400 | // registering start successful
401 | s.sqlMock.ExpectBegin()
402 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
403 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
404 | s.sqlMock.ExpectCommit()
405 |
406 | task, err = s.mockedRepository.RegisterStart(ctx, testTask)
407 | assert.NotEmpty(s.T(), task)
408 | assert.Nil(s.T(), err)
409 | }
410 |
411 | func (s *MySQLTestSuite) TestRegisterError() {
412 | var (
413 | updateMockRegexp = regexp.QuoteMeta(`UPDATE
414 | test_tasks
415 | SET
416 | last_error = ?
417 | WHERE
418 | id = ?;`)
419 | selectMockRegexp = regexp.QuoteMeta(`SELECT *
420 | FROM
421 | test_tasks
422 | WHERE
423 | id = ?;`)
424 | )
425 |
426 | // beginning the transaction fails
427 | s.sqlMock.ExpectBegin().WillReturnError(errSQL)
428 |
429 | task, err := s.mockedRepository.RegisterError(ctx, testTask, errTask)
430 | assert.Empty(s.T(), task)
431 | assert.NotNil(s.T(), err)
432 |
433 | // registering error when update fails
434 | s.sqlMock.ExpectBegin()
435 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL)
436 | s.sqlMock.ExpectRollback()
437 |
438 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask)
439 | assert.Empty(s.T(), task)
440 | assert.NotNil(s.T(), err)
441 |
442 | // registering error when update is successful but select fails
443 | s.sqlMock.ExpectBegin()
444 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
445 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL)
446 | s.sqlMock.ExpectRollback()
447 |
448 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask)
449 | assert.Empty(s.T(), task)
450 | assert.NotNil(s.T(), err)
451 |
452 | // registering error when update is successful but select returns no rows
453 | s.sqlMock.ExpectBegin()
454 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
455 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns))
456 | s.sqlMock.ExpectRollback()
457 |
458 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask)
459 | assert.Empty(s.T(), task)
460 | assert.NotNil(s.T(), err)
461 |
462 | // registering error when commit fails
463 | s.sqlMock.ExpectBegin()
464 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
465 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
466 | s.sqlMock.ExpectCommit().WillReturnError(errSQL)
467 |
468 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask)
469 | assert.Empty(s.T(), task)
470 | assert.NotNil(s.T(), err)
471 |
472 | // registering error successful
473 | s.sqlMock.ExpectBegin()
474 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
475 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
476 | s.sqlMock.ExpectCommit()
477 |
478 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask)
479 | assert.NotEmpty(s.T(), task)
480 | assert.Nil(s.T(), err)
481 | }
482 |
483 | func (s *MySQLTestSuite) TestRegisterFinish() {
484 | var (
485 | updateMockRegexp = regexp.QuoteMeta(`UPDATE
486 | test_tasks
487 | SET
488 | status = ?,
489 | finished_at = ?
490 | WHERE
491 | id = ?;`)
492 | selectMockRegexp = regexp.QuoteMeta(`SELECT *
493 | FROM
494 | test_tasks
495 | WHERE
496 | id = ?;`)
497 | )
498 |
499 | // beginning the transaction fails
500 | s.sqlMock.ExpectBegin().WillReturnError(errSQL)
501 |
502 | task, err := s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful)
503 | assert.Empty(s.T(), task)
504 | assert.NotNil(s.T(), err)
505 |
506 | // registering success when update fails
507 | s.sqlMock.ExpectBegin()
508 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL)
509 | s.sqlMock.ExpectRollback()
510 |
511 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful)
512 | assert.Empty(s.T(), task)
513 | assert.NotNil(s.T(), err)
514 |
515 | // registering success when update is successful but select fails
516 | s.sqlMock.ExpectBegin()
517 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
518 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL)
519 | s.sqlMock.ExpectRollback()
520 |
521 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful)
522 | assert.Empty(s.T(), task)
523 | assert.NotNil(s.T(), err)
524 |
525 | // registering success when update is successful but select returns no rows
526 | s.sqlMock.ExpectBegin()
527 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
528 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns))
529 | s.sqlMock.ExpectRollback()
530 |
531 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful)
532 | assert.Empty(s.T(), task)
533 | assert.NotNil(s.T(), err)
534 |
535 | // registering success when commit fails
536 | s.sqlMock.ExpectBegin()
537 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
538 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
539 | s.sqlMock.ExpectCommit().WillReturnError(errSQL)
540 |
541 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful)
542 | assert.Empty(s.T(), task)
543 | assert.NotNil(s.T(), err)
544 |
545 | // registering success successful
546 | s.sqlMock.ExpectBegin()
547 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
548 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
549 | s.sqlMock.ExpectCommit()
550 |
551 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful)
552 | assert.NotEmpty(s.T(), task)
553 | assert.Nil(s.T(), err)
554 | }
555 |
556 | func (s *MySQLTestSuite) TestSubmitTask() {
557 | var (
558 | insertMockRegexp = regexp.QuoteMeta(`INSERT INTO
559 | test_tasks
560 | (id, type, args, queue, priority, status, max_receives, created_at, visible_at)
561 | VALUES
562 | (?, ?, ?, ?, ?, ?, ?, ?, ?);`)
563 | selectMockRegexp = regexp.QuoteMeta(`SELECT *
564 | FROM
565 | test_tasks
566 | WHERE
567 | id = ?;`)
568 | )
569 |
570 | // beginning the transaction fails
571 | s.sqlMock.ExpectBegin().WillReturnError(errSQL)
572 |
573 | task, err := s.mockedRepository.SubmitTask(ctx, testTask)
574 | assert.Empty(s.T(), task)
575 | assert.NotNil(s.T(), err)
576 |
577 | // registering failure when update fails
578 | s.sqlMock.ExpectBegin()
579 | s.sqlMock.ExpectExec(insertMockRegexp).WillReturnError(errSQL)
580 | s.sqlMock.ExpectRollback()
581 |
582 | task, err = s.mockedRepository.SubmitTask(ctx, testTask)
583 | assert.Empty(s.T(), task)
584 | assert.NotNil(s.T(), err)
585 |
586 | // registering failure when update is successful but select fails
587 | s.sqlMock.ExpectBegin()
588 | s.sqlMock.ExpectExec(insertMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
589 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL)
590 | s.sqlMock.ExpectRollback()
591 |
592 | task, err = s.mockedRepository.SubmitTask(ctx, testTask)
593 | assert.Empty(s.T(), task)
594 | assert.NotNil(s.T(), err)
595 |
596 | // registering failure when update is successful but select returns no rows
597 | s.sqlMock.ExpectBegin()
598 | s.sqlMock.ExpectExec(insertMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
599 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns))
600 | s.sqlMock.ExpectRollback()
601 |
602 | task, err = s.mockedRepository.SubmitTask(ctx, testTask)
603 | assert.Empty(s.T(), task)
604 | assert.NotNil(s.T(), err)
605 |
606 | // registering failure when commit fails
607 | s.sqlMock.ExpectBegin()
608 | s.sqlMock.ExpectExec(insertMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
609 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
610 | s.sqlMock.ExpectCommit().WillReturnError(errSQL)
611 |
612 | task, err = s.mockedRepository.SubmitTask(ctx, testTask)
613 | assert.Empty(s.T(), task)
614 | assert.NotNil(s.T(), err)
615 |
616 | // registering failure successful
617 | s.sqlMock.ExpectBegin()
618 | s.sqlMock.ExpectExec(insertMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
619 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
620 | s.sqlMock.ExpectCommit()
621 |
622 | task, err = s.mockedRepository.SubmitTask(ctx, testTask)
623 | assert.NotEmpty(s.T(), task)
624 | assert.Nil(s.T(), err)
625 | }
626 |
627 | func (s *MySQLTestSuite) TestDeleteTask() {
628 | var (
629 | deleteMockRegexp = regexp.QuoteMeta(`DELETE
630 | FROM
631 | test_tasks
632 | WHERE
633 | id = ?;`)
634 | deleteSafeDeleteMockRegexp = regexp.QuoteMeta(`DELETE
635 | FROM
636 | test_tasks
637 | WHERE
638 | id = ? AND
639 | (
640 | (
641 | visible_at <= ?
642 | ) OR
643 | (
644 | status IN (?) AND
645 | visible_at > ?
646 | )
647 | );`)
648 | )
649 |
650 | // deleting task when DB returns error
651 | s.sqlMock.ExpectExec(deleteMockRegexp).WillReturnError(errSQL)
652 |
653 | err := s.mockedRepository.DeleteTask(ctx, testTask, false)
654 | assert.NotNil(s.T(), err)
655 |
656 | // deleting task successful
657 | s.sqlMock.ExpectExec(deleteMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
658 |
659 | err = s.mockedRepository.DeleteTask(ctx, testTask, false)
660 | assert.Nil(s.T(), err)
661 |
662 | // deleting invisible task successful
663 | s.sqlMock.ExpectExec(deleteSafeDeleteMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
664 |
665 | err = s.mockedRepository.DeleteTask(ctx, testTask, true)
666 | assert.Nil(s.T(), err)
667 | }
668 |
669 | func (s *MySQLTestSuite) TestCountTasks() {
670 | selectMockRegexp := regexp.QuoteMeta(`SELECT
671 | COUNT(*)
672 | FROM
673 | test_tasks
674 | WHERE
675 | status IN (?) AND
676 | type IN (?) AND
677 | queue IN (?);`)
678 |
679 | // counting when DB returns error
680 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL)
681 |
682 | count, err := s.mockedRepository.CountTasks(ctx, []tasq.TaskStatus{testTask.Status}, []string{testTask.Type}, []string{testTask.Queue})
683 | assert.Equal(s.T(), int64(0), count)
684 | assert.NotNil(s.T(), err)
685 |
686 | // counting successful
687 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(10))
688 |
689 | count, err = s.mockedRepository.CountTasks(ctx, []tasq.TaskStatus{testTask.Status}, []string{testTask.Type}, []string{testTask.Queue})
690 | assert.Equal(s.T(), int64(10), count)
691 | assert.Nil(s.T(), err)
692 | }
693 |
694 | func (s *MySQLTestSuite) TestScanTasks() {
695 | selectMockRegexp := regexp.QuoteMeta(`SELECT
696 | *
697 | FROM
698 | test_tasks
699 | WHERE
700 | status IN (?) AND
701 | type IN (?) AND
702 | queue IN (?)
703 | ORDER BY ?
704 | LIMIT ?;`)
705 |
706 | // scanning when DB returns error
707 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL)
708 |
709 | tasks, err := s.mockedRepository.ScanTasks(ctx, []tasq.TaskStatus{testTask.Status}, []string{testTask.Type}, []string{testTask.Queue}, tasq.OrderingCreatedAtFirst, 10)
710 | assert.Empty(s.T(), tasks)
711 | assert.NotNil(s.T(), err)
712 |
713 | // scanning successful
714 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
715 |
716 | tasks, err = s.mockedRepository.ScanTasks(ctx, []tasq.TaskStatus{testTask.Status}, []string{testTask.Type}, []string{testTask.Queue}, tasq.OrderingCreatedAtFirst, 10)
717 | assert.NotEmpty(s.T(), tasks)
718 | assert.Nil(s.T(), err)
719 | }
720 |
721 | func (s *MySQLTestSuite) TestPurgeTasks() {
722 | var (
723 | purgeMockRegexp = regexp.QuoteMeta(`DELETE
724 | FROM
725 | test_tasks
726 | WHERE
727 | status IN (?) AND
728 | queue IN (?);`)
729 | purgeSafeDeleteMockRegexp = regexp.QuoteMeta(`DELETE
730 | FROM
731 | test_tasks
732 | WHERE
733 | status IN (?) AND
734 | queue IN (?) AND
735 | (
736 | ( visible_at <= ? ) OR
737 | (
738 | status IN (?) AND
739 | visible_at > ?
740 | )
741 | );`)
742 | )
743 |
744 | // purging tasks when DB returns error
745 | s.sqlMock.ExpectExec(purgeMockRegexp).WillReturnError(errSQL)
746 |
747 | count, err := s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false)
748 | assert.Equal(s.T(), int64(0), count)
749 | assert.NotNil(s.T(), err)
750 |
751 | // purging when no rows are found
752 | s.sqlMock.ExpectExec(purgeMockRegexp).WillReturnResult(driver.ResultNoRows)
753 |
754 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false)
755 | assert.Equal(s.T(), int64(0), count)
756 | assert.NotNil(s.T(), err)
757 |
758 | // purging tasks successful
759 | s.sqlMock.ExpectExec(purgeMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
760 |
761 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false)
762 | assert.Equal(s.T(), int64(1), count)
763 | assert.Nil(s.T(), err)
764 |
765 | // purging tasks with safeDelete successful
766 | s.sqlMock.ExpectExec(purgeSafeDeleteMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
767 |
768 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, true)
769 | assert.Equal(s.T(), int64(1), count)
770 | assert.Nil(s.T(), err)
771 | }
772 |
773 | func (s *MySQLTestSuite) TestRequeueTask() {
774 | var (
775 | updateMockRegexp = regexp.QuoteMeta(`UPDATE
776 | test_tasks
777 | SET
778 | status = ?
779 | WHERE
780 | id = ?;`)
781 | selectMockRegexp = regexp.QuoteMeta(`SELECT *
782 | FROM
783 | test_tasks
784 | WHERE
785 | id = ?;`)
786 | )
787 |
788 | // beginning the transaction fails
789 | s.sqlMock.ExpectBegin().WillReturnError(errSQL)
790 |
791 | task, err := s.mockedRepository.RequeueTask(ctx, testTask)
792 | assert.Empty(s.T(), task)
793 | assert.NotNil(s.T(), err)
794 |
795 | // requeuing when update fails
796 | s.sqlMock.ExpectBegin()
797 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnError(errSQL)
798 | s.sqlMock.ExpectRollback()
799 |
800 | task, err = s.mockedRepository.RequeueTask(ctx, testTask)
801 | assert.Empty(s.T(), task)
802 | assert.NotNil(s.T(), err)
803 |
804 | // requeuing when update is successful but select fails
805 | s.sqlMock.ExpectBegin()
806 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
807 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnError(errSQL)
808 | s.sqlMock.ExpectRollback()
809 |
810 | task, err = s.mockedRepository.RequeueTask(ctx, testTask)
811 | assert.Empty(s.T(), task)
812 | assert.NotNil(s.T(), err)
813 |
814 | // requeuing when update is successful but select returns no rows
815 | s.sqlMock.ExpectBegin()
816 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
817 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns))
818 | s.sqlMock.ExpectRollback()
819 |
820 | task, err = s.mockedRepository.RequeueTask(ctx, testTask)
821 | assert.Empty(s.T(), task)
822 | assert.NotNil(s.T(), err)
823 |
824 | // requeuing when commit fails
825 | s.sqlMock.ExpectBegin()
826 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
827 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
828 | s.sqlMock.ExpectCommit().WillReturnError(errSQL)
829 |
830 | task, err = s.mockedRepository.RequeueTask(ctx, testTask)
831 | assert.Empty(s.T(), task)
832 | assert.NotNil(s.T(), err)
833 |
834 | // requeuing successful
835 | s.sqlMock.ExpectBegin()
836 | s.sqlMock.ExpectExec(updateMockRegexp).WillReturnResult(sqlmock.NewResult(1, 1))
837 | s.sqlMock.ExpectQuery(selectMockRegexp).WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
838 | s.sqlMock.ExpectCommit()
839 |
840 | task, err = s.mockedRepository.RequeueTask(ctx, testTask)
841 | assert.NotEmpty(s.T(), task)
842 | assert.Nil(s.T(), err)
843 | }
844 |
845 | func (s *MySQLTestSuite) TestGetQueryWithTableName() {
846 | var (
847 | taskUUID = uuid.New()
848 | taskUUIDBytes, _ = taskUUID.MarshalBinary()
849 | )
850 |
851 | mysqlRepository, ok := s.mockedRepository.(*mysql.Repository)
852 | require.True(s.T(), ok)
853 |
854 | assert.Panics(s.T(), func() {
855 | _, _ = mysqlRepository.GetQueryWithTableName("SELECT * FROM {{.tableName}} WHERE id = :taskID:;", map[string]any{
856 | "taskID": taskUUIDBytes,
857 | })
858 | })
859 |
860 | assert.Panics(s.T(), func() {
861 | _, _ = mysqlRepository.GetQueryWithTableName("SELECT * FROM {{.tableName}} WHERE id IN (:taskIDs);", map[string]any{
862 | "taskIDs": [][]byte{},
863 | })
864 | })
865 |
866 | assert.NotPanics(s.T(), func() {
867 | query, args := mysqlRepository.GetQueryWithTableName("SELECT * FROM {{.tableName}} WHERE id = :taskID", map[string]any{
868 | "taskID": taskUUIDBytes,
869 | })
870 | assert.Equal(s.T(), "SELECT * FROM test_tasks WHERE id = ?", query)
871 | assert.Contains(s.T(), args, taskUUIDBytes)
872 | })
873 | }
874 |
875 | func (s *MySQLTestSuite) TestInterpolateSQL() {
876 | params := map[string]any{"tableName": "test_table"}
877 |
878 | // Interpolate SQL successfully
879 | interpolatedSQL := mysql.InterpolateSQL("SELECT * FROM {{.tableName}}", params)
880 | assert.Equal(s.T(), "SELECT * FROM test_table", interpolatedSQL)
881 |
882 | // Fail interpolaing unparseable SQL template
883 | assert.Panics(s.T(), func() {
884 | unparseableTemplateSQL := mysql.InterpolateSQL("SELECT * FROM {{.tableName", params)
885 | assert.Empty(s.T(), unparseableTemplateSQL)
886 | })
887 |
888 | // Fail interpolaing unexecutable SQL template
889 | assert.Panics(s.T(), func() {
890 | unexecutableTemplateSQL := mysql.InterpolateSQL(`SELECT * FROM {{if .tableName eq 1}} {{end}} {{.tableName}}`, params)
891 | assert.Empty(s.T(), unexecutableTemplateSQL)
892 | })
893 | }
894 |
--------------------------------------------------------------------------------
/repository/postgres/export_test.go:
--------------------------------------------------------------------------------
1 | package postgres
2 |
3 | import (
4 | "database/sql"
5 | "database/sql/driver"
6 |
7 | "github.com/greencoda/tasq"
8 | "github.com/jmoiron/sqlx"
9 | )
10 |
11 | func GetTestTaskValues(task *tasq.Task) []driver.Value {
12 | testMySQLTask := newFromTask(task)
13 |
14 | return []driver.Value{
15 | testMySQLTask.ID,
16 | testMySQLTask.Type,
17 | testMySQLTask.Args,
18 | testMySQLTask.Queue,
19 | testMySQLTask.Priority,
20 | testMySQLTask.Status,
21 | testMySQLTask.ReceiveCount,
22 | testMySQLTask.MaxReceives,
23 | testMySQLTask.LastError,
24 | testMySQLTask.CreatedAt,
25 | testMySQLTask.StartedAt,
26 | testMySQLTask.FinishedAt,
27 | testMySQLTask.VisibleAt,
28 | }
29 | }
30 |
31 | func (d *Repository) PrepareWithTableName(sqlTemplate string) *sqlx.NamedStmt {
32 | return d.prepareWithTableName(sqlTemplate)
33 | }
34 |
35 | func (d *Repository) CloseNamedStmt(stmt closeableStmt) {
36 | d.closeStmt(stmt)
37 | }
38 |
39 | func InterpolateSQL(sql string, params map[string]any) string {
40 | return interpolateSQL(sql, params)
41 | }
42 |
43 | func StringToSQLNullString(input *string) sql.NullString {
44 | return stringToSQLNullString(input)
45 | }
46 |
47 | func ParseNullableString(input sql.NullString) *string {
48 | return parseNullableString(input)
49 | }
50 |
--------------------------------------------------------------------------------
/repository/postgres/postgres.go:
--------------------------------------------------------------------------------
1 | // Package postgres provides the implementation of a tasq repository in PostgreSQL
2 | package postgres
3 |
4 | import (
5 | "bytes"
6 | "context"
7 | "database/sql"
8 | "errors"
9 | "fmt"
10 | "strings"
11 | "text/template"
12 | "time"
13 |
14 | "github.com/google/uuid"
15 | "github.com/greencoda/tasq"
16 | "github.com/jmoiron/sqlx"
17 | "github.com/lib/pq"
18 | )
19 |
20 | const driverName = "postgres"
21 |
22 | var (
23 | errUnexpectedDataSourceType = errors.New("unexpected dataSource type")
24 | errFailedToExecuteUpdate = errors.New("failed to execute update query")
25 | errFailedToExecuteDelete = errors.New("failed to execute delete query")
26 | errFailedToExecuteInsert = errors.New("failed to execute insert query")
27 | errFailedToExecuteCreateTable = errors.New("failed to execute create table query")
28 | errFailedToExecuteCreateType = errors.New("failed to execute create type query")
29 | )
30 |
31 | // Repository implements the menthods necessary for tasq to work in PostgreSQL.
32 | type Repository struct {
33 | db *sqlx.DB
34 | statusTypeName string
35 | tableName string
36 | }
37 |
38 | // NewRepository creates a new PostgreSQL Repository instance.
39 | func NewRepository(dataSource any, prefix string) (*Repository, error) {
40 | switch d := dataSource.(type) {
41 | case string:
42 | return newRepositoryFromDSN(d, prefix)
43 | case *sql.DB:
44 | return newRepositoryFromDB(d, prefix)
45 | }
46 |
47 | return nil, fmt.Errorf("%w: %T", errUnexpectedDataSourceType, dataSource)
48 | }
49 |
50 | func newRepositoryFromDSN(dsn string, prefix string) (*Repository, error) {
51 | dbx, _ := sqlx.Open(driverName, dsn)
52 |
53 | return &Repository{
54 | db: dbx,
55 | statusTypeName: statusTypeName(prefix),
56 | tableName: tableName(prefix),
57 | }, nil
58 | }
59 |
60 | func newRepositoryFromDB(db *sql.DB, prefix string) (*Repository, error) {
61 | dbx := sqlx.NewDb(db, driverName)
62 |
63 | return &Repository{
64 | db: dbx,
65 | statusTypeName: statusTypeName(prefix),
66 | tableName: tableName(prefix),
67 | }, nil
68 | }
69 |
70 | // Migrate prepares the database with the task status type
71 | // and by adding the tasks table.
72 | func (d *Repository) Migrate(ctx context.Context) error {
73 | err := d.migrateStatus(ctx)
74 | if err != nil {
75 | return err
76 | }
77 |
78 | err = d.migrateTable(ctx)
79 | if err != nil {
80 | return err
81 | }
82 |
83 | return nil
84 | }
85 |
86 | // PingTasks pings a list of tasks by their ID
87 | // and extends their invisibility timestamp with the supplied timeout parameter.
88 | func (d *Repository) PingTasks(ctx context.Context, taskIDs []uuid.UUID, visibilityTimeout time.Duration) ([]*tasq.Task, error) {
89 | if len(taskIDs) == 0 {
90 | return []*tasq.Task{}, nil
91 | }
92 |
93 | var (
94 | pingedTasks []*postgresTask
95 | pingTime = time.Now()
96 | sqlTemplate = `UPDATE
97 | {{.tableName}}
98 | SET
99 | "visible_at" = :visibleAt
100 | WHERE
101 | "id" = ANY(:pingedTaskIDs)
102 | RETURNING id;`
103 | stmt = d.prepareWithTableName(sqlTemplate)
104 | )
105 |
106 | defer d.closeStmt(stmt)
107 |
108 | err := stmt.SelectContext(ctx, &pingedTasks, map[string]any{
109 | "visibleAt": pingTime.Add(visibilityTimeout),
110 | "pingedTaskIDs": pq.Array(taskIDs),
111 | })
112 | if err != nil && !errors.Is(err, sql.ErrNoRows) {
113 | return []*tasq.Task{}, fmt.Errorf("failed to update tasks: %w", err)
114 | }
115 |
116 | return postgresTasksToTasks(pingedTasks), nil
117 | }
118 |
119 | // PollTasks polls for available tasks matching supplied the parameters
120 | // and sets their invisibility the supplied timeout parameter to the future.
121 | func (d *Repository) PollTasks(ctx context.Context, types, queues []string, visibilityTimeout time.Duration, ordering tasq.Ordering, pollLimit int) ([]*tasq.Task, error) {
122 | if pollLimit == 0 {
123 | return []*tasq.Task{}, nil
124 | }
125 |
126 | var (
127 | polledTasks []*postgresTask
128 | pollTime = time.Now()
129 | sqlTemplate = `UPDATE {{.tableName}} SET
130 | "status" = :status,
131 | "receive_count" = "receive_count" + 1,
132 | "visible_at" = :visibleAt
133 | WHERE
134 | "id" IN (
135 | SELECT
136 | "id" FROM {{.tableName}}
137 | WHERE
138 | "type" = ANY(:pollTypes) AND
139 | "queue" = ANY(:pollQueues) AND
140 | "status" = ANY(:pollStatuses) AND
141 | "visible_at" <= :pollTime
142 | ORDER BY
143 | :pollOrdering
144 | LIMIT :pollLimit
145 | FOR UPDATE )
146 | RETURNING *;`
147 | stmt = d.prepareWithTableName(sqlTemplate)
148 | )
149 |
150 | defer d.closeStmt(stmt)
151 |
152 | err := stmt.SelectContext(ctx, &polledTasks, map[string]any{
153 | "status": tasq.StatusEnqueued,
154 | "visibleAt": pollTime.Add(visibilityTimeout),
155 | "pollTypes": pq.Array(types),
156 | "pollQueues": pq.Array(queues),
157 | "pollStatuses": pq.Array(tasq.GetTaskStatuses(tasq.OpenTasks)),
158 | "pollTime": pollTime,
159 | "pollOrdering": pq.Array(getOrderingDirectives(ordering)),
160 | "pollLimit": pollLimit,
161 | })
162 | if err != nil && !errors.Is(err, sql.ErrNoRows) {
163 | return []*tasq.Task{}, fmt.Errorf("failed to update tasks: %w", err)
164 | }
165 |
166 | return postgresTasksToTasks(polledTasks), nil
167 | }
168 |
169 | // CleanTasks removes finished tasks from the queue
170 | // if their creation date is past the supplied duration.
171 | func (d *Repository) CleanTasks(ctx context.Context, cleanAge time.Duration) (int64, error) {
172 | var (
173 | cleanTime = time.Now()
174 | sqlTemplate = `DELETE FROM {{.tableName}}
175 | WHERE
176 | "status" = ANY(:statuses) AND
177 | "created_at" <= :cleanAt;`
178 | stmt = d.prepareWithTableName(sqlTemplate)
179 | )
180 |
181 | defer d.closeStmt(stmt)
182 |
183 | result, err := stmt.ExecContext(ctx, map[string]any{
184 | "statuses": pq.Array(tasq.GetTaskStatuses(tasq.FinishedTasks)),
185 | "cleanAt": cleanTime.Add(-cleanAge),
186 | })
187 | if err != nil {
188 | return 0, fmt.Errorf("failed to delete tasks: %w", err)
189 | }
190 |
191 | rowsAffected, err := result.RowsAffected()
192 | if err != nil {
193 | return 0, fmt.Errorf("failed to get number of affected rows: %w", err)
194 | }
195 |
196 | return rowsAffected, nil
197 | }
198 |
199 | // RegisterStart marks a task as started with the 'in progress' status
200 | // and records the time of start.
201 | func (d *Repository) RegisterStart(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
202 | var (
203 | updatedTask = new(postgresTask)
204 | startTime = time.Now()
205 | sqlTemplate = `UPDATE {{.tableName}} SET
206 | "status" = :status,
207 | "started_at" = :startTime
208 | WHERE
209 | "id" = :taskID
210 | RETURNING *;`
211 | stmt = d.prepareWithTableName(sqlTemplate)
212 | )
213 |
214 | defer d.closeStmt(stmt)
215 |
216 | err := stmt.
217 | QueryRowContext(ctx, map[string]any{
218 | "status": tasq.StatusInProgress,
219 | "startTime": startTime,
220 | "taskID": task.ID,
221 | }).
222 | StructScan(updatedTask)
223 | if err != nil {
224 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err)
225 | }
226 |
227 | return updatedTask.toTask(), nil
228 | }
229 |
230 | // RegisterError records an error message on the task as last error.
231 | func (d *Repository) RegisterError(ctx context.Context, task *tasq.Task, errTask error) (*tasq.Task, error) {
232 | var (
233 | updatedTask = new(postgresTask)
234 | sqlTemplate = `UPDATE {{.tableName}} SET
235 | "last_error" = :errorMessage
236 | WHERE
237 | "id" = :taskID
238 | RETURNING *;`
239 | stmt = d.prepareWithTableName(sqlTemplate)
240 | )
241 |
242 | defer d.closeStmt(stmt)
243 |
244 | err := stmt.
245 | QueryRowContext(ctx, map[string]any{
246 | "errorMessage": errTask.Error(),
247 | "taskID": task.ID,
248 | }).
249 | StructScan(updatedTask)
250 | if err != nil {
251 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err)
252 | }
253 |
254 | return updatedTask.toTask(), nil
255 | }
256 |
257 | // RegisterFinish marks a task as finished with the supplied status
258 | // and records the time of finish.
259 | func (d *Repository) RegisterFinish(ctx context.Context, task *tasq.Task, finishStatus tasq.TaskStatus) (*tasq.Task, error) {
260 | var (
261 | updatedTask = new(postgresTask)
262 | finishTime = time.Now()
263 | sqlTemplate = `UPDATE {{.tableName}} SET
264 | "status" = :status,
265 | "finished_at" = :finishTime
266 | WHERE
267 | "id" = :taskID
268 | RETURNING *;`
269 | stmt = d.prepareWithTableName(sqlTemplate)
270 | )
271 |
272 | defer d.closeStmt(stmt)
273 |
274 | err := stmt.
275 | QueryRowContext(ctx, map[string]any{
276 | "status": finishStatus,
277 | "finishTime": finishTime,
278 | "taskID": task.ID,
279 | }).
280 | StructScan(updatedTask)
281 | if err != nil {
282 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err)
283 | }
284 |
285 | return updatedTask.toTask(), nil
286 | }
287 |
288 | // SubmitTask adds the supplied task to the queue.
289 | func (d *Repository) SubmitTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
290 | var (
291 | postgresTask = newFromTask(task)
292 | sqlTemplate = `INSERT INTO {{.tableName}}
293 | (id, type, args, queue, priority, status, max_receives, created_at, visible_at)
294 | VALUES
295 | (:id, :type, :args, :queue, :priority, :status, :maxReceives, :createdAt, :visibleAt)
296 | RETURNING *;`
297 | stmt = d.prepareWithTableName(sqlTemplate)
298 | )
299 |
300 | defer d.closeStmt(stmt)
301 |
302 | err := stmt.
303 | QueryRowContext(ctx, map[string]any{
304 | "id": postgresTask.ID,
305 | "type": postgresTask.Type,
306 | "args": postgresTask.Args,
307 | "queue": postgresTask.Queue,
308 | "priority": postgresTask.Priority,
309 | "status": postgresTask.Status,
310 | "maxReceives": postgresTask.MaxReceives,
311 | "createdAt": postgresTask.CreatedAt,
312 | "visibleAt": postgresTask.VisibleAt,
313 | }).
314 | StructScan(postgresTask)
315 | if err != nil {
316 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteInsert, err)
317 | }
318 |
319 | return postgresTask.toTask(), nil
320 | }
321 |
322 | // DeleteTask removes the supplied task from the queue.
323 | func (d *Repository) DeleteTask(ctx context.Context, task *tasq.Task, safeDelete bool) error {
324 | var (
325 | conditions = []string{
326 | `"id" = :taskID`,
327 | }
328 | parameters = map[string]any{
329 | "taskID": task.ID,
330 | }
331 | )
332 |
333 | if safeDelete {
334 | d.applySafeDeleteConditions(&conditions, ¶meters)
335 | }
336 |
337 | sqlTemplate := `DELETE FROM {{.tableName}} WHERE ` + strings.Join(conditions, ` AND `) + `;`
338 |
339 | _, err := d.prepareWithTableName(sqlTemplate).ExecContext(ctx, parameters)
340 | if err != nil {
341 | return fmt.Errorf("%s: %w", errFailedToExecuteDelete, err)
342 | }
343 |
344 | return nil
345 | }
346 |
347 | // RequeueTask marks a task as new, so it can be picked up again.
348 | func (d *Repository) RequeueTask(ctx context.Context, task *tasq.Task) (*tasq.Task, error) {
349 | var (
350 | updatedTask = new(postgresTask)
351 | sqlTemplate = `UPDATE {{.tableName}} SET
352 | "status" = :status
353 | WHERE
354 | "id" = :taskID
355 | RETURNING *;`
356 | stmt = d.prepareWithTableName(sqlTemplate)
357 | )
358 |
359 | err := stmt.
360 | QueryRowContext(ctx, map[string]any{
361 | "status": tasq.StatusNew,
362 | "taskID": task.ID,
363 | }).
364 | StructScan(updatedTask)
365 | if err != nil {
366 | return nil, fmt.Errorf("%s: %w", errFailedToExecuteUpdate, err)
367 | }
368 |
369 | return updatedTask.toTask(), err
370 | }
371 |
372 | // CountTasks returns the number of tasks in the queue based on the supplied filters.
373 | func (d *Repository) CountTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (int64, error) {
374 | var (
375 | count int64
376 | sqlTemplate, parameters = d.buildCountSQLTemplate(taskStatuses, taskTypes, queues)
377 | stmt = d.prepareWithTableName(sqlTemplate)
378 | )
379 |
380 | err := stmt.GetContext(ctx, &count, parameters)
381 | if err != nil && !errors.Is(err, sql.ErrNoRows) {
382 | return 0, fmt.Errorf("failed to count tasks: %w", err)
383 | }
384 |
385 | return count, nil
386 | }
387 |
388 | // ScanTasks returns a list of tasks in the queue based on the supplied filters.
389 | func (d *Repository) ScanTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) ([]*tasq.Task, error) {
390 | var (
391 | scannedTasks []*postgresTask
392 | sqlTemplate, parameters = d.buildScanSQLTemplate(taskStatuses, taskTypes, queues, ordering, scanLimit)
393 | stmt = d.prepareWithTableName(sqlTemplate)
394 | )
395 |
396 | err := stmt.SelectContext(ctx, &scannedTasks, parameters)
397 | if err != nil && !errors.Is(err, sql.ErrNoRows) {
398 | return []*tasq.Task{}, fmt.Errorf("failed to scan tasks: %w", err)
399 | }
400 |
401 | return postgresTasksToTasks(scannedTasks), nil
402 | }
403 |
404 | // PurgeTasks removes all tasks from the queue based on the supplied filters.
405 | func (d *Repository) PurgeTasks(ctx context.Context, taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (int64, error) {
406 | var (
407 | sqlTemplate, parameters = d.buildPurgeSQLTemplate(taskStatuses, taskTypes, queues, safeDelete)
408 | stmt = d.prepareWithTableName(sqlTemplate)
409 | )
410 |
411 | result, err := stmt.ExecContext(ctx, parameters)
412 | if err != nil && !errors.Is(err, sql.ErrNoRows) {
413 | return 0, fmt.Errorf("failed to purge tasks: %w", err)
414 | }
415 |
416 | rowsAffected, err := result.RowsAffected()
417 | if err != nil {
418 | return 0, fmt.Errorf("failed to get number of affected rows: %w", err)
419 | }
420 |
421 | return rowsAffected, nil
422 | }
423 |
424 | func (d *Repository) buildCountSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) (string, map[string]any) {
425 | var (
426 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
427 | sqlTemplate = `SELECT COUNT(*) FROM {{.tableName}}`
428 | )
429 |
430 | if len(conditions) > 0 {
431 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
432 | }
433 |
434 | return sqlTemplate, parameters
435 | }
436 |
437 | func (d *Repository) buildScanSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, ordering tasq.Ordering, scanLimit int) (string, map[string]any) {
438 | var (
439 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
440 | sqlTemplate = `SELECT * FROM {{.tableName}}`
441 | )
442 |
443 | if len(conditions) > 0 {
444 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
445 | }
446 |
447 | sqlTemplate += ` ORDER BY :scanOrdering LIMIT :limit;`
448 |
449 | parameters["scanOrdering"] = pq.Array(getOrderingDirectives(ordering))
450 | parameters["limit"] = scanLimit
451 |
452 | return sqlTemplate, parameters
453 | }
454 |
455 | func (d *Repository) buildPurgeSQLTemplate(taskStatuses []tasq.TaskStatus, taskTypes, queues []string, safeDelete bool) (string, map[string]any) {
456 | var (
457 | conditions, parameters = d.buildFilterConditions(taskStatuses, taskTypes, queues)
458 | sqlTemplate = `DELETE FROM {{.tableName}}`
459 | )
460 |
461 | if safeDelete {
462 | d.applySafeDeleteConditions(&conditions, ¶meters)
463 | }
464 |
465 | if len(conditions) > 0 {
466 | sqlTemplate += ` WHERE ` + strings.Join(conditions, " AND ")
467 | }
468 |
469 | return sqlTemplate + `;`, parameters
470 | }
471 |
472 | func (d *Repository) applySafeDeleteConditions(conditions *[]string, parameters *map[string]any) {
473 | *conditions = append(*conditions, `(
474 | (
475 | "visible_at" <= :visibleAt
476 | ) OR (
477 | "status" = ANY(:statuses) AND
478 | "visible_at" > :visibleAt
479 | )
480 | )`)
481 | (*parameters)["statuses"] = pq.Array([]tasq.TaskStatus{tasq.StatusNew})
482 | (*parameters)["visibleAt"] = time.Now()
483 | }
484 |
485 | func (d *Repository) buildFilterConditions(taskStatuses []tasq.TaskStatus, taskTypes, queues []string) ([]string, map[string]any) {
486 | var (
487 | conditions []string
488 | parameters = make(map[string]any)
489 | )
490 |
491 | if len(taskStatuses) > 0 {
492 | conditions = append(conditions, `"status" = ANY(:filterStatuses)`)
493 | parameters["filterStatuses"] = pq.Array(taskStatuses)
494 | }
495 |
496 | if len(taskTypes) > 0 {
497 | conditions = append(conditions, `"type" = ANY(:filterTypes)`)
498 | parameters["filterTypes"] = pq.Array(taskTypes)
499 | }
500 |
501 | if len(queues) > 0 {
502 | conditions = append(conditions, `"queue" = ANY(:filterQueues)`)
503 | parameters["filterQueues"] = pq.Array(queues)
504 | }
505 |
506 | return conditions, parameters
507 | }
508 |
509 | func (d *Repository) migrateStatus(ctx context.Context) error {
510 | var (
511 | sqlTemplate = `DO $$
512 | BEGIN
513 | IF NOT EXISTS (SELECT 1 FROM pg_type WHERE typname = '{{.statusTypeName}}') THEN
514 | CREATE TYPE {{.statusTypeName}} AS ENUM ({{.enumValues}});
515 | END IF;
516 | END$$;`
517 | query = interpolateSQL(sqlTemplate, map[string]any{
518 | "statusTypeName": d.statusTypeName,
519 | "enumValues": sliceToPostgreSQLValueList(tasq.GetTaskStatuses(tasq.AllTasks)),
520 | })
521 | )
522 |
523 | _, err := d.db.ExecContext(ctx, query)
524 | if err != nil {
525 | return fmt.Errorf("%s: %w", errFailedToExecuteCreateType, err)
526 | }
527 |
528 | return nil
529 | }
530 |
531 | func (d *Repository) migrateTable(ctx context.Context) error {
532 | const sqlTemplate = `CREATE TABLE IF NOT EXISTS {{.tableName}} (
533 | "id" UUID NOT NULL PRIMARY KEY,
534 | "type" TEXT NOT NULL,
535 | "args" BYTEA NOT NULL,
536 | "queue" TEXT NOT NULL,
537 | "priority" SMALLINT NOT NULL,
538 | "status" {{.statusTypeName}} NOT NULL,
539 | "receive_count" INTEGER NOT NULL DEFAULT 0,
540 | "max_receives" INTEGER NOT NULL DEFAULT 0,
541 | "last_error" TEXT,
542 | "created_at" TIMESTAMPTZ NOT NULL DEFAULT '0001-01-01 00:00:00.000000',
543 | "started_at" TIMESTAMPTZ,
544 | "finished_at" TIMESTAMPTZ,
545 | "visible_at" TIMESTAMPTZ NOT NULL DEFAULT '0001-01-01 00:00:00.000000'
546 | );`
547 |
548 | query := interpolateSQL(sqlTemplate, map[string]any{
549 | "tableName": d.tableName,
550 | "statusTypeName": d.statusTypeName,
551 | })
552 |
553 | _, err := d.db.ExecContext(ctx, query)
554 | if err != nil {
555 | return fmt.Errorf("%s: %w", errFailedToExecuteCreateTable, err)
556 | }
557 |
558 | return nil
559 | }
560 |
561 | func (d *Repository) prepareWithTableName(sqlTemplate string) *sqlx.NamedStmt {
562 | query := interpolateSQL(sqlTemplate, map[string]any{
563 | "tableName": d.tableName,
564 | })
565 |
566 | namedStmt, err := d.db.PrepareNamed(query)
567 | if err != nil {
568 | panic(err)
569 | }
570 |
571 | return namedStmt
572 | }
573 |
574 | type closeableStmt interface {
575 | Close() error
576 | }
577 |
578 | func (d *Repository) closeStmt(stmt closeableStmt) {
579 | if err := stmt.Close(); err != nil {
580 | panic(err)
581 | }
582 | }
583 |
584 | func getOrderingDirectives(ordering tasq.Ordering) []string {
585 | var (
586 | OrderingCreatedAtFirst = []string{"created_at ASC", "priority DESC"}
587 | OrderingPriorityFirst = []string{"priority DESC", "created_at ASC"}
588 | )
589 |
590 | if orderingDirectives, ok := map[tasq.Ordering][]string{
591 | tasq.OrderingCreatedAtFirst: OrderingCreatedAtFirst,
592 | tasq.OrderingPriorityFirst: OrderingPriorityFirst,
593 | }[ordering]; ok {
594 | return orderingDirectives
595 | }
596 |
597 | return OrderingCreatedAtFirst
598 | }
599 |
600 | func sliceToPostgreSQLValueList[T any](slice []T) string {
601 | stringSlice := make([]string, 0, len(slice))
602 |
603 | for _, s := range slice {
604 | stringSlice = append(stringSlice, fmt.Sprint(s))
605 | }
606 |
607 | return fmt.Sprintf("'%s'", strings.Join(stringSlice, "','"))
608 | }
609 |
610 | func statusTypeName(prefix string) string {
611 | const statusTypeName = "task_status"
612 |
613 | if len(prefix) > 0 {
614 | return prefix + "_" + statusTypeName
615 | }
616 |
617 | return statusTypeName
618 | }
619 |
620 | func tableName(prefix string) string {
621 | const tableName = "tasks"
622 |
623 | if len(prefix) > 0 {
624 | return prefix + "_" + tableName
625 | }
626 |
627 | return tableName
628 | }
629 |
630 | func interpolateSQL(sql string, params map[string]any) string {
631 | template, err := template.New("sql").Parse(sql)
632 | if err != nil {
633 | panic(err)
634 | }
635 |
636 | var outputBuffer bytes.Buffer
637 |
638 | err = template.Execute(&outputBuffer, params)
639 | if err != nil {
640 | panic(err)
641 | }
642 |
643 | return outputBuffer.String()
644 | }
645 |
--------------------------------------------------------------------------------
/repository/postgres/postgresTask.go:
--------------------------------------------------------------------------------
1 | package postgres
2 |
3 | import (
4 | "database/sql"
5 | "time"
6 |
7 | "github.com/google/uuid"
8 | "github.com/greencoda/tasq"
9 | )
10 |
11 | type postgresTask struct {
12 | ID uuid.UUID `db:"id"`
13 | Type string `db:"type"`
14 | Args []byte `db:"args"`
15 | Queue string `db:"queue"`
16 | Priority int16 `db:"priority"`
17 | Status tasq.TaskStatus `db:"status"`
18 | ReceiveCount int32 `db:"receive_count"`
19 | MaxReceives int32 `db:"max_receives"`
20 | LastError sql.NullString `db:"last_error"`
21 | CreatedAt time.Time `db:"created_at"`
22 | StartedAt *time.Time `db:"started_at"`
23 | FinishedAt *time.Time `db:"finished_at"`
24 | VisibleAt time.Time `db:"visible_at"`
25 | }
26 |
27 | func newFromTask(task *tasq.Task) *postgresTask {
28 | return &postgresTask{
29 | ID: task.ID,
30 | Type: task.Type,
31 | Args: task.Args,
32 | Queue: task.Queue,
33 | Priority: task.Priority,
34 | Status: task.Status,
35 | ReceiveCount: task.ReceiveCount,
36 | MaxReceives: task.MaxReceives,
37 | LastError: stringToSQLNullString(task.LastError),
38 | CreatedAt: task.CreatedAt,
39 | StartedAt: task.StartedAt,
40 | FinishedAt: task.FinishedAt,
41 | VisibleAt: task.VisibleAt,
42 | }
43 | }
44 |
45 | func (t *postgresTask) toTask() *tasq.Task {
46 | return &tasq.Task{
47 | ID: t.ID,
48 | Type: t.Type,
49 | Args: t.Args,
50 | Queue: t.Queue,
51 | Priority: t.Priority,
52 | Status: t.Status,
53 | ReceiveCount: t.ReceiveCount,
54 | MaxReceives: t.MaxReceives,
55 | LastError: parseNullableString(t.LastError),
56 | CreatedAt: t.CreatedAt,
57 | StartedAt: t.StartedAt,
58 | FinishedAt: t.FinishedAt,
59 | VisibleAt: t.VisibleAt,
60 | }
61 | }
62 |
63 | func postgresTasksToTasks(postgresTasks []*postgresTask) []*tasq.Task {
64 | tasks := make([]*tasq.Task, len(postgresTasks))
65 |
66 | for i, postgresTask := range postgresTasks {
67 | tasks[i] = postgresTask.toTask()
68 | }
69 |
70 | return tasks
71 | }
72 |
73 | func stringToSQLNullString(input *string) sql.NullString {
74 | if input == nil {
75 | return sql.NullString{
76 | String: "",
77 | Valid: false,
78 | }
79 | }
80 |
81 | return sql.NullString{
82 | String: *input,
83 | Valid: true,
84 | }
85 | }
86 |
87 | func parseNullableString(input sql.NullString) *string {
88 | if !input.Valid {
89 | return nil
90 | }
91 |
92 | return &input.String
93 | }
94 |
--------------------------------------------------------------------------------
/repository/postgres/postgresTask_test.go:
--------------------------------------------------------------------------------
1 | package postgres_test
2 |
3 | import (
4 | "database/sql"
5 | "testing"
6 |
7 | "github.com/greencoda/tasq/repository/postgres"
8 | "github.com/stretchr/testify/assert"
9 | )
10 |
11 | func TestStringToSQLNullString(t *testing.T) {
12 | t.Parallel()
13 |
14 | var (
15 | emptyInput = ""
16 | nonEmptyInput = "test"
17 | nilInput *string
18 | )
19 |
20 | sqlNullString := postgres.StringToSQLNullString(&emptyInput)
21 | assert.Equal(t, sql.NullString{
22 | String: emptyInput,
23 | Valid: true,
24 | }, sqlNullString)
25 |
26 | sqlNullString = postgres.StringToSQLNullString(&nonEmptyInput)
27 | assert.Equal(t, sql.NullString{
28 | String: nonEmptyInput,
29 | Valid: true,
30 | }, sqlNullString)
31 |
32 | sqlNullString = postgres.StringToSQLNullString(nilInput)
33 | assert.Equal(t, sql.NullString{
34 | String: "",
35 | Valid: false,
36 | }, sqlNullString)
37 | }
38 |
39 | func TestParseNullableString(t *testing.T) {
40 | t.Parallel()
41 |
42 | var (
43 | emptyInput = sql.NullString{
44 | String: "",
45 | Valid: true,
46 | }
47 | nonEmptyInput = sql.NullString{
48 | String: "test",
49 | Valid: true,
50 | }
51 | nilInput = sql.NullString{
52 | String: "",
53 | Valid: false,
54 | }
55 | )
56 |
57 | output := postgres.ParseNullableString(emptyInput)
58 | assert.NotNil(t, output)
59 | assert.Equal(t, *output, emptyInput.String)
60 |
61 | output = postgres.ParseNullableString(nonEmptyInput)
62 | assert.NotNil(t, output)
63 | assert.Equal(t, *output, nonEmptyInput.String)
64 |
65 | output = postgres.ParseNullableString(nilInput)
66 | assert.Nil(t, output)
67 | }
68 |
--------------------------------------------------------------------------------
/repository/postgres/postgres_test.go:
--------------------------------------------------------------------------------
1 | package postgres_test
2 |
3 | import (
4 | "context"
5 | "database/sql"
6 | "database/sql/driver"
7 | "errors"
8 | "regexp"
9 | "testing"
10 | "time"
11 |
12 | "github.com/DATA-DOG/go-sqlmock"
13 | "github.com/google/uuid"
14 | "github.com/greencoda/tasq"
15 | "github.com/greencoda/tasq/repository/postgres"
16 | "github.com/stretchr/testify/assert"
17 | "github.com/stretchr/testify/require"
18 | "github.com/stretchr/testify/suite"
19 | )
20 |
21 | var (
22 | ctx = context.Background()
23 | testTask = getStartedTestTask()
24 | taskColumns = []string{
25 | "id",
26 | "type",
27 | "args",
28 | "queue",
29 | "priority",
30 | "status",
31 | "receive_count",
32 | "max_receives",
33 | "last_error",
34 | "created_at",
35 | "started_at",
36 | "finished_at",
37 | "visible_at",
38 | }
39 | testTaskType = "testTask"
40 | testTaskQueue = "testQueue"
41 | taskValues = postgres.GetTestTaskValues(testTask)
42 | errSQL = errors.New("sql error")
43 | errTask = errors.New("task error")
44 | )
45 |
46 | func getStartedTestTask() *tasq.Task {
47 | var (
48 | testTask, _ = tasq.NewTask(testTaskType, true, testTaskQueue, 100, 5)
49 | startTime = testTask.CreatedAt.Add(time.Second)
50 | )
51 |
52 | testTask.StartedAt = &startTime
53 |
54 | return testTask
55 | }
56 |
57 | type PostgresTestSuite struct {
58 | suite.Suite
59 | db *sql.DB
60 | sqlMock sqlmock.Sqlmock
61 | mockedRepository tasq.IRepository
62 | }
63 |
64 | func TestTaskTestSuite(t *testing.T) {
65 | t.Parallel()
66 |
67 | suite.Run(t, new(PostgresTestSuite))
68 | }
69 |
70 | func (s *PostgresTestSuite) SetupTest() {
71 | var err error
72 |
73 | s.db, s.sqlMock, err = sqlmock.New()
74 | require.Nil(s.T(), err)
75 |
76 | s.mockedRepository, err = postgres.NewRepository(s.db, "test")
77 | require.NotNil(s.T(), s.mockedRepository)
78 | require.Nil(s.T(), err)
79 | }
80 |
81 | func (s *PostgresTestSuite) TestNewRepository() {
82 | // providing the datasource as *sql.DB
83 | repository, err := postgres.NewRepository(s.db, "test")
84 | assert.NotNil(s.T(), repository)
85 | assert.Nil(s.T(), err)
86 |
87 | // providing the datasource as *sql.DB with no prefix
88 | repository, err = postgres.NewRepository(s.db, "")
89 | assert.NotNil(s.T(), repository)
90 | assert.Nil(s.T(), err)
91 |
92 | // providing the datasource as dsn string
93 | repository, err = postgres.NewRepository("testDSN", "test")
94 | assert.NotNil(s.T(), repository)
95 | assert.Nil(s.T(), err)
96 |
97 | // providing the datasource as unknown datasource type
98 | repository, err = postgres.NewRepository(false, "test")
99 | assert.Nil(s.T(), repository)
100 | assert.NotNil(s.T(), err)
101 | }
102 |
103 | func (s *PostgresTestSuite) TestMigrate() {
104 | // First try - creating the task_status type fails
105 | s.sqlMock.ExpectExec(`CREATE TYPE test_task_status AS ENUM`).WillReturnError(errSQL)
106 |
107 | err := s.mockedRepository.Migrate(ctx)
108 | assert.NotNil(s.T(), err)
109 |
110 | // Second try - creating the tasks table fails
111 | s.sqlMock.ExpectExec(`CREATE TYPE test_task_status AS ENUM`).WillReturnResult(sqlmock.NewResult(1, 1))
112 | s.sqlMock.ExpectExec(`CREATE TABLE IF NOT EXISTS test_tasks`).WillReturnError(errSQL)
113 |
114 | err = s.mockedRepository.Migrate(ctx)
115 | assert.NotNil(s.T(), err)
116 |
117 | // Third try - migration succeeds
118 | s.sqlMock.ExpectExec(`CREATE TYPE test_task_status AS ENUM`).WillReturnResult(sqlmock.NewResult(1, 1))
119 | s.sqlMock.ExpectExec(`CREATE TABLE IF NOT EXISTS test_tasks`).WillReturnResult(sqlmock.NewResult(1, 1))
120 |
121 | err = s.mockedRepository.Migrate(ctx)
122 | assert.Nil(s.T(), err)
123 | }
124 |
125 | func (s *PostgresTestSuite) TestPingTasks() {
126 | var (
127 | taskUUID = uuid.New()
128 | stmtMockRegexp = regexp.QuoteMeta(`UPDATE test_tasks SET "visible_at" = $1 WHERE "id" = ANY($2) RETURNING id;`)
129 | )
130 | // pinging empty tasklist
131 | tasks, err := s.mockedRepository.PingTasks(ctx, []uuid.UUID{}, 15*time.Second)
132 | assert.Len(s.T(), tasks, 0)
133 | assert.Nil(s.T(), err)
134 |
135 | // pinging when stmt preparation returns an error
136 | s.sqlMock.ExpectPrepare(stmtMockRegexp).WillReturnError(errSQL)
137 |
138 | assert.PanicsWithError(s.T(), errSQL.Error(), func() {
139 | _, _ = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second)
140 | })
141 |
142 | // pinging when DB returns no rows
143 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL)
144 |
145 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second)
146 | assert.Len(s.T(), tasks, 0)
147 | assert.NotNil(s.T(), err)
148 |
149 | // pinging existing task
150 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(taskUUID))
151 |
152 | tasks, err = s.mockedRepository.PingTasks(ctx, []uuid.UUID{taskUUID}, 15*time.Second)
153 | assert.Len(s.T(), tasks, 1)
154 | assert.Nil(s.T(), err)
155 | }
156 |
157 | func (s *PostgresTestSuite) TestPollTasks() {
158 | stmtMockRegexp := regexp.QuoteMeta(`UPDATE test_tasks SET
159 | "status" = $1,
160 | "receive_count" = "receive_count" + 1,
161 | "visible_at" = $2
162 | WHERE "id" IN (
163 | SELECT
164 | "id"
165 | FROM test_tasks
166 | WHERE "type" = ANY($3)
167 | AND "queue" = ANY($4)
168 | AND "status" = ANY($5)
169 | AND "visible_at" <= $6
170 | ORDER BY $7
171 | LIMIT $8
172 | FOR UPDATE
173 | ) RETURNING *;`)
174 |
175 | // polling with 0 limit
176 | tasks, err := s.mockedRepository.PollTasks(ctx, []string{testTaskType}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 0)
177 | assert.Len(s.T(), tasks, 0)
178 | assert.Nil(s.T(), err)
179 |
180 | // polling when DB returns no rows
181 | s.sqlMock.ExpectPrepare(stmtMockRegexp).
182 | ExpectQuery().
183 | WillReturnError(sql.ErrNoRows)
184 |
185 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{testTaskType}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1)
186 | assert.Len(s.T(), tasks, 0)
187 | assert.Nil(s.T(), err)
188 |
189 | // polling when DB returns error
190 | s.sqlMock.ExpectPrepare(stmtMockRegexp).
191 | ExpectQuery().
192 | WillReturnError(errSQL)
193 |
194 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{testTaskType}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1)
195 | assert.Len(s.T(), tasks, 0)
196 | assert.NotNil(s.T(), err)
197 |
198 | // polling for existing tasks
199 | s.sqlMock.ExpectPrepare(stmtMockRegexp).
200 | ExpectQuery().
201 | WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
202 |
203 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{testTaskType}, []string{"testQueue"}, 15*time.Second, tasq.OrderingCreatedAtFirst, 1)
204 | assert.Len(s.T(), tasks, 1)
205 | assert.Nil(s.T(), err)
206 |
207 | // polling for existing tasks with unknown ordering
208 | s.sqlMock.ExpectPrepare(stmtMockRegexp).
209 | ExpectQuery().
210 | WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
211 |
212 | tasks, err = s.mockedRepository.PollTasks(ctx, []string{testTaskType}, []string{"testQueue"}, 15*time.Second, -1, 1)
213 | assert.Len(s.T(), tasks, 1)
214 | assert.Nil(s.T(), err)
215 | }
216 |
217 | func (s *PostgresTestSuite) TestCleanTasks() {
218 | stmtMockRegexp := regexp.QuoteMeta(`DELETE FROM test_tasks WHERE "status" = ANY($1) AND "created_at" <= $2;`)
219 |
220 | // cleaning when DB returns error
221 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnError(errSQL)
222 |
223 | rowsAffected, err := s.mockedRepository.CleanTasks(ctx, time.Hour)
224 | assert.Zero(s.T(), rowsAffected)
225 | assert.NotNil(s.T(), err)
226 |
227 | // cleaning when no rows are found
228 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnResult(driver.ResultNoRows)
229 |
230 | rowsAffected, err = s.mockedRepository.CleanTasks(ctx, time.Hour)
231 | assert.Equal(s.T(), int64(0), rowsAffected)
232 | assert.NotNil(s.T(), err)
233 |
234 | // cleaning successful
235 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1))
236 |
237 | rowsAffected, err = s.mockedRepository.CleanTasks(ctx, time.Hour)
238 | assert.Equal(s.T(), int64(1), rowsAffected)
239 | assert.Nil(s.T(), err)
240 | }
241 |
242 | func (s *PostgresTestSuite) TestRegisterStart() {
243 | stmtMockRegexp := regexp.QuoteMeta(`UPDATE test_tasks SET "status" = $1, "started_at" = $2 WHERE "id" = $3 RETURNING *;`)
244 |
245 | // registering start when DB returns error
246 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL)
247 |
248 | task, err := s.mockedRepository.RegisterStart(ctx, testTask)
249 | assert.Empty(s.T(), task)
250 | assert.NotNil(s.T(), err)
251 |
252 | // registering start successful
253 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
254 |
255 | task, err = s.mockedRepository.RegisterStart(ctx, testTask)
256 | assert.NotEmpty(s.T(), task)
257 | assert.Nil(s.T(), err)
258 | }
259 |
260 | func (s *PostgresTestSuite) TestRegisterError() {
261 | stmtMockRegexp := regexp.QuoteMeta(`UPDATE test_tasks SET "last_error" = $1 WHERE "id" = $2 RETURNING *;`)
262 |
263 | // registering error when DB returns error
264 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL)
265 |
266 | task, err := s.mockedRepository.RegisterError(ctx, testTask, errTask)
267 | assert.Empty(s.T(), task)
268 | assert.NotNil(s.T(), err)
269 |
270 | // registering error successful
271 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
272 |
273 | task, err = s.mockedRepository.RegisterError(ctx, testTask, errTask)
274 | assert.NotEmpty(s.T(), task)
275 | assert.Nil(s.T(), err)
276 | }
277 |
278 | func (s *PostgresTestSuite) TestRegisterFinish() {
279 | stmtMockRegexp := regexp.QuoteMeta(`UPDATE test_tasks SET "status" = $1, "finished_at" = $2 WHERE "id" = $3 RETURNING *;`)
280 |
281 | // registering failure when DB returns error
282 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL)
283 |
284 | task, err := s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful)
285 | assert.Empty(s.T(), task)
286 | assert.NotNil(s.T(), err)
287 |
288 | // registering failure successful
289 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
290 |
291 | task, err = s.mockedRepository.RegisterFinish(ctx, testTask, tasq.StatusSuccessful)
292 | assert.NotEmpty(s.T(), task)
293 | assert.Nil(s.T(), err)
294 | }
295 |
296 | func (s *PostgresTestSuite) TestSubmitTask() {
297 | stmtMockRegexp := regexp.QuoteMeta(`INSERT INTO test_tasks (id, type, args, queue, priority, status, max_receives, created_at, visible_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING *;`)
298 |
299 | // submitting task when DB returns error
300 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL)
301 |
302 | task, err := s.mockedRepository.SubmitTask(ctx, testTask)
303 | assert.Empty(s.T(), task)
304 | assert.NotNil(s.T(), err)
305 |
306 | // submitting task successful
307 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
308 |
309 | task, err = s.mockedRepository.SubmitTask(ctx, testTask)
310 | assert.NotEmpty(s.T(), task)
311 | assert.Nil(s.T(), err)
312 | }
313 |
314 | func (s *PostgresTestSuite) TestDeleteTask() {
315 | var (
316 | stmtMockRegexp = regexp.QuoteMeta(`DELETE
317 | FROM
318 | test_tasks
319 | WHERE
320 | "id" = $1;`)
321 | stmtInvisibleMockRegexp = regexp.QuoteMeta(`DELETE
322 | FROM
323 | test_tasks
324 | WHERE
325 | "id" = $1 AND
326 | (
327 | (
328 | "visible_at" <= $2
329 | ) OR
330 | (
331 | "status" = ANY($3) AND
332 | "visible_at" > $4
333 | )
334 | );`)
335 | )
336 |
337 | // deleting task when DB returns error
338 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnError(errSQL)
339 |
340 | err := s.mockedRepository.DeleteTask(ctx, testTask, false)
341 | assert.NotNil(s.T(), err)
342 |
343 | // deleting task successful
344 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1))
345 |
346 | err = s.mockedRepository.DeleteTask(ctx, testTask, false)
347 | assert.Nil(s.T(), err)
348 |
349 | // deleting invisible task successful
350 | s.sqlMock.ExpectPrepare(stmtInvisibleMockRegexp).ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1))
351 |
352 | err = s.mockedRepository.DeleteTask(ctx, testTask, true)
353 | assert.Nil(s.T(), err)
354 | }
355 |
356 | func (s *PostgresTestSuite) TestRequeueTask() {
357 | stmtMockRegexp := regexp.QuoteMeta(`UPDATE test_tasks SET "status" = $1 WHERE "id" = $2 RETURNING *;`)
358 |
359 | // requeuing task when DB returns error
360 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL)
361 |
362 | task, err := s.mockedRepository.RequeueTask(ctx, testTask)
363 | assert.Empty(s.T(), task)
364 | assert.NotNil(s.T(), err)
365 |
366 | // requeuing task successful
367 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
368 |
369 | task, err = s.mockedRepository.RequeueTask(ctx, testTask)
370 | assert.NotEmpty(s.T(), task)
371 | assert.Nil(s.T(), err)
372 | }
373 |
374 | func (s *PostgresTestSuite) TestCountTasks() {
375 | stmtMockRegexp := regexp.QuoteMeta(`SELECT COUNT(*) FROM test_tasks WHERE "status" = ANY($1) AND "type" = ANY($2) AND "queue" = ANY($3)`)
376 |
377 | // counting tasks when DB returns error
378 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL)
379 |
380 | count, err := s.mockedRepository.CountTasks(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"test"}, []string{"test"})
381 | assert.Equal(s.T(), int64(0), count)
382 | assert.NotNil(s.T(), err)
383 |
384 | // counting tasks successful
385 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(10))
386 |
387 | count, err = s.mockedRepository.CountTasks(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"test"}, []string{"test"})
388 | assert.Equal(s.T(), int64(10), count)
389 | assert.Nil(s.T(), err)
390 | }
391 |
392 | func (s *PostgresTestSuite) TestScanTasks() {
393 | stmtMockRegexp := regexp.QuoteMeta(`SELECT * FROM test_tasks WHERE "status" = ANY($1) AND "type" = ANY($2) AND "queue" = ANY($3) ORDER BY $4 LIMIT $5;`)
394 |
395 | // scanning tasks when DB returns error
396 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnError(errSQL)
397 |
398 | tasks, err := s.mockedRepository.ScanTasks(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"test"}, []string{"test"}, 0, 10)
399 | assert.Empty(s.T(), tasks)
400 | assert.NotNil(s.T(), err)
401 |
402 | // scanning tasks successful
403 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectQuery().WillReturnRows(sqlmock.NewRows(taskColumns).AddRow(taskValues...))
404 |
405 | tasks, err = s.mockedRepository.ScanTasks(ctx, []tasq.TaskStatus{tasq.StatusNew}, []string{"test"}, []string{"test"}, 0, 10)
406 | assert.NotEmpty(s.T(), tasks)
407 | assert.Nil(s.T(), err)
408 | }
409 |
410 | func (s *PostgresTestSuite) TestPurgeTasks() {
411 | var (
412 | stmtMockRegexp = regexp.QuoteMeta(`DELETE
413 | FROM
414 | test_tasks
415 | WHERE
416 | "status" = ANY($1) AND
417 | "queue" = ANY($2);`)
418 | stmtSafeDeleteMockRegexp = regexp.QuoteMeta(`DELETE
419 | FROM
420 | test_tasks
421 | WHERE
422 | "status" = ANY($1) AND
423 | "queue" = ANY($2) AND
424 | (
425 | ( "visible_at" <= $3 ) OR
426 | ( "status" = ANY($4) AND "visible_at" > $5 )
427 | );`)
428 | )
429 |
430 | // purging tasks when DB returns error
431 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnError(errSQL)
432 |
433 | count, err := s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false)
434 | assert.Equal(s.T(), int64(0), count)
435 | assert.NotNil(s.T(), err)
436 |
437 | // purging when no rows are found
438 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnResult(driver.ResultNoRows)
439 |
440 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false)
441 | assert.Equal(s.T(), int64(0), count)
442 | assert.NotNil(s.T(), err)
443 |
444 | // purging tasks successful
445 | s.sqlMock.ExpectPrepare(stmtMockRegexp).ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1))
446 |
447 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, false)
448 | assert.Equal(s.T(), int64(1), count)
449 | assert.Nil(s.T(), err)
450 |
451 | // purging tasks with safeDelete successful
452 | s.sqlMock.ExpectPrepare(stmtSafeDeleteMockRegexp).ExpectExec().WillReturnResult(sqlmock.NewResult(1, 1))
453 |
454 | count, err = s.mockedRepository.PurgeTasks(ctx, []tasq.TaskStatus{tasq.StatusFailed}, []string{}, []string{testTaskQueue}, true)
455 | assert.Equal(s.T(), int64(1), count)
456 | assert.Nil(s.T(), err)
457 | }
458 |
459 | func (s *PostgresTestSuite) TestPrepareWithTableName() {
460 | stmtMockRegexp := regexp.QuoteMeta(`SELECT * FROM test_tasks`)
461 |
462 | postgresRepository, ok := s.mockedRepository.(*postgres.Repository)
463 | require.True(s.T(), ok)
464 |
465 | // preparing stmt with table name when DB returns error
466 | s.sqlMock.ExpectPrepare(stmtMockRegexp).WillReturnError(errSQL)
467 |
468 | assert.PanicsWithError(s.T(), "sql error", func() {
469 | _ = postgresRepository.PrepareWithTableName("SELECT * FROM {{.tableName}}")
470 | })
471 | }
472 |
473 | func (s *PostgresTestSuite) TestCloseNamedStmt() {
474 | stmtMockRegexp := regexp.QuoteMeta(`SELECT * FROM test_tasks`)
475 |
476 | postgresRepository, ok := s.mockedRepository.(*postgres.Repository)
477 | require.True(s.T(), ok)
478 |
479 | s.sqlMock.ExpectPrepare(stmtMockRegexp)
480 |
481 | stmt, err := s.db.Prepare("SELECT * FROM test_tasks")
482 | require.Nil(s.T(), err)
483 |
484 | // an alternative DB to test the panic
485 | altDB, altSQLMock, err := sqlmock.New()
486 | require.Nil(s.T(), err)
487 |
488 | altSQLMock.ExpectBegin()
489 |
490 | tx, err := altDB.Begin()
491 | require.Nil(s.T(), err)
492 |
493 | assert.PanicsWithError(s.T(), "sql: Tx.Stmt: statement from different database used", func() {
494 | postgresRepository.CloseNamedStmt(tx.Stmt(stmt))
495 | })
496 | }
497 |
498 | func (s *PostgresTestSuite) TestInterpolateSQL() {
499 | params := map[string]any{"tableName": "test_table"}
500 |
501 | // Interpolate SQL successfully
502 | interpolatedSQL := postgres.InterpolateSQL("SELECT * FROM {{.tableName}}", params)
503 | assert.Equal(s.T(), "SELECT * FROM test_table", interpolatedSQL)
504 |
505 | // Fail interpolaing unparseable SQL template
506 | assert.Panics(s.T(), func() {
507 | unparseableTemplateSQL := postgres.InterpolateSQL("SELECT * FROM {{.tableName", params)
508 | assert.Empty(s.T(), unparseableTemplateSQL)
509 | })
510 |
511 | // Fail interpolaing unexecutable SQL template
512 | assert.Panics(s.T(), func() {
513 | unexecutableTemplateSQL := postgres.InterpolateSQL(`SELECT * FROM {{if .tableName eq 1}} {{end}} {{.tableName}}`, params)
514 | assert.Empty(s.T(), unexecutableTemplateSQL)
515 | })
516 | }
517 |
--------------------------------------------------------------------------------
/task.go:
--------------------------------------------------------------------------------
1 | package tasq
2 |
3 | import (
4 | "bytes"
5 | "encoding/gob"
6 | "fmt"
7 | "time"
8 |
9 | "github.com/google/uuid"
10 | )
11 |
12 | // TaskStatus is an enum type describing the status a task is currently in.
13 | type TaskStatus string
14 |
15 | // The collection of possible task statuses.
16 | const (
17 | StatusNew TaskStatus = "NEW"
18 | StatusEnqueued TaskStatus = "ENQUEUED"
19 | StatusInProgress TaskStatus = "IN_PROGRESS"
20 | StatusSuccessful TaskStatus = "SUCCESSFUL"
21 | StatusFailed TaskStatus = "FAILED"
22 | )
23 |
24 | // TaskStatusGroup is an enum type describing the key used in the
25 | // map of TaskStatuses which groups them for different purposes.
26 | type TaskStatusGroup int
27 |
28 | // The collection of possible task status groupings.
29 | const (
30 | AllTasks TaskStatusGroup = iota
31 | OpenTasks
32 | FinishedTasks
33 | )
34 |
35 | // GetTaskStatuses returns a slice of TaskStatuses based on the TaskStatusGroup
36 | // passed as an argument.
37 | func GetTaskStatuses(taskStatusGroup TaskStatusGroup) []TaskStatus {
38 | if selected, ok := map[TaskStatusGroup][]TaskStatus{
39 | AllTasks: {
40 | StatusNew,
41 | StatusEnqueued,
42 | StatusInProgress,
43 | StatusSuccessful,
44 | StatusFailed,
45 | },
46 | OpenTasks: {
47 | StatusNew,
48 | StatusEnqueued,
49 | StatusInProgress,
50 | },
51 | FinishedTasks: {
52 | StatusSuccessful,
53 | StatusFailed,
54 | },
55 | }[taskStatusGroup]; ok {
56 | return selected
57 | }
58 |
59 | return nil
60 | }
61 |
62 | // Task is the struct used to represent an atomic task managed by tasq.
63 | type Task struct {
64 | ID uuid.UUID
65 | Type string
66 | Args []byte
67 | Queue string
68 | Priority int16
69 | Status TaskStatus
70 | ReceiveCount int32
71 | MaxReceives int32
72 | LastError *string
73 | CreatedAt time.Time
74 | StartedAt *time.Time
75 | FinishedAt *time.Time
76 | VisibleAt time.Time
77 | }
78 |
79 | // NewTask creates a new Task struct based on the supplied arguments required to define it.
80 | func NewTask(taskType string, taskArgs any, queue string, priority int16, maxReceives int32) (*Task, error) {
81 | taskID, err := uuid.NewRandom()
82 | if err != nil {
83 | return nil, fmt.Errorf("failed to generate new task ID: %w", err)
84 | }
85 |
86 | encodedArgs, err := encodeTaskArgs(taskArgs)
87 | if err != nil {
88 | return nil, err
89 | }
90 |
91 | return &Task{
92 | ID: taskID,
93 | Type: taskType,
94 | Args: encodedArgs,
95 | Queue: queue,
96 | Priority: priority,
97 | Status: StatusNew,
98 | ReceiveCount: 0,
99 | MaxReceives: maxReceives,
100 | LastError: nil,
101 | CreatedAt: time.Now(),
102 | StartedAt: nil,
103 | FinishedAt: nil,
104 | VisibleAt: time.Time{},
105 | }, nil
106 | }
107 |
108 | // IsLastReceive returns true if the task has reached its maximum number of receives.
109 | func (t *Task) IsLastReceive() bool {
110 | return t.ReceiveCount >= t.MaxReceives
111 | }
112 |
113 | // SetVisibility sets the time at which the task will become visible again.
114 | func (t *Task) SetVisibility(visibleAt time.Time) {
115 | t.VisibleAt = visibleAt
116 | }
117 |
118 | // UnmarshalArgs decodes the task arguments into the passed target interface.
119 | func (t *Task) UnmarshalArgs(target any) error {
120 | var (
121 | buffer = bytes.NewBuffer(t.Args)
122 | decoder = gob.NewDecoder(buffer)
123 | )
124 |
125 | if err := decoder.Decode(target); err != nil {
126 | return fmt.Errorf("failed to decode task arguments: %w", err)
127 | }
128 |
129 | return nil
130 | }
131 |
132 | func encodeTaskArgs(taskArgs any) ([]byte, error) {
133 | var (
134 | buffer bytes.Buffer
135 | encoder = gob.NewEncoder(&buffer)
136 | )
137 |
138 | err := encoder.Encode(taskArgs)
139 | if err != nil {
140 | return []byte{}, fmt.Errorf("failed to encode task arguments: %w", err)
141 | }
142 |
143 | return buffer.Bytes(), nil
144 | }
145 |
--------------------------------------------------------------------------------
/task_test.go:
--------------------------------------------------------------------------------
1 | package tasq_test
2 |
3 | import (
4 | "errors"
5 | "testing"
6 | "time"
7 |
8 | "github.com/google/uuid"
9 | "github.com/greencoda/tasq"
10 | "github.com/stretchr/testify/assert"
11 | "github.com/stretchr/testify/suite"
12 | )
13 |
14 | var errTest = errors.New("test error")
15 |
16 | type errorReader int
17 |
18 | func (errorReader) Read(p []byte) (int, error) {
19 | return 0, errTest
20 | }
21 |
22 | type TaskTestSuite struct {
23 | suite.Suite
24 | }
25 |
26 | func TestTaskTestSuite(t *testing.T) {
27 | suite.Run(t, new(TaskTestSuite))
28 | }
29 |
30 | func (s *TaskTestSuite) SetupTest() {
31 | uuid.SetRand(nil)
32 | }
33 |
34 | func (s *TaskTestSuite) TestGetTaskStatuses() {
35 | allTasks := tasq.GetTaskStatuses(tasq.AllTasks)
36 | assert.ElementsMatch(s.T(), allTasks, []tasq.TaskStatus{
37 | tasq.StatusNew,
38 | tasq.StatusEnqueued,
39 | tasq.StatusInProgress,
40 | tasq.StatusSuccessful,
41 | tasq.StatusFailed,
42 | })
43 |
44 | openTasks := tasq.GetTaskStatuses(tasq.OpenTasks)
45 | assert.ElementsMatch(s.T(), openTasks, []tasq.TaskStatus{
46 | tasq.StatusNew,
47 | tasq.StatusEnqueued,
48 | tasq.StatusInProgress,
49 | })
50 |
51 | finishedTasks := tasq.GetTaskStatuses(tasq.FinishedTasks)
52 | assert.ElementsMatch(s.T(), finishedTasks, []tasq.TaskStatus{
53 | tasq.StatusSuccessful,
54 | tasq.StatusFailed,
55 | })
56 |
57 | unknownTasks := tasq.GetTaskStatuses(-1)
58 | assert.Empty(s.T(), unknownTasks)
59 | }
60 |
61 | func (s *TaskTestSuite) TestNewTask() {
62 | // Create task successfully
63 | task, _ := tasq.NewTask("testTask", true, "testQueue", 0, 5)
64 | assert.NotNil(s.T(), task)
65 |
66 | // Fail by creating task with nil args
67 | nilTask, err := tasq.NewTask("testTask", nil, "testQueue", 0, 5)
68 | assert.Nil(s.T(), nilTask)
69 | assert.NotNil(s.T(), err)
70 |
71 | // Fail by causing uuid generation to return error
72 | uuid.SetRand(new(errorReader))
73 |
74 | invalidUUIDTask, err := tasq.NewTask("testTask", false, "testQueue", 0, 5)
75 | assert.Nil(s.T(), invalidUUIDTask)
76 | assert.NotNil(s.T(), err)
77 | }
78 |
79 | func (s *TaskTestSuite) TestTaskUnmarshalArgs() {
80 | // Create task successfully
81 | task, _ := tasq.NewTask("testTask", true, "testQueue", 0, 5)
82 | assert.NotNil(s.T(), task)
83 |
84 | // Unmarshal task args successfully
85 | var args bool
86 | err := task.UnmarshalArgs(&args)
87 | assert.Nil(s.T(), err)
88 | assert.True(s.T(), args)
89 |
90 | // Fail by unmarshaling args to incorrect type
91 | var incorrectTypeArgs string
92 | err = task.UnmarshalArgs(&incorrectTypeArgs)
93 | assert.NotNil(s.T(), err)
94 | assert.Empty(s.T(), incorrectTypeArgs)
95 | }
96 |
97 | func (s *TaskTestSuite) TestTaskIsLastReceive() {
98 | // Create singleReceiveTask successfully
99 | singleReceiveTask, _ := tasq.NewTask("testTask", true, "testQueue", 0, 1)
100 | singleReceiveTask.ReceiveCount = 1
101 | assert.NotNil(s.T(), singleReceiveTask)
102 |
103 | // Check if task is in its last receive before reaching the maximum amount of receives
104 | assert.True(s.T(), singleReceiveTask.IsLastReceive())
105 |
106 | // Create multiReceiveTask successfully
107 | multiReceiveTask, _ := tasq.NewTask("testTask", true, "testQueue", 0, 5)
108 | multiReceiveTask.ReceiveCount = 1
109 | assert.NotNil(s.T(), multiReceiveTask)
110 |
111 | // Check if task is in its last receive before reaching the maximum amount of receives
112 | assert.False(s.T(), multiReceiveTask.IsLastReceive())
113 | }
114 |
115 | func (s *TaskTestSuite) TestTaskSetVisibility() {
116 | // Create task successfully
117 | task, _ := tasq.NewTask("testTask", true, "testQueue", 0, 5)
118 | assert.NotNil(s.T(), task)
119 |
120 | // Set Visibility successfully
121 | visibilityTime := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC)
122 | task.SetVisibility(visibilityTime)
123 | assert.Equal(s.T(), task.VisibleAt, visibilityTime)
124 | }
125 |
--------------------------------------------------------------------------------