├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── adminserver └── admin.go ├── amqp-0-9-1.xml ├── amqp ├── amqp.proto ├── amqpreader.go ├── amqpwriter.go ├── constants_generated.go ├── domains_generated.go ├── errors.go ├── messages.proto ├── protocol_generated.proto ├── protocol_protobuf_readwrite_generated.go ├── readwrite_test.go ├── table.go ├── table_test.go ├── testlib.go ├── tx.go ├── tx_test.go └── types.go ├── amqp0-9-1.extended.xml ├── amqpgen ├── amqpgen.go └── templates.go ├── binding ├── binding.go └── binding_test.go ├── config.default.json ├── consumer ├── consumer.go └── consumer_test.go ├── dev └── config.json ├── dispatchd ├── config.go └── main.go ├── exchange ├── exchange.go └── exchange_test.go ├── gen └── server.proto ├── msgstore ├── msgstore.go ├── msgstore_test.go └── testlib.go ├── persist ├── persist.go └── persist_test.go ├── queue ├── queue.go └── queue_test.go ├── scripts ├── benchmark_helper.sh └── cover.py ├── server ├── auth.go ├── basicMethods.go ├── channel.go ├── channelMethods.go ├── connection.go ├── connectionMethods.go ├── exchangeMethods.go ├── queueMethods.go ├── server.go ├── server_consumer_test.go ├── server_exchange_test.go ├── server_publish_test.go ├── server_queue_test.go ├── server_test.go ├── server_tx_test.go └── txMethods.go ├── static ├── admin.html └── admin.js ├── stats ├── stats.go └── stats_test.go └── util ├── util.go └── util_test.go /.dockerignore: -------------------------------------------------------------------------------- 1 | scripts/external/ 2 | .git -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pb.go 2 | *.cover 3 | *.db 4 | scripts/external/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 2 | # Everything about this is kind of gross, but it does get a server running 3 | # 4 | 5 | FROM centos:latest 6 | # OS setup 7 | RUN yum install -y make golang git 8 | RUN mkdir -p /app/dispatchd && mkdir -p /data/dispatchd/ 9 | RUN yum install -y python-setuptools.noarch gcc-c++ glibc-headers 10 | RUN easy_install mako 11 | 12 | # protobuf 13 | RUN cd /tmp && curl -L -o protobuf-2.6.1.tar.gz https://github.com/google/protobuf/releases/download/v2.6.1/protobuf-2.6.1.tar.gz 14 | RUN cd /tmp && tar -xzf protobuf-2.6.1.tar.gz 15 | RUN cd /tmp/protobuf-2.6.1/ && ./configure && make install 16 | 17 | # Build dispatchd 18 | ENV BUILD_DIR /app/dispatchd/src/github.com/jeffjenkins/dispatchd/ 19 | RUN mkdir -p $BUILD_DIR 20 | COPY . $BUILD_DIR 21 | ENV GOPATH /app/dispatchd 22 | RUN cd $BUILD_DIR && PATH=$PATH:$GOPATH/bin make install 23 | 24 | # Runtime configuration 25 | ENV STATIC_PATH $BUILD_DIR/static 26 | RUN cp $BUILD_DIR/config.default.json /etc/dispatchd.json 27 | CMD ["/app/dispatchd/bin/server", "-config-file=/etc/dispatchd.json", "-persist-dir=/data/dispatchd/"] 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2015 Jeffrey Jenkins 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | .PHONY: all protoc_present deps gen_all gen_pb gen_amqp build test full_coverage \ 3 | real_line_count devserver benchmark_dev benchmark install clean 4 | 5 | PROTOC := protoc -I=${GOPATH}/src:${GOPATH}/src/github.com/gogo/protobuf/protobuf/ 6 | PROJECT_PATH := ${GOPATH}/src/github.com/jeffjenkins/dispatchd 7 | 8 | RUN_PORT=5672 9 | PERF_SCRIPT=scripts/external/perf-client/runjava.sh 10 | 11 | all: build 12 | 13 | clean: 14 | rm -Rf scripts/external/ 15 | rm -f */*.pb.go 16 | rm -f ${GOPATH}/bin/server 17 | 18 | protoc_present: 19 | which protoc 20 | 21 | deps: 22 | go get github.com/boltdb/bolt \ 23 | github.com/gogo/protobuf/gogoproto \ 24 | github.com/gogo/protobuf/proto \ 25 | github.com/gogo/protobuf/protoc-gen-gogo \ 26 | github.com/rcrowley/go-metrics \ 27 | github.com/streadway/amqp \ 28 | github.com/wadey/gocovmerge \ 29 | golang.org/x/crypto/bcrypt 30 | 31 | gen_all: deps gen_pb gen_amqp 32 | 33 | gen_pb: gen_amqp protoc_present 34 | $(PROTOC) --gogo_out=${GOPATH}/src ${PROJECT_PATH}/amqp/*.proto 35 | $(PROTOC) --gogo_out=${GOPATH}/src ${PROJECT_PATH}/gen/*.proto 36 | 37 | gen_amqp: 38 | go run amqpgen/*.go --spec=amqp0-9-1.extended.xml && go fmt github.com/jeffjenkins/dispatchd/... 39 | gofmt -w amqp/*generated*.go 40 | 41 | build: deps gen_all 42 | go build -o ${GOPATH}/dispatchd github.com/jeffjenkins/dispatchd/server 43 | 44 | install: deps gen_all 45 | go install github.com/jeffjenkins/dispatchd/server 46 | 47 | test: deps gen_all 48 | go test -cover github.com/jeffjenkins/dispatchd/... 49 | 50 | full_coverage: test 51 | # Output: $$GOPATH/all.cover 52 | python scripts/cover.py 53 | 54 | real_line_count: 55 | find . | grep '.go$$' | grep -v pb.go | grep -v generated | xargs cat | wc -l 56 | 57 | ${PERF_SCRIPT}: 58 | mkdir -p scripts/external/ 59 | curl -o scripts/external/perf-client.tar.gz 'https://www.rabbitmq.com/releases/rabbitmq-java-client/v3.5.6/rabbitmq-java-client-bin-3.5.6.tar.gz' 60 | tar -C scripts/external/ -zxf scripts/external/perf-client.tar.gz 61 | mv scripts/external/rabbitmq-java-client-bin-3.5.6 scripts/external/perf-client/ 62 | 63 | devserver: install 64 | go install github.com/jeffjenkins/dispatchd/server 65 | STATIC_PATH=${GOPATH}/src/github.com/jeffjenkins/dispatchd/static \ 66 | ${GOPATH}/bin/server \ 67 | -config-file ${GOPATH}/src/github.com/jeffjenkins/dispatchd/dev/config.json 68 | 69 | benchmark_dev: scripts/external/perf-client/runjava.sh 70 | RUN_PORT=1111 scripts/benchmark_helper.sh 71 | 72 | benchmark: scripts/external/perf-client/runjava.sh 73 | RUN_PORT=${RUN_PORT} scripts/benchmark_helper.sh 74 | 75 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dispatchd - A Message Broker and Queue Server 2 | 3 | ## Status 4 | 5 | `dispatchd` is in alpha. 6 | 7 | It generally works but is not hardened enough for production use. More on what it can and can't do below 8 | 9 | ## Features 10 | 11 | * Basically all (and many optional) amqp 0-9-1 features are supported 12 | * Some rabbitmq extensions are implemented: 13 | * nack (almost: the same consumer could receive the message again) 14 | * internal exchanges (flag exists, but currently unused) 15 | * auto-delete exchanges 16 | * Rabbit's reinterpretation of basic.qos 17 | * There is a simple admin page that can show basic info about what's 18 | happening in the server 19 | 20 | Notably missing from the features in the amqp spec are: 21 | 22 | * support for multiple priority levels 23 | * handling of queue/memory limits being exceeded 24 | 25 | ## Configuration 26 | 27 | There are command line flags for basic configuration: 28 | 29 | -admin-port int 30 | Port for admin server. Default: 8080 31 | -amqp-port int 32 | Port for amqp protocol messages. Default: 5672 33 | -config-file string 34 | Directory for the server and message database files. Default: do not read a config file 35 | -debug-port int 36 | Port for the golang debug handlers. Default: 6060 37 | -persist-dir string 38 | Directory for the server and message database files. Default: /data/dispatchd/ 39 | 40 | These options can be overridden if `-config-file` is specified. The config file is JSON and will complain loudly if any types don't look right rather than ignoring or working around them. 41 | 42 | Right now the only config file exclusive options are for users and passwords. In the future the config file will have tuning parameters as well. 43 | 44 | ## Running Dispatchd 45 | 46 | Dispatchd is currently only packaged as a docker image. You can run it with this command: 47 | 48 | docker run \ 49 | -p=8080:8080 \ 50 | -p=5672:5672 \ 51 | --volume=YOUR_CONFIG_FILE:/etc/dispatchd.json \ 52 | --volume=YOUR_DATA_DIR:/data/dispatchd/ \ 53 | dispatchd/dispatchd 54 | 55 | Config file can be left out for the default behaviors. The data volume needs 56 | to be specified so that data is persisted outside of the container. 57 | 58 | ## Security/Auth 59 | 60 | Dispatchd uses SASL PLAIN auth as required by the amqp spec. There is a default user (user: guest, pw: guest) which is available if there is no config file. If there is a config file the user entries look like this: 61 | 62 | { 63 | "users" : { 64 | "guest" : { 65 | "password_bcrypt_base64" : "JDJhJDExJENobGk4dG5rY0RGemJhTjhsV21xR3VNNnFZZ1ZqTzUzQWxtbGtyMHRYN3RkUHMuYjF5SUt5" 66 | } 67 | } 68 | } 69 | 70 | Passwords are generated using bcrypt and then base64 encoded. 71 | 72 | ## Performance compared to RabbitMQ 73 | 74 | All perf testing is done with RabbitMQ's Java perf testing tool. Generally using this command line: 75 | 76 | ./runjava.sh com.rabbitmq.examples.PerfTest --exchange perf-test -uri amqp://guest:guest@localhost:5672 --queue some-queue --consumers 4 --producers 2 --qos 100 77 | 78 | On a late 2014 i7 mac mini the results were as follows: 79 | 80 | RabbitMQ Send: ~13000 msg/s, consistent 81 | RabbitMQ Recv: ~10000 msg/s, consistent 82 | Dispatchd Send: ~18000 msg/s, varying between 15k and 22k 83 | Dispatchd Recv: ~18000 msg/s, consistent 84 | 85 | It is unclear whether this difference in performance would go away if the server had complete feature parity with Rabbit. Based on the feature diff it isn't clear why it would, but Rabbit is highly tuned and extremely performant. 86 | 87 | With the `-flag persistent` performance drops a bit: 88 | 89 | RabbitMQ Send: ~9000k msg/s, varying between 6 and 12k 90 | RabbitMQ Recv: ~7000k msg/s, consistent 91 | Dispatchd Send: ~13500k msg/s, varying between 11 and 15k 92 | Dispatchd Recv: ~13000k msg/s, varying between 11 and 15k 93 | 94 | The one thing to note about Dispatchd's send (publish) performance here is that it does not have any internal flow control, so it can get backlogged writing messages to disk. It could be that Rabbit is doing a sustainable 9k and Dispatchd would lose way more messages than come in during one coalesce interval. 95 | 96 | On the Receieve (deliver) side, Dispatchd reconciles messages which don't need to by persisted (because they have already been delivered/acked) and so there is no performance hit to persistence if all messages are delivered before the next write to disk happens (every 200ms by default). 97 | 98 | ## Testing and Code Coverage 99 | 100 | Dispatchd has a fairly extensive test suite. Almost all of the major functions are tested and test coverage—ignoring generated code—is around 80% 101 | 102 | ## What's Next? How do I request changes? 103 | 104 | Non-trivial changes are tracked through [github issues](https://github.com/jeffjenkins/dispatchd/issues). -------------------------------------------------------------------------------- /adminserver/admin.go: -------------------------------------------------------------------------------- 1 | package adminserver 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/jeffjenkins/dispatchd/server" 7 | "github.com/rcrowley/go-metrics" 8 | "net/http" 9 | "os" 10 | ) 11 | 12 | func homeJSON(w http.ResponseWriter, r *http.Request, server *server.Server) { 13 | var b, err = json.MarshalIndent(server, "", " ") 14 | if err != nil { 15 | w.Write([]byte(err.Error())) 16 | } 17 | w.Write(b) 18 | } 19 | 20 | func statsJSON(w http.ResponseWriter, r *http.Request, server *server.Server) { 21 | // fmt.Println(metrics.DefaultRegistry) 22 | var b, err = json.MarshalIndent(metrics.DefaultRegistry, "", " ") 23 | if err != nil { 24 | w.Write([]byte(err.Error())) 25 | } 26 | w.Write(b) 27 | } 28 | 29 | func StartAdminServer(server *server.Server, port int) { 30 | // Static files 31 | var path = os.Getenv("STATIC_PATH") 32 | if len(path) == 0 { 33 | panic("No static file path in $STATIC_PATH!") 34 | } 35 | var fileServer = http.FileServer(http.Dir(path)) 36 | http.Handle("/static/", http.StripPrefix("/static", fileServer)) 37 | 38 | // Home 39 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 40 | var p = path + "/admin.html" 41 | fmt.Println(p) 42 | http.ServeFile(w, r, p) 43 | }) 44 | 45 | // API 46 | http.HandleFunc("/api/server", func(w http.ResponseWriter, r *http.Request) { 47 | homeJSON(w, r, server) 48 | }) 49 | 50 | http.HandleFunc("/api/stats", func(w http.ResponseWriter, r *http.Request) { 51 | statsJSON(w, r, server) 52 | }) 53 | 54 | // Boot admin server 55 | fmt.Printf("Admin server on port %d, static files from: %s\n", port, path) 56 | http.ListenAndServe(fmt.Sprintf(":%d", port), nil) 57 | } 58 | -------------------------------------------------------------------------------- /amqp/amqp.proto: -------------------------------------------------------------------------------- 1 | 2 | package amqp; 3 | 4 | import "github.com/gogo/protobuf/gogoproto/gogo.proto"; 5 | 6 | option (gogoproto.marshaler_all) = true; 7 | option (gogoproto.unmarshaler_all) = true; 8 | option (gogoproto.sizer_all) = true; 9 | 10 | message FieldValuePair { 11 | option (gogoproto.goproto_unrecognized) = false; 12 | option (gogoproto.goproto_getters) = false; 13 | optional string key = 1; 14 | optional FieldValue value = 2; 15 | } 16 | 17 | message Table { 18 | option (gogoproto.goproto_unrecognized) = false; 19 | option (gogoproto.goproto_getters) = false; 20 | repeated FieldValuePair table = 1; 21 | } 22 | 23 | message FieldArray { 24 | option (gogoproto.goproto_unrecognized) = false; 25 | option (gogoproto.goproto_getters) = false; 26 | repeated FieldValue value = 1; 27 | } 28 | 29 | message Decimal { 30 | option (gogoproto.goproto_unrecognized) = false; 31 | option (gogoproto.goproto_getters) = false; 32 | optional uint32 scale = 1 [(gogoproto.casttype) = "uint8"]; 33 | optional int32 value = 2; 34 | } 35 | 36 | message FieldValue { 37 | option (gogoproto.goproto_unrecognized) = false; 38 | option (gogoproto.goproto_getters) = false; 39 | oneof value { 40 | bool v_boolean = 1; 41 | int32 v_int8 = 2 [(gogoproto.casttype) = "int8"]; 42 | uint32 v_uint8 = 3 [(gogoproto.casttype) = "uint8"]; 43 | int32 v_int16 = 4 [(gogoproto.casttype) = "int16"]; 44 | uint32 v_uint16 = 5 [(gogoproto.casttype) = "uint16"]; 45 | int32 v_int32 = 6 ; 46 | uint32 v_uint32 = 7 ; 47 | int64 v_int64 = 8 ; 48 | uint64 v_uint64 = 9 ; 49 | float v_float = 10 ; 50 | double v_double = 11 ; 51 | Decimal v_decimal = 12 ; 52 | string v_shortstr = 13 ; // < 256 bytes 53 | bytes v_longstr = 14 ; 54 | FieldArray v_array = 15 ; 55 | uint64 v_timestamp = 16 ; 56 | Table v_table = 17 ; 57 | bytes v_bytes = 18 ; 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /amqp/amqpreader.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "fmt" 8 | "io" 9 | ) 10 | 11 | func ReadFrame(reader io.Reader) (*WireFrame, error) { 12 | // Using little since other functions will assume big 13 | 14 | // get fixed size portion 15 | var incoming = make([]byte, 1+2+4) 16 | var err = binary.Read(reader, binary.LittleEndian, incoming) 17 | if err != nil { 18 | return nil, err 19 | } 20 | var memReader = bytes.NewBuffer(incoming) 21 | 22 | var f = &WireFrame{} 23 | 24 | // The reads from memReader are guaranteed to succeed because the 7 bytes 25 | // were allocated above and that is what is read out 26 | 27 | // frame type 28 | frameType, _ := ReadOctet(memReader) 29 | f.FrameType = frameType 30 | 31 | // channel 32 | channel, _ := ReadShort(memReader) 33 | f.Channel = channel 34 | 35 | // Variable length part 36 | var length uint32 37 | err = binary.Read(memReader, binary.BigEndian, &length) 38 | 39 | var slice = make([]byte, length+1) 40 | err = binary.Read(reader, binary.BigEndian, slice) 41 | if err != nil { 42 | return nil, errors.New("Bad frame payload: " + err.Error()) 43 | } 44 | f.Payload = slice[0:length] 45 | return f, nil 46 | } 47 | 48 | // Fields 49 | 50 | func ReadOctet(buf io.Reader) (data byte, err error) { 51 | err = binary.Read(buf, binary.BigEndian, &data) 52 | if err != nil { 53 | return 0, errors.New("Could not read byte: " + err.Error()) 54 | } 55 | return data, nil 56 | } 57 | 58 | func ReadShort(buf io.Reader) (data uint16, err error) { 59 | err = binary.Read(buf, binary.BigEndian, &data) 60 | if err != nil { 61 | return 0, errors.New("Could not read uint16: " + err.Error()) 62 | } 63 | return data, nil 64 | } 65 | 66 | func ReadLong(buf io.Reader) (data uint32, err error) { 67 | err = binary.Read(buf, binary.BigEndian, &data) 68 | if err != nil { 69 | return 0, errors.New("Could not read uint32: " + err.Error()) 70 | } 71 | return data, nil 72 | } 73 | 74 | func ReadLonglong(buf io.Reader) (data uint64, err error) { 75 | err = binary.Read(buf, binary.BigEndian, &data) 76 | if err != nil { 77 | return 0, errors.New("Could not read uint64: " + err.Error()) 78 | } 79 | return data, nil 80 | } 81 | 82 | func ReadShortstr(buf io.Reader) (string, error) { 83 | var length uint8 84 | var err = binary.Read(buf, binary.BigEndian, &length) 85 | if err != nil { 86 | return "", err 87 | } 88 | var slice = make([]byte, length) 89 | err = binary.Read(buf, binary.BigEndian, slice) 90 | if err != nil { 91 | return "", err 92 | } 93 | return string(slice), nil 94 | } 95 | 96 | func ReadLongstr(buf io.Reader) ([]byte, error) { 97 | var length uint32 98 | var err = binary.Read(buf, binary.BigEndian, &length) 99 | if err != nil { 100 | return nil, err 101 | } 102 | var slice = make([]byte, length) 103 | err = binary.Read(buf, binary.BigEndian, slice) 104 | if err != nil { 105 | return nil, err 106 | } 107 | return slice, err 108 | } 109 | 110 | // Can't get coverage on this easily since I can't currently generate 111 | // timestamp values in Tables because protobuf doesn't give me a type that 112 | // is different from uint64 113 | func ReadTimestamp(buf io.Reader) (uint64, error) { // pragma: nocover 114 | var t uint64 115 | var err = binary.Read(buf, binary.BigEndian, &t) 116 | if err != nil { // pragma: nocover 117 | return 0, errors.New("Could not read uint64") 118 | } 119 | return t, nil // pragma: nocover 120 | } 121 | 122 | func ReadTable(reader io.Reader, strictMode bool) (*Table, error) { 123 | var seen = make(map[string]bool) 124 | var table = &Table{Table: make([]*FieldValuePair, 0)} 125 | var byteData, err = ReadLongstr(reader) 126 | if err != nil { 127 | return nil, errors.New("Error reading table longstr: " + err.Error()) 128 | } 129 | var data = bytes.NewBuffer(byteData) 130 | for data.Len() > 0 { 131 | key, err := ReadShortstr(data) 132 | if err != nil { 133 | return nil, errors.New("Error reading key: " + err.Error()) 134 | } 135 | if _, found := seen[key]; found { 136 | return nil, fmt.Errorf("Duplicate key in table: %s", key) 137 | } 138 | value, err := readValue(data, strictMode) 139 | if err != nil { 140 | return nil, errors.New("Error reading value for '" + key + "': " + err.Error()) 141 | } 142 | table.Table = append(table.Table, &FieldValuePair{Key: &key, Value: value}) 143 | seen[key] = true 144 | } 145 | return table, nil 146 | } 147 | 148 | func readValue(reader io.Reader, strictMode bool) (*FieldValue, error) { 149 | var t, err = ReadOctet(reader) 150 | if err != nil { 151 | return nil, err 152 | } 153 | 154 | switch { 155 | case t == 't': 156 | var v, err = ReadOctet(reader) 157 | if err != nil { 158 | return nil, err 159 | } 160 | var vb = v != 0 161 | return &FieldValue{Value: &FieldValue_VBoolean{VBoolean: vb}}, nil 162 | case t == 'b': 163 | var v int8 164 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil { 165 | return nil, err 166 | } 167 | return &FieldValue{Value: &FieldValue_VInt8{VInt8: v}}, nil 168 | case t == 'B' && strictMode: 169 | var v uint8 170 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil { 171 | return nil, err 172 | } 173 | return &FieldValue{Value: &FieldValue_VUint8{VUint8: v}}, nil 174 | case t == 'U' && strictMode || t == 's' && !strictMode: 175 | var v int16 176 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil { 177 | return nil, err 178 | } 179 | return &FieldValue{Value: &FieldValue_VInt16{VInt16: v}}, nil 180 | case t == 'u' && strictMode: 181 | var v uint16 182 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil { 183 | return nil, err 184 | } 185 | return &FieldValue{Value: &FieldValue_VUint16{VUint16: v}}, nil 186 | case t == 'I': 187 | var v int32 188 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil { 189 | return nil, err 190 | } 191 | return &FieldValue{Value: &FieldValue_VInt32{VInt32: v}}, nil 192 | case t == 'i' && strictMode: 193 | var v uint32 194 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil { 195 | return nil, err 196 | } 197 | return &FieldValue{Value: &FieldValue_VUint32{VUint32: v}}, nil 198 | case t == 'L' && strictMode || t == 'l' && !strictMode: 199 | var v int64 200 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil { 201 | return nil, err 202 | } 203 | return &FieldValue{Value: &FieldValue_VInt64{VInt64: v}}, nil 204 | case t == 'l' && strictMode: 205 | var v uint64 206 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil { 207 | return nil, err 208 | } 209 | return &FieldValue{Value: &FieldValue_VUint64{VUint64: v}}, nil 210 | case t == 'f': 211 | var v float32 212 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil { 213 | return nil, err 214 | } 215 | return &FieldValue{Value: &FieldValue_VFloat{VFloat: v}}, nil 216 | case t == 'd': 217 | var v float64 218 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil { 219 | return nil, err 220 | } 221 | return &FieldValue{Value: &FieldValue_VDouble{VDouble: v}}, nil 222 | case t == 'D': 223 | var scale uint8 = 0 224 | var val int32 = 0 225 | var v = Decimal{&scale, &val} 226 | if err = binary.Read(reader, binary.BigEndian, v.Scale); err != nil { 227 | return nil, err 228 | } 229 | if err = binary.Read(reader, binary.BigEndian, v.Value); err != nil { 230 | return nil, err 231 | } 232 | return &FieldValue{Value: &FieldValue_VDecimal{VDecimal: &v}}, nil 233 | case t == 's' && strictMode: 234 | v, err := ReadShortstr(reader) 235 | if err != nil { 236 | return nil, err 237 | } 238 | return &FieldValue{Value: &FieldValue_VShortstr{VShortstr: v}}, nil 239 | case t == 'S': 240 | v, err := ReadLongstr(reader) 241 | if err != nil { 242 | return nil, err 243 | } 244 | return &FieldValue{Value: &FieldValue_VLongstr{VLongstr: v}}, nil 245 | case t == 'A': 246 | v, err := readArray(reader, strictMode) 247 | if err != nil { 248 | return nil, err 249 | } 250 | return &FieldValue{Value: &FieldValue_VArray{VArray: &FieldArray{Value: v}}}, nil 251 | case t == 'T': 252 | v, err := ReadTimestamp(reader) 253 | if err != nil { 254 | return nil, err 255 | } 256 | return &FieldValue{Value: &FieldValue_VTimestamp{VTimestamp: v}}, nil 257 | case t == 'F': 258 | v, err := ReadTable(reader, strictMode) 259 | if err != nil { 260 | return nil, err 261 | } 262 | return &FieldValue{Value: &FieldValue_VTable{VTable: v}}, nil 263 | case t == 'V': 264 | return nil, nil 265 | case t == 'x': 266 | v, err := ReadLongstr(reader) 267 | if err != nil { 268 | return nil, err 269 | } 270 | return &FieldValue{Value: &FieldValue_VBytes{VBytes: v}}, nil 271 | } 272 | return nil, fmt.Errorf("Unknown table value type '%c' (%d)", t, t) 273 | } 274 | 275 | func readArray(reader io.Reader, strictMode bool) ([]*FieldValue, error) { 276 | var ret = make([]*FieldValue, 0, 0) 277 | var longstr, errs = ReadLongstr(reader) 278 | if errs != nil { 279 | return nil, errs 280 | } 281 | var data = bytes.NewBuffer(longstr) 282 | for data.Len() > 0 { 283 | var value, err = readValue(data, strictMode) 284 | if err != nil { 285 | return nil, err 286 | } 287 | ret = append(ret, value) 288 | } 289 | return ret, nil 290 | } 291 | 292 | func (props *BasicContentHeaderProperties) ReadProps(flags uint16, reader io.Reader, strictMode bool) (err error) { 293 | if MaskContentType&flags != 0 { 294 | v, err := ReadShortstr(reader) 295 | props.ContentType = &v 296 | if err != nil { 297 | return err 298 | } 299 | } 300 | if MaskContentEncoding&flags != 0 { 301 | v, err := ReadShortstr(reader) 302 | props.ContentEncoding = &v 303 | if err != nil { 304 | return err 305 | } 306 | } 307 | if MaskHeaders&flags != 0 { 308 | v, err := ReadTable(reader, strictMode) 309 | props.Headers = v 310 | if err != nil { 311 | return err 312 | } 313 | } 314 | if MaskDeliveryMode&flags != 0 { 315 | v, err := ReadOctet(reader) 316 | props.DeliveryMode = &v 317 | if err != nil { 318 | return err 319 | } 320 | } 321 | if MaskPriority&flags != 0 { 322 | v, err := ReadOctet(reader) 323 | props.Priority = &v 324 | if err != nil { 325 | return err 326 | } 327 | } 328 | if MaskCorrelationId&flags != 0 { 329 | v, err := ReadShortstr(reader) 330 | props.CorrelationId = &v 331 | if err != nil { 332 | return err 333 | } 334 | } 335 | if MaskReplyTo&flags != 0 { 336 | v, err := ReadShortstr(reader) 337 | props.ReplyTo = &v 338 | if err != nil { 339 | return err 340 | } 341 | } 342 | if MaskExpiration&flags != 0 { 343 | v, err := ReadShortstr(reader) 344 | props.Expiration = &v 345 | if err != nil { 346 | return err 347 | } 348 | } 349 | if MaskMessageId&flags != 0 { 350 | v, err := ReadShortstr(reader) 351 | props.MessageId = &v 352 | if err != nil { 353 | return err 354 | } 355 | } 356 | if MaskTimestamp&flags != 0 { 357 | v, err := ReadLonglong(reader) 358 | props.Timestamp = &v 359 | if err != nil { 360 | return err 361 | } 362 | } 363 | if MaskType&flags != 0 { 364 | v, err := ReadShortstr(reader) 365 | props.Type = &v 366 | if err != nil { 367 | return err 368 | } 369 | } 370 | if MaskUserId&flags != 0 { 371 | v, err := ReadShortstr(reader) 372 | props.UserId = &v 373 | if err != nil { 374 | return err 375 | } 376 | } 377 | if MaskAppId&flags != 0 { 378 | v, err := ReadShortstr(reader) 379 | props.AppId = &v 380 | if err != nil { 381 | return err 382 | } 383 | } 384 | if MaskReserved&flags != 0 { 385 | v, err := ReadShortstr(reader) 386 | props.Reserved = &v 387 | if err != nil { 388 | return err 389 | } 390 | } 391 | return nil 392 | } 393 | -------------------------------------------------------------------------------- /amqp/amqpwriter.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "io" 8 | ) 9 | 10 | func WriteFrame(buf io.Writer, frame *WireFrame) { 11 | bb := make([]byte, 0, 1+2+4+len(frame.Payload)+2) 12 | buf2 := bytes.NewBuffer(bb) 13 | WriteOctet(buf2, frame.FrameType) 14 | WriteShort(buf2, frame.Channel) 15 | WriteLongstr(buf2, frame.Payload) 16 | // buf.Write(frame.Payload.Bytes()) 17 | WriteFrameEnd(buf2) 18 | 19 | // binary.LittleEndian since we want to stick to the system 20 | // byte order and the other write functions are writing BigEndian 21 | // TODO: error checking 22 | binary.Write(buf, binary.LittleEndian, buf2.Bytes()) 23 | } 24 | 25 | // Constants 26 | 27 | func WriteProtocolHeader(buf io.Writer) error { 28 | return binary.Write(buf, binary.BigEndian, []byte{'A', 'M', 'Q', 'P', 0, 0, 9, 1}) 29 | } 30 | 31 | func WriteVersion(buf io.Writer) error { 32 | return binary.Write(buf, binary.BigEndian, []byte{0, 9}) 33 | } 34 | 35 | func WriteFrameEnd(buf io.Writer) error { 36 | return binary.Write(buf, binary.BigEndian, byte(0xCE)) 37 | } 38 | 39 | // Fields 40 | 41 | func WriteOctet(buf io.Writer, b byte) error { 42 | return binary.Write(buf, binary.BigEndian, b) 43 | } 44 | 45 | func WriteShort(buf io.Writer, i uint16) error { 46 | return binary.Write(buf, binary.BigEndian, i) 47 | } 48 | 49 | func WriteLong(buf io.Writer, i uint32) error { 50 | return binary.Write(buf, binary.BigEndian, i) 51 | } 52 | 53 | func WriteLonglong(buf io.Writer, i uint64) error { 54 | return binary.Write(buf, binary.BigEndian, i) 55 | } 56 | 57 | func WriteStringChar(buf io.Writer, b byte) error { 58 | return binary.Write(buf, binary.BigEndian, b) 59 | } 60 | 61 | func WriteShortstr(buf io.Writer, s string) error { 62 | if len(s) > int(MaxShortStringLength) { 63 | return errors.New("String too long for short string") 64 | } 65 | err := binary.Write(buf, binary.BigEndian, byte(len(s))) 66 | if err != nil { 67 | return errors.New("Could not write bytes: " + err.Error()) 68 | } 69 | return binary.Write(buf, binary.BigEndian, []byte(s)) 70 | } 71 | 72 | func WriteLongstr(buf io.Writer, bytes []byte) (err error) { 73 | if err = binary.Write(buf, binary.BigEndian, uint32(len(bytes))); err != nil { 74 | return 75 | } 76 | if err = binary.Write(buf, binary.BigEndian, bytes); err != nil { 77 | return 78 | } 79 | return nil 80 | } 81 | 82 | func WriteTimestamp(buf io.Writer, timestamp uint64) error { 83 | return binary.Write(buf, binary.BigEndian, timestamp) 84 | } 85 | 86 | func WriteTable(writer io.Writer, table *Table) error { 87 | var buf = bytes.NewBuffer(make([]byte, 0)) 88 | for _, kv := range table.Table { 89 | if err := WriteShortstr(buf, *kv.Key); err != nil { 90 | return err 91 | } 92 | if err := writeValue(buf, kv.Value); err != nil { 93 | return err 94 | } 95 | } 96 | return WriteLongstr(writer, buf.Bytes()) 97 | } 98 | 99 | func writeArray(writer io.Writer, array []*FieldValue) error { 100 | var buf = bytes.NewBuffer([]byte{}) 101 | for _, v := range array { 102 | if err := writeValue(buf, v); err != nil { 103 | return err 104 | } 105 | } 106 | return WriteLongstr(writer, buf.Bytes()) 107 | } 108 | 109 | func writeValue(writer io.Writer, value *FieldValue) (err error) { 110 | switch v := value.Value.(type) { 111 | case *FieldValue_VBoolean: 112 | if err = binary.Write(writer, binary.BigEndian, byte('t')); err == nil { 113 | if v.VBoolean { 114 | err = WriteOctet(writer, uint8(1)) 115 | } else { 116 | err = WriteOctet(writer, uint8(0)) 117 | } 118 | } 119 | case *FieldValue_VInt8: 120 | if err = binary.Write(writer, binary.BigEndian, byte('b')); err == nil { 121 | err = binary.Write(writer, binary.BigEndian, int8(v.VInt8)) 122 | } 123 | case *FieldValue_VUint8: 124 | if err = binary.Write(writer, binary.BigEndian, byte('B')); err == nil { 125 | err = binary.Write(writer, binary.BigEndian, uint8(v.VUint8)) 126 | } 127 | case *FieldValue_VInt16: 128 | if err = binary.Write(writer, binary.BigEndian, byte('U')); err == nil { 129 | err = binary.Write(writer, binary.BigEndian, int16(v.VInt16)) 130 | } 131 | case *FieldValue_VUint16: 132 | if err = binary.Write(writer, binary.BigEndian, byte('u')); err == nil { 133 | err = binary.Write(writer, binary.BigEndian, uint16(v.VUint16)) 134 | } 135 | case *FieldValue_VInt32: 136 | if err = binary.Write(writer, binary.BigEndian, byte('I')); err == nil { 137 | err = binary.Write(writer, binary.BigEndian, int32(v.VInt32)) 138 | } 139 | case *FieldValue_VUint32: 140 | if err = binary.Write(writer, binary.BigEndian, byte('i')); err == nil { 141 | err = binary.Write(writer, binary.BigEndian, uint32(v.VUint32)) 142 | } 143 | case *FieldValue_VInt64: 144 | if err = binary.Write(writer, binary.BigEndian, byte('L')); err == nil { 145 | err = binary.Write(writer, binary.BigEndian, int64(v.VInt64)) 146 | } 147 | case *FieldValue_VUint64: 148 | if err = binary.Write(writer, binary.BigEndian, byte('l')); err == nil { 149 | err = binary.Write(writer, binary.BigEndian, uint64(v.VUint64)) 150 | } 151 | case *FieldValue_VFloat: 152 | if err = binary.Write(writer, binary.BigEndian, byte('f')); err == nil { 153 | err = binary.Write(writer, binary.BigEndian, float32(v.VFloat)) 154 | } 155 | case *FieldValue_VDouble: 156 | if err = binary.Write(writer, binary.BigEndian, byte('d')); err == nil { 157 | err = binary.Write(writer, binary.BigEndian, float64(v.VDouble)) 158 | } 159 | case *FieldValue_VDecimal: 160 | if err = binary.Write(writer, binary.BigEndian, byte('D')); err == nil { 161 | if err = binary.Write(writer, binary.BigEndian, byte(*v.VDecimal.Scale)); err == nil { 162 | err = binary.Write(writer, binary.BigEndian, uint32(*v.VDecimal.Value)) 163 | } 164 | } 165 | case *FieldValue_VShortstr: 166 | if err = WriteOctet(writer, byte('s')); err == nil { 167 | err = WriteShortstr(writer, v.VShortstr) 168 | } 169 | case *FieldValue_VLongstr: 170 | if err = WriteOctet(writer, byte('S')); err == nil { 171 | err = WriteLongstr(writer, v.VLongstr) 172 | } 173 | case *FieldValue_VArray: 174 | if err = WriteOctet(writer, byte('A')); err == nil { 175 | err = writeArray(writer, v.VArray.Value) 176 | } 177 | case *FieldValue_VTimestamp: 178 | if err = WriteOctet(writer, byte('T')); err == nil { 179 | err = WriteTimestamp(writer, v.VTimestamp) 180 | } 181 | case *FieldValue_VTable: 182 | if err = WriteOctet(writer, byte('F')); err == nil { 183 | err = WriteTable(writer, v.VTable) 184 | } 185 | case nil: 186 | err = binary.Write(writer, binary.BigEndian, byte('V')) 187 | default: 188 | panic("unsupported type!") 189 | } 190 | return 191 | } 192 | 193 | func (props *BasicContentHeaderProperties) WriteProps(writer io.Writer) (flags uint16, err error) { 194 | if props.ContentType != nil { 195 | flags = flags | MaskContentType 196 | err = WriteShortstr(writer, *props.ContentType) 197 | if err != nil { 198 | return 199 | } 200 | } 201 | if props.ContentEncoding != nil { 202 | flags = flags | MaskContentEncoding 203 | err = WriteShortstr(writer, *props.ContentEncoding) 204 | if err != nil { 205 | return 206 | } 207 | } 208 | if props.Headers != nil { 209 | flags = flags | MaskHeaders 210 | err = WriteTable(writer, props.Headers) 211 | if err != nil { 212 | return 213 | } 214 | } 215 | if props.DeliveryMode != nil { 216 | flags = flags | MaskDeliveryMode 217 | err = WriteOctet(writer, *props.DeliveryMode) 218 | if err != nil { 219 | return 220 | } 221 | } 222 | if props.Priority != nil { 223 | flags = flags | MaskPriority 224 | err = WriteOctet(writer, *props.Priority) 225 | if err != nil { 226 | return 227 | } 228 | } 229 | if props.CorrelationId != nil { 230 | flags = flags | MaskCorrelationId 231 | err = WriteShortstr(writer, *props.CorrelationId) 232 | if err != nil { 233 | return 234 | } 235 | } 236 | if props.ReplyTo != nil { 237 | flags = flags | MaskReplyTo 238 | err = WriteShortstr(writer, *props.ReplyTo) 239 | if err != nil { 240 | return 241 | } 242 | } 243 | if props.Expiration != nil { 244 | flags = flags | MaskExpiration 245 | err = WriteShortstr(writer, *props.Expiration) 246 | if err != nil { 247 | return 248 | } 249 | } 250 | if props.MessageId != nil { 251 | flags = flags | MaskMessageId 252 | err = WriteShortstr(writer, *props.MessageId) 253 | if err != nil { 254 | return 255 | } 256 | } 257 | if props.Timestamp != nil { 258 | flags = flags | MaskTimestamp 259 | err = WriteLonglong(writer, *props.Timestamp) 260 | if err != nil { 261 | return 262 | } 263 | } 264 | if props.Type != nil { 265 | flags = flags | MaskType 266 | err = WriteShortstr(writer, *props.Type) 267 | if err != nil { 268 | return 269 | } 270 | } 271 | if props.UserId != nil { 272 | flags = flags | MaskUserId 273 | err = WriteShortstr(writer, *props.UserId) 274 | if err != nil { 275 | return 276 | } 277 | } 278 | if props.AppId != nil { 279 | flags = flags | MaskAppId 280 | err = WriteShortstr(writer, *props.AppId) 281 | if err != nil { 282 | return 283 | } 284 | } 285 | if props.Reserved != nil { 286 | flags = flags | MaskReserved 287 | err = WriteShortstr(writer, *props.Reserved) 288 | if err != nil { 289 | return 290 | } 291 | } 292 | return 293 | } 294 | -------------------------------------------------------------------------------- /amqp/constants_generated.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | var MaxShortStringLength uint8 = 255 4 | var FrameMethod = 1 5 | var FrameHeader = 2 6 | var FrameBody = 3 7 | var FrameHeartbeat = 8 8 | var FrameMinSize = 4096 9 | var FrameEnd = 206 10 | 11 | // Indicates that the method completed successfully. This reply code is 12 | // reserved for future use - the current protocol design does not use positive 13 | // confirmation and reply codes are sent only in case of an error. 14 | var ReplySuccess = 200 15 | 16 | // The client attempted to transfer content larger than the server could accept 17 | // at the present time. The client may retry at a later time. 18 | var ContentTooLarge = 311 19 | 20 | // When the exchange cannot deliver to a consumer when the immediate flag is 21 | // set. As a result of pending data on the queue or the absence of any 22 | // consumers of the queue. 23 | var NoConsumers = 313 24 | 25 | // An operator intervened to close the connection for some reason. The client 26 | // may retry at some later date. 27 | var ConnectionForced = 320 28 | 29 | // The client tried to work with an unknown virtual host. 30 | var InvalidPath = 402 31 | 32 | // The client attempted to work with a server entity to which it has no 33 | // access due to security settings. 34 | var AccessRefused = 403 35 | 36 | // The client attempted to work with a server entity that does not exist. 37 | var NotFound = 404 38 | 39 | // The client attempted to work with a server entity to which it has no 40 | // access because another client is working with it. 41 | var ResourceLocked = 405 42 | 43 | // The client requested a method that was not allowed because some precondition 44 | // failed. 45 | var PreconditionFailed = 406 46 | 47 | // The sender sent a malformed frame that the recipient could not decode. 48 | // This strongly implies a programming error in the sending peer. 49 | var FrameError = 501 50 | 51 | // The sender sent a frame that contained illegal values for one or more 52 | // fields. This strongly implies a programming error in the sending peer. 53 | var SyntaxError = 502 54 | 55 | // The client sent an invalid sequence of frames, attempting to perform an 56 | // operation that was considered invalid by the server. This usually implies 57 | // a programming error in the client. 58 | var CommandInvalid = 503 59 | 60 | // The client attempted to work with a channel that had not been correctly 61 | // opened. This most likely indicates a fault in the client layer. 62 | var ChannelError = 504 63 | 64 | // The peer sent a frame that was not expected, usually in the context of 65 | // a content header and body. This strongly indicates a fault in the peer's 66 | // content processing. 67 | var UnexpectedFrame = 505 68 | 69 | // The server could not complete the method because it lacked sufficient 70 | // resources. This may be due to the client creating too many of some type 71 | // of entity. 72 | var ResourceError = 506 73 | 74 | // The client tried to work with some entity in a manner that is prohibited 75 | // by the server, due to security settings or by some other criteria. 76 | var NotAllowed = 530 77 | 78 | // The client tried to use functionality that is not implemented in the 79 | // server. 80 | var NotImplemented = 540 81 | 82 | // The server could not complete the method because of an internal error. 83 | // The server may require intervention by an operator in order to resume 84 | // normal operations. 85 | var InternalError = 541 86 | -------------------------------------------------------------------------------- /amqp/domains_generated.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | var ReadMethodId = ReadShort 4 | var WriteMethodId = WriteShort 5 | var ReadReplyText = ReadShortstr 6 | var WriteReplyText = WriteShortstr 7 | var ReadPeerProperties = ReadTable 8 | var WritePeerProperties = WriteTable 9 | var ReadClassId = ReadShort 10 | var WriteClassId = WriteShort 11 | var ReadMessageCount = ReadLong 12 | var WriteMessageCount = WriteLong 13 | var ReadReplyCode = ReadShort 14 | var WriteReplyCode = WriteShort 15 | var ReadQueueName = ReadShortstr 16 | var WriteQueueName = WriteShortstr 17 | var ReadConsumerTag = ReadShortstr 18 | var WriteConsumerTag = WriteShortstr 19 | var ReadPath = ReadShortstr 20 | var WritePath = WriteShortstr 21 | var ReadExchangeName = ReadShortstr 22 | var WriteExchangeName = WriteShortstr 23 | var ReadDeliveryTag = ReadLonglong 24 | var WriteDeliveryTag = WriteLonglong 25 | -------------------------------------------------------------------------------- /amqp/errors.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | // Soft error (close channel) 4 | 5 | type AMQPError struct { 6 | Code uint16 7 | Class uint16 8 | Method uint16 9 | Msg string 10 | Soft bool 11 | } 12 | 13 | func NewSoftError(code uint16, msg string, class uint16, method uint16) *AMQPError { 14 | return &AMQPError{ 15 | Code: code, 16 | Class: class, 17 | Method: method, 18 | Msg: msg, 19 | Soft: true, 20 | } 21 | } 22 | 23 | // Hard error (close connection) 24 | 25 | func NewHardError(code uint16, msg string, class uint16, method uint16) *AMQPError { 26 | return &AMQPError{ 27 | Code: code, 28 | Class: class, 29 | Method: method, 30 | Msg: msg, 31 | Soft: false, 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /amqp/messages.proto: -------------------------------------------------------------------------------- 1 | 2 | package amqp; 3 | 4 | import "github.com/gogo/protobuf/gogoproto/gogo.proto"; 5 | import "github.com/jeffjenkins/dispatchd/amqp/protocol_generated.proto"; 6 | 7 | option (gogoproto.marshaler_all) = true; 8 | option (gogoproto.unmarshaler_all) = true; 9 | option (gogoproto.sizer_all) = true; 10 | 11 | message WireFrame { 12 | option (gogoproto.goproto_unrecognized) = false; 13 | option (gogoproto.goproto_getters) = false; 14 | optional uint32 frameType = 1 [(gogoproto.nullable) = false, (gogoproto.casttype) = "uint8"]; 15 | optional uint32 channel = 2 [(gogoproto.nullable) = false, (gogoproto.casttype) = "uint16"]; 16 | optional bytes payload = 3 ; 17 | } 18 | 19 | message IndexMessage { 20 | // The ID of the underlying message 21 | optional int64 id = 1 [(gogoproto.nullable) = false]; 22 | // The number of outstanding references to the message 23 | optional int32 refs = 2 [(gogoproto.nullable) = false]; 24 | optional bool durable = 3 [(gogoproto.nullable) = false]; 25 | optional int32 deliveryCount = 4 [(gogoproto.nullable) = false]; 26 | optional bool persisted = 5 [(gogoproto.nullable) = false]; 27 | } 28 | 29 | message Message { 30 | optional int64 id = 1 [(gogoproto.nullable) = false]; 31 | optional amqp.ContentHeaderFrame header = 2; 32 | repeated WireFrame payload = 3; 33 | optional string exchange = 4 [(gogoproto.nullable) = false]; 34 | optional string key = 5 [(gogoproto.nullable) = false]; 35 | optional amqp.BasicPublish method = 6; 36 | optional uint32 redelivered = 7 [(gogoproto.nullable) = false]; 37 | optional int64 local_id = 8 [(gogoproto.nullable) = false]; 38 | } 39 | 40 | message QueueMessage { 41 | optional int64 id = 1 [(gogoproto.nullable) = false]; 42 | optional int32 deliveryCount = 2 [(gogoproto.nullable) = false]; 43 | optional bool durable = 3 [(gogoproto.nullable) = false]; 44 | optional uint32 msgSize = 4 [(gogoproto.nullable) = false]; 45 | optional int64 localId = 5 [(gogoproto.nullable) = false]; 46 | } 47 | 48 | message ContentHeaderFrame { 49 | optional uint32 content_class = 1 [(gogoproto.casttype) = "uint16", (gogoproto.nullable) = false]; 50 | optional uint32 content_weight = 2 [(gogoproto.casttype) = "uint16", (gogoproto.nullable) = false]; 51 | optional uint64 content_body_size = 3 [(gogoproto.nullable) = false]; 52 | optional uint32 property_flags = 4 [(gogoproto.casttype) = "uint16", (gogoproto.nullable) = false]; 53 | optional amqp.BasicContentHeaderProperties properties = 5; 54 | } 55 | 56 | message TxMessage { 57 | optional Message msg = 1; 58 | optional string queue_name = 2 [(gogoproto.nullable) = false]; 59 | } 60 | 61 | message TxAck { 62 | optional uint64 tag = 1 [(gogoproto.nullable) = false]; 63 | optional bool multiple = 2 [(gogoproto.nullable) = false]; 64 | optional bool nack = 3 [(gogoproto.nullable) = false]; 65 | optional bool requeue_nack = 4 [(gogoproto.nullable) = false]; 66 | } 67 | 68 | message UnackedMessage { 69 | optional string consumer_tag = 1 [(gogoproto.nullable) = false]; 70 | optional QueueMessage msg = 2; 71 | optional string queue_name = 3 [(gogoproto.nullable) = false]; 72 | } 73 | -------------------------------------------------------------------------------- /amqp/readwrite_test.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "math/rand" 7 | "reflect" 8 | "testing" 9 | ) 10 | 11 | var testRand *rand.Rand = nil 12 | 13 | func init() { 14 | var source = rand.NewSource(int64(1234)) 15 | testRand = rand.New(source) 16 | } 17 | 18 | func TestMalformedTable(t *testing.T) { 19 | var table = NewTable() 20 | table.SetKey("hi", "bye") 21 | table.SetKey("mi", "bye") 22 | var method = &ExchangeBind{ 23 | Destination: "dest", 24 | Source: "src", 25 | RoutingKey: "rk", 26 | NoWait: true, 27 | Arguments: table, 28 | } 29 | var outBuf = bytes.NewBuffer([]byte{}) 30 | err := method.Write(outBuf) 31 | if err != nil { 32 | t.Errorf(err.Error()) 33 | } 34 | outBytes := outBuf.Bytes()[4:] 35 | // Use this to see which bytes to alter 36 | printWireBytes(outBytes, t) 37 | 38 | // make key too long 39 | outBytes[19] = 250 40 | var inMethod = &ExchangeBind{} 41 | err = inMethod.Read(bytes.NewBuffer(outBytes), true) 42 | if err == nil { 43 | t.Errorf("Successfully read malformed bytes!") 44 | } 45 | outBytes[19] = 2 46 | 47 | // make value the wrong type 48 | outBytes[22] = 'S' 49 | inMethod = &ExchangeBind{} 50 | err = inMethod.Read(bytes.NewBuffer(outBytes), true) 51 | if err == nil { 52 | t.Errorf("Successfully read malformed bytes!") 53 | } 54 | outBytes[22] = 's' 55 | 56 | // duplicate keys 57 | outBytes[28] = 'h' 58 | inMethod = &ExchangeBind{} 59 | err = inMethod.Read(bytes.NewBuffer(outBytes), true) 60 | if err == nil { 61 | t.Errorf("Successfully read malformed bytes!") 62 | } 63 | outBytes[28] = 'm' 64 | 65 | // can't read value type 66 | inMethod = &ExchangeBind{} 67 | outBytes[27] = 7 68 | err = inMethod.Read(bytes.NewBuffer(outBytes), true) 69 | if err == nil { 70 | t.Errorf("Successfully read malformed bytes!") 71 | } 72 | outBytes[27] = 2 73 | 74 | // t.Errorf("fail") 75 | } 76 | 77 | func TestReadValue(t *testing.T) { 78 | tryRead := func(bts []byte, msg string) { 79 | _, err := readValue(bytes.NewBuffer(bts), true) 80 | if err == nil { 81 | t.Errorf(msg) 82 | } 83 | } 84 | var types = []byte{'t', 'b', 'B', 'U', 'u', 'I', 'i', 'L', 'l', 85 | 'f', 'd', 'D', 's', 'S', 'A', 'T', 'F'} 86 | // 'V' isn't included since it has no value 87 | for i := 0; i < len(types); i++ { 88 | tryRead([]byte{types[i]}, fmt.Sprintf("Successfully read malformed value for type %c", types[i])) 89 | } 90 | tryRead([]byte{'D', 1}, "read malformed decimal") 91 | 92 | tryRead([]byte{'~', 1}, "read bad value type") 93 | 94 | // successful timestamp read since the server doesn't really support them 95 | val, err := readValue(bytes.NewBuffer([]byte{'T', 0, 0, 0, 0, 0, 0, 0, 2}), true) 96 | if err != nil { 97 | t.Errorf(err.Error()) 98 | } 99 | if val.GetVTimestamp() != uint64(2) { 100 | t.Errorf("Failed to deserialize uint64") 101 | } 102 | // successful 'V' (no value) read, mainly for coverage 103 | val, err = readValue(bytes.NewBuffer([]byte{'V'}), true) 104 | if err != nil { 105 | t.Errorf(err.Error()) 106 | } 107 | if val != nil { 108 | t.Errorf("Failed to deserialize 'V' field ") 109 | } 110 | // tryRead() 111 | 112 | } 113 | 114 | func sptr(s string) *string { 115 | return &s 116 | } 117 | 118 | func bptr(b byte) *byte { 119 | return &b 120 | } 121 | 122 | // func allFlags() uint16 { 123 | // return (MaskContentType | MaskContentEncoding | MaskHeaders | 124 | // MaskDeliveryMode | MaskPriority | MaskCorrelationId | MaskReplyTo | 125 | // MaskExpiration | MaskMessageId | MaskTimestamp | MaskType | MaskUserId | 126 | // MaskAppId | MaskReserved) 127 | // } 128 | 129 | func TestReadingContentHeaderProps(t *testing.T) { 130 | time := uint64(1312312) 131 | var props = BasicContentHeaderProperties{ 132 | ContentType: sptr("ContentType"), 133 | ContentEncoding: sptr("ContentEncoding"), 134 | Headers: NewTable(), 135 | DeliveryMode: bptr(byte(1)), 136 | Priority: bptr(byte(1)), 137 | CorrelationId: sptr("CorrelationId"), 138 | ReplyTo: sptr("ReplyTo"), 139 | Expiration: sptr("Expiration"), 140 | MessageId: sptr("MessageId"), 141 | Timestamp: &time, 142 | Type: sptr("Type"), 143 | UserId: sptr("UserId"), 144 | AppId: sptr("AppId"), 145 | Reserved: sptr(""), 146 | } 147 | var outBuf = bytes.NewBuffer([]byte{}) 148 | flags, err := props.WriteProps(outBuf) 149 | if err != nil { 150 | t.Errorf(err.Error()) 151 | } 152 | outBytes := outBuf.Bytes() 153 | // Use subsets of bytes to trigger all failure conditions 154 | for i := 0; i < len(outBytes); i++ { 155 | var partialBuffer = bytes.NewBuffer(outBytes[:i]) 156 | var inProps = &BasicContentHeaderProperties{} 157 | err = inProps.ReadProps(flags, partialBuffer, true) 158 | if err == nil { 159 | t.Errorf("Successfully read malformed props. %d/%d bytes read", i, len(outBytes)) 160 | } 161 | } 162 | // Succeed in reading all bytes 163 | var partialBuffer = bytes.NewBuffer(outBytes) 164 | var inProps = &BasicContentHeaderProperties{} 165 | err = inProps.ReadProps(flags, partialBuffer, true) 166 | if err != nil { 167 | t.Errorf(err.Error()) 168 | } 169 | // check a few random fields 170 | if *inProps.ContentType != "ContentType" { 171 | t.Errorf("Bad content type: %s", *inProps.ContentType) 172 | } 173 | if *inProps.Timestamp != time { 174 | t.Errorf("Bad timestamp: %s", *inProps.Timestamp) 175 | } 176 | } 177 | 178 | func TestReadArrayFailures(t *testing.T) { 179 | _, err := readArray(bytes.NewBuffer([]byte{0, 0, 0, 1, 'L'}), true) 180 | if err == nil { 181 | t.Errorf("Read a malformed array value") 182 | } 183 | } 184 | 185 | func TestWireFrame(t *testing.T) { 186 | // Write frame to bytes 187 | var outFrame = &WireFrame{ 188 | FrameType: uint8(10), 189 | Channel: uint16(12311), 190 | Payload: []byte{0, 0, 9, 1}, 191 | } 192 | var buf = bytes.NewBuffer(make([]byte, 0)) 193 | WriteFrame(buf, outFrame) 194 | 195 | // Read frame from bytes 196 | var outBytes = buf.Bytes() 197 | var inFrame, err = ReadFrame(bytes.NewBuffer(outBytes)) 198 | if err != nil { 199 | t.Errorf(err.Error()) 200 | } 201 | if !reflect.DeepEqual(inFrame, outFrame) { 202 | t.Errorf("Couldn't read the frame that was written") 203 | } 204 | // Incomplete frames 205 | for i := 0; i < len(outBytes); i++ { 206 | var noTypeBuf = bytes.NewBuffer(outBytes[:i]) 207 | _, err = ReadFrame(noTypeBuf) 208 | if err == nil { 209 | t.Errorf("No error on malformed frame. %d/%d bytes read", i, len(buf.Bytes())) 210 | } 211 | } 212 | } 213 | 214 | func TestMethodTypes(t *testing.T) { 215 | for _, method := range methodsForTesting() { 216 | var outBuf = bytes.NewBuffer([]byte{}) 217 | err := method.Write(outBuf) 218 | if err != nil { 219 | t.Errorf(err.Error()) 220 | } 221 | 222 | var outBytes = outBuf.Bytes()[4:] 223 | // Try all lengths of bytes below the ones needed 224 | for index, _ := range outBytes { 225 | var inBind = reflect.New(reflect.TypeOf(method).Elem()).Interface().(MethodFrame) 226 | err = inBind.Read(bytes.NewBuffer(outBytes[:index]), true) 227 | if err == nil { 228 | printWireBytes(outBytes[:index], t) 229 | t.Errorf("Parsed malformed request bytes") 230 | return 231 | } 232 | } 233 | // printWireBytes(outBytes, t) 234 | // Try the right set of bytes 235 | var inBind = reflect.New(reflect.TypeOf(method).Elem()).Interface().(MethodFrame) 236 | err = inBind.Read(bytes.NewBuffer(outBytes), true) 237 | if err != nil { 238 | t.Logf("Method is %s", method.MethodName()) 239 | printWireBytes(outBytes, t) 240 | t.Errorf(err.Error()) 241 | return 242 | } 243 | } 244 | } 245 | 246 | func methodsForTesting() []MethodFrame { 247 | return []MethodFrame{ 248 | &ExchangeBind{ 249 | Destination: "dest", 250 | Source: "src", 251 | RoutingKey: "rk", 252 | NoWait: true, 253 | Arguments: EverythingTable(), 254 | }, 255 | &ConnectionTune{ 256 | ChannelMax: uint16(3), 257 | FrameMax: uint32(1), 258 | Heartbeat: uint16(2), 259 | }, 260 | &BasicDeliver{ 261 | ConsumerTag: string("deliver"), 262 | DeliveryTag: uint64(4), 263 | Redelivered: true, 264 | Exchange: string("ex1"), 265 | RoutingKey: string("rk1"), 266 | }, 267 | } 268 | } 269 | 270 | func printWireBytes(bs []byte, t *testing.T) { 271 | t.Logf("Byte count: %d", len(bs)) 272 | for i, b := range bs { 273 | t.Logf("%d:(%c %d),", i, b, b) 274 | } 275 | t.Logf("\n") 276 | } 277 | -------------------------------------------------------------------------------- /amqp/table.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "fmt" 5 | "reflect" 6 | ) 7 | 8 | func NewTable() *Table { 9 | return &Table{Table: make([]*FieldValuePair, 0)} 10 | } 11 | 12 | func EquivalentTables(t1 *Table, t2 *Table) bool { 13 | if len(t1.Table) == 0 && len(t2.Table) == 0 { 14 | return true 15 | } 16 | return reflect.DeepEqual(t1, t2) 17 | } 18 | 19 | func (table *Table) GetKey(key string) *FieldValue { 20 | for _, kv := range table.Table { 21 | var k = *kv.Key 22 | if k == key { 23 | return kv.Value 24 | } 25 | } 26 | return nil 27 | } 28 | 29 | func (table *Table) SetKey(key string, value interface{}) error { 30 | var fieldValue *FieldValue = nil 31 | for _, kv := range table.Table { 32 | var k = *kv.Key 33 | if k == key { 34 | fieldValue = kv.Value 35 | } 36 | } 37 | if fieldValue == nil { 38 | fieldValue = &FieldValue{ 39 | Value: nil, 40 | } 41 | table.Table = append(table.Table, &FieldValuePair{ 42 | Key: &key, 43 | Value: fieldValue, 44 | }) 45 | } 46 | fv, err := calcValue(value) 47 | 48 | if err != nil { 49 | return err 50 | } 51 | fieldValue.Value = fv 52 | return nil 53 | } 54 | 55 | func calcValue(value interface{}) (isFieldValue_Value, error) { 56 | switch v := value.(type) { 57 | case bool: 58 | return &FieldValue_VBoolean{VBoolean: v}, nil 59 | case int8: 60 | return &FieldValue_VInt8{VInt8: v}, nil 61 | case uint8: 62 | return &FieldValue_VUint8{VUint8: v}, nil 63 | case int16: 64 | return &FieldValue_VInt16{VInt16: v}, nil 65 | case uint16: 66 | return &FieldValue_VUint16{VUint16: v}, nil 67 | case int32: 68 | return &FieldValue_VInt32{VInt32: v}, nil 69 | case uint32: 70 | return &FieldValue_VUint32{VUint32: v}, nil 71 | case int64: 72 | return &FieldValue_VInt64{VInt64: v}, nil 73 | case uint64: 74 | return &FieldValue_VUint64{VUint64: v}, nil 75 | case float32: 76 | return &FieldValue_VFloat{VFloat: v}, nil 77 | case float64: 78 | return &FieldValue_VDouble{VDouble: v}, nil 79 | case *Decimal: 80 | return &FieldValue_VDecimal{VDecimal: v}, nil 81 | case string: 82 | return &FieldValue_VShortstr{VShortstr: v}, nil 83 | case []byte: 84 | return &FieldValue_VLongstr{VLongstr: v}, nil 85 | case *FieldArray: 86 | return &FieldValue_VArray{VArray: v}, nil 87 | // TODO: not currently reachable since uint64 will take it 88 | // case timestamp: 89 | // return &FieldValue_VTimestamp{VTimestamp: v} 90 | case *Table: 91 | return &FieldValue_VTable{VTable: v}, nil 92 | } 93 | return nil, fmt.Errorf("Field value has invalid type for table/array: %s", value) 94 | } 95 | 96 | func NewFieldArray() *FieldArray { 97 | return &FieldArray{Value: make([]*FieldValue, 0)} 98 | } 99 | 100 | func (fa *FieldArray) AppendFA(value interface{}) error { 101 | var fv, err = calcValue(value) 102 | if err != nil { 103 | return err 104 | } 105 | fa.Value = append(fa.Value, &FieldValue{Value: fv}) 106 | return nil 107 | } 108 | -------------------------------------------------------------------------------- /amqp/table_test.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "bytes" 5 | "github.com/gogo/protobuf/proto" 6 | "math/rand" 7 | "reflect" 8 | "testing" 9 | ) 10 | 11 | func TestPersistRoundtripEmpty(t *testing.T) { 12 | // Protobuf returns a nil Table.Table if there are no entries. 13 | // This test makes sure equality with a NewTable() table works 14 | var table = NewTable() 15 | var bb, _ = proto.Marshal(table) 16 | var table2 = &Table{} 17 | proto.Unmarshal(bb, table2) 18 | if !EquivalentTables(table, table2) { 19 | t.Errorf("%#v, %#v\n", table, table2) 20 | } 21 | } 22 | 23 | func TestNilTableSet(t *testing.T) { 24 | // Protobuf returns a nil Table.Table if there are no entries. 25 | // This test makes sure setting keys on a nil table works 26 | table := Table{Table: nil} 27 | err := table.SetKey("a", uint32(1)) 28 | if err != nil { 29 | t.Errorf(err.Error()) 30 | } 31 | val := table.GetKey("a") 32 | if val == nil { 33 | t.Errorf("No value returned for key 'a'") 34 | } 35 | } 36 | 37 | func TestBasicFieldArray(t *testing.T) { 38 | // TODO: test this more thoroughly. it gets some working out in other tests 39 | // but this could do more. 40 | var fa = NewFieldArray() 41 | err := fa.AppendFA(make(map[bool]bool)) 42 | if err == nil { 43 | t.Errorf("No error with bad append value") 44 | } 45 | } 46 | 47 | func TestBasicTable(t *testing.T) { 48 | // Create 49 | var table = NewTable() 50 | table.SetKey("product", "dispatchd") 51 | table.SetKey("version", uint8(7)) 52 | table.SetKey("version", uint8(6)) // for code coverage, reset a value 53 | err := table.SetKey("bad", make(map[bool]bool)) // for code coverage, a type it doesn't understand 54 | if err == nil { 55 | t.Errorf("No error on bad set value") 56 | } 57 | 58 | var fv = table.GetKey("version") 59 | if fv.Value.(*FieldValue_VUint8).VUint8 != uint8(6) { 60 | t.Errorf("Didn't get the right key from table") 61 | } 62 | if table.GetKey("DOES NOT EXIST") != nil { 63 | t.Errorf("Found key that shouldn't exist!") 64 | } 65 | 66 | } 67 | 68 | func TestTableTypes(t *testing.T) { 69 | var inTable = EverythingTable() 70 | 71 | // Encode 72 | writer := bytes.NewBuffer(make([]byte, 0)) 73 | err := WriteTable(writer, inTable) 74 | if err != nil { 75 | t.Errorf(err.Error()) 76 | } 77 | 78 | // decode 79 | var reader = bytes.NewReader(writer.Bytes()) 80 | outTable, err := ReadTable(reader, true) 81 | if err != nil { 82 | t.Errorf(err.Error()) 83 | } 84 | 85 | // compare 86 | if !EquivalentTables(inTable, outTable) { 87 | t.Errorf("Tables no equal") 88 | } 89 | } 90 | 91 | func (table *Table) Generator(rand *rand.Rand, size int) reflect.Value { 92 | return reflect.ValueOf(EverythingTable()) 93 | } 94 | 95 | func EverythingTable() *Table { 96 | var inTable = NewTable() 97 | 98 | // Basic types 99 | inTable.SetKey("bool", true) 100 | inTable.SetKey("int8", int8(-2)) 101 | inTable.SetKey("uint8", uint8(3)) 102 | inTable.SetKey("int16", int16(-4)) 103 | inTable.SetKey("uint16", uint16(5)) 104 | inTable.SetKey("int32", int32(-6)) 105 | inTable.SetKey("uint32", uint32(7)) 106 | inTable.SetKey("int64", int64(-8)) 107 | inTable.SetKey("uint64", uint64(9)) 108 | inTable.SetKey("float32", float32(10.1)) 109 | inTable.SetKey("float64", float64(-11.2)) 110 | inTable.SetKey("string", "string value") 111 | inTable.SetKey("[]byte", []byte{14, 15, 16, 17}) 112 | // TODO: timestamp 113 | // Decimal 114 | var scale = uint8(12) 115 | var value = int32(-13) 116 | inTable.SetKey("*Decimal", &Decimal{&scale, &value}) 117 | // Field Array 118 | var fa = NewFieldArray() 119 | fa.AppendFA(int8(101)) 120 | inTable.SetKey("*FieldArray", fa) 121 | // Table 122 | var innerTable = NewTable() 123 | innerTable.SetKey("some key", "some value") 124 | inTable.SetKey("*Table", innerTable) 125 | return inTable 126 | } 127 | -------------------------------------------------------------------------------- /amqp/testlib.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | // "fmt" 5 | "math/rand" 6 | "sync/atomic" 7 | "time" 8 | ) 9 | 10 | var counter int64 11 | 12 | func init() { 13 | rand.Seed(time.Now().UnixNano()) 14 | counter = 0 15 | } 16 | 17 | func nextId() int64 { 18 | return atomic.AddInt64(&counter, 1) 19 | } 20 | 21 | func RandomMessage(persistent bool) *Message { 22 | var size, payload = RandomPayload() 23 | var randomHeader = RandomHeader(size, persistent) 24 | var exchange = "exchange-name" 25 | var routingKey = "routing-key" 26 | var randomPublish = &BasicPublish{ 27 | Exchange: exchange, 28 | RoutingKey: routingKey, 29 | Mandatory: false, 30 | Immediate: false, 31 | } 32 | return &Message{ 33 | Id: nextId(), 34 | Header: randomHeader, 35 | Payload: payload, 36 | Exchange: exchange, 37 | Key: routingKey, 38 | Method: randomPublish, 39 | Redelivered: 0, 40 | LocalId: nextId(), 41 | } 42 | } 43 | 44 | func RandomPayload() (uint64, []*WireFrame) { 45 | var size uint64 = 500 46 | var payload = make([]byte, size) 47 | return size, []*WireFrame{ 48 | &WireFrame{ 49 | FrameType: 3, 50 | Channel: 1, 51 | Payload: payload, 52 | }, 53 | } 54 | } 55 | 56 | func RandomHeader(size uint64, persistent bool) *ContentHeaderFrame { 57 | var props = BasicContentHeaderProperties{} 58 | if persistent { 59 | var b = byte(2) 60 | props.DeliveryMode = &b 61 | } else { 62 | var b = byte(1) 63 | props.DeliveryMode = &b 64 | } 65 | return &ContentHeaderFrame{ 66 | ContentClass: 60, 67 | ContentWeight: 0, 68 | ContentBodySize: size, 69 | PropertyFlags: 0, 70 | Properties: &props, 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /amqp/tx.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | func NewTxMessage(msg *Message, queueName string) *TxMessage { 4 | return &TxMessage{ 5 | QueueName: queueName, 6 | Msg: msg, 7 | } 8 | } 9 | 10 | func NewTxAck(tag uint64, nack bool, requeueNack bool, multiple bool) *TxAck { 11 | return &TxAck{ 12 | Tag: tag, 13 | Nack: nack, 14 | RequeueNack: requeueNack, 15 | Multiple: multiple, 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /amqp/tx_test.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "reflect" 5 | "testing" 6 | ) 7 | 8 | func TestNewTxMessage(t *testing.T) { 9 | var msg = RandomMessage(false) 10 | var txMessage = NewTxMessage(msg, "hithere") 11 | if !reflect.DeepEqual(txMessage.GetMsg(), msg) { 12 | t.Errorf("Messages not equal") 13 | } 14 | if txMessage.GetQueueName() != "hithere" { 15 | t.Errorf("bad queue name") 16 | } 17 | } 18 | 19 | func TestNewTxAck(t *testing.T) { 20 | var tag uint64 = 12345 //, nack bool, requeueNack bool, multiple bool 21 | var txack = NewTxAck(tag, false, false, true) 22 | if txack.GetTag() != 12345 { 23 | t.Errorf("tag mismatch") 24 | } 25 | if !txack.Multiple { 26 | t.Errorf("bad multiple flag") 27 | } 28 | txack = NewTxAck(tag, false, true, false) 29 | if !txack.RequeueNack { 30 | t.Errorf("bad multiple requeue") 31 | } 32 | txack = NewTxAck(tag, true, false, true) 33 | if !txack.Nack { 34 | t.Errorf("bad multiple nack") 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /amqp/types.go: -------------------------------------------------------------------------------- 1 | package amqp 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "github.com/jeffjenkins/dispatchd/util" 7 | "io" 8 | "regexp" 9 | ) 10 | 11 | type Frame interface { 12 | FrameType() byte 13 | } 14 | 15 | type MethodFrame interface { 16 | MethodName() string 17 | MethodIdentifier() (uint16, uint16) 18 | Read(reader io.Reader, strictMode bool) (err error) 19 | Write(writer io.Writer) (err error) 20 | FrameType() byte 21 | } 22 | 23 | // A message resource is something which has limits on the count of 24 | // messages it can handle as well as the cumulative size of the messages 25 | // it can handle. 26 | type MessageResourceHolder interface { 27 | AcquireResources(qm *QueueMessage) bool 28 | ReleaseResources(qm *QueueMessage) 29 | } 30 | 31 | func NewMessage(method *BasicPublish, localId int64) *Message { 32 | return &Message{ 33 | Id: util.NextId(), 34 | Method: method, 35 | Exchange: method.Exchange, 36 | Key: method.RoutingKey, 37 | Payload: make([]*WireFrame, 0, 1), 38 | LocalId: localId, 39 | } 40 | } 41 | 42 | func NewTruncatedBodyFrame(channel uint16) WireFrame { 43 | return WireFrame{ 44 | FrameType: byte(FrameBody), 45 | Channel: channel, 46 | Payload: make([]byte, 0, 0), 47 | } 48 | } 49 | 50 | func NewUnackedMessage(tag string, qm *QueueMessage, queueName string) *UnackedMessage { 51 | return &UnackedMessage{ 52 | ConsumerTag: tag, 53 | Msg: qm, 54 | QueueName: queueName, 55 | } 56 | } 57 | 58 | func NewIndexMessage(id int64, refCount int32, durable bool, deliveryCount int32) *IndexMessage { 59 | return &IndexMessage{ 60 | Id: id, 61 | Refs: refCount, 62 | Durable: durable, 63 | DeliveryCount: deliveryCount, 64 | } 65 | } 66 | 67 | func NewQueueMessage(id int64, deliveryCount int32, durable bool, msgSize uint32, localId int64) *QueueMessage { 68 | return &QueueMessage{ 69 | Id: id, 70 | DeliveryCount: deliveryCount, 71 | Durable: durable, 72 | MsgSize: msgSize, 73 | LocalId: localId, 74 | // NOTE: When loading this from disk later we should zero localId out. This is the ID 75 | // of the publishing channel, which isn't relevant on server boot. 76 | } 77 | } 78 | 79 | func (frame *ContentHeaderFrame) FrameType() byte { 80 | return 2 81 | } 82 | 83 | var exchangeNameRegex = regexp.MustCompile(`^[a-zA-Z0-9-_.:]*$`) 84 | 85 | func CheckExchangeOrQueueName(s string) error { 86 | // Is it possible this length check is generally ignored since a short 87 | // string is only twice as long? 88 | if len(s) > 127 { 89 | return fmt.Errorf("Exchange name too long: %d", len(s)) 90 | } 91 | if !exchangeNameRegex.MatchString(s) { 92 | return fmt.Errorf("Exchange name invalid: %s", s) 93 | } 94 | return nil 95 | } 96 | 97 | func (frame *ContentHeaderFrame) Read(reader io.Reader, strictMode bool) (err error) { 98 | frame.ContentClass, err = ReadShort(reader) 99 | if err != nil { 100 | return err 101 | } 102 | 103 | frame.ContentWeight, err = ReadShort(reader) 104 | if err != nil { 105 | return err 106 | } 107 | if frame.ContentWeight != 0 { 108 | return errors.New("Bad content weight in header frame. Should be 0") 109 | } 110 | 111 | frame.ContentBodySize, err = ReadLonglong(reader) 112 | if err != nil { 113 | return err 114 | } 115 | 116 | frame.PropertyFlags, err = ReadShort(reader) 117 | if err != nil { 118 | return err 119 | } 120 | 121 | frame.Properties = &BasicContentHeaderProperties{} 122 | err = frame.Properties.ReadProps(frame.PropertyFlags, reader, strictMode) 123 | if err != nil { 124 | return err 125 | } 126 | return nil 127 | } 128 | -------------------------------------------------------------------------------- /amqpgen/amqpgen.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/xml" 6 | "flag" 7 | "fmt" 8 | "io/ioutil" 9 | "os" 10 | "strings" 11 | ) 12 | 13 | var amqpToGo = map[string]string{ 14 | "bit": "bool", 15 | "octet": "byte", 16 | "short": "uint16", 17 | "long": "uint32", 18 | "longlong": "uint64", 19 | "table": "Table", 20 | "timestamp": "uint64", 21 | "shortstr": "string", 22 | "longstr": "[]byte", 23 | } 24 | 25 | var amqpToProto = map[string]string{ 26 | "bit": "bool", 27 | "octet": "uint32", 28 | "short": "uint32", 29 | "long": "uint32", 30 | "longlong": "uint64", 31 | "table": "Table", 32 | "timestamp": "uint64", 33 | "shortstr": "string", 34 | "longstr": "bytes", 35 | } 36 | 37 | type Root struct { 38 | Amqp Amqp `xml:"amqp"` 39 | } 40 | 41 | type Amqp struct { 42 | Constants []*Constant `xml:"constant"` 43 | Domains []*Domain `xml:"domain"` 44 | Classes []*Class `xml:"class"` 45 | } 46 | 47 | type Domain struct { 48 | Name string `xml:"name,attr"` 49 | Type string `xml:"type,attr"` 50 | } 51 | 52 | type Class struct { 53 | Methods []*Method `xml:"method"` 54 | Name string `xml:"name,attr"` 55 | NormName string 56 | Handler string `xml:"handler,attr"` 57 | Index string `xml:"index,attr"` 58 | Fields []*Field `xml:"field"` 59 | } 60 | 61 | type Constant struct { 62 | Name string `xml:"name,attr"` 63 | NormName string 64 | Value uint16 `xml:"value,attr"` 65 | Class string `xml:"class,attr"` 66 | } 67 | 68 | type Method struct { 69 | Name string `xml:"name,attr"` 70 | NormName string 71 | Synchronous string `xml:"synchronous,attr"` 72 | Index string `xml:"index,attr"` 73 | Fields []*Field `xml:"field"` 74 | BitsAtEnd bool 75 | } 76 | 77 | type Field struct { 78 | Name string `xml:"name,attr"` 79 | Domain string `xml:"domain,attr"` 80 | AmqpType string `xml:"type,attr"` 81 | NormName string 82 | ProtoType string 83 | ProtoIndex int 84 | ProtoName string 85 | Options string 86 | MaskIndex string 87 | Serializer string 88 | ReadArgs string 89 | GoType string 90 | BitOffset int 91 | PreviousBit bool 92 | } 93 | 94 | var specFile string 95 | 96 | func transform(r *Root) { 97 | transformConstants(r.Amqp.Constants) 98 | domainTypes := transformDomains(r.Amqp.Domains) 99 | transformClasses(r.Amqp.Classes, domainTypes) 100 | } 101 | 102 | func transformConstants(cs []*Constant) { 103 | for _, c := range cs { 104 | c.NormName = normalizeName(c.Name) 105 | } 106 | } 107 | 108 | func transformClasses(cs []*Class, domainTypes map[string]string) { 109 | for _, c := range cs { 110 | c.NormName = normalizeName(c.Name) 111 | transformFields(c.Fields, domainTypes, true, 1) 112 | transformMethods(c.NormName, c.Methods, domainTypes) 113 | } 114 | } 115 | 116 | func transformDomains(ds []*Domain) map[string]string { 117 | domainTypes := make(map[string]string) 118 | for _, d := range ds { 119 | domainTypes[d.Name] = d.Type 120 | } 121 | return domainTypes 122 | } 123 | 124 | func transformMethods(className string, ms []*Method, domainTypes map[string]string) { 125 | for _, m := range ms { 126 | m.NormName = className + normalizeName(m.Name) 127 | m.BitsAtEnd = transformFields(m.Fields, domainTypes, false, 1) 128 | } 129 | } 130 | 131 | func transformFields(fs []*Field, domainTypes map[string]string, nullable bool, offset int) bool { 132 | var bits = 0 133 | var bitsAtEnd = false 134 | for index, f := range fs { 135 | var ok bool 136 | domain := f.Domain 137 | if f.AmqpType != "" { 138 | domain = f.AmqpType 139 | } else { 140 | f.AmqpType, ok = domainTypes[domain] 141 | if !ok { 142 | panic("") 143 | } 144 | } 145 | f.ProtoType, ok = amqpToProto[f.AmqpType] 146 | if !ok { 147 | panic("") 148 | } 149 | f.GoType, ok = amqpToGo[f.AmqpType] 150 | if !ok { 151 | panic("") 152 | } 153 | 154 | f.NormName = normalizeName(f.Name) 155 | f.ProtoName = protoName(f.Name) 156 | f.ProtoIndex = index + offset 157 | f.MaskIndex = fmt.Sprintf("%04x", (0 | 1<<(uint(15)-uint(index)))) 158 | f.Serializer = normalizeName(domain) 159 | if strings.Contains(f.Serializer, "Properties") || strings.Contains(f.Serializer, "Table") { 160 | f.ReadArgs = ", strictMode" 161 | } 162 | if f.AmqpType == "bit" { 163 | f.BitOffset = bits 164 | bits += 1 165 | bitsAtEnd = true 166 | } else { 167 | f.PreviousBit = bits > 0 168 | f.BitOffset = -1 169 | bits = 0 170 | bitsAtEnd = false 171 | } 172 | f.Options = fieldOptions(f, domainTypes, nullable) 173 | } 174 | return bitsAtEnd 175 | } 176 | 177 | func fieldOptions(f *Field, domainTypes map[string]string, nullable bool) string { 178 | var options = make([]string, 0) 179 | if f.GoType != f.ProtoType && f.ProtoType != "bytes" { 180 | options = append(options, fmt.Sprintf("(gogoproto.casttype) = \"%s\"", f.GoType)) 181 | } 182 | if f.ProtoType != "Table" && f.ProtoType != "bytes" && !nullable { 183 | options = append(options, "(gogoproto.nullable) = false") 184 | } 185 | if len(options) > 0 { 186 | return " [" + strings.Join(options, ", ") + "]" 187 | } 188 | return "" 189 | 190 | } 191 | 192 | func normalizeName(s string) string { 193 | parts := strings.Split(s, "-") 194 | ret := "" 195 | for _, p := range parts { 196 | ret += upperFirst(p) 197 | } 198 | return ret 199 | } 200 | func protoName(s string) string { 201 | if strings.Contains(s, "reserved") { 202 | return strings.Join(strings.Split(s, "-"), "") 203 | } 204 | return strings.Join(strings.Split(s, "-"), "_") 205 | } 206 | 207 | func upperFirst(s string) string { 208 | if s == "" { 209 | return "" 210 | } 211 | 212 | return string(bytes.ToUpper([]byte(s[0:1]))) + s[1:] 213 | } 214 | 215 | func main() { 216 | // fmt.Println(protoTemplate) 217 | flag.StringVar(&specFile, "spec", "", "Spec XML file") 218 | flag.Parse() 219 | var bytes, err = ioutil.ReadFile(specFile) 220 | if err != nil { 221 | panic(err) 222 | } 223 | var root Root 224 | err = xml.Unmarshal(bytes, &root.Amqp) 225 | if err != nil { 226 | panic(err.Error()) 227 | } 228 | transform(&root) 229 | // Proto 230 | f, err := os.Create("amqp/protocol_generated.proto") 231 | if err != nil { 232 | panic(err.Error()) 233 | } 234 | protoTemplate.Execute(f, &root) 235 | 236 | // Readers/Writers 237 | f, err = os.Create("amqp/protocol_protobuf_readwrite_generated.go") 238 | if err != nil { 239 | panic(err.Error()) 240 | } 241 | readWriteTemplate.Execute(f, &root) 242 | } 243 | -------------------------------------------------------------------------------- /amqpgen/templates.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "text/template" 5 | ) 6 | 7 | var protoTemplate *template.Template 8 | var readWriteTemplate *template.Template 9 | 10 | func init() { 11 | // 12 | // PROTOCOL BUFFER TEMPLATE 13 | // 14 | t, err := template.New("proto").Parse(`package amqp; 15 | 16 | import "github.com/gogo/protobuf/gogoproto/gogo.proto"; 17 | import "github.com/jeffjenkins/dispatchd/amqp/amqp.proto"; 18 | 19 | option (gogoproto.marshaler_all) = true; 20 | option (gogoproto.sizer_all) = true; 21 | option (gogoproto.unmarshaler_all) = true; 22 | {{range .Amqp.Classes}} 23 | {{if .Fields}} 24 | message {{.NormName}}ContentHeaderProperties { 25 | {{range $index, $field := .Fields}} 26 | optional {{.ProtoType}} {{.ProtoName}} = {{.ProtoIndex}}{{.Options}};{{end}} 27 | } 28 | {{end}} 29 | {{range .Methods}} 30 | message {{.NormName}} { 31 | option (gogoproto.goproto_unrecognized) = false; 32 | option (gogoproto.goproto_getters) = false; 33 | {{range $index, $field := .Fields}} 34 | optional {{.ProtoType}} {{.ProtoName}} = {{.ProtoIndex}}{{.Options}};{{end}} 35 | } 36 | {{end}} 37 | {{end}} 38 | 39 | 40 | `) 41 | if err != nil { 42 | panic(err.Error()) 43 | } 44 | protoTemplate = t 45 | } 46 | func init() { 47 | // 48 | // READ/WRITE TEMPLATE 49 | // 50 | t, err := template.New("proto").Parse(` 51 | package amqp 52 | 53 | import ( 54 | "io" 55 | "errors" 56 | "fmt" 57 | ) 58 | 59 | {{range $class := .Amqp.Classes}} 60 | var ClassId{{.NormName}} uint16 = {{.Index}} 61 | {{if .Fields}}{{range $index, $field := .Fields}} 62 | var Mask{{.NormName}} uint16 = 0x{{.MaskIndex}};{{end}} 63 | {{end}} 64 | {{range .Methods}} 65 | // ************************ 66 | // {{.NormName}} 67 | // ************************ 68 | var MethodId{{.NormName}} uint16 = {{.Index}}; 69 | 70 | func (f *{{.NormName}}) MethodIdentifier() (uint16, uint16) { 71 | return {{$class.Index}}, {{.Index}} 72 | } 73 | 74 | func (f *{{.NormName}}) MethodName() string { 75 | return "{{.NormName}}" 76 | } 77 | 78 | func (f *{{.NormName}}) FrameType() byte { 79 | return 1 80 | } 81 | 82 | // Reader 83 | func (f *{{.NormName}}) Read(reader io.Reader, strictMode bool) (err error) { 84 | {{range $index, $field := .Fields}} 85 | {{if eq .BitOffset -1}} 86 | f.{{.NormName}}, err = Read{{.Serializer}}(reader{{.ReadArgs}}) 87 | if err != nil { 88 | return errors.New("Error reading field {{.NormName}}: " + err.Error()) 89 | } 90 | {{else}} 91 | {{if eq .BitOffset 0}} 92 | bits, err := ReadOctet(reader) 93 | if err != nil { 94 | return errors.New("Error reading bit fields" + err.Error()) 95 | } 96 | {{end}} 97 | f.{{.NormName}} = (bits & (1 << {{.BitOffset}}) > 0) 98 | {{end}} 99 | {{end}}{{/* end fields */}} 100 | return 101 | } 102 | 103 | // Writer 104 | func (f *{{.NormName}}) Write(writer io.Writer) (err error) { 105 | if err = WriteShort(writer, {{$class.Index}}); err != nil { 106 | return err 107 | } 108 | if err = WriteShort(writer, {{.Index}}); err != nil { 109 | return err 110 | } 111 | 112 | {{range $index, $field := .Fields}} 113 | {{if eq .GoType "bool"}} 114 | {{if eq .BitOffset 0}} 115 | var bits byte 116 | {{end}} 117 | if f.{{.NormName}} { 118 | bits |= 1 << {{.BitOffset}} 119 | } 120 | {{else}}{{/* else go type is not bool */}} 121 | {{if .PreviousBit}} 122 | err = WriteOctet(writer, bits) 123 | if err != nil { 124 | return errors.New("Error writing bit fields") 125 | } 126 | {{end}} 127 | err = Write{{.Serializer}}(writer, f.{{.NormName}}) 128 | if err != nil { 129 | return errors.New("Error writing field {{.NormName}}") 130 | } 131 | {{end}}{{/* end go type bool */}} 132 | {{end}}{{/* end fields */}} 133 | {{if .BitsAtEnd}} 134 | err = WriteOctet(writer, bits) 135 | if err != nil { 136 | return errors.New("Error writing bit fields") 137 | } 138 | {{end}} 139 | return 140 | } 141 | 142 | {{end}}{{/* end methods */}} 143 | {{end}}{{/* end amqp.classes */}} 144 | 145 | // ******************************** 146 | // METHOD READER 147 | // ******************************** 148 | 149 | func ReadMethod(reader io.Reader, strictMode bool) (MethodFrame, error) { 150 | classIndex, err := ReadShort(reader) 151 | if err != nil { 152 | return nil, err 153 | } 154 | methodIndex, err := ReadShort(reader) 155 | if err != nil { 156 | return nil, err 157 | } 158 | switch { 159 | {{range $class := .Amqp.Classes}} 160 | case classIndex == {{.Index}}: 161 | switch { 162 | {{range .Methods}} 163 | case methodIndex == {{.Index}}: 164 | var method = &{{.NormName}}{} 165 | err = method.Read(reader, strictMode) 166 | if err != nil { 167 | return nil, err 168 | } 169 | return method, nil 170 | {{end}} 171 | } 172 | {{end}} 173 | } 174 | 175 | return nil, errors.New(fmt.Sprintf("Bad method or class Id! classId: %d, methodIndex: %d", classIndex, methodIndex)) 176 | } 177 | `) 178 | if err != nil { 179 | panic(err.Error()) 180 | } 181 | readWriteTemplate = t 182 | } 183 | -------------------------------------------------------------------------------- /binding/binding.go: -------------------------------------------------------------------------------- 1 | package binding 2 | 3 | import ( 4 | "bytes" 5 | "crypto/sha1" 6 | "encoding/json" 7 | "fmt" 8 | "github.com/boltdb/bolt" 9 | "github.com/gogo/protobuf/proto" 10 | "github.com/jeffjenkins/dispatchd/amqp" 11 | "github.com/jeffjenkins/dispatchd/gen" 12 | "github.com/jeffjenkins/dispatchd/persist" 13 | "regexp" 14 | "strings" 15 | ) 16 | 17 | type BindingStateFactory struct{} 18 | 19 | func (bsf *BindingStateFactory) New() proto.Unmarshaler { 20 | return &gen.BindingState{} 21 | } 22 | 23 | var BINDINGS_BUCKET_NAME = []byte("bindings") 24 | 25 | type Binding struct { 26 | gen.BindingState 27 | topicMatcher *regexp.Regexp 28 | } 29 | 30 | var topicRoutingPatternPattern, _ = regexp.Compile(`^((\w+|\*|#)(\.(\w+|\*|#))*|)$`) 31 | 32 | func (binding *Binding) MarshalJSON() ([]byte, error) { 33 | return json.Marshal(map[string]interface{}{ 34 | "queueName": binding.QueueName, 35 | "exchangeName": binding.ExchangeName, 36 | "key": binding.Key, 37 | "arguments": binding.Arguments, 38 | }) 39 | } 40 | 41 | func (binding *Binding) Equals(other *Binding) bool { 42 | if other == nil || binding == nil { 43 | return false 44 | } 45 | return binding.QueueName == other.QueueName && 46 | binding.ExchangeName == other.ExchangeName && 47 | binding.Key == other.Key 48 | } 49 | 50 | func (binding *Binding) Depersist(db *bolt.DB) error { 51 | return persist.DepersistOne(db, BINDINGS_BUCKET_NAME, string(binding.Id)) 52 | } 53 | 54 | func (binding *Binding) DepersistBoltTx(tx *bolt.Tx) error { 55 | bucket, err := tx.CreateBucketIfNotExists(BINDINGS_BUCKET_NAME) 56 | if err != nil { // pragma: nocover 57 | // If we're hitting this it means the disk is full, the db is readonly, 58 | // or something else has gone irrecoverably wrong 59 | panic(fmt.Sprintf("create bucket: %s", err)) 60 | } 61 | return persist.DepersistOneBoltTx(bucket, string(binding.Id)) 62 | } 63 | 64 | func NewBinding(queueName string, exchangeName string, key string, arguments *amqp.Table, topic bool) (*Binding, error) { 65 | var re *regexp.Regexp = nil 66 | // Topic routing key 67 | if topic { 68 | if !topicRoutingPatternPattern.MatchString(key) { 69 | return nil, fmt.Errorf("Topic exchange routing key can only have a-zA-Z0-9, or # or *") 70 | } 71 | var parts = strings.Split(key, ".") 72 | for i, part := range parts { 73 | if part == "*" { 74 | parts[i] = `[^\.]+` 75 | } else if part == "#" { 76 | parts[i] = ".*" 77 | } else { 78 | parts[i] = regexp.QuoteMeta(parts[i]) 79 | } 80 | } 81 | expression := "^" + strings.Join(parts, `\.`) + "$" 82 | var err error = nil 83 | re, err = regexp.Compile(expression) 84 | if err != nil { // pragma: nocover 85 | // This is impossible to get to based on the earlier 86 | // code, so we panic and don't count it for coverage 87 | panic(fmt.Sprintf("Could not compile regex: '%s'", expression)) 88 | } 89 | } 90 | 91 | return &Binding{ 92 | BindingState: gen.BindingState{ 93 | Id: calcId(queueName, exchangeName, key, arguments), 94 | QueueName: queueName, 95 | ExchangeName: exchangeName, 96 | Key: key, 97 | Arguments: arguments, 98 | Topic: topic, 99 | }, 100 | topicMatcher: re, 101 | }, nil 102 | } 103 | 104 | func LoadAllBindings(db *bolt.DB) (map[string]*Binding, error) { 105 | exStateMap, err := persist.LoadAll(db, BINDINGS_BUCKET_NAME, &BindingStateFactory{}) 106 | if err != nil { 107 | return nil, err 108 | } 109 | var ret = make(map[string]*Binding) 110 | for key, state := range exStateMap { 111 | var sb = state.(*gen.BindingState) 112 | // TODO: we don't actually know if topic is true, so this is extra work 113 | // for other exchange binding types 114 | ret[key], err = NewBinding(sb.QueueName, sb.ExchangeName, sb.Key, sb.Arguments, sb.Topic) 115 | if err != nil { 116 | return nil, err 117 | } 118 | } 119 | return ret, nil 120 | } 121 | 122 | func (b *Binding) Persist(db *bolt.DB) error { 123 | return persist.PersistOne(db, BINDINGS_BUCKET_NAME, string(b.Id), b) 124 | } 125 | 126 | func (b *Binding) MatchDirect(message *amqp.BasicPublish) bool { 127 | return message.Exchange == b.ExchangeName && b.Key == message.RoutingKey 128 | } 129 | 130 | func (b *Binding) MatchFanout(message *amqp.BasicPublish) bool { 131 | return message.Exchange == b.ExchangeName 132 | } 133 | 134 | func (b *Binding) MatchTopic(message *amqp.BasicPublish) bool { 135 | var ex = b.ExchangeName == message.Exchange 136 | var match = b.topicMatcher.MatchString(message.RoutingKey) 137 | return ex && match 138 | } 139 | 140 | // Calculate an ID by encoding the QueueBind call that created this binding and 141 | // taking a hash of it. 142 | func calcId(queueName string, exchangeName string, key string, arguments *amqp.Table) []byte { 143 | var method = &amqp.QueueBind{ 144 | Queue: queueName, 145 | Exchange: exchangeName, 146 | RoutingKey: key, 147 | Arguments: arguments, 148 | } 149 | var buffer = bytes.NewBuffer(make([]byte, 0)) 150 | method.Write(buffer) 151 | // trim off the first four bytes, they're the class/method, which we 152 | // already know 153 | var value = buffer.Bytes()[4:] 154 | // bindings aren't named, so we hash the bytes we encoded 155 | hash := sha1.New() 156 | hash.Write(value) 157 | return []byte(hash.Sum(nil)) 158 | } 159 | -------------------------------------------------------------------------------- /binding/binding_test.go: -------------------------------------------------------------------------------- 1 | package binding 2 | 3 | import ( 4 | "encoding/json" 5 | "github.com/boltdb/bolt" 6 | "github.com/jeffjenkins/dispatchd/amqp" 7 | // "github.com/jeffjenkins/dispatchd/persist" 8 | "os" 9 | "testing" 10 | ) 11 | 12 | func TestFanout(t *testing.T) { 13 | b, _ := NewBinding("q1", "e1", "rk", amqp.NewTable(), false) 14 | if b.MatchFanout(basicPublish("DIFF", "asdf")) { 15 | t.Errorf("Fanout did not check exchanges") 16 | } 17 | if !b.MatchFanout(basicPublish("e1", "asdf")) { 18 | t.Errorf("Fanout didn't match regardless of key") 19 | } 20 | } 21 | 22 | func TestDirect(t *testing.T) { 23 | b, _ := NewBinding("q1", "e1", "rk", amqp.NewTable(), false) 24 | if b.MatchDirect(basicPublish("DIFF", "asdf")) { 25 | t.Errorf("MatchDirect did not check exchanges") 26 | } 27 | if b.MatchDirect(basicPublish("e1", "asdf")) { 28 | t.Errorf("MatchDirect matched even with the wrong key") 29 | } 30 | if !b.MatchDirect(basicPublish("e1", "rk")) { 31 | t.Errorf("MatchDirect did not match with the correct key and exchange") 32 | } 33 | } 34 | 35 | func TestEquals(t *testing.T) { 36 | var bNil *Binding = nil 37 | b, _ := NewBinding("q1", "e1", "rk", amqp.NewTable(), false) 38 | same, _ := NewBinding("q1", "e1", "rk", amqp.NewTable(), false) 39 | diffQ, _ := NewBinding("DIFF", "e1", "rk", amqp.NewTable(), false) 40 | diffE, _ := NewBinding("q1", "DIFF", "rk", amqp.NewTable(), false) 41 | diffR, _ := NewBinding("q1", "e1", "DIFF", amqp.NewTable(), false) 42 | 43 | if b == nil || same == nil || diffQ == nil || diffE == nil || diffR == nil { 44 | t.Errorf("Failed to construct bindings") 45 | } 46 | if b.Equals(nil) || bNil.Equals(b) { 47 | t.Errorf("Comparison to nil was true!") 48 | } 49 | if !b.Equals(same) { 50 | t.Errorf("Equals returns false!") 51 | } 52 | if b.Equals(diffQ) { 53 | t.Errorf("Equals returns true on queue name diff!") 54 | } 55 | if b.Equals(diffE) { 56 | t.Errorf("Equals returns true on exchange name diff!") 57 | } 58 | if b.Equals(diffR) { 59 | t.Errorf("Equals returns true on routing key diff!") 60 | } 61 | 62 | } 63 | 64 | func TestTopicRouting(t *testing.T) { 65 | _, err := NewBinding("q1", "e1", "(", amqp.NewTable(), true) 66 | if err == nil { 67 | t.Errorf("Bad topic patter compiled!") 68 | } 69 | basic, _ := NewBinding("q1", "e1", "hello.world", amqp.NewTable(), true) 70 | singleWild, _ := NewBinding("q1", "e1", "hello.*.world", amqp.NewTable(), true) 71 | multiWild, _ := NewBinding("q1", "e1", "hello.#.world", amqp.NewTable(), true) 72 | multiWild2, _ := NewBinding("q1", "e1", "hello.#.world.#", amqp.NewTable(), true) 73 | if !basic.MatchTopic(basicPublish("e1", "hello.world")) { 74 | t.Errorf("Basic match failed") 75 | } 76 | if basic.MatchTopic(basicPublish("e1", "hello.worlds")) { 77 | t.Errorf("Incorrect match with suffix") 78 | } 79 | if !basic.MatchTopic(basicPublish("e1", "hello.world")) { 80 | t.Errorf("Match succeeded despite mismatched exchange") 81 | } 82 | if !singleWild.MatchTopic(basicPublish("e1", "hello.one.world")) { 83 | t.Errorf("Failed to match single wildcard") 84 | } 85 | if singleWild.MatchTopic(basicPublish("e1", "hello.world")) { 86 | t.Errorf("Matched without wildcard token") 87 | } 88 | if !multiWild.MatchTopic(basicPublish("e1", "hello.one.two.three.world")) { 89 | t.Errorf("Failed to match multi wildcard") 90 | } 91 | if !multiWild2.MatchTopic(basicPublish("e1", "hello.one.world.hi")) { 92 | t.Errorf("Multiple multi-wild tokens failed") 93 | } 94 | 95 | } 96 | 97 | func basicPublish(e string, key string) *amqp.BasicPublish { 98 | return &amqp.BasicPublish{ 99 | Exchange: e, 100 | RoutingKey: key, 101 | } 102 | } 103 | 104 | func TestJson(t *testing.T) { 105 | basic, _ := NewBinding("q1", "e1", "hello.world", amqp.NewTable(), true) 106 | var basicBytes, err = basic.MarshalJSON() 107 | if err != nil { 108 | t.Errorf(err.Error()) 109 | } 110 | 111 | expectedBytes, err := json.Marshal(map[string]interface{}{ 112 | "queueName": "q1", 113 | "exchangeName": "e1", 114 | "key": "hello.world", 115 | "arguments": make(map[string]interface{}), 116 | }) 117 | if string(expectedBytes) != string(basicBytes) { 118 | t.Logf("Expected: %s", expectedBytes) 119 | t.Logf("Got: %s", basicBytes) 120 | t.Errorf("Wrong json bytes!") 121 | } 122 | 123 | } 124 | 125 | func TestPersistence(t *testing.T) { 126 | var dbFile = "TestBindingPersistence.db" 127 | os.Remove(dbFile) 128 | defer os.Remove(dbFile) 129 | db, err := bolt.Open(dbFile, 0600, nil) 130 | if err != nil { 131 | t.Errorf("Failed to create db") 132 | } 133 | 134 | // Persist 135 | b, err := NewBinding("q1", "ex1", "rk1", amqp.NewTable(), true) 136 | if err != nil { 137 | t.Errorf("Error in NewBinding") 138 | } 139 | err = b.Persist(db) 140 | if err != nil { 141 | t.Errorf("Error in NewBinding") 142 | } 143 | 144 | // Read 145 | bMap, err := LoadAllBindings(db) 146 | if err != nil { 147 | t.Errorf("Error in LoadAllBindings") 148 | } 149 | if len(bMap) != 1 { 150 | t.Errorf("Wrong number of bindings") 151 | } 152 | for _, b2 := range bMap { 153 | if !b2.Equals(b) { 154 | t.Errorf("Did not get the same binding from the db") 155 | } 156 | } 157 | 158 | // Depersist 159 | b.Depersist(db) 160 | 161 | bMap, err = LoadAllBindings(db) 162 | if err != nil { 163 | t.Errorf("Error in LoadAllBindings") 164 | } 165 | if len(bMap) != 0 { 166 | t.Errorf("Wrong number of bindings") 167 | } 168 | 169 | } 170 | -------------------------------------------------------------------------------- /config.default.json: -------------------------------------------------------------------------------- 1 | { 2 | "users" : { 3 | "guest" : { 4 | "password_bcrypt_base64" : "JDJhJDExJENobGk4dG5rY0RGemJhTjhsV21xR3VNNnFZZ1ZqTzUzQWxtbGtyMHRYN3RkUHMuYjF5SUt5" 5 | } 6 | } 7 | } -------------------------------------------------------------------------------- /consumer/consumer.go: -------------------------------------------------------------------------------- 1 | package consumer 2 | 3 | import ( 4 | "encoding/json" 5 | "github.com/jeffjenkins/dispatchd/amqp" 6 | "github.com/jeffjenkins/dispatchd/msgstore" 7 | "github.com/jeffjenkins/dispatchd/stats" 8 | "sync" 9 | ) 10 | 11 | type Consumer struct { 12 | msgStore *msgstore.MessageStore 13 | arguments *amqp.Table 14 | cchannel ConsumerChannel 15 | ConsumerTag string 16 | exclusive bool 17 | incoming chan bool 18 | noAck bool 19 | noLocal bool 20 | cqueue ConsumerQueue 21 | queueName string 22 | consumeLock sync.Mutex 23 | limitLock sync.Mutex 24 | prefetchSize uint32 25 | prefetchCount uint16 26 | activeSize uint32 27 | activeCount uint16 28 | stopLock sync.Mutex 29 | stopped bool 30 | StatCount uint64 31 | localId int64 32 | // stats 33 | statConsumeOneGetOne stats.Histogram 34 | statConsumeOne stats.Histogram 35 | statConsumeOneAck stats.Histogram 36 | statConsumeOneSend stats.Histogram 37 | } 38 | 39 | type ConsumerQueue interface { 40 | GetOne(rhs ...amqp.MessageResourceHolder) (*amqp.QueueMessage, *amqp.Message) 41 | MaybeReady() chan bool 42 | } 43 | 44 | // The methods necessary for a consumer to interact with a channel 45 | type ConsumerChannel interface { 46 | amqp.MessageResourceHolder 47 | SendContent(method amqp.MethodFrame, msg *amqp.Message) 48 | SendMethod(method amqp.MethodFrame) 49 | FlowActive() bool 50 | AddUnackedMessage(consumerTag string, qm *amqp.QueueMessage, queueName string) uint64 51 | } 52 | 53 | func NewConsumer( 54 | msgStore *msgstore.MessageStore, 55 | arguments *amqp.Table, 56 | cchannel ConsumerChannel, 57 | consumerTag string, 58 | exclusive bool, 59 | noAck bool, 60 | noLocal bool, 61 | cqueue ConsumerQueue, 62 | queueName string, 63 | prefetchSize uint32, 64 | prefetchCount uint16, 65 | localId int64, 66 | ) *Consumer { 67 | return &Consumer{ 68 | msgStore: msgStore, 69 | arguments: arguments, 70 | cchannel: cchannel, 71 | ConsumerTag: consumerTag, 72 | exclusive: exclusive, 73 | incoming: make(chan bool, 1), 74 | noAck: noAck, 75 | noLocal: noLocal, 76 | cqueue: cqueue, 77 | queueName: queueName, 78 | prefetchSize: prefetchSize, 79 | prefetchCount: prefetchCount, 80 | localId: localId, 81 | // stats 82 | statConsumeOneGetOne: stats.MakeHistogram("Consume-One-Get-One"), 83 | statConsumeOne: stats.MakeHistogram("Consume-One-"), 84 | statConsumeOneAck: stats.MakeHistogram("Consume-One-Ack"), 85 | statConsumeOneSend: stats.MakeHistogram("Consume-One-Send"), 86 | } 87 | } 88 | 89 | func (consumer *Consumer) MarshalJSON() ([]byte, error) { 90 | return json.Marshal(map[string]interface{}{ 91 | "tag": consumer.ConsumerTag, 92 | "stats": map[string]interface{}{ 93 | "total": consumer.StatCount, 94 | "active_size_bytes": consumer.activeSize, 95 | "active_count": consumer.activeCount, 96 | }, 97 | "ack": !consumer.noAck, 98 | }) 99 | } 100 | 101 | // TODO: make this a field that we construct on init 102 | func (consumer *Consumer) MessageResourceHolders() []amqp.MessageResourceHolder { 103 | return []amqp.MessageResourceHolder{consumer, consumer.cchannel} 104 | } 105 | 106 | func (consumer *Consumer) Stop() { 107 | if !consumer.stopped { 108 | consumer.stopLock.Lock() 109 | consumer.stopped = true 110 | close(consumer.incoming) 111 | consumer.stopLock.Unlock() 112 | } 113 | } 114 | 115 | func (consumer *Consumer) AcquireResources(qm *amqp.QueueMessage) bool { 116 | consumer.limitLock.Lock() 117 | defer consumer.limitLock.Unlock() 118 | 119 | // If no-local was set on the consumer, reject messages 120 | if consumer.noLocal && qm.LocalId == consumer.localId { 121 | return false 122 | } 123 | 124 | // If the channel is in flow mode we don't consume 125 | // TODO: If flow is mostly for producers, then maybe we 126 | // should consume? I feel like the right answer here is for 127 | // clients to not produce and consume on the same channel. 128 | if !consumer.cchannel.FlowActive() { 129 | return false 130 | } 131 | // If we aren't acking then there are no resource limits. Up the stats 132 | // and return true 133 | if consumer.noAck { 134 | consumer.activeCount += 1 135 | consumer.activeSize += qm.MsgSize 136 | return true 137 | } 138 | 139 | // Calculate whether we're over either of the size and count limits 140 | var sizeOk = consumer.prefetchSize == 0 || consumer.activeSize < consumer.prefetchSize 141 | var countOk = consumer.prefetchCount == 0 || consumer.activeCount < consumer.prefetchCount 142 | if sizeOk && countOk { 143 | consumer.activeCount += 1 144 | consumer.activeSize += qm.MsgSize 145 | return true 146 | } 147 | return false 148 | } 149 | 150 | func (consumer *Consumer) ReleaseResources(qm *amqp.QueueMessage) { 151 | consumer.limitLock.Lock() 152 | consumer.activeCount -= 1 153 | consumer.activeSize -= qm.MsgSize 154 | consumer.limitLock.Unlock() 155 | } 156 | 157 | func (consumer *Consumer) Start() { 158 | go consumer.consume(0) 159 | } 160 | 161 | func (consumer *Consumer) Ping() { 162 | consumer.stopLock.Lock() 163 | defer consumer.stopLock.Unlock() 164 | if !consumer.stopped { 165 | select { 166 | case consumer.incoming <- true: 167 | default: 168 | } 169 | } 170 | 171 | } 172 | 173 | func (consumer *Consumer) consume(id uint16) { 174 | // TODO: what is this doing? 175 | consumer.cqueue.MaybeReady() <- false 176 | for _ = range consumer.incoming { 177 | 178 | consumer.consumeOne() 179 | } 180 | } 181 | 182 | func (consumer *Consumer) consumeOne() { 183 | defer stats.RecordHisto(consumer.statConsumeOne, stats.Start()) 184 | var err error 185 | // Check local limit 186 | consumer.consumeLock.Lock() 187 | defer consumer.consumeLock.Unlock() 188 | // Try to get message/check channel limit 189 | 190 | var start = stats.Start() 191 | var qm, msg = consumer.cqueue.GetOne(consumer.cchannel, consumer) 192 | stats.RecordHisto(consumer.statConsumeOneGetOne, start) 193 | if qm == nil { 194 | return 195 | } 196 | var tag uint64 = 0 197 | start = stats.Start() 198 | if !consumer.noAck { 199 | tag = consumer.cchannel.AddUnackedMessage(consumer.ConsumerTag, qm, consumer.queueName) 200 | } else { 201 | // We aren't expecting an ack, so this is the last time the message 202 | // will be referenced. 203 | var rhs = []amqp.MessageResourceHolder{consumer.cchannel, consumer} 204 | err = consumer.msgStore.RemoveRef(qm, consumer.queueName, rhs) 205 | if err != nil { 206 | panic("Error getting queue message") 207 | } 208 | } 209 | stats.RecordHisto(consumer.statConsumeOneAck, start) 210 | start = stats.Start() 211 | consumer.cchannel.SendContent(&amqp.BasicDeliver{ 212 | ConsumerTag: consumer.ConsumerTag, 213 | DeliveryTag: tag, 214 | Redelivered: qm.DeliveryCount > 0, 215 | Exchange: msg.Exchange, 216 | RoutingKey: msg.Key, 217 | }, msg) 218 | stats.RecordHisto(consumer.statConsumeOneSend, start) 219 | consumer.StatCount += 1 220 | // Since we succeeded in processing a message, ping so that we try again 221 | consumer.Ping() 222 | } 223 | 224 | func (consumer *Consumer) SendCancel() { 225 | consumer.cchannel.SendMethod(&amqp.BasicCancel{consumer.ConsumerTag, true}) 226 | } 227 | 228 | func (consumer *Consumer) ConsumeImmediate(qm *amqp.QueueMessage, msg *amqp.Message) bool { 229 | consumer.consumeLock.Lock() 230 | defer consumer.consumeLock.Unlock() 231 | var tag uint64 = 0 232 | if !consumer.noAck { 233 | tag = consumer.cchannel.AddUnackedMessage(consumer.ConsumerTag, qm, consumer.queueName) 234 | } 235 | consumer.cchannel.SendContent(&amqp.BasicDeliver{ 236 | ConsumerTag: consumer.ConsumerTag, 237 | DeliveryTag: tag, 238 | Redelivered: msg.Redelivered > 0, 239 | Exchange: msg.Exchange, 240 | RoutingKey: msg.Key, 241 | }, msg) 242 | consumer.StatCount += 1 243 | return true 244 | } 245 | 246 | // Send again, leave all stats the same since this consumer was already 247 | // dealing with this message 248 | func (consumer *Consumer) Redeliver(tag uint64, qm *amqp.QueueMessage) { 249 | msg, found := consumer.msgStore.GetNoChecks(qm.Id) 250 | if !found { 251 | panic("Integrity error, message not found in message store") 252 | } 253 | consumer.cchannel.SendContent(&amqp.BasicDeliver{ 254 | ConsumerTag: consumer.ConsumerTag, 255 | DeliveryTag: tag, 256 | Redelivered: msg.Redelivered > 0, 257 | Exchange: msg.Exchange, 258 | RoutingKey: msg.Key, 259 | }, msg) 260 | } 261 | -------------------------------------------------------------------------------- /consumer/consumer_test.go: -------------------------------------------------------------------------------- 1 | package consumer 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestMarshalJson(t *testing.T) { 8 | 9 | } 10 | -------------------------------------------------------------------------------- /dev/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "users" : { 3 | "guest" : { 4 | "password_bcrypt_base64" : "JDJhJDExJENobGk4dG5rY0RGemJhTjhsV21xR3VNNnFZZ1ZqTzUzQWxtbGtyMHRYN3RkUHMuYjF5SUt5" 5 | } 6 | } 7 | } -------------------------------------------------------------------------------- /dispatchd/config.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/json" 5 | "flag" 6 | "io/ioutil" 7 | ) 8 | 9 | var amqpPort int 10 | var amqpPortDefault = 5672 11 | var adminPort int 12 | var adminPortDefault = 8080 13 | var persistDir string 14 | var persistDirDefault = "/data/dispatchd/" 15 | var configFile string 16 | var configFileDefault = "" 17 | var strictMode bool 18 | 19 | func init() { 20 | flag.IntVar(&amqpPort, "amqp-port", 0, "Port for amqp protocol messages. Default: 5672") 21 | flag.IntVar(&adminPort, "admin-port", 0, "Port for admin server. Default: 8080") 22 | flag.StringVar(&persistDir, "persist-dir", "", "Directory for the server and message database files. Default: /data/dispatchd/") 23 | flag.BoolVar(&strictMode, "strict-mode", false, "Obey the AMQP spec even where it differs from common implementations") 24 | flag.StringVar( 25 | &configFile, 26 | "config-file", 27 | "", 28 | "Location of the configuration file. Default: do not read a config file", 29 | ) 30 | } 31 | 32 | func configure() map[string]interface{} { 33 | // TODO: It's no great that this is manual. I should make/find a small library 34 | // to automate this. 35 | var config = make(map[string]interface{}) 36 | if configFile != "" { 37 | config = parseConfigFile(configFile) 38 | } 39 | configureIntParam(&amqpPort, amqpPortDefault, "amqp-port", config) 40 | configureIntParam(&adminPort, adminPortDefault, "admin-port", config) 41 | configureStringParam(&persistDir, persistDirDefault, "persist-dir", config) 42 | configureBoolParam(&strictMode, "strict-mode", config) 43 | _, ok := config["users"] 44 | if !ok { 45 | config["users"] = make(map[string]interface{}) 46 | } 47 | return config 48 | } 49 | 50 | func configureIntParam(param *int, defaultValue int, configName string, config map[string]interface{}) { 51 | if *param != 0 { 52 | return 53 | } 54 | if len(configName) != 0 { 55 | value, ok := config[configName] 56 | if ok { 57 | *param = int(value.(float64)) 58 | return 59 | } 60 | } 61 | *param = defaultValue 62 | } 63 | 64 | func configureBoolParam(param *bool, configName string, config map[string]interface{}) { 65 | if len(configName) != 0 { 66 | value, ok := config[configName] 67 | if ok { 68 | *param = bool(value.(bool)) 69 | return 70 | } 71 | } 72 | *param = false 73 | } 74 | 75 | func configureStringParam(param *string, defaultValue string, configName string, config map[string]interface{}) { 76 | if *param != "" { 77 | return 78 | } 79 | if len(configName) != 0 { 80 | value, ok := config[configName] 81 | if ok { 82 | *param = value.(string) 83 | return 84 | } 85 | } 86 | *param = defaultValue 87 | } 88 | 89 | func parseConfigFile(path string) map[string]interface{} { 90 | ret := make(map[string]interface{}) 91 | data, err := ioutil.ReadFile(path) 92 | if err != nil { 93 | panic("Could not read config file: " + err.Error()) 94 | } 95 | err = json.Unmarshal(data, &ret) 96 | if err != nil { 97 | panic("Could not parse config file: " + err.Error()) 98 | } 99 | return ret 100 | } 101 | -------------------------------------------------------------------------------- /dispatchd/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "flag" 5 | "fmt" 6 | "net" 7 | // _ "net/http/pprof" // uncomment for debugging 8 | "github.com/jeffjenkins/dispatchd/adminserver" 9 | "github.com/jeffjenkins/dispatchd/server" 10 | "os" 11 | "path/filepath" 12 | "runtime" 13 | ) 14 | 15 | func handleConnection(server *server.Server, conn net.Conn) { 16 | server.OpenConnection(conn) 17 | } 18 | 19 | func main() { 20 | flag.Parse() 21 | config := configure() 22 | runtime.SetBlockProfileRate(1) 23 | serverDbPath := filepath.Join(persistDir, "dispatchd-server.db") 24 | msgDbPath := filepath.Join(persistDir, "messages.db") 25 | var server = server.NewServer(serverDbPath, msgDbPath, config["users"].(map[string]interface{}), strictMode) 26 | ln, err := net.Listen("tcp", fmt.Sprintf(":%d", amqpPort)) 27 | if err != nil { 28 | fmt.Printf("Error!\n") 29 | os.Exit(1) 30 | } 31 | fmt.Printf("Listening on port %d\n", amqpPort) 32 | go func() { 33 | adminserver.StartAdminServer(server, adminPort) 34 | }() 35 | for { 36 | conn, err := ln.Accept() 37 | if err != nil { 38 | fmt.Printf("Error accepting connection!\n") 39 | os.Exit(1) 40 | } 41 | go handleConnection(server, conn) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /exchange/exchange.go: -------------------------------------------------------------------------------- 1 | package exchange 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/boltdb/bolt" 7 | "github.com/gogo/protobuf/proto" 8 | "github.com/jeffjenkins/dispatchd/amqp" 9 | "github.com/jeffjenkins/dispatchd/binding" 10 | "github.com/jeffjenkins/dispatchd/gen" 11 | "github.com/jeffjenkins/dispatchd/persist" 12 | "sync" 13 | "time" 14 | ) 15 | 16 | type ExchangeStateFactory struct{} 17 | 18 | func (esf *ExchangeStateFactory) New() proto.Unmarshaler { 19 | return &gen.ExchangeState{} 20 | } 21 | 22 | var EXCHANGES_BUCKET_NAME = []byte("exchanges") 23 | 24 | const ( 25 | EX_TYPE_DIRECT uint8 = 1 26 | EX_TYPE_FANOUT uint8 = 2 27 | EX_TYPE_TOPIC uint8 = 3 28 | EX_TYPE_HEADERS uint8 = 4 29 | ) 30 | 31 | type Exchange struct { 32 | gen.ExchangeState 33 | bindings []*binding.Binding 34 | bindingsLock sync.Mutex 35 | incoming chan amqp.Frame 36 | Closed bool 37 | deleteActive time.Time 38 | deleteChan chan *Exchange 39 | autodeletePeriod time.Duration 40 | } 41 | 42 | func (exchange *Exchange) Close() { 43 | exchange.Closed = true 44 | } 45 | 46 | func (exchange *Exchange) MarshalJSON() ([]byte, error) { 47 | var typ, err = exchangeTypeToName(exchange.ExType) 48 | if err != nil { 49 | return nil, err 50 | } 51 | return json.Marshal(map[string]interface{}{ 52 | "type": typ, 53 | "bindings": exchange.bindings, 54 | }) 55 | } 56 | 57 | func NewExchange( 58 | name string, 59 | extype uint8, 60 | durable bool, 61 | autodelete bool, 62 | internal bool, 63 | arguments *amqp.Table, 64 | system bool, 65 | deleteChan chan *Exchange, 66 | ) *Exchange { 67 | return &Exchange{ 68 | ExchangeState: gen.ExchangeState{ 69 | Name: name, 70 | ExType: extype, 71 | Durable: durable, 72 | AutoDelete: autodelete, 73 | Internal: internal, 74 | Arguments: arguments, 75 | System: system, 76 | }, 77 | deleteChan: deleteChan, 78 | // not passed in 79 | incoming: make(chan amqp.Frame), 80 | bindings: make([]*binding.Binding, 0), 81 | autodeletePeriod: 5 * time.Second, 82 | } 83 | } 84 | 85 | func NewFromExchangeState(exState *gen.ExchangeState, deleteChan chan *Exchange) *Exchange { 86 | return &Exchange{ 87 | ExchangeState: *exState, 88 | deleteChan: deleteChan, 89 | incoming: make(chan amqp.Frame), 90 | bindings: make([]*binding.Binding, 0), 91 | autodeletePeriod: 5 * time.Second, 92 | } 93 | } 94 | 95 | func NewFromMethod(method *amqp.ExchangeDeclare, system bool, exchangeDeleter chan *Exchange) (*Exchange, *amqp.AMQPError) { 96 | var classId, methodId = method.MethodIdentifier() 97 | var tp, err = ExchangeNameToType(method.Type) 98 | if err != nil || tp == EX_TYPE_HEADERS { 99 | return nil, amqp.NewHardError(503, "Bad exchange type", classId, methodId) 100 | } 101 | var ex = NewExchange( 102 | method.Exchange, 103 | tp, 104 | method.Durable, 105 | method.AutoDelete, 106 | method.Internal, 107 | method.Arguments, 108 | system, 109 | exchangeDeleter, 110 | ) 111 | return ex, nil 112 | } 113 | 114 | func (ex1 *Exchange) EquivalentExchanges(ex2 *Exchange) bool { 115 | // NOTE: auto-delete is ignored for existing exchanges, so we 116 | // do not check it here. 117 | if ex1.Name != ex2.Name { 118 | return false 119 | } 120 | if ex1.ExType != ex2.ExType { 121 | return false 122 | } 123 | if ex1.Durable != ex2.Durable { 124 | return false 125 | } 126 | if ex1.Internal != ex2.Internal { 127 | return false 128 | } 129 | if !amqp.EquivalentTables(ex1.Arguments, ex2.Arguments) { 130 | return false 131 | } 132 | return true 133 | } 134 | 135 | func ExchangeNameToType(et string) (uint8, error) { 136 | switch { 137 | case et == "direct": 138 | return EX_TYPE_DIRECT, nil 139 | case et == "fanout": 140 | return EX_TYPE_FANOUT, nil 141 | case et == "topic": 142 | return EX_TYPE_TOPIC, nil 143 | case et == "headers": 144 | return EX_TYPE_HEADERS, nil 145 | default: 146 | return 0, fmt.Errorf("Unknown exchang type '%s', %d %d", et, len(et), len("direct")) 147 | } 148 | } 149 | 150 | func exchangeTypeToName(et uint8) (string, error) { 151 | switch { 152 | case et == EX_TYPE_DIRECT: 153 | return "direct", nil 154 | case et == EX_TYPE_FANOUT: 155 | return "fanout", nil 156 | case et == EX_TYPE_TOPIC: 157 | return "topic", nil 158 | case et == EX_TYPE_HEADERS: 159 | return "headers", nil 160 | default: 161 | return "", fmt.Errorf("bad exchange type: %d", et) 162 | } 163 | } 164 | 165 | func LoadAllExchanges(db *bolt.DB, deleteChan chan *Exchange) (map[string]*Exchange, error) { 166 | exStateMap, err := persist.LoadAll(db, EXCHANGES_BUCKET_NAME, &ExchangeStateFactory{}) 167 | if err != nil { 168 | return nil, err 169 | } 170 | var ret = make(map[string]*Exchange) 171 | for key, state := range exStateMap { 172 | ret[key] = NewFromExchangeState(state.(*gen.ExchangeState), deleteChan) 173 | } 174 | return ret, nil 175 | } 176 | 177 | func (exchange *Exchange) QueuesForPublish(msg *amqp.Message) (map[string]bool, *amqp.AMQPError) { 178 | var queues = make(map[string]bool) 179 | if msg.Method.Exchange != exchange.Name { 180 | return queues, nil 181 | } 182 | switch { 183 | case exchange.ExType == EX_TYPE_DIRECT: 184 | // In a direct exchange we can return the first match since there is 185 | // only one queue with a particular name 186 | for _, binding := range exchange.bindings { 187 | if binding.MatchDirect(msg.Method) { 188 | queues[binding.QueueName] = true 189 | return queues, nil 190 | } 191 | } 192 | case exchange.ExType == EX_TYPE_FANOUT: 193 | for _, binding := range exchange.bindings { 194 | queues[binding.QueueName] = true 195 | } 196 | case exchange.ExType == EX_TYPE_TOPIC: 197 | for _, binding := range exchange.bindings { 198 | if binding.MatchTopic(msg.Method) { 199 | var _, alreadySeen = queues[binding.QueueName] 200 | if alreadySeen { 201 | continue 202 | } 203 | queues[binding.QueueName] = true 204 | } 205 | } 206 | // case exchange.ExType == EX_TYPE_HEADERS: 207 | // // TODO: implement 208 | // panic("Headers is not implemented!") 209 | default: // pragma: nocover 210 | panic("Unknown exchange type created somehow. Server integrity error!") 211 | } 212 | return queues, nil 213 | } 214 | 215 | func (exchange *Exchange) Persist(db *bolt.DB) error { 216 | var key = exchange.Name 217 | if key == "" { 218 | key = "~" 219 | } 220 | return persist.PersistOne(db, EXCHANGES_BUCKET_NAME, key, &exchange.ExchangeState) 221 | } 222 | 223 | func NewFromDisk(db *bolt.DB, key string, deleteChan chan *Exchange) (ex *Exchange, err error) { 224 | err = db.View(func(tx *bolt.Tx) error { 225 | bucket := tx.Bucket(EXCHANGES_BUCKET_NAME) 226 | if bucket == nil { 227 | return fmt.Errorf("Bucket not found: 'exchanges'") 228 | } 229 | ex, err = NewFromDiskBoltTx(bucket, []byte(key), deleteChan) 230 | return err 231 | }) 232 | return 233 | } 234 | 235 | func NewFromDiskBoltTx(bucket *bolt.Bucket, key []byte, deleteChan chan *Exchange) (ex *Exchange, err error) { 236 | var lookupKey = key 237 | if len(key) == 0 { 238 | lookupKey = []byte{'~'} 239 | } 240 | exBytes := bucket.Get(lookupKey) 241 | if exBytes == nil { 242 | return nil, fmt.Errorf("Key not found: '%s'", key) 243 | } 244 | exState := gen.ExchangeState{} 245 | err = exState.Unmarshal(exBytes) 246 | if err != nil { 247 | return nil, fmt.Errorf("Could not unmarshal exchange %s", key) 248 | } 249 | return NewFromExchangeState(&exState, deleteChan), nil 250 | } 251 | 252 | func (exchange *Exchange) Depersist(db *bolt.DB) error { 253 | return db.Update(func(tx *bolt.Tx) error { 254 | bucket := tx.Bucket(EXCHANGES_BUCKET_NAME) 255 | if bucket == nil { 256 | return fmt.Errorf("Bucket not found: '%s'", bucket) 257 | } 258 | for _, binding := range exchange.bindings { 259 | if err := binding.DepersistBoltTx(tx); err != nil { // pragma: nocover 260 | return err 261 | } 262 | } 263 | return persist.DepersistOneBoltTx(bucket, exchange.Name) 264 | }) 265 | } 266 | 267 | func (exchange *Exchange) IsTopic() bool { 268 | return exchange.ExType == EX_TYPE_TOPIC 269 | } 270 | 271 | func (exchange *Exchange) AddBinding(b *binding.Binding, connId int64) error { 272 | exchange.bindingsLock.Lock() 273 | defer exchange.bindingsLock.Unlock() 274 | 275 | for _, b2 := range exchange.bindings { 276 | if b.Equals(b2) { 277 | return nil 278 | } 279 | } 280 | 281 | if exchange.AutoDelete { 282 | exchange.deleteActive = time.Unix(0, 0) 283 | } 284 | exchange.bindings = append(exchange.bindings, b) 285 | return nil 286 | } 287 | 288 | func (exchange *Exchange) BindingsForQueue(queueName string) []*binding.Binding { 289 | var ret = make([]*binding.Binding, 0) 290 | exchange.bindingsLock.Lock() 291 | defer exchange.bindingsLock.Unlock() 292 | for _, b := range exchange.bindings { 293 | if b.QueueName == queueName { 294 | ret = append(ret, b) 295 | } 296 | } 297 | return ret 298 | } 299 | 300 | func (exchange *Exchange) RemoveBindingsForQueue(queueName string) { 301 | var remaining = make([]*binding.Binding, 0) 302 | exchange.bindingsLock.Lock() 303 | defer exchange.bindingsLock.Unlock() 304 | for _, b := range exchange.bindings { 305 | if b.QueueName != queueName { 306 | remaining = append(remaining, b) 307 | } 308 | } 309 | exchange.bindings = remaining 310 | } 311 | 312 | func (exchange *Exchange) RemoveBinding(binding *binding.Binding) error { 313 | exchange.bindingsLock.Lock() 314 | defer exchange.bindingsLock.Unlock() 315 | 316 | // Delete binding 317 | for i, b := range exchange.bindings { 318 | if binding.Equals(b) { 319 | exchange.bindings = append(exchange.bindings[:i], exchange.bindings[i+1:]...) 320 | if exchange.AutoDelete && len(exchange.bindings) == 0 { 321 | go exchange.autodeleteTimeout() 322 | } 323 | return nil 324 | } 325 | } 326 | return nil 327 | } 328 | 329 | func (exchange *Exchange) autodeleteTimeout() { 330 | // There's technically a race condition here where a new binding could be 331 | // added right as we check this, but after a 5 second wait with no activity 332 | // I think this is probably safe enough. 333 | var now = time.Now() 334 | exchange.deleteActive = now 335 | time.Sleep(exchange.autodeletePeriod) 336 | if exchange.deleteActive == now { 337 | exchange.deleteChan <- exchange 338 | } 339 | } 340 | -------------------------------------------------------------------------------- /gen/server.proto: -------------------------------------------------------------------------------- 1 | package gen; 2 | 3 | import "github.com/jeffjenkins/dispatchd/amqp/amqp.proto"; 4 | import "github.com/gogo/protobuf/gogoproto/gogo.proto"; 5 | 6 | option (gogoproto.marshaler_all) = true; 7 | option (gogoproto.sizer_all) = true; 8 | option (gogoproto.unmarshaler_all) = true; 9 | 10 | message ExchangeState { 11 | option (gogoproto.goproto_unrecognized) = false; 12 | option (gogoproto.goproto_getters) = false; 13 | optional string name = 1 [(gogoproto.nullable) = false]; 14 | optional uint32 ex_type = 2 [(gogoproto.casttype) = "uint8", (gogoproto.nullable) = false]; 15 | optional bool passive = 3 [(gogoproto.nullable) = false]; 16 | optional bool durable = 4 [(gogoproto.nullable) = false]; 17 | optional bool auto_delete = 5 [(gogoproto.nullable) = false]; 18 | optional bool internal = 6 [(gogoproto.nullable) = false]; 19 | optional bool system = 7 [(gogoproto.nullable) = false]; 20 | optional amqp.Table arguments = 8; 21 | } 22 | 23 | message BindingState { 24 | option (gogoproto.goproto_unrecognized) = false; 25 | option (gogoproto.goproto_getters) = false; 26 | optional bytes id = 1; 27 | optional string queue_name = 2 [(gogoproto.nullable) = false]; 28 | optional string exchange_name = 3 [(gogoproto.nullable) = false]; 29 | optional string key = 4 [(gogoproto.nullable) = false]; 30 | optional amqp.Table arguments = 5; 31 | optional bool topic = 6 [(gogoproto.nullable) = false]; 32 | } 33 | 34 | message QueueState { 35 | option (gogoproto.goproto_unrecognized) = false; 36 | option (gogoproto.goproto_getters) = false; 37 | optional string name = 1 [(gogoproto.nullable) = false]; 38 | optional bool durable = 2 [(gogoproto.nullable) = false]; 39 | optional amqp.Table arguments = 3; 40 | } -------------------------------------------------------------------------------- /msgstore/msgstore_test.go: -------------------------------------------------------------------------------- 1 | package msgstore 2 | 3 | import ( 4 | // "container/list" 5 | "fmt" 6 | "github.com/jeffjenkins/dispatchd/amqp" 7 | "os" 8 | "testing" 9 | ) 10 | 11 | func TestWrite(t *testing.T) { 12 | // Setup 13 | var dbFile = "TestWrite.db" 14 | os.Remove(dbFile) 15 | defer os.Remove(dbFile) 16 | rhs := []amqp.MessageResourceHolder{&TestResourceHolder{}} 17 | ms, err := NewMessageStore(dbFile) 18 | // Create messages 19 | msg1 := amqp.RandomMessage(true) 20 | msg2 := amqp.RandomMessage(true) 21 | fmt.Printf("Creating ids: %d, %d\n", msg1.Id, msg2.Id) 22 | 23 | // Store messages and delete one 24 | fmt.Println("Adding message 1") 25 | ms.AddMessage(msg1, []string{"some-queue", "some-other-queue"}) 26 | fmt.Println("Adding message 2") 27 | qm2Map, err := ms.AddMessage(msg2, []string{"some-queue"}) 28 | _, err = ms.GetAndDecrRef(qm2Map["some-queue"][0], "some-queue", rhs) 29 | if err != nil { 30 | t.Errorf(err.Error()) 31 | return 32 | } 33 | 34 | // Close DB 35 | ms.db.Close() 36 | keys := map[int64]bool{ 37 | msg1.Id: true, 38 | // msg2.Id: true, 39 | } 40 | // Assert that the DB is correct 41 | err = assertKeys(dbFile, keys) 42 | if err != nil { 43 | t.Errorf(err.Error()) 44 | return 45 | } 46 | 47 | // try loading from disk in the message store 48 | ms2, err := NewMessageStore(dbFile) 49 | _, err = ms2.LoadQueueFromDisk("some-queue") 50 | if err != nil { 51 | t.Errorf(err.Error()) 52 | return 53 | } 54 | _, err = ms2.LoadQueueFromDisk("some-other-queue") 55 | if err != nil { 56 | t.Errorf(err.Error()) 57 | return 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /msgstore/testlib.go: -------------------------------------------------------------------------------- 1 | package msgstore 2 | 3 | import ( 4 | "fmt" 5 | "github.com/boltdb/bolt" 6 | "github.com/jeffjenkins/dispatchd/amqp" 7 | "reflect" 8 | ) 9 | 10 | // ids to map 11 | func idsToMap() { 12 | 13 | } 14 | 15 | // Check if the msg store is composed of exactly these keys 16 | func assertKeys(dbName string, keys map[int64]bool) error { 17 | // Open DB 18 | db, err := bolt.Open(dbName, 0600, nil) 19 | defer db.Close() 20 | if err != nil { 21 | return err 22 | } 23 | err = db.View(func(tx *bolt.Tx) error { 24 | // Check index 25 | bucket := tx.Bucket([]byte("message_index")) 26 | if bucket == nil { 27 | return nil 28 | } 29 | 30 | // get from db 31 | var indexKeys, err1 = keysForBucket(tx, MESSAGE_INDEX_BUCKET) 32 | var contentKeys, err2 = keysForBucket(tx, MESSAGE_CONTENT_BUCKET) 33 | if err1 != nil { 34 | return err1 35 | } 36 | if err2 != nil { 37 | return err2 38 | } 39 | 40 | // Check equality 41 | // TODO: return key diff 42 | indexNotKeys := subtract(indexKeys, keys) 43 | keysNotIndex := subtract(keys, indexKeys) 44 | contentNotKeys := subtract(contentKeys, keys) 45 | keysNotContent := subtract(keys, contentKeys) 46 | 47 | if !reflect.DeepEqual(keys, indexKeys) { 48 | return fmt.Errorf("Different values in index!\nindexNotKeys:%q\nkeysNotIndex:%q", indexNotKeys, keysNotIndex) 49 | } 50 | if !reflect.DeepEqual(keys, contentKeys) { 51 | return fmt.Errorf("Different values in content!\ncontentNotKeys:%q\nkeysNotContent:%q", contentNotKeys, keysNotContent) 52 | } 53 | return nil 54 | }) 55 | if err != nil { 56 | return err 57 | } 58 | return nil 59 | } 60 | 61 | func subtract(original map[int64]bool, subtractThis map[int64]bool) []int64 { 62 | ret := make([]int64, 0) 63 | for id, _ := range original { 64 | _, found := subtractThis[id] 65 | if !found { 66 | ret = append(ret, id) 67 | } 68 | } 69 | return ret 70 | } 71 | 72 | func keysForBucket(tx *bolt.Tx, bucketName []byte) (map[int64]bool, error) { 73 | // Check index 74 | bucket := tx.Bucket(bucketName) 75 | if bucket == nil { 76 | return nil, fmt.Errorf("No bucket!") 77 | } 78 | var cursor = bucket.Cursor() 79 | var keys = make(map[int64]bool) 80 | for bid, _ := cursor.First(); bid != nil; bid, _ = cursor.Next() { 81 | fmt.Printf("%s, key:%d\n", bucketName, bytesToInt64(bid)) 82 | keys[bytesToInt64(bid)] = true 83 | } 84 | return keys, nil 85 | } 86 | 87 | type TestResourceHolder struct { 88 | } 89 | 90 | func (trh *TestResourceHolder) AcquireResources(qm *amqp.QueueMessage) bool { 91 | return true 92 | } 93 | func (trh *TestResourceHolder) ReleaseResources(qm *amqp.QueueMessage) { 94 | 95 | } 96 | -------------------------------------------------------------------------------- /persist/persist.go: -------------------------------------------------------------------------------- 1 | package persist 2 | 3 | import ( 4 | "fmt" 5 | "github.com/boltdb/bolt" 6 | "github.com/gogo/protobuf/proto" 7 | ) 8 | 9 | type UnmarshalerFactory interface { 10 | New() proto.Unmarshaler 11 | } 12 | 13 | // 14 | // Persist 15 | // 16 | 17 | func PersistOne(db *bolt.DB, bucketName []byte, key string, obj proto.Marshaler) error { 18 | return db.Update(func(tx *bolt.Tx) error { 19 | bucket, err := tx.CreateBucketIfNotExists(bucketName) 20 | if err != nil { // pragma: nocover 21 | return fmt.Errorf("create bucket: %s", err) 22 | } 23 | return PersistOneBoltTx(bucket, key, obj) 24 | }) 25 | } 26 | 27 | func PersistOneBoltTx(bucket *bolt.Bucket, key string, obj proto.Marshaler) error { 28 | exBytes, err := obj.Marshal() 29 | if err != nil { // pragma: nocover -- no idea how to produce this error 30 | return fmt.Errorf("Could not marshal object") 31 | } 32 | return bucket.Put([]byte(key), exBytes) 33 | } 34 | 35 | func PersistMany(db *bolt.DB, bucketName []byte, objs map[string]proto.Marshaler) error { 36 | return db.Update(func(tx *bolt.Tx) error { 37 | bucket, err := tx.CreateBucketIfNotExists(bucketName) 38 | if err != nil { // pragma: nocover 39 | return fmt.Errorf("create bucket: %s", err) 40 | } 41 | return PersistManyBoltTx(bucket, objs) 42 | }) 43 | } 44 | 45 | func PersistManyBoltTx(bucket *bolt.Bucket, objs map[string]proto.Marshaler) error { 46 | for key, obj := range objs { 47 | err := PersistOneBoltTx(bucket, key, obj) 48 | if err != nil { 49 | return err 50 | } 51 | } 52 | return nil 53 | } 54 | 55 | // 56 | // Load 57 | // 58 | 59 | func LoadOne(db *bolt.DB, bucketName []byte, key string, obj proto.Unmarshaler) error { 60 | return db.Update(func(tx *bolt.Tx) error { 61 | bucket := tx.Bucket(bucketName) 62 | if bucket == nil { 63 | return fmt.Errorf("Bucket not found: '%s'", bucket) 64 | } 65 | return LoadOneBoltTx(bucket, key, obj) 66 | }) 67 | } 68 | 69 | func LoadOneBoltTx(bucket *bolt.Bucket, key string, obj proto.Unmarshaler) error { 70 | objBytes := bucket.Get([]byte(key)) 71 | if objBytes == nil { 72 | return fmt.Errorf("Key not found: '%s'", key) 73 | } 74 | err := obj.Unmarshal(objBytes) 75 | if err != nil { 76 | return fmt.Errorf("Could not unmarshal key %s", key) 77 | } 78 | return nil 79 | } 80 | 81 | func LoadMany(db *bolt.DB, bucketName []byte, objs map[string]proto.Unmarshaler) error { 82 | return db.Update(func(tx *bolt.Tx) error { 83 | bucket := tx.Bucket(bucketName) 84 | if bucket == nil { // pragma: nocover 85 | return fmt.Errorf("create bucket: '%s'", bucket) 86 | } 87 | return LoadManyBoltTx(bucket, objs) 88 | }) 89 | } 90 | 91 | func LoadManyBoltTx(bucket *bolt.Bucket, objs map[string]proto.Unmarshaler) error { 92 | for key, obj := range objs { 93 | err := LoadOneBoltTx(bucket, key, obj) 94 | if err != nil { 95 | return err 96 | } 97 | } 98 | return nil 99 | } 100 | 101 | func LoadAll(db *bolt.DB, bucket []byte, factory UnmarshalerFactory) (map[string]proto.Unmarshaler, error) { 102 | ret := make(map[string]proto.Unmarshaler) 103 | err := db.View(func(tx *bolt.Tx) error { 104 | bucket := tx.Bucket(bucket) 105 | if bucket == nil { 106 | return nil 107 | } 108 | // iterate through queues 109 | cursor := bucket.Cursor() 110 | for name, data := cursor.First(); name != nil; name, data = cursor.Next() { 111 | obj := factory.New() 112 | err := obj.Unmarshal(data) 113 | if err != nil { 114 | return fmt.Errorf("Could not unmarshal key %s", string(name)) 115 | } 116 | ret[string(name)] = obj 117 | } 118 | return nil 119 | }) 120 | if err != nil { 121 | return nil, err 122 | } 123 | return ret, nil 124 | } 125 | 126 | // 127 | // Depersist 128 | // 129 | 130 | func DepersistOne(db *bolt.DB, bucketName []byte, key string) error { 131 | return db.Update(func(tx *bolt.Tx) error { 132 | bucket := tx.Bucket(bucketName) 133 | if bucket == nil { 134 | return fmt.Errorf("Bucket not found: '%s'", bucket) 135 | } 136 | return DepersistOneBoltTx(bucket, key) 137 | }) 138 | } 139 | 140 | func DepersistOneBoltTx(bucket *bolt.Bucket, key string) error { 141 | return bucket.Delete([]byte(key)) 142 | } 143 | 144 | func DepersistMany(db *bolt.DB, bucketName []byte, keys map[string]bool) error { 145 | return db.Update(func(tx *bolt.Tx) error { 146 | bucket := tx.Bucket(bucketName) 147 | if bucket == nil { // pragma: nocover 148 | return fmt.Errorf("create bucket: '%s'", bucket) 149 | } 150 | return DepersistManyBoltTx(bucket, keys) 151 | }) 152 | } 153 | 154 | func DepersistManyBoltTx(bucket *bolt.Bucket, keys map[string]bool) error { 155 | for key, _ := range keys { 156 | err := DepersistOneBoltTx(bucket, key) 157 | if err != nil { 158 | return err 159 | } 160 | } 161 | return nil 162 | } 163 | -------------------------------------------------------------------------------- /persist/persist_test.go: -------------------------------------------------------------------------------- 1 | package persist 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestPersist(t *testing.T) { 8 | 9 | } 10 | -------------------------------------------------------------------------------- /queue/queue.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "container/list" 5 | "encoding/json" 6 | "errors" 7 | "fmt" 8 | "github.com/boltdb/bolt" 9 | "github.com/gogo/protobuf/proto" 10 | "github.com/jeffjenkins/dispatchd/amqp" 11 | "github.com/jeffjenkins/dispatchd/consumer" 12 | "github.com/jeffjenkins/dispatchd/gen" 13 | "github.com/jeffjenkins/dispatchd/msgstore" 14 | "github.com/jeffjenkins/dispatchd/persist" 15 | "github.com/jeffjenkins/dispatchd/stats" 16 | "sync" 17 | "time" 18 | ) 19 | 20 | var QUEUE_BUCKET_NAME = []byte("queues") 21 | 22 | type QueueStateFactory struct{} 23 | 24 | func (qsf *QueueStateFactory) New() proto.Unmarshaler { 25 | return &gen.QueueState{} 26 | } 27 | 28 | type Queue struct { 29 | gen.QueueState 30 | autoDelete bool 31 | exclusive bool 32 | Closed bool 33 | objLock sync.RWMutex 34 | queue *list.List // int64 35 | queueLock sync.Mutex 36 | consumerLock sync.RWMutex 37 | consumers []*consumer.Consumer // *Consumer 38 | currentConsumer int 39 | statCount uint64 40 | maybeReady chan bool 41 | soleConsumer *consumer.Consumer 42 | ConnId int64 43 | deleteActive time.Time 44 | hasHadConsumers bool 45 | msgStore *msgstore.MessageStore 46 | statProcOne stats.Histogram 47 | deleteChan chan *Queue 48 | } 49 | 50 | func NewQueue( 51 | name string, 52 | durable bool, 53 | exclusive bool, 54 | autoDelete bool, 55 | arguments *amqp.Table, 56 | connId int64, 57 | msgStore *msgstore.MessageStore, 58 | deleteChan chan *Queue, 59 | ) *Queue { 60 | return &Queue{ 61 | QueueState: gen.QueueState{ 62 | Name: name, 63 | Durable: durable, 64 | Arguments: arguments, 65 | }, 66 | exclusive: exclusive, 67 | autoDelete: autoDelete, 68 | ConnId: connId, 69 | msgStore: msgStore, 70 | deleteChan: deleteChan, 71 | // Fields that aren't passed in 72 | statProcOne: stats.MakeHistogram("queue-proc-one"), 73 | queue: list.New(), 74 | consumers: make([]*consumer.Consumer, 0, 1), 75 | maybeReady: make(chan bool, 1), 76 | } 77 | } 78 | 79 | func NewFromPersistedState(state *gen.QueueState, msgStore *msgstore.MessageStore, deleteChan chan *Queue) *Queue { 80 | return &Queue{ 81 | QueueState: *state, 82 | exclusive: false, 83 | autoDelete: false, 84 | ConnId: -1, 85 | msgStore: msgStore, 86 | deleteChan: deleteChan, 87 | // Fields that aren't passed in 88 | statProcOne: stats.MakeHistogram("queue-proc-one"), 89 | queue: list.New(), 90 | consumers: make([]*consumer.Consumer, 0, 1), 91 | maybeReady: make(chan bool, 1), 92 | } 93 | } 94 | 95 | func (q1 *Queue) EquivalentQueues(q2 *Queue) bool { 96 | if q1 == nil { 97 | return q2 == nil 98 | } 99 | if q2 == nil { 100 | return false 101 | } 102 | 103 | // Note: autodelete is not included since the spec says to ignore 104 | // the field if the queue is already created 105 | if q1.Name != q2.Name { 106 | return false 107 | } 108 | if q1.Durable != q2.Durable { 109 | return false 110 | } 111 | if q1.exclusive != q2.exclusive { 112 | return false 113 | } 114 | if !amqp.EquivalentTables(q1.Arguments, q2.Arguments) { 115 | return false 116 | } 117 | return true 118 | } 119 | 120 | func (q *Queue) Len() uint32 { 121 | var l = q.queue.Len() 122 | if l < 0 { 123 | panic("Queue length overflow!") 124 | } 125 | return uint32(l) 126 | } 127 | 128 | func (q *Queue) ActiveConsumerCount() uint32 { 129 | // TODO(MUST): don't count consumers in the Channel.Flow state once 130 | // that is implemented 131 | return uint32(len(q.consumers)) 132 | } 133 | 134 | func (q *Queue) MarshalJSON() ([]byte, error) { 135 | return json.Marshal(map[string]interface{}{ 136 | "name": q.Name, 137 | "durable": q.Durable, 138 | "exclusive": q.exclusive, 139 | "connId": q.ConnId, 140 | "autoDelete": q.autoDelete, 141 | "size": q.queue.Len(), 142 | "consumers": q.consumers, 143 | }) 144 | } 145 | 146 | func (q *Queue) Persist(db *bolt.DB) error { 147 | return persist.PersistOne(db, QUEUE_BUCKET_NAME, q.Name, q) 148 | } 149 | 150 | func (q *Queue) Depersist(db *bolt.DB) error { 151 | return persist.DepersistOne(db, QUEUE_BUCKET_NAME, q.Name) 152 | } 153 | 154 | func (q *Queue) DepersistBoltTx(tx *bolt.Tx) error { 155 | bucket, err := tx.CreateBucketIfNotExists(QUEUE_BUCKET_NAME) 156 | if err != nil { 157 | return fmt.Errorf("create bucket: %s", err) 158 | } 159 | return persist.DepersistOneBoltTx(bucket, q.Name) 160 | } 161 | 162 | func LoadAllQueues(db *bolt.DB, msgStore *msgstore.MessageStore, deleteChan chan *Queue) (map[string]*Queue, error) { 163 | queueStateMap, err := persist.LoadAll(db, QUEUE_BUCKET_NAME, &QueueStateFactory{}) 164 | if err != nil { 165 | return nil, err 166 | } 167 | var ret = make(map[string]*Queue) 168 | for key, state := range queueStateMap { 169 | ret[key] = NewFromPersistedState(state.(*gen.QueueState), msgStore, deleteChan) 170 | } 171 | return ret, nil 172 | } 173 | 174 | func (q *Queue) LoadFromMsgStore(msgStore *msgstore.MessageStore) { 175 | queueList, err := msgStore.LoadQueueFromDisk(q.Name) 176 | if err != nil { 177 | panic("Integrity error reading queue from disk! " + err.Error()) 178 | } 179 | q.queue = queueList 180 | select { 181 | case q.maybeReady <- true: 182 | default: 183 | } 184 | } 185 | 186 | func (q *Queue) Close() { 187 | // This discards any messages which would be added. It does not 188 | // do cleanup 189 | q.queueLock.Lock() 190 | defer q.queueLock.Unlock() 191 | q.Closed = true 192 | } 193 | 194 | func (q *Queue) Purge() uint32 { 195 | q.queueLock.Lock() 196 | defer q.queueLock.Unlock() 197 | return q.purgeNotThreadSafe() 198 | } 199 | 200 | func (q *Queue) purgeNotThreadSafe() uint32 { 201 | var length = q.queue.Len() 202 | q.queue.Init() 203 | return uint32(length) 204 | } 205 | 206 | func (q *Queue) Add(qm *amqp.QueueMessage) bool { 207 | // NOTE: I tried using consumeImmediate before adding things to the queue, 208 | // but it caused a pretty significant slowdown. 209 | q.queueLock.Lock() 210 | defer q.queueLock.Unlock() 211 | if !q.Closed { 212 | q.statCount += 1 213 | q.queue.PushBack(qm) 214 | select { 215 | case q.maybeReady <- true: 216 | default: 217 | } 218 | return true 219 | } else { 220 | return false 221 | } 222 | } 223 | 224 | func (q *Queue) ConsumeImmediate(qm *amqp.QueueMessage) bool { 225 | // TODO: randomize or round-robin through consumers 226 | q.consumerLock.RLock() 227 | defer q.consumerLock.RUnlock() 228 | for _, consumer := range q.consumers { 229 | var msg, acquired = q.msgStore.Get(qm, consumer.MessageResourceHolders()) 230 | if acquired { 231 | consumer.ConsumeImmediate(qm, msg) 232 | return true 233 | } 234 | } 235 | return false 236 | } 237 | 238 | func (q *Queue) Delete(ifUnused bool, ifEmpty bool) (uint32, error) { 239 | // Lock 240 | if !q.Closed { 241 | panic("Queue deleted before it was closed!") 242 | } 243 | q.queueLock.Lock() 244 | defer q.queueLock.Unlock() 245 | 246 | // Check 247 | var usedOk = !ifUnused || len(q.consumers) == 0 248 | var emptyOk = !ifEmpty || q.queue.Len() == 0 249 | if !usedOk { 250 | return 0, errors.New("if-unused specified and there are consumers") 251 | } 252 | if !emptyOk { 253 | return 0, errors.New("if-empty specified and there are messages in the queue") 254 | } 255 | // Purge 256 | q.cancelConsumers() 257 | return q.purgeNotThreadSafe(), nil 258 | } 259 | 260 | func (q *Queue) Readd(queueName string, msg *amqp.QueueMessage) { 261 | // TODO: if there is a consumer available, dispatch 262 | q.queueLock.Lock() 263 | defer q.queueLock.Unlock() 264 | // this method is only called when we get a nack or we shut down a channel, 265 | // so it means the message was not acked. 266 | q.msgStore.IncrDeliveryCount(queueName, msg) 267 | q.queue.PushFront(msg) 268 | select { 269 | case q.maybeReady <- true: 270 | default: 271 | } 272 | } 273 | 274 | func (q *Queue) removeConsumer(consumerTag string) { 275 | q.consumerLock.Lock() 276 | defer q.consumerLock.Unlock() 277 | if q.soleConsumer != nil && q.soleConsumer.ConsumerTag == consumerTag { 278 | q.soleConsumer = nil 279 | } 280 | // remove from list 281 | for i, c := range q.consumers { 282 | if c.ConsumerTag == consumerTag { 283 | q.consumers = append(q.consumers[:i], q.consumers[i+1:]...) 284 | } 285 | } 286 | var size = len(q.consumers) 287 | if size == 0 { 288 | q.currentConsumer = 0 289 | if q.autoDelete && q.hasHadConsumers { 290 | go q.autodeleteTimeout() 291 | } 292 | } else { 293 | q.currentConsumer = q.currentConsumer % size 294 | } 295 | 296 | } 297 | 298 | func (q *Queue) autodeleteTimeout() { 299 | // There's technically a race condition here where a new binding could be 300 | // added right as we check this, but after a 5 second wait with no activity 301 | // I think this is probably safe enough. 302 | var now = time.Now() 303 | q.deleteActive = now 304 | time.Sleep(5 * time.Second) 305 | if q.deleteActive == now { 306 | q.deleteChan <- q 307 | } 308 | } 309 | 310 | func (q *Queue) cancelConsumers() { 311 | q.consumerLock.Lock() 312 | defer q.consumerLock.Unlock() 313 | q.soleConsumer = nil 314 | // Send cancel to each consumer 315 | for _, c := range q.consumers { 316 | c.SendCancel() 317 | c.Stop() 318 | } 319 | q.consumers = make([]*consumer.Consumer, 0, 1) 320 | } 321 | 322 | func (q *Queue) AddConsumer(c *consumer.Consumer, exclusive bool) (uint16, error) { 323 | if q.Closed { 324 | return 0, nil 325 | } 326 | // Reset auto-delete 327 | q.deleteActive = time.Unix(0, 0) 328 | 329 | // Add consumer 330 | q.consumerLock.Lock() 331 | if exclusive { 332 | if len(q.consumers) == 0 { 333 | q.soleConsumer = c 334 | } else { 335 | return 403, fmt.Errorf("Exclusive access denied, %d consumers active", len(q.consumers)) 336 | } 337 | } 338 | q.consumers = append(q.consumers, c) 339 | q.hasHadConsumers = true 340 | q.consumerLock.Unlock() 341 | return 0, nil 342 | } 343 | 344 | func (q *Queue) Start() { 345 | go func() { 346 | select { 347 | case q.maybeReady <- true: 348 | default: 349 | } 350 | for _ = range q.maybeReady { 351 | if q.Closed { 352 | fmt.Printf("Queue closed!\n") 353 | break 354 | } 355 | q.processOne() 356 | } 357 | }() 358 | } 359 | 360 | func (q *Queue) MaybeReady() chan bool { 361 | return q.maybeReady 362 | } 363 | 364 | func (q *Queue) processOne() { 365 | defer stats.RecordHisto(q.statProcOne, stats.Start()) 366 | q.consumerLock.RLock() 367 | defer q.consumerLock.RUnlock() 368 | var size = len(q.consumers) 369 | if size == 0 { 370 | return 371 | } 372 | for count := 0; count < size; count++ { 373 | q.currentConsumer = (q.currentConsumer + 1) % size 374 | var c = q.consumers[q.currentConsumer] 375 | c.Ping() 376 | } 377 | } 378 | 379 | func (q *Queue) GetOneForced() *amqp.QueueMessage { 380 | q.queueLock.Lock() 381 | defer q.queueLock.Unlock() 382 | if q.queue.Len() == 0 { 383 | return nil 384 | } 385 | qMsg := q.queue.Remove(q.queue.Front()).(*amqp.QueueMessage) 386 | return qMsg 387 | } 388 | 389 | func (q *Queue) GetOne(rhs ...amqp.MessageResourceHolder) (*amqp.QueueMessage, *amqp.Message) { 390 | q.queueLock.Lock() 391 | defer q.queueLock.Unlock() 392 | // Empty check 393 | if q.queue.Len() == 0 || q.Closed { 394 | return nil, nil 395 | } 396 | 397 | // Get one message. If there is a message try to acquire the resources 398 | // from the channel. 399 | var qm = q.queue.Front().Value.(*amqp.QueueMessage) 400 | 401 | var msg, acquired = q.msgStore.Get(qm, rhs) 402 | if acquired { 403 | q.queue.Remove(q.queue.Front()) 404 | return qm, msg 405 | } 406 | return nil, nil 407 | } 408 | -------------------------------------------------------------------------------- /queue/queue_test.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestQueueAdd(t *testing.T) { 8 | 9 | } 10 | -------------------------------------------------------------------------------- /scripts/benchmark_helper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | cd scripts/external/perf-client/ 6 | ./runjava.sh com.rabbitmq.examples.PerfTest \ 7 | --exchange perf-test \ 8 | -uri amqp://guest:guest@localhost:${RUN_PORT} \ 9 | --queue perf-test-transient \ 10 | --consumers 4 \ 11 | --producers 2 \ 12 | --qos 20 \ 13 | --time 20 "$@" 14 | -------------------------------------------------------------------------------- /scripts/cover.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import sys 3 | import subprocess 4 | import os 5 | 6 | PREFIX = 'github.com/jeffjenkins/dispatchd/' 7 | 8 | class Colors(object): 9 | PURPLE = '\033[95m' 10 | BLUE = '\033[94m' 11 | GREEN = '\033[92m' 12 | YELLOW = '\033[93m' 13 | RED = '\033[91m' 14 | ENDC = '\033[0m' 15 | BOLD = '\033[1m' 16 | UNDERLINE = '\033[4m' 17 | 18 | def main(args): 19 | # Gather packages 20 | gopath = os.environ['GOPATH'] 21 | output_dir = os.path.join(gopath, 'cover') 22 | if not os.path.exists(output_dir): 23 | os.makedirs(output_dir) 24 | package_dir = os.path.join(gopath, 'src', PREFIX) 25 | prefix = 'github.com/jeffjenkins/dispatchd/' 26 | test_packages = subprocess.check_output([ 27 | 'go', 28 | 'list', 29 | '{}...'.format(prefix), 30 | ]).split('\n') 31 | test_packages = [t.strip().replace(prefix, '') for t in test_packages] 32 | test_packages = [t for t in test_packages if t] 33 | 34 | cover_names = [] 35 | for pkg in test_packages: 36 | cover_name = os.path.join(output_dir, '{}.cover'.format(pkg)) 37 | cover_names.append((pkg, cover_name)) 38 | test_target = os.path.join(PREFIX, pkg) 39 | cmd = [ 40 | 'go', 41 | 'test', 42 | '-coverpkg=github.com/jeffjenkins/dispatchd/...', 43 | '-coverprofile={}'.format(cover_name), 44 | test_target, 45 | ] 46 | subprocess.check_call(cmd) 47 | 48 | merge_call = [os.path.join(gopath, 'bin', 'gocovmerge')] 49 | merge_call.extend([n for _, n in cover_names if os.path.exists(n)]) 50 | output = subprocess.check_output(merge_call) 51 | all = os.path.join(output_dir, 'all.cover') 52 | with open(all, 'w') as f: 53 | f.write(output) 54 | cover_summary([all]) 55 | 56 | def nocover(file, line): 57 | with open(os.path.join(os.environ['GOPATH'], 'src', file)) as inf: 58 | lines = inf.readlines() 59 | if 'pragma: nocover' in lines[int(line)-1]: 60 | return True 61 | return False 62 | 63 | def count_lines(file): 64 | with open(os.path.join(os.environ['GOPATH'], 'src', file)) as inf: 65 | return len(inf.readlines()) 66 | 67 | def cover_summary(cover_names): 68 | print Colors.BLUE, '===== Missing Coverage =====', Colors.ENDC 69 | count = defaultdict(list) 70 | missing = defaultdict(list) 71 | total_missing = 0 72 | total_lines = 0 73 | seen = set() 74 | for name in cover_names: 75 | with open(name) as inf: 76 | for line in inf.readlines(): 77 | line = line.strip() 78 | if line[-1] != '0': 79 | continue 80 | full_file, _, report = line.partition(':') 81 | file = full_file.replace(PREFIX, '') 82 | if 'pb.go' in file or 'generated' in file or 'testlib' in file: 83 | continue 84 | range, _, _ = report.partition(' ') 85 | first, _, second = range.partition(',') 86 | first_line, _, _ = first.partition('.') 87 | second_line, _, _ = second.partition('.') 88 | range = (int(first_line), int(second_line)) 89 | if full_file not in seen: 90 | total_lines += count_lines(full_file) 91 | seen.add(full_file) 92 | total_missing += range[1] - range[0] + 1 93 | if nocover(full_file, first_line): 94 | continue 95 | missing[file].append(range) 96 | 97 | 98 | for file, ranges in sorted(missing.items()): 99 | if ranges is not None: 100 | ranges = sorted(ranges) 101 | ranges = ['{}-{}'.format(x,y) for x, y in ranges] 102 | if len(ranges) == 0: 103 | print Colors.GREEN, file+':', Colors.ENDC, 'Full coverage!' 104 | print Colors.RED, file+':', Colors.ENDC, ', '.join(ranges) 105 | else: 106 | print Colors.RED, file+':', Colors.ENDC, 'No coverage' 107 | print Colors.RED, 'Remaining lines on files without full coverage:', total_missing, '/', total_lines, Colors.ENDC 108 | 109 | 110 | 111 | 112 | if __name__ == '__main__': 113 | main(sys.argv[1:]) -------------------------------------------------------------------------------- /server/auth.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "golang.org/x/crypto/bcrypt" 7 | ) 8 | 9 | type User struct { 10 | name string 11 | password []byte 12 | } 13 | 14 | func (s *Server) addUsers(userJson map[string]interface{}) { 15 | if len(userJson) == 0 { 16 | decoded, err := base64.StdEncoding.DecodeString(defaultUserPasswordBase64) 17 | if err != nil { 18 | panic("System integrity error: Could not base64 decode password for built in default user!") 19 | } 20 | s.users[defaultUserName] = User{name: defaultUserName, password: decoded} 21 | } 22 | for name, user := range userJson { 23 | var encoded = user.(map[string]interface{})["password_bcrypt_base64"].(string) 24 | decoded, err := base64.StdEncoding.DecodeString(encoded) 25 | if err != nil { 26 | panic("Could not base64 decode password for default config file user: " + name) 27 | } 28 | s.users[name] = User{name: name, password: decoded} 29 | } 30 | } 31 | 32 | // guest/guest 33 | var defaultUserName = "guest" 34 | var defaultUserPasswordBase64 = "JDJhJDExJENobGk4dG5rY0RGemJhTjhsV21xR3VNNnFZZ1ZqTzUzQWxtbGtyMHRYN3RkUHMuYjF5SUt5" 35 | 36 | func (s *Server) authenticate(mechanism string, blob []byte) bool { 37 | // Split. SASL PLAIN has three parts 38 | parts := bytes.Split(blob, []byte{0}) 39 | if len(parts) != 3 { 40 | return false 41 | } 42 | 43 | for name, user := range s.users { 44 | if string(parts[1]) != name { 45 | continue 46 | } 47 | err := bcrypt.CompareHashAndPassword(user.password, parts[2]) 48 | if err == nil { 49 | return true 50 | } 51 | } 52 | return false 53 | } 54 | -------------------------------------------------------------------------------- /server/basicMethods.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/jeffjenkins/dispatchd/amqp" 5 | "github.com/jeffjenkins/dispatchd/stats" 6 | "github.com/jeffjenkins/dispatchd/util" 7 | ) 8 | 9 | func (channel *Channel) basicRoute(methodFrame amqp.MethodFrame) *amqp.AMQPError { 10 | switch method := methodFrame.(type) { 11 | case *amqp.BasicQos: 12 | return channel.basicQos(method) 13 | case *amqp.BasicRecover: 14 | return channel.basicRecover(method) 15 | case *amqp.BasicNack: 16 | return channel.basicNack(method) 17 | case *amqp.BasicConsume: 18 | return channel.basicConsume(method) 19 | case *amqp.BasicCancel: 20 | return channel.basicCancel(method) 21 | case *amqp.BasicCancelOk: 22 | return channel.basicCancelOk(method) 23 | case *amqp.BasicPublish: 24 | return channel.basicPublish(method) 25 | case *amqp.BasicGet: 26 | return channel.basicGet(method) 27 | case *amqp.BasicAck: 28 | return channel.basicAck(method) 29 | case *amqp.BasicReject: 30 | return channel.basicReject(method) 31 | } 32 | var classId, methodId = methodFrame.MethodIdentifier() 33 | return amqp.NewHardError(540, "Unable to route method frame", classId, methodId) 34 | } 35 | 36 | func (channel *Channel) basicQos(method *amqp.BasicQos) *amqp.AMQPError { 37 | channel.setPrefetch(method.PrefetchCount, method.PrefetchSize, method.Global) 38 | channel.SendMethod(&amqp.BasicQosOk{}) 39 | return nil 40 | } 41 | 42 | func (channel *Channel) basicRecover(method *amqp.BasicRecover) *amqp.AMQPError { 43 | channel.recover(method.Requeue) 44 | channel.SendMethod(&amqp.BasicRecoverOk{}) 45 | return nil 46 | } 47 | 48 | func (channel *Channel) basicNack(method *amqp.BasicNack) *amqp.AMQPError { 49 | if method.Multiple { 50 | return channel.nackBelow(method.DeliveryTag, method.Requeue, false) 51 | } 52 | return channel.nackOne(method.DeliveryTag, method.Requeue, false) 53 | } 54 | 55 | func (channel *Channel) basicConsume(method *amqp.BasicConsume) *amqp.AMQPError { 56 | var classId, methodId = method.MethodIdentifier() 57 | // Check queue 58 | if len(method.Queue) == 0 { 59 | if len(channel.lastQueueName) == 0 { 60 | return amqp.NewSoftError(404, "Queue not found", classId, methodId) 61 | } else { 62 | method.Queue = channel.lastQueueName 63 | } 64 | } 65 | // TODO: do not directly access channel.conn.server.queues 66 | var queue, found = channel.conn.server.queues[method.Queue] 67 | if !found { 68 | // Spec doesn't say, but seems like a 404? 69 | return amqp.NewSoftError(404, "Queue not found", classId, methodId) 70 | } 71 | if len(method.ConsumerTag) == 0 { 72 | method.ConsumerTag = util.RandomId() 73 | } 74 | amqpErr := channel.addConsumer(queue, method) 75 | if amqpErr != nil { 76 | return amqpErr 77 | } 78 | if !method.NoWait { 79 | channel.SendMethod(&amqp.BasicConsumeOk{method.ConsumerTag}) 80 | } 81 | 82 | return nil 83 | } 84 | 85 | func (channel *Channel) basicCancel(method *amqp.BasicCancel) *amqp.AMQPError { 86 | 87 | if err := channel.removeConsumer(method.ConsumerTag); err != nil { 88 | var classId, methodId = method.MethodIdentifier() 89 | return amqp.NewSoftError(404, "Consumer not found", classId, methodId) 90 | } 91 | 92 | if !method.NoWait { 93 | channel.SendMethod(&amqp.BasicCancelOk{method.ConsumerTag}) 94 | } 95 | return nil 96 | } 97 | 98 | func (channel *Channel) basicCancelOk(method *amqp.BasicCancelOk) *amqp.AMQPError { 99 | // TODO(MAY) 100 | var classId, methodId = method.MethodIdentifier() 101 | return amqp.NewHardError(540, "Not implemented", classId, methodId) 102 | } 103 | 104 | func (channel *Channel) basicPublish(method *amqp.BasicPublish) *amqp.AMQPError { 105 | defer stats.RecordHisto(channel.statPublish, stats.Start()) 106 | var _, found = channel.server.exchanges[method.Exchange] 107 | if !found { 108 | var classId, methodId = method.MethodIdentifier() 109 | return amqp.NewSoftError(404, "Exchange not found", classId, methodId) 110 | } 111 | channel.startPublish(method) 112 | return nil 113 | } 114 | 115 | func (channel *Channel) basicGet(method *amqp.BasicGet) *amqp.AMQPError { 116 | // var classId, methodId = method.MethodIdentifier() 117 | // channel.conn.connectionErrorWithMethod(540, "Not implemented", classId, methodId) 118 | var queue, found = channel.conn.server.queues[method.Queue] 119 | if !found { 120 | // Spec doesn't say, but seems like a 404? 121 | var classId, methodId = method.MethodIdentifier() 122 | return amqp.NewSoftError(404, "Queue not found", classId, methodId) 123 | } 124 | var qm = queue.GetOneForced() 125 | if qm == nil { 126 | channel.SendMethod(&amqp.BasicGetEmpty{}) 127 | return nil 128 | } 129 | 130 | var rhs = []amqp.MessageResourceHolder{channel} 131 | msg, err := channel.server.msgStore.GetAndDecrRef(qm, queue.Name, rhs) 132 | if err != nil { 133 | // TODO: return 500 error 134 | channel.SendMethod(&amqp.BasicGetEmpty{}) 135 | return nil 136 | } 137 | 138 | channel.SendContent(&amqp.BasicGetOk{ 139 | DeliveryTag: channel.nextDeliveryTag(), 140 | Redelivered: qm.DeliveryCount > 0, 141 | Exchange: msg.Exchange, 142 | RoutingKey: msg.Key, 143 | MessageCount: 1, 144 | }, msg) 145 | return nil 146 | } 147 | 148 | func (channel *Channel) basicAck(method *amqp.BasicAck) *amqp.AMQPError { 149 | if method.Multiple { 150 | return channel.ackBelow(method.DeliveryTag, false) 151 | } 152 | return channel.ackOne(method.DeliveryTag, false) 153 | } 154 | 155 | func (channel *Channel) basicReject(method *amqp.BasicReject) *amqp.AMQPError { 156 | return channel.nackOne(method.DeliveryTag, method.Requeue, false) 157 | } 158 | -------------------------------------------------------------------------------- /server/channelMethods.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/jeffjenkins/dispatchd/amqp" 5 | ) 6 | 7 | func (channel *Channel) channelRoute(methodFrame amqp.MethodFrame) *amqp.AMQPError { 8 | switch method := methodFrame.(type) { 9 | case *amqp.ChannelOpen: 10 | return channel.channelOpen(method) 11 | case *amqp.ChannelFlow: 12 | return channel.channelFlow(method) 13 | case *amqp.ChannelFlowOk: 14 | return channel.channelFlowOk(method) 15 | case *amqp.ChannelClose: 16 | return channel.channelClose(method) 17 | case *amqp.ChannelCloseOk: 18 | return channel.channelCloseOk(method) 19 | // case *amqp.ChannelOpenOk: 20 | // return channel.channelOpenOk(method) 21 | } 22 | var classId, methodId = methodFrame.MethodIdentifier() 23 | return amqp.NewHardError(540, "Unable to route method frame", classId, methodId) 24 | } 25 | 26 | func (channel *Channel) channelOpen(method *amqp.ChannelOpen) *amqp.AMQPError { 27 | if channel.state == CH_STATE_OPEN { 28 | var classId, methodId = method.MethodIdentifier() 29 | return amqp.NewHardError(504, "Channel already open", classId, methodId) 30 | } 31 | channel.SendMethod(&amqp.ChannelOpenOk{}) 32 | channel.setStateOpen() 33 | return nil 34 | } 35 | 36 | func (channel *Channel) channelFlow(method *amqp.ChannelFlow) *amqp.AMQPError { 37 | channel.changeFlow(method.Active) 38 | channel.SendMethod(&amqp.ChannelFlowOk{channel.flow}) 39 | return nil 40 | } 41 | 42 | func (channel *Channel) channelFlowOk(method *amqp.ChannelFlowOk) *amqp.AMQPError { 43 | var classId, methodId = method.MethodIdentifier() 44 | return amqp.NewHardError(540, "Not implemented", classId, methodId) 45 | } 46 | 47 | func (channel *Channel) channelClose(method *amqp.ChannelClose) *amqp.AMQPError { 48 | // TODO(MAY): Report the class and method that are the reason for the close 49 | channel.SendMethod(&amqp.ChannelCloseOk{}) 50 | channel.shutdown() 51 | return nil 52 | } 53 | 54 | func (channel *Channel) channelCloseOk(method *amqp.ChannelCloseOk) *amqp.AMQPError { 55 | channel.shutdown() 56 | return nil 57 | } 58 | -------------------------------------------------------------------------------- /server/connection.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "github.com/jeffjenkins/dispatchd/amqp" 8 | "github.com/jeffjenkins/dispatchd/stats" 9 | "github.com/jeffjenkins/dispatchd/util" 10 | "net" 11 | "sync" 12 | "time" 13 | ) 14 | 15 | // TODO: we can only be "in" one of these at once, so this should probably 16 | // be one field 17 | type ConnectStatus struct { 18 | start bool 19 | startOk bool 20 | secure bool 21 | secureOk bool 22 | tune bool 23 | tuneOk bool 24 | open bool 25 | openOk bool 26 | closing bool 27 | closed bool 28 | } 29 | 30 | type AMQPConnection struct { 31 | id int64 32 | nextChannel int 33 | channels map[uint16]*Channel 34 | outgoing chan *amqp.WireFrame 35 | connectStatus ConnectStatus 36 | server *Server 37 | network net.Conn 38 | lock sync.Mutex 39 | ttl time.Time 40 | sendHeartbeatInterval time.Duration 41 | receiveHeartbeatInterval time.Duration 42 | maxChannels uint16 43 | maxFrameSize uint32 44 | clientProperties *amqp.Table 45 | // stats 46 | statOutBlocked stats.Histogram 47 | statOutNetwork stats.Histogram 48 | statInBlocked stats.Histogram 49 | statInNetwork stats.Histogram 50 | } 51 | 52 | func (conn *AMQPConnection) MarshalJSON() ([]byte, error) { 53 | return json.Marshal(map[string]interface{}{ 54 | "id": conn.id, 55 | "address": fmt.Sprintf("%s", conn.network.RemoteAddr()), 56 | "clientProperties": conn.clientProperties.Table, 57 | "channelCount": len(conn.channels), 58 | }) 59 | } 60 | 61 | func NewAMQPConnection(server *Server, network net.Conn) *AMQPConnection { 62 | return &AMQPConnection{ 63 | // If outgoing has a buffer the server performs better. I'm not adding one 64 | // in until I fully understand why that is 65 | id: util.NextId(), 66 | network: network, 67 | channels: make(map[uint16]*Channel), 68 | outgoing: make(chan *amqp.WireFrame, 100), 69 | connectStatus: ConnectStatus{}, 70 | server: server, 71 | receiveHeartbeatInterval: 10 * time.Second, 72 | maxChannels: 4096, 73 | maxFrameSize: 65536, 74 | // stats 75 | statOutBlocked: stats.MakeHistogram("Connection.Out.Blocked"), 76 | statOutNetwork: stats.MakeHistogram("Connection.Out.Network"), 77 | statInBlocked: stats.MakeHistogram("Connection.In.Blocked"), 78 | statInNetwork: stats.MakeHistogram("Connection.In.Network"), 79 | } 80 | } 81 | 82 | func (conn *AMQPConnection) openConnection() { 83 | // Negotiate Protocol 84 | buf := make([]byte, 8) 85 | _, err := conn.network.Read(buf) 86 | if err != nil { 87 | conn.hardClose() 88 | return 89 | } 90 | 91 | var supported = []byte{'A', 'M', 'Q', 'P', 0, 0, 9, 1} 92 | if bytes.Compare(buf, supported) != 0 { 93 | conn.network.Write(supported) 94 | conn.hardClose() 95 | return 96 | } 97 | 98 | // Create channel 0 and start the connection handshake 99 | conn.channels[0] = NewChannel(0, conn) 100 | conn.channels[0].start() 101 | conn.handleOutgoing() 102 | conn.handleIncoming() 103 | } 104 | 105 | func (conn *AMQPConnection) cleanUp() { 106 | 107 | } 108 | 109 | func (conn *AMQPConnection) deregisterChannel(id uint16) { 110 | delete(conn.channels, id) 111 | } 112 | 113 | func (conn *AMQPConnection) hardClose() { 114 | conn.network.Close() 115 | conn.connectStatus.closed = true 116 | conn.server.deregisterConnection(conn.id) 117 | conn.server.deleteQueuesForConn(conn.id) 118 | for _, channel := range conn.channels { 119 | channel.shutdown() 120 | } 121 | } 122 | 123 | func (conn *AMQPConnection) setMaxChannels(max uint16) { 124 | conn.maxChannels = max 125 | } 126 | 127 | func (conn *AMQPConnection) setMaxFrameSize(max uint32) { 128 | conn.maxFrameSize = max 129 | } 130 | 131 | func (conn *AMQPConnection) startSendHeartbeat(interval time.Duration) { 132 | conn.sendHeartbeatInterval = interval 133 | conn.handleSendHeartbeat() 134 | } 135 | 136 | func (conn *AMQPConnection) handleSendHeartbeat() { 137 | go func() { 138 | for { 139 | if conn.connectStatus.closed { 140 | break 141 | } 142 | time.Sleep(conn.sendHeartbeatInterval / 2) 143 | conn.outgoing <- &amqp.WireFrame{8, 0, make([]byte, 0)} 144 | } 145 | }() 146 | } 147 | 148 | func (conn *AMQPConnection) handleClientHeartbeatTimeout() { 149 | // TODO(MUST): The spec is that any octet is a heartbeat substitute. Right 150 | // now this is only looking at frames, so a long send could cause a timeout 151 | // TODO(MUST): if the client isn't heartbeating how do we know when it's 152 | // gone? 153 | go func() { 154 | for { 155 | if conn.connectStatus.closed { 156 | break 157 | } 158 | time.Sleep(conn.receiveHeartbeatInterval / 2) // 159 | // If now is higher than TTL we need to time the client out 160 | if conn.ttl.Before(time.Now()) { 161 | conn.hardClose() 162 | } 163 | } 164 | }() 165 | } 166 | 167 | func (conn *AMQPConnection) handleOutgoing() { 168 | // TODO(MUST): Use SetWriteDeadline so we never wait too long. It should be 169 | // higher than the heartbeat in use. It should be reset after the heartbeat 170 | // interval is known. 171 | go func() { 172 | for { 173 | if conn.connectStatus.closed { 174 | break 175 | } 176 | var start = stats.Start() 177 | var frame = <-conn.outgoing 178 | stats.RecordHisto(conn.statOutBlocked, start) 179 | 180 | // fmt.Printf("Sending outgoing message. type: %d\n", frame.FrameType) 181 | // TODO(MUST): Hard close on irrecoverable errors, retry on recoverable 182 | // ones some number of times. 183 | start = stats.Start() 184 | amqp.WriteFrame(conn.network, frame) 185 | stats.RecordHisto(conn.statOutNetwork, start) 186 | // for wire protocol debugging: 187 | // for _, b := range frame.Payload { 188 | // fmt.Printf("%d,", b) 189 | // } 190 | // fmt.Printf("\n") 191 | } 192 | }() 193 | } 194 | 195 | func (conn *AMQPConnection) connectionErrorWithMethod(amqpErr *amqp.AMQPError) { 196 | fmt.Println("Sending connection error:", amqpErr.Msg) 197 | conn.connectStatus.closing = true 198 | conn.channels[0].SendMethod(&amqp.ConnectionClose{ 199 | ReplyCode: amqpErr.Code, 200 | ReplyText: amqpErr.Msg, 201 | ClassId: amqpErr.Class, 202 | MethodId: amqpErr.Method, 203 | }) 204 | } 205 | 206 | func (conn *AMQPConnection) handleIncoming() { 207 | for { 208 | // If the connection is done, we stop handling frames 209 | if conn.connectStatus.closed { 210 | break 211 | } 212 | // Read from the network 213 | // TODO(MUST): Add a timeout to the read, esp. if there is no heartbeat 214 | // TODO(MUST): Hard close on unrecoverable errors, retry (with backoff?) 215 | // for recoverable ones 216 | var start = stats.Start() 217 | frame, err := amqp.ReadFrame(conn.network) 218 | if err != nil { 219 | fmt.Println("Error reading frame: " + err.Error()) 220 | conn.hardClose() 221 | break 222 | } 223 | stats.RecordHisto(conn.statInNetwork, start) 224 | conn.handleFrame(frame) 225 | } 226 | } 227 | 228 | func (conn *AMQPConnection) handleFrame(frame *amqp.WireFrame) { 229 | 230 | // Upkeep. Remove things which have expired, etc 231 | conn.cleanUp() 232 | conn.ttl = time.Now().Add(conn.receiveHeartbeatInterval * 2) 233 | 234 | switch { 235 | case frame.FrameType == 8: 236 | // TODO(MUST): Update last heartbeat time 237 | return 238 | } 239 | 240 | if !conn.connectStatus.open && frame.Channel != 0 { 241 | fmt.Println("Non-0 channel for unopened connection") 242 | conn.hardClose() 243 | return 244 | } 245 | var channel, ok = conn.channels[frame.Channel] 246 | // TODO(MUST): Check that the channel number if in the valid range 247 | if !ok { 248 | channel = NewChannel(frame.Channel, conn) 249 | conn.channels[frame.Channel] = channel 250 | conn.channels[frame.Channel].start() 251 | } 252 | // Dispatch 253 | start := stats.Start() 254 | channel.incoming <- frame 255 | stats.RecordHisto(conn.statInBlocked, start) 256 | } 257 | -------------------------------------------------------------------------------- /server/connectionMethods.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/jeffjenkins/dispatchd/amqp" 5 | "os" 6 | "runtime" 7 | "time" 8 | ) 9 | 10 | func (channel *Channel) connectionRoute(conn *AMQPConnection, methodFrame amqp.MethodFrame) *amqp.AMQPError { 11 | switch method := methodFrame.(type) { 12 | case *amqp.ConnectionStartOk: 13 | return channel.connectionStartOk(conn, method) 14 | case *amqp.ConnectionTuneOk: 15 | return channel.connectionTuneOk(conn, method) 16 | case *amqp.ConnectionOpen: 17 | return channel.connectionOpen(conn, method) 18 | case *amqp.ConnectionClose: 19 | return channel.connectionClose(conn, method) 20 | case *amqp.ConnectionSecureOk: 21 | return channel.connectionSecureOk(conn, method) 22 | case *amqp.ConnectionCloseOk: 23 | return channel.connectionCloseOk(conn, method) 24 | case *amqp.ConnectionBlocked: 25 | return channel.connectionBlocked(conn, method) 26 | case *amqp.ConnectionUnblocked: 27 | return channel.connectionUnblocked(conn, method) 28 | } 29 | var classId, methodId = methodFrame.MethodIdentifier() 30 | return amqp.NewHardError(540, "Unable to route method frame", classId, methodId) 31 | } 32 | 33 | func (channel *Channel) connectionOpen(conn *AMQPConnection, method *amqp.ConnectionOpen) *amqp.AMQPError { 34 | // TODO(MAY): Add support for virtual hosts. Check for access to the 35 | // selected one 36 | conn.connectStatus.open = true 37 | channel.SendMethod(&amqp.ConnectionOpenOk{""}) 38 | conn.connectStatus.openOk = true 39 | return nil 40 | } 41 | 42 | func (channel *Channel) connectionTuneOk(conn *AMQPConnection, method *amqp.ConnectionTuneOk) *amqp.AMQPError { 43 | conn.connectStatus.tuneOk = true 44 | if method.ChannelMax > conn.maxChannels || method.FrameMax > conn.maxFrameSize { 45 | conn.hardClose() 46 | return nil 47 | } 48 | 49 | conn.setMaxChannels(method.ChannelMax) 50 | conn.setMaxFrameSize(method.FrameMax) 51 | 52 | if method.Heartbeat > 0 { 53 | // Start sending heartbeats to the client 54 | conn.startSendHeartbeat(time.Duration(method.Heartbeat) * time.Second) 55 | } 56 | // Start listening for heartbeats from the client. 57 | // We always ask for them since we want to shut down 58 | // connections not in use 59 | conn.handleClientHeartbeatTimeout() 60 | return nil 61 | } 62 | 63 | func (channel *Channel) connectionStartOk(conn *AMQPConnection, method *amqp.ConnectionStartOk) *amqp.AMQPError { 64 | // TODO(SHOULD): record product/version/platform/copyright/information 65 | // TODO(MUST): assert mechanism, response, locale are not null 66 | conn.connectStatus.startOk = true 67 | 68 | if method.Mechanism != "PLAIN" { 69 | conn.hardClose() 70 | } 71 | 72 | if !conn.server.authenticate(method.Mechanism, method.Response) { 73 | var classId, methodId = method.MethodIdentifier() 74 | return &amqp.AMQPError{ 75 | Code: 530, 76 | Class: classId, 77 | Method: methodId, 78 | Msg: "Authorization failed", 79 | Soft: false, 80 | } 81 | } 82 | 83 | conn.clientProperties = method.ClientProperties 84 | // TODO(MUST): add support these being enforced at the connection level. 85 | channel.SendMethod(&amqp.ConnectionTune{ 86 | conn.maxChannels, 87 | conn.maxFrameSize, 88 | uint16(conn.receiveHeartbeatInterval.Nanoseconds() / int64(time.Second)), 89 | }) 90 | // TODO: Implement secure/secure-ok later if needed 91 | conn.connectStatus.secure = true 92 | conn.connectStatus.secureOk = true 93 | conn.connectStatus.tune = true 94 | return nil 95 | } 96 | 97 | func (channel *Channel) startConnection() *amqp.AMQPError { 98 | // TODO(SHOULD): add fields: host, product, version, platform, copyright, information 99 | var capabilities = amqp.NewTable() 100 | capabilities.SetKey("publisher_confirms", false) 101 | capabilities.SetKey("basic.nack", true) 102 | var serverProps = amqp.NewTable() 103 | // TODO: the java rabbitmq client I'm using for load testing doesn't like these string 104 | // fields even though the go/python clients do. If they are set as longstr (bytes) 105 | // instead they work, so I'm doing that for now 106 | serverProps.SetKey("product", []byte("dispatchd")) 107 | serverProps.SetKey("version", []byte("0.1")) 108 | serverProps.SetKey("copyright", []byte("Jeffrey Jenkins, 2015")) 109 | serverProps.SetKey("capabilities", capabilities) 110 | serverProps.SetKey("platform", []byte(runtime.GOARCH)) 111 | host, err := os.Hostname() 112 | if err != nil { 113 | serverProps.SetKey("host", []byte("UnknownHostError")) 114 | } else { 115 | serverProps.SetKey("host", []byte(host)) 116 | } 117 | 118 | serverProps.SetKey("information", []byte("http://dispatchd.org")) 119 | 120 | channel.SendMethod(&amqp.ConnectionStart{0, 9, serverProps, []byte("PLAIN"), []byte("en_US")}) 121 | return nil 122 | } 123 | 124 | func (channel *Channel) connectionClose(conn *AMQPConnection, method *amqp.ConnectionClose) *amqp.AMQPError { 125 | channel.SendMethod(&amqp.ConnectionCloseOk{}) 126 | conn.hardClose() 127 | return nil 128 | } 129 | 130 | func (channel *Channel) connectionCloseOk(conn *AMQPConnection, method *amqp.ConnectionCloseOk) *amqp.AMQPError { 131 | conn.hardClose() 132 | return nil 133 | } 134 | 135 | func (channel *Channel) connectionSecureOk(conn *AMQPConnection, method *amqp.ConnectionSecureOk) *amqp.AMQPError { 136 | // TODO(MAY): If other security mechanisms are in place, handle this 137 | conn.hardClose() 138 | return nil 139 | } 140 | 141 | func (channel *Channel) connectionBlocked(conn *AMQPConnection, method *amqp.ConnectionBlocked) *amqp.AMQPError { 142 | return amqp.NewHardError(540, "Not implemented", 10, 60) 143 | } 144 | 145 | func (channel *Channel) connectionUnblocked(conn *AMQPConnection, method *amqp.ConnectionUnblocked) *amqp.AMQPError { 146 | return amqp.NewHardError(540, "Not implemented", 10, 61) 147 | } 148 | -------------------------------------------------------------------------------- /server/exchangeMethods.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | _ "fmt" 5 | "github.com/jeffjenkins/dispatchd/amqp" 6 | "github.com/jeffjenkins/dispatchd/exchange" 7 | "strings" 8 | ) 9 | 10 | func (channel *Channel) exchangeRoute(methodFrame amqp.MethodFrame) *amqp.AMQPError { 11 | switch method := methodFrame.(type) { 12 | case *amqp.ExchangeDeclare: 13 | return channel.exchangeDeclare(method) 14 | case *amqp.ExchangeBind: 15 | return channel.exchangeBind(method) 16 | case *amqp.ExchangeUnbind: 17 | return channel.exchangeUnbind(method) 18 | case *amqp.ExchangeDelete: 19 | return channel.exchangeDelete(method) 20 | } 21 | var classId, methodId = methodFrame.MethodIdentifier() 22 | return amqp.NewHardError(540, "Not implemented", classId, methodId) 23 | } 24 | 25 | func (channel *Channel) exchangeDeclare(method *amqp.ExchangeDeclare) *amqp.AMQPError { 26 | var classId, methodId = method.MethodIdentifier() 27 | // The client I'm using for testing thought declaring the empty exchange 28 | // was OK. Check later 29 | // if len(method.Exchange) > 0 && !method.Passive { 30 | // var msg = "The empty exchange name is reserved" 31 | // channel.channelErrorWithMethod(406, msg, classId, methodId) 32 | // return nil 33 | // } 34 | 35 | // Check the name format 36 | var err = amqp.CheckExchangeOrQueueName(method.Exchange) 37 | if err != nil { 38 | return amqp.NewSoftError(406, err.Error(), classId, methodId) 39 | } 40 | 41 | // Declare! 42 | var ex, amqpErr = exchange.NewFromMethod(method, false, channel.server.exchangeDeleter) 43 | if amqpErr != nil { 44 | return amqpErr 45 | } 46 | tp, err := exchange.ExchangeNameToType(method.Type) 47 | if err != nil || tp == exchange.EX_TYPE_HEADERS { 48 | return amqp.NewHardError(503, err.Error(), classId, methodId) 49 | } 50 | existing, hasKey := channel.server.exchanges[ex.Name] 51 | if !hasKey && method.Passive { 52 | return amqp.NewSoftError(404, "Exchange does not exist", classId, methodId) 53 | } 54 | if hasKey { 55 | // if diskLoad { 56 | // panic(fmt.Sprintf("Can't disk load a key that exists: %s", ex.Name)) 57 | // } 58 | if existing.ExType != ex.ExType { 59 | return amqp.NewHardError(530, "Cannot redeclare an exchange with a different type", classId, methodId) 60 | } 61 | if existing.EquivalentExchanges(ex) { 62 | if !method.NoWait { 63 | channel.SendMethod(&amqp.ExchangeDeclareOk{}) 64 | } 65 | return nil 66 | } 67 | // Not equivalent, error in passive mode 68 | if method.Passive { 69 | return amqp.NewSoftError(406, "Exchange with this name already exists", classId, methodId) 70 | } 71 | } 72 | if method.Passive { 73 | if !method.NoWait { 74 | channel.SendMethod(&amqp.ExchangeDeclareOk{}) 75 | } 76 | return nil 77 | } 78 | 79 | // outside of passive mode you can't create an exchange starting with 80 | // amq. 81 | if strings.HasPrefix(method.Exchange, "amq.") { 82 | return amqp.NewSoftError(403, "Exchange names starting with 'amq.' are reserved", classId, methodId) 83 | } 84 | 85 | err = channel.server.addExchange(ex) 86 | if err != nil { 87 | return amqp.NewSoftError(500, err.Error(), classId, methodId) 88 | } 89 | err = ex.Persist(channel.server.db) 90 | if err != nil { 91 | return amqp.NewSoftError(500, err.Error(), classId, methodId) 92 | } 93 | if !method.NoWait { 94 | channel.SendMethod(&amqp.ExchangeDeclareOk{}) 95 | } 96 | return nil 97 | } 98 | 99 | func (channel *Channel) exchangeDelete(method *amqp.ExchangeDelete) *amqp.AMQPError { 100 | var classId, methodId = method.MethodIdentifier() 101 | var errCode, err = channel.server.deleteExchange(method) 102 | if err != nil { 103 | return amqp.NewSoftError(errCode, err.Error(), classId, methodId) 104 | } 105 | if !method.NoWait { 106 | channel.SendMethod(&amqp.ExchangeDeleteOk{}) 107 | } 108 | return nil 109 | } 110 | 111 | func (channel *Channel) exchangeBind(method *amqp.ExchangeBind) *amqp.AMQPError { 112 | var classId, methodId = method.MethodIdentifier() 113 | return amqp.NewHardError(540, "Not implemented", classId, methodId) 114 | // if !method.NoWait { 115 | // channel.SendMethod(&amqp.ExchangeBindOk{}) 116 | // } 117 | } 118 | func (channel *Channel) exchangeUnbind(method *amqp.ExchangeUnbind) *amqp.AMQPError { 119 | var classId, methodId = method.MethodIdentifier() 120 | return amqp.NewHardError(540, "Not implemented", classId, methodId) 121 | // if !method.NoWait { 122 | // channel.SendMethod(&amqp.ExchangeUnbindOk{}) 123 | // } 124 | } 125 | -------------------------------------------------------------------------------- /server/queueMethods.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "fmt" 5 | "github.com/jeffjenkins/dispatchd/amqp" 6 | "github.com/jeffjenkins/dispatchd/binding" 7 | "github.com/jeffjenkins/dispatchd/queue" 8 | "github.com/jeffjenkins/dispatchd/util" 9 | ) 10 | 11 | func (channel *Channel) queueRoute(methodFrame amqp.MethodFrame) *amqp.AMQPError { 12 | switch method := methodFrame.(type) { 13 | case *amqp.QueueDeclare: 14 | return channel.queueDeclare(method) 15 | case *amqp.QueueBind: 16 | return channel.queueBind(method) 17 | case *amqp.QueuePurge: 18 | return channel.queuePurge(method) 19 | case *amqp.QueueDelete: 20 | return channel.queueDelete(method) 21 | case *amqp.QueueUnbind: 22 | return channel.queueUnbind(method) 23 | } 24 | var classId, methodId = methodFrame.MethodIdentifier() 25 | return amqp.NewHardError(540, "Not implemented", classId, methodId) 26 | } 27 | 28 | func (channel *Channel) queueDeclare(method *amqp.QueueDeclare) *amqp.AMQPError { 29 | var classId, methodId = method.MethodIdentifier() 30 | // No name means generate a name 31 | if len(method.Queue) == 0 { 32 | method.Queue = util.RandomId() 33 | } 34 | 35 | // Check the name format 36 | var err = amqp.CheckExchangeOrQueueName(method.Queue) 37 | if err != nil { 38 | return amqp.NewSoftError(406, err.Error(), classId, methodId) 39 | } 40 | 41 | // If this is a passive request, do the appropriate checks and return 42 | if method.Passive { 43 | queue, found := channel.conn.server.queues[method.Queue] 44 | if found { 45 | if !method.NoWait { 46 | var qsize = uint32(queue.Len()) 47 | var csize = queue.ActiveConsumerCount() 48 | channel.SendMethod(&amqp.QueueDeclareOk{method.Queue, qsize, csize}) 49 | } 50 | channel.lastQueueName = method.Queue 51 | return nil 52 | } 53 | return amqp.NewSoftError(404, "Queue not found", classId, methodId) 54 | } 55 | 56 | // Create the new queue 57 | var connId = channel.conn.id 58 | if !method.Exclusive { 59 | connId = -1 60 | } 61 | var queue = queue.NewQueue( 62 | method.Queue, 63 | method.Durable, 64 | method.Exclusive, 65 | method.AutoDelete, 66 | method.Arguments, 67 | connId, 68 | channel.server.msgStore, 69 | channel.server.queueDeleter, 70 | ) 71 | 72 | // If the new queue exists already, ensure the settings are the same. If it 73 | // doesn't, add it and optionally persist it 74 | existing, hasKey := channel.server.queues[queue.Name] 75 | if hasKey { 76 | if existing.ConnId != -1 && existing.ConnId != channel.conn.id { 77 | return amqp.NewSoftError(405, "Queue is locked to another connection", classId, methodId) 78 | } 79 | if !existing.EquivalentQueues(queue) { 80 | return amqp.NewSoftError(406, "Queue exists and is not equivalent to existing", classId, methodId) 81 | } 82 | } else { 83 | err = channel.server.addQueue(queue) 84 | if err != nil { // pragma: nocover 85 | return amqp.NewSoftError(500, "Error creating queue", classId, methodId) 86 | } 87 | // Persist 88 | if queue.Durable { 89 | queue.Persist(channel.server.db) 90 | } 91 | } 92 | 93 | channel.lastQueueName = method.Queue 94 | if !method.NoWait { 95 | channel.SendMethod(&amqp.QueueDeclareOk{queue.Name, uint32(0), uint32(0)}) 96 | } 97 | return nil 98 | } 99 | 100 | func (channel *Channel) queueBind(method *amqp.QueueBind) *amqp.AMQPError { 101 | var classId, methodId = method.MethodIdentifier() 102 | 103 | if len(method.Queue) == 0 { 104 | if len(channel.lastQueueName) == 0 { 105 | return amqp.NewSoftError(404, "Queue not found", classId, methodId) 106 | } else { 107 | method.Queue = channel.lastQueueName 108 | } 109 | } 110 | 111 | // Check exchange 112 | var exchange, foundExchange = channel.server.exchanges[method.Exchange] 113 | if !foundExchange { 114 | return amqp.NewSoftError(404, "Exchange not found", classId, methodId) 115 | } 116 | 117 | // Check queue 118 | var queue, foundQueue = channel.server.queues[method.Queue] 119 | if !foundQueue || queue.Closed { 120 | return amqp.NewSoftError(404, fmt.Sprintf("Queue not found: %s", method.Queue), classId, methodId) 121 | } 122 | 123 | if queue.ConnId != -1 && queue.ConnId != channel.conn.id { 124 | return amqp.NewSoftError(405, fmt.Sprintf("Queue is locked to another connection"), classId, methodId) 125 | } 126 | 127 | // Create binding 128 | b, err := binding.NewBinding(method.Queue, method.Exchange, method.RoutingKey, method.Arguments, exchange.IsTopic()) 129 | if err != nil { 130 | return amqp.NewSoftError(500, err.Error(), classId, methodId) 131 | } 132 | 133 | // Add binding 134 | err = exchange.AddBinding(b, channel.conn.id) 135 | if err != nil { 136 | return amqp.NewSoftError(500, err.Error(), classId, methodId) 137 | } 138 | 139 | // Persist durable bindings 140 | if exchange.Durable && queue.Durable { 141 | var err = b.Persist(channel.server.db) 142 | if err != nil { 143 | return amqp.NewSoftError(500, err.Error(), classId, methodId) 144 | } 145 | } 146 | 147 | if !method.NoWait { 148 | channel.SendMethod(&amqp.QueueBindOk{}) 149 | } 150 | return nil 151 | } 152 | 153 | func (channel *Channel) queuePurge(method *amqp.QueuePurge) *amqp.AMQPError { 154 | fmt.Println("Got queuePurge") 155 | var classId, methodId = method.MethodIdentifier() 156 | 157 | // Check queue 158 | if len(method.Queue) == 0 { 159 | if len(channel.lastQueueName) == 0 { 160 | return amqp.NewSoftError(404, "Queue not found", classId, methodId) 161 | } else { 162 | method.Queue = channel.lastQueueName 163 | } 164 | } 165 | 166 | var queue, foundQueue = channel.server.queues[method.Queue] 167 | if !foundQueue { 168 | return amqp.NewSoftError(404, "Queue not found", classId, methodId) 169 | } 170 | 171 | if queue.ConnId != -1 && queue.ConnId != channel.conn.id { 172 | return amqp.NewSoftError(405, "Queue is locked to another connection", classId, methodId) 173 | } 174 | 175 | numPurged := queue.Purge() 176 | if !method.NoWait { 177 | channel.SendMethod(&amqp.QueuePurgeOk{numPurged}) 178 | } 179 | return nil 180 | } 181 | 182 | func (channel *Channel) queueDelete(method *amqp.QueueDelete) *amqp.AMQPError { 183 | fmt.Println("Got queueDelete") 184 | var classId, methodId = method.MethodIdentifier() 185 | 186 | // Check queue 187 | if len(method.Queue) == 0 { 188 | if len(channel.lastQueueName) == 0 { 189 | return amqp.NewSoftError(404, "Queue not found", classId, methodId) 190 | } else { 191 | method.Queue = channel.lastQueueName 192 | } 193 | } 194 | 195 | numPurged, errCode, err := channel.server.deleteQueue(method, channel.conn.id) 196 | if err != nil { 197 | return amqp.NewSoftError(errCode, err.Error(), classId, methodId) 198 | } 199 | 200 | if !method.NoWait { 201 | channel.SendMethod(&amqp.QueueDeleteOk{numPurged}) 202 | } 203 | return nil 204 | } 205 | 206 | func (channel *Channel) queueUnbind(method *amqp.QueueUnbind) *amqp.AMQPError { 207 | var classId, methodId = method.MethodIdentifier() 208 | 209 | // Check queue 210 | if len(method.Queue) == 0 { 211 | if len(channel.lastQueueName) == 0 { 212 | return amqp.NewSoftError(404, "Queue not found", classId, methodId) 213 | } else { 214 | method.Queue = channel.lastQueueName 215 | } 216 | } 217 | 218 | var queue, foundQueue = channel.server.queues[method.Queue] 219 | if !foundQueue { 220 | return amqp.NewSoftError(404, "Queue not found", classId, methodId) 221 | } 222 | 223 | if queue.ConnId != -1 && queue.ConnId != channel.conn.id { 224 | return amqp.NewSoftError(405, "Queue is locked to another connection", classId, methodId) 225 | } 226 | 227 | // Check exchange 228 | var exchange, foundExchange = channel.server.exchanges[method.Exchange] 229 | if !foundExchange { 230 | return amqp.NewSoftError(404, "Exchange not found", classId, methodId) 231 | } 232 | 233 | var binding, err = binding.NewBinding( 234 | method.Queue, 235 | method.Exchange, 236 | method.RoutingKey, 237 | method.Arguments, 238 | exchange.IsTopic(), 239 | ) 240 | 241 | if err != nil { 242 | return amqp.NewSoftError(500, err.Error(), classId, methodId) 243 | } 244 | 245 | if queue.Durable && exchange.Durable { 246 | err := binding.Depersist(channel.server.db) 247 | if err != nil { 248 | return amqp.NewSoftError(500, "Could not de-persist binding!", classId, methodId) 249 | } 250 | } 251 | 252 | if err := exchange.RemoveBinding(binding); err != nil { 253 | return amqp.NewSoftError(500, err.Error(), classId, methodId) 254 | } 255 | channel.SendMethod(&amqp.QueueUnbindOk{}) 256 | return nil 257 | } 258 | -------------------------------------------------------------------------------- /server/server.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "github.com/boltdb/bolt" 8 | "github.com/jeffjenkins/dispatchd/amqp" 9 | "github.com/jeffjenkins/dispatchd/binding" 10 | "github.com/jeffjenkins/dispatchd/exchange" 11 | "github.com/jeffjenkins/dispatchd/msgstore" 12 | "github.com/jeffjenkins/dispatchd/queue" 13 | "net" 14 | "sync" 15 | ) 16 | 17 | type Server struct { 18 | exchanges map[string]*exchange.Exchange 19 | queues map[string]*queue.Queue 20 | bindings []*binding.Binding 21 | idLock sync.Mutex 22 | conns map[int64]*AMQPConnection 23 | db *bolt.DB 24 | serverLock sync.Mutex 25 | msgStore *msgstore.MessageStore 26 | exchangeDeleter chan *exchange.Exchange 27 | queueDeleter chan *queue.Queue 28 | users map[string]User 29 | strictMode bool 30 | } 31 | 32 | func (server *Server) MarshalJSON() ([]byte, error) { 33 | conns := make(map[string]*AMQPConnection) 34 | for id, value := range server.conns { 35 | conns[fmt.Sprintf("%d", id)] = value 36 | } 37 | return json.Marshal(map[string]interface{}{ 38 | "exchanges": server.exchanges, 39 | "queues": server.queues, 40 | "connections": conns, 41 | "msgCount": server.msgStore.MessageCount(), 42 | "msgIndexCount": server.msgStore.IndexCount(), 43 | }) 44 | } 45 | 46 | func NewServer(dbPath string, msgStorePath string, userJson map[string]interface{}, strictMode bool) *Server { 47 | db, err := bolt.Open(dbPath, 0600, nil) 48 | if err != nil { 49 | panic(err.Error()) 50 | 51 | } 52 | msgStore, err := msgstore.NewMessageStore(msgStorePath) 53 | msgStore.Start() 54 | if err != nil { 55 | panic("Could not create message store!") 56 | } 57 | 58 | var server = &Server{ 59 | exchanges: make(map[string]*exchange.Exchange), 60 | queues: make(map[string]*queue.Queue), 61 | bindings: make([]*binding.Binding, 0), 62 | conns: make(map[int64]*AMQPConnection), 63 | db: db, 64 | msgStore: msgStore, 65 | exchangeDeleter: make(chan *exchange.Exchange), 66 | queueDeleter: make(chan *queue.Queue), 67 | users: make(map[string]User), 68 | strictMode: strictMode, 69 | } 70 | 71 | server.init() 72 | server.addUsers(userJson) 73 | return server 74 | } 75 | 76 | func (server *Server) init() { 77 | server.msgStore.LoadMessages() //this must be before initQueues 78 | server.initExchanges() 79 | server.initQueues() 80 | server.initBindings() // this must be after init{Exchanges,Queues} 81 | go server.exchangeDeleteMonitor() 82 | go server.queueDeleteMonitor() 83 | } 84 | 85 | func (server *Server) exchangeDeleteMonitor() { 86 | for e := range server.exchangeDeleter { 87 | var dele = &amqp.ExchangeDelete{ 88 | Exchange: e.Name, 89 | NoWait: true, 90 | } 91 | server.deleteExchange(dele) 92 | } 93 | } 94 | 95 | func (server *Server) queueDeleteMonitor() { 96 | for q := range server.queueDeleter { 97 | var delq = &amqp.QueueDelete{ 98 | Queue: q.Name, 99 | NoWait: true, 100 | } 101 | server.deleteQueue(delq, -1) 102 | } 103 | } 104 | 105 | func (server *Server) initBindings() { 106 | // Load bindings 107 | bindings, err := binding.LoadAllBindings(server.db) 108 | if err != nil { 109 | panic("Couldn't load bindings!") 110 | } 111 | for _, b := range bindings { 112 | // Get Exchange 113 | var exchange, foundExchange = server.exchanges[b.ExchangeName] 114 | if !foundExchange { 115 | panic("Couldn't bind non-existant exchange " + b.ExchangeName) 116 | } 117 | // Add Binding 118 | err = exchange.AddBinding(b, -1) 119 | if err != nil { 120 | panic(err.Error()) 121 | } 122 | } 123 | } 124 | 125 | func (server *Server) initQueues() { 126 | // Load queues 127 | queues, err := queue.LoadAllQueues(server.db, server.msgStore, server.queueDeleter) 128 | if err != nil { 129 | panic("Couldn't load queues!") 130 | } 131 | for _, queue := range queues { 132 | err = server.addQueue(queue) 133 | if err != nil { 134 | panic("Couldn't load queues!") 135 | } 136 | } 137 | // Load queue data 138 | for _, queue := range server.queues { 139 | queue.LoadFromMsgStore(server.msgStore) 140 | } 141 | } 142 | 143 | func (server *Server) initExchanges() { 144 | // LOAD FROM PERSISTENT STORAGE 145 | exchanges, err := exchange.LoadAllExchanges(server.db, server.exchangeDeleter) 146 | if err != nil { 147 | panic("Couldn't load exchanges!") 148 | } 149 | for _, ex := range exchanges { 150 | err = server.addExchange(ex) 151 | if err != nil { 152 | panic("Couldn't load queues!") 153 | } 154 | } 155 | if err != nil { 156 | panic("FAILED TO LOAD EXCHANGES: " + err.Error()) 157 | } 158 | 159 | // DECLARE MISSING SYSEM EXCHANGES 160 | server.genDefaultExchange("", exchange.EX_TYPE_DIRECT) 161 | server.genDefaultExchange("amq.direct", exchange.EX_TYPE_DIRECT) 162 | server.genDefaultExchange("amq.fanout", exchange.EX_TYPE_FANOUT) 163 | server.genDefaultExchange("amq.topic", exchange.EX_TYPE_TOPIC) 164 | } 165 | 166 | func (server *Server) genDefaultExchange(name string, typ uint8) { 167 | _, hasKey := server.exchanges[name] 168 | if !hasKey { 169 | var ex = exchange.NewExchange( 170 | name, 171 | exchange.EX_TYPE_TOPIC, 172 | true, 173 | false, 174 | false, 175 | amqp.NewTable(), 176 | true, 177 | server.exchangeDeleter, 178 | ) 179 | // Persist 180 | ex.Persist(server.db) 181 | err := server.addExchange(ex) 182 | if err != nil { 183 | panic(err.Error()) 184 | } 185 | } 186 | } 187 | 188 | func (server *Server) addExchange(ex *exchange.Exchange) error { 189 | server.serverLock.Lock() 190 | defer server.serverLock.Unlock() 191 | server.exchanges[ex.Name] = ex 192 | return nil 193 | } 194 | 195 | func (server *Server) addQueue(q *queue.Queue) error { 196 | server.serverLock.Lock() 197 | defer server.serverLock.Unlock() 198 | server.queues[q.Name] = q 199 | var defaultExchange = server.exchanges[""] 200 | var defaultBinding, err = binding.NewBinding(q.Name, "", q.Name, amqp.NewTable(), false) 201 | if err != nil { 202 | return err 203 | } 204 | defaultExchange.AddBinding(defaultBinding, q.ConnId) 205 | q.Start() 206 | return nil 207 | } 208 | 209 | func (server *Server) deleteQueuesForConn(connId int64) { 210 | server.serverLock.Lock() 211 | var queues = make([]*queue.Queue, 0) 212 | for _, queue := range server.queues { 213 | if queue.ConnId == connId { 214 | queues = append(queues, queue) 215 | } 216 | } 217 | server.serverLock.Unlock() 218 | for _, queue := range queues { 219 | var method = &amqp.QueueDelete{ 220 | Queue: queue.Name, 221 | } 222 | server.deleteQueue(method, connId) 223 | } 224 | } 225 | 226 | func (server *Server) deleteQueue(method *amqp.QueueDelete, connId int64) (uint32, uint16, error) { 227 | server.serverLock.Lock() 228 | defer server.serverLock.Unlock() 229 | // Validate 230 | var queue, foundQueue = server.queues[method.Queue] 231 | if !foundQueue { 232 | return 0, 404, errors.New("Queue not found") 233 | } 234 | 235 | if queue.ConnId != -1 && queue.ConnId != connId { 236 | return 0, 405, fmt.Errorf("Queue is locked to another connection") 237 | } 238 | 239 | // Close to stop anything from changing 240 | queue.Close() 241 | // Delete for storage 242 | bindings := server.bindingsForQueue(queue.Name) 243 | server.removeBindingsForQueue(method.Queue) 244 | server.depersistQueue(queue, bindings) 245 | 246 | // Cleanup 247 | numPurged, err := queue.Delete(method.IfUnused, method.IfEmpty) 248 | delete(server.queues, method.Queue) 249 | if err != nil { 250 | return 0, 406, err 251 | } 252 | return numPurged, 0, nil 253 | 254 | } 255 | 256 | func (server *Server) depersistQueue(queue *queue.Queue, bindings []*binding.Binding) error { 257 | return server.db.Update(func(tx *bolt.Tx) error { 258 | for _, binding := range bindings { 259 | if err := binding.DepersistBoltTx(tx); err != nil { 260 | return err 261 | } 262 | } 263 | return queue.DepersistBoltTx(tx) 264 | }) 265 | } 266 | 267 | func (server *Server) bindingsForQueue(queueName string) []*binding.Binding { 268 | ret := make([]*binding.Binding, 0) 269 | for _, exchange := range server.exchanges { 270 | ret = append(ret, exchange.BindingsForQueue(queueName)...) 271 | } 272 | return ret 273 | } 274 | 275 | func (server *Server) removeBindingsForQueue(queueName string) { 276 | for _, exchange := range server.exchanges { 277 | exchange.RemoveBindingsForQueue(queueName) 278 | } 279 | } 280 | 281 | func (server *Server) deleteExchange(method *amqp.ExchangeDelete) (uint16, error) { 282 | server.serverLock.Lock() 283 | defer server.serverLock.Unlock() 284 | exchange, found := server.exchanges[method.Exchange] 285 | if !found { 286 | return 404, fmt.Errorf("Exchange not found: '%s'", method.Exchange) 287 | } 288 | if exchange.System { 289 | return 530, fmt.Errorf("Cannot delete system exchange: '%s'", method.Exchange) 290 | } 291 | exchange.Close() 292 | exchange.Depersist(server.db) 293 | // Note: we don't need to delete the bindings from the queues they are 294 | // associated with because they are stored on the exchange. 295 | delete(server.exchanges, method.Exchange) 296 | return 0, nil 297 | } 298 | 299 | func (server *Server) deregisterConnection(id int64) { 300 | delete(server.conns, id) 301 | } 302 | 303 | func (server *Server) OpenConnection(network net.Conn) { 304 | c := NewAMQPConnection(server, network) 305 | server.conns[c.id] = c 306 | c.openConnection() 307 | } 308 | 309 | func (server *Server) returnMessage(msg *amqp.Message, code uint16, text string) *amqp.BasicReturn { 310 | return &amqp.BasicReturn{ 311 | Exchange: msg.Method.Exchange, 312 | RoutingKey: msg.Method.RoutingKey, 313 | ReplyCode: code, 314 | ReplyText: text, 315 | } 316 | } 317 | 318 | func (server *Server) publish(exchange *exchange.Exchange, msg *amqp.Message) (*amqp.BasicReturn, *amqp.AMQPError) { 319 | // Concurrency note: Since there is no lock we can, technically, have messages 320 | // published after the exchange has been closed. These couldn't be on the same 321 | // channel as the close is happening on, so that seems justifiable. 322 | if exchange.Closed { 323 | if msg.Method.Mandatory || msg.Method.Immediate { 324 | var rm = server.returnMessage(msg, 313, "Exchange closed, cannot route to queues or consumers") 325 | return rm, nil 326 | } 327 | return nil, nil 328 | } 329 | queues, amqpErr := exchange.QueuesForPublish(msg) 330 | if amqpErr != nil { 331 | return nil, amqpErr 332 | } 333 | 334 | if len(queues) == 0 { 335 | // If we got here the message was unroutable. 336 | if msg.Method.Mandatory || msg.Method.Immediate { 337 | var rm = server.returnMessage(msg, 313, "No queues available") 338 | return rm, nil 339 | } 340 | } 341 | 342 | var queueNames = make([]string, 0, len(queues)) 343 | for k, _ := range queues { 344 | queueNames = append(queueNames, k) 345 | } 346 | 347 | // Immediate messages 348 | if msg.Method.Immediate { 349 | var consumed = false 350 | // Add message to message store 351 | queueMessagesByQueue, err := server.msgStore.AddMessage(msg, queueNames) 352 | if err != nil { 353 | return nil, amqp.NewSoftError(500, err.Error(), 60, 40) 354 | } 355 | // Try to immediately consumed it 356 | for queueName, _ := range queues { 357 | qms := queueMessagesByQueue[queueName] 358 | for _, qm := range qms { 359 | queue, found := server.queues[queueName] 360 | if !found { 361 | // The queue must have been deleted since the queuesForPublish call 362 | continue 363 | } 364 | var oneConsumed = queue.ConsumeImmediate(qm) 365 | var rhs = make([]amqp.MessageResourceHolder, 0) 366 | if !oneConsumed { 367 | server.msgStore.RemoveRef(qm, queueName, rhs) 368 | } 369 | consumed = oneConsumed || consumed 370 | } 371 | } 372 | if !consumed { 373 | var rm = server.returnMessage(msg, 313, "No consumers available for immediate message") 374 | return rm, nil 375 | } 376 | return nil, nil 377 | } 378 | 379 | // Add the message to the message store along with the queues we're about to add it to 380 | queueMessagesByQueue, err := server.msgStore.AddMessage(msg, queueNames) 381 | if err != nil { 382 | return nil, amqp.NewSoftError(500, err.Error(), 60, 40) 383 | } 384 | 385 | for queueName, _ := range queues { 386 | qms := queueMessagesByQueue[queueName] 387 | for _, qm := range qms { 388 | queue, found := server.queues[queueName] 389 | if !found || !queue.Add(qm) { 390 | // If we couldn't add it means the queue is closed and we should 391 | // remove the ref from the message store. The queue being closed means 392 | // it is going away, so worst case if the server dies we have to process 393 | // and discard the message on boot. 394 | var rhs = make([]amqp.MessageResourceHolder, 0) 395 | server.msgStore.RemoveRef(qm, queueName, rhs) 396 | } 397 | } 398 | } 399 | return nil, nil 400 | } 401 | -------------------------------------------------------------------------------- /server/server_consumer_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/jeffjenkins/dispatchd/util" 5 | "testing" 6 | ) 7 | 8 | func TestAckNackOne(t *testing.T) { 9 | // 10 | // Setup 11 | // 12 | tc := newTestClient(t) 13 | defer tc.cleanup() 14 | conn := tc.connect() 15 | ch, _, _ := channelHelper(tc, conn) 16 | 17 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 18 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS) 19 | 20 | // 21 | // Publish and consume one message. Check that channel.awaitingAcks is updated 22 | // properly before and after acking the message 23 | // 24 | deliveries, err := ch.Consume("q1", "TestAckNackOne-1", false, false, false, false, NO_ARGS) 25 | if err != nil { 26 | t.Fatalf("Failed to consume") 27 | } 28 | 29 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 30 | msg := <-deliveries 31 | 32 | if len(tc.connFromServer().channels[1].awaitingAcks) != 1 { 33 | t.Fatalf("No awaiting ack for message just received") 34 | } 35 | 36 | msg.Ack(false) 37 | tc.wait(ch) 38 | if len(tc.connFromServer().channels[1].awaitingAcks) != 0 { 39 | t.Fatalf("No awaiting ack for message just received") 40 | } 41 | 42 | // 43 | // Publish and consume two messages. Check that awaiting acks is updated 44 | // and that nacking with and without the requeue options works 45 | // 46 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 47 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 48 | msg1 := <-deliveries 49 | msg2 := <-deliveries 50 | 51 | tc.wait(ch) 52 | if len(tc.connFromServer().channels[1].awaitingAcks) != 2 { 53 | t.Fatalf("Should have 2 messages awaiting acks") 54 | } 55 | // Stop consuming so we can check requeue 56 | ch.Cancel("TestAckNackOne-1", false) 57 | msg1.Nack(false, false) 58 | msg2.Nack(false, true) 59 | tc.wait(ch) 60 | if tc.s.queues["q1"].Len() != 1 { 61 | t.Fatalf("Should have 1 message in queue") 62 | } 63 | 64 | deliveries, err = ch.Consume("q1", "TestAckNackOne-1", false, false, false, false, NO_ARGS) 65 | if err != nil { 66 | t.Fatalf("Failed to consume") 67 | } 68 | msg2_again := <-deliveries 69 | if !msg2_again.Redelivered { 70 | t.Fatalf("Redelivered message wasn't flagged") 71 | } 72 | } 73 | 74 | func TestAckNackMany(t *testing.T) { 75 | // 76 | // Setup 77 | // 78 | tc := newTestClient(t) 79 | defer tc.cleanup() 80 | conn := tc.connect() 81 | ch, _, _ := channelHelper(tc, conn) 82 | 83 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 84 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS) 85 | var consumerId = util.RandomId() 86 | 87 | // Ack Many 88 | // Publish and consume two messages. Acck-multiple the second one and 89 | // check that both are acked 90 | // 91 | deliveries, err := ch.Consume("q1", consumerId, false, false, false, false, NO_ARGS) 92 | if err != nil { 93 | t.Fatalf("Failed to consume") 94 | } 95 | 96 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 97 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 98 | _ = <-deliveries 99 | msg2 := <-deliveries 100 | 101 | msg2.Ack(true) 102 | tc.wait(ch) 103 | if len(tc.connFromServer().channels[1].awaitingAcks) != 0 { 104 | t.Fatalf("No awaiting ack for message just received") 105 | } 106 | 107 | // Nack Many (no requeue) 108 | // Publish and consume two messages. Nack-multiple the second one and 109 | // check that both are acked and the queue is empty 110 | // 111 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 112 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 113 | _ = <-deliveries 114 | msg2 = <-deliveries 115 | 116 | // Stop consuming so we can check requeue 117 | ch.Cancel(consumerId, false) 118 | msg2.Nack(true, false) 119 | tc.wait(ch) 120 | if tc.s.queues["q1"].Len() != 0 { 121 | t.Fatalf("Should have 0 message in queue") 122 | } 123 | 124 | // Nack Many (no requeue) 125 | // Publish and consume two messages. Nack-multiple the second one and 126 | // check that both are acked and the queue is empty 127 | // 128 | consumerId = util.RandomId() 129 | deliveries, err = ch.Consume("q1", consumerId, false, false, false, false, NO_ARGS) 130 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 131 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 132 | _ = <-deliveries 133 | msg2 = <-deliveries 134 | 135 | // Stop consuming so we can check requeue 136 | ch.Cancel(consumerId, false) 137 | msg2.Nack(true, true) 138 | tc.wait(ch) 139 | if tc.s.queues["q1"].Len() != 2 { 140 | t.Fatalf("Should have 2 message in queue") 141 | } 142 | } 143 | 144 | func TestRecover(t *testing.T) { 145 | // 146 | // Setup 147 | // 148 | tc := newTestClient(t) 149 | defer tc.cleanup() 150 | conn := tc.connect() 151 | ch, _, _ := channelHelper(tc, conn) 152 | cTag := util.RandomId() 153 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 154 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS) 155 | 156 | // Recover - no requeue 157 | // Publish two messages, consume them, call recover, consume them again 158 | // 159 | deliveries, err := ch.Consume("q1", cTag, false, false, false, false, NO_ARGS) 160 | if err != nil { 161 | t.Fatalf("Failed to consume") 162 | } 163 | 164 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 165 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 166 | tc.wait(ch) 167 | <-deliveries 168 | lastMsg := <-deliveries 169 | ch.Recover(false) 170 | <-deliveries 171 | <-deliveries 172 | lastMsg.Ack(true) 173 | 174 | // Recover - requeue 175 | // Publish two messages, consume them, call recover, check that they are 176 | // requeued 177 | // 178 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 179 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 180 | tc.wait(ch) 181 | <-deliveries 182 | <-deliveries 183 | ch.Cancel(cTag, false) 184 | ch.Recover(true) 185 | tc.wait(ch) 186 | msgCount := tc.s.queues["q1"].Len() 187 | if msgCount != 2 { 188 | t.Fatalf("Should have 2 message in queue. Found", msgCount) 189 | } 190 | } 191 | 192 | func TestGet(t *testing.T) { 193 | // 194 | // Setup 195 | // 196 | tc := newTestClient(t) 197 | defer tc.cleanup() 198 | conn := tc.connect() 199 | ch, _, _ := channelHelper(tc, conn) 200 | 201 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 202 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS) 203 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 204 | msg, ok, err := ch.Get("q1", false) 205 | if err != nil { 206 | t.Fatalf(err.Error()) 207 | } 208 | if !ok { 209 | t.Fatalf("Did not receive message") 210 | } 211 | if string(msg.Body) != string(TEST_TRANSIENT_MSG.Body) { 212 | t.Fatalf("wrong message response in get") 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /server/server_exchange_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestExchangeMethods(t *testing.T) { 8 | tc := newTestClient(t) 9 | defer tc.cleanup() 10 | conn := tc.connect() 11 | channel, _, _ := channelHelper(tc, conn) 12 | 13 | channel.ExchangeDeclare("ex-1", "topic", false, false, false, false, NO_ARGS) 14 | 15 | // Create exchange 16 | if len(tc.s.exchanges) != 5 { 17 | t.Errorf("Wrong number of exchanges: %d", len(tc.s.exchanges)) 18 | } 19 | 20 | // Create Queue 21 | channel.QueueDeclare("q-1", false, false, false, false, NO_ARGS) 22 | if len(tc.s.queues) != 1 { 23 | t.Errorf("Wrong number of queues: %d", len(tc.s.queues)) 24 | } 25 | 26 | // Delete exchange 27 | channel.ExchangeDelete("ex-1", false, false) 28 | if len(tc.s.exchanges) != 4 { 29 | t.Errorf("Wrong number of exchanges: %d", len(tc.s.exchanges)) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /server/server_publish_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/jeffjenkins/dispatchd/util" 5 | "testing" 6 | ) 7 | 8 | func TestImmediateFail(t *testing.T) { 9 | tc := newTestClient(t) 10 | defer tc.cleanup() 11 | conn := tc.connect() 12 | ch, retChan, _ := channelHelper(tc, conn) 13 | 14 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 15 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS) 16 | ch.Publish("amq.direct", "abc", false, true, TEST_TRANSIENT_MSG) 17 | 18 | ret := <-retChan 19 | 20 | if ret.ReplyCode != 313 { 21 | t.Fatalf("Wrong reply code with Immediate return") 22 | } 23 | if string(ret.Body) != string(TEST_TRANSIENT_MSG.Body) { 24 | t.Fatalf("Did not get same payload back in BasicReturn") 25 | } 26 | } 27 | 28 | func TestImmediate(t *testing.T) { 29 | tc := newTestClient(t) 30 | defer tc.cleanup() 31 | conn := tc.connect() 32 | ch, _, _ := channelHelper(tc, conn) 33 | 34 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 35 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS) 36 | 37 | deliveries, err := ch.Consume("q1", util.RandomId(), false, false, false, false, NO_ARGS) 38 | if err != nil { 39 | t.Fatalf("Failed to consume") 40 | } 41 | ch.Publish("amq.direct", "abc", false, true, TEST_TRANSIENT_MSG) 42 | <-deliveries 43 | } 44 | 45 | func TestMandatory(t *testing.T) { 46 | tc := newTestClient(t) 47 | defer tc.cleanup() 48 | conn := tc.connect() 49 | ch, retChan, _ := channelHelper(tc, conn) 50 | 51 | ch.Publish("amq.direct", "abc", false, true, TEST_TRANSIENT_MSG) 52 | 53 | ret := <-retChan 54 | 55 | if ret.ReplyCode != 313 { 56 | t.Fatalf("Wrong reply code with Mandatory return") 57 | } 58 | if string(ret.Body) != string(TEST_TRANSIENT_MSG.Body) { 59 | t.Fatalf("Did not get same payload back in BasicReturn") 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /server/server_queue_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestQueueMethods(t *testing.T) { 8 | tc := newTestClient(t) 9 | defer tc.cleanup() 10 | conn := tc.connect() 11 | ch, _, _ := channelHelper(tc, conn) 12 | 13 | // Create Queue 14 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 15 | 16 | if len(tc.s.queues) != 1 { 17 | t.Errorf("Wrong number of queues: %d", len(tc.s.queues)) 18 | } 19 | 20 | // Passive Check 21 | ch.QueueDeclarePassive("q1", true, false, false, false, NO_ARGS) 22 | 23 | // Bind 24 | ch.QueueBind("q1", "rk.*.#", "amq.topic", false, NO_ARGS) 25 | if len(tc.s.exchanges["amq.topic"].BindingsForQueue("q1")) != 1 { 26 | t.Errorf("Failed to bind to q1") 27 | } 28 | 29 | // Unbind 30 | ch.QueueUnbind("q1", "rk.*.#", "amq.topic", NO_ARGS) 31 | 32 | if len(tc.s.exchanges["amq.topic"].BindingsForQueue("q1")) != 0 { 33 | t.Errorf("Failed to unbind from q1") 34 | } 35 | 36 | // Delete 37 | ch.QueueDelete("q1", false, false, false) 38 | if len(tc.s.queues) != 0 { 39 | t.Errorf("Wrong number of queues: %d", len(tc.s.queues)) 40 | } 41 | } 42 | 43 | func TestAutoAssignedQueue(t *testing.T) { 44 | tc := newTestClient(t) 45 | defer tc.cleanup() 46 | conn := tc.connect() 47 | ch, _, _ := channelHelper(tc, conn) 48 | 49 | // Create Queue 50 | resp, err := ch.QueueDeclare("", false, false, false, false, NO_ARGS) 51 | if err != nil { 52 | t.Fatalf("Error declaring queue") 53 | } 54 | if len(resp.Name) == 0 { 55 | t.Errorf("Autogenerate queue name failed") 56 | } 57 | } 58 | 59 | func TestBadQueueName(t *testing.T) { 60 | tc := newTestClient(t) 61 | defer tc.cleanup() 62 | conn := tc.connect() 63 | ch, _, errChan := channelHelper(tc, conn) 64 | 65 | // Create Queue 66 | _, err := ch.QueueDeclare("!", false, false, false, true, NO_ARGS) 67 | if err != nil { 68 | panic("failed to declare queue") 69 | } 70 | resp := <-errChan 71 | if resp.Code != 406 { 72 | t.Errorf("Wrong response code") 73 | } 74 | } 75 | 76 | func TestPassiveNotFound(t *testing.T) { 77 | tc := newTestClient(t) 78 | defer tc.cleanup() 79 | conn := tc.connect() 80 | ch, _, errChan := channelHelper(tc, conn) 81 | 82 | // Create Queue 83 | _, err := ch.QueueDeclarePassive("does.not.exist", true, false, false, true, NO_ARGS) 84 | if err != nil { 85 | panic("failed to declare queue") 86 | } 87 | resp := <-errChan 88 | if resp.Code != 404 { 89 | t.Errorf("Wrong response code") 90 | } 91 | } 92 | 93 | func TestPurge(t *testing.T) { 94 | tc := newTestClient(t) 95 | defer tc.cleanup() 96 | conn := tc.connect() 97 | ch, _, _ := channelHelper(tc, conn) 98 | 99 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 100 | ch.QueueBind("q1", "a.b.c", "amq.topic", false, NO_ARGS) 101 | 102 | ch.Publish("amq.topic", "a.b.c", false, false, TEST_TRANSIENT_MSG) 103 | 104 | // This unbind is just to block us on the message being processed by the 105 | // channel so that the server has it. 106 | ch.QueueUnbind("q1", "a.b.c", "amq.topic", NO_ARGS) 107 | if tc.s.queues["q1"].Len() == 0 { 108 | t.Fatalf("Message did not make it into queue") 109 | } 110 | 111 | resp, err := ch.QueuePurge("q1", false) 112 | if err != nil { 113 | t.Fatalf("Failed to call QueuePurge") 114 | } 115 | 116 | if tc.s.queues["q1"].Len() != 0 { 117 | t.Fatalf("Message did not get purged from queue. Got %d", tc.s.queues["q1"].Len()) 118 | } 119 | 120 | if resp != 1 { 121 | t.Fatalf("Purge did not return the right number of messages deleted") 122 | } 123 | } 124 | 125 | func TestExclusive(t *testing.T) { 126 | tc := newTestClient(t) 127 | defer tc.cleanup() 128 | conn := tc.connect() 129 | ch, _, errChan := channelHelper(tc, conn) 130 | 131 | // Create Queue 132 | ch.QueueDeclare("q1", false, false, true, false, NO_ARGS) 133 | 134 | // Check conn id 135 | serverConn := tc.connFromServer() 136 | q, ok := tc.s.queues["q1"] 137 | if !ok { 138 | t.Fatalf("Could not find q1") 139 | } 140 | if serverConn.id != q.ConnId { 141 | t.Fatalf("Exclusive queue does not have connId set") 142 | } 143 | 144 | // Cheat and change the connection id so we don't need a second conn 145 | // for this test 146 | q.ConnId = 54321 147 | // NOTE: if nowait isn't true this blocks forever 148 | ch.QueueDeclare("q1", false, false, true, true, NO_ARGS) 149 | 150 | resp := <-errChan 151 | if resp.Code != 405 { 152 | t.Errorf("Wrong response code") 153 | } 154 | } 155 | 156 | func TestNonMatchingQueue(t *testing.T) { 157 | tc := newTestClient(t) 158 | defer tc.cleanup() 159 | conn := tc.connect() 160 | ch, _, errChan := channelHelper(tc, conn) 161 | 162 | // Create Queue 163 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 164 | 165 | // Create queue again with different args 166 | // NOTE: if nowait isn't true this blocks forever 167 | ch.QueueDeclare("q1", true, false, false, true, NO_ARGS) 168 | resp := <-errChan 169 | if resp.Code != 406 { 170 | t.Errorf("Wrong response code") 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /server/server_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/jeffjenkins/dispatchd/util" 5 | amqpclient "github.com/streadway/amqp" 6 | "net" 7 | "os" 8 | "testing" 9 | "time" 10 | ) 11 | 12 | var NO_ARGS = make(amqpclient.Table) 13 | var TEST_TRANSIENT_MSG = amqpclient.Publishing{ 14 | Body: []byte("dispatchd"), 15 | } 16 | 17 | type testClient struct { 18 | t *testing.T 19 | s *Server 20 | serverDb string 21 | msgDb string 22 | } 23 | 24 | func newTestClient(t *testing.T) *testClient { 25 | serverDb := dbPath() 26 | msgDb := dbPath() 27 | s := NewServer(serverDb, msgDb, nil, false) 28 | s.init() 29 | tc := &testClient{ 30 | t: t, 31 | s: s, 32 | serverDb: serverDb, 33 | msgDb: msgDb, 34 | } 35 | return tc 36 | } 37 | 38 | func channelHelper( 39 | tc *testClient, 40 | conn *amqpclient.Connection, 41 | ) ( 42 | *amqpclient.Channel, 43 | chan amqpclient.Return, 44 | chan *amqpclient.Error, 45 | ) { 46 | ch, err := conn.Channel() 47 | if err != nil { 48 | panic("Bad channel!") 49 | } 50 | retChan := make(chan amqpclient.Return) 51 | closeChan := make(chan *amqpclient.Error) 52 | ch.NotifyReturn(retChan) 53 | ch.NotifyClose(closeChan) 54 | return ch, retChan, closeChan 55 | } 56 | 57 | func (tc *testClient) connect() *amqpclient.Connection { 58 | internal, external := net.Pipe() 59 | go tc.s.OpenConnection(internal) 60 | // Set up connection 61 | clientconfig := amqpclient.Config{ 62 | SASL: nil, 63 | Vhost: "/", 64 | ChannelMax: 100000, 65 | FrameSize: 100000, 66 | Heartbeat: time.Duration(0), 67 | TLSClientConfig: nil, 68 | Properties: make(amqpclient.Table), 69 | Dial: func(network, addr string) (net.Conn, error) { 70 | return external, nil 71 | }, 72 | } 73 | 74 | client, err := amqpclient.DialConfig("amqp://localhost:1234", clientconfig) 75 | if err != nil { 76 | panic(err.Error()) 77 | } 78 | return client 79 | } 80 | 81 | func (tc *testClient) wait(ch *amqpclient.Channel) { 82 | ch.QueueDeclare(util.RandomId(), false, false, false, false, NO_ARGS) 83 | } 84 | 85 | func (tc *testClient) cleanup() { 86 | os.Remove(tc.msgDb) 87 | os.Remove(tc.serverDb) 88 | // tc.client.Close() 89 | } 90 | 91 | func dbPath() string { 92 | return "/tmp/" + util.RandomId() + ".dispatchd.test.db" 93 | } 94 | 95 | func (tc *testClient) connFromServer() *AMQPConnection { 96 | for _, conn := range tc.s.conns { 97 | return conn 98 | } 99 | panic("no connections!") 100 | } 101 | -------------------------------------------------------------------------------- /server/server_tx_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/jeffjenkins/dispatchd/util" 5 | "testing" 6 | ) 7 | 8 | func TestTxCommitPublish(t *testing.T) { 9 | // 10 | // Setup 11 | // 12 | tc := newTestClient(t) 13 | defer tc.cleanup() 14 | conn := tc.connect() 15 | ch, _, _ := channelHelper(tc, conn) 16 | 17 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 18 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS) 19 | ch.Tx() 20 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 21 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 22 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 23 | tc.wait(ch) 24 | if tc.s.queues["q1"].Len() != 0 { 25 | t.Fatalf("Tx failed to buffer messages") 26 | } 27 | ch.TxCommit() 28 | tc.wait(ch) 29 | if tc.s.queues["q1"].Len() != 3 { 30 | t.Fatalf("All messages were not added to queue") 31 | } 32 | } 33 | 34 | func TestTxRollbackPublish(t *testing.T) { 35 | // 36 | // Setup 37 | // 38 | tc := newTestClient(t) 39 | defer tc.cleanup() 40 | conn := tc.connect() 41 | ch, _, _ := channelHelper(tc, conn) 42 | 43 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 44 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS) 45 | ch.Tx() 46 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 47 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 48 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 49 | ch.TxRollback() 50 | ch.TxCommit() 51 | tc.wait(ch) 52 | if tc.s.queues["q1"].Len() != 0 { 53 | t.Fatalf("Tx Rollback still put messages in queue") 54 | } 55 | } 56 | 57 | func TestTxCommitAckNack(t *testing.T) { 58 | // 59 | // Setup 60 | // 61 | tc := newTestClient(t) 62 | defer tc.cleanup() 63 | conn := tc.connect() 64 | ch, _, _ := channelHelper(tc, conn) 65 | 66 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 67 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS) 68 | ch.Tx() 69 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 70 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 71 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 72 | ch.TxCommit() 73 | cTag := util.RandomId() 74 | deliveries, _ := ch.Consume("q1", cTag, false, false, false, false, NO_ARGS) 75 | <-deliveries 76 | <-deliveries 77 | lastMsg := <-deliveries 78 | lastMsg.Ack(true) 79 | tc.wait(ch) 80 | if len(tc.connFromServer().channels[1].awaitingAcks) != 3 { 81 | t.Fatalf("Acks were not held in tx") 82 | } 83 | ch.TxCommit() 84 | if len(tc.connFromServer().channels[1].awaitingAcks) != 0 { 85 | t.Fatalf("Acks were not processed on commit") 86 | } 87 | } 88 | 89 | func TestTxRollbackAckNack(t *testing.T) { 90 | // 91 | // Setup 92 | // 93 | tc := newTestClient(t) 94 | defer tc.cleanup() 95 | conn := tc.connect() 96 | ch, _, _ := channelHelper(tc, conn) 97 | 98 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS) 99 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS) 100 | ch.Tx() 101 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 102 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 103 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG) 104 | ch.TxCommit() 105 | 106 | cTag := util.RandomId() 107 | deliveries, _ := ch.Consume("q1", cTag, false, false, false, false, NO_ARGS) 108 | <-deliveries 109 | <-deliveries 110 | lastMsg := <-deliveries 111 | lastMsg.Nack(true, true) 112 | tc.wait(ch) 113 | if len(tc.connFromServer().channels[1].awaitingAcks) != 3 { 114 | t.Fatalf("Messages were acked despite rollback") 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /server/txMethods.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/jeffjenkins/dispatchd/amqp" 5 | ) 6 | 7 | func (channel *Channel) txRoute(methodFrame amqp.MethodFrame) *amqp.AMQPError { 8 | switch method := methodFrame.(type) { 9 | case *amqp.TxSelect: 10 | return channel.txSelect(method) 11 | case *amqp.TxCommit: 12 | return channel.txCommit(method) 13 | case *amqp.TxRollback: 14 | return channel.txRollback(method) 15 | } 16 | var classId, methodId = methodFrame.MethodIdentifier() 17 | return amqp.NewHardError(540, "Unable to route method frame", classId, methodId) 18 | } 19 | 20 | func (channel *Channel) txSelect(method *amqp.TxSelect) *amqp.AMQPError { 21 | channel.startTxMode() 22 | channel.SendMethod(&amqp.TxSelectOk{}) 23 | return nil 24 | } 25 | 26 | func (channel *Channel) txCommit(method *amqp.TxCommit) *amqp.AMQPError { 27 | if amqpErr := channel.commitTx(); amqpErr != nil { 28 | return amqpErr 29 | } 30 | channel.SendMethod(&amqp.TxCommitOk{}) 31 | return nil 32 | } 33 | 34 | func (channel *Channel) txRollback(method *amqp.TxRollback) *amqp.AMQPError { 35 | channel.rollbackTx() 36 | channel.SendMethod(&amqp.TxRollbackOk{}) 37 | return nil 38 | } 39 | -------------------------------------------------------------------------------- /static/admin.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Admin 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 |

Server

16 |
Message Store Size: {{server.msgCount}}
17 |
Message Index Size: {{server.msgIndexCount}}
18 |

Exchanges

19 |
20 |
21 | "{{name}}" 22 | (default exchange) 23 | — Type: {{exchange.type}} 24 |
25 | 26 |
27 |

Bindings

28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 |
QueueKey
{{binding.queueName}}{{binding.key}}
38 |
39 |
40 |

Queues

41 |
42 |
43 | {{name}} 44 | 45 | (exclusive: {{queue.connId}}) 46 | 47 |
48 |
49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 |
TagAck?QoSActiveLifetime
{{consumer.tag}}{{consumer.ack}}{{consumer.stats.qos}}{{consumer.stats.active_count}}N/A{{consumer.stats.total}}
66 |
67 | 71 |
72 | 73 |

Connections

74 |
75 |
76 | id: {{id}} – addr: {{conn.address}} — {{conn.channelCount}} channels active 77 |
78 |
79 |

Client Properties

80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 |
namevalue
{{name}}{{value}}
90 |
91 |
92 |
93 | 94 | -------------------------------------------------------------------------------- /static/admin.js: -------------------------------------------------------------------------------- 1 | 2 | REFRESH_INTERVAL = 2000 3 | 4 | var app = angular.module('adminApp', []) 5 | 6 | app.service('serverState', ['$http', '$interval', 7 | function($http, $interval) { 8 | var self = this 9 | self.server = {} 10 | self.refresh = function() { 11 | $http({ 12 | 'url' : '/api/server' 13 | }).then(function(data) { 14 | angular.forEach(self.server, function(value, key) { 15 | delete self.server[key] 16 | }) 17 | angular.forEach(data.data, function(value, key) { 18 | self.server[key] = value 19 | }) 20 | }) 21 | } 22 | self.refresh() 23 | $interval(self.refresh, REFRESH_INTERVAL) 24 | }]) 25 | 26 | app.controller('ServerController', ['$scope', 'serverState', 27 | function($scope, serverState) { 28 | $scope.server = serverState.server 29 | }]) 30 | -------------------------------------------------------------------------------- /stats/stats.go: -------------------------------------------------------------------------------- 1 | package stats 2 | 3 | import ( 4 | "github.com/rcrowley/go-metrics" 5 | "time" 6 | ) 7 | 8 | func MakeHistogram(name string) metrics.Histogram { 9 | return metrics.GetOrRegisterHistogram( 10 | name, 11 | metrics.DefaultRegistry, 12 | metrics.NewUniformSample(10000), 13 | ) 14 | } 15 | 16 | func RecordHisto(histo metrics.Histogram, start int64) { 17 | histo.Update(time.Now().UnixNano() - start) 18 | } 19 | 20 | func Start() int64 { 21 | return time.Now().UnixNano() 22 | } 23 | 24 | type Histogram metrics.Histogram 25 | -------------------------------------------------------------------------------- /stats/stats_test.go: -------------------------------------------------------------------------------- 1 | package stats 2 | 3 | import ( 4 | "github.com/rcrowley/go-metrics" 5 | "testing" 6 | ) 7 | 8 | func TestStats(t *testing.T) { 9 | // This is basically for code coverage. These are simple wrappers. 10 | // I'll add more tests once this start needing to do more. 11 | var m = MakeHistogram("hello") 12 | if m != metrics.Get("hello") { 13 | t.Errorf("Got different histogram from MakeHistogram") 14 | } 15 | 16 | var start = Start() 17 | RecordHisto(m, start-20) 18 | if m.Max() < 19 || m.Max() > start { 19 | t.Errorf("Bad value in histo %d", m.Max()) 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "math/rand" 5 | "sync/atomic" 6 | "time" 7 | ) 8 | 9 | var counter int64 10 | 11 | func init() { 12 | rand.Seed(time.Now().UnixNano()) 13 | // System Integrity Note: 14 | // 15 | // Start the counter at the time of server boot. This is so that we have 16 | // message IDs which are consistent within a single server even if it is 17 | // restarted. One implication of using the time is that if the server boot 18 | // time is earlier than the previous boot time durable messages may be 19 | // loaded out of order 20 | counter = time.Now().UnixNano() 21 | } 22 | 23 | var chars = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890") 24 | 25 | func RandomId() string { 26 | var size = 32 27 | var numChars = len(chars) 28 | id := make([]rune, size) 29 | for i := range id { 30 | id[i] = chars[rand.Intn(numChars)] 31 | } 32 | return string(id) 33 | } 34 | 35 | func NextId() int64 { 36 | return atomic.AddInt64(&counter, 1) 37 | } 38 | -------------------------------------------------------------------------------- /util/util_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "regexp" 5 | "testing" 6 | "time" 7 | ) 8 | 9 | func TestRandomId(t *testing.T) { 10 | var re, err = regexp.Compile("^\\w{32}$") 11 | if err != nil { 12 | t.Errorf(err.Error()) 13 | } 14 | for i := 0; i < 100; i++ { 15 | var r = RandomId() 16 | if !re.MatchString(r) { 17 | t.Errorf("Did not match string: '%s'", r) 18 | } 19 | } 20 | } 21 | 22 | func TestNextId(t *testing.T) { 23 | var now = time.Now().UnixNano() 24 | var next = NextId() 25 | var nextNext = NextId() 26 | if now < next { 27 | t.Errorf("NextId was less than current time") 28 | } 29 | if next+1 != nextNext { 30 | t.Errorf( 31 | "subsequent NextId calls do not produce incrementing numbers: %d, %d", 32 | next, 33 | nextNext, 34 | ) 35 | } 36 | } 37 | --------------------------------------------------------------------------------