├── .deepsource.toml ├── .gitignore ├── .travis.yml ├── Dockerfile.gcppubsub ├── Dockerfile.test ├── LICENSE ├── Makefile ├── README.md ├── docker-compose.test.yml ├── example ├── amqp │ └── main.go ├── redis │ └── main.go ├── tasks │ └── tasks.go └── tracers │ └── jaeger.go ├── go.mod ├── go.sum ├── instruction-notes └── dynamodb.md ├── integration-tests ├── amqp_amqp_test.go ├── amqp_get_pending_tasks_test.go ├── amqp_memcache_test.go ├── amqp_mongodb_test.go ├── amqp_redis_test.go ├── eager_eager_test.go ├── gcppubsub_redis_test.go ├── redis_get_pending_tasks_test.go ├── redis_memcache_test.go ├── redis_mongodb_test.go ├── redis_redis_test.go ├── redis_socket_test.go ├── sqs_amqp_test.go ├── sqs_mongodb_test.go ├── suite_test.go └── worker_only_consumes_registered_tasks_test.go ├── v1 ├── backends │ ├── amqp │ │ ├── amqp.go │ │ └── amqp_test.go │ ├── dynamodb │ │ ├── dynamodb.go │ │ ├── dynamodb_export_test.go │ │ └── dynamodb_test.go │ ├── eager │ │ ├── eager.go │ │ └── eager_test.go │ ├── iface │ │ └── interfaces.go │ ├── memcache │ │ ├── memcache.go │ │ └── memcache_test.go │ ├── mongo │ │ ├── mongodb.go │ │ └── mongodb_test.go │ ├── null │ │ └── null.go │ ├── package.go │ ├── redis │ │ ├── goredis.go │ │ ├── goredis_test.go │ │ ├── redis.go │ │ └── redis_test.go │ └── result │ │ └── async_result.go ├── brokers │ ├── amqp │ │ ├── amqp.go │ │ ├── amqp_concurrence_test.go │ │ └── amqp_test.go │ ├── eager │ │ └── eager.go │ ├── errs │ │ └── errors.go │ ├── gcppubsub │ │ └── gcp_pubsub.go │ ├── iface │ │ └── interfaces.go │ ├── package.go │ ├── redis │ │ ├── goredis.go │ │ └── redis.go │ └── sqs │ │ ├── sqs.go │ │ ├── sqs_export_test.go │ │ └── sqs_test.go ├── common │ ├── amqp.go │ ├── backend.go │ ├── broker.go │ ├── broker_test.go │ └── redis.go ├── config │ ├── config.go │ ├── env.go │ ├── env_test.go │ ├── file.go │ ├── file_test.go │ ├── test.env │ └── testconfig.yml ├── factories.go ├── factories_test.go ├── locks │ ├── eager │ │ ├── eager.go │ │ └── eager_test.go │ ├── iface │ │ └── interfaces.go │ └── redis │ │ └── redis.go ├── log │ ├── log.go │ └── log_test.go ├── package.go ├── retry │ ├── fibonacci.go │ ├── fibonacci_test.go │ └── retry.go ├── server.go ├── server_test.go ├── tasks │ ├── errors.go │ ├── reflect.go │ ├── reflect_test.go │ ├── result.go │ ├── result_test.go │ ├── signature.go │ ├── state.go │ ├── state_test.go │ ├── task.go │ ├── task_test.go │ ├── validate.go │ ├── validate_test.go │ ├── workflow.go │ └── workflow_test.go ├── tracing │ └── tracing.go ├── utils │ ├── deepcopy.go │ ├── deepcopy_test.go │ ├── utils.go │ ├── utils_test.go │ ├── uuid.go │ └── uuid_test.go ├── worker.go └── worker_test.go ├── v2 ├── backends │ ├── amqp │ │ ├── amqp.go │ │ └── amqp_test.go │ ├── dynamodb │ │ ├── dynamodb.go │ │ ├── dynamodb_export_test.go │ │ └── dynamodb_test.go │ ├── eager │ │ ├── eager.go │ │ └── eager_test.go │ ├── iface │ │ └── interfaces.go │ ├── memcache │ │ ├── memcache.go │ │ └── memcache_test.go │ ├── mongo │ │ ├── mongodb.go │ │ └── mongodb_test.go │ ├── null │ │ └── null.go │ ├── package.go │ ├── redis │ │ ├── goredis.go │ │ ├── goredis_test.go │ │ ├── redis.go │ │ └── redis_test.go │ └── result │ │ └── async_result.go ├── brokers │ ├── amqp │ │ ├── amqp.go │ │ ├── amqp_concurrence_test.go │ │ └── amqp_test.go │ ├── eager │ │ └── eager.go │ ├── errs │ │ └── errors.go │ ├── gcppubsub │ │ └── gcp_pubsub.go │ ├── iface │ │ └── interfaces.go │ ├── package.go │ ├── redis │ │ ├── goredis.go │ │ └── redis.go │ └── sqs │ │ ├── sqs.go │ │ ├── sqs_export_test.go │ │ └── sqs_test.go ├── common │ ├── amqp.go │ ├── backend.go │ ├── broker.go │ ├── broker_test.go │ └── redis.go ├── config │ ├── config.go │ ├── env.go │ ├── env_test.go │ ├── file.go │ ├── file_test.go │ ├── test.env │ └── testconfig.yml ├── example │ ├── amqp │ │ └── main.go │ ├── go-redis │ │ └── main.go │ ├── redigo │ │ └── main.go │ ├── tasks │ │ └── tasks.go │ └── tracers │ │ └── jaeger.go ├── go.mod ├── go.sum ├── integration-tests │ ├── amqp_amqp_test.go │ ├── redis_redis_test.go │ └── suite_test.go ├── locks │ ├── eager │ │ ├── eager.go │ │ └── eager_test.go │ ├── iface │ │ └── interfaces.go │ └── redis │ │ └── redis.go ├── log │ ├── log.go │ └── log_test.go ├── package.go ├── retry │ ├── fibonacci.go │ ├── fibonacci_test.go │ └── retry.go ├── server.go ├── server_test.go ├── tasks │ ├── errors.go │ ├── reflect.go │ ├── reflect_test.go │ ├── result.go │ ├── result_test.go │ ├── signature.go │ ├── state.go │ ├── state_test.go │ ├── task.go │ ├── task_test.go │ ├── validate.go │ ├── validate_test.go │ ├── workflow.go │ └── workflow_test.go ├── tracing │ └── tracing.go ├── utils │ ├── deepcopy.go │ ├── deepcopy_test.go │ ├── utils.go │ ├── utils_test.go │ ├── uuid.go │ └── uuid_test.go ├── worker.go └── worker_test.go └── wait-for-it.sh /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | test_patterns = ["**/*_test.go"] 4 | 5 | exclude_patterns = ["example/**"] 6 | 7 | [[analyzers]] 8 | name = "go" 9 | enabled = true 10 | 11 | [analyzers.meta] 12 | import_paths = ["github.com/RichardKnop/machinery"] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | coverage* 2 | _vendor-* 3 | .idea/ 4 | .env 5 | .DS_Store 6 | dump.rdb 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | --- 2 | language: go 3 | 4 | go: 5 | - 1.13.x 6 | 7 | env: 8 | - GO111MODULE=on 9 | 10 | services: 11 | - docker 12 | 13 | script: 14 | - make ci 15 | 16 | after_success: 17 | - bash <(curl -s https://codecov.io/bash) 18 | -------------------------------------------------------------------------------- /Dockerfile.gcppubsub: -------------------------------------------------------------------------------- 1 | FROM google/cloud-sdk:216.0.0-alpine 2 | 3 | RUN apk --update add openjdk8-jre 4 | RUN gcloud components install --quiet beta pubsub-emulator 5 | RUN mkdir -p /var/pubsub 6 | 7 | EXPOSE 8085 8 | 9 | CMD [ "gcloud", "beta", "emulators", "pubsub", "start", "--host-port=0.0.0.0:8085"] 10 | -------------------------------------------------------------------------------- /Dockerfile.test: -------------------------------------------------------------------------------- 1 | # Start from a Debian image with the latest version of Go installed 2 | # and a workspace (GOPATH) configured at /go. 3 | FROM golang 4 | 5 | # Contact maintainer with any issues you encounter 6 | MAINTAINER Richard Knop 7 | 8 | # Set environment variables 9 | ENV PATH /go/bin:$PATH 10 | 11 | # Cd into the source code directory 12 | WORKDIR /go/src/github.com/RichardKnop/machinery 13 | 14 | # Copy the local package files to the container's workspace. 15 | ADD . /go/src/github.com/RichardKnop/machinery 16 | 17 | # Set GO111MODULE=on variable to activate module support 18 | ENV GO111MODULE on 19 | 20 | # Run integration tests as default command 21 | CMD /go/src/github.com/RichardKnop/machinery/wait-for-it.sh rabbitmq:5672 -- make test-with-coverage 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: fmt lint golint test test-with-coverage ci 2 | # TODO: When Go 1.9 is released vendor folder should be ignored automatically 3 | PACKAGES=`go list ./... | grep -v vendor | grep -v mocks` 4 | 5 | fmt: 6 | for pkg in ${PACKAGES}; do \ 7 | go fmt $$pkg; \ 8 | done; 9 | 10 | lint: 11 | gometalinter --tests --disable-all --deadline=120s -E vet -E gofmt -E misspell -E ineffassign -E goimports -E deadcode ./... 12 | 13 | golint: 14 | for pkg in ${PACKAGES}; do \ 15 | golint -set_exit_status $$pkg || GOLINT_FAILED=1; \ 16 | done; \ 17 | [ -z "$$GOLINT_FAILED" ] 18 | 19 | test: 20 | TEST_FAILED= ; \ 21 | for pkg in ${PACKAGES}; do \ 22 | go test $$pkg || TEST_FAILED=1; \ 23 | done; \ 24 | [ -z "$$TEST_FAILED" ] 25 | 26 | test-with-coverage: 27 | echo "" > coverage.out 28 | echo "mode: set" > coverage-all.out 29 | TEST_FAILED= ; \ 30 | for pkg in ${PACKAGES}; do \ 31 | go test -coverprofile=coverage.out -covermode=set $$pkg || TEST_FAILED=1; \ 32 | tail -n +2 coverage.out >> coverage-all.out; \ 33 | done; \ 34 | [ -z "$$TEST_FAILED" ] 35 | #go tool cover -html=coverage-all.out 36 | 37 | ci: 38 | bash -c 'docker-compose -f docker-compose.test.yml -p machinery_ci up --build --abort-on-container-exit --exit-code-from sut' 39 | -------------------------------------------------------------------------------- /docker-compose.test.yml: -------------------------------------------------------------------------------- 1 | version: "2" 2 | 3 | services: 4 | sut: 5 | container_name: machinery_sut 6 | image: machinery_sut:latest 7 | volumes: 8 | - "./:/go/src/github.com/RichardKnop/machinery" 9 | depends_on: 10 | - rabbitmq 11 | - redis 12 | - memcached 13 | - mongo 14 | - gcppubsub 15 | links: 16 | - rabbitmq 17 | - redis 18 | - memcached 19 | - mongo 20 | - gcppubsub 21 | build: 22 | context: . 23 | dockerfile: ./Dockerfile.test 24 | environment: 25 | AMQP_URLS: 'amqp://guest:guest@dummy:5672/,amqp://guest:guest@rabbitmq:5672/' 26 | AMQP_URLS_SEPARATOR: ',' 27 | REDIS_URL: 'redis:6379' 28 | MEMCACHE_URL: 'memcached:11211' 29 | MONGODB_URL: 'mongo:27017' 30 | SQS_URL: ${SQS_URL} 31 | AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID} 32 | AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY} 33 | AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION} 34 | AWS_REGION: 'us-west-2' 35 | GCPPUBSUB_URL: 'gcppubsub://example-project/test_subscription_queue' 36 | GCPPUBSUB_TOPIC: 'test_topic_queue' 37 | PUBSUB_EMULATOR_HOST: 'gcppubsub:8085' 38 | 39 | rabbitmq: 40 | container_name: machinery_sut_rabbitmq 41 | image: rabbitmq 42 | environment: 43 | - RABBITMQ_DEFAULT_USER=guest 44 | - RABBITMQ_DEFAULT_PASS=guest 45 | logging: 46 | driver: none 47 | 48 | redis: 49 | container_name: machinery_sut_redis 50 | image: redis 51 | logging: 52 | driver: none 53 | 54 | memcached: 55 | container_name: machinery_sut_memcached 56 | image: memcached 57 | logging: 58 | driver: none 59 | 60 | mongo: 61 | container_name: machinery_sut_mongo 62 | image: mongo 63 | logging: 64 | driver: none 65 | 66 | gcppubsub: 67 | container_name: machinery_sut_gcppubsub 68 | build: 69 | context: . 70 | dockerfile: ./Dockerfile.gcppubsub 71 | logging: 72 | driver: none 73 | -------------------------------------------------------------------------------- /example/tasks/tasks.go: -------------------------------------------------------------------------------- 1 | package exampletasks 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | "time" 7 | 8 | "github.com/RichardKnop/machinery/v1/log" 9 | ) 10 | 11 | // Add ... 12 | func Add(args ...int64) (int64, error) { 13 | sum := int64(0) 14 | for _, arg := range args { 15 | sum += arg 16 | } 17 | return sum, nil 18 | } 19 | 20 | // Multiply ... 21 | func Multiply(args ...int64) (int64, error) { 22 | sum := int64(1) 23 | for _, arg := range args { 24 | sum *= arg 25 | } 26 | return sum, nil 27 | } 28 | 29 | // SumInts ... 30 | func SumInts(numbers []int64) (int64, error) { 31 | var sum int64 32 | for _, num := range numbers { 33 | sum += num 34 | } 35 | return sum, nil 36 | } 37 | 38 | // SumFloats ... 39 | func SumFloats(numbers []float64) (float64, error) { 40 | var sum float64 41 | for _, num := range numbers { 42 | sum += num 43 | } 44 | return sum, nil 45 | } 46 | 47 | // Concat ... 48 | func Concat(strs []string) (string, error) { 49 | var res string 50 | for _, s := range strs { 51 | res += s 52 | } 53 | return res, nil 54 | } 55 | 56 | // Split ... 57 | func Split(str string) ([]string, error) { 58 | return strings.Split(str, ""), nil 59 | } 60 | 61 | // PanicTask ... 62 | func PanicTask() (string, error) { 63 | panic(errors.New("oops")) 64 | } 65 | 66 | // LongRunningTask ... 67 | func LongRunningTask() error { 68 | log.INFO.Print("Long running task started") 69 | for i := 0; i < 10; i++ { 70 | log.INFO.Print(10 - i) 71 | time.Sleep(1 * time.Second) 72 | } 73 | log.INFO.Print("Long running task finished") 74 | return nil 75 | } 76 | -------------------------------------------------------------------------------- /example/tracers/jaeger.go: -------------------------------------------------------------------------------- 1 | package tracers 2 | 3 | // Uncomment the import statement for the jaeger tracer. 4 | // make sure you run dep ensure to pull in the jaeger client 5 | // 6 | // import ( 7 | // jaeger "github.com/uber/jaeger-client-go" 8 | // jaegercfg "github.com/uber/jaeger-client-go/config" 9 | // ) 10 | 11 | // SetupTracer is the place where you'd setup your specific tracer. 12 | // The jaeger tracer is given as an example. 13 | // To capture the jaeger traces you should run the jaeger backend. 14 | // This can be done using the following docker command: 15 | // 16 | // `docker run -ti --rm -p6831:6831/udp -p16686:16686 jaegertracing/all-in-one:latest` 17 | // 18 | // The collector will be listening on localhost:6831 19 | // and the query UI is reachable on localhost:16686. 20 | func SetupTracer(serviceName string) (func(), error) { 21 | 22 | // Jaeger setup code 23 | // 24 | // config := jaegercfg.Configuration{ 25 | // Sampler: &jaegercfg.SamplerConfig{ 26 | // Type: jaeger.SamplerTypeConst, 27 | // Param: 1, 28 | // }, 29 | // } 30 | 31 | // closer, err := config.InitGlobalTracer(serviceName) 32 | // if err != nil { 33 | // return nil, err 34 | // } 35 | 36 | cleanupFunc := func() { 37 | // closer.Close() 38 | } 39 | 40 | return cleanupFunc, nil 41 | } 42 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/RichardKnop/machinery 2 | 3 | go 1.15 4 | 5 | require ( 6 | cloud.google.com/go v0.76.0 // indirect 7 | cloud.google.com/go/pubsub v1.10.0 8 | github.com/RichardKnop/logging v0.0.0-20190827224416-1a693bdd4fae 9 | github.com/aws/aws-sdk-go v1.37.16 10 | github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b 11 | github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect 12 | github.com/go-redsync/redsync/v4 v4.8.1 13 | github.com/gomodule/redigo v1.8.10-0.20230511231101-78e255f9bd2a 14 | github.com/google/uuid v1.2.0 15 | github.com/kelseyhightower/envconfig v1.4.0 16 | github.com/opentracing/opentracing-go v1.2.0 17 | github.com/pkg/errors v0.9.1 18 | github.com/rabbitmq/amqp091-go v1.9.0 19 | github.com/redis/go-redis/v9 v9.0.5 20 | github.com/robfig/cron/v3 v3.0.1 21 | github.com/russross/blackfriday/v2 v2.1.0 // indirect 22 | github.com/stretchr/testify v1.8.1 23 | github.com/urfave/cli v1.22.5 24 | go.mongodb.org/mongo-driver v1.17.0 25 | go.opencensus.io v0.22.6 // indirect 26 | golang.org/x/oauth2 v0.0.0-20210201163806-010130855d6c // indirect 27 | gopkg.in/yaml.v2 v2.4.0 28 | ) 29 | 30 | replace git.apache.org/thrift.git => github.com/apache/thrift v0.0.0-20180902110319-2566ecd5d999 31 | -------------------------------------------------------------------------------- /instruction-notes/dynamodb.md: -------------------------------------------------------------------------------- 1 | # Using DynamoDB as a result backend 2 | ## What is DynamoDB 3 | Amazon DynamoDB is a fast and flexible NoSQL database service. 4 | Check this [official website](https://aws.amazon.com/dynamodb/ 5 | ) for details. 6 | 7 | ## How to use DynamoDB as a result backend in Machinery 8 | ### Create two tables first 9 | There will be two tables required now(2018-01-12): 10 | * group_metas: A table which saves group tasks' meta data. The primary key for this table is ```GroupUUID```, and it should be set properly when creating this table. 11 | * task_states: A table which saves every task's states. The primary key for this table is ```TaskUUID```, and it should be set properly when creating this table. 12 | 13 | 14 | ### Add DynamoDB config to the config file 15 | #### example config 16 | ```yaml 17 | broker: 'https://sqs.us-west-1.amazonaws.com/123456789012' 18 | default_queue: machinery-queue 19 | result_backend: 'https://dynamodb.us-west-1.amazonaws.com/123456789012' 20 | results_expire_in: 3600 21 | dynamodb: 22 | task_states_table: 'task_states' 23 | group_metas_table: 'group_metas' 24 | ``` 25 | Then DynamoDB will be used as a result backend. -------------------------------------------------------------------------------- /integration-tests/amqp_amqp_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/RichardKnop/machinery/v1" 8 | "github.com/RichardKnop/machinery/v1/config" 9 | ) 10 | 11 | func TestAmqpAmqp(t *testing.T) { 12 | amqpURL := os.Getenv("AMQP_URL") 13 | if amqpURL == "" { 14 | t.Skip("AMQP_URL is not defined") 15 | } 16 | 17 | finalAmqpURL := amqpURL 18 | var finalSeparator string 19 | 20 | amqpURLs := os.Getenv("AMQP_URLS") 21 | if amqpURLs != "" { 22 | separator := os.Getenv("AMQP_URLS_SEPARATOR") 23 | if separator == "" { 24 | return 25 | } 26 | finalSeparator = separator 27 | finalAmqpURL = amqpURLs 28 | } 29 | 30 | // AMQP broker, AMQP result backend 31 | server := testSetup(&config.Config{ 32 | Broker: finalAmqpURL, 33 | MultipleBrokerSeparator: finalSeparator, 34 | DefaultQueue: "test_queue", 35 | ResultBackend: amqpURL, 36 | Lock: "eager", 37 | AMQP: &config.AMQPConfig{ 38 | Exchange: "test_exchange", 39 | ExchangeType: "direct", 40 | BindingKey: "test_task", 41 | PrefetchCount: 1, 42 | }, 43 | }) 44 | 45 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 46 | defer worker.Quit() 47 | go worker.Launch() 48 | testAll(server, t) 49 | } 50 | -------------------------------------------------------------------------------- /integration-tests/amqp_get_pending_tasks_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | "time" 8 | 9 | "github.com/RichardKnop/machinery/v1" 10 | "github.com/RichardKnop/machinery/v1/backends/result" 11 | "github.com/RichardKnop/machinery/v1/config" 12 | "github.com/RichardKnop/machinery/v1/tasks" 13 | ) 14 | 15 | func TestAmqpGetPendingTasks(t *testing.T) { 16 | amqpURL := os.Getenv("AMQP_URL") 17 | if amqpURL == "" { 18 | t.Skip("AMQP_URL is not defined") 19 | } 20 | 21 | finalAmqpURL := amqpURL 22 | var finalSeparator string 23 | 24 | amqpURLs := os.Getenv("AMQP_URLS") 25 | if amqpURLs != "" { 26 | separator := os.Getenv("AMQP_URLS_SEPARATOR") 27 | if separator == "" { 28 | return 29 | } 30 | finalSeparator = separator 31 | finalAmqpURL = amqpURLs 32 | } 33 | 34 | redisURL := os.Getenv("REDIS_URL") 35 | if redisURL == "" { 36 | t.Skip("REDIS_URL is not defined") 37 | } 38 | 39 | // AMQP broker, AMQP result backend 40 | server := testSetup(&config.Config{ 41 | Broker: finalAmqpURL, 42 | MultipleBrokerSeparator: finalSeparator, 43 | DefaultQueue: "test_queue", 44 | ResultBackend: amqpURL, 45 | Lock: fmt.Sprintf("redis://%v", redisURL), 46 | AMQP: &config.AMQPConfig{ 47 | Exchange: "test_exchange", 48 | ExchangeType: "direct", 49 | BindingKey: "test_task", 50 | PrefetchCount: 1, 51 | }, 52 | }) 53 | 54 | var results []*result.AsyncResult 55 | signatures := []*tasks.Signature{newAddTask(1, 2), newAddTask(3, 5), newAddTask(6, 7)} 56 | for _, s := range signatures { 57 | ar, err := server.SendTask(s) 58 | if err != nil { 59 | t.Error(err) 60 | } 61 | results = append(results, ar) 62 | } 63 | pendingMessages, err := server.GetBroker().GetPendingTasks(server.GetConfig().DefaultQueue) 64 | if err != nil { 65 | t.Error(err) 66 | } 67 | 68 | if len(pendingMessages) != len(signatures) { 69 | t.Errorf( 70 | "%d pending messages, should be %d", 71 | len(pendingMessages), 72 | len(signatures), 73 | ) 74 | } 75 | for i := 0; i < len(signatures); i++ { 76 | compareSigs(t, signatures[i], pendingMessages[i]) 77 | } 78 | 79 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 80 | go worker.Launch() 81 | defer worker.Quit() 82 | for _, r := range results { 83 | r.Get(time.Duration(time.Millisecond * 5)) 84 | } 85 | 86 | pendingMessages, err = server.GetBroker().GetPendingTasks(server.GetConfig().DefaultQueue) 87 | if err != nil { 88 | t.Error(err) 89 | } 90 | 91 | if len(pendingMessages) != 0 { 92 | t.Errorf( 93 | "%d pending messages, should be 0", 94 | len(pendingMessages), 95 | ) 96 | } 97 | } 98 | 99 | func compareSigs(t *testing.T, a *tasks.Signature, b *tasks.Signature) { 100 | if a.UUID != b.UUID { 101 | t.Errorf("UUID mismatch, %v != %v", a.UUID, b.UUID) 102 | } 103 | if a.Name != b.Name { 104 | t.Errorf("UUID mismatch, %v != %v", a.Name, b.Name) 105 | } 106 | if len(a.Args) != len(b.Args) { 107 | t.Errorf("Arg length mismatch, %v != %v", len(a.Args), len(b.Args)) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /integration-tests/amqp_memcache_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/RichardKnop/machinery/v1" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | ) 11 | 12 | func TestAmqpMemcache(t *testing.T) { 13 | amqpURL := os.Getenv("AMQP_URL") 14 | memcacheURL := os.Getenv("MEMCACHE_URL") 15 | if amqpURL == "" { 16 | t.Skip("AMQP_URL is not defined") 17 | } 18 | if memcacheURL == "" { 19 | t.Skip("MEMCACHE_URL is not defined") 20 | } 21 | 22 | redisURL := os.Getenv("REDIS_URL") 23 | if redisURL == "" { 24 | t.Skip("REDIS_URL is not defined") 25 | } 26 | 27 | // AMQP broker, Memcache result backend 28 | server := testSetup(&config.Config{ 29 | Broker: amqpURL, 30 | DefaultQueue: "test_queue", 31 | ResultBackend: fmt.Sprintf("memcache://%v", memcacheURL), 32 | Lock: fmt.Sprintf("redis://%v", redisURL), 33 | AMQP: &config.AMQPConfig{ 34 | Exchange: "test_exchange", 35 | ExchangeType: "direct", 36 | BindingKey: "test_task", 37 | PrefetchCount: 1, 38 | }, 39 | }) 40 | 41 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 42 | defer worker.Quit() 43 | go worker.Launch() 44 | testAll(server, t) 45 | } 46 | -------------------------------------------------------------------------------- /integration-tests/amqp_mongodb_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/RichardKnop/machinery/v1" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | ) 11 | 12 | func TestAmqpMongodb(t *testing.T) { 13 | amqpURL := os.Getenv("AMQP_URL") 14 | mongodbURL := os.Getenv("MONGODB_URL") 15 | if amqpURL == "" { 16 | t.Skip("AMQP_URL is not defined") 17 | } 18 | if mongodbURL == "" { 19 | t.Skip("MONGODB_URL is not defined") 20 | } 21 | 22 | // AMQP broker, MongoDB result backend 23 | server := testSetup(&config.Config{ 24 | Broker: amqpURL, 25 | DefaultQueue: "test_queue", 26 | ResultsExpireIn: 30, 27 | ResultBackend: fmt.Sprintf("mongodb://%v", mongodbURL), 28 | Lock: "eager", 29 | AMQP: &config.AMQPConfig{ 30 | Exchange: "test_exchange", 31 | ExchangeType: "direct", 32 | BindingKey: "test_task", 33 | PrefetchCount: 1, 34 | }, 35 | }) 36 | 37 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 38 | defer worker.Quit() 39 | go worker.Launch() 40 | testAll(server, t) 41 | } 42 | -------------------------------------------------------------------------------- /integration-tests/amqp_redis_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/RichardKnop/machinery/v1" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | ) 11 | 12 | func TestAmqpRedis(t *testing.T) { 13 | amqpURL := os.Getenv("AMQP_URL") 14 | redisURL := os.Getenv("REDIS_URL") 15 | if amqpURL == "" { 16 | t.Skip("AMQP_URL is not defined") 17 | } 18 | if redisURL == "" { 19 | t.Skip("REDIS_URL is not defined") 20 | } 21 | 22 | // AMQP broker, Redis result backend 23 | server := testSetup(&config.Config{ 24 | Broker: amqpURL, 25 | DefaultQueue: "test_queue", 26 | ResultBackend: fmt.Sprintf("redis://%v", redisURL), 27 | Lock: "eager", 28 | AMQP: &config.AMQPConfig{ 29 | Exchange: "test_exchange", 30 | ExchangeType: "direct", 31 | BindingKey: "test_task", 32 | PrefetchCount: 1, 33 | }, 34 | }) 35 | 36 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 37 | defer worker.Quit() 38 | go worker.Launch() 39 | testAll(server, t) 40 | } 41 | -------------------------------------------------------------------------------- /integration-tests/eager_eager_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | "time" 7 | 8 | "github.com/RichardKnop/machinery/v1" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | "github.com/RichardKnop/machinery/v1/tasks" 11 | "github.com/stretchr/testify/suite" 12 | ) 13 | 14 | type EagerIntegrationTestSuite struct { 15 | suite.Suite 16 | 17 | srv *machinery.Server 18 | called float64 19 | } 20 | 21 | func TestEagerIntegrationTestSuite(t *testing.T) { 22 | suite.Run(t, &EagerIntegrationTestSuite{}) 23 | } 24 | 25 | func (s *EagerIntegrationTestSuite) SetupSuite() { 26 | var err error 27 | 28 | // init server 29 | cnf := config.Config{ 30 | Broker: "eager", 31 | ResultBackend: "eager", 32 | Lock: "eager", 33 | } 34 | s.srv, err = machinery.NewServer(&cnf) 35 | s.Nil(err) 36 | s.NotNil(s.srv) 37 | 38 | // register task 39 | s.called = 0 40 | s.srv.RegisterTask("float_called", func(i float64) (float64, error) { 41 | s.called = i 42 | return s.called, nil 43 | }) 44 | 45 | s.srv.RegisterTask("float_result", func(i float64) (float64, error) { 46 | return i + 100.0, nil 47 | }) 48 | 49 | s.srv.RegisterTask("int_result", func(i int64) (int64, error) { 50 | return i + 100, nil 51 | }) 52 | } 53 | 54 | func (s *EagerIntegrationTestSuite) TestCalled() { 55 | _, err := s.srv.SendTask(&tasks.Signature{ 56 | Name: "float_called", 57 | Args: []tasks.Arg{ 58 | { 59 | Type: "float64", 60 | Value: 100.0, 61 | }, 62 | }, 63 | }) 64 | 65 | s.Nil(err) 66 | s.Equal(100.0, s.called) 67 | } 68 | 69 | func (s *EagerIntegrationTestSuite) TestSuccessResult() { 70 | // float64 71 | { 72 | asyncResult, err := s.srv.SendTask(&tasks.Signature{ 73 | Name: "float_result", 74 | Args: []tasks.Arg{ 75 | { 76 | Type: "float64", 77 | Value: 100.0, 78 | }, 79 | }, 80 | }) 81 | 82 | s.NotNil(asyncResult) 83 | s.Nil(err) 84 | 85 | s.True(asyncResult.GetState().IsCompleted()) 86 | s.True(asyncResult.GetState().IsSuccess()) 87 | 88 | results, err := asyncResult.Get(time.Duration(time.Millisecond * 5)) 89 | if s.NoError(err) { 90 | if len(results) != 1 { 91 | s.T().Errorf("Number of results returned = %d. Wanted %d", len(results), 1) 92 | } 93 | 94 | s.Equal(reflect.Float64, results[0].Kind()) 95 | if results[0].Kind() == reflect.Float64 { 96 | s.Equal(200.0, results[0].Float()) 97 | } 98 | } 99 | } 100 | 101 | // int 102 | { 103 | asyncResult, err := s.srv.SendTask(&tasks.Signature{ 104 | Name: "int_result", 105 | Args: []tasks.Arg{ 106 | { 107 | Type: "int64", 108 | Value: 100, 109 | }, 110 | }, 111 | }) 112 | 113 | s.NotNil(asyncResult) 114 | s.Nil(err) 115 | 116 | s.True(asyncResult.GetState().IsCompleted()) 117 | s.True(asyncResult.GetState().IsSuccess()) 118 | 119 | results, err := asyncResult.Get(time.Duration(time.Millisecond * 5)) 120 | if s.NoError(err) { 121 | if len(results) != 1 { 122 | s.T().Errorf("Number of results returned = %d. Wanted %d", len(results), 1) 123 | } 124 | 125 | s.Equal(reflect.Int64, results[0].Kind()) 126 | if results[0].Kind() == reflect.Int64 { 127 | s.Equal(int64(200), results[0].Int()) 128 | } 129 | } 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /integration-tests/gcppubsub_redis_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "os" 7 | "testing" 8 | "time" 9 | 10 | "cloud.google.com/go/pubsub" 11 | 12 | "github.com/RichardKnop/machinery/v1" 13 | "github.com/RichardKnop/machinery/v1/config" 14 | ) 15 | 16 | func createGCPPubSubTopicAndSubscription(cli *pubsub.Client, topicName, subscriptionName string) { 17 | ctx := context.Background() 18 | 19 | var topic *pubsub.Topic 20 | 21 | topic = cli.Topic(topicName) 22 | topicExists, err := topic.Exists(ctx) 23 | if err != nil { 24 | panic(err) 25 | } 26 | 27 | if !topicExists { 28 | topic, err = cli.CreateTopic(ctx, topicName) 29 | if err != nil { 30 | panic(err) 31 | } 32 | } 33 | 34 | var sub *pubsub.Subscription 35 | 36 | sub = cli.Subscription(subscriptionName) 37 | subExists, err := sub.Exists(ctx) 38 | if err != nil { 39 | panic(err) 40 | } 41 | 42 | if !subExists { 43 | _, err = cli.CreateSubscription(ctx, subscriptionName, pubsub.SubscriptionConfig{ 44 | Topic: topic, 45 | AckDeadline: 10 * time.Second, 46 | }) 47 | if err != nil { 48 | panic(err) 49 | } 50 | } 51 | } 52 | 53 | func TestGCPPubSubRedis(t *testing.T) { 54 | // start Cloud Pub/Sub emulator 55 | // $ LANG=C gcloud beta emulators pubsub start 56 | // $ eval $(LANG=C gcloud beta emulators pubsub env-init) 57 | 58 | pubsubURL := os.Getenv("GCPPUBSUB_URL") 59 | if pubsubURL == "" { 60 | t.Skip("GCPPUBSUB_URL is not defined") 61 | } 62 | 63 | topicName := os.Getenv("GCPPUBSUB_TOPIC") 64 | if topicName == "" { 65 | t.Skip("GCPPUBSUB_TOPIC is not defined") 66 | } 67 | 68 | _, subscriptionName, err := machinery.ParseGCPPubSubURL(pubsubURL) 69 | if err != nil { 70 | t.Fatal(err) 71 | } 72 | 73 | redisURL := os.Getenv("REDIS_URL") 74 | if redisURL == "" { 75 | t.Skip("REDIS_URL is not defined") 76 | } 77 | 78 | pubsubClient, err := pubsub.NewClient(context.Background(), "") 79 | if err != nil { 80 | t.Fatal(err) 81 | } 82 | 83 | // Create Cloud Pub/Sub Topic and Subscription 84 | createGCPPubSubTopicAndSubscription(pubsubClient, topicName, subscriptionName) 85 | 86 | // Redis broker, Redis result backend 87 | server := testSetup(&config.Config{ 88 | Broker: pubsubURL, 89 | DefaultQueue: topicName, 90 | ResultBackend: fmt.Sprintf("redis://%v", redisURL), 91 | Lock: fmt.Sprintf("redis://%v", redisURL), 92 | GCPPubSub: &config.GCPPubSubConfig{ 93 | Client: pubsubClient, 94 | }, 95 | }) 96 | 97 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 98 | defer worker.Quit() 99 | go worker.Launch() 100 | testAll(server, t) 101 | } 102 | -------------------------------------------------------------------------------- /integration-tests/redis_get_pending_tasks_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/RichardKnop/machinery/v1/config" 9 | ) 10 | 11 | func TestRedisGetPendingTasks(t *testing.T) { 12 | redisURL := os.Getenv("REDIS_URL") 13 | if redisURL == "" { 14 | t.Skip("REDIS_URL is not defined") 15 | } 16 | 17 | // Redis broker, Redis result backend 18 | server := testSetup(&config.Config{ 19 | Broker: fmt.Sprintf("redis://%v", redisURL), 20 | DefaultQueue: "test_queue", 21 | ResultBackend: fmt.Sprintf("redis://%v", redisURL), 22 | Lock: fmt.Sprintf("redis://%v", redisURL), 23 | }) 24 | pendingMessages, err := server.GetBroker().GetPendingTasks(server.GetConfig().DefaultQueue) 25 | if err != nil { 26 | t.Error(err) 27 | } 28 | if len(pendingMessages) != 0 { 29 | t.Errorf( 30 | "%d pending messages, should be %d", 31 | len(pendingMessages), 32 | 0, 33 | ) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /integration-tests/redis_memcache_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/RichardKnop/machinery/v1" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | ) 11 | 12 | func TestRedisMemcache(t *testing.T) { 13 | redisURL := os.Getenv("REDIS_URL") 14 | memcacheURL := os.Getenv("MEMCACHE_URL") 15 | if redisURL == "" { 16 | t.Skip("REDIS_URL is not defined") 17 | } 18 | if memcacheURL == "" { 19 | t.Skip("MEMCACHE_URL is not defined") 20 | } 21 | 22 | // Redis broker, Redis result backend 23 | server := testSetup(&config.Config{ 24 | Broker: fmt.Sprintf("redis://%v", redisURL), 25 | DefaultQueue: "test_queue", 26 | ResultBackend: fmt.Sprintf("memcache://%v", memcacheURL), 27 | Lock: fmt.Sprintf("redis://%v", redisURL), 28 | }) 29 | 30 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 31 | defer worker.Quit() 32 | go worker.Launch() 33 | testAll(server, t) 34 | } 35 | -------------------------------------------------------------------------------- /integration-tests/redis_mongodb_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/RichardKnop/machinery/v1" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | ) 11 | 12 | func TestRedisMongodb(t *testing.T) { 13 | redisURL := os.Getenv("REDIS_URL") 14 | mongodbURL := os.Getenv("MONGODB_URL") 15 | if redisURL == "" { 16 | t.Skip("REDIS_URL is not defined") 17 | } 18 | if mongodbURL == "" { 19 | t.Skip("MONGODB_URL is not defined") 20 | } 21 | 22 | // Redis broker, MongoDB result backend 23 | server := testSetup(&config.Config{ 24 | Broker: fmt.Sprintf("redis://%v", redisURL), 25 | DefaultQueue: "test_queue", 26 | ResultsExpireIn: 30, 27 | ResultBackend: fmt.Sprintf("mongodb://%v", mongodbURL), 28 | Lock: fmt.Sprintf("redis://%v", redisURL), 29 | }) 30 | 31 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 32 | defer worker.Quit() 33 | go worker.Launch() 34 | testAll(server, t) 35 | } 36 | -------------------------------------------------------------------------------- /integration-tests/redis_socket_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/RichardKnop/machinery/v1" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | ) 11 | 12 | func TestRedisSocket(t *testing.T) { 13 | redisSocket := os.Getenv("REDIS_SOCKET") 14 | if redisSocket == "" { 15 | t.Skip("REDIS_SOCKET is not defined") 16 | } 17 | 18 | // Redis broker, Redis result backend 19 | server := testSetup(&config.Config{ 20 | Broker: fmt.Sprintf("redis+socket://%v", redisSocket), 21 | DefaultQueue: "test_queue", 22 | ResultBackend: fmt.Sprintf("redis+socket://%v", redisSocket), 23 | Lock: "eager", 24 | }) 25 | 26 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 27 | defer worker.Quit() 28 | go worker.Launch() 29 | testAll(server, t) 30 | } 31 | -------------------------------------------------------------------------------- /integration-tests/sqs_amqp_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/RichardKnop/machinery/v1" 8 | "github.com/RichardKnop/machinery/v1/config" 9 | ) 10 | 11 | func TestSQSAmqp(t *testing.T) { 12 | sqsURL := os.Getenv("SQS_URL") 13 | if sqsURL == "" { 14 | t.Skip("SQS_URL is not defined") 15 | } 16 | 17 | amqpURL := os.Getenv("AMQP_URL") 18 | if amqpURL == "" { 19 | t.Skip("AMQP_URL is not defined") 20 | } 21 | 22 | // AMQP broker, AMQP result backend 23 | server := testSetup(&config.Config{ 24 | Broker: sqsURL, 25 | DefaultQueue: "test_queue", 26 | ResultBackend: amqpURL, 27 | Lock: "eager", 28 | AMQP: &config.AMQPConfig{ 29 | Exchange: "test_exchange", 30 | ExchangeType: "direct", 31 | BindingKey: "test_task", 32 | PrefetchCount: 1, 33 | }, 34 | }) 35 | 36 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 37 | defer worker.Quit() 38 | go worker.Launch() 39 | testAll(server, t) 40 | } 41 | -------------------------------------------------------------------------------- /integration-tests/sqs_mongodb_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "testing" 7 | 8 | "github.com/RichardKnop/machinery/v1" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | ) 11 | 12 | func TestSQSMongodb(t *testing.T) { 13 | sqsURL := os.Getenv("SQS_URL") 14 | mongodbURL := os.Getenv("MONGODB_URL") 15 | if sqsURL == "" { 16 | t.Skip("SQS_URL is not defined") 17 | } 18 | if mongodbURL == "" { 19 | t.Skip("MONGODB_URL is not defined") 20 | } 21 | 22 | // AMQP broker, MongoDB result backend 23 | server := testSetup(&config.Config{ 24 | Broker: sqsURL, 25 | DefaultQueue: "test_queue", 26 | ResultsExpireIn: 30, 27 | ResultBackend: fmt.Sprintf("mongodb://%v", mongodbURL), 28 | Lock: "eager", 29 | }) 30 | worker := server.(*machinery.Server).NewWorker("test_worker", 0) 31 | go worker.Launch() 32 | testAll(server, t) 33 | worker.Quit() 34 | } 35 | -------------------------------------------------------------------------------- /v1/backends/amqp/amqp_test.go: -------------------------------------------------------------------------------- 1 | package amqp_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | 8 | "github.com/RichardKnop/machinery/v1/backends/amqp" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | "github.com/RichardKnop/machinery/v1/tasks" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var ( 15 | amqpConfig *config.Config 16 | ) 17 | 18 | func init() { 19 | amqpURL := os.Getenv("AMQP_URL") 20 | if amqpURL == "" { 21 | return 22 | } 23 | 24 | finalAmqpURL := amqpURL 25 | var finalSeparator string 26 | 27 | amqpURLs := os.Getenv("AMQP_URLS") 28 | if amqpURLs != "" { 29 | separator := os.Getenv("AMQP_URLS_SEPARATOR") 30 | if separator == "" { 31 | return 32 | } 33 | finalSeparator = separator 34 | finalAmqpURL = amqpURLs 35 | } 36 | 37 | amqp2URL := os.Getenv("AMQP2_URL") 38 | if amqp2URL == "" { 39 | amqp2URL = amqpURL 40 | } 41 | 42 | amqpConfig = &config.Config{ 43 | Broker: finalAmqpURL, 44 | MultipleBrokerSeparator: finalSeparator, 45 | DefaultQueue: "test_queue", 46 | ResultBackend: amqp2URL, 47 | AMQP: &config.AMQPConfig{ 48 | Exchange: "test_exchange", 49 | ExchangeType: "direct", 50 | BindingKey: "test_task", 51 | PrefetchCount: 1, 52 | }, 53 | } 54 | } 55 | 56 | func TestGroupCompleted(t *testing.T) { 57 | if os.Getenv("AMQP_URL") == "" { 58 | t.Skip("AMQP_URL is not defined") 59 | } 60 | 61 | groupUUID := "testGroupUUID" 62 | groupTaskCount := 2 63 | task1 := &tasks.Signature{ 64 | UUID: "testTaskUUID1", 65 | GroupUUID: groupUUID, 66 | GroupTaskCount: groupTaskCount, 67 | } 68 | task2 := &tasks.Signature{ 69 | UUID: "testTaskUUID2", 70 | GroupUUID: groupUUID, 71 | GroupTaskCount: groupTaskCount, 72 | } 73 | 74 | backend := amqp.New(amqpConfig) 75 | 76 | // Cleanup before the test 77 | backend.PurgeState(task1.UUID) 78 | backend.PurgeState(task2.UUID) 79 | backend.PurgeGroupMeta(groupUUID) 80 | 81 | groupCompleted, err := backend.GroupCompleted(groupUUID, groupTaskCount) 82 | if assert.NoError(t, err) { 83 | assert.False(t, groupCompleted) 84 | } 85 | 86 | backend.InitGroup(groupUUID, []string{task1.UUID, task2.UUID}) 87 | 88 | groupCompleted, err = backend.GroupCompleted(groupUUID, groupTaskCount) 89 | if assert.NoError(t, err) { 90 | assert.False(t, groupCompleted) 91 | } 92 | 93 | backend.SetStatePending(task1) 94 | backend.SetStateStarted(task2) 95 | groupCompleted, err = backend.GroupCompleted(groupUUID, groupTaskCount) 96 | if assert.NoError(t, err) { 97 | assert.False(t, groupCompleted) 98 | } 99 | 100 | taskResults := []*tasks.TaskResult{new(tasks.TaskResult)} 101 | backend.SetStateSuccess(task1, taskResults) 102 | backend.SetStateSuccess(task2, taskResults) 103 | groupCompleted, err = backend.GroupCompleted(groupUUID, groupTaskCount) 104 | if assert.NoError(t, err) { 105 | assert.True(t, groupCompleted) 106 | } 107 | } 108 | 109 | func TestGetState(t *testing.T) { 110 | if os.Getenv("AMQP_URL") == "" { 111 | t.Skip("AMQP_URL is not defined") 112 | } 113 | 114 | signature := &tasks.Signature{ 115 | UUID: "testTaskUUID", 116 | GroupUUID: "testGroupUUID", 117 | } 118 | 119 | go func() { 120 | backend := amqp.New(amqpConfig) 121 | backend.SetStatePending(signature) 122 | time.Sleep(2 * time.Millisecond) 123 | backend.SetStateReceived(signature) 124 | time.Sleep(2 * time.Millisecond) 125 | backend.SetStateStarted(signature) 126 | time.Sleep(2 * time.Millisecond) 127 | 128 | taskResults := []*tasks.TaskResult{ 129 | { 130 | Type: "float64", 131 | Value: 2, 132 | }, 133 | } 134 | backend.SetStateSuccess(signature, taskResults) 135 | }() 136 | 137 | backend := amqp.New(amqpConfig) 138 | 139 | var ( 140 | taskState *tasks.TaskState 141 | err error 142 | ) 143 | for { 144 | taskState, err = backend.GetState(signature.UUID) 145 | if taskState == nil { 146 | assert.Equal(t, "No state ready", err.Error()) 147 | continue 148 | } 149 | 150 | assert.NoError(t, err) 151 | if taskState.IsCompleted() { 152 | break 153 | } 154 | } 155 | } 156 | 157 | func TestPurgeState(t *testing.T) { 158 | if os.Getenv("AMQP_URL") == "" { 159 | t.Skip("AMQP_URL is not defined") 160 | } 161 | 162 | signature := &tasks.Signature{ 163 | UUID: "testTaskUUID", 164 | GroupUUID: "testGroupUUID", 165 | } 166 | 167 | backend := amqp.New(amqpConfig) 168 | 169 | backend.SetStatePending(signature) 170 | backend.SetStateReceived(signature) 171 | taskState, err := backend.GetState(signature.UUID) 172 | assert.NotNil(t, taskState) 173 | assert.NoError(t, err) 174 | 175 | backend.PurgeState(taskState.TaskUUID) 176 | taskState, err = backend.GetState(signature.UUID) 177 | assert.Nil(t, taskState) 178 | assert.Error(t, err) 179 | } 180 | -------------------------------------------------------------------------------- /v1/backends/iface/interfaces.go: -------------------------------------------------------------------------------- 1 | package iface 2 | 3 | import ( 4 | "github.com/RichardKnop/machinery/v1/tasks" 5 | ) 6 | 7 | // Backend - a common interface for all result backends 8 | type Backend interface { 9 | // Group related functions 10 | InitGroup(groupUUID string, taskUUIDs []string) error 11 | GroupCompleted(groupUUID string, groupTaskCount int) (bool, error) 12 | GroupTaskStates(groupUUID string, groupTaskCount int) ([]*tasks.TaskState, error) 13 | TriggerChord(groupUUID string) (bool, error) 14 | 15 | // Setting / getting task state 16 | SetStatePending(signature *tasks.Signature) error 17 | SetStateReceived(signature *tasks.Signature) error 18 | SetStateStarted(signature *tasks.Signature) error 19 | SetStateRetry(signature *tasks.Signature) error 20 | SetStateSuccess(signature *tasks.Signature, results []*tasks.TaskResult) error 21 | SetStateFailure(signature *tasks.Signature, err string) error 22 | GetState(taskUUID string) (*tasks.TaskState, error) 23 | 24 | // Purging stored stored tasks states and group meta data 25 | IsAMQP() bool 26 | PurgeState(taskUUID string) error 27 | PurgeGroupMeta(groupUUID string) error 28 | } 29 | -------------------------------------------------------------------------------- /v1/backends/memcache/memcache_test.go: -------------------------------------------------------------------------------- 1 | package memcache_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | 8 | "github.com/RichardKnop/machinery/v1/backends/memcache" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | "github.com/RichardKnop/machinery/v1/tasks" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestGroupCompleted(t *testing.T) { 15 | memcacheURL := os.Getenv("MEMCACHE_URL") 16 | if memcacheURL == "" { 17 | t.Skip("MEMCACHE_URL is not defined") 18 | } 19 | 20 | groupUUID := "testGroupUUID" 21 | task1 := &tasks.Signature{ 22 | UUID: "testTaskUUID1", 23 | GroupUUID: groupUUID, 24 | } 25 | task2 := &tasks.Signature{ 26 | UUID: "testTaskUUID2", 27 | GroupUUID: groupUUID, 28 | } 29 | 30 | backend := memcache.New(new(config.Config), []string{memcacheURL}) 31 | 32 | // Cleanup before the test 33 | backend.PurgeState(task1.UUID) 34 | backend.PurgeState(task2.UUID) 35 | backend.PurgeGroupMeta(groupUUID) 36 | 37 | groupCompleted, err := backend.GroupCompleted(groupUUID, 2) 38 | if assert.Error(t, err) { 39 | assert.False(t, groupCompleted) 40 | assert.Equal(t, "memcache: cache miss", err.Error()) 41 | } 42 | 43 | backend.InitGroup(groupUUID, []string{task1.UUID, task2.UUID}) 44 | 45 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 46 | if assert.Error(t, err) { 47 | assert.False(t, groupCompleted) 48 | assert.Equal(t, "memcache: cache miss", err.Error()) 49 | } 50 | 51 | backend.SetStatePending(task1) 52 | backend.SetStateStarted(task2) 53 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 54 | if assert.NoError(t, err) { 55 | assert.False(t, groupCompleted) 56 | } 57 | 58 | taskResults := []*tasks.TaskResult{new(tasks.TaskResult)} 59 | backend.SetStateStarted(task1) 60 | backend.SetStateSuccess(task2, taskResults) 61 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 62 | if assert.NoError(t, err) { 63 | assert.False(t, groupCompleted) 64 | } 65 | 66 | backend.SetStateFailure(task1, "Some error") 67 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 68 | if assert.NoError(t, err) { 69 | assert.True(t, groupCompleted) 70 | } 71 | } 72 | 73 | func TestGetState(t *testing.T) { 74 | memcacheURL := os.Getenv("MEMCACHE_URL") 75 | if memcacheURL == "" { 76 | t.Skip("MEMCACHE_URL is not defined") 77 | } 78 | 79 | signature := &tasks.Signature{ 80 | UUID: "testTaskUUID", 81 | GroupUUID: "testGroupUUID", 82 | } 83 | 84 | backend := memcache.New(new(config.Config), []string{memcacheURL}) 85 | 86 | go func() { 87 | backend.SetStatePending(signature) 88 | time.Sleep(2 * time.Millisecond) 89 | backend.SetStateReceived(signature) 90 | time.Sleep(2 * time.Millisecond) 91 | backend.SetStateStarted(signature) 92 | time.Sleep(2 * time.Millisecond) 93 | taskResults := []*tasks.TaskResult{ 94 | { 95 | Type: "float64", 96 | Value: 2, 97 | }, 98 | } 99 | backend.SetStateSuccess(signature, taskResults) 100 | }() 101 | 102 | var ( 103 | taskState *tasks.TaskState 104 | err error 105 | ) 106 | for { 107 | taskState, err = backend.GetState(signature.UUID) 108 | if taskState == nil { 109 | assert.Equal(t, "memcache: cache miss", err.Error()) 110 | continue 111 | } 112 | 113 | assert.NoError(t, err) 114 | if taskState.IsCompleted() { 115 | break 116 | } 117 | } 118 | } 119 | 120 | func TestPurgeState(t *testing.T) { 121 | memcacheURL := os.Getenv("MEMCACHE_URL") 122 | if memcacheURL == "" { 123 | t.Skip("MEMCACHE_URL is not defined") 124 | } 125 | 126 | signature := &tasks.Signature{ 127 | UUID: "testTaskUUID", 128 | GroupUUID: "testGroupUUID", 129 | } 130 | 131 | backend := memcache.New(new(config.Config), []string{memcacheURL}) 132 | 133 | backend.SetStatePending(signature) 134 | taskState, err := backend.GetState(signature.UUID) 135 | assert.NotNil(t, taskState) 136 | assert.NoError(t, err) 137 | 138 | backend.PurgeState(taskState.TaskUUID) 139 | taskState, err = backend.GetState(signature.UUID) 140 | assert.Nil(t, taskState) 141 | assert.Error(t, err) 142 | } 143 | -------------------------------------------------------------------------------- /v1/backends/null/null.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/RichardKnop/machinery/v1/backends/iface" 7 | "github.com/RichardKnop/machinery/v1/common" 8 | "github.com/RichardKnop/machinery/v1/config" 9 | "github.com/RichardKnop/machinery/v1/tasks" 10 | ) 11 | 12 | // ErrGroupNotFound ... 13 | type ErrGroupNotFound struct { 14 | groupUUID string 15 | } 16 | 17 | // NewErrGroupNotFound returns new instance of ErrGroupNotFound 18 | func NewErrGroupNotFound(groupUUID string) ErrGroupNotFound { 19 | return ErrGroupNotFound{groupUUID: groupUUID} 20 | } 21 | 22 | // Error implements error interface 23 | func (e ErrGroupNotFound) Error() string { 24 | return fmt.Sprintf("Group not found: %v", e.groupUUID) 25 | } 26 | 27 | // ErrTasknotFound ... 28 | type ErrTasknotFound struct { 29 | taskUUID string 30 | } 31 | 32 | // NewErrTasknotFound returns new instance of ErrTasknotFound 33 | func NewErrTasknotFound(taskUUID string) ErrTasknotFound { 34 | return ErrTasknotFound{taskUUID: taskUUID} 35 | } 36 | 37 | // Error implements error interface 38 | func (e ErrTasknotFound) Error() string { 39 | return fmt.Sprintf("Task not found: %v", e.taskUUID) 40 | } 41 | 42 | // Backend represents an "null" result backend 43 | type Backend struct { 44 | common.Backend 45 | groups map[string]struct{} 46 | } 47 | 48 | // New creates NullBackend instance 49 | func New() iface.Backend { 50 | return &Backend{ 51 | Backend: common.NewBackend(new(config.Config)), 52 | groups: make(map[string]struct{}), 53 | } 54 | } 55 | 56 | // InitGroup creates and saves a group meta data object 57 | func (b *Backend) InitGroup(groupUUID string, taskUUIDs []string) error { 58 | b.groups[groupUUID] = struct{}{} 59 | return nil 60 | } 61 | 62 | // GroupCompleted returns true (always) 63 | func (b *Backend) GroupCompleted(groupUUID string, groupTaskCount int) (bool, error) { 64 | _, ok := b.groups[groupUUID] 65 | if !ok { 66 | return false, NewErrGroupNotFound(groupUUID) 67 | } 68 | 69 | return true, nil 70 | } 71 | 72 | // GroupTaskStates returns null states of all tasks in the group 73 | func (b *Backend) GroupTaskStates(groupUUID string, groupTaskCount int) ([]*tasks.TaskState, error) { 74 | _, ok := b.groups[groupUUID] 75 | if !ok { 76 | return nil, NewErrGroupNotFound(groupUUID) 77 | } 78 | 79 | ret := make([]*tasks.TaskState, 0, groupTaskCount) 80 | return ret, nil 81 | } 82 | 83 | // TriggerChord returns true (always) 84 | func (b *Backend) TriggerChord(groupUUID string) (bool, error) { 85 | return true, nil 86 | } 87 | 88 | // SetStatePending updates task state to PENDING 89 | func (b *Backend) SetStatePending(signature *tasks.Signature) error { 90 | state := tasks.NewPendingTaskState(signature) 91 | return b.updateState(state) 92 | } 93 | 94 | // SetStateReceived updates task state to RECEIVED 95 | func (b *Backend) SetStateReceived(signature *tasks.Signature) error { 96 | state := tasks.NewReceivedTaskState(signature) 97 | return b.updateState(state) 98 | } 99 | 100 | // SetStateStarted updates task state to STARTED 101 | func (b *Backend) SetStateStarted(signature *tasks.Signature) error { 102 | state := tasks.NewStartedTaskState(signature) 103 | return b.updateState(state) 104 | } 105 | 106 | // SetStateRetry updates task state to RETRY 107 | func (b *Backend) SetStateRetry(signature *tasks.Signature) error { 108 | state := tasks.NewRetryTaskState(signature) 109 | return b.updateState(state) 110 | } 111 | 112 | // SetStateSuccess updates task state to SUCCESS 113 | func (b *Backend) SetStateSuccess(signature *tasks.Signature, results []*tasks.TaskResult) error { 114 | state := tasks.NewSuccessTaskState(signature, results) 115 | return b.updateState(state) 116 | } 117 | 118 | // SetStateFailure updates task state to FAILURE 119 | func (b *Backend) SetStateFailure(signature *tasks.Signature, err string) error { 120 | state := tasks.NewFailureTaskState(signature, err) 121 | return b.updateState(state) 122 | } 123 | 124 | // GetState returns the latest task state 125 | func (b *Backend) GetState(taskUUID string) (*tasks.TaskState, error) { 126 | return nil, NewErrTasknotFound(taskUUID) 127 | } 128 | 129 | // PurgeState deletes stored task state 130 | func (b *Backend) PurgeState(taskUUID string) error { 131 | return NewErrTasknotFound(taskUUID) 132 | } 133 | 134 | // PurgeGroupMeta deletes stored group meta data 135 | func (b *Backend) PurgeGroupMeta(groupUUID string) error { 136 | _, ok := b.groups[groupUUID] 137 | if !ok { 138 | return NewErrGroupNotFound(groupUUID) 139 | } 140 | 141 | return nil 142 | } 143 | 144 | func (b *Backend) updateState(s *tasks.TaskState) error { 145 | return nil 146 | } 147 | -------------------------------------------------------------------------------- /v1/backends/package.go: -------------------------------------------------------------------------------- 1 | package backends 2 | -------------------------------------------------------------------------------- /v1/backends/redis/goredis_test.go: -------------------------------------------------------------------------------- 1 | package redis_test 2 | 3 | import ( 4 | "github.com/RichardKnop/machinery/v1/backends/iface" 5 | "os" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/RichardKnop/machinery/v1/backends/redis" 10 | "github.com/RichardKnop/machinery/v1/config" 11 | "github.com/RichardKnop/machinery/v1/tasks" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func getRedisG() iface.Backend { 16 | // host1:port1,host2:port2 17 | redisURL := os.Getenv("REDIS_URL_GR") 18 | //redisPassword := os.Getenv("REDIS_PASSWORD") 19 | if redisURL == "" { 20 | return nil 21 | } 22 | backend := redis.NewGR(new(config.Config), strings.Split(redisURL, ","), 0) 23 | return backend 24 | } 25 | 26 | func TestGroupCompletedGR(t *testing.T) { 27 | backend := getRedisG() 28 | if backend == nil { 29 | t.Skip() 30 | } 31 | 32 | groupUUID := "testGroupUUID" 33 | task1 := &tasks.Signature{ 34 | UUID: "testTaskUUID1", 35 | GroupUUID: groupUUID, 36 | } 37 | task2 := &tasks.Signature{ 38 | UUID: "testTaskUUID2", 39 | GroupUUID: groupUUID, 40 | } 41 | 42 | // Cleanup before the test 43 | backend.PurgeState(task1.UUID) 44 | backend.PurgeState(task2.UUID) 45 | backend.PurgeGroupMeta(groupUUID) 46 | 47 | groupCompleted, err := backend.GroupCompleted(groupUUID, 2) 48 | if assert.Error(t, err) { 49 | assert.False(t, groupCompleted) 50 | assert.Equal(t, "redis: nil", err.Error()) 51 | } 52 | 53 | backend.InitGroup(groupUUID, []string{task1.UUID, task2.UUID}) 54 | 55 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 56 | if assert.Error(t, err) { 57 | assert.False(t, groupCompleted) 58 | assert.Equal(t, "redis: nil", err.Error()) 59 | } 60 | 61 | backend.SetStatePending(task1) 62 | backend.SetStateStarted(task2) 63 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 64 | if assert.NoError(t, err) { 65 | assert.False(t, groupCompleted) 66 | } 67 | 68 | taskResults := []*tasks.TaskResult{new(tasks.TaskResult)} 69 | backend.SetStateStarted(task1) 70 | backend.SetStateSuccess(task2, taskResults) 71 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 72 | if assert.NoError(t, err) { 73 | assert.False(t, groupCompleted) 74 | } 75 | 76 | backend.SetStateFailure(task1, "Some error") 77 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 78 | if assert.NoError(t, err) { 79 | assert.True(t, groupCompleted) 80 | } 81 | } 82 | 83 | func TestGetStateGR(t *testing.T) { 84 | backend := getRedisG() 85 | if backend == nil { 86 | t.Skip() 87 | } 88 | 89 | signature := &tasks.Signature{ 90 | UUID: "testTaskUUID", 91 | GroupUUID: "testGroupUUID", 92 | } 93 | 94 | backend.PurgeState("testTaskUUID") 95 | 96 | var ( 97 | taskState *tasks.TaskState 98 | err error 99 | ) 100 | 101 | taskState, err = backend.GetState(signature.UUID) 102 | assert.Equal(t, "redis: nil", err.Error()) 103 | assert.Nil(t, taskState) 104 | 105 | //Pending State 106 | backend.SetStatePending(signature) 107 | taskState, err = backend.GetState(signature.UUID) 108 | assert.NoError(t, err) 109 | assert.Equal(t, signature.Name, taskState.TaskName) 110 | createdAt := taskState.CreatedAt 111 | 112 | //Received State 113 | backend.SetStateReceived(signature) 114 | taskState, err = backend.GetState(signature.UUID) 115 | assert.NoError(t, err) 116 | assert.Equal(t, signature.Name, taskState.TaskName) 117 | assert.Equal(t, createdAt, taskState.CreatedAt) 118 | 119 | //Started State 120 | backend.SetStateStarted(signature) 121 | taskState, err = backend.GetState(signature.UUID) 122 | assert.NoError(t, err) 123 | assert.Equal(t, signature.Name, taskState.TaskName) 124 | assert.Equal(t, createdAt, taskState.CreatedAt) 125 | 126 | //Success State 127 | taskResults := []*tasks.TaskResult{ 128 | { 129 | Type: "float64", 130 | Value: 2, 131 | }, 132 | } 133 | backend.SetStateSuccess(signature, taskResults) 134 | taskState, err = backend.GetState(signature.UUID) 135 | assert.NoError(t, err) 136 | assert.Equal(t, signature.Name, taskState.TaskName) 137 | assert.Equal(t, createdAt, taskState.CreatedAt) 138 | assert.NotNil(t, taskState.Results) 139 | } 140 | 141 | func TestPurgeStateGR(t *testing.T) { 142 | backend := getRedisG() 143 | if backend == nil { 144 | t.Skip() 145 | } 146 | 147 | signature := &tasks.Signature{ 148 | UUID: "testTaskUUID", 149 | GroupUUID: "testGroupUUID", 150 | } 151 | 152 | backend.SetStatePending(signature) 153 | taskState, err := backend.GetState(signature.UUID) 154 | assert.NotNil(t, taskState) 155 | assert.NoError(t, err) 156 | 157 | backend.PurgeState(taskState.TaskUUID) 158 | taskState, err = backend.GetState(signature.UUID) 159 | assert.Nil(t, taskState) 160 | assert.Error(t, err) 161 | } 162 | -------------------------------------------------------------------------------- /v1/backends/redis/redis_test.go: -------------------------------------------------------------------------------- 1 | package redis_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/RichardKnop/machinery/v1/backends/redis" 8 | "github.com/RichardKnop/machinery/v1/config" 9 | "github.com/RichardKnop/machinery/v1/tasks" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestGroupCompleted(t *testing.T) { 14 | redisURL := os.Getenv("REDIS_URL") 15 | redisPassword := os.Getenv("REDIS_PASSWORD") 16 | if redisURL == "" { 17 | t.Skip("REDIS_URL is not defined") 18 | } 19 | 20 | groupUUID := "testGroupUUID" 21 | task1 := &tasks.Signature{ 22 | UUID: "testTaskUUID1", 23 | GroupUUID: groupUUID, 24 | } 25 | task2 := &tasks.Signature{ 26 | UUID: "testTaskUUID2", 27 | GroupUUID: groupUUID, 28 | } 29 | 30 | backend := redis.New(new(config.Config), redisURL, "", redisPassword, "", 0) 31 | 32 | // Cleanup before the test 33 | backend.PurgeState(task1.UUID) 34 | backend.PurgeState(task2.UUID) 35 | backend.PurgeGroupMeta(groupUUID) 36 | 37 | groupCompleted, err := backend.GroupCompleted(groupUUID, 2) 38 | if assert.Error(t, err) { 39 | assert.False(t, groupCompleted) 40 | assert.Equal(t, "redigo: nil returned", err.Error()) 41 | } 42 | 43 | backend.InitGroup(groupUUID, []string{task1.UUID, task2.UUID}) 44 | 45 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 46 | if assert.Error(t, err) { 47 | assert.False(t, groupCompleted) 48 | assert.Equal(t, "Expected byte array, instead got: ", err.Error()) 49 | } 50 | 51 | backend.SetStatePending(task1) 52 | backend.SetStateStarted(task2) 53 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 54 | if assert.NoError(t, err) { 55 | assert.False(t, groupCompleted) 56 | } 57 | 58 | taskResults := []*tasks.TaskResult{new(tasks.TaskResult)} 59 | backend.SetStateStarted(task1) 60 | backend.SetStateSuccess(task2, taskResults) 61 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 62 | if assert.NoError(t, err) { 63 | assert.False(t, groupCompleted) 64 | } 65 | 66 | backend.SetStateFailure(task1, "Some error") 67 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 68 | if assert.NoError(t, err) { 69 | assert.True(t, groupCompleted) 70 | } 71 | } 72 | 73 | func TestGetState(t *testing.T) { 74 | redisURL := os.Getenv("REDIS_URL") 75 | redisPassword := os.Getenv("REDIS_PASSWORD") 76 | if redisURL == "" { 77 | return 78 | } 79 | 80 | signature := &tasks.Signature{ 81 | UUID: "testTaskUUID", 82 | GroupUUID: "testGroupUUID", 83 | } 84 | 85 | backend := redis.New(new(config.Config), redisURL, "", redisPassword, "", 0) 86 | 87 | backend.PurgeState("testTaskUUID") 88 | 89 | var ( 90 | taskState *tasks.TaskState 91 | err error 92 | ) 93 | 94 | taskState, err = backend.GetState(signature.UUID) 95 | assert.Equal(t, "redigo: nil returned", err.Error()) 96 | assert.Nil(t, taskState) 97 | 98 | //Pending State 99 | backend.SetStatePending(signature) 100 | taskState, err = backend.GetState(signature.UUID) 101 | assert.NoError(t, err) 102 | assert.Equal(t, signature.Name, taskState.TaskName) 103 | createdAt := taskState.CreatedAt 104 | 105 | //Received State 106 | backend.SetStateReceived(signature) 107 | taskState, err = backend.GetState(signature.UUID) 108 | assert.NoError(t, err) 109 | assert.Equal(t, signature.Name, taskState.TaskName) 110 | assert.Equal(t, createdAt, taskState.CreatedAt) 111 | 112 | //Started State 113 | backend.SetStateStarted(signature) 114 | taskState, err = backend.GetState(signature.UUID) 115 | assert.NoError(t, err) 116 | assert.Equal(t, signature.Name, taskState.TaskName) 117 | assert.Equal(t, createdAt, taskState.CreatedAt) 118 | 119 | //Success State 120 | taskResults := []*tasks.TaskResult{ 121 | { 122 | Type: "float64", 123 | Value: 2, 124 | }, 125 | } 126 | backend.SetStateSuccess(signature, taskResults) 127 | taskState, err = backend.GetState(signature.UUID) 128 | assert.NoError(t, err) 129 | assert.Equal(t, signature.Name, taskState.TaskName) 130 | assert.Equal(t, createdAt, taskState.CreatedAt) 131 | assert.NotNil(t, taskState.Results) 132 | } 133 | 134 | func TestPurgeState(t *testing.T) { 135 | redisURL := os.Getenv("REDIS_URL") 136 | redisPassword := os.Getenv("REDIS_PASSWORD") 137 | if redisURL == "" { 138 | return 139 | } 140 | 141 | signature := &tasks.Signature{ 142 | UUID: "testTaskUUID", 143 | GroupUUID: "testGroupUUID", 144 | } 145 | 146 | backend := redis.New(new(config.Config), redisURL, "", redisPassword, "", 0) 147 | 148 | backend.SetStatePending(signature) 149 | taskState, err := backend.GetState(signature.UUID) 150 | assert.NotNil(t, taskState) 151 | assert.NoError(t, err) 152 | 153 | backend.PurgeState(taskState.TaskUUID) 154 | taskState, err = backend.GetState(signature.UUID) 155 | assert.Nil(t, taskState) 156 | assert.Error(t, err) 157 | } 158 | -------------------------------------------------------------------------------- /v1/brokers/amqp/amqp_concurrence_test.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "fmt" 5 | "github.com/RichardKnop/machinery/v1/brokers/iface" 6 | "github.com/RichardKnop/machinery/v1/config" 7 | "github.com/RichardKnop/machinery/v1/tasks" 8 | amqp "github.com/rabbitmq/amqp091-go" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | type doNothingProcessor struct{} 14 | 15 | func (_ doNothingProcessor) Process(signature *tasks.Signature) error { 16 | return fmt.Errorf("failed") 17 | } 18 | 19 | func (_ doNothingProcessor) CustomQueue() string { 20 | return "oops" 21 | } 22 | 23 | func (_ doNothingProcessor) PreConsumeHandler() bool { 24 | return true 25 | } 26 | 27 | func TestConsume(t *testing.T) { 28 | var ( 29 | iBroker iface.Broker 30 | deliveries = make(chan amqp.Delivery, 3) 31 | closeChan chan *amqp.Error 32 | processor doNothingProcessor 33 | ) 34 | 35 | t.Run("with deliveries more than the number of concurrency", func(t *testing.T) { 36 | iBroker = New(&config.Config{}) 37 | broker, _ := iBroker.(*Broker) 38 | errChan := make(chan error) 39 | 40 | // simulate that there are too much deliveries 41 | go func() { 42 | for i := 0; i < 3; i++ { 43 | deliveries <- amqp.Delivery{} // broker.consumeOne() will complain this error: Received an empty message 44 | } 45 | }() 46 | 47 | go func() { 48 | err := broker.consume(deliveries, 2, processor, closeChan) 49 | if err != nil { 50 | errChan <- err 51 | } 52 | }() 53 | 54 | select { 55 | case <-errChan: 56 | case <-time.After(1 * time.Second): 57 | t.Error("Maybe deadlock") 58 | } 59 | }) 60 | } 61 | -------------------------------------------------------------------------------- /v1/brokers/amqp/amqp_test.go: -------------------------------------------------------------------------------- 1 | package amqp_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v1/brokers/amqp" 7 | "github.com/RichardKnop/machinery/v1/brokers/iface" 8 | "github.com/RichardKnop/machinery/v1/config" 9 | "github.com/RichardKnop/machinery/v1/tasks" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestAdjustRoutingKey(t *testing.T) { 14 | t.Parallel() 15 | 16 | var ( 17 | s *tasks.Signature 18 | broker iface.Broker 19 | ) 20 | 21 | t.Run("with routing and binding keys", func(t *testing.T) { 22 | s := &tasks.Signature{RoutingKey: "routing_key"} 23 | broker = amqp.New(&config.Config{ 24 | DefaultQueue: "queue", 25 | AMQP: &config.AMQPConfig{ 26 | ExchangeType: "direct", 27 | BindingKey: "binding_key", 28 | }, 29 | }) 30 | broker.AdjustRoutingKey(s) 31 | assert.Equal(t, "routing_key", s.RoutingKey) 32 | }) 33 | 34 | t.Run("with binding key", func(t *testing.T) { 35 | s = new(tasks.Signature) 36 | broker = amqp.New(&config.Config{ 37 | DefaultQueue: "queue", 38 | AMQP: &config.AMQPConfig{ 39 | ExchangeType: "direct", 40 | BindingKey: "binding_key", 41 | }, 42 | }) 43 | broker.AdjustRoutingKey(s) 44 | assert.Equal(t, "binding_key", s.RoutingKey) 45 | }) 46 | } 47 | -------------------------------------------------------------------------------- /v1/brokers/eager/eager.go: -------------------------------------------------------------------------------- 1 | package eager 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | 10 | "github.com/RichardKnop/machinery/v1/brokers/iface" 11 | "github.com/RichardKnop/machinery/v1/common" 12 | "github.com/RichardKnop/machinery/v1/tasks" 13 | ) 14 | 15 | // Broker represents an "eager" in-memory broker 16 | type Broker struct { 17 | worker iface.TaskProcessor 18 | common.Broker 19 | } 20 | 21 | // New creates new Broker instance 22 | func New() iface.Broker { 23 | return new(Broker) 24 | } 25 | 26 | // Mode interface with methods specific for this broker 27 | type Mode interface { 28 | AssignWorker(p iface.TaskProcessor) 29 | } 30 | 31 | // StartConsuming enters a loop and waits for incoming messages 32 | func (eagerBroker *Broker) StartConsuming(consumerTag string, concurrency int, p iface.TaskProcessor) (bool, error) { 33 | return true, nil 34 | } 35 | 36 | // StopConsuming quits the loop 37 | func (eagerBroker *Broker) StopConsuming() { 38 | // do nothing 39 | } 40 | 41 | // Publish places a new message on the default queue 42 | func (eagerBroker *Broker) Publish(ctx context.Context, task *tasks.Signature) error { 43 | if eagerBroker.worker == nil { 44 | return errors.New("worker is not assigned in eager-mode") 45 | } 46 | 47 | // faking the behavior to marshal input into json 48 | // and unmarshal it back 49 | message, err := json.Marshal(task) 50 | if err != nil { 51 | return fmt.Errorf("JSON marshal error: %s", err) 52 | } 53 | 54 | signature := new(tasks.Signature) 55 | decoder := json.NewDecoder(bytes.NewReader(message)) 56 | decoder.UseNumber() 57 | if err := decoder.Decode(signature); err != nil { 58 | return fmt.Errorf("JSON unmarshal error: %s", err) 59 | } 60 | 61 | // blocking call to the task directly 62 | return eagerBroker.worker.Process(signature) 63 | } 64 | 65 | // AssignWorker assigns a worker to the eager broker 66 | func (eagerBroker *Broker) AssignWorker(w iface.TaskProcessor) { 67 | eagerBroker.worker = w 68 | } 69 | -------------------------------------------------------------------------------- /v1/brokers/errs/errors.go: -------------------------------------------------------------------------------- 1 | package errs 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | // ErrCouldNotUnmarshalTaskSignature ... 9 | type ErrCouldNotUnmarshalTaskSignature struct { 10 | msg []byte 11 | reason string 12 | } 13 | 14 | // Error implements the error interface 15 | func (e ErrCouldNotUnmarshalTaskSignature) Error() string { 16 | return fmt.Sprintf("Could not unmarshal '%s' into a task signature: %v", e.msg, e.reason) 17 | } 18 | 19 | // NewErrCouldNotUnmarshalTaskSignature returns new ErrCouldNotUnmarshalTaskSignature instance 20 | func NewErrCouldNotUnmarshalTaskSignature(msg []byte, err error) ErrCouldNotUnmarshalTaskSignature { 21 | return ErrCouldNotUnmarshalTaskSignature{msg: msg, reason: err.Error()} 22 | } 23 | 24 | // ErrConsumerStopped indicates that the operation is now illegal because of the consumer being stopped. 25 | var ErrConsumerStopped = errors.New("the server has been stopped") 26 | 27 | // ErrStopTaskDeletion indicates that the task should not be deleted from source after task failure 28 | var ErrStopTaskDeletion = errors.New("task should not be deleted") 29 | -------------------------------------------------------------------------------- /v1/brokers/iface/interfaces.go: -------------------------------------------------------------------------------- 1 | package iface 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/RichardKnop/machinery/v1/config" 7 | "github.com/RichardKnop/machinery/v1/tasks" 8 | ) 9 | 10 | // Broker - a common interface for all brokers 11 | type Broker interface { 12 | GetConfig() *config.Config 13 | SetRegisteredTaskNames(names []string) 14 | IsTaskRegistered(name string) bool 15 | StartConsuming(consumerTag string, concurrency int, p TaskProcessor) (bool, error) 16 | StopConsuming() 17 | Publish(ctx context.Context, task *tasks.Signature) error 18 | GetPendingTasks(queue string) ([]*tasks.Signature, error) 19 | GetDelayedTasks() ([]*tasks.Signature, error) 20 | AdjustRoutingKey(s *tasks.Signature) 21 | } 22 | 23 | // TaskProcessor - can process a delivered task 24 | // This will probably always be a worker instance 25 | type TaskProcessor interface { 26 | Process(signature *tasks.Signature) error 27 | CustomQueue() string 28 | PreConsumeHandler() bool 29 | } 30 | -------------------------------------------------------------------------------- /v1/brokers/package.go: -------------------------------------------------------------------------------- 1 | package brokers 2 | -------------------------------------------------------------------------------- /v1/common/amqp.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "strings" 7 | 8 | amqp "github.com/rabbitmq/amqp091-go" 9 | ) 10 | 11 | // AMQPConnector ... 12 | type AMQPConnector struct{} 13 | 14 | // Connect opens a connection to RabbitMQ, declares an exchange, opens a channel, 15 | // declares and binds the queue and enables publish notifications 16 | func (ac *AMQPConnector) Connect(urls string, urlSeparator string, tlsConfig *tls.Config, exchange, exchangeType, queueName string, queueDurable, queueDelete bool, queueBindingKey string, exchangeDeclareArgs, queueDeclareArgs, queueBindingArgs amqp.Table) (*amqp.Connection, *amqp.Channel, amqp.Queue, <-chan amqp.Confirmation, <-chan *amqp.Error, error) { 17 | urlsList := []string{urls} 18 | if urlSeparator != "" { 19 | urlsList = strings.Split(urls, urlSeparator) 20 | } 21 | 22 | var conn *amqp.Connection 23 | var channel *amqp.Channel 24 | var err error 25 | 26 | for _, url := range urlsList { 27 | // Connect to server 28 | conn, channel, err = ac.Open(url, tlsConfig) 29 | if err != nil { 30 | continue 31 | } else { 32 | break 33 | } 34 | } 35 | 36 | if err != nil { 37 | return nil, nil, amqp.Queue{}, nil, nil, err 38 | } 39 | 40 | if exchange != "" { 41 | // Declare an exchange 42 | if err = channel.ExchangeDeclare( 43 | exchange, // name of the exchange 44 | exchangeType, // type 45 | true, // durable 46 | false, // delete when complete 47 | false, // internal 48 | false, // noWait 49 | exchangeDeclareArgs, // arguments 50 | ); err != nil { 51 | return conn, channel, amqp.Queue{}, nil, nil, fmt.Errorf("Exchange declare error: %s", err) 52 | } 53 | } 54 | 55 | var queue amqp.Queue 56 | if queueName != "" { 57 | // Declare a queue 58 | queue, err = channel.QueueDeclare( 59 | queueName, // name 60 | queueDurable, // durable 61 | queueDelete, // delete when unused 62 | false, // exclusive 63 | false, // no-wait 64 | queueDeclareArgs, // arguments 65 | ) 66 | if err != nil { 67 | return conn, channel, amqp.Queue{}, nil, nil, fmt.Errorf("Queue declare error: %s", err) 68 | } 69 | 70 | // Bind the queue 71 | if err = channel.QueueBind( 72 | queue.Name, // name of the queue 73 | queueBindingKey, // binding key 74 | exchange, // source exchange 75 | false, // noWait 76 | queueBindingArgs, // arguments 77 | ); err != nil { 78 | return conn, channel, queue, nil, nil, fmt.Errorf("Queue bind error: %s", err) 79 | } 80 | } 81 | 82 | // Enable publish confirmations 83 | if err = channel.Confirm(false); err != nil { 84 | return conn, channel, queue, nil, nil, fmt.Errorf("Channel could not be put into confirm mode: %s", err) 85 | } 86 | 87 | return conn, channel, queue, channel.NotifyPublish(make(chan amqp.Confirmation, 1)), conn.NotifyClose(make(chan *amqp.Error, 1)), nil 88 | } 89 | 90 | // DeleteQueue deletes a queue by name 91 | func (ac *AMQPConnector) DeleteQueue(channel *amqp.Channel, queueName string) error { 92 | // First return value is number of messages removed 93 | _, err := channel.QueueDelete( 94 | queueName, // name 95 | false, // ifUnused 96 | false, // ifEmpty 97 | false, // noWait 98 | ) 99 | 100 | return err 101 | } 102 | 103 | // InspectQueue provides information about a specific queue 104 | func (*AMQPConnector) InspectQueue(channel *amqp.Channel, queueName string) (*amqp.Queue, error) { 105 | queueState, err := channel.QueueInspect(queueName) 106 | if err != nil { 107 | return nil, fmt.Errorf("Queue inspect error: %s", err) 108 | } 109 | 110 | return &queueState, nil 111 | } 112 | 113 | // Open new RabbitMQ connection 114 | func (ac *AMQPConnector) Open(url string, tlsConfig *tls.Config) (*amqp.Connection, *amqp.Channel, error) { 115 | // Connect 116 | // From amqp docs: DialTLS will use the provided tls.Config when it encounters an amqps:// scheme 117 | // and will dial a plain connection when it encounters an amqp:// scheme. 118 | conn, err := amqp.DialTLS(url, tlsConfig) 119 | if err != nil { 120 | return nil, nil, fmt.Errorf("Dial error: %s", err) 121 | } 122 | 123 | // Open a channel 124 | channel, err := conn.Channel() 125 | if err != nil { 126 | return nil, nil, fmt.Errorf("Open channel error: %s", err) 127 | } 128 | 129 | return conn, channel, nil 130 | } 131 | 132 | // Close connection 133 | func (ac *AMQPConnector) Close(channel *amqp.Channel, conn *amqp.Connection) error { 134 | if channel != nil { 135 | if err := channel.Close(); err != nil { 136 | return fmt.Errorf("Close channel error: %s", err) 137 | } 138 | } 139 | 140 | if conn != nil { 141 | if err := conn.Close(); err != nil { 142 | return fmt.Errorf("Close connection error: %s", err) 143 | } 144 | } 145 | 146 | return nil 147 | } 148 | -------------------------------------------------------------------------------- /v1/common/backend.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "github.com/RichardKnop/machinery/v1/config" 5 | ) 6 | 7 | // Backend represents a base backend structure 8 | type Backend struct { 9 | cnf *config.Config 10 | } 11 | 12 | // NewBackend creates new Backend instance 13 | func NewBackend(cnf *config.Config) Backend { 14 | return Backend{cnf: cnf} 15 | } 16 | 17 | // GetConfig returns config 18 | func (b *Backend) GetConfig() *config.Config { 19 | return b.cnf 20 | } 21 | 22 | // IsAMQP ... 23 | func (b *Backend) IsAMQP() bool { 24 | return false 25 | } 26 | -------------------------------------------------------------------------------- /v1/common/broker.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | 7 | "github.com/RichardKnop/machinery/v1/brokers/iface" 8 | "github.com/RichardKnop/machinery/v1/config" 9 | "github.com/RichardKnop/machinery/v1/log" 10 | "github.com/RichardKnop/machinery/v1/retry" 11 | "github.com/RichardKnop/machinery/v1/tasks" 12 | ) 13 | 14 | type registeredTaskNames struct { 15 | sync.RWMutex 16 | items []string 17 | } 18 | 19 | // Broker represents a base broker structure 20 | type Broker struct { 21 | cnf *config.Config 22 | registeredTaskNames registeredTaskNames 23 | retry bool 24 | retryFunc func(chan int) 25 | retryStopChan chan int 26 | stopChan chan int 27 | } 28 | 29 | // NewBroker creates new Broker instance 30 | func NewBroker(cnf *config.Config) Broker { 31 | return Broker{ 32 | cnf: cnf, 33 | retry: true, 34 | stopChan: make(chan int), 35 | retryStopChan: make(chan int), 36 | } 37 | } 38 | 39 | // GetConfig returns config 40 | func (b *Broker) GetConfig() *config.Config { 41 | return b.cnf 42 | } 43 | 44 | // GetRetry ... 45 | func (b *Broker) GetRetry() bool { 46 | return b.retry 47 | } 48 | 49 | // GetRetryFunc ... 50 | func (b *Broker) GetRetryFunc() func(chan int) { 51 | return b.retryFunc 52 | } 53 | 54 | // GetRetryStopChan ... 55 | func (b *Broker) GetRetryStopChan() chan int { 56 | return b.retryStopChan 57 | } 58 | 59 | // GetStopChan ... 60 | func (b *Broker) GetStopChan() chan int { 61 | return b.stopChan 62 | } 63 | 64 | // Publish places a new message on the default queue 65 | func (b *Broker) Publish(signature *tasks.Signature) error { 66 | return errors.New("Not implemented") 67 | } 68 | 69 | // SetRegisteredTaskNames sets registered task names 70 | func (b *Broker) SetRegisteredTaskNames(names []string) { 71 | b.registeredTaskNames.Lock() 72 | defer b.registeredTaskNames.Unlock() 73 | b.registeredTaskNames.items = names 74 | } 75 | 76 | // IsTaskRegistered returns true if the task is registered with this broker 77 | func (b *Broker) IsTaskRegistered(name string) bool { 78 | b.registeredTaskNames.RLock() 79 | defer b.registeredTaskNames.RUnlock() 80 | for _, registeredTaskName := range b.registeredTaskNames.items { 81 | if registeredTaskName == name { 82 | return true 83 | } 84 | } 85 | return false 86 | } 87 | 88 | // GetPendingTasks returns a slice of task.Signatures waiting in the queue 89 | func (b *Broker) GetPendingTasks(queue string) ([]*tasks.Signature, error) { 90 | return nil, errors.New("Not implemented") 91 | } 92 | 93 | // GetDelayedTasks returns a slice of task.Signatures that are scheduled, but not yet in the queue 94 | func (b *Broker) GetDelayedTasks() ([]*tasks.Signature, error) { 95 | return nil, errors.New("Not implemented") 96 | } 97 | 98 | // StartConsuming is a common part of StartConsuming method 99 | func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcessor iface.TaskProcessor) { 100 | if b.retryFunc == nil { 101 | b.retryFunc = retry.Closure() 102 | } 103 | 104 | } 105 | 106 | // StopConsuming is a common part of StopConsuming 107 | func (b *Broker) StopConsuming() { 108 | // Do not retry from now on 109 | b.retry = false 110 | // Stop the retry closure earlier 111 | select { 112 | case b.retryStopChan <- 1: 113 | log.WARNING.Print("Stopping retry closure.") 114 | default: 115 | } 116 | // Notifying the stop channel stops consuming of messages 117 | close(b.stopChan) 118 | log.WARNING.Print("Stop channel") 119 | } 120 | 121 | // GetRegisteredTaskNames returns registered tasks names 122 | func (b *Broker) GetRegisteredTaskNames() []string { 123 | b.registeredTaskNames.RLock() 124 | defer b.registeredTaskNames.RUnlock() 125 | items := b.registeredTaskNames.items 126 | return items 127 | } 128 | 129 | // AdjustRoutingKey makes sure the routing key is correct. 130 | // If the routing key is an empty string: 131 | // a) set it to binding key for direct exchange type 132 | // b) set it to default queue name 133 | func (b *Broker) AdjustRoutingKey(s *tasks.Signature) { 134 | if s.RoutingKey != "" { 135 | return 136 | } 137 | 138 | s.RoutingKey = b.GetConfig().DefaultQueue 139 | } 140 | -------------------------------------------------------------------------------- /v1/common/broker_test.go: -------------------------------------------------------------------------------- 1 | package common_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v1" 7 | "github.com/RichardKnop/machinery/v1/common" 8 | "github.com/RichardKnop/machinery/v1/config" 9 | "github.com/RichardKnop/machinery/v1/tasks" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestIsTaskRegistered(t *testing.T) { 14 | t.Parallel() 15 | 16 | broker := common.NewBroker(new(config.Config)) 17 | broker.SetRegisteredTaskNames([]string{"foo", "bar"}) 18 | 19 | assert.True(t, broker.IsTaskRegistered("foo")) 20 | assert.False(t, broker.IsTaskRegistered("bogus")) 21 | } 22 | 23 | func TestAdjustRoutingKey(t *testing.T) { 24 | t.Parallel() 25 | 26 | var ( 27 | s *tasks.Signature 28 | broker common.Broker 29 | ) 30 | 31 | t.Run("with routing key", func(t *testing.T) { 32 | s = &tasks.Signature{RoutingKey: "routing_key"} 33 | broker = common.NewBroker(&config.Config{ 34 | DefaultQueue: "queue", 35 | }) 36 | broker.AdjustRoutingKey(s) 37 | assert.Equal(t, "routing_key", s.RoutingKey) 38 | }) 39 | 40 | t.Run("without routing key", func(t *testing.T) { 41 | s = new(tasks.Signature) 42 | broker = common.NewBroker(&config.Config{ 43 | DefaultQueue: "queue", 44 | }) 45 | broker.AdjustRoutingKey(s) 46 | assert.Equal(t, "queue", s.RoutingKey) 47 | }) 48 | } 49 | 50 | func TestGetRegisteredTaskNames(t *testing.T) { 51 | t.Parallel() 52 | 53 | broker := common.NewBroker(new(config.Config)) 54 | fooTasks := []string{"foo", "bar", "baz"} 55 | broker.SetRegisteredTaskNames(fooTasks) 56 | assert.Equal(t, fooTasks, broker.GetRegisteredTaskNames()) 57 | } 58 | 59 | func TestStopConsuming(t *testing.T) { 60 | t.Parallel() 61 | 62 | t.Run("stop consuming", func(t *testing.T) { 63 | broker := common.NewBroker(&config.Config{ 64 | DefaultQueue: "queue", 65 | }) 66 | broker.StartConsuming("", 1, &machinery.Worker{}) 67 | broker.StopConsuming() 68 | select { 69 | case <-broker.GetStopChan(): 70 | default: 71 | assert.Fail(t, "still blocking") 72 | } 73 | }) 74 | } 75 | -------------------------------------------------------------------------------- /v1/common/redis.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "crypto/tls" 5 | "time" 6 | 7 | "github.com/gomodule/redigo/redis" 8 | 9 | "github.com/RichardKnop/machinery/v1/config" 10 | ) 11 | 12 | var ( 13 | defaultConfig = &config.RedisConfig{ 14 | MaxIdle: 10, 15 | MaxActive: 100, 16 | IdleTimeout: 300, 17 | Wait: true, 18 | ReadTimeout: 15, 19 | WriteTimeout: 15, 20 | ConnectTimeout: 15, 21 | NormalTasksPollPeriod: 1000, 22 | DelayedTasksPollPeriod: 20, 23 | } 24 | ) 25 | 26 | // RedisConnector ... 27 | type RedisConnector struct{} 28 | 29 | // NewPool returns a new pool of Redis connections 30 | func (rc *RedisConnector) NewPool(socketPath, host, username, password string, db int, cnf *config.RedisConfig, tlsConfig *tls.Config) *redis.Pool { 31 | if cnf == nil { 32 | cnf = defaultConfig 33 | } 34 | return &redis.Pool{ 35 | MaxIdle: cnf.MaxIdle, 36 | IdleTimeout: time.Duration(cnf.IdleTimeout) * time.Second, 37 | MaxActive: cnf.MaxActive, 38 | Wait: cnf.Wait, 39 | Dial: func() (redis.Conn, error) { 40 | c, err := rc.open(socketPath, host, username, password, db, cnf, tlsConfig) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | if db != 0 { 46 | _, err = c.Do("SELECT", db) 47 | if err != nil { 48 | return nil, err 49 | } 50 | } 51 | 52 | return c, err 53 | }, 54 | // PINGs connections that have been idle more than 10 seconds 55 | TestOnBorrow: func(c redis.Conn, t time.Time) error { 56 | if time.Since(t) < time.Duration(10*time.Second) { 57 | return nil 58 | } 59 | _, err := c.Do("PING") 60 | return err 61 | }, 62 | } 63 | } 64 | 65 | // Open a new Redis connection 66 | func (rc *RedisConnector) open(socketPath, host, username string, password string, db int, cnf *config.RedisConfig, tlsConfig *tls.Config) (redis.Conn, error) { 67 | var opts = []redis.DialOption{ 68 | redis.DialDatabase(db), 69 | redis.DialReadTimeout(time.Duration(cnf.ReadTimeout) * time.Second), 70 | redis.DialWriteTimeout(time.Duration(cnf.WriteTimeout) * time.Second), 71 | redis.DialConnectTimeout(time.Duration(cnf.ConnectTimeout) * time.Second), 72 | redis.DialClientName(cnf.ClientName), 73 | } 74 | 75 | if tlsConfig != nil { 76 | opts = append(opts, redis.DialTLSConfig(tlsConfig), redis.DialUseTLS(true)) 77 | } 78 | 79 | if username != "" { 80 | opts = append(opts, redis.DialUsername(username)) 81 | } 82 | 83 | if password != "" { 84 | opts = append(opts, redis.DialPassword(password)) 85 | } 86 | 87 | if socketPath != "" { 88 | return redis.Dial("unix", socketPath, opts...) 89 | } 90 | 91 | return redis.Dial("tcp", host, opts...) 92 | } 93 | -------------------------------------------------------------------------------- /v1/config/env.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/kelseyhightower/envconfig" 5 | 6 | "github.com/RichardKnop/machinery/v1/log" 7 | ) 8 | 9 | // NewFromEnvironment creates a config object from environment variables 10 | func NewFromEnvironment() (*Config, error) { 11 | cnf, err := fromEnvironment() 12 | if err != nil { 13 | return nil, err 14 | } 15 | 16 | log.INFO.Print("Successfully loaded config from the environment") 17 | 18 | return cnf, nil 19 | } 20 | 21 | func fromEnvironment() (*Config, error) { 22 | loadedCnf, cnf := new(Config), new(Config) 23 | *cnf = *defaultCnf 24 | 25 | if err := envconfig.Process("", cnf); err != nil { 26 | return nil, err 27 | } 28 | if err := envconfig.Process("", loadedCnf); err != nil { 29 | return nil, err 30 | } 31 | 32 | if loadedCnf.AMQP == nil { 33 | cnf.AMQP = nil 34 | } 35 | 36 | return cnf, nil 37 | } 38 | -------------------------------------------------------------------------------- /v1/config/env_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "bufio" 5 | "os" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/RichardKnop/machinery/v1/config" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestNewFromEnvironment(t *testing.T) { 14 | t.Parallel() 15 | 16 | file, err := os.Open("test.env") 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | reader := bufio.NewReader(file) 21 | scanner := bufio.NewScanner(reader) 22 | scanner.Split(bufio.ScanLines) 23 | for scanner.Scan() { 24 | parts := strings.Split(scanner.Text(), "=") 25 | if len(parts) != 2 { 26 | continue 27 | } 28 | os.Setenv(parts[0], parts[1]) 29 | } 30 | 31 | cnf, err := config.NewFromEnvironment() 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | 36 | assert.Equal(t, "broker", cnf.Broker) 37 | assert.Equal(t, "default_queue", cnf.DefaultQueue) 38 | assert.Equal(t, "result_backend", cnf.ResultBackend) 39 | assert.Equal(t, 123456, cnf.ResultsExpireIn) 40 | assert.Equal(t, "exchange", cnf.AMQP.Exchange) 41 | assert.Equal(t, "exchange_type", cnf.AMQP.ExchangeType) 42 | assert.Equal(t, "binding_key", cnf.AMQP.BindingKey) 43 | assert.Equal(t, "any", cnf.AMQP.QueueBindingArgs["x-match"]) 44 | assert.Equal(t, "png", cnf.AMQP.QueueBindingArgs["image-type"]) 45 | assert.Equal(t, 123, cnf.AMQP.PrefetchCount) 46 | } 47 | -------------------------------------------------------------------------------- /v1/config/file.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "time" 7 | 8 | "github.com/RichardKnop/machinery/v1/log" 9 | "gopkg.in/yaml.v2" 10 | ) 11 | 12 | // NewFromYaml creates a config object from YAML file 13 | func NewFromYaml(cnfPath string, keepReloading bool) (*Config, error) { 14 | cnf, err := fromFile(cnfPath) 15 | if err != nil { 16 | return nil, err 17 | } 18 | 19 | log.INFO.Printf("Successfully loaded config from file %s", cnfPath) 20 | 21 | if keepReloading { 22 | // Open a goroutine to watch remote changes forever 23 | go func() { 24 | for { 25 | // Delay after each request 26 | time.Sleep(reloadDelay) 27 | 28 | // Attempt to reload the config 29 | newCnf, newErr := fromFile(cnfPath) 30 | if newErr != nil { 31 | log.WARNING.Printf("Failed to reload config from file %s: %v", cnfPath, newErr) 32 | continue 33 | } 34 | 35 | *cnf = *newCnf 36 | } 37 | }() 38 | } 39 | 40 | return cnf, nil 41 | } 42 | 43 | // ReadFromFile reads data from a file 44 | func ReadFromFile(cnfPath string) ([]byte, error) { 45 | file, err := os.Open(cnfPath) 46 | 47 | // Config file not found 48 | if err != nil { 49 | return nil, fmt.Errorf("Open file error: %s", err) 50 | } 51 | defer file.Close() 52 | 53 | // Config file found, let's try to read it 54 | data := make([]byte, 1000) 55 | count, err := file.Read(data) 56 | if err != nil { 57 | return nil, fmt.Errorf("Read from file error: %s", err) 58 | } 59 | 60 | return data[:count], nil 61 | } 62 | 63 | func fromFile(cnfPath string) (*Config, error) { 64 | loadedCnf, cnf := new(Config), new(Config) 65 | *cnf = *defaultCnf 66 | 67 | data, err := ReadFromFile(cnfPath) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | if err := yaml.Unmarshal(data, cnf); err != nil { 73 | return nil, fmt.Errorf("Unmarshal YAML error: %s", err) 74 | } 75 | if err := yaml.Unmarshal(data, loadedCnf); err != nil { 76 | return nil, fmt.Errorf("Unmarshal YAML error: %s", err) 77 | } 78 | if loadedCnf.AMQP == nil { 79 | cnf.AMQP = nil 80 | } 81 | 82 | return cnf, nil 83 | } 84 | -------------------------------------------------------------------------------- /v1/config/file_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v1/config" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | var configYAMLData = `--- 11 | broker: broker 12 | default_queue: default_queue 13 | result_backend: result_backend 14 | results_expire_in: 123456 15 | amqp: 16 | binding_key: binding_key 17 | exchange: exchange 18 | exchange_type: exchange_type 19 | prefetch_count: 123 20 | queue_declare_args: 21 | x-max-priority: 10 22 | queue_binding_args: 23 | image-type: png 24 | x-match: any 25 | sqs: 26 | receive_wait_time_seconds: 123 27 | receive_visibility_timeout: 456 28 | redis: 29 | max_idle: 12 30 | max_active: 123 31 | max_idle_timeout: 456 32 | wait: false 33 | read_timeout: 17 34 | write_timeout: 19 35 | connect_timeout: 21 36 | normal_tasks_poll_period: 1001 37 | delayed_tasks_poll_period: 23 38 | delayed_tasks_key: delayed_tasks_key 39 | master_name: master_name 40 | no_unix_signals: true 41 | dynamodb: 42 | task_states_table: task_states_table 43 | group_metas_table: group_metas_table 44 | ` 45 | 46 | func TestReadFromFile(t *testing.T) { 47 | t.Parallel() 48 | 49 | data, err := config.ReadFromFile("testconfig.yml") 50 | if err != nil { 51 | t.Fatal(err) 52 | } 53 | 54 | assert.Equal(t, configYAMLData, string(data)) 55 | } 56 | 57 | func TestNewFromYaml(t *testing.T) { 58 | t.Parallel() 59 | 60 | cnf, err := config.NewFromYaml("testconfig.yml", false) 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | assert.Equal(t, "broker", cnf.Broker) 66 | assert.Equal(t, "default_queue", cnf.DefaultQueue) 67 | assert.Equal(t, "result_backend", cnf.ResultBackend) 68 | assert.Equal(t, 123456, cnf.ResultsExpireIn) 69 | 70 | assert.Equal(t, "exchange", cnf.AMQP.Exchange) 71 | assert.Equal(t, "exchange_type", cnf.AMQP.ExchangeType) 72 | assert.Equal(t, "binding_key", cnf.AMQP.BindingKey) 73 | assert.Equal(t, 10, cnf.AMQP.QueueDeclareArgs["x-max-priority"]) 74 | assert.Equal(t, "any", cnf.AMQP.QueueBindingArgs["x-match"]) 75 | assert.Equal(t, "png", cnf.AMQP.QueueBindingArgs["image-type"]) 76 | assert.Equal(t, 123, cnf.AMQP.PrefetchCount) 77 | 78 | assert.Equal(t, 123, cnf.SQS.WaitTimeSeconds) 79 | assert.Equal(t, 456, *cnf.SQS.VisibilityTimeout) 80 | 81 | assert.Equal(t, 12, cnf.Redis.MaxIdle) 82 | assert.Equal(t, 123, cnf.Redis.MaxActive) 83 | assert.Equal(t, 456, cnf.Redis.IdleTimeout) 84 | assert.Equal(t, false, cnf.Redis.Wait) 85 | assert.Equal(t, 17, cnf.Redis.ReadTimeout) 86 | assert.Equal(t, 19, cnf.Redis.WriteTimeout) 87 | assert.Equal(t, 21, cnf.Redis.ConnectTimeout) 88 | assert.Equal(t, 1001, cnf.Redis.NormalTasksPollPeriod) 89 | assert.Equal(t, 23, cnf.Redis.DelayedTasksPollPeriod) 90 | assert.Equal(t, "delayed_tasks_key", cnf.Redis.DelayedTasksKey) 91 | assert.Equal(t, "master_name", cnf.Redis.MasterName) 92 | 93 | assert.Equal(t, true, cnf.NoUnixSignals) 94 | 95 | assert.Equal(t, "task_states_table", cnf.DynamoDB.TaskStatesTable) 96 | assert.Equal(t, "group_metas_table", cnf.DynamoDB.GroupMetasTable) 97 | } 98 | -------------------------------------------------------------------------------- /v1/config/test.env: -------------------------------------------------------------------------------- 1 | BROKER=broker 2 | DEFAULT_QUEUE=default_queue 3 | RESULT_BACKEND=result_backend 4 | RESULTS_EXPIRE_IN=123456 5 | AMQP_BINDING_KEY=binding_key 6 | AMQP_EXCHANGE=exchange 7 | AMQP_EXCHANGE_TYPE=exchange_type 8 | AMQP_PREFETCH_COUNT=123 9 | AMQP_QUEUE_BINDING_ARGS=image-type:png,x-match:any 10 | -------------------------------------------------------------------------------- /v1/config/testconfig.yml: -------------------------------------------------------------------------------- 1 | --- 2 | broker: broker 3 | default_queue: default_queue 4 | result_backend: result_backend 5 | results_expire_in: 123456 6 | amqp: 7 | binding_key: binding_key 8 | exchange: exchange 9 | exchange_type: exchange_type 10 | prefetch_count: 123 11 | queue_declare_args: 12 | x-max-priority: 10 13 | queue_binding_args: 14 | image-type: png 15 | x-match: any 16 | sqs: 17 | receive_wait_time_seconds: 123 18 | receive_visibility_timeout: 456 19 | redis: 20 | max_idle: 12 21 | max_active: 123 22 | max_idle_timeout: 456 23 | wait: false 24 | read_timeout: 17 25 | write_timeout: 19 26 | connect_timeout: 21 27 | normal_tasks_poll_period: 1001 28 | delayed_tasks_poll_period: 23 29 | delayed_tasks_key: delayed_tasks_key 30 | master_name: master_name 31 | no_unix_signals: true 32 | dynamodb: 33 | task_states_table: task_states_table 34 | group_metas_table: group_metas_table 35 | -------------------------------------------------------------------------------- /v1/locks/eager/eager.go: -------------------------------------------------------------------------------- 1 | package eager 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | var ( 10 | ErrEagerLockFailed = errors.New("eager lock: failed to acquire lock") 11 | ) 12 | 13 | type Lock struct { 14 | retries int 15 | interval time.Duration 16 | register struct { 17 | sync.RWMutex 18 | m map[string]int64 19 | } 20 | } 21 | 22 | func New() *Lock { 23 | return &Lock{ 24 | retries: 3, 25 | interval: 5 * time.Second, 26 | register: struct { 27 | sync.RWMutex 28 | m map[string]int64 29 | }{ 30 | m: make(map[string]int64), 31 | }, 32 | } 33 | } 34 | 35 | func (e *Lock) LockWithRetries(key string, value int64) error { 36 | for i := 0; i <= e.retries; i++ { 37 | err := e.Lock(key, value) 38 | if err == nil { 39 | //成功拿到锁,返回 40 | return nil 41 | } 42 | 43 | time.Sleep(e.interval) 44 | } 45 | return ErrEagerLockFailed 46 | } 47 | 48 | func (e *Lock) Lock(key string, value int64) error { 49 | e.register.Lock() 50 | defer e.register.Unlock() 51 | timeout, exist := e.register.m[key] 52 | if !exist || time.Now().UnixNano() > timeout { 53 | e.register.m[key] = value 54 | return nil 55 | } 56 | return ErrEagerLockFailed 57 | } 58 | -------------------------------------------------------------------------------- /v1/locks/eager/eager_test.go: -------------------------------------------------------------------------------- 1 | package eager 2 | 3 | import ( 4 | lockiface "github.com/RichardKnop/machinery/v1/locks/iface" 5 | "github.com/RichardKnop/machinery/v1/utils" 6 | "github.com/stretchr/testify/assert" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestLock_Lock(t *testing.T) { 12 | lock := New() 13 | keyName := utils.GetPureUUID() 14 | 15 | go func() { 16 | err := lock.Lock(keyName, time.Now().Add(25*time.Second).UnixNano()) 17 | assert.NoError(t, err) 18 | }() 19 | time.Sleep(1 * time.Second) 20 | err := lock.Lock(keyName, time.Now().Add(25*time.Second).UnixNano()) 21 | assert.Error(t, err) 22 | assert.EqualError(t, err, ErrEagerLockFailed.Error()) 23 | } 24 | 25 | func TestLock_LockWithRetries(t *testing.T) { 26 | lock := New() 27 | keyName := utils.GetPureUUID() 28 | 29 | go func() { 30 | err := lock.LockWithRetries(keyName, time.Now().Add(25*time.Second).UnixNano()) 31 | assert.NoError(t, err) 32 | }() 33 | time.Sleep(1 * time.Second) 34 | err := lock.LockWithRetries(keyName, time.Now().Add(25*time.Second).UnixNano()) 35 | assert.Error(t, err) 36 | assert.EqualError(t, err, ErrEagerLockFailed.Error()) 37 | } 38 | 39 | func TestNew(t *testing.T) { 40 | lock := New() 41 | assert.Implements(t, (*lockiface.Lock)(nil), lock) 42 | } 43 | -------------------------------------------------------------------------------- /v1/locks/iface/interfaces.go: -------------------------------------------------------------------------------- 1 | package iface 2 | 3 | type Lock interface { 4 | //Acquire the lock with retry 5 | //key: the name of the lock, 6 | //value: at the nanosecond timestamp that lock needs to be released automatically 7 | LockWithRetries(key string, value int64) error 8 | 9 | //Acquire the lock with once 10 | //key: the name of the lock, 11 | //value: at the nanosecond timestamp that lock needs to be released automatically 12 | Lock(key string, value int64) error 13 | } 14 | -------------------------------------------------------------------------------- /v1/locks/redis/redis.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "strconv" 7 | "strings" 8 | "time" 9 | 10 | "github.com/RichardKnop/machinery/v1/config" 11 | "github.com/redis/go-redis/v9" 12 | ) 13 | 14 | var ( 15 | ErrRedisLockFailed = errors.New("redis lock: failed to acquire lock") 16 | ) 17 | 18 | type Lock struct { 19 | rclient redis.UniversalClient 20 | retries int 21 | interval time.Duration 22 | } 23 | 24 | func New(cnf *config.Config, addrs []string, db, retries int) Lock { 25 | if retries <= 0 { 26 | return Lock{} 27 | } 28 | lock := Lock{retries: retries} 29 | 30 | var password string 31 | 32 | parts := strings.Split(addrs[0], "@") 33 | if len(parts) >= 2 { 34 | password = strings.Join(parts[:len(parts)-1], "@") 35 | addrs[0] = parts[len(parts)-1] // addr is the last one without @ 36 | } 37 | 38 | ropt := &redis.UniversalOptions{ 39 | Addrs: addrs, 40 | DB: db, 41 | Password: password, 42 | } 43 | if cnf.Redis != nil { 44 | ropt.MasterName = cnf.Redis.MasterName 45 | } 46 | 47 | lock.rclient = redis.NewUniversalClient(ropt) 48 | 49 | return lock 50 | } 51 | 52 | func (r Lock) LockWithRetries(key string, unixTsToExpireNs int64) error { 53 | for i := 0; i <= r.retries; i++ { 54 | err := r.Lock(key, unixTsToExpireNs) 55 | if err == nil { 56 | //成功拿到锁,返回 57 | return nil 58 | } 59 | 60 | time.Sleep(r.interval) 61 | } 62 | return ErrRedisLockFailed 63 | } 64 | 65 | func (r Lock) Lock(key string, unixTsToExpireNs int64) error { 66 | now := time.Now().UnixNano() 67 | expiration := time.Duration(unixTsToExpireNs + 1 - now) 68 | // ctx := r.rclient.Context() 69 | ctx := context.Background() 70 | 71 | success, err := r.rclient.SetNX(ctx, key, unixTsToExpireNs, expiration).Result() 72 | if err != nil { 73 | return err 74 | } 75 | 76 | if !success { 77 | v, err := r.rclient.Get(ctx, key).Result() 78 | if err != nil { 79 | return err 80 | } 81 | timeout, err := strconv.Atoi(v) 82 | if err != nil { 83 | return err 84 | } 85 | 86 | if timeout != 0 && now > int64(timeout) { 87 | newTimeout, err := r.rclient.GetSet(ctx, key, unixTsToExpireNs).Result() 88 | if err != nil { 89 | return err 90 | } 91 | 92 | curTimeout, err := strconv.Atoi(newTimeout) 93 | if err != nil { 94 | return err 95 | } 96 | 97 | if now > int64(curTimeout) { 98 | // success to acquire lock with get set 99 | // set the expiration of redis key 100 | r.rclient.Expire(ctx, key, expiration) 101 | return nil 102 | } 103 | 104 | return ErrRedisLockFailed 105 | } 106 | 107 | return ErrRedisLockFailed 108 | } 109 | 110 | return nil 111 | } 112 | -------------------------------------------------------------------------------- /v1/log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "github.com/RichardKnop/logging" 5 | ) 6 | 7 | var ( 8 | logger = logging.New(nil, nil, new(logging.ColouredFormatter)) 9 | 10 | // DEBUG ... 11 | DEBUG = logger[logging.DEBUG] 12 | // INFO ... 13 | INFO = logger[logging.INFO] 14 | // WARNING ... 15 | WARNING = logger[logging.WARNING] 16 | // ERROR ... 17 | ERROR = logger[logging.ERROR] 18 | // FATAL ... 19 | FATAL = logger[logging.FATAL] 20 | ) 21 | 22 | // Set sets a custom logger for all log levels 23 | func Set(l logging.LoggerInterface) { 24 | DEBUG = l 25 | INFO = l 26 | WARNING = l 27 | ERROR = l 28 | FATAL = l 29 | } 30 | 31 | // SetDebug sets a custom logger for DEBUG level logs 32 | func SetDebug(l logging.LoggerInterface) { 33 | DEBUG = l 34 | } 35 | 36 | // SetInfo sets a custom logger for INFO level logs 37 | func SetInfo(l logging.LoggerInterface) { 38 | INFO = l 39 | } 40 | 41 | // SetWarning sets a custom logger for WARNING level logs 42 | func SetWarning(l logging.LoggerInterface) { 43 | WARNING = l 44 | } 45 | 46 | // SetError sets a custom logger for ERROR level logs 47 | func SetError(l logging.LoggerInterface) { 48 | ERROR = l 49 | } 50 | 51 | // SetFatal sets a custom logger for FATAL level logs 52 | func SetFatal(l logging.LoggerInterface) { 53 | FATAL = l 54 | } 55 | -------------------------------------------------------------------------------- /v1/log/log_test.go: -------------------------------------------------------------------------------- 1 | package log_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v1/log" 7 | ) 8 | 9 | func TestDefaultLogger(t *testing.T) { 10 | log.INFO.Print("should not panic") 11 | log.WARNING.Print("should not panic") 12 | log.ERROR.Print("should not panic") 13 | log.FATAL.Print("should not panic") 14 | } 15 | -------------------------------------------------------------------------------- /v1/package.go: -------------------------------------------------------------------------------- 1 | package machinery 2 | -------------------------------------------------------------------------------- /v1/retry/fibonacci.go: -------------------------------------------------------------------------------- 1 | package retry 2 | 3 | // Fibonacci returns successive Fibonacci numbers starting from 1 4 | func Fibonacci() func() int { 5 | a, b := 0, 1 6 | return func() int { 7 | a, b = b, a+b 8 | return a 9 | } 10 | } 11 | 12 | // FibonacciNext returns next number in Fibonacci sequence greater than start 13 | func FibonacciNext(start int) int { 14 | fib := Fibonacci() 15 | num := fib() 16 | for num <= start { 17 | num = fib() 18 | } 19 | return num 20 | } 21 | -------------------------------------------------------------------------------- /v1/retry/fibonacci_test.go: -------------------------------------------------------------------------------- 1 | package retry_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v1/retry" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestFibonacci(t *testing.T) { 11 | fibonacci := retry.Fibonacci() 12 | 13 | sequence := []int{ 14 | fibonacci(), 15 | fibonacci(), 16 | fibonacci(), 17 | fibonacci(), 18 | fibonacci(), 19 | fibonacci(), 20 | } 21 | 22 | assert.EqualValues(t, sequence, []int{1, 1, 2, 3, 5, 8}) 23 | } 24 | 25 | func TestFibonacciNext(t *testing.T) { 26 | assert.Equal(t, 1, retry.FibonacciNext(0)) 27 | assert.Equal(t, 2, retry.FibonacciNext(1)) 28 | assert.Equal(t, 5, retry.FibonacciNext(3)) 29 | assert.Equal(t, 5, retry.FibonacciNext(4)) 30 | assert.Equal(t, 8, retry.FibonacciNext(5)) 31 | assert.Equal(t, 13, retry.FibonacciNext(8)) 32 | } 33 | -------------------------------------------------------------------------------- /v1/retry/retry.go: -------------------------------------------------------------------------------- 1 | package retry 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/RichardKnop/machinery/v1/log" 8 | ) 9 | 10 | // Closure - a useful closure we can use when there is a problem 11 | // connecting to the broker. It uses Fibonacci sequence to space out retry attempts 12 | var Closure = func() func(chan int) { 13 | retryIn := 0 14 | fibonacci := Fibonacci() 15 | return func(stopChan chan int) { 16 | if retryIn > 0 { 17 | durationString := fmt.Sprintf("%vs", retryIn) 18 | duration, _ := time.ParseDuration(durationString) 19 | 20 | log.WARNING.Printf("Retrying in %v seconds", retryIn) 21 | 22 | select { 23 | case <-stopChan: 24 | break 25 | case <-time.After(duration): 26 | break 27 | } 28 | } 29 | retryIn = fibonacci() 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /v1/server_test.go: -------------------------------------------------------------------------------- 1 | package machinery_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/RichardKnop/machinery/v1" 9 | "github.com/RichardKnop/machinery/v1/config" 10 | ) 11 | 12 | func TestRegisterTasks(t *testing.T) { 13 | t.Parallel() 14 | 15 | server := getTestServer(t) 16 | err := server.RegisterTasks(map[string]interface{}{ 17 | "test_task": func() error { return nil }, 18 | }) 19 | assert.NoError(t, err) 20 | 21 | _, err = server.GetRegisteredTask("test_task") 22 | assert.NoError(t, err, "test_task is not registered but it should be") 23 | } 24 | 25 | func TestRegisterTask(t *testing.T) { 26 | t.Parallel() 27 | 28 | server := getTestServer(t) 29 | err := server.RegisterTask("test_task", func() error { return nil }) 30 | assert.NoError(t, err) 31 | 32 | _, err = server.GetRegisteredTask("test_task") 33 | assert.NoError(t, err, "test_task is not registered but it should be") 34 | } 35 | 36 | func TestRegisterTaskInRaceCondition(t *testing.T) { 37 | t.Parallel() 38 | 39 | server := getTestServer(t) 40 | for i:=0; i<10; i++ { 41 | go func() { 42 | err := server.RegisterTask("test_task", func() error { return nil }) 43 | assert.NoError(t, err) 44 | _, err = server.GetRegisteredTask("test_task") 45 | assert.NoError(t, err, "test_task is not registered but it should be") 46 | }() 47 | } 48 | } 49 | 50 | func TestGetRegisteredTask(t *testing.T) { 51 | t.Parallel() 52 | 53 | server := getTestServer(t) 54 | _, err := server.GetRegisteredTask("test_task") 55 | assert.Error(t, err, "test_task is registered but it should not be") 56 | } 57 | 58 | func TestGetRegisteredTaskNames(t *testing.T) { 59 | t.Parallel() 60 | 61 | server := getTestServer(t) 62 | 63 | taskName := "test_task" 64 | err := server.RegisterTask(taskName, func() error { return nil }) 65 | assert.NoError(t, err) 66 | 67 | taskNames := server.GetRegisteredTaskNames() 68 | assert.Equal(t, 1, len(taskNames)) 69 | assert.Equal(t, taskName, taskNames[0]) 70 | } 71 | 72 | func TestNewWorker(t *testing.T) { 73 | t.Parallel() 74 | 75 | server := getTestServer(t) 76 | 77 | server.NewWorker("test_worker", 1) 78 | assert.NoError(t, nil) 79 | } 80 | 81 | func TestNewCustomQueueWorker(t *testing.T) { 82 | t.Parallel() 83 | 84 | server := getTestServer(t) 85 | 86 | server.NewCustomQueueWorker("test_customqueueworker", 1, "test_queue") 87 | assert.NoError(t, nil) 88 | } 89 | 90 | func getTestServer(t *testing.T) *machinery.Server { 91 | server, err := machinery.NewServer(&config.Config{ 92 | Broker: "amqp://guest:guest@localhost:5672/", 93 | DefaultQueue: "machinery_tasks", 94 | ResultBackend: "redis://127.0.0.1:6379", 95 | Lock: "redis://127.0.0.1:6379", 96 | AMQP: &config.AMQPConfig{ 97 | Exchange: "machinery_exchange", 98 | ExchangeType: "direct", 99 | BindingKey: "machinery_task", 100 | PrefetchCount: 1, 101 | }, 102 | }) 103 | if err != nil { 104 | t.Error(err) 105 | } 106 | return server 107 | } 108 | -------------------------------------------------------------------------------- /v1/tasks/errors.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | // ErrRetryTaskLater ... 9 | type ErrRetryTaskLater struct { 10 | name, msg string 11 | retryIn time.Duration 12 | } 13 | 14 | // RetryIn returns time.Duration from now when task should be retried 15 | func (e ErrRetryTaskLater) RetryIn() time.Duration { 16 | return e.retryIn 17 | } 18 | 19 | // Error implements the error interface 20 | func (e ErrRetryTaskLater) Error() string { 21 | return fmt.Sprintf("Task error: %s Will retry in: %s", e.msg, e.retryIn) 22 | } 23 | 24 | // NewErrRetryTaskLater returns new ErrRetryTaskLater instance 25 | func NewErrRetryTaskLater(msg string, retryIn time.Duration) ErrRetryTaskLater { 26 | return ErrRetryTaskLater{msg: msg, retryIn: retryIn} 27 | } 28 | 29 | // Retriable is interface that retriable errors should implement 30 | type Retriable interface { 31 | RetryIn() time.Duration 32 | } 33 | -------------------------------------------------------------------------------- /v1/tasks/result.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | ) 8 | 9 | // TaskResult represents an actual return value of a processed task 10 | type TaskResult struct { 11 | Type string `bson:"type"` 12 | Value interface{} `bson:"value"` 13 | } 14 | 15 | // ReflectTaskResults ... 16 | func ReflectTaskResults(taskResults []*TaskResult) ([]reflect.Value, error) { 17 | resultValues := make([]reflect.Value, len(taskResults)) 18 | for i, taskResult := range taskResults { 19 | resultValue, err := ReflectValue(taskResult.Type, taskResult.Value) 20 | if err != nil { 21 | return nil, err 22 | } 23 | resultValues[i] = resultValue 24 | } 25 | return resultValues, nil 26 | } 27 | 28 | // HumanReadableResults ... 29 | func HumanReadableResults(results []reflect.Value) string { 30 | if len(results) == 1 { 31 | return fmt.Sprintf("%v", results[0].Interface()) 32 | } 33 | 34 | readableResults := make([]string, len(results)) 35 | for i := 0; i < len(results); i++ { 36 | readableResults[i] = fmt.Sprintf("%v", results[i].Interface()) 37 | } 38 | 39 | return fmt.Sprintf("[%s]", strings.Join(readableResults, ", ")) 40 | } 41 | -------------------------------------------------------------------------------- /v1/tasks/result_test.go: -------------------------------------------------------------------------------- 1 | package tasks_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v1/tasks" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestReflectTaskResults(t *testing.T) { 11 | t.Parallel() 12 | 13 | taskResults := []*tasks.TaskResult{ 14 | { 15 | Type: "[]string", 16 | Value: []string{"f", "o", "o"}, 17 | }, 18 | } 19 | results, err := tasks.ReflectTaskResults(taskResults) 20 | if assert.NoError(t, err) { 21 | assert.Equal(t, 1, len(results)) 22 | assert.Equal(t, 3, results[0].Len()) 23 | assert.Equal(t, "f", results[0].Index(0).String()) 24 | assert.Equal(t, "o", results[0].Index(1).String()) 25 | assert.Equal(t, "o", results[0].Index(2).String()) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /v1/tasks/signature.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import ( 4 | "fmt" 5 | "github.com/RichardKnop/machinery/v1/utils" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | ) 10 | 11 | // Arg represents a single argument passed to invocation fo a task 12 | type Arg struct { 13 | Name string `bson:"name"` 14 | Type string `bson:"type"` 15 | Value interface{} `bson:"value"` 16 | } 17 | 18 | // Headers represents the headers which should be used to direct the task 19 | type Headers map[string]interface{} 20 | 21 | // Set on Headers implements opentracing.TextMapWriter for trace propagation 22 | func (h Headers) Set(key, val string) { 23 | h[key] = val 24 | } 25 | 26 | // ForeachKey on Headers implements opentracing.TextMapReader for trace propagation. 27 | // It is essentially the same as the opentracing.TextMapReader implementation except 28 | // for the added casting from interface{} to string. 29 | func (h Headers) ForeachKey(handler func(key, val string) error) error { 30 | for k, v := range h { 31 | // Skip any non string values 32 | stringValue, ok := v.(string) 33 | if !ok { 34 | continue 35 | } 36 | 37 | if err := handler(k, stringValue); err != nil { 38 | return err 39 | } 40 | } 41 | 42 | return nil 43 | } 44 | 45 | // Signature represents a single task invocation 46 | type Signature struct { 47 | UUID string 48 | Name string 49 | RoutingKey string 50 | ETA *time.Time 51 | GroupUUID string 52 | GroupTaskCount int 53 | Args []Arg 54 | Headers Headers 55 | Priority uint8 56 | Immutable bool 57 | RetryCount int 58 | RetryTimeout int 59 | OnSuccess []*Signature 60 | OnError []*Signature 61 | ChordCallback *Signature 62 | //MessageGroupId for Broker, e.g. SQS 63 | BrokerMessageGroupId string 64 | //ReceiptHandle of SQS Message 65 | SQSReceiptHandle string 66 | // StopTaskDeletionOnError used with sqs when we want to send failed messages to dlq, 67 | // and don't want machinery to delete from source queue 68 | StopTaskDeletionOnError bool 69 | // IgnoreWhenTaskNotRegistered auto removes the request when there is no handeler available 70 | // When this is true a task with no handler will be ignored and not placed back in the queue 71 | IgnoreWhenTaskNotRegistered bool 72 | } 73 | 74 | // NewSignature creates a new task signature 75 | func NewSignature(name string, args []Arg) (*Signature, error) { 76 | signatureID := uuid.New().String() 77 | return &Signature{ 78 | UUID: fmt.Sprintf("task_%v", signatureID), 79 | Name: name, 80 | Args: args, 81 | }, nil 82 | } 83 | 84 | func CopySignatures(signatures ...*Signature) []*Signature { 85 | var sigs = make([]*Signature, len(signatures)) 86 | for index, signature := range signatures { 87 | sigs[index] = CopySignature(signature) 88 | } 89 | return sigs 90 | } 91 | 92 | func CopySignature(signature *Signature) *Signature { 93 | var sig = new(Signature) 94 | _ = utils.DeepCopy(sig, signature) 95 | return sig 96 | } 97 | -------------------------------------------------------------------------------- /v1/tasks/state.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import "time" 4 | 5 | const ( 6 | // StatePending - initial state of a task 7 | StatePending = "PENDING" 8 | // StateReceived - when task is received by a worker 9 | StateReceived = "RECEIVED" 10 | // StateStarted - when the worker starts processing the task 11 | StateStarted = "STARTED" 12 | // StateRetry - when failed task has been scheduled for retry 13 | StateRetry = "RETRY" 14 | // StateSuccess - when the task is processed successfully 15 | StateSuccess = "SUCCESS" 16 | // StateFailure - when processing of the task fails 17 | StateFailure = "FAILURE" 18 | ) 19 | 20 | // TaskState represents a state of a task 21 | type TaskState struct { 22 | TaskUUID string `bson:"_id"` 23 | TaskName string `bson:"task_name"` 24 | State string `bson:"state"` 25 | Results []*TaskResult `bson:"results"` 26 | Error string `bson:"error"` 27 | CreatedAt time.Time `bson:"created_at"` 28 | TTL int64 `bson:"ttl,omitempty"` 29 | } 30 | 31 | // GroupMeta stores useful metadata about tasks within the same group 32 | // E.g. UUIDs of all tasks which are used in order to check if all tasks 33 | // completed successfully or not and thus whether to trigger chord callback 34 | type GroupMeta struct { 35 | GroupUUID string `bson:"_id"` 36 | TaskUUIDs []string `bson:"task_uuids"` 37 | ChordTriggered bool `bson:"chord_triggered"` 38 | Lock bool `bson:"lock"` 39 | CreatedAt time.Time `bson:"created_at"` 40 | TTL int64 `bson:"ttl,omitempty"` 41 | } 42 | 43 | // NewPendingTaskState ... 44 | func NewPendingTaskState(signature *Signature) *TaskState { 45 | return &TaskState{ 46 | TaskUUID: signature.UUID, 47 | TaskName: signature.Name, 48 | State: StatePending, 49 | CreatedAt: time.Now().UTC(), 50 | } 51 | } 52 | 53 | // NewReceivedTaskState ... 54 | func NewReceivedTaskState(signature *Signature) *TaskState { 55 | return &TaskState{ 56 | TaskUUID: signature.UUID, 57 | State: StateReceived, 58 | } 59 | } 60 | 61 | // NewStartedTaskState ... 62 | func NewStartedTaskState(signature *Signature) *TaskState { 63 | return &TaskState{ 64 | TaskUUID: signature.UUID, 65 | State: StateStarted, 66 | } 67 | } 68 | 69 | // NewSuccessTaskState ... 70 | func NewSuccessTaskState(signature *Signature, results []*TaskResult) *TaskState { 71 | return &TaskState{ 72 | TaskUUID: signature.UUID, 73 | State: StateSuccess, 74 | Results: results, 75 | } 76 | } 77 | 78 | // NewFailureTaskState ... 79 | func NewFailureTaskState(signature *Signature, err string) *TaskState { 80 | return &TaskState{ 81 | TaskUUID: signature.UUID, 82 | State: StateFailure, 83 | Error: err, 84 | } 85 | } 86 | 87 | // NewRetryTaskState ... 88 | func NewRetryTaskState(signature *Signature) *TaskState { 89 | return &TaskState{ 90 | TaskUUID: signature.UUID, 91 | State: StateRetry, 92 | } 93 | } 94 | 95 | // IsCompleted returns true if state is SUCCESS or FAILURE, 96 | // i.e. the task has finished processing and either succeeded or failed. 97 | func (taskState *TaskState) IsCompleted() bool { 98 | return taskState.IsSuccess() || taskState.IsFailure() 99 | } 100 | 101 | // IsSuccess returns true if state is SUCCESS 102 | func (taskState *TaskState) IsSuccess() bool { 103 | return taskState.State == StateSuccess 104 | } 105 | 106 | // IsFailure returns true if state is FAILURE 107 | func (taskState *TaskState) IsFailure() bool { 108 | return taskState.State == StateFailure 109 | } 110 | -------------------------------------------------------------------------------- /v1/tasks/state_test.go: -------------------------------------------------------------------------------- 1 | package tasks_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v1/tasks" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestTaskStateIsCompleted(t *testing.T) { 11 | t.Parallel() 12 | 13 | taskState := &tasks.TaskState{ 14 | TaskUUID: "taskUUID", 15 | State: tasks.StatePending, 16 | } 17 | 18 | assert.False(t, taskState.IsCompleted()) 19 | 20 | taskState.State = tasks.StateReceived 21 | assert.False(t, taskState.IsCompleted()) 22 | 23 | taskState.State = tasks.StateStarted 24 | assert.False(t, taskState.IsCompleted()) 25 | 26 | taskState.State = tasks.StateSuccess 27 | assert.True(t, taskState.IsCompleted()) 28 | 29 | taskState.State = tasks.StateFailure 30 | assert.True(t, taskState.IsCompleted()) 31 | } 32 | -------------------------------------------------------------------------------- /v1/tasks/task_test.go: -------------------------------------------------------------------------------- 1 | package tasks_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "math" 7 | "testing" 8 | "time" 9 | 10 | "github.com/RichardKnop/machinery/v1/tasks" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestTaskCallErrorTest(t *testing.T) { 15 | t.Parallel() 16 | 17 | // Create test task that returns tasks.ErrRetryTaskLater error 18 | retriable := func() error { return tasks.NewErrRetryTaskLater("some error", 4*time.Hour) } 19 | 20 | task, err := tasks.New(retriable, []tasks.Arg{}) 21 | assert.NoError(t, err) 22 | 23 | // Invoke TryCall and validate that returned error can be cast to tasks.ErrRetryTaskLater 24 | results, err := task.Call() 25 | assert.Nil(t, results) 26 | assert.NotNil(t, err) 27 | _, ok := interface{}(err).(tasks.ErrRetryTaskLater) 28 | assert.True(t, ok, "Error should be castable to tasks.ErrRetryTaskLater") 29 | 30 | // Create test task that returns a standard error 31 | standard := func() error { return errors.New("some error") } 32 | 33 | task, err = tasks.New(standard, []tasks.Arg{}) 34 | assert.NoError(t, err) 35 | 36 | // Invoke TryCall and validate that returned error is standard 37 | results, err = task.Call() 38 | assert.Nil(t, results) 39 | assert.NotNil(t, err) 40 | assert.Equal(t, "some error", err.Error()) 41 | } 42 | 43 | func TestTaskReflectArgs(t *testing.T) { 44 | t.Parallel() 45 | 46 | task := new(tasks.Task) 47 | args := []tasks.Arg{ 48 | { 49 | Type: "[]int64", 50 | Value: []int64{1, 2}, 51 | }, 52 | } 53 | 54 | err := task.ReflectArgs(args) 55 | assert.NoError(t, err) 56 | assert.Equal(t, 1, len(task.Args)) 57 | assert.Equal(t, "[]int64", task.Args[0].Type().String()) 58 | } 59 | 60 | func TestTaskCallInvalidArgRobustnessError(t *testing.T) { 61 | t.Parallel() 62 | 63 | // Create a test task function 64 | f := func(x int) error { return nil } 65 | 66 | // Construct an invalid argument list and reflect it 67 | args := []tasks.Arg{ 68 | {Type: "bool", Value: true}, 69 | } 70 | 71 | task, err := tasks.New(f, args) 72 | assert.NoError(t, err) 73 | 74 | // Invoke TryCall and validate error handling 75 | results, err := task.Call() 76 | assert.Equal(t, "reflect: Call using bool as type int", err.Error()) 77 | assert.Nil(t, results) 78 | } 79 | 80 | func TestTaskCallInterfaceValuedResult(t *testing.T) { 81 | t.Parallel() 82 | 83 | // Create a test task function 84 | f := func() (interface{}, error) { return math.Pi, nil } 85 | 86 | task, err := tasks.New(f, []tasks.Arg{}) 87 | assert.NoError(t, err) 88 | 89 | taskResults, err := task.Call() 90 | assert.NoError(t, err) 91 | assert.Equal(t, "float64", taskResults[0].Type) 92 | assert.Equal(t, math.Pi, taskResults[0].Value) 93 | } 94 | 95 | func TestTaskCallWithContext(t *testing.T) { 96 | t.Parallel() 97 | 98 | f := func(c context.Context) (interface{}, error) { 99 | assert.NotNil(t, c) 100 | assert.Nil(t, tasks.SignatureFromContext(c)) 101 | return math.Pi, nil 102 | } 103 | task, err := tasks.New(f, []tasks.Arg{}) 104 | assert.NoError(t, err) 105 | taskResults, err := task.Call() 106 | assert.NoError(t, err) 107 | assert.Equal(t, "float64", taskResults[0].Type) 108 | assert.Equal(t, math.Pi, taskResults[0].Value) 109 | } 110 | 111 | func TestTaskCallWithSignatureInContext(t *testing.T) { 112 | t.Parallel() 113 | 114 | f := func(c context.Context) (interface{}, error) { 115 | assert.NotNil(t, c) 116 | signature := tasks.SignatureFromContext(c) 117 | assert.NotNil(t, signature) 118 | assert.Equal(t, "foo", signature.Name) 119 | return math.Pi, nil 120 | } 121 | signature, err := tasks.NewSignature("foo", []tasks.Arg{}) 122 | assert.NoError(t, err) 123 | task, err := tasks.NewWithSignature(f, signature) 124 | assert.NoError(t, err) 125 | taskResults, err := task.Call() 126 | assert.NoError(t, err) 127 | assert.Equal(t, "float64", taskResults[0].Type) 128 | assert.Equal(t, math.Pi, taskResults[0].Value) 129 | } 130 | -------------------------------------------------------------------------------- /v1/tasks/validate.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | ) 7 | 8 | var ( 9 | // ErrTaskMustBeFunc ... 10 | ErrTaskMustBeFunc = errors.New("Task must be a func type") 11 | // ErrTaskReturnsNoValue ... 12 | ErrTaskReturnsNoValue = errors.New("Task must return at least a single value") 13 | // ErrLastReturnValueMustBeError .. 14 | ErrLastReturnValueMustBeError = errors.New("Last return value of a task must be error") 15 | ) 16 | 17 | // ValidateTask validates task function using reflection and makes sure 18 | // it has a proper signature. Functions used as tasks must return at least a 19 | // single value and the last return type must be error 20 | func ValidateTask(task interface{}) error { 21 | v := reflect.ValueOf(task) 22 | t := v.Type() 23 | 24 | // Task must be a function 25 | if t.Kind() != reflect.Func { 26 | return ErrTaskMustBeFunc 27 | } 28 | 29 | // Task must return at least a single value 30 | if t.NumOut() < 1 { 31 | return ErrTaskReturnsNoValue 32 | } 33 | 34 | // Last return value must be error 35 | lastReturnType := t.Out(t.NumOut() - 1) 36 | errorInterface := reflect.TypeOf((*error)(nil)).Elem() 37 | if !lastReturnType.Implements(errorInterface) { 38 | return ErrLastReturnValueMustBeError 39 | } 40 | 41 | return nil 42 | } 43 | -------------------------------------------------------------------------------- /v1/tasks/validate_test.go: -------------------------------------------------------------------------------- 1 | package tasks_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v1/tasks" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestValidateTask(t *testing.T) { 11 | t.Parallel() 12 | 13 | type someStruct struct{} 14 | var ( 15 | taskOfWrongType = new(someStruct) 16 | taskWithoutReturnValue = func() {} 17 | taskWithoutErrorAsLastReturnValue = func() int { return 0 } 18 | validTask = func(arg string) error { return nil } 19 | ) 20 | 21 | err := tasks.ValidateTask(taskOfWrongType) 22 | assert.Equal(t, tasks.ErrTaskMustBeFunc, err) 23 | 24 | err = tasks.ValidateTask(taskWithoutReturnValue) 25 | assert.Equal(t, tasks.ErrTaskReturnsNoValue, err) 26 | 27 | err = tasks.ValidateTask(taskWithoutErrorAsLastReturnValue) 28 | assert.Equal(t, tasks.ErrLastReturnValueMustBeError, err) 29 | 30 | err = tasks.ValidateTask(validTask) 31 | assert.NoError(t, err) 32 | } 33 | -------------------------------------------------------------------------------- /v1/tasks/workflow.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | // Chain creates a chain of tasks to be executed one after another 10 | type Chain struct { 11 | Tasks []*Signature 12 | } 13 | 14 | // Group creates a set of tasks to be executed in parallel 15 | type Group struct { 16 | GroupUUID string 17 | Tasks []*Signature 18 | } 19 | 20 | // Chord adds an optional callback to the group to be executed 21 | // after all tasks in the group finished 22 | type Chord struct { 23 | Group *Group 24 | Callback *Signature 25 | } 26 | 27 | // GetUUIDs returns slice of task UUIDS 28 | func (group *Group) GetUUIDs() []string { 29 | taskUUIDs := make([]string, len(group.Tasks)) 30 | for i, signature := range group.Tasks { 31 | taskUUIDs[i] = signature.UUID 32 | } 33 | return taskUUIDs 34 | } 35 | 36 | // NewChain creates a new chain of tasks to be processed one by one, passing 37 | // results unless task signatures are set to be immutable 38 | func NewChain(signatures ...*Signature) (*Chain, error) { 39 | // Auto generate task UUIDs if needed 40 | for _, signature := range signatures { 41 | if signature.UUID == "" { 42 | signatureID := uuid.New().String() 43 | signature.UUID = fmt.Sprintf("task_%v", signatureID) 44 | } 45 | } 46 | 47 | for i := len(signatures) - 1; i > 0; i-- { 48 | if i > 0 { 49 | signatures[i-1].OnSuccess = []*Signature{signatures[i]} 50 | } 51 | } 52 | 53 | chain := &Chain{Tasks: signatures} 54 | 55 | return chain, nil 56 | } 57 | 58 | // NewGroup creates a new group of tasks to be processed in parallel 59 | func NewGroup(signatures ...*Signature) (*Group, error) { 60 | // Generate a group UUID 61 | groupUUID := uuid.New().String() 62 | groupID := fmt.Sprintf("group_%v", groupUUID) 63 | 64 | // Auto generate task UUIDs if needed, group tasks by common group UUID 65 | for _, signature := range signatures { 66 | if signature.UUID == "" { 67 | signatureID := uuid.New().String() 68 | signature.UUID = fmt.Sprintf("task_%v", signatureID) 69 | } 70 | signature.GroupUUID = groupID 71 | signature.GroupTaskCount = len(signatures) 72 | } 73 | 74 | return &Group{ 75 | GroupUUID: groupID, 76 | Tasks: signatures, 77 | }, nil 78 | } 79 | 80 | // NewChord creates a new chord (a group of tasks with a single callback 81 | // to be executed after all tasks in the group has completed) 82 | func NewChord(group *Group, callback *Signature) (*Chord, error) { 83 | if callback.UUID == "" { 84 | // Generate a UUID for the chord callback 85 | callbackUUID := uuid.New().String() 86 | callback.UUID = fmt.Sprintf("chord_%v", callbackUUID) 87 | } 88 | 89 | // Add a chord callback to all tasks 90 | for _, signature := range group.Tasks { 91 | signature.ChordCallback = callback 92 | } 93 | 94 | return &Chord{Group: group, Callback: callback}, nil 95 | } 96 | -------------------------------------------------------------------------------- /v1/tasks/workflow_test.go: -------------------------------------------------------------------------------- 1 | package tasks_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v1/tasks" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestNewChain(t *testing.T) { 11 | t.Parallel() 12 | 13 | task1 := tasks.Signature{ 14 | Name: "foo", 15 | Args: []tasks.Arg{ 16 | { 17 | Type: "float64", 18 | Value: interface{}(1), 19 | }, 20 | { 21 | Type: "float64", 22 | Value: interface{}(1), 23 | }, 24 | }, 25 | } 26 | 27 | task2 := tasks.Signature{ 28 | Name: "bar", 29 | Args: []tasks.Arg{ 30 | { 31 | Type: "float64", 32 | Value: interface{}(5), 33 | }, 34 | { 35 | Type: "float64", 36 | Value: interface{}(6), 37 | }, 38 | }, 39 | } 40 | 41 | task3 := tasks.Signature{ 42 | Name: "qux", 43 | Args: []tasks.Arg{ 44 | { 45 | Type: "float64", 46 | Value: interface{}(4), 47 | }, 48 | }, 49 | } 50 | 51 | chain, err := tasks.NewChain(&task1, &task2, &task3) 52 | if err != nil { 53 | t.Fatal(err) 54 | } 55 | 56 | firstTask := chain.Tasks[0] 57 | 58 | assert.Equal(t, "foo", firstTask.Name) 59 | assert.Equal(t, "bar", firstTask.OnSuccess[0].Name) 60 | assert.Equal(t, "qux", firstTask.OnSuccess[0].OnSuccess[0].Name) 61 | } 62 | -------------------------------------------------------------------------------- /v1/utils/deepcopy.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | ) 7 | 8 | var ( 9 | ErrNoMatchType = errors.New("no match type") 10 | ErrNoPointer = errors.New("must be interface") 11 | ErrInvalidArgument = errors.New("invalid arguments") 12 | ) 13 | 14 | func deepCopy(dst, src reflect.Value) { 15 | switch src.Kind() { 16 | case reflect.Interface: 17 | value := src.Elem() 18 | if !value.IsValid() { 19 | return 20 | } 21 | newValue := reflect.New(value.Type()).Elem() 22 | deepCopy(newValue, value) 23 | dst.Set(newValue) 24 | case reflect.Ptr: 25 | value := src.Elem() 26 | if !value.IsValid() { 27 | return 28 | } 29 | dst.Set(reflect.New(value.Type())) 30 | deepCopy(dst.Elem(), value) 31 | case reflect.Map: 32 | dst.Set(reflect.MakeMap(src.Type())) 33 | keys := src.MapKeys() 34 | for _, key := range keys { 35 | value := src.MapIndex(key) 36 | newValue := reflect.New(value.Type()).Elem() 37 | deepCopy(newValue, value) 38 | dst.SetMapIndex(key, newValue) 39 | } 40 | case reflect.Slice: 41 | dst.Set(reflect.MakeSlice(src.Type(), src.Len(), src.Cap())) 42 | for i := 0; i < src.Len(); i++ { 43 | deepCopy(dst.Index(i), src.Index(i)) 44 | } 45 | case reflect.Struct: 46 | typeSrc := src.Type() 47 | for i := 0; i < src.NumField(); i++ { 48 | value := src.Field(i) 49 | tag := typeSrc.Field(i).Tag 50 | if value.CanSet() && tag.Get("deepcopy") != "-" { 51 | deepCopy(dst.Field(i), value) 52 | } 53 | } 54 | default: 55 | dst.Set(src) 56 | } 57 | } 58 | 59 | func DeepCopy(dst, src interface{}) error { 60 | typeDst := reflect.TypeOf(dst) 61 | typeSrc := reflect.TypeOf(src) 62 | if typeDst != typeSrc { 63 | return ErrNoMatchType 64 | } 65 | if typeSrc.Kind() != reflect.Ptr { 66 | return ErrNoPointer 67 | } 68 | 69 | valueDst := reflect.ValueOf(dst).Elem() 70 | valueSrc := reflect.ValueOf(src).Elem() 71 | if !valueDst.IsValid() || !valueSrc.IsValid() { 72 | return ErrInvalidArgument 73 | } 74 | 75 | deepCopy(valueDst, valueSrc) 76 | return nil 77 | } 78 | 79 | func DeepClone(v interface{}) interface{} { 80 | dst := reflect.New(reflect.TypeOf(v)).Elem() 81 | deepCopy(dst, reflect.ValueOf(v)) 82 | return dst.Interface() 83 | } 84 | -------------------------------------------------------------------------------- /v1/utils/deepcopy_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestDeepCopy(t *testing.T) { 10 | t.Parallel() 11 | 12 | type s struct { 13 | A float64 14 | B int 15 | C []int 16 | D *int 17 | E map[string]int 18 | } 19 | var d = 3 20 | var dst = new(s) 21 | var src = s{1.0, 1, []int{1, 2, 3}, &d, map[string]int{"a": 1}} 22 | 23 | err := DeepCopy(dst, &src) 24 | src.A = 2 25 | 26 | assert.NoError(t, err) 27 | assert.Equal(t, 1.0, dst.A) 28 | assert.Equal(t, 1, dst.B) 29 | assert.Equal(t, []int{1, 2, 3}, dst.C) 30 | assert.Equal(t, &d, dst.D) 31 | assert.Equal(t, map[string]int{"a": 1}, dst.E) 32 | } 33 | -------------------------------------------------------------------------------- /v1/utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | ) 7 | 8 | const ( 9 | LockKeyPrefix = "machinery_lock_" 10 | ) 11 | 12 | func GetLockName(name, spec string) string { 13 | return LockKeyPrefix + filepath.Base(os.Args[0]) + name + spec 14 | } 15 | -------------------------------------------------------------------------------- /v1/utils/utils_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestGetLockName(t *testing.T) { 10 | t.Parallel() 11 | 12 | lockName := GetLockName("test", "*/3 * * *") 13 | assert.Equal(t, "machinery_lock_utils.testtest*/3 * * *", lockName) 14 | } 15 | -------------------------------------------------------------------------------- /v1/utils/uuid.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "github.com/google/uuid" 5 | "strings" 6 | ) 7 | 8 | func GetPureUUID() string { 9 | uid, _ := uuid.NewUUID() 10 | return strings.Replace(uid.String(), "-", "", -1) 11 | } 12 | -------------------------------------------------------------------------------- /v1/utils/uuid_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestGetPureUUID(t *testing.T) { 10 | t.Parallel() 11 | 12 | assert.Len(t, GetPureUUID(), 32) 13 | } 14 | -------------------------------------------------------------------------------- /v1/worker_test.go: -------------------------------------------------------------------------------- 1 | package machinery_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/RichardKnop/machinery/v1" 9 | ) 10 | 11 | func TestRedactURL(t *testing.T) { 12 | t.Parallel() 13 | 14 | broker := "amqp://guest:guest@localhost:5672" 15 | redactedURL := machinery.RedactURL(broker) 16 | assert.Equal(t, "amqp://localhost:5672", redactedURL) 17 | } 18 | 19 | func TestPreConsumeHandler(t *testing.T) { 20 | t.Parallel() 21 | 22 | worker := &machinery.Worker{} 23 | 24 | worker.SetPreConsumeHandler(SamplePreConsumeHandler) 25 | assert.True(t, worker.PreConsumeHandler()) 26 | } 27 | 28 | func SamplePreConsumeHandler(w *machinery.Worker) bool { 29 | return true 30 | } 31 | -------------------------------------------------------------------------------- /v2/backends/amqp/amqp_test.go: -------------------------------------------------------------------------------- 1 | package amqp_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | 8 | "github.com/RichardKnop/machinery/v2/backends/amqp" 9 | "github.com/RichardKnop/machinery/v2/config" 10 | "github.com/RichardKnop/machinery/v2/tasks" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | var ( 15 | amqpConfig *config.Config 16 | ) 17 | 18 | func init() { 19 | amqpURL := os.Getenv("AMQP_URL") 20 | if amqpURL == "" { 21 | return 22 | } 23 | 24 | finalAmqpURL := amqpURL 25 | var finalSeparator string 26 | 27 | amqpURLs := os.Getenv("AMQP_URLS") 28 | if amqpURLs != "" { 29 | separator := os.Getenv("AMQP_URLS_SEPARATOR") 30 | if separator == "" { 31 | return 32 | } 33 | finalSeparator = separator 34 | finalAmqpURL = amqpURLs 35 | } 36 | 37 | amqp2URL := os.Getenv("AMQP2_URL") 38 | if amqp2URL == "" { 39 | amqp2URL = amqpURL 40 | } 41 | 42 | amqpConfig = &config.Config{ 43 | Broker: finalAmqpURL, 44 | MultipleBrokerSeparator: finalSeparator, 45 | DefaultQueue: "test_queue", 46 | ResultBackend: amqp2URL, 47 | AMQP: &config.AMQPConfig{ 48 | Exchange: "test_exchange", 49 | ExchangeType: "direct", 50 | BindingKey: "test_task", 51 | PrefetchCount: 1, 52 | }, 53 | } 54 | } 55 | 56 | func TestGroupCompleted(t *testing.T) { 57 | if os.Getenv("AMQP_URL") == "" { 58 | t.Skip("AMQP_URL is not defined") 59 | } 60 | 61 | groupUUID := "testGroupUUID" 62 | groupTaskCount := 2 63 | task1 := &tasks.Signature{ 64 | UUID: "testTaskUUID1", 65 | GroupUUID: groupUUID, 66 | GroupTaskCount: groupTaskCount, 67 | } 68 | task2 := &tasks.Signature{ 69 | UUID: "testTaskUUID2", 70 | GroupUUID: groupUUID, 71 | GroupTaskCount: groupTaskCount, 72 | } 73 | 74 | backend := amqp.New(amqpConfig) 75 | 76 | // Cleanup before the test 77 | backend.PurgeState(task1.UUID) 78 | backend.PurgeState(task2.UUID) 79 | backend.PurgeGroupMeta(groupUUID) 80 | 81 | groupCompleted, err := backend.GroupCompleted(groupUUID, groupTaskCount) 82 | if assert.NoError(t, err) { 83 | assert.False(t, groupCompleted) 84 | } 85 | 86 | backend.InitGroup(groupUUID, []string{task1.UUID, task2.UUID}) 87 | 88 | groupCompleted, err = backend.GroupCompleted(groupUUID, groupTaskCount) 89 | if assert.NoError(t, err) { 90 | assert.False(t, groupCompleted) 91 | } 92 | 93 | backend.SetStatePending(task1) 94 | backend.SetStateStarted(task2) 95 | groupCompleted, err = backend.GroupCompleted(groupUUID, groupTaskCount) 96 | if assert.NoError(t, err) { 97 | assert.False(t, groupCompleted) 98 | } 99 | 100 | taskResults := []*tasks.TaskResult{new(tasks.TaskResult)} 101 | backend.SetStateSuccess(task1, taskResults) 102 | backend.SetStateSuccess(task2, taskResults) 103 | groupCompleted, err = backend.GroupCompleted(groupUUID, groupTaskCount) 104 | if assert.NoError(t, err) { 105 | assert.True(t, groupCompleted) 106 | } 107 | } 108 | 109 | func TestGetState(t *testing.T) { 110 | if os.Getenv("AMQP_URL") == "" { 111 | t.Skip("AMQP_URL is not defined") 112 | } 113 | 114 | signature := &tasks.Signature{ 115 | UUID: "testTaskUUID", 116 | GroupUUID: "testGroupUUID", 117 | } 118 | 119 | go func() { 120 | backend := amqp.New(amqpConfig) 121 | backend.SetStatePending(signature) 122 | time.Sleep(2 * time.Millisecond) 123 | backend.SetStateReceived(signature) 124 | time.Sleep(2 * time.Millisecond) 125 | backend.SetStateStarted(signature) 126 | time.Sleep(2 * time.Millisecond) 127 | 128 | taskResults := []*tasks.TaskResult{ 129 | { 130 | Type: "float64", 131 | Value: 2, 132 | }, 133 | } 134 | backend.SetStateSuccess(signature, taskResults) 135 | }() 136 | 137 | backend := amqp.New(amqpConfig) 138 | 139 | var ( 140 | taskState *tasks.TaskState 141 | err error 142 | ) 143 | for { 144 | taskState, err = backend.GetState(signature.UUID) 145 | if taskState == nil { 146 | assert.Equal(t, "No state ready", err.Error()) 147 | continue 148 | } 149 | 150 | assert.NoError(t, err) 151 | if taskState.IsCompleted() { 152 | break 153 | } 154 | } 155 | } 156 | 157 | func TestPurgeState(t *testing.T) { 158 | if os.Getenv("AMQP_URL") == "" { 159 | t.Skip("AMQP_URL is not defined") 160 | } 161 | 162 | signature := &tasks.Signature{ 163 | UUID: "testTaskUUID", 164 | GroupUUID: "testGroupUUID", 165 | } 166 | 167 | backend := amqp.New(amqpConfig) 168 | 169 | backend.SetStatePending(signature) 170 | backend.SetStateReceived(signature) 171 | taskState, err := backend.GetState(signature.UUID) 172 | assert.NotNil(t, taskState) 173 | assert.NoError(t, err) 174 | 175 | backend.PurgeState(taskState.TaskUUID) 176 | taskState, err = backend.GetState(signature.UUID) 177 | assert.Nil(t, taskState) 178 | assert.Error(t, err) 179 | } 180 | -------------------------------------------------------------------------------- /v2/backends/iface/interfaces.go: -------------------------------------------------------------------------------- 1 | package iface 2 | 3 | import ( 4 | "github.com/RichardKnop/machinery/v2/tasks" 5 | ) 6 | 7 | // Backend - a common interface for all result backends 8 | type Backend interface { 9 | // Group related functions 10 | InitGroup(groupUUID string, taskUUIDs []string) error 11 | GroupCompleted(groupUUID string, groupTaskCount int) (bool, error) 12 | GroupTaskStates(groupUUID string, groupTaskCount int) ([]*tasks.TaskState, error) 13 | TriggerChord(groupUUID string) (bool, error) 14 | 15 | // Setting / getting task state 16 | SetStatePending(signature *tasks.Signature) error 17 | SetStateReceived(signature *tasks.Signature) error 18 | SetStateStarted(signature *tasks.Signature) error 19 | SetStateRetry(signature *tasks.Signature) error 20 | SetStateSuccess(signature *tasks.Signature, results []*tasks.TaskResult) error 21 | SetStateFailure(signature *tasks.Signature, err string) error 22 | GetState(taskUUID string) (*tasks.TaskState, error) 23 | 24 | // Purging stored stored tasks states and group meta data 25 | IsAMQP() bool 26 | PurgeState(taskUUID string) error 27 | PurgeGroupMeta(groupUUID string) error 28 | } 29 | -------------------------------------------------------------------------------- /v2/backends/memcache/memcache_test.go: -------------------------------------------------------------------------------- 1 | package memcache_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | "time" 7 | 8 | "github.com/RichardKnop/machinery/v2/backends/memcache" 9 | "github.com/RichardKnop/machinery/v2/config" 10 | "github.com/RichardKnop/machinery/v2/tasks" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestGroupCompleted(t *testing.T) { 15 | memcacheURL := os.Getenv("MEMCACHE_URL") 16 | if memcacheURL == "" { 17 | t.Skip("MEMCACHE_URL is not defined") 18 | } 19 | 20 | groupUUID := "testGroupUUID" 21 | task1 := &tasks.Signature{ 22 | UUID: "testTaskUUID1", 23 | GroupUUID: groupUUID, 24 | } 25 | task2 := &tasks.Signature{ 26 | UUID: "testTaskUUID2", 27 | GroupUUID: groupUUID, 28 | } 29 | 30 | backend := memcache.New(new(config.Config), []string{memcacheURL}) 31 | 32 | // Cleanup before the test 33 | backend.PurgeState(task1.UUID) 34 | backend.PurgeState(task2.UUID) 35 | backend.PurgeGroupMeta(groupUUID) 36 | 37 | groupCompleted, err := backend.GroupCompleted(groupUUID, 2) 38 | if assert.Error(t, err) { 39 | assert.False(t, groupCompleted) 40 | assert.Equal(t, "memcache: cache miss", err.Error()) 41 | } 42 | 43 | backend.InitGroup(groupUUID, []string{task1.UUID, task2.UUID}) 44 | 45 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 46 | if assert.Error(t, err) { 47 | assert.False(t, groupCompleted) 48 | assert.Equal(t, "memcache: cache miss", err.Error()) 49 | } 50 | 51 | backend.SetStatePending(task1) 52 | backend.SetStateStarted(task2) 53 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 54 | if assert.NoError(t, err) { 55 | assert.False(t, groupCompleted) 56 | } 57 | 58 | taskResults := []*tasks.TaskResult{new(tasks.TaskResult)} 59 | backend.SetStateStarted(task1) 60 | backend.SetStateSuccess(task2, taskResults) 61 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 62 | if assert.NoError(t, err) { 63 | assert.False(t, groupCompleted) 64 | } 65 | 66 | backend.SetStateFailure(task1, "Some error") 67 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 68 | if assert.NoError(t, err) { 69 | assert.True(t, groupCompleted) 70 | } 71 | } 72 | 73 | func TestGetState(t *testing.T) { 74 | memcacheURL := os.Getenv("MEMCACHE_URL") 75 | if memcacheURL == "" { 76 | t.Skip("MEMCACHE_URL is not defined") 77 | } 78 | 79 | signature := &tasks.Signature{ 80 | UUID: "testTaskUUID", 81 | GroupUUID: "testGroupUUID", 82 | } 83 | 84 | backend := memcache.New(new(config.Config), []string{memcacheURL}) 85 | 86 | go func() { 87 | backend.SetStatePending(signature) 88 | time.Sleep(2 * time.Millisecond) 89 | backend.SetStateReceived(signature) 90 | time.Sleep(2 * time.Millisecond) 91 | backend.SetStateStarted(signature) 92 | time.Sleep(2 * time.Millisecond) 93 | taskResults := []*tasks.TaskResult{ 94 | { 95 | Type: "float64", 96 | Value: 2, 97 | }, 98 | } 99 | backend.SetStateSuccess(signature, taskResults) 100 | }() 101 | 102 | var ( 103 | taskState *tasks.TaskState 104 | err error 105 | ) 106 | for { 107 | taskState, err = backend.GetState(signature.UUID) 108 | if taskState == nil { 109 | assert.Equal(t, "memcache: cache miss", err.Error()) 110 | continue 111 | } 112 | 113 | assert.NoError(t, err) 114 | if taskState.IsCompleted() { 115 | break 116 | } 117 | } 118 | } 119 | 120 | func TestPurgeState(t *testing.T) { 121 | memcacheURL := os.Getenv("MEMCACHE_URL") 122 | if memcacheURL == "" { 123 | t.Skip("MEMCACHE_URL is not defined") 124 | } 125 | 126 | signature := &tasks.Signature{ 127 | UUID: "testTaskUUID", 128 | GroupUUID: "testGroupUUID", 129 | } 130 | 131 | backend := memcache.New(new(config.Config), []string{memcacheURL}) 132 | 133 | backend.SetStatePending(signature) 134 | taskState, err := backend.GetState(signature.UUID) 135 | assert.NotNil(t, taskState) 136 | assert.NoError(t, err) 137 | 138 | backend.PurgeState(taskState.TaskUUID) 139 | taskState, err = backend.GetState(signature.UUID) 140 | assert.Nil(t, taskState) 141 | assert.Error(t, err) 142 | } 143 | -------------------------------------------------------------------------------- /v2/backends/null/null.go: -------------------------------------------------------------------------------- 1 | package null 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/RichardKnop/machinery/v2/backends/iface" 7 | "github.com/RichardKnop/machinery/v2/common" 8 | "github.com/RichardKnop/machinery/v2/config" 9 | "github.com/RichardKnop/machinery/v2/tasks" 10 | ) 11 | 12 | // ErrGroupNotFound ... 13 | type ErrGroupNotFound struct { 14 | groupUUID string 15 | } 16 | 17 | // NewErrGroupNotFound returns new instance of ErrGroupNotFound 18 | func NewErrGroupNotFound(groupUUID string) ErrGroupNotFound { 19 | return ErrGroupNotFound{groupUUID: groupUUID} 20 | } 21 | 22 | // Error implements error interface 23 | func (e ErrGroupNotFound) Error() string { 24 | return fmt.Sprintf("Group not found: %v", e.groupUUID) 25 | } 26 | 27 | // ErrTasknotFound ... 28 | type ErrTasknotFound struct { 29 | taskUUID string 30 | } 31 | 32 | // NewErrTasknotFound returns new instance of ErrTasknotFound 33 | func NewErrTasknotFound(taskUUID string) ErrTasknotFound { 34 | return ErrTasknotFound{taskUUID: taskUUID} 35 | } 36 | 37 | // Error implements error interface 38 | func (e ErrTasknotFound) Error() string { 39 | return fmt.Sprintf("Task not found: %v", e.taskUUID) 40 | } 41 | 42 | // Backend represents an "null" result backend 43 | type Backend struct { 44 | common.Backend 45 | groups map[string]struct{} 46 | } 47 | 48 | // New creates NullBackend instance 49 | func New() iface.Backend { 50 | return &Backend{ 51 | Backend: common.NewBackend(new(config.Config)), 52 | groups: make(map[string]struct{}), 53 | } 54 | } 55 | 56 | // InitGroup creates and saves a group meta data object 57 | func (b *Backend) InitGroup(groupUUID string, taskUUIDs []string) error { 58 | b.groups[groupUUID] = struct{}{} 59 | return nil 60 | } 61 | 62 | // GroupCompleted returns true (always) 63 | func (b *Backend) GroupCompleted(groupUUID string, groupTaskCount int) (bool, error) { 64 | _, ok := b.groups[groupUUID] 65 | if !ok { 66 | return false, NewErrGroupNotFound(groupUUID) 67 | } 68 | 69 | return true, nil 70 | } 71 | 72 | // GroupTaskStates returns null states of all tasks in the group 73 | func (b *Backend) GroupTaskStates(groupUUID string, groupTaskCount int) ([]*tasks.TaskState, error) { 74 | _, ok := b.groups[groupUUID] 75 | if !ok { 76 | return nil, NewErrGroupNotFound(groupUUID) 77 | } 78 | 79 | ret := make([]*tasks.TaskState, 0, groupTaskCount) 80 | return ret, nil 81 | } 82 | 83 | // TriggerChord returns true (always) 84 | func (b *Backend) TriggerChord(groupUUID string) (bool, error) { 85 | return true, nil 86 | } 87 | 88 | // SetStatePending updates task state to PENDING 89 | func (b *Backend) SetStatePending(signature *tasks.Signature) error { 90 | state := tasks.NewPendingTaskState(signature) 91 | return b.updateState(state) 92 | } 93 | 94 | // SetStateReceived updates task state to RECEIVED 95 | func (b *Backend) SetStateReceived(signature *tasks.Signature) error { 96 | state := tasks.NewReceivedTaskState(signature) 97 | return b.updateState(state) 98 | } 99 | 100 | // SetStateStarted updates task state to STARTED 101 | func (b *Backend) SetStateStarted(signature *tasks.Signature) error { 102 | state := tasks.NewStartedTaskState(signature) 103 | return b.updateState(state) 104 | } 105 | 106 | // SetStateRetry updates task state to RETRY 107 | func (b *Backend) SetStateRetry(signature *tasks.Signature) error { 108 | state := tasks.NewRetryTaskState(signature) 109 | return b.updateState(state) 110 | } 111 | 112 | // SetStateSuccess updates task state to SUCCESS 113 | func (b *Backend) SetStateSuccess(signature *tasks.Signature, results []*tasks.TaskResult) error { 114 | state := tasks.NewSuccessTaskState(signature, results) 115 | return b.updateState(state) 116 | } 117 | 118 | // SetStateFailure updates task state to FAILURE 119 | func (b *Backend) SetStateFailure(signature *tasks.Signature, err string) error { 120 | state := tasks.NewFailureTaskState(signature, err) 121 | return b.updateState(state) 122 | } 123 | 124 | // GetState returns the latest task state 125 | func (b *Backend) GetState(taskUUID string) (*tasks.TaskState, error) { 126 | return nil, NewErrTasknotFound(taskUUID) 127 | } 128 | 129 | // PurgeState deletes stored task state 130 | func (b *Backend) PurgeState(taskUUID string) error { 131 | return NewErrTasknotFound(taskUUID) 132 | } 133 | 134 | // PurgeGroupMeta deletes stored group meta data 135 | func (b *Backend) PurgeGroupMeta(groupUUID string) error { 136 | _, ok := b.groups[groupUUID] 137 | if !ok { 138 | return NewErrGroupNotFound(groupUUID) 139 | } 140 | 141 | return nil 142 | } 143 | 144 | func (b *Backend) updateState(s *tasks.TaskState) error { 145 | return nil 146 | } 147 | -------------------------------------------------------------------------------- /v2/backends/package.go: -------------------------------------------------------------------------------- 1 | package backends 2 | -------------------------------------------------------------------------------- /v2/backends/redis/goredis_test.go: -------------------------------------------------------------------------------- 1 | package redis_test 2 | 3 | import ( 4 | "github.com/RichardKnop/machinery/v2/backends/iface" 5 | "os" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/RichardKnop/machinery/v2/backends/redis" 10 | "github.com/RichardKnop/machinery/v2/config" 11 | "github.com/RichardKnop/machinery/v2/tasks" 12 | "github.com/stretchr/testify/assert" 13 | ) 14 | 15 | func getRedisG() iface.Backend { 16 | // host1:port1,host2:port2 17 | redisURL := os.Getenv("REDIS_URL_GR") 18 | //redisPassword := os.Getenv("REDIS_PASSWORD") 19 | if redisURL == "" { 20 | return nil 21 | } 22 | backend := redis.NewGR(new(config.Config), strings.Split(redisURL, ","), 0) 23 | return backend 24 | } 25 | 26 | func TestGroupCompletedGR(t *testing.T) { 27 | backend := getRedisG() 28 | if backend == nil { 29 | t.Skip() 30 | } 31 | 32 | groupUUID := "testGroupUUID" 33 | task1 := &tasks.Signature{ 34 | UUID: "testTaskUUID1", 35 | GroupUUID: groupUUID, 36 | } 37 | task2 := &tasks.Signature{ 38 | UUID: "testTaskUUID2", 39 | GroupUUID: groupUUID, 40 | } 41 | 42 | // Cleanup before the test 43 | backend.PurgeState(task1.UUID) 44 | backend.PurgeState(task2.UUID) 45 | backend.PurgeGroupMeta(groupUUID) 46 | 47 | groupCompleted, err := backend.GroupCompleted(groupUUID, 2) 48 | if assert.Error(t, err) { 49 | assert.False(t, groupCompleted) 50 | assert.Equal(t, "redis: nil", err.Error()) 51 | } 52 | 53 | backend.InitGroup(groupUUID, []string{task1.UUID, task2.UUID}) 54 | 55 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 56 | if assert.Error(t, err) { 57 | assert.False(t, groupCompleted) 58 | assert.Equal(t, "redis: nil", err.Error()) 59 | } 60 | 61 | backend.SetStatePending(task1) 62 | backend.SetStateStarted(task2) 63 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 64 | if assert.NoError(t, err) { 65 | assert.False(t, groupCompleted) 66 | } 67 | 68 | taskResults := []*tasks.TaskResult{new(tasks.TaskResult)} 69 | backend.SetStateStarted(task1) 70 | backend.SetStateSuccess(task2, taskResults) 71 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 72 | if assert.NoError(t, err) { 73 | assert.False(t, groupCompleted) 74 | } 75 | 76 | backend.SetStateFailure(task1, "Some error") 77 | groupCompleted, err = backend.GroupCompleted(groupUUID, 2) 78 | if assert.NoError(t, err) { 79 | assert.True(t, groupCompleted) 80 | } 81 | } 82 | 83 | func TestGetStateGR(t *testing.T) { 84 | backend := getRedisG() 85 | if backend == nil { 86 | t.Skip() 87 | } 88 | 89 | signature := &tasks.Signature{ 90 | UUID: "testTaskUUID", 91 | GroupUUID: "testGroupUUID", 92 | } 93 | 94 | backend.PurgeState("testTaskUUID") 95 | 96 | var ( 97 | taskState *tasks.TaskState 98 | err error 99 | ) 100 | 101 | taskState, err = backend.GetState(signature.UUID) 102 | assert.Equal(t, "redis: nil", err.Error()) 103 | assert.Nil(t, taskState) 104 | 105 | //Pending State 106 | backend.SetStatePending(signature) 107 | taskState, err = backend.GetState(signature.UUID) 108 | assert.NoError(t, err) 109 | assert.Equal(t, signature.Name, taskState.TaskName) 110 | createdAt := taskState.CreatedAt 111 | 112 | //Received State 113 | backend.SetStateReceived(signature) 114 | taskState, err = backend.GetState(signature.UUID) 115 | assert.NoError(t, err) 116 | assert.Equal(t, signature.Name, taskState.TaskName) 117 | assert.Equal(t, createdAt, taskState.CreatedAt) 118 | 119 | //Started State 120 | backend.SetStateStarted(signature) 121 | taskState, err = backend.GetState(signature.UUID) 122 | assert.NoError(t, err) 123 | assert.Equal(t, signature.Name, taskState.TaskName) 124 | assert.Equal(t, createdAt, taskState.CreatedAt) 125 | 126 | //Success State 127 | taskResults := []*tasks.TaskResult{ 128 | { 129 | Type: "float64", 130 | Value: 2, 131 | }, 132 | } 133 | backend.SetStateSuccess(signature, taskResults) 134 | taskState, err = backend.GetState(signature.UUID) 135 | assert.NoError(t, err) 136 | assert.Equal(t, signature.Name, taskState.TaskName) 137 | assert.Equal(t, createdAt, taskState.CreatedAt) 138 | assert.NotNil(t, taskState.Results) 139 | } 140 | 141 | func TestPurgeStateGR(t *testing.T) { 142 | backend := getRedisG() 143 | if backend == nil { 144 | t.Skip() 145 | } 146 | 147 | signature := &tasks.Signature{ 148 | UUID: "testTaskUUID", 149 | GroupUUID: "testGroupUUID", 150 | } 151 | 152 | backend.SetStatePending(signature) 153 | taskState, err := backend.GetState(signature.UUID) 154 | assert.NotNil(t, taskState) 155 | assert.NoError(t, err) 156 | 157 | backend.PurgeState(taskState.TaskUUID) 158 | taskState, err = backend.GetState(signature.UUID) 159 | assert.Nil(t, taskState) 160 | assert.Error(t, err) 161 | } 162 | -------------------------------------------------------------------------------- /v2/brokers/amqp/amqp_concurrence_test.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "fmt" 5 | "github.com/RichardKnop/machinery/v2/brokers/iface" 6 | "github.com/RichardKnop/machinery/v2/config" 7 | "github.com/RichardKnop/machinery/v2/tasks" 8 | amqp "github.com/rabbitmq/amqp091-go" 9 | "testing" 10 | "time" 11 | ) 12 | 13 | type doNothingProcessor struct{} 14 | 15 | func (_ doNothingProcessor) Process(signature *tasks.Signature) error { 16 | return fmt.Errorf("failed") 17 | } 18 | 19 | func (_ doNothingProcessor) CustomQueue() string { 20 | return "oops" 21 | } 22 | 23 | func (_ doNothingProcessor) PreConsumeHandler() bool { 24 | return true 25 | } 26 | 27 | func TestConsume(t *testing.T) { 28 | var ( 29 | iBroker iface.Broker 30 | deliveries = make(chan amqp.Delivery, 3) 31 | closeChan chan *amqp.Error 32 | processor doNothingProcessor 33 | ) 34 | 35 | t.Run("with deliveries more than the number of concurrency", func(t *testing.T) { 36 | iBroker = New(&config.Config{}) 37 | broker, _ := iBroker.(*Broker) 38 | errChan := make(chan error) 39 | 40 | // simulate that there are too much deliveries 41 | go func() { 42 | for i := 0; i < 3; i++ { 43 | deliveries <- amqp.Delivery{} // broker.consumeOne() will complain this error: Received an empty message 44 | } 45 | }() 46 | 47 | go func() { 48 | err := broker.consume(deliveries, 2, processor, closeChan) 49 | if err != nil { 50 | errChan <- err 51 | } 52 | }() 53 | 54 | select { 55 | case <-errChan: 56 | case <-time.After(1 * time.Second): 57 | t.Error("Maybe deadlock") 58 | } 59 | }) 60 | } 61 | -------------------------------------------------------------------------------- /v2/brokers/amqp/amqp_test.go: -------------------------------------------------------------------------------- 1 | package amqp_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v2/brokers/amqp" 7 | "github.com/RichardKnop/machinery/v2/brokers/iface" 8 | "github.com/RichardKnop/machinery/v2/config" 9 | "github.com/RichardKnop/machinery/v2/tasks" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestAdjustRoutingKey(t *testing.T) { 14 | t.Parallel() 15 | 16 | var ( 17 | s *tasks.Signature 18 | broker iface.Broker 19 | ) 20 | 21 | t.Run("with routing and binding keys", func(t *testing.T) { 22 | s := &tasks.Signature{RoutingKey: "routing_key"} 23 | broker = amqp.New(&config.Config{ 24 | DefaultQueue: "queue", 25 | AMQP: &config.AMQPConfig{ 26 | ExchangeType: "direct", 27 | BindingKey: "binding_key", 28 | }, 29 | }) 30 | broker.AdjustRoutingKey(s) 31 | assert.Equal(t, "routing_key", s.RoutingKey) 32 | }) 33 | 34 | t.Run("with binding key", func(t *testing.T) { 35 | s = new(tasks.Signature) 36 | broker = amqp.New(&config.Config{ 37 | DefaultQueue: "queue", 38 | AMQP: &config.AMQPConfig{ 39 | ExchangeType: "direct", 40 | BindingKey: "binding_key", 41 | }, 42 | }) 43 | broker.AdjustRoutingKey(s) 44 | assert.Equal(t, "binding_key", s.RoutingKey) 45 | }) 46 | } 47 | -------------------------------------------------------------------------------- /v2/brokers/eager/eager.go: -------------------------------------------------------------------------------- 1 | package eager 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | 10 | "github.com/RichardKnop/machinery/v2/brokers/iface" 11 | "github.com/RichardKnop/machinery/v2/common" 12 | "github.com/RichardKnop/machinery/v2/tasks" 13 | ) 14 | 15 | // Broker represents an "eager" in-memory broker 16 | type Broker struct { 17 | worker iface.TaskProcessor 18 | common.Broker 19 | } 20 | 21 | // New creates new Broker instance 22 | func New() iface.Broker { 23 | return new(Broker) 24 | } 25 | 26 | // Mode interface with methods specific for this broker 27 | type Mode interface { 28 | AssignWorker(p iface.TaskProcessor) 29 | } 30 | 31 | // StartConsuming enters a loop and waits for incoming messages 32 | func (eagerBroker *Broker) StartConsuming(consumerTag string, concurrency int, p iface.TaskProcessor) (bool, error) { 33 | return true, nil 34 | } 35 | 36 | // StopConsuming quits the loop 37 | func (eagerBroker *Broker) StopConsuming() { 38 | // do nothing 39 | } 40 | 41 | // Publish places a new message on the default queue 42 | func (eagerBroker *Broker) Publish(ctx context.Context, task *tasks.Signature) error { 43 | if eagerBroker.worker == nil { 44 | return errors.New("worker is not assigned in eager-mode") 45 | } 46 | 47 | // faking the behavior to marshal input into json 48 | // and unmarshal it back 49 | message, err := json.Marshal(task) 50 | if err != nil { 51 | return fmt.Errorf("JSON marshal error: %s", err) 52 | } 53 | 54 | signature := new(tasks.Signature) 55 | decoder := json.NewDecoder(bytes.NewReader(message)) 56 | decoder.UseNumber() 57 | if err := decoder.Decode(signature); err != nil { 58 | return fmt.Errorf("JSON unmarshal error: %s", err) 59 | } 60 | 61 | // blocking call to the task directly 62 | return eagerBroker.worker.Process(signature) 63 | } 64 | 65 | // AssignWorker assigns a worker to the eager broker 66 | func (eagerBroker *Broker) AssignWorker(w iface.TaskProcessor) { 67 | eagerBroker.worker = w 68 | } 69 | -------------------------------------------------------------------------------- /v2/brokers/errs/errors.go: -------------------------------------------------------------------------------- 1 | package errs 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | // ErrCouldNotUnmarshalTaskSignature ... 9 | type ErrCouldNotUnmarshalTaskSignature struct { 10 | msg []byte 11 | reason string 12 | } 13 | 14 | // Error implements the error interface 15 | func (e ErrCouldNotUnmarshalTaskSignature) Error() string { 16 | return fmt.Sprintf("Could not unmarshal '%s' into a task signature: %v", e.msg, e.reason) 17 | } 18 | 19 | // NewErrCouldNotUnmarshalTaskSignature returns new ErrCouldNotUnmarshalTaskSignature instance 20 | func NewErrCouldNotUnmarshalTaskSignature(msg []byte, err error) ErrCouldNotUnmarshalTaskSignature { 21 | return ErrCouldNotUnmarshalTaskSignature{msg: msg, reason: err.Error()} 22 | } 23 | 24 | // ErrConsumerStopped indicates that the operation is now illegal because of the consumer being stopped. 25 | var ErrConsumerStopped = errors.New("the server has been stopped") 26 | 27 | // ErrStopTaskDeletion indicates that the task should not be deleted from source after task failure 28 | var ErrStopTaskDeletion = errors.New("task should not be deleted") 29 | -------------------------------------------------------------------------------- /v2/brokers/iface/interfaces.go: -------------------------------------------------------------------------------- 1 | package iface 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/RichardKnop/machinery/v2/config" 7 | "github.com/RichardKnop/machinery/v2/tasks" 8 | ) 9 | 10 | // Broker - a common interface for all brokers 11 | type Broker interface { 12 | GetConfig() *config.Config 13 | SetRegisteredTaskNames(names []string) 14 | IsTaskRegistered(name string) bool 15 | StartConsuming(consumerTag string, concurrency int, p TaskProcessor) (bool, error) 16 | StopConsuming() 17 | Publish(ctx context.Context, task *tasks.Signature) error 18 | GetPendingTasks(queue string) ([]*tasks.Signature, error) 19 | GetDelayedTasks() ([]*tasks.Signature, error) 20 | AdjustRoutingKey(s *tasks.Signature) 21 | } 22 | 23 | // TaskProcessor - can process a delivered task 24 | // This will probably always be a worker instance 25 | type TaskProcessor interface { 26 | Process(signature *tasks.Signature) error 27 | CustomQueue() string 28 | PreConsumeHandler() bool 29 | } 30 | -------------------------------------------------------------------------------- /v2/brokers/package.go: -------------------------------------------------------------------------------- 1 | package brokers 2 | -------------------------------------------------------------------------------- /v2/common/amqp.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "crypto/tls" 5 | "fmt" 6 | "strings" 7 | 8 | amqp "github.com/rabbitmq/amqp091-go" 9 | ) 10 | 11 | // AMQPConnector ... 12 | type AMQPConnector struct{} 13 | 14 | // Connect opens a connection to RabbitMQ, declares an exchange, opens a channel, 15 | // declares and binds the queue and enables publish notifications 16 | func (ac *AMQPConnector) Connect(urls string, urlSeparator string, tlsConfig *tls.Config, exchange, exchangeType, queueName string, queueDurable, queueDelete bool, queueBindingKey string, exchangeDeclareArgs, queueDeclareArgs, queueBindingArgs amqp.Table) (*amqp.Connection, *amqp.Channel, amqp.Queue, <-chan amqp.Confirmation, <-chan *amqp.Error, error) { 17 | urlsList := []string{urls} 18 | if urlSeparator != "" { 19 | urlsList = strings.Split(urls, urlSeparator) 20 | } 21 | 22 | var conn *amqp.Connection 23 | var channel *amqp.Channel 24 | var err error 25 | 26 | for _, url := range urlsList { 27 | // Connect to server 28 | conn, channel, err = ac.Open(url, tlsConfig) 29 | if err != nil { 30 | continue 31 | } else { 32 | break 33 | } 34 | } 35 | 36 | if err != nil { 37 | return nil, nil, amqp.Queue{}, nil, nil, err 38 | } 39 | 40 | if exchange != "" { 41 | // Declare an exchange 42 | if err = channel.ExchangeDeclare( 43 | exchange, // name of the exchange 44 | exchangeType, // type 45 | true, // durable 46 | false, // delete when complete 47 | false, // internal 48 | false, // noWait 49 | exchangeDeclareArgs, // arguments 50 | ); err != nil { 51 | return conn, channel, amqp.Queue{}, nil, nil, fmt.Errorf("Exchange declare error: %s", err) 52 | } 53 | } 54 | 55 | var queue amqp.Queue 56 | if queueName != "" { 57 | // Declare a queue 58 | queue, err = channel.QueueDeclare( 59 | queueName, // name 60 | queueDurable, // durable 61 | queueDelete, // delete when unused 62 | false, // exclusive 63 | false, // no-wait 64 | queueDeclareArgs, // arguments 65 | ) 66 | if err != nil { 67 | return conn, channel, amqp.Queue{}, nil, nil, fmt.Errorf("Queue declare error: %s", err) 68 | } 69 | 70 | // Bind the queue 71 | if err = channel.QueueBind( 72 | queue.Name, // name of the queue 73 | queueBindingKey, // binding key 74 | exchange, // source exchange 75 | false, // noWait 76 | queueBindingArgs, // arguments 77 | ); err != nil { 78 | return conn, channel, queue, nil, nil, fmt.Errorf("Queue bind error: %s", err) 79 | } 80 | } 81 | 82 | // Enable publish confirmations 83 | if err = channel.Confirm(false); err != nil { 84 | return conn, channel, queue, nil, nil, fmt.Errorf("Channel could not be put into confirm mode: %s", err) 85 | } 86 | 87 | return conn, channel, queue, channel.NotifyPublish(make(chan amqp.Confirmation, 1)), conn.NotifyClose(make(chan *amqp.Error, 1)), nil 88 | } 89 | 90 | // DeleteQueue deletes a queue by name 91 | func (ac *AMQPConnector) DeleteQueue(channel *amqp.Channel, queueName string) error { 92 | // First return value is number of messages removed 93 | _, err := channel.QueueDelete( 94 | queueName, // name 95 | false, // ifUnused 96 | false, // ifEmpty 97 | false, // noWait 98 | ) 99 | 100 | return err 101 | } 102 | 103 | // InspectQueue provides information about a specific queue 104 | func (*AMQPConnector) InspectQueue(channel *amqp.Channel, queueName string) (*amqp.Queue, error) { 105 | queueState, err := channel.QueueInspect(queueName) 106 | if err != nil { 107 | return nil, fmt.Errorf("Queue inspect error: %s", err) 108 | } 109 | 110 | return &queueState, nil 111 | } 112 | 113 | // Open new RabbitMQ connection 114 | func (ac *AMQPConnector) Open(url string, tlsConfig *tls.Config) (*amqp.Connection, *amqp.Channel, error) { 115 | // Connect 116 | // From amqp docs: DialTLS will use the provided tls.Config when it encounters an amqps:// scheme 117 | // and will dial a plain connection when it encounters an amqp:// scheme. 118 | conn, err := amqp.DialTLS(url, tlsConfig) 119 | if err != nil { 120 | return nil, nil, fmt.Errorf("Dial error: %s", err) 121 | } 122 | 123 | // Open a channel 124 | channel, err := conn.Channel() 125 | if err != nil { 126 | return nil, nil, fmt.Errorf("Open channel error: %s", err) 127 | } 128 | 129 | return conn, channel, nil 130 | } 131 | 132 | // Close connection 133 | func (ac *AMQPConnector) Close(channel *amqp.Channel, conn *amqp.Connection) error { 134 | if channel != nil { 135 | if err := channel.Close(); err != nil { 136 | return fmt.Errorf("Close channel error: %s", err) 137 | } 138 | } 139 | 140 | if conn != nil { 141 | if err := conn.Close(); err != nil { 142 | return fmt.Errorf("Close connection error: %s", err) 143 | } 144 | } 145 | 146 | return nil 147 | } 148 | -------------------------------------------------------------------------------- /v2/common/backend.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "github.com/RichardKnop/machinery/v2/config" 5 | ) 6 | 7 | // Backend represents a base backend structure 8 | type Backend struct { 9 | cnf *config.Config 10 | } 11 | 12 | // NewBackend creates new Backend instance 13 | func NewBackend(cnf *config.Config) Backend { 14 | return Backend{cnf: cnf} 15 | } 16 | 17 | // GetConfig returns config 18 | func (b *Backend) GetConfig() *config.Config { 19 | return b.cnf 20 | } 21 | 22 | // IsAMQP ... 23 | func (b *Backend) IsAMQP() bool { 24 | return false 25 | } 26 | -------------------------------------------------------------------------------- /v2/common/broker.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | 7 | "github.com/RichardKnop/machinery/v2/brokers/iface" 8 | "github.com/RichardKnop/machinery/v2/config" 9 | "github.com/RichardKnop/machinery/v2/log" 10 | "github.com/RichardKnop/machinery/v2/retry" 11 | "github.com/RichardKnop/machinery/v2/tasks" 12 | ) 13 | 14 | type registeredTaskNames struct { 15 | sync.RWMutex 16 | items []string 17 | } 18 | 19 | // Broker represents a base broker structure 20 | type Broker struct { 21 | cnf *config.Config 22 | registeredTaskNames registeredTaskNames 23 | retry bool 24 | retryFunc func(chan int) 25 | retryStopChan chan int 26 | stopChan chan int 27 | } 28 | 29 | // NewBroker creates new Broker instance 30 | func NewBroker(cnf *config.Config) Broker { 31 | return Broker{ 32 | cnf: cnf, 33 | retry: true, 34 | stopChan: make(chan int), 35 | retryStopChan: make(chan int), 36 | } 37 | } 38 | 39 | // GetConfig returns config 40 | func (b *Broker) GetConfig() *config.Config { 41 | return b.cnf 42 | } 43 | 44 | // GetRetry ... 45 | func (b *Broker) GetRetry() bool { 46 | return b.retry 47 | } 48 | 49 | // GetRetryFunc ... 50 | func (b *Broker) GetRetryFunc() func(chan int) { 51 | return b.retryFunc 52 | } 53 | 54 | // GetRetryStopChan ... 55 | func (b *Broker) GetRetryStopChan() chan int { 56 | return b.retryStopChan 57 | } 58 | 59 | // GetStopChan ... 60 | func (b *Broker) GetStopChan() chan int { 61 | return b.stopChan 62 | } 63 | 64 | // Publish places a new message on the default queue 65 | func (b *Broker) Publish(signature *tasks.Signature) error { 66 | return errors.New("Not implemented") 67 | } 68 | 69 | // SetRegisteredTaskNames sets registered task names 70 | func (b *Broker) SetRegisteredTaskNames(names []string) { 71 | b.registeredTaskNames.Lock() 72 | defer b.registeredTaskNames.Unlock() 73 | b.registeredTaskNames.items = names 74 | } 75 | 76 | // IsTaskRegistered returns true if the task is registered with this broker 77 | func (b *Broker) IsTaskRegistered(name string) bool { 78 | b.registeredTaskNames.RLock() 79 | defer b.registeredTaskNames.RUnlock() 80 | for _, registeredTaskName := range b.registeredTaskNames.items { 81 | if registeredTaskName == name { 82 | return true 83 | } 84 | } 85 | return false 86 | } 87 | 88 | // GetPendingTasks returns a slice of task.Signatures waiting in the queue 89 | func (b *Broker) GetPendingTasks(queue string) ([]*tasks.Signature, error) { 90 | return nil, errors.New("Not implemented") 91 | } 92 | 93 | // GetDelayedTasks returns a slice of task.Signatures that are scheduled, but not yet in the queue 94 | func (b *Broker) GetDelayedTasks() ([]*tasks.Signature, error) { 95 | return nil, errors.New("Not implemented") 96 | } 97 | 98 | // StartConsuming is a common part of StartConsuming method 99 | func (b *Broker) StartConsuming(consumerTag string, concurrency int, taskProcessor iface.TaskProcessor) { 100 | if b.retryFunc == nil { 101 | b.retryFunc = retry.Closure() 102 | } 103 | 104 | } 105 | 106 | // StopConsuming is a common part of StopConsuming 107 | func (b *Broker) StopConsuming() { 108 | // Do not retry from now on 109 | b.retry = false 110 | // Stop the retry closure earlier 111 | select { 112 | case b.retryStopChan <- 1: 113 | log.WARNING.Print("Stopping retry closure.") 114 | default: 115 | } 116 | // Notifying the stop channel stops consuming of messages 117 | close(b.stopChan) 118 | log.WARNING.Print("Stop channel") 119 | } 120 | 121 | // GetRegisteredTaskNames returns registered tasks names 122 | func (b *Broker) GetRegisteredTaskNames() []string { 123 | b.registeredTaskNames.RLock() 124 | defer b.registeredTaskNames.RUnlock() 125 | items := b.registeredTaskNames.items 126 | return items 127 | } 128 | 129 | // AdjustRoutingKey makes sure the routing key is correct. 130 | // If the routing key is an empty string: 131 | // a) set it to binding key for direct exchange type 132 | // b) set it to default queue name 133 | func (b *Broker) AdjustRoutingKey(s *tasks.Signature) { 134 | if s.RoutingKey != "" { 135 | return 136 | } 137 | 138 | s.RoutingKey = b.GetConfig().DefaultQueue 139 | } 140 | -------------------------------------------------------------------------------- /v2/common/broker_test.go: -------------------------------------------------------------------------------- 1 | package common_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v2" 7 | "github.com/RichardKnop/machinery/v2/common" 8 | "github.com/RichardKnop/machinery/v2/config" 9 | "github.com/RichardKnop/machinery/v2/tasks" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestIsTaskRegistered(t *testing.T) { 14 | t.Parallel() 15 | 16 | broker := common.NewBroker(new(config.Config)) 17 | broker.SetRegisteredTaskNames([]string{"foo", "bar"}) 18 | 19 | assert.True(t, broker.IsTaskRegistered("foo")) 20 | assert.False(t, broker.IsTaskRegistered("bogus")) 21 | } 22 | 23 | func TestAdjustRoutingKey(t *testing.T) { 24 | t.Parallel() 25 | 26 | var ( 27 | s *tasks.Signature 28 | broker common.Broker 29 | ) 30 | 31 | t.Run("with routing key", func(t *testing.T) { 32 | s = &tasks.Signature{RoutingKey: "routing_key"} 33 | broker = common.NewBroker(&config.Config{ 34 | DefaultQueue: "queue", 35 | }) 36 | broker.AdjustRoutingKey(s) 37 | assert.Equal(t, "routing_key", s.RoutingKey) 38 | }) 39 | 40 | t.Run("without routing key", func(t *testing.T) { 41 | s = new(tasks.Signature) 42 | broker = common.NewBroker(&config.Config{ 43 | DefaultQueue: "queue", 44 | }) 45 | broker.AdjustRoutingKey(s) 46 | assert.Equal(t, "queue", s.RoutingKey) 47 | }) 48 | } 49 | 50 | func TestGetRegisteredTaskNames(t *testing.T) { 51 | t.Parallel() 52 | 53 | broker := common.NewBroker(new(config.Config)) 54 | fooTasks := []string{"foo", "bar", "baz"} 55 | broker.SetRegisteredTaskNames(fooTasks) 56 | assert.Equal(t, fooTasks, broker.GetRegisteredTaskNames()) 57 | } 58 | 59 | func TestStopConsuming(t *testing.T) { 60 | t.Parallel() 61 | 62 | t.Run("stop consuming", func(t *testing.T) { 63 | broker := common.NewBroker(&config.Config{ 64 | DefaultQueue: "queue", 65 | }) 66 | broker.StartConsuming("", 1, &machinery.Worker{}) 67 | broker.StopConsuming() 68 | select { 69 | case <-broker.GetStopChan(): 70 | default: 71 | assert.Fail(t, "still blocking") 72 | } 73 | }) 74 | } 75 | -------------------------------------------------------------------------------- /v2/common/redis.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import ( 4 | "crypto/tls" 5 | "time" 6 | 7 | "github.com/gomodule/redigo/redis" 8 | 9 | "github.com/RichardKnop/machinery/v2/config" 10 | ) 11 | 12 | var ( 13 | defaultConfig = &config.RedisConfig{ 14 | MaxIdle: 10, 15 | MaxActive: 100, 16 | IdleTimeout: 300, 17 | Wait: true, 18 | ReadTimeout: 15, 19 | WriteTimeout: 15, 20 | ConnectTimeout: 15, 21 | NormalTasksPollPeriod: 1000, 22 | DelayedTasksPollPeriod: 20, 23 | } 24 | ) 25 | 26 | // RedisConnector ... 27 | type RedisConnector struct{} 28 | 29 | // NewPool returns a new pool of Redis connections 30 | func (rc *RedisConnector) NewPool(socketPath, host, username, password string, db int, cnf *config.RedisConfig, tlsConfig *tls.Config) *redis.Pool { 31 | if cnf == nil { 32 | cnf = defaultConfig 33 | } 34 | return &redis.Pool{ 35 | MaxIdle: cnf.MaxIdle, 36 | IdleTimeout: time.Duration(cnf.IdleTimeout) * time.Second, 37 | MaxActive: cnf.MaxActive, 38 | Wait: cnf.Wait, 39 | Dial: func() (redis.Conn, error) { 40 | c, err := rc.open(socketPath, host, username, password, db, cnf, tlsConfig) 41 | if err != nil { 42 | return nil, err 43 | } 44 | 45 | if db != 0 { 46 | _, err = c.Do("SELECT", db) 47 | if err != nil { 48 | return nil, err 49 | } 50 | } 51 | 52 | return c, err 53 | }, 54 | // PINGs connections that have been idle more than 10 seconds 55 | TestOnBorrow: func(c redis.Conn, t time.Time) error { 56 | if time.Since(t) < time.Duration(10*time.Second) { 57 | return nil 58 | } 59 | _, err := c.Do("PING") 60 | return err 61 | }, 62 | } 63 | } 64 | 65 | // Open a new Redis connection 66 | func (rc *RedisConnector) open(socketPath, host, username, password string, db int, cnf *config.RedisConfig, tlsConfig *tls.Config) (redis.Conn, error) { 67 | var opts = []redis.DialOption{ 68 | redis.DialDatabase(db), 69 | redis.DialReadTimeout(time.Duration(cnf.ReadTimeout) * time.Second), 70 | redis.DialWriteTimeout(time.Duration(cnf.WriteTimeout) * time.Second), 71 | redis.DialConnectTimeout(time.Duration(cnf.ConnectTimeout) * time.Second), 72 | redis.DialClientName(cnf.ClientName), 73 | } 74 | 75 | if tlsConfig != nil { 76 | opts = append(opts, redis.DialTLSConfig(tlsConfig), redis.DialUseTLS(true)) 77 | } 78 | if username != "" { 79 | opts = append(opts, redis.DialUsername(username)) 80 | } 81 | 82 | if password != "" { 83 | opts = append(opts, redis.DialPassword(password)) 84 | } 85 | 86 | if socketPath != "" { 87 | return redis.Dial("unix", socketPath, opts...) 88 | } 89 | 90 | return redis.Dial("tcp", host, opts...) 91 | } 92 | -------------------------------------------------------------------------------- /v2/config/env.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/kelseyhightower/envconfig" 5 | 6 | "github.com/RichardKnop/machinery/v2/log" 7 | ) 8 | 9 | // NewFromEnvironment creates a config object from environment variables 10 | func NewFromEnvironment() (*Config, error) { 11 | cnf, err := fromEnvironment() 12 | if err != nil { 13 | return nil, err 14 | } 15 | 16 | log.INFO.Print("Successfully loaded config from the environment") 17 | 18 | return cnf, nil 19 | } 20 | 21 | func fromEnvironment() (*Config, error) { 22 | loadedCnf, cnf := new(Config), new(Config) 23 | *cnf = *defaultCnf 24 | 25 | if err := envconfig.Process("", cnf); err != nil { 26 | return nil, err 27 | } 28 | if err := envconfig.Process("", loadedCnf); err != nil { 29 | return nil, err 30 | } 31 | 32 | if loadedCnf.AMQP == nil { 33 | cnf.AMQP = nil 34 | } 35 | 36 | return cnf, nil 37 | } 38 | -------------------------------------------------------------------------------- /v2/config/env_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "bufio" 5 | "os" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/RichardKnop/machinery/v2/config" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestNewFromEnvironment(t *testing.T) { 14 | t.Parallel() 15 | 16 | file, err := os.Open("test.env") 17 | if err != nil { 18 | t.Fatal(err) 19 | } 20 | reader := bufio.NewReader(file) 21 | scanner := bufio.NewScanner(reader) 22 | scanner.Split(bufio.ScanLines) 23 | for scanner.Scan() { 24 | parts := strings.Split(scanner.Text(), "=") 25 | if len(parts) != 2 { 26 | continue 27 | } 28 | os.Setenv(parts[0], parts[1]) 29 | } 30 | 31 | cnf, err := config.NewFromEnvironment() 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | 36 | assert.Equal(t, "broker", cnf.Broker) 37 | assert.Equal(t, "default_queue", cnf.DefaultQueue) 38 | assert.Equal(t, "result_backend", cnf.ResultBackend) 39 | assert.Equal(t, 123456, cnf.ResultsExpireIn) 40 | assert.Equal(t, "exchange", cnf.AMQP.Exchange) 41 | assert.Equal(t, "exchange_type", cnf.AMQP.ExchangeType) 42 | assert.Equal(t, "binding_key", cnf.AMQP.BindingKey) 43 | assert.Equal(t, "any", cnf.AMQP.QueueBindingArgs["x-match"]) 44 | assert.Equal(t, "png", cnf.AMQP.QueueBindingArgs["image-type"]) 45 | assert.Equal(t, 123, cnf.AMQP.PrefetchCount) 46 | } 47 | -------------------------------------------------------------------------------- /v2/config/file.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "time" 7 | 8 | "github.com/RichardKnop/machinery/v2/log" 9 | "gopkg.in/yaml.v2" 10 | ) 11 | 12 | // NewFromYaml creates a config object from YAML file 13 | func NewFromYaml(cnfPath string, keepReloading bool) (*Config, error) { 14 | cnf, err := fromFile(cnfPath) 15 | if err != nil { 16 | return nil, err 17 | } 18 | 19 | log.INFO.Printf("Successfully loaded config from file %s", cnfPath) 20 | 21 | if keepReloading { 22 | // Open a goroutine to watch remote changes forever 23 | go func() { 24 | for { 25 | // Delay after each request 26 | time.Sleep(reloadDelay) 27 | 28 | // Attempt to reload the config 29 | newCnf, newErr := fromFile(cnfPath) 30 | if newErr != nil { 31 | log.WARNING.Printf("Failed to reload config from file %s: %v", cnfPath, newErr) 32 | continue 33 | } 34 | 35 | *cnf = *newCnf 36 | } 37 | }() 38 | } 39 | 40 | return cnf, nil 41 | } 42 | 43 | // ReadFromFile reads data from a file 44 | func ReadFromFile(cnfPath string) ([]byte, error) { 45 | file, err := os.Open(cnfPath) 46 | 47 | // Config file not found 48 | if err != nil { 49 | return nil, fmt.Errorf("Open file error: %s", err) 50 | } 51 | defer file.Close() 52 | 53 | // Config file found, let's try to read it 54 | data := make([]byte, 1000) 55 | count, err := file.Read(data) 56 | if err != nil { 57 | return nil, fmt.Errorf("Read from file error: %s", err) 58 | } 59 | 60 | return data[:count], nil 61 | } 62 | 63 | func fromFile(cnfPath string) (*Config, error) { 64 | loadedCnf, cnf := new(Config), new(Config) 65 | *cnf = *defaultCnf 66 | 67 | data, err := ReadFromFile(cnfPath) 68 | if err != nil { 69 | return nil, err 70 | } 71 | 72 | if err := yaml.Unmarshal(data, cnf); err != nil { 73 | return nil, fmt.Errorf("Unmarshal YAML error: %s", err) 74 | } 75 | if err := yaml.Unmarshal(data, loadedCnf); err != nil { 76 | return nil, fmt.Errorf("Unmarshal YAML error: %s", err) 77 | } 78 | if loadedCnf.AMQP == nil { 79 | cnf.AMQP = nil 80 | } 81 | 82 | return cnf, nil 83 | } 84 | -------------------------------------------------------------------------------- /v2/config/file_test.go: -------------------------------------------------------------------------------- 1 | package config_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v2/config" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | var configYAMLData = `--- 11 | broker: broker 12 | default_queue: default_queue 13 | result_backend: result_backend 14 | results_expire_in: 123456 15 | amqp: 16 | binding_key: binding_key 17 | exchange: exchange 18 | exchange_type: exchange_type 19 | prefetch_count: 123 20 | queue_declare_args: 21 | x-max-priority: 10 22 | queue_binding_args: 23 | image-type: png 24 | x-match: any 25 | sqs: 26 | receive_wait_time_seconds: 123 27 | receive_visibility_timeout: 456 28 | redis: 29 | max_idle: 12 30 | max_active: 123 31 | max_idle_timeout: 456 32 | wait: false 33 | read_timeout: 17 34 | write_timeout: 19 35 | connect_timeout: 21 36 | normal_tasks_poll_period: 1001 37 | delayed_tasks_poll_period: 23 38 | delayed_tasks_key: delayed_tasks_key 39 | master_name: master_name 40 | no_unix_signals: true 41 | dynamodb: 42 | task_states_table: task_states_table 43 | group_metas_table: group_metas_table 44 | ` 45 | 46 | func TestReadFromFile(t *testing.T) { 47 | t.Parallel() 48 | 49 | data, err := config.ReadFromFile("testconfig.yml") 50 | if err != nil { 51 | t.Fatal(err) 52 | } 53 | 54 | assert.Equal(t, configYAMLData, string(data)) 55 | } 56 | 57 | func TestNewFromYaml(t *testing.T) { 58 | t.Parallel() 59 | 60 | cnf, err := config.NewFromYaml("testconfig.yml", false) 61 | if err != nil { 62 | t.Fatal(err) 63 | } 64 | 65 | assert.Equal(t, "broker", cnf.Broker) 66 | assert.Equal(t, "default_queue", cnf.DefaultQueue) 67 | assert.Equal(t, "result_backend", cnf.ResultBackend) 68 | assert.Equal(t, 123456, cnf.ResultsExpireIn) 69 | 70 | assert.Equal(t, "exchange", cnf.AMQP.Exchange) 71 | assert.Equal(t, "exchange_type", cnf.AMQP.ExchangeType) 72 | assert.Equal(t, "binding_key", cnf.AMQP.BindingKey) 73 | assert.Equal(t, 10, cnf.AMQP.QueueDeclareArgs["x-max-priority"]) 74 | assert.Equal(t, "any", cnf.AMQP.QueueBindingArgs["x-match"]) 75 | assert.Equal(t, "png", cnf.AMQP.QueueBindingArgs["image-type"]) 76 | assert.Equal(t, 123, cnf.AMQP.PrefetchCount) 77 | 78 | assert.Equal(t, 123, cnf.SQS.WaitTimeSeconds) 79 | assert.Equal(t, 456, *cnf.SQS.VisibilityTimeout) 80 | 81 | assert.Equal(t, 12, cnf.Redis.MaxIdle) 82 | assert.Equal(t, 123, cnf.Redis.MaxActive) 83 | assert.Equal(t, 456, cnf.Redis.IdleTimeout) 84 | assert.Equal(t, false, cnf.Redis.Wait) 85 | assert.Equal(t, 17, cnf.Redis.ReadTimeout) 86 | assert.Equal(t, 19, cnf.Redis.WriteTimeout) 87 | assert.Equal(t, 21, cnf.Redis.ConnectTimeout) 88 | assert.Equal(t, 1001, cnf.Redis.NormalTasksPollPeriod) 89 | assert.Equal(t, 23, cnf.Redis.DelayedTasksPollPeriod) 90 | assert.Equal(t, "delayed_tasks_key", cnf.Redis.DelayedTasksKey) 91 | assert.Equal(t, "master_name", cnf.Redis.MasterName) 92 | 93 | assert.Equal(t, true, cnf.NoUnixSignals) 94 | 95 | assert.Equal(t, "task_states_table", cnf.DynamoDB.TaskStatesTable) 96 | assert.Equal(t, "group_metas_table", cnf.DynamoDB.GroupMetasTable) 97 | } 98 | -------------------------------------------------------------------------------- /v2/config/test.env: -------------------------------------------------------------------------------- 1 | BROKER=broker 2 | DEFAULT_QUEUE=default_queue 3 | RESULT_BACKEND=result_backend 4 | RESULTS_EXPIRE_IN=123456 5 | AMQP_BINDING_KEY=binding_key 6 | AMQP_EXCHANGE=exchange 7 | AMQP_EXCHANGE_TYPE=exchange_type 8 | AMQP_PREFETCH_COUNT=123 9 | AMQP_QUEUE_BINDING_ARGS=image-type:png,x-match:any 10 | -------------------------------------------------------------------------------- /v2/config/testconfig.yml: -------------------------------------------------------------------------------- 1 | --- 2 | broker: broker 3 | default_queue: default_queue 4 | result_backend: result_backend 5 | results_expire_in: 123456 6 | amqp: 7 | binding_key: binding_key 8 | exchange: exchange 9 | exchange_type: exchange_type 10 | prefetch_count: 123 11 | queue_declare_args: 12 | x-max-priority: 10 13 | queue_binding_args: 14 | image-type: png 15 | x-match: any 16 | sqs: 17 | receive_wait_time_seconds: 123 18 | receive_visibility_timeout: 456 19 | redis: 20 | max_idle: 12 21 | max_active: 123 22 | max_idle_timeout: 456 23 | wait: false 24 | read_timeout: 17 25 | write_timeout: 19 26 | connect_timeout: 21 27 | normal_tasks_poll_period: 1001 28 | delayed_tasks_poll_period: 23 29 | delayed_tasks_key: delayed_tasks_key 30 | master_name: master_name 31 | no_unix_signals: true 32 | dynamodb: 33 | task_states_table: task_states_table 34 | group_metas_table: group_metas_table 35 | -------------------------------------------------------------------------------- /v2/example/tasks/tasks.go: -------------------------------------------------------------------------------- 1 | package exampletasks 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | "time" 7 | 8 | "github.com/RichardKnop/machinery/v2/log" 9 | ) 10 | 11 | // Add ... 12 | func Add(args ...int64) (int64, error) { 13 | sum := int64(0) 14 | for _, arg := range args { 15 | sum += arg 16 | } 17 | return sum, nil 18 | } 19 | 20 | // Multiply ... 21 | func Multiply(args ...int64) (int64, error) { 22 | sum := int64(1) 23 | for _, arg := range args { 24 | sum *= arg 25 | } 26 | return sum, nil 27 | } 28 | 29 | // SumInts ... 30 | func SumInts(numbers []int64) (int64, error) { 31 | var sum int64 32 | for _, num := range numbers { 33 | sum += num 34 | } 35 | return sum, nil 36 | } 37 | 38 | // SumFloats ... 39 | func SumFloats(numbers []float64) (float64, error) { 40 | var sum float64 41 | for _, num := range numbers { 42 | sum += num 43 | } 44 | return sum, nil 45 | } 46 | 47 | // Concat ... 48 | func Concat(strs []string) (string, error) { 49 | var res string 50 | for _, s := range strs { 51 | res += s 52 | } 53 | return res, nil 54 | } 55 | 56 | // Split ... 57 | func Split(str string) ([]string, error) { 58 | return strings.Split(str, ""), nil 59 | } 60 | 61 | // PanicTask ... 62 | func PanicTask() (string, error) { 63 | panic(errors.New("oops")) 64 | } 65 | 66 | // LongRunningTask ... 67 | func LongRunningTask() error { 68 | log.INFO.Print("Long running task started") 69 | for i := 0; i < 10; i++ { 70 | log.INFO.Print(10 - i) 71 | time.Sleep(1 * time.Second) 72 | } 73 | log.INFO.Print("Long running task finished") 74 | return nil 75 | } 76 | -------------------------------------------------------------------------------- /v2/example/tracers/jaeger.go: -------------------------------------------------------------------------------- 1 | package tracers 2 | 3 | // Uncomment the import statement for the jaeger tracer. 4 | // make sure you run dep ensure to pull in the jaeger client 5 | // 6 | // import ( 7 | // jaeger "github.com/uber/jaeger-client-go" 8 | // jaegercfg "github.com/uber/jaeger-client-go/config" 9 | // ) 10 | 11 | // SetupTracer is the place where you'd setup your specific tracer. 12 | // The jaeger tracer is given as an example. 13 | // To capture the jaeger traces you should run the jaeger backend. 14 | // This can be done using the following docker command: 15 | // 16 | // `docker run -ti --rm -p6831:6831/udp -p16686:16686 jaegertracing/all-in-one:latest` 17 | // 18 | // The collector will be listening on localhost:6831 19 | // and the query UI is reachable on localhost:16686. 20 | func SetupTracer(serviceName string) (func(), error) { 21 | 22 | // Jaeger setup code 23 | // 24 | // config := jaegercfg.Configuration{ 25 | // Sampler: &jaegercfg.SamplerConfig{ 26 | // Type: jaeger.SamplerTypeConst, 27 | // Param: 1, 28 | // }, 29 | // } 30 | 31 | // closer, err := config.InitGlobalTracer(serviceName) 32 | // if err != nil { 33 | // return nil, err 34 | // } 35 | 36 | cleanupFunc := func() { 37 | // closer.Close() 38 | } 39 | 40 | return cleanupFunc, nil 41 | } 42 | -------------------------------------------------------------------------------- /v2/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/RichardKnop/machinery/v2 2 | 3 | go 1.15 4 | 5 | require ( 6 | cloud.google.com/go/pubsub v1.10.0 7 | github.com/RichardKnop/logging v0.0.0-20190827224416-1a693bdd4fae 8 | github.com/aws/aws-sdk-go v1.37.16 9 | github.com/bradfitz/gomemcache v0.0.0-20190913173617-a41fca850d0b 10 | github.com/go-redsync/redsync/v4 v4.8.1 11 | github.com/gomodule/redigo v1.9.2 12 | github.com/google/uuid v1.2.0 13 | github.com/kelseyhightower/envconfig v1.4.0 14 | github.com/opentracing/opentracing-go v1.2.0 15 | github.com/pkg/errors v0.9.1 16 | github.com/rabbitmq/amqp091-go v1.9.0 17 | github.com/redis/go-redis/v9 v9.0.5 18 | github.com/robfig/cron/v3 v3.0.1 19 | github.com/stretchr/testify v1.8.4 20 | github.com/urfave/cli v1.22.5 21 | go.mongodb.org/mongo-driver v1.17.0 22 | gopkg.in/yaml.v2 v2.4.0 23 | ) 24 | 25 | replace git.apache.org/thrift.git => github.com/apache/thrift v0.0.0-20180902110319-2566ecd5d999 26 | -------------------------------------------------------------------------------- /v2/integration-tests/amqp_amqp_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/RichardKnop/machinery/v2" 8 | "github.com/RichardKnop/machinery/v2/config" 9 | 10 | amqpbackend "github.com/RichardKnop/machinery/v2/backends/amqp" 11 | amqpbroker "github.com/RichardKnop/machinery/v2/brokers/amqp" 12 | eagerlock "github.com/RichardKnop/machinery/v2/locks/eager" 13 | ) 14 | 15 | func TestAmqpAmqp(t *testing.T) { 16 | amqpURL := os.Getenv("AMQP_URL") 17 | if amqpURL == "" { 18 | t.Skip("AMQP_URL is not defined") 19 | } 20 | 21 | finalAmqpURL := amqpURL 22 | var finalSeparator string 23 | 24 | amqpURLs := os.Getenv("AMQP_URLS") 25 | if amqpURLs != "" { 26 | separator := os.Getenv("AMQP_URLS_SEPARATOR") 27 | if separator == "" { 28 | return 29 | } 30 | finalSeparator = separator 31 | finalAmqpURL = amqpURLs 32 | } 33 | 34 | cnf := &config.Config{ 35 | Broker: finalAmqpURL, 36 | MultipleBrokerSeparator: finalSeparator, 37 | DefaultQueue: "machinery_tasks", 38 | ResultBackend: amqpURL, 39 | ResultsExpireIn: 3600, 40 | AMQP: &config.AMQPConfig{ 41 | Exchange: "test_exchange", 42 | ExchangeType: "direct", 43 | BindingKey: "test_task", 44 | PrefetchCount: 1, 45 | }, 46 | } 47 | 48 | broker := amqpbroker.New(cnf) 49 | backend := amqpbackend.New(cnf) 50 | lock := eagerlock.New() 51 | server := machinery.NewServer(cnf, broker, backend, lock) 52 | 53 | registerTestTasks(server) 54 | 55 | worker := server.NewWorker("test_worker", 0) 56 | defer worker.Quit() 57 | go worker.Launch() 58 | testAll(server, t) 59 | } 60 | -------------------------------------------------------------------------------- /v2/integration-tests/redis_redis_test.go: -------------------------------------------------------------------------------- 1 | package integration_test 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/RichardKnop/machinery/v2" 8 | redisbackend "github.com/RichardKnop/machinery/v2/backends/redis" 9 | redisbroker "github.com/RichardKnop/machinery/v2/brokers/redis" 10 | "github.com/RichardKnop/machinery/v2/config" 11 | eagerlock "github.com/RichardKnop/machinery/v2/locks/eager" 12 | ) 13 | 14 | func TestRedisRedis_GoRedis(t *testing.T) { 15 | redisURL := os.Getenv("REDIS_URL") 16 | if redisURL == "" { 17 | t.Skip("REDIS_URL is not defined") 18 | } 19 | 20 | cnf := &config.Config{ 21 | DefaultQueue: "machinery_tasks", 22 | ResultsExpireIn: 3600, 23 | Redis: &config.RedisConfig{ 24 | MaxIdle: 3, 25 | IdleTimeout: 240, 26 | ReadTimeout: 15, 27 | WriteTimeout: 15, 28 | ConnectTimeout: 15, 29 | NormalTasksPollPeriod: 1000, 30 | DelayedTasksPollPeriod: 500, 31 | }, 32 | } 33 | 34 | broker := redisbroker.NewGR(cnf, []string{redisURL}, 0) 35 | backend := redisbackend.NewGR(cnf, []string{redisURL}, 0) 36 | lock := eagerlock.New() 37 | server := machinery.NewServer(cnf, broker, backend, lock) 38 | 39 | registerTestTasks(server) 40 | 41 | worker := server.NewWorker("test_worker", 0) 42 | defer worker.Quit() 43 | go worker.Launch() 44 | testAll(server, t) 45 | } 46 | -------------------------------------------------------------------------------- /v2/locks/eager/eager.go: -------------------------------------------------------------------------------- 1 | package eager 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | var ( 10 | ErrEagerLockFailed = errors.New("eager lock: failed to acquire lock") 11 | ) 12 | 13 | type Lock struct { 14 | retries int 15 | interval time.Duration 16 | register struct { 17 | sync.RWMutex 18 | m map[string]int64 19 | } 20 | } 21 | 22 | func New() *Lock { 23 | return &Lock{ 24 | retries: 3, 25 | interval: 5 * time.Second, 26 | register: struct { 27 | sync.RWMutex 28 | m map[string]int64 29 | }{m: make(map[string]int64)}, 30 | } 31 | } 32 | 33 | func (e *Lock) LockWithRetries(key string, value int64) error { 34 | for i := 0; i <= e.retries; i++ { 35 | err := e.Lock(key, value) 36 | if err == nil { 37 | //成功拿到锁,返回 38 | return nil 39 | } 40 | 41 | time.Sleep(e.interval) 42 | } 43 | return ErrEagerLockFailed 44 | } 45 | 46 | func (e *Lock) Lock(key string, value int64) error { 47 | e.register.Lock() 48 | defer e.register.Unlock() 49 | timeout, exist := e.register.m[key] 50 | if !exist || time.Now().UnixNano() > timeout { 51 | e.register.m[key] = value 52 | return nil 53 | } 54 | return ErrEagerLockFailed 55 | } 56 | -------------------------------------------------------------------------------- /v2/locks/eager/eager_test.go: -------------------------------------------------------------------------------- 1 | package eager 2 | 3 | import ( 4 | lockiface "github.com/RichardKnop/machinery/v2/locks/iface" 5 | "github.com/RichardKnop/machinery/v2/utils" 6 | "github.com/stretchr/testify/assert" 7 | "testing" 8 | "time" 9 | ) 10 | 11 | func TestLock_Lock(t *testing.T) { 12 | lock := New() 13 | keyName := utils.GetPureUUID() 14 | 15 | go func() { 16 | err := lock.Lock(keyName, time.Now().Add(25*time.Second).UnixNano()) 17 | assert.NoError(t, err) 18 | }() 19 | time.Sleep(1 * time.Second) 20 | err := lock.Lock(keyName, time.Now().Add(25*time.Second).UnixNano()) 21 | assert.Error(t, err) 22 | assert.EqualError(t, err, ErrEagerLockFailed.Error()) 23 | } 24 | 25 | func TestLock_LockWithRetries(t *testing.T) { 26 | lock := New() 27 | keyName := utils.GetPureUUID() 28 | 29 | go func() { 30 | err := lock.LockWithRetries(keyName, time.Now().Add(25*time.Second).UnixNano()) 31 | assert.NoError(t, err) 32 | }() 33 | time.Sleep(1 * time.Second) 34 | err := lock.LockWithRetries(keyName, time.Now().Add(25*time.Second).UnixNano()) 35 | assert.Error(t, err) 36 | assert.EqualError(t, err, ErrEagerLockFailed.Error()) 37 | } 38 | 39 | func TestNew(t *testing.T) { 40 | lock := New() 41 | assert.Implements(t, (*lockiface.Lock)(nil), lock) 42 | } 43 | -------------------------------------------------------------------------------- /v2/locks/iface/interfaces.go: -------------------------------------------------------------------------------- 1 | package iface 2 | 3 | type Lock interface { 4 | //Acquire the lock with retry 5 | //key: the name of the lock, 6 | //value: at the nanosecond timestamp that lock needs to be released automatically 7 | LockWithRetries(key string, value int64) error 8 | 9 | //Acquire the lock with once 10 | //key: the name of the lock, 11 | //value: at the nanosecond timestamp that lock needs to be released automatically 12 | Lock(key string, value int64) error 13 | } 14 | -------------------------------------------------------------------------------- /v2/locks/redis/redis.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "strconv" 7 | "strings" 8 | "time" 9 | 10 | "github.com/redis/go-redis/v9" 11 | 12 | "github.com/RichardKnop/machinery/v2/config" 13 | ) 14 | 15 | var ( 16 | ErrRedisLockFailed = errors.New("redis lock: failed to acquire lock") 17 | ) 18 | 19 | type Lock struct { 20 | rclient redis.UniversalClient 21 | retries int 22 | interval time.Duration 23 | } 24 | 25 | func New(cnf *config.Config, addrs []string, db, retries int) Lock { 26 | if retries <= 0 { 27 | return Lock{} 28 | } 29 | lock := Lock{retries: retries} 30 | 31 | var password string 32 | 33 | parts := strings.Split(addrs[0], "@") 34 | if len(parts) >= 2 { 35 | password = strings.Join(parts[:len(parts)-1], "@") 36 | addrs[0] = parts[len(parts)-1] // addr is the last one without @ 37 | } 38 | 39 | ropt := &redis.UniversalOptions{ 40 | Addrs: addrs, 41 | DB: db, 42 | Password: password, 43 | } 44 | if cnf.Redis != nil { 45 | ropt.MasterName = cnf.Redis.MasterName 46 | } 47 | 48 | if cnf.Redis != nil && cnf.Redis.SentinelPassword != "" { 49 | ropt.SentinelPassword = cnf.Redis.SentinelPassword 50 | } 51 | 52 | if cnf.Redis != nil && cnf.Redis.ClusterEnabled { 53 | lock.rclient = redis.NewClusterClient(ropt.Cluster()) 54 | } else { 55 | lock.rclient = redis.NewUniversalClient(ropt) 56 | } 57 | 58 | return lock 59 | } 60 | 61 | func (r Lock) LockWithRetries(key string, unixTsToExpireNs int64) error { 62 | for i := 0; i <= r.retries; i++ { 63 | err := r.Lock(key, unixTsToExpireNs) 64 | if err == nil { 65 | // 成功拿到锁,返回 66 | return nil 67 | } 68 | 69 | time.Sleep(r.interval) 70 | } 71 | return ErrRedisLockFailed 72 | } 73 | 74 | func (r Lock) Lock(key string, unixTsToExpireNs int64) error { 75 | now := time.Now().UnixNano() 76 | expiration := time.Duration(unixTsToExpireNs + 1 - now) 77 | ctx := context.Background() 78 | 79 | success, err := r.rclient.SetNX(ctx, key, unixTsToExpireNs, expiration).Result() 80 | if err != nil { 81 | return err 82 | } 83 | 84 | if !success { 85 | v, err := r.rclient.Get(ctx, key).Result() 86 | if err != nil { 87 | return err 88 | } 89 | timeout, err := strconv.Atoi(v) 90 | if err != nil { 91 | return err 92 | } 93 | 94 | if timeout != 0 && now > int64(timeout) { 95 | newTimeout, err := r.rclient.GetSet(ctx, key, unixTsToExpireNs).Result() 96 | if err != nil { 97 | return err 98 | } 99 | 100 | curTimeout, err := strconv.Atoi(newTimeout) 101 | if err != nil { 102 | return err 103 | } 104 | 105 | if now > int64(curTimeout) { 106 | // success to acquire lock with get set 107 | // set the expiration of redis key 108 | r.rclient.Expire(ctx, key, expiration) 109 | return nil 110 | } 111 | 112 | return ErrRedisLockFailed 113 | } 114 | 115 | return ErrRedisLockFailed 116 | } 117 | 118 | return nil 119 | } 120 | -------------------------------------------------------------------------------- /v2/log/log.go: -------------------------------------------------------------------------------- 1 | package log 2 | 3 | import ( 4 | "github.com/RichardKnop/logging" 5 | ) 6 | 7 | var ( 8 | logger = logging.New(nil, nil, new(logging.ColouredFormatter)) 9 | 10 | // DEBUG ... 11 | DEBUG = logger[logging.DEBUG] 12 | // INFO ... 13 | INFO = logger[logging.INFO] 14 | // WARNING ... 15 | WARNING = logger[logging.WARNING] 16 | // ERROR ... 17 | ERROR = logger[logging.ERROR] 18 | // FATAL ... 19 | FATAL = logger[logging.FATAL] 20 | ) 21 | 22 | // Set sets a custom logger for all log levels 23 | func Set(l logging.LoggerInterface) { 24 | DEBUG = l 25 | INFO = l 26 | WARNING = l 27 | ERROR = l 28 | FATAL = l 29 | } 30 | 31 | // SetDebug sets a custom logger for DEBUG level logs 32 | func SetDebug(l logging.LoggerInterface) { 33 | DEBUG = l 34 | } 35 | 36 | // SetInfo sets a custom logger for INFO level logs 37 | func SetInfo(l logging.LoggerInterface) { 38 | INFO = l 39 | } 40 | 41 | // SetWarning sets a custom logger for WARNING level logs 42 | func SetWarning(l logging.LoggerInterface) { 43 | WARNING = l 44 | } 45 | 46 | // SetError sets a custom logger for ERROR level logs 47 | func SetError(l logging.LoggerInterface) { 48 | ERROR = l 49 | } 50 | 51 | // SetFatal sets a custom logger for FATAL level logs 52 | func SetFatal(l logging.LoggerInterface) { 53 | FATAL = l 54 | } 55 | -------------------------------------------------------------------------------- /v2/log/log_test.go: -------------------------------------------------------------------------------- 1 | package log_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v2/log" 7 | ) 8 | 9 | func TestDefaultLogger(t *testing.T) { 10 | log.INFO.Print("should not panic") 11 | log.WARNING.Print("should not panic") 12 | log.ERROR.Print("should not panic") 13 | log.FATAL.Print("should not panic") 14 | } 15 | -------------------------------------------------------------------------------- /v2/package.go: -------------------------------------------------------------------------------- 1 | package machinery 2 | -------------------------------------------------------------------------------- /v2/retry/fibonacci.go: -------------------------------------------------------------------------------- 1 | package retry 2 | 3 | // Fibonacci returns successive Fibonacci numbers starting from 1 4 | func Fibonacci() func() int { 5 | a, b := 0, 1 6 | return func() int { 7 | a, b = b, a+b 8 | return a 9 | } 10 | } 11 | 12 | // FibonacciNext returns next number in Fibonacci sequence greater than start 13 | func FibonacciNext(start int) int { 14 | fib := Fibonacci() 15 | num := fib() 16 | for num <= start { 17 | num = fib() 18 | } 19 | return num 20 | } 21 | -------------------------------------------------------------------------------- /v2/retry/fibonacci_test.go: -------------------------------------------------------------------------------- 1 | package retry_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v2/retry" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestFibonacci(t *testing.T) { 11 | fibonacci := retry.Fibonacci() 12 | 13 | sequence := []int{ 14 | fibonacci(), 15 | fibonacci(), 16 | fibonacci(), 17 | fibonacci(), 18 | fibonacci(), 19 | fibonacci(), 20 | } 21 | 22 | assert.EqualValues(t, sequence, []int{1, 1, 2, 3, 5, 8}) 23 | } 24 | 25 | func TestFibonacciNext(t *testing.T) { 26 | assert.Equal(t, 1, retry.FibonacciNext(0)) 27 | assert.Equal(t, 2, retry.FibonacciNext(1)) 28 | assert.Equal(t, 5, retry.FibonacciNext(3)) 29 | assert.Equal(t, 5, retry.FibonacciNext(4)) 30 | assert.Equal(t, 8, retry.FibonacciNext(5)) 31 | assert.Equal(t, 13, retry.FibonacciNext(8)) 32 | } 33 | -------------------------------------------------------------------------------- /v2/retry/retry.go: -------------------------------------------------------------------------------- 1 | package retry 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/RichardKnop/machinery/v2/log" 8 | ) 9 | 10 | // Closure - a useful closure we can use when there is a problem 11 | // connecting to the broker. It uses Fibonacci sequence to space out retry attempts 12 | var Closure = func() func(chan int) { 13 | retryIn := 0 14 | fibonacci := Fibonacci() 15 | return func(stopChan chan int) { 16 | if retryIn > 0 { 17 | durationString := fmt.Sprintf("%vs", retryIn) 18 | duration, _ := time.ParseDuration(durationString) 19 | 20 | log.WARNING.Printf("Retrying in %v seconds", retryIn) 21 | 22 | select { 23 | case <-stopChan: 24 | break 25 | case <-time.After(duration): 26 | break 27 | } 28 | } 29 | retryIn = fibonacci() 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /v2/server_test.go: -------------------------------------------------------------------------------- 1 | package machinery_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/RichardKnop/machinery/v2" 9 | "github.com/RichardKnop/machinery/v2/config" 10 | 11 | backend "github.com/RichardKnop/machinery/v2/backends/eager" 12 | broker "github.com/RichardKnop/machinery/v2/brokers/eager" 13 | lock "github.com/RichardKnop/machinery/v2/locks/eager" 14 | ) 15 | 16 | func TestRegisterTasks(t *testing.T) { 17 | t.Parallel() 18 | 19 | server := getTestServer(t) 20 | err := server.RegisterTasks(map[string]interface{}{ 21 | "test_task": func() error { return nil }, 22 | }) 23 | assert.NoError(t, err) 24 | 25 | _, err = server.GetRegisteredTask("test_task") 26 | assert.NoError(t, err, "test_task is not registered but it should be") 27 | } 28 | 29 | func TestRegisterTask(t *testing.T) { 30 | t.Parallel() 31 | 32 | server := getTestServer(t) 33 | err := server.RegisterTask("test_task", func() error { return nil }) 34 | assert.NoError(t, err) 35 | 36 | _, err = server.GetRegisteredTask("test_task") 37 | assert.NoError(t, err, "test_task is not registered but it should be") 38 | } 39 | 40 | func TestGetRegisteredTask(t *testing.T) { 41 | t.Parallel() 42 | 43 | server := getTestServer(t) 44 | _, err := server.GetRegisteredTask("test_task") 45 | assert.Error(t, err, "test_task is registered but it should not be") 46 | } 47 | 48 | func TestGetRegisteredTaskNames(t *testing.T) { 49 | t.Parallel() 50 | 51 | server := getTestServer(t) 52 | 53 | taskName := "test_task" 54 | err := server.RegisterTask(taskName, func() error { return nil }) 55 | assert.NoError(t, err) 56 | 57 | taskNames := server.GetRegisteredTaskNames() 58 | assert.Equal(t, 1, len(taskNames)) 59 | assert.Equal(t, taskName, taskNames[0]) 60 | } 61 | 62 | func TestNewWorker(t *testing.T) { 63 | t.Parallel() 64 | 65 | server := getTestServer(t) 66 | 67 | server.NewWorker("test_worker", 1) 68 | assert.NoError(t, nil) 69 | } 70 | 71 | func TestNewCustomQueueWorker(t *testing.T) { 72 | t.Parallel() 73 | 74 | server := getTestServer(t) 75 | 76 | server.NewCustomQueueWorker("test_customqueueworker", 1, "test_queue") 77 | assert.NoError(t, nil) 78 | } 79 | 80 | func getTestServer(t *testing.T) *machinery.Server { 81 | return machinery.NewServer(&config.Config{}, broker.New(), backend.New(), lock.New()) 82 | } 83 | -------------------------------------------------------------------------------- /v2/tasks/errors.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | ) 7 | 8 | // ErrRetryTaskLater ... 9 | type ErrRetryTaskLater struct { 10 | name, msg string 11 | retryIn time.Duration 12 | } 13 | 14 | // RetryIn returns time.Duration from now when task should be retried 15 | func (e ErrRetryTaskLater) RetryIn() time.Duration { 16 | return e.retryIn 17 | } 18 | 19 | // Error implements the error interface 20 | func (e ErrRetryTaskLater) Error() string { 21 | return fmt.Sprintf("Task error: %s Will retry in: %s", e.msg, e.retryIn) 22 | } 23 | 24 | // NewErrRetryTaskLater returns new ErrRetryTaskLater instance 25 | func NewErrRetryTaskLater(msg string, retryIn time.Duration) ErrRetryTaskLater { 26 | return ErrRetryTaskLater{msg: msg, retryIn: retryIn} 27 | } 28 | 29 | // Retriable is interface that retriable errors should implement 30 | type Retriable interface { 31 | RetryIn() time.Duration 32 | } 33 | -------------------------------------------------------------------------------- /v2/tasks/result.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | "strings" 7 | ) 8 | 9 | // TaskResult represents an actual return value of a processed task 10 | type TaskResult struct { 11 | Type string `bson:"type"` 12 | Value interface{} `bson:"value"` 13 | } 14 | 15 | // ReflectTaskResults ... 16 | func ReflectTaskResults(taskResults []*TaskResult) ([]reflect.Value, error) { 17 | resultValues := make([]reflect.Value, len(taskResults)) 18 | for i, taskResult := range taskResults { 19 | resultValue, err := ReflectValue(taskResult.Type, taskResult.Value) 20 | if err != nil { 21 | return nil, err 22 | } 23 | resultValues[i] = resultValue 24 | } 25 | return resultValues, nil 26 | } 27 | 28 | // HumanReadableResults ... 29 | func HumanReadableResults(results []reflect.Value) string { 30 | if len(results) == 1 { 31 | return fmt.Sprintf("%v", results[0].Interface()) 32 | } 33 | 34 | readableResults := make([]string, len(results)) 35 | for i := 0; i < len(results); i++ { 36 | readableResults[i] = fmt.Sprintf("%v", results[i].Interface()) 37 | } 38 | 39 | return fmt.Sprintf("[%s]", strings.Join(readableResults, ", ")) 40 | } 41 | -------------------------------------------------------------------------------- /v2/tasks/result_test.go: -------------------------------------------------------------------------------- 1 | package tasks_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v2/tasks" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestReflectTaskResults(t *testing.T) { 11 | t.Parallel() 12 | 13 | taskResults := []*tasks.TaskResult{ 14 | { 15 | Type: "[]string", 16 | Value: []string{"f", "o", "o"}, 17 | }, 18 | } 19 | results, err := tasks.ReflectTaskResults(taskResults) 20 | if assert.NoError(t, err) { 21 | assert.Equal(t, 1, len(results)) 22 | assert.Equal(t, 3, results[0].Len()) 23 | assert.Equal(t, "f", results[0].Index(0).String()) 24 | assert.Equal(t, "o", results[0].Index(1).String()) 25 | assert.Equal(t, "o", results[0].Index(2).String()) 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /v2/tasks/signature.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import ( 4 | "fmt" 5 | "github.com/RichardKnop/machinery/v2/utils" 6 | "time" 7 | 8 | "github.com/google/uuid" 9 | ) 10 | 11 | // Arg represents a single argument passed to invocation fo a task 12 | type Arg struct { 13 | Name string `bson:"name"` 14 | Type string `bson:"type"` 15 | Value interface{} `bson:"value"` 16 | } 17 | 18 | // Headers represents the headers which should be used to direct the task 19 | type Headers map[string]interface{} 20 | 21 | // Set on Headers implements opentracing.TextMapWriter for trace propagation 22 | func (h Headers) Set(key, val string) { 23 | h[key] = val 24 | } 25 | 26 | // ForeachKey on Headers implements opentracing.TextMapReader for trace propagation. 27 | // It is essentially the same as the opentracing.TextMapReader implementation except 28 | // for the added casting from interface{} to string. 29 | func (h Headers) ForeachKey(handler func(key, val string) error) error { 30 | for k, v := range h { 31 | // Skip any non string values 32 | stringValue, ok := v.(string) 33 | if !ok { 34 | continue 35 | } 36 | 37 | if err := handler(k, stringValue); err != nil { 38 | return err 39 | } 40 | } 41 | 42 | return nil 43 | } 44 | 45 | // Signature represents a single task invocation 46 | type Signature struct { 47 | UUID string 48 | Name string 49 | RoutingKey string 50 | ETA *time.Time 51 | GroupUUID string 52 | GroupTaskCount int 53 | Args []Arg 54 | Headers Headers 55 | Priority uint8 56 | Immutable bool 57 | RetryCount int 58 | RetryTimeout int 59 | OnSuccess []*Signature 60 | OnError []*Signature 61 | ChordCallback *Signature 62 | //MessageGroupId for Broker, e.g. SQS 63 | BrokerMessageGroupId string 64 | //ReceiptHandle of SQS Message 65 | SQSReceiptHandle string 66 | // StopTaskDeletionOnError used with sqs when we want to send failed messages to dlq, 67 | // and don't want machinery to delete from source queue 68 | StopTaskDeletionOnError bool 69 | // IgnoreWhenTaskNotRegistered auto removes the request when there is no handeler available 70 | // When this is true a task with no handler will be ignored and not placed back in the queue 71 | IgnoreWhenTaskNotRegistered bool 72 | } 73 | 74 | // NewSignature creates a new task signature 75 | func NewSignature(name string, args []Arg) (*Signature, error) { 76 | signatureID := uuid.New().String() 77 | return &Signature{ 78 | UUID: fmt.Sprintf("task_%v", signatureID), 79 | Name: name, 80 | Args: args, 81 | }, nil 82 | } 83 | 84 | func CopySignatures(signatures ...*Signature) []*Signature { 85 | var sigs = make([]*Signature, len(signatures)) 86 | for index, signature := range signatures { 87 | sigs[index] = CopySignature(signature) 88 | } 89 | return sigs 90 | } 91 | 92 | func CopySignature(signature *Signature) *Signature { 93 | var sig = new(Signature) 94 | _ = utils.DeepCopy(sig, signature) 95 | return sig 96 | } 97 | -------------------------------------------------------------------------------- /v2/tasks/state.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import "time" 4 | 5 | const ( 6 | // StatePending - initial state of a task 7 | StatePending = "PENDING" 8 | // StateReceived - when task is received by a worker 9 | StateReceived = "RECEIVED" 10 | // StateStarted - when the worker starts processing the task 11 | StateStarted = "STARTED" 12 | // StateRetry - when failed task has been scheduled for retry 13 | StateRetry = "RETRY" 14 | // StateSuccess - when the task is processed successfully 15 | StateSuccess = "SUCCESS" 16 | // StateFailure - when processing of the task fails 17 | StateFailure = "FAILURE" 18 | ) 19 | 20 | // TaskState represents a state of a task 21 | type TaskState struct { 22 | TaskUUID string `bson:"_id"` 23 | TaskName string `bson:"task_name"` 24 | State string `bson:"state"` 25 | Results []*TaskResult `bson:"results"` 26 | Error string `bson:"error"` 27 | CreatedAt time.Time `bson:"created_at"` 28 | TTL int64 `bson:"ttl,omitempty"` 29 | } 30 | 31 | // GroupMeta stores useful metadata about tasks within the same group 32 | // E.g. UUIDs of all tasks which are used in order to check if all tasks 33 | // completed successfully or not and thus whether to trigger chord callback 34 | type GroupMeta struct { 35 | GroupUUID string `bson:"_id"` 36 | TaskUUIDs []string `bson:"task_uuids"` 37 | ChordTriggered bool `bson:"chord_triggered"` 38 | Lock bool `bson:"lock"` 39 | CreatedAt time.Time `bson:"created_at"` 40 | TTL int64 `bson:"ttl,omitempty"` 41 | } 42 | 43 | // NewPendingTaskState ... 44 | func NewPendingTaskState(signature *Signature) *TaskState { 45 | return &TaskState{ 46 | TaskUUID: signature.UUID, 47 | TaskName: signature.Name, 48 | State: StatePending, 49 | CreatedAt: time.Now().UTC(), 50 | } 51 | } 52 | 53 | // NewReceivedTaskState ... 54 | func NewReceivedTaskState(signature *Signature) *TaskState { 55 | return &TaskState{ 56 | TaskUUID: signature.UUID, 57 | State: StateReceived, 58 | } 59 | } 60 | 61 | // NewStartedTaskState ... 62 | func NewStartedTaskState(signature *Signature) *TaskState { 63 | return &TaskState{ 64 | TaskUUID: signature.UUID, 65 | State: StateStarted, 66 | } 67 | } 68 | 69 | // NewSuccessTaskState ... 70 | func NewSuccessTaskState(signature *Signature, results []*TaskResult) *TaskState { 71 | return &TaskState{ 72 | TaskUUID: signature.UUID, 73 | State: StateSuccess, 74 | Results: results, 75 | } 76 | } 77 | 78 | // NewFailureTaskState ... 79 | func NewFailureTaskState(signature *Signature, err string) *TaskState { 80 | return &TaskState{ 81 | TaskUUID: signature.UUID, 82 | State: StateFailure, 83 | Error: err, 84 | } 85 | } 86 | 87 | // NewRetryTaskState ... 88 | func NewRetryTaskState(signature *Signature) *TaskState { 89 | return &TaskState{ 90 | TaskUUID: signature.UUID, 91 | State: StateRetry, 92 | } 93 | } 94 | 95 | // IsCompleted returns true if state is SUCCESS or FAILURE, 96 | // i.e. the task has finished processing and either succeeded or failed. 97 | func (taskState *TaskState) IsCompleted() bool { 98 | return taskState.IsSuccess() || taskState.IsFailure() 99 | } 100 | 101 | // IsSuccess returns true if state is SUCCESS 102 | func (taskState *TaskState) IsSuccess() bool { 103 | return taskState.State == StateSuccess 104 | } 105 | 106 | // IsFailure returns true if state is FAILURE 107 | func (taskState *TaskState) IsFailure() bool { 108 | return taskState.State == StateFailure 109 | } 110 | -------------------------------------------------------------------------------- /v2/tasks/state_test.go: -------------------------------------------------------------------------------- 1 | package tasks_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v2/tasks" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestTaskStateIsCompleted(t *testing.T) { 11 | t.Parallel() 12 | 13 | taskState := &tasks.TaskState{ 14 | TaskUUID: "taskUUID", 15 | State: tasks.StatePending, 16 | } 17 | 18 | assert.False(t, taskState.IsCompleted()) 19 | 20 | taskState.State = tasks.StateReceived 21 | assert.False(t, taskState.IsCompleted()) 22 | 23 | taskState.State = tasks.StateStarted 24 | assert.False(t, taskState.IsCompleted()) 25 | 26 | taskState.State = tasks.StateSuccess 27 | assert.True(t, taskState.IsCompleted()) 28 | 29 | taskState.State = tasks.StateFailure 30 | assert.True(t, taskState.IsCompleted()) 31 | } 32 | -------------------------------------------------------------------------------- /v2/tasks/task_test.go: -------------------------------------------------------------------------------- 1 | package tasks_test 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "math" 7 | "testing" 8 | "time" 9 | 10 | "github.com/RichardKnop/machinery/v2/tasks" 11 | "github.com/stretchr/testify/assert" 12 | ) 13 | 14 | func TestTaskCallErrorTest(t *testing.T) { 15 | t.Parallel() 16 | 17 | // Create test task that returns tasks.ErrRetryTaskLater error 18 | retriable := func() error { return tasks.NewErrRetryTaskLater("some error", 4*time.Hour) } 19 | 20 | task, err := tasks.New(retriable, []tasks.Arg{}) 21 | assert.NoError(t, err) 22 | 23 | // Invoke TryCall and validate that returned error can be cast to tasks.ErrRetryTaskLater 24 | results, err := task.Call() 25 | assert.Nil(t, results) 26 | assert.NotNil(t, err) 27 | _, ok := interface{}(err).(tasks.ErrRetryTaskLater) 28 | assert.True(t, ok, "Error should be castable to tasks.ErrRetryTaskLater") 29 | 30 | // Create test task that returns a standard error 31 | standard := func() error { return errors.New("some error") } 32 | 33 | task, err = tasks.New(standard, []tasks.Arg{}) 34 | assert.NoError(t, err) 35 | 36 | // Invoke TryCall and validate that returned error is standard 37 | results, err = task.Call() 38 | assert.Nil(t, results) 39 | assert.NotNil(t, err) 40 | assert.Equal(t, "some error", err.Error()) 41 | } 42 | 43 | func TestTaskReflectArgs(t *testing.T) { 44 | t.Parallel() 45 | 46 | task := new(tasks.Task) 47 | args := []tasks.Arg{ 48 | { 49 | Type: "[]int64", 50 | Value: []int64{1, 2}, 51 | }, 52 | } 53 | 54 | err := task.ReflectArgs(args) 55 | assert.NoError(t, err) 56 | assert.Equal(t, 1, len(task.Args)) 57 | assert.Equal(t, "[]int64", task.Args[0].Type().String()) 58 | } 59 | 60 | func TestTaskCallInvalidArgRobustnessError(t *testing.T) { 61 | t.Parallel() 62 | 63 | // Create a test task function 64 | f := func(x int) error { return nil } 65 | 66 | // Construct an invalid argument list and reflect it 67 | args := []tasks.Arg{ 68 | {Type: "bool", Value: true}, 69 | } 70 | 71 | task, err := tasks.New(f, args) 72 | assert.NoError(t, err) 73 | 74 | // Invoke TryCall and validate error handling 75 | results, err := task.Call() 76 | assert.Equal(t, "reflect: Call using bool as type int", err.Error()) 77 | assert.Nil(t, results) 78 | } 79 | 80 | func TestTaskCallInterfaceValuedResult(t *testing.T) { 81 | t.Parallel() 82 | 83 | // Create a test task function 84 | f := func() (interface{}, error) { return math.Pi, nil } 85 | 86 | task, err := tasks.New(f, []tasks.Arg{}) 87 | assert.NoError(t, err) 88 | 89 | taskResults, err := task.Call() 90 | assert.NoError(t, err) 91 | assert.Equal(t, "float64", taskResults[0].Type) 92 | assert.Equal(t, math.Pi, taskResults[0].Value) 93 | } 94 | 95 | func TestTaskCallWithContext(t *testing.T) { 96 | t.Parallel() 97 | 98 | f := func(c context.Context) (interface{}, error) { 99 | assert.NotNil(t, c) 100 | assert.Nil(t, tasks.SignatureFromContext(c)) 101 | return math.Pi, nil 102 | } 103 | task, err := tasks.New(f, []tasks.Arg{}) 104 | assert.NoError(t, err) 105 | taskResults, err := task.Call() 106 | assert.NoError(t, err) 107 | assert.Equal(t, "float64", taskResults[0].Type) 108 | assert.Equal(t, math.Pi, taskResults[0].Value) 109 | } 110 | 111 | func TestTaskCallWithSignatureInContext(t *testing.T) { 112 | t.Parallel() 113 | 114 | f := func(c context.Context) (interface{}, error) { 115 | assert.NotNil(t, c) 116 | signature := tasks.SignatureFromContext(c) 117 | assert.NotNil(t, signature) 118 | assert.Equal(t, "foo", signature.Name) 119 | return math.Pi, nil 120 | } 121 | signature, err := tasks.NewSignature("foo", []tasks.Arg{}) 122 | assert.NoError(t, err) 123 | task, err := tasks.NewWithSignature(f, signature) 124 | assert.NoError(t, err) 125 | taskResults, err := task.Call() 126 | assert.NoError(t, err) 127 | assert.Equal(t, "float64", taskResults[0].Type) 128 | assert.Equal(t, math.Pi, taskResults[0].Value) 129 | } 130 | -------------------------------------------------------------------------------- /v2/tasks/validate.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | ) 7 | 8 | var ( 9 | // ErrTaskMustBeFunc ... 10 | ErrTaskMustBeFunc = errors.New("Task must be a func type") 11 | // ErrTaskReturnsNoValue ... 12 | ErrTaskReturnsNoValue = errors.New("Task must return at least a single value") 13 | // ErrLastReturnValueMustBeError .. 14 | ErrLastReturnValueMustBeError = errors.New("Last return value of a task must be error") 15 | ) 16 | 17 | // ValidateTask validates task function using reflection and makes sure 18 | // it has a proper signature. Functions used as tasks must return at least a 19 | // single value and the last return type must be error 20 | func ValidateTask(task interface{}) error { 21 | v := reflect.ValueOf(task) 22 | t := v.Type() 23 | 24 | // Task must be a function 25 | if t.Kind() != reflect.Func { 26 | return ErrTaskMustBeFunc 27 | } 28 | 29 | // Task must return at least a single value 30 | if t.NumOut() < 1 { 31 | return ErrTaskReturnsNoValue 32 | } 33 | 34 | // Last return value must be error 35 | lastReturnType := t.Out(t.NumOut() - 1) 36 | errorInterface := reflect.TypeOf((*error)(nil)).Elem() 37 | if !lastReturnType.Implements(errorInterface) { 38 | return ErrLastReturnValueMustBeError 39 | } 40 | 41 | return nil 42 | } 43 | -------------------------------------------------------------------------------- /v2/tasks/validate_test.go: -------------------------------------------------------------------------------- 1 | package tasks_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v2/tasks" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestValidateTask(t *testing.T) { 11 | t.Parallel() 12 | 13 | type someStruct struct{} 14 | var ( 15 | taskOfWrongType = new(someStruct) 16 | taskWithoutReturnValue = func() {} 17 | taskWithoutErrorAsLastReturnValue = func() int { return 0 } 18 | validTask = func(arg string) error { return nil } 19 | ) 20 | 21 | err := tasks.ValidateTask(taskOfWrongType) 22 | assert.Equal(t, tasks.ErrTaskMustBeFunc, err) 23 | 24 | err = tasks.ValidateTask(taskWithoutReturnValue) 25 | assert.Equal(t, tasks.ErrTaskReturnsNoValue, err) 26 | 27 | err = tasks.ValidateTask(taskWithoutErrorAsLastReturnValue) 28 | assert.Equal(t, tasks.ErrLastReturnValueMustBeError, err) 29 | 30 | err = tasks.ValidateTask(validTask) 31 | assert.NoError(t, err) 32 | } 33 | -------------------------------------------------------------------------------- /v2/tasks/workflow.go: -------------------------------------------------------------------------------- 1 | package tasks 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/google/uuid" 7 | ) 8 | 9 | // Chain creates a chain of tasks to be executed one after another 10 | type Chain struct { 11 | Tasks []*Signature 12 | } 13 | 14 | // Group creates a set of tasks to be executed in parallel 15 | type Group struct { 16 | GroupUUID string 17 | Tasks []*Signature 18 | } 19 | 20 | // Chord adds an optional callback to the group to be executed 21 | // after all tasks in the group finished 22 | type Chord struct { 23 | Group *Group 24 | Callback *Signature 25 | } 26 | 27 | // GetUUIDs returns slice of task UUIDS 28 | func (group *Group) GetUUIDs() []string { 29 | taskUUIDs := make([]string, len(group.Tasks)) 30 | for i, signature := range group.Tasks { 31 | taskUUIDs[i] = signature.UUID 32 | } 33 | return taskUUIDs 34 | } 35 | 36 | // NewChain creates a new chain of tasks to be processed one by one, passing 37 | // results unless task signatures are set to be immutable 38 | func NewChain(signatures ...*Signature) (*Chain, error) { 39 | // Auto generate task UUIDs if needed 40 | for _, signature := range signatures { 41 | if signature.UUID == "" { 42 | signatureID := uuid.New().String() 43 | signature.UUID = fmt.Sprintf("task_%v", signatureID) 44 | } 45 | } 46 | 47 | for i := len(signatures) - 1; i > 0; i-- { 48 | if i > 0 { 49 | signatures[i-1].OnSuccess = []*Signature{signatures[i]} 50 | } 51 | } 52 | 53 | chain := &Chain{Tasks: signatures} 54 | 55 | return chain, nil 56 | } 57 | 58 | // NewGroup creates a new group of tasks to be processed in parallel 59 | func NewGroup(signatures ...*Signature) (*Group, error) { 60 | // Generate a group UUID 61 | groupUUID := uuid.New().String() 62 | groupID := fmt.Sprintf("group_%v", groupUUID) 63 | 64 | // Auto generate task UUIDs if needed, group tasks by common group UUID 65 | for _, signature := range signatures { 66 | if signature.UUID == "" { 67 | signatureID := uuid.New().String() 68 | signature.UUID = fmt.Sprintf("task_%v", signatureID) 69 | } 70 | signature.GroupUUID = groupID 71 | signature.GroupTaskCount = len(signatures) 72 | } 73 | 74 | return &Group{ 75 | GroupUUID: groupID, 76 | Tasks: signatures, 77 | }, nil 78 | } 79 | 80 | // NewChord creates a new chord (a group of tasks with a single callback 81 | // to be executed after all tasks in the group has completed) 82 | func NewChord(group *Group, callback *Signature) (*Chord, error) { 83 | if callback.UUID == "" { 84 | // Generate a UUID for the chord callback 85 | callbackUUID := uuid.New().String() 86 | callback.UUID = fmt.Sprintf("chord_%v", callbackUUID) 87 | } 88 | 89 | // Add a chord callback to all tasks 90 | for _, signature := range group.Tasks { 91 | signature.ChordCallback = callback 92 | } 93 | 94 | return &Chord{Group: group, Callback: callback}, nil 95 | } 96 | -------------------------------------------------------------------------------- /v2/tasks/workflow_test.go: -------------------------------------------------------------------------------- 1 | package tasks_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/RichardKnop/machinery/v2/tasks" 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestNewChain(t *testing.T) { 11 | t.Parallel() 12 | 13 | task1 := tasks.Signature{ 14 | Name: "foo", 15 | Args: []tasks.Arg{ 16 | { 17 | Type: "float64", 18 | Value: interface{}(1), 19 | }, 20 | { 21 | Type: "float64", 22 | Value: interface{}(1), 23 | }, 24 | }, 25 | } 26 | 27 | task2 := tasks.Signature{ 28 | Name: "bar", 29 | Args: []tasks.Arg{ 30 | { 31 | Type: "float64", 32 | Value: interface{}(5), 33 | }, 34 | { 35 | Type: "float64", 36 | Value: interface{}(6), 37 | }, 38 | }, 39 | } 40 | 41 | task3 := tasks.Signature{ 42 | Name: "qux", 43 | Args: []tasks.Arg{ 44 | { 45 | Type: "float64", 46 | Value: interface{}(4), 47 | }, 48 | }, 49 | } 50 | 51 | chain, err := tasks.NewChain(&task1, &task2, &task3) 52 | if err != nil { 53 | t.Fatal(err) 54 | } 55 | 56 | firstTask := chain.Tasks[0] 57 | 58 | assert.Equal(t, "foo", firstTask.Name) 59 | assert.Equal(t, "bar", firstTask.OnSuccess[0].Name) 60 | assert.Equal(t, "qux", firstTask.OnSuccess[0].OnSuccess[0].Name) 61 | } 62 | -------------------------------------------------------------------------------- /v2/utils/deepcopy.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "reflect" 6 | ) 7 | 8 | var ( 9 | ErrNoMatchType = errors.New("no match type") 10 | ErrNoPointer = errors.New("must be interface") 11 | ErrInvalidArgument = errors.New("invalid arguments") 12 | ) 13 | 14 | func deepCopy(dst, src reflect.Value) { 15 | switch src.Kind() { 16 | case reflect.Interface: 17 | value := src.Elem() 18 | if !value.IsValid() { 19 | return 20 | } 21 | newValue := reflect.New(value.Type()).Elem() 22 | deepCopy(newValue, value) 23 | dst.Set(newValue) 24 | case reflect.Ptr: 25 | value := src.Elem() 26 | if !value.IsValid() { 27 | return 28 | } 29 | dst.Set(reflect.New(value.Type())) 30 | deepCopy(dst.Elem(), value) 31 | case reflect.Map: 32 | dst.Set(reflect.MakeMap(src.Type())) 33 | keys := src.MapKeys() 34 | for _, key := range keys { 35 | value := src.MapIndex(key) 36 | newValue := reflect.New(value.Type()).Elem() 37 | deepCopy(newValue, value) 38 | dst.SetMapIndex(key, newValue) 39 | } 40 | case reflect.Slice: 41 | dst.Set(reflect.MakeSlice(src.Type(), src.Len(), src.Cap())) 42 | for i := 0; i < src.Len(); i++ { 43 | deepCopy(dst.Index(i), src.Index(i)) 44 | } 45 | case reflect.Struct: 46 | typeSrc := src.Type() 47 | for i := 0; i < src.NumField(); i++ { 48 | value := src.Field(i) 49 | tag := typeSrc.Field(i).Tag 50 | if value.CanSet() && tag.Get("deepcopy") != "-" { 51 | deepCopy(dst.Field(i), value) 52 | } 53 | } 54 | default: 55 | dst.Set(src) 56 | } 57 | } 58 | 59 | func DeepCopy(dst, src interface{}) error { 60 | typeDst := reflect.TypeOf(dst) 61 | typeSrc := reflect.TypeOf(src) 62 | if typeDst != typeSrc { 63 | return ErrNoMatchType 64 | } 65 | if typeSrc.Kind() != reflect.Ptr { 66 | return ErrNoPointer 67 | } 68 | 69 | valueDst := reflect.ValueOf(dst).Elem() 70 | valueSrc := reflect.ValueOf(src).Elem() 71 | if !valueDst.IsValid() || !valueSrc.IsValid() { 72 | return ErrInvalidArgument 73 | } 74 | 75 | deepCopy(valueDst, valueSrc) 76 | return nil 77 | } 78 | 79 | func DeepClone(v interface{}) interface{} { 80 | dst := reflect.New(reflect.TypeOf(v)).Elem() 81 | deepCopy(dst, reflect.ValueOf(v)) 82 | return dst.Interface() 83 | } 84 | -------------------------------------------------------------------------------- /v2/utils/deepcopy_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestDeepCopy(t *testing.T) { 10 | t.Parallel() 11 | 12 | type s struct { 13 | A float64 14 | B int 15 | C []int 16 | D *int 17 | E map[string]int 18 | } 19 | var d = 3 20 | var dst = new(s) 21 | var src = s{1.0, 1, []int{1, 2, 3}, &d, map[string]int{"a": 1}} 22 | 23 | err := DeepCopy(dst, &src) 24 | src.A = 2 25 | 26 | assert.NoError(t, err) 27 | assert.Equal(t, 1.0, dst.A) 28 | assert.Equal(t, 1, dst.B) 29 | assert.Equal(t, []int{1, 2, 3}, dst.C) 30 | assert.Equal(t, &d, dst.D) 31 | assert.Equal(t, map[string]int{"a": 1}, dst.E) 32 | } 33 | -------------------------------------------------------------------------------- /v2/utils/utils.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | ) 7 | 8 | const ( 9 | LockKeyPrefix = "machinery_lock_" 10 | ) 11 | 12 | func GetLockName(name, spec string) string { 13 | return LockKeyPrefix + filepath.Base(os.Args[0]) + name + spec 14 | } 15 | -------------------------------------------------------------------------------- /v2/utils/utils_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestGetLockName(t *testing.T) { 10 | t.Parallel() 11 | 12 | lockName := GetLockName("test", "*/3 * * *") 13 | assert.Equal(t, "machinery_lock_utils.testtest*/3 * * *", lockName) 14 | } 15 | -------------------------------------------------------------------------------- /v2/utils/uuid.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "github.com/google/uuid" 5 | "strings" 6 | ) 7 | 8 | func GetPureUUID() string { 9 | uid, _ := uuid.NewUUID() 10 | return strings.Replace(uid.String(), "-", "", -1) 11 | } 12 | -------------------------------------------------------------------------------- /v2/utils/uuid_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestGetPureUUID(t *testing.T) { 10 | t.Parallel() 11 | 12 | assert.Len(t, GetPureUUID(), 32) 13 | } 14 | -------------------------------------------------------------------------------- /v2/worker_test.go: -------------------------------------------------------------------------------- 1 | package machinery_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/RichardKnop/machinery/v2" 9 | ) 10 | 11 | func TestRedactURL(t *testing.T) { 12 | t.Parallel() 13 | 14 | broker := "amqp://guest:guest@localhost:5672" 15 | redactedURL := machinery.RedactURL(broker) 16 | assert.Equal(t, "amqp://localhost:5672", redactedURL) 17 | } 18 | 19 | func TestPreConsumeHandler(t *testing.T) { 20 | t.Parallel() 21 | 22 | worker := &machinery.Worker{} 23 | 24 | worker.SetPreConsumeHandler(SamplePreConsumeHandler) 25 | assert.True(t, worker.PreConsumeHandler()) 26 | } 27 | 28 | func SamplePreConsumeHandler(w *machinery.Worker) bool { 29 | return true 30 | } 31 | --------------------------------------------------------------------------------