├── .gitignore ├── LICENSE ├── README.md ├── client.go ├── clientstate.go ├── decoder_lowmem.go ├── definitions.go ├── encode.go ├── example_test.go ├── go.mod ├── mqtt.go ├── mqtt_test.go ├── rxtx.go ├── subscriptions.go └── travis.yml /.gitignore: -------------------------------------------------------------------------------- 1 | # Binaries for programs and plugins 2 | *.exe 3 | *.exe~ 4 | *.dll 5 | *.so 6 | *.dylib 7 | 8 | # Test binary, built with `go test -c` 9 | *.test 10 | 11 | # Output of the go coverage tool, specifically when used with LiteIDE 12 | *.out 13 | .vscode 14 | # Dependency directories (remove the comment below to include it) 15 | # vendor/ 16 | 17 | coverage.txt 18 | local_test.go -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Patricio Whittingslow 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Go Report Card](https://goreportcard.com/badge/github.com/soypat/natiu-mqtt)](https://goreportcard.com/report/github.com/soypat/natiu-mqtt) 2 | [![GoDoc](https://godoc.org/github.com/soypat/natiu-mqtt?status.svg)](https://godoc.org/github.com/soypat/natiu-mqtt) 3 | [![codecov](https://codecov.io/gh/soypat/natiu-mqtt/branch/main/graph/badge.svg)](https://codecov.io/gh/soypat/natiu-mqtt/branch/main) 4 | 5 | # natiu-mqtt 6 | ### A dead-simple, extensible and correct MQTT implementation. 7 | 8 | **Natiu**: Means *mosquito* in the [Guaraní language](https://en.wikipedia.org/wiki/Guarani_language), a language spoken primarily in Paraguay. Commonly written as ñati'û or ñati'ũ. 9 | 10 | ## Highlights 11 | * **Modular** 12 | * Client implementation leaves allocating parts up to the [`Decoder`](./mqtt.go) interface type. Users can choose to use non-allocating or allocating implementations of the 3 method interface. 13 | * [`RxTx`](./rxtx.go) type lets one build an MQTT implementation from scratch for any transport. No server/client logic defined at this level. 14 | 15 | * **No uneeded allocations**: The PUBLISH application message is not handled by this library, the user receives an `io.Reader` with the underlying transport bytes. This prevents allocations on `natiu-mqtt` side. 16 | * **V3.1.1**: Compliant with [MQTT version 3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) for QoS0 interactions. _QoS1 and QoS2 are WIP._ 17 | * **No external dependencies**: Nada. Nope. 18 | * **Data oriented design**: Minimizes abstractions or objects for the data on the wire. 19 | * **Fuzz tested, robust**: Decoding implementation fuzzed to prevent adversarial user input from crashing application (95% coverage). 20 | * **Simplicity**: A simple base package yields simple implementations for different transports. See [Implementations section](#implementations). 21 | * **Runtime-what?**: Unlike other MQTT implementations. **No** channels, **no** interface conversions, **no** goroutines- as little runtimey stuff as possible. You get the best of Go's concrete types when using Natiu's API. Why? Because MQTT deserialization and serialization are an *embarrassingly* serial and concrete problem. 22 | 23 | ## Goals 24 | This implementation will have a simple embedded-systems implementation in the package 25 | top level. This implementation will be transport agnostic and non-concurrent. This will make it far easier to modify and reason about. The transport dependent implementations will have their own subpackage, so one package for TCP transport, another for UART, PPP etc. 26 | 27 | * Minimal, if any, heap allocations. 28 | * Support for TCP transport. 29 | * User owns payload bytes. 30 | 31 | ## Implementations 32 | - [natiu-wsocket](https://github.com/soypat/natiu-wsocket): MQTT via **Websockets**. Tested with [moscajs/aedes broker server.](https://github.com/moscajs/aedes). 33 | 34 | ## Examples 35 | API subject to before v1.0.0 release. 36 | 37 | ### Example use of `Client` 38 | 39 | ```go 40 | // Create new client. 41 | client := mqtt.NewClient(mqtt.ClientConfig{ 42 | Decoder: mqtt.DecoderNoAlloc{make([]byte, 1500)}, 43 | OnPub: func(_ mqtt.Header, _ mqtt.VariablesPublish, r io.Reader) error { 44 | message, _ := io.ReadAll(r) 45 | log.Println("received message:", string(message)) 46 | return nil 47 | }, 48 | }) 49 | 50 | // Get a transport for MQTT packets. 51 | const defaultMQTTPort = ":1883" 52 | conn, err := net.Dial("tcp", "127.0.0.1"+defaultMQTTPort) 53 | if err != nil { 54 | fmt.Println(err) 55 | return 56 | } 57 | 58 | // Prepare for CONNECT interaction with server. 59 | var varConn mqtt.VariablesConnect 60 | varConn.SetDefaultMQTT([]byte("salamanca")) 61 | ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) 62 | err = client.Connect(ctx, conn, &varConn) // Connect to server. 63 | cancel() 64 | if err != nil { 65 | // Error or loop until connect success. 66 | log.Fatalf("connect attempt failed: %v\n", err) 67 | } 68 | 69 | // Ping forever until error. 70 | for { 71 | ctx, cancel := context.WithTimeout(context.Background(), time.Second) 72 | pingErr := client.Ping(ctx) 73 | cancel() 74 | if pingErr != nil { 75 | log.Fatal("ping error: ", pingErr, " with disconnect reason:", client.Err()) 76 | } 77 | log.Println("ping success!") 78 | } 79 | ``` 80 | 81 | ## Why not just use paho? 82 | 83 | Some issues with Eclipse's Paho implementation: 84 | * [Inherent data races on API side](https://github.com/eclipse/paho.mqtt.golang/issues/550). The implementation is so notoriously hard to modify this issue has been in a frozen state. 85 | * Calling Client.Disconnect when client is already disconnected blocks indefinetely and can cause deadlock or spin with Paho's implementation. 86 | * If there is an issue with the network and Reconnect is enabled then then Paho's Reconnect spins. There is no way to prevent this. 87 | * Interfaces used for ALL data types. This is not necessary and makes it difficult to work with since there is no in-IDE documentation on interface methods. 88 | * No lower level abstraction of MQTT for use in embedded systems with non-TCP transport. 89 | * Uses `any` interface for the payload, which could simply be a byte slice... 90 | 91 | I found these issues after a 2 hour dev session. There will undoubtedly be more if I were to try to actually get it working... 92 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "io" 7 | "net" 8 | "sync" 9 | "time" 10 | ) 11 | 12 | var ( 13 | errDisconnected = errors.New("natiu-mqtt: disconnected") 14 | ) 15 | 16 | // Client is a asynchronous MQTT v3.1.1 client implementation which is 17 | // safe for concurrent use. 18 | type Client struct { 19 | cs clientState 20 | 21 | rxlock sync.Mutex 22 | rx Rx 23 | 24 | txlock sync.Mutex 25 | tx Tx 26 | } 27 | 28 | // ClientConfig is used to configure a new Client. 29 | type ClientConfig struct { 30 | // If a Decoder is not set one will automatically be picked. 31 | Decoder Decoder 32 | // OnPub is executed on every PUBLISH message received. Do not call 33 | // HandleNext or other client methods from within this function. 34 | OnPub func(pubHead Header, varPub VariablesPublish, r io.Reader) error 35 | // TODO: add a backoff algorithm callback here so clients can roll their own. 36 | } 37 | 38 | // NewClient creates a new MQTT client with the configuration parameters provided. 39 | // If no Decoder is provided a DecoderNoAlloc will be used. 40 | func NewClient(cfg ClientConfig) *Client { 41 | var onPub func(rx *Rx, varPub VariablesPublish, r io.Reader) error 42 | if cfg.OnPub != nil { 43 | onPub = func(rx *Rx, varPub VariablesPublish, r io.Reader) error { 44 | return cfg.OnPub(rx.LastReceivedHeader, varPub, r) 45 | } 46 | } 47 | if cfg.Decoder == nil { 48 | cfg.Decoder = DecoderNoAlloc{UserBuffer: make([]byte, 4*1024)} 49 | } 50 | c := &Client{cs: clientState{closeErr: errors.New("yet to connect")}} 51 | c.rx.RxCallbacks, c.tx.TxCallbacks = c.cs.callbacks(onPub) 52 | c.rx.userDecoder = cfg.Decoder 53 | return c 54 | } 55 | 56 | // HandleNext reads from the wire and decodes MQTT packets. 57 | // If bytes are read and the decoder fails to read a packet the whole 58 | // client fails and disconnects. 59 | // HandleNext only returns an error in the case where the OnPub callback passed 60 | // in the ClientConfig returns an error or if a packet is malformed. 61 | // If HandleNext returns an error the client will be in a disconnected state. 62 | func (c *Client) HandleNext() error { 63 | n, err := c.readNextWrapped() 64 | if err != nil && c.IsConnected() { 65 | if n != 0 || errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { 66 | // We disconnect if: 67 | // - We've read a malformed packet: n!=0 68 | // - We receive an error signalling end of data (EOF) or closing of network connection. 69 | // We don't want to disconnect if we've read 0 bytes and get a timeout error (there may be more data in future) 70 | c.cs.OnDisconnect(err) 71 | c.txlock.Lock() 72 | c.tx.WriteSimple(PacketDisconnect) // Try to write disconnect but don't hold your breath. This is probably useless. 73 | c.txlock.Unlock() 74 | } else { 75 | // Not a any of above cases. We stay connected and ignore error, but print it. 76 | println("ignoring error", err.Error()) 77 | err = nil 78 | } 79 | } 80 | return err 81 | } 82 | 83 | // readNextWrapped is a separate function so mutex locks Rx for minimum amount of time. 84 | func (c *Client) readNextWrapped() (int, error) { 85 | c.rxlock.Lock() 86 | defer c.rxlock.Unlock() 87 | if !c.IsConnected() && c.cs.lastTx.IsZero() { 88 | // Client disconnected and not expecting to receive packets back. 89 | return 0, errDisconnected 90 | } 91 | return c.rx.ReadNextPacket() 92 | } 93 | 94 | // StartConnect sends a CONNECT packet over the transport and does not wait for a 95 | // CONNACK response. Client is not guaranteed to be connected after a call to this function. 96 | func (c *Client) StartConnect(rwc io.ReadWriteCloser, vc *VariablesConnect) error { 97 | c.rxlock.Lock() 98 | defer c.rxlock.Unlock() 99 | c.txlock.Lock() 100 | defer c.txlock.Unlock() 101 | c.tx.SetTxTransport(rwc) 102 | c.rx.SetRxTransport(rwc) 103 | if c.cs.IsConnected() { 104 | return errors.New("already connected; disconnect before connecting") 105 | } 106 | return c.tx.WriteConnect(vc) 107 | } 108 | 109 | // Connect sends a CONNECT packet over the transport and waits for a 110 | // CONNACK response from the server. The client is connected if the returned error is nil. 111 | func (c *Client) Connect(ctx context.Context, rwc io.ReadWriteCloser, vc *VariablesConnect) error { 112 | err := c.StartConnect(rwc, vc) 113 | if err != nil { 114 | return err 115 | } 116 | backoff := newBackoff() 117 | for !c.IsConnected() && ctx.Err() == nil { 118 | backoff.Miss() 119 | err := c.HandleNext() 120 | if err != nil { 121 | return err 122 | } 123 | } 124 | if c.IsConnected() { 125 | return nil 126 | } 127 | return ctx.Err() 128 | } 129 | 130 | // IsConnected returns true if there still has been no disconnect event or an 131 | // unrecoverable error encountered during decoding. 132 | // A Connected client may send and receive MQTT messages. 133 | func (c *Client) IsConnected() bool { return c.cs.IsConnected() } 134 | 135 | // Disconnect performs a MQTT disconnect and resets the connection. Future 136 | // calls to Err will return the argument userErr. 137 | func (c *Client) Disconnect(userErr error) error { 138 | if userErr == nil { 139 | panic("nil error argument to Disconnect") 140 | } 141 | c.txlock.Lock() 142 | defer c.txlock.Unlock() 143 | if !c.IsConnected() { 144 | return errDisconnected 145 | } 146 | c.cs.OnDisconnect(userErr) 147 | err := c.tx.WriteSimple(PacketDisconnect) 148 | if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { 149 | err = nil //if EOF or network closed simply exit. 150 | } 151 | c.rxlock.Lock() 152 | defer c.rxlock.Unlock() 153 | c.rx.rxTrp.Close() 154 | c.tx.txTrp.Close() 155 | return err 156 | } 157 | 158 | // StartSubscribe begins subscription to argument topics. 159 | func (c *Client) StartSubscribe(vsub VariablesSubscribe) error { 160 | if err := vsub.Validate(); err != nil { 161 | return err 162 | } 163 | c.txlock.Lock() 164 | defer c.txlock.Unlock() 165 | if !c.IsConnected() { 166 | return errDisconnected 167 | } 168 | if c.AwaitingSuback() { 169 | // TODO(soypat): Allow multiple subscriptions to be queued. 170 | return errors.New("tried to subscribe while still awaiting suback") 171 | } 172 | c.cs.pendingSubs = vsub.Copy() 173 | return c.tx.WriteSubscribe(vsub) 174 | } 175 | 176 | // Subscribe writes a SUBSCRIBE packet over the network and waits for the server 177 | // to respond with a SUBACK packet or until the context ends. 178 | func (c *Client) Subscribe(ctx context.Context, vsub VariablesSubscribe) error { 179 | session := c.ConnectedAt() 180 | err := c.StartSubscribe(vsub) 181 | if err != nil { 182 | return err 183 | } 184 | backoff := newBackoff() 185 | for c.cs.PendingSublen() != 0 && ctx.Err() == nil { 186 | if c.ConnectedAt() != session { 187 | // Prevent waiting on subscribes from previous connection or during disconnection. 188 | return errDisconnected 189 | } 190 | backoff.Miss() 191 | c.HandleNext() 192 | } 193 | return ctx.Err() 194 | } 195 | 196 | // SubscribedTopics returns list of topics the client successfully subscribed to. 197 | // Returns a copy of a slice so is safe for concurrent use. 198 | func (c *Client) SubscribedTopics() []string { 199 | c.cs.mu.Lock() 200 | defer c.cs.mu.Unlock() 201 | return append([]string{}, c.cs.activeSubs...) 202 | } 203 | 204 | // PublishPayload sends a PUBLISH packet over the network on the topic defined by 205 | // varPub. 206 | func (c *Client) PublishPayload(flags PacketFlags, varPub VariablesPublish, payload []byte) error { 207 | if err := varPub.Validate(); err != nil { 208 | return err 209 | } 210 | qos := flags.QoS() 211 | if qos != QoS0 { 212 | return errors.New("only supports QoS0") 213 | } 214 | c.txlock.Lock() 215 | defer c.txlock.Unlock() 216 | if !c.IsConnected() { 217 | return errDisconnected 218 | } 219 | return c.tx.WritePublishPayload(newHeader(PacketPublish, flags, uint32(varPub.Size(qos)+len(payload))), varPub, payload) 220 | } 221 | 222 | // Err returns error indicating the cause of client disconnection. 223 | func (c *Client) Err() error { 224 | return c.cs.Err() 225 | } 226 | 227 | // StartPing writes a PINGREQ packet over the network without blocking waiting for response. 228 | func (c *Client) StartPing() error { 229 | c.txlock.Lock() 230 | defer c.txlock.Unlock() 231 | if !c.IsConnected() { 232 | return errDisconnected 233 | } 234 | err := c.tx.WriteSimple(PacketPingreq) 235 | if err == nil { 236 | c.cs.PingSent() // Flag the fact that a ping has been sent successfully. 237 | } 238 | return err 239 | } 240 | 241 | // Ping writes a ping packet over the network and blocks until it receives the ping 242 | // response back. It uses an exponential backoff algorithm to time checks on the 243 | // status of the ping. 244 | func (c *Client) Ping(ctx context.Context) error { 245 | session := c.ConnectedAt() 246 | err := c.StartPing() 247 | if err != nil { 248 | return err 249 | } 250 | pingTime := c.cs.LastPingTime() 251 | if pingTime.IsZero() { 252 | return nil // Ping completed. 253 | } 254 | backoff := newBackoff() 255 | for pingTime == c.cs.LastPingTime() && ctx.Err() == nil { 256 | if c.ConnectedAt() != session { 257 | // Prevent waiting on subscribes from previous connection or during disconnection. 258 | return errDisconnected 259 | } 260 | backoff.Miss() 261 | c.HandleNext() 262 | } 263 | return ctx.Err() 264 | } 265 | 266 | // AwaitingPingresp checks if a ping sent over the wire had no response received back. 267 | func (c *Client) AwaitingPingresp() bool { return c.cs.AwaitingPingresp() } 268 | 269 | // ConnectedAt returns the time the client managed to successfully connect. If 270 | // client is disconnected ConnectedAt returns the zero-value for time.Time. 271 | func (c *Client) ConnectedAt() time.Time { return c.cs.ConnectedAt() } 272 | 273 | // AwaitingSuback checks if a subscribe request sent over the wire had no suback received back. 274 | // Returns false if client is disconnected. 275 | func (c *Client) AwaitingSuback() bool { return c.cs.AwaitingSuback() } 276 | 277 | // LastRx returns the time the last packet was received at. 278 | // If Client is disconnected LastRx returns the zero value of time.Time. 279 | func (c *Client) LastRx() time.Time { return c.cs.LastRx() } 280 | 281 | // LastTx returns the time the last successful packet transmission finished at. 282 | // A "successful" transmission does not necessarily mean the packet was received on the other end. 283 | // If Client is disconnected LastTx returns the zero value of time.Time. 284 | func (c *Client) LastTx() time.Time { return c.cs.LastTx() } 285 | 286 | func newBackoff() exponentialBackoff { 287 | return exponentialBackoff{ 288 | MaxWait: 500 * time.Millisecond, 289 | } 290 | } 291 | 292 | // exponentialBackoff implements a [Exponential Backoff] 293 | // delay algorithm to prevent saturation network or processor 294 | // with failing tasks. An exponentialBackoff with a non-zero MaxWait is ready for use. 295 | // 296 | // [Exponential Backoff]: https://en.wikipedia.org/wiki/Exponential_backoff 297 | type exponentialBackoff struct { 298 | // Wait defines the amount of time that Miss will wait on next call. 299 | Wait time.Duration 300 | // Maximum allowable value for Wait. 301 | MaxWait time.Duration 302 | // StartWait is the value that Wait takes after a call to Hit. 303 | StartWait time.Duration 304 | // ExpMinusOne is the shift performed on Wait minus one, so the zero value performs a shift of 1. 305 | ExpMinusOne uint32 306 | } 307 | 308 | // Hit sets eb.Wait to the StartWait value. 309 | func (eb *exponentialBackoff) Hit() { 310 | if eb.MaxWait == 0 { 311 | panic("MaxWait cannot be zero") 312 | } 313 | eb.Wait = eb.StartWait 314 | } 315 | 316 | // Miss sleeps for eb.Wait and increases eb.Wait exponentially. 317 | func (eb *exponentialBackoff) Miss() { 318 | const k = 1 319 | wait := eb.Wait 320 | maxWait := eb.MaxWait 321 | exp := eb.ExpMinusOne + 1 322 | if maxWait == 0 { 323 | panic("MaxWait cannot be zero") 324 | } 325 | time.Sleep(wait) 326 | wait |= time.Duration(k) 327 | wait <<= exp 328 | if wait > maxWait { 329 | wait = maxWait 330 | } 331 | eb.Wait = wait 332 | } 333 | -------------------------------------------------------------------------------- /clientstate.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "sync" 7 | "time" 8 | ) 9 | 10 | type clientState struct { 11 | mu sync.Mutex 12 | lastRx time.Time 13 | lastTx time.Time 14 | connectedAt time.Time 15 | activeSubs []string 16 | // field flag indicates we received a ping request from server and need to reply. 17 | pendingPingreq time.Time 18 | // field flags we are waiting on a ping response packet from server. 19 | pendingPingresp time.Time 20 | // closeErr stores the reason for disconnection. 21 | closeErr error 22 | pendingSubs VariablesSubscribe 23 | } 24 | 25 | // onConnect is meant to be called on opening a new connection to delete 26 | // previous connection state. Not guarded by mutex. 27 | func (cs *clientState) onConnect(t time.Time) { 28 | cs.closeErr = nil 29 | if cs.activeSubs == nil { 30 | cs.activeSubs = make([]string, 2) 31 | } 32 | cs.activeSubs = cs.activeSubs[:0] 33 | cs.lastRx = t 34 | cs.connectedAt = t 35 | cs.pendingSubs = VariablesSubscribe{} 36 | } 37 | 38 | // onConnect is meant to be called on opening a new connection to delete 39 | // previous connection state. 40 | func (cs *clientState) OnDisconnect(err error) { 41 | cs.mu.Lock() 42 | defer cs.mu.Unlock() 43 | cs.onDisconnect(err) 44 | } 45 | 46 | //go:inline 47 | func (cs *clientState) onDisconnect(err error) { 48 | if err == nil { 49 | panic("onDisconnect expects non-nil error") 50 | } 51 | cs.closeErr = err 52 | cs.connectedAt = time.Time{} 53 | cs.lastRx = time.Time{} 54 | cs.lastTx = time.Time{} 55 | cs.pendingPingreq = time.Time{} 56 | cs.pendingPingresp = time.Time{} 57 | cs.pendingSubs = VariablesSubscribe{} 58 | } 59 | 60 | // callbacks returns the Rx and Tx callbacks necessary for a clientState to function automatically. 61 | // The onPub callback 62 | func (cs *clientState) callbacks(onPub func(rx *Rx, varPub VariablesPublish, r io.Reader) error) (RxCallbacks, TxCallbacks) { 63 | return RxCallbacks{ 64 | OnConnack: func(r *Rx, vc VariablesConnack) error { 65 | connTime := time.Now() 66 | cs.mu.Lock() 67 | defer cs.mu.Unlock() 68 | cs.lastRx = connTime 69 | if cs.closeErr == nil { 70 | return errors.New("connack received while connected") 71 | } 72 | if vc.ReturnCode != 0 { 73 | return vc.ReturnCode 74 | } 75 | cs.onConnect(connTime) 76 | return nil 77 | }, 78 | OnPub: onPub, 79 | OnSuback: func(r *Rx, vs VariablesSuback) error { 80 | rxTime := time.Now() 81 | cs.mu.Lock() 82 | defer cs.mu.Unlock() 83 | cs.lastRx = rxTime 84 | if len(vs.ReturnCodes) != len(cs.pendingSubs.TopicFilters) { 85 | return errors.New("got mismatched number of return codes compared to pending client subscriptions") 86 | } 87 | for i, qos := range vs.ReturnCodes { 88 | if qos != QoSSubfail { 89 | if qos != cs.pendingSubs.TopicFilters[i].QoS { 90 | return errors.New("QoS does not match requested QoS for topic") 91 | } 92 | cs.activeSubs = append(cs.activeSubs, string(cs.pendingSubs.TopicFilters[i].TopicFilter)) 93 | } 94 | } 95 | cs.pendingSubs.TopicFilters = cs.pendingSubs.TopicFilters[:0] 96 | return nil 97 | }, 98 | OnOther: func(rx *Rx, packetIdentifier uint16) (err error) { 99 | tp := rx.LastReceivedHeader.Type() 100 | rxTime := time.Now() 101 | cs.mu.Lock() 102 | defer cs.mu.Unlock() 103 | cs.lastRx = rxTime 104 | switch tp { 105 | case PacketDisconnect: 106 | err = errDisconnected 107 | case PacketPingreq: 108 | cs.pendingPingreq = rxTime 109 | case PacketPingresp: 110 | cs.pendingPingresp = time.Time{} // got the response, we can unflag. 111 | default: 112 | println("unexpected packet type: ", tp.String()) 113 | } 114 | if err != nil { 115 | cs.onDisconnect(err) 116 | } 117 | return err 118 | }, 119 | OnRxError: func(r *Rx, err error) { 120 | cs.onDisconnect(err) 121 | }, 122 | }, TxCallbacks{ 123 | OnTxError: func(tx *Tx, err error) { 124 | cs.onDisconnect(err) 125 | }, 126 | OnSuccessfulTx: func(tx *Tx) { 127 | cs.mu.Lock() 128 | defer cs.mu.Unlock() 129 | cs.lastTx = time.Now() 130 | }, 131 | } 132 | } 133 | 134 | // IsConnected returns true if the client is currently connected. 135 | func (cs *clientState) IsConnected() bool { 136 | cs.mu.Lock() 137 | defer cs.mu.Unlock() 138 | if cs.connectedAt.IsZero() != (cs.closeErr != nil) { 139 | panic("assertion failed: bug in natiu-mqtt clientState implementation") 140 | } 141 | return cs.closeErr == nil 142 | } 143 | 144 | // Err returns the error that caused the MQTT connection to finish. 145 | // Returns nil if currently connected. 146 | func (cs *clientState) Err() error { 147 | cs.mu.Lock() 148 | defer cs.mu.Unlock() 149 | if cs.connectedAt.IsZero() != (cs.closeErr != nil) { 150 | panic("assertion failed: bug in natiu-mqtt clientState implementation") 151 | } 152 | return cs.closeErr 153 | } 154 | 155 | // PendingResponse returns true if the client is waiting on the server for a response. 156 | func (cs *clientState) PendingResponse() bool { 157 | cs.mu.Lock() 158 | defer cs.mu.Unlock() 159 | return cs.closeErr == nil && (len(cs.pendingSubs.TopicFilters) > 0 || !cs.pendingPingreq.IsZero()) 160 | } 161 | 162 | func (cs *clientState) AwaitingPingresp() bool { 163 | cs.mu.Lock() 164 | defer cs.mu.Unlock() 165 | return !cs.pendingPingresp.IsZero() 166 | } 167 | 168 | func (cs *clientState) AwaitingSuback() bool { 169 | cs.mu.Lock() 170 | defer cs.mu.Unlock() 171 | return cs.awaitingSuback() 172 | } 173 | func (cs *clientState) awaitingSuback() bool { 174 | return len(cs.pendingSubs.TopicFilters) > 0 175 | } 176 | 177 | func (cs *clientState) RegisterSubscribe(vsub VariablesSubscribe) error { 178 | if len(vsub.TopicFilters) == 0 { 179 | return errors.New("need at least one topic to subscribe") 180 | } 181 | cs.mu.Lock() 182 | defer cs.mu.Unlock() 183 | if cs.awaitingSuback() { 184 | return errors.New("tried to register subscribe while awaiting suback") 185 | } 186 | cs.pendingSubs = vsub.Copy() 187 | return nil 188 | } 189 | func (cs *clientState) LastPingTime() time.Time { 190 | cs.mu.Lock() 191 | defer cs.mu.Unlock() 192 | return cs.pendingPingresp 193 | } 194 | 195 | func (cs *clientState) PendingSublen() int { 196 | cs.mu.Lock() 197 | defer cs.mu.Unlock() 198 | return len(cs.pendingSubs.TopicFilters) 199 | } 200 | 201 | func (cs *clientState) ConnectedAt() time.Time { 202 | cs.mu.Lock() 203 | defer cs.mu.Unlock() 204 | return cs.connectedAt 205 | } 206 | 207 | func (cs *clientState) LastTx() time.Time { 208 | cs.mu.Lock() 209 | defer cs.mu.Unlock() 210 | return cs.lastTx 211 | } 212 | 213 | func (cs *clientState) PingSent() { 214 | cs.mu.Lock() 215 | defer cs.mu.Unlock() 216 | cs.pendingPingresp = time.Now() 217 | } 218 | 219 | func (cs *clientState) LastRx() time.Time { 220 | cs.mu.Lock() 221 | defer cs.mu.Unlock() 222 | return cs.lastRx 223 | } 224 | -------------------------------------------------------------------------------- /decoder_lowmem.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "bytes" 5 | "errors" 6 | "io" 7 | ) 8 | 9 | // DecoderNoAlloc implements the [Decoder] interface for unmarshalling Variable headers 10 | // of MQTT packets. This particular implementation avoids heap allocations to ensure 11 | // minimal memory usage during decoding. The UserBuffer is used up to it's length. 12 | // Decode Calls that receive strings invalidate strings decoded in previous calls. 13 | // Needless to say, this implementation is NOT safe for concurrent use. 14 | // Calls that allocate strings or bytes are contained in the [Decoder] interface. 15 | type DecoderNoAlloc struct { 16 | UserBuffer []byte 17 | } 18 | 19 | // DecodeConnect implements [Decoder] interface. 20 | func (d DecoderNoAlloc) DecodeConnect(r io.Reader) (varConn VariablesConnect, n int, err error) { 21 | payloadDst := d.UserBuffer 22 | var ngot int 23 | varConn.Protocol, n, err = decodeMQTTString(r, payloadDst) 24 | if err != nil { 25 | return VariablesConnect{}, n, err 26 | } 27 | payloadDst = payloadDst[len(varConn.Protocol):] 28 | varConn.ProtocolLevel, err = decodeByte(r) 29 | if err != nil { 30 | return VariablesConnect{}, n, err 31 | } 32 | n++ 33 | flags, err := decodeByte(r) 34 | if err != nil { 35 | return VariablesConnect{}, n, err 36 | } 37 | n++ 38 | if flags&1 != 0 { // [MQTT-3.1.2-3]. 39 | return VariablesConnect{}, n, errors.New("reserved bit set in CONNECT flag") 40 | } 41 | userNameFlag := flags&(1<<7) != 0 42 | passwordFlag := flags&(1<<6) != 0 43 | varConn.WillRetain = flags&(1<<5) != 0 44 | varConn.WillQoS = QoSLevel(flags>>3) & 0b11 45 | willFlag := flags&(1<<2) != 0 46 | varConn.CleanSession = flags&(1<<1) != 0 47 | if passwordFlag && !userNameFlag { 48 | return VariablesConnect{}, n, errors.New("username flag must be set to use password flag") 49 | } 50 | 51 | varConn.KeepAlive, ngot, err = decodeUint16(r) 52 | n += ngot 53 | if err != nil { 54 | return VariablesConnect{}, n, err 55 | } 56 | varConn.ClientID, ngot, err = decodeMQTTString(r, payloadDst) 57 | if err != nil { 58 | return VariablesConnect{}, n, err 59 | } 60 | n += ngot 61 | payloadDst = payloadDst[len(varConn.ClientID):] 62 | 63 | if willFlag { 64 | varConn.WillTopic, ngot, err = decodeMQTTString(r, payloadDst) 65 | n += ngot 66 | if err != nil { 67 | return VariablesConnect{}, n, err 68 | } 69 | payloadDst = payloadDst[len(varConn.WillTopic):] 70 | varConn.WillMessage, ngot, err = decodeMQTTString(r, payloadDst) 71 | n += ngot 72 | if err != nil { 73 | return VariablesConnect{}, n, err 74 | } 75 | payloadDst = payloadDst[len(varConn.WillMessage):] 76 | } 77 | 78 | if userNameFlag { 79 | // Username and password. 80 | varConn.Username, ngot, err = decodeMQTTString(r, payloadDst) 81 | n += ngot 82 | if err != nil { 83 | return VariablesConnect{}, n, err 84 | } 85 | if passwordFlag { 86 | payloadDst = payloadDst[len(varConn.Username):] 87 | varConn.Password, ngot, err = decodeMQTTString(r, payloadDst) 88 | n += ngot 89 | if err != nil { 90 | return VariablesConnect{}, n, err 91 | } 92 | } 93 | } 94 | return varConn, n, nil 95 | } 96 | 97 | // DecodePublish implements [Decoder] interface. 98 | func (d DecoderNoAlloc) DecodePublish(r io.Reader, qos QoSLevel) (_ VariablesPublish, n int, err error) { 99 | topic, n, err := decodeMQTTString(r, d.UserBuffer) 100 | if err != nil { 101 | return VariablesPublish{}, n, err 102 | } 103 | var PI uint16 104 | if qos == 1 || qos == 2 { 105 | var ngot int 106 | // In these cases PUBLISH contains a variable header which must be decoded. 107 | PI, ngot, err = decodeUint16(r) 108 | n += ngot 109 | if err != nil { // && !errors.Is(err, io.EOF) TODO(soypat): Investigate if it is necessary to guard against io.EOFs on packet ends. 110 | return VariablesPublish{}, n, err 111 | } 112 | } 113 | return VariablesPublish{TopicName: topic, PacketIdentifier: PI}, n, nil 114 | } 115 | 116 | // DecodeSubscribe implements [Decoder] interface. 117 | func (d DecoderNoAlloc) DecodeSubscribe(r io.Reader, remainingLen uint32) (varSub VariablesSubscribe, n int, err error) { 118 | payloadDst := d.UserBuffer 119 | varSub.PacketIdentifier, n, err = decodeUint16(r) 120 | if err != nil { 121 | return VariablesSubscribe{}, n, err 122 | } 123 | for n < int(remainingLen) { 124 | hotTopic, ngot, err := decodeMQTTString(r, payloadDst) 125 | n += ngot 126 | payloadDst = payloadDst[ngot:] //Advance buffer pointer to not overwrite. 127 | if err != nil { 128 | return VariablesSubscribe{}, n, err 129 | } 130 | qos, err := decodeByte(r) 131 | if err != nil { 132 | return VariablesSubscribe{}, n, err 133 | } 134 | n++ 135 | varSub.TopicFilters = append(varSub.TopicFilters, SubscribeRequest{TopicFilter: hotTopic, QoS: QoSLevel(qos)}) 136 | } 137 | return varSub, n, nil 138 | } 139 | 140 | // DecodeUnsubscribe implements [Decoder] interface. 141 | func (d DecoderNoAlloc) DecodeUnsubscribe(r io.Reader, remainingLength uint32) (varUnsub VariablesUnsubscribe, n int, err error) { 142 | payloadDst := d.UserBuffer 143 | varUnsub.PacketIdentifier, n, err = decodeUint16(r) 144 | if err != nil { 145 | return VariablesUnsubscribe{}, n, err 146 | } 147 | for n < int(remainingLength) { 148 | coldTopic, ngot, err := decodeMQTTString(r, payloadDst) 149 | n += ngot 150 | payloadDst = payloadDst[ngot:] // Advance buffer pointer to not overwrite. 151 | if err != nil { 152 | return VariablesUnsubscribe{}, n, err 153 | } 154 | varUnsub.Topics = append(varUnsub.Topics, coldTopic) 155 | } 156 | return varUnsub, n, nil 157 | } 158 | 159 | // decodeConnack decodes a connack packet. It is the responsibility of the caller to handle a non-zero [ConnectReturnCode]. 160 | func decodeConnack(r io.Reader) (VariablesConnack, int, error) { 161 | var buf [2]byte 162 | n, err := readFull(r, buf[:]) 163 | if err != nil { 164 | return VariablesConnack{}, n, err 165 | } 166 | varConnack := VariablesConnack{AckFlags: buf[0], ReturnCode: ConnectReturnCode(buf[1])} 167 | if err = varConnack.validate(); err != nil { 168 | return VariablesConnack{}, n, err 169 | } 170 | return varConnack, n, nil 171 | } 172 | 173 | // decodeSuback decodes a SUBACK packet. 174 | func decodeSuback(r io.Reader, remainingLen uint32) (varSuback VariablesSuback, n int, err error) { 175 | varSuback.PacketIdentifier, n, err = decodeUint16(r) 176 | if err != nil { 177 | return VariablesSuback{}, n, err 178 | } 179 | for n < int(remainingLen) { 180 | qos, err := decodeByte(r) 181 | if err != nil { 182 | return VariablesSuback{}, n, err 183 | } 184 | n++ 185 | varSuback.ReturnCodes = append(varSuback.ReturnCodes, QoSLevel(qos)) 186 | } 187 | return varSuback, n, nil 188 | } 189 | 190 | // decodeRemainingLength decodes the Remaining Length variable length integer 191 | // in MQTT fixed headers. This value can range from 1 to 4 bytes in length and 192 | func decodeRemainingLength(r io.Reader) (value uint32, n int, err error) { 193 | multiplier := uint32(1) 194 | for i := 0; i < maxRemainingLengthSize && multiplier <= 128*128*128; i++ { 195 | encodedByte, err := decodeByte(r) 196 | if err != nil { 197 | return value, n, err 198 | } 199 | n++ 200 | value += uint32(encodedByte&127) * multiplier 201 | if encodedByte&128 == 0 { 202 | return value, n, nil 203 | } 204 | multiplier *= 128 205 | 206 | } 207 | return 0, n, errors.New("malformed remaining length") 208 | } 209 | 210 | func readFull(src io.Reader, dst []byte) (int, error) { 211 | n, err := src.Read(dst) 212 | if err == nil && n != len(dst) { 213 | var buffer [256]byte 214 | // TODO(soypat): Avoid heavy heap allocation by implementing lightweight algorithm here. 215 | i64, err := io.CopyBuffer(bytes.NewBuffer(dst[n:]), src, buffer[:]) 216 | i := int(i64) 217 | if err != nil && errors.Is(err, io.EOF) && i == len(dst[n:]) { 218 | err = nil 219 | } 220 | return n + i, err 221 | } 222 | return n, err 223 | } 224 | 225 | // decodeMQTT unmarshals a string from r into buffer's start. The unmarshalled 226 | // string can be at most len(buffer). buffer must be at least of length 2. 227 | // decodeMQTTString only returns a non-nil string on a successful decode. 228 | func decodeMQTTString(r io.Reader, buffer []byte) ([]byte, int, error) { 229 | if len(buffer) < 2 { 230 | return nil, 0, ErrUserBufferFull 231 | } 232 | stringLength, n, err := decodeUint16(r) 233 | if err != nil { 234 | return nil, n, err 235 | } 236 | if stringLength == 0 { 237 | return nil, n, errors.New("zero length MQTT string") 238 | } 239 | if stringLength > uint16(len(buffer)) { 240 | return nil, n, ErrUserBufferFull // errors.New("buffer too small for string of length " + strconv.FormatUint(uint64(stringLength), 10)) 241 | } 242 | ngot, err := readFull(r, buffer[:stringLength]) 243 | n += ngot 244 | if err != nil && errors.Is(err, io.EOF) && uint16(ngot) == stringLength { 245 | err = nil // MQTT string was read successfully albeit with an EOF right at the end. 246 | } 247 | return buffer[:stringLength], n, err 248 | } 249 | 250 | func decodeByte(r io.Reader) (value byte, err error) { 251 | var vbuf [1]byte 252 | n, err := r.Read(vbuf[:]) 253 | if err != nil && errors.Is(err, io.EOF) && n == 1 { 254 | err = nil // Byte was read successfully albeit with an EOF. 255 | } 256 | return vbuf[0], err 257 | } 258 | 259 | func decodeUint16(r io.Reader) (value uint16, n int, err error) { 260 | var vbuf [2]byte 261 | n, err = readFull(r, vbuf[:]) 262 | if err != nil && errors.Is(err, io.EOF) && n == 2 { 263 | err = nil // integer was read successfully albeit with an EOF. 264 | } 265 | return uint16(vbuf[0])<<8 | uint16(vbuf[1]), n, err 266 | } 267 | -------------------------------------------------------------------------------- /definitions.go: -------------------------------------------------------------------------------- 1 | /* 2 | package mqtt implements MQTT v3.1.1 protocol providing users of this package with 3 | low level decoding and encoding primitives and complete documentation sufficient 4 | to grapple with the concepts of the MQTT protocol. 5 | 6 | If you are new to MQTT start by reading definitions.go. 7 | */ 8 | package mqtt 9 | 10 | const ( 11 | // Accepted protocol level as per MQTT v3.1.1. This goes in the CONNECT variable header. 12 | DefaultProtocolLevel = 4 13 | // Accepted protocol as per MQTT v3.1.1. This goes in the CONNECT variable header. 14 | DefaultProtocol = "MQTT" 15 | // Size on wire after being encoded. 16 | maxRemainingLengthSize = 4 17 | // Max value Remaining Length can take 0xfff_ffff. When encoded over the wire this value yields 0xffff_ff7f. 18 | maxRemainingLengthValue = 0xfff_ffff 19 | ) 20 | 21 | // Reserved flags for PUBREL, SUBSCRIBE and UNSUBSCRIBE packet types. 22 | // This is effectively a PUBLISH flag with QoS1 set and no DUP or RETAIN bits. 23 | const PacketFlagsPubrelSubUnsub PacketFlags = 0b10 24 | 25 | // PacketType represents the 4 MSB bits in the first byte in an MQTT fixed header. 26 | // takes on values 1..14. PacketType and PacketFlags are present in all MQTT packets. 27 | type PacketType byte 28 | 29 | const ( 30 | // 0 Forbidden/Reserved 31 | _ PacketType = iota 32 | // A CONNECT packet is sent from Client to Server, it is a Client request to connect to a Server. 33 | // After a network connection is established by a client to a server at the transport layer, the first 34 | // packet sent from the client to the server must be a Connect packet. 35 | // A Client can only send the CONNECT Packet once over a Network Connection. 36 | // The CONNECT packet contains a 10 byte variable header and a 37 | // payload determined by flags present in variable header. See [VariablesConnect]. 0x10. 38 | PacketConnect 39 | // The CONNACK Packet is the packet sent by the Server in response to a CONNECT Packet received from a Client. 40 | // The first packet sent from the Server to the Client MUST be a CONNACK Packet 41 | // The payload contains a 2 byte variable header and no payload. 0x20. 42 | PacketConnack 43 | // A PUBLISH Control Packet is sent from a Client to a Server or from Server to a Client to transport an Application Message. 44 | // It's payload contains a variable header with a MQTT encoded string for the topic name and a packet identifier. 45 | // The payload may or may not contain a Application Message that is being published. The length of this Message 46 | // can be calculated by subtracting the length of the variable header from the Remaining Length field that is in the Fixed Header. 0x3?. 47 | PacketPublish 48 | // A PUBACK Packet is the response to a PUBLISH Packet with QoS level 1. It's Variable header contains the packet identifier. No payload. 0x40. 49 | PacketPuback 50 | // A PUBREC Packet is the response to a PUBLISH Packet with QoS 2. It is the second packet of the QoS 2 protocol exchange. It's Variable header contains the packet identifier. No payload. 0x50. 51 | PacketPubrec 52 | // A PUBREL Packet is the response to a PUBREC Packet. It is the third packet of the QoS 2 protocol exchange. It's Variable header contains the packet identifier. No payload. 0x62. 53 | PacketPubrel 54 | // The PUBCOMP Packet is the response to a PUBREL Packet. It is the fourth and final packet of the QoS 2 protocol exchange. It's Variable header contains the packet identifier. No payload. 0x70. 55 | PacketPubcomp 56 | // The SUBSCRIBE Packet is sent from the Client to the Server to create one or more Subscriptions. 57 | // Each Subscription registers a Client’s interest in one or more Topics. The Server sends PUBLISH 58 | // Packets to the Client in order to forward Application Messages that were published to Topics that match these Subscriptions. 59 | // The SUBSCRIBE Packet also specifies (for each Subscription) the maximum QoS with which the Server can 60 | // send Application Messages to the Client. 61 | // The variable header of a subscribe topic contains the packet identifier. The payload contains a list of topic filters, see [VariablesSubscribe]. 0x82. 62 | PacketSubscribe 63 | // A SUBACK Packet is sent by the Server to the Client to confirm receipt and processing of a SUBSCRIBE Packet. 64 | // The variable header contains the packet identifier. The payload contains a list of octet return codes for each subscription requested by client, see [VariablesSuback]. 0x90. 65 | PacketSuback 66 | // An UNSUBSCRIBE Packet is sent by the Client to the Server, to unsubscribe from topics. 67 | // The variable header contains the packet identifier. Its payload contains a list of mqtt encoded strings corresponding to unsubscribed topics, see [VariablesUnsubscribe]. 0xa2. 68 | PacketUnsubscribe 69 | // The UNSUBACK Packet is sent by the Server to the Client to confirm receipt of an UNSUBSCRIBE Packet. 70 | // The variable header contains the packet identifier. It has no payload. 0xb0. 71 | PacketUnsuback 72 | // The PINGREQ Packet is sent from a Client to the Server. It can be used to: 73 | // - Indicate to the Server that the Client is alive in the absence of any other Control Packets being sent from the Client to the Server. 74 | // - Request that the Server responds to confirm that it is alive. 75 | // - Exercise the network to indicate that the Network Connection is active. 76 | // No payload or variable header. 0xc0. 77 | PacketPingreq 78 | // A PINGRESP Packet is sent by the Server to the Client in response to a PINGREQ Packet. It indicates that the Server is alive. 79 | // No payload or variable header. 0xd0. 80 | PacketPingresp 81 | // The DISCONNECT Packet is the final Control Packet sent from the Client to the Server. It indicates that the Client is disconnecting cleanly. 82 | // No payload or variable header. 0xe0. 83 | PacketDisconnect 84 | ) 85 | 86 | // QoSLevel represents the Quality of Service specified by the client. 87 | // The server can choose to provide or reject requested QoS. The values 88 | // of QoS range from 0 to 2, each representing a different methodology for 89 | // message delivery guarantees. 90 | type QoSLevel uint8 91 | 92 | // QoS indicates the level of assurance for packet delivery. 93 | const ( 94 | // QoS0 at most once delivery. Arrives either once or not at all. Depends on capabilities of underlying network. 95 | QoS0 QoSLevel = iota 96 | // QoS1 at least once delivery. Ensures message arrives at receiver at least once. 97 | QoS1 98 | // QoS2 Exactly once delivery. Highest quality service. For use when neither loss nor duplication of messages are acceptable. 99 | // There is an increased overhead associated with this quality of service. 100 | QoS2 101 | // Reserved, must not be used. 102 | reservedQoS3 103 | // QoSSubfail marks a failure in SUBACK. This value cannot be encoded into a header 104 | // and is only returned upon an unsuccessful subscribe to a topic in an SUBACK packet. 105 | QoSSubfail QoSLevel = 0x80 106 | ) 107 | 108 | // ConnectReturnCode represents the CONNACK return code, which is the second byte in the variable header. 109 | // It indicates if the connection was successful (0 value) or if the connection attempt failed on the server side. 110 | // ConnectReturnCode also implements the error interface and can be returned on a failed connection. 111 | type ConnectReturnCode uint8 112 | 113 | const ( 114 | ReturnCodeConnAccepted ConnectReturnCode = iota 115 | ReturnCodeUnnaceptableProtocol 116 | ReturnCodeIdentifierRejected 117 | ReturnCodeServerUnavailable 118 | ReturnCodeBadUserCredentials 119 | ReturnCodeUnauthorized 120 | minInvalidReturnCode 121 | ) 122 | 123 | // Error implements the error interface for a non-zero return code. 124 | func (rc ConnectReturnCode) Error() string { return rc.String() } 125 | -------------------------------------------------------------------------------- /encode.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "io" 8 | "math" 9 | ) 10 | 11 | // Encode encodes the header into the argument writer. It will encode up to a maximum 12 | // of 7 bytes, which is the max length header in MQTT v3.1. 13 | func (h Header) Encode(w io.Writer) (n int, err error) { 14 | if h.RemainingLength > maxRemainingLengthValue { 15 | return 0, errors.New("remaining length too large for MQTT v3.1.1 spec") 16 | } 17 | var headerBuf [5]byte 18 | n = h.Put(headerBuf[:]) 19 | return writeFull(w, headerBuf[:n]) 20 | } 21 | 22 | func (h Header) Put(buf []byte) int { 23 | _ = buf[4] 24 | buf[0] = h.firstByte 25 | return encodeRemainingLength(h.RemainingLength, buf[1:]) + 1 26 | } 27 | 28 | func encodeMQTTString(w io.Writer, s []byte) (int, error) { 29 | length := len(s) 30 | if length == 0 { 31 | return 0, errors.New("cannot encode MQTT string of length 0") 32 | } 33 | if length > math.MaxUint16 { 34 | return 0, errors.New("cannot encode MQTT string of length > MaxUint16 or length 0") 35 | } 36 | n, err := encodeUint16(w, uint16(len(s))) 37 | if err != nil { 38 | return n, err 39 | } 40 | n2, err := writeFull(w, s) 41 | n += n2 42 | if err != nil { 43 | return n, err 44 | } 45 | return n, nil 46 | } 47 | 48 | // encodeRemainingLength encodes between 1 to 4 bytes. 49 | func encodeRemainingLength(remlen uint32, b []byte) (n int) { 50 | if remlen > maxRemainingLengthValue { 51 | panic("remaining length too large. " + bugReportLink) 52 | } 53 | if remlen < 128 { 54 | // Fast path for small remaining lengths. Also the implementation below is not correct for remaining length = 0. 55 | b[0] = byte(remlen) 56 | return 1 57 | } 58 | 59 | for n = 0; remlen > 0; n++ { 60 | encoded := byte(remlen % 128) 61 | remlen /= 128 62 | if remlen > 0 { 63 | encoded |= 128 64 | } 65 | b[n] = encoded 66 | } 67 | return n 68 | } 69 | 70 | // All encode{PacketType} functions encode only their variable header. 71 | 72 | // encodeConnect encodes a CONNECT packet variable header over w given connVars. Does not encode 73 | // either the fixed header or the Packet Payload. 74 | func encodeConnect(w io.Writer, varConn *VariablesConnect) (n int, err error) { 75 | // Begin encoding variable header buffer. 76 | var varHeaderBuf [10]byte 77 | // Set protocol name 'MQTT' and protocol level 4. 78 | n += copy(varHeaderBuf[:], "\x00\x04MQTT\x04") // writes 7 bytes. 79 | varHeaderBuf[n] = varConn.Flags() 80 | varHeaderBuf[n+1] = byte(varConn.KeepAlive >> 8) // MSB 81 | varHeaderBuf[n+2] = byte(varConn.KeepAlive) // LSB 82 | // n+=3 // We've written 10 bytes exactly if all went well up to here. 83 | n, err = w.Write(varHeaderBuf[:]) 84 | if err == nil && n != 10 { 85 | return n, errors.New("single write did not complete for encoding, use larger underlying buffer") 86 | } 87 | if err != nil { 88 | return n, err 89 | } 90 | // Begin Encoding payload contents. First field is ClientID. 91 | ngot, err := encodeMQTTString(w, varConn.ClientID) 92 | n += ngot 93 | if err != nil { 94 | return n, err 95 | } 96 | 97 | if varConn.WillFlag() { 98 | ngot, err = encodeMQTTString(w, varConn.WillTopic) 99 | n += ngot 100 | if err != nil { 101 | return n, err 102 | } 103 | ngot, err = encodeMQTTString(w, varConn.WillMessage) 104 | n += ngot 105 | if err != nil { 106 | return n, err 107 | } 108 | } 109 | 110 | if len(varConn.Username) != 0 { 111 | // Username and password. 112 | ngot, err = encodeMQTTString(w, varConn.Username) 113 | n += ngot 114 | if err != nil { 115 | return n, err 116 | } 117 | if len(varConn.Password) != 0 { 118 | ngot, err = encodeMQTTString(w, varConn.Password) 119 | n += ngot 120 | if err != nil { 121 | return n, err 122 | } 123 | } 124 | } 125 | return n, nil 126 | } 127 | 128 | func encodeConnack(w io.Writer, varConn VariablesConnack) (int, error) { 129 | var buf [2]byte 130 | buf[0] = varConn.AckFlags 131 | buf[1] = byte(varConn.ReturnCode) 132 | return writeFull(w, buf[:]) 133 | } 134 | 135 | // encodePublish encodes PUBLISH packet variable header. Does not encode fixed header or user payload. 136 | func encodePublish(w io.Writer, qos QoSLevel, varPub VariablesPublish) (n int, err error) { 137 | n, err = encodeMQTTString(w, varPub.TopicName) 138 | if err != nil { 139 | return n, err 140 | } 141 | if qos != QoS0 { 142 | ngot, err := encodeUint16(w, varPub.PacketIdentifier) 143 | n += ngot 144 | if err != nil { 145 | return n, err 146 | } 147 | } 148 | return n, err 149 | } 150 | 151 | func encodeByte(w io.Writer, value byte) (n int, err error) { 152 | var vbuf [1]byte 153 | vbuf[0] = value 154 | return w.Write(vbuf[:]) 155 | } 156 | 157 | func encodeUint16(w io.Writer, value uint16) (n int, err error) { 158 | var vbuf [2]byte 159 | binary.BigEndian.PutUint16(vbuf[:], value) 160 | return writeFull(w, vbuf[:]) 161 | } 162 | 163 | func encodeSubscribe(w io.Writer, varSub VariablesSubscribe) (n int, err error) { 164 | if len(varSub.TopicFilters) == 0 { 165 | return 0, errors.New("payload of SUBSCRIBE must contain at least one topic filter / QoS pair") 166 | } 167 | n, err = encodeUint16(w, varSub.PacketIdentifier) 168 | if err != nil { 169 | return n, err 170 | } 171 | var vbuf [1]byte 172 | for _, hotTopic := range varSub.TopicFilters { 173 | ngot, err := encodeMQTTString(w, hotTopic.TopicFilter) 174 | n += ngot 175 | if err != nil { 176 | return n, err 177 | } 178 | vbuf[0] = byte(hotTopic.QoS & 0b11) 179 | ngot, err = w.Write(vbuf[:1]) 180 | n += ngot 181 | if err != nil { 182 | return n, err 183 | } 184 | } 185 | return n, nil 186 | } 187 | 188 | func encodeSuback(w io.Writer, varSuback VariablesSuback) (n int, err error) { 189 | n, err = encodeUint16(w, varSuback.PacketIdentifier) 190 | if err != nil { 191 | return n, err 192 | } 193 | for _, qos := range varSuback.ReturnCodes { 194 | if !qos.IsValid() && qos != QoSSubfail { // Suback can encode a subfail. 195 | panic("encodeSuback received an invalid QoS return code. " + bugReportLink) 196 | } 197 | ngot, err := encodeByte(w, byte(qos)) 198 | n += ngot 199 | if err != nil { 200 | return n, err 201 | } 202 | } 203 | return n, nil 204 | } 205 | 206 | func encodeUnsubscribe(w io.Writer, varUnsub VariablesUnsubscribe) (n int, err error) { 207 | if len(varUnsub.Topics) == 0 { 208 | return 0, errors.New("payload of UNSUBSCRIBE must contain at least one topic") 209 | } 210 | n, err = encodeUint16(w, varUnsub.PacketIdentifier) 211 | if err != nil { 212 | return n, err 213 | } 214 | for _, coldTopic := range varUnsub.Topics { 215 | ngot, err := encodeMQTTString(w, coldTopic) 216 | n += ngot 217 | if err != nil { 218 | return n, err 219 | } 220 | } 221 | return n, nil 222 | } 223 | 224 | // Pings and DISCONNECT do not have variable headers so no encoders here. 225 | 226 | func writeFull(dst io.Writer, src []byte) (int, error) { 227 | // dataPtr := 0 228 | n, err := dst.Write(src) 229 | if err == nil && n != len(src) { 230 | // TODO(soypat): Avoid heavy heap allocation by implementing lightweight algorithm here. 231 | var buffer [256]byte 232 | i, err := io.CopyBuffer(dst, bytes.NewBuffer(src[n:]), buffer[:]) 233 | return n + int(i), err 234 | } 235 | return n, err 236 | } 237 | 238 | // bool to uint8 239 | // 240 | //go:inline 241 | func b2u8(b bool) uint8 { 242 | if b { 243 | return 1 244 | } 245 | return 0 246 | } 247 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package mqtt_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "log" 10 | "math/rand" 11 | "net" 12 | "time" 13 | 14 | mqtt "github.com/soypat/natiu-mqtt" 15 | ) 16 | 17 | func ExampleClient_concurrent() { 18 | // Create new client. 19 | received := make(chan []byte, 10) 20 | client := mqtt.NewClient(mqtt.ClientConfig{ 21 | Decoder: mqtt.DecoderNoAlloc{make([]byte, 1500)}, 22 | OnPub: func(_ mqtt.Header, _ mqtt.VariablesPublish, r io.Reader) error { 23 | message, _ := io.ReadAll(r) 24 | if len(message) > 0 { 25 | select { 26 | case received <- message: 27 | default: 28 | // If channel is full we ignore message. 29 | } 30 | } 31 | log.Println("received message:", string(message)) 32 | return nil 33 | }, 34 | }) 35 | const TOPICNAME = "/mqttnerds" 36 | // Set the connection parameters and set the Client ID to "salamanca". 37 | var varConn mqtt.VariablesConnect 38 | varConn.SetDefaultMQTT([]byte("salamanca")) 39 | rng := rand.New(rand.NewSource(1)) 40 | 41 | // Define an inline function that connects the MQTT client automatically. 42 | // Is inline so it is contained within example. 43 | tryConnect := func() error { 44 | // Get a transport for MQTT packets using the local host and default MQTT port (1883). 45 | conn, err := net.Dial("tcp", "127.0.0.1:1883") 46 | if err != nil { 47 | return err 48 | } 49 | ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) 50 | defer cancel() 51 | err = client.Connect(ctx, conn, &varConn) // Connect to server. 52 | if err != nil { 53 | return err 54 | } 55 | 56 | // On succesful connection subscribe to topic. 57 | ctx, cancel = context.WithTimeout(context.Background(), 4*time.Second) 58 | defer cancel() 59 | vsub := mqtt.VariablesSubscribe{ 60 | TopicFilters: []mqtt.SubscribeRequest{ 61 | {TopicFilter: []byte(TOPICNAME), QoS: mqtt.QoS0}, // Only support QoS0 for now. 62 | }, 63 | PacketIdentifier: uint16(rng.Int31()), 64 | } 65 | return client.Subscribe(ctx, vsub) 66 | } 67 | 68 | // Attempt first connection and fail immediately if that does not work. 69 | err := tryConnect() 70 | if err != nil { 71 | log.Println(err) 72 | return 73 | } 74 | 75 | // Call read goroutine. Read goroutine will also handle reconnection 76 | // when client disconnects. 77 | go func() { 78 | for { 79 | if !client.IsConnected() { 80 | time.Sleep(time.Second) 81 | tryConnect() 82 | continue 83 | } 84 | err = client.HandleNext() 85 | if err != nil { 86 | log.Println("HandleNext failed:", err) 87 | } 88 | } 89 | }() 90 | 91 | // Call Write goroutine and create a channel to serialize messages 92 | // that we want to send out. 93 | pubFlags, _ := mqtt.NewPublishFlags(mqtt.QoS0, false, false) 94 | varPub := mqtt.VariablesPublish{ 95 | TopicName: []byte(TOPICNAME), 96 | } 97 | txQueue := make(chan []byte, 10) 98 | go func() { 99 | for { 100 | if !client.IsConnected() { 101 | time.Sleep(time.Second) 102 | continue 103 | } 104 | message := <-txQueue 105 | varPub.PacketIdentifier = uint16(rng.Int()) 106 | // Loop until message is sent successfully. This guarantees 107 | // all messages are sent, even in events of disconnect. 108 | for { 109 | err := client.PublishPayload(pubFlags, varPub, message) 110 | if err == nil { 111 | break 112 | } 113 | time.Sleep(time.Second) 114 | } 115 | } 116 | }() 117 | 118 | // Main program logic. 119 | for { 120 | message := <-received 121 | // We transform the message and send it back out. 122 | fields := bytes.Fields(message) 123 | message = bytes.Join(fields, []byte(",")) 124 | txQueue <- message 125 | } 126 | } 127 | 128 | func ExampleClient() { 129 | // Create new client with default settings. 130 | client := mqtt.NewClient(mqtt.ClientConfig{}) 131 | 132 | // Get a transport for MQTT packets. 133 | const defaultMQTTPort = ":1883" 134 | conn, err := net.Dial("tcp", "test.mosquitto.org"+defaultMQTTPort) 135 | if err != nil { 136 | fmt.Println(err) 137 | return 138 | } 139 | 140 | // Prepare for CONNECT interaction with server. 141 | var varConn mqtt.VariablesConnect 142 | varConn.SetDefaultMQTT([]byte("salamanca")) 143 | ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) 144 | err = client.Connect(ctx, conn, &varConn) // Connect to server. 145 | cancel() 146 | if err != nil { 147 | // Error or loop until connect success. 148 | log.Fatalf("connect attempt failed: %v\n", err) 149 | } 150 | fmt.Println("connection success") 151 | 152 | defer func() { 153 | err := client.Disconnect(errors.New("end of test")) 154 | if err != nil { 155 | fmt.Println("disconnect failed:", err) 156 | } 157 | }() 158 | 159 | // Ping forever until error. 160 | ctx, cancel = context.WithTimeout(context.Background(), time.Second) 161 | pingErr := client.Ping(ctx) 162 | cancel() 163 | if pingErr != nil { 164 | log.Fatal("ping error: ", pingErr, " with disconnect reason:", client.Err()) 165 | } 166 | fmt.Println("ping success!") 167 | // Output: 168 | // connection success 169 | // ping success! 170 | } 171 | 172 | func ExampleRxTx() { 173 | const defaultMQTTPort = ":1883" 174 | conn, err := net.Dial("tcp", "127.0.0.1"+defaultMQTTPort) 175 | if err != nil { 176 | log.Fatal(err) 177 | } 178 | rxtx, err := mqtt.NewRxTx(conn, mqtt.DecoderNoAlloc{UserBuffer: make([]byte, 1500)}) 179 | if err != nil { 180 | log.Fatal(err) 181 | } 182 | rxtx.RxCallbacks.OnConnack = func(rt *mqtt.Rx, vc mqtt.VariablesConnack) error { 183 | log.Printf("%v received, SP=%v, rc=%v", rt.LastReceivedHeader.String(), vc.SessionPresent(), vc.ReturnCode.String()) 184 | return nil 185 | } 186 | // PacketFlags set automatically for all packets that are not PUBLISH. So set to 0. 187 | varConnect := mqtt.VariablesConnect{ 188 | ClientID: []byte("salamanca"), 189 | Protocol: []byte("MQTT"), 190 | ProtocolLevel: 4, 191 | KeepAlive: 60, 192 | CleanSession: true, 193 | WillMessage: []byte("MQTT is okay, I guess"), 194 | WillTopic: []byte("mqttnerds"), 195 | WillRetain: true, 196 | } 197 | err = rxtx.WriteConnect(&varConnect) 198 | if err != nil { 199 | log.Fatal(err) 200 | } 201 | } 202 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/soypat/natiu-mqtt 2 | 3 | go 1.19 4 | -------------------------------------------------------------------------------- /mqtt.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "errors" 5 | "io" 6 | "strconv" 7 | ) 8 | 9 | // Decoder provides an abstraction for an MQTT variable header decoding implementation. 10 | // This is because heap allocations are necessary to be able to decode any MQTT packet. 11 | // Some compile targets are restrictive in terms of memory usage, so the best decoder for the situation may differ. 12 | type Decoder interface { 13 | // TODO(soypat): The CONNACK and SUBACK decoders can probably be excluded 14 | // from this interface since they do not need heap allocations, or if they 15 | // do end uf allocating their allocations are short lived, within scope of function. 16 | 17 | // DecodeConnack(r io.Reader) (VariablesConnack, int, error) 18 | 19 | // DecodeSuback(r io.Reader, remainingLen uint32) (VariablesSuback, int, error) 20 | 21 | DecodePublish(r io.Reader, qos QoSLevel) (VariablesPublish, int, error) 22 | DecodeConnect(r io.Reader) (VariablesConnect, int, error) 23 | DecodeSubscribe(r io.Reader, remainingLen uint32) (VariablesSubscribe, int, error) 24 | DecodeUnsubscribe(r io.Reader, remainingLength uint32) (VariablesUnsubscribe, int, error) 25 | } 26 | 27 | const bugReportLink = "Please report bugs at https://github.com/soypat/natiu-mqtt/issues/new " 28 | 29 | var ( 30 | errQoS0NoDup = errors.New("DUP must be 0 for all QoS0 [MQTT-3.3.1-2]") 31 | errEmptyTopic = errors.New("empty topic") 32 | errGotZeroPI = errors.New("packet identifier must be nonzero for packet type") 33 | 34 | // natiu-mqtt depends on user provided buffers for string and byte slice allocation. 35 | // If a buffer is too small for the incoming strings or for marshalling a subscription topic 36 | // then the implementation should return this error. 37 | ErrUserBufferFull = errors.New("natiu-mqtt: user buffer full") 38 | // ErrBadRemainingLen is passed to Rx's OnRxError after decoding a header with a 39 | // remaining length that does not conform to MQTT v3.1.1 packet specifications. 40 | ErrBadRemainingLen = errors.New("natiu-mqtt: MQTT v3.1.1 bad remaining length") 41 | ) 42 | 43 | // Header represents the bytes preceding the payload in an MQTT packet. 44 | // This commonly called the Fixed Header, although this Header type also contains 45 | // PacketIdentifier, which is part of the Variable Header and may or may not be present 46 | // in an MQTT packet. 47 | type Header struct { 48 | RemainingLength uint32 49 | // firstByte contains packet type in MSB bits 7-4 and flags in LSB bits 3-0. 50 | firstByte byte 51 | } 52 | 53 | // Size returns the size of the header as encoded over the wire. If the remaining 54 | // length is invalid Size returns 0. 55 | func (h Header) Size() (sz int) { 56 | rl := h.RemainingLength 57 | switch { 58 | case rl <= 0x7F: 59 | sz = 2 60 | case rl <= 0xff7f: 61 | sz = 3 62 | case rl <= 0xffff_7f: 63 | sz = 4 64 | case rl < maxRemainingLengthValue: 65 | sz = 5 66 | default: 67 | // sz = 0 // Not needed since sz's default value is zero. 68 | } 69 | return sz 70 | } 71 | 72 | // HasPacketIdentifier returns true if the MQTT packet has a 2 octet packet identifier number. 73 | func (h Header) HasPacketIdentifier() bool { 74 | tp := h.Type() 75 | qos := h.Flags().QoS() 76 | if tp == PacketPublish && (qos == QoS1 || qos == QoS2) { 77 | return true 78 | } 79 | noPI := tp == PacketConnect || tp == PacketConnack || 80 | tp == PacketPingreq || tp == PacketPingresp || tp == PacketDisconnect || tp == PacketPublish 81 | return tp != 0 && tp < 15 && !noPI 82 | } 83 | 84 | // PacketFlags represents the LSB 4 bits in the first byte in an MQTT fixed header. 85 | // PacketFlags takes on select values in range 1..15. PacketType and PacketFlags are present in all MQTT packets. 86 | type PacketFlags uint8 87 | 88 | // QoS returns the PUBLISH QoSLevel in pf which varies between 0..2. 89 | // PUBREL, UNSUBSCRIBE and SUBSCRIBE packets MUST have QoS1 set by standard. 90 | // Other packets will have a QoS1 set. 91 | func (pf PacketFlags) QoS() QoSLevel { return QoSLevel((pf >> 1) & 0b11) } 92 | 93 | // QoS returns true if the PUBLISH Retain bit is set. This typically is set by the client 94 | // to indicate the packet must be preserved after a Session ends which is to say Retained packets do not form part of Session state. 95 | func (pf PacketFlags) Retain() bool { return pf&1 != 0 } 96 | 97 | // Dup returns true if the DUP flag bit is set. 98 | // If the DUP flag is set to 0, it indicates that this is the first occasion that the Client or Server has attempted to send this MQTT PUBLISH Packet. 99 | func (pf PacketFlags) Dup() bool { return pf&(1<<3) != 0 } 100 | 101 | // String returns a pretty string representation of pf. Allocates memory. 102 | func (pf PacketFlags) String() string { 103 | if pf > 15 { 104 | return "invalid packet flags" 105 | } 106 | s := pf.QoS().String() 107 | if pf.Dup() { 108 | s += "/DUP" 109 | } 110 | if pf.Retain() { 111 | s += "/RET" 112 | } 113 | return s 114 | } 115 | 116 | // NewPublishFlags returns PUBLISH packet flags and an error if the flags were 117 | // to create a malformed packet according to MQTT specification. 118 | func NewPublishFlags(qos QoSLevel, dup, retain bool) (PacketFlags, error) { 119 | if qos > QoS2 { 120 | return 0, errors.New("invalid QoS") 121 | } 122 | if dup && qos == QoS0 { 123 | return 0, errQoS0NoDup 124 | } 125 | return PacketFlags(b2u8(retain) | (b2u8(dup) << 3) | uint8(qos<<1)), nil 126 | } 127 | 128 | // NewHeader creates a new Header for a packetType and returns an error if invalid 129 | // arguments are passed in. It will set expected reserved flags for non-PUBLISH packets. 130 | func NewHeader(packetType PacketType, packetFlags PacketFlags, remainingLen uint32) (Header, error) { 131 | if packetType != PacketPublish { 132 | // Set reserved flag for non-publish packets. 133 | ctlBit := b2u8(packetType == PacketPubrel || packetType == PacketSubscribe || packetType == PacketUnsubscribe) 134 | packetFlags = PacketFlags(ctlBit << 1) 135 | } 136 | if packetFlags > 15 { 137 | return Header{}, errors.New("packet flags exceeds 4 bit range 0..15") 138 | } 139 | if packetType > 15 { 140 | return Header{}, errors.New("packet type exceeds 4 bit range 0..15") 141 | } 142 | h := newHeader(packetType, packetFlags, remainingLen) 143 | if err := h.Validate(); err != nil { 144 | return Header{}, err 145 | } 146 | return h, nil 147 | } 148 | 149 | // newHeader returns a header with the argument type, flags and remaining length. 150 | // For internal use. This function performs no validation whatsoever. 151 | func newHeader(pt PacketType, pf PacketFlags, rlen uint32) Header { 152 | return Header{ 153 | firstByte: byte(pt)<<4 | byte(pf), 154 | RemainingLength: rlen, 155 | } 156 | } 157 | 158 | // Validate returns an error if the Header contains malformed data. This usually means 159 | // the header has bits set that contradict "MUST" statements in MQTT's protocol specification. 160 | func (h Header) Validate() error { 161 | pflags := h.Flags() 162 | ptype := h.Type() 163 | err := ptype.validateFlags(pflags) 164 | if err != nil { 165 | return err 166 | } 167 | if ptype == PacketPublish { 168 | dup := pflags.Dup() 169 | qos := pflags.QoS() 170 | if qos > QoS2 { 171 | return errors.New("invalid QoS") 172 | } 173 | if dup && qos == QoS0 { 174 | return errQoS0NoDup 175 | } 176 | } 177 | return nil 178 | } 179 | 180 | // Flags returns the MQTT packet flags in the fixed header. Important mainly for PUBLISH packets. 181 | func (h Header) Flags() PacketFlags { return PacketFlags(h.firstByte & 0b1111) } 182 | 183 | // Type returns the packet type with no validation. 184 | func (h Header) Type() PacketType { return PacketType(h.firstByte >> 4) } 185 | 186 | // String returns a pretty-string representation of h. Allocates memory. 187 | func (h Header) String() string { 188 | return h.Type().String() + " " + h.Flags().String() + " remlen: 0x" + strconv.FormatUint(uint64(h.RemainingLength), 16) 189 | } 190 | 191 | // PacketType lists in definitions.go 192 | 193 | func (p PacketType) validateFlags(flag4bits PacketFlags) error { 194 | onlyBit1Set := flag4bits&^(1<<1) == 0 195 | isControlPacket := p == PacketPubrel || p == PacketSubscribe || p == PacketUnsubscribe 196 | if p == PacketPublish || (onlyBit1Set && isControlPacket) || (!isControlPacket && flag4bits == 0) { 197 | return nil 198 | } 199 | if isControlPacket { 200 | return errors.New("control packet bit not set (0b0010)") 201 | } 202 | return errors.New("expected 0b0000 flag for packet type") 203 | } 204 | 205 | // String returns a string representation of the packet type, stylized with all caps 206 | // i.e: "PUBREL", "CONNECT". Does not allocate memory. 207 | func (p PacketType) String() string { 208 | if p > 15 { 209 | return "impossible packet type value" // Exceeds 4 bit value. 210 | } 211 | var s string 212 | switch p { 213 | // First two cases are reserved packets according to MQTT v3.1.1. 214 | case 15: 215 | s = "RESERVED(15)" 216 | case 0: 217 | s = "RESERVED(0)" 218 | case PacketConnect: 219 | s = "CONNECT" 220 | case PacketConnack: 221 | s = "CONNACK" 222 | case PacketPuback: 223 | s = "PUBACK" 224 | case PacketPubcomp: 225 | s = "PUBCOMP" 226 | case PacketPublish: 227 | s = "PUBLISH" 228 | case PacketPubrec: 229 | s = "PUBREC" 230 | case PacketPubrel: 231 | s = "PUBREL" 232 | case PacketSubscribe: 233 | s = "SUBSCRIBE" 234 | case PacketUnsubscribe: 235 | s = "UNSUBSCRIBE" 236 | case PacketUnsuback: 237 | s = "UNSUBACK" 238 | case PacketSuback: 239 | s = "SUBACK" 240 | case PacketPingresp: 241 | s = "PINGRESP" 242 | case PacketPingreq: 243 | s = "PINGREQ" 244 | case PacketDisconnect: 245 | s = "DISCONNECT" 246 | default: 247 | panic("unreachable") // Caught during fuzzing lets hope. 248 | } 249 | return s 250 | } 251 | 252 | // QoSLevel defined in definitions.go 253 | 254 | // IsValid returns true if qos is a valid Quality of Service. 255 | func (qos QoSLevel) IsValid() bool { return qos <= QoS2 } 256 | 257 | // String returns a pretty-string representation of qos i.e: "QoS0". Does not allocate memory. 258 | func (qos QoSLevel) String() (s string) { 259 | switch qos { 260 | case QoS0: 261 | s = "QoS0" 262 | case QoS1: 263 | s = "QoS1" 264 | case QoS2: 265 | s = "QoS2" 266 | case QoSSubfail: 267 | s = "QoS subscribe failure" 268 | case reservedQoS3: 269 | s = "invalid: use of reserved QoS3" 270 | default: 271 | s = "undefined QoS" 272 | } 273 | return s 274 | } 275 | 276 | // Packet specific functions 277 | 278 | // VariablesConnect all strings in the variable header must be UTF-8 encoded 279 | // except password which may be binary data. 280 | type VariablesConnect struct { 281 | // Must be present and unique to the server. UTF-8 encoded string 282 | // between 1 and 23 bytes in length although some servers may allow larger ClientIDs. 283 | ClientID []byte 284 | // By default will be set to 'MQTT' protocol if nil, which is v3.1 compliant. 285 | Protocol []byte 286 | Username []byte 287 | // For password to be used username must also be set. See [MQTT-3.1.2-22]. 288 | Password []byte 289 | WillTopic []byte 290 | WillMessage []byte 291 | // KeepAlive is a interval measured in seconds. it is the maximum time interval that is 292 | // permitted to elapse between the point at which the Client finishes transmitting one 293 | // Control Packet and the point it starts sending the next. 294 | KeepAlive uint16 295 | // By default if set to 0 will use Protocol level 4, which is v3.1 compliant 296 | ProtocolLevel byte 297 | // This bit specifies if the Will Message is to be Retained when it is published. 298 | WillRetain bool 299 | CleanSession bool 300 | // These two bits specify the QoS level to be used when publishing the Will Message. 301 | WillQoS QoSLevel 302 | } 303 | 304 | // Size returns size-on-wire of the CONNECT variable header generated by vs. 305 | func (vc *VariablesConnect) Size() (sz int) { 306 | sz += mqttStringSize(vc.Username) 307 | if len(vc.Username) != 0 { 308 | sz += mqttStringSize(vc.Password) // Make sure password is only added when username is enabled. 309 | } 310 | if vc.WillFlag() { 311 | // If will flag set then these two strings are obligatory but may be zero lengthed. 312 | sz += len(vc.WillTopic) + len(vc.WillMessage) + 4 313 | } 314 | sz += len(vc.ClientID) + len(vc.Protocol) + 4 315 | return sz + 1 + 2 + 1 // Add Connect flags (1), Protocol level (1) and keepalive (2). 316 | } 317 | 318 | // StringsLen returns length of all strings in variable header before being encoded. 319 | // StringsLen is useful to know how much of the user's buffer was consumed during decoding. 320 | func (vc *VariablesConnect) StringsLen() (n int) { 321 | if len(vc.Username) != 0 { 322 | n += len(vc.Password) // Make sure password is only added when username is enabled. 323 | } 324 | if vc.WillFlag() { 325 | n += len(vc.WillTopic) + len(vc.WillMessage) 326 | } 327 | return len(vc.ClientID) + len(vc.Protocol) + len(vc.Username) 328 | } 329 | 330 | // Flags returns the eighth CONNECT packet byte. 331 | func (vc *VariablesConnect) Flags() byte { 332 | willFlag := vc.WillFlag() 333 | hasUsername := len(vc.Username) != 0 334 | return b2u8(hasUsername)<<7 | b2u8(hasUsername && len(vc.Password) != 0)<<6 | // See [MQTT-3.1.2-22]. 335 | b2u8(vc.WillRetain)<<5 | byte(vc.WillQoS&0b11)<<3 | 336 | b2u8(willFlag)<<2 | b2u8(vc.CleanSession)<<1 337 | } 338 | 339 | // WillFlag returns true if CONNECT packet will have a will topic and a will message, which means setting Will Flag bit to 1. 340 | func (vc *VariablesConnect) WillFlag() bool { 341 | return len(vc.WillTopic) != 0 && len(vc.WillMessage) != 0 342 | } 343 | 344 | // VarConnack TODO 345 | 346 | // VariablesPublish represents the variable header of a PUBLISH packet. It does not 347 | // include the payload with the topic data. 348 | type VariablesPublish struct { 349 | // Must be present as utf-8 encoded string with NO wildcard characters. 350 | // The server may override the TopicName on response according to matching process [Section 4.7] 351 | TopicName []byte 352 | // Only present (non-zero) in QoS level 1 or 2. 353 | PacketIdentifier uint16 354 | } 355 | 356 | func (vp VariablesPublish) Validate() error { 357 | if vp.PacketIdentifier == 0 { 358 | return errGotZeroPI 359 | } else if len(vp.TopicName) == 0 { 360 | return errEmptyTopic 361 | } 362 | return nil 363 | } 364 | 365 | // Size returns size-on-wire of the PUBLISH variable header generated by vp. 366 | // It takes the packet QoS as an argument as it decides whether there's a Packet Identifier in the header. 367 | func (vp VariablesPublish) Size(qos QoSLevel) int { 368 | if qos != 0 { 369 | return len(vp.TopicName) + 2 + 2 // QoS1 and QoS2 include a 2 octet packet identifier. 370 | } 371 | return len(vp.TopicName) + 2 // No packet identifier, only the topic string. 372 | } 373 | 374 | // StringsLen returns length of all strings in variable header before being encoded. 375 | // StringsLen is useful to know how much of the user's buffer was consumed during decoding. 376 | func (vp VariablesPublish) StringsLen() int { return len(vp.TopicName) } 377 | 378 | // VariablesSubscribe represents the variable header of a SUBSCRIBE packet. 379 | // It encodes the topic filters requested by a Client and the desired QoS for each topic. 380 | type VariablesSubscribe struct { 381 | TopicFilters []SubscribeRequest 382 | PacketIdentifier uint16 383 | } 384 | 385 | // Size returns size-on-wire of the SUBSCRIBE variable header generated by vs. 386 | func (vs VariablesSubscribe) Size() (sz int) { 387 | for _, sub := range vs.TopicFilters { 388 | sz += len(sub.TopicFilter) + 2 + 1 389 | } 390 | return sz + 2 // Add packet ID. 391 | } 392 | 393 | // StringsLen returns length of all strings in variable header before being encoded. 394 | // StringsLen is useful to know how much of the user's buffer was consumed during decoding. 395 | func (vs VariablesSubscribe) StringsLen() (n int) { 396 | for _, sub := range vs.TopicFilters { 397 | n += len(sub.TopicFilter) 398 | } 399 | return n 400 | } 401 | 402 | // SubscribeRequest is relevant only to SUBSCRIBE packets where several SubscribeRequest 403 | // each encode a topic filter that is to be matched on the server side and a desired 404 | // QoS for each matched topic. 405 | type SubscribeRequest struct { 406 | // utf8 encoded topic or match pattern for topic filter. 407 | TopicFilter []byte 408 | // The desired QoS level. 409 | QoS QoSLevel 410 | } 411 | 412 | // VariablesSuback represents the variable header of a SUBACK packet. 413 | type VariablesSuback struct { 414 | // Each return code corresponds to a topic filter in the SUBSCRIBE 415 | // packet being acknowledged. These MUST match the order of said SUBSCRIBE packet. 416 | // A return code can indicate failure using QoSSubfail. 417 | ReturnCodes []QoSLevel 418 | PacketIdentifier uint16 419 | } 420 | 421 | func (vs VariablesSuback) Validate() error { 422 | if vs.PacketIdentifier == 0 { 423 | return errGotZeroPI 424 | } 425 | for _, rc := range vs.ReturnCodes { 426 | if !rc.IsValid() && rc != QoSSubfail { 427 | return errors.New("invalid QoS") 428 | } 429 | } 430 | return nil 431 | } 432 | 433 | // Size returns size-on-wire of the SUBACK variable header generated by vs. 434 | func (vs VariablesSuback) Size() (sz int) { return len(vs.ReturnCodes) + 2 } 435 | 436 | // VariablesUnsubscribe represents the variable header of a UNSUBSCRIBE packet. 437 | type VariablesUnsubscribe struct { 438 | Topics [][]byte 439 | PacketIdentifier uint16 440 | } 441 | 442 | // Size returns size-on-wire of the UNSUBSCRIBE variable header generated by vu. 443 | func (vu VariablesUnsubscribe) Size() (sz int) { 444 | for _, coldTopic := range vu.Topics { 445 | sz += len(coldTopic) + 2 446 | } 447 | return sz + 2 448 | } 449 | 450 | // StringsLen returns length of all strings in variable header before being encoded. 451 | // StringsLen is useful to know how much of the user's buffer was consumed during decoding. 452 | func (vu VariablesUnsubscribe) StringsLen() (n int) { 453 | for _, sub := range vu.Topics { 454 | n += len(sub) 455 | } 456 | return n 457 | } 458 | 459 | type VariablesConnack struct { 460 | // Octet with SP (Session Present) on LSB bit0. 461 | AckFlags uint8 462 | // Octet 463 | ReturnCode ConnectReturnCode 464 | } 465 | 466 | // String returns a pretty-string representation of CONNACK variable header. 467 | func (vc VariablesConnack) String() string { 468 | sp := vc.SessionPresent() 469 | if vc.AckFlags&^1 != 0 { 470 | return "forbidden connack ack flag bit set" 471 | } else if sp && vc.ReturnCode != 0 { 472 | return "invalid SP and return code combination" 473 | } 474 | s := "CONNACK: " + vc.ReturnCode.String() 475 | if sp { 476 | s += " (session present)" 477 | } 478 | return s 479 | } 480 | 481 | // Size returns size-on-wire of the CONNACK variable header generated by vs. 482 | func (vc VariablesConnack) Size() (sz int) { return 1 + 1 } 483 | 484 | // SessionPresent returns true if the SP bit is set in the CONNACK Ack flags. This bit indicates whether 485 | // the ClientID already has a session on the server. 486 | // - If server accepts a connection with CleanSession set to 1 the server MUST set SP to 0 (false). 487 | // - If server accepts a connection with CleanSession set to 0 SP depends on whether the server 488 | // already has stored a Session state for the supplied Client ID. If the server has stored a Session 489 | // then SP MUST set to 1, else MUST set to 0. 490 | // 491 | // In both cases above this is in addition to returning a zero CONNACK return code. If the CONNACK return code 492 | // is non-zero then SP MUST set to 0. 493 | func (vc VariablesConnack) SessionPresent() bool { return vc.AckFlags&1 != 0 } 494 | 495 | // validate provides early validation of CONNACK variables. 496 | func (vc VariablesConnack) validate() error { 497 | if vc.AckFlags&^1 != 0 { 498 | return errors.New("CONNACK Ack flag bits 7-1 must be set to 0") 499 | } 500 | return nil 501 | } 502 | 503 | // ConnectReturnCode defined in definitions.go 504 | 505 | // String returns a pretty-string representation of rc indicating if 506 | // the connection was accepted or the human-readable error if present. 507 | func (rc ConnectReturnCode) String() (s string) { 508 | switch rc { 509 | default: 510 | s = "unknown CONNACK return code" 511 | case ReturnCodeConnAccepted: 512 | s = "connection accepted" 513 | case ReturnCodeUnnaceptableProtocol: 514 | s = "unacceptable protocol version" 515 | case ReturnCodeIdentifierRejected: 516 | s = "client identifier rejected" 517 | case ReturnCodeBadUserCredentials: 518 | s = "bad username and/or password" 519 | case ReturnCodeUnauthorized: 520 | s = "client unauthorized" 521 | } 522 | return s 523 | } 524 | 525 | // DecodeHeader receives transp, an io.ByteReader that reads from an underlying arbitrary 526 | // transport protocol. transp should start returning the first byte of the MQTT packet. 527 | // Decode header returns the decoded header and any error that prevented it from 528 | // reading the entire header as specified by the MQTT v3.1 protocol. 529 | func DecodeHeader(transp io.Reader) (Header, int, error) { 530 | // Start parsing fixed header. 531 | firstByte, err := decodeByte(transp) 532 | if err != nil { 533 | return Header{}, 0, err 534 | } 535 | n := 1 536 | rlen, ngot, err := decodeRemainingLength(transp) 537 | n += ngot 538 | if err != nil { 539 | return Header{}, n, err 540 | } 541 | packetType := PacketType(firstByte >> 4) 542 | if packetType == 0 || packetType > PacketDisconnect { 543 | return Header{}, n, errors.New("invalid packet type") 544 | } 545 | packetFlags := PacketFlags(firstByte & 0b1111) 546 | if err := packetType.validateFlags(packetFlags); err != nil { 547 | // Early validation. 548 | return Header{}, n, err 549 | } 550 | hdr := Header{ 551 | firstByte: firstByte, 552 | RemainingLength: rlen, 553 | } 554 | return hdr, n, nil 555 | } 556 | 557 | // mqttStringSize returns the size on wire occupied 558 | // by an *OPTIONAL* MQTT encoded string. If string is zero length returns 0. 559 | func mqttStringSize(b []byte) int { 560 | lb := len(b) 561 | if lb > 0 { 562 | return lb + 2 563 | } 564 | return 0 565 | } 566 | 567 | // SetDefaultMQTT sets required fields, like the ClientID, Protocol and Protocol level fields. 568 | // If KeepAlive is zero, is set to 60 (one minute). If Protocol field is not set to "MQTT" then memory is allocated for it. 569 | // Clean session is also set to true. 570 | func (vc *VariablesConnect) SetDefaultMQTT(clientID []byte) { 571 | vc.ClientID = clientID 572 | if string(vc.Protocol) != DefaultProtocol { 573 | vc.Protocol = make([]byte, len(DefaultProtocol)) 574 | copy(vc.Protocol, DefaultProtocol) 575 | } 576 | vc.ProtocolLevel = DefaultProtocolLevel 577 | if vc.KeepAlive == 0 { 578 | vc.KeepAlive = 60 579 | } 580 | vc.CleanSession = true 581 | } 582 | 583 | func (vs *VariablesSubscribe) Validate() error { 584 | if len(vs.TopicFilters) == 0 { 585 | return errors.New("no topic filters in VariablesSubscribe") 586 | } 587 | for _, v := range vs.TopicFilters { 588 | if !v.QoS.IsValid() { 589 | return errors.New("invalid QoS in VariablesSubscribe") 590 | } else if len(v.TopicFilter) == 0 { 591 | return errors.New("got empty topic filter in VariablesSubscribe") 592 | } 593 | } 594 | return nil 595 | } 596 | 597 | // Copy copies the subscribe variables optimizing for memory space savings. 598 | func (vs *VariablesSubscribe) Copy() VariablesSubscribe { 599 | vscp := VariablesSubscribe{ 600 | TopicFilters: make([]SubscribeRequest, len(vs.TopicFilters)), 601 | PacketIdentifier: vs.PacketIdentifier, 602 | } 603 | blen := 0 604 | for i := range vs.TopicFilters { 605 | blen += len(vs.TopicFilters[i].TopicFilter) 606 | } 607 | buf := make([]byte, blen) 608 | blen = 0 609 | for i := range vs.TopicFilters { 610 | vscp.TopicFilters[i].TopicFilter = buf[blen : blen+len(vs.TopicFilters[i].TopicFilter)] 611 | blen += copy(vscp.TopicFilters[i].TopicFilter, vs.TopicFilters[i].TopicFilter) 612 | vscp.TopicFilters[i].QoS = vs.TopicFilters[i].QoS 613 | } 614 | return vscp 615 | } 616 | -------------------------------------------------------------------------------- /mqtt_test.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "fmt" 8 | "io" 9 | "math" 10 | "net" 11 | "testing" 12 | "time" 13 | ) 14 | 15 | const TCPServer = "test.mosquitto.org:1883" 16 | 17 | func TestMQTTConnect(t *testing.T) { 18 | const ( 19 | clientID = "natiu-test" 20 | topic = "abc" 21 | payload = "hello world!" 22 | testTimeout = 3 * time.Second 23 | ) 24 | tcpaddr, err := net.ResolveTCPAddr("tcp", TCPServer) 25 | if err != nil { 26 | t.Fatal(err) 27 | } 28 | conn, err := net.DialTCP("tcp", nil, tcpaddr) 29 | if err != nil { 30 | t.Fatal(err) 31 | } 32 | c := NewClient(ClientConfig{OnPub: func(pubHead Header, varPub VariablesPublish, r io.Reader) error { 33 | t.Log(pubHead, varPub) 34 | return nil 35 | }}) 36 | ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 37 | defer cancel() 38 | var varconn VariablesConnect 39 | varconn.SetDefaultMQTT([]byte(clientID)) 40 | err = c.Connect(ctx, conn, &varconn) 41 | if err != nil { 42 | t.Error(err) 43 | } 44 | pid := (uint16(time.Now().UnixMilli()) % 512) + 2 // Ensure packet ID greater than 1 that cant overflow. 45 | err = c.Subscribe(ctx, VariablesSubscribe{ 46 | PacketIdentifier: pid, 47 | TopicFilters: []SubscribeRequest{{ 48 | TopicFilter: []byte(topic), 49 | QoS: QoS0, 50 | }}, 51 | }) 52 | if err != nil { 53 | t.Error(err) 54 | } 55 | flags, _ := NewPublishFlags(QoS0, false, false) 56 | varPub := VariablesPublish{ 57 | TopicName: []byte(topic), 58 | PacketIdentifier: pid + 1, 59 | } 60 | err = c.PublishPayload(flags, varPub, []byte(payload)) 61 | if err != nil { 62 | t.Error(err) 63 | } 64 | } 65 | 66 | // NewRxTx creates a new RxTx. Before use user must configure OnX fields by setting a function 67 | // to perform an action each time a packet is received. After a call to transport.Close() 68 | // all future calls must return errors until the transport is replaced with [RxTx.SetTransport]. 69 | func NewRxTx(transport io.ReadWriteCloser, decoder Decoder) (*RxTx, error) { 70 | if transport == nil || decoder == nil { 71 | return nil, errors.New("got nil transport io.ReadWriteCloser or nil Decoder") 72 | } 73 | cc := &RxTx{ 74 | Rx: Rx{ 75 | rxTrp: transport, 76 | userDecoder: decoder, 77 | }, 78 | Tx: Tx{txTrp: transport}, 79 | } 80 | return cc, nil 81 | } 82 | 83 | // RxTx implements a bare minimum MQTT v3.1.1 protocol transport layer handler. 84 | // If there is an error during read/write of a packet the transport is closed 85 | // and a new transport must be set with [RxTx.SetTransport]. 86 | // An RxTx will not validate data before encoding, that is up to the caller, it 87 | // will validate incoming data according to MQTT's specification. Malformed packets 88 | // will be rejected and the connection will be closed immediately with a call to [RxTx.OnError]. 89 | type RxTx struct { 90 | Tx 91 | Rx 92 | } 93 | 94 | // ShallowCopy shallow copies rxtx and underlying transports and encoders/decoders. Does not copy callbacks over. 95 | func (rxtx *RxTx) ShallowCopy() *RxTx { 96 | return &RxTx{ 97 | Tx: *rxtx.Tx.ShallowCopy(), 98 | Rx: *rxtx.Rx.ShallowCopy(), 99 | } 100 | } 101 | 102 | // SetTransport sets the rxtx's reader and writer. 103 | func (rxtx *RxTx) SetTransport(transport io.ReadWriteCloser) { 104 | rxtx.rxTrp = transport 105 | rxtx.txTrp = transport 106 | } 107 | 108 | func FuzzRxTxReadNextPacket(f *testing.F) { 109 | const maxSize = 1500 110 | testCases := [][]byte{ 111 | // Typical connect packet. 112 | []byte("\x10\x1e\x00\x04MQTT\x04\xec\x00<\x00\x020w\x00\x02Bw\x00\x02Aw\x00\x02Cw\x00\x02Dw"), 113 | // Typical connack 114 | []byte("\x02\x01\x04"), 115 | // A Publish packet 116 | []byte(";\x8e\x01\x00&now-for-something-completely-different\xff\xffertytgbhjjhundsaip;vf[oniw[aondmiksfvoWDNFOEWOPndsafr;poulikujyhtgbfrvdcsxzaesxt dfcgvfhbg kjnlkm/'."), 117 | // A subscribe packet. 118 | []byte("\x824\xff\xff\x00\tfavorites\x02\x00\tthe-clash\x02\x00\x0falways-watching\x02\x00\x05k-pop\x02"), 119 | // Unsubscribe packet 120 | []byte("\xa2$\xff\xff\x00\x06topic1\x00\x06topic2\x00\x06topic3\x00\bsemperfi"), 121 | // Suback packet. 122 | []byte("\x90\b\xff\xff\x00\x01\x00\x02\x80\x01"), 123 | // Pubrel packet. 124 | []byte("b\x02\f\xa0"), 125 | } 126 | testCases = append(testCases, fuzzCorpus...) 127 | for _, tc := range testCases { 128 | f.Add(tc) // Provide seed corpus. 129 | } 130 | 131 | f.Fuzz(func(t *testing.T, a []byte) { 132 | if len(a) == 0 || len(a) > maxSize { 133 | return 134 | } 135 | buf := newLoopbackTransport() 136 | _, err := writeFull(buf, a) 137 | if err != nil { 138 | t.Fatal(err) 139 | } 140 | rxtx, err := NewRxTx(buf, DecoderNoAlloc{make([]byte, maxSize+10)}) 141 | if err != nil { 142 | t.Fatal(err) 143 | } 144 | rxtx.ReadNextPacket() 145 | }) 146 | _ = testCases 147 | } 148 | 149 | func TestHeaderLoopback(t *testing.T) { 150 | pubQoS0flag, err := NewPublishFlags(QoS0, false, true) 151 | if err != nil { 152 | t.Fatal(err) 153 | } 154 | for _, header := range []struct { 155 | tp PacketType 156 | flags PacketFlags 157 | 158 | remlen uint32 159 | }{ 160 | {tp: PacketPubrel}, 161 | {tp: PacketPingreq}, 162 | {tp: PacketPublish, flags: pubQoS0flag}, 163 | {tp: PacketConnect}, 164 | } { 165 | h, err := NewHeader(header.tp, header.flags, header.remlen) 166 | if err != nil { 167 | t.Fatal(err) 168 | } 169 | if h.RemainingLength != header.remlen { 170 | t.Error("remaining length mismatch") 171 | } 172 | flagsGot := h.Flags() 173 | if header.tp == PacketPublish && flagsGot != header.flags { 174 | t.Error("publish flag mismatch", flagsGot, header.flags) 175 | } 176 | typeGot := h.Type() 177 | if typeGot != header.tp { 178 | t.Error("type mismatch") 179 | } 180 | } 181 | } 182 | 183 | func TestRxTxBadPacketRxErrors(t *testing.T) { 184 | rxtx, err := NewRxTx(&testTransport{}, DecoderNoAlloc{UserBuffer: make([]byte, 1500)}) 185 | if err != nil { 186 | t.Fatal(err) 187 | } 188 | for _, test := range []struct { 189 | reason string 190 | rx []byte 191 | }{ 192 | {"no contents", []byte("")}, 193 | {"EOF during fixed header", []byte("\x01")}, 194 | {"forbidden packet type 0", []byte("\x00\x00")}, 195 | {"forbidden packet type 15", []byte("\xf0\x00")}, 196 | {"missing CONNECT var header and bad remaining length", []byte("\x10\x0a")}, 197 | {"missing CONNECT var header", []byte("\x10\x00")}, 198 | {"missing CONNACK var header", []byte("\x20\x00")}, 199 | {"missing PUBLISH var header", []byte("\x30\x00")}, 200 | {"missing PUBACK var header", []byte("\x40\x00")}, 201 | {"missing SUBSCRIBE var header", []byte("\x80\x00")}, 202 | {"missing SUBACK var header", []byte("\x90\x00")}, 203 | {"missing UNSUBSCRIBE var header", []byte("\xa0\x00")}, 204 | {"missing UNSUBACK var header", []byte("\xb0\x00")}, 205 | } { 206 | buf := newLoopbackTransport() 207 | rxtx.SetTransport(buf) 208 | n, err := buf.Write(test.rx) 209 | if err != nil || n != len(test.rx) { 210 | t.Fatal("all bytes not written or error:", err) 211 | } 212 | _, err = rxtx.ReadNextPacket() 213 | if err == nil { 214 | t.Error("expected error for case:", test.reason) 215 | } 216 | } 217 | } 218 | 219 | func TestHasPacketIdentifer(t *testing.T) { 220 | const ( 221 | qos0Flag = PacketFlags(QoS0 << 1) 222 | qos1Flag = PacketFlags(QoS1 << 1) 223 | qos2Flag = PacketFlags(QoS2 << 1) 224 | ) 225 | for _, test := range []struct { 226 | h Header 227 | expect bool 228 | }{ 229 | {h: newHeader(PacketConnect, 0, 0), expect: false}, 230 | {h: newHeader(PacketConnack, 0, 0), expect: false}, 231 | {h: newHeader(PacketPublish, qos0Flag, 0), expect: false}, 232 | {h: newHeader(PacketPublish, qos1Flag, 0), expect: true}, 233 | {h: newHeader(PacketPublish, qos2Flag, 0), expect: true}, 234 | {h: newHeader(PacketPuback, 0, 0), expect: true}, 235 | {h: newHeader(PacketPubrec, 0, 0), expect: true}, 236 | {h: newHeader(PacketPubrel, 0, 0), expect: true}, 237 | {h: newHeader(PacketPubcomp, 0, 0), expect: true}, 238 | {h: newHeader(PacketUnsubscribe, 0, 0), expect: true}, 239 | {h: newHeader(PacketUnsuback, 0, 0), expect: true}, 240 | {h: newHeader(PacketPingreq, 0, 0), expect: false}, 241 | {h: newHeader(PacketPingresp, 0, 0), expect: false}, 242 | {h: newHeader(PacketDisconnect, 0, 0), expect: false}, 243 | } { 244 | t.Log("tested ", test.h.String()) 245 | got := test.h.HasPacketIdentifier() 246 | if got != test.expect { 247 | t.Errorf("%s: got %v, expected %v", test.h.String(), got, test.expect) 248 | } 249 | } 250 | } 251 | 252 | func TestVariablesConnectFlags(t *testing.T) { 253 | getFlags := func(flag byte) (username, password, willRetain, willFlag, cleanSession, reserved bool, qos QoSLevel) { 254 | return flag&(1<<7) != 0, flag&(1<<6) != 0, flag&(1<<5) != 0, flag&(1<<2) != 0, flag&(1<<1) != 0, flag&1 != 0, QoSLevel(flag>>3) & 0b11 255 | } 256 | var connect VariablesConnect 257 | connect.SetDefaultMQTT([]byte("salamanca")) 258 | flags := connect.Flags() 259 | usr, pwd, wR, wF, cs, forbidden, qos := getFlags(flags) 260 | if qos != QoS0 { 261 | t.Error("QoS0 default, got ", qos.String()) 262 | } 263 | if usr || pwd { 264 | t.Error("expected no password or user on default flags") 265 | } 266 | if wR { 267 | t.Error("will retain set") 268 | } 269 | if wF { 270 | t.Error("will flag set") 271 | } 272 | if !cs { 273 | t.Error("clean session not set") 274 | } 275 | if forbidden { 276 | t.Error("forbidden bit set") 277 | } 278 | if DefaultProtocolLevel != connect.ProtocolLevel { 279 | t.Error("protocol level mismatch") 280 | } 281 | if DefaultProtocol != string(connect.Protocol) { 282 | t.Error("protocol mismatch") 283 | } 284 | connect.WillQoS = QoS2 285 | connect.Username = []byte("inigo") 286 | connect.Password = []byte("123") 287 | connect.CleanSession = false 288 | usr, pwd, wR, wF, cs, forbidden, qos = getFlags(connect.Flags()) 289 | if qos != QoS2 { 290 | t.Error("QoS0 default, got ", qos.String()) 291 | } 292 | if !usr { 293 | t.Error("username flag not ok") 294 | } 295 | if !pwd { 296 | t.Error("password flag not ok") 297 | } 298 | if wR { 299 | t.Error("will retain set") 300 | } 301 | if wF { 302 | t.Error("will flag set") 303 | } 304 | if cs { 305 | t.Error("clean session set") 306 | } 307 | if forbidden { 308 | t.Error("forbidden bit set") 309 | } 310 | } 311 | 312 | func TestHeaderSize(t *testing.T) { 313 | for _, test := range []struct { 314 | h Header 315 | expect int 316 | }{ 317 | {h: newHeader(1, 0, 0), expect: 2}, 318 | {h: newHeader(1, 0, 1), expect: 2}, 319 | {h: newHeader(1, 0, 2), expect: 2}, 320 | {h: newHeader(1, 0, 128), expect: 3}, 321 | {h: newHeader(1, 0, 0xffff), expect: 4}, 322 | {h: newHeader(1, 0, 0xffff_ff), expect: 5}, 323 | {h: newHeader(1, 0, 0xffff_ffff), expect: 0}, // bad remaining length 324 | } { 325 | got := test.h.Size() 326 | if got != test.expect { 327 | t.Error("size mismatch for remlen:", test.h.RemainingLength, got, test.expect) 328 | } 329 | } 330 | } 331 | 332 | func TestHeaderEncodeDecodeLoopback(t *testing.T) { 333 | var b bytes.Buffer 334 | for _, test := range []struct { 335 | desc string 336 | h Header 337 | expect int 338 | }{ 339 | {desc: "max remlen", h: newHeader(1, 0, maxRemainingLengthValue), expect: 5}, // TODO(soypat): must support up to maxRemainingLengthValue remaining length. 340 | {desc: "bad: maxremlen+1", h: newHeader(1, 0, maxRemainingLengthValue+1), expect: 0}, // bad remaining length 341 | {desc: "remlen=0", h: newHeader(1, 0, 0), expect: 2}, 342 | {desc: "remlen=1", h: newHeader(1, 0, 1), expect: 2}, 343 | {desc: "remlen=2", h: newHeader(1, 0, 2), expect: 2}, 344 | {desc: "medium remlen", h: newHeader(1, 0, 128), expect: 3}, 345 | {desc: "big remlen", h: newHeader(1, 0, 0xffff), expect: 4}, 346 | } { 347 | hdr := test.h 348 | nencode, err := hdr.Encode(&b) 349 | if hdr.RemainingLength > maxRemainingLengthValue { 350 | if err == nil { 351 | t.Errorf("%s: expected error for malformed packet", test.desc) 352 | } 353 | continue 354 | } 355 | if err != nil { 356 | t.Fatal(err) 357 | } 358 | gotHdr, ndecode, err := DecodeHeader(&b) 359 | if err != nil { 360 | t.Fatalf("%s:decoded %d byte for %+v: %v", test.desc, ndecode, hdr, err) 361 | } 362 | if nencode != ndecode { 363 | t.Errorf("%s: number of bytes encoded (%d) not match decoded (%d)", test.desc, nencode, ndecode) 364 | } 365 | if nencode != test.expect { 366 | t.Errorf("%s: expected to encode %d bytes, encoded %d: %s", test.desc, test.expect, nencode, hdr) 367 | } 368 | if hdr != gotHdr { 369 | t.Errorf("%s: header mismatch in values encode:%+v; decode:%+v", test.desc, hdr, gotHdr) 370 | } 371 | } 372 | } 373 | 374 | func TestVariablesConnectSize(t *testing.T) { 375 | var varConn VariablesConnect 376 | varConn.SetDefaultMQTT([]byte("salamanca")) 377 | varConn.WillQoS = QoS1 378 | varConn.WillRetain = true 379 | varConn.WillMessage = []byte("Hello, my name is Inigo Montoya. You killed my father. Prepare to die.") 380 | varConn.WillTopic = []byte("great-movies") 381 | varConn.Username = []byte("Inigo") 382 | varConn.Password = []byte("\x00\x01\x02\x03flab\xff\x7f\xff") 383 | got := varConn.Size() 384 | expect, err := encodeConnect(io.Discard, &varConn) 385 | if err != nil { 386 | t.Fatal(err) 387 | } 388 | if got != expect { 389 | t.Errorf("Size returned %d. encoding CONNECT variable header yielded %d", got, expect) 390 | } 391 | } 392 | 393 | func TestRxTxLoopback(t *testing.T) { 394 | // This test starts with a long running 395 | buf := newLoopbackTransport() 396 | rxtx, err := NewRxTx(buf, DecoderNoAlloc{make([]byte, 1500)}) 397 | if err != nil { 398 | t.Fatal(err) 399 | } 400 | rxtx.SetTransport(buf) 401 | // 402 | // Send CONNECT packet over wire. 403 | // 404 | { 405 | var varConn VariablesConnect 406 | varConn.SetDefaultMQTT([]byte("0w")) 407 | varConn.WillQoS = QoS1 408 | varConn.WillRetain = true 409 | varConn.WillMessage = []byte("Aw") 410 | varConn.WillTopic = []byte("Bw") 411 | varConn.Username = []byte("Cw") 412 | varConn.Password = []byte("Dw") 413 | remlen := uint32(varConn.Size()) 414 | expectHeader := newHeader(PacketConnect, 0, remlen) 415 | err = rxtx.WriteConnect(&varConn) 416 | if err != nil { 417 | t.Fatal(err) 418 | } 419 | // We now prepare to receive CONNECT packet on other side. 420 | callbackExecuted := false 421 | rxtx.RxCallbacks.OnConnect = func(rt *Rx, vc *VariablesConnect) error { 422 | if rt.LastReceivedHeader != expectHeader { 423 | t.Errorf("rxtx header mismatch, expect:%v, rxed:%v", expectHeader.String(), rt.LastReceivedHeader.String()) 424 | } 425 | varEqual(t, &varConn, vc) 426 | callbackExecuted = true 427 | return nil 428 | } 429 | // Read packet that is on the "wire" 430 | n, err := rxtx.ReadNextPacket() 431 | if err != nil { 432 | t.Fatal(err) 433 | } 434 | expectSize := expectHeader.Size() + varConn.Size() 435 | if n != expectSize { 436 | t.Errorf("read %v bytes, expected to read %v bytes", n, expectSize) 437 | } 438 | if !callbackExecuted { 439 | t.Error("OnConnect callback not executed") 440 | } 441 | } 442 | if t.Failed() { 443 | return // fix first clause before continuing. 444 | } 445 | buf.rw.Reset() 446 | // 447 | // Send CONNACK packet over wire. 448 | // 449 | { 450 | varConnck := VariablesConnack{ 451 | AckFlags: 1, // SP set 452 | ReturnCode: ReturnCodeConnAccepted, // Accepted since SP set. 453 | } 454 | err = rxtx.WriteConnack(varConnck) 455 | if err != nil { 456 | t.Fatal(err) 457 | } 458 | expectHeader := newHeader(PacketConnack, 0, uint32(varConnck.Size())) 459 | callbackExecuted := false 460 | rxtx.RxCallbacks.OnConnack = func(rt *Rx, vc VariablesConnack) error { 461 | if rt.LastReceivedHeader != expectHeader { 462 | t.Errorf("rxtx header mismatch, expect:%v, rxed:%v", expectHeader.String(), rt.LastReceivedHeader.String()) 463 | } 464 | varEqual(t, varConnck, vc) 465 | callbackExecuted = true 466 | return nil 467 | } 468 | 469 | n, err := rxtx.ReadNextPacket() 470 | if err != nil { 471 | t.Fatal(err) 472 | } 473 | expectSize := expectHeader.Size() + varConnck.Size() 474 | if n != expectSize { 475 | t.Errorf("read %v bytes, expected to read %v bytes", n, expectSize) 476 | } 477 | if !callbackExecuted { 478 | t.Error("OnConnack callback not executed") 479 | } 480 | } 481 | 482 | // 483 | // Send PUBLISH QoS1 packet over wire. 484 | // 485 | { 486 | const pubQos = QoS1 487 | publishPayload := []byte("PL") 488 | varPublish := VariablesPublish{ 489 | TopicName: []byte("TOP"), 490 | PacketIdentifier: math.MaxUint16, 491 | } 492 | pubflags, err := NewPublishFlags(pubQos, true, true) 493 | if err != nil { 494 | t.Fatal(err) 495 | } 496 | expectedRemainingLen := uint32(varPublish.Size(pubQos) + len(publishPayload)) 497 | publishHeader := newHeader(PacketPublish, pubflags, expectedRemainingLen) 498 | err = rxtx.WritePublishPayload(publishHeader, varPublish, publishPayload) 499 | if err != nil { 500 | t.Fatal(err) 501 | } 502 | callbackExecuted := false 503 | rxtx.RxCallbacks.OnPub = func(rt *Rx, vp VariablesPublish, r io.Reader) error { 504 | b, err := io.ReadAll(r) 505 | if err != nil { 506 | t.Fatal(err) 507 | } 508 | if !bytes.Equal(b, publishPayload) { 509 | t.Error("got different payloads!") 510 | } 511 | if rt.LastReceivedHeader != publishHeader { 512 | t.Errorf("rxtx header mismatch, txed:%v, rxed:%v", publishHeader.String(), rt.LastReceivedHeader.String()) 513 | } 514 | if vp.Size(pubQos) != varPublish.Size(pubQos) { 515 | t.Errorf("mismatch between publish variable sizes") 516 | } 517 | varEqual(t, varPublish, vp) 518 | callbackExecuted = true 519 | return nil 520 | } 521 | 522 | n, err := rxtx.ReadNextPacket() 523 | if err != nil { 524 | t.Fatal(err) 525 | } 526 | 527 | if n != int(expectedRemainingLen) { 528 | t.Errorf("read %v bytes, expected to read %v bytes", n, expectedRemainingLen) 529 | } 530 | if !callbackExecuted { 531 | t.Error("OnPub callback not executed") 532 | } 533 | } 534 | // 535 | // Send PUBLISH QoS0 packet over wire. 536 | // 537 | { 538 | const pubQos = QoS0 539 | publishPayload := []byte("\xa6\x32") 540 | varPublish := VariablesPublish{ 541 | TopicName: []byte("pressure"), 542 | PacketIdentifier: 0, // No packet ID for QoS0. 543 | } 544 | pubflags, err := NewPublishFlags(pubQos, false, false) 545 | if err != nil { 546 | t.Fatal(err) 547 | } 548 | expectedRemainingLen := uint32(varPublish.Size(pubQos) + len(publishPayload)) 549 | publishHeader := newHeader(PacketPublish, pubflags, expectedRemainingLen) 550 | err = rxtx.WritePublishPayload(publishHeader, varPublish, publishPayload) 551 | if err != nil { 552 | t.Fatal(err) 553 | } 554 | callbackExecuted := false 555 | rxtx.RxCallbacks.OnPub = func(rt *Rx, vp VariablesPublish, r io.Reader) error { 556 | b, err := io.ReadAll(r) 557 | if err != nil { 558 | t.Fatal(err) 559 | } 560 | if !bytes.Equal(b, publishPayload) { 561 | t.Error("got different payloads!") 562 | } 563 | if rt.LastReceivedHeader != publishHeader { 564 | t.Errorf("rxtx header mismatch, txed:%v, rxed:%v", publishHeader.String(), rt.LastReceivedHeader.String()) 565 | } 566 | if vp.Size(pubQos) != varPublish.Size(pubQos) { 567 | t.Errorf("mismatch between publish variable sizes") 568 | } 569 | varEqual(t, varPublish, vp) 570 | callbackExecuted = true 571 | return nil 572 | } 573 | 574 | n, err := rxtx.ReadNextPacket() 575 | if err != nil { 576 | t.Fatal(err) 577 | } 578 | 579 | if n != int(expectedRemainingLen) { 580 | t.Errorf("read %v bytes, expected to read %v bytes", n, expectedRemainingLen) 581 | } 582 | if !callbackExecuted { 583 | t.Error("OnPub callback not executed") 584 | } 585 | } 586 | 587 | // 588 | // Send PUBLISH packet over wire and ignore packet. 589 | // 590 | { 591 | const pubQos = QoS1 592 | publishPayload := []byte("ertytgbhjjhundsaip;vf[oniw[aondmiksfvoWDNFOEWOPndsafr;poulikujyhtgbfrvdcsxzaesxt dfcgvfhbg kjnlkm/'.") 593 | varPublish := VariablesPublish{ 594 | TopicName: []byte("now-for-something-completely-different"), 595 | PacketIdentifier: math.MaxUint16, 596 | } 597 | pubflags, err := NewPublishFlags(pubQos, true, true) 598 | if err != nil { 599 | t.Fatal(err) 600 | } 601 | 602 | publishHeader := newHeader(PacketPublish, pubflags, uint32(varPublish.Size(pubQos)+len(publishPayload))) 603 | err = rxtx.WritePublishPayload(publishHeader, varPublish, publishPayload) 604 | if err != nil { 605 | t.Fatal(err) 606 | } 607 | rxtx.RxCallbacks.OnPub = nil 608 | 609 | n, err := rxtx.ReadNextPacket() 610 | if err != nil { 611 | t.Fatal(err) 612 | } 613 | expectSize := publishHeader.Size() + varPublish.Size(pubQos) 614 | if n != expectSize { 615 | t.Errorf("read %v bytes, expected to read %v bytes", n, expectSize) 616 | } 617 | 618 | } 619 | 620 | // 621 | // Send SUBSCRIBE packet over wire. 622 | // 623 | { 624 | 625 | varsub := VariablesSubscribe{ 626 | PacketIdentifier: math.MaxUint16, 627 | TopicFilters: []SubscribeRequest{ 628 | {TopicFilter: []byte("favorites"), QoS: QoS2}, 629 | {TopicFilter: []byte("the-clash"), QoS: QoS2}, 630 | {TopicFilter: []byte("always-watching"), QoS: QoS2}, 631 | {TopicFilter: []byte("k-pop"), QoS: QoS2}, 632 | }, 633 | } 634 | err = rxtx.WriteSubscribe(varsub) 635 | if err != nil { 636 | t.Fatal(err) 637 | } 638 | 639 | expectHeader := newHeader(PacketSubscribe, PacketFlagsPubrelSubUnsub, uint32(varsub.Size())) 640 | callbackExecuted := false 641 | rxtx.RxCallbacks.OnSub = func(rt *Rx, vs VariablesSubscribe) error { 642 | if rt.LastReceivedHeader != expectHeader { 643 | t.Errorf("rxtx header mismatch, expect:%v, rxed:%v", expectHeader.String(), rt.LastReceivedHeader.String()) 644 | } 645 | varEqual(t, varsub, vs) 646 | callbackExecuted = true 647 | return nil 648 | } 649 | 650 | n, err := rxtx.ReadNextPacket() 651 | if err != nil { 652 | t.Fatal(err) 653 | } 654 | expectSize := expectHeader.Size() + varsub.Size() 655 | if n != expectSize { 656 | t.Errorf("read %v bytes, expected to read %v bytes", n, expectSize) 657 | } 658 | if !callbackExecuted { 659 | t.Error("OnSub callback not executed") 660 | } 661 | } 662 | 663 | // 664 | // Send UNSUBSCRIBE packet over wire. 665 | // 666 | { 667 | callbackExecuted := false 668 | varunsub := VariablesUnsubscribe{ 669 | PacketIdentifier: math.MaxUint16, 670 | Topics: bytes.Fields([]byte("topic1 topic2 topic3 semperfi")), 671 | } 672 | err = rxtx.WriteUnsubscribe(varunsub) 673 | if err != nil { 674 | t.Fatal(err) 675 | } 676 | expectHeader := newHeader(PacketUnsubscribe, PacketFlagsPubrelSubUnsub, uint32(varunsub.Size())) 677 | rxtx.RxCallbacks.OnUnsub = func(rt *Rx, vu VariablesUnsubscribe) error { 678 | if rt.LastReceivedHeader != expectHeader { 679 | t.Errorf("rxtx header mismatch, expect:%v, rxed:%v", expectHeader.String(), rt.LastReceivedHeader.String()) 680 | } 681 | varEqual(t, varunsub, vu) 682 | callbackExecuted = true 683 | return nil 684 | } 685 | 686 | n, err := rxtx.ReadNextPacket() 687 | if err != nil { 688 | t.Fatal(err) 689 | } 690 | expectSize := expectHeader.Size() + varunsub.Size() 691 | if n != expectSize { 692 | t.Errorf("read %v bytes, expected to read %v bytes", n, expectSize) 693 | } 694 | if !callbackExecuted { 695 | t.Error("OnUnsub callback not executed") 696 | } 697 | } 698 | 699 | // 700 | // Send SUBACK packet over wire. 701 | // 702 | { 703 | callbackExecuted := false 704 | varSuback := VariablesSuback{ 705 | PacketIdentifier: math.MaxUint16, 706 | ReturnCodes: []QoSLevel{QoS0, QoS1, QoS0, QoS2, QoSSubfail, QoS1}, 707 | } 708 | err = rxtx.WriteSuback(varSuback) 709 | if err != nil { 710 | t.Fatal(err) 711 | } 712 | expectHeader := newHeader(PacketSuback, 0, uint32(varSuback.Size())) 713 | rxtx.RxCallbacks.OnSuback = func(rt *Rx, vu VariablesSuback) error { 714 | if rt.LastReceivedHeader != expectHeader { 715 | t.Errorf("rxtx header mismatch, expect:%v, rxed:%v", expectHeader.String(), rt.LastReceivedHeader.String()) 716 | } 717 | varEqual(t, varSuback, vu) 718 | callbackExecuted = true 719 | return nil 720 | } 721 | 722 | n, err := rxtx.ReadNextPacket() 723 | if err != nil { 724 | t.Fatal(err) 725 | } 726 | expectSize := expectHeader.Size() + varSuback.Size() 727 | if n != expectSize { 728 | t.Errorf("read %v bytes, expected to read %v bytes", n, expectSize) 729 | } 730 | if !callbackExecuted { 731 | t.Error("OnSuback callback not executed") 732 | } 733 | } 734 | 735 | // 736 | // Send PUBREL packet over wire. 737 | // 738 | { 739 | callbackExecuted := false 740 | txPI := uint16(3232) 741 | expectHeader := newHeader(PacketPubrel, PacketFlagsPubrelSubUnsub, 2) 742 | err = rxtx.WriteIdentified(PacketPubrel, txPI) 743 | if err != nil { 744 | t.Fatal(err) 745 | } 746 | 747 | rxtx.RxCallbacks.OnOther = func(rt *Rx, gotPI uint16) error { 748 | if rt.LastReceivedHeader != expectHeader { 749 | t.Errorf("rxtx header mismatch, expect:%v, rxed:%v", expectHeader.String(), rt.LastReceivedHeader.String()) 750 | } 751 | if gotPI != txPI { 752 | t.Error("mismatch of packet identifiers", gotPI, txPI) 753 | } 754 | callbackExecuted = true 755 | return nil 756 | } 757 | 758 | n, err := rxtx.ReadNextPacket() 759 | if err != nil { 760 | t.Fatal(err) 761 | } 762 | expectSize := expectHeader.Size() + 2 763 | if n != expectSize { 764 | t.Errorf("read %v bytes, expected to read %v bytes", n, expectSize) 765 | } 766 | if !callbackExecuted { 767 | t.Error("OnOther callback not executed") 768 | } 769 | } 770 | 771 | // 772 | // Send PINGREQ packet over wire. 773 | // 774 | { 775 | callbackExecuted := false 776 | expectHeader := newHeader(PacketPingreq, 0, 0) 777 | err = rxtx.WriteSimple(PacketPingreq) 778 | if err != nil { 779 | t.Fatal(err) 780 | } 781 | 782 | rxtx.RxCallbacks.OnOther = func(rt *Rx, gotPI uint16) error { 783 | if rt.LastReceivedHeader != expectHeader { 784 | t.Errorf("rxtx header mismatch, expect:%v, rxed:%v", expectHeader.String(), rt.LastReceivedHeader.String()) 785 | } 786 | if gotPI != 0 { 787 | t.Error("mismatch of packet identifiers", gotPI, 0) 788 | } 789 | callbackExecuted = true 790 | return nil 791 | } 792 | 793 | n, err := rxtx.ReadNextPacket() 794 | if err != nil { 795 | t.Fatal(err) 796 | } 797 | expectSize := expectHeader.Size() 798 | if n != expectSize { 799 | t.Errorf("read %v bytes, expected to read %v bytes", n, expectSize) 800 | } 801 | if !callbackExecuted { 802 | t.Error("OnOther callback not executed") 803 | } 804 | } 805 | err = rxtx.CloseRx() // closes both. 806 | if err != nil { 807 | t.Error(err) 808 | } 809 | } 810 | 811 | func newLoopbackTransport() *testTransport { 812 | var _buf bytes.Buffer 813 | // buf := bufio.NewReadWriter(bufio.NewReader(&_buf), bufio.NewWriter(&_buf)) 814 | return &testTransport{&_buf} 815 | } 816 | 817 | type testTransport struct { 818 | rw *bytes.Buffer 819 | } 820 | 821 | func (t *testTransport) Close() error { 822 | t.rw = nil 823 | return nil 824 | } 825 | 826 | func (t *testTransport) Read(p []byte) (int, error) { 827 | if t.rw == nil { 828 | return 0, io.ErrClosedPipe 829 | } 830 | return t.rw.Read(p) 831 | } 832 | 833 | func (t *testTransport) Write(p []byte) (int, error) { 834 | if t.rw == nil { 835 | return 0, io.ErrClosedPipe 836 | } 837 | return t.rw.Write(p) 838 | } 839 | 840 | // varEqual errors test if a's fields not equal to b. Takes as argument all VariablesPACKET structs. 841 | // Expects pointer to VariablesConnect. 842 | func varEqual(t *testing.T, a, b any) { 843 | switch va := a.(type) { 844 | case *VariablesConnect: 845 | // Make name distinct to va to catch bugs easier. 846 | veebee := b.(*VariablesConnect) 847 | if va.CleanSession != veebee.CleanSession { 848 | t.Error("clean session mismatch") 849 | } 850 | if va.ProtocolLevel != veebee.ProtocolLevel { 851 | t.Error("protocol level mismatch") 852 | } 853 | if va.KeepAlive != veebee.KeepAlive { 854 | t.Error("willQoS mismatch") 855 | } 856 | if va.WillQoS != veebee.WillQoS { 857 | t.Error("willQoS mismatch") 858 | } 859 | if !bytes.Equal(va.ClientID, veebee.ClientID) { 860 | t.Error("client id mismatch") 861 | } 862 | if !bytes.Equal(va.Protocol, veebee.Protocol) { 863 | t.Error("protocol mismatch") 864 | } 865 | if !bytes.Equal(va.Password, veebee.Password) { 866 | t.Error("password mismatch") 867 | } 868 | if !bytes.Equal(va.Username, veebee.Username) { 869 | t.Error("username mismatch") 870 | } 871 | if !bytes.Equal(va.WillMessage, veebee.WillMessage) { 872 | t.Error("will message mismatch") 873 | } 874 | if !bytes.Equal(va.WillTopic, veebee.WillTopic) { 875 | t.Error("will topic mismatch") 876 | } 877 | 878 | case VariablesConnack: 879 | vb := b.(VariablesConnack) 880 | if va != vb { 881 | t.Error("CONNACK not equal:", va, vb) 882 | } 883 | 884 | case VariablesPublish: 885 | vb := b.(VariablesPublish) 886 | if !bytes.Equal(va.TopicName, vb.TopicName) { 887 | t.Error("publish topic names mismatch") 888 | } 889 | if va.PacketIdentifier != vb.PacketIdentifier { 890 | t.Error("packet id mismatch") 891 | } 892 | 893 | case VariablesSuback: 894 | vb := b.(VariablesSuback) 895 | if va.PacketIdentifier != vb.PacketIdentifier { 896 | t.Error("SUBACK packet identifier mismatch") 897 | } 898 | for i, rca := range va.ReturnCodes { 899 | rcb := vb.ReturnCodes[i] 900 | if rca != rcb { 901 | t.Errorf("SUBACK %dth return code mismatch, %s! = %s", i, rca, rcb) 902 | } 903 | } 904 | 905 | case VariablesSubscribe: 906 | vb := b.(VariablesSubscribe) 907 | if va.PacketIdentifier != vb.PacketIdentifier { 908 | t.Error("SUBSCRIBE packet identifier mismatch") 909 | } 910 | for i, hotopicA := range va.TopicFilters { 911 | hotTopicB := vb.TopicFilters[i] 912 | if hotopicA.QoS != hotTopicB.QoS { 913 | t.Errorf("SUBSCRIBE %dth QoS mismatch, %s! = %s", i, hotopicA.QoS, hotTopicB.QoS) 914 | } 915 | if !bytes.Equal(hotopicA.TopicFilter, hotTopicB.TopicFilter) { 916 | t.Errorf("SUBSCRIBE %dth topic filter mismatch, %s! = %s", i, string(hotopicA.TopicFilter), string(hotTopicB.TopicFilter)) 917 | } 918 | } 919 | 920 | case VariablesUnsubscribe: 921 | vb := b.(VariablesUnsubscribe) 922 | if va.PacketIdentifier != vb.PacketIdentifier { 923 | t.Error("UNSUBSCRIBE packet identifier mismatch", va.PacketIdentifier, vb.PacketIdentifier) 924 | } 925 | for i, coldtopicA := range va.Topics { 926 | coldTopicB := vb.Topics[i] 927 | if !bytes.Equal(coldtopicA, coldTopicB) { 928 | t.Errorf("UNSUBSCRIBE %dth topic mismatch, %s! = %s", i, coldtopicA, coldTopicB) 929 | } 930 | } 931 | 932 | default: 933 | panic(fmt.Sprintf("%T undefined in varEqual", va)) 934 | } 935 | } 936 | 937 | var fuzzCorpus = [][]byte{ 938 | []byte("00\x0000"), 939 | []byte("\x90\xa7000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 940 | []byte("\xa2A00\x00\x06000000\x00\x06000000\x00\b00000000\x00\x06000000\x00\x06000000\x00\x06000000\x00\b000000000"), 941 | []byte("\x900000000000000000000"), 942 | []byte("\x100\x00\x0400000\xec00\x00\x0200\x00\x0200\x0000"), 943 | []byte("\x82000"), 944 | []byte("\x900000000000000000"), 945 | []byte("\x90000000000000000000"), 946 | []byte("\x100\x00\x0400000\x8000\x00\x0200\x0000"), 947 | []byte("\x100\x00\xbf00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 948 | []byte("\xa000"), 949 | []byte("20\x0000"), 950 | []byte("\x82000\x00\t0000000000\x00\x020000"), 951 | []byte("\x100\x00\x0400000$00\x00\x0200\x0000"), 952 | []byte("00\x0400000000000000000000000000000000000000000000000000000000000000"), 953 | []byte("\x900000"), 954 | []byte("\xa00"), 955 | []byte("0"), 956 | []byte(" 00"), 957 | []byte("0\xfe\xff\xff"), 958 | []byte("a0"), 959 | []byte("A0"), 960 | []byte("\x100\x0200"), 961 | []byte("\x100\x00\x0400000"), 962 | []byte("\x100\x00\x0400000$00\x0000"), 963 | []byte("\x100\x00\x04000000"), 964 | []byte("\xa2 00\x00000000000000000000000000000000000"), 965 | []byte("\xa2000\x00\x06000000\x00\x060000000"), 966 | []byte("00\x00\x000"), 967 | []byte("\x100\x00\x0400000\xec00\x00\x0200\x00\x0200\x00\x0200\x00\x0200\x00"), 968 | []byte("\x82000\x00\x0200000"), 969 | []byte("\x820"), 970 | []byte("\x9000000"), 971 | []byte("0\xee\xff\xff"), 972 | []byte("0\x8e\x01\x00 00000000000000000000000000000000000000000"), 973 | []byte(" 0\x00"), 974 | []byte("0\xff\xc4"), 975 | []byte("\x90\xa700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 976 | []byte("\x9000000000000000000000000000000000000"), 977 | []byte("\x82000\x000"), 978 | []byte(" 0"), 979 | []byte("\x100\x00\x0400000B"), 980 | []byte("\x82000\x0000"), 981 | []byte("\x90000"), 982 | []byte("000"), 983 | []byte("0\xff0\x00\t0000000000"), 984 | []byte("\x100\x00\x040000000000"), 985 | []byte("0\xbb0\x0100"), 986 | []byte("\x9000000000000"), 987 | []byte(""), 988 | []byte("\x100\x00\x04000001"), 989 | []byte("\x100\x00\x0400000$00\x05\xe20"), 990 | []byte("b0"), 991 | []byte("\xa2A00\x00\x06000000\x00\x06000000\x00\b00000000\x00\x06000000\x00\x06000000\x00\x06000000\x00\b00000000"), 992 | []byte("\x90B0000000000000000000000000000000000000000000000000000000000000000000"), 993 | []byte("\x100\x00\x0400000000\x0000"), 994 | []byte("\x900"), 995 | []byte("\x90\xa700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 996 | []byte("\xa2A00\x00\x06000000\x00\x06000000\x00\b00000000\x00\x06000000\x00\x06000000\x00\x06000000\x00\b00000000\x0000"), 997 | []byte("00\x0400"), 998 | []byte("\x100"), 999 | } 1000 | -------------------------------------------------------------------------------- /rxtx.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "io" 8 | ) 9 | 10 | // Rx implements a bare minimum MQTT v3.1.1 protocol transport layer handler. 11 | // Packages are received by calling [Rx.ReadNextPacket] and setting the callback 12 | // in Rx corresponding to the expected packet. 13 | // Rx will perform basic validation of input data according to MQTT's specification. 14 | // If there is an error after reading the first byte of a packet the transport is closed 15 | // and a new transport must be set with [Rx.SetRxTransport]. 16 | // If OnRxError is set the underlying transport is not automatically closed and 17 | // it becomes the callback's responsibility to close the transport. 18 | // 19 | // Not safe for concurrent use. 20 | type Rx struct { 21 | // Transport over which packets are read and written to. 22 | // Not exported since RxTx type might be composed of embedded Rx and Tx types in future. TBD. 23 | rxTrp io.ReadCloser 24 | RxCallbacks RxCallbacks 25 | // User defined decoder for allocating packets. 26 | userDecoder Decoder 27 | // ScratchBuf is lazily allocated to exhaust Publish payloads when received and no 28 | // OnPub callback is set. 29 | ScratchBuf []byte 30 | // LastReceivedHeader contains the last correctly read header. 31 | LastReceivedHeader Header 32 | // LimitedReader field prevents a heap allocation in ReadNext since passing 33 | // a stack allocated LimitedReader into RxCallbacks.OnPub will escape inconditionally. 34 | packetLimitReader io.LimitedReader 35 | } 36 | 37 | // RxCallbacks groups all functionality executed on data receipt, both successful 38 | // and unsuccessful. 39 | type RxCallbacks struct { 40 | // Functions below can access the Header of the message via RxTx.LastReceivedHeader. 41 | // All these functions block RxTx.ReadNextPacket. 42 | OnConnect func(*Rx, *VariablesConnect) error // Receives pointer because of large struct! 43 | // OnConnack is called on a CONNACK packet receipt. 44 | OnConnack func(*Rx, VariablesConnack) error 45 | // OnPub is called on PUBLISH packet receive. The [io.Reader] points to the transport's reader 46 | // and is limited to read the amount of bytes in the payload as given by RemainingLength. 47 | // One may calculate amount of bytes in the reader like so: 48 | // payloadLen := rx.LastReceivedHeader.RemainingLength - varPub.Size() 49 | // It is important to note the reader `r` will be invalidated on the next incoming publish packet, 50 | // calling r after this point will result in undefined behaviour. 51 | OnPub func(rx *Rx, varPub VariablesPublish, r io.Reader) error 52 | // OnOther takes in the Header of received packet and a packet identifier uint16 if present. 53 | // OnOther receives PUBACK, PUBREC, PUBREL, PUBCOMP, UNSUBACK packets containing non-zero packet identfiers 54 | // and DISCONNECT, PINGREQ, PINGRESP packets with no packet identifier. 55 | OnOther func(rx *Rx, packetIdentifier uint16) error 56 | OnSub func(*Rx, VariablesSubscribe) error 57 | OnSuback func(*Rx, VariablesSuback) error 58 | OnUnsub func(*Rx, VariablesUnsubscribe) error 59 | // OnRxError is called if an error is encountered during decoding of packet. 60 | // If it is set then it becomes the responsibility of the callback to close the transport. 61 | OnRxError func(*Rx, error) 62 | } 63 | 64 | // SetRxTransport sets the rx's reader. 65 | func (rx *Rx) SetRxTransport(transport io.ReadCloser) { 66 | rx.rxTrp = transport 67 | } 68 | 69 | // Close closes the underlying transport. 70 | func (rx *Rx) CloseRx() error { return rx.rxTrp.Close() } 71 | func (rx *Rx) rxErrHandler(err error) { 72 | if rx.RxCallbacks.OnRxError != nil { 73 | rx.RxCallbacks.OnRxError(rx, err) 74 | } else { 75 | rx.CloseRx() 76 | } 77 | } 78 | 79 | // ReadNextPacket reads the next packet in the transport. If it fails after reading a 80 | // non-zero amount of bytes it closes the transport and the underlying transport must be reset. 81 | func (rx *Rx) ReadNextPacket() (int, error) { 82 | if rx.rxTrp == nil { 83 | return 0, errors.New("nil transport") 84 | } 85 | rx.LastReceivedHeader = Header{} 86 | hdr, n, err := DecodeHeader(rx.rxTrp) 87 | if err != nil { 88 | if n > 0 { 89 | rx.rxErrHandler(err) 90 | } 91 | return n, err 92 | } 93 | rx.LastReceivedHeader = hdr 94 | var ( 95 | packetType = hdr.Type() 96 | ngot int 97 | packetIdentifier uint16 98 | ) 99 | switch packetType { 100 | case PacketPublish: 101 | packetFlags := hdr.Flags() 102 | qos := packetFlags.QoS() 103 | var vp VariablesPublish 104 | vp, ngot, err = rx.userDecoder.DecodePublish(rx.rxTrp, qos) 105 | n += ngot 106 | if err != nil { 107 | break 108 | } 109 | payloadLen := int(hdr.RemainingLength) - ngot 110 | rx.packetLimitReader = io.LimitedReader{R: rx.rxTrp, N: int64(payloadLen)} 111 | if rx.RxCallbacks.OnPub != nil { 112 | err = rx.RxCallbacks.OnPub(rx, vp, &rx.packetLimitReader) 113 | } else { 114 | err = rx.exhaustReader(&rx.packetLimitReader) 115 | } 116 | 117 | if rx.packetLimitReader.N != 0 && err == nil { 118 | err = errors.New("expected OnPub to completely read payload") 119 | break 120 | } 121 | 122 | case PacketConnack: 123 | if hdr.RemainingLength != 2 { 124 | err = ErrBadRemainingLen 125 | break 126 | } 127 | var vc VariablesConnack 128 | vc, ngot, err = decodeConnack(rx.rxTrp) 129 | n += ngot 130 | if err != nil { 131 | break 132 | } 133 | if rx.RxCallbacks.OnConnack != nil { 134 | err = rx.RxCallbacks.OnConnack(rx, vc) 135 | } 136 | 137 | case PacketConnect: 138 | // if hdr.RemainingLength != 0 { // TODO(soypat): What's the minimum RL for CONNECT? 139 | // err = ErrBadRemainingLen 140 | // break 141 | // } 142 | var vc VariablesConnect 143 | vc, ngot, err = rx.userDecoder.DecodeConnect(rx.rxTrp) 144 | n += ngot 145 | if err != nil { 146 | break 147 | } 148 | if rx.RxCallbacks.OnConnect != nil { 149 | err = rx.RxCallbacks.OnConnect(rx, &vc) 150 | } 151 | 152 | case PacketSuback: 153 | if hdr.RemainingLength < 2 { 154 | err = ErrBadRemainingLen 155 | break 156 | } 157 | var vsbck VariablesSuback 158 | vsbck, ngot, err = decodeSuback(rx.rxTrp, hdr.RemainingLength) 159 | n += ngot 160 | if err != nil { 161 | break 162 | } 163 | if rx.RxCallbacks.OnSuback != nil { 164 | err = rx.RxCallbacks.OnSuback(rx, vsbck) 165 | } 166 | 167 | case PacketSubscribe: 168 | var vsbck VariablesSubscribe 169 | vsbck, ngot, err = rx.userDecoder.DecodeSubscribe(rx.rxTrp, hdr.RemainingLength) 170 | n += ngot 171 | if err != nil { 172 | break 173 | } 174 | if rx.RxCallbacks.OnSub != nil { 175 | err = rx.RxCallbacks.OnSub(rx, vsbck) 176 | } 177 | 178 | case PacketUnsubscribe: 179 | var vunsub VariablesUnsubscribe 180 | vunsub, ngot, err = rx.userDecoder.DecodeUnsubscribe(rx.rxTrp, hdr.RemainingLength) 181 | n += ngot 182 | if err != nil { 183 | break 184 | } 185 | if rx.RxCallbacks.OnUnsub != nil { 186 | err = rx.RxCallbacks.OnUnsub(rx, vunsub) 187 | } 188 | 189 | case PacketPuback, PacketPubrec, PacketPubrel, PacketPubcomp, PacketUnsuback: 190 | if hdr.RemainingLength != 2 { 191 | err = ErrBadRemainingLen 192 | break 193 | } 194 | // Only PI, no payload. 195 | packetIdentifier, ngot, err = decodeUint16(rx.rxTrp) 196 | n += ngot 197 | if err != nil { 198 | break 199 | } 200 | if rx.RxCallbacks.OnOther != nil { 201 | err = rx.RxCallbacks.OnOther(rx, packetIdentifier) 202 | } 203 | 204 | case PacketDisconnect, PacketPingreq, PacketPingresp: 205 | if hdr.RemainingLength != 0 { 206 | err = ErrBadRemainingLen 207 | break 208 | } 209 | // No payload or variable header. 210 | if rx.RxCallbacks.OnOther != nil { 211 | err = rx.RxCallbacks.OnOther(rx, packetIdentifier) 212 | } 213 | 214 | default: 215 | // Header Decode should return an error on incorrect packet type receive. 216 | // This could be tested via fuzzing. 217 | panic("unreachable") 218 | } 219 | 220 | if err != nil { 221 | rx.rxErrHandler(err) 222 | } 223 | return n, err 224 | } 225 | 226 | // RxTransport returns the underlying transport handler. It may be nil. 227 | func (rx *Rx) RxTransport() io.ReadCloser { 228 | return rx.rxTrp 229 | } 230 | 231 | // ShallowCopy shallow copies rx and underlying transport and decoder. Does not copy callbacks over. 232 | func (rx *Rx) ShallowCopy() *Rx { 233 | return &Rx{rxTrp: rx.rxTrp, userDecoder: rx.userDecoder} 234 | } 235 | 236 | func (rx *Rx) exhaustReader(r io.Reader) (err error) { 237 | if len(rx.ScratchBuf) == 0 { 238 | rx.ScratchBuf = make([]byte, 1024) // Lazy initialization when needed. 239 | } 240 | for err == nil { 241 | _, err = r.Read(rx.ScratchBuf[:]) 242 | } 243 | if errors.Is(err, io.EOF) { 244 | return nil 245 | } 246 | return err 247 | } 248 | 249 | // Tx implements a bare minimum MQTT v3.1.1 protocol transport layer handler for transmitting packets. 250 | // If there is an error during read/write of a packet the transport is closed 251 | // and a new transport must be set with [Tx.SetTxTransport]. 252 | // A Tx will not validate data before encoding, that is up to the caller, Malformed packets 253 | // will be rejected and the connection will be closed immediately. If OnTxError is 254 | // set then the underlying transport is not closed and it becomes responsibility 255 | // of the callback to close the transport. 256 | type Tx struct { 257 | txTrp io.WriteCloser 258 | TxCallbacks TxCallbacks 259 | buffer bytes.Buffer 260 | } 261 | 262 | // TxCallbacks groups functionality executed on transmission success or failure 263 | // of an MQTT packet. 264 | type TxCallbacks struct { 265 | // OnTxError is called if an error is encountered during encoding. If it is set 266 | // then it becomes the responsibility of the callback to close Tx's transport. 267 | OnTxError func(*Tx, error) 268 | // OnSuccessfulTx is called after a MQTT packet is fully written to the underlying transport. 269 | OnSuccessfulTx func(*Tx) 270 | } 271 | 272 | // TxTransport returns the underlying transport handler. It may be nil. 273 | func (tx *Tx) TxTransport() io.WriteCloser { 274 | return tx.txTrp 275 | } 276 | 277 | // SetTxTransport sets the tx's writer. 278 | func (tx *Tx) SetTxTransport(transport io.WriteCloser) { 279 | tx.txTrp = transport 280 | } 281 | 282 | // WriteConnack writes a CONNECT packet over the transport. 283 | func (tx *Tx) WriteConnect(varConn *VariablesConnect) error { 284 | if tx.txTrp == nil { 285 | return errors.New("nil transport") 286 | } 287 | buffer := &tx.buffer 288 | buffer.Reset() 289 | h := newHeader(PacketConnect, 0, uint32(varConn.Size())) 290 | _, err := h.Encode(buffer) 291 | if err != nil { 292 | return err 293 | } 294 | _, err = encodeConnect(buffer, varConn) 295 | if err != nil { 296 | return err 297 | } 298 | n, err := buffer.WriteTo(tx.txTrp) 299 | if err != nil && n > 0 { 300 | tx.prepClose(err) 301 | } else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil { 302 | tx.TxCallbacks.OnSuccessfulTx(tx) 303 | } 304 | return err 305 | } 306 | 307 | // WriteConnack writes a CONNACK packet over the transport. 308 | func (tx *Tx) WriteConnack(varConnack VariablesConnack) error { 309 | if tx.txTrp == nil { 310 | return errors.New("nil transport") 311 | } 312 | buffer := &tx.buffer 313 | buffer.Reset() 314 | h := newHeader(PacketConnack, 0, uint32(varConnack.Size())) 315 | _, err := h.Encode(buffer) 316 | if err != nil { 317 | return err 318 | } 319 | _, err = encodeConnack(buffer, varConnack) 320 | if err != nil { 321 | return err 322 | } 323 | n, err := buffer.WriteTo(tx.txTrp) 324 | if err != nil && n > 0 { 325 | tx.prepClose(err) 326 | } else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil { 327 | tx.TxCallbacks.OnSuccessfulTx(tx) 328 | } 329 | return err 330 | } 331 | 332 | // WritePublishPayload writes a PUBLISH packet over the transport along with the 333 | // Application Message in the payload. payload can be zero-length. 334 | func (tx *Tx) WritePublishPayload(h Header, varPub VariablesPublish, payload []byte) error { 335 | if tx.txTrp == nil { 336 | return errors.New("nil transport") 337 | } 338 | buffer := &tx.buffer 339 | buffer.Reset() 340 | qos := h.Flags().QoS() 341 | h.RemainingLength = uint32(varPub.Size(qos) + len(payload)) 342 | _, err := h.Encode(buffer) 343 | if err != nil { 344 | return err 345 | } 346 | _, err = encodePublish(buffer, qos, varPub) 347 | if err != nil { 348 | return err 349 | } 350 | _, err = writeFull(buffer, payload) 351 | if err != nil { 352 | return err 353 | } 354 | n, err := buffer.WriteTo(tx.txTrp) 355 | if err != nil && n > 0 { 356 | tx.prepClose(err) 357 | } else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil { 358 | tx.TxCallbacks.OnSuccessfulTx(tx) 359 | } 360 | return err 361 | } 362 | 363 | // WriteSubscribe writes an SUBSCRIBE packet over the transport. 364 | func (tx *Tx) WriteSubscribe(varSub VariablesSubscribe) error { 365 | if tx.txTrp == nil { 366 | return errors.New("nil transport") 367 | } 368 | buffer := &tx.buffer 369 | buffer.Reset() 370 | h := newHeader(PacketSubscribe, PacketFlagsPubrelSubUnsub, uint32(varSub.Size())) 371 | _, err := h.Encode(buffer) 372 | if err != nil { 373 | return err 374 | } 375 | _, err = encodeSubscribe(buffer, varSub) 376 | if err != nil { 377 | return err 378 | } 379 | n, err := buffer.WriteTo(tx.txTrp) 380 | if err != nil && n > 0 { 381 | tx.prepClose(err) 382 | } else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil { 383 | tx.TxCallbacks.OnSuccessfulTx(tx) 384 | } 385 | return err 386 | } 387 | 388 | // WriteSuback writes an UNSUBACK packet over the transport. 389 | func (tx *Tx) WriteSuback(varSub VariablesSuback) error { 390 | if tx.txTrp == nil { 391 | return errors.New("nil transport") 392 | } 393 | if err := varSub.Validate(); err != nil { 394 | return err 395 | } 396 | buffer := &tx.buffer 397 | buffer.Reset() 398 | h := newHeader(PacketSuback, 0, uint32(varSub.Size())) 399 | _, err := h.Encode(buffer) 400 | if err != nil { 401 | return err 402 | } 403 | _, err = encodeSuback(buffer, varSub) 404 | if err != nil { 405 | return err 406 | } 407 | n, err := buffer.WriteTo(tx.txTrp) 408 | if err != nil && n > 0 { 409 | tx.prepClose(err) 410 | } else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil { 411 | tx.TxCallbacks.OnSuccessfulTx(tx) 412 | } 413 | return err 414 | } 415 | 416 | // WriteUnsubscribe writes an UNSUBSCRIBE packet over the transport. 417 | func (tx *Tx) WriteUnsubscribe(varUnsub VariablesUnsubscribe) error { 418 | if tx.txTrp == nil { 419 | return errors.New("nil transport") 420 | } 421 | buffer := &tx.buffer 422 | buffer.Reset() 423 | h := newHeader(PacketUnsubscribe, PacketFlagsPubrelSubUnsub, uint32(varUnsub.Size())) 424 | _, err := h.Encode(buffer) 425 | if err != nil { 426 | return err 427 | } 428 | _, err = encodeUnsubscribe(buffer, varUnsub) 429 | if err != nil { 430 | return err 431 | } 432 | n, err := buffer.WriteTo(tx.txTrp) 433 | if err != nil && n > 0 { 434 | tx.prepClose(err) 435 | } else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil { 436 | tx.TxCallbacks.OnSuccessfulTx(tx) 437 | } 438 | return err 439 | } 440 | 441 | // WriteIdentified writes PUBACK, PUBREC, PUBREL, PUBCOMP, UNSUBACK packets containing non-zero packet identfiers 442 | // It automatically sets the RemainingLength field to 2. 443 | func (tx *Tx) WriteIdentified(packetType PacketType, packetIdentifier uint16) (err error) { 444 | if tx.txTrp == nil { 445 | return errors.New("nil transport") 446 | } 447 | if packetIdentifier == 0 { 448 | return errGotZeroPI 449 | } 450 | // This packet has special QoS1 flag. 451 | isPubrelSubUnsub := packetType == PacketPubrel 452 | if !(isPubrelSubUnsub || packetType == PacketPuback || packetType == PacketPubrec || 453 | packetType == PacketPubcomp || packetType == PacketUnsuback) { 454 | return errors.New("expected a packet type from PUBACK|PUBREL|PUBCOMP|UNSUBACK") 455 | } 456 | 457 | var buf [5 + 2]byte 458 | n := newHeader(packetType, PacketFlags(b2u8(isPubrelSubUnsub)<<1), 2).Put(buf[:]) 459 | binary.BigEndian.PutUint16(buf[n:], packetIdentifier) 460 | n, err = writeFull(tx.txTrp, buf[:n+2]) 461 | 462 | if err != nil && n > 0 { 463 | tx.prepClose(err) 464 | } else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil { 465 | tx.TxCallbacks.OnSuccessfulTx(tx) 466 | } 467 | return err 468 | } 469 | 470 | // WriteSimple facilitates easy sending of the 2 octet DISCONNECT, PINGREQ, PINGRESP packets. 471 | // If the packet is not one of these then an error is returned. 472 | // It also returns an error with encoding step if there was one. 473 | func (tx *Tx) WriteSimple(packetType PacketType) (err error) { 474 | if tx.txTrp == nil { 475 | return errors.New("nil transport") 476 | } 477 | isValid := packetType == PacketDisconnect || packetType == PacketPingreq || packetType == PacketPingresp 478 | if !isValid { 479 | return errors.New("expected packet type from PINGREQ|PINGRESP|DISCONNECT") 480 | } 481 | n, err := newHeader(packetType, 0, 0).Encode(tx.txTrp) 482 | if err != nil && n > 0 { 483 | tx.prepClose(err) 484 | } else if tx.TxCallbacks.OnSuccessfulTx != nil && err == nil { 485 | tx.TxCallbacks.OnSuccessfulTx(tx) 486 | } 487 | return err 488 | } 489 | 490 | // Close closes the underlying tranport and returns an error if any. 491 | func (tx *Tx) CloseTx() error { return tx.txTrp.Close() } 492 | 493 | func (tx *Tx) prepClose(err error) { 494 | if tx.TxCallbacks.OnTxError != nil { 495 | tx.TxCallbacks.OnTxError(tx, err) 496 | } else { 497 | tx.txTrp.Close() 498 | } 499 | } 500 | 501 | // ShallowCopy shallow copies rx and underlying transport and encoder. Does not copy callbacks over. 502 | func (tx *Tx) ShallowCopy() *Tx { 503 | return &Tx{txTrp: tx.txTrp} 504 | } 505 | -------------------------------------------------------------------------------- /subscriptions.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "errors" 5 | "strings" 6 | ) 7 | 8 | // Subscriptions is a WIP. 9 | 10 | // subscriptions provides clients and servers with a way to manage requested 11 | // topic published messages. subscriptions is an abstraction over state, not 12 | // input/output operations, so calls to Subscribe should not write bytes over a transport. 13 | type subscriptions interface { 14 | // Subscribe takes a []byte slice to make it explicit and abundantly clear that 15 | // Subscriptions is in charge of the memory corresponding to subscription topics. 16 | // This is to say that Subscriptions should copy topic contents into its own memory 17 | // storage mechanism or allocate the topic on the heap. 18 | Subscribe(topic []byte) error 19 | 20 | // Successfully matched topics are stored in the userBuffer and returned 21 | // as a slice of byte slices. 22 | 23 | // Match finds all subscribers to a topic or a filter. 24 | Match(topicFilter string, userBuffer []byte) ([][]byte, error) 25 | 26 | Unsubscribe(topicFilter string, userBuffer []byte) ([][]byte, error) 27 | } 28 | 29 | // TODO(soypat): Add AVL tree implementation like the one in github.com/soypat/go-canard, supposedly is best data structure for this [citation needed]. 30 | 31 | var _ subscriptions = subscriptionsMap{} 32 | 33 | // subscriptionsMap implements Subscriptions interface with a map. 34 | // It performs allocations. 35 | type subscriptionsMap map[string]struct{} 36 | 37 | func (sm subscriptionsMap) Subscribe(topic []byte) error { 38 | tp := string(topic) 39 | if _, ok := sm[tp]; ok { 40 | return errors.New("topic already exists in subscriptions") 41 | } 42 | sm[tp] = struct{}{} 43 | return nil 44 | } 45 | 46 | func (sm subscriptionsMap) Unsubscribe(topicFilter string, userBuffer []byte) (matched [][]byte, err error) { 47 | return sm.match(topicFilter, userBuffer, true) 48 | } 49 | 50 | func (sm subscriptionsMap) Match(topicFilter string, userBuffer []byte) (matched [][]byte, err error) { 51 | return sm.match(topicFilter, userBuffer, false) 52 | } 53 | 54 | func (sm subscriptionsMap) match(topicFilter string, userBuffer []byte, deleteMatches bool) (matched [][]byte, err error) { 55 | n := 0 // Bytes copied into userBuffer. 56 | filterParts := strings.Split(topicFilter, "/") 57 | if err := validateWildcards(filterParts); err != nil { 58 | return nil, err 59 | } 60 | 61 | _, hasNonWildSub := sm[topicFilter] 62 | if hasNonWildSub { 63 | if len(topicFilter) > len(userBuffer) { 64 | return nil, ErrUserBufferFull 65 | } 66 | n += copy(userBuffer, topicFilter) 67 | matched = append(matched, userBuffer[:n]) 68 | userBuffer = userBuffer[n:] 69 | if deleteMatches { 70 | delete(sm, topicFilter) 71 | } 72 | } 73 | 74 | for k := range sm { 75 | parts := strings.Split(k, "/") 76 | if matches(filterParts, parts) { 77 | if len(k) > len(userBuffer) { 78 | return matched, ErrUserBufferFull 79 | } 80 | n += copy(userBuffer, k) 81 | matched = append(matched, userBuffer[:n]) 82 | userBuffer = userBuffer[n:] 83 | if deleteMatches { 84 | delete(sm, k) 85 | } 86 | } 87 | } 88 | return matched, nil 89 | } 90 | 91 | func matches(filter, topicParts []string) bool { 92 | i := 0 93 | for i < len(topicParts) { 94 | // topic is longer, no match 95 | if i >= len(filter) { 96 | return false 97 | } 98 | // matched up to here, and now the wildcard says "all others will match" 99 | if filter[i] == "#" { 100 | return true 101 | } 102 | // text does not match, and there wasn't a + to excuse it 103 | if topicParts[i] != filter[i] && filter[i] != "+" { 104 | return false 105 | } 106 | i++ 107 | } 108 | 109 | // make finance/stock/ibm/# match finance/stock/ibm 110 | return i == len(filter)-1 && filter[len(filter)-1] == "#" || i == len(filter) 111 | } 112 | 113 | func isWildcard(topic string) bool { 114 | return strings.IndexByte(topic, '#') >= 0 || strings.IndexByte(topic, '+') >= 0 115 | } 116 | 117 | func validateWildcards(wildcards []string) error { 118 | for i, part := range wildcards { 119 | // catch things like finance# 120 | if isWildcard(part) && len(part) != 1 { 121 | return errors.New("malformed wildcard of style \"finance#\"") 122 | } 123 | isSingle := len(part) == 1 && part[0] == '#' 124 | // # can only occur as the last part 125 | if isSingle && i != len(wildcards)-1 { 126 | return errors.New("last wildcard is single \"#\"") 127 | } 128 | } 129 | return nil 130 | } 131 | -------------------------------------------------------------------------------- /travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - 1.19.x 5 | - tip 6 | 7 | before_install: 8 | - go mod tidy 9 | 10 | script: 11 | - go test -coverprofile=coverage.txt -covermode=atomic 12 | 13 | after_success: 14 | - bash <(curl -s https://codecov.io/bash) --------------------------------------------------------------------------------