├── .github └── workflows │ └── build.yml ├── .gitignore ├── .golangci.yml ├── Dockerfile ├── LICENSE.md ├── README-CN.md ├── README-JP.md ├── README.md ├── clients.go ├── clients_test.go ├── cmd ├── docker │ └── main.go └── main.go ├── config.yaml ├── config ├── config.go └── config_test.go ├── examples ├── auth │ ├── basic │ │ └── main.go │ └── encoded │ │ ├── auth.json │ │ ├── auth.yaml │ │ └── main.go ├── benchmark │ └── main.go ├── config │ ├── config.json │ ├── config.yaml │ └── main.go ├── debug │ └── main.go ├── direct │ └── main.go ├── hooks │ └── main.go ├── paho.testing │ └── main.go ├── persistence │ ├── badger │ │ └── main.go │ ├── bolt │ │ └── main.go │ ├── pebble │ │ └── main.go │ └── redis │ │ └── main.go ├── tcp │ └── main.go ├── tls │ └── main.go └── websocket │ └── main.go ├── go.mod ├── go.sum ├── hooks.go ├── hooks ├── auth │ ├── allow_all.go │ ├── allow_all_test.go │ ├── auth.go │ ├── auth_test.go │ ├── ledger.go │ └── ledger_test.go ├── debug │ └── debug.go └── storage │ ├── badger │ ├── badger.go │ └── badger_test.go │ ├── bolt │ ├── bolt.go │ └── bolt_test.go │ ├── pebble │ ├── pebble.go │ └── pebble_test.go │ ├── redis │ ├── redis.go │ └── redis_test.go │ ├── storage.go │ └── storage_test.go ├── hooks_test.go ├── inflight.go ├── inflight_test.go ├── listeners ├── http_healthcheck.go ├── http_healthcheck_test.go ├── http_sysinfo.go ├── http_sysinfo_test.go ├── listeners.go ├── listeners_test.go ├── mock.go ├── mock_test.go ├── net.go ├── net_test.go ├── tcp.go ├── tcp_test.go ├── unixsock.go ├── unixsock_test.go ├── websocket.go └── websocket_test.go ├── mempool ├── bufpool.go └── bufpool_test.go ├── packets ├── codec.go ├── codec_test.go ├── codes.go ├── codes_test.go ├── fixedheader.go ├── fixedheader_test.go ├── packets.go ├── packets_test.go ├── properties.go ├── properties_test.go ├── tpackets.go └── tpackets_test.go ├── server.go ├── server_test.go ├── system ├── system.go └── system_test.go ├── topics.go └── topics_test.go /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v3 10 | - name: Set up Go 11 | uses: actions/setup-go@v3 12 | with: 13 | go-version: 1.21 14 | - name: Vet 15 | run: go vet ./... 16 | - name: Test 17 | run: go test -race ./... && echo true 18 | 19 | coverage: 20 | name: Test with Coverage 21 | runs-on: ubuntu-latest 22 | steps: 23 | - name: Set up Go 24 | uses: actions/setup-go@v3 25 | with: 26 | go-version: '1.21' 27 | - name: Check out code 28 | uses: actions/checkout@v3 29 | - name: Install dependencies 30 | run: | 31 | go mod download 32 | - name: Run Unit tests 33 | run: | 34 | go test -race -covermode atomic -coverprofile=covprofile ./... 35 | - name: Install goveralls 36 | run: go install github.com/mattn/goveralls@latest 37 | - name: Send coverage 38 | env: 39 | COVERALLS_TOKEN: ${{ secrets.GITHUB_TOKEN }} 40 | run: goveralls -coverprofile=covprofile -service=github 41 | 42 | docker: 43 | if: github.repository == 'mochi-mqtt/server' && startsWith(github.ref, 'refs/tags/v') 44 | runs-on: ubuntu-latest 45 | needs: build 46 | steps: 47 | - name: Checkout 48 | uses: actions/checkout@v4 49 | - name: Docker meta 50 | id: meta 51 | uses: docker/metadata-action@v5 52 | with: 53 | images: mochimqtt/server 54 | tags: | 55 | type=semver,pattern={{version}} 56 | type=semver,pattern={{major}}.{{minor}} 57 | type=raw,value=latest,enable=${{ endsWith(github.ref, 'main') }} 58 | - name: Login to Docker Hub 59 | uses: docker/login-action@v3 60 | with: 61 | username: ${{ secrets.DOCKERHUB_USERNAME }} 62 | password: ${{ secrets.DOCKERHUB_TOKEN }} 63 | - name: Set up Docker Buildx 64 | uses: docker/setup-buildx-action@v3 65 | - name: Build and push 66 | uses: docker/build-push-action@v5 67 | with: 68 | context: . 69 | file: ./Dockerfile 70 | platforms: linux/amd64,linux/arm64 71 | push: true 72 | tags: ${{ steps.meta.outputs.tags }} 73 | labels: ${{ steps.meta.outputs.labels }} 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | cmd/mqtt 2 | .DS_Store 3 | *.db 4 | .idea 5 | vendor -------------------------------------------------------------------------------- /.golangci.yml: -------------------------------------------------------------------------------- 1 | linters: 2 | disable-all: false 3 | fix: false # Fix found issues (if it's supported by the linter). 4 | enable: 5 | # - asasalint 6 | # - asciicheck 7 | # - bidichk 8 | # - bodyclose 9 | # - containedctx 10 | # - contextcheck 11 | #- cyclop 12 | # - deadcode 13 | - decorder 14 | # - depguard 15 | # - dogsled 16 | # - dupl 17 | - durationcheck 18 | # - errchkjson 19 | # - errname 20 | - errorlint 21 | # - execinquery 22 | # - exhaustive 23 | # - exhaustruct 24 | # - exportloopref 25 | #- forcetypeassert 26 | #- forbidigo 27 | #- funlen 28 | #- gci 29 | # - gochecknoglobals 30 | # - gochecknoinits 31 | # - gocognit 32 | # - goconst 33 | # - gocritic 34 | - gocyclo 35 | - godot 36 | # - godox 37 | # - goerr113 38 | # - gofmt 39 | # - gofumpt 40 | # - goheader 41 | - goimports 42 | # - golint 43 | # - gomnd 44 | # - gomoddirectives 45 | # - gomodguard 46 | # - goprintffuncname 47 | - gosec 48 | - gosimple 49 | - govet 50 | # - grouper 51 | # - ifshort 52 | - importas 53 | - ineffassign 54 | # - interfacebloat 55 | # - interfacer 56 | # - ireturn 57 | # - lll 58 | # - maintidx 59 | # - makezero 60 | - maligned 61 | - misspell 62 | # - nakedret 63 | # - nestif 64 | # - nilerr 65 | # - nilnil 66 | # - nlreturn 67 | # - noctx 68 | # - nolintlint 69 | # - nonamedreturns 70 | # - nosnakecase 71 | # - nosprintfhostport 72 | # - paralleltest 73 | # - prealloc 74 | # - predeclared 75 | # - promlinter 76 | - reassign 77 | # - revive 78 | # - rowserrcheck 79 | # - scopelint 80 | # - sqlclosecheck 81 | # - staticcheck 82 | # - structcheck 83 | # - stylecheck 84 | # - tagliatelle 85 | # - tenv 86 | # - testpackage 87 | # - thelper 88 | - tparallel 89 | # - typecheck 90 | - unconvert 91 | - unparam 92 | - unused 93 | - usestdlibvars 94 | # - varcheck 95 | # - varnamelen 96 | - wastedassign 97 | - whitespace 98 | # - wrapcheck 99 | # - wsl 100 | disable: 101 | - errcheck 102 | 103 | 104 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.21.0-alpine3.18 AS builder 2 | 3 | RUN apk update 4 | RUN apk add git 5 | 6 | WORKDIR /app 7 | 8 | COPY go.mod ./ 9 | COPY go.sum ./ 10 | RUN go mod download 11 | 12 | COPY . ./ 13 | 14 | RUN go build -o /app/mochi ./cmd/docker 15 | 16 | FROM alpine 17 | 18 | WORKDIR / 19 | COPY --from=builder /app/mochi . 20 | 21 | ENTRYPOINT [ "/mochi" ] 22 | CMD ["/cmd/docker", "--config", "config.yaml"] -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2023 Mochi-MQTT Organisation 5 | Copyright (c) 2019, 2022, 2023 Jonathan Blake (mochi-co) 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /cmd/docker/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2023 mochi-mqtt 3 | // SPDX-FileContributor: dgduncan, mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "flag" 9 | "github.com/mochi-mqtt/server/v2/config" 10 | "log" 11 | "log/slog" 12 | "os" 13 | "os/signal" 14 | "syscall" 15 | 16 | mqtt "github.com/mochi-mqtt/server/v2" 17 | ) 18 | 19 | func main() { 20 | slog.SetDefault(slog.New(slog.NewTextHandler(os.Stdout, nil))) // set basic logger to ensure logs before configuration are in a consistent format 21 | 22 | configFile := flag.String("config", "config.yaml", "path to mochi config yaml or json file") 23 | flag.Parse() 24 | 25 | sigs := make(chan os.Signal, 1) 26 | done := make(chan bool, 1) 27 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 28 | go func() { 29 | <-sigs 30 | done <- true 31 | }() 32 | 33 | configBytes, err := os.ReadFile(*configFile) 34 | if err != nil { 35 | log.Fatal(err) 36 | } 37 | 38 | options, err := config.FromBytes(configBytes) 39 | if err != nil { 40 | log.Fatal(err) 41 | } 42 | 43 | server := mqtt.New(options) 44 | 45 | go func() { 46 | err := server.Serve() 47 | if err != nil { 48 | log.Fatal(err) 49 | } 50 | }() 51 | 52 | <-done 53 | server.Log.Warn("caught signal, stopping...") 54 | _ = server.Close() 55 | server.Log.Info("mochi mqtt shutdown complete") 56 | } 57 | -------------------------------------------------------------------------------- /cmd/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "crypto/tls" 9 | "flag" 10 | "log" 11 | "os" 12 | "os/signal" 13 | "syscall" 14 | 15 | mqtt "github.com/mochi-mqtt/server/v2" 16 | "github.com/mochi-mqtt/server/v2/hooks/auth" 17 | "github.com/mochi-mqtt/server/v2/listeners" 18 | ) 19 | 20 | func main() { 21 | tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener") 22 | wsAddr := flag.String("ws", ":1882", "network address for Websocket listener") 23 | infoAddr := flag.String("info", ":8080", "network address for web info dashboard listener") 24 | tlsCertFile := flag.String("tls-cert-file", "", "TLS certificate file") 25 | tlsKeyFile := flag.String("tls-key-file", "", "TLS key file") 26 | flag.Parse() 27 | 28 | sigs := make(chan os.Signal, 1) 29 | done := make(chan bool, 1) 30 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 31 | go func() { 32 | <-sigs 33 | done <- true 34 | }() 35 | 36 | var tlsConfig *tls.Config 37 | 38 | if tlsCertFile != nil && tlsKeyFile != nil && *tlsCertFile != "" && *tlsKeyFile != "" { 39 | cert, err := tls.LoadX509KeyPair(*tlsCertFile, *tlsKeyFile) 40 | if err != nil { 41 | return 42 | } 43 | tlsConfig = &tls.Config{ 44 | Certificates: []tls.Certificate{cert}, 45 | } 46 | } 47 | 48 | server := mqtt.New(nil) 49 | _ = server.AddHook(new(auth.AllowHook), nil) 50 | 51 | tcp := listeners.NewTCP(listeners.Config{ 52 | ID: "t1", 53 | Address: *tcpAddr, 54 | TLSConfig: tlsConfig, 55 | }) 56 | err := server.AddListener(tcp) 57 | if err != nil { 58 | log.Fatal(err) 59 | } 60 | 61 | ws := listeners.NewWebsocket(listeners.Config{ 62 | ID: "ws1", 63 | Address: *wsAddr, 64 | }) 65 | err = server.AddListener(ws) 66 | if err != nil { 67 | log.Fatal(err) 68 | } 69 | 70 | stats := listeners.NewHTTPStats( 71 | listeners.Config{ 72 | ID: "info", 73 | Address: *infoAddr, 74 | }, 75 | server.Info, 76 | ) 77 | err = server.AddListener(stats) 78 | if err != nil { 79 | log.Fatal(err) 80 | } 81 | 82 | go func() { 83 | err := server.Serve() 84 | if err != nil { 85 | log.Fatal(err) 86 | } 87 | }() 88 | 89 | <-done 90 | server.Log.Warn("caught signal, stopping...") 91 | _ = server.Close() 92 | server.Log.Info("mochi mqtt shutdown complete") 93 | } 94 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | listeners: 2 | - type: "tcp" 3 | id: "tcp1" 4 | address: ":1883" 5 | - type: "ws" 6 | id: "ws1" 7 | address: ":1882" 8 | - type: "sysinfo" 9 | id: "stats" 10 | address: ":1880" 11 | hooks: 12 | auth: 13 | allow_all: true 14 | options: 15 | inline_client: true 16 | -------------------------------------------------------------------------------- /config/config.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package config 6 | 7 | import ( 8 | "encoding/json" 9 | "log/slog" 10 | "os" 11 | 12 | "github.com/mochi-mqtt/server/v2/hooks/auth" 13 | "github.com/mochi-mqtt/server/v2/hooks/debug" 14 | "github.com/mochi-mqtt/server/v2/hooks/storage/badger" 15 | "github.com/mochi-mqtt/server/v2/hooks/storage/bolt" 16 | "github.com/mochi-mqtt/server/v2/hooks/storage/pebble" 17 | "github.com/mochi-mqtt/server/v2/hooks/storage/redis" 18 | "github.com/mochi-mqtt/server/v2/listeners" 19 | "gopkg.in/yaml.v3" 20 | 21 | mqtt "github.com/mochi-mqtt/server/v2" 22 | ) 23 | 24 | // config defines the structure of configuration data to be parsed from a config source. 25 | type config struct { 26 | Options mqtt.Options 27 | Listeners []listeners.Config `yaml:"listeners" json:"listeners"` 28 | HookConfigs HookConfigs `yaml:"hooks" json:"hooks"` 29 | LoggingConfig LoggingConfig `yaml:"logging" json:"logging"` 30 | } 31 | 32 | type LoggingConfig struct { 33 | Level string 34 | } 35 | 36 | func (lc LoggingConfig) ToLogger() *slog.Logger { 37 | var level slog.Level 38 | if err := level.UnmarshalText([]byte(lc.Level)); err != nil { 39 | level = slog.LevelInfo 40 | } 41 | 42 | leveler := new(slog.LevelVar) 43 | leveler.Set(level) 44 | return slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ 45 | Level: leveler, 46 | })) 47 | } 48 | 49 | // HookConfigs contains configurations to enable individual hooks. 50 | type HookConfigs struct { 51 | Auth *HookAuthConfig `yaml:"auth" json:"auth"` 52 | Storage *HookStorageConfig `yaml:"storage" json:"storage"` 53 | Debug *debug.Options `yaml:"debug" json:"debug"` 54 | } 55 | 56 | // HookAuthConfig contains configurations for the auth hook. 57 | type HookAuthConfig struct { 58 | Ledger auth.Ledger `yaml:"ledger" json:"ledger"` 59 | AllowAll bool `yaml:"allow_all" json:"allow_all"` 60 | } 61 | 62 | // HookStorageConfig contains configurations for the different storage hooks. 63 | type HookStorageConfig struct { 64 | Badger *badger.Options `yaml:"badger" json:"badger"` 65 | Bolt *bolt.Options `yaml:"bolt" json:"bolt"` 66 | Pebble *pebble.Options `yaml:"pebble" json:"pebble"` 67 | Redis *redis.Options `yaml:"redis" json:"redis"` 68 | } 69 | 70 | // ToHooks converts Hook file configurations into Hooks to be added to the server. 71 | func (hc HookConfigs) ToHooks() []mqtt.HookLoadConfig { 72 | var hlc []mqtt.HookLoadConfig 73 | 74 | if hc.Auth != nil { 75 | hlc = append(hlc, hc.toHooksAuth()...) 76 | } 77 | 78 | if hc.Storage != nil { 79 | hlc = append(hlc, hc.toHooksStorage()...) 80 | } 81 | 82 | if hc.Debug != nil { 83 | hlc = append(hlc, mqtt.HookLoadConfig{ 84 | Hook: new(debug.Hook), 85 | Config: hc.Debug, 86 | }) 87 | } 88 | 89 | return hlc 90 | } 91 | 92 | // toHooksAuth converts auth hook configurations into auth hooks. 93 | func (hc HookConfigs) toHooksAuth() []mqtt.HookLoadConfig { 94 | var hlc []mqtt.HookLoadConfig 95 | if hc.Auth.AllowAll { 96 | hlc = append(hlc, mqtt.HookLoadConfig{ 97 | Hook: new(auth.AllowHook), 98 | }) 99 | } else { 100 | hlc = append(hlc, mqtt.HookLoadConfig{ 101 | Hook: new(auth.Hook), 102 | Config: &auth.Options{ 103 | Ledger: &auth.Ledger{ // avoid copying sync.Locker 104 | Users: hc.Auth.Ledger.Users, 105 | Auth: hc.Auth.Ledger.Auth, 106 | ACL: hc.Auth.Ledger.ACL, 107 | }, 108 | }, 109 | }) 110 | } 111 | return hlc 112 | } 113 | 114 | // toHooksAuth converts storage hook configurations into storage hooks. 115 | func (hc HookConfigs) toHooksStorage() []mqtt.HookLoadConfig { 116 | var hlc []mqtt.HookLoadConfig 117 | if hc.Storage.Badger != nil { 118 | hlc = append(hlc, mqtt.HookLoadConfig{ 119 | Hook: new(badger.Hook), 120 | Config: hc.Storage.Badger, 121 | }) 122 | } 123 | 124 | if hc.Storage.Bolt != nil { 125 | hlc = append(hlc, mqtt.HookLoadConfig{ 126 | Hook: new(bolt.Hook), 127 | Config: hc.Storage.Bolt, 128 | }) 129 | } 130 | 131 | if hc.Storage.Redis != nil { 132 | hlc = append(hlc, mqtt.HookLoadConfig{ 133 | Hook: new(redis.Hook), 134 | Config: hc.Storage.Redis, 135 | }) 136 | } 137 | 138 | if hc.Storage.Pebble != nil { 139 | hlc = append(hlc, mqtt.HookLoadConfig{ 140 | Hook: new(pebble.Hook), 141 | Config: hc.Storage.Pebble, 142 | }) 143 | } 144 | return hlc 145 | } 146 | 147 | // FromBytes unmarshals a byte slice of JSON or YAML config data into a valid server options value. 148 | // Any hooks configurations are converted into Hooks using the toHooks methods in this package. 149 | func FromBytes(b []byte) (*mqtt.Options, error) { 150 | c := new(config) 151 | o := mqtt.Options{} 152 | 153 | if len(b) == 0 { 154 | return nil, nil 155 | } 156 | 157 | if b[0] == '{' { 158 | err := json.Unmarshal(b, c) 159 | if err != nil { 160 | return nil, err 161 | } 162 | } else { 163 | err := yaml.Unmarshal(b, c) 164 | if err != nil { 165 | return nil, err 166 | } 167 | } 168 | 169 | o = c.Options 170 | o.Hooks = c.HookConfigs.ToHooks() 171 | o.Listeners = c.Listeners 172 | o.Logger = c.LoggingConfig.ToLogger() 173 | 174 | return &o, nil 175 | } 176 | -------------------------------------------------------------------------------- /config/config_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package config 6 | 7 | import ( 8 | "log/slog" 9 | "os" 10 | "testing" 11 | 12 | "github.com/stretchr/testify/require" 13 | 14 | "github.com/mochi-mqtt/server/v2/hooks/auth" 15 | "github.com/mochi-mqtt/server/v2/hooks/storage/badger" 16 | "github.com/mochi-mqtt/server/v2/hooks/storage/bolt" 17 | "github.com/mochi-mqtt/server/v2/hooks/storage/pebble" 18 | "github.com/mochi-mqtt/server/v2/hooks/storage/redis" 19 | "github.com/mochi-mqtt/server/v2/listeners" 20 | 21 | mqtt "github.com/mochi-mqtt/server/v2" 22 | ) 23 | 24 | var ( 25 | yamlBytes = []byte(` 26 | listeners: 27 | - type: "tcp" 28 | id: "file-tcp1" 29 | address: ":1883" 30 | hooks: 31 | auth: 32 | allow_all: true 33 | options: 34 | client_net_write_buffer_size: 2048 35 | capabilities: 36 | minimum_protocol_version: 3 37 | compatibilities: 38 | restore_sys_info_on_restart: true 39 | `) 40 | 41 | jsonBytes = []byte(`{ 42 | "listeners": [ 43 | { 44 | "type": "tcp", 45 | "id": "file-tcp1", 46 | "address": ":1883" 47 | } 48 | ], 49 | "hooks": { 50 | "auth": { 51 | "allow_all": true 52 | } 53 | }, 54 | "options": { 55 | "client_net_write_buffer_size": 2048, 56 | "capabilities": { 57 | "minimum_protocol_version": 3, 58 | "compatibilities": { 59 | "restore_sys_info_on_restart": true 60 | } 61 | } 62 | } 63 | } 64 | `) 65 | parsedOptions = mqtt.Options{ 66 | Listeners: []listeners.Config{ 67 | { 68 | Type: listeners.TypeTCP, 69 | ID: "file-tcp1", 70 | Address: ":1883", 71 | }, 72 | }, 73 | Hooks: []mqtt.HookLoadConfig{ 74 | { 75 | Hook: new(auth.AllowHook), 76 | }, 77 | }, 78 | ClientNetWriteBufferSize: 2048, 79 | Capabilities: &mqtt.Capabilities{ 80 | MinimumProtocolVersion: 3, 81 | Compatibilities: mqtt.Compatibilities{ 82 | RestoreSysInfoOnRestart: true, 83 | }, 84 | }, 85 | Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ 86 | Level: new(slog.LevelVar), 87 | })), 88 | } 89 | ) 90 | 91 | func TestFromBytesEmptyL(t *testing.T) { 92 | _, err := FromBytes([]byte{}) 93 | require.NoError(t, err) 94 | } 95 | 96 | func TestFromBytesYAML(t *testing.T) { 97 | o, err := FromBytes(yamlBytes) 98 | require.NoError(t, err) 99 | require.Equal(t, parsedOptions, *o) 100 | } 101 | 102 | func TestFromBytesYAMLError(t *testing.T) { 103 | _, err := FromBytes(append(yamlBytes, 'a')) 104 | require.Error(t, err) 105 | } 106 | 107 | func TestFromBytesJSON(t *testing.T) { 108 | o, err := FromBytes(jsonBytes) 109 | require.NoError(t, err) 110 | require.Equal(t, parsedOptions, *o) 111 | } 112 | 113 | func TestFromBytesJSONError(t *testing.T) { 114 | _, err := FromBytes(append(jsonBytes, 'a')) 115 | require.Error(t, err) 116 | } 117 | 118 | func TestToHooksAuthAllowAll(t *testing.T) { 119 | hc := HookConfigs{ 120 | Auth: &HookAuthConfig{ 121 | AllowAll: true, 122 | }, 123 | } 124 | 125 | th := hc.toHooksAuth() 126 | expect := []mqtt.HookLoadConfig{ 127 | {Hook: new(auth.AllowHook)}, 128 | } 129 | require.Equal(t, expect, th) 130 | } 131 | 132 | func TestToHooksAuthAllowLedger(t *testing.T) { 133 | hc := HookConfigs{ 134 | Auth: &HookAuthConfig{ 135 | Ledger: auth.Ledger{ 136 | Auth: auth.AuthRules{ 137 | {Username: "peach", Password: "password1", Allow: true}, 138 | }, 139 | }, 140 | }, 141 | } 142 | 143 | th := hc.toHooksAuth() 144 | expect := []mqtt.HookLoadConfig{ 145 | { 146 | Hook: new(auth.Hook), 147 | Config: &auth.Options{ 148 | Ledger: &auth.Ledger{ // avoid copying sync.Locker 149 | Auth: auth.AuthRules{ 150 | {Username: "peach", Password: "password1", Allow: true}, 151 | }, 152 | }, 153 | }, 154 | }, 155 | } 156 | require.Equal(t, expect, th) 157 | } 158 | 159 | func TestToHooksStorageBadger(t *testing.T) { 160 | hc := HookConfigs{ 161 | Storage: &HookStorageConfig{ 162 | Badger: &badger.Options{ 163 | Path: "badger", 164 | }, 165 | }, 166 | } 167 | 168 | th := hc.toHooksStorage() 169 | expect := []mqtt.HookLoadConfig{ 170 | { 171 | Hook: new(badger.Hook), 172 | Config: hc.Storage.Badger, 173 | }, 174 | } 175 | 176 | require.Equal(t, expect, th) 177 | } 178 | 179 | func TestToHooksStorageBolt(t *testing.T) { 180 | hc := HookConfigs{ 181 | Storage: &HookStorageConfig{ 182 | Bolt: &bolt.Options{ 183 | Path: "bolt", 184 | Bucket: "mochi", 185 | }, 186 | }, 187 | } 188 | 189 | th := hc.toHooksStorage() 190 | expect := []mqtt.HookLoadConfig{ 191 | { 192 | Hook: new(bolt.Hook), 193 | Config: hc.Storage.Bolt, 194 | }, 195 | } 196 | 197 | require.Equal(t, expect, th) 198 | } 199 | 200 | func TestToHooksStorageRedis(t *testing.T) { 201 | hc := HookConfigs{ 202 | Storage: &HookStorageConfig{ 203 | Redis: &redis.Options{ 204 | Username: "test", 205 | }, 206 | }, 207 | } 208 | 209 | th := hc.toHooksStorage() 210 | expect := []mqtt.HookLoadConfig{ 211 | { 212 | Hook: new(redis.Hook), 213 | Config: hc.Storage.Redis, 214 | }, 215 | } 216 | 217 | require.Equal(t, expect, th) 218 | } 219 | 220 | func TestToHooksStoragePebble(t *testing.T) { 221 | hc := HookConfigs{ 222 | Storage: &HookStorageConfig{ 223 | Pebble: &pebble.Options{ 224 | Path: "pebble", 225 | }, 226 | }, 227 | } 228 | 229 | th := hc.toHooksStorage() 230 | expect := []mqtt.HookLoadConfig{ 231 | { 232 | Hook: new(pebble.Hook), 233 | Config: hc.Storage.Pebble, 234 | }, 235 | } 236 | 237 | require.Equal(t, expect, th) 238 | } 239 | -------------------------------------------------------------------------------- /examples/auth/basic/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "log" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | 13 | mqtt "github.com/mochi-mqtt/server/v2" 14 | "github.com/mochi-mqtt/server/v2/hooks/auth" 15 | "github.com/mochi-mqtt/server/v2/listeners" 16 | ) 17 | 18 | func main() { 19 | sigs := make(chan os.Signal, 1) 20 | done := make(chan bool, 1) 21 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 22 | go func() { 23 | <-sigs 24 | done <- true 25 | }() 26 | 27 | authRules := &auth.Ledger{ 28 | Auth: auth.AuthRules{ // Auth disallows all by default 29 | {Username: "peach", Password: "password1", Allow: true}, 30 | {Username: "melon", Password: "password2", Allow: true}, 31 | {Remote: "127.0.0.1:*", Allow: true}, 32 | {Remote: "localhost:*", Allow: true}, 33 | }, 34 | ACL: auth.ACLRules{ // ACL allows all by default 35 | {Remote: "127.0.0.1:*"}, // local superuser allow all 36 | { 37 | // user melon can read and write to their own topic 38 | Username: "melon", Filters: auth.Filters{ 39 | "melon/#": auth.ReadWrite, 40 | "updates/#": auth.WriteOnly, // can write to updates, but can't read updates from others 41 | }, 42 | }, 43 | { 44 | // Otherwise, no clients have publishing permissions 45 | Filters: auth.Filters{ 46 | "#": auth.ReadOnly, 47 | "updates/#": auth.Deny, 48 | }, 49 | }, 50 | }, 51 | } 52 | 53 | // you may also find this useful... 54 | // d, _ := authRules.ToYAML() 55 | // d, _ := authRules.ToJSON() 56 | // fmt.Println(string(d)) 57 | 58 | server := mqtt.New(nil) 59 | err := server.AddHook(new(auth.Hook), &auth.Options{ 60 | Ledger: authRules, 61 | }) 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | 66 | tcp := listeners.NewTCP(listeners.Config{ 67 | ID: "t1", 68 | Address: ":1883", 69 | }) 70 | err = server.AddListener(tcp) 71 | if err != nil { 72 | log.Fatal(err) 73 | } 74 | 75 | go func() { 76 | err := server.Serve() 77 | if err != nil { 78 | log.Fatal(err) 79 | } 80 | }() 81 | 82 | <-done 83 | server.Log.Warn("caught signal, stopping...") 84 | _ = server.Close() 85 | server.Log.Info("main.go finished") 86 | } 87 | -------------------------------------------------------------------------------- /examples/auth/encoded/auth.json: -------------------------------------------------------------------------------- 1 | { 2 | "auth": [ 3 | { 4 | "username": "peach", 5 | "password": "password1", 6 | "allow": true 7 | }, 8 | { 9 | "username": "melon", 10 | "password": "password2", 11 | "allow": true 12 | }, 13 | { 14 | "remote": "127.0.0.1:*", 15 | "allow": false 16 | }, 17 | { 18 | "remote": "localhost:*", 19 | "allow": false 20 | } 21 | ], 22 | "acl": [ 23 | { 24 | "remote": "127.0.0.1:*" 25 | }, 26 | { 27 | "username": "melon", 28 | "filters": { 29 | "melon/#": 3, 30 | "updates/#": 2 31 | } 32 | }, 33 | { 34 | "filters": { 35 | "#": 1, 36 | "updates/#": 0 37 | } 38 | } 39 | ] 40 | } -------------------------------------------------------------------------------- /examples/auth/encoded/auth.yaml: -------------------------------------------------------------------------------- 1 | auth: 2 | - username: peach 3 | password: password1 4 | allow: true 5 | - username: melon 6 | password: password2 7 | allow: true 8 | # - remote: 127.0.0.1:* 9 | # allow: true 10 | # - remote: localhost:* 11 | # allow: true 12 | acl: 13 | # 0 = deny, 1 = read only, 2 = write only, 3 = read and write 14 | - remote: 127.0.0.1:* 15 | - username: melon 16 | filters: 17 | melon/#: 3 18 | updates/#: 2 19 | - filters: 20 | '#': 1 21 | updates/#: 0 -------------------------------------------------------------------------------- /examples/auth/encoded/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "flag" 9 | "log" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | 14 | mqtt "github.com/mochi-mqtt/server/v2" 15 | "github.com/mochi-mqtt/server/v2/hooks/auth" 16 | "github.com/mochi-mqtt/server/v2/listeners" 17 | ) 18 | 19 | func main() { 20 | sigs := make(chan os.Signal, 1) 21 | done := make(chan bool, 1) 22 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 23 | go func() { 24 | <-sigs 25 | done <- true 26 | }() 27 | 28 | // You can also run from top-level server.go folder: 29 | // go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.yaml 30 | // go run examples/auth/encoded/main.go --path=examples/auth/encoded/auth.json 31 | path := flag.String("path", "auth.yaml", "path to data auth file") 32 | flag.Parse() 33 | 34 | // Get ledger from yaml file 35 | data, err := os.ReadFile(*path) 36 | if err != nil { 37 | log.Fatal(err) 38 | } 39 | 40 | server := mqtt.New(nil) 41 | err = server.AddHook(new(auth.Hook), &auth.Options{ 42 | Data: data, // build ledger from byte slice, yaml or json 43 | }) 44 | if err != nil { 45 | log.Fatal(err) 46 | } 47 | 48 | tcp := listeners.NewTCP(listeners.Config{ 49 | ID: "t1", 50 | Address: ":1883", 51 | }) 52 | err = server.AddListener(tcp) 53 | if err != nil { 54 | log.Fatal(err) 55 | } 56 | 57 | go func() { 58 | err := server.Serve() 59 | if err != nil { 60 | log.Fatal(err) 61 | } 62 | }() 63 | 64 | <-done 65 | server.Log.Warn("caught signal, stopping...") 66 | _ = server.Close() 67 | server.Log.Info("main.go finished") 68 | } 69 | -------------------------------------------------------------------------------- /examples/benchmark/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "flag" 9 | "log" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | 14 | mqtt "github.com/mochi-mqtt/server/v2" 15 | "github.com/mochi-mqtt/server/v2/hooks/auth" 16 | "github.com/mochi-mqtt/server/v2/listeners" 17 | ) 18 | 19 | func main() { 20 | tcpAddr := flag.String("tcp", ":1883", "network address for TCP listener") 21 | flag.Parse() 22 | 23 | sigs := make(chan os.Signal, 1) 24 | done := make(chan bool, 1) 25 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 26 | go func() { 27 | <-sigs 28 | done <- true 29 | }() 30 | 31 | server := mqtt.New(nil) 32 | server.Options.Capabilities.MaximumClientWritesPending = 16 * 1024 33 | _ = server.AddHook(new(auth.AllowHook), nil) 34 | 35 | tcp := listeners.NewTCP(listeners.Config{ 36 | ID: "t1", 37 | Address: *tcpAddr, 38 | }) 39 | err := server.AddListener(tcp) 40 | if err != nil { 41 | log.Fatal(err) 42 | } 43 | 44 | go func() { 45 | err := server.Serve() 46 | if err != nil { 47 | log.Fatal(err) 48 | } 49 | }() 50 | 51 | <-done 52 | server.Log.Warn("caught signal, stopping...") 53 | _ = server.Close() 54 | server.Log.Info("main.go finished") 55 | } 56 | -------------------------------------------------------------------------------- /examples/config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "listeners": [ 3 | { 4 | "type": "tcp", 5 | "id": "file-tcp1", 6 | "address": ":1883" 7 | }, 8 | { 9 | "type": "ws", 10 | "id": "file-websocket", 11 | "address": ":1882" 12 | }, 13 | { 14 | "type": "healthcheck", 15 | "id": "file-healthcheck", 16 | "address": ":1880" 17 | } 18 | ], 19 | "hooks": { 20 | "debug": { 21 | "enable": true 22 | }, 23 | "storage": { 24 | "pebble": { 25 | "path": "pebble.db", 26 | "mode": "NoSync" 27 | }, 28 | "badger": { 29 | "path": "badger.db", 30 | "gc_interval": 3, 31 | "gc_discard_ratio": 0.5 32 | }, 33 | "bolt": { 34 | "path": "bolt.db", 35 | "bucket": "mochi" 36 | }, 37 | "redis": { 38 | "h_prefix": "mc", 39 | "username": "mochi", 40 | "password": "melon", 41 | "address": "localhost:6379", 42 | "database": 1 43 | } 44 | }, 45 | "auth": { 46 | "allow_all": false, 47 | "ledger": { 48 | "auth": [ 49 | { 50 | "username": "peach", 51 | "password": "password1", 52 | "allow": true 53 | } 54 | ], 55 | "acl": [ 56 | { 57 | "remote": "127.0.0.1:*" 58 | }, 59 | { 60 | "username": "melon", 61 | "filters": null, 62 | "melon/#": 3, 63 | "updates/#": 2 64 | } 65 | ] 66 | } 67 | } 68 | }, 69 | "options": { 70 | "client_net_write_buffer_size": 2048, 71 | "client_net_read_buffer_size": 2048, 72 | "sys_topic_resend_interval": 10, 73 | "inline_client": true, 74 | "capabilities": { 75 | "maximum_message_expiry_interval": 100, 76 | "maximum_client_writes_pending": 8192, 77 | "maximum_session_expiry_interval": 86400, 78 | "maximum_packet_size": 0, 79 | "receive_maximum": 1024, 80 | "maximum_inflight": 8192, 81 | "topic_alias_maximum": 65535, 82 | "shared_sub_available": 1, 83 | "minimum_protocol_version": 3, 84 | "maximum_qos": 2, 85 | "retain_available": 1, 86 | "wildcard_sub_available": 1, 87 | "sub_id_available": 1, 88 | "compatibilities": { 89 | "obscure_not_authorized": true, 90 | "passive_client_disconnect": false, 91 | "always_return_response_info": false, 92 | "restore_sys_info_on_restart": false, 93 | "no_inherited_properties_on_ack": false 94 | } 95 | } 96 | } 97 | } -------------------------------------------------------------------------------- /examples/config/config.yaml: -------------------------------------------------------------------------------- 1 | listeners: 2 | - type: "tcp" 3 | id: "file-tcp1" 4 | address: ":1883" 5 | - type: "ws" 6 | id: "file-websocket" 7 | address: ":1882" 8 | - type: "healthcheck" 9 | id: "file-healthcheck" 10 | address: ":1880" 11 | hooks: 12 | debug: 13 | enable: true 14 | storage: 15 | badger: 16 | path: badger.db 17 | gc_interval: 3 18 | gc_discard_ratio: 0.5 19 | pebble: 20 | path: pebble.db 21 | mode: "NoSync" 22 | bolt: 23 | path: bolt.db 24 | bucket: "mochi" 25 | redis: 26 | h_prefix: "mc" 27 | username: "mochi" 28 | password: "melon" 29 | address: "localhost:6379" 30 | database: 1 31 | auth: 32 | allow_all: false 33 | ledger: 34 | auth: 35 | - username: peach 36 | password: password1 37 | allow: true 38 | acl: 39 | - remote: 127.0.0.1:* 40 | - username: melon 41 | filters: 42 | melon/#: 3 43 | updates/#: 2 44 | options: 45 | client_net_write_buffer_size: 2048 46 | client_net_read_buffer_size: 2048 47 | sys_topic_resend_interval: 10 48 | inline_client: true 49 | capabilities: 50 | maximum_message_expiry_interval: 100 51 | maximum_client_writes_pending: 8192 52 | maximum_session_expiry_interval: 86400 53 | maximum_packet_size: 0 54 | receive_maximum: 1024 55 | maximum_inflight: 8192 56 | topic_alias_maximum: 65535 57 | shared_sub_available: 1 58 | minimum_protocol_version: 3 59 | maximum_qos: 2 60 | retain_available: 1 61 | wildcard_sub_available: 1 62 | sub_id_available: 1 63 | compatibilities: 64 | obscure_not_authorized: true 65 | passive_client_disconnect: false 66 | always_return_response_info: false 67 | restore_sys_info_on_restart: false 68 | no_inherited_properties_on_ack: false 69 | logging: 70 | level: INFO 71 | -------------------------------------------------------------------------------- /examples/config/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "github.com/mochi-mqtt/server/v2/config" 9 | "log" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | 14 | mqtt "github.com/mochi-mqtt/server/v2" 15 | ) 16 | 17 | func main() { 18 | sigs := make(chan os.Signal, 1) 19 | done := make(chan bool, 1) 20 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 21 | go func() { 22 | <-sigs 23 | done <- true 24 | }() 25 | 26 | configBytes, err := os.ReadFile("config.json") 27 | if err != nil { 28 | log.Fatal(err) 29 | } 30 | 31 | options, err := config.FromBytes(configBytes) 32 | if err != nil { 33 | log.Fatal(err) 34 | } 35 | 36 | server := mqtt.New(options) 37 | 38 | go func() { 39 | err := server.Serve() 40 | if err != nil { 41 | log.Fatal(err) 42 | } 43 | }() 44 | 45 | <-done 46 | server.Log.Warn("caught signal, stopping...") 47 | _ = server.Close() 48 | server.Log.Info("main.go finished") 49 | } 50 | -------------------------------------------------------------------------------- /examples/debug/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "log" 9 | "log/slog" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | 14 | mqtt "github.com/mochi-mqtt/server/v2" 15 | "github.com/mochi-mqtt/server/v2/hooks/auth" 16 | "github.com/mochi-mqtt/server/v2/hooks/debug" 17 | "github.com/mochi-mqtt/server/v2/listeners" 18 | ) 19 | 20 | func main() { 21 | sigs := make(chan os.Signal, 1) 22 | done := make(chan bool, 1) 23 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 24 | go func() { 25 | <-sigs 26 | done <- true 27 | }() 28 | 29 | server := mqtt.New(nil) 30 | 31 | level := new(slog.LevelVar) 32 | server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ 33 | Level: level, 34 | })) 35 | level.Set(slog.LevelDebug) 36 | 37 | err := server.AddHook(new(debug.Hook), &debug.Options{ 38 | // ShowPacketData: true, 39 | }) 40 | if err != nil { 41 | log.Fatal(err) 42 | } 43 | 44 | err = server.AddHook(new(auth.AllowHook), nil) 45 | if err != nil { 46 | log.Fatal(err) 47 | } 48 | 49 | tcp := listeners.NewTCP(listeners.Config{ 50 | ID: "t1", 51 | Address: ":1883", 52 | }) 53 | err = server.AddListener(tcp) 54 | if err != nil { 55 | log.Fatal(err) 56 | } 57 | 58 | go func() { 59 | err := server.Serve() 60 | if err != nil { 61 | log.Fatal(err) 62 | } 63 | }() 64 | 65 | <-done 66 | server.Log.Warn("caught signal, stopping...") 67 | _ = server.Close() 68 | server.Log.Info("main.go finished") 69 | } 70 | -------------------------------------------------------------------------------- /examples/direct/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "log" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | "time" 13 | 14 | "github.com/mochi-mqtt/server/v2/hooks/auth" 15 | 16 | mqtt "github.com/mochi-mqtt/server/v2" 17 | "github.com/mochi-mqtt/server/v2/packets" 18 | ) 19 | 20 | func main() { 21 | sigs := make(chan os.Signal, 1) 22 | done := make(chan bool, 1) 23 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 24 | go func() { 25 | <-sigs 26 | done <- true 27 | }() 28 | 29 | server := mqtt.New(&mqtt.Options{ 30 | InlineClient: true, // you must enable inline client to use direct publishing and subscribing. 31 | }) 32 | _ = server.AddHook(new(auth.AllowHook), nil) 33 | 34 | // Start the server 35 | go func() { 36 | err := server.Serve() 37 | if err != nil { 38 | log.Fatal(err) 39 | } 40 | }() 41 | 42 | // Demonstration of using an inline client to directly subscribe to a topic and receive a message when 43 | // that subscription is activated. The inline subscription method uses the same internal subscription logic 44 | // as used for external (normal) clients. 45 | go func() { 46 | // Inline subscriptions can also receive retained messages on subscription. 47 | _ = server.Publish("direct/retained", []byte("retained message"), true, 0) 48 | _ = server.Publish("direct/alternate/retained", []byte("some other retained message"), true, 0) 49 | 50 | // Subscribe to a filter and handle any received messages via a callback function. 51 | callbackFn := func(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) { 52 | server.Log.Info("inline client received message from subscription", "client", cl.ID, "subscriptionId", sub.Identifier, "topic", pk.TopicName, "payload", string(pk.Payload)) 53 | } 54 | server.Log.Info("inline client subscribing") 55 | _ = server.Subscribe("direct/#", 1, callbackFn) 56 | _ = server.Subscribe("direct/#", 2, callbackFn) 57 | }() 58 | 59 | // There is a shorthand convenience function, Publish, for easily sending publish packets if you are not 60 | // concerned with creating your own packets. If you want to have more control over your packets, you can 61 | //directly inject a packet of any kind into the broker. See examples/hooks/main.go for usage. 62 | go func() { 63 | for range time.Tick(time.Second * 3) { 64 | err := server.Publish("direct/publish", []byte("scheduled message"), false, 0) 65 | if err != nil { 66 | server.Log.Error("server.Publish", "error", err) 67 | } 68 | server.Log.Info("main.go issued direct message to direct/publish") 69 | } 70 | }() 71 | 72 | go func() { 73 | time.Sleep(time.Second * 10) 74 | // Unsubscribe from the same filter to stop receiving messages. 75 | server.Log.Info("inline client unsubscribing") 76 | _ = server.Unsubscribe("direct/#", 1) 77 | }() 78 | 79 | <-done 80 | server.Log.Warn("caught signal, stopping...") 81 | _ = server.Close() 82 | server.Log.Info("main.go finished") 83 | } 84 | -------------------------------------------------------------------------------- /examples/hooks/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "bytes" 9 | "fmt" 10 | "log" 11 | "os" 12 | "os/signal" 13 | "syscall" 14 | "time" 15 | 16 | mqtt "github.com/mochi-mqtt/server/v2" 17 | "github.com/mochi-mqtt/server/v2/hooks/auth" 18 | "github.com/mochi-mqtt/server/v2/listeners" 19 | "github.com/mochi-mqtt/server/v2/packets" 20 | ) 21 | 22 | func main() { 23 | sigs := make(chan os.Signal, 1) 24 | done := make(chan bool, 1) 25 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 26 | go func() { 27 | <-sigs 28 | done <- true 29 | }() 30 | 31 | server := mqtt.New(&mqtt.Options{ 32 | InlineClient: true, // you must enable inline client to use direct publishing and subscribing. 33 | }) 34 | 35 | _ = server.AddHook(new(auth.AllowHook), nil) 36 | tcp := listeners.NewTCP(listeners.Config{ 37 | ID: "t1", 38 | Address: ":1883", 39 | }) 40 | err := server.AddListener(tcp) 41 | if err != nil { 42 | log.Fatal(err) 43 | } 44 | 45 | // Add custom hook (ExampleHook) to the server 46 | err = server.AddHook(new(ExampleHook), &ExampleHookOptions{ 47 | Server: server, 48 | }) 49 | 50 | if err != nil { 51 | log.Fatal(err) 52 | } 53 | 54 | // Start the server 55 | go func() { 56 | err := server.Serve() 57 | if err != nil { 58 | log.Fatal(err) 59 | } 60 | }() 61 | 62 | // Demonstration of directly publishing messages to a topic via the 63 | // `server.Publish` method. Subscribe to `direct/publish` using your 64 | // MQTT client to see the messages. 65 | go func() { 66 | cl := server.NewClient(nil, "local", "inline", true) 67 | for range time.Tick(time.Second * 1) { 68 | err := server.InjectPacket(cl, packets.Packet{ 69 | FixedHeader: packets.FixedHeader{ 70 | Type: packets.Publish, 71 | }, 72 | TopicName: "direct/publish", 73 | Payload: []byte("injected scheduled message"), 74 | }) 75 | if err != nil { 76 | server.Log.Error("server.InjectPacket", "error", err) 77 | } 78 | server.Log.Info("main.go injected packet to direct/publish") 79 | } 80 | }() 81 | 82 | // There is also a shorthand convenience function, Publish, for easily sending 83 | // publish packets if you are not concerned with creating your own packets. 84 | go func() { 85 | for range time.Tick(time.Second * 5) { 86 | err := server.Publish("direct/publish", []byte("packet scheduled message"), false, 0) 87 | if err != nil { 88 | server.Log.Error("server.Publish", "error", err) 89 | } 90 | server.Log.Info("main.go issued direct message to direct/publish") 91 | } 92 | }() 93 | 94 | <-done 95 | server.Log.Warn("caught signal, stopping...") 96 | _ = server.Close() 97 | server.Log.Info("main.go finished") 98 | } 99 | 100 | // Options contains configuration settings for the hook. 101 | type ExampleHookOptions struct { 102 | Server *mqtt.Server 103 | } 104 | 105 | type ExampleHook struct { 106 | mqtt.HookBase 107 | config *ExampleHookOptions 108 | } 109 | 110 | func (h *ExampleHook) ID() string { 111 | return "events-example" 112 | } 113 | 114 | func (h *ExampleHook) Provides(b byte) bool { 115 | return bytes.Contains([]byte{ 116 | mqtt.OnConnect, 117 | mqtt.OnDisconnect, 118 | mqtt.OnSubscribed, 119 | mqtt.OnUnsubscribed, 120 | mqtt.OnPublished, 121 | mqtt.OnPublish, 122 | }, []byte{b}) 123 | } 124 | 125 | func (h *ExampleHook) Init(config any) error { 126 | h.Log.Info("initialised") 127 | if _, ok := config.(*ExampleHookOptions); !ok && config != nil { 128 | return mqtt.ErrInvalidConfigType 129 | } 130 | 131 | h.config = config.(*ExampleHookOptions) 132 | if h.config.Server == nil { 133 | return mqtt.ErrInvalidConfigType 134 | } 135 | return nil 136 | } 137 | 138 | // subscribeCallback handles messages for subscribed topics 139 | func (h *ExampleHook) subscribeCallback(cl *mqtt.Client, sub packets.Subscription, pk packets.Packet) { 140 | h.Log.Info("hook subscribed message", "client", cl.ID, "topic", pk.TopicName) 141 | } 142 | 143 | func (h *ExampleHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error { 144 | h.Log.Info("client connected", "client", cl.ID) 145 | 146 | // Example demonstrating how to subscribe to a topic within the hook. 147 | h.config.Server.Subscribe("hook/direct/publish", 1, h.subscribeCallback) 148 | 149 | // Example demonstrating how to publish a message within the hook 150 | err := h.config.Server.Publish("hook/direct/publish", []byte("packet hook message"), false, 0) 151 | if err != nil { 152 | h.Log.Error("hook.publish", "error", err) 153 | } 154 | 155 | return nil 156 | } 157 | 158 | func (h *ExampleHook) OnDisconnect(cl *mqtt.Client, err error, expire bool) { 159 | if err != nil { 160 | h.Log.Info("client disconnected", "client", cl.ID, "expire", expire, "error", err) 161 | } else { 162 | h.Log.Info("client disconnected", "client", cl.ID, "expire", expire) 163 | } 164 | 165 | } 166 | 167 | func (h *ExampleHook) OnSubscribed(cl *mqtt.Client, pk packets.Packet, reasonCodes []byte) { 168 | h.Log.Info(fmt.Sprintf("subscribed qos=%v", reasonCodes), "client", cl.ID, "filters", pk.Filters) 169 | } 170 | 171 | func (h *ExampleHook) OnUnsubscribed(cl *mqtt.Client, pk packets.Packet) { 172 | h.Log.Info("unsubscribed", "client", cl.ID, "filters", pk.Filters) 173 | } 174 | 175 | func (h *ExampleHook) OnPublish(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { 176 | h.Log.Info("received from client", "client", cl.ID, "payload", string(pk.Payload)) 177 | 178 | pkx := pk 179 | if string(pk.Payload) == "hello" { 180 | pkx.Payload = []byte("hello world") 181 | h.Log.Info("received modified packet from client", "client", cl.ID, "payload", string(pkx.Payload)) 182 | } 183 | 184 | return pkx, nil 185 | } 186 | 187 | func (h *ExampleHook) OnPublished(cl *mqtt.Client, pk packets.Packet) { 188 | h.Log.Info("published to client", "client", cl.ID, "payload", string(pk.Payload)) 189 | } 190 | -------------------------------------------------------------------------------- /examples/paho.testing/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "bytes" 9 | "log" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | 14 | mqtt "github.com/mochi-mqtt/server/v2" 15 | "github.com/mochi-mqtt/server/v2/listeners" 16 | "github.com/mochi-mqtt/server/v2/packets" 17 | ) 18 | 19 | func main() { 20 | sigs := make(chan os.Signal, 1) 21 | done := make(chan bool, 1) 22 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 23 | go func() { 24 | <-sigs 25 | done <- true 26 | }() 27 | 28 | server := mqtt.New(nil) 29 | server.Options.Capabilities.Compatibilities.ObscureNotAuthorized = true 30 | server.Options.Capabilities.Compatibilities.PassiveClientDisconnect = true 31 | server.Options.Capabilities.Compatibilities.NoInheritedPropertiesOnAck = true 32 | 33 | _ = server.AddHook(new(pahoAuthHook), nil) 34 | tcp := listeners.NewTCP(listeners.Config{ 35 | ID: "t1", 36 | Address: ":1883", 37 | }) 38 | err := server.AddListener(tcp) 39 | if err != nil { 40 | log.Fatal(err) 41 | } 42 | 43 | go func() { 44 | err := server.Serve() 45 | if err != nil { 46 | log.Fatal(err) 47 | } 48 | }() 49 | 50 | <-done 51 | server.Log.Warn("caught signal, stopping...") 52 | _ = server.Close() 53 | server.Log.Info("main.go finished") 54 | } 55 | 56 | type pahoAuthHook struct { 57 | mqtt.HookBase 58 | } 59 | 60 | func (h *pahoAuthHook) ID() string { 61 | return "allow-all-auth" 62 | } 63 | 64 | func (h *pahoAuthHook) Provides(b byte) bool { 65 | return bytes.Contains([]byte{ 66 | mqtt.OnConnectAuthenticate, 67 | mqtt.OnConnect, 68 | mqtt.OnACLCheck, 69 | }, []byte{b}) 70 | } 71 | 72 | func (h *pahoAuthHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { 73 | return true 74 | } 75 | 76 | func (h *pahoAuthHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { 77 | return topic != "test/nosubscribe" 78 | } 79 | 80 | func (h *pahoAuthHook) OnConnect(cl *mqtt.Client, pk packets.Packet) error { 81 | // Handle paho test_server_keep_alive 82 | if pk.Connect.Keepalive == 120 && pk.Connect.Clean { 83 | cl.State.Keepalive = 60 84 | cl.State.ServerKeepalive = true 85 | } 86 | return nil 87 | } 88 | -------------------------------------------------------------------------------- /examples/persistence/badger/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co, werbenhu 4 | 5 | package main 6 | 7 | import ( 8 | "log" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | 13 | badgerdb "github.com/dgraph-io/badger/v4" 14 | mqtt "github.com/mochi-mqtt/server/v2" 15 | "github.com/mochi-mqtt/server/v2/hooks/auth" 16 | "github.com/mochi-mqtt/server/v2/hooks/storage/badger" 17 | "github.com/mochi-mqtt/server/v2/listeners" 18 | ) 19 | 20 | func main() { 21 | badgerPath := ".badger" 22 | defer os.RemoveAll(badgerPath) // remove the example badger files at the end 23 | 24 | sigs := make(chan os.Signal, 1) 25 | done := make(chan bool, 1) 26 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 27 | go func() { 28 | <-sigs 29 | done <- true 30 | }() 31 | 32 | server := mqtt.New(nil) 33 | _ = server.AddHook(new(auth.AllowHook), nil) 34 | 35 | badgerOpts := badgerdb.DefaultOptions(badgerPath) // BadgerDB options. Adjust according to your actual scenario. 36 | badgerOpts.ValueLogFileSize = 100 * (1 << 20) // Set the default size of the log file to 100 MB. 37 | 38 | // AddHook adds a BadgerDB hook to the server with the specified options. 39 | // GcInterval specifies the interval at which BadgerDB garbage collection process runs. 40 | // Refer to https://dgraph.io/docs/badger/get-started/#garbage-collection for more information. 41 | err := server.AddHook(new(badger.Hook), &badger.Options{ 42 | Path: badgerPath, 43 | 44 | // Set the interval for garbage collection. Adjust according to your actual scenario. 45 | GcInterval: 5 * 60, 46 | 47 | // GcDiscardRatio specifies the ratio of log discard compared to the maximum possible log discard. 48 | // Setting it to a higher value would result in fewer space reclaims, while setting it to a lower value 49 | // would result in more space reclaims at the cost of increased activity on the LSM tree. 50 | // discardRatio must be in the range (0.0, 1.0), both endpoints excluded, otherwise, it will be set to the default value of 0.5. 51 | // Adjust according to your actual scenario. 52 | GcDiscardRatio: 0.5, 53 | 54 | Options: &badgerOpts, 55 | }) 56 | if err != nil { 57 | log.Fatal(err) 58 | } 59 | 60 | tcp := listeners.NewTCP(listeners.Config{ 61 | ID: "t1", 62 | Address: ":1883", 63 | }) 64 | err = server.AddListener(tcp) 65 | if err != nil { 66 | log.Fatal(err) 67 | } 68 | 69 | go func() { 70 | err := server.Serve() 71 | if err != nil { 72 | log.Fatal(err) 73 | } 74 | }() 75 | 76 | <-done 77 | server.Log.Warn("caught signal, stopping...") 78 | _ = server.Close() 79 | server.Log.Info("main.go finished") 80 | } 81 | -------------------------------------------------------------------------------- /examples/persistence/bolt/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co, werbenhu 4 | 5 | package main 6 | 7 | import ( 8 | "log" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | "time" 13 | 14 | mqtt "github.com/mochi-mqtt/server/v2" 15 | "github.com/mochi-mqtt/server/v2/hooks/auth" 16 | "github.com/mochi-mqtt/server/v2/hooks/storage/bolt" 17 | "github.com/mochi-mqtt/server/v2/listeners" 18 | "go.etcd.io/bbolt" 19 | ) 20 | 21 | func main() { 22 | boltPath := ".bolt" 23 | defer os.RemoveAll(boltPath) // remove the example db files at the end 24 | 25 | sigs := make(chan os.Signal, 1) 26 | done := make(chan bool, 1) 27 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 28 | go func() { 29 | <-sigs 30 | done <- true 31 | }() 32 | 33 | server := mqtt.New(nil) 34 | _ = server.AddHook(new(auth.AllowHook), nil) 35 | 36 | err := server.AddHook(new(bolt.Hook), &bolt.Options{ 37 | Path: boltPath, 38 | Options: &bbolt.Options{ 39 | Timeout: 500 * time.Millisecond, 40 | }, 41 | }) 42 | if err != nil { 43 | log.Fatal(err) 44 | } 45 | 46 | tcp := listeners.NewTCP(listeners.Config{ 47 | ID: "t1", 48 | Address: ":1883", 49 | }) 50 | err = server.AddListener(tcp) 51 | if err != nil { 52 | log.Fatal(err) 53 | } 54 | 55 | go func() { 56 | err := server.Serve() 57 | if err != nil { 58 | log.Fatal(err) 59 | } 60 | }() 61 | 62 | <-done 63 | server.Log.Warn("caught signal, stopping...") 64 | _ = server.Close() 65 | server.Log.Info("main.go finished") 66 | } 67 | -------------------------------------------------------------------------------- /examples/persistence/pebble/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co, werbenhu 3 | // SPDX-FileContributor: werbenhu 4 | 5 | package main 6 | 7 | import ( 8 | "log" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | 13 | mqtt "github.com/mochi-mqtt/server/v2" 14 | "github.com/mochi-mqtt/server/v2/hooks/auth" 15 | "github.com/mochi-mqtt/server/v2/hooks/storage/pebble" 16 | "github.com/mochi-mqtt/server/v2/listeners" 17 | ) 18 | 19 | func main() { 20 | pebblePath := ".pebble" 21 | defer os.RemoveAll(pebblePath) // remove the example pebble files at the end 22 | 23 | sigs := make(chan os.Signal, 1) 24 | done := make(chan bool, 1) 25 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 26 | go func() { 27 | <-sigs 28 | done <- true 29 | }() 30 | 31 | server := mqtt.New(nil) 32 | _ = server.AddHook(new(auth.AllowHook), nil) 33 | 34 | err := server.AddHook(new(pebble.Hook), &pebble.Options{ 35 | Path: pebblePath, 36 | Mode: pebble.NoSync, 37 | }) 38 | if err != nil { 39 | log.Fatal(err) 40 | } 41 | 42 | tcp := listeners.NewTCP(listeners.Config{ 43 | ID: "t1", 44 | Address: ":1883", 45 | }) 46 | err = server.AddListener(tcp) 47 | if err != nil { 48 | log.Fatal(err) 49 | } 50 | 51 | go func() { 52 | err := server.Serve() 53 | if err != nil { 54 | log.Fatal(err) 55 | } 56 | }() 57 | 58 | <-done 59 | server.Log.Warn("caught signal, stopping...") 60 | _ = server.Close() 61 | server.Log.Info("main.go finished") 62 | } 63 | -------------------------------------------------------------------------------- /examples/persistence/redis/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "log" 9 | "log/slog" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | 14 | mqtt "github.com/mochi-mqtt/server/v2" 15 | "github.com/mochi-mqtt/server/v2/hooks/auth" 16 | "github.com/mochi-mqtt/server/v2/hooks/storage/redis" 17 | "github.com/mochi-mqtt/server/v2/listeners" 18 | 19 | rv8 "github.com/go-redis/redis/v8" 20 | ) 21 | 22 | func main() { 23 | sigs := make(chan os.Signal, 1) 24 | done := make(chan bool, 1) 25 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 26 | go func() { 27 | <-sigs 28 | done <- true 29 | }() 30 | 31 | server := mqtt.New(nil) 32 | _ = server.AddHook(new(auth.AllowHook), nil) 33 | 34 | level := new(slog.LevelVar) 35 | server.Log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ 36 | Level: level, 37 | })) 38 | level.Set(slog.LevelDebug) 39 | 40 | err := server.AddHook(new(redis.Hook), &redis.Options{ 41 | Options: &rv8.Options{ 42 | Addr: "localhost:6379", // default redis address 43 | Password: "", // your password 44 | DB: 0, // your redis db 45 | }, 46 | }) 47 | if err != nil { 48 | log.Fatal(err) 49 | } 50 | 51 | tcp := listeners.NewTCP(listeners.Config{ 52 | ID: "t1", 53 | Address: ":1883", 54 | }) 55 | err = server.AddListener(tcp) 56 | if err != nil { 57 | log.Fatal(err) 58 | } 59 | 60 | go func() { 61 | err := server.Serve() 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | }() 66 | 67 | <-done 68 | server.Log.Warn("caught signal, stopping...") 69 | _ = server.Close() 70 | server.Log.Info("main.go finished") 71 | } 72 | -------------------------------------------------------------------------------- /examples/tcp/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "log" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | 13 | mqtt "github.com/mochi-mqtt/server/v2" 14 | "github.com/mochi-mqtt/server/v2/hooks/auth" 15 | "github.com/mochi-mqtt/server/v2/listeners" 16 | ) 17 | 18 | func main() { 19 | sigs := make(chan os.Signal, 1) 20 | done := make(chan bool, 1) 21 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 22 | go func() { 23 | <-sigs 24 | done <- true 25 | }() 26 | 27 | // An example of configuring various server options... 28 | options := &mqtt.Options{ 29 | // InflightTTL: 60 * 15, // Set an example custom 15-min TTL for inflight messages 30 | } 31 | 32 | server := mqtt.New(options) 33 | 34 | // For security reasons, the default implementation disallows all connections. 35 | // If you want to allow all connections, you must specifically allow it. 36 | err := server.AddHook(new(auth.AllowHook), nil) 37 | if err != nil { 38 | log.Fatal(err) 39 | } 40 | 41 | tcp := listeners.NewTCP(listeners.Config{ 42 | ID: "t1", 43 | Address: ":1883", 44 | }) 45 | err = server.AddListener(tcp) 46 | if err != nil { 47 | log.Fatal(err) 48 | } 49 | 50 | go func() { 51 | err := server.Serve() 52 | if err != nil { 53 | log.Fatal(err) 54 | } 55 | }() 56 | 57 | <-done 58 | server.Log.Warn("caught signal, stopping...") 59 | _ = server.Close() 60 | server.Log.Info("main.go finished") 61 | } 62 | -------------------------------------------------------------------------------- /examples/tls/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "crypto/tls" 9 | "log" 10 | "os" 11 | "os/signal" 12 | "syscall" 13 | 14 | mqtt "github.com/mochi-mqtt/server/v2" 15 | "github.com/mochi-mqtt/server/v2/hooks/auth" 16 | "github.com/mochi-mqtt/server/v2/listeners" 17 | ) 18 | 19 | var ( 20 | testCertificate = []byte(`-----BEGIN CERTIFICATE----- 21 | MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB 22 | VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV 23 | BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD 24 | VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x 25 | DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3 26 | AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi 27 | OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI 28 | MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD 29 | gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ 30 | qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy 31 | zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw= 32 | -----END CERTIFICATE-----`) 33 | 34 | testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- 35 | MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o 36 | FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA 37 | rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB 38 | AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K 39 | UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m 40 | n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ 41 | mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6 42 | INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z 43 | AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt 44 | /F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32 45 | WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy 46 | w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3 47 | OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc= 48 | -----END RSA PRIVATE KEY-----`) 49 | ) 50 | 51 | func main() { 52 | sigs := make(chan os.Signal, 1) 53 | done := make(chan bool, 1) 54 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 55 | go func() { 56 | <-sigs 57 | done <- true 58 | }() 59 | 60 | // Load tls cert from your cert file 61 | cert, err := tls.LoadX509KeyPair("replace_your_cert.pem", "replace_your_cert.key") 62 | 63 | //cert, err := tls.X509KeyPair(testCertificate, testPrivateKey) 64 | if err != nil { 65 | log.Fatal(err) 66 | } 67 | 68 | // Basic TLS Config 69 | tlsConfig := &tls.Config{ 70 | Certificates: []tls.Certificate{cert}, 71 | } 72 | 73 | // Optionally, if you want clients to authenticate only with certs issued by your CA, 74 | // you might want to use something like this: 75 | // certPool := x509.NewCertPool() 76 | // _ = certPool.AppendCertsFromPEM(caCertPem) 77 | // tlsConfig := &tls.Config{ 78 | // ClientCAs: certPool, 79 | // ClientAuth: tls.RequireAndVerifyClientCert, 80 | // } 81 | 82 | server := mqtt.New(nil) 83 | _ = server.AddHook(new(auth.AllowHook), nil) 84 | 85 | tcp := listeners.NewTCP(listeners.Config{ 86 | ID: "t1", 87 | Address: ":1883", 88 | TLSConfig: tlsConfig, 89 | }) 90 | err = server.AddListener(tcp) 91 | if err != nil { 92 | log.Fatal(err) 93 | } 94 | 95 | ws := listeners.NewWebsocket(listeners.Config{ 96 | ID: "ws1", 97 | Address: ":1882", 98 | TLSConfig: tlsConfig, 99 | }) 100 | err = server.AddListener(ws) 101 | if err != nil { 102 | log.Fatal(err) 103 | } 104 | 105 | stats := listeners.NewHTTPStats( 106 | listeners.Config{ 107 | ID: "stats", 108 | Address: ":8080", 109 | TLSConfig: tlsConfig, 110 | }, server.Info, 111 | ) 112 | err = server.AddListener(stats) 113 | if err != nil { 114 | log.Fatal(err) 115 | } 116 | 117 | go func() { 118 | err := server.Serve() 119 | if err != nil { 120 | log.Fatal(err) 121 | } 122 | }() 123 | 124 | <-done 125 | server.Log.Warn("caught signal, stopping...") 126 | _ = server.Close() 127 | server.Log.Info("main.go finished") 128 | } 129 | -------------------------------------------------------------------------------- /examples/websocket/main.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package main 6 | 7 | import ( 8 | "log" 9 | "os" 10 | "os/signal" 11 | "syscall" 12 | 13 | mqtt "github.com/mochi-mqtt/server/v2" 14 | "github.com/mochi-mqtt/server/v2/hooks/auth" 15 | "github.com/mochi-mqtt/server/v2/listeners" 16 | ) 17 | 18 | func main() { 19 | sigs := make(chan os.Signal, 1) 20 | done := make(chan bool, 1) 21 | signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) 22 | go func() { 23 | <-sigs 24 | done <- true 25 | }() 26 | 27 | server := mqtt.New(nil) 28 | _ = server.AddHook(new(auth.AllowHook), nil) 29 | 30 | ws := listeners.NewWebsocket(listeners.Config{ 31 | ID: "ws1", 32 | Address: ":1882", 33 | }) 34 | err := server.AddListener(ws) 35 | if err != nil { 36 | log.Fatal(err) 37 | } 38 | 39 | go func() { 40 | err := server.Serve() 41 | if err != nil { 42 | log.Fatal(err) 43 | } 44 | }() 45 | 46 | <-done 47 | server.Log.Warn("caught signal, stopping...") 48 | _ = server.Close() 49 | server.Log.Info("main.go finished") 50 | } 51 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mochi-mqtt/server/v2 2 | 3 | go 1.21 4 | 5 | require ( 6 | github.com/alicebob/miniredis/v2 v2.23.0 7 | github.com/cockroachdb/pebble v1.1.0 8 | github.com/dgraph-io/badger/v4 v4.2.0 9 | github.com/go-redis/redis/v8 v8.11.5 10 | github.com/gorilla/websocket v1.5.0 11 | github.com/jinzhu/copier v0.3.5 12 | github.com/rs/xid v1.4.0 13 | github.com/stretchr/testify v1.8.1 14 | go.etcd.io/bbolt v1.3.5 15 | gopkg.in/yaml.v3 v3.0.1 16 | ) 17 | 18 | require ( 19 | github.com/DataDog/zstd v1.4.5 // indirect 20 | github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect 21 | github.com/beorn7/perks v1.0.1 // indirect 22 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 23 | github.com/cockroachdb/errors v1.11.1 // indirect 24 | github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect 25 | github.com/cockroachdb/redact v1.1.5 // indirect 26 | github.com/cockroachdb/tokenbucket v0.0.0-20230807174530-cc333fc44b06 // indirect 27 | github.com/davecgh/go-spew v1.1.1 // indirect 28 | github.com/dgraph-io/ristretto v0.1.1 // indirect 29 | github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect 30 | github.com/dustin/go-humanize v1.0.0 // indirect 31 | github.com/getsentry/sentry-go v0.18.0 // indirect 32 | github.com/gogo/protobuf v1.3.2 // indirect 33 | github.com/golang/glog v1.2.4 // indirect 34 | github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect 35 | github.com/golang/protobuf v1.5.2 // indirect 36 | github.com/golang/snappy v0.0.4 // indirect 37 | github.com/google/flatbuffers v1.12.1 // indirect 38 | github.com/klauspost/compress v1.15.15 // indirect 39 | github.com/kr/pretty v0.3.1 // indirect 40 | github.com/kr/text v0.2.0 // indirect 41 | github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect 42 | github.com/pkg/errors v0.9.1 // indirect 43 | github.com/pmezard/go-difflib v1.0.0 // indirect 44 | github.com/prometheus/client_golang v1.12.0 // indirect 45 | github.com/prometheus/client_model v0.2.1-0.20210607210712-147c58e9608a // indirect 46 | github.com/prometheus/common v0.32.1 // indirect 47 | github.com/prometheus/procfs v0.7.3 // indirect 48 | github.com/rogpeppe/go-internal v1.9.0 // indirect 49 | github.com/yuin/gopher-lua v0.0.0-20210529063254-f4c35e4016d9 // indirect 50 | go.opencensus.io v0.22.5 // indirect 51 | golang.org/x/exp v0.0.0-20230626212559-97b1e661b5df // indirect 52 | golang.org/x/net v0.33.0 // indirect 53 | golang.org/x/sys v0.28.0 // indirect 54 | golang.org/x/text v0.21.0 // indirect 55 | google.golang.org/protobuf v1.33.0 // indirect 56 | ) 57 | -------------------------------------------------------------------------------- /hooks/auth/allow_all.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package auth 6 | 7 | import ( 8 | "bytes" 9 | 10 | "github.com/mochi-mqtt/server/v2" 11 | "github.com/mochi-mqtt/server/v2/packets" 12 | ) 13 | 14 | // AllowHook is an authentication hook which allows connection access 15 | // for all users and read and write access to all topics. 16 | type AllowHook struct { 17 | mqtt.HookBase 18 | } 19 | 20 | // ID returns the ID of the hook. 21 | func (h *AllowHook) ID() string { 22 | return "allow-all-auth" 23 | } 24 | 25 | // Provides indicates which hook methods this hook provides. 26 | func (h *AllowHook) Provides(b byte) bool { 27 | return bytes.Contains([]byte{ 28 | mqtt.OnConnectAuthenticate, 29 | mqtt.OnACLCheck, 30 | }, []byte{b}) 31 | } 32 | 33 | // OnConnectAuthenticate returns true/allowed for all requests. 34 | func (h *AllowHook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { 35 | return true 36 | } 37 | 38 | // OnACLCheck returns true/allowed for all checks. 39 | func (h *AllowHook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { 40 | return true 41 | } 42 | -------------------------------------------------------------------------------- /hooks/auth/allow_all_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package auth 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/mochi-mqtt/server/v2" 11 | "github.com/mochi-mqtt/server/v2/packets" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestAllowAllID(t *testing.T) { 16 | h := new(AllowHook) 17 | require.Equal(t, "allow-all-auth", h.ID()) 18 | } 19 | 20 | func TestAllowAllProvides(t *testing.T) { 21 | h := new(AllowHook) 22 | require.True(t, h.Provides(mqtt.OnACLCheck)) 23 | require.True(t, h.Provides(mqtt.OnConnectAuthenticate)) 24 | require.False(t, h.Provides(mqtt.OnPublished)) 25 | } 26 | 27 | func TestAllowAllOnConnectAuthenticate(t *testing.T) { 28 | h := new(AllowHook) 29 | require.True(t, h.OnConnectAuthenticate(new(mqtt.Client), packets.Packet{})) 30 | } 31 | 32 | func TestAllowAllOnACLCheck(t *testing.T) { 33 | h := new(AllowHook) 34 | require.True(t, h.OnACLCheck(new(mqtt.Client), "any", true)) 35 | } 36 | -------------------------------------------------------------------------------- /hooks/auth/auth.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package auth 6 | 7 | import ( 8 | "bytes" 9 | 10 | mqtt "github.com/mochi-mqtt/server/v2" 11 | "github.com/mochi-mqtt/server/v2/packets" 12 | ) 13 | 14 | // Options contains the configuration/rules data for the auth ledger. 15 | type Options struct { 16 | Data []byte 17 | Ledger *Ledger 18 | } 19 | 20 | // Hook is an authentication hook which implements an auth ledger. 21 | type Hook struct { 22 | mqtt.HookBase 23 | config *Options 24 | ledger *Ledger 25 | } 26 | 27 | // ID returns the ID of the hook. 28 | func (h *Hook) ID() string { 29 | return "auth-ledger" 30 | } 31 | 32 | // Provides indicates which hook methods this hook provides. 33 | func (h *Hook) Provides(b byte) bool { 34 | return bytes.Contains([]byte{ 35 | mqtt.OnConnectAuthenticate, 36 | mqtt.OnACLCheck, 37 | }, []byte{b}) 38 | } 39 | 40 | // Init configures the hook with the auth ledger to be used for checking. 41 | func (h *Hook) Init(config any) error { 42 | if _, ok := config.(*Options); !ok && config != nil { 43 | return mqtt.ErrInvalidConfigType 44 | } 45 | 46 | if config == nil { 47 | config = new(Options) 48 | } 49 | 50 | h.config = config.(*Options) 51 | 52 | var err error 53 | if h.config.Ledger != nil { 54 | h.ledger = h.config.Ledger 55 | } else if len(h.config.Data) > 0 { 56 | h.ledger = new(Ledger) 57 | err = h.ledger.Unmarshal(h.config.Data) 58 | } 59 | if err != nil { 60 | return err 61 | } 62 | 63 | if h.ledger == nil { 64 | h.ledger = &Ledger{ 65 | Auth: AuthRules{}, 66 | ACL: ACLRules{}, 67 | } 68 | } 69 | 70 | h.Log.Info("loaded auth rules", 71 | "authentication", len(h.ledger.Auth), 72 | "acl", len(h.ledger.ACL)) 73 | 74 | return nil 75 | } 76 | 77 | // OnConnectAuthenticate returns true if the connecting client has rules which provide access 78 | // in the auth ledger. 79 | func (h *Hook) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { 80 | if _, ok := h.ledger.AuthOk(cl, pk); ok { 81 | return true 82 | } 83 | 84 | h.Log.Info("client failed authentication check", 85 | "username", string(pk.Connect.Username), 86 | "remote", cl.Net.Remote) 87 | return false 88 | } 89 | 90 | // OnACLCheck returns true if the connecting client has matching read or write access to subscribe 91 | // or publish to a given topic. 92 | func (h *Hook) OnACLCheck(cl *mqtt.Client, topic string, write bool) bool { 93 | if _, ok := h.ledger.ACLOk(cl, topic, write); ok { 94 | return true 95 | } 96 | 97 | h.Log.Debug("client failed allowed ACL check", 98 | "client", cl.ID, 99 | "username", string(cl.Properties.Username), 100 | "topic", topic) 101 | 102 | return false 103 | } 104 | -------------------------------------------------------------------------------- /hooks/auth/auth_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package auth 6 | 7 | import ( 8 | "log/slog" 9 | "os" 10 | "testing" 11 | 12 | mqtt "github.com/mochi-mqtt/server/v2" 13 | "github.com/mochi-mqtt/server/v2/packets" 14 | "github.com/stretchr/testify/require" 15 | ) 16 | 17 | var logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) 18 | 19 | // func teardown(t *testing.T, path string, h *Hook) { 20 | // h.Stop() 21 | // } 22 | 23 | func TestBasicID(t *testing.T) { 24 | h := new(Hook) 25 | require.Equal(t, "auth-ledger", h.ID()) 26 | } 27 | 28 | func TestBasicProvides(t *testing.T) { 29 | h := new(Hook) 30 | require.True(t, h.Provides(mqtt.OnACLCheck)) 31 | require.True(t, h.Provides(mqtt.OnConnectAuthenticate)) 32 | require.False(t, h.Provides(mqtt.OnPublish)) 33 | } 34 | 35 | func TestBasicInitBadConfig(t *testing.T) { 36 | h := new(Hook) 37 | h.SetOpts(logger, nil) 38 | 39 | err := h.Init(map[string]any{}) 40 | require.Error(t, err) 41 | } 42 | 43 | func TestBasicInitDefaultConfig(t *testing.T) { 44 | h := new(Hook) 45 | h.SetOpts(logger, nil) 46 | 47 | err := h.Init(nil) 48 | require.NoError(t, err) 49 | } 50 | 51 | func TestBasicInitWithLedgerPointer(t *testing.T) { 52 | h := new(Hook) 53 | h.SetOpts(logger, nil) 54 | 55 | ln := &Ledger{ 56 | Auth: []AuthRule{ 57 | { 58 | Remote: "127.0.0.1", 59 | Allow: true, 60 | }, 61 | }, 62 | ACL: []ACLRule{ 63 | { 64 | Remote: "127.0.0.1", 65 | Filters: Filters{ 66 | "#": ReadWrite, 67 | }, 68 | }, 69 | }, 70 | } 71 | 72 | err := h.Init(&Options{ 73 | Ledger: ln, 74 | }) 75 | 76 | require.NoError(t, err) 77 | require.Same(t, ln, h.ledger) 78 | } 79 | 80 | func TestBasicInitWithLedgerJSON(t *testing.T) { 81 | h := new(Hook) 82 | h.SetOpts(logger, nil) 83 | 84 | require.Nil(t, h.ledger) 85 | err := h.Init(&Options{ 86 | Data: ledgerJSON, 87 | }) 88 | 89 | require.NoError(t, err) 90 | require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username) 91 | require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client) 92 | } 93 | 94 | func TestBasicInitWithLedgerYAML(t *testing.T) { 95 | h := new(Hook) 96 | h.SetOpts(logger, nil) 97 | 98 | require.Nil(t, h.ledger) 99 | err := h.Init(&Options{ 100 | Data: ledgerYAML, 101 | }) 102 | 103 | require.NoError(t, err) 104 | require.Equal(t, ledgerStruct.Auth[0].Username, h.ledger.Auth[0].Username) 105 | require.Equal(t, ledgerStruct.ACL[0].Client, h.ledger.ACL[0].Client) 106 | } 107 | 108 | func TestBasicInitWithLedgerBadDAta(t *testing.T) { 109 | h := new(Hook) 110 | h.SetOpts(logger, nil) 111 | 112 | require.Nil(t, h.ledger) 113 | err := h.Init(&Options{ 114 | Data: []byte("fdsfdsafasd"), 115 | }) 116 | 117 | require.Error(t, err) 118 | } 119 | 120 | func TestOnConnectAuthenticate(t *testing.T) { 121 | h := new(Hook) 122 | h.SetOpts(logger, nil) 123 | 124 | ln := new(Ledger) 125 | ln.Auth = checkLedger.Auth 126 | ln.ACL = checkLedger.ACL 127 | err := h.Init( 128 | &Options{ 129 | Ledger: ln, 130 | }, 131 | ) 132 | 133 | require.NoError(t, err) 134 | 135 | require.True(t, h.OnConnectAuthenticate( 136 | &mqtt.Client{ 137 | Properties: mqtt.ClientProperties{ 138 | Username: []byte("mochi"), 139 | }, 140 | }, 141 | packets.Packet{Connect: packets.ConnectParams{Password: []byte("melon")}}, 142 | )) 143 | 144 | require.False(t, h.OnConnectAuthenticate( 145 | &mqtt.Client{ 146 | Properties: mqtt.ClientProperties{ 147 | Username: []byte("mochi"), 148 | }, 149 | }, 150 | packets.Packet{Connect: packets.ConnectParams{Password: []byte("bad-pass")}}, 151 | )) 152 | 153 | require.False(t, h.OnConnectAuthenticate( 154 | &mqtt.Client{}, 155 | packets.Packet{}, 156 | )) 157 | } 158 | 159 | func TestOnACL(t *testing.T) { 160 | h := new(Hook) 161 | h.SetOpts(logger, nil) 162 | 163 | ln := new(Ledger) 164 | ln.Auth = checkLedger.Auth 165 | ln.ACL = checkLedger.ACL 166 | err := h.Init( 167 | &Options{ 168 | Ledger: ln, 169 | }, 170 | ) 171 | 172 | require.NoError(t, err) 173 | 174 | require.True(t, h.OnACLCheck( 175 | &mqtt.Client{ 176 | Properties: mqtt.ClientProperties{ 177 | Username: []byte("mochi"), 178 | }, 179 | }, 180 | "mochi/info", 181 | true, 182 | )) 183 | 184 | require.False(t, h.OnACLCheck( 185 | &mqtt.Client{ 186 | Properties: mqtt.ClientProperties{ 187 | Username: []byte("mochi"), 188 | }, 189 | }, 190 | "d/j/f", 191 | true, 192 | )) 193 | 194 | require.True(t, h.OnACLCheck( 195 | &mqtt.Client{ 196 | Properties: mqtt.ClientProperties{ 197 | Username: []byte("mochi"), 198 | }, 199 | }, 200 | "readonly", 201 | false, 202 | )) 203 | 204 | require.False(t, h.OnACLCheck( 205 | &mqtt.Client{ 206 | Properties: mqtt.ClientProperties{ 207 | Username: []byte("mochi"), 208 | }, 209 | }, 210 | "readonly", 211 | true, 212 | )) 213 | } 214 | -------------------------------------------------------------------------------- /hooks/auth/ledger.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package auth 6 | 7 | import ( 8 | "encoding/json" 9 | "strings" 10 | "sync" 11 | 12 | "gopkg.in/yaml.v3" 13 | 14 | "github.com/mochi-mqtt/server/v2" 15 | "github.com/mochi-mqtt/server/v2/packets" 16 | ) 17 | 18 | const ( 19 | Deny Access = iota // user cannot access the topic 20 | ReadOnly // user can only subscribe to the topic 21 | WriteOnly // user can only publish to the topic 22 | ReadWrite // user can both publish and subscribe to the topic 23 | ) 24 | 25 | // Access determines the read/write privileges for an ACL rule. 26 | type Access byte 27 | 28 | // Users contains a map of access rules for specific users, keyed on username. 29 | type Users map[string]UserRule 30 | 31 | // UserRule defines a set of access rules for a specific user. 32 | type UserRule struct { 33 | Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user 34 | Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user 35 | ACL Filters `json:"acl,omitempty" yaml:"acl,omitempty"` // filters to match, if desired 36 | Disallow bool `json:"disallow,omitempty" yaml:"disallow,omitempty"` // allow or disallow the user 37 | } 38 | 39 | // AuthRules defines generic access rules applicable to all users. 40 | type AuthRules []AuthRule 41 | 42 | type AuthRule struct { 43 | Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client 44 | Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user 45 | Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or 46 | Password RString `json:"password,omitempty" yaml:"password,omitempty"` // the password of a user 47 | Allow bool `json:"allow,omitempty" yaml:"allow,omitempty"` // allow or disallow the users 48 | } 49 | 50 | // ACLRules defines generic topic or filter access rules applicable to all users. 51 | type ACLRules []ACLRule 52 | 53 | // ACLRule defines access rules for a specific topic or filter. 54 | type ACLRule struct { 55 | Client RString `json:"client,omitempty" yaml:"client,omitempty"` // the id of a connecting client 56 | Username RString `json:"username,omitempty" yaml:"username,omitempty"` // the username of a user 57 | Remote RString `json:"remote,omitempty" yaml:"remote,omitempty"` // remote address or 58 | Filters Filters `json:"filters,omitempty" yaml:"filters,omitempty"` // filters to match 59 | } 60 | 61 | // Filters is a map of Access rules keyed on filter. 62 | type Filters map[RString]Access 63 | 64 | // RString is a rule value string. 65 | type RString string 66 | 67 | // Matches returns true if the rule matches a given string. 68 | func (r RString) Matches(a string) bool { 69 | rr := string(r) 70 | if r == "" || r == "*" || a == rr { 71 | return true 72 | } 73 | 74 | i := strings.Index(rr, "*") 75 | if i > 0 && len(a) > i && strings.Compare(rr[:i], a[:i]) == 0 { 76 | return true 77 | } 78 | 79 | return false 80 | } 81 | 82 | // FilterMatches returns true if a filter matches a topic rule. 83 | func (r RString) FilterMatches(a string) bool { 84 | _, ok := MatchTopic(string(r), a) 85 | return ok 86 | } 87 | 88 | // MatchTopic checks if a given topic matches a filter, accounting for filter 89 | // wildcards. Eg. filter /a/b/+/c == topic a/b/d/c. 90 | func MatchTopic(filter string, topic string) (elements []string, matched bool) { 91 | filterParts := strings.Split(filter, "/") 92 | topicParts := strings.Split(topic, "/") 93 | 94 | elements = make([]string, 0) 95 | for i := 0; i < len(filterParts); i++ { 96 | if i >= len(topicParts) { 97 | matched = false 98 | return 99 | } 100 | 101 | if filterParts[i] == "+" { 102 | elements = append(elements, topicParts[i]) 103 | continue 104 | } 105 | 106 | if filterParts[i] == "#" { 107 | matched = true 108 | elements = append(elements, strings.Join(topicParts[i:], "/")) 109 | return 110 | } 111 | 112 | if filterParts[i] != topicParts[i] { 113 | matched = false 114 | return 115 | } 116 | } 117 | 118 | return elements, true 119 | } 120 | 121 | // Ledger is an auth ledger containing access rules for users and topics. 122 | type Ledger struct { 123 | sync.Mutex `json:"-" yaml:"-"` 124 | Users Users `json:"users" yaml:"users"` 125 | Auth AuthRules `json:"auth" yaml:"auth"` 126 | ACL ACLRules `json:"acl" yaml:"acl"` 127 | } 128 | 129 | // Update updates the internal values of the ledger. 130 | func (l *Ledger) Update(ln *Ledger) { 131 | l.Lock() 132 | defer l.Unlock() 133 | l.Auth = ln.Auth 134 | l.ACL = ln.ACL 135 | } 136 | 137 | // AuthOk returns true if the rules indicate the user is allowed to authenticate. 138 | func (l *Ledger) AuthOk(cl *mqtt.Client, pk packets.Packet) (n int, ok bool) { 139 | // If the users map is set, always check for a predefined user first instead 140 | // of iterating through global rules. 141 | if l.Users != nil { 142 | if u, ok := l.Users[string(cl.Properties.Username)]; ok && 143 | u.Password != "" && 144 | u.Password == RString(pk.Connect.Password) { 145 | return 0, !u.Disallow 146 | } 147 | } 148 | 149 | // If there's no users map, or no user was found, attempt to find a matching 150 | // rule (which may also contain a user). 151 | for n, rule := range l.Auth { 152 | if rule.Client.Matches(cl.ID) && 153 | rule.Username.Matches(string(cl.Properties.Username)) && 154 | rule.Password.Matches(string(pk.Connect.Password)) && 155 | rule.Remote.Matches(cl.Net.Remote) { 156 | return n, rule.Allow 157 | } 158 | } 159 | 160 | return 0, false 161 | } 162 | 163 | // ACLOk returns true if the rules indicate the user is allowed to read or write to 164 | // a specific filter or topic respectively, based on the `write` bool. 165 | func (l *Ledger) ACLOk(cl *mqtt.Client, topic string, write bool) (n int, ok bool) { 166 | // If the users map is set, always check for a predefined user first instead 167 | // of iterating through global rules. 168 | if l.Users != nil { 169 | if u, ok := l.Users[string(cl.Properties.Username)]; ok && len(u.ACL) > 0 { 170 | for filter, access := range u.ACL { 171 | if filter.FilterMatches(topic) { 172 | if !write && (access == ReadOnly || access == ReadWrite) { 173 | return n, true 174 | } else if write && (access == WriteOnly || access == ReadWrite) { 175 | return n, true 176 | } else { 177 | return n, false 178 | } 179 | } 180 | } 181 | } 182 | } 183 | 184 | for n, rule := range l.ACL { 185 | if rule.Client.Matches(cl.ID) && 186 | rule.Username.Matches(string(cl.Properties.Username)) && 187 | rule.Remote.Matches(cl.Net.Remote) { 188 | if len(rule.Filters) == 0 { 189 | return n, true 190 | } 191 | 192 | if write { 193 | for filter, access := range rule.Filters { 194 | if access == WriteOnly || access == ReadWrite { 195 | if filter.FilterMatches(topic) { 196 | return n, true 197 | } 198 | } 199 | } 200 | } 201 | 202 | if !write { 203 | for filter, access := range rule.Filters { 204 | if access == ReadOnly || access == ReadWrite { 205 | if filter.FilterMatches(topic) { 206 | return n, true 207 | } 208 | } 209 | } 210 | } 211 | 212 | for filter := range rule.Filters { 213 | if filter.FilterMatches(topic) { 214 | return n, false 215 | } 216 | } 217 | } 218 | } 219 | 220 | return 0, true 221 | } 222 | 223 | // ToJSON encodes the values into a JSON string. 224 | func (l *Ledger) ToJSON() (data []byte, err error) { 225 | return json.Marshal(l) 226 | } 227 | 228 | // ToYAML encodes the values into a YAML string. 229 | func (l *Ledger) ToYAML() (data []byte, err error) { 230 | return yaml.Marshal(l) 231 | } 232 | 233 | // Unmarshal decodes a JSON or YAML string (such as a rule config from a file) into a struct. 234 | func (l *Ledger) Unmarshal(data []byte) error { 235 | l.Lock() 236 | defer l.Unlock() 237 | if len(data) == 0 { 238 | return nil 239 | } 240 | 241 | if data[0] == '{' { 242 | return json.Unmarshal(data, l) 243 | } 244 | 245 | return yaml.Unmarshal(data, &l) 246 | } 247 | -------------------------------------------------------------------------------- /hooks/debug/debug.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package debug 6 | 7 | import ( 8 | "fmt" 9 | "log/slog" 10 | "strings" 11 | 12 | mqtt "github.com/mochi-mqtt/server/v2" 13 | "github.com/mochi-mqtt/server/v2/hooks/storage" 14 | "github.com/mochi-mqtt/server/v2/packets" 15 | ) 16 | 17 | // Options contains configuration settings for the debug output. 18 | type Options struct { 19 | Enable bool `yaml:"enable" json:"enable"` // non-zero field for enabling hook using file-based config 20 | ShowPacketData bool `yaml:"show_packet_data" json:"show_packet_data"` // include decoded packet data (default false) 21 | ShowPings bool `yaml:"show_pings" json:"show_pings"` // show ping requests and responses (default false) 22 | ShowPasswords bool `yaml:"show_passwords" json:"show_passwords"` // show connecting user passwords (default false) 23 | } 24 | 25 | // Hook is a debugging hook which logs additional low-level information from the server. 26 | type Hook struct { 27 | mqtt.HookBase 28 | config *Options 29 | Log *slog.Logger 30 | } 31 | 32 | // ID returns the ID of the hook. 33 | func (h *Hook) ID() string { 34 | return "debug" 35 | } 36 | 37 | // Provides indicates that this hook provides all methods. 38 | func (h *Hook) Provides(b byte) bool { 39 | return true 40 | } 41 | 42 | // Init is called when the hook is initialized. 43 | func (h *Hook) Init(config any) error { 44 | if _, ok := config.(*Options); !ok && config != nil { 45 | return mqtt.ErrInvalidConfigType 46 | } 47 | 48 | if config == nil { 49 | config = new(Options) 50 | } 51 | 52 | h.config = config.(*Options) 53 | 54 | return nil 55 | } 56 | 57 | // SetOpts is called when the hook receives inheritable server parameters. 58 | func (h *Hook) SetOpts(l *slog.Logger, opts *mqtt.HookOptions) { 59 | h.Log = l 60 | h.Log.Debug("", "method", "SetOpts") 61 | } 62 | 63 | // Stop is called when the hook is stopped. 64 | func (h *Hook) Stop() error { 65 | h.Log.Debug("", "method", "Stop") 66 | return nil 67 | } 68 | 69 | // OnStarted is called when the server starts. 70 | func (h *Hook) OnStarted() { 71 | h.Log.Debug("", "method", "OnStarted") 72 | } 73 | 74 | // OnStopped is called when the server stops. 75 | func (h *Hook) OnStopped() { 76 | h.Log.Debug("", "method", "OnStopped") 77 | } 78 | 79 | // OnPacketRead is called when a new packet is received from a client. 80 | func (h *Hook) OnPacketRead(cl *mqtt.Client, pk packets.Packet) (packets.Packet, error) { 81 | if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings { 82 | return pk, nil 83 | } 84 | 85 | h.Log.Debug(fmt.Sprintf("%s << %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk)) 86 | return pk, nil 87 | } 88 | 89 | // OnPacketSent is called when a packet is sent to a client. 90 | func (h *Hook) OnPacketSent(cl *mqtt.Client, pk packets.Packet, b []byte) { 91 | if (pk.FixedHeader.Type == packets.Pingresp || pk.FixedHeader.Type == packets.Pingreq) && !h.config.ShowPings { 92 | return 93 | } 94 | 95 | h.Log.Debug(fmt.Sprintf("%s >> %s", strings.ToUpper(packets.PacketNames[pk.FixedHeader.Type]), cl.ID), "m", h.packetMeta(pk)) 96 | } 97 | 98 | // OnRetainMessage is called when a published message is retained (or retain deleted/modified). 99 | func (h *Hook) OnRetainMessage(cl *mqtt.Client, pk packets.Packet, r int64) { 100 | h.Log.Debug("retained message on topic", "m", h.packetMeta(pk)) 101 | } 102 | 103 | // OnQosPublish is called when a publish packet with Qos is issued to a subscriber. 104 | func (h *Hook) OnQosPublish(cl *mqtt.Client, pk packets.Packet, sent int64, resends int) { 105 | h.Log.Debug("inflight out", "m", h.packetMeta(pk)) 106 | } 107 | 108 | // OnQosComplete is called when the Qos flow for a message has been completed. 109 | func (h *Hook) OnQosComplete(cl *mqtt.Client, pk packets.Packet) { 110 | h.Log.Debug("inflight complete", "m", h.packetMeta(pk)) 111 | } 112 | 113 | // OnQosDropped is called the Qos flow for a message expires. 114 | func (h *Hook) OnQosDropped(cl *mqtt.Client, pk packets.Packet) { 115 | h.Log.Debug("inflight dropped", "m", h.packetMeta(pk)) 116 | } 117 | 118 | // OnLWTSent is called when a Will Message has been issued from a disconnecting client. 119 | func (h *Hook) OnLWTSent(cl *mqtt.Client, pk packets.Packet) { 120 | h.Log.Debug("sent lwt for client", "method", "OnLWTSent", "client", cl.ID) 121 | } 122 | 123 | // OnRetainedExpired is called when the server clears expired retained messages. 124 | func (h *Hook) OnRetainedExpired(filter string) { 125 | h.Log.Debug("retained message expired", "method", "OnRetainedExpired", "topic", filter) 126 | } 127 | 128 | // OnClientExpired is called when the server clears an expired client. 129 | func (h *Hook) OnClientExpired(cl *mqtt.Client) { 130 | h.Log.Debug("client session expired", "method", "OnClientExpired", "client", cl.ID) 131 | } 132 | 133 | // StoredClients is called when the server restores clients from a store. 134 | func (h *Hook) StoredClients() (v []storage.Client, err error) { 135 | h.Log.Debug("", "method", "StoredClients") 136 | 137 | return v, nil 138 | } 139 | 140 | // StoredSubscriptions is called when the server restores subscriptions from a store. 141 | func (h *Hook) StoredSubscriptions() (v []storage.Subscription, err error) { 142 | h.Log.Debug("", "method", "StoredSubscriptions") 143 | return v, nil 144 | } 145 | 146 | // StoredRetainedMessages is called when the server restores retained messages from a store. 147 | func (h *Hook) StoredRetainedMessages() (v []storage.Message, err error) { 148 | h.Log.Debug("", "method", "StoredRetainedMessages") 149 | return v, nil 150 | } 151 | 152 | // StoredInflightMessages is called when the server restores inflight messages from a store. 153 | func (h *Hook) StoredInflightMessages() (v []storage.Message, err error) { 154 | h.Log.Debug("", "method", "StoredInflightMessages") 155 | return v, nil 156 | } 157 | 158 | // StoredSysInfo is called when the server restores system info from a store. 159 | func (h *Hook) StoredSysInfo() (v storage.SystemInfo, err error) { 160 | h.Log.Debug("", "method", "StoredSysInfo") 161 | 162 | return v, nil 163 | } 164 | 165 | // packetMeta adds additional type-specific metadata to the debug logs. 166 | func (h *Hook) packetMeta(pk packets.Packet) map[string]any { 167 | m := map[string]any{} 168 | switch pk.FixedHeader.Type { 169 | case packets.Connect: 170 | m["id"] = pk.Connect.ClientIdentifier 171 | m["clean"] = pk.Connect.Clean 172 | m["keepalive"] = pk.Connect.Keepalive 173 | m["version"] = pk.ProtocolVersion 174 | m["username"] = string(pk.Connect.Username) 175 | if h.config.ShowPasswords { 176 | m["password"] = string(pk.Connect.Password) 177 | } 178 | if pk.Connect.WillFlag { 179 | m["will_topic"] = pk.Connect.WillTopic 180 | m["will_payload"] = string(pk.Connect.WillPayload) 181 | } 182 | case packets.Publish: 183 | m["topic"] = pk.TopicName 184 | m["payload"] = string(pk.Payload) 185 | m["raw"] = pk.Payload 186 | m["qos"] = pk.FixedHeader.Qos 187 | m["id"] = pk.PacketID 188 | case packets.Connack: 189 | fallthrough 190 | case packets.Disconnect: 191 | fallthrough 192 | case packets.Puback: 193 | fallthrough 194 | case packets.Pubrec: 195 | fallthrough 196 | case packets.Pubrel: 197 | fallthrough 198 | case packets.Pubcomp: 199 | m["id"] = pk.PacketID 200 | m["reason"] = int(pk.ReasonCode) 201 | if pk.ReasonCode > packets.CodeSuccess.Code && pk.ProtocolVersion == 5 { 202 | m["reason_string"] = pk.Properties.ReasonString 203 | } 204 | case packets.Subscribe: 205 | f := map[string]int{} 206 | ids := map[string]int{} 207 | for _, v := range pk.Filters { 208 | f[v.Filter] = int(v.Qos) 209 | ids[v.Filter] = v.Identifier 210 | } 211 | m["filters"] = f 212 | m["subids"] = f 213 | 214 | case packets.Unsubscribe: 215 | f := []string{} 216 | for _, v := range pk.Filters { 217 | f = append(f, v.Filter) 218 | } 219 | m["filters"] = f 220 | case packets.Suback: 221 | fallthrough 222 | case packets.Unsuback: 223 | r := []int{} 224 | for _, v := range pk.ReasonCodes { 225 | r = append(r, int(v)) 226 | } 227 | m["reasons"] = r 228 | case packets.Auth: 229 | // tbd 230 | } 231 | 232 | if h.config.ShowPacketData { 233 | m["packet"] = pk 234 | } 235 | 236 | return m 237 | } 238 | -------------------------------------------------------------------------------- /hooks/storage/storage.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package storage 6 | 7 | import ( 8 | "encoding/json" 9 | "errors" 10 | 11 | "github.com/mochi-mqtt/server/v2/packets" 12 | "github.com/mochi-mqtt/server/v2/system" 13 | ) 14 | 15 | const ( 16 | SubscriptionKey = "SUB" // unique key to denote Subscriptions in a store 17 | SysInfoKey = "SYS" // unique key to denote server system information in a store 18 | RetainedKey = "RET" // unique key to denote retained messages in a store 19 | InflightKey = "IFM" // unique key to denote inflight messages in a store 20 | ClientKey = "CL" // unique key to denote clients in a store 21 | ) 22 | 23 | var ( 24 | // ErrDBFileNotOpen indicates that the file database (e.g. bolt/badger) wasn't open for reading. 25 | ErrDBFileNotOpen = errors.New("db file not open") 26 | ) 27 | 28 | // Serializable is an interface for objects that can be serialized and deserialized. 29 | type Serializable interface { 30 | UnmarshalBinary([]byte) error 31 | MarshalBinary() (data []byte, err error) 32 | } 33 | 34 | // Client is a storable representation of an MQTT client. 35 | type Client struct { 36 | Will ClientWill `json:"will"` // will topic and payload data if applicable 37 | Properties ClientProperties `json:"properties"` // the connect properties for the client 38 | Username []byte `json:"username"` // the username of the client 39 | ID string `json:"id" storm:"id"` // the client id / storage key 40 | T string `json:"t"` // the data type (client) 41 | Remote string `json:"remote"` // the remote address of the client 42 | Listener string `json:"listener"` // the listener the client connected on 43 | ProtocolVersion byte `json:"protocolVersion"` // mqtt protocol version of the client 44 | Clean bool `json:"clean"` // if the client requested a clean start/session 45 | } 46 | 47 | // ClientProperties contains a limited set of the mqtt v5 properties specific to a client connection. 48 | type ClientProperties struct { 49 | AuthenticationData []byte `json:"authenticationData,omitempty"` 50 | User []packets.UserProperty `json:"user,omitempty"` 51 | AuthenticationMethod string `json:"authenticationMethod,omitempty"` 52 | SessionExpiryInterval uint32 `json:"sessionExpiryInterval,omitempty"` 53 | MaximumPacketSize uint32 `json:"maximumPacketSize,omitempty"` 54 | ReceiveMaximum uint16 `json:"receiveMaximum,omitempty"` 55 | TopicAliasMaximum uint16 `json:"topicAliasMaximum,omitempty"` 56 | SessionExpiryIntervalFlag bool `json:"sessionExpiryIntervalFlag,omitempty"` 57 | RequestProblemInfo byte `json:"requestProblemInfo,omitempty"` 58 | RequestProblemInfoFlag bool `json:"requestProblemInfoFlag,omitempty"` 59 | RequestResponseInfo byte `json:"requestResponseInfo,omitempty"` 60 | } 61 | 62 | // ClientWill contains a will message for a client, and limited mqtt v5 properties. 63 | type ClientWill struct { 64 | Payload []byte `json:"payload,omitempty"` 65 | User []packets.UserProperty `json:"user,omitempty"` 66 | TopicName string `json:"topicName,omitempty"` 67 | Flag uint32 `json:"flag,omitempty"` 68 | WillDelayInterval uint32 `json:"willDelayInterval,omitempty"` 69 | Qos byte `json:"qos,omitempty"` 70 | Retain bool `json:"retain,omitempty"` 71 | } 72 | 73 | // MarshalBinary encodes the values into a json string. 74 | func (d Client) MarshalBinary() (data []byte, err error) { 75 | return json.Marshal(d) 76 | } 77 | 78 | // UnmarshalBinary decodes a json string into a struct. 79 | func (d *Client) UnmarshalBinary(data []byte) error { 80 | if len(data) == 0 { 81 | return nil 82 | } 83 | return json.Unmarshal(data, d) 84 | } 85 | 86 | // Message is a storable representation of an MQTT message (specifically publish). 87 | type Message struct { 88 | Properties MessageProperties `json:"properties"` // - 89 | Payload []byte `json:"payload"` // the message payload (if retained) 90 | T string `json:"t,omitempty"` // the data type 91 | ID string `json:"id,omitempty" storm:"id"` // the storage key 92 | Client string `json:"client,omitempty"` // the client id the message is for 93 | Origin string `json:"origin,omitempty"` // the id of the client who sent the message 94 | TopicName string `json:"topic_name,omitempty"` // the topic the message was sent to (if retained) 95 | FixedHeader packets.FixedHeader `json:"fixedheader"` // the header properties of the message 96 | Created int64 `json:"created,omitempty"` // the time the message was created in unixtime 97 | Sent int64 `json:"sent,omitempty"` // the last time the message was sent (for retries) in unixtime (if inflight) 98 | PacketID uint16 `json:"packet_id,omitempty"` // the unique id of the packet (if inflight) 99 | } 100 | 101 | // MessageProperties contains a limited subset of mqtt v5 properties specific to publish messages. 102 | type MessageProperties struct { 103 | CorrelationData []byte `json:"correlationData,omitempty"` 104 | SubscriptionIdentifier []int `json:"subscriptionIdentifier,omitempty"` 105 | User []packets.UserProperty `json:"user,omitempty"` 106 | ContentType string `json:"contentType,omitempty"` 107 | ResponseTopic string `json:"responseTopic,omitempty"` 108 | MessageExpiryInterval uint32 `json:"messageExpiry,omitempty"` 109 | TopicAlias uint16 `json:"topicAlias,omitempty"` 110 | PayloadFormat byte `json:"payloadFormat,omitempty"` 111 | PayloadFormatFlag bool `json:"payloadFormatFlag,omitempty"` 112 | } 113 | 114 | // MarshalBinary encodes the values into a json string. 115 | func (d Message) MarshalBinary() (data []byte, err error) { 116 | return json.Marshal(d) 117 | } 118 | 119 | // UnmarshalBinary decodes a json string into a struct. 120 | func (d *Message) UnmarshalBinary(data []byte) error { 121 | if len(data) == 0 { 122 | return nil 123 | } 124 | return json.Unmarshal(data, d) 125 | } 126 | 127 | // ToPacket converts a storage.Message to a standard packet. 128 | func (d *Message) ToPacket() packets.Packet { 129 | pk := packets.Packet{ 130 | FixedHeader: d.FixedHeader, 131 | PacketID: d.PacketID, 132 | TopicName: d.TopicName, 133 | Payload: d.Payload, 134 | Origin: d.Origin, 135 | Created: d.Created, 136 | Properties: packets.Properties{ 137 | PayloadFormat: d.Properties.PayloadFormat, 138 | PayloadFormatFlag: d.Properties.PayloadFormatFlag, 139 | MessageExpiryInterval: d.Properties.MessageExpiryInterval, 140 | ContentType: d.Properties.ContentType, 141 | ResponseTopic: d.Properties.ResponseTopic, 142 | CorrelationData: d.Properties.CorrelationData, 143 | SubscriptionIdentifier: d.Properties.SubscriptionIdentifier, 144 | TopicAlias: d.Properties.TopicAlias, 145 | User: d.Properties.User, 146 | }, 147 | } 148 | 149 | // Return a deep copy of the packet data otherwise the slices will 150 | // continue pointing at the values from the storage packet. 151 | pk = pk.Copy(true) 152 | pk.FixedHeader.Dup = d.FixedHeader.Dup 153 | 154 | return pk 155 | } 156 | 157 | // Subscription is a storable representation of an MQTT subscription. 158 | type Subscription struct { 159 | T string `json:"t,omitempty"` 160 | ID string `json:"id,omitempty" storm:"id"` 161 | Client string `json:"client,omitempty"` 162 | Filter string `json:"filter"` 163 | Identifier int `json:"identifier,omitempty"` 164 | RetainHandling byte `json:"retain_handling,omitempty"` 165 | Qos byte `json:"qos"` 166 | RetainAsPublished bool `json:"retain_as_pub,omitempty"` 167 | NoLocal bool `json:"no_local,omitempty"` 168 | } 169 | 170 | // MarshalBinary encodes the values into a json string. 171 | func (d Subscription) MarshalBinary() (data []byte, err error) { 172 | return json.Marshal(d) 173 | } 174 | 175 | // UnmarshalBinary decodes a json string into a struct. 176 | func (d *Subscription) UnmarshalBinary(data []byte) error { 177 | if len(data) == 0 { 178 | return nil 179 | } 180 | return json.Unmarshal(data, d) 181 | } 182 | 183 | // SystemInfo is a storable representation of the system information values. 184 | type SystemInfo struct { 185 | system.Info // embed the system info struct 186 | T string `json:"t"` // the data type 187 | ID string `json:"id" storm:"id"` // the storage key 188 | } 189 | 190 | // MarshalBinary encodes the values into a json string. 191 | func (d SystemInfo) MarshalBinary() (data []byte, err error) { 192 | return json.Marshal(d) 193 | } 194 | 195 | // UnmarshalBinary decodes a json string into a struct. 196 | func (d *SystemInfo) UnmarshalBinary(data []byte) error { 197 | if len(data) == 0 { 198 | return nil 199 | } 200 | return json.Unmarshal(data, d) 201 | } 202 | -------------------------------------------------------------------------------- /hooks/storage/storage_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package storage 6 | 7 | import ( 8 | "testing" 9 | "time" 10 | 11 | "github.com/mochi-mqtt/server/v2/packets" 12 | "github.com/mochi-mqtt/server/v2/system" 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | var ( 17 | clientStruct = Client{ 18 | ID: "test", 19 | T: "client", 20 | Remote: "remote", 21 | Listener: "listener", 22 | Username: []byte("mochi"), 23 | Clean: true, 24 | Properties: ClientProperties{ 25 | SessionExpiryInterval: 2, 26 | SessionExpiryIntervalFlag: true, 27 | AuthenticationMethod: "a", 28 | AuthenticationData: []byte("test"), 29 | RequestProblemInfo: 1, 30 | RequestProblemInfoFlag: true, 31 | RequestResponseInfo: 1, 32 | ReceiveMaximum: 128, 33 | TopicAliasMaximum: 256, 34 | User: []packets.UserProperty{ 35 | {Key: "k", Val: "v"}, 36 | }, 37 | MaximumPacketSize: 120, 38 | }, 39 | Will: ClientWill{ 40 | Qos: 1, 41 | Payload: []byte("abc"), 42 | TopicName: "a/b/c", 43 | Flag: 1, 44 | Retain: true, 45 | WillDelayInterval: 2, 46 | User: []packets.UserProperty{ 47 | {Key: "k2", Val: "v2"}, 48 | }, 49 | }, 50 | } 51 | clientJSON = []byte(`{"will":{"payload":"YWJj","user":[{"k":"k2","v":"v2"}],"topicName":"a/b/c","flag":1,"willDelayInterval":2,"qos":1,"retain":true},"properties":{"authenticationData":"dGVzdA==","user":[{"k":"k","v":"v"}],"authenticationMethod":"a","sessionExpiryInterval":2,"maximumPacketSize":120,"receiveMaximum":128,"topicAliasMaximum":256,"sessionExpiryIntervalFlag":true,"requestProblemInfo":1,"requestProblemInfoFlag":true,"requestResponseInfo":1},"username":"bW9jaGk=","id":"test","t":"client","remote":"remote","listener":"listener","protocolVersion":0,"clean":true}`) 52 | 53 | messageStruct = Message{ 54 | T: "message", 55 | Payload: []byte("payload"), 56 | FixedHeader: packets.FixedHeader{ 57 | Remaining: 2, 58 | Type: 3, 59 | Qos: 1, 60 | Dup: true, 61 | Retain: true, 62 | }, 63 | ID: "id", 64 | Origin: "mochi", 65 | TopicName: "topic", 66 | Properties: MessageProperties{ 67 | PayloadFormat: 1, 68 | PayloadFormatFlag: true, 69 | MessageExpiryInterval: 20, 70 | ContentType: "type", 71 | ResponseTopic: "a/b/r", 72 | CorrelationData: []byte("r"), 73 | SubscriptionIdentifier: []int{1}, 74 | TopicAlias: 2, 75 | User: []packets.UserProperty{ 76 | {Key: "k2", Val: "v2"}, 77 | }, 78 | }, 79 | Created: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(), 80 | Sent: time.Date(2019, time.September, 21, 1, 2, 3, 4, time.UTC).Unix(), 81 | PacketID: 100, 82 | } 83 | messageJSON = []byte(`{"properties":{"correlationData":"cg==","subscriptionIdentifier":[1],"user":[{"k":"k2","v":"v2"}],"contentType":"type","responseTopic":"a/b/r","messageExpiry":20,"topicAlias":2,"payloadFormat":1,"payloadFormatFlag":true},"payload":"cGF5bG9hZA==","t":"message","id":"id","origin":"mochi","topic_name":"topic","fixedheader":{"remaining":2,"type":3,"qos":1,"dup":true,"retain":true},"created":1569027723,"sent":1569027723,"packet_id":100}`) 84 | 85 | subscriptionStruct = Subscription{ 86 | T: "subscription", 87 | ID: "id", 88 | Client: "mochi", 89 | Filter: "a/b/c", 90 | Qos: 1, 91 | } 92 | subscriptionJSON = []byte(`{"t":"subscription","id":"id","client":"mochi","filter":"a/b/c","qos":1}`) 93 | 94 | sysInfoStruct = SystemInfo{ 95 | T: "info", 96 | ID: "id", 97 | Info: system.Info{ 98 | Version: "2.0.0", 99 | Started: 1, 100 | Uptime: 2, 101 | BytesReceived: 3, 102 | BytesSent: 4, 103 | ClientsConnected: 5, 104 | ClientsMaximum: 7, 105 | MessagesReceived: 10, 106 | MessagesSent: 11, 107 | MessagesDropped: 20, 108 | PacketsReceived: 12, 109 | PacketsSent: 13, 110 | Retained: 15, 111 | Inflight: 16, 112 | InflightDropped: 17, 113 | }, 114 | } 115 | sysInfoJSON = []byte(`{"version":"2.0.0","started":1,"time":0,"uptime":2,"bytes_received":3,"bytes_sent":4,"clients_connected":5,"clients_disconnected":0,"clients_maximum":7,"clients_total":0,"messages_received":10,"messages_sent":11,"messages_dropped":20,"retained":15,"inflight":16,"inflight_dropped":17,"subscriptions":0,"packets_received":12,"packets_sent":13,"memory_alloc":0,"threads":0,"t":"info","id":"id"}`) 116 | ) 117 | 118 | func TestClientMarshalBinary(t *testing.T) { 119 | data, err := clientStruct.MarshalBinary() 120 | require.NoError(t, err) 121 | require.JSONEq(t, string(clientJSON), string(data)) 122 | } 123 | 124 | func TestClientUnmarshalBinary(t *testing.T) { 125 | d := clientStruct 126 | err := d.UnmarshalBinary(clientJSON) 127 | require.NoError(t, err) 128 | require.Equal(t, clientStruct, d) 129 | } 130 | 131 | func TestClientUnmarshalBinaryEmpty(t *testing.T) { 132 | d := Client{} 133 | err := d.UnmarshalBinary([]byte{}) 134 | require.NoError(t, err) 135 | require.Equal(t, Client{}, d) 136 | } 137 | 138 | func TestMessageMarshalBinary(t *testing.T) { 139 | data, err := messageStruct.MarshalBinary() 140 | require.NoError(t, err) 141 | require.JSONEq(t, string(messageJSON), string(data)) 142 | } 143 | 144 | func TestMessageUnmarshalBinary(t *testing.T) { 145 | d := messageStruct 146 | err := d.UnmarshalBinary(messageJSON) 147 | require.NoError(t, err) 148 | require.Equal(t, messageStruct, d) 149 | } 150 | 151 | func TestMessageUnmarshalBinaryEmpty(t *testing.T) { 152 | d := Message{} 153 | err := d.UnmarshalBinary([]byte{}) 154 | require.NoError(t, err) 155 | require.Equal(t, Message{}, d) 156 | } 157 | 158 | func TestSubscriptionMarshalBinary(t *testing.T) { 159 | data, err := subscriptionStruct.MarshalBinary() 160 | require.NoError(t, err) 161 | require.JSONEq(t, string(subscriptionJSON), string(data)) 162 | } 163 | 164 | func TestSubscriptionUnmarshalBinary(t *testing.T) { 165 | d := subscriptionStruct 166 | err := d.UnmarshalBinary(subscriptionJSON) 167 | require.NoError(t, err) 168 | require.Equal(t, subscriptionStruct, d) 169 | } 170 | 171 | func TestSubscriptionUnmarshalBinaryEmpty(t *testing.T) { 172 | d := Subscription{} 173 | err := d.UnmarshalBinary([]byte{}) 174 | require.NoError(t, err) 175 | require.Equal(t, Subscription{}, d) 176 | } 177 | 178 | func TestSysInfoMarshalBinary(t *testing.T) { 179 | data, err := sysInfoStruct.MarshalBinary() 180 | require.NoError(t, err) 181 | require.JSONEq(t, string(sysInfoJSON), string(data)) 182 | } 183 | 184 | func TestSysInfoUnmarshalBinary(t *testing.T) { 185 | d := sysInfoStruct 186 | err := d.UnmarshalBinary(sysInfoJSON) 187 | require.NoError(t, err) 188 | require.Equal(t, sysInfoStruct, d) 189 | } 190 | 191 | func TestSysInfoUnmarshalBinaryEmpty(t *testing.T) { 192 | d := SystemInfo{} 193 | err := d.UnmarshalBinary([]byte{}) 194 | require.NoError(t, err) 195 | require.Equal(t, SystemInfo{}, d) 196 | } 197 | 198 | func TestMessageToPacket(t *testing.T) { 199 | d := messageStruct 200 | pk := d.ToPacket() 201 | 202 | require.Equal(t, packets.Packet{ 203 | Payload: []byte("payload"), 204 | FixedHeader: packets.FixedHeader{ 205 | Remaining: d.FixedHeader.Remaining, 206 | Type: d.FixedHeader.Type, 207 | Qos: d.FixedHeader.Qos, 208 | Dup: d.FixedHeader.Dup, 209 | Retain: d.FixedHeader.Retain, 210 | }, 211 | Origin: d.Origin, 212 | TopicName: d.TopicName, 213 | Properties: packets.Properties{ 214 | PayloadFormat: d.Properties.PayloadFormat, 215 | PayloadFormatFlag: d.Properties.PayloadFormatFlag, 216 | MessageExpiryInterval: d.Properties.MessageExpiryInterval, 217 | ContentType: d.Properties.ContentType, 218 | ResponseTopic: d.Properties.ResponseTopic, 219 | CorrelationData: d.Properties.CorrelationData, 220 | SubscriptionIdentifier: d.Properties.SubscriptionIdentifier, 221 | TopicAlias: d.Properties.TopicAlias, 222 | User: d.Properties.User, 223 | }, 224 | PacketID: 100, 225 | Created: d.Created, 226 | }, pk) 227 | 228 | } 229 | -------------------------------------------------------------------------------- /inflight.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package mqtt 6 | 7 | import ( 8 | "sort" 9 | "sync" 10 | "sync/atomic" 11 | 12 | "github.com/mochi-mqtt/server/v2/packets" 13 | ) 14 | 15 | // Inflight is a map of InflightMessage keyed on packet id. 16 | type Inflight struct { 17 | sync.RWMutex 18 | internal map[uint16]packets.Packet // internal contains the inflight packets 19 | receiveQuota int32 // remaining inbound qos quota for flow control 20 | sendQuota int32 // remaining outbound qos quota for flow control 21 | maximumReceiveQuota int32 // maximum allowed receive quota 22 | maximumSendQuota int32 // maximum allowed send quota 23 | } 24 | 25 | // NewInflights returns a new instance of an Inflight packets map. 26 | func NewInflights() *Inflight { 27 | return &Inflight{ 28 | internal: map[uint16]packets.Packet{}, 29 | } 30 | } 31 | 32 | // Set adds or updates an inflight packet by packet id. 33 | func (i *Inflight) Set(m packets.Packet) bool { 34 | i.Lock() 35 | defer i.Unlock() 36 | 37 | _, ok := i.internal[m.PacketID] 38 | i.internal[m.PacketID] = m 39 | return !ok 40 | } 41 | 42 | // Get returns an inflight packet by packet id. 43 | func (i *Inflight) Get(id uint16) (packets.Packet, bool) { 44 | i.RLock() 45 | defer i.RUnlock() 46 | 47 | if m, ok := i.internal[id]; ok { 48 | return m, true 49 | } 50 | 51 | return packets.Packet{}, false 52 | } 53 | 54 | // Len returns the size of the inflight messages map. 55 | func (i *Inflight) Len() int { 56 | i.RLock() 57 | defer i.RUnlock() 58 | return len(i.internal) 59 | } 60 | 61 | // Clone returns a new instance of Inflight with the same message data. 62 | // This is used when transferring inflights from a taken-over session. 63 | func (i *Inflight) Clone() *Inflight { 64 | c := NewInflights() 65 | i.RLock() 66 | defer i.RUnlock() 67 | for k, v := range i.internal { 68 | c.internal[k] = v 69 | } 70 | return c 71 | } 72 | 73 | // GetAll returns all the inflight messages. 74 | func (i *Inflight) GetAll(immediate bool) []packets.Packet { 75 | i.RLock() 76 | defer i.RUnlock() 77 | 78 | m := []packets.Packet{} 79 | for _, v := range i.internal { 80 | if !immediate || (immediate && v.Expiry < 0) { 81 | m = append(m, v) 82 | } 83 | } 84 | 85 | sort.Slice(m, func(i, j int) bool { 86 | return uint16(m[i].Created) < uint16(m[j].Created) 87 | }) 88 | 89 | return m 90 | } 91 | 92 | // NextImmediate returns the next inflight packet which is indicated to be sent immediately. 93 | // This typically occurs when the quota has been exhausted, and we need to wait until new quota 94 | // is free to continue sending. 95 | func (i *Inflight) NextImmediate() (packets.Packet, bool) { 96 | i.RLock() 97 | defer i.RUnlock() 98 | 99 | m := i.GetAll(true) 100 | if len(m) > 0 { 101 | return m[0], true 102 | } 103 | 104 | return packets.Packet{}, false 105 | } 106 | 107 | // Delete removes an in-flight message from the map. Returns true if the message existed. 108 | func (i *Inflight) Delete(id uint16) bool { 109 | i.Lock() 110 | defer i.Unlock() 111 | 112 | _, ok := i.internal[id] 113 | delete(i.internal, id) 114 | 115 | return ok 116 | } 117 | 118 | // TakeRecieveQuota reduces the receive quota by 1. 119 | func (i *Inflight) DecreaseReceiveQuota() { 120 | if atomic.LoadInt32(&i.receiveQuota) > 0 { 121 | atomic.AddInt32(&i.receiveQuota, -1) 122 | } 123 | } 124 | 125 | // TakeRecieveQuota increases the receive quota by 1. 126 | func (i *Inflight) IncreaseReceiveQuota() { 127 | if atomic.LoadInt32(&i.receiveQuota) < atomic.LoadInt32(&i.maximumReceiveQuota) { 128 | atomic.AddInt32(&i.receiveQuota, 1) 129 | } 130 | } 131 | 132 | // ResetReceiveQuota resets the receive quota to the maximum allowed value. 133 | func (i *Inflight) ResetReceiveQuota(n int32) { 134 | atomic.StoreInt32(&i.receiveQuota, n) 135 | atomic.StoreInt32(&i.maximumReceiveQuota, n) 136 | } 137 | 138 | // DecreaseSendQuota reduces the send quota by 1. 139 | func (i *Inflight) DecreaseSendQuota() { 140 | if atomic.LoadInt32(&i.sendQuota) > 0 { 141 | atomic.AddInt32(&i.sendQuota, -1) 142 | } 143 | } 144 | 145 | // IncreaseSendQuota increases the send quota by 1. 146 | func (i *Inflight) IncreaseSendQuota() { 147 | if atomic.LoadInt32(&i.sendQuota) < atomic.LoadInt32(&i.maximumSendQuota) { 148 | atomic.AddInt32(&i.sendQuota, 1) 149 | } 150 | } 151 | 152 | // ResetSendQuota resets the send quota to the maximum allowed value. 153 | func (i *Inflight) ResetSendQuota(n int32) { 154 | atomic.StoreInt32(&i.sendQuota, n) 155 | atomic.StoreInt32(&i.maximumSendQuota, n) 156 | } 157 | -------------------------------------------------------------------------------- /inflight_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package mqtt 6 | 7 | import ( 8 | "sync/atomic" 9 | "testing" 10 | 11 | "github.com/mochi-mqtt/server/v2/packets" 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestInflightSet(t *testing.T) { 16 | cl, _, _ := newTestClient() 17 | 18 | r := cl.State.Inflight.Set(packets.Packet{PacketID: 1}) 19 | require.True(t, r) 20 | require.NotNil(t, cl.State.Inflight.internal[1]) 21 | require.NotEqual(t, 0, cl.State.Inflight.internal[1].PacketID) 22 | 23 | r = cl.State.Inflight.Set(packets.Packet{PacketID: 1}) 24 | require.False(t, r) 25 | } 26 | 27 | func TestInflightGet(t *testing.T) { 28 | cl, _, _ := newTestClient() 29 | cl.State.Inflight.Set(packets.Packet{PacketID: 2}) 30 | 31 | msg, ok := cl.State.Inflight.Get(2) 32 | require.True(t, ok) 33 | require.NotEqual(t, 0, msg.PacketID) 34 | } 35 | 36 | func TestInflightGetAllAndImmediate(t *testing.T) { 37 | cl, _, _ := newTestClient() 38 | cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) 39 | cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) 40 | cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) 41 | cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1}) 42 | cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5}) 43 | 44 | require.Equal(t, []packets.Packet{ 45 | {PacketID: 1, Created: 1}, 46 | {PacketID: 2, Created: 2}, 47 | {PacketID: 3, Created: 3, Expiry: -1}, 48 | {PacketID: 4, Created: 4, Expiry: -1}, 49 | {PacketID: 5, Created: 5}, 50 | }, cl.State.Inflight.GetAll(false)) 51 | 52 | require.Equal(t, []packets.Packet{ 53 | {PacketID: 3, Created: 3, Expiry: -1}, 54 | {PacketID: 4, Created: 4, Expiry: -1}, 55 | }, cl.State.Inflight.GetAll(true)) 56 | } 57 | 58 | func TestInflightLen(t *testing.T) { 59 | cl, _, _ := newTestClient() 60 | cl.State.Inflight.Set(packets.Packet{PacketID: 2}) 61 | require.Equal(t, 1, cl.State.Inflight.Len()) 62 | } 63 | 64 | func TestInflightClone(t *testing.T) { 65 | cl, _, _ := newTestClient() 66 | cl.State.Inflight.Set(packets.Packet{PacketID: 2}) 67 | require.Equal(t, 1, cl.State.Inflight.Len()) 68 | 69 | cloned := cl.State.Inflight.Clone() 70 | require.NotNil(t, cloned) 71 | require.NotSame(t, cloned, cl.State.Inflight) 72 | } 73 | 74 | func TestInflightDelete(t *testing.T) { 75 | cl, _, _ := newTestClient() 76 | 77 | cl.State.Inflight.Set(packets.Packet{PacketID: 3}) 78 | require.NotNil(t, cl.State.Inflight.internal[3]) 79 | 80 | r := cl.State.Inflight.Delete(3) 81 | require.True(t, r) 82 | require.Equal(t, uint16(0), cl.State.Inflight.internal[3].PacketID) 83 | 84 | _, ok := cl.State.Inflight.Get(3) 85 | require.False(t, ok) 86 | 87 | r = cl.State.Inflight.Delete(3) 88 | require.False(t, r) 89 | } 90 | 91 | func TestResetReceiveQuota(t *testing.T) { 92 | i := NewInflights() 93 | require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumReceiveQuota)) 94 | require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) 95 | i.ResetReceiveQuota(6) 96 | require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumReceiveQuota)) 97 | require.Equal(t, int32(6), atomic.LoadInt32(&i.receiveQuota)) 98 | } 99 | 100 | func TestReceiveQuota(t *testing.T) { 101 | i := NewInflights() 102 | i.receiveQuota = 4 103 | i.maximumReceiveQuota = 5 104 | require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) 105 | require.Equal(t, int32(4), atomic.LoadInt32(&i.receiveQuota)) 106 | 107 | // Return 1 108 | i.IncreaseReceiveQuota() 109 | require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) 110 | require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) 111 | 112 | // Try to go over max limit 113 | i.IncreaseReceiveQuota() 114 | require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumReceiveQuota)) 115 | require.Equal(t, int32(5), atomic.LoadInt32(&i.receiveQuota)) 116 | 117 | // Reset to max 1 118 | i.ResetReceiveQuota(1) 119 | require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) 120 | require.Equal(t, int32(1), atomic.LoadInt32(&i.receiveQuota)) 121 | 122 | // Take 1 123 | i.DecreaseReceiveQuota() 124 | require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) 125 | require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) 126 | 127 | // Try to go below zero 128 | i.DecreaseReceiveQuota() 129 | require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumReceiveQuota)) 130 | require.Equal(t, int32(0), atomic.LoadInt32(&i.receiveQuota)) 131 | } 132 | 133 | func TestResetSendQuota(t *testing.T) { 134 | i := NewInflights() 135 | require.Equal(t, int32(0), atomic.LoadInt32(&i.maximumSendQuota)) 136 | require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) 137 | i.ResetSendQuota(6) 138 | require.Equal(t, int32(6), atomic.LoadInt32(&i.maximumSendQuota)) 139 | require.Equal(t, int32(6), atomic.LoadInt32(&i.sendQuota)) 140 | } 141 | 142 | func TestSendQuota(t *testing.T) { 143 | i := NewInflights() 144 | i.sendQuota = 4 145 | i.maximumSendQuota = 5 146 | require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) 147 | require.Equal(t, int32(4), atomic.LoadInt32(&i.sendQuota)) 148 | 149 | // Return 1 150 | i.IncreaseSendQuota() 151 | require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) 152 | require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) 153 | 154 | // Try to go over max limit 155 | i.IncreaseSendQuota() 156 | require.Equal(t, int32(5), atomic.LoadInt32(&i.maximumSendQuota)) 157 | require.Equal(t, int32(5), atomic.LoadInt32(&i.sendQuota)) 158 | 159 | // Reset to max 1 160 | i.ResetSendQuota(1) 161 | require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) 162 | require.Equal(t, int32(1), atomic.LoadInt32(&i.sendQuota)) 163 | 164 | // Take 1 165 | i.DecreaseSendQuota() 166 | require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) 167 | require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) 168 | 169 | // Try to go below zero 170 | i.DecreaseSendQuota() 171 | require.Equal(t, int32(1), atomic.LoadInt32(&i.maximumSendQuota)) 172 | require.Equal(t, int32(0), atomic.LoadInt32(&i.sendQuota)) 173 | } 174 | 175 | func TestNextImmediate(t *testing.T) { 176 | cl, _, _ := newTestClient() 177 | cl.State.Inflight.Set(packets.Packet{PacketID: 1, Created: 1}) 178 | cl.State.Inflight.Set(packets.Packet{PacketID: 2, Created: 2}) 179 | cl.State.Inflight.Set(packets.Packet{PacketID: 3, Created: 3, Expiry: -1}) 180 | cl.State.Inflight.Set(packets.Packet{PacketID: 4, Created: 4, Expiry: -1}) 181 | cl.State.Inflight.Set(packets.Packet{PacketID: 5, Created: 5}) 182 | 183 | pk, ok := cl.State.Inflight.NextImmediate() 184 | require.True(t, ok) 185 | require.Equal(t, packets.Packet{PacketID: 3, Created: 3, Expiry: -1}, pk) 186 | 187 | r := cl.State.Inflight.Delete(3) 188 | require.True(t, r) 189 | 190 | pk, ok = cl.State.Inflight.NextImmediate() 191 | require.True(t, ok) 192 | require.Equal(t, packets.Packet{PacketID: 4, Created: 4, Expiry: -1}, pk) 193 | 194 | r = cl.State.Inflight.Delete(4) 195 | require.True(t, r) 196 | 197 | _, ok = cl.State.Inflight.NextImmediate() 198 | require.False(t, ok) 199 | } 200 | -------------------------------------------------------------------------------- /listeners/http_healthcheck.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: Derek Duncan 4 | 5 | package listeners 6 | 7 | import ( 8 | "context" 9 | "log/slog" 10 | "net/http" 11 | "sync" 12 | "sync/atomic" 13 | "time" 14 | ) 15 | 16 | const TypeHealthCheck = "healthcheck" 17 | 18 | // HTTPHealthCheck is a listener for providing an HTTP healthcheck endpoint. 19 | type HTTPHealthCheck struct { 20 | sync.RWMutex 21 | id string // the internal id of the listener 22 | address string // the network address to bind to 23 | config Config // configuration values for the listener 24 | listen *http.Server // the http server 25 | end uint32 // ensure the close methods are only called once 26 | } 27 | 28 | // NewHTTPHealthCheck initializes and returns a new HTTP listener, listening on an address. 29 | func NewHTTPHealthCheck(config Config) *HTTPHealthCheck { 30 | return &HTTPHealthCheck{ 31 | id: config.ID, 32 | address: config.Address, 33 | config: config, 34 | } 35 | } 36 | 37 | // ID returns the id of the listener. 38 | func (l *HTTPHealthCheck) ID() string { 39 | return l.id 40 | } 41 | 42 | // Address returns the address of the listener. 43 | func (l *HTTPHealthCheck) Address() string { 44 | return l.address 45 | } 46 | 47 | // Protocol returns the address of the listener. 48 | func (l *HTTPHealthCheck) Protocol() string { 49 | if l.listen != nil && l.listen.TLSConfig != nil { 50 | return "https" 51 | } 52 | 53 | return "http" 54 | } 55 | 56 | // Init initializes the listener. 57 | func (l *HTTPHealthCheck) Init(_ *slog.Logger) error { 58 | mux := http.NewServeMux() 59 | mux.HandleFunc("/healthcheck", func(w http.ResponseWriter, r *http.Request) { 60 | if r.Method != http.MethodGet { 61 | w.WriteHeader(http.StatusMethodNotAllowed) 62 | } 63 | }) 64 | l.listen = &http.Server{ 65 | ReadTimeout: 5 * time.Second, 66 | WriteTimeout: 5 * time.Second, 67 | Addr: l.address, 68 | Handler: mux, 69 | } 70 | 71 | if l.config.TLSConfig != nil { 72 | l.listen.TLSConfig = l.config.TLSConfig 73 | } 74 | 75 | return nil 76 | } 77 | 78 | // Serve starts listening for new connections and serving responses. 79 | func (l *HTTPHealthCheck) Serve(establish EstablishFn) { 80 | if l.listen.TLSConfig != nil { 81 | _ = l.listen.ListenAndServeTLS("", "") 82 | } else { 83 | _ = l.listen.ListenAndServe() 84 | } 85 | } 86 | 87 | // Close closes the listener and any client connections. 88 | func (l *HTTPHealthCheck) Close(closeClients CloseFn) { 89 | l.Lock() 90 | defer l.Unlock() 91 | 92 | if atomic.CompareAndSwapUint32(&l.end, 0, 1) { 93 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 94 | defer cancel() 95 | _ = l.listen.Shutdown(ctx) 96 | } 97 | 98 | closeClients(l.id) 99 | } 100 | -------------------------------------------------------------------------------- /listeners/http_healthcheck_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: Derek Duncan 4 | 5 | package listeners 6 | 7 | import ( 8 | "io" 9 | "net/http" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestNewHTTPHealthCheck(t *testing.T) { 17 | l := NewHTTPHealthCheck(basicConfig) 18 | require.Equal(t, basicConfig.ID, l.id) 19 | require.Equal(t, basicConfig.Address, l.address) 20 | } 21 | 22 | func TestHTTPHealthCheckID(t *testing.T) { 23 | l := NewHTTPHealthCheck(basicConfig) 24 | require.Equal(t, basicConfig.ID, l.ID()) 25 | } 26 | 27 | func TestHTTPHealthCheckAddress(t *testing.T) { 28 | l := NewHTTPHealthCheck(basicConfig) 29 | require.Equal(t, basicConfig.Address, l.Address()) 30 | } 31 | 32 | func TestHTTPHealthCheckProtocol(t *testing.T) { 33 | l := NewHTTPHealthCheck(basicConfig) 34 | require.Equal(t, "http", l.Protocol()) 35 | } 36 | 37 | func TestHTTPHealthCheckTLSProtocol(t *testing.T) { 38 | l := NewHTTPHealthCheck(tlsConfig) 39 | _ = l.Init(logger) 40 | require.Equal(t, "https", l.Protocol()) 41 | } 42 | 43 | func TestHTTPHealthCheckInit(t *testing.T) { 44 | l := NewHTTPHealthCheck(basicConfig) 45 | err := l.Init(logger) 46 | require.NoError(t, err) 47 | 48 | require.NotNil(t, l.listen) 49 | require.Equal(t, basicConfig.Address, l.listen.Addr) 50 | } 51 | 52 | func TestHTTPHealthCheckServeAndClose(t *testing.T) { 53 | // setup http stats listener 54 | l := NewHTTPHealthCheck(basicConfig) 55 | err := l.Init(logger) 56 | require.NoError(t, err) 57 | 58 | o := make(chan bool) 59 | go func(o chan bool) { 60 | l.Serve(MockEstablisher) 61 | o <- true 62 | }(o) 63 | 64 | time.Sleep(time.Millisecond) 65 | 66 | // call healthcheck 67 | resp, err := http.Get("http://localhost" + testAddr + "/healthcheck") 68 | require.NoError(t, err) 69 | require.NotNil(t, resp) 70 | 71 | defer resp.Body.Close() 72 | _, err = io.ReadAll(resp.Body) 73 | require.NoError(t, err) 74 | 75 | // ensure listening is closed 76 | var closed bool 77 | l.Close(func(id string) { 78 | closed = true 79 | }) 80 | 81 | require.Equal(t, true, closed) 82 | 83 | _, err = http.Get("http://localhost/healthcheck" + testAddr + "/healthcheck") 84 | require.Error(t, err) 85 | <-o 86 | } 87 | 88 | func TestHTTPHealthCheckServeAndCloseMethodNotAllowed(t *testing.T) { 89 | // setup http stats listener 90 | l := NewHTTPHealthCheck(basicConfig) 91 | err := l.Init(logger) 92 | require.NoError(t, err) 93 | 94 | o := make(chan bool) 95 | go func(o chan bool) { 96 | l.Serve(MockEstablisher) 97 | o <- true 98 | }(o) 99 | 100 | time.Sleep(time.Millisecond) 101 | 102 | // make disallowed method type http request 103 | resp, err := http.Post("http://localhost"+testAddr+"/healthcheck", "application/json", http.NoBody) 104 | require.NoError(t, err) 105 | require.NotNil(t, resp) 106 | 107 | defer resp.Body.Close() 108 | _, err = io.ReadAll(resp.Body) 109 | require.NoError(t, err) 110 | 111 | // ensure listening is closed 112 | var closed bool 113 | l.Close(func(id string) { 114 | closed = true 115 | }) 116 | 117 | require.Equal(t, true, closed) 118 | 119 | _, err = http.Post("http://localhost/healthcheck"+testAddr+"/healthcheck", "application/json", http.NoBody) 120 | require.Error(t, err) 121 | <-o 122 | } 123 | 124 | func TestHTTPHealthCheckServeTLSAndClose(t *testing.T) { 125 | l := NewHTTPHealthCheck(tlsConfig) 126 | err := l.Init(logger) 127 | require.NoError(t, err) 128 | 129 | o := make(chan bool) 130 | go func(o chan bool) { 131 | l.Serve(MockEstablisher) 132 | o <- true 133 | }(o) 134 | 135 | time.Sleep(time.Millisecond) 136 | l.Close(MockCloser) 137 | } 138 | -------------------------------------------------------------------------------- /listeners/http_sysinfo.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package listeners 6 | 7 | import ( 8 | "context" 9 | "encoding/json" 10 | "io" 11 | "log/slog" 12 | "net/http" 13 | "sync" 14 | "sync/atomic" 15 | "time" 16 | 17 | "github.com/mochi-mqtt/server/v2/system" 18 | ) 19 | 20 | const TypeSysInfo = "sysinfo" 21 | 22 | // HTTPStats is a listener for presenting the server $SYS stats on a JSON http endpoint. 23 | type HTTPStats struct { 24 | sync.RWMutex 25 | id string // the internal id of the listener 26 | address string // the network address to bind to 27 | config Config // configuration values for the listener 28 | listen *http.Server // the http server 29 | sysInfo *system.Info // pointers to the server data 30 | log *slog.Logger // server logger 31 | end uint32 // ensure the close methods are only called once 32 | } 33 | 34 | // NewHTTPStats initializes and returns a new HTTP listener, listening on an address. 35 | func NewHTTPStats(config Config, sysInfo *system.Info) *HTTPStats { 36 | return &HTTPStats{ 37 | sysInfo: sysInfo, 38 | id: config.ID, 39 | address: config.Address, 40 | config: config, 41 | } 42 | } 43 | 44 | // ID returns the id of the listener. 45 | func (l *HTTPStats) ID() string { 46 | return l.id 47 | } 48 | 49 | // Address returns the address of the listener. 50 | func (l *HTTPStats) Address() string { 51 | return l.address 52 | } 53 | 54 | // Protocol returns the address of the listener. 55 | func (l *HTTPStats) Protocol() string { 56 | if l.listen != nil && l.listen.TLSConfig != nil { 57 | return "https" 58 | } 59 | 60 | return "http" 61 | } 62 | 63 | // Init initializes the listener. 64 | func (l *HTTPStats) Init(log *slog.Logger) error { 65 | l.log = log 66 | mux := http.NewServeMux() 67 | mux.HandleFunc("/", l.jsonHandler) 68 | l.listen = &http.Server{ 69 | ReadTimeout: 5 * time.Second, 70 | WriteTimeout: 5 * time.Second, 71 | Addr: l.address, 72 | Handler: mux, 73 | } 74 | 75 | if l.config.TLSConfig != nil { 76 | l.listen.TLSConfig = l.config.TLSConfig 77 | } 78 | 79 | return nil 80 | } 81 | 82 | // Serve starts listening for new connections and serving responses. 83 | func (l *HTTPStats) Serve(establish EstablishFn) { 84 | 85 | var err error 86 | if l.listen.TLSConfig != nil { 87 | err = l.listen.ListenAndServeTLS("", "") 88 | } else { 89 | err = l.listen.ListenAndServe() 90 | } 91 | 92 | // After the listener has been shutdown, no need to print the http.ErrServerClosed error. 93 | if err != nil && atomic.LoadUint32(&l.end) == 0 { 94 | l.log.Error("failed to serve.", "error", err, "listener", l.id) 95 | } 96 | } 97 | 98 | // Close closes the listener and any client connections. 99 | func (l *HTTPStats) Close(closeClients CloseFn) { 100 | l.Lock() 101 | defer l.Unlock() 102 | 103 | if atomic.CompareAndSwapUint32(&l.end, 0, 1) { 104 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 105 | defer cancel() 106 | _ = l.listen.Shutdown(ctx) 107 | } 108 | 109 | closeClients(l.id) 110 | } 111 | 112 | // jsonHandler is an HTTP handler which outputs the $SYS stats as JSON. 113 | func (l *HTTPStats) jsonHandler(w http.ResponseWriter, req *http.Request) { 114 | info := *l.sysInfo.Clone() 115 | 116 | out, err := json.MarshalIndent(info, "", "\t") 117 | if err != nil { 118 | _, _ = io.WriteString(w, err.Error()) 119 | } 120 | 121 | _, _ = w.Write(out) 122 | } 123 | -------------------------------------------------------------------------------- /listeners/http_sysinfo_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package listeners 6 | 7 | import ( 8 | "encoding/json" 9 | "io" 10 | "net/http" 11 | "testing" 12 | "time" 13 | 14 | "github.com/mochi-mqtt/server/v2/system" 15 | 16 | "github.com/stretchr/testify/require" 17 | ) 18 | 19 | func TestNewHTTPStats(t *testing.T) { 20 | l := NewHTTPStats(basicConfig, nil) 21 | require.Equal(t, "t1", l.id) 22 | require.Equal(t, testAddr, l.address) 23 | } 24 | 25 | func TestHTTPStatsID(t *testing.T) { 26 | l := NewHTTPStats(basicConfig, nil) 27 | require.Equal(t, "t1", l.ID()) 28 | } 29 | 30 | func TestHTTPStatsAddress(t *testing.T) { 31 | l := NewHTTPStats(basicConfig, nil) 32 | require.Equal(t, testAddr, l.Address()) 33 | } 34 | 35 | func TestHTTPStatsProtocol(t *testing.T) { 36 | l := NewHTTPStats(basicConfig, nil) 37 | require.Equal(t, "http", l.Protocol()) 38 | } 39 | 40 | func TestHTTPStatsTLSProtocol(t *testing.T) { 41 | l := NewHTTPStats(tlsConfig, nil) 42 | _ = l.Init(logger) 43 | require.Equal(t, "https", l.Protocol()) 44 | } 45 | 46 | func TestHTTPStatsInit(t *testing.T) { 47 | sysInfo := new(system.Info) 48 | l := NewHTTPStats(basicConfig, sysInfo) 49 | err := l.Init(logger) 50 | require.NoError(t, err) 51 | 52 | require.NotNil(t, l.sysInfo) 53 | require.Equal(t, sysInfo, l.sysInfo) 54 | require.NotNil(t, l.listen) 55 | require.Equal(t, testAddr, l.listen.Addr) 56 | } 57 | 58 | func TestHTTPStatsServeAndClose(t *testing.T) { 59 | sysInfo := &system.Info{ 60 | Version: "test", 61 | } 62 | 63 | // setup http stats listener 64 | l := NewHTTPStats(basicConfig, sysInfo) 65 | err := l.Init(logger) 66 | require.NoError(t, err) 67 | 68 | o := make(chan bool) 69 | go func(o chan bool) { 70 | l.Serve(MockEstablisher) 71 | o <- true 72 | }(o) 73 | 74 | time.Sleep(time.Millisecond) 75 | 76 | // get body from stats address 77 | resp, err := http.Get("http://localhost" + testAddr) 78 | require.NoError(t, err) 79 | require.NotNil(t, resp) 80 | 81 | defer resp.Body.Close() 82 | body, err := io.ReadAll(resp.Body) 83 | require.NoError(t, err) 84 | 85 | // decode body from json and check data 86 | v := new(system.Info) 87 | err = json.Unmarshal(body, v) 88 | require.NoError(t, err) 89 | require.Equal(t, "test", v.Version) 90 | 91 | // ensure listening is closed 92 | var closed bool 93 | l.Close(func(id string) { 94 | closed = true 95 | }) 96 | 97 | require.Equal(t, true, closed) 98 | 99 | _, err = http.Get("http://localhost" + testAddr) 100 | require.Error(t, err) 101 | <-o 102 | } 103 | 104 | func TestHTTPStatsServeTLSAndClose(t *testing.T) { 105 | sysInfo := &system.Info{ 106 | Version: "test", 107 | } 108 | 109 | l := NewHTTPStats(tlsConfig, sysInfo) 110 | 111 | err := l.Init(logger) 112 | require.NoError(t, err) 113 | 114 | o := make(chan bool) 115 | go func(o chan bool) { 116 | l.Serve(MockEstablisher) 117 | o <- true 118 | }(o) 119 | 120 | time.Sleep(time.Millisecond) 121 | l.Close(MockCloser) 122 | } 123 | 124 | func TestHTTPStatsFailedToServe(t *testing.T) { 125 | sysInfo := &system.Info{ 126 | Version: "test", 127 | } 128 | 129 | // setup http stats listener 130 | config := basicConfig 131 | config.Address = "wrong_addr" 132 | l := NewHTTPStats(config, sysInfo) 133 | err := l.Init(logger) 134 | require.NoError(t, err) 135 | 136 | o := make(chan bool) 137 | go func(o chan bool) { 138 | l.Serve(MockEstablisher) 139 | o <- true 140 | }(o) 141 | 142 | <-o 143 | // ensure listening is closed 144 | var closed bool 145 | l.Close(func(id string) { 146 | closed = true 147 | }) 148 | require.Equal(t, true, closed) 149 | } 150 | -------------------------------------------------------------------------------- /listeners/listeners.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package listeners 6 | 7 | import ( 8 | "crypto/tls" 9 | "net" 10 | "sync" 11 | 12 | "log/slog" 13 | ) 14 | 15 | // Config contains configuration values for a listener. 16 | type Config struct { 17 | Type string 18 | ID string 19 | Address string 20 | // TLSConfig is a tls.Config configuration to be used with the listener. See examples folder for basic and mutual-tls use. 21 | TLSConfig *tls.Config 22 | } 23 | 24 | // EstablishFn is a callback function for establishing new clients. 25 | type EstablishFn func(id string, c net.Conn) error 26 | 27 | // CloseFn is a callback function for closing all listener clients. 28 | type CloseFn func(id string) 29 | 30 | // Listener is an interface for network listeners. A network listener listens 31 | // for incoming client connections and adds them to the server. 32 | type Listener interface { 33 | Init(*slog.Logger) error // open the network address 34 | Serve(EstablishFn) // starting actively listening for new connections 35 | ID() string // return the id of the listener 36 | Address() string // the address of the listener 37 | Protocol() string // the protocol in use by the listener 38 | Close(CloseFn) // stop and close the listener 39 | } 40 | 41 | // Listeners contains the network listeners for the broker. 42 | type Listeners struct { 43 | ClientsWg sync.WaitGroup // a waitgroup that waits for all clients in all listeners to finish. 44 | internal map[string]Listener // a map of active listeners. 45 | sync.RWMutex 46 | } 47 | 48 | // New returns a new instance of Listeners. 49 | func New() *Listeners { 50 | return &Listeners{ 51 | internal: map[string]Listener{}, 52 | } 53 | } 54 | 55 | // Add adds a new listener to the listeners map, keyed on id. 56 | func (l *Listeners) Add(val Listener) { 57 | l.Lock() 58 | defer l.Unlock() 59 | l.internal[val.ID()] = val 60 | } 61 | 62 | // Get returns the value of a listener if it exists. 63 | func (l *Listeners) Get(id string) (Listener, bool) { 64 | l.RLock() 65 | defer l.RUnlock() 66 | val, ok := l.internal[id] 67 | return val, ok 68 | } 69 | 70 | // Len returns the length of the listeners map. 71 | func (l *Listeners) Len() int { 72 | l.RLock() 73 | defer l.RUnlock() 74 | return len(l.internal) 75 | } 76 | 77 | // Delete removes a listener from the internal map. 78 | func (l *Listeners) Delete(id string) { 79 | l.Lock() 80 | defer l.Unlock() 81 | delete(l.internal, id) 82 | } 83 | 84 | // Serve starts a listener serving from the internal map. 85 | func (l *Listeners) Serve(id string, establisher EstablishFn) { 86 | l.RLock() 87 | defer l.RUnlock() 88 | listener := l.internal[id] 89 | 90 | go func(e EstablishFn) { 91 | listener.Serve(e) 92 | }(establisher) 93 | } 94 | 95 | // ServeAll starts all listeners serving from the internal map. 96 | func (l *Listeners) ServeAll(establisher EstablishFn) { 97 | l.RLock() 98 | i := 0 99 | ids := make([]string, len(l.internal)) 100 | for id := range l.internal { 101 | ids[i] = id 102 | i++ 103 | } 104 | l.RUnlock() 105 | 106 | for _, id := range ids { 107 | l.Serve(id, establisher) 108 | } 109 | } 110 | 111 | // Close stops a listener from the internal map. 112 | func (l *Listeners) Close(id string, closer CloseFn) { 113 | l.RLock() 114 | defer l.RUnlock() 115 | if listener, ok := l.internal[id]; ok { 116 | listener.Close(closer) 117 | } 118 | } 119 | 120 | // CloseAll iterates and closes all registered listeners. 121 | func (l *Listeners) CloseAll(closer CloseFn) { 122 | l.RLock() 123 | i := 0 124 | ids := make([]string, len(l.internal)) 125 | for id := range l.internal { 126 | ids[i] = id 127 | i++ 128 | } 129 | l.RUnlock() 130 | 131 | for _, id := range ids { 132 | l.Close(id, closer) 133 | } 134 | l.ClientsWg.Wait() 135 | } 136 | -------------------------------------------------------------------------------- /listeners/listeners_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package listeners 6 | 7 | import ( 8 | "crypto/tls" 9 | "log" 10 | "os" 11 | "testing" 12 | "time" 13 | 14 | "log/slog" 15 | 16 | "github.com/stretchr/testify/require" 17 | ) 18 | 19 | const testAddr = ":22222" 20 | 21 | var ( 22 | basicConfig = Config{ID: "t1", Address: testAddr} 23 | tlsConfig = Config{ID: "t1", Address: testAddr, TLSConfig: tlsConfigBasic} 24 | 25 | logger = slog.New(slog.NewTextHandler(os.Stdout, nil)) 26 | 27 | testCertificate = []byte(`-----BEGIN CERTIFICATE----- 28 | MIIB/zCCAWgCCQDm3jV+lSF1AzANBgkqhkiG9w0BAQsFADBEMQswCQYDVQQGEwJB 29 | VTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28xDTALBgNV 30 | BAsMBE1RVFQwHhcNMjAwMTA0MjAzMzQyWhcNMjEwMTAzMjAzMzQyWjBEMQswCQYD 31 | VQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UECgwITW9jaGkgQ28x 32 | DTALBgNVBAsMBE1RVFQwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAKz2bUz3 33 | AOymssVLuvSOEbQ/sF8C/Ill8nRTd7sX9WBIxHJZf+gVn8lQ4BTQ0NchLDRIlpbi 34 | OuZgktpd6ba8sIfVM4jbVprctky5tGsyHRFwL/GAycCtKwvuXkvcwSwLvB8b29EI 35 | MLQ/3vNnYuC3eZ4qqxlODJgRsfQ7mUNB8zkLAgMBAAEwDQYJKoZIhvcNAQELBQAD 36 | gYEAiMoKnQaD0F/J332arGvcmtbHmF2XZp/rGy3dooPug8+OPUSAJY9vTfxJwOsQ 37 | qN1EcI+kIgrGxzA3VRfVYV8gr7IX+fUYfVCaPGcDCfPvo/Ihu757afJRVvpafWgy 38 | zSpDZYu6C62h3KSzMJxffDjy7/2t8oYbTzkLSamsHJJjLZw= 39 | -----END CERTIFICATE-----`) 40 | 41 | testPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY----- 42 | MIICXAIBAAKBgQCs9m1M9wDsprLFS7r0jhG0P7BfAvyJZfJ0U3e7F/VgSMRyWX/o 43 | FZ/JUOAU0NDXISw0SJaW4jrmYJLaXem2vLCH1TOI21aa3LZMubRrMh0RcC/xgMnA 44 | rSsL7l5L3MEsC7wfG9vRCDC0P97zZ2Lgt3meKqsZTgyYEbH0O5lDQfM5CwIDAQAB 45 | AoGBAKlmVVirFqmw/qhDaqD4wBg0xI3Zw/Lh+Vu7ICoK5hVeT6DbTW3GOBAY+M8K 46 | UXBSGhQ+/9ZZTmyyK0JZ9nw2RAG3lONU6wS41pZhB7F4siatZfP/JJfU6p+ohe8m 47 | n22hTw4brY/8E/tjuki9T5e2GeiUPBhjbdECkkVXMYBPKDZhAkEA5h/b/HBcsIZZ 48 | mL2d3dyWkXR/IxngQa4NH3124M8MfBqCYXPLgD7RDI+3oT/uVe+N0vu6+7CSMVx6 49 | INM67CuE0QJBAMBpKW54cfMsMya3CM1BfdPEBzDT5kTMqxJ7ez164PHv9CJCnL0Z 50 | AuWgM/p2WNbAF1yHNxw1eEfNbUWwVX2yhxsCQEtnMQvcPWLSAtWbe/jQaL2scGQt 51 | /F9JCp/A2oz7Cto3TXVlHc8dxh3ZkY/ShOO/pLb3KOODjcOCy7mpvOrZr6ECQH32 52 | WoFPqImhrfryaHi3H0C7XFnC30S7GGOJIy0kfI7mn9St9x50eUkKj/yv7YjpSGHy 53 | w0lcV9npyleNEOqxLXECQBL3VRGCfZfhfFpL8z+5+HPKXw6FxWr+p5h8o3CZ6Yi3 54 | OJVN3Mfo6mbz34wswrEdMXn25MzAwbhFQvCVpPZrFwc= 55 | -----END RSA PRIVATE KEY-----`) 56 | 57 | tlsConfigBasic *tls.Config 58 | ) 59 | 60 | func init() { 61 | cert, err := tls.X509KeyPair(testCertificate, testPrivateKey) 62 | if err != nil { 63 | log.Fatal(err) 64 | } 65 | 66 | // Basic TLS Config 67 | tlsConfigBasic = &tls.Config{ 68 | MinVersion: tls.VersionTLS12, 69 | Certificates: []tls.Certificate{cert}, 70 | } 71 | tlsConfig.TLSConfig = tlsConfigBasic 72 | } 73 | 74 | func TestNew(t *testing.T) { 75 | l := New() 76 | require.NotNil(t, l.internal) 77 | } 78 | 79 | func TestAddListener(t *testing.T) { 80 | l := New() 81 | l.Add(NewMockListener("t1", testAddr)) 82 | require.Contains(t, l.internal, "t1") 83 | } 84 | 85 | func TestGetListener(t *testing.T) { 86 | l := New() 87 | l.Add(NewMockListener("t1", testAddr)) 88 | l.Add(NewMockListener("t2", testAddr)) 89 | require.Contains(t, l.internal, "t1") 90 | require.Contains(t, l.internal, "t2") 91 | 92 | g, ok := l.Get("t1") 93 | require.True(t, ok) 94 | require.Equal(t, g.ID(), "t1") 95 | } 96 | 97 | func TestLenListener(t *testing.T) { 98 | l := New() 99 | l.Add(NewMockListener("t1", testAddr)) 100 | l.Add(NewMockListener("t2", testAddr)) 101 | require.Contains(t, l.internal, "t1") 102 | require.Contains(t, l.internal, "t2") 103 | require.Equal(t, 2, l.Len()) 104 | } 105 | 106 | func TestDeleteListener(t *testing.T) { 107 | l := New() 108 | l.Add(NewMockListener("t1", testAddr)) 109 | require.Contains(t, l.internal, "t1") 110 | l.Delete("t1") 111 | _, ok := l.Get("t1") 112 | require.False(t, ok) 113 | require.Nil(t, l.internal["t1"]) 114 | } 115 | 116 | func TestServeListener(t *testing.T) { 117 | l := New() 118 | l.Add(NewMockListener("t1", testAddr)) 119 | l.Serve("t1", MockEstablisher) 120 | time.Sleep(time.Millisecond) 121 | require.True(t, l.internal["t1"].(*MockListener).IsServing()) 122 | 123 | l.Close("t1", MockCloser) 124 | require.False(t, l.internal["t1"].(*MockListener).IsServing()) 125 | } 126 | 127 | func TestServeAllListeners(t *testing.T) { 128 | l := New() 129 | l.Add(NewMockListener("t1", testAddr)) 130 | l.Add(NewMockListener("t2", testAddr)) 131 | l.Add(NewMockListener("t3", testAddr)) 132 | l.ServeAll(MockEstablisher) 133 | time.Sleep(time.Millisecond) 134 | 135 | require.True(t, l.internal["t1"].(*MockListener).IsServing()) 136 | require.True(t, l.internal["t2"].(*MockListener).IsServing()) 137 | require.True(t, l.internal["t3"].(*MockListener).IsServing()) 138 | 139 | l.Close("t1", MockCloser) 140 | l.Close("t2", MockCloser) 141 | l.Close("t3", MockCloser) 142 | 143 | require.False(t, l.internal["t1"].(*MockListener).IsServing()) 144 | require.False(t, l.internal["t2"].(*MockListener).IsServing()) 145 | require.False(t, l.internal["t3"].(*MockListener).IsServing()) 146 | } 147 | 148 | func TestCloseListener(t *testing.T) { 149 | l := New() 150 | mocked := NewMockListener("t1", testAddr) 151 | l.Add(mocked) 152 | l.Serve("t1", MockEstablisher) 153 | time.Sleep(time.Millisecond) 154 | var closed bool 155 | l.Close("t1", func(id string) { 156 | closed = true 157 | }) 158 | require.True(t, closed) 159 | } 160 | 161 | func TestCloseAllListeners(t *testing.T) { 162 | l := New() 163 | l.Add(NewMockListener("t1", testAddr)) 164 | l.Add(NewMockListener("t2", testAddr)) 165 | l.Add(NewMockListener("t3", testAddr)) 166 | l.ServeAll(MockEstablisher) 167 | time.Sleep(time.Millisecond) 168 | require.True(t, l.internal["t1"].(*MockListener).IsServing()) 169 | require.True(t, l.internal["t2"].(*MockListener).IsServing()) 170 | require.True(t, l.internal["t3"].(*MockListener).IsServing()) 171 | 172 | closed := make(map[string]bool) 173 | l.CloseAll(func(id string) { 174 | closed[id] = true 175 | }) 176 | require.Contains(t, closed, "t1") 177 | require.Contains(t, closed, "t2") 178 | require.Contains(t, closed, "t3") 179 | require.True(t, closed["t1"]) 180 | require.True(t, closed["t2"]) 181 | require.True(t, closed["t3"]) 182 | } 183 | -------------------------------------------------------------------------------- /listeners/mock.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package listeners 6 | 7 | import ( 8 | "fmt" 9 | "net" 10 | "sync" 11 | 12 | "log/slog" 13 | ) 14 | 15 | const TypeMock = "mock" 16 | 17 | // MockEstablisher is a function signature which can be used in testing. 18 | func MockEstablisher(id string, c net.Conn) error { 19 | return nil 20 | } 21 | 22 | // MockCloser is a function signature which can be used in testing. 23 | func MockCloser(id string) {} 24 | 25 | // MockListener is a mock listener for establishing client connections. 26 | type MockListener struct { 27 | sync.RWMutex 28 | id string // the id of the listener 29 | address string // the network address the listener binds to 30 | Config *Config // configuration for the listener 31 | done chan bool // indicate the listener is done 32 | Serving bool // indicate the listener is serving 33 | Listening bool // indiciate the listener is listening 34 | ErrListen bool // throw an error on listen 35 | } 36 | 37 | // NewMockListener returns a new instance of MockListener. 38 | func NewMockListener(id, address string) *MockListener { 39 | return &MockListener{ 40 | id: id, 41 | address: address, 42 | done: make(chan bool), 43 | } 44 | } 45 | 46 | // Serve serves the mock listener. 47 | func (l *MockListener) Serve(establisher EstablishFn) { 48 | l.Lock() 49 | l.Serving = true 50 | l.Unlock() 51 | 52 | for range l.done { 53 | return 54 | } 55 | } 56 | 57 | // Init initializes the listener. 58 | func (l *MockListener) Init(log *slog.Logger) error { 59 | if l.ErrListen { 60 | return fmt.Errorf("listen failure") 61 | } 62 | 63 | l.Lock() 64 | defer l.Unlock() 65 | l.Listening = true 66 | return nil 67 | } 68 | 69 | // ID returns the id of the mock listener. 70 | func (l *MockListener) ID() string { 71 | return l.id 72 | } 73 | 74 | // Address returns the address of the listener. 75 | func (l *MockListener) Address() string { 76 | return l.address 77 | } 78 | 79 | // Protocol returns the address of the listener. 80 | func (l *MockListener) Protocol() string { 81 | return "mock" 82 | } 83 | 84 | // Close closes the mock listener. 85 | func (l *MockListener) Close(closer CloseFn) { 86 | l.Lock() 87 | defer l.Unlock() 88 | l.Serving = false 89 | closer(l.id) 90 | close(l.done) 91 | } 92 | 93 | // IsServing indicates whether the mock listener is serving. 94 | func (l *MockListener) IsServing() bool { 95 | l.Lock() 96 | defer l.Unlock() 97 | return l.Serving 98 | } 99 | 100 | // IsListening indicates whether the mock listener is listening. 101 | func (l *MockListener) IsListening() bool { 102 | l.Lock() 103 | defer l.Unlock() 104 | return l.Listening 105 | } 106 | -------------------------------------------------------------------------------- /listeners/mock_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package listeners 6 | 7 | import ( 8 | "net" 9 | "testing" 10 | "time" 11 | 12 | "github.com/stretchr/testify/require" 13 | ) 14 | 15 | func TestMockEstablisher(t *testing.T) { 16 | _, w := net.Pipe() 17 | err := MockEstablisher("t1", w) 18 | require.NoError(t, err) 19 | _ = w.Close() 20 | } 21 | 22 | func TestNewMockListener(t *testing.T) { 23 | mocked := NewMockListener("t1", testAddr) 24 | require.Equal(t, "t1", mocked.id) 25 | require.Equal(t, testAddr, mocked.address) 26 | } 27 | func TestMockListenerID(t *testing.T) { 28 | mocked := NewMockListener("t1", testAddr) 29 | require.Equal(t, "t1", mocked.ID()) 30 | } 31 | 32 | func TestMockListenerAddress(t *testing.T) { 33 | mocked := NewMockListener("t1", testAddr) 34 | require.Equal(t, testAddr, mocked.Address()) 35 | } 36 | func TestMockListenerProtocol(t *testing.T) { 37 | mocked := NewMockListener("t1", testAddr) 38 | require.Equal(t, "mock", mocked.Protocol()) 39 | } 40 | 41 | func TestNewMockListenerIsListening(t *testing.T) { 42 | mocked := NewMockListener("t1", testAddr) 43 | require.Equal(t, false, mocked.IsListening()) 44 | } 45 | 46 | func TestNewMockListenerIsServing(t *testing.T) { 47 | mocked := NewMockListener("t1", testAddr) 48 | require.Equal(t, false, mocked.IsServing()) 49 | } 50 | 51 | func TestNewMockListenerInit(t *testing.T) { 52 | mocked := NewMockListener("t1", testAddr) 53 | require.Equal(t, "t1", mocked.id) 54 | require.Equal(t, testAddr, mocked.address) 55 | 56 | require.Equal(t, false, mocked.IsListening()) 57 | err := mocked.Init(nil) 58 | require.NoError(t, err) 59 | require.Equal(t, true, mocked.IsListening()) 60 | } 61 | 62 | func TestNewMockListenerInitFailure(t *testing.T) { 63 | mocked := NewMockListener("t1", testAddr) 64 | mocked.ErrListen = true 65 | err := mocked.Init(nil) 66 | require.Error(t, err) 67 | } 68 | 69 | func TestMockListenerServe(t *testing.T) { 70 | mocked := NewMockListener("t1", testAddr) 71 | require.Equal(t, false, mocked.IsServing()) 72 | 73 | o := make(chan bool) 74 | go func(o chan bool) { 75 | mocked.Serve(MockEstablisher) 76 | o <- true 77 | }(o) 78 | 79 | time.Sleep(time.Millisecond) // easy non-channel wait for start of serving 80 | require.Equal(t, true, mocked.IsServing()) 81 | 82 | var closed bool 83 | mocked.Close(func(id string) { 84 | closed = true 85 | }) 86 | require.Equal(t, true, closed) 87 | <-o 88 | 89 | _ = mocked.Init(nil) 90 | } 91 | 92 | func TestMockListenerClose(t *testing.T) { 93 | mocked := NewMockListener("t1", testAddr) 94 | var closed bool 95 | mocked.Close(func(id string) { 96 | closed = true 97 | }) 98 | require.Equal(t, true, closed) 99 | } 100 | -------------------------------------------------------------------------------- /listeners/net.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2023 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: Jeroen Rinzema 4 | 5 | package listeners 6 | 7 | import ( 8 | "net" 9 | "sync" 10 | "sync/atomic" 11 | 12 | "log/slog" 13 | ) 14 | 15 | // Net is a listener for establishing client connections on basic TCP protocol. 16 | type Net struct { // [MQTT-4.2.0-1] 17 | mu sync.Mutex 18 | listener net.Listener // a net.Listener which will listen for new clients 19 | id string // the internal id of the listener 20 | log *slog.Logger // server logger 21 | end uint32 // ensure the close methods are only called once 22 | } 23 | 24 | // NewNet initialises and returns a listener serving incoming connections on the given net.Listener 25 | func NewNet(id string, listener net.Listener) *Net { 26 | return &Net{ 27 | id: id, 28 | listener: listener, 29 | } 30 | } 31 | 32 | // ID returns the id of the listener. 33 | func (l *Net) ID() string { 34 | return l.id 35 | } 36 | 37 | // Address returns the address of the listener. 38 | func (l *Net) Address() string { 39 | return l.listener.Addr().String() 40 | } 41 | 42 | // Protocol returns the network of the listener. 43 | func (l *Net) Protocol() string { 44 | return l.listener.Addr().Network() 45 | } 46 | 47 | // Init initializes the listener. 48 | func (l *Net) Init(log *slog.Logger) error { 49 | l.log = log 50 | return nil 51 | } 52 | 53 | // Serve starts waiting for new TCP connections, and calls the establish 54 | // connection callback for any received. 55 | func (l *Net) Serve(establish EstablishFn) { 56 | for { 57 | if atomic.LoadUint32(&l.end) == 1 { 58 | return 59 | } 60 | 61 | conn, err := l.listener.Accept() 62 | if err != nil { 63 | return 64 | } 65 | 66 | if atomic.LoadUint32(&l.end) == 0 { 67 | go func() { 68 | err = establish(l.id, conn) 69 | if err != nil { 70 | l.log.Warn("", "error", err) 71 | } 72 | }() 73 | } 74 | } 75 | } 76 | 77 | // Close closes the listener and any client connections. 78 | func (l *Net) Close(closeClients CloseFn) { 79 | l.mu.Lock() 80 | defer l.mu.Unlock() 81 | 82 | if atomic.CompareAndSwapUint32(&l.end, 0, 1) { 83 | closeClients(l.id) 84 | } 85 | 86 | if l.listener != nil { 87 | err := l.listener.Close() 88 | if err != nil { 89 | return 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /listeners/net_test.go: -------------------------------------------------------------------------------- 1 | package listeners 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | "testing" 7 | "time" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestNewNet(t *testing.T) { 13 | n, err := net.Listen("tcp", "127.0.0.1:0") 14 | require.NoError(t, err) 15 | 16 | l := NewNet("t1", n) 17 | require.Equal(t, "t1", l.id) 18 | } 19 | 20 | func TestNetID(t *testing.T) { 21 | n, err := net.Listen("tcp", "127.0.0.1:0") 22 | require.NoError(t, err) 23 | 24 | l := NewNet("t1", n) 25 | require.Equal(t, "t1", l.ID()) 26 | } 27 | 28 | func TestNetAddress(t *testing.T) { 29 | n, err := net.Listen("tcp", "127.0.0.1:0") 30 | require.NoError(t, err) 31 | 32 | l := NewNet("t1", n) 33 | require.Equal(t, n.Addr().String(), l.Address()) 34 | } 35 | 36 | func TestNetProtocol(t *testing.T) { 37 | n, err := net.Listen("tcp", "127.0.0.1:0") 38 | require.NoError(t, err) 39 | 40 | l := NewNet("t1", n) 41 | require.Equal(t, "tcp", l.Protocol()) 42 | } 43 | 44 | func TestNetInit(t *testing.T) { 45 | n, err := net.Listen("tcp", "127.0.0.1:0") 46 | require.NoError(t, err) 47 | 48 | l := NewNet("t1", n) 49 | err = l.Init(logger) 50 | l.Close(MockCloser) 51 | require.NoError(t, err) 52 | } 53 | 54 | func TestNetServeAndClose(t *testing.T) { 55 | n, err := net.Listen("tcp", "127.0.0.1:0") 56 | require.NoError(t, err) 57 | 58 | l := NewNet("t1", n) 59 | err = l.Init(logger) 60 | require.NoError(t, err) 61 | 62 | o := make(chan bool) 63 | go func(o chan bool) { 64 | l.Serve(MockEstablisher) 65 | o <- true 66 | }(o) 67 | 68 | time.Sleep(time.Millisecond) 69 | 70 | var closed bool 71 | l.Close(func(id string) { 72 | closed = true 73 | }) 74 | 75 | require.True(t, closed) 76 | <-o 77 | 78 | l.Close(MockCloser) // coverage: close closed 79 | l.Serve(MockEstablisher) // coverage: serve closed 80 | } 81 | 82 | func TestNetEstablishThenEnd(t *testing.T) { 83 | n, err := net.Listen("tcp", "127.0.0.1:0") 84 | require.NoError(t, err) 85 | 86 | l := NewNet("t1", n) 87 | err = l.Init(logger) 88 | require.NoError(t, err) 89 | 90 | o := make(chan bool) 91 | established := make(chan bool) 92 | go func() { 93 | l.Serve(func(id string, c net.Conn) error { 94 | established <- true 95 | return errors.New("ending") // return an error to exit immediately 96 | }) 97 | o <- true 98 | }() 99 | 100 | time.Sleep(time.Millisecond) 101 | _, _ = net.Dial("tcp", n.Addr().String()) 102 | require.Equal(t, true, <-established) 103 | l.Close(MockCloser) 104 | <-o 105 | } 106 | -------------------------------------------------------------------------------- /listeners/tcp.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package listeners 6 | 7 | import ( 8 | "crypto/tls" 9 | "net" 10 | "sync" 11 | "sync/atomic" 12 | 13 | "log/slog" 14 | ) 15 | 16 | const TypeTCP = "tcp" 17 | 18 | // TCP is a listener for establishing client connections on basic TCP protocol. 19 | type TCP struct { // [MQTT-4.2.0-1] 20 | sync.RWMutex 21 | id string // the internal id of the listener 22 | address string // the network address to bind to 23 | listen net.Listener // a net.Listener which will listen for new clients 24 | config Config // configuration values for the listener 25 | log *slog.Logger // server logger 26 | end uint32 // ensure the close methods are only called once 27 | } 28 | 29 | // NewTCP initializes and returns a new TCP listener, listening on an address. 30 | func NewTCP(config Config) *TCP { 31 | return &TCP{ 32 | id: config.ID, 33 | address: config.Address, 34 | config: config, 35 | } 36 | } 37 | 38 | // ID returns the id of the listener. 39 | func (l *TCP) ID() string { 40 | return l.id 41 | } 42 | 43 | // Address returns the address of the listener. 44 | func (l *TCP) Address() string { 45 | if l.listen != nil { 46 | return l.listen.Addr().String() 47 | } 48 | return l.address 49 | } 50 | 51 | // Protocol returns the address of the listener. 52 | func (l *TCP) Protocol() string { 53 | return "tcp" 54 | } 55 | 56 | // Init initializes the listener. 57 | func (l *TCP) Init(log *slog.Logger) error { 58 | l.log = log 59 | 60 | var err error 61 | if l.config.TLSConfig != nil { 62 | l.listen, err = tls.Listen("tcp", l.address, l.config.TLSConfig) 63 | } else { 64 | l.listen, err = net.Listen("tcp", l.address) 65 | } 66 | 67 | return err 68 | } 69 | 70 | // Serve starts waiting for new TCP connections, and calls the establish 71 | // connection callback for any received. 72 | func (l *TCP) Serve(establish EstablishFn) { 73 | for { 74 | if atomic.LoadUint32(&l.end) == 1 { 75 | return 76 | } 77 | 78 | conn, err := l.listen.Accept() 79 | if err != nil { 80 | return 81 | } 82 | 83 | if atomic.LoadUint32(&l.end) == 0 { 84 | go func() { 85 | err = establish(l.id, conn) 86 | if err != nil { 87 | l.log.Warn("", "error", err) 88 | } 89 | }() 90 | } 91 | } 92 | } 93 | 94 | // Close closes the listener and any client connections. 95 | func (l *TCP) Close(closeClients CloseFn) { 96 | l.Lock() 97 | defer l.Unlock() 98 | 99 | if atomic.CompareAndSwapUint32(&l.end, 0, 1) { 100 | closeClients(l.id) 101 | } 102 | 103 | if l.listen != nil { 104 | err := l.listen.Close() 105 | if err != nil { 106 | return 107 | } 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /listeners/tcp_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package listeners 6 | 7 | import ( 8 | "errors" 9 | "net" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | func TestNewTCP(t *testing.T) { 17 | l := NewTCP(basicConfig) 18 | require.Equal(t, "t1", l.id) 19 | require.Equal(t, testAddr, l.address) 20 | } 21 | 22 | func TestTCPID(t *testing.T) { 23 | l := NewTCP(basicConfig) 24 | require.Equal(t, "t1", l.ID()) 25 | } 26 | 27 | func TestTCPAddress(t *testing.T) { 28 | l := NewTCP(basicConfig) 29 | require.Equal(t, testAddr, l.Address()) 30 | } 31 | 32 | func TestTCPProtocol(t *testing.T) { 33 | l := NewTCP(basicConfig) 34 | require.Equal(t, "tcp", l.Protocol()) 35 | } 36 | 37 | func TestTCPProtocolTLS(t *testing.T) { 38 | l := NewTCP(tlsConfig) 39 | _ = l.Init(logger) 40 | defer l.listen.Close() 41 | require.Equal(t, "tcp", l.Protocol()) 42 | } 43 | 44 | func TestTCPInit(t *testing.T) { 45 | l := NewTCP(basicConfig) 46 | err := l.Init(logger) 47 | l.Close(MockCloser) 48 | require.NoError(t, err) 49 | 50 | l2 := NewTCP(tlsConfig) 51 | err = l2.Init(logger) 52 | l2.Close(MockCloser) 53 | require.NoError(t, err) 54 | require.NotNil(t, l2.config.TLSConfig) 55 | } 56 | 57 | func TestTCPServeAndClose(t *testing.T) { 58 | l := NewTCP(basicConfig) 59 | err := l.Init(logger) 60 | require.NoError(t, err) 61 | 62 | o := make(chan bool) 63 | go func(o chan bool) { 64 | l.Serve(MockEstablisher) 65 | o <- true 66 | }(o) 67 | 68 | time.Sleep(time.Millisecond) 69 | 70 | var closed bool 71 | l.Close(func(id string) { 72 | closed = true 73 | }) 74 | 75 | require.True(t, closed) 76 | <-o 77 | 78 | l.Close(MockCloser) // coverage: close closed 79 | l.Serve(MockEstablisher) // coverage: serve closed 80 | } 81 | 82 | func TestTCPServeTLSAndClose(t *testing.T) { 83 | l := NewTCP(tlsConfig) 84 | err := l.Init(logger) 85 | require.NoError(t, err) 86 | 87 | o := make(chan bool) 88 | go func(o chan bool) { 89 | l.Serve(MockEstablisher) 90 | o <- true 91 | }(o) 92 | 93 | time.Sleep(time.Millisecond) 94 | 95 | var closed bool 96 | l.Close(func(id string) { 97 | closed = true 98 | }) 99 | 100 | require.Equal(t, true, closed) 101 | <-o 102 | } 103 | 104 | func TestTCPEstablishThenEnd(t *testing.T) { 105 | l := NewTCP(basicConfig) 106 | err := l.Init(logger) 107 | require.NoError(t, err) 108 | 109 | o := make(chan bool) 110 | established := make(chan bool) 111 | go func() { 112 | l.Serve(func(id string, c net.Conn) error { 113 | established <- true 114 | return errors.New("ending") // return an error to exit immediately 115 | }) 116 | o <- true 117 | }() 118 | 119 | time.Sleep(time.Millisecond) 120 | _, _ = net.Dial("tcp", l.listen.Addr().String()) 121 | require.Equal(t, true, <-established) 122 | l.Close(MockCloser) 123 | <-o 124 | } 125 | -------------------------------------------------------------------------------- /listeners/unixsock.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: jason@zgwit.com 4 | 5 | package listeners 6 | 7 | import ( 8 | "net" 9 | "os" 10 | "sync" 11 | "sync/atomic" 12 | 13 | "log/slog" 14 | ) 15 | 16 | const TypeUnix = "unix" 17 | 18 | // UnixSock is a listener for establishing client connections on basic UnixSock protocol. 19 | type UnixSock struct { 20 | sync.RWMutex 21 | id string // the internal id of the listener. 22 | address string // the network address to bind to. 23 | config Config // configuration values for the listener 24 | listen net.Listener // a net.Listener which will listen for new clients. 25 | log *slog.Logger // server logger 26 | end uint32 // ensure the close methods are only called once. 27 | } 28 | 29 | // NewUnixSock initializes and returns a new UnixSock listener, listening on an address. 30 | func NewUnixSock(config Config) *UnixSock { 31 | return &UnixSock{ 32 | id: config.ID, 33 | address: config.Address, 34 | config: config, 35 | } 36 | } 37 | 38 | // ID returns the id of the listener. 39 | func (l *UnixSock) ID() string { 40 | return l.id 41 | } 42 | 43 | // Address returns the address of the listener. 44 | func (l *UnixSock) Address() string { 45 | return l.address 46 | } 47 | 48 | // Protocol returns the address of the listener. 49 | func (l *UnixSock) Protocol() string { 50 | return "unix" 51 | } 52 | 53 | // Init initializes the listener. 54 | func (l *UnixSock) Init(log *slog.Logger) error { 55 | l.log = log 56 | 57 | var err error 58 | _ = os.Remove(l.address) 59 | l.listen, err = net.Listen("unix", l.address) 60 | return err 61 | } 62 | 63 | // Serve starts waiting for new UnixSock connections, and calls the establish 64 | // connection callback for any received. 65 | func (l *UnixSock) Serve(establish EstablishFn) { 66 | for { 67 | if atomic.LoadUint32(&l.end) == 1 { 68 | return 69 | } 70 | 71 | conn, err := l.listen.Accept() 72 | if err != nil { 73 | return 74 | } 75 | 76 | if atomic.LoadUint32(&l.end) == 0 { 77 | go func() { 78 | err = establish(l.id, conn) 79 | if err != nil { 80 | l.log.Warn("", "error", err) 81 | } 82 | }() 83 | } 84 | } 85 | } 86 | 87 | // Close closes the listener and any client connections. 88 | func (l *UnixSock) Close(closeClients CloseFn) { 89 | l.Lock() 90 | defer l.Unlock() 91 | 92 | if atomic.CompareAndSwapUint32(&l.end, 0, 1) { 93 | closeClients(l.id) 94 | } 95 | 96 | if l.listen != nil { 97 | err := l.listen.Close() 98 | if err != nil { 99 | return 100 | } 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /listeners/unixsock_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: jason@zgwit.com 4 | 5 | package listeners 6 | 7 | import ( 8 | "errors" 9 | "net" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/require" 14 | ) 15 | 16 | const testUnixAddr = "mochi.sock" 17 | 18 | var ( 19 | unixConfig = Config{ID: "t1", Address: testUnixAddr} 20 | ) 21 | 22 | func TestNewUnixSock(t *testing.T) { 23 | l := NewUnixSock(unixConfig) 24 | require.Equal(t, "t1", l.id) 25 | require.Equal(t, testUnixAddr, l.address) 26 | } 27 | 28 | func TestUnixSockID(t *testing.T) { 29 | l := NewUnixSock(unixConfig) 30 | require.Equal(t, "t1", l.ID()) 31 | } 32 | 33 | func TestUnixSockAddress(t *testing.T) { 34 | l := NewUnixSock(unixConfig) 35 | require.Equal(t, testUnixAddr, l.Address()) 36 | } 37 | 38 | func TestUnixSockProtocol(t *testing.T) { 39 | l := NewUnixSock(unixConfig) 40 | require.Equal(t, "unix", l.Protocol()) 41 | } 42 | 43 | func TestUnixSockInit(t *testing.T) { 44 | l := NewUnixSock(unixConfig) 45 | err := l.Init(logger) 46 | l.Close(MockCloser) 47 | require.NoError(t, err) 48 | 49 | t2Config := unixConfig 50 | t2Config.ID = "t2" 51 | l2 := NewUnixSock(t2Config) 52 | err = l2.Init(logger) 53 | l2.Close(MockCloser) 54 | require.NoError(t, err) 55 | } 56 | 57 | func TestUnixSockServeAndClose(t *testing.T) { 58 | l := NewUnixSock(unixConfig) 59 | err := l.Init(logger) 60 | require.NoError(t, err) 61 | 62 | o := make(chan bool) 63 | go func(o chan bool) { 64 | l.Serve(MockEstablisher) 65 | o <- true 66 | }(o) 67 | 68 | time.Sleep(time.Millisecond) 69 | 70 | var closed bool 71 | l.Close(func(id string) { 72 | closed = true 73 | }) 74 | 75 | require.True(t, closed) 76 | <-o 77 | 78 | l.Close(MockCloser) // coverage: close closed 79 | l.Serve(MockEstablisher) // coverage: serve closed 80 | } 81 | 82 | func TestUnixSockEstablishThenEnd(t *testing.T) { 83 | l := NewUnixSock(unixConfig) 84 | err := l.Init(logger) 85 | require.NoError(t, err) 86 | 87 | o := make(chan bool) 88 | established := make(chan bool) 89 | go func() { 90 | l.Serve(func(id string, c net.Conn) error { 91 | established <- true 92 | return errors.New("ending") // return an error to exit immediately 93 | }) 94 | o <- true 95 | }() 96 | 97 | time.Sleep(time.Millisecond) 98 | _, _ = net.Dial("unix", l.listen.Addr().String()) 99 | require.Equal(t, true, <-established) 100 | l.Close(MockCloser) 101 | <-o 102 | } 103 | -------------------------------------------------------------------------------- /listeners/websocket.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package listeners 6 | 7 | import ( 8 | "context" 9 | "errors" 10 | "io" 11 | "net" 12 | "net/http" 13 | "sync" 14 | "sync/atomic" 15 | "time" 16 | 17 | "log/slog" 18 | 19 | "github.com/gorilla/websocket" 20 | ) 21 | 22 | const TypeWS = "ws" 23 | 24 | var ( 25 | // ErrInvalidMessage indicates that a message payload was not valid. 26 | ErrInvalidMessage = errors.New("message type not binary") 27 | ) 28 | 29 | // Websocket is a listener for establishing websocket connections. 30 | type Websocket struct { // [MQTT-4.2.0-1] 31 | sync.RWMutex 32 | id string // the internal id of the listener 33 | address string // the network address to bind to 34 | config Config // configuration values for the listener 35 | listen *http.Server // a http server for serving websocket connections 36 | log *slog.Logger // server logger 37 | establish EstablishFn // the server's establish connection handler 38 | upgrader *websocket.Upgrader // upgrade the incoming http/tcp connection to a websocket compliant connection. 39 | end uint32 // ensure the close methods are only called once 40 | } 41 | 42 | // NewWebsocket initializes and returns a new Websocket listener, listening on an address. 43 | func NewWebsocket(config Config) *Websocket { 44 | return &Websocket{ 45 | id: config.ID, 46 | address: config.Address, 47 | config: config, 48 | upgrader: &websocket.Upgrader{ 49 | Subprotocols: []string{"mqtt"}, 50 | CheckOrigin: func(r *http.Request) bool { 51 | return true 52 | }, 53 | }, 54 | } 55 | } 56 | 57 | // ID returns the id of the listener. 58 | func (l *Websocket) ID() string { 59 | return l.id 60 | } 61 | 62 | // Address returns the address of the listener. 63 | func (l *Websocket) Address() string { 64 | return l.address 65 | } 66 | 67 | // Protocol returns the address of the listener. 68 | func (l *Websocket) Protocol() string { 69 | if l.config.TLSConfig != nil { 70 | return "wss" 71 | } 72 | 73 | return "ws" 74 | } 75 | 76 | // Init initializes the listener. 77 | func (l *Websocket) Init(log *slog.Logger) error { 78 | l.log = log 79 | 80 | mux := http.NewServeMux() 81 | mux.HandleFunc("/", l.handler) 82 | l.listen = &http.Server{ 83 | Addr: l.address, 84 | Handler: mux, 85 | TLSConfig: l.config.TLSConfig, 86 | ReadTimeout: 60 * time.Second, 87 | WriteTimeout: 60 * time.Second, 88 | } 89 | 90 | return nil 91 | } 92 | 93 | // handler upgrades and handles an incoming websocket connection. 94 | func (l *Websocket) handler(w http.ResponseWriter, r *http.Request) { 95 | c, err := l.upgrader.Upgrade(w, r, nil) 96 | if err != nil { 97 | return 98 | } 99 | defer c.Close() 100 | 101 | err = l.establish(l.id, &wsConn{Conn: c.UnderlyingConn(), c: c}) 102 | if err != nil { 103 | l.log.Warn("", "error", err) 104 | } 105 | } 106 | 107 | // Serve starts waiting for new Websocket connections, and calls the connection 108 | // establishment callback for any received. 109 | func (l *Websocket) Serve(establish EstablishFn) { 110 | var err error 111 | l.establish = establish 112 | 113 | if l.listen.TLSConfig != nil { 114 | err = l.listen.ListenAndServeTLS("", "") 115 | } else { 116 | err = l.listen.ListenAndServe() 117 | } 118 | 119 | // After the listener has been shutdown, no need to print the http.ErrServerClosed error. 120 | if err != nil && atomic.LoadUint32(&l.end) == 0 { 121 | l.log.Error("failed to serve.", "error", err, "listener", l.id) 122 | } 123 | } 124 | 125 | // Close closes the listener and any client connections. 126 | func (l *Websocket) Close(closeClients CloseFn) { 127 | l.Lock() 128 | defer l.Unlock() 129 | 130 | if atomic.CompareAndSwapUint32(&l.end, 0, 1) { 131 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 132 | defer cancel() 133 | _ = l.listen.Shutdown(ctx) 134 | } 135 | 136 | closeClients(l.id) 137 | } 138 | 139 | // wsConn is a websocket connection which satisfies the net.Conn interface. 140 | type wsConn struct { 141 | net.Conn 142 | c *websocket.Conn 143 | 144 | // reader for the current message (can be nil) 145 | r io.Reader 146 | } 147 | 148 | // Read reads the next span of bytes from the websocket connection and returns the number of bytes read. 149 | func (ws *wsConn) Read(p []byte) (int, error) { 150 | if ws.r == nil { 151 | op, r, err := ws.c.NextReader() 152 | if err != nil { 153 | return 0, err 154 | } 155 | 156 | if op != websocket.BinaryMessage { 157 | err = ErrInvalidMessage 158 | return 0, err 159 | } 160 | 161 | ws.r = r 162 | } 163 | 164 | var n int 165 | for { 166 | // buffer is full, return what we've read so far 167 | if n == len(p) { 168 | return n, nil 169 | } 170 | 171 | br, err := ws.r.Read(p[n:]) 172 | n += br 173 | if err != nil { 174 | // when ANY error occurs, we consider this the end of the current message (either because it really is, via 175 | // io.EOF, or because something bad happened, in which case we want to drop the remainder) 176 | ws.r = nil 177 | 178 | if errors.Is(err, io.EOF) { 179 | err = nil 180 | } 181 | return n, err 182 | } 183 | } 184 | } 185 | 186 | // Write writes bytes to the websocket connection. 187 | func (ws *wsConn) Write(p []byte) (int, error) { 188 | err := ws.c.WriteMessage(websocket.BinaryMessage, p) 189 | if err != nil { 190 | return 0, err 191 | } 192 | 193 | return len(p), nil 194 | } 195 | 196 | // Close signals the underlying websocket conn to close. 197 | func (ws *wsConn) Close() error { 198 | return ws.Conn.Close() 199 | } 200 | -------------------------------------------------------------------------------- /listeners/websocket_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package listeners 6 | 7 | import ( 8 | "net" 9 | "net/http" 10 | "net/http/httptest" 11 | "strings" 12 | "testing" 13 | "time" 14 | 15 | "github.com/gorilla/websocket" 16 | "github.com/stretchr/testify/require" 17 | ) 18 | 19 | func TestNewWebsocket(t *testing.T) { 20 | l := NewWebsocket(basicConfig) 21 | require.Equal(t, "t1", l.id) 22 | require.Equal(t, testAddr, l.address) 23 | } 24 | 25 | func TestWebsocketID(t *testing.T) { 26 | l := NewWebsocket(basicConfig) 27 | require.Equal(t, "t1", l.ID()) 28 | } 29 | 30 | func TestWebsocketAddress(t *testing.T) { 31 | l := NewWebsocket(basicConfig) 32 | require.Equal(t, testAddr, l.Address()) 33 | } 34 | 35 | func TestWebsocketProtocol(t *testing.T) { 36 | l := NewWebsocket(basicConfig) 37 | require.Equal(t, "ws", l.Protocol()) 38 | } 39 | 40 | func TestWebsocketProtocolTLS(t *testing.T) { 41 | l := NewWebsocket(tlsConfig) 42 | require.Equal(t, "wss", l.Protocol()) 43 | } 44 | 45 | func TestWebsocketInit(t *testing.T) { 46 | l := NewWebsocket(basicConfig) 47 | require.Nil(t, l.listen) 48 | err := l.Init(logger) 49 | require.NoError(t, err) 50 | require.NotNil(t, l.listen) 51 | } 52 | 53 | func TestWebsocketServeAndClose(t *testing.T) { 54 | l := NewWebsocket(basicConfig) 55 | _ = l.Init(logger) 56 | 57 | o := make(chan bool) 58 | go func(o chan bool) { 59 | l.Serve(MockEstablisher) 60 | o <- true 61 | }(o) 62 | 63 | time.Sleep(time.Millisecond) 64 | 65 | var closed bool 66 | l.Close(func(id string) { 67 | closed = true 68 | }) 69 | 70 | require.True(t, closed) 71 | <-o 72 | } 73 | 74 | func TestWebsocketServeTLSAndClose(t *testing.T) { 75 | l := NewWebsocket(tlsConfig) 76 | err := l.Init(logger) 77 | require.NoError(t, err) 78 | 79 | o := make(chan bool) 80 | go func(o chan bool) { 81 | l.Serve(MockEstablisher) 82 | o <- true 83 | }(o) 84 | 85 | time.Sleep(time.Millisecond) 86 | var closed bool 87 | l.Close(func(id string) { 88 | closed = true 89 | }) 90 | require.Equal(t, true, closed) 91 | <-o 92 | } 93 | 94 | func TestWebsocketFailedToServe(t *testing.T) { 95 | config := tlsConfig 96 | config.Address = "wrong_addr" 97 | l := NewWebsocket(config) 98 | err := l.Init(logger) 99 | require.NoError(t, err) 100 | 101 | o := make(chan bool) 102 | go func(o chan bool) { 103 | l.Serve(MockEstablisher) 104 | o <- true 105 | }(o) 106 | 107 | <-o 108 | var closed bool 109 | l.Close(func(id string) { 110 | closed = true 111 | }) 112 | require.Equal(t, true, closed) 113 | } 114 | 115 | func TestWebsocketUpgrade(t *testing.T) { 116 | l := NewWebsocket(basicConfig) 117 | _ = l.Init(logger) 118 | 119 | e := make(chan bool) 120 | l.establish = func(id string, c net.Conn) error { 121 | e <- true 122 | return nil 123 | } 124 | 125 | s := httptest.NewServer(http.HandlerFunc(l.handler)) 126 | ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil) 127 | require.NoError(t, err) 128 | require.Equal(t, true, <-e) 129 | 130 | s.Close() 131 | _ = ws.Close() 132 | } 133 | 134 | func TestWebsocketConnectionReads(t *testing.T) { 135 | l := NewWebsocket(basicConfig) 136 | _ = l.Init(nil) 137 | 138 | recv := make(chan []byte) 139 | l.establish = func(id string, c net.Conn) error { 140 | var out []byte 141 | for { 142 | buf := make([]byte, 2048) 143 | n, err := c.Read(buf) 144 | require.NoError(t, err) 145 | out = append(out, buf[:n]...) 146 | if n < 2048 { 147 | break 148 | } 149 | } 150 | 151 | recv <- out 152 | return nil 153 | } 154 | 155 | s := httptest.NewServer(http.HandlerFunc(l.handler)) 156 | ws, _, err := websocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(s.URL, "http"), nil) 157 | require.NoError(t, err) 158 | 159 | pkt := make([]byte, 3000) // make sure this is >2048 160 | for i := 0; i < len(pkt); i++ { 161 | pkt[i] = byte(i % 100) 162 | } 163 | 164 | err = ws.WriteMessage(websocket.BinaryMessage, pkt) 165 | require.NoError(t, err) 166 | 167 | got := <-recv 168 | require.Equal(t, 3000, len(got)) 169 | require.Equal(t, pkt, got) 170 | 171 | s.Close() 172 | _ = ws.Close() 173 | } 174 | -------------------------------------------------------------------------------- /mempool/bufpool.go: -------------------------------------------------------------------------------- 1 | package mempool 2 | 3 | import ( 4 | "bytes" 5 | "sync" 6 | ) 7 | 8 | var bufPool = NewBuffer(0) 9 | 10 | // GetBuffer takes a Buffer from the default buffer pool 11 | func GetBuffer() *bytes.Buffer { return bufPool.Get() } 12 | 13 | // PutBuffer returns Buffer to the default buffer pool 14 | func PutBuffer(x *bytes.Buffer) { bufPool.Put(x) } 15 | 16 | type BufferPool interface { 17 | Get() *bytes.Buffer 18 | Put(x *bytes.Buffer) 19 | } 20 | 21 | // NewBuffer returns a buffer pool. The max specify the max capacity of the Buffer the pool will 22 | // return. If the Buffer becoomes large than max, it will no longer be returned to the pool. If 23 | // max <= 0, no limit will be enforced. 24 | func NewBuffer(max int) BufferPool { 25 | if max > 0 { 26 | return newBufferWithCap(max) 27 | } 28 | 29 | return newBuffer() 30 | } 31 | 32 | // Buffer is a Buffer pool. 33 | type Buffer struct { 34 | pool *sync.Pool 35 | } 36 | 37 | func newBuffer() *Buffer { 38 | return &Buffer{ 39 | pool: &sync.Pool{ 40 | New: func() any { return new(bytes.Buffer) }, 41 | }, 42 | } 43 | } 44 | 45 | // Get a Buffer from the pool. 46 | func (b *Buffer) Get() *bytes.Buffer { 47 | return b.pool.Get().(*bytes.Buffer) 48 | } 49 | 50 | // Put the Buffer back into pool. It resets the Buffer for reuse. 51 | func (b *Buffer) Put(x *bytes.Buffer) { 52 | x.Reset() 53 | b.pool.Put(x) 54 | } 55 | 56 | // BufferWithCap is a Buffer pool that 57 | type BufferWithCap struct { 58 | bp *Buffer 59 | max int 60 | } 61 | 62 | func newBufferWithCap(max int) *BufferWithCap { 63 | return &BufferWithCap{ 64 | bp: newBuffer(), 65 | max: max, 66 | } 67 | } 68 | 69 | // Get a Buffer from the pool. 70 | func (b *BufferWithCap) Get() *bytes.Buffer { 71 | return b.bp.Get() 72 | } 73 | 74 | // Put the Buffer back into the pool if the capacity doesn't exceed the limit. It resets the Buffer 75 | // for reuse. 76 | func (b *BufferWithCap) Put(x *bytes.Buffer) { 77 | if x.Cap() > b.max { 78 | return 79 | } 80 | b.bp.Put(x) 81 | } 82 | -------------------------------------------------------------------------------- /mempool/bufpool_test.go: -------------------------------------------------------------------------------- 1 | package mempool 2 | 3 | import ( 4 | "bytes" 5 | "reflect" 6 | "runtime/debug" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/require" 10 | ) 11 | 12 | func TestNewBuffer(t *testing.T) { 13 | defer debug.SetGCPercent(debug.SetGCPercent(-1)) 14 | bp := NewBuffer(1000) 15 | require.Equal(t, "*mempool.BufferWithCap", reflect.TypeOf(bp).String()) 16 | 17 | bp = NewBuffer(0) 18 | require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String()) 19 | 20 | bp = NewBuffer(-1) 21 | require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String()) 22 | } 23 | 24 | func TestBuffer(t *testing.T) { 25 | defer debug.SetGCPercent(debug.SetGCPercent(-1)) 26 | Size := 101 27 | 28 | bp := NewBuffer(0) 29 | buf := bp.Get() 30 | 31 | for i := 0; i < Size; i++ { 32 | buf.WriteByte('a') 33 | } 34 | 35 | bp.Put(buf) 36 | buf = bp.Get() 37 | require.Equal(t, 0, buf.Len()) 38 | } 39 | 40 | func TestBufferWithCap(t *testing.T) { 41 | defer debug.SetGCPercent(debug.SetGCPercent(-1)) 42 | Size := 101 43 | bp := NewBuffer(100) 44 | buf := bp.Get() 45 | 46 | for i := 0; i < Size; i++ { 47 | buf.WriteByte('a') 48 | } 49 | 50 | bp.Put(buf) 51 | buf = bp.Get() 52 | require.Equal(t, 0, buf.Len()) 53 | require.Equal(t, 0, buf.Cap()) 54 | } 55 | 56 | func BenchmarkBufferPool(b *testing.B) { 57 | bp := NewBuffer(0) 58 | 59 | b.ResetTimer() 60 | for i := 0; i < b.N; i++ { 61 | b := bp.Get() 62 | b.WriteString("this is a test") 63 | bp.Put(b) 64 | } 65 | } 66 | 67 | func BenchmarkBufferPoolWithCapLarger(b *testing.B) { 68 | bp := NewBuffer(64 * 1024) 69 | 70 | b.ResetTimer() 71 | for i := 0; i < b.N; i++ { 72 | b := bp.Get() 73 | b.WriteString("this is a test") 74 | bp.Put(b) 75 | } 76 | } 77 | 78 | func BenchmarkBufferPoolWithCapLesser(b *testing.B) { 79 | bp := NewBuffer(10) 80 | 81 | b.ResetTimer() 82 | for i := 0; i < b.N; i++ { 83 | b := bp.Get() 84 | b.WriteString("this is a test") 85 | bp.Put(b) 86 | } 87 | } 88 | 89 | func BenchmarkBufferWithoutPool(b *testing.B) { 90 | b.ResetTimer() 91 | for i := 0; i < b.N; i++ { 92 | b := new(bytes.Buffer) 93 | b.WriteString("this is a test") 94 | _ = b 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /packets/codec.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package packets 6 | 7 | import ( 8 | "bytes" 9 | "encoding/binary" 10 | "io" 11 | "unicode/utf8" 12 | "unsafe" 13 | ) 14 | 15 | // bytesToString provides a zero-alloc no-copy byte to string conversion. 16 | // via https://github.com/golang/go/issues/25484#issuecomment-391415660 17 | func bytesToString(bs []byte) string { 18 | return *(*string)(unsafe.Pointer(&bs)) 19 | } 20 | 21 | // decodeUint16 extracts the value of two bytes from a byte array. 22 | func decodeUint16(buf []byte, offset int) (uint16, int, error) { 23 | if len(buf) < offset+2 { 24 | return 0, 0, ErrMalformedOffsetUintOutOfRange 25 | } 26 | 27 | return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil 28 | } 29 | 30 | // decodeUint32 extracts the value of four bytes from a byte array. 31 | func decodeUint32(buf []byte, offset int) (uint32, int, error) { 32 | if len(buf) < offset+4 { 33 | return 0, 0, ErrMalformedOffsetUintOutOfRange 34 | } 35 | 36 | return binary.BigEndian.Uint32(buf[offset : offset+4]), offset + 4, nil 37 | } 38 | 39 | // decodeString extracts a string from a byte array, beginning at an offset. 40 | func decodeString(buf []byte, offset int) (string, int, error) { 41 | b, n, err := decodeBytes(buf, offset) 42 | if err != nil { 43 | return "", 0, err 44 | } 45 | 46 | if !validUTF8(b) { // [MQTT-1.5.4-1] [MQTT-3.1.3-5] 47 | return "", 0, ErrMalformedInvalidUTF8 48 | } 49 | 50 | return bytesToString(b), n, nil 51 | } 52 | 53 | // validUTF8 checks if the byte array contains valid UTF-8 characters. 54 | func validUTF8(b []byte) bool { 55 | return utf8.Valid(b) && bytes.IndexByte(b, 0x00) == -1 // [MQTT-1.5.4-1] [MQTT-1.5.4-2] 56 | } 57 | 58 | // decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads. 59 | func decodeBytes(buf []byte, offset int) ([]byte, int, error) { 60 | length, next, err := decodeUint16(buf, offset) 61 | if err != nil { 62 | return make([]byte, 0), 0, err 63 | } 64 | 65 | if next+int(length) > len(buf) { 66 | return make([]byte, 0), 0, ErrMalformedOffsetBytesOutOfRange 67 | } 68 | 69 | return buf[next : next+int(length)], next + int(length), nil 70 | } 71 | 72 | // decodeByte extracts the value of a byte from a byte array. 73 | func decodeByte(buf []byte, offset int) (byte, int, error) { 74 | if len(buf) <= offset { 75 | return 0, 0, ErrMalformedOffsetByteOutOfRange 76 | } 77 | return buf[offset], offset + 1, nil 78 | } 79 | 80 | // decodeByteBool extracts the value of a byte from a byte array and returns a bool. 81 | func decodeByteBool(buf []byte, offset int) (bool, int, error) { 82 | if len(buf) <= offset { 83 | return false, 0, ErrMalformedOffsetBoolOutOfRange 84 | } 85 | return 1&buf[offset] > 0, offset + 1, nil 86 | } 87 | 88 | // encodeBool returns a byte instead of a bool. 89 | func encodeBool(b bool) byte { 90 | if b { 91 | return 1 92 | } 93 | return 0 94 | } 95 | 96 | // encodeBytes encodes a byte array to a byte array. Used primarily for message payloads. 97 | func encodeBytes(val []byte) []byte { 98 | // In most circumstances the number of bytes being encoded is small. 99 | // Setting the cap to a low amount allows us to account for those without 100 | // triggering allocation growth on append unless we need to. 101 | buf := make([]byte, 2, 32) 102 | binary.BigEndian.PutUint16(buf, uint16(len(val))) 103 | return append(buf, val...) 104 | } 105 | 106 | // encodeUint16 encodes a uint16 value to a byte array. 107 | func encodeUint16(val uint16) []byte { 108 | buf := make([]byte, 2) 109 | binary.BigEndian.PutUint16(buf, val) 110 | return buf 111 | } 112 | 113 | // encodeUint32 encodes a uint16 value to a byte array. 114 | func encodeUint32(val uint32) []byte { 115 | buf := make([]byte, 4) 116 | binary.BigEndian.PutUint32(buf, val) 117 | return buf 118 | } 119 | 120 | // encodeString encodes a string to a byte array. 121 | func encodeString(val string) []byte { 122 | // Like encodeBytes, we set the cap to a small number to avoid 123 | // triggering allocation growth on append unless we absolutely need to. 124 | buf := make([]byte, 2, 32) 125 | binary.BigEndian.PutUint16(buf, uint16(len(val))) 126 | return append(buf, []byte(val)...) 127 | } 128 | 129 | // encodeLength writes length bits for the header. 130 | func encodeLength(b *bytes.Buffer, length int64) { 131 | // 1.5.5 Variable Byte Integer encode non-normative 132 | // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027 133 | for { 134 | eb := byte(length % 128) 135 | length /= 128 136 | if length > 0 { 137 | eb |= 0x80 138 | } 139 | b.WriteByte(eb) 140 | if length == 0 { 141 | break // [MQTT-1.5.5-1] 142 | } 143 | } 144 | } 145 | 146 | func DecodeLength(b io.ByteReader) (n, bu int, err error) { 147 | // see 1.5.5 Variable Byte Integer decode non-normative 148 | // https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901027 149 | var multiplier uint32 150 | var value uint32 151 | bu = 1 152 | for { 153 | eb, err := b.ReadByte() 154 | if err != nil { 155 | return 0, bu, err 156 | } 157 | 158 | value |= uint32(eb&127) << multiplier 159 | if value > 268435455 { 160 | return 0, bu, ErrMalformedVariableByteInteger 161 | } 162 | 163 | if (eb & 128) == 0 { 164 | break 165 | } 166 | 167 | multiplier += 7 168 | bu++ 169 | } 170 | 171 | return int(value), bu, nil 172 | } 173 | -------------------------------------------------------------------------------- /packets/codes.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package packets 6 | 7 | // Code contains a reason code and reason string for a response. 8 | type Code struct { 9 | Reason string 10 | Code byte 11 | } 12 | 13 | // String returns the readable reason for a code. 14 | func (c Code) String() string { 15 | return c.Reason 16 | } 17 | 18 | // Error returns the readable reason for a code. 19 | func (c Code) Error() string { 20 | return c.Reason 21 | } 22 | 23 | var ( 24 | // QosCodes indicates the reason codes for each Qos byte. 25 | QosCodes = map[byte]Code{ 26 | 0: CodeGrantedQos0, 27 | 1: CodeGrantedQos1, 28 | 2: CodeGrantedQos2, 29 | } 30 | 31 | CodeSuccessIgnore = Code{Code: 0x00, Reason: "ignore packet"} 32 | CodeSuccess = Code{Code: 0x00, Reason: "success"} 33 | CodeDisconnect = Code{Code: 0x00, Reason: "disconnected"} 34 | CodeGrantedQos0 = Code{Code: 0x00, Reason: "granted qos 0"} 35 | CodeGrantedQos1 = Code{Code: 0x01, Reason: "granted qos 1"} 36 | CodeGrantedQos2 = Code{Code: 0x02, Reason: "granted qos 2"} 37 | CodeDisconnectWillMessage = Code{Code: 0x04, Reason: "disconnect with will message"} 38 | CodeNoMatchingSubscribers = Code{Code: 0x10, Reason: "no matching subscribers"} 39 | CodeNoSubscriptionExisted = Code{Code: 0x11, Reason: "no subscription existed"} 40 | CodeContinueAuthentication = Code{Code: 0x18, Reason: "continue authentication"} 41 | CodeReAuthenticate = Code{Code: 0x19, Reason: "re-authenticate"} 42 | ErrUnspecifiedError = Code{Code: 0x80, Reason: "unspecified error"} 43 | ErrMalformedPacket = Code{Code: 0x81, Reason: "malformed packet"} 44 | ErrMalformedProtocolName = Code{Code: 0x81, Reason: "malformed packet: protocol name"} 45 | ErrMalformedProtocolVersion = Code{Code: 0x81, Reason: "malformed packet: protocol version"} 46 | ErrMalformedFlags = Code{Code: 0x81, Reason: "malformed packet: flags"} 47 | ErrMalformedKeepalive = Code{Code: 0x81, Reason: "malformed packet: keepalive"} 48 | ErrMalformedPacketID = Code{Code: 0x81, Reason: "malformed packet: packet identifier"} 49 | ErrMalformedTopic = Code{Code: 0x81, Reason: "malformed packet: topic"} 50 | ErrMalformedWillTopic = Code{Code: 0x81, Reason: "malformed packet: will topic"} 51 | ErrMalformedWillPayload = Code{Code: 0x81, Reason: "malformed packet: will message"} 52 | ErrMalformedUsername = Code{Code: 0x81, Reason: "malformed packet: username"} 53 | ErrMalformedPassword = Code{Code: 0x81, Reason: "malformed packet: password"} 54 | ErrMalformedQos = Code{Code: 0x81, Reason: "malformed packet: qos"} 55 | ErrMalformedOffsetUintOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset uint out of range"} 56 | ErrMalformedOffsetBytesOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset bytes out of range"} 57 | ErrMalformedOffsetByteOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset byte out of range"} 58 | ErrMalformedOffsetBoolOutOfRange = Code{Code: 0x81, Reason: "malformed packet: offset boolean out of range"} 59 | ErrMalformedInvalidUTF8 = Code{Code: 0x81, Reason: "malformed packet: invalid utf-8 string"} 60 | ErrMalformedVariableByteInteger = Code{Code: 0x81, Reason: "malformed packet: variable byte integer out of range"} 61 | ErrMalformedBadProperty = Code{Code: 0x81, Reason: "malformed packet: unknown property"} 62 | ErrMalformedProperties = Code{Code: 0x81, Reason: "malformed packet: properties"} 63 | ErrMalformedWillProperties = Code{Code: 0x81, Reason: "malformed packet: will properties"} 64 | ErrMalformedSessionPresent = Code{Code: 0x81, Reason: "malformed packet: session present"} 65 | ErrMalformedReasonCode = Code{Code: 0x81, Reason: "malformed packet: reason code"} 66 | ErrProtocolViolation = Code{Code: 0x82, Reason: "protocol violation"} 67 | ErrProtocolViolationProtocolName = Code{Code: 0x82, Reason: "protocol violation: protocol name"} 68 | ErrProtocolViolationProtocolVersion = Code{Code: 0x82, Reason: "protocol violation: protocol version"} 69 | ErrProtocolViolationReservedBit = Code{Code: 0x82, Reason: "protocol violation: reserved bit not 0"} 70 | ErrProtocolViolationFlagNoUsername = Code{Code: 0x82, Reason: "protocol violation: username flag set but no value"} 71 | ErrProtocolViolationFlagNoPassword = Code{Code: 0x82, Reason: "protocol violation: password flag set but no value"} 72 | ErrProtocolViolationUsernameNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"} 73 | ErrProtocolViolationPasswordNoFlag = Code{Code: 0x82, Reason: "protocol violation: username set but no flag"} 74 | ErrProtocolViolationPasswordTooLong = Code{Code: 0x82, Reason: "protocol violation: password too long"} 75 | ErrProtocolViolationUsernameTooLong = Code{Code: 0x82, Reason: "protocol violation: username too long"} 76 | ErrProtocolViolationNoPacketID = Code{Code: 0x82, Reason: "protocol violation: missing packet id"} 77 | ErrProtocolViolationSurplusPacketID = Code{Code: 0x82, Reason: "protocol violation: surplus packet id"} 78 | ErrProtocolViolationQosOutOfRange = Code{Code: 0x82, Reason: "protocol violation: qos out of range"} 79 | ErrProtocolViolationSecondConnect = Code{Code: 0x82, Reason: "protocol violation: second connect packet"} 80 | ErrProtocolViolationZeroNonZeroExpiry = Code{Code: 0x82, Reason: "protocol violation: non-zero expiry"} 81 | ErrProtocolViolationRequireFirstConnect = Code{Code: 0x82, Reason: "protocol violation: first packet must be connect"} 82 | ErrProtocolViolationWillFlagNoPayload = Code{Code: 0x82, Reason: "protocol violation: will flag no payload"} 83 | ErrProtocolViolationWillFlagSurplusRetain = Code{Code: 0x82, Reason: "protocol violation: will flag surplus retain"} 84 | ErrProtocolViolationSurplusWildcard = Code{Code: 0x82, Reason: "protocol violation: topic contains wildcards"} 85 | ErrProtocolViolationSurplusSubID = Code{Code: 0x82, Reason: "protocol violation: contained subscription identifier"} 86 | ErrProtocolViolationInvalidTopic = Code{Code: 0x82, Reason: "protocol violation: invalid topic"} 87 | ErrProtocolViolationInvalidSharedNoLocal = Code{Code: 0x82, Reason: "protocol violation: invalid shared no local"} 88 | ErrProtocolViolationNoFilters = Code{Code: 0x82, Reason: "protocol violation: must contain at least one filter"} 89 | ErrProtocolViolationInvalidReason = Code{Code: 0x82, Reason: "protocol violation: invalid reason"} 90 | ErrProtocolViolationOversizeSubID = Code{Code: 0x82, Reason: "protocol violation: oversize subscription id"} 91 | ErrProtocolViolationDupNoQos = Code{Code: 0x82, Reason: "protocol violation: dup true with no qos"} 92 | ErrProtocolViolationUnsupportedProperty = Code{Code: 0x82, Reason: "protocol violation: unsupported property"} 93 | ErrProtocolViolationNoTopic = Code{Code: 0x82, Reason: "protocol violation: no topic or alias"} 94 | ErrImplementationSpecificError = Code{Code: 0x83, Reason: "implementation specific error"} 95 | ErrRejectPacket = Code{Code: 0x83, Reason: "packet rejected"} 96 | ErrUnsupportedProtocolVersion = Code{Code: 0x84, Reason: "unsupported protocol version"} 97 | ErrClientIdentifierNotValid = Code{Code: 0x85, Reason: "client identifier not valid"} 98 | ErrClientIdentifierTooLong = Code{Code: 0x85, Reason: "client identifier too long"} 99 | ErrBadUsernameOrPassword = Code{Code: 0x86, Reason: "bad username or password"} 100 | ErrNotAuthorized = Code{Code: 0x87, Reason: "not authorized"} 101 | ErrServerUnavailable = Code{Code: 0x88, Reason: "server unavailable"} 102 | ErrServerBusy = Code{Code: 0x89, Reason: "server busy"} 103 | ErrBanned = Code{Code: 0x8A, Reason: "banned"} 104 | ErrServerShuttingDown = Code{Code: 0x8B, Reason: "server shutting down"} 105 | ErrBadAuthenticationMethod = Code{Code: 0x8C, Reason: "bad authentication method"} 106 | ErrKeepAliveTimeout = Code{Code: 0x8D, Reason: "keep alive timeout"} 107 | ErrSessionTakenOver = Code{Code: 0x8E, Reason: "session takeover"} 108 | ErrTopicFilterInvalid = Code{Code: 0x8F, Reason: "topic filter invalid"} 109 | ErrTopicNameInvalid = Code{Code: 0x90, Reason: "topic name invalid"} 110 | ErrPacketIdentifierInUse = Code{Code: 0x91, Reason: "packet identifier in use"} 111 | ErrPacketIdentifierNotFound = Code{Code: 0x92, Reason: "packet identifier not found"} 112 | ErrReceiveMaximum = Code{Code: 0x93, Reason: "receive maximum exceeded"} 113 | ErrTopicAliasInvalid = Code{Code: 0x94, Reason: "topic alias invalid"} 114 | ErrPacketTooLarge = Code{Code: 0x95, Reason: "packet too large"} 115 | ErrMessageRateTooHigh = Code{Code: 0x96, Reason: "message rate too high"} 116 | ErrQuotaExceeded = Code{Code: 0x97, Reason: "quota exceeded"} 117 | ErrPendingClientWritesExceeded = Code{Code: 0x97, Reason: "too many pending writes"} 118 | ErrAdministrativeAction = Code{Code: 0x98, Reason: "administrative action"} 119 | ErrPayloadFormatInvalid = Code{Code: 0x99, Reason: "payload format invalid"} 120 | ErrRetainNotSupported = Code{Code: 0x9A, Reason: "retain not supported"} 121 | ErrQosNotSupported = Code{Code: 0x9B, Reason: "qos not supported"} 122 | ErrUseAnotherServer = Code{Code: 0x9C, Reason: "use another server"} 123 | ErrServerMoved = Code{Code: 0x9D, Reason: "server moved"} 124 | ErrSharedSubscriptionsNotSupported = Code{Code: 0x9E, Reason: "shared subscriptions not supported"} 125 | ErrConnectionRateExceeded = Code{Code: 0x9F, Reason: "connection rate exceeded"} 126 | ErrMaxConnectTime = Code{Code: 0xA0, Reason: "maximum connect time"} 127 | ErrSubscriptionIdentifiersNotSupported = Code{Code: 0xA1, Reason: "subscription identifiers not supported"} 128 | ErrWildcardSubscriptionsNotSupported = Code{Code: 0xA2, Reason: "wildcard subscriptions not supported"} 129 | ErrInlineSubscriptionHandlerInvalid = Code{Code: 0xA3, Reason: "inline subscription handler not valid."} 130 | 131 | // MQTTv3 specific bytes. 132 | Err3UnsupportedProtocolVersion = Code{Code: 0x01} 133 | Err3ClientIdentifierNotValid = Code{Code: 0x02} 134 | Err3ServerUnavailable = Code{Code: 0x03} 135 | ErrMalformedUsernameOrPassword = Code{Code: 0x04} 136 | Err3NotAuthorized = Code{Code: 0x05} 137 | 138 | // V5CodesToV3 maps MQTTv5 Connack reason codes to MQTTv3 return codes. 139 | // This is required because MQTTv3 has different return byte specification. 140 | // See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349257 141 | V5CodesToV3 = map[Code]Code{ 142 | ErrUnsupportedProtocolVersion: Err3UnsupportedProtocolVersion, 143 | ErrClientIdentifierNotValid: Err3ClientIdentifierNotValid, 144 | ErrServerUnavailable: Err3ServerUnavailable, 145 | ErrMalformedUsername: ErrMalformedUsernameOrPassword, 146 | ErrMalformedPassword: ErrMalformedUsernameOrPassword, 147 | ErrBadUsernameOrPassword: Err3NotAuthorized, 148 | } 149 | ) 150 | -------------------------------------------------------------------------------- /packets/codes_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package packets 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func TestCodesString(t *testing.T) { 14 | c := Code{ 15 | Reason: "test", 16 | Code: 0x1, 17 | } 18 | 19 | require.Equal(t, "test", c.String()) 20 | } 21 | 22 | func TestCodesError(t *testing.T) { 23 | c := Code{ 24 | Reason: "error", 25 | Code: 0x1, 26 | } 27 | 28 | require.Equal(t, "error", error(c).Error()) 29 | } 30 | -------------------------------------------------------------------------------- /packets/fixedheader.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package packets 6 | 7 | import ( 8 | "bytes" 9 | ) 10 | 11 | // FixedHeader contains the values of the fixed header portion of the MQTT packet. 12 | type FixedHeader struct { 13 | Remaining int `json:"remaining"` // the number of remaining bytes in the payload. 14 | Type byte `json:"type"` // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1). 15 | Qos byte `json:"qos"` // indicates the quality of service expected. 16 | Dup bool `json:"dup"` // indicates if the packet was already sent at an earlier time. 17 | Retain bool `json:"retain"` // whether the message should be retained. 18 | } 19 | 20 | // Encode encodes the FixedHeader and returns a bytes buffer. 21 | func (fh *FixedHeader) Encode(buf *bytes.Buffer) { 22 | buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain)) 23 | encodeLength(buf, int64(fh.Remaining)) 24 | } 25 | 26 | // Decode extracts the specification bits from the header byte. 27 | func (fh *FixedHeader) Decode(hb byte) error { 28 | fh.Type = hb >> 4 // Get the message type from the first 4 bytes. 29 | 30 | switch fh.Type { 31 | case Publish: 32 | if (hb>>1)&0x01 > 0 && (hb>>1)&0x02 > 0 { 33 | return ErrProtocolViolationQosOutOfRange // [MQTT-3.3.1-4] 34 | } 35 | 36 | fh.Dup = (hb>>3)&0x01 > 0 // is duplicate 37 | fh.Qos = (hb >> 1) & 0x03 // qos flag 38 | fh.Retain = hb&0x01 > 0 // is retain flag 39 | case Pubrel: 40 | fallthrough 41 | case Subscribe: 42 | fallthrough 43 | case Unsubscribe: 44 | if (hb>>0)&0x01 != 0 || (hb>>1)&0x01 != 1 || (hb>>2)&0x01 != 0 || (hb>>3)&0x01 != 0 { // [MQTT-3.8.1-1] [MQTT-3.10.1-1] 45 | return ErrMalformedFlags 46 | } 47 | 48 | fh.Qos = (hb >> 1) & 0x03 49 | default: 50 | if (hb>>0)&0x01 != 0 || 51 | (hb>>1)&0x01 != 0 || 52 | (hb>>2)&0x01 != 0 || 53 | (hb>>3)&0x01 != 0 { // [MQTT-3.8.3-5] [MQTT-3.14.1-1] [MQTT-3.15.1-1] 54 | return ErrMalformedFlags 55 | } 56 | } 57 | 58 | if fh.Qos == 0 && fh.Dup { 59 | return ErrProtocolViolationDupNoQos // [MQTT-3.3.1-2] 60 | } 61 | 62 | return nil 63 | } 64 | -------------------------------------------------------------------------------- /packets/fixedheader_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package packets 6 | 7 | import ( 8 | "bytes" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | type fixedHeaderTable struct { 15 | desc string 16 | rawBytes []byte 17 | header FixedHeader 18 | packetError bool 19 | expect error 20 | } 21 | 22 | var fixedHeaderExpected = []fixedHeaderTable{ 23 | { 24 | desc: "connect", 25 | rawBytes: []byte{Connect << 4, 0x00}, 26 | header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 27 | }, 28 | { 29 | desc: "connack", 30 | rawBytes: []byte{Connack << 4, 0x00}, 31 | header: FixedHeader{Type: Connack, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 32 | }, 33 | { 34 | desc: "publish", 35 | rawBytes: []byte{Publish << 4, 0x00}, 36 | header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 37 | }, 38 | { 39 | desc: "publish qos 1", 40 | rawBytes: []byte{Publish<<4 | 1<<1, 0x00}, 41 | header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: false, Remaining: 0}, 42 | }, 43 | { 44 | desc: "publish qos 1 retain", 45 | rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00}, 46 | header: FixedHeader{Type: Publish, Dup: false, Qos: 1, Retain: true, Remaining: 0}, 47 | }, 48 | { 49 | desc: "publish qos 2", 50 | rawBytes: []byte{Publish<<4 | 2<<1, 0x00}, 51 | header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: false, Remaining: 0}, 52 | }, 53 | { 54 | desc: "publish qos 2 retain", 55 | rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00}, 56 | header: FixedHeader{Type: Publish, Dup: false, Qos: 2, Retain: true, Remaining: 0}, 57 | }, 58 | { 59 | desc: "publish dup qos 0", 60 | rawBytes: []byte{Publish<<4 | 1<<3, 0x00}, 61 | header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: false, Remaining: 0}, 62 | expect: ErrProtocolViolationDupNoQos, 63 | }, 64 | { 65 | desc: "publish dup qos 0 retain", 66 | rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00}, 67 | header: FixedHeader{Type: Publish, Dup: true, Qos: 0, Retain: true, Remaining: 0}, 68 | expect: ErrProtocolViolationDupNoQos, 69 | }, 70 | { 71 | desc: "publish dup qos 1 retain", 72 | rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00}, 73 | header: FixedHeader{Type: Publish, Dup: true, Qos: 1, Retain: true, Remaining: 0}, 74 | }, 75 | { 76 | desc: "publish dup qos 2 retain", 77 | rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00}, 78 | header: FixedHeader{Type: Publish, Dup: true, Qos: 2, Retain: true, Remaining: 0}, 79 | }, 80 | { 81 | desc: "puback", 82 | rawBytes: []byte{Puback << 4, 0x00}, 83 | header: FixedHeader{Type: Puback, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 84 | }, 85 | { 86 | desc: "pubrec", 87 | rawBytes: []byte{Pubrec << 4, 0x00}, 88 | header: FixedHeader{Type: Pubrec, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 89 | }, 90 | { 91 | desc: "pubrel", 92 | rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00}, 93 | header: FixedHeader{Type: Pubrel, Dup: false, Qos: 1, Retain: false, Remaining: 0}, 94 | }, 95 | { 96 | desc: "pubcomp", 97 | rawBytes: []byte{Pubcomp << 4, 0x00}, 98 | header: FixedHeader{Type: Pubcomp, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 99 | }, 100 | { 101 | desc: "subscribe", 102 | rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00}, 103 | header: FixedHeader{Type: Subscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0}, 104 | }, 105 | { 106 | desc: "suback", 107 | rawBytes: []byte{Suback << 4, 0x00}, 108 | header: FixedHeader{Type: Suback, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 109 | }, 110 | { 111 | desc: "unsubscribe", 112 | rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00}, 113 | header: FixedHeader{Type: Unsubscribe, Dup: false, Qos: 1, Retain: false, Remaining: 0}, 114 | }, 115 | { 116 | desc: "unsuback", 117 | rawBytes: []byte{Unsuback << 4, 0x00}, 118 | header: FixedHeader{Type: Unsuback, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 119 | }, 120 | { 121 | desc: "pingreq", 122 | rawBytes: []byte{Pingreq << 4, 0x00}, 123 | header: FixedHeader{Type: Pingreq, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 124 | }, 125 | { 126 | desc: "pingresp", 127 | rawBytes: []byte{Pingresp << 4, 0x00}, 128 | header: FixedHeader{Type: Pingresp, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 129 | }, 130 | { 131 | desc: "disconnect", 132 | rawBytes: []byte{Disconnect << 4, 0x00}, 133 | header: FixedHeader{Type: Disconnect, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 134 | }, 135 | { 136 | desc: "auth", 137 | rawBytes: []byte{Auth << 4, 0x00}, 138 | header: FixedHeader{Type: Auth, Dup: false, Qos: 0, Retain: false, Remaining: 0}, 139 | }, 140 | 141 | // remaining length 142 | { 143 | desc: "remaining length 10", 144 | rawBytes: []byte{Publish << 4, 0x0a}, 145 | header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 10}, 146 | }, 147 | { 148 | desc: "remaining length 512", 149 | rawBytes: []byte{Publish << 4, 0x80, 0x04}, 150 | header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 512}, 151 | }, 152 | { 153 | desc: "remaining length 978", 154 | rawBytes: []byte{Publish << 4, 0xd2, 0x07}, 155 | header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 978}, 156 | }, 157 | { 158 | desc: "remaining length 20202", 159 | rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01}, 160 | header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 20102}, 161 | }, 162 | { 163 | desc: "remaining length oversize", 164 | rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, 165 | header: FixedHeader{Type: Publish, Dup: false, Qos: 0, Retain: false, Remaining: 333333333}, 166 | packetError: true, 167 | }, 168 | 169 | // Invalid flags for packet 170 | { 171 | desc: "invalid type dup is true", 172 | rawBytes: []byte{Connect<<4 | 1<<3, 0x00}, 173 | header: FixedHeader{Type: Connect, Dup: true, Qos: 0, Retain: false, Remaining: 0}, 174 | expect: ErrMalformedFlags, 175 | }, 176 | { 177 | desc: "invalid type qos is 1", 178 | rawBytes: []byte{Connect<<4 | 1<<1, 0x00}, 179 | header: FixedHeader{Type: Connect, Dup: false, Qos: 1, Retain: false, Remaining: 0}, 180 | expect: ErrMalformedFlags, 181 | }, 182 | { 183 | desc: "invalid type retain is true", 184 | rawBytes: []byte{Connect<<4 | 1, 0x00}, 185 | header: FixedHeader{Type: Connect, Dup: false, Qos: 0, Retain: true, Remaining: 0}, 186 | expect: ErrMalformedFlags, 187 | }, 188 | { 189 | desc: "invalid publish qos bits 1 + 2 set", 190 | rawBytes: []byte{Publish<<4 | 1<<1 | 1<<2, 0x00}, 191 | header: FixedHeader{Type: Publish}, 192 | expect: ErrProtocolViolationQosOutOfRange, 193 | }, 194 | { 195 | desc: "invalid pubrel bits 3,2,1,0 should be 0,0,1,0", 196 | rawBytes: []byte{Pubrel<<4 | 1<<2 | 1<<0, 0x00}, 197 | header: FixedHeader{Type: Pubrel, Qos: 1}, 198 | expect: ErrMalformedFlags, 199 | }, 200 | { 201 | desc: "invalid subscribe bits 3,2,1,0 should be 0,0,1,0", 202 | rawBytes: []byte{Subscribe<<4 | 1<<2, 0x00}, 203 | header: FixedHeader{Type: Subscribe, Qos: 1}, 204 | expect: ErrMalformedFlags, 205 | }, 206 | } 207 | 208 | func TestFixedHeaderEncode(t *testing.T) { 209 | for _, wanted := range fixedHeaderExpected { 210 | t.Run(wanted.desc, func(t *testing.T) { 211 | buf := new(bytes.Buffer) 212 | wanted.header.Encode(buf) 213 | if wanted.expect == nil { 214 | require.Equal(t, len(wanted.rawBytes), len(buf.Bytes())) 215 | require.EqualValues(t, wanted.rawBytes, buf.Bytes()) 216 | } 217 | }) 218 | } 219 | } 220 | 221 | func TestFixedHeaderDecode(t *testing.T) { 222 | for _, wanted := range fixedHeaderExpected { 223 | t.Run(wanted.desc, func(t *testing.T) { 224 | fh := new(FixedHeader) 225 | err := fh.Decode(wanted.rawBytes[0]) 226 | if wanted.expect != nil { 227 | require.Equal(t, wanted.expect, err) 228 | } else { 229 | require.NoError(t, err) 230 | require.Equal(t, wanted.header.Type, fh.Type) 231 | require.Equal(t, wanted.header.Dup, fh.Dup) 232 | require.Equal(t, wanted.header.Qos, fh.Qos) 233 | require.Equal(t, wanted.header.Retain, fh.Retain) 234 | } 235 | }) 236 | } 237 | } 238 | -------------------------------------------------------------------------------- /packets/properties_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package packets 6 | 7 | import ( 8 | "bytes" 9 | "testing" 10 | 11 | "github.com/stretchr/testify/require" 12 | ) 13 | 14 | var ( 15 | propertiesStruct = Properties{ 16 | PayloadFormat: byte(1), // UTF-8 Format 17 | PayloadFormatFlag: true, 18 | MessageExpiryInterval: uint32(2), 19 | ContentType: "text/plain", 20 | ResponseTopic: "a/b/c", 21 | CorrelationData: []byte("data"), 22 | SubscriptionIdentifier: []int{322122}, 23 | SessionExpiryInterval: uint32(120), 24 | SessionExpiryIntervalFlag: true, 25 | AssignedClientID: "mochi-v5", 26 | ServerKeepAlive: uint16(20), 27 | ServerKeepAliveFlag: true, 28 | AuthenticationMethod: "SHA-1", 29 | AuthenticationData: []byte("auth-data"), 30 | RequestProblemInfo: byte(1), 31 | RequestProblemInfoFlag: true, 32 | WillDelayInterval: uint32(600), 33 | RequestResponseInfo: byte(1), 34 | ResponseInfo: "response", 35 | ServerReference: "mochi-2", 36 | ReasonString: "reason", 37 | ReceiveMaximum: uint16(500), 38 | TopicAliasMaximum: uint16(999), 39 | TopicAlias: uint16(3), 40 | TopicAliasFlag: true, 41 | MaximumQos: byte(1), 42 | MaximumQosFlag: true, 43 | RetainAvailable: byte(1), 44 | RetainAvailableFlag: true, 45 | User: []UserProperty{ 46 | { 47 | Key: "hello", 48 | Val: "世界", 49 | }, 50 | { 51 | Key: "key2", 52 | Val: "value2", 53 | }, 54 | }, 55 | MaximumPacketSize: uint32(32000), 56 | WildcardSubAvailable: byte(1), 57 | WildcardSubAvailableFlag: true, 58 | SubIDAvailable: byte(1), 59 | SubIDAvailableFlag: true, 60 | SharedSubAvailable: byte(1), 61 | SharedSubAvailableFlag: true, 62 | } 63 | 64 | propertiesBytes = []byte{ 65 | 172, 1, // VBI 66 | 67 | // Payload Format (1) (vbi:2) 68 | 1, 1, 69 | 70 | // Message Expiry (2) (vbi:7) 71 | 2, 0, 0, 0, 2, 72 | 73 | // Content Type (3) (vbi:20) 74 | 3, 75 | 0, 10, 't', 'e', 'x', 't', '/', 'p', 'l', 'a', 'i', 'n', 76 | 77 | // Response Topic (8) (vbi:28) 78 | 8, 79 | 0, 5, 'a', '/', 'b', '/', 'c', 80 | 81 | // Correlations Data (9) (vbi:35) 82 | 9, 83 | 0, 4, 'd', 'a', 't', 'a', 84 | 85 | // Subscription Identifier (11) (vbi:39) 86 | 11, 87 | 202, 212, 19, 88 | 89 | // Session Expiry Interval (17) (vbi:43) 90 | 17, 91 | 0, 0, 0, 120, 92 | 93 | // Assigned Client ID (18) (vbi:55) 94 | 18, 95 | 0, 8, 'm', 'o', 'c', 'h', 'i', '-', 'v', '5', 96 | 97 | // Server Keep Alive (19) (vbi:58) 98 | 19, 99 | 0, 20, 100 | 101 | // Authentication Method (21) (vbi:66) 102 | 21, 103 | 0, 5, 'S', 'H', 'A', '-', '1', 104 | 105 | // Authentication Data (22) (vbi:78) 106 | 22, 107 | 0, 9, 'a', 'u', 't', 'h', '-', 'd', 'a', 't', 'a', 108 | 109 | // Request Problem Info (23) (vbi:80) 110 | 23, 1, 111 | 112 | // Will Delay Interval (24) (vbi:85) 113 | 24, 114 | 0, 0, 2, 88, 115 | 116 | // Request Response Info (25) (vbi:87) 117 | 25, 1, 118 | 119 | // Response Info (26) (vbi:98) 120 | 26, 121 | 0, 8, 'r', 'e', 's', 'p', 'o', 'n', 's', 'e', 122 | 123 | // Server Reference (28) (vbi:108) 124 | 28, 125 | 0, 7, 'm', 'o', 'c', 'h', 'i', '-', '2', 126 | 127 | // Reason String (31) (vbi:117) 128 | 31, 129 | 0, 6, 'r', 'e', 'a', 's', 'o', 'n', 130 | 131 | // Receive Maximum (33) (vbi:120) 132 | 33, 133 | 1, 244, 134 | 135 | // Topic Alias Maximum (34) (vbi:123) 136 | 34, 137 | 3, 231, 138 | 139 | // Topic Alias (35) (vbi:126) 140 | 35, 141 | 0, 3, 142 | 143 | // Maximum Qos (36) (vbi:128) 144 | 36, 1, 145 | 146 | // Retain Available (37) (vbi: 130) 147 | 37, 1, 148 | 149 | // User Properties (38) (vbi:161) 150 | 38, 151 | 0, 5, 'h', 'e', 'l', 'l', 'o', 152 | 0, 6, 228, 184, 150, 231, 149, 140, 153 | 38, 154 | 0, 4, 'k', 'e', 'y', '2', 155 | 0, 6, 'v', 'a', 'l', 'u', 'e', '2', 156 | 157 | // Maximum Packet Size (39) (vbi:166) 158 | 39, 159 | 0, 0, 125, 0, 160 | 161 | // Wildcard Subscriptions Available (40) (vbi:168) 162 | 40, 1, 163 | 164 | // Subscription ID Available (41) (vbi:170) 165 | 41, 1, 166 | 167 | // Shared Subscriptions Available (42) (vbi:172) 168 | 42, 1, 169 | } 170 | ) 171 | 172 | func init() { 173 | validPacketProperties[PropPayloadFormat][Reserved] = 1 174 | validPacketProperties[PropMessageExpiryInterval][Reserved] = 1 175 | validPacketProperties[PropContentType][Reserved] = 1 176 | validPacketProperties[PropResponseTopic][Reserved] = 1 177 | validPacketProperties[PropCorrelationData][Reserved] = 1 178 | validPacketProperties[PropSubscriptionIdentifier][Reserved] = 1 179 | validPacketProperties[PropSessionExpiryInterval][Reserved] = 1 180 | validPacketProperties[PropAssignedClientID][Reserved] = 1 181 | validPacketProperties[PropServerKeepAlive][Reserved] = 1 182 | validPacketProperties[PropAuthenticationMethod][Reserved] = 1 183 | validPacketProperties[PropAuthenticationData][Reserved] = 1 184 | validPacketProperties[PropRequestProblemInfo][Reserved] = 1 185 | validPacketProperties[PropWillDelayInterval][Reserved] = 1 186 | validPacketProperties[PropRequestResponseInfo][Reserved] = 1 187 | validPacketProperties[PropResponseInfo][Reserved] = 1 188 | validPacketProperties[PropServerReference][Reserved] = 1 189 | validPacketProperties[PropReasonString][Reserved] = 1 190 | validPacketProperties[PropReceiveMaximum][Reserved] = 1 191 | validPacketProperties[PropTopicAliasMaximum][Reserved] = 1 192 | validPacketProperties[PropTopicAlias][Reserved] = 1 193 | validPacketProperties[PropMaximumQos][Reserved] = 1 194 | validPacketProperties[PropRetainAvailable][Reserved] = 1 195 | validPacketProperties[PropUser][Reserved] = 1 196 | validPacketProperties[PropMaximumPacketSize][Reserved] = 1 197 | validPacketProperties[PropWildcardSubAvailable][Reserved] = 1 198 | validPacketProperties[PropSubIDAvailable][Reserved] = 1 199 | validPacketProperties[PropSharedSubAvailable][Reserved] = 1 200 | } 201 | 202 | func TestEncodeProperties(t *testing.T) { 203 | props := propertiesStruct 204 | b := bytes.NewBuffer([]byte{}) 205 | props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0) 206 | require.Equal(t, propertiesBytes, b.Bytes()) 207 | } 208 | 209 | func TestEncodePropertiesDisallowProblemInfo(t *testing.T) { 210 | props := propertiesStruct 211 | b := bytes.NewBuffer([]byte{}) 212 | props.Encode(Reserved, Mods{DisallowProblemInfo: true}, b, 0) 213 | require.NotEqual(t, propertiesBytes, b.Bytes()) 214 | require.False(t, bytes.Contains(b.Bytes(), []byte{31, 0, 6})) 215 | require.False(t, bytes.Contains(b.Bytes(), []byte{38, 0, 5})) 216 | require.False(t, bytes.Contains(b.Bytes(), []byte{26, 0, 8})) 217 | } 218 | 219 | func TestEncodePropertiesDisallowResponseInfo(t *testing.T) { 220 | props := propertiesStruct 221 | b := bytes.NewBuffer([]byte{}) 222 | props.Encode(Reserved, Mods{AllowResponseInfo: false}, b, 0) 223 | require.NotEqual(t, propertiesBytes, b.Bytes()) 224 | require.NotContains(t, b.Bytes(), []byte{8, 0, 5}) 225 | require.NotContains(t, b.Bytes(), []byte{9, 0, 4}) 226 | } 227 | 228 | func TestEncodePropertiesNil(t *testing.T) { 229 | type tmp struct { 230 | p *Properties 231 | } 232 | 233 | pr := tmp{} 234 | b := bytes.NewBuffer([]byte{}) 235 | pr.p.Encode(Reserved, Mods{}, b, 0) 236 | require.Equal(t, []byte{}, b.Bytes()) 237 | } 238 | 239 | func TestEncodeZeroProperties(t *testing.T) { 240 | // [MQTT-2.2.2-1] If there are no properties, this MUST be indicated by including a Property Length of zero. 241 | props := new(Properties) 242 | b := bytes.NewBuffer([]byte{}) 243 | props.Encode(Reserved, Mods{AllowResponseInfo: true}, b, 0) 244 | require.Equal(t, []byte{0x00}, b.Bytes()) 245 | } 246 | 247 | func TestDecodeProperties(t *testing.T) { 248 | b := bytes.NewBuffer(propertiesBytes) 249 | 250 | props := new(Properties) 251 | n, err := props.Decode(Reserved, b) 252 | require.NoError(t, err) 253 | require.Equal(t, 172+2, n) 254 | require.EqualValues(t, propertiesStruct, *props) 255 | } 256 | 257 | func TestDecodePropertiesNil(t *testing.T) { 258 | b := bytes.NewBuffer(propertiesBytes) 259 | 260 | type tmp struct { 261 | p *Properties 262 | } 263 | 264 | pr := tmp{} 265 | n, err := pr.p.Decode(Reserved, b) 266 | require.NoError(t, err) 267 | require.Equal(t, 0, n) 268 | } 269 | 270 | func TestDecodePropertiesBadInitialVBI(t *testing.T) { 271 | b := bytes.NewBuffer([]byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}) 272 | props := new(Properties) 273 | _, err := props.Decode(Reserved, b) 274 | require.Error(t, err) 275 | require.ErrorIs(t, ErrMalformedVariableByteInteger, err) 276 | } 277 | 278 | func TestDecodePropertiesZeroLengthVBI(t *testing.T) { 279 | b := bytes.NewBuffer([]byte{0}) 280 | props := new(Properties) 281 | _, err := props.Decode(Reserved, b) 282 | require.NoError(t, err) 283 | require.Equal(t, props, new(Properties)) 284 | } 285 | 286 | func TestDecodePropertiesBadKeyByte(t *testing.T) { 287 | b := bytes.NewBuffer([]byte{64, 1}) 288 | props := new(Properties) 289 | _, err := props.Decode(Reserved, b) 290 | require.Error(t, err) 291 | require.ErrorIs(t, err, ErrMalformedOffsetByteOutOfRange) 292 | } 293 | 294 | func TestDecodePropertiesInvalidForPacket(t *testing.T) { 295 | b := bytes.NewBuffer([]byte{1, 99}) 296 | props := new(Properties) 297 | _, err := props.Decode(Reserved, b) 298 | require.Error(t, err) 299 | require.ErrorIs(t, err, ErrProtocolViolationUnsupportedProperty) 300 | } 301 | 302 | func TestDecodePropertiesGeneralFailure(t *testing.T) { 303 | b := bytes.NewBuffer([]byte{10, 11, 202, 212, 19}) 304 | props := new(Properties) 305 | _, err := props.Decode(Reserved, b) 306 | require.Error(t, err) 307 | } 308 | 309 | func TestDecodePropertiesBadSubscriptionID(t *testing.T) { 310 | b := bytes.NewBuffer([]byte{10, 11, 255, 255, 255, 255, 255, 255, 255, 255}) 311 | props := new(Properties) 312 | _, err := props.Decode(Reserved, b) 313 | require.Error(t, err) 314 | } 315 | 316 | func TestDecodePropertiesBadUserProps(t *testing.T) { 317 | b := bytes.NewBuffer([]byte{10, 38, 255, 255, 255, 255, 255, 255, 255, 255}) 318 | props := new(Properties) 319 | _, err := props.Decode(Reserved, b) 320 | require.Error(t, err) 321 | } 322 | 323 | func TestCopyProperties(t *testing.T) { 324 | require.EqualValues(t, propertiesStruct, propertiesStruct.Copy(true)) 325 | } 326 | 327 | func TestCopyPropertiesNoTransfer(t *testing.T) { 328 | pkA := propertiesStruct 329 | pkB := pkA.Copy(false) 330 | 331 | // Properties which should never be transferred from one connection to another 332 | require.Equal(t, uint16(0), pkB.TopicAlias) 333 | } 334 | -------------------------------------------------------------------------------- /packets/tpackets_test.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package packets 6 | 7 | import ( 8 | "testing" 9 | 10 | "github.com/stretchr/testify/require" 11 | ) 12 | 13 | func encodeTestOK(wanted TPacketCase) bool { 14 | if wanted.RawBytes == nil { 15 | return false 16 | } 17 | if wanted.Group != "" && wanted.Group != "encode" { 18 | return false 19 | } 20 | return true 21 | } 22 | 23 | func decodeTestOK(wanted TPacketCase) bool { 24 | if wanted.Group != "" && wanted.Group != "decode" { 25 | return false 26 | } 27 | return true 28 | } 29 | 30 | func TestTPacketCaseGet(t *testing.T) { 31 | require.Equal(t, TPacketData[Connect][1], TPacketData[Connect].Get(TConnectMqtt311)) 32 | require.Equal(t, TPacketCase{}, TPacketData[Connect].Get(byte(128))) 33 | } 34 | -------------------------------------------------------------------------------- /system/system.go: -------------------------------------------------------------------------------- 1 | // SPDX-License-Identifier: MIT 2 | // SPDX-FileCopyrightText: 2022 mochi-mqtt, mochi-co 3 | // SPDX-FileContributor: mochi-co 4 | 5 | package system 6 | 7 | import "sync/atomic" 8 | 9 | // Info contains atomic counters and values for various server statistics 10 | // commonly found in $SYS topics (and others). 11 | // based on https://github.com/mqtt/mqtt.org/wiki/SYS-Topics 12 | type Info struct { 13 | Version string `json:"version"` // the current version of the server 14 | Started int64 `json:"started"` // the time the server started in unix seconds 15 | Time int64 `json:"time"` // current time on the server 16 | Uptime int64 `json:"uptime"` // the number of seconds the server has been online 17 | BytesReceived int64 `json:"bytes_received"` // total number of bytes received since the broker started 18 | BytesSent int64 `json:"bytes_sent"` // total number of bytes sent since the broker started 19 | ClientsConnected int64 `json:"clients_connected"` // number of currently connected clients 20 | ClientsDisconnected int64 `json:"clients_disconnected"` // total number of persistent clients (with clean session disabled) that are registered at the broker but are currently disconnected 21 | ClientsMaximum int64 `json:"clients_maximum"` // maximum number of active clients that have been connected 22 | ClientsTotal int64 `json:"clients_total"` // total number of connected and disconnected clients with a persistent session currently connected and registered 23 | MessagesReceived int64 `json:"messages_received"` // total number of publish messages received 24 | MessagesSent int64 `json:"messages_sent"` // total number of publish messages sent 25 | MessagesDropped int64 `json:"messages_dropped"` // total number of publish messages dropped to slow subscriber 26 | Retained int64 `json:"retained"` // total number of retained messages active on the broker 27 | Inflight int64 `json:"inflight"` // the number of messages currently in-flight 28 | InflightDropped int64 `json:"inflight_dropped"` // the number of inflight messages which were dropped 29 | Subscriptions int64 `json:"subscriptions"` // total number of subscriptions active on the broker 30 | PacketsReceived int64 `json:"packets_received"` // the total number of publish messages received 31 | PacketsSent int64 `json:"packets_sent"` // total number of messages of any type sent since the broker started 32 | MemoryAlloc int64 `json:"memory_alloc"` // memory currently allocated 33 | Threads int64 `json:"threads"` // number of active goroutines, named as threads for platform ambiguity 34 | } 35 | 36 | // Clone makes a copy of Info using atomic operation 37 | func (i *Info) Clone() *Info { 38 | return &Info{ 39 | Version: i.Version, 40 | Started: atomic.LoadInt64(&i.Started), 41 | Time: atomic.LoadInt64(&i.Time), 42 | Uptime: atomic.LoadInt64(&i.Uptime), 43 | BytesReceived: atomic.LoadInt64(&i.BytesReceived), 44 | BytesSent: atomic.LoadInt64(&i.BytesSent), 45 | ClientsConnected: atomic.LoadInt64(&i.ClientsConnected), 46 | ClientsMaximum: atomic.LoadInt64(&i.ClientsMaximum), 47 | ClientsTotal: atomic.LoadInt64(&i.ClientsTotal), 48 | ClientsDisconnected: atomic.LoadInt64(&i.ClientsDisconnected), 49 | MessagesReceived: atomic.LoadInt64(&i.MessagesReceived), 50 | MessagesSent: atomic.LoadInt64(&i.MessagesSent), 51 | MessagesDropped: atomic.LoadInt64(&i.MessagesDropped), 52 | Retained: atomic.LoadInt64(&i.Retained), 53 | Inflight: atomic.LoadInt64(&i.Inflight), 54 | InflightDropped: atomic.LoadInt64(&i.InflightDropped), 55 | Subscriptions: atomic.LoadInt64(&i.Subscriptions), 56 | PacketsReceived: atomic.LoadInt64(&i.PacketsReceived), 57 | PacketsSent: atomic.LoadInt64(&i.PacketsSent), 58 | MemoryAlloc: atomic.LoadInt64(&i.MemoryAlloc), 59 | Threads: atomic.LoadInt64(&i.Threads), 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /system/system_test.go: -------------------------------------------------------------------------------- 1 | package system 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/require" 7 | ) 8 | 9 | func TestClone(t *testing.T) { 10 | o := &Info{ 11 | Version: "version", 12 | Started: 1, 13 | Time: 2, 14 | Uptime: 3, 15 | BytesReceived: 4, 16 | BytesSent: 5, 17 | ClientsConnected: 6, 18 | ClientsMaximum: 7, 19 | ClientsTotal: 8, 20 | ClientsDisconnected: 9, 21 | MessagesReceived: 10, 22 | MessagesSent: 11, 23 | MessagesDropped: 20, 24 | Retained: 12, 25 | Inflight: 13, 26 | InflightDropped: 14, 27 | Subscriptions: 15, 28 | PacketsReceived: 16, 29 | PacketsSent: 17, 30 | MemoryAlloc: 18, 31 | Threads: 19, 32 | } 33 | 34 | n := o.Clone() 35 | 36 | require.Equal(t, o, n) 37 | } 38 | --------------------------------------------------------------------------------