├── message.go ├── .gitignore ├── README.md ├── demo └── main.go ├── result.go ├── logging.go ├── queue.go ├── broker.go ├── LICENSE ├── task.go ├── connection.go ├── celery.go └── amqp_connection.go /message.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | type Message struct { 4 | ContentType string 5 | Body []byte 6 | Receipt Receipt 7 | } 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Celery](http://www.celeryproject.org/) consumer 2 | Consumes jobs that are published using [Celery](http://www.celeryproject.org/) in Go. 3 | 4 | *__Note__: In very early development.* 5 | 6 | ## Installation 7 | 8 | ``` 9 | $ go get github.com/mattrobenolt/go-celery 10 | ``` 11 | 12 | ## Current status 13 | * AMQP consumer 14 | * Handles replies 15 | 16 | ## TODO 17 | * Redis consumer 18 | * ETAs and retries 19 | * Lots of other things 20 | -------------------------------------------------------------------------------- /demo/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "celery" 4 | import "time" 5 | 6 | type Adder struct {} 7 | func (a *Adder) Exec(task *celery.Task) (result interface{}, err error) { 8 | sum := float64(0) 9 | for _, arg := range task.Args { 10 | sum += arg.(float64) 11 | } 12 | result = sum 13 | time.Sleep(5*time.Second) 14 | return 15 | } 16 | 17 | func main() { 18 | celery.RegisterTask("myapp.add", &Adder{}) 19 | celery.Init() 20 | } 21 | -------------------------------------------------------------------------------- /result.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | type ResultStatus string 4 | 5 | const ( 6 | StatusSuccess ResultStatus = "SUCCESS" 7 | ) 8 | 9 | // {'status': 'SUCCESS', 'traceback': None, 'result': 2, 'task_id': 'd3858e68-48da-4631-b42b-7dbd0ffa08d1', 'children': []} 10 | 11 | type Result struct { 12 | Status ResultStatus `json:"status"` 13 | Traceback []string `json:"traceback"` 14 | Result interface{} `json:"result"` 15 | Id string `json:"task_id"` 16 | Children []string `json:"children"` 17 | } 18 | -------------------------------------------------------------------------------- /logging.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | import ( 4 | log "code.google.com/p/log4go" 5 | "flag" 6 | "strings" 7 | ) 8 | 9 | var ( 10 | loglevel = flag.String("l", "error", "Log level") 11 | ) 12 | 13 | var logger log.Logger 14 | 15 | func SetupLogging() { 16 | flag.Parse() 17 | 18 | level := log.ERROR 19 | 20 | switch strings.ToLower(*loglevel) { 21 | case "debug": 22 | level = log.DEBUG 23 | case "trace": 24 | level = log.TRACE 25 | case "info": 26 | level = log.INFO 27 | case "warning": 28 | level = log.WARNING 29 | case "error": 30 | level = log.ERROR 31 | case "critical": 32 | level = log.CRITICAL 33 | } 34 | 35 | logger = log.NewDefaultLogger(level) 36 | } 37 | 38 | func GetLogger() log.Logger { 39 | return logger 40 | } 41 | -------------------------------------------------------------------------------- /queue.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | type Exchange struct { 4 | Name string 5 | Type string 6 | Durable bool 7 | DeleteWhenComplete bool 8 | } 9 | 10 | type Queue struct { 11 | Name string 12 | Durable bool 13 | DeleteWhenUnused bool 14 | Ttl int 15 | } 16 | 17 | type Binding struct { 18 | Name string 19 | Queue *Queue 20 | Exchange *Exchange 21 | } 22 | 23 | type Publishing struct { 24 | Key string 25 | Exchange *Exchange 26 | Body []byte 27 | } 28 | 29 | func NewExchange(name string, durable bool) *Exchange { 30 | return &Exchange{ 31 | Name: name, 32 | Type: "direct", // not sure when we'd ever want anything else 33 | Durable: durable, 34 | DeleteWhenComplete: !durable, 35 | } 36 | } 37 | 38 | func NewDurableExchange(name string) *Exchange { 39 | return NewExchange(name, true) 40 | } 41 | 42 | func NewQueue(name string, durable bool, ttl int) *Queue { 43 | return &Queue{ 44 | Name: name, 45 | Durable: durable, 46 | DeleteWhenUnused: !durable, 47 | Ttl: ttl, 48 | } 49 | } 50 | 51 | func NewDurableQueue(name string) *Queue { 52 | return NewQueue(name, true, 0) 53 | } 54 | 55 | func NewExpiringQueue(name string, ttl int) *Queue { 56 | return NewQueue(name, false, ttl) 57 | } 58 | 59 | func NewBinding(name string, q *Queue, e *Exchange) *Binding { 60 | return &Binding{ 61 | Name: name, 62 | Queue: q, 63 | Exchange: e, 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /broker.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | import ( 4 | "strings" 5 | "time" 6 | "errors" 7 | "fmt" 8 | "encoding/json" 9 | ) 10 | 11 | var ( 12 | TwoSeconds = 2 * time.Second 13 | MaximumRetriesError = errors.New("Maximum retries exceeded") 14 | ) 15 | 16 | type Deliveries chan *Task 17 | 18 | type Broker struct { 19 | conn *Connection 20 | } 21 | 22 | func (b *Broker) StartConsuming(q *Queue, rate int) Deliveries { 23 | b.conn.DeclareQueue(q) 24 | deliveries := make(Deliveries) 25 | go func() { 26 | for { 27 | messages, err := b.conn.Consume(q, rate) 28 | if err != nil { 29 | logger.Error(err) 30 | time.Sleep(TwoSeconds) 31 | continue 32 | } 33 | for msg := range messages { 34 | go func(msg *Message) { 35 | task := &Task{ 36 | Receipt: msg.Receipt, 37 | } 38 | switch msg.ContentType { 39 | case "application/json": 40 | json.Unmarshal(msg.Body, &task) 41 | default: 42 | logger.Warn("Unsupported content-type [%s]", msg.ContentType) 43 | // msg.Reject(false) 44 | return 45 | } 46 | deliveries <- task 47 | }(msg) 48 | } 49 | } 50 | }() 51 | return deliveries 52 | } 53 | 54 | func NewBroker(uri string) *Broker { 55 | var scheme = strings.SplitN(uri, "://", 2)[0] 56 | 57 | if transport, ok := transportRegistry[scheme]; ok { 58 | driver := transport.Open(uri) 59 | conn := NewConnection(driver) 60 | return &Broker{conn} 61 | } 62 | 63 | panic(fmt.Sprintf("Unknown transport [%s]", scheme)) 64 | } 65 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2013, Matt Robenolt 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | Redistributions in binary form must reproduce the above copyright notice, this 11 | list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | Neither the name of the {organization} nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /task.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "time" 7 | ) 8 | 9 | const CELERY_FORMAT = `"2006-01-02T15:04:05.999999999"` 10 | 11 | type celeryTime struct { 12 | time.Time 13 | } 14 | 15 | var null = []byte("null") 16 | 17 | func (ct *celeryTime) UnmarshalJSON(data []byte) (err error) { 18 | if bytes.Equal(data, null) { 19 | return 20 | } 21 | t, err := time.Parse(CELERY_FORMAT, string(data)) 22 | if err == nil { 23 | *ct = celeryTime{t} 24 | } 25 | return 26 | } 27 | 28 | func (ct *celeryTime) MarshalJSON() (data []byte, err error) { 29 | if ct.IsZero() { 30 | return null, nil 31 | } 32 | return []byte(ct.Format(CELERY_FORMAT)), nil 33 | } 34 | 35 | type Receipt interface { 36 | Reply(string, interface{}) 37 | Ack() 38 | Requeue() 39 | Reject() 40 | } 41 | 42 | type Task struct { 43 | Task string `json:"task"` 44 | Id string `json:"id"` 45 | Args []interface{} `json:"args"` 46 | Kwargs map[string]interface{} `json:"kwargs"` 47 | Retries int `json:"retries"` 48 | Eta celeryTime `json:"eta"` 49 | Expires celeryTime `json:"expires"` 50 | Receipt Receipt `json:"-"` 51 | } 52 | 53 | func (t *Task) Ack(result interface{}) { 54 | if result != nil { 55 | t.Receipt.Reply(t.Id, result) 56 | } 57 | t.Receipt.Ack() 58 | } 59 | 60 | func (t *Task) Requeue() { 61 | go func() { 62 | time.Sleep(time.Second) 63 | t.Receipt.Requeue() 64 | }() 65 | } 66 | 67 | func (t *Task) Reject() { 68 | t.Receipt.Reject() 69 | } 70 | 71 | func (t *Task) String() string { 72 | return fmt.Sprintf("%s[%s]", t.Task, t.Id) 73 | } 74 | -------------------------------------------------------------------------------- /connection.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | type Transport interface { 4 | Open(string) Driver 5 | } 6 | 7 | type Driver interface { 8 | Connect() error 9 | DeclareExchange(*Exchange) error 10 | DeclareQueue(*Queue) error 11 | Bind(*Binding) error 12 | GetMessages(*Queue, int) (<-chan *Message, error) 13 | Publish(*Publishing) error 14 | IsConnected() bool 15 | } 16 | 17 | type Connection struct { 18 | driver Driver 19 | } 20 | 21 | func (c *Connection) Ping() (err error) { 22 | if c.driver.IsConnected() { 23 | return 24 | } 25 | err = c.driver.Connect() 26 | if err != nil { 27 | // lol, not sure what we should do here 28 | logger.Error("Error connecting [%s]", err) 29 | return 30 | } 31 | return 32 | } 33 | 34 | func (c *Connection) DeclareExchange(e *Exchange) error { 35 | err := c.Ping() 36 | if err != nil { 37 | return err 38 | } 39 | return c.driver.DeclareExchange(e) 40 | } 41 | 42 | func (c *Connection) DeclareQueue(q *Queue) error { 43 | err := c.Ping() 44 | if err != nil { 45 | return err 46 | } 47 | logger.Info("Declaring queue [%s]", q.Name) 48 | return c.driver.DeclareQueue(q) 49 | } 50 | 51 | func (c *Connection) Bind(b *Binding) error { 52 | err := c.Ping() 53 | if err != nil { 54 | return err 55 | } 56 | return c.driver.Bind(b) 57 | } 58 | 59 | func (c *Connection) Consume(q *Queue, rate int) (<-chan *Message, error) { 60 | err := c.Ping() 61 | if err != nil { 62 | return nil, err 63 | } 64 | logger.Info("Consuming from [%s]", q.Name) 65 | return c.driver.GetMessages(q, rate) 66 | } 67 | 68 | func NewConnection(driver Driver) *Connection { 69 | return &Connection{driver: driver} 70 | } 71 | 72 | var transportRegistry = make(map[string]Transport) 73 | 74 | func RegisterTransport(name string, t Transport) { 75 | transportRegistry[name] = t 76 | } 77 | -------------------------------------------------------------------------------- /celery.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | import ( 4 | "flag" 5 | "time" 6 | "os" 7 | "os/signal" 8 | "runtime" 9 | "errors" 10 | "fmt" 11 | "encoding/json" 12 | "sync" 13 | "syscall" 14 | ) 15 | 16 | var ( 17 | broker = flag.String("broker", "amqp://guest:guest@localhost:5672//", "Broker") 18 | queue = flag.String("Q", "celery", "queue") 19 | concurrency = flag.Int("c", runtime.NumCPU(), "concurrency") 20 | ) 21 | 22 | type Worker interface { 23 | Exec(*Task) (interface{}, error) 24 | } 25 | 26 | var registry = make(map[string]Worker) 27 | 28 | func RegisterTask(name string, worker Worker) { 29 | registry[name] = worker 30 | } 31 | 32 | var ( 33 | RetryError = errors.New("Retry task again") 34 | RejectError = errors.New("Reject task") 35 | ) 36 | 37 | func shutdown(status int) { 38 | fmt.Println("\nceleryd: Warm shutdown") 39 | os.Exit(status) 40 | } 41 | 42 | func Init() { 43 | flag.Parse() 44 | SetupLogging() 45 | 46 | runtime.GOMAXPROCS(*concurrency) 47 | broker := NewBroker(*broker) 48 | fmt.Println("") 49 | fmt.Println("[Tasks]") 50 | for key, _ := range registry { 51 | fmt.Printf(" %s\n", key) 52 | } 53 | fmt.Println("") 54 | hostname, _ := os.Hostname() 55 | logger.Warn("celery@%s ready.", hostname) 56 | 57 | queue := NewDurableQueue(*queue) 58 | deliveries := broker.StartConsuming(queue, *concurrency) 59 | var wg sync.WaitGroup 60 | draining := false 61 | go func() { 62 | c := make(chan os.Signal, 1) 63 | signal.Notify(c, os.Interrupt, syscall.SIGTERM) 64 | for _ = range c { 65 | // If interrupting for the second time, 66 | // terminate un-gracefully 67 | if draining { 68 | shutdown(1) 69 | } 70 | fmt.Println("\nceleryd: Hitting Ctrl+C again will terminate all running tasks!") 71 | // Gracefully shut down 72 | draining = true 73 | go func() { 74 | wg.Wait() 75 | shutdown(0) 76 | }() 77 | } 78 | }() 79 | for !draining { 80 | task := <- deliveries 81 | wg.Add(1) 82 | go func(task *Task) { 83 | defer wg.Done() 84 | if worker, ok := registry[task.Task]; ok { 85 | logger.Info("Got task from broker: %s", task) 86 | start := time.Now() 87 | result, err := worker.Exec(task) 88 | end := time.Now() 89 | if err != nil { 90 | switch err { 91 | case RetryError: 92 | task.Requeue() 93 | default: 94 | task.Reject() 95 | } 96 | } else { 97 | logger.Info(func()string { 98 | res, _ := json.Marshal(result) 99 | return fmt.Sprintf("Task %s succeeded in %s: %s", task, end.Sub(start), res) 100 | }) 101 | task.Ack(result) 102 | } 103 | } else { 104 | task.Reject() 105 | logger.Error("Received unregistered task of type '%s'.\nThe message has been ignored and discarded.\n", task.Task) 106 | } 107 | }(task) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /amqp_connection.go: -------------------------------------------------------------------------------- 1 | package celery 2 | 3 | import ( 4 | "strings" 5 | "github.com/streadway/amqp" 6 | "time" 7 | "encoding/json" 8 | ) 9 | 10 | type AMQPReceipt struct { 11 | driver *AMQPDriver 12 | delivery amqp.Delivery 13 | } 14 | 15 | func (r *AMQPReceipt) Ack() { 16 | r.delivery.Ack(false) 17 | } 18 | 19 | func (r *AMQPReceipt) Requeue() { 20 | r.delivery.Reject(true) 21 | } 22 | 23 | func (r *AMQPReceipt) Reject() { 24 | r.delivery.Reject(false) 25 | } 26 | 27 | func (r *AMQPReceipt) Reply(id string, data interface{}) { 28 | result := &Result{ 29 | Status: StatusSuccess, 30 | Result: data, 31 | Id: id, 32 | } 33 | 34 | id = strings.Replace(id, "-", "", -1) 35 | 36 | payload, err := json.Marshal(result) 37 | if err != nil { 38 | logger.Error("Error marshalling reply [%s]", err) 39 | return 40 | } 41 | 42 | r.driver.Connect() 43 | err = r.driver.DeclareQueue(NewExpiringQueue(id, 86400000)) 44 | if err != nil { 45 | logger.Error("Error declaring queue [%s]", err) 46 | return 47 | } 48 | 49 | publishing := &Publishing{ 50 | Key: id, 51 | Exchange: NewDurableExchange(""), 52 | Body: payload, 53 | } 54 | err = r.driver.Publish(publishing) 55 | if err != nil { 56 | logger.Error("Error publishing [%s]", err) 57 | return 58 | } 59 | } 60 | 61 | type AMQPDriver struct { 62 | uris []string 63 | alive bool 64 | i int 65 | channel *amqp.Channel 66 | } 67 | 68 | func (c *AMQPDriver) Connect() (err error) { 69 | if c.alive { 70 | return 71 | } 72 | defer func() { 73 | // On next connect, use the next uri 74 | c.i = (c.i + 1) % len(c.uris) 75 | }() 76 | 77 | logger.Info("Dialing [%s]", c.uris[c.i]) 78 | conn, err := amqp.Dial(c.uris[c.i]) 79 | if err != nil { 80 | return 81 | } 82 | 83 | c.channel, err = conn.Channel() 84 | if err != nil { 85 | return 86 | } 87 | c.alive = true 88 | return 89 | } 90 | 91 | func (c *AMQPDriver) IsConnected() bool { 92 | return c.alive 93 | } 94 | 95 | func (c *AMQPDriver) DeclareExchange(e *Exchange) error { 96 | return c.channel.ExchangeDeclare( 97 | e.Name, 98 | e.Type, 99 | e.Durable, 100 | e.DeleteWhenComplete, 101 | false, // internal 102 | false, // noWait 103 | nil, 104 | ) 105 | } 106 | 107 | func (c *AMQPDriver) DeclareQueue(q *Queue) error { 108 | var ( 109 | arguments amqp.Table 110 | ) 111 | if q.Ttl > 0 { 112 | arguments = amqp.Table{"x-expires": int32(q.Ttl)} 113 | } 114 | _, err := c.channel.QueueDeclare( 115 | q.Name, 116 | q.Durable, 117 | q.DeleteWhenUnused, 118 | false, // exclusive 119 | false, // noWait 120 | arguments, 121 | ) 122 | return err 123 | } 124 | 125 | func (c *AMQPDriver) Bind(b *Binding) error { 126 | return c.channel.QueueBind( 127 | b.Queue.Name, 128 | b.Name, 129 | b.Exchange.Name, 130 | false, // noWait 131 | nil, // arguments 132 | ) 133 | } 134 | 135 | func (c *AMQPDriver) Publish(p *Publishing) error { 136 | msg := amqp.Publishing{ 137 | DeliveryMode: amqp.Persistent, 138 | Timestamp: time.Now(), 139 | ContentType: "application/json", 140 | Body: p.Body, 141 | } 142 | return c.channel.Publish( 143 | p.Exchange.Name, 144 | p.Key, 145 | false, 146 | false, 147 | msg, 148 | ) 149 | } 150 | 151 | func (c *AMQPDriver) GetMessages(q *Queue, rate int) (<-chan *Message, error) { 152 | c.channel.Qos(rate, 0, false) 153 | deliveries, err := c.channel.Consume( 154 | q.Name, 155 | "", // consumerTag 156 | false, // autoAck 157 | false, // exclusive 158 | false, // noLocal 159 | false, // noWait 160 | nil, // arguments 161 | ) 162 | if err != nil { 163 | c.alive = false 164 | return nil, err 165 | } 166 | 167 | messages := make(chan *Message) 168 | go func() { 169 | for d := range deliveries { 170 | messages <- &Message{ 171 | ContentType: d.ContentType, 172 | Body: d.Body, 173 | Receipt: &AMQPReceipt{ 174 | driver: c, 175 | delivery: d, 176 | }, 177 | } 178 | } 179 | // connection was lost 180 | c.alive = false 181 | close(messages) 182 | }() 183 | return messages, nil 184 | } 185 | 186 | type AMQP struct {} 187 | func (a *AMQP) Open(uri string) Driver { 188 | return &AMQPDriver{ 189 | uris: strings.Split(uri, ";"), 190 | } 191 | } 192 | 193 | func init() { 194 | RegisterTransport("amqp", &AMQP{}) 195 | } 196 | --------------------------------------------------------------------------------