├── .github
└── workflows
│ └── mqtt-nats_test.yml
├── .gitignore
├── .golangci.yml
├── .idea
├── .gitignore
├── codeStyles
│ └── codeStyleConfig.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── mqtt-nats.iml
├── vcs.xml
└── watcherTasks.xml
├── LICENSE
├── README.md
├── bridge
├── client.go
├── client_test.go
├── nats.go
├── natspub.go
├── options.go
├── package.go
├── replytopic.go
├── retained.go
├── server.go
└── session.go
├── cli
├── mqtt-nats.go
└── package.go
├── examples
├── certs
│ ├── .gitignore
│ ├── ca.json
│ ├── config.json
│ ├── csr.json
│ ├── generate.sh
│ └── server.json
├── jmeter
│ └── MQTT Pub Sampler.jmx
├── server.conf
└── tools
│ ├── nats-pub-repeat
│ └── nats-pub-repeat.go
│ └── nats-sub-reply
│ └── nats-sub-reply.go
├── go.mod
├── go.sum
├── logger
├── logger.go
└── logger_test.go
├── main.go
├── mqtt
├── package.go
├── pkg
│ ├── connect.go
│ ├── connect_test.go
│ ├── credentials.go
│ ├── idmanager.go
│ ├── idmanager_test.go
│ ├── packet.go
│ ├── packet_test.go
│ ├── ping.go
│ ├── ping_test.go
│ ├── publish.go
│ ├── publish_test.go
│ ├── subscribe.go
│ ├── subscribe_test.go
│ ├── unsubscribe.go
│ ├── unsubscribe_test.go
│ └── will.go
├── reader.go
├── reader_test.go
├── topic.go
├── topic_test.go
├── writer.go
└── writer_test.go
└── test
├── connect_test.go
├── full
├── bridge.go
├── client.go
└── nats.go
├── main_test.go
├── mock
├── connection.go
└── connection_test.go
├── package.go
├── packet
├── parse.go
└── parse_test.go
├── publish_test.go
├── retained_test.go
├── tls
├── connect_test.go
├── main_test.go
├── package.go
├── server.conf
└── testdata
│ ├── ca-key.pem
│ ├── ca.pem
│ ├── client-key.pem
│ ├── client.pem
│ ├── server-key.pem
│ └── server.pem
└── utils
├── checks.go
├── checks_test.go
├── logger.go
├── logger_test.go
├── panics.go
└── panics_test.go
/.github/workflows/mqtt-nats_test.yml:
--------------------------------------------------------------------------------
1 | name: MQTT-NATS Test
2 | on: [push, pull_request]
3 | jobs:
4 |
5 | test:
6 | name: Test Linux
7 | runs-on: ubuntu-latest
8 | steps:
9 |
10 | - name: Set up Go 1.14
11 | uses: actions/setup-go@v1
12 | with:
13 | go-version: 1.14
14 | id: go
15 |
16 | - name: Check out code into the Go module directory
17 | uses: actions/checkout@v1
18 |
19 | - name: Set up GolangCI-Lint
20 | run: curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- latest
21 |
22 | - name: Lint
23 | run: ./bin/golangci-lint run ./...
24 |
25 | - name: Test
26 | run: go test -v -tags=citest -covermode=atomic -coverpkg=$(go list ./... | grep -v 'nats/test' | tr '\n' ,) -coverprofile coverage.tmp ./...
27 |
28 | # - name: Test Coverage Check
29 | # run: |
30 | # COV=$(go tool cover -func=coverage.tmp | grep -e '^total:\s*(statements)' | awk '{ print $3 }')
31 | # test $COV = '100.0%' || (echo "Expected 100% test coverage, got $COV" && exit 1)
32 |
33 | - name: Send coverage
34 | env:
35 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }}
36 | run: |
37 | go get github.com/mattn/goveralls@master
38 | $(go env GOPATH)/bin/goveralls -tags=citest -coverprofile=coverage.tmp -service=github
39 |
40 | test-windows:
41 | name: Test Windows
42 | runs-on: windows-latest
43 | steps:
44 |
45 | - name: Set up Go 1.13
46 | uses: actions/setup-go@v1
47 | with:
48 | go-version: 1.13
49 | id: go
50 |
51 | - name: Check out code into the Go module directory
52 | uses: actions/checkout@v1
53 |
54 | - name: Test
55 | run: go test -v -tags=citest ./...
56 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/**/workspace.xml
2 | .idea/**/tasks.xml
3 | .idea/**/dictionaries
4 | .idea/**/shelf
5 |
6 | # Bridge persistence file
7 | mqtt-nats.json
8 |
9 | # Binaries for programs and plugins
10 | *.exe
11 | *.exe~
12 | *.dll
13 | *.so
14 | *.dylib
15 |
16 | # Test binary, build with `go test -c`
17 | *.test
18 |
19 | # Output of the go coverage tool, specifically when used with LiteIDE
20 | *.out
21 | *.tmp
22 |
--------------------------------------------------------------------------------
/.golangci.yml:
--------------------------------------------------------------------------------
1 | run:
2 | build-tags:
3 | - citest
4 |
5 | issues:
6 | exclude-use-default: false
7 | exclude-rules:
8 | - path: _test\.go
9 | linters:
10 | - const
11 | - dupl
12 | - gochecknoglobals
13 | - goconst
14 | - golint
15 | - unparam
16 |
17 | linters-settings:
18 | gocyclo:
19 | min-complexity: 35
20 |
21 | gocognit:
22 | min-complexity: 60
23 |
24 | lll:
25 | line-length: 140
26 | tab-width: 2
27 |
28 | misspell:
29 | ignore-words:
30 | - mosquitto
31 |
32 | linters:
33 | disable-all: true
34 | enable:
35 | - bodyclose
36 | - deadcode
37 | - depguard
38 | - dogsled
39 | - dupl
40 | - errcheck
41 | - gochecknoglobals
42 | - gochecknoinits
43 | - gocognit
44 | - goconst
45 | - gocritic
46 | - gocyclo
47 | - gofmt
48 | - goimports
49 | - golint
50 | - gosimple
51 | - govet
52 | - ineffassign
53 | - interfacer
54 | - lll
55 | - maligned
56 | - misspell
57 | - nakedret
58 | - prealloc
59 | - scopelint
60 | - structcheck
61 | - staticcheck
62 | - stylecheck
63 | - typecheck
64 | - unconvert
65 | - unused
66 | - varcheck
67 | - whitespace
68 | - unparam
69 |
70 | # don't enable:
71 | # - funlen
72 | # - godox
73 | # - gosec
74 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/mqtt-nats.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/watcherTasks.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | The mqtt-nats bridge enables [MQTT](http://mqtt.org/) devices to publish and subscribe to the [NATS](https://nats.io) communication system.
2 |
3 | [](https://opensource.org/licenses/Apache-2.0)
4 | [](https://goreportcard.com/report/github.com/tada/mqtt-nats)
5 | [](https://godoc.org/github.com/tada/mqtt-nats)
6 | [](https://github.com/tada/mqtt-nats/actions)
7 | [](https://coveralls.io/github/tada/mqtt-nats)
8 |
9 | ## Project status
10 | Currently under development.
11 |
12 | ## Getting started
13 |
14 | ### Obtaining and running the bridge
15 | You can install the binary using go get
16 | ```
17 | go get github.com/tada/mqtt-nats
18 | ```
19 | If you just start using the command `mqtt-nats` it will create an MQTT server on port 1883 that will attempt to connect
20 | to a NATS server using the default URL "nats://127.0.0.1:4222". Use:
21 | ```
22 | mqtt-nats -help
23 | ```
24 | to get a list of all configurable options.
25 |
26 | ### Run the tests
27 | The test utilities within this code-base are tagged with the special build tag "citest". This flag is required
28 | for most of the tests to build and run. I.e. to run all tests, use:
29 | ```
30 | go test -tags citest ./...
31 | ```
32 |
33 | ## Current limitations:
34 | - Only MQTT 3.1.1 is supported
35 | - Only QoS levels 0 (at most once) and 1 (at least once)
36 | - The bridge has no way of knowing when new subscriptions are added in the NATS network and hence, cannot send retained
37 | messages in response to such subscriptions.
38 |
39 | ## Solutions and workarounds
40 |
41 | ### QoS level 1 is accomplished using the NATS reply-to subject.
42 | When the bridge receives an MQTT publish with QoS = 1 from a client, it forwards that to NATS with a reply-to subject.
43 | A PUBACK is sent to the MQTT client when the reply arrives. Similarly, if an MQTT client subscribes using desired QoS
44 | = 1, then a NATS publish with a reply-to will be considered in need of a PUBACK from the MQTT client.
45 |
46 | ### Retain request from NATS
47 | An MQTT client that subscribes to a topic will immediately receive all retained messages for that topic. The same is
48 | not true for a NATS client simply because the bridge has no way of knowing when a NATS client subscribes to a topic. To
49 | mitigate this, the bridge subscribes to a specific "retained request" topic (configurable option). A NATS client that
50 | publishes to this topic with a subscription string as the payload and a reply-to inbox, will get a JSON encoded reply
51 | containing all messages that matches the subcription string.
52 |
--------------------------------------------------------------------------------
/bridge/client.go:
--------------------------------------------------------------------------------
1 | package bridge
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "io"
7 | "net"
8 | "sync"
9 | "time"
10 |
11 | "github.com/nats-io/nats.go"
12 | "github.com/tada/mqtt-nats/logger"
13 | "github.com/tada/mqtt-nats/mqtt"
14 | "github.com/tada/mqtt-nats/mqtt/pkg"
15 | )
16 |
17 | const (
18 | // StateInfant is set when the client is created and awaits a Connect packet
19 | StateInfant = byte(iota)
20 |
21 | // StateConnected is set once a successful connection has been established
22 | StateConnected
23 |
24 | // StateDisconnected is set when a disconnect packet arrives or when a non recoverable error occurs.
25 | StateDisconnected
26 | )
27 |
28 | // A Client represents a connection from a client.
29 | type Client interface {
30 | // Serve starts the read and write loops and then waits for them to finish which
31 | // normally happens after the receipt of a disconnect packet
32 | Serve()
33 |
34 | // State returns the current client state
35 | State() byte
36 |
37 | // PublishResponse publishes a packet to the client in response to a subscription
38 | PublishResponse(qos byte, pp *pkg.Publish)
39 |
40 | // SetDisconnected will end the read and write loop and eventually cause Serve() to end.
41 | SetDisconnected(error)
42 | }
43 |
44 | type client struct {
45 | server Server
46 | log logger.Logger
47 | mqttConn net.Conn
48 | natsConn *nats.Conn
49 | session Session
50 | connectPacket *pkg.Connect
51 | err error
52 | maxWait time.Duration
53 | natsSubs map[string]*nats.Subscription
54 | writeQueue chan pkg.Packet
55 | stLock sync.RWMutex
56 | subLock sync.Mutex
57 | workers sync.WaitGroup
58 | sessionPresent bool
59 | st byte
60 | }
61 |
62 | // TODO: This should probably be configurable.
63 | const writeQueueSize = 1024
64 |
65 | // NewClient returns a new Client instance with StateInfant state.
66 | func NewClient(s Server, log logger.Logger, conn net.Conn) Client {
67 | return &client{
68 | server: s,
69 | log: log,
70 | mqttConn: conn,
71 | natsSubs: make(map[string]*nats.Subscription),
72 | st: StateInfant,
73 | writeQueue: make(chan pkg.Packet, writeQueueSize)}
74 | }
75 |
76 | func (c *client) Serve() {
77 | defer func() {
78 | _ = c.mqttConn.Close()
79 |
80 | if c.natsConn != nil {
81 | c.natsConn.Close()
82 | }
83 | }()
84 |
85 | c.workers.Add(2)
86 | go c.readLoop()
87 | go c.writeLoop()
88 | c.workers.Wait()
89 |
90 | cp := c.connectPacket
91 | if c.err != nil {
92 | c.Error(c.err)
93 | }
94 | if cp == nil {
95 | // No connection was established
96 | c.Debug("client connection could not be established")
97 | } else {
98 | c.Debug("disconnected")
99 | if cp.CleanSession() {
100 | c.server.SessionManager().Remove(cp.ClientID())
101 | c.Debug("session removed")
102 | }
103 | }
104 | }
105 |
106 | // String returns a text suitable for logging of client messages.
107 | func (c *client) String() string {
108 | switch c.State() {
109 | case StateInfant:
110 | return "Client (not yet connected)"
111 | case StateConnected:
112 | return "Client " + c.connectPacket.ClientID()
113 | default:
114 | return "Client " + c.connectPacket.ClientID() + " (disconnected)"
115 | }
116 | }
117 |
118 | func (c *client) State() byte {
119 | var s byte
120 | c.stLock.RLock()
121 | s = c.st
122 | c.stLock.RUnlock()
123 | return s
124 | }
125 |
126 | func (c *client) StateAndMaxWait() (byte, time.Duration) {
127 | c.stLock.RLock()
128 | s := c.st
129 | m := c.maxWait
130 | c.stLock.RUnlock()
131 | return s, m
132 | }
133 |
134 | func (c *client) setState(newState byte) {
135 | c.stLock.Lock()
136 | c.st = newState
137 | c.stLock.Unlock()
138 | }
139 |
140 | func (c *client) setStateAndMaxWait(newState byte, maxWait time.Duration) {
141 | c.stLock.Lock()
142 | c.st = newState
143 | c.maxWait = maxWait
144 | c.stLock.Unlock()
145 | }
146 |
147 | func (c *client) SetDisconnected(err error) {
148 | doit := false
149 | c.stLock.Lock()
150 | if c.st != StateDisconnected {
151 | doit = true
152 | c.st = StateDisconnected
153 | c.maxWait = time.Millisecond
154 | }
155 | c.stLock.Unlock()
156 |
157 | if doit {
158 | if cp := c.connectPacket; cp != nil && cp.HasWill() {
159 | err := c.server.PublishWill(cp.Will(), cp.Credentials())
160 | if err != nil {
161 | c.Error(err)
162 | } else {
163 | c.Debug("will published to", cp.Will().Topic)
164 | }
165 | }
166 | // This packet will not be sent but it will terminate the write loop once everything else
167 | // has been flushed
168 | c.writeQueue <- pkg.DisconnectSingleton
169 |
170 | if err == io.EOF {
171 | err = io.ErrUnexpectedEOF
172 | }
173 | c.err = err
174 |
175 | // release reader block
176 | _ = c.mqttConn.SetReadDeadline(time.Now().Add(time.Millisecond))
177 | }
178 | }
179 |
180 | func (c *client) Debug(args ...interface{}) {
181 | if c.log.DebugEnabled() {
182 | c.log.Debug(c.addFirst(args)...)
183 | }
184 | }
185 |
186 | func (c *client) Error(args ...interface{}) {
187 | if c.log.ErrorEnabled() {
188 | c.log.Error(c.addFirst(args)...)
189 | }
190 | }
191 |
192 | // addFirst prepends the client to the args slice and returns the new slice
193 | func (c *client) addFirst(args []interface{}) []interface{} {
194 | na := make([]interface{}, len(args)+1)
195 | na[0] = c
196 | copy(na[1:], args)
197 | return na
198 | }
199 |
200 | func (c *client) readLoop() {
201 | defer c.workers.Done()
202 |
203 | r := mqtt.NewReader(c.mqttConn)
204 |
205 | var err error
206 |
207 | readNextPacket:
208 | for st, maxWait := c.StateAndMaxWait(); st != StateDisconnected && err == nil; st, maxWait = c.StateAndMaxWait() {
209 | var (
210 | b byte
211 | rl int
212 | )
213 |
214 | if maxWait > 0 {
215 | _ = c.mqttConn.SetReadDeadline(time.Now().Add(maxWait))
216 | }
217 |
218 | // Read packet type and flags
219 | if b, err = r.ReadByte(); err != nil {
220 | break
221 | }
222 |
223 | pkgType := b & pkg.TpMask
224 | switch st {
225 | case StateConnected:
226 | if pkgType == pkg.TpConnect {
227 | err = errors.New("second connect packet")
228 | break readNextPacket
229 | }
230 | case StateInfant:
231 | if pkgType != pkg.TpConnect {
232 | err = errors.New("not connected")
233 | break readNextPacket
234 | }
235 | }
236 |
237 | // Read packet length
238 | if rl, err = r.ReadVarInt(); err != nil {
239 | break
240 | }
241 |
242 | var p pkg.Packet
243 | switch pkgType {
244 | case pkg.TpDisconnect:
245 | // Normal disconnect
246 | // Discard will
247 | c.Debug("received", pkg.DisconnectSingleton)
248 | c.connectPacket.DeleteWill()
249 | break readNextPacket
250 | case pkg.TpPing:
251 | pr := pkg.PingRequestSingleton
252 | c.Debug("received", pr)
253 | c.queueForWrite(pkg.PingResponseSingleton)
254 | case pkg.TpConnect:
255 | if p, err = pkg.ParseConnect(r, b, rl); err == nil {
256 | c.Debug("received", p)
257 | err = c.handleConnect(p.(*pkg.Connect))
258 | if err == nil {
259 | c.server.ManageClient(c)
260 | }
261 | }
262 | if retCode, ok := err.(pkg.ReturnCode); ok {
263 | c.Debug("received", p, "return code", retCode)
264 | c.setState(StateConnected)
265 | c.queueForWrite(pkg.NewConnAck(c.sessionPresent, retCode))
266 | }
267 | case pkg.TpPublish:
268 | if p, err = pkg.ParsePublish(r, b, rl); err == nil {
269 | c.Debug("received", p)
270 | err = c.natsPublish(c.server.HandleRetain(p.(*pkg.Publish)))
271 | }
272 | case pkg.TpPubAck:
273 | if p, err = pkg.ParsePubAck(r, b, rl); err == nil {
274 | c.Debug("received", p)
275 | id := p.(pkg.PubAck).ID()
276 | c.session.ClientAckReceived(id, c.natsConn)
277 | c.server.ReleasePacketID(id)
278 | }
279 | case pkg.TpPubRec:
280 | if p, err = pkg.ParsePubRec(r, b, rl); err == nil {
281 | c.Debug("received", p)
282 | // TODO: handle PubRec
283 | }
284 | case pkg.TpPubRel:
285 | if p, err = pkg.ParsePubRel(r, b, rl); err == nil {
286 | c.Debug("received", p)
287 | // TODO: handle PubRel
288 | }
289 | case pkg.TpPubComp:
290 | if p, err = pkg.ParsePubComp(r, b, rl); err == nil {
291 | c.Debug("received", p)
292 | // TODO: handle PubComp
293 | }
294 | case pkg.TpSubscribe:
295 | if p, err = pkg.ParseSubscribe(r, b, rl); err == nil {
296 | c.Debug("received", p)
297 | sp := p.(*pkg.Subscribe)
298 | c.natsSubscribe(sp)
299 | c.server.PublishMatching(sp, c)
300 | }
301 | case pkg.TpUnsubscribe:
302 | if p, err = pkg.ParseUnsubscribe(r, b, rl); err == nil {
303 | c.Debug("received", p)
304 | c.natsUnsubscribe(p.(*pkg.Unsubscribe))
305 | }
306 | default:
307 | err = fmt.Errorf("received unknown packet type %d", (b&pkg.TpMask)>>4)
308 | }
309 | }
310 | c.SetDisconnected(err)
311 | }
312 |
313 | func (c *client) handleConnect(cp *pkg.Connect) error {
314 | var err error
315 | c.connectPacket = cp
316 | c.natsConn, err = c.server.NatsConn(cp.Credentials())
317 | if err != nil {
318 | // TODO: Different error codes depending on error from NATS
319 | c.Error("NATS connect failed", err)
320 | return pkg.RtServerUnavailable
321 | }
322 |
323 | cid := cp.ClientID()
324 | m := c.server.SessionManager()
325 | c.sessionPresent = false
326 |
327 | if cp.CleanSession() {
328 | c.session = m.Create(cid)
329 | } else {
330 | if s := m.Get(cid); s != nil {
331 | c.session = s
332 | c.sessionPresent = true
333 | } else {
334 | c.session = m.Create(cid)
335 | }
336 | }
337 |
338 | var maxWait time.Duration
339 | if cp.KeepAlive() > 0 {
340 | // Max wait between control packets is 1.5 times the keep alive value
341 | maxWait = (cp.KeepAlive() * 3) / 2
342 | }
343 | c.setStateAndMaxWait(StateConnected, maxWait)
344 | c.queueForWrite(pkg.NewConnAck(c.sessionPresent, 0))
345 |
346 | if cp.CleanSession() {
347 | c.Debug("connected with clean session")
348 | } else {
349 | if c.sessionPresent {
350 | c.Debug("connected using preexisting session")
351 | c.session.RestoreAckSubscriptions(c)
352 | c.session.ResendClientUnack(c)
353 | } else {
354 | c.Debug("connected using new (unclean) session")
355 | }
356 | }
357 | return nil
358 | }
359 |
360 | func (c *client) queueForWrite(p pkg.Packet) {
361 | if c.State() == StateConnected {
362 | c.writeQueue <- p
363 | }
364 | }
365 |
366 | func (c *client) writeLoop() {
367 | defer c.workers.Done()
368 |
369 | bulk := make([]pkg.Packet, writeQueueSize)
370 |
371 | // writer's buffer is reused for each bulk operation
372 | w := mqtt.NewWriter()
373 |
374 | // Each iteration of this loop with pick max writeQueueSize packets from the writeQueue
375 | // and then write those packets on mqtt.Writer (a bytes.Buffer extension). The resulting bytes
376 | // are then written to the connection using one single write on the connection.
377 | for connected := true; connected; {
378 | bulk[0] = <-c.writeQueue
379 | i := 1
380 | inner:
381 | for ; i < writeQueueSize; i++ {
382 | select {
383 | case p := <-c.writeQueue:
384 | bulk[i] = p
385 | default:
386 | break inner
387 | }
388 | }
389 | w.Reset()
390 |
391 | for n := 0; n < i; n++ {
392 | p := bulk[n]
393 | if p == pkg.DisconnectSingleton {
394 | connected = false
395 | break
396 | }
397 | c.Debug("sending", p)
398 | p.Write(w)
399 | }
400 | bs := w.Bytes()
401 | if len(bs) > 0 {
402 | if _, err := c.mqttConn.Write(bs); err != nil {
403 | if connected {
404 | c.SetDisconnected(err)
405 | } else {
406 | // Drain failed. Log the error
407 | c.Error(err)
408 | }
409 | break
410 | }
411 | }
412 | }
413 | }
414 |
415 | func (c *client) PublishResponse(qos byte, pp *pkg.Publish) {
416 | if qos > 0 {
417 | c.session.ClientAckRequested(pp)
418 | }
419 | c.queueForWrite(pp)
420 | }
421 |
--------------------------------------------------------------------------------
/bridge/client_test.go:
--------------------------------------------------------------------------------
1 | package bridge
2 |
3 | import (
4 | "encoding/json"
5 | "errors"
6 | "fmt"
7 | "io"
8 | "testing"
9 | "time"
10 |
11 | "github.com/tada/mqtt-nats/test/packet"
12 |
13 | "github.com/nats-io/nats.go"
14 | "github.com/tada/mqtt-nats/logger"
15 | "github.com/tada/mqtt-nats/mqtt"
16 | "github.com/tada/mqtt-nats/mqtt/pkg"
17 | "github.com/tada/mqtt-nats/test/mock"
18 | "github.com/tada/mqtt-nats/test/utils"
19 | )
20 |
21 | type mockServer struct {
22 | sm
23 | pkg.IDManager
24 | nc *nats.Conn
25 | ncError error
26 | willError error
27 | t *testing.T
28 | }
29 |
30 | func (m *mockServer) UnmarshalFromJSON(js *json.Decoder, firstToken json.Token) {
31 | m.t.Helper()
32 | m.t.Fatal("implement me")
33 | }
34 |
35 | func (m *mockServer) MarshalToJSON(io.Writer) {
36 | m.t.Helper()
37 | m.t.Fatal("implement me")
38 | }
39 |
40 | func (m *mockServer) SessionManager() SessionManager {
41 | return m
42 | }
43 |
44 | func (m *mockServer) ManageClient(c Client) {
45 | }
46 |
47 | func (m *mockServer) NatsConn(creds *pkg.Credentials) (*nats.Conn, error) {
48 | return m.nc, m.ncError
49 | }
50 |
51 | func (m *mockServer) HandleRetain(pp *pkg.Publish) *pkg.Publish {
52 | return pp
53 | }
54 |
55 | func (m *mockServer) PublishMatching(sp *pkg.Subscribe, c Client) {
56 | }
57 |
58 | func (m *mockServer) PublishWill(will *pkg.Will, creds *pkg.Credentials) error {
59 | return m.willError
60 | }
61 |
62 | func newMockServer(t *testing.T) *mockServer {
63 | return &mockServer{sm: sm{m: make(map[string]Session, 3)}, IDManager: pkg.NewIDManager(), t: t}
64 | }
65 |
66 | func writePacket(t *testing.T, p pkg.Packet, w io.Writer) {
67 | t.Helper()
68 | mw := mqtt.NewWriter()
69 | p.Write(mw)
70 | _, err := w.Write(mw.Bytes())
71 | utils.CheckNotError(err, t)
72 | }
73 |
74 | var silent = logger.New(logger.Silent, nil, nil)
75 |
76 | // Test_client_String tests that the client String method produces sane output
77 | // in all states of the client (infant, connected, disconnected)
78 | func Test_client_String(t *testing.T) {
79 | conn := mock.NewConnection()
80 | cl := NewClient(newMockServer(t), silent, conn)
81 | done := make(chan bool, 1)
82 | go func() {
83 | cl.Serve()
84 | done <- true
85 | }()
86 |
87 | utils.CheckEqual("Client (not yet connected)", cl.(fmt.Stringer).String(), t)
88 |
89 | rConn := conn.Remote()
90 | writePacket(t, pkg.NewConnect("client-id", false, 1, nil, nil), rConn)
91 | bs := make([]byte, 2)
92 | _, err := rConn.Read(bs)
93 | utils.CheckNotError(err, t)
94 | utils.CheckEqual("Client client-id", cl.(fmt.Stringer).String(), t)
95 | writePacket(t, pkg.DisconnectSingleton, rConn)
96 | <-done
97 | utils.CheckEqual("Client client-id (disconnected)", cl.(fmt.Stringer).String(), t)
98 | }
99 |
100 | // Test_client_natsConnError tests that the server responds with a ConnAck containing
101 | // an pkg.RtServerUnavailable when the client was unable to establish a NATS connection.
102 | func Test_client_natsConnError(t *testing.T) {
103 | conn := mock.NewConnection()
104 | rConn := conn.Remote()
105 | ms := newMockServer(t)
106 | ms.ncError = errors.New("unauthorized")
107 | cl := NewClient(ms, silent, conn)
108 | go cl.Serve()
109 |
110 | writePacket(t, pkg.NewConnect("client-id", false, 1, nil, nil), rConn)
111 | ca, ok := packet.Parse(t, rConn).(*pkg.ConnAck)
112 | utils.CheckTrue(ok, t)
113 | utils.CheckEqual(pkg.RtServerUnavailable, ca.ReturnCode(), t)
114 | }
115 |
116 | // Test_client_natsConnError tests that the server responds with a ConnAck containing
117 | // an pkg.RtServerUnavailable when the client was unable to establish a NATS connection.
118 | func Test_client_natsSubscribeError(t *testing.T) {
119 | mt := &collectLogsT{}
120 | conn := mock.NewConnection()
121 | ms := newMockServer(t)
122 | ms.ncError = nil
123 | ms.nc = &nats.Conn{}
124 | cl := NewClient(ms, utils.NewLogger(logger.Error, mt), conn)
125 | go cl.Serve()
126 |
127 | rConn := conn.Remote()
128 | writePacket(t, pkg.NewConnect("client-id", true, 1, nil, nil), rConn)
129 | ca, ok := packet.Parse(t, rConn).(*pkg.ConnAck)
130 | utils.CheckTrue(ok, t)
131 | utils.CheckEqual(pkg.RtAccepted, ca.ReturnCode(), t)
132 |
133 | // Newline is unacceptable in a subject
134 | writePacket(t, pkg.NewSubscribe(1, pkg.Topic{Name: "top\nic"}), rConn)
135 | sa, ok := packet.Parse(t, rConn).(*pkg.SubAck)
136 | utils.CheckTrue(ok, t)
137 |
138 | // Topic return code should be 0x80 to indicate failure
139 | utils.CheckEqual(pkg.NewSubAck(1, 0x80), sa, t)
140 |
141 | // At least one error should be logged (additional caused by forced disconnect)
142 | utils.CheckTrue(len(mt.logEntries) > 0, t)
143 | el := mt.logEntries[0]
144 | utils.CheckEqual(5, len(el), t)
145 | utils.CheckEqual(el[0], "ERROR", t)
146 | utils.CheckTrue(cl == el[1], t)
147 | utils.CheckEqual("NATS subscribe", el[2], t)
148 | utils.CheckEqual("top\nic", el[3], t)
149 | }
150 |
151 | type collectLogsT struct {
152 | logEntries [][]interface{}
153 | }
154 |
155 | func (m *collectLogsT) Log(args ...interface{}) {
156 | m.logEntries = append(m.logEntries, args)
157 | }
158 |
159 | func (m *collectLogsT) Helper() {
160 | }
161 |
162 | // Test_client_publishWillError tests that errors during an attempt to publish the will
163 | // provided in the CONNECT package are logged at level logger.Error
164 | func Test_client_publishWillError(t *testing.T) {
165 | mt := &collectLogsT{}
166 | conn := mock.NewConnection()
167 | ms := newMockServer(t)
168 | ms.willError = errors.New("unauthorized")
169 | cl := NewClient(ms, utils.NewLogger(logger.Error, mt), conn)
170 |
171 | done := make(chan bool, 1)
172 | go func() {
173 | cl.Serve()
174 | done <- true
175 | }()
176 |
177 | rConn := conn.Remote()
178 | writePacket(t, pkg.NewConnect("client-id", false, 1, &pkg.Will{
179 | Topic: "some/will",
180 | Message: []byte("will message")}, nil), rConn)
181 |
182 | ca, ok := packet.Parse(t, rConn).(*pkg.ConnAck)
183 | utils.CheckEqual(pkg.RtAccepted, ca.ReturnCode(), t)
184 | utils.CheckTrue(ok, t)
185 | _ = conn.Close()
186 | <-done
187 |
188 | // At least one error should be logged (additional caused by forced disconnect)
189 | utils.CheckTrue(len(mt.logEntries) > 0, t)
190 | el := mt.logEntries[0]
191 | utils.CheckEqual(len(el), 3, t)
192 | utils.CheckEqual(el[0], "ERROR", t)
193 | utils.CheckTrue(cl == el[1], t)
194 | err, ok := el[2].(error)
195 | utils.CheckTrue(ok, t)
196 | utils.CheckEqual("unauthorized", err.Error(), t)
197 | }
198 |
199 | // Test_client_debugLog checks that the client performs debug logging
200 | func Test_client_debugLog(t *testing.T) {
201 | mt := &collectLogsT{}
202 | conn := mock.NewConnection()
203 | cl := NewClient(newMockServer(t), utils.NewLogger(logger.Debug, mt), conn)
204 |
205 | done := make(chan bool, 1)
206 | go func() {
207 | cl.Serve()
208 | done <- true
209 | }()
210 |
211 | rConn := conn.Remote()
212 | writePacket(t, pkg.NewConnect("client-id", false, 1, nil, nil), rConn)
213 | ca, ok := packet.Parse(t, rConn).(*pkg.ConnAck)
214 | utils.CheckTrue(ok, t)
215 | utils.CheckEqual(pkg.RtAccepted, ca.ReturnCode(), t)
216 | writePacket(t, pkg.PubRec(1), rConn)
217 | writePacket(t, pkg.PubRel(2), rConn)
218 | writePacket(t, pkg.PubComp(3), rConn)
219 | writePacket(t, pkg.DisconnectSingleton, rConn)
220 | <-done
221 |
222 | // check that all received packages were logged
223 | cnt := 0
224 | for _, le := range mt.logEntries {
225 | if len(le) == 4 && le[0] == "DEBUG" && le[2] == "received" {
226 | switch le[3].(type) {
227 | case *pkg.Connect, pkg.PubRec, pkg.PubRel, pkg.PubComp:
228 | cnt++
229 | }
230 | }
231 | }
232 | utils.CheckEqual(4, cnt, t)
233 | }
234 |
235 | type writeFailure struct {
236 | *mock.Connection
237 | succeed uint
238 | tick chan bool
239 | }
240 |
241 | func (c *writeFailure) Write(bs []byte) (int, error) {
242 | if c.succeed == 0 {
243 | return 0, errors.New("write failed")
244 | }
245 | i, err := c.Connection.Write(bs)
246 | c.succeed--
247 | c.tick <- true
248 | return i, err
249 | }
250 |
251 | // Test_write_failure_when_connected checks that the client propagates write error
252 | func Test_write_failure_when_connected(t *testing.T) {
253 | conn := &writeFailure{Connection: mock.NewConnection()}
254 | mt := &collectLogsT{}
255 | cl := NewClient(newMockServer(t), utils.NewLogger(logger.Error, mt), conn)
256 |
257 | done := make(chan bool, 1)
258 | go func() {
259 | cl.Serve()
260 | done <- true
261 | }()
262 |
263 | writePacket(t, pkg.NewConnect("client-id", false, 1, nil, nil), conn.Remote())
264 | // Should fail with forced disconnect
265 | select {
266 | case <-done:
267 | case <-time.After(time.Second):
268 | t.Fatal("expected forced disconnect did not occur")
269 | }
270 |
271 | // At least one error should be logged (additional caused by forced disconnect)
272 | utils.CheckTrue(len(mt.logEntries) > 0, t)
273 | el := mt.logEntries[0]
274 | utils.CheckEqual(len(el), 3, t)
275 | utils.CheckEqual(el[0], "ERROR", t)
276 | utils.CheckTrue(cl == el[1], t)
277 | err, ok := el[2].(error)
278 | utils.CheckTrue(ok, t)
279 | utils.CheckEqual("write failed", err.Error(), t)
280 | }
281 |
282 | // Test_write_failure_during_drain checks that the client logs error that occurs during writeLoop drain
283 | func Test_write_failure_during_drain(t *testing.T) {
284 | conn := &writeFailure{Connection: mock.NewConnection(), succeed: 1, tick: make(chan bool, 1)}
285 | mt := &collectLogsT{}
286 | cl := NewClient(newMockServer(t), utils.NewLogger(logger.Error, mt), conn)
287 |
288 | done := make(chan bool, 1)
289 | go func() {
290 | cl.Serve()
291 | done <- true
292 | }()
293 |
294 | rConn := conn.Remote()
295 | writePacket(t, pkg.NewConnect("client-id", false, 1, nil, nil), rConn)
296 | <-conn.tick
297 | cl.(*client).queueForWrite(pkg.PingResponseSingleton)
298 | cl.(*client).queueForWrite(pkg.DisconnectSingleton)
299 | cl.SetDisconnected(nil)
300 |
301 | // Should fail with forced disconnect
302 | select {
303 | case <-done:
304 | case <-time.After(100 * time.Millisecond):
305 | t.Fatal("expected forced disconnect did not occur")
306 | }
307 |
308 | // At least one error should be logged (additional caused by forced disconnect)
309 | utils.CheckTrue(len(mt.logEntries) > 0, t)
310 | el := mt.logEntries[0]
311 | utils.CheckEqual(len(el), 3, t)
312 | utils.CheckEqual(el[0], "ERROR", t)
313 | utils.CheckTrue(cl == el[1], t)
314 | err, ok := el[2].(error)
315 | utils.CheckTrue(ok, t)
316 | utils.CheckEqual("write failed", err.Error(), t)
317 | }
318 |
--------------------------------------------------------------------------------
/bridge/nats.go:
--------------------------------------------------------------------------------
1 | package bridge
2 |
3 | import (
4 | "errors"
5 |
6 | "github.com/nats-io/nats.go"
7 | "github.com/tada/mqtt-nats/mqtt"
8 | "github.com/tada/mqtt-nats/mqtt/pkg"
9 | )
10 |
11 | func (c *client) natsPublish(pp *pkg.Publish) error {
12 | var err error
13 | if pp.IsDup() {
14 | if c.session.AwaitsAck(pp.ID()) {
15 | // Already waiting for this one
16 | return nil
17 | }
18 | }
19 |
20 | natsSubject := mqtt.ToNATS(pp.TopicName())
21 | switch pp.QoSLevel() {
22 | case 0:
23 | // Fire and forget
24 | err = c.natsConn.Publish(natsSubject, pp.Payload())
25 | case 1:
26 | // use client id and packet id to form a reply subject
27 | replyTo := NewReplyTopic(c.session, pp).String()
28 | var sub *nats.Subscription
29 | sub, err = c.natsSubscribeAck(replyTo)
30 | if err == nil {
31 | c.session.AckRequested(pp.ID(), sub)
32 | err = c.natsConn.PublishRequest(natsSubject, replyTo, pp.Payload())
33 | }
34 | case 2:
35 | err = errors.New("QoS level 2 is not supported")
36 | default:
37 | err = errors.New("invalid QoS level")
38 | }
39 | return err
40 | }
41 |
42 | func (c *client) natsSubscribeAck(topic string) (*nats.Subscription, error) {
43 | return c.natsConn.Subscribe(topic, func(m *nats.Msg) {
44 | // Client may have disconnected at this point which is why it is essential to ask
45 | // the session manager for the session based on the replyTo subject
46 | mt := ParseReplyTopic(m.Subject)
47 | if mt != nil {
48 | if s := c.server.SessionManager().Get(mt.ClientID()); s != nil && s.ID() == mt.SessionID() {
49 | c.cancelNatsSubscriptions(s.AckReceived(mt.PacketID()))
50 | c.queueForWrite(pkg.PubAck(mt.PacketID()))
51 | }
52 | }
53 | })
54 | }
55 |
56 | func (c *client) cancelNatsSubscriptions(nss []*nats.Subscription) {
57 | for i := range nss {
58 | ns := nss[i]
59 | if err := ns.Unsubscribe(); err != nil {
60 | c.Error("NATS unsubscribe", ns.Subject, err)
61 | }
62 | }
63 | }
64 |
65 | func (c *client) natsSubscribe(sp *pkg.Subscribe) {
66 | tps := sp.Topics()
67 | nms := make([]string, len(tps))
68 | qss := make([]byte, len(tps))
69 | var nss []*nats.Subscription
70 | c.subLock.Lock()
71 | for i := range tps {
72 | tp := tps[i]
73 | nm := mqtt.ToNATSSubscription(tp.Name)
74 | nms[i] = nm
75 | qss[i] = tp.QoS
76 | if os := c.natsSubs[nm]; os != nil {
77 | delete(c.natsSubs, nm)
78 | nss = append(nss, os)
79 | }
80 | }
81 | c.subLock.Unlock()
82 | c.cancelNatsSubscriptions(nss)
83 |
84 | nss = make([]*nats.Subscription, 0, len(nms))
85 | for i := range nms {
86 | nm := nms[i]
87 | qs := qss[i]
88 | if qs > 1 {
89 | qs = 1
90 | qss[i] = 1
91 | }
92 | ns, err := c.natsConn.Subscribe(nm, func(m *nats.Msg) {
93 | c.natsResponse(qs, m)
94 | })
95 | if err == nil {
96 | nss = append(nss, ns)
97 | } else {
98 | c.Error("NATS subscribe", nm, err)
99 | qss[i] = 0x80
100 | }
101 | }
102 | c.subLock.Lock()
103 | for i := range nss {
104 | ns := nss[i]
105 | c.natsSubs[ns.Subject] = ns
106 | }
107 | c.subLock.Unlock()
108 | c.queueForWrite(pkg.NewSubAck(sp.ID(), qss...))
109 | }
110 |
111 | func (c *client) natsUnsubscribe(up *pkg.Unsubscribe) {
112 | tps := up.Topics()
113 | nss := make([]*nats.Subscription, 0, len(tps))
114 | c.subLock.Lock()
115 | for i := range tps {
116 | subj := mqtt.ToNATSSubscription(tps[i])
117 | if ns := c.natsSubs[subj]; ns != nil {
118 | nss = append(nss, ns)
119 | delete(c.natsSubs, subj)
120 | }
121 | }
122 | c.subLock.Unlock()
123 | c.cancelNatsSubscriptions(nss)
124 | c.queueForWrite(pkg.UnsubAck(up.ID()))
125 | }
126 |
127 | func (c *client) natsResponse(desiredQoS byte, m *nats.Msg) {
128 | id := uint16(0)
129 | flags := byte(0)
130 | if desiredQoS > 0 && m.Reply != `` {
131 | if mt := ParseReplyTopic(m.Reply); mt != nil {
132 | id = mt.PacketID()
133 | flags = mt.Flags()
134 | } else {
135 | id = c.server.NextFreePacketID()
136 | flags = 2 // QoS level 1
137 | }
138 | }
139 | pp := pkg.NewPublish(id, mqtt.FromNATS(m.Subject), flags, m.Data, false, m.Reply)
140 | qos := desiredQoS
141 | if pp.QoSLevel() < qos {
142 | qos = pp.QoSLevel()
143 | }
144 | c.PublishResponse(qos, pp)
145 | }
146 |
--------------------------------------------------------------------------------
/bridge/natspub.go:
--------------------------------------------------------------------------------
1 | package bridge
2 |
3 | import (
4 | "encoding/json"
5 | "io"
6 |
7 | "github.com/tada/catch/pio"
8 | "github.com/tada/jsonstream"
9 | "github.com/tada/mqtt-nats/mqtt/pkg"
10 | )
11 |
12 | // natsPub represents a message which originated from this server (such as a client will) that has been
13 | // published to NATS using some given credentials and now awaits a reply.
14 | type natsPub struct {
15 | // pp is the packet that was published
16 | pp *pkg.Publish
17 |
18 | // creds are the client credentials for the publication
19 | creds *pkg.Credentials
20 | }
21 |
22 | func (n *natsPub) MarshalToJSON(w io.Writer) {
23 | pio.WriteString(w, `{"m":`)
24 | n.pp.MarshalToJSON(w)
25 | if n.creds != nil {
26 | pio.WriteString(w, `,"c":`)
27 | n.creds.MarshalToJSON(w)
28 | }
29 | pio.WriteByte(w, '}')
30 | }
31 |
32 | func (n *natsPub) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) {
33 | jsonstream.AssertDelim(t, '{')
34 | for {
35 | k, ok := js.ReadStringOrEnd('}')
36 | if !ok {
37 | break
38 | }
39 | switch k {
40 | case "m":
41 | n.pp = &pkg.Publish{}
42 | if !js.ReadConsumer(n.pp) {
43 | n.pp = nil
44 | }
45 | case "c":
46 | n.creds = &pkg.Credentials{}
47 | if !js.ReadConsumer(n.creds) {
48 | n.creds = nil
49 | }
50 | }
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/bridge/options.go:
--------------------------------------------------------------------------------
1 | package bridge
2 |
3 | import (
4 | "crypto/tls"
5 |
6 | "github.com/nats-io/nats.go"
7 | )
8 |
9 | // Options contains all configuration options for the mqtt-nats bridge.
10 | type Options struct {
11 | // Path to file where the bridge is persisted. Can be empty if no persistence is desired
12 | StoragePath string
13 |
14 | // NATSUrls is a comma separated list of URLs used when connecting to NATS
15 | NATSUrls string
16 |
17 | // RetainedRequestTopic is a NATS topic that a NATS client can publish to after doing a subscribe
18 | // in order to retrieve any messages that are retained for that subscription. The payload must be
19 | // the verbatim NATS subscription. Retained messages that matches the subscription will be published
20 | // to the reply-to topic in the form of a JSON list of objects with a "subject" string and a "payload"
21 | // base64 encoded string
22 | RetainedRequestTopic string
23 |
24 | // Port is the MQTT port
25 | Port int
26 |
27 | // RepeatRate is the delay in milliseconds between publishing packets that originated in this server
28 | // that have QoS > 0 but hasn't been acknowledged.
29 | RepeatRate int
30 |
31 | // NATSOpts are options specific to the NATS connection
32 | NATSOpts []nats.Option
33 |
34 | TLSTimeout float64
35 | TLSCert string
36 | TLSKey string
37 | TLSCaCert string
38 | TLSConfig *tls.Config
39 | TLS bool
40 | TLSVerify bool
41 | TLSMap bool
42 |
43 | // Debug enables debug level log output
44 | Debug bool
45 | }
46 |
--------------------------------------------------------------------------------
/bridge/package.go:
--------------------------------------------------------------------------------
1 | // Package bridge contains the MQTT-NATS bridge server implementation
2 | package bridge
3 |
--------------------------------------------------------------------------------
/bridge/replytopic.go:
--------------------------------------------------------------------------------
1 | package bridge
2 |
3 | import (
4 | "fmt"
5 | "strconv"
6 | "strings"
7 |
8 | "github.com/tada/mqtt-nats/mqtt/pkg"
9 | )
10 |
11 | // ReplyTopic represents the decoded form of the NATS reply-topic that the bridge uses to track
12 | // messages that are in need of an ACK.
13 | type ReplyTopic struct {
14 | s string
15 | c string
16 | p uint16
17 | f byte
18 | }
19 |
20 | // NewReplyTopic creates a new ReplyTopic based on a pkg.Publish packet.
21 | func NewReplyTopic(s Session, pp *pkg.Publish) *ReplyTopic {
22 | return &ReplyTopic{c: s.ClientID(), s: s.ID(), p: pp.ID(), f: pp.Flags()}
23 | }
24 |
25 | // ParseReplyTopic creates a new ReplyTopic by parsing a NATS reply-to string
26 | func ParseReplyTopic(s string) *ReplyTopic {
27 | ps := strings.Split(s, ".")
28 | if len(ps) == 5 && ps[0] == "_INBOX" {
29 | p, err := strconv.Atoi(ps[3])
30 | if err == nil {
31 | var f int
32 | f, err = strconv.Atoi(ps[4])
33 | if err == nil {
34 | return &ReplyTopic{c: ps[1], s: ps[2], p: uint16(p), f: byte(f)}
35 | }
36 | }
37 | }
38 | return nil
39 | }
40 |
41 | // ClientID returns the ID of the client where the message originated
42 | func (r *ReplyTopic) ClientID() string {
43 | return r.c
44 | }
45 |
46 | // SessionID returns the ID of the client session within the mqtt-nats bridge
47 | func (r *ReplyTopic) SessionID() string {
48 | return r.s
49 | }
50 |
51 | // PacketID returns the packet identifier of the original packet
52 | func (r *ReplyTopic) PacketID() uint16 {
53 | return r.p
54 | }
55 |
56 | // Flags returns the packet flags of the original packet
57 | func (r *ReplyTopic) Flags() byte {
58 | return r.f
59 | }
60 |
61 | // String returns the NATS string form of the reply-topic
62 | func (r *ReplyTopic) String() string {
63 | return fmt.Sprintf("_INBOX.%s.%s.%d.%d", r.c, r.s, r.p, r.f)
64 | }
65 |
--------------------------------------------------------------------------------
/bridge/retained.go:
--------------------------------------------------------------------------------
1 | package bridge
2 |
3 | import (
4 | "encoding/json"
5 | "io"
6 | "regexp"
7 | "strings"
8 | "sync"
9 |
10 | "github.com/nats-io/nats.go"
11 | "github.com/tada/catch/pio"
12 | "github.com/tada/jsonstream"
13 | "github.com/tada/mqtt-nats/mqtt"
14 | "github.com/tada/mqtt-nats/mqtt/pkg"
15 | )
16 |
17 | type retained struct {
18 | lock sync.RWMutex
19 | msgs map[string]*pkg.Publish
20 | order []string
21 | }
22 |
23 | func (r *retained) Empty() bool {
24 | r.lock.RLock()
25 | empty := len(r.order) == 0
26 | r.lock.RUnlock()
27 | return empty
28 | }
29 |
30 | func (r *retained) MarshalJSON() ([]byte, error) {
31 | return jsonstream.Marshal(r)
32 | }
33 |
34 | func (r *retained) UnmarshalJSON(bs []byte) error {
35 | return jsonstream.Unmarshal(r, bs)
36 | }
37 |
38 | func (r *retained) MarshalToJSON(w io.Writer) {
39 | r.lock.RLock()
40 | defer r.lock.RUnlock()
41 | sep := byte('{')
42 | for _, t := range r.order {
43 | pio.WriteByte(w, sep)
44 | sep = byte(',')
45 | jsonstream.WriteString(w, t)
46 | pio.WriteByte(w, ':')
47 | r.msgs[t].MarshalToJSON(w)
48 | }
49 | if sep == '{' {
50 | pio.WriteByte(w, sep)
51 | }
52 | pio.WriteByte(w, '}')
53 | }
54 |
55 | func (r *retained) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) {
56 | jsonstream.AssertDelim(t, '{')
57 | r.msgs = make(map[string]*pkg.Publish)
58 | r.order = nil
59 | for {
60 | s, ok := js.ReadStringOrEnd('}')
61 | if !ok {
62 | break
63 | }
64 | p := &pkg.Publish{}
65 | js.ReadConsumer(p)
66 | r.msgs[s] = p
67 | r.order = append(r.order, s)
68 | }
69 | }
70 |
71 | func (r *retained) add(m *pkg.Publish) bool {
72 | r.lock.Lock()
73 | t := m.TopicName()
74 | _, present := r.msgs[t]
75 | r.msgs[t] = m
76 | if !present {
77 | r.order = append(r.order, t)
78 | }
79 | r.lock.Unlock()
80 | return !present
81 | }
82 |
83 | func (r *retained) drop(t string) bool {
84 | dropped := false
85 | r.lock.Lock()
86 | if _, present := r.msgs[t]; present {
87 | delete(r.msgs, t)
88 | o := r.order
89 | last := len(o) - 1
90 | for i := 0; i <= last; i++ {
91 | if o[i] == t {
92 | copy(o[i:], o[i+1:])
93 | o[last] = `` // allow GC of last
94 | r.order = o[:last]
95 | dropped = true
96 | break
97 | }
98 | }
99 | }
100 | r.lock.Unlock()
101 | return dropped
102 | }
103 |
104 | func (r *retained) messagesMatchingRetainRequest(m *nats.Msg) ([]*pkg.Publish, []byte) {
105 | natsTopics := strings.Split(string(m.Data), ",")
106 | topics := make([]pkg.Topic, len(natsTopics))
107 | for i := range natsTopics {
108 | topics[i] = pkg.Topic{Name: mqtt.FromNATSSubscription(natsTopics[i])}
109 | }
110 | return r.matchingMessages(topics)
111 | }
112 |
113 | func (r *retained) publishMatching(s *pkg.Subscribe, c Client) {
114 | pps, qs := r.matchingMessages(s.Topics())
115 | for i := range pps {
116 | pp := pps[i]
117 | c.PublishResponse(qs[i], pp)
118 | }
119 | }
120 |
121 | func (r *retained) matchingMessages(tps []pkg.Topic) ([]*pkg.Publish, []byte) {
122 | tpl := len(tps)
123 |
124 | // Create slices of subscription regexps and desired QoS. One entry for each subscription topic
125 | txs := make([]*regexp.Regexp, tpl)
126 | dqs := make([]byte, tpl)
127 | for i := 0; i < tpl; i++ {
128 | tp := tps[i]
129 | txs[i] = mqtt.SubscriptionToRegexp(tp.Name)
130 | dqs[i] = tp.QoS
131 | }
132 |
133 | // For each subscription topic, extract matching packets and desired QoS
134 | pps := make([]*pkg.Publish, 0)
135 | qs := make([]byte, 0)
136 |
137 | r.lock.RLock()
138 | for i := 0; i < tpl; i++ {
139 | tx := txs[i]
140 | dq := dqs[i]
141 | for _, t := range r.order {
142 | if tx.MatchString(t) {
143 | pps = append(pps, r.msgs[t])
144 | qs = append(qs, dq)
145 | }
146 | }
147 | }
148 | r.lock.RUnlock()
149 | return pps, qs
150 | }
151 |
--------------------------------------------------------------------------------
/bridge/session.go:
--------------------------------------------------------------------------------
1 | package bridge
2 |
3 | import (
4 | "encoding/json"
5 | "io"
6 | "strconv"
7 | "sync"
8 |
9 | "github.com/tada/catch"
10 |
11 | "github.com/nats-io/nats.go"
12 | "github.com/tada/catch/pio"
13 | "github.com/tada/jsonstream"
14 | "github.com/tada/mqtt-nats/mqtt/pkg"
15 | )
16 |
17 | // A Session contains data associated with a client ID. The session might survive client
18 | // connections.
19 | type Session interface {
20 | jsonstream.Consumer
21 | jsonstream.Streamer
22 |
23 | // ID returns an identifier that is unique for this session
24 | ID() string
25 |
26 | // ClientID returns the id of the client that this session belongs to
27 | ClientID() string
28 |
29 | // Destroy the session
30 | Destroy()
31 |
32 | // AckRequested remembers the given subscription which represents an awaited ACK
33 | // for the given packetID
34 | AckRequested(uint16, *nats.Subscription)
35 |
36 | // AwaitsAck returns true if a subscription associated with the given packet identifier
37 | // is currently waiting for an Ack.
38 | AwaitsAck(uint16) bool
39 |
40 | // AckReceived will delete pending ack subscription from the session and return them. It is up
41 | // to the caller to cancel the returned subscriptions.
42 | AckReceived(uint16) []*nats.Subscription
43 |
44 | // ClientAckRequested remembers the id of a packet which has been sent to the client. The packet stems from a NATS
45 | // subscription with QoS level > 0 and it is now expected that the client sends an PubACK back to which can be
46 | // propagated to the reply-to address.
47 | ClientAckRequested(*pkg.Publish)
48 |
49 | // ClientAckReceived will close a pending response ack subscription and forward the ACK to the
50 | // replyTo subject. It returns whether or not such an ack was pending
51 | ClientAckReceived(uint16, *nats.Conn) bool
52 |
53 | // Resend all messages that the client hasn't acknowledged
54 | ResendClientUnack(c *client)
55 |
56 | // RestoreAckSubscriptions called when a client restores an old session. THe method restores subscriptions that
57 | // were peristed and then loaded again.
58 | RestoreAckSubscriptions(c *client)
59 | }
60 |
61 | type session struct {
62 | id string
63 | clientID string
64 | prelAwaitsAck map[uint16]string
65 | awaitsAck map[uint16]*nats.Subscription // awaits ack on reply-to to be propagated to client
66 | awaitsClientAck map[uint16]*pkg.Publish // awaits ack from client to be propagated to nats
67 | awaitsAckLock sync.RWMutex
68 | }
69 |
70 | func (s *session) MarshalJSON() ([]byte, error) {
71 | return jsonstream.Marshal(s)
72 | }
73 |
74 | func (s *session) MarshalToJSON(w io.Writer) {
75 | pio.WriteString(w, `{"id":`)
76 | jsonstream.WriteString(w, s.id)
77 | pio.WriteString(w, `,"cid":`)
78 | jsonstream.WriteString(w, s.clientID)
79 | s.awaitsAckLock.RLock()
80 | if len(s.awaitsAck) > 0 {
81 | pio.WriteString(w, `,"awAck":`)
82 | sep := byte('{')
83 | for k, v := range s.awaitsAck {
84 | pio.WriteByte(w, sep)
85 | sep = byte(',')
86 | pio.WriteByte(w, '"')
87 | pio.WriteInt(w, int64(k))
88 | pio.WriteString(w, `":`)
89 | jsonstream.WriteString(w, v.Subject)
90 | }
91 | pio.WriteByte(w, '}')
92 | }
93 | if len(s.awaitsClientAck) > 0 {
94 | pio.WriteString(w, `,"awClientAck":`)
95 | sep := byte('{')
96 | for k, v := range s.awaitsClientAck {
97 | pio.WriteByte(w, sep)
98 | sep = byte(',')
99 | pio.WriteByte(w, '"')
100 | pio.WriteInt(w, int64(k))
101 | pio.WriteString(w, `":`)
102 | v.MarshalToJSON(w)
103 | }
104 | pio.WriteByte(w, '}')
105 | }
106 | pio.WriteByte(w, '}')
107 | s.awaitsAckLock.RUnlock()
108 | }
109 |
110 | func (s *session) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) {
111 | jsonstream.AssertDelim(t, '{')
112 | for {
113 | k, ok := js.ReadStringOrEnd('}')
114 | if !ok {
115 | break
116 | }
117 | switch k {
118 | case "id":
119 | s.id = js.ReadString()
120 | case "cid":
121 | s.clientID = js.ReadString()
122 | case "awAck":
123 | js.ReadDelim('{')
124 | for {
125 | k, ok = js.ReadStringOrEnd('}')
126 | if !ok {
127 | break
128 | }
129 | if s.prelAwaitsAck == nil {
130 | s.prelAwaitsAck = make(map[uint16]string)
131 | }
132 | i, err := strconv.Atoi(k)
133 | if err != nil {
134 | panic(catch.Error(err))
135 | }
136 | s.prelAwaitsAck[uint16(i)] = js.ReadString()
137 | }
138 | case "awClientAck":
139 | js.ReadDelim('{')
140 | for {
141 | k, ok = js.ReadStringOrEnd('}')
142 | if !ok {
143 | break
144 | }
145 | if s.awaitsClientAck == nil {
146 | s.awaitsClientAck = make(map[uint16]*pkg.Publish)
147 | }
148 | i, err := strconv.Atoi(k)
149 | if err != nil {
150 | panic(catch.Error(err))
151 | }
152 | pp := &pkg.Publish{}
153 | if js.ReadConsumer(pp) {
154 | s.awaitsClientAck[uint16(i)] = pp
155 | }
156 | }
157 | }
158 | }
159 | }
160 |
161 | func (s *session) RestoreAckSubscriptions(c *client) {
162 | if s.prelAwaitsAck != nil {
163 | for k, v := range s.prelAwaitsAck {
164 | sb, err := c.natsSubscribeAck(v)
165 | if err != nil {
166 | c.Error(err)
167 | } else {
168 | s.AckRequested(k, sb)
169 | }
170 | }
171 | s.prelAwaitsAck = nil
172 | }
173 | }
174 |
175 | func (s *session) AckReceived(packetID uint16) []*nats.Subscription {
176 | var nss []*nats.Subscription
177 | s.awaitsAckLock.Lock()
178 | if s.awaitsAck != nil {
179 | if sb, awaits := s.awaitsAck[packetID]; awaits {
180 | nss = append(nss, sb)
181 | delete(s.awaitsAck, packetID)
182 | }
183 | }
184 | s.awaitsAckLock.Unlock()
185 | return nss
186 | }
187 |
188 | func (s *session) AckRequested(packetID uint16, sb *nats.Subscription) {
189 | s.awaitsAckLock.Lock()
190 | if s.awaitsAck == nil {
191 | s.awaitsAck = make(map[uint16]*nats.Subscription)
192 | }
193 | s.awaitsAck[packetID] = sb
194 | s.awaitsAckLock.Unlock()
195 | }
196 |
197 | func (s *session) AwaitsAck(packetID uint16) bool {
198 | awaits := false
199 | s.awaitsAckLock.RLock()
200 | if s.awaitsAck != nil {
201 | _, awaits = s.awaitsAck[packetID]
202 | }
203 | s.awaitsAckLock.RUnlock()
204 | return awaits
205 | }
206 |
207 | func (s *session) ClientAckReceived(packetID uint16, c *nats.Conn) bool {
208 | var pp *pkg.Publish
209 | s.awaitsAckLock.Lock()
210 | if s.awaitsClientAck != nil {
211 | var found bool
212 | if pp, found = s.awaitsClientAck[packetID]; found {
213 | delete(s.awaitsClientAck, packetID)
214 | }
215 | }
216 | s.awaitsAckLock.Unlock()
217 | if pp != nil {
218 | _ = c.Publish(pp.NatsReplyTo(), []byte{0})
219 | return true
220 | }
221 | return false
222 | }
223 |
224 | func (s *session) ClientAckRequested(pp *pkg.Publish) {
225 | s.awaitsAckLock.Lock()
226 | if s.awaitsClientAck == nil {
227 | s.awaitsClientAck = make(map[uint16]*pkg.Publish)
228 | }
229 | s.awaitsClientAck[pp.ID()] = pp
230 | s.awaitsAckLock.Unlock()
231 | }
232 |
233 | func (s *session) ResendClientUnack(c *client) {
234 | s.awaitsAckLock.RLock()
235 | as := make([]*pkg.Publish, 0, len(s.awaitsClientAck))
236 | for _, a := range s.awaitsClientAck {
237 | as = append(as, a)
238 | }
239 | s.awaitsAckLock.RUnlock()
240 | for i := range as {
241 | a := as[i]
242 | c.PublishResponse(a.QoSLevel(), a)
243 | }
244 | }
245 |
246 | func (s *session) ID() string {
247 | return s.id
248 | }
249 |
250 | func (s *session) ClientID() string {
251 | return s.clientID
252 | }
253 |
254 | func (s *session) Destroy() {
255 | // Unsubscribe all pending subscriptions
256 | s.awaitsAckLock.Lock()
257 | if s.awaitsAck != nil {
258 | for _, sb := range s.awaitsAck {
259 | _ = sb.Unsubscribe()
260 | }
261 | s.awaitsAck = nil
262 | }
263 | s.awaitsAckLock.Unlock()
264 | }
265 |
266 | // A SessionManager manages sessions.
267 | type SessionManager interface {
268 | // Create creates a new session for the given clientID. Any previous session registered for
269 | // the given id is discarded
270 | Create(clientID string) Session
271 |
272 | // Get returns an existing session for the given clientID or nil if no such session exists
273 | Get(clientID string) Session
274 |
275 | // Remove removes any session for the given clientID
276 | Remove(clientID string)
277 | }
278 |
279 | type sm struct {
280 | lock sync.RWMutex
281 | seed uint32
282 | m map[string]Session
283 | }
284 |
285 | func (m *sm) Get(clientID string) Session {
286 | var s Session
287 | m.lock.RLock()
288 | s = m.m[clientID]
289 | m.lock.RUnlock()
290 | return s
291 | }
292 |
293 | func (m *sm) Create(clientID string) Session {
294 | m.lock.Lock()
295 | m.seed++
296 | s := &session{id: `s` + strconv.Itoa(int(m.seed)), clientID: clientID}
297 | m.m[clientID] = s
298 | m.lock.Unlock()
299 | return s
300 | }
301 |
302 | func (m *sm) MarshalToJSON(w io.Writer) {
303 | m.lock.RLock()
304 | defer m.lock.RUnlock()
305 |
306 | pio.WriteString(w, `{"seed":`)
307 | pio.WriteInt(w, int64(m.seed))
308 | if len(m.m) > 0 {
309 | pio.WriteString(w, `,"sessions":`)
310 | sep := byte('{')
311 | for k, v := range m.m {
312 | pio.WriteByte(w, sep)
313 | sep = ','
314 | jsonstream.WriteString(w, k)
315 | pio.WriteByte(w, ':')
316 | v.MarshalToJSON(w)
317 | }
318 | pio.WriteByte(w, '}')
319 | }
320 | pio.WriteByte(w, '}')
321 | }
322 |
323 | func (m *sm) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) {
324 | jsonstream.AssertDelim(t, '{')
325 | for {
326 | k, ok := js.ReadStringOrEnd('}')
327 | if !ok {
328 | break
329 | }
330 | switch k {
331 | case "sessions":
332 | js.ReadDelim('{')
333 | for {
334 | k, ok = js.ReadStringOrEnd('}')
335 | if !ok {
336 | break
337 | }
338 | s := &session{}
339 | js.ReadConsumer(s)
340 | m.m[k] = s
341 | }
342 | case "seed":
343 | m.seed = uint32(js.ReadInt())
344 | }
345 | }
346 | }
347 |
348 | func (m *sm) Remove(clientID string) {
349 | var s Session
350 | m.lock.Lock()
351 | s = m.m[clientID]
352 | delete(m.m, clientID)
353 | m.lock.Unlock()
354 | if s != nil {
355 | s.Destroy()
356 | }
357 | }
358 |
--------------------------------------------------------------------------------
/cli/mqtt-nats.go:
--------------------------------------------------------------------------------
1 | // +build !citest
2 |
3 | package cli
4 |
5 | import (
6 | "flag"
7 | "io"
8 |
9 | "github.com/nats-io/nats.go"
10 | "github.com/tada/mqtt-nats/bridge"
11 | "github.com/tada/mqtt-nats/logger"
12 | )
13 |
14 | // Bridge parses the command line arguments of args into an bridge.Options instance and then starts the
15 | // bridge with those options.
16 | func Bridge(args []string, stdout, stderr io.Writer) int {
17 | fs := flag.NewFlagSet(args[0], flag.ExitOnError)
18 | fs.SetOutput(stderr)
19 |
20 | var (
21 | printHelp bool
22 | natsClientCert string
23 | natsClientKey string
24 | natsRootCAs string
25 | natsCredsFile string
26 | )
27 | opts := &bridge.Options{}
28 | fs.StringVar(&opts.NATSUrls, "natsurl", nats.DefaultURL, "NATS server URLs separated by comma")
29 | fs.IntVar(&opts.Port, "port", 0, "MQTT Port to listen on (defaults to 1883 or 8883 with TLS)")
30 | fs.BoolVar(&printHelp, "h", false, "")
31 | fs.BoolVar(&printHelp, "help", false, "Print this help")
32 | fs.IntVar(&opts.RepeatRate, "repeatrate", 5000, "time in milliseconds between each publish of unacknowledged messages")
33 | // persistence
34 | fs.StringVar(&opts.StoragePath, "storage", "mqtt-nats.json", "path to json file where server state is persisted")
35 |
36 | fs.BoolVar(&opts.Debug, "D", false, "Enable Debug logging")
37 | fs.BoolVar(&opts.Debug, "debug", false, "Enable Debug logging")
38 |
39 | // tls
40 | fs.BoolVar(&opts.TLS, "tls", false, "Enable TLS. If true, the -tlscert and -tlskey options are mandatory")
41 | fs.StringVar(&opts.TLSCert, "tlscert", "", "Server certificate file")
42 | fs.StringVar(&opts.TLSKey, "tlskey", "", "Private key for server certificate")
43 |
44 | // options to verify client certificate
45 | fs.BoolVar(&opts.TLSVerify, "tlsverify", false,
46 | "Enable verification of client TLS certificate. If true, the -tlscacert option is mandatory")
47 | fs.StringVar(&opts.TLSCaCert, "tlscacert", "", "Root Certificate for verification of client TLS certificate")
48 |
49 | fs.StringVar(&natsCredsFile, "nats-creds", "", "User Credentials File used when bridge connects to NATS")
50 | // tls when connecting to the NATS server
51 | fs.StringVar(&natsClientKey, "nats-key", "", "Public Key used by the bridge when connecting to NATS")
52 | fs.StringVar(&natsClientCert, "nats-cert", "", "Client Certificate used by the bridge when connecting to NATS")
53 | fs.StringVar(&natsRootCAs, "nats-cacert", "", "Client Root Certificate used by the bridge when connecting to NATS")
54 |
55 | _ = fs.Parse(args[1:])
56 | if printHelp {
57 | fs.SetOutput(stdout)
58 | fs.PrintDefaults()
59 | return 0
60 | }
61 |
62 | if opts.TLS {
63 | if opts.TLSCert == "" || opts.TLSKey == "" {
64 | _, _ = io.WriteString(stderr, "both -tlscert and -tlskey must be given when tls is enabled")
65 | return 2
66 | }
67 |
68 | if opts.Port == 0 {
69 | opts.Port = 8883
70 | }
71 | } else if opts.Port == 0 {
72 | opts.Port = 1883
73 | }
74 |
75 | opts.NATSOpts = []nats.Option{nats.Name("MQTT Bridge")}
76 | if natsCredsFile != "" {
77 | opts.NATSOpts = append(opts.NATSOpts, nats.UserCredentials(natsCredsFile))
78 | }
79 | if natsClientCert != "" || natsClientKey != "" {
80 | if natsClientCert == "" || natsClientKey == "" {
81 | _, _ = io.WriteString(stderr, "both -nats-cert and -nats-key must be given to enable client verification")
82 | return 2
83 | }
84 | opts.NATSOpts = append(opts.NATSOpts, nats.ClientCert(natsClientCert, natsClientKey))
85 | }
86 | if natsRootCAs != `` {
87 | opts.NATSOpts = append(opts.NATSOpts, nats.RootCAs(natsRootCAs))
88 | }
89 |
90 | lg := logger.New(logger.Debug, stdout, stderr)
91 | s, err := bridge.New(opts, lg)
92 | if err == nil {
93 | err = s.Serve(nil)
94 | }
95 |
96 | if err != nil {
97 | lg.Error(err)
98 | return 1
99 | }
100 | return 0
101 | }
102 |
--------------------------------------------------------------------------------
/cli/package.go:
--------------------------------------------------------------------------------
1 | // Package cli contains the Command Line Interface
2 | package cli
3 |
--------------------------------------------------------------------------------
/examples/certs/.gitignore:
--------------------------------------------------------------------------------
1 | *.pem
2 | *.csr
3 |
--------------------------------------------------------------------------------
/examples/certs/ca.json:
--------------------------------------------------------------------------------
1 | {
2 | "CN": "tada.se",
3 | "key": {
4 | "algo": "rsa",
5 | "size": 2048
6 | },
7 | "names": [
8 | {
9 | "C": "SE",
10 | "L": "Täby",
11 | "O": "Tada AB",
12 | "ST": "Stockholms Län"
13 | }
14 | ]
15 | }
--------------------------------------------------------------------------------
/examples/certs/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "signing": {
3 | "default": {
4 | "expiry": "43800h"
5 | },
6 | "profiles": {
7 | "server": {
8 | "expiry": "43800h",
9 | "usages": [
10 | "signing",
11 | "digital signing",
12 | "key encipherment",
13 | "server auth"
14 | ]
15 | }
16 | }
17 | }
18 | }
--------------------------------------------------------------------------------
/examples/certs/csr.json:
--------------------------------------------------------------------------------
1 | {
2 | "hosts": [
3 | "127.0.0.1",
4 | "localhost",
5 | "tada",
6 | "home.tada.se"
7 | ],
8 | "CN": "mqtt-nats.tada.se",
9 | "key": {
10 | "algo": "rsa",
11 | "size": 2048
12 | },
13 | "names": [
14 | {
15 | "C": "SE",
16 | "L": "Täby",
17 | "O": "Tada AB",
18 | "ST": "Stockholms Län"
19 | }
20 | ]
21 | }
22 |
--------------------------------------------------------------------------------
/examples/certs/generate.sh:
--------------------------------------------------------------------------------
1 | # Generate root CA
2 | cfssl gencert -initca ca.json | cfssljson -bare ca
3 |
4 | # Generate server cert
5 | cfssl gencert -ca=ca.pem -ca-key=ca-key.pem -config=config.json -profile=server server.json | cfssljson -bare server
6 |
7 | # Generate client cert + key
8 | cfssl gencert -ca ca.pem -ca-key ca-key.pem csr.json | cfssljson -bare client
9 |
--------------------------------------------------------------------------------
/examples/certs/server.json:
--------------------------------------------------------------------------------
1 | {
2 | "CN": "Server",
3 | "hosts": [
4 | "127.0.0.1",
5 | "localhost",
6 | "tada",
7 | "home.tada.se"
8 | ]
9 | }
10 |
--------------------------------------------------------------------------------
/examples/jmeter/MQTT Pub Sampler.jmx:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | false
7 | true
8 | false
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | continue
17 |
18 | false
19 | 10
20 |
21 | 1
22 | 1
23 | false
24 |
25 |
26 | true
27 |
28 |
29 |
30 | 127.0.0.1
31 | 1883
32 | 3.1.1
33 | 10
34 | TCP
35 | false
36 |
37 |
38 |
39 |
40 |
41 |
42 | conn_
43 | true
44 | 300
45 | 0
46 | 0
47 |
48 |
49 |
50 | my/test/topic
51 | 1
52 | true
53 | String
54 | 1024
55 | This is my topic
56 |
57 |
58 |
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/examples/server.conf:
--------------------------------------------------------------------------------
1 | listen: 0.0.0.0:4222
2 | tls: {
3 | ca_file: "./certs/ca.pem"
4 | cert_file: "./certs/server.pem"
5 | key_file: "./certs/server-key.pem"
6 | verify: true
7 | }
8 |
--------------------------------------------------------------------------------
/examples/tools/nats-pub-repeat/nats-pub-repeat.go:
--------------------------------------------------------------------------------
1 | // +build !citest
2 |
3 | // Copyright 2012-2019 The NATS Authors
4 | // Licensed under the Apache License, Version 2.0 (the "License");
5 | // you may not use this file except in compliance with the License.
6 | // You may obtain a copy of the License at
7 | //
8 | // http://www.apache.org/licenses/LICENSE-2.0
9 | //
10 | // Unless required by applicable law or agreed to in writing, software
11 | // distributed under the License is distributed on an "AS IS" BASIS,
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | // See the License for the specific language governing permissions and
14 | // limitations under the License.
15 |
16 | package main
17 |
18 | import (
19 | "bytes"
20 | "flag"
21 | "log"
22 | "os"
23 | "strconv"
24 | "time"
25 |
26 | "github.com/nats-io/nats.go"
27 | )
28 |
29 | // NOTE: Can test with demo servers.
30 | // nats-pub-repeat -s demo.nats.io
31 | // nats-pub-repeat -s demo.nats.io:4443 (TLS version)
32 |
33 | func showUsageAndExit(exitcode int) {
34 | flag.PrintDefaults()
35 | os.Exit(exitcode)
36 | }
37 |
38 | func main() {
39 | var urls = flag.String("s", nats.DefaultURL, "The nats server URLs (separated by comma)")
40 | var userCreds = flag.String("creds", "", "User Credentials File")
41 | var repeat = flag.Int("repeat", 1, "Repeat count")
42 | var askReply = flag.Int("askreply", 0, "seconds to wait for reply. 0 means don't wait")
43 | var showHelp = flag.Bool("h", false, "Show help message")
44 | var rootCA = flag.String("cacert", "", "TLS root certificate")
45 | var clientKey = flag.String("key", "", "TLS key")
46 | var clientCert = flag.String("cert", "", "TLS certificate")
47 |
48 | log.SetFlags(0)
49 | flag.Parse()
50 |
51 | if *showHelp {
52 | showUsageAndExit(0)
53 | }
54 |
55 | args := flag.Args()
56 | if len(args) != 2 {
57 | showUsageAndExit(1)
58 | }
59 |
60 | // Connect Options.
61 | opts := []nats.Option{nats.Name("NATS Sample Publisher")}
62 | if *rootCA != "" {
63 | opts = append(opts, nats.RootCAs(*rootCA))
64 | }
65 | if *clientCert != "" {
66 | opts = append(opts, nats.ClientCert(*clientCert, *clientKey))
67 | }
68 |
69 | // Use UserCredentials
70 | if *userCreds != "" {
71 | opts = append(opts, nats.UserCredentials(*userCreds))
72 | }
73 |
74 | // Connect to NATS
75 | nc, err := nats.Connect(*urls, opts...)
76 | if err != nil {
77 | log.Fatal(err)
78 | }
79 | defer nc.Close()
80 | subj, msg := args[0], args[1]
81 |
82 | buf := bytes.Buffer{}
83 | for i := 0; i < *repeat; i++ {
84 | buf.Reset()
85 | buf.WriteString(msg)
86 | buf.WriteByte(' ')
87 | buf.WriteString(strconv.Itoa(i))
88 | if *askReply != 0 {
89 | nc.Request(subj, buf.Bytes(), time.Duration(*askReply)*time.Second)
90 | } else {
91 | nc.Publish(subj, buf.Bytes())
92 | }
93 | }
94 | nc.Flush()
95 |
96 | if err := nc.LastError(); err != nil {
97 | log.Fatal(err)
98 | } else {
99 | log.Printf("Published [%s] : '%s'\n", subj, msg)
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/examples/tools/nats-sub-reply/nats-sub-reply.go:
--------------------------------------------------------------------------------
1 | // +build !citest
2 |
3 | // Copyright 2012-2019 The NATS Authors
4 | // Licensed under the Apache License, Version 2.0 (the "License");
5 | // you may not use this file except in compliance with the License.
6 | // You may obtain a copy of the License at
7 | //
8 | // http://www.apache.org/licenses/LICENSE-2.0
9 | //
10 | // Unless required by applicable law or agreed to in writing, software
11 | // distributed under the License is distributed on an "AS IS" BASIS,
12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | // See the License for the specific language governing permissions and
14 | // limitations under the License.
15 |
16 | package main
17 |
18 | import (
19 | "flag"
20 | "log"
21 | "os"
22 | "runtime"
23 | "time"
24 |
25 | "github.com/nats-io/nats.go"
26 | )
27 |
28 | // NOTE: Can test with demo servers.
29 | // nats-sub -s demo.nats.io
30 | // nats-sub -s demo.nats.io:4443 (TLS version)
31 |
32 | func usage() {
33 | log.Printf("Usage: nats-sub-reply [-s server] [-creds file] [-t] \n")
34 | flag.PrintDefaults()
35 | }
36 |
37 | func showUsageAndExit(exitcode int) {
38 | usage()
39 | os.Exit(exitcode)
40 | }
41 |
42 | func printMsg(m *nats.Msg, i int) {
43 | log.Printf("[#%d] Received on [%s]: '%s'", i, m.Subject, string(m.Data))
44 | }
45 |
46 | func main() {
47 | var urls = flag.String("s", nats.DefaultURL, "The nats server URLs (separated by comma)")
48 | var userCreds = flag.String("creds", "", "User Credentials File")
49 | var showTime = flag.Bool("t", false, "Display timestamps")
50 | var showHelp = flag.Bool("h", false, "Show help message")
51 | var rootCA = flag.String("cacert", "", "TLS root certificate")
52 | var clientKey = flag.String("key", "", "TLS key")
53 | var clientCert = flag.String("cert", "", "TLS certificate")
54 |
55 | log.SetFlags(0)
56 | flag.Usage = usage
57 | flag.Parse()
58 |
59 | if *showHelp {
60 | showUsageAndExit(0)
61 | }
62 |
63 | args := flag.Args()
64 | if len(args) != 1 {
65 | showUsageAndExit(1)
66 | }
67 |
68 | // Connect Options.
69 | opts := []nats.Option{nats.Name("NATS Sample Subscriber")}
70 | opts = setupConnOptions(opts)
71 | if *rootCA != "" {
72 | opts = append(opts, nats.RootCAs(*rootCA))
73 | }
74 | if *clientCert != "" {
75 | opts = append(opts, nats.ClientCert(*clientCert, *clientKey))
76 | }
77 |
78 | // Use UserCredentials
79 | if *userCreds != "" {
80 | opts = append(opts, nats.UserCredentials(*userCreds))
81 | }
82 |
83 | // Use UserCredentials
84 | if *userCreds != "" {
85 | opts = append(opts, nats.UserCredentials(*userCreds))
86 | }
87 |
88 | // Connect to NATS
89 | nc, err := nats.Connect(*urls, opts...)
90 | if err != nil {
91 | log.Fatal(err)
92 | }
93 |
94 | subj, i := args[0], 0
95 |
96 | nc.Subscribe(subj, func(msg *nats.Msg) {
97 | i += 1
98 | printMsg(msg, i)
99 | if msg.Reply != `` {
100 | log.Printf("Replying to %s", msg.Reply)
101 | msg.Respond([]byte{0})
102 | }
103 | })
104 | nc.Flush()
105 |
106 | if err := nc.LastError(); err != nil {
107 | log.Fatal(err)
108 | }
109 |
110 | log.Printf("Listening on [%s]", subj)
111 | if *showTime {
112 | log.SetFlags(log.LstdFlags)
113 | }
114 |
115 | runtime.Goexit()
116 | }
117 |
118 | func setupConnOptions(opts []nats.Option) []nats.Option {
119 | totalWait := 10 * time.Minute
120 | reconnectDelay := time.Second
121 |
122 | opts = append(opts, nats.ReconnectWait(reconnectDelay))
123 | opts = append(opts, nats.MaxReconnects(int(totalWait/reconnectDelay)))
124 | opts = append(opts, nats.DisconnectHandler(func(nc *nats.Conn) {
125 | log.Printf("Disconnected: will attempt reconnects for %.0fm", totalWait.Minutes())
126 | }))
127 | opts = append(opts, nats.ReconnectHandler(func(nc *nats.Conn) {
128 | log.Printf("Reconnected [%s]", nc.ConnectedUrl())
129 | }))
130 | opts = append(opts, nats.ClosedHandler(func(nc *nats.Conn) {
131 | log.Fatalf("Exiting: %v", nc.LastError())
132 | }))
133 | return opts
134 | }
135 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/tada/mqtt-nats
2 |
3 | go 1.14
4 |
5 | require (
6 | github.com/golang/protobuf v1.3.5 // indirect
7 | github.com/nats-io/nats-server/v2 v2.1.4
8 | github.com/nats-io/nats.go v1.9.1
9 | github.com/nats-io/nuid v1.0.1
10 | github.com/tada/catch v0.0.0-20200501140707-b8b11d55b4e6
11 | github.com/tada/jsonstream v0.0.0-20200501141504-4d34829515db
12 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 // indirect
13 | )
14 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
2 | github.com/golang/protobuf v1.3.5 h1:F768QJ1E9tib+q5Sc8MkdJi1RxLTbRcTf8LJV56aRls=
3 | github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk=
4 | github.com/nats-io/jwt v0.3.0 h1:xdnzwFETV++jNc4W1mw//qFyJGb2ABOombmZJQS4+Qo=
5 | github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg=
6 | github.com/nats-io/jwt v0.3.2 h1:+RB5hMpXUUA2dfxuhBTEkMOrYmM+gKIZYS1KjSostMI=
7 | github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU=
8 | github.com/nats-io/nats-server/v2 v2.1.4 h1:BILRnsJ2Yb/fefiFbBWADpViGF69uh4sxe8poVDQ06g=
9 | github.com/nats-io/nats-server/v2 v2.1.4/go.mod h1:Jw1Z28soD/QasIA2uWjXyM9El1jly3YwyFOuR8tH1rg=
10 | github.com/nats-io/nats.go v1.9.1 h1:ik3HbLhZ0YABLto7iX80pZLPw/6dx3T+++MZJwLnMrQ=
11 | github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w=
12 | github.com/nats-io/nkeys v0.1.0 h1:qMd4+pRHgdr1nAClu+2h/2a5F2TmKcCzjCDazVgRoX4=
13 | github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w=
14 | github.com/nats-io/nkeys v0.1.3 h1:6JrEfig+HzTH85yxzhSVbjHRJv9cn0p6n3IngIcM5/k=
15 | github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w=
16 | github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
17 | github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
18 | github.com/tada/catch v0.0.0-20200501140707-b8b11d55b4e6 h1:FOmtz4bkMV7ArdaFX9Ev8Mw1frMpw4XfTj8sAb4XprE=
19 | github.com/tada/catch v0.0.0-20200501140707-b8b11d55b4e6/go.mod h1:mL60x4NqUvoa7GzNDLlmi6IyIX8eRqQzkgcD2cs8dWM=
20 | github.com/tada/jsonstream v0.0.0-20200501141504-4d34829515db h1:VzNg3u3uJj5rNoop6te/tK4u/FPh2ytFv7IdbGBTlIE=
21 | github.com/tada/jsonstream v0.0.0-20200501141504-4d34829515db/go.mod h1:MzoAgsR5aJ/WBHQG8XdOmC6EP1Rpk1DpI+li6mgdDf0=
22 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
23 | golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 h1:HuIa8hRrWRSrqYzx1qI49NNxhdi2PrY7gxVSq1JjLDc=
24 | golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
25 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=
26 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
27 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
28 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
29 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
30 | golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e h1:D5TXcfTk7xF7hvieo4QErS3qqCB4teTffacDWr7CI+0=
31 | golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
32 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
33 |
--------------------------------------------------------------------------------
/logger/logger.go:
--------------------------------------------------------------------------------
1 | // Package logger contains a logger interface and an implementation that is based ont he standard log.Logger
2 | package logger
3 |
4 | import (
5 | "io"
6 | "log"
7 | )
8 |
9 | // A Logger logs information using a log level
10 | type Logger interface {
11 | // DebugEnabled returns true if debug level logging is enabled
12 | DebugEnabled() bool
13 |
14 | // Debug logs at debug level. Arguments are handled in the manner of fmt.Println.
15 | Debug(...interface{})
16 |
17 | // ErrorEnabled returns true if error level logging is enabled
18 | ErrorEnabled() bool
19 |
20 | // Error logs at error level. Arguments are handled in the manner of fmt.Println.
21 | Error(...interface{})
22 |
23 | // InfoEnabled returns true if info level logging is enabled
24 | InfoEnabled() bool
25 |
26 | // Info logs at info level. Arguments are handled in the manner of fmt.Println.
27 | Info(...interface{})
28 | }
29 |
30 | // Level determines at of logging that is enabled
31 | type Level int
32 |
33 | const (
34 | // Silent means that all logging is disabled
35 | Silent = Level(iota)
36 |
37 | // Error means that only error logging is enabled
38 | Error
39 |
40 | // Info means that error and info logging is enabled
41 | Info
42 |
43 | // Debug means that all logging is enabled
44 | Debug
45 | )
46 |
47 | type silent int
48 |
49 | func (silent) Debug(...interface{}) {
50 | }
51 | func (silent) DebugEnabled() bool {
52 | return false
53 | }
54 | func (silent) Error(...interface{}) {
55 | }
56 | func (silent) ErrorEnabled() bool {
57 | return false
58 | }
59 | func (silent) Info(...interface{}) {
60 | }
61 | func (silent) InfoEnabled() bool {
62 | return false
63 | }
64 |
65 | type writer struct {
66 | debug *log.Logger
67 | info *log.Logger
68 | err *log.Logger
69 | }
70 |
71 | func (l *writer) Debug(args ...interface{}) {
72 | if l.debug != nil {
73 | l.debug.Println(args...)
74 | }
75 | }
76 |
77 | func (l *writer) DebugEnabled() bool {
78 | return l.debug != nil
79 | }
80 |
81 | func (l *writer) Error(args ...interface{}) {
82 | if l.err != nil {
83 | l.err.Println(args...)
84 | }
85 | }
86 |
87 | func (l *writer) ErrorEnabled() bool {
88 | return l.err != nil
89 | }
90 |
91 | func (l *writer) Info(args ...interface{}) {
92 | if l.info != nil {
93 | l.info.Println(args...)
94 | }
95 | }
96 |
97 | func (l *writer) InfoEnabled() bool {
98 | return l.info != nil
99 | }
100 |
101 | // New returns a logger that is based on the standard log.Logger.
102 | func New(level Level, out, err io.Writer) Logger {
103 | if level == Silent {
104 | return silent(0)
105 | }
106 |
107 | l := &writer{}
108 | switch level {
109 | case Debug:
110 | l.debug = log.New(out, "DEBUG ", log.LstdFlags)
111 | fallthrough
112 | case Info:
113 | l.info = log.New(out, "INFO ", log.LstdFlags)
114 | fallthrough
115 | case Error:
116 | l.err = log.New(err, "ERROR ", log.LstdFlags)
117 | }
118 | return l
119 | }
120 |
--------------------------------------------------------------------------------
/logger/logger_test.go:
--------------------------------------------------------------------------------
1 | package logger
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "regexp"
7 | "testing"
8 | )
9 |
10 | func TestLogger_New(t *testing.T) {
11 | l := New(Silent, nil, nil)
12 | if _, ok := l.(silent); !ok {
13 | t.Error("silent leven did not result in silent logger")
14 | }
15 | }
16 |
17 | var msgEx = regexp.MustCompile(`^(A-Z)\s+^[\s]+\s^[\s+]\s(.*)$`)
18 |
19 | func assertEqual(t *testing.T, e, a string) {
20 | t.Helper()
21 | if e != a {
22 | t.Fatalf("expected '%s', got '%s'", e, a)
23 | }
24 | }
25 |
26 | func checkLogOutput(t *testing.T, ls, ms string, bf fmt.Stringer) {
27 | t.Helper()
28 | if ps := msgEx.FindStringSubmatch(bf.String()); len(ps) == 2 {
29 | assertEqual(t, ls, ps[0])
30 | assertEqual(t, ms, ps[1])
31 | }
32 | }
33 |
34 | func TestLogger_Debug(t *testing.T) {
35 | o := bytes.Buffer{}
36 | e := bytes.Buffer{}
37 | l := New(Debug, &o, &e)
38 | m := "some message"
39 | if !(l.ErrorEnabled() && l.InfoEnabled() && l.DebugEnabled()) {
40 | t.Fatal("wrong levels enabled")
41 | }
42 | l.Debug(m)
43 | if e.Len() > 0 {
44 | t.Error("debug log produced output on stderr")
45 | }
46 | l.Error(m)
47 | checkLogOutput(t, "DEBUG", m, &o)
48 | checkLogOutput(t, "ERROR", m, &e)
49 | }
50 |
51 | func TestLogger_Info(t *testing.T) {
52 | o := bytes.Buffer{}
53 | e := bytes.Buffer{}
54 | l := New(Info, &o, &e)
55 | m := "some message"
56 | if !(l.ErrorEnabled() && l.InfoEnabled() && !l.DebugEnabled()) {
57 | t.Fatal("wrong levels enabled")
58 | }
59 | l.Info(m)
60 | if e.Len() > 0 {
61 | t.Error("info log produced output on stderr")
62 | }
63 | l.Error(m)
64 | checkLogOutput(t, "INFO", m, &o)
65 | checkLogOutput(t, "ERROR", m, &e)
66 | }
67 |
68 | func TestLogger_Error(t *testing.T) {
69 | o := bytes.Buffer{}
70 | e := bytes.Buffer{}
71 | l := New(Error, &o, &e)
72 | m := "some message"
73 | if !(l.ErrorEnabled() && !l.InfoEnabled() && !l.DebugEnabled()) {
74 | t.Fatal("wrong levels enabled")
75 | }
76 | l.Error(m)
77 | l.Info(m)
78 | if o.Len() > 0 {
79 | t.Error("error log produced output on stdout")
80 | }
81 | checkLogOutput(t, "ERROR", m, &e)
82 | }
83 |
84 | func TestLogger_Silent(t *testing.T) {
85 | o := bytes.Buffer{}
86 | e := bytes.Buffer{}
87 | l := New(Silent, &o, &e)
88 | m := "some message"
89 | if l.ErrorEnabled() || l.InfoEnabled() || l.DebugEnabled() {
90 | t.Fatal("wrong levels enabled")
91 | }
92 | l.Debug(m)
93 | l.Info(m)
94 | l.Error(m)
95 | if o.Len() > 0 {
96 | t.Error("silent log produced output on stdout")
97 | }
98 | if e.Len() > 0 {
99 | t.Error("silent log produced output on stderr")
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/main.go:
--------------------------------------------------------------------------------
1 | // +build !citest
2 |
3 | package main
4 |
5 | import (
6 | "os"
7 |
8 | "github.com/tada/mqtt-nats/cli"
9 | )
10 |
11 | func main() {
12 | os.Exit(cli.Bridge(os.Args, os.Stdout, os.Stderr))
13 | }
14 |
--------------------------------------------------------------------------------
/mqtt/package.go:
--------------------------------------------------------------------------------
1 | // Package mqtt contains the MQTT reader, writer and MQTT to NATS topic conversion utilities
2 | package mqtt
3 |
--------------------------------------------------------------------------------
/mqtt/pkg/connect.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "errors"
5 | "fmt"
6 | "time"
7 |
8 | "github.com/tada/mqtt-nats/mqtt"
9 | )
10 |
11 | const (
12 | protoName = "MQTT"
13 |
14 | cleanSessionFlag = byte(0x02)
15 | willFlag = byte(0x04)
16 | willQoS = byte(0x18)
17 | willRetainFlag = byte(0x20)
18 | passwordFlag = byte(0x40)
19 | userNameFlag = byte(0x80)
20 | )
21 |
22 | // ReturnCode used in response to a CONNECT
23 | type ReturnCode byte
24 |
25 | func (r ReturnCode) Error() string {
26 | switch r {
27 | case RtAccepted:
28 | return "accepted"
29 | case RtUnacceptableProtocolVersion:
30 | return "unacceptable protocol version"
31 | case RtIdentifierRejected:
32 | return "identifier rejected"
33 | case RtServerUnavailable:
34 | return "server unavailable"
35 | case RtBadUserNameOrPassword:
36 | return "bad user name or password"
37 | case RtNotAuthorized:
38 | return "not authorized"
39 | default:
40 | return "unknown error"
41 | }
42 | }
43 |
44 | const (
45 | // RtAccepted Connection Accepted
46 | RtAccepted = ReturnCode(iota)
47 |
48 | // RtUnacceptableProtocolVersion The Server does not support the level of the MQTT protocol requested by the Client
49 | RtUnacceptableProtocolVersion
50 |
51 | // RtIdentifierRejected The Client identifier is correct UTF-8 but not allowed by the Server
52 | RtIdentifierRejected
53 |
54 | // RtServerUnavailable The Network Connection has been made but the MQTT service is unavailable
55 | RtServerUnavailable
56 |
57 | // RtBadUserNameOrPassword The data in the user name or password is malformed
58 | RtBadUserNameOrPassword
59 |
60 | // RtNotAuthorized The Client is not authorized to connect
61 | RtNotAuthorized
62 | )
63 |
64 | // Connect is the MQTT connect packet
65 | type Connect struct {
66 | clientID string
67 | creds *Credentials
68 | will *Will
69 | keepAlive uint16
70 | clientLevel byte
71 | flags byte
72 | }
73 |
74 | // NewConnect creates a new MQTT connect packet
75 | func NewConnect(clientID string, cleanSession bool, keepAlive uint16, will *Will, creds *Credentials) *Connect {
76 | flags := byte(0)
77 | if cleanSession {
78 | flags |= cleanSessionFlag
79 | }
80 | if will != nil {
81 | flags |= willFlag | (will.QoS << 3)
82 | if will.Retain {
83 | flags |= willRetainFlag
84 | }
85 | }
86 | if creds != nil {
87 | if len(creds.Password) > 0 {
88 | flags |= passwordFlag
89 | }
90 | if len(creds.User) > 0 {
91 | flags |= userNameFlag
92 | }
93 | }
94 |
95 | return &Connect{
96 | clientID: clientID,
97 | creds: creds,
98 | keepAlive: keepAlive,
99 | clientLevel: 0x4,
100 | flags: flags,
101 | will: will}
102 | }
103 |
104 | // ParseConnect parses the connect packet from the given reader.
105 | func ParseConnect(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) {
106 | var err error
107 | if r, err = r.ReadPacket(pkLen); err != nil {
108 | return nil, err
109 | }
110 |
111 | // Protocol Name
112 | var proto string
113 | if proto, err = r.ReadString(); err != nil {
114 | return nil, err
115 | }
116 | if proto != protoName {
117 | return nil, fmt.Errorf(`expected connect packet with protocol name "MQTT", got "%s"`, proto)
118 | }
119 |
120 | c := &Connect{}
121 |
122 | // Protocol Level
123 | if c.clientLevel, err = r.ReadByte(); err != nil {
124 | return nil, err
125 | }
126 | if c.clientLevel != 0x4 {
127 | return c, RtUnacceptableProtocolVersion
128 | }
129 |
130 | // Connect Flags
131 | if c.flags, err = r.ReadByte(); err != nil {
132 | return nil, err
133 | }
134 |
135 | // Keep Alive
136 | if c.keepAlive, err = r.ReadUint16(); err != nil {
137 | return nil, err
138 | }
139 |
140 | // Payload starts here
141 |
142 | // Client Identifier
143 | if c.clientID, err = r.ReadString(); err != nil {
144 | return nil, err
145 | }
146 |
147 | // Will
148 | if c.HasWill() {
149 | c.will = &Will{QoS: (c.flags & willQoS) >> 3, Retain: (c.flags & willRetainFlag) != 0}
150 | if c.will.Topic, err = r.ReadString(); err != nil {
151 | return nil, err
152 | }
153 | if c.will.Message, err = r.ReadBytes(); err != nil {
154 | return nil, err
155 | }
156 | }
157 |
158 | // User Name
159 | if (c.flags & (userNameFlag | passwordFlag)) != 0 {
160 | c.creds = &Credentials{}
161 | if c.HasUserName() {
162 | if c.creds.User, err = r.ReadString(); err != nil {
163 | return nil, err
164 | }
165 | }
166 | // Password
167 | if c.HasPassword() {
168 | if c.creds.Password, err = r.ReadBytes(); err != nil {
169 | return nil, err
170 | }
171 | }
172 | }
173 | return c, nil
174 | }
175 |
176 | // Equals returns true if this packet is equal to the given packet, false if not
177 | func (c *Connect) Equals(other interface{}) bool {
178 | oc, ok := other.(*Connect)
179 | return ok &&
180 | c.keepAlive == oc.keepAlive &&
181 | c.clientLevel == oc.clientLevel &&
182 | c.flags == oc.flags &&
183 | c.clientID == oc.clientID &&
184 | (c.will == oc.will || (c.will != nil && c.will.Equals(oc.will))) &&
185 | (c.creds == oc.creds || (c.creds != nil && c.creds.Equals(oc.creds)))
186 | }
187 |
188 | // SetClientLevel sets the client level. Intended for testing purposes
189 | func (c *Connect) SetClientLevel(cl byte) {
190 | c.clientLevel = cl
191 | }
192 |
193 | // Write writes the MQTT bits of this packet on the given Writer
194 | func (c *Connect) Write(w *mqtt.Writer) {
195 | pkLen := 2 + len(protoName) +
196 | 1 + // clientLevel
197 | 1 + // flags
198 | 2 + // keepAlive
199 | 2 + len(c.clientID)
200 |
201 | if c.HasWill() {
202 | pkLen += 2 + len(c.will.Topic)
203 | pkLen += 2 + len(c.will.Message)
204 | }
205 | if c.HasUserName() {
206 | pkLen += 2 + len(c.creds.User)
207 | }
208 | if c.HasPassword() {
209 | pkLen += 2 + len(c.creds.Password)
210 | }
211 |
212 | w.WriteU8(TpConnect)
213 | w.WriteVarInt(pkLen)
214 | w.WriteString(protoName)
215 | w.WriteU8(c.clientLevel)
216 | w.WriteU8(c.flags)
217 | w.WriteU16(c.keepAlive)
218 | w.WriteString(c.clientID)
219 | if c.HasWill() {
220 | w.WriteString(c.will.Topic)
221 | w.WriteBytes(c.will.Message)
222 | }
223 | if c.HasUserName() {
224 | w.WriteString(c.creds.User)
225 | }
226 | if c.HasPassword() {
227 | w.WriteBytes(c.creds.Password)
228 | }
229 | }
230 |
231 | // String returns a brief string representation of the packet. Suitable for logging
232 | func (c *Connect) String() string {
233 | will := ""
234 | if c.HasWill() {
235 | will = ", " + c.Will().String()
236 | }
237 | return fmt.Sprintf("CONNECT (c%d, k%d, u%d, p%d%s)",
238 | (c.flags&cleanSessionFlag)>>1,
239 | time.Duration(c.keepAlive),
240 | (c.flags&userNameFlag)>>7,
241 | (c.flags&passwordFlag)>>6,
242 | will)
243 | }
244 |
245 | // CleanSession returns true if the connection is requesting a clean session
246 | func (c *Connect) CleanSession() bool {
247 | return (c.flags & cleanSessionFlag) != 0
248 | }
249 |
250 | // ClientID returns the id provided by the client
251 | func (c *Connect) ClientID() string {
252 | return c.clientID
253 | }
254 |
255 | // HasPassword returns true if the connection contains a password
256 | func (c *Connect) HasPassword() bool {
257 | return (c.flags & passwordFlag) != 0
258 | }
259 |
260 | // HasUserName returns true if the connection contains a user name
261 | func (c *Connect) HasUserName() bool {
262 | return (c.flags & userNameFlag) != 0
263 | }
264 |
265 | // HasWill returns true if the connection contains a will
266 | func (c *Connect) HasWill() bool {
267 | return (c.flags & willFlag) != 0
268 | }
269 |
270 | // KeepAlive returns the desired keep alive duration
271 | func (c *Connect) KeepAlive() time.Duration {
272 | return time.Duration(c.keepAlive) * time.Second
273 | }
274 |
275 | // Credentials returns the user name and password credentials or nil
276 | func (c *Connect) Credentials() *Credentials {
277 | return c.creds
278 | }
279 |
280 | // Will returns the client will or nil
281 | func (c *Connect) Will() *Will {
282 | return c.will
283 | }
284 |
285 | // DeleteWill clears will and all flags that are associated with the will
286 | func (c *Connect) DeleteWill() {
287 | c.flags &^= willFlag | willQoS | willRetainFlag
288 | c.will = nil
289 | }
290 |
291 | // ConnAck is the MQTT CONNACK packet sent in response to a CONNECT
292 | type ConnAck struct {
293 | flags byte
294 | returnCode byte
295 | }
296 |
297 | // NewConnAck creates an CONNACK packet
298 | func NewConnAck(sessionPresent bool, returnCode ReturnCode) Packet {
299 | flags := byte(0x00)
300 | if sessionPresent {
301 | flags |= 0x01
302 | }
303 | return &ConnAck{flags: flags, returnCode: byte(returnCode)}
304 | }
305 |
306 | // ParseConnAck parses a CONNACK packet
307 | func ParseConnAck(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) {
308 | var err error
309 | if pkLen != 2 {
310 | return nil, errors.New("malformed CONNACK")
311 | }
312 | var bs []byte
313 | bs, err = r.ReadExact(2)
314 | if err != nil {
315 | return nil, err
316 | }
317 | return &ConnAck{flags: bs[0], returnCode: bs[1]}, nil
318 | }
319 |
320 | // Equals returns true if this packet is equal to the given packet, false if not
321 | func (a *ConnAck) Equals(other interface{}) bool {
322 | ac, ok := other.(*ConnAck)
323 | return ok && *a == *ac
324 | }
325 |
326 | // ReturnCode returns the return code from the server
327 | func (a *ConnAck) ReturnCode() ReturnCode {
328 | return ReturnCode(a.returnCode)
329 | }
330 |
331 | // String returns a brief string representation of the packet. Suitable for logging
332 | func (a *ConnAck) String() string {
333 | return fmt.Sprintf("CONNACK (s%d, rt%d)", a.flags, a.returnCode)
334 | }
335 |
336 | // Write writes the MQTT bits of this packet on the given Writer
337 | func (a *ConnAck) Write(w *mqtt.Writer) {
338 | w.WriteU8(TpConnAck)
339 | w.WriteU8(2)
340 | w.WriteU8(a.flags)
341 | w.WriteU8(a.returnCode)
342 | }
343 |
344 | // The Disconnect type represents the MQTT DISCONNECT packet
345 | type Disconnect int
346 |
347 | // DisconnectSingleton is the one and only instance of the Disconnect type
348 | const DisconnectSingleton = Disconnect(0)
349 |
350 | // Equals returns true if this packet is equal to the given packet, false if not
351 | func (Disconnect) Equals(other interface{}) bool {
352 | return other == DisconnectSingleton
353 | }
354 |
355 | // String returns a brief string representation of the packet. Suitable for logging
356 | func (Disconnect) String() string {
357 | return "DISCONNECT"
358 | }
359 |
360 | // Write writes the MQTT bits of this packet on the given Writer
361 | func (Disconnect) Write(w *mqtt.Writer) {
362 | w.WriteU8(TpDisconnect)
363 | w.WriteU8(0)
364 | }
365 |
--------------------------------------------------------------------------------
/mqtt/pkg/connect_test.go:
--------------------------------------------------------------------------------
1 | package pkg_test
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 |
7 | "github.com/tada/mqtt-nats/test/utils"
8 |
9 | "github.com/tada/mqtt-nats/mqtt"
10 |
11 | "github.com/tada/mqtt-nats/mqtt/pkg"
12 | )
13 |
14 | func TestParseConnect(t *testing.T) {
15 | c1 := pkg.NewConnect(`cid`, true, 5, &pkg.Will{
16 | Topic: "my/will",
17 | Message: []byte("the will"),
18 | QoS: 1,
19 | Retain: false,
20 | }, &pkg.Credentials{User: "bob", Password: []byte("password")})
21 | writeReadAndCompare(t, c1, "CONNECT (c1, k5, u1, p1, w(r0, q1, 'my/will', ... (8 bytes)))")
22 | }
23 |
24 | func TestParseConnect_badLen(t *testing.T) {
25 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader([]byte{})), pkg.TpConnAck, 20)
26 | utils.CheckError(err, t)
27 | }
28 |
29 | func TestParseConnect_badProto(t *testing.T) {
30 | w := mqtt.NewWriter()
31 | w.WriteU16(5)
32 | w.WriteU8('N')
33 | w.WriteU8('O')
34 | w.WriteU8('N')
35 | w.WriteU8('O')
36 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 6)
37 | utils.CheckError(err, t)
38 | }
39 |
40 | func TestParseConnect_illegalProto(t *testing.T) {
41 | w := mqtt.NewWriter()
42 | w.WriteString("NONO")
43 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 6)
44 | utils.CheckError(err, t)
45 | }
46 |
47 | func TestParseConnect_badClientLevel(t *testing.T) {
48 | w := mqtt.NewWriter()
49 | w.WriteString("MQTT")
50 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 6)
51 | utils.CheckError(err, t)
52 | }
53 |
54 | func TestParseConnect_illegalClientLevel(t *testing.T) {
55 | w := mqtt.NewWriter()
56 | w.WriteString("MQTT")
57 | w.WriteU8(2)
58 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 7)
59 | utils.CheckError(err, t)
60 | }
61 |
62 | func TestParseConnect_badConnectFlags(t *testing.T) {
63 | w := mqtt.NewWriter()
64 | w.WriteString("MQTT")
65 | w.WriteU8(4)
66 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 7)
67 | utils.CheckError(err, t)
68 | }
69 |
70 | func TestParseConnect_badKeepAlive(t *testing.T) {
71 | w := mqtt.NewWriter()
72 | w.WriteString("MQTT")
73 | w.WriteU8(4)
74 | w.WriteU8(0)
75 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 8)
76 | utils.CheckError(err, t)
77 | }
78 |
79 | func TestParseConnect_badClientID(t *testing.T) {
80 | w := mqtt.NewWriter()
81 | w.WriteString("MQTT")
82 | w.WriteU8(4)
83 | w.WriteU8(0)
84 | w.WriteU16(5)
85 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 10)
86 | utils.CheckError(err, t)
87 | }
88 |
89 | func TestParseConnect_badWillTopic(t *testing.T) {
90 | w := mqtt.NewWriter()
91 | w.WriteString("MQTT")
92 | w.WriteU8(4)
93 | w.WriteU8(0x04)
94 | w.WriteU16(5)
95 | w.WriteString("cid")
96 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 15)
97 | utils.CheckError(err, t)
98 | }
99 |
100 | func TestParseConnect_badWillMessage(t *testing.T) {
101 | w := mqtt.NewWriter()
102 | w.WriteString("MQTT")
103 | w.WriteU8(4)
104 | w.WriteU8(0x04)
105 | w.WriteU16(5)
106 | w.WriteString("cid")
107 | w.WriteString("wtp")
108 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 20)
109 | utils.CheckError(err, t)
110 | }
111 |
112 | func TestParseConnect_badUser(t *testing.T) {
113 | w := mqtt.NewWriter()
114 | w.WriteString("MQTT")
115 | w.WriteU8(4)
116 | w.WriteU8(0x80)
117 | w.WriteU16(5)
118 | w.WriteString("cid")
119 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 15)
120 | utils.CheckError(err, t)
121 | }
122 |
123 | func TestParseConnect_badPw(t *testing.T) {
124 | w := mqtt.NewWriter()
125 | w.WriteString("MQTT")
126 | w.WriteU8(4)
127 | w.WriteU8(0x40)
128 | w.WriteU16(5)
129 | w.WriteString("cid")
130 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 15)
131 | utils.CheckError(err, t)
132 | }
133 |
134 | func TestParseConnAck(t *testing.T) {
135 | writeReadAndCompare(t, pkg.NewConnAck(false, 1), "CONNACK (s0, rt1)")
136 | }
137 |
138 | func TestParseConnAck_badLen(t *testing.T) {
139 | _, err := pkg.ParseConnAck(mqtt.NewReader(bytes.NewReader([]byte{})), pkg.TpConnAck, 0)
140 | utils.CheckError(err, t)
141 | }
142 |
143 | func TestParseConnAck_badBytes(t *testing.T) {
144 | _, err := pkg.ParseConnAck(mqtt.NewReader(bytes.NewReader([]byte{1})), pkg.TpConnAck, 2)
145 | utils.CheckError(err, t)
146 | }
147 |
148 | func TestParseDisconnect(t *testing.T) {
149 | writeReadAndCompare(t, pkg.DisconnectSingleton, "DISCONNECT")
150 | }
151 |
152 | func TestReturnCode_Error(t *testing.T) {
153 | tests := []struct {
154 | name string
155 | code pkg.ReturnCode
156 | want string
157 | }{
158 | {
159 | name: "RtAccepted",
160 | code: pkg.RtAccepted,
161 | want: "accepted",
162 | },
163 | {
164 | name: "RtUnacceptableProtocolVersion",
165 | code: pkg.RtUnacceptableProtocolVersion,
166 | want: "unacceptable protocol version",
167 | },
168 | {
169 | name: "RtIdentifierRejected",
170 | code: pkg.RtIdentifierRejected,
171 | want: "identifier rejected",
172 | },
173 | {
174 | name: "RtServerUnavailable",
175 | code: pkg.RtServerUnavailable,
176 | want: "server unavailable",
177 | },
178 | {
179 | name: "RtBadUserNameOrPassword",
180 | code: pkg.RtBadUserNameOrPassword,
181 | want: "bad user name or password",
182 | },
183 | {
184 | name: "RtNotAuthorized",
185 | code: pkg.RtNotAuthorized,
186 | want: "not authorized",
187 | },
188 | {
189 | name: "RtNotAuthorized",
190 | code: pkg.ReturnCode(99),
191 | want: "unknown error",
192 | },
193 | }
194 | for i := range tests {
195 | tt := tests[i]
196 | t.Run(tt.name, func(t *testing.T) {
197 | if got := tt.code.Error(); got != tt.want {
198 | t.Errorf("ReturnCode.Error() = %v, want %v", got, tt.want)
199 | }
200 | })
201 | }
202 | }
203 |
--------------------------------------------------------------------------------
/mqtt/pkg/credentials.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "bytes"
5 | "encoding/base64"
6 | "encoding/json"
7 | "io"
8 |
9 | "github.com/tada/catch"
10 |
11 | "github.com/tada/catch/pio"
12 | "github.com/tada/jsonstream"
13 | )
14 |
15 | // Credentials are user credentials that originates from an MQTT CONNECT packet.
16 | type Credentials struct {
17 | User string
18 | Password []byte
19 | }
20 |
21 | // MarshalToJSON streams the JSON encoded form of this instance onto the given io.Writer
22 | func (c *Credentials) MarshalToJSON(w io.Writer) {
23 | pio.WriteByte(w, '{')
24 | if c.User != "" {
25 | pio.WriteString(w, `"u":`)
26 | jsonstream.WriteString(w, c.User)
27 | }
28 | if c.Password != nil {
29 | if c.User != "" {
30 | pio.WriteByte(w, ',')
31 | }
32 | pio.WriteString(w, `"p":`)
33 | jsonstream.WriteString(w, base64.StdEncoding.EncodeToString(c.Password))
34 | }
35 | pio.WriteByte(w, '}')
36 | }
37 |
38 | // UnmarshalFromJSON initializes this instance from the tokens stream provided by the json.Decoder. The
39 | // first token has already been read and is passed as an argument.
40 | func (c *Credentials) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) {
41 | jsonstream.AssertDelim(t, '{')
42 | for {
43 | k, ok := js.ReadStringOrEnd('}')
44 | if !ok {
45 | break
46 | }
47 | switch k {
48 | case "u":
49 | c.User = js.ReadString()
50 | case "p":
51 | p, err := base64.StdEncoding.DecodeString(js.ReadString())
52 | if err != nil {
53 | panic(catch.Error(err))
54 | }
55 | c.Password = p
56 | }
57 | }
58 | }
59 |
60 | // Equals returns true if this instance is equal to the given instance, false if not
61 | func (c *Credentials) Equals(oc *Credentials) bool {
62 | return c.User == oc.User && bytes.Equal(c.Password, oc.Password)
63 | }
64 |
--------------------------------------------------------------------------------
/mqtt/pkg/idmanager.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "encoding/json"
5 | "io"
6 | "sync"
7 |
8 | "github.com/tada/catch/pio"
9 | "github.com/tada/jsonstream"
10 | )
11 |
12 | // An IDManager manages packet IDs and ensures their uniqueness by maintaining a list of
13 | // IDs that are in use
14 | type IDManager interface {
15 | // NextFreePacketID allocates and returns the next free packet ID
16 | NextFreePacketID() uint16
17 |
18 | // ReleasePacketID releases a previously allocated packet ID
19 | ReleasePacketID(uint16)
20 | }
21 |
22 | type idManager struct {
23 | pkgIDLock sync.Mutex
24 | inFlight map[uint16]bool
25 | nextFreePkgID uint16
26 | }
27 |
28 | // NewIDManager creates a new IDManager
29 | func NewIDManager() IDManager {
30 | return &idManager{nextFreePkgID: 1, inFlight: make(map[uint16]bool, 37)}
31 | }
32 |
33 | func (s *idManager) NextFreePacketID() uint16 {
34 | s.pkgIDLock.Lock()
35 | s.nextFreePkgID++
36 | if s.nextFreePkgID == 0 {
37 | // counter flipped over and zero is not a valid ID
38 | s.nextFreePkgID++
39 | }
40 | for s.inFlight[s.nextFreePkgID] {
41 | s.nextFreePkgID++
42 | if s.nextFreePkgID == 0 {
43 | // counter flipped over and zero is not a valid ID
44 | s.nextFreePkgID++
45 | }
46 | }
47 | s.inFlight[s.nextFreePkgID] = true
48 | s.pkgIDLock.Unlock()
49 | return s.nextFreePkgID
50 | }
51 |
52 | func (s *idManager) ReleasePacketID(id uint16) {
53 | s.pkgIDLock.Lock()
54 | delete(s.inFlight, id)
55 | s.pkgIDLock.Unlock()
56 | }
57 |
58 | func (s *idManager) MarshalToJSON(w io.Writer) {
59 | var (
60 | nf uint16
61 | inf []uint16
62 | )
63 |
64 | // take a snapshot of things in flight
65 | s.pkgIDLock.Lock()
66 | nf = s.nextFreePkgID
67 | inf = make([]uint16, len(s.inFlight))
68 | i := 0
69 | for k := range s.inFlight {
70 | inf[i] = k
71 | i++
72 | }
73 | s.pkgIDLock.Unlock()
74 |
75 | pio.WriteString(w, `{"next":`)
76 | pio.WriteInt(w, int64(nf))
77 | if len(inf) > 0 {
78 | pio.WriteString(w, `,"inFlight":[`)
79 | for i := range inf {
80 | if i > 0 {
81 | pio.WriteByte(w, ',')
82 | }
83 | pio.WriteInt(w, int64(inf[i]))
84 | }
85 | pio.WriteByte(w, ']')
86 | }
87 | pio.WriteByte(w, '}')
88 | }
89 |
90 | func (s *idManager) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) {
91 | s.inFlight = make(map[uint16]bool, 37)
92 | jsonstream.AssertDelim(t, '{')
93 | for {
94 | k, ok := js.ReadStringOrEnd('}')
95 | if !ok {
96 | break
97 | }
98 | switch k {
99 | case "next":
100 | s.nextFreePkgID = uint16(js.ReadInt())
101 | case "inFlight":
102 | js.ReadDelim('[')
103 | for {
104 | i, ok := js.ReadIntOrEnd(']')
105 | if !ok {
106 | break
107 | }
108 | s.inFlight[uint16(i)] = true
109 | }
110 | }
111 | }
112 | }
113 |
--------------------------------------------------------------------------------
/mqtt/pkg/idmanager_test.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "math"
5 | "testing"
6 |
7 | "github.com/tada/jsonstream"
8 | "github.com/tada/mqtt-nats/test/utils"
9 | )
10 |
11 | func TestIdManager_NextFreePacketID_flip(t *testing.T) {
12 | idm := NewIDManager().(*idManager)
13 | idm.nextFreePkgID = math.MaxUint16 - 1
14 | utils.CheckEqual(uint16(math.MaxUint16), idm.NextFreePacketID(), t)
15 | utils.CheckEqual(uint16(1), idm.NextFreePacketID(), t)
16 | }
17 |
18 | func TestIdManager_NextFreePacketID_flipInFlight(t *testing.T) {
19 | idm := NewIDManager().(*idManager)
20 | idm.nextFreePkgID = math.MaxUint16 - 2
21 | idm.inFlight[uint16(math.MaxUint16)] = true
22 | utils.CheckEqual(uint16(math.MaxUint16-1), idm.NextFreePacketID(), t)
23 | utils.CheckEqual(uint16(1), idm.NextFreePacketID(), t)
24 | }
25 |
26 | func TestIdManager_NextFreePacketID_json(t *testing.T) {
27 | idm := NewIDManager().(*idManager)
28 | idm.nextFreePkgID = math.MaxUint16 - 3
29 | idm.inFlight[uint16(math.MaxUint16-1)] = true
30 | idm.inFlight[uint16(math.MaxUint16)] = true
31 | idm.inFlight[uint16(1)] = true
32 | bs, err := jsonstream.Marshal(idm)
33 | utils.CheckNotError(err, t)
34 | idm = &idManager{}
35 | utils.CheckNotError(jsonstream.Unmarshal(idm, bs), t)
36 | utils.CheckEqual(uint16(math.MaxUint16-2), idm.NextFreePacketID(), t)
37 | utils.CheckEqual(uint16(2), idm.NextFreePacketID(), t)
38 | }
39 |
--------------------------------------------------------------------------------
/mqtt/pkg/packet.go:
--------------------------------------------------------------------------------
1 | // Package pkg contains the MQTT packet structures
2 | package pkg
3 |
4 | import "github.com/tada/mqtt-nats/mqtt"
5 |
6 | const (
7 | // TpConnect is the MQTT CONNECT type
8 | TpConnect = 0x10
9 |
10 | // TpConnAck is the MQTT CONNACK type
11 | TpConnAck = 0x20
12 |
13 | // TpPublish is the MQTT PUBLISH type
14 | TpPublish = 0x30
15 |
16 | // TpPubAck is the MQTT PUBACK type
17 | TpPubAck = 0x40
18 |
19 | // TpPubRec is the MQTT PUBREC type
20 | TpPubRec = 0x50
21 |
22 | // TpPubRel is the MQTT PUBREL type
23 | TpPubRel = 0x60
24 |
25 | // TpPubComp is the MQTT PUBCOMP type
26 | TpPubComp = 0x70
27 |
28 | // TpSubscribe is the MQTT SUBSCRIBE type
29 | TpSubscribe = 0x80
30 |
31 | // TpSubAck is the MQTT SUBACK type
32 | TpSubAck = 0x90
33 |
34 | // TpUnsubscribe is the MQTT UNSUBSCRIBE type
35 | TpUnsubscribe = 0xa0
36 |
37 | // TpUnsubAck is the MQTT UNSUBACK type
38 | TpUnsubAck = 0xb0
39 |
40 | // TpPing is the MQTT PINGREQ type
41 | TpPing = 0xc0
42 |
43 | // TpPingResp is the MQTT PINGRESP type
44 | TpPingResp = 0xd0
45 |
46 | // TpDisconnect is the MQTT DISCONNECT type
47 | TpDisconnect = 0xe0
48 |
49 | // TpMask is bitmask for the MQTT type
50 | TpMask = 0xf0
51 | )
52 |
53 | // The Packet interface is implemented by all MQTT packet types
54 | type Packet interface {
55 | // Equals returns true if this packet is equal to the given packet, false if not
56 | Equals(other interface{}) bool
57 |
58 | // Write writes the MQTT bits of this packet on the given Writer
59 | Write(w *mqtt.Writer)
60 | }
61 |
--------------------------------------------------------------------------------
/mqtt/pkg/packet_test.go:
--------------------------------------------------------------------------------
1 | package pkg_test
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | "testing"
7 |
8 | "github.com/tada/mqtt-nats/mqtt"
9 | "github.com/tada/mqtt-nats/mqtt/pkg"
10 | "github.com/tada/mqtt-nats/test/packet"
11 | )
12 |
13 | func writeReadAndCompare(t *testing.T, p pkg.Packet, ex string) {
14 | w := &mqtt.Writer{}
15 | p.Write(w)
16 | p2 := packet.Parse(t, bytes.NewReader(w.Bytes()))
17 | if !p.Equals(p2) {
18 | t.Fatal(p, "!=", p2)
19 | }
20 | ac := p.(fmt.Stringer).String()
21 | if ex != ac {
22 | t.Errorf("expected '%s' got '%s'", ex, ac)
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/mqtt/pkg/ping.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import "github.com/tada/mqtt-nats/mqtt"
4 |
5 | // The PingRequest type represents the MQTT PINGREQ packet
6 | type PingRequest int
7 |
8 | // PingRequestSingleton is the one and only instance of the PingRequest type
9 | const PingRequestSingleton = PingRequest(0)
10 |
11 | // Equals returns true if this packet is equal to the given packet, false if not
12 | func (PingRequest) Equals(other interface{}) bool {
13 | return other == PingRequestSingleton
14 | }
15 |
16 | // String returns a brief string representation of the packet. Suitable for logging
17 | func (PingRequest) String() string {
18 | return "PINGREQ"
19 | }
20 |
21 | // Write writes the MQTT bits of this packet on the given Writer
22 | func (PingRequest) Write(w *mqtt.Writer) {
23 | w.WriteU8(TpPing)
24 | w.WriteU8(0)
25 | }
26 |
27 | // The PingResponse type represents the MQTT PINGRESP packet
28 | type PingResponse int
29 |
30 | // PingResponseSingleton is the one and only instance of the PingResponse type
31 | const PingResponseSingleton = PingResponse(0)
32 |
33 | // Equals returns true if this packet is equal to the given packet, false if not
34 | func (PingResponse) Equals(other interface{}) bool {
35 | return other == PingResponseSingleton
36 | }
37 |
38 | // String returns a brief string representation of the packet. Suitable for logging
39 | func (PingResponse) String() string {
40 | return "PINGRESP"
41 | }
42 |
43 | // Write writes the MQTT bits of this packet on the given Writer
44 | func (PingResponse) Write(w *mqtt.Writer) {
45 | w.WriteU8(TpPingResp)
46 | w.WriteU8(0)
47 | }
48 |
--------------------------------------------------------------------------------
/mqtt/pkg/ping_test.go:
--------------------------------------------------------------------------------
1 | package pkg_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/tada/mqtt-nats/mqtt/pkg"
7 | )
8 |
9 | func TestParsePingReq(t *testing.T) {
10 | writeReadAndCompare(t, pkg.PingRequestSingleton, "PINGREQ")
11 | }
12 |
13 | func TestParsePingResp(t *testing.T) {
14 | writeReadAndCompare(t, pkg.PingResponseSingleton, "PINGRESP")
15 | }
16 |
--------------------------------------------------------------------------------
/mqtt/pkg/publish.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "bytes"
5 | "encoding/base64"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | "io"
10 |
11 | "github.com/tada/catch"
12 |
13 | "github.com/tada/catch/pio"
14 | "github.com/tada/jsonstream"
15 | "github.com/tada/mqtt-nats/mqtt"
16 | )
17 |
18 | const (
19 | // PublishRetain is the bit representing MQTT PUBLISH "retain" flag
20 | PublishRetain = 0x01
21 |
22 | // PublishQoS is the mask for the MQTT PUBLISH "quality of service" bits
23 | PublishQoS = 0x06
24 |
25 | // PublishDup is the bit representing MQTT PUBLISH "dup" flag
26 | PublishDup = 0x08
27 | )
28 |
29 | // The Publish type represents the MQTT PUBLISH packet
30 | type Publish struct {
31 | name string
32 | replyTo string
33 | payload []byte
34 | id uint16
35 | flags byte
36 | sentByUs bool // set if the message originated from this server (happens when a client will is published)
37 | }
38 |
39 | // SimplePublish creates a new Publish packet with all flags zero and no reply
40 | func SimplePublish(topic string, payload []byte) *Publish {
41 | return &Publish{name: topic, payload: payload}
42 | }
43 |
44 | // NewPublish creates a new Publish packet
45 | func NewPublish(id uint16, topic string, flags byte, payload []byte, sentByUs bool, natsReplyTo string) *Publish {
46 | return &Publish{
47 | name: topic,
48 | replyTo: natsReplyTo,
49 | payload: payload,
50 | id: id,
51 | flags: flags,
52 | sentByUs: sentByUs,
53 | }
54 | }
55 |
56 | // NewPublish2 creates a new Publish packet
57 | func NewPublish2(id uint16, topic string, payload []byte, qos byte, dup bool, retain bool) *Publish {
58 | flags := byte(0)
59 | if qos > 0 {
60 | flags |= qos << 1
61 | }
62 | if dup {
63 | flags |= PublishDup
64 | }
65 | if retain {
66 | flags |= PublishRetain
67 | }
68 | return &Publish{id: id, flags: flags, name: topic, payload: payload, sentByUs: false, replyTo: ""}
69 | }
70 |
71 | // ParsePublish parses the publish packet from the given reader.
72 | func ParsePublish(r *mqtt.Reader, flags byte, pkLen int) (Packet, error) {
73 | var err error
74 | if r, err = r.ReadPacket(pkLen); err != nil {
75 | return nil, err
76 | }
77 |
78 | pp := &Publish{flags: flags & 0xf}
79 | if pp.name, err = r.ReadString(); err != nil {
80 | return nil, err
81 | }
82 |
83 | if pp.QoSLevel() > 0 {
84 | if pp.id, err = r.ReadUint16(); err != nil {
85 | return nil, err
86 | }
87 | }
88 | if pp.payload, err = r.ReadRemainingBytes(); err != nil {
89 | return nil, err
90 | }
91 | return pp, nil
92 | }
93 |
94 | // Equals returns true if this packet is equal to the given packet, false if not
95 | func (p *Publish) Equals(other interface{}) bool {
96 | op, ok := other.(*Publish)
97 | return ok &&
98 | p.id == op.id &&
99 | p.flags == op.flags &&
100 | p.sentByUs == op.sentByUs &&
101 | p.name == op.name &&
102 | p.replyTo == op.replyTo &&
103 | bytes.Equal(p.payload, op.payload)
104 | }
105 |
106 | // Flags returns the packet flags
107 | func (p *Publish) Flags() byte {
108 | return p.flags
109 | }
110 |
111 | // ID returns the MQTT Packet Identifier. The identifier is only valid if QoS > 0
112 | func (p *Publish) ID() uint16 {
113 | return p.id
114 | }
115 |
116 | // IsDup returns true if the packet is a duplicate of a previously sent packet
117 | func (p *Publish) IsDup() bool {
118 | return (p.flags & PublishDup) != 0
119 | }
120 |
121 | // SetDup sets the dup flag of the packet
122 | func (p *Publish) SetDup() {
123 | p.flags |= PublishDup
124 | }
125 |
126 | // IsPrintableASCII returns true if the given bytes are constrained to the ASCII 7-bit character set and
127 | // has no control characters.
128 | func IsPrintableASCII(bs []byte) bool {
129 | for i := range bs {
130 | c := bs[i]
131 | if c < 32 || c > 127 {
132 | return false
133 | }
134 | }
135 | return true
136 | }
137 |
138 | // MarshalToJSON marshals the packet as a JSON object onto the given writer
139 | func (p *Publish) MarshalToJSON(w io.Writer) {
140 | pio.WriteString(w, `{"flags":`)
141 | pio.WriteInt(w, int64(p.flags))
142 | pio.WriteString(w, `,"id":`)
143 | pio.WriteInt(w, int64(p.id))
144 | pio.WriteString(w, `,"name":`)
145 | jsonstream.WriteString(w, p.name)
146 | if p.replyTo != "" {
147 | pio.WriteString(w, `,"replyTo":`)
148 | jsonstream.WriteString(w, p.replyTo)
149 | }
150 | if len(p.payload) > 0 {
151 | if IsPrintableASCII(p.payload) {
152 | pio.WriteString(w, `,"payload":`)
153 | jsonstream.WriteString(w, string(p.payload))
154 | } else {
155 | pio.WriteString(w, `,"payloadEnc":`)
156 | jsonstream.WriteString(w, base64.StdEncoding.EncodeToString(p.payload))
157 | }
158 | }
159 | pio.WriteByte(w, '}')
160 | }
161 |
162 | // NatsReplyTo returns the NATS replyTo subject. Only valid when the packet represents something
163 | // received from NATS due to a client subscribing to a topic with QoS level > 0
164 | func (p *Publish) NatsReplyTo() string {
165 | return p.replyTo
166 | }
167 |
168 | // Payload returns the payload of the published message
169 | func (p *Publish) Payload() []byte {
170 | return p.payload
171 | }
172 |
173 | // QoSLevel returns the quality of service level which is 0, 1 or 2.
174 | func (p *Publish) QoSLevel() byte {
175 | return (p.flags & PublishQoS) >> 1
176 | }
177 |
178 | // ResetRetain resets the retain flag
179 | func (p *Publish) ResetRetain() {
180 | p.flags &^= PublishRetain
181 | }
182 |
183 | // Retain returns the retain flag setting
184 | func (p *Publish) Retain() bool {
185 | return (p.flags & PublishRetain) != 0
186 | }
187 |
188 | // String returns a brief string representation of the packet. Suitable for logging
189 | func (p *Publish) String() string {
190 | // layout borrowed from mosquitto_sub log output
191 | return fmt.Sprintf("PUBLISH (d%d, q%d, r%b, m%d, '%s', ... (%d bytes))",
192 | (p.flags&0x08)>>3,
193 | p.QoSLevel(),
194 | p.flags&0x01,
195 | p.ID(),
196 | p.name,
197 | len(p.payload))
198 | }
199 |
200 | // TopicName returns the name of the topic
201 | func (p *Publish) TopicName() string {
202 | return p.name
203 | }
204 |
205 | // UnmarshalFromJSON expects the given token to be the object start '{'. If it is, the rest
206 | // of the object is unmarshalled into the receiver. The method will panic with a pio.Error
207 | // if any errors are detected.
208 | //
209 | // See jsonstreamer.Consumer for more info.
210 | func (p *Publish) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) {
211 | jsonstream.AssertDelim(t, '{')
212 | for {
213 | s, ok := js.ReadStringOrEnd('}')
214 | if !ok {
215 | break
216 | }
217 | switch s {
218 | case "flags":
219 | p.flags = byte(js.ReadInt())
220 | case "id":
221 | p.id = uint16(js.ReadInt())
222 | case "name":
223 | p.name = js.ReadString()
224 | case "replyTo":
225 | p.replyTo = js.ReadString()
226 | case "payload":
227 | p.payload = []byte(js.ReadString())
228 | case "payloadEnc":
229 | var err error
230 | p.payload, err = base64.StdEncoding.DecodeString(js.ReadString())
231 | if err != nil {
232 | panic(catch.Error(err))
233 | }
234 | }
235 | }
236 | }
237 |
238 | // Write writes the MQTT bits of this packet on the given Writer
239 | func (p *Publish) Write(w *mqtt.Writer) {
240 | w.WriteU8(TpPublish | p.flags)
241 | pkLen := 2 + len(p.name) + len(p.payload)
242 | if p.QoSLevel() > 0 {
243 | pkLen += 2
244 | }
245 | w.WriteVarInt(pkLen)
246 | w.WriteString(p.name)
247 | if p.QoSLevel() > 0 {
248 | w.WriteU16(p.id)
249 | }
250 | _, _ = w.Write(p.payload)
251 | }
252 |
253 | // The PubAck type represents the MQTT PUBACK packet
254 | type PubAck uint16
255 |
256 | // ParsePubAck parses a PUBACK packet
257 | func ParsePubAck(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) {
258 | if pkLen != 2 {
259 | return PubAck(0), errors.New("malformed PUBACK")
260 | }
261 | id, err := r.ReadUint16()
262 | return PubAck(id), err
263 | }
264 |
265 | // Equals returns true if this packet is equal to the given packet, false if not
266 | func (p PubAck) Equals(other interface{}) bool {
267 | return p == other
268 | }
269 |
270 | // ID returns the packet ID
271 | func (p PubAck) ID() uint16 {
272 | return uint16(p)
273 | }
274 |
275 | // String returns a brief string representation of the packet. Suitable for logging
276 | func (p PubAck) String() string {
277 | return fmt.Sprintf("PUBACK (m%d)", p.ID())
278 | }
279 |
280 | // Write writes the MQTT bits of this packet on the given Writer
281 | func (p PubAck) Write(w *mqtt.Writer) {
282 | w.WriteU8(TpPubAck)
283 | w.WriteU8(2)
284 | w.WriteU16(uint16(p))
285 | }
286 |
287 | // The PubRec type represents the MQTT PUBREC packet
288 | type PubRec uint16
289 |
290 | // ParsePubRec parses a PUBREC packet
291 | func ParsePubRec(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) {
292 | if pkLen != 2 {
293 | return PubRec(0), errors.New("malformed PUBREC")
294 | }
295 | id, err := r.ReadUint16()
296 | return PubRec(id), err
297 | }
298 |
299 | // Equals returns true if this packet is equal to the given packet, false if not
300 | func (p PubRec) Equals(other interface{}) bool {
301 | return p == other
302 | }
303 |
304 | // ID returns the packet ID
305 | func (p PubRec) ID() uint16 {
306 | return uint16(p)
307 | }
308 |
309 | // String returns a brief string representation of the packet. Suitable for logging
310 | func (p PubRec) String() string {
311 | return fmt.Sprintf("PUBREC (m%d)", p.ID())
312 | }
313 |
314 | // Write writes the MQTT bits of this packet on the given Writer
315 | func (p PubRec) Write(w *mqtt.Writer) {
316 | w.WriteU8(TpPubRec)
317 | w.WriteU8(2)
318 | w.WriteU16(uint16(p))
319 | }
320 |
321 | // The PubRel type represents the MQTT PUBREL packet
322 | type PubRel uint16
323 |
324 | // ParsePubRel parses a PUBREL packet
325 | func ParsePubRel(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) {
326 | if pkLen != 2 {
327 | return PubRel(0), errors.New("malformed PUBREL")
328 | }
329 | id, err := r.ReadUint16()
330 | return PubRel(id), err
331 | }
332 |
333 | // Equals returns true if this packet is equal to the given packet, false if not
334 | func (p PubRel) Equals(other interface{}) bool {
335 | return p == other
336 | }
337 |
338 | // ID returns the packet ID
339 | func (p PubRel) ID() uint16 {
340 | return uint16(p)
341 | }
342 |
343 | // String returns a brief string representation of the packet. Suitable for logging
344 | func (p PubRel) String() string {
345 | return fmt.Sprintf("PUBREL (m%d)", p.ID())
346 | }
347 |
348 | // Write writes the MQTT bits of this packet on the given Writer
349 | func (p PubRel) Write(w *mqtt.Writer) {
350 | w.WriteU8(TpPubRel)
351 | w.WriteU8(2)
352 | w.WriteU16(uint16(p))
353 | }
354 |
355 | // The PubComp type represents the MQTT PUBCOMP packet
356 | type PubComp uint16
357 |
358 | // ParsePubComp parses a PUBCOMP packet
359 | func ParsePubComp(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) {
360 | if pkLen != 2 {
361 | return PubComp(0), errors.New("malformed PUBCOMP")
362 | }
363 | id, err := r.ReadUint16()
364 | return PubComp(id), err
365 | }
366 |
367 | // Equals returns true if this packet is equal to the given packet, false if not
368 | func (p PubComp) Equals(other interface{}) bool {
369 | return p == other
370 | }
371 |
372 | // ID returns the packet ID
373 | func (p PubComp) ID() uint16 {
374 | return uint16(p)
375 | }
376 |
377 | // String returns a brief string representation of the packet. Suitable for logging
378 | func (p PubComp) String() string {
379 | return fmt.Sprintf("PUBCOMP (m%d)", p.ID())
380 | }
381 |
382 | // Write writes the MQTT bits of this packet on the given Writer
383 | func (p PubComp) Write(w *mqtt.Writer) {
384 | w.WriteU8(TpPubComp)
385 | w.WriteU8(2)
386 | w.WriteU16(uint16(p))
387 | }
388 |
--------------------------------------------------------------------------------
/mqtt/pkg/publish_test.go:
--------------------------------------------------------------------------------
1 | package pkg_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/tada/jsonstream"
7 | "github.com/tada/mqtt-nats/mqtt/pkg"
8 | )
9 |
10 | func TestParsePublish(t *testing.T) {
11 | writeReadAndCompare(t, pkg.NewPublish(23, "some/topic", 2, []byte(`the "message"`), false, ""),
12 | "PUBLISH (d0, q1, r0, m23, 'some/topic', ... (13 bytes))")
13 | }
14 |
15 | func TestParsePubAck(t *testing.T) {
16 | writeReadAndCompare(t, pkg.PubAck(23), "PUBACK (m23)")
17 | }
18 |
19 | func TestParsePubRec(t *testing.T) {
20 | writeReadAndCompare(t, pkg.PubRec(23), "PUBREC (m23)")
21 | }
22 |
23 | func TestParsePubRel(t *testing.T) {
24 | writeReadAndCompare(t, pkg.PubRel(23), "PUBREL (m23)")
25 | }
26 |
27 | func TestParsePubComp(t *testing.T) {
28 | writeReadAndCompare(t, pkg.PubComp(23), "PUBCOMP (m23)")
29 | }
30 |
31 | func TestPublish_MarshalToJSON(t *testing.T) {
32 | p1 := pkg.NewPublish(23, "some/topic", 2, []byte(`the "message"`), false, "")
33 | bs, err := jsonstream.Marshal(p1)
34 | if err != nil {
35 | t.Fatal(err)
36 | }
37 |
38 | p2 := &pkg.Publish{}
39 | err = jsonstream.Unmarshal(p2, bs)
40 | if err != nil {
41 | t.Fatal(err)
42 | }
43 | if !p1.Equals(p2) {
44 | t.Fatal(p1, "!=", p2)
45 | }
46 | }
47 |
48 | func TestPublish_MarshalToJSON_nonUTF(t *testing.T) {
49 | p1 := pkg.NewPublish(23, "some/topic", 2, []byte{0, 1, 2, 3, 5}, false, "")
50 | bs, err := jsonstream.Marshal(p1)
51 | if err != nil {
52 | t.Fatal(err)
53 | }
54 |
55 | p2 := &pkg.Publish{}
56 | err = jsonstream.Unmarshal(p2, bs)
57 | if err != nil {
58 | t.Fatal(err)
59 | }
60 | if !p1.Equals(p2) {
61 | t.Fatal(p1, "!=", p2)
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/mqtt/pkg/subscribe.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "strconv"
7 |
8 | "github.com/tada/mqtt-nats/mqtt"
9 | )
10 |
11 | // Topic is an MQTT Topic subscription with name and desired quality of service
12 | type Topic struct {
13 | // Name is the Topic Name
14 | Name string
15 |
16 | // QoS Quality of Service, will be 0, 1, or 2.
17 | QoS byte
18 | }
19 |
20 | // Subscribe is the MQTT subscribe packet
21 | type Subscribe struct {
22 | id uint16
23 | topics []Topic
24 | }
25 |
26 | const fixedSubscribeFlags = 2
27 |
28 | // NewSubscribe creates a new MQTT subscribe packet
29 | func NewSubscribe(id uint16, topics ...Topic) *Subscribe {
30 | return &Subscribe{id: id, topics: topics}
31 | }
32 |
33 | // ParseSubscribe parses the subscribe packet from the given reader.
34 | func ParseSubscribe(r *mqtt.Reader, b byte, pkLen int) (Packet, error) {
35 | if (b & 0xf) != fixedSubscribeFlags {
36 | return nil, errors.New("malformed subscribe header")
37 | }
38 |
39 | var err error
40 | if r, err = r.ReadPacket(pkLen); err != nil {
41 | return nil, err
42 | }
43 |
44 | sp := &Subscribe{}
45 | if sp.id, err = r.ReadUint16(); err != nil {
46 | return nil, err
47 | }
48 |
49 | for r.Len() > 0 {
50 | t := Topic{}
51 | if t.Name, err = r.ReadString(); err != nil {
52 | return nil, err
53 | }
54 | if t.QoS, err = r.ReadByte(); err != nil {
55 | return nil, err
56 | }
57 | if t.QoS > 2 {
58 | return nil, errors.New("malformed subscribed topic QoS")
59 | }
60 | sp.topics = append(sp.topics, t)
61 | }
62 | return sp, nil
63 | }
64 |
65 | // ID returns the MQTT Packet Identifier
66 | func (s *Subscribe) ID() uint16 {
67 | return s.id
68 | }
69 |
70 | // Equals returns true if this packet is equal to the given packet, false if not
71 | func (s *Subscribe) Equals(other interface{}) bool {
72 | if os, ok := other.(*Subscribe); ok && s.id == os.id && len(s.topics) == len(os.topics) {
73 | for i := range s.topics {
74 | if s.topics[i] != os.topics[i] {
75 | return false
76 | }
77 | }
78 | return true
79 | }
80 | return false
81 | }
82 |
83 | // String returns a brief string representation of the packet. Suitable for logging
84 | func (s *Subscribe) String() string {
85 | bs := bytes.NewBufferString("SUBSCRIBE (m")
86 | bs.WriteString(strconv.Itoa(int(s.ID())))
87 | bs.WriteString(", ")
88 | wt := func(t Topic) {
89 | bs.WriteByte('q')
90 | bs.WriteString(strconv.Itoa(int(t.QoS)))
91 | bs.WriteString(", '")
92 | bs.WriteString(t.Name)
93 | bs.WriteByte('\'')
94 | }
95 | if len(s.topics) != 1 {
96 | bs.WriteByte('[')
97 | for i, t := range s.topics {
98 | if i > 0 {
99 | bs.WriteString(", ")
100 | }
101 | bs.WriteByte('(')
102 | wt(t)
103 | bs.WriteByte(')')
104 | }
105 | bs.WriteByte(']')
106 | } else {
107 | wt(s.topics[0])
108 | }
109 | bs.WriteByte(')')
110 | return bs.String()
111 | }
112 |
113 | // Topics returns the list of topics to subscribe to
114 | func (s *Subscribe) Topics() []Topic {
115 | return s.topics
116 | }
117 |
118 | // Write writes the MQTT bits of this packet on the given Writer
119 | func (s *Subscribe) Write(w *mqtt.Writer) {
120 | pkLen := 2 // id
121 | for i := range s.topics {
122 | pkLen += 3 + len(s.topics[i].Name)
123 | }
124 | w.WriteU8(TpSubscribe | fixedSubscribeFlags)
125 | w.WriteVarInt(pkLen)
126 | w.WriteU16(s.id)
127 | for i := range s.topics {
128 | t := s.topics[i]
129 | w.WriteString(t.Name)
130 | w.WriteU8(t.QoS)
131 | }
132 | }
133 |
134 | // SubAck is the MQTT SUBACK packet sent in response to a SUBSCRIBE
135 | type SubAck struct {
136 | id uint16
137 | topicReturns []byte
138 | }
139 |
140 | // NewSubAck creates an SUBACK packet
141 | func NewSubAck(id uint16, topicReturns ...byte) *SubAck {
142 | return &SubAck{id: id, topicReturns: topicReturns}
143 | }
144 |
145 | // ParseSubAck parses a SUBACK packet
146 | func ParseSubAck(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) {
147 | var err error
148 | if r, err = r.ReadPacket(pkLen); err != nil {
149 | return nil, err
150 | }
151 | s := &SubAck{}
152 | if s.id, err = r.ReadUint16(); err != nil {
153 | return nil, err
154 | }
155 | if s.topicReturns, err = r.ReadExact(pkLen - 2); err != nil {
156 | return nil, err
157 | }
158 | return s, nil
159 | }
160 |
161 | // Equals returns true if this packet is equal to the given packet, false if not
162 | func (s *SubAck) Equals(other interface{}) bool {
163 | os, ok := other.(*SubAck)
164 | return ok && s.id == os.id && bytes.Equal(s.topicReturns, os.topicReturns)
165 | }
166 |
167 | // ID returns the packet ID
168 | func (s *SubAck) ID() uint16 {
169 | return s.id
170 | }
171 |
172 | // String returns a brief string representation of the packet. Suitable for logging
173 | func (s *SubAck) String() string {
174 | bs := bytes.NewBufferString("SUBACK (m")
175 | bs.WriteString(strconv.Itoa(int(s.ID())))
176 | bs.WriteString(", ")
177 | if len(s.topicReturns) != 1 {
178 | bs.WriteByte('[')
179 | for i, t := range s.topicReturns {
180 | if i > 0 {
181 | bs.WriteString(", ")
182 | }
183 | bs.WriteString("rc")
184 | bs.WriteString(strconv.Itoa(int(t)))
185 | }
186 | bs.WriteByte(']')
187 | } else {
188 | bs.WriteString("rc")
189 | bs.WriteString(strconv.Itoa(int(s.topicReturns[0])))
190 | }
191 | bs.WriteByte(')')
192 | return bs.String()
193 | }
194 |
195 | // TopicReturns returns the desired QoS value for each subscribed topic
196 | func (s *SubAck) TopicReturns() []byte {
197 | return s.topicReturns
198 | }
199 |
200 | // Write writes the MQTT bits of this packet on the given Writer
201 | func (s *SubAck) Write(w *mqtt.Writer) {
202 | w.WriteU8(TpSubAck)
203 | w.WriteVarInt(2 + len(s.topicReturns))
204 | w.WriteU16(s.id)
205 | _, _ = w.Write(s.topicReturns)
206 | }
207 |
--------------------------------------------------------------------------------
/mqtt/pkg/subscribe_test.go:
--------------------------------------------------------------------------------
1 | package pkg_test
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 |
7 | "github.com/tada/mqtt-nats/test/packet"
8 |
9 | "github.com/tada/mqtt-nats/mqtt/pkg"
10 |
11 | "github.com/tada/mqtt-nats/mqtt"
12 | "github.com/tada/mqtt-nats/test/utils"
13 | )
14 |
15 | func TestParseSubscribe(t *testing.T) {
16 | writeReadAndCompare(t, pkg.NewSubscribe(23, pkg.Topic{Name: "some/topic", QoS: 1}),
17 | "SUBSCRIBE (m23, q1, 'some/topic')")
18 | writeReadAndCompare(t, pkg.NewSubscribe(23, pkg.Topic{Name: "some/topic", QoS: 0}, pkg.Topic{Name: "some/other"}),
19 | "SUBSCRIBE (m23, [(q0, 'some/topic'), (q0, 'some/other')])")
20 | }
21 |
22 | func TestParseSubscribe_badFlags(t *testing.T) {
23 | utils.EnsureFailed(t, func(st *testing.T) {
24 | packet.Parse(st, mqtt.NewReader(bytes.NewReader([]byte{pkg.TpSubscribe | 1, 0})))
25 | })
26 | }
27 |
28 | func TestParseSubscribe_badLen(t *testing.T) {
29 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader([]byte{})), pkg.TpSubscribe|2, 8)
30 | utils.CheckError(err, t)
31 | }
32 |
33 | func TestParseSubscribe_badId(t *testing.T) {
34 | w := mqtt.NewWriter()
35 | w.WriteU16(28)
36 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader([]byte{1})), pkg.TpConnAck|2, 1)
37 | utils.CheckError(err, t)
38 | }
39 |
40 | func TestParseSubscribe_badCount(t *testing.T) {
41 | w := mqtt.NewWriter()
42 | w.WriteU16(28)
43 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader([]byte{0, 2, 3})), pkg.TpConnAck|2, 3)
44 | utils.CheckError(err, t)
45 | }
46 |
47 | func TestParseSubscribe_badTopic(t *testing.T) {
48 | w := mqtt.NewWriter()
49 | w.WriteU16(28)
50 | w.WriteU16(1)
51 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck|2, 4)
52 | utils.CheckError(err, t)
53 | }
54 |
55 | func TestParseSubscribe_badQoS(t *testing.T) {
56 | w := mqtt.NewWriter()
57 | w.WriteU16(28)
58 | w.WriteString("tpc")
59 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck|2, 7)
60 | utils.CheckError(err, t)
61 | }
62 |
63 | func TestParseSubscribe_invalidQoS(t *testing.T) {
64 | w := mqtt.NewWriter()
65 | w.WriteU16(28)
66 | w.WriteString("tpc")
67 | w.WriteU8(3)
68 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck|2, 8)
69 | utils.CheckError(err, t)
70 | }
71 |
72 | func TestSubscribe_Equals(t *testing.T) {
73 | a := pkg.NewSubscribe(32, pkg.Topic{Name: "a"}, pkg.Topic{Name: "b", QoS: 1})
74 | b := pkg.NewSubscribe(32, pkg.Topic{Name: "a"}, pkg.Topic{Name: "b", QoS: 1})
75 | utils.CheckTrue(a.Equals(b), t)
76 |
77 | b = pkg.NewSubscribe(32, pkg.Topic{Name: "a"}, pkg.Topic{Name: "b", QoS: 2})
78 | utils.CheckFalse(a.Equals(b), t)
79 |
80 | c := pkg.NewSubAck(32, 0, 2)
81 | utils.CheckFalse(a.Equals(c), t)
82 | }
83 |
84 | func TestParseSubAck(t *testing.T) {
85 | writeReadAndCompare(t, pkg.NewSubAck(23, 1), "SUBACK (m23, rc1)")
86 | writeReadAndCompare(t, pkg.NewSubAck(23, 1, 1), "SUBACK (m23, [rc1, rc1])")
87 | }
88 |
89 | func TestSubAck_Equals(t *testing.T) {
90 | a := pkg.NewSubAck(32, 1, 2, 0)
91 | b := pkg.NewSubAck(32, 1, 2, 0)
92 | utils.CheckTrue(a.Equals(b), t)
93 |
94 | b = pkg.NewSubAck(32, 1, 2, 1)
95 | utils.CheckFalse(a.Equals(b), t)
96 |
97 | c := pkg.NewSubscribe(32, pkg.Topic{Name: "a"}, pkg.Topic{Name: "b", QoS: 1})
98 | utils.CheckFalse(a.Equals(c), t)
99 | }
100 |
--------------------------------------------------------------------------------
/mqtt/pkg/unsubscribe.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "bytes"
5 | "errors"
6 | "fmt"
7 | "strconv"
8 |
9 | "github.com/tada/mqtt-nats/mqtt"
10 | )
11 |
12 | // Unsubscribe is the MQTT UNSUBSCRIBE packet
13 | type Unsubscribe struct {
14 | id uint16
15 | topics []string
16 | }
17 |
18 | // NewUnsubscribe creates a new Unsubscribe packet
19 | func NewUnsubscribe(id uint16, topics ...string) *Unsubscribe {
20 | return &Unsubscribe{id: id, topics: topics}
21 | }
22 |
23 | // ParseUnsubscribe parses the unsubscribe packet from the given reader.
24 | func ParseUnsubscribe(r *mqtt.Reader, b byte, pkLen int) (Packet, error) {
25 | if (b & 0xf) != 2 {
26 | return nil, errors.New("malformed unsubscribe header")
27 | }
28 |
29 | var err error
30 | if r, err = r.ReadPacket(pkLen); err != nil {
31 | return nil, err
32 | }
33 |
34 | up := &Unsubscribe{}
35 | if up.id, err = r.ReadUint16(); err != nil {
36 | return nil, err
37 | }
38 |
39 | for r.Len() > 0 {
40 | var name string
41 | if name, err = r.ReadString(); err != nil {
42 | return nil, err
43 | }
44 | up.topics = append(up.topics, name)
45 | }
46 | return up, nil
47 | }
48 |
49 | // Write writes the MQTT bits of this packet on the given Writer
50 | func (u *Unsubscribe) Write(w *mqtt.Writer) {
51 | pkLen := 2 // packet id
52 | tps := u.topics
53 | for i := range tps {
54 | pkLen += 2 + len(tps[i])
55 | }
56 | w.WriteU8(TpUnsubscribe | 2)
57 | w.WriteVarInt(pkLen)
58 | w.WriteU16(u.id)
59 | for i := range tps {
60 | w.WriteString(tps[i])
61 | }
62 | }
63 |
64 | // Equals returns true if this packet is equal to the given packet, false if not
65 | func (u *Unsubscribe) Equals(other interface{}) bool {
66 | if os, ok := other.(*Unsubscribe); ok && u.id == os.id && len(u.topics) == len(os.topics) {
67 | for i := range u.topics {
68 | if u.topics[i] != os.topics[i] {
69 | return false
70 | }
71 | }
72 | return true
73 | }
74 | return false
75 | }
76 |
77 | // ID returns the MQTT Packet Identifier
78 | func (u *Unsubscribe) ID() uint16 {
79 | return u.id
80 | }
81 |
82 | // String returns a brief string representation of the packet. Suitable for logging
83 | func (u *Unsubscribe) String() string {
84 | bs := bytes.NewBufferString("UNSUBSCRIBE (m")
85 | bs.WriteString(strconv.Itoa(int(u.ID())))
86 | bs.WriteString(", [")
87 | for i, t := range u.topics {
88 | if i > 0 {
89 | bs.WriteString(", ")
90 | }
91 | bs.WriteByte('\'')
92 | bs.WriteString(t)
93 | bs.WriteByte('\'')
94 | }
95 | bs.WriteString("])")
96 | return bs.String()
97 | }
98 |
99 | // Topics returns the list of topics to subscribe to
100 | func (u *Unsubscribe) Topics() []string {
101 | return u.topics
102 | }
103 |
104 | // UnsubAck is the MQTT UNSUBACK packet
105 | type UnsubAck uint16
106 |
107 | // ParseUnsubAck parses the unsubscribe packet from the given reader.
108 | func ParseUnsubAck(r *mqtt.Reader, b byte, pkLen int) (Packet, error) {
109 | if pkLen != 2 {
110 | return UnsubAck(0), errors.New("malformed UNSUBACK")
111 | }
112 | id, err := r.ReadUint16()
113 | return UnsubAck(id), err
114 | }
115 |
116 | // ID returns the packet ID
117 | func (u UnsubAck) ID() uint16 {
118 | return uint16(u)
119 | }
120 |
121 | // Equals returns true if this packet is equal to the given packet, false if not
122 | func (u UnsubAck) Equals(other interface{}) bool {
123 | return u == other
124 | }
125 |
126 | // String returns a brief string representation of the packet. Suitable for logging
127 | func (u UnsubAck) String() string {
128 | return fmt.Sprintf("UNSUBACK (m%d)", u.ID())
129 | }
130 |
131 | // Write writes the MQTT bits of this packet on the given Writer
132 | func (u UnsubAck) Write(w *mqtt.Writer) {
133 | w.WriteU8(TpUnsubAck)
134 | w.WriteU8(2)
135 | w.WriteU16(uint16(u))
136 | }
137 |
--------------------------------------------------------------------------------
/mqtt/pkg/unsubscribe_test.go:
--------------------------------------------------------------------------------
1 | package pkg_test
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/tada/mqtt-nats/mqtt/pkg"
7 | )
8 |
9 | func TestParseUnsubscribe(t *testing.T) {
10 | writeReadAndCompare(t, pkg.NewUnsubscribe(23, "some/topic", "some/other"), "UNSUBSCRIBE (m23, ['some/topic', 'some/other'])")
11 | }
12 |
13 | func TestParseUnsubAck(t *testing.T) {
14 | writeReadAndCompare(t, pkg.UnsubAck(23), "UNSUBACK (m23)")
15 | }
16 |
--------------------------------------------------------------------------------
/mqtt/pkg/will.go:
--------------------------------------------------------------------------------
1 | package pkg
2 |
3 | import (
4 | "bytes"
5 | "fmt"
6 | )
7 |
8 | // Will is the optional client will in the MQTT connect packet
9 | type Will struct {
10 | Topic string
11 | Message []byte
12 | QoS byte
13 | Retain bool
14 | }
15 |
16 | // Equals returns true if this instance is equal to the given instance, false if not
17 | func (w *Will) Equals(ow *Will) bool {
18 | return w.Retain == ow.Retain && w.QoS == ow.QoS && w.Topic == ow.Topic && bytes.Equal(w.Message, ow.Message)
19 | }
20 |
21 | // String returns a brief string representation of the will. Suitable for logging
22 | func (w *Will) String() string {
23 | r := 0
24 | if w.Retain {
25 | r = 1
26 | }
27 | return fmt.Sprintf("w(r%d, q%d, '%s', ... (%d bytes))", r, w.QoS, w.Topic, len(w.Message))
28 | }
29 |
--------------------------------------------------------------------------------
/mqtt/reader.go:
--------------------------------------------------------------------------------
1 | package mqtt
2 |
3 | import (
4 | "bytes"
5 | "encoding/binary"
6 | "errors"
7 | "fmt"
8 | "io"
9 | )
10 |
11 | // Reader extends the io.Reader with MQTT specific semantics for reading variable length integers,
12 | // two byte unsigned integers, and length prefixed strings and bytes.
13 | type Reader struct {
14 | io.Reader
15 | }
16 |
17 | // NewReader creates a new Reader that reads from the given io.Reader
18 | func NewReader(r io.Reader) *Reader {
19 | return &Reader{r}
20 | }
21 |
22 | // ReadByte reads and returns the next byte from the input or
23 | // any error encountered. If ReadByte returns an error, no input
24 | // byte was consumed, and the returned byte value is undefined.
25 | func (r *Reader) ReadByte() (byte, error) {
26 | b := []byte{0}
27 | n, err := r.Read(b)
28 | if n == 1 {
29 | return b[0], nil
30 | }
31 | return 0, err
32 | }
33 |
34 | // ReadVarInt returns the next variable size unsigned integer from the input stream.
35 | // A "malformed compressed int" error is returned if the value is larger than 0x0FFFFFFF.
36 | //
37 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read.
38 | func (r *Reader) ReadVarInt() (int, error) {
39 | m := 1
40 | v := 0
41 | for {
42 | b, err := r.ReadByte()
43 | if err != nil {
44 | if err == io.EOF {
45 | err = io.ErrUnexpectedEOF
46 | }
47 | return 0, err
48 | }
49 | v += int(b&0x7f) * m
50 | if (b & 0x80) == 0 {
51 | return v, nil
52 | }
53 | m *= 0x80
54 | if m > 0x200000 {
55 | return 0, errors.New("malformed compressed int")
56 | }
57 | }
58 | }
59 |
60 | // ReadUint16 reads the next two bytes from the input stream and returns a big endian unsigned integer
61 | //
62 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read.
63 | func (r *Reader) ReadUint16() (uint16, error) {
64 | var v uint16
65 | bs, err := r.ReadExact(2)
66 | if err == nil {
67 | v = binary.BigEndian.Uint16(bs)
68 | }
69 | return v, err
70 | }
71 |
72 | // ReadString will reads a big endian uint16 from the stream that denotes the number of bytes
73 | // that will follow. It then reads those bytes and returns them as a UTF8 encoded string.
74 | //
75 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read.
76 | func (r *Reader) ReadString() (string, error) {
77 | var s string
78 | bs, err := r.ReadBytes()
79 | if err == nil {
80 | s = string(bs)
81 | }
82 | return s, err
83 | }
84 |
85 | // ReadBytes will reads a big endian uint16 from the stream that denotes the number of bytes
86 | // that will follow. It then reads those bytes and returns them.
87 | //
88 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read.
89 | func (r *Reader) ReadBytes() ([]byte, error) {
90 | var bs []byte
91 | l, err := r.ReadUint16()
92 | if l > 0 && err == nil {
93 | bs, err = r.ReadExact(int(l))
94 | }
95 | return bs, err
96 | }
97 |
98 | // ReadExact reads an exact number of bytes into a []byte slice and returns it.
99 | //
100 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read.
101 | func (r *Reader) ReadExact(n int) ([]byte, error) {
102 | bs := make([]byte, n)
103 | _, err := io.ReadFull(r, bs)
104 | if err != nil {
105 | bs = nil
106 | if err == io.EOF {
107 | err = io.ErrUnexpectedEOF
108 | }
109 | }
110 | return bs, err
111 | }
112 |
113 | // Len returns the number of bytes of the unread portion of the slice. This method will panic
114 | // unless the underlying reader is a bytes.Reader.
115 | func (r *Reader) Len() int {
116 | if br, ok := r.Reader.(*bytes.Reader); ok {
117 | return br.Len()
118 | }
119 |
120 | // Reader was not set up to read remaining length
121 | panic(fmt.Errorf("unsupported operation on %T: Len", r.Reader))
122 | }
123 |
124 | // ReadRemainingBytes returns the remaining bytes of the underlying reader. This method will panic
125 | // unless the underlying reader is a bytes.Reader.
126 | func (r *Reader) ReadRemainingBytes() ([]byte, error) {
127 | return r.ReadExact(r.Len())
128 | }
129 |
130 | // ReadPacket reads exactly pkLen bytes from the underlying reader into a byte slice and creates a new Reader
131 | // that will read from this slice. The new Reader is returned.
132 | //
133 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read.
134 | func (r *Reader) ReadPacket(pkLen int) (*Reader, error) {
135 | // Do a bulk read and switch to read from that bulk
136 | var rdr *Reader
137 | pk, err := r.ReadExact(pkLen)
138 | if err == nil {
139 | rdr = &Reader{bytes.NewReader(pk)}
140 | }
141 | return rdr, err
142 | }
143 |
--------------------------------------------------------------------------------
/mqtt/reader_test.go:
--------------------------------------------------------------------------------
1 | package mqtt
2 |
3 | import (
4 | "bytes"
5 | "io"
6 | "testing"
7 | )
8 |
9 | type noLen int
10 |
11 | func (noLen) Read([]byte) (int, error) {
12 | return 0, io.EOF
13 | }
14 |
15 | func TestReader_Len(t *testing.T) {
16 | l := NewReader(bytes.NewReader([]byte{1, 2, 3})).Len()
17 | if l != 3 {
18 | t.Fatalf("expected length 3 got %d", l)
19 | }
20 | defer func() {
21 | r := recover()
22 | if r == nil {
23 | t.Fatal("expected panic")
24 | }
25 | }()
26 | NewReader(noLen(0)).Len()
27 | }
28 |
29 | func TestReader_ReadBytes(t *testing.T) {
30 | r := NewReader(bytes.NewReader([]byte{0, 2, 'a', 'b'}))
31 | bs, err := r.ReadBytes()
32 | if err != nil {
33 | t.Fatal(err)
34 | }
35 | sbs := string(bs)
36 | if sbs != "ab" {
37 | t.Fatalf(`expected "ab", got %q`, sbs)
38 | }
39 |
40 | // test premature EOF
41 | r = NewReader(bytes.NewReader([]byte{0, 3, 'a', 'b'}))
42 | _, err = r.ReadBytes()
43 | if err != io.ErrUnexpectedEOF {
44 | t.Fatal("expected error")
45 | }
46 |
47 | // test premature EOF nothing after length
48 | r = NewReader(bytes.NewReader([]byte{0, 3}))
49 | _, err = r.ReadBytes()
50 | if err != io.ErrUnexpectedEOF {
51 | t.Fatal("expected error")
52 | }
53 | }
54 |
55 | func TestReader_ReadVarInt(t *testing.T) {
56 | r := NewReader(bytes.NewReader([]byte{0x82, 0xff, 0x3}))
57 | l, err := r.ReadVarInt()
58 | if err != nil {
59 | t.Fatal(err)
60 | }
61 | if l != 0xff82 {
62 | t.Fatalf("expected length 0xff82 got 0x%x", l)
63 | }
64 | _, err = NewReader(bytes.NewReader([]byte{0xff, 0xff, 0xff, 0xff, 0xff})).ReadVarInt()
65 | if err == nil {
66 | t.Fatal("expected error")
67 | }
68 | _, err = NewReader(bytes.NewReader([]byte{0x80})).ReadVarInt()
69 | if err != io.ErrUnexpectedEOF {
70 | t.Fatal("expected error")
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/mqtt/topic.go:
--------------------------------------------------------------------------------
1 | package mqtt
2 |
3 | import (
4 | "io"
5 | "regexp"
6 | "strings"
7 | )
8 |
9 | const (
10 | dot = rune('.')
11 | slash = rune('/')
12 | star = rune('*')
13 | plus = rune('+')
14 | hash = rune('#')
15 | gt = rune('>')
16 | matchSegment = `^[/]+`
17 | matchRest = `.*`
18 | )
19 |
20 | // SubscriptionToRegexp converts an MQTT topic subscription into a regular expression that can be
21 | // used to match topics.
22 | func SubscriptionToRegexp(s string) *regexp.Regexp {
23 | w := strings.Builder{}
24 | for i, p := range strings.Split(s, "/") {
25 | if i > 0 {
26 | _ = w.WriteByte('/')
27 | }
28 | if len(p) == 1 {
29 | switch rune(p[0]) {
30 | case plus:
31 | _, _ = w.WriteString(matchSegment)
32 | continue
33 | case hash:
34 | _, _ = w.WriteString(matchRest)
35 | continue
36 | }
37 | }
38 | _, _ = w.WriteString(regexp.QuoteMeta(p))
39 | }
40 | return regexp.MustCompile(w.String())
41 | }
42 |
43 | // ToNATS converts an MQTT topic to a NATS subject. The following conversions take place
44 | //
45 | // dots become slashes
46 | // slashes become dots
47 | func ToNATS(mqttTopic string) string {
48 | r := strings.NewReader(mqttTopic)
49 | w := strings.Builder{}
50 | for {
51 | c, _, err := r.ReadRune()
52 | if err == io.EOF {
53 | return w.String()
54 | }
55 | switch c {
56 | case dot:
57 | c = slash
58 | case slash:
59 | c = dot
60 | }
61 | _, _ = w.WriteRune(c)
62 | }
63 | }
64 |
65 | // ToNATSSubscription converts the given MQTT subscription into a NATS subscription
66 | func ToNATSSubscription(mqttSub string) string {
67 | r := strings.NewReader(mqttSub)
68 | w := strings.Builder{}
69 | for {
70 | c, _, err := r.ReadRune()
71 | if err == io.EOF {
72 | return w.String()
73 | }
74 | switch c {
75 | case dot:
76 | c = slash
77 | case slash:
78 | c = dot
79 | case star:
80 | c = plus
81 | case plus:
82 | c = star
83 | case hash:
84 | c = gt
85 | case gt:
86 | c = hash
87 | }
88 | _, _ = w.WriteRune(c)
89 | }
90 | }
91 |
92 | // FromNATS converts an MATS subject to a MQTT topic. The following conversions take place
93 | //
94 | // dots become slashes
95 | // slashes become dots
96 | func FromNATS(natsSubject string) string {
97 | // exact same conversion but opposite direction, at least for now
98 | return ToNATS(natsSubject)
99 | }
100 |
101 | // FromNATSSubscription converts the given NATS subscription into a MQTT subscription
102 | func FromNATSSubscription(natsSubject string) string {
103 | // exact same conversion but opposite direction, at least for now
104 | return ToNATSSubscription(natsSubject)
105 | }
106 |
--------------------------------------------------------------------------------
/mqtt/topic_test.go:
--------------------------------------------------------------------------------
1 | package mqtt
2 |
3 | import (
4 | "testing"
5 | )
6 |
7 | func TestFromNATS(t *testing.T) {
8 | tests := []struct {
9 | name string
10 | nats string
11 | want string
12 | }{{
13 | name: "Slash is dot",
14 | nats: "a/b/c",
15 | want: "a.b.c",
16 | }, {
17 | name: "Dot is slash",
18 | nats: "a.b.c",
19 | want: "a/b/c",
20 | },
21 | }
22 | for i := range tests {
23 | tt := tests[i]
24 | t.Run(tt.name, func(t *testing.T) {
25 | if got := FromNATS(tt.nats); got != tt.want {
26 | t.Errorf("FromNATS() = %v, want %v", got, tt.want)
27 | }
28 | })
29 | }
30 | }
31 |
32 | func TestFromNATSSubscription(t *testing.T) {
33 | tests := []struct {
34 | name string
35 | nats string
36 | want string
37 | }{{
38 | name: "Slash is dot",
39 | nats: "a/b/c",
40 | want: "a.b.c",
41 | }, {
42 | name: "Dot is slash",
43 | nats: "a.b.c",
44 | want: "a/b/c",
45 | }, {
46 | name: "Star is plus",
47 | nats: "a.*.b",
48 | want: "a/+/b",
49 | }, {
50 | name: "Plus is star",
51 | nats: "a.+.b",
52 | want: "a/*/b",
53 | }, {
54 | name: "> is #",
55 | nats: "a.b.>",
56 | want: "a/b/#",
57 | }, {
58 | name: "# is >",
59 | nats: "a.b.#",
60 | want: "a/b/>",
61 | }}
62 | for i := range tests {
63 | tt := tests[i]
64 | t.Run(tt.name, func(t *testing.T) {
65 | if got := FromNATSSubscription(tt.nats); got != tt.want {
66 | t.Errorf("FromNATSSubscription() = %v, want %v", got, tt.want)
67 | }
68 | })
69 | }
70 | }
71 |
--------------------------------------------------------------------------------
/mqtt/writer.go:
--------------------------------------------------------------------------------
1 | package mqtt
2 |
3 | import "bytes"
4 |
5 | // Writer extends a bytes.Buffer with MQTT specific semantics for writing variable length integers,
6 | // two byte unsigned integers, and length prefixed strings and bytes.
7 | type Writer struct {
8 | bytes.Buffer
9 | }
10 |
11 | // NewWriter returns a new Writer instance
12 | func NewWriter() *Writer {
13 | return &Writer{}
14 | }
15 |
16 | // WriteU8 writes a byte on the underlying buffer. This is the same as calling WriteByte but there
17 | // is no error return as opposed to WriteByte which returns the error type although it is always nil.
18 | func (w *Writer) WriteU8(i uint8) {
19 | _ = w.WriteByte(i)
20 | }
21 |
22 | // WriteU16 writes the big endian two bytes of the given uint16 on the underlying buffer
23 | func (w *Writer) WriteU16(i uint16) {
24 | w.WriteU8(byte(i >> 8))
25 | w.WriteU8(byte(i))
26 | }
27 |
28 | // WriteString first writes the length of the string using WriteU16 and then the strings bytes.
29 | func (w *Writer) WriteString(s string) {
30 | t := len(s)
31 | w.WriteU16(uint16(t))
32 | for i := 0; i < t; i++ {
33 | w.WriteU8(s[i])
34 | }
35 | }
36 |
37 | // WriteBytes first writes the length of the byte slice using WriteU16 and then the bytes.
38 | func (w *Writer) WriteBytes(bs []byte) {
39 | t := len(bs)
40 | w.WriteU16(uint16(t))
41 | _, _ = w.Write(bs)
42 | }
43 |
44 | // WriteVarInt writes a variable length integer in accordance with the MQTT 3.1.1 specification on the
45 | // underlying buffer.
46 | func (w *Writer) WriteVarInt(value int) {
47 | for {
48 | b := byte(value & 0x7f)
49 | value >>= 7
50 | if value > 0 {
51 | b |= 0x80
52 | }
53 | w.WriteU8(b)
54 | if value == 0 {
55 | break
56 | }
57 | }
58 | }
59 |
--------------------------------------------------------------------------------
/mqtt/writer_test.go:
--------------------------------------------------------------------------------
1 | package mqtt
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 | )
7 |
8 | func TestWriteVarInt(t *testing.T) {
9 | ints := []int{
10 | 0, 1, 127, 128, 16383, 16384, 2097151, 2097152, 268435455,
11 | }
12 | lens := []int{
13 | 1, 1, 1, 2, 2, 3, 3, 4, 4,
14 | }
15 | w := NewWriter()
16 | tl := 0
17 | for i, v := range ints {
18 | w.WriteVarInt(v)
19 | tl += lens[i]
20 | if tl != w.Len() {
21 | t.Fatalf("expected len %d, got %d", tl, w.Len())
22 | }
23 | }
24 |
25 | r := Reader{bytes.NewReader(w.Bytes())}
26 | for _, v := range ints {
27 | x, _ := r.ReadVarInt()
28 | if v != x {
29 | t.Fatalf("expected %d, got %d", v, x)
30 | }
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/test/connect_test.go:
--------------------------------------------------------------------------------
1 | package test
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 | "time"
7 |
8 | "github.com/tada/mqtt-nats/mqtt/pkg"
9 | "github.com/tada/mqtt-nats/test/full"
10 | )
11 |
12 | var packetIDManager = pkg.NewIDManager()
13 |
14 | func nextPacketID() uint16 {
15 | return packetIDManager.NextFreePacketID()
16 | }
17 |
18 | func TestConnect(t *testing.T) {
19 | conn := full.MqttConnect(t, mqttPort)
20 | full.MqttSend(t, conn, pkg.NewConnect(full.NextClientID(), true, 1, nil, nil))
21 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0))
22 | full.MqttDisconnect(t, conn)
23 | }
24 |
25 | func TestConnect_sessionPresent(t *testing.T) {
26 | conn := full.MqttConnect(t, mqttPort)
27 | c := pkg.NewConnect(full.NextClientID(), false, 1, nil, nil)
28 | full.MqttSend(t, conn, c)
29 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0))
30 | full.MqttDisconnect(t, conn)
31 |
32 | conn = full.MqttConnect(t, mqttPort)
33 | full.MqttSend(t, conn, c)
34 | full.MqttExpect(t, conn, pkg.NewConnAck(true, 0))
35 | full.MqttDisconnect(t, conn)
36 | }
37 |
38 | func TestConnect_will_qos_0(t *testing.T) {
39 | conn1 := full.MqttConnect(t, mqttPort)
40 | full.MqttSend(t, conn1, pkg.NewConnect(full.NextClientID(), true, 1, nil, nil))
41 | full.MqttExpect(t, conn1, pkg.NewConnAck(false, 0))
42 | mid := nextPacketID()
43 | full.MqttSend(t, conn1, pkg.NewSubscribe(mid, pkg.Topic{Name: "testing/my/will"}))
44 | full.MqttExpect(t, conn1, pkg.NewSubAck(mid, 0))
45 |
46 | conn2 := full.MqttConnect(t, mqttPort)
47 | full.MqttSend(t, conn2,
48 | pkg.NewConnect(full.NextClientID(), true, 1,
49 | &pkg.Will{
50 | Topic: "testing/my/will",
51 | Message: []byte("the will message")}, nil))
52 | full.MqttExpect(t, conn2, pkg.NewConnAck(false, 0))
53 | // forcefully close connection
54 | _ = conn2.Close()
55 |
56 | full.MqttExpect(t, conn1,
57 | func(p pkg.Packet) bool {
58 | pp, ok := p.(*pkg.Publish)
59 | return ok && pp.TopicName() == "testing/my/will" && pp.QoSLevel() == 0 && !pp.IsDup() && !pp.Retain()
60 | })
61 | full.MqttDisconnect(t, conn1)
62 | }
63 |
64 | func TestConnect_will_qos_1(t *testing.T) {
65 | conn := full.MqttConnect(t, mqttPort)
66 | full.MqttSend(t, conn,
67 | pkg.NewConnect(full.NextClientID(), true, 1, &pkg.Will{
68 | Topic: "testing/my/will",
69 | Message: []byte("the will message"),
70 | QoS: 1}, nil))
71 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0))
72 | // forcefully close connection
73 | _ = conn.Close()
74 |
75 | // Ensure that first package is wasted and a dup is published
76 | time.Sleep(10 * time.Millisecond)
77 |
78 | conn = full.MqttConnectClean(t, mqttPort)
79 | mid := nextPacketID()
80 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: "testing/my/will", QoS: 1}))
81 | full.MqttExpect(t, conn,
82 | pkg.NewSubAck(mid, 1),
83 | func(p pkg.Packet) bool {
84 | if pp, ok := p.(*pkg.Publish); ok {
85 | full.MqttSend(t, conn, pkg.PubAck(pp.ID()))
86 | return pp.TopicName() == "testing/my/will" && pp.QoSLevel() == 1 && pp.IsDup() && !pp.Retain()
87 | }
88 | return false
89 | })
90 | full.MqttDisconnect(t, conn)
91 | }
92 |
93 | func TestConnect_will_retain_qos_0(t *testing.T) {
94 | willTopic := "testing/my/will"
95 | c1 := full.MqttConnect(t, mqttPort)
96 | full.MqttSend(t, c1,
97 | pkg.NewConnect(full.NextClientID(), true, 5, &pkg.Will{
98 | Topic: willTopic,
99 | Message: []byte("the will message"),
100 | Retain: true}, nil))
101 | full.MqttExpect(t, c1, pkg.NewConnAck(false, 0))
102 |
103 | // forcefully close connection to make server publish will
104 | _ = c1.Close()
105 | time.Sleep(10 * time.Millisecond) // give bridge time to handle retained
106 |
107 | gotIt := make(chan bool, 1)
108 | go func() {
109 | c2 := full.MqttConnectClean(t, mqttPort)
110 | mid := nextPacketID()
111 | full.MqttSend(t, c2, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic}))
112 | full.MqttExpect(t, c2, pkg.NewSubAck(mid, 0))
113 | full.MqttExpect(t, c2, func(p pkg.Packet) bool {
114 | pp, ok := p.(*pkg.Publish)
115 | return ok && pp.TopicName() == willTopic && pp.QoSLevel() == 0 && !pp.IsDup() && pp.Retain()
116 | })
117 | gotIt <- true
118 | full.MqttDisconnect(t, c2)
119 | }()
120 |
121 | full.AssertMessageReceived(t, gotIt)
122 |
123 | // check that retained will still exists
124 | c1 = full.MqttConnectClean(t, mqttPort)
125 | mid := nextPacketID()
126 | full.MqttSend(t, c1, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic}))
127 | full.MqttExpect(t, c1,
128 | pkg.NewSubAck(mid, 0),
129 | func(p pkg.Packet) bool {
130 | pp, ok := p.(*pkg.Publish)
131 | return ok && pp.TopicName() == willTopic && pp.QoSLevel() == 0 && !pp.IsDup() && pp.Retain()
132 | })
133 |
134 | // drop the retained packet
135 | full.MqttSend(t, c1, pkg.NewPublish2(0, willTopic, []byte{}, 0, false, true))
136 | full.MqttDisconnect(t, c1)
137 | }
138 |
139 | func TestConnect_will_retain_qos_1(t *testing.T) {
140 | conn := full.MqttConnect(t, mqttPort)
141 | willTopic := "testing/my/will"
142 | willPayload := []byte("the will message")
143 | full.MqttSend(t, conn,
144 | pkg.NewConnect(full.NextClientID(), true, 5, &pkg.Will{
145 | Topic: willTopic,
146 | Message: willPayload,
147 | QoS: 1,
148 | Retain: true}, nil))
149 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0))
150 | // forcefully close connection
151 | _ = conn.Close()
152 |
153 | conn = full.MqttConnectClean(t, mqttPort)
154 | mid := nextPacketID()
155 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic, QoS: 1}))
156 | var ackID uint16
157 | full.MqttExpect(t, conn,
158 | pkg.NewSubAck(mid, 1),
159 | func(p pkg.Packet) bool {
160 | if pp, ok := p.(*pkg.Publish); ok {
161 | ackID = pp.ID()
162 | return pp.TopicName() == willTopic && bytes.Equal(pp.Payload(), willPayload) && pp.QoSLevel() == 1 && !pp.IsDup() && pp.Retain()
163 | }
164 | return false
165 | })
166 | full.MqttSend(t, conn, pkg.PubAck(ackID))
167 | full.MqttDisconnect(t, conn)
168 |
169 | conn = full.MqttConnectClean(t, mqttPort)
170 | mid = nextPacketID()
171 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic, QoS: 1}))
172 | full.MqttExpect(t, conn,
173 | pkg.NewSubAck(mid, 1),
174 | func(p pkg.Packet) bool {
175 | if pp, ok := p.(*pkg.Publish); ok {
176 | ackID = pp.ID()
177 | return pp.TopicName() == willTopic && bytes.Equal(pp.Payload(), willPayload) && pp.QoSLevel() == 1 && !pp.IsDup() && pp.Retain()
178 | }
179 | return false
180 | })
181 | full.MqttSend(t, conn, pkg.PubAck(ackID))
182 |
183 | // drop the retained packet
184 | full.MqttSend(t, conn, pkg.NewPublish2(0, willTopic, []byte{}, 1, false, true))
185 | full.MqttDisconnect(t, conn)
186 | }
187 |
188 | func TestConnect_will_retain_qos_1_restart(t *testing.T) {
189 | conn := full.MqttConnect(t, mqttPort)
190 | willTopic := "testing/my/will"
191 | willPayload := []byte("the will message")
192 | full.MqttSend(t, conn,
193 | pkg.NewConnect(full.NextClientID(), true, 5, &pkg.Will{
194 | Topic: willTopic,
195 | Message: willPayload,
196 | QoS: 1,
197 | Retain: true}, nil))
198 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0))
199 | // forcefully close connection
200 | _ = conn.Close()
201 | time.Sleep(50 * time.Millisecond) // give bridge time to publish will
202 |
203 | conn = full.MqttConnectClean(t, mqttPort)
204 | mid := nextPacketID()
205 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic, QoS: 1}))
206 | full.MqttExpect(t, conn, pkg.NewSubAck(mid, 1))
207 |
208 | full.MqttExpect(t, conn, func(p pkg.Packet) bool {
209 | if pp, ok := p.(*pkg.Publish); ok {
210 | full.MqttSend(t, conn, pkg.PubAck(pp.ID()))
211 | return pp.TopicName() == willTopic && bytes.Equal(pp.Payload(), willPayload) && pp.QoSLevel() == 1 && !pp.IsDup() && pp.Retain()
212 | }
213 | return false
214 | })
215 | full.MqttDisconnect(t, conn)
216 |
217 | full.RestartBridge(t, mqttServer)
218 |
219 | conn = full.MqttConnectClean(t, mqttPort)
220 | mid = nextPacketID()
221 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic, QoS: 1}))
222 | full.MqttExpect(t, conn, pkg.NewSubAck(mid, 1))
223 | full.MqttExpect(t, conn,
224 | func(p pkg.Packet) bool {
225 | if pp, ok := p.(*pkg.Publish); ok {
226 | full.MqttSend(t, conn, pkg.PubAck(pp.ID()))
227 | return pp.TopicName() == willTopic && bytes.Equal(pp.Payload(), willPayload) && pp.QoSLevel() == 1 && !pp.IsDup() && pp.Retain()
228 | }
229 | return false
230 | })
231 |
232 | // drop the retained packet
233 | full.MqttSend(t, conn, pkg.NewPublish2(0, willTopic, []byte{}, 1, false, true))
234 | full.MqttDisconnect(t, conn)
235 | }
236 |
237 | func TestConnect_will_qos_1_restart(t *testing.T) {
238 | conn := full.MqttConnect(t, mqttPort)
239 | willTopic := "testing/my/will"
240 | willPayload := []byte("the will message")
241 | full.MqttSend(t, conn,
242 | pkg.NewConnect(full.NextClientID(), true, 5,
243 | &pkg.Will{
244 | Topic: willTopic,
245 | Message: willPayload,
246 | QoS: 1},
247 | &pkg.Credentials{
248 | User: "bob",
249 | Password: []byte("password")}))
250 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0))
251 | // forcefully close connection
252 | _ = conn.Close()
253 | time.Sleep(50 * time.Millisecond) // give bridge time to publish will
254 | full.RestartBridge(t, mqttServer)
255 |
256 | conn = full.MqttConnectClean(t, mqttPort)
257 | mid := nextPacketID()
258 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic, QoS: 1}))
259 | var ackID uint16
260 | full.MqttExpect(t, conn,
261 | pkg.NewSubAck(mid, 1),
262 | func(p pkg.Packet) bool {
263 | if pp, ok := p.(*pkg.Publish); ok {
264 | ackID = pp.ID()
265 | return pp.TopicName() == willTopic && bytes.Equal(pp.Payload(), willPayload) && pp.QoSLevel() == 1 && pp.IsDup() && !pp.Retain()
266 | }
267 | return false
268 | })
269 | full.MqttSend(t, conn, pkg.PubAck(ackID))
270 |
271 | // drop the retained packet
272 | full.MqttSend(t, conn, pkg.NewPublish2(0, willTopic, []byte{}, 1, false, true))
273 | full.MqttDisconnect(t, conn)
274 | }
275 |
276 | func TestPing(t *testing.T) {
277 | conn := full.MqttConnectClean(t, mqttPort)
278 | full.MqttSend(t, conn, pkg.PingRequestSingleton)
279 | full.MqttExpect(t, conn, pkg.PingResponseSingleton)
280 | full.MqttDisconnect(t, conn)
281 | }
282 |
283 | func TestPing_beforeConnect(t *testing.T) {
284 | conn := full.MqttConnect(t, mqttPort)
285 | full.MqttSend(t, conn, pkg.PingRequestSingleton)
286 | full.MqttExpectConnReset(t, conn)
287 | }
288 |
289 | func TestConnect_badProtocolVersion(t *testing.T) {
290 | conn := full.MqttConnect(t, mqttPort)
291 | cp := pkg.NewConnect(full.NextClientID(), true, 1, nil, nil)
292 | cp.SetClientLevel(3)
293 | full.MqttSend(t, conn, cp)
294 | full.MqttExpect(t, conn, pkg.NewConnAck(false, pkg.RtUnacceptableProtocolVersion))
295 | full.MqttDisconnect(t, conn)
296 | }
297 |
298 | func TestConnect_multiple(t *testing.T) {
299 | conn := full.MqttConnectClean(t, mqttPort)
300 | full.MqttSend(t, conn, pkg.NewConnect(full.NextClientID(), true, 1, nil, nil))
301 | full.MqttExpectConnReset(t, conn)
302 | }
303 |
304 | func TestBadPacketLength(t *testing.T) {
305 | conn := full.MqttConnectClean(t, mqttPort)
306 | _, err := conn.Write([]byte{0x01, 0xff, 0xff, 0xff, 0xff})
307 | if err != nil {
308 | t.Fatal(err)
309 | }
310 | full.MqttExpectConnReset(t, conn)
311 | }
312 |
313 | func TestBadPacketType(t *testing.T) {
314 | conn := full.MqttConnectClean(t, mqttPort)
315 | _, err := conn.Write([]byte{0xff, 0x0})
316 | if err != nil {
317 | t.Fatal(err)
318 | }
319 | full.MqttExpectConnReset(t, conn)
320 | }
321 |
--------------------------------------------------------------------------------
/test/full/bridge.go:
--------------------------------------------------------------------------------
1 | // +build citest
2 |
3 | // Package full contains the test utilities that enables full roundtrip testing with both an mqtt-bridge and
4 | // a NATS test server.
5 | package full
6 |
7 | import (
8 | "sync"
9 | "testing"
10 | "time"
11 |
12 | "github.com/nats-io/nats-server/v2/server"
13 | testserver "github.com/nats-io/nats-server/v2/test"
14 | "github.com/tada/mqtt-nats/bridge"
15 | "github.com/tada/mqtt-nats/logger"
16 | )
17 |
18 | // RunBridge starts an in-process mqtt-nats bridge configured with the given options.
19 | func RunBridge(lg logger.Logger, opts *bridge.Options) (bridge.Bridge, error) {
20 | srv, err := bridge.New(opts, lg)
21 | if err != nil {
22 | return nil, err
23 | }
24 |
25 | serverReady := sync.WaitGroup{}
26 | serverReady.Add(1)
27 | go func() {
28 | err = srv.Serve(&serverReady)
29 | if err != nil {
30 | lg.Error(err)
31 | }
32 | }()
33 | serverReady.Wait()
34 | return srv, err
35 | }
36 |
37 | // RestartBridge restarts the given bridge
38 | func RestartBridge(t *testing.T, b bridge.Bridge) {
39 | serverReady := sync.WaitGroup{}
40 | serverReady.Add(1)
41 | go func() {
42 | if err := b.Restart(&serverReady); err != nil {
43 | t.Error(err)
44 | }
45 | }()
46 | serverReady.Wait()
47 | }
48 |
49 | // NATSServerOnPort will run a server on the given port.
50 | func NATSServerOnPort(port int) *server.Server {
51 | opts := testserver.DefaultTestOptions
52 | opts.Port = port
53 | return NATSServerWithOptions(&opts)
54 | }
55 |
56 | // NATSServerWithOptions will run a server with the given options.
57 | func NATSServerWithOptions(opts *server.Options) *server.Server {
58 | return testserver.RunServer(opts)
59 | }
60 |
61 | // AssertMessageReceived waits for a boolean to be received on the given channel for one second and then
62 | // bails out with a Fatal message "expected message did not arrive"
63 | func AssertMessageReceived(t *testing.T, c <-chan bool) {
64 | t.Helper()
65 | select {
66 | case <-c:
67 | case <-time.After(time.Second): // Wait time is somewhat arbitrary.
68 | t.Fatalf(`expected message did not arrive`)
69 | }
70 | }
71 |
72 | // AssertTimeout waits for 10 milliseconds to ensure that no boolean is received on the given channel and
73 | // then bails out with a Fatal message "unexpected message arrived".
74 | func AssertTimeout(t *testing.T, c <-chan bool) {
75 | t.Helper()
76 | select {
77 | case <-c:
78 | t.Fatalf(`unexpected message arrived`)
79 | case <-time.After(10 * time.Millisecond):
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/test/full/client.go:
--------------------------------------------------------------------------------
1 | // +build citest
2 |
3 | package full
4 |
5 | import (
6 | "io"
7 | "net"
8 | "strconv"
9 | "strings"
10 | "testing"
11 |
12 | "github.com/nats-io/nuid"
13 | "github.com/tada/mqtt-nats/mqtt"
14 | "github.com/tada/mqtt-nats/mqtt/pkg"
15 | "github.com/tada/mqtt-nats/test/packet"
16 | )
17 |
18 | // NextClientID returns a new unique client ID with the prefix "testclient-"
19 | func NextClientID() string {
20 | return "testclient-" + nuid.Next()
21 | }
22 |
23 | // MqttConnect establishes a tcp connection to the given port on the default host
24 | func MqttConnect(t *testing.T, port int) net.Conn {
25 | t.Helper()
26 | conn, err := net.Dial("tcp", ":"+strconv.Itoa(port))
27 | if err != nil {
28 | t.Fatal(err)
29 | }
30 | return conn
31 | }
32 |
33 | // MqttConnectClean establishes a tcp connection to the given port on the default host, sends the
34 | // initial connect packet for a clean session and awaits the CONNACK.
35 | func MqttConnectClean(t *testing.T, port int) net.Conn {
36 | conn := MqttConnect(t, port)
37 | MqttSend(t, conn, pkg.NewConnect(NextClientID(), true, 1, nil, nil))
38 | MqttExpect(t, conn, pkg.NewConnAck(false, 0))
39 | return conn
40 | }
41 |
42 | // MqttDisconnect sends a disconnect packet and closes the connection
43 | func MqttDisconnect(t *testing.T, conn io.WriteCloser) {
44 | t.Helper()
45 | defer func() {
46 | if err := conn.Close(); err != nil {
47 | t.Fatal(err)
48 | }
49 | }()
50 | buf := mqtt.NewWriter()
51 | pkg.DisconnectSingleton.Write(buf)
52 | _, err := conn.Write(buf.Bytes())
53 | if err != nil {
54 | t.Fatal(err)
55 | }
56 | }
57 |
58 | // MqttSend writes the given packets on the given connection
59 | func MqttSend(t *testing.T, conn io.Writer, send ...pkg.Packet) {
60 | t.Helper()
61 | buf := mqtt.NewWriter()
62 | for i := range send {
63 | send[i].Write(buf)
64 | }
65 | _, err := conn.Write(buf.Bytes())
66 | if err != nil {
67 | t.Fatal(err)
68 | }
69 | }
70 |
71 | // MqttExpect will read one packet for each entry in the list of expectations and assert that it is matched
72 | // by that entry. An expectation is either an expected verbatim pkg.Packet or a PacketMatcher function.
73 | func MqttExpect(t *testing.T, conn io.Reader, expectations ...interface{}) {
74 | t.Helper()
75 | for _, e := range expectations {
76 | a := packet.Parse(t, conn)
77 | switch e := e.(type) {
78 | case pkg.Packet:
79 | if !e.Equals(a) {
80 | t.Fatalf("expected '%s', got '%s'", e, a)
81 | }
82 | case func(pkg.Packet) bool:
83 | if !e(a) {
84 | t.Fatalf("packet '%s' does not match packet match function", a)
85 | }
86 | default:
87 | t.Fatalf("a %T is not a valid expectation", e)
88 | }
89 | }
90 | }
91 |
92 | // MqttExpectConnReset will make a read attempt and expect that it fails with an error
93 | func MqttExpectConnReset(t *testing.T, conn net.Conn) {
94 | t.Helper()
95 | _, err := conn.Read([]byte{0})
96 | if err != nil {
97 | if strings.Contains(err.Error(), "EOF") {
98 | return
99 | }
100 | if strings.Contains(err.Error(), "reset") {
101 | return
102 | }
103 | if strings.Contains(err.Error(), "forcibly closed") {
104 | return
105 | }
106 | }
107 | t.Fatalf("connection is not reset: %v", err)
108 | }
109 |
110 | // MqttExpectConnClosed will make a read attempt and expect that it fails with an error
111 | func MqttExpectConnClosed(t *testing.T, conn net.Conn) {
112 | t.Helper()
113 | _, err := conn.Read([]byte{0})
114 | if err != nil && err == io.EOF || strings.Contains(err.Error(), "closed") {
115 | return
116 | }
117 | t.Fatalf("connection is not closed: %v", err)
118 | }
119 |
--------------------------------------------------------------------------------
/test/full/nats.go:
--------------------------------------------------------------------------------
1 | // +build citest
2 |
3 | package full
4 |
5 | import (
6 | "strconv"
7 | "testing"
8 |
9 | "github.com/nats-io/nats.go"
10 | )
11 |
12 | // NatsConnect creates a new NATS connection on the given port.
13 | func NatsConnect(t *testing.T, port int) *nats.Conn {
14 | nc, err := nats.Connect(":" + strconv.Itoa(port))
15 | if err != nil {
16 | t.Fatal(err)
17 | }
18 | return nc
19 | }
20 |
--------------------------------------------------------------------------------
/test/main_test.go:
--------------------------------------------------------------------------------
1 | package test
2 |
3 | import (
4 | "os"
5 | "strconv"
6 | "testing"
7 |
8 | "github.com/tada/mqtt-nats/bridge"
9 | "github.com/tada/mqtt-nats/logger"
10 | "github.com/tada/mqtt-nats/test/full"
11 | )
12 |
13 | var mqttServer bridge.Bridge
14 |
15 | const (
16 | storageFile = "mqtt-nats.json"
17 | mqttPort = 11883
18 | natsPort = 14222
19 | retainedRequestTopic = "mqtt.retained.request"
20 | )
21 |
22 | func TestMain(m *testing.M) {
23 | _ = os.Remove(storageFile)
24 | natsServer := full.NATSServerOnPort(natsPort)
25 |
26 | // NOTE: Setting level to logger.Debug here is very helpful when authoring and debugging tests but
27 | // it also makes the tests very verbose.
28 | lg := logger.New(logger.Debug, os.Stdout, os.Stderr)
29 |
30 | opts := bridge.Options{
31 | Port: mqttPort,
32 | NATSUrls: ":" + strconv.Itoa(natsPort),
33 | RepeatRate: 50,
34 | RetainedRequestTopic: retainedRequestTopic,
35 | StoragePath: storageFile}
36 | var err error
37 | mqttServer, err = full.RunBridge(lg, &opts)
38 |
39 | var code int
40 | if err == nil {
41 | code = m.Run()
42 | } else {
43 | lg.Error(err)
44 | code = 1
45 | }
46 | natsServer.Shutdown()
47 | if err = mqttServer.Shutdown(); err != nil {
48 | lg.Error(err)
49 | code = 1
50 | }
51 | os.Exit(code)
52 | }
53 |
--------------------------------------------------------------------------------
/test/mock/connection.go:
--------------------------------------------------------------------------------
1 | // +build citest
2 |
3 | // Package mock contains mocking/simulated versions of real runtime types
4 | // primarily used for testing.
5 | //
6 | package mock
7 |
8 | import (
9 | "bytes"
10 | "io"
11 | "net"
12 | "sync"
13 | "time"
14 | )
15 |
16 | // Connection implements net.Conn and allows recording and playback.
17 | //
18 | // An instance must be created by calling NewConnection() and it can thereafter be
19 | // used the same way as a net.Conn returned from net.Dial.
20 | // In addition to the net.Conn API, the Connection supports
21 | // RemoteRead() and RemoteWrite() for what would be the remote end of
22 | // a net.Conn.
23 | //
24 | // The primary intended use case for Connection is to help with unit testing
25 | // logic using a net.Conn.
26 | //
27 | // ReadDeadLine is supported (TODO: currently WriteDeadLine is ignored as all writes will succeed until
28 | // the system is out of memory).
29 | //
30 | // TODO: This Connection does not have a way to simulate
31 | // syscall.ETIMEOUT error return resulting from a TCP keepAlive
32 | // failure.
33 | //
34 | // TODO: This Connection uses the deadlines for both local and remote
35 | //
36 | type Connection struct {
37 | input bytes.Buffer
38 | output bytes.Buffer
39 | closed bool
40 | readDeadline time.Time
41 | readDeadLineTimer *time.Timer
42 | writeDeadline time.Time
43 | moreData *sync.Cond // Cond to wait on when there is no data to read
44 | moreRemoteData *sync.Cond // Cond to wait on when there is no data to read at remote end
45 | }
46 |
47 | // NewConnection returns a new connection - i.e. comparable to net.Dial() but everything is hardcoded
48 | func NewConnection() *Connection {
49 | // start out with a stopped timer - this timer is reset when deadline is changed
50 | mc := &Connection{
51 | moreData: sync.NewCond(&sync.Mutex{}),
52 | moreRemoteData: sync.NewCond(&sync.Mutex{}),
53 | readDeadLineTimer: newStoppedTimer(),
54 | }
55 |
56 | // Run a readDeadLine listener that wakes up those that are blocked reading
57 | go func() {
58 | for {
59 | // wait for timer to fire
60 | <-mc.readDeadLineTimer.C
61 | // timer fired, lock and Broadcast to those in blocking read
62 | // TODO: Fix debug level logging
63 | // if log.IsLevelEnabled(log.DebugLevel) {
64 | // log.Debug("mock.Connection ReadDeadline fired")
65 | // }
66 | md := mc.moreData
67 | md.L.Lock()
68 | md.Broadcast()
69 | md.L.Unlock()
70 | }
71 | }()
72 | return mc
73 | }
74 |
75 | // newStoppedTimer returns a stopped timer
76 | func newStoppedTimer() *time.Timer {
77 | timer := time.NewTimer(1 * time.Second)
78 | if !timer.Stop() {
79 | <-timer.C // drain if it fired (very unlikely, but happens when stepping over this in debugging)
80 | }
81 | return timer
82 | }
83 |
84 | // Addr implements net.Addr interface and is a staic "tcp" "0.0.0.0"
85 | type Addr struct{}
86 |
87 | // Network returns a static "tcp" for MockConnectionAddr
88 | func (a *Addr) Network() string { return "tcp" }
89 |
90 | // String returns a static "0.0.0.0" for MockConnectionAddr
91 | func (a *Addr) String() string { return "0.0.0.0" }
92 |
93 | // Read reads data from the connection.
94 | // Read can be made to time out and return an Error with Timeout() == true
95 | // after a fixed time limit; see SetDeadline and SetReadDeadline.
96 | func (c *Connection) Read(b []byte) (n int, err error) {
97 | return c.readBufWithLock(b, &c.input, c.moreData)
98 | }
99 |
100 | // RemoteRead reads data from the connection's remote end (this returns what was written with Write)
101 | func (c *Connection) RemoteRead(b []byte) (n int, err error) {
102 | return c.readBufWithLock(b, &c.output, c.moreRemoteData)
103 | }
104 |
105 | // readBufWithLock reads from the buffer and waits for more data if it is empty
106 | func (c *Connection) readBufWithLock(b []byte, buffer *bytes.Buffer, condition *sync.Cond) (n int, err error) {
107 | // TODO: timeout & read of 0 bytes?
108 | for {
109 | condition.L.Lock()
110 | again:
111 | availBytes := buffer.Len()
112 | if c.closed && availBytes == 0 {
113 | condition.L.Unlock()
114 | return 0, io.EOF
115 | }
116 | deadline := c.readDeadline
117 | if !deadline.IsZero() && time.Now().After(deadline) { // while all times are after "zero time", this avoids a system call to Now()
118 | condition.L.Unlock()
119 | return 0, ErrTimeout
120 | }
121 | if availBytes == 0 {
122 | condition.Wait()
123 | goto again
124 | }
125 | n, err = buffer.Read(b)
126 | condition.L.Unlock()
127 | return n, err
128 | }
129 | }
130 |
131 | // Write writes data to the connection.
132 | // Write can be made to time out and return an Error with Timeout() == true
133 | // after a fixed time limit; see SetDeadline and SetWriteDeadline.
134 | func (c *Connection) Write(b []byte) (n int, err error) {
135 | return c.writeBufWithLock(b, &c.output, c.moreRemoteData, c.writeDeadline)
136 | }
137 |
138 | // RemoteWrite writes data to the connection as if something was written on the remote side of
139 | // a connection. (This data will be returned from subsequent Read() operations).
140 | //
141 | func (c *Connection) RemoteWrite(b []byte) (n int, err error) {
142 | return c.writeBufWithLock(b, &c.input, c.moreData, c.writeDeadline)
143 | }
144 |
145 | // writeBufWithLock writes to the buffer, and signals a condition if buffer goes from empty to having content
146 | func (c *Connection) writeBufWithLock(b []byte, buffer *bytes.Buffer, condition *sync.Cond, deadline time.Time) (n int, err error) {
147 | condition.L.Lock()
148 | defer condition.L.Unlock()
149 |
150 | if c.closed {
151 | return 0, io.EOF
152 | }
153 |
154 | if !deadline.IsZero() && time.Now().After(deadline) { // while all times are after "zero time", this avoids a system call to Now()
155 | return 0, ErrTimeout
156 | }
157 |
158 | availBytes := buffer.Len()
159 | n, err = buffer.Write(b)
160 | if availBytes == 0 {
161 | condition.Broadcast()
162 | }
163 | return n, err
164 | }
165 |
166 | // Close closes the connection.
167 | // Any blocked Read or Write operations will be unblocked and return errors.
168 | func (c *Connection) Close() error {
169 | c.moreRemoteData.L.Lock()
170 | defer c.moreRemoteData.L.Unlock()
171 | c.moreData.L.Lock()
172 | defer c.moreData.L.Unlock()
173 |
174 | // mark closed
175 | c.closed = true
176 |
177 | // release blocked readers so they pick up the close
178 | if c.input.Len() == 0 {
179 | c.moreData.Broadcast()
180 | }
181 | // release blocked remote readers so they pick up the close
182 | if c.output.Len() == 0 {
183 | c.moreRemoteData.Broadcast()
184 | }
185 | return nil
186 | }
187 |
188 | // LocalAddr returns a hardcoded local network address.
189 | func (c *Connection) LocalAddr() net.Addr {
190 | return &Addr{}
191 | }
192 |
193 | // RemoteAddr returns a hardcoded remote network address.
194 | func (c *Connection) RemoteAddr() net.Addr {
195 | return &Addr{}
196 | }
197 |
198 | // SetDeadline sets the read and write deadlines associated
199 | // with the connection. It is equivalent to calling both
200 | // SetReadDeadline and SetWriteDeadline.
201 | //
202 | // A deadline is an absolute time after which I/O operations
203 | // fail with a timeout (see type Error) instead of
204 | // blocking. The deadline applies to all future and pending
205 | // I/O, not just the immediately following call to Read or
206 | // Write. After a deadline has been exceeded, the connection
207 | // can be refreshed by setting a deadline in the future.
208 | //
209 | // An idle timeout can be implemented by repeatedly extending
210 | // the deadline after successful Read or Write calls.
211 | //
212 | // A zero value for t means I/O operations will not time out.
213 | //
214 | // Note that if a TCP connection has keep-alive turned on,
215 | // which is the default unless overridden by Dialer.KeepAlive
216 | // or ListenConfig.KeepAlive, then a keep-alive failure may
217 | // also return a timeout error. On Unix systems a keep-alive
218 | // failure on I/O can be detected using
219 | // errors.Is(err, syscall.ETIMEDOUT).
220 | //
221 | // TODO: This MockConnection does not have a way to simulate
222 | // syscall.ETIMEOUT error return resulting from a TCP keepAlive
223 | // failure.
224 | //
225 | // TODO: At present the WriteDeadLine does not have any effect.
226 | func (c *Connection) SetDeadline(t time.Time) error {
227 | err := c.SetReadDeadline(t)
228 | if err == nil {
229 | err = c.SetWriteDeadline(t)
230 | }
231 | return err
232 | }
233 |
234 | // SetReadDeadline sets the deadline for future Read calls
235 | // and any currently-blocked Read call.
236 | // A zero value for t means Read will not time out.
237 | func (c *Connection) SetReadDeadline(t time.Time) error {
238 | c.moreData.L.Lock()
239 | defer c.moreData.L.Unlock()
240 | c.readDeadline = t
241 |
242 | // stop the currently ticking (or possibly fired) timer
243 | oldTimer := c.readDeadLineTimer
244 | if !oldTimer.Stop() {
245 | select {
246 | case <-oldTimer.C: // drain if it fired
247 | default: // it was stopped already
248 | }
249 | }
250 | if t.IsZero() {
251 | // No need to wake those that are blocked, simply keep the timer in stopped state
252 | return nil
253 | }
254 | // Set a new duration
255 | oldTimer.Reset(time.Until(t))
256 | return nil
257 | }
258 |
259 | // SetWriteDeadline sets the deadline for future Write calls
260 | // and any currently-blocked Write call.
261 | // Even if write times out, it may return n > 0, indicating that
262 | // some of the data was successfully written.
263 | // A zero value for t means Write will not time out.
264 | // TODO: At present this WriteDeadLine does not have any effect.
265 | func (c *Connection) SetWriteDeadline(t time.Time) error {
266 | c.writeDeadline = t
267 | return nil
268 | }
269 |
270 | // RemoteIO is an interface for operations on what MockConnection.Remote() returns
271 | type RemoteIO interface {
272 | io.ReadWriter
273 | io.ByteReader
274 | }
275 |
276 | // Remote is a io.ReadWriter that adapts the "remote" end of a MockConnection to
277 | // io.ReadWriter
278 | type Remote struct {
279 | conn *Connection
280 | }
281 |
282 | func (r *Remote) Read(b []byte) (int, error) {
283 | return r.conn.RemoteRead(b)
284 | }
285 |
286 | func (r *Remote) Write(b []byte) (int, error) {
287 | return r.conn.RemoteWrite(b)
288 | }
289 |
290 | // ReadByte reads one byte - implements io.ByteReader
291 | func (r *Remote) ReadByte() (byte, error) {
292 | b := make([]byte, 1)
293 | n, err := r.conn.RemoteRead(b)
294 | if err != nil || n != 1 {
295 | return 0, err
296 | }
297 | return b[0], err
298 | }
299 |
300 | // Remote returns a io.ReadWriter for the remote end of this MockConnetion
301 | func (c *Connection) Remote() RemoteIO {
302 | return &Remote{conn: c}
303 | }
304 |
305 | // TimeoutError is returned for an expired deadline.
306 | // It implements net.Error interface
307 | // This type is needed because the error actually returned from net.Conn is an internal data type
308 | //
309 | type TimeoutError struct{}
310 |
311 | // Error Implement the net.Error interface.
312 | func (e *TimeoutError) Error() string { return "i/o timeout" }
313 |
314 | // Timeout Implement the net.Error interface.
315 | func (e *TimeoutError) Timeout() bool { return true }
316 |
317 | // Temporary Implement the net.Error interface.
318 | func (e *TimeoutError) Temporary() bool { return true }
319 |
320 | // ErrTimeout is returned for an expired deadline
321 | var ErrTimeout error = &TimeoutError{}
322 |
--------------------------------------------------------------------------------
/test/mock/connection_test.go:
--------------------------------------------------------------------------------
1 | package mock
2 |
3 | import (
4 | "io"
5 | "net"
6 | "testing"
7 | "time"
8 |
9 | "github.com/tada/mqtt-nats/test/utils"
10 | )
11 |
12 | func Test_MockConnection_implements_net_Conn(t *testing.T) {
13 | defer utils.ShouldNotPanic(t)
14 | _ = net.Conn(NewConnection())
15 | }
16 |
17 | func Test_MockConnection_has_hardocded_local_and_remote_addr(t *testing.T) {
18 | conn := NewConnection()
19 | local := conn.LocalAddr()
20 | remote := conn.RemoteAddr()
21 | utils.CheckEqual("0.0.0.0", local.String(), t)
22 | utils.CheckEqual("tcp", local.Network(), t)
23 | utils.CheckEqual("0.0.0.0", remote.String(), t)
24 | utils.CheckEqual("tcp", remote.Network(), t)
25 | }
26 |
27 | // Test that RemoteWrite can be called without error
28 | func Test_MockConnection_Can_do_RemoteWrite(t *testing.T) {
29 | conn := NewConnection()
30 | n, err := conn.RemoteWrite([]byte("test"))
31 | utils.CheckEqual(4, n, t)
32 | utils.CheckNotError(err, t)
33 | }
34 |
35 | // Test that what is written "remotely" can be read "locally"
36 | func Test_MockConnection_Read_reads_what_was_written_with_RemoteWrite(t *testing.T) {
37 | conn := NewConnection()
38 | n, err := conn.RemoteWrite([]byte("test"))
39 | utils.CheckEqual(4, n, t)
40 | utils.CheckNotError(err, t)
41 | buf := make([]byte, 4)
42 | n, err = conn.Read(buf)
43 | utils.CheckEqual(4, n, t)
44 | utils.CheckNotError(err, t)
45 | }
46 |
47 | // Test that what is written "locally" can be read "remotely"
48 | func Test_MockConnection_ReadRemote_reads_what_was_written_with_Write(t *testing.T) {
49 | conn := NewConnection()
50 | n, err := conn.Write([]byte("test"))
51 | utils.CheckEqual(4, n, t)
52 | utils.CheckNotError(err, t)
53 | buf := make([]byte, 4)
54 | n, err = conn.RemoteRead(buf)
55 | utils.CheckEqual(4, n, t)
56 | utils.CheckNotError(err, t)
57 | }
58 |
59 | func Test_MockConnection_Read_waits_for_data_until_close(t *testing.T) {
60 | conn := NewConnection()
61 | readResult := make(chan error)
62 | timeout := make(chan bool)
63 | readEndTime := time.Now()
64 | closeTime := time.Now().Add(1 * time.Millisecond) // ensure this is after
65 |
66 | go func() {
67 | delay := 200 * time.Millisecond
68 | time.Sleep(delay)
69 | timeout <- true
70 | }()
71 | go func() {
72 | aByte := make([]byte, 1)
73 | _, err := conn.Read(aByte)
74 | readEndTime = time.Now()
75 | readResult <- err
76 | }()
77 |
78 | // Wait for the timeout
79 | <-timeout
80 | if err := conn.Close(); err != nil {
81 | panic(err)
82 | }
83 |
84 | // Wait for the read
85 | err := <-readResult
86 |
87 | // Did they occur in the expected order?
88 | utils.CheckTrue(readEndTime.After(closeTime), t)
89 | utils.CheckTrue(err == io.EOF, t)
90 | }
91 |
92 | func Test_MockConnection_Read_waits_until_given_read_deadline(t *testing.T) {
93 | conn := NewConnection()
94 | if err := conn.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil {
95 | t.Fatal(err)
96 | }
97 |
98 | aByte := make([]byte, 1)
99 | _, err := conn.Read(aByte)
100 | nerr, ok := err.(net.Error)
101 | if !ok {
102 | t.Fatalf("Expected a net.Error but could not convert it")
103 | }
104 | utils.CheckTrue(nerr.Timeout(), t)
105 | }
106 |
107 | func Test_MockConnection_Read_returns_amount_read_if_buffer_becomes_empty(t *testing.T) {
108 | conn := NewConnection()
109 | oneByte := []byte{1}
110 | n, err := conn.RemoteWrite(oneByte)
111 | utils.CheckEqual(1, n, t)
112 | utils.CheckNotError(err, t)
113 |
114 | threeBytes := make([]byte, 3)
115 | n, err = conn.Read(threeBytes)
116 | utils.CheckEqual(1, n, t)
117 | utils.CheckNotError(err, t)
118 | utils.CheckEqual(byte(1), threeBytes[0], t)
119 |
120 | oneByte = []byte{2}
121 | n, err = conn.RemoteWrite(oneByte)
122 | utils.CheckEqual(1, n, t)
123 | utils.CheckNotError(err, t)
124 |
125 | n, err = conn.Read(threeBytes)
126 | utils.CheckEqual(1, n, t)
127 | utils.CheckNotError(err, t)
128 | utils.CheckEqual(byte(2), threeBytes[0], t)
129 | }
130 |
131 | // TODO: Test Multithreaded reading and writing
132 | // 1. n threads write 1 byte each, n threads read one byte each - all bytes are read
133 |
--------------------------------------------------------------------------------
/test/package.go:
--------------------------------------------------------------------------------
1 | // +build citest
2 |
3 | // Package test contains the in-process functional tests for a fully configure mqtt-nats bridge +
4 | // nats test server combo
5 | package test
6 |
--------------------------------------------------------------------------------
/test/packet/parse.go:
--------------------------------------------------------------------------------
1 | // +build citest
2 |
3 | // Package packet contains a test version of the MQTT packet parser
4 | package packet
5 |
6 | import (
7 | "fmt"
8 | "io"
9 | "testing"
10 |
11 | "github.com/tada/mqtt-nats/mqtt"
12 | "github.com/tada/mqtt-nats/mqtt/pkg"
13 | )
14 |
15 | // Parse is a test helper function that parses the next package from the given reader and returns it. t.Fatal(err)
16 | // will be called if an error occurs when reading or parsing.
17 | func Parse(t *testing.T, rdr io.Reader) pkg.Packet {
18 | t.Helper()
19 | // Read packet type and flags
20 | r := mqtt.NewReader(rdr)
21 | var p pkg.Packet
22 | b, err := r.ReadByte()
23 | if err == nil {
24 | pkgType := b & pkg.TpMask
25 |
26 | // Read packet length
27 | var rl int
28 | rl, err = r.ReadVarInt()
29 | if err == nil {
30 | switch pkgType {
31 | case pkg.TpConnAck:
32 | p, err = pkg.ParseConnAck(r, b, rl)
33 | case pkg.TpDisconnect:
34 | p = pkg.DisconnectSingleton
35 | case pkg.TpPing:
36 | p = pkg.PingRequestSingleton
37 | case pkg.TpPingResp:
38 | p = pkg.PingResponseSingleton
39 | case pkg.TpConnect:
40 | p, err = pkg.ParseConnect(r, b, rl)
41 | case pkg.TpPublish:
42 | p, err = pkg.ParsePublish(r, b, rl)
43 | case pkg.TpPubAck:
44 | p, err = pkg.ParsePubAck(r, b, rl)
45 | case pkg.TpPubRec:
46 | p, err = pkg.ParsePubRec(r, b, rl)
47 | case pkg.TpPubRel:
48 | p, err = pkg.ParsePubRel(r, b, rl)
49 | case pkg.TpPubComp:
50 | p, err = pkg.ParsePubComp(r, b, rl)
51 | case pkg.TpSubscribe:
52 | p, err = pkg.ParseSubscribe(r, b, rl)
53 | case pkg.TpSubAck:
54 | p, err = pkg.ParseSubAck(r, b, rl)
55 | case pkg.TpUnsubscribe:
56 | p, err = pkg.ParseUnsubscribe(r, b, rl)
57 | case pkg.TpUnsubAck:
58 | p, err = pkg.ParseUnsubAck(r, b, rl)
59 | default:
60 | err = fmt.Errorf("received unknown packet type %d", (b&pkg.TpMask)>>4)
61 | }
62 | }
63 | }
64 | if err != nil {
65 | t.Fatal(err)
66 | }
67 | return p
68 | }
69 |
--------------------------------------------------------------------------------
/test/packet/parse_test.go:
--------------------------------------------------------------------------------
1 | package packet
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 |
7 | "github.com/tada/mqtt-nats/test/utils"
8 | )
9 |
10 | func TestParse_unknownPacket(t *testing.T) {
11 | utils.EnsureFailed(t, func(st *testing.T) {
12 | Parse(st, bytes.NewReader([]byte{0xf0, 0}))
13 | }, "unknown packet type 15")
14 | }
15 |
--------------------------------------------------------------------------------
/test/publish_test.go:
--------------------------------------------------------------------------------
1 | package test
2 |
3 | import (
4 | "bytes"
5 | "testing"
6 | "time"
7 |
8 | "github.com/nats-io/nats.go"
9 | "github.com/tada/mqtt-nats/mqtt/pkg"
10 | "github.com/tada/mqtt-nats/test/full"
11 | )
12 |
13 | func TestPublishSubscribe(t *testing.T) {
14 | topic := "testing/some/topic"
15 | pp := pkg.SimplePublish(topic, []byte("payload"))
16 | gotIt := make(chan bool, 1)
17 | c1 := full.MqttConnectClean(t, mqttPort)
18 | mid := nextPacketID()
19 | full.MqttSend(t, c1, pkg.NewSubscribe(mid, pkg.Topic{Name: topic}))
20 | full.MqttExpect(t, c1, pkg.NewSubAck(mid, 0))
21 | go func() {
22 | full.MqttExpect(t, c1, pp)
23 | gotIt <- true
24 | full.MqttDisconnect(t, c1)
25 | }()
26 |
27 | c2 := full.MqttConnectClean(t, mqttPort)
28 | full.MqttSend(t, c2, pp)
29 | full.MqttDisconnect(t, c2)
30 | full.AssertMessageReceived(t, gotIt)
31 | }
32 |
33 | func TestPublishSubscribe_qos_1(t *testing.T) {
34 | topic := "testing/some/topic"
35 | mid := nextPacketID()
36 | pp := pkg.NewPublish2(mid, topic, []byte("payload"), 1, false, false)
37 | c1 := full.MqttConnectClean(t, mqttPort)
38 | gotIt := make(chan bool, 1)
39 | go func() {
40 | sid := nextPacketID()
41 | full.MqttSend(t, c1, pkg.NewSubscribe(sid, pkg.Topic{Name: topic, QoS: 1}))
42 | full.MqttExpect(t, c1, pkg.NewSubAck(sid, 1), pp)
43 | full.MqttSend(t, c1, pkg.PubAck(mid))
44 | full.MqttDisconnect(t, c1)
45 | gotIt <- true
46 | }()
47 |
48 | c2 := full.MqttConnectClean(t, mqttPort)
49 | full.MqttSend(t, c2, pp)
50 | full.MqttExpect(t, c2, pkg.PubAck(mid))
51 | full.MqttDisconnect(t, c2)
52 | full.AssertMessageReceived(t, gotIt)
53 | }
54 |
55 | func TestPublishSubscribe_qos_2(t *testing.T) {
56 | topic := "testing/some/topic"
57 | mid := nextPacketID()
58 | pp := pkg.NewPublish2(mid, topic, []byte("payload"), 1, false, false)
59 | c1 := full.MqttConnectClean(t, mqttPort)
60 | gotIt := make(chan bool, 1)
61 | go func() {
62 | sid := nextPacketID()
63 | full.MqttSend(t, c1, pkg.NewSubscribe(sid, pkg.Topic{Name: topic, QoS: 2}))
64 | full.MqttExpect(t, c1, pkg.NewSubAck(sid, 1), pp)
65 | full.MqttSend(t, c1, pkg.PubAck(mid))
66 | full.MqttDisconnect(t, c1)
67 | gotIt <- true
68 | }()
69 |
70 | c2 := full.MqttConnectClean(t, mqttPort)
71 | full.MqttSend(t, c2, pp)
72 | full.MqttExpect(t, c2, pkg.PubAck(mid))
73 | full.MqttDisconnect(t, c2)
74 | full.AssertMessageReceived(t, gotIt)
75 | }
76 |
77 | func TestPublishSubscribe_qos_1_restart(t *testing.T) {
78 | topic := "testing/some/topic"
79 | mid := nextPacketID()
80 | pp := pkg.NewPublish2(mid, topic, []byte("payload"), 1, false, false)
81 |
82 | c1ID := full.NextClientID()
83 | c1 := full.MqttConnect(t, mqttPort)
84 | gotIt := make(chan bool, 1)
85 | go func() {
86 | full.MqttSend(t, c1, pkg.NewConnect(c1ID, false, 1, nil, nil))
87 | full.MqttExpect(t, c1, pkg.NewConnAck(false, 0))
88 |
89 | sid := nextPacketID()
90 | full.MqttSend(t, c1, pkg.NewSubscribe(sid, pkg.Topic{Name: topic, QoS: 1}))
91 | full.MqttExpect(t, c1, pkg.NewSubAck(sid, 1))
92 | gotIt <- true
93 | full.MqttExpect(t, c1, pp)
94 | gotIt <- true
95 | }()
96 |
97 | c2ID := full.NextClientID()
98 | c2 := full.MqttConnect(t, mqttPort)
99 | full.MqttSend(t, c2, pkg.NewConnect(c2ID, false, 1, nil, nil))
100 | full.MqttExpect(t, c2, pkg.NewConnAck(false, 0))
101 | full.AssertMessageReceived(t, gotIt)
102 | full.MqttSend(t, c2, pp)
103 | full.AssertMessageReceived(t, gotIt)
104 |
105 | full.RestartBridge(t, mqttServer)
106 |
107 | // client c1 reestablishes session and sends outstanding ack
108 | c1 = full.MqttConnect(t, mqttPort)
109 | full.MqttSend(t, c1, pkg.NewConnect(c1ID, false, 1, nil, nil))
110 | full.MqttExpect(t, c1, pkg.NewConnAck(true, 0))
111 |
112 | // client c2 reestablishes session and receives outstanding ack
113 | c2 = full.MqttConnect(t, mqttPort)
114 | full.MqttSend(t, c2, pkg.NewConnect(c2ID, false, 1, nil, nil))
115 | full.MqttExpect(t, c2, pkg.NewConnAck(true, 0))
116 |
117 | full.MqttSend(t, c1, pkg.PubAck(mid))
118 | full.MqttExpect(t, c2, pkg.PubAck(mid))
119 |
120 | full.MqttDisconnect(t, c1)
121 | full.MqttDisconnect(t, c2)
122 | }
123 |
124 | func TestMqttPublishNatsSubscribe(t *testing.T) {
125 | pl := []byte("payload")
126 | pp := pkg.SimplePublish("testing/s.o.m.e/topic", pl)
127 | gotIt := make(chan bool, 1)
128 | nc := full.NatsConnect(t, natsPort)
129 | defer nc.Close()
130 |
131 | _, err := nc.Subscribe("testing.s/o/m/e.>", func(m *nats.Msg) {
132 | if !bytes.Equal(pl, m.Data) {
133 | t.Error("nats subscription did not receive expected data")
134 | }
135 | gotIt <- true
136 | })
137 | if err != nil {
138 | t.Fatal(err)
139 | }
140 |
141 | c2 := full.MqttConnectClean(t, mqttPort)
142 | full.MqttSend(t, c2, pp)
143 | full.MqttDisconnect(t, c2)
144 | full.AssertMessageReceived(t, gotIt)
145 | }
146 |
147 | func TestNatsPublishMqttSubscribe(t *testing.T) {
148 | topic := "testing/some/topic"
149 | pl := []byte("payload")
150 | pp := pkg.SimplePublish(topic, pl)
151 |
152 | c1 := full.MqttConnectClean(t, mqttPort)
153 | sid := nextPacketID()
154 | full.MqttSend(t, c1, pkg.NewSubscribe(sid, pkg.Topic{Name: "testing/+/topic"}))
155 | full.MqttExpect(t, c1, pkg.NewSubAck(sid, 0))
156 |
157 | gotIt := make(chan bool, 1)
158 | go func() {
159 | full.MqttExpect(t, c1, pp)
160 | full.MqttDisconnect(t, c1)
161 | gotIt <- true
162 | }()
163 |
164 | nc := full.NatsConnect(t, natsPort)
165 | defer nc.Close()
166 | err := nc.Publish("testing.some.topic", pl)
167 | if err != nil {
168 | t.Fatal(err)
169 | }
170 | full.AssertMessageReceived(t, gotIt)
171 | }
172 |
173 | func TestNatsPublishMqttSubscribe_qos_1(t *testing.T) {
174 | topic := "testing/some/topic"
175 | pl := []byte("payload")
176 |
177 | c1 := full.MqttConnectClean(t, mqttPort)
178 | sid := nextPacketID()
179 | full.MqttSend(t, c1, pkg.NewSubscribe(sid, pkg.Topic{Name: topic, QoS: 1}))
180 | full.MqttExpect(t, c1, pkg.NewSubAck(sid, 1))
181 |
182 | gotIt := make(chan bool, 1)
183 | go func() {
184 | full.MqttExpect(t, c1, func(p pkg.Packet) bool {
185 | if pp, ok := p.(*pkg.Publish); ok {
186 | full.MqttSend(t, c1, pkg.PubAck(pp.ID()))
187 | return pp.TopicName() == topic && bytes.Equal(pp.Payload(), pl) && pp.QoSLevel() == 1
188 | }
189 | return false
190 | })
191 | full.MqttDisconnect(t, c1)
192 | gotIt <- true
193 | }()
194 |
195 | nc := full.NatsConnect(t, natsPort)
196 | defer nc.Close()
197 | _, err := nc.Request("testing.some.topic", pl, 10*time.Millisecond)
198 | if err != nil {
199 | t.Fatal(err)
200 | }
201 | }
202 |
203 | func TestUnubscribe(t *testing.T) {
204 | topic := "testing/some/topic"
205 | pp := pkg.SimplePublish(topic, []byte("payload"))
206 | gotIt := make(chan bool, 1)
207 | c1 := full.MqttConnectClean(t, mqttPort)
208 | mid := nextPacketID()
209 | full.MqttSend(t, c1, pkg.NewSubscribe(mid, pkg.Topic{Name: topic}))
210 | full.MqttExpect(t, c1, pkg.NewSubAck(mid, 0))
211 | go func() {
212 | full.MqttExpect(t, c1, pp)
213 |
214 | uid := nextPacketID()
215 | full.MqttSend(t, c1, pkg.NewUnsubscribe(uid, topic))
216 | full.MqttExpect(t, c1, pkg.UnsubAck(uid))
217 | gotIt <- true
218 | full.MqttExpectConnClosed(t, c1)
219 | gotIt <- true
220 | }()
221 |
222 | c2 := full.MqttConnectClean(t, mqttPort)
223 | full.MqttSend(t, c2, pp)
224 |
225 | // wait for subscriber to consume and unsubscribe
226 | full.AssertMessageReceived(t, gotIt)
227 |
228 | // send again, this should not reach subscriber
229 | full.MqttSend(t, c2, pp)
230 | full.MqttDisconnect(t, c2)
231 | full.AssertTimeout(t, gotIt)
232 | full.MqttDisconnect(t, c1)
233 | }
234 |
235 | func TestPublish_qos_2(t *testing.T) {
236 | conn := full.MqttConnectClean(t, mqttPort)
237 | full.MqttSend(t, conn, pkg.NewPublish2(
238 | nextPacketID(), "testing/some/topic", []byte("payload"), 2, false, false))
239 | full.MqttExpectConnReset(t, conn)
240 | }
241 |
242 | func TestPublish_qos_3(t *testing.T) {
243 | conn := full.MqttConnectClean(t, mqttPort)
244 | full.MqttSend(t, conn, pkg.NewPublish2(
245 | nextPacketID(), "testing/some/topic", []byte("payload"), 3, false, false))
246 | full.MqttExpectConnReset(t, conn)
247 | }
248 |
--------------------------------------------------------------------------------
/test/retained_test.go:
--------------------------------------------------------------------------------
1 | package test
2 |
3 | import (
4 | "encoding/base64"
5 | "encoding/json"
6 | "testing"
7 | "time"
8 |
9 | "github.com/tada/mqtt-nats/mqtt"
10 | "github.com/tada/mqtt-nats/mqtt/pkg"
11 | "github.com/tada/mqtt-nats/test/full"
12 | )
13 |
14 | func decodeRetained(t *testing.T, data []byte) []*pkg.Publish {
15 | t.Helper()
16 | var ms []map[string]string
17 | err := json.Unmarshal(data, &ms)
18 | if err != nil {
19 | t.Fatal(err)
20 | }
21 | ns := make([]*pkg.Publish, len(ms))
22 | for i := range ms {
23 | m := ms[i]
24 | var pl []byte
25 | if ps, ok := m["payload"]; ok {
26 | pl = []byte(ps)
27 | } else if pe, ok := m["payloadEnc"]; ok {
28 | if pl, err = base64.StdEncoding.DecodeString(pe); err != nil {
29 | t.Fatal(err)
30 | }
31 | }
32 | ns[i] = pkg.NewPublish2(0, mqtt.FromNATS(m["subject"]), pl, 0, false, true)
33 | }
34 | return ns
35 | }
36 |
37 | func TestNATS_requestRetained(t *testing.T) {
38 | conn := full.MqttConnectClean(t, mqttPort)
39 | pp1 := pkg.NewPublish2(0, "testing/s.o.m.e/retained/first", []byte("the first retained message"), 0, false, true)
40 | pp2 := pkg.NewPublish2(0, "testing/s.o.m.e/retained/second", []byte("the second retained message"), 0, false, true)
41 | full.MqttSend(t, conn, pp1, pp2)
42 | full.MqttDisconnect(t, conn)
43 |
44 | nc := full.NatsConnect(t, natsPort)
45 | defer nc.Close()
46 |
47 | m, err := nc.Request(retainedRequestTopic, []byte("testing.s/o/m/e.retained.>"), 10*time.Millisecond)
48 | if err != nil {
49 | t.Fatal(err)
50 | }
51 |
52 | pps := decodeRetained(t, m.Data)
53 | if !(len(pps) == 2 && pp1.Equals(pps[0]) && pp2.Equals(pps[1])) {
54 | t.Fatal("unexpected retained publication")
55 | }
56 |
57 | m, err = nc.Request(retainedRequestTopic, []byte("do.not.find.this"), 10*time.Millisecond)
58 | if err != nil {
59 | t.Fatal(err)
60 | }
61 | pps = decodeRetained(t, m.Data)
62 | if len(pps) != 0 {
63 | t.Fatal("unexpected retained publication")
64 | }
65 | }
66 |
--------------------------------------------------------------------------------
/test/tls/connect_test.go:
--------------------------------------------------------------------------------
1 | package tls
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/tada/mqtt-nats/mqtt/pkg"
7 | "github.com/tada/mqtt-nats/test/full"
8 | )
9 |
10 | func TestConnect(t *testing.T) {
11 | conn := tlsDial(t, mqttPort)
12 | full.MqttSend(t, conn, pkg.NewConnect(full.NextClientID(), true, 1, nil, nil))
13 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0))
14 | full.MqttDisconnect(t, conn)
15 | }
16 |
--------------------------------------------------------------------------------
/test/tls/main_test.go:
--------------------------------------------------------------------------------
1 | package tls
2 |
3 | import (
4 | "crypto/tls"
5 | "crypto/x509"
6 | "io/ioutil"
7 | "net"
8 | "os"
9 | "strconv"
10 | "testing"
11 |
12 | "github.com/nats-io/nats-server/v2/test"
13 |
14 | "github.com/nats-io/nats.go"
15 | "github.com/tada/mqtt-nats/bridge"
16 | "github.com/tada/mqtt-nats/logger"
17 | "github.com/tada/mqtt-nats/test/full"
18 | )
19 |
20 | var mqttServer bridge.Bridge
21 |
22 | const (
23 | storageFile = "mqtt-nats.json"
24 | mqttPort = 18883
25 | natsPort = 14443
26 | retainedRequestTopic = "mqtt.retained.request"
27 | )
28 |
29 | // tlsDial establishes a tcp tls connection to the given port on the default host
30 | func tlsDial(t *testing.T, port int) net.Conn {
31 | t.Helper()
32 | pbs, err := ioutil.ReadFile("testdata/ca.pem")
33 | if err != nil {
34 | t.Fatal(err)
35 | }
36 | roots := x509.NewCertPool()
37 | roots.AppendCertsFromPEM(pbs)
38 | cert, err := tls.LoadX509KeyPair("testdata/client.pem", "testdata/client-key.pem")
39 | if err != nil {
40 | t.Fatal(err)
41 | }
42 | conn, err := tls.Dial("tcp", ":"+strconv.Itoa(port), &tls.Config{
43 | ServerName: "127.0.0.1",
44 | RootCAs: roots,
45 | Certificates: []tls.Certificate{cert},
46 | })
47 | if err != nil {
48 | t.Fatal(err)
49 | }
50 | return conn
51 | }
52 |
53 | func TestMain(m *testing.M) {
54 | _ = os.Remove(storageFile)
55 | natsServer, _ := test.RunServerWithConfig("server.conf")
56 |
57 | natsOpts := []nats.Option{nats.Name("MQTT Bridge")}
58 | natsOpts = append(natsOpts, nats.ClientCert("testdata/client.pem", "testdata/client-key.pem"))
59 | natsOpts = append(natsOpts, nats.RootCAs("testdata/ca.pem"))
60 |
61 | // NOTE: Setting level to logger.Debug here is very helpful when authoring and debugging tests but
62 | // it also makes the tests very verbose.
63 | lg := logger.New(logger.Silent, os.Stdout, os.Stderr)
64 |
65 | opts := bridge.Options{
66 | Port: mqttPort,
67 | NATSUrls: "localhost:" + strconv.Itoa(natsPort),
68 | RepeatRate: 50,
69 | RetainedRequestTopic: retainedRequestTopic,
70 | StoragePath: storageFile,
71 | TLS: true,
72 | TLSCaCert: "testdata/ca.pem",
73 | TLSCert: "testdata/server.pem",
74 | TLSKey: "testdata/server-key.pem",
75 | NATSOpts: natsOpts}
76 |
77 | var err error
78 | mqttServer, err = full.RunBridge(lg, &opts)
79 |
80 | var code int
81 | if err == nil {
82 | code = m.Run()
83 | } else {
84 | lg.Error(err)
85 | code = 1
86 | }
87 | natsServer.Shutdown()
88 | if err = mqttServer.Shutdown(); err != nil {
89 | lg.Error(err)
90 | code = 1
91 | }
92 | os.Exit(code)
93 | }
94 |
--------------------------------------------------------------------------------
/test/tls/package.go:
--------------------------------------------------------------------------------
1 | // +build citest
2 |
3 | // Package tls contains full roundtrip tests using TLS configured connections
4 | package tls
5 |
--------------------------------------------------------------------------------
/test/tls/server.conf:
--------------------------------------------------------------------------------
1 | listen: localhost:14443
2 | tls: {
3 | ca_file: "./testdata/ca.pem"
4 | cert_file: "./testdata/server.pem"
5 | key_file: "./testdata/server-key.pem"
6 | verify: true
7 | }
8 |
--------------------------------------------------------------------------------
/test/tls/testdata/ca-key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN RSA PRIVATE KEY-----
2 | MIIEpQIBAAKCAQEA8qmvXNCsGLhQeq5/XIZp0bLL8lOqvNQz86U5QQwETnxs1KDU
3 | DquhNhzd2lgvaP98zO8Owkqz/kXIT4R/RxIeZ1BD9gY9iVVMyZar/bI+72DpX4Hg
4 | x7+UC9m34CAfK/AAYGZIgjhpDBobkVaKjVahBWUF7BWon+1xu23XdS99lxThJCTv
5 | G5gfquGFf4eSvhGSt7SRzW5nVP+OpTKxNMMawZa15jh3U7BfrXxDeZfDKFxF2RaK
6 | 6Lwi+ATLLe9wwvEKvCvDT8XPKsm9RUl48pLKRMy73tWitSh9SSFoa5zd49qIteLt
7 | LPKbNWPJLUyVrJzvbv9hVVJe4CKd6zKH+mYQywIDAQABAoIBAQCDGsp0CwnwESTq
8 | I30MMFLbyQ4HTszgWIX5DTtxuVxaSz9BYeMwSeo/ojj6zspOoDp9Pmtq7ZFxv6IJ
9 | 1Dwv2cozZ1pQge6dVEi4YX9rAfKewm1T/IfFY+xIushtfu1Yf8K0Uo66TF/0+eYL
10 | EAardjJpB7u7YbhJL7BS43WVCqOAC+p/nzJAyjETvvJhzUo7LEc7IeLUhAHCfTOX
11 | tUdmEOQnRy9xhgd2wkOwrt4loAxwdZALvbLfl58YtE6FgKeDLTQrNQzsA3a5dB2O
12 | 2N2UbB26RjTo7cshnj09uRuTWCiu9NYzPC/u6L9avPkRcXIgtqHrJ2k7U9MnRXXD
13 | 7lGRWg4BAoGBAPYQte1B8y5fNG/TEC/yV2lNO5zd+Noq/xQXG0huULYUXpaPcxg3
14 | 5+oBY7+Y6suI4XDasLKTz6HFIB414cKg9us0VH1fMJNM6HOqNdjRbR2udJ7fersW
15 | 24CviytYck7bbulDzmJgfxIpdF7zsFkh7MbK/8b9KMSxkBvngCWmk/f7AoGBAPx1
16 | zqk9RUL7DAOBqsbMQi8ldhRapSqOxEf3pKk+HaQ92l96ef+yGMghfhc4uKLTZ9tY
17 | wqFnPbi5Drsgboaoepcy3XYOFB13E+eac6px0l50//iQg5hEHkKdlvspsmVWLOMP
18 | Y0wir5X6Q0Fmw+ORcy0jpEy0CKT5wolZnhmuxuFxAoGBAPYPv92CFaxJiCZK6eUI
19 | cmDa2sIDNtb0KB/u+1ly90MdG3lz+aQ+Q6u9uAHg6Oqf9tDj3860AO3EMloDh78Z
20 | N9H8goDcr7adMdZ4X2ByDKuhyP0WfaSZNud4o7K0v5ob1M1vAPNfi7KdwcEx7ycy
21 | xZQFa8GRZzNKXNGKrpr3+QABAoGBAIhl3dHyGImnuUXruKjPkrKGOtWkY7gqikGX
22 | uo710G38PQ94zJEpV9pIvictrhPKxEHuIrmxXdd/pEXVr+FxBUrLYHt3/8Yrn8Vx
23 | 3Swpcs81x1Y0PeT2aKL1Ia1xScEWXgoPNkbcNqGBJPUg4JUC8IdiylHmswTvK/up
24 | P5IAq9MBAoGAVMOChcBBOAwzsgaVWm9hAcdMfh9gjOF6cBnG2McigrG8RY6E3AYE
25 | /nK4+zOrYfHBo7GL0Gd5JxmNL8Nu2QzqCLdr5cLw8gLbwmfgahU6siJc5GQrOsTS
26 | TEklyqEaDQqnvMYVsLvza/W3K8g6GX8Xo/NOBjYeeFcRPSP94g56JaE=
27 | -----END RSA PRIVATE KEY-----
28 |
--------------------------------------------------------------------------------
/test/tls/testdata/ca.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIDhTCCAm2gAwIBAgITM8zKCwt+VQxi6YGDirBkR12XmDANBgkqhkiG9w0BAQsF
3 | ADBbMQswCQYDVQQGEwJTRTEYMBYGA1UECAwPU3RvY2tob2xtcyBMw6RuMQ4wDAYD
4 | VQQHDAVUw6RieTEQMA4GA1UEChMHVGFkYSBBQjEQMA4GA1UEAxMHdGFkYS5zZTAe
5 | Fw0yMDAzMjMyMDQzMDBaFw0yNTAzMjIyMDQzMDBaMFsxCzAJBgNVBAYTAlNFMRgw
6 | FgYDVQQIDA9TdG9ja2hvbG1zIEzDpG4xDjAMBgNVBAcMBVTDpGJ5MRAwDgYDVQQK
7 | EwdUYWRhIEFCMRAwDgYDVQQDEwd0YWRhLnNlMIIBIjANBgkqhkiG9w0BAQEFAAOC
8 | AQ8AMIIBCgKCAQEA8qmvXNCsGLhQeq5/XIZp0bLL8lOqvNQz86U5QQwETnxs1KDU
9 | DquhNhzd2lgvaP98zO8Owkqz/kXIT4R/RxIeZ1BD9gY9iVVMyZar/bI+72DpX4Hg
10 | x7+UC9m34CAfK/AAYGZIgjhpDBobkVaKjVahBWUF7BWon+1xu23XdS99lxThJCTv
11 | G5gfquGFf4eSvhGSt7SRzW5nVP+OpTKxNMMawZa15jh3U7BfrXxDeZfDKFxF2RaK
12 | 6Lwi+ATLLe9wwvEKvCvDT8XPKsm9RUl48pLKRMy73tWitSh9SSFoa5zd49qIteLt
13 | LPKbNWPJLUyVrJzvbv9hVVJe4CKd6zKH+mYQywIDAQABo0IwQDAOBgNVHQ8BAf8E
14 | BAMCAQYwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUoZgRiuciAK1H9qeAXuSa
15 | hHINGdIwDQYJKoZIhvcNAQELBQADggEBAA0T6GlzDPE/aBggBPVvJtBNgzsqWuf/
16 | Z3osgtBq0T0Z5pe+1x+lYJaKADb/qgY7gpGuPVOuKJGvFiXAGknSYdR+PQwf7VNZ
17 | WTr89eR4oAS+qwF0hRTi93l8j5rZ1YgP1eUf9L2D6ycGydyoDJbLhIdCyEz9t89r
18 | ZnY5hj9xxA9dRJu3A9MAvBU9gGE0Nwh9BWp3//qrNo0ixgIZS1l6gm0CtBHRn1l/
19 | YwOsYsQv4hOlf4M1CsHIYEtnOX7MDzoDYU+B9JEBvJbNoqLDQQ0XLWlp31+tMImm
20 | N1DNkOEObT9oogVb59PlBb8aHwnfT+wZ56FyVbzrkFSOTjgOX/t427Q=
21 | -----END CERTIFICATE-----
22 |
--------------------------------------------------------------------------------
/test/tls/testdata/client-key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN RSA PRIVATE KEY-----
2 | MIIEpAIBAAKCAQEAwOhHDGCm6kR9w/UbCcZ25YlBNO8gzOV2OXPT4xfo1x/ADwe2
3 | znf6hoIC+RTe+cFWH2cy/vw3XUy1SqloB7bgh8a+APPBaHI+p8RElE4+zVmhGsXN
4 | y0atdn0hecsJN4z3DrwwC6LdkSnfu1pNg7oO6DEyyewC/wZ7NmLZw+E4R+Vo+/nn
5 | Vri5cHIuwcD7wkr+dgIBeU9G7+YcM2hAjj069EC9ur5lYN0AkwcBhFo+4RXmmChv
6 | yzowvbOtSXWElA7XDCSLzPMjYsZ14fiCpDv/w1Fx2JpYZbbT0HlHfmik2utlocuX
7 | S4q5Q44hcMrv60p//IjRXfEnw/bZVuNbXfHbkQIDAQABAoIBAE8owcbtfnERi/42
8 | fVLkkvOcABsFqZMK8hmfUyqULCLiz4AbbUOKbk512Vx22Qzp7jpSsdV6kAmEKbyq
9 | iZroy3hL8LoZTJtcjiNv4aht901y4y5GTy2EIjhGHs+Ipo6aFOOCC8EqovsnkLyj
10 | 0L0mQ2m6jpnXdF9MPJFTvQKpT9wIKz6HXdBIQZJF7CKVIlXhaWgjc9nmE/pdXfKQ
11 | GsNKLZP6JGs04VfQVkIKO9/6kVeRONVraepl9kv5Wis3qRTP44xZ58lxKqymcBgX
12 | 2HUQsGeLsBK2vOfk1RXc7SwJg8H7h03ftp1nYNUlJWnHA8EeCkJzv+eS6V1j2I9y
13 | 69KgLKUCgYEA92VBIesA8AY3GDqjBqgOH69X5TFUq8ssFbAOA6hhFnORTbBbyHRZ
14 | kEBT0A+tl0U30Z61sLrk2M6gEvTPkyy2rz6B28PkObwtmKOLNjO9eeGevhnnDR2T
15 | bOQ08HNZrBxPiydOK+d8pfVgvp4Xj3qWrHdvBLVMLajamn8ojQFtAKsCgYEAx53f
16 | 2CEdgDns+DHetm2K8aqabESInV3i3o/IuWVVCaqjnGzgqTwxNg/Z9zmJcxvxa8/g
17 | INs8hTLFX2a+8ATXA7azDMI3ffFU0sll/fulA6DLlFZVZnZuxArtnqWiY1X9cQyo
18 | O3N1BfunzqXXlSzmK3QjMQa0mTxPsn+MeTnXLLMCgYAfEjl+8Av7GVy8D0lAYcT8
19 | V8JbR7nRpb/QrX7lGLWw4yzhq/+rCmnhQyMDo6Rytj/PdPZuztpFHJZgKx0S5++9
20 | zMT0fALi+W5kmE24rgDjGOIeEBTDwe4tI/A+Ls6ZXijjWjloLDeshEf1SNe+rm/U
21 | E1//IGID7gwekU/ffclZ5wKBgQC8hg7taUEaZBq4wSistEJAQTa8v/EiZpQoTDVv
22 | WxN4IK+KwY1gZ9e2Tjw18CIvE5nrj5UGkufSiIO9uSTlPDzxZfAuQZL1ICJTPSBV
23 | Qf+zsH30Z6EaNwofno6Sga4fEQxeY2zTURSZhPYUBa7YVWJAcdv2pnWUL1C5rRq3
24 | NvhQXwKBgQCB+wilOcResE0TXBNTXnNvx2/TPGL++BePi9KE/bArX0DpurgwAZ6y
25 | hwEz6yHhtVkAecGuorE82YzPXvWVx3AEOtXnO8AeEneNVAnXX00DbNQLSg+RCenR
26 | cF2ec8KBjbGObndFv5zYZVaHprvz+RW/hCRY51ZwWopFhDkTnSu8bQ==
27 | -----END RSA PRIVATE KEY-----
28 |
--------------------------------------------------------------------------------
/test/tls/testdata/client.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIEADCCAuigAwIBAgIUadhpIhM+k/Twso977sUV0ktk32kwDQYJKoZIhvcNAQEL
3 | BQAwWzELMAkGA1UEBhMCU0UxGDAWBgNVBAgMD1N0b2NraG9sbXMgTMOkbjEOMAwG
4 | A1UEBwwFVMOkYnkxEDAOBgNVBAoTB1RhZGEgQUIxEDAOBgNVBAMTB3RhZGEuc2Uw
5 | HhcNMjAwMzIzMjA0MzAwWhcNMjEwMzIzMjA0MzAwWjBmMQswCQYDVQQGEwJTRTEY
6 | MBYGA1UECAwPU3RvY2tob2xtcyBMw6RuMQ4wDAYDVQQHDAVUw6RieTEQMA4GA1UE
7 | ChMHVGFkYSBBQjEbMBkGA1UEAxMSbXF0dGJyaWRnZS50YWRhLnNlMIIBIjANBgkq
8 | hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwOhHDGCm6kR9w/UbCcZ25YlBNO8gzOV2
9 | OXPT4xfo1x/ADwe2znf6hoIC+RTe+cFWH2cy/vw3XUy1SqloB7bgh8a+APPBaHI+
10 | p8RElE4+zVmhGsXNy0atdn0hecsJN4z3DrwwC6LdkSnfu1pNg7oO6DEyyewC/wZ7
11 | NmLZw+E4R+Vo+/nnVri5cHIuwcD7wkr+dgIBeU9G7+YcM2hAjj069EC9ur5lYN0A
12 | kwcBhFo+4RXmmChvyzowvbOtSXWElA7XDCSLzPMjYsZ14fiCpDv/w1Fx2JpYZbbT
13 | 0HlHfmik2utlocuXS4q5Q44hcMrv60p//IjRXfEnw/bZVuNbXfHbkQIDAQABo4Gw
14 | MIGtMA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUH
15 | AwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUcjlXBU52VMYDbGeK6rtGdD+gs4cw
16 | HwYDVR0jBBgwFoAUoZgRiuciAK1H9qeAXuSahHINGdIwLgYDVR0RBCcwJYIJbG9j
17 | YWxob3N0ggR0YWRhggxob21lLnRhZGEuc2WHBH8AAAEwDQYJKoZIhvcNAQELBQAD
18 | ggEBACXkSDwf8hb5yp80fhmH0yP7cdArEFFisdNvCpdHoWFmvs7O4NK2bkegXP/3
19 | KfMlEom7Ctr01wNdg8rmxspHyVCSTV4h5xywLpXkcrhHzaSAY0hKgxwQABtaUQ0R
20 | oJPtvAJs53M4e6IXuBjdHtWKn1eOdR7/WbnvlTUbmyf3eHin9imkYZHAfT5j3d1m
21 | QfR5K2oX+psrYFUUGkyqFBSDyKl2iSSiC16NyRZ96Ils37KkL55GKUsNDSGLevB+
22 | mTIH4tqVY1CY9PixX/MKl3h9V6GnUkVYKhdW9DgAdmiQUEaPmXvo+R8vRwH8Rfm7
23 | CGzBLfoPu9yiSWaDvCTBcWH4cmo=
24 | -----END CERTIFICATE-----
25 |
--------------------------------------------------------------------------------
/test/tls/testdata/server-key.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN EC PRIVATE KEY-----
2 | MHcCAQEEINM8lQ1SOzLGSC0E1nVW+PUowZTs9LMgoB/gBwtyoV3loAoGCCqGSM49
3 | AwEHoUQDQgAEtd/igiJxP2uUk+lqijBxLQ/ChwN7M/tyznPHE60Xzx9V1UX3Tw9R
4 | 4DS4yy+vUW5ewoSVZ7xZ/oaOQURPg0lRkg==
5 | -----END EC PRIVATE KEY-----
6 |
--------------------------------------------------------------------------------
/test/tls/testdata/server.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIC1jCCAb6gAwIBAgIUAIeqiCsgQQ1qgTCB+Wkcpsq1+VIwDQYJKoZIhvcNAQEL
3 | BQAwWzELMAkGA1UEBhMCU0UxGDAWBgNVBAgMD1N0b2NraG9sbXMgTMOkbjEOMAwG
4 | A1UEBwwFVMOkYnkxEDAOBgNVBAoTB1RhZGEgQUIxEDAOBgNVBAMTB3RhZGEuc2Uw
5 | HhcNMjAwMzIzMjA0MzAwWhcNMjUwMzIyMjA0MzAwWjARMQ8wDQYDVQQDEwZTZXJ2
6 | ZXIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS13+KCInE/a5ST6WqKMHEtD8KH
7 | A3sz+3LOc8cTrRfPH1XVRfdPD1HgNLjLL69Rbl7ChJVnvFn+ho5BRE+DSVGSo4Gm
8 | MIGjMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMB
9 | Af8EAjAAMB0GA1UdDgQWBBTXYQ46vfov5IKoKj1AELNM28uU/zAfBgNVHSMEGDAW
10 | gBShmBGK5yIArUf2p4Be5JqEcg0Z0jAuBgNVHREEJzAlgglsb2NhbGhvc3SCBHRh
11 | ZGGCDGhvbWUudGFkYS5zZYcEfwAAATANBgkqhkiG9w0BAQsFAAOCAQEAw+Cip5Zo
12 | TQdIhBz9Ejm3YYf9A0mch2Q1/fJ4ZpnoK2tBidB21SvSmnwuc7i925t0sJXzK1UR
13 | QSk49IXiVHBb2eIoyUyEffp1Jt9hvcqNKpanBWt0Y7Uta1gwNNhy20jhKWA4HkAJ
14 | D8GglyCRrHaoM8IF28V0wqD+scZjH1EjcNZzBJ0haUrjifoDWcrWruEW3b/KJ0eG
15 | Y3U+LkNzZvRyhpUUt5mCGcH3NfGPigd8XqP1LL8xLYkx0Lwd9McZLQv9EYLROZCE
16 | vRB6ZTSYcVAbGGTo1OrQed8y9c9IKmo8lVl9fMQgl6ICq3tu8mj0c+sc93wVfIy6
17 | 4tVLIbH2PcCM7Q==
18 | -----END CERTIFICATE-----
19 |
--------------------------------------------------------------------------------
/test/utils/checks.go:
--------------------------------------------------------------------------------
1 | // +build citest
2 |
3 | // Package utils contains convenient testing checkers that compare a produced
4 | // value against an expected value (or condition).
5 | // There are value checks like `CheckEqual(expected, produced, t)``, and
6 | // checks that should run deferred like `defer ShouldPanic(t)`.
7 | //
8 | package utils
9 |
10 | import (
11 | "reflect"
12 | "strings"
13 | "testing"
14 | "unsafe"
15 | )
16 |
17 | // CheckEqual checks if two values are deeply equal and calls t.Fatalf if not
18 | func CheckEqual(expected interface{}, got interface{}, t *testing.T) {
19 | t.Helper()
20 | if !reflect.DeepEqual(expected, got) {
21 | t.Fatalf("Expected: %v, got %v", expected, got)
22 | }
23 | }
24 |
25 | // CheckNil checks if value is nil
26 | func CheckNil(got interface{}, t *testing.T) {
27 | t.Helper()
28 | rf := reflect.ValueOf(got)
29 | if rf.IsValid() && !rf.IsNil() {
30 | t.Fatalf("Expected: nil, got %v", got)
31 | }
32 | }
33 |
34 | // CheckNotNil checks if value is not nil
35 | func CheckNotNil(got interface{}, t *testing.T) {
36 | t.Helper()
37 | rf := reflect.ValueOf(got)
38 | if !rf.IsValid() || rf.IsNil() {
39 | t.Fatalf("Expected: not nil, got nil")
40 | }
41 | }
42 |
43 | // CheckError checks if there is an error
44 | func CheckError(got error, t *testing.T) {
45 | t.Helper()
46 | if got == nil {
47 | t.Fatalf("Expected: error, got %v", got)
48 | }
49 | }
50 |
51 | // CheckNotError checks if error value is not nil
52 | func CheckNotError(got error, t *testing.T) {
53 | t.Helper()
54 | if got != nil {
55 | t.Fatalf("Expected: no error, got %v", got)
56 | }
57 | }
58 |
59 | // CheckTrue checks if value is true
60 | func CheckTrue(got bool, t *testing.T) {
61 | t.Helper()
62 | if !got {
63 | t.Fatalf("Expected: true, got %v", got)
64 | }
65 | }
66 |
67 | // CheckFalse checks if value is false
68 | func CheckFalse(got bool, t *testing.T) {
69 | t.Helper()
70 | if got {
71 | t.Fatalf("Expected: false, got %v", got)
72 | }
73 | }
74 |
75 | // EnsureFailed checks that the given test function fails with an Error or Fatal
76 | func EnsureFailed(t *testing.T, f func(t *testing.T), substrings ...string) {
77 | tt := testing.T{}
78 | rs := reflect.ValueOf(&tt).Elem()
79 | x := make(chan bool, 1)
80 | go func() {
81 | defer func() { x <- true }() // GoExit runs all deferred calls
82 | f(&tt)
83 | }()
84 | <-x
85 | if !tt.Failed() {
86 | t.Fail()
87 | }
88 | if len(substrings) > 0 {
89 | // Pick the output bytes from the testing.T using an unsafe.Pointer.
90 | rf := rs.FieldByName("common").FieldByName("output")
91 | rf = reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem()
92 | le := string(rf.Interface().([]byte))
93 | for _, ss := range substrings {
94 | if !strings.Contains(le, ss) {
95 | t.Fatalf("string %q does not contain %q", le, ss)
96 | }
97 | }
98 | }
99 | }
100 |
--------------------------------------------------------------------------------
/test/utils/checks_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "io"
5 | "testing"
6 | )
7 |
8 | func TestCheckEqual(t *testing.T) {
9 | EnsureFailed(t, func(ft *testing.T) {
10 | CheckEqual("a", "b", ft)
11 | })
12 | }
13 |
14 | func TestCheckNil(t *testing.T) {
15 | EnsureFailed(t, func(ft *testing.T) {
16 | CheckNil([]byte{0}, ft)
17 | })
18 | }
19 |
20 | func TestCheckNotNil(t *testing.T) {
21 | EnsureFailed(t, func(ft *testing.T) {
22 | CheckNotNil(nil, ft)
23 | })
24 | }
25 |
26 | func TestCheckError(t *testing.T) {
27 | EnsureFailed(t, func(ft *testing.T) {
28 | CheckError(nil, ft)
29 | })
30 | }
31 |
32 | func TestCheckNotError(t *testing.T) {
33 | EnsureFailed(t, func(ft *testing.T) {
34 | CheckNotError(io.ErrUnexpectedEOF, ft)
35 | })
36 | }
37 |
38 | func TestCheckTrue(t *testing.T) {
39 | EnsureFailed(t, func(ft *testing.T) {
40 | CheckTrue(false, ft)
41 | })
42 | }
43 |
44 | func TestCheckFalse(t *testing.T) {
45 | EnsureFailed(t, func(ft *testing.T) {
46 | CheckFalse(true, ft)
47 | })
48 | }
49 |
--------------------------------------------------------------------------------
/test/utils/logger.go:
--------------------------------------------------------------------------------
1 | // +build citest
2 |
3 | package utils
4 |
5 | import (
6 | "github.com/tada/mqtt-nats/logger"
7 | )
8 |
9 | // The T interface is fulfilled by *testing.T but can be implemented by a T mock if needed.
10 | type T interface {
11 | Helper()
12 | Log(...interface{})
13 | }
14 |
15 | // NewLogger creates a new Logger configured to log at the given level. The logger uses the given
16 | // testing.T's Log function for all output.
17 | func NewLogger(l logger.Level, t T) logger.Logger {
18 | return &testLogger{l: l, t: t}
19 | }
20 |
21 | type testLogger struct {
22 | l logger.Level
23 | t T
24 | }
25 |
26 | func (t *testLogger) DebugEnabled() bool {
27 | return t.l >= logger.Debug
28 | }
29 |
30 | func (t *testLogger) Debug(args ...interface{}) {
31 | if t.DebugEnabled() {
32 | t.t.Helper()
33 | t.t.Log(addFirst("DEBUG", args)...)
34 | }
35 | }
36 |
37 | func (t *testLogger) ErrorEnabled() bool {
38 | return t.l >= logger.Error
39 | }
40 |
41 | func (t *testLogger) Error(args ...interface{}) {
42 | if t.ErrorEnabled() {
43 | t.t.Helper()
44 | t.t.Log(addFirst("ERROR", args)...)
45 | }
46 | }
47 |
48 | func (t *testLogger) InfoEnabled() bool {
49 | return t.l >= logger.Info
50 | }
51 |
52 | func (t *testLogger) Info(args ...interface{}) {
53 | if t.InfoEnabled() {
54 | t.t.Helper()
55 | t.t.Log(addFirst("INFO ", args)...)
56 | }
57 | }
58 |
59 | // addFirst prepends the client to the args slice and returns the new slice
60 | func addFirst(first string, args []interface{}) []interface{} {
61 | na := make([]interface{}, len(args)+1)
62 | na[0] = first
63 | copy(na[1:], args)
64 | return na
65 | }
66 |
--------------------------------------------------------------------------------
/test/utils/logger_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "testing"
5 |
6 | "github.com/tada/mqtt-nats/logger"
7 | )
8 |
9 | type mockT struct {
10 | logEntries [][]interface{}
11 | }
12 |
13 | func (m *mockT) Helper() {
14 | }
15 |
16 | func (m *mockT) Log(args ...interface{}) {
17 | m.logEntries = append(m.logEntries, args)
18 | }
19 |
20 | func Test_testLogger_Debug(t *testing.T) {
21 | tf := mockT{}
22 | lg := NewLogger(logger.Debug, &tf)
23 | CheckTrue(lg.DebugEnabled(), t)
24 | CheckTrue(lg.InfoEnabled(), t)
25 | CheckTrue(lg.ErrorEnabled(), t)
26 | lg.Debug("some stuff")
27 | CheckEqual(1, len(tf.logEntries), t)
28 | le := tf.logEntries[0]
29 | CheckEqual(2, len(le), t)
30 | CheckEqual("DEBUG", le[0], t)
31 | CheckEqual("some stuff", le[1], t)
32 | }
33 |
34 | func Test_testLogger_Info(t *testing.T) {
35 | tf := mockT{}
36 | lg := NewLogger(logger.Info, &tf)
37 | CheckFalse(lg.DebugEnabled(), t)
38 | CheckTrue(lg.InfoEnabled(), t)
39 | CheckTrue(lg.ErrorEnabled(), t)
40 | lg.Info("some stuff")
41 | CheckEqual(1, len(tf.logEntries), t)
42 | le := tf.logEntries[0]
43 | CheckEqual(2, len(le), t)
44 | CheckEqual("INFO ", le[0], t)
45 | CheckEqual("some stuff", le[1], t)
46 | }
47 |
48 | func Test_testLogger_Error(t *testing.T) {
49 | tf := mockT{}
50 | lg := NewLogger(logger.Error, &tf)
51 | CheckFalse(lg.DebugEnabled(), t)
52 | CheckFalse(lg.InfoEnabled(), t)
53 | CheckTrue(lg.ErrorEnabled(), t)
54 | lg.Error("some stuff")
55 | CheckEqual(1, len(tf.logEntries), t)
56 | le := tf.logEntries[0]
57 | CheckEqual(2, len(le), t)
58 | CheckEqual("ERROR", le[0], t)
59 | CheckEqual("some stuff", le[1], t)
60 | }
61 |
62 | func TestClient_String(t *testing.T) {
63 |
64 | }
65 |
--------------------------------------------------------------------------------
/test/utils/panics.go:
--------------------------------------------------------------------------------
1 | // +build citest
2 |
3 | package utils
4 |
5 | import "testing"
6 |
7 | // ShouldNotPanic is used to assert that a function does not panic
8 | // Usage: defer testutils.ShouldNotPanic(t) at the point where the rest is expected to not panic
9 | func ShouldNotPanic(t *testing.T) {
10 | t.Helper()
11 | if r := recover(); r != nil {
12 | t.Error("Unexpected panic")
13 | }
14 | }
15 |
16 | // ShouldPanic is used to assert that a function does panic
17 | // Usage: defer testutils.ShouldNotPanic(t) at the point where the rest is expected to panic
18 | func ShouldPanic(t *testing.T) {
19 | t.Helper()
20 | if r := recover(); r == nil {
21 | t.Error("Expected panic but got none")
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/test/utils/panics_test.go:
--------------------------------------------------------------------------------
1 | package utils
2 |
3 | import (
4 | "errors"
5 | "testing"
6 | )
7 |
8 | func TestShouldNotPanic(t *testing.T) {
9 | EnsureFailed(t, func(ft *testing.T) {
10 | defer ShouldNotPanic(ft)
11 | panic(errors.New("but it did"))
12 | })
13 | }
14 |
15 | func TestShouldPanic(t *testing.T) {
16 | EnsureFailed(t, func(ft *testing.T) {
17 | defer ShouldPanic(ft)
18 | // but didn't
19 | })
20 | }
21 |
--------------------------------------------------------------------------------