├── go.mod ├── integration ├── README.md ├── docker-compose.yml └── integration_test.go ├── LICENSE ├── .github └── workflows │ └── go.yml ├── mqtttest ├── mqtttest_test.go └── mqtttest.go ├── README.md ├── example_test.go ├── mqtt_test.go ├── cmd └── mqttc │ └── main.go ├── mqtt.go ├── client_test.go ├── request.go ├── request_test.go └── client.go /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/pascaldekloe/mqtt 2 | 3 | go 1.21 4 | -------------------------------------------------------------------------------- /integration/README.md: -------------------------------------------------------------------------------- 1 | # Integration 2 | 3 | ## Testing 4 | 5 | Run `docker-compose up --exit-code-from test` to validate the client against 6 | various MQTT implementations. 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | To the extent possible under law, Pascal S. de Kloe has waived all 2 | copyright and related or neighboring rights to MQTT🤖. This work is 3 | published from The Netherlands. 4 | 5 | https://creativecommons.org/publicdomain/zero/1.0/legalcode 6 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | name: Go 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | jobs: 10 | 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Set up Go 17 | uses: actions/setup-go@v3 18 | with: 19 | go-version: '1.21' 20 | 21 | - name: Build 22 | run: go build -v ./cmd/... 23 | 24 | - name: Test 25 | run: go test -v ./... 26 | 27 | - name: Integration Test 28 | run: cd integration && docker compose up --exit-code-from test 29 | -------------------------------------------------------------------------------- /integration/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | services: 3 | test: 4 | image: golang:1.21 5 | mem_limit: 500m 6 | volumes: 7 | - ..:/code 8 | command: go test -v -race 9 | working_dir: /code/integration 10 | depends_on: 11 | activemq: 12 | condition: service_healthy 13 | environment: 14 | - MQTT_HOSTS=activemq emqx hivemq mosquitto rumqttd vernemq 15 | 16 | activemq: 17 | image: apache/activemq-artemis:2.39.0 18 | mem_limit: 500m 19 | # TCP "connection refused" error before HTTP service 20 | healthcheck: 21 | test: ["CMD", "curl", "http://localhost:8161"] 22 | interval: 2s 23 | timeout: 1s 24 | retries: 10 25 | 26 | emqx: 27 | image: emqx/emqx:5.7.2 28 | mem_limit: 500m 29 | hivemq: 30 | image: hivemq/hivemq-ce:2024.7 31 | mem_limit: 2000m 32 | mosquitto: 33 | image: eclipse-mosquitto:1.6.15 34 | mem_limit: 500m 35 | rumqttd: 36 | image: bytebeamio/rumqttd:0.19.0 37 | mem_limit: 100m 38 | vernemq: 39 | image: vernemq/vernemq:1.13.0 40 | mem_limit: 1000m 41 | environment: 42 | - DOCKER_VERNEMQ_ACCEPT_EULA=yes 43 | - DOCKER_VERNEMQ_ALLOW_ANONYMOUS=on 44 | -------------------------------------------------------------------------------- /mqtttest/mqtttest_test.go: -------------------------------------------------------------------------------- 1 | package mqtttest_test 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | "testing" 7 | "time" 8 | 9 | "github.com/pascaldekloe/mqtt" 10 | "github.com/pascaldekloe/mqtt/mqtttest" 11 | ) 12 | 13 | // Signatures 14 | var ( 15 | client mqtt.Client 16 | subscribe = client.Subscribe 17 | unsubscribe = client.Unsubscribe 18 | publish = client.Publish 19 | publishEnqueued = client.PublishAtLeastOnce 20 | readSlices = client.ReadSlices 21 | ) 22 | 23 | // Won't compile on failure. 24 | func TestSignatureMatch(t *testing.T) { 25 | var c mqtt.Client 26 | // check dupe assumptions 27 | subscribe = c.SubscribeLimitAtMostOnce 28 | subscribe = c.SubscribeLimitAtLeastOnce 29 | publishEnqueued = c.PublishExactlyOnce 30 | 31 | // check fits 32 | readSlices = mqtttest.NewReadSlicesStub(mqtttest.Transfer{}) 33 | readSlices = mqtttest.NewReadSlicesMock(t) 34 | publish = mqtttest.NewPublishMock(t) 35 | publish = mqtttest.NewPublishStub(nil) 36 | publishEnqueued = mqtttest.NewPublishExchangeStub(nil) 37 | subscribe = mqtttest.NewSubscribeMock(t) 38 | subscribe = mqtttest.NewSubscribeStub(nil) 39 | unsubscribe = mqtttest.NewUnsubscribeMock(t) 40 | unsubscribe = mqtttest.NewUnsubscribeStub(nil) 41 | } 42 | 43 | func ExampleNewPublishExchangeStub() { 44 | PublishExchange := mqtttest.NewPublishExchangeStub(nil, 45 | mqtttest.ExchangeBlock{Delay: time.Millisecond}, 46 | errors.New("test storage failure"), 47 | ) 48 | 49 | exchange, err := PublishExchange([]byte("Hi!"), "announce") 50 | if err != nil { 51 | fmt.Println("publish error:", err) 52 | return 53 | } 54 | for err := range exchange { 55 | fmt.Println("exchange error:", err) 56 | } 57 | // Output: 58 | // exchange error: test storage failure 59 | } 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MQTT🤖 2 | 3 | ## About 4 | 5 | MQTT is a protocol for message queueing over a network. This project provides a 6 | client library for the Go programming language. Message-delivery guarantees are 7 | maintained at all costs, even when in error. The client recovers from failure 8 | with automatic reconnects. Message transfers in both directions do zero-copy. 9 | 10 | The development was kindly sponsored by [Northvolt](https://northvolt.com) as a 11 | gift to the open-source community. Many 🤖 navigated over unreliable connections 12 | since then. 13 | 14 | This is free and unencumbered software released into the 15 | [public domain](https://creativecommons.org/publicdomain/zero/1.0). 16 | 17 | [![Go Reference](https://pkg.go.dev/badge/github.com/pascaldekloe/mqtt.svg)](https://pkg.go.dev/github.com/pascaldekloe/mqtt) 18 | [![Build Status](https://github.com/pascaldekloe/mqtt/actions/workflows/go.yml/badge.svg)](https://github.com/pascaldekloe/mqtt/actions/workflows/go.yml) 19 | 20 | 21 | ## Usage 22 | 23 | Client instantiation validates the configuration only. Network management itself 24 | operates from the *read routine*. 25 | 26 | ```go 27 | client, err := mqtt.VolatileSession("demo-client", &mqtt.Config{ 28 | Dialer: mqtt.NewDialer("tcp", "mq1.example.com:1883"), 29 | PauseTimeout: 4 * time.Second, 30 | CleanSession: true, 31 | }) 32 | if err != nil { 33 | log.Fatal("exit on broken setup: ", err) 34 | } 35 | ``` 36 | 37 | A *read routine* sees inbound messages from any of the subscribed topics. 38 | 39 | ```go 40 | for { 41 | message, topic, err := client.ReadSlices() 42 | if err != nil { 43 | log.Print(err) 44 | <-client.ReadBackoff(err) 45 | continue 46 | } 47 | 48 | r, _ := utf8.DecodeLastRune(message) 49 | switch r { 50 | case '℃', '℉': 51 | log.Printf("%q at %q", message, topic) 52 | } 53 | } 54 | ``` 55 | 56 | The client supports confirmed message delivery with full progress disclosure. 57 | Message transfers without confirmation can be as simple as the following. 58 | 59 | ```go 60 | err := client.Publish(ctx.Done(), []byte("20.8℃"), "bedroom") 61 | if err != nil { 62 | log.Print("thermostat update lost: ", err) 63 | } 64 | ``` 65 | 66 | See the [examples](https://pkg.go.dev/github.com/pascaldekloe/mqtt#pkg-examples) 67 | from the package documentation for more detail. 68 | 69 | 70 | ## Command-Line Client 71 | 72 | Run `go install github.com/pascaldekloe/mqtt/cmd/mqttc@latest` to build the 73 | binary. 74 | 75 | ``` 76 | NAME 77 | mqttc — MQTT broker access 78 | 79 | SYNOPSIS 80 | mqttc [options] address 81 | 82 | DESCRIPTION 83 | The command connects to the address argument, with an option to 84 | publish a message and/or subscribe with topic filters. 85 | 86 | When the address does not specify a port, then the defaults are 87 | applied, which is 1883 for plain connections and 8883 for TLS. 88 | 89 | OPTIONS 90 | -ca file 91 | Amend the trusted certificate authorities with a PEM file. 92 | -cert file 93 | Use a client certificate from a PEM file (with a corresponding 94 | -key option). 95 | -client identifier 96 | Use a specific client identifier. (default "generated") 97 | -key file 98 | Use a private key (matching the client certificate) from a PEM 99 | file. 100 | -net name 101 | Select the network by name. Valid alternatives include tcp4, 102 | tcp6 and unix. (default "tcp") 103 | -pass file 104 | The file content is used as a password. 105 | -prefix string 106 | Print a string before each inbound message. 107 | -publish topic 108 | Send a message to a topic. The payload is read from standard 109 | input. 110 | -quiet 111 | Suppress all output to standard error. Error reporting is 112 | deduced to the exit code only. 113 | -quote 114 | Print inbound topics and messages as quoted strings. 115 | -server name 116 | Use a specific server name with TLS 117 | -subscribe filter 118 | Listen with a topic filter. Inbound messages are printed to 119 | standard output until interrupted by a signal(3). Multiple 120 | -subscribe options may be applied together. 121 | -suffix string 122 | Print a string after each inbound message. (default "\n") 123 | -timeout duration 124 | Network operation expiry. (default 4s) 125 | -tls 126 | Secure the connection with TLS. 127 | -topic 128 | Print the respective topic of each inbound message. 129 | -user name 130 | The user name may be used by the broker for authentication 131 | and/or authorization purposes. 132 | -verbose 133 | Produces more output to standard error for debug purposes. 134 | 135 | EXIT STATUS 136 | (0) no error 137 | (1) MQTT operational error 138 | (2) illegal command invocation 139 | (5) connection refused: unacceptable protocol version 140 | (6) connection refused: identifier rejected 141 | (7) connection refused: server unavailable 142 | (8) connection refused: bad username or password 143 | (9) connection refused: not authorized 144 | (130) close on SIGINT 145 | (143) disconnect on SIGTERM 146 | 147 | EXAMPLES 148 | Send a message: 149 | 150 | echo "hello" | mqttc -publish chat/misc localhost 151 | 152 | Print messages: 153 | 154 | mqttc -subscribe "news/#" -prefix "📥 " :1883 155 | 156 | Health check: 157 | 158 | mqttc -tls q1.example.com:8883 || echo "exit $?" 159 | 160 | BUGS 161 | Report bugs at . 162 | 163 | SEE ALSO 164 | mosquitto_pub(1) 165 | ``` 166 | 167 | 168 | ## Standard Compliance 169 | 170 | The implementation follows version 3.1.1 of the 171 | [OASIS specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) 172 | in a strict manner. Support for the original 173 | [IBM-specification](https://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html) 174 | may be added at some point in time. 175 | 176 | There are no plans to support protocol version 5. Version 3 is lean and well 177 | suited for IOT. The additions in version 5 may be more of a fit for backend 178 | computing. 179 | 180 | See the [Broker wiki](https://github.com/pascaldekloe/mqtt/wiki/Brokers) for 181 | implementation specifics. The most notable offender is AWS IoT. Additions to 182 | `integration/docker-compose.yml` are welcome. 183 | -------------------------------------------------------------------------------- /example_test.go: -------------------------------------------------------------------------------- 1 | package mqtt_test 2 | 3 | import ( 4 | "context" 5 | "crypto/tls" 6 | "errors" 7 | "fmt" 8 | "log" 9 | "time" 10 | 11 | "github.com/pascaldekloe/mqtt" 12 | "github.com/pascaldekloe/mqtt/mqtttest" 13 | ) 14 | 15 | // Publish is a method from mqtt.Client. 16 | var Publish func(quit <-chan struct{}, message []byte, topic string) error 17 | 18 | // PublishAtLeastOnce is a method from mqtt.Client. 19 | var PublishAtLeastOnce func(message []byte, topic string) (ack <-chan error, err error) 20 | 21 | // Subscribe is a method from mqtt.Client. 22 | var Subscribe func(quit <-chan struct{}, topicFilters ...string) error 23 | 24 | // Online is a method from mqtt.Client. 25 | var Online func() <-chan struct{} 26 | 27 | func init() { 28 | PublishAtLeastOnce = mqtttest.NewPublishExchangeStub(nil) 29 | Subscribe = mqtttest.NewSubscribeStub(nil) 30 | Online = func() <-chan struct{} { return nil } 31 | } 32 | 33 | // Some brokers permit authentication with TLS client certificates. 34 | func ExampleNewTLSDialer_clientCertificate() { 35 | certPEM := []byte(`-----BEGIN CERTIFICATE----- 36 | MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw 37 | DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow 38 | EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d 39 | 7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B 40 | 5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr 41 | BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 42 | NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l 43 | Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc 44 | 6MF9+Yw1Yy0t 45 | -----END CERTIFICATE-----`) 46 | keyPEM := []byte(`-----BEGIN EC PRIVATE KEY----- 47 | MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 48 | AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q 49 | EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== 50 | -----END EC PRIVATE KEY-----`) 51 | 52 | cert, err := tls.X509KeyPair(certPEM, keyPEM) 53 | if err != nil { 54 | log.Fatal(err) 55 | } 56 | mqtt.NewTLSDialer("tcp", "mq1.example.com:8883", &tls.Config{ 57 | Certificates: []tls.Certificate{cert}, 58 | }) 59 | // Output: 60 | } 61 | 62 | // It is good practice to install the client from main. 63 | func ExampleClient_setup() { 64 | client, err := mqtt.VolatileSession("demo-client", &mqtt.Config{ 65 | Dialer: mqtt.NewDialer("tcp", "localhost:1883"), 66 | PauseTimeout: 4 * time.Second, 67 | }) 68 | if err != nil { 69 | log.Fatal("exit on broken setup: ", err) 70 | } 71 | 72 | // launch read-routine 73 | go func() { 74 | var big *mqtt.BigMessage 75 | for { 76 | message, topic, err := client.ReadSlices() 77 | switch { 78 | case err == nil: 79 | // do something with inbound message 80 | log.Printf("📥 %q: %q", topic, message) 81 | 82 | case errors.As(err, &big): 83 | log.Printf("📥 %q: %d byte message omitted", 84 | big.Topic, big.Size) 85 | 86 | default: 87 | log.Print(err) 88 | wait := client.ReadBackoff(err) 89 | if wait == nil { 90 | return // terminated 91 | } 92 | <-wait 93 | } 94 | } 95 | }() 96 | 97 | // Install each method in use as a package variable. Such setup is 98 | // compatible with the tools from the mqtttest subpackage. 99 | Publish = client.Publish 100 | // Output: 101 | } 102 | 103 | // Demonstrates all error scenario and the respective recovery options. 104 | func ExampleClient_PublishAtLeastOnce_critical() { 105 | for { 106 | exchange, err := PublishAtLeastOnce([]byte("🍸🆘"), "demo/alert") 107 | switch { 108 | case err == nil: 109 | fmt.Println("alert submitted…") 110 | break 111 | 112 | case mqtt.IsDeny(err), errors.Is(err, mqtt.ErrClosed): 113 | fmt.Println("🚨 alert not send:", err) 114 | return 115 | 116 | case errors.Is(err, mqtt.ErrMax): 117 | fmt.Println("⚠️ alert submission hold-up:", err) 118 | time.Sleep(time.Second / 4) 119 | continue 120 | 121 | default: 122 | fmt.Println("⚠️ alert submission blocked on persistence malfunction:", err) 123 | time.Sleep(4 * time.Second) 124 | continue 125 | } 126 | 127 | for err := range exchange { 128 | if errors.Is(err, mqtt.ErrClosed) { 129 | fmt.Println("🚨 alert exchange suspended:", err) 130 | // An AdoptSession may continue the transaction. 131 | return 132 | } 133 | 134 | fmt.Println("⚠️ alert request transfer interrupted:", err) 135 | } 136 | fmt.Println("alert acknowledged ✓") 137 | break 138 | } 139 | 140 | // Output: 141 | // alert submitted… 142 | // alert acknowledged ✓ 143 | } 144 | 145 | // The switch lists all possible outcomes of a SUBSCRIBE request. 146 | func ExampleClient_Subscribe_scenario() { 147 | ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 148 | defer cancel() 149 | 150 | for { 151 | err := Subscribe(ctx.Done(), "demo/+") 152 | switch { 153 | case err == nil: 154 | fmt.Println("subscribe confirmed by broker") 155 | return 156 | 157 | case errors.As(err, new(mqtt.SubscribeError)): 158 | fmt.Println("subscribe failed by broker") 159 | return 160 | 161 | case mqtt.IsDeny(err): // illegal topic filter 162 | panic(err) // unreachable for string literal 163 | 164 | case errors.Is(err, mqtt.ErrClosed): 165 | fmt.Println("no subscribe with closed client") 166 | return 167 | 168 | case errors.Is(err, mqtt.ErrCanceled): 169 | fmt.Println("no subscribe with quit before submision") 170 | return 171 | 172 | case errors.Is(err, mqtt.ErrAbandoned): 173 | fmt.Println("subscribe in limbo with quit after submission") 174 | return 175 | 176 | case errors.Is(err, mqtt.ErrDown): 177 | fmt.Println("no subscribe without connection") 178 | select { 179 | case <-Online(): 180 | fmt.Println("subscribe retry with new connection") 181 | case <-ctx.Done(): 182 | fmt.Println("subscribe expired before reconnect") 183 | return 184 | } 185 | 186 | case errors.Is(err, mqtt.ErrSubmit), errors.Is(err, mqtt.ErrBreak): 187 | fmt.Println("subscribe in limbo with transit failure") 188 | select { 189 | case <-Online(): 190 | fmt.Println("subscribe retry with new connection") 191 | case <-ctx.Done(): 192 | fmt.Println("subscribe expired before reconnect") 193 | return 194 | } 195 | 196 | case errors.Is(err, mqtt.ErrMax): 197 | fmt.Println("no subscribe with too many requests") 198 | time.Sleep(time.Second) // backoff 199 | 200 | default: // unreachable 201 | fmt.Println("unknown subscribe state:", err) 202 | time.Sleep(time.Second) // backoff 203 | } 204 | } 205 | // Output: 206 | // subscribe confirmed by broker 207 | } 208 | -------------------------------------------------------------------------------- /mqtt_test.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "errors" 7 | "net" 8 | "sort" 9 | "testing" 10 | ) 11 | 12 | func TestConstants(t *testing.T) { 13 | if want := 268_435_455; packetMax != want { 14 | t.Errorf("got packetMax %d, want %d", packetMax, want) 15 | } 16 | if want := 65_535; stringMax != want { 17 | t.Errorf("got stringMax %d, want %d", stringMax, want) 18 | } 19 | } 20 | 21 | func TestErrorClasses(t *testing.T) { 22 | if IsDeny(nil) { 23 | t.Error("IsDeny got true for nil") 24 | } 25 | if IsEnd(nil) { 26 | t.Error("IsEnd true for nil") 27 | } 28 | 29 | for _, err := range denyErrs { 30 | if !IsDeny(err) { 31 | t.Error("IsDeny got false for error:", err) 32 | } 33 | if IsEnd(err) { 34 | t.Error("IsEnd got true for error:", err) 35 | } 36 | } 37 | 38 | for _, err := range endErrs { 39 | if IsDeny(err) { 40 | t.Error("IsDenry got true for error: ", err) 41 | } 42 | if !IsEnd(err) { 43 | t.Error("IsEnd got false for error:", err) 44 | } 45 | } 46 | 47 | var transitErrs = []error{ErrSubmit, ErrBreak} 48 | for _, err := range transitErrs { 49 | if IsDeny(err) { 50 | t.Error("IsDenry got true for error: ", err) 51 | } 52 | if IsEnd(err) { 53 | t.Error("IsEnd got true for error:", err) 54 | } 55 | } 56 | } 57 | 58 | func TestNewConnectReq(t *testing.T) { 59 | c := &Config{ 60 | Dialer: func(context.Context) (net.Conn, error) { 61 | return nil, errors.New("dialer call not allowed for test") 62 | }, 63 | UserName: "me", 64 | Password: []byte{'?'}, 65 | CleanSession: true, 66 | KeepAlive: 3600, 67 | } 68 | c.Will.Topic = "☯️" 69 | c.Will.Message = []byte("☠") 70 | c.Will.Retain = true 71 | c.Will.AtLeastOnce = true 72 | c.Will.ExactlyOnce = true 73 | 74 | got := c.newConnectReq([]byte("#🤖")) 75 | want := []byte{0x10, 37, 0, 4, 'M', 'Q', 'T', 'T', 4, 0b1111_0110, 0x0e, 0x10, 76 | 0, 5, '#', 0xF0, 0x9F, 0xA4, 0x96, 77 | 0, 6, 0xe2, 0x98, 0xaf, 0xef, 0xb8, 0x8f, 78 | 0, 3, 0xe2, 0x98, 0xa0, 79 | 0, 2, 'm', 'e', 80 | 0, 1, '?'} 81 | if !bytes.Equal(got, want) { 82 | t.Errorf("full session config got %#x, want %#x", got, want) 83 | } 84 | } 85 | 86 | func TestPesistenceEmpty(t *testing.T) { 87 | t.Run("volatile", func(t *testing.T) { 88 | testPersistenceEmpty(t, newVolatile()) 89 | }) 90 | t.Run("fileSystem", func(t *testing.T) { 91 | testPersistenceEmpty(t, FileSystem(t.TempDir())) 92 | }) 93 | } 94 | 95 | func testPersistenceEmpty(t *testing.T, p Persistence) { 96 | if data, err := p.Load(42); err != nil { 97 | t.Error("Load got error:", err) 98 | } else if data != nil { 99 | t.Errorf("Load got %#x, want nil", data) 100 | } 101 | 102 | if err := p.Delete(42); err != nil { 103 | t.Error("Delete got error:", err) 104 | } 105 | 106 | if keys, err := p.List(); err != nil { 107 | t.Error("List got error:", err) 108 | } else if len(keys) != 0 { 109 | t.Errorf("List got keys %d", keys) 110 | } 111 | } 112 | 113 | func TestPersistence(t *testing.T) { 114 | t.Run("volatile", func(t *testing.T) { 115 | testPersistence(t, newVolatile()) 116 | }) 117 | t.Run("fileSystem", func(t *testing.T) { 118 | testPersistence(t, FileSystem(t.TempDir())) 119 | }) 120 | } 121 | 122 | func testPersistence(t *testing.T, p Persistence) { 123 | for i := 0; i < 3; i++ { 124 | bufs := make(net.Buffers, i+1) 125 | for j := range bufs { 126 | bufs[j] = make([]byte, j+1) 127 | for k := range bufs[j] { 128 | bufs[j][k] = byte('a' + k) 129 | } 130 | } 131 | 132 | err := p.Save(uint(i), bufs) 133 | if err != nil { 134 | t.Errorf("Save %d got error: %s", i, err) 135 | } 136 | } 137 | 138 | if keys, err := p.List(); err != nil { 139 | t.Error("List got error:", err) 140 | } else { 141 | // order undefined 142 | ints := make([]int, len(keys)) 143 | for i := range keys { 144 | ints[i] = int(keys[i]) 145 | } 146 | sort.Ints(ints) 147 | if len(ints) != 3 || ints[0] != 0 || ints[1] != 1 || ints[2] != 2 { 148 | t.Errorf("List got %d, want %d", ints, []int{0, 1, 2}) 149 | } 150 | } 151 | 152 | if data, err := p.Load(0); err != nil { 153 | t.Error("Load 0 got error:", err) 154 | } else if want := "a"; string(data) != want { 155 | t.Errorf("Load 0 got %q, want %q", data, want) 156 | } 157 | if data, err := p.Load(1); err != nil { 158 | t.Error("Load 1 got error:", err) 159 | } else if want := "aab"; string(data) != want { 160 | t.Errorf("Load 1 got %q, want %q", data, want) 161 | } 162 | if data, err := p.Load(2); err != nil { 163 | t.Error("Load 2 got error:", err) 164 | } else if want := "aababc"; string(data) != want { 165 | t.Errorf("Load 2 got %q, want %q", data, want) 166 | } 167 | } 168 | 169 | func TestPersistenceUpdate(t *testing.T) { 170 | t.Run("volatile", func(t *testing.T) { 171 | testPersistenceUpdate(t, newVolatile()) 172 | }) 173 | t.Run("fileSystem", func(t *testing.T) { 174 | testPersistenceUpdate(t, FileSystem(t.TempDir())) 175 | }) 176 | } 177 | 178 | func testPersistenceUpdate(t *testing.T, p Persistence) { 179 | err := p.Save(0, net.Buffers{[]byte("ab"), []byte("cd")}) 180 | if err != nil { 181 | t.Fatal("Save new 0 got error:", err) 182 | } 183 | err = p.Save(42, net.Buffers{[]byte("ef")}) 184 | if err != nil { 185 | t.Fatal("Save new 42 got error:", err) 186 | } 187 | err = p.Save(0, net.Buffers{[]byte("12")}) 188 | if err != nil { 189 | t.Fatal("Save update 0 got error:", err) 190 | } 191 | err = p.Save(42, net.Buffers{[]byte("34"), []byte("56")}) 192 | if err != nil { 193 | t.Fatal("Save update 42 got error:", err) 194 | } 195 | 196 | if data, err := p.Load(0); err != nil { 197 | t.Error("Load 0 got error:", err) 198 | } else if want := "12"; string(data) != want { 199 | t.Errorf("Load 0 got %#v, want %#v", data, want) 200 | } 201 | if data, err := p.Load(42); err != nil { 202 | t.Error("Load 42 got error:", err) 203 | } else if want := "3456"; string(data) != want { 204 | t.Errorf("Load 42 got %#v, want %#v", data, want) 205 | } 206 | } 207 | 208 | func TestPersistenceDelete(t *testing.T) { 209 | t.Run("volatile", func(t *testing.T) { 210 | testPersistenceDelete(t, newVolatile()) 211 | }) 212 | t.Run("fileSystem", func(t *testing.T) { 213 | testPersistenceDelete(t, FileSystem(t.TempDir())) 214 | }) 215 | } 216 | 217 | func testPersistenceDelete(t *testing.T, p Persistence) { 218 | err := p.Save(0, net.Buffers{[]byte("ab"), []byte("cd")}) 219 | if err != nil { 220 | t.Fatal("Save new 0 got error:", err) 221 | } 222 | err = p.Save(42, net.Buffers{[]byte("ef")}) 223 | if err != nil { 224 | t.Fatal("Save new 42 got error:", err) 225 | } 226 | err = p.Save(42, net.Buffers{[]byte("gh")}) 227 | if err != nil { 228 | t.Fatal("Save update 42 got error:", err) 229 | } 230 | err = p.Save(99, net.Buffers{[]byte("ij")}) 231 | if err != nil { 232 | t.Fatal("Save new 99 got error:", err) 233 | } 234 | 235 | if err := p.Delete(42); err != nil { 236 | t.Error("Delete 42 got error:", err) 237 | } 238 | if err := p.Delete(0); err != nil { 239 | t.Error("Delete 0 got error:", err) 240 | } 241 | if keys, err := p.List(); err != nil { 242 | t.Error("List got error:", err) 243 | } else if len(keys) != 1 || keys[0] != 99 { 244 | t.Errorf("List got %d, want %d", keys, []uint{99}) 245 | } 246 | } 247 | -------------------------------------------------------------------------------- /integration/integration_test.go: -------------------------------------------------------------------------------- 1 | package integration 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "os" 9 | "strings" 10 | "sync" 11 | "testing" 12 | "time" 13 | 14 | "github.com/pascaldekloe/mqtt" 15 | ) 16 | 17 | func hosts(tb testing.TB) []string { 18 | s, ok := os.LookupEnv("MQTT_HOSTS") 19 | if !ok { 20 | tb.Skip("no test targets without MQTT_HOSTS environment variable") 21 | } 22 | return strings.Fields(s) 23 | } 24 | 25 | // BatchSize is a reasonable number of messages which should not cause any of 26 | // them to be dropped (by the broker) when send sequentially. 27 | const batchSize = 128 28 | 29 | // SendBatch publishes a total of batchSize messages, with 8-byte payloads, 30 | // containing msgOffset until msgOffset + batchSize − 1. It pushes for the 31 | // maximum number of in-flight messages allowed by the mqtt.Client. 32 | func sendBatch(t testing.TB, publish func([]byte, string) (<-chan error, error), msgOffset uint64) { 33 | retry := time.NewTicker(time.Microsecond) 34 | defer retry.Stop() 35 | 36 | // confirmation channels 37 | var exchanges [batchSize]<-chan error 38 | topic := t.Name() 39 | 40 | // publish each 41 | for i := uint64(0); i < batchSize; { 42 | var msg [8]byte 43 | binary.BigEndian.PutUint64(msg[:], msgOffset+i) 44 | ch, err := publish(msg[:], topic) 45 | switch { 46 | case err == nil: 47 | exchanges[i] = ch 48 | i++ 49 | 50 | case errors.Is(err, mqtt.ErrMax): 51 | <-retry.C 52 | 53 | default: 54 | t.Error("publish batch abort on:", err) 55 | return 56 | } 57 | } 58 | 59 | // read confirmations 60 | for i := range exchanges { 61 | for err := range exchanges[i] { 62 | t.Error("publish exchange abort on:", err) 63 | return 64 | } 65 | } 66 | } 67 | 68 | // ReceiveBatch reads a sendBatch sequence. 69 | func receiveBatch(t testing.TB, messages <-chan uint64, offset uint64) { 70 | for i := uint64(0); i < batchSize; i++ { 71 | got, ok := <-messages 72 | want := offset + i 73 | if !ok { 74 | t.Errorf("receive stopped, want message # %d", want) 75 | return 76 | } 77 | if got != want { 78 | t.Errorf("got message # %d, want # %d", got, want) 79 | } 80 | } 81 | } 82 | 83 | func exchangeN(t testing.TB, n uint64, publish func([]byte, string) (<-chan error, error), messages <-chan uint64) { 84 | if n%batchSize != 0 { 85 | t.Fatalf("exchange count %d must be a multiple of %d", n, batchSize) 86 | } 87 | 88 | for i := uint64(0); i < n; i += batchSize { 89 | done := make(chan struct{}) 90 | go func() { 91 | defer close(done) 92 | sendBatch(t, publish, i) 93 | }() 94 | receiveBatch(t, messages, i) 95 | <-done 96 | if t.Failed() { 97 | return 98 | } 99 | } 100 | } 101 | 102 | func TestRoundtrip(t *testing.T) { 103 | for _, host := range hosts(t) { 104 | t.Run(host, func(t *testing.T) { 105 | testRoundtripHost(t, host) 106 | }) 107 | } 108 | } 109 | 110 | func testRoundtripHost(t *testing.T, host string) { 111 | // client instantiation 112 | clientID := t.Name() 113 | config := mqtt.Config{ 114 | Dialer: mqtt.NewDialer("tcp", net.JoinHostPort(host, "1883")), 115 | PauseTimeout: time.Second, 116 | CleanSession: true, 117 | AtLeastOnceMax: 9, 118 | ExactlyOnceMax: 9, 119 | } 120 | 121 | // target specifics 122 | switch host { 123 | case "activemq": 124 | config.UserName = "artemis" 125 | config.Password = []byte("artemis") 126 | 127 | case "rumqttd": 128 | config.KeepAlive = 20 129 | clientID = strings.Replace(clientID, "/", "-", 1) 130 | 131 | case "volantmq": 132 | config.UserName = "testuser" 133 | config.Password = []byte("testpassword") 134 | } 135 | 136 | client, err := mqtt.VolatileSession(clientID, &config) 137 | if err != nil { 138 | t.Fatal("client instantiation:", err) 139 | } 140 | 141 | testRoundtripClient(t, client) 142 | } 143 | 144 | func testRoundtripClient(t *testing.T, client *mqtt.Client) { 145 | const messageN = 16384 + batchSize // overflows mqtt.publishIDMask 146 | 147 | // receive streams 148 | atMostOnceMessages := make(chan uint64) 149 | atLeastOnceMessages := make(chan uint64) 150 | exactlyOnceMessages := make(chan uint64) 151 | 152 | // The read routine continues until Client Close. 153 | // Errors are discarded once the channel buffer is full. 154 | readDone := make(chan error, 60) 155 | go func() { 156 | defer close(readDone) 157 | 158 | defer close(atMostOnceMessages) 159 | defer close(atLeastOnceMessages) 160 | defer close(exactlyOnceMessages) 161 | 162 | for { 163 | message, topic, err := client.ReadSlices() 164 | if err != nil { 165 | wait := client.ReadBackoff(err) 166 | if wait == nil { 167 | return // terminated 168 | } 169 | 170 | select { 171 | case readDone <- err: 172 | default: // discard 173 | } 174 | 175 | <-wait 176 | continue 177 | } 178 | 179 | if len(message) != 8 { 180 | select { 181 | case readDone <- fmt.Errorf("unexpected message %#x on topic %q", message, topic): 182 | default: // discard 183 | } 184 | continue 185 | } 186 | seqNo := binary.BigEndian.Uint64(message) 187 | 188 | switch s := string(topic); { 189 | case strings.HasSuffix(s, "/at-most-once"): 190 | atMostOnceMessages <- seqNo 191 | case strings.HasSuffix(s, "/at-least-once"): 192 | atLeastOnceMessages <- seqNo 193 | case strings.HasSuffix(s, "/exactly-once"): 194 | exactlyOnceMessages <- seqNo 195 | default: 196 | select { 197 | case readDone <- fmt.Errorf("message # %d on unexpected topic %q", seqNo, topic): 198 | default: // discard 199 | } 200 | } 201 | } 202 | }() 203 | 204 | // test each QoS in parallel 205 | var testGroup sync.WaitGroup 206 | testGroup.Add(3) 207 | 208 | t.Run("at-most-once", func(t *testing.T) { 209 | defer testGroup.Done() 210 | t.Parallel() 211 | 212 | <-client.Online() 213 | t.Log("client online") 214 | 215 | err := client.Subscribe(nil, t.Name()) 216 | if err != nil { 217 | t.Fatal(err) 218 | } 219 | exchange := make(chan error) 220 | close(exchange) 221 | exchangeN(t, messageN, func(message []byte, topic string) (<-chan error, error) { 222 | err := client.Publish(nil, message, topic) 223 | return exchange, err 224 | }, atMostOnceMessages) 225 | }) 226 | 227 | t.Run("at-least-once", func(t *testing.T) { 228 | defer testGroup.Done() 229 | t.Parallel() 230 | 231 | <-client.Online() 232 | t.Log("client online") 233 | 234 | err := client.Subscribe(nil, t.Name()) 235 | if err != nil { 236 | t.Fatal(err) 237 | } 238 | exchangeN(t, messageN, client.PublishAtLeastOnce, atLeastOnceMessages) 239 | }) 240 | 241 | t.Run("exactly-once", func(t *testing.T) { 242 | defer testGroup.Done() 243 | t.Parallel() 244 | 245 | <-client.Online() 246 | t.Log("client online") 247 | 248 | err := client.Subscribe(nil, t.Name()) 249 | if err != nil { 250 | t.Fatal(err) 251 | } 252 | exchangeN(t, messageN, client.PublishExactlyOnce, exactlyOnceMessages) 253 | }) 254 | 255 | t.Run("clean-exit", func(t *testing.T) { 256 | t.Parallel() 257 | 258 | <-client.Online() 259 | t.Log("client online") 260 | 261 | testGroup.Wait() 262 | 263 | t.Log("disconnect request") 264 | err := client.Disconnect(nil) 265 | if err != nil { 266 | t.Error(err) 267 | } 268 | 269 | <-client.Offline() 270 | t.Log("client offline") 271 | 272 | for err := range readDone { 273 | t.Error(err) 274 | } 275 | }) 276 | } 277 | -------------------------------------------------------------------------------- /mqtttest/mqtttest.go: -------------------------------------------------------------------------------- 1 | // Package mqtttest provides utilities for MQTT testing. 2 | package mqtttest 3 | 4 | import ( 5 | "bytes" 6 | "errors" 7 | "sync/atomic" 8 | "testing" 9 | "time" 10 | 11 | "github.com/pascaldekloe/mqtt" 12 | ) 13 | 14 | // Transfer defines a message exchange. 15 | type Transfer struct { 16 | Message []byte // payload 17 | Topic string // destination 18 | Err error // result 19 | } 20 | 21 | // NewReadSlicesStub returns a new stub for mqtt.Client ReadSlices with a fixed 22 | // return value. 23 | func NewReadSlicesStub(fix Transfer) func() (message, topic []byte, err error) { 24 | return func() (message, topic []byte, err error) { 25 | // use copies to prevent some hard to trace issues 26 | message = make([]byte, len(fix.Message)) 27 | copy(message, fix.Message) 28 | topic = []byte(fix.Topic) 29 | return message, topic, fix.Err 30 | } 31 | } 32 | 33 | // NewReadSlicesMock returns a new mock for mqtt.Client ReadSlices, which 34 | // returns the Transfers in order of appearance. 35 | func NewReadSlicesMock(t testing.TB, want ...Transfer) func() (message, topic []byte, err error) { 36 | t.Helper() 37 | 38 | var wantIndex uint64 39 | 40 | t.Cleanup(func() { 41 | t.Helper() 42 | 43 | if n := uint64(len(want)) - atomic.LoadUint64(&wantIndex); n > 0 { 44 | t.Errorf("want %d more MQTT ReadSlices", n) 45 | } 46 | }) 47 | 48 | return func() (message, topic []byte, err error) { 49 | t.Helper() 50 | 51 | i := atomic.AddUint64(&wantIndex, 1) - 1 52 | if i >= uint64(len(want)) { 53 | err = errors.New("unwanted MQTT ReadSlices") 54 | t.Error(err) 55 | return 56 | } 57 | 58 | return NewReadSlicesStub(want[i])() 59 | } 60 | } 61 | 62 | // NewPublishMock returns a new mock for mqtt.Client Publish, which compares the 63 | // invocation with want in order of appearance. 64 | func NewPublishMock(t testing.TB, want ...Transfer) func(quit <-chan struct{}, message []byte, topic string) error { 65 | t.Helper() 66 | 67 | var wantIndex uint64 68 | 69 | t.Cleanup(func() { 70 | if n := uint64(len(want)) - atomic.LoadUint64(&wantIndex); n > 0 { 71 | t.Errorf("want %d more MQTT publishes", n) 72 | } 73 | }) 74 | 75 | return func(quit <-chan struct{}, message []byte, topic string) error { 76 | t.Helper() 77 | 78 | select { 79 | case <-quit: 80 | return mqtt.ErrCanceled 81 | default: 82 | break 83 | } 84 | 85 | i := atomic.AddUint64(&wantIndex, 1) - 1 86 | if i >= uint64(len(want)) { 87 | t.Errorf("unwanted MQTT publish of %#x to %q", message, topic) 88 | return nil 89 | } 90 | transfer := want[i] 91 | 92 | if !bytes.Equal(message, transfer.Message) && topic != transfer.Topic { 93 | t.Errorf("got MQTT publish of %#x to %q, want %#x to %q", message, topic, transfer.Message, transfer.Topic) 94 | } 95 | return transfer.Err 96 | } 97 | } 98 | 99 | // NewPublishStub returns a new stub for mqtt.Client Publish with a fixed return 100 | // value. 101 | func NewPublishStub(fix error) func(quit <-chan struct{}, message []byte, topic string) error { 102 | return func(quit <-chan struct{}, message []byte, topic string) error { 103 | select { 104 | case <-quit: 105 | return mqtt.ErrCanceled 106 | default: 107 | return fix 108 | } 109 | } 110 | } 111 | 112 | // ExchangeBlock prevents exchange <-chan error submission. 113 | type ExchangeBlock struct { 114 | Delay time.Duration // zero defaults to indefinite 115 | } 116 | 117 | // Error implements the standard error interface. 118 | func (b ExchangeBlock) Error() string { 119 | return "mqtttest: ExchangeBlock used as an error" 120 | } 121 | 122 | // NewPublishExchangeStub returns a stub for mqtt.Client PublishAtLeastOnce or 123 | // PublishExactlyOnce with a fixed return value. 124 | // 125 | // The exchangeFix errors are applied to the exchange return, with an option for 126 | // ExchangeBlock entries. An mqtt.ErrClosed in the exchangeFix keeps the 127 | // exchange channel open (without an extra ExchangeBlock entry). 128 | func NewPublishExchangeStub(errFix error, exchangeFix ...error) func(message []byte, topic string) (exchange <-chan error, err error) { 129 | if errFix != nil && len(exchangeFix) != 0 { 130 | panic("exchangeFix entries with non-nil errFix") 131 | } 132 | var block ExchangeBlock 133 | for i, err := range exchangeFix { 134 | switch { 135 | case err == nil: 136 | panic("nil entry in exchangeFix") 137 | case errors.Is(err, mqtt.ErrClosed): 138 | if i+1 < len(exchangeFix) { 139 | panic("followup on mqtt.ErrClosed exchangeFix entry") 140 | } 141 | case errors.As(err, &block): 142 | if block.Delay == 0 && i+1 < len(exchangeFix) { 143 | panic("followup on indefinite ExchangeBlock exchangeFix entry") 144 | } 145 | } 146 | } 147 | 148 | return func(message []byte, topic string) (exchange <-chan error, err error) { 149 | if errFix != nil { 150 | return nil, errFix 151 | } 152 | 153 | ch := make(chan error, len(exchangeFix)) 154 | go func() { 155 | var block ExchangeBlock 156 | for _, err := range exchangeFix { 157 | switch { 158 | default: 159 | ch <- err 160 | case errors.Is(err, mqtt.ErrClosed): 161 | ch <- err 162 | return // without close 163 | case errors.As(err, &block): 164 | if block.Delay == 0 { 165 | return // without close 166 | } 167 | time.Sleep(block.Delay) 168 | } 169 | } 170 | close(ch) 171 | }() 172 | return ch, nil 173 | } 174 | } 175 | 176 | // NewSubscribeStub returns a stub for mqtt.Client Subscribe with a fixed return 177 | // value. 178 | func NewSubscribeStub(fix error) func(quit <-chan struct{}, topicFilters ...string) error { 179 | return newSubscribeStub("subscribe", fix) 180 | } 181 | 182 | // NewUnsubscribeStub returns a stub for mqtt.Client Unsubscribe with a fixed 183 | // return value. 184 | func NewUnsubscribeStub(fix error) func(quit <-chan struct{}, topicFilters ...string) error { 185 | return newSubscribeStub("unsubscribe", fix) 186 | } 187 | 188 | func newSubscribeStub(name string, fix error) func(quit <-chan struct{}, topicFilters ...string) error { 189 | return func(quit <-chan struct{}, topicFilters ...string) error { 190 | if len(topicFilters) == 0 { 191 | // TODO(pascaldekloe): move validation to internal 192 | // package and then return appropriate errors here. 193 | panic("MQTT " + name + " without topic filters") 194 | } 195 | select { 196 | case <-quit: 197 | return mqtt.ErrCanceled 198 | default: 199 | break 200 | } 201 | return fix 202 | } 203 | } 204 | 205 | // Filter defines a subscription exchange. 206 | type Filter struct { 207 | Topics []string // order is ignored 208 | Err error // result 209 | } 210 | 211 | // NewSubscribeMock returns a new mock for mqtt.Client Subscribe, which compares 212 | // the invocation with want in order of appearece. 213 | func NewSubscribeMock(t testing.TB, want ...Filter) func(quit <-chan struct{}, topicFilters ...string) error { 214 | t.Helper() 215 | return newSubscribeMock("subscribe", t, want...) 216 | } 217 | 218 | // NewUnsubscribeMock returns a new mock for mqtt.Client Unsubscribe, which 219 | // compares the invocation with want in order of appearece. 220 | func NewUnsubscribeMock(t testing.TB, want ...Filter) func(quit <-chan struct{}, topicFilters ...string) error { 221 | t.Helper() 222 | return newSubscribeMock("unsubscribe", t, want...) 223 | } 224 | 225 | func newSubscribeMock(name string, t testing.TB, want ...Filter) func(quit <-chan struct{}, topicFilters ...string) error { 226 | t.Helper() 227 | 228 | var wantIndex uint64 229 | 230 | t.Cleanup(func() { 231 | if n := uint64(len(want)) - atomic.LoadUint64(&wantIndex); n > 0 { 232 | t.Errorf("want %d more MQTT %ss", n, name) 233 | } 234 | }) 235 | 236 | return func(quit <-chan struct{}, topicFilters ...string) error { 237 | t.Helper() 238 | if len(topicFilters) == 0 { 239 | t.Fatalf("MQTT %s without topic filters", name) 240 | } 241 | select { 242 | case <-quit: 243 | return mqtt.ErrCanceled 244 | default: 245 | break 246 | } 247 | 248 | i := atomic.AddUint64(&wantIndex, 1) - 1 249 | if i >= uint64(len(want)) { 250 | t.Errorf("unwanted MQTT %s of %q", name, topicFilters) 251 | } 252 | filter := want[i] 253 | 254 | todo := make(map[string]struct{}, len(filter.Topics)) 255 | for _, topic := range filter.Topics { 256 | todo[topic] = struct{}{} 257 | } 258 | var wrong []string 259 | for _, filter := range topicFilters { 260 | if _, ok := todo[filter]; ok { 261 | delete(todo, filter) 262 | } else { 263 | wrong = append(wrong, filter) 264 | } 265 | } 266 | if len(wrong) != 0 { 267 | t.Errorf("unwanted MQTT %s of %q (out of %q)", name, wrong, filter.Topics) 268 | } 269 | if len(todo) != 0 { 270 | var miss []string 271 | for filter := range todo { 272 | miss = append(miss, filter) 273 | } 274 | t.Errorf("no MQTT %s of %q (out of %q)", name, miss, filter.Topics) 275 | } 276 | 277 | return filter.Err 278 | } 279 | } 280 | -------------------------------------------------------------------------------- /cmd/mqttc/main.go: -------------------------------------------------------------------------------- 1 | // Package main provides a command-line utility. 2 | package main 3 | 4 | import ( 5 | "context" 6 | "crypto/tls" 7 | "crypto/x509" 8 | "encoding/pem" 9 | "errors" 10 | "flag" 11 | "fmt" 12 | "io" 13 | "log" 14 | "net" 15 | "os" 16 | "os/signal" 17 | "syscall" 18 | "time" 19 | 20 | "github.com/pascaldekloe/mqtt" 21 | ) 22 | 23 | // ANSI escape codes for markup. 24 | const ( 25 | bold = "\x1b[1m" 26 | italic = "\x1b[3m" 27 | clear = "\x1b[0m" 28 | ) 29 | 30 | // Name of the invoked executable. 31 | var name = os.Args[0] 32 | 33 | var subscribeFlags []string 34 | 35 | func init() { 36 | flag.Func("subscribe", "Listen with a topic `filter`. Inbound messages are printed to\n"+italic+"standard output"+clear+" until interrupted by a signal(3). Multiple\n"+bold+"-subscribe"+clear+" options may be applied together.", func(value string) error { 37 | subscribeFlags = append(subscribeFlags, value) 38 | return nil 39 | }) 40 | } 41 | 42 | const generatedLabel = "generated" 43 | 44 | var ( 45 | publishFlag = flag.String("publish", "", "Send a message to a `topic`. The payload is read from "+italic+"standard\ninput"+clear+".") 46 | 47 | timeoutFlag = flag.Duration("timeout", 4*time.Second, "Network operation expiry.") 48 | netFlag = flag.String("net", "tcp", "Select the network by `name`. Valid alternatives include tcp4,\ntcp6 and unix.") 49 | 50 | tlsFlag = flag.Bool("tls", false, "Secure the connection with TLS.") 51 | serverFlag = flag.String("server", "", "Use a specific server `name` with TLS") 52 | caFlag = flag.String("ca", "", "Amend the trusted certificate authorities with a PEM `file`.") 53 | certFlag = flag.String("cert", "", "Use a client certificate from a PEM `file` (with a corresponding\n"+bold+"-key"+clear+" option).") 54 | keyFlag = flag.String("key", "", "Use a private key (matching the client certificate) from a PEM\n`file`.") 55 | 56 | userFlag = flag.String("user", "", "The user `name` may be used by the broker for authentication\nand/or authorization purposes.") 57 | passFlag = flag.String("pass", "", "The `file` content is used as a password.") 58 | 59 | clientFlag = flag.String("client", generatedLabel, "Use a specific client `identifier`.") 60 | 61 | prefixFlag = flag.String("prefix", "", "Print a `string` before each inbound message.") 62 | suffixFlag = flag.String("suffix", "\n", "Print a `string` after each inbound message.") 63 | topicFlag = flag.Bool("topic", false, "Print the respective topic of each inbound message.") 64 | quoteFlag = flag.Bool("quote", false, "Print inbound topics and messages as quoted strings.") 65 | 66 | quietFlag = flag.Bool("quiet", false, "Suppress all output to "+italic+"standard error"+clear+". Error reporting is\ndeduced to the exit code only.") 67 | verboseFlag = flag.Bool("verbose", false, "Produces more output to "+italic+"standard error"+clear+" for debug purposes.") 68 | ) 69 | 70 | // Config collects the command arguments. 71 | func Config() (clientID string, config *mqtt.Config) { 72 | var addr string 73 | switch args := flag.Args(); { 74 | case len(args) == 0: 75 | printManual() 76 | os.Exit(2) 77 | 78 | case len(args) == 1: 79 | addr = args[0] 80 | 81 | default: 82 | log.Printf("%s: multiple address arguments %q", name, args) 83 | os.Exit(2) 84 | } 85 | 86 | var TLS *tls.Config 87 | if *tlsFlag { 88 | TLS = new(tls.Config) 89 | } 90 | 91 | if *serverFlag != "" { 92 | if TLS == nil { 93 | log.Fatal(name, ": -server requires -tls option") 94 | } 95 | TLS.ServerName = *serverFlag 96 | } 97 | 98 | switch { 99 | case *certFlag != "" && *keyFlag != "": 100 | if TLS == nil { 101 | log.Fatal(name, ": -cert requires -tls option") 102 | } 103 | 104 | certPEM, err := os.ReadFile(*certFlag) 105 | if err != nil { 106 | log.Fatal(err) 107 | } 108 | keyPEM, err := os.ReadFile(*keyFlag) 109 | if err != nil { 110 | log.Fatal(err) 111 | } 112 | cert, err := tls.X509KeyPair(certPEM, keyPEM) 113 | if err != nil { 114 | log.Fatal(name, ": unusable -cert and -key content; ", err) 115 | } 116 | TLS.Certificates = append(TLS.Certificates, cert) 117 | 118 | case *certFlag != "": 119 | log.Fatal(name, ": -cert requires -key option") 120 | case *keyFlag != "": 121 | log.Fatal(name, ": -key requires -cert option") 122 | } 123 | 124 | if *caFlag != "" { 125 | if TLS == nil { 126 | log.Fatal(name, ": -ca requires -tls option") 127 | } 128 | 129 | if certs, err := x509.SystemCertPool(); err != nil { 130 | log.Print(name, ": system certificates unavailable; ", err) 131 | TLS.RootCAs = x509.NewCertPool() 132 | } else { 133 | TLS.RootCAs = certs 134 | } 135 | 136 | text, err := os.ReadFile(*caFlag) 137 | if err != nil { 138 | log.Fatal(err) 139 | } 140 | for n := 1; ; n++ { 141 | var block *pem.Block 142 | block, text = pem.Decode(text) 143 | if block == nil { 144 | break 145 | } 146 | if block.Type != "CERTIFICATE" || len(block.Headers) != 0 { 147 | log.Printf("%s: ignoring PEM block № %d of type %q in %s", name, n, block.Type, *caFlag) 148 | continue 149 | } 150 | cert, err := x509.ParseCertificate(block.Bytes) 151 | if err != nil { 152 | log.Printf("%s: ignoring PEM block № %d in %s; %s", name, n, *caFlag, err) 153 | continue 154 | } 155 | TLS.RootCAs.AddCert(cert) 156 | } 157 | } 158 | 159 | if _, _, err := net.SplitHostPort(addr); err != nil { 160 | port := "1883" 161 | if TLS != nil { 162 | port = "8883" 163 | } 164 | addr = net.JoinHostPort(addr, port) 165 | } 166 | 167 | clientID = *clientFlag 168 | if clientID == generatedLabel { 169 | clientID = "mqttc(1)-" + time.Now().In(time.UTC).Format(time.RFC3339Nano) 170 | } 171 | 172 | config = &mqtt.Config{ 173 | PauseTimeout: *timeoutFlag, 174 | UserName: *userFlag, 175 | } 176 | if *passFlag != "" { 177 | bytes, err := os.ReadFile(*passFlag) 178 | if err != nil { 179 | log.Fatal(err) 180 | } 181 | config.Password = bytes 182 | } 183 | 184 | if TLS != nil { 185 | config.Dialer = mqtt.NewTLSDialer(*netFlag, addr, TLS) 186 | } else { 187 | config.Dialer = mqtt.NewDialer(*netFlag, addr) 188 | } 189 | return 190 | } 191 | 192 | var exitStatus = make(chan int, 1) 193 | 194 | func setExitStatusOnce(code int) { 195 | select { 196 | case exitStatus <- code: 197 | default: 198 | } 199 | } 200 | 201 | func main() { 202 | log.SetFlags(0) 203 | flag.Usage = printManual 204 | flag.Parse() 205 | if *quietFlag { 206 | log.SetOutput(io.Discard) 207 | } 208 | 209 | clientID, config := Config() 210 | client, err := mqtt.VolatileSession(clientID, config) 211 | if err != nil { 212 | log.Fatal(err) 213 | } 214 | 215 | go applySignals(client) 216 | 217 | // broker exchange 218 | go func() { 219 | // maybe PUBLISH 220 | if *publishFlag != "" && !publish(client, *publishFlag) { 221 | return 222 | } 223 | // maybe SUBSCRIBE 224 | if len(subscribeFlags) != 0 && !subscribe(client, subscribeFlags) { 225 | return 226 | } 227 | // PING when no PUBLISH and no SUBSCRIBE 228 | if *publishFlag == "" && len(subscribeFlags) == 0 && !ping(client) { 229 | return 230 | } 231 | // DISCONNECT when no SUBSCRIBE 232 | if len(subscribeFlags) == 0 && disconnect(client) { 233 | setExitStatusOnce(0) 234 | } 235 | }() 236 | 237 | // read routine 238 | var big *mqtt.BigMessage 239 | for { 240 | message, topic, err := client.ReadSlices() 241 | switch { 242 | case err == nil: 243 | printMessage(message, topic) 244 | 245 | case errors.As(err, &big): 246 | message, err = big.ReadAll() 247 | if err != nil { 248 | log.Print(err) 249 | os.Exit(1) 250 | } 251 | printMessage(message, big.Topic) 252 | 253 | case errors.Is(err, mqtt.ErrClosed): 254 | os.Exit(<-exitStatus) 255 | 256 | case errors.Is(err, mqtt.ErrProtocolLevel): 257 | os.Exit(5) 258 | case errors.Is(err, mqtt.ErrClientID): 259 | os.Exit(6) 260 | case errors.Is(err, mqtt.ErrUnavailable): 261 | os.Exit(7) 262 | case errors.Is(err, mqtt.ErrAuthBad): 263 | os.Exit(8) 264 | case errors.Is(err, mqtt.ErrAuth): 265 | os.Exit(9) 266 | 267 | default: 268 | log.Print(err) 269 | os.Exit(1) 270 | } 271 | } 272 | } 273 | 274 | func printMessage(message, topic interface{}) { 275 | switch { 276 | case *topicFlag && *quoteFlag: 277 | fmt.Printf("%q%s%q%s", topic, *prefixFlag, message, *suffixFlag) 278 | case *topicFlag: 279 | fmt.Printf("%s%s%s%s", topic, *prefixFlag, message, *suffixFlag) 280 | case *quoteFlag: 281 | fmt.Printf("%s%q%s", *prefixFlag, message, *suffixFlag) 282 | default: 283 | fmt.Printf("%s%s%s", *prefixFlag, message, *suffixFlag) 284 | } 285 | } 286 | 287 | func publish(client *mqtt.Client, topic string) (ok bool) { 288 | // messages of 256 MiB get an mqtt.IsDeny 289 | const bufMax = 256 * 1024 * 1024 290 | message, err := io.ReadAll(io.LimitReader(os.Stdin, bufMax)) 291 | if err != nil { 292 | log.Fatal(name, ": ", err) 293 | } 294 | ctx, cancel := context.WithTimeout(context.Background(), *timeoutFlag) 295 | defer cancel() 296 | err = client.Publish(ctx.Done(), message, topic) 297 | if err != nil { 298 | onReqErr(err, client) 299 | return false 300 | } 301 | if *verboseFlag { 302 | log.Printf("%s: published %d bytes to %q", name, len(message), *publishFlag) 303 | } 304 | return true 305 | } 306 | 307 | func subscribe(client *mqtt.Client, filters []string) (ok bool) { 308 | ctx, cancel := context.WithTimeout(context.Background(), *timeoutFlag) 309 | defer cancel() 310 | err := client.Subscribe(ctx.Done(), filters...) 311 | if err != nil { 312 | onReqErr(err, client) 313 | return false 314 | } 315 | if *verboseFlag { 316 | log.Printf("%s: subscribed to %d topic filters", name, len(subscribeFlags)) 317 | } 318 | return true 319 | } 320 | 321 | func ping(client *mqtt.Client) (ok bool) { 322 | // ping exchange 323 | ctx, cancel := context.WithTimeout(context.Background(), *timeoutFlag) 324 | defer cancel() 325 | err := client.Ping(ctx.Done()) 326 | if err != nil { 327 | onReqErr(err, client) 328 | return false 329 | } 330 | if *verboseFlag { 331 | log.Printf("%s: ping OK", name) 332 | } 333 | return true 334 | } 335 | 336 | func disconnect(client *mqtt.Client) (ok bool) { 337 | ctx, cancel := context.WithTimeout(context.Background(), *timeoutFlag) 338 | defer cancel() 339 | err := client.Disconnect(ctx.Done()) 340 | if err != nil { 341 | onReqErr(err, client) 342 | return false 343 | } 344 | if *verboseFlag { 345 | log.Printf("%s: disconnected", name) 346 | } 347 | return true 348 | } 349 | 350 | func onReqErr(err error, client *mqtt.Client) { 351 | if errors.Is(err, mqtt.ErrClosed) { 352 | return // already done for 353 | } 354 | log.Print(err) 355 | setExitStatusOnce(1) 356 | err = client.Close() 357 | if err != nil { 358 | log.Print(err) 359 | } 360 | } 361 | 362 | func applySignals(client *mqtt.Client) { 363 | signals := make(chan os.Signal, 1) 364 | signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) 365 | for sig := range signals { 366 | switch sig { 367 | case syscall.SIGINT: 368 | log.Print(name, ": SIGINT received") 369 | setExitStatusOnce(130) 370 | err := client.Close() 371 | if err != nil { 372 | log.Print(err) 373 | } 374 | 375 | case syscall.SIGTERM: 376 | log.Print(name, ": SIGTERM received") 377 | if disconnect(client) { 378 | setExitStatusOnce(143) 379 | } 380 | } 381 | } 382 | } 383 | 384 | func printManual() { 385 | if *quietFlag { 386 | return 387 | } 388 | 389 | log.Print(bold + "NAME\n\t" + name + clear + " \u2014 MQTT broker access\n" + 390 | "\n" + 391 | bold + "SYNOPSIS\n" + 392 | "\t" + bold + name + clear + " [options] address\n" + 393 | "\n" + 394 | bold + "DESCRIPTION" + clear + "\n" + 395 | "\tThe command connects to the address argument, with an option to\n" + 396 | "\tpublish a message and/or subscribe with topic filters.\n" + 397 | "\n" + 398 | "\tWhen the address does not specify a port, then the defaults are\n" + 399 | "\tapplied, which is 1883 for plain connections and 8883 for TLS.\n" + 400 | "\n" + 401 | bold + "OPTIONS" + clear + "\n", 402 | ) 403 | 404 | flag.PrintDefaults() 405 | 406 | log.Print("\n" + bold + "EXIT STATUS" + clear + "\n" + 407 | "\t(0) no error\n" + 408 | "\t(1) MQTT operational error\n" + 409 | "\t(2) illegal command invocation\n" + 410 | "\t(5) connection refused: unacceptable protocol version\n" + 411 | "\t(6) connection refused: identifier rejected\n" + 412 | "\t(7) connection refused: server unavailable\n" + 413 | "\t(8) connection refused: bad username or password\n" + 414 | "\t(9) connection refused: not authorized\n" + 415 | "\t(130) close on SIGINT\n" + 416 | "\t(143) disconnect on SIGTERM\n" + 417 | "\n" + 418 | 419 | bold + "EXAMPLES" + clear + "\n" + 420 | "\tSend a message:\n" + 421 | "\n" + 422 | "\t\techo \"hello\" | " + name + " -publish chat/misc localhost\n" + 423 | "\n" + 424 | "\tPrint messages:\n" + 425 | "\n" + 426 | "\t\t" + name + " -subscribe \"news/#\" -prefix \"📥 \" :1883\n" + 427 | "\n" + 428 | "\tHealth check:\n" + 429 | "\n" + 430 | "\t\t" + name + " -tls q1.example.com:8883 || echo \"exit $?\"\n" + 431 | "\n" + 432 | 433 | bold + "BUGS" + clear + "\n" + 434 | "\tReport bugs at .\n" + 435 | "\n" + 436 | 437 | "SEE ALSO" + clear + "\n\tmosquitto_pub(1)\n", 438 | ) 439 | } 440 | -------------------------------------------------------------------------------- /mqtt.go: -------------------------------------------------------------------------------- 1 | // Package mqtt provides a client for the Message Queuing Telemetry Transport 2 | // protocol. 3 | // 4 | // http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html 5 | // 6 | // Publish does a fire-and-forget submission. ErrClosed, ErrDown, ErrCanceled 7 | // and any IsDeny all imply no request submission. Any other error is guaranteed 8 | // to be an ErrSubmit, which leaves the request status in limbo. Disconnect is 9 | // similar, yet it won't IsDeny. 10 | // 11 | // Subscribe and Unsubscribe await response from the broker. ErrClosed, ErrDown, 12 | // ErrMax, ErrCanceled and any IsDeny all imply no request submission. A broker 13 | // may fail subscription, which is represented by SubscribeError. Any other 14 | // error is guaranteed to be a an ErrSubmit, ErrBreak or ErrAbandoned, which 15 | // leaves the request status in limbo. Ping is similar, yet it won't IsDeny. 16 | // 17 | // PublishAtLeastOnce and PublishExactlyOnce enqueue requests to a Persistence. 18 | // Errors (either ErrClosed, ErrMax, any IsDeny, or any Save return) all imply 19 | // that the message was dropped. Once persisted, the Client will execute the 20 | // transfer with endless retries. 21 | package mqtt 22 | 23 | import ( 24 | "encoding/binary" 25 | "errors" 26 | "fmt" 27 | "hash/fnv" 28 | "net" 29 | "os" 30 | "strconv" 31 | "strings" 32 | "sync" 33 | "sync/atomic" 34 | "unicode/utf8" 35 | ) 36 | 37 | // Control packets have a 4-bit type code in the first byte. 38 | const ( 39 | typeRESERVED0 = iota 40 | typeCONNECT 41 | typeCONNACK 42 | typePUBLISH 43 | typePUBACK // QOS level 1 confirm 44 | typePUBREC // QOS level 2 confirm, part Ⅰ 45 | typePUBREL // QOS level 2 confirm, part Ⅱ 46 | typePUBCOMP // QOS level 2 confirm, part Ⅲ 47 | typeSUBSCRIBE 48 | typeSUBACK 49 | typeUNSUBSCRIBE 50 | typeUNSUBACK 51 | typePINGREQ 52 | typePINGRESP 53 | typeDISCONNECT 54 | typeRESERVED15 55 | ) 56 | 57 | // The quality-of-service is defined as a numeric level. 58 | const ( 59 | atMostOnceLevel = iota // fire and forget 60 | atLeastOnceLevel // network round trip + persistence 61 | exactlyOnceLevel // two network round trips + persistence 62 | ) 63 | 64 | // PUBLISH packets may have flags in the header. 65 | const ( 66 | dupeFlag = 0b1000 // retry of an earlier attempt 67 | retainFlag = 0b0001 // store for future subscribers 68 | ) 69 | 70 | // Some packet types do not carry any payload. 71 | var ( 72 | packetDISCONNECT = []byte{typeDISCONNECT << 4, 0} 73 | packetPINGREQ = []byte{typePINGREQ << 4, 0} 74 | ) 75 | 76 | // Capacity limitations are defined by their respective size prefix. 77 | const ( 78 | // See MQTT Version 3.1.1, table 2.4: “Size of Remaining Length field”. 79 | packetMax = 1<<(4*7) - 1 // 4-byte varint 80 | 81 | // “Unless stated otherwise all UTF-8 encoded strings can have any 82 | // length in the range 0 to 65535 bytes.” 83 | // — MQTT Version 3.1.1, subsection 1.5.3 84 | stringMax = 1<<16 - 1 // 16-bit size prefixes 85 | ) 86 | 87 | var ( 88 | // ErrPacketMax enforces packetMax. 89 | errPacketMax = errors.New("mqtt: packet reached 256 MiB limit") 90 | // ErrStringMax enforces stringMax. 91 | errStringMax = errors.New("mqtt: string reached 64 KiB limit") 92 | 93 | errUTF8 = errors.New("mqtt: invalid UTF-8 byte sequence") 94 | errNull = errors.New("mqtt: string contains null character") 95 | errZero = errors.New("mqtt: string is empty") 96 | ) 97 | 98 | func stringCheck(s string) error { 99 | if len(s) > stringMax { 100 | return errStringMax 101 | } 102 | 103 | // “The character data in a UTF-8 encoded string MUST be well-formed 104 | // UTF-8 as defined by the Unicode specification and restated in RFC 105 | // 3629.” 106 | if !utf8.ValidString(s) { 107 | return errUTF8 108 | } 109 | 110 | // “A UTF-8 encoded string MUST NOT include an encoding of the null 111 | // character U+0000.” 112 | // — MQTT Version 3.1.1, conformance statement MQTT-1.5.3-2 113 | if strings.IndexByte(s, 0) >= 0 { 114 | return errNull 115 | } 116 | 117 | // Characters 0x01–0x1F, 0x7F–0x9F, and the code points defined in the 118 | // Unicode specification to be non-characters are all tolerated, yet the 119 | // receiver may decide otherwise, and it may close the connection. 120 | return nil 121 | } 122 | 123 | // “All Topic Names and Topic Filters MUST be at least one character long.” 124 | // — MQTT Version 3.1.1, conformance statement MQTT-4.7.3-1 125 | func topicCheck(s string) error { 126 | if s == "" { 127 | return errZero 128 | } 129 | return stringCheck(s) 130 | } 131 | 132 | // NonNilIsAny is a slightly more optimal errors.Is for multiple targets. 133 | // Matches are each assumed to be comprable. 134 | func nonNilIsAny(err error, matches []error) bool { 135 | var more []error 136 | for { 137 | withIs, hasIs := err.(interface{ Is(error) bool }) 138 | for _, match := range matches { 139 | if err == match || hasIs && withIs.Is(match) { 140 | return true 141 | } 142 | } 143 | 144 | switch u := err.(type) { 145 | case interface{ Unwrap() error }: 146 | err = u.Unwrap() 147 | if err != nil { 148 | continue 149 | } 150 | 151 | case interface{ Unwrap() []error }: 152 | wrapped := u.Unwrap() 153 | if more == nil { 154 | // ensure append (up next) copies, just in case 155 | more = wrapped[:len(wrapped):len(wrapped)] 156 | } else { 157 | more = append(more, wrapped...) 158 | } 159 | } 160 | 161 | if len(more) == 0 { 162 | return false 163 | } 164 | err = more[len(more)-1] 165 | more = more[:len(more)-1] 166 | } 167 | } 168 | 169 | var denyErrs = []error{errPacketMax, errStringMax, errUTF8, errNull, errZero, errSubscribeNone, errUnsubscribeNone} 170 | 171 | // IsDeny returns whether execution got rejected by a Client based on validation 172 | // constraints, such as size limits or invalid UTF-8. The rejection is permanent 173 | // in such case. Retries are futile. 174 | func IsDeny(err error) bool { 175 | return err != nil && nonNilIsAny(err, denyErrs) 176 | } 177 | 178 | var endErrs = []error{ErrClosed, ErrCanceled, ErrAbandoned} 179 | 180 | // IsEnd returns whether execution got rejected by a Client based on lifespan, 181 | // which is true for ErrClosed, ErrCanceled and ErrAbandoned. The rejection is 182 | // permanent in such case. Retries are futile. 183 | func IsEnd(err error) bool { 184 | return err != nil && nonNilIsAny(err, endErrs) 185 | } 186 | 187 | // ConnectReturn is the response code from CONNACK. 188 | type connectReturn byte 189 | 190 | // Connect return errors are predefined reasons for a broker to deny a connect 191 | // request. IsConnectionRefused returns true for each of these. 192 | const ( 193 | accepted connectReturn = iota 194 | 195 | // ErrProtocolLevel means that the broker does not support the level of 196 | // the MQTT protocol requested by the Client. 197 | ErrProtocolLevel 198 | 199 | // ErrClientID means that the client identifier is correct UTF-8 but not 200 | // allowed by the broker. 201 | ErrClientID 202 | 203 | // ErrUnavailable means that the network connection has been made but 204 | // the MQTT service is unavailable. 205 | ErrUnavailable 206 | 207 | // ErrAuthBad means that the data in the user name or password is 208 | // malformed. 209 | ErrAuthBad 210 | 211 | // ErrAuth means that the client is not authorized to connect. 212 | ErrAuth 213 | ) 214 | 215 | // Error implements the standard error interface. 216 | func (code connectReturn) Error() string { 217 | const refuse = "mqtt: connection refused: " 218 | 219 | switch code { 220 | case accepted: 221 | return "mqtt: connect return code 0 “Connection Accepted” used as an error" 222 | case ErrProtocolLevel: 223 | return refuse + "unacceptable protocol version" 224 | case ErrClientID: 225 | return refuse + "client identifier rejected" 226 | case ErrUnavailable: 227 | return refuse + "server unavailable" 228 | case ErrAuthBad: 229 | return refuse + "bad user name or password" 230 | case ErrAuth: 231 | return refuse + "not authorized" 232 | default: 233 | return fmt.Sprintf(refuse+"connect return code %d reserved for future use", code) 234 | } 235 | } 236 | 237 | // IsConnectionRefused returns whether the broker denied a connect request from 238 | // the Client. 239 | func IsConnectionRefused(err error) bool { 240 | var code connectReturn 241 | if errors.As(err, &code) { 242 | return code != accepted 243 | } 244 | return false 245 | } 246 | 247 | // Persistence keys correspond to MQTT packet identifiers. 248 | const ( 249 | // The 16-bit packet identifiers on inbound and outbound requests each 250 | // have their own namespace. The most-significant bit from a key makes 251 | // the distinction. 252 | remoteIDKeyFlag = 1 << 16 253 | 254 | // Packet identifier zero is not in use by the protocol. 255 | clientIDKey = 0 256 | ) 257 | 258 | // Persistence tracks the session state as a key–value store. An instance may 259 | // serve only one Client at a time. 260 | // 261 | // Values are addressed by a 17-bit key, mask 0x1ffff. The minimum size is 12 B. 262 | // The maximum size is 256 MiB + 17 B. Clients apply integrity checks all round. 263 | // 264 | // Multiple goroutines may invoke methods on a Persistence simultaneously. 265 | type Persistence interface { 266 | // Load resolves the value of a key. A nil return means “not found”. 267 | Load(key uint) ([]byte, error) 268 | 269 | // Save defines the value of a key. 270 | Save(key uint, value net.Buffers) error 271 | 272 | // Delete clears the value of a key, whether it existed or not. Failures 273 | // will be overwitten eventually due to the limited address space. 274 | Delete(key uint) error 275 | 276 | // List enumerates all available in any order. 277 | List() (keys []uint, err error) 278 | } 279 | 280 | // Volatile is an in-memory Persistence. 281 | type volatile struct { 282 | sync.Mutex 283 | perKey map[uint][]byte 284 | } 285 | 286 | func newVolatile() Persistence { 287 | return &volatile{perKey: make(map[uint][]byte)} 288 | } 289 | 290 | // Load implements the Persistence interface. 291 | func (m *volatile) Load(key uint) ([]byte, error) { 292 | m.Lock() 293 | defer m.Unlock() 294 | return m.perKey[key], nil 295 | } 296 | 297 | // Save implements the Persistence interface. 298 | func (m *volatile) Save(key uint, value net.Buffers) error { 299 | var n int 300 | for _, buf := range value { 301 | n += len(buf) 302 | } 303 | bytes := make([]byte, n) 304 | i := 0 305 | for _, buf := range value { 306 | i += copy(bytes[i:], buf) 307 | } 308 | 309 | m.Lock() 310 | defer m.Unlock() 311 | m.perKey[key] = bytes 312 | return nil 313 | } 314 | 315 | // Delete implements the Persistence interface. 316 | func (m *volatile) Delete(key uint) error { 317 | m.Lock() 318 | defer m.Unlock() 319 | delete(m.perKey, key) 320 | return nil 321 | } 322 | 323 | // List implements the Persistence interface. 324 | func (m *volatile) List() (keys []uint, err error) { 325 | m.Lock() 326 | defer m.Unlock() 327 | keys = make([]uint, 0, len(m.perKey)) 328 | for k := range m.perKey { 329 | keys = append(keys, k) 330 | } 331 | return keys, nil 332 | } 333 | 334 | type fileSystem string 335 | 336 | // FileSystem stores values per file in a directory. Callers must ensure the 337 | // availability, including write permission for the user. The empty string 338 | // defaults to the working directory. 339 | func FileSystem(dir string) Persistence { 340 | if dir != "" && dir[len(dir)-1] != os.PathSeparator { 341 | dir += string([]rune{os.PathSeparator}) 342 | } 343 | return fileSystem(dir) 344 | } 345 | 346 | func (dir fileSystem) file(key uint) string { 347 | return fmt.Sprintf("%s%05x", dir, key) 348 | } 349 | 350 | func (dir fileSystem) spoolFile(key uint) string { 351 | return fmt.Sprintf("%s%05x.spool", dir, key) 352 | } 353 | 354 | // Load implements the Persistence interface. 355 | func (dir fileSystem) Load(key uint) ([]byte, error) { 356 | value, err := os.ReadFile(dir.file(key)) 357 | if err != nil && errors.Is(err, os.ErrNotExist) { 358 | return nil, nil 359 | } 360 | return value, err 361 | } 362 | 363 | // Save implements the Persistence interface. 364 | func (dir fileSystem) Save(key uint, value net.Buffers) error { 365 | f, err := os.Create(dir.spoolFile(key)) 366 | if err != nil { 367 | return err 368 | } 369 | _, err = value.WriteTo(f) 370 | // ⚠️ inverse error checks 371 | if err == nil { 372 | err = f.Sync() 373 | } 374 | f.Close() 375 | if err == nil { 376 | err = os.Rename(f.Name(), dir.file(key)) 377 | } 378 | if err == nil { 379 | return nil // OK 380 | } 381 | 382 | if removeErr := os.Remove(f.Name()); removeErr != nil { 383 | err = fmt.Errorf("%w, AND file leak: %w", err, removeErr) 384 | } 385 | return err 386 | } 387 | 388 | // Delete implements the Persistence interface. 389 | func (dir fileSystem) Delete(key uint) error { 390 | err := os.Remove(dir.file(key)) 391 | if err == nil || errors.Is(err, os.ErrNotExist) { 392 | return nil 393 | } 394 | return err 395 | } 396 | 397 | // List implements the Persistence interface. 398 | func (dir fileSystem) List() (keys []uint, err error) { 399 | f, err := os.Open(string(dir)) 400 | if err != nil { 401 | return nil, err 402 | } 403 | defer f.Close() 404 | names, err := f.Readdirnames(0) 405 | if err != nil { 406 | return nil, err 407 | } 408 | 409 | keys = make([]uint, 0, len(names)) 410 | for _, name := range names { 411 | if len(name) != 5 { 412 | continue 413 | } 414 | u, err := strconv.ParseUint(name, 16, 17) 415 | if err != nil { 416 | continue 417 | } 418 | keys = append(keys, uint(u)) 419 | } 420 | return keys, nil 421 | } 422 | 423 | // ruggedPersistence applies a sequence number plus integrity checks to a 424 | // delegate. 425 | type ruggedPersistence struct { 426 | Persistence // delegate 427 | 428 | // Content is ordered based on this sequence number. The clientIDKey 429 | // will have zero, as it is set [Save] only once as the first thing. 430 | seqNo atomic.Uint64 431 | } 432 | 433 | // Load implements the Persistence interface. 434 | func (r *ruggedPersistence) Load(key uint) ([]byte, error) { 435 | value, err := r.Persistence.Load(key) 436 | switch { 437 | case err != nil: 438 | return nil, err 439 | case value == nil: 440 | return nil, nil 441 | default: 442 | value, _, err := decodeValue(value) 443 | if err != nil { 444 | return nil, fmt.Errorf("%w; record %#x unavailable", err, key) 445 | } 446 | return value, err 447 | } 448 | } 449 | 450 | // Save implements the Persistence interface. 451 | func (r *ruggedPersistence) Save(key uint, value net.Buffers) error { 452 | return r.Persistence.Save(key, encodeValue(value, r.seqNo.Add(1))) 453 | } 454 | 455 | func encodeValue(packet net.Buffers, seqNo uint64) net.Buffers { 456 | digest := fnv.New32a() 457 | for _, buf := range packet { 458 | digest.Write(buf) 459 | } 460 | var buf [12]byte 461 | binary.LittleEndian.PutUint64(buf[:8], seqNo) 462 | digest.Write(buf[:8]) 463 | binary.BigEndian.PutUint32(buf[8:], digest.Sum32()) 464 | return append(packet, buf[:]) 465 | } 466 | 467 | func decodeValue(buf []byte) (packet []byte, seqNo uint64, _ error) { 468 | if len(buf) < 12 { 469 | return nil, 0, errors.New("mqtt: persisted value truncated") 470 | } 471 | digest := fnv.New32a() 472 | digest.Write(buf[:len(buf)-4]) 473 | if digest.Sum32() != binary.BigEndian.Uint32(buf[len(buf)-4:]) { 474 | return nil, 0, errors.New("mqtt: persisted value corrupt") 475 | } 476 | return buf[:len(buf)-12], binary.LittleEndian.Uint64(buf[len(buf)-12:]), nil 477 | } 478 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package mqtt_test 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/hex" 7 | "errors" 8 | "io" 9 | "net" 10 | "sync/atomic" 11 | "testing" 12 | "time" 13 | 14 | "github.com/pascaldekloe/mqtt" 15 | "github.com/pascaldekloe/mqtt/mqtttest" 16 | ) 17 | 18 | const testTimeout = time.Second 19 | 20 | // VerifyClient reads client with assertions and timeouts. 21 | func verifyClient(t *testing.T, client *mqtt.Client, want ...mqtttest.Transfer) { 22 | // extra offline state check 23 | select { 24 | case <-client.Online(): 25 | t.Fatal("online signal receive before ReadSlices") 26 | default: 27 | break 28 | } 29 | select { 30 | case <-client.Offline(): 31 | break 32 | default: 33 | t.Fatal("offline signal blocked before ReadSlices") 34 | } 35 | 36 | timeoutOver := make(chan struct{}) 37 | timeout := time.AfterFunc(2*time.Second, func() { 38 | defer close(timeoutOver) 39 | 40 | t.Error("test timeout; closing Client now…") 41 | err := client.Close() 42 | if err != nil { 43 | t.Error("Close error:", err) 44 | } 45 | }) 46 | 47 | readRoutineDone := testRoutine(t, func() { 48 | defer func() { 49 | if timeout.Stop() { 50 | close(timeoutOver) 51 | } 52 | }() 53 | 54 | const readSlicesMax = 10 55 | for n := 0; n < readSlicesMax; n++ { 56 | message, topic, err := client.ReadSlices() 57 | 58 | if errors.Is(err, errLastTestConn) { 59 | t.Log("backoff on:", err) 60 | time.Sleep(10 * time.Millisecond) 61 | continue 62 | } 63 | 64 | if errors.Is(err, mqtt.ErrClosed) { 65 | for i := range want { 66 | if want[i].Err != nil { 67 | t.Errorf("client closed, want ReadSlices error: %v", want[i].Err) 68 | } else { 69 | t.Errorf("client closed, want ReadSlices message %#.1000x @ %q", want[i].Message, want[i].Topic) 70 | } 71 | } 72 | 73 | return 74 | } 75 | 76 | if len(want) == 0 { 77 | t.Errorf("ReadSlices got message %q, topic %q, and error %q, want ErrClosed", 78 | message, topic, err) 79 | continue 80 | } 81 | 82 | var big *mqtt.BigMessage 83 | if errors.As(err, &big) { 84 | t.Log("got BigMessage") 85 | topic = []byte(big.Topic) 86 | message, err = big.ReadAll() 87 | } 88 | 89 | if err != nil { 90 | if want[0].Err == nil { 91 | t.Errorf("ReadSlices got error %q, want message %#.200x @ %q", err, want[0].Message, want[0].Topic) 92 | } else if !errors.Is(err, want[0].Err) && err.Error() != want[0].Err.Error() { 93 | t.Errorf("ReadSlices got error %q, want errors.Is %q", err, want[0].Err) 94 | } 95 | } else { 96 | if want[0].Err != nil { 97 | t.Errorf("ReadSlices got message %#.200x @ %q, want error %q", message, topic, want[0].Err) 98 | } else if !bytes.Equal(message, want[0].Message) || string(topic) != want[0].Topic { 99 | t.Errorf("ReadSlices got message %#.200x @ %q, want %#.200x @ %q", message, topic, want[0].Message, want[0].Topic) 100 | } 101 | } 102 | 103 | want = want[1:] // move to next in line 104 | } 105 | 106 | t.Errorf("test abort after %d ReadSlices", readSlicesMax) 107 | }) 108 | 109 | t.Cleanup(func() { 110 | err := client.Close() 111 | if err != nil { 112 | t.Error("client close error:", err) 113 | } 114 | 115 | // extra offline state check 116 | select { 117 | case <-client.Online(): 118 | t.Error("online signal receive after client close") 119 | default: 120 | break 121 | } 122 | select { 123 | case <-client.Offline(): 124 | break 125 | default: 126 | t.Error("offline signal blocked after client close") 127 | } 128 | 129 | // no routine leaks 130 | <-readRoutineDone 131 | <-timeoutOver 132 | }) 133 | } 134 | 135 | // NewTestClient returns a new Client which dials to a pipe. 136 | func newTestClient(t *testing.T, want ...mqtttest.Transfer) (*mqtt.Client, net.Conn, <-chan struct{}) { 137 | // type of test is slow in general 138 | t.Parallel() 139 | // start timers after Parallel branche 140 | ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 141 | t.Cleanup(cancel) 142 | testDeadline, _ := ctx.Deadline() 143 | 144 | clientConn, brokerConn := net.Pipe() 145 | // expire I/O mock before tests timeout 146 | brokerConn.SetDeadline(testDeadline.Add(-200 * time.Millisecond)) 147 | 148 | client := newTestClientDials(t, clientConn) 149 | verifyClient(t, client, want...) 150 | return client, brokerConn, ctx.Done() 151 | } 152 | 153 | // NewTestClientOnline completes the CONNECT from a newTestClient. 154 | func newTestClientOnline(t *testing.T, want ...mqtttest.Transfer) (*mqtt.Client, net.Conn, <-chan struct{}) { 155 | client, conn, testTimeout := newTestClient(t, want...) 156 | wantConnectExchange(t, conn) 157 | wantOnline(t, client, testTimeout) 158 | return client, conn, testTimeout 159 | } 160 | 161 | // NewTestClientRedial verifies a newTestClient to dial twice during the test. 162 | func newTestClientRedial(t *testing.T, want ...mqtttest.Transfer) (*mqtt.Client, [2]net.Conn, <-chan struct{}) { 163 | // type of test is slow in general 164 | t.Parallel() 165 | // start timers after Parallel branche 166 | ctx, cancel := context.WithTimeout(context.Background(), testTimeout) 167 | t.Cleanup(cancel) 168 | testDeadline, _ := ctx.Deadline() 169 | 170 | var clientConns, brokerConns [2]net.Conn 171 | for i := range clientConns { 172 | clientConns[i], brokerConns[i] = net.Pipe() 173 | // expire I/O mocks before tests timeout 174 | brokerConns[i].SetDeadline(testDeadline.Add(-200 * time.Millisecond)) 175 | } 176 | 177 | client := newTestClientDials(t, clientConns[:]...) 178 | verifyClient(t, client, want...) 179 | 180 | return client, brokerConns, ctx.Done() 181 | } 182 | 183 | // NewTestClientOnlineRedial completes the CONNECT from a newTestClientRedial. 184 | func newTestClientOnlineRedial(t *testing.T, want ...mqtttest.Transfer) (*mqtt.Client, [2]net.Conn, <-chan struct{}) { 185 | client, conns, testTimeout := newTestClientRedial(t, want...) 186 | wantConnectExchange(t, conns[0]) 187 | wantOnline(t, client, testTimeout) 188 | return client, conns, testTimeout 189 | } 190 | 191 | func wantConnectExchange(t *testing.T, conn net.Conn) { 192 | wantPacketHex(t, conn, "100c00044d515454040000000000") // CONNECT 193 | sendPacketHex(t, conn, "20020000") // CONNACK 194 | } 195 | 196 | func wantOnline(t *testing.T, client *mqtt.Client, timeout <-chan struct{}) { 197 | select { 198 | case <-client.Online(): 199 | break 200 | case <-timeout: 201 | t.Fatal("test timeout before Online") 202 | } 203 | } 204 | 205 | func newTestClientDials(t *testing.T, conns ...net.Conn) *mqtt.Client { 206 | client, err := mqtt.VolatileSession("", &mqtt.Config{ 207 | PauseTimeout: time.Second / 4, 208 | AtLeastOnceMax: 2, 209 | ExactlyOnceMax: 2, 210 | Dialer: newDialerMock(t, conns...), 211 | }) 212 | if err != nil { 213 | t.Fatal("volatile session error:", err) 214 | } 215 | return client 216 | } 217 | 218 | var errLastTestConn = errors.New("Dialer mock exhausted: all connections served") 219 | 220 | // NewDialerMock returns a new dialer which returns the conns in order of 221 | // appearance. The test fails on fewer dials. 222 | func newDialerMock(t *testing.T, conns ...net.Conn) mqtt.Dialer { 223 | t.Helper() 224 | 225 | var dialN atomic.Uint64 226 | 227 | t.Cleanup(func() { 228 | n := dialN.Load() 229 | if n < uint64(len(conns)) && !t.Failed() { 230 | t.Errorf("got %d Dialer invocations, want %d", n, len(conns)) 231 | } 232 | }) 233 | 234 | return func(context.Context) (net.Conn, error) { 235 | n := dialN.Add(1) 236 | t.Log("Dial #", n) 237 | if n > uint64(len(conns)) { 238 | return nil, errLastTestConn 239 | } 240 | 241 | return conns[n-1], nil 242 | } 243 | } 244 | 245 | func TestClose(t *testing.T) { 246 | client, err := mqtt.VolatileSession("test-client", &mqtt.Config{ 247 | Dialer: func(ctx context.Context) (net.Conn, error) { 248 | <-ctx.Done() 249 | return nil, ctx.Err() 250 | }, 251 | PauseTimeout: time.Second / 4, 252 | }) 253 | if err != nil { 254 | t.Fatal("volatile session error:", err) 255 | } 256 | 257 | online := client.Online() 258 | select { 259 | case <-online: 260 | t.Error("online signal receive on initial state") 261 | default: 262 | break 263 | } 264 | offline := client.Offline() 265 | select { 266 | case <-offline: 267 | break 268 | default: 269 | t.Error("offline signal blocked on initial state") 270 | } 271 | 272 | // Race because we can. ™️ 273 | for n := 0; n < 3; n++ { 274 | go func() { 275 | err := client.Close() 276 | if err != nil { 277 | t.Error("got close error:", err) 278 | } 279 | }() 280 | } 281 | 282 | _, _, err = client.ReadSlices() 283 | if !errors.Is(err, mqtt.ErrClosed) { 284 | t.Fatalf("ReadSlices got error %q, want an ErrClosed", err) 285 | } 286 | 287 | // Run twice to ensure the semaphores ain't leaking. 288 | for roundN := 1; roundN <= 2; roundN++ { 289 | err = client.Subscribe(nil, "x") 290 | if !errors.Is(err, mqtt.ErrClosed) { 291 | t.Errorf("Subscribe round %d got error %q, want an ErrClosed", roundN, err) 292 | } 293 | err = client.Unsubscribe(nil, "x") 294 | if !errors.Is(err, mqtt.ErrClosed) { 295 | t.Errorf("Unsubscribe round %d got error %q, want an ErrClosed", roundN, err) 296 | } 297 | err = client.Publish(nil, nil, "x") 298 | if !errors.Is(err, mqtt.ErrClosed) { 299 | t.Errorf("Publish round %d got error %q, want an ErrClosed", roundN, err) 300 | } 301 | err = client.PublishRetained(nil, nil, "x") 302 | if !errors.Is(err, mqtt.ErrClosed) { 303 | t.Errorf("PublishRetained round %d got error %q, want an ErrClosed", roundN, err) 304 | } 305 | _, err = client.PublishAtLeastOnce(nil, "x") 306 | if !errors.Is(err, mqtt.ErrClosed) { 307 | t.Errorf("PublishAtLeastOnce round %d got error %q, want an ErrClosed", roundN, err) 308 | } 309 | _, err = client.PublishAtLeastOnceRetained(nil, "x") 310 | if !errors.Is(err, mqtt.ErrClosed) { 311 | t.Errorf("PublishAtLeastOnceRetained round %d got error %q, want an ErrClosed", roundN, err) 312 | } 313 | _, err = client.PublishExactlyOnce(nil, "x") 314 | if !errors.Is(err, mqtt.ErrClosed) { 315 | t.Errorf("PublishExactlyOnce round %d got error %q, want an ErrClosed", roundN, err) 316 | } 317 | _, err = client.PublishExactlyOnceRetained(nil, "x") 318 | if !errors.Is(err, mqtt.ErrClosed) { 319 | t.Errorf("PublishExactlyOnceRetained round %d got error %q, want an ErrClosed", roundN, err) 320 | } 321 | err = client.Ping(nil) 322 | if !errors.Is(err, mqtt.ErrClosed) { 323 | t.Errorf("Ping round %d got error %q, want an ErrClosed", roundN, err) 324 | } 325 | err = client.Disconnect(nil) 326 | if !errors.Is(err, mqtt.ErrClosed) { 327 | t.Errorf("Disconnect round %d got error %q, want an ErrClosed", roundN, err) 328 | } 329 | _, _, err = client.ReadSlices() 330 | if !errors.Is(err, mqtt.ErrClosed) { 331 | t.Fatalf("ReadSlices round %d got error %q, want an ErrClosed", roundN, err) 332 | } 333 | } 334 | 335 | select { 336 | case <-online: 337 | t.Error("online signal receive") 338 | default: 339 | break 340 | } 341 | select { 342 | case <-offline: 343 | break 344 | default: 345 | t.Error("offline signal blocked") 346 | } 347 | } 348 | 349 | func TestDown(t *testing.T) { 350 | brokerEnd, clientEnd := net.Pipe() 351 | brokerEnd.SetDeadline(time.Now().Add(800 * time.Millisecond)) 352 | 353 | var dialN int 354 | client, err := mqtt.VolatileSession("", &mqtt.Config{ 355 | Dialer: func(context.Context) (net.Conn, error) { 356 | dialN++ 357 | if dialN > 1 { 358 | return nil, errors.New("no more connections for test") 359 | } 360 | return clientEnd, nil 361 | }, 362 | PauseTimeout: time.Second / 4, 363 | AtLeastOnceMax: 2, 364 | ExactlyOnceMax: 2, 365 | }) 366 | if err != nil { 367 | t.Fatal("volatile session error:", err) 368 | } 369 | 370 | brokerMockDone := testRoutine(t, func() { 371 | wantPacketHex(t, brokerEnd, "100c00044d515454040000000000") // CONNECT 372 | sendPacketHex(t, brokerEnd, "20020003") // CONNACK 373 | }) 374 | 375 | message, topic, err := client.ReadSlices() 376 | if !errors.Is(err, mqtt.ErrUnavailable) { 377 | t.Fatalf("ReadSlices got (%q, %q, %q), want an ErrUnavailable", message, topic, err) 378 | } 379 | if !mqtt.IsConnectionRefused(err) { 380 | t.Errorf("ReadSlices error %q is not an IsConnectionRefused", err) 381 | } 382 | <-brokerMockDone 383 | 384 | // Run twice to ensure the semaphores ain't leaking. 385 | for roundN := 1; roundN <= 2; roundN++ { 386 | err := client.Subscribe(nil, "x") 387 | if !errors.Is(err, mqtt.ErrDown) { 388 | t.Errorf("Subscribe round %d got error %q, want an ErrDown", roundN, err) 389 | } 390 | err = client.Unsubscribe(nil, "x") 391 | if !errors.Is(err, mqtt.ErrDown) { 392 | t.Errorf("Unsubscribe round %d got error %q, want an ErrDown", roundN, err) 393 | } 394 | err = client.Publish(nil, nil, "x") 395 | if !errors.Is(err, mqtt.ErrDown) { 396 | t.Errorf("Publish round %d got error %q, want an ErrDown", roundN, err) 397 | } 398 | err = client.PublishRetained(nil, nil, "x") 399 | if !errors.Is(err, mqtt.ErrDown) { 400 | t.Errorf("PublishRetained round %d got error %q, want an ErrDown", roundN, err) 401 | } 402 | _, err = client.PublishAtLeastOnce(nil, "x") 403 | if roundN > 1 { 404 | if !errors.Is(err, mqtt.ErrMax) { 405 | t.Errorf("PublishAtLeastOnce round %d got error %q, want an ErrMax", roundN, err) 406 | } 407 | } else if err != nil { 408 | t.Errorf("PublishAtLeastOnce round %d got error %q", roundN, err) 409 | } 410 | _, err = client.PublishAtLeastOnceRetained(nil, "x") 411 | if roundN > 1 { 412 | if !errors.Is(err, mqtt.ErrMax) { 413 | t.Errorf("PublishAtLeastOnceRetained round %d got error %q, want an ErrMax", roundN, err) 414 | } 415 | } else if err != nil { 416 | t.Errorf("PublishAtLeastOnceRetained round %d got error %q", roundN, err) 417 | } 418 | _, err = client.PublishExactlyOnce(nil, "x") 419 | if roundN > 1 { 420 | if !errors.Is(err, mqtt.ErrMax) { 421 | t.Errorf("PublishExactlyOnce round %d got error %q, want an ErrMax", roundN, err) 422 | } 423 | } else if err != nil { 424 | t.Errorf("PublishExactlyOnce round %d got error %q", roundN, err) 425 | } 426 | _, err = client.PublishExactlyOnceRetained(nil, "x") 427 | if roundN > 1 { 428 | if !errors.Is(err, mqtt.ErrMax) { 429 | t.Errorf("PublishExactlyOnceRetained round %d got error %q, want an ErrMax", roundN, err) 430 | } 431 | } else if err != nil { 432 | t.Errorf("PublishExactlyOnceRetained round %d got error %q", roundN, err) 433 | } 434 | err = client.Ping(nil) 435 | if !errors.Is(err, mqtt.ErrDown) { 436 | t.Errorf("Ping round %d got error %q, want an ErrDown", roundN, err) 437 | } 438 | } 439 | } 440 | 441 | func TestReceivePublishAtLeastOnce(t *testing.T) { 442 | _, conn, _ := newTestClientOnline(t, 443 | mqtttest.Transfer{Message: []byte("hello"), Topic: "greet"}, 444 | ) 445 | 446 | sendPacketHex(t, conn, hex.EncodeToString([]byte{ 447 | 0x32, 14, 448 | 0, 5, 'g', 'r', 'e', 'e', 't', 449 | 0xab, 0xcd, // packet identifier 450 | 'h', 'e', 'l', 'l', 'o'})) 451 | wantPacketHex(t, conn, "4002abcd") // PUBACK 452 | } 453 | 454 | func TestReceivePublishExactlyOnce(t *testing.T) { 455 | _, conn, _ := newTestClientOnline(t, 456 | mqtttest.Transfer{Message: []byte("hello"), Topic: "greet"}, 457 | ) 458 | 459 | sendPacketHex(t, conn, hex.EncodeToString([]byte{ 460 | 0x34, 14, 461 | 0, 5, 'g', 'r', 'e', 'e', 't', 462 | 0xab, 0xcd, // packet identifier 463 | 'h', 'e', 'l', 'l', 'o'})) 464 | wantPacketHex(t, conn, "5002abcd") // PUBREC 465 | sendPacketHex(t, conn, "6002abcd") // PUBREL 466 | wantPacketHex(t, conn, "7002abcd") // PUBCOMP 467 | } 468 | 469 | func TestReceivePublishAtLeastOnceBig(t *testing.T) { 470 | const bigN = 256 * 1024 471 | _, conn, _ := newTestClientOnline(t, 472 | mqtttest.Transfer{Message: bytes.Repeat([]byte{'A'}, bigN), Topic: "bam"}, 473 | ) 474 | 475 | sendPacketHex(t, conn, "32"+ // publish at least once 476 | "878010"+ // size varint 7 + bigN 477 | "000362616d"+ // topic 478 | "abcd") // packet identifier 479 | _, err := conn.Write(bytes.Repeat([]byte{'A'}, bigN)) 480 | if err != nil { 481 | t.Fatal("payload submission error:", err) 482 | } 483 | wantPacketHex(t, conn, "4002abcd") // PUBACK 484 | } 485 | 486 | func testRoutine(t *testing.T, f func()) (done <-chan struct{}) { 487 | t.Helper() 488 | ch := make(chan struct{}) 489 | go func() { 490 | defer close(ch) 491 | f() 492 | }() 493 | t.Cleanup(func() { 494 | t.Helper() 495 | select { 496 | case <-ch: 497 | break // OK 498 | default: 499 | t.Error("test routine leak") 500 | } 501 | }) 502 | return ch 503 | } 504 | 505 | func sendPacketHex(t *testing.T, conn net.Conn, send string) { 506 | t.Helper() 507 | t.Logf("send %s…", typeLabelFromHex(send[0])) 508 | packet, err := hex.DecodeString(send) 509 | if err != nil { 510 | t.Fatalf("test has malformed packet data 0x%s: %s", send, err) 511 | } 512 | _, err = conn.Write(packet) 513 | if err != nil { 514 | t.Fatalf("broker write 0x%s error: %s", send, err) 515 | } 516 | } 517 | 518 | func wantPacketHex(t *testing.T, conn net.Conn, want string) { 519 | t.Helper() 520 | t.Logf("want %s…", typeLabelFromHex(want[0])) 521 | var buf [128]byte 522 | _, err := io.ReadFull(conn, buf[:2]) 523 | if err != nil { 524 | t.Fatalf("broker read error %q, want 0x%s", err, want) 525 | } 526 | if buf[1] > 126 { 527 | t.Fatalf("packet %#x… too big for test, want 0x%s", buf[:2], want) 528 | } 529 | n, err := io.ReadFull(conn, buf[2:2+buf[1]]) 530 | if err != nil { 531 | t.Fatalf("broker read error %q after %#x, want 0x%s", err, buf[:2+n], want) 532 | } 533 | got := hex.EncodeToString(buf[:2+n]) 534 | if want != got { 535 | t.Errorf("broker got packet 0x%s, want 0x%s", got, want) 536 | } 537 | } 538 | 539 | func typeLabelFromHex(char byte) string { 540 | switch char { 541 | case '0': 542 | return "RESERVED0" 543 | case '1': 544 | return "CONNECT" 545 | case '2': 546 | return "CONNACK" 547 | case '3': 548 | return "PUBLISH" 549 | case '4': 550 | return "PUBACK" 551 | case '5': 552 | return "PUBREC" 553 | case '6': 554 | return "PUBREL" 555 | case '7': 556 | return "PUBCOMP" 557 | case '8': 558 | return "SUBSCRIBE" 559 | case '9': 560 | return "SUBACK" 561 | case 'a', 'A': 562 | return "UNSUBSCRIBE" 563 | case 'b', 'B': 564 | return "UNSUBACK" 565 | case 'c', 'C': 566 | return "PINGREQ" 567 | case 'd', 'D': 568 | return "PINGRESP" 569 | case 'e', 'E': 570 | return "DISCONNECT" 571 | case 'f', 'F': 572 | return "RESERVED15" 573 | default: 574 | panic("not a hex character") 575 | } 576 | } 577 | -------------------------------------------------------------------------------- /request.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "encoding/binary" 5 | "errors" 6 | "fmt" 7 | "net" 8 | "sort" 9 | "sync" 10 | "time" 11 | ) 12 | 13 | // ErrMax denies a request on transit capacity, which prevents the Client from 14 | // blocking. Ping has a limit of 1 slot. Subscribe and Unsubscribe share a large 15 | // number of slots. PublishAtLeastOnce and PublishExactlyOnce each have a limit 16 | // defined by Config. A plain Publish (at most once) has no limit. 17 | var ErrMax = errors.New("mqtt: maximum number of pending requests reached") 18 | 19 | // ErrCanceled means that a quit signal got applied before the request was send. 20 | // The transacion never happened, as opposed to ErrAbandoned. 21 | var ErrCanceled = errors.New("mqtt: request canceled before submission") 22 | 23 | // ErrAbandoned means that a quit signal got applied after the request was send. 24 | // The broker received the request, yet the result/response remains unknown. 25 | var ErrAbandoned = errors.New("mqtt: request abandoned after submission") 26 | 27 | // ErrSubmit signals that the connection was lost during outbound transfer. The 28 | // status of the execution remains unknown, because there is no telling how much 29 | // of the payload actually reached the broker. Connection loss after submision 30 | // causes ErrBreak instead. 31 | var ErrSubmit = errors.New("mqtt: connection fatal during submission") 32 | 33 | // ErrBreak signals that the connection was lost after a request was send, and 34 | // before a response was received. The status of the execution remains unknown, 35 | // similar to ErrSubmit. 36 | var ErrBreak = errors.New("mqtt: connection fatal while awaiting response") 37 | 38 | // BufSize should fit topic names with a bit of overhead. 39 | const bufSize = 128 40 | 41 | // BufPool is used to construct packets for submission. 42 | // Append will allocate the appropriate amount on overflows. 43 | // The PUBLISH messages are not copied into these buffers. 44 | var bufPool = sync.Pool{New: func() interface{} { return new([bufSize]byte) }} 45 | 46 | var denyAndEndErrs = append(append( 47 | make([]error, 0, len(denyErrs)+len(endErrs)), 48 | denyErrs..., 49 | ), 50 | endErrs..., 51 | ) 52 | 53 | // Backoff returns a channel which is closed once the Client can (and possibly 54 | // should) retry the error. The return is nil when retries are not applicable, 55 | // i.e., any IsDeny, IsEnd, or SubscribeError gets a nil channel which blocks. 56 | func (c *Client) Backoff(err error) <-chan struct{} { 57 | switch { 58 | case err == nil || nonNilIsAny(err, denyAndEndErrs): 59 | return nil 60 | 61 | case errors.Is(err, ErrMax): 62 | const timeout = time.Second / 4 63 | 64 | // shared backoff for too many requests 65 | for { 66 | shared := c.backoffOnMax.Load() 67 | if shared != nil { 68 | // maybe shared already expired 69 | select { 70 | case <-*shared: 71 | break // replace with new up next 72 | default: 73 | return *shared 74 | } 75 | } 76 | // shared nil or expired 77 | 78 | new := make(chan struct{}) 79 | var readEnd <-chan struct{} = new 80 | if c.backoffOnMax.CompareAndSwap(shared, &readEnd) { 81 | // launch Timer only within singleton guarantee 82 | time.AfterFunc(timeout, func() { close(new) }) 83 | return readEnd 84 | } 85 | // another goroutine placed a new backoff first 86 | } 87 | 88 | case errors.As(err, new(SubscribeError)): 89 | // server failed request 90 | return nil 91 | 92 | default: 93 | // connection was down, or it just broke 94 | return c.Online() // await reconnect 95 | } 96 | } 97 | 98 | // Ping makes a roundtrip to validate the connection. Client allows only one 99 | // ping at a time. Redundant requests get ErrMax. 100 | // 101 | // Quit is optional, as nil just blocks. Appliance of quit will strictly result 102 | // in either ErrCanceled or ErrAbandoned. 103 | func (c *Client) Ping(quit <-chan struct{}) error { 104 | // install callback 105 | done := make(chan error, 1) 106 | select { 107 | case c.pingAck <- done: 108 | break // OK 109 | default: 110 | return fmt.Errorf("%w; PING unavailable", ErrMax) 111 | } 112 | 113 | // submit transaction 114 | if err := c.write(quit, packetPINGREQ); err != nil { 115 | select { 116 | case <-c.pingAck: // unlock 117 | default: // picked up by unrelated pong 118 | } 119 | if errors.Is(err, ErrSubmit) { 120 | return fmt.Errorf("%w; PING in limbo", err) 121 | } 122 | return fmt.Errorf("%w; PING not send", err) 123 | } 124 | 125 | select { 126 | case err := <-done: 127 | return err 128 | case <-quit: 129 | select { 130 | case <-c.pingAck: // unlock 131 | return fmt.Errorf("%w; PING not confirmed", ErrAbandoned) 132 | default: // picked up in mean time 133 | return <-done 134 | } 135 | } 136 | } 137 | 138 | func (c *Client) onPINGRESP() error { 139 | if len(c.peek) != 0 { 140 | return fmt.Errorf("%w: PINGRESP with %d byte remaining length", errProtoReset, len(c.peek)) 141 | } 142 | select { 143 | case ack := <-c.pingAck: 144 | close(ack) 145 | default: 146 | break // tolerates wandering pong 147 | } 148 | return nil 149 | } 150 | 151 | // SubscribeError holds one or more topic filters which were failed by the broker. 152 | // The element order matches the originating request's. 153 | type SubscribeError []string 154 | 155 | // Error implements the standard error interface. 156 | func (e SubscribeError) Error() string { 157 | return fmt.Sprintf("mqtt: broker failed %d topic filters", len(e)) 158 | } 159 | 160 | // Reject no-ops to prevent programming mistakes. 161 | var ( 162 | // “The payload of a SUBSCRIBE packet MUST contain at least one 163 | // Topic Filter / QoS pair. A SUBSCRIBE packet with no payload 164 | // is a protocol violation.” 165 | // — MQTT Version 3.1.1, conformance statement MQTT-3.8.3-3 166 | errSubscribeNone = errors.New("mqtt: SUBSCRIBE without topic filters denied") 167 | 168 | // “The Payload of an UNSUBSCRIBE packet MUST contain at least 169 | // one Topic Filter. An UNSUBSCRIBE packet with no payload is a 170 | // protocol violation.” 171 | // — MQTT Version 3.1.1, conformance statement MQTT-3.10.3-2 172 | errUnsubscribeNone = errors.New("mqtt: UNSUBSCRIBE without topic filters denied") 173 | ) 174 | 175 | // A total for four types of client requests require a 16-bit packet identifier, 176 | // namely SUBSCRIBE, UNSUBSCRIBE and PUBLISH at-least-once or exactly-once. 177 | // The outbound identifiers are assigned in segments per type. The non-zero 178 | // prefixes/spaces also prevent use of the reserved packet identifier zero. 179 | const ( 180 | // A 14-bit address space allows for up to 16,384 pending transactions. 181 | publishIDMask = 0x3fff 182 | // The most-significant bit flags an ordered transaction for publish. 183 | // The second most-significant bit distinguises the QOS level. 184 | atLeastOnceIDSpace = 0x8000 185 | exactlyOnceIDSpace = 0xc000 186 | 187 | // A 13-bit address space allows for up to 8,192 pending transactions. 188 | unorderedIDMask = 0x1fff 189 | subscribeIDSpace = 0x6000 190 | unsubscribeIDSpace = 0x4000 191 | ) 192 | 193 | // ErrPacketIDSpace signals a response packet with an identifier outside of the 194 | // respective address spaces, defined by subscribeIDSpace, unsubscribeIDSpace, 195 | // atLeastOnceIDSpace and exactlyOnceIDSpace. This extra check has a potential 196 | // to detect corruptions which would otherwise go unnoticed. 197 | var errPacketIDSpace = fmt.Errorf("%w: packet ID space mismatch", errProtoReset) 198 | 199 | // UnorderedTxs tracks outbound transactions without sequence contraints. 200 | type unorderedTxs struct { 201 | sync.Mutex 202 | n uint // counter is permitted to overflow 203 | perPacketID map[uint16]unorderedCallback // transit state 204 | } 205 | 206 | type unorderedCallback struct { 207 | done chan<- error 208 | topicFilters []string 209 | } 210 | 211 | // StartTx assigns a slot for either a subscribe or an unsubscribe. 212 | // The filter slice is nil for unsubscribes only. 213 | func (txs *unorderedTxs) startTx(topicFilters []string) (packetID uint16, done <-chan error, err error) { 214 | var space uint 215 | if topicFilters == nil { 216 | space = unsubscribeIDSpace 217 | } else { 218 | space = subscribeIDSpace 219 | } 220 | 221 | // Only one response error can be applied on done. 222 | ch := make(chan error, 1) 223 | 224 | txs.Lock() 225 | defer txs.Unlock() 226 | 227 | // By using only a small window of the actual space we 228 | // minimise any overlap risks with ErrAbandoned cases. 229 | if len(txs.perPacketID) > unorderedIDMask>>4 { 230 | return 0, nil, ErrMax 231 | } 232 | 233 | // Find a free identifier with the sequence counter. 234 | for { 235 | packetID = uint16(txs.n&unorderedIDMask | space) 236 | txs.n++ 237 | if _, ok := txs.perPacketID[packetID]; ok { 238 | // Such collision indicates a very late response. 239 | continue // just skips the identifier 240 | } 241 | txs.perPacketID[packetID] = unorderedCallback{ 242 | topicFilters: topicFilters, 243 | done: ch, 244 | } 245 | return packetID, ch, nil 246 | } 247 | } 248 | 249 | // EndTx releases a slot. The filter slice is nil for unsubscribe requests. 250 | func (txs *unorderedTxs) endTx(packetID uint16) (done chan<- error, topicFilters []string) { 251 | txs.Lock() 252 | defer txs.Unlock() 253 | callback := txs.perPacketID[packetID] 254 | delete(txs.perPacketID, packetID) 255 | return callback.done, callback.topicFilters 256 | } 257 | 258 | func (txs *unorderedTxs) breakAll() { 259 | txs.Lock() 260 | defer txs.Unlock() 261 | for packetID, callback := range txs.perPacketID { 262 | delete(txs.perPacketID, packetID) 263 | callback.done <- fmt.Errorf("%w; subscription change not confirmed", ErrBreak) 264 | } 265 | } 266 | 267 | // Subscribe requests subscription for all topics that match any of the filter 268 | // arguments. 269 | // 270 | // Quit is optional, as nil just blocks. Appliance of quit will strictly result 271 | // in either ErrCanceled or ErrAbandoned. 272 | func (c *Client) Subscribe(quit <-chan struct{}, topicFilters ...string) error { 273 | return c.subscribeLevel(quit, topicFilters, exactlyOnceLevel) 274 | } 275 | 276 | // SubscribeLimitAtMostOnce is like Subscribe, but it limits message reception 277 | // to quality-of-service level 0—fire and forget. 278 | func (c *Client) SubscribeLimitAtMostOnce(quit <-chan struct{}, topicFilters ...string) error { 279 | return c.subscribeLevel(quit, topicFilters, atMostOnceLevel) 280 | } 281 | 282 | // SubscribeLimitAtLeastOnce is like Subscribe, but it limits message reception 283 | // to quality-of-service level 1—acknowledged transfer. 284 | func (c *Client) SubscribeLimitAtLeastOnce(quit <-chan struct{}, topicFilters ...string) error { 285 | return c.subscribeLevel(quit, topicFilters, atLeastOnceLevel) 286 | } 287 | 288 | func (c *Client) subscribeLevel(quit <-chan struct{}, topicFilters []string, levelMax byte) error { 289 | if len(topicFilters) == 0 { 290 | return errSubscribeNone 291 | } 292 | size := 2 + len(topicFilters)*3 293 | for _, s := range topicFilters { 294 | if err := topicCheck(s); err != nil { 295 | return fmt.Errorf("%w; SUBSCRIBE request denied on topic filter", err) 296 | } 297 | size += len(s) 298 | } 299 | if size > packetMax { 300 | return fmt.Errorf("%w; SUBSCRIBE request denied", errPacketMax) 301 | } 302 | 303 | // slot assignment 304 | packetID, done, err := c.unorderedTxs.startTx(topicFilters) 305 | if err != nil { 306 | return fmt.Errorf("%w; SUBSCRIBE unavailable", err) 307 | } 308 | 309 | // request packet composition 310 | buf := bufPool.Get().(*[bufSize]byte) 311 | defer bufPool.Put(buf) 312 | packet := append(buf[:0], typeSUBSCRIBE<<4|atLeastOnceLevel<<1) 313 | l := uint(size) 314 | for ; l > 0x7f; l >>= 7 { 315 | packet = append(packet, byte(l|0x80)) 316 | } 317 | packet = append(packet, byte(l)) 318 | packet = append(packet, byte(packetID>>8), byte(packetID)) 319 | for _, s := range topicFilters { 320 | packet = append(packet, byte(len(s)>>8), byte(len(s))) 321 | packet = append(packet, s...) 322 | packet = append(packet, levelMax) 323 | } 324 | 325 | // network submission 326 | if err = c.write(quit, packet); err != nil { 327 | c.unorderedTxs.endTx(packetID) // releases slot 328 | if errors.Is(err, ErrSubmit) { 329 | return fmt.Errorf("%w; SUBSCRIBE in limbo", err) 330 | } 331 | return fmt.Errorf("%w; SUBSCRIBE not send", err) 332 | } 333 | 334 | select { 335 | case err := <-done: 336 | return err 337 | case <-quit: 338 | c.unorderedTxs.endTx(packetID) // releases slot 339 | return fmt.Errorf("%w; SUBSCRIBE not confirmed", ErrAbandoned) 340 | } 341 | } 342 | 343 | func (c *Client) onSUBACK() error { 344 | if len(c.peek) < 3 { 345 | return fmt.Errorf("%w: SUBACK with %d byte remaining length", errProtoReset, len(c.peek)) 346 | } 347 | packetID := binary.BigEndian.Uint16(c.peek) 348 | switch { 349 | case packetID == 0: 350 | return errPacketIDZero 351 | case packetID&^unorderedIDMask != subscribeIDSpace: 352 | return errPacketIDSpace 353 | } 354 | 355 | returnCodes := c.peek[2:] 356 | var failN int 357 | for _, code := range returnCodes { 358 | switch code { 359 | case atMostOnceLevel, atLeastOnceLevel, exactlyOnceLevel: 360 | break 361 | case 0x80: 362 | failN++ 363 | default: 364 | return fmt.Errorf("%w: SUBACK with illegal return code %#02x", errProtoReset, code) 365 | } 366 | } 367 | 368 | // commit 369 | done, topicFilters := c.unorderedTxs.endTx(packetID) 370 | if done == nil { // hopefully due ErrAbandoned 371 | return nil 372 | } 373 | 374 | // “The SUBACK Packet sent by the Server to the Client MUST contain a 375 | // return code for each Topic Filter/QoS pair. …” 376 | // — MQTT Version 3.1.1, conformance statement MQTT-3.8.4-5 377 | if len(topicFilters) != len(returnCodes) { 378 | done <- fmt.Errorf("mqtt: %d return codes for SUBSCRIBE with %d topic filters", len(returnCodes), len(topicFilters)) 379 | return errProtoReset 380 | } 381 | 382 | if failN != 0 { 383 | var err SubscribeError 384 | for i, code := range returnCodes { 385 | if code == 0x80 { 386 | err = append(err, topicFilters[i]) 387 | } 388 | } 389 | done <- err 390 | } 391 | close(done) 392 | return nil 393 | } 394 | 395 | // Unsubscribe requests subscription cancelation for each of the filter 396 | // arguments. 397 | // 398 | // Quit is optional, as nil just blocks. Appliance of quit will strictly result 399 | // in either ErrCanceled or ErrAbandoned. 400 | func (c *Client) Unsubscribe(quit <-chan struct{}, topicFilters ...string) error { 401 | if len(topicFilters) == 0 { 402 | return errUnsubscribeNone 403 | } 404 | size := 2 + len(topicFilters)*2 405 | for _, s := range topicFilters { 406 | size += len(s) 407 | if err := topicCheck(s); err != nil { 408 | return fmt.Errorf("%w; UNSUBSCRIBE request denied on topic filter", err) 409 | } 410 | } 411 | if size > packetMax { 412 | return fmt.Errorf("%w; UNSUBSCRIBE request denied", errPacketMax) 413 | } 414 | 415 | // slot assignment 416 | packetID, done, err := c.unorderedTxs.startTx(nil) 417 | if err != nil { 418 | return fmt.Errorf("%w; UNSUBSCRIBE unavailable", err) 419 | } 420 | 421 | // request packet composition 422 | buf := bufPool.Get().(*[bufSize]byte) 423 | defer bufPool.Put(buf) 424 | // header 425 | packet := append(buf[:0], typeUNSUBSCRIBE<<4|atLeastOnceLevel<<1) 426 | l := uint(size) 427 | for ; l > 0x7f; l >>= 7 { 428 | packet = append(packet, byte(l|0x80)) 429 | } 430 | packet = append(packet, byte(l)) 431 | packet = append(packet, byte(packetID>>8), byte(packetID)) 432 | // payload 433 | for _, s := range topicFilters { 434 | packet = append(packet, byte(len(s)>>8), byte(len(s))) 435 | packet = append(packet, s...) 436 | } 437 | 438 | // network submission 439 | if err = c.write(quit, packet); err != nil { 440 | c.unorderedTxs.endTx(packetID) // releases slot 441 | if errors.Is(err, ErrSubmit) { 442 | return fmt.Errorf("%w; UNSUBSCRIBE in limbo", err) 443 | } 444 | return fmt.Errorf("%w; UNSUBSCRIBE not send", err) 445 | } 446 | 447 | select { 448 | case err := <-done: 449 | return err 450 | case <-quit: 451 | c.unorderedTxs.endTx(packetID) // releases slot 452 | return fmt.Errorf("%w; UNSUBSCRIBE not confirmed", ErrAbandoned) 453 | } 454 | } 455 | 456 | func (c *Client) onUNSUBACK() error { 457 | if len(c.peek) != 2 { 458 | return fmt.Errorf("%w: UNSUBACK with %d byte remaining length", errProtoReset, len(c.peek)) 459 | } 460 | packetID := binary.BigEndian.Uint16(c.peek) 461 | switch { 462 | case packetID == 0: 463 | return errPacketIDZero 464 | case packetID&^unorderedIDMask != unsubscribeIDSpace: 465 | return errPacketIDSpace 466 | } 467 | done, _ := c.unorderedTxs.endTx(packetID) 468 | if done != nil { 469 | close(done) 470 | } 471 | return nil 472 | } 473 | 474 | // OrderedTxs tracks outbound transactions with sequence constraints. 475 | // The counters are allowed to overflow. 476 | type orderedTxs struct { 477 | Acked uint // confirm count for PublishAtLeastOnce 478 | Received uint // confirm count 1/2 for PublishExactlyOnce 479 | Completed uint // confirm count 2/2 for PublishExactlyOnce 480 | } 481 | 482 | // Publish delivers the message with an “at most once” guarantee. 483 | // Subscribers may or may not receive the message when subject to error. 484 | // This delivery method is the most efficient option. 485 | // 486 | // Quit is optional, as nil just blocks. Appliance of quit will strictly result 487 | // in ErrCanceled. 488 | func (c *Client) Publish(quit <-chan struct{}, message []byte, topic string) error { 489 | return c.publish(quit, message, topic, typePUBLISH<<4) 490 | } 491 | 492 | // PublishRetained is like Publish, but the broker must store the message, so 493 | // that it can be delivered to future subscribers whose subscriptions match the 494 | // topic name. The broker may choose to discard the message at any time though. 495 | // Uppon reception, the broker must discard any message previously retained for 496 | // the topic name. 497 | func (c *Client) PublishRetained(quit <-chan struct{}, message []byte, topic string) error { 498 | return c.publish(quit, message, topic, typePUBLISH<<4|retainFlag) 499 | } 500 | 501 | func (c *Client) publish(quit <-chan struct{}, message []byte, topic string, head byte) error { 502 | buf := bufPool.Get().(*[bufSize]byte) 503 | defer bufPool.Put(buf) 504 | packet, err := publishPacket(buf, message, topic, 0, head) 505 | if err != nil { 506 | return err 507 | } 508 | 509 | err = c.writeBuffers(quit, packet) 510 | if err != nil { 511 | if errors.Is(err, ErrSubmit) { 512 | return fmt.Errorf("%w; PUBLISH in limbo", err) 513 | } 514 | return fmt.Errorf("%w; PUBLISH not send", err) 515 | } 516 | return nil 517 | } 518 | 519 | // PublishAtLeastOnce delivers the message with an “at least once” guarantee. 520 | // Subscribers may receive the message more than once when subject to error. 521 | // This delivery method requires a response transmission plus persistence on 522 | // both client-side and broker-side. 523 | // 524 | // The exchange channel is closed uppon receival confirmation by the broker. 525 | // ErrClosed leaves the channel blocked (with no further input). 526 | func (c *Client) PublishAtLeastOnce(message []byte, topic string) (exchange <-chan error, err error) { 527 | buf := bufPool.Get().(*[bufSize]byte) 528 | defer bufPool.Put(buf) 529 | packet, err := publishPacket(buf, message, topic, atLeastOnceIDSpace, typePUBLISH<<4|atLeastOnceLevel<<1) 530 | if err != nil { 531 | return nil, err 532 | } 533 | return c.submitPersisted(packet, c.atLeastOnce) 534 | } 535 | 536 | // PublishAtLeastOnceRetained is like PublishAtLeastOnce, but the broker must 537 | // store the message, so that it can be delivered to future subscribers whose 538 | // subscriptions match the topic name. When a new subscription is established, 539 | // the last retained message, if any, on each matching topic name must be sent 540 | // to the subscriber. 541 | func (c *Client) PublishAtLeastOnceRetained(message []byte, topic string) (exchange <-chan error, err error) { 542 | buf := bufPool.Get().(*[bufSize]byte) 543 | defer bufPool.Put(buf) 544 | packet, err := publishPacket(buf, message, topic, atLeastOnceIDSpace, typePUBLISH<<4|atLeastOnceLevel<<1|retainFlag) 545 | if err != nil { 546 | return nil, err 547 | } 548 | return c.submitPersisted(packet, c.atLeastOnce) 549 | } 550 | 551 | // PublishExactlyOnce delivers the message with an “exactly once” guarantee. 552 | // This delivery method eliminates the duplicate-delivery risk from 553 | // PublishAtLeastOnce at the expense of an additional network roundtrip. 554 | func (c *Client) PublishExactlyOnce(message []byte, topic string) (exchange <-chan error, err error) { 555 | buf := bufPool.Get().(*[bufSize]byte) 556 | defer bufPool.Put(buf) 557 | packet, err := publishPacket(buf, message, topic, exactlyOnceIDSpace, typePUBLISH<<4|exactlyOnceLevel<<1) 558 | if err != nil { 559 | return nil, err 560 | } 561 | return c.submitPersisted(packet, c.exactlyOnce) 562 | } 563 | 564 | // PublishExactlyOnceRetained is like PublishExactlyOnce, but the broker must 565 | // store the message, so that it can be delivered to future subscribers whose 566 | // subscriptions match the topic name. When a new subscription is established, 567 | // the last retained message, if any, on each matching topic name must be sent 568 | // to the subscriber. 569 | func (c *Client) PublishExactlyOnceRetained(message []byte, topic string) (exchange <-chan error, err error) { 570 | buf := bufPool.Get().(*[bufSize]byte) 571 | defer bufPool.Put(buf) 572 | packet, err := publishPacket(buf, message, topic, exactlyOnceIDSpace, typePUBLISH<<4|exactlyOnceLevel<<1|retainFlag) 573 | if err != nil { 574 | return nil, err 575 | } 576 | return c.submitPersisted(packet, c.exactlyOnce) 577 | } 578 | 579 | var errNoConn = errors.New("mqtt: not connected") 580 | 581 | func (c *Client) submitPersisted(packet net.Buffers, out outbound) (exchange <-chan error, err error) { 582 | // lock sequence 583 | seq, ok := <-out.seqSem 584 | if !ok { 585 | return nil, ErrClosed 586 | } 587 | defer func() { 588 | out.seqSem <- seq // unlock with updated 589 | }() 590 | 591 | nBefore := seq.acceptN - seq.submitN 592 | 593 | // persist 594 | done, err := c.applySeqNoAndEnqueue(packet, seq.acceptN, out) 595 | if err != nil { 596 | return nil, err 597 | } 598 | seq.acceptN++ 599 | 600 | // submit 601 | if nBefore != 0 { 602 | // buffered channel won't block 603 | done <- fmt.Errorf("mqtt: %d earlier submissions pending; PUBLISH enqueued", 604 | nBefore) 605 | } else { 606 | err = c.writeBuffers(c.Offline(), packet) 607 | if err != nil { 608 | // canceled implies offline 609 | if errors.Is(err, ErrCanceled) { 610 | err = errNoConn 611 | } 612 | // buffered channel won't block 613 | done <- fmt.Errorf("%w; PUBLISH enqueued", err) 614 | } else { 615 | // seq.submitN = seq.acceptN 616 | seq.submitN++ 617 | } 618 | } 619 | 620 | return done, nil 621 | } 622 | 623 | func (c *Client) applySeqNoAndEnqueue(packet net.Buffers, seqNo uint, out outbound) (done chan error, err error) { 624 | if cap(out.queue) == len(out.queue) { 625 | return nil, fmt.Errorf("%w; PUBLISH unavailable", ErrMax) 626 | } 627 | 628 | // apply sequence number to packet 629 | buf := packet[0] 630 | i := len(buf) - 2 631 | packetID := uint(binary.BigEndian.Uint16(buf[i:])) 632 | packetID |= seqNo & publishIDMask 633 | binary.BigEndian.PutUint16(buf[i:], uint16(packetID)) 634 | 635 | err = c.persistence.Save(packetID, packet) 636 | if err != nil { 637 | return nil, fmt.Errorf("%w; PUBLISH dropped", err) 638 | } 639 | 640 | done = make(chan error, 2) // receives at most 1 write error + ErrClosed 641 | out.queue <- done // won't block due ErrMax check 642 | return done, nil 643 | } 644 | 645 | func publishPacket(buf *[bufSize]byte, message []byte, topic string, packetID uint, head byte) (net.Buffers, error) { 646 | if err := topicCheck(topic); err != nil { 647 | return nil, fmt.Errorf("%w; PUBLISH request denied on topic", err) 648 | } 649 | size := 2 + len(topic) + len(message) 650 | if packetID != 0 { 651 | size += 2 652 | } 653 | if size < 0 || size > packetMax { 654 | return nil, fmt.Errorf("%w; PUBLISH request denied", errPacketMax) 655 | } 656 | 657 | packet := append(buf[:0], head) 658 | l := uint(size) 659 | for ; l > 0x7f; l >>= 7 { 660 | packet = append(packet, byte(l|0x80)) 661 | } 662 | packet = append(packet, byte(l)) 663 | packet = append(packet, byte(len(topic)>>8), byte(len(topic))) 664 | packet = append(packet, topic...) 665 | if packetID != 0 { 666 | packet = append(packet, byte(packetID>>8), byte(packetID)) 667 | } 668 | return net.Buffers{packet, message}, nil 669 | } 670 | 671 | // OnPUBACK applies the confirm of a PublishAtLeastOnce. 672 | func (c *Client) onPUBACK() error { 673 | // parse packet 674 | if len(c.peek) != 2 { 675 | return fmt.Errorf("%w: PUBACK with %d byte remaining length", errProtoReset, len(c.peek)) 676 | } 677 | packetID := uint(binary.BigEndian.Uint16(c.peek)) 678 | 679 | // match identifier 680 | expect := c.orderedTxs.Acked&publishIDMask | atLeastOnceIDSpace 681 | switch { 682 | case packetID == 0: 683 | return errPacketIDZero 684 | case packetID&^publishIDMask != atLeastOnceIDSpace: 685 | return errPacketIDSpace 686 | case expect != packetID: 687 | return fmt.Errorf("%w: PUBACK %#04x while %#04x next in line", errProtoReset, packetID, expect) 688 | case len(c.atLeastOnce.queue) == 0: 689 | return fmt.Errorf("%w: PUBACK precedes PUBLISH", errProtoReset) 690 | } 691 | 692 | // ceil transaction 693 | err := c.persistence.Delete(packetID) 694 | if err != nil { 695 | return err // causes resubmission of PUBLISH 696 | } 697 | c.orderedTxs.Acked++ 698 | close(<-c.atLeastOnce.queue) 699 | return nil 700 | } 701 | 702 | // OnPUBREC applies the first confirm of a PublishExactlyOnce. 703 | func (c *Client) onPUBREC() error { 704 | // parse packet 705 | if len(c.peek) != 2 { 706 | return fmt.Errorf("%w: PUBREC with %d byte remaining length", errProtoReset, len(c.peek)) 707 | } 708 | packetID := uint(binary.BigEndian.Uint16(c.peek)) 709 | 710 | // match identifier 711 | expect := c.orderedTxs.Received&publishIDMask | exactlyOnceIDSpace 712 | switch { 713 | case packetID == 0: 714 | return errPacketIDZero 715 | case packetID&^publishIDMask != exactlyOnceIDSpace: 716 | return errPacketIDSpace 717 | case packetID != expect: 718 | return fmt.Errorf("%w: PUBREC %#04x while %#04x next in line", errProtoReset, packetID, expect) 719 | case int(c.Received-c.Completed) >= len(c.exactlyOnce.queue): 720 | return fmt.Errorf("%w: PUBREC precedes PUBLISH", errProtoReset) 721 | } 722 | 723 | // Use pendingAck as a buffer here. 724 | c.pendingAck = append(c.pendingAck[:0], typePUBREL<<4|atLeastOnceLevel<<1, 2, byte(packetID>>8), byte(packetID)) 725 | err := c.persistence.Save(packetID, net.Buffers{c.pendingAck}) 726 | if err != nil { 727 | c.pendingAck = c.pendingAck[:0] 728 | return err // causes resubmission of PUBLISH (from persistence) 729 | } 730 | c.orderedTxs.Received++ 731 | 732 | err = c.write(nil, c.pendingAck) 733 | if err != nil { 734 | return err // keeps pendingAck to retry 735 | } 736 | c.pendingAck = c.pendingAck[:0] 737 | return nil 738 | } 739 | 740 | // OnPUBCOMP applies the second (and final) confirm of a PublishExactlyOnce. 741 | func (c *Client) onPUBCOMP() error { 742 | // parse packet 743 | if len(c.peek) != 2 { 744 | return fmt.Errorf("%w: PUBCOMP with %d byte remaining length", errProtoReset, len(c.peek)) 745 | } 746 | packetID := uint(binary.BigEndian.Uint16(c.peek)) 747 | 748 | // match identifier 749 | expect := c.orderedTxs.Completed&publishIDMask | exactlyOnceIDSpace 750 | switch { 751 | case packetID == 0: 752 | return errPacketIDZero 753 | case packetID&^publishIDMask != exactlyOnceIDSpace: 754 | return errPacketIDSpace 755 | case packetID != expect: 756 | return fmt.Errorf("%w: PUBCOMP %#04x while %#04x next in line", errProtoReset, packetID, expect) 757 | case c.orderedTxs.Completed >= c.orderedTxs.Received || len(c.exactlyOnce.queue) == 0: 758 | return fmt.Errorf("%w: PUBCOMP precedes PUBREL", errProtoReset) 759 | } 760 | 761 | // ceil transaction 762 | err := c.persistence.Delete(packetID) 763 | if err != nil { 764 | return err // causes resubmission of PUBREL (from Persistence) 765 | } 766 | c.orderedTxs.Completed++ 767 | close(<-c.exactlyOnce.queue) 768 | return nil 769 | } 770 | 771 | // InitSession configures the Persistence for first use. Brokers use clientID to 772 | // uniquely identify the session. The session may be continued with AdoptSession 773 | // on another Client. 774 | // 775 | // An error implies either a broken setup or Persistence failure. Connection 776 | // issues, if any, are reported by ReadSlices. 777 | func InitSession(clientID string, p Persistence, c *Config) (*Client, error) { 778 | return initSession(clientID, &ruggedPersistence{Persistence: p}, c) 779 | } 780 | 781 | // VolatileSession operates solely in-memory. This setup is recommended for 782 | // delivery with the “at most once” guarantee [Publish], and for reception 783 | // without the “exactly once” guarantee [SubscribeLimitAtLeastOnce], and for 784 | // testing. 785 | // 786 | // Brokers use clientID to uniquely identify the session. Volatile sessions may 787 | // be continued by using the same clientID again. Use CleanSession to prevent 788 | // reuse of an existing state. 789 | // 790 | // An error implies a broken setup. Connection issues, if any, are reported by 791 | // ReadSlices. 792 | func VolatileSession(clientID string, c *Config) (*Client, error) { 793 | return initSession(clientID, newVolatile(), c) 794 | } 795 | 796 | func initSession(clientID string, p Persistence, c *Config) (*Client, error) { 797 | if err := stringCheck(clientID); err != nil { 798 | return nil, fmt.Errorf("%w; illegal client identifier", err) 799 | } 800 | if err := c.valid(); err != nil { 801 | return nil, err 802 | } 803 | 804 | // empty check 805 | keys, err := p.List() 806 | if err != nil { 807 | return nil, err 808 | } 809 | if len(keys) != 0 { 810 | return nil, errors.New("mqtt: init on non-empty persistence") 811 | } 812 | 813 | // install 814 | err = p.Save(clientIDKey, net.Buffers{[]byte(clientID)}) 815 | if err != nil { 816 | return nil, err 817 | } 818 | 819 | return newClient(p, c), nil 820 | } 821 | 822 | // AdoptSession continues with a Persistence which had an InitSession already. 823 | // 824 | // A fatal implies either a broken setup or persistence failure. Connection 825 | // issues, if any, are reported by ReadSlices. The Client recovers from corrupt 826 | // states (in Persistence) automatically with warn entries. 827 | func AdoptSession(p Persistence, c *Config) (client *Client, warn []error, fatal error) { 828 | if err := c.valid(); err != nil { 829 | return nil, warn, err 830 | } 831 | 832 | keys, err := p.List() 833 | if err != nil { 834 | return nil, warn, err 835 | } 836 | 837 | // storage includes a sequence number 838 | storeOrderPerKey := make(map[uint]uint64, len(keys)) 839 | 840 | // “When a Client reconnects with CleanSession set to 0, both the Client 841 | // and Server MUST re-send any unacknowledged PUBLISH Packets (where QoS 842 | // > 0) and PUBREL Packets using their original Packet Identifiers.” 843 | // — MQTT Version 3.1.1, conformance statement MQTT-4.4.0-1 844 | var publishAtLeastOnceKeys, publishExactlyOnceKeys, publishReleaseKeys []uint 845 | for _, key := range keys { 846 | if key == clientIDKey || key&remoteIDKeyFlag != 0 { 847 | continue 848 | } 849 | value, err := p.Load(key) 850 | if err != nil { 851 | return nil, warn, err 852 | } 853 | 854 | packet, storageSeqNo, err := decodeValue(value) 855 | if err != nil { 856 | delErr := p.Delete(key) 857 | if delErr != nil { 858 | warn = append(warn, fmt.Errorf("%w; record %#x not deleted: %w", err, key, delErr)) 859 | } else { 860 | warn = append(warn, fmt.Errorf("%w; record %#x deleted", err, key)) 861 | } 862 | 863 | continue 864 | } 865 | 866 | storeOrderPerKey[key] = storageSeqNo 867 | 868 | switch packet[0] >> 4 { 869 | case typePUBLISH: 870 | switch key &^ publishIDMask { 871 | case atLeastOnceIDSpace: 872 | publishAtLeastOnceKeys = append(publishAtLeastOnceKeys, key) 873 | case exactlyOnceIDSpace: 874 | publishExactlyOnceKeys = append(publishExactlyOnceKeys, key) 875 | } 876 | case typePUBREL: 877 | publishReleaseKeys = append(publishReleaseKeys, key) 878 | } 879 | } 880 | 881 | // sort by persistence sequence number 882 | sort.Slice(publishAtLeastOnceKeys, func(i, j int) (less bool) { 883 | return storeOrderPerKey[publishAtLeastOnceKeys[i]] < storeOrderPerKey[publishAtLeastOnceKeys[j]] 884 | }) 885 | sort.Slice(publishExactlyOnceKeys, func(i, j int) (less bool) { 886 | return storeOrderPerKey[publishExactlyOnceKeys[i]] < storeOrderPerKey[publishExactlyOnceKeys[j]] 887 | }) 888 | sort.Slice(publishReleaseKeys, func(i, j int) (less bool) { 889 | return storeOrderPerKey[publishReleaseKeys[i]] < storeOrderPerKey[publishReleaseKeys[j]] 890 | }) 891 | // ensure continuous sequence 892 | publishAtLeastOnceKeys = cleanSequence(publishAtLeastOnceKeys, "PUBLISH at-least-once", &warn) 893 | publishExactlyOnceKeys = cleanSequence(publishExactlyOnceKeys, "PUBLISH exactly-once", &warn) 894 | publishReleaseKeys = cleanSequence(publishReleaseKeys, "PUBREL", &warn) 895 | if len(publishExactlyOnceKeys) != 0 && len(publishReleaseKeys) != 0 { 896 | n := publishExactlyOnceKeys[0] & publishIDMask 897 | p := publishReleaseKeys[len(publishReleaseKeys)-1] & publishIDMask 898 | if n-p != 1 && !(n == 0 && p == publishIDMask) { 899 | warn = append(warn, fmt.Errorf("mqtt: PUBREL %#x–%#x dropped ☠️ due gap until PUBLISH %#x", 900 | publishReleaseKeys[0], publishReleaseKeys[len(publishReleaseKeys)-1], publishExactlyOnceKeys[0])) 901 | } 902 | } 903 | 904 | // instantiate client 905 | if n := len(publishAtLeastOnceKeys); n > c.AtLeastOnceMax { 906 | return nil, warn, fmt.Errorf("mqtt: %d AtLeastOnceMax is less than the %d pending in session", c.AtLeastOnceMax, n) 907 | } 908 | if n := len(publishExactlyOnceKeys) + len(publishReleaseKeys); n > c.ExactlyOnceMax { 909 | return nil, warn, fmt.Errorf("mqtt: %d ExactlyOnceMax is less than the %d pending in session", c.ExactlyOnceMax, n) 910 | } 911 | client = newClient(&ruggedPersistence{Persistence: p}, c) 912 | 913 | // check for outbound publish pending confirmation 914 | if keys = publishAtLeastOnceKeys; len(keys) != 0 { 915 | // install sequence counts; txs.Acked < seq.acceptN 916 | // and: seq.acceptN − txs.Acked ≤ publishIDMask 917 | 918 | client.orderedTxs.Acked = keys[0] & publishIDMask 919 | last := keys[len(keys)-1] & publishIDMask 920 | if last < client.orderedTxs.Acked { 921 | // range overflows address space 922 | last += publishIDMask + 1 923 | } 924 | seq := <-client.atLeastOnce.seqSem 925 | seq.acceptN = last + 1 926 | // BUG(pascaldekloe): 927 | // AdoptSession assumes that all publish-at-least-once packets 928 | // were submitted before already. Persisting the actual state 929 | // after each network submission seems like an overkill for 930 | // the DUP flag to be more accurate. 931 | seq.submitN = seq.acceptN 932 | client.atLeastOnce.seqSem <- seq 933 | } 934 | 935 | // check for outbound publish pending confirmation 936 | if publishKeys, releaseKeys := publishExactlyOnceKeys, publishReleaseKeys; len(publishKeys) != 0 || len(releaseKeys) != 0 { 937 | // install sequence counts; txs.Completed < seq.acceptN 938 | // and: txs.Completed ≤ txs.Received ≤ seq.acceptN 939 | // and: seq.acceptN − txs.Completed ≤ publishIDMask 940 | 941 | txs := &client.orderedTxs 942 | if len(releaseKeys) == 0 { // implies len(publishKeys) != 0 943 | txs.Completed = publishKeys[0] & publishIDMask 944 | txs.Received = txs.Completed 945 | } else { 946 | txs.Completed = releaseKeys[0] & publishIDMask 947 | txs.Received = releaseKeys[len(releaseKeys)-1]&publishIDMask + 1 948 | if txs.Received < txs.Completed { 949 | // range overflows address space 950 | txs.Received += publishIDMask + 1 951 | } 952 | } 953 | 954 | var last uint 955 | if len(publishKeys) != 0 { 956 | last = publishKeys[len(publishKeys)-1] & publishIDMask 957 | } else { 958 | last = releaseKeys[len(releaseKeys)-1] & publishIDMask 959 | } 960 | if last < txs.Received { 961 | // range overflows address space 962 | last += publishIDMask + 1 963 | } 964 | seq := <-client.exactlyOnce.seqSem 965 | seq.acceptN = last + 1 966 | // BUG(pascaldekloe): 967 | // AdoptSession assumes that all publish-exactly-once packets 968 | // were submitted before already. Persisting the actual state 969 | // after each network submission seems like an overkill for 970 | // the DUP flag to be more accurate. 971 | seq.submitN = seq.acceptN 972 | client.exactlyOnce.seqSem <- seq 973 | } 974 | 975 | // install callback placeholders; won't block due Max check above 976 | for range publishAtLeastOnceKeys { 977 | client.atLeastOnce.queue <- make(chan<- error, 1) 978 | } 979 | for range publishExactlyOnceKeys { 980 | client.exactlyOnce.queue <- make(chan<- error, 1) 981 | } 982 | for range publishReleaseKeys { 983 | client.exactlyOnce.queue <- make(chan<- error, 1) 984 | } 985 | 986 | return client, warn, nil 987 | } 988 | 989 | func cleanSequence(keys []uint, name string, warn *[]error) []uint { 990 | for i := 1; i < len(keys); i++ { 991 | n := keys[i] & publishIDMask 992 | p := keys[i-1] & publishIDMask 993 | if n-p == 1 || n == 0 && p == publishIDMask { 994 | continue 995 | } 996 | 997 | *warn = append(*warn, fmt.Errorf("mqtt: %s %#x–%#x dropped ☠️ due gap until %#x", name, keys[0], keys[i-1], keys[i])) 998 | 999 | keys = keys[i:] 1000 | i = 0 1001 | } 1002 | return keys 1003 | } 1004 | -------------------------------------------------------------------------------- /request_test.go: -------------------------------------------------------------------------------- 1 | package mqtt_test 2 | 3 | import ( 4 | "context" 5 | "encoding/hex" 6 | "errors" 7 | "io" 8 | "net" 9 | "os" 10 | "path/filepath" 11 | "strings" 12 | "testing" 13 | "time" 14 | 15 | "github.com/pascaldekloe/mqtt" 16 | "github.com/pascaldekloe/mqtt/mqtttest" 17 | ) 18 | 19 | func TestBackoff_ErrMax(t *testing.T) { 20 | client, _, testTimeout := newTestClientOnline(t) 21 | 22 | const parallelism = 5 23 | backoffs := make(chan (<-chan struct{}), parallelism) 24 | 25 | launch := make(chan struct{}) 26 | for i := 0; i < parallelism; i++ { 27 | go func() { 28 | <-launch // race start 29 | backoffs <- client.Backoff(mqtt.ErrMax) 30 | }() 31 | } 32 | close(launch) 33 | 34 | first := <-backoffs 35 | if first == nil { 36 | t.Fatal("got no backoff for retriable error") 37 | } 38 | select { 39 | case <-first: 40 | t.Fatal("backoff expired on arrival") 41 | default: 42 | break // OK 43 | } 44 | 45 | // match first against all others 46 | for i := 1; i < parallelism; i++ { 47 | switch <-backoffs { 48 | case nil: 49 | t.Fatal("got no backoff for retriable error") 50 | case first: 51 | break // OK 52 | default: 53 | t.Errorf("got another wait channel; want all the same") 54 | } 55 | } 56 | 57 | t.Log("await backoff channel") 58 | select { 59 | case <-first: 60 | break // good 61 | case <-testTimeout: 62 | t.Error("test timeout before backoff expiry") 63 | } 64 | } 65 | 66 | func TestPing(t *testing.T) { 67 | client, conn, testTimeout := newTestClientOnline(t) 68 | pingDone := testRoutine(t, func() { 69 | err := client.Ping(testTimeout) 70 | if err != nil { 71 | t.Error("ping got error:", err) 72 | } 73 | }) 74 | wantPacketHex(t, conn, "c000") // PINGREQ 75 | sendPacketHex(t, conn, "d000") // PINGRESP 76 | <-pingDone 77 | } 78 | 79 | // Ping should await the first connect attempt. 80 | func TestPing_beforeConnect(t *testing.T) { 81 | client, conn, testTimeout := newTestClient(t) 82 | pingDone := testRoutine(t, func() { 83 | err := client.Ping(testTimeout) 84 | if err != nil { 85 | t.Error("ping got error:", err) 86 | } 87 | }) 88 | // CONNECT slowdown causes wait scenario 89 | time.Sleep(100 * time.Millisecond) 90 | wantConnectExchange(t, conn) 91 | wantPacketHex(t, conn, "c000") // PINGREQ 92 | sendPacketHex(t, conn, "d000") // PINGRESP 93 | <-pingDone 94 | } 95 | 96 | // Ping should signal ErrDown when the first connect attempt fails. 97 | func TestPing_failedConnect(t *testing.T) { 98 | client, conns, testTimeout := newTestClientRedial(t, 99 | mqtttest.Transfer{Err: mqtt.ErrUnavailable}) 100 | 101 | pingDone := testRoutine(t, func() { 102 | err := client.Ping(testTimeout) 103 | if !errors.Is(err, mqtt.ErrDown) { 104 | t.Errorf("ping got error %v, want ErrDown", err) 105 | } 106 | 107 | // second ping when Online again 108 | wantOnline(t, client, testTimeout) 109 | err = client.Ping(testTimeout) 110 | if err != nil { 111 | t.Error("ping got error after Online signal:", err) 112 | } 113 | }) 114 | 115 | // fail first connect 116 | wantPacketHex(t, conns[0], "100c00044d515454040000000000") // CONNECT 117 | sendPacketHex(t, conns[0], "20020003") // CONNACK 118 | time.Sleep(200 * time.Millisecond) // stay ErrDown 119 | // accept second connect + ping exchange 120 | wantConnectExchange(t, conns[1]) 121 | wantPacketHex(t, conns[1], "c000") // PINGREQ 122 | sendPacketHex(t, conns[1], "d000") // PINGRESP 123 | <-pingDone 124 | } 125 | 126 | // Ping should get a timeout error on stale request submission. 127 | func TestPing_reqTimeout(t *testing.T) { 128 | client, conns, testTimeout := newTestClientOnlineRedial(t) 129 | 130 | pingDone := testRoutine(t, func() { 131 | err := client.Ping(testTimeout) 132 | var e net.Error 133 | if !errors.As(err, &e) || !e.Timeout() { 134 | t.Errorf("got error %v, want a Timeout net.Error", err) 135 | return 136 | } 137 | switch client.Backoff(err) { 138 | case nil: 139 | t.Error("no backoff for ping error") 140 | case client.Online(): 141 | break // OK 142 | default: 143 | t.Error("backoff for ping error does not equal Online channel") 144 | } 145 | }) 146 | 147 | // read first byte 148 | var buf [1]byte 149 | _, err := io.ReadFull(conns[0], buf[:]) 150 | if err != nil { 151 | t.Fatal("broker read error:", err) 152 | } 153 | if buf[0] != 0xC0 { 154 | t.Errorf("want PINGREQ head 0xC0, got %#x", buf[0]) 155 | } 156 | t.Log("first connection abandoned after partial read") 157 | // check reconnect 158 | wantPacketHex(t, conns[1], "100c00044d515454040000000000") // CONNECT 159 | t.Log("second connection abandoned after connect request") 160 | <-pingDone 161 | } 162 | 163 | // Subscribe should await the first connect attempt. 164 | func TestSubscribe_beforeConnect(t *testing.T) { 165 | client, conn, testTimeout := newTestClient(t) 166 | subscribeDone := testRoutine(t, func() { 167 | err := client.Subscribe(testTimeout, "u/noi", "u/shin") 168 | if err != nil { 169 | t.Error("subscribe got error:", err) 170 | } 171 | }) 172 | // CONNECT slowdown causes wait scenario 173 | time.Sleep(100 * time.Millisecond) 174 | wantConnectExchange(t, conn) 175 | wantPacketHex(t, conn, hex.EncodeToString([]byte{ 176 | 0x82, 19, 177 | 0x60, 0x00, // packet identifier 178 | 0, 5, 'u', '/', 'n', 'o', 'i', 179 | 2, // max QOS 180 | 0, 6, 'u', '/', 's', 'h', 'i', 'n', 181 | 2, // max QOS 182 | })) 183 | sendPacketHex(t, conn, "900460000102") // SUBACK 184 | <-subscribeDone 185 | } 186 | 187 | // Subscribe should get a timeout error on stale request submission. 188 | func TestSubscribe_reqTimeout(t *testing.T) { 189 | client, conns, testTimeout := newTestClientOnlineRedial(t) 190 | 191 | subscribeDone := testRoutine(t, func() { 192 | err := client.Subscribe(testTimeout, "x") 193 | var e net.Error 194 | if !errors.As(err, &e) || !e.Timeout() { 195 | t.Errorf("got error %v, want a Timeout net.Error", err) 196 | return 197 | } 198 | switch client.Backoff(err) { 199 | case nil: 200 | t.Error("no backoff for subscribe error") 201 | case client.Online(): 202 | break // OK 203 | default: 204 | t.Error("backoff for subscribe error does not equal Online channel") 205 | } 206 | }) 207 | 208 | // read first byte 209 | var buf [1]byte 210 | _, err := io.ReadFull(conns[0], buf[:]) 211 | if err != nil { 212 | t.Fatal("broker read error:", err) 213 | } 214 | if buf[0] != 0x82 { 215 | t.Errorf("want SUBSCRIBE head 0x82, got %#x", buf[0]) 216 | } 217 | t.Log("first connection abandoned after partial read") 218 | // check reconnect 219 | wantPacketHex(t, conns[1], "100c00044d515454040000000000") // CONNECT 220 | t.Log("second connection abandoned after connect request") 221 | <-subscribeDone 222 | } 223 | 224 | // Unsubscribe should await the first connect attempt. 225 | func TestUnsubscribe_beforeConnect(t *testing.T) { 226 | client, conn, testTimeout := newTestClient(t) 227 | unsubscribeDone := testRoutine(t, func() { 228 | err := client.Unsubscribe(testTimeout, "u/noi", "u/shin") 229 | if err != nil { 230 | t.Errorf("got error %q [%T]", err, err) 231 | } 232 | }) 233 | // CONNECT slowdown causes wait scenario 234 | time.Sleep(100 * time.Millisecond) 235 | wantConnectExchange(t, conn) 236 | wantPacketHex(t, conn, hex.EncodeToString([]byte{ 237 | 0xa2, 17, 238 | 0x40, 0x00, // packet identifier 239 | 0, 5, 'u', '/', 'n', 'o', 'i', 240 | 0, 6, 'u', '/', 's', 'h', 'i', 'n', 241 | })) 242 | sendPacketHex(t, conn, "b0024000") // UNSUBACK 243 | <-unsubscribeDone 244 | } 245 | 246 | // Unsubscribe should get a timeout error on stale request submission. 247 | func TestUnsubscribe_reqTimeout(t *testing.T) { 248 | client, conns, testTimeout := newTestClientOnlineRedial(t) 249 | 250 | unsubscribeDone := testRoutine(t, func() { 251 | err := client.Unsubscribe(testTimeout, "x") 252 | var e net.Error 253 | if !errors.As(err, &e) || !e.Timeout() { 254 | t.Errorf("unsubscribe got error %v, want a Timeout net.Error", err) 255 | } 256 | switch client.Backoff(err) { 257 | case nil: 258 | t.Error("no backoff for unsubscribe error") 259 | case client.Online(): 260 | break // OK 261 | default: 262 | t.Error("backoff for unsubscribe error does not equal Online channel") 263 | } 264 | }) 265 | 266 | // read first byte 267 | var buf [1]byte 268 | _, err := io.ReadFull(conns[0], buf[:]) 269 | if err != nil { 270 | t.Fatal("broker read error:", err) 271 | } 272 | if buf[0] != 0xa2 { 273 | t.Errorf("want UNSUBSCRIBE head 0xa2, got %#x", buf[0]) 274 | } 275 | t.Log("first connection abandoned after partial read") 276 | // check reconnect 277 | wantPacketHex(t, conns[1], "100c00044d515454040000000000") // CONNECT 278 | t.Log("second connection abandoned after connect request") 279 | <-unsubscribeDone 280 | } 281 | 282 | // Publish should await the first connect attempt. 283 | func TestPublish_beforeConnect(t *testing.T) { 284 | client, conn, testTimeout := newTestClient(t) 285 | publishDone := testRoutine(t, func() { 286 | err := client.Publish(testTimeout, []byte("hello"), "greet") 287 | if err != nil { 288 | t.Error("publish got error:", err) 289 | } 290 | }) 291 | // CONNECT slowdown causes wait scenario 292 | time.Sleep(100 * time.Millisecond) 293 | wantConnectExchange(t, conn) 294 | wantPacketHex(t, conn, hex.EncodeToString([]byte{ 295 | 0x30, 12, 296 | 0, 5, 'g', 'r', 'e', 'e', 't', 297 | 'h', 'e', 'l', 'l', 'o'})) 298 | <-publishDone 299 | } 300 | 301 | // Publish should get a timeout error on stale request submission. 302 | func TestPublish_reqTimeout(t *testing.T) { 303 | client, conns, testTimeout := newTestClientOnlineRedial(t) 304 | 305 | publishDone := testRoutine(t, func() { 306 | err := client.Publish(testTimeout, []byte{'x'}, "y") 307 | var e net.Error 308 | if !errors.As(err, &e) || !e.Timeout() { 309 | t.Errorf("got error %q [%T], want a Timeout net.Error", err, err) 310 | return 311 | } 312 | switch client.Backoff(err) { 313 | case nil: 314 | t.Error("no backoff for publish error") 315 | case client.Online(): 316 | break // OK 317 | default: 318 | t.Error("backoff for publish error does not equal Online channel") 319 | } 320 | }) 321 | 322 | // read first byte 323 | var buf [1]byte 324 | _, err := io.ReadFull(conns[0], buf[:]) 325 | if err != nil { 326 | t.Fatal("broker read error:", err) 327 | } 328 | if buf[0] != 0x30 { 329 | t.Errorf("want PUBLISH head 0x30, got %#x", buf[0]) 330 | } 331 | t.Log("first connection abandoned after partial read") 332 | // check reconnect 333 | wantPacketHex(t, conns[1], "100c00044d515454040000000000") 334 | t.Log("second connection abandoned after connect request") 335 | <-publishDone 336 | } 337 | 338 | func TestPublishAtLeastOnce(t *testing.T) { 339 | client, conn, testTimeout := newTestClientOnline(t) 340 | publishDone := testRoutine(t, func() { 341 | exchange, err := client.PublishAtLeastOnce([]byte("hello"), "greet") 342 | if err != nil { 343 | t.Error("publish got error:", err) 344 | } 345 | verifyExchange(t, testTimeout, exchange, nil) 346 | }) 347 | wantPacketHex(t, conn, hex.EncodeToString([]byte{ 348 | 0x32, 14, 349 | 0, 5, 'g', 'r', 'e', 'e', 't', 350 | 0x80, 0x00, // packet identifier 351 | 'h', 'e', 'l', 'l', 'o'})) 352 | sendPacketHex(t, conn, "40028000") // PUBACK 353 | <-publishDone 354 | } 355 | 356 | // Publish should enqueue and continue once online. 357 | func TestPublishAtLeastOnce_beforeConnect(t *testing.T) { 358 | client, conn, testTimeout := newTestClient(t) 359 | publishDone := testRoutine(t, func() { 360 | exchange, err := client.PublishAtLeastOnce([]byte("hello"), "greet") 361 | if err != nil { 362 | t.Error("publish got error:", err) 363 | } 364 | verifyExchange(t, testTimeout, exchange, 365 | "mqtt: not connected; PUBLISH enqueued", nil) 366 | }) 367 | // CONNECT slowdown causes wait scenario 368 | time.Sleep(100 * time.Millisecond) 369 | wantConnectExchange(t, conn) 370 | wantPacketHex(t, conn, hex.EncodeToString([]byte{ 371 | 0x32, 14, 372 | 0, 5, 'g', 'r', 'e', 'e', 't', 373 | 0x80, 0x00, // packet identifier 374 | 'h', 'e', 'l', 'l', 'o'})) 375 | sendPacketHex(t, conn, "40028000") // PUBACK 376 | <-publishDone 377 | } 378 | 379 | // Publish should get a timeout error on stale request submission and resubmit 380 | // after reconnect. 381 | func TestPublishAtLeastOnce_reqTimeout(t *testing.T) { 382 | client, conns, testTimeout := newTestClientOnlineRedial(t) 383 | publishDone := testRoutine(t, func() { 384 | exchange, err := client.PublishAtLeastOnce([]byte{'x'}, "y") 385 | if err != nil { 386 | t.Error("publish got error:", err) 387 | } 388 | verifyExchangeTimeout(t, testTimeout, exchange) 389 | }) 390 | 391 | // read first byte 392 | var buf [1]byte 393 | _, err := io.ReadFull(conns[0], buf[:]) 394 | if err != nil { 395 | t.Fatal("broker read error:", err) 396 | } 397 | if buf[0] != 0x32 { 398 | t.Errorf("want PUBLISH head 0x32, got %#x", buf[0]) 399 | } 400 | t.Log("first connection abandoned after partial read") 401 | // check reconnect 402 | wantConnectExchange(t, conns[1]) 403 | wantPacketHex(t, conns[1], hex.EncodeToString([]byte{ 404 | 0x32, 6, 405 | 00, 01, 'y', 406 | 0x80, 0x00, // packet identifier 407 | 'x'})) 408 | sendPacketHex(t, conns[1], "40028000") // PUBACK 409 | <-publishDone 410 | } 411 | 412 | // Publish should get ErrDown notification and resubmit after reconnect. 413 | func TestPublishAtLeastOnce_whileDown(t *testing.T) { 414 | client, conns, testTimeout := newTestClientRedial(t, 415 | mqtttest.Transfer{Err: mqtt.ErrUnavailable}) 416 | 417 | publishDone := testRoutine(t, func() { 418 | // await connection refusal 419 | time.Sleep(100 * time.Millisecond) 420 | // publish should enqueue with ErrDown notification 421 | exchange1, err := client.PublishAtLeastOnce([]byte("x"), "y") 422 | if err != nil { 423 | t.Error("first publish got error:", err) 424 | return 425 | } 426 | verifyExchange(t, testTimeout, exchange1, mqtt.ErrDown, nil) 427 | 428 | // check recovery 429 | exchange2, err := client.PublishAtLeastOnce([]byte("a"), "b") 430 | if err != nil { 431 | t.Fatal("second publish got error:", err) 432 | } 433 | verifyExchange(t, testTimeout, exchange2, nil) 434 | }) 435 | 436 | // fail first connect 437 | wantPacketHex(t, conns[0], "100c00044d515454040000000000") // CONNECT 438 | sendPacketHex(t, conns[0], "20020003") // CONNACK 439 | time.Sleep(200 * time.Millisecond) // give ErrDown some time 440 | // accept second connect + two publish exchanges 441 | wantConnectExchange(t, conns[1]) 442 | wantPacketHex(t, conns[1], "3206000179800078") // PUBLISH #1 443 | sendPacketHex(t, conns[1], "40028000") // PUBACK #1 444 | wantPacketHex(t, conns[1], "3206000162800161") // PUBLISH #2 445 | sendPacketHex(t, conns[1], "40028001") // PUBACK #2 446 | <-publishDone 447 | } 448 | 449 | // TestPublishAtLeastOnce_restart sends three messages with QOS 1. The broker 450 | // simulation will do all of the following: 451 | // 452 | // A. Receive message #1 453 | // B. Receive message #2 454 | // C. Acknowledge mesage #1 455 | // D. Partially receive message #3 456 | // 457 | // Then the session is continued with a new client. It must automatically send 458 | // message #2 and #3 again. 459 | func TestPublishAtLeastOnce_restart(t *testing.T) { 460 | t.Parallel() 461 | dir := t.TempDir() // persistence location 462 | 463 | // start timers after Parallel branche 464 | ctx, cancel := context.WithTimeout(context.Background(), 2*testTimeout) 465 | defer cancel() 466 | deadline, _ := ctx.Deadline() 467 | 468 | // request packets 469 | const ( 470 | publish1Hex = "3206000178800031" // '1' (0x31) @ 'x' (0x78) 471 | publish2Hex = "3206000178800132" // '2' (0x32) @ 'x' (0x78) 472 | publish3Hex = "3206000178800233" // '3' (0x33) @ 'x' (0x78) 473 | publish2DupeHex = "3a06000178800132" // with duplicate [DUP] flag 474 | publish3DupeHex = "3a06000178800233" // with duplicate [DUP] flag 475 | ) 476 | 477 | clientConn, brokerConn := net.Pipe() 478 | // expire I/O mock before tests timeout 479 | brokerConn.SetDeadline(deadline.Add(-200 * time.Millisecond)) 480 | client, err := mqtt.InitSession("test-client", mqtt.FileSystem(dir), &mqtt.Config{ 481 | PauseTimeout: time.Second / 4, 482 | AtLeastOnceMax: 3, 483 | Dialer: newDialerMock(t, clientConn), 484 | }) 485 | if err != nil { 486 | t.Fatal("init session got error:", err) 487 | } 488 | verifyClient(t, client) 489 | 490 | publishDone := testRoutine(t, func() { 491 | select { 492 | case <-client.Online(): 493 | break // OK 494 | case <-ctx.Done(): 495 | t.Error("test timeout before Online") 496 | return 497 | } 498 | exchange1, err := client.PublishAtLeastOnce([]byte{'1'}, "x") 499 | if err != nil { 500 | t.Errorf("publish #1 got error %q [%T]", err, err) 501 | } 502 | exchange2, err := client.PublishAtLeastOnce([]byte{'2'}, "x") 503 | if err != nil { 504 | t.Errorf("publish #2 got error %q [%T]", err, err) 505 | } 506 | exchange3, err := client.PublishAtLeastOnce([]byte{'3'}, "x") 507 | if err != nil { 508 | t.Errorf("publish #3 got error %q [%T]", err, err) 509 | } 510 | verifyExchange(t, ctx.Done(), exchange1, nil) 511 | verifyExchange(t, ctx.Done(), exchange2, mqtt.ErrClosed) 512 | verifyExchange(t, ctx.Done(), exchange3, mqtt.ErrSubmit, mqtt.ErrClosed) 513 | }) 514 | 515 | wantPacketHex(t, brokerConn, "101700044d51545404000000000b746573742d636c69656e74") 516 | sendPacketHex(t, brokerConn, "20020000") // CONNACK accept 517 | wantPacketHex(t, brokerConn, publish1Hex) 518 | wantPacketHex(t, brokerConn, publish2Hex) 519 | sendPacketHex(t, brokerConn, "40028000") // PUBACK #1 520 | // read first byte of publish3 521 | var buf [1]byte 522 | switch _, err := io.ReadFull(brokerConn, buf[:]); { 523 | case err != nil: 524 | t.Fatal("broker read error:", err) 525 | case buf[0] != 0x32: 526 | t.Errorf("want PUBLISH head 0x32, got %#x", buf[0]) 527 | } 528 | err = client.Close() 529 | if err != nil { 530 | t.Error("close got error:", err) 531 | } 532 | <-publishDone 533 | 534 | // verify persistence; seals compatibility 535 | publish2File := filepath.Join(dir, "08001") // named after its packet ID 536 | publish3File := filepath.Join(dir, "08002") 537 | if bytes, err := os.ReadFile(publish2File); err != nil { 538 | t.Error("publish #2 file:", err) 539 | } else { 540 | gotHex := hex.EncodeToString(bytes) 541 | // packet + sequence number + checksum: 542 | const wantHex = publish2Hex + "0300000000000000" + "c0dcafa6" 543 | if gotHex != wantHex { 544 | t.Errorf("publish #2 file contains 0x%s, want 0x%s", 545 | gotHex, wantHex) 546 | } 547 | } 548 | if bytes, err := os.ReadFile(publish3File); err != nil { 549 | t.Error("publish #3 file:", err) 550 | } else { 551 | gotHex := hex.EncodeToString(bytes) 552 | // packet + sequence number + checksum: 553 | const wantHex = publish3Hex + "04000000000000000" + "5a75959" 554 | if gotHex != wantHex { 555 | t.Errorf("publish #3 file contains 0x%s, want 0x%s", 556 | gotHex, wantHex) 557 | } 558 | } 559 | 560 | if t.Failed() { 561 | return 562 | } 563 | t.Log("session continue with another Client") 564 | 565 | clientConn, brokerConn = net.Pipe() 566 | // expire I/O mock before tests timeout 567 | brokerConn.SetDeadline(deadline.Add(-200 * time.Millisecond)) 568 | client, warn, err := mqtt.AdoptSession(mqtt.FileSystem(dir), &mqtt.Config{ 569 | PauseTimeout: time.Second / 4, 570 | AtLeastOnceMax: 3, 571 | Dialer: newDialerMock(t, clientConn), 572 | }) 573 | if err != nil { 574 | t.Fatal("adopt session got error:", err) 575 | } 576 | for _, err := range warn { 577 | t.Error("adopt session got warning:", err) 578 | } 579 | verifyClient(t, client) 580 | 581 | wantPacketHex(t, brokerConn, "101700044d51545404000000000b746573742d636c69656e74") 582 | sendPacketHex(t, brokerConn, "20020000") // CONNACK 583 | wantPacketHex(t, brokerConn, publish2DupeHex) 584 | wantPacketHex(t, brokerConn, publish3DupeHex) 585 | sendPacketHex(t, brokerConn, "40028001") // PUBACK #2 586 | sendPacketHex(t, brokerConn, "40028002") // PUBACK #3 587 | 588 | // await PUBACK appliance 589 | time.Sleep(200 * time.Millisecond) 590 | if _, err := os.Stat(publish2File); err == nil { 591 | t.Error("publish #2 file still exits after PUBACK", err) 592 | } else if !os.IsNotExist(err) { 593 | t.Error("publish #2 file error:", err) 594 | } 595 | if _, err := os.Stat(publish3File); err == nil { 596 | t.Error("publish #3 file still exits after PUBACK", err) 597 | } else if !os.IsNotExist(err) { 598 | t.Error("publish #3 file error:", err) 599 | } 600 | } 601 | 602 | func TestPublishExactlyOnce(t *testing.T) { 603 | client, conn, testTimeout := newTestClientOnline(t) 604 | publishDone := testRoutine(t, func() { 605 | exchange, err := client.PublishExactlyOnce([]byte("hello"), "greet") 606 | if err != nil { 607 | t.Error("publish got error:", err) 608 | return 609 | } 610 | verifyExchange(t, testTimeout, exchange, nil) 611 | }) 612 | wantPacketHex(t, conn, hex.EncodeToString([]byte{ 613 | 0x34, 14, 614 | 0, 5, 'g', 'r', 'e', 'e', 't', 615 | 0xc0, 0x00, // packet identifier 616 | 'h', 'e', 'l', 'l', 'o'})) 617 | sendPacketHex(t, conn, "5002c000") // PUBREC 618 | wantPacketHex(t, conn, "6202c000") // PUBREL 619 | sendPacketHex(t, conn, "7002c000") // PUBCOMP 620 | <-publishDone 621 | } 622 | 623 | // Publish should enqueue and continue once online. 624 | func TestPublishExactlyOnce_beforeConnect(t *testing.T) { 625 | client, conn, testTimeout := newTestClient(t) 626 | 627 | publishDone := testRoutine(t, func() { 628 | exchange, err := client.PublishExactlyOnce([]byte("hello"), "greet") 629 | if err != nil { 630 | t.Error("publish got error:", err) 631 | return 632 | } 633 | verifyExchange(t, testTimeout, exchange, 634 | "mqtt: not connected; PUBLISH enqueued", nil) 635 | }) 636 | 637 | // CONNECT slowdown causes wait scenario 638 | time.Sleep(100 * time.Millisecond) 639 | wantConnectExchange(t, conn) 640 | wantPacketHex(t, conn, hex.EncodeToString([]byte{ 641 | 0x34, 14, 642 | 0, 5, 'g', 'r', 'e', 'e', 't', 643 | 0xC0, 0x00, // packet identifier 644 | 'h', 'e', 'l', 'l', 'o'})) 645 | sendPacketHex(t, conn, "5002c000") // PUBREC 646 | wantPacketHex(t, conn, "6202c000") // PUBREL 647 | sendPacketHex(t, conn, "7002c000") // PUBCOMP 648 | <-publishDone 649 | } 650 | 651 | // Publish should get a timeout error on stale request submission and resubmit 652 | // after reconnect. 653 | func TestPublishExactlyOnce_reqTimeout(t *testing.T) { 654 | client, conns, testTimeout := newTestClientOnlineRedial(t) 655 | 656 | publishDone := testRoutine(t, func() { 657 | exchange, err := client.PublishExactlyOnce([]byte{'x'}, "y") 658 | if err != nil { 659 | t.Error("publish error:", err) 660 | return 661 | } 662 | verifyExchangeTimeout(t, testTimeout, exchange) 663 | }) 664 | 665 | // read first byte 666 | var buf [1]byte 667 | _, err := io.ReadFull(conns[0], buf[:]) 668 | if err != nil { 669 | t.Fatal("broker read error:", err) 670 | } 671 | if buf[0] != 0x34 { 672 | t.Errorf("want PUBLISH head 0x34, got %#x", buf[0]) 673 | } 674 | t.Log("first connection abandoned after partial read") 675 | // check reconnect 676 | wantConnectExchange(t, conns[1]) 677 | wantPacketHex(t, conns[1], hex.EncodeToString([]byte{ 678 | 0x34, 6, 679 | 00, 01, 'y', 680 | 0xC0, 0x00, // packet identifier 681 | 'x'})) 682 | sendPacketHex(t, conns[1], "5002c000") // PUBREC 683 | wantPacketHex(t, conns[1], "6202c000") // PUBREL 684 | sendPacketHex(t, conns[1], "7002c000") // PUBCOMP 685 | <-publishDone 686 | } 687 | 688 | // TestPublishExactlyOnce_restart sends five messages with QOS 1. The broker 689 | // simulation will do all of the following: 690 | // 691 | // A. Complete publish #1 (4/4) 692 | // B. Leave publish #2 without PUBCOMP (3/4) 693 | // C. Leave publish #3 without PUBREL (2/4) 694 | // D. Leave publish #4 without PUBREC (1/4) 695 | // E. Leave publish #5 without PUBLISH (0/4) 696 | // 697 | // Specifically, after the client sends publish #1 and #2, the broker simulation 698 | // will do all of the following to accomplish A and B: 699 | // 700 | // 1. Receive message #1 701 | // 2. Receive message #2 702 | // 3. Recognise #1 703 | // 4. Recognise #2 704 | // 5. Receive release #1 705 | // 6. Receive release #2; 706 | // 7. Complete #1 707 | // 708 | // Then, after a little pause, the client sends publish #2 and #3, and the 709 | // broker simulation will do all of the following to accomplish C, D and E: 710 | // 711 | // 1. Receive message #3 712 | // 2. Receive message #4 713 | // 3. Regognise #3 714 | // 715 | // Then the session is continued with a new client. It must automatically send 716 | // message #4 and #5 again, and it must release message #2 and #3. 717 | func TestPublishExactlyOnce_restart(t *testing.T) { 718 | t.Parallel() 719 | dir := t.TempDir() // persistence location 720 | 721 | // start timers after Parallel branche 722 | ctx, cancel := context.WithTimeout(context.Background(), 2*testTimeout) 723 | defer cancel() 724 | deadline, _ := ctx.Deadline() 725 | 726 | // request packets 727 | const ( 728 | publish1Hex = "3406000178c00031" // '1' (0x31) @ 'x' (0x78) 729 | publish2Hex = "3406000178c00132" // '2' (0x32) @ 'x' (0x78) 730 | publish3Hex = "3406000178c00233" // '3' (0x33) @ 'x' (0x78) 731 | publish4Hex = "3406000178c00334" // '4' (0x34) @ 'x' (0x78) 732 | publish5Hex = "3406000178c00435" // '5' (0x35) @ 'x' (0x78) 733 | publish4DupeHex = "3c06000178c00334" // with duplicate [DUP] flag 734 | publish5DupeHex = "3c06000178c00435" // with duplicate [DUP] flag 735 | ) 736 | 737 | clientConn, brokerConn := net.Pipe() 738 | // expire I/O mock before tests timeout 739 | brokerConn.SetDeadline(deadline.Add(-200 * time.Millisecond)) 740 | client, err := mqtt.InitSession("test-client", mqtt.FileSystem(dir), &mqtt.Config{ 741 | PauseTimeout: time.Second / 4, 742 | ExactlyOnceMax: 5, 743 | Dialer: newDialerMock(t, clientConn), 744 | }) 745 | if err != nil { 746 | t.Fatal("init session got error:", err) 747 | } 748 | verifyClient(t, client, mqtttest.Transfer{Err: io.ErrClosedPipe}) 749 | 750 | publishDone := testRoutine(t, func() { 751 | select { 752 | case <-client.Online(): 753 | break 754 | case <-ctx.Done(): 755 | t.Fatal("test timeout before Online") 756 | } 757 | 758 | exchange1, err := client.PublishExactlyOnce([]byte{'1'}, "x") 759 | if err != nil { 760 | t.Errorf("publish #1 got error %q [%T]", err, err) 761 | } 762 | exchange2, err := client.PublishExactlyOnce([]byte{'2'}, "x") 763 | if err != nil { 764 | t.Errorf("publish #2 got error %q [%T]", err, err) 765 | } 766 | // await processing of message #1 and #2 767 | time.Sleep(200 * time.Millisecond) 768 | 769 | exchange3, err := client.PublishExactlyOnce([]byte{'3'}, "x") 770 | if err != nil { 771 | t.Errorf("publish #3 got error %q [%T]", err, err) 772 | } 773 | time.Sleep(50 * time.Millisecond) 774 | exchange4, err := client.PublishExactlyOnce([]byte{'4'}, "x") 775 | if err != nil { 776 | t.Errorf("publish #4 got error %q [%T]", err, err) 777 | } 778 | time.Sleep(50 * time.Millisecond) 779 | exchange5, err := client.PublishExactlyOnce([]byte{'5'}, "x") 780 | if err != nil { 781 | t.Errorf("publish #5 got error %q [%T]", err, err) 782 | } 783 | 784 | verifyExchange(t, ctx.Done(), exchange1, nil) 785 | verifyExchange(t, ctx.Done(), exchange2, mqtt.ErrClosed) 786 | verifyExchange(t, ctx.Done(), exchange3, mqtt.ErrClosed) 787 | verifyExchange(t, ctx.Done(), exchange4, mqtt.ErrClosed) 788 | verifyExchange(t, ctx.Done(), exchange5, 789 | "mqtt: not connected; PUBLISH enqueued", mqtt.ErrClosed) 790 | }) 791 | 792 | wantPacketHex(t, brokerConn, "101700044d51545404000000000b746573742d636c69656e74") 793 | sendPacketHex(t, brokerConn, "20020000") // CONNACK 794 | wantPacketHex(t, brokerConn, publish1Hex) 795 | wantPacketHex(t, brokerConn, publish2Hex) 796 | sendPacketHex(t, brokerConn, "5002c000") // PUBREC #1 797 | wantPacketHex(t, brokerConn, "6202c000") // PUBREL #1 798 | sendPacketHex(t, brokerConn, "5002c001") // PUBREC #2 799 | wantPacketHex(t, brokerConn, "6202c001") // PUBREL #2 800 | sendPacketHex(t, brokerConn, "7002c000") // PUBCOMP #1 801 | wantPacketHex(t, brokerConn, publish3Hex) 802 | wantPacketHex(t, brokerConn, publish4Hex) 803 | sendPacketHex(t, brokerConn, "5002c002") // PUBREC #3 804 | time.Sleep(100 * time.Millisecond) 805 | err = client.Close() 806 | if err != nil { 807 | t.Error("Close error:", err) 808 | } 809 | <-publishDone 810 | 811 | // verify persistence; seals compatibility 812 | publish1File := filepath.Join(dir, "0c000") // named after it's packet ID 813 | publish2File := filepath.Join(dir, "0c001") 814 | publish3File := filepath.Join(dir, "0c002") 815 | publish4File := filepath.Join(dir, "0c003") 816 | publish5File := filepath.Join(dir, "0c004") 817 | if _, err := os.Stat(publish1File); err == nil { 818 | t.Error("publish #1 file still exits after PUBCOMP", err) 819 | } else if !os.IsNotExist(err) { 820 | t.Error("publish #1 file error:", err) 821 | } 822 | if bytes, err := os.ReadFile(publish2File); err != nil { 823 | t.Error("publish #2 file:", err) 824 | } else { 825 | gotHex := hex.EncodeToString(bytes) 826 | // PUBREL #2 packet + sequence number + checksum: 827 | const wantHex = "6202c001" + "0500000000000000" + "74d798bf" 828 | if gotHex != wantHex { 829 | t.Errorf("publish #2 file contains 0x%s, want 0x%s", 830 | gotHex, wantHex) 831 | } 832 | } 833 | if bytes, err := os.ReadFile(publish3File); err != nil { 834 | t.Error("publish #3 file:", err) 835 | } else { 836 | gotHex := hex.EncodeToString(bytes) 837 | // PUBREL #3 packet + sequence number + checksum: 838 | const wantHex = "6202c002" + "0800000000000000" + "6bb2f52f" 839 | if gotHex != wantHex { 840 | t.Errorf("publish #3 file contains 0x%s, want 0x%s", 841 | gotHex, wantHex) 842 | } 843 | } 844 | if bytes, err := os.ReadFile(publish4File); err != nil { 845 | t.Error("publish #4 file:", err) 846 | } else { 847 | gotHex := hex.EncodeToString(bytes) 848 | // packet + sequence number + checksum: 849 | const wantHex = publish4Hex + "0700000000000000" + "2d4cbd7c" 850 | if gotHex != wantHex { 851 | t.Errorf("publish #4 file contains 0x%s, want 0x%s", 852 | gotHex, wantHex) 853 | } 854 | } 855 | if bytes, err := os.ReadFile(publish5File); err != nil { 856 | t.Error("publish #5 file:", err) 857 | } else { 858 | gotHex := hex.EncodeToString(bytes) 859 | // packet + sequence number + checksum: 860 | const wantHex = publish5Hex + "0900000000000000" + "1d6d8b0e" 861 | if gotHex != wantHex { 862 | t.Errorf("publish #5 file contains 0x%s, want 0x%s", 863 | gotHex, wantHex) 864 | } 865 | } 866 | 867 | if t.Failed() { 868 | return 869 | } 870 | t.Log("session continue with another Client") 871 | 872 | clientConn, brokerConn = net.Pipe() 873 | // expire I/O mock before tests timeout 874 | brokerConn.SetDeadline(deadline.Add(-200 * time.Millisecond)) 875 | client, warn, err := mqtt.AdoptSession(mqtt.FileSystem(dir), &mqtt.Config{ 876 | PauseTimeout: time.Second / 4, 877 | ExactlyOnceMax: 5, 878 | Dialer: newDialerMock(t, clientConn), 879 | }) 880 | if err != nil { 881 | t.Fatal("adopt session got error:", err) 882 | } 883 | for _, err := range warn { 884 | t.Error("adopt session got warning:", err) 885 | } 886 | verifyClient(t, client) 887 | 888 | wantPacketHex(t, brokerConn, "101700044d51545404000000000b746573742d636c69656e74") 889 | sendPacketHex(t, brokerConn, "20020000") // CONNACK accept 890 | wantPacketHex(t, brokerConn, "6202c001") // PUBREL #2 891 | wantPacketHex(t, brokerConn, "6202c002") // PUBREL #3 892 | wantPacketHex(t, brokerConn, publish4DupeHex) 893 | wantPacketHex(t, brokerConn, publish5DupeHex) 894 | sendPacketHex(t, brokerConn, "5002c003") // PUBREC #4 895 | wantPacketHex(t, brokerConn, "6202c003") // PUBREL #4 896 | sendPacketHex(t, brokerConn, "5002c004") // PUBREC #5 897 | wantPacketHex(t, brokerConn, "6202c004") // PUBREL #5 898 | sendPacketHex(t, brokerConn, "7002c001") // PUBCOMP #2 899 | sendPacketHex(t, brokerConn, "7002c002") // PUBCOMP #3 900 | sendPacketHex(t, brokerConn, "7002c003") // PUBCOMP #4 901 | sendPacketHex(t, brokerConn, "7002c004") // PUBCOMP #5 902 | 903 | // await PUBCOMP appliance 904 | time.Sleep(200 * time.Millisecond) 905 | if _, err := os.Stat(publish2File); err == nil { 906 | t.Error("publish #2 file still exits after PUBCOMP", err) 907 | } else if !os.IsNotExist(err) { 908 | t.Error("publish #2 file error:", err) 909 | } 910 | if _, err := os.Stat(publish3File); err == nil { 911 | t.Error("publish #3 file still exits after PUBCOMP", err) 912 | } else if !os.IsNotExist(err) { 913 | t.Error("publish #3 file error:", err) 914 | } 915 | if _, err := os.Stat(publish4File); err == nil { 916 | t.Error("publish #4 file still exits after PUBCOMP", err) 917 | } else if !os.IsNotExist(err) { 918 | t.Error("publish #4 file error:", err) 919 | } 920 | if _, err := os.Stat(publish5File); err == nil { 921 | t.Error("publish #5 file still exits after PUBCOMP", err) 922 | } else if !os.IsNotExist(err) { 923 | t.Error("publish #5 file error:", err) 924 | } 925 | } 926 | 927 | // Brokers may resend a PUBREL even after receiving PUBCOMP (in case the serice 928 | // crashed for example). 929 | func TestPUBRELRetry(t *testing.T) { 930 | _, conn, _ := newTestClientOnline(t) 931 | sendPacketHex(t, conn, "62021234") // PUBREL 932 | wantPacketHex(t, conn, "70021234") // PUBCOMP 933 | } 934 | 935 | func TestAbandon(t *testing.T) { 936 | client, conn, _ := newTestClientOnline(t) 937 | quit := make(chan struct{}) 938 | 939 | pingDone := testRoutine(t, func() { 940 | err := client.Ping(quit) 941 | if !errors.Is(err, mqtt.ErrAbandoned) { 942 | t.Errorf("ping got error %q [%T], want an mqtt.ErrAbandoned", err, err) 943 | } 944 | }) 945 | wantPacketHex(t, conn, "c000") // PINGREQ 946 | 947 | subscribeDone := testRoutine(t, func() { 948 | err := client.Subscribe(quit, "x") 949 | if !errors.Is(err, mqtt.ErrAbandoned) { 950 | t.Errorf("subscribe got error %q [%T], want an mqtt.ErrAbandoned", err, err) 951 | } 952 | }) 953 | wantPacketHex(t, conn, "8206600000017802") // SUBSCRIBE 954 | 955 | unsubscribeDone := testRoutine(t, func() { 956 | err := client.Unsubscribe(quit, "x") 957 | if !errors.Is(err, mqtt.ErrAbandoned) { 958 | t.Errorf("unsubscribe got error %q [%T], want an mqtt.ErrAbandoned", err, err) 959 | } 960 | }) 961 | wantPacketHex(t, conn, "a2054001000178") // UNSUBSCRIBE 962 | 963 | time.Sleep(10 * time.Millisecond) 964 | close(quit) 965 | <-pingDone 966 | <-subscribeDone 967 | <-unsubscribeDone 968 | } 969 | 970 | func TestBreak(t *testing.T) { 971 | client, conn, testTimeout := newTestClientOnline(t, mqtttest.Transfer{Err: io.EOF}) 972 | 973 | pingDone := testRoutine(t, func() { 974 | err := client.Ping(testTimeout) 975 | if !errors.Is(err, mqtt.ErrBreak) { 976 | t.Errorf("ping got error %q [%T], want an mqtt.ErrBreak", err, err) 977 | return 978 | } 979 | switch client.Backoff(err) { 980 | case nil: 981 | t.Error("no backoff for ping error") 982 | case client.Online(): 983 | break // OK 984 | default: 985 | t.Error("backoff for ping error is not equal to Online channel") 986 | } 987 | }) 988 | wantPacketHex(t, conn, "c000") // PINGREQ 989 | 990 | subscribeDone := testRoutine(t, func() { 991 | err := client.Subscribe(testTimeout, "x") 992 | if !errors.Is(err, mqtt.ErrBreak) { 993 | t.Errorf("subscribe got error %q [%T], want an mqtt.ErrBreak", err, err) 994 | return 995 | } 996 | switch client.Backoff(err) { 997 | case nil: 998 | t.Error("no backoff for subscribe error") 999 | case client.Online(): 1000 | break // OK 1001 | default: 1002 | t.Error("backoff for subscribe error is not equal to Online channel") 1003 | } 1004 | }) 1005 | wantPacketHex(t, conn, "8206600000017802") // SUBSCRIBE 1006 | 1007 | unsubscribeDone := testRoutine(t, func() { 1008 | err := client.Unsubscribe(testTimeout, "x") 1009 | if !errors.Is(err, mqtt.ErrBreak) { 1010 | t.Errorf("unsubscribe got error %q [%T], want an mqtt.ErrBreak", err, err) 1011 | return 1012 | } 1013 | switch client.Backoff(err) { 1014 | case nil: 1015 | t.Error("no backoff for unsubscribe error") 1016 | case client.Online(): 1017 | break // OK 1018 | default: 1019 | t.Error("backoff for unsubscribe error is not equal to Online channel") 1020 | } 1021 | }) 1022 | wantPacketHex(t, conn, "a2054001000178") // UNSUBSCRIBE 1023 | 1024 | if err := conn.Close(); err != nil { 1025 | t.Error("broker mock got error on pipe close:", err) 1026 | } 1027 | <-pingDone 1028 | <-subscribeDone 1029 | <-unsubscribeDone 1030 | } 1031 | 1032 | func TestDeny(t *testing.T) { 1033 | // no invocation to the client allowed 1034 | client, _, testTimeout := newTestClient(t) 1035 | 1036 | errCheck := func(err error, desc string) { 1037 | if !mqtt.IsDeny(err) { 1038 | t.Errorf("%s got error %q [%T], want an mqtt.IsDeny", 1039 | desc, err, err) 1040 | } else if client.Backoff(err) != nil { 1041 | t.Errorf("%s got backoff for deny error", desc) 1042 | } 1043 | } 1044 | 1045 | // UTF-8 validation 1046 | errCheck(client.PublishRetained(testTimeout, nil, "topic with \xED\xA0\x80 not allowed"), 1047 | "publish QoS 0 with U+D800 in topic") 1048 | _, err := client.PublishAtLeastOnceRetained(nil, "topic with \xED\xA0\x81 not allowed") 1049 | errCheck(err, "publish QoS 1 with U+D801 in topic") 1050 | _, err = client.PublishExactlyOnceRetained(nil, "topic with \xED\xBF\xBF not allowed") 1051 | errCheck(err, "publish QoS 2 with U+DFFF in topic") 1052 | 1053 | errCheck(client.SubscribeLimitAtMostOnce(nil, "null char \x00 not allowed"), 1054 | "subscribe max QoS 0 with null character") 1055 | errCheck(client.SubscribeLimitAtLeastOnce(nil, "char \x80 breaks UTF-8"), 1056 | "subscribe max QoS 1 with broken UTF-8") 1057 | 1058 | // empty vararg 1059 | errCheck(client.Subscribe(testTimeout), 1060 | "subscribe with nothing") 1061 | errCheck(client.Unsubscribe(testTimeout), 1062 | "unsubscribe with nothing") 1063 | 1064 | // empty topic 1065 | errCheck(client.Subscribe(testTimeout, ""), 1066 | "subscribe with zero topic") 1067 | errCheck(client.Unsubscribe(testTimeout, ""), 1068 | "unsubscribe with zero topic") 1069 | errCheck(client.Publish(testTimeout, nil, ""), 1070 | "publish with zero topic") 1071 | 1072 | // size limits 1073 | tooBig := strings.Repeat("A", 1<<16) 1074 | errCheck(client.Unsubscribe(testTimeout, tooBig), 1075 | "unsubscribe with 64 KiB filter") 1076 | errCheck(client.Publish(testTimeout, make([]byte, 256*1024*1024), ""), 1077 | "publish with 256 MiB") 1078 | 1079 | filtersTooBig := make([]string, 256*1024) 1080 | KiB := strings.Repeat("A", 1024) 1081 | for i := range filtersTooBig { 1082 | filtersTooBig[i] = KiB 1083 | } 1084 | errCheck(client.Subscribe(testTimeout, filtersTooBig...), 1085 | "subscribe with 256 MiB topic filters") 1086 | errCheck(client.Unsubscribe(testTimeout, filtersTooBig...), 1087 | "unsubscribe with 256 MiB topic filters") 1088 | } 1089 | 1090 | // VerifyExchange compares exchange reception against the wanted list in order 1091 | // of appearence. Errors want a errors.Is, strings want a strings.Contain, and 1092 | // nil wants a closed channel. 1093 | func verifyExchange(t *testing.T, testTimeout <-chan struct{}, exchange <-chan error, wanted ...any) { 1094 | t.Helper() 1095 | defer t.Log("exchange verification done") 1096 | 1097 | if len(wanted) == 0 { 1098 | panic("can't verify without wanted listing") 1099 | } 1100 | 1101 | for i := range wanted { 1102 | var got error 1103 | select { 1104 | case <-testTimeout: 1105 | t.Error("test timeout while awaiting exchange error") 1106 | return 1107 | case err, ok := <-exchange: 1108 | if !ok { 1109 | if wanted[i] != nil { 1110 | t.Errorf("exchange closed after %d errors, want error %q", 1111 | i, wanted[i]) 1112 | } 1113 | return 1114 | } 1115 | got = err 1116 | } 1117 | 1118 | switch want := wanted[i].(type) { 1119 | case nil: 1120 | t.Errorf("got exchange error %q [%T], want channel close", 1121 | got, got) 1122 | case string: 1123 | if !strings.Contains(got.Error(), want) { 1124 | t.Errorf("got exchange error %q [%T], want %q mentioned", 1125 | got, got, want) 1126 | } 1127 | case error: 1128 | if !errors.Is(got, want) { 1129 | t.Errorf("got exchange error %q [%T], want a %q [%T]", 1130 | got, got, want, want) 1131 | } 1132 | default: 1133 | panic("want is non-nil, non-string and non-error") 1134 | } 1135 | } 1136 | } 1137 | 1138 | func verifyExchangeTimeout(t *testing.T, testTimeout <-chan struct{}, exchange <-chan error) { 1139 | t.Helper() 1140 | defer t.Log("exchange verification done") 1141 | 1142 | select { 1143 | case <-testTimeout: 1144 | t.Error("test timeout while awaiting timeout error") 1145 | return 1146 | case err, ok := <-exchange: 1147 | if !ok { 1148 | t.Errorf("exchange complete, want timeout error") 1149 | return 1150 | } 1151 | var e net.Error 1152 | if !errors.As(err, &e) || !e.Timeout() { 1153 | t.Errorf("got exchange error %v, want a Timeout net.Error", err) 1154 | } 1155 | } 1156 | 1157 | select { 1158 | case <-testTimeout: 1159 | t.Error("test timeout while awaiting exchange complete") 1160 | return 1161 | case err, ok := <-exchange: 1162 | if ok { 1163 | t.Errorf("got exchange error %v, want exchange complete", err) 1164 | } 1165 | } 1166 | } 1167 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package mqtt 2 | 3 | import ( 4 | "bufio" 5 | "context" 6 | "crypto/tls" 7 | "encoding/binary" 8 | "errors" 9 | "fmt" 10 | "io" 11 | "net" 12 | "sync" 13 | "sync/atomic" 14 | "time" 15 | ) 16 | 17 | // ReadBufSize covers inbound packet reception. BigMessage still uses the buffer 18 | // to parse everything up until the message payload, which makes a worst-case of 19 | // 2 B size prefix + 64 KiB topic + 2 B packet identifier. 20 | var readBufSize = 128 * 1024 21 | 22 | // ErrDown signals no-service after a failed connect attempt. 23 | // The error state will clear once a connect retry succeeds. 24 | var ErrDown = errors.New("mqtt: connection unavailable") 25 | 26 | // ErrClosed signals use after Close. The state is permanent. 27 | // Further invocation will result again in an ErrClosed error. 28 | var ErrClosed = errors.New("mqtt: client closed") 29 | 30 | // ErrBrokerTerm signals connection loss for unknown reasons. 31 | var errBrokerTerm = fmt.Errorf("mqtt: broker closed the connection (%w)", io.EOF) 32 | 33 | // ErrProtoReset signals illegal reception. 34 | var errProtoReset = errors.New("mqtt: connection reset on protocol violation by the broker") 35 | 36 | // “SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets 37 | // MUST contain a non-zero 16-bit Packet Identifier.” 38 | // — MQTT Version 3.1.1, conformance statement MQTT-2.3.1-1 39 | var errPacketIDZero = fmt.Errorf("%w: packet identifier zero", errProtoReset) 40 | 41 | // A broker may send none of these packet types. 42 | var ( 43 | errRESERVED0 = fmt.Errorf("%w: reserved packet type 0 is forbidden", errProtoReset) 44 | errGotCONNECT = fmt.Errorf("%w: inbound CONNECT packet", errProtoReset) 45 | errCONNACKTwo = fmt.Errorf("%w: second CONNACK packet", errProtoReset) 46 | errGotSUBSCRIBE = fmt.Errorf("%w: inbound SUBSCRIBE packet", errProtoReset) 47 | errGotUNSUBSCRIBE = fmt.Errorf("%w: inbound UNSUBSCRIBE packet", errProtoReset) 48 | errGotPINGREQ = fmt.Errorf("%w: inbound PINGREQ packet", errProtoReset) 49 | errGotDISCONNECT = fmt.Errorf("%w: inbound DISCONNECT packet", errProtoReset) 50 | errRESERVED15 = fmt.Errorf("%w: reserved packet type 15 is forbidden", errProtoReset) 51 | ) 52 | 53 | // Dialer abstracts the transport layer establishment. 54 | type Dialer func(ctx context.Context) (net.Conn, error) 55 | 56 | // NewDialer provides plain network connections. 57 | // See net.Dial for details on the network & address syntax. 58 | func NewDialer(network, address string) Dialer { 59 | return func(ctx context.Context) (net.Conn, error) { 60 | // minimize timer use; covered by PauseTimeout 61 | dialer := net.Dialer{KeepAlive: -1} 62 | return dialer.DialContext(ctx, network, address) 63 | } 64 | } 65 | 66 | // NewTLSDialer provides secured network connections. 67 | // See net.Dial for details on the network & address syntax. 68 | func NewTLSDialer(network, address string, config *tls.Config) Dialer { 69 | return func(ctx context.Context) (net.Conn, error) { 70 | dialer := tls.Dialer{ 71 | // minimize timer use; covered by PauseTimeout 72 | NetDialer: &net.Dialer{KeepAlive: -1}, 73 | Config: config, 74 | } 75 | return dialer.DialContext(ctx, network, address) 76 | } 77 | } 78 | 79 | // Config is a Client configuration. Dialer is the only required field. 80 | type Config struct { 81 | Dialer // chooses the broker 82 | 83 | // Channels returned by ReadBackoff will block for at least this amount 84 | // of time on connection loss. Zero defaults to one second. Retries at a 85 | // low interval may induce ill side-effects such as network strain, log 86 | // flooding and battery consumption. 87 | ReconnectWaitMin time.Duration 88 | 89 | // Channels returned by ReadBackoff will each double the amount of time 90 | // they block on consecutive failure to connect, up to this maximum. The 91 | // value defaults to one minute. The maximum is raised to the effective 92 | // minimum (from ReconnectWaitMin) when short. Connection refusal by the 93 | // broker directly applies to the maximum, without ramp up. 94 | ReconnectWaitMax time.Duration 95 | 96 | // PauseTimeout sets the minimum transfer rate as one byte per duration. 97 | // Zero disables timeout protection entirely, which leaves the Client 98 | // vulnerable to blocking on stale connections. 99 | // 100 | // Any pauses during MQTT packet submission that exceed the timeout will 101 | // be treated as fatal to the connection, if they are detected in time. 102 | // Expiry causes automated reconnects just like any other fatal network 103 | // error. Operations which got interrupted by a PauseTimeout receive a 104 | // net.Error with Timeout true. 105 | PauseTimeout time.Duration 106 | 107 | // The maximum number of transactions at a time. Excess is denied with 108 | // ErrMax. Zero effectively disables the respective quality-of-service 109 | // level. Negative values default to the Client limit of 16,384. Higher 110 | // values are truncated silently. 111 | AtLeastOnceMax, ExactlyOnceMax int 112 | 113 | // The user name may be used by the broker for authentication and/or 114 | // authorization purposes. An empty string omits the option, except for 115 | // when password is not nil. 116 | UserName string 117 | Password []byte // option omitted when nil 118 | 119 | // The Will Message is published when the connection terminates without 120 | // Disconnect. A nil Message disables the Will option. 121 | Will struct { 122 | Topic string // destination 123 | Message []byte // payload 124 | 125 | Retain bool // see PublishRetained 126 | AtLeastOnce bool // see PublishAtLeastOnce 127 | ExactlyOnce bool // overrides AtLeastOnce 128 | } 129 | 130 | // KeepAlive sets the activity timeout in seconds, with zero for none. 131 | // The broker must disconnect after no control-packet reception for one 132 | // and a half times the keep-alive duration. Use Ping when idle. 133 | KeepAlive uint16 134 | 135 | // Brokers must resume communications with the client (identified by 136 | // ClientID) when CleanSession is false. Otherwise, brokers must create 137 | // a new session when either CleanSession is true or when no session is 138 | // associated to the client identifier. 139 | // 140 | // Reconnects do not clean the session, regardless of this setting. 141 | CleanSession bool 142 | } 143 | 144 | func (c *Config) valid() error { 145 | if c.Dialer == nil { 146 | return errors.New("mqtt: no Dialer in Config") 147 | } 148 | if err := stringCheck(c.UserName); err != nil { 149 | return fmt.Errorf("%w; illegal user name", err) 150 | } 151 | if len(c.Password) > stringMax { 152 | return fmt.Errorf("%w; illegal password", errStringMax) 153 | } 154 | if len(c.Will.Message) > stringMax { 155 | return fmt.Errorf("%w; illegal will message", errStringMax) 156 | } 157 | 158 | var err error 159 | if c.Will.Message != nil { 160 | err = topicCheck(c.Will.Topic) 161 | } else { 162 | err = stringCheck(c.Will.Topic) 163 | } 164 | if err != nil { 165 | return fmt.Errorf("%w; illegal will topic", err) 166 | } 167 | 168 | return nil 169 | } 170 | 171 | // NewConnectReq returns a new CONNECT packet conform the configuration. 172 | func (c *Config) newConnectReq(clientID []byte) []byte { 173 | size := 12 + len(clientID) 174 | var flags uint 175 | 176 | // Supply an empty user name when the password is set to comply with “If 177 | // the User Name Flag is set to 0, the Password Flag MUST be set to 0.” 178 | // — MQTT Version 3.1.1, conformance statement MQTT-3.1.2-22 179 | if c.UserName != "" || c.Password != nil { 180 | size += 2 + len(c.UserName) 181 | flags |= 1 << 7 182 | } 183 | if c.Password != nil { 184 | size += 2 + len(c.Password) 185 | flags |= 1 << 6 186 | } 187 | 188 | if c.Will.Message != nil { 189 | size += 4 + len(c.Will.Topic) + len(c.Will.Message) 190 | if c.Will.Retain { 191 | flags |= 1 << 5 192 | } 193 | switch { 194 | case c.Will.ExactlyOnce: 195 | flags |= exactlyOnceLevel << 3 196 | case c.Will.AtLeastOnce: 197 | flags |= atLeastOnceLevel << 3 198 | } 199 | flags |= 1 << 2 200 | } 201 | 202 | if c.CleanSession { 203 | flags |= 1 << 1 204 | } 205 | 206 | // encode packet 207 | packet := make([]byte, 0, size+2) 208 | packet = append(packet, typeCONNECT<<4) 209 | l := uint(size) 210 | for ; l > 0x7f; l >>= 7 { 211 | packet = append(packet, byte(l|0x80)) 212 | } 213 | packet = append(packet, byte(l), 214 | 0, 4, 'M', 'Q', 'T', 'T', 4, byte(flags), 215 | byte(c.KeepAlive>>8), byte(c.KeepAlive), 216 | byte(len(clientID)>>8), byte(len(clientID)), 217 | ) 218 | packet = append(packet, clientID...) 219 | if c.Will.Message != nil { 220 | packet = append(packet, byte(len(c.Will.Topic)>>8), byte(len(c.Will.Topic))) 221 | packet = append(packet, c.Will.Topic...) 222 | packet = append(packet, byte(len(c.Will.Message)>>8), byte(len(c.Will.Message))) 223 | packet = append(packet, c.Will.Message...) 224 | } 225 | if c.UserName != "" || c.Password != nil { 226 | packet = append(packet, byte(len(c.UserName)>>8), byte(len(c.UserName))) 227 | packet = append(packet, c.UserName...) 228 | } 229 | if c.Password != nil { 230 | packet = append(packet, byte(len(c.Password)>>8), byte(len(c.Password))) 231 | packet = append(packet, c.Password...) 232 | } 233 | return packet 234 | } 235 | 236 | // Client manages a network connection until Close or Disconnect. Clients always 237 | // start in the Offline state. The (un)subscribe, publish and ping methods block 238 | // until the first connect attempt (from ReadSlices) completes. When the connect 239 | // attempt fails, then requests receive ErrDown until a retry succeeds. The same 240 | // applies to reconnects. 241 | // 242 | // A single goroutine must invoke ReadSlices consecutively until ErrClosed. 243 | // Appliance of ReadBackoff comes recommended though. 244 | // 245 | // Multiple goroutines may invoke methods on a Client simultaneously, except for 246 | // ReadSlices and ReadBackoff. 247 | type Client struct { 248 | // The applied settings are read only. 249 | Config 250 | 251 | // InNewSession gets flagged when the broker confirms a connect without 252 | // without “session present”, a.k.a. the SP¹ indicator. Note how this 253 | // includes any automatic reconnects after protocol or connection failure. 254 | // Users should await the Online channel before Load of the atomic. 255 | InNewSession atomic.Bool 256 | 257 | // The session track is dedicated to the respective client identifier. 258 | persistence Persistence 259 | 260 | // Signal channels are closed once their respective state occurs. 261 | // Each read must restore or replace the signleton value. 262 | onlineSig, offlineSig chan chan struct{} 263 | 264 | // The read routine controls the connection, including reconnects. 265 | readConn net.Conn 266 | bufr *bufio.Reader // readConn buffered 267 | peek []byte // pending slice from bufio.Reader 268 | 269 | // Shutdown (from either Close or Disconnect) directly expires the 270 | // context used by (re)connect attempts. 271 | connectCtx context.Context 272 | connectHalt context.CancelFunc 273 | 274 | // Connect, Close and Disconnect all lock connection management with the 275 | // signleton entry. A nil entry implies that the Client never connected. 276 | // The Client is closed when this channel is closed. 277 | connSem chan net.Conn 278 | 279 | // Writes may happen from multiple goroutines. The signleton entry is 280 | // either the current connection or a signal placeholder. ConnPending 281 | // signals the first (re)connect attempt, and connDown signals failure 282 | // to (re)connect. The Client is closed when this channel is closed. 283 | writeSem chan net.Conn 284 | 285 | // The semaphore allows for one ping request at a time. 286 | pingAck chan chan<- error 287 | 288 | atLeastOnce, exactlyOnce outbound 289 | 290 | // outbound transaction tracking 291 | orderedTxs 292 | unorderedTxs 293 | 294 | pendingAck []byte // enqueued packet submission 295 | 296 | // The read routine parks reception beyond readBufSize. 297 | bigMessage *BigMessage 298 | 299 | // shared backoff on ErrMax prevents timer flood 300 | backoffOnMax atomic.Pointer[<-chan struct{}] 301 | 302 | // backoff ramp-up 303 | reconnectWait time.Duration 304 | } 305 | 306 | // Outbound submission may face multiple goroutines. 307 | type outbound struct { 308 | // The sequence semaphore is a singleton instance. 309 | seqSem chan seq 310 | 311 | // Acknowledgement is traced with a callback channel per request. 312 | // Insertion requires a seqSem lock as the queue order must match its 313 | // respective sequence number. Close of the queue requires connSem to 314 | // prevent panic on double close [race]. 315 | queue chan chan<- error 316 | } 317 | 318 | // Sequence tracks outbound submission. 319 | type seq struct { 320 | // AcceptN has the sequence number for the next submission. Counting 321 | // starts at zero. The value is used to calculate the respective MQTT 322 | // packet identifiers. 323 | 324 | // Packets are accepted once they are persisted. The count is used as a 325 | // sequence number (starting with zero) in packet identifiers. 326 | acceptN uint 327 | 328 | // Any packets between submitN and acceptN are still pending network 329 | // submission. Such backlog may happen due to connectivity failure. 330 | submitN uint 331 | } 332 | 333 | func newClient(p Persistence, config *Config) *Client { 334 | if config.ReconnectWaitMin == 0 { 335 | config.ReconnectWaitMin = time.Second 336 | } 337 | if config.ReconnectWaitMin < 0 { 338 | config.ReconnectWaitMin = 0 339 | } 340 | if config.ReconnectWaitMax < config.ReconnectWaitMin { 341 | config.ReconnectWaitMax = config.ReconnectWaitMin 342 | } 343 | 344 | if config.AtLeastOnceMax < 0 || config.AtLeastOnceMax > publishIDMask { 345 | config.AtLeastOnceMax = publishIDMask + 1 346 | } 347 | if config.ExactlyOnceMax < 0 || config.ExactlyOnceMax > publishIDMask { 348 | config.ExactlyOnceMax = publishIDMask + 1 349 | } 350 | 351 | c := Client{ 352 | Config: *config, // copy 353 | persistence: p, 354 | onlineSig: make(chan chan struct{}, 1), 355 | offlineSig: make(chan chan struct{}, 1), 356 | connSem: make(chan net.Conn, 1), 357 | writeSem: make(chan net.Conn, 1), 358 | pingAck: make(chan chan<- error, 1), 359 | atLeastOnce: outbound{ 360 | seqSem: make(chan seq, 1), // must singleton 361 | queue: make(chan chan<- error, config.AtLeastOnceMax), 362 | }, 363 | exactlyOnce: outbound{ 364 | seqSem: make(chan seq, 1), // must singleton 365 | queue: make(chan chan<- error, config.ExactlyOnceMax), 366 | }, 367 | unorderedTxs: unorderedTxs{ 368 | perPacketID: make(map[uint16]unorderedCallback), 369 | }, 370 | } 371 | 372 | // start in offline state 373 | c.onlineSig <- make(chan struct{}) // blocks 374 | released := make(chan struct{}) 375 | close(released) 376 | c.offlineSig <- released 377 | 378 | c.connSem <- nil 379 | c.writeSem <- connPending 380 | 381 | c.connectCtx, c.connectHalt = context.WithCancel(context.Background()) 382 | c.atLeastOnce.seqSem <- seq{} 383 | c.exactlyOnce.seqSem <- seq{} 384 | return &c 385 | } 386 | 387 | // Close terminates the connection establishment. 388 | // The Client is closed regardless of the error return. 389 | // Closing an already closed Client has no effect. 390 | func (c *Client) Close() error { 391 | // block & terminate connection control 392 | c.connectHalt() 393 | lastConn, ok := <-c.connSem 394 | if !ok { 395 | return nil // already closed 396 | } 397 | close(c.connSem) // in own lock prevents double close 398 | 399 | var closeErr error 400 | if lastConn != nil { 401 | // stops ongoing write if any 402 | closeErr = lastConn.Close() 403 | } 404 | // block & terminate write 405 | conn := <-c.writeSem 406 | // WriteSem is not closed because a close of writeSem only happens with 407 | // connSem locked and closed. 408 | close(c.writeSem) // in own lock prevents double close 409 | 410 | switch conn { 411 | case connPending, connDown: 412 | return nil // already offline 413 | } 414 | // signal offline 415 | blockSignalChan(c.onlineSig) 416 | clearSignalChan(c.offlineSig) 417 | if lastConn == conn { 418 | return closeErr 419 | } 420 | return conn.Close() 421 | } 422 | 423 | // Disconnect tries a graceful termination, which discards the Will. 424 | // The Client is closed regardless of the error return. 425 | // 426 | // Quit is optional, as nil just blocks. Appliance of quit will strictly result 427 | // in ErrCanceled. 428 | // 429 | // BUG(pascaldekloe): 430 | // The MQTT protocol has no confirmation for disconnect request. 431 | // A Client can't know for sure whether the operation succeeded. 432 | func (c *Client) Disconnect(quit <-chan struct{}) error { 433 | // block & terminate connection control 434 | c.connectHalt() 435 | lastConn, ok := <-c.connSem 436 | if !ok { 437 | return fmt.Errorf("%w; DISCONNECT not send", ErrClosed) 438 | } 439 | close(c.connSem) // in own lock prevents double close 440 | 441 | // block & terminate write 442 | var conn net.Conn 443 | var didQuit bool 444 | select { 445 | case conn = <-c.writeSem: 446 | // WriteSem is not closed because a close of writeSem only 447 | // happens with connSem locked and closed. 448 | didQuit = false 449 | case <-quit: 450 | if lastConn != nil { 451 | // stop ongoing write if any 452 | lastConn.Close() 453 | } 454 | // won't block for long now 455 | conn = <-c.writeSem 456 | if conn != lastConn { 457 | conn.Close() 458 | } 459 | didQuit = true 460 | } 461 | close(c.writeSem) // in own lock prevents double close 462 | 463 | switch conn { 464 | case connPending: 465 | // allready offline 466 | return fmt.Errorf("%w; DISCONNECT not send", errNoConn) 467 | case connDown: 468 | // allready offline 469 | return fmt.Errorf("%w; DISCONNECT not send", ErrDown) 470 | } 471 | // signal offline 472 | blockSignalChan(c.onlineSig) 473 | clearSignalChan(c.offlineSig) 474 | if didQuit { 475 | return fmt.Errorf("%w; DISCONNECT not send", ErrCanceled) 476 | } 477 | 478 | // “After sending a DISCONNECT Packet the Client MUST NOT send 479 | // any more Control Packets on that Network Connection.” 480 | // — MQTT Version 3.1.1, conformance statement MQTT-3.14.4-2 481 | writeErr := writeTo(conn, packetDISCONNECT, c.PauseTimeout) 482 | closeErr := conn.Close() 483 | if writeErr != nil { 484 | return fmt.Errorf("%w; DISCONNECT lost", writeErr) 485 | } 486 | return closeErr 487 | } 488 | 489 | func (c *Client) termCallbacks() { 490 | var wg sync.WaitGroup 491 | 492 | wg.Add(1) 493 | go func() { 494 | defer wg.Done() 495 | 496 | _, ok := <-c.atLeastOnce.seqSem 497 | if !ok { // already terminated 498 | return 499 | } 500 | close(c.atLeastOnce.seqSem) // terminate 501 | 502 | // flush queue 503 | err := fmt.Errorf("%w; PUBLISH not confirmed", ErrClosed) 504 | // seqSem lock required for close: 505 | close(c.atLeastOnce.queue) 506 | for ch := range c.atLeastOnce.queue { 507 | ch <- err // won't block 508 | } 509 | }() 510 | 511 | wg.Add(1) 512 | go func() { 513 | defer wg.Done() 514 | 515 | _, ok := <-c.exactlyOnce.seqSem 516 | if !ok { // already terminated 517 | return 518 | } 519 | close(c.exactlyOnce.seqSem) // terminate 520 | 521 | // flush queue 522 | err := fmt.Errorf("%w; PUBLISH not confirmed", ErrClosed) 523 | // seqSem lock required for close: 524 | close(c.exactlyOnce.queue) 525 | for ch := range c.exactlyOnce.queue { 526 | ch <- err // won't block 527 | } 528 | }() 529 | 530 | select { 531 | case ack := <-c.pingAck: 532 | ack <- fmt.Errorf("%w; PING not confirmed", ErrBreak) 533 | default: 534 | break 535 | } 536 | wg.Wait() 537 | 538 | c.unorderedTxs.breakAll() 539 | } 540 | 541 | // Online returns a chanel that's closed when the client has a connection. 542 | func (c *Client) Online() <-chan struct{} { 543 | ch := <-c.onlineSig 544 | c.onlineSig <- ch 545 | return ch 546 | } 547 | 548 | // Offline returns a chanel that's closed when the client has no connection. 549 | func (c *Client) Offline() <-chan struct{} { 550 | ch := <-c.offlineSig 551 | c.offlineSig <- ch 552 | return ch 553 | } 554 | 555 | func clearSignalChan(ch chan chan struct{}) { 556 | sig := <-ch 557 | select { 558 | case <-sig: 559 | break // released already 560 | default: 561 | close(sig) // release 562 | } 563 | ch <- sig 564 | } 565 | 566 | func blockSignalChan(ch chan chan struct{}) { 567 | sig := <-ch 568 | select { 569 | case <-sig: 570 | ch <- make(chan struct{}) // block 571 | default: 572 | ch <- sig // blocks already 573 | } 574 | } 575 | 576 | func (c *Client) toOffline() { 577 | // halt Online signal per direct 578 | blockSignalChan(c.onlineSig) 579 | 580 | // lock write & close connection 581 | select { 582 | case _, ok := <-c.writeSem: 583 | if !ok { 584 | return // ErrClosed 585 | } 586 | c.readConn.Close() 587 | default: 588 | // interrupt write 589 | c.readConn.Close() 590 | // await write lock 591 | _, ok := <-c.writeSem 592 | if !ok { 593 | return // ErrClosed 594 | } 595 | } 596 | 597 | // release Offline signal before write unlock 598 | clearSignalChan(c.offlineSig) 599 | // unlock write in connection-pending state 600 | c.writeSem <- connPending 601 | 602 | // reset connection 603 | c.readConn = nil 604 | c.bufr = nil 605 | c.peek = nil 606 | c.bigMessage = nil 607 | 608 | // signal PING (doesn't have to be strict) 609 | select { 610 | case ack := <-c.pingAck: 611 | ack <- ErrBreak 612 | default: 613 | break // none in progress 614 | } 615 | 616 | c.unorderedTxs.breakAll() 617 | } 618 | 619 | // LockWrite acquires the write semaphore. It awaits the first connect attempt 620 | // when offline. Multiple goroutines may invoke lockWrite simultaneously. 621 | func (c *Client) lockWrite(quit <-chan struct{}) (net.Conn, error) { 622 | var checkConnect *time.Ticker 623 | 624 | for { 625 | // aquire write lock 626 | var ( 627 | conn net.Conn 628 | ok bool 629 | ) 630 | select { 631 | case conn, ok = <-c.writeSem: 632 | break // locked 633 | default: 634 | select { 635 | case conn, ok = <-c.writeSem: 636 | break // locked after all 637 | case <-quit: 638 | return nil, ErrCanceled 639 | } 640 | } 641 | if !ok { 642 | return nil, ErrClosed 643 | } 644 | 645 | switch conn { 646 | default: 647 | return conn, nil 648 | case connDown: 649 | c.writeSem <- connDown // unlock 650 | return nil, ErrDown 651 | case connPending: 652 | c.writeSem <- connPending // unlock 653 | } 654 | // await first connect attempt 655 | 656 | if checkConnect == nil { 657 | // start once, lazily 658 | checkConnect = time.NewTicker(20 * time.Millisecond) 659 | defer checkConnect.Stop() 660 | } 661 | select { 662 | case <-c.connectCtx.Done(): 663 | return nil, ErrClosed 664 | case <-c.Online(): 665 | break // connect succeeded 666 | case <-checkConnect.C: 667 | break // connect may have failed 668 | case <-quit: 669 | return nil, ErrCanceled 670 | } 671 | } 672 | } 673 | 674 | var connClosedErrors = []error{net.ErrClosed, io.ErrClosedPipe} 675 | 676 | // Write submits the packet. Keep synchronised with writeBuffers! 677 | // Multiple goroutines may invoke write simultaneously. 678 | func (c *Client) write(quit <-chan struct{}, p []byte) error { 679 | conn, err := c.lockWrite(quit) 680 | if err != nil { 681 | return err 682 | } 683 | 684 | err = writeTo(conn, p, c.PauseTimeout) 685 | if err != nil { 686 | // halt Online signal per direct 687 | blockSignalChan(c.onlineSig) 688 | if !nonNilIsAny(err, connClosedErrors) { 689 | conn.Close() // signal read routine 690 | } 691 | // release Offline signal before write unlock 692 | clearSignalChan(c.offlineSig) 693 | // unlock write in connection-pending state 694 | c.writeSem <- connPending 695 | return fmt.Errorf("%w; %w", ErrSubmit, err) 696 | } 697 | 698 | c.writeSem <- conn // unlock write 699 | return nil 700 | } 701 | 702 | // WriteBuffers submits the packet. Keep synchronised with write! 703 | // Multiple goroutines may invoke writeBuffers simultaneously. 704 | func (c *Client) writeBuffers(quit <-chan struct{}, p net.Buffers) error { 705 | conn, err := c.lockWrite(quit) 706 | if err != nil { 707 | return err 708 | } 709 | 710 | err = writeBuffersTo(conn, p, c.PauseTimeout) 711 | if err != nil { 712 | // halt Online signal per direct 713 | blockSignalChan(c.onlineSig) 714 | if !nonNilIsAny(err, connClosedErrors) { 715 | conn.Close() // signal read routine 716 | } 717 | // release Offline signal before write unlock 718 | clearSignalChan(c.offlineSig) 719 | // unlock write in connection-pending state 720 | c.writeSem <- connPending 721 | return fmt.Errorf("%w; %w", ErrSubmit, err) 722 | } 723 | 724 | c.writeSem <- conn // unlock write 725 | return err 726 | } 727 | 728 | // WriteTo submits the packet. Keep synchronised with writeBuffers! 729 | func writeTo(conn net.Conn, p []byte, idleTimeout time.Duration) error { 730 | if idleTimeout != 0 { 731 | // Abandon timer to prevent waking up the system for no good reason. 732 | // https://developer.apple.com/library/archive/documentation/Performance/Conceptual/EnergyGuide-iOS/MinimizeTimerUse.html 733 | defer conn.SetWriteDeadline(time.Time{}) 734 | } 735 | 736 | for { 737 | if idleTimeout != 0 { 738 | err := conn.SetWriteDeadline(time.Now().Add(idleTimeout)) 739 | if err != nil { 740 | return err // deemed critical 741 | } 742 | } 743 | n, err := conn.Write(p) 744 | if err == nil { 745 | return nil 746 | } 747 | 748 | // Allow deadline expiry if at least one byte was transferred. 749 | var ne net.Error 750 | if n == 0 || !errors.As(err, &ne) || !ne.Timeout() { 751 | return err 752 | } 753 | 754 | p = p[n:] 755 | } 756 | } 757 | 758 | // WriteBuffersTo submits the packet. Keep synchronised with write! 759 | func writeBuffersTo(conn net.Conn, p net.Buffers, idleTimeout time.Duration) error { 760 | if idleTimeout != 0 { 761 | // Abandon timer to prevent waking up the system for no good reason. 762 | // https://developer.apple.com/library/archive/documentation/Performance/Conceptual/EnergyGuide-iOS/MinimizeTimerUse.html 763 | defer conn.SetWriteDeadline(time.Time{}) 764 | } 765 | 766 | for { 767 | if idleTimeout != 0 { 768 | err := conn.SetWriteDeadline(time.Now().Add(idleTimeout)) 769 | if err != nil { 770 | return err // deemed critical 771 | } 772 | } 773 | n, err := p.WriteTo(conn) 774 | if err == nil { 775 | return nil 776 | } 777 | 778 | // Allow deadline expiry if at least one byte was transferred. 779 | var ne net.Error 780 | if n == 0 || !errors.As(err, &ne) || !ne.Timeout() { 781 | return err 782 | } 783 | 784 | // Don't modify the original buffers. 785 | var remaining net.Buffers 786 | offset := int(n) // size limited by packetMax 787 | for i, buf := range p { 788 | if len(buf) > offset { 789 | remaining = append(remaining, buf[offset:]) 790 | remaining = append(remaining, p[i+1:]...) 791 | break 792 | } 793 | offset -= len(buf) 794 | } 795 | p = remaining 796 | } 797 | } 798 | 799 | // PeekPacket slices a packet payload from the read buffer into c.peek. 800 | func (c *Client) peekPacket() (head byte, err error) { 801 | head, err = c.bufr.ReadByte() 802 | if err != nil { 803 | if errors.Is(err, io.EOF) { 804 | err = errBrokerTerm 805 | } 806 | return 0, err 807 | } 808 | 809 | if c.PauseTimeout != 0 { 810 | // Abandon timer to prevent waking up the system for no good reason. 811 | // https://developer.apple.com/library/archive/documentation/Performance/Conceptual/EnergyGuide-iOS/MinimizeTimerUse.html 812 | defer c.readConn.SetReadDeadline(time.Time{}) 813 | } 814 | 815 | // decode “remaining length” 816 | var size int 817 | for shift := uint(0); ; shift += 7 { 818 | if c.bufr.Buffered() == 0 && c.PauseTimeout != 0 { 819 | err := c.readConn.SetReadDeadline(time.Now().Add(c.PauseTimeout)) 820 | if err != nil { 821 | return 0, err // deemed critical 822 | } 823 | } 824 | b, err := c.bufr.ReadByte() 825 | if err != nil { 826 | if errors.Is(err, io.EOF) { 827 | err = io.ErrUnexpectedEOF 828 | } 829 | return 0, fmt.Errorf("mqtt: header from packet %#b incomplete: %w", head, err) 830 | } 831 | size |= int(b&0x7f) << shift 832 | if b&0x80 == 0 { 833 | break 834 | } 835 | if shift > 21 { 836 | return 0, fmt.Errorf("%w: remaining length encoding from packet %#b exceeds 4 bytes", errProtoReset, head) 837 | } 838 | } 839 | 840 | // slice payload form read buffer 841 | for { 842 | if c.bufr.Buffered() < size && c.PauseTimeout != 0 { 843 | err := c.readConn.SetReadDeadline(time.Now().Add(c.PauseTimeout)) 844 | if err != nil { 845 | return 0, err // deemed critical 846 | } 847 | } 848 | 849 | lastN := len(c.peek) 850 | c.peek, err = c.bufr.Peek(size) 851 | switch { 852 | case err == nil: // OK 853 | return head, err 854 | case head>>4 == typePUBLISH && errors.Is(err, bufio.ErrBufferFull): 855 | return head, &BigMessage{Client: c, Size: size} 856 | } 857 | 858 | // Allow deadline expiry if at least one byte was transferred. 859 | var ne net.Error 860 | if len(c.peek) > lastN && errors.As(err, &ne) && ne.Timeout() { 861 | continue 862 | } 863 | 864 | if errors.Is(err, io.EOF) { 865 | err = io.ErrUnexpectedEOF 866 | } 867 | return 0, fmt.Errorf("mqtt: got %d out of %d bytes from packet %#b: %w", 868 | len(c.peek), size, head, err) 869 | } 870 | } 871 | 872 | // Discard skips n bytes from the network connection. 873 | func (c *Client) discard(n int) error { 874 | if c.PauseTimeout != 0 { 875 | // Abandon timer to prevent waking up the system for no good reason. 876 | // https://developer.apple.com/library/archive/documentation/Performance/Conceptual/EnergyGuide-iOS/MinimizeTimerUse.html 877 | defer c.readConn.SetReadDeadline(time.Time{}) 878 | } 879 | 880 | for { 881 | if c.PauseTimeout != 0 { 882 | err := c.readConn.SetReadDeadline(time.Now().Add(c.PauseTimeout)) 883 | if err != nil { 884 | return err // deemed critical 885 | } 886 | } 887 | 888 | done, err := c.bufr.Discard(n) 889 | if err == nil { 890 | return nil 891 | } 892 | 893 | // Allow deadline expiry if at least one byte was transferred. 894 | var ne net.Error 895 | if done != 0 && errors.As(err, &ne) && ne.Timeout() { 896 | n -= done 897 | continue 898 | } 899 | 900 | return fmt.Errorf("mqtt: %d bytes remaining of packet discard: %w", 901 | n, err) 902 | } 903 | } 904 | 905 | // Connect installs the transport layer. 906 | // 907 | // The current connection must be closed in case of a reconnect. 908 | func (c *Client) connect() error { 909 | previousConn, ok := <-c.connSem // locks connection control 910 | if !ok { 911 | return ErrClosed 912 | } 913 | // No need for further closed channel checks as the 914 | // connSem lock is required to close any of them. 915 | 916 | config := c.Config // copy 917 | // Reconnects shouldn't reset the session. 918 | if previousConn != nil { 919 | config.CleanSession = false 920 | } 921 | conn, bufr, err := c.dialAndConnect(&config) 922 | if err != nil { 923 | // ErrDown after failed connect 924 | <-c.writeSem 925 | c.writeSem <- connDown 926 | 927 | c.connSem <- previousConn // unlock 928 | return err 929 | } 930 | 931 | // lock sequences until resubmission (checks) complete 932 | atLeastOnceSeq := <-c.atLeastOnce.seqSem 933 | exactlyOnceSeq := <-c.exactlyOnce.seqSem 934 | 935 | // lock write in sequence locks, conform submitPersisted 936 | <-c.writeSem 937 | 938 | c.connSem <- conn // unlock (for interruption of resends) 939 | 940 | err = c.resend(conn, c.orderedTxs.Acked, &atLeastOnceSeq, atLeastOnceIDSpace) 941 | c.atLeastOnce.seqSem <- atLeastOnceSeq // unlock 942 | if err != nil { 943 | c.exactlyOnce.seqSem <- exactlyOnceSeq // unlock 944 | conn.Close() 945 | c.writeSem <- connDown 946 | return err 947 | } 948 | err = c.resend(conn, c.orderedTxs.Completed, &exactlyOnceSeq, exactlyOnceIDSpace) 949 | c.exactlyOnce.seqSem <- exactlyOnceSeq // unlock 950 | if err != nil { 951 | conn.Close() 952 | c.writeSem <- connDown 953 | return err 954 | } 955 | 956 | // halt Offline signal before connection release 957 | blockSignalChan(c.offlineSig) 958 | // connection release 959 | c.writeSem <- conn 960 | c.readConn = conn 961 | c.bufr = bufr 962 | // release Online signal after connection release 963 | clearSignalChan(c.onlineSig) 964 | // reset backoff ramp-up 965 | c.reconnectWait = 0 966 | return nil 967 | } 968 | 969 | func (c *Client) dialAndConnect(config *Config) (net.Conn, *bufio.Reader, error) { 970 | clientID, err := c.persistence.Load(clientIDKey) 971 | if err != nil { 972 | return nil, nil, err 973 | } 974 | 975 | // establish network connection 976 | ctx := c.connectCtx 977 | if c.PauseTimeout != 0 { 978 | var cancel context.CancelFunc 979 | ctx, cancel = context.WithTimeout(ctx, c.PauseTimeout) 980 | defer cancel() 981 | } 982 | conn, err := c.Dialer(ctx) 983 | if err != nil { 984 | switch { 985 | case errors.Is(err, context.Canceled): 986 | return nil, nil, fmt.Errorf("%w; dial cancelled", ErrClosed) 987 | case errors.Is(err, context.DeadlineExceeded): 988 | return nil, nil, fmt.Errorf("mqtt: dial timeout (after %s)", 989 | c.PauseTimeout) 990 | } 991 | return nil, nil, fmt.Errorf("mqtt: no connect: %w", err) 992 | } 993 | // “After a Network Connection is established by a Client to a Server, 994 | // the first Packet sent from the Client to the Server MUST be a CONNECT 995 | // Packet.” 996 | // — MQTT Version 3.1.1, conformance statement MQTT-3.1.0-1 997 | 998 | // The connection context applies to the handshake too, as we don't want 999 | // slow handshakes to block a shutdown from either Close or Disconnect. 1000 | stopConnClose := context.AfterFunc(c.connectCtx, func() { 1001 | conn.Close() 1002 | }) 1003 | bufr, err := c.handshake(conn, config, clientID) 1004 | if !stopConnClose() { 1005 | // connect context canceled (by either Close or Disconnect) 1006 | return nil, nil, ErrClosed 1007 | } 1008 | if err != nil { 1009 | closeErr := conn.Close() 1010 | return nil, nil, errors.Join(err, closeErr) 1011 | } 1012 | return conn, bufr, nil 1013 | } 1014 | 1015 | // Resend submits any and all pending since seqNoOffset. Sequence numbers count 1016 | // from zero. Each sequence number is one less than the respective accept count 1017 | // was at the time. 1018 | func (c *Client) resend(conn net.Conn, seqNoOffset uint, seq *seq, space uint) error { 1019 | for seqNo := seqNoOffset; seqNo < seq.acceptN; seqNo++ { 1020 | key := seqNo&publishIDMask | space 1021 | packet, err := c.persistence.Load(uint(key)) 1022 | if err != nil { 1023 | return err 1024 | } 1025 | if packet == nil { 1026 | return fmt.Errorf("mqtt: persistence key %#04x gone missing 👻", key) 1027 | } 1028 | 1029 | if seqNo < seq.submitN && packet[0]>>4 == typePUBLISH { 1030 | packet[0] |= dupeFlag 1031 | } 1032 | 1033 | err = writeTo(conn, packet, c.PauseTimeout) 1034 | if err != nil { 1035 | return err 1036 | } 1037 | 1038 | if seqNo >= seq.submitN { 1039 | seq.submitN = seqNo + 1 1040 | } 1041 | } 1042 | return nil 1043 | } 1044 | 1045 | func (c *Client) handshake(conn net.Conn, config *Config, clientID []byte) (*bufio.Reader, error) { 1046 | // send request 1047 | err := writeTo(conn, config.newConnectReq(clientID), c.PauseTimeout) 1048 | if err != nil { 1049 | return nil, fmt.Errorf("mqtt: connection fatal during CONNECT submission; %w", err) 1050 | } 1051 | 1052 | r := bufio.NewReaderSize(conn, readBufSize) 1053 | 1054 | // Apply the timeout to the "entire" 4-byte response. 1055 | if c.PauseTimeout != 0 { 1056 | err := conn.SetReadDeadline(time.Now().Add(c.PauseTimeout)) 1057 | if err != nil { 1058 | return nil, err // deemed critical 1059 | } 1060 | defer conn.SetReadDeadline(time.Time{}) 1061 | } 1062 | 1063 | // “The first packet sent from the Server to the Client MUST be a 1064 | // CONNACK Packet.” 1065 | // — MQTT Version 3.1.1, conformance statement MQTT-3.2.0-1 1066 | packet, err := r.Peek(4) 1067 | // A smaller packet may cause timeout errors. 😉 1068 | if len(packet) > 1 && (packet[0] != typeCONNACK<<4 || packet[1] != 2) { 1069 | return nil, fmt.Errorf("%w: want fixed CONNACK header 0x2002, got %#x", errProtoReset, packet) 1070 | } 1071 | if err != nil { 1072 | if errors.Is(err, io.EOF) { 1073 | err = errBrokerTerm 1074 | } 1075 | return nil, fmt.Errorf("%w; CONNECT not confirmed", err) 1076 | } 1077 | 1078 | // Check the return code first to prevent confusion with flag appliance. 1079 | // 1080 | // “If a server sends a CONNACK packet containing a non-zero return code 1081 | // it MUST set Session Present to 0.” 1082 | // — MQTT Version 3.1.1, conformance statement MQTT-3.2.2-4 1083 | if r := connectReturn(packet[3]); r != accepted { 1084 | return nil, r 1085 | } 1086 | 1087 | switch flags := packet[2]; flags { 1088 | // “Bits 7-1 are reserved and MUST be set to 0.” 1089 | default: 1090 | return nil, fmt.Errorf("%w: CONNACK with reserved flags %#b", 1091 | errProtoReset, flags) 1092 | 1093 | // “Bit 0 (SP1) is the Session Present Flag.” 1094 | case 1: 1095 | // “If the Server accepts a connection with CleanSession set to 1096 | // 1, the Server MUST set Session Present to 0 in the CONNACK …” 1097 | // — MQTT Version 3.1.1, conformance statement MQTT-3.2.2-1 1098 | if config.CleanSession { 1099 | return nil, fmt.Errorf("%w: CONNACK with session-present for clean-session request", 1100 | errProtoReset) 1101 | } 1102 | 1103 | // don't clear InNewSession (on reconnects) 1104 | 1105 | case 0: 1106 | c.InNewSession.Store(true) 1107 | } 1108 | 1109 | r.Discard(len(packet)) // no errors guaranteed 1110 | return r, nil 1111 | } 1112 | 1113 | var closed = make(chan struct{}) 1114 | 1115 | func init() { 1116 | close(closed) 1117 | } 1118 | 1119 | // ReadBackoff returns a channel which is closed once ReadSlices should be 1120 | // invoked again. The return is nil when ReadSlices was fatal, i.e., ErrClosed 1121 | // gets a nil channel which blocks. Idle time on connection loss is subject to 1122 | // ReconnectWaitMin and ReconnectWaitMax from Config. 1123 | func (c *Client) ReadBackoff(err error) <-chan struct{} { 1124 | var idle time.Duration 1125 | switch { 1126 | case err == nil, c.bigMessage != nil: 1127 | return closed // no backoff 1128 | 1129 | case errors.Is(err, ErrClosed): 1130 | return nil // blocks 1131 | 1132 | case c.readConn != nil: 1133 | // error came from Persistence ☠️ 1134 | idle = time.Second 1135 | 1136 | case IsConnectionRefused(err): 1137 | // documented behaviour 1138 | idle = c.ReconnectWaitMax 1139 | 1140 | default: 1141 | // need reconnect 1142 | idle = c.reconnectWait 1143 | idle = max(idle, c.ReconnectWaitMin) 1144 | idle = min(idle, c.ReconnectWaitMax) 1145 | // exponential ramp-up is documented behaviour 1146 | c.reconnectWait = idle * 2 1147 | } 1148 | 1149 | wait := make(chan struct{}) 1150 | time.AfterFunc(idle, func() { close(wait) }) 1151 | return wait 1152 | } 1153 | 1154 | // ReadSlices should be invoked consecutively from a single goroutine until 1155 | // ErrClosed. Each invocation acknowledges ownership of the previous return. 1156 | // 1157 | // Both message and topic are sliced from a read buffer. The bytes stop being 1158 | // valid at the next read. BigMessage leaves memory allocation beyond the read 1159 | // buffer as a choice to the consumer. 1160 | // 1161 | // Slow processing of the return may freeze the Client. Blocking operations 1162 | // require counter measures for stability: 1163 | // 1164 | // - Start a goroutine with a copy of message and/or topic. 1165 | // - Start a goroutine with the bytes parsed/unmarshalled. 1166 | // - Persist message and/or topic. Then, continue from there. 1167 | // - Apply low timeouts in a strict manner. 1168 | // 1169 | // Invocation should apply some backoff after errors other than BigMessage. 1170 | // Use of ReadBackoff comes recommended. See the Client example for a setup. 1171 | func (c *Client) ReadSlices() (message, topic []byte, err error) { 1172 | message, topic, err = c.readSlices() 1173 | switch { 1174 | case err == c.bigMessage: // either nil or BigMessage 1175 | break 1176 | case errors.Is(err, ErrClosed): 1177 | c.termCallbacks() 1178 | } 1179 | return 1180 | } 1181 | 1182 | func (c *Client) readSlices() (message, topic []byte, err error) { 1183 | // auto connect 1184 | if c.readConn == nil { 1185 | if err = c.connect(); err != nil { 1186 | return nil, nil, err 1187 | } 1188 | } 1189 | 1190 | // flush big message if any 1191 | if c.bigMessage != nil { 1192 | remaining := c.bigMessage.Size 1193 | c.bigMessage = nil 1194 | 1195 | err = c.discard(remaining) 1196 | if err != nil { 1197 | c.toOffline() 1198 | return nil, nil, err 1199 | } 1200 | } 1201 | 1202 | // skip previous packet, if any 1203 | c.bufr.Discard(len(c.peek)) // no error guaranteed 1204 | c.peek = nil 1205 | 1206 | // acknowledge previous packet, if any 1207 | if len(c.pendingAck) != 0 { 1208 | // BUG(pascaldekloe): 1209 | // Save errors from Persistence may cause duplicate reception 1210 | // of messages with the “exactly once” guarantee, but only in 1211 | // a follow-up with AdoptSession, and only if the Client which 1212 | // encountered the Persistence failure went down before its 1213 | // automatic-recovery (from ReadSlices) succeeded. 1214 | if c.pendingAck[0]>>4 == typePUBREC { 1215 | key := uint(binary.BigEndian.Uint16(c.pendingAck[2:4])) | remoteIDKeyFlag 1216 | err = c.persistence.Save(key, net.Buffers{c.pendingAck}) 1217 | if err != nil { 1218 | return nil, nil, err 1219 | } 1220 | } 1221 | err := c.write(nil, c.pendingAck) 1222 | if err != nil { 1223 | c.toOffline() 1224 | return nil, nil, err // keeps pendingAck to retry 1225 | } 1226 | 1227 | c.pendingAck = c.pendingAck[:0] 1228 | } 1229 | 1230 | // process packets until a PUBLISH appears 1231 | for { 1232 | head, err := c.peekPacket() 1233 | switch { 1234 | case err == nil: 1235 | break 1236 | 1237 | case errors.Is(err, net.ErrClosed) || errors.Is(err, io.ErrClosedPipe): 1238 | // closed by either Close, Disconnect, or failed write 1239 | c.toOffline() 1240 | if err := c.connect(); err != nil { 1241 | return nil, nil, err 1242 | } 1243 | continue // with new connection 1244 | 1245 | case errors.As(err, &c.bigMessage): 1246 | // keys + topic under readBufSize thus in c.peek 1247 | partialMessage, topic, err := c.onPUBLISH(head) 1248 | if err != nil { 1249 | if err != errDupe { 1250 | c.toOffline() 1251 | return nil, nil, err 1252 | } 1253 | 1254 | // can just skip the already received 1255 | payloadSize := c.bigMessage.Size 1256 | c.bigMessage = nil 1257 | c.peek = nil 1258 | err := c.discard(payloadSize) 1259 | if err != nil { 1260 | return nil, nil, err 1261 | } 1262 | continue 1263 | } 1264 | 1265 | // serve big message (as error) 1266 | c.bigMessage.Topic = string(topic) // copy 1267 | beforeMessage := readBufSize - len(partialMessage) 1268 | c.bigMessage.Size -= beforeMessage 1269 | c.peek = nil 1270 | c.bufr.Discard(beforeMessage) // no errors guaranteed 1271 | return nil, nil, c.bigMessage 1272 | 1273 | default: 1274 | c.toOffline() 1275 | return nil, nil, err 1276 | } 1277 | 1278 | switch head >> 4 { 1279 | case typeRESERVED0: 1280 | err = errRESERVED0 1281 | case typeCONNECT: 1282 | err = errGotCONNECT 1283 | case typeCONNACK: 1284 | err = errCONNACKTwo 1285 | case typePUBLISH: 1286 | message, topic, err = c.onPUBLISH(head) 1287 | if err == nil { 1288 | return message, topic, nil 1289 | } 1290 | if err == errDupe { 1291 | err = nil // can just skip 1292 | } 1293 | case typePUBACK: 1294 | err = c.onPUBACK() 1295 | case typePUBREC: 1296 | err = c.onPUBREC() 1297 | case typePUBREL: 1298 | err = c.onPUBREL() 1299 | case typePUBCOMP: 1300 | err = c.onPUBCOMP() 1301 | case typeSUBSCRIBE: 1302 | err = errGotSUBSCRIBE 1303 | case typeSUBACK: 1304 | err = c.onSUBACK() 1305 | case typeUNSUBSCRIBE: 1306 | err = errGotUNSUBSCRIBE 1307 | case typeUNSUBACK: 1308 | err = c.onUNSUBACK() 1309 | case typePINGREQ: 1310 | err = errGotPINGREQ 1311 | case typePINGRESP: 1312 | err = c.onPINGRESP() 1313 | case typeDISCONNECT: 1314 | err = errGotDISCONNECT 1315 | case typeRESERVED15: 1316 | err = errRESERVED15 1317 | } 1318 | if err != nil { 1319 | c.toOffline() 1320 | return nil, nil, err 1321 | } 1322 | 1323 | // no errors guaranteed 1324 | c.bufr.Discard(len(c.peek)) 1325 | } 1326 | } 1327 | 1328 | // BigMessage signals reception beyond the read buffer capacity. 1329 | // Receivers may or may not allocate the memory with ReadAll. 1330 | // The next ReadSlices will acknowledge reception either way. 1331 | type BigMessage struct { 1332 | *Client // source 1333 | Topic string // destination 1334 | Size int // byte count 1335 | } 1336 | 1337 | // Error implements the standard error interface. 1338 | func (e *BigMessage) Error() string { 1339 | return fmt.Sprintf("mqtt: %d-byte message exceeds read buffer capacity", e.Size) 1340 | } 1341 | 1342 | // ReadAll returns the message in a new/dedicated buffer. Messages can be read 1343 | // once at most. Read fails on second attempt. Read also fails after followup by 1344 | // another ReadSlices. 1345 | func (e *BigMessage) ReadAll() ([]byte, error) { 1346 | if e.Client.bigMessage != e { 1347 | return nil, errors.New("mqtt: read window expired for a big message") 1348 | } 1349 | e.Client.bigMessage = nil 1350 | 1351 | message := make([]byte, e.Size) 1352 | _, err := io.ReadFull(e.Client.bufr, message) 1353 | if err != nil { 1354 | return nil, err 1355 | } 1356 | return message, nil 1357 | } 1358 | 1359 | var errDupe = errors.New("mqtt: duplicate reception") 1360 | 1361 | // OnPUBLISH slices an inbound message from Client.peek. 1362 | func (c *Client) onPUBLISH(head byte) (message, topic []byte, err error) { 1363 | if len(c.peek) < 2 { 1364 | return nil, nil, fmt.Errorf("%w: PUBLISH with %d byte remaining length", errProtoReset, len(c.peek)) 1365 | } 1366 | i := int(uint(binary.BigEndian.Uint16(c.peek))) + 2 1367 | if i > len(c.peek) { 1368 | return nil, nil, fmt.Errorf("%w: PUBLISH topic exceeds remaining length", errProtoReset) 1369 | } 1370 | topic = c.peek[2:i] 1371 | 1372 | switch head & 0b0110 { 1373 | case atMostOnceLevel << 1: 1374 | break 1375 | 1376 | case atLeastOnceLevel << 1: 1377 | if len(c.peek) < i+2 { 1378 | return nil, nil, fmt.Errorf("%w: PUBLISH packet identifier exceeds remaining length", errProtoReset) 1379 | } 1380 | packetID := binary.BigEndian.Uint16(c.peek[i:]) 1381 | if packetID == 0 { 1382 | return nil, nil, errPacketIDZero 1383 | } 1384 | i += 2 1385 | 1386 | // enqueue for next call 1387 | if len(c.pendingAck) != 0 { 1388 | return nil, nil, fmt.Errorf("mqtt: internal error: ack %#x pending during PUBLISH at least once reception", c.pendingAck) 1389 | } 1390 | c.pendingAck = append(c.pendingAck, typePUBACK<<4, 2, byte(packetID>>8), byte(packetID)) 1391 | 1392 | case exactlyOnceLevel << 1: 1393 | if len(c.peek) < i+2 { 1394 | return nil, nil, fmt.Errorf("%w: PUBLISH packet identifier exceeds remaining length", errProtoReset) 1395 | } 1396 | packetID := uint(binary.BigEndian.Uint16(c.peek[i:])) 1397 | if packetID == 0 { 1398 | return nil, nil, errPacketIDZero 1399 | } 1400 | i += 2 1401 | 1402 | bytes, err := c.persistence.Load(packetID | remoteIDKeyFlag) 1403 | if err != nil { 1404 | return nil, nil, err 1405 | } 1406 | if bytes != nil { 1407 | return nil, nil, errDupe 1408 | } 1409 | 1410 | // enqueue for next call 1411 | if len(c.pendingAck) != 0 { 1412 | return nil, nil, fmt.Errorf("mqtt: internal error: ack %#x pending during PUBLISH exactly once reception", c.pendingAck) 1413 | } 1414 | c.pendingAck = append(c.pendingAck, typePUBREC<<4, 2, byte(packetID>>8), byte(packetID)) 1415 | 1416 | default: 1417 | return nil, nil, fmt.Errorf("%w: PUBLISH with reserved quality-of-service level 3", errProtoReset) 1418 | } 1419 | 1420 | return c.peek[i:], topic, nil 1421 | } 1422 | 1423 | // OnPUBREL applies the second round-trip for “exactly-once” reception. 1424 | func (c *Client) onPUBREL() error { 1425 | if len(c.peek) != 2 { 1426 | return fmt.Errorf("%w: PUBREL with %d byte remaining length", errProtoReset, len(c.peek)) 1427 | } 1428 | packetID := uint(binary.BigEndian.Uint16(c.peek)) 1429 | if packetID == 0 { 1430 | return errPacketIDZero 1431 | } 1432 | 1433 | err := c.persistence.Delete(packetID | remoteIDKeyFlag) 1434 | if err != nil { 1435 | return err // causes resubmission of PUBREL 1436 | } 1437 | // Use pendingAck as a buffer here. 1438 | if len(c.pendingAck) != 0 { 1439 | return fmt.Errorf("mqtt: internal error: ack %#x pending during PUBREL reception", c.pendingAck) 1440 | } 1441 | c.pendingAck = append(c.pendingAck, typePUBCOMP<<4, 2, byte(packetID>>8), byte(packetID)) 1442 | err = c.write(nil, c.pendingAck) 1443 | if err != nil { 1444 | return err // causes resubmission of PUBCOMP 1445 | } 1446 | c.pendingAck = c.pendingAck[:0] 1447 | return nil 1448 | } 1449 | 1450 | // The write semaphore may hold a connSignal when not connected. 1451 | const ( 1452 | connPending connSignal = iota // first (re)connect attempt 1453 | connDown // failed (re)connect attempt 1454 | ) 1455 | 1456 | // ConnSignal is a net.Conn. 1457 | type connSignal int 1458 | 1459 | const connSignalInvoke = "signal invoked as connection" 1460 | 1461 | func (c connSignal) Read(b []byte) (n int, err error) { panic(connSignalInvoke) } 1462 | func (c connSignal) Write(b []byte) (int, error) { panic(connSignalInvoke) } 1463 | func (c connSignal) Close() error { panic(connSignalInvoke) } 1464 | func (c connSignal) LocalAddr() net.Addr { panic(connSignalInvoke) } 1465 | func (c connSignal) RemoteAddr() net.Addr { panic(connSignalInvoke) } 1466 | func (c connSignal) SetDeadline(time.Time) error { panic(connSignalInvoke) } 1467 | func (c connSignal) SetReadDeadline(time.Time) error { panic(connSignalInvoke) } 1468 | func (c connSignal) SetWriteDeadline(time.Time) error { panic(connSignalInvoke) } 1469 | --------------------------------------------------------------------------------