├── .dockerignore
├── .gitignore
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── adminserver
└── admin.go
├── amqp-0-9-1.xml
├── amqp
├── amqp.proto
├── amqpreader.go
├── amqpwriter.go
├── constants_generated.go
├── domains_generated.go
├── errors.go
├── messages.proto
├── protocol_generated.proto
├── protocol_protobuf_readwrite_generated.go
├── readwrite_test.go
├── table.go
├── table_test.go
├── testlib.go
├── tx.go
├── tx_test.go
└── types.go
├── amqp0-9-1.extended.xml
├── amqpgen
├── amqpgen.go
└── templates.go
├── binding
├── binding.go
└── binding_test.go
├── config.default.json
├── consumer
├── consumer.go
└── consumer_test.go
├── dev
└── config.json
├── dispatchd
├── config.go
└── main.go
├── exchange
├── exchange.go
└── exchange_test.go
├── gen
└── server.proto
├── msgstore
├── msgstore.go
├── msgstore_test.go
└── testlib.go
├── persist
├── persist.go
└── persist_test.go
├── queue
├── queue.go
└── queue_test.go
├── scripts
├── benchmark_helper.sh
└── cover.py
├── server
├── auth.go
├── basicMethods.go
├── channel.go
├── channelMethods.go
├── connection.go
├── connectionMethods.go
├── exchangeMethods.go
├── queueMethods.go
├── server.go
├── server_consumer_test.go
├── server_exchange_test.go
├── server_publish_test.go
├── server_queue_test.go
├── server_test.go
├── server_tx_test.go
└── txMethods.go
├── static
├── admin.html
└── admin.js
├── stats
├── stats.go
└── stats_test.go
└── util
├── util.go
└── util_test.go
/.dockerignore:
--------------------------------------------------------------------------------
1 | scripts/external/
2 | .git
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pb.go
2 | *.cover
3 | *.db
4 | scripts/external/
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | #
2 | # Everything about this is kind of gross, but it does get a server running
3 | #
4 |
5 | FROM centos:latest
6 | # OS setup
7 | RUN yum install -y make golang git
8 | RUN mkdir -p /app/dispatchd && mkdir -p /data/dispatchd/
9 | RUN yum install -y python-setuptools.noarch gcc-c++ glibc-headers
10 | RUN easy_install mako
11 |
12 | # protobuf
13 | RUN cd /tmp && curl -L -o protobuf-2.6.1.tar.gz https://github.com/google/protobuf/releases/download/v2.6.1/protobuf-2.6.1.tar.gz
14 | RUN cd /tmp && tar -xzf protobuf-2.6.1.tar.gz
15 | RUN cd /tmp/protobuf-2.6.1/ && ./configure && make install
16 |
17 | # Build dispatchd
18 | ENV BUILD_DIR /app/dispatchd/src/github.com/jeffjenkins/dispatchd/
19 | RUN mkdir -p $BUILD_DIR
20 | COPY . $BUILD_DIR
21 | ENV GOPATH /app/dispatchd
22 | RUN cd $BUILD_DIR && PATH=$PATH:$GOPATH/bin make install
23 |
24 | # Runtime configuration
25 | ENV STATIC_PATH $BUILD_DIR/static
26 | RUN cp $BUILD_DIR/config.default.json /etc/dispatchd.json
27 | CMD ["/app/dispatchd/bin/server", "-config-file=/etc/dispatchd.json", "-persist-dir=/data/dispatchd/"]
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2015 Jeffrey Jenkins
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 |
2 | .PHONY: all protoc_present deps gen_all gen_pb gen_amqp build test full_coverage \
3 | real_line_count devserver benchmark_dev benchmark install clean
4 |
5 | PROTOC := protoc -I=${GOPATH}/src:${GOPATH}/src/github.com/gogo/protobuf/protobuf/
6 | PROJECT_PATH := ${GOPATH}/src/github.com/jeffjenkins/dispatchd
7 |
8 | RUN_PORT=5672
9 | PERF_SCRIPT=scripts/external/perf-client/runjava.sh
10 |
11 | all: build
12 |
13 | clean:
14 | rm -Rf scripts/external/
15 | rm -f */*.pb.go
16 | rm -f ${GOPATH}/bin/server
17 |
18 | protoc_present:
19 | which protoc
20 |
21 | deps:
22 | go get github.com/boltdb/bolt \
23 | github.com/gogo/protobuf/gogoproto \
24 | github.com/gogo/protobuf/proto \
25 | github.com/gogo/protobuf/protoc-gen-gogo \
26 | github.com/rcrowley/go-metrics \
27 | github.com/streadway/amqp \
28 | github.com/wadey/gocovmerge \
29 | golang.org/x/crypto/bcrypt
30 |
31 | gen_all: deps gen_pb gen_amqp
32 |
33 | gen_pb: gen_amqp protoc_present
34 | $(PROTOC) --gogo_out=${GOPATH}/src ${PROJECT_PATH}/amqp/*.proto
35 | $(PROTOC) --gogo_out=${GOPATH}/src ${PROJECT_PATH}/gen/*.proto
36 |
37 | gen_amqp:
38 | go run amqpgen/*.go --spec=amqp0-9-1.extended.xml && go fmt github.com/jeffjenkins/dispatchd/...
39 | gofmt -w amqp/*generated*.go
40 |
41 | build: deps gen_all
42 | go build -o ${GOPATH}/dispatchd github.com/jeffjenkins/dispatchd/server
43 |
44 | install: deps gen_all
45 | go install github.com/jeffjenkins/dispatchd/server
46 |
47 | test: deps gen_all
48 | go test -cover github.com/jeffjenkins/dispatchd/...
49 |
50 | full_coverage: test
51 | # Output: $$GOPATH/all.cover
52 | python scripts/cover.py
53 |
54 | real_line_count:
55 | find . | grep '.go$$' | grep -v pb.go | grep -v generated | xargs cat | wc -l
56 |
57 | ${PERF_SCRIPT}:
58 | mkdir -p scripts/external/
59 | curl -o scripts/external/perf-client.tar.gz 'https://www.rabbitmq.com/releases/rabbitmq-java-client/v3.5.6/rabbitmq-java-client-bin-3.5.6.tar.gz'
60 | tar -C scripts/external/ -zxf scripts/external/perf-client.tar.gz
61 | mv scripts/external/rabbitmq-java-client-bin-3.5.6 scripts/external/perf-client/
62 |
63 | devserver: install
64 | go install github.com/jeffjenkins/dispatchd/server
65 | STATIC_PATH=${GOPATH}/src/github.com/jeffjenkins/dispatchd/static \
66 | ${GOPATH}/bin/server \
67 | -config-file ${GOPATH}/src/github.com/jeffjenkins/dispatchd/dev/config.json
68 |
69 | benchmark_dev: scripts/external/perf-client/runjava.sh
70 | RUN_PORT=1111 scripts/benchmark_helper.sh
71 |
72 | benchmark: scripts/external/perf-client/runjava.sh
73 | RUN_PORT=${RUN_PORT} scripts/benchmark_helper.sh
74 |
75 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # dispatchd - A Message Broker and Queue Server
2 |
3 | ## Status
4 |
5 | `dispatchd` is in alpha.
6 |
7 | It generally works but is not hardened enough for production use. More on what it can and can't do below
8 |
9 | ## Features
10 |
11 | * Basically all (and many optional) amqp 0-9-1 features are supported
12 | * Some rabbitmq extensions are implemented:
13 | * nack (almost: the same consumer could receive the message again)
14 | * internal exchanges (flag exists, but currently unused)
15 | * auto-delete exchanges
16 | * Rabbit's reinterpretation of basic.qos
17 | * There is a simple admin page that can show basic info about what's
18 | happening in the server
19 |
20 | Notably missing from the features in the amqp spec are:
21 |
22 | * support for multiple priority levels
23 | * handling of queue/memory limits being exceeded
24 |
25 | ## Configuration
26 |
27 | There are command line flags for basic configuration:
28 |
29 | -admin-port int
30 | Port for admin server. Default: 8080
31 | -amqp-port int
32 | Port for amqp protocol messages. Default: 5672
33 | -config-file string
34 | Directory for the server and message database files. Default: do not read a config file
35 | -debug-port int
36 | Port for the golang debug handlers. Default: 6060
37 | -persist-dir string
38 | Directory for the server and message database files. Default: /data/dispatchd/
39 |
40 | These options can be overridden if `-config-file` is specified. The config file is JSON and will complain loudly if any types don't look right rather than ignoring or working around them.
41 |
42 | Right now the only config file exclusive options are for users and passwords. In the future the config file will have tuning parameters as well.
43 |
44 | ## Running Dispatchd
45 |
46 | Dispatchd is currently only packaged as a docker image. You can run it with this command:
47 |
48 | docker run \
49 | -p=8080:8080 \
50 | -p=5672:5672 \
51 | --volume=YOUR_CONFIG_FILE:/etc/dispatchd.json \
52 | --volume=YOUR_DATA_DIR:/data/dispatchd/ \
53 | dispatchd/dispatchd
54 |
55 | Config file can be left out for the default behaviors. The data volume needs
56 | to be specified so that data is persisted outside of the container.
57 |
58 | ## Security/Auth
59 |
60 | Dispatchd uses SASL PLAIN auth as required by the amqp spec. There is a default user (user: guest, pw: guest) which is available if there is no config file. If there is a config file the user entries look like this:
61 |
62 | {
63 | "users" : {
64 | "guest" : {
65 | "password_bcrypt_base64" : "JDJhJDExJENobGk4dG5rY0RGemJhTjhsV21xR3VNNnFZZ1ZqTzUzQWxtbGtyMHRYN3RkUHMuYjF5SUt5"
66 | }
67 | }
68 | }
69 |
70 | Passwords are generated using bcrypt and then base64 encoded.
71 |
72 | ## Performance compared to RabbitMQ
73 |
74 | All perf testing is done with RabbitMQ's Java perf testing tool. Generally using this command line:
75 |
76 | ./runjava.sh com.rabbitmq.examples.PerfTest --exchange perf-test -uri amqp://guest:guest@localhost:5672 --queue some-queue --consumers 4 --producers 2 --qos 100
77 |
78 | On a late 2014 i7 mac mini the results were as follows:
79 |
80 | RabbitMQ Send: ~13000 msg/s, consistent
81 | RabbitMQ Recv: ~10000 msg/s, consistent
82 | Dispatchd Send: ~18000 msg/s, varying between 15k and 22k
83 | Dispatchd Recv: ~18000 msg/s, consistent
84 |
85 | It is unclear whether this difference in performance would go away if the server had complete feature parity with Rabbit. Based on the feature diff it isn't clear why it would, but Rabbit is highly tuned and extremely performant.
86 |
87 | With the `-flag persistent` performance drops a bit:
88 |
89 | RabbitMQ Send: ~9000k msg/s, varying between 6 and 12k
90 | RabbitMQ Recv: ~7000k msg/s, consistent
91 | Dispatchd Send: ~13500k msg/s, varying between 11 and 15k
92 | Dispatchd Recv: ~13000k msg/s, varying between 11 and 15k
93 |
94 | The one thing to note about Dispatchd's send (publish) performance here is that it does not have any internal flow control, so it can get backlogged writing messages to disk. It could be that Rabbit is doing a sustainable 9k and Dispatchd would lose way more messages than come in during one coalesce interval.
95 |
96 | On the Receieve (deliver) side, Dispatchd reconciles messages which don't need to by persisted (because they have already been delivered/acked) and so there is no performance hit to persistence if all messages are delivered before the next write to disk happens (every 200ms by default).
97 |
98 | ## Testing and Code Coverage
99 |
100 | Dispatchd has a fairly extensive test suite. Almost all of the major functions are tested and test coverage—ignoring generated code—is around 80%
101 |
102 | ## What's Next? How do I request changes?
103 |
104 | Non-trivial changes are tracked through [github issues](https://github.com/jeffjenkins/dispatchd/issues).
--------------------------------------------------------------------------------
/adminserver/admin.go:
--------------------------------------------------------------------------------
1 | package adminserver
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "github.com/jeffjenkins/dispatchd/server"
7 | "github.com/rcrowley/go-metrics"
8 | "net/http"
9 | "os"
10 | )
11 |
12 | func homeJSON(w http.ResponseWriter, r *http.Request, server *server.Server) {
13 | var b, err = json.MarshalIndent(server, "", " ")
14 | if err != nil {
15 | w.Write([]byte(err.Error()))
16 | }
17 | w.Write(b)
18 | }
19 |
20 | func statsJSON(w http.ResponseWriter, r *http.Request, server *server.Server) {
21 | // fmt.Println(metrics.DefaultRegistry)
22 | var b, err = json.MarshalIndent(metrics.DefaultRegistry, "", " ")
23 | if err != nil {
24 | w.Write([]byte(err.Error()))
25 | }
26 | w.Write(b)
27 | }
28 |
29 | func StartAdminServer(server *server.Server, port int) {
30 | // Static files
31 | var path = os.Getenv("STATIC_PATH")
32 | if len(path) == 0 {
33 | panic("No static file path in $STATIC_PATH!")
34 | }
35 | var fileServer = http.FileServer(http.Dir(path))
36 | http.Handle("/static/", http.StripPrefix("/static", fileServer))
37 |
38 | // Home
39 | http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
40 | var p = path + "/admin.html"
41 | fmt.Println(p)
42 | http.ServeFile(w, r, p)
43 | })
44 |
45 | // API
46 | http.HandleFunc("/api/server", func(w http.ResponseWriter, r *http.Request) {
47 | homeJSON(w, r, server)
48 | })
49 |
50 | http.HandleFunc("/api/stats", func(w http.ResponseWriter, r *http.Request) {
51 | statsJSON(w, r, server)
52 | })
53 |
54 | // Boot admin server
55 | fmt.Printf("Admin server on port %d, static files from: %s\n", port, path)
56 | http.ListenAndServe(fmt.Sprintf(":%d", port), nil)
57 | }
58 |
--------------------------------------------------------------------------------
/amqp/amqp.proto:
--------------------------------------------------------------------------------
1 |
2 | package amqp;
3 |
4 | import "github.com/gogo/protobuf/gogoproto/gogo.proto";
5 |
6 | option (gogoproto.marshaler_all) = true;
7 | option (gogoproto.unmarshaler_all) = true;
8 | option (gogoproto.sizer_all) = true;
9 |
10 | message FieldValuePair {
11 | option (gogoproto.goproto_unrecognized) = false;
12 | option (gogoproto.goproto_getters) = false;
13 | optional string key = 1;
14 | optional FieldValue value = 2;
15 | }
16 |
17 | message Table {
18 | option (gogoproto.goproto_unrecognized) = false;
19 | option (gogoproto.goproto_getters) = false;
20 | repeated FieldValuePair table = 1;
21 | }
22 |
23 | message FieldArray {
24 | option (gogoproto.goproto_unrecognized) = false;
25 | option (gogoproto.goproto_getters) = false;
26 | repeated FieldValue value = 1;
27 | }
28 |
29 | message Decimal {
30 | option (gogoproto.goproto_unrecognized) = false;
31 | option (gogoproto.goproto_getters) = false;
32 | optional uint32 scale = 1 [(gogoproto.casttype) = "uint8"];
33 | optional int32 value = 2;
34 | }
35 |
36 | message FieldValue {
37 | option (gogoproto.goproto_unrecognized) = false;
38 | option (gogoproto.goproto_getters) = false;
39 | oneof value {
40 | bool v_boolean = 1;
41 | int32 v_int8 = 2 [(gogoproto.casttype) = "int8"];
42 | uint32 v_uint8 = 3 [(gogoproto.casttype) = "uint8"];
43 | int32 v_int16 = 4 [(gogoproto.casttype) = "int16"];
44 | uint32 v_uint16 = 5 [(gogoproto.casttype) = "uint16"];
45 | int32 v_int32 = 6 ;
46 | uint32 v_uint32 = 7 ;
47 | int64 v_int64 = 8 ;
48 | uint64 v_uint64 = 9 ;
49 | float v_float = 10 ;
50 | double v_double = 11 ;
51 | Decimal v_decimal = 12 ;
52 | string v_shortstr = 13 ; // < 256 bytes
53 | bytes v_longstr = 14 ;
54 | FieldArray v_array = 15 ;
55 | uint64 v_timestamp = 16 ;
56 | Table v_table = 17 ;
57 | bytes v_bytes = 18 ;
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/amqp/amqpreader.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | import (
4 | "bytes"
5 | "encoding/binary"
6 | "errors"
7 | "fmt"
8 | "io"
9 | )
10 |
11 | func ReadFrame(reader io.Reader) (*WireFrame, error) {
12 | // Using little since other functions will assume big
13 |
14 | // get fixed size portion
15 | var incoming = make([]byte, 1+2+4)
16 | var err = binary.Read(reader, binary.LittleEndian, incoming)
17 | if err != nil {
18 | return nil, err
19 | }
20 | var memReader = bytes.NewBuffer(incoming)
21 |
22 | var f = &WireFrame{}
23 |
24 | // The reads from memReader are guaranteed to succeed because the 7 bytes
25 | // were allocated above and that is what is read out
26 |
27 | // frame type
28 | frameType, _ := ReadOctet(memReader)
29 | f.FrameType = frameType
30 |
31 | // channel
32 | channel, _ := ReadShort(memReader)
33 | f.Channel = channel
34 |
35 | // Variable length part
36 | var length uint32
37 | err = binary.Read(memReader, binary.BigEndian, &length)
38 |
39 | var slice = make([]byte, length+1)
40 | err = binary.Read(reader, binary.BigEndian, slice)
41 | if err != nil {
42 | return nil, errors.New("Bad frame payload: " + err.Error())
43 | }
44 | f.Payload = slice[0:length]
45 | return f, nil
46 | }
47 |
48 | // Fields
49 |
50 | func ReadOctet(buf io.Reader) (data byte, err error) {
51 | err = binary.Read(buf, binary.BigEndian, &data)
52 | if err != nil {
53 | return 0, errors.New("Could not read byte: " + err.Error())
54 | }
55 | return data, nil
56 | }
57 |
58 | func ReadShort(buf io.Reader) (data uint16, err error) {
59 | err = binary.Read(buf, binary.BigEndian, &data)
60 | if err != nil {
61 | return 0, errors.New("Could not read uint16: " + err.Error())
62 | }
63 | return data, nil
64 | }
65 |
66 | func ReadLong(buf io.Reader) (data uint32, err error) {
67 | err = binary.Read(buf, binary.BigEndian, &data)
68 | if err != nil {
69 | return 0, errors.New("Could not read uint32: " + err.Error())
70 | }
71 | return data, nil
72 | }
73 |
74 | func ReadLonglong(buf io.Reader) (data uint64, err error) {
75 | err = binary.Read(buf, binary.BigEndian, &data)
76 | if err != nil {
77 | return 0, errors.New("Could not read uint64: " + err.Error())
78 | }
79 | return data, nil
80 | }
81 |
82 | func ReadShortstr(buf io.Reader) (string, error) {
83 | var length uint8
84 | var err = binary.Read(buf, binary.BigEndian, &length)
85 | if err != nil {
86 | return "", err
87 | }
88 | var slice = make([]byte, length)
89 | err = binary.Read(buf, binary.BigEndian, slice)
90 | if err != nil {
91 | return "", err
92 | }
93 | return string(slice), nil
94 | }
95 |
96 | func ReadLongstr(buf io.Reader) ([]byte, error) {
97 | var length uint32
98 | var err = binary.Read(buf, binary.BigEndian, &length)
99 | if err != nil {
100 | return nil, err
101 | }
102 | var slice = make([]byte, length)
103 | err = binary.Read(buf, binary.BigEndian, slice)
104 | if err != nil {
105 | return nil, err
106 | }
107 | return slice, err
108 | }
109 |
110 | // Can't get coverage on this easily since I can't currently generate
111 | // timestamp values in Tables because protobuf doesn't give me a type that
112 | // is different from uint64
113 | func ReadTimestamp(buf io.Reader) (uint64, error) { // pragma: nocover
114 | var t uint64
115 | var err = binary.Read(buf, binary.BigEndian, &t)
116 | if err != nil { // pragma: nocover
117 | return 0, errors.New("Could not read uint64")
118 | }
119 | return t, nil // pragma: nocover
120 | }
121 |
122 | func ReadTable(reader io.Reader, strictMode bool) (*Table, error) {
123 | var seen = make(map[string]bool)
124 | var table = &Table{Table: make([]*FieldValuePair, 0)}
125 | var byteData, err = ReadLongstr(reader)
126 | if err != nil {
127 | return nil, errors.New("Error reading table longstr: " + err.Error())
128 | }
129 | var data = bytes.NewBuffer(byteData)
130 | for data.Len() > 0 {
131 | key, err := ReadShortstr(data)
132 | if err != nil {
133 | return nil, errors.New("Error reading key: " + err.Error())
134 | }
135 | if _, found := seen[key]; found {
136 | return nil, fmt.Errorf("Duplicate key in table: %s", key)
137 | }
138 | value, err := readValue(data, strictMode)
139 | if err != nil {
140 | return nil, errors.New("Error reading value for '" + key + "': " + err.Error())
141 | }
142 | table.Table = append(table.Table, &FieldValuePair{Key: &key, Value: value})
143 | seen[key] = true
144 | }
145 | return table, nil
146 | }
147 |
148 | func readValue(reader io.Reader, strictMode bool) (*FieldValue, error) {
149 | var t, err = ReadOctet(reader)
150 | if err != nil {
151 | return nil, err
152 | }
153 |
154 | switch {
155 | case t == 't':
156 | var v, err = ReadOctet(reader)
157 | if err != nil {
158 | return nil, err
159 | }
160 | var vb = v != 0
161 | return &FieldValue{Value: &FieldValue_VBoolean{VBoolean: vb}}, nil
162 | case t == 'b':
163 | var v int8
164 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil {
165 | return nil, err
166 | }
167 | return &FieldValue{Value: &FieldValue_VInt8{VInt8: v}}, nil
168 | case t == 'B' && strictMode:
169 | var v uint8
170 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil {
171 | return nil, err
172 | }
173 | return &FieldValue{Value: &FieldValue_VUint8{VUint8: v}}, nil
174 | case t == 'U' && strictMode || t == 's' && !strictMode:
175 | var v int16
176 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil {
177 | return nil, err
178 | }
179 | return &FieldValue{Value: &FieldValue_VInt16{VInt16: v}}, nil
180 | case t == 'u' && strictMode:
181 | var v uint16
182 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil {
183 | return nil, err
184 | }
185 | return &FieldValue{Value: &FieldValue_VUint16{VUint16: v}}, nil
186 | case t == 'I':
187 | var v int32
188 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil {
189 | return nil, err
190 | }
191 | return &FieldValue{Value: &FieldValue_VInt32{VInt32: v}}, nil
192 | case t == 'i' && strictMode:
193 | var v uint32
194 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil {
195 | return nil, err
196 | }
197 | return &FieldValue{Value: &FieldValue_VUint32{VUint32: v}}, nil
198 | case t == 'L' && strictMode || t == 'l' && !strictMode:
199 | var v int64
200 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil {
201 | return nil, err
202 | }
203 | return &FieldValue{Value: &FieldValue_VInt64{VInt64: v}}, nil
204 | case t == 'l' && strictMode:
205 | var v uint64
206 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil {
207 | return nil, err
208 | }
209 | return &FieldValue{Value: &FieldValue_VUint64{VUint64: v}}, nil
210 | case t == 'f':
211 | var v float32
212 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil {
213 | return nil, err
214 | }
215 | return &FieldValue{Value: &FieldValue_VFloat{VFloat: v}}, nil
216 | case t == 'd':
217 | var v float64
218 | if err = binary.Read(reader, binary.BigEndian, &v); err != nil {
219 | return nil, err
220 | }
221 | return &FieldValue{Value: &FieldValue_VDouble{VDouble: v}}, nil
222 | case t == 'D':
223 | var scale uint8 = 0
224 | var val int32 = 0
225 | var v = Decimal{&scale, &val}
226 | if err = binary.Read(reader, binary.BigEndian, v.Scale); err != nil {
227 | return nil, err
228 | }
229 | if err = binary.Read(reader, binary.BigEndian, v.Value); err != nil {
230 | return nil, err
231 | }
232 | return &FieldValue{Value: &FieldValue_VDecimal{VDecimal: &v}}, nil
233 | case t == 's' && strictMode:
234 | v, err := ReadShortstr(reader)
235 | if err != nil {
236 | return nil, err
237 | }
238 | return &FieldValue{Value: &FieldValue_VShortstr{VShortstr: v}}, nil
239 | case t == 'S':
240 | v, err := ReadLongstr(reader)
241 | if err != nil {
242 | return nil, err
243 | }
244 | return &FieldValue{Value: &FieldValue_VLongstr{VLongstr: v}}, nil
245 | case t == 'A':
246 | v, err := readArray(reader, strictMode)
247 | if err != nil {
248 | return nil, err
249 | }
250 | return &FieldValue{Value: &FieldValue_VArray{VArray: &FieldArray{Value: v}}}, nil
251 | case t == 'T':
252 | v, err := ReadTimestamp(reader)
253 | if err != nil {
254 | return nil, err
255 | }
256 | return &FieldValue{Value: &FieldValue_VTimestamp{VTimestamp: v}}, nil
257 | case t == 'F':
258 | v, err := ReadTable(reader, strictMode)
259 | if err != nil {
260 | return nil, err
261 | }
262 | return &FieldValue{Value: &FieldValue_VTable{VTable: v}}, nil
263 | case t == 'V':
264 | return nil, nil
265 | case t == 'x':
266 | v, err := ReadLongstr(reader)
267 | if err != nil {
268 | return nil, err
269 | }
270 | return &FieldValue{Value: &FieldValue_VBytes{VBytes: v}}, nil
271 | }
272 | return nil, fmt.Errorf("Unknown table value type '%c' (%d)", t, t)
273 | }
274 |
275 | func readArray(reader io.Reader, strictMode bool) ([]*FieldValue, error) {
276 | var ret = make([]*FieldValue, 0, 0)
277 | var longstr, errs = ReadLongstr(reader)
278 | if errs != nil {
279 | return nil, errs
280 | }
281 | var data = bytes.NewBuffer(longstr)
282 | for data.Len() > 0 {
283 | var value, err = readValue(data, strictMode)
284 | if err != nil {
285 | return nil, err
286 | }
287 | ret = append(ret, value)
288 | }
289 | return ret, nil
290 | }
291 |
292 | func (props *BasicContentHeaderProperties) ReadProps(flags uint16, reader io.Reader, strictMode bool) (err error) {
293 | if MaskContentType&flags != 0 {
294 | v, err := ReadShortstr(reader)
295 | props.ContentType = &v
296 | if err != nil {
297 | return err
298 | }
299 | }
300 | if MaskContentEncoding&flags != 0 {
301 | v, err := ReadShortstr(reader)
302 | props.ContentEncoding = &v
303 | if err != nil {
304 | return err
305 | }
306 | }
307 | if MaskHeaders&flags != 0 {
308 | v, err := ReadTable(reader, strictMode)
309 | props.Headers = v
310 | if err != nil {
311 | return err
312 | }
313 | }
314 | if MaskDeliveryMode&flags != 0 {
315 | v, err := ReadOctet(reader)
316 | props.DeliveryMode = &v
317 | if err != nil {
318 | return err
319 | }
320 | }
321 | if MaskPriority&flags != 0 {
322 | v, err := ReadOctet(reader)
323 | props.Priority = &v
324 | if err != nil {
325 | return err
326 | }
327 | }
328 | if MaskCorrelationId&flags != 0 {
329 | v, err := ReadShortstr(reader)
330 | props.CorrelationId = &v
331 | if err != nil {
332 | return err
333 | }
334 | }
335 | if MaskReplyTo&flags != 0 {
336 | v, err := ReadShortstr(reader)
337 | props.ReplyTo = &v
338 | if err != nil {
339 | return err
340 | }
341 | }
342 | if MaskExpiration&flags != 0 {
343 | v, err := ReadShortstr(reader)
344 | props.Expiration = &v
345 | if err != nil {
346 | return err
347 | }
348 | }
349 | if MaskMessageId&flags != 0 {
350 | v, err := ReadShortstr(reader)
351 | props.MessageId = &v
352 | if err != nil {
353 | return err
354 | }
355 | }
356 | if MaskTimestamp&flags != 0 {
357 | v, err := ReadLonglong(reader)
358 | props.Timestamp = &v
359 | if err != nil {
360 | return err
361 | }
362 | }
363 | if MaskType&flags != 0 {
364 | v, err := ReadShortstr(reader)
365 | props.Type = &v
366 | if err != nil {
367 | return err
368 | }
369 | }
370 | if MaskUserId&flags != 0 {
371 | v, err := ReadShortstr(reader)
372 | props.UserId = &v
373 | if err != nil {
374 | return err
375 | }
376 | }
377 | if MaskAppId&flags != 0 {
378 | v, err := ReadShortstr(reader)
379 | props.AppId = &v
380 | if err != nil {
381 | return err
382 | }
383 | }
384 | if MaskReserved&flags != 0 {
385 | v, err := ReadShortstr(reader)
386 | props.Reserved = &v
387 | if err != nil {
388 | return err
389 | }
390 | }
391 | return nil
392 | }
393 |
--------------------------------------------------------------------------------
/amqp/amqpwriter.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | import (
4 | "bytes"
5 | "encoding/binary"
6 | "errors"
7 | "io"
8 | )
9 |
10 | func WriteFrame(buf io.Writer, frame *WireFrame) {
11 | bb := make([]byte, 0, 1+2+4+len(frame.Payload)+2)
12 | buf2 := bytes.NewBuffer(bb)
13 | WriteOctet(buf2, frame.FrameType)
14 | WriteShort(buf2, frame.Channel)
15 | WriteLongstr(buf2, frame.Payload)
16 | // buf.Write(frame.Payload.Bytes())
17 | WriteFrameEnd(buf2)
18 |
19 | // binary.LittleEndian since we want to stick to the system
20 | // byte order and the other write functions are writing BigEndian
21 | // TODO: error checking
22 | binary.Write(buf, binary.LittleEndian, buf2.Bytes())
23 | }
24 |
25 | // Constants
26 |
27 | func WriteProtocolHeader(buf io.Writer) error {
28 | return binary.Write(buf, binary.BigEndian, []byte{'A', 'M', 'Q', 'P', 0, 0, 9, 1})
29 | }
30 |
31 | func WriteVersion(buf io.Writer) error {
32 | return binary.Write(buf, binary.BigEndian, []byte{0, 9})
33 | }
34 |
35 | func WriteFrameEnd(buf io.Writer) error {
36 | return binary.Write(buf, binary.BigEndian, byte(0xCE))
37 | }
38 |
39 | // Fields
40 |
41 | func WriteOctet(buf io.Writer, b byte) error {
42 | return binary.Write(buf, binary.BigEndian, b)
43 | }
44 |
45 | func WriteShort(buf io.Writer, i uint16) error {
46 | return binary.Write(buf, binary.BigEndian, i)
47 | }
48 |
49 | func WriteLong(buf io.Writer, i uint32) error {
50 | return binary.Write(buf, binary.BigEndian, i)
51 | }
52 |
53 | func WriteLonglong(buf io.Writer, i uint64) error {
54 | return binary.Write(buf, binary.BigEndian, i)
55 | }
56 |
57 | func WriteStringChar(buf io.Writer, b byte) error {
58 | return binary.Write(buf, binary.BigEndian, b)
59 | }
60 |
61 | func WriteShortstr(buf io.Writer, s string) error {
62 | if len(s) > int(MaxShortStringLength) {
63 | return errors.New("String too long for short string")
64 | }
65 | err := binary.Write(buf, binary.BigEndian, byte(len(s)))
66 | if err != nil {
67 | return errors.New("Could not write bytes: " + err.Error())
68 | }
69 | return binary.Write(buf, binary.BigEndian, []byte(s))
70 | }
71 |
72 | func WriteLongstr(buf io.Writer, bytes []byte) (err error) {
73 | if err = binary.Write(buf, binary.BigEndian, uint32(len(bytes))); err != nil {
74 | return
75 | }
76 | if err = binary.Write(buf, binary.BigEndian, bytes); err != nil {
77 | return
78 | }
79 | return nil
80 | }
81 |
82 | func WriteTimestamp(buf io.Writer, timestamp uint64) error {
83 | return binary.Write(buf, binary.BigEndian, timestamp)
84 | }
85 |
86 | func WriteTable(writer io.Writer, table *Table) error {
87 | var buf = bytes.NewBuffer(make([]byte, 0))
88 | for _, kv := range table.Table {
89 | if err := WriteShortstr(buf, *kv.Key); err != nil {
90 | return err
91 | }
92 | if err := writeValue(buf, kv.Value); err != nil {
93 | return err
94 | }
95 | }
96 | return WriteLongstr(writer, buf.Bytes())
97 | }
98 |
99 | func writeArray(writer io.Writer, array []*FieldValue) error {
100 | var buf = bytes.NewBuffer([]byte{})
101 | for _, v := range array {
102 | if err := writeValue(buf, v); err != nil {
103 | return err
104 | }
105 | }
106 | return WriteLongstr(writer, buf.Bytes())
107 | }
108 |
109 | func writeValue(writer io.Writer, value *FieldValue) (err error) {
110 | switch v := value.Value.(type) {
111 | case *FieldValue_VBoolean:
112 | if err = binary.Write(writer, binary.BigEndian, byte('t')); err == nil {
113 | if v.VBoolean {
114 | err = WriteOctet(writer, uint8(1))
115 | } else {
116 | err = WriteOctet(writer, uint8(0))
117 | }
118 | }
119 | case *FieldValue_VInt8:
120 | if err = binary.Write(writer, binary.BigEndian, byte('b')); err == nil {
121 | err = binary.Write(writer, binary.BigEndian, int8(v.VInt8))
122 | }
123 | case *FieldValue_VUint8:
124 | if err = binary.Write(writer, binary.BigEndian, byte('B')); err == nil {
125 | err = binary.Write(writer, binary.BigEndian, uint8(v.VUint8))
126 | }
127 | case *FieldValue_VInt16:
128 | if err = binary.Write(writer, binary.BigEndian, byte('U')); err == nil {
129 | err = binary.Write(writer, binary.BigEndian, int16(v.VInt16))
130 | }
131 | case *FieldValue_VUint16:
132 | if err = binary.Write(writer, binary.BigEndian, byte('u')); err == nil {
133 | err = binary.Write(writer, binary.BigEndian, uint16(v.VUint16))
134 | }
135 | case *FieldValue_VInt32:
136 | if err = binary.Write(writer, binary.BigEndian, byte('I')); err == nil {
137 | err = binary.Write(writer, binary.BigEndian, int32(v.VInt32))
138 | }
139 | case *FieldValue_VUint32:
140 | if err = binary.Write(writer, binary.BigEndian, byte('i')); err == nil {
141 | err = binary.Write(writer, binary.BigEndian, uint32(v.VUint32))
142 | }
143 | case *FieldValue_VInt64:
144 | if err = binary.Write(writer, binary.BigEndian, byte('L')); err == nil {
145 | err = binary.Write(writer, binary.BigEndian, int64(v.VInt64))
146 | }
147 | case *FieldValue_VUint64:
148 | if err = binary.Write(writer, binary.BigEndian, byte('l')); err == nil {
149 | err = binary.Write(writer, binary.BigEndian, uint64(v.VUint64))
150 | }
151 | case *FieldValue_VFloat:
152 | if err = binary.Write(writer, binary.BigEndian, byte('f')); err == nil {
153 | err = binary.Write(writer, binary.BigEndian, float32(v.VFloat))
154 | }
155 | case *FieldValue_VDouble:
156 | if err = binary.Write(writer, binary.BigEndian, byte('d')); err == nil {
157 | err = binary.Write(writer, binary.BigEndian, float64(v.VDouble))
158 | }
159 | case *FieldValue_VDecimal:
160 | if err = binary.Write(writer, binary.BigEndian, byte('D')); err == nil {
161 | if err = binary.Write(writer, binary.BigEndian, byte(*v.VDecimal.Scale)); err == nil {
162 | err = binary.Write(writer, binary.BigEndian, uint32(*v.VDecimal.Value))
163 | }
164 | }
165 | case *FieldValue_VShortstr:
166 | if err = WriteOctet(writer, byte('s')); err == nil {
167 | err = WriteShortstr(writer, v.VShortstr)
168 | }
169 | case *FieldValue_VLongstr:
170 | if err = WriteOctet(writer, byte('S')); err == nil {
171 | err = WriteLongstr(writer, v.VLongstr)
172 | }
173 | case *FieldValue_VArray:
174 | if err = WriteOctet(writer, byte('A')); err == nil {
175 | err = writeArray(writer, v.VArray.Value)
176 | }
177 | case *FieldValue_VTimestamp:
178 | if err = WriteOctet(writer, byte('T')); err == nil {
179 | err = WriteTimestamp(writer, v.VTimestamp)
180 | }
181 | case *FieldValue_VTable:
182 | if err = WriteOctet(writer, byte('F')); err == nil {
183 | err = WriteTable(writer, v.VTable)
184 | }
185 | case nil:
186 | err = binary.Write(writer, binary.BigEndian, byte('V'))
187 | default:
188 | panic("unsupported type!")
189 | }
190 | return
191 | }
192 |
193 | func (props *BasicContentHeaderProperties) WriteProps(writer io.Writer) (flags uint16, err error) {
194 | if props.ContentType != nil {
195 | flags = flags | MaskContentType
196 | err = WriteShortstr(writer, *props.ContentType)
197 | if err != nil {
198 | return
199 | }
200 | }
201 | if props.ContentEncoding != nil {
202 | flags = flags | MaskContentEncoding
203 | err = WriteShortstr(writer, *props.ContentEncoding)
204 | if err != nil {
205 | return
206 | }
207 | }
208 | if props.Headers != nil {
209 | flags = flags | MaskHeaders
210 | err = WriteTable(writer, props.Headers)
211 | if err != nil {
212 | return
213 | }
214 | }
215 | if props.DeliveryMode != nil {
216 | flags = flags | MaskDeliveryMode
217 | err = WriteOctet(writer, *props.DeliveryMode)
218 | if err != nil {
219 | return
220 | }
221 | }
222 | if props.Priority != nil {
223 | flags = flags | MaskPriority
224 | err = WriteOctet(writer, *props.Priority)
225 | if err != nil {
226 | return
227 | }
228 | }
229 | if props.CorrelationId != nil {
230 | flags = flags | MaskCorrelationId
231 | err = WriteShortstr(writer, *props.CorrelationId)
232 | if err != nil {
233 | return
234 | }
235 | }
236 | if props.ReplyTo != nil {
237 | flags = flags | MaskReplyTo
238 | err = WriteShortstr(writer, *props.ReplyTo)
239 | if err != nil {
240 | return
241 | }
242 | }
243 | if props.Expiration != nil {
244 | flags = flags | MaskExpiration
245 | err = WriteShortstr(writer, *props.Expiration)
246 | if err != nil {
247 | return
248 | }
249 | }
250 | if props.MessageId != nil {
251 | flags = flags | MaskMessageId
252 | err = WriteShortstr(writer, *props.MessageId)
253 | if err != nil {
254 | return
255 | }
256 | }
257 | if props.Timestamp != nil {
258 | flags = flags | MaskTimestamp
259 | err = WriteLonglong(writer, *props.Timestamp)
260 | if err != nil {
261 | return
262 | }
263 | }
264 | if props.Type != nil {
265 | flags = flags | MaskType
266 | err = WriteShortstr(writer, *props.Type)
267 | if err != nil {
268 | return
269 | }
270 | }
271 | if props.UserId != nil {
272 | flags = flags | MaskUserId
273 | err = WriteShortstr(writer, *props.UserId)
274 | if err != nil {
275 | return
276 | }
277 | }
278 | if props.AppId != nil {
279 | flags = flags | MaskAppId
280 | err = WriteShortstr(writer, *props.AppId)
281 | if err != nil {
282 | return
283 | }
284 | }
285 | if props.Reserved != nil {
286 | flags = flags | MaskReserved
287 | err = WriteShortstr(writer, *props.Reserved)
288 | if err != nil {
289 | return
290 | }
291 | }
292 | return
293 | }
294 |
--------------------------------------------------------------------------------
/amqp/constants_generated.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | var MaxShortStringLength uint8 = 255
4 | var FrameMethod = 1
5 | var FrameHeader = 2
6 | var FrameBody = 3
7 | var FrameHeartbeat = 8
8 | var FrameMinSize = 4096
9 | var FrameEnd = 206
10 |
11 | // Indicates that the method completed successfully. This reply code is
12 | // reserved for future use - the current protocol design does not use positive
13 | // confirmation and reply codes are sent only in case of an error.
14 | var ReplySuccess = 200
15 |
16 | // The client attempted to transfer content larger than the server could accept
17 | // at the present time. The client may retry at a later time.
18 | var ContentTooLarge = 311
19 |
20 | // When the exchange cannot deliver to a consumer when the immediate flag is
21 | // set. As a result of pending data on the queue or the absence of any
22 | // consumers of the queue.
23 | var NoConsumers = 313
24 |
25 | // An operator intervened to close the connection for some reason. The client
26 | // may retry at some later date.
27 | var ConnectionForced = 320
28 |
29 | // The client tried to work with an unknown virtual host.
30 | var InvalidPath = 402
31 |
32 | // The client attempted to work with a server entity to which it has no
33 | // access due to security settings.
34 | var AccessRefused = 403
35 |
36 | // The client attempted to work with a server entity that does not exist.
37 | var NotFound = 404
38 |
39 | // The client attempted to work with a server entity to which it has no
40 | // access because another client is working with it.
41 | var ResourceLocked = 405
42 |
43 | // The client requested a method that was not allowed because some precondition
44 | // failed.
45 | var PreconditionFailed = 406
46 |
47 | // The sender sent a malformed frame that the recipient could not decode.
48 | // This strongly implies a programming error in the sending peer.
49 | var FrameError = 501
50 |
51 | // The sender sent a frame that contained illegal values for one or more
52 | // fields. This strongly implies a programming error in the sending peer.
53 | var SyntaxError = 502
54 |
55 | // The client sent an invalid sequence of frames, attempting to perform an
56 | // operation that was considered invalid by the server. This usually implies
57 | // a programming error in the client.
58 | var CommandInvalid = 503
59 |
60 | // The client attempted to work with a channel that had not been correctly
61 | // opened. This most likely indicates a fault in the client layer.
62 | var ChannelError = 504
63 |
64 | // The peer sent a frame that was not expected, usually in the context of
65 | // a content header and body. This strongly indicates a fault in the peer's
66 | // content processing.
67 | var UnexpectedFrame = 505
68 |
69 | // The server could not complete the method because it lacked sufficient
70 | // resources. This may be due to the client creating too many of some type
71 | // of entity.
72 | var ResourceError = 506
73 |
74 | // The client tried to work with some entity in a manner that is prohibited
75 | // by the server, due to security settings or by some other criteria.
76 | var NotAllowed = 530
77 |
78 | // The client tried to use functionality that is not implemented in the
79 | // server.
80 | var NotImplemented = 540
81 |
82 | // The server could not complete the method because of an internal error.
83 | // The server may require intervention by an operator in order to resume
84 | // normal operations.
85 | var InternalError = 541
86 |
--------------------------------------------------------------------------------
/amqp/domains_generated.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | var ReadMethodId = ReadShort
4 | var WriteMethodId = WriteShort
5 | var ReadReplyText = ReadShortstr
6 | var WriteReplyText = WriteShortstr
7 | var ReadPeerProperties = ReadTable
8 | var WritePeerProperties = WriteTable
9 | var ReadClassId = ReadShort
10 | var WriteClassId = WriteShort
11 | var ReadMessageCount = ReadLong
12 | var WriteMessageCount = WriteLong
13 | var ReadReplyCode = ReadShort
14 | var WriteReplyCode = WriteShort
15 | var ReadQueueName = ReadShortstr
16 | var WriteQueueName = WriteShortstr
17 | var ReadConsumerTag = ReadShortstr
18 | var WriteConsumerTag = WriteShortstr
19 | var ReadPath = ReadShortstr
20 | var WritePath = WriteShortstr
21 | var ReadExchangeName = ReadShortstr
22 | var WriteExchangeName = WriteShortstr
23 | var ReadDeliveryTag = ReadLonglong
24 | var WriteDeliveryTag = WriteLonglong
25 |
--------------------------------------------------------------------------------
/amqp/errors.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | // Soft error (close channel)
4 |
5 | type AMQPError struct {
6 | Code uint16
7 | Class uint16
8 | Method uint16
9 | Msg string
10 | Soft bool
11 | }
12 |
13 | func NewSoftError(code uint16, msg string, class uint16, method uint16) *AMQPError {
14 | return &AMQPError{
15 | Code: code,
16 | Class: class,
17 | Method: method,
18 | Msg: msg,
19 | Soft: true,
20 | }
21 | }
22 |
23 | // Hard error (close connection)
24 |
25 | func NewHardError(code uint16, msg string, class uint16, method uint16) *AMQPError {
26 | return &AMQPError{
27 | Code: code,
28 | Class: class,
29 | Method: method,
30 | Msg: msg,
31 | Soft: false,
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/amqp/messages.proto:
--------------------------------------------------------------------------------
1 |
2 | package amqp;
3 |
4 | import "github.com/gogo/protobuf/gogoproto/gogo.proto";
5 | import "github.com/jeffjenkins/dispatchd/amqp/protocol_generated.proto";
6 |
7 | option (gogoproto.marshaler_all) = true;
8 | option (gogoproto.unmarshaler_all) = true;
9 | option (gogoproto.sizer_all) = true;
10 |
11 | message WireFrame {
12 | option (gogoproto.goproto_unrecognized) = false;
13 | option (gogoproto.goproto_getters) = false;
14 | optional uint32 frameType = 1 [(gogoproto.nullable) = false, (gogoproto.casttype) = "uint8"];
15 | optional uint32 channel = 2 [(gogoproto.nullable) = false, (gogoproto.casttype) = "uint16"];
16 | optional bytes payload = 3 ;
17 | }
18 |
19 | message IndexMessage {
20 | // The ID of the underlying message
21 | optional int64 id = 1 [(gogoproto.nullable) = false];
22 | // The number of outstanding references to the message
23 | optional int32 refs = 2 [(gogoproto.nullable) = false];
24 | optional bool durable = 3 [(gogoproto.nullable) = false];
25 | optional int32 deliveryCount = 4 [(gogoproto.nullable) = false];
26 | optional bool persisted = 5 [(gogoproto.nullable) = false];
27 | }
28 |
29 | message Message {
30 | optional int64 id = 1 [(gogoproto.nullable) = false];
31 | optional amqp.ContentHeaderFrame header = 2;
32 | repeated WireFrame payload = 3;
33 | optional string exchange = 4 [(gogoproto.nullable) = false];
34 | optional string key = 5 [(gogoproto.nullable) = false];
35 | optional amqp.BasicPublish method = 6;
36 | optional uint32 redelivered = 7 [(gogoproto.nullable) = false];
37 | optional int64 local_id = 8 [(gogoproto.nullable) = false];
38 | }
39 |
40 | message QueueMessage {
41 | optional int64 id = 1 [(gogoproto.nullable) = false];
42 | optional int32 deliveryCount = 2 [(gogoproto.nullable) = false];
43 | optional bool durable = 3 [(gogoproto.nullable) = false];
44 | optional uint32 msgSize = 4 [(gogoproto.nullable) = false];
45 | optional int64 localId = 5 [(gogoproto.nullable) = false];
46 | }
47 |
48 | message ContentHeaderFrame {
49 | optional uint32 content_class = 1 [(gogoproto.casttype) = "uint16", (gogoproto.nullable) = false];
50 | optional uint32 content_weight = 2 [(gogoproto.casttype) = "uint16", (gogoproto.nullable) = false];
51 | optional uint64 content_body_size = 3 [(gogoproto.nullable) = false];
52 | optional uint32 property_flags = 4 [(gogoproto.casttype) = "uint16", (gogoproto.nullable) = false];
53 | optional amqp.BasicContentHeaderProperties properties = 5;
54 | }
55 |
56 | message TxMessage {
57 | optional Message msg = 1;
58 | optional string queue_name = 2 [(gogoproto.nullable) = false];
59 | }
60 |
61 | message TxAck {
62 | optional uint64 tag = 1 [(gogoproto.nullable) = false];
63 | optional bool multiple = 2 [(gogoproto.nullable) = false];
64 | optional bool nack = 3 [(gogoproto.nullable) = false];
65 | optional bool requeue_nack = 4 [(gogoproto.nullable) = false];
66 | }
67 |
68 | message UnackedMessage {
69 | optional string consumer_tag = 1 [(gogoproto.nullable) = false];
70 | optional QueueMessage msg = 2;
71 | optional string queue_name = 3 [(gogoproto.nullable) = false];
72 | }
73 |
--------------------------------------------------------------------------------
/amqp/readwrite_test.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "math/rand"
7 | "reflect"
8 | "testing"
9 | )
10 |
11 | var testRand *rand.Rand = nil
12 |
13 | func init() {
14 | var source = rand.NewSource(int64(1234))
15 | testRand = rand.New(source)
16 | }
17 |
18 | func TestMalformedTable(t *testing.T) {
19 | var table = NewTable()
20 | table.SetKey("hi", "bye")
21 | table.SetKey("mi", "bye")
22 | var method = &ExchangeBind{
23 | Destination: "dest",
24 | Source: "src",
25 | RoutingKey: "rk",
26 | NoWait: true,
27 | Arguments: table,
28 | }
29 | var outBuf = bytes.NewBuffer([]byte{})
30 | err := method.Write(outBuf)
31 | if err != nil {
32 | t.Errorf(err.Error())
33 | }
34 | outBytes := outBuf.Bytes()[4:]
35 | // Use this to see which bytes to alter
36 | printWireBytes(outBytes, t)
37 |
38 | // make key too long
39 | outBytes[19] = 250
40 | var inMethod = &ExchangeBind{}
41 | err = inMethod.Read(bytes.NewBuffer(outBytes), true)
42 | if err == nil {
43 | t.Errorf("Successfully read malformed bytes!")
44 | }
45 | outBytes[19] = 2
46 |
47 | // make value the wrong type
48 | outBytes[22] = 'S'
49 | inMethod = &ExchangeBind{}
50 | err = inMethod.Read(bytes.NewBuffer(outBytes), true)
51 | if err == nil {
52 | t.Errorf("Successfully read malformed bytes!")
53 | }
54 | outBytes[22] = 's'
55 |
56 | // duplicate keys
57 | outBytes[28] = 'h'
58 | inMethod = &ExchangeBind{}
59 | err = inMethod.Read(bytes.NewBuffer(outBytes), true)
60 | if err == nil {
61 | t.Errorf("Successfully read malformed bytes!")
62 | }
63 | outBytes[28] = 'm'
64 |
65 | // can't read value type
66 | inMethod = &ExchangeBind{}
67 | outBytes[27] = 7
68 | err = inMethod.Read(bytes.NewBuffer(outBytes), true)
69 | if err == nil {
70 | t.Errorf("Successfully read malformed bytes!")
71 | }
72 | outBytes[27] = 2
73 |
74 | // t.Errorf("fail")
75 | }
76 |
77 | func TestReadValue(t *testing.T) {
78 | tryRead := func(bts []byte, msg string) {
79 | _, err := readValue(bytes.NewBuffer(bts), true)
80 | if err == nil {
81 | t.Errorf(msg)
82 | }
83 | }
84 | var types = []byte{'t', 'b', 'B', 'U', 'u', 'I', 'i', 'L', 'l',
85 | 'f', 'd', 'D', 's', 'S', 'A', 'T', 'F'}
86 | // 'V' isn't included since it has no value
87 | for i := 0; i < len(types); i++ {
88 | tryRead([]byte{types[i]}, fmt.Sprintf("Successfully read malformed value for type %c", types[i]))
89 | }
90 | tryRead([]byte{'D', 1}, "read malformed decimal")
91 |
92 | tryRead([]byte{'~', 1}, "read bad value type")
93 |
94 | // successful timestamp read since the server doesn't really support them
95 | val, err := readValue(bytes.NewBuffer([]byte{'T', 0, 0, 0, 0, 0, 0, 0, 2}), true)
96 | if err != nil {
97 | t.Errorf(err.Error())
98 | }
99 | if val.GetVTimestamp() != uint64(2) {
100 | t.Errorf("Failed to deserialize uint64")
101 | }
102 | // successful 'V' (no value) read, mainly for coverage
103 | val, err = readValue(bytes.NewBuffer([]byte{'V'}), true)
104 | if err != nil {
105 | t.Errorf(err.Error())
106 | }
107 | if val != nil {
108 | t.Errorf("Failed to deserialize 'V' field ")
109 | }
110 | // tryRead()
111 |
112 | }
113 |
114 | func sptr(s string) *string {
115 | return &s
116 | }
117 |
118 | func bptr(b byte) *byte {
119 | return &b
120 | }
121 |
122 | // func allFlags() uint16 {
123 | // return (MaskContentType | MaskContentEncoding | MaskHeaders |
124 | // MaskDeliveryMode | MaskPriority | MaskCorrelationId | MaskReplyTo |
125 | // MaskExpiration | MaskMessageId | MaskTimestamp | MaskType | MaskUserId |
126 | // MaskAppId | MaskReserved)
127 | // }
128 |
129 | func TestReadingContentHeaderProps(t *testing.T) {
130 | time := uint64(1312312)
131 | var props = BasicContentHeaderProperties{
132 | ContentType: sptr("ContentType"),
133 | ContentEncoding: sptr("ContentEncoding"),
134 | Headers: NewTable(),
135 | DeliveryMode: bptr(byte(1)),
136 | Priority: bptr(byte(1)),
137 | CorrelationId: sptr("CorrelationId"),
138 | ReplyTo: sptr("ReplyTo"),
139 | Expiration: sptr("Expiration"),
140 | MessageId: sptr("MessageId"),
141 | Timestamp: &time,
142 | Type: sptr("Type"),
143 | UserId: sptr("UserId"),
144 | AppId: sptr("AppId"),
145 | Reserved: sptr(""),
146 | }
147 | var outBuf = bytes.NewBuffer([]byte{})
148 | flags, err := props.WriteProps(outBuf)
149 | if err != nil {
150 | t.Errorf(err.Error())
151 | }
152 | outBytes := outBuf.Bytes()
153 | // Use subsets of bytes to trigger all failure conditions
154 | for i := 0; i < len(outBytes); i++ {
155 | var partialBuffer = bytes.NewBuffer(outBytes[:i])
156 | var inProps = &BasicContentHeaderProperties{}
157 | err = inProps.ReadProps(flags, partialBuffer, true)
158 | if err == nil {
159 | t.Errorf("Successfully read malformed props. %d/%d bytes read", i, len(outBytes))
160 | }
161 | }
162 | // Succeed in reading all bytes
163 | var partialBuffer = bytes.NewBuffer(outBytes)
164 | var inProps = &BasicContentHeaderProperties{}
165 | err = inProps.ReadProps(flags, partialBuffer, true)
166 | if err != nil {
167 | t.Errorf(err.Error())
168 | }
169 | // check a few random fields
170 | if *inProps.ContentType != "ContentType" {
171 | t.Errorf("Bad content type: %s", *inProps.ContentType)
172 | }
173 | if *inProps.Timestamp != time {
174 | t.Errorf("Bad timestamp: %s", *inProps.Timestamp)
175 | }
176 | }
177 |
178 | func TestReadArrayFailures(t *testing.T) {
179 | _, err := readArray(bytes.NewBuffer([]byte{0, 0, 0, 1, 'L'}), true)
180 | if err == nil {
181 | t.Errorf("Read a malformed array value")
182 | }
183 | }
184 |
185 | func TestWireFrame(t *testing.T) {
186 | // Write frame to bytes
187 | var outFrame = &WireFrame{
188 | FrameType: uint8(10),
189 | Channel: uint16(12311),
190 | Payload: []byte{0, 0, 9, 1},
191 | }
192 | var buf = bytes.NewBuffer(make([]byte, 0))
193 | WriteFrame(buf, outFrame)
194 |
195 | // Read frame from bytes
196 | var outBytes = buf.Bytes()
197 | var inFrame, err = ReadFrame(bytes.NewBuffer(outBytes))
198 | if err != nil {
199 | t.Errorf(err.Error())
200 | }
201 | if !reflect.DeepEqual(inFrame, outFrame) {
202 | t.Errorf("Couldn't read the frame that was written")
203 | }
204 | // Incomplete frames
205 | for i := 0; i < len(outBytes); i++ {
206 | var noTypeBuf = bytes.NewBuffer(outBytes[:i])
207 | _, err = ReadFrame(noTypeBuf)
208 | if err == nil {
209 | t.Errorf("No error on malformed frame. %d/%d bytes read", i, len(buf.Bytes()))
210 | }
211 | }
212 | }
213 |
214 | func TestMethodTypes(t *testing.T) {
215 | for _, method := range methodsForTesting() {
216 | var outBuf = bytes.NewBuffer([]byte{})
217 | err := method.Write(outBuf)
218 | if err != nil {
219 | t.Errorf(err.Error())
220 | }
221 |
222 | var outBytes = outBuf.Bytes()[4:]
223 | // Try all lengths of bytes below the ones needed
224 | for index, _ := range outBytes {
225 | var inBind = reflect.New(reflect.TypeOf(method).Elem()).Interface().(MethodFrame)
226 | err = inBind.Read(bytes.NewBuffer(outBytes[:index]), true)
227 | if err == nil {
228 | printWireBytes(outBytes[:index], t)
229 | t.Errorf("Parsed malformed request bytes")
230 | return
231 | }
232 | }
233 | // printWireBytes(outBytes, t)
234 | // Try the right set of bytes
235 | var inBind = reflect.New(reflect.TypeOf(method).Elem()).Interface().(MethodFrame)
236 | err = inBind.Read(bytes.NewBuffer(outBytes), true)
237 | if err != nil {
238 | t.Logf("Method is %s", method.MethodName())
239 | printWireBytes(outBytes, t)
240 | t.Errorf(err.Error())
241 | return
242 | }
243 | }
244 | }
245 |
246 | func methodsForTesting() []MethodFrame {
247 | return []MethodFrame{
248 | &ExchangeBind{
249 | Destination: "dest",
250 | Source: "src",
251 | RoutingKey: "rk",
252 | NoWait: true,
253 | Arguments: EverythingTable(),
254 | },
255 | &ConnectionTune{
256 | ChannelMax: uint16(3),
257 | FrameMax: uint32(1),
258 | Heartbeat: uint16(2),
259 | },
260 | &BasicDeliver{
261 | ConsumerTag: string("deliver"),
262 | DeliveryTag: uint64(4),
263 | Redelivered: true,
264 | Exchange: string("ex1"),
265 | RoutingKey: string("rk1"),
266 | },
267 | }
268 | }
269 |
270 | func printWireBytes(bs []byte, t *testing.T) {
271 | t.Logf("Byte count: %d", len(bs))
272 | for i, b := range bs {
273 | t.Logf("%d:(%c %d),", i, b, b)
274 | }
275 | t.Logf("\n")
276 | }
277 |
--------------------------------------------------------------------------------
/amqp/table.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | import (
4 | "fmt"
5 | "reflect"
6 | )
7 |
8 | func NewTable() *Table {
9 | return &Table{Table: make([]*FieldValuePair, 0)}
10 | }
11 |
12 | func EquivalentTables(t1 *Table, t2 *Table) bool {
13 | if len(t1.Table) == 0 && len(t2.Table) == 0 {
14 | return true
15 | }
16 | return reflect.DeepEqual(t1, t2)
17 | }
18 |
19 | func (table *Table) GetKey(key string) *FieldValue {
20 | for _, kv := range table.Table {
21 | var k = *kv.Key
22 | if k == key {
23 | return kv.Value
24 | }
25 | }
26 | return nil
27 | }
28 |
29 | func (table *Table) SetKey(key string, value interface{}) error {
30 | var fieldValue *FieldValue = nil
31 | for _, kv := range table.Table {
32 | var k = *kv.Key
33 | if k == key {
34 | fieldValue = kv.Value
35 | }
36 | }
37 | if fieldValue == nil {
38 | fieldValue = &FieldValue{
39 | Value: nil,
40 | }
41 | table.Table = append(table.Table, &FieldValuePair{
42 | Key: &key,
43 | Value: fieldValue,
44 | })
45 | }
46 | fv, err := calcValue(value)
47 |
48 | if err != nil {
49 | return err
50 | }
51 | fieldValue.Value = fv
52 | return nil
53 | }
54 |
55 | func calcValue(value interface{}) (isFieldValue_Value, error) {
56 | switch v := value.(type) {
57 | case bool:
58 | return &FieldValue_VBoolean{VBoolean: v}, nil
59 | case int8:
60 | return &FieldValue_VInt8{VInt8: v}, nil
61 | case uint8:
62 | return &FieldValue_VUint8{VUint8: v}, nil
63 | case int16:
64 | return &FieldValue_VInt16{VInt16: v}, nil
65 | case uint16:
66 | return &FieldValue_VUint16{VUint16: v}, nil
67 | case int32:
68 | return &FieldValue_VInt32{VInt32: v}, nil
69 | case uint32:
70 | return &FieldValue_VUint32{VUint32: v}, nil
71 | case int64:
72 | return &FieldValue_VInt64{VInt64: v}, nil
73 | case uint64:
74 | return &FieldValue_VUint64{VUint64: v}, nil
75 | case float32:
76 | return &FieldValue_VFloat{VFloat: v}, nil
77 | case float64:
78 | return &FieldValue_VDouble{VDouble: v}, nil
79 | case *Decimal:
80 | return &FieldValue_VDecimal{VDecimal: v}, nil
81 | case string:
82 | return &FieldValue_VShortstr{VShortstr: v}, nil
83 | case []byte:
84 | return &FieldValue_VLongstr{VLongstr: v}, nil
85 | case *FieldArray:
86 | return &FieldValue_VArray{VArray: v}, nil
87 | // TODO: not currently reachable since uint64 will take it
88 | // case timestamp:
89 | // return &FieldValue_VTimestamp{VTimestamp: v}
90 | case *Table:
91 | return &FieldValue_VTable{VTable: v}, nil
92 | }
93 | return nil, fmt.Errorf("Field value has invalid type for table/array: %s", value)
94 | }
95 |
96 | func NewFieldArray() *FieldArray {
97 | return &FieldArray{Value: make([]*FieldValue, 0)}
98 | }
99 |
100 | func (fa *FieldArray) AppendFA(value interface{}) error {
101 | var fv, err = calcValue(value)
102 | if err != nil {
103 | return err
104 | }
105 | fa.Value = append(fa.Value, &FieldValue{Value: fv})
106 | return nil
107 | }
108 |
--------------------------------------------------------------------------------
/amqp/table_test.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | import (
4 | "bytes"
5 | "github.com/gogo/protobuf/proto"
6 | "math/rand"
7 | "reflect"
8 | "testing"
9 | )
10 |
11 | func TestPersistRoundtripEmpty(t *testing.T) {
12 | // Protobuf returns a nil Table.Table if there are no entries.
13 | // This test makes sure equality with a NewTable() table works
14 | var table = NewTable()
15 | var bb, _ = proto.Marshal(table)
16 | var table2 = &Table{}
17 | proto.Unmarshal(bb, table2)
18 | if !EquivalentTables(table, table2) {
19 | t.Errorf("%#v, %#v\n", table, table2)
20 | }
21 | }
22 |
23 | func TestNilTableSet(t *testing.T) {
24 | // Protobuf returns a nil Table.Table if there are no entries.
25 | // This test makes sure setting keys on a nil table works
26 | table := Table{Table: nil}
27 | err := table.SetKey("a", uint32(1))
28 | if err != nil {
29 | t.Errorf(err.Error())
30 | }
31 | val := table.GetKey("a")
32 | if val == nil {
33 | t.Errorf("No value returned for key 'a'")
34 | }
35 | }
36 |
37 | func TestBasicFieldArray(t *testing.T) {
38 | // TODO: test this more thoroughly. it gets some working out in other tests
39 | // but this could do more.
40 | var fa = NewFieldArray()
41 | err := fa.AppendFA(make(map[bool]bool))
42 | if err == nil {
43 | t.Errorf("No error with bad append value")
44 | }
45 | }
46 |
47 | func TestBasicTable(t *testing.T) {
48 | // Create
49 | var table = NewTable()
50 | table.SetKey("product", "dispatchd")
51 | table.SetKey("version", uint8(7))
52 | table.SetKey("version", uint8(6)) // for code coverage, reset a value
53 | err := table.SetKey("bad", make(map[bool]bool)) // for code coverage, a type it doesn't understand
54 | if err == nil {
55 | t.Errorf("No error on bad set value")
56 | }
57 |
58 | var fv = table.GetKey("version")
59 | if fv.Value.(*FieldValue_VUint8).VUint8 != uint8(6) {
60 | t.Errorf("Didn't get the right key from table")
61 | }
62 | if table.GetKey("DOES NOT EXIST") != nil {
63 | t.Errorf("Found key that shouldn't exist!")
64 | }
65 |
66 | }
67 |
68 | func TestTableTypes(t *testing.T) {
69 | var inTable = EverythingTable()
70 |
71 | // Encode
72 | writer := bytes.NewBuffer(make([]byte, 0))
73 | err := WriteTable(writer, inTable)
74 | if err != nil {
75 | t.Errorf(err.Error())
76 | }
77 |
78 | // decode
79 | var reader = bytes.NewReader(writer.Bytes())
80 | outTable, err := ReadTable(reader, true)
81 | if err != nil {
82 | t.Errorf(err.Error())
83 | }
84 |
85 | // compare
86 | if !EquivalentTables(inTable, outTable) {
87 | t.Errorf("Tables no equal")
88 | }
89 | }
90 |
91 | func (table *Table) Generator(rand *rand.Rand, size int) reflect.Value {
92 | return reflect.ValueOf(EverythingTable())
93 | }
94 |
95 | func EverythingTable() *Table {
96 | var inTable = NewTable()
97 |
98 | // Basic types
99 | inTable.SetKey("bool", true)
100 | inTable.SetKey("int8", int8(-2))
101 | inTable.SetKey("uint8", uint8(3))
102 | inTable.SetKey("int16", int16(-4))
103 | inTable.SetKey("uint16", uint16(5))
104 | inTable.SetKey("int32", int32(-6))
105 | inTable.SetKey("uint32", uint32(7))
106 | inTable.SetKey("int64", int64(-8))
107 | inTable.SetKey("uint64", uint64(9))
108 | inTable.SetKey("float32", float32(10.1))
109 | inTable.SetKey("float64", float64(-11.2))
110 | inTable.SetKey("string", "string value")
111 | inTable.SetKey("[]byte", []byte{14, 15, 16, 17})
112 | // TODO: timestamp
113 | // Decimal
114 | var scale = uint8(12)
115 | var value = int32(-13)
116 | inTable.SetKey("*Decimal", &Decimal{&scale, &value})
117 | // Field Array
118 | var fa = NewFieldArray()
119 | fa.AppendFA(int8(101))
120 | inTable.SetKey("*FieldArray", fa)
121 | // Table
122 | var innerTable = NewTable()
123 | innerTable.SetKey("some key", "some value")
124 | inTable.SetKey("*Table", innerTable)
125 | return inTable
126 | }
127 |
--------------------------------------------------------------------------------
/amqp/testlib.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | import (
4 | // "fmt"
5 | "math/rand"
6 | "sync/atomic"
7 | "time"
8 | )
9 |
10 | var counter int64
11 |
12 | func init() {
13 | rand.Seed(time.Now().UnixNano())
14 | counter = 0
15 | }
16 |
17 | func nextId() int64 {
18 | return atomic.AddInt64(&counter, 1)
19 | }
20 |
21 | func RandomMessage(persistent bool) *Message {
22 | var size, payload = RandomPayload()
23 | var randomHeader = RandomHeader(size, persistent)
24 | var exchange = "exchange-name"
25 | var routingKey = "routing-key"
26 | var randomPublish = &BasicPublish{
27 | Exchange: exchange,
28 | RoutingKey: routingKey,
29 | Mandatory: false,
30 | Immediate: false,
31 | }
32 | return &Message{
33 | Id: nextId(),
34 | Header: randomHeader,
35 | Payload: payload,
36 | Exchange: exchange,
37 | Key: routingKey,
38 | Method: randomPublish,
39 | Redelivered: 0,
40 | LocalId: nextId(),
41 | }
42 | }
43 |
44 | func RandomPayload() (uint64, []*WireFrame) {
45 | var size uint64 = 500
46 | var payload = make([]byte, size)
47 | return size, []*WireFrame{
48 | &WireFrame{
49 | FrameType: 3,
50 | Channel: 1,
51 | Payload: payload,
52 | },
53 | }
54 | }
55 |
56 | func RandomHeader(size uint64, persistent bool) *ContentHeaderFrame {
57 | var props = BasicContentHeaderProperties{}
58 | if persistent {
59 | var b = byte(2)
60 | props.DeliveryMode = &b
61 | } else {
62 | var b = byte(1)
63 | props.DeliveryMode = &b
64 | }
65 | return &ContentHeaderFrame{
66 | ContentClass: 60,
67 | ContentWeight: 0,
68 | ContentBodySize: size,
69 | PropertyFlags: 0,
70 | Properties: &props,
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/amqp/tx.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | func NewTxMessage(msg *Message, queueName string) *TxMessage {
4 | return &TxMessage{
5 | QueueName: queueName,
6 | Msg: msg,
7 | }
8 | }
9 |
10 | func NewTxAck(tag uint64, nack bool, requeueNack bool, multiple bool) *TxAck {
11 | return &TxAck{
12 | Tag: tag,
13 | Nack: nack,
14 | RequeueNack: requeueNack,
15 | Multiple: multiple,
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/amqp/tx_test.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | import (
4 | "reflect"
5 | "testing"
6 | )
7 |
8 | func TestNewTxMessage(t *testing.T) {
9 | var msg = RandomMessage(false)
10 | var txMessage = NewTxMessage(msg, "hithere")
11 | if !reflect.DeepEqual(txMessage.GetMsg(), msg) {
12 | t.Errorf("Messages not equal")
13 | }
14 | if txMessage.GetQueueName() != "hithere" {
15 | t.Errorf("bad queue name")
16 | }
17 | }
18 |
19 | func TestNewTxAck(t *testing.T) {
20 | var tag uint64 = 12345 //, nack bool, requeueNack bool, multiple bool
21 | var txack = NewTxAck(tag, false, false, true)
22 | if txack.GetTag() != 12345 {
23 | t.Errorf("tag mismatch")
24 | }
25 | if !txack.Multiple {
26 | t.Errorf("bad multiple flag")
27 | }
28 | txack = NewTxAck(tag, false, true, false)
29 | if !txack.RequeueNack {
30 | t.Errorf("bad multiple requeue")
31 | }
32 | txack = NewTxAck(tag, true, false, true)
33 | if !txack.Nack {
34 | t.Errorf("bad multiple nack")
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/amqp/types.go:
--------------------------------------------------------------------------------
1 | package amqp
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "github.com/jeffjenkins/dispatchd/util"
7 | "io"
8 | "regexp"
9 | )
10 |
11 | type Frame interface {
12 | FrameType() byte
13 | }
14 |
15 | type MethodFrame interface {
16 | MethodName() string
17 | MethodIdentifier() (uint16, uint16)
18 | Read(reader io.Reader, strictMode bool) (err error)
19 | Write(writer io.Writer) (err error)
20 | FrameType() byte
21 | }
22 |
23 | // A message resource is something which has limits on the count of
24 | // messages it can handle as well as the cumulative size of the messages
25 | // it can handle.
26 | type MessageResourceHolder interface {
27 | AcquireResources(qm *QueueMessage) bool
28 | ReleaseResources(qm *QueueMessage)
29 | }
30 |
31 | func NewMessage(method *BasicPublish, localId int64) *Message {
32 | return &Message{
33 | Id: util.NextId(),
34 | Method: method,
35 | Exchange: method.Exchange,
36 | Key: method.RoutingKey,
37 | Payload: make([]*WireFrame, 0, 1),
38 | LocalId: localId,
39 | }
40 | }
41 |
42 | func NewTruncatedBodyFrame(channel uint16) WireFrame {
43 | return WireFrame{
44 | FrameType: byte(FrameBody),
45 | Channel: channel,
46 | Payload: make([]byte, 0, 0),
47 | }
48 | }
49 |
50 | func NewUnackedMessage(tag string, qm *QueueMessage, queueName string) *UnackedMessage {
51 | return &UnackedMessage{
52 | ConsumerTag: tag,
53 | Msg: qm,
54 | QueueName: queueName,
55 | }
56 | }
57 |
58 | func NewIndexMessage(id int64, refCount int32, durable bool, deliveryCount int32) *IndexMessage {
59 | return &IndexMessage{
60 | Id: id,
61 | Refs: refCount,
62 | Durable: durable,
63 | DeliveryCount: deliveryCount,
64 | }
65 | }
66 |
67 | func NewQueueMessage(id int64, deliveryCount int32, durable bool, msgSize uint32, localId int64) *QueueMessage {
68 | return &QueueMessage{
69 | Id: id,
70 | DeliveryCount: deliveryCount,
71 | Durable: durable,
72 | MsgSize: msgSize,
73 | LocalId: localId,
74 | // NOTE: When loading this from disk later we should zero localId out. This is the ID
75 | // of the publishing channel, which isn't relevant on server boot.
76 | }
77 | }
78 |
79 | func (frame *ContentHeaderFrame) FrameType() byte {
80 | return 2
81 | }
82 |
83 | var exchangeNameRegex = regexp.MustCompile(`^[a-zA-Z0-9-_.:]*$`)
84 |
85 | func CheckExchangeOrQueueName(s string) error {
86 | // Is it possible this length check is generally ignored since a short
87 | // string is only twice as long?
88 | if len(s) > 127 {
89 | return fmt.Errorf("Exchange name too long: %d", len(s))
90 | }
91 | if !exchangeNameRegex.MatchString(s) {
92 | return fmt.Errorf("Exchange name invalid: %s", s)
93 | }
94 | return nil
95 | }
96 |
97 | func (frame *ContentHeaderFrame) Read(reader io.Reader, strictMode bool) (err error) {
98 | frame.ContentClass, err = ReadShort(reader)
99 | if err != nil {
100 | return err
101 | }
102 |
103 | frame.ContentWeight, err = ReadShort(reader)
104 | if err != nil {
105 | return err
106 | }
107 | if frame.ContentWeight != 0 {
108 | return errors.New("Bad content weight in header frame. Should be 0")
109 | }
110 |
111 | frame.ContentBodySize, err = ReadLonglong(reader)
112 | if err != nil {
113 | return err
114 | }
115 |
116 | frame.PropertyFlags, err = ReadShort(reader)
117 | if err != nil {
118 | return err
119 | }
120 |
121 | frame.Properties = &BasicContentHeaderProperties{}
122 | err = frame.Properties.ReadProps(frame.PropertyFlags, reader, strictMode)
123 | if err != nil {
124 | return err
125 | }
126 | return nil
127 | }
128 |
--------------------------------------------------------------------------------
/amqpgen/amqpgen.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "bytes"
5 | "encoding/xml"
6 | "flag"
7 | "fmt"
8 | "io/ioutil"
9 | "os"
10 | "strings"
11 | )
12 |
13 | var amqpToGo = map[string]string{
14 | "bit": "bool",
15 | "octet": "byte",
16 | "short": "uint16",
17 | "long": "uint32",
18 | "longlong": "uint64",
19 | "table": "Table",
20 | "timestamp": "uint64",
21 | "shortstr": "string",
22 | "longstr": "[]byte",
23 | }
24 |
25 | var amqpToProto = map[string]string{
26 | "bit": "bool",
27 | "octet": "uint32",
28 | "short": "uint32",
29 | "long": "uint32",
30 | "longlong": "uint64",
31 | "table": "Table",
32 | "timestamp": "uint64",
33 | "shortstr": "string",
34 | "longstr": "bytes",
35 | }
36 |
37 | type Root struct {
38 | Amqp Amqp `xml:"amqp"`
39 | }
40 |
41 | type Amqp struct {
42 | Constants []*Constant `xml:"constant"`
43 | Domains []*Domain `xml:"domain"`
44 | Classes []*Class `xml:"class"`
45 | }
46 |
47 | type Domain struct {
48 | Name string `xml:"name,attr"`
49 | Type string `xml:"type,attr"`
50 | }
51 |
52 | type Class struct {
53 | Methods []*Method `xml:"method"`
54 | Name string `xml:"name,attr"`
55 | NormName string
56 | Handler string `xml:"handler,attr"`
57 | Index string `xml:"index,attr"`
58 | Fields []*Field `xml:"field"`
59 | }
60 |
61 | type Constant struct {
62 | Name string `xml:"name,attr"`
63 | NormName string
64 | Value uint16 `xml:"value,attr"`
65 | Class string `xml:"class,attr"`
66 | }
67 |
68 | type Method struct {
69 | Name string `xml:"name,attr"`
70 | NormName string
71 | Synchronous string `xml:"synchronous,attr"`
72 | Index string `xml:"index,attr"`
73 | Fields []*Field `xml:"field"`
74 | BitsAtEnd bool
75 | }
76 |
77 | type Field struct {
78 | Name string `xml:"name,attr"`
79 | Domain string `xml:"domain,attr"`
80 | AmqpType string `xml:"type,attr"`
81 | NormName string
82 | ProtoType string
83 | ProtoIndex int
84 | ProtoName string
85 | Options string
86 | MaskIndex string
87 | Serializer string
88 | ReadArgs string
89 | GoType string
90 | BitOffset int
91 | PreviousBit bool
92 | }
93 |
94 | var specFile string
95 |
96 | func transform(r *Root) {
97 | transformConstants(r.Amqp.Constants)
98 | domainTypes := transformDomains(r.Amqp.Domains)
99 | transformClasses(r.Amqp.Classes, domainTypes)
100 | }
101 |
102 | func transformConstants(cs []*Constant) {
103 | for _, c := range cs {
104 | c.NormName = normalizeName(c.Name)
105 | }
106 | }
107 |
108 | func transformClasses(cs []*Class, domainTypes map[string]string) {
109 | for _, c := range cs {
110 | c.NormName = normalizeName(c.Name)
111 | transformFields(c.Fields, domainTypes, true, 1)
112 | transformMethods(c.NormName, c.Methods, domainTypes)
113 | }
114 | }
115 |
116 | func transformDomains(ds []*Domain) map[string]string {
117 | domainTypes := make(map[string]string)
118 | for _, d := range ds {
119 | domainTypes[d.Name] = d.Type
120 | }
121 | return domainTypes
122 | }
123 |
124 | func transformMethods(className string, ms []*Method, domainTypes map[string]string) {
125 | for _, m := range ms {
126 | m.NormName = className + normalizeName(m.Name)
127 | m.BitsAtEnd = transformFields(m.Fields, domainTypes, false, 1)
128 | }
129 | }
130 |
131 | func transformFields(fs []*Field, domainTypes map[string]string, nullable bool, offset int) bool {
132 | var bits = 0
133 | var bitsAtEnd = false
134 | for index, f := range fs {
135 | var ok bool
136 | domain := f.Domain
137 | if f.AmqpType != "" {
138 | domain = f.AmqpType
139 | } else {
140 | f.AmqpType, ok = domainTypes[domain]
141 | if !ok {
142 | panic("")
143 | }
144 | }
145 | f.ProtoType, ok = amqpToProto[f.AmqpType]
146 | if !ok {
147 | panic("")
148 | }
149 | f.GoType, ok = amqpToGo[f.AmqpType]
150 | if !ok {
151 | panic("")
152 | }
153 |
154 | f.NormName = normalizeName(f.Name)
155 | f.ProtoName = protoName(f.Name)
156 | f.ProtoIndex = index + offset
157 | f.MaskIndex = fmt.Sprintf("%04x", (0 | 1<<(uint(15)-uint(index))))
158 | f.Serializer = normalizeName(domain)
159 | if strings.Contains(f.Serializer, "Properties") || strings.Contains(f.Serializer, "Table") {
160 | f.ReadArgs = ", strictMode"
161 | }
162 | if f.AmqpType == "bit" {
163 | f.BitOffset = bits
164 | bits += 1
165 | bitsAtEnd = true
166 | } else {
167 | f.PreviousBit = bits > 0
168 | f.BitOffset = -1
169 | bits = 0
170 | bitsAtEnd = false
171 | }
172 | f.Options = fieldOptions(f, domainTypes, nullable)
173 | }
174 | return bitsAtEnd
175 | }
176 |
177 | func fieldOptions(f *Field, domainTypes map[string]string, nullable bool) string {
178 | var options = make([]string, 0)
179 | if f.GoType != f.ProtoType && f.ProtoType != "bytes" {
180 | options = append(options, fmt.Sprintf("(gogoproto.casttype) = \"%s\"", f.GoType))
181 | }
182 | if f.ProtoType != "Table" && f.ProtoType != "bytes" && !nullable {
183 | options = append(options, "(gogoproto.nullable) = false")
184 | }
185 | if len(options) > 0 {
186 | return " [" + strings.Join(options, ", ") + "]"
187 | }
188 | return ""
189 |
190 | }
191 |
192 | func normalizeName(s string) string {
193 | parts := strings.Split(s, "-")
194 | ret := ""
195 | for _, p := range parts {
196 | ret += upperFirst(p)
197 | }
198 | return ret
199 | }
200 | func protoName(s string) string {
201 | if strings.Contains(s, "reserved") {
202 | return strings.Join(strings.Split(s, "-"), "")
203 | }
204 | return strings.Join(strings.Split(s, "-"), "_")
205 | }
206 |
207 | func upperFirst(s string) string {
208 | if s == "" {
209 | return ""
210 | }
211 |
212 | return string(bytes.ToUpper([]byte(s[0:1]))) + s[1:]
213 | }
214 |
215 | func main() {
216 | // fmt.Println(protoTemplate)
217 | flag.StringVar(&specFile, "spec", "", "Spec XML file")
218 | flag.Parse()
219 | var bytes, err = ioutil.ReadFile(specFile)
220 | if err != nil {
221 | panic(err)
222 | }
223 | var root Root
224 | err = xml.Unmarshal(bytes, &root.Amqp)
225 | if err != nil {
226 | panic(err.Error())
227 | }
228 | transform(&root)
229 | // Proto
230 | f, err := os.Create("amqp/protocol_generated.proto")
231 | if err != nil {
232 | panic(err.Error())
233 | }
234 | protoTemplate.Execute(f, &root)
235 |
236 | // Readers/Writers
237 | f, err = os.Create("amqp/protocol_protobuf_readwrite_generated.go")
238 | if err != nil {
239 | panic(err.Error())
240 | }
241 | readWriteTemplate.Execute(f, &root)
242 | }
243 |
--------------------------------------------------------------------------------
/amqpgen/templates.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "text/template"
5 | )
6 |
7 | var protoTemplate *template.Template
8 | var readWriteTemplate *template.Template
9 |
10 | func init() {
11 | //
12 | // PROTOCOL BUFFER TEMPLATE
13 | //
14 | t, err := template.New("proto").Parse(`package amqp;
15 |
16 | import "github.com/gogo/protobuf/gogoproto/gogo.proto";
17 | import "github.com/jeffjenkins/dispatchd/amqp/amqp.proto";
18 |
19 | option (gogoproto.marshaler_all) = true;
20 | option (gogoproto.sizer_all) = true;
21 | option (gogoproto.unmarshaler_all) = true;
22 | {{range .Amqp.Classes}}
23 | {{if .Fields}}
24 | message {{.NormName}}ContentHeaderProperties {
25 | {{range $index, $field := .Fields}}
26 | optional {{.ProtoType}} {{.ProtoName}} = {{.ProtoIndex}}{{.Options}};{{end}}
27 | }
28 | {{end}}
29 | {{range .Methods}}
30 | message {{.NormName}} {
31 | option (gogoproto.goproto_unrecognized) = false;
32 | option (gogoproto.goproto_getters) = false;
33 | {{range $index, $field := .Fields}}
34 | optional {{.ProtoType}} {{.ProtoName}} = {{.ProtoIndex}}{{.Options}};{{end}}
35 | }
36 | {{end}}
37 | {{end}}
38 |
39 |
40 | `)
41 | if err != nil {
42 | panic(err.Error())
43 | }
44 | protoTemplate = t
45 | }
46 | func init() {
47 | //
48 | // READ/WRITE TEMPLATE
49 | //
50 | t, err := template.New("proto").Parse(`
51 | package amqp
52 |
53 | import (
54 | "io"
55 | "errors"
56 | "fmt"
57 | )
58 |
59 | {{range $class := .Amqp.Classes}}
60 | var ClassId{{.NormName}} uint16 = {{.Index}}
61 | {{if .Fields}}{{range $index, $field := .Fields}}
62 | var Mask{{.NormName}} uint16 = 0x{{.MaskIndex}};{{end}}
63 | {{end}}
64 | {{range .Methods}}
65 | // ************************
66 | // {{.NormName}}
67 | // ************************
68 | var MethodId{{.NormName}} uint16 = {{.Index}};
69 |
70 | func (f *{{.NormName}}) MethodIdentifier() (uint16, uint16) {
71 | return {{$class.Index}}, {{.Index}}
72 | }
73 |
74 | func (f *{{.NormName}}) MethodName() string {
75 | return "{{.NormName}}"
76 | }
77 |
78 | func (f *{{.NormName}}) FrameType() byte {
79 | return 1
80 | }
81 |
82 | // Reader
83 | func (f *{{.NormName}}) Read(reader io.Reader, strictMode bool) (err error) {
84 | {{range $index, $field := .Fields}}
85 | {{if eq .BitOffset -1}}
86 | f.{{.NormName}}, err = Read{{.Serializer}}(reader{{.ReadArgs}})
87 | if err != nil {
88 | return errors.New("Error reading field {{.NormName}}: " + err.Error())
89 | }
90 | {{else}}
91 | {{if eq .BitOffset 0}}
92 | bits, err := ReadOctet(reader)
93 | if err != nil {
94 | return errors.New("Error reading bit fields" + err.Error())
95 | }
96 | {{end}}
97 | f.{{.NormName}} = (bits & (1 << {{.BitOffset}}) > 0)
98 | {{end}}
99 | {{end}}{{/* end fields */}}
100 | return
101 | }
102 |
103 | // Writer
104 | func (f *{{.NormName}}) Write(writer io.Writer) (err error) {
105 | if err = WriteShort(writer, {{$class.Index}}); err != nil {
106 | return err
107 | }
108 | if err = WriteShort(writer, {{.Index}}); err != nil {
109 | return err
110 | }
111 |
112 | {{range $index, $field := .Fields}}
113 | {{if eq .GoType "bool"}}
114 | {{if eq .BitOffset 0}}
115 | var bits byte
116 | {{end}}
117 | if f.{{.NormName}} {
118 | bits |= 1 << {{.BitOffset}}
119 | }
120 | {{else}}{{/* else go type is not bool */}}
121 | {{if .PreviousBit}}
122 | err = WriteOctet(writer, bits)
123 | if err != nil {
124 | return errors.New("Error writing bit fields")
125 | }
126 | {{end}}
127 | err = Write{{.Serializer}}(writer, f.{{.NormName}})
128 | if err != nil {
129 | return errors.New("Error writing field {{.NormName}}")
130 | }
131 | {{end}}{{/* end go type bool */}}
132 | {{end}}{{/* end fields */}}
133 | {{if .BitsAtEnd}}
134 | err = WriteOctet(writer, bits)
135 | if err != nil {
136 | return errors.New("Error writing bit fields")
137 | }
138 | {{end}}
139 | return
140 | }
141 |
142 | {{end}}{{/* end methods */}}
143 | {{end}}{{/* end amqp.classes */}}
144 |
145 | // ********************************
146 | // METHOD READER
147 | // ********************************
148 |
149 | func ReadMethod(reader io.Reader, strictMode bool) (MethodFrame, error) {
150 | classIndex, err := ReadShort(reader)
151 | if err != nil {
152 | return nil, err
153 | }
154 | methodIndex, err := ReadShort(reader)
155 | if err != nil {
156 | return nil, err
157 | }
158 | switch {
159 | {{range $class := .Amqp.Classes}}
160 | case classIndex == {{.Index}}:
161 | switch {
162 | {{range .Methods}}
163 | case methodIndex == {{.Index}}:
164 | var method = &{{.NormName}}{}
165 | err = method.Read(reader, strictMode)
166 | if err != nil {
167 | return nil, err
168 | }
169 | return method, nil
170 | {{end}}
171 | }
172 | {{end}}
173 | }
174 |
175 | return nil, errors.New(fmt.Sprintf("Bad method or class Id! classId: %d, methodIndex: %d", classIndex, methodIndex))
176 | }
177 | `)
178 | if err != nil {
179 | panic(err.Error())
180 | }
181 | readWriteTemplate = t
182 | }
183 |
--------------------------------------------------------------------------------
/binding/binding.go:
--------------------------------------------------------------------------------
1 | package binding
2 |
3 | import (
4 | "bytes"
5 | "crypto/sha1"
6 | "encoding/json"
7 | "fmt"
8 | "github.com/boltdb/bolt"
9 | "github.com/gogo/protobuf/proto"
10 | "github.com/jeffjenkins/dispatchd/amqp"
11 | "github.com/jeffjenkins/dispatchd/gen"
12 | "github.com/jeffjenkins/dispatchd/persist"
13 | "regexp"
14 | "strings"
15 | )
16 |
17 | type BindingStateFactory struct{}
18 |
19 | func (bsf *BindingStateFactory) New() proto.Unmarshaler {
20 | return &gen.BindingState{}
21 | }
22 |
23 | var BINDINGS_BUCKET_NAME = []byte("bindings")
24 |
25 | type Binding struct {
26 | gen.BindingState
27 | topicMatcher *regexp.Regexp
28 | }
29 |
30 | var topicRoutingPatternPattern, _ = regexp.Compile(`^((\w+|\*|#)(\.(\w+|\*|#))*|)$`)
31 |
32 | func (binding *Binding) MarshalJSON() ([]byte, error) {
33 | return json.Marshal(map[string]interface{}{
34 | "queueName": binding.QueueName,
35 | "exchangeName": binding.ExchangeName,
36 | "key": binding.Key,
37 | "arguments": binding.Arguments,
38 | })
39 | }
40 |
41 | func (binding *Binding) Equals(other *Binding) bool {
42 | if other == nil || binding == nil {
43 | return false
44 | }
45 | return binding.QueueName == other.QueueName &&
46 | binding.ExchangeName == other.ExchangeName &&
47 | binding.Key == other.Key
48 | }
49 |
50 | func (binding *Binding) Depersist(db *bolt.DB) error {
51 | return persist.DepersistOne(db, BINDINGS_BUCKET_NAME, string(binding.Id))
52 | }
53 |
54 | func (binding *Binding) DepersistBoltTx(tx *bolt.Tx) error {
55 | bucket, err := tx.CreateBucketIfNotExists(BINDINGS_BUCKET_NAME)
56 | if err != nil { // pragma: nocover
57 | // If we're hitting this it means the disk is full, the db is readonly,
58 | // or something else has gone irrecoverably wrong
59 | panic(fmt.Sprintf("create bucket: %s", err))
60 | }
61 | return persist.DepersistOneBoltTx(bucket, string(binding.Id))
62 | }
63 |
64 | func NewBinding(queueName string, exchangeName string, key string, arguments *amqp.Table, topic bool) (*Binding, error) {
65 | var re *regexp.Regexp = nil
66 | // Topic routing key
67 | if topic {
68 | if !topicRoutingPatternPattern.MatchString(key) {
69 | return nil, fmt.Errorf("Topic exchange routing key can only have a-zA-Z0-9, or # or *")
70 | }
71 | var parts = strings.Split(key, ".")
72 | for i, part := range parts {
73 | if part == "*" {
74 | parts[i] = `[^\.]+`
75 | } else if part == "#" {
76 | parts[i] = ".*"
77 | } else {
78 | parts[i] = regexp.QuoteMeta(parts[i])
79 | }
80 | }
81 | expression := "^" + strings.Join(parts, `\.`) + "$"
82 | var err error = nil
83 | re, err = regexp.Compile(expression)
84 | if err != nil { // pragma: nocover
85 | // This is impossible to get to based on the earlier
86 | // code, so we panic and don't count it for coverage
87 | panic(fmt.Sprintf("Could not compile regex: '%s'", expression))
88 | }
89 | }
90 |
91 | return &Binding{
92 | BindingState: gen.BindingState{
93 | Id: calcId(queueName, exchangeName, key, arguments),
94 | QueueName: queueName,
95 | ExchangeName: exchangeName,
96 | Key: key,
97 | Arguments: arguments,
98 | Topic: topic,
99 | },
100 | topicMatcher: re,
101 | }, nil
102 | }
103 |
104 | func LoadAllBindings(db *bolt.DB) (map[string]*Binding, error) {
105 | exStateMap, err := persist.LoadAll(db, BINDINGS_BUCKET_NAME, &BindingStateFactory{})
106 | if err != nil {
107 | return nil, err
108 | }
109 | var ret = make(map[string]*Binding)
110 | for key, state := range exStateMap {
111 | var sb = state.(*gen.BindingState)
112 | // TODO: we don't actually know if topic is true, so this is extra work
113 | // for other exchange binding types
114 | ret[key], err = NewBinding(sb.QueueName, sb.ExchangeName, sb.Key, sb.Arguments, sb.Topic)
115 | if err != nil {
116 | return nil, err
117 | }
118 | }
119 | return ret, nil
120 | }
121 |
122 | func (b *Binding) Persist(db *bolt.DB) error {
123 | return persist.PersistOne(db, BINDINGS_BUCKET_NAME, string(b.Id), b)
124 | }
125 |
126 | func (b *Binding) MatchDirect(message *amqp.BasicPublish) bool {
127 | return message.Exchange == b.ExchangeName && b.Key == message.RoutingKey
128 | }
129 |
130 | func (b *Binding) MatchFanout(message *amqp.BasicPublish) bool {
131 | return message.Exchange == b.ExchangeName
132 | }
133 |
134 | func (b *Binding) MatchTopic(message *amqp.BasicPublish) bool {
135 | var ex = b.ExchangeName == message.Exchange
136 | var match = b.topicMatcher.MatchString(message.RoutingKey)
137 | return ex && match
138 | }
139 |
140 | // Calculate an ID by encoding the QueueBind call that created this binding and
141 | // taking a hash of it.
142 | func calcId(queueName string, exchangeName string, key string, arguments *amqp.Table) []byte {
143 | var method = &amqp.QueueBind{
144 | Queue: queueName,
145 | Exchange: exchangeName,
146 | RoutingKey: key,
147 | Arguments: arguments,
148 | }
149 | var buffer = bytes.NewBuffer(make([]byte, 0))
150 | method.Write(buffer)
151 | // trim off the first four bytes, they're the class/method, which we
152 | // already know
153 | var value = buffer.Bytes()[4:]
154 | // bindings aren't named, so we hash the bytes we encoded
155 | hash := sha1.New()
156 | hash.Write(value)
157 | return []byte(hash.Sum(nil))
158 | }
159 |
--------------------------------------------------------------------------------
/binding/binding_test.go:
--------------------------------------------------------------------------------
1 | package binding
2 |
3 | import (
4 | "encoding/json"
5 | "github.com/boltdb/bolt"
6 | "github.com/jeffjenkins/dispatchd/amqp"
7 | // "github.com/jeffjenkins/dispatchd/persist"
8 | "os"
9 | "testing"
10 | )
11 |
12 | func TestFanout(t *testing.T) {
13 | b, _ := NewBinding("q1", "e1", "rk", amqp.NewTable(), false)
14 | if b.MatchFanout(basicPublish("DIFF", "asdf")) {
15 | t.Errorf("Fanout did not check exchanges")
16 | }
17 | if !b.MatchFanout(basicPublish("e1", "asdf")) {
18 | t.Errorf("Fanout didn't match regardless of key")
19 | }
20 | }
21 |
22 | func TestDirect(t *testing.T) {
23 | b, _ := NewBinding("q1", "e1", "rk", amqp.NewTable(), false)
24 | if b.MatchDirect(basicPublish("DIFF", "asdf")) {
25 | t.Errorf("MatchDirect did not check exchanges")
26 | }
27 | if b.MatchDirect(basicPublish("e1", "asdf")) {
28 | t.Errorf("MatchDirect matched even with the wrong key")
29 | }
30 | if !b.MatchDirect(basicPublish("e1", "rk")) {
31 | t.Errorf("MatchDirect did not match with the correct key and exchange")
32 | }
33 | }
34 |
35 | func TestEquals(t *testing.T) {
36 | var bNil *Binding = nil
37 | b, _ := NewBinding("q1", "e1", "rk", amqp.NewTable(), false)
38 | same, _ := NewBinding("q1", "e1", "rk", amqp.NewTable(), false)
39 | diffQ, _ := NewBinding("DIFF", "e1", "rk", amqp.NewTable(), false)
40 | diffE, _ := NewBinding("q1", "DIFF", "rk", amqp.NewTable(), false)
41 | diffR, _ := NewBinding("q1", "e1", "DIFF", amqp.NewTable(), false)
42 |
43 | if b == nil || same == nil || diffQ == nil || diffE == nil || diffR == nil {
44 | t.Errorf("Failed to construct bindings")
45 | }
46 | if b.Equals(nil) || bNil.Equals(b) {
47 | t.Errorf("Comparison to nil was true!")
48 | }
49 | if !b.Equals(same) {
50 | t.Errorf("Equals returns false!")
51 | }
52 | if b.Equals(diffQ) {
53 | t.Errorf("Equals returns true on queue name diff!")
54 | }
55 | if b.Equals(diffE) {
56 | t.Errorf("Equals returns true on exchange name diff!")
57 | }
58 | if b.Equals(diffR) {
59 | t.Errorf("Equals returns true on routing key diff!")
60 | }
61 |
62 | }
63 |
64 | func TestTopicRouting(t *testing.T) {
65 | _, err := NewBinding("q1", "e1", "(", amqp.NewTable(), true)
66 | if err == nil {
67 | t.Errorf("Bad topic patter compiled!")
68 | }
69 | basic, _ := NewBinding("q1", "e1", "hello.world", amqp.NewTable(), true)
70 | singleWild, _ := NewBinding("q1", "e1", "hello.*.world", amqp.NewTable(), true)
71 | multiWild, _ := NewBinding("q1", "e1", "hello.#.world", amqp.NewTable(), true)
72 | multiWild2, _ := NewBinding("q1", "e1", "hello.#.world.#", amqp.NewTable(), true)
73 | if !basic.MatchTopic(basicPublish("e1", "hello.world")) {
74 | t.Errorf("Basic match failed")
75 | }
76 | if basic.MatchTopic(basicPublish("e1", "hello.worlds")) {
77 | t.Errorf("Incorrect match with suffix")
78 | }
79 | if !basic.MatchTopic(basicPublish("e1", "hello.world")) {
80 | t.Errorf("Match succeeded despite mismatched exchange")
81 | }
82 | if !singleWild.MatchTopic(basicPublish("e1", "hello.one.world")) {
83 | t.Errorf("Failed to match single wildcard")
84 | }
85 | if singleWild.MatchTopic(basicPublish("e1", "hello.world")) {
86 | t.Errorf("Matched without wildcard token")
87 | }
88 | if !multiWild.MatchTopic(basicPublish("e1", "hello.one.two.three.world")) {
89 | t.Errorf("Failed to match multi wildcard")
90 | }
91 | if !multiWild2.MatchTopic(basicPublish("e1", "hello.one.world.hi")) {
92 | t.Errorf("Multiple multi-wild tokens failed")
93 | }
94 |
95 | }
96 |
97 | func basicPublish(e string, key string) *amqp.BasicPublish {
98 | return &amqp.BasicPublish{
99 | Exchange: e,
100 | RoutingKey: key,
101 | }
102 | }
103 |
104 | func TestJson(t *testing.T) {
105 | basic, _ := NewBinding("q1", "e1", "hello.world", amqp.NewTable(), true)
106 | var basicBytes, err = basic.MarshalJSON()
107 | if err != nil {
108 | t.Errorf(err.Error())
109 | }
110 |
111 | expectedBytes, err := json.Marshal(map[string]interface{}{
112 | "queueName": "q1",
113 | "exchangeName": "e1",
114 | "key": "hello.world",
115 | "arguments": make(map[string]interface{}),
116 | })
117 | if string(expectedBytes) != string(basicBytes) {
118 | t.Logf("Expected: %s", expectedBytes)
119 | t.Logf("Got: %s", basicBytes)
120 | t.Errorf("Wrong json bytes!")
121 | }
122 |
123 | }
124 |
125 | func TestPersistence(t *testing.T) {
126 | var dbFile = "TestBindingPersistence.db"
127 | os.Remove(dbFile)
128 | defer os.Remove(dbFile)
129 | db, err := bolt.Open(dbFile, 0600, nil)
130 | if err != nil {
131 | t.Errorf("Failed to create db")
132 | }
133 |
134 | // Persist
135 | b, err := NewBinding("q1", "ex1", "rk1", amqp.NewTable(), true)
136 | if err != nil {
137 | t.Errorf("Error in NewBinding")
138 | }
139 | err = b.Persist(db)
140 | if err != nil {
141 | t.Errorf("Error in NewBinding")
142 | }
143 |
144 | // Read
145 | bMap, err := LoadAllBindings(db)
146 | if err != nil {
147 | t.Errorf("Error in LoadAllBindings")
148 | }
149 | if len(bMap) != 1 {
150 | t.Errorf("Wrong number of bindings")
151 | }
152 | for _, b2 := range bMap {
153 | if !b2.Equals(b) {
154 | t.Errorf("Did not get the same binding from the db")
155 | }
156 | }
157 |
158 | // Depersist
159 | b.Depersist(db)
160 |
161 | bMap, err = LoadAllBindings(db)
162 | if err != nil {
163 | t.Errorf("Error in LoadAllBindings")
164 | }
165 | if len(bMap) != 0 {
166 | t.Errorf("Wrong number of bindings")
167 | }
168 |
169 | }
170 |
--------------------------------------------------------------------------------
/config.default.json:
--------------------------------------------------------------------------------
1 | {
2 | "users" : {
3 | "guest" : {
4 | "password_bcrypt_base64" : "JDJhJDExJENobGk4dG5rY0RGemJhTjhsV21xR3VNNnFZZ1ZqTzUzQWxtbGtyMHRYN3RkUHMuYjF5SUt5"
5 | }
6 | }
7 | }
--------------------------------------------------------------------------------
/consumer/consumer.go:
--------------------------------------------------------------------------------
1 | package consumer
2 |
3 | import (
4 | "encoding/json"
5 | "github.com/jeffjenkins/dispatchd/amqp"
6 | "github.com/jeffjenkins/dispatchd/msgstore"
7 | "github.com/jeffjenkins/dispatchd/stats"
8 | "sync"
9 | )
10 |
11 | type Consumer struct {
12 | msgStore *msgstore.MessageStore
13 | arguments *amqp.Table
14 | cchannel ConsumerChannel
15 | ConsumerTag string
16 | exclusive bool
17 | incoming chan bool
18 | noAck bool
19 | noLocal bool
20 | cqueue ConsumerQueue
21 | queueName string
22 | consumeLock sync.Mutex
23 | limitLock sync.Mutex
24 | prefetchSize uint32
25 | prefetchCount uint16
26 | activeSize uint32
27 | activeCount uint16
28 | stopLock sync.Mutex
29 | stopped bool
30 | StatCount uint64
31 | localId int64
32 | // stats
33 | statConsumeOneGetOne stats.Histogram
34 | statConsumeOne stats.Histogram
35 | statConsumeOneAck stats.Histogram
36 | statConsumeOneSend stats.Histogram
37 | }
38 |
39 | type ConsumerQueue interface {
40 | GetOne(rhs ...amqp.MessageResourceHolder) (*amqp.QueueMessage, *amqp.Message)
41 | MaybeReady() chan bool
42 | }
43 |
44 | // The methods necessary for a consumer to interact with a channel
45 | type ConsumerChannel interface {
46 | amqp.MessageResourceHolder
47 | SendContent(method amqp.MethodFrame, msg *amqp.Message)
48 | SendMethod(method amqp.MethodFrame)
49 | FlowActive() bool
50 | AddUnackedMessage(consumerTag string, qm *amqp.QueueMessage, queueName string) uint64
51 | }
52 |
53 | func NewConsumer(
54 | msgStore *msgstore.MessageStore,
55 | arguments *amqp.Table,
56 | cchannel ConsumerChannel,
57 | consumerTag string,
58 | exclusive bool,
59 | noAck bool,
60 | noLocal bool,
61 | cqueue ConsumerQueue,
62 | queueName string,
63 | prefetchSize uint32,
64 | prefetchCount uint16,
65 | localId int64,
66 | ) *Consumer {
67 | return &Consumer{
68 | msgStore: msgStore,
69 | arguments: arguments,
70 | cchannel: cchannel,
71 | ConsumerTag: consumerTag,
72 | exclusive: exclusive,
73 | incoming: make(chan bool, 1),
74 | noAck: noAck,
75 | noLocal: noLocal,
76 | cqueue: cqueue,
77 | queueName: queueName,
78 | prefetchSize: prefetchSize,
79 | prefetchCount: prefetchCount,
80 | localId: localId,
81 | // stats
82 | statConsumeOneGetOne: stats.MakeHistogram("Consume-One-Get-One"),
83 | statConsumeOne: stats.MakeHistogram("Consume-One-"),
84 | statConsumeOneAck: stats.MakeHistogram("Consume-One-Ack"),
85 | statConsumeOneSend: stats.MakeHistogram("Consume-One-Send"),
86 | }
87 | }
88 |
89 | func (consumer *Consumer) MarshalJSON() ([]byte, error) {
90 | return json.Marshal(map[string]interface{}{
91 | "tag": consumer.ConsumerTag,
92 | "stats": map[string]interface{}{
93 | "total": consumer.StatCount,
94 | "active_size_bytes": consumer.activeSize,
95 | "active_count": consumer.activeCount,
96 | },
97 | "ack": !consumer.noAck,
98 | })
99 | }
100 |
101 | // TODO: make this a field that we construct on init
102 | func (consumer *Consumer) MessageResourceHolders() []amqp.MessageResourceHolder {
103 | return []amqp.MessageResourceHolder{consumer, consumer.cchannel}
104 | }
105 |
106 | func (consumer *Consumer) Stop() {
107 | if !consumer.stopped {
108 | consumer.stopLock.Lock()
109 | consumer.stopped = true
110 | close(consumer.incoming)
111 | consumer.stopLock.Unlock()
112 | }
113 | }
114 |
115 | func (consumer *Consumer) AcquireResources(qm *amqp.QueueMessage) bool {
116 | consumer.limitLock.Lock()
117 | defer consumer.limitLock.Unlock()
118 |
119 | // If no-local was set on the consumer, reject messages
120 | if consumer.noLocal && qm.LocalId == consumer.localId {
121 | return false
122 | }
123 |
124 | // If the channel is in flow mode we don't consume
125 | // TODO: If flow is mostly for producers, then maybe we
126 | // should consume? I feel like the right answer here is for
127 | // clients to not produce and consume on the same channel.
128 | if !consumer.cchannel.FlowActive() {
129 | return false
130 | }
131 | // If we aren't acking then there are no resource limits. Up the stats
132 | // and return true
133 | if consumer.noAck {
134 | consumer.activeCount += 1
135 | consumer.activeSize += qm.MsgSize
136 | return true
137 | }
138 |
139 | // Calculate whether we're over either of the size and count limits
140 | var sizeOk = consumer.prefetchSize == 0 || consumer.activeSize < consumer.prefetchSize
141 | var countOk = consumer.prefetchCount == 0 || consumer.activeCount < consumer.prefetchCount
142 | if sizeOk && countOk {
143 | consumer.activeCount += 1
144 | consumer.activeSize += qm.MsgSize
145 | return true
146 | }
147 | return false
148 | }
149 |
150 | func (consumer *Consumer) ReleaseResources(qm *amqp.QueueMessage) {
151 | consumer.limitLock.Lock()
152 | consumer.activeCount -= 1
153 | consumer.activeSize -= qm.MsgSize
154 | consumer.limitLock.Unlock()
155 | }
156 |
157 | func (consumer *Consumer) Start() {
158 | go consumer.consume(0)
159 | }
160 |
161 | func (consumer *Consumer) Ping() {
162 | consumer.stopLock.Lock()
163 | defer consumer.stopLock.Unlock()
164 | if !consumer.stopped {
165 | select {
166 | case consumer.incoming <- true:
167 | default:
168 | }
169 | }
170 |
171 | }
172 |
173 | func (consumer *Consumer) consume(id uint16) {
174 | // TODO: what is this doing?
175 | consumer.cqueue.MaybeReady() <- false
176 | for _ = range consumer.incoming {
177 |
178 | consumer.consumeOne()
179 | }
180 | }
181 |
182 | func (consumer *Consumer) consumeOne() {
183 | defer stats.RecordHisto(consumer.statConsumeOne, stats.Start())
184 | var err error
185 | // Check local limit
186 | consumer.consumeLock.Lock()
187 | defer consumer.consumeLock.Unlock()
188 | // Try to get message/check channel limit
189 |
190 | var start = stats.Start()
191 | var qm, msg = consumer.cqueue.GetOne(consumer.cchannel, consumer)
192 | stats.RecordHisto(consumer.statConsumeOneGetOne, start)
193 | if qm == nil {
194 | return
195 | }
196 | var tag uint64 = 0
197 | start = stats.Start()
198 | if !consumer.noAck {
199 | tag = consumer.cchannel.AddUnackedMessage(consumer.ConsumerTag, qm, consumer.queueName)
200 | } else {
201 | // We aren't expecting an ack, so this is the last time the message
202 | // will be referenced.
203 | var rhs = []amqp.MessageResourceHolder{consumer.cchannel, consumer}
204 | err = consumer.msgStore.RemoveRef(qm, consumer.queueName, rhs)
205 | if err != nil {
206 | panic("Error getting queue message")
207 | }
208 | }
209 | stats.RecordHisto(consumer.statConsumeOneAck, start)
210 | start = stats.Start()
211 | consumer.cchannel.SendContent(&amqp.BasicDeliver{
212 | ConsumerTag: consumer.ConsumerTag,
213 | DeliveryTag: tag,
214 | Redelivered: qm.DeliveryCount > 0,
215 | Exchange: msg.Exchange,
216 | RoutingKey: msg.Key,
217 | }, msg)
218 | stats.RecordHisto(consumer.statConsumeOneSend, start)
219 | consumer.StatCount += 1
220 | // Since we succeeded in processing a message, ping so that we try again
221 | consumer.Ping()
222 | }
223 |
224 | func (consumer *Consumer) SendCancel() {
225 | consumer.cchannel.SendMethod(&amqp.BasicCancel{consumer.ConsumerTag, true})
226 | }
227 |
228 | func (consumer *Consumer) ConsumeImmediate(qm *amqp.QueueMessage, msg *amqp.Message) bool {
229 | consumer.consumeLock.Lock()
230 | defer consumer.consumeLock.Unlock()
231 | var tag uint64 = 0
232 | if !consumer.noAck {
233 | tag = consumer.cchannel.AddUnackedMessage(consumer.ConsumerTag, qm, consumer.queueName)
234 | }
235 | consumer.cchannel.SendContent(&amqp.BasicDeliver{
236 | ConsumerTag: consumer.ConsumerTag,
237 | DeliveryTag: tag,
238 | Redelivered: msg.Redelivered > 0,
239 | Exchange: msg.Exchange,
240 | RoutingKey: msg.Key,
241 | }, msg)
242 | consumer.StatCount += 1
243 | return true
244 | }
245 |
246 | // Send again, leave all stats the same since this consumer was already
247 | // dealing with this message
248 | func (consumer *Consumer) Redeliver(tag uint64, qm *amqp.QueueMessage) {
249 | msg, found := consumer.msgStore.GetNoChecks(qm.Id)
250 | if !found {
251 | panic("Integrity error, message not found in message store")
252 | }
253 | consumer.cchannel.SendContent(&amqp.BasicDeliver{
254 | ConsumerTag: consumer.ConsumerTag,
255 | DeliveryTag: tag,
256 | Redelivered: msg.Redelivered > 0,
257 | Exchange: msg.Exchange,
258 | RoutingKey: msg.Key,
259 | }, msg)
260 | }
261 |
--------------------------------------------------------------------------------
/consumer/consumer_test.go:
--------------------------------------------------------------------------------
1 | package consumer
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestMarshalJson(t *testing.T) {
8 |
9 | }
10 |
--------------------------------------------------------------------------------
/dev/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "users" : {
3 | "guest" : {
4 | "password_bcrypt_base64" : "JDJhJDExJENobGk4dG5rY0RGemJhTjhsV21xR3VNNnFZZ1ZqTzUzQWxtbGtyMHRYN3RkUHMuYjF5SUt5"
5 | }
6 | }
7 | }
--------------------------------------------------------------------------------
/dispatchd/config.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "encoding/json"
5 | "flag"
6 | "io/ioutil"
7 | )
8 |
9 | var amqpPort int
10 | var amqpPortDefault = 5672
11 | var adminPort int
12 | var adminPortDefault = 8080
13 | var persistDir string
14 | var persistDirDefault = "/data/dispatchd/"
15 | var configFile string
16 | var configFileDefault = ""
17 | var strictMode bool
18 |
19 | func init() {
20 | flag.IntVar(&amqpPort, "amqp-port", 0, "Port for amqp protocol messages. Default: 5672")
21 | flag.IntVar(&adminPort, "admin-port", 0, "Port for admin server. Default: 8080")
22 | flag.StringVar(&persistDir, "persist-dir", "", "Directory for the server and message database files. Default: /data/dispatchd/")
23 | flag.BoolVar(&strictMode, "strict-mode", false, "Obey the AMQP spec even where it differs from common implementations")
24 | flag.StringVar(
25 | &configFile,
26 | "config-file",
27 | "",
28 | "Location of the configuration file. Default: do not read a config file",
29 | )
30 | }
31 |
32 | func configure() map[string]interface{} {
33 | // TODO: It's no great that this is manual. I should make/find a small library
34 | // to automate this.
35 | var config = make(map[string]interface{})
36 | if configFile != "" {
37 | config = parseConfigFile(configFile)
38 | }
39 | configureIntParam(&amqpPort, amqpPortDefault, "amqp-port", config)
40 | configureIntParam(&adminPort, adminPortDefault, "admin-port", config)
41 | configureStringParam(&persistDir, persistDirDefault, "persist-dir", config)
42 | configureBoolParam(&strictMode, "strict-mode", config)
43 | _, ok := config["users"]
44 | if !ok {
45 | config["users"] = make(map[string]interface{})
46 | }
47 | return config
48 | }
49 |
50 | func configureIntParam(param *int, defaultValue int, configName string, config map[string]interface{}) {
51 | if *param != 0 {
52 | return
53 | }
54 | if len(configName) != 0 {
55 | value, ok := config[configName]
56 | if ok {
57 | *param = int(value.(float64))
58 | return
59 | }
60 | }
61 | *param = defaultValue
62 | }
63 |
64 | func configureBoolParam(param *bool, configName string, config map[string]interface{}) {
65 | if len(configName) != 0 {
66 | value, ok := config[configName]
67 | if ok {
68 | *param = bool(value.(bool))
69 | return
70 | }
71 | }
72 | *param = false
73 | }
74 |
75 | func configureStringParam(param *string, defaultValue string, configName string, config map[string]interface{}) {
76 | if *param != "" {
77 | return
78 | }
79 | if len(configName) != 0 {
80 | value, ok := config[configName]
81 | if ok {
82 | *param = value.(string)
83 | return
84 | }
85 | }
86 | *param = defaultValue
87 | }
88 |
89 | func parseConfigFile(path string) map[string]interface{} {
90 | ret := make(map[string]interface{})
91 | data, err := ioutil.ReadFile(path)
92 | if err != nil {
93 | panic("Could not read config file: " + err.Error())
94 | }
95 | err = json.Unmarshal(data, &ret)
96 | if err != nil {
97 | panic("Could not parse config file: " + err.Error())
98 | }
99 | return ret
100 | }
101 |
--------------------------------------------------------------------------------
/dispatchd/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "flag"
5 | "fmt"
6 | "net"
7 | // _ "net/http/pprof" // uncomment for debugging
8 | "github.com/jeffjenkins/dispatchd/adminserver"
9 | "github.com/jeffjenkins/dispatchd/server"
10 | "os"
11 | "path/filepath"
12 | "runtime"
13 | )
14 |
15 | func handleConnection(server *server.Server, conn net.Conn) {
16 | server.OpenConnection(conn)
17 | }
18 |
19 | func main() {
20 | flag.Parse()
21 | config := configure()
22 | runtime.SetBlockProfileRate(1)
23 | serverDbPath := filepath.Join(persistDir, "dispatchd-server.db")
24 | msgDbPath := filepath.Join(persistDir, "messages.db")
25 | var server = server.NewServer(serverDbPath, msgDbPath, config["users"].(map[string]interface{}), strictMode)
26 | ln, err := net.Listen("tcp", fmt.Sprintf(":%d", amqpPort))
27 | if err != nil {
28 | fmt.Printf("Error!\n")
29 | os.Exit(1)
30 | }
31 | fmt.Printf("Listening on port %d\n", amqpPort)
32 | go func() {
33 | adminserver.StartAdminServer(server, adminPort)
34 | }()
35 | for {
36 | conn, err := ln.Accept()
37 | if err != nil {
38 | fmt.Printf("Error accepting connection!\n")
39 | os.Exit(1)
40 | }
41 | go handleConnection(server, conn)
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/exchange/exchange.go:
--------------------------------------------------------------------------------
1 | package exchange
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "github.com/boltdb/bolt"
7 | "github.com/gogo/protobuf/proto"
8 | "github.com/jeffjenkins/dispatchd/amqp"
9 | "github.com/jeffjenkins/dispatchd/binding"
10 | "github.com/jeffjenkins/dispatchd/gen"
11 | "github.com/jeffjenkins/dispatchd/persist"
12 | "sync"
13 | "time"
14 | )
15 |
16 | type ExchangeStateFactory struct{}
17 |
18 | func (esf *ExchangeStateFactory) New() proto.Unmarshaler {
19 | return &gen.ExchangeState{}
20 | }
21 |
22 | var EXCHANGES_BUCKET_NAME = []byte("exchanges")
23 |
24 | const (
25 | EX_TYPE_DIRECT uint8 = 1
26 | EX_TYPE_FANOUT uint8 = 2
27 | EX_TYPE_TOPIC uint8 = 3
28 | EX_TYPE_HEADERS uint8 = 4
29 | )
30 |
31 | type Exchange struct {
32 | gen.ExchangeState
33 | bindings []*binding.Binding
34 | bindingsLock sync.Mutex
35 | incoming chan amqp.Frame
36 | Closed bool
37 | deleteActive time.Time
38 | deleteChan chan *Exchange
39 | autodeletePeriod time.Duration
40 | }
41 |
42 | func (exchange *Exchange) Close() {
43 | exchange.Closed = true
44 | }
45 |
46 | func (exchange *Exchange) MarshalJSON() ([]byte, error) {
47 | var typ, err = exchangeTypeToName(exchange.ExType)
48 | if err != nil {
49 | return nil, err
50 | }
51 | return json.Marshal(map[string]interface{}{
52 | "type": typ,
53 | "bindings": exchange.bindings,
54 | })
55 | }
56 |
57 | func NewExchange(
58 | name string,
59 | extype uint8,
60 | durable bool,
61 | autodelete bool,
62 | internal bool,
63 | arguments *amqp.Table,
64 | system bool,
65 | deleteChan chan *Exchange,
66 | ) *Exchange {
67 | return &Exchange{
68 | ExchangeState: gen.ExchangeState{
69 | Name: name,
70 | ExType: extype,
71 | Durable: durable,
72 | AutoDelete: autodelete,
73 | Internal: internal,
74 | Arguments: arguments,
75 | System: system,
76 | },
77 | deleteChan: deleteChan,
78 | // not passed in
79 | incoming: make(chan amqp.Frame),
80 | bindings: make([]*binding.Binding, 0),
81 | autodeletePeriod: 5 * time.Second,
82 | }
83 | }
84 |
85 | func NewFromExchangeState(exState *gen.ExchangeState, deleteChan chan *Exchange) *Exchange {
86 | return &Exchange{
87 | ExchangeState: *exState,
88 | deleteChan: deleteChan,
89 | incoming: make(chan amqp.Frame),
90 | bindings: make([]*binding.Binding, 0),
91 | autodeletePeriod: 5 * time.Second,
92 | }
93 | }
94 |
95 | func NewFromMethod(method *amqp.ExchangeDeclare, system bool, exchangeDeleter chan *Exchange) (*Exchange, *amqp.AMQPError) {
96 | var classId, methodId = method.MethodIdentifier()
97 | var tp, err = ExchangeNameToType(method.Type)
98 | if err != nil || tp == EX_TYPE_HEADERS {
99 | return nil, amqp.NewHardError(503, "Bad exchange type", classId, methodId)
100 | }
101 | var ex = NewExchange(
102 | method.Exchange,
103 | tp,
104 | method.Durable,
105 | method.AutoDelete,
106 | method.Internal,
107 | method.Arguments,
108 | system,
109 | exchangeDeleter,
110 | )
111 | return ex, nil
112 | }
113 |
114 | func (ex1 *Exchange) EquivalentExchanges(ex2 *Exchange) bool {
115 | // NOTE: auto-delete is ignored for existing exchanges, so we
116 | // do not check it here.
117 | if ex1.Name != ex2.Name {
118 | return false
119 | }
120 | if ex1.ExType != ex2.ExType {
121 | return false
122 | }
123 | if ex1.Durable != ex2.Durable {
124 | return false
125 | }
126 | if ex1.Internal != ex2.Internal {
127 | return false
128 | }
129 | if !amqp.EquivalentTables(ex1.Arguments, ex2.Arguments) {
130 | return false
131 | }
132 | return true
133 | }
134 |
135 | func ExchangeNameToType(et string) (uint8, error) {
136 | switch {
137 | case et == "direct":
138 | return EX_TYPE_DIRECT, nil
139 | case et == "fanout":
140 | return EX_TYPE_FANOUT, nil
141 | case et == "topic":
142 | return EX_TYPE_TOPIC, nil
143 | case et == "headers":
144 | return EX_TYPE_HEADERS, nil
145 | default:
146 | return 0, fmt.Errorf("Unknown exchang type '%s', %d %d", et, len(et), len("direct"))
147 | }
148 | }
149 |
150 | func exchangeTypeToName(et uint8) (string, error) {
151 | switch {
152 | case et == EX_TYPE_DIRECT:
153 | return "direct", nil
154 | case et == EX_TYPE_FANOUT:
155 | return "fanout", nil
156 | case et == EX_TYPE_TOPIC:
157 | return "topic", nil
158 | case et == EX_TYPE_HEADERS:
159 | return "headers", nil
160 | default:
161 | return "", fmt.Errorf("bad exchange type: %d", et)
162 | }
163 | }
164 |
165 | func LoadAllExchanges(db *bolt.DB, deleteChan chan *Exchange) (map[string]*Exchange, error) {
166 | exStateMap, err := persist.LoadAll(db, EXCHANGES_BUCKET_NAME, &ExchangeStateFactory{})
167 | if err != nil {
168 | return nil, err
169 | }
170 | var ret = make(map[string]*Exchange)
171 | for key, state := range exStateMap {
172 | ret[key] = NewFromExchangeState(state.(*gen.ExchangeState), deleteChan)
173 | }
174 | return ret, nil
175 | }
176 |
177 | func (exchange *Exchange) QueuesForPublish(msg *amqp.Message) (map[string]bool, *amqp.AMQPError) {
178 | var queues = make(map[string]bool)
179 | if msg.Method.Exchange != exchange.Name {
180 | return queues, nil
181 | }
182 | switch {
183 | case exchange.ExType == EX_TYPE_DIRECT:
184 | // In a direct exchange we can return the first match since there is
185 | // only one queue with a particular name
186 | for _, binding := range exchange.bindings {
187 | if binding.MatchDirect(msg.Method) {
188 | queues[binding.QueueName] = true
189 | return queues, nil
190 | }
191 | }
192 | case exchange.ExType == EX_TYPE_FANOUT:
193 | for _, binding := range exchange.bindings {
194 | queues[binding.QueueName] = true
195 | }
196 | case exchange.ExType == EX_TYPE_TOPIC:
197 | for _, binding := range exchange.bindings {
198 | if binding.MatchTopic(msg.Method) {
199 | var _, alreadySeen = queues[binding.QueueName]
200 | if alreadySeen {
201 | continue
202 | }
203 | queues[binding.QueueName] = true
204 | }
205 | }
206 | // case exchange.ExType == EX_TYPE_HEADERS:
207 | // // TODO: implement
208 | // panic("Headers is not implemented!")
209 | default: // pragma: nocover
210 | panic("Unknown exchange type created somehow. Server integrity error!")
211 | }
212 | return queues, nil
213 | }
214 |
215 | func (exchange *Exchange) Persist(db *bolt.DB) error {
216 | var key = exchange.Name
217 | if key == "" {
218 | key = "~"
219 | }
220 | return persist.PersistOne(db, EXCHANGES_BUCKET_NAME, key, &exchange.ExchangeState)
221 | }
222 |
223 | func NewFromDisk(db *bolt.DB, key string, deleteChan chan *Exchange) (ex *Exchange, err error) {
224 | err = db.View(func(tx *bolt.Tx) error {
225 | bucket := tx.Bucket(EXCHANGES_BUCKET_NAME)
226 | if bucket == nil {
227 | return fmt.Errorf("Bucket not found: 'exchanges'")
228 | }
229 | ex, err = NewFromDiskBoltTx(bucket, []byte(key), deleteChan)
230 | return err
231 | })
232 | return
233 | }
234 |
235 | func NewFromDiskBoltTx(bucket *bolt.Bucket, key []byte, deleteChan chan *Exchange) (ex *Exchange, err error) {
236 | var lookupKey = key
237 | if len(key) == 0 {
238 | lookupKey = []byte{'~'}
239 | }
240 | exBytes := bucket.Get(lookupKey)
241 | if exBytes == nil {
242 | return nil, fmt.Errorf("Key not found: '%s'", key)
243 | }
244 | exState := gen.ExchangeState{}
245 | err = exState.Unmarshal(exBytes)
246 | if err != nil {
247 | return nil, fmt.Errorf("Could not unmarshal exchange %s", key)
248 | }
249 | return NewFromExchangeState(&exState, deleteChan), nil
250 | }
251 |
252 | func (exchange *Exchange) Depersist(db *bolt.DB) error {
253 | return db.Update(func(tx *bolt.Tx) error {
254 | bucket := tx.Bucket(EXCHANGES_BUCKET_NAME)
255 | if bucket == nil {
256 | return fmt.Errorf("Bucket not found: '%s'", bucket)
257 | }
258 | for _, binding := range exchange.bindings {
259 | if err := binding.DepersistBoltTx(tx); err != nil { // pragma: nocover
260 | return err
261 | }
262 | }
263 | return persist.DepersistOneBoltTx(bucket, exchange.Name)
264 | })
265 | }
266 |
267 | func (exchange *Exchange) IsTopic() bool {
268 | return exchange.ExType == EX_TYPE_TOPIC
269 | }
270 |
271 | func (exchange *Exchange) AddBinding(b *binding.Binding, connId int64) error {
272 | exchange.bindingsLock.Lock()
273 | defer exchange.bindingsLock.Unlock()
274 |
275 | for _, b2 := range exchange.bindings {
276 | if b.Equals(b2) {
277 | return nil
278 | }
279 | }
280 |
281 | if exchange.AutoDelete {
282 | exchange.deleteActive = time.Unix(0, 0)
283 | }
284 | exchange.bindings = append(exchange.bindings, b)
285 | return nil
286 | }
287 |
288 | func (exchange *Exchange) BindingsForQueue(queueName string) []*binding.Binding {
289 | var ret = make([]*binding.Binding, 0)
290 | exchange.bindingsLock.Lock()
291 | defer exchange.bindingsLock.Unlock()
292 | for _, b := range exchange.bindings {
293 | if b.QueueName == queueName {
294 | ret = append(ret, b)
295 | }
296 | }
297 | return ret
298 | }
299 |
300 | func (exchange *Exchange) RemoveBindingsForQueue(queueName string) {
301 | var remaining = make([]*binding.Binding, 0)
302 | exchange.bindingsLock.Lock()
303 | defer exchange.bindingsLock.Unlock()
304 | for _, b := range exchange.bindings {
305 | if b.QueueName != queueName {
306 | remaining = append(remaining, b)
307 | }
308 | }
309 | exchange.bindings = remaining
310 | }
311 |
312 | func (exchange *Exchange) RemoveBinding(binding *binding.Binding) error {
313 | exchange.bindingsLock.Lock()
314 | defer exchange.bindingsLock.Unlock()
315 |
316 | // Delete binding
317 | for i, b := range exchange.bindings {
318 | if binding.Equals(b) {
319 | exchange.bindings = append(exchange.bindings[:i], exchange.bindings[i+1:]...)
320 | if exchange.AutoDelete && len(exchange.bindings) == 0 {
321 | go exchange.autodeleteTimeout()
322 | }
323 | return nil
324 | }
325 | }
326 | return nil
327 | }
328 |
329 | func (exchange *Exchange) autodeleteTimeout() {
330 | // There's technically a race condition here where a new binding could be
331 | // added right as we check this, but after a 5 second wait with no activity
332 | // I think this is probably safe enough.
333 | var now = time.Now()
334 | exchange.deleteActive = now
335 | time.Sleep(exchange.autodeletePeriod)
336 | if exchange.deleteActive == now {
337 | exchange.deleteChan <- exchange
338 | }
339 | }
340 |
--------------------------------------------------------------------------------
/gen/server.proto:
--------------------------------------------------------------------------------
1 | package gen;
2 |
3 | import "github.com/jeffjenkins/dispatchd/amqp/amqp.proto";
4 | import "github.com/gogo/protobuf/gogoproto/gogo.proto";
5 |
6 | option (gogoproto.marshaler_all) = true;
7 | option (gogoproto.sizer_all) = true;
8 | option (gogoproto.unmarshaler_all) = true;
9 |
10 | message ExchangeState {
11 | option (gogoproto.goproto_unrecognized) = false;
12 | option (gogoproto.goproto_getters) = false;
13 | optional string name = 1 [(gogoproto.nullable) = false];
14 | optional uint32 ex_type = 2 [(gogoproto.casttype) = "uint8", (gogoproto.nullable) = false];
15 | optional bool passive = 3 [(gogoproto.nullable) = false];
16 | optional bool durable = 4 [(gogoproto.nullable) = false];
17 | optional bool auto_delete = 5 [(gogoproto.nullable) = false];
18 | optional bool internal = 6 [(gogoproto.nullable) = false];
19 | optional bool system = 7 [(gogoproto.nullable) = false];
20 | optional amqp.Table arguments = 8;
21 | }
22 |
23 | message BindingState {
24 | option (gogoproto.goproto_unrecognized) = false;
25 | option (gogoproto.goproto_getters) = false;
26 | optional bytes id = 1;
27 | optional string queue_name = 2 [(gogoproto.nullable) = false];
28 | optional string exchange_name = 3 [(gogoproto.nullable) = false];
29 | optional string key = 4 [(gogoproto.nullable) = false];
30 | optional amqp.Table arguments = 5;
31 | optional bool topic = 6 [(gogoproto.nullable) = false];
32 | }
33 |
34 | message QueueState {
35 | option (gogoproto.goproto_unrecognized) = false;
36 | option (gogoproto.goproto_getters) = false;
37 | optional string name = 1 [(gogoproto.nullable) = false];
38 | optional bool durable = 2 [(gogoproto.nullable) = false];
39 | optional amqp.Table arguments = 3;
40 | }
--------------------------------------------------------------------------------
/msgstore/msgstore_test.go:
--------------------------------------------------------------------------------
1 | package msgstore
2 |
3 | import (
4 | // "container/list"
5 | "fmt"
6 | "github.com/jeffjenkins/dispatchd/amqp"
7 | "os"
8 | "testing"
9 | )
10 |
11 | func TestWrite(t *testing.T) {
12 | // Setup
13 | var dbFile = "TestWrite.db"
14 | os.Remove(dbFile)
15 | defer os.Remove(dbFile)
16 | rhs := []amqp.MessageResourceHolder{&TestResourceHolder{}}
17 | ms, err := NewMessageStore(dbFile)
18 | // Create messages
19 | msg1 := amqp.RandomMessage(true)
20 | msg2 := amqp.RandomMessage(true)
21 | fmt.Printf("Creating ids: %d, %d\n", msg1.Id, msg2.Id)
22 |
23 | // Store messages and delete one
24 | fmt.Println("Adding message 1")
25 | ms.AddMessage(msg1, []string{"some-queue", "some-other-queue"})
26 | fmt.Println("Adding message 2")
27 | qm2Map, err := ms.AddMessage(msg2, []string{"some-queue"})
28 | _, err = ms.GetAndDecrRef(qm2Map["some-queue"][0], "some-queue", rhs)
29 | if err != nil {
30 | t.Errorf(err.Error())
31 | return
32 | }
33 |
34 | // Close DB
35 | ms.db.Close()
36 | keys := map[int64]bool{
37 | msg1.Id: true,
38 | // msg2.Id: true,
39 | }
40 | // Assert that the DB is correct
41 | err = assertKeys(dbFile, keys)
42 | if err != nil {
43 | t.Errorf(err.Error())
44 | return
45 | }
46 |
47 | // try loading from disk in the message store
48 | ms2, err := NewMessageStore(dbFile)
49 | _, err = ms2.LoadQueueFromDisk("some-queue")
50 | if err != nil {
51 | t.Errorf(err.Error())
52 | return
53 | }
54 | _, err = ms2.LoadQueueFromDisk("some-other-queue")
55 | if err != nil {
56 | t.Errorf(err.Error())
57 | return
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/msgstore/testlib.go:
--------------------------------------------------------------------------------
1 | package msgstore
2 |
3 | import (
4 | "fmt"
5 | "github.com/boltdb/bolt"
6 | "github.com/jeffjenkins/dispatchd/amqp"
7 | "reflect"
8 | )
9 |
10 | // ids to map
11 | func idsToMap() {
12 |
13 | }
14 |
15 | // Check if the msg store is composed of exactly these keys
16 | func assertKeys(dbName string, keys map[int64]bool) error {
17 | // Open DB
18 | db, err := bolt.Open(dbName, 0600, nil)
19 | defer db.Close()
20 | if err != nil {
21 | return err
22 | }
23 | err = db.View(func(tx *bolt.Tx) error {
24 | // Check index
25 | bucket := tx.Bucket([]byte("message_index"))
26 | if bucket == nil {
27 | return nil
28 | }
29 |
30 | // get from db
31 | var indexKeys, err1 = keysForBucket(tx, MESSAGE_INDEX_BUCKET)
32 | var contentKeys, err2 = keysForBucket(tx, MESSAGE_CONTENT_BUCKET)
33 | if err1 != nil {
34 | return err1
35 | }
36 | if err2 != nil {
37 | return err2
38 | }
39 |
40 | // Check equality
41 | // TODO: return key diff
42 | indexNotKeys := subtract(indexKeys, keys)
43 | keysNotIndex := subtract(keys, indexKeys)
44 | contentNotKeys := subtract(contentKeys, keys)
45 | keysNotContent := subtract(keys, contentKeys)
46 |
47 | if !reflect.DeepEqual(keys, indexKeys) {
48 | return fmt.Errorf("Different values in index!\nindexNotKeys:%q\nkeysNotIndex:%q", indexNotKeys, keysNotIndex)
49 | }
50 | if !reflect.DeepEqual(keys, contentKeys) {
51 | return fmt.Errorf("Different values in content!\ncontentNotKeys:%q\nkeysNotContent:%q", contentNotKeys, keysNotContent)
52 | }
53 | return nil
54 | })
55 | if err != nil {
56 | return err
57 | }
58 | return nil
59 | }
60 |
61 | func subtract(original map[int64]bool, subtractThis map[int64]bool) []int64 {
62 | ret := make([]int64, 0)
63 | for id, _ := range original {
64 | _, found := subtractThis[id]
65 | if !found {
66 | ret = append(ret, id)
67 | }
68 | }
69 | return ret
70 | }
71 |
72 | func keysForBucket(tx *bolt.Tx, bucketName []byte) (map[int64]bool, error) {
73 | // Check index
74 | bucket := tx.Bucket(bucketName)
75 | if bucket == nil {
76 | return nil, fmt.Errorf("No bucket!")
77 | }
78 | var cursor = bucket.Cursor()
79 | var keys = make(map[int64]bool)
80 | for bid, _ := cursor.First(); bid != nil; bid, _ = cursor.Next() {
81 | fmt.Printf("%s, key:%d\n", bucketName, bytesToInt64(bid))
82 | keys[bytesToInt64(bid)] = true
83 | }
84 | return keys, nil
85 | }
86 |
87 | type TestResourceHolder struct {
88 | }
89 |
90 | func (trh *TestResourceHolder) AcquireResources(qm *amqp.QueueMessage) bool {
91 | return true
92 | }
93 | func (trh *TestResourceHolder) ReleaseResources(qm *amqp.QueueMessage) {
94 |
95 | }
96 |
--------------------------------------------------------------------------------
/persist/persist.go:
--------------------------------------------------------------------------------
1 | package persist
2 |
3 | import (
4 | "fmt"
5 | "github.com/boltdb/bolt"
6 | "github.com/gogo/protobuf/proto"
7 | )
8 |
9 | type UnmarshalerFactory interface {
10 | New() proto.Unmarshaler
11 | }
12 |
13 | //
14 | // Persist
15 | //
16 |
17 | func PersistOne(db *bolt.DB, bucketName []byte, key string, obj proto.Marshaler) error {
18 | return db.Update(func(tx *bolt.Tx) error {
19 | bucket, err := tx.CreateBucketIfNotExists(bucketName)
20 | if err != nil { // pragma: nocover
21 | return fmt.Errorf("create bucket: %s", err)
22 | }
23 | return PersistOneBoltTx(bucket, key, obj)
24 | })
25 | }
26 |
27 | func PersistOneBoltTx(bucket *bolt.Bucket, key string, obj proto.Marshaler) error {
28 | exBytes, err := obj.Marshal()
29 | if err != nil { // pragma: nocover -- no idea how to produce this error
30 | return fmt.Errorf("Could not marshal object")
31 | }
32 | return bucket.Put([]byte(key), exBytes)
33 | }
34 |
35 | func PersistMany(db *bolt.DB, bucketName []byte, objs map[string]proto.Marshaler) error {
36 | return db.Update(func(tx *bolt.Tx) error {
37 | bucket, err := tx.CreateBucketIfNotExists(bucketName)
38 | if err != nil { // pragma: nocover
39 | return fmt.Errorf("create bucket: %s", err)
40 | }
41 | return PersistManyBoltTx(bucket, objs)
42 | })
43 | }
44 |
45 | func PersistManyBoltTx(bucket *bolt.Bucket, objs map[string]proto.Marshaler) error {
46 | for key, obj := range objs {
47 | err := PersistOneBoltTx(bucket, key, obj)
48 | if err != nil {
49 | return err
50 | }
51 | }
52 | return nil
53 | }
54 |
55 | //
56 | // Load
57 | //
58 |
59 | func LoadOne(db *bolt.DB, bucketName []byte, key string, obj proto.Unmarshaler) error {
60 | return db.Update(func(tx *bolt.Tx) error {
61 | bucket := tx.Bucket(bucketName)
62 | if bucket == nil {
63 | return fmt.Errorf("Bucket not found: '%s'", bucket)
64 | }
65 | return LoadOneBoltTx(bucket, key, obj)
66 | })
67 | }
68 |
69 | func LoadOneBoltTx(bucket *bolt.Bucket, key string, obj proto.Unmarshaler) error {
70 | objBytes := bucket.Get([]byte(key))
71 | if objBytes == nil {
72 | return fmt.Errorf("Key not found: '%s'", key)
73 | }
74 | err := obj.Unmarshal(objBytes)
75 | if err != nil {
76 | return fmt.Errorf("Could not unmarshal key %s", key)
77 | }
78 | return nil
79 | }
80 |
81 | func LoadMany(db *bolt.DB, bucketName []byte, objs map[string]proto.Unmarshaler) error {
82 | return db.Update(func(tx *bolt.Tx) error {
83 | bucket := tx.Bucket(bucketName)
84 | if bucket == nil { // pragma: nocover
85 | return fmt.Errorf("create bucket: '%s'", bucket)
86 | }
87 | return LoadManyBoltTx(bucket, objs)
88 | })
89 | }
90 |
91 | func LoadManyBoltTx(bucket *bolt.Bucket, objs map[string]proto.Unmarshaler) error {
92 | for key, obj := range objs {
93 | err := LoadOneBoltTx(bucket, key, obj)
94 | if err != nil {
95 | return err
96 | }
97 | }
98 | return nil
99 | }
100 |
101 | func LoadAll(db *bolt.DB, bucket []byte, factory UnmarshalerFactory) (map[string]proto.Unmarshaler, error) {
102 | ret := make(map[string]proto.Unmarshaler)
103 | err := db.View(func(tx *bolt.Tx) error {
104 | bucket := tx.Bucket(bucket)
105 | if bucket == nil {
106 | return nil
107 | }
108 | // iterate through queues
109 | cursor := bucket.Cursor()
110 | for name, data := cursor.First(); name != nil; name, data = cursor.Next() {
111 | obj := factory.New()
112 | err := obj.Unmarshal(data)
113 | if err != nil {
114 | return fmt.Errorf("Could not unmarshal key %s", string(name))
115 | }
116 | ret[string(name)] = obj
117 | }
118 | return nil
119 | })
120 | if err != nil {
121 | return nil, err
122 | }
123 | return ret, nil
124 | }
125 |
126 | //
127 | // Depersist
128 | //
129 |
130 | func DepersistOne(db *bolt.DB, bucketName []byte, key string) error {
131 | return db.Update(func(tx *bolt.Tx) error {
132 | bucket := tx.Bucket(bucketName)
133 | if bucket == nil {
134 | return fmt.Errorf("Bucket not found: '%s'", bucket)
135 | }
136 | return DepersistOneBoltTx(bucket, key)
137 | })
138 | }
139 |
140 | func DepersistOneBoltTx(bucket *bolt.Bucket, key string) error {
141 | return bucket.Delete([]byte(key))
142 | }
143 |
144 | func DepersistMany(db *bolt.DB, bucketName []byte, keys map[string]bool) error {
145 | return db.Update(func(tx *bolt.Tx) error {
146 | bucket := tx.Bucket(bucketName)
147 | if bucket == nil { // pragma: nocover
148 | return fmt.Errorf("create bucket: '%s'", bucket)
149 | }
150 | return DepersistManyBoltTx(bucket, keys)
151 | })
152 | }
153 |
154 | func DepersistManyBoltTx(bucket *bolt.Bucket, keys map[string]bool) error {
155 | for key, _ := range keys {
156 | err := DepersistOneBoltTx(bucket, key)
157 | if err != nil {
158 | return err
159 | }
160 | }
161 | return nil
162 | }
163 |
--------------------------------------------------------------------------------
/persist/persist_test.go:
--------------------------------------------------------------------------------
1 | package persist
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestPersist(t *testing.T) {
8 |
9 | }
10 |
--------------------------------------------------------------------------------
/queue/queue.go:
--------------------------------------------------------------------------------
1 | package queue
2 |
3 | import (
4 | "container/list"
5 | "encoding/json"
6 | "errors"
7 | "fmt"
8 | "github.com/boltdb/bolt"
9 | "github.com/gogo/protobuf/proto"
10 | "github.com/jeffjenkins/dispatchd/amqp"
11 | "github.com/jeffjenkins/dispatchd/consumer"
12 | "github.com/jeffjenkins/dispatchd/gen"
13 | "github.com/jeffjenkins/dispatchd/msgstore"
14 | "github.com/jeffjenkins/dispatchd/persist"
15 | "github.com/jeffjenkins/dispatchd/stats"
16 | "sync"
17 | "time"
18 | )
19 |
20 | var QUEUE_BUCKET_NAME = []byte("queues")
21 |
22 | type QueueStateFactory struct{}
23 |
24 | func (qsf *QueueStateFactory) New() proto.Unmarshaler {
25 | return &gen.QueueState{}
26 | }
27 |
28 | type Queue struct {
29 | gen.QueueState
30 | autoDelete bool
31 | exclusive bool
32 | Closed bool
33 | objLock sync.RWMutex
34 | queue *list.List // int64
35 | queueLock sync.Mutex
36 | consumerLock sync.RWMutex
37 | consumers []*consumer.Consumer // *Consumer
38 | currentConsumer int
39 | statCount uint64
40 | maybeReady chan bool
41 | soleConsumer *consumer.Consumer
42 | ConnId int64
43 | deleteActive time.Time
44 | hasHadConsumers bool
45 | msgStore *msgstore.MessageStore
46 | statProcOne stats.Histogram
47 | deleteChan chan *Queue
48 | }
49 |
50 | func NewQueue(
51 | name string,
52 | durable bool,
53 | exclusive bool,
54 | autoDelete bool,
55 | arguments *amqp.Table,
56 | connId int64,
57 | msgStore *msgstore.MessageStore,
58 | deleteChan chan *Queue,
59 | ) *Queue {
60 | return &Queue{
61 | QueueState: gen.QueueState{
62 | Name: name,
63 | Durable: durable,
64 | Arguments: arguments,
65 | },
66 | exclusive: exclusive,
67 | autoDelete: autoDelete,
68 | ConnId: connId,
69 | msgStore: msgStore,
70 | deleteChan: deleteChan,
71 | // Fields that aren't passed in
72 | statProcOne: stats.MakeHistogram("queue-proc-one"),
73 | queue: list.New(),
74 | consumers: make([]*consumer.Consumer, 0, 1),
75 | maybeReady: make(chan bool, 1),
76 | }
77 | }
78 |
79 | func NewFromPersistedState(state *gen.QueueState, msgStore *msgstore.MessageStore, deleteChan chan *Queue) *Queue {
80 | return &Queue{
81 | QueueState: *state,
82 | exclusive: false,
83 | autoDelete: false,
84 | ConnId: -1,
85 | msgStore: msgStore,
86 | deleteChan: deleteChan,
87 | // Fields that aren't passed in
88 | statProcOne: stats.MakeHistogram("queue-proc-one"),
89 | queue: list.New(),
90 | consumers: make([]*consumer.Consumer, 0, 1),
91 | maybeReady: make(chan bool, 1),
92 | }
93 | }
94 |
95 | func (q1 *Queue) EquivalentQueues(q2 *Queue) bool {
96 | if q1 == nil {
97 | return q2 == nil
98 | }
99 | if q2 == nil {
100 | return false
101 | }
102 |
103 | // Note: autodelete is not included since the spec says to ignore
104 | // the field if the queue is already created
105 | if q1.Name != q2.Name {
106 | return false
107 | }
108 | if q1.Durable != q2.Durable {
109 | return false
110 | }
111 | if q1.exclusive != q2.exclusive {
112 | return false
113 | }
114 | if !amqp.EquivalentTables(q1.Arguments, q2.Arguments) {
115 | return false
116 | }
117 | return true
118 | }
119 |
120 | func (q *Queue) Len() uint32 {
121 | var l = q.queue.Len()
122 | if l < 0 {
123 | panic("Queue length overflow!")
124 | }
125 | return uint32(l)
126 | }
127 |
128 | func (q *Queue) ActiveConsumerCount() uint32 {
129 | // TODO(MUST): don't count consumers in the Channel.Flow state once
130 | // that is implemented
131 | return uint32(len(q.consumers))
132 | }
133 |
134 | func (q *Queue) MarshalJSON() ([]byte, error) {
135 | return json.Marshal(map[string]interface{}{
136 | "name": q.Name,
137 | "durable": q.Durable,
138 | "exclusive": q.exclusive,
139 | "connId": q.ConnId,
140 | "autoDelete": q.autoDelete,
141 | "size": q.queue.Len(),
142 | "consumers": q.consumers,
143 | })
144 | }
145 |
146 | func (q *Queue) Persist(db *bolt.DB) error {
147 | return persist.PersistOne(db, QUEUE_BUCKET_NAME, q.Name, q)
148 | }
149 |
150 | func (q *Queue) Depersist(db *bolt.DB) error {
151 | return persist.DepersistOne(db, QUEUE_BUCKET_NAME, q.Name)
152 | }
153 |
154 | func (q *Queue) DepersistBoltTx(tx *bolt.Tx) error {
155 | bucket, err := tx.CreateBucketIfNotExists(QUEUE_BUCKET_NAME)
156 | if err != nil {
157 | return fmt.Errorf("create bucket: %s", err)
158 | }
159 | return persist.DepersistOneBoltTx(bucket, q.Name)
160 | }
161 |
162 | func LoadAllQueues(db *bolt.DB, msgStore *msgstore.MessageStore, deleteChan chan *Queue) (map[string]*Queue, error) {
163 | queueStateMap, err := persist.LoadAll(db, QUEUE_BUCKET_NAME, &QueueStateFactory{})
164 | if err != nil {
165 | return nil, err
166 | }
167 | var ret = make(map[string]*Queue)
168 | for key, state := range queueStateMap {
169 | ret[key] = NewFromPersistedState(state.(*gen.QueueState), msgStore, deleteChan)
170 | }
171 | return ret, nil
172 | }
173 |
174 | func (q *Queue) LoadFromMsgStore(msgStore *msgstore.MessageStore) {
175 | queueList, err := msgStore.LoadQueueFromDisk(q.Name)
176 | if err != nil {
177 | panic("Integrity error reading queue from disk! " + err.Error())
178 | }
179 | q.queue = queueList
180 | select {
181 | case q.maybeReady <- true:
182 | default:
183 | }
184 | }
185 |
186 | func (q *Queue) Close() {
187 | // This discards any messages which would be added. It does not
188 | // do cleanup
189 | q.queueLock.Lock()
190 | defer q.queueLock.Unlock()
191 | q.Closed = true
192 | }
193 |
194 | func (q *Queue) Purge() uint32 {
195 | q.queueLock.Lock()
196 | defer q.queueLock.Unlock()
197 | return q.purgeNotThreadSafe()
198 | }
199 |
200 | func (q *Queue) purgeNotThreadSafe() uint32 {
201 | var length = q.queue.Len()
202 | q.queue.Init()
203 | return uint32(length)
204 | }
205 |
206 | func (q *Queue) Add(qm *amqp.QueueMessage) bool {
207 | // NOTE: I tried using consumeImmediate before adding things to the queue,
208 | // but it caused a pretty significant slowdown.
209 | q.queueLock.Lock()
210 | defer q.queueLock.Unlock()
211 | if !q.Closed {
212 | q.statCount += 1
213 | q.queue.PushBack(qm)
214 | select {
215 | case q.maybeReady <- true:
216 | default:
217 | }
218 | return true
219 | } else {
220 | return false
221 | }
222 | }
223 |
224 | func (q *Queue) ConsumeImmediate(qm *amqp.QueueMessage) bool {
225 | // TODO: randomize or round-robin through consumers
226 | q.consumerLock.RLock()
227 | defer q.consumerLock.RUnlock()
228 | for _, consumer := range q.consumers {
229 | var msg, acquired = q.msgStore.Get(qm, consumer.MessageResourceHolders())
230 | if acquired {
231 | consumer.ConsumeImmediate(qm, msg)
232 | return true
233 | }
234 | }
235 | return false
236 | }
237 |
238 | func (q *Queue) Delete(ifUnused bool, ifEmpty bool) (uint32, error) {
239 | // Lock
240 | if !q.Closed {
241 | panic("Queue deleted before it was closed!")
242 | }
243 | q.queueLock.Lock()
244 | defer q.queueLock.Unlock()
245 |
246 | // Check
247 | var usedOk = !ifUnused || len(q.consumers) == 0
248 | var emptyOk = !ifEmpty || q.queue.Len() == 0
249 | if !usedOk {
250 | return 0, errors.New("if-unused specified and there are consumers")
251 | }
252 | if !emptyOk {
253 | return 0, errors.New("if-empty specified and there are messages in the queue")
254 | }
255 | // Purge
256 | q.cancelConsumers()
257 | return q.purgeNotThreadSafe(), nil
258 | }
259 |
260 | func (q *Queue) Readd(queueName string, msg *amqp.QueueMessage) {
261 | // TODO: if there is a consumer available, dispatch
262 | q.queueLock.Lock()
263 | defer q.queueLock.Unlock()
264 | // this method is only called when we get a nack or we shut down a channel,
265 | // so it means the message was not acked.
266 | q.msgStore.IncrDeliveryCount(queueName, msg)
267 | q.queue.PushFront(msg)
268 | select {
269 | case q.maybeReady <- true:
270 | default:
271 | }
272 | }
273 |
274 | func (q *Queue) removeConsumer(consumerTag string) {
275 | q.consumerLock.Lock()
276 | defer q.consumerLock.Unlock()
277 | if q.soleConsumer != nil && q.soleConsumer.ConsumerTag == consumerTag {
278 | q.soleConsumer = nil
279 | }
280 | // remove from list
281 | for i, c := range q.consumers {
282 | if c.ConsumerTag == consumerTag {
283 | q.consumers = append(q.consumers[:i], q.consumers[i+1:]...)
284 | }
285 | }
286 | var size = len(q.consumers)
287 | if size == 0 {
288 | q.currentConsumer = 0
289 | if q.autoDelete && q.hasHadConsumers {
290 | go q.autodeleteTimeout()
291 | }
292 | } else {
293 | q.currentConsumer = q.currentConsumer % size
294 | }
295 |
296 | }
297 |
298 | func (q *Queue) autodeleteTimeout() {
299 | // There's technically a race condition here where a new binding could be
300 | // added right as we check this, but after a 5 second wait with no activity
301 | // I think this is probably safe enough.
302 | var now = time.Now()
303 | q.deleteActive = now
304 | time.Sleep(5 * time.Second)
305 | if q.deleteActive == now {
306 | q.deleteChan <- q
307 | }
308 | }
309 |
310 | func (q *Queue) cancelConsumers() {
311 | q.consumerLock.Lock()
312 | defer q.consumerLock.Unlock()
313 | q.soleConsumer = nil
314 | // Send cancel to each consumer
315 | for _, c := range q.consumers {
316 | c.SendCancel()
317 | c.Stop()
318 | }
319 | q.consumers = make([]*consumer.Consumer, 0, 1)
320 | }
321 |
322 | func (q *Queue) AddConsumer(c *consumer.Consumer, exclusive bool) (uint16, error) {
323 | if q.Closed {
324 | return 0, nil
325 | }
326 | // Reset auto-delete
327 | q.deleteActive = time.Unix(0, 0)
328 |
329 | // Add consumer
330 | q.consumerLock.Lock()
331 | if exclusive {
332 | if len(q.consumers) == 0 {
333 | q.soleConsumer = c
334 | } else {
335 | return 403, fmt.Errorf("Exclusive access denied, %d consumers active", len(q.consumers))
336 | }
337 | }
338 | q.consumers = append(q.consumers, c)
339 | q.hasHadConsumers = true
340 | q.consumerLock.Unlock()
341 | return 0, nil
342 | }
343 |
344 | func (q *Queue) Start() {
345 | go func() {
346 | select {
347 | case q.maybeReady <- true:
348 | default:
349 | }
350 | for _ = range q.maybeReady {
351 | if q.Closed {
352 | fmt.Printf("Queue closed!\n")
353 | break
354 | }
355 | q.processOne()
356 | }
357 | }()
358 | }
359 |
360 | func (q *Queue) MaybeReady() chan bool {
361 | return q.maybeReady
362 | }
363 |
364 | func (q *Queue) processOne() {
365 | defer stats.RecordHisto(q.statProcOne, stats.Start())
366 | q.consumerLock.RLock()
367 | defer q.consumerLock.RUnlock()
368 | var size = len(q.consumers)
369 | if size == 0 {
370 | return
371 | }
372 | for count := 0; count < size; count++ {
373 | q.currentConsumer = (q.currentConsumer + 1) % size
374 | var c = q.consumers[q.currentConsumer]
375 | c.Ping()
376 | }
377 | }
378 |
379 | func (q *Queue) GetOneForced() *amqp.QueueMessage {
380 | q.queueLock.Lock()
381 | defer q.queueLock.Unlock()
382 | if q.queue.Len() == 0 {
383 | return nil
384 | }
385 | qMsg := q.queue.Remove(q.queue.Front()).(*amqp.QueueMessage)
386 | return qMsg
387 | }
388 |
389 | func (q *Queue) GetOne(rhs ...amqp.MessageResourceHolder) (*amqp.QueueMessage, *amqp.Message) {
390 | q.queueLock.Lock()
391 | defer q.queueLock.Unlock()
392 | // Empty check
393 | if q.queue.Len() == 0 || q.Closed {
394 | return nil, nil
395 | }
396 |
397 | // Get one message. If there is a message try to acquire the resources
398 | // from the channel.
399 | var qm = q.queue.Front().Value.(*amqp.QueueMessage)
400 |
401 | var msg, acquired = q.msgStore.Get(qm, rhs)
402 | if acquired {
403 | q.queue.Remove(q.queue.Front())
404 | return qm, msg
405 | }
406 | return nil, nil
407 | }
408 |
--------------------------------------------------------------------------------
/queue/queue_test.go:
--------------------------------------------------------------------------------
1 | package queue
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestQueueAdd(t *testing.T) {
8 |
9 | }
10 |
--------------------------------------------------------------------------------
/scripts/benchmark_helper.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -x
4 |
5 | cd scripts/external/perf-client/
6 | ./runjava.sh com.rabbitmq.examples.PerfTest \
7 | --exchange perf-test \
8 | -uri amqp://guest:guest@localhost:${RUN_PORT} \
9 | --queue perf-test-transient \
10 | --consumers 4 \
11 | --producers 2 \
12 | --qos 20 \
13 | --time 20 "$@"
14 |
--------------------------------------------------------------------------------
/scripts/cover.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import sys
3 | import subprocess
4 | import os
5 |
6 | PREFIX = 'github.com/jeffjenkins/dispatchd/'
7 |
8 | class Colors(object):
9 | PURPLE = '\033[95m'
10 | BLUE = '\033[94m'
11 | GREEN = '\033[92m'
12 | YELLOW = '\033[93m'
13 | RED = '\033[91m'
14 | ENDC = '\033[0m'
15 | BOLD = '\033[1m'
16 | UNDERLINE = '\033[4m'
17 |
18 | def main(args):
19 | # Gather packages
20 | gopath = os.environ['GOPATH']
21 | output_dir = os.path.join(gopath, 'cover')
22 | if not os.path.exists(output_dir):
23 | os.makedirs(output_dir)
24 | package_dir = os.path.join(gopath, 'src', PREFIX)
25 | prefix = 'github.com/jeffjenkins/dispatchd/'
26 | test_packages = subprocess.check_output([
27 | 'go',
28 | 'list',
29 | '{}...'.format(prefix),
30 | ]).split('\n')
31 | test_packages = [t.strip().replace(prefix, '') for t in test_packages]
32 | test_packages = [t for t in test_packages if t]
33 |
34 | cover_names = []
35 | for pkg in test_packages:
36 | cover_name = os.path.join(output_dir, '{}.cover'.format(pkg))
37 | cover_names.append((pkg, cover_name))
38 | test_target = os.path.join(PREFIX, pkg)
39 | cmd = [
40 | 'go',
41 | 'test',
42 | '-coverpkg=github.com/jeffjenkins/dispatchd/...',
43 | '-coverprofile={}'.format(cover_name),
44 | test_target,
45 | ]
46 | subprocess.check_call(cmd)
47 |
48 | merge_call = [os.path.join(gopath, 'bin', 'gocovmerge')]
49 | merge_call.extend([n for _, n in cover_names if os.path.exists(n)])
50 | output = subprocess.check_output(merge_call)
51 | all = os.path.join(output_dir, 'all.cover')
52 | with open(all, 'w') as f:
53 | f.write(output)
54 | cover_summary([all])
55 |
56 | def nocover(file, line):
57 | with open(os.path.join(os.environ['GOPATH'], 'src', file)) as inf:
58 | lines = inf.readlines()
59 | if 'pragma: nocover' in lines[int(line)-1]:
60 | return True
61 | return False
62 |
63 | def count_lines(file):
64 | with open(os.path.join(os.environ['GOPATH'], 'src', file)) as inf:
65 | return len(inf.readlines())
66 |
67 | def cover_summary(cover_names):
68 | print Colors.BLUE, '===== Missing Coverage =====', Colors.ENDC
69 | count = defaultdict(list)
70 | missing = defaultdict(list)
71 | total_missing = 0
72 | total_lines = 0
73 | seen = set()
74 | for name in cover_names:
75 | with open(name) as inf:
76 | for line in inf.readlines():
77 | line = line.strip()
78 | if line[-1] != '0':
79 | continue
80 | full_file, _, report = line.partition(':')
81 | file = full_file.replace(PREFIX, '')
82 | if 'pb.go' in file or 'generated' in file or 'testlib' in file:
83 | continue
84 | range, _, _ = report.partition(' ')
85 | first, _, second = range.partition(',')
86 | first_line, _, _ = first.partition('.')
87 | second_line, _, _ = second.partition('.')
88 | range = (int(first_line), int(second_line))
89 | if full_file not in seen:
90 | total_lines += count_lines(full_file)
91 | seen.add(full_file)
92 | total_missing += range[1] - range[0] + 1
93 | if nocover(full_file, first_line):
94 | continue
95 | missing[file].append(range)
96 |
97 |
98 | for file, ranges in sorted(missing.items()):
99 | if ranges is not None:
100 | ranges = sorted(ranges)
101 | ranges = ['{}-{}'.format(x,y) for x, y in ranges]
102 | if len(ranges) == 0:
103 | print Colors.GREEN, file+':', Colors.ENDC, 'Full coverage!'
104 | print Colors.RED, file+':', Colors.ENDC, ', '.join(ranges)
105 | else:
106 | print Colors.RED, file+':', Colors.ENDC, 'No coverage'
107 | print Colors.RED, 'Remaining lines on files without full coverage:', total_missing, '/', total_lines, Colors.ENDC
108 |
109 |
110 |
111 |
112 | if __name__ == '__main__':
113 | main(sys.argv[1:])
--------------------------------------------------------------------------------
/server/auth.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "bytes"
5 | "encoding/base64"
6 | "golang.org/x/crypto/bcrypt"
7 | )
8 |
9 | type User struct {
10 | name string
11 | password []byte
12 | }
13 |
14 | func (s *Server) addUsers(userJson map[string]interface{}) {
15 | if len(userJson) == 0 {
16 | decoded, err := base64.StdEncoding.DecodeString(defaultUserPasswordBase64)
17 | if err != nil {
18 | panic("System integrity error: Could not base64 decode password for built in default user!")
19 | }
20 | s.users[defaultUserName] = User{name: defaultUserName, password: decoded}
21 | }
22 | for name, user := range userJson {
23 | var encoded = user.(map[string]interface{})["password_bcrypt_base64"].(string)
24 | decoded, err := base64.StdEncoding.DecodeString(encoded)
25 | if err != nil {
26 | panic("Could not base64 decode password for default config file user: " + name)
27 | }
28 | s.users[name] = User{name: name, password: decoded}
29 | }
30 | }
31 |
32 | // guest/guest
33 | var defaultUserName = "guest"
34 | var defaultUserPasswordBase64 = "JDJhJDExJENobGk4dG5rY0RGemJhTjhsV21xR3VNNnFZZ1ZqTzUzQWxtbGtyMHRYN3RkUHMuYjF5SUt5"
35 |
36 | func (s *Server) authenticate(mechanism string, blob []byte) bool {
37 | // Split. SASL PLAIN has three parts
38 | parts := bytes.Split(blob, []byte{0})
39 | if len(parts) != 3 {
40 | return false
41 | }
42 |
43 | for name, user := range s.users {
44 | if string(parts[1]) != name {
45 | continue
46 | }
47 | err := bcrypt.CompareHashAndPassword(user.password, parts[2])
48 | if err == nil {
49 | return true
50 | }
51 | }
52 | return false
53 | }
54 |
--------------------------------------------------------------------------------
/server/basicMethods.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/jeffjenkins/dispatchd/amqp"
5 | "github.com/jeffjenkins/dispatchd/stats"
6 | "github.com/jeffjenkins/dispatchd/util"
7 | )
8 |
9 | func (channel *Channel) basicRoute(methodFrame amqp.MethodFrame) *amqp.AMQPError {
10 | switch method := methodFrame.(type) {
11 | case *amqp.BasicQos:
12 | return channel.basicQos(method)
13 | case *amqp.BasicRecover:
14 | return channel.basicRecover(method)
15 | case *amqp.BasicNack:
16 | return channel.basicNack(method)
17 | case *amqp.BasicConsume:
18 | return channel.basicConsume(method)
19 | case *amqp.BasicCancel:
20 | return channel.basicCancel(method)
21 | case *amqp.BasicCancelOk:
22 | return channel.basicCancelOk(method)
23 | case *amqp.BasicPublish:
24 | return channel.basicPublish(method)
25 | case *amqp.BasicGet:
26 | return channel.basicGet(method)
27 | case *amqp.BasicAck:
28 | return channel.basicAck(method)
29 | case *amqp.BasicReject:
30 | return channel.basicReject(method)
31 | }
32 | var classId, methodId = methodFrame.MethodIdentifier()
33 | return amqp.NewHardError(540, "Unable to route method frame", classId, methodId)
34 | }
35 |
36 | func (channel *Channel) basicQos(method *amqp.BasicQos) *amqp.AMQPError {
37 | channel.setPrefetch(method.PrefetchCount, method.PrefetchSize, method.Global)
38 | channel.SendMethod(&amqp.BasicQosOk{})
39 | return nil
40 | }
41 |
42 | func (channel *Channel) basicRecover(method *amqp.BasicRecover) *amqp.AMQPError {
43 | channel.recover(method.Requeue)
44 | channel.SendMethod(&amqp.BasicRecoverOk{})
45 | return nil
46 | }
47 |
48 | func (channel *Channel) basicNack(method *amqp.BasicNack) *amqp.AMQPError {
49 | if method.Multiple {
50 | return channel.nackBelow(method.DeliveryTag, method.Requeue, false)
51 | }
52 | return channel.nackOne(method.DeliveryTag, method.Requeue, false)
53 | }
54 |
55 | func (channel *Channel) basicConsume(method *amqp.BasicConsume) *amqp.AMQPError {
56 | var classId, methodId = method.MethodIdentifier()
57 | // Check queue
58 | if len(method.Queue) == 0 {
59 | if len(channel.lastQueueName) == 0 {
60 | return amqp.NewSoftError(404, "Queue not found", classId, methodId)
61 | } else {
62 | method.Queue = channel.lastQueueName
63 | }
64 | }
65 | // TODO: do not directly access channel.conn.server.queues
66 | var queue, found = channel.conn.server.queues[method.Queue]
67 | if !found {
68 | // Spec doesn't say, but seems like a 404?
69 | return amqp.NewSoftError(404, "Queue not found", classId, methodId)
70 | }
71 | if len(method.ConsumerTag) == 0 {
72 | method.ConsumerTag = util.RandomId()
73 | }
74 | amqpErr := channel.addConsumer(queue, method)
75 | if amqpErr != nil {
76 | return amqpErr
77 | }
78 | if !method.NoWait {
79 | channel.SendMethod(&amqp.BasicConsumeOk{method.ConsumerTag})
80 | }
81 |
82 | return nil
83 | }
84 |
85 | func (channel *Channel) basicCancel(method *amqp.BasicCancel) *amqp.AMQPError {
86 |
87 | if err := channel.removeConsumer(method.ConsumerTag); err != nil {
88 | var classId, methodId = method.MethodIdentifier()
89 | return amqp.NewSoftError(404, "Consumer not found", classId, methodId)
90 | }
91 |
92 | if !method.NoWait {
93 | channel.SendMethod(&amqp.BasicCancelOk{method.ConsumerTag})
94 | }
95 | return nil
96 | }
97 |
98 | func (channel *Channel) basicCancelOk(method *amqp.BasicCancelOk) *amqp.AMQPError {
99 | // TODO(MAY)
100 | var classId, methodId = method.MethodIdentifier()
101 | return amqp.NewHardError(540, "Not implemented", classId, methodId)
102 | }
103 |
104 | func (channel *Channel) basicPublish(method *amqp.BasicPublish) *amqp.AMQPError {
105 | defer stats.RecordHisto(channel.statPublish, stats.Start())
106 | var _, found = channel.server.exchanges[method.Exchange]
107 | if !found {
108 | var classId, methodId = method.MethodIdentifier()
109 | return amqp.NewSoftError(404, "Exchange not found", classId, methodId)
110 | }
111 | channel.startPublish(method)
112 | return nil
113 | }
114 |
115 | func (channel *Channel) basicGet(method *amqp.BasicGet) *amqp.AMQPError {
116 | // var classId, methodId = method.MethodIdentifier()
117 | // channel.conn.connectionErrorWithMethod(540, "Not implemented", classId, methodId)
118 | var queue, found = channel.conn.server.queues[method.Queue]
119 | if !found {
120 | // Spec doesn't say, but seems like a 404?
121 | var classId, methodId = method.MethodIdentifier()
122 | return amqp.NewSoftError(404, "Queue not found", classId, methodId)
123 | }
124 | var qm = queue.GetOneForced()
125 | if qm == nil {
126 | channel.SendMethod(&amqp.BasicGetEmpty{})
127 | return nil
128 | }
129 |
130 | var rhs = []amqp.MessageResourceHolder{channel}
131 | msg, err := channel.server.msgStore.GetAndDecrRef(qm, queue.Name, rhs)
132 | if err != nil {
133 | // TODO: return 500 error
134 | channel.SendMethod(&amqp.BasicGetEmpty{})
135 | return nil
136 | }
137 |
138 | channel.SendContent(&amqp.BasicGetOk{
139 | DeliveryTag: channel.nextDeliveryTag(),
140 | Redelivered: qm.DeliveryCount > 0,
141 | Exchange: msg.Exchange,
142 | RoutingKey: msg.Key,
143 | MessageCount: 1,
144 | }, msg)
145 | return nil
146 | }
147 |
148 | func (channel *Channel) basicAck(method *amqp.BasicAck) *amqp.AMQPError {
149 | if method.Multiple {
150 | return channel.ackBelow(method.DeliveryTag, false)
151 | }
152 | return channel.ackOne(method.DeliveryTag, false)
153 | }
154 |
155 | func (channel *Channel) basicReject(method *amqp.BasicReject) *amqp.AMQPError {
156 | return channel.nackOne(method.DeliveryTag, method.Requeue, false)
157 | }
158 |
--------------------------------------------------------------------------------
/server/channelMethods.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/jeffjenkins/dispatchd/amqp"
5 | )
6 |
7 | func (channel *Channel) channelRoute(methodFrame amqp.MethodFrame) *amqp.AMQPError {
8 | switch method := methodFrame.(type) {
9 | case *amqp.ChannelOpen:
10 | return channel.channelOpen(method)
11 | case *amqp.ChannelFlow:
12 | return channel.channelFlow(method)
13 | case *amqp.ChannelFlowOk:
14 | return channel.channelFlowOk(method)
15 | case *amqp.ChannelClose:
16 | return channel.channelClose(method)
17 | case *amqp.ChannelCloseOk:
18 | return channel.channelCloseOk(method)
19 | // case *amqp.ChannelOpenOk:
20 | // return channel.channelOpenOk(method)
21 | }
22 | var classId, methodId = methodFrame.MethodIdentifier()
23 | return amqp.NewHardError(540, "Unable to route method frame", classId, methodId)
24 | }
25 |
26 | func (channel *Channel) channelOpen(method *amqp.ChannelOpen) *amqp.AMQPError {
27 | if channel.state == CH_STATE_OPEN {
28 | var classId, methodId = method.MethodIdentifier()
29 | return amqp.NewHardError(504, "Channel already open", classId, methodId)
30 | }
31 | channel.SendMethod(&amqp.ChannelOpenOk{})
32 | channel.setStateOpen()
33 | return nil
34 | }
35 |
36 | func (channel *Channel) channelFlow(method *amqp.ChannelFlow) *amqp.AMQPError {
37 | channel.changeFlow(method.Active)
38 | channel.SendMethod(&amqp.ChannelFlowOk{channel.flow})
39 | return nil
40 | }
41 |
42 | func (channel *Channel) channelFlowOk(method *amqp.ChannelFlowOk) *amqp.AMQPError {
43 | var classId, methodId = method.MethodIdentifier()
44 | return amqp.NewHardError(540, "Not implemented", classId, methodId)
45 | }
46 |
47 | func (channel *Channel) channelClose(method *amqp.ChannelClose) *amqp.AMQPError {
48 | // TODO(MAY): Report the class and method that are the reason for the close
49 | channel.SendMethod(&amqp.ChannelCloseOk{})
50 | channel.shutdown()
51 | return nil
52 | }
53 |
54 | func (channel *Channel) channelCloseOk(method *amqp.ChannelCloseOk) *amqp.AMQPError {
55 | channel.shutdown()
56 | return nil
57 | }
58 |
--------------------------------------------------------------------------------
/server/connection.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "bytes"
5 | "encoding/json"
6 | "fmt"
7 | "github.com/jeffjenkins/dispatchd/amqp"
8 | "github.com/jeffjenkins/dispatchd/stats"
9 | "github.com/jeffjenkins/dispatchd/util"
10 | "net"
11 | "sync"
12 | "time"
13 | )
14 |
15 | // TODO: we can only be "in" one of these at once, so this should probably
16 | // be one field
17 | type ConnectStatus struct {
18 | start bool
19 | startOk bool
20 | secure bool
21 | secureOk bool
22 | tune bool
23 | tuneOk bool
24 | open bool
25 | openOk bool
26 | closing bool
27 | closed bool
28 | }
29 |
30 | type AMQPConnection struct {
31 | id int64
32 | nextChannel int
33 | channels map[uint16]*Channel
34 | outgoing chan *amqp.WireFrame
35 | connectStatus ConnectStatus
36 | server *Server
37 | network net.Conn
38 | lock sync.Mutex
39 | ttl time.Time
40 | sendHeartbeatInterval time.Duration
41 | receiveHeartbeatInterval time.Duration
42 | maxChannels uint16
43 | maxFrameSize uint32
44 | clientProperties *amqp.Table
45 | // stats
46 | statOutBlocked stats.Histogram
47 | statOutNetwork stats.Histogram
48 | statInBlocked stats.Histogram
49 | statInNetwork stats.Histogram
50 | }
51 |
52 | func (conn *AMQPConnection) MarshalJSON() ([]byte, error) {
53 | return json.Marshal(map[string]interface{}{
54 | "id": conn.id,
55 | "address": fmt.Sprintf("%s", conn.network.RemoteAddr()),
56 | "clientProperties": conn.clientProperties.Table,
57 | "channelCount": len(conn.channels),
58 | })
59 | }
60 |
61 | func NewAMQPConnection(server *Server, network net.Conn) *AMQPConnection {
62 | return &AMQPConnection{
63 | // If outgoing has a buffer the server performs better. I'm not adding one
64 | // in until I fully understand why that is
65 | id: util.NextId(),
66 | network: network,
67 | channels: make(map[uint16]*Channel),
68 | outgoing: make(chan *amqp.WireFrame, 100),
69 | connectStatus: ConnectStatus{},
70 | server: server,
71 | receiveHeartbeatInterval: 10 * time.Second,
72 | maxChannels: 4096,
73 | maxFrameSize: 65536,
74 | // stats
75 | statOutBlocked: stats.MakeHistogram("Connection.Out.Blocked"),
76 | statOutNetwork: stats.MakeHistogram("Connection.Out.Network"),
77 | statInBlocked: stats.MakeHistogram("Connection.In.Blocked"),
78 | statInNetwork: stats.MakeHistogram("Connection.In.Network"),
79 | }
80 | }
81 |
82 | func (conn *AMQPConnection) openConnection() {
83 | // Negotiate Protocol
84 | buf := make([]byte, 8)
85 | _, err := conn.network.Read(buf)
86 | if err != nil {
87 | conn.hardClose()
88 | return
89 | }
90 |
91 | var supported = []byte{'A', 'M', 'Q', 'P', 0, 0, 9, 1}
92 | if bytes.Compare(buf, supported) != 0 {
93 | conn.network.Write(supported)
94 | conn.hardClose()
95 | return
96 | }
97 |
98 | // Create channel 0 and start the connection handshake
99 | conn.channels[0] = NewChannel(0, conn)
100 | conn.channels[0].start()
101 | conn.handleOutgoing()
102 | conn.handleIncoming()
103 | }
104 |
105 | func (conn *AMQPConnection) cleanUp() {
106 |
107 | }
108 |
109 | func (conn *AMQPConnection) deregisterChannel(id uint16) {
110 | delete(conn.channels, id)
111 | }
112 |
113 | func (conn *AMQPConnection) hardClose() {
114 | conn.network.Close()
115 | conn.connectStatus.closed = true
116 | conn.server.deregisterConnection(conn.id)
117 | conn.server.deleteQueuesForConn(conn.id)
118 | for _, channel := range conn.channels {
119 | channel.shutdown()
120 | }
121 | }
122 |
123 | func (conn *AMQPConnection) setMaxChannels(max uint16) {
124 | conn.maxChannels = max
125 | }
126 |
127 | func (conn *AMQPConnection) setMaxFrameSize(max uint32) {
128 | conn.maxFrameSize = max
129 | }
130 |
131 | func (conn *AMQPConnection) startSendHeartbeat(interval time.Duration) {
132 | conn.sendHeartbeatInterval = interval
133 | conn.handleSendHeartbeat()
134 | }
135 |
136 | func (conn *AMQPConnection) handleSendHeartbeat() {
137 | go func() {
138 | for {
139 | if conn.connectStatus.closed {
140 | break
141 | }
142 | time.Sleep(conn.sendHeartbeatInterval / 2)
143 | conn.outgoing <- &amqp.WireFrame{8, 0, make([]byte, 0)}
144 | }
145 | }()
146 | }
147 |
148 | func (conn *AMQPConnection) handleClientHeartbeatTimeout() {
149 | // TODO(MUST): The spec is that any octet is a heartbeat substitute. Right
150 | // now this is only looking at frames, so a long send could cause a timeout
151 | // TODO(MUST): if the client isn't heartbeating how do we know when it's
152 | // gone?
153 | go func() {
154 | for {
155 | if conn.connectStatus.closed {
156 | break
157 | }
158 | time.Sleep(conn.receiveHeartbeatInterval / 2) //
159 | // If now is higher than TTL we need to time the client out
160 | if conn.ttl.Before(time.Now()) {
161 | conn.hardClose()
162 | }
163 | }
164 | }()
165 | }
166 |
167 | func (conn *AMQPConnection) handleOutgoing() {
168 | // TODO(MUST): Use SetWriteDeadline so we never wait too long. It should be
169 | // higher than the heartbeat in use. It should be reset after the heartbeat
170 | // interval is known.
171 | go func() {
172 | for {
173 | if conn.connectStatus.closed {
174 | break
175 | }
176 | var start = stats.Start()
177 | var frame = <-conn.outgoing
178 | stats.RecordHisto(conn.statOutBlocked, start)
179 |
180 | // fmt.Printf("Sending outgoing message. type: %d\n", frame.FrameType)
181 | // TODO(MUST): Hard close on irrecoverable errors, retry on recoverable
182 | // ones some number of times.
183 | start = stats.Start()
184 | amqp.WriteFrame(conn.network, frame)
185 | stats.RecordHisto(conn.statOutNetwork, start)
186 | // for wire protocol debugging:
187 | // for _, b := range frame.Payload {
188 | // fmt.Printf("%d,", b)
189 | // }
190 | // fmt.Printf("\n")
191 | }
192 | }()
193 | }
194 |
195 | func (conn *AMQPConnection) connectionErrorWithMethod(amqpErr *amqp.AMQPError) {
196 | fmt.Println("Sending connection error:", amqpErr.Msg)
197 | conn.connectStatus.closing = true
198 | conn.channels[0].SendMethod(&amqp.ConnectionClose{
199 | ReplyCode: amqpErr.Code,
200 | ReplyText: amqpErr.Msg,
201 | ClassId: amqpErr.Class,
202 | MethodId: amqpErr.Method,
203 | })
204 | }
205 |
206 | func (conn *AMQPConnection) handleIncoming() {
207 | for {
208 | // If the connection is done, we stop handling frames
209 | if conn.connectStatus.closed {
210 | break
211 | }
212 | // Read from the network
213 | // TODO(MUST): Add a timeout to the read, esp. if there is no heartbeat
214 | // TODO(MUST): Hard close on unrecoverable errors, retry (with backoff?)
215 | // for recoverable ones
216 | var start = stats.Start()
217 | frame, err := amqp.ReadFrame(conn.network)
218 | if err != nil {
219 | fmt.Println("Error reading frame: " + err.Error())
220 | conn.hardClose()
221 | break
222 | }
223 | stats.RecordHisto(conn.statInNetwork, start)
224 | conn.handleFrame(frame)
225 | }
226 | }
227 |
228 | func (conn *AMQPConnection) handleFrame(frame *amqp.WireFrame) {
229 |
230 | // Upkeep. Remove things which have expired, etc
231 | conn.cleanUp()
232 | conn.ttl = time.Now().Add(conn.receiveHeartbeatInterval * 2)
233 |
234 | switch {
235 | case frame.FrameType == 8:
236 | // TODO(MUST): Update last heartbeat time
237 | return
238 | }
239 |
240 | if !conn.connectStatus.open && frame.Channel != 0 {
241 | fmt.Println("Non-0 channel for unopened connection")
242 | conn.hardClose()
243 | return
244 | }
245 | var channel, ok = conn.channels[frame.Channel]
246 | // TODO(MUST): Check that the channel number if in the valid range
247 | if !ok {
248 | channel = NewChannel(frame.Channel, conn)
249 | conn.channels[frame.Channel] = channel
250 | conn.channels[frame.Channel].start()
251 | }
252 | // Dispatch
253 | start := stats.Start()
254 | channel.incoming <- frame
255 | stats.RecordHisto(conn.statInBlocked, start)
256 | }
257 |
--------------------------------------------------------------------------------
/server/connectionMethods.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/jeffjenkins/dispatchd/amqp"
5 | "os"
6 | "runtime"
7 | "time"
8 | )
9 |
10 | func (channel *Channel) connectionRoute(conn *AMQPConnection, methodFrame amqp.MethodFrame) *amqp.AMQPError {
11 | switch method := methodFrame.(type) {
12 | case *amqp.ConnectionStartOk:
13 | return channel.connectionStartOk(conn, method)
14 | case *amqp.ConnectionTuneOk:
15 | return channel.connectionTuneOk(conn, method)
16 | case *amqp.ConnectionOpen:
17 | return channel.connectionOpen(conn, method)
18 | case *amqp.ConnectionClose:
19 | return channel.connectionClose(conn, method)
20 | case *amqp.ConnectionSecureOk:
21 | return channel.connectionSecureOk(conn, method)
22 | case *amqp.ConnectionCloseOk:
23 | return channel.connectionCloseOk(conn, method)
24 | case *amqp.ConnectionBlocked:
25 | return channel.connectionBlocked(conn, method)
26 | case *amqp.ConnectionUnblocked:
27 | return channel.connectionUnblocked(conn, method)
28 | }
29 | var classId, methodId = methodFrame.MethodIdentifier()
30 | return amqp.NewHardError(540, "Unable to route method frame", classId, methodId)
31 | }
32 |
33 | func (channel *Channel) connectionOpen(conn *AMQPConnection, method *amqp.ConnectionOpen) *amqp.AMQPError {
34 | // TODO(MAY): Add support for virtual hosts. Check for access to the
35 | // selected one
36 | conn.connectStatus.open = true
37 | channel.SendMethod(&amqp.ConnectionOpenOk{""})
38 | conn.connectStatus.openOk = true
39 | return nil
40 | }
41 |
42 | func (channel *Channel) connectionTuneOk(conn *AMQPConnection, method *amqp.ConnectionTuneOk) *amqp.AMQPError {
43 | conn.connectStatus.tuneOk = true
44 | if method.ChannelMax > conn.maxChannels || method.FrameMax > conn.maxFrameSize {
45 | conn.hardClose()
46 | return nil
47 | }
48 |
49 | conn.setMaxChannels(method.ChannelMax)
50 | conn.setMaxFrameSize(method.FrameMax)
51 |
52 | if method.Heartbeat > 0 {
53 | // Start sending heartbeats to the client
54 | conn.startSendHeartbeat(time.Duration(method.Heartbeat) * time.Second)
55 | }
56 | // Start listening for heartbeats from the client.
57 | // We always ask for them since we want to shut down
58 | // connections not in use
59 | conn.handleClientHeartbeatTimeout()
60 | return nil
61 | }
62 |
63 | func (channel *Channel) connectionStartOk(conn *AMQPConnection, method *amqp.ConnectionStartOk) *amqp.AMQPError {
64 | // TODO(SHOULD): record product/version/platform/copyright/information
65 | // TODO(MUST): assert mechanism, response, locale are not null
66 | conn.connectStatus.startOk = true
67 |
68 | if method.Mechanism != "PLAIN" {
69 | conn.hardClose()
70 | }
71 |
72 | if !conn.server.authenticate(method.Mechanism, method.Response) {
73 | var classId, methodId = method.MethodIdentifier()
74 | return &amqp.AMQPError{
75 | Code: 530,
76 | Class: classId,
77 | Method: methodId,
78 | Msg: "Authorization failed",
79 | Soft: false,
80 | }
81 | }
82 |
83 | conn.clientProperties = method.ClientProperties
84 | // TODO(MUST): add support these being enforced at the connection level.
85 | channel.SendMethod(&amqp.ConnectionTune{
86 | conn.maxChannels,
87 | conn.maxFrameSize,
88 | uint16(conn.receiveHeartbeatInterval.Nanoseconds() / int64(time.Second)),
89 | })
90 | // TODO: Implement secure/secure-ok later if needed
91 | conn.connectStatus.secure = true
92 | conn.connectStatus.secureOk = true
93 | conn.connectStatus.tune = true
94 | return nil
95 | }
96 |
97 | func (channel *Channel) startConnection() *amqp.AMQPError {
98 | // TODO(SHOULD): add fields: host, product, version, platform, copyright, information
99 | var capabilities = amqp.NewTable()
100 | capabilities.SetKey("publisher_confirms", false)
101 | capabilities.SetKey("basic.nack", true)
102 | var serverProps = amqp.NewTable()
103 | // TODO: the java rabbitmq client I'm using for load testing doesn't like these string
104 | // fields even though the go/python clients do. If they are set as longstr (bytes)
105 | // instead they work, so I'm doing that for now
106 | serverProps.SetKey("product", []byte("dispatchd"))
107 | serverProps.SetKey("version", []byte("0.1"))
108 | serverProps.SetKey("copyright", []byte("Jeffrey Jenkins, 2015"))
109 | serverProps.SetKey("capabilities", capabilities)
110 | serverProps.SetKey("platform", []byte(runtime.GOARCH))
111 | host, err := os.Hostname()
112 | if err != nil {
113 | serverProps.SetKey("host", []byte("UnknownHostError"))
114 | } else {
115 | serverProps.SetKey("host", []byte(host))
116 | }
117 |
118 | serverProps.SetKey("information", []byte("http://dispatchd.org"))
119 |
120 | channel.SendMethod(&amqp.ConnectionStart{0, 9, serverProps, []byte("PLAIN"), []byte("en_US")})
121 | return nil
122 | }
123 |
124 | func (channel *Channel) connectionClose(conn *AMQPConnection, method *amqp.ConnectionClose) *amqp.AMQPError {
125 | channel.SendMethod(&amqp.ConnectionCloseOk{})
126 | conn.hardClose()
127 | return nil
128 | }
129 |
130 | func (channel *Channel) connectionCloseOk(conn *AMQPConnection, method *amqp.ConnectionCloseOk) *amqp.AMQPError {
131 | conn.hardClose()
132 | return nil
133 | }
134 |
135 | func (channel *Channel) connectionSecureOk(conn *AMQPConnection, method *amqp.ConnectionSecureOk) *amqp.AMQPError {
136 | // TODO(MAY): If other security mechanisms are in place, handle this
137 | conn.hardClose()
138 | return nil
139 | }
140 |
141 | func (channel *Channel) connectionBlocked(conn *AMQPConnection, method *amqp.ConnectionBlocked) *amqp.AMQPError {
142 | return amqp.NewHardError(540, "Not implemented", 10, 60)
143 | }
144 |
145 | func (channel *Channel) connectionUnblocked(conn *AMQPConnection, method *amqp.ConnectionUnblocked) *amqp.AMQPError {
146 | return amqp.NewHardError(540, "Not implemented", 10, 61)
147 | }
148 |
--------------------------------------------------------------------------------
/server/exchangeMethods.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | _ "fmt"
5 | "github.com/jeffjenkins/dispatchd/amqp"
6 | "github.com/jeffjenkins/dispatchd/exchange"
7 | "strings"
8 | )
9 |
10 | func (channel *Channel) exchangeRoute(methodFrame amqp.MethodFrame) *amqp.AMQPError {
11 | switch method := methodFrame.(type) {
12 | case *amqp.ExchangeDeclare:
13 | return channel.exchangeDeclare(method)
14 | case *amqp.ExchangeBind:
15 | return channel.exchangeBind(method)
16 | case *amqp.ExchangeUnbind:
17 | return channel.exchangeUnbind(method)
18 | case *amqp.ExchangeDelete:
19 | return channel.exchangeDelete(method)
20 | }
21 | var classId, methodId = methodFrame.MethodIdentifier()
22 | return amqp.NewHardError(540, "Not implemented", classId, methodId)
23 | }
24 |
25 | func (channel *Channel) exchangeDeclare(method *amqp.ExchangeDeclare) *amqp.AMQPError {
26 | var classId, methodId = method.MethodIdentifier()
27 | // The client I'm using for testing thought declaring the empty exchange
28 | // was OK. Check later
29 | // if len(method.Exchange) > 0 && !method.Passive {
30 | // var msg = "The empty exchange name is reserved"
31 | // channel.channelErrorWithMethod(406, msg, classId, methodId)
32 | // return nil
33 | // }
34 |
35 | // Check the name format
36 | var err = amqp.CheckExchangeOrQueueName(method.Exchange)
37 | if err != nil {
38 | return amqp.NewSoftError(406, err.Error(), classId, methodId)
39 | }
40 |
41 | // Declare!
42 | var ex, amqpErr = exchange.NewFromMethod(method, false, channel.server.exchangeDeleter)
43 | if amqpErr != nil {
44 | return amqpErr
45 | }
46 | tp, err := exchange.ExchangeNameToType(method.Type)
47 | if err != nil || tp == exchange.EX_TYPE_HEADERS {
48 | return amqp.NewHardError(503, err.Error(), classId, methodId)
49 | }
50 | existing, hasKey := channel.server.exchanges[ex.Name]
51 | if !hasKey && method.Passive {
52 | return amqp.NewSoftError(404, "Exchange does not exist", classId, methodId)
53 | }
54 | if hasKey {
55 | // if diskLoad {
56 | // panic(fmt.Sprintf("Can't disk load a key that exists: %s", ex.Name))
57 | // }
58 | if existing.ExType != ex.ExType {
59 | return amqp.NewHardError(530, "Cannot redeclare an exchange with a different type", classId, methodId)
60 | }
61 | if existing.EquivalentExchanges(ex) {
62 | if !method.NoWait {
63 | channel.SendMethod(&amqp.ExchangeDeclareOk{})
64 | }
65 | return nil
66 | }
67 | // Not equivalent, error in passive mode
68 | if method.Passive {
69 | return amqp.NewSoftError(406, "Exchange with this name already exists", classId, methodId)
70 | }
71 | }
72 | if method.Passive {
73 | if !method.NoWait {
74 | channel.SendMethod(&amqp.ExchangeDeclareOk{})
75 | }
76 | return nil
77 | }
78 |
79 | // outside of passive mode you can't create an exchange starting with
80 | // amq.
81 | if strings.HasPrefix(method.Exchange, "amq.") {
82 | return amqp.NewSoftError(403, "Exchange names starting with 'amq.' are reserved", classId, methodId)
83 | }
84 |
85 | err = channel.server.addExchange(ex)
86 | if err != nil {
87 | return amqp.NewSoftError(500, err.Error(), classId, methodId)
88 | }
89 | err = ex.Persist(channel.server.db)
90 | if err != nil {
91 | return amqp.NewSoftError(500, err.Error(), classId, methodId)
92 | }
93 | if !method.NoWait {
94 | channel.SendMethod(&amqp.ExchangeDeclareOk{})
95 | }
96 | return nil
97 | }
98 |
99 | func (channel *Channel) exchangeDelete(method *amqp.ExchangeDelete) *amqp.AMQPError {
100 | var classId, methodId = method.MethodIdentifier()
101 | var errCode, err = channel.server.deleteExchange(method)
102 | if err != nil {
103 | return amqp.NewSoftError(errCode, err.Error(), classId, methodId)
104 | }
105 | if !method.NoWait {
106 | channel.SendMethod(&amqp.ExchangeDeleteOk{})
107 | }
108 | return nil
109 | }
110 |
111 | func (channel *Channel) exchangeBind(method *amqp.ExchangeBind) *amqp.AMQPError {
112 | var classId, methodId = method.MethodIdentifier()
113 | return amqp.NewHardError(540, "Not implemented", classId, methodId)
114 | // if !method.NoWait {
115 | // channel.SendMethod(&amqp.ExchangeBindOk{})
116 | // }
117 | }
118 | func (channel *Channel) exchangeUnbind(method *amqp.ExchangeUnbind) *amqp.AMQPError {
119 | var classId, methodId = method.MethodIdentifier()
120 | return amqp.NewHardError(540, "Not implemented", classId, methodId)
121 | // if !method.NoWait {
122 | // channel.SendMethod(&amqp.ExchangeUnbindOk{})
123 | // }
124 | }
125 |
--------------------------------------------------------------------------------
/server/queueMethods.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "fmt"
5 | "github.com/jeffjenkins/dispatchd/amqp"
6 | "github.com/jeffjenkins/dispatchd/binding"
7 | "github.com/jeffjenkins/dispatchd/queue"
8 | "github.com/jeffjenkins/dispatchd/util"
9 | )
10 |
11 | func (channel *Channel) queueRoute(methodFrame amqp.MethodFrame) *amqp.AMQPError {
12 | switch method := methodFrame.(type) {
13 | case *amqp.QueueDeclare:
14 | return channel.queueDeclare(method)
15 | case *amqp.QueueBind:
16 | return channel.queueBind(method)
17 | case *amqp.QueuePurge:
18 | return channel.queuePurge(method)
19 | case *amqp.QueueDelete:
20 | return channel.queueDelete(method)
21 | case *amqp.QueueUnbind:
22 | return channel.queueUnbind(method)
23 | }
24 | var classId, methodId = methodFrame.MethodIdentifier()
25 | return amqp.NewHardError(540, "Not implemented", classId, methodId)
26 | }
27 |
28 | func (channel *Channel) queueDeclare(method *amqp.QueueDeclare) *amqp.AMQPError {
29 | var classId, methodId = method.MethodIdentifier()
30 | // No name means generate a name
31 | if len(method.Queue) == 0 {
32 | method.Queue = util.RandomId()
33 | }
34 |
35 | // Check the name format
36 | var err = amqp.CheckExchangeOrQueueName(method.Queue)
37 | if err != nil {
38 | return amqp.NewSoftError(406, err.Error(), classId, methodId)
39 | }
40 |
41 | // If this is a passive request, do the appropriate checks and return
42 | if method.Passive {
43 | queue, found := channel.conn.server.queues[method.Queue]
44 | if found {
45 | if !method.NoWait {
46 | var qsize = uint32(queue.Len())
47 | var csize = queue.ActiveConsumerCount()
48 | channel.SendMethod(&amqp.QueueDeclareOk{method.Queue, qsize, csize})
49 | }
50 | channel.lastQueueName = method.Queue
51 | return nil
52 | }
53 | return amqp.NewSoftError(404, "Queue not found", classId, methodId)
54 | }
55 |
56 | // Create the new queue
57 | var connId = channel.conn.id
58 | if !method.Exclusive {
59 | connId = -1
60 | }
61 | var queue = queue.NewQueue(
62 | method.Queue,
63 | method.Durable,
64 | method.Exclusive,
65 | method.AutoDelete,
66 | method.Arguments,
67 | connId,
68 | channel.server.msgStore,
69 | channel.server.queueDeleter,
70 | )
71 |
72 | // If the new queue exists already, ensure the settings are the same. If it
73 | // doesn't, add it and optionally persist it
74 | existing, hasKey := channel.server.queues[queue.Name]
75 | if hasKey {
76 | if existing.ConnId != -1 && existing.ConnId != channel.conn.id {
77 | return amqp.NewSoftError(405, "Queue is locked to another connection", classId, methodId)
78 | }
79 | if !existing.EquivalentQueues(queue) {
80 | return amqp.NewSoftError(406, "Queue exists and is not equivalent to existing", classId, methodId)
81 | }
82 | } else {
83 | err = channel.server.addQueue(queue)
84 | if err != nil { // pragma: nocover
85 | return amqp.NewSoftError(500, "Error creating queue", classId, methodId)
86 | }
87 | // Persist
88 | if queue.Durable {
89 | queue.Persist(channel.server.db)
90 | }
91 | }
92 |
93 | channel.lastQueueName = method.Queue
94 | if !method.NoWait {
95 | channel.SendMethod(&amqp.QueueDeclareOk{queue.Name, uint32(0), uint32(0)})
96 | }
97 | return nil
98 | }
99 |
100 | func (channel *Channel) queueBind(method *amqp.QueueBind) *amqp.AMQPError {
101 | var classId, methodId = method.MethodIdentifier()
102 |
103 | if len(method.Queue) == 0 {
104 | if len(channel.lastQueueName) == 0 {
105 | return amqp.NewSoftError(404, "Queue not found", classId, methodId)
106 | } else {
107 | method.Queue = channel.lastQueueName
108 | }
109 | }
110 |
111 | // Check exchange
112 | var exchange, foundExchange = channel.server.exchanges[method.Exchange]
113 | if !foundExchange {
114 | return amqp.NewSoftError(404, "Exchange not found", classId, methodId)
115 | }
116 |
117 | // Check queue
118 | var queue, foundQueue = channel.server.queues[method.Queue]
119 | if !foundQueue || queue.Closed {
120 | return amqp.NewSoftError(404, fmt.Sprintf("Queue not found: %s", method.Queue), classId, methodId)
121 | }
122 |
123 | if queue.ConnId != -1 && queue.ConnId != channel.conn.id {
124 | return amqp.NewSoftError(405, fmt.Sprintf("Queue is locked to another connection"), classId, methodId)
125 | }
126 |
127 | // Create binding
128 | b, err := binding.NewBinding(method.Queue, method.Exchange, method.RoutingKey, method.Arguments, exchange.IsTopic())
129 | if err != nil {
130 | return amqp.NewSoftError(500, err.Error(), classId, methodId)
131 | }
132 |
133 | // Add binding
134 | err = exchange.AddBinding(b, channel.conn.id)
135 | if err != nil {
136 | return amqp.NewSoftError(500, err.Error(), classId, methodId)
137 | }
138 |
139 | // Persist durable bindings
140 | if exchange.Durable && queue.Durable {
141 | var err = b.Persist(channel.server.db)
142 | if err != nil {
143 | return amqp.NewSoftError(500, err.Error(), classId, methodId)
144 | }
145 | }
146 |
147 | if !method.NoWait {
148 | channel.SendMethod(&amqp.QueueBindOk{})
149 | }
150 | return nil
151 | }
152 |
153 | func (channel *Channel) queuePurge(method *amqp.QueuePurge) *amqp.AMQPError {
154 | fmt.Println("Got queuePurge")
155 | var classId, methodId = method.MethodIdentifier()
156 |
157 | // Check queue
158 | if len(method.Queue) == 0 {
159 | if len(channel.lastQueueName) == 0 {
160 | return amqp.NewSoftError(404, "Queue not found", classId, methodId)
161 | } else {
162 | method.Queue = channel.lastQueueName
163 | }
164 | }
165 |
166 | var queue, foundQueue = channel.server.queues[method.Queue]
167 | if !foundQueue {
168 | return amqp.NewSoftError(404, "Queue not found", classId, methodId)
169 | }
170 |
171 | if queue.ConnId != -1 && queue.ConnId != channel.conn.id {
172 | return amqp.NewSoftError(405, "Queue is locked to another connection", classId, methodId)
173 | }
174 |
175 | numPurged := queue.Purge()
176 | if !method.NoWait {
177 | channel.SendMethod(&amqp.QueuePurgeOk{numPurged})
178 | }
179 | return nil
180 | }
181 |
182 | func (channel *Channel) queueDelete(method *amqp.QueueDelete) *amqp.AMQPError {
183 | fmt.Println("Got queueDelete")
184 | var classId, methodId = method.MethodIdentifier()
185 |
186 | // Check queue
187 | if len(method.Queue) == 0 {
188 | if len(channel.lastQueueName) == 0 {
189 | return amqp.NewSoftError(404, "Queue not found", classId, methodId)
190 | } else {
191 | method.Queue = channel.lastQueueName
192 | }
193 | }
194 |
195 | numPurged, errCode, err := channel.server.deleteQueue(method, channel.conn.id)
196 | if err != nil {
197 | return amqp.NewSoftError(errCode, err.Error(), classId, methodId)
198 | }
199 |
200 | if !method.NoWait {
201 | channel.SendMethod(&amqp.QueueDeleteOk{numPurged})
202 | }
203 | return nil
204 | }
205 |
206 | func (channel *Channel) queueUnbind(method *amqp.QueueUnbind) *amqp.AMQPError {
207 | var classId, methodId = method.MethodIdentifier()
208 |
209 | // Check queue
210 | if len(method.Queue) == 0 {
211 | if len(channel.lastQueueName) == 0 {
212 | return amqp.NewSoftError(404, "Queue not found", classId, methodId)
213 | } else {
214 | method.Queue = channel.lastQueueName
215 | }
216 | }
217 |
218 | var queue, foundQueue = channel.server.queues[method.Queue]
219 | if !foundQueue {
220 | return amqp.NewSoftError(404, "Queue not found", classId, methodId)
221 | }
222 |
223 | if queue.ConnId != -1 && queue.ConnId != channel.conn.id {
224 | return amqp.NewSoftError(405, "Queue is locked to another connection", classId, methodId)
225 | }
226 |
227 | // Check exchange
228 | var exchange, foundExchange = channel.server.exchanges[method.Exchange]
229 | if !foundExchange {
230 | return amqp.NewSoftError(404, "Exchange not found", classId, methodId)
231 | }
232 |
233 | var binding, err = binding.NewBinding(
234 | method.Queue,
235 | method.Exchange,
236 | method.RoutingKey,
237 | method.Arguments,
238 | exchange.IsTopic(),
239 | )
240 |
241 | if err != nil {
242 | return amqp.NewSoftError(500, err.Error(), classId, methodId)
243 | }
244 |
245 | if queue.Durable && exchange.Durable {
246 | err := binding.Depersist(channel.server.db)
247 | if err != nil {
248 | return amqp.NewSoftError(500, "Could not de-persist binding!", classId, methodId)
249 | }
250 | }
251 |
252 | if err := exchange.RemoveBinding(binding); err != nil {
253 | return amqp.NewSoftError(500, err.Error(), classId, methodId)
254 | }
255 | channel.SendMethod(&amqp.QueueUnbindOk{})
256 | return nil
257 | }
258 |
--------------------------------------------------------------------------------
/server/server.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "fmt"
7 | "github.com/boltdb/bolt"
8 | "github.com/jeffjenkins/dispatchd/amqp"
9 | "github.com/jeffjenkins/dispatchd/binding"
10 | "github.com/jeffjenkins/dispatchd/exchange"
11 | "github.com/jeffjenkins/dispatchd/msgstore"
12 | "github.com/jeffjenkins/dispatchd/queue"
13 | "net"
14 | "sync"
15 | )
16 |
17 | type Server struct {
18 | exchanges map[string]*exchange.Exchange
19 | queues map[string]*queue.Queue
20 | bindings []*binding.Binding
21 | idLock sync.Mutex
22 | conns map[int64]*AMQPConnection
23 | db *bolt.DB
24 | serverLock sync.Mutex
25 | msgStore *msgstore.MessageStore
26 | exchangeDeleter chan *exchange.Exchange
27 | queueDeleter chan *queue.Queue
28 | users map[string]User
29 | strictMode bool
30 | }
31 |
32 | func (server *Server) MarshalJSON() ([]byte, error) {
33 | conns := make(map[string]*AMQPConnection)
34 | for id, value := range server.conns {
35 | conns[fmt.Sprintf("%d", id)] = value
36 | }
37 | return json.Marshal(map[string]interface{}{
38 | "exchanges": server.exchanges,
39 | "queues": server.queues,
40 | "connections": conns,
41 | "msgCount": server.msgStore.MessageCount(),
42 | "msgIndexCount": server.msgStore.IndexCount(),
43 | })
44 | }
45 |
46 | func NewServer(dbPath string, msgStorePath string, userJson map[string]interface{}, strictMode bool) *Server {
47 | db, err := bolt.Open(dbPath, 0600, nil)
48 | if err != nil {
49 | panic(err.Error())
50 |
51 | }
52 | msgStore, err := msgstore.NewMessageStore(msgStorePath)
53 | msgStore.Start()
54 | if err != nil {
55 | panic("Could not create message store!")
56 | }
57 |
58 | var server = &Server{
59 | exchanges: make(map[string]*exchange.Exchange),
60 | queues: make(map[string]*queue.Queue),
61 | bindings: make([]*binding.Binding, 0),
62 | conns: make(map[int64]*AMQPConnection),
63 | db: db,
64 | msgStore: msgStore,
65 | exchangeDeleter: make(chan *exchange.Exchange),
66 | queueDeleter: make(chan *queue.Queue),
67 | users: make(map[string]User),
68 | strictMode: strictMode,
69 | }
70 |
71 | server.init()
72 | server.addUsers(userJson)
73 | return server
74 | }
75 |
76 | func (server *Server) init() {
77 | server.msgStore.LoadMessages() //this must be before initQueues
78 | server.initExchanges()
79 | server.initQueues()
80 | server.initBindings() // this must be after init{Exchanges,Queues}
81 | go server.exchangeDeleteMonitor()
82 | go server.queueDeleteMonitor()
83 | }
84 |
85 | func (server *Server) exchangeDeleteMonitor() {
86 | for e := range server.exchangeDeleter {
87 | var dele = &amqp.ExchangeDelete{
88 | Exchange: e.Name,
89 | NoWait: true,
90 | }
91 | server.deleteExchange(dele)
92 | }
93 | }
94 |
95 | func (server *Server) queueDeleteMonitor() {
96 | for q := range server.queueDeleter {
97 | var delq = &amqp.QueueDelete{
98 | Queue: q.Name,
99 | NoWait: true,
100 | }
101 | server.deleteQueue(delq, -1)
102 | }
103 | }
104 |
105 | func (server *Server) initBindings() {
106 | // Load bindings
107 | bindings, err := binding.LoadAllBindings(server.db)
108 | if err != nil {
109 | panic("Couldn't load bindings!")
110 | }
111 | for _, b := range bindings {
112 | // Get Exchange
113 | var exchange, foundExchange = server.exchanges[b.ExchangeName]
114 | if !foundExchange {
115 | panic("Couldn't bind non-existant exchange " + b.ExchangeName)
116 | }
117 | // Add Binding
118 | err = exchange.AddBinding(b, -1)
119 | if err != nil {
120 | panic(err.Error())
121 | }
122 | }
123 | }
124 |
125 | func (server *Server) initQueues() {
126 | // Load queues
127 | queues, err := queue.LoadAllQueues(server.db, server.msgStore, server.queueDeleter)
128 | if err != nil {
129 | panic("Couldn't load queues!")
130 | }
131 | for _, queue := range queues {
132 | err = server.addQueue(queue)
133 | if err != nil {
134 | panic("Couldn't load queues!")
135 | }
136 | }
137 | // Load queue data
138 | for _, queue := range server.queues {
139 | queue.LoadFromMsgStore(server.msgStore)
140 | }
141 | }
142 |
143 | func (server *Server) initExchanges() {
144 | // LOAD FROM PERSISTENT STORAGE
145 | exchanges, err := exchange.LoadAllExchanges(server.db, server.exchangeDeleter)
146 | if err != nil {
147 | panic("Couldn't load exchanges!")
148 | }
149 | for _, ex := range exchanges {
150 | err = server.addExchange(ex)
151 | if err != nil {
152 | panic("Couldn't load queues!")
153 | }
154 | }
155 | if err != nil {
156 | panic("FAILED TO LOAD EXCHANGES: " + err.Error())
157 | }
158 |
159 | // DECLARE MISSING SYSEM EXCHANGES
160 | server.genDefaultExchange("", exchange.EX_TYPE_DIRECT)
161 | server.genDefaultExchange("amq.direct", exchange.EX_TYPE_DIRECT)
162 | server.genDefaultExchange("amq.fanout", exchange.EX_TYPE_FANOUT)
163 | server.genDefaultExchange("amq.topic", exchange.EX_TYPE_TOPIC)
164 | }
165 |
166 | func (server *Server) genDefaultExchange(name string, typ uint8) {
167 | _, hasKey := server.exchanges[name]
168 | if !hasKey {
169 | var ex = exchange.NewExchange(
170 | name,
171 | exchange.EX_TYPE_TOPIC,
172 | true,
173 | false,
174 | false,
175 | amqp.NewTable(),
176 | true,
177 | server.exchangeDeleter,
178 | )
179 | // Persist
180 | ex.Persist(server.db)
181 | err := server.addExchange(ex)
182 | if err != nil {
183 | panic(err.Error())
184 | }
185 | }
186 | }
187 |
188 | func (server *Server) addExchange(ex *exchange.Exchange) error {
189 | server.serverLock.Lock()
190 | defer server.serverLock.Unlock()
191 | server.exchanges[ex.Name] = ex
192 | return nil
193 | }
194 |
195 | func (server *Server) addQueue(q *queue.Queue) error {
196 | server.serverLock.Lock()
197 | defer server.serverLock.Unlock()
198 | server.queues[q.Name] = q
199 | var defaultExchange = server.exchanges[""]
200 | var defaultBinding, err = binding.NewBinding(q.Name, "", q.Name, amqp.NewTable(), false)
201 | if err != nil {
202 | return err
203 | }
204 | defaultExchange.AddBinding(defaultBinding, q.ConnId)
205 | q.Start()
206 | return nil
207 | }
208 |
209 | func (server *Server) deleteQueuesForConn(connId int64) {
210 | server.serverLock.Lock()
211 | var queues = make([]*queue.Queue, 0)
212 | for _, queue := range server.queues {
213 | if queue.ConnId == connId {
214 | queues = append(queues, queue)
215 | }
216 | }
217 | server.serverLock.Unlock()
218 | for _, queue := range queues {
219 | var method = &amqp.QueueDelete{
220 | Queue: queue.Name,
221 | }
222 | server.deleteQueue(method, connId)
223 | }
224 | }
225 |
226 | func (server *Server) deleteQueue(method *amqp.QueueDelete, connId int64) (uint32, uint16, error) {
227 | server.serverLock.Lock()
228 | defer server.serverLock.Unlock()
229 | // Validate
230 | var queue, foundQueue = server.queues[method.Queue]
231 | if !foundQueue {
232 | return 0, 404, errors.New("Queue not found")
233 | }
234 |
235 | if queue.ConnId != -1 && queue.ConnId != connId {
236 | return 0, 405, fmt.Errorf("Queue is locked to another connection")
237 | }
238 |
239 | // Close to stop anything from changing
240 | queue.Close()
241 | // Delete for storage
242 | bindings := server.bindingsForQueue(queue.Name)
243 | server.removeBindingsForQueue(method.Queue)
244 | server.depersistQueue(queue, bindings)
245 |
246 | // Cleanup
247 | numPurged, err := queue.Delete(method.IfUnused, method.IfEmpty)
248 | delete(server.queues, method.Queue)
249 | if err != nil {
250 | return 0, 406, err
251 | }
252 | return numPurged, 0, nil
253 |
254 | }
255 |
256 | func (server *Server) depersistQueue(queue *queue.Queue, bindings []*binding.Binding) error {
257 | return server.db.Update(func(tx *bolt.Tx) error {
258 | for _, binding := range bindings {
259 | if err := binding.DepersistBoltTx(tx); err != nil {
260 | return err
261 | }
262 | }
263 | return queue.DepersistBoltTx(tx)
264 | })
265 | }
266 |
267 | func (server *Server) bindingsForQueue(queueName string) []*binding.Binding {
268 | ret := make([]*binding.Binding, 0)
269 | for _, exchange := range server.exchanges {
270 | ret = append(ret, exchange.BindingsForQueue(queueName)...)
271 | }
272 | return ret
273 | }
274 |
275 | func (server *Server) removeBindingsForQueue(queueName string) {
276 | for _, exchange := range server.exchanges {
277 | exchange.RemoveBindingsForQueue(queueName)
278 | }
279 | }
280 |
281 | func (server *Server) deleteExchange(method *amqp.ExchangeDelete) (uint16, error) {
282 | server.serverLock.Lock()
283 | defer server.serverLock.Unlock()
284 | exchange, found := server.exchanges[method.Exchange]
285 | if !found {
286 | return 404, fmt.Errorf("Exchange not found: '%s'", method.Exchange)
287 | }
288 | if exchange.System {
289 | return 530, fmt.Errorf("Cannot delete system exchange: '%s'", method.Exchange)
290 | }
291 | exchange.Close()
292 | exchange.Depersist(server.db)
293 | // Note: we don't need to delete the bindings from the queues they are
294 | // associated with because they are stored on the exchange.
295 | delete(server.exchanges, method.Exchange)
296 | return 0, nil
297 | }
298 |
299 | func (server *Server) deregisterConnection(id int64) {
300 | delete(server.conns, id)
301 | }
302 |
303 | func (server *Server) OpenConnection(network net.Conn) {
304 | c := NewAMQPConnection(server, network)
305 | server.conns[c.id] = c
306 | c.openConnection()
307 | }
308 |
309 | func (server *Server) returnMessage(msg *amqp.Message, code uint16, text string) *amqp.BasicReturn {
310 | return &amqp.BasicReturn{
311 | Exchange: msg.Method.Exchange,
312 | RoutingKey: msg.Method.RoutingKey,
313 | ReplyCode: code,
314 | ReplyText: text,
315 | }
316 | }
317 |
318 | func (server *Server) publish(exchange *exchange.Exchange, msg *amqp.Message) (*amqp.BasicReturn, *amqp.AMQPError) {
319 | // Concurrency note: Since there is no lock we can, technically, have messages
320 | // published after the exchange has been closed. These couldn't be on the same
321 | // channel as the close is happening on, so that seems justifiable.
322 | if exchange.Closed {
323 | if msg.Method.Mandatory || msg.Method.Immediate {
324 | var rm = server.returnMessage(msg, 313, "Exchange closed, cannot route to queues or consumers")
325 | return rm, nil
326 | }
327 | return nil, nil
328 | }
329 | queues, amqpErr := exchange.QueuesForPublish(msg)
330 | if amqpErr != nil {
331 | return nil, amqpErr
332 | }
333 |
334 | if len(queues) == 0 {
335 | // If we got here the message was unroutable.
336 | if msg.Method.Mandatory || msg.Method.Immediate {
337 | var rm = server.returnMessage(msg, 313, "No queues available")
338 | return rm, nil
339 | }
340 | }
341 |
342 | var queueNames = make([]string, 0, len(queues))
343 | for k, _ := range queues {
344 | queueNames = append(queueNames, k)
345 | }
346 |
347 | // Immediate messages
348 | if msg.Method.Immediate {
349 | var consumed = false
350 | // Add message to message store
351 | queueMessagesByQueue, err := server.msgStore.AddMessage(msg, queueNames)
352 | if err != nil {
353 | return nil, amqp.NewSoftError(500, err.Error(), 60, 40)
354 | }
355 | // Try to immediately consumed it
356 | for queueName, _ := range queues {
357 | qms := queueMessagesByQueue[queueName]
358 | for _, qm := range qms {
359 | queue, found := server.queues[queueName]
360 | if !found {
361 | // The queue must have been deleted since the queuesForPublish call
362 | continue
363 | }
364 | var oneConsumed = queue.ConsumeImmediate(qm)
365 | var rhs = make([]amqp.MessageResourceHolder, 0)
366 | if !oneConsumed {
367 | server.msgStore.RemoveRef(qm, queueName, rhs)
368 | }
369 | consumed = oneConsumed || consumed
370 | }
371 | }
372 | if !consumed {
373 | var rm = server.returnMessage(msg, 313, "No consumers available for immediate message")
374 | return rm, nil
375 | }
376 | return nil, nil
377 | }
378 |
379 | // Add the message to the message store along with the queues we're about to add it to
380 | queueMessagesByQueue, err := server.msgStore.AddMessage(msg, queueNames)
381 | if err != nil {
382 | return nil, amqp.NewSoftError(500, err.Error(), 60, 40)
383 | }
384 |
385 | for queueName, _ := range queues {
386 | qms := queueMessagesByQueue[queueName]
387 | for _, qm := range qms {
388 | queue, found := server.queues[queueName]
389 | if !found || !queue.Add(qm) {
390 | // If we couldn't add it means the queue is closed and we should
391 | // remove the ref from the message store. The queue being closed means
392 | // it is going away, so worst case if the server dies we have to process
393 | // and discard the message on boot.
394 | var rhs = make([]amqp.MessageResourceHolder, 0)
395 | server.msgStore.RemoveRef(qm, queueName, rhs)
396 | }
397 | }
398 | }
399 | return nil, nil
400 | }
401 |
--------------------------------------------------------------------------------
/server/server_consumer_test.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/jeffjenkins/dispatchd/util"
5 | "testing"
6 | )
7 |
8 | func TestAckNackOne(t *testing.T) {
9 | //
10 | // Setup
11 | //
12 | tc := newTestClient(t)
13 | defer tc.cleanup()
14 | conn := tc.connect()
15 | ch, _, _ := channelHelper(tc, conn)
16 |
17 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
18 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS)
19 |
20 | //
21 | // Publish and consume one message. Check that channel.awaitingAcks is updated
22 | // properly before and after acking the message
23 | //
24 | deliveries, err := ch.Consume("q1", "TestAckNackOne-1", false, false, false, false, NO_ARGS)
25 | if err != nil {
26 | t.Fatalf("Failed to consume")
27 | }
28 |
29 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
30 | msg := <-deliveries
31 |
32 | if len(tc.connFromServer().channels[1].awaitingAcks) != 1 {
33 | t.Fatalf("No awaiting ack for message just received")
34 | }
35 |
36 | msg.Ack(false)
37 | tc.wait(ch)
38 | if len(tc.connFromServer().channels[1].awaitingAcks) != 0 {
39 | t.Fatalf("No awaiting ack for message just received")
40 | }
41 |
42 | //
43 | // Publish and consume two messages. Check that awaiting acks is updated
44 | // and that nacking with and without the requeue options works
45 | //
46 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
47 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
48 | msg1 := <-deliveries
49 | msg2 := <-deliveries
50 |
51 | tc.wait(ch)
52 | if len(tc.connFromServer().channels[1].awaitingAcks) != 2 {
53 | t.Fatalf("Should have 2 messages awaiting acks")
54 | }
55 | // Stop consuming so we can check requeue
56 | ch.Cancel("TestAckNackOne-1", false)
57 | msg1.Nack(false, false)
58 | msg2.Nack(false, true)
59 | tc.wait(ch)
60 | if tc.s.queues["q1"].Len() != 1 {
61 | t.Fatalf("Should have 1 message in queue")
62 | }
63 |
64 | deliveries, err = ch.Consume("q1", "TestAckNackOne-1", false, false, false, false, NO_ARGS)
65 | if err != nil {
66 | t.Fatalf("Failed to consume")
67 | }
68 | msg2_again := <-deliveries
69 | if !msg2_again.Redelivered {
70 | t.Fatalf("Redelivered message wasn't flagged")
71 | }
72 | }
73 |
74 | func TestAckNackMany(t *testing.T) {
75 | //
76 | // Setup
77 | //
78 | tc := newTestClient(t)
79 | defer tc.cleanup()
80 | conn := tc.connect()
81 | ch, _, _ := channelHelper(tc, conn)
82 |
83 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
84 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS)
85 | var consumerId = util.RandomId()
86 |
87 | // Ack Many
88 | // Publish and consume two messages. Acck-multiple the second one and
89 | // check that both are acked
90 | //
91 | deliveries, err := ch.Consume("q1", consumerId, false, false, false, false, NO_ARGS)
92 | if err != nil {
93 | t.Fatalf("Failed to consume")
94 | }
95 |
96 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
97 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
98 | _ = <-deliveries
99 | msg2 := <-deliveries
100 |
101 | msg2.Ack(true)
102 | tc.wait(ch)
103 | if len(tc.connFromServer().channels[1].awaitingAcks) != 0 {
104 | t.Fatalf("No awaiting ack for message just received")
105 | }
106 |
107 | // Nack Many (no requeue)
108 | // Publish and consume two messages. Nack-multiple the second one and
109 | // check that both are acked and the queue is empty
110 | //
111 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
112 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
113 | _ = <-deliveries
114 | msg2 = <-deliveries
115 |
116 | // Stop consuming so we can check requeue
117 | ch.Cancel(consumerId, false)
118 | msg2.Nack(true, false)
119 | tc.wait(ch)
120 | if tc.s.queues["q1"].Len() != 0 {
121 | t.Fatalf("Should have 0 message in queue")
122 | }
123 |
124 | // Nack Many (no requeue)
125 | // Publish and consume two messages. Nack-multiple the second one and
126 | // check that both are acked and the queue is empty
127 | //
128 | consumerId = util.RandomId()
129 | deliveries, err = ch.Consume("q1", consumerId, false, false, false, false, NO_ARGS)
130 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
131 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
132 | _ = <-deliveries
133 | msg2 = <-deliveries
134 |
135 | // Stop consuming so we can check requeue
136 | ch.Cancel(consumerId, false)
137 | msg2.Nack(true, true)
138 | tc.wait(ch)
139 | if tc.s.queues["q1"].Len() != 2 {
140 | t.Fatalf("Should have 2 message in queue")
141 | }
142 | }
143 |
144 | func TestRecover(t *testing.T) {
145 | //
146 | // Setup
147 | //
148 | tc := newTestClient(t)
149 | defer tc.cleanup()
150 | conn := tc.connect()
151 | ch, _, _ := channelHelper(tc, conn)
152 | cTag := util.RandomId()
153 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
154 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS)
155 |
156 | // Recover - no requeue
157 | // Publish two messages, consume them, call recover, consume them again
158 | //
159 | deliveries, err := ch.Consume("q1", cTag, false, false, false, false, NO_ARGS)
160 | if err != nil {
161 | t.Fatalf("Failed to consume")
162 | }
163 |
164 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
165 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
166 | tc.wait(ch)
167 | <-deliveries
168 | lastMsg := <-deliveries
169 | ch.Recover(false)
170 | <-deliveries
171 | <-deliveries
172 | lastMsg.Ack(true)
173 |
174 | // Recover - requeue
175 | // Publish two messages, consume them, call recover, check that they are
176 | // requeued
177 | //
178 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
179 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
180 | tc.wait(ch)
181 | <-deliveries
182 | <-deliveries
183 | ch.Cancel(cTag, false)
184 | ch.Recover(true)
185 | tc.wait(ch)
186 | msgCount := tc.s.queues["q1"].Len()
187 | if msgCount != 2 {
188 | t.Fatalf("Should have 2 message in queue. Found", msgCount)
189 | }
190 | }
191 |
192 | func TestGet(t *testing.T) {
193 | //
194 | // Setup
195 | //
196 | tc := newTestClient(t)
197 | defer tc.cleanup()
198 | conn := tc.connect()
199 | ch, _, _ := channelHelper(tc, conn)
200 |
201 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
202 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS)
203 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
204 | msg, ok, err := ch.Get("q1", false)
205 | if err != nil {
206 | t.Fatalf(err.Error())
207 | }
208 | if !ok {
209 | t.Fatalf("Did not receive message")
210 | }
211 | if string(msg.Body) != string(TEST_TRANSIENT_MSG.Body) {
212 | t.Fatalf("wrong message response in get")
213 | }
214 | }
215 |
--------------------------------------------------------------------------------
/server/server_exchange_test.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestExchangeMethods(t *testing.T) {
8 | tc := newTestClient(t)
9 | defer tc.cleanup()
10 | conn := tc.connect()
11 | channel, _, _ := channelHelper(tc, conn)
12 |
13 | channel.ExchangeDeclare("ex-1", "topic", false, false, false, false, NO_ARGS)
14 |
15 | // Create exchange
16 | if len(tc.s.exchanges) != 5 {
17 | t.Errorf("Wrong number of exchanges: %d", len(tc.s.exchanges))
18 | }
19 |
20 | // Create Queue
21 | channel.QueueDeclare("q-1", false, false, false, false, NO_ARGS)
22 | if len(tc.s.queues) != 1 {
23 | t.Errorf("Wrong number of queues: %d", len(tc.s.queues))
24 | }
25 |
26 | // Delete exchange
27 | channel.ExchangeDelete("ex-1", false, false)
28 | if len(tc.s.exchanges) != 4 {
29 | t.Errorf("Wrong number of exchanges: %d", len(tc.s.exchanges))
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/server/server_publish_test.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/jeffjenkins/dispatchd/util"
5 | "testing"
6 | )
7 |
8 | func TestImmediateFail(t *testing.T) {
9 | tc := newTestClient(t)
10 | defer tc.cleanup()
11 | conn := tc.connect()
12 | ch, retChan, _ := channelHelper(tc, conn)
13 |
14 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
15 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS)
16 | ch.Publish("amq.direct", "abc", false, true, TEST_TRANSIENT_MSG)
17 |
18 | ret := <-retChan
19 |
20 | if ret.ReplyCode != 313 {
21 | t.Fatalf("Wrong reply code with Immediate return")
22 | }
23 | if string(ret.Body) != string(TEST_TRANSIENT_MSG.Body) {
24 | t.Fatalf("Did not get same payload back in BasicReturn")
25 | }
26 | }
27 |
28 | func TestImmediate(t *testing.T) {
29 | tc := newTestClient(t)
30 | defer tc.cleanup()
31 | conn := tc.connect()
32 | ch, _, _ := channelHelper(tc, conn)
33 |
34 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
35 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS)
36 |
37 | deliveries, err := ch.Consume("q1", util.RandomId(), false, false, false, false, NO_ARGS)
38 | if err != nil {
39 | t.Fatalf("Failed to consume")
40 | }
41 | ch.Publish("amq.direct", "abc", false, true, TEST_TRANSIENT_MSG)
42 | <-deliveries
43 | }
44 |
45 | func TestMandatory(t *testing.T) {
46 | tc := newTestClient(t)
47 | defer tc.cleanup()
48 | conn := tc.connect()
49 | ch, retChan, _ := channelHelper(tc, conn)
50 |
51 | ch.Publish("amq.direct", "abc", false, true, TEST_TRANSIENT_MSG)
52 |
53 | ret := <-retChan
54 |
55 | if ret.ReplyCode != 313 {
56 | t.Fatalf("Wrong reply code with Mandatory return")
57 | }
58 | if string(ret.Body) != string(TEST_TRANSIENT_MSG.Body) {
59 | t.Fatalf("Did not get same payload back in BasicReturn")
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/server/server_queue_test.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestQueueMethods(t *testing.T) {
8 | tc := newTestClient(t)
9 | defer tc.cleanup()
10 | conn := tc.connect()
11 | ch, _, _ := channelHelper(tc, conn)
12 |
13 | // Create Queue
14 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
15 |
16 | if len(tc.s.queues) != 1 {
17 | t.Errorf("Wrong number of queues: %d", len(tc.s.queues))
18 | }
19 |
20 | // Passive Check
21 | ch.QueueDeclarePassive("q1", true, false, false, false, NO_ARGS)
22 |
23 | // Bind
24 | ch.QueueBind("q1", "rk.*.#", "amq.topic", false, NO_ARGS)
25 | if len(tc.s.exchanges["amq.topic"].BindingsForQueue("q1")) != 1 {
26 | t.Errorf("Failed to bind to q1")
27 | }
28 |
29 | // Unbind
30 | ch.QueueUnbind("q1", "rk.*.#", "amq.topic", NO_ARGS)
31 |
32 | if len(tc.s.exchanges["amq.topic"].BindingsForQueue("q1")) != 0 {
33 | t.Errorf("Failed to unbind from q1")
34 | }
35 |
36 | // Delete
37 | ch.QueueDelete("q1", false, false, false)
38 | if len(tc.s.queues) != 0 {
39 | t.Errorf("Wrong number of queues: %d", len(tc.s.queues))
40 | }
41 | }
42 |
43 | func TestAutoAssignedQueue(t *testing.T) {
44 | tc := newTestClient(t)
45 | defer tc.cleanup()
46 | conn := tc.connect()
47 | ch, _, _ := channelHelper(tc, conn)
48 |
49 | // Create Queue
50 | resp, err := ch.QueueDeclare("", false, false, false, false, NO_ARGS)
51 | if err != nil {
52 | t.Fatalf("Error declaring queue")
53 | }
54 | if len(resp.Name) == 0 {
55 | t.Errorf("Autogenerate queue name failed")
56 | }
57 | }
58 |
59 | func TestBadQueueName(t *testing.T) {
60 | tc := newTestClient(t)
61 | defer tc.cleanup()
62 | conn := tc.connect()
63 | ch, _, errChan := channelHelper(tc, conn)
64 |
65 | // Create Queue
66 | _, err := ch.QueueDeclare("!", false, false, false, true, NO_ARGS)
67 | if err != nil {
68 | panic("failed to declare queue")
69 | }
70 | resp := <-errChan
71 | if resp.Code != 406 {
72 | t.Errorf("Wrong response code")
73 | }
74 | }
75 |
76 | func TestPassiveNotFound(t *testing.T) {
77 | tc := newTestClient(t)
78 | defer tc.cleanup()
79 | conn := tc.connect()
80 | ch, _, errChan := channelHelper(tc, conn)
81 |
82 | // Create Queue
83 | _, err := ch.QueueDeclarePassive("does.not.exist", true, false, false, true, NO_ARGS)
84 | if err != nil {
85 | panic("failed to declare queue")
86 | }
87 | resp := <-errChan
88 | if resp.Code != 404 {
89 | t.Errorf("Wrong response code")
90 | }
91 | }
92 |
93 | func TestPurge(t *testing.T) {
94 | tc := newTestClient(t)
95 | defer tc.cleanup()
96 | conn := tc.connect()
97 | ch, _, _ := channelHelper(tc, conn)
98 |
99 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
100 | ch.QueueBind("q1", "a.b.c", "amq.topic", false, NO_ARGS)
101 |
102 | ch.Publish("amq.topic", "a.b.c", false, false, TEST_TRANSIENT_MSG)
103 |
104 | // This unbind is just to block us on the message being processed by the
105 | // channel so that the server has it.
106 | ch.QueueUnbind("q1", "a.b.c", "amq.topic", NO_ARGS)
107 | if tc.s.queues["q1"].Len() == 0 {
108 | t.Fatalf("Message did not make it into queue")
109 | }
110 |
111 | resp, err := ch.QueuePurge("q1", false)
112 | if err != nil {
113 | t.Fatalf("Failed to call QueuePurge")
114 | }
115 |
116 | if tc.s.queues["q1"].Len() != 0 {
117 | t.Fatalf("Message did not get purged from queue. Got %d", tc.s.queues["q1"].Len())
118 | }
119 |
120 | if resp != 1 {
121 | t.Fatalf("Purge did not return the right number of messages deleted")
122 | }
123 | }
124 |
125 | func TestExclusive(t *testing.T) {
126 | tc := newTestClient(t)
127 | defer tc.cleanup()
128 | conn := tc.connect()
129 | ch, _, errChan := channelHelper(tc, conn)
130 |
131 | // Create Queue
132 | ch.QueueDeclare("q1", false, false, true, false, NO_ARGS)
133 |
134 | // Check conn id
135 | serverConn := tc.connFromServer()
136 | q, ok := tc.s.queues["q1"]
137 | if !ok {
138 | t.Fatalf("Could not find q1")
139 | }
140 | if serverConn.id != q.ConnId {
141 | t.Fatalf("Exclusive queue does not have connId set")
142 | }
143 |
144 | // Cheat and change the connection id so we don't need a second conn
145 | // for this test
146 | q.ConnId = 54321
147 | // NOTE: if nowait isn't true this blocks forever
148 | ch.QueueDeclare("q1", false, false, true, true, NO_ARGS)
149 |
150 | resp := <-errChan
151 | if resp.Code != 405 {
152 | t.Errorf("Wrong response code")
153 | }
154 | }
155 |
156 | func TestNonMatchingQueue(t *testing.T) {
157 | tc := newTestClient(t)
158 | defer tc.cleanup()
159 | conn := tc.connect()
160 | ch, _, errChan := channelHelper(tc, conn)
161 |
162 | // Create Queue
163 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
164 |
165 | // Create queue again with different args
166 | // NOTE: if nowait isn't true this blocks forever
167 | ch.QueueDeclare("q1", true, false, false, true, NO_ARGS)
168 | resp := <-errChan
169 | if resp.Code != 406 {
170 | t.Errorf("Wrong response code")
171 | }
172 | }
173 |
--------------------------------------------------------------------------------
/server/server_test.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/jeffjenkins/dispatchd/util"
5 | amqpclient "github.com/streadway/amqp"
6 | "net"
7 | "os"
8 | "testing"
9 | "time"
10 | )
11 |
12 | var NO_ARGS = make(amqpclient.Table)
13 | var TEST_TRANSIENT_MSG = amqpclient.Publishing{
14 | Body: []byte("dispatchd"),
15 | }
16 |
17 | type testClient struct {
18 | t *testing.T
19 | s *Server
20 | serverDb string
21 | msgDb string
22 | }
23 |
24 | func newTestClient(t *testing.T) *testClient {
25 | serverDb := dbPath()
26 | msgDb := dbPath()
27 | s := NewServer(serverDb, msgDb, nil, false)
28 | s.init()
29 | tc := &testClient{
30 | t: t,
31 | s: s,
32 | serverDb: serverDb,
33 | msgDb: msgDb,
34 | }
35 | return tc
36 | }
37 |
38 | func channelHelper(
39 | tc *testClient,
40 | conn *amqpclient.Connection,
41 | ) (
42 | *amqpclient.Channel,
43 | chan amqpclient.Return,
44 | chan *amqpclient.Error,
45 | ) {
46 | ch, err := conn.Channel()
47 | if err != nil {
48 | panic("Bad channel!")
49 | }
50 | retChan := make(chan amqpclient.Return)
51 | closeChan := make(chan *amqpclient.Error)
52 | ch.NotifyReturn(retChan)
53 | ch.NotifyClose(closeChan)
54 | return ch, retChan, closeChan
55 | }
56 |
57 | func (tc *testClient) connect() *amqpclient.Connection {
58 | internal, external := net.Pipe()
59 | go tc.s.OpenConnection(internal)
60 | // Set up connection
61 | clientconfig := amqpclient.Config{
62 | SASL: nil,
63 | Vhost: "/",
64 | ChannelMax: 100000,
65 | FrameSize: 100000,
66 | Heartbeat: time.Duration(0),
67 | TLSClientConfig: nil,
68 | Properties: make(amqpclient.Table),
69 | Dial: func(network, addr string) (net.Conn, error) {
70 | return external, nil
71 | },
72 | }
73 |
74 | client, err := amqpclient.DialConfig("amqp://localhost:1234", clientconfig)
75 | if err != nil {
76 | panic(err.Error())
77 | }
78 | return client
79 | }
80 |
81 | func (tc *testClient) wait(ch *amqpclient.Channel) {
82 | ch.QueueDeclare(util.RandomId(), false, false, false, false, NO_ARGS)
83 | }
84 |
85 | func (tc *testClient) cleanup() {
86 | os.Remove(tc.msgDb)
87 | os.Remove(tc.serverDb)
88 | // tc.client.Close()
89 | }
90 |
91 | func dbPath() string {
92 | return "/tmp/" + util.RandomId() + ".dispatchd.test.db"
93 | }
94 |
95 | func (tc *testClient) connFromServer() *AMQPConnection {
96 | for _, conn := range tc.s.conns {
97 | return conn
98 | }
99 | panic("no connections!")
100 | }
101 |
--------------------------------------------------------------------------------
/server/server_tx_test.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/jeffjenkins/dispatchd/util"
5 | "testing"
6 | )
7 |
8 | func TestTxCommitPublish(t *testing.T) {
9 | //
10 | // Setup
11 | //
12 | tc := newTestClient(t)
13 | defer tc.cleanup()
14 | conn := tc.connect()
15 | ch, _, _ := channelHelper(tc, conn)
16 |
17 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
18 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS)
19 | ch.Tx()
20 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
21 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
22 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
23 | tc.wait(ch)
24 | if tc.s.queues["q1"].Len() != 0 {
25 | t.Fatalf("Tx failed to buffer messages")
26 | }
27 | ch.TxCommit()
28 | tc.wait(ch)
29 | if tc.s.queues["q1"].Len() != 3 {
30 | t.Fatalf("All messages were not added to queue")
31 | }
32 | }
33 |
34 | func TestTxRollbackPublish(t *testing.T) {
35 | //
36 | // Setup
37 | //
38 | tc := newTestClient(t)
39 | defer tc.cleanup()
40 | conn := tc.connect()
41 | ch, _, _ := channelHelper(tc, conn)
42 |
43 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
44 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS)
45 | ch.Tx()
46 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
47 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
48 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
49 | ch.TxRollback()
50 | ch.TxCommit()
51 | tc.wait(ch)
52 | if tc.s.queues["q1"].Len() != 0 {
53 | t.Fatalf("Tx Rollback still put messages in queue")
54 | }
55 | }
56 |
57 | func TestTxCommitAckNack(t *testing.T) {
58 | //
59 | // Setup
60 | //
61 | tc := newTestClient(t)
62 | defer tc.cleanup()
63 | conn := tc.connect()
64 | ch, _, _ := channelHelper(tc, conn)
65 |
66 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
67 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS)
68 | ch.Tx()
69 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
70 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
71 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
72 | ch.TxCommit()
73 | cTag := util.RandomId()
74 | deliveries, _ := ch.Consume("q1", cTag, false, false, false, false, NO_ARGS)
75 | <-deliveries
76 | <-deliveries
77 | lastMsg := <-deliveries
78 | lastMsg.Ack(true)
79 | tc.wait(ch)
80 | if len(tc.connFromServer().channels[1].awaitingAcks) != 3 {
81 | t.Fatalf("Acks were not held in tx")
82 | }
83 | ch.TxCommit()
84 | if len(tc.connFromServer().channels[1].awaitingAcks) != 0 {
85 | t.Fatalf("Acks were not processed on commit")
86 | }
87 | }
88 |
89 | func TestTxRollbackAckNack(t *testing.T) {
90 | //
91 | // Setup
92 | //
93 | tc := newTestClient(t)
94 | defer tc.cleanup()
95 | conn := tc.connect()
96 | ch, _, _ := channelHelper(tc, conn)
97 |
98 | ch.QueueDeclare("q1", false, false, false, false, NO_ARGS)
99 | ch.QueueBind("q1", "abc", "amq.direct", false, NO_ARGS)
100 | ch.Tx()
101 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
102 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
103 | ch.Publish("amq.direct", "abc", false, false, TEST_TRANSIENT_MSG)
104 | ch.TxCommit()
105 |
106 | cTag := util.RandomId()
107 | deliveries, _ := ch.Consume("q1", cTag, false, false, false, false, NO_ARGS)
108 | <-deliveries
109 | <-deliveries
110 | lastMsg := <-deliveries
111 | lastMsg.Nack(true, true)
112 | tc.wait(ch)
113 | if len(tc.connFromServer().channels[1].awaitingAcks) != 3 {
114 | t.Fatalf("Messages were acked despite rollback")
115 | }
116 | }
117 |
--------------------------------------------------------------------------------
/server/txMethods.go:
--------------------------------------------------------------------------------
1 | package server
2 |
3 | import (
4 | "github.com/jeffjenkins/dispatchd/amqp"
5 | )
6 |
7 | func (channel *Channel) txRoute(methodFrame amqp.MethodFrame) *amqp.AMQPError {
8 | switch method := methodFrame.(type) {
9 | case *amqp.TxSelect:
10 | return channel.txSelect(method)
11 | case *amqp.TxCommit:
12 | return channel.txCommit(method)
13 | case *amqp.TxRollback:
14 | return channel.txRollback(method)
15 | }
16 | var classId, methodId = methodFrame.MethodIdentifier()
17 | return amqp.NewHardError(540, "Unable to route method frame", classId, methodId)
18 | }
19 |
20 | func (channel *Channel) txSelect(method *amqp.TxSelect) *amqp.AMQPError {
21 | channel.startTxMode()
22 | channel.SendMethod(&amqp.TxSelectOk{})
23 | return nil
24 | }
25 |
26 | func (channel *Channel) txCommit(method *amqp.TxCommit) *amqp.AMQPError {
27 | if amqpErr := channel.commitTx(); amqpErr != nil {
28 | return amqpErr
29 | }
30 | channel.SendMethod(&amqp.TxCommitOk{})
31 | return nil
32 | }
33 |
34 | func (channel *Channel) txRollback(method *amqp.TxRollback) *amqp.AMQPError {
35 | channel.rollbackTx()
36 | channel.SendMethod(&amqp.TxRollbackOk{})
37 | return nil
38 | }
39 |
--------------------------------------------------------------------------------
/static/admin.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Admin
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
Server
16 |
Message Store Size: {{server.msgCount}}
17 |
Message Index Size: {{server.msgIndexCount}}
18 |
Exchanges
19 |
20 |
21 | "{{name}}"
22 | (default exchange)
23 | — Type: {{exchange.type}}
24 |
25 |
26 |
27 |
Bindings
28 |
29 |
30 | Queue |
31 | Key |
32 |
33 |
34 | {{binding.queueName}} |
35 | {{binding.key}} |
36 |
37 |
38 |
39 |
40 |
Queues
41 |
42 |
43 | {{name}}
44 |
45 | (exclusive: {{queue.connId}})
46 |
47 |
48 |
49 |
50 |
51 | Tag |
52 | Ack? |
53 | QoS |
54 | Active |
55 | Lifetime |
56 |
57 |
58 | {{consumer.tag}} |
59 | {{consumer.ack}} |
60 | {{consumer.stats.qos}} |
61 | {{consumer.stats.active_count}} |
62 | N/A |
63 | {{consumer.stats.total}} |
64 |
65 |
66 |
67 |
71 |
72 |
73 |
Connections
74 |
75 |
76 | id: {{id}} – addr: {{conn.address}} — {{conn.channelCount}} channels active
77 |
78 |
79 |
Client Properties
80 |
81 |
82 | name |
83 | value |
84 |
85 |
86 | {{name}} |
87 | {{value}} |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/static/admin.js:
--------------------------------------------------------------------------------
1 |
2 | REFRESH_INTERVAL = 2000
3 |
4 | var app = angular.module('adminApp', [])
5 |
6 | app.service('serverState', ['$http', '$interval',
7 | function($http, $interval) {
8 | var self = this
9 | self.server = {}
10 | self.refresh = function() {
11 | $http({
12 | 'url' : '/api/server'
13 | }).then(function(data) {
14 | angular.forEach(self.server, function(value, key) {
15 | delete self.server[key]
16 | })
17 | angular.forEach(data.data, function(value, key) {
18 | self.server[key] = value
19 | })
20 | })
21 | }
22 | self.refresh()
23 | $interval(self.refresh, REFRESH_INTERVAL)
24 | }])
25 |
26 | app.controller('ServerController', ['$scope', 'serverState',
27 | function($scope, serverState) {
28 | $scope.server = serverState.server
29 | }])
30 |
--------------------------------------------------------------------------------
/stats/stats.go:
--------------------------------------------------------------------------------
1 | package stats
2 |
3 | import (
4 | "github.com/rcrowley/go-metrics"
5 | "time"
6 | )
7 |
8 | func MakeHistogram(name string) metrics.Histogram {
9 | return metrics.GetOrRegisterHistogram(
10 | name,
11 | metrics.DefaultRegistry,
12 | metrics.NewUniformSample(10000),
13 | )
14 | }
15 |
16 | func RecordHisto(histo metrics.Histogram, start int64) {
17 | histo.Update(time.Now().UnixNano() - start)
18 | }
19 |
20 | func Start() int64 {
21 | return time.Now().UnixNano()
22 | }
23 |
24 | type Histogram metrics.Histogram
25 |
--------------------------------------------------------------------------------
/stats/stats_test.go:
--------------------------------------------------------------------------------
1 | package stats
2 |
3 | import (
4 | "github.com/rcrowley/go-metrics"
5 | "testing"
6 | )
7 |
8 | func TestStats(t *testing.T) {
9 | // This is basically for code coverage. These are simple wrappers.
10 | // I'll add more tests once this start needing to do more.
11 | var m = MakeHistogram("hello")
12 | if m != metrics.Get("hello") {
13 | t.Errorf("Got different histogram from MakeHistogram")
14 | }
15 |
16 | var start = Start()
17 | RecordHisto(m, start-20)
18 | if m.Max() < 19 || m.Max() > start {
19 | t.Errorf("Bad value in histo %d", m.Max())
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/util/util.go:
--------------------------------------------------------------------------------
1 | package util
2 |
3 | import (
4 | "math/rand"
5 | "sync/atomic"
6 | "time"
7 | )
8 |
9 | var counter int64
10 |
11 | func init() {
12 | rand.Seed(time.Now().UnixNano())
13 | // System Integrity Note:
14 | //
15 | // Start the counter at the time of server boot. This is so that we have
16 | // message IDs which are consistent within a single server even if it is
17 | // restarted. One implication of using the time is that if the server boot
18 | // time is earlier than the previous boot time durable messages may be
19 | // loaded out of order
20 | counter = time.Now().UnixNano()
21 | }
22 |
23 | var chars = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890")
24 |
25 | func RandomId() string {
26 | var size = 32
27 | var numChars = len(chars)
28 | id := make([]rune, size)
29 | for i := range id {
30 | id[i] = chars[rand.Intn(numChars)]
31 | }
32 | return string(id)
33 | }
34 |
35 | func NextId() int64 {
36 | return atomic.AddInt64(&counter, 1)
37 | }
38 |
--------------------------------------------------------------------------------
/util/util_test.go:
--------------------------------------------------------------------------------
1 | package util
2 |
3 | import (
4 | "regexp"
5 | "testing"
6 | "time"
7 | )
8 |
9 | func TestRandomId(t *testing.T) {
10 | var re, err = regexp.Compile("^\\w{32}$")
11 | if err != nil {
12 | t.Errorf(err.Error())
13 | }
14 | for i := 0; i < 100; i++ {
15 | var r = RandomId()
16 | if !re.MatchString(r) {
17 | t.Errorf("Did not match string: '%s'", r)
18 | }
19 | }
20 | }
21 |
22 | func TestNextId(t *testing.T) {
23 | var now = time.Now().UnixNano()
24 | var next = NextId()
25 | var nextNext = NextId()
26 | if now < next {
27 | t.Errorf("NextId was less than current time")
28 | }
29 | if next+1 != nextNext {
30 | t.Errorf(
31 | "subsequent NextId calls do not produce incrementing numbers: %d, %d",
32 | next,
33 | nextNext,
34 | )
35 | }
36 | }
37 |
--------------------------------------------------------------------------------