├── .github └── workflows │ ├── buildAndRelease.yml │ ├── docker-image-build.yml │ ├── docker-image-v1.0.0.yml │ ├── docker-image-v1.1.0.yml │ └── docker-image-v1.1.1.yml ├── .gitignore ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── README_ZH.md ├── cmd ├── gmqctl │ ├── command │ │ ├── command.go │ │ └── gen-plugin │ │ │ ├── command.go │ │ │ ├── command_test.go │ │ │ └── tmpl.go │ ├── main.go │ └── version.go └── gmqttd │ ├── certs │ ├── ca.crt │ ├── server.crt │ └── server.key │ ├── command │ ├── reload.go │ └── start.go │ ├── config_unix.go │ ├── config_windows.go │ ├── default_config.yml │ ├── gmqtt.sh │ ├── main.go │ ├── main_pprof.go │ ├── plugins.go │ ├── thingspanel.yml │ └── version.go ├── config ├── api.go ├── api_test.go ├── config.go ├── config_mock.go ├── config_test.go ├── mqtt.go ├── persistence.go ├── testdata │ ├── config.yml │ ├── default_values.yml │ └── default_values_expected.yml └── topic_alias.go ├── go.mod ├── go.sum ├── go_test.sh ├── message.go ├── mock_gen.sh ├── persistence ├── encoding │ ├── binary.go │ └── redis.go ├── memory.go ├── memory_test.go ├── queue │ ├── elem.go │ ├── elem_mock.go │ ├── elem_test.go │ ├── error.go │ ├── mem │ │ └── mem.go │ ├── queue.go │ ├── queue_mock.go │ ├── redis │ │ └── redis.go │ └── test │ │ └── test_suite.go ├── redis.go ├── redis_test.go ├── session │ ├── mem │ │ └── store.go │ ├── redis │ │ └── store.go │ ├── session.go │ ├── session_mock.go │ └── test │ │ └── test_suite.go ├── subscription │ ├── mem │ │ ├── topic_trie.go │ │ ├── topic_trie_test.go │ │ └── trie_db.go │ ├── redis │ │ ├── subscription.go │ │ └── subscription_test.go │ ├── subscription.go │ ├── subscription_mock.go │ └── test │ │ └── test_suite.go └── unack │ ├── mem │ └── mem.go │ ├── redis │ └── redis.go │ ├── test │ └── test_suite.go │ ├── unack.go │ └── unack_mock.go ├── pkg ├── bitmap │ ├── bitmap.go │ └── bitmap_test.go ├── codes │ └── codes.go ├── packets │ ├── auth.go │ ├── auth_test.go │ ├── connack.go │ ├── connack_test.go │ ├── connect.go │ ├── connect_test.go │ ├── disconnect.go │ ├── disconnect_test.go │ ├── packets.go │ ├── packets_mock.go │ ├── packets_test.go │ ├── ping_test.go │ ├── pingreq.go │ ├── pingresp.go │ ├── properties.go │ ├── puback.go │ ├── puback_test.go │ ├── pubcomp.go │ ├── pubcomp_test.go │ ├── publish.go │ ├── publish_test.go │ ├── pubrec.go │ ├── pubrec_test.go │ ├── pubrel.go │ ├── pubrel_test.go │ ├── suback.go │ ├── suback_test.go │ ├── subscribe.go │ ├── subscribe_test.go │ ├── unsuback.go │ ├── unsuback_test.go │ ├── unsubscribe.go │ └── unsubscribe_test.go └── pidfile │ ├── pidfile.go │ ├── pidfile_darwin.go │ ├── pidfile_test.go │ ├── pidfile_unix.go │ └── pidfile_windows.go ├── plugin ├── .DS_Store ├── README.md ├── admin │ ├── README.md │ ├── admin.go │ ├── client.go │ ├── client.pb.go │ ├── client.pb.gw.go │ ├── client_grpc.pb.go │ ├── client_test.go │ ├── config.go │ ├── hooks.go │ ├── protos │ │ ├── client.proto │ │ ├── proto_gen.sh │ │ ├── publish.proto │ │ └── subscription.proto │ ├── publish.go │ ├── publish.pb.go │ ├── publish.pb.gw.go │ ├── publish_grpc.pb.go │ ├── publish_test.go │ ├── store.go │ ├── subscription.go │ ├── subscription.pb.go │ ├── subscription.pb.gw.go │ ├── subscription_grpc.pb.go │ ├── subscription_test.go │ ├── swagger │ │ ├── client.swagger.json │ │ ├── publish.swagger.json │ │ └── subscription.swagger.json │ ├── utils.go │ └── utils_test.go ├── auth │ ├── README.md │ ├── account.pb.go │ ├── account.pb.gw.go │ ├── account_grpc.pb.go │ ├── account_grpc.pb_mock.go │ ├── auth.go │ ├── auth_test.go │ ├── config.go │ ├── grpc_handler.go │ ├── grpc_handler_test.go │ ├── hooks.go │ ├── hooks_test.go │ ├── protos │ │ ├── account.proto │ │ └── proto_gen.sh │ ├── swagger │ │ └── account.swagger.json │ └── testdata │ │ ├── gmqtt_password.yml │ │ ├── gmqtt_password_duplicated.yml │ │ └── gmqtt_password_save.yml ├── federation │ ├── README.md │ ├── config.go │ ├── config_test.go │ ├── examples │ │ ├── join_node3_config.yml │ │ ├── node1_config.yml │ │ └── node2_config.yml │ ├── federation.go │ ├── federation.pb.go │ ├── federation.pb.gw.go │ ├── federation.pb_mock.go │ ├── federation_grpc.pb.go │ ├── federation_grpc.pb_mock.go │ ├── federation_test.go │ ├── hooks.go │ ├── hooks_test.go │ ├── membership.go │ ├── membership_mock.go │ ├── peer.go │ ├── peer_mock.go │ ├── peer_test.go │ ├── protos │ │ ├── federation.proto │ │ └── proto_gen.sh │ └── swagger │ │ └── federation.swagger.json ├── prometheus │ ├── README.md │ ├── config.go │ ├── hooks.go │ └── prometheus.go └── thingspanel │ ├── config.go │ ├── db.go │ ├── hooks.go │ ├── mqtt.go │ ├── other.go │ ├── thingspanel.go │ └── util │ ├── check_pub_topic.go │ ├── check_pub_topic_test.go │ ├── check_sub_topic.go │ └── check_sub_topic_test.go ├── plugin_generate.go ├── plugin_imports.yml ├── retained ├── interface.go ├── interface_mock.go └── trie │ ├── retain_trie.go │ ├── trie_db.go │ └── trie_db_test.go ├── script └── main.go ├── server ├── api_registrar.go ├── api_registrar_test.go ├── client.go ├── client_mock.go ├── client_test.go ├── hook.go ├── limiter.go ├── limiter_test.go ├── options.go ├── persistence.go ├── persistence_mock.go ├── plugin.go ├── plugin_mock.go ├── publish_service.go ├── queue_notifier.go ├── server.go ├── server_mock.go ├── server_test.go ├── service.go ├── service_mock.go ├── stats.go ├── stats_mock.go ├── testdata │ ├── ca.pem │ ├── extfile.cnf │ ├── openssl.conf │ ├── server-cert.pem │ ├── server-key.pem │ └── test_gen.sh ├── topic_alias.go └── topic_alias_mock.go ├── session.go ├── subscription.go └── topicalias └── fifo ├── fifo.go └── fifo_test.go /.github/workflows/docker-image-build.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image Build 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: 检出代码 13 | uses: actions/checkout@v3 14 | with: 15 | fetch-depth: 0 16 | 17 | - name: 获取版本号 18 | id: get_version 19 | run: | 20 | VERSION=$(git describe --tags --abbrev=0 2>/dev/null || echo 'latest') 21 | echo "VERSION=$VERSION" >> $GITHUB_ENV 22 | # 添加仓库名小写转换 23 | echo "OWNER_LC=${GITHUB_REPOSITORY_OWNER,,}" >> $GITHUB_ENV 24 | 25 | - name: 登录镜像仓库 26 | run: | 27 | echo "${{ secrets.DOCKERHUB_TOKEN }}" | docker login -u ${{ secrets.DOCKERHUB_USERNAME }} --password-stdin 28 | echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u ${{ github.repository_owner }} --password-stdin 29 | echo "${{ secrets.IMAGE_PASS }}" | docker login registry.cn-hangzhou.aliyuncs.com -u ${{ secrets.IMAGE_USER }} --password-stdin 30 | 31 | - name: 设置 Docker Buildx 32 | uses: docker/setup-buildx-action@v1 33 | 34 | - name: 构建并推送到 GitHub/DockerHub 35 | uses: docker/build-push-action@v4 36 | with: 37 | context: . 38 | push: true 39 | tags: | 40 | thingspanel/thingspanel-gmqtt:${{ env.VERSION }} 41 | ghcr.io/${{ env.OWNER_LC }}/thingspanel-gmqtt:${{ env.VERSION }} 42 | 43 | - name: 推送到阿里云 44 | run: | 45 | docker pull ghcr.io/${{ env.OWNER_LC }}/thingspanel-gmqtt:${{ env.VERSION }} 46 | docker tag ghcr.io/${{ env.OWNER_LC }}/thingspanel-gmqtt:${{ env.VERSION }} registry.cn-hangzhou.aliyuncs.com/thingspanel/thingspanel-gmqtt:${{ env.VERSION }} 47 | docker push registry.cn-hangzhou.aliyuncs.com/thingspanel/thingspanel-gmqtt:${{ env.VERSION }} -------------------------------------------------------------------------------- /.github/workflows/docker-image-v1.0.0.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI-1.0.0 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | 8 | build: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Login to DockerHub 15 | uses: docker/login-action@v1 16 | with: 17 | username: ${{ secrets.DOCKERHUB_USERNAME }} 18 | password: ${{ secrets.DOCKERHUB_TOKEN }} 19 | - name: Build and push 20 | id: docker_build 21 | uses: docker/build-push-action@v2 22 | with: 23 | push: true 24 | tags: thingspanel/thingspanel-gmqtt:v1.0.0 25 | 26 | - name: Login to Aliyuncs Docker Hub 27 | uses: docker/login-action@v2.2.0 28 | with: 29 | registry: registry.cn-hangzhou.aliyuncs.com 30 | username: ${{ secrets.IMAGE_USER }} 31 | password: ${{ secrets.IMAGE_PASS }} 32 | logout: false 33 | 34 | - name: Use Skopeo Tools Sync Image to Aliyuncs Docker Hub 35 | run: | 36 | skopeo copy docker://docker.io/thingspanel/thingspanel-gmqtt:v1.0.0 docker://registry.cn-hangzhou.aliyuncs.com/thingspanel/thingspanel-gmqtt:v1.0.0 37 | -------------------------------------------------------------------------------- /.github/workflows/docker-image-v1.1.0.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI-1.1.0 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | 8 | build: 9 | 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Login to DockerHub 15 | uses: docker/login-action@v1 16 | with: 17 | username: ${{ secrets.DOCKERHUB_USERNAME }} 18 | password: ${{ secrets.DOCKERHUB_TOKEN }} 19 | - name: Build and push 20 | id: docker_build 21 | uses: docker/build-push-action@v2 22 | with: 23 | push: true 24 | tags: thingspanel/thingspanel-gmqtt:v1.1.0 25 | 26 | - name: Login to Aliyuncs Docker Hub 27 | uses: docker/login-action@v2.2.0 28 | with: 29 | registry: registry.cn-hangzhou.aliyuncs.com 30 | username: ${{ secrets.IMAGE_USER }} 31 | password: ${{ secrets.IMAGE_PASS }} 32 | logout: false 33 | 34 | - name: Use Skopeo Tools Sync Image to Aliyuncs Docker Hub 35 | run: | 36 | skopeo copy docker://docker.io/thingspanel/thingspanel-gmqtt:v1.1.0 docker://registry.cn-hangzhou.aliyuncs.com/thingspanel/thingspanel-gmqtt:v1.1.0 37 | -------------------------------------------------------------------------------- /.github/workflows/docker-image-v1.1.1.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI-1.1.1 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | 12 | - name: Set lowercase owner name 13 | run: | 14 | echo "OWNER_LC=${GITHUB_REPOSITORY_OWNER,,}" >>${GITHUB_ENV} 15 | 16 | - name: Login to DockerHub 17 | uses: docker/login-action@v1 18 | with: 19 | username: ${{ secrets.DOCKERHUB_USERNAME }} 20 | password: ${{ secrets.DOCKERHUB_TOKEN }} 21 | 22 | - name: Login to GitHub Container Registry 23 | uses: docker/login-action@v2 24 | with: 25 | registry: ghcr.io 26 | username: ${{ github.repository_owner }} 27 | password: ${{ secrets.GITHUB_TOKEN }} 28 | 29 | - name: Build and push 30 | id: docker_build 31 | uses: docker/build-push-action@v2 32 | with: 33 | push: true 34 | tags: | 35 | thingspanel/thingspanel-gmqtt:v1.1.1 36 | ghcr.io/${{ env.OWNER_LC }}/thingspanel-gmqtt:v1.1.1 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /vendor -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Gmqtt 2 | We welcome contributions to Gmqtt of any kind including documentation, plugin, test, bug reports, issues, feature requests, typo fix, etc. 3 | 4 | If you want to write some code, but don't know where to start or what you might want to do, take a look at the [Unplanned](https://github.com/DrmagicE/gmqtt/milestone/2) milestone. 5 | 6 | ## Contributing Code 7 | Feel free submit a pull request. Any pull request must be related to one or more open issues. 8 | If you are submitting a complex feature, it is recommended to open up a discussion or design proposal on issue track to get feedback before you start. 9 | 10 | ### Code Style 11 | Gmqtt is a Go project, it is recommended to follow the [CodeReviewComments](https://github.com/golang/go/wiki/CodeReviewComments) guidelines. 12 | When you’re ready to create a pull request, be sure to: 13 | * Have unit test for the new code. 14 | * Run [goimport](https://godoc.org/golang.org/x/tools/cmd/goimports). 15 | * Run go test -race ./... 16 | * Build the project with race detection enable (go build -race .), and pass both V3 and V5 test cases (except test_flow_control2 [#68](https://github.com/eclipse/paho.mqtt.testing/issues/68)) in [paho.mqtt.testing](https://github.com/eclipse/paho.mqtt.testing/tree/master/interoperability) 17 | 18 | 19 | ### Testing 20 | Testing is really important for building a robust application. Any new code or changes must come with unit test. 21 | 22 | #### Mocking 23 | Gmqtt uses [GoMock](https://github.com/golang/mock) to generate mock codes. The mock file must begin with the source file name and ends with `_mock.go`. For example, the following command will generate the mock file for `client.go` 24 | ```bash 25 | mockgen -source=server/client.go -destination=/usr/local/gopath/src/github.com/DrmagicE/gmqtt/server/client_mock.go -package=server -self_package=github.com/DrmagicE/gmqtt/server 26 | ``` 27 | #### Assertion 28 | Please use [testify](https://github.com/stretchr/testify) for easy assertion. 29 | 30 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | FROM golang:alpine AS builder 3 | WORKDIR $GOPATH/src/app 4 | ADD . ./ 5 | ENV GO111MODULE on 6 | ENV GOPROXY="https://goproxy.io" 7 | WORKDIR $GOPATH/src/app/cmd/gmqttd 8 | RUN go build 9 | 10 | FROM alpine:3.12 11 | WORKDIR /gmqttd 12 | # RUN apk update && apk add --no-cache tzdata 13 | COPY --from=builder /go/src/app/cmd/gmqttd . 14 | EXPOSE 1883 8883 8082 8083 8084 15 | RUN chmod +x gmqttd 16 | RUN pwd 17 | RUN ls -lrt 18 | ENTRYPOINT ["./gmqttd", "start", "-c", "/gmqttd/default_config.yml"] 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 DrmagicE 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /cmd/gmqctl/command/command.go: -------------------------------------------------------------------------------- 1 | package command 2 | 3 | import ( 4 | "github.com/spf13/cobra" 5 | 6 | gen_plugin "github.com/DrmagicE/gmqtt/cmd/gmqctl/command/gen-plugin" 7 | ) 8 | 9 | // Gen is the command for code generator. 10 | var Gen = &cobra.Command{ 11 | Use: "gen", 12 | Short: "Code generator", 13 | } 14 | 15 | func init() { 16 | Gen.AddCommand(gen_plugin.Command) 17 | } 18 | -------------------------------------------------------------------------------- /cmd/gmqctl/command/gen-plugin/command_test.go: -------------------------------------------------------------------------------- 1 | package gen_plugin 2 | 3 | import ( 4 | "os" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | func TestRun(t *testing.T) { 11 | defer os.RemoveAll("./testdata/") 12 | a := assert.New(t) 13 | var err error 14 | os.Mkdir("./testdata/", 0777) 15 | name = "test" 16 | hooksStr = "OnBasicAuth" 17 | config = true 18 | output = "./testdata" 19 | a.Nil(run(nil, nil)) 20 | _, err = os.Stat("./testdata/test.go") 21 | a.Nil(err) 22 | _, err = os.Stat("./testdata/config.go") 23 | a.Nil(err) 24 | _, err = os.Stat("./testdata/hooks.go") 25 | a.Nil(err) 26 | a.NotNil(run(nil, nil)) 27 | } 28 | 29 | func TestRunEmptyHooks(t *testing.T) { 30 | defer os.RemoveAll("./testdata/") 31 | a := assert.New(t) 32 | var err error 33 | os.Mkdir("./testdata/", 0777) 34 | name = "test" 35 | config = true 36 | output = "./testdata" 37 | a.Nil(run(nil, nil)) 38 | _, err = os.Stat("./testdata/test.go") 39 | a.Nil(err) 40 | _, err = os.Stat("./testdata/config.go") 41 | a.Nil(err) 42 | _, err = os.Stat("./testdata/hooks.go") 43 | a.Nil(err) 44 | } 45 | 46 | func TestRunNoConfig(t *testing.T) { 47 | a := assert.New(t) 48 | var err error 49 | defer os.RemoveAll("./testdata/") 50 | os.Mkdir("./testdata", 0777) 51 | name = "test" 52 | hooksStr = "OnBasicAuth" 53 | config = false 54 | output = "./testdata" 55 | run(nil, nil) 56 | _, err = os.Stat("./testdata/test.go") 57 | a.Nil(err) 58 | _, err = os.Stat("./testdata/config.go") 59 | a.True(os.IsNotExist(err)) 60 | _, err = os.Stat("./testdata/hooks.go") 61 | a.Nil(err) 62 | } 63 | 64 | func TestValidateHookName(t *testing.T) { 65 | var tt = []struct { 66 | name string 67 | hooks string 68 | rs []string 69 | valid bool 70 | }{ 71 | { 72 | name: "valid", 73 | hooks: "OnSubscribe, OnSubscribed", 74 | rs: []string{"OnSubscribe", "OnSubscribed"}, 75 | valid: true, 76 | }, 77 | { 78 | name: "invalid", 79 | hooks: "OnAbc,OnDEF", 80 | rs: nil, 81 | valid: false, 82 | }, 83 | } 84 | for _, v := range tt { 85 | t.Run(v.name, func(t *testing.T) { 86 | a := assert.New(t) 87 | got, err := ValidateHooks(v.hooks) 88 | if v.valid { 89 | a.Nil(err) 90 | } else { 91 | a.NotNil(err) 92 | } 93 | a.Equal(v.rs, got) 94 | }) 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /cmd/gmqctl/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/spf13/cobra" 8 | 9 | "github.com/DrmagicE/gmqtt/cmd/gmqctl/command" 10 | ) 11 | 12 | var ( 13 | rootCmd = &cobra.Command{ 14 | Use: "gmqctl", 15 | Long: "gmqctl is a command line tool for gmqtt", 16 | Version: Version, 17 | } 18 | ) 19 | 20 | func init() { 21 | rootCmd.AddCommand(command.Gen) 22 | } 23 | 24 | func must(err error) { 25 | if err != nil { 26 | fmt.Println(err) 27 | os.Exit(1) 28 | } 29 | } 30 | 31 | func main() { 32 | must(rootCmd.Execute()) 33 | } 34 | -------------------------------------------------------------------------------- /cmd/gmqctl/version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // The git commit that was compiled. This will be filled in by the compiler. 4 | var GitCommit string 5 | 6 | // The main version number that is being run at the moment. 7 | const Version = "" 8 | -------------------------------------------------------------------------------- /cmd/gmqttd/certs/ca.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDtTCCAp2gAwIBAgIUaL6CZOK/XH/5Ec4b7UdA4AIGipkwDQYJKoZIhvcNAQEL 3 | BQAwajELMAkGA1UEBhMCQ04xEjAQBgNVBAgMCUd1YW5nRG9uZzERMA8GA1UEBwwI 4 | U2hlblpoZW4xEDAOBgNVBAoMB0NvbXBhbnkxEDAOBgNVBAsMB0dhdGV3YXkxEDAO 5 | BgNVBAMMB1Jvb3QgQ0EwHhcNMjMwNTI1MDIyNDI2WhcNMzMwNTIyMDIyNDI2WjBq 6 | MQswCQYDVQQGEwJDTjESMBAGA1UECAwJR3VhbmdEb25nMREwDwYDVQQHDAhTaGVu 7 | WmhlbjEQMA4GA1UECgwHQ29tcGFueTEQMA4GA1UECwwHR2F0ZXdheTEQMA4GA1UE 8 | AwwHUm9vdCBDQTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMJBnDMF 9 | +6I/+9UQzpr6qaaYRwID1aRWm9AzT1MAIxKyQDxfIIIeOTKyyHrKUJ2u9rACuc7u 10 | b6Yd/J11O7Ptf4vU0BZrPd51etFA13aajqQhQ3PMfETUEsovnYB91E8Sc4YZHunL 11 | yeQlwo0hyySOsXbM2MvPnpbCVeSGQ5YW7QAPvCFmiwQpldazIXT8ArMvDIdXfi2V 12 | PtlOGZeYnsRRwimWp675edr+IkdbBTEctWyk7t7AVgq14BxidvB0+3s+O0ufdovx 13 | e8MbxOmVmZj7rUVd/GtY0pZbhqnZEfkxHEmugFRZf9soxRkS4EReiTkg7za4VnFn 14 | nDJuZy/PchacKj8CAwEAAaNTMFEwHQYDVR0OBBYEFDGz6XMgaWnrsLiHNh4l3bJd 15 | iNQIMB8GA1UdIwQYMBaAFDGz6XMgaWnrsLiHNh4l3bJdiNQIMA8GA1UdEwEB/wQF 16 | MAMBAf8wDQYJKoZIhvcNAQELBQADggEBALLnWYs5VqnfiJv1kWdYrMeguy6M5T4v 17 | AS6i0klO6P3/s5ER92ql5MxIEA069cZ9OTORcd3istBlY3DrMh/7NQnTvgmqBUYK 18 | FygzZEsUUngyb5TH6HT7NTSK29d+KXqmiUH/ASIWwXu6VtUflxbjDcWqU45ipk+X 19 | r4eucSvTloCS69IWQ0y7g9Hh63qJT+2x8lbFFB9f6CdsvAACwONTU+m2NjUa8xxV 20 | 1K0d6rJYOouMg+oA+hVgfXkxX7csq7tPYE3M5rXBXIhcNdltwCcZd/vpQTbN3i2E 21 | lWfScTmbcLpfJ5byZQzrOndLsXXVzc99YpQYWW5TpMg1VeC36jb4GZs= 22 | -----END CERTIFICATE----- 23 | -------------------------------------------------------------------------------- /cmd/gmqttd/certs/server.crt: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDZTCCAk0CFAqBc91yHPZFhpthdYyKM8ZQGsOjMA0GCSqGSIb3DQEBCwUAMGox 3 | CzAJBgNVBAYTAkNOMRIwEAYDVQQIDAlHdWFuZ0RvbmcxETAPBgNVBAcMCFNoZW5a 4 | aGVuMRAwDgYDVQQKDAdDb21wYW55MRAwDgYDVQQLDAdHYXRld2F5MRAwDgYDVQQD 5 | DAdSb290IENBMB4XDTIzMDUyNTAyMjUxOFoXDTMzMDUyMjAyMjUxOFowdDELMAkG 6 | A1UEBhMCQ04xEjAQBgNVBAgMCUd1YW5nRG9uZzERMA8GA1UEBwwIU2hlblpoZW4x 7 | DzANBgNVBAoMBlNlcnZlcjEQMA4GA1UECwwHR2F0ZXdheTEbMBkGA1UEAwwSZGV2 8 | LnRoaW5nc3BhbmVsLmNuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA 9 | 329CnPH2YnFZc7iNwC8vV4ClYixu+Yun4oupJ2v9B0xDqCdNb/EL2vC6piYd89BC 10 | Sqjl0ERncAK3XDPyJLmEnonsKcJ1rMYIuF7zf2GSWDZ7Ybj1zmChATAuvyobSEVx 11 | 9gskrxoi7w9pWP8yBtWN/RubBHd2lzm6hgmUeFOI4jFwNCnRH4/thmtyB0w56KYy 12 | N4/KJ6JhlZGneQ2gsVdbH8iFnw8n3zPIF7F79IFRa04s8dUvSS7CdR3QQeInOboZ 13 | KYPcjr498i4S9MmWAkjSsd15tcv4mQdjk5wVa44VPXzeoRNba+gyFKHvPrhxdKtY 14 | OGJoTeAc0RS+LBPXexjJXQIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQAaLqfsL5B9 15 | lGuJPFXya9UUPUg1ZOBR0Z0itOOqFhtYYWsxBIdAfXAq+rbX2BrfilXHClVeyLpw 16 | Rc4vXBxgG1LoIKpGY8OvNUDpoDtZgtqcveRmFYeMsmMsHvZTRNg8ftz1EsATN3Rb 17 | OWXTfuZ6kouwF94lZ8B34h08cNhQt60XwyL965+0I2biiMQluZ4m8vJkrvMSZN1R 18 | yq5wAqw3it0Z2S5TjNOSNw8fKUeJnp3MFrtgMcBWzHqZPSpJRG3Coqic2Y22Qd2I 19 | gc5/cys6wfBo26MzB9JniZxlERGMaHd3H26pZeA6CnOITVe3cbydZCKPBl10XdAb 20 | 21K1IcYcMd1A 21 | -----END CERTIFICATE----- 22 | -------------------------------------------------------------------------------- /cmd/gmqttd/certs/server.key: -------------------------------------------------------------------------------- 1 | -----BEGIN PRIVATE KEY----- 2 | MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDfb0Kc8fZicVlz 3 | uI3ALy9XgKViLG75i6fii6kna/0HTEOoJ01v8Qva8LqmJh3z0EJKqOXQRGdwArdc 4 | M/IkuYSeiewpwnWsxgi4XvN/YZJYNnthuPXOYKEBMC6/KhtIRXH2CySvGiLvD2lY 5 | /zIG1Y39G5sEd3aXObqGCZR4U4jiMXA0KdEfj+2Ga3IHTDnopjI3j8onomGVkad5 6 | DaCxV1sfyIWfDyffM8gXsXv0gVFrTizx1S9JLsJ1HdBB4ic5uhkpg9yOvj3yLhL0 7 | yZYCSNKx3Xm1y/iZB2OTnBVrjhU9fN6hE1tr6DIUoe8+uHF0q1g4YmhN4BzRFL4s 8 | E9d7GMldAgMBAAECggEACHhq0ZFYslWyuvDZqtn/FPWjD2w4zswNEskRohV1c7Pf 9 | 8r1+sYo1VVj/8nGKfCY+hR8PC0y0kSoBMoUjnmHMFciLw++Dz4d4aCjtWTxlNGPo 10 | XIWLCAZgxC9D7mpVu3Eqh1XXz62ReneemkmeZ2TsK5bC0zIGNGhzWkZ9suwTyeNc 11 | v2N67H/boYP9/t2CSU6GUJpSSzB2Jns3WzPlayQwUp548zIQleBQNcWF2Vn3pqWX 12 | 4bphiCx6xJ/Tckkit2NZZ3kLItW1NUdrz6cj2QmOathowx69zgF4NBV1xQqcJF7i 13 | mzlfmmwznmhFquarRkqeClnAktY4KKRYCA7HXK5P3wKBgQD+Er/aBf3k+wwSiINY 14 | Qb2HiDOjPaypaGm9lD0W2prwMUDfwMINXwkYaypG++xNpBde6bQ7cp2KgXgx+pDJ 15 | /AhtgwHIRz/zQdC5r+R1ay/wp33pyEWZTFvAYB9xg8kuc/va1WulHEG+1p6KfoyP 16 | TVUrtzQFT72xEKoKIKHm+bTEpwKBgQDhIQeiTQ+SoScPp6OeBqbCyG+V4riH6Jjp 17 | l9GOa5CG4AXCpoEOwqMhNshDdf/Qc/wsOkgUmU86afo0tCt/Svqx5xE9jxHj9ecN 18 | A0OqVbqbauMUiSopBinj2j0HY6JdVdcN/DBBpAHxw0DzmJ8QD8B5yXv7lYUUv+a+ 19 | g2be/jROWwKBgQC+V7mXUunVRCbVM6SC2C1vfiCBaVETUX/2YTorBvcQfzXE65n1 20 | fn9H5fE0YMO1nvtLRfaamtFf1IMBnmAekkyWDpGlQ4uraGFA239iYDz4I+L24+0Q 21 | Xd5XDyw/VKXBwW6rkTwl5Dd1C0CXLRuMuDjYmVXFrOnF32AkWjIw4l4E/QKBgEL5 22 | mZTOSii8Kqu8mq+DmQ7vpEq6BV3hc8RityQgmgGWGgCbml9yxic2bgOr0iwIpWfe 23 | +tyt82UUbCxLwXkALG7KqFVg/9iKqm8znmjJUle0R9QvLkzAGaxAm9Fb8czEodL2 24 | SMDucumixeryZ7fWh9NzfqANDmdq49Gfs/X5OERrAoGBAMVeMwImOE6fG2PVCCbQ 25 | cLb+uMzm51ywvihQ5Xjw1pZ3IHs0LDCP7u0FBnFy22HX4vNxvu9YyJE5uydsyNzA 26 | B15V9KN9Mq62dRfhYOq5s5ebBOou0o9dEaBoMppxUA7rGEFKozxXQ3wOBbEq2InZ 27 | u4gXXqqpxuMwjjOX2u9Gsvx7 28 | -----END PRIVATE KEY----- 29 | -------------------------------------------------------------------------------- /cmd/gmqttd/command/reload.go: -------------------------------------------------------------------------------- 1 | package command 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "strconv" 7 | "syscall" 8 | 9 | "github.com/pkg/errors" 10 | "github.com/spf13/cobra" 11 | 12 | "github.com/DrmagicE/gmqtt/config" 13 | ) 14 | 15 | // NewReloadCommand creates a *cobra.Command object for reload command. 16 | func NewReloadCommand() *cobra.Command { 17 | cmd := &cobra.Command{ 18 | Use: "reload", 19 | Short: "Reload gmqtt broker", 20 | Run: func(cmd *cobra.Command, args []string) { 21 | var c config.Config 22 | var err error 23 | c, err = config.ParseConfig(ConfigFile) 24 | if os.IsNotExist(err) { 25 | c = config.DefaultConfig() 26 | } else { 27 | must(err) 28 | } 29 | b, err := ioutil.ReadFile(c.PidFile) 30 | must(errors.Wrap(err, "read pid file error")) 31 | pid, err := strconv.Atoi(string(b)) 32 | must(errors.Wrap(err, "read pid file error")) 33 | p, err := os.FindProcess(pid) 34 | must(errors.Wrap(err, "find process error")) 35 | err = p.Signal(syscall.SIGHUP) 36 | must(err) 37 | }, 38 | } 39 | return cmd 40 | } 41 | -------------------------------------------------------------------------------- /cmd/gmqttd/config_unix.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | // +build !windows 3 | 4 | package main 5 | 6 | var ( 7 | DefaultConfigDir = "/gmqttd" 8 | ) 9 | 10 | func getDefaultConfigDir() (string, error) { 11 | return DefaultConfigDir, nil 12 | } 13 | -------------------------------------------------------------------------------- /cmd/gmqttd/config_windows.go: -------------------------------------------------------------------------------- 1 | // +build windows 2 | 3 | package main 4 | 5 | import ( 6 | "os" 7 | "path/filepath" 8 | ) 9 | 10 | func getDefaultConfigDir() (string, error) { 11 | return filepath.Join(os.Getenv("programdata"), "gmqtt"), nil 12 | } 13 | -------------------------------------------------------------------------------- /cmd/gmqttd/gmqtt.sh: -------------------------------------------------------------------------------- 1 | go run . start -c default_config.yml -------------------------------------------------------------------------------- /cmd/gmqttd/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path" 7 | 8 | "github.com/spf13/cobra" 9 | 10 | "github.com/DrmagicE/gmqtt/cmd/gmqttd/command" 11 | _ "github.com/DrmagicE/gmqtt/persistence" 12 | _ "github.com/DrmagicE/gmqtt/plugin/prometheus" 13 | _ "github.com/DrmagicE/gmqtt/topicalias/fifo" 14 | ) 15 | 16 | var ( 17 | rootCmd = &cobra.Command{ 18 | Use: "gmqttd", 19 | Long: "Gmqtt is a MQTT broker that fully implements MQTT V5.0 and V3.1.1 protocol", 20 | Version: Version, 21 | } 22 | ) 23 | 24 | func must(err error) { 25 | if err != nil { 26 | fmt.Fprint(os.Stderr, err.Error()) 27 | os.Exit(1) 28 | } 29 | } 30 | 31 | func init() { 32 | configDir, err := getDefaultConfigDir() 33 | must(err) 34 | command.ConfigFile = path.Join(configDir, "gmqttd.yml") 35 | rootCmd.PersistentFlags().StringVarP(&command.ConfigFile, "config", "c", command.ConfigFile, "The configuration file path") 36 | rootCmd.AddCommand(command.NewStartCmd()) 37 | //rootCmd.AddCommand(command.NewReloadCommand()) 38 | } 39 | 40 | func main() { 41 | if err := rootCmd.Execute(); err != nil { 42 | fmt.Fprint(os.Stderr, err.Error()) 43 | os.Exit(1) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /cmd/gmqttd/main_pprof.go: -------------------------------------------------------------------------------- 1 | // +build pprof 2 | 3 | package main 4 | 5 | import ( 6 | _ "net/http/pprof" 7 | ) 8 | 9 | func init() { 10 | enablePprof = true 11 | rootCmd.PersistentFlags().StringVar(&pprofAddr, "pprof_addr", pprofAddr, "The listening address for the pprof http server") 12 | } 13 | -------------------------------------------------------------------------------- /cmd/gmqttd/plugins.go: -------------------------------------------------------------------------------- 1 | //go:generate sh -c "cd ../../ && go run plugin_generate.go" 2 | // generated by plugin_generate.go; DO NOT EDIT 3 | 4 | package main 5 | 6 | import ( 7 | _ "github.com/DrmagicE/gmqtt/plugin/admin" 8 | _ "github.com/DrmagicE/gmqtt/plugin/auth" 9 | _ "github.com/DrmagicE/gmqtt/plugin/federation" 10 | _ "github.com/DrmagicE/gmqtt/plugin/prometheus" 11 | _ "github.com/DrmagicE/gmqtt/plugin/thingspanel" 12 | ) 13 | -------------------------------------------------------------------------------- /cmd/gmqttd/thingspanel.yml: -------------------------------------------------------------------------------- 1 | db: 2 | redis: 3 | # redis 连接字符串 4 | conn: 127.0.0.1:6379 5 | # redis 数据库号 6 | db_num: 1 7 | # redis 密码 8 | password: "redis" 9 | psql: 10 | # psqladdr: "127.0.0.1" 11 | # psqladdr: "47.251.45.205" 12 | # psqlport: 5432 13 | # psqldb: irrigate 14 | # psqluser: postgres 15 | # psqlpass: postgresThingsPanel2022 16 | 17 | #社区版 18 | # psqladdr: "47.92.253.145" 19 | psqladdr: "127.0.0.1" 20 | psqlport: 5432 # 默认5432 21 | psqldb: ThingsPanel 22 | psqluser: postgres 23 | psqlpass: postgresThingsPanel 24 | 25 | mqtt: 26 | # root用户的密码 27 | broker: localhost:1883 28 | password: "root" 29 | plugin_password: "plugin" 30 | -------------------------------------------------------------------------------- /cmd/gmqttd/version.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | // The git commit that was compiled. This will be filled in by the compiler. 4 | var GitCommit string 5 | 6 | // The main version number that is being run at the moment. 7 | const Version = "" 8 | -------------------------------------------------------------------------------- /config/api.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "fmt" 5 | "net" 6 | "strings" 7 | ) 8 | 9 | // API is the configuration for API server. 10 | // The API server use gRPC-gateway to provide both gRPC and HTTP endpoints. 11 | type API struct { 12 | // GRPC is the gRPC endpoint configuration. 13 | GRPC []*Endpoint `yaml:"grpc"` 14 | // HTTP is the HTTP endpoint configuration. 15 | HTTP []*Endpoint `yaml:"http"` 16 | } 17 | 18 | // Endpoint represents a gRPC or HTTP server endpoint. 19 | type Endpoint struct { 20 | // Address is the bind address of the endpoint. 21 | // Format: [tcp|unix://][]: 22 | // e.g : 23 | // * unix:///var/run/gmqttd.sock 24 | // * tcp://127.0.0.1:8080 25 | // * :8081 (equal to tcp://:8081) 26 | Address string `yaml:"address"` 27 | // Map maps the HTTP endpoint to gRPC endpoint. 28 | // Must be set if the endpoint is representing a HTTP endpoint. 29 | Map string `yaml:"map"` 30 | // TLS is the tls configuration. 31 | TLS *TLSOptions `yaml:"tls"` 32 | } 33 | 34 | var DefaultAPI API 35 | 36 | func (a API) validateAddress(address string, fieldName string) error { 37 | if address == "" { 38 | return fmt.Errorf("%s cannot be empty", fieldName) 39 | } 40 | epParts := strings.SplitN(address, "://", 2) 41 | if len(epParts) == 1 && epParts[0] != "" { 42 | epParts = []string{"tcp", epParts[0]} 43 | } 44 | if len(epParts) != 0 { 45 | switch epParts[0] { 46 | case "tcp": 47 | _, _, err := net.SplitHostPort(epParts[1]) 48 | if err != nil { 49 | return fmt.Errorf("invalid %s: %s", fieldName, err.Error()) 50 | } 51 | case "unix": 52 | default: 53 | return fmt.Errorf("invalid %s schema: %s", fieldName, epParts[0]) 54 | } 55 | } 56 | return nil 57 | } 58 | 59 | func (a API) Validate() error { 60 | for _, v := range a.GRPC { 61 | err := a.validateAddress(v.Address, "endpoint") 62 | if err != nil { 63 | return err 64 | } 65 | } 66 | for _, v := range a.HTTP { 67 | err := a.validateAddress(v.Address, "endpoint") 68 | if err != nil { 69 | return err 70 | } 71 | err = a.validateAddress(v.Map, "map") 72 | if err != nil { 73 | return err 74 | } 75 | } 76 | return nil 77 | } 78 | -------------------------------------------------------------------------------- /config/api_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestAPI_Validate(t *testing.T) { 10 | a := assert.New(t) 11 | 12 | tt := []struct { 13 | cfg API 14 | valid bool 15 | }{ 16 | { 17 | cfg: API{ 18 | GRPC: []*Endpoint{ 19 | { 20 | Address: "udp://127.0.0.1", 21 | }, 22 | }, 23 | HTTP: []*Endpoint{ 24 | {}, 25 | }, 26 | }, 27 | valid: false, 28 | }, 29 | { 30 | cfg: API{ 31 | GRPC: []*Endpoint{ 32 | { 33 | Address: "tcp://127.0.0.1:1234", 34 | }, 35 | }, 36 | HTTP: []*Endpoint{ 37 | { 38 | Address: "udp://127.0.0.1", 39 | }, 40 | }, 41 | }, 42 | valid: false, 43 | }, 44 | { 45 | cfg: API{ 46 | GRPC: []*Endpoint{ 47 | { 48 | Address: "tcp://127.0.0.1:1234", 49 | }, 50 | }, 51 | }, 52 | valid: true, 53 | }, 54 | { 55 | cfg: API{ 56 | GRPC: []*Endpoint{ 57 | { 58 | Address: "tcp://127.0.0.1:1234", 59 | }, 60 | }, 61 | HTTP: []*Endpoint{ 62 | { 63 | Address: "tcp://127.0.0.1:1235", 64 | }, 65 | }, 66 | }, 67 | valid: false, 68 | }, 69 | { 70 | cfg: API{ 71 | GRPC: []*Endpoint{ 72 | { 73 | Address: "unix:///var/run/gmqttd.sock", 74 | }, 75 | }, 76 | HTTP: []*Endpoint{ 77 | { 78 | Address: "tcp://127.0.0.1:1235", 79 | Map: "unix:///var/run/gmqttd.sock", 80 | }, 81 | }, 82 | }, 83 | valid: true, 84 | }, 85 | } 86 | for _, v := range tt { 87 | err := v.cfg.Validate() 88 | if v.valid { 89 | a.NoError(err) 90 | } else { 91 | a.Error(err) 92 | } 93 | } 94 | 95 | } 96 | -------------------------------------------------------------------------------- /config/config_mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: config/config.go 3 | 4 | // Package config is a generated GoMock package. 5 | package config 6 | 7 | import ( 8 | gomock "github.com/golang/mock/gomock" 9 | reflect "reflect" 10 | ) 11 | 12 | // MockConfiguration is a mock of Configuration interface 13 | type MockConfiguration struct { 14 | ctrl *gomock.Controller 15 | recorder *MockConfigurationMockRecorder 16 | } 17 | 18 | // MockConfigurationMockRecorder is the mock recorder for MockConfiguration 19 | type MockConfigurationMockRecorder struct { 20 | mock *MockConfiguration 21 | } 22 | 23 | // NewMockConfiguration creates a new mock instance 24 | func NewMockConfiguration(ctrl *gomock.Controller) *MockConfiguration { 25 | mock := &MockConfiguration{ctrl: ctrl} 26 | mock.recorder = &MockConfigurationMockRecorder{mock} 27 | return mock 28 | } 29 | 30 | // EXPECT returns an object that allows the caller to indicate expected use 31 | func (m *MockConfiguration) EXPECT() *MockConfigurationMockRecorder { 32 | return m.recorder 33 | } 34 | 35 | // Validate mocks base method 36 | func (m *MockConfiguration) Validate() error { 37 | m.ctrl.T.Helper() 38 | ret := m.ctrl.Call(m, "Validate") 39 | ret0, _ := ret[0].(error) 40 | return ret0 41 | } 42 | 43 | // Validate indicates an expected call of Validate 44 | func (mr *MockConfigurationMockRecorder) Validate() *gomock.Call { 45 | mr.mock.ctrl.T.Helper() 46 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Validate", reflect.TypeOf((*MockConfiguration)(nil).Validate)) 47 | } 48 | 49 | // UnmarshalYAML mocks base method 50 | func (m *MockConfiguration) UnmarshalYAML(unmarshal func(interface{}) error) error { 51 | m.ctrl.T.Helper() 52 | ret := m.ctrl.Call(m, "UnmarshalYAML", unmarshal) 53 | ret0, _ := ret[0].(error) 54 | return ret0 55 | } 56 | 57 | // UnmarshalYAML indicates an expected call of UnmarshalYAML 58 | func (mr *MockConfigurationMockRecorder) UnmarshalYAML(unmarshal interface{}) *gomock.Call { 59 | mr.mock.ctrl.T.Helper() 60 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnmarshalYAML", reflect.TypeOf((*MockConfiguration)(nil).UnmarshalYAML), unmarshal) 61 | } 62 | -------------------------------------------------------------------------------- /config/config_test.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestParseConfig(t *testing.T) { 10 | var tt = []struct { 11 | caseName string 12 | fileName string 13 | hasErr bool 14 | expected Config 15 | }{ 16 | { 17 | caseName: "defaultConfig", 18 | fileName: "", 19 | hasErr: false, 20 | expected: DefaultConfig(), 21 | }, 22 | } 23 | 24 | for _, v := range tt { 25 | t.Run(v.caseName, func(t *testing.T) { 26 | a := assert.New(t) 27 | c, err := ParseConfig(v.fileName) 28 | if v.hasErr { 29 | a.NotNil(err) 30 | } else { 31 | a.Nil(err) 32 | } 33 | a.Equal(v.expected, c) 34 | }) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /config/persistence.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "net" 5 | "time" 6 | 7 | "github.com/pkg/errors" 8 | ) 9 | 10 | type PersistenceType = string 11 | 12 | const ( 13 | PersistenceTypeMemory PersistenceType = "memory" 14 | PersistenceTypeRedis PersistenceType = "redis" 15 | ) 16 | 17 | var ( 18 | defaultMaxActive = uint(0) 19 | defaultMaxIdle = uint(1000) 20 | // DefaultPersistenceConfig is the default value of Persistence 21 | DefaultPersistenceConfig = Persistence{ 22 | Type: PersistenceTypeMemory, 23 | Redis: RedisPersistence{ 24 | Addr: "127.0.0.1:6379", 25 | Password: "", 26 | Database: 0, 27 | MaxIdle: &defaultMaxIdle, 28 | MaxActive: &defaultMaxActive, 29 | IdleTimeout: 240 * time.Second, 30 | }, 31 | } 32 | ) 33 | 34 | // Persistence is the config of backend persistence. 35 | type Persistence struct { 36 | // Type is the persistence type. 37 | // If empty, use "memory" as default. 38 | Type PersistenceType `yaml:"type"` 39 | // Redis is the redis configuration and must be set when Type == "redis". 40 | Redis RedisPersistence `yaml:"redis"` 41 | } 42 | 43 | // RedisPersistence is the configuration of redis persistence. 44 | type RedisPersistence struct { 45 | // Addr is the redis server address. 46 | // If empty, use "127.0.0.1:6379" as default. 47 | Addr string `yaml:"addr"` 48 | // Password is the redis password. 49 | Password string `yaml:"password"` 50 | // Database is the number of the redis database to be connected. 51 | Database uint `yaml:"database"` 52 | // MaxIdle is the maximum number of idle connections in the pool. 53 | // If nil, use 1000 as default. 54 | // This value will pass to redis.Pool.MaxIde. 55 | MaxIdle *uint `yaml:"max_idle"` 56 | // MaxActive is the maximum number of connections allocated by the pool at a given time. 57 | // If nil, use 0 as default. 58 | // If zero, there is no limit on the number of connections in the pool. 59 | // This value will pass to redis.Pool.MaxActive. 60 | MaxActive *uint `yaml:"max_active"` 61 | // Close connections after remaining idle for this duration. If the value 62 | // is zero, then idle connections are not closed. Applications should set 63 | // the timeout to a value less than the server's timeout. 64 | // Ff zero, use 240 * time.Second as default. 65 | // This value will pass to redis.Pool.IdleTimeout. 66 | IdleTimeout time.Duration `yaml:"idle_timeout"` 67 | } 68 | 69 | func (p *Persistence) Validate() error { 70 | if p.Type != PersistenceTypeMemory && p.Type != PersistenceTypeRedis { 71 | return errors.New("invalid persistence type") 72 | } 73 | _, _, err := net.SplitHostPort(p.Redis.Addr) 74 | if err != nil { 75 | return err 76 | } 77 | if p.Redis.Database < 0 { 78 | return errors.New("invalid redis database number") 79 | } 80 | return nil 81 | } 82 | -------------------------------------------------------------------------------- /config/testdata/config.yml: -------------------------------------------------------------------------------- 1 | listeners: 2 | - address: ":1883" 3 | websocket: 4 | path: "/" 5 | - address: ":1234" 6 | 7 | mqtt: 8 | session_expiry: 1m 9 | message_expiry: 1m 10 | max_packet_size: 200 11 | server_receive_maximum: 65535 12 | max_keepalive: 0 # unlimited 13 | topic_alias_maximum: 0 # 0 means not Supported 14 | subscription_identifier_available: true 15 | wildcard_subscription_available: true 16 | shared_subscription_available: true 17 | maximum_qos: 2 18 | retain_available: true 19 | max_queued_messages: 1000 20 | max_inflight: 32 21 | max_awaiting_rel: 100 22 | queue_qos0_messages: true 23 | delivery_mode: overlap # overlap or onlyonce 24 | allow_zero_length_clientid: true 25 | 26 | log: 27 | level: debug # debug | info | warning | error 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /config/testdata/default_values.yml: -------------------------------------------------------------------------------- 1 | listeners: 2 | mqtt: 3 | log: 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /config/testdata/default_values_expected.yml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThingsPanel/thingspanel-gmqtt/25e5e350779c0c5a32e4945648cf60f68d5adb17/config/testdata/default_values_expected.yml -------------------------------------------------------------------------------- /config/topic_alias.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | type TopicAliasType = string 4 | 5 | const ( 6 | TopicAliasMgrTypeFIFO TopicAliasType = "fifo" 7 | ) 8 | 9 | var ( 10 | // DefaultTopicAliasManager is the default value of TopicAliasManager 11 | DefaultTopicAliasManager = TopicAliasManager{ 12 | Type: TopicAliasMgrTypeFIFO, 13 | } 14 | ) 15 | 16 | // TopicAliasManager is the config of the topic alias manager. 17 | type TopicAliasManager struct { 18 | Type TopicAliasType 19 | } 20 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/DrmagicE/gmqtt 2 | 3 | go 1.14 4 | 5 | require ( 6 | github.com/eclipse/paho.mqtt.golang v1.4.2 7 | github.com/golang/mock v1.6.0 8 | github.com/golang/protobuf v1.5.3 9 | github.com/gomodule/redigo v1.8.2 10 | github.com/google/uuid v1.3.0 11 | github.com/gorilla/websocket v1.4.2 12 | github.com/grpc-ecosystem/go-grpc-middleware v1.0.0 13 | github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 14 | github.com/grpc-ecosystem/grpc-gateway v1.16.0 15 | github.com/hashicorp/go-sockaddr v1.0.0 16 | github.com/hashicorp/logutils v1.0.0 17 | github.com/hashicorp/serf v0.9.5 18 | github.com/iancoleman/strcase v0.1.2 19 | github.com/onsi/ginkgo v1.16.5 // indirect 20 | github.com/onsi/gomega v1.30.0 // indirect 21 | github.com/pkg/errors v0.8.1 22 | github.com/prometheus/client_golang v1.4.0 23 | github.com/spf13/cobra v1.0.0 24 | github.com/spf13/viper v1.4.0 25 | github.com/stretchr/testify v1.8.1 26 | go.uber.org/zap v1.13.0 27 | golang.org/x/crypto v0.14.0 28 | golang.org/x/sys v0.13.0 29 | google.golang.org/genproto v0.0.0-20221201204527-e3fa12d562f3 30 | google.golang.org/grpc v1.50.1 31 | google.golang.org/protobuf v1.28.1 32 | gopkg.in/redis.v5 v5.2.9 33 | gopkg.in/yaml.v2 v2.4.0 34 | gorm.io/driver/postgres v1.5.4 35 | gorm.io/gorm v1.25.5 36 | ) 37 | -------------------------------------------------------------------------------- /go_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | go test -race ./... -coverprofile=coverage.txt -covermode=atomic 5 | sed -i -e '/.pb./d' coverage.txt 6 | sed -i -e '/_mock/d' coverage.txt 7 | sed -i -e '/example/d' coverage.txt 8 | sed -i -e '/_darwin/d' coverage.txt 9 | sed -i -e '/_windows/d' coverage.txt -------------------------------------------------------------------------------- /mock_gen.sh: -------------------------------------------------------------------------------- 1 | mockgen -source=config/config.go -destination=./config/config_mock.go -package=config -self_package=github.com/DrmagicE/gmqtt/config 2 | mockgen -source=persistence/queue/elem.go -destination=./persistence/queue/elem_mock.go -package=queue -self_package=github.com/DrmagicE/gmqtt/queue 3 | mockgen -source=persistence/queue/queue.go -destination=./persistence/queue/queue_mock.go -package=queue -self_package=github.com/DrmagicE/gmqtt/queue 4 | mockgen -source=persistence/session/session.go -destination=./persistence/session/session_mock.go -package=session -self_package=github.com/DrmagicE/gmqtt/session 5 | mockgen -source=persistence/subscription/subscription.go -destination=./persistence/subscription/subscription_mock.go -package=subscription -self_package=github.com/DrmagicE/gmqtt/subscription 6 | mockgen -source=persistence/unack/unack.go -destination=./persistence/unack/unack_mock.go -package=unack -self_package=github.com/DrmagicE/gmqtt/unack 7 | mockgen -source=pkg/packets/packets.go -destination=./pkg/packets/packets_mock.go -package=packets -self_package=github.com/DrmagicE/gmqtt/packets 8 | mockgen -source=plugin/auth/account_grpc.pb.go -destination=./plugin/auth/account_grpc.pb_mock.go -package=auth -self_package=github.com/DrmagicE/gmqtt/auth 9 | mockgen -source=plugin/federation/federation.pb.go -destination=./plugin/federation/federation.pb_mock.go -package=federation -self_package=github.com/DrmagicE/gmqtt/federation 10 | mockgen -source=plugin/federation/peer.go -destination=./plugin/federation/peer_mock.go -package=federation -self_package=github.com/DrmagicE/gmqtt/federation 11 | mockgen -source=plugin/federation/membership.go -destination=./plugin/federation/membership_mock.go -package=federation -self_package=github.com/DrmagicE/gmqtt/federation 12 | mockgen -source=retained/interface.go -destination=./retained/interface_mock.go -package=retained -self_package=github.com/DrmagicE/gmqtt/retained 13 | mockgen -source=server/client.go -destination=./server/client_mock.go -package=server -self_package=github.com/DrmagicE/gmqtt/server 14 | mockgen -source=server/persistence.go -destination=./server/persistence_mock.go -package=server -self_package=github.com/DrmagicE/gmqtt/server 15 | mockgen -source=server/plugin.go -destination=./server/plugin_mock.go -package=server -self_package=github.com/DrmagicE/gmqtt/server 16 | mockgen -source=server/server.go -destination=./server/server_mock.go -package=server -self_package=github.com/DrmagicE/gmqtt/server 17 | mockgen -source=server/service.go -destination=./server/service_mock.go -package=server -self_package=github.com/DrmagicE/gmqtt/server 18 | mockgen -source=server/stats.go -destination=./server/stats_mock.go -package=server -self_package=github.com/DrmagicE/gmqtt/server 19 | mockgen -source=server/topic_alias.go -destination=./server/topic_alias_mock.go -package=server -self_package=github.com/DrmagicE/gmqtt/server 20 | 21 | # reflection mode. 22 | # gRPC streaming mock issue: https://github.com/golang/mock/pull/163 23 | mockgen -package=federation -destination=/usr/local/gopath/src/github.com/DrmagicE/gmqtt/plugin/federation/federation_grpc.pb_mock.go github.com/DrmagicE/gmqtt/plugin/federation FederationClient,Federation_EventStreamClient 24 | -------------------------------------------------------------------------------- /persistence/encoding/binary.go: -------------------------------------------------------------------------------- 1 | package encoding 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "io" 8 | ) 9 | 10 | func WriteUint16(w *bytes.Buffer, i uint16) { 11 | w.WriteByte(byte(i >> 8)) 12 | w.WriteByte(byte(i)) 13 | } 14 | 15 | func WriteBool(w *bytes.Buffer, b bool) { 16 | if b { 17 | w.WriteByte(1) 18 | } else { 19 | w.WriteByte(0) 20 | } 21 | } 22 | 23 | func ReadBool(r *bytes.Buffer) (bool, error) { 24 | b, err := r.ReadByte() 25 | if err != nil { 26 | return false, err 27 | } 28 | if b == 0 { 29 | return false, nil 30 | } 31 | return true, nil 32 | } 33 | 34 | func WriteString(w *bytes.Buffer, s []byte) { 35 | WriteUint16(w, uint16(len(s))) 36 | w.Write(s) 37 | } 38 | func ReadString(r *bytes.Buffer) (b []byte, err error) { 39 | l := make([]byte, 2) 40 | _, err = io.ReadFull(r, l) 41 | if err != nil { 42 | return nil, err 43 | } 44 | length := int(binary.BigEndian.Uint16(l)) 45 | paylaod := make([]byte, length) 46 | 47 | _, err = io.ReadFull(r, paylaod) 48 | if err != nil { 49 | return nil, err 50 | } 51 | return paylaod, nil 52 | } 53 | 54 | func WriteUint32(w *bytes.Buffer, i uint32) { 55 | w.WriteByte(byte(i >> 24)) 56 | w.WriteByte(byte(i >> 16)) 57 | w.WriteByte(byte(i >> 8)) 58 | w.WriteByte(byte(i)) 59 | } 60 | 61 | func ReadUint16(r *bytes.Buffer) (uint16, error) { 62 | if r.Len() < 2 { 63 | return 0, errors.New("invalid length") 64 | } 65 | return binary.BigEndian.Uint16(r.Next(2)), nil 66 | } 67 | 68 | func ReadUint32(r *bytes.Buffer) (uint32, error) { 69 | if r.Len() < 4 { 70 | return 0, errors.New("invalid length") 71 | } 72 | return binary.BigEndian.Uint32(r.Next(4)), nil 73 | } 74 | -------------------------------------------------------------------------------- /persistence/memory.go: -------------------------------------------------------------------------------- 1 | package persistence 2 | 3 | import ( 4 | "github.com/DrmagicE/gmqtt/config" 5 | "github.com/DrmagicE/gmqtt/persistence/queue" 6 | mem_queue "github.com/DrmagicE/gmqtt/persistence/queue/mem" 7 | "github.com/DrmagicE/gmqtt/persistence/session" 8 | mem_session "github.com/DrmagicE/gmqtt/persistence/session/mem" 9 | "github.com/DrmagicE/gmqtt/persistence/subscription" 10 | mem_sub "github.com/DrmagicE/gmqtt/persistence/subscription/mem" 11 | "github.com/DrmagicE/gmqtt/persistence/unack" 12 | mem_unack "github.com/DrmagicE/gmqtt/persistence/unack/mem" 13 | "github.com/DrmagicE/gmqtt/server" 14 | ) 15 | 16 | func init() { 17 | server.RegisterPersistenceFactory("memory", NewMemory) 18 | } 19 | 20 | func NewMemory(config config.Config) (server.Persistence, error) { 21 | return &memory{}, nil 22 | } 23 | 24 | type memory struct { 25 | } 26 | 27 | func (m *memory) NewUnackStore(config config.Config, clientID string) (unack.Store, error) { 28 | return mem_unack.New(mem_unack.Options{ 29 | ClientID: clientID, 30 | }), nil 31 | } 32 | 33 | func (m *memory) NewSessionStore(config config.Config) (session.Store, error) { 34 | return mem_session.New(), nil 35 | } 36 | 37 | func (m *memory) Open() error { 38 | return nil 39 | } 40 | func (m *memory) NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) { 41 | return mem_queue.New(mem_queue.Options{ 42 | MaxQueuedMsg: config.MQTT.MaxQueuedMsg, 43 | InflightExpiry: config.MQTT.InflightExpiry, 44 | ClientID: clientID, 45 | DefaultNotifier: defaultNotifier, 46 | }) 47 | } 48 | 49 | func (m *memory) NewSubscriptionStore(config config.Config) (subscription.Store, error) { 50 | return mem_sub.NewStore(), nil 51 | } 52 | 53 | func (m *memory) Close() error { 54 | return nil 55 | } 56 | -------------------------------------------------------------------------------- /persistence/memory_test.go: -------------------------------------------------------------------------------- 1 | package persistence 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | "github.com/stretchr/testify/suite" 8 | 9 | "github.com/DrmagicE/gmqtt/config" 10 | queue_test "github.com/DrmagicE/gmqtt/persistence/queue/test" 11 | sess_test "github.com/DrmagicE/gmqtt/persistence/session/test" 12 | "github.com/DrmagicE/gmqtt/persistence/subscription" 13 | sub_test "github.com/DrmagicE/gmqtt/persistence/subscription/test" 14 | unack_test "github.com/DrmagicE/gmqtt/persistence/unack/test" 15 | "github.com/DrmagicE/gmqtt/server" 16 | ) 17 | 18 | type MemorySuite struct { 19 | suite.Suite 20 | new server.NewPersistence 21 | p server.Persistence 22 | } 23 | 24 | func (s *MemorySuite) TestQueue() { 25 | a := assert.New(s.T()) 26 | qs, err := s.p.NewQueueStore(queue_test.TestServerConfig, queue_test.TestNotifier, queue_test.TestClientID) 27 | a.Nil(err) 28 | queue_test.TestQueue(s.T(), qs) 29 | } 30 | func (s *MemorySuite) TestSubscription() { 31 | newFn := func() subscription.Store { 32 | st, err := s.p.NewSubscriptionStore(queue_test.TestServerConfig) 33 | if err != nil { 34 | panic(err) 35 | } 36 | return st 37 | } 38 | sub_test.TestSuite(s.T(), newFn) 39 | } 40 | 41 | func (s *MemorySuite) TestSession() { 42 | a := assert.New(s.T()) 43 | st, err := s.p.NewSessionStore(queue_test.TestServerConfig) 44 | a.Nil(err) 45 | sess_test.TestSuite(s.T(), st) 46 | } 47 | 48 | func (s *MemorySuite) TestUnack() { 49 | a := assert.New(s.T()) 50 | st, err := s.p.NewUnackStore(unack_test.TestServerConfig, unack_test.TestClientID) 51 | a.Nil(err) 52 | unack_test.TestSuite(s.T(), st) 53 | } 54 | 55 | func TestMemory(t *testing.T) { 56 | p, err := NewMemory(config.Config{}) 57 | if err != nil { 58 | t.Fatal(err.Error()) 59 | } 60 | suite.Run(t, &MemorySuite{ 61 | p: p, 62 | }) 63 | } 64 | -------------------------------------------------------------------------------- /persistence/queue/elem.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "errors" 7 | "time" 8 | 9 | "github.com/DrmagicE/gmqtt" 10 | "github.com/DrmagicE/gmqtt/persistence/encoding" 11 | "github.com/DrmagicE/gmqtt/pkg/packets" 12 | ) 13 | 14 | type MessageWithID interface { 15 | ID() packets.PacketID 16 | SetID(id packets.PacketID) 17 | } 18 | 19 | type Publish struct { 20 | *gmqtt.Message 21 | } 22 | 23 | func (p *Publish) ID() packets.PacketID { 24 | return p.PacketID 25 | } 26 | func (p *Publish) SetID(id packets.PacketID) { 27 | p.PacketID = id 28 | } 29 | 30 | type Pubrel struct { 31 | PacketID packets.PacketID 32 | } 33 | 34 | func (p *Pubrel) ID() packets.PacketID { 35 | return p.PacketID 36 | } 37 | func (p *Pubrel) SetID(id packets.PacketID) { 38 | p.PacketID = id 39 | } 40 | 41 | // Elem represents the element store in the queue. 42 | type Elem struct { 43 | // At represents the entry time. 44 | At time.Time 45 | // Expiry represents the expiry time. 46 | // Empty means never expire. 47 | Expiry time.Time 48 | MessageWithID 49 | } 50 | 51 | // Encode encodes the publish structure into bytes and write it to the buffer 52 | func (p *Publish) Encode(b *bytes.Buffer) { 53 | encoding.EncodeMessage(p.Message, b) 54 | } 55 | 56 | func (p *Publish) Decode(b *bytes.Buffer) (err error) { 57 | msg, err := encoding.DecodeMessage(b) 58 | if err != nil { 59 | return err 60 | } 61 | p.Message = msg 62 | return nil 63 | } 64 | 65 | // Encode encode the pubrel structure into bytes. 66 | func (p *Pubrel) Encode(b *bytes.Buffer) { 67 | encoding.WriteUint16(b, p.PacketID) 68 | } 69 | 70 | func (p *Pubrel) Decode(b *bytes.Buffer) (err error) { 71 | p.PacketID, err = encoding.ReadUint16(b) 72 | return 73 | } 74 | 75 | // Encode encode the elem structure into bytes. 76 | // Format: 8 byte timestamp | 1 byte identifier| data 77 | func (e *Elem) Encode() []byte { 78 | b := bytes.NewBuffer(make([]byte, 0, 100)) 79 | rs := make([]byte, 19) 80 | binary.BigEndian.PutUint64(rs[0:9], uint64(e.At.Unix())) 81 | binary.BigEndian.PutUint64(rs[9:18], uint64(e.Expiry.Unix())) 82 | switch m := e.MessageWithID.(type) { 83 | case *Publish: 84 | rs[18] = 0 85 | b.Write(rs) 86 | m.Encode(b) 87 | case *Pubrel: 88 | rs[18] = 1 89 | b.Write(rs) 90 | m.Encode(b) 91 | } 92 | return b.Bytes() 93 | } 94 | 95 | func (e *Elem) Decode(b []byte) (err error) { 96 | if len(b) < 19 { 97 | return errors.New("invalid input length") 98 | } 99 | e.At = time.Unix(int64(binary.BigEndian.Uint64(b[0:9])), 0) 100 | e.Expiry = time.Unix(int64(binary.BigEndian.Uint64(b[9:19])), 0) 101 | switch b[18] { 102 | case 0: // publish 103 | p := &Publish{} 104 | buf := bytes.NewBuffer(b[19:]) 105 | err = p.Decode(buf) 106 | e.MessageWithID = p 107 | case 1: // pubrel 108 | p := &Pubrel{} 109 | buf := bytes.NewBuffer(b[19:]) 110 | err = p.Decode(buf) 111 | e.MessageWithID = p 112 | default: 113 | return errors.New("invalid identifier") 114 | } 115 | return 116 | } 117 | -------------------------------------------------------------------------------- /persistence/queue/elem_mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: persistence/queue/elem.go 3 | 4 | // Package queue is a generated GoMock package. 5 | package queue 6 | 7 | import ( 8 | packets "github.com/DrmagicE/gmqtt/pkg/packets" 9 | gomock "github.com/golang/mock/gomock" 10 | reflect "reflect" 11 | ) 12 | 13 | // MockMessageWithID is a mock of MessageWithID interface 14 | type MockMessageWithID struct { 15 | ctrl *gomock.Controller 16 | recorder *MockMessageWithIDMockRecorder 17 | } 18 | 19 | // MockMessageWithIDMockRecorder is the mock recorder for MockMessageWithID 20 | type MockMessageWithIDMockRecorder struct { 21 | mock *MockMessageWithID 22 | } 23 | 24 | // NewMockMessageWithID creates a new mock instance 25 | func NewMockMessageWithID(ctrl *gomock.Controller) *MockMessageWithID { 26 | mock := &MockMessageWithID{ctrl: ctrl} 27 | mock.recorder = &MockMessageWithIDMockRecorder{mock} 28 | return mock 29 | } 30 | 31 | // EXPECT returns an object that allows the caller to indicate expected use 32 | func (m *MockMessageWithID) EXPECT() *MockMessageWithIDMockRecorder { 33 | return m.recorder 34 | } 35 | 36 | // ID mocks base method 37 | func (m *MockMessageWithID) ID() packets.PacketID { 38 | m.ctrl.T.Helper() 39 | ret := m.ctrl.Call(m, "ID") 40 | ret0, _ := ret[0].(packets.PacketID) 41 | return ret0 42 | } 43 | 44 | // ID indicates an expected call of ID 45 | func (mr *MockMessageWithIDMockRecorder) ID() *gomock.Call { 46 | mr.mock.ctrl.T.Helper() 47 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ID", reflect.TypeOf((*MockMessageWithID)(nil).ID)) 48 | } 49 | 50 | // SetID mocks base method 51 | func (m *MockMessageWithID) SetID(id packets.PacketID) { 52 | m.ctrl.T.Helper() 53 | m.ctrl.Call(m, "SetID", id) 54 | } 55 | 56 | // SetID indicates an expected call of SetID 57 | func (mr *MockMessageWithIDMockRecorder) SetID(id interface{}) *gomock.Call { 58 | mr.mock.ctrl.T.Helper() 59 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetID", reflect.TypeOf((*MockMessageWithID)(nil).SetID), id) 60 | } 61 | -------------------------------------------------------------------------------- /persistence/queue/elem_test.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/DrmagicE/gmqtt" 10 | "github.com/DrmagicE/gmqtt/pkg/packets" 11 | ) 12 | 13 | func assertElemEqual(a *assert.Assertions, expected, actual *Elem) { 14 | expected.At = time.Unix(expected.At.Unix(), 0) 15 | expected.Expiry = time.Unix(expected.Expiry.Unix(), 0) 16 | actual.At = time.Unix(actual.At.Unix(), 0) 17 | actual.Expiry = time.Unix(actual.Expiry.Unix(), 0) 18 | a.Equal(expected, actual) 19 | } 20 | 21 | func TestElem_Encode_Publish(t *testing.T) { 22 | a := assert.New(t) 23 | e := &Elem{ 24 | At: time.Now(), 25 | MessageWithID: &Publish{ 26 | Message: &gmqtt.Message{ 27 | Dup: false, 28 | QoS: 2, 29 | Retained: false, 30 | Topic: "/mytopic", 31 | Payload: []byte("payload"), 32 | PacketID: 2, 33 | ContentType: "type", 34 | CorrelationData: nil, 35 | MessageExpiry: 1, 36 | PayloadFormat: packets.PayloadFormatString, 37 | ResponseTopic: "", 38 | SubscriptionIdentifier: []uint32{1, 2}, 39 | UserProperties: []packets.UserProperty{ 40 | { 41 | K: []byte("1"), 42 | V: []byte("2"), 43 | }, { 44 | K: []byte("3"), 45 | V: []byte("4"), 46 | }, 47 | }, 48 | }, 49 | }, 50 | } 51 | rs := e.Encode() 52 | de := &Elem{} 53 | err := de.Decode(rs) 54 | a.Nil(err) 55 | assertElemEqual(a, e, de) 56 | } 57 | func TestElem_Encode_Pubrel(t *testing.T) { 58 | a := assert.New(t) 59 | e := &Elem{ 60 | At: time.Unix(time.Now().Unix(), 0), 61 | MessageWithID: &Pubrel{ 62 | PacketID: 2, 63 | }, 64 | } 65 | rs := e.Encode() 66 | de := &Elem{} 67 | err := de.Decode(rs) 68 | a.Nil(err) 69 | assertElemEqual(a, e, de) 70 | } 71 | 72 | func Benchmark_Encode_Publish(b *testing.B) { 73 | for i := 0; i < b.N; i++ { 74 | e := &Elem{ 75 | At: time.Unix(time.Now().Unix(), 0), 76 | MessageWithID: &Publish{ 77 | Message: &gmqtt.Message{ 78 | Dup: false, 79 | QoS: 2, 80 | Retained: false, 81 | Topic: "/mytopic", 82 | Payload: []byte("payload"), 83 | PacketID: 2, 84 | ContentType: "type", 85 | CorrelationData: nil, 86 | MessageExpiry: 1, 87 | PayloadFormat: packets.PayloadFormatString, 88 | ResponseTopic: "", 89 | SubscriptionIdentifier: []uint32{1, 2}, 90 | UserProperties: []packets.UserProperty{ 91 | { 92 | K: []byte("1"), 93 | V: []byte("2"), 94 | }, { 95 | K: []byte("3"), 96 | V: []byte("4"), 97 | }, 98 | }, 99 | }, 100 | }, 101 | } 102 | e.Encode() 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /persistence/queue/error.go: -------------------------------------------------------------------------------- 1 | package queue 2 | 3 | import ( 4 | "errors" 5 | ) 6 | 7 | var ( 8 | ErrClosed = errors.New("queue has been closed") 9 | ErrDropExceedsMaxPacketSize = errors.New("maximum packet size exceeded") 10 | ErrDropQueueFull = errors.New("the message queue is full") 11 | ErrDropExpired = errors.New("the message is expired") 12 | ErrDropExpiredInflight = errors.New("the inflight message is expired") 13 | ) 14 | 15 | // InternalError wraps the error of the backend storage. 16 | type InternalError struct { 17 | // Err is the error return by the backend storage. 18 | Err error 19 | } 20 | 21 | func (i *InternalError) Error() string { 22 | return i.Err.Error() 23 | } 24 | -------------------------------------------------------------------------------- /persistence/redis.go: -------------------------------------------------------------------------------- 1 | package persistence 2 | 3 | import ( 4 | redigo "github.com/gomodule/redigo/redis" 5 | 6 | "github.com/DrmagicE/gmqtt/config" 7 | "github.com/DrmagicE/gmqtt/persistence/queue" 8 | redis_queue "github.com/DrmagicE/gmqtt/persistence/queue/redis" 9 | "github.com/DrmagicE/gmqtt/persistence/session" 10 | redis_sess "github.com/DrmagicE/gmqtt/persistence/session/redis" 11 | "github.com/DrmagicE/gmqtt/persistence/subscription" 12 | redis_sub "github.com/DrmagicE/gmqtt/persistence/subscription/redis" 13 | "github.com/DrmagicE/gmqtt/persistence/unack" 14 | redis_unack "github.com/DrmagicE/gmqtt/persistence/unack/redis" 15 | "github.com/DrmagicE/gmqtt/server" 16 | ) 17 | 18 | func init() { 19 | server.RegisterPersistenceFactory("redis", NewRedis) 20 | } 21 | 22 | func NewRedis(config config.Config) (server.Persistence, error) { 23 | return &redis{ 24 | config: config, 25 | }, nil 26 | } 27 | 28 | type redis struct { 29 | pool *redigo.Pool 30 | config config.Config 31 | onMsgDropped server.OnMsgDropped 32 | } 33 | 34 | func (r *redis) NewUnackStore(config config.Config, clientID string) (unack.Store, error) { 35 | return redis_unack.New(redis_unack.Options{ 36 | ClientID: clientID, 37 | Pool: r.pool, 38 | }), nil 39 | } 40 | 41 | func (r *redis) NewSessionStore(config config.Config) (session.Store, error) { 42 | return redis_sess.New(r.pool), nil 43 | } 44 | 45 | func newPool(config config.Config) *redigo.Pool { 46 | return &redigo.Pool{ 47 | // Dial or DialContext must be set. When both are set, DialContext takes precedence over Dial. 48 | Dial: func() (redigo.Conn, error) { 49 | c, err := redigo.Dial("tcp", config.Persistence.Redis.Addr) 50 | if err != nil { 51 | return nil, err 52 | } 53 | if pswd := config.Persistence.Redis.Password; pswd != "" { 54 | if _, err := c.Do("AUTH", pswd); err != nil { 55 | c.Close() 56 | return nil, err 57 | } 58 | } 59 | if _, err := c.Do("SELECT", config.Persistence.Redis.Database); err != nil { 60 | c.Close() 61 | return nil, err 62 | } 63 | return c, nil 64 | }, 65 | } 66 | } 67 | func (r *redis) Open() error { 68 | r.pool = newPool(r.config) 69 | r.pool.MaxIdle = int(*r.config.Persistence.Redis.MaxIdle) 70 | r.pool.MaxActive = int(*r.config.Persistence.Redis.MaxActive) 71 | r.pool.IdleTimeout = r.config.Persistence.Redis.IdleTimeout 72 | conn := r.pool.Get() 73 | defer conn.Close() 74 | // Test the connection 75 | _, err := conn.Do("PING") 76 | 77 | return err 78 | } 79 | 80 | func (r *redis) NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) { 81 | return redis_queue.New(redis_queue.Options{ 82 | MaxQueuedMsg: config.MQTT.MaxQueuedMsg, 83 | InflightExpiry: config.MQTT.InflightExpiry, 84 | ClientID: clientID, 85 | Pool: r.pool, 86 | DefaultNotifier: defaultNotifier, 87 | }) 88 | } 89 | 90 | func (r *redis) NewSubscriptionStore(config config.Config) (subscription.Store, error) { 91 | return redis_sub.New(r.pool), nil 92 | } 93 | 94 | func (r *redis) Close() error { 95 | return r.pool.Close() 96 | } 97 | -------------------------------------------------------------------------------- /persistence/session/mem/store.go: -------------------------------------------------------------------------------- 1 | package mem 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/DrmagicE/gmqtt" 7 | "github.com/DrmagicE/gmqtt/persistence/session" 8 | ) 9 | 10 | var _ session.Store = (*Store)(nil) 11 | 12 | func New() *Store { 13 | return &Store{ 14 | mu: sync.Mutex{}, 15 | sess: make(map[string]*gmqtt.Session), 16 | } 17 | } 18 | 19 | type Store struct { 20 | mu sync.Mutex 21 | sess map[string]*gmqtt.Session 22 | } 23 | 24 | func (s *Store) Set(session *gmqtt.Session) error { 25 | s.mu.Lock() 26 | defer s.mu.Unlock() 27 | s.sess[session.ClientID] = session 28 | return nil 29 | } 30 | 31 | func (s *Store) Remove(clientID string) error { 32 | s.mu.Lock() 33 | defer s.mu.Unlock() 34 | delete(s.sess, clientID) 35 | return nil 36 | } 37 | 38 | func (s *Store) Get(clientID string) (*gmqtt.Session, error) { 39 | s.mu.Lock() 40 | defer s.mu.Unlock() 41 | return s.sess[clientID], nil 42 | } 43 | 44 | func (s *Store) GetAll() ([]*gmqtt.Session, error) { 45 | return nil, nil 46 | } 47 | 48 | func (s *Store) SetSessionExpiry(clientID string, expiry uint32) error { 49 | s.mu.Lock() 50 | defer s.mu.Unlock() 51 | if s, ok := s.sess[clientID]; ok { 52 | s.ExpiryInterval = expiry 53 | 54 | } 55 | return nil 56 | } 57 | 58 | func (s *Store) Iterate(fn session.IterateFn) error { 59 | s.mu.Lock() 60 | defer s.mu.Unlock() 61 | for _, v := range s.sess { 62 | cont := fn(v) 63 | if !cont { 64 | break 65 | } 66 | } 67 | return nil 68 | } 69 | -------------------------------------------------------------------------------- /persistence/session/session.go: -------------------------------------------------------------------------------- 1 | package session 2 | 3 | import ( 4 | "github.com/DrmagicE/gmqtt" 5 | ) 6 | 7 | // IterateFn is the callback function used by Iterate() 8 | // Return false means to stop the iteration. 9 | type IterateFn func(session *gmqtt.Session) bool 10 | 11 | type Store interface { 12 | Set(session *gmqtt.Session) error 13 | Remove(clientID string) error 14 | Get(clientID string) (*gmqtt.Session, error) 15 | Iterate(fn IterateFn) error 16 | SetSessionExpiry(clientID string, expiry uint32) error 17 | } 18 | -------------------------------------------------------------------------------- /persistence/session/test/test_suite.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/golang/mock/gomock" 8 | "github.com/stretchr/testify/assert" 9 | 10 | "github.com/DrmagicE/gmqtt" 11 | "github.com/DrmagicE/gmqtt/persistence/session" 12 | ) 13 | 14 | func TestSuite(t *testing.T, store session.Store) { 15 | a := assert.New(t) 16 | ctrl := gomock.NewController(t) 17 | defer ctrl.Finish() 18 | var tt = []*gmqtt.Session{ 19 | { 20 | ClientID: "client", 21 | Will: &gmqtt.Message{ 22 | Topic: "topicA", 23 | Payload: []byte("abc"), 24 | }, 25 | WillDelayInterval: 1, 26 | ConnectedAt: time.Unix(1, 0), 27 | ExpiryInterval: 2, 28 | }, { 29 | ClientID: "client2", 30 | Will: nil, 31 | WillDelayInterval: 0, 32 | ConnectedAt: time.Unix(2, 0), 33 | ExpiryInterval: 0, 34 | }, 35 | } 36 | for _, v := range tt { 37 | a.Nil(store.Set(v)) 38 | } 39 | for _, v := range tt { 40 | sess, err := store.Get(v.ClientID) 41 | a.Nil(err) 42 | a.EqualValues(v, sess) 43 | } 44 | var sess []*gmqtt.Session 45 | err := store.Iterate(func(session *gmqtt.Session) bool { 46 | sess = append(sess, session) 47 | return true 48 | }) 49 | a.Nil(err) 50 | a.ElementsMatch(sess, tt) 51 | } 52 | -------------------------------------------------------------------------------- /persistence/subscription/redis/subscription_test.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/DrmagicE/gmqtt" 9 | ) 10 | 11 | func TestEncodeDecodeSubscription(t *testing.T) { 12 | a := assert.New(t) 13 | tt := []*gmqtt.Subscription{ 14 | { 15 | ShareName: "shareName", 16 | TopicFilter: "filter", 17 | ID: 1, 18 | QoS: 1, 19 | NoLocal: false, 20 | RetainAsPublished: false, 21 | RetainHandling: 0, 22 | }, { 23 | ShareName: "", 24 | TopicFilter: "abc", 25 | ID: 0, 26 | QoS: 2, 27 | NoLocal: false, 28 | RetainAsPublished: true, 29 | RetainHandling: 1, 30 | }, 31 | } 32 | 33 | for _, v := range tt { 34 | b := EncodeSubscription(v) 35 | sub, err := DecodeSubscription(b) 36 | a.Nil(err) 37 | a.Equal(v, sub) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /persistence/unack/mem/mem.go: -------------------------------------------------------------------------------- 1 | package mem 2 | 3 | import ( 4 | "github.com/DrmagicE/gmqtt/persistence/unack" 5 | "github.com/DrmagicE/gmqtt/pkg/packets" 6 | ) 7 | 8 | var _ unack.Store = (*Store)(nil) 9 | 10 | type Store struct { 11 | clientID string 12 | unackpublish map[packets.PacketID]struct{} 13 | } 14 | 15 | type Options struct { 16 | ClientID string 17 | } 18 | 19 | func New(opts Options) *Store { 20 | return &Store{ 21 | clientID: opts.ClientID, 22 | unackpublish: make(map[packets.PacketID]struct{}), 23 | } 24 | } 25 | 26 | func (s *Store) Init(cleanStart bool) error { 27 | if cleanStart { 28 | s.unackpublish = make(map[packets.PacketID]struct{}) 29 | } 30 | return nil 31 | } 32 | 33 | func (s *Store) Set(id packets.PacketID) (bool, error) { 34 | if _, ok := s.unackpublish[id]; ok { 35 | return true, nil 36 | } 37 | s.unackpublish[id] = struct{}{} 38 | return false, nil 39 | } 40 | 41 | func (s *Store) Remove(id packets.PacketID) error { 42 | delete(s.unackpublish, id) 43 | return nil 44 | } 45 | -------------------------------------------------------------------------------- /persistence/unack/redis/redis.go: -------------------------------------------------------------------------------- 1 | package redis 2 | 3 | import ( 4 | "github.com/gomodule/redigo/redis" 5 | 6 | "github.com/DrmagicE/gmqtt/persistence/unack" 7 | "github.com/DrmagicE/gmqtt/pkg/packets" 8 | ) 9 | 10 | const ( 11 | unackPrefix = "unack:" 12 | ) 13 | 14 | var _ unack.Store = (*Store)(nil) 15 | 16 | type Store struct { 17 | clientID string 18 | pool *redis.Pool 19 | unackpublish map[packets.PacketID]struct{} 20 | } 21 | 22 | type Options struct { 23 | ClientID string 24 | Pool *redis.Pool 25 | } 26 | 27 | func New(opts Options) *Store { 28 | return &Store{ 29 | clientID: opts.ClientID, 30 | pool: opts.Pool, 31 | unackpublish: make(map[packets.PacketID]struct{}), 32 | } 33 | } 34 | 35 | func getKey(clientID string) string { 36 | return unackPrefix + clientID 37 | } 38 | func (s *Store) Init(cleanStart bool) error { 39 | if cleanStart { 40 | c := s.pool.Get() 41 | defer c.Close() 42 | s.unackpublish = make(map[packets.PacketID]struct{}) 43 | _, err := c.Do("del", getKey(s.clientID)) 44 | if err != nil { 45 | return err 46 | } 47 | } 48 | return nil 49 | } 50 | 51 | func (s *Store) Set(id packets.PacketID) (bool, error) { 52 | // from cache 53 | if _, ok := s.unackpublish[id]; ok { 54 | return true, nil 55 | } 56 | c := s.pool.Get() 57 | defer c.Close() 58 | _, err := c.Do("hset", getKey(s.clientID), id, 1) 59 | if err != nil { 60 | return false, err 61 | } 62 | s.unackpublish[id] = struct{}{} 63 | return false, nil 64 | } 65 | 66 | func (s *Store) Remove(id packets.PacketID) error { 67 | c := s.pool.Get() 68 | defer c.Close() 69 | _, err := c.Do("hdel", getKey(s.clientID), id) 70 | if err != nil { 71 | return err 72 | } 73 | delete(s.unackpublish, id) 74 | return nil 75 | } 76 | -------------------------------------------------------------------------------- /persistence/unack/test/test_suite.go: -------------------------------------------------------------------------------- 1 | package test 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | 8 | "github.com/DrmagicE/gmqtt/config" 9 | "github.com/DrmagicE/gmqtt/persistence/unack" 10 | "github.com/DrmagicE/gmqtt/pkg/packets" 11 | ) 12 | 13 | var ( 14 | TestServerConfig = config.Config{} 15 | cid = "cid" 16 | TestClientID = cid 17 | ) 18 | 19 | func TestSuite(t *testing.T, store unack.Store) { 20 | a := assert.New(t) 21 | a.Nil(store.Init(false)) 22 | for i := packets.PacketID(1); i < 10; i++ { 23 | rs, err := store.Set(i) 24 | a.Nil(err) 25 | a.False(rs) 26 | rs, err = store.Set(i) 27 | a.Nil(err) 28 | a.True(rs) 29 | err = store.Remove(i) 30 | a.Nil(err) 31 | rs, err = store.Set(i) 32 | a.Nil(err) 33 | a.False(rs) 34 | 35 | } 36 | a.Nil(store.Init(false)) 37 | for i := packets.PacketID(1); i < 10; i++ { 38 | rs, err := store.Set(i) 39 | a.Nil(err) 40 | a.True(rs) 41 | err = store.Remove(i) 42 | a.Nil(err) 43 | rs, err = store.Set(i) 44 | a.Nil(err) 45 | a.False(rs) 46 | } 47 | a.Nil(store.Init(true)) 48 | for i := packets.PacketID(1); i < 10; i++ { 49 | rs, err := store.Set(i) 50 | a.Nil(err) 51 | a.False(rs) 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /persistence/unack/unack.go: -------------------------------------------------------------------------------- 1 | package unack 2 | 3 | import ( 4 | "github.com/DrmagicE/gmqtt/pkg/packets" 5 | ) 6 | 7 | // Store represents a unack store for one client. 8 | // Unack store is used to persist the unacknowledged qos2 messages. 9 | type Store interface { 10 | // Init will be called when the client connect. 11 | // If cleanStart set to true, the implementation should remove any associated data in backend store. 12 | // If it set to false, the implementation should retrieve the associated data from backend store. 13 | Init(cleanStart bool) error 14 | // Set sets the given id into store. 15 | // The return boolean indicates whether the id exist. 16 | Set(id packets.PacketID) (bool, error) 17 | // Remove removes the given id from store. 18 | Remove(id packets.PacketID) error 19 | } 20 | -------------------------------------------------------------------------------- /persistence/unack/unack_mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: persistence/unack/unack.go 3 | 4 | // Package unack is a generated GoMock package. 5 | package unack 6 | 7 | import ( 8 | packets "github.com/DrmagicE/gmqtt/pkg/packets" 9 | gomock "github.com/golang/mock/gomock" 10 | reflect "reflect" 11 | ) 12 | 13 | // MockStore is a mock of Store interface 14 | type MockStore struct { 15 | ctrl *gomock.Controller 16 | recorder *MockStoreMockRecorder 17 | } 18 | 19 | // MockStoreMockRecorder is the mock recorder for MockStore 20 | type MockStoreMockRecorder struct { 21 | mock *MockStore 22 | } 23 | 24 | // NewMockStore creates a new mock instance 25 | func NewMockStore(ctrl *gomock.Controller) *MockStore { 26 | mock := &MockStore{ctrl: ctrl} 27 | mock.recorder = &MockStoreMockRecorder{mock} 28 | return mock 29 | } 30 | 31 | // EXPECT returns an object that allows the caller to indicate expected use 32 | func (m *MockStore) EXPECT() *MockStoreMockRecorder { 33 | return m.recorder 34 | } 35 | 36 | // Init mocks base method 37 | func (m *MockStore) Init(cleanStart bool) error { 38 | m.ctrl.T.Helper() 39 | ret := m.ctrl.Call(m, "Init", cleanStart) 40 | ret0, _ := ret[0].(error) 41 | return ret0 42 | } 43 | 44 | // Init indicates an expected call of Init 45 | func (mr *MockStoreMockRecorder) Init(cleanStart interface{}) *gomock.Call { 46 | mr.mock.ctrl.T.Helper() 47 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockStore)(nil).Init), cleanStart) 48 | } 49 | 50 | // Set mocks base method 51 | func (m *MockStore) Set(id packets.PacketID) (bool, error) { 52 | m.ctrl.T.Helper() 53 | ret := m.ctrl.Call(m, "Set", id) 54 | ret0, _ := ret[0].(bool) 55 | ret1, _ := ret[1].(error) 56 | return ret0, ret1 57 | } 58 | 59 | // Set indicates an expected call of Set 60 | func (mr *MockStoreMockRecorder) Set(id interface{}) *gomock.Call { 61 | mr.mock.ctrl.T.Helper() 62 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockStore)(nil).Set), id) 63 | } 64 | 65 | // Remove mocks base method 66 | func (m *MockStore) Remove(id packets.PacketID) error { 67 | m.ctrl.T.Helper() 68 | ret := m.ctrl.Call(m, "Remove", id) 69 | ret0, _ := ret[0].(error) 70 | return ret0 71 | } 72 | 73 | // Remove indicates an expected call of Remove 74 | func (mr *MockStoreMockRecorder) Remove(id interface{}) *gomock.Call { 75 | mr.mock.ctrl.T.Helper() 76 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockStore)(nil).Remove), id) 77 | } 78 | -------------------------------------------------------------------------------- /pkg/bitmap/bitmap.go: -------------------------------------------------------------------------------- 1 | package bitmap 2 | 3 | //MaxSize 最大支持的大小 4 | const MaxSize = uint16(65535) 5 | 6 | //Bitmap Bitmap结构体 7 | type Bitmap struct { 8 | vals []byte 9 | size uint16 10 | } 11 | 12 | //New 初始化一个Bitmap 13 | func New(size uint16) *Bitmap { 14 | if size == 0 || size >= MaxSize { 15 | size = MaxSize 16 | } else if remainder := size % 8; remainder != 0 { 17 | size += 8 - remainder 18 | } 19 | return &Bitmap{size: size, vals: make([]byte, size>>3+1)} 20 | } 21 | 22 | //Size 返回Bitmap大小 23 | func (b *Bitmap) Size() uint16 { 24 | return b.size 25 | } 26 | 27 | //Set 将offset位置的值设置为value(0/1) 28 | func (b *Bitmap) Set(offset uint16, value uint8) bool { 29 | if b.size < offset { 30 | return false 31 | } 32 | 33 | index, pos := offset>>3, offset&0x07 34 | 35 | if value == 0 { 36 | b.vals[index] &^= 0x01 << pos 37 | } else { 38 | b.vals[index] |= 0x01 << pos 39 | } 40 | 41 | return true 42 | } 43 | 44 | //Get 获取offset位置处的value值 45 | func (b *Bitmap) Get(offset uint16) uint8 { 46 | if b.size < offset { 47 | return 0 48 | } 49 | 50 | index, pos := offset>>3, offset&0x07 51 | 52 | return (b.vals[index] >> pos) & 0x01 53 | } 54 | -------------------------------------------------------------------------------- /pkg/bitmap/bitmap_test.go: -------------------------------------------------------------------------------- 1 | package bitmap 2 | 3 | import ( 4 | "testing" 5 | ) 6 | 7 | func TestBitmap(t *testing.T) { 8 | 9 | size := uint16(MaxSize) 10 | b := New(size) 11 | if b.Size() != size { 12 | t.Fatalf("wrong size %d", size) 13 | } 14 | 15 | b.Set(1, 1) 16 | if b.Get(1) != 1 { 17 | t.Fatalf("wrong value at bit %d", 1) 18 | } 19 | 20 | b.Set(1, 0) 21 | if b.Get(100) != 0 { 22 | t.Fatalf("wrong value at bit %d", 0) 23 | } 24 | 25 | b.Set(size, 1) 26 | if b.Get(size) != 1 { 27 | t.Fatalf("wrong value at bit %d", size) 28 | } 29 | 30 | b.Set(size, 0) 31 | if b.Get(size) != 0 { 32 | t.Fatalf("wrong value at bit %d", size) 33 | } 34 | 35 | b.Set(MaxSize, 1) 36 | v := b.Get(MaxSize) 37 | if v != 1 { 38 | t.Fatalf("wrong value %d", v) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /pkg/packets/auth.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | ) 10 | 11 | type Auth struct { 12 | FixHeader *FixHeader 13 | Code byte 14 | Properties *Properties 15 | } 16 | 17 | func (a *Auth) String() string { 18 | return fmt.Sprintf("Auth, Code: %v, Properties: %s", a.Code, a.Properties) 19 | } 20 | 21 | func (a *Auth) Pack(w io.Writer) error { 22 | a.FixHeader = &FixHeader{PacketType: AUTH, Flags: FlagReserved} 23 | bufw := &bytes.Buffer{} 24 | if a.Code != codes.Success || a.Properties != nil { 25 | bufw.WriteByte(a.Code) 26 | a.Properties.Pack(bufw, AUTH) 27 | } 28 | a.FixHeader.RemainLength = bufw.Len() 29 | err := a.FixHeader.Pack(w) 30 | if err != nil { 31 | return err 32 | } 33 | _, err = bufw.WriteTo(w) 34 | return err 35 | } 36 | 37 | func (a *Auth) Unpack(r io.Reader) error { 38 | if a.FixHeader.RemainLength == 0 { 39 | a.Code = codes.Success 40 | return nil 41 | } 42 | restBuffer := make([]byte, a.FixHeader.RemainLength) 43 | _, err := io.ReadFull(r, restBuffer) 44 | if err != nil { 45 | return codes.ErrMalformed 46 | } 47 | bufr := bytes.NewBuffer(restBuffer) 48 | a.Code, err = bufr.ReadByte() 49 | if err != nil { 50 | return codes.ErrMalformed 51 | } 52 | if !ValidateCode(AUTH, a.Code) { 53 | return codes.ErrProtocol 54 | } 55 | a.Properties = &Properties{} 56 | return a.Properties.Unpack(bufr, AUTH) 57 | } 58 | 59 | func NewAuthPacket(fh *FixHeader, r io.Reader) (*Auth, error) { 60 | p := &Auth{FixHeader: fh} 61 | //判断 标志位 flags 是否合法[MQTT-2.2.2-2] 62 | if fh.Flags != FlagReserved { 63 | return nil, codes.ErrMalformed 64 | } 65 | err := p.Unpack(r) 66 | if err != nil { 67 | return nil, err 68 | } 69 | return p, err 70 | } 71 | -------------------------------------------------------------------------------- /pkg/packets/auth_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/DrmagicE/gmqtt/pkg/codes" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestReadWriteAuthPacket(t *testing.T) { 12 | tt := []struct { 13 | testname string 14 | code codes.Code 15 | properties *Properties 16 | want []byte 17 | }{ 18 | { 19 | testname: "omit properties when code = 0", 20 | code: codes.Success, 21 | properties: nil, 22 | want: []byte{0xF0, 0}, 23 | }, 24 | { 25 | testname: "code = 0 with properties", 26 | code: codes.Success, 27 | properties: &Properties{ 28 | ReasonString: []byte("a"), 29 | }, 30 | want: []byte{0xF0, 6, 0, 4, 0x1F, 0, 1, 'a'}, 31 | }, { 32 | testname: "code != 0 with properties", 33 | code: codes.NotAuthorized, 34 | properties: &Properties{}, 35 | want: []byte{0xF0, 2, codes.NotAuthorized, 0}, 36 | }, 37 | } 38 | 39 | for _, v := range tt { 40 | t.Run(v.testname, func(t *testing.T) { 41 | a := assert.New(t) 42 | b := make([]byte, 0, 2048) 43 | buf := bytes.NewBuffer(b) 44 | au := &Auth{ 45 | Properties: v.properties, 46 | Code: v.code, 47 | } 48 | err := NewWriter(buf).WriteAndFlush(au) 49 | a.Nil(err) 50 | a.Equal(v.want, buf.Bytes()) 51 | 52 | bufr := bytes.NewBuffer(buf.Bytes()) 53 | p, err := NewReader(bufr).ReadPacket() 54 | a.Nil(err) 55 | rp := p.(*Auth) 56 | 57 | a.Equal(v.code, rp.Code) 58 | a.Equal(v.properties, rp.Properties) 59 | 60 | }) 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /pkg/packets/connack.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | ) 10 | 11 | // Connack represents the MQTT Connack packet 12 | type Connack struct { 13 | Version Version 14 | FixHeader *FixHeader 15 | Code codes.Code 16 | SessionPresent bool 17 | Properties *Properties 18 | } 19 | 20 | func (c *Connack) String() string { 21 | return fmt.Sprintf("Connack, Version: %v, Code:%v, SessionPresent:%v, Properties: %s", c.Version, c.Code, c.SessionPresent, c.Properties) 22 | } 23 | 24 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 25 | func (c *Connack) Pack(w io.Writer) error { 26 | var err error 27 | c.FixHeader = &FixHeader{PacketType: CONNACK, Flags: FlagReserved} 28 | bufw := &bytes.Buffer{} 29 | if c.SessionPresent { 30 | bufw.WriteByte(1) 31 | } else { 32 | bufw.WriteByte(0) 33 | } 34 | bufw.WriteByte(c.Code) 35 | if c.Version == Version5 { 36 | c.Properties.Pack(bufw, CONNACK) 37 | } 38 | c.FixHeader.RemainLength = bufw.Len() 39 | err = c.FixHeader.Pack(w) 40 | if err != nil { 41 | return err 42 | } 43 | _, err = bufw.WriteTo(w) 44 | return err 45 | } 46 | 47 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct 48 | func (c *Connack) Unpack(r io.Reader) error { 49 | restBuffer := make([]byte, c.FixHeader.RemainLength) 50 | _, err := io.ReadFull(r, restBuffer) 51 | if err != nil { 52 | return codes.ErrMalformed 53 | } 54 | bufr := bytes.NewBuffer(restBuffer) 55 | sp, err := bufr.ReadByte() 56 | if (127 & (sp >> 1)) > 0 { 57 | return codes.ErrMalformed 58 | } 59 | c.SessionPresent = sp == 1 60 | 61 | code, err := bufr.ReadByte() 62 | if err != nil { 63 | return codes.ErrMalformed 64 | } 65 | 66 | c.Code = code 67 | if c.Version == Version5 { 68 | if !ValidateCode(CONNACK, code) { 69 | return codes.ErrProtocol 70 | } 71 | c.Properties = &Properties{} 72 | return c.Properties.Unpack(bufr, CONNACK) 73 | } 74 | return nil 75 | 76 | } 77 | 78 | // NewConnackPacket returns a Connack instance by the given FixHeader and io.Reader 79 | func NewConnackPacket(fh *FixHeader, version Version, r io.Reader) (*Connack, error) { 80 | p := &Connack{FixHeader: fh, Version: Version5} 81 | if fh.Flags != FlagReserved { 82 | return nil, codes.ErrMalformed 83 | } 84 | err := p.Unpack(r) 85 | if err != nil { 86 | return nil, err 87 | } 88 | return p, err 89 | } 90 | -------------------------------------------------------------------------------- /pkg/packets/connect_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/DrmagicE/gmqtt/pkg/codes" 10 | ) 11 | 12 | func TestReadConnectPacketErr_V5(t *testing.T) { 13 | //[MQTT-3.1.2-3],服务端必须验证CONNECT控制报文的保留标志位(第0位)是否为0,如果不为0必须断开客户端连接 14 | a := assert.New(t) 15 | 16 | b := []byte{16, 12, 0, 4, 'M', 'Q', 'T', 'T', 05, 01, 00, 02, 31, 32} 17 | buf := bytes.NewBuffer(b) 18 | r := NewReader(buf) 19 | r.SetVersion(Version5) 20 | connectPacket, err := r.ReadPacket() 21 | a.Nil(connectPacket) 22 | a.Error(codes.ErrMalformed, err) 23 | 24 | } 25 | func TestReadConnectPacketErr_V311(t *testing.T) { 26 | //[MQTT-3.1.2-3],服务端必须验证CONNECT控制报文的保留标志位(第0位)是否为0,如果不为0必须断开客户端连接 27 | a := assert.New(t) 28 | b := []byte{16, 12, 0, 4, 'M', 'Q', 'T', 'T', 04, 01, 00, 02, 31, 32} 29 | buf := bytes.NewBuffer(b) 30 | connectPacket, err := NewReader(buf).ReadPacket() 31 | a.Nil(connectPacket) 32 | a.Error(codes.ErrMalformed, err) 33 | } 34 | 35 | func TestReadConnect_V31(t *testing.T) { 36 | a := assert.New(t) 37 | b := []byte{0x10, 0x0f, 0, 0x06, 'M', 'Q', 'I', 's', 'd', 'p', 0x03, 0x02, 0x00, 0x0a, 0x00, 0x01, 0x74} 38 | buf := bytes.NewBuffer(b) 39 | connectPacket, err := NewReader(buf).ReadPacket() 40 | a.NoError(err) 41 | a.EqualValues(10, connectPacket.(*Connect).KeepAlive) 42 | } 43 | -------------------------------------------------------------------------------- /pkg/packets/disconnect.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | ) 10 | 11 | // Disconnect represents the MQTT Disconnect packet 12 | type Disconnect struct { 13 | Version Version 14 | FixHeader *FixHeader 15 | // V5 16 | Code codes.Code 17 | Properties *Properties 18 | } 19 | 20 | func (d *Disconnect) String() string { 21 | return fmt.Sprintf("Disconnect, Version: %v, Code: %v, Properties: %s", d.Version, d.Code, d.Properties) 22 | } 23 | 24 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 25 | func (d *Disconnect) Pack(w io.Writer) error { 26 | var err error 27 | d.FixHeader = &FixHeader{PacketType: DISCONNECT, Flags: FlagReserved} 28 | if IsVersion3X(d.Version) { 29 | d.FixHeader.RemainLength = 0 30 | return d.FixHeader.Pack(w) 31 | } 32 | bufw := &bytes.Buffer{} 33 | if d.Code != codes.Success || d.Properties != nil { 34 | bufw.WriteByte(d.Code) 35 | d.Properties.Pack(bufw, DISCONNECT) 36 | } 37 | d.FixHeader.RemainLength = bufw.Len() 38 | err = d.FixHeader.Pack(w) 39 | if err != nil { 40 | return err 41 | } 42 | _, err = bufw.WriteTo(w) 43 | return err 44 | } 45 | 46 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct. 47 | func (d *Disconnect) Unpack(r io.Reader) error { 48 | restBuffer := make([]byte, d.FixHeader.RemainLength) 49 | _, err := io.ReadFull(r, restBuffer) 50 | if err != nil { 51 | return codes.ErrMalformed 52 | } 53 | if d.Version == Version5 { 54 | d.Properties = &Properties{} 55 | bufr := bytes.NewBuffer(restBuffer) 56 | if d.FixHeader.RemainLength == 0 { 57 | d.Code = codes.Success 58 | return nil 59 | } 60 | d.Code, err = bufr.ReadByte() 61 | if err != nil { 62 | return codes.ErrMalformed 63 | } 64 | if !ValidateCode(DISCONNECT, d.Code) { 65 | return codes.ErrProtocol 66 | } 67 | return d.Properties.Unpack(bufr, DISCONNECT) 68 | } 69 | return nil 70 | } 71 | 72 | // NewDisConnectPackets returns a Disconnect instance by the given FixHeader and io.Reader 73 | func NewDisConnectPackets(fh *FixHeader, version Version, r io.Reader) (*Disconnect, error) { 74 | if fh.Flags != 0 { 75 | return nil, codes.ErrMalformed 76 | } 77 | p := &Disconnect{FixHeader: fh, Version: version} 78 | err := p.Unpack(r) 79 | if err != nil { 80 | return nil, err 81 | } 82 | return p, nil 83 | } 84 | -------------------------------------------------------------------------------- /pkg/packets/disconnect_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/DrmagicE/gmqtt/pkg/codes" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestReadWriteDisconnectPacket_V5(t *testing.T) { 12 | tt := []struct { 13 | testname string 14 | code codes.Code 15 | properties *Properties 16 | want []byte 17 | }{ 18 | { 19 | testname: "omit properties when code = 0", 20 | code: codes.Success, 21 | properties: &Properties{}, 22 | want: []byte{0xe0, 0x02, 0x00, 0x00}, 23 | }, 24 | { 25 | testname: "code = 0 with properties", 26 | code: codes.Success, 27 | properties: &Properties{ 28 | ReasonString: []byte("a"), 29 | }, 30 | want: []byte{0xe0, 6, 0, 4, 0x1f, 0, 1, 'a'}, 31 | }, { 32 | testname: "code != 0 with properties", 33 | code: codes.NotAuthorized, 34 | properties: &Properties{}, 35 | want: []byte{0xe0, 2, codes.NotAuthorized, 0}, 36 | }, 37 | } 38 | 39 | for _, v := range tt { 40 | t.Run(v.testname, func(t *testing.T) { 41 | a := assert.New(t) 42 | b := make([]byte, 0, 2048) 43 | buf := bytes.NewBuffer(b) 44 | dis := &Disconnect{ 45 | Properties: v.properties, 46 | Code: v.code, 47 | } 48 | err := NewWriter(buf).WriteAndFlush(dis) 49 | a.Nil(err) 50 | a.Equal(v.want, buf.Bytes()) 51 | 52 | bufr := bytes.NewBuffer(buf.Bytes()) 53 | r := NewReader(bufr) 54 | r.SetVersion(Version5) 55 | p, err := r.ReadPacket() 56 | a.Nil(err) 57 | rp := p.(*Disconnect) 58 | 59 | a.Equal(v.code, rp.Code) 60 | a.Equal(v.properties, rp.Properties) 61 | 62 | }) 63 | } 64 | 65 | } 66 | 67 | func TestReadDisconnect_V311(t *testing.T) { 68 | a := assert.New(t) 69 | b := []byte{0xe0, 0} 70 | buf := bytes.NewBuffer(b) 71 | packet, err := NewReader(buf).ReadPacket() 72 | a.Nil(err) 73 | 74 | _, ok := packet.(*Disconnect) 75 | a.True(ok) 76 | } 77 | 78 | func TestWriteDisconnect_V311(t *testing.T) { 79 | a := assert.New(t) 80 | disconnect := &Disconnect{Version: Version311} 81 | buf := bytes.NewBuffer(make([]byte, 0, 2048)) 82 | err := NewWriter(buf).WriteAndFlush(disconnect) 83 | a.Nil(err) 84 | want := []byte{0xe0, 0} 85 | a.Equal(want, buf.Bytes()) 86 | } 87 | -------------------------------------------------------------------------------- /pkg/packets/packets_mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: pkg/packets/packets.go 3 | 4 | // Package packets is a generated GoMock package. 5 | package packets 6 | 7 | import ( 8 | gomock "github.com/golang/mock/gomock" 9 | io "io" 10 | reflect "reflect" 11 | ) 12 | 13 | // MockPacket is a mock of Packet interface 14 | type MockPacket struct { 15 | ctrl *gomock.Controller 16 | recorder *MockPacketMockRecorder 17 | } 18 | 19 | // MockPacketMockRecorder is the mock recorder for MockPacket 20 | type MockPacketMockRecorder struct { 21 | mock *MockPacket 22 | } 23 | 24 | // NewMockPacket creates a new mock instance 25 | func NewMockPacket(ctrl *gomock.Controller) *MockPacket { 26 | mock := &MockPacket{ctrl: ctrl} 27 | mock.recorder = &MockPacketMockRecorder{mock} 28 | return mock 29 | } 30 | 31 | // EXPECT returns an object that allows the caller to indicate expected use 32 | func (m *MockPacket) EXPECT() *MockPacketMockRecorder { 33 | return m.recorder 34 | } 35 | 36 | // Pack mocks base method 37 | func (m *MockPacket) Pack(w io.Writer) error { 38 | m.ctrl.T.Helper() 39 | ret := m.ctrl.Call(m, "Pack", w) 40 | ret0, _ := ret[0].(error) 41 | return ret0 42 | } 43 | 44 | // Pack indicates an expected call of Pack 45 | func (mr *MockPacketMockRecorder) Pack(w interface{}) *gomock.Call { 46 | mr.mock.ctrl.T.Helper() 47 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pack", reflect.TypeOf((*MockPacket)(nil).Pack), w) 48 | } 49 | 50 | // Unpack mocks base method 51 | func (m *MockPacket) Unpack(r io.Reader) error { 52 | m.ctrl.T.Helper() 53 | ret := m.ctrl.Call(m, "Unpack", r) 54 | ret0, _ := ret[0].(error) 55 | return ret0 56 | } 57 | 58 | // Unpack indicates an expected call of Unpack 59 | func (mr *MockPacketMockRecorder) Unpack(r interface{}) *gomock.Call { 60 | mr.mock.ctrl.T.Helper() 61 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unpack", reflect.TypeOf((*MockPacket)(nil).Unpack), r) 62 | } 63 | 64 | // String mocks base method 65 | func (m *MockPacket) String() string { 66 | m.ctrl.T.Helper() 67 | ret := m.ctrl.Call(m, "String") 68 | ret0, _ := ret[0].(string) 69 | return ret0 70 | } 71 | 72 | // String indicates an expected call of String 73 | func (mr *MockPacketMockRecorder) String() *gomock.Call { 74 | mr.mock.ctrl.T.Helper() 75 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "String", reflect.TypeOf((*MockPacket)(nil).String)) 76 | } 77 | -------------------------------------------------------------------------------- /pkg/packets/ping_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "reflect" 6 | "testing" 7 | ) 8 | 9 | func TestReadPingreq(t *testing.T) { 10 | b := []byte{0xc0, 0} 11 | buf := bytes.NewBuffer(b) 12 | packet, err := NewReader(buf).ReadPacket() 13 | if err != nil { 14 | t.Fatalf("unexpected error: %s", err.Error()) 15 | } 16 | if _, ok := packet.(*Pingreq); !ok { 17 | t.Fatalf("Packet Type error,want %v,got %v", reflect.TypeOf(&Pingreq{}), reflect.TypeOf(packet)) 18 | } 19 | } 20 | 21 | func TestWritePingreq(t *testing.T) { 22 | req := &Pingreq{} 23 | buf := bytes.NewBuffer(make([]byte, 0, 2048)) 24 | err := NewWriter(buf).WriteAndFlush(req) 25 | if err != nil { 26 | t.Fatalf("unexpected error: %s", err.Error()) 27 | } 28 | want := []byte{0xc0, 0} 29 | if !bytes.Equal(buf.Bytes(), want) { 30 | t.Fatalf("write error,want %v, got %v", want, buf.Bytes()) 31 | } 32 | } 33 | 34 | func TestReadPingresp(t *testing.T) { 35 | b := []byte{0xd0, 0} 36 | buf := bytes.NewBuffer(b) 37 | packet, err := NewReader(buf).ReadPacket() 38 | if err != nil { 39 | t.Fatalf("unexpected error: %s", err.Error()) 40 | } 41 | if _, ok := packet.(*Pingresp); !ok { 42 | t.Fatalf("Packet Type error,want %v,got %v", reflect.TypeOf(&Pingresp{}), reflect.TypeOf(packet)) 43 | } 44 | } 45 | 46 | func TestWritePingresp(t *testing.T) { 47 | 48 | resp := &Pingresp{} 49 | buf := bytes.NewBuffer(make([]byte, 0, 2048)) 50 | err := NewWriter(buf).WriteAndFlush(resp) 51 | if err != nil { 52 | t.Fatalf("unexpected error: %s", err.Error()) 53 | } 54 | want := []byte{0xd0, 0} 55 | if !bytes.Equal(buf.Bytes(), want) { 56 | t.Fatalf("write error,want %v, got %v", want, buf.Bytes()) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /pkg/packets/pingreq.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | 7 | "github.com/DrmagicE/gmqtt/pkg/codes" 8 | ) 9 | 10 | // Pingreq represents the MQTT Pingreq packet 11 | type Pingreq struct { 12 | FixHeader *FixHeader 13 | } 14 | 15 | func (p *Pingreq) String() string { 16 | return fmt.Sprintf("Pingreq") 17 | } 18 | 19 | // NewPingreqPacket returns a Pingreq instance by the given FixHeader and io.Reader 20 | func NewPingreqPacket(fh *FixHeader, r io.Reader) (*Pingreq, error) { 21 | if fh.Flags != FlagReserved { 22 | return nil, codes.ErrMalformed 23 | } 24 | p := &Pingreq{FixHeader: fh} 25 | err := p.Unpack(r) 26 | if err != nil { 27 | return nil, err 28 | } 29 | return p, nil 30 | } 31 | 32 | // NewPingresp returns a Pingresp struct 33 | func (p *Pingreq) NewPingresp() *Pingresp { 34 | fh := &FixHeader{PacketType: PINGRESP, Flags: 0, RemainLength: 0} 35 | return &Pingresp{FixHeader: fh} 36 | } 37 | 38 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 39 | func (p *Pingreq) Pack(w io.Writer) error { 40 | p.FixHeader = &FixHeader{PacketType: PINGREQ, Flags: 0, RemainLength: 0} 41 | return p.FixHeader.Pack(w) 42 | } 43 | 44 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct. 45 | func (p *Pingreq) Unpack(r io.Reader) error { 46 | if p.FixHeader.RemainLength != 0 { 47 | return codes.ErrMalformed 48 | } 49 | return nil 50 | } 51 | -------------------------------------------------------------------------------- /pkg/packets/pingresp.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | 7 | "github.com/DrmagicE/gmqtt/pkg/codes" 8 | ) 9 | 10 | // Pingresp represents the MQTT Pingresp packet 11 | type Pingresp struct { 12 | FixHeader *FixHeader 13 | } 14 | 15 | func (p *Pingresp) String() string { 16 | return fmt.Sprintf("Pingresp") 17 | } 18 | 19 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 20 | func (p *Pingresp) Pack(w io.Writer) error { 21 | p.FixHeader = &FixHeader{PacketType: PINGRESP, Flags: 0, RemainLength: 0} 22 | return p.FixHeader.Pack(w) 23 | } 24 | 25 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct. 26 | func (p *Pingresp) Unpack(r io.Reader) error { 27 | if p.FixHeader.RemainLength != 0 { 28 | return codes.ErrMalformed 29 | } 30 | return nil 31 | } 32 | 33 | // NewPingrespPacket returns a Pingresp instance by the given FixHeader and io.Reader 34 | func NewPingrespPacket(fh *FixHeader, r io.Reader) (*Pingresp, error) { 35 | if fh.Flags != FlagReserved { 36 | return nil, codes.ErrMalformed 37 | } 38 | p := &Pingresp{FixHeader: fh} 39 | err := p.Unpack(r) 40 | if err != nil { 41 | return nil, err 42 | } 43 | return p, nil 44 | } 45 | -------------------------------------------------------------------------------- /pkg/packets/puback.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | ) 10 | 11 | // Puback represents the MQTT Puback packet 12 | type Puback struct { 13 | Version Version 14 | FixHeader *FixHeader 15 | PacketID PacketID 16 | // V5 17 | Code codes.Code 18 | Properties *Properties 19 | } 20 | 21 | func (p *Puback) String() string { 22 | return fmt.Sprintf("Puback, Version: %v, Pid: %v, Properties: %s", p.Version, p.PacketID, p.Properties) 23 | } 24 | 25 | // NewPubackPacket returns a Puback instance by the given FixHeader and io.Reader 26 | func NewPubackPacket(fh *FixHeader, version Version, r io.Reader) (*Puback, error) { 27 | p := &Puback{FixHeader: fh, Version: version} 28 | err := p.Unpack(r) 29 | if err != nil { 30 | return nil, err 31 | } 32 | return p, nil 33 | } 34 | 35 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 36 | func (p *Puback) Pack(w io.Writer) error { 37 | p.FixHeader = &FixHeader{PacketType: PUBACK, Flags: FlagReserved} 38 | bufw := &bytes.Buffer{} 39 | writeUint16(bufw, p.PacketID) 40 | if p.Version == Version5 && (p.Code != codes.Success || p.Properties != nil) { 41 | bufw.WriteByte(p.Code) 42 | p.Properties.Pack(bufw, PUBACK) 43 | 44 | } 45 | p.FixHeader.RemainLength = bufw.Len() 46 | err := p.FixHeader.Pack(w) 47 | if err != nil { 48 | return err 49 | } 50 | _, err = bufw.WriteTo(w) 51 | return err 52 | } 53 | 54 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct. 55 | func (p *Puback) Unpack(r io.Reader) error { 56 | restBuffer := make([]byte, p.FixHeader.RemainLength) 57 | _, err := io.ReadFull(r, restBuffer) 58 | if err != nil { 59 | return codes.ErrMalformed 60 | } 61 | bufr := bytes.NewBuffer(restBuffer) 62 | 63 | p.PacketID, err = readUint16(bufr) 64 | if err != nil { 65 | return err 66 | } 67 | if p.FixHeader.RemainLength == 2 { 68 | p.Code = codes.Success 69 | return nil 70 | } 71 | 72 | if p.Version == Version5 { 73 | p.Properties = &Properties{} 74 | if p.Code, err = bufr.ReadByte(); err != nil { 75 | return codes.ErrMalformed 76 | } 77 | if !ValidateCode(PUBACK, p.Code) { 78 | return codes.ErrProtocol 79 | } 80 | if err := p.Properties.Unpack(bufr, PUBACK); err != nil { 81 | return err 82 | } 83 | } 84 | return nil 85 | } 86 | -------------------------------------------------------------------------------- /pkg/packets/puback_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/DrmagicE/gmqtt/pkg/codes" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestReadWritePubackPacket_V5(t *testing.T) { 14 | tt := []struct { 15 | testname string 16 | pid PacketID 17 | code codes.Code 18 | properties *Properties 19 | want []byte 20 | }{ 21 | { 22 | testname: "omit properties when code = 0", 23 | pid: 10, 24 | code: codes.Success, 25 | properties: nil, 26 | want: []byte{64, 2, 0, 10}, 27 | }, 28 | { 29 | testname: "code = 0 with properties", 30 | pid: 10, 31 | code: codes.Success, 32 | properties: &Properties{ 33 | ReasonString: []byte("a"), 34 | }, 35 | want: []byte{64, 8, 0, 10, 0, 4, 0x1f, 0, 1, 'a'}, 36 | }, { 37 | testname: "code != 0 with properties", 38 | pid: 10, 39 | code: codes.NotAuthorized, 40 | properties: &Properties{}, 41 | want: []byte{64, 4, 0, 10, codes.NotAuthorized, 0}, 42 | }, 43 | } 44 | 45 | for _, v := range tt { 46 | t.Run(v.testname, func(t *testing.T) { 47 | a := assert.New(t) 48 | b := make([]byte, 0, 2048) 49 | buf := bytes.NewBuffer(b) 50 | puback := &Puback{ 51 | Version: Version5, 52 | PacketID: v.pid, 53 | Properties: v.properties, 54 | Code: v.code, 55 | } 56 | err := NewWriter(buf).WriteAndFlush(puback) 57 | a.Nil(err) 58 | a.Equal(v.want, buf.Bytes()) 59 | 60 | bufr := bytes.NewBuffer(buf.Bytes()) 61 | r := NewReader(bufr) 62 | r.SetVersion(Version5) 63 | p, err := r.ReadPacket() 64 | a.Nil(err) 65 | rp := p.(*Puback) 66 | 67 | a.Equal(v.code, rp.Code) 68 | a.Equal(v.properties, rp.Properties) 69 | a.Equal(v.pid, rp.PacketID) 70 | 71 | }) 72 | } 73 | 74 | } 75 | 76 | func TestWritePubackPacket_V311(t *testing.T) { 77 | a := assert.New(t) 78 | b := make([]byte, 0, 2048) 79 | buf := bytes.NewBuffer(b) 80 | pid := uint16(65535) 81 | puback := &Puback{ 82 | Version: Version311, 83 | PacketID: pid, 84 | } 85 | err := NewWriter(buf).WriteAndFlush(puback) 86 | a.Nil(err) 87 | packet, err := NewReader(buf).ReadPacket() 88 | a.Nil(err) 89 | _, err = buf.ReadByte() 90 | a.Equal(io.EOF, err) 91 | 92 | if p, ok := packet.(*Puback); ok { 93 | a.EqualValues(pid, p.PacketID) 94 | } else { 95 | t.Fatalf("Packet type error,want %v,got %v", reflect.TypeOf(&Puback{}), reflect.TypeOf(packet)) 96 | } 97 | 98 | } 99 | 100 | func TestReadPubackPacket(t *testing.T) { 101 | a := assert.New(t) 102 | pubackBytes := bytes.NewBuffer([]byte{64, 2, 0, 1}) 103 | packet, err := NewReader(pubackBytes).ReadPacket() 104 | a.Nil(err) 105 | 106 | if p, ok := packet.(*Puback); ok { 107 | a.EqualValues(1, p.PacketID) 108 | } else { 109 | t.Fatalf("Packet Type error,want %v,got %v", reflect.TypeOf(&Puback{}), reflect.TypeOf(packet)) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pkg/packets/pubcomp.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | ) 10 | 11 | // Pubcomp represents the MQTT Pubcomp packet 12 | type Pubcomp struct { 13 | Version Version 14 | FixHeader *FixHeader 15 | PacketID PacketID 16 | Code byte 17 | Properties *Properties 18 | } 19 | 20 | func (p *Pubcomp) String() string { 21 | return fmt.Sprintf("Pubcomp, Version: %v, Pid: %v, Properties: %s", p.Version, p.PacketID, p.Properties) 22 | } 23 | 24 | // NewPubcompPacket returns a Pubcomp instance by the given FixHeader and io.Reader 25 | func NewPubcompPacket(fh *FixHeader, version Version, r io.Reader) (*Pubcomp, error) { 26 | p := &Pubcomp{FixHeader: fh, Version: version} 27 | err := p.Unpack(r) 28 | if err != nil { 29 | return nil, err 30 | } 31 | return p, nil 32 | } 33 | 34 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 35 | func (p *Pubcomp) Pack(w io.Writer) error { 36 | p.FixHeader = &FixHeader{PacketType: PUBCOMP, Flags: FlagReserved} 37 | bufw := &bytes.Buffer{} 38 | writeUint16(bufw, p.PacketID) 39 | if p.Version == Version5 && (p.Code != codes.Success || p.Properties != nil) { 40 | bufw.WriteByte(p.Code) 41 | p.Properties.Pack(bufw, PUBCOMP) 42 | } 43 | p.FixHeader.RemainLength = bufw.Len() 44 | err := p.FixHeader.Pack(w) 45 | if err != nil { 46 | return err 47 | } 48 | _, err = bufw.WriteTo(w) 49 | return err 50 | } 51 | 52 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct. 53 | func (p *Pubcomp) Unpack(r io.Reader) error { 54 | restBuffer := make([]byte, p.FixHeader.RemainLength) 55 | _, err := io.ReadFull(r, restBuffer) 56 | if err != nil { 57 | return codes.ErrMalformed 58 | } 59 | bufr := bytes.NewBuffer(restBuffer) 60 | p.PacketID, err = readUint16(bufr) 61 | if err != nil { 62 | return err 63 | } 64 | if p.FixHeader.RemainLength == 2 { 65 | p.Code = codes.Success 66 | return nil 67 | } 68 | if p.Version == Version5 { 69 | p.Properties = &Properties{} 70 | if p.Code, err = bufr.ReadByte(); err != nil { 71 | return codes.ErrMalformed 72 | } 73 | if !ValidateCode(PUBCOMP, p.Code) { 74 | return codes.ErrProtocol 75 | } 76 | return p.Properties.Unpack(bufr, PUBCOMP) 77 | } 78 | return nil 79 | } 80 | -------------------------------------------------------------------------------- /pkg/packets/pubcomp_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/DrmagicE/gmqtt/pkg/codes" 10 | "github.com/stretchr/testify/assert" 11 | ) 12 | 13 | func TestReadWritePubcompPacket_v5(t *testing.T) { 14 | tt := []struct { 15 | testname string 16 | pid PacketID 17 | code codes.Code 18 | properties *Properties 19 | want []byte 20 | }{ 21 | { 22 | testname: "omit properties when code = 0", 23 | pid: 10, 24 | code: codes.Success, 25 | properties: nil, 26 | want: []byte{112, 2, 0, 10}, 27 | }, 28 | { 29 | testname: "code = 0 with properties", 30 | pid: 10, 31 | code: codes.Success, 32 | properties: &Properties{ 33 | ReasonString: []byte("a"), 34 | }, 35 | want: []byte{112, 8, 0, 10, 0, 4, 0x1f, 0, 1, 'a'}, 36 | }, { 37 | testname: "code != 0 with properties", 38 | pid: 10, 39 | code: codes.NotAuthorized, 40 | properties: &Properties{}, 41 | want: []byte{112, 4, 0, 10, codes.NotAuthorized, 0}, 42 | }, 43 | } 44 | 45 | for _, v := range tt { 46 | t.Run(v.testname, func(t *testing.T) { 47 | a := assert.New(t) 48 | b := make([]byte, 0, 2048) 49 | buf := bytes.NewBuffer(b) 50 | pubcomp := &Pubcomp{ 51 | Version: Version5, 52 | PacketID: v.pid, 53 | Properties: v.properties, 54 | Code: v.code, 55 | } 56 | err := NewWriter(buf).WriteAndFlush(pubcomp) 57 | a.Nil(err) 58 | a.Equal(v.want, buf.Bytes()) 59 | 60 | bufr := bytes.NewBuffer(buf.Bytes()) 61 | r := NewReader(bufr) 62 | r.SetVersion(Version5) 63 | p, err := r.ReadPacket() 64 | a.Nil(err) 65 | rp := p.(*Pubcomp) 66 | 67 | a.Equal(v.code, rp.Code) 68 | a.Equal(v.properties, rp.Properties) 69 | a.Equal(v.pid, rp.PacketID) 70 | 71 | }) 72 | } 73 | 74 | } 75 | 76 | func TestWritePubcompPacket_V311(t *testing.T) { 77 | a := assert.New(t) 78 | b := make([]byte, 0, 2048) 79 | buf := bytes.NewBuffer(b) 80 | pid := uint16(65535) 81 | pubcomp := &Pubcomp{ 82 | Version: Version311, 83 | PacketID: pid, 84 | } 85 | err := NewWriter(buf).WriteAndFlush(pubcomp) 86 | a.Nil(err) 87 | packet, err := NewReader(buf).ReadPacket() 88 | a.Nil(err) 89 | _, err = buf.ReadByte() 90 | a.Equal(io.EOF, err) 91 | 92 | if p, ok := packet.(*Pubcomp); ok { 93 | a.EqualValues(pid, p.PacketID) 94 | } else { 95 | t.Fatalf("Packet type error,want %v,got %v", reflect.TypeOf(&Pubrec{}), reflect.TypeOf(packet)) 96 | } 97 | 98 | } 99 | 100 | func TestReadPubcompPacket_V311(t *testing.T) { 101 | a := assert.New(t) 102 | pubcompBytes := bytes.NewBuffer([]byte{0x70, 2, 0, 1}) 103 | packet, err := NewReader(pubcompBytes).ReadPacket() 104 | a.Nil(err) 105 | if p, ok := packet.(*Pubcomp); ok { 106 | a.EqualValues(1, p.PacketID) 107 | } else { 108 | t.Fatalf("Packet Type error,want %v,got %v", reflect.TypeOf(&Pubcomp{}), reflect.TypeOf(packet)) 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /pkg/packets/pubrec.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | ) 10 | 11 | // Pubrec represents the MQTT Pubrec packet. 12 | type Pubrec struct { 13 | Version Version 14 | FixHeader *FixHeader 15 | PacketID PacketID 16 | // V5 17 | Code byte 18 | Properties *Properties 19 | } 20 | 21 | func (p *Pubrec) String() string { 22 | return fmt.Sprintf("Pubrec, Version: %v, Code %v, Pid: %v, Properties: %s", p.Version, p.Code, p.PacketID, p.Properties) 23 | } 24 | 25 | // NewPubrecPacket returns a Pubrec instance by the given FixHeader and io.Reader. 26 | func NewPubrecPacket(fh *FixHeader, version Version, r io.Reader) (*Pubrec, error) { 27 | p := &Pubrec{FixHeader: fh, Version: version} 28 | err := p.Unpack(r) 29 | if err != nil { 30 | return nil, err 31 | } 32 | return p, nil 33 | } 34 | 35 | // NewPubrel returns the Pubrel struct related to the Pubrec struct in QoS 2. 36 | func (p *Pubrec) NewPubrel() *Pubrel { 37 | pub := &Pubrel{FixHeader: &FixHeader{PacketType: PUBREL, Flags: FlagPubrel}} 38 | pub.PacketID = p.PacketID 39 | return pub 40 | } 41 | 42 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 43 | func (p *Pubrec) Pack(w io.Writer) error { 44 | p.FixHeader = &FixHeader{PacketType: PUBREC, Flags: FlagReserved} 45 | bufw := &bytes.Buffer{} 46 | writeUint16(bufw, p.PacketID) 47 | if p.Version == Version5 && (p.Code != codes.Success || p.Properties != nil) { 48 | bufw.WriteByte(p.Code) 49 | p.Properties.Pack(bufw, PUBREC) 50 | } 51 | p.FixHeader.RemainLength = bufw.Len() 52 | err := p.FixHeader.Pack(w) 53 | if err != nil { 54 | return err 55 | } 56 | _, err = bufw.WriteTo(w) 57 | return err 58 | } 59 | 60 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct. 61 | func (p *Pubrec) Unpack(r io.Reader) error { 62 | restBuffer := make([]byte, p.FixHeader.RemainLength) 63 | _, err := io.ReadFull(r, restBuffer) 64 | if err != nil { 65 | return codes.ErrMalformed 66 | } 67 | bufr := bytes.NewBuffer(restBuffer) 68 | p.PacketID, err = readUint16(bufr) 69 | if err != nil { 70 | return err 71 | } 72 | if p.FixHeader.RemainLength == 2 { 73 | p.Code = codes.Success 74 | return nil 75 | } 76 | if p.Version == Version5 { 77 | p.Properties = &Properties{} 78 | if p.Code, err = bufr.ReadByte(); err != nil { 79 | return codes.ErrMalformed 80 | } 81 | if !ValidateCode(PUBREC, p.Code) { 82 | return codes.ErrProtocol 83 | } 84 | return p.Properties.Unpack(bufr, PUBREC) 85 | } 86 | return nil 87 | 88 | } 89 | -------------------------------------------------------------------------------- /pkg/packets/pubrel.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | ) 10 | 11 | // Pubrel represents the MQTT Pubrel packet 12 | type Pubrel struct { 13 | FixHeader *FixHeader 14 | PacketID PacketID 15 | // V5 16 | Code codes.Code 17 | Properties *Properties 18 | } 19 | 20 | func (p *Pubrel) String() string { 21 | return fmt.Sprintf("Pubrel, Code: %v, Pid: %v, Properties: %s", p.Code, p.PacketID, p.Properties) 22 | } 23 | 24 | // NewPubrelPacket returns a Pubrel instance by the given FixHeader and io.Reader. 25 | func NewPubrelPacket(fh *FixHeader, r io.Reader) (*Pubrel, error) { 26 | p := &Pubrel{FixHeader: fh} 27 | err := p.Unpack(r) 28 | if err != nil { 29 | return nil, err 30 | } 31 | return p, nil 32 | } 33 | 34 | // NewPubcomp returns the Pubcomp struct related to the Pubrel struct in QoS 2. 35 | func (p *Pubrel) NewPubcomp() *Pubcomp { 36 | pub := &Pubcomp{FixHeader: &FixHeader{PacketType: PUBCOMP, Flags: FlagReserved, RemainLength: 2}} 37 | pub.PacketID = p.PacketID 38 | return pub 39 | } 40 | 41 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 42 | func (p *Pubrel) Pack(w io.Writer) error { 43 | p.FixHeader = &FixHeader{PacketType: PUBREL, Flags: FlagPubrel} 44 | bufw := &bytes.Buffer{} 45 | writeUint16(bufw, p.PacketID) 46 | if p.Code != codes.Success || p.Properties != nil { 47 | bufw.WriteByte(p.Code) 48 | p.Properties.Pack(bufw, PUBREL) 49 | } 50 | p.FixHeader.RemainLength = bufw.Len() 51 | err := p.FixHeader.Pack(w) 52 | if err != nil { 53 | return err 54 | } 55 | _, err = bufw.WriteTo(w) 56 | return err 57 | } 58 | 59 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct. 60 | func (p *Pubrel) Unpack(r io.Reader) error { 61 | restBuffer := make([]byte, p.FixHeader.RemainLength) 62 | _, err := io.ReadFull(r, restBuffer) 63 | if err != nil { 64 | return codes.ErrMalformed 65 | } 66 | bufr := bytes.NewBuffer(restBuffer) 67 | p.PacketID, err = readUint16(bufr) 68 | if err != nil { 69 | return err 70 | } 71 | if p.FixHeader.RemainLength == 2 { 72 | p.Code = codes.Success 73 | return nil 74 | } 75 | p.Properties = &Properties{} 76 | if p.Code, err = bufr.ReadByte(); err != nil { 77 | return err 78 | } 79 | if !ValidateCode(PUBREL, p.Code) { 80 | return codes.ErrProtocol 81 | } 82 | return p.Properties.Unpack(bufr, PUBREL) 83 | } 84 | -------------------------------------------------------------------------------- /pkg/packets/pubrel_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/DrmagicE/gmqtt/pkg/codes" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestReadWritePubrelPacket(t *testing.T) { 12 | tt := []struct { 13 | testname string 14 | pid PacketID 15 | code codes.Code 16 | properties *Properties 17 | want []byte 18 | }{ 19 | { 20 | testname: "omit properties when code = 0", 21 | pid: 10, 22 | code: codes.Success, 23 | properties: nil, 24 | want: []byte{98, 2, 0, 10}, 25 | }, 26 | { 27 | testname: "code = 0 with properties", 28 | pid: 10, 29 | code: codes.Success, 30 | properties: &Properties{ 31 | ReasonString: []byte("a"), 32 | }, 33 | want: []byte{98, 8, 0, 10, 0, 4, 0x1f, 0, 1, 'a'}, 34 | }, { 35 | testname: "code != 0 with properties", 36 | pid: 10, 37 | code: codes.NotAuthorized, 38 | properties: &Properties{}, 39 | want: []byte{98, 4, 0, 10, codes.NotAuthorized, 0}, 40 | }, 41 | } 42 | 43 | for _, v := range tt { 44 | t.Run(v.testname, func(t *testing.T) { 45 | a := assert.New(t) 46 | b := make([]byte, 0, 2048) 47 | buf := bytes.NewBuffer(b) 48 | puback := &Pubrel{ 49 | PacketID: v.pid, 50 | Properties: v.properties, 51 | Code: v.code, 52 | } 53 | err := NewWriter(buf).WriteAndFlush(puback) 54 | a.Nil(err) 55 | a.Equal(v.want, buf.Bytes()) 56 | 57 | bufr := bytes.NewBuffer(buf.Bytes()) 58 | 59 | p, err := NewReader(bufr).ReadPacket() 60 | a.Nil(err) 61 | rp := p.(*Pubrel) 62 | 63 | a.Equal(v.code, rp.Code) 64 | a.Equal(v.properties, rp.Properties) 65 | a.Equal(v.pid, rp.PacketID) 66 | 67 | }) 68 | } 69 | 70 | } 71 | 72 | func TestPubrel_NewPubcomp(t *testing.T) { 73 | a := assert.New(t) 74 | pid := uint16(10) 75 | pubrel := &Pubrel{ 76 | PacketID: pid, 77 | } 78 | pubcomp := pubrel.NewPubcomp() 79 | a.Equal(pid, pubcomp.PacketID) 80 | } 81 | -------------------------------------------------------------------------------- /pkg/packets/suback.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | ) 10 | 11 | // Suback represents the MQTT Suback packet. 12 | type Suback struct { 13 | Version Version 14 | FixHeader *FixHeader 15 | PacketID PacketID 16 | Payload []codes.Code 17 | Properties *Properties 18 | } 19 | 20 | func (p *Suback) String() string { 21 | return fmt.Sprintf("Suback,Version: %v, Pid: %v, Payload: %v, Properties: %s", p.Version, p.PacketID, p.Payload, p.Properties) 22 | } 23 | 24 | // NewSubackPacket returns a Suback instance by the given FixHeader and io.Reader. 25 | func NewSubackPacket(fh *FixHeader, version Version, r io.Reader) (*Suback, error) { 26 | p := &Suback{FixHeader: fh, Version: version} 27 | //判断 标志位 flags 是否合法[MQTT-3.8.1-1] 28 | if fh.Flags != FlagReserved { 29 | return nil, codes.ErrMalformed 30 | } 31 | err := p.Unpack(r) 32 | return p, err 33 | } 34 | 35 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 36 | func (p *Suback) Pack(w io.Writer) error { 37 | p.FixHeader = &FixHeader{PacketType: SUBACK, Flags: FlagReserved} 38 | bufw := &bytes.Buffer{} 39 | writeUint16(bufw, p.PacketID) 40 | if p.Version == Version5 { 41 | p.Properties.Pack(bufw, SUBACK) 42 | } 43 | bufw.Write(p.Payload) 44 | p.FixHeader.RemainLength = bufw.Len() 45 | err := p.FixHeader.Pack(w) 46 | if err != nil { 47 | return err 48 | } 49 | _, err = bufw.WriteTo(w) 50 | return err 51 | 52 | } 53 | 54 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct. 55 | func (p *Suback) Unpack(r io.Reader) error { 56 | restBuffer := make([]byte, p.FixHeader.RemainLength) 57 | _, err := io.ReadFull(r, restBuffer) 58 | if err != nil { 59 | return codes.ErrMalformed 60 | } 61 | bufr := bytes.NewBuffer(restBuffer) 62 | 63 | p.PacketID, err = readUint16(bufr) 64 | if err != nil { 65 | return codes.ErrMalformed 66 | } 67 | if p.Version == Version5 { 68 | p.Properties = &Properties{} 69 | err = p.Properties.Unpack(bufr, SUBACK) 70 | if err != nil { 71 | return err 72 | } 73 | } 74 | for { 75 | b, err := bufr.ReadByte() 76 | if err != nil { 77 | return codes.ErrMalformed 78 | } 79 | if !ValidateCode(SUBACK, b) { 80 | return codes.ErrProtocol 81 | } 82 | p.Payload = append(p.Payload, b) 83 | if bufr.Len() == 0 { 84 | return nil 85 | } 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /pkg/packets/suback_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "reflect" 6 | "testing" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestReadWriteSubackPacket_V5(t *testing.T) { 13 | a := assert.New(t) 14 | tt := []struct { 15 | pid PacketID 16 | codes []codes.Code 17 | properties *Properties 18 | want []byte 19 | }{ 20 | { 21 | pid: 10, 22 | codes: []codes.Code{codes.Success}, 23 | properties: &Properties{}, 24 | want: []byte{0x90, 4, 25 | 0, 10, // pid 26 | 0, // properties 27 | 0, //code 28 | }, 29 | }, 30 | { 31 | pid: 10, 32 | codes: []codes.Code{codes.Success, codes.GrantedQoS1}, 33 | properties: &Properties{ 34 | ReasonString: []byte("a"), 35 | }, 36 | want: []byte{0x90, 9, 37 | 0, 10, // pid 38 | 4, // properties 39 | 0x1f, 0, 1, 'a', 40 | 0, 0x01, //codes 41 | }, 42 | }, 43 | { 44 | pid: 10, 45 | codes: []codes.Code{codes.GrantedQoS1, codes.GrantedQoS2}, 46 | properties: &Properties{}, 47 | want: []byte{0x90, 5, 48 | 0, 10, // pid 49 | 0, // properties 50 | 0x01, 0x02, //codes 51 | }, 52 | }, 53 | } 54 | 55 | for _, v := range tt { 56 | 57 | b := make([]byte, 0, 2048) 58 | buf := bytes.NewBuffer(b) 59 | pkg := &Suback{ 60 | Version: Version5, 61 | PacketID: v.pid, 62 | Properties: v.properties, 63 | Payload: v.codes, 64 | } 65 | err := NewWriter(buf).WriteAndFlush(pkg) 66 | a.Nil(err) 67 | a.Equal(v.want, buf.Bytes()) 68 | 69 | bufr := bytes.NewBuffer(buf.Bytes()) 70 | 71 | r := NewReader(bufr) 72 | r.SetVersion(Version5) 73 | p, err := r.ReadPacket() 74 | a.Nil(err) 75 | rp := p.(*Suback) 76 | a.Equal(v.codes, rp.Payload) 77 | a.Equal(v.properties, rp.Properties) 78 | a.Equal(v.pid, rp.PacketID) 79 | 80 | } 81 | 82 | } 83 | 84 | func TestReadSuback_V311(t *testing.T) { 85 | a := assert.New(t) 86 | subackBytes := bytes.NewBuffer([]byte{0x90, 5, //FixHeader 87 | 0, 10, //packetID 88 | 0, 1, 2, //payload 89 | }) 90 | packet, err := NewReader(subackBytes).ReadPacket() 91 | a.Nil(err) 92 | if p, ok := packet.(*Suback); ok { 93 | a.EqualValues(10, p.PacketID) 94 | a.Equal([]byte{0, 1, 2}, p.Payload) 95 | } else { 96 | t.Fatalf("Packet Type error,want %v,got %v", reflect.TypeOf(&Suback{}), reflect.TypeOf(packet)) 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /pkg/packets/unsuback.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | ) 10 | 11 | // Unsuback represents the MQTT Unsuback packet. 12 | type Unsuback struct { 13 | Version Version 14 | FixHeader *FixHeader 15 | PacketID PacketID 16 | Properties *Properties 17 | Payload []codes.Code 18 | } 19 | 20 | func (p *Unsuback) String() string { 21 | return fmt.Sprintf("Unsuback, Version: %v, Pid: %v, Payload: %v, Properties: %s", p.Version, p.PacketID, p.Payload, p.Properties) 22 | } 23 | 24 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 25 | func (p *Unsuback) Pack(w io.Writer) error { 26 | p.FixHeader = &FixHeader{PacketType: UNSUBACK, Flags: FlagReserved} 27 | bufw := &bytes.Buffer{} 28 | writeUint16(bufw, p.PacketID) 29 | if p.Version == Version5 { 30 | p.Properties.Pack(bufw, UNSUBACK) 31 | } 32 | bufw.Write(p.Payload) 33 | p.FixHeader.RemainLength = bufw.Len() 34 | err := p.FixHeader.Pack(w) 35 | if err != nil { 36 | return err 37 | } 38 | _, err = bufw.WriteTo(w) 39 | return err 40 | } 41 | 42 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct. 43 | func (p *Unsuback) Unpack(r io.Reader) error { 44 | restBuffer := make([]byte, p.FixHeader.RemainLength) 45 | _, err := io.ReadFull(r, restBuffer) 46 | if err != nil { 47 | return codes.ErrMalformed 48 | } 49 | bufr := bytes.NewBuffer(restBuffer) 50 | p.PacketID, err = readUint16(bufr) 51 | if err != nil { 52 | return err 53 | } 54 | if IsVersion3X(p.Version) { 55 | return nil 56 | } 57 | 58 | p.Properties = &Properties{} 59 | err = p.Properties.Unpack(bufr, UNSUBACK) 60 | if err != nil { 61 | return err 62 | } 63 | for { 64 | b, err := bufr.ReadByte() 65 | if err != nil { 66 | return codes.ErrMalformed 67 | } 68 | if p.Version == Version5 && !ValidateCode(UNSUBACK, b) { 69 | return codes.ErrProtocol 70 | } 71 | p.Payload = append(p.Payload, b) 72 | if bufr.Len() == 0 { 73 | return nil 74 | } 75 | } 76 | } 77 | 78 | // NewUnsubackPacket returns a Unsuback instance by the given FixHeader and io.Reader. 79 | func NewUnsubackPacket(fh *FixHeader, version Version, r io.Reader) (*Unsuback, error) { 80 | p := &Unsuback{FixHeader: fh, Version: version} 81 | if fh.Flags != FlagReserved { 82 | return nil, codes.ErrMalformed 83 | } 84 | err := p.Unpack(r) 85 | if err != nil { 86 | return nil, err 87 | } 88 | return p, err 89 | } 90 | -------------------------------------------------------------------------------- /pkg/packets/unsuback_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/DrmagicE/gmqtt/pkg/codes" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestReadWriteUnsubackPacket_V5(t *testing.T) { 12 | a := assert.New(t) 13 | tt := []struct { 14 | pid PacketID 15 | codes []codes.Code 16 | properties *Properties 17 | want []byte 18 | }{ 19 | { 20 | pid: 10, 21 | codes: []codes.Code{codes.Success}, 22 | properties: &Properties{}, 23 | want: []byte{0xb0, 4, 24 | 0, 10, // pid 25 | 0, // properties 26 | 0, //code 27 | }, 28 | }, 29 | { 30 | pid: 10, 31 | codes: []codes.Code{codes.Success, codes.NoSubscriptionExisted}, 32 | properties: &Properties{ 33 | ReasonString: []byte("a"), 34 | }, 35 | want: []byte{0xb0, 9, 36 | 0, 10, // pid 37 | 4, // properties 38 | 0x1f, 0, 1, 'a', 39 | 0, 0x11, //codes 40 | }, 41 | }, 42 | } 43 | for _, v := range tt { 44 | 45 | b := make([]byte, 0, 2048) 46 | buf := bytes.NewBuffer(b) 47 | pkg := &Unsuback{ 48 | Version: Version5, 49 | PacketID: v.pid, 50 | Properties: v.properties, 51 | Payload: v.codes, 52 | } 53 | err := NewWriter(buf).WriteAndFlush(pkg) 54 | a.Nil(err) 55 | a.Equal(v.want, buf.Bytes()) 56 | 57 | bufr := bytes.NewBuffer(buf.Bytes()) 58 | 59 | r := NewReader(bufr) 60 | r.SetVersion(Version5) 61 | p, err := r.ReadPacket() 62 | a.Nil(err) 63 | rp := p.(*Unsuback) 64 | 65 | a.EqualValues(v.codes, rp.Payload) 66 | a.Equal(v.properties, rp.Properties) 67 | a.Equal(v.pid, rp.PacketID) 68 | 69 | } 70 | 71 | } 72 | 73 | func TestReadWriteUnsubackPacket_V311(t *testing.T) { 74 | a := assert.New(t) 75 | tt := []struct { 76 | pid PacketID 77 | want []byte 78 | }{ 79 | { 80 | pid: 10, 81 | want: []byte{0xb0, 2, 82 | 0, 10, // pid 83 | }, 84 | }, 85 | } 86 | for _, v := range tt { 87 | 88 | b := make([]byte, 0, 2048) 89 | buf := bytes.NewBuffer(b) 90 | pkg := &Unsuback{ 91 | Version: Version311, 92 | PacketID: v.pid, 93 | } 94 | err := NewWriter(buf).WriteAndFlush(pkg) 95 | a.Nil(err) 96 | a.Equal(v.want, buf.Bytes()) 97 | 98 | bufr := bytes.NewBuffer(buf.Bytes()) 99 | 100 | r := NewReader(bufr) 101 | r.SetVersion(Version311) 102 | p, err := r.ReadPacket() 103 | a.Nil(err) 104 | rp := p.(*Unsuback) 105 | a.Equal(v.pid, rp.PacketID) 106 | 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /pkg/packets/unsubscribe.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "fmt" 6 | "io" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | ) 10 | 11 | // Unsubscribe represents the MQTT Unsubscribe packet. 12 | type Unsubscribe struct { 13 | Version Version 14 | FixHeader *FixHeader 15 | PacketID PacketID 16 | Topics []string 17 | Properties *Properties 18 | } 19 | 20 | func (u *Unsubscribe) String() string { 21 | return fmt.Sprintf("Unsubscribe, Version: %v, Pid: %v, Topics: %v, Properties: %s", u.Version, u.PacketID, u.Topics, u.Properties) 22 | } 23 | 24 | // NewUnSubBack returns the Unsuback struct which is the ack packet of the Unsubscribe packet. 25 | func (u *Unsubscribe) NewUnSubBack() *Unsuback { 26 | fh := &FixHeader{PacketType: UNSUBACK, Flags: 0} 27 | unSuback := &Unsuback{FixHeader: fh, PacketID: u.PacketID, Version: u.Version} 28 | if unSuback.Version == Version5 { 29 | unSuback.Payload = make([]codes.Code, len(u.Topics)) 30 | } 31 | return unSuback 32 | } 33 | 34 | // NewUnsubscribePacket returns a Unsubscribe instance by the given FixHeader and io.Reader. 35 | func NewUnsubscribePacket(fh *FixHeader, version Version, r io.Reader) (*Unsubscribe, error) { 36 | p := &Unsubscribe{FixHeader: fh, Version: version} 37 | //判断 标志位 flags 是否合法[MQTT-3.10.1-1] 38 | if fh.Flags != FlagUnsubscribe { 39 | return nil, codes.ErrMalformed 40 | } 41 | err := p.Unpack(r) 42 | if err != nil { 43 | return nil, err 44 | } 45 | return p, err 46 | } 47 | 48 | // Pack encodes the packet struct into bytes and writes it into io.Writer. 49 | func (u *Unsubscribe) Pack(w io.Writer) error { 50 | u.FixHeader = &FixHeader{PacketType: UNSUBSCRIBE, Flags: FlagUnsubscribe} 51 | bufw := &bytes.Buffer{} 52 | writeUint16(bufw, u.PacketID) 53 | if u.Version == Version5 { 54 | u.Properties.Pack(bufw, UNSUBSCRIBE) 55 | } 56 | for _, topic := range u.Topics { 57 | writeUTF8String(bufw, []byte(topic)) 58 | } 59 | u.FixHeader.RemainLength = bufw.Len() 60 | err := u.FixHeader.Pack(w) 61 | if err != nil { 62 | return err 63 | } 64 | _, err = bufw.WriteTo(w) 65 | return err 66 | } 67 | 68 | // Unpack read the packet bytes from io.Reader and decodes it into the packet struct. 69 | func (u *Unsubscribe) Unpack(r io.Reader) error { 70 | restBuffer := make([]byte, u.FixHeader.RemainLength) 71 | _, err := io.ReadFull(r, restBuffer) 72 | if err != nil { 73 | return codes.ErrMalformed 74 | } 75 | bufr := bytes.NewBuffer(restBuffer) 76 | u.PacketID, err = readUint16(bufr) 77 | if err != nil { 78 | return err 79 | } 80 | 81 | if u.Version == Version5 { 82 | u.Properties = &Properties{} 83 | if err := u.Properties.Unpack(bufr, UNSUBSCRIBE); err != nil { 84 | return err 85 | } 86 | } 87 | for { 88 | topicFilter, err := readUTF8String(true, bufr) 89 | if err != nil { 90 | return err 91 | } 92 | if !ValidTopicFilter(true, topicFilter) { 93 | return codes.ErrProtocol 94 | } 95 | u.Topics = append(u.Topics, string(topicFilter)) 96 | if bufr.Len() == 0 { 97 | return nil 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /pkg/packets/unsubscribe_test.go: -------------------------------------------------------------------------------- 1 | package packets 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "reflect" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestReadWriteUnsubscribe_V5(t *testing.T) { 13 | a := assert.New(t) 14 | firstByte := byte(0xa2) 15 | pid := []byte{0, 10} 16 | properties := []byte{0} 17 | topicFilter1 := []byte("/topic/A") 18 | topicFilter1Bytes, _, _ := EncodeUTF8String(topicFilter1) 19 | 20 | topicFilter2 := []byte("/topic/B") 21 | topicFilter2Bytes, _, _ := EncodeUTF8String(topicFilter2) 22 | 23 | pb := appendPacket(firstByte, pid, properties, topicFilter1Bytes, topicFilter2Bytes) 24 | 25 | unsubBytes := bytes.NewBuffer(pb) 26 | 27 | var packet Packet 28 | var err error 29 | t.Run("unpack", func(t *testing.T) { 30 | r := NewReader(unsubBytes) 31 | r.SetVersion(Version5) 32 | packet, err = r.ReadPacket() 33 | a.Nil(err) 34 | if p, ok := packet.(*Unsubscribe); ok { 35 | a.Equal(binary.BigEndian.Uint16(pid), p.PacketID) 36 | a.EqualValues(topicFilter1, p.Topics[0]) 37 | a.EqualValues(topicFilter2, p.Topics[1]) 38 | a.Len(p.Topics, 2) 39 | } else { 40 | t.Fatalf("Packet Type error,want %v,got %v", reflect.TypeOf(&Unsubscribe{}), reflect.TypeOf(packet)) 41 | } 42 | }) 43 | 44 | t.Run("pack", func(t *testing.T) { 45 | bufw := &bytes.Buffer{} 46 | err = packet.Pack(bufw) 47 | a.Nil(err) 48 | a.Equal(pb, bufw.Bytes()) 49 | }) 50 | 51 | } 52 | 53 | func TestUnsubscribeNoTopics_V5(t *testing.T) { 54 | a := assert.New(t) 55 | firstByte := byte(0xa2) 56 | pid := []byte{0, 10} 57 | properties := []byte{0} 58 | pb := appendPacket(firstByte, pid, properties) 59 | unsubBytes := bytes.NewBuffer(pb) 60 | r := NewReader(unsubBytes) 61 | r.SetVersion(Version5) 62 | packet, err := r.ReadPacket() 63 | a.NotNil(err) 64 | a.Nil(packet) 65 | } 66 | 67 | func TestReadWriteUnsubscribe_V311(t *testing.T) { 68 | a := assert.New(t) 69 | firstByte := byte(0xa2) 70 | pid := []byte{0, 10} 71 | topicFilter1 := []byte("/topic/A") 72 | topicFilter1Bytes, _, _ := EncodeUTF8String(topicFilter1) 73 | 74 | topicFilter2 := []byte("/topic/B") 75 | topicFilter2Bytes, _, _ := EncodeUTF8String(topicFilter2) 76 | 77 | pb := appendPacket(firstByte, pid, topicFilter1Bytes, topicFilter2Bytes) 78 | 79 | unsubBytes := bytes.NewBuffer(pb) 80 | 81 | var packet Packet 82 | var err error 83 | t.Run("unpack", func(t *testing.T) { 84 | r := NewReader(unsubBytes) 85 | r.SetVersion(Version311) 86 | packet, err = r.ReadPacket() 87 | a.Nil(err) 88 | if p, ok := packet.(*Unsubscribe); ok { 89 | a.Equal(binary.BigEndian.Uint16(pid), p.PacketID) 90 | a.EqualValues(topicFilter1, p.Topics[0]) 91 | a.EqualValues(topicFilter2, p.Topics[1]) 92 | a.Len(p.Topics, 2) 93 | } else { 94 | t.Fatalf("Packet Type error,want %v,got %v", reflect.TypeOf(&Unsubscribe{}), reflect.TypeOf(packet)) 95 | } 96 | }) 97 | 98 | t.Run("pack", func(t *testing.T) { 99 | bufw := &bytes.Buffer{} 100 | err = packet.Pack(bufw) 101 | a.Nil(err) 102 | a.Equal(pb, bufw.Bytes()) 103 | }) 104 | 105 | } 106 | -------------------------------------------------------------------------------- /pkg/pidfile/pidfile.go: -------------------------------------------------------------------------------- 1 | // Package pidfile provides structure and helper functions to create and remove 2 | // PID file. A PID file is usually a file used to store the process ID of a 3 | // running process. 4 | package pidfile 5 | 6 | import ( 7 | "fmt" 8 | "io/ioutil" 9 | "os" 10 | "path/filepath" 11 | "strconv" 12 | "strings" 13 | ) 14 | 15 | // PIDFile is a file used to store the process ID of a running process. 16 | type PIDFile struct { 17 | path string 18 | } 19 | 20 | func checkPIDFileAlreadyExists(path string) error { 21 | if pidByte, err := ioutil.ReadFile(path); err == nil { 22 | pidString := strings.TrimSpace(string(pidByte)) 23 | if pid, err := strconv.Atoi(pidString); err == nil { 24 | if processExists(pid) { 25 | return fmt.Errorf("pid file found, ensure gmqtt is not running or delete %s", path) 26 | } 27 | } 28 | } 29 | return nil 30 | } 31 | 32 | // New creates a PIDfile using the specified path. 33 | func New(path string) (*PIDFile, error) { 34 | if err := checkPIDFileAlreadyExists(path); err != nil { 35 | return nil, err 36 | } 37 | // Note MkdirAll returns nil if a directory already exists 38 | if err := os.MkdirAll(filepath.Dir(path), os.FileMode(0755)); err != nil { 39 | return nil, err 40 | } 41 | if err := ioutil.WriteFile(path, []byte(fmt.Sprintf("%d", os.Getpid())), 0644); err != nil { 42 | return nil, err 43 | } 44 | 45 | return &PIDFile{path: path}, nil 46 | } 47 | 48 | // remove removes the PIDFile. 49 | func (file PIDFile) Remove() error { 50 | return os.Remove(file.path) 51 | } 52 | -------------------------------------------------------------------------------- /pkg/pidfile/pidfile_darwin.go: -------------------------------------------------------------------------------- 1 | // +build darwin 2 | 3 | package pidfile 4 | 5 | import ( 6 | "golang.org/x/sys/unix" 7 | ) 8 | 9 | func processExists(pid int) bool { 10 | // OS X does not have a proc filesystem. 11 | // Use kill -0 pid to judge if the process exists. 12 | err := unix.Kill(pid, 0) 13 | return err == nil 14 | } 15 | -------------------------------------------------------------------------------- /pkg/pidfile/pidfile_test.go: -------------------------------------------------------------------------------- 1 | package pidfile 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "path/filepath" 7 | "testing" 8 | ) 9 | 10 | func TestNewAndRemove(t *testing.T) { 11 | dir, err := ioutil.TempDir(os.TempDir(), "test-pidfile") 12 | if err != nil { 13 | t.Fatal("Could not create test directory") 14 | } 15 | 16 | path := filepath.Join(dir, "testfile") 17 | file, err := New(path) 18 | if err != nil { 19 | t.Fatal("Could not create test file", err) 20 | } 21 | 22 | _, err = New(path) 23 | if err == nil { 24 | t.Fatal("Test file creation not blocked") 25 | } 26 | 27 | if err := file.Remove(); err != nil { 28 | t.Fatal("Could not delete created test file") 29 | } 30 | } 31 | 32 | func TestRemoveInvalidPath(t *testing.T) { 33 | file := PIDFile{path: filepath.Join("foo", "bar")} 34 | 35 | if err := file.Remove(); err == nil { 36 | t.Fatal("Non-existing file doesn't give an error on delete") 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /pkg/pidfile/pidfile_unix.go: -------------------------------------------------------------------------------- 1 | // +build !windows,!darwin 2 | 3 | package pidfile 4 | 5 | import ( 6 | "os" 7 | "path/filepath" 8 | "strconv" 9 | ) 10 | 11 | func processExists(pid int) bool { 12 | if _, err := os.Stat(filepath.Join("/proc", strconv.Itoa(pid))); err == nil { 13 | return true 14 | } 15 | return false 16 | } 17 | -------------------------------------------------------------------------------- /pkg/pidfile/pidfile_windows.go: -------------------------------------------------------------------------------- 1 | package pidfile 2 | 3 | import ( 4 | "golang.org/x/sys/windows" 5 | ) 6 | 7 | const ( 8 | processQueryLimitedInformation = 0x1000 9 | 10 | stillActive = 259 11 | ) 12 | 13 | func processExists(pid int) bool { 14 | h, err := windows.OpenProcess(processQueryLimitedInformation, false, uint32(pid)) 15 | if err != nil { 16 | return false 17 | } 18 | var c uint32 19 | err = windows.GetExitCodeProcess(h, &c) 20 | windows.Close(h) 21 | if err != nil { 22 | return c == stillActive 23 | } 24 | return true 25 | } 26 | -------------------------------------------------------------------------------- /plugin/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ThingsPanel/thingspanel-gmqtt/25e5e350779c0c5a32e4945648cf60f68d5adb17/plugin/.DS_Store -------------------------------------------------------------------------------- /plugin/README.md: -------------------------------------------------------------------------------- 1 | # Plugin 2 | [Gmqtt插件机制详解](https://juejin.cn/post/6908305981923409934) 3 | ## How to write plugins 4 | Gmqtt uses code generator to generate plugin template. 5 | 6 | ## 1. Install the CLI tool 7 | ```bash 8 | # run under gmqtt project root directory. 9 | go install ./cmd/gmqctl 10 | ``` 11 | ## 2. Run `gmqctl gen plugin` 12 | ```bash 13 | $ gmqctl gen plugin --help 14 | code generator 15 | 16 | Usage: 17 | gmqctl gen plugin [flags] 18 | 19 | Examples: 20 | The following command will generate a code template for the 'awesome' plugin, which makes use of OnBasicAuth and OnSubscribe hook and enables the configuration in ./plugin directory. 21 | 22 | gmqctl gen plugin -n awesome -H OnBasicAuth,OnSubscribe -c true -o ./plugin 23 | 24 | Flags: 25 | -c, --config Whether the plugin needs a configuration. 26 | -h, --help help for plugin 27 | -H, --hooks string The hooks use by the plugin, multiple hooks are separated by ',' 28 | -n, --name string The plugin name. 29 | -o, --output string The output directory. 30 | ``` 31 | 32 | ## 3. Edit `plugin_imports.yml` 33 | Append your plugin name to `plugin_imports.yml`. 34 | ```yaml 35 | packages: 36 | - admin 37 | - prometheus 38 | - federation 39 | - auth 40 | # for external plugin, use full import path 41 | # - github.com/DrmagicE/gmqtt/plugin/prometheu 42 | ``` 43 | 44 | ## 4. Run `go generate ./...` 45 | Run `go generate ./...` under the project root directory. The command will recreate the `./cmd/gmqttd/plugins.go` file, 46 | which is needed during the compile time. -------------------------------------------------------------------------------- /plugin/admin/README.md: -------------------------------------------------------------------------------- 1 | # admin 2 | 3 | Admin plugin use [grpc-gateway](https://github.com/grpc-ecosystem/grpc-gateway) to provide both REST HTTP and GRPC APIs for integration with external systems. 4 | 5 | # API Doc 6 | 7 | See [swagger](https://github.com/DrmagicE/gmqtt/blob/master/plugin/admin/swagger) 8 | 9 | # Examples 10 | 11 | ## List Clients 12 | ```bash 13 | $ curl 127.0.0.1:8083/v1/clients 14 | ``` 15 | Response: 16 | ```json 17 | { 18 | "clients": [ 19 | { 20 | "client_id": "ab", 21 | "username": "", 22 | "keep_alive": 60, 23 | "version": 4, 24 | "remote_addr": "127.0.0.1:51637", 25 | "local_addr": "127.0.0.1:1883", 26 | "connected_at": "2020-12-12T12:26:36Z", 27 | "disconnected_at": null, 28 | "session_expiry": 7200, 29 | "max_inflight": 100, 30 | "inflight_len": 0, 31 | "max_queue": 100, 32 | "queue_len": 0, 33 | "subscriptions_current": 0, 34 | "subscriptions_total": 0, 35 | "packets_received_bytes": "54", 36 | "packets_received_nums": "3", 37 | "packets_send_bytes": "8", 38 | "packets_send_nums": "2", 39 | "message_dropped": "0" 40 | } 41 | ], 42 | "total_count": 1 43 | } 44 | ``` 45 | 46 | ## Filter Subscriptions 47 | ```bash 48 | $ curl 127.0.0.1:8083/v1/filter_subscriptions?filter_type=1,2,3&match_type=1&topic_name=/a 49 | ``` 50 | This curl is able to filter the subscription that the topic name is equal to "/a". 51 | 52 | Response: 53 | ```json 54 | { 55 | "subscriptions": [ 56 | { 57 | "topic_name": "/a", 58 | "id": 0, 59 | "qos": 1, 60 | "no_local": false, 61 | "retain_as_published": false, 62 | "retain_handling": 0, 63 | "client_id": "ab" 64 | } 65 | ] 66 | } 67 | ``` 68 | 69 | ## Publish Message 70 | ```bash 71 | $ curl -X POST 127.0.0.1:8083/v1/publish -d '{"topic_name":"a","payload":"test","qos":1}' 72 | ``` 73 | This curl will publish the message to the broker.The broker will check if there are matched topics and 74 | send the message to the subscribers, just like received a message from a MQTT client. -------------------------------------------------------------------------------- /plugin/admin/admin.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "go.uber.org/zap" 5 | 6 | "github.com/DrmagicE/gmqtt/config" 7 | "github.com/DrmagicE/gmqtt/server" 8 | ) 9 | 10 | var _ server.Plugin = (*Admin)(nil) 11 | 12 | const Name = "admin" 13 | 14 | func init() { 15 | server.RegisterPlugin(Name, New) 16 | } 17 | 18 | func New(config config.Config) (server.Plugin, error) { 19 | return &Admin{}, nil 20 | } 21 | 22 | var log *zap.Logger 23 | 24 | // Admin providers gRPC and HTTP API that enables the external system to interact with the broker. 25 | type Admin struct { 26 | statsReader server.StatsReader 27 | publisher server.Publisher 28 | clientService server.ClientService 29 | store *store 30 | } 31 | 32 | func (a *Admin) registerHTTP(g server.APIRegistrar) (err error) { 33 | err = g.RegisterHTTPHandler(RegisterClientServiceHandlerFromEndpoint) 34 | if err != nil { 35 | return err 36 | } 37 | err = g.RegisterHTTPHandler(RegisterSubscriptionServiceHandlerFromEndpoint) 38 | if err != nil { 39 | return err 40 | } 41 | err = g.RegisterHTTPHandler(RegisterPublishServiceHandlerFromEndpoint) 42 | if err != nil { 43 | return err 44 | } 45 | return nil 46 | } 47 | 48 | func (a *Admin) Load(service server.Server) error { 49 | log = server.LoggerWithField(zap.String("plugin", Name)) 50 | apiRegistrar := service.APIRegistrar() 51 | RegisterClientServiceServer(apiRegistrar, &clientService{a: a}) 52 | RegisterSubscriptionServiceServer(apiRegistrar, &subscriptionService{a: a}) 53 | RegisterPublishServiceServer(apiRegistrar, &publisher{a: a}) 54 | err := a.registerHTTP(apiRegistrar) 55 | if err != nil { 56 | return err 57 | } 58 | a.statsReader = service.StatsManager() 59 | a.store = newStore(a.statsReader, service.GetConfig()) 60 | a.store.subscriptionService = service.SubscriptionService() 61 | a.publisher = service.Publisher() 62 | a.clientService = service.ClientService() 63 | return nil 64 | } 65 | 66 | func (a *Admin) Unload() error { 67 | return nil 68 | } 69 | 70 | func (a *Admin) Name() string { 71 | return Name 72 | } 73 | -------------------------------------------------------------------------------- /plugin/admin/client.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/golang/protobuf/ptypes/empty" 7 | ) 8 | 9 | type clientService struct { 10 | a *Admin 11 | } 12 | 13 | func (c *clientService) mustEmbedUnimplementedClientServiceServer() { 14 | return 15 | } 16 | 17 | // List lists clients information which the session is valid in the broker (both connected and disconnected). 18 | func (c *clientService) List(ctx context.Context, req *ListClientRequest) (*ListClientResponse, error) { 19 | page, pageSize := GetPage(req.Page, req.PageSize) 20 | clients, total, err := c.a.store.GetClients(page, pageSize) 21 | if err != nil { 22 | return &ListClientResponse{}, err 23 | } 24 | return &ListClientResponse{ 25 | Clients: clients, 26 | TotalCount: total, 27 | }, nil 28 | } 29 | 30 | // Get returns the client information for given request client id. 31 | func (c *clientService) Get(ctx context.Context, req *GetClientRequest) (*GetClientResponse, error) { 32 | if req.ClientId == "" { 33 | return nil, ErrInvalidArgument("client_id", "") 34 | } 35 | client := c.a.store.GetClientByID(req.ClientId) 36 | if client == nil { 37 | return nil, ErrNotFound 38 | } 39 | return &GetClientResponse{ 40 | Client: client, 41 | }, nil 42 | } 43 | 44 | // Delete force disconnect. 45 | func (c *clientService) Delete(ctx context.Context, req *DeleteClientRequest) (*empty.Empty, error) { 46 | if req.ClientId == "" { 47 | return nil, ErrInvalidArgument("client_id", "") 48 | } 49 | if req.CleanSession { 50 | c.a.clientService.TerminateSession(req.ClientId) 51 | } else { 52 | client := c.a.clientService.GetClient(req.ClientId) 53 | if client != nil { 54 | client.Close() 55 | } 56 | } 57 | return &empty.Empty{}, nil 58 | } 59 | -------------------------------------------------------------------------------- /plugin/admin/config.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | ) 7 | 8 | // Config is the configuration for the admin plugin. 9 | type Config struct { 10 | HTTP HTTPConfig `yaml:"http"` 11 | GRPC GRPCConfig `yaml:"grpc"` 12 | } 13 | 14 | // HTTPConfig is the configuration for http endpoint. 15 | type HTTPConfig struct { 16 | // Enable indicates whether to expose http endpoint. 17 | Enable bool `yaml:"enable"` 18 | // Addr is the address that the http server listen on. 19 | Addr string `yaml:"http_addr"` 20 | } 21 | 22 | // GRPCConfig is the configuration for gRPC endpoint. 23 | type GRPCConfig struct { 24 | // Addr is the address that the gRPC server listen on. 25 | Addr string `yaml:"http_addr"` 26 | } 27 | 28 | // Validate validates the configuration, and return an error if it is invalid. 29 | func (c *Config) Validate() error { 30 | if c.HTTP.Enable { 31 | _, _, err := net.SplitHostPort(c.HTTP.Addr) 32 | if err != nil { 33 | return errors.New("invalid http_addr") 34 | } 35 | } 36 | _, _, err := net.SplitHostPort(c.GRPC.Addr) 37 | if err != nil { 38 | return errors.New("invalid grpc_addr") 39 | } 40 | return nil 41 | } 42 | 43 | // DefaultConfig is the default configuration. 44 | var DefaultConfig = Config{ 45 | HTTP: HTTPConfig{ 46 | Enable: true, 47 | Addr: "127.0.0.1:8083", 48 | }, 49 | GRPC: GRPCConfig{ 50 | Addr: "unix://./gmqttd.sock", 51 | }, 52 | } 53 | 54 | func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { 55 | type cfg Config 56 | var v = &struct { 57 | Admin cfg `yaml:"admin"` 58 | }{ 59 | Admin: cfg(DefaultConfig), 60 | } 61 | if err := unmarshal(v); err != nil { 62 | return err 63 | } 64 | emptyGRPC := GRPCConfig{} 65 | if v.Admin.GRPC == emptyGRPC { 66 | v.Admin.GRPC = DefaultConfig.GRPC 67 | } 68 | emptyHTTP := HTTPConfig{} 69 | if v.Admin.HTTP == emptyHTTP { 70 | v.Admin.HTTP = DefaultConfig.HTTP 71 | } 72 | empty := cfg(Config{}) 73 | if v.Admin == empty { 74 | v.Admin = cfg(DefaultConfig) 75 | } 76 | *c = Config(v.Admin) 77 | return nil 78 | } 79 | -------------------------------------------------------------------------------- /plugin/admin/hooks.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/DrmagicE/gmqtt" 7 | "github.com/DrmagicE/gmqtt/server" 8 | ) 9 | 10 | func (a *Admin) HookWrapper() server.HookWrapper { 11 | return server.HookWrapper{ 12 | OnSessionCreatedWrapper: a.OnSessionCreatedWrapper, 13 | OnSessionResumedWrapper: a.OnSessionResumedWrapper, 14 | OnClosedWrapper: a.OnClosedWrapper, 15 | OnSessionTerminatedWrapper: a.OnSessionTerminatedWrapper, 16 | OnSubscribedWrapper: a.OnSubscribedWrapper, 17 | OnUnsubscribedWrapper: a.OnUnsubscribedWrapper, 18 | } 19 | } 20 | 21 | func (a *Admin) OnSessionCreatedWrapper(pre server.OnSessionCreated) server.OnSessionCreated { 22 | return func(ctx context.Context, client server.Client) { 23 | pre(ctx, client) 24 | a.store.addClient(client) 25 | } 26 | } 27 | 28 | func (a *Admin) OnSessionResumedWrapper(pre server.OnSessionResumed) server.OnSessionResumed { 29 | return func(ctx context.Context, client server.Client) { 30 | pre(ctx, client) 31 | a.store.addClient(client) 32 | } 33 | } 34 | 35 | func (a *Admin) OnClosedWrapper(pre server.OnClosed) server.OnClosed { 36 | return func(ctx context.Context, client server.Client, err error) { 37 | pre(ctx, client, err) 38 | a.store.setClientDisconnected(client.ClientOptions().ClientID) 39 | } 40 | } 41 | 42 | func (a *Admin) OnSessionTerminatedWrapper(pre server.OnSessionTerminated) server.OnSessionTerminated { 43 | return func(ctx context.Context, clientID string, reason server.SessionTerminatedReason) { 44 | pre(ctx, clientID, reason) 45 | a.store.removeClient(clientID) 46 | } 47 | } 48 | 49 | func (a *Admin) OnSubscribedWrapper(pre server.OnSubscribed) server.OnSubscribed { 50 | return func(ctx context.Context, client server.Client, subscription *gmqtt.Subscription) { 51 | pre(ctx, client, subscription) 52 | a.store.addSubscription(client.ClientOptions().ClientID, subscription) 53 | } 54 | } 55 | 56 | func (a *Admin) OnUnsubscribedWrapper(pre server.OnUnsubscribed) server.OnUnsubscribed { 57 | return func(ctx context.Context, client server.Client, topicName string) { 58 | pre(ctx, client, topicName) 59 | a.store.removeSubscription(client.ClientOptions().ClientID, topicName) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /plugin/admin/protos/client.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package gmqtt.admin.api; 4 | option go_package = ".;admin"; 5 | 6 | import "google/api/annotations.proto"; 7 | import "google/protobuf/empty.proto"; 8 | import "google/protobuf/timestamp.proto"; 9 | 10 | message ListClientRequest { 11 | uint32 page_size = 1; 12 | uint32 page = 2; 13 | } 14 | 15 | message ListClientResponse { 16 | repeated Client clients = 1; 17 | uint32 total_count = 2; 18 | } 19 | 20 | message GetClientRequest { 21 | string client_id = 1; 22 | } 23 | 24 | message GetClientResponse { 25 | Client client = 1; 26 | } 27 | 28 | 29 | message DeleteClientRequest { 30 | string client_id = 1; 31 | bool clean_session = 2; 32 | } 33 | 34 | message Client { 35 | string client_id =1; 36 | string username = 2; 37 | int32 keep_alive = 3; 38 | int32 version = 4; 39 | string remote_addr=5; 40 | string local_addr=6; 41 | google.protobuf.Timestamp connected_at = 7; 42 | google.protobuf.Timestamp disconnected_at = 8; 43 | uint32 session_expiry = 9; 44 | uint32 max_inflight = 10; 45 | uint32 inflight_len = 11; 46 | uint32 max_queue=12; 47 | uint32 queue_len=13; 48 | uint32 subscriptions_current = 14; 49 | uint32 subscriptions_total = 15; 50 | uint64 packets_received_bytes = 16; 51 | uint64 packets_received_nums = 17; 52 | uint64 packets_send_bytes = 18; 53 | uint64 packets_send_nums = 19; 54 | uint64 message_dropped = 20; 55 | } 56 | 57 | 58 | service ClientService { 59 | // List clients 60 | rpc List (ListClientRequest) returns (ListClientResponse){ 61 | option (google.api.http) = { 62 | get: "/v1/clients" 63 | }; 64 | } 65 | // Get the client for given client id. 66 | // Return NotFound error when client not found. 67 | rpc Get (GetClientRequest) returns (GetClientResponse){ 68 | option (google.api.http) = { 69 | get: "/v1/clients/{client_id}" 70 | }; 71 | } 72 | // Disconnect the client for given client id. 73 | rpc Delete (DeleteClientRequest) returns (google.protobuf.Empty) { 74 | option (google.api.http) = { 75 | delete: "/v1/clients/{client_id}" 76 | }; 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /plugin/admin/protos/proto_gen.sh: -------------------------------------------------------------------------------- 1 | protoc -I. \ 2 | -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway \ 3 | -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis \ 4 | --go-grpc_out=../ \ 5 | --go_out=../ \ 6 | --grpc-gateway_out=../ \ 7 | --swagger_out=../swagger \ 8 | *.proto -------------------------------------------------------------------------------- /plugin/admin/protos/publish.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package gmqtt.admin.api; 4 | option go_package = ".;admin"; 5 | 6 | import "google/api/annotations.proto"; 7 | import "google/protobuf/empty.proto"; 8 | 9 | message PublishRequest { 10 | string topic_name = 1; 11 | string payload = 2; 12 | uint32 qos = 3; 13 | bool retained = 4; 14 | // the following fields are using in v5 client. 15 | string content_type = 5; 16 | string correlation_data = 6; 17 | uint32 message_expiry = 7; 18 | uint32 payload_format = 8; 19 | string response_topic = 9; 20 | repeated UserProperties user_properties = 10; 21 | } 22 | 23 | message UserProperties { 24 | bytes K = 1; 25 | bytes V = 2; 26 | } 27 | 28 | service PublishService { 29 | // Publish message to broker 30 | rpc Publish (PublishRequest) returns (google.protobuf.Empty){ 31 | option (google.api.http) = { 32 | post: "/v1/publish" 33 | body:"*" 34 | }; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /plugin/admin/publish.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "context" 5 | 6 | "github.com/golang/protobuf/ptypes/empty" 7 | 8 | "github.com/DrmagicE/gmqtt" 9 | "github.com/DrmagicE/gmqtt/pkg/packets" 10 | ) 11 | 12 | type publisher struct { 13 | a *Admin 14 | } 15 | 16 | func (p *publisher) mustEmbedUnimplementedPublishServiceServer() { 17 | return 18 | } 19 | 20 | // Publish publishes a message into broker. 21 | func (p *publisher) Publish(ctx context.Context, req *PublishRequest) (resp *empty.Empty, err error) { 22 | if !packets.ValidTopicName(false, []byte(req.TopicName)) { 23 | return nil, ErrInvalidArgument("topic_name", "") 24 | } 25 | if req.Qos > uint32(packets.Qos2) { 26 | return nil, ErrInvalidArgument("qos", "") 27 | } 28 | if req.PayloadFormat != 0 && req.PayloadFormat != 1 { 29 | return nil, ErrInvalidArgument("payload_format", "") 30 | } 31 | if req.ResponseTopic != "" && !packets.ValidV5Topic([]byte(req.ResponseTopic)) { 32 | return nil, ErrInvalidArgument("response_topic", "") 33 | } 34 | var userPpt []packets.UserProperty 35 | for _, v := range req.UserProperties { 36 | userPpt = append(userPpt, packets.UserProperty{ 37 | K: v.K, 38 | V: v.V, 39 | }) 40 | } 41 | 42 | p.a.publisher.Publish(&gmqtt.Message{ 43 | Dup: false, 44 | QoS: byte(req.Qos), 45 | Retained: req.Retained, 46 | Topic: req.TopicName, 47 | Payload: []byte(req.Payload), 48 | ContentType: req.ContentType, 49 | CorrelationData: []byte(req.CorrelationData), 50 | MessageExpiry: req.MessageExpiry, 51 | PayloadFormat: packets.PayloadFormat(req.PayloadFormat), 52 | ResponseTopic: req.ResponseTopic, 53 | UserProperties: userPpt, 54 | }) 55 | return &empty.Empty{}, nil 56 | } 57 | -------------------------------------------------------------------------------- /plugin/admin/utils_test.go: -------------------------------------------------------------------------------- 1 | package admin 2 | 3 | import ( 4 | "container/list" 5 | "strconv" 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestIndexer(t *testing.T) { 12 | a := assert.New(t) 13 | i := NewIndexer() 14 | for j := 0; j < 100; j++ { 15 | i.Set(strconv.Itoa(j), j) 16 | a.EqualValues(j, i.GetByID(strconv.Itoa(j)).Value) 17 | } 18 | a.EqualValues(100, i.Len()) 19 | 20 | var jj int 21 | i.Iterate(func(elem *list.Element) { 22 | v := elem.Value.(int) 23 | a.Equal(jj, v) 24 | jj++ 25 | }, 0, uint(i.Len())) 26 | 27 | e := i.Remove("5") 28 | a.Equal(5, e.Value.(int)) 29 | 30 | var rs []int 31 | i.Iterate(func(elem *list.Element) { 32 | rs = append(rs, elem.Value.(int)) 33 | }, 4, 2) 34 | // 5 is removed 35 | a.Equal([]int{4, 6}, rs) 36 | 37 | } 38 | -------------------------------------------------------------------------------- /plugin/auth/README.md: -------------------------------------------------------------------------------- 1 | # Auth 2 | 3 | Auth plugin provides a simple username/password authentication mechanism. 4 | 5 | # API Doc 6 | 7 | See [swagger](https://github.com/DrmagicE/gmqtt/blob/master/plugin/auth/swagger) 8 | -------------------------------------------------------------------------------- /plugin/auth/config.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "errors" 5 | "fmt" 6 | ) 7 | 8 | type hashType = string 9 | 10 | const ( 11 | Plain hashType = "plain" 12 | MD5 = "md5" 13 | SHA256 = "sha256" 14 | Bcrypt = "bcrypt" 15 | ) 16 | 17 | var ValidateHashType = []string{ 18 | Plain, MD5, SHA256, Bcrypt, 19 | } 20 | 21 | // Config is the configuration for the auth plugin. 22 | type Config struct { 23 | // PasswordFile is the file to store username and password. 24 | PasswordFile string `yaml:"password_file"` 25 | // Hash is the password hash algorithm. 26 | // Possible values: plain | md5 | sha256 | bcrypt 27 | Hash string `yaml:"hash"` 28 | } 29 | 30 | // validate validates the configuration, and return an error if it is invalid. 31 | func (c *Config) Validate() error { 32 | if c.PasswordFile == "" { 33 | return errors.New("password_file must be set") 34 | } 35 | for _, v := range ValidateHashType { 36 | if v == c.Hash { 37 | return nil 38 | } 39 | } 40 | return fmt.Errorf("invalid hash type: %s", c.Hash) 41 | } 42 | 43 | // DefaultConfig is the default configuration. 44 | var DefaultConfig = Config{ 45 | Hash: MD5, 46 | PasswordFile: "./gmqtt_password.yml", 47 | } 48 | 49 | func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { 50 | type cfg Config 51 | var v = &struct { 52 | Auth cfg `yaml:"auth"` 53 | }{ 54 | Auth: cfg(DefaultConfig), 55 | } 56 | if err := unmarshal(v); err != nil { 57 | return err 58 | } 59 | empty := cfg(Config{}) 60 | if v.Auth == empty { 61 | v.Auth = cfg(DefaultConfig) 62 | } 63 | *c = Config(v.Auth) 64 | return nil 65 | } 66 | -------------------------------------------------------------------------------- /plugin/auth/hooks.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "context" 5 | 6 | "go.uber.org/zap" 7 | 8 | "github.com/DrmagicE/gmqtt/pkg/codes" 9 | "github.com/DrmagicE/gmqtt/pkg/packets" 10 | "github.com/DrmagicE/gmqtt/server" 11 | ) 12 | 13 | func (a *Auth) HookWrapper() server.HookWrapper { 14 | return server.HookWrapper{ 15 | OnBasicAuthWrapper: a.OnBasicAuthWrapper, 16 | } 17 | } 18 | 19 | func (a *Auth) OnBasicAuthWrapper(pre server.OnBasicAuth) server.OnBasicAuth { 20 | return func(ctx context.Context, client server.Client, req *server.ConnectRequest) (err error) { 21 | err = pre(ctx, client, req) 22 | if err != nil { 23 | return err 24 | } 25 | ok, err := a.validate(string(req.Connect.Username), string(req.Connect.Password)) 26 | if err != nil { 27 | return err 28 | } 29 | if !ok { 30 | log.Debug("authentication failed", zap.String("username", string(req.Connect.Username))) 31 | v := client.Version() 32 | if packets.IsVersion3X(v) { 33 | return &codes.Error{ 34 | Code: codes.V3NotAuthorized, 35 | } 36 | } 37 | if packets.IsVersion5(v) { 38 | return &codes.Error{ 39 | Code: codes.NotAuthorized, 40 | } 41 | } 42 | } 43 | return nil 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /plugin/auth/hooks_test.go: -------------------------------------------------------------------------------- 1 | package auth 2 | 3 | import ( 4 | "context" 5 | "testing" 6 | 7 | "github.com/golang/mock/gomock" 8 | "github.com/stretchr/testify/assert" 9 | 10 | "github.com/DrmagicE/gmqtt/config" 11 | "github.com/DrmagicE/gmqtt/pkg/packets" 12 | "github.com/DrmagicE/gmqtt/server" 13 | ) 14 | 15 | func TestAuth_OnBasicAuthWrapper(t *testing.T) { 16 | a := assert.New(t) 17 | ctrl := gomock.NewController(t) 18 | defer ctrl.Finish() 19 | 20 | path := "./testdata/gmqtt_password.yml" 21 | cfg := DefaultConfig 22 | cfg.PasswordFile = path 23 | cfg.Hash = Plain 24 | auth, err := New(config.Config{ 25 | Plugins: map[string]config.Configuration{ 26 | "auth": &cfg, 27 | }, 28 | }) 29 | mockClient := server.NewMockClient(ctrl) 30 | mockClient.EXPECT().Version().Return(packets.Version311).AnyTimes() 31 | a.Nil(err) 32 | a.Nil(auth.Load(nil)) 33 | au := auth.(*Auth) 34 | var preCalled bool 35 | fn := au.OnBasicAuthWrapper(func(ctx context.Context, client server.Client, req *server.ConnectRequest) (err error) { 36 | preCalled = true 37 | return nil 38 | }) 39 | // pass 40 | a.Nil(fn(context.Background(), mockClient, &server.ConnectRequest{ 41 | Connect: &packets.Connect{ 42 | Username: []byte("u1"), 43 | Password: []byte("p1"), 44 | }, 45 | })) 46 | a.True(preCalled) 47 | 48 | // fail 49 | a.NotNil(fn(context.Background(), mockClient, &server.ConnectRequest{ 50 | Connect: &packets.Connect{ 51 | Username: []byte("u1"), 52 | Password: []byte("p11"), 53 | }, 54 | })) 55 | 56 | a.Nil(au.Unload()) 57 | } 58 | -------------------------------------------------------------------------------- /plugin/auth/protos/account.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package gmqtt.auth.api; 4 | option go_package = ".;auth"; 5 | 6 | import "google/api/annotations.proto"; 7 | import "google/protobuf/empty.proto"; 8 | 9 | message ListAccountsRequest { 10 | uint32 page_size = 1; 11 | uint32 page = 2; 12 | } 13 | 14 | message ListAccountsResponse { 15 | repeated Account accounts = 1; 16 | uint32 total_count = 2; 17 | } 18 | 19 | message GetAccountRequest { 20 | string username = 1; 21 | } 22 | 23 | message GetAccountResponse { 24 | Account account =1; 25 | } 26 | 27 | message UpdateAccountRequest { 28 | string username = 1; 29 | string password = 2; 30 | } 31 | 32 | message Account { 33 | string username = 1; 34 | string password = 2; 35 | } 36 | 37 | message DeleteAccountRequest { 38 | string username = 1; 39 | } 40 | 41 | service AccountService { 42 | // List all accounts 43 | rpc List (ListAccountsRequest) returns (ListAccountsResponse){ 44 | option (google.api.http) = { 45 | get: "/v1/accounts" 46 | }; 47 | } 48 | 49 | // Get the account for given username. 50 | // Return NotFound error when account not found. 51 | rpc Get (GetAccountRequest) returns (GetAccountResponse){ 52 | option (google.api.http) = { 53 | get: "/v1/accounts/{username}" 54 | }; 55 | } 56 | // Update the password for the account. 57 | // This API will create the account if not exists. 58 | rpc Update(UpdateAccountRequest) returns (google.protobuf.Empty) { 59 | option (google.api.http) = { 60 | post: "/v1/accounts/{username}" 61 | body:"*" 62 | }; 63 | } 64 | // Delete the account for given username 65 | rpc Delete (DeleteAccountRequest) returns (google.protobuf.Empty) { 66 | option (google.api.http) = { 67 | delete: "/v1/accounts/{username}" 68 | }; 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /plugin/auth/protos/proto_gen.sh: -------------------------------------------------------------------------------- 1 | protoc -I. \ 2 | -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway \ 3 | -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis \ 4 | --go-grpc_out=../ \ 5 | --go_out=../ \ 6 | --grpc-gateway_out=../ \ 7 | --swagger_out=../swagger \ 8 | *.proto -------------------------------------------------------------------------------- /plugin/auth/testdata/gmqtt_password.yml: -------------------------------------------------------------------------------- 1 | - username: u1 2 | password: p1 3 | - username: u2 4 | password: p2 -------------------------------------------------------------------------------- /plugin/auth/testdata/gmqtt_password_duplicated.yml: -------------------------------------------------------------------------------- 1 | - username: u1 2 | password: p1 3 | - username: u1 4 | password: p1 -------------------------------------------------------------------------------- /plugin/auth/testdata/gmqtt_password_save.yml: -------------------------------------------------------------------------------- 1 | - username: u1 2 | password: p1 3 | - username: u2 4 | password: p2 -------------------------------------------------------------------------------- /plugin/federation/examples/join_node3_config.yml: -------------------------------------------------------------------------------- 1 | listeners: 2 | - address: ":1885" 3 | api: 4 | grpc: 5 | - address: "tcp://127.0.0.1:8284" 6 | http: 7 | - address: "tcp://127.0.0.1:8283" 8 | map: "tcp://127.0.0.1:8284" # The backend gRPC server endpoint 9 | mqtt: 10 | session_expiry: 2h 11 | session_expiry_check_timer: 20s 12 | message_expiry: 2h 13 | max_packet_size: 268435456 14 | server_receive_maximum: 100 15 | max_keepalive: 60 16 | topic_alias_maximum: 10 17 | subscription_identifier_available: true 18 | wildcard_subscription_available: true 19 | shared_subscription_available: true 20 | maximum_qos: 2 21 | retain_available: true 22 | max_queued_messages: 10000 23 | max_inflight: 1000 24 | queue_qos0_messages: true 25 | delivery_mode: onlyonce # overlap or onlyonce 26 | allow_zero_length_clientid: true 27 | 28 | plugins: 29 | federation: 30 | # node_name is the unique identifier for the node in the federation. Defaults to hostname. 31 | node_name: node3 32 | # fed_addr is the gRPC server listening address for the federation internal communication. Defaults to :8901 33 | fed_addr: :8931 34 | # advertise_fed_addr is used to change the federation gRPC server address that we advertise to other nodes in the cluster. 35 | # Defaults to "fed_addr".However, in some cases, there may be a routable address that cannot be bound. 36 | # If the port is missing, the default federation port (8901) will be used. 37 | advertise_fed_addr: :8931 38 | # gossip_addr is the address that the gossip will listen on, It is used for both UDP and TCP gossip. Defaults to :8902 39 | gossip_addr: :8932 40 | # retry_join is the address of other nodes to join upon starting up. 41 | # If port is missing, the default gossip port (8902) will be used. 42 | #retry_join: 43 | # - 127.0.0.1:8912 44 | # rejoin_after_leave will be pass to "RejoinAfterLeave" in serf configuration. 45 | # It controls our interaction with the snapshot file. 46 | # When set to false (default), a leave causes a Serf to not rejoin the cluster until an explicit join is received. 47 | # If this is set to true, we ignore the leave, and rejoin the cluster on start. 48 | rejoin_after_leave: false 49 | # snapshot_path will be pass to "SnapshotPath" in serf configuration. 50 | # When Serf is started with a snapshot,it will attempt to join all the previously known nodes until one 51 | # succeeds and will also avoid replaying old user events. 52 | snapshot_path: 53 | 54 | # plugin loading orders 55 | plugin_order: 56 | # Uncomment auth to enable authentication. 57 | # - auth 58 | #- prometheus 59 | #- admin 60 | - federation 61 | log: 62 | level: debug # debug | info | warn | error 63 | format: text # json | text 64 | # whether to dump MQTT packet in debug level 65 | dump_packet: false 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /plugin/federation/examples/node1_config.yml: -------------------------------------------------------------------------------- 1 | listeners: 2 | - address: ":1883" 3 | api: 4 | grpc: 5 | - address: "tcp://127.0.0.1:8084" 6 | http: 7 | - address: "tcp://127.0.0.1:8083" 8 | map: "tcp://127.0.0.1:8084" # The backend gRPC server endpoint 9 | mqtt: 10 | session_expiry: 2h 11 | session_expiry_check_timer: 20s 12 | message_expiry: 2h 13 | max_packet_size: 268435456 14 | server_receive_maximum: 100 15 | max_keepalive: 60 16 | topic_alias_maximum: 10 17 | subscription_identifier_available: true 18 | wildcard_subscription_available: true 19 | shared_subscription_available: true 20 | maximum_qos: 2 21 | retain_available: true 22 | max_queued_messages: 10000 23 | max_inflight: 1000 24 | queue_qos0_messages: true 25 | delivery_mode: onlyonce # overlap or onlyonce 26 | allow_zero_length_clientid: true 27 | 28 | plugins: 29 | federation: 30 | # node_name is the unique identifier for the node in the federation. Defaults to hostname. 31 | node_name: node1 32 | # fed_addr is the gRPC server listening address for the federation internal communication. Defaults to :8901 33 | fed_addr: :8901 34 | # advertise_fed_addr is used to change the federation gRPC server address that we advertise to other nodes in the cluster. 35 | # Defaults to "fed_addr".However, in some cases, there may be a routable address that cannot be bound. 36 | # If the port is missing, the default federation port (8901) will be used. 37 | advertise_fed_addr: :8901 38 | # gossip_addr is the address that the gossip will listen on, It is used for both UDP and TCP gossip. Defaults to :8902 39 | gossip_addr: :8902 40 | # retry_join is the address of other nodes to join upon starting up. 41 | # If port is missing, the default gossip port (8902) will be used. 42 | retry_join: 43 | # Change 127.0.0.1 to real routable ip address if you run gmqtt in multiple nodes. 44 | - 127.0.0.1:8912 45 | # rejoin_after_leave will be pass to "RejoinAfterLeave" in serf configuration. 46 | # It controls our interaction with the snapshot file. 47 | # When set to false (default), a leave causes a Serf to not rejoin the cluster until an explicit join is received. 48 | # If this is set to true, we ignore the leave, and rejoin the cluster on start. 49 | rejoin_after_leave: false 50 | # snapshot_path will be pass to "SnapshotPath" in serf configuration. 51 | # When Serf is started with a snapshot,it will attempt to join all the previously known nodes until one 52 | # succeeds and will also avoid replaying old user events. 53 | snapshot_path: 54 | 55 | # plugin loading orders 56 | plugin_order: 57 | # Uncomment auth to enable authentication. 58 | # - auth 59 | #- prometheus 60 | #- admin 61 | - federation 62 | log: 63 | level: debug # debug | info | warn | error 64 | format: text # json | text 65 | # whether to dump MQTT packet in debug level 66 | dump_packet: false 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /plugin/federation/examples/node2_config.yml: -------------------------------------------------------------------------------- 1 | listeners: 2 | - address: ":1884" 3 | api: 4 | grpc: 5 | - address: "tcp://127.0.0.1:8184" 6 | http: 7 | - address: "tcp://127.0.0.1:8183" 8 | map: "tcp://127.0.0.1:8184" # The backend gRPC server endpoint 9 | mqtt: 10 | session_expiry: 2h 11 | session_expiry_check_timer: 20s 12 | message_expiry: 2h 13 | max_packet_size: 268435456 14 | server_receive_maximum: 100 15 | max_keepalive: 60 16 | topic_alias_maximum: 10 17 | subscription_identifier_available: true 18 | wildcard_subscription_available: true 19 | shared_subscription_available: true 20 | maximum_qos: 2 21 | retain_available: true 22 | max_queued_messages: 10000 23 | max_inflight: 1000 24 | queue_qos0_messages: true 25 | delivery_mode: onlyonce # overlap or onlyonce 26 | allow_zero_length_clientid: true 27 | 28 | plugins: 29 | federation: 30 | # node_name is the unique identifier for the node in the federation. Defaults to hostname. 31 | node_name: node2 32 | # fed_addr is the gRPC server listening address for the federation internal communication. Defaults to :8901 33 | fed_addr: :8911 34 | # advertise_fed_addr is used to change the federation gRPC server address that we advertise to other nodes in the cluster. 35 | # Defaults to "fed_addr".However, in some cases, there may be a routable address that cannot be bound. 36 | # If the port is missing, the default federation port (8901) will be used. 37 | advertise_fed_addr: :8911 38 | # gossip_addr is the address that the gossip will listen on, It is used for both UDP and TCP gossip. Defaults to :8902 39 | gossip_addr: :8912 40 | # retry_join is the address of other nodes to join upon starting up. 41 | # If port is missing, the default gossip port (8902) will be used. 42 | retry_join: 43 | # Change 127.0.0.1 to real routable ip address if you run gmqtt in multiple nodes. 44 | - 127.0.0.1:8902 45 | # rejoin_after_leave will be pass to "RejoinAfterLeave" in serf configuration. 46 | # It controls our interaction with the snapshot file. 47 | # When set to false (default), a leave causes a Serf to not rejoin the cluster until an explicit join is received. 48 | # If this is set to true, we ignore the leave, and rejoin the cluster on start. 49 | rejoin_after_leave: false 50 | # snapshot_path will be pass to "SnapshotPath" in serf configuration. 51 | # When Serf is started with a snapshot,it will attempt to join all the previously known nodes until one 52 | # succeeds and will also avoid replaying old user events. 53 | snapshot_path: 54 | 55 | # plugin loading orders 56 | plugin_order: 57 | # Uncomment auth to enable authentication. 58 | # - auth 59 | #- prometheus 60 | #- admin 61 | - federation 62 | log: 63 | level: debug # debug | info | warn | error 64 | format: text # json | text 65 | # whether to dump MQTT packet in debug level 66 | dump_packet: false 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /plugin/federation/federation.pb_mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: plugin/federation/federation.pb.go 3 | 4 | // Package federation is a generated GoMock package. 5 | package federation 6 | 7 | import ( 8 | gomock "github.com/golang/mock/gomock" 9 | reflect "reflect" 10 | ) 11 | 12 | // MockisEvent_Event is a mock of isEvent_Event interface 13 | type MockisEvent_Event struct { 14 | ctrl *gomock.Controller 15 | recorder *MockisEvent_EventMockRecorder 16 | } 17 | 18 | // MockisEvent_EventMockRecorder is the mock recorder for MockisEvent_Event 19 | type MockisEvent_EventMockRecorder struct { 20 | mock *MockisEvent_Event 21 | } 22 | 23 | // NewMockisEvent_Event creates a new mock instance 24 | func NewMockisEvent_Event(ctrl *gomock.Controller) *MockisEvent_Event { 25 | mock := &MockisEvent_Event{ctrl: ctrl} 26 | mock.recorder = &MockisEvent_EventMockRecorder{mock} 27 | return mock 28 | } 29 | 30 | // EXPECT returns an object that allows the caller to indicate expected use 31 | func (m *MockisEvent_Event) EXPECT() *MockisEvent_EventMockRecorder { 32 | return m.recorder 33 | } 34 | 35 | // isEvent_Event mocks base method 36 | func (m *MockisEvent_Event) isEvent_Event() { 37 | m.ctrl.T.Helper() 38 | m.ctrl.Call(m, "isEvent_Event") 39 | } 40 | 41 | // isEvent_Event indicates an expected call of isEvent_Event 42 | func (mr *MockisEvent_EventMockRecorder) isEvent_Event() *gomock.Call { 43 | mr.mock.ctrl.T.Helper() 44 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "isEvent_Event", reflect.TypeOf((*MockisEvent_Event)(nil).isEvent_Event)) 45 | } 46 | -------------------------------------------------------------------------------- /plugin/federation/membership.go: -------------------------------------------------------------------------------- 1 | package federation 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/google/uuid" 7 | "github.com/hashicorp/serf/serf" 8 | "go.uber.org/zap" 9 | ) 10 | 11 | // iSerf is the interface for *serf.Serf. 12 | // It is used for test. 13 | type iSerf interface { 14 | Join(existing []string, ignoreOld bool) (int, error) 15 | RemoveFailedNode(node string) error 16 | Leave() error 17 | Members() []serf.Member 18 | Shutdown() error 19 | } 20 | 21 | var servePeerEventStream = func(p *peer) { 22 | p.serveEventStream() 23 | } 24 | 25 | func (f *Federation) startSerf(t *time.Timer) error { 26 | defer func() { 27 | t.Reset(f.config.RetryInterval) 28 | }() 29 | if _, err := f.serf.Join(f.config.RetryJoin, true); err != nil { 30 | return err 31 | } 32 | go f.eventHandler() 33 | return nil 34 | } 35 | 36 | func (f *Federation) eventHandler() { 37 | for { 38 | select { 39 | case evt := <-f.serfEventCh: 40 | switch evt.EventType() { 41 | case serf.EventMemberJoin: 42 | f.nodeJoin(evt.(serf.MemberEvent)) 43 | case serf.EventMemberLeave, serf.EventMemberFailed, serf.EventMemberReap: 44 | f.nodeFail(evt.(serf.MemberEvent)) 45 | case serf.EventUser: 46 | case serf.EventMemberUpdate: 47 | // TODO 48 | case serf.EventQuery: // Ignore 49 | default: 50 | } 51 | case <-f.exit: 52 | f.memberMu.Lock() 53 | for _, v := range f.peers { 54 | v.stop() 55 | } 56 | f.memberMu.Unlock() 57 | return 58 | } 59 | } 60 | } 61 | 62 | func (f *Federation) nodeJoin(member serf.MemberEvent) { 63 | f.memberMu.Lock() 64 | defer f.memberMu.Unlock() 65 | for _, v := range member.Members { 66 | if v.Name == f.nodeName { 67 | continue 68 | } 69 | log.Info("member joined", zap.String("node_name", v.Name)) 70 | if _, ok := f.peers[v.Name]; !ok { 71 | p := &peer{ 72 | fed: f, 73 | member: v, 74 | exit: make(chan struct{}), 75 | sessionID: uuid.New().String(), 76 | queue: newEventQueue(), 77 | localName: f.nodeName, 78 | } 79 | f.peers[v.Name] = p 80 | go servePeerEventStream(p) 81 | } 82 | } 83 | } 84 | 85 | func (f *Federation) nodeFail(member serf.MemberEvent) { 86 | f.memberMu.Lock() 87 | defer f.memberMu.Unlock() 88 | for _, v := range member.Members { 89 | if v.Name == f.nodeName { 90 | continue 91 | } 92 | if p, ok := f.peers[v.Name]; ok { 93 | log.Error("node failed, close stream client", zap.String("node_name", v.Name)) 94 | p.stop() 95 | delete(f.peers, v.Name) 96 | _ = f.fedSubStore.UnsubscribeAll(v.Name) 97 | f.sessionMgr.del(v.Name) 98 | } 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /plugin/federation/protos/proto_gen.sh: -------------------------------------------------------------------------------- 1 | protoc -I. \ 2 | -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway \ 3 | -I$GOPATH/src/github.com/grpc-ecosystem/grpc-gateway/third_party/googleapis \ 4 | --go-grpc_out=../ \ 5 | --go_out=../ \ 6 | --grpc-gateway_out=../ \ 7 | --swagger_out=../swagger \ 8 | *.proto 9 | -------------------------------------------------------------------------------- /plugin/prometheus/README.md: -------------------------------------------------------------------------------- 1 | # Prometheus 2 | `Prometheus` implements the prometheus exporter for gmqtt. 3 | Default URL: 127.0.0.1:8082/metrics 4 | 5 | # Metrics 6 | 7 | metric name | Type | Labels 8 | ---|---|--- 9 | gmqtt_clients_connected_total | Counter | 10 | gmqtt_messages_dropped_total | Counter | qos: qos of the dropped message 11 | gmqtt_packets_received_bytes_total | Counter | type: type of the packet 12 | gmqtt_packets_received_total | Counter | type: type of the packet 13 | gmqtt_packets_sent_bytes_total | Counter | type: type of the packet 14 | gmqtt_packets_sent_total | Counter | type: type of the packet 15 | gmqtt_sessions_created_total | Counter | 16 | gmqtt_sessions_terminated_total | Counter | reason: the reason of termination. (expired|taken_over|normal) 17 | gmqtt_sessions_active_current | Gauge | 18 | gmqtt_sessions_expired_total | Counter | 19 | gmqtt_sessions_inactive_current | Gauge | 20 | gmqtt_subscriptions_current | Gauge | 21 | gmqtt_subscriptions_total | Counter | 22 | gmqtt_messages_queued_current | Gauge | 23 | gmqtt_messages_received_total | Counter | qos: qos of the message 24 | gmqtt_messages_sent_total | Counter | qos: qos of the message -------------------------------------------------------------------------------- /plugin/prometheus/config.go: -------------------------------------------------------------------------------- 1 | package prometheus 2 | 3 | import ( 4 | "errors" 5 | "net" 6 | ) 7 | 8 | // Config is the configuration for the prometheus plugin. 9 | type Config struct { 10 | // ListenAddress is the address that the exporter will listen on. 11 | ListenAddress string `yaml:"listen_address"` 12 | // Path is the exporter url path. 13 | Path string `yaml:"path"` 14 | } 15 | 16 | // Validate validates the configuration, and return an error if it is invalid. 17 | func (c *Config) Validate() error { 18 | _, _, err := net.SplitHostPort(c.ListenAddress) 19 | if err != nil { 20 | return errors.New("invalid listen_address") 21 | } 22 | return nil 23 | } 24 | 25 | // DefaultConfig is the default configuration. 26 | var DefaultConfig = Config{ 27 | ListenAddress: ":8082", 28 | Path: "/metrics", 29 | } 30 | 31 | func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { 32 | type cfg Config 33 | var v = &struct { 34 | Prometheus cfg `yaml:"prometheus"` 35 | }{ 36 | Prometheus: cfg(DefaultConfig), 37 | } 38 | if err := unmarshal(v); err != nil { 39 | return err 40 | } 41 | empty := cfg(Config{}) 42 | if v.Prometheus == empty { 43 | v.Prometheus = cfg(DefaultConfig) 44 | } 45 | *c = Config(v.Prometheus) 46 | return nil 47 | } 48 | -------------------------------------------------------------------------------- /plugin/prometheus/hooks.go: -------------------------------------------------------------------------------- 1 | package prometheus 2 | 3 | import ( 4 | "github.com/DrmagicE/gmqtt/server" 5 | ) 6 | 7 | func (p *Prometheus) HookWrapper() server.HookWrapper { 8 | return server.HookWrapper{} 9 | } 10 | -------------------------------------------------------------------------------- /plugin/thingspanel/config.go: -------------------------------------------------------------------------------- 1 | package thingspanel 2 | 3 | // Config is the configuration for the thingspanel plugin. 4 | type Config struct { 5 | // add your config fields 6 | } 7 | 8 | // Validate validates the configuration, and return an error if it is invalid. 9 | func (c *Config) Validate() error { 10 | return nil 11 | } 12 | 13 | // DefaultConfig is the default configuration. 14 | var DefaultConfig = Config{} 15 | 16 | func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { 17 | return nil 18 | } 19 | -------------------------------------------------------------------------------- /plugin/thingspanel/mqtt.go: -------------------------------------------------------------------------------- 1 | package thingspanel 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | mqtt "github.com/eclipse/paho.mqtt.golang" 8 | "github.com/spf13/viper" 9 | ) 10 | 11 | type MqttClient struct { 12 | Client mqtt.Client 13 | IsFlag bool 14 | } 15 | 16 | var DefaultMqttClient *MqttClient = &MqttClient{} 17 | 18 | func (c *MqttClient) MqttInit() error { 19 | opts := mqtt.NewClientOptions() 20 | opts.SetUsername("root") 21 | password := viper.GetString("mqtt.password") 22 | opts.SetPassword(password) 23 | addr := viper.GetString("mqtt.broker") 24 | if addr == "" { 25 | addr = "localhost:1883" 26 | } 27 | opts.AddBroker(addr) 28 | // 干净会话 29 | opts.SetCleanSession(true) 30 | // 失败重连 31 | opts.SetAutoReconnect(true) 32 | opts.SetConnectRetryInterval(1 * time.Second) // 初始连接重试间隔 33 | opts.SetMaxReconnectInterval(200 * time.Second) // 丢失连接后的最大重试间隔 34 | 35 | opts.SetOrderMatters(false) //设置消息的顺序 36 | //opts.OnConnectionLost = connectLostHandler 37 | opts.SetOnConnectHandler(func(c mqtt.Client) { 38 | fmt.Println("Mqtt客户端已连接") 39 | }) 40 | opts.SetClientID("thingspanel-gmqtt-client") 41 | c.Client = mqtt.NewClient(opts) 42 | // 等待连接成功 43 | // 等待连接成功 44 | for { 45 | if token := c.Client.Connect(); token.Wait() && token.Error() != nil { 46 | fmt.Println("Mqtt客户端连接失败(", addr, "),等待重连...") 47 | time.Sleep(1 * time.Second) 48 | } else { 49 | fmt.Println("Mqtt客户端连接成功") 50 | c.IsFlag = true 51 | break 52 | } 53 | } 54 | return nil 55 | } 56 | 57 | func (c *MqttClient) SendData(topic string, data []byte) error { 58 | defer func() { 59 | if err := recover(); err != nil { 60 | fmt.Println("【SendData】异常捕捉:", err) 61 | return 62 | } 63 | }() 64 | //go func() { 65 | Log.Info("检查MqttClIent连接状态...") 66 | if !c.IsFlag { 67 | i := 1 68 | for { 69 | fmt.Println("等待...", i) 70 | if i == 10 || c.IsFlag { 71 | break 72 | } 73 | time.Sleep(1 * time.Second) 74 | i++ 75 | } 76 | } 77 | Log.Info("发送设备状态...") 78 | token := c.Client.Publish(topic, 0, false, string(data)) 79 | if !token.WaitTimeout(5 * time.Second) { 80 | Log.Warn("发送设备状态超时") 81 | } else if err := token.Error(); err != nil { 82 | Log.Warn("发送设备状态失败: " + err.Error()) 83 | } 84 | Log.Info("发送设备状态完成") 85 | //}() 86 | return nil 87 | } 88 | -------------------------------------------------------------------------------- /plugin/thingspanel/thingspanel.go: -------------------------------------------------------------------------------- 1 | package thingspanel 2 | 3 | import ( 4 | "fmt" 5 | "io/ioutil" 6 | "log" 7 | "net/http" 8 | "strings" 9 | 10 | "go.uber.org/zap" 11 | 12 | "github.com/DrmagicE/gmqtt/config" 13 | "github.com/DrmagicE/gmqtt/server" 14 | "github.com/spf13/viper" 15 | ) 16 | 17 | var _ server.Plugin = (*Thingspanel)(nil) 18 | 19 | const Name = "thingspanel" 20 | 21 | func init() { 22 | log.Println("系统配置文件初始化...") 23 | viper.SetEnvPrefix("GMQTT") 24 | viper.AutomaticEnv() 25 | viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) 26 | viper.SetConfigName("thingspanel") 27 | viper.SetConfigType("yml") 28 | viper.AddConfigPath(".") 29 | err := viper.ReadInConfig() 30 | if err != nil { 31 | panic(fmt.Errorf("failed to read configuration file: %s", err)) 32 | } 33 | log.Println("系统配置文件初始化完成") 34 | Init() //启动数据库和redis 35 | go DefaultMqttClient.MqttInit() 36 | server.RegisterPlugin(Name, New) 37 | config.RegisterDefaultPluginConfig(Name, &DefaultConfig) 38 | } 39 | 40 | func New(config config.Config) (server.Plugin, error) { 41 | //panic("implement me") 42 | return &Thingspanel{}, nil 43 | } 44 | 45 | var Log *zap.Logger 46 | 47 | type Thingspanel struct { 48 | } 49 | 50 | func (t *Thingspanel) Load(service server.Server) error { 51 | Log = server.LoggerWithField(zap.String("plugin", Name)) 52 | return nil 53 | } 54 | 55 | func (t *Thingspanel) Unload() error { 56 | return nil 57 | } 58 | 59 | func (t *Thingspanel) Name() string { 60 | return Name 61 | } 62 | 63 | // 不用 64 | func (t *Thingspanel) UpdateStatus(accessToken string, status string) { 65 | url := "/api/device/status" 66 | method := "POST" 67 | payload := strings.NewReader(`"accessToken": "` + accessToken + `","values":{"status": "` + status + `"}}`) 68 | client := &http.Client{} 69 | req, err := http.NewRequest(method, url, payload) 70 | if err != nil { 71 | fmt.Println(err) 72 | return 73 | } 74 | req.Header.Add("Content-Type", "application/json") 75 | res, err := client.Do(req) 76 | if err != nil { 77 | fmt.Println(err) 78 | return 79 | } 80 | defer res.Body.Close() 81 | body, err := ioutil.ReadAll(res.Body) 82 | if err != nil { 83 | fmt.Println(err) 84 | return 85 | } 86 | fmt.Println(string(body)) 87 | } 88 | -------------------------------------------------------------------------------- /plugin/thingspanel/util/check_pub_topic.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | // 物联网设备主题规则列表 8 | var pubList = []string{ 9 | "devices/telemetry", // 遥测上报 10 | "devices/attributes/+", // 属性上报 11 | "devices/event/+", // 事件上报 12 | "ota/devices/progress", // 设备升级进度更新 13 | "devices/attributes/set/response/+", // 属性设置响应上报 14 | "devices/command/response/+", // 命令响应上报 15 | 16 | "gateway/telemetry", // 设备遥测(网关) 17 | "gateway/attributes/+", // 属性上报 (网关) 18 | "gateway/event/+", // 事件上报 (网关) 19 | "gateway/attributes/set/response/+", // 属性设置响应上报 (网关) 20 | "gateway/command/response/+", // 命令响应上报 (网关) 21 | 22 | "devices/register", //网关子设备注册 23 | "devices/config/down", //设备配置下载 24 | 25 | "+/up", //心智悦喷淋一体机上行数据 26 | } 27 | 28 | // MQTT 通配符 29 | const mqttWildcard = "+" 30 | 31 | // ValidateTopic 检查一个主题是否符合pubList里的任何一种模式 32 | func ValidateTopic(topic string) bool { 33 | for _, pattern := range pubList { 34 | if matchesPattern(topic, pattern) { 35 | return true 36 | } 37 | } 38 | return false 39 | } 40 | 41 | // matchesPattern 检查一个主题是否符合给定的模式 42 | func matchesPattern(topic, pattern string) bool { 43 | topicParts := strings.Split(topic, "/") 44 | patternParts := strings.Split(pattern, "/") 45 | 46 | // 如果主题和模式部分的长度不一致,则不匹配 47 | if len(topicParts) != len(patternParts) { 48 | return false 49 | } 50 | 51 | // 检查是否直接匹配或者是通配符 52 | for i := range topicParts { 53 | if patternParts[i] != mqttWildcard && topicParts[i] != patternParts[i] { 54 | return false 55 | } 56 | } 57 | 58 | return true 59 | } 60 | -------------------------------------------------------------------------------- /plugin/thingspanel/util/check_pub_topic_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "testing" 4 | 5 | func TestValidateTopic(t *testing.T) { 6 | var cases = []struct { 7 | input string 8 | want bool 9 | }{ 10 | {"devices/telemetry", true}, 11 | {"devices/attributes/test", true}, 12 | {"devices/event/test", true}, 13 | {"gateway/attributes/test", true}, 14 | {"gateway/event/test", true}, 15 | {"devices/telemetry/test", false}, 16 | {"devices/test/telemetry", false}, 17 | {"devices_test", false}, 18 | {"", false}, 19 | {"xxxx/up", true}, 20 | } 21 | 22 | for _, c := range cases { 23 | got := ValidateTopic(c.input) 24 | if got != c.want { 25 | t.Errorf("ValidateTopic(%q) == %v, want %v", c.input, got, c.want) 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /plugin/thingspanel/util/check_sub_topic.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | var subList = []string{ 8 | "devices/telemetry/control/{device_number}", //订阅平台下发的控制 9 | "devices/telemetry/control/{device_number}/+", //订阅平台下发的控制 10 | "devices/attributes/set/{device_number}/+", //订阅平台下发的属性设置 11 | "devices/attributes/get/{device_number}", //订阅平台对属性的请求 12 | "devices/command/{device_number}/+", //订阅命令 13 | 14 | "ota/devices/infrom/{device_number}", //接收升级任务(固件升级相关) 15 | 16 | "devices/attributes/response/{device_number}/+", //订阅平台收到属性的响应 17 | "devices/event/response/{device_number}/+", //接收平台收到事件的响应 18 | 19 | "gateway/telemetry/control/{device_number}", //订阅平台下发的控制(网关) 20 | "gateway/attributes/set/{device_number}/+", //订阅平台下发的属性设置(网关) 21 | "gateway/attributes/get/{device_number}", //订阅平台对属性的请求(网关) 22 | "gateway/command/{device_number}/+", //订阅命令(网关) 23 | 24 | "gateway/attributes/response/{device_number}/+", //订阅平台收到属性的响应(网关) 25 | "gateway/event/response/{device_number}/+", //接收平台收到事件的响应(网关) 26 | 27 | "{device_number}/down", //心智悦喷淋一体机下行数据 28 | 29 | "devices/register/response/+", //网关子设备注册平台回复 30 | "devices/config/down/response/+", //设备配置下载平台回复 31 | } 32 | 33 | // ValidateTopic 检查一个主题是否符合subList里的任何一种模式 34 | func ValidateSubTopic(topic string) bool { 35 | for _, pattern := range subList { 36 | if matchesPatternSub(topic, pattern) { 37 | return true 38 | } 39 | } 40 | return false 41 | } 42 | 43 | // matchesPattern 检查一个主题是否符合给定的模式 44 | func matchesPatternSub(topic, pattern string) bool { 45 | topicParts := strings.Split(topic, "/") 46 | patternParts := strings.Split(pattern, "/") 47 | 48 | // 如果主题和模式部分的长度不一致,则不匹配 49 | if len(topicParts) != len(patternParts) { 50 | return false 51 | } 52 | 53 | // 检查每个部分 54 | for i := range topicParts { 55 | switch patternParts[i] { 56 | case "{device_number}": 57 | // {device_number} 部分不能是+或者# 58 | if topicParts[i] == "+" || topicParts[i] == "#" { 59 | return false 60 | } 61 | case "+": 62 | // +部分不可以是#通配符,可以是其他任意字符包括+通配符 63 | if topicParts[i] == "#" { 64 | return false 65 | } 66 | default: 67 | // 其他部分必须相等 68 | if topicParts[i] != patternParts[i] { 69 | return false 70 | } 71 | } 72 | } 73 | 74 | return true 75 | } 76 | -------------------------------------------------------------------------------- /plugin/thingspanel/util/check_sub_topic_test.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import "testing" 4 | 5 | func TestValidateSubTopic(t *testing.T) { 6 | var cases = []struct { 7 | input string 8 | want bool 9 | }{ 10 | {"devices/telemetry", false}, 11 | {"devices/telemetry/xxxxxx/+", true}, 12 | {"devices/attributes/set/xxxxxx/+", true}, 13 | {"devices/attributes/set/+/+", false}, 14 | {"devices/attributes/get/xxxxxx", true}, 15 | {"devices/command/xxxxx/+", true}, 16 | {"ota/devices/infrom/xxxxx", true}, 17 | {"evices/attributes/response/xxxx/+", true}, 18 | {"devices/event/response/xxxxxx/+", true}, 19 | {"", false}, 20 | {"001/down", true}, 21 | } 22 | 23 | for _, c := range cases { 24 | got := ValidateSubTopic(c.input) 25 | if got != c.want { 26 | t.Errorf("ValidateSubTopic(%q) == %v, want %v", c.input, got, c.want) 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /plugin_generate.go: -------------------------------------------------------------------------------- 1 | // +build ignore 2 | 3 | package main 4 | 5 | import ( 6 | "bytes" 7 | "go/format" 8 | "io" 9 | "io/ioutil" 10 | "log" 11 | "strings" 12 | "text/template" 13 | 14 | "gopkg.in/yaml.v2" 15 | ) 16 | 17 | var tmpl = `//go:generate sh -c "cd ../../ && go run plugin_generate.go" 18 | // generated by plugin_generate.go; DO NOT EDIT 19 | 20 | package main 21 | 22 | import ( 23 | {{- range $index, $element := .}} 24 | _ "{{$element}}" 25 | {{- end}} 26 | ) 27 | ` 28 | 29 | const ( 30 | pluginFile = "./cmd/gmqttd/plugins.go" 31 | pluginCfg = "plugin_imports.yml" 32 | importPath = "github.com/DrmagicE/gmqtt/plugin" 33 | ) 34 | 35 | type ymlCfg struct { 36 | Packages []string `yaml:"packages"` 37 | } 38 | 39 | func main() { 40 | b, err := ioutil.ReadFile(pluginCfg) 41 | if err != nil { 42 | log.Fatalf("ReadFile error %s", err) 43 | return 44 | } 45 | 46 | var cfg ymlCfg 47 | err = yaml.Unmarshal(b, &cfg) 48 | if err != nil { 49 | log.Fatalf("Unmarshal error: %s", err) 50 | return 51 | } 52 | t, err := template.New("plugin_gen").Parse(tmpl) 53 | if err != nil { 54 | log.Fatalf("Parse template error: %s", err) 55 | return 56 | } 57 | 58 | for k, v := range cfg.Packages { 59 | if !strings.Contains(v, "/") { 60 | cfg.Packages[k] = importPath + "/" + v 61 | } 62 | } 63 | 64 | if err != nil && err != io.EOF { 65 | log.Fatalf("read error: %s", err) 66 | return 67 | } 68 | buf := &bytes.Buffer{} 69 | err = t.Execute(buf, cfg.Packages) 70 | if err != nil { 71 | log.Fatalf("excute template error: %s", err) 72 | return 73 | } 74 | rs, err := format.Source(buf.Bytes()) 75 | if err != nil { 76 | log.Fatalf("format error: %s", err) 77 | return 78 | } 79 | err = ioutil.WriteFile(pluginFile, rs, 0666) 80 | if err != nil { 81 | log.Fatalf("writeFile error: %s", err) 82 | return 83 | } 84 | return 85 | } 86 | -------------------------------------------------------------------------------- /plugin_imports.yml: -------------------------------------------------------------------------------- 1 | packages: 2 | - admin 3 | - prometheus 4 | - federation 5 | - auth 6 | - thingspanel 7 | # for external plugin, use full import path 8 | # - github.com/DrmagicE/gmqtt/plugin/prometheus -------------------------------------------------------------------------------- /retained/interface.go: -------------------------------------------------------------------------------- 1 | package retained 2 | 3 | import ( 4 | "github.com/DrmagicE/gmqtt" 5 | ) 6 | 7 | // IterateFn is the callback function used by iterate() 8 | // Return false means to stop the iteration. 9 | type IterateFn func(message *gmqtt.Message) bool 10 | 11 | // Store is the interface used by gmqtt.server and external logic to handler the operations of retained messages. 12 | // User can get the implementation from gmqtt.Server interface. 13 | // This interface provides the ability for extensions to interact with the retained message store. 14 | // Notice: 15 | // This methods will not trigger any gmqtt hooks. 16 | type Store interface { 17 | // GetRetainedMessage returns the message that equals the passed topic. 18 | GetRetainedMessage(topicName string) *gmqtt.Message 19 | // ClearAll clears all retained messages. 20 | ClearAll() 21 | // AddOrReplace adds or replaces a retained message. 22 | AddOrReplace(message *gmqtt.Message) 23 | // remove removes a retained message. 24 | Remove(topicName string) 25 | // GetMatchedMessages returns the retained messages that match the passed topic filter. 26 | GetMatchedMessages(topicFilter string) []*gmqtt.Message 27 | // Iterate iterate all retained messages. The callback is called once for each message. 28 | // If callback return false, the iteration will be stopped. 29 | // Notice: 30 | // The results are not sorted in any way, no ordering of any kind is guaranteed. 31 | // This method will walk through all retained messages, 32 | // so this will be a expensive operation if there are a large number of retained messages. 33 | Iterate(fn IterateFn) 34 | } 35 | -------------------------------------------------------------------------------- /retained/trie/trie_db.go: -------------------------------------------------------------------------------- 1 | package trie 2 | 3 | import ( 4 | "sync" 5 | 6 | "github.com/DrmagicE/gmqtt" 7 | "github.com/DrmagicE/gmqtt/retained" 8 | ) 9 | 10 | // trieDB implement the retain.Store, it use trie tree to store retain messages . 11 | type trieDB struct { 12 | sync.RWMutex 13 | userTrie *topicTrie 14 | systemTrie *topicTrie 15 | } 16 | 17 | func (t *trieDB) Iterate(fn retained.IterateFn) { 18 | t.RLock() 19 | defer t.RUnlock() 20 | if !t.userTrie.preOrderTraverse(fn) { 21 | return 22 | } 23 | t.systemTrie.preOrderTraverse(fn) 24 | } 25 | 26 | func (t *trieDB) getTrie(topicName string) *topicTrie { 27 | if isSystemTopic(topicName) { 28 | return t.systemTrie 29 | } 30 | return t.userTrie 31 | } 32 | 33 | // GetRetainedMessage return the retain message of the given topic name. 34 | // return nil if the topic name not exists 35 | func (t *trieDB) GetRetainedMessage(topicName string) *gmqtt.Message { 36 | t.RLock() 37 | defer t.RUnlock() 38 | node := t.getTrie(topicName).find(topicName) 39 | if node != nil { 40 | return node.msg.Copy() 41 | } 42 | return nil 43 | } 44 | 45 | // ClearAll clear all retain messages. 46 | func (t *trieDB) ClearAll() { 47 | t.Lock() 48 | defer t.Unlock() 49 | t.systemTrie = newTopicTrie() 50 | t.userTrie = newTopicTrie() 51 | } 52 | 53 | // AddOrReplace add or replace a retain message. 54 | func (t *trieDB) AddOrReplace(message *gmqtt.Message) { 55 | t.Lock() 56 | defer t.Unlock() 57 | t.getTrie(message.Topic).addRetainMsg(message.Topic, message) 58 | } 59 | 60 | // remove remove the retain message of the topic name. 61 | func (t *trieDB) Remove(topicName string) { 62 | t.Lock() 63 | defer t.Unlock() 64 | t.getTrie(topicName).remove(topicName) 65 | } 66 | 67 | // GetMatchedMessages returns all messages that match the topic filter. 68 | func (t *trieDB) GetMatchedMessages(topicFilter string) []*gmqtt.Message { 69 | t.RLock() 70 | defer t.RUnlock() 71 | return t.getTrie(topicFilter).getMatchedMessages(topicFilter) 72 | } 73 | 74 | func NewStore() *trieDB { 75 | return &trieDB{ 76 | userTrie: newTopicTrie(), 77 | systemTrie: newTopicTrie(), 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /script/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | mqtt "github.com/eclipse/paho.mqtt.golang" 8 | ) 9 | 10 | // mqtt服务器地址 11 | var BROKEN string = "127.0.0.1:1883" 12 | 13 | //每隔多久发送一次消息,单位s 14 | var LOOP_TIME int = 5 15 | 16 | // 模拟设备的数量 17 | var DEVICE_NUM int = 1000 18 | 19 | // 发送主题 20 | var TOPIC string = "device/attributes" 21 | 22 | // 发送的数据 23 | var PAYLOAD string = `{"temperature": 20, "humidity": 60}` 24 | 25 | // 消息质量 26 | var QOS int = 0 27 | 28 | func main() { 29 | fmt.Println("开始执行脚本...") 30 | // 循环创建mqtt客户端并每个客户端循环发送消息 31 | MqttPublishLoopClient(TOPIC, PAYLOAD, QOS) 32 | } 33 | 34 | // 新增mqtt客户端连接 35 | func MqttClient(clientId string) (mqtt.Client, error) { 36 | // 掉线重连 37 | var connectLostHandler mqtt.ConnectionLostHandler = func(c mqtt.Client, err error) { 38 | fmt.Printf("("+clientId+")Mqtt Connect lost: %v", err) 39 | i := 0 40 | for { 41 | time.Sleep(5 * time.Second) 42 | if !c.IsConnectionOpen() { 43 | i++ 44 | fmt.Println("("+clientId+")Mqtt客户端掉线重连...", i) 45 | if token := c.Connect(); token.Wait() && token.Error() != nil { 46 | fmt.Println("(" + clientId + ")Mqtt客户端连接失败...") 47 | } else { 48 | break 49 | } 50 | } else { 51 | //subscribe(msgProc1, gatewayMsgProc) 52 | break 53 | } 54 | } 55 | } 56 | opts := mqtt.NewClientOptions() 57 | opts.SetClientID(clientId) 58 | opts.AddBroker(BROKEN) 59 | opts.SetAutoReconnect(true) 60 | opts.SetOrderMatters(false) 61 | opts.OnConnectionLost = connectLostHandler 62 | opts.SetOnConnectHandler(func(c mqtt.Client) { 63 | fmt.Println("Mqtt客户端已连接(" + clientId + ")") 64 | }) 65 | reconnec_number := 0 66 | c := mqtt.NewClient(opts) 67 | // 异步建立连接,失败重连 68 | for { 69 | 70 | if token := c.Connect(); token.Wait() && token.Error() != nil { 71 | reconnec_number++ 72 | fmt.Println("链接错误错误说明:", token.Error().Error()) 73 | fmt.Println("Mqtt客户端连接失败("+clientId+")...重试", reconnec_number) 74 | } else { 75 | MqttPublishLoop(TOPIC, PAYLOAD, QOS, c) 76 | fmt.Println("Mqtt客户端连接成功(" + clientId + ")") 77 | break 78 | } 79 | time.Sleep(5 * time.Second) 80 | } 81 | return c, nil 82 | // 1.连接mqtt服务器 83 | // 2.发布消息 84 | // 3.断开连接 85 | } 86 | 87 | // 发送mqtt消息 88 | func MqttPublish(topic string, payload string, qos int, c mqtt.Client) { 89 | cc := c.OptionsReader() 90 | token := c.Publish(topic, byte(qos), false, payload) 91 | token.Wait() 92 | fmt.Printf("%s发送消息成功,topic:%s, payload:%s\n", cc.ClientID(), topic, payload) 93 | } 94 | 95 | // 循环发送mqtt消息 96 | func MqttPublishLoop(topic string, payload string, qos int, c mqtt.Client) { 97 | for { 98 | MqttPublish(topic, payload, qos, c) 99 | time.Sleep(time.Duration(LOOP_TIME) * time.Second) 100 | } 101 | } 102 | 103 | // 循环创建mqtt客户端并循环发送消息 104 | func MqttPublishLoopClient(topic string, payload string, qos int) { 105 | //循环生成100个clientId 106 | for i := 0; i < DEVICE_NUM; i++ { 107 | clientId := fmt.Sprintf("client_%d", i) 108 | go MqttClient(clientId) 109 | } 110 | time.Sleep(100 * time.Second) 111 | } 112 | -------------------------------------------------------------------------------- /server/limiter_test.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "testing" 5 | "time" 6 | 7 | "github.com/stretchr/testify/assert" 8 | 9 | "github.com/DrmagicE/gmqtt/pkg/packets" 10 | ) 11 | 12 | func Test_packetIDLimiter(t *testing.T) { 13 | a := assert.New(t) 14 | p := newPacketIDLimiter(10) 15 | ids := p.pollPacketIDs(20) 16 | a.Len(ids, 10) 17 | a.Equal([]packets.PacketID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, ids) 18 | 19 | p.batchRelease([]packets.PacketID{7, 8, 9}) 20 | 21 | ids = p.pollPacketIDs(4) 22 | a.Len(ids, 3) 23 | a.Equal([]packets.PacketID{11, 12, 13}, ids) 24 | 25 | c := make(chan struct{}) 26 | go func() { 27 | p.pollPacketIDs(1) 28 | c <- struct{}{} 29 | }() 30 | select { 31 | case <-c: 32 | t.Fatal("pollPacketIDs should be blocked") 33 | case <-time.After(1 * time.Second): 34 | } 35 | p.close() 36 | a.Nil(p.pollPacketIDs(10)) 37 | } 38 | 39 | func Test_packetIDLimiterMax(t *testing.T) { 40 | a := assert.New(t) 41 | p := newPacketIDLimiter(65535) 42 | ids := p.pollPacketIDs(65535) 43 | a.Len(ids, 65535) 44 | p.batchRelease([]packets.PacketID{1, 2, 3, 65535}) 45 | a.Equal([]packets.PacketID{1, 2, 3}, p.pollPacketIDs(3)) 46 | a.Equal([]packets.PacketID{65535}, p.pollPacketIDs(3)) 47 | 48 | } 49 | -------------------------------------------------------------------------------- /server/options.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/DrmagicE/gmqtt/config" 7 | "github.com/DrmagicE/gmqtt/retained" 8 | "go.uber.org/zap" 9 | ) 10 | 11 | type Options func(srv *server) 12 | 13 | // WithConfig set the config of the server 14 | func WithConfig(config config.Config) Options { 15 | return func(srv *server) { 16 | srv.config = config 17 | } 18 | } 19 | 20 | // WithTCPListener set tcp listener(s) of the server. Default listen on :1883. 21 | func WithTCPListener(lns ...net.Listener) Options { 22 | return func(srv *server) { 23 | srv.tcpListener = append(srv.tcpListener, lns...) 24 | } 25 | } 26 | 27 | // WithWebsocketServer set websocket server(s) of the server. 28 | func WithWebsocketServer(ws ...*WsServer) Options { 29 | return func(srv *server) { 30 | srv.websocketServer = ws 31 | } 32 | } 33 | 34 | // WithPlugin set plugin(s) of the server. 35 | func WithPlugin(plugin ...Plugin) Options { 36 | return func(srv *server) { 37 | srv.plugins = append(srv.plugins, plugin...) 38 | } 39 | } 40 | 41 | // WithHook set hooks of the server. Notice: WithPlugin() will overwrite hooks. 42 | func WithHook(hooks Hooks) Options { 43 | return func(srv *server) { 44 | srv.hooks = hooks 45 | } 46 | } 47 | 48 | func WithLogger(logger *zap.Logger) Options { 49 | return func(srv *server) { 50 | zaplog = logger 51 | } 52 | } 53 | 54 | // WithRetainedStore set retained db of the server. Notice: WithRetainedStore(s) will overwrite retainedDB. 55 | func WithRetainedStore(store retained.Store) Options { 56 | return func(srv *server) { 57 | srv.retainedDB = store 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /server/persistence.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/DrmagicE/gmqtt/config" 5 | "github.com/DrmagicE/gmqtt/persistence/queue" 6 | "github.com/DrmagicE/gmqtt/persistence/session" 7 | "github.com/DrmagicE/gmqtt/persistence/subscription" 8 | "github.com/DrmagicE/gmqtt/persistence/unack" 9 | ) 10 | 11 | type NewPersistence func(config config.Config) (Persistence, error) 12 | 13 | type Persistence interface { 14 | Open() error 15 | NewQueueStore(config config.Config, defaultNotifier queue.Notifier, clientID string) (queue.Store, error) 16 | NewSubscriptionStore(config config.Config) (subscription.Store, error) 17 | NewSessionStore(config config.Config) (session.Store, error) 18 | NewUnackStore(config config.Config, clientID string) (unack.Store, error) 19 | Close() error 20 | } 21 | -------------------------------------------------------------------------------- /server/plugin.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/DrmagicE/gmqtt/config" 5 | ) 6 | 7 | // HookWrapper groups all hook wrappers function 8 | type HookWrapper struct { 9 | OnBasicAuthWrapper OnBasicAuthWrapper 10 | OnEnhancedAuthWrapper OnEnhancedAuthWrapper 11 | OnConnectedWrapper OnConnectedWrapper 12 | OnReAuthWrapper OnReAuthWrapper 13 | OnSessionCreatedWrapper OnSessionCreatedWrapper 14 | OnSessionResumedWrapper OnSessionResumedWrapper 15 | OnSessionTerminatedWrapper OnSessionTerminatedWrapper 16 | OnSubscribeWrapper OnSubscribeWrapper 17 | OnSubscribedWrapper OnSubscribedWrapper 18 | OnUnsubscribeWrapper OnUnsubscribeWrapper 19 | OnUnsubscribedWrapper OnUnsubscribedWrapper 20 | OnMsgArrivedWrapper OnMsgArrivedWrapper 21 | OnMsgDroppedWrapper OnMsgDroppedWrapper 22 | OnDeliveredWrapper OnDeliveredWrapper 23 | OnClosedWrapper OnClosedWrapper 24 | OnAcceptWrapper OnAcceptWrapper 25 | OnStopWrapper OnStopWrapper 26 | OnWillPublishWrapper OnWillPublishWrapper 27 | OnWillPublishedWrapper OnWillPublishedWrapper 28 | } 29 | 30 | // NewPlugin is the constructor of a plugin. 31 | type NewPlugin func(config config.Config) (Plugin, error) 32 | 33 | // Plugin is the interface need to be implemented for every plugins. 34 | type Plugin interface { 35 | // Load will be called in server.Run(). If return error, the server will panic. 36 | Load(service Server) error 37 | // Unload will be called when the server is shutdown, the return error is only for logging 38 | Unload() error 39 | // HookWrapper returns all hook wrappers that used by the plugin. 40 | // Return a empty wrapper if the plugin does not need any hooks 41 | HookWrapper() HookWrapper 42 | // Name return the plugin name 43 | Name() string 44 | } 45 | -------------------------------------------------------------------------------- /server/plugin_mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: server/plugin.go 3 | 4 | // Package server is a generated GoMock package. 5 | package server 6 | 7 | import ( 8 | gomock "github.com/golang/mock/gomock" 9 | reflect "reflect" 10 | ) 11 | 12 | // MockPlugin is a mock of Plugin interface 13 | type MockPlugin struct { 14 | ctrl *gomock.Controller 15 | recorder *MockPluginMockRecorder 16 | } 17 | 18 | // MockPluginMockRecorder is the mock recorder for MockPlugin 19 | type MockPluginMockRecorder struct { 20 | mock *MockPlugin 21 | } 22 | 23 | // NewMockPlugin creates a new mock instance 24 | func NewMockPlugin(ctrl *gomock.Controller) *MockPlugin { 25 | mock := &MockPlugin{ctrl: ctrl} 26 | mock.recorder = &MockPluginMockRecorder{mock} 27 | return mock 28 | } 29 | 30 | // EXPECT returns an object that allows the caller to indicate expected use 31 | func (m *MockPlugin) EXPECT() *MockPluginMockRecorder { 32 | return m.recorder 33 | } 34 | 35 | // Load mocks base method 36 | func (m *MockPlugin) Load(service Server) error { 37 | m.ctrl.T.Helper() 38 | ret := m.ctrl.Call(m, "Load", service) 39 | ret0, _ := ret[0].(error) 40 | return ret0 41 | } 42 | 43 | // Load indicates an expected call of Load 44 | func (mr *MockPluginMockRecorder) Load(service interface{}) *gomock.Call { 45 | mr.mock.ctrl.T.Helper() 46 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Load", reflect.TypeOf((*MockPlugin)(nil).Load), service) 47 | } 48 | 49 | // Unload mocks base method 50 | func (m *MockPlugin) Unload() error { 51 | m.ctrl.T.Helper() 52 | ret := m.ctrl.Call(m, "Unload") 53 | ret0, _ := ret[0].(error) 54 | return ret0 55 | } 56 | 57 | // Unload indicates an expected call of Unload 58 | func (mr *MockPluginMockRecorder) Unload() *gomock.Call { 59 | mr.mock.ctrl.T.Helper() 60 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unload", reflect.TypeOf((*MockPlugin)(nil).Unload)) 61 | } 62 | 63 | // HookWrapper mocks base method 64 | func (m *MockPlugin) HookWrapper() HookWrapper { 65 | m.ctrl.T.Helper() 66 | ret := m.ctrl.Call(m, "HookWrapper") 67 | ret0, _ := ret[0].(HookWrapper) 68 | return ret0 69 | } 70 | 71 | // HookWrapper indicates an expected call of HookWrapper 72 | func (mr *MockPluginMockRecorder) HookWrapper() *gomock.Call { 73 | mr.mock.ctrl.T.Helper() 74 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HookWrapper", reflect.TypeOf((*MockPlugin)(nil).HookWrapper)) 75 | } 76 | 77 | // Name mocks base method 78 | func (m *MockPlugin) Name() string { 79 | m.ctrl.T.Helper() 80 | ret := m.ctrl.Call(m, "Name") 81 | ret0, _ := ret[0].(string) 82 | return ret0 83 | } 84 | 85 | // Name indicates an expected call of Name 86 | func (mr *MockPluginMockRecorder) Name() *gomock.Call { 87 | mr.mock.ctrl.T.Helper() 88 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockPlugin)(nil).Name)) 89 | } 90 | -------------------------------------------------------------------------------- /server/publish_service.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import "github.com/DrmagicE/gmqtt" 4 | 5 | type publishService struct { 6 | server *server 7 | } 8 | 9 | func (p *publishService) Publish(message *gmqtt.Message) { 10 | p.server.mu.Lock() 11 | p.server.deliverMessage("", message, defaultIterateOptions(message.Topic)) 12 | p.server.mu.Unlock() 13 | } 14 | -------------------------------------------------------------------------------- /server/queue_notifier.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "context" 5 | 6 | "go.uber.org/zap" 7 | 8 | "github.com/DrmagicE/gmqtt" 9 | "github.com/DrmagicE/gmqtt/persistence/queue" 10 | ) 11 | 12 | // queueNotifier implements queue.Notifier interface. 13 | type queueNotifier struct { 14 | dropHook OnMsgDropped 15 | sts *statsManager 16 | cli *client 17 | } 18 | 19 | // defaultNotifier is used to init the notifier when using a persistent session store (e.g redis) which can load session data 20 | // while bootstrapping. 21 | func defaultNotifier(dropHook OnMsgDropped, sts *statsManager, clientID string) *queueNotifier { 22 | return &queueNotifier{ 23 | dropHook: dropHook, 24 | sts: sts, 25 | cli: &client{opts: &ClientOptions{ClientID: clientID}, status: Connected + 1}, 26 | } 27 | } 28 | 29 | func (q *queueNotifier) notifyDropped(msg *gmqtt.Message, err error) { 30 | cid := q.cli.opts.ClientID 31 | zaplog.Warn("message dropped", zap.String("client_id", cid), zap.Error(err)) 32 | q.sts.messageDropped(msg.QoS, q.cli.opts.ClientID, err) 33 | if q.dropHook != nil { 34 | q.dropHook(context.Background(), cid, msg, err) 35 | } 36 | } 37 | 38 | func (q *queueNotifier) NotifyDropped(elem *queue.Elem, err error) { 39 | cid := q.cli.opts.ClientID 40 | if err == queue.ErrDropExpiredInflight && q.cli.IsConnected() { 41 | q.cli.pl.release(elem.ID()) 42 | } 43 | if pub, ok := elem.MessageWithID.(*queue.Publish); ok { 44 | q.notifyDropped(pub.Message, err) 45 | } else { 46 | zaplog.Warn("message dropped", zap.String("client_id", cid), zap.Error(err)) 47 | } 48 | } 49 | 50 | func (q *queueNotifier) NotifyInflightAdded(delta int) { 51 | cid := q.cli.opts.ClientID 52 | if delta > 0 { 53 | q.sts.addInflight(cid, uint64(delta)) 54 | } 55 | if delta < 0 { 56 | q.sts.decInflight(cid, uint64(-delta)) 57 | } 58 | 59 | } 60 | 61 | func (q *queueNotifier) NotifyMsgQueueAdded(delta int) { 62 | cid := q.cli.opts.ClientID 63 | if delta > 0 { 64 | q.sts.addQueueLen(cid, uint64(delta)) 65 | } 66 | if delta < 0 { 67 | q.sts.decQueueLen(cid, uint64(-delta)) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /server/service.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/DrmagicE/gmqtt" 5 | "github.com/DrmagicE/gmqtt/persistence/session" 6 | "github.com/DrmagicE/gmqtt/persistence/subscription" 7 | "github.com/DrmagicE/gmqtt/retained" 8 | ) 9 | 10 | // Publisher provides the ability to Publish messages to the broker. 11 | type Publisher interface { 12 | // Publish Publish a message to broker. 13 | // Calling this method will not trigger OnMsgArrived hook. 14 | Publish(message *gmqtt.Message) 15 | } 16 | 17 | // ClientIterateFn is the callback function used by ClientService.IterateClient 18 | // Return false means to stop the iteration. 19 | type ClientIterateFn = func(client Client) bool 20 | 21 | // ClientService provides the ability to query and close clients. 22 | type ClientService interface { 23 | IterateSession(fn session.IterateFn) error 24 | GetSession(clientID string) (*gmqtt.Session, error) 25 | GetClient(clientID string) Client 26 | IterateClient(fn ClientIterateFn) 27 | TerminateSession(clientID string) 28 | } 29 | 30 | // SubscriptionService providers the ability to query and add/delete subscriptions. 31 | type SubscriptionService interface { 32 | // Subscribe adds subscriptions to a specific client. 33 | // Notice: 34 | // This method will succeed even if the client is not exists, the subscriptions 35 | // will affect the new client with the client id. 36 | Subscribe(clientID string, subscriptions ...*gmqtt.Subscription) (rs subscription.SubscribeResult, err error) 37 | // Unsubscribe removes subscriptions of a specific client. 38 | Unsubscribe(clientID string, topics ...string) error 39 | // UnsubscribeAll removes all subscriptions of a specific client. 40 | UnsubscribeAll(clientID string) error 41 | // Iterate iterates all subscriptions. The callback is called once for each subscription. 42 | // If callback return false, the iteration will be stopped. 43 | // Notice: 44 | // The results are not sorted in any way, no ordering of any kind is guaranteed. 45 | // This method will walk through all subscriptions, 46 | // so it is a very expensive operation. Do not call it frequently. 47 | Iterate(fn subscription.IterateFn, options subscription.IterationOptions) 48 | subscription.StatsReader 49 | } 50 | 51 | // RetainedService providers the ability to query and add/delete retained messages. 52 | type RetainedService interface { 53 | retained.Store 54 | } 55 | -------------------------------------------------------------------------------- /server/stats_mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: server/stats.go 3 | 4 | // Package server is a generated GoMock package. 5 | package server 6 | 7 | import ( 8 | gomock "github.com/golang/mock/gomock" 9 | reflect "reflect" 10 | ) 11 | 12 | // MockStatsReader is a mock of StatsReader interface 13 | type MockStatsReader struct { 14 | ctrl *gomock.Controller 15 | recorder *MockStatsReaderMockRecorder 16 | } 17 | 18 | // MockStatsReaderMockRecorder is the mock recorder for MockStatsReader 19 | type MockStatsReaderMockRecorder struct { 20 | mock *MockStatsReader 21 | } 22 | 23 | // NewMockStatsReader creates a new mock instance 24 | func NewMockStatsReader(ctrl *gomock.Controller) *MockStatsReader { 25 | mock := &MockStatsReader{ctrl: ctrl} 26 | mock.recorder = &MockStatsReaderMockRecorder{mock} 27 | return mock 28 | } 29 | 30 | // EXPECT returns an object that allows the caller to indicate expected use 31 | func (m *MockStatsReader) EXPECT() *MockStatsReaderMockRecorder { 32 | return m.recorder 33 | } 34 | 35 | // GetGlobalStats mocks base method 36 | func (m *MockStatsReader) GetGlobalStats() GlobalStats { 37 | m.ctrl.T.Helper() 38 | ret := m.ctrl.Call(m, "GetGlobalStats") 39 | ret0, _ := ret[0].(GlobalStats) 40 | return ret0 41 | } 42 | 43 | // GetGlobalStats indicates an expected call of GetGlobalStats 44 | func (mr *MockStatsReaderMockRecorder) GetGlobalStats() *gomock.Call { 45 | mr.mock.ctrl.T.Helper() 46 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGlobalStats", reflect.TypeOf((*MockStatsReader)(nil).GetGlobalStats)) 47 | } 48 | 49 | // GetClientStats mocks base method 50 | func (m *MockStatsReader) GetClientStats(clientID string) (ClientStats, bool) { 51 | m.ctrl.T.Helper() 52 | ret := m.ctrl.Call(m, "GetClientStats", clientID) 53 | ret0, _ := ret[0].(ClientStats) 54 | ret1, _ := ret[1].(bool) 55 | return ret0, ret1 56 | } 57 | 58 | // GetClientStats indicates an expected call of GetClientStats 59 | func (mr *MockStatsReaderMockRecorder) GetClientStats(clientID interface{}) *gomock.Call { 60 | mr.mock.ctrl.T.Helper() 61 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClientStats", reflect.TypeOf((*MockStatsReader)(nil).GetClientStats), clientID) 62 | } 63 | -------------------------------------------------------------------------------- /server/testdata/ca.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIC0zCCAbsCFHNchtacwmLUKBWHbVAe0M+2H+tEMA0GCSqGSIb3DQEBCwUAMCUx 3 | CzAJBgNVBAYTAkNOMRYwFAYDVQQDDA1kcm1hZ2ljLmxvY2FsMCAXDTIxMDEyNDEz 4 | NDUxOVoYDzIxMjAxMjMxMTM0NTE5WjAlMQswCQYDVQQGEwJDTjEWMBQGA1UEAwwN 5 | ZHJtYWdpYy5sb2NhbDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAL3v 6 | sYOylxpCCNWLyLOjL+smnZgFsbt7PL9wxOJOgTFVesVV/mRnlydn9Ism9ERCIBHF 7 | yfsX6lnOKkqGisoTt5DuBphwJeZjSJYOjTIgQdcVbLyvspPN1+no2qAO/jv1Fsg6 8 | WXmq6lkEc1LPE+fVlQG9pl6ypBdrrCzGKFtfEI+B3nuDIlLhzt2avZ4RmaFZjQJW 9 | WtoWHN56ujoZLzUVv+tjc/wgmMTCA36TYS5jBWXOPfqvg0hYRBysqBZACu2jZ/R0 10 | qeCvLwJemUlECpiNbEn2w9ApEKlyM58ArXlVLixkYZxEVa+Ai6q7aYRpNvg3gKZw 11 | R/9zWPn/0t8u4Z7GwykCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAESviZdpHauCg 12 | 2ir8kn314rMK9QrK/nt60z+Cd5FkaFKiHQUuD+obXzri5R2qzHNZLJdOmpzaI+1e 13 | tGHJ1jh0J1ShMDDr9qA/CknBM3r/dzDHneNb8B0xFxOABI5vcywG/xM8Dv/dBIuF 14 | PuWvjvw7EJI3i6Vy2tR885ksDB/ucoNSpWevdXDJdoUxA88vNgt1nMzMy4+IGBYf 15 | TwqK++T3V5DCGQj+24eYgiShHAIchbgoB8F+Suvseo8kd9kFsoORToqmoTM1J4JM 16 | K4u8Xvh5sRnbo7ViwcKPAD3fI14z0mqFCObllp6ynib6WAAGztW2F4khyY4WTHH9 17 | MNE9/UMb6Q== 18 | -----END CERTIFICATE----- 19 | -------------------------------------------------------------------------------- /server/testdata/extfile.cnf: -------------------------------------------------------------------------------- 1 | subjectAltName = DNS:drmagic.local,IP:127.0.0.1 2 | extendedKeyUsage = serverAuth 3 | -------------------------------------------------------------------------------- /server/testdata/openssl.conf: -------------------------------------------------------------------------------- 1 | [req] 2 | distinguished_name = req_distinguished_name 3 | prompt = no 4 | 5 | [req_distinguished_name] 6 | C = CN 7 | CN = drmagic.local -------------------------------------------------------------------------------- /server/testdata/server-cert.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDBDCCAeygAwIBAgIUFxnmuzcIANni8M9czNSrB3dPRSIwDQYJKoZIhvcNAQEL 3 | BQAwJTELMAkGA1UEBhMCQ04xFjAUBgNVBAMMDWRybWFnaWMubG9jYWwwIBcNMjEw 4 | MTI0MTM0NTE5WhgPMjEyMDEyMzExMzQ1MTlaMBgxFjAUBgNVBAMMDWRybWFnaWMu 5 | bG9jYWwwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDW37XX6xuVLB4e 6 | KEJfCJ53bb4hJYE4bLidC0a47pUoLQ+eQF0Dp6vF+d79d35Vq9WJJ2AceZO8s1Zf 7 | 41aGWlioFk8M6TjFvIxEQMdeDkmsnsCGjNBfWHgESv3OZm2pIF2Hww9aNIz2f24K 8 | QBcRYc7wYKYYyfBscMa+aPg7PkaK9FUU/eV+QJHLMElYEwa/vicJ1qeNjlyBczvY 9 | ckDn2TQdk6eOkwSn6ZlsjGlK0B5cb3dL3UCUBeZjLvEOiB4oXXiBuHCnjhrjoJxL 10 | SbZ9MPxQgN/Np34PBjR7dXxIKnLJzk8E+fD7o+yNcUo9Y4JPiJXom6isiwWSIVSa 11 | fP23DGxhAgMBAAGjNzA1MB4GA1UdEQQXMBWCDWRybWFnaWMubG9jYWyHBH8AAAEw 12 | EwYDVR0lBAwwCgYIKwYBBQUHAwEwDQYJKoZIhvcNAQELBQADggEBALM8lVW82KRr 13 | XVh829urMs6emCjeYhdqHFk8QyX48IprOTBFmTrVFNfD8zcX6NlhsPxPFjsWy5ND 14 | E7T0qROQ0x4/9oe8Hr6+wm1qXfSD02aBop+67WBBUFI2bGm44ZSCeEL/1GACaZry 15 | h+knJAQQp5+mHszQDz2XaqzUOE6tfa9guRUHo9GVO9oIJdP/DjaT9XpsNdHczZdD 16 | 1H4Yweit61JaiizA1nMJ1LT0mq8P780InbTgj1r/WgfVhlO1CZ6L3IxGLEUxHg3y 17 | TFRG6Z82rrxi1DA20NZPeB3nDTS7IeIEDpYIn88olTLStiKN6du805nsZ/5+clNs 18 | Fn4N5RMn2vc= 19 | -----END CERTIFICATE----- 20 | -------------------------------------------------------------------------------- /server/testdata/server-key.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | MIIEowIBAAKCAQEA1t+11+sblSweHihCXwied22+ISWBOGy4nQtGuO6VKC0PnkBd 3 | A6erxfne/Xd+VavViSdgHHmTvLNWX+NWhlpYqBZPDOk4xbyMREDHXg5JrJ7AhozQ 4 | X1h4BEr9zmZtqSBdh8MPWjSM9n9uCkAXEWHO8GCmGMnwbHDGvmj4Oz5GivRVFP3l 5 | fkCRyzBJWBMGv74nCdanjY5cgXM72HJA59k0HZOnjpMEp+mZbIxpStAeXG93S91A 6 | lAXmYy7xDogeKF14gbhwp44a46CcS0m2fTD8UIDfzad+DwY0e3V8SCpyyc5PBPnw 7 | +6PsjXFKPWOCT4iV6JuorIsFkiFUmnz9twxsYQIDAQABAoIBAFsPx8LPsorPfZwO 8 | N8KKpo26hn8Jo+/Ds6Fqa/hns/Ko1huc705jOprWQDhu8a1g+0f61fJ7W6722b4d 9 | XEfn9faWLb4tAJBcTZ2HTnZ/2506UiEzgANIPOSk21cjdYndW4XzlogGCU9Vxc62 10 | RpBpQQgCDaInwqpSSQfc+IYy6DZuekPCbm3hBhEF9grY9j2/QrVHKjbx9rGA6jD3 11 | 31FjL1SqGrkgduef2EW9geoduVSGnEyYU0CoVZb/es51c/5rAzv91eC21egu0OPq 12 | XPFtDM5Gz/4iC9wQ7k2EDF5LiKR49DKmAJM4FSRZKCqDYY7NtFHILXC9S5rMeuMQ 13 | 1mVnknkCgYEA/ntEJyIc8h/QzKuLi5ydMZzIXmF6aKW1Gz0ZwMKKDhWMiNLyp+Kc 14 | N2RXdTlmyZNkzcnm0/SmUyaoC5o9NJIVg7HfdyYhgsN/MBHJ9Q+ISHbAogcVugk0 15 | 3CRZ9c3kyVkqiLJeY194/rpI/S7m+/VkeyNqcwBedb2/CDrJP0q0Rk8CgYEA2Cfw 16 | /XVYG3TqziDqWonZtINc275yP5ecw4N9qXIjuoqH6L7N/MzC52QfkpJFoO+bcGxe 17 | umg4mjFA67RCCpFLY2jhh4nS4bAwh3bM+EXgnL1rrgPAz4ZasnsbXEcivP+7a4SZ 18 | pKRT/20CUjimeykhZxzAZcvuENCkGA+WmpWVJk8CgYEA+hIFkfMKwL+U/lsgoMwB 19 | CKzJlT1y/XzA8IhlUy+YXGi+lgG9ZE7iNeiLrO0AXdtSdosOIoDKJPHatrQVqyBW 20 | tfhH4Rz+VzJnPMRuUju2L4dKmq4dopfDcwThxhNS3K2bh4LIEBzUmHRUnz/EyhmF 21 | aSAPTf0x1b/lBmBGPMTbTC8CgYAJmDRFO9kuVtE5VxKv9CB6t73+bwSpN/SYZRTF 22 | 2bAmTpHbzeRczUX1eWdBXUbD7v7KTbUitw+UII2OKNEpoOtkvToNhxuaMvTkfmx4 23 | tLlUm7/U2IvNalxKQdakEPBEzWEnU5pySW0FEHSi66rQGrJF3mvX2OZ3TpuKCd8Y 24 | e31EVwKBgDWP9poeeli+0+EqbsZtP7lGN63yp4XYMYim9zKojmAu6tuAUkUCbIzn 25 | 5Jdeoe6I7LgTpnahC+c+a8/4JuPKrYeWV/Tf9R+zatGyYzZ6W4qzlpYJaE6zFAJq 26 | b+mx/By39E+WH3bRQBhcsp1hmxclFhd4KWLdqm5+Zsycu7JvUSjY 27 | -----END RSA PRIVATE KEY----- 28 | -------------------------------------------------------------------------------- /server/testdata/test_gen.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ca key 3 | openssl genrsa -out ca-key.pem 2048 4 | # ca certificate 5 | openssl req -new -x509 -days 36500 -key ca-key.pem -out ca.pem -config openssl.conf 6 | # server key 7 | openssl genrsa -out server-key.pem 2048 8 | # server csr 9 | openssl req -subj "/CN=drmagic.local" -new -key server-key.pem -out server.csr 10 | # sign the public key with our CA 11 | openssl x509 -req -days 36500 -in server.csr -CA ca.pem -CAkey ca-key.pem \ 12 | -CAcreateserial -out server-cert.pem -extfile extfile.cnf 13 | 14 | rm ./ca.srl ./ca-key.pem ./server.csr -------------------------------------------------------------------------------- /server/topic_alias.go: -------------------------------------------------------------------------------- 1 | package server 2 | 3 | import ( 4 | "github.com/DrmagicE/gmqtt/config" 5 | "github.com/DrmagicE/gmqtt/pkg/packets" 6 | ) 7 | 8 | type NewTopicAliasManager func(config config.Config, maxAlias uint16, clientID string) TopicAliasManager 9 | 10 | // TopicAliasManager manage the topic alias for a V5 client. 11 | // see topicalias/fifo for more details. 12 | type TopicAliasManager interface { 13 | // Check return the alias number and whether the alias exist. 14 | // For examples: 15 | // If the Publish alias exist and the manager decides to use the alias, it return the alias number and true. 16 | // If the Publish alias exist, but the manager decides not to use alias, it return 0 and true. 17 | // If the Publish alias not exist and the manager decides to assign a new alias, it return the new alias and false. 18 | // If the Publish alias not exist, but the manager decides not to assign alias, it return the 0 and false. 19 | Check(publish *packets.Publish) (alias uint16, exist bool) 20 | } 21 | -------------------------------------------------------------------------------- /server/topic_alias_mock.go: -------------------------------------------------------------------------------- 1 | // Code generated by MockGen. DO NOT EDIT. 2 | // Source: server/topic_alias.go 3 | 4 | // Package server is a generated GoMock package. 5 | package server 6 | 7 | import ( 8 | packets "github.com/DrmagicE/gmqtt/pkg/packets" 9 | gomock "github.com/golang/mock/gomock" 10 | reflect "reflect" 11 | ) 12 | 13 | // MockTopicAliasManager is a mock of TopicAliasManager interface 14 | type MockTopicAliasManager struct { 15 | ctrl *gomock.Controller 16 | recorder *MockTopicAliasManagerMockRecorder 17 | } 18 | 19 | // MockTopicAliasManagerMockRecorder is the mock recorder for MockTopicAliasManager 20 | type MockTopicAliasManagerMockRecorder struct { 21 | mock *MockTopicAliasManager 22 | } 23 | 24 | // NewMockTopicAliasManager creates a new mock instance 25 | func NewMockTopicAliasManager(ctrl *gomock.Controller) *MockTopicAliasManager { 26 | mock := &MockTopicAliasManager{ctrl: ctrl} 27 | mock.recorder = &MockTopicAliasManagerMockRecorder{mock} 28 | return mock 29 | } 30 | 31 | // EXPECT returns an object that allows the caller to indicate expected use 32 | func (m *MockTopicAliasManager) EXPECT() *MockTopicAliasManagerMockRecorder { 33 | return m.recorder 34 | } 35 | 36 | // Check mocks base method 37 | func (m *MockTopicAliasManager) Check(publish *packets.Publish) (uint16, bool) { 38 | m.ctrl.T.Helper() 39 | ret := m.ctrl.Call(m, "Check", publish) 40 | ret0, _ := ret[0].(uint16) 41 | ret1, _ := ret[1].(bool) 42 | return ret0, ret1 43 | } 44 | 45 | // Check indicates an expected call of Check 46 | func (mr *MockTopicAliasManagerMockRecorder) Check(publish interface{}) *gomock.Call { 47 | mr.mock.ctrl.T.Helper() 48 | return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Check", reflect.TypeOf((*MockTopicAliasManager)(nil).Check), publish) 49 | } 50 | -------------------------------------------------------------------------------- /session.go: -------------------------------------------------------------------------------- 1 | package gmqtt 2 | 3 | import ( 4 | "time" 5 | ) 6 | 7 | // Session represents a MQTT session. 8 | type Session struct { 9 | // ClientID represents the client id. 10 | ClientID string 11 | // Will is the will message of the client, can be nil if there is no will message. 12 | Will *Message 13 | // WillDelayInterval represents the Will Delay Interval in seconds 14 | WillDelayInterval uint32 15 | // ConnectedAt is the session create time. 16 | ConnectedAt time.Time 17 | // ExpiryInterval represents the Session Expiry Interval in seconds 18 | ExpiryInterval uint32 19 | } 20 | 21 | // IsExpired return whether the session is expired 22 | func (s *Session) IsExpired(now time.Time) bool { 23 | return s.ConnectedAt.Add(time.Duration(s.ExpiryInterval) * time.Second).Before(now) 24 | } 25 | -------------------------------------------------------------------------------- /subscription.go: -------------------------------------------------------------------------------- 1 | package gmqtt 2 | 3 | import ( 4 | "errors" 5 | 6 | "github.com/DrmagicE/gmqtt/pkg/packets" 7 | ) 8 | 9 | // Subscription represents a subscription in gmqtt. 10 | type Subscription struct { 11 | // ShareName is the share name of a shared subscription. 12 | // set to "" if it is a non-shared subscription. 13 | ShareName string 14 | // TopicFilter is the topic filter which does not include the share name. 15 | TopicFilter string 16 | // ID is the subscription identifier 17 | ID uint32 18 | // The following fields are Subscription Options. 19 | // See: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901169 20 | 21 | // QoS is the qos level of the Subscription. 22 | QoS packets.QoS 23 | // NoLocal is the No Local option. 24 | NoLocal bool 25 | // RetainAsPublished is the Retain As Published option. 26 | RetainAsPublished bool 27 | // RetainHandling the Retain Handling option. 28 | RetainHandling byte 29 | } 30 | 31 | // GetFullTopicName returns the full topic name of the subscription. 32 | func (s *Subscription) GetFullTopicName() string { 33 | if s.ShareName != "" { 34 | return "$share/" + s.ShareName + "/" + s.TopicFilter 35 | } 36 | return s.TopicFilter 37 | } 38 | 39 | // Copy makes a copy of subscription. 40 | func (s *Subscription) Copy() *Subscription { 41 | return &Subscription{ 42 | ShareName: s.ShareName, 43 | TopicFilter: s.TopicFilter, 44 | ID: s.ID, 45 | QoS: s.QoS, 46 | NoLocal: s.NoLocal, 47 | RetainAsPublished: s.RetainAsPublished, 48 | RetainHandling: s.RetainHandling, 49 | } 50 | } 51 | 52 | // Validate returns whether the subscription is valid. 53 | // If you can ensure the subscription is valid then just skip the validation. 54 | func (s *Subscription) Validate() error { 55 | if !packets.ValidV5Topic([]byte(s.GetFullTopicName())) { 56 | return errors.New("invalid topic name") 57 | } 58 | if s.QoS > 2 { 59 | return errors.New("invalid qos") 60 | } 61 | if s.RetainHandling != 0 && s.RetainHandling != 1 && s.RetainHandling != 2 { 62 | return errors.New("invalid retain handling") 63 | } 64 | return nil 65 | } 66 | -------------------------------------------------------------------------------- /topicalias/fifo/fifo.go: -------------------------------------------------------------------------------- 1 | package fifo 2 | 3 | import ( 4 | "container/list" 5 | 6 | "github.com/DrmagicE/gmqtt/config" 7 | "github.com/DrmagicE/gmqtt/pkg/packets" 8 | "github.com/DrmagicE/gmqtt/server" 9 | ) 10 | 11 | var _ server.TopicAliasManager = (*Queue)(nil) 12 | 13 | func init() { 14 | server.RegisterTopicAliasMgrFactory("fifo", New) 15 | } 16 | 17 | // New is the constructor of Queue. 18 | func New(config config.Config, maxAlias uint16, clientID string) server.TopicAliasManager { 19 | return &Queue{ 20 | clientID: clientID, 21 | topicAlias: &topicAlias{ 22 | max: int(maxAlias), 23 | alias: list.New(), 24 | index: make(map[string]uint16), 25 | }, 26 | } 27 | } 28 | 29 | // Queue is the fifo queue which store all topic alias for one client 30 | type Queue struct { 31 | clientID string 32 | topicAlias *topicAlias 33 | } 34 | type topicAlias struct { 35 | max int 36 | alias *list.List 37 | // topic name => alias 38 | index map[string]uint16 39 | } 40 | type aliasElem struct { 41 | topic string 42 | alias uint16 43 | } 44 | 45 | func (q *Queue) Check(publish *packets.Publish) (alias uint16, exist bool) { 46 | topicName := string(publish.TopicName) 47 | // alias exist 48 | if a, ok := q.topicAlias.index[topicName]; ok { 49 | return a, true 50 | } 51 | l := q.topicAlias.alias.Len() 52 | // alias has been exhausted 53 | if l == q.topicAlias.max { 54 | first := q.topicAlias.alias.Front() 55 | elem := first.Value.(*aliasElem) 56 | q.topicAlias.alias.Remove(first) 57 | delete(q.topicAlias.index, elem.topic) 58 | alias = elem.alias 59 | } else { 60 | alias = uint16(l + 1) 61 | } 62 | q.topicAlias.alias.PushBack(&aliasElem{ 63 | topic: topicName, 64 | alias: alias, 65 | }) 66 | q.topicAlias.index[topicName] = alias 67 | return 68 | } 69 | -------------------------------------------------------------------------------- /topicalias/fifo/fifo_test.go: -------------------------------------------------------------------------------- 1 | package fifo 2 | 3 | import ( 4 | "strconv" 5 | "testing" 6 | 7 | "github.com/golang/mock/gomock" 8 | "github.com/stretchr/testify/assert" 9 | 10 | "github.com/DrmagicE/gmqtt/config" 11 | "github.com/DrmagicE/gmqtt/pkg/packets" 12 | ) 13 | 14 | func TestQueue(t *testing.T) { 15 | a := assert.New(t) 16 | ctrl := gomock.NewController(t) 17 | defer ctrl.Finish() 18 | 19 | cid := "clientID" 20 | max := uint16(10) 21 | q := New(config.DefaultConfig(), max, cid).(*Queue) 22 | for i := uint16(1); i <= max; i++ { 23 | alias, ok := q.Check(&packets.Publish{ 24 | TopicName: []byte(strconv.Itoa(int(i))), 25 | }) 26 | a.Equal(i, alias) 27 | a.False(ok) 28 | } 29 | alias := uint16(1) 30 | for e := q.topicAlias.alias.Front(); e != nil; e = e.Next() { 31 | elem := e.Value.(*aliasElem) 32 | a.Equal(alias, elem.alias) 33 | a.Equal(strconv.Itoa(int(alias)), elem.topic) 34 | alias++ 35 | } 36 | a.Equal(10, q.topicAlias.alias.Len()) 37 | 38 | // alias exist 39 | alias, ok := q.Check(&packets.Publish{TopicName: []byte("1")}) 40 | a.True(ok) 41 | a.EqualValues(1, alias) 42 | 43 | alias, ok = q.Check(&packets.Publish{TopicName: []byte("not exist")}) 44 | a.False(ok) 45 | a.EqualValues(1, alias) 46 | 47 | } 48 | --------------------------------------------------------------------------------