├── .github └── workflows │ └── mqtt-nats_test.yml ├── .gitignore ├── .golangci.yml ├── .idea ├── .gitignore ├── codeStyles │ └── codeStyleConfig.xml ├── inspectionProfiles │ └── Project_Default.xml ├── misc.xml ├── modules.xml ├── mqtt-nats.iml ├── vcs.xml └── watcherTasks.xml ├── LICENSE ├── README.md ├── bridge ├── client.go ├── client_test.go ├── nats.go ├── natspub.go ├── options.go ├── package.go ├── replytopic.go ├── retained.go ├── server.go └── session.go ├── cli ├── mqtt-nats.go └── package.go ├── examples ├── certs │ ├── .gitignore │ ├── ca.json │ ├── config.json │ ├── csr.json │ ├── generate.sh │ └── server.json ├── jmeter │ └── MQTT Pub Sampler.jmx ├── server.conf └── tools │ ├── nats-pub-repeat │ └── nats-pub-repeat.go │ └── nats-sub-reply │ └── nats-sub-reply.go ├── go.mod ├── go.sum ├── logger ├── logger.go └── logger_test.go ├── main.go ├── mqtt ├── package.go ├── pkg │ ├── connect.go │ ├── connect_test.go │ ├── credentials.go │ ├── idmanager.go │ ├── idmanager_test.go │ ├── packet.go │ ├── packet_test.go │ ├── ping.go │ ├── ping_test.go │ ├── publish.go │ ├── publish_test.go │ ├── subscribe.go │ ├── subscribe_test.go │ ├── unsubscribe.go │ ├── unsubscribe_test.go │ └── will.go ├── reader.go ├── reader_test.go ├── topic.go ├── topic_test.go ├── writer.go └── writer_test.go └── test ├── connect_test.go ├── full ├── bridge.go ├── client.go └── nats.go ├── main_test.go ├── mock ├── connection.go └── connection_test.go ├── package.go ├── packet ├── parse.go └── parse_test.go ├── publish_test.go ├── retained_test.go ├── tls ├── connect_test.go ├── main_test.go ├── package.go ├── server.conf └── testdata │ ├── ca-key.pem │ ├── ca.pem │ ├── client-key.pem │ ├── client.pem │ ├── server-key.pem │ └── server.pem └── utils ├── checks.go ├── checks_test.go ├── logger.go ├── logger_test.go ├── panics.go └── panics_test.go /.github/workflows/mqtt-nats_test.yml: -------------------------------------------------------------------------------- 1 | name: MQTT-NATS Test 2 | on: [push, pull_request] 3 | jobs: 4 | 5 | test: 6 | name: Test Linux 7 | runs-on: ubuntu-latest 8 | steps: 9 | 10 | - name: Set up Go 1.14 11 | uses: actions/setup-go@v1 12 | with: 13 | go-version: 1.14 14 | id: go 15 | 16 | - name: Check out code into the Go module directory 17 | uses: actions/checkout@v1 18 | 19 | - name: Set up GolangCI-Lint 20 | run: curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- latest 21 | 22 | - name: Lint 23 | run: ./bin/golangci-lint run ./... 24 | 25 | - name: Test 26 | run: go test -v -tags=citest -covermode=atomic -coverpkg=$(go list ./... | grep -v 'nats/test' | tr '\n' ,) -coverprofile coverage.tmp ./... 27 | 28 | # - name: Test Coverage Check 29 | # run: | 30 | # COV=$(go tool cover -func=coverage.tmp | grep -e '^total:\s*(statements)' | awk '{ print $3 }') 31 | # test $COV = '100.0%' || (echo "Expected 100% test coverage, got $COV" && exit 1) 32 | 33 | - name: Send coverage 34 | env: 35 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} 36 | run: | 37 | go get github.com/mattn/goveralls@master 38 | $(go env GOPATH)/bin/goveralls -tags=citest -coverprofile=coverage.tmp -service=github 39 | 40 | test-windows: 41 | name: Test Windows 42 | runs-on: windows-latest 43 | steps: 44 | 45 | - name: Set up Go 1.13 46 | uses: actions/setup-go@v1 47 | with: 48 | go-version: 1.13 49 | id: go 50 | 51 | - name: Check out code into the Go module directory 52 | uses: actions/checkout@v1 53 | 54 | - name: Test 55 | run: go test -v -tags=citest ./... 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/**/workspace.xml 2 | .idea/**/tasks.xml 3 | .idea/**/dictionaries 4 | .idea/**/shelf 5 | 6 | # Bridge persistence file 7 | mqtt-nats.json 8 | 9 | # Binaries for programs and plugins 10 | *.exe 11 | *.exe~ 12 | *.dll 13 | *.so 14 | *.dylib 15 | 16 | # Test binary, build with `go test -c` 17 | *.test 18 | 19 | # Output of the go coverage tool, specifically when used with LiteIDE 20 | *.out 21 | *.tmp 22 | -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | run: 2 | build-tags: 3 | - citest 4 | 5 | issues: 6 | exclude-use-default: false 7 | exclude-rules: 8 | - path: _test\.go 9 | linters: 10 | - const 11 | - dupl 12 | - gochecknoglobals 13 | - goconst 14 | - golint 15 | - unparam 16 | 17 | linters-settings: 18 | gocyclo: 19 | min-complexity: 35 20 | 21 | gocognit: 22 | min-complexity: 60 23 | 24 | lll: 25 | line-length: 140 26 | tab-width: 2 27 | 28 | misspell: 29 | ignore-words: 30 | - mosquitto 31 | 32 | linters: 33 | disable-all: true 34 | enable: 35 | - bodyclose 36 | - deadcode 37 | - depguard 38 | - dogsled 39 | - dupl 40 | - errcheck 41 | - gochecknoglobals 42 | - gochecknoinits 43 | - gocognit 44 | - goconst 45 | - gocritic 46 | - gocyclo 47 | - gofmt 48 | - goimports 49 | - golint 50 | - gosimple 51 | - govet 52 | - ineffassign 53 | - interfacer 54 | - lll 55 | - maligned 56 | - misspell 57 | - nakedret 58 | - prealloc 59 | - scopelint 60 | - structcheck 61 | - staticcheck 62 | - stylecheck 63 | - typecheck 64 | - unconvert 65 | - unused 66 | - varcheck 67 | - whitespace 68 | - unparam 69 | 70 | # don't enable: 71 | # - funlen 72 | # - godox 73 | # - gosec 74 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 29 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 39 | 40 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/mqtt-nats.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /.idea/watcherTasks.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 16 | 28 | 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | The mqtt-nats bridge enables [MQTT](http://mqtt.org/) devices to publish and subscribe to the [NATS](https://nats.io) communication system. 2 | 3 | [![](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 4 | [![](https://goreportcard.com/badge/github.com/tada/mqtt-nats)](https://goreportcard.com/report/github.com/tada/mqtt-nats) 5 | [![](https://img.shields.io/badge/godoc-reference-blue.svg)](https://godoc.org/github.com/tada/mqtt-nats) 6 | [![](https://github.com/tada/mqtt-nats/workflows/MQTT-NATS%20Test/badge.svg)](https://github.com/tada/mqtt-nats/actions) 7 | [![](https://coveralls.io/repos/github/tada/mqtt-nats/badge.svg?service=github)](https://coveralls.io/github/tada/mqtt-nats) 8 | 9 | ## Project status 10 | Currently under development. 11 | 12 | ## Getting started 13 | 14 | ### Obtaining and running the bridge 15 | You can install the binary using go get 16 | ``` 17 | go get github.com/tada/mqtt-nats 18 | ``` 19 | If you just start using the command `mqtt-nats` it will create an MQTT server on port 1883 that will attempt to connect 20 | to a NATS server using the default URL "nats://127.0.0.1:4222". Use: 21 | ``` 22 | mqtt-nats -help 23 | ``` 24 | to get a list of all configurable options. 25 | 26 | ### Run the tests 27 | The test utilities within this code-base are tagged with the special build tag "citest". This flag is required 28 | for most of the tests to build and run. I.e. to run all tests, use: 29 | ``` 30 | go test -tags citest ./... 31 | ``` 32 | 33 | ## Current limitations: 34 | - Only MQTT 3.1.1 is supported 35 | - Only QoS levels 0 (at most once) and 1 (at least once) 36 | - The bridge has no way of knowing when new subscriptions are added in the NATS network and hence, cannot send retained 37 | messages in response to such subscriptions. 38 | 39 | ## Solutions and workarounds 40 | 41 | ### QoS level 1 is accomplished using the NATS reply-to subject. 42 | When the bridge receives an MQTT publish with QoS = 1 from a client, it forwards that to NATS with a reply-to subject. 43 | A PUBACK is sent to the MQTT client when the reply arrives. Similarly, if an MQTT client subscribes using desired QoS 44 | = 1, then a NATS publish with a reply-to will be considered in need of a PUBACK from the MQTT client. 45 | 46 | ### Retain request from NATS 47 | An MQTT client that subscribes to a topic will immediately receive all retained messages for that topic. The same is 48 | not true for a NATS client simply because the bridge has no way of knowing when a NATS client subscribes to a topic. To 49 | mitigate this, the bridge subscribes to a specific "retained request" topic (configurable option). A NATS client that 50 | publishes to this topic with a subscription string as the payload and a reply-to inbox, will get a JSON encoded reply 51 | containing all messages that matches the subcription string. 52 | -------------------------------------------------------------------------------- /bridge/client.go: -------------------------------------------------------------------------------- 1 | package bridge 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "io" 7 | "net" 8 | "sync" 9 | "time" 10 | 11 | "github.com/nats-io/nats.go" 12 | "github.com/tada/mqtt-nats/logger" 13 | "github.com/tada/mqtt-nats/mqtt" 14 | "github.com/tada/mqtt-nats/mqtt/pkg" 15 | ) 16 | 17 | const ( 18 | // StateInfant is set when the client is created and awaits a Connect packet 19 | StateInfant = byte(iota) 20 | 21 | // StateConnected is set once a successful connection has been established 22 | StateConnected 23 | 24 | // StateDisconnected is set when a disconnect packet arrives or when a non recoverable error occurs. 25 | StateDisconnected 26 | ) 27 | 28 | // A Client represents a connection from a client. 29 | type Client interface { 30 | // Serve starts the read and write loops and then waits for them to finish which 31 | // normally happens after the receipt of a disconnect packet 32 | Serve() 33 | 34 | // State returns the current client state 35 | State() byte 36 | 37 | // PublishResponse publishes a packet to the client in response to a subscription 38 | PublishResponse(qos byte, pp *pkg.Publish) 39 | 40 | // SetDisconnected will end the read and write loop and eventually cause Serve() to end. 41 | SetDisconnected(error) 42 | } 43 | 44 | type client struct { 45 | server Server 46 | log logger.Logger 47 | mqttConn net.Conn 48 | natsConn *nats.Conn 49 | session Session 50 | connectPacket *pkg.Connect 51 | err error 52 | maxWait time.Duration 53 | natsSubs map[string]*nats.Subscription 54 | writeQueue chan pkg.Packet 55 | stLock sync.RWMutex 56 | subLock sync.Mutex 57 | workers sync.WaitGroup 58 | sessionPresent bool 59 | st byte 60 | } 61 | 62 | // TODO: This should probably be configurable. 63 | const writeQueueSize = 1024 64 | 65 | // NewClient returns a new Client instance with StateInfant state. 66 | func NewClient(s Server, log logger.Logger, conn net.Conn) Client { 67 | return &client{ 68 | server: s, 69 | log: log, 70 | mqttConn: conn, 71 | natsSubs: make(map[string]*nats.Subscription), 72 | st: StateInfant, 73 | writeQueue: make(chan pkg.Packet, writeQueueSize)} 74 | } 75 | 76 | func (c *client) Serve() { 77 | defer func() { 78 | _ = c.mqttConn.Close() 79 | 80 | if c.natsConn != nil { 81 | c.natsConn.Close() 82 | } 83 | }() 84 | 85 | c.workers.Add(2) 86 | go c.readLoop() 87 | go c.writeLoop() 88 | c.workers.Wait() 89 | 90 | cp := c.connectPacket 91 | if c.err != nil { 92 | c.Error(c.err) 93 | } 94 | if cp == nil { 95 | // No connection was established 96 | c.Debug("client connection could not be established") 97 | } else { 98 | c.Debug("disconnected") 99 | if cp.CleanSession() { 100 | c.server.SessionManager().Remove(cp.ClientID()) 101 | c.Debug("session removed") 102 | } 103 | } 104 | } 105 | 106 | // String returns a text suitable for logging of client messages. 107 | func (c *client) String() string { 108 | switch c.State() { 109 | case StateInfant: 110 | return "Client (not yet connected)" 111 | case StateConnected: 112 | return "Client " + c.connectPacket.ClientID() 113 | default: 114 | return "Client " + c.connectPacket.ClientID() + " (disconnected)" 115 | } 116 | } 117 | 118 | func (c *client) State() byte { 119 | var s byte 120 | c.stLock.RLock() 121 | s = c.st 122 | c.stLock.RUnlock() 123 | return s 124 | } 125 | 126 | func (c *client) StateAndMaxWait() (byte, time.Duration) { 127 | c.stLock.RLock() 128 | s := c.st 129 | m := c.maxWait 130 | c.stLock.RUnlock() 131 | return s, m 132 | } 133 | 134 | func (c *client) setState(newState byte) { 135 | c.stLock.Lock() 136 | c.st = newState 137 | c.stLock.Unlock() 138 | } 139 | 140 | func (c *client) setStateAndMaxWait(newState byte, maxWait time.Duration) { 141 | c.stLock.Lock() 142 | c.st = newState 143 | c.maxWait = maxWait 144 | c.stLock.Unlock() 145 | } 146 | 147 | func (c *client) SetDisconnected(err error) { 148 | doit := false 149 | c.stLock.Lock() 150 | if c.st != StateDisconnected { 151 | doit = true 152 | c.st = StateDisconnected 153 | c.maxWait = time.Millisecond 154 | } 155 | c.stLock.Unlock() 156 | 157 | if doit { 158 | if cp := c.connectPacket; cp != nil && cp.HasWill() { 159 | err := c.server.PublishWill(cp.Will(), cp.Credentials()) 160 | if err != nil { 161 | c.Error(err) 162 | } else { 163 | c.Debug("will published to", cp.Will().Topic) 164 | } 165 | } 166 | // This packet will not be sent but it will terminate the write loop once everything else 167 | // has been flushed 168 | c.writeQueue <- pkg.DisconnectSingleton 169 | 170 | if err == io.EOF { 171 | err = io.ErrUnexpectedEOF 172 | } 173 | c.err = err 174 | 175 | // release reader block 176 | _ = c.mqttConn.SetReadDeadline(time.Now().Add(time.Millisecond)) 177 | } 178 | } 179 | 180 | func (c *client) Debug(args ...interface{}) { 181 | if c.log.DebugEnabled() { 182 | c.log.Debug(c.addFirst(args)...) 183 | } 184 | } 185 | 186 | func (c *client) Error(args ...interface{}) { 187 | if c.log.ErrorEnabled() { 188 | c.log.Error(c.addFirst(args)...) 189 | } 190 | } 191 | 192 | // addFirst prepends the client to the args slice and returns the new slice 193 | func (c *client) addFirst(args []interface{}) []interface{} { 194 | na := make([]interface{}, len(args)+1) 195 | na[0] = c 196 | copy(na[1:], args) 197 | return na 198 | } 199 | 200 | func (c *client) readLoop() { 201 | defer c.workers.Done() 202 | 203 | r := mqtt.NewReader(c.mqttConn) 204 | 205 | var err error 206 | 207 | readNextPacket: 208 | for st, maxWait := c.StateAndMaxWait(); st != StateDisconnected && err == nil; st, maxWait = c.StateAndMaxWait() { 209 | var ( 210 | b byte 211 | rl int 212 | ) 213 | 214 | if maxWait > 0 { 215 | _ = c.mqttConn.SetReadDeadline(time.Now().Add(maxWait)) 216 | } 217 | 218 | // Read packet type and flags 219 | if b, err = r.ReadByte(); err != nil { 220 | break 221 | } 222 | 223 | pkgType := b & pkg.TpMask 224 | switch st { 225 | case StateConnected: 226 | if pkgType == pkg.TpConnect { 227 | err = errors.New("second connect packet") 228 | break readNextPacket 229 | } 230 | case StateInfant: 231 | if pkgType != pkg.TpConnect { 232 | err = errors.New("not connected") 233 | break readNextPacket 234 | } 235 | } 236 | 237 | // Read packet length 238 | if rl, err = r.ReadVarInt(); err != nil { 239 | break 240 | } 241 | 242 | var p pkg.Packet 243 | switch pkgType { 244 | case pkg.TpDisconnect: 245 | // Normal disconnect 246 | // Discard will 247 | c.Debug("received", pkg.DisconnectSingleton) 248 | c.connectPacket.DeleteWill() 249 | break readNextPacket 250 | case pkg.TpPing: 251 | pr := pkg.PingRequestSingleton 252 | c.Debug("received", pr) 253 | c.queueForWrite(pkg.PingResponseSingleton) 254 | case pkg.TpConnect: 255 | if p, err = pkg.ParseConnect(r, b, rl); err == nil { 256 | c.Debug("received", p) 257 | err = c.handleConnect(p.(*pkg.Connect)) 258 | if err == nil { 259 | c.server.ManageClient(c) 260 | } 261 | } 262 | if retCode, ok := err.(pkg.ReturnCode); ok { 263 | c.Debug("received", p, "return code", retCode) 264 | c.setState(StateConnected) 265 | c.queueForWrite(pkg.NewConnAck(c.sessionPresent, retCode)) 266 | } 267 | case pkg.TpPublish: 268 | if p, err = pkg.ParsePublish(r, b, rl); err == nil { 269 | c.Debug("received", p) 270 | err = c.natsPublish(c.server.HandleRetain(p.(*pkg.Publish))) 271 | } 272 | case pkg.TpPubAck: 273 | if p, err = pkg.ParsePubAck(r, b, rl); err == nil { 274 | c.Debug("received", p) 275 | id := p.(pkg.PubAck).ID() 276 | c.session.ClientAckReceived(id, c.natsConn) 277 | c.server.ReleasePacketID(id) 278 | } 279 | case pkg.TpPubRec: 280 | if p, err = pkg.ParsePubRec(r, b, rl); err == nil { 281 | c.Debug("received", p) 282 | // TODO: handle PubRec 283 | } 284 | case pkg.TpPubRel: 285 | if p, err = pkg.ParsePubRel(r, b, rl); err == nil { 286 | c.Debug("received", p) 287 | // TODO: handle PubRel 288 | } 289 | case pkg.TpPubComp: 290 | if p, err = pkg.ParsePubComp(r, b, rl); err == nil { 291 | c.Debug("received", p) 292 | // TODO: handle PubComp 293 | } 294 | case pkg.TpSubscribe: 295 | if p, err = pkg.ParseSubscribe(r, b, rl); err == nil { 296 | c.Debug("received", p) 297 | sp := p.(*pkg.Subscribe) 298 | c.natsSubscribe(sp) 299 | c.server.PublishMatching(sp, c) 300 | } 301 | case pkg.TpUnsubscribe: 302 | if p, err = pkg.ParseUnsubscribe(r, b, rl); err == nil { 303 | c.Debug("received", p) 304 | c.natsUnsubscribe(p.(*pkg.Unsubscribe)) 305 | } 306 | default: 307 | err = fmt.Errorf("received unknown packet type %d", (b&pkg.TpMask)>>4) 308 | } 309 | } 310 | c.SetDisconnected(err) 311 | } 312 | 313 | func (c *client) handleConnect(cp *pkg.Connect) error { 314 | var err error 315 | c.connectPacket = cp 316 | c.natsConn, err = c.server.NatsConn(cp.Credentials()) 317 | if err != nil { 318 | // TODO: Different error codes depending on error from NATS 319 | c.Error("NATS connect failed", err) 320 | return pkg.RtServerUnavailable 321 | } 322 | 323 | cid := cp.ClientID() 324 | m := c.server.SessionManager() 325 | c.sessionPresent = false 326 | 327 | if cp.CleanSession() { 328 | c.session = m.Create(cid) 329 | } else { 330 | if s := m.Get(cid); s != nil { 331 | c.session = s 332 | c.sessionPresent = true 333 | } else { 334 | c.session = m.Create(cid) 335 | } 336 | } 337 | 338 | var maxWait time.Duration 339 | if cp.KeepAlive() > 0 { 340 | // Max wait between control packets is 1.5 times the keep alive value 341 | maxWait = (cp.KeepAlive() * 3) / 2 342 | } 343 | c.setStateAndMaxWait(StateConnected, maxWait) 344 | c.queueForWrite(pkg.NewConnAck(c.sessionPresent, 0)) 345 | 346 | if cp.CleanSession() { 347 | c.Debug("connected with clean session") 348 | } else { 349 | if c.sessionPresent { 350 | c.Debug("connected using preexisting session") 351 | c.session.RestoreAckSubscriptions(c) 352 | c.session.ResendClientUnack(c) 353 | } else { 354 | c.Debug("connected using new (unclean) session") 355 | } 356 | } 357 | return nil 358 | } 359 | 360 | func (c *client) queueForWrite(p pkg.Packet) { 361 | if c.State() == StateConnected { 362 | c.writeQueue <- p 363 | } 364 | } 365 | 366 | func (c *client) writeLoop() { 367 | defer c.workers.Done() 368 | 369 | bulk := make([]pkg.Packet, writeQueueSize) 370 | 371 | // writer's buffer is reused for each bulk operation 372 | w := mqtt.NewWriter() 373 | 374 | // Each iteration of this loop with pick max writeQueueSize packets from the writeQueue 375 | // and then write those packets on mqtt.Writer (a bytes.Buffer extension). The resulting bytes 376 | // are then written to the connection using one single write on the connection. 377 | for connected := true; connected; { 378 | bulk[0] = <-c.writeQueue 379 | i := 1 380 | inner: 381 | for ; i < writeQueueSize; i++ { 382 | select { 383 | case p := <-c.writeQueue: 384 | bulk[i] = p 385 | default: 386 | break inner 387 | } 388 | } 389 | w.Reset() 390 | 391 | for n := 0; n < i; n++ { 392 | p := bulk[n] 393 | if p == pkg.DisconnectSingleton { 394 | connected = false 395 | break 396 | } 397 | c.Debug("sending", p) 398 | p.Write(w) 399 | } 400 | bs := w.Bytes() 401 | if len(bs) > 0 { 402 | if _, err := c.mqttConn.Write(bs); err != nil { 403 | if connected { 404 | c.SetDisconnected(err) 405 | } else { 406 | // Drain failed. Log the error 407 | c.Error(err) 408 | } 409 | break 410 | } 411 | } 412 | } 413 | } 414 | 415 | func (c *client) PublishResponse(qos byte, pp *pkg.Publish) { 416 | if qos > 0 { 417 | c.session.ClientAckRequested(pp) 418 | } 419 | c.queueForWrite(pp) 420 | } 421 | -------------------------------------------------------------------------------- /bridge/client_test.go: -------------------------------------------------------------------------------- 1 | package bridge 2 | 3 | import ( 4 | "encoding/json" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "testing" 9 | "time" 10 | 11 | "github.com/tada/mqtt-nats/test/packet" 12 | 13 | "github.com/nats-io/nats.go" 14 | "github.com/tada/mqtt-nats/logger" 15 | "github.com/tada/mqtt-nats/mqtt" 16 | "github.com/tada/mqtt-nats/mqtt/pkg" 17 | "github.com/tada/mqtt-nats/test/mock" 18 | "github.com/tada/mqtt-nats/test/utils" 19 | ) 20 | 21 | type mockServer struct { 22 | sm 23 | pkg.IDManager 24 | nc *nats.Conn 25 | ncError error 26 | willError error 27 | t *testing.T 28 | } 29 | 30 | func (m *mockServer) UnmarshalFromJSON(js *json.Decoder, firstToken json.Token) { 31 | m.t.Helper() 32 | m.t.Fatal("implement me") 33 | } 34 | 35 | func (m *mockServer) MarshalToJSON(io.Writer) { 36 | m.t.Helper() 37 | m.t.Fatal("implement me") 38 | } 39 | 40 | func (m *mockServer) SessionManager() SessionManager { 41 | return m 42 | } 43 | 44 | func (m *mockServer) ManageClient(c Client) { 45 | } 46 | 47 | func (m *mockServer) NatsConn(creds *pkg.Credentials) (*nats.Conn, error) { 48 | return m.nc, m.ncError 49 | } 50 | 51 | func (m *mockServer) HandleRetain(pp *pkg.Publish) *pkg.Publish { 52 | return pp 53 | } 54 | 55 | func (m *mockServer) PublishMatching(sp *pkg.Subscribe, c Client) { 56 | } 57 | 58 | func (m *mockServer) PublishWill(will *pkg.Will, creds *pkg.Credentials) error { 59 | return m.willError 60 | } 61 | 62 | func newMockServer(t *testing.T) *mockServer { 63 | return &mockServer{sm: sm{m: make(map[string]Session, 3)}, IDManager: pkg.NewIDManager(), t: t} 64 | } 65 | 66 | func writePacket(t *testing.T, p pkg.Packet, w io.Writer) { 67 | t.Helper() 68 | mw := mqtt.NewWriter() 69 | p.Write(mw) 70 | _, err := w.Write(mw.Bytes()) 71 | utils.CheckNotError(err, t) 72 | } 73 | 74 | var silent = logger.New(logger.Silent, nil, nil) 75 | 76 | // Test_client_String tests that the client String method produces sane output 77 | // in all states of the client (infant, connected, disconnected) 78 | func Test_client_String(t *testing.T) { 79 | conn := mock.NewConnection() 80 | cl := NewClient(newMockServer(t), silent, conn) 81 | done := make(chan bool, 1) 82 | go func() { 83 | cl.Serve() 84 | done <- true 85 | }() 86 | 87 | utils.CheckEqual("Client (not yet connected)", cl.(fmt.Stringer).String(), t) 88 | 89 | rConn := conn.Remote() 90 | writePacket(t, pkg.NewConnect("client-id", false, 1, nil, nil), rConn) 91 | bs := make([]byte, 2) 92 | _, err := rConn.Read(bs) 93 | utils.CheckNotError(err, t) 94 | utils.CheckEqual("Client client-id", cl.(fmt.Stringer).String(), t) 95 | writePacket(t, pkg.DisconnectSingleton, rConn) 96 | <-done 97 | utils.CheckEqual("Client client-id (disconnected)", cl.(fmt.Stringer).String(), t) 98 | } 99 | 100 | // Test_client_natsConnError tests that the server responds with a ConnAck containing 101 | // an pkg.RtServerUnavailable when the client was unable to establish a NATS connection. 102 | func Test_client_natsConnError(t *testing.T) { 103 | conn := mock.NewConnection() 104 | rConn := conn.Remote() 105 | ms := newMockServer(t) 106 | ms.ncError = errors.New("unauthorized") 107 | cl := NewClient(ms, silent, conn) 108 | go cl.Serve() 109 | 110 | writePacket(t, pkg.NewConnect("client-id", false, 1, nil, nil), rConn) 111 | ca, ok := packet.Parse(t, rConn).(*pkg.ConnAck) 112 | utils.CheckTrue(ok, t) 113 | utils.CheckEqual(pkg.RtServerUnavailable, ca.ReturnCode(), t) 114 | } 115 | 116 | // Test_client_natsConnError tests that the server responds with a ConnAck containing 117 | // an pkg.RtServerUnavailable when the client was unable to establish a NATS connection. 118 | func Test_client_natsSubscribeError(t *testing.T) { 119 | mt := &collectLogsT{} 120 | conn := mock.NewConnection() 121 | ms := newMockServer(t) 122 | ms.ncError = nil 123 | ms.nc = &nats.Conn{} 124 | cl := NewClient(ms, utils.NewLogger(logger.Error, mt), conn) 125 | go cl.Serve() 126 | 127 | rConn := conn.Remote() 128 | writePacket(t, pkg.NewConnect("client-id", true, 1, nil, nil), rConn) 129 | ca, ok := packet.Parse(t, rConn).(*pkg.ConnAck) 130 | utils.CheckTrue(ok, t) 131 | utils.CheckEqual(pkg.RtAccepted, ca.ReturnCode(), t) 132 | 133 | // Newline is unacceptable in a subject 134 | writePacket(t, pkg.NewSubscribe(1, pkg.Topic{Name: "top\nic"}), rConn) 135 | sa, ok := packet.Parse(t, rConn).(*pkg.SubAck) 136 | utils.CheckTrue(ok, t) 137 | 138 | // Topic return code should be 0x80 to indicate failure 139 | utils.CheckEqual(pkg.NewSubAck(1, 0x80), sa, t) 140 | 141 | // At least one error should be logged (additional caused by forced disconnect) 142 | utils.CheckTrue(len(mt.logEntries) > 0, t) 143 | el := mt.logEntries[0] 144 | utils.CheckEqual(5, len(el), t) 145 | utils.CheckEqual(el[0], "ERROR", t) 146 | utils.CheckTrue(cl == el[1], t) 147 | utils.CheckEqual("NATS subscribe", el[2], t) 148 | utils.CheckEqual("top\nic", el[3], t) 149 | } 150 | 151 | type collectLogsT struct { 152 | logEntries [][]interface{} 153 | } 154 | 155 | func (m *collectLogsT) Log(args ...interface{}) { 156 | m.logEntries = append(m.logEntries, args) 157 | } 158 | 159 | func (m *collectLogsT) Helper() { 160 | } 161 | 162 | // Test_client_publishWillError tests that errors during an attempt to publish the will 163 | // provided in the CONNECT package are logged at level logger.Error 164 | func Test_client_publishWillError(t *testing.T) { 165 | mt := &collectLogsT{} 166 | conn := mock.NewConnection() 167 | ms := newMockServer(t) 168 | ms.willError = errors.New("unauthorized") 169 | cl := NewClient(ms, utils.NewLogger(logger.Error, mt), conn) 170 | 171 | done := make(chan bool, 1) 172 | go func() { 173 | cl.Serve() 174 | done <- true 175 | }() 176 | 177 | rConn := conn.Remote() 178 | writePacket(t, pkg.NewConnect("client-id", false, 1, &pkg.Will{ 179 | Topic: "some/will", 180 | Message: []byte("will message")}, nil), rConn) 181 | 182 | ca, ok := packet.Parse(t, rConn).(*pkg.ConnAck) 183 | utils.CheckEqual(pkg.RtAccepted, ca.ReturnCode(), t) 184 | utils.CheckTrue(ok, t) 185 | _ = conn.Close() 186 | <-done 187 | 188 | // At least one error should be logged (additional caused by forced disconnect) 189 | utils.CheckTrue(len(mt.logEntries) > 0, t) 190 | el := mt.logEntries[0] 191 | utils.CheckEqual(len(el), 3, t) 192 | utils.CheckEqual(el[0], "ERROR", t) 193 | utils.CheckTrue(cl == el[1], t) 194 | err, ok := el[2].(error) 195 | utils.CheckTrue(ok, t) 196 | utils.CheckEqual("unauthorized", err.Error(), t) 197 | } 198 | 199 | // Test_client_debugLog checks that the client performs debug logging 200 | func Test_client_debugLog(t *testing.T) { 201 | mt := &collectLogsT{} 202 | conn := mock.NewConnection() 203 | cl := NewClient(newMockServer(t), utils.NewLogger(logger.Debug, mt), conn) 204 | 205 | done := make(chan bool, 1) 206 | go func() { 207 | cl.Serve() 208 | done <- true 209 | }() 210 | 211 | rConn := conn.Remote() 212 | writePacket(t, pkg.NewConnect("client-id", false, 1, nil, nil), rConn) 213 | ca, ok := packet.Parse(t, rConn).(*pkg.ConnAck) 214 | utils.CheckTrue(ok, t) 215 | utils.CheckEqual(pkg.RtAccepted, ca.ReturnCode(), t) 216 | writePacket(t, pkg.PubRec(1), rConn) 217 | writePacket(t, pkg.PubRel(2), rConn) 218 | writePacket(t, pkg.PubComp(3), rConn) 219 | writePacket(t, pkg.DisconnectSingleton, rConn) 220 | <-done 221 | 222 | // check that all received packages were logged 223 | cnt := 0 224 | for _, le := range mt.logEntries { 225 | if len(le) == 4 && le[0] == "DEBUG" && le[2] == "received" { 226 | switch le[3].(type) { 227 | case *pkg.Connect, pkg.PubRec, pkg.PubRel, pkg.PubComp: 228 | cnt++ 229 | } 230 | } 231 | } 232 | utils.CheckEqual(4, cnt, t) 233 | } 234 | 235 | type writeFailure struct { 236 | *mock.Connection 237 | succeed uint 238 | tick chan bool 239 | } 240 | 241 | func (c *writeFailure) Write(bs []byte) (int, error) { 242 | if c.succeed == 0 { 243 | return 0, errors.New("write failed") 244 | } 245 | i, err := c.Connection.Write(bs) 246 | c.succeed-- 247 | c.tick <- true 248 | return i, err 249 | } 250 | 251 | // Test_write_failure_when_connected checks that the client propagates write error 252 | func Test_write_failure_when_connected(t *testing.T) { 253 | conn := &writeFailure{Connection: mock.NewConnection()} 254 | mt := &collectLogsT{} 255 | cl := NewClient(newMockServer(t), utils.NewLogger(logger.Error, mt), conn) 256 | 257 | done := make(chan bool, 1) 258 | go func() { 259 | cl.Serve() 260 | done <- true 261 | }() 262 | 263 | writePacket(t, pkg.NewConnect("client-id", false, 1, nil, nil), conn.Remote()) 264 | // Should fail with forced disconnect 265 | select { 266 | case <-done: 267 | case <-time.After(time.Second): 268 | t.Fatal("expected forced disconnect did not occur") 269 | } 270 | 271 | // At least one error should be logged (additional caused by forced disconnect) 272 | utils.CheckTrue(len(mt.logEntries) > 0, t) 273 | el := mt.logEntries[0] 274 | utils.CheckEqual(len(el), 3, t) 275 | utils.CheckEqual(el[0], "ERROR", t) 276 | utils.CheckTrue(cl == el[1], t) 277 | err, ok := el[2].(error) 278 | utils.CheckTrue(ok, t) 279 | utils.CheckEqual("write failed", err.Error(), t) 280 | } 281 | 282 | // Test_write_failure_during_drain checks that the client logs error that occurs during writeLoop drain 283 | func Test_write_failure_during_drain(t *testing.T) { 284 | conn := &writeFailure{Connection: mock.NewConnection(), succeed: 1, tick: make(chan bool, 1)} 285 | mt := &collectLogsT{} 286 | cl := NewClient(newMockServer(t), utils.NewLogger(logger.Error, mt), conn) 287 | 288 | done := make(chan bool, 1) 289 | go func() { 290 | cl.Serve() 291 | done <- true 292 | }() 293 | 294 | rConn := conn.Remote() 295 | writePacket(t, pkg.NewConnect("client-id", false, 1, nil, nil), rConn) 296 | <-conn.tick 297 | cl.(*client).queueForWrite(pkg.PingResponseSingleton) 298 | cl.(*client).queueForWrite(pkg.DisconnectSingleton) 299 | cl.SetDisconnected(nil) 300 | 301 | // Should fail with forced disconnect 302 | select { 303 | case <-done: 304 | case <-time.After(100 * time.Millisecond): 305 | t.Fatal("expected forced disconnect did not occur") 306 | } 307 | 308 | // At least one error should be logged (additional caused by forced disconnect) 309 | utils.CheckTrue(len(mt.logEntries) > 0, t) 310 | el := mt.logEntries[0] 311 | utils.CheckEqual(len(el), 3, t) 312 | utils.CheckEqual(el[0], "ERROR", t) 313 | utils.CheckTrue(cl == el[1], t) 314 | err, ok := el[2].(error) 315 | utils.CheckTrue(ok, t) 316 | utils.CheckEqual("write failed", err.Error(), t) 317 | } 318 | -------------------------------------------------------------------------------- /bridge/nats.go: -------------------------------------------------------------------------------- 1 | package bridge 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/nats-io/nats.go" 7 | "github.com/tada/mqtt-nats/mqtt" 8 | "github.com/tada/mqtt-nats/mqtt/pkg" 9 | ) 10 | 11 | func (c *client) natsPublish(pp *pkg.Publish) error { 12 | var err error 13 | if pp.IsDup() { 14 | if c.session.AwaitsAck(pp.ID()) { 15 | // Already waiting for this one 16 | return nil 17 | } 18 | } 19 | 20 | natsSubject := mqtt.ToNATS(pp.TopicName()) 21 | switch pp.QoSLevel() { 22 | case 0: 23 | // Fire and forget 24 | err = c.natsConn.Publish(natsSubject, pp.Payload()) 25 | case 1: 26 | // use client id and packet id to form a reply subject 27 | replyTo := NewReplyTopic(c.session, pp).String() 28 | var sub *nats.Subscription 29 | sub, err = c.natsSubscribeAck(replyTo) 30 | if err == nil { 31 | c.session.AckRequested(pp.ID(), sub) 32 | err = c.natsConn.PublishRequest(natsSubject, replyTo, pp.Payload()) 33 | } 34 | case 2: 35 | err = errors.New("QoS level 2 is not supported") 36 | default: 37 | err = errors.New("invalid QoS level") 38 | } 39 | return err 40 | } 41 | 42 | func (c *client) natsSubscribeAck(topic string) (*nats.Subscription, error) { 43 | return c.natsConn.Subscribe(topic, func(m *nats.Msg) { 44 | // Client may have disconnected at this point which is why it is essential to ask 45 | // the session manager for the session based on the replyTo subject 46 | mt := ParseReplyTopic(m.Subject) 47 | if mt != nil { 48 | if s := c.server.SessionManager().Get(mt.ClientID()); s != nil && s.ID() == mt.SessionID() { 49 | c.cancelNatsSubscriptions(s.AckReceived(mt.PacketID())) 50 | c.queueForWrite(pkg.PubAck(mt.PacketID())) 51 | } 52 | } 53 | }) 54 | } 55 | 56 | func (c *client) cancelNatsSubscriptions(nss []*nats.Subscription) { 57 | for i := range nss { 58 | ns := nss[i] 59 | if err := ns.Unsubscribe(); err != nil { 60 | c.Error("NATS unsubscribe", ns.Subject, err) 61 | } 62 | } 63 | } 64 | 65 | func (c *client) natsSubscribe(sp *pkg.Subscribe) { 66 | tps := sp.Topics() 67 | nms := make([]string, len(tps)) 68 | qss := make([]byte, len(tps)) 69 | var nss []*nats.Subscription 70 | c.subLock.Lock() 71 | for i := range tps { 72 | tp := tps[i] 73 | nm := mqtt.ToNATSSubscription(tp.Name) 74 | nms[i] = nm 75 | qss[i] = tp.QoS 76 | if os := c.natsSubs[nm]; os != nil { 77 | delete(c.natsSubs, nm) 78 | nss = append(nss, os) 79 | } 80 | } 81 | c.subLock.Unlock() 82 | c.cancelNatsSubscriptions(nss) 83 | 84 | nss = make([]*nats.Subscription, 0, len(nms)) 85 | for i := range nms { 86 | nm := nms[i] 87 | qs := qss[i] 88 | if qs > 1 { 89 | qs = 1 90 | qss[i] = 1 91 | } 92 | ns, err := c.natsConn.Subscribe(nm, func(m *nats.Msg) { 93 | c.natsResponse(qs, m) 94 | }) 95 | if err == nil { 96 | nss = append(nss, ns) 97 | } else { 98 | c.Error("NATS subscribe", nm, err) 99 | qss[i] = 0x80 100 | } 101 | } 102 | c.subLock.Lock() 103 | for i := range nss { 104 | ns := nss[i] 105 | c.natsSubs[ns.Subject] = ns 106 | } 107 | c.subLock.Unlock() 108 | c.queueForWrite(pkg.NewSubAck(sp.ID(), qss...)) 109 | } 110 | 111 | func (c *client) natsUnsubscribe(up *pkg.Unsubscribe) { 112 | tps := up.Topics() 113 | nss := make([]*nats.Subscription, 0, len(tps)) 114 | c.subLock.Lock() 115 | for i := range tps { 116 | subj := mqtt.ToNATSSubscription(tps[i]) 117 | if ns := c.natsSubs[subj]; ns != nil { 118 | nss = append(nss, ns) 119 | delete(c.natsSubs, subj) 120 | } 121 | } 122 | c.subLock.Unlock() 123 | c.cancelNatsSubscriptions(nss) 124 | c.queueForWrite(pkg.UnsubAck(up.ID())) 125 | } 126 | 127 | func (c *client) natsResponse(desiredQoS byte, m *nats.Msg) { 128 | id := uint16(0) 129 | flags := byte(0) 130 | if desiredQoS > 0 && m.Reply != `` { 131 | if mt := ParseReplyTopic(m.Reply); mt != nil { 132 | id = mt.PacketID() 133 | flags = mt.Flags() 134 | } else { 135 | id = c.server.NextFreePacketID() 136 | flags = 2 // QoS level 1 137 | } 138 | } 139 | pp := pkg.NewPublish(id, mqtt.FromNATS(m.Subject), flags, m.Data, false, m.Reply) 140 | qos := desiredQoS 141 | if pp.QoSLevel() < qos { 142 | qos = pp.QoSLevel() 143 | } 144 | c.PublishResponse(qos, pp) 145 | } 146 | -------------------------------------------------------------------------------- /bridge/natspub.go: -------------------------------------------------------------------------------- 1 | package bridge 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | 7 | "github.com/tada/catch/pio" 8 | "github.com/tada/jsonstream" 9 | "github.com/tada/mqtt-nats/mqtt/pkg" 10 | ) 11 | 12 | // natsPub represents a message which originated from this server (such as a client will) that has been 13 | // published to NATS using some given credentials and now awaits a reply. 14 | type natsPub struct { 15 | // pp is the packet that was published 16 | pp *pkg.Publish 17 | 18 | // creds are the client credentials for the publication 19 | creds *pkg.Credentials 20 | } 21 | 22 | func (n *natsPub) MarshalToJSON(w io.Writer) { 23 | pio.WriteString(w, `{"m":`) 24 | n.pp.MarshalToJSON(w) 25 | if n.creds != nil { 26 | pio.WriteString(w, `,"c":`) 27 | n.creds.MarshalToJSON(w) 28 | } 29 | pio.WriteByte(w, '}') 30 | } 31 | 32 | func (n *natsPub) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) { 33 | jsonstream.AssertDelim(t, '{') 34 | for { 35 | k, ok := js.ReadStringOrEnd('}') 36 | if !ok { 37 | break 38 | } 39 | switch k { 40 | case "m": 41 | n.pp = &pkg.Publish{} 42 | if !js.ReadConsumer(n.pp) { 43 | n.pp = nil 44 | } 45 | case "c": 46 | n.creds = &pkg.Credentials{} 47 | if !js.ReadConsumer(n.creds) { 48 | n.creds = nil 49 | } 50 | } 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /bridge/options.go: -------------------------------------------------------------------------------- 1 | package bridge 2 | 3 | import ( 4 | "crypto/tls" 5 | 6 | "github.com/nats-io/nats.go" 7 | ) 8 | 9 | // Options contains all configuration options for the mqtt-nats bridge. 10 | type Options struct { 11 | // Path to file where the bridge is persisted. Can be empty if no persistence is desired 12 | StoragePath string 13 | 14 | // NATSUrls is a comma separated list of URLs used when connecting to NATS 15 | NATSUrls string 16 | 17 | // RetainedRequestTopic is a NATS topic that a NATS client can publish to after doing a subscribe 18 | // in order to retrieve any messages that are retained for that subscription. The payload must be 19 | // the verbatim NATS subscription. Retained messages that matches the subscription will be published 20 | // to the reply-to topic in the form of a JSON list of objects with a "subject" string and a "payload" 21 | // base64 encoded string 22 | RetainedRequestTopic string 23 | 24 | // Port is the MQTT port 25 | Port int 26 | 27 | // RepeatRate is the delay in milliseconds between publishing packets that originated in this server 28 | // that have QoS > 0 but hasn't been acknowledged. 29 | RepeatRate int 30 | 31 | // NATSOpts are options specific to the NATS connection 32 | NATSOpts []nats.Option 33 | 34 | TLSTimeout float64 35 | TLSCert string 36 | TLSKey string 37 | TLSCaCert string 38 | TLSConfig *tls.Config 39 | TLS bool 40 | TLSVerify bool 41 | TLSMap bool 42 | 43 | // Debug enables debug level log output 44 | Debug bool 45 | } 46 | -------------------------------------------------------------------------------- /bridge/package.go: -------------------------------------------------------------------------------- 1 | // Package bridge contains the MQTT-NATS bridge server implementation 2 | package bridge 3 | -------------------------------------------------------------------------------- /bridge/replytopic.go: -------------------------------------------------------------------------------- 1 | package bridge 2 | 3 | import ( 4 | "fmt" 5 | "strconv" 6 | "strings" 7 | 8 | "github.com/tada/mqtt-nats/mqtt/pkg" 9 | ) 10 | 11 | // ReplyTopic represents the decoded form of the NATS reply-topic that the bridge uses to track 12 | // messages that are in need of an ACK. 13 | type ReplyTopic struct { 14 | s string 15 | c string 16 | p uint16 17 | f byte 18 | } 19 | 20 | // NewReplyTopic creates a new ReplyTopic based on a pkg.Publish packet. 21 | func NewReplyTopic(s Session, pp *pkg.Publish) *ReplyTopic { 22 | return &ReplyTopic{c: s.ClientID(), s: s.ID(), p: pp.ID(), f: pp.Flags()} 23 | } 24 | 25 | // ParseReplyTopic creates a new ReplyTopic by parsing a NATS reply-to string 26 | func ParseReplyTopic(s string) *ReplyTopic { 27 | ps := strings.Split(s, ".") 28 | if len(ps) == 5 && ps[0] == "_INBOX" { 29 | p, err := strconv.Atoi(ps[3]) 30 | if err == nil { 31 | var f int 32 | f, err = strconv.Atoi(ps[4]) 33 | if err == nil { 34 | return &ReplyTopic{c: ps[1], s: ps[2], p: uint16(p), f: byte(f)} 35 | } 36 | } 37 | } 38 | return nil 39 | } 40 | 41 | // ClientID returns the ID of the client where the message originated 42 | func (r *ReplyTopic) ClientID() string { 43 | return r.c 44 | } 45 | 46 | // SessionID returns the ID of the client session within the mqtt-nats bridge 47 | func (r *ReplyTopic) SessionID() string { 48 | return r.s 49 | } 50 | 51 | // PacketID returns the packet identifier of the original packet 52 | func (r *ReplyTopic) PacketID() uint16 { 53 | return r.p 54 | } 55 | 56 | // Flags returns the packet flags of the original packet 57 | func (r *ReplyTopic) Flags() byte { 58 | return r.f 59 | } 60 | 61 | // String returns the NATS string form of the reply-topic 62 | func (r *ReplyTopic) String() string { 63 | return fmt.Sprintf("_INBOX.%s.%s.%d.%d", r.c, r.s, r.p, r.f) 64 | } 65 | -------------------------------------------------------------------------------- /bridge/retained.go: -------------------------------------------------------------------------------- 1 | package bridge 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "regexp" 7 | "strings" 8 | "sync" 9 | 10 | "github.com/nats-io/nats.go" 11 | "github.com/tada/catch/pio" 12 | "github.com/tada/jsonstream" 13 | "github.com/tada/mqtt-nats/mqtt" 14 | "github.com/tada/mqtt-nats/mqtt/pkg" 15 | ) 16 | 17 | type retained struct { 18 | lock sync.RWMutex 19 | msgs map[string]*pkg.Publish 20 | order []string 21 | } 22 | 23 | func (r *retained) Empty() bool { 24 | r.lock.RLock() 25 | empty := len(r.order) == 0 26 | r.lock.RUnlock() 27 | return empty 28 | } 29 | 30 | func (r *retained) MarshalJSON() ([]byte, error) { 31 | return jsonstream.Marshal(r) 32 | } 33 | 34 | func (r *retained) UnmarshalJSON(bs []byte) error { 35 | return jsonstream.Unmarshal(r, bs) 36 | } 37 | 38 | func (r *retained) MarshalToJSON(w io.Writer) { 39 | r.lock.RLock() 40 | defer r.lock.RUnlock() 41 | sep := byte('{') 42 | for _, t := range r.order { 43 | pio.WriteByte(w, sep) 44 | sep = byte(',') 45 | jsonstream.WriteString(w, t) 46 | pio.WriteByte(w, ':') 47 | r.msgs[t].MarshalToJSON(w) 48 | } 49 | if sep == '{' { 50 | pio.WriteByte(w, sep) 51 | } 52 | pio.WriteByte(w, '}') 53 | } 54 | 55 | func (r *retained) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) { 56 | jsonstream.AssertDelim(t, '{') 57 | r.msgs = make(map[string]*pkg.Publish) 58 | r.order = nil 59 | for { 60 | s, ok := js.ReadStringOrEnd('}') 61 | if !ok { 62 | break 63 | } 64 | p := &pkg.Publish{} 65 | js.ReadConsumer(p) 66 | r.msgs[s] = p 67 | r.order = append(r.order, s) 68 | } 69 | } 70 | 71 | func (r *retained) add(m *pkg.Publish) bool { 72 | r.lock.Lock() 73 | t := m.TopicName() 74 | _, present := r.msgs[t] 75 | r.msgs[t] = m 76 | if !present { 77 | r.order = append(r.order, t) 78 | } 79 | r.lock.Unlock() 80 | return !present 81 | } 82 | 83 | func (r *retained) drop(t string) bool { 84 | dropped := false 85 | r.lock.Lock() 86 | if _, present := r.msgs[t]; present { 87 | delete(r.msgs, t) 88 | o := r.order 89 | last := len(o) - 1 90 | for i := 0; i <= last; i++ { 91 | if o[i] == t { 92 | copy(o[i:], o[i+1:]) 93 | o[last] = `` // allow GC of last 94 | r.order = o[:last] 95 | dropped = true 96 | break 97 | } 98 | } 99 | } 100 | r.lock.Unlock() 101 | return dropped 102 | } 103 | 104 | func (r *retained) messagesMatchingRetainRequest(m *nats.Msg) ([]*pkg.Publish, []byte) { 105 | natsTopics := strings.Split(string(m.Data), ",") 106 | topics := make([]pkg.Topic, len(natsTopics)) 107 | for i := range natsTopics { 108 | topics[i] = pkg.Topic{Name: mqtt.FromNATSSubscription(natsTopics[i])} 109 | } 110 | return r.matchingMessages(topics) 111 | } 112 | 113 | func (r *retained) publishMatching(s *pkg.Subscribe, c Client) { 114 | pps, qs := r.matchingMessages(s.Topics()) 115 | for i := range pps { 116 | pp := pps[i] 117 | c.PublishResponse(qs[i], pp) 118 | } 119 | } 120 | 121 | func (r *retained) matchingMessages(tps []pkg.Topic) ([]*pkg.Publish, []byte) { 122 | tpl := len(tps) 123 | 124 | // Create slices of subscription regexps and desired QoS. One entry for each subscription topic 125 | txs := make([]*regexp.Regexp, tpl) 126 | dqs := make([]byte, tpl) 127 | for i := 0; i < tpl; i++ { 128 | tp := tps[i] 129 | txs[i] = mqtt.SubscriptionToRegexp(tp.Name) 130 | dqs[i] = tp.QoS 131 | } 132 | 133 | // For each subscription topic, extract matching packets and desired QoS 134 | pps := make([]*pkg.Publish, 0) 135 | qs := make([]byte, 0) 136 | 137 | r.lock.RLock() 138 | for i := 0; i < tpl; i++ { 139 | tx := txs[i] 140 | dq := dqs[i] 141 | for _, t := range r.order { 142 | if tx.MatchString(t) { 143 | pps = append(pps, r.msgs[t]) 144 | qs = append(qs, dq) 145 | } 146 | } 147 | } 148 | r.lock.RUnlock() 149 | return pps, qs 150 | } 151 | -------------------------------------------------------------------------------- /bridge/session.go: -------------------------------------------------------------------------------- 1 | package bridge 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "strconv" 7 | "sync" 8 | 9 | "github.com/tada/catch" 10 | 11 | "github.com/nats-io/nats.go" 12 | "github.com/tada/catch/pio" 13 | "github.com/tada/jsonstream" 14 | "github.com/tada/mqtt-nats/mqtt/pkg" 15 | ) 16 | 17 | // A Session contains data associated with a client ID. The session might survive client 18 | // connections. 19 | type Session interface { 20 | jsonstream.Consumer 21 | jsonstream.Streamer 22 | 23 | // ID returns an identifier that is unique for this session 24 | ID() string 25 | 26 | // ClientID returns the id of the client that this session belongs to 27 | ClientID() string 28 | 29 | // Destroy the session 30 | Destroy() 31 | 32 | // AckRequested remembers the given subscription which represents an awaited ACK 33 | // for the given packetID 34 | AckRequested(uint16, *nats.Subscription) 35 | 36 | // AwaitsAck returns true if a subscription associated with the given packet identifier 37 | // is currently waiting for an Ack. 38 | AwaitsAck(uint16) bool 39 | 40 | // AckReceived will delete pending ack subscription from the session and return them. It is up 41 | // to the caller to cancel the returned subscriptions. 42 | AckReceived(uint16) []*nats.Subscription 43 | 44 | // ClientAckRequested remembers the id of a packet which has been sent to the client. The packet stems from a NATS 45 | // subscription with QoS level > 0 and it is now expected that the client sends an PubACK back to which can be 46 | // propagated to the reply-to address. 47 | ClientAckRequested(*pkg.Publish) 48 | 49 | // ClientAckReceived will close a pending response ack subscription and forward the ACK to the 50 | // replyTo subject. It returns whether or not such an ack was pending 51 | ClientAckReceived(uint16, *nats.Conn) bool 52 | 53 | // Resend all messages that the client hasn't acknowledged 54 | ResendClientUnack(c *client) 55 | 56 | // RestoreAckSubscriptions called when a client restores an old session. THe method restores subscriptions that 57 | // were peristed and then loaded again. 58 | RestoreAckSubscriptions(c *client) 59 | } 60 | 61 | type session struct { 62 | id string 63 | clientID string 64 | prelAwaitsAck map[uint16]string 65 | awaitsAck map[uint16]*nats.Subscription // awaits ack on reply-to to be propagated to client 66 | awaitsClientAck map[uint16]*pkg.Publish // awaits ack from client to be propagated to nats 67 | awaitsAckLock sync.RWMutex 68 | } 69 | 70 | func (s *session) MarshalJSON() ([]byte, error) { 71 | return jsonstream.Marshal(s) 72 | } 73 | 74 | func (s *session) MarshalToJSON(w io.Writer) { 75 | pio.WriteString(w, `{"id":`) 76 | jsonstream.WriteString(w, s.id) 77 | pio.WriteString(w, `,"cid":`) 78 | jsonstream.WriteString(w, s.clientID) 79 | s.awaitsAckLock.RLock() 80 | if len(s.awaitsAck) > 0 { 81 | pio.WriteString(w, `,"awAck":`) 82 | sep := byte('{') 83 | for k, v := range s.awaitsAck { 84 | pio.WriteByte(w, sep) 85 | sep = byte(',') 86 | pio.WriteByte(w, '"') 87 | pio.WriteInt(w, int64(k)) 88 | pio.WriteString(w, `":`) 89 | jsonstream.WriteString(w, v.Subject) 90 | } 91 | pio.WriteByte(w, '}') 92 | } 93 | if len(s.awaitsClientAck) > 0 { 94 | pio.WriteString(w, `,"awClientAck":`) 95 | sep := byte('{') 96 | for k, v := range s.awaitsClientAck { 97 | pio.WriteByte(w, sep) 98 | sep = byte(',') 99 | pio.WriteByte(w, '"') 100 | pio.WriteInt(w, int64(k)) 101 | pio.WriteString(w, `":`) 102 | v.MarshalToJSON(w) 103 | } 104 | pio.WriteByte(w, '}') 105 | } 106 | pio.WriteByte(w, '}') 107 | s.awaitsAckLock.RUnlock() 108 | } 109 | 110 | func (s *session) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) { 111 | jsonstream.AssertDelim(t, '{') 112 | for { 113 | k, ok := js.ReadStringOrEnd('}') 114 | if !ok { 115 | break 116 | } 117 | switch k { 118 | case "id": 119 | s.id = js.ReadString() 120 | case "cid": 121 | s.clientID = js.ReadString() 122 | case "awAck": 123 | js.ReadDelim('{') 124 | for { 125 | k, ok = js.ReadStringOrEnd('}') 126 | if !ok { 127 | break 128 | } 129 | if s.prelAwaitsAck == nil { 130 | s.prelAwaitsAck = make(map[uint16]string) 131 | } 132 | i, err := strconv.Atoi(k) 133 | if err != nil { 134 | panic(catch.Error(err)) 135 | } 136 | s.prelAwaitsAck[uint16(i)] = js.ReadString() 137 | } 138 | case "awClientAck": 139 | js.ReadDelim('{') 140 | for { 141 | k, ok = js.ReadStringOrEnd('}') 142 | if !ok { 143 | break 144 | } 145 | if s.awaitsClientAck == nil { 146 | s.awaitsClientAck = make(map[uint16]*pkg.Publish) 147 | } 148 | i, err := strconv.Atoi(k) 149 | if err != nil { 150 | panic(catch.Error(err)) 151 | } 152 | pp := &pkg.Publish{} 153 | if js.ReadConsumer(pp) { 154 | s.awaitsClientAck[uint16(i)] = pp 155 | } 156 | } 157 | } 158 | } 159 | } 160 | 161 | func (s *session) RestoreAckSubscriptions(c *client) { 162 | if s.prelAwaitsAck != nil { 163 | for k, v := range s.prelAwaitsAck { 164 | sb, err := c.natsSubscribeAck(v) 165 | if err != nil { 166 | c.Error(err) 167 | } else { 168 | s.AckRequested(k, sb) 169 | } 170 | } 171 | s.prelAwaitsAck = nil 172 | } 173 | } 174 | 175 | func (s *session) AckReceived(packetID uint16) []*nats.Subscription { 176 | var nss []*nats.Subscription 177 | s.awaitsAckLock.Lock() 178 | if s.awaitsAck != nil { 179 | if sb, awaits := s.awaitsAck[packetID]; awaits { 180 | nss = append(nss, sb) 181 | delete(s.awaitsAck, packetID) 182 | } 183 | } 184 | s.awaitsAckLock.Unlock() 185 | return nss 186 | } 187 | 188 | func (s *session) AckRequested(packetID uint16, sb *nats.Subscription) { 189 | s.awaitsAckLock.Lock() 190 | if s.awaitsAck == nil { 191 | s.awaitsAck = make(map[uint16]*nats.Subscription) 192 | } 193 | s.awaitsAck[packetID] = sb 194 | s.awaitsAckLock.Unlock() 195 | } 196 | 197 | func (s *session) AwaitsAck(packetID uint16) bool { 198 | awaits := false 199 | s.awaitsAckLock.RLock() 200 | if s.awaitsAck != nil { 201 | _, awaits = s.awaitsAck[packetID] 202 | } 203 | s.awaitsAckLock.RUnlock() 204 | return awaits 205 | } 206 | 207 | func (s *session) ClientAckReceived(packetID uint16, c *nats.Conn) bool { 208 | var pp *pkg.Publish 209 | s.awaitsAckLock.Lock() 210 | if s.awaitsClientAck != nil { 211 | var found bool 212 | if pp, found = s.awaitsClientAck[packetID]; found { 213 | delete(s.awaitsClientAck, packetID) 214 | } 215 | } 216 | s.awaitsAckLock.Unlock() 217 | if pp != nil { 218 | _ = c.Publish(pp.NatsReplyTo(), []byte{0}) 219 | return true 220 | } 221 | return false 222 | } 223 | 224 | func (s *session) ClientAckRequested(pp *pkg.Publish) { 225 | s.awaitsAckLock.Lock() 226 | if s.awaitsClientAck == nil { 227 | s.awaitsClientAck = make(map[uint16]*pkg.Publish) 228 | } 229 | s.awaitsClientAck[pp.ID()] = pp 230 | s.awaitsAckLock.Unlock() 231 | } 232 | 233 | func (s *session) ResendClientUnack(c *client) { 234 | s.awaitsAckLock.RLock() 235 | as := make([]*pkg.Publish, 0, len(s.awaitsClientAck)) 236 | for _, a := range s.awaitsClientAck { 237 | as = append(as, a) 238 | } 239 | s.awaitsAckLock.RUnlock() 240 | for i := range as { 241 | a := as[i] 242 | c.PublishResponse(a.QoSLevel(), a) 243 | } 244 | } 245 | 246 | func (s *session) ID() string { 247 | return s.id 248 | } 249 | 250 | func (s *session) ClientID() string { 251 | return s.clientID 252 | } 253 | 254 | func (s *session) Destroy() { 255 | // Unsubscribe all pending subscriptions 256 | s.awaitsAckLock.Lock() 257 | if s.awaitsAck != nil { 258 | for _, sb := range s.awaitsAck { 259 | _ = sb.Unsubscribe() 260 | } 261 | s.awaitsAck = nil 262 | } 263 | s.awaitsAckLock.Unlock() 264 | } 265 | 266 | // A SessionManager manages sessions. 267 | type SessionManager interface { 268 | // Create creates a new session for the given clientID. Any previous session registered for 269 | // the given id is discarded 270 | Create(clientID string) Session 271 | 272 | // Get returns an existing session for the given clientID or nil if no such session exists 273 | Get(clientID string) Session 274 | 275 | // Remove removes any session for the given clientID 276 | Remove(clientID string) 277 | } 278 | 279 | type sm struct { 280 | lock sync.RWMutex 281 | seed uint32 282 | m map[string]Session 283 | } 284 | 285 | func (m *sm) Get(clientID string) Session { 286 | var s Session 287 | m.lock.RLock() 288 | s = m.m[clientID] 289 | m.lock.RUnlock() 290 | return s 291 | } 292 | 293 | func (m *sm) Create(clientID string) Session { 294 | m.lock.Lock() 295 | m.seed++ 296 | s := &session{id: `s` + strconv.Itoa(int(m.seed)), clientID: clientID} 297 | m.m[clientID] = s 298 | m.lock.Unlock() 299 | return s 300 | } 301 | 302 | func (m *sm) MarshalToJSON(w io.Writer) { 303 | m.lock.RLock() 304 | defer m.lock.RUnlock() 305 | 306 | pio.WriteString(w, `{"seed":`) 307 | pio.WriteInt(w, int64(m.seed)) 308 | if len(m.m) > 0 { 309 | pio.WriteString(w, `,"sessions":`) 310 | sep := byte('{') 311 | for k, v := range m.m { 312 | pio.WriteByte(w, sep) 313 | sep = ',' 314 | jsonstream.WriteString(w, k) 315 | pio.WriteByte(w, ':') 316 | v.MarshalToJSON(w) 317 | } 318 | pio.WriteByte(w, '}') 319 | } 320 | pio.WriteByte(w, '}') 321 | } 322 | 323 | func (m *sm) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) { 324 | jsonstream.AssertDelim(t, '{') 325 | for { 326 | k, ok := js.ReadStringOrEnd('}') 327 | if !ok { 328 | break 329 | } 330 | switch k { 331 | case "sessions": 332 | js.ReadDelim('{') 333 | for { 334 | k, ok = js.ReadStringOrEnd('}') 335 | if !ok { 336 | break 337 | } 338 | s := &session{} 339 | js.ReadConsumer(s) 340 | m.m[k] = s 341 | } 342 | case "seed": 343 | m.seed = uint32(js.ReadInt()) 344 | } 345 | } 346 | } 347 | 348 | func (m *sm) Remove(clientID string) { 349 | var s Session 350 | m.lock.Lock() 351 | s = m.m[clientID] 352 | delete(m.m, clientID) 353 | m.lock.Unlock() 354 | if s != nil { 355 | s.Destroy() 356 | } 357 | } 358 | -------------------------------------------------------------------------------- /cli/mqtt-nats.go: -------------------------------------------------------------------------------- 1 | // +build !citest 2 | 3 | package cli 4 | 5 | import ( 6 | "flag" 7 | "io" 8 | 9 | "github.com/nats-io/nats.go" 10 | "github.com/tada/mqtt-nats/bridge" 11 | "github.com/tada/mqtt-nats/logger" 12 | ) 13 | 14 | // Bridge parses the command line arguments of args into an bridge.Options instance and then starts the 15 | // bridge with those options. 16 | func Bridge(args []string, stdout, stderr io.Writer) int { 17 | fs := flag.NewFlagSet(args[0], flag.ExitOnError) 18 | fs.SetOutput(stderr) 19 | 20 | var ( 21 | printHelp bool 22 | natsClientCert string 23 | natsClientKey string 24 | natsRootCAs string 25 | natsCredsFile string 26 | ) 27 | opts := &bridge.Options{} 28 | fs.StringVar(&opts.NATSUrls, "natsurl", nats.DefaultURL, "NATS server URLs separated by comma") 29 | fs.IntVar(&opts.Port, "port", 0, "MQTT Port to listen on (defaults to 1883 or 8883 with TLS)") 30 | fs.BoolVar(&printHelp, "h", false, "") 31 | fs.BoolVar(&printHelp, "help", false, "Print this help") 32 | fs.IntVar(&opts.RepeatRate, "repeatrate", 5000, "time in milliseconds between each publish of unacknowledged messages") 33 | // persistence 34 | fs.StringVar(&opts.StoragePath, "storage", "mqtt-nats.json", "path to json file where server state is persisted") 35 | 36 | fs.BoolVar(&opts.Debug, "D", false, "Enable Debug logging") 37 | fs.BoolVar(&opts.Debug, "debug", false, "Enable Debug logging") 38 | 39 | // tls 40 | fs.BoolVar(&opts.TLS, "tls", false, "Enable TLS. If true, the -tlscert and -tlskey options are mandatory") 41 | fs.StringVar(&opts.TLSCert, "tlscert", "", "Server certificate file") 42 | fs.StringVar(&opts.TLSKey, "tlskey", "", "Private key for server certificate") 43 | 44 | // options to verify client certificate 45 | fs.BoolVar(&opts.TLSVerify, "tlsverify", false, 46 | "Enable verification of client TLS certificate. If true, the -tlscacert option is mandatory") 47 | fs.StringVar(&opts.TLSCaCert, "tlscacert", "", "Root Certificate for verification of client TLS certificate") 48 | 49 | fs.StringVar(&natsCredsFile, "nats-creds", "", "User Credentials File used when bridge connects to NATS") 50 | // tls when connecting to the NATS server 51 | fs.StringVar(&natsClientKey, "nats-key", "", "Public Key used by the bridge when connecting to NATS") 52 | fs.StringVar(&natsClientCert, "nats-cert", "", "Client Certificate used by the bridge when connecting to NATS") 53 | fs.StringVar(&natsRootCAs, "nats-cacert", "", "Client Root Certificate used by the bridge when connecting to NATS") 54 | 55 | _ = fs.Parse(args[1:]) 56 | if printHelp { 57 | fs.SetOutput(stdout) 58 | fs.PrintDefaults() 59 | return 0 60 | } 61 | 62 | if opts.TLS { 63 | if opts.TLSCert == "" || opts.TLSKey == "" { 64 | _, _ = io.WriteString(stderr, "both -tlscert and -tlskey must be given when tls is enabled") 65 | return 2 66 | } 67 | 68 | if opts.Port == 0 { 69 | opts.Port = 8883 70 | } 71 | } else if opts.Port == 0 { 72 | opts.Port = 1883 73 | } 74 | 75 | opts.NATSOpts = []nats.Option{nats.Name("MQTT Bridge")} 76 | if natsCredsFile != "" { 77 | opts.NATSOpts = append(opts.NATSOpts, nats.UserCredentials(natsCredsFile)) 78 | } 79 | if natsClientCert != "" || natsClientKey != "" { 80 | if natsClientCert == "" || natsClientKey == "" { 81 | _, _ = io.WriteString(stderr, "both -nats-cert and -nats-key must be given to enable client verification") 82 | return 2 83 | } 84 | opts.NATSOpts = append(opts.NATSOpts, nats.ClientCert(natsClientCert, natsClientKey)) 85 | } 86 | if natsRootCAs != `` { 87 | opts.NATSOpts = append(opts.NATSOpts, nats.RootCAs(natsRootCAs)) 88 | } 89 | 90 | lg := logger.New(logger.Debug, stdout, stderr) 91 | s, err := bridge.New(opts, lg) 92 | if err == nil { 93 | err = s.Serve(nil) 94 | } 95 | 96 | if err != nil { 97 | lg.Error(err) 98 | return 1 99 | } 100 | return 0 101 | } 102 | -------------------------------------------------------------------------------- /cli/package.go: -------------------------------------------------------------------------------- 1 | // Package cli contains the Command Line Interface 2 | package cli 3 | -------------------------------------------------------------------------------- /examples/certs/.gitignore: -------------------------------------------------------------------------------- 1 | *.pem 2 | *.csr 3 | -------------------------------------------------------------------------------- /examples/certs/ca.json: -------------------------------------------------------------------------------- 1 | { 2 | "CN": "tada.se", 3 | "key": { 4 | "algo": "rsa", 5 | "size": 2048 6 | }, 7 | "names": [ 8 | { 9 | "C": "SE", 10 | "L": "Täby", 11 | "O": "Tada AB", 12 | "ST": "Stockholms Län" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /examples/certs/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "signing": { 3 | "default": { 4 | "expiry": "43800h" 5 | }, 6 | "profiles": { 7 | "server": { 8 | "expiry": "43800h", 9 | "usages": [ 10 | "signing", 11 | "digital signing", 12 | "key encipherment", 13 | "server auth" 14 | ] 15 | } 16 | } 17 | } 18 | } -------------------------------------------------------------------------------- /examples/certs/csr.json: -------------------------------------------------------------------------------- 1 | { 2 | "hosts": [ 3 | "127.0.0.1", 4 | "localhost", 5 | "tada", 6 | "home.tada.se" 7 | ], 8 | "CN": "mqtt-nats.tada.se", 9 | "key": { 10 | "algo": "rsa", 11 | "size": 2048 12 | }, 13 | "names": [ 14 | { 15 | "C": "SE", 16 | "L": "Täby", 17 | "O": "Tada AB", 18 | "ST": "Stockholms Län" 19 | } 20 | ] 21 | } 22 | -------------------------------------------------------------------------------- /examples/certs/generate.sh: -------------------------------------------------------------------------------- 1 | # Generate root CA 2 | cfssl gencert -initca ca.json | cfssljson -bare ca 3 | 4 | # Generate server cert 5 | cfssl gencert -ca=ca.pem -ca-key=ca-key.pem -config=config.json -profile=server server.json | cfssljson -bare server 6 | 7 | # Generate client cert + key 8 | cfssl gencert -ca ca.pem -ca-key ca-key.pem csr.json | cfssljson -bare client 9 | -------------------------------------------------------------------------------- /examples/certs/server.json: -------------------------------------------------------------------------------- 1 | { 2 | "CN": "Server", 3 | "hosts": [ 4 | "127.0.0.1", 5 | "localhost", 6 | "tada", 7 | "home.tada.se" 8 | ] 9 | } 10 | -------------------------------------------------------------------------------- /examples/jmeter/MQTT Pub Sampler.jmx: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | false 7 | true 8 | false 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | continue 17 | 18 | false 19 | 10 20 | 21 | 1 22 | 1 23 | false 24 | 25 | 26 | true 27 | 28 | 29 | 30 | 127.0.0.1 31 | 1883 32 | 3.1.1 33 | 10 34 | TCP 35 | false 36 | 37 | 38 | 39 | 40 | 41 | 42 | conn_ 43 | true 44 | 300 45 | 0 46 | 0 47 | 48 | 49 | 50 | my/test/topic 51 | 1 52 | true 53 | String 54 | 1024 55 | This is my topic 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /examples/server.conf: -------------------------------------------------------------------------------- 1 | listen: 0.0.0.0:4222 2 | tls: { 3 | ca_file: "./certs/ca.pem" 4 | cert_file: "./certs/server.pem" 5 | key_file: "./certs/server-key.pem" 6 | verify: true 7 | } 8 | -------------------------------------------------------------------------------- /examples/tools/nats-pub-repeat/nats-pub-repeat.go: -------------------------------------------------------------------------------- 1 | // +build !citest 2 | 3 | // Copyright 2012-2019 The NATS Authors 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | 16 | package main 17 | 18 | import ( 19 | "bytes" 20 | "flag" 21 | "log" 22 | "os" 23 | "strconv" 24 | "time" 25 | 26 | "github.com/nats-io/nats.go" 27 | ) 28 | 29 | // NOTE: Can test with demo servers. 30 | // nats-pub-repeat -s demo.nats.io 31 | // nats-pub-repeat -s demo.nats.io:4443 (TLS version) 32 | 33 | func showUsageAndExit(exitcode int) { 34 | flag.PrintDefaults() 35 | os.Exit(exitcode) 36 | } 37 | 38 | func main() { 39 | var urls = flag.String("s", nats.DefaultURL, "The nats server URLs (separated by comma)") 40 | var userCreds = flag.String("creds", "", "User Credentials File") 41 | var repeat = flag.Int("repeat", 1, "Repeat count") 42 | var askReply = flag.Int("askreply", 0, "seconds to wait for reply. 0 means don't wait") 43 | var showHelp = flag.Bool("h", false, "Show help message") 44 | var rootCA = flag.String("cacert", "", "TLS root certificate") 45 | var clientKey = flag.String("key", "", "TLS key") 46 | var clientCert = flag.String("cert", "", "TLS certificate") 47 | 48 | log.SetFlags(0) 49 | flag.Parse() 50 | 51 | if *showHelp { 52 | showUsageAndExit(0) 53 | } 54 | 55 | args := flag.Args() 56 | if len(args) != 2 { 57 | showUsageAndExit(1) 58 | } 59 | 60 | // Connect Options. 61 | opts := []nats.Option{nats.Name("NATS Sample Publisher")} 62 | if *rootCA != "" { 63 | opts = append(opts, nats.RootCAs(*rootCA)) 64 | } 65 | if *clientCert != "" { 66 | opts = append(opts, nats.ClientCert(*clientCert, *clientKey)) 67 | } 68 | 69 | // Use UserCredentials 70 | if *userCreds != "" { 71 | opts = append(opts, nats.UserCredentials(*userCreds)) 72 | } 73 | 74 | // Connect to NATS 75 | nc, err := nats.Connect(*urls, opts...) 76 | if err != nil { 77 | log.Fatal(err) 78 | } 79 | defer nc.Close() 80 | subj, msg := args[0], args[1] 81 | 82 | buf := bytes.Buffer{} 83 | for i := 0; i < *repeat; i++ { 84 | buf.Reset() 85 | buf.WriteString(msg) 86 | buf.WriteByte(' ') 87 | buf.WriteString(strconv.Itoa(i)) 88 | if *askReply != 0 { 89 | nc.Request(subj, buf.Bytes(), time.Duration(*askReply)*time.Second) 90 | } else { 91 | nc.Publish(subj, buf.Bytes()) 92 | } 93 | } 94 | nc.Flush() 95 | 96 | if err := nc.LastError(); err != nil { 97 | log.Fatal(err) 98 | } else { 99 | log.Printf("Published [%s] : '%s'\n", subj, msg) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /examples/tools/nats-sub-reply/nats-sub-reply.go: -------------------------------------------------------------------------------- 1 | // +build !citest 2 | 3 | // Copyright 2012-2019 The NATS Authors 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | 16 | package main 17 | 18 | import ( 19 | "flag" 20 | "log" 21 | "os" 22 | "runtime" 23 | "time" 24 | 25 | "github.com/nats-io/nats.go" 26 | ) 27 | 28 | // NOTE: Can test with demo servers. 29 | // nats-sub -s demo.nats.io 30 | // nats-sub -s demo.nats.io:4443 (TLS version) 31 | 32 | func usage() { 33 | log.Printf("Usage: nats-sub-reply [-s server] [-creds file] [-t] \n") 34 | flag.PrintDefaults() 35 | } 36 | 37 | func showUsageAndExit(exitcode int) { 38 | usage() 39 | os.Exit(exitcode) 40 | } 41 | 42 | func printMsg(m *nats.Msg, i int) { 43 | log.Printf("[#%d] Received on [%s]: '%s'", i, m.Subject, string(m.Data)) 44 | } 45 | 46 | func main() { 47 | var urls = flag.String("s", nats.DefaultURL, "The nats server URLs (separated by comma)") 48 | var userCreds = flag.String("creds", "", "User Credentials File") 49 | var showTime = flag.Bool("t", false, "Display timestamps") 50 | var showHelp = flag.Bool("h", false, "Show help message") 51 | var rootCA = flag.String("cacert", "", "TLS root certificate") 52 | var clientKey = flag.String("key", "", "TLS key") 53 | var clientCert = flag.String("cert", "", "TLS certificate") 54 | 55 | log.SetFlags(0) 56 | flag.Usage = usage 57 | flag.Parse() 58 | 59 | if *showHelp { 60 | showUsageAndExit(0) 61 | } 62 | 63 | args := flag.Args() 64 | if len(args) != 1 { 65 | showUsageAndExit(1) 66 | } 67 | 68 | // Connect Options. 69 | opts := []nats.Option{nats.Name("NATS Sample Subscriber")} 70 | opts = setupConnOptions(opts) 71 | if *rootCA != "" { 72 | opts = append(opts, nats.RootCAs(*rootCA)) 73 | } 74 | if *clientCert != "" { 75 | opts = append(opts, nats.ClientCert(*clientCert, *clientKey)) 76 | } 77 | 78 | // Use UserCredentials 79 | if *userCreds != "" { 80 | opts = append(opts, nats.UserCredentials(*userCreds)) 81 | } 82 | 83 | // Use UserCredentials 84 | if *userCreds != "" { 85 | opts = append(opts, nats.UserCredentials(*userCreds)) 86 | } 87 | 88 | // Connect to NATS 89 | nc, err := nats.Connect(*urls, opts...) 90 | if err != nil { 91 | log.Fatal(err) 92 | } 93 | 94 | subj, i := args[0], 0 95 | 96 | nc.Subscribe(subj, func(msg *nats.Msg) { 97 | i += 1 98 | printMsg(msg, i) 99 | if msg.Reply != `` { 100 | log.Printf("Replying to %s", msg.Reply) 101 | msg.Respond([]byte{0}) 102 | } 103 | }) 104 | nc.Flush() 105 | 106 | if err := nc.LastError(); err != nil { 107 | log.Fatal(err) 108 | } 109 | 110 | log.Printf("Listening on [%s]", subj) 111 | if *showTime { 112 | log.SetFlags(log.LstdFlags) 113 | } 114 | 115 | runtime.Goexit() 116 | } 117 | 118 | func setupConnOptions(opts []nats.Option) []nats.Option { 119 | totalWait := 10 * time.Minute 120 | reconnectDelay := time.Second 121 | 122 | opts = append(opts, nats.ReconnectWait(reconnectDelay)) 123 | opts = append(opts, nats.MaxReconnects(int(totalWait/reconnectDelay))) 124 | opts = append(opts, nats.DisconnectHandler(func(nc *nats.Conn) { 125 | log.Printf("Disconnected: will attempt reconnects for %.0fm", totalWait.Minutes()) 126 | })) 127 | opts = append(opts, nats.ReconnectHandler(func(nc *nats.Conn) { 128 | log.Printf("Reconnected [%s]", nc.ConnectedUrl()) 129 | })) 130 | opts = append(opts, nats.ClosedHandler(func(nc *nats.Conn) { 131 | log.Fatalf("Exiting: %v", nc.LastError()) 132 | })) 133 | return opts 134 | } 135 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tada/mqtt-nats 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/golang/protobuf v1.3.5 // indirect 7 | github.com/nats-io/nats-server/v2 v2.1.4 8 | github.com/nats-io/nats.go v1.9.1 9 | github.com/nats-io/nuid v1.0.1 10 | github.com/tada/catch v0.0.0-20200501140707-b8b11d55b4e6 11 | github.com/tada/jsonstream v0.0.0-20200501141504-4d34829515db 12 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 // indirect 13 | ) 14 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= 2 | github.com/golang/protobuf v1.3.5 h1:F768QJ1E9tib+q5Sc8MkdJi1RxLTbRcTf8LJV56aRls= 3 | github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= 4 | github.com/nats-io/jwt v0.3.0 h1:xdnzwFETV++jNc4W1mw//qFyJGb2ABOombmZJQS4+Qo= 5 | github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg= 6 | github.com/nats-io/jwt v0.3.2 h1:+RB5hMpXUUA2dfxuhBTEkMOrYmM+gKIZYS1KjSostMI= 7 | github.com/nats-io/jwt v0.3.2/go.mod h1:/euKqTS1ZD+zzjYrY7pseZrTtWQSjujC7xjPc8wL6eU= 8 | github.com/nats-io/nats-server/v2 v2.1.4 h1:BILRnsJ2Yb/fefiFbBWADpViGF69uh4sxe8poVDQ06g= 9 | github.com/nats-io/nats-server/v2 v2.1.4/go.mod h1:Jw1Z28soD/QasIA2uWjXyM9El1jly3YwyFOuR8tH1rg= 10 | github.com/nats-io/nats.go v1.9.1 h1:ik3HbLhZ0YABLto7iX80pZLPw/6dx3T+++MZJwLnMrQ= 11 | github.com/nats-io/nats.go v1.9.1/go.mod h1:ZjDU1L/7fJ09jvUSRVBR2e7+RnLiiIQyqyzEE/Zbp4w= 12 | github.com/nats-io/nkeys v0.1.0 h1:qMd4+pRHgdr1nAClu+2h/2a5F2TmKcCzjCDazVgRoX4= 13 | github.com/nats-io/nkeys v0.1.0/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= 14 | github.com/nats-io/nkeys v0.1.3 h1:6JrEfig+HzTH85yxzhSVbjHRJv9cn0p6n3IngIcM5/k= 15 | github.com/nats-io/nkeys v0.1.3/go.mod h1:xpnFELMwJABBLVhffcfd1MZx6VsNRFpEugbxziKVo7w= 16 | github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= 17 | github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= 18 | github.com/tada/catch v0.0.0-20200501140707-b8b11d55b4e6 h1:FOmtz4bkMV7ArdaFX9Ev8Mw1frMpw4XfTj8sAb4XprE= 19 | github.com/tada/catch v0.0.0-20200501140707-b8b11d55b4e6/go.mod h1:mL60x4NqUvoa7GzNDLlmi6IyIX8eRqQzkgcD2cs8dWM= 20 | github.com/tada/jsonstream v0.0.0-20200501141504-4d34829515db h1:VzNg3u3uJj5rNoop6te/tK4u/FPh2ytFv7IdbGBTlIE= 21 | github.com/tada/jsonstream v0.0.0-20200501141504-4d34829515db/go.mod h1:MzoAgsR5aJ/WBHQG8XdOmC6EP1Rpk1DpI+li6mgdDf0= 22 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 23 | golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 h1:HuIa8hRrWRSrqYzx1qI49NNxhdi2PrY7gxVSq1JjLDc= 24 | golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 25 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= 26 | golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 27 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 28 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 29 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 30 | golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e h1:D5TXcfTk7xF7hvieo4QErS3qqCB4teTffacDWr7CI+0= 31 | golang.org/x/sys v0.0.0-20190726091711-fc99dfbffb4e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 32 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 33 | -------------------------------------------------------------------------------- /logger/logger.go: -------------------------------------------------------------------------------- 1 | // Package logger contains a logger interface and an implementation that is based ont he standard log.Logger 2 | package logger 3 | 4 | import ( 5 | "io" 6 | "log" 7 | ) 8 | 9 | // A Logger logs information using a log level 10 | type Logger interface { 11 | // DebugEnabled returns true if debug level logging is enabled 12 | DebugEnabled() bool 13 | 14 | // Debug logs at debug level. Arguments are handled in the manner of fmt.Println. 15 | Debug(...interface{}) 16 | 17 | // ErrorEnabled returns true if error level logging is enabled 18 | ErrorEnabled() bool 19 | 20 | // Error logs at error level. Arguments are handled in the manner of fmt.Println. 21 | Error(...interface{}) 22 | 23 | // InfoEnabled returns true if info level logging is enabled 24 | InfoEnabled() bool 25 | 26 | // Info logs at info level. Arguments are handled in the manner of fmt.Println. 27 | Info(...interface{}) 28 | } 29 | 30 | // Level determines at of logging that is enabled 31 | type Level int 32 | 33 | const ( 34 | // Silent means that all logging is disabled 35 | Silent = Level(iota) 36 | 37 | // Error means that only error logging is enabled 38 | Error 39 | 40 | // Info means that error and info logging is enabled 41 | Info 42 | 43 | // Debug means that all logging is enabled 44 | Debug 45 | ) 46 | 47 | type silent int 48 | 49 | func (silent) Debug(...interface{}) { 50 | } 51 | func (silent) DebugEnabled() bool { 52 | return false 53 | } 54 | func (silent) Error(...interface{}) { 55 | } 56 | func (silent) ErrorEnabled() bool { 57 | return false 58 | } 59 | func (silent) Info(...interface{}) { 60 | } 61 | func (silent) InfoEnabled() bool { 62 | return false 63 | } 64 | 65 | type writer struct { 66 | debug *log.Logger 67 | info *log.Logger 68 | err *log.Logger 69 | } 70 | 71 | func (l *writer) Debug(args ...interface{}) { 72 | if l.debug != nil { 73 | l.debug.Println(args...) 74 | } 75 | } 76 | 77 | func (l *writer) DebugEnabled() bool { 78 | return l.debug != nil 79 | } 80 | 81 | func (l *writer) Error(args ...interface{}) { 82 | if l.err != nil { 83 | l.err.Println(args...) 84 | } 85 | } 86 | 87 | func (l *writer) ErrorEnabled() bool { 88 | return l.err != nil 89 | } 90 | 91 | func (l *writer) Info(args ...interface{}) { 92 | if l.info != nil { 93 | l.info.Println(args...) 94 | } 95 | } 96 | 97 | func (l *writer) InfoEnabled() bool { 98 | return l.info != nil 99 | } 100 | 101 | // New returns a logger that is based on the standard log.Logger. 102 | func New(level Level, out, err io.Writer) Logger { 103 | if level == Silent { 104 | return silent(0) 105 | } 106 | 107 | l := &writer{} 108 | switch level { 109 | case Debug: 110 | l.debug = log.New(out, "DEBUG ", log.LstdFlags) 111 | fallthrough 112 | case Info: 113 | l.info = log.New(out, "INFO ", log.LstdFlags) 114 | fallthrough 115 | case Error: 116 | l.err = log.New(err, "ERROR ", log.LstdFlags) 117 | } 118 | return l 119 | } 120 | -------------------------------------------------------------------------------- /logger/logger_test.go: -------------------------------------------------------------------------------- 1 | package logger 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "regexp" 7 | "testing" 8 | ) 9 | 10 | func TestLogger_New(t *testing.T) { 11 | l := New(Silent, nil, nil) 12 | if _, ok := l.(silent); !ok { 13 | t.Error("silent leven did not result in silent logger") 14 | } 15 | } 16 | 17 | var msgEx = regexp.MustCompile(`^(A-Z)\s+^[\s]+\s^[\s+]\s(.*)$`) 18 | 19 | func assertEqual(t *testing.T, e, a string) { 20 | t.Helper() 21 | if e != a { 22 | t.Fatalf("expected '%s', got '%s'", e, a) 23 | } 24 | } 25 | 26 | func checkLogOutput(t *testing.T, ls, ms string, bf fmt.Stringer) { 27 | t.Helper() 28 | if ps := msgEx.FindStringSubmatch(bf.String()); len(ps) == 2 { 29 | assertEqual(t, ls, ps[0]) 30 | assertEqual(t, ms, ps[1]) 31 | } 32 | } 33 | 34 | func TestLogger_Debug(t *testing.T) { 35 | o := bytes.Buffer{} 36 | e := bytes.Buffer{} 37 | l := New(Debug, &o, &e) 38 | m := "some message" 39 | if !(l.ErrorEnabled() && l.InfoEnabled() && l.DebugEnabled()) { 40 | t.Fatal("wrong levels enabled") 41 | } 42 | l.Debug(m) 43 | if e.Len() > 0 { 44 | t.Error("debug log produced output on stderr") 45 | } 46 | l.Error(m) 47 | checkLogOutput(t, "DEBUG", m, &o) 48 | checkLogOutput(t, "ERROR", m, &e) 49 | } 50 | 51 | func TestLogger_Info(t *testing.T) { 52 | o := bytes.Buffer{} 53 | e := bytes.Buffer{} 54 | l := New(Info, &o, &e) 55 | m := "some message" 56 | if !(l.ErrorEnabled() && l.InfoEnabled() && !l.DebugEnabled()) { 57 | t.Fatal("wrong levels enabled") 58 | } 59 | l.Info(m) 60 | if e.Len() > 0 { 61 | t.Error("info log produced output on stderr") 62 | } 63 | l.Error(m) 64 | checkLogOutput(t, "INFO", m, &o) 65 | checkLogOutput(t, "ERROR", m, &e) 66 | } 67 | 68 | func TestLogger_Error(t *testing.T) { 69 | o := bytes.Buffer{} 70 | e := bytes.Buffer{} 71 | l := New(Error, &o, &e) 72 | m := "some message" 73 | if !(l.ErrorEnabled() && !l.InfoEnabled() && !l.DebugEnabled()) { 74 | t.Fatal("wrong levels enabled") 75 | } 76 | l.Error(m) 77 | l.Info(m) 78 | if o.Len() > 0 { 79 | t.Error("error log produced output on stdout") 80 | } 81 | checkLogOutput(t, "ERROR", m, &e) 82 | } 83 | 84 | func TestLogger_Silent(t *testing.T) { 85 | o := bytes.Buffer{} 86 | e := bytes.Buffer{} 87 | l := New(Silent, &o, &e) 88 | m := "some message" 89 | if l.ErrorEnabled() || l.InfoEnabled() || l.DebugEnabled() { 90 | t.Fatal("wrong levels enabled") 91 | } 92 | l.Debug(m) 93 | l.Info(m) 94 | l.Error(m) 95 | if o.Len() > 0 { 96 | t.Error("silent log produced output on stdout") 97 | } 98 | if e.Len() > 0 { 99 | t.Error("silent log produced output on stderr") 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | // +build !citest 2 | 3 | package main 4 | 5 | import ( 6 | "os" 7 | 8 | "github.com/tada/mqtt-nats/cli" 9 | ) 10 | 11 | func main() { 12 | os.Exit(cli.Bridge(os.Args, os.Stdout, os.Stderr)) 13 | } 14 | -------------------------------------------------------------------------------- /mqtt/package.go: -------------------------------------------------------------------------------- 1 | // Package mqtt contains the MQTT reader, writer and MQTT to NATS topic conversion utilities 2 | package mqtt 3 | -------------------------------------------------------------------------------- /mqtt/pkg/connect.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/tada/mqtt-nats/mqtt" 9 | ) 10 | 11 | const ( 12 | protoName = "MQTT" 13 | 14 | cleanSessionFlag = byte(0x02) 15 | willFlag = byte(0x04) 16 | willQoS = byte(0x18) 17 | willRetainFlag = byte(0x20) 18 | passwordFlag = byte(0x40) 19 | userNameFlag = byte(0x80) 20 | ) 21 | 22 | // ReturnCode used in response to a CONNECT 23 | type ReturnCode byte 24 | 25 | func (r ReturnCode) Error() string { 26 | switch r { 27 | case RtAccepted: 28 | return "accepted" 29 | case RtUnacceptableProtocolVersion: 30 | return "unacceptable protocol version" 31 | case RtIdentifierRejected: 32 | return "identifier rejected" 33 | case RtServerUnavailable: 34 | return "server unavailable" 35 | case RtBadUserNameOrPassword: 36 | return "bad user name or password" 37 | case RtNotAuthorized: 38 | return "not authorized" 39 | default: 40 | return "unknown error" 41 | } 42 | } 43 | 44 | const ( 45 | // RtAccepted Connection Accepted 46 | RtAccepted = ReturnCode(iota) 47 | 48 | // RtUnacceptableProtocolVersion The Server does not support the level of the MQTT protocol requested by the Client 49 | RtUnacceptableProtocolVersion 50 | 51 | // RtIdentifierRejected The Client identifier is correct UTF-8 but not allowed by the Server 52 | RtIdentifierRejected 53 | 54 | // RtServerUnavailable The Network Connection has been made but the MQTT service is unavailable 55 | RtServerUnavailable 56 | 57 | // RtBadUserNameOrPassword The data in the user name or password is malformed 58 | RtBadUserNameOrPassword 59 | 60 | // RtNotAuthorized The Client is not authorized to connect 61 | RtNotAuthorized 62 | ) 63 | 64 | // Connect is the MQTT connect packet 65 | type Connect struct { 66 | clientID string 67 | creds *Credentials 68 | will *Will 69 | keepAlive uint16 70 | clientLevel byte 71 | flags byte 72 | } 73 | 74 | // NewConnect creates a new MQTT connect packet 75 | func NewConnect(clientID string, cleanSession bool, keepAlive uint16, will *Will, creds *Credentials) *Connect { 76 | flags := byte(0) 77 | if cleanSession { 78 | flags |= cleanSessionFlag 79 | } 80 | if will != nil { 81 | flags |= willFlag | (will.QoS << 3) 82 | if will.Retain { 83 | flags |= willRetainFlag 84 | } 85 | } 86 | if creds != nil { 87 | if len(creds.Password) > 0 { 88 | flags |= passwordFlag 89 | } 90 | if len(creds.User) > 0 { 91 | flags |= userNameFlag 92 | } 93 | } 94 | 95 | return &Connect{ 96 | clientID: clientID, 97 | creds: creds, 98 | keepAlive: keepAlive, 99 | clientLevel: 0x4, 100 | flags: flags, 101 | will: will} 102 | } 103 | 104 | // ParseConnect parses the connect packet from the given reader. 105 | func ParseConnect(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) { 106 | var err error 107 | if r, err = r.ReadPacket(pkLen); err != nil { 108 | return nil, err 109 | } 110 | 111 | // Protocol Name 112 | var proto string 113 | if proto, err = r.ReadString(); err != nil { 114 | return nil, err 115 | } 116 | if proto != protoName { 117 | return nil, fmt.Errorf(`expected connect packet with protocol name "MQTT", got "%s"`, proto) 118 | } 119 | 120 | c := &Connect{} 121 | 122 | // Protocol Level 123 | if c.clientLevel, err = r.ReadByte(); err != nil { 124 | return nil, err 125 | } 126 | if c.clientLevel != 0x4 { 127 | return c, RtUnacceptableProtocolVersion 128 | } 129 | 130 | // Connect Flags 131 | if c.flags, err = r.ReadByte(); err != nil { 132 | return nil, err 133 | } 134 | 135 | // Keep Alive 136 | if c.keepAlive, err = r.ReadUint16(); err != nil { 137 | return nil, err 138 | } 139 | 140 | // Payload starts here 141 | 142 | // Client Identifier 143 | if c.clientID, err = r.ReadString(); err != nil { 144 | return nil, err 145 | } 146 | 147 | // Will 148 | if c.HasWill() { 149 | c.will = &Will{QoS: (c.flags & willQoS) >> 3, Retain: (c.flags & willRetainFlag) != 0} 150 | if c.will.Topic, err = r.ReadString(); err != nil { 151 | return nil, err 152 | } 153 | if c.will.Message, err = r.ReadBytes(); err != nil { 154 | return nil, err 155 | } 156 | } 157 | 158 | // User Name 159 | if (c.flags & (userNameFlag | passwordFlag)) != 0 { 160 | c.creds = &Credentials{} 161 | if c.HasUserName() { 162 | if c.creds.User, err = r.ReadString(); err != nil { 163 | return nil, err 164 | } 165 | } 166 | // Password 167 | if c.HasPassword() { 168 | if c.creds.Password, err = r.ReadBytes(); err != nil { 169 | return nil, err 170 | } 171 | } 172 | } 173 | return c, nil 174 | } 175 | 176 | // Equals returns true if this packet is equal to the given packet, false if not 177 | func (c *Connect) Equals(other interface{}) bool { 178 | oc, ok := other.(*Connect) 179 | return ok && 180 | c.keepAlive == oc.keepAlive && 181 | c.clientLevel == oc.clientLevel && 182 | c.flags == oc.flags && 183 | c.clientID == oc.clientID && 184 | (c.will == oc.will || (c.will != nil && c.will.Equals(oc.will))) && 185 | (c.creds == oc.creds || (c.creds != nil && c.creds.Equals(oc.creds))) 186 | } 187 | 188 | // SetClientLevel sets the client level. Intended for testing purposes 189 | func (c *Connect) SetClientLevel(cl byte) { 190 | c.clientLevel = cl 191 | } 192 | 193 | // Write writes the MQTT bits of this packet on the given Writer 194 | func (c *Connect) Write(w *mqtt.Writer) { 195 | pkLen := 2 + len(protoName) + 196 | 1 + // clientLevel 197 | 1 + // flags 198 | 2 + // keepAlive 199 | 2 + len(c.clientID) 200 | 201 | if c.HasWill() { 202 | pkLen += 2 + len(c.will.Topic) 203 | pkLen += 2 + len(c.will.Message) 204 | } 205 | if c.HasUserName() { 206 | pkLen += 2 + len(c.creds.User) 207 | } 208 | if c.HasPassword() { 209 | pkLen += 2 + len(c.creds.Password) 210 | } 211 | 212 | w.WriteU8(TpConnect) 213 | w.WriteVarInt(pkLen) 214 | w.WriteString(protoName) 215 | w.WriteU8(c.clientLevel) 216 | w.WriteU8(c.flags) 217 | w.WriteU16(c.keepAlive) 218 | w.WriteString(c.clientID) 219 | if c.HasWill() { 220 | w.WriteString(c.will.Topic) 221 | w.WriteBytes(c.will.Message) 222 | } 223 | if c.HasUserName() { 224 | w.WriteString(c.creds.User) 225 | } 226 | if c.HasPassword() { 227 | w.WriteBytes(c.creds.Password) 228 | } 229 | } 230 | 231 | // String returns a brief string representation of the packet. Suitable for logging 232 | func (c *Connect) String() string { 233 | will := "" 234 | if c.HasWill() { 235 | will = ", " + c.Will().String() 236 | } 237 | return fmt.Sprintf("CONNECT (c%d, k%d, u%d, p%d%s)", 238 | (c.flags&cleanSessionFlag)>>1, 239 | time.Duration(c.keepAlive), 240 | (c.flags&userNameFlag)>>7, 241 | (c.flags&passwordFlag)>>6, 242 | will) 243 | } 244 | 245 | // CleanSession returns true if the connection is requesting a clean session 246 | func (c *Connect) CleanSession() bool { 247 | return (c.flags & cleanSessionFlag) != 0 248 | } 249 | 250 | // ClientID returns the id provided by the client 251 | func (c *Connect) ClientID() string { 252 | return c.clientID 253 | } 254 | 255 | // HasPassword returns true if the connection contains a password 256 | func (c *Connect) HasPassword() bool { 257 | return (c.flags & passwordFlag) != 0 258 | } 259 | 260 | // HasUserName returns true if the connection contains a user name 261 | func (c *Connect) HasUserName() bool { 262 | return (c.flags & userNameFlag) != 0 263 | } 264 | 265 | // HasWill returns true if the connection contains a will 266 | func (c *Connect) HasWill() bool { 267 | return (c.flags & willFlag) != 0 268 | } 269 | 270 | // KeepAlive returns the desired keep alive duration 271 | func (c *Connect) KeepAlive() time.Duration { 272 | return time.Duration(c.keepAlive) * time.Second 273 | } 274 | 275 | // Credentials returns the user name and password credentials or nil 276 | func (c *Connect) Credentials() *Credentials { 277 | return c.creds 278 | } 279 | 280 | // Will returns the client will or nil 281 | func (c *Connect) Will() *Will { 282 | return c.will 283 | } 284 | 285 | // DeleteWill clears will and all flags that are associated with the will 286 | func (c *Connect) DeleteWill() { 287 | c.flags &^= willFlag | willQoS | willRetainFlag 288 | c.will = nil 289 | } 290 | 291 | // ConnAck is the MQTT CONNACK packet sent in response to a CONNECT 292 | type ConnAck struct { 293 | flags byte 294 | returnCode byte 295 | } 296 | 297 | // NewConnAck creates an CONNACK packet 298 | func NewConnAck(sessionPresent bool, returnCode ReturnCode) Packet { 299 | flags := byte(0x00) 300 | if sessionPresent { 301 | flags |= 0x01 302 | } 303 | return &ConnAck{flags: flags, returnCode: byte(returnCode)} 304 | } 305 | 306 | // ParseConnAck parses a CONNACK packet 307 | func ParseConnAck(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) { 308 | var err error 309 | if pkLen != 2 { 310 | return nil, errors.New("malformed CONNACK") 311 | } 312 | var bs []byte 313 | bs, err = r.ReadExact(2) 314 | if err != nil { 315 | return nil, err 316 | } 317 | return &ConnAck{flags: bs[0], returnCode: bs[1]}, nil 318 | } 319 | 320 | // Equals returns true if this packet is equal to the given packet, false if not 321 | func (a *ConnAck) Equals(other interface{}) bool { 322 | ac, ok := other.(*ConnAck) 323 | return ok && *a == *ac 324 | } 325 | 326 | // ReturnCode returns the return code from the server 327 | func (a *ConnAck) ReturnCode() ReturnCode { 328 | return ReturnCode(a.returnCode) 329 | } 330 | 331 | // String returns a brief string representation of the packet. Suitable for logging 332 | func (a *ConnAck) String() string { 333 | return fmt.Sprintf("CONNACK (s%d, rt%d)", a.flags, a.returnCode) 334 | } 335 | 336 | // Write writes the MQTT bits of this packet on the given Writer 337 | func (a *ConnAck) Write(w *mqtt.Writer) { 338 | w.WriteU8(TpConnAck) 339 | w.WriteU8(2) 340 | w.WriteU8(a.flags) 341 | w.WriteU8(a.returnCode) 342 | } 343 | 344 | // The Disconnect type represents the MQTT DISCONNECT packet 345 | type Disconnect int 346 | 347 | // DisconnectSingleton is the one and only instance of the Disconnect type 348 | const DisconnectSingleton = Disconnect(0) 349 | 350 | // Equals returns true if this packet is equal to the given packet, false if not 351 | func (Disconnect) Equals(other interface{}) bool { 352 | return other == DisconnectSingleton 353 | } 354 | 355 | // String returns a brief string representation of the packet. Suitable for logging 356 | func (Disconnect) String() string { 357 | return "DISCONNECT" 358 | } 359 | 360 | // Write writes the MQTT bits of this packet on the given Writer 361 | func (Disconnect) Write(w *mqtt.Writer) { 362 | w.WriteU8(TpDisconnect) 363 | w.WriteU8(0) 364 | } 365 | -------------------------------------------------------------------------------- /mqtt/pkg/connect_test.go: -------------------------------------------------------------------------------- 1 | package pkg_test 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/tada/mqtt-nats/test/utils" 8 | 9 | "github.com/tada/mqtt-nats/mqtt" 10 | 11 | "github.com/tada/mqtt-nats/mqtt/pkg" 12 | ) 13 | 14 | func TestParseConnect(t *testing.T) { 15 | c1 := pkg.NewConnect(`cid`, true, 5, &pkg.Will{ 16 | Topic: "my/will", 17 | Message: []byte("the will"), 18 | QoS: 1, 19 | Retain: false, 20 | }, &pkg.Credentials{User: "bob", Password: []byte("password")}) 21 | writeReadAndCompare(t, c1, "CONNECT (c1, k5, u1, p1, w(r0, q1, 'my/will', ... (8 bytes)))") 22 | } 23 | 24 | func TestParseConnect_badLen(t *testing.T) { 25 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader([]byte{})), pkg.TpConnAck, 20) 26 | utils.CheckError(err, t) 27 | } 28 | 29 | func TestParseConnect_badProto(t *testing.T) { 30 | w := mqtt.NewWriter() 31 | w.WriteU16(5) 32 | w.WriteU8('N') 33 | w.WriteU8('O') 34 | w.WriteU8('N') 35 | w.WriteU8('O') 36 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 6) 37 | utils.CheckError(err, t) 38 | } 39 | 40 | func TestParseConnect_illegalProto(t *testing.T) { 41 | w := mqtt.NewWriter() 42 | w.WriteString("NONO") 43 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 6) 44 | utils.CheckError(err, t) 45 | } 46 | 47 | func TestParseConnect_badClientLevel(t *testing.T) { 48 | w := mqtt.NewWriter() 49 | w.WriteString("MQTT") 50 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 6) 51 | utils.CheckError(err, t) 52 | } 53 | 54 | func TestParseConnect_illegalClientLevel(t *testing.T) { 55 | w := mqtt.NewWriter() 56 | w.WriteString("MQTT") 57 | w.WriteU8(2) 58 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 7) 59 | utils.CheckError(err, t) 60 | } 61 | 62 | func TestParseConnect_badConnectFlags(t *testing.T) { 63 | w := mqtt.NewWriter() 64 | w.WriteString("MQTT") 65 | w.WriteU8(4) 66 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 7) 67 | utils.CheckError(err, t) 68 | } 69 | 70 | func TestParseConnect_badKeepAlive(t *testing.T) { 71 | w := mqtt.NewWriter() 72 | w.WriteString("MQTT") 73 | w.WriteU8(4) 74 | w.WriteU8(0) 75 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 8) 76 | utils.CheckError(err, t) 77 | } 78 | 79 | func TestParseConnect_badClientID(t *testing.T) { 80 | w := mqtt.NewWriter() 81 | w.WriteString("MQTT") 82 | w.WriteU8(4) 83 | w.WriteU8(0) 84 | w.WriteU16(5) 85 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 10) 86 | utils.CheckError(err, t) 87 | } 88 | 89 | func TestParseConnect_badWillTopic(t *testing.T) { 90 | w := mqtt.NewWriter() 91 | w.WriteString("MQTT") 92 | w.WriteU8(4) 93 | w.WriteU8(0x04) 94 | w.WriteU16(5) 95 | w.WriteString("cid") 96 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 15) 97 | utils.CheckError(err, t) 98 | } 99 | 100 | func TestParseConnect_badWillMessage(t *testing.T) { 101 | w := mqtt.NewWriter() 102 | w.WriteString("MQTT") 103 | w.WriteU8(4) 104 | w.WriteU8(0x04) 105 | w.WriteU16(5) 106 | w.WriteString("cid") 107 | w.WriteString("wtp") 108 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 20) 109 | utils.CheckError(err, t) 110 | } 111 | 112 | func TestParseConnect_badUser(t *testing.T) { 113 | w := mqtt.NewWriter() 114 | w.WriteString("MQTT") 115 | w.WriteU8(4) 116 | w.WriteU8(0x80) 117 | w.WriteU16(5) 118 | w.WriteString("cid") 119 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 15) 120 | utils.CheckError(err, t) 121 | } 122 | 123 | func TestParseConnect_badPw(t *testing.T) { 124 | w := mqtt.NewWriter() 125 | w.WriteString("MQTT") 126 | w.WriteU8(4) 127 | w.WriteU8(0x40) 128 | w.WriteU16(5) 129 | w.WriteString("cid") 130 | _, err := pkg.ParseConnect(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck, 15) 131 | utils.CheckError(err, t) 132 | } 133 | 134 | func TestParseConnAck(t *testing.T) { 135 | writeReadAndCompare(t, pkg.NewConnAck(false, 1), "CONNACK (s0, rt1)") 136 | } 137 | 138 | func TestParseConnAck_badLen(t *testing.T) { 139 | _, err := pkg.ParseConnAck(mqtt.NewReader(bytes.NewReader([]byte{})), pkg.TpConnAck, 0) 140 | utils.CheckError(err, t) 141 | } 142 | 143 | func TestParseConnAck_badBytes(t *testing.T) { 144 | _, err := pkg.ParseConnAck(mqtt.NewReader(bytes.NewReader([]byte{1})), pkg.TpConnAck, 2) 145 | utils.CheckError(err, t) 146 | } 147 | 148 | func TestParseDisconnect(t *testing.T) { 149 | writeReadAndCompare(t, pkg.DisconnectSingleton, "DISCONNECT") 150 | } 151 | 152 | func TestReturnCode_Error(t *testing.T) { 153 | tests := []struct { 154 | name string 155 | code pkg.ReturnCode 156 | want string 157 | }{ 158 | { 159 | name: "RtAccepted", 160 | code: pkg.RtAccepted, 161 | want: "accepted", 162 | }, 163 | { 164 | name: "RtUnacceptableProtocolVersion", 165 | code: pkg.RtUnacceptableProtocolVersion, 166 | want: "unacceptable protocol version", 167 | }, 168 | { 169 | name: "RtIdentifierRejected", 170 | code: pkg.RtIdentifierRejected, 171 | want: "identifier rejected", 172 | }, 173 | { 174 | name: "RtServerUnavailable", 175 | code: pkg.RtServerUnavailable, 176 | want: "server unavailable", 177 | }, 178 | { 179 | name: "RtBadUserNameOrPassword", 180 | code: pkg.RtBadUserNameOrPassword, 181 | want: "bad user name or password", 182 | }, 183 | { 184 | name: "RtNotAuthorized", 185 | code: pkg.RtNotAuthorized, 186 | want: "not authorized", 187 | }, 188 | { 189 | name: "RtNotAuthorized", 190 | code: pkg.ReturnCode(99), 191 | want: "unknown error", 192 | }, 193 | } 194 | for i := range tests { 195 | tt := tests[i] 196 | t.Run(tt.name, func(t *testing.T) { 197 | if got := tt.code.Error(); got != tt.want { 198 | t.Errorf("ReturnCode.Error() = %v, want %v", got, tt.want) 199 | } 200 | }) 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /mqtt/pkg/credentials.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "encoding/json" 7 | "io" 8 | 9 | "github.com/tada/catch" 10 | 11 | "github.com/tada/catch/pio" 12 | "github.com/tada/jsonstream" 13 | ) 14 | 15 | // Credentials are user credentials that originates from an MQTT CONNECT packet. 16 | type Credentials struct { 17 | User string 18 | Password []byte 19 | } 20 | 21 | // MarshalToJSON streams the JSON encoded form of this instance onto the given io.Writer 22 | func (c *Credentials) MarshalToJSON(w io.Writer) { 23 | pio.WriteByte(w, '{') 24 | if c.User != "" { 25 | pio.WriteString(w, `"u":`) 26 | jsonstream.WriteString(w, c.User) 27 | } 28 | if c.Password != nil { 29 | if c.User != "" { 30 | pio.WriteByte(w, ',') 31 | } 32 | pio.WriteString(w, `"p":`) 33 | jsonstream.WriteString(w, base64.StdEncoding.EncodeToString(c.Password)) 34 | } 35 | pio.WriteByte(w, '}') 36 | } 37 | 38 | // UnmarshalFromJSON initializes this instance from the tokens stream provided by the json.Decoder. The 39 | // first token has already been read and is passed as an argument. 40 | func (c *Credentials) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) { 41 | jsonstream.AssertDelim(t, '{') 42 | for { 43 | k, ok := js.ReadStringOrEnd('}') 44 | if !ok { 45 | break 46 | } 47 | switch k { 48 | case "u": 49 | c.User = js.ReadString() 50 | case "p": 51 | p, err := base64.StdEncoding.DecodeString(js.ReadString()) 52 | if err != nil { 53 | panic(catch.Error(err)) 54 | } 55 | c.Password = p 56 | } 57 | } 58 | } 59 | 60 | // Equals returns true if this instance is equal to the given instance, false if not 61 | func (c *Credentials) Equals(oc *Credentials) bool { 62 | return c.User == oc.User && bytes.Equal(c.Password, oc.Password) 63 | } 64 | -------------------------------------------------------------------------------- /mqtt/pkg/idmanager.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "encoding/json" 5 | "io" 6 | "sync" 7 | 8 | "github.com/tada/catch/pio" 9 | "github.com/tada/jsonstream" 10 | ) 11 | 12 | // An IDManager manages packet IDs and ensures their uniqueness by maintaining a list of 13 | // IDs that are in use 14 | type IDManager interface { 15 | // NextFreePacketID allocates and returns the next free packet ID 16 | NextFreePacketID() uint16 17 | 18 | // ReleasePacketID releases a previously allocated packet ID 19 | ReleasePacketID(uint16) 20 | } 21 | 22 | type idManager struct { 23 | pkgIDLock sync.Mutex 24 | inFlight map[uint16]bool 25 | nextFreePkgID uint16 26 | } 27 | 28 | // NewIDManager creates a new IDManager 29 | func NewIDManager() IDManager { 30 | return &idManager{nextFreePkgID: 1, inFlight: make(map[uint16]bool, 37)} 31 | } 32 | 33 | func (s *idManager) NextFreePacketID() uint16 { 34 | s.pkgIDLock.Lock() 35 | s.nextFreePkgID++ 36 | if s.nextFreePkgID == 0 { 37 | // counter flipped over and zero is not a valid ID 38 | s.nextFreePkgID++ 39 | } 40 | for s.inFlight[s.nextFreePkgID] { 41 | s.nextFreePkgID++ 42 | if s.nextFreePkgID == 0 { 43 | // counter flipped over and zero is not a valid ID 44 | s.nextFreePkgID++ 45 | } 46 | } 47 | s.inFlight[s.nextFreePkgID] = true 48 | s.pkgIDLock.Unlock() 49 | return s.nextFreePkgID 50 | } 51 | 52 | func (s *idManager) ReleasePacketID(id uint16) { 53 | s.pkgIDLock.Lock() 54 | delete(s.inFlight, id) 55 | s.pkgIDLock.Unlock() 56 | } 57 | 58 | func (s *idManager) MarshalToJSON(w io.Writer) { 59 | var ( 60 | nf uint16 61 | inf []uint16 62 | ) 63 | 64 | // take a snapshot of things in flight 65 | s.pkgIDLock.Lock() 66 | nf = s.nextFreePkgID 67 | inf = make([]uint16, len(s.inFlight)) 68 | i := 0 69 | for k := range s.inFlight { 70 | inf[i] = k 71 | i++ 72 | } 73 | s.pkgIDLock.Unlock() 74 | 75 | pio.WriteString(w, `{"next":`) 76 | pio.WriteInt(w, int64(nf)) 77 | if len(inf) > 0 { 78 | pio.WriteString(w, `,"inFlight":[`) 79 | for i := range inf { 80 | if i > 0 { 81 | pio.WriteByte(w, ',') 82 | } 83 | pio.WriteInt(w, int64(inf[i])) 84 | } 85 | pio.WriteByte(w, ']') 86 | } 87 | pio.WriteByte(w, '}') 88 | } 89 | 90 | func (s *idManager) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) { 91 | s.inFlight = make(map[uint16]bool, 37) 92 | jsonstream.AssertDelim(t, '{') 93 | for { 94 | k, ok := js.ReadStringOrEnd('}') 95 | if !ok { 96 | break 97 | } 98 | switch k { 99 | case "next": 100 | s.nextFreePkgID = uint16(js.ReadInt()) 101 | case "inFlight": 102 | js.ReadDelim('[') 103 | for { 104 | i, ok := js.ReadIntOrEnd(']') 105 | if !ok { 106 | break 107 | } 108 | s.inFlight[uint16(i)] = true 109 | } 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /mqtt/pkg/idmanager_test.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "math" 5 | "testing" 6 | 7 | "github.com/tada/jsonstream" 8 | "github.com/tada/mqtt-nats/test/utils" 9 | ) 10 | 11 | func TestIdManager_NextFreePacketID_flip(t *testing.T) { 12 | idm := NewIDManager().(*idManager) 13 | idm.nextFreePkgID = math.MaxUint16 - 1 14 | utils.CheckEqual(uint16(math.MaxUint16), idm.NextFreePacketID(), t) 15 | utils.CheckEqual(uint16(1), idm.NextFreePacketID(), t) 16 | } 17 | 18 | func TestIdManager_NextFreePacketID_flipInFlight(t *testing.T) { 19 | idm := NewIDManager().(*idManager) 20 | idm.nextFreePkgID = math.MaxUint16 - 2 21 | idm.inFlight[uint16(math.MaxUint16)] = true 22 | utils.CheckEqual(uint16(math.MaxUint16-1), idm.NextFreePacketID(), t) 23 | utils.CheckEqual(uint16(1), idm.NextFreePacketID(), t) 24 | } 25 | 26 | func TestIdManager_NextFreePacketID_json(t *testing.T) { 27 | idm := NewIDManager().(*idManager) 28 | idm.nextFreePkgID = math.MaxUint16 - 3 29 | idm.inFlight[uint16(math.MaxUint16-1)] = true 30 | idm.inFlight[uint16(math.MaxUint16)] = true 31 | idm.inFlight[uint16(1)] = true 32 | bs, err := jsonstream.Marshal(idm) 33 | utils.CheckNotError(err, t) 34 | idm = &idManager{} 35 | utils.CheckNotError(jsonstream.Unmarshal(idm, bs), t) 36 | utils.CheckEqual(uint16(math.MaxUint16-2), idm.NextFreePacketID(), t) 37 | utils.CheckEqual(uint16(2), idm.NextFreePacketID(), t) 38 | } 39 | -------------------------------------------------------------------------------- /mqtt/pkg/packet.go: -------------------------------------------------------------------------------- 1 | // Package pkg contains the MQTT packet structures 2 | package pkg 3 | 4 | import "github.com/tada/mqtt-nats/mqtt" 5 | 6 | const ( 7 | // TpConnect is the MQTT CONNECT type 8 | TpConnect = 0x10 9 | 10 | // TpConnAck is the MQTT CONNACK type 11 | TpConnAck = 0x20 12 | 13 | // TpPublish is the MQTT PUBLISH type 14 | TpPublish = 0x30 15 | 16 | // TpPubAck is the MQTT PUBACK type 17 | TpPubAck = 0x40 18 | 19 | // TpPubRec is the MQTT PUBREC type 20 | TpPubRec = 0x50 21 | 22 | // TpPubRel is the MQTT PUBREL type 23 | TpPubRel = 0x60 24 | 25 | // TpPubComp is the MQTT PUBCOMP type 26 | TpPubComp = 0x70 27 | 28 | // TpSubscribe is the MQTT SUBSCRIBE type 29 | TpSubscribe = 0x80 30 | 31 | // TpSubAck is the MQTT SUBACK type 32 | TpSubAck = 0x90 33 | 34 | // TpUnsubscribe is the MQTT UNSUBSCRIBE type 35 | TpUnsubscribe = 0xa0 36 | 37 | // TpUnsubAck is the MQTT UNSUBACK type 38 | TpUnsubAck = 0xb0 39 | 40 | // TpPing is the MQTT PINGREQ type 41 | TpPing = 0xc0 42 | 43 | // TpPingResp is the MQTT PINGRESP type 44 | TpPingResp = 0xd0 45 | 46 | // TpDisconnect is the MQTT DISCONNECT type 47 | TpDisconnect = 0xe0 48 | 49 | // TpMask is bitmask for the MQTT type 50 | TpMask = 0xf0 51 | ) 52 | 53 | // The Packet interface is implemented by all MQTT packet types 54 | type Packet interface { 55 | // Equals returns true if this packet is equal to the given packet, false if not 56 | Equals(other interface{}) bool 57 | 58 | // Write writes the MQTT bits of this packet on the given Writer 59 | Write(w *mqtt.Writer) 60 | } 61 | -------------------------------------------------------------------------------- /mqtt/pkg/packet_test.go: -------------------------------------------------------------------------------- 1 | package pkg_test 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "testing" 7 | 8 | "github.com/tada/mqtt-nats/mqtt" 9 | "github.com/tada/mqtt-nats/mqtt/pkg" 10 | "github.com/tada/mqtt-nats/test/packet" 11 | ) 12 | 13 | func writeReadAndCompare(t *testing.T, p pkg.Packet, ex string) { 14 | w := &mqtt.Writer{} 15 | p.Write(w) 16 | p2 := packet.Parse(t, bytes.NewReader(w.Bytes())) 17 | if !p.Equals(p2) { 18 | t.Fatal(p, "!=", p2) 19 | } 20 | ac := p.(fmt.Stringer).String() 21 | if ex != ac { 22 | t.Errorf("expected '%s' got '%s'", ex, ac) 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /mqtt/pkg/ping.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import "github.com/tada/mqtt-nats/mqtt" 4 | 5 | // The PingRequest type represents the MQTT PINGREQ packet 6 | type PingRequest int 7 | 8 | // PingRequestSingleton is the one and only instance of the PingRequest type 9 | const PingRequestSingleton = PingRequest(0) 10 | 11 | // Equals returns true if this packet is equal to the given packet, false if not 12 | func (PingRequest) Equals(other interface{}) bool { 13 | return other == PingRequestSingleton 14 | } 15 | 16 | // String returns a brief string representation of the packet. Suitable for logging 17 | func (PingRequest) String() string { 18 | return "PINGREQ" 19 | } 20 | 21 | // Write writes the MQTT bits of this packet on the given Writer 22 | func (PingRequest) Write(w *mqtt.Writer) { 23 | w.WriteU8(TpPing) 24 | w.WriteU8(0) 25 | } 26 | 27 | // The PingResponse type represents the MQTT PINGRESP packet 28 | type PingResponse int 29 | 30 | // PingResponseSingleton is the one and only instance of the PingResponse type 31 | const PingResponseSingleton = PingResponse(0) 32 | 33 | // Equals returns true if this packet is equal to the given packet, false if not 34 | func (PingResponse) Equals(other interface{}) bool { 35 | return other == PingResponseSingleton 36 | } 37 | 38 | // String returns a brief string representation of the packet. Suitable for logging 39 | func (PingResponse) String() string { 40 | return "PINGRESP" 41 | } 42 | 43 | // Write writes the MQTT bits of this packet on the given Writer 44 | func (PingResponse) Write(w *mqtt.Writer) { 45 | w.WriteU8(TpPingResp) 46 | w.WriteU8(0) 47 | } 48 | -------------------------------------------------------------------------------- /mqtt/pkg/ping_test.go: -------------------------------------------------------------------------------- 1 | package pkg_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/tada/mqtt-nats/mqtt/pkg" 7 | ) 8 | 9 | func TestParsePingReq(t *testing.T) { 10 | writeReadAndCompare(t, pkg.PingRequestSingleton, "PINGREQ") 11 | } 12 | 13 | func TestParsePingResp(t *testing.T) { 14 | writeReadAndCompare(t, pkg.PingResponseSingleton, "PINGRESP") 15 | } 16 | -------------------------------------------------------------------------------- /mqtt/pkg/publish.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "bytes" 5 | "encoding/base64" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | "io" 10 | 11 | "github.com/tada/catch" 12 | 13 | "github.com/tada/catch/pio" 14 | "github.com/tada/jsonstream" 15 | "github.com/tada/mqtt-nats/mqtt" 16 | ) 17 | 18 | const ( 19 | // PublishRetain is the bit representing MQTT PUBLISH "retain" flag 20 | PublishRetain = 0x01 21 | 22 | // PublishQoS is the mask for the MQTT PUBLISH "quality of service" bits 23 | PublishQoS = 0x06 24 | 25 | // PublishDup is the bit representing MQTT PUBLISH "dup" flag 26 | PublishDup = 0x08 27 | ) 28 | 29 | // The Publish type represents the MQTT PUBLISH packet 30 | type Publish struct { 31 | name string 32 | replyTo string 33 | payload []byte 34 | id uint16 35 | flags byte 36 | sentByUs bool // set if the message originated from this server (happens when a client will is published) 37 | } 38 | 39 | // SimplePublish creates a new Publish packet with all flags zero and no reply 40 | func SimplePublish(topic string, payload []byte) *Publish { 41 | return &Publish{name: topic, payload: payload} 42 | } 43 | 44 | // NewPublish creates a new Publish packet 45 | func NewPublish(id uint16, topic string, flags byte, payload []byte, sentByUs bool, natsReplyTo string) *Publish { 46 | return &Publish{ 47 | name: topic, 48 | replyTo: natsReplyTo, 49 | payload: payload, 50 | id: id, 51 | flags: flags, 52 | sentByUs: sentByUs, 53 | } 54 | } 55 | 56 | // NewPublish2 creates a new Publish packet 57 | func NewPublish2(id uint16, topic string, payload []byte, qos byte, dup bool, retain bool) *Publish { 58 | flags := byte(0) 59 | if qos > 0 { 60 | flags |= qos << 1 61 | } 62 | if dup { 63 | flags |= PublishDup 64 | } 65 | if retain { 66 | flags |= PublishRetain 67 | } 68 | return &Publish{id: id, flags: flags, name: topic, payload: payload, sentByUs: false, replyTo: ""} 69 | } 70 | 71 | // ParsePublish parses the publish packet from the given reader. 72 | func ParsePublish(r *mqtt.Reader, flags byte, pkLen int) (Packet, error) { 73 | var err error 74 | if r, err = r.ReadPacket(pkLen); err != nil { 75 | return nil, err 76 | } 77 | 78 | pp := &Publish{flags: flags & 0xf} 79 | if pp.name, err = r.ReadString(); err != nil { 80 | return nil, err 81 | } 82 | 83 | if pp.QoSLevel() > 0 { 84 | if pp.id, err = r.ReadUint16(); err != nil { 85 | return nil, err 86 | } 87 | } 88 | if pp.payload, err = r.ReadRemainingBytes(); err != nil { 89 | return nil, err 90 | } 91 | return pp, nil 92 | } 93 | 94 | // Equals returns true if this packet is equal to the given packet, false if not 95 | func (p *Publish) Equals(other interface{}) bool { 96 | op, ok := other.(*Publish) 97 | return ok && 98 | p.id == op.id && 99 | p.flags == op.flags && 100 | p.sentByUs == op.sentByUs && 101 | p.name == op.name && 102 | p.replyTo == op.replyTo && 103 | bytes.Equal(p.payload, op.payload) 104 | } 105 | 106 | // Flags returns the packet flags 107 | func (p *Publish) Flags() byte { 108 | return p.flags 109 | } 110 | 111 | // ID returns the MQTT Packet Identifier. The identifier is only valid if QoS > 0 112 | func (p *Publish) ID() uint16 { 113 | return p.id 114 | } 115 | 116 | // IsDup returns true if the packet is a duplicate of a previously sent packet 117 | func (p *Publish) IsDup() bool { 118 | return (p.flags & PublishDup) != 0 119 | } 120 | 121 | // SetDup sets the dup flag of the packet 122 | func (p *Publish) SetDup() { 123 | p.flags |= PublishDup 124 | } 125 | 126 | // IsPrintableASCII returns true if the given bytes are constrained to the ASCII 7-bit character set and 127 | // has no control characters. 128 | func IsPrintableASCII(bs []byte) bool { 129 | for i := range bs { 130 | c := bs[i] 131 | if c < 32 || c > 127 { 132 | return false 133 | } 134 | } 135 | return true 136 | } 137 | 138 | // MarshalToJSON marshals the packet as a JSON object onto the given writer 139 | func (p *Publish) MarshalToJSON(w io.Writer) { 140 | pio.WriteString(w, `{"flags":`) 141 | pio.WriteInt(w, int64(p.flags)) 142 | pio.WriteString(w, `,"id":`) 143 | pio.WriteInt(w, int64(p.id)) 144 | pio.WriteString(w, `,"name":`) 145 | jsonstream.WriteString(w, p.name) 146 | if p.replyTo != "" { 147 | pio.WriteString(w, `,"replyTo":`) 148 | jsonstream.WriteString(w, p.replyTo) 149 | } 150 | if len(p.payload) > 0 { 151 | if IsPrintableASCII(p.payload) { 152 | pio.WriteString(w, `,"payload":`) 153 | jsonstream.WriteString(w, string(p.payload)) 154 | } else { 155 | pio.WriteString(w, `,"payloadEnc":`) 156 | jsonstream.WriteString(w, base64.StdEncoding.EncodeToString(p.payload)) 157 | } 158 | } 159 | pio.WriteByte(w, '}') 160 | } 161 | 162 | // NatsReplyTo returns the NATS replyTo subject. Only valid when the packet represents something 163 | // received from NATS due to a client subscribing to a topic with QoS level > 0 164 | func (p *Publish) NatsReplyTo() string { 165 | return p.replyTo 166 | } 167 | 168 | // Payload returns the payload of the published message 169 | func (p *Publish) Payload() []byte { 170 | return p.payload 171 | } 172 | 173 | // QoSLevel returns the quality of service level which is 0, 1 or 2. 174 | func (p *Publish) QoSLevel() byte { 175 | return (p.flags & PublishQoS) >> 1 176 | } 177 | 178 | // ResetRetain resets the retain flag 179 | func (p *Publish) ResetRetain() { 180 | p.flags &^= PublishRetain 181 | } 182 | 183 | // Retain returns the retain flag setting 184 | func (p *Publish) Retain() bool { 185 | return (p.flags & PublishRetain) != 0 186 | } 187 | 188 | // String returns a brief string representation of the packet. Suitable for logging 189 | func (p *Publish) String() string { 190 | // layout borrowed from mosquitto_sub log output 191 | return fmt.Sprintf("PUBLISH (d%d, q%d, r%b, m%d, '%s', ... (%d bytes))", 192 | (p.flags&0x08)>>3, 193 | p.QoSLevel(), 194 | p.flags&0x01, 195 | p.ID(), 196 | p.name, 197 | len(p.payload)) 198 | } 199 | 200 | // TopicName returns the name of the topic 201 | func (p *Publish) TopicName() string { 202 | return p.name 203 | } 204 | 205 | // UnmarshalFromJSON expects the given token to be the object start '{'. If it is, the rest 206 | // of the object is unmarshalled into the receiver. The method will panic with a pio.Error 207 | // if any errors are detected. 208 | // 209 | // See jsonstreamer.Consumer for more info. 210 | func (p *Publish) UnmarshalFromJSON(js jsonstream.Decoder, t json.Token) { 211 | jsonstream.AssertDelim(t, '{') 212 | for { 213 | s, ok := js.ReadStringOrEnd('}') 214 | if !ok { 215 | break 216 | } 217 | switch s { 218 | case "flags": 219 | p.flags = byte(js.ReadInt()) 220 | case "id": 221 | p.id = uint16(js.ReadInt()) 222 | case "name": 223 | p.name = js.ReadString() 224 | case "replyTo": 225 | p.replyTo = js.ReadString() 226 | case "payload": 227 | p.payload = []byte(js.ReadString()) 228 | case "payloadEnc": 229 | var err error 230 | p.payload, err = base64.StdEncoding.DecodeString(js.ReadString()) 231 | if err != nil { 232 | panic(catch.Error(err)) 233 | } 234 | } 235 | } 236 | } 237 | 238 | // Write writes the MQTT bits of this packet on the given Writer 239 | func (p *Publish) Write(w *mqtt.Writer) { 240 | w.WriteU8(TpPublish | p.flags) 241 | pkLen := 2 + len(p.name) + len(p.payload) 242 | if p.QoSLevel() > 0 { 243 | pkLen += 2 244 | } 245 | w.WriteVarInt(pkLen) 246 | w.WriteString(p.name) 247 | if p.QoSLevel() > 0 { 248 | w.WriteU16(p.id) 249 | } 250 | _, _ = w.Write(p.payload) 251 | } 252 | 253 | // The PubAck type represents the MQTT PUBACK packet 254 | type PubAck uint16 255 | 256 | // ParsePubAck parses a PUBACK packet 257 | func ParsePubAck(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) { 258 | if pkLen != 2 { 259 | return PubAck(0), errors.New("malformed PUBACK") 260 | } 261 | id, err := r.ReadUint16() 262 | return PubAck(id), err 263 | } 264 | 265 | // Equals returns true if this packet is equal to the given packet, false if not 266 | func (p PubAck) Equals(other interface{}) bool { 267 | return p == other 268 | } 269 | 270 | // ID returns the packet ID 271 | func (p PubAck) ID() uint16 { 272 | return uint16(p) 273 | } 274 | 275 | // String returns a brief string representation of the packet. Suitable for logging 276 | func (p PubAck) String() string { 277 | return fmt.Sprintf("PUBACK (m%d)", p.ID()) 278 | } 279 | 280 | // Write writes the MQTT bits of this packet on the given Writer 281 | func (p PubAck) Write(w *mqtt.Writer) { 282 | w.WriteU8(TpPubAck) 283 | w.WriteU8(2) 284 | w.WriteU16(uint16(p)) 285 | } 286 | 287 | // The PubRec type represents the MQTT PUBREC packet 288 | type PubRec uint16 289 | 290 | // ParsePubRec parses a PUBREC packet 291 | func ParsePubRec(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) { 292 | if pkLen != 2 { 293 | return PubRec(0), errors.New("malformed PUBREC") 294 | } 295 | id, err := r.ReadUint16() 296 | return PubRec(id), err 297 | } 298 | 299 | // Equals returns true if this packet is equal to the given packet, false if not 300 | func (p PubRec) Equals(other interface{}) bool { 301 | return p == other 302 | } 303 | 304 | // ID returns the packet ID 305 | func (p PubRec) ID() uint16 { 306 | return uint16(p) 307 | } 308 | 309 | // String returns a brief string representation of the packet. Suitable for logging 310 | func (p PubRec) String() string { 311 | return fmt.Sprintf("PUBREC (m%d)", p.ID()) 312 | } 313 | 314 | // Write writes the MQTT bits of this packet on the given Writer 315 | func (p PubRec) Write(w *mqtt.Writer) { 316 | w.WriteU8(TpPubRec) 317 | w.WriteU8(2) 318 | w.WriteU16(uint16(p)) 319 | } 320 | 321 | // The PubRel type represents the MQTT PUBREL packet 322 | type PubRel uint16 323 | 324 | // ParsePubRel parses a PUBREL packet 325 | func ParsePubRel(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) { 326 | if pkLen != 2 { 327 | return PubRel(0), errors.New("malformed PUBREL") 328 | } 329 | id, err := r.ReadUint16() 330 | return PubRel(id), err 331 | } 332 | 333 | // Equals returns true if this packet is equal to the given packet, false if not 334 | func (p PubRel) Equals(other interface{}) bool { 335 | return p == other 336 | } 337 | 338 | // ID returns the packet ID 339 | func (p PubRel) ID() uint16 { 340 | return uint16(p) 341 | } 342 | 343 | // String returns a brief string representation of the packet. Suitable for logging 344 | func (p PubRel) String() string { 345 | return fmt.Sprintf("PUBREL (m%d)", p.ID()) 346 | } 347 | 348 | // Write writes the MQTT bits of this packet on the given Writer 349 | func (p PubRel) Write(w *mqtt.Writer) { 350 | w.WriteU8(TpPubRel) 351 | w.WriteU8(2) 352 | w.WriteU16(uint16(p)) 353 | } 354 | 355 | // The PubComp type represents the MQTT PUBCOMP packet 356 | type PubComp uint16 357 | 358 | // ParsePubComp parses a PUBCOMP packet 359 | func ParsePubComp(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) { 360 | if pkLen != 2 { 361 | return PubComp(0), errors.New("malformed PUBCOMP") 362 | } 363 | id, err := r.ReadUint16() 364 | return PubComp(id), err 365 | } 366 | 367 | // Equals returns true if this packet is equal to the given packet, false if not 368 | func (p PubComp) Equals(other interface{}) bool { 369 | return p == other 370 | } 371 | 372 | // ID returns the packet ID 373 | func (p PubComp) ID() uint16 { 374 | return uint16(p) 375 | } 376 | 377 | // String returns a brief string representation of the packet. Suitable for logging 378 | func (p PubComp) String() string { 379 | return fmt.Sprintf("PUBCOMP (m%d)", p.ID()) 380 | } 381 | 382 | // Write writes the MQTT bits of this packet on the given Writer 383 | func (p PubComp) Write(w *mqtt.Writer) { 384 | w.WriteU8(TpPubComp) 385 | w.WriteU8(2) 386 | w.WriteU16(uint16(p)) 387 | } 388 | -------------------------------------------------------------------------------- /mqtt/pkg/publish_test.go: -------------------------------------------------------------------------------- 1 | package pkg_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/tada/jsonstream" 7 | "github.com/tada/mqtt-nats/mqtt/pkg" 8 | ) 9 | 10 | func TestParsePublish(t *testing.T) { 11 | writeReadAndCompare(t, pkg.NewPublish(23, "some/topic", 2, []byte(`the "message"`), false, ""), 12 | "PUBLISH (d0, q1, r0, m23, 'some/topic', ... (13 bytes))") 13 | } 14 | 15 | func TestParsePubAck(t *testing.T) { 16 | writeReadAndCompare(t, pkg.PubAck(23), "PUBACK (m23)") 17 | } 18 | 19 | func TestParsePubRec(t *testing.T) { 20 | writeReadAndCompare(t, pkg.PubRec(23), "PUBREC (m23)") 21 | } 22 | 23 | func TestParsePubRel(t *testing.T) { 24 | writeReadAndCompare(t, pkg.PubRel(23), "PUBREL (m23)") 25 | } 26 | 27 | func TestParsePubComp(t *testing.T) { 28 | writeReadAndCompare(t, pkg.PubComp(23), "PUBCOMP (m23)") 29 | } 30 | 31 | func TestPublish_MarshalToJSON(t *testing.T) { 32 | p1 := pkg.NewPublish(23, "some/topic", 2, []byte(`the "message"`), false, "") 33 | bs, err := jsonstream.Marshal(p1) 34 | if err != nil { 35 | t.Fatal(err) 36 | } 37 | 38 | p2 := &pkg.Publish{} 39 | err = jsonstream.Unmarshal(p2, bs) 40 | if err != nil { 41 | t.Fatal(err) 42 | } 43 | if !p1.Equals(p2) { 44 | t.Fatal(p1, "!=", p2) 45 | } 46 | } 47 | 48 | func TestPublish_MarshalToJSON_nonUTF(t *testing.T) { 49 | p1 := pkg.NewPublish(23, "some/topic", 2, []byte{0, 1, 2, 3, 5}, false, "") 50 | bs, err := jsonstream.Marshal(p1) 51 | if err != nil { 52 | t.Fatal(err) 53 | } 54 | 55 | p2 := &pkg.Publish{} 56 | err = jsonstream.Unmarshal(p2, bs) 57 | if err != nil { 58 | t.Fatal(err) 59 | } 60 | if !p1.Equals(p2) { 61 | t.Fatal(p1, "!=", p2) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /mqtt/pkg/subscribe.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "strconv" 7 | 8 | "github.com/tada/mqtt-nats/mqtt" 9 | ) 10 | 11 | // Topic is an MQTT Topic subscription with name and desired quality of service 12 | type Topic struct { 13 | // Name is the Topic Name 14 | Name string 15 | 16 | // QoS Quality of Service, will be 0, 1, or 2. 17 | QoS byte 18 | } 19 | 20 | // Subscribe is the MQTT subscribe packet 21 | type Subscribe struct { 22 | id uint16 23 | topics []Topic 24 | } 25 | 26 | const fixedSubscribeFlags = 2 27 | 28 | // NewSubscribe creates a new MQTT subscribe packet 29 | func NewSubscribe(id uint16, topics ...Topic) *Subscribe { 30 | return &Subscribe{id: id, topics: topics} 31 | } 32 | 33 | // ParseSubscribe parses the subscribe packet from the given reader. 34 | func ParseSubscribe(r *mqtt.Reader, b byte, pkLen int) (Packet, error) { 35 | if (b & 0xf) != fixedSubscribeFlags { 36 | return nil, errors.New("malformed subscribe header") 37 | } 38 | 39 | var err error 40 | if r, err = r.ReadPacket(pkLen); err != nil { 41 | return nil, err 42 | } 43 | 44 | sp := &Subscribe{} 45 | if sp.id, err = r.ReadUint16(); err != nil { 46 | return nil, err 47 | } 48 | 49 | for r.Len() > 0 { 50 | t := Topic{} 51 | if t.Name, err = r.ReadString(); err != nil { 52 | return nil, err 53 | } 54 | if t.QoS, err = r.ReadByte(); err != nil { 55 | return nil, err 56 | } 57 | if t.QoS > 2 { 58 | return nil, errors.New("malformed subscribed topic QoS") 59 | } 60 | sp.topics = append(sp.topics, t) 61 | } 62 | return sp, nil 63 | } 64 | 65 | // ID returns the MQTT Packet Identifier 66 | func (s *Subscribe) ID() uint16 { 67 | return s.id 68 | } 69 | 70 | // Equals returns true if this packet is equal to the given packet, false if not 71 | func (s *Subscribe) Equals(other interface{}) bool { 72 | if os, ok := other.(*Subscribe); ok && s.id == os.id && len(s.topics) == len(os.topics) { 73 | for i := range s.topics { 74 | if s.topics[i] != os.topics[i] { 75 | return false 76 | } 77 | } 78 | return true 79 | } 80 | return false 81 | } 82 | 83 | // String returns a brief string representation of the packet. Suitable for logging 84 | func (s *Subscribe) String() string { 85 | bs := bytes.NewBufferString("SUBSCRIBE (m") 86 | bs.WriteString(strconv.Itoa(int(s.ID()))) 87 | bs.WriteString(", ") 88 | wt := func(t Topic) { 89 | bs.WriteByte('q') 90 | bs.WriteString(strconv.Itoa(int(t.QoS))) 91 | bs.WriteString(", '") 92 | bs.WriteString(t.Name) 93 | bs.WriteByte('\'') 94 | } 95 | if len(s.topics) != 1 { 96 | bs.WriteByte('[') 97 | for i, t := range s.topics { 98 | if i > 0 { 99 | bs.WriteString(", ") 100 | } 101 | bs.WriteByte('(') 102 | wt(t) 103 | bs.WriteByte(')') 104 | } 105 | bs.WriteByte(']') 106 | } else { 107 | wt(s.topics[0]) 108 | } 109 | bs.WriteByte(')') 110 | return bs.String() 111 | } 112 | 113 | // Topics returns the list of topics to subscribe to 114 | func (s *Subscribe) Topics() []Topic { 115 | return s.topics 116 | } 117 | 118 | // Write writes the MQTT bits of this packet on the given Writer 119 | func (s *Subscribe) Write(w *mqtt.Writer) { 120 | pkLen := 2 // id 121 | for i := range s.topics { 122 | pkLen += 3 + len(s.topics[i].Name) 123 | } 124 | w.WriteU8(TpSubscribe | fixedSubscribeFlags) 125 | w.WriteVarInt(pkLen) 126 | w.WriteU16(s.id) 127 | for i := range s.topics { 128 | t := s.topics[i] 129 | w.WriteString(t.Name) 130 | w.WriteU8(t.QoS) 131 | } 132 | } 133 | 134 | // SubAck is the MQTT SUBACK packet sent in response to a SUBSCRIBE 135 | type SubAck struct { 136 | id uint16 137 | topicReturns []byte 138 | } 139 | 140 | // NewSubAck creates an SUBACK packet 141 | func NewSubAck(id uint16, topicReturns ...byte) *SubAck { 142 | return &SubAck{id: id, topicReturns: topicReturns} 143 | } 144 | 145 | // ParseSubAck parses a SUBACK packet 146 | func ParseSubAck(r *mqtt.Reader, _ byte, pkLen int) (Packet, error) { 147 | var err error 148 | if r, err = r.ReadPacket(pkLen); err != nil { 149 | return nil, err 150 | } 151 | s := &SubAck{} 152 | if s.id, err = r.ReadUint16(); err != nil { 153 | return nil, err 154 | } 155 | if s.topicReturns, err = r.ReadExact(pkLen - 2); err != nil { 156 | return nil, err 157 | } 158 | return s, nil 159 | } 160 | 161 | // Equals returns true if this packet is equal to the given packet, false if not 162 | func (s *SubAck) Equals(other interface{}) bool { 163 | os, ok := other.(*SubAck) 164 | return ok && s.id == os.id && bytes.Equal(s.topicReturns, os.topicReturns) 165 | } 166 | 167 | // ID returns the packet ID 168 | func (s *SubAck) ID() uint16 { 169 | return s.id 170 | } 171 | 172 | // String returns a brief string representation of the packet. Suitable for logging 173 | func (s *SubAck) String() string { 174 | bs := bytes.NewBufferString("SUBACK (m") 175 | bs.WriteString(strconv.Itoa(int(s.ID()))) 176 | bs.WriteString(", ") 177 | if len(s.topicReturns) != 1 { 178 | bs.WriteByte('[') 179 | for i, t := range s.topicReturns { 180 | if i > 0 { 181 | bs.WriteString(", ") 182 | } 183 | bs.WriteString("rc") 184 | bs.WriteString(strconv.Itoa(int(t))) 185 | } 186 | bs.WriteByte(']') 187 | } else { 188 | bs.WriteString("rc") 189 | bs.WriteString(strconv.Itoa(int(s.topicReturns[0]))) 190 | } 191 | bs.WriteByte(')') 192 | return bs.String() 193 | } 194 | 195 | // TopicReturns returns the desired QoS value for each subscribed topic 196 | func (s *SubAck) TopicReturns() []byte { 197 | return s.topicReturns 198 | } 199 | 200 | // Write writes the MQTT bits of this packet on the given Writer 201 | func (s *SubAck) Write(w *mqtt.Writer) { 202 | w.WriteU8(TpSubAck) 203 | w.WriteVarInt(2 + len(s.topicReturns)) 204 | w.WriteU16(s.id) 205 | _, _ = w.Write(s.topicReturns) 206 | } 207 | -------------------------------------------------------------------------------- /mqtt/pkg/subscribe_test.go: -------------------------------------------------------------------------------- 1 | package pkg_test 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/tada/mqtt-nats/test/packet" 8 | 9 | "github.com/tada/mqtt-nats/mqtt/pkg" 10 | 11 | "github.com/tada/mqtt-nats/mqtt" 12 | "github.com/tada/mqtt-nats/test/utils" 13 | ) 14 | 15 | func TestParseSubscribe(t *testing.T) { 16 | writeReadAndCompare(t, pkg.NewSubscribe(23, pkg.Topic{Name: "some/topic", QoS: 1}), 17 | "SUBSCRIBE (m23, q1, 'some/topic')") 18 | writeReadAndCompare(t, pkg.NewSubscribe(23, pkg.Topic{Name: "some/topic", QoS: 0}, pkg.Topic{Name: "some/other"}), 19 | "SUBSCRIBE (m23, [(q0, 'some/topic'), (q0, 'some/other')])") 20 | } 21 | 22 | func TestParseSubscribe_badFlags(t *testing.T) { 23 | utils.EnsureFailed(t, func(st *testing.T) { 24 | packet.Parse(st, mqtt.NewReader(bytes.NewReader([]byte{pkg.TpSubscribe | 1, 0}))) 25 | }) 26 | } 27 | 28 | func TestParseSubscribe_badLen(t *testing.T) { 29 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader([]byte{})), pkg.TpSubscribe|2, 8) 30 | utils.CheckError(err, t) 31 | } 32 | 33 | func TestParseSubscribe_badId(t *testing.T) { 34 | w := mqtt.NewWriter() 35 | w.WriteU16(28) 36 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader([]byte{1})), pkg.TpConnAck|2, 1) 37 | utils.CheckError(err, t) 38 | } 39 | 40 | func TestParseSubscribe_badCount(t *testing.T) { 41 | w := mqtt.NewWriter() 42 | w.WriteU16(28) 43 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader([]byte{0, 2, 3})), pkg.TpConnAck|2, 3) 44 | utils.CheckError(err, t) 45 | } 46 | 47 | func TestParseSubscribe_badTopic(t *testing.T) { 48 | w := mqtt.NewWriter() 49 | w.WriteU16(28) 50 | w.WriteU16(1) 51 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck|2, 4) 52 | utils.CheckError(err, t) 53 | } 54 | 55 | func TestParseSubscribe_badQoS(t *testing.T) { 56 | w := mqtt.NewWriter() 57 | w.WriteU16(28) 58 | w.WriteString("tpc") 59 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck|2, 7) 60 | utils.CheckError(err, t) 61 | } 62 | 63 | func TestParseSubscribe_invalidQoS(t *testing.T) { 64 | w := mqtt.NewWriter() 65 | w.WriteU16(28) 66 | w.WriteString("tpc") 67 | w.WriteU8(3) 68 | _, err := pkg.ParseSubscribe(mqtt.NewReader(bytes.NewReader(w.Bytes())), pkg.TpConnAck|2, 8) 69 | utils.CheckError(err, t) 70 | } 71 | 72 | func TestSubscribe_Equals(t *testing.T) { 73 | a := pkg.NewSubscribe(32, pkg.Topic{Name: "a"}, pkg.Topic{Name: "b", QoS: 1}) 74 | b := pkg.NewSubscribe(32, pkg.Topic{Name: "a"}, pkg.Topic{Name: "b", QoS: 1}) 75 | utils.CheckTrue(a.Equals(b), t) 76 | 77 | b = pkg.NewSubscribe(32, pkg.Topic{Name: "a"}, pkg.Topic{Name: "b", QoS: 2}) 78 | utils.CheckFalse(a.Equals(b), t) 79 | 80 | c := pkg.NewSubAck(32, 0, 2) 81 | utils.CheckFalse(a.Equals(c), t) 82 | } 83 | 84 | func TestParseSubAck(t *testing.T) { 85 | writeReadAndCompare(t, pkg.NewSubAck(23, 1), "SUBACK (m23, rc1)") 86 | writeReadAndCompare(t, pkg.NewSubAck(23, 1, 1), "SUBACK (m23, [rc1, rc1])") 87 | } 88 | 89 | func TestSubAck_Equals(t *testing.T) { 90 | a := pkg.NewSubAck(32, 1, 2, 0) 91 | b := pkg.NewSubAck(32, 1, 2, 0) 92 | utils.CheckTrue(a.Equals(b), t) 93 | 94 | b = pkg.NewSubAck(32, 1, 2, 1) 95 | utils.CheckFalse(a.Equals(b), t) 96 | 97 | c := pkg.NewSubscribe(32, pkg.Topic{Name: "a"}, pkg.Topic{Name: "b", QoS: 1}) 98 | utils.CheckFalse(a.Equals(c), t) 99 | } 100 | -------------------------------------------------------------------------------- /mqtt/pkg/unsubscribe.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "fmt" 7 | "strconv" 8 | 9 | "github.com/tada/mqtt-nats/mqtt" 10 | ) 11 | 12 | // Unsubscribe is the MQTT UNSUBSCRIBE packet 13 | type Unsubscribe struct { 14 | id uint16 15 | topics []string 16 | } 17 | 18 | // NewUnsubscribe creates a new Unsubscribe packet 19 | func NewUnsubscribe(id uint16, topics ...string) *Unsubscribe { 20 | return &Unsubscribe{id: id, topics: topics} 21 | } 22 | 23 | // ParseUnsubscribe parses the unsubscribe packet from the given reader. 24 | func ParseUnsubscribe(r *mqtt.Reader, b byte, pkLen int) (Packet, error) { 25 | if (b & 0xf) != 2 { 26 | return nil, errors.New("malformed unsubscribe header") 27 | } 28 | 29 | var err error 30 | if r, err = r.ReadPacket(pkLen); err != nil { 31 | return nil, err 32 | } 33 | 34 | up := &Unsubscribe{} 35 | if up.id, err = r.ReadUint16(); err != nil { 36 | return nil, err 37 | } 38 | 39 | for r.Len() > 0 { 40 | var name string 41 | if name, err = r.ReadString(); err != nil { 42 | return nil, err 43 | } 44 | up.topics = append(up.topics, name) 45 | } 46 | return up, nil 47 | } 48 | 49 | // Write writes the MQTT bits of this packet on the given Writer 50 | func (u *Unsubscribe) Write(w *mqtt.Writer) { 51 | pkLen := 2 // packet id 52 | tps := u.topics 53 | for i := range tps { 54 | pkLen += 2 + len(tps[i]) 55 | } 56 | w.WriteU8(TpUnsubscribe | 2) 57 | w.WriteVarInt(pkLen) 58 | w.WriteU16(u.id) 59 | for i := range tps { 60 | w.WriteString(tps[i]) 61 | } 62 | } 63 | 64 | // Equals returns true if this packet is equal to the given packet, false if not 65 | func (u *Unsubscribe) Equals(other interface{}) bool { 66 | if os, ok := other.(*Unsubscribe); ok && u.id == os.id && len(u.topics) == len(os.topics) { 67 | for i := range u.topics { 68 | if u.topics[i] != os.topics[i] { 69 | return false 70 | } 71 | } 72 | return true 73 | } 74 | return false 75 | } 76 | 77 | // ID returns the MQTT Packet Identifier 78 | func (u *Unsubscribe) ID() uint16 { 79 | return u.id 80 | } 81 | 82 | // String returns a brief string representation of the packet. Suitable for logging 83 | func (u *Unsubscribe) String() string { 84 | bs := bytes.NewBufferString("UNSUBSCRIBE (m") 85 | bs.WriteString(strconv.Itoa(int(u.ID()))) 86 | bs.WriteString(", [") 87 | for i, t := range u.topics { 88 | if i > 0 { 89 | bs.WriteString(", ") 90 | } 91 | bs.WriteByte('\'') 92 | bs.WriteString(t) 93 | bs.WriteByte('\'') 94 | } 95 | bs.WriteString("])") 96 | return bs.String() 97 | } 98 | 99 | // Topics returns the list of topics to subscribe to 100 | func (u *Unsubscribe) Topics() []string { 101 | return u.topics 102 | } 103 | 104 | // UnsubAck is the MQTT UNSUBACK packet 105 | type UnsubAck uint16 106 | 107 | // ParseUnsubAck parses the unsubscribe packet from the given reader. 108 | func ParseUnsubAck(r *mqtt.Reader, b byte, pkLen int) (Packet, error) { 109 | if pkLen != 2 { 110 | return UnsubAck(0), errors.New("malformed UNSUBACK") 111 | } 112 | id, err := r.ReadUint16() 113 | return UnsubAck(id), err 114 | } 115 | 116 | // ID returns the packet ID 117 | func (u UnsubAck) ID() uint16 { 118 | return uint16(u) 119 | } 120 | 121 | // Equals returns true if this packet is equal to the given packet, false if not 122 | func (u UnsubAck) Equals(other interface{}) bool { 123 | return u == other 124 | } 125 | 126 | // String returns a brief string representation of the packet. Suitable for logging 127 | func (u UnsubAck) String() string { 128 | return fmt.Sprintf("UNSUBACK (m%d)", u.ID()) 129 | } 130 | 131 | // Write writes the MQTT bits of this packet on the given Writer 132 | func (u UnsubAck) Write(w *mqtt.Writer) { 133 | w.WriteU8(TpUnsubAck) 134 | w.WriteU8(2) 135 | w.WriteU16(uint16(u)) 136 | } 137 | -------------------------------------------------------------------------------- /mqtt/pkg/unsubscribe_test.go: -------------------------------------------------------------------------------- 1 | package pkg_test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/tada/mqtt-nats/mqtt/pkg" 7 | ) 8 | 9 | func TestParseUnsubscribe(t *testing.T) { 10 | writeReadAndCompare(t, pkg.NewUnsubscribe(23, "some/topic", "some/other"), "UNSUBSCRIBE (m23, ['some/topic', 'some/other'])") 11 | } 12 | 13 | func TestParseUnsubAck(t *testing.T) { 14 | writeReadAndCompare(t, pkg.UnsubAck(23), "UNSUBACK (m23)") 15 | } 16 | -------------------------------------------------------------------------------- /mqtt/pkg/will.go: -------------------------------------------------------------------------------- 1 | package pkg 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | ) 7 | 8 | // Will is the optional client will in the MQTT connect packet 9 | type Will struct { 10 | Topic string 11 | Message []byte 12 | QoS byte 13 | Retain bool 14 | } 15 | 16 | // Equals returns true if this instance is equal to the given instance, false if not 17 | func (w *Will) Equals(ow *Will) bool { 18 | return w.Retain == ow.Retain && w.QoS == ow.QoS && w.Topic == ow.Topic && bytes.Equal(w.Message, ow.Message) 19 | } 20 | 21 | // String returns a brief string representation of the will. Suitable for logging 22 | func (w *Will) String() string { 23 | r := 0 24 | if w.Retain { 25 | r = 1 26 | } 27 | return fmt.Sprintf("w(r%d, q%d, '%s', ... (%d bytes))", r, w.QoS, w.Topic, len(w.Message)) 28 | } 29 | -------------------------------------------------------------------------------- /mqtt/reader.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "fmt" 8 | "io" 9 | ) 10 | 11 | // Reader extends the io.Reader with MQTT specific semantics for reading variable length integers, 12 | // two byte unsigned integers, and length prefixed strings and bytes. 13 | type Reader struct { 14 | io.Reader 15 | } 16 | 17 | // NewReader creates a new Reader that reads from the given io.Reader 18 | func NewReader(r io.Reader) *Reader { 19 | return &Reader{r} 20 | } 21 | 22 | // ReadByte reads and returns the next byte from the input or 23 | // any error encountered. If ReadByte returns an error, no input 24 | // byte was consumed, and the returned byte value is undefined. 25 | func (r *Reader) ReadByte() (byte, error) { 26 | b := []byte{0} 27 | n, err := r.Read(b) 28 | if n == 1 { 29 | return b[0], nil 30 | } 31 | return 0, err 32 | } 33 | 34 | // ReadVarInt returns the next variable size unsigned integer from the input stream. 35 | // A "malformed compressed int" error is returned if the value is larger than 0x0FFFFFFF. 36 | // 37 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read. 38 | func (r *Reader) ReadVarInt() (int, error) { 39 | m := 1 40 | v := 0 41 | for { 42 | b, err := r.ReadByte() 43 | if err != nil { 44 | if err == io.EOF { 45 | err = io.ErrUnexpectedEOF 46 | } 47 | return 0, err 48 | } 49 | v += int(b&0x7f) * m 50 | if (b & 0x80) == 0 { 51 | return v, nil 52 | } 53 | m *= 0x80 54 | if m > 0x200000 { 55 | return 0, errors.New("malformed compressed int") 56 | } 57 | } 58 | } 59 | 60 | // ReadUint16 reads the next two bytes from the input stream and returns a big endian unsigned integer 61 | // 62 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read. 63 | func (r *Reader) ReadUint16() (uint16, error) { 64 | var v uint16 65 | bs, err := r.ReadExact(2) 66 | if err == nil { 67 | v = binary.BigEndian.Uint16(bs) 68 | } 69 | return v, err 70 | } 71 | 72 | // ReadString will reads a big endian uint16 from the stream that denotes the number of bytes 73 | // that will follow. It then reads those bytes and returns them as a UTF8 encoded string. 74 | // 75 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read. 76 | func (r *Reader) ReadString() (string, error) { 77 | var s string 78 | bs, err := r.ReadBytes() 79 | if err == nil { 80 | s = string(bs) 81 | } 82 | return s, err 83 | } 84 | 85 | // ReadBytes will reads a big endian uint16 from the stream that denotes the number of bytes 86 | // that will follow. It then reads those bytes and returns them. 87 | // 88 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read. 89 | func (r *Reader) ReadBytes() ([]byte, error) { 90 | var bs []byte 91 | l, err := r.ReadUint16() 92 | if l > 0 && err == nil { 93 | bs, err = r.ReadExact(int(l)) 94 | } 95 | return bs, err 96 | } 97 | 98 | // ReadExact reads an exact number of bytes into a []byte slice and returns it. 99 | // 100 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read. 101 | func (r *Reader) ReadExact(n int) ([]byte, error) { 102 | bs := make([]byte, n) 103 | _, err := io.ReadFull(r, bs) 104 | if err != nil { 105 | bs = nil 106 | if err == io.EOF { 107 | err = io.ErrUnexpectedEOF 108 | } 109 | } 110 | return bs, err 111 | } 112 | 113 | // Len returns the number of bytes of the unread portion of the slice. This method will panic 114 | // unless the underlying reader is a bytes.Reader. 115 | func (r *Reader) Len() int { 116 | if br, ok := r.Reader.(*bytes.Reader); ok { 117 | return br.Len() 118 | } 119 | 120 | // Reader was not set up to read remaining length 121 | panic(fmt.Errorf("unsupported operation on %T: Len", r.Reader)) 122 | } 123 | 124 | // ReadRemainingBytes returns the remaining bytes of the underlying reader. This method will panic 125 | // unless the underlying reader is a bytes.Reader. 126 | func (r *Reader) ReadRemainingBytes() ([]byte, error) { 127 | return r.ReadExact(r.Len()) 128 | } 129 | 130 | // ReadPacket reads exactly pkLen bytes from the underlying reader into a byte slice and creates a new Reader 131 | // that will read from this slice. The new Reader is returned. 132 | // 133 | // An io.ErrUnexpectedEOF is returned if EOF is encountered during the read. 134 | func (r *Reader) ReadPacket(pkLen int) (*Reader, error) { 135 | // Do a bulk read and switch to read from that bulk 136 | var rdr *Reader 137 | pk, err := r.ReadExact(pkLen) 138 | if err == nil { 139 | rdr = &Reader{bytes.NewReader(pk)} 140 | } 141 | return rdr, err 142 | } 143 | -------------------------------------------------------------------------------- /mqtt/reader_test.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "testing" 7 | ) 8 | 9 | type noLen int 10 | 11 | func (noLen) Read([]byte) (int, error) { 12 | return 0, io.EOF 13 | } 14 | 15 | func TestReader_Len(t *testing.T) { 16 | l := NewReader(bytes.NewReader([]byte{1, 2, 3})).Len() 17 | if l != 3 { 18 | t.Fatalf("expected length 3 got %d", l) 19 | } 20 | defer func() { 21 | r := recover() 22 | if r == nil { 23 | t.Fatal("expected panic") 24 | } 25 | }() 26 | NewReader(noLen(0)).Len() 27 | } 28 | 29 | func TestReader_ReadBytes(t *testing.T) { 30 | r := NewReader(bytes.NewReader([]byte{0, 2, 'a', 'b'})) 31 | bs, err := r.ReadBytes() 32 | if err != nil { 33 | t.Fatal(err) 34 | } 35 | sbs := string(bs) 36 | if sbs != "ab" { 37 | t.Fatalf(`expected "ab", got %q`, sbs) 38 | } 39 | 40 | // test premature EOF 41 | r = NewReader(bytes.NewReader([]byte{0, 3, 'a', 'b'})) 42 | _, err = r.ReadBytes() 43 | if err != io.ErrUnexpectedEOF { 44 | t.Fatal("expected error") 45 | } 46 | 47 | // test premature EOF nothing after length 48 | r = NewReader(bytes.NewReader([]byte{0, 3})) 49 | _, err = r.ReadBytes() 50 | if err != io.ErrUnexpectedEOF { 51 | t.Fatal("expected error") 52 | } 53 | } 54 | 55 | func TestReader_ReadVarInt(t *testing.T) { 56 | r := NewReader(bytes.NewReader([]byte{0x82, 0xff, 0x3})) 57 | l, err := r.ReadVarInt() 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | if l != 0xff82 { 62 | t.Fatalf("expected length 0xff82 got 0x%x", l) 63 | } 64 | _, err = NewReader(bytes.NewReader([]byte{0xff, 0xff, 0xff, 0xff, 0xff})).ReadVarInt() 65 | if err == nil { 66 | t.Fatal("expected error") 67 | } 68 | _, err = NewReader(bytes.NewReader([]byte{0x80})).ReadVarInt() 69 | if err != io.ErrUnexpectedEOF { 70 | t.Fatal("expected error") 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /mqtt/topic.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "io" 5 | "regexp" 6 | "strings" 7 | ) 8 | 9 | const ( 10 | dot = rune('.') 11 | slash = rune('/') 12 | star = rune('*') 13 | plus = rune('+') 14 | hash = rune('#') 15 | gt = rune('>') 16 | matchSegment = `^[/]+` 17 | matchRest = `.*` 18 | ) 19 | 20 | // SubscriptionToRegexp converts an MQTT topic subscription into a regular expression that can be 21 | // used to match topics. 22 | func SubscriptionToRegexp(s string) *regexp.Regexp { 23 | w := strings.Builder{} 24 | for i, p := range strings.Split(s, "/") { 25 | if i > 0 { 26 | _ = w.WriteByte('/') 27 | } 28 | if len(p) == 1 { 29 | switch rune(p[0]) { 30 | case plus: 31 | _, _ = w.WriteString(matchSegment) 32 | continue 33 | case hash: 34 | _, _ = w.WriteString(matchRest) 35 | continue 36 | } 37 | } 38 | _, _ = w.WriteString(regexp.QuoteMeta(p)) 39 | } 40 | return regexp.MustCompile(w.String()) 41 | } 42 | 43 | // ToNATS converts an MQTT topic to a NATS subject. The following conversions take place 44 | // 45 | // dots become slashes 46 | // slashes become dots 47 | func ToNATS(mqttTopic string) string { 48 | r := strings.NewReader(mqttTopic) 49 | w := strings.Builder{} 50 | for { 51 | c, _, err := r.ReadRune() 52 | if err == io.EOF { 53 | return w.String() 54 | } 55 | switch c { 56 | case dot: 57 | c = slash 58 | case slash: 59 | c = dot 60 | } 61 | _, _ = w.WriteRune(c) 62 | } 63 | } 64 | 65 | // ToNATSSubscription converts the given MQTT subscription into a NATS subscription 66 | func ToNATSSubscription(mqttSub string) string { 67 | r := strings.NewReader(mqttSub) 68 | w := strings.Builder{} 69 | for { 70 | c, _, err := r.ReadRune() 71 | if err == io.EOF { 72 | return w.String() 73 | } 74 | switch c { 75 | case dot: 76 | c = slash 77 | case slash: 78 | c = dot 79 | case star: 80 | c = plus 81 | case plus: 82 | c = star 83 | case hash: 84 | c = gt 85 | case gt: 86 | c = hash 87 | } 88 | _, _ = w.WriteRune(c) 89 | } 90 | } 91 | 92 | // FromNATS converts an MATS subject to a MQTT topic. The following conversions take place 93 | // 94 | // dots become slashes 95 | // slashes become dots 96 | func FromNATS(natsSubject string) string { 97 | // exact same conversion but opposite direction, at least for now 98 | return ToNATS(natsSubject) 99 | } 100 | 101 | // FromNATSSubscription converts the given NATS subscription into a MQTT subscription 102 | func FromNATSSubscription(natsSubject string) string { 103 | // exact same conversion but opposite direction, at least for now 104 | return ToNATSSubscription(natsSubject) 105 | } 106 | -------------------------------------------------------------------------------- /mqtt/topic_test.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestFromNATS(t *testing.T) { 8 | tests := []struct { 9 | name string 10 | nats string 11 | want string 12 | }{{ 13 | name: "Slash is dot", 14 | nats: "a/b/c", 15 | want: "a.b.c", 16 | }, { 17 | name: "Dot is slash", 18 | nats: "a.b.c", 19 | want: "a/b/c", 20 | }, 21 | } 22 | for i := range tests { 23 | tt := tests[i] 24 | t.Run(tt.name, func(t *testing.T) { 25 | if got := FromNATS(tt.nats); got != tt.want { 26 | t.Errorf("FromNATS() = %v, want %v", got, tt.want) 27 | } 28 | }) 29 | } 30 | } 31 | 32 | func TestFromNATSSubscription(t *testing.T) { 33 | tests := []struct { 34 | name string 35 | nats string 36 | want string 37 | }{{ 38 | name: "Slash is dot", 39 | nats: "a/b/c", 40 | want: "a.b.c", 41 | }, { 42 | name: "Dot is slash", 43 | nats: "a.b.c", 44 | want: "a/b/c", 45 | }, { 46 | name: "Star is plus", 47 | nats: "a.*.b", 48 | want: "a/+/b", 49 | }, { 50 | name: "Plus is star", 51 | nats: "a.+.b", 52 | want: "a/*/b", 53 | }, { 54 | name: "> is #", 55 | nats: "a.b.>", 56 | want: "a/b/#", 57 | }, { 58 | name: "# is >", 59 | nats: "a.b.#", 60 | want: "a/b/>", 61 | }} 62 | for i := range tests { 63 | tt := tests[i] 64 | t.Run(tt.name, func(t *testing.T) { 65 | if got := FromNATSSubscription(tt.nats); got != tt.want { 66 | t.Errorf("FromNATSSubscription() = %v, want %v", got, tt.want) 67 | } 68 | }) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /mqtt/writer.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import "bytes" 4 | 5 | // Writer extends a bytes.Buffer with MQTT specific semantics for writing variable length integers, 6 | // two byte unsigned integers, and length prefixed strings and bytes. 7 | type Writer struct { 8 | bytes.Buffer 9 | } 10 | 11 | // NewWriter returns a new Writer instance 12 | func NewWriter() *Writer { 13 | return &Writer{} 14 | } 15 | 16 | // WriteU8 writes a byte on the underlying buffer. This is the same as calling WriteByte but there 17 | // is no error return as opposed to WriteByte which returns the error type although it is always nil. 18 | func (w *Writer) WriteU8(i uint8) { 19 | _ = w.WriteByte(i) 20 | } 21 | 22 | // WriteU16 writes the big endian two bytes of the given uint16 on the underlying buffer 23 | func (w *Writer) WriteU16(i uint16) { 24 | w.WriteU8(byte(i >> 8)) 25 | w.WriteU8(byte(i)) 26 | } 27 | 28 | // WriteString first writes the length of the string using WriteU16 and then the strings bytes. 29 | func (w *Writer) WriteString(s string) { 30 | t := len(s) 31 | w.WriteU16(uint16(t)) 32 | for i := 0; i < t; i++ { 33 | w.WriteU8(s[i]) 34 | } 35 | } 36 | 37 | // WriteBytes first writes the length of the byte slice using WriteU16 and then the bytes. 38 | func (w *Writer) WriteBytes(bs []byte) { 39 | t := len(bs) 40 | w.WriteU16(uint16(t)) 41 | _, _ = w.Write(bs) 42 | } 43 | 44 | // WriteVarInt writes a variable length integer in accordance with the MQTT 3.1.1 specification on the 45 | // underlying buffer. 46 | func (w *Writer) WriteVarInt(value int) { 47 | for { 48 | b := byte(value & 0x7f) 49 | value >>= 7 50 | if value > 0 { 51 | b |= 0x80 52 | } 53 | w.WriteU8(b) 54 | if value == 0 { 55 | break 56 | } 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /mqtt/writer_test.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | ) 7 | 8 | func TestWriteVarInt(t *testing.T) { 9 | ints := []int{ 10 | 0, 1, 127, 128, 16383, 16384, 2097151, 2097152, 268435455, 11 | } 12 | lens := []int{ 13 | 1, 1, 1, 2, 2, 3, 3, 4, 4, 14 | } 15 | w := NewWriter() 16 | tl := 0 17 | for i, v := range ints { 18 | w.WriteVarInt(v) 19 | tl += lens[i] 20 | if tl != w.Len() { 21 | t.Fatalf("expected len %d, got %d", tl, w.Len()) 22 | } 23 | } 24 | 25 | r := Reader{bytes.NewReader(w.Bytes())} 26 | for _, v := range ints { 27 | x, _ := r.ReadVarInt() 28 | if v != x { 29 | t.Fatalf("expected %d, got %d", v, x) 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /test/connect_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | "time" 7 | 8 | "github.com/tada/mqtt-nats/mqtt/pkg" 9 | "github.com/tada/mqtt-nats/test/full" 10 | ) 11 | 12 | var packetIDManager = pkg.NewIDManager() 13 | 14 | func nextPacketID() uint16 { 15 | return packetIDManager.NextFreePacketID() 16 | } 17 | 18 | func TestConnect(t *testing.T) { 19 | conn := full.MqttConnect(t, mqttPort) 20 | full.MqttSend(t, conn, pkg.NewConnect(full.NextClientID(), true, 1, nil, nil)) 21 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0)) 22 | full.MqttDisconnect(t, conn) 23 | } 24 | 25 | func TestConnect_sessionPresent(t *testing.T) { 26 | conn := full.MqttConnect(t, mqttPort) 27 | c := pkg.NewConnect(full.NextClientID(), false, 1, nil, nil) 28 | full.MqttSend(t, conn, c) 29 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0)) 30 | full.MqttDisconnect(t, conn) 31 | 32 | conn = full.MqttConnect(t, mqttPort) 33 | full.MqttSend(t, conn, c) 34 | full.MqttExpect(t, conn, pkg.NewConnAck(true, 0)) 35 | full.MqttDisconnect(t, conn) 36 | } 37 | 38 | func TestConnect_will_qos_0(t *testing.T) { 39 | conn1 := full.MqttConnect(t, mqttPort) 40 | full.MqttSend(t, conn1, pkg.NewConnect(full.NextClientID(), true, 1, nil, nil)) 41 | full.MqttExpect(t, conn1, pkg.NewConnAck(false, 0)) 42 | mid := nextPacketID() 43 | full.MqttSend(t, conn1, pkg.NewSubscribe(mid, pkg.Topic{Name: "testing/my/will"})) 44 | full.MqttExpect(t, conn1, pkg.NewSubAck(mid, 0)) 45 | 46 | conn2 := full.MqttConnect(t, mqttPort) 47 | full.MqttSend(t, conn2, 48 | pkg.NewConnect(full.NextClientID(), true, 1, 49 | &pkg.Will{ 50 | Topic: "testing/my/will", 51 | Message: []byte("the will message")}, nil)) 52 | full.MqttExpect(t, conn2, pkg.NewConnAck(false, 0)) 53 | // forcefully close connection 54 | _ = conn2.Close() 55 | 56 | full.MqttExpect(t, conn1, 57 | func(p pkg.Packet) bool { 58 | pp, ok := p.(*pkg.Publish) 59 | return ok && pp.TopicName() == "testing/my/will" && pp.QoSLevel() == 0 && !pp.IsDup() && !pp.Retain() 60 | }) 61 | full.MqttDisconnect(t, conn1) 62 | } 63 | 64 | func TestConnect_will_qos_1(t *testing.T) { 65 | conn := full.MqttConnect(t, mqttPort) 66 | full.MqttSend(t, conn, 67 | pkg.NewConnect(full.NextClientID(), true, 1, &pkg.Will{ 68 | Topic: "testing/my/will", 69 | Message: []byte("the will message"), 70 | QoS: 1}, nil)) 71 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0)) 72 | // forcefully close connection 73 | _ = conn.Close() 74 | 75 | // Ensure that first package is wasted and a dup is published 76 | time.Sleep(10 * time.Millisecond) 77 | 78 | conn = full.MqttConnectClean(t, mqttPort) 79 | mid := nextPacketID() 80 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: "testing/my/will", QoS: 1})) 81 | full.MqttExpect(t, conn, 82 | pkg.NewSubAck(mid, 1), 83 | func(p pkg.Packet) bool { 84 | if pp, ok := p.(*pkg.Publish); ok { 85 | full.MqttSend(t, conn, pkg.PubAck(pp.ID())) 86 | return pp.TopicName() == "testing/my/will" && pp.QoSLevel() == 1 && pp.IsDup() && !pp.Retain() 87 | } 88 | return false 89 | }) 90 | full.MqttDisconnect(t, conn) 91 | } 92 | 93 | func TestConnect_will_retain_qos_0(t *testing.T) { 94 | willTopic := "testing/my/will" 95 | c1 := full.MqttConnect(t, mqttPort) 96 | full.MqttSend(t, c1, 97 | pkg.NewConnect(full.NextClientID(), true, 5, &pkg.Will{ 98 | Topic: willTopic, 99 | Message: []byte("the will message"), 100 | Retain: true}, nil)) 101 | full.MqttExpect(t, c1, pkg.NewConnAck(false, 0)) 102 | 103 | // forcefully close connection to make server publish will 104 | _ = c1.Close() 105 | time.Sleep(10 * time.Millisecond) // give bridge time to handle retained 106 | 107 | gotIt := make(chan bool, 1) 108 | go func() { 109 | c2 := full.MqttConnectClean(t, mqttPort) 110 | mid := nextPacketID() 111 | full.MqttSend(t, c2, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic})) 112 | full.MqttExpect(t, c2, pkg.NewSubAck(mid, 0)) 113 | full.MqttExpect(t, c2, func(p pkg.Packet) bool { 114 | pp, ok := p.(*pkg.Publish) 115 | return ok && pp.TopicName() == willTopic && pp.QoSLevel() == 0 && !pp.IsDup() && pp.Retain() 116 | }) 117 | gotIt <- true 118 | full.MqttDisconnect(t, c2) 119 | }() 120 | 121 | full.AssertMessageReceived(t, gotIt) 122 | 123 | // check that retained will still exists 124 | c1 = full.MqttConnectClean(t, mqttPort) 125 | mid := nextPacketID() 126 | full.MqttSend(t, c1, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic})) 127 | full.MqttExpect(t, c1, 128 | pkg.NewSubAck(mid, 0), 129 | func(p pkg.Packet) bool { 130 | pp, ok := p.(*pkg.Publish) 131 | return ok && pp.TopicName() == willTopic && pp.QoSLevel() == 0 && !pp.IsDup() && pp.Retain() 132 | }) 133 | 134 | // drop the retained packet 135 | full.MqttSend(t, c1, pkg.NewPublish2(0, willTopic, []byte{}, 0, false, true)) 136 | full.MqttDisconnect(t, c1) 137 | } 138 | 139 | func TestConnect_will_retain_qos_1(t *testing.T) { 140 | conn := full.MqttConnect(t, mqttPort) 141 | willTopic := "testing/my/will" 142 | willPayload := []byte("the will message") 143 | full.MqttSend(t, conn, 144 | pkg.NewConnect(full.NextClientID(), true, 5, &pkg.Will{ 145 | Topic: willTopic, 146 | Message: willPayload, 147 | QoS: 1, 148 | Retain: true}, nil)) 149 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0)) 150 | // forcefully close connection 151 | _ = conn.Close() 152 | 153 | conn = full.MqttConnectClean(t, mqttPort) 154 | mid := nextPacketID() 155 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic, QoS: 1})) 156 | var ackID uint16 157 | full.MqttExpect(t, conn, 158 | pkg.NewSubAck(mid, 1), 159 | func(p pkg.Packet) bool { 160 | if pp, ok := p.(*pkg.Publish); ok { 161 | ackID = pp.ID() 162 | return pp.TopicName() == willTopic && bytes.Equal(pp.Payload(), willPayload) && pp.QoSLevel() == 1 && !pp.IsDup() && pp.Retain() 163 | } 164 | return false 165 | }) 166 | full.MqttSend(t, conn, pkg.PubAck(ackID)) 167 | full.MqttDisconnect(t, conn) 168 | 169 | conn = full.MqttConnectClean(t, mqttPort) 170 | mid = nextPacketID() 171 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic, QoS: 1})) 172 | full.MqttExpect(t, conn, 173 | pkg.NewSubAck(mid, 1), 174 | func(p pkg.Packet) bool { 175 | if pp, ok := p.(*pkg.Publish); ok { 176 | ackID = pp.ID() 177 | return pp.TopicName() == willTopic && bytes.Equal(pp.Payload(), willPayload) && pp.QoSLevel() == 1 && !pp.IsDup() && pp.Retain() 178 | } 179 | return false 180 | }) 181 | full.MqttSend(t, conn, pkg.PubAck(ackID)) 182 | 183 | // drop the retained packet 184 | full.MqttSend(t, conn, pkg.NewPublish2(0, willTopic, []byte{}, 1, false, true)) 185 | full.MqttDisconnect(t, conn) 186 | } 187 | 188 | func TestConnect_will_retain_qos_1_restart(t *testing.T) { 189 | conn := full.MqttConnect(t, mqttPort) 190 | willTopic := "testing/my/will" 191 | willPayload := []byte("the will message") 192 | full.MqttSend(t, conn, 193 | pkg.NewConnect(full.NextClientID(), true, 5, &pkg.Will{ 194 | Topic: willTopic, 195 | Message: willPayload, 196 | QoS: 1, 197 | Retain: true}, nil)) 198 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0)) 199 | // forcefully close connection 200 | _ = conn.Close() 201 | time.Sleep(50 * time.Millisecond) // give bridge time to publish will 202 | 203 | conn = full.MqttConnectClean(t, mqttPort) 204 | mid := nextPacketID() 205 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic, QoS: 1})) 206 | full.MqttExpect(t, conn, pkg.NewSubAck(mid, 1)) 207 | 208 | full.MqttExpect(t, conn, func(p pkg.Packet) bool { 209 | if pp, ok := p.(*pkg.Publish); ok { 210 | full.MqttSend(t, conn, pkg.PubAck(pp.ID())) 211 | return pp.TopicName() == willTopic && bytes.Equal(pp.Payload(), willPayload) && pp.QoSLevel() == 1 && !pp.IsDup() && pp.Retain() 212 | } 213 | return false 214 | }) 215 | full.MqttDisconnect(t, conn) 216 | 217 | full.RestartBridge(t, mqttServer) 218 | 219 | conn = full.MqttConnectClean(t, mqttPort) 220 | mid = nextPacketID() 221 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic, QoS: 1})) 222 | full.MqttExpect(t, conn, pkg.NewSubAck(mid, 1)) 223 | full.MqttExpect(t, conn, 224 | func(p pkg.Packet) bool { 225 | if pp, ok := p.(*pkg.Publish); ok { 226 | full.MqttSend(t, conn, pkg.PubAck(pp.ID())) 227 | return pp.TopicName() == willTopic && bytes.Equal(pp.Payload(), willPayload) && pp.QoSLevel() == 1 && !pp.IsDup() && pp.Retain() 228 | } 229 | return false 230 | }) 231 | 232 | // drop the retained packet 233 | full.MqttSend(t, conn, pkg.NewPublish2(0, willTopic, []byte{}, 1, false, true)) 234 | full.MqttDisconnect(t, conn) 235 | } 236 | 237 | func TestConnect_will_qos_1_restart(t *testing.T) { 238 | conn := full.MqttConnect(t, mqttPort) 239 | willTopic := "testing/my/will" 240 | willPayload := []byte("the will message") 241 | full.MqttSend(t, conn, 242 | pkg.NewConnect(full.NextClientID(), true, 5, 243 | &pkg.Will{ 244 | Topic: willTopic, 245 | Message: willPayload, 246 | QoS: 1}, 247 | &pkg.Credentials{ 248 | User: "bob", 249 | Password: []byte("password")})) 250 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0)) 251 | // forcefully close connection 252 | _ = conn.Close() 253 | time.Sleep(50 * time.Millisecond) // give bridge time to publish will 254 | full.RestartBridge(t, mqttServer) 255 | 256 | conn = full.MqttConnectClean(t, mqttPort) 257 | mid := nextPacketID() 258 | full.MqttSend(t, conn, pkg.NewSubscribe(mid, pkg.Topic{Name: willTopic, QoS: 1})) 259 | var ackID uint16 260 | full.MqttExpect(t, conn, 261 | pkg.NewSubAck(mid, 1), 262 | func(p pkg.Packet) bool { 263 | if pp, ok := p.(*pkg.Publish); ok { 264 | ackID = pp.ID() 265 | return pp.TopicName() == willTopic && bytes.Equal(pp.Payload(), willPayload) && pp.QoSLevel() == 1 && pp.IsDup() && !pp.Retain() 266 | } 267 | return false 268 | }) 269 | full.MqttSend(t, conn, pkg.PubAck(ackID)) 270 | 271 | // drop the retained packet 272 | full.MqttSend(t, conn, pkg.NewPublish2(0, willTopic, []byte{}, 1, false, true)) 273 | full.MqttDisconnect(t, conn) 274 | } 275 | 276 | func TestPing(t *testing.T) { 277 | conn := full.MqttConnectClean(t, mqttPort) 278 | full.MqttSend(t, conn, pkg.PingRequestSingleton) 279 | full.MqttExpect(t, conn, pkg.PingResponseSingleton) 280 | full.MqttDisconnect(t, conn) 281 | } 282 | 283 | func TestPing_beforeConnect(t *testing.T) { 284 | conn := full.MqttConnect(t, mqttPort) 285 | full.MqttSend(t, conn, pkg.PingRequestSingleton) 286 | full.MqttExpectConnReset(t, conn) 287 | } 288 | 289 | func TestConnect_badProtocolVersion(t *testing.T) { 290 | conn := full.MqttConnect(t, mqttPort) 291 | cp := pkg.NewConnect(full.NextClientID(), true, 1, nil, nil) 292 | cp.SetClientLevel(3) 293 | full.MqttSend(t, conn, cp) 294 | full.MqttExpect(t, conn, pkg.NewConnAck(false, pkg.RtUnacceptableProtocolVersion)) 295 | full.MqttDisconnect(t, conn) 296 | } 297 | 298 | func TestConnect_multiple(t *testing.T) { 299 | conn := full.MqttConnectClean(t, mqttPort) 300 | full.MqttSend(t, conn, pkg.NewConnect(full.NextClientID(), true, 1, nil, nil)) 301 | full.MqttExpectConnReset(t, conn) 302 | } 303 | 304 | func TestBadPacketLength(t *testing.T) { 305 | conn := full.MqttConnectClean(t, mqttPort) 306 | _, err := conn.Write([]byte{0x01, 0xff, 0xff, 0xff, 0xff}) 307 | if err != nil { 308 | t.Fatal(err) 309 | } 310 | full.MqttExpectConnReset(t, conn) 311 | } 312 | 313 | func TestBadPacketType(t *testing.T) { 314 | conn := full.MqttConnectClean(t, mqttPort) 315 | _, err := conn.Write([]byte{0xff, 0x0}) 316 | if err != nil { 317 | t.Fatal(err) 318 | } 319 | full.MqttExpectConnReset(t, conn) 320 | } 321 | -------------------------------------------------------------------------------- /test/full/bridge.go: -------------------------------------------------------------------------------- 1 | // +build citest 2 | 3 | // Package full contains the test utilities that enables full roundtrip testing with both an mqtt-bridge and 4 | // a NATS test server. 5 | package full 6 | 7 | import ( 8 | "sync" 9 | "testing" 10 | "time" 11 | 12 | "github.com/nats-io/nats-server/v2/server" 13 | testserver "github.com/nats-io/nats-server/v2/test" 14 | "github.com/tada/mqtt-nats/bridge" 15 | "github.com/tada/mqtt-nats/logger" 16 | ) 17 | 18 | // RunBridge starts an in-process mqtt-nats bridge configured with the given options. 19 | func RunBridge(lg logger.Logger, opts *bridge.Options) (bridge.Bridge, error) { 20 | srv, err := bridge.New(opts, lg) 21 | if err != nil { 22 | return nil, err 23 | } 24 | 25 | serverReady := sync.WaitGroup{} 26 | serverReady.Add(1) 27 | go func() { 28 | err = srv.Serve(&serverReady) 29 | if err != nil { 30 | lg.Error(err) 31 | } 32 | }() 33 | serverReady.Wait() 34 | return srv, err 35 | } 36 | 37 | // RestartBridge restarts the given bridge 38 | func RestartBridge(t *testing.T, b bridge.Bridge) { 39 | serverReady := sync.WaitGroup{} 40 | serverReady.Add(1) 41 | go func() { 42 | if err := b.Restart(&serverReady); err != nil { 43 | t.Error(err) 44 | } 45 | }() 46 | serverReady.Wait() 47 | } 48 | 49 | // NATSServerOnPort will run a server on the given port. 50 | func NATSServerOnPort(port int) *server.Server { 51 | opts := testserver.DefaultTestOptions 52 | opts.Port = port 53 | return NATSServerWithOptions(&opts) 54 | } 55 | 56 | // NATSServerWithOptions will run a server with the given options. 57 | func NATSServerWithOptions(opts *server.Options) *server.Server { 58 | return testserver.RunServer(opts) 59 | } 60 | 61 | // AssertMessageReceived waits for a boolean to be received on the given channel for one second and then 62 | // bails out with a Fatal message "expected message did not arrive" 63 | func AssertMessageReceived(t *testing.T, c <-chan bool) { 64 | t.Helper() 65 | select { 66 | case <-c: 67 | case <-time.After(time.Second): // Wait time is somewhat arbitrary. 68 | t.Fatalf(`expected message did not arrive`) 69 | } 70 | } 71 | 72 | // AssertTimeout waits for 10 milliseconds to ensure that no boolean is received on the given channel and 73 | // then bails out with a Fatal message "unexpected message arrived". 74 | func AssertTimeout(t *testing.T, c <-chan bool) { 75 | t.Helper() 76 | select { 77 | case <-c: 78 | t.Fatalf(`unexpected message arrived`) 79 | case <-time.After(10 * time.Millisecond): 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /test/full/client.go: -------------------------------------------------------------------------------- 1 | // +build citest 2 | 3 | package full 4 | 5 | import ( 6 | "io" 7 | "net" 8 | "strconv" 9 | "strings" 10 | "testing" 11 | 12 | "github.com/nats-io/nuid" 13 | "github.com/tada/mqtt-nats/mqtt" 14 | "github.com/tada/mqtt-nats/mqtt/pkg" 15 | "github.com/tada/mqtt-nats/test/packet" 16 | ) 17 | 18 | // NextClientID returns a new unique client ID with the prefix "testclient-" 19 | func NextClientID() string { 20 | return "testclient-" + nuid.Next() 21 | } 22 | 23 | // MqttConnect establishes a tcp connection to the given port on the default host 24 | func MqttConnect(t *testing.T, port int) net.Conn { 25 | t.Helper() 26 | conn, err := net.Dial("tcp", ":"+strconv.Itoa(port)) 27 | if err != nil { 28 | t.Fatal(err) 29 | } 30 | return conn 31 | } 32 | 33 | // MqttConnectClean establishes a tcp connection to the given port on the default host, sends the 34 | // initial connect packet for a clean session and awaits the CONNACK. 35 | func MqttConnectClean(t *testing.T, port int) net.Conn { 36 | conn := MqttConnect(t, port) 37 | MqttSend(t, conn, pkg.NewConnect(NextClientID(), true, 1, nil, nil)) 38 | MqttExpect(t, conn, pkg.NewConnAck(false, 0)) 39 | return conn 40 | } 41 | 42 | // MqttDisconnect sends a disconnect packet and closes the connection 43 | func MqttDisconnect(t *testing.T, conn io.WriteCloser) { 44 | t.Helper() 45 | defer func() { 46 | if err := conn.Close(); err != nil { 47 | t.Fatal(err) 48 | } 49 | }() 50 | buf := mqtt.NewWriter() 51 | pkg.DisconnectSingleton.Write(buf) 52 | _, err := conn.Write(buf.Bytes()) 53 | if err != nil { 54 | t.Fatal(err) 55 | } 56 | } 57 | 58 | // MqttSend writes the given packets on the given connection 59 | func MqttSend(t *testing.T, conn io.Writer, send ...pkg.Packet) { 60 | t.Helper() 61 | buf := mqtt.NewWriter() 62 | for i := range send { 63 | send[i].Write(buf) 64 | } 65 | _, err := conn.Write(buf.Bytes()) 66 | if err != nil { 67 | t.Fatal(err) 68 | } 69 | } 70 | 71 | // MqttExpect will read one packet for each entry in the list of expectations and assert that it is matched 72 | // by that entry. An expectation is either an expected verbatim pkg.Packet or a PacketMatcher function. 73 | func MqttExpect(t *testing.T, conn io.Reader, expectations ...interface{}) { 74 | t.Helper() 75 | for _, e := range expectations { 76 | a := packet.Parse(t, conn) 77 | switch e := e.(type) { 78 | case pkg.Packet: 79 | if !e.Equals(a) { 80 | t.Fatalf("expected '%s', got '%s'", e, a) 81 | } 82 | case func(pkg.Packet) bool: 83 | if !e(a) { 84 | t.Fatalf("packet '%s' does not match packet match function", a) 85 | } 86 | default: 87 | t.Fatalf("a %T is not a valid expectation", e) 88 | } 89 | } 90 | } 91 | 92 | // MqttExpectConnReset will make a read attempt and expect that it fails with an error 93 | func MqttExpectConnReset(t *testing.T, conn net.Conn) { 94 | t.Helper() 95 | _, err := conn.Read([]byte{0}) 96 | if err != nil { 97 | if strings.Contains(err.Error(), "EOF") { 98 | return 99 | } 100 | if strings.Contains(err.Error(), "reset") { 101 | return 102 | } 103 | if strings.Contains(err.Error(), "forcibly closed") { 104 | return 105 | } 106 | } 107 | t.Fatalf("connection is not reset: %v", err) 108 | } 109 | 110 | // MqttExpectConnClosed will make a read attempt and expect that it fails with an error 111 | func MqttExpectConnClosed(t *testing.T, conn net.Conn) { 112 | t.Helper() 113 | _, err := conn.Read([]byte{0}) 114 | if err != nil && err == io.EOF || strings.Contains(err.Error(), "closed") { 115 | return 116 | } 117 | t.Fatalf("connection is not closed: %v", err) 118 | } 119 | -------------------------------------------------------------------------------- /test/full/nats.go: -------------------------------------------------------------------------------- 1 | // +build citest 2 | 3 | package full 4 | 5 | import ( 6 | "strconv" 7 | "testing" 8 | 9 | "github.com/nats-io/nats.go" 10 | ) 11 | 12 | // NatsConnect creates a new NATS connection on the given port. 13 | func NatsConnect(t *testing.T, port int) *nats.Conn { 14 | nc, err := nats.Connect(":" + strconv.Itoa(port)) 15 | if err != nil { 16 | t.Fatal(err) 17 | } 18 | return nc 19 | } 20 | -------------------------------------------------------------------------------- /test/main_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "os" 5 | "strconv" 6 | "testing" 7 | 8 | "github.com/tada/mqtt-nats/bridge" 9 | "github.com/tada/mqtt-nats/logger" 10 | "github.com/tada/mqtt-nats/test/full" 11 | ) 12 | 13 | var mqttServer bridge.Bridge 14 | 15 | const ( 16 | storageFile = "mqtt-nats.json" 17 | mqttPort = 11883 18 | natsPort = 14222 19 | retainedRequestTopic = "mqtt.retained.request" 20 | ) 21 | 22 | func TestMain(m *testing.M) { 23 | _ = os.Remove(storageFile) 24 | natsServer := full.NATSServerOnPort(natsPort) 25 | 26 | // NOTE: Setting level to logger.Debug here is very helpful when authoring and debugging tests but 27 | // it also makes the tests very verbose. 28 | lg := logger.New(logger.Debug, os.Stdout, os.Stderr) 29 | 30 | opts := bridge.Options{ 31 | Port: mqttPort, 32 | NATSUrls: ":" + strconv.Itoa(natsPort), 33 | RepeatRate: 50, 34 | RetainedRequestTopic: retainedRequestTopic, 35 | StoragePath: storageFile} 36 | var err error 37 | mqttServer, err = full.RunBridge(lg, &opts) 38 | 39 | var code int 40 | if err == nil { 41 | code = m.Run() 42 | } else { 43 | lg.Error(err) 44 | code = 1 45 | } 46 | natsServer.Shutdown() 47 | if err = mqttServer.Shutdown(); err != nil { 48 | lg.Error(err) 49 | code = 1 50 | } 51 | os.Exit(code) 52 | } 53 | -------------------------------------------------------------------------------- /test/mock/connection.go: -------------------------------------------------------------------------------- 1 | // +build citest 2 | 3 | // Package mock contains mocking/simulated versions of real runtime types 4 | // primarily used for testing. 5 | // 6 | package mock 7 | 8 | import ( 9 | "bytes" 10 | "io" 11 | "net" 12 | "sync" 13 | "time" 14 | ) 15 | 16 | // Connection implements net.Conn and allows recording and playback. 17 | // 18 | // An instance must be created by calling NewConnection() and it can thereafter be 19 | // used the same way as a net.Conn returned from net.Dial. 20 | // In addition to the net.Conn API, the Connection supports 21 | // RemoteRead() and RemoteWrite() for what would be the remote end of 22 | // a net.Conn. 23 | // 24 | // The primary intended use case for Connection is to help with unit testing 25 | // logic using a net.Conn. 26 | // 27 | // ReadDeadLine is supported (TODO: currently WriteDeadLine is ignored as all writes will succeed until 28 | // the system is out of memory). 29 | // 30 | // TODO: This Connection does not have a way to simulate 31 | // syscall.ETIMEOUT error return resulting from a TCP keepAlive 32 | // failure. 33 | // 34 | // TODO: This Connection uses the deadlines for both local and remote 35 | // 36 | type Connection struct { 37 | input bytes.Buffer 38 | output bytes.Buffer 39 | closed bool 40 | readDeadline time.Time 41 | readDeadLineTimer *time.Timer 42 | writeDeadline time.Time 43 | moreData *sync.Cond // Cond to wait on when there is no data to read 44 | moreRemoteData *sync.Cond // Cond to wait on when there is no data to read at remote end 45 | } 46 | 47 | // NewConnection returns a new connection - i.e. comparable to net.Dial() but everything is hardcoded 48 | func NewConnection() *Connection { 49 | // start out with a stopped timer - this timer is reset when deadline is changed 50 | mc := &Connection{ 51 | moreData: sync.NewCond(&sync.Mutex{}), 52 | moreRemoteData: sync.NewCond(&sync.Mutex{}), 53 | readDeadLineTimer: newStoppedTimer(), 54 | } 55 | 56 | // Run a readDeadLine listener that wakes up those that are blocked reading 57 | go func() { 58 | for { 59 | // wait for timer to fire 60 | <-mc.readDeadLineTimer.C 61 | // timer fired, lock and Broadcast to those in blocking read 62 | // TODO: Fix debug level logging 63 | // if log.IsLevelEnabled(log.DebugLevel) { 64 | // log.Debug("mock.Connection ReadDeadline fired") 65 | // } 66 | md := mc.moreData 67 | md.L.Lock() 68 | md.Broadcast() 69 | md.L.Unlock() 70 | } 71 | }() 72 | return mc 73 | } 74 | 75 | // newStoppedTimer returns a stopped timer 76 | func newStoppedTimer() *time.Timer { 77 | timer := time.NewTimer(1 * time.Second) 78 | if !timer.Stop() { 79 | <-timer.C // drain if it fired (very unlikely, but happens when stepping over this in debugging) 80 | } 81 | return timer 82 | } 83 | 84 | // Addr implements net.Addr interface and is a staic "tcp" "0.0.0.0" 85 | type Addr struct{} 86 | 87 | // Network returns a static "tcp" for MockConnectionAddr 88 | func (a *Addr) Network() string { return "tcp" } 89 | 90 | // String returns a static "0.0.0.0" for MockConnectionAddr 91 | func (a *Addr) String() string { return "0.0.0.0" } 92 | 93 | // Read reads data from the connection. 94 | // Read can be made to time out and return an Error with Timeout() == true 95 | // after a fixed time limit; see SetDeadline and SetReadDeadline. 96 | func (c *Connection) Read(b []byte) (n int, err error) { 97 | return c.readBufWithLock(b, &c.input, c.moreData) 98 | } 99 | 100 | // RemoteRead reads data from the connection's remote end (this returns what was written with Write) 101 | func (c *Connection) RemoteRead(b []byte) (n int, err error) { 102 | return c.readBufWithLock(b, &c.output, c.moreRemoteData) 103 | } 104 | 105 | // readBufWithLock reads from the buffer and waits for more data if it is empty 106 | func (c *Connection) readBufWithLock(b []byte, buffer *bytes.Buffer, condition *sync.Cond) (n int, err error) { 107 | // TODO: timeout & read of 0 bytes? 108 | for { 109 | condition.L.Lock() 110 | again: 111 | availBytes := buffer.Len() 112 | if c.closed && availBytes == 0 { 113 | condition.L.Unlock() 114 | return 0, io.EOF 115 | } 116 | deadline := c.readDeadline 117 | if !deadline.IsZero() && time.Now().After(deadline) { // while all times are after "zero time", this avoids a system call to Now() 118 | condition.L.Unlock() 119 | return 0, ErrTimeout 120 | } 121 | if availBytes == 0 { 122 | condition.Wait() 123 | goto again 124 | } 125 | n, err = buffer.Read(b) 126 | condition.L.Unlock() 127 | return n, err 128 | } 129 | } 130 | 131 | // Write writes data to the connection. 132 | // Write can be made to time out and return an Error with Timeout() == true 133 | // after a fixed time limit; see SetDeadline and SetWriteDeadline. 134 | func (c *Connection) Write(b []byte) (n int, err error) { 135 | return c.writeBufWithLock(b, &c.output, c.moreRemoteData, c.writeDeadline) 136 | } 137 | 138 | // RemoteWrite writes data to the connection as if something was written on the remote side of 139 | // a connection. (This data will be returned from subsequent Read() operations). 140 | // 141 | func (c *Connection) RemoteWrite(b []byte) (n int, err error) { 142 | return c.writeBufWithLock(b, &c.input, c.moreData, c.writeDeadline) 143 | } 144 | 145 | // writeBufWithLock writes to the buffer, and signals a condition if buffer goes from empty to having content 146 | func (c *Connection) writeBufWithLock(b []byte, buffer *bytes.Buffer, condition *sync.Cond, deadline time.Time) (n int, err error) { 147 | condition.L.Lock() 148 | defer condition.L.Unlock() 149 | 150 | if c.closed { 151 | return 0, io.EOF 152 | } 153 | 154 | if !deadline.IsZero() && time.Now().After(deadline) { // while all times are after "zero time", this avoids a system call to Now() 155 | return 0, ErrTimeout 156 | } 157 | 158 | availBytes := buffer.Len() 159 | n, err = buffer.Write(b) 160 | if availBytes == 0 { 161 | condition.Broadcast() 162 | } 163 | return n, err 164 | } 165 | 166 | // Close closes the connection. 167 | // Any blocked Read or Write operations will be unblocked and return errors. 168 | func (c *Connection) Close() error { 169 | c.moreRemoteData.L.Lock() 170 | defer c.moreRemoteData.L.Unlock() 171 | c.moreData.L.Lock() 172 | defer c.moreData.L.Unlock() 173 | 174 | // mark closed 175 | c.closed = true 176 | 177 | // release blocked readers so they pick up the close 178 | if c.input.Len() == 0 { 179 | c.moreData.Broadcast() 180 | } 181 | // release blocked remote readers so they pick up the close 182 | if c.output.Len() == 0 { 183 | c.moreRemoteData.Broadcast() 184 | } 185 | return nil 186 | } 187 | 188 | // LocalAddr returns a hardcoded local network address. 189 | func (c *Connection) LocalAddr() net.Addr { 190 | return &Addr{} 191 | } 192 | 193 | // RemoteAddr returns a hardcoded remote network address. 194 | func (c *Connection) RemoteAddr() net.Addr { 195 | return &Addr{} 196 | } 197 | 198 | // SetDeadline sets the read and write deadlines associated 199 | // with the connection. It is equivalent to calling both 200 | // SetReadDeadline and SetWriteDeadline. 201 | // 202 | // A deadline is an absolute time after which I/O operations 203 | // fail with a timeout (see type Error) instead of 204 | // blocking. The deadline applies to all future and pending 205 | // I/O, not just the immediately following call to Read or 206 | // Write. After a deadline has been exceeded, the connection 207 | // can be refreshed by setting a deadline in the future. 208 | // 209 | // An idle timeout can be implemented by repeatedly extending 210 | // the deadline after successful Read or Write calls. 211 | // 212 | // A zero value for t means I/O operations will not time out. 213 | // 214 | // Note that if a TCP connection has keep-alive turned on, 215 | // which is the default unless overridden by Dialer.KeepAlive 216 | // or ListenConfig.KeepAlive, then a keep-alive failure may 217 | // also return a timeout error. On Unix systems a keep-alive 218 | // failure on I/O can be detected using 219 | // errors.Is(err, syscall.ETIMEDOUT). 220 | // 221 | // TODO: This MockConnection does not have a way to simulate 222 | // syscall.ETIMEOUT error return resulting from a TCP keepAlive 223 | // failure. 224 | // 225 | // TODO: At present the WriteDeadLine does not have any effect. 226 | func (c *Connection) SetDeadline(t time.Time) error { 227 | err := c.SetReadDeadline(t) 228 | if err == nil { 229 | err = c.SetWriteDeadline(t) 230 | } 231 | return err 232 | } 233 | 234 | // SetReadDeadline sets the deadline for future Read calls 235 | // and any currently-blocked Read call. 236 | // A zero value for t means Read will not time out. 237 | func (c *Connection) SetReadDeadline(t time.Time) error { 238 | c.moreData.L.Lock() 239 | defer c.moreData.L.Unlock() 240 | c.readDeadline = t 241 | 242 | // stop the currently ticking (or possibly fired) timer 243 | oldTimer := c.readDeadLineTimer 244 | if !oldTimer.Stop() { 245 | select { 246 | case <-oldTimer.C: // drain if it fired 247 | default: // it was stopped already 248 | } 249 | } 250 | if t.IsZero() { 251 | // No need to wake those that are blocked, simply keep the timer in stopped state 252 | return nil 253 | } 254 | // Set a new duration 255 | oldTimer.Reset(time.Until(t)) 256 | return nil 257 | } 258 | 259 | // SetWriteDeadline sets the deadline for future Write calls 260 | // and any currently-blocked Write call. 261 | // Even if write times out, it may return n > 0, indicating that 262 | // some of the data was successfully written. 263 | // A zero value for t means Write will not time out. 264 | // TODO: At present this WriteDeadLine does not have any effect. 265 | func (c *Connection) SetWriteDeadline(t time.Time) error { 266 | c.writeDeadline = t 267 | return nil 268 | } 269 | 270 | // RemoteIO is an interface for operations on what MockConnection.Remote() returns 271 | type RemoteIO interface { 272 | io.ReadWriter 273 | io.ByteReader 274 | } 275 | 276 | // Remote is a io.ReadWriter that adapts the "remote" end of a MockConnection to 277 | // io.ReadWriter 278 | type Remote struct { 279 | conn *Connection 280 | } 281 | 282 | func (r *Remote) Read(b []byte) (int, error) { 283 | return r.conn.RemoteRead(b) 284 | } 285 | 286 | func (r *Remote) Write(b []byte) (int, error) { 287 | return r.conn.RemoteWrite(b) 288 | } 289 | 290 | // ReadByte reads one byte - implements io.ByteReader 291 | func (r *Remote) ReadByte() (byte, error) { 292 | b := make([]byte, 1) 293 | n, err := r.conn.RemoteRead(b) 294 | if err != nil || n != 1 { 295 | return 0, err 296 | } 297 | return b[0], err 298 | } 299 | 300 | // Remote returns a io.ReadWriter for the remote end of this MockConnetion 301 | func (c *Connection) Remote() RemoteIO { 302 | return &Remote{conn: c} 303 | } 304 | 305 | // TimeoutError is returned for an expired deadline. 306 | // It implements net.Error interface 307 | // This type is needed because the error actually returned from net.Conn is an internal data type 308 | // 309 | type TimeoutError struct{} 310 | 311 | // Error Implement the net.Error interface. 312 | func (e *TimeoutError) Error() string { return "i/o timeout" } 313 | 314 | // Timeout Implement the net.Error interface. 315 | func (e *TimeoutError) Timeout() bool { return true } 316 | 317 | // Temporary Implement the net.Error interface. 318 | func (e *TimeoutError) Temporary() bool { return true } 319 | 320 | // ErrTimeout is returned for an expired deadline 321 | var ErrTimeout error = &TimeoutError{} 322 | -------------------------------------------------------------------------------- /test/mock/connection_test.go: -------------------------------------------------------------------------------- 1 | package mock 2 | 3 | import ( 4 | "io" 5 | "net" 6 | "testing" 7 | "time" 8 | 9 | "github.com/tada/mqtt-nats/test/utils" 10 | ) 11 | 12 | func Test_MockConnection_implements_net_Conn(t *testing.T) { 13 | defer utils.ShouldNotPanic(t) 14 | _ = net.Conn(NewConnection()) 15 | } 16 | 17 | func Test_MockConnection_has_hardocded_local_and_remote_addr(t *testing.T) { 18 | conn := NewConnection() 19 | local := conn.LocalAddr() 20 | remote := conn.RemoteAddr() 21 | utils.CheckEqual("0.0.0.0", local.String(), t) 22 | utils.CheckEqual("tcp", local.Network(), t) 23 | utils.CheckEqual("0.0.0.0", remote.String(), t) 24 | utils.CheckEqual("tcp", remote.Network(), t) 25 | } 26 | 27 | // Test that RemoteWrite can be called without error 28 | func Test_MockConnection_Can_do_RemoteWrite(t *testing.T) { 29 | conn := NewConnection() 30 | n, err := conn.RemoteWrite([]byte("test")) 31 | utils.CheckEqual(4, n, t) 32 | utils.CheckNotError(err, t) 33 | } 34 | 35 | // Test that what is written "remotely" can be read "locally" 36 | func Test_MockConnection_Read_reads_what_was_written_with_RemoteWrite(t *testing.T) { 37 | conn := NewConnection() 38 | n, err := conn.RemoteWrite([]byte("test")) 39 | utils.CheckEqual(4, n, t) 40 | utils.CheckNotError(err, t) 41 | buf := make([]byte, 4) 42 | n, err = conn.Read(buf) 43 | utils.CheckEqual(4, n, t) 44 | utils.CheckNotError(err, t) 45 | } 46 | 47 | // Test that what is written "locally" can be read "remotely" 48 | func Test_MockConnection_ReadRemote_reads_what_was_written_with_Write(t *testing.T) { 49 | conn := NewConnection() 50 | n, err := conn.Write([]byte("test")) 51 | utils.CheckEqual(4, n, t) 52 | utils.CheckNotError(err, t) 53 | buf := make([]byte, 4) 54 | n, err = conn.RemoteRead(buf) 55 | utils.CheckEqual(4, n, t) 56 | utils.CheckNotError(err, t) 57 | } 58 | 59 | func Test_MockConnection_Read_waits_for_data_until_close(t *testing.T) { 60 | conn := NewConnection() 61 | readResult := make(chan error) 62 | timeout := make(chan bool) 63 | readEndTime := time.Now() 64 | closeTime := time.Now().Add(1 * time.Millisecond) // ensure this is after 65 | 66 | go func() { 67 | delay := 200 * time.Millisecond 68 | time.Sleep(delay) 69 | timeout <- true 70 | }() 71 | go func() { 72 | aByte := make([]byte, 1) 73 | _, err := conn.Read(aByte) 74 | readEndTime = time.Now() 75 | readResult <- err 76 | }() 77 | 78 | // Wait for the timeout 79 | <-timeout 80 | if err := conn.Close(); err != nil { 81 | panic(err) 82 | } 83 | 84 | // Wait for the read 85 | err := <-readResult 86 | 87 | // Did they occur in the expected order? 88 | utils.CheckTrue(readEndTime.After(closeTime), t) 89 | utils.CheckTrue(err == io.EOF, t) 90 | } 91 | 92 | func Test_MockConnection_Read_waits_until_given_read_deadline(t *testing.T) { 93 | conn := NewConnection() 94 | if err := conn.SetDeadline(time.Now().Add(200 * time.Millisecond)); err != nil { 95 | t.Fatal(err) 96 | } 97 | 98 | aByte := make([]byte, 1) 99 | _, err := conn.Read(aByte) 100 | nerr, ok := err.(net.Error) 101 | if !ok { 102 | t.Fatalf("Expected a net.Error but could not convert it") 103 | } 104 | utils.CheckTrue(nerr.Timeout(), t) 105 | } 106 | 107 | func Test_MockConnection_Read_returns_amount_read_if_buffer_becomes_empty(t *testing.T) { 108 | conn := NewConnection() 109 | oneByte := []byte{1} 110 | n, err := conn.RemoteWrite(oneByte) 111 | utils.CheckEqual(1, n, t) 112 | utils.CheckNotError(err, t) 113 | 114 | threeBytes := make([]byte, 3) 115 | n, err = conn.Read(threeBytes) 116 | utils.CheckEqual(1, n, t) 117 | utils.CheckNotError(err, t) 118 | utils.CheckEqual(byte(1), threeBytes[0], t) 119 | 120 | oneByte = []byte{2} 121 | n, err = conn.RemoteWrite(oneByte) 122 | utils.CheckEqual(1, n, t) 123 | utils.CheckNotError(err, t) 124 | 125 | n, err = conn.Read(threeBytes) 126 | utils.CheckEqual(1, n, t) 127 | utils.CheckNotError(err, t) 128 | utils.CheckEqual(byte(2), threeBytes[0], t) 129 | } 130 | 131 | // TODO: Test Multithreaded reading and writing 132 | // 1. n threads write 1 byte each, n threads read one byte each - all bytes are read 133 | -------------------------------------------------------------------------------- /test/package.go: -------------------------------------------------------------------------------- 1 | // +build citest 2 | 3 | // Package test contains the in-process functional tests for a fully configure mqtt-nats bridge + 4 | // nats test server combo 5 | package test 6 | -------------------------------------------------------------------------------- /test/packet/parse.go: -------------------------------------------------------------------------------- 1 | // +build citest 2 | 3 | // Package packet contains a test version of the MQTT packet parser 4 | package packet 5 | 6 | import ( 7 | "fmt" 8 | "io" 9 | "testing" 10 | 11 | "github.com/tada/mqtt-nats/mqtt" 12 | "github.com/tada/mqtt-nats/mqtt/pkg" 13 | ) 14 | 15 | // Parse is a test helper function that parses the next package from the given reader and returns it. t.Fatal(err) 16 | // will be called if an error occurs when reading or parsing. 17 | func Parse(t *testing.T, rdr io.Reader) pkg.Packet { 18 | t.Helper() 19 | // Read packet type and flags 20 | r := mqtt.NewReader(rdr) 21 | var p pkg.Packet 22 | b, err := r.ReadByte() 23 | if err == nil { 24 | pkgType := b & pkg.TpMask 25 | 26 | // Read packet length 27 | var rl int 28 | rl, err = r.ReadVarInt() 29 | if err == nil { 30 | switch pkgType { 31 | case pkg.TpConnAck: 32 | p, err = pkg.ParseConnAck(r, b, rl) 33 | case pkg.TpDisconnect: 34 | p = pkg.DisconnectSingleton 35 | case pkg.TpPing: 36 | p = pkg.PingRequestSingleton 37 | case pkg.TpPingResp: 38 | p = pkg.PingResponseSingleton 39 | case pkg.TpConnect: 40 | p, err = pkg.ParseConnect(r, b, rl) 41 | case pkg.TpPublish: 42 | p, err = pkg.ParsePublish(r, b, rl) 43 | case pkg.TpPubAck: 44 | p, err = pkg.ParsePubAck(r, b, rl) 45 | case pkg.TpPubRec: 46 | p, err = pkg.ParsePubRec(r, b, rl) 47 | case pkg.TpPubRel: 48 | p, err = pkg.ParsePubRel(r, b, rl) 49 | case pkg.TpPubComp: 50 | p, err = pkg.ParsePubComp(r, b, rl) 51 | case pkg.TpSubscribe: 52 | p, err = pkg.ParseSubscribe(r, b, rl) 53 | case pkg.TpSubAck: 54 | p, err = pkg.ParseSubAck(r, b, rl) 55 | case pkg.TpUnsubscribe: 56 | p, err = pkg.ParseUnsubscribe(r, b, rl) 57 | case pkg.TpUnsubAck: 58 | p, err = pkg.ParseUnsubAck(r, b, rl) 59 | default: 60 | err = fmt.Errorf("received unknown packet type %d", (b&pkg.TpMask)>>4) 61 | } 62 | } 63 | } 64 | if err != nil { 65 | t.Fatal(err) 66 | } 67 | return p 68 | } 69 | -------------------------------------------------------------------------------- /test/packet/parse_test.go: -------------------------------------------------------------------------------- 1 | package packet 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/tada/mqtt-nats/test/utils" 8 | ) 9 | 10 | func TestParse_unknownPacket(t *testing.T) { 11 | utils.EnsureFailed(t, func(st *testing.T) { 12 | Parse(st, bytes.NewReader([]byte{0xf0, 0})) 13 | }, "unknown packet type 15") 14 | } 15 | -------------------------------------------------------------------------------- /test/publish_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | "time" 7 | 8 | "github.com/nats-io/nats.go" 9 | "github.com/tada/mqtt-nats/mqtt/pkg" 10 | "github.com/tada/mqtt-nats/test/full" 11 | ) 12 | 13 | func TestPublishSubscribe(t *testing.T) { 14 | topic := "testing/some/topic" 15 | pp := pkg.SimplePublish(topic, []byte("payload")) 16 | gotIt := make(chan bool, 1) 17 | c1 := full.MqttConnectClean(t, mqttPort) 18 | mid := nextPacketID() 19 | full.MqttSend(t, c1, pkg.NewSubscribe(mid, pkg.Topic{Name: topic})) 20 | full.MqttExpect(t, c1, pkg.NewSubAck(mid, 0)) 21 | go func() { 22 | full.MqttExpect(t, c1, pp) 23 | gotIt <- true 24 | full.MqttDisconnect(t, c1) 25 | }() 26 | 27 | c2 := full.MqttConnectClean(t, mqttPort) 28 | full.MqttSend(t, c2, pp) 29 | full.MqttDisconnect(t, c2) 30 | full.AssertMessageReceived(t, gotIt) 31 | } 32 | 33 | func TestPublishSubscribe_qos_1(t *testing.T) { 34 | topic := "testing/some/topic" 35 | mid := nextPacketID() 36 | pp := pkg.NewPublish2(mid, topic, []byte("payload"), 1, false, false) 37 | c1 := full.MqttConnectClean(t, mqttPort) 38 | gotIt := make(chan bool, 1) 39 | go func() { 40 | sid := nextPacketID() 41 | full.MqttSend(t, c1, pkg.NewSubscribe(sid, pkg.Topic{Name: topic, QoS: 1})) 42 | full.MqttExpect(t, c1, pkg.NewSubAck(sid, 1), pp) 43 | full.MqttSend(t, c1, pkg.PubAck(mid)) 44 | full.MqttDisconnect(t, c1) 45 | gotIt <- true 46 | }() 47 | 48 | c2 := full.MqttConnectClean(t, mqttPort) 49 | full.MqttSend(t, c2, pp) 50 | full.MqttExpect(t, c2, pkg.PubAck(mid)) 51 | full.MqttDisconnect(t, c2) 52 | full.AssertMessageReceived(t, gotIt) 53 | } 54 | 55 | func TestPublishSubscribe_qos_2(t *testing.T) { 56 | topic := "testing/some/topic" 57 | mid := nextPacketID() 58 | pp := pkg.NewPublish2(mid, topic, []byte("payload"), 1, false, false) 59 | c1 := full.MqttConnectClean(t, mqttPort) 60 | gotIt := make(chan bool, 1) 61 | go func() { 62 | sid := nextPacketID() 63 | full.MqttSend(t, c1, pkg.NewSubscribe(sid, pkg.Topic{Name: topic, QoS: 2})) 64 | full.MqttExpect(t, c1, pkg.NewSubAck(sid, 1), pp) 65 | full.MqttSend(t, c1, pkg.PubAck(mid)) 66 | full.MqttDisconnect(t, c1) 67 | gotIt <- true 68 | }() 69 | 70 | c2 := full.MqttConnectClean(t, mqttPort) 71 | full.MqttSend(t, c2, pp) 72 | full.MqttExpect(t, c2, pkg.PubAck(mid)) 73 | full.MqttDisconnect(t, c2) 74 | full.AssertMessageReceived(t, gotIt) 75 | } 76 | 77 | func TestPublishSubscribe_qos_1_restart(t *testing.T) { 78 | topic := "testing/some/topic" 79 | mid := nextPacketID() 80 | pp := pkg.NewPublish2(mid, topic, []byte("payload"), 1, false, false) 81 | 82 | c1ID := full.NextClientID() 83 | c1 := full.MqttConnect(t, mqttPort) 84 | gotIt := make(chan bool, 1) 85 | go func() { 86 | full.MqttSend(t, c1, pkg.NewConnect(c1ID, false, 1, nil, nil)) 87 | full.MqttExpect(t, c1, pkg.NewConnAck(false, 0)) 88 | 89 | sid := nextPacketID() 90 | full.MqttSend(t, c1, pkg.NewSubscribe(sid, pkg.Topic{Name: topic, QoS: 1})) 91 | full.MqttExpect(t, c1, pkg.NewSubAck(sid, 1)) 92 | gotIt <- true 93 | full.MqttExpect(t, c1, pp) 94 | gotIt <- true 95 | }() 96 | 97 | c2ID := full.NextClientID() 98 | c2 := full.MqttConnect(t, mqttPort) 99 | full.MqttSend(t, c2, pkg.NewConnect(c2ID, false, 1, nil, nil)) 100 | full.MqttExpect(t, c2, pkg.NewConnAck(false, 0)) 101 | full.AssertMessageReceived(t, gotIt) 102 | full.MqttSend(t, c2, pp) 103 | full.AssertMessageReceived(t, gotIt) 104 | 105 | full.RestartBridge(t, mqttServer) 106 | 107 | // client c1 reestablishes session and sends outstanding ack 108 | c1 = full.MqttConnect(t, mqttPort) 109 | full.MqttSend(t, c1, pkg.NewConnect(c1ID, false, 1, nil, nil)) 110 | full.MqttExpect(t, c1, pkg.NewConnAck(true, 0)) 111 | 112 | // client c2 reestablishes session and receives outstanding ack 113 | c2 = full.MqttConnect(t, mqttPort) 114 | full.MqttSend(t, c2, pkg.NewConnect(c2ID, false, 1, nil, nil)) 115 | full.MqttExpect(t, c2, pkg.NewConnAck(true, 0)) 116 | 117 | full.MqttSend(t, c1, pkg.PubAck(mid)) 118 | full.MqttExpect(t, c2, pkg.PubAck(mid)) 119 | 120 | full.MqttDisconnect(t, c1) 121 | full.MqttDisconnect(t, c2) 122 | } 123 | 124 | func TestMqttPublishNatsSubscribe(t *testing.T) { 125 | pl := []byte("payload") 126 | pp := pkg.SimplePublish("testing/s.o.m.e/topic", pl) 127 | gotIt := make(chan bool, 1) 128 | nc := full.NatsConnect(t, natsPort) 129 | defer nc.Close() 130 | 131 | _, err := nc.Subscribe("testing.s/o/m/e.>", func(m *nats.Msg) { 132 | if !bytes.Equal(pl, m.Data) { 133 | t.Error("nats subscription did not receive expected data") 134 | } 135 | gotIt <- true 136 | }) 137 | if err != nil { 138 | t.Fatal(err) 139 | } 140 | 141 | c2 := full.MqttConnectClean(t, mqttPort) 142 | full.MqttSend(t, c2, pp) 143 | full.MqttDisconnect(t, c2) 144 | full.AssertMessageReceived(t, gotIt) 145 | } 146 | 147 | func TestNatsPublishMqttSubscribe(t *testing.T) { 148 | topic := "testing/some/topic" 149 | pl := []byte("payload") 150 | pp := pkg.SimplePublish(topic, pl) 151 | 152 | c1 := full.MqttConnectClean(t, mqttPort) 153 | sid := nextPacketID() 154 | full.MqttSend(t, c1, pkg.NewSubscribe(sid, pkg.Topic{Name: "testing/+/topic"})) 155 | full.MqttExpect(t, c1, pkg.NewSubAck(sid, 0)) 156 | 157 | gotIt := make(chan bool, 1) 158 | go func() { 159 | full.MqttExpect(t, c1, pp) 160 | full.MqttDisconnect(t, c1) 161 | gotIt <- true 162 | }() 163 | 164 | nc := full.NatsConnect(t, natsPort) 165 | defer nc.Close() 166 | err := nc.Publish("testing.some.topic", pl) 167 | if err != nil { 168 | t.Fatal(err) 169 | } 170 | full.AssertMessageReceived(t, gotIt) 171 | } 172 | 173 | func TestNatsPublishMqttSubscribe_qos_1(t *testing.T) { 174 | topic := "testing/some/topic" 175 | pl := []byte("payload") 176 | 177 | c1 := full.MqttConnectClean(t, mqttPort) 178 | sid := nextPacketID() 179 | full.MqttSend(t, c1, pkg.NewSubscribe(sid, pkg.Topic{Name: topic, QoS: 1})) 180 | full.MqttExpect(t, c1, pkg.NewSubAck(sid, 1)) 181 | 182 | gotIt := make(chan bool, 1) 183 | go func() { 184 | full.MqttExpect(t, c1, func(p pkg.Packet) bool { 185 | if pp, ok := p.(*pkg.Publish); ok { 186 | full.MqttSend(t, c1, pkg.PubAck(pp.ID())) 187 | return pp.TopicName() == topic && bytes.Equal(pp.Payload(), pl) && pp.QoSLevel() == 1 188 | } 189 | return false 190 | }) 191 | full.MqttDisconnect(t, c1) 192 | gotIt <- true 193 | }() 194 | 195 | nc := full.NatsConnect(t, natsPort) 196 | defer nc.Close() 197 | _, err := nc.Request("testing.some.topic", pl, 10*time.Millisecond) 198 | if err != nil { 199 | t.Fatal(err) 200 | } 201 | } 202 | 203 | func TestUnubscribe(t *testing.T) { 204 | topic := "testing/some/topic" 205 | pp := pkg.SimplePublish(topic, []byte("payload")) 206 | gotIt := make(chan bool, 1) 207 | c1 := full.MqttConnectClean(t, mqttPort) 208 | mid := nextPacketID() 209 | full.MqttSend(t, c1, pkg.NewSubscribe(mid, pkg.Topic{Name: topic})) 210 | full.MqttExpect(t, c1, pkg.NewSubAck(mid, 0)) 211 | go func() { 212 | full.MqttExpect(t, c1, pp) 213 | 214 | uid := nextPacketID() 215 | full.MqttSend(t, c1, pkg.NewUnsubscribe(uid, topic)) 216 | full.MqttExpect(t, c1, pkg.UnsubAck(uid)) 217 | gotIt <- true 218 | full.MqttExpectConnClosed(t, c1) 219 | gotIt <- true 220 | }() 221 | 222 | c2 := full.MqttConnectClean(t, mqttPort) 223 | full.MqttSend(t, c2, pp) 224 | 225 | // wait for subscriber to consume and unsubscribe 226 | full.AssertMessageReceived(t, gotIt) 227 | 228 | // send again, this should not reach subscriber 229 | full.MqttSend(t, c2, pp) 230 | full.MqttDisconnect(t, c2) 231 | full.AssertTimeout(t, gotIt) 232 | full.MqttDisconnect(t, c1) 233 | } 234 | 235 | func TestPublish_qos_2(t *testing.T) { 236 | conn := full.MqttConnectClean(t, mqttPort) 237 | full.MqttSend(t, conn, pkg.NewPublish2( 238 | nextPacketID(), "testing/some/topic", []byte("payload"), 2, false, false)) 239 | full.MqttExpectConnReset(t, conn) 240 | } 241 | 242 | func TestPublish_qos_3(t *testing.T) { 243 | conn := full.MqttConnectClean(t, mqttPort) 244 | full.MqttSend(t, conn, pkg.NewPublish2( 245 | nextPacketID(), "testing/some/topic", []byte("payload"), 3, false, false)) 246 | full.MqttExpectConnReset(t, conn) 247 | } 248 | -------------------------------------------------------------------------------- /test/retained_test.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "encoding/base64" 5 | "encoding/json" 6 | "testing" 7 | "time" 8 | 9 | "github.com/tada/mqtt-nats/mqtt" 10 | "github.com/tada/mqtt-nats/mqtt/pkg" 11 | "github.com/tada/mqtt-nats/test/full" 12 | ) 13 | 14 | func decodeRetained(t *testing.T, data []byte) []*pkg.Publish { 15 | t.Helper() 16 | var ms []map[string]string 17 | err := json.Unmarshal(data, &ms) 18 | if err != nil { 19 | t.Fatal(err) 20 | } 21 | ns := make([]*pkg.Publish, len(ms)) 22 | for i := range ms { 23 | m := ms[i] 24 | var pl []byte 25 | if ps, ok := m["payload"]; ok { 26 | pl = []byte(ps) 27 | } else if pe, ok := m["payloadEnc"]; ok { 28 | if pl, err = base64.StdEncoding.DecodeString(pe); err != nil { 29 | t.Fatal(err) 30 | } 31 | } 32 | ns[i] = pkg.NewPublish2(0, mqtt.FromNATS(m["subject"]), pl, 0, false, true) 33 | } 34 | return ns 35 | } 36 | 37 | func TestNATS_requestRetained(t *testing.T) { 38 | conn := full.MqttConnectClean(t, mqttPort) 39 | pp1 := pkg.NewPublish2(0, "testing/s.o.m.e/retained/first", []byte("the first retained message"), 0, false, true) 40 | pp2 := pkg.NewPublish2(0, "testing/s.o.m.e/retained/second", []byte("the second retained message"), 0, false, true) 41 | full.MqttSend(t, conn, pp1, pp2) 42 | full.MqttDisconnect(t, conn) 43 | 44 | nc := full.NatsConnect(t, natsPort) 45 | defer nc.Close() 46 | 47 | m, err := nc.Request(retainedRequestTopic, []byte("testing.s/o/m/e.retained.>"), 10*time.Millisecond) 48 | if err != nil { 49 | t.Fatal(err) 50 | } 51 | 52 | pps := decodeRetained(t, m.Data) 53 | if !(len(pps) == 2 && pp1.Equals(pps[0]) && pp2.Equals(pps[1])) { 54 | t.Fatal("unexpected retained publication") 55 | } 56 | 57 | m, err = nc.Request(retainedRequestTopic, []byte("do.not.find.this"), 10*time.Millisecond) 58 | if err != nil { 59 | t.Fatal(err) 60 | } 61 | pps = decodeRetained(t, m.Data) 62 | if len(pps) != 0 { 63 | t.Fatal("unexpected retained publication") 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /test/tls/connect_test.go: -------------------------------------------------------------------------------- 1 | package tls 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/tada/mqtt-nats/mqtt/pkg" 7 | "github.com/tada/mqtt-nats/test/full" 8 | ) 9 | 10 | func TestConnect(t *testing.T) { 11 | conn := tlsDial(t, mqttPort) 12 | full.MqttSend(t, conn, pkg.NewConnect(full.NextClientID(), true, 1, nil, nil)) 13 | full.MqttExpect(t, conn, pkg.NewConnAck(false, 0)) 14 | full.MqttDisconnect(t, conn) 15 | } 16 | -------------------------------------------------------------------------------- /test/tls/main_test.go: -------------------------------------------------------------------------------- 1 | package tls 2 | 3 | import ( 4 | "crypto/tls" 5 | "crypto/x509" 6 | "io/ioutil" 7 | "net" 8 | "os" 9 | "strconv" 10 | "testing" 11 | 12 | "github.com/nats-io/nats-server/v2/test" 13 | 14 | "github.com/nats-io/nats.go" 15 | "github.com/tada/mqtt-nats/bridge" 16 | "github.com/tada/mqtt-nats/logger" 17 | "github.com/tada/mqtt-nats/test/full" 18 | ) 19 | 20 | var mqttServer bridge.Bridge 21 | 22 | const ( 23 | storageFile = "mqtt-nats.json" 24 | mqttPort = 18883 25 | natsPort = 14443 26 | retainedRequestTopic = "mqtt.retained.request" 27 | ) 28 | 29 | // tlsDial establishes a tcp tls connection to the given port on the default host 30 | func tlsDial(t *testing.T, port int) net.Conn { 31 | t.Helper() 32 | pbs, err := ioutil.ReadFile("testdata/ca.pem") 33 | if err != nil { 34 | t.Fatal(err) 35 | } 36 | roots := x509.NewCertPool() 37 | roots.AppendCertsFromPEM(pbs) 38 | cert, err := tls.LoadX509KeyPair("testdata/client.pem", "testdata/client-key.pem") 39 | if err != nil { 40 | t.Fatal(err) 41 | } 42 | conn, err := tls.Dial("tcp", ":"+strconv.Itoa(port), &tls.Config{ 43 | ServerName: "127.0.0.1", 44 | RootCAs: roots, 45 | Certificates: []tls.Certificate{cert}, 46 | }) 47 | if err != nil { 48 | t.Fatal(err) 49 | } 50 | return conn 51 | } 52 | 53 | func TestMain(m *testing.M) { 54 | _ = os.Remove(storageFile) 55 | natsServer, _ := test.RunServerWithConfig("server.conf") 56 | 57 | natsOpts := []nats.Option{nats.Name("MQTT Bridge")} 58 | natsOpts = append(natsOpts, nats.ClientCert("testdata/client.pem", "testdata/client-key.pem")) 59 | natsOpts = append(natsOpts, nats.RootCAs("testdata/ca.pem")) 60 | 61 | // NOTE: Setting level to logger.Debug here is very helpful when authoring and debugging tests but 62 | // it also makes the tests very verbose. 63 | lg := logger.New(logger.Silent, os.Stdout, os.Stderr) 64 | 65 | opts := bridge.Options{ 66 | Port: mqttPort, 67 | NATSUrls: "localhost:" + strconv.Itoa(natsPort), 68 | RepeatRate: 50, 69 | RetainedRequestTopic: retainedRequestTopic, 70 | StoragePath: storageFile, 71 | TLS: true, 72 | TLSCaCert: "testdata/ca.pem", 73 | TLSCert: "testdata/server.pem", 74 | TLSKey: "testdata/server-key.pem", 75 | NATSOpts: natsOpts} 76 | 77 | var err error 78 | mqttServer, err = full.RunBridge(lg, &opts) 79 | 80 | var code int 81 | if err == nil { 82 | code = m.Run() 83 | } else { 84 | lg.Error(err) 85 | code = 1 86 | } 87 | natsServer.Shutdown() 88 | if err = mqttServer.Shutdown(); err != nil { 89 | lg.Error(err) 90 | code = 1 91 | } 92 | os.Exit(code) 93 | } 94 | -------------------------------------------------------------------------------- /test/tls/package.go: -------------------------------------------------------------------------------- 1 | // +build citest 2 | 3 | // Package tls contains full roundtrip tests using TLS configured connections 4 | package tls 5 | -------------------------------------------------------------------------------- /test/tls/server.conf: -------------------------------------------------------------------------------- 1 | listen: localhost:14443 2 | tls: { 3 | ca_file: "./testdata/ca.pem" 4 | cert_file: "./testdata/server.pem" 5 | key_file: "./testdata/server-key.pem" 6 | verify: true 7 | } 8 | -------------------------------------------------------------------------------- /test/tls/testdata/ca-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpQIBAAKCAQEA8qmvXNCsGLhQeq5/XIZp0bLL8lOqvNQz86U5QQwETnxs1KDU 3 | DquhNhzd2lgvaP98zO8Owkqz/kXIT4R/RxIeZ1BD9gY9iVVMyZar/bI+72DpX4Hg 4 | x7+UC9m34CAfK/AAYGZIgjhpDBobkVaKjVahBWUF7BWon+1xu23XdS99lxThJCTv 5 | G5gfquGFf4eSvhGSt7SRzW5nVP+OpTKxNMMawZa15jh3U7BfrXxDeZfDKFxF2RaK 6 | 6Lwi+ATLLe9wwvEKvCvDT8XPKsm9RUl48pLKRMy73tWitSh9SSFoa5zd49qIteLt 7 | LPKbNWPJLUyVrJzvbv9hVVJe4CKd6zKH+mYQywIDAQABAoIBAQCDGsp0CwnwESTq 8 | I30MMFLbyQ4HTszgWIX5DTtxuVxaSz9BYeMwSeo/ojj6zspOoDp9Pmtq7ZFxv6IJ 9 | 1Dwv2cozZ1pQge6dVEi4YX9rAfKewm1T/IfFY+xIushtfu1Yf8K0Uo66TF/0+eYL 10 | EAardjJpB7u7YbhJL7BS43WVCqOAC+p/nzJAyjETvvJhzUo7LEc7IeLUhAHCfTOX 11 | tUdmEOQnRy9xhgd2wkOwrt4loAxwdZALvbLfl58YtE6FgKeDLTQrNQzsA3a5dB2O 12 | 2N2UbB26RjTo7cshnj09uRuTWCiu9NYzPC/u6L9avPkRcXIgtqHrJ2k7U9MnRXXD 13 | 7lGRWg4BAoGBAPYQte1B8y5fNG/TEC/yV2lNO5zd+Noq/xQXG0huULYUXpaPcxg3 14 | 5+oBY7+Y6suI4XDasLKTz6HFIB414cKg9us0VH1fMJNM6HOqNdjRbR2udJ7fersW 15 | 24CviytYck7bbulDzmJgfxIpdF7zsFkh7MbK/8b9KMSxkBvngCWmk/f7AoGBAPx1 16 | zqk9RUL7DAOBqsbMQi8ldhRapSqOxEf3pKk+HaQ92l96ef+yGMghfhc4uKLTZ9tY 17 | wqFnPbi5Drsgboaoepcy3XYOFB13E+eac6px0l50//iQg5hEHkKdlvspsmVWLOMP 18 | Y0wir5X6Q0Fmw+ORcy0jpEy0CKT5wolZnhmuxuFxAoGBAPYPv92CFaxJiCZK6eUI 19 | cmDa2sIDNtb0KB/u+1ly90MdG3lz+aQ+Q6u9uAHg6Oqf9tDj3860AO3EMloDh78Z 20 | N9H8goDcr7adMdZ4X2ByDKuhyP0WfaSZNud4o7K0v5ob1M1vAPNfi7KdwcEx7ycy 21 | xZQFa8GRZzNKXNGKrpr3+QABAoGBAIhl3dHyGImnuUXruKjPkrKGOtWkY7gqikGX 22 | uo710G38PQ94zJEpV9pIvictrhPKxEHuIrmxXdd/pEXVr+FxBUrLYHt3/8Yrn8Vx 23 | 3Swpcs81x1Y0PeT2aKL1Ia1xScEWXgoPNkbcNqGBJPUg4JUC8IdiylHmswTvK/up 24 | P5IAq9MBAoGAVMOChcBBOAwzsgaVWm9hAcdMfh9gjOF6cBnG2McigrG8RY6E3AYE 25 | /nK4+zOrYfHBo7GL0Gd5JxmNL8Nu2QzqCLdr5cLw8gLbwmfgahU6siJc5GQrOsTS 26 | TEklyqEaDQqnvMYVsLvza/W3K8g6GX8Xo/NOBjYeeFcRPSP94g56JaE= 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /test/tls/testdata/ca.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDhTCCAm2gAwIBAgITM8zKCwt+VQxi6YGDirBkR12XmDANBgkqhkiG9w0BAQsF 3 | ADBbMQswCQYDVQQGEwJTRTEYMBYGA1UECAwPU3RvY2tob2xtcyBMw6RuMQ4wDAYD 4 | VQQHDAVUw6RieTEQMA4GA1UEChMHVGFkYSBBQjEQMA4GA1UEAxMHdGFkYS5zZTAe 5 | Fw0yMDAzMjMyMDQzMDBaFw0yNTAzMjIyMDQzMDBaMFsxCzAJBgNVBAYTAlNFMRgw 6 | FgYDVQQIDA9TdG9ja2hvbG1zIEzDpG4xDjAMBgNVBAcMBVTDpGJ5MRAwDgYDVQQK 7 | EwdUYWRhIEFCMRAwDgYDVQQDEwd0YWRhLnNlMIIBIjANBgkqhkiG9w0BAQEFAAOC 8 | AQ8AMIIBCgKCAQEA8qmvXNCsGLhQeq5/XIZp0bLL8lOqvNQz86U5QQwETnxs1KDU 9 | DquhNhzd2lgvaP98zO8Owkqz/kXIT4R/RxIeZ1BD9gY9iVVMyZar/bI+72DpX4Hg 10 | x7+UC9m34CAfK/AAYGZIgjhpDBobkVaKjVahBWUF7BWon+1xu23XdS99lxThJCTv 11 | G5gfquGFf4eSvhGSt7SRzW5nVP+OpTKxNMMawZa15jh3U7BfrXxDeZfDKFxF2RaK 12 | 6Lwi+ATLLe9wwvEKvCvDT8XPKsm9RUl48pLKRMy73tWitSh9SSFoa5zd49qIteLt 13 | LPKbNWPJLUyVrJzvbv9hVVJe4CKd6zKH+mYQywIDAQABo0IwQDAOBgNVHQ8BAf8E 14 | BAMCAQYwDwYDVR0TAQH/BAUwAwEB/zAdBgNVHQ4EFgQUoZgRiuciAK1H9qeAXuSa 15 | hHINGdIwDQYJKoZIhvcNAQELBQADggEBAA0T6GlzDPE/aBggBPVvJtBNgzsqWuf/ 16 | Z3osgtBq0T0Z5pe+1x+lYJaKADb/qgY7gpGuPVOuKJGvFiXAGknSYdR+PQwf7VNZ 17 | WTr89eR4oAS+qwF0hRTi93l8j5rZ1YgP1eUf9L2D6ycGydyoDJbLhIdCyEz9t89r 18 | ZnY5hj9xxA9dRJu3A9MAvBU9gGE0Nwh9BWp3//qrNo0ixgIZS1l6gm0CtBHRn1l/ 19 | YwOsYsQv4hOlf4M1CsHIYEtnOX7MDzoDYU+B9JEBvJbNoqLDQQ0XLWlp31+tMImm 20 | N1DNkOEObT9oogVb59PlBb8aHwnfT+wZ56FyVbzrkFSOTjgOX/t427Q= 21 | -----END CERTIFICATE----- 22 | -------------------------------------------------------------------------------- /test/tls/testdata/client-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEpAIBAAKCAQEAwOhHDGCm6kR9w/UbCcZ25YlBNO8gzOV2OXPT4xfo1x/ADwe2 3 | znf6hoIC+RTe+cFWH2cy/vw3XUy1SqloB7bgh8a+APPBaHI+p8RElE4+zVmhGsXN 4 | y0atdn0hecsJN4z3DrwwC6LdkSnfu1pNg7oO6DEyyewC/wZ7NmLZw+E4R+Vo+/nn 5 | Vri5cHIuwcD7wkr+dgIBeU9G7+YcM2hAjj069EC9ur5lYN0AkwcBhFo+4RXmmChv 6 | yzowvbOtSXWElA7XDCSLzPMjYsZ14fiCpDv/w1Fx2JpYZbbT0HlHfmik2utlocuX 7 | S4q5Q44hcMrv60p//IjRXfEnw/bZVuNbXfHbkQIDAQABAoIBAE8owcbtfnERi/42 8 | fVLkkvOcABsFqZMK8hmfUyqULCLiz4AbbUOKbk512Vx22Qzp7jpSsdV6kAmEKbyq 9 | iZroy3hL8LoZTJtcjiNv4aht901y4y5GTy2EIjhGHs+Ipo6aFOOCC8EqovsnkLyj 10 | 0L0mQ2m6jpnXdF9MPJFTvQKpT9wIKz6HXdBIQZJF7CKVIlXhaWgjc9nmE/pdXfKQ 11 | GsNKLZP6JGs04VfQVkIKO9/6kVeRONVraepl9kv5Wis3qRTP44xZ58lxKqymcBgX 12 | 2HUQsGeLsBK2vOfk1RXc7SwJg8H7h03ftp1nYNUlJWnHA8EeCkJzv+eS6V1j2I9y 13 | 69KgLKUCgYEA92VBIesA8AY3GDqjBqgOH69X5TFUq8ssFbAOA6hhFnORTbBbyHRZ 14 | kEBT0A+tl0U30Z61sLrk2M6gEvTPkyy2rz6B28PkObwtmKOLNjO9eeGevhnnDR2T 15 | bOQ08HNZrBxPiydOK+d8pfVgvp4Xj3qWrHdvBLVMLajamn8ojQFtAKsCgYEAx53f 16 | 2CEdgDns+DHetm2K8aqabESInV3i3o/IuWVVCaqjnGzgqTwxNg/Z9zmJcxvxa8/g 17 | INs8hTLFX2a+8ATXA7azDMI3ffFU0sll/fulA6DLlFZVZnZuxArtnqWiY1X9cQyo 18 | O3N1BfunzqXXlSzmK3QjMQa0mTxPsn+MeTnXLLMCgYAfEjl+8Av7GVy8D0lAYcT8 19 | V8JbR7nRpb/QrX7lGLWw4yzhq/+rCmnhQyMDo6Rytj/PdPZuztpFHJZgKx0S5++9 20 | zMT0fALi+W5kmE24rgDjGOIeEBTDwe4tI/A+Ls6ZXijjWjloLDeshEf1SNe+rm/U 21 | E1//IGID7gwekU/ffclZ5wKBgQC8hg7taUEaZBq4wSistEJAQTa8v/EiZpQoTDVv 22 | WxN4IK+KwY1gZ9e2Tjw18CIvE5nrj5UGkufSiIO9uSTlPDzxZfAuQZL1ICJTPSBV 23 | Qf+zsH30Z6EaNwofno6Sga4fEQxeY2zTURSZhPYUBa7YVWJAcdv2pnWUL1C5rRq3 24 | NvhQXwKBgQCB+wilOcResE0TXBNTXnNvx2/TPGL++BePi9KE/bArX0DpurgwAZ6y 25 | hwEz6yHhtVkAecGuorE82YzPXvWVx3AEOtXnO8AeEneNVAnXX00DbNQLSg+RCenR 26 | cF2ec8KBjbGObndFv5zYZVaHprvz+RW/hCRY51ZwWopFhDkTnSu8bQ== 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /test/tls/testdata/client.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIEADCCAuigAwIBAgIUadhpIhM+k/Twso977sUV0ktk32kwDQYJKoZIhvcNAQEL 3 | BQAwWzELMAkGA1UEBhMCU0UxGDAWBgNVBAgMD1N0b2NraG9sbXMgTMOkbjEOMAwG 4 | A1UEBwwFVMOkYnkxEDAOBgNVBAoTB1RhZGEgQUIxEDAOBgNVBAMTB3RhZGEuc2Uw 5 | HhcNMjAwMzIzMjA0MzAwWhcNMjEwMzIzMjA0MzAwWjBmMQswCQYDVQQGEwJTRTEY 6 | MBYGA1UECAwPU3RvY2tob2xtcyBMw6RuMQ4wDAYDVQQHDAVUw6RieTEQMA4GA1UE 7 | ChMHVGFkYSBBQjEbMBkGA1UEAxMSbXF0dGJyaWRnZS50YWRhLnNlMIIBIjANBgkq 8 | hkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwOhHDGCm6kR9w/UbCcZ25YlBNO8gzOV2 9 | OXPT4xfo1x/ADwe2znf6hoIC+RTe+cFWH2cy/vw3XUy1SqloB7bgh8a+APPBaHI+ 10 | p8RElE4+zVmhGsXNy0atdn0hecsJN4z3DrwwC6LdkSnfu1pNg7oO6DEyyewC/wZ7 11 | NmLZw+E4R+Vo+/nnVri5cHIuwcD7wkr+dgIBeU9G7+YcM2hAjj069EC9ur5lYN0A 12 | kwcBhFo+4RXmmChvyzowvbOtSXWElA7XDCSLzPMjYsZ14fiCpDv/w1Fx2JpYZbbT 13 | 0HlHfmik2utlocuXS4q5Q44hcMrv60p//IjRXfEnw/bZVuNbXfHbkQIDAQABo4Gw 14 | MIGtMA4GA1UdDwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUH 15 | AwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4EFgQUcjlXBU52VMYDbGeK6rtGdD+gs4cw 16 | HwYDVR0jBBgwFoAUoZgRiuciAK1H9qeAXuSahHINGdIwLgYDVR0RBCcwJYIJbG9j 17 | YWxob3N0ggR0YWRhggxob21lLnRhZGEuc2WHBH8AAAEwDQYJKoZIhvcNAQELBQAD 18 | ggEBACXkSDwf8hb5yp80fhmH0yP7cdArEFFisdNvCpdHoWFmvs7O4NK2bkegXP/3 19 | KfMlEom7Ctr01wNdg8rmxspHyVCSTV4h5xywLpXkcrhHzaSAY0hKgxwQABtaUQ0R 20 | oJPtvAJs53M4e6IXuBjdHtWKn1eOdR7/WbnvlTUbmyf3eHin9imkYZHAfT5j3d1m 21 | QfR5K2oX+psrYFUUGkyqFBSDyKl2iSSiC16NyRZ96Ils37KkL55GKUsNDSGLevB+ 22 | mTIH4tqVY1CY9PixX/MKl3h9V6GnUkVYKhdW9DgAdmiQUEaPmXvo+R8vRwH8Rfm7 23 | CGzBLfoPu9yiSWaDvCTBcWH4cmo= 24 | -----END CERTIFICATE----- 25 | -------------------------------------------------------------------------------- /test/tls/testdata/server-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN EC PRIVATE KEY----- 2 | MHcCAQEEINM8lQ1SOzLGSC0E1nVW+PUowZTs9LMgoB/gBwtyoV3loAoGCCqGSM49 3 | AwEHoUQDQgAEtd/igiJxP2uUk+lqijBxLQ/ChwN7M/tyznPHE60Xzx9V1UX3Tw9R 4 | 4DS4yy+vUW5ewoSVZ7xZ/oaOQURPg0lRkg== 5 | -----END EC PRIVATE KEY----- 6 | -------------------------------------------------------------------------------- /test/tls/testdata/server.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC1jCCAb6gAwIBAgIUAIeqiCsgQQ1qgTCB+Wkcpsq1+VIwDQYJKoZIhvcNAQEL 3 | BQAwWzELMAkGA1UEBhMCU0UxGDAWBgNVBAgMD1N0b2NraG9sbXMgTMOkbjEOMAwG 4 | A1UEBwwFVMOkYnkxEDAOBgNVBAoTB1RhZGEgQUIxEDAOBgNVBAMTB3RhZGEuc2Uw 5 | HhcNMjAwMzIzMjA0MzAwWhcNMjUwMzIyMjA0MzAwWjARMQ8wDQYDVQQDEwZTZXJ2 6 | ZXIwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS13+KCInE/a5ST6WqKMHEtD8KH 7 | A3sz+3LOc8cTrRfPH1XVRfdPD1HgNLjLL69Rbl7ChJVnvFn+ho5BRE+DSVGSo4Gm 8 | MIGjMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMB 9 | Af8EAjAAMB0GA1UdDgQWBBTXYQ46vfov5IKoKj1AELNM28uU/zAfBgNVHSMEGDAW 10 | gBShmBGK5yIArUf2p4Be5JqEcg0Z0jAuBgNVHREEJzAlgglsb2NhbGhvc3SCBHRh 11 | ZGGCDGhvbWUudGFkYS5zZYcEfwAAATANBgkqhkiG9w0BAQsFAAOCAQEAw+Cip5Zo 12 | TQdIhBz9Ejm3YYf9A0mch2Q1/fJ4ZpnoK2tBidB21SvSmnwuc7i925t0sJXzK1UR 13 | QSk49IXiVHBb2eIoyUyEffp1Jt9hvcqNKpanBWt0Y7Uta1gwNNhy20jhKWA4HkAJ 14 | D8GglyCRrHaoM8IF28V0wqD+scZjH1EjcNZzBJ0haUrjifoDWcrWruEW3b/KJ0eG 15 | Y3U+LkNzZvRyhpUUt5mCGcH3NfGPigd8XqP1LL8xLYkx0Lwd9McZLQv9EYLROZCE 16 | vRB6ZTSYcVAbGGTo1OrQed8y9c9IKmo8lVl9fMQgl6ICq3tu8mj0c+sc93wVfIy6 17 | 4tVLIbH2PcCM7Q== 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /test/utils/checks.go: -------------------------------------------------------------------------------- 1 | // +build citest 2 | 3 | // Package utils contains convenient testing checkers that compare a produced 4 | // value against an expected value (or condition). 5 | // There are value checks like `CheckEqual(expected, produced, t)``, and 6 | // checks that should run deferred like `defer ShouldPanic(t)`. 7 | // 8 | package utils 9 | 10 | import ( 11 | "reflect" 12 | "strings" 13 | "testing" 14 | "unsafe" 15 | ) 16 | 17 | // CheckEqual checks if two values are deeply equal and calls t.Fatalf if not 18 | func CheckEqual(expected interface{}, got interface{}, t *testing.T) { 19 | t.Helper() 20 | if !reflect.DeepEqual(expected, got) { 21 | t.Fatalf("Expected: %v, got %v", expected, got) 22 | } 23 | } 24 | 25 | // CheckNil checks if value is nil 26 | func CheckNil(got interface{}, t *testing.T) { 27 | t.Helper() 28 | rf := reflect.ValueOf(got) 29 | if rf.IsValid() && !rf.IsNil() { 30 | t.Fatalf("Expected: nil, got %v", got) 31 | } 32 | } 33 | 34 | // CheckNotNil checks if value is not nil 35 | func CheckNotNil(got interface{}, t *testing.T) { 36 | t.Helper() 37 | rf := reflect.ValueOf(got) 38 | if !rf.IsValid() || rf.IsNil() { 39 | t.Fatalf("Expected: not nil, got nil") 40 | } 41 | } 42 | 43 | // CheckError checks if there is an error 44 | func CheckError(got error, t *testing.T) { 45 | t.Helper() 46 | if got == nil { 47 | t.Fatalf("Expected: error, got %v", got) 48 | } 49 | } 50 | 51 | // CheckNotError checks if error value is not nil 52 | func CheckNotError(got error, t *testing.T) { 53 | t.Helper() 54 | if got != nil { 55 | t.Fatalf("Expected: no error, got %v", got) 56 | } 57 | } 58 | 59 | // CheckTrue checks if value is true 60 | func CheckTrue(got bool, t *testing.T) { 61 | t.Helper() 62 | if !got { 63 | t.Fatalf("Expected: true, got %v", got) 64 | } 65 | } 66 | 67 | // CheckFalse checks if value is false 68 | func CheckFalse(got bool, t *testing.T) { 69 | t.Helper() 70 | if got { 71 | t.Fatalf("Expected: false, got %v", got) 72 | } 73 | } 74 | 75 | // EnsureFailed checks that the given test function fails with an Error or Fatal 76 | func EnsureFailed(t *testing.T, f func(t *testing.T), substrings ...string) { 77 | tt := testing.T{} 78 | rs := reflect.ValueOf(&tt).Elem() 79 | x := make(chan bool, 1) 80 | go func() { 81 | defer func() { x <- true }() // GoExit runs all deferred calls 82 | f(&tt) 83 | }() 84 | <-x 85 | if !tt.Failed() { 86 | t.Fail() 87 | } 88 | if len(substrings) > 0 { 89 | // Pick the output bytes from the testing.T using an unsafe.Pointer. 90 | rf := rs.FieldByName("common").FieldByName("output") 91 | rf = reflect.NewAt(rf.Type(), unsafe.Pointer(rf.UnsafeAddr())).Elem() 92 | le := string(rf.Interface().([]byte)) 93 | for _, ss := range substrings { 94 | if !strings.Contains(le, ss) { 95 | t.Fatalf("string %q does not contain %q", le, ss) 96 | } 97 | } 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /test/utils/checks_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "io" 5 | "testing" 6 | ) 7 | 8 | func TestCheckEqual(t *testing.T) { 9 | EnsureFailed(t, func(ft *testing.T) { 10 | CheckEqual("a", "b", ft) 11 | }) 12 | } 13 | 14 | func TestCheckNil(t *testing.T) { 15 | EnsureFailed(t, func(ft *testing.T) { 16 | CheckNil([]byte{0}, ft) 17 | }) 18 | } 19 | 20 | func TestCheckNotNil(t *testing.T) { 21 | EnsureFailed(t, func(ft *testing.T) { 22 | CheckNotNil(nil, ft) 23 | }) 24 | } 25 | 26 | func TestCheckError(t *testing.T) { 27 | EnsureFailed(t, func(ft *testing.T) { 28 | CheckError(nil, ft) 29 | }) 30 | } 31 | 32 | func TestCheckNotError(t *testing.T) { 33 | EnsureFailed(t, func(ft *testing.T) { 34 | CheckNotError(io.ErrUnexpectedEOF, ft) 35 | }) 36 | } 37 | 38 | func TestCheckTrue(t *testing.T) { 39 | EnsureFailed(t, func(ft *testing.T) { 40 | CheckTrue(false, ft) 41 | }) 42 | } 43 | 44 | func TestCheckFalse(t *testing.T) { 45 | EnsureFailed(t, func(ft *testing.T) { 46 | CheckFalse(true, ft) 47 | }) 48 | } 49 | -------------------------------------------------------------------------------- /test/utils/logger.go: -------------------------------------------------------------------------------- 1 | // +build citest 2 | 3 | package utils 4 | 5 | import ( 6 | "github.com/tada/mqtt-nats/logger" 7 | ) 8 | 9 | // The T interface is fulfilled by *testing.T but can be implemented by a T mock if needed. 10 | type T interface { 11 | Helper() 12 | Log(...interface{}) 13 | } 14 | 15 | // NewLogger creates a new Logger configured to log at the given level. The logger uses the given 16 | // testing.T's Log function for all output. 17 | func NewLogger(l logger.Level, t T) logger.Logger { 18 | return &testLogger{l: l, t: t} 19 | } 20 | 21 | type testLogger struct { 22 | l logger.Level 23 | t T 24 | } 25 | 26 | func (t *testLogger) DebugEnabled() bool { 27 | return t.l >= logger.Debug 28 | } 29 | 30 | func (t *testLogger) Debug(args ...interface{}) { 31 | if t.DebugEnabled() { 32 | t.t.Helper() 33 | t.t.Log(addFirst("DEBUG", args)...) 34 | } 35 | } 36 | 37 | func (t *testLogger) ErrorEnabled() bool { 38 | return t.l >= logger.Error 39 | } 40 | 41 | func (t *testLogger) Error(args ...interface{}) { 42 | if t.ErrorEnabled() { 43 | t.t.Helper() 44 | t.t.Log(addFirst("ERROR", args)...) 45 | } 46 | } 47 | 48 | func (t *testLogger) InfoEnabled() bool { 49 | return t.l >= logger.Info 50 | } 51 | 52 | func (t *testLogger) Info(args ...interface{}) { 53 | if t.InfoEnabled() { 54 | t.t.Helper() 55 | t.t.Log(addFirst("INFO ", args)...) 56 | } 57 | } 58 | 59 | // addFirst prepends the client to the args slice and returns the new slice 60 | func addFirst(first string, args []interface{}) []interface{} { 61 | na := make([]interface{}, len(args)+1) 62 | na[0] = first 63 | copy(na[1:], args) 64 | return na 65 | } 66 | -------------------------------------------------------------------------------- /test/utils/logger_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/tada/mqtt-nats/logger" 7 | ) 8 | 9 | type mockT struct { 10 | logEntries [][]interface{} 11 | } 12 | 13 | func (m *mockT) Helper() { 14 | } 15 | 16 | func (m *mockT) Log(args ...interface{}) { 17 | m.logEntries = append(m.logEntries, args) 18 | } 19 | 20 | func Test_testLogger_Debug(t *testing.T) { 21 | tf := mockT{} 22 | lg := NewLogger(logger.Debug, &tf) 23 | CheckTrue(lg.DebugEnabled(), t) 24 | CheckTrue(lg.InfoEnabled(), t) 25 | CheckTrue(lg.ErrorEnabled(), t) 26 | lg.Debug("some stuff") 27 | CheckEqual(1, len(tf.logEntries), t) 28 | le := tf.logEntries[0] 29 | CheckEqual(2, len(le), t) 30 | CheckEqual("DEBUG", le[0], t) 31 | CheckEqual("some stuff", le[1], t) 32 | } 33 | 34 | func Test_testLogger_Info(t *testing.T) { 35 | tf := mockT{} 36 | lg := NewLogger(logger.Info, &tf) 37 | CheckFalse(lg.DebugEnabled(), t) 38 | CheckTrue(lg.InfoEnabled(), t) 39 | CheckTrue(lg.ErrorEnabled(), t) 40 | lg.Info("some stuff") 41 | CheckEqual(1, len(tf.logEntries), t) 42 | le := tf.logEntries[0] 43 | CheckEqual(2, len(le), t) 44 | CheckEqual("INFO ", le[0], t) 45 | CheckEqual("some stuff", le[1], t) 46 | } 47 | 48 | func Test_testLogger_Error(t *testing.T) { 49 | tf := mockT{} 50 | lg := NewLogger(logger.Error, &tf) 51 | CheckFalse(lg.DebugEnabled(), t) 52 | CheckFalse(lg.InfoEnabled(), t) 53 | CheckTrue(lg.ErrorEnabled(), t) 54 | lg.Error("some stuff") 55 | CheckEqual(1, len(tf.logEntries), t) 56 | le := tf.logEntries[0] 57 | CheckEqual(2, len(le), t) 58 | CheckEqual("ERROR", le[0], t) 59 | CheckEqual("some stuff", le[1], t) 60 | } 61 | 62 | func TestClient_String(t *testing.T) { 63 | 64 | } 65 | -------------------------------------------------------------------------------- /test/utils/panics.go: -------------------------------------------------------------------------------- 1 | // +build citest 2 | 3 | package utils 4 | 5 | import "testing" 6 | 7 | // ShouldNotPanic is used to assert that a function does not panic 8 | // Usage: defer testutils.ShouldNotPanic(t) at the point where the rest is expected to not panic 9 | func ShouldNotPanic(t *testing.T) { 10 | t.Helper() 11 | if r := recover(); r != nil { 12 | t.Error("Unexpected panic") 13 | } 14 | } 15 | 16 | // ShouldPanic is used to assert that a function does panic 17 | // Usage: defer testutils.ShouldNotPanic(t) at the point where the rest is expected to panic 18 | func ShouldPanic(t *testing.T) { 19 | t.Helper() 20 | if r := recover(); r == nil { 21 | t.Error("Expected panic but got none") 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /test/utils/panics_test.go: -------------------------------------------------------------------------------- 1 | package utils 2 | 3 | import ( 4 | "errors" 5 | "testing" 6 | ) 7 | 8 | func TestShouldNotPanic(t *testing.T) { 9 | EnsureFailed(t, func(ft *testing.T) { 10 | defer ShouldNotPanic(ft) 11 | panic(errors.New("but it did")) 12 | }) 13 | } 14 | 15 | func TestShouldPanic(t *testing.T) { 16 | EnsureFailed(t, func(ft *testing.T) { 17 | defer ShouldPanic(ft) 18 | // but didn't 19 | }) 20 | } 21 | --------------------------------------------------------------------------------