├── .gitignore ├── .golangci.yml ├── examples ├── redis │ ├── myproject.py │ ├── producer.py │ ├── producer │ │ └── main.go │ ├── consumer │ │ └── main.go │ ├── goredis │ │ └── main.go │ ├── retry │ │ └── main.go │ ├── redis │ │ └── main.go │ └── metrics │ │ └── main.go ├── rabbitmq │ ├── myproject.py │ ├── producer.py │ ├── consumer │ │ └── main.go │ └── producer │ │ └── main.go ├── go.mod └── go.sum ├── go.mod ├── .github └── workflows │ └── ci.yml ├── protocol ├── testdata │ ├── v1_noparams.json │ ├── v2_noparams.json │ └── v2_argskwargs.json ├── json.go ├── json_test.go ├── serializer_test.go └── serializer.go ├── internal └── broker │ ├── move2back.go │ └── move2back_test.go ├── LICENSE ├── goredis ├── broker_test.go └── broker.go ├── redis ├── broker_test.go └── broker.go ├── rabbitmq ├── broker_test.go └── broker.go ├── go.sum ├── config.go ├── param_test.go ├── param.go ├── celery_test.go ├── bench-old.txt ├── celery.go └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | dump.rdb 3 | bench-new.txt 4 | __pycache__ 5 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | linters: 3 | exclusions: 4 | rules: 5 | - source: "\\.Log\\(" 6 | linters: 7 | - errcheck 8 | - path: '(.+)_test\.go' 9 | linters: 10 | - errcheck 11 | -------------------------------------------------------------------------------- /examples/redis/myproject.py: -------------------------------------------------------------------------------- 1 | """ 2 | myproject 3 | ~~~~~~~~~ 4 | 5 | Run a Celery worker as follows: 6 | 7 | $ celery --app myproject worker --queues important --loglevel=debug --without-heartbeat --without-mingle 8 | 9 | It will process tasks from "important" queue. 10 | 11 | """ 12 | from celery import Celery 13 | 14 | app = Celery(broker='redis://localhost:6379') 15 | 16 | 17 | @app.task 18 | def mytask(a, b): 19 | print('received a={} b={}'.format(a, b)) 20 | -------------------------------------------------------------------------------- /examples/rabbitmq/myproject.py: -------------------------------------------------------------------------------- 1 | """ 2 | myproject 3 | ~~~~~~~~~ 4 | 5 | Run a Celery worker as follows: 6 | 7 | $ celery --app myproject worker --queues important --loglevel=debug --without-heartbeat --without-mingle 8 | 9 | It will process tasks from "important" queue. 10 | 11 | """ 12 | from celery import Celery 13 | 14 | app = Celery(broker='amqp://guest:guest@localhost:5672/') 15 | 16 | 17 | @app.task 18 | def mytask(a, b): 19 | print('received a={} b={}'.format(a, b)) 20 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/marselester/gopher-celery 2 | 3 | go 1.19 4 | 5 | require ( 6 | github.com/go-kit/log v0.2.1 7 | github.com/gomodule/redigo v1.9.2 8 | github.com/google/go-cmp v0.6.0 9 | github.com/google/uuid v1.6.0 10 | github.com/rabbitmq/amqp091-go v1.10.0 11 | github.com/redis/go-redis/v9 v9.7.0 12 | golang.org/x/sync v0.10.0 13 | ) 14 | 15 | require ( 16 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 17 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 18 | github.com/go-logfmt/logfmt v0.6.0 // indirect 19 | ) 20 | -------------------------------------------------------------------------------- /examples/redis/producer.py: -------------------------------------------------------------------------------- 1 | """ 2 | producer 3 | ~~~~~~~~ 4 | 5 | This module sends a "myproject.mytask" task to "important" queue. 6 | 7 | """ 8 | import argparse 9 | 10 | from celery import Celery 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--protocol', type=int, default=2, help='Celery protocol version') 14 | args = parser.parse_args() 15 | 16 | app = Celery( 17 | main='myproject', 18 | broker='redis://localhost:6379', 19 | ) 20 | app.conf.update( 21 | CELERY_TASK_SERIALIZER='json', 22 | CELERY_TASK_PROTOCOL=args.protocol, 23 | ) 24 | 25 | 26 | @app.task 27 | def mytask(a, b): 28 | pass 29 | 30 | mytask.apply_async(args=('fizz',), kwargs={'b': 'bazz'}, queue='important') 31 | -------------------------------------------------------------------------------- /examples/rabbitmq/producer.py: -------------------------------------------------------------------------------- 1 | """ 2 | producer 3 | ~~~~~~~~ 4 | 5 | This module sends a "myproject.mytask" task to "important" queue. 6 | 7 | """ 8 | import argparse 9 | 10 | from celery import Celery 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--protocol', type=int, default=2, help='Celery protocol version') 14 | args = parser.parse_args() 15 | 16 | app = Celery( 17 | main='myproject', 18 | broker='amqp://guest:guest@localhost:5672/', 19 | ) 20 | app.conf.update( 21 | CELERY_TASK_SERIALIZER='json', 22 | CELERY_TASK_PROTOCOL=args.protocol, 23 | ) 24 | 25 | 26 | @app.task 27 | def mytask(a, b): 28 | pass 29 | 30 | mytask.apply_async(args=('fizz',), kwargs={'b': 'bazz'}, queue='important') 31 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | pull_request: 4 | 5 | jobs: 6 | test: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v4 10 | - uses: actions/setup-go@v5 11 | with: 12 | go-version: stable 13 | - uses: supercharge/redis-github-action@1.8.0 14 | with: 15 | redis-version: 7 16 | - uses: nijel/rabbitmq-action@v1.0.0 17 | with: 18 | rabbitmq version: '4.1.1-management' 19 | - run: go test -count=1 -v ./... 20 | lint: 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v4 24 | - uses: actions/setup-go@v5 25 | with: 26 | go-version: stable 27 | - uses: golangci/golangci-lint-action@v8 28 | with: 29 | version: v2.1 30 | -------------------------------------------------------------------------------- /examples/redis/producer/main.go: -------------------------------------------------------------------------------- 1 | // Program producer sends two "myproject.mytask" tasks to "important" queue. 2 | package main 3 | 4 | import ( 5 | "os" 6 | 7 | "github.com/go-kit/log" 8 | celery "github.com/marselester/gopher-celery" 9 | ) 10 | 11 | func main() { 12 | logger := log.NewJSONLogger(log.NewSyncWriter(os.Stderr)) 13 | 14 | app := celery.NewApp( 15 | celery.WithLogger(logger), 16 | celery.WithTaskProtocol(2), 17 | ) 18 | err := app.Delay("myproject.mytask", "important", "fizz", "bazz") 19 | logger.Log("msg", "task was sent using protocol v2", "err", err) 20 | 21 | app = celery.NewApp( 22 | celery.WithLogger(logger), 23 | celery.WithTaskProtocol(1), 24 | ) 25 | err = app.Delay("myproject.mytask", "important", "fizz", "bazz") 26 | logger.Log("msg", "task was sent using protocol v1", "err", err) 27 | } 28 | -------------------------------------------------------------------------------- /examples/redis/consumer/main.go: -------------------------------------------------------------------------------- 1 | // Program consumer receives "myproject.mytask" tasks from "important" queue. 2 | package main 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "os" 8 | "os/signal" 9 | 10 | "github.com/go-kit/log" 11 | celery "github.com/marselester/gopher-celery" 12 | ) 13 | 14 | func main() { 15 | logger := log.NewJSONLogger(log.NewSyncWriter(os.Stderr)) 16 | 17 | app := celery.NewApp( 18 | celery.WithLogger(logger), 19 | ) 20 | app.Register( 21 | "myproject.mytask", 22 | "important", 23 | func(ctx context.Context, p *celery.TaskParam) error { 24 | p.NameArgs("a", "b") 25 | fmt.Printf("received a=%s b=%s\n", p.MustString("a"), p.MustString("b")) 26 | return nil 27 | }, 28 | ) 29 | 30 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) 31 | defer stop() 32 | 33 | logger.Log("msg", "waiting for tasks...") 34 | err := app.Run(ctx) 35 | logger.Log("msg", "program stopped", "err", err) 36 | } 37 | -------------------------------------------------------------------------------- /protocol/testdata/v1_noparams.json: -------------------------------------------------------------------------------- 1 | { 2 | "body": "eyJ0YXNrIjogIm15cHJvamVjdC5hcHBzLm15YXBwLnRhc2tzLm15dGFzayIsICJpZCI6ICIwZDA5YTZkZC05OWZjLTQzNmEtYTQxYS0wZGNhYTQ4NzU0NTkiLCAiYXJncyI6IFtdLCAia3dhcmdzIjoge30sICJncm91cCI6IG51bGwsICJncm91cF9pbmRleCI6IG51bGwsICJyZXRyaWVzIjogMCwgImV0YSI6IG51bGwsICJleHBpcmVzIjogbnVsbCwgInV0YyI6IHRydWUsICJjYWxsYmFja3MiOiBudWxsLCAiZXJyYmFja3MiOiBudWxsLCAidGltZWxpbWl0IjogW251bGwsIG51bGxdLCAidGFza3NldCI6IG51bGwsICJjaG9yZCI6IG51bGx9", 3 | "content-encoding": "utf-8", 4 | "content-type": "application/json", 5 | "headers": {}, 6 | "properties": { 7 | "correlation_id": "0d09a6dd-99fc-436a-a41a-0dcaa4875459", 8 | "reply_to": "b06fb011-8bf7-3a3d-bb0f-20fdce0595c5", 9 | "delivery_mode": 2, 10 | "delivery_info": { 11 | "exchange": "", 12 | "routing_key": "important" 13 | }, 14 | "priority": 0, 15 | "body_encoding": "base64", 16 | "delivery_tag": "b7e41188-d25c-4ae9-880a-6518e82b293d" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /internal/broker/move2back.go: -------------------------------------------------------------------------------- 1 | package broker 2 | 3 | // Move2back moves item v to the end of the slice ss. 4 | // For example, given slice [a, b, c, d, e, f] and item c, 5 | // the result is [a, b, d, e, f, c]. 6 | // The running time is linear in the worst case. 7 | func Move2back(ss []string, v string) { 8 | n := len(ss) 9 | if n <= 1 { 10 | return 11 | } 12 | // Nothing to do when an item is already at the end of the slice. 13 | if ss[n-1] == v { 14 | return 15 | } 16 | 17 | var found bool 18 | i := 0 19 | for ; i < n; i++ { 20 | if ss[i] == v { 21 | found = true 22 | break 23 | } 24 | } 25 | if !found { 26 | return 27 | } 28 | 29 | // Swap the found item with the last item in the slice, 30 | // and then swap the neighbors starting from the found index i till the n-2: 31 | // the last item is already in its place, 32 | // and the one before it shouldn't be swapped with the last item. 33 | ss[i], ss[n-1] = ss[n-1], ss[i] 34 | for ; i < n-2; i++ { 35 | ss[i], ss[i+1] = ss[i+1], ss[i] 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Marsel Mavletkulov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/rabbitmq/consumer/main.go: -------------------------------------------------------------------------------- 1 | // Program consumer receives "myproject.mytask" tasks from "important" queue. 2 | package main 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "os" 8 | "os/signal" 9 | 10 | "github.com/go-kit/log" 11 | celery "github.com/marselester/gopher-celery" 12 | celeryrabbitmq "github.com/marselester/gopher-celery/rabbitmq" 13 | ) 14 | 15 | func main() { 16 | logger := log.NewJSONLogger(log.NewSyncWriter(os.Stderr)) 17 | 18 | broker := celeryrabbitmq.NewBroker(celeryrabbitmq.WithAmqpUri("amqp://guest:guest@localhost:5672/")) 19 | app := celery.NewApp( 20 | celery.WithBroker(broker), 21 | celery.WithLogger(logger), 22 | ) 23 | app.Register( 24 | "myproject.mytask", 25 | "important", 26 | func(ctx context.Context, p *celery.TaskParam) error { 27 | p.NameArgs("a", "b") 28 | fmt.Printf("received a=%s b=%s\n", p.MustString("a"), p.MustString("b")) 29 | return nil 30 | }, 31 | ) 32 | 33 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) 34 | defer stop() 35 | 36 | logger.Log("msg", "waiting for tasks...") 37 | err := app.Run(ctx) 38 | logger.Log("msg", "program stopped", "err", err) 39 | } 40 | -------------------------------------------------------------------------------- /examples/rabbitmq/producer/main.go: -------------------------------------------------------------------------------- 1 | // Program producer sends two "myproject.mytask" tasks to "important" queue. 2 | package main 3 | 4 | import ( 5 | "os" 6 | 7 | "github.com/go-kit/log" 8 | celery "github.com/marselester/gopher-celery" 9 | celeryrabbitmq "github.com/marselester/gopher-celery/rabbitmq" 10 | ) 11 | 12 | func main() { 13 | logger := log.NewJSONLogger(log.NewSyncWriter(os.Stderr)) 14 | 15 | broker := celeryrabbitmq.NewBroker(celeryrabbitmq.WithAmqpUri("amqp://guest:guest@localhost:5672/")) 16 | app := celery.NewApp( 17 | celery.WithBroker(broker), 18 | celery.WithLogger(logger), 19 | celery.WithTaskProtocol(2), 20 | ) 21 | err := app.Delay("myproject.mytask", "important", "fizz", "bazz") 22 | logger.Log("msg", "task was sent using protocol v2", "err", err) 23 | 24 | broker = celeryrabbitmq.NewBroker(celeryrabbitmq.WithAmqpUri("amqp://guest:guest@localhost:5672/")) 25 | app = celery.NewApp( 26 | celery.WithBroker(broker), 27 | celery.WithLogger(logger), 28 | celery.WithTaskProtocol(1), 29 | ) 30 | err = app.Delay("myproject.mytask", "important", "fizz", "bazz") 31 | logger.Log("msg", "task was sent using protocol v1", "err", err) 32 | } 33 | -------------------------------------------------------------------------------- /goredis/broker_test.go: -------------------------------------------------------------------------------- 1 | package goredis 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | "time" 7 | 8 | "github.com/google/go-cmp/cmp" 9 | ) 10 | 11 | func TestReceive(t *testing.T) { 12 | q := "goredisq" 13 | br := NewBroker(WithReceiveTimeout(time.Second)) 14 | br.Observe([]string{q}) 15 | 16 | tests := map[string]struct { 17 | input []string 18 | want []byte 19 | err error 20 | }{ 21 | "timeout": { 22 | input: nil, 23 | want: nil, 24 | err: nil, 25 | }, 26 | "one-msg": { 27 | input: []string{"{}"}, 28 | want: []byte("{}"), 29 | err: nil, 30 | }, 31 | "oldest-msg": { 32 | input: []string{"1", "2"}, 33 | want: []byte("1"), 34 | err: nil, 35 | }, 36 | } 37 | 38 | for name, tc := range tests { 39 | t.Run(name, func(t *testing.T) { 40 | t.Cleanup(func() { 41 | br.pool.Del(context.Background(), q) 42 | }) 43 | 44 | for _, m := range tc.input { 45 | if err := br.Send([]byte(m), q); err != nil { 46 | t.Fatal(err) 47 | } 48 | } 49 | 50 | got, err := br.Receive() 51 | if err != tc.err { 52 | t.Fatal(err) 53 | } 54 | if diff := cmp.Diff(tc.want, got); diff != "" { 55 | t.Error(diff, string(got)) 56 | } 57 | }) 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /examples/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/marselester/gopher-celery/examples 2 | 3 | go 1.21 4 | 5 | replace github.com/marselester/gopher-celery => ../ 6 | 7 | require ( 8 | github.com/go-kit/log v0.2.1 9 | github.com/gomodule/redigo v1.9.2 10 | github.com/marselester/backoff v0.0.1 11 | github.com/oklog/run v1.1.0 12 | github.com/prometheus/client_golang v1.20.5 13 | github.com/redis/go-redis/v9 v9.7.0 14 | github.com/marselester/gopher-celery v0.0.0-00010101000000-000000000000 15 | ) 16 | 17 | require ( 18 | github.com/beorn7/perks v1.0.1 // indirect 19 | github.com/cespare/xxhash/v2 v2.3.0 // indirect 20 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 21 | github.com/go-logfmt/logfmt v0.6.0 // indirect 22 | github.com/google/uuid v1.6.0 // indirect 23 | github.com/klauspost/compress v1.17.9 // indirect 24 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 25 | github.com/prometheus/client_model v0.6.1 // indirect 26 | github.com/prometheus/common v0.61.0 // indirect 27 | github.com/prometheus/procfs v0.15.1 // indirect 28 | github.com/rabbitmq/amqp091-go v1.10.0 // indirect 29 | golang.org/x/sync v0.10.0 // indirect 30 | golang.org/x/sys v0.28.0 // indirect 31 | google.golang.org/protobuf v1.36.1 // indirect 32 | ) 33 | -------------------------------------------------------------------------------- /redis/broker_test.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/google/go-cmp/cmp" 8 | ) 9 | 10 | func TestReceive(t *testing.T) { 11 | q := "redigoq" 12 | br := NewBroker(WithReceiveTimeout(time.Second)) 13 | br.Observe([]string{q}) 14 | 15 | tests := map[string]struct { 16 | input []string 17 | want []byte 18 | err error 19 | }{ 20 | "timeout": { 21 | input: nil, 22 | want: nil, 23 | err: nil, 24 | }, 25 | "one-msg": { 26 | input: []string{"{}"}, 27 | want: []byte("{}"), 28 | err: nil, 29 | }, 30 | "oldest-msg": { 31 | input: []string{"1", "2"}, 32 | want: []byte("1"), 33 | err: nil, 34 | }, 35 | } 36 | 37 | for name, tc := range tests { 38 | t.Run(name, func(t *testing.T) { 39 | t.Cleanup(func() { 40 | c := br.pool.Get() 41 | defer c.Close() 42 | 43 | if _, err := c.Do("DEL", q); err != nil { 44 | t.Fatal(err) 45 | } 46 | }) 47 | 48 | for _, m := range tc.input { 49 | if err := br.Send([]byte(m), q); err != nil { 50 | t.Fatal(err) 51 | } 52 | } 53 | 54 | got, err := br.Receive() 55 | if err != tc.err { 56 | t.Fatal(err) 57 | } 58 | if diff := cmp.Diff(tc.want, got); diff != "" { 59 | t.Error(diff, string(got)) 60 | } 61 | }) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /rabbitmq/broker_test.go: -------------------------------------------------------------------------------- 1 | package rabbitmq 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/google/go-cmp/cmp" 8 | ) 9 | 10 | func TestReceive(t *testing.T) { 11 | q := "rabbitmqq" 12 | 13 | tests := map[string]struct { 14 | input []string 15 | want []byte 16 | err error 17 | }{ 18 | "timeout": { 19 | input: nil, 20 | want: nil, 21 | err: nil, 22 | }, 23 | "one-msg": { 24 | input: []string{"{}"}, 25 | want: []byte("{}"), 26 | err: nil, 27 | }, 28 | "oldest-msg": { 29 | input: []string{"1", "2"}, 30 | want: []byte("1"), 31 | err: nil, 32 | }, 33 | } 34 | 35 | for name, tc := range tests { 36 | t.Run(name, func(t *testing.T) { 37 | br, err := NewBroker(WithReceiveTimeout(time.Second)) 38 | if err != nil { 39 | t.Fatal(err) 40 | } 41 | 42 | br.rawMode = true 43 | br.Observe([]string{q}) 44 | 45 | t.Cleanup(func() { 46 | br.channel.QueueDelete(q, false, false, false) 47 | }) 48 | 49 | for _, m := range tc.input { 50 | if err := br.Send([]byte(m), q); err != nil { 51 | t.Fatal(err) 52 | } 53 | } 54 | 55 | got, err := br.Receive() 56 | if err != tc.err { 57 | t.Fatal(err) 58 | } 59 | if diff := cmp.Diff(tc.want, got); diff != "" { 60 | t.Error(diff, string(got)) 61 | } 62 | }) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /protocol/testdata/v2_noparams.json: -------------------------------------------------------------------------------- 1 | { 2 | "body": "W1tdLCB7fSwgeyJjYWxsYmFja3MiOiBudWxsLCAiZXJyYmFja3MiOiBudWxsLCAiY2hhaW4iOiBudWxsLCAiY2hvcmQiOiBudWxsfV0=", 3 | "content-encoding": "utf-8", 4 | "content-type": "application/json", 5 | "headers": { 6 | "lang": "py", 7 | "task": "myproject.apps.myapp.tasks.mytask", 8 | "id": "3802f860-8d3c-4dad-b18c-597fb2ac728b", 9 | "shadow": null, 10 | "eta": null, 11 | "expires": null, 12 | "group": null, 13 | "group_index": null, 14 | "retries": 0, 15 | "timelimit": [ 16 | null, 17 | null 18 | ], 19 | "root_id": "3802f860-8d3c-4dad-b18c-597fb2ac728b", 20 | "parent_id": null, 21 | "argsrepr": "()", 22 | "kwargsrepr": "{}", 23 | "origin": "gen8382@mdesk.hitronhub.home", 24 | "ignore_result": false 25 | }, 26 | "properties": { 27 | "correlation_id": "3802f860-8d3c-4dad-b18c-597fb2ac728b", 28 | "reply_to": "288a69eb-08c6-39c4-a11d-86c0e966a634", 29 | "delivery_mode": 2, 30 | "delivery_info": { 31 | "exchange": "", 32 | "routing_key": "important" 33 | }, 34 | "priority": 0, 35 | "body_encoding": "base64", 36 | "delivery_tag": "abd5cf11-115b-485a-a93f-d445114b11cb" 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /protocol/testdata/v2_argskwargs.json: -------------------------------------------------------------------------------- 1 | { 2 | "body": "W1siZml6eiJdLCB7ImIiOiAiYmF6eiJ9LCB7ImNhbGxiYWNrcyI6IG51bGwsICJlcnJiYWNrcyI6IG51bGwsICJjaGFpbiI6IG51bGwsICJjaG9yZCI6IG51bGx9XQ==", 3 | "content-encoding": "utf-8", 4 | "content-type": "application/json", 5 | "headers": { 6 | "lang": "py", 7 | "task": "myproject.apps.myapp.tasks.mytask", 8 | "id": "0ad73c66-f4c9-4600-bd20-96746e720eed", 9 | "shadow": null, 10 | "eta": null, 11 | "expires": null, 12 | "group": null, 13 | "group_index": null, 14 | "retries": 0, 15 | "timelimit": [ 16 | null, 17 | null 18 | ], 19 | "root_id": "0ad73c66-f4c9-4600-bd20-96746e720eed", 20 | "parent_id": null, 21 | "argsrepr": "(fizz,)", 22 | "kwargsrepr": "{b: bazz}", 23 | "origin": "gen7968@mdesk.hitronhub.home", 24 | "ignore_result": false 25 | }, 26 | "properties": { 27 | "correlation_id": "0ad73c66-f4c9-4600-bd20-96746e720eed", 28 | "reply_to": "94f453cb-c38c-341b-aa92-07177cf72c4a", 29 | "delivery_mode": 2, 30 | "delivery_info": { 31 | "exchange": "", 32 | "routing_key": "important" 33 | }, 34 | "priority": 0, 35 | "body_encoding": "base64", 36 | "delivery_tag": "04f7364a-38b2-42a3-a8f7-9ca448aa1bb9" 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /examples/redis/goredis/main.go: -------------------------------------------------------------------------------- 1 | // Program goredis shows how to use github.com/redis/go-redis. 2 | package main 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "os" 8 | "os/signal" 9 | 10 | "github.com/go-kit/log" 11 | "github.com/go-kit/log/level" 12 | celery "github.com/marselester/gopher-celery" 13 | celeryredis "github.com/marselester/gopher-celery/goredis" 14 | "github.com/redis/go-redis/v9" 15 | ) 16 | 17 | func main() { 18 | logger := log.NewJSONLogger(log.NewSyncWriter(os.Stderr)) 19 | 20 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) 21 | defer stop() 22 | 23 | c := redis.NewClient(&redis.Options{ 24 | Addr: "localhost:6379", 25 | }) 26 | defer func() { 27 | if err := c.Close(); err != nil { 28 | level.Error(logger).Log("msg", "failed to close Redis client", "err", err) 29 | } 30 | }() 31 | 32 | if _, err := c.Ping(ctx).Result(); err != nil { 33 | level.Error(logger).Log("msg", "Redis connection failed", "err", err) 34 | return 35 | } 36 | 37 | broker := celeryredis.NewBroker( 38 | celeryredis.WithClient(c), 39 | ) 40 | app := celery.NewApp( 41 | celery.WithBroker(broker), 42 | celery.WithLogger(logger), 43 | ) 44 | app.Register( 45 | "myproject.mytask", 46 | "important", 47 | func(ctx context.Context, p *celery.TaskParam) error { 48 | p.NameArgs("a", "b") 49 | fmt.Printf("received a=%s b=%s\n", p.MustString("a"), p.MustString("b")) 50 | return nil 51 | }, 52 | ) 53 | 54 | logger.Log("msg", "waiting for tasks...") 55 | err := app.Run(ctx) 56 | logger.Log("msg", "program stopped", "err", err) 57 | } 58 | -------------------------------------------------------------------------------- /internal/broker/move2back_test.go: -------------------------------------------------------------------------------- 1 | package broker 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/google/go-cmp/cmp" 7 | ) 8 | 9 | func TestMove2back(t *testing.T) { 10 | tests := map[string]struct { 11 | input []string 12 | v string 13 | want []string 14 | }{ 15 | "start": { 16 | input: []string{"a", "b", "c", "d", "e", "f"}, 17 | v: "a", 18 | want: []string{"b", "c", "d", "e", "f", "a"}, 19 | }, 20 | "middle": { 21 | input: []string{"a", "b", "c", "d", "e", "f"}, 22 | v: "c", 23 | want: []string{"a", "b", "d", "e", "f", "c"}, 24 | }, 25 | "end": { 26 | input: []string{"a", "b", "c", "d", "e", "f"}, 27 | v: "f", 28 | want: []string{"a", "b", "c", "d", "e", "f"}, 29 | }, 30 | "nil": { 31 | input: nil, 32 | v: "f", 33 | want: nil, 34 | }, 35 | "one-item": { 36 | input: []string{"a"}, 37 | v: "a", 38 | want: []string{"a"}, 39 | }, 40 | "two-items-start": { 41 | input: []string{"a", "b"}, 42 | v: "a", 43 | want: []string{"b", "a"}, 44 | }, 45 | "two-items-end": { 46 | input: []string{"a", "b"}, 47 | v: "b", 48 | want: []string{"a", "b"}, 49 | }, 50 | } 51 | 52 | for name, tc := range tests { 53 | t.Run(name, func(t *testing.T) { 54 | Move2back(tc.input, tc.v) 55 | if diff := cmp.Diff(tc.want, tc.input); diff != "" { 56 | t.Error(diff) 57 | } 58 | }) 59 | } 60 | } 61 | 62 | func BenchmarkMove2back(b *testing.B) { 63 | ss := []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"} 64 | for n := 0; n < b.N; n++ { 65 | Move2back(ss, "a") 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /examples/redis/retry/main.go: -------------------------------------------------------------------------------- 1 | // Program retry shows how a task's business logic can be retried 2 | // without rescheduling the task itself. 3 | package main 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "os" 9 | "os/signal" 10 | "time" 11 | 12 | "github.com/go-kit/log" 13 | "github.com/marselester/backoff" 14 | celery "github.com/marselester/gopher-celery" 15 | ) 16 | 17 | func main() { 18 | logger := log.NewJSONLogger(log.NewSyncWriter(os.Stderr)) 19 | logger = log.With(logger, "ts", log.DefaultTimestampUTC) 20 | 21 | r := requestor{ 22 | maxRetries: 2, 23 | logger: logger, 24 | } 25 | 26 | app := celery.NewApp( 27 | celery.WithLogger(logger), 28 | ) 29 | app.Register( 30 | "myproject.mytask", 31 | "important", 32 | r.request, 33 | ) 34 | 35 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) 36 | defer stop() 37 | 38 | err := app.Delay("myproject.mytask", "important", "fizz", "bazz") 39 | logger.Log("msg", "task was sent", "err", err) 40 | 41 | logger.Log("msg", "waiting for tasks...") 42 | err = app.Run(ctx) 43 | logger.Log("msg", "program stopped", "err", err) 44 | } 45 | 46 | type requestor struct { 47 | maxRetries int 48 | logger log.Logger 49 | } 50 | 51 | func (rq *requestor) request(ctx context.Context, p *celery.TaskParam) error { 52 | // Make 3 delivery attempts with exponential backoff (1st attempt and 2 retries). 53 | // The 5 seconds multiplier increases the wait intervals. 54 | // Max waiting time between attempts is 15 seconds. 55 | r := backoff.NewDecorrJitter( 56 | backoff.WithMaxRetries(rq.maxRetries), 57 | backoff.WithMultiplier(5*time.Second), 58 | backoff.WithMaxWait(15*time.Second), 59 | ) 60 | 61 | return backoff.Run(ctx, r, func(attempt int) (err error) { 62 | if err = rq.work(); err != nil { 63 | rq.logger.Log("msg", "request failed", "attempt", attempt, "err", err) 64 | } else { 65 | rq.logger.Log("msg", "request succeeded", "attempt", attempt) 66 | } 67 | return err 68 | }) 69 | } 70 | 71 | // work is a dummy function that imitates work. 72 | func (rq *requestor) work() error { 73 | return fmt.Errorf("uh oh") 74 | } 75 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= 2 | github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= 3 | github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= 4 | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 5 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 6 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= 7 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= 8 | github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= 9 | github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= 10 | github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= 11 | github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= 12 | github.com/gomodule/redigo v1.9.2 h1:HrutZBLhSIU8abiSfW8pj8mPhOyMYjZT/wcA4/L9L9s= 13 | github.com/gomodule/redigo v1.9.2/go.mod h1:KsU3hiK/Ay8U42qpaJk+kuNa3C+spxapWpM+ywhcgtw= 14 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 15 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 16 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 17 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 18 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 19 | github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw= 20 | github.com/rabbitmq/amqp091-go v1.10.0/go.mod h1:Hy4jKW5kQART1u+JkDTF9YYOQUHXqMuhrgxOEeS7G4o= 21 | github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= 22 | github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= 23 | github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= 24 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 25 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= 26 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 27 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 28 | -------------------------------------------------------------------------------- /examples/redis/redis/main.go: -------------------------------------------------------------------------------- 1 | // Program redis shows how to pass a Redis connection pool to the broker. 2 | package main 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "os" 8 | "os/signal" 9 | "time" 10 | 11 | "github.com/go-kit/log" 12 | "github.com/go-kit/log/level" 13 | "github.com/gomodule/redigo/redis" 14 | celery "github.com/marselester/gopher-celery" 15 | celeryredis "github.com/marselester/gopher-celery/redis" 16 | ) 17 | 18 | func main() { 19 | logger := log.NewJSONLogger(log.NewSyncWriter(os.Stderr)) 20 | 21 | pool := redis.Pool{ 22 | Dial: func() (redis.Conn, error) { 23 | c, err := redis.DialURL( 24 | "redis://localhost", 25 | redis.DialConnectTimeout(5*time.Second), 26 | // The Conn.Do method sets a write deadline, 27 | // writes the command arguments to the network connection, 28 | // sets the read deadline and reads the response from the network connection 29 | // https://github.com/gomodule/redigo/issues/320. 30 | // 31 | // Note, the read timeout should be big enough for BRPOP to finish 32 | // or else the broker returns i/o timeout error. 33 | redis.DialWriteTimeout(5*time.Second), 34 | redis.DialReadTimeout(10*time.Second), 35 | ) 36 | if err != nil { 37 | level.Error(logger).Log("msg", "Redis dial failed", "err", err) 38 | } 39 | return c, err 40 | }, 41 | // Check the health of an idle connection before using it. 42 | // It PINGs connections that have been idle more than a minute. 43 | TestOnBorrow: func(c redis.Conn, t time.Time) error { 44 | if time.Since(t) < time.Minute { 45 | return nil 46 | } 47 | _, err := c.Do("PING") 48 | return err 49 | }, 50 | // Maximum number of idle connections in the pool. 51 | MaxIdle: 3, 52 | // Close connections after remaining idle for given duration. 53 | IdleTimeout: 5 * time.Minute, 54 | } 55 | defer func() { 56 | if err := pool.Close(); err != nil { 57 | level.Error(logger).Log("msg", "failed to close Redis connection pool", "err", err) 58 | } 59 | }() 60 | 61 | c := pool.Get() 62 | if _, err := c.Do("PING"); err != nil { 63 | level.Error(logger).Log("msg", "Redis connection failed", "err", err) 64 | return 65 | } 66 | c.Close() 67 | 68 | broker := celeryredis.NewBroker( 69 | celeryredis.WithPool(&pool), 70 | ) 71 | app := celery.NewApp( 72 | celery.WithBroker(broker), 73 | celery.WithLogger(logger), 74 | ) 75 | app.Register( 76 | "myproject.mytask", 77 | "important", 78 | func(ctx context.Context, p *celery.TaskParam) error { 79 | p.NameArgs("a", "b") 80 | fmt.Printf("received a=%s b=%s\n", p.MustString("a"), p.MustString("b")) 81 | return nil 82 | }, 83 | ) 84 | 85 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) 86 | defer stop() 87 | 88 | logger.Log("msg", "waiting for tasks...") 89 | err := app.Run(ctx) 90 | logger.Log("msg", "program stopped", "err", err) 91 | } 92 | -------------------------------------------------------------------------------- /config.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | import ( 4 | "github.com/go-kit/log" 5 | 6 | "github.com/marselester/gopher-celery/protocol" 7 | ) 8 | 9 | // DefaultMaxWorkers is the default upper limit of goroutines 10 | // allowed to process Celery tasks. 11 | // Note, the workers are launched only when there are tasks to process. 12 | // 13 | // Let's say it takes ~5s to process a task on average, 14 | // so 1000 goroutines should be able to handle 200 tasks per second 15 | // (X = N / R = 1000 / 5) according to Little's law N = X * R. 16 | const DefaultMaxWorkers = 1000 17 | 18 | // Option sets up a Config. 19 | type Option func(*Config) 20 | 21 | // WithCustomTaskSerializer registers a custom serializer where 22 | // mime is the mime-type describing the serialized structure, e.g., application/json, 23 | // and encoding is the content encoding which is usually utf-8 or binary. 24 | func WithCustomTaskSerializer(serializer protocol.Serializer, mime, encoding string) Option { 25 | return func(c *Config) { 26 | c.registry.Register(serializer, mime, encoding) 27 | } 28 | } 29 | 30 | // WithTaskSerializer sets a serializer mime-type, e.g., 31 | // the message's body is encoded in JSON when a task is sent to the broker. 32 | // It is equivalent to CELERY_TASK_SERIALIZER in Python. 33 | func WithTaskSerializer(mime string) Option { 34 | return func(c *Config) { 35 | switch mime { 36 | case protocol.MimeJSON: 37 | c.mime = mime 38 | default: 39 | c.mime = protocol.MimeJSON 40 | } 41 | } 42 | } 43 | 44 | // WithTaskProtocol sets the default task message protocol version used to send tasks. 45 | // It is equivalent to CELERY_TASK_PROTOCOL in Python. 46 | func WithTaskProtocol(version int) Option { 47 | return func(c *Config) { 48 | switch version { 49 | case protocol.V1, protocol.V2: 50 | c.protocol = version 51 | default: 52 | c.protocol = protocol.V2 53 | } 54 | } 55 | } 56 | 57 | // WithBroker allows a caller to replace the default broker. 58 | func WithBroker(broker Broker) Option { 59 | return func(c *Config) { 60 | c.broker = broker 61 | } 62 | } 63 | 64 | // WithLogger sets a structured logger. 65 | func WithLogger(logger log.Logger) Option { 66 | return func(c *Config) { 67 | c.logger = logger 68 | } 69 | } 70 | 71 | // WithMaxWorkers sets an upper limit of goroutines 72 | // allowed to process Celery tasks. 73 | func WithMaxWorkers(n int) Option { 74 | return func(c *Config) { 75 | c.maxWorkers = n 76 | } 77 | } 78 | 79 | // WithMiddlewares sets a chain of task middlewares. 80 | // The first middleware is treated as the outermost middleware. 81 | func WithMiddlewares(chain ...Middleware) Option { 82 | return func(c *Config) { 83 | c.chain = func(next TaskF) TaskF { 84 | for i := len(chain) - 1; i >= 0; i-- { 85 | next = chain[i](next) 86 | } 87 | return next 88 | } 89 | } 90 | } 91 | 92 | // Config represents Celery settings. 93 | type Config struct { 94 | logger log.Logger 95 | broker Broker 96 | registry *protocol.SerializerRegistry 97 | mime string 98 | protocol int 99 | maxWorkers int 100 | chain Middleware 101 | } 102 | -------------------------------------------------------------------------------- /param_test.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | 7 | "github.com/google/go-cmp/cmp" 8 | ) 9 | 10 | func ExampleTaskParam() { 11 | var ( 12 | args = []interface{}{2} 13 | kwargs = map[string]interface{}{"b": 3} 14 | ) 15 | p := NewTaskParam(args, kwargs) 16 | p.NameArgs("a", "b") 17 | 18 | fmt.Println(p.Get("a")) 19 | fmt.Println(p.Get("b")) 20 | fmt.Println(p.Get("c")) 21 | // Output: 22 | // 2 true 23 | // 3 true 24 | // false 25 | } 26 | 27 | func TestTaskParamGet(t *testing.T) { 28 | tests := map[string]struct { 29 | p *TaskParam 30 | argNames []string 31 | pname string 32 | want interface{} 33 | exists bool 34 | }{ 35 | "no-params": { 36 | p: NewTaskParam(nil, nil), 37 | pname: "a", 38 | want: nil, 39 | exists: false, 40 | }, 41 | "found-kwarg": { 42 | p: NewTaskParam( 43 | nil, 44 | map[string]interface{}{"b": 3}, 45 | ), 46 | pname: "b", 47 | want: 3, 48 | exists: true, 49 | }, 50 | "unnamed-arg": { 51 | p: NewTaskParam( 52 | []interface{}{2}, 53 | nil, 54 | ), 55 | pname: "a", 56 | want: nil, 57 | exists: false, 58 | }, 59 | "found-named-arg": { 60 | p: NewTaskParam( 61 | []interface{}{2}, 62 | nil, 63 | ), 64 | pname: "a", 65 | argNames: []string{"a"}, 66 | want: 2, 67 | exists: true, 68 | }, 69 | "args-lt-names": { 70 | p: NewTaskParam( 71 | []interface{}{2}, 72 | nil, 73 | ), 74 | argNames: []string{"a", "b", "c"}, 75 | pname: "c", 76 | want: nil, 77 | exists: false, 78 | }, 79 | } 80 | 81 | for name, tc := range tests { 82 | t.Run(name, func(t *testing.T) { 83 | tc.p.NameArgs(tc.argNames...) 84 | 85 | got, ok := tc.p.Get(tc.pname) 86 | if diff := cmp.Diff(tc.want, got); diff != "" { 87 | t.Error(diff) 88 | } 89 | if tc.exists != ok { 90 | t.Errorf("expected %t got %t", tc.exists, ok) 91 | } 92 | }) 93 | } 94 | } 95 | 96 | func TestTaskParamMustInt(t *testing.T) { 97 | tests := map[string]struct { 98 | p *TaskParam 99 | argNames []string 100 | pname string 101 | want int 102 | panics bool 103 | }{ 104 | "float64": { 105 | p: NewTaskParam([]interface{}{2.0}, nil), 106 | argNames: []string{"a"}, 107 | pname: "a", 108 | want: 2, 109 | }, 110 | "int": { 111 | p: NewTaskParam([]interface{}{2}, nil), 112 | argNames: []string{"a"}, 113 | pname: "a", 114 | want: 2, 115 | }, 116 | "string": { 117 | p: NewTaskParam([]interface{}{"2"}, nil), 118 | argNames: []string{"a"}, 119 | pname: "a", 120 | panics: true, 121 | }, 122 | } 123 | 124 | for name, tc := range tests { 125 | t.Run(name, func(t *testing.T) { 126 | tc.p.NameArgs(tc.argNames...) 127 | 128 | defer func() { 129 | if r := recover(); r != nil { 130 | if !tc.panics { 131 | t.Error(r) 132 | } 133 | } 134 | }() 135 | 136 | got := tc.p.MustInt(tc.pname) 137 | if tc.want != got { 138 | t.Errorf("expected %d got %d", tc.want, got) 139 | } 140 | }) 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /examples/redis/metrics/main.go: -------------------------------------------------------------------------------- 1 | // Program metrics is a Celery worker with metrics middleware. 2 | package main 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "net/http" 8 | "os" 9 | "os/signal" 10 | "time" 11 | 12 | "github.com/go-kit/log" 13 | celery "github.com/marselester/gopher-celery" 14 | "github.com/oklog/run" 15 | "github.com/prometheus/client_golang/prometheus" 16 | "github.com/prometheus/client_golang/prometheus/promhttp" 17 | ) 18 | 19 | const serverAddr = ":8080" 20 | 21 | func main() { 22 | logger := log.NewJSONLogger(log.NewSyncWriter(os.Stderr)) 23 | 24 | m := metrics{ 25 | total: prometheus.NewCounterVec( 26 | prometheus.CounterOpts{ 27 | Name: "tasks_total", 28 | Help: "How many Celery tasks processed, partitioned by task name and error.", 29 | }, 30 | []string{"task", "error"}, 31 | ), 32 | duration: prometheus.NewHistogramVec( 33 | prometheus.HistogramOpts{ 34 | Name: "task_duration_seconds", 35 | Help: "How long it took in seconds to process a task.", 36 | Buckets: []float64{ 37 | 0.016, 0.032, 0.064, 0.128, 0.256, 0.512, 1.024, 2.048, 4.096, 8.192, 16.384, 32.768, 60, 38 | }, 39 | }, 40 | []string{"task"}, 41 | ), 42 | } 43 | prometheus.MustRegister(m.total) 44 | prometheus.MustRegister(m.duration) 45 | http.Handle("/metrics", promhttp.Handler()) 46 | 47 | app := celery.NewApp( 48 | celery.WithLogger(logger), 49 | celery.WithMiddlewares(m.middleware), 50 | ) 51 | app.Register( 52 | "myproject.mytask", 53 | "important", 54 | func(ctx context.Context, p *celery.TaskParam) error { 55 | p.NameArgs("a", "b") 56 | fmt.Printf("received a=%s b=%s\n", p.MustString("a"), p.MustString("b")) 57 | return nil 58 | }, 59 | ) 60 | 61 | ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) 62 | defer stop() 63 | 64 | srv := http.Server{Addr: serverAddr} 65 | 66 | var g run.Group 67 | { 68 | g.Add(func() error { 69 | logger.Log("msg", "starting http server with metrics") 70 | return srv.ListenAndServe() 71 | }, func(err error) { 72 | logger.Log("msg", "shutting down http server", "err", err) 73 | err = srv.Shutdown(ctx) 74 | logger.Log("msg", "http server shut down", "err", err) 75 | }) 76 | } 77 | { 78 | g.Add(func() error { 79 | logger.Log("msg", "waiting for tasks...") 80 | return app.Run(ctx) 81 | }, func(err error) { 82 | stop() 83 | logger.Log("msg", "celery shut down", "err", err) 84 | }) 85 | } 86 | err := g.Run() 87 | 88 | logger.Log("msg", "program stopped", "err", err) 89 | } 90 | 91 | type metrics struct { 92 | total *prometheus.CounterVec 93 | duration *prometheus.HistogramVec 94 | } 95 | 96 | func (m *metrics) middleware(next celery.TaskF) celery.TaskF { 97 | return func(ctx context.Context, p *celery.TaskParam) (err error) { 98 | name, ok := ctx.Value(celery.ContextKeyTaskName).(string) 99 | if !ok { 100 | return fmt.Errorf("task name not found in context") 101 | } 102 | 103 | defer func(begin time.Time) { 104 | m.total.With(prometheus.Labels{ 105 | "task": name, 106 | "error": fmt.Sprint(err != nil), 107 | }).Inc() 108 | m.duration.With(prometheus.Labels{ 109 | "task": name, 110 | }).Observe(time.Since(begin).Seconds()) 111 | }(time.Now()) 112 | 113 | return next(ctx, p) 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /param.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | import "reflect" 4 | 5 | // NewTaskParam returns a task param which facilitates access to args and kwargs. 6 | func NewTaskParam(args []interface{}, kwargs map[string]interface{}) *TaskParam { 7 | return &TaskParam{ 8 | args: args, 9 | kwargs: kwargs, 10 | } 11 | } 12 | 13 | // TaskParam provides access to task's positional and keyword arguments. 14 | // A task function might not know upfront how parameters will be supplied from the caller. 15 | // They could be passed as positional arguments f(2, 3), 16 | // keyword arguments f(a=2, b=3) or a mix of both f(2, b=3). 17 | // In this case the arguments should be named and accessed by name, 18 | // see NameArgs and Get methods. 19 | // 20 | // Methods prefixed with Must panic if they can't find an argument name 21 | // or can't cast it to the corresponding type. 22 | // The panic is logged by a worker and it doesn't affect other tasks. 23 | type TaskParam struct { 24 | // argNames is map of argument names to the respective args indices. 25 | argNames map[string]int 26 | // args are arguments. 27 | args []interface{} 28 | // kwargs are keyword arguments. 29 | kwargs map[string]interface{} 30 | } 31 | 32 | // Args returns task's positional arguments. 33 | func (p *TaskParam) Args() []interface{} { 34 | return p.args 35 | } 36 | 37 | // Kwargs returns task's keyword arguments. 38 | func (p *TaskParam) Kwargs() map[string]interface{} { 39 | return p.kwargs 40 | } 41 | 42 | // NameArgs assigns names to the task arguments. 43 | func (p *TaskParam) NameArgs(name ...string) { 44 | p.argNames = make(map[string]int, len(p.args)) 45 | 46 | for i := 0; i < len(name); i++ { 47 | p.argNames[name[i]] = i 48 | } 49 | } 50 | 51 | // Get returns a parameter by name. 52 | // Firstly it tries to look it up in Kwargs, 53 | // and then in Args if their names were provided by the client. 54 | func (p *TaskParam) Get(name string) (v interface{}, ok bool) { 55 | if v, ok = p.kwargs[name]; ok { 56 | return v, true 57 | } 58 | 59 | var pos int 60 | pos, ok = p.argNames[name] 61 | if !ok || pos >= len(p.args) { 62 | return nil, false 63 | } 64 | 65 | return p.args[pos], true 66 | } 67 | 68 | // MustString looks up a parameter by name and casts it to string. 69 | // It panics if a parameter is missing or of a wrong type. 70 | func (p *TaskParam) MustString(name string) string { 71 | v, ok := p.Get(name) 72 | if !ok { 73 | panic("param not found") 74 | } 75 | return v.(string) 76 | } 77 | 78 | // MustInt looks up a parameter by name and casts it to integer. 79 | // It panics if a parameter is missing or of a wrong type. 80 | func (p *TaskParam) MustInt(name string) int { 81 | v, ok := p.Get(name) 82 | if !ok { 83 | panic("param not found") 84 | } 85 | 86 | switch reflect.TypeOf(v).Kind() { 87 | case reflect.Float64: 88 | return int(v.(float64)) 89 | default: 90 | return v.(int) 91 | } 92 | } 93 | 94 | // MustFloat looks up a parameter by name and casts it to float. 95 | // It panics if a parameter is missing or of a wrong type. 96 | func (p *TaskParam) MustFloat(name string) float64 { 97 | v, ok := p.Get(name) 98 | if !ok { 99 | panic("param not found") 100 | } 101 | return v.(float64) 102 | } 103 | 104 | // MustBool looks up a parameter by name and casts it to boolean. 105 | // It panics if a parameter is missing or of a wrong type. 106 | func (p *TaskParam) MustBool(name string) bool { 107 | v, ok := p.Get(name) 108 | if !ok { 109 | panic("param not found") 110 | } 111 | return v.(bool) 112 | } 113 | -------------------------------------------------------------------------------- /goredis/broker.go: -------------------------------------------------------------------------------- 1 | // Package goredis implements a Celery broker using Redis 2 | // and https://github.com/redis/go-redis. 3 | package goredis 4 | 5 | import ( 6 | "context" 7 | "fmt" 8 | "time" 9 | 10 | "github.com/redis/go-redis/v9" 11 | 12 | "github.com/marselester/gopher-celery/internal/broker" 13 | ) 14 | 15 | // DefaultReceiveTimeout defines how many seconds the broker's Receive command 16 | // should block waiting for results from Redis. 17 | const DefaultReceiveTimeout = 5 18 | 19 | // BrokerOption sets up a Broker. 20 | type BrokerOption func(*Broker) 21 | 22 | // WithReceiveTimeout sets a timeout of how long the broker's Receive command 23 | // should block waiting for results from Redis. 24 | // Larger the timeout, longer the client will have to wait for Celery app to exit. 25 | // Smaller the timeout, more BRPOP commands would have to be sent to Redis. 26 | func WithReceiveTimeout(timeout time.Duration) BrokerOption { 27 | return func(br *Broker) { 28 | br.receiveTimeout = timeout 29 | } 30 | } 31 | 32 | // WithClient sets Redis client representing a pool of connections. 33 | func WithClient(c *redis.Client) BrokerOption { 34 | return func(br *Broker) { 35 | br.pool = c 36 | } 37 | } 38 | 39 | // NewBroker creates a broker backed by Redis. 40 | // By default, it connects to localhost. 41 | func NewBroker(options ...BrokerOption) *Broker { 42 | br := Broker{ 43 | receiveTimeout: DefaultReceiveTimeout * time.Second, 44 | ctx: context.Background(), 45 | } 46 | for _, opt := range options { 47 | opt(&br) 48 | } 49 | 50 | if br.pool == nil { 51 | br.pool = redis.NewClient(&redis.Options{}) 52 | } 53 | return &br 54 | } 55 | 56 | // Broker is a Redis broker that sends/receives messages from specified queues. 57 | type Broker struct { 58 | pool *redis.Client 59 | queues []string 60 | receiveTimeout time.Duration 61 | ctx context.Context 62 | } 63 | 64 | // Send inserts the specified message at the head of the queue using LPUSH command. 65 | // Note, the method is safe to call concurrently. 66 | func (br *Broker) Send(m []byte, q string) error { 67 | res := br.pool.LPush(br.ctx, q, m) 68 | return res.Err() 69 | } 70 | 71 | // Observe sets the queues from which the tasks should be received. 72 | // Note, the method is not concurrency safe. 73 | func (br *Broker) Observe(queues []string) error { 74 | br.queues = queues 75 | return nil 76 | } 77 | 78 | // Receive fetches a Celery task message from a tail of one of the queues in Redis. 79 | // After a timeout it returns nil, nil. 80 | // 81 | // Celery relies on BRPOP command to process messages fairly, see https://github.com/celery/kombu/issues/166. 82 | // Redis BRPOP is a blocking list pop primitive. 83 | // It blocks the connection when there are no elements to pop from any of the given lists. 84 | // An element is popped from the tail of the first list that is non-empty, 85 | // with the given keys being checked in the order that they are given, 86 | // see https://redis.io/commands/brpop/. 87 | // 88 | // Note, the method is not concurrency safe. 89 | func (br *Broker) Receive() ([]byte, error) { 90 | res := br.pool.BRPop(br.ctx, br.receiveTimeout, br.queues...) 91 | err := res.Err() 92 | if err == redis.Nil { 93 | return nil, nil 94 | } 95 | if err != nil { 96 | return nil, fmt.Errorf("failed to BRPOP %v: %w", br.queues, err) 97 | } 98 | 99 | // Put the Celery queue name to the end of the slice for fair processing. 100 | q := res.Val()[0] 101 | b := res.Val()[1] 102 | broker.Move2back(br.queues, q) 103 | return []byte(b), nil 104 | } 105 | -------------------------------------------------------------------------------- /redis/broker.go: -------------------------------------------------------------------------------- 1 | // Package redis implements a Celery broker using Redis 2 | // and github.com/gomodule/redigo. 3 | package redis 4 | 5 | import ( 6 | "fmt" 7 | "time" 8 | 9 | "github.com/gomodule/redigo/redis" 10 | 11 | "github.com/marselester/gopher-celery/internal/broker" 12 | ) 13 | 14 | // DefaultReceiveTimeout defines how many seconds the broker's Receive command 15 | // should block waiting for results from Redis. 16 | const DefaultReceiveTimeout = 5 17 | 18 | // BrokerOption sets up a Broker. 19 | type BrokerOption func(*Broker) 20 | 21 | // WithReceiveTimeout sets a timeout of how long the broker's Receive command 22 | // should block waiting for results from Redis. 23 | // Larger the timeout, longer the client will have to wait for Celery app to exit. 24 | // Smaller the timeout, more BRPOP commands would have to be sent to Redis. 25 | // 26 | // Note, the read timeout you specified with redis.DialReadTimeout() method 27 | // should be bigger than the receive timeout. 28 | // Otherwise redigo would return i/o timeout error. 29 | func WithReceiveTimeout(timeout time.Duration) BrokerOption { 30 | return func(br *Broker) { 31 | sec := int(timeout.Seconds()) 32 | if sec <= 0 { 33 | sec = 1 34 | } 35 | br.receiveTimeout = sec 36 | } 37 | } 38 | 39 | // WithPool sets Redis connection pool. 40 | func WithPool(pool *redis.Pool) BrokerOption { 41 | return func(br *Broker) { 42 | br.pool = pool 43 | } 44 | } 45 | 46 | // NewBroker creates a broker backed by Redis. 47 | // By default it connects to localhost. 48 | func NewBroker(options ...BrokerOption) *Broker { 49 | br := Broker{ 50 | receiveTimeout: DefaultReceiveTimeout, 51 | } 52 | for _, opt := range options { 53 | opt(&br) 54 | } 55 | 56 | if br.pool == nil { 57 | br.pool = &redis.Pool{ 58 | Dial: func() (redis.Conn, error) { 59 | return redis.DialURL("redis://localhost") 60 | }, 61 | } 62 | } 63 | return &br 64 | } 65 | 66 | // Broker is a Redis broker that sends/receives messages from specified queues. 67 | type Broker struct { 68 | pool *redis.Pool 69 | queues []string 70 | receiveTimeout int 71 | } 72 | 73 | // Send inserts the specified message at the head of the queue using LPUSH command. 74 | // Note, the method is safe to call concurrently. 75 | func (br *Broker) Send(m []byte, q string) error { 76 | conn := br.pool.Get() 77 | defer conn.Close() //nolint:errcheck 78 | 79 | _, err := conn.Do("LPUSH", q, m) 80 | return err 81 | } 82 | 83 | // Observe sets the queues from which the tasks should be received. 84 | // Note, the method is not concurrency safe. 85 | func (br *Broker) Observe(queues []string) error { 86 | br.queues = queues 87 | return nil 88 | } 89 | 90 | // Receive fetches a Celery task message from a tail of one of the queues in Redis. 91 | // After a timeout it returns nil, nil. 92 | // 93 | // Celery relies on BRPOP command to process messages fairly, see https://github.com/celery/kombu/issues/166. 94 | // Redis BRPOP is a blocking list pop primitive. 95 | // It blocks the connection when there are no elements to pop from any of the given lists. 96 | // An element is popped from the tail of the first list that is non-empty, 97 | // with the given keys being checked in the order that they are given, 98 | // see https://redis.io/commands/brpop/. 99 | // 100 | // Note, the method is not concurrency safe. 101 | func (br *Broker) Receive() ([]byte, error) { 102 | conn := br.pool.Get() 103 | defer conn.Close() //nolint:errcheck 104 | 105 | // See the discussion regarding timeout and Context cancellation 106 | // https://github.com/gomodule/redigo/issues/207#issuecomment-283815775. 107 | res, err := redis.ByteSlices(conn.Do( 108 | "BRPOP", 109 | redis.Args{}.AddFlat(br.queues).Add(br.receiveTimeout)..., 110 | )) 111 | if err == redis.ErrNil { 112 | return nil, nil 113 | } 114 | if err != nil { 115 | return nil, fmt.Errorf("failed to BRPOP %v: %w", br.queues, err) 116 | } 117 | 118 | // Put the Celery queue name to the end of the slice for fair processing. 119 | q := string(res[0]) 120 | b := res[1] 121 | broker.Move2back(br.queues, q) 122 | return b, nil 123 | } 124 | -------------------------------------------------------------------------------- /examples/go.sum: -------------------------------------------------------------------------------- 1 | github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= 2 | github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= 3 | github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= 4 | github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= 5 | github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= 6 | github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= 7 | github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= 8 | github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= 9 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 10 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 11 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= 12 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= 13 | github.com/go-kit/log v0.2.1 h1:MRVx0/zhvdseW+Gza6N9rVzU/IVzaeE1SFI4raAhmBU= 14 | github.com/go-kit/log v0.2.1/go.mod h1:NwTd00d/i8cPZ3xOwwiv2PO5MOcx78fFErGNcVmBjv0= 15 | github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= 16 | github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= 17 | github.com/gomodule/redigo v1.9.2 h1:HrutZBLhSIU8abiSfW8pj8mPhOyMYjZT/wcA4/L9L9s= 18 | github.com/gomodule/redigo v1.9.2/go.mod h1:KsU3hiK/Ay8U42qpaJk+kuNa3C+spxapWpM+ywhcgtw= 19 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 20 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 21 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 22 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 23 | github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= 24 | github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= 25 | github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= 26 | github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= 27 | github.com/marselester/backoff v0.0.1 h1:kqdo2QHvKxfo2rH+28usaI94atP7uiGYNsICLHY9kgE= 28 | github.com/marselester/backoff v0.0.1/go.mod h1:ONi5Ngkrx9Tcyb3rgcRIBuz2FM6aRDjyZEncpWgayns= 29 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= 30 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= 31 | github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA= 32 | github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= 33 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 34 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 35 | github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= 36 | github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= 37 | github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= 38 | github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= 39 | github.com/prometheus/common v0.61.0 h1:3gv/GThfX0cV2lpO7gkTUwZru38mxevy90Bj8YFSRQQ= 40 | github.com/prometheus/common v0.61.0/go.mod h1:zr29OCN/2BsJRaFwG8QOBr41D6kkchKbpeNH7pAjb/s= 41 | github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= 42 | github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= 43 | github.com/rabbitmq/amqp091-go v1.10.0 h1:STpn5XsHlHGcecLmMFCtg7mqq0RnD+zFr4uzukfVhBw= 44 | github.com/rabbitmq/amqp091-go v1.10.0/go.mod h1:Hy4jKW5kQART1u+JkDTF9YYOQUHXqMuhrgxOEeS7G4o= 45 | github.com/redis/go-redis/v9 v9.7.0 h1:HhLSs+B6O021gwzl+locl0zEDnyNkxMtf/Z3NNBMa9E= 46 | github.com/redis/go-redis/v9 v9.7.0/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw= 47 | github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 48 | github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 49 | go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 50 | go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 51 | golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= 52 | golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 53 | golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= 54 | golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 55 | google.golang.org/protobuf v1.36.1 h1:yBPeRvTftaleIgM3PZ/WBIZ7XM/eEYAaEyCwvyjq/gk= 56 | google.golang.org/protobuf v1.36.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= 57 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 58 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 59 | -------------------------------------------------------------------------------- /protocol/json.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "encoding/json" 7 | "fmt" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | // NewJSONSerializer returns JSONSerializer. 13 | func NewJSONSerializer() *JSONSerializer { 14 | return &JSONSerializer{ 15 | pool: sync.Pool{New: func() interface{} { 16 | return &bytes.Buffer{} 17 | }}, 18 | now: time.Now, 19 | } 20 | } 21 | 22 | // JSONSerializer encodes/decodes a task messages in JSON format. 23 | // The zero value is not usable. 24 | type JSONSerializer struct { 25 | pool sync.Pool 26 | now func() time.Time 27 | } 28 | 29 | // jsonInboundV1Body helps to decode a message body v1. 30 | type jsonInboundV1Body struct { 31 | ID string `json:"id"` 32 | Task string `json:"task"` 33 | Args []interface{} `json:"args"` 34 | Kwargs map[string]interface{} `json:"kwargs"` 35 | Expires time.Time `json:"expires"` 36 | } 37 | 38 | // Decode parses the JSON-encoded message body s 39 | // depending on Celery protocol version (v1 or v2). 40 | // The task t is updated with the decoded params. 41 | func (ser *JSONSerializer) Decode(p int, s string, t *Task) error { 42 | b, err := base64.StdEncoding.DecodeString(s) 43 | if err != nil { 44 | return fmt.Errorf("base64 decode: %w", err) 45 | } 46 | 47 | switch p { 48 | case V1: 49 | var body jsonInboundV1Body 50 | if err := json.Unmarshal(b, &body); err != nil { 51 | return fmt.Errorf("json decode: %w", err) 52 | } 53 | 54 | t.ID = body.ID 55 | t.Name = body.Task 56 | t.Args = body.Args 57 | t.Kwargs = body.Kwargs 58 | t.Expires = body.Expires 59 | case V2: 60 | var a [3]interface{} 61 | if err := json.Unmarshal(b, &a); err != nil { 62 | return fmt.Errorf("json decode: %w", err) 63 | } 64 | args, ok := a[0].([]interface{}) 65 | if !ok { 66 | return fmt.Errorf("expected args: %v", a[0]) 67 | } 68 | kwargs, ok := a[1].(map[string]interface{}) 69 | if !ok { 70 | return fmt.Errorf("expected kwargs: %v", a[1]) 71 | } 72 | 73 | t.Args = args 74 | t.Kwargs = kwargs 75 | default: 76 | return fmt.Errorf("unknown protocol version %d", p) 77 | } 78 | 79 | return nil 80 | } 81 | 82 | // Encode encodes task t using protocol version p and returns the message body s. 83 | func (ser *JSONSerializer) Encode(p int, t *Task) (s string, err error) { 84 | if p == V1 { 85 | return ser.encodeV1(t) 86 | } 87 | return ser.encodeV2(t) 88 | } 89 | 90 | // jsonOutboundV1Body is an auxiliary task struct to encode the message body v1 in json. 91 | type jsonOutboundV1Body struct { 92 | ID string `json:"id"` 93 | Task string `json:"task"` 94 | Args []interface{} `json:"args"` 95 | Kwargs json.RawMessage `json:"kwargs"` 96 | Expires *string `json:"expires"` 97 | // Retries is a current number of times this task has been retried. 98 | // It's always set to zero. 99 | Retries int `json:"retries"` 100 | // ETA is an estimated time of arrival in ISO 8601 format, e.g., 2009-11-17T12:30:56.527191. 101 | // If not provided the message isn't scheduled, but will be executed ASAP. 102 | ETA string `json:"eta"` 103 | // UTC indicates to use the UTC timezone or the current local timezone. 104 | UTC bool `json:"utc"` 105 | } 106 | 107 | func (ser *JSONSerializer) encodeV1(t *Task) (s string, err error) { 108 | v := jsonOutboundV1Body{ 109 | ID: t.ID, 110 | Task: t.Name, 111 | Args: t.Args, 112 | Kwargs: jsonEmptyMap, 113 | ETA: ser.now().Format(time.RFC3339), 114 | UTC: true, 115 | } 116 | if t.Args == nil { 117 | v.Args = make([]interface{}, 0) 118 | } 119 | if t.Kwargs != nil { 120 | if v.Kwargs, err = json.Marshal(t.Kwargs); err != nil { 121 | return "", fmt.Errorf("kwargs json encode: %w", err) 122 | } 123 | } 124 | if !t.Expires.IsZero() { 125 | s := t.Expires.Format(time.RFC3339) 126 | v.Expires = &s 127 | } 128 | 129 | buf := ser.pool.Get().(*bytes.Buffer) 130 | buf.Reset() 131 | defer ser.pool.Put(buf) 132 | 133 | if err = json.NewEncoder(buf).Encode(&v); err != nil { 134 | return "", fmt.Errorf("json encode: %w", err) 135 | } 136 | 137 | return base64.StdEncoding.EncodeToString(buf.Bytes()), nil 138 | } 139 | 140 | const ( 141 | // jsonV2opts represents blank task options in protocol v2. 142 | // They are blank because none of those features are supported here. 143 | jsonV2opts = `{"callbacks":null,"errbacks":null,"chain":null,"chord":null}` 144 | // jsonV2noparams is a base64+json encoded task with no args/kwargs when protocol v2 is used. 145 | // It helps to reduce allocs. 146 | jsonV2noparams = "W1tdLCB7fSwgeyJjYWxsYmFja3MiOiBudWxsLCAiZXJyYmFja3MiOiBudWxsLCAiY2hhaW4iOiBudWxsLCAiY2hvcmQiOiBudWxsfV0=" 147 | ) 148 | 149 | func (ser *JSONSerializer) encodeV2(t *Task) (s string, err error) { 150 | if t.Args == nil && t.Kwargs == nil { 151 | return jsonV2noparams, nil 152 | } 153 | 154 | buf := ser.pool.Get().(*bytes.Buffer) 155 | buf.Reset() 156 | defer ser.pool.Put(buf) 157 | 158 | buf.WriteRune('[') 159 | { 160 | js := json.NewEncoder(buf) 161 | if t.Args == nil { 162 | buf.WriteString("[]") 163 | } else if err = js.Encode(t.Args); err != nil { 164 | return "", fmt.Errorf("args json encode: %w", err) 165 | } 166 | 167 | buf.WriteRune(',') 168 | 169 | if t.Kwargs == nil { 170 | buf.WriteString("{}") 171 | } else if err = js.Encode(t.Kwargs); err != nil { 172 | return "", fmt.Errorf("kwargs json encode: %w", err) 173 | } 174 | 175 | buf.WriteRune(',') 176 | 177 | buf.WriteString(jsonV2opts) 178 | } 179 | buf.WriteRune(']') 180 | 181 | return base64.StdEncoding.EncodeToString(buf.Bytes()), nil 182 | } 183 | -------------------------------------------------------------------------------- /protocol/json_test.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "testing" 7 | "time" 8 | 9 | "github.com/google/go-cmp/cmp" 10 | ) 11 | 12 | func TestJSONSerializerEncode(t *testing.T) { 13 | tests := map[string]struct { 14 | task Task 15 | version int 16 | body string 17 | }{ 18 | "v2_noparams": { 19 | task: Task{ 20 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 21 | Name: "myproject.apps.myapp.tasks.mytask", 22 | }, 23 | version: 2, 24 | body: `[[], {}, {"callbacks": null, "errbacks": null, "chain": null, "chord": null}]`, 25 | }, 26 | "v2_args": { 27 | task: Task{ 28 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 29 | Name: "myproject.apps.myapp.tasks.mytask", 30 | Args: []interface{}{"fizz"}, 31 | }, 32 | version: 2, 33 | body: `[["fizz"],{},{"callbacks":null,"errbacks":null,"chain":null,"chord":null}]`, 34 | }, 35 | "v2_argskwargs": { 36 | task: Task{ 37 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 38 | Name: "myproject.apps.myapp.tasks.mytask", 39 | Args: []interface{}{"fizz"}, 40 | Kwargs: map[string]interface{}{"b": "bazz"}, 41 | }, 42 | version: 2, 43 | body: `[["fizz"],{"b":"bazz"},{"callbacks":null,"errbacks":null,"chain":null,"chord":null}]`, 44 | }, 45 | "v1_noparams": { 46 | task: Task{ 47 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 48 | Name: "myproject.apps.myapp.tasks.mytask", 49 | }, 50 | version: 1, 51 | body: `{"id":"0ad73c66-f4c9-4600-bd20-96746e720eed","task":"myproject.apps.myapp.tasks.mytask","args":[],"kwargs":{},"expires":null,"retries":0,"eta":"2009-11-17T12:30:56Z","utc":true}`, 52 | }, 53 | "v1_args": { 54 | task: Task{ 55 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 56 | Name: "myproject.apps.myapp.tasks.mytask", 57 | Args: []interface{}{"fizz"}, 58 | }, 59 | version: 1, 60 | body: `{"id":"0ad73c66-f4c9-4600-bd20-96746e720eed","task":"myproject.apps.myapp.tasks.mytask","args":["fizz"],"kwargs":{},"expires":null,"retries":0,"eta":"2009-11-17T12:30:56Z","utc":true}`, 61 | }, 62 | "v1_argskwargs": { 63 | task: Task{ 64 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 65 | Name: "myproject.apps.myapp.tasks.mytask", 66 | Args: []interface{}{"fizz"}, 67 | Kwargs: map[string]interface{}{"b": "bazz"}, 68 | }, 69 | version: 1, 70 | body: `{"id":"0ad73c66-f4c9-4600-bd20-96746e720eed","task":"myproject.apps.myapp.tasks.mytask","args":["fizz"],"kwargs":{"b":"bazz"},"expires":null,"retries":0,"eta":"2009-11-17T12:30:56Z","utc":true}`, 71 | }, 72 | } 73 | 74 | s := NewJSONSerializer() 75 | s.now = func() time.Time { 76 | // 2009-11-17T12:30:56Z 77 | return time.Date(2009, 11, 17, 12, 30, 56, 0, time.UTC) 78 | } 79 | for name, tc := range tests { 80 | t.Run(name, func(t *testing.T) { 81 | b64, err := s.Encode(tc.version, &tc.task) 82 | if err != nil { 83 | t.Fatal(err) 84 | } 85 | gotb64, err := base64.StdEncoding.DecodeString(b64) 86 | if err != nil { 87 | t.Fatal(err) 88 | } 89 | 90 | got := string( 91 | bytes.ReplaceAll(gotb64, []byte("\n"), nil), 92 | ) 93 | if diff := cmp.Diff(tc.body, got); diff != "" { 94 | t.Error(diff, got) 95 | } 96 | }) 97 | } 98 | } 99 | 100 | func BenchmarkJSONSerializerEncode_v2NoParams(b *testing.B) { 101 | s := NewJSONSerializer() 102 | task := Task{ 103 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 104 | Name: "myproject.apps.myapp.tasks.mytask", 105 | } 106 | b.ResetTimer() 107 | 108 | for n := 0; n < b.N; n++ { 109 | s.Encode(2, &task) 110 | } 111 | } 112 | 113 | func BenchmarkJSONSerializerEncode_v2Args(b *testing.B) { 114 | s := NewJSONSerializer() 115 | task := Task{ 116 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 117 | Name: "myproject.apps.myapp.tasks.mytask", 118 | Args: []interface{}{"fizz"}, 119 | } 120 | b.ResetTimer() 121 | 122 | for n := 0; n < b.N; n++ { 123 | s.Encode(2, &task) 124 | } 125 | } 126 | 127 | func BenchmarkJSONSerializerEncode_v2Kwargs(b *testing.B) { 128 | s := NewJSONSerializer() 129 | task := Task{ 130 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 131 | Name: "myproject.apps.myapp.tasks.mytask", 132 | Kwargs: map[string]interface{}{"b": "bazz"}, 133 | } 134 | b.ResetTimer() 135 | 136 | for n := 0; n < b.N; n++ { 137 | s.Encode(2, &task) 138 | } 139 | } 140 | 141 | func BenchmarkJSONSerializerEncode_v2ArgsKwargs(b *testing.B) { 142 | s := NewJSONSerializer() 143 | task := Task{ 144 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 145 | Name: "myproject.apps.myapp.tasks.mytask", 146 | Args: []interface{}{"fizz"}, 147 | Kwargs: map[string]interface{}{"b": "bazz"}, 148 | } 149 | b.ResetTimer() 150 | 151 | for n := 0; n < b.N; n++ { 152 | s.Encode(2, &task) 153 | } 154 | } 155 | 156 | func BenchmarkJSONSerializerEncode_v1NoParams(b *testing.B) { 157 | s := NewJSONSerializer() 158 | task := Task{ 159 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 160 | Name: "myproject.apps.myapp.tasks.mytask", 161 | } 162 | b.ResetTimer() 163 | 164 | for n := 0; n < b.N; n++ { 165 | s.Encode(1, &task) 166 | } 167 | } 168 | 169 | func BenchmarkJSONSerializerEncode_v1Args(b *testing.B) { 170 | s := NewJSONSerializer() 171 | task := Task{ 172 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 173 | Name: "myproject.apps.myapp.tasks.mytask", 174 | Args: []interface{}{"fizz"}, 175 | } 176 | b.ResetTimer() 177 | 178 | for n := 0; n < b.N; n++ { 179 | s.Encode(1, &task) 180 | } 181 | } 182 | 183 | func BenchmarkJSONSerializerEncode_v1Kwargs(b *testing.B) { 184 | s := NewJSONSerializer() 185 | task := Task{ 186 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 187 | Name: "myproject.apps.myapp.tasks.mytask", 188 | Kwargs: map[string]interface{}{"b": "bazz"}, 189 | } 190 | b.ResetTimer() 191 | 192 | for n := 0; n < b.N; n++ { 193 | s.Encode(1, &task) 194 | } 195 | } 196 | 197 | func BenchmarkJSONSerializerEncode_v1ArgsKwargs(b *testing.B) { 198 | s := NewJSONSerializer() 199 | task := Task{ 200 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 201 | Name: "myproject.apps.myapp.tasks.mytask", 202 | Args: []interface{}{"fizz"}, 203 | Kwargs: map[string]interface{}{"b": "bazz"}, 204 | } 205 | b.ResetTimer() 206 | 207 | for n := 0; n < b.N; n++ { 208 | s.Encode(1, &task) 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /protocol/serializer_test.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "bytes" 5 | "os" 6 | "path/filepath" 7 | "strings" 8 | "testing" 9 | "time" 10 | "unicode" 11 | 12 | "github.com/google/go-cmp/cmp" 13 | ) 14 | 15 | func TestSerializerRegistryDecode(t *testing.T) { 16 | tests := map[string]Task{ 17 | "v2_argskwargs.json": { 18 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 19 | Name: "myproject.apps.myapp.tasks.mytask", 20 | Args: []interface{}{"fizz"}, 21 | Kwargs: map[string]interface{}{"b": "bazz"}, 22 | Expires: time.Time{}, 23 | }, 24 | "v2_noparams.json": { 25 | ID: "3802f860-8d3c-4dad-b18c-597fb2ac728b", 26 | Name: "myproject.apps.myapp.tasks.mytask", 27 | Args: []interface{}{}, 28 | Kwargs: map[string]interface{}{}, 29 | Expires: time.Time{}, 30 | }, 31 | "v1_noparams.json": { 32 | ID: "0d09a6dd-99fc-436a-a41a-0dcaa4875459", 33 | Name: "myproject.apps.myapp.tasks.mytask", 34 | Args: []interface{}{}, 35 | Kwargs: map[string]interface{}{}, 36 | Expires: time.Time{}, 37 | }, 38 | } 39 | 40 | r := NewSerializerRegistry() 41 | for testfile, want := range tests { 42 | t.Run(testfile, func(t *testing.T) { 43 | filename := filepath.Join("testdata", testfile) 44 | content, err := os.ReadFile(filename) 45 | if err != nil { 46 | t.Fatal(err) 47 | } 48 | 49 | got, err := r.Decode(content) 50 | if err != nil { 51 | t.Fatal(err) 52 | } 53 | 54 | if diff := cmp.Diff(&want, got); diff != "" { 55 | t.Error(diff) 56 | } 57 | }) 58 | } 59 | } 60 | 61 | func TestSerializerRegistryEncode(t *testing.T) { 62 | tests := map[string]struct { 63 | task Task 64 | queue string 65 | format string 66 | version int 67 | 68 | msg string 69 | }{ 70 | "v1_argskwargs": { 71 | task: Task{ 72 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 73 | Name: "myproject.apps.myapp.tasks.mytask", 74 | Args: []interface{}{"fizz"}, 75 | Kwargs: map[string]interface{}{"b": "bazz"}, 76 | }, 77 | queue: "important", 78 | format: "application/json", 79 | version: 1, 80 | msg: ` 81 | { 82 | "body": "", 83 | "content-encoding": "utf-8", 84 | "content-type": "application/json", 85 | "headers": {}, 86 | "properties": { 87 | "delivery_info": { 88 | "exchange": "important", 89 | "routing_key": "important" 90 | }, 91 | "correlation_id": "0ad73c66-f4c9-4600-bd20-96746e720eed", 92 | "reply_to": "967ff33a-e83c-4225-99ea-1d945c62526a", 93 | "body_encoding": "base64", 94 | "delivery_tag": "967ff33a-e83c-4225-99ea-1d945c62526a", 95 | "delivery_mode": 2, 96 | "priority": 0 97 | } 98 | }`, 99 | }, 100 | "v2_argskwargs": { 101 | task: Task{ 102 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 103 | Name: "myproject.apps.myapp.tasks.mytask", 104 | Args: []interface{}{"fizz"}, 105 | Kwargs: map[string]interface{}{"b": "bazz"}, 106 | }, 107 | queue: "important", 108 | format: "application/json", 109 | version: 2, 110 | msg: ` 111 | { 112 | "body": "", 113 | "content-encoding": "utf-8", 114 | "content-type": "application/json", 115 | "headers": { 116 | "lang": "go", 117 | "id": "0ad73c66-f4c9-4600-bd20-96746e720eed", 118 | "root_id": "0ad73c66-f4c9-4600-bd20-96746e720eed", 119 | "task": "myproject.apps.myapp.tasks.mytask", 120 | "origin": "123@home", 121 | "expires": null, 122 | "retries": 0 123 | }, 124 | "properties": { 125 | "delivery_info": { 126 | "exchange": "important", 127 | "routing_key": "important" 128 | }, 129 | "correlation_id": "0ad73c66-f4c9-4600-bd20-96746e720eed", 130 | "reply_to": "967ff33a-e83c-4225-99ea-1d945c62526a", 131 | "body_encoding": "base64", 132 | "delivery_tag": "967ff33a-e83c-4225-99ea-1d945c62526a", 133 | "delivery_mode": 2, 134 | "priority": 0 135 | } 136 | }`, 137 | }, 138 | "v2_noparams": { 139 | task: Task{ 140 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 141 | Name: "myproject.apps.myapp.tasks.mytask", 142 | }, 143 | queue: "important", 144 | format: "application/json", 145 | version: 2, 146 | msg: ` 147 | { 148 | "body": "", 149 | "content-encoding": "utf-8", 150 | "content-type": "application/json", 151 | "headers": { 152 | "lang": "go", 153 | "id": "0ad73c66-f4c9-4600-bd20-96746e720eed", 154 | "root_id": "0ad73c66-f4c9-4600-bd20-96746e720eed", 155 | "task": "myproject.apps.myapp.tasks.mytask", 156 | "origin": "123@home", 157 | "expires": null, 158 | "retries": 0 159 | }, 160 | "properties": { 161 | "delivery_info": { 162 | "exchange": "important", 163 | "routing_key": "important" 164 | }, 165 | "correlation_id": "0ad73c66-f4c9-4600-bd20-96746e720eed", 166 | "reply_to": "967ff33a-e83c-4225-99ea-1d945c62526a", 167 | "body_encoding": "base64", 168 | "delivery_tag": "967ff33a-e83c-4225-99ea-1d945c62526a", 169 | "delivery_mode": 2, 170 | "priority": 0 171 | } 172 | }`, 173 | }, 174 | } 175 | 176 | r := NewSerializerRegistry() 177 | // Suppress body encoding to simplify testing. 178 | r.serializers["application/json"] = &mockSerializer{} 179 | r.uuid4 = func() string { 180 | return "967ff33a-e83c-4225-99ea-1d945c62526a" 181 | } 182 | r.origin = "123@home" 183 | 184 | for name, tc := range tests { 185 | t.Run(name, func(t *testing.T) { 186 | b, err := r.Encode(tc.queue, tc.format, tc.version, &tc.task) 187 | if err != nil { 188 | t.Fatal(err) 189 | } 190 | 191 | got := string( 192 | bytes.ReplaceAll(b, []byte("\n"), nil), 193 | ) 194 | want := strings.Map(func(r rune) rune { 195 | if !unicode.IsSpace(r) { 196 | return r 197 | } 198 | return -1 199 | }, tc.msg) 200 | if diff := cmp.Diff(want, got); diff != "" { 201 | t.Error(diff, want, got) 202 | } 203 | }) 204 | } 205 | } 206 | 207 | func BenchmarkSerializerRegistryEncode_v1Args(b *testing.B) { 208 | r := NewSerializerRegistry() 209 | r.serializers["application/json"] = &mockSerializer{} 210 | 211 | task := Task{ 212 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 213 | Name: "myproject.apps.myapp.tasks.mytask", 214 | Args: []interface{}{"fizz"}, 215 | } 216 | b.ResetTimer() 217 | 218 | for n := 0; n < b.N; n++ { 219 | r.Encode("important", "application/json", 1, &task) 220 | } 221 | } 222 | 223 | func BenchmarkSerializerRegistryEncode_v2Args(b *testing.B) { 224 | r := NewSerializerRegistry() 225 | r.serializers["application/json"] = &mockSerializer{} 226 | 227 | task := Task{ 228 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 229 | Name: "myproject.apps.myapp.tasks.mytask", 230 | Args: []interface{}{"fizz"}, 231 | } 232 | b.ResetTimer() 233 | 234 | for n := 0; n < b.N; n++ { 235 | r.Encode("important", "application/json", 2, &task) 236 | } 237 | } 238 | 239 | type mockSerializer struct { 240 | DecodeF func(p int, s string, t *Task) error 241 | EncodeF func(p int, t *Task) (s string, err error) 242 | } 243 | 244 | func (ser *mockSerializer) Decode(p int, s string, t *Task) error { 245 | if ser.DecodeF == nil { 246 | return nil 247 | } 248 | return ser.DecodeF(p, s, t) 249 | } 250 | 251 | func (ser *mockSerializer) Encode(p int, t *Task) (string, error) { 252 | if ser.EncodeF == nil { 253 | return "", nil 254 | } 255 | return ser.EncodeF(p, t) 256 | } 257 | -------------------------------------------------------------------------------- /rabbitmq/broker.go: -------------------------------------------------------------------------------- 1 | // Package rabbitmq implements a Celery broker using RabbitMQ 2 | // and github.com/rabbitmq/amqp091-go. 3 | package rabbitmq 4 | 5 | import ( 6 | "encoding/base64" 7 | "encoding/json" 8 | "time" 9 | 10 | amqp "github.com/rabbitmq/amqp091-go" 11 | 12 | "github.com/marselester/gopher-celery/internal/broker" 13 | ) 14 | 15 | // DefaultAmqpUri defines the default AMQP URI which is used to connect to RabbitMQ. 16 | const DefaultAmqpUri = "amqp://guest:guest@localhost:5672/" 17 | 18 | // DefaultReceiveTimeout defines how many seconds the broker's Receive command 19 | // should block waiting for results from RabbitMQ. 20 | const DefaultReceiveTimeout = 5 21 | 22 | // BrokerOption sets up a Broker. 23 | type BrokerOption func(*Broker) 24 | 25 | // Broker is a RabbitMQ broker that sends/receives messages from specified queues. 26 | type Broker struct { 27 | amqpUri string 28 | receiveTimeout time.Duration 29 | rawMode bool 30 | queues []string 31 | conn *amqp.Connection 32 | channel *amqp.Channel 33 | delivery map[string]<-chan amqp.Delivery 34 | } 35 | 36 | // WithAmqpUri sets the AMQP connection URI to RabbitMQ. 37 | func WithAmqpUri(amqpUri string) BrokerOption { 38 | return func(br *Broker) { 39 | br.amqpUri = amqpUri 40 | } 41 | } 42 | 43 | // WithReceiveTimeout sets a timeout of how long the broker's Receive command 44 | // should block waiting for results from RabbitMQ. 45 | // Larger the timeout, longer the client will have to wait for Celery app to exit. 46 | // Smaller the timeout, more Get commands would have to be sent to RabbitMQ. 47 | func WithReceiveTimeout(timeout time.Duration) BrokerOption { 48 | return func(br *Broker) { 49 | br.receiveTimeout = timeout 50 | } 51 | } 52 | 53 | // WithClient sets RabbitMQ client representing a connection to RabbitMQ. 54 | func WithClient(c *amqp.Connection) BrokerOption { 55 | return func(br *Broker) { 56 | br.conn = c 57 | } 58 | } 59 | 60 | // NewBroker creates a broker backed by RabbitMQ. 61 | // By default, it connects to localhost. 62 | func NewBroker(options ...BrokerOption) (*Broker, error) { 63 | br := Broker{ 64 | amqpUri: DefaultAmqpUri, 65 | receiveTimeout: DefaultReceiveTimeout * time.Second, 66 | rawMode: false, 67 | delivery: make(map[string]<-chan amqp.Delivery), 68 | } 69 | for _, opt := range options { 70 | opt(&br) 71 | } 72 | 73 | if br.conn == nil { 74 | conn, err := amqp.Dial(br.amqpUri) 75 | if err != nil { 76 | return nil, err 77 | } 78 | 79 | br.conn = conn 80 | } 81 | 82 | channel, err := br.conn.Channel() 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | br.channel = channel 88 | 89 | return &br, nil 90 | } 91 | 92 | // Send inserts the specified message at the head of the queue. 93 | // Note, the method is safe to call concurrently. 94 | func (br *Broker) Send(m []byte, q string) error { 95 | var ( 96 | headers map[string]interface{} 97 | body []byte 98 | contentType string 99 | contentEncoding string 100 | deliveryMode uint8 101 | correlationId string 102 | replyTo string 103 | ) 104 | 105 | if br.rawMode { 106 | headers = make(amqp.Table) 107 | body = m 108 | contentType = "application/json" 109 | contentEncoding = "utf-8" 110 | deliveryMode = 2 111 | correlationId = "" 112 | replyTo = "" 113 | } else { 114 | var msgmap map[string]interface{} 115 | err := json.Unmarshal(m, &msgmap) 116 | if err != nil { 117 | return err 118 | } 119 | 120 | headers = msgmap["headers"].(map[string]interface{}) 121 | body, err = base64.StdEncoding.DecodeString(msgmap["body"].(string)) 122 | if err != nil { 123 | return err 124 | } 125 | contentType = msgmap["content-type"].(string) 126 | contentEncoding = msgmap["content-encoding"].(string) 127 | 128 | properties_in := msgmap["properties"].(map[string]interface{}) 129 | deliveryMode = uint8(properties_in["delivery_mode"].(float64)) 130 | correlationId = properties_in["correlation_id"].(string) 131 | replyTo = properties_in["reply_to"].(string) 132 | } 133 | 134 | return br.channel.Publish( 135 | "", // exchange 136 | q, // routing key 137 | false, // mandatory 138 | false, // immediate 139 | amqp.Publishing{ 140 | Headers: headers, 141 | ContentType: contentType, 142 | ContentEncoding: contentEncoding, 143 | DeliveryMode: deliveryMode, 144 | CorrelationId: correlationId, 145 | ReplyTo: replyTo, 146 | Body: body, 147 | }) 148 | } 149 | 150 | // Observe sets the queues from which the tasks should be received. 151 | // Note, the method is not concurrency safe. 152 | func (br *Broker) Observe(queues []string) error { 153 | br.queues = queues 154 | 155 | var ( 156 | durable = true 157 | autoDelete = false 158 | exclusive = false 159 | noWait = false 160 | ) 161 | for _, queue := range queues { 162 | // Check whether the queue exists. 163 | // If the queue doesn't exist, attempt to create it. 164 | _, err := br.channel.QueueDeclarePassive( 165 | queue, 166 | durable, 167 | autoDelete, 168 | exclusive, 169 | noWait, 170 | nil, 171 | ) 172 | if err != nil { 173 | // QueueDeclarePassive() will close the channel if the queue does not exist, 174 | // so we have to create a new channel when this happens. 175 | if br.channel.IsClosed() { 176 | channel, err := br.conn.Channel() 177 | if err != nil { 178 | return err 179 | } 180 | 181 | br.channel = channel 182 | } 183 | 184 | _, err = br.channel.QueueDeclare( 185 | queue, 186 | durable, 187 | autoDelete, 188 | exclusive, 189 | noWait, 190 | nil, 191 | ) 192 | if err != nil { 193 | return err 194 | } 195 | } 196 | } 197 | 198 | return nil 199 | } 200 | 201 | // Receive fetches a Celery task message from a tail of one of the queues in RabbitMQ. 202 | // After a timeout it returns nil, nil. 203 | func (br *Broker) Receive() ([]byte, error) { 204 | queue := br.queues[0] 205 | // Put the Celery queue name to the end of the slice for fair processing. 206 | broker.Move2back(br.queues, queue) 207 | 208 | var err error 209 | 210 | delivery, deliveryExists := br.delivery[queue] 211 | if !deliveryExists { 212 | delivery, err = br.channel.Consume( 213 | queue, // queue 214 | "", // consumer 215 | true, // autoAck 216 | false, // exclusive 217 | false, // noLocal (ignored) 218 | false, // noWait 219 | nil, // args 220 | ) 221 | if err != nil { 222 | return nil, err 223 | } 224 | 225 | br.delivery[queue] = delivery 226 | } 227 | 228 | select { 229 | case msg := <-delivery: 230 | if br.rawMode { 231 | return msg.Body, nil 232 | } 233 | 234 | // Marshal msg from RabbitMQ Celery format to internal Celery format. 235 | 236 | properties := make(map[string]interface{}) 237 | properties["correlation_id"] = msg.CorrelationId 238 | properties["reply_to"] = msg.ReplyTo 239 | properties["delivery_mode"] = msg.DeliveryMode 240 | properties["delivery_info"] = map[string]interface{}{ 241 | "exchange": msg.Exchange, 242 | "routing_key": msg.RoutingKey, 243 | } 244 | properties["priority"] = msg.Priority 245 | properties["body_encoding"] = "base64" 246 | properties["delivery_tag"] = msg.DeliveryTag 247 | 248 | imsg := make(map[string]interface{}) 249 | imsg["body"] = msg.Body 250 | imsg["content-encoding"] = msg.ContentEncoding 251 | imsg["content-type"] = msg.ContentType 252 | imsg["headers"] = msg.Headers 253 | imsg["properties"] = properties 254 | 255 | var result []byte 256 | result, err := json.Marshal(imsg) 257 | if err != nil { 258 | return nil, err 259 | } 260 | 261 | return result, nil 262 | 263 | case <-time.After(br.receiveTimeout): 264 | // Receive timeout 265 | return nil, nil 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /celery_test.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "strings" 8 | "sync/atomic" 9 | "testing" 10 | "time" 11 | 12 | "github.com/go-kit/log" 13 | 14 | "github.com/marselester/gopher-celery/goredis" 15 | "github.com/marselester/gopher-celery/protocol" 16 | "github.com/marselester/gopher-celery/rabbitmq" 17 | ) 18 | 19 | func TestExecuteTaskPanic(t *testing.T) { 20 | app := NewApp() 21 | app.Register( 22 | "myproject.apps.myapp.tasks.mytask", 23 | "important", 24 | func(ctx context.Context, p *TaskParam) error { 25 | _ = p.Args()[100] 26 | return nil 27 | }, 28 | ) 29 | 30 | m := protocol.Task{ 31 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 32 | Name: "myproject.apps.myapp.tasks.mytask", 33 | Args: []interface{}{"fizz"}, 34 | Kwargs: map[string]interface{}{ 35 | "b": "bazz", 36 | }, 37 | } 38 | 39 | want := "unexpected task error" 40 | err := app.executeTask(context.Background(), &m) 41 | if !strings.HasPrefix(err.Error(), want) { 42 | t.Errorf("expected %q got %q", want, err) 43 | } 44 | } 45 | 46 | func TestExecuteTaskMiddlewares(t *testing.T) { 47 | // The middlewares are called in the order they were defined, e.g., A -> B -> task. 48 | tests := map[string]struct { 49 | middlewares []Middleware 50 | want string 51 | }{ 52 | "A-B-task": { 53 | middlewares: []Middleware{ 54 | func(next TaskF) TaskF { 55 | return func(ctx context.Context, p *TaskParam) error { 56 | err := next(ctx, p) 57 | return fmt.Errorf("A -> %w", err) 58 | } 59 | }, 60 | func(next TaskF) TaskF { 61 | return func(ctx context.Context, p *TaskParam) error { 62 | err := next(ctx, p) 63 | return fmt.Errorf("B -> %w", err) 64 | } 65 | }, 66 | }, 67 | want: "A -> B -> task", 68 | }, 69 | "A-task": { 70 | middlewares: []Middleware{ 71 | func(next TaskF) TaskF { 72 | return func(ctx context.Context, p *TaskParam) error { 73 | err := next(ctx, p) 74 | return fmt.Errorf("A -> %w", err) 75 | } 76 | }, 77 | }, 78 | want: "A -> task", 79 | }, 80 | "empty chain": { 81 | middlewares: []Middleware{}, 82 | want: "task", 83 | }, 84 | "nil chain": { 85 | middlewares: nil, 86 | want: "task", 87 | }, 88 | "nil middleware panic": { 89 | middlewares: []Middleware{nil}, 90 | want: "unexpected task error", 91 | }, 92 | } 93 | 94 | ctx := context.Background() 95 | m := protocol.Task{ 96 | ID: "0ad73c66-f4c9-4600-bd20-96746e720eed", 97 | Name: "myproject.apps.myapp.tasks.mytask", 98 | Args: []interface{}{"fizz"}, 99 | Kwargs: map[string]interface{}{ 100 | "b": "bazz", 101 | }, 102 | } 103 | for name, tc := range tests { 104 | t.Run(name, func(t *testing.T) { 105 | app := NewApp( 106 | WithMiddlewares(tc.middlewares...), 107 | ) 108 | app.Register( 109 | "myproject.apps.myapp.tasks.mytask", 110 | "important", 111 | func(ctx context.Context, p *TaskParam) error { 112 | return fmt.Errorf("task") 113 | }, 114 | ) 115 | 116 | err := app.executeTask(ctx, &m) 117 | if !strings.HasPrefix(err.Error(), tc.want) { 118 | t.Errorf("expected %q got %q", tc.want, err) 119 | } 120 | }) 121 | } 122 | } 123 | 124 | func TestProduceAndConsume(t *testing.T) { 125 | app := NewApp(WithLogger(log.NewJSONLogger(os.Stderr))) 126 | err := app.Delay( 127 | "myproject.apps.myapp.tasks.mytask", 128 | "important", 129 | 2, 130 | 3, 131 | ) 132 | if err != nil { 133 | t.Fatal(err) 134 | } 135 | 136 | // The test finishes either when ctx times out or the task finishes. 137 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 138 | t.Cleanup(cancel) 139 | 140 | var sum int 141 | app.Register( 142 | "myproject.apps.myapp.tasks.mytask", 143 | "important", 144 | func(ctx context.Context, p *TaskParam) error { 145 | defer cancel() 146 | 147 | p.NameArgs("a", "b") 148 | sum = p.MustInt("a") + p.MustInt("b") 149 | return nil 150 | }, 151 | ) 152 | if err := app.Run(ctx); err != nil { 153 | t.Error(err) 154 | } 155 | 156 | want := 5 157 | if want != sum { 158 | t.Errorf("expected sum %d got %d", want, sum) 159 | } 160 | } 161 | 162 | func TestProduceAndConsume100times(t *testing.T) { 163 | app := NewApp(WithLogger(log.NewJSONLogger(os.Stderr))) 164 | for i := 0; i < 100; i++ { 165 | err := app.Delay( 166 | "myproject.apps.myapp.tasks.mytask", 167 | "important", 168 | 2, 169 | 3, 170 | ) 171 | if err != nil { 172 | t.Fatal(err) 173 | } 174 | } 175 | 176 | // The test finishes either when ctx times out or all the tasks finish. 177 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 178 | t.Cleanup(cancel) 179 | 180 | var sum int32 181 | app.Register( 182 | "myproject.apps.myapp.tasks.mytask", 183 | "important", 184 | func(ctx context.Context, p *TaskParam) error { 185 | p.NameArgs("a", "b") 186 | atomic.AddInt32( 187 | &sum, 188 | int32(p.MustInt("a")+p.MustInt("b")), 189 | ) 190 | return nil 191 | }, 192 | ) 193 | if err := app.Run(ctx); err != nil { 194 | t.Error(err) 195 | } 196 | 197 | var want int32 = 500 198 | if want != sum { 199 | t.Errorf("expected sum %d got %d", want, sum) 200 | } 201 | } 202 | 203 | func TestGoredisProduceAndConsume100times(t *testing.T) { 204 | app := NewApp( 205 | WithBroker(goredis.NewBroker()), 206 | WithLogger(log.NewJSONLogger(os.Stderr)), 207 | ) 208 | for i := 0; i < 100; i++ { 209 | err := app.Delay( 210 | "myproject.apps.myapp.tasks.mytask", 211 | "important", 212 | 2, 213 | 3, 214 | ) 215 | if err != nil { 216 | t.Fatal(err) 217 | } 218 | } 219 | 220 | // The test finishes either when ctx times out or all the tasks finish. 221 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 222 | t.Cleanup(cancel) 223 | 224 | var sum int32 225 | app.Register( 226 | "myproject.apps.myapp.tasks.mytask", 227 | "important", 228 | func(ctx context.Context, p *TaskParam) error { 229 | p.NameArgs("a", "b") 230 | atomic.AddInt32( 231 | &sum, 232 | int32(p.MustInt("a")+p.MustInt("b")), 233 | ) 234 | return nil 235 | }, 236 | ) 237 | if err := app.Run(ctx); err != nil { 238 | t.Error(err) 239 | } 240 | 241 | var want int32 = 500 242 | if want != sum { 243 | t.Errorf("expected sum %d got %d", want, sum) 244 | } 245 | } 246 | 247 | func TestRabbitmqProduceAndConsume100times(t *testing.T) { 248 | br, err := rabbitmq.NewBroker() 249 | if err != nil { 250 | t.Fatal(err) 251 | } 252 | 253 | app := NewApp( 254 | WithBroker(br), 255 | WithLogger(log.NewJSONLogger(os.Stderr)), 256 | ) 257 | 258 | queue := "rabbitmq_broker_test" 259 | 260 | // Create the queue, if it doesn't exist. 261 | app.conf.broker.Observe([]string{queue}) 262 | 263 | for i := 0; i < 100; i++ { 264 | err := app.Delay( 265 | "myproject.apps.myapp.tasks.mytask", 266 | queue, 267 | 2, 268 | 3, 269 | ) 270 | if err != nil { 271 | t.Fatal(err) 272 | } 273 | } 274 | 275 | // The test finishes either when ctx times out or all the tasks finish. 276 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 277 | t.Cleanup(cancel) 278 | 279 | var sum int32 280 | app.Register( 281 | "myproject.apps.myapp.tasks.mytask", 282 | queue, 283 | func(ctx context.Context, p *TaskParam) error { 284 | p.NameArgs("a", "b") 285 | atomic.AddInt32( 286 | &sum, 287 | int32(p.MustInt("a")+p.MustInt("b")), 288 | ) 289 | return nil 290 | }, 291 | ) 292 | if err := app.Run(ctx); err != nil { 293 | t.Error(err) 294 | } 295 | 296 | var want int32 = 500 297 | if want != sum { 298 | t.Errorf("expected sum %d got %d", want, sum) 299 | } 300 | 301 | } 302 | 303 | func TestConsumeSequentially(t *testing.T) { 304 | app := NewApp( 305 | WithLogger(log.NewJSONLogger(os.Stderr)), 306 | WithMaxWorkers(1), 307 | ) 308 | if err := app.Delay("t1", "q"); err != nil { 309 | t.Fatal(err) 310 | } 311 | if err := app.Delay("t2", "q"); err != nil { 312 | t.Fatal(err) 313 | } 314 | 315 | // The test finishes either when ctx times out or all the tasks finish. 316 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 317 | t.Cleanup(cancel) 318 | 319 | var t1Done, t2Done atomic.Bool 320 | app.Register("t1", "q", func(ctx context.Context, p *TaskParam) error { 321 | time.Sleep(100 * time.Millisecond) 322 | 323 | if t2Done.Load() { 324 | t.Error("t2 finished before t1") 325 | } 326 | t1Done.Store(true) 327 | 328 | return nil 329 | }) 330 | app.Register("t2", "q", func(ctx context.Context, p *TaskParam) error { 331 | if !t1Done.Load() { 332 | t.Error("t2 started before t1 finished") 333 | } 334 | t2Done.Store(true) 335 | 336 | return nil 337 | }) 338 | if err := app.Run(ctx); err != nil { 339 | t.Error(err) 340 | } 341 | } 342 | -------------------------------------------------------------------------------- /bench-old.txt: -------------------------------------------------------------------------------- 1 | goos: darwin 2 | goarch: amd64 3 | pkg: github.com/marselester/gopher-celery/internal/protocol 4 | cpu: Intel(R) Core(TM) i5-10600 CPU @ 3.30GHz 5 | BenchmarkJSONSerializerEncode_v2NoParams-12 405000379 2.960 ns/op 0 B/op 0 allocs/op 6 | BenchmarkJSONSerializerEncode_v2NoParams-12 403379344 2.970 ns/op 0 B/op 0 allocs/op 7 | BenchmarkJSONSerializerEncode_v2NoParams-12 403119211 2.980 ns/op 0 B/op 0 allocs/op 8 | BenchmarkJSONSerializerEncode_v2NoParams-12 403116849 2.967 ns/op 0 B/op 0 allocs/op 9 | BenchmarkJSONSerializerEncode_v2NoParams-12 403092967 2.988 ns/op 0 B/op 0 allocs/op 10 | BenchmarkJSONSerializerEncode_v2NoParams-12 396784292 2.973 ns/op 0 B/op 0 allocs/op 11 | BenchmarkJSONSerializerEncode_v2NoParams-12 404032770 2.962 ns/op 0 B/op 0 allocs/op 12 | BenchmarkJSONSerializerEncode_v2NoParams-12 399675436 2.964 ns/op 0 B/op 0 allocs/op 13 | BenchmarkJSONSerializerEncode_v2NoParams-12 399877645 2.985 ns/op 0 B/op 0 allocs/op 14 | BenchmarkJSONSerializerEncode_v2NoParams-12 402808989 2.965 ns/op 0 B/op 0 allocs/op 15 | BenchmarkJSONSerializerEncode_v2Args-12 3420657 351.1 ns/op 248 B/op 3 allocs/op 16 | BenchmarkJSONSerializerEncode_v2Args-12 3384900 350.4 ns/op 248 B/op 3 allocs/op 17 | BenchmarkJSONSerializerEncode_v2Args-12 3437300 349.7 ns/op 248 B/op 3 allocs/op 18 | BenchmarkJSONSerializerEncode_v2Args-12 3422995 349.4 ns/op 248 B/op 3 allocs/op 19 | BenchmarkJSONSerializerEncode_v2Args-12 3424119 350.1 ns/op 248 B/op 3 allocs/op 20 | BenchmarkJSONSerializerEncode_v2Args-12 3433743 351.0 ns/op 248 B/op 3 allocs/op 21 | BenchmarkJSONSerializerEncode_v2Args-12 3383856 350.0 ns/op 248 B/op 3 allocs/op 22 | BenchmarkJSONSerializerEncode_v2Args-12 3403886 351.0 ns/op 248 B/op 3 allocs/op 23 | BenchmarkJSONSerializerEncode_v2Args-12 3436648 351.7 ns/op 248 B/op 3 allocs/op 24 | BenchmarkJSONSerializerEncode_v2Args-12 3411801 350.1 ns/op 248 B/op 3 allocs/op 25 | BenchmarkJSONSerializerEncode_v2Kwargs-12 2062453 581.1 ns/op 472 B/op 7 allocs/op 26 | BenchmarkJSONSerializerEncode_v2Kwargs-12 2057394 580.3 ns/op 472 B/op 7 allocs/op 27 | BenchmarkJSONSerializerEncode_v2Kwargs-12 2050737 581.9 ns/op 472 B/op 7 allocs/op 28 | BenchmarkJSONSerializerEncode_v2Kwargs-12 1906161 584.2 ns/op 472 B/op 7 allocs/op 29 | BenchmarkJSONSerializerEncode_v2Kwargs-12 2050518 581.9 ns/op 472 B/op 7 allocs/op 30 | BenchmarkJSONSerializerEncode_v2Kwargs-12 2048047 582.1 ns/op 472 B/op 7 allocs/op 31 | BenchmarkJSONSerializerEncode_v2Kwargs-12 2061298 581.8 ns/op 472 B/op 7 allocs/op 32 | BenchmarkJSONSerializerEncode_v2Kwargs-12 2063181 583.8 ns/op 472 B/op 7 allocs/op 33 | BenchmarkJSONSerializerEncode_v2Kwargs-12 2057473 584.6 ns/op 472 B/op 7 allocs/op 34 | BenchmarkJSONSerializerEncode_v2Kwargs-12 2043186 582.4 ns/op 472 B/op 7 allocs/op 35 | BenchmarkJSONSerializerEncode_v2ArgsKwargs-12 1522341 787.0 ns/op 528 B/op 8 allocs/op 36 | BenchmarkJSONSerializerEncode_v2ArgsKwargs-12 1523054 792.9 ns/op 528 B/op 8 allocs/op 37 | BenchmarkJSONSerializerEncode_v2ArgsKwargs-12 1447650 811.3 ns/op 528 B/op 8 allocs/op 38 | BenchmarkJSONSerializerEncode_v2ArgsKwargs-12 1522234 786.9 ns/op 528 B/op 8 allocs/op 39 | BenchmarkJSONSerializerEncode_v2ArgsKwargs-12 1526050 786.4 ns/op 528 B/op 8 allocs/op 40 | BenchmarkJSONSerializerEncode_v2ArgsKwargs-12 1522459 786.8 ns/op 528 B/op 8 allocs/op 41 | BenchmarkJSONSerializerEncode_v2ArgsKwargs-12 1519941 786.8 ns/op 528 B/op 8 allocs/op 42 | BenchmarkJSONSerializerEncode_v2ArgsKwargs-12 1523570 817.0 ns/op 528 B/op 8 allocs/op 43 | BenchmarkJSONSerializerEncode_v2ArgsKwargs-12 1526043 787.6 ns/op 528 B/op 8 allocs/op 44 | BenchmarkJSONSerializerEncode_v2ArgsKwargs-12 1517802 788.9 ns/op 528 B/op 8 allocs/op 45 | BenchmarkJSONSerializerEncode_v1NoParams-12 1000000 1108 ns/op 672 B/op 4 allocs/op 46 | BenchmarkJSONSerializerEncode_v1NoParams-12 1000000 1111 ns/op 672 B/op 4 allocs/op 47 | BenchmarkJSONSerializerEncode_v1NoParams-12 1000000 1108 ns/op 672 B/op 4 allocs/op 48 | BenchmarkJSONSerializerEncode_v1NoParams-12 1000000 1117 ns/op 672 B/op 4 allocs/op 49 | BenchmarkJSONSerializerEncode_v1NoParams-12 1000000 1111 ns/op 672 B/op 4 allocs/op 50 | BenchmarkJSONSerializerEncode_v1NoParams-12 1000000 1115 ns/op 672 B/op 4 allocs/op 51 | BenchmarkJSONSerializerEncode_v1NoParams-12 1000000 1120 ns/op 672 B/op 4 allocs/op 52 | BenchmarkJSONSerializerEncode_v1NoParams-12 1000000 1120 ns/op 673 B/op 4 allocs/op 53 | BenchmarkJSONSerializerEncode_v1NoParams-12 1000000 1125 ns/op 672 B/op 4 allocs/op 54 | BenchmarkJSONSerializerEncode_v1NoParams-12 993504 1131 ns/op 673 B/op 4 allocs/op 55 | BenchmarkJSONSerializerEncode_v1Args-12 984577 1221 ns/op 672 B/op 4 allocs/op 56 | BenchmarkJSONSerializerEncode_v1Args-12 962990 1205 ns/op 672 B/op 4 allocs/op 57 | BenchmarkJSONSerializerEncode_v1Args-12 985591 1204 ns/op 672 B/op 4 allocs/op 58 | BenchmarkJSONSerializerEncode_v1Args-12 968157 1202 ns/op 672 B/op 4 allocs/op 59 | BenchmarkJSONSerializerEncode_v1Args-12 991500 1209 ns/op 672 B/op 4 allocs/op 60 | BenchmarkJSONSerializerEncode_v1Args-12 984794 1206 ns/op 672 B/op 4 allocs/op 61 | BenchmarkJSONSerializerEncode_v1Args-12 991324 1207 ns/op 672 B/op 4 allocs/op 62 | BenchmarkJSONSerializerEncode_v1Args-12 991214 1244 ns/op 672 B/op 4 allocs/op 63 | BenchmarkJSONSerializerEncode_v1Args-12 984279 1205 ns/op 672 B/op 4 allocs/op 64 | BenchmarkJSONSerializerEncode_v1Args-12 964975 1205 ns/op 673 B/op 4 allocs/op 65 | BenchmarkJSONSerializerEncode_v1Kwargs-12 695530 1677 ns/op 1001 B/op 10 allocs/op 66 | BenchmarkJSONSerializerEncode_v1Kwargs-12 694477 1688 ns/op 1001 B/op 10 allocs/op 67 | BenchmarkJSONSerializerEncode_v1Kwargs-12 702664 1687 ns/op 1001 B/op 10 allocs/op 68 | BenchmarkJSONSerializerEncode_v1Kwargs-12 700444 1682 ns/op 1001 B/op 10 allocs/op 69 | BenchmarkJSONSerializerEncode_v1Kwargs-12 692666 1682 ns/op 1001 B/op 10 allocs/op 70 | BenchmarkJSONSerializerEncode_v1Kwargs-12 705742 1685 ns/op 1001 B/op 10 allocs/op 71 | BenchmarkJSONSerializerEncode_v1Kwargs-12 702127 1679 ns/op 1001 B/op 10 allocs/op 72 | BenchmarkJSONSerializerEncode_v1Kwargs-12 693442 1680 ns/op 1001 B/op 10 allocs/op 73 | BenchmarkJSONSerializerEncode_v1Kwargs-12 675585 1680 ns/op 1001 B/op 10 allocs/op 74 | BenchmarkJSONSerializerEncode_v1Kwargs-12 699783 1687 ns/op 1001 B/op 10 allocs/op 75 | BenchmarkJSONSerializerEncode_v1ArgsKwargs-12 668502 1775 ns/op 1001 B/op 10 allocs/op 76 | BenchmarkJSONSerializerEncode_v1ArgsKwargs-12 639879 1775 ns/op 1001 B/op 10 allocs/op 77 | BenchmarkJSONSerializerEncode_v1ArgsKwargs-12 671182 1776 ns/op 1001 B/op 10 allocs/op 78 | BenchmarkJSONSerializerEncode_v1ArgsKwargs-12 668834 1769 ns/op 1001 B/op 10 allocs/op 79 | BenchmarkJSONSerializerEncode_v1ArgsKwargs-12 671733 1770 ns/op 1001 B/op 10 allocs/op 80 | BenchmarkJSONSerializerEncode_v1ArgsKwargs-12 659529 1773 ns/op 1001 B/op 10 allocs/op 81 | BenchmarkJSONSerializerEncode_v1ArgsKwargs-12 671095 1770 ns/op 1001 B/op 10 allocs/op 82 | BenchmarkJSONSerializerEncode_v1ArgsKwargs-12 646405 1768 ns/op 1001 B/op 10 allocs/op 83 | BenchmarkJSONSerializerEncode_v1ArgsKwargs-12 660436 1771 ns/op 1001 B/op 10 allocs/op 84 | BenchmarkJSONSerializerEncode_v1ArgsKwargs-12 668602 1773 ns/op 1001 B/op 10 allocs/op 85 | PASS 86 | ok github.com/marselester/gopher-celery/internal/protocol 115.912s 87 | -------------------------------------------------------------------------------- /protocol/serializer.go: -------------------------------------------------------------------------------- 1 | // Package protocol provides means to encode/decode task messages 2 | // as described in https://github.com/celery/celery/blob/master/docs/internals/protocol.rst. 3 | package protocol 4 | 5 | import ( 6 | "encoding/json" 7 | "fmt" 8 | "os" 9 | "time" 10 | 11 | "github.com/google/uuid" 12 | ) 13 | 14 | // Task represents a task message that provides essential params to run a task. 15 | type Task struct { 16 | // ID id a unique id of the task in UUID v4 format (required). 17 | ID string 18 | // Name is a name of the task (required). 19 | Name string 20 | // Args is a list of arguments. 21 | // It will be an empty list if not provided. 22 | Args []interface{} 23 | // Kwargs is a dictionary of keyword arguments. 24 | // It will be an empty dictionary if not provided. 25 | Kwargs map[string]interface{} 26 | // Expires is an expiration date in ISO 8601 format. 27 | // If not provided the message will never expire. 28 | // The message will be expired when the message is received and the expiration date has been exceeded. 29 | Expires time.Time 30 | } 31 | 32 | // IsExpired returns true if the message is expired 33 | // and shouldn't be processed. 34 | func (t *Task) IsExpired() bool { 35 | return !t.Expires.IsZero() && t.Expires.Before(time.Now()) 36 | } 37 | 38 | // The mime-type describing the serializers. 39 | const ( 40 | MimeJSON = "application/json" 41 | ) 42 | 43 | // Supported protocol versions. 44 | const ( 45 | V1 = 1 46 | V2 = 2 47 | ) 48 | 49 | // Serializer encodes/decodes Celery tasks (message's body param to be precise). 50 | // See https://docs.celeryq.dev/projects/kombu/en/latest/userguide/serialization.html. 51 | type Serializer interface { 52 | // Decode decodes the message body s into task t 53 | // using protocol p which could be version 1 or 2. 54 | Decode(p int, s string, t *Task) error 55 | // Encode encodes task t using protocol p and returns the message body s. 56 | Encode(p int, t *Task) (s string, err error) 57 | } 58 | 59 | // NewSerializerRegistry creates a registry of serializers. 60 | func NewSerializerRegistry() *SerializerRegistry { 61 | js := NewJSONSerializer() 62 | r := SerializerRegistry{ 63 | serializers: make(map[string]Serializer), 64 | encoding: make(map[string]string), 65 | uuid4: uuid.NewString, 66 | } 67 | r.Register(js, "json", "utf-8") 68 | r.Register(js, "application/json", "utf-8") 69 | 70 | var ( 71 | host string 72 | err error 73 | ) 74 | if host, err = os.Hostname(); err != nil { 75 | host = "unknown" 76 | } 77 | r.origin = fmt.Sprintf("%d@%s", os.Getpid(), host) 78 | 79 | return &r 80 | } 81 | 82 | // SerializerRegistry encodes/decodes task messages using registered serializers. 83 | // Celery relies on JSON format to store message metadata 84 | // such as content type and headers. 85 | // Task details (args, kwargs) are encoded in message body in base64 and JSON by default. 86 | // The encoding is indicated by body_encoding and content-type message params. 87 | // Therefore a client doesn't have to specify the formats since the registry can 88 | // pick an appropriate decoder based on the aforementioned params. 89 | type SerializerRegistry struct { 90 | // serializers helps to look up a serializer by a content-type, 91 | // see also https://github.com/celery/kombu/blob/master/kombu/serialization.py#L388. 92 | serializers map[string]Serializer 93 | // encoding maps content-type to its encoding, e.g., application/json uses utf-8 encoding. 94 | encoding map[string]string 95 | // uuid4 returns uuid v4, e.g., 0ad73c66-f4c9-4600-bd20-96746e720eed. 96 | uuid4 func() string 97 | // origin is a pid@host used in encoding task messages. 98 | origin string 99 | } 100 | 101 | // Register registers a custom serializer where 102 | // mime is the mime-type describing the serialized structure, e.g., application/json, 103 | // and encoding is the content encoding which is usually utf-8 or binary. 104 | func (r *SerializerRegistry) Register(serializer Serializer, mime, encoding string) { 105 | r.serializers[mime] = serializer 106 | r.encoding[mime] = encoding 107 | } 108 | 109 | type inboundMessage struct { 110 | Body string `json:"body"` 111 | ContentType string `json:"content-type"` 112 | Header inboundMessageV2Header `json:"headers"` 113 | } 114 | type inboundMessageV2Header struct { 115 | ID string `json:"id"` 116 | Task string `json:"task"` 117 | Expires time.Time `json:"expires"` 118 | } 119 | 120 | // Decode decodes the raw message and returns a task info. 121 | // If the header doesn't contain a task name, then protocol v1 is assumed. 122 | // Otherwise the protocol v2 is used. 123 | func (r *SerializerRegistry) Decode(raw []byte) (*Task, error) { 124 | var m inboundMessage 125 | err := json.Unmarshal(raw, &m) 126 | if err != nil { 127 | return nil, fmt.Errorf("json decode: %w", err) 128 | } 129 | 130 | var ( 131 | prot int 132 | t Task 133 | ) 134 | // Protocol version is detected by the presence of a task message header. 135 | if m.Header.Task == "" { 136 | prot = V1 137 | } else { 138 | prot = V2 139 | t.ID = m.Header.ID 140 | t.Name = m.Header.Task 141 | t.Expires = m.Header.Expires 142 | } 143 | 144 | ser := r.serializers[m.ContentType] 145 | if ser == nil { 146 | return nil, fmt.Errorf("unregistered serializer: %s", m.ContentType) 147 | } 148 | if err = ser.Decode(prot, m.Body, &t); err != nil { 149 | return nil, fmt.Errorf("parsing body v%d: %w", prot, err) 150 | } 151 | if t.Name == "" { 152 | return nil, fmt.Errorf("missing task name") 153 | } 154 | 155 | return &t, err 156 | } 157 | 158 | // Encode encodes the task message. 159 | func (r *SerializerRegistry) Encode(queue, mime string, prot int, t *Task) ([]byte, error) { 160 | if prot != V1 && prot != V2 { 161 | return nil, fmt.Errorf("unknown protocol version %d", prot) 162 | } 163 | 164 | ser := r.serializers[mime] 165 | if ser == nil { 166 | return nil, fmt.Errorf("unregistered serializer %s", mime) 167 | } 168 | if r.encoding[mime] == "" { 169 | return nil, fmt.Errorf("unregistered serializer encoding %s", mime) 170 | } 171 | 172 | body, err := ser.Encode(prot, t) 173 | if err != nil { 174 | return nil, fmt.Errorf("%s encode %d: %w", mime, prot, err) 175 | } 176 | 177 | if prot == V1 { 178 | return r.encodeV1(body, queue, mime, t) 179 | } else { 180 | return r.encodeV2(body, queue, mime, t) 181 | } 182 | } 183 | 184 | // jsonEmptyMap helps to reduce allocs when encoding empty maps in json. 185 | var jsonEmptyMap = json.RawMessage("{}") 186 | 187 | type outboundMessageV1 struct { 188 | Body string `json:"body"` 189 | ContentEncoding string `json:"content-encoding"` 190 | ContentType string `json:"content-type"` 191 | Header json.RawMessage `json:"headers"` 192 | Property outboundMessageProperty `json:"properties"` 193 | } 194 | 195 | type outboundMessageProperty struct { 196 | DeliveryInfo outboundMessageDeliveryInfo `json:"delivery_info"` 197 | CorrelationID string `json:"correlation_id"` 198 | ReplyTo string `json:"reply_to"` 199 | BodyEncoding string `json:"body_encoding"` 200 | DeliveryTag string `json:"delivery_tag"` 201 | DeliveryMode int `json:"delivery_mode"` 202 | // Priority is a number between 0 and 255, where 255 is the highest priority in RabbitMQ 203 | // and 0 is the highest in Redis. 204 | Priority int `json:"priority"` 205 | } 206 | type outboundMessageDeliveryInfo struct { 207 | Exchange string `json:"exchange"` 208 | RoutingKey string `json:"routing_key"` 209 | } 210 | 211 | func (r *SerializerRegistry) encodeV1(body, queue, mime string, t *Task) ([]byte, error) { 212 | m := outboundMessageV1{ 213 | Body: body, 214 | ContentEncoding: r.encoding[mime], 215 | ContentType: mime, 216 | Header: jsonEmptyMap, 217 | Property: outboundMessageProperty{ 218 | BodyEncoding: "base64", 219 | CorrelationID: t.ID, 220 | ReplyTo: r.uuid4(), 221 | DeliveryInfo: outboundMessageDeliveryInfo{ 222 | Exchange: queue, 223 | RoutingKey: queue, 224 | }, 225 | DeliveryMode: 2, 226 | DeliveryTag: r.uuid4(), 227 | }, 228 | } 229 | 230 | return json.Marshal(&m) 231 | } 232 | 233 | type outboundMessageV2 struct { 234 | Body string `json:"body"` 235 | ContentEncoding string `json:"content-encoding"` 236 | ContentType string `json:"content-type"` 237 | Header outboundMessageV2Header `json:"headers"` 238 | Property outboundMessageProperty `json:"properties"` 239 | } 240 | type outboundMessageV2Header struct { 241 | // Lang enables support for multiple languages. 242 | // Worker may redirect the message to a worker that supports the language. 243 | Lang string `json:"lang"` 244 | ID string `json:"id"` 245 | // RootID helps to keep track of workflows. 246 | RootID string `json:"root_id"` 247 | Task string `json:"task"` 248 | // Origin is the name of the node sending the task, 249 | // '@'.join([os.getpid(), socket.gethostname()]). 250 | Origin string `json:"origin"` 251 | Expires *string `json:"expires"` 252 | // Retries is a current number of times this task has been retried. 253 | // It's always set to zero. 254 | Retries int `json:"retries"` 255 | } 256 | 257 | func (r *SerializerRegistry) encodeV2(body, queue, mime string, t *Task) ([]byte, error) { 258 | m := outboundMessageV2{ 259 | Body: body, 260 | ContentEncoding: r.encoding[mime], 261 | ContentType: mime, 262 | Header: outboundMessageV2Header{ 263 | Lang: "go", 264 | ID: t.ID, 265 | RootID: t.ID, 266 | Task: t.Name, 267 | Origin: r.origin, 268 | }, 269 | Property: outboundMessageProperty{ 270 | BodyEncoding: "base64", 271 | CorrelationID: t.ID, 272 | ReplyTo: r.uuid4(), 273 | DeliveryInfo: outboundMessageDeliveryInfo{ 274 | Exchange: queue, 275 | RoutingKey: queue, 276 | }, 277 | DeliveryMode: 2, 278 | DeliveryTag: r.uuid4(), 279 | }, 280 | } 281 | if !t.Expires.IsZero() { 282 | s := t.Expires.Format(time.RFC3339) 283 | m.Header.Expires = &s 284 | } 285 | 286 | return json.Marshal(&m) 287 | } 288 | -------------------------------------------------------------------------------- /celery.go: -------------------------------------------------------------------------------- 1 | // Package celery helps to work with Celery (place tasks in queues and execute them). 2 | package celery 3 | 4 | import ( 5 | "context" 6 | "fmt" 7 | "runtime/debug" 8 | "slices" 9 | "time" 10 | 11 | "github.com/go-kit/log" 12 | "github.com/go-kit/log/level" 13 | "github.com/google/uuid" 14 | "golang.org/x/sync/errgroup" 15 | 16 | "github.com/marselester/gopher-celery/protocol" 17 | "github.com/marselester/gopher-celery/redis" 18 | ) 19 | 20 | // TaskF represents a Celery task implemented by the client. 21 | // The error doesn't affect anything, it's logged though. 22 | type TaskF func(ctx context.Context, p *TaskParam) error 23 | 24 | // Middleware is a chainable behavior modifier for tasks. 25 | // For example, a caller can collect task metrics. 26 | type Middleware func(next TaskF) TaskF 27 | 28 | // Broker is responsible for receiving and sending task messages. 29 | // For example, it knows how to read a message from a given queue in Redis. 30 | // The messages can be in defferent formats depending on Celery protocol version. 31 | type Broker interface { 32 | // Send puts a message to a queue. 33 | // Note, the method is safe to call concurrently. 34 | Send(msg []byte, queue string) error 35 | // Observe sets the queues from which the tasks should be received. 36 | // Note, the method is not concurrency safe. 37 | Observe(queues []string) error 38 | // Receive returns a raw message from one of the queues. 39 | // It blocks until there is a message available for consumption. 40 | // Note, the method is not concurrency safe. 41 | Receive() ([]byte, error) 42 | } 43 | 44 | // AsyncParam represents parameters for sending a task message. 45 | type AsyncParam struct { 46 | // Args is a list of arguments. 47 | // It will be an empty list if not provided. 48 | Args []interface{} 49 | // Kwargs is a dictionary of keyword arguments. 50 | // It will be an empty dictionary if not provided. 51 | Kwargs map[string]interface{} 52 | // Expires is an expiration date. 53 | // If not provided the message will never expire. 54 | Expires time.Time 55 | } 56 | 57 | // NewApp creates a Celery app. 58 | // The default broker is Redis assumed to run on localhost. 59 | // When producing tasks the default message serializer is json and protocol is v2. 60 | func NewApp(options ...Option) *App { 61 | app := App{ 62 | conf: Config{ 63 | logger: log.NewNopLogger(), 64 | registry: protocol.NewSerializerRegistry(), 65 | mime: protocol.MimeJSON, 66 | protocol: protocol.V2, 67 | maxWorkers: DefaultMaxWorkers, 68 | }, 69 | task: make(map[string]TaskF), 70 | taskQueue: make(map[string]string), 71 | } 72 | 73 | for _, opt := range options { 74 | opt(&app.conf) 75 | } 76 | 77 | if app.conf.broker == nil { 78 | app.conf.broker = redis.NewBroker() 79 | } 80 | 81 | return &app 82 | } 83 | 84 | // App is a Celery app to produce or consume tasks asynchronously. 85 | type App struct { 86 | // conf represents app settings. 87 | conf Config 88 | 89 | // task maps a Celery task path to a task itself, e.g., 90 | // "myproject.apps.myapp.tasks.mytask": TaskF. 91 | task map[string]TaskF 92 | // taskQueue helps to determine which queue a task belongs to, e.g., 93 | // "myproject.apps.myapp.tasks.mytask": "important". 94 | taskQueue map[string]string 95 | } 96 | 97 | // Register associates the task with given Python path and queue. 98 | // For example, when "myproject.apps.myapp.tasks.mytask" 99 | // is seen in "important" queue, the TaskF task is executed. 100 | // 101 | // Note, the method is not concurrency safe. 102 | // The tasks mustn't be registered after the app starts processing tasks. 103 | func (a *App) Register(path, queue string, task TaskF) { 104 | a.task[path] = task 105 | a.taskQueue[path] = queue 106 | } 107 | 108 | // ApplyAsync sends a task message. 109 | func (a *App) ApplyAsync(path, queue string, p *AsyncParam) error { 110 | m := protocol.Task{ 111 | ID: uuid.NewString(), 112 | Name: path, 113 | Args: p.Args, 114 | Kwargs: p.Kwargs, 115 | Expires: p.Expires, 116 | } 117 | rawMsg, err := a.conf.registry.Encode(queue, a.conf.mime, a.conf.protocol, &m) 118 | if err != nil { 119 | return fmt.Errorf("failed to encode task message: %w", err) 120 | } 121 | 122 | if err = a.conf.broker.Send(rawMsg, queue); err != nil { 123 | return fmt.Errorf("failed to send task message to broker: %w", err) 124 | } 125 | return nil 126 | } 127 | 128 | // Delay is a shortcut to send a task message, 129 | // i.e., it places the task associated with given Python path into queue. 130 | func (a *App) Delay(path, queue string, args ...interface{}) error { 131 | m := protocol.Task{ 132 | ID: uuid.NewString(), 133 | Name: path, 134 | Args: args, 135 | } 136 | rawMsg, err := a.conf.registry.Encode(queue, a.conf.mime, a.conf.protocol, &m) 137 | if err != nil { 138 | return fmt.Errorf("failed to encode task message: %w", err) 139 | } 140 | 141 | if err = a.conf.broker.Send(rawMsg, queue); err != nil { 142 | return fmt.Errorf("failed to send task message to broker: %w", err) 143 | } 144 | return nil 145 | } 146 | 147 | // Run launches the workers that process the tasks received from the broker. 148 | // The call is blocking until ctx is cancelled. 149 | // The caller mustn't register any new tasks at this point. 150 | func (a *App) Run(ctx context.Context) error { 151 | // Build list of all unique, non-empty queue names for registered tasks. 152 | qq := []string{} 153 | for _, v := range a.taskQueue { 154 | if v != "" && slices.Index(qq, v) < 0 { 155 | qq = append(qq, v) 156 | } 157 | } 158 | 159 | err := a.conf.broker.Observe(qq) 160 | if err != nil { 161 | return err 162 | } 163 | 164 | level.Debug(a.conf.logger).Log("msg", "observing queues", "queues", qq) 165 | 166 | // Tasks are processed concurrently only if there are multiple workers. 167 | if a.conf.maxWorkers <= 1 { 168 | return a.syncRun(ctx) 169 | } 170 | 171 | g, ctx := errgroup.WithContext(ctx) 172 | // There will be at most maxWorkers goroutines processing tasks, and one fetching them. 173 | g.SetLimit(a.conf.maxWorkers + 1) 174 | 175 | msgs := make(chan *protocol.Task, 1) 176 | g.Go(func() error { 177 | defer close(msgs) 178 | 179 | // One goroutine fetching and decoding tasks from queues 180 | // shouldn't be a bottleneck since the worker goroutines 181 | // usually take seconds/minutes to complete. 182 | for { 183 | // Stop fetching tasks. 184 | if ctx.Err() != nil { 185 | return nil 186 | } 187 | 188 | rawMsg, err := a.conf.broker.Receive() 189 | if err != nil { 190 | return fmt.Errorf("failed to receive a raw task message: %w", err) 191 | } 192 | // No messages in the broker so far. 193 | if rawMsg == nil { 194 | continue 195 | } 196 | 197 | m, err := a.conf.registry.Decode(rawMsg) 198 | if err != nil { 199 | level.Error(a.conf.logger).Log("msg", "failed to decode task message", "rawmsg", rawMsg, "err", err) 200 | continue 201 | } 202 | 203 | msgs <- m 204 | } 205 | }) 206 | 207 | go func() { 208 | // Start a worker when there is a task. 209 | for m := range msgs { 210 | level.Debug(a.conf.logger).Log("msg", "task received", "name", m.Name) 211 | 212 | if a.task[m.Name] == nil { 213 | level.Debug(a.conf.logger).Log("msg", "unregistered task", "name", m.Name) 214 | continue 215 | } 216 | if m.IsExpired() { 217 | level.Debug(a.conf.logger).Log("msg", "task message expired", "name", m.Name) 218 | continue 219 | } 220 | 221 | // Stop processing tasks. 222 | if ctx.Err() != nil { 223 | return 224 | } 225 | 226 | m := m 227 | g.Go(func() error { 228 | if err := a.executeTask(ctx, m); err != nil { 229 | level.Error(a.conf.logger).Log("msg", "task failed", "taskmsg", m, "err", err) 230 | } else { 231 | level.Debug(a.conf.logger).Log("msg", "task succeeded", "name", m.Name) 232 | } 233 | return nil 234 | }) 235 | } 236 | }() 237 | 238 | return g.Wait() 239 | } 240 | 241 | // syncRun processes tasks one by one. 242 | // Note, it doesn't fetch a new task until the current one is finished. 243 | func (a *App) syncRun(ctx context.Context) error { 244 | for { 245 | // Stop fetching and processing tasks. 246 | if ctx.Err() != nil { 247 | return nil 248 | } 249 | 250 | rawMsg, err := a.conf.broker.Receive() 251 | if err != nil { 252 | return fmt.Errorf("failed to receive a raw task message: %w", err) 253 | } 254 | // No messages in the broker so far. 255 | if rawMsg == nil { 256 | continue 257 | } 258 | 259 | m, err := a.conf.registry.Decode(rawMsg) 260 | if err != nil { 261 | level.Error(a.conf.logger).Log("msg", "failed to decode task message", "rawmsg", rawMsg, "err", err) 262 | continue 263 | } 264 | 265 | level.Debug(a.conf.logger).Log("msg", "task received", "name", m.Name) 266 | 267 | if a.task[m.Name] == nil { 268 | level.Debug(a.conf.logger).Log("msg", "unregistered task", "name", m.Name) 269 | continue 270 | } 271 | if m.IsExpired() { 272 | level.Debug(a.conf.logger).Log("msg", "task message expired", "name", m.Name) 273 | continue 274 | } 275 | 276 | if err = a.executeTask(ctx, m); err != nil { 277 | level.Error(a.conf.logger).Log("msg", "task failed", "taskmsg", m, "err", err) 278 | } else { 279 | level.Debug(a.conf.logger).Log("msg", "task succeeded", "name", m.Name) 280 | } 281 | } 282 | } 283 | 284 | type contextKey int 285 | 286 | const ( 287 | // ContextKeyTaskName is a context key to access task names. 288 | ContextKeyTaskName contextKey = iota 289 | // ContextKeyTaskID is a context key to access task IDs. 290 | ContextKeyTaskID 291 | ) 292 | 293 | // executeTask calls the task function with args and kwargs from the message. 294 | // If the task panics, the stack trace is returned as an error. 295 | func (a *App) executeTask(ctx context.Context, m *protocol.Task) (err error) { 296 | defer func() { 297 | if r := recover(); r != nil { 298 | err = fmt.Errorf("unexpected task error: %v: %s", r, debug.Stack()) 299 | } 300 | }() 301 | 302 | task := a.task[m.Name] 303 | // Use middlewares if a client provided them. 304 | if a.conf.chain != nil { 305 | task = a.conf.chain(task) 306 | } 307 | 308 | ctx = context.WithValue(ctx, ContextKeyTaskName, m.Name) 309 | ctx = context.WithValue(ctx, ContextKeyTaskID, m.ID) 310 | p := NewTaskParam(m.Args, m.Kwargs) 311 | return task(ctx, p) 312 | } 313 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gopher Celery 🥬 2 | 3 | [![Documentation](https://godoc.org/github.com/marselester/gopher-celery?status.svg)](https://pkg.go.dev/github.com/marselester/gopher-celery) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/marselester/gopher-celery)](https://goreportcard.com/report/github.com/marselester/gopher-celery) 5 | 6 | The objective of this project is to provide 7 | the very basic mechanism to efficiently produce and consume Celery tasks on Go side. 8 | Therefore there are no plans to support all the rich features the Python version provides, 9 | such as tasks chains, etc. 10 | Even task result backend has no practical value in the context of Gopher Celery, 11 | so it wasn't taken into account. 12 | Note, Celery has [no result backend](https://docs.celeryq.dev/en/stable/userguide/tasks.html?#result-backends) 13 | enabled by default (it incurs overhead). 14 | 15 | Typically one would want to use Gopher Celery when certain tasks on Python side 16 | take too long to complete or there is a big volume of tasks requiring lots of Python workers 17 | (expensive infrastructure). 18 | 19 | This project offers a little bit more convenient API of https://github.com/gocelery/gocelery 20 | including support for Celery protocol v2. 21 | 22 | ## Usage 23 | 24 | The Celery app can be used as either a producer or consumer (worker). 25 | To send tasks to a queue for a worker to consume, use `Delay` method. 26 | In order to process a task you should register it using `Register` method. 27 | 28 | ```python 29 | def mytask(a, b): 30 | print(a + b) 31 | ``` 32 | 33 | For example, whenever a task `mytask` is popped from `important` queue, 34 | the Go function is executed with args and kwargs obtained from the task message. 35 | By default Redis broker (localhost) is used with json task message serialization. 36 | There is also a RabbitMQ broker available. 37 | 38 | ```go 39 | app := celery.NewApp() 40 | app.Register( 41 | "myproject.apps.myapp.tasks.mytask", 42 | "important", 43 | func(ctx context.Context, p *celery.TaskParam) error { 44 | p.NameArgs("a", "b") 45 | // Methods prefixed with Must panic if they can't find an argument name 46 | // or can't cast it to the corresponding type. 47 | // The panic doesn't affect other tasks execution; it's logged. 48 | fmt.Println(p.MustInt("a") + p.MustInt("b")) 49 | // Non-nil errors are logged. 50 | return nil 51 | }, 52 | ) 53 | if err := app.Run(context.Background()); err != nil { 54 | log.Printf("celery worker error: %v", err) 55 | } 56 | ``` 57 | 58 | Here is an example of sending `mytask` task to `important` queue with `a=2`, `b=3` arguments. 59 | If a task is processed on Python side, 60 | you don't need to register the task or run the app. 61 | 62 | ```go 63 | app := celery.NewApp() 64 | err := app.Delay( 65 | "myproject.apps.myapp.tasks.mytask", 66 | "important", 67 | 2, 68 | 3, 69 | ) 70 | if err != nil { 71 | log.Printf("failed to send mytask: %v", err) 72 | } 73 | ``` 74 | 75 | ### Redis Examples 76 | 77 | Redis examples can be found in [the redis examples](examples/redis) dir. 78 | Note: You'll need a Redis server to run the examples which use Redis. 79 | 80 | ```sh 81 | $ redis-server 82 | $ cd ./examples/redis 83 | ``` 84 | 85 |
86 | 87 | Sending tasks from Go and receiving them on Python side. 88 | 89 | ```sh 90 | $ go run ./producer/ 91 | {"err":null,"msg":"task was sent using protocol v2"} 92 | {"err":null,"msg":"task was sent using protocol v1"} 93 | $ celery --app myproject worker --queues important --loglevel=debug --without-heartbeat --without-mingle 94 | ... 95 | [... WARNING/ForkPoolWorker-1] received a=fizz b=bazz 96 | [... WARNING/ForkPoolWorker-8] received a=fizz b=bazz 97 | ``` 98 | 99 |
100 | 101 |
102 | 103 | Sending tasks from Python and receiving them on Go side. 104 | 105 | ```sh 106 | $ python producer.py 107 | $ go run ./consumer/ 108 | {"msg":"waiting for tasks..."} 109 | received a=fizz b=bazz 110 | ``` 111 | 112 | To send a task with Celery Protocol version 1, run *producer.py* with the `--protocol=1` command-line argument. 113 | 114 | ```sh 115 | $ python producer.py --protocol=1 116 | ``` 117 | 118 |
119 | 120 |
121 | 122 | Most likely your Redis server won't be running on localhost when the service is deployed, 123 | so you would need to pass a connection pool to the broker. 124 | 125 | Redis connection pool. 126 | 127 | ```sh 128 | $ go run ./producer/ 129 | {"err":null,"msg":"task was sent using protocol v2"} 130 | {"err":null,"msg":"task was sent using protocol v1"} 131 | $ go run ./redis/ 132 | ``` 133 | 134 |
135 | 136 |
137 | 138 | Prometheus task metrics. 139 | 140 | ```sh 141 | $ go run ./producer/ 142 | $ go run ./metrics/ 143 | $ curl http://0.0.0.0:8080/metrics 144 | # HELP task_duration_seconds How long it took in seconds to process a task. 145 | # TYPE task_duration_seconds histogram 146 | task_duration_seconds_bucket{task="myproject.mytask",le="0.016"} 2 147 | task_duration_seconds_bucket{task="myproject.mytask",le="0.032"} 2 148 | task_duration_seconds_bucket{task="myproject.mytask",le="0.064"} 2 149 | task_duration_seconds_bucket{task="myproject.mytask",le="0.128"} 2 150 | task_duration_seconds_bucket{task="myproject.mytask",le="0.256"} 2 151 | task_duration_seconds_bucket{task="myproject.mytask",le="0.512"} 2 152 | task_duration_seconds_bucket{task="myproject.mytask",le="1.024"} 2 153 | task_duration_seconds_bucket{task="myproject.mytask",le="2.048"} 2 154 | task_duration_seconds_bucket{task="myproject.mytask",le="4.096"} 2 155 | task_duration_seconds_bucket{task="myproject.mytask",le="8.192"} 2 156 | task_duration_seconds_bucket{task="myproject.mytask",le="16.384"} 2 157 | task_duration_seconds_bucket{task="myproject.mytask",le="32.768"} 2 158 | task_duration_seconds_bucket{task="myproject.mytask",le="60"} 2 159 | task_duration_seconds_bucket{task="myproject.mytask",le="+Inf"} 2 160 | task_duration_seconds_sum{task="myproject.mytask"} 7.2802e-05 161 | task_duration_seconds_count{task="myproject.mytask"} 2 162 | # HELP tasks_total How many Celery tasks processed, partitioned by task name and error. 163 | # TYPE tasks_total counter 164 | tasks_total{error="false",task="myproject.mytask"} 2 165 | ``` 166 | 167 |
168 | 169 |
170 | 171 | Although there is no built-in support for task retries (publishing a task back to Redis), 172 | you can still retry the operation within the same goroutine. 173 | 174 | Task retries. 175 | 176 | ```sh 177 | $ go run ./retry/ 178 | ... 179 | {"attempt":1,"err":"uh oh","msg":"request failed","ts":"2022-08-07T23:42:23.401191Z"} 180 | {"attempt":2,"err":"uh oh","msg":"request failed","ts":"2022-08-07T23:42:28.337204Z"} 181 | {"attempt":3,"err":"uh oh","msg":"request failed","ts":"2022-08-07T23:42:37.279873Z"} 182 | ``` 183 | 184 |
185 | 186 | ### RabbitMQ Examples 187 | 188 | RabbitMQ examples can be found in [the rabbitmq examples](examples/rabbitmq) dir. 189 | Note: You'll need a RabbitMQ server to run the examples which use RabbitMQ. 190 | 191 | ```sh 192 | $ rabbitmq-server 193 | $ cd ./examples/rabbitmq 194 | ``` 195 | 196 |
197 | 198 | Sending tasks from Go and receiving them on Python side. 199 | 200 | ```sh 201 | $ go run ./producer/ 202 | {"err":null,"msg":"task was sent using protocol v2"} 203 | {"err":null,"msg":"task was sent using protocol v1"} 204 | $ celery --app myproject worker --queues important --loglevel=debug --without-heartbeat --without-mingle 205 | ... 206 | [... WARNING/ForkPoolWorker-1] received a=fizz b=bazz 207 | [... WARNING/ForkPoolWorker-8] received a=fizz b=bazz 208 | ``` 209 | 210 |
211 | 212 |
213 | 214 | Sending tasks from Python and receiving them on Go side. 215 | 216 | ```sh 217 | $ python producer.py 218 | $ go run ./consumer/ 219 | {"msg":"waiting for tasks..."} 220 | received a=fizz b=bazz 221 | ``` 222 | 223 | To send a task with Celery Protocol version 1, run *producer.py* with the `--protocol=1` command-line argument. 224 | 225 | ```sh 226 | $ python producer.py --protocol=1 227 | ``` 228 | 229 |
230 | 231 | ## Testing 232 | 233 | Tests require both a Redis and a RabbitMQ server running locally. 234 | 235 | ```sh 236 | $ go test -v -count=1 ./... 237 | ``` 238 | 239 | Benchmarks help to spot performance changes as the project evolves 240 | and also compare performance of serializers. 241 | For example, based on the results below the protocol v2 is faster than v1 when encoding args: 242 | 243 | - 350 nanoseconds mean time, 3 allocations (248 bytes) with 0% variation across the samples 244 | - 1.21 microseconds mean time, 4 allocations (672 bytes) with 0% variation across the samples 245 | 246 | It is recommended to run benchmarks multiple times and check 247 | how stable they are using [Benchstat](https://pkg.go.dev/golang.org/x/perf/cmd/benchstat) tool. 248 | 249 | ```sh 250 | $ go test -bench=. -benchmem -count=10 ./internal/... | tee bench-new.txt 251 | ``` 252 | 253 |
254 | 255 | 256 | 257 | ```sh 258 | $ benchstat bench-old.txt 259 | ``` 260 | 261 | 262 | 263 | ``` 264 | name time/op 265 | JSONSerializerEncode_v2NoParams-12 2.97ns ± 1% 266 | JSONSerializerEncode_v2Args-12 350ns ± 0% 267 | JSONSerializerEncode_v2Kwargs-12 582ns ± 0% 268 | JSONSerializerEncode_v2ArgsKwargs-12 788ns ± 1% 269 | JSONSerializerEncode_v1NoParams-12 1.12µs ± 1% 270 | JSONSerializerEncode_v1Args-12 1.21µs ± 0% 271 | JSONSerializerEncode_v1Kwargs-12 1.68µs ± 0% 272 | JSONSerializerEncode_v1ArgsKwargs-12 1.77µs ± 0% 273 | 274 | name alloc/op 275 | JSONSerializerEncode_v2NoParams-12 0.00B 276 | JSONSerializerEncode_v2Args-12 248B ± 0% 277 | JSONSerializerEncode_v2Kwargs-12 472B ± 0% 278 | JSONSerializerEncode_v2ArgsKwargs-12 528B ± 0% 279 | JSONSerializerEncode_v1NoParams-12 672B ± 0% 280 | JSONSerializerEncode_v1Args-12 672B ± 0% 281 | JSONSerializerEncode_v1Kwargs-12 1.00kB ± 0% 282 | JSONSerializerEncode_v1ArgsKwargs-12 1.00kB ± 0% 283 | 284 | name allocs/op 285 | JSONSerializerEncode_v2NoParams-12 0.00 286 | JSONSerializerEncode_v2Args-12 3.00 ± 0% 287 | JSONSerializerEncode_v2Kwargs-12 7.00 ± 0% 288 | JSONSerializerEncode_v2ArgsKwargs-12 8.00 ± 0% 289 | JSONSerializerEncode_v1NoParams-12 4.00 ± 0% 290 | JSONSerializerEncode_v1Args-12 4.00 ± 0% 291 | JSONSerializerEncode_v1Kwargs-12 10.0 ± 0% 292 | JSONSerializerEncode_v1ArgsKwargs-12 10.0 ± 0% 293 | ``` 294 | 295 |
296 | 297 | The old and new stats are compared as follows. 298 | 299 | ```sh 300 | $ benchstat bench-old.txt bench-new.txt 301 | ``` 302 | --------------------------------------------------------------------------------