├── .gitignore ├── .travis.yml ├── Gopkg.lock ├── Gopkg.toml ├── LICENSE ├── README.md ├── bin ├── check_fmt.sh ├── check_mocks.sh ├── cover.sh ├── coveralls.sh ├── install_tools.sh ├── regenerate_mocks.sh └── test.sh ├── capability ├── light.go ├── light_test.go ├── power.go └── power_test.go ├── cli ├── .gitignore ├── control.go ├── discovery.go └── main.go ├── client.go ├── client_test.go ├── common ├── device.go ├── event.go ├── logger.go └── types.go ├── device ├── basedevice.go ├── basedevice_test.go ├── classification.go ├── classification_test.go ├── device.go ├── mocks │ ├── Device.go │ └── RefreshThrottle.go ├── powerplug.go ├── powerplug_test.go ├── product │ ├── product.go │ └── product_test.go ├── rthrottle │ ├── mocks │ │ └── RefreshThrottle.go │ ├── refresh_throttle.go │ └── refresh_throttle_test.go ├── yeelight.go └── yeelight_test.go ├── go.mod ├── go.sum ├── protocol ├── mocks │ └── Protocol.go ├── packet │ ├── crypto.go │ ├── crypto_test.go │ ├── mocks │ │ └── Crypto.go │ ├── packet.go │ └── packet_test.go ├── protocol.go ├── protocol_test.go ├── tokens │ ├── token_store.go │ ├── token_store_test.go │ └── tokens.example.txt └── transport │ ├── inbound.go │ ├── mocks │ ├── Conn.go │ ├── Inbound.go │ ├── InboundConn.go │ ├── Outbound.go │ ├── OutboundConn.go │ └── Transport.go │ ├── outbound.go │ └── transport.go ├── simulator ├── .gitignore ├── README.md ├── capability │ ├── capability.go │ ├── info.go │ ├── light.go │ └── power.go ├── device │ ├── device.go │ ├── powerplug.go │ └── yeelight.go └── main.go ├── subscription ├── common │ ├── common.go │ └── mocks │ │ ├── Subscription.go │ │ └── SubscriptionTarget.go ├── subscription.go ├── subscription │ ├── subscription.go │ └── subscription_test.go └── target │ ├── subscription_target.go │ └── subscription_target_test.go ├── tokencapture └── main.go └── tools └── wireshark └── miio.lua /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | .idea 3 | coverage.html 4 | vendor 5 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | go: 3 | - stable 4 | 5 | env: 6 | - PATH=$PATH:$GOPATH/bin GO111MODULE=on 7 | 8 | before_install: 9 | - ./bin/install_tools.sh 10 | 11 | install: 12 | - go mod download 13 | 14 | script: 15 | - ./bin/check_fmt.sh 16 | - ./bin/check_mocks.sh 17 | - ./bin/coveralls.sh 18 | -------------------------------------------------------------------------------- /Gopkg.lock: -------------------------------------------------------------------------------- 1 | # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. 2 | 3 | 4 | [[projects]] 5 | digest = "1:15d017551627c8bb091bde628215b2861bed128855343fdd570c62d08871f6e1" 6 | name = "github.com/alecthomas/kingpin" 7 | packages = ["."] 8 | pruneopts = "" 9 | revision = "947dcec5ba9c011838740e680966fd7087a71d0d" 10 | version = "v2.2.6" 11 | 12 | [[projects]] 13 | branch = "master" 14 | digest = "1:a74730e052a45a3fab1d310fdef2ec17ae3d6af16228421e238320846f2aaec8" 15 | name = "github.com/alecthomas/template" 16 | packages = [ 17 | ".", 18 | "parse", 19 | ] 20 | pruneopts = "" 21 | revision = "a0175ee3bccc567396460bf5acd36800cb10c49c" 22 | 23 | [[projects]] 24 | branch = "master" 25 | digest = "1:8483994d21404c8a1d489f6be756e25bfccd3b45d65821f25695577791a08e68" 26 | name = "github.com/alecthomas/units" 27 | packages = ["."] 28 | pruneopts = "" 29 | revision = "2efee857e7cfd4f3d0138cc3cbb1b4966962b93a" 30 | 31 | [[projects]] 32 | branch = "master" 33 | digest = "1:afaa6de27e2d86b66cf71d55096f00e32b2ef40ec3349b535555aa81c77bc7d3" 34 | name = "github.com/benbjohnson/clock" 35 | packages = ["."] 36 | pruneopts = "" 37 | revision = "7dc76406b6d3c05b5f71a86293cbcf3c4ea03b19" 38 | 39 | [[projects]] 40 | digest = "1:56c130d885a4aacae1dd9c7b71cfe39912c7ebc1ff7d2b46083c8812996dc43b" 41 | name = "github.com/davecgh/go-spew" 42 | packages = ["spew"] 43 | pruneopts = "" 44 | revision = "346938d642f2ec3594ed81d874461961cd0faa76" 45 | version = "v1.1.0" 46 | 47 | [[projects]] 48 | branch = "master" 49 | digest = "1:f880d3b7dfb23226b5601f4e904c60b15cd3c890708957409bbfc67f2720f8f4" 50 | name = "github.com/lunixbochs/struc" 51 | packages = ["."] 52 | pruneopts = "" 53 | revision = "02e4c2afbb2ac4bae6876f52c8273fc4cf5a4b0a" 54 | 55 | [[projects]] 56 | digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" 57 | name = "github.com/pmezard/go-difflib" 58 | packages = ["difflib"] 59 | pruneopts = "" 60 | revision = "792786c7400a136282c1664665ae0a8db921c6c2" 61 | version = "v1.0.0" 62 | 63 | [[projects]] 64 | digest = "1:7f569d906bdd20d906b606415b7d794f798f91a62fcfb6a4daa6d50690fb7a3f" 65 | name = "github.com/satori/go.uuid" 66 | packages = ["."] 67 | pruneopts = "" 68 | revision = "f58768cc1a7a7e77a3bd49e98cdd21419399b6a3" 69 | version = "v1.2.0" 70 | 71 | [[projects]] 72 | digest = "1:3fcbf733a8d810a21265a7f2fe08a3353db2407da052b233f8b204b5afc03d9b" 73 | name = "github.com/sirupsen/logrus" 74 | packages = ["."] 75 | pruneopts = "" 76 | revision = "3e01752db0189b9157070a0e1668a620f9a85da2" 77 | version = "v1.0.6" 78 | 79 | [[projects]] 80 | digest = "1:711eebe744c0151a9d09af2315f0bb729b2ec7637ef4c410fa90a18ef74b65b6" 81 | name = "github.com/stretchr/objx" 82 | packages = ["."] 83 | pruneopts = "" 84 | revision = "477a77ecc69700c7cdeb1fa9e129548e1c1c393c" 85 | version = "v0.1.1" 86 | 87 | [[projects]] 88 | digest = "1:c587772fb8ad29ad4db67575dad25ba17a51f072ff18a22b4f0257a4d9c24f75" 89 | name = "github.com/stretchr/testify" 90 | packages = [ 91 | "assert", 92 | "mock", 93 | ] 94 | pruneopts = "" 95 | revision = "f35b8ab0b5a2cef36673838d662e249dd9c94686" 96 | version = "v1.2.2" 97 | 98 | [[projects]] 99 | branch = "master" 100 | digest = "1:cae234a803b78380e4d769db6036b9fcc8c08ed4ff862571ffc1a958edc1f629" 101 | name = "golang.org/x/crypto" 102 | packages = ["ssh/terminal"] 103 | pruneopts = "" 104 | revision = "c126467f60eb25f8f27e5a981f32a87e3965053f" 105 | 106 | [[projects]] 107 | branch = "master" 108 | digest = "1:dd631ee90bd2e7aa16b6e094217d77a797684b52811374c948c695cbb46b5bbb" 109 | name = "golang.org/x/sys" 110 | packages = [ 111 | "unix", 112 | "windows", 113 | ] 114 | pruneopts = "" 115 | revision = "bd9dbc187b6e1dacfdd2722a87e83093c2d7bd6e" 116 | 117 | [solve-meta] 118 | analyzer-name = "dep" 119 | analyzer-version = 1 120 | input-imports = [ 121 | "github.com/alecthomas/kingpin", 122 | "github.com/benbjohnson/clock", 123 | "github.com/lunixbochs/struc", 124 | "github.com/satori/go.uuid", 125 | "github.com/sirupsen/logrus", 126 | "github.com/stretchr/testify/assert", 127 | "github.com/stretchr/testify/mock", 128 | ] 129 | solver-name = "gps-cdcl" 130 | solver-version = 1 131 | -------------------------------------------------------------------------------- /Gopkg.toml: -------------------------------------------------------------------------------- 1 | # Gopkg.toml example 2 | # 3 | # Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md 4 | # for detailed Gopkg.toml documentation. 5 | # 6 | # required = ["github.com/user/thing/cmd/thing"] 7 | # ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] 8 | # 9 | # [[constraint]] 10 | # name = "github.com/user/project" 11 | # version = "1.0.0" 12 | # 13 | # [[constraint]] 14 | # name = "github.com/user/project2" 15 | # branch = "dev" 16 | # source = "github.com/myfork/project2" 17 | # 18 | # [[override]] 19 | # name = "github.com/x/y" 20 | # version = "2.4.0" 21 | 22 | 23 | [[constraint]] 24 | branch = "master" 25 | name = "github.com/benbjohnson/clock" 26 | 27 | [[constraint]] 28 | branch = "master" 29 | name = "github.com/lunixbochs/struc" 30 | 31 | [[constraint]] 32 | name = "github.com/satori/go.uuid" 33 | version = "1.2.0" 34 | 35 | [[constraint]] 36 | name = "github.com/sirupsen/logrus" 37 | version = "1.0.4" 38 | 39 | [[constraint]] 40 | name = "github.com/stretchr/testify" 41 | version = "1.2.1" 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Nick Whyte 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # miio-go 2 | 3 | [![Coverage Status](https://coveralls.io/repos/github/nickw444/miio-go/badge.svg?branch=master)](https://coveralls.io/github/nickw444/miio-go?branch=master) 4 | 5 | An implementation of the miIO home protocol by Xiaomi written in Golang. Heavily inspired by: 6 | 7 | - [The Javascript Implementation](https://github.com/aholstenson/miio) 8 | - [Protocol Specification](https://github.com/OpenMiHome/mihome-binary-protocol) 9 | - [API design in this Lifx client implementation (pdx/lifx)](https://github.com/pdf/golifx) 10 | 11 | This implementation has been design with the following concerns: 12 | - Testability 13 | - Development without a miIO device handy (or performing any real network operations) 14 | - A simple event-based API. 15 | 16 | ## Supported Devices 17 | At the moment, only the following devices are officially supported by this library. Feel free to 18 | [submit a pull request](), I'd be more than happy to have more devices supported by this library. 19 | 20 | - Xiaomi Mi Smart WiFi Socket (v1 - no USB) (chuangmi.plug.m1) 21 | - Xiamoi Yeelight (yeelink.light.color1) 22 | 23 | 24 | ## Simulator 25 | 26 | A device simulator/emulator exists in the [simulator](simulator/) package. It takes 27 | advantage of the low level network used to communicate with real devices to emulate 28 | hardware devices. 29 | 30 | [Give it a try!](simulator/) 31 | 32 | ## Tokens 33 | Documentation coming soon... 34 | 35 | ## Examples 36 | Documentation coming soon... 37 | 38 | ## CLI 39 | 40 | A CLI exists to allow controlling devices using this library. 41 | 42 | ``` 43 | usage: miio-go CLI [] [ ...] 44 | 45 | CLI application to manually test miio-go functionality 46 | 47 | Flags: 48 | --help Show context-sensitive help (also try --help-long and --help-man). 49 | --local Send broadcast to 127.0.0.1 instead of 255.255.255.255 (For use with locally hosted simulator) 50 | --log-level=warn Set MiiO to a specific log level 51 | 52 | Commands: 53 | help [...] 54 | Show help. 55 | 56 | 57 | control brightness 58 | Set device brightness 59 | 60 | 61 | control power 62 | Set device power 63 | 64 | 65 | control color hsv 66 | Set color using HSV values 67 | 68 | 69 | control color rgb 70 | Set color using RGB values 71 | 72 | 73 | discover 74 | Discover devices on the local network 75 | 76 | ``` 77 | -------------------------------------------------------------------------------- /bin/check_fmt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eu 3 | 4 | ROOT_DIR=$(cd "$(dirname "$0")"/.. && pwd) 5 | cd ${ROOT_DIR} 6 | 7 | go fmt ./... 8 | if ! git diff --exit-code HEAD; then 9 | echo "Code is not formatted, please run: go fmt ./..." 10 | exit 1 11 | fi 12 | -------------------------------------------------------------------------------- /bin/check_mocks.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eu 3 | 4 | ROOT_DIR=$(cd "$(dirname "$0")"/.. && pwd) 5 | cd ${ROOT_DIR} 6 | 7 | ./bin/regenerate_mocks.sh 8 | if ! git diff --exit-code HEAD -- '*.go'; then 9 | echo "Mocks are not up to date, please run: make mocks" 10 | exit 1 11 | fi 12 | -------------------------------------------------------------------------------- /bin/cover.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -eu 3 | ROOT_DIR=$(cd "$(dirname "$0")"/.. && pwd) 4 | 5 | usage() { 6 | local prog 7 | prog=$(basename "$0") 8 | cat <> 16 & 0xff 113 | green = int(*m) >> 8 & 0xff 114 | blue = int(*m) & 0xff 115 | return 116 | } 117 | 118 | func (m *miioRGB) SetComponents(red int, green int, blue int) { 119 | i := 0 120 | i |= red << 16 121 | i |= green << 8 122 | i |= blue 123 | *m = miioRGB(i) 124 | } 125 | -------------------------------------------------------------------------------- /capability/light_test.go: -------------------------------------------------------------------------------- 1 | package capability 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/nickw444/miio-go/protocol/transport" 7 | transportMocks "github.com/nickw444/miio-go/protocol/transport/mocks" 8 | subscriptionMocks "github.com/nickw444/miio-go/subscription/common/mocks" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/mock" 11 | ) 12 | 13 | func TestMiiORGB_GetComponents1(t *testing.T) { 14 | m := miioRGB(0xffffff) 15 | r, g, b := m.GetComponents() 16 | assert.Equal(t, 255, r) 17 | assert.Equal(t, 255, g) 18 | assert.Equal(t, 255, b) 19 | } 20 | 21 | func TestMiiORGB_GetComponents2(t *testing.T) { 22 | m := miioRGB(0xff7f0f) 23 | r, g, b := m.GetComponents() 24 | assert.Equal(t, 255, r) 25 | assert.Equal(t, 127, g) 26 | assert.Equal(t, 15, b) 27 | } 28 | 29 | func TestMiiORGB_SetComponents1(t *testing.T) { 30 | m := miioRGB(0) 31 | m.SetComponents(255, 255, 255) 32 | assert.Equal(t, miioRGB(0xffffff), m) 33 | } 34 | 35 | func TestMiiORGB_SetComponents2(t *testing.T) { 36 | m := miioRGB(0) 37 | m.SetComponents(255, 127, 15) 38 | assert.Equal(t, miioRGB(0xff7f0f), m) 39 | } 40 | 41 | func Light_SetUp() (tt struct { 42 | light *Light 43 | outbound *transportMocks.Outbound 44 | target *subscriptionMocks.SubscriptionTarget 45 | }) { 46 | tt.target = new(subscriptionMocks.SubscriptionTarget) 47 | tt.outbound = new(transportMocks.Outbound) 48 | tt.light = NewLight(tt.target, tt.outbound) 49 | return 50 | } 51 | 52 | func TestLight_Update(t *testing.T) { 53 | tt := Light_SetUp() 54 | 55 | tt.outbound.On("CallAndDeserialize", mock.AnythingOfType("string"), mock.AnythingOfType("[]string"), mock.Anything). 56 | Return(nil). 57 | Run(func(args mock.Arguments) { 58 | resp := args.Get(2).(*transport.Response) 59 | resp.Result = []interface{}{"100", "3", "12345", "128", "100"} 60 | }) 61 | tt.target.On("Publish", mock.Anything).Return(nil).Once() 62 | 63 | err := tt.light.Update() 64 | assert.NoError(t, err) 65 | tt.target.AssertExpectations(t) 66 | } 67 | 68 | func TestLight_UpdateNoChanges(t *testing.T) { 69 | tt := Light_SetUp() 70 | 71 | tt.outbound.On("CallAndDeserialize", mock.AnythingOfType("string"), mock.AnythingOfType("[]string"), mock.Anything). 72 | Return(nil). 73 | Run(func(args mock.Arguments) { 74 | resp := args.Get(2).(*transport.Response) 75 | resp.Result = []interface{}{"0", "0", "0", "0", "0"} 76 | }) 77 | err := tt.light.Update() 78 | assert.NoError(t, err) 79 | tt.target.AssertExpectations(t) 80 | } 81 | 82 | func TestLight_SetRGB(t *testing.T) { 83 | tt := Light_SetUp() 84 | tt.outbound.On("Call", "set_rgb", []interface{}{16777215}).Return(nil, nil) 85 | tt.target.On("Publish", mock.Anything).Return(nil).Once() 86 | 87 | err := tt.light.SetRGB(255, 255, 255) 88 | assert.NoError(t, err) 89 | tt.target.AssertExpectations(t) 90 | } 91 | 92 | func TestLight_SetHSV(t *testing.T) { 93 | tt := Light_SetUp() 94 | tt.outbound.On("Call", "set_hsv", []interface{}{120, 77}).Return(nil, nil) 95 | tt.target.On("Publish", mock.Anything).Return(nil).Once() 96 | 97 | err := tt.light.SetHSV(120, 77) 98 | assert.NoError(t, err) 99 | tt.target.AssertExpectations(t) 100 | } 101 | 102 | func TestLight_SetBrightness(t *testing.T) { 103 | tt := Light_SetUp() 104 | tt.outbound.On("Call", "set_bright", []interface{}{55}).Return(nil, nil) 105 | tt.target.On("Publish", mock.Anything).Return(nil).Once() 106 | 107 | err := tt.light.SetBrightness(55) 108 | assert.NoError(t, err) 109 | tt.target.AssertExpectations(t) 110 | } 111 | -------------------------------------------------------------------------------- /capability/power.go: -------------------------------------------------------------------------------- 1 | package capability 2 | 3 | import ( 4 | "github.com/nickw444/miio-go/common" 5 | "github.com/nickw444/miio-go/protocol/transport" 6 | "github.com/nickw444/miio-go/subscription" 7 | ) 8 | 9 | type Power struct { 10 | subscriptionTarget subscription.SubscriptionTarget 11 | outbound transport.Outbound 12 | powerState common.PowerState 13 | } 14 | 15 | type PowerResponse struct { 16 | Result []common.PowerState `json:"result"` 17 | } 18 | 19 | func NewPower(target subscription.SubscriptionTarget, transport transport.Outbound) *Power { 20 | return &Power{ 21 | subscriptionTarget: target, 22 | outbound: transport, 23 | powerState: common.PowerStateUnknown, 24 | } 25 | } 26 | 27 | func (p *Power) SetPower(state common.PowerState) error { 28 | _, err := p.outbound.Call("set_power", []string{string(state)}) 29 | if err != nil { 30 | return err 31 | } 32 | 33 | // TODO NW: Use the value from the response here. 34 | p.powerState = state 35 | return p.subscriptionTarget.Publish(common.EventUpdatePower{p.powerState}) 36 | } 37 | 38 | func (p *Power) Update() error { 39 | resp := PowerResponse{} 40 | err := p.outbound.CallAndDeserialize("get_prop", []string{"power"}, &resp) 41 | if err != nil { 42 | return err 43 | } 44 | 45 | if resp.Result[0] != p.powerState { 46 | p.powerState = resp.Result[0] 47 | p.subscriptionTarget.Publish(common.EventUpdatePower{p.powerState}) 48 | } 49 | 50 | return nil 51 | } 52 | -------------------------------------------------------------------------------- /capability/power_test.go: -------------------------------------------------------------------------------- 1 | package capability 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/nickw444/miio-go/common" 7 | transportMocks "github.com/nickw444/miio-go/protocol/transport/mocks" 8 | subscriptionMocks "github.com/nickw444/miio-go/subscription/common/mocks" 9 | "github.com/stretchr/testify/assert" 10 | "github.com/stretchr/testify/mock" 11 | ) 12 | 13 | func Power_SetUp() (tt struct { 14 | power *Power 15 | outbound *transportMocks.Outbound 16 | target *subscriptionMocks.SubscriptionTarget 17 | }) { 18 | tt.target = new(subscriptionMocks.SubscriptionTarget) 19 | tt.outbound = new(transportMocks.Outbound) 20 | tt.power = NewPower(tt.target, tt.outbound) 21 | return 22 | } 23 | 24 | // Ensure an event is emitted on power state change 25 | func TestPower_Update(t *testing.T) { 26 | tt := Power_SetUp() 27 | 28 | tt.outbound.On("CallAndDeserialize", mock.AnythingOfType("string"), mock.AnythingOfType("[]string"), mock.Anything). 29 | Return(nil). 30 | Run(func(args mock.Arguments) { 31 | resp := args.Get(2).(*PowerResponse) 32 | resp.Result = []common.PowerState{common.PowerStateOn} 33 | }) 34 | tt.target.On("Publish", mock.Anything).Return(nil).Once() 35 | 36 | err := tt.power.Update() 37 | assert.NoError(t, err) 38 | tt.target.AssertExpectations(t) 39 | } 40 | 41 | // Should bubbdle outbound errors 42 | func TestPower_Update2(t *testing.T) { 43 | tt := Power_SetUp() 44 | 45 | tt.outbound.On("CallAndDeserialize", mock.AnythingOfType("string"), mock.AnythingOfType("[]string"), mock.Anything). 46 | Return(assert.AnError) 47 | 48 | err := tt.power.Update() 49 | assert.Error(t, err) 50 | } 51 | 52 | // Ensure an event is emitted on SetPower. 53 | func TestPower_SetPower(t *testing.T) { 54 | tt := Power_SetUp() 55 | 56 | tt.outbound.On("Call", mock.Anything, mock.Anything).Return(nil, nil) 57 | tt.target.On("Publish", mock.Anything).Return(nil).Once() 58 | 59 | err := tt.power.SetPower(common.PowerStateOn) 60 | assert.NoError(t, err) 61 | tt.target.AssertExpectations(t) 62 | } 63 | 64 | // Ensure outbound is called on SetPower 65 | func TestPower_SetPower2(t *testing.T) { 66 | tt := Power_SetUp() 67 | 68 | tt.target.On("Publish", mock.Anything).Return(nil) 69 | tt.outbound. 70 | On("Call", "set_power", []string{common.PowerStateOn}). 71 | Return(nil, nil) 72 | 73 | err := tt.power.SetPower(common.PowerStateOn) 74 | assert.NoError(t, err) 75 | 76 | tt.outbound.AssertNumberOfCalls(t, "Call", 1) 77 | tt.outbound.AssertExpectations(t) 78 | } 79 | 80 | // Should bubble outbound errors 81 | func TestPower_SetPower3(t *testing.T) { 82 | tt := Power_SetUp() 83 | 84 | tt.outbound. 85 | On("Call", "set_power", []string{common.PowerStateOn}). 86 | Return(nil, assert.AnError) 87 | 88 | err := tt.power.SetPower(common.PowerStateOn) 89 | assert.Error(t, err) 90 | } 91 | -------------------------------------------------------------------------------- /cli/.gitignore: -------------------------------------------------------------------------------- 1 | cli 2 | 3 | -------------------------------------------------------------------------------- /cli/control.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "fmt" 5 | "time" 6 | 7 | "github.com/alecthomas/kingpin" 8 | "github.com/nickw444/miio-go/capability" 9 | "github.com/nickw444/miio-go/common" 10 | "github.com/nickw444/miio-go/device" 11 | ) 12 | 13 | var sharedDevice common.Device 14 | 15 | func findDevice(deviceId uint32, timeout time.Duration) (common.Device, error) { 16 | var timeoutCh <-chan time.Time 17 | timeoutCh = time.After(timeout) 18 | sub, err := sharedClient.NewSubscription() 19 | if err != nil { 20 | panic(err) 21 | } 22 | events := sub.Events() 23 | defer sub.Close() 24 | 25 | for { 26 | select { 27 | case event := <-events: 28 | switch event.(type) { 29 | case common.EventNewDevice: 30 | dev := event.(common.EventNewDevice).Device 31 | if dev.ID() == deviceId { 32 | return dev, nil 33 | } 34 | } 35 | case <-timeoutCh: 36 | return nil, fmt.Errorf("Timed out whilst connecting to device with id %d", deviceId) 37 | } 38 | } 39 | } 40 | 41 | func installControl(app *kingpin.Application) { 42 | controlCmd := app.Command("control", "Control lights") 43 | deviceId := controlCmd.Flag("device-id", "The ID of the device to control").Required().Uint32() 44 | 45 | controlCmd.Action(func(ctx *kingpin.ParseContext) (err error) { 46 | sharedDevice, err = findDevice(*deviceId, time.Second*5) 47 | return 48 | }) 49 | 50 | installBrightness(controlCmd) 51 | installPower(controlCmd) 52 | installColor(controlCmd) 53 | } 54 | 55 | func installBrightness(parent *kingpin.CmdClause) { 56 | cmd := parent.Command("brightness", "Set device brightness") 57 | brightness := cmd.Arg("brightness", "The brightness to set (between 0-100)").Required().Int() 58 | cmd.Action(func(ctx *kingpin.ParseContext) error { 59 | var light *capability.Light 60 | 61 | switch sharedDevice.(type) { 62 | case *device.Yeelight: 63 | light = sharedDevice.(*device.Yeelight).Light 64 | default: 65 | return fmt.Errorf("Device with type %T cannot have brightness adjusted", sharedDevice) 66 | } 67 | 68 | return light.SetBrightness(*brightness) 69 | }) 70 | } 71 | 72 | func installPower(parent *kingpin.CmdClause) { 73 | cmd := parent.Command("power", "Set device power") 74 | state := cmd.Arg("state", "The power state (on/off)").Required().Enum("on", "off") 75 | cmd.Action(func(ctx *kingpin.ParseContext) error { 76 | var power *capability.Power 77 | 78 | switch sharedDevice.(type) { 79 | case *device.Yeelight: 80 | power = sharedDevice.(*device.Yeelight).Power 81 | case *device.PowerPlug: 82 | power = sharedDevice.(*device.PowerPlug).Power 83 | default: 84 | return fmt.Errorf("Device with type %T cannot have brightness adjusted", sharedDevice) 85 | } 86 | 87 | return power.SetPower(common.PowerState(*state)) 88 | }) 89 | } 90 | 91 | func installColor(parent *kingpin.CmdClause) { 92 | cmd := parent.Command("color", "Set device color") 93 | 94 | hsv := cmd.Command("hsv", "Set color using HSV values") 95 | hue := hsv.Arg("hue", "Hue to set (0-360)").Required().Int() 96 | sat := hsv.Arg("saturation", "Saturation to set (0-100)").Required().Int() 97 | 98 | rgb := cmd.Command("rgb", "Set color using RGB values") 99 | red := rgb.Arg("red", "Red value to set (0-255)").Required().Int() 100 | green := rgb.Arg("green", "Green value to set (0-255)").Required().Int() 101 | blue := rgb.Arg("blue", "Blue value to set (0-255)").Required().Int() 102 | 103 | var light *capability.Light 104 | 105 | rgb.Action(func(ctx *kingpin.ParseContext) error { 106 | return light.SetRGB(*red, *green, *blue) 107 | }) 108 | 109 | hsv.Action(func(ctx *kingpin.ParseContext) error { 110 | return light.SetHSV(*hue, *sat) 111 | }) 112 | 113 | cmd.Action(func(ctx *kingpin.ParseContext) error { 114 | switch sharedDevice.(type) { 115 | case *device.Yeelight: 116 | light = sharedDevice.(*device.Yeelight).Light 117 | default: 118 | return fmt.Errorf("Device with type %T cannot have brightness adjusted", sharedDevice) 119 | } 120 | return nil 121 | }) 122 | } 123 | -------------------------------------------------------------------------------- /cli/discovery.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "encoding/hex" 5 | "fmt" 6 | "time" 7 | 8 | "github.com/alecthomas/kingpin" 9 | "github.com/nickw444/miio-go/common" 10 | ) 11 | 12 | func installDiscovery(app *kingpin.Application) { 13 | cmd := app.Command("discover", "Discover devices on the local network") 14 | cmd.Action(func(ctx *kingpin.ParseContext) error { 15 | sharedClient.SetDiscoveryInterval(time.Second * 2) 16 | 17 | sub, err := sharedClient.NewSubscription() 18 | if err != nil { 19 | panic(err) 20 | } 21 | for event := range sub.Events() { 22 | switch event.(type) { 23 | case common.EventNewDevice: 24 | dev := event.(common.EventNewDevice).Device 25 | go writeDeviceInfo(dev) 26 | case common.EventNewMaskedDevice: 27 | deviceId := event.(common.EventNewMaskedDevice).DeviceID 28 | go writeMaskedDeviceInfo(deviceId) 29 | 30 | } 31 | } 32 | return nil 33 | }) 34 | } 35 | 36 | func writeDeviceInfo(dev common.Device) { 37 | deviceInfo, _ := dev.GetInfo() 38 | fmt.Println("-------------") 39 | fmt.Println("Discovered new device:") 40 | fmt.Printf("ID: %d\n", dev.ID()) 41 | fmt.Printf("Firmware Version: %s\n", deviceInfo.FirmwareVersion) 42 | fmt.Printf("Hardware Version: %s\n", deviceInfo.HardwareVersion) 43 | fmt.Printf("Mac Address: %s\n", deviceInfo.MacAddress) 44 | fmt.Printf("Model: %s\n", deviceInfo.Model) 45 | fmt.Printf("Token: %s\n", hex.EncodeToString(dev.GetToken())) 46 | fmt.Println("-------------") 47 | } 48 | 49 | func writeMaskedDeviceInfo(deviceId uint32) { 50 | fmt.Println("-------------") 51 | fmt.Println("Discovered new device with masked token:") 52 | fmt.Printf("ID: %d\n", deviceId) 53 | fmt.Println("You must manually retreive this token in order to communicate with the device.") 54 | fmt.Println("-------------") 55 | } 56 | -------------------------------------------------------------------------------- /cli/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net" 5 | "os" 6 | 7 | "github.com/alecthomas/kingpin" 8 | "github.com/nickw444/miio-go" 9 | "github.com/nickw444/miio-go/common" 10 | "github.com/nickw444/miio-go/protocol" 11 | "github.com/nickw444/miio-go/protocol/tokens" 12 | "github.com/sirupsen/logrus" 13 | ) 14 | 15 | var sharedClient *miio.Client 16 | 17 | func createClient(local bool) (*miio.Client, error) { 18 | addr := net.IPv4bcast 19 | if local { 20 | addr = net.IPv4(127, 0, 0, 1) 21 | } 22 | 23 | tokenStore, err := tokens.FromFile("tokens.txt") 24 | if err != nil { 25 | panic(err) 26 | } 27 | 28 | proto, err := protocol.NewProtocol(protocol.ProtocolConfig{ 29 | BroadcastIP: addr, 30 | TokenStore: tokenStore, 31 | }) 32 | if err != nil { 33 | return nil, err 34 | } 35 | 36 | return miio.NewClientWithProtocol(proto) 37 | } 38 | 39 | func main() { 40 | app := kingpin.New("miio-go CLI", "CLI application to manually test miio-go functionality") 41 | local := app.Flag("local", "Send broadcast to 127.0.0.1 instead of 255.255.255.255 (For use with locally hosted simulator)").Bool() 42 | logLevel := app.Flag("log-level", "Set MiiO to a specific log level").Default("warn").Enum("debug", "warn", "info", "error") 43 | 44 | installControl(app) 45 | installDiscovery(app) 46 | 47 | app.Action(func(ctx *kingpin.ParseContext) error { 48 | level, _ := logrus.ParseLevel(*logLevel) 49 | l := logrus.New() 50 | l.SetLevel(level) 51 | common.SetLogger(l) 52 | 53 | var err error 54 | sharedClient, err = createClient(*local) 55 | return err 56 | }) 57 | 58 | kingpin.MustParse(app.Parse(os.Args[1:])) 59 | } 60 | -------------------------------------------------------------------------------- /client.go: -------------------------------------------------------------------------------- 1 | package miio 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | "time" 7 | 8 | "github.com/nickw444/miio-go/common" 9 | "github.com/nickw444/miio-go/protocol" 10 | "github.com/nickw444/miio-go/protocol/tokens" 11 | "github.com/nickw444/miio-go/subscription" 12 | ) 13 | 14 | type Client struct { 15 | sync.RWMutex 16 | subscription.SubscriptionTarget 17 | 18 | protocol protocol.Protocol 19 | discoveryInterval time.Duration 20 | quitChan chan struct{} 21 | events chan interface{} 22 | } 23 | 24 | // NewClient creates a new default Client with the protocol. 25 | func NewClient() (*Client, error) { 26 | tokenStore, err := tokens.FromFile("tokens.txt") 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | protocolConfig := protocol.ProtocolConfig{ 32 | BroadcastIP: net.IPv4bcast, 33 | TokenStore: tokenStore, 34 | } 35 | 36 | p, err := protocol.NewProtocol(protocolConfig) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | return NewClientWithProtocol(p) 42 | } 43 | 44 | func NewClientWithProtocol(protocol protocol.Protocol) (*Client, error) { 45 | c := &Client{ 46 | SubscriptionTarget: subscription.NewTarget(), 47 | protocol: protocol, 48 | quitChan: make(chan struct{}), 49 | } 50 | 51 | c.SetDiscoveryInterval(time.Second * 15) 52 | 53 | return c, c.init() 54 | } 55 | 56 | func (c *Client) init() error { 57 | if err := c.subscribe(); err != nil { 58 | return err 59 | } 60 | return c.discover() 61 | } 62 | 63 | func (c *Client) SetDiscoveryInterval(interval time.Duration) { 64 | c.discoveryInterval = interval 65 | c.protocol.SetExpiryTime(interval * 2) 66 | } 67 | 68 | func (c *Client) discover() error { 69 | if c.discoveryInterval == 0 { 70 | common.Log.Debugf("Discovery interval is zero, discovery will only be performed once") 71 | return c.protocol.Discover() 72 | } 73 | 74 | _ = c.protocol.Discover() 75 | 76 | go func() { 77 | c.RLock() 78 | tick := time.Tick(c.discoveryInterval) 79 | c.RUnlock() 80 | for { 81 | select { 82 | case <-c.quitChan: 83 | common.Log.Debugf("Quitting discovery loop") 84 | return 85 | default: 86 | } 87 | select { 88 | case <-c.quitChan: 89 | common.Log.Debugf("Quitting discovery loop") 90 | return 91 | case <-tick: 92 | common.Log.Debugf("Performing discovery") 93 | _ = c.protocol.Discover() 94 | } 95 | } 96 | }() 97 | 98 | return nil 99 | } 100 | 101 | // Proxy events from protocol level 102 | func (c *Client) subscribe() error { 103 | sub, err := c.protocol.NewSubscription() 104 | if err != nil { 105 | return err 106 | } 107 | 108 | go func() { 109 | for event := range sub.Events() { 110 | c.Publish(event) 111 | } 112 | }() 113 | return nil 114 | } 115 | -------------------------------------------------------------------------------- /client_test.go: -------------------------------------------------------------------------------- 1 | package miio 2 | -------------------------------------------------------------------------------- /common/device.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import "github.com/nickw444/miio-go/subscription" 4 | 5 | type DeviceInfo struct { 6 | FirmwareVersion string `json:"fw_ver"` 7 | HardwareVersion string `json:"hw_ver"` 8 | MacAddress string `json:"mac"` 9 | Model string `json:"model"` 10 | } 11 | 12 | type Device interface { 13 | subscription.SubscriptionTarget 14 | 15 | ID() uint32 16 | GetLabel() (string, error) 17 | GetInfo() (DeviceInfo, error) 18 | GetToken() []byte 19 | } 20 | -------------------------------------------------------------------------------- /common/event.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | type EventNewDevice struct { 4 | Device Device 5 | } 6 | 7 | type EventNewMaskedDevice struct { 8 | DeviceID uint32 9 | } 10 | 11 | type EventExpiredDevice struct { 12 | Device Device 13 | } 14 | 15 | type EventUpdatePower struct { 16 | PowerState PowerState 17 | } 18 | 19 | type EventUpdateLight struct { 20 | Brightness int 21 | 22 | ColorMode int // 1: rgb mode, 2: color temperature mode, 3: hsv mode 23 | RGB struct { 24 | Red int 25 | Green int 26 | Blue int 27 | } 28 | Hue int 29 | Saturation int 30 | } 31 | -------------------------------------------------------------------------------- /common/logger.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | import "github.com/sirupsen/logrus" 4 | 5 | var ( 6 | Log *logrus.Logger = logrus.New() 7 | ) 8 | 9 | func init() { 10 | Log.SetLevel(logrus.WarnLevel) 11 | } 12 | 13 | func SetLogger(logger *logrus.Logger) { 14 | Log = logger 15 | } 16 | -------------------------------------------------------------------------------- /common/types.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | type PowerState string 4 | 5 | const ( 6 | PowerStateUnknown PowerState = "" 7 | PowerStateOn = "on" 8 | PowerStateOff = "off" 9 | ) 10 | -------------------------------------------------------------------------------- /device/basedevice.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/nickw444/miio-go/common" 7 | "github.com/nickw444/miio-go/device/product" 8 | "github.com/nickw444/miio-go/device/rthrottle" 9 | "github.com/nickw444/miio-go/protocol/packet" 10 | "github.com/nickw444/miio-go/protocol/transport" 11 | "github.com/nickw444/miio-go/subscription" 12 | ) 13 | 14 | // baseDevice implements the Device interface. 15 | type baseDevice struct { 16 | subscription.SubscriptionTarget 17 | 18 | refreshThrottle rthrottle.RefreshThrottle 19 | outbound transport.Outbound 20 | 21 | product product.Product 22 | id uint32 23 | provisional bool 24 | seen time.Time 25 | token []byte 26 | } 27 | 28 | type InfoResponse struct { 29 | Result common.DeviceInfo `json:"result"` 30 | ID uint32 `json:"ID"` 31 | } 32 | 33 | func New(deviceId uint32, transport transport.Outbound, seen time.Time, token []byte) Device { 34 | throttle := rthrottle.NewRefreshThrottle(time.Second * 5) 35 | b := &baseDevice{ 36 | SubscriptionTarget: subscription.NewTarget(), 37 | 38 | refreshThrottle: throttle, 39 | outbound: transport, 40 | id: deviceId, 41 | seen: seen, 42 | token: token, 43 | } 44 | b.init() 45 | return b 46 | } 47 | 48 | func (b *baseDevice) init() { 49 | b.product = product.Unknown 50 | b.provisional = true 51 | } 52 | 53 | func (b *baseDevice) ID() uint32 { 54 | return b.id 55 | } 56 | 57 | func (b *baseDevice) GetLabel() (string, error) { 58 | return "", nil 59 | } 60 | 61 | func (b *baseDevice) Handle(pkt *packet.Packet) error { 62 | common.Log.Debugf("Handling packet at base_device") 63 | b.seen = pkt.Meta.DecodeTime 64 | return b.outbound.Handle(pkt) 65 | } 66 | 67 | func (b *baseDevice) Close() error { 68 | err := b.SubscriptionTarget.CloseAllSubscriptions() 69 | b.refreshThrottle.Close() 70 | return err 71 | } 72 | 73 | func (b *baseDevice) Seen() time.Time { 74 | return b.seen 75 | } 76 | 77 | func (b *baseDevice) Provisional() bool { 78 | return b.provisional 79 | } 80 | 81 | func (b *baseDevice) SetProvisional(provisional bool) { 82 | b.provisional = provisional 83 | } 84 | 85 | func (b *baseDevice) GetProduct() (product.Product, error) { 86 | resp := InfoResponse{} 87 | err := b.outbound.CallAndDeserialize("miIO.info", nil, &resp) 88 | if err != nil { 89 | return product.Unknown, err 90 | } 91 | 92 | return product.GetModel(resp.Result.Model) 93 | } 94 | 95 | func (b *baseDevice) GetInfo() (common.DeviceInfo, error) { 96 | resp := InfoResponse{} 97 | err := b.outbound.CallAndDeserialize("miIO.info", nil, &resp) 98 | return resp.Result, err 99 | } 100 | 101 | func (b *baseDevice) Discover() error { 102 | return b.outbound.Send(packet.NewHello()) 103 | } 104 | 105 | func (b *baseDevice) NewSubscription() (subscription.Subscription, error) { 106 | sub, err := b.SubscriptionTarget.NewSubscription() 107 | b.refreshThrottle.Start() 108 | return sub, err 109 | } 110 | 111 | func (b *baseDevice) RemoveSubscription(s subscription.Subscription) (err error) { 112 | err = b.SubscriptionTarget.RemoveSubscription(s) 113 | if !b.HasSubscribers() { 114 | b.refreshThrottle.Stop() 115 | } 116 | return 117 | } 118 | 119 | func (b *baseDevice) RefreshThrottle() <-chan struct{} { 120 | return b.refreshThrottle.Chan() 121 | } 122 | 123 | func (b *baseDevice) Outbound() transport.Outbound { 124 | return b.outbound 125 | } 126 | 127 | func (b *baseDevice) GetToken() []byte { 128 | return b.token 129 | } 130 | -------------------------------------------------------------------------------- /device/basedevice_test.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | "time" 7 | 8 | "github.com/benbjohnson/clock" 9 | deviceMocks "github.com/nickw444/miio-go/device/mocks" 10 | "github.com/nickw444/miio-go/device/product" 11 | "github.com/nickw444/miio-go/protocol/packet" 12 | transportMocks "github.com/nickw444/miio-go/protocol/transport/mocks" 13 | subscriptionMocks "github.com/nickw444/miio-go/subscription/common/mocks" 14 | "github.com/stretchr/testify/assert" 15 | "github.com/stretchr/testify/mock" 16 | ) 17 | 18 | func BaseDevice_SetUp() (ret struct { 19 | deviceId uint32 20 | clk *clock.Mock 21 | subTgt *subscriptionMocks.SubscriptionTarget 22 | outbound *transportMocks.Outbound 23 | rThrottle *deviceMocks.RefreshThrottle 24 | device *baseDevice 25 | }) { 26 | ret.deviceId = 10 27 | ret.clk = clock.NewMock() 28 | ret.subTgt = &subscriptionMocks.SubscriptionTarget{} 29 | ret.outbound = &transportMocks.Outbound{} 30 | ret.rThrottle = &deviceMocks.RefreshThrottle{} 31 | ret.device = &baseDevice{ 32 | SubscriptionTarget: ret.subTgt, 33 | refreshThrottle: ret.rThrottle, 34 | outbound: ret.outbound, 35 | id: ret.deviceId, 36 | seen: ret.clk.Now(), 37 | } 38 | ret.device.init() 39 | 40 | // Tick the clock 41 | ret.clk.Add(time.Second * 5) 42 | 43 | return 44 | } 45 | 46 | // Handle an incoming packet via the outbound 47 | func TestBaseDevice_Handle(t *testing.T) { 48 | tt := BaseDevice_SetUp() 49 | 50 | pkt := packet.New(tt.deviceId, bytes.Repeat([]byte{0xfa}, 16), 0xAAA, bytes.Repeat([]byte{0xCA}, 10)) 51 | tt.outbound.On("Handle", pkt).Return(nil) 52 | 53 | tt.device.Handle(pkt) 54 | tt.outbound.AssertExpectations(t) 55 | } 56 | 57 | // Handle an incoming packet should update seen time to packet decode time. 58 | func TestBaseDevice_Handle2(t *testing.T) { 59 | tt := BaseDevice_SetUp() 60 | pkt := packet.New(tt.deviceId, bytes.Repeat([]byte{0xfa}, 16), 0xAAA, bytes.Repeat([]byte{0xCA}, 10)) 61 | pkt.Meta.DecodeTime = tt.clk.Now() 62 | 63 | tt.outbound.On("Handle", pkt).Return(nil) 64 | tt.device.Handle(pkt) 65 | 66 | assert.EqualValues(t, pkt.Meta.DecodeTime, tt.device.seen) 67 | } 68 | 69 | // Closes subscriptions and refreshThrottle on close 70 | func TestBaseDevice_Close(t *testing.T) { 71 | tt := BaseDevice_SetUp() 72 | 73 | tt.subTgt.On("CloseAllSubscriptions").Return(nil) 74 | tt.rThrottle.On("Close") 75 | tt.device.Close() 76 | 77 | tt.subTgt.AssertExpectations(t) 78 | tt.rThrottle.AssertExpectations(t) 79 | } 80 | 81 | // Sets / Gets Provisional value 82 | func TestBaseDevice_Provisional(t *testing.T) { 83 | tt := BaseDevice_SetUp() 84 | 85 | assert.True(t, tt.device.provisional) 86 | tt.device.SetProvisional(false) 87 | assert.False(t, tt.device.provisional) 88 | tt.device.SetProvisional(true) 89 | assert.True(t, tt.device.provisional) 90 | } 91 | 92 | func BaseDevice_GetProduct_Setup(outbound *transportMocks.Outbound) { 93 | outbound.On("CallAndDeserialize", "miIO.info", mock.Anything, mock.Anything). 94 | Return(nil). 95 | Run(func(args mock.Arguments) { 96 | resp := args.Get(2).(*InfoResponse) 97 | resp.Result.Model = "chuangmi.plug.m1" 98 | }) 99 | } 100 | 101 | // GetProduct performs a miIO.info via outbound 102 | func TestBaseDevice_GetProduct(t *testing.T) { 103 | tt := BaseDevice_SetUp() 104 | BaseDevice_GetProduct_Setup(tt.outbound) 105 | 106 | _, err := tt.device.GetProduct() 107 | assert.NoError(t, err) 108 | tt.outbound.AssertExpectations(t) 109 | } 110 | 111 | // GetProduct returns the appropriate product 112 | func TestBaseDevice_GetProduct2(t *testing.T) { 113 | tt := BaseDevice_SetUp() 114 | BaseDevice_GetProduct_Setup(tt.outbound) 115 | 116 | p, err := tt.device.GetProduct() 117 | assert.NoError(t, err) 118 | assert.Equal(t, product.PowerPlug, p) 119 | } 120 | 121 | // Discover sends a hello packet via outbound 122 | func TestBaseDevice_Discover(t *testing.T) { 123 | tt := BaseDevice_SetUp() 124 | 125 | tt.outbound.On("Send", packet.NewHello()).Return(nil) 126 | tt.device.Discover() 127 | tt.outbound.AssertExpectations(t) 128 | 129 | } 130 | 131 | func BaseDevice_NewSubscription_Setup(throttle *deviceMocks.RefreshThrottle, target *subscriptionMocks.SubscriptionTarget) { 132 | throttle.On("Start") 133 | target.On("NewSubscription").Return(nil, nil) 134 | } 135 | 136 | // starts refreshThrottle when a new subscription is created 137 | func TestBaseDevice_NewSubscription(t *testing.T) { 138 | tt := BaseDevice_SetUp() 139 | BaseDevice_NewSubscription_Setup(tt.rThrottle, tt.subTgt) 140 | 141 | _, err := tt.device.NewSubscription() 142 | assert.NoError(t, err) 143 | tt.rThrottle.AssertExpectations(t) 144 | } 145 | 146 | // creates a new subscription 147 | func TestBaseDevice_NewSubscription2(t *testing.T) { 148 | tt := BaseDevice_SetUp() 149 | BaseDevice_NewSubscription_Setup(tt.rThrottle, tt.subTgt) 150 | 151 | _, err := tt.device.NewSubscription() 152 | assert.NoError(t, err) 153 | tt.subTgt.AssertExpectations(t) 154 | } 155 | 156 | // stops refresh throttle if the last subscription is closed 157 | func TestBaseDevice_RemoveSubscription(t *testing.T) { 158 | tt := BaseDevice_SetUp() 159 | 160 | tt.subTgt.On("RemoveSubscription", mock.Anything).Return(nil).Times(2) 161 | tt.rThrottle.On("Stop").Once() 162 | 163 | tt.subTgt.On("HasSubscribers").Return(true).Once() 164 | err := tt.device.RemoveSubscription(nil) 165 | assert.NoError(t, err) 166 | 167 | tt.subTgt.On("HasSubscribers").Return(false).Once() 168 | err = tt.device.RemoveSubscription(nil) 169 | assert.NoError(t, err) 170 | 171 | tt.subTgt.AssertExpectations(t) 172 | tt.rThrottle.AssertExpectations(t) 173 | } 174 | -------------------------------------------------------------------------------- /device/classification.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/nickw444/miio-go/device/product" 7 | ) 8 | 9 | // Classify determines the underlying product of the device and returns an 10 | // appropriate device implementation. 11 | func Classify(dev Device) (Device, error) { 12 | if !dev.Provisional() { 13 | return dev, nil 14 | } 15 | 16 | p, err := dev.GetProduct() 17 | if err != nil { 18 | return nil, err 19 | } 20 | 21 | defer dev.SetProvisional(false) 22 | 23 | switch p { 24 | case product.Yeelight: 25 | return NewYeelight(dev), nil 26 | case product.PowerPlug: 27 | return NewPowerPlug(dev), nil 28 | default: 29 | return nil, fmt.Errorf("Classify: Unknown device type") 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /device/classification_test.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "testing" 5 | 6 | deviceMocks "github.com/nickw444/miio-go/device/mocks" 7 | "github.com/nickw444/miio-go/device/product" 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | // Device is set to non-provisional after classify. 12 | func TestClassify1(t *testing.T) { 13 | baseDev := &deviceMocks.Device{} 14 | baseDev.On("Provisional").Return(true) 15 | baseDev.On("GetProduct").Return(product.Unknown, nil) 16 | baseDev.On("SetProvisional", false).Once() 17 | 18 | Classify(baseDev) 19 | baseDev.AssertExpectations(t) 20 | } 21 | 22 | // Non-provisional devices are not classified 23 | func TestClassify2(t *testing.T) { 24 | baseDev := &deviceMocks.Device{} 25 | baseDev.On("Provisional").Return(false) 26 | 27 | dev, err := Classify(baseDev) 28 | assert.NoError(t, err) 29 | assert.Equal(t, baseDev, dev) 30 | } 31 | 32 | func Classify_SetUp(product product.Product) *deviceMocks.Device { 33 | dev := &deviceMocks.Device{} 34 | dev.On("Provisional").Return(true) 35 | dev.On("GetProduct").Return(product, nil) 36 | dev.On("SetProvisional", false) 37 | dev.On("Outbound").Return(nil) 38 | return dev 39 | } 40 | 41 | func TestClassify_PowerPlug(t *testing.T) { 42 | baseDev := Classify_SetUp(product.PowerPlug) 43 | baseDev.On("RefreshThrottle").Return(nil) 44 | 45 | dev, err := Classify(baseDev) 46 | 47 | assert.NoError(t, err) 48 | assert.IsType(t, &PowerPlug{}, dev) 49 | } 50 | 51 | func TestClassify_Yeelight(t *testing.T) { 52 | baseDev := Classify_SetUp(product.Yeelight) 53 | baseDev.On("RefreshThrottle").Return(nil) 54 | 55 | dev, err := Classify(baseDev) 56 | 57 | assert.NoError(t, err) 58 | assert.IsType(t, &Yeelight{}, dev) 59 | } 60 | 61 | func TestClassify_Unknown(t *testing.T) { 62 | baseDev := Classify_SetUp(product.Unknown) 63 | 64 | _, err := Classify(baseDev) 65 | 66 | assert.Error(t, err) 67 | } 68 | -------------------------------------------------------------------------------- /device/device.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/nickw444/miio-go/common" 7 | "github.com/nickw444/miio-go/device/product" 8 | "github.com/nickw444/miio-go/protocol/packet" 9 | "github.com/nickw444/miio-go/protocol/transport" 10 | ) 11 | 12 | type Device interface { 13 | common.Device 14 | 15 | Handle(*packet.Packet) error 16 | Close() error 17 | Seen() time.Time 18 | Provisional() bool 19 | SetProvisional(bool) 20 | GetProduct() (product.Product, error) 21 | Discover() error 22 | RefreshThrottle() <-chan struct{} 23 | Outbound() transport.Outbound 24 | } 25 | -------------------------------------------------------------------------------- /device/mocks/Device.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | common "github.com/nickw444/miio-go/common" 7 | 8 | mock "github.com/stretchr/testify/mock" 9 | 10 | packet "github.com/nickw444/miio-go/protocol/packet" 11 | 12 | product "github.com/nickw444/miio-go/device/product" 13 | 14 | subscriptioncommon "github.com/nickw444/miio-go/subscription/common" 15 | 16 | time "time" 17 | 18 | transport "github.com/nickw444/miio-go/protocol/transport" 19 | ) 20 | 21 | // Device is an autogenerated mock type for the Device type 22 | type Device struct { 23 | mock.Mock 24 | } 25 | 26 | // Close provides a mock function with given fields: 27 | func (_m *Device) Close() error { 28 | ret := _m.Called() 29 | 30 | var r0 error 31 | if rf, ok := ret.Get(0).(func() error); ok { 32 | r0 = rf() 33 | } else { 34 | r0 = ret.Error(0) 35 | } 36 | 37 | return r0 38 | } 39 | 40 | // CloseAllSubscriptions provides a mock function with given fields: 41 | func (_m *Device) CloseAllSubscriptions() error { 42 | ret := _m.Called() 43 | 44 | var r0 error 45 | if rf, ok := ret.Get(0).(func() error); ok { 46 | r0 = rf() 47 | } else { 48 | r0 = ret.Error(0) 49 | } 50 | 51 | return r0 52 | } 53 | 54 | // Discover provides a mock function with given fields: 55 | func (_m *Device) Discover() error { 56 | ret := _m.Called() 57 | 58 | var r0 error 59 | if rf, ok := ret.Get(0).(func() error); ok { 60 | r0 = rf() 61 | } else { 62 | r0 = ret.Error(0) 63 | } 64 | 65 | return r0 66 | } 67 | 68 | // GetInfo provides a mock function with given fields: 69 | func (_m *Device) GetInfo() (common.DeviceInfo, error) { 70 | ret := _m.Called() 71 | 72 | var r0 common.DeviceInfo 73 | if rf, ok := ret.Get(0).(func() common.DeviceInfo); ok { 74 | r0 = rf() 75 | } else { 76 | r0 = ret.Get(0).(common.DeviceInfo) 77 | } 78 | 79 | var r1 error 80 | if rf, ok := ret.Get(1).(func() error); ok { 81 | r1 = rf() 82 | } else { 83 | r1 = ret.Error(1) 84 | } 85 | 86 | return r0, r1 87 | } 88 | 89 | // GetLabel provides a mock function with given fields: 90 | func (_m *Device) GetLabel() (string, error) { 91 | ret := _m.Called() 92 | 93 | var r0 string 94 | if rf, ok := ret.Get(0).(func() string); ok { 95 | r0 = rf() 96 | } else { 97 | r0 = ret.Get(0).(string) 98 | } 99 | 100 | var r1 error 101 | if rf, ok := ret.Get(1).(func() error); ok { 102 | r1 = rf() 103 | } else { 104 | r1 = ret.Error(1) 105 | } 106 | 107 | return r0, r1 108 | } 109 | 110 | // GetProduct provides a mock function with given fields: 111 | func (_m *Device) GetProduct() (product.Product, error) { 112 | ret := _m.Called() 113 | 114 | var r0 product.Product 115 | if rf, ok := ret.Get(0).(func() product.Product); ok { 116 | r0 = rf() 117 | } else { 118 | r0 = ret.Get(0).(product.Product) 119 | } 120 | 121 | var r1 error 122 | if rf, ok := ret.Get(1).(func() error); ok { 123 | r1 = rf() 124 | } else { 125 | r1 = ret.Error(1) 126 | } 127 | 128 | return r0, r1 129 | } 130 | 131 | // GetToken provides a mock function with given fields: 132 | func (_m *Device) GetToken() []byte { 133 | ret := _m.Called() 134 | 135 | var r0 []byte 136 | if rf, ok := ret.Get(0).(func() []byte); ok { 137 | r0 = rf() 138 | } else { 139 | if ret.Get(0) != nil { 140 | r0 = ret.Get(0).([]byte) 141 | } 142 | } 143 | 144 | return r0 145 | } 146 | 147 | // Handle provides a mock function with given fields: _a0 148 | func (_m *Device) Handle(_a0 *packet.Packet) error { 149 | ret := _m.Called(_a0) 150 | 151 | var r0 error 152 | if rf, ok := ret.Get(0).(func(*packet.Packet) error); ok { 153 | r0 = rf(_a0) 154 | } else { 155 | r0 = ret.Error(0) 156 | } 157 | 158 | return r0 159 | } 160 | 161 | // HasSubscribers provides a mock function with given fields: 162 | func (_m *Device) HasSubscribers() bool { 163 | ret := _m.Called() 164 | 165 | var r0 bool 166 | if rf, ok := ret.Get(0).(func() bool); ok { 167 | r0 = rf() 168 | } else { 169 | r0 = ret.Get(0).(bool) 170 | } 171 | 172 | return r0 173 | } 174 | 175 | // ID provides a mock function with given fields: 176 | func (_m *Device) ID() uint32 { 177 | ret := _m.Called() 178 | 179 | var r0 uint32 180 | if rf, ok := ret.Get(0).(func() uint32); ok { 181 | r0 = rf() 182 | } else { 183 | r0 = ret.Get(0).(uint32) 184 | } 185 | 186 | return r0 187 | } 188 | 189 | // NewSubscription provides a mock function with given fields: 190 | func (_m *Device) NewSubscription() (subscriptioncommon.Subscription, error) { 191 | ret := _m.Called() 192 | 193 | var r0 subscriptioncommon.Subscription 194 | if rf, ok := ret.Get(0).(func() subscriptioncommon.Subscription); ok { 195 | r0 = rf() 196 | } else { 197 | if ret.Get(0) != nil { 198 | r0 = ret.Get(0).(subscriptioncommon.Subscription) 199 | } 200 | } 201 | 202 | var r1 error 203 | if rf, ok := ret.Get(1).(func() error); ok { 204 | r1 = rf() 205 | } else { 206 | r1 = ret.Error(1) 207 | } 208 | 209 | return r0, r1 210 | } 211 | 212 | // Outbound provides a mock function with given fields: 213 | func (_m *Device) Outbound() transport.Outbound { 214 | ret := _m.Called() 215 | 216 | var r0 transport.Outbound 217 | if rf, ok := ret.Get(0).(func() transport.Outbound); ok { 218 | r0 = rf() 219 | } else { 220 | if ret.Get(0) != nil { 221 | r0 = ret.Get(0).(transport.Outbound) 222 | } 223 | } 224 | 225 | return r0 226 | } 227 | 228 | // Provisional provides a mock function with given fields: 229 | func (_m *Device) Provisional() bool { 230 | ret := _m.Called() 231 | 232 | var r0 bool 233 | if rf, ok := ret.Get(0).(func() bool); ok { 234 | r0 = rf() 235 | } else { 236 | r0 = ret.Get(0).(bool) 237 | } 238 | 239 | return r0 240 | } 241 | 242 | // Publish provides a mock function with given fields: event 243 | func (_m *Device) Publish(event interface{}) error { 244 | ret := _m.Called(event) 245 | 246 | var r0 error 247 | if rf, ok := ret.Get(0).(func(interface{}) error); ok { 248 | r0 = rf(event) 249 | } else { 250 | r0 = ret.Error(0) 251 | } 252 | 253 | return r0 254 | } 255 | 256 | // RefreshThrottle provides a mock function with given fields: 257 | func (_m *Device) RefreshThrottle() <-chan struct{} { 258 | ret := _m.Called() 259 | 260 | var r0 <-chan struct{} 261 | if rf, ok := ret.Get(0).(func() <-chan struct{}); ok { 262 | r0 = rf() 263 | } else { 264 | if ret.Get(0) != nil { 265 | r0 = ret.Get(0).(<-chan struct{}) 266 | } 267 | } 268 | 269 | return r0 270 | } 271 | 272 | // RemoveSubscription provides a mock function with given fields: s 273 | func (_m *Device) RemoveSubscription(s subscriptioncommon.Subscription) error { 274 | ret := _m.Called(s) 275 | 276 | var r0 error 277 | if rf, ok := ret.Get(0).(func(subscriptioncommon.Subscription) error); ok { 278 | r0 = rf(s) 279 | } else { 280 | r0 = ret.Error(0) 281 | } 282 | 283 | return r0 284 | } 285 | 286 | // Seen provides a mock function with given fields: 287 | func (_m *Device) Seen() time.Time { 288 | ret := _m.Called() 289 | 290 | var r0 time.Time 291 | if rf, ok := ret.Get(0).(func() time.Time); ok { 292 | r0 = rf() 293 | } else { 294 | r0 = ret.Get(0).(time.Time) 295 | } 296 | 297 | return r0 298 | } 299 | 300 | // SetProvisional provides a mock function with given fields: _a0 301 | func (_m *Device) SetProvisional(_a0 bool) { 302 | _m.Called(_a0) 303 | } 304 | -------------------------------------------------------------------------------- /device/mocks/RefreshThrottle.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import mock "github.com/stretchr/testify/mock" 6 | 7 | // RefreshThrottle is an autogenerated mock type for the RefreshThrottle type 8 | type RefreshThrottle struct { 9 | mock.Mock 10 | } 11 | 12 | // Chan provides a mock function with given fields: 13 | func (_m *RefreshThrottle) Chan() <-chan struct{} { 14 | ret := _m.Called() 15 | 16 | var r0 <-chan struct{} 17 | if rf, ok := ret.Get(0).(func() <-chan struct{}); ok { 18 | r0 = rf() 19 | } else { 20 | if ret.Get(0) != nil { 21 | r0 = ret.Get(0).(<-chan struct{}) 22 | } 23 | } 24 | 25 | return r0 26 | } 27 | 28 | // Close provides a mock function with given fields: 29 | func (_m *RefreshThrottle) Close() { 30 | _m.Called() 31 | } 32 | 33 | // Start provides a mock function with given fields: 34 | func (_m *RefreshThrottle) Start() { 35 | _m.Called() 36 | } 37 | 38 | // Stop provides a mock function with given fields: 39 | func (_m *RefreshThrottle) Stop() { 40 | _m.Called() 41 | } 42 | -------------------------------------------------------------------------------- /device/powerplug.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "github.com/nickw444/miio-go/capability" 5 | "github.com/nickw444/miio-go/common" 6 | ) 7 | 8 | type PowerPlug struct { 9 | Device 10 | *capability.Power 11 | } 12 | 13 | func NewPowerPlug(device Device) *PowerPlug { 14 | dev := &PowerPlug{ 15 | Device: device, 16 | Power: capability.NewPower(device, device.Outbound()), 17 | } 18 | go dev.refresh() 19 | return dev 20 | } 21 | 22 | func (p *PowerPlug) refresh() { 23 | for range p.RefreshThrottle() { 24 | _ = p.Power.Update() 25 | } 26 | 27 | common.Log.Debug("Device refresh closed.") 28 | } 29 | -------------------------------------------------------------------------------- /device/powerplug_test.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | // TODO (NW): Write Tests 4 | -------------------------------------------------------------------------------- /device/product/product.go: -------------------------------------------------------------------------------- 1 | package product 2 | 3 | import "fmt" 4 | 5 | type Product uint16 6 | 7 | const ( 8 | PowerPlug Product = iota << 1 9 | Yeelight 10 | Unknown 11 | ) 12 | 13 | func GetModel(modelName string) (Product, error) { 14 | switch modelName { 15 | case "chuangmi.plug.m1": 16 | return PowerPlug, nil 17 | case "yeelink.light.color1": 18 | return Yeelight, nil 19 | default: 20 | return Unknown, fmt.Errorf("Unknown product for device type %s", modelName) 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /device/product/product_test.go: -------------------------------------------------------------------------------- 1 | package product 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestGetModel(t *testing.T) { 10 | p, err := GetModel("chuangmi.plug.m1") 11 | assert.NoError(t, err) 12 | assert.Equal(t, PowerPlug, p) 13 | 14 | p, err = GetModel("fake") 15 | assert.Error(t, err) 16 | assert.Equal(t, Unknown, p) 17 | } 18 | -------------------------------------------------------------------------------- /device/rthrottle/mocks/RefreshThrottle.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import mock "github.com/stretchr/testify/mock" 6 | 7 | // RefreshThrottle is an autogenerated mock type for the RefreshThrottle type 8 | type RefreshThrottle struct { 9 | mock.Mock 10 | } 11 | 12 | // Chan provides a mock function with given fields: 13 | func (_m *RefreshThrottle) Chan() <-chan struct{} { 14 | ret := _m.Called() 15 | 16 | var r0 <-chan struct{} 17 | if rf, ok := ret.Get(0).(func() <-chan struct{}); ok { 18 | r0 = rf() 19 | } else { 20 | if ret.Get(0) != nil { 21 | r0 = ret.Get(0).(<-chan struct{}) 22 | } 23 | } 24 | 25 | return r0 26 | } 27 | 28 | // Close provides a mock function with given fields: 29 | func (_m *RefreshThrottle) Close() { 30 | _m.Called() 31 | } 32 | 33 | // Start provides a mock function with given fields: 34 | func (_m *RefreshThrottle) Start() { 35 | _m.Called() 36 | } 37 | 38 | // Stop provides a mock function with given fields: 39 | func (_m *RefreshThrottle) Stop() { 40 | _m.Called() 41 | } 42 | -------------------------------------------------------------------------------- /device/rthrottle/refresh_throttle.go: -------------------------------------------------------------------------------- 1 | package rthrottle 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/benbjohnson/clock" 7 | ) 8 | 9 | type RefreshThrottle interface { 10 | Chan() <-chan struct{} 11 | Stop() 12 | Start() 13 | Close() 14 | } 15 | 16 | type TickerFactory func() *clock.Ticker 17 | 18 | type refreshThrottle struct { 19 | tickerFactory TickerFactory 20 | ch chan struct{} 21 | 22 | quitChan chan struct{} 23 | ticker *clock.Ticker 24 | } 25 | 26 | func NewRefreshThrottle(refreshInterval time.Duration) RefreshThrottle { 27 | c := clock.New() 28 | return &refreshThrottle{ 29 | tickerFactory: func() *clock.Ticker { return c.Ticker(refreshInterval) }, 30 | ch: make(chan struct{}), 31 | } 32 | } 33 | 34 | func (r *refreshThrottle) Chan() <-chan struct{} { 35 | return r.ch 36 | } 37 | 38 | func (r *refreshThrottle) Start() { 39 | if r.ticker == nil { 40 | r.quitChan = make(chan struct{}) 41 | r.ticker = r.tickerFactory() 42 | go r.refresh() 43 | } 44 | } 45 | 46 | func (r *refreshThrottle) Stop() { 47 | if r.ticker != nil { 48 | r.ticker.Stop() 49 | r.ticker = nil 50 | close(r.quitChan) 51 | } 52 | } 53 | 54 | func (r *refreshThrottle) Close() { 55 | r.Stop() 56 | close(r.ch) 57 | } 58 | 59 | func (r *refreshThrottle) refresh() { 60 | // Request a refresh immediately. 61 | r.ch <- struct{}{} 62 | 63 | for { 64 | select { 65 | case <-r.quitChan: 66 | return 67 | default: 68 | } 69 | 70 | select { 71 | case <-r.quitChan: 72 | return 73 | case <-r.ticker.C: 74 | r.ch <- struct{}{} 75 | } 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /device/rthrottle/refresh_throttle_test.go: -------------------------------------------------------------------------------- 1 | package rthrottle 2 | 3 | import ( 4 | "sync" 5 | "testing" 6 | "time" 7 | 8 | "github.com/benbjohnson/clock" 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func RefreshThrottle_Setup() (tt struct { 13 | refreshInterval time.Duration 14 | clk *clock.Mock 15 | throttle *refreshThrottle 16 | }) { 17 | tt.refreshInterval = 5 * time.Second 18 | tt.clk = clock.NewMock() 19 | tt.throttle = &refreshThrottle{ 20 | tickerFactory: func() *clock.Ticker { return tt.clk.Ticker(tt.refreshInterval) }, 21 | ch: make(chan struct{}, 2), 22 | } 23 | 24 | return 25 | } 26 | 27 | //does not close Chan until Close is called 28 | func TestRefreshThrottle_ChanDoesNotCloseUntilClose(t *testing.T) { 29 | tt := RefreshThrottle_Setup() 30 | 31 | wg := sync.WaitGroup{} 32 | ch := tt.throttle.Chan() 33 | closed := false 34 | 35 | wg.Add(1) 36 | 37 | go func() { 38 | ticks := 0 39 | for range ch { 40 | ticks++ 41 | } 42 | 43 | assert.True(t, ticks > 1, "Expected at least 2 ticks before channel close.") 44 | 45 | if !closed { 46 | t.Error("Expected channel to not have been closed yet.") 47 | } 48 | 49 | wg.Done() 50 | }() 51 | 52 | tt.throttle.Start() 53 | tt.clk.Add(tt.refreshInterval) 54 | 55 | closed = true 56 | tt.throttle.Close() 57 | wg.Wait() 58 | } 59 | 60 | // Test to ensure that stop kills all clock ticks 61 | func TestRefreshThrottle_Stop(t *testing.T) { 62 | tt := RefreshThrottle_Setup() 63 | 64 | ch := tt.throttle.Chan() 65 | tt.throttle.Start() 66 | <-ch // Clear initial tick. 67 | tt.clk.Add(tt.refreshInterval) 68 | race(t, ch) 69 | 70 | tt.throttle.Stop() 71 | tt.clk.Add(tt.refreshInterval) 72 | assert.Len(t, ch, 0) 73 | 74 | tt.clk.Add(tt.refreshInterval) 75 | assert.Len(t, ch, 0) 76 | } 77 | 78 | // Test to ensure that start starts the throttle ticking. 79 | func TestRefreshThrottle_Start(t *testing.T) { 80 | tt := RefreshThrottle_Setup() 81 | 82 | ch := tt.throttle.Chan() 83 | assert.Empty(t, ch) 84 | tt.throttle.Start() 85 | 86 | // Expect an initial event. 87 | race(t, ch) 88 | 89 | // Expect an event after refresh interval 90 | tt.clk.Add(tt.refreshInterval) 91 | race(t, ch) 92 | 93 | // Expect no more ticks until refresh interval 94 | assert.Len(t, ch, 0) 95 | tt.clk.Add(tt.refreshInterval - time.Second) 96 | assert.Len(t, ch, 0) 97 | tt.clk.Add(time.Second) 98 | assert.Len(t, ch, 1) 99 | } 100 | 101 | // Test to ensure that start starts the throttle ticking after being stopped. 102 | func TestRefreshThrottle_StartAfterStop(t *testing.T) { 103 | tt := RefreshThrottle_Setup() 104 | 105 | ch := tt.throttle.Chan() 106 | tt.throttle.Start() 107 | <-ch // Clear initial tick. 108 | tt.clk.Add(tt.refreshInterval) 109 | race(t, ch) 110 | 111 | tt.throttle.Stop() 112 | tt.clk.Add(tt.refreshInterval) 113 | assert.Len(t, ch, 0) 114 | 115 | tt.clk.Add(tt.refreshInterval) 116 | assert.Len(t, ch, 0) 117 | 118 | tt.throttle.Start() 119 | <-ch // Clear initial tick. 120 | tt.clk.Add(tt.refreshInterval) 121 | race(t, ch) 122 | 123 | tt.clk.Add(tt.refreshInterval) 124 | assert.Len(t, ch, 1) 125 | } 126 | 127 | func race(t *testing.T, ch <-chan struct{}) { 128 | select { 129 | case <-ch: 130 | return 131 | case <-time.After(10 * time.Second): 132 | t.Error("Timed out whilst waiting for channel.") 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /device/yeelight.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "github.com/nickw444/miio-go/capability" 5 | "github.com/nickw444/miio-go/common" 6 | ) 7 | 8 | type Yeelight struct { 9 | Device 10 | *capability.Light 11 | *capability.Power 12 | } 13 | 14 | func NewYeelight(device Device) *Yeelight { 15 | dev := &Yeelight{ 16 | Device: device, 17 | Power: capability.NewPower(device, device.Outbound()), 18 | Light: capability.NewLight(device, device.Outbound()), 19 | } 20 | go dev.refresh() 21 | return dev 22 | } 23 | 24 | func (p *Yeelight) refresh() { 25 | for range p.RefreshThrottle() { 26 | _ = p.Power.Update() 27 | _ = p.Light.Update() 28 | } 29 | 30 | common.Log.Debug("Device refresh closed.") 31 | } 32 | -------------------------------------------------------------------------------- /device/yeelight_test.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | // TODO (NW): Write Tests 4 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/nickw444/miio-go 2 | 3 | go 1.12 4 | 5 | require ( 6 | github.com/alecthomas/kingpin v2.2.6+incompatible 7 | github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 8 | github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4 9 | github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 10 | github.com/davecgh/go-spew v1.1.1 11 | github.com/haya14busa/goverage v0.0.0-20180129164344-eec3514a20b5 // indirect 12 | github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect 13 | github.com/lunixbochs/struc v0.0.0-20190326164542-a9e4041416c2 14 | github.com/mattn/goveralls v0.0.2 // indirect 15 | github.com/pmezard/go-difflib v1.0.0 16 | github.com/satori/go.uuid v1.2.0 17 | github.com/sirupsen/logrus v1.4.2 18 | github.com/stretchr/objx v0.2.0 19 | github.com/stretchr/testify v1.4.0 20 | github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5 // indirect 21 | golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 22 | golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7 // indirect 23 | golang.org/x/sys v0.0.0-20190825160603-fb81701db80f 24 | golang.org/x/text v0.3.2 // indirect 25 | golang.org/x/tools v0.0.0-20190825031127-d72b05d2b1b6 // indirect 26 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect 27 | ) 28 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/alecthomas/kingpin v2.2.6+incompatible h1:5svnBTFgJjZvGKyYBtMB0+m5wvrbUHiqye8wRJMlnYI= 2 | github.com/alecthomas/kingpin v2.2.6+incompatible/go.mod h1:59OFYbFVLKQKq+mqrL6Rw5bR0c3ACQaawgXx0QYndlE= 3 | github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc h1:cAKDfWh5VpdgMhJosfJnn5/FoN2SRZ4p7fJNX58YPaU= 4 | github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= 5 | github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 h1:JYp7IbQjafoB+tBA3gMyHYHrpOtNuDiK/uB5uXxq5wM= 6 | github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= 7 | github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf h1:qet1QNfXsQxTZqLG4oE62mJzwPIB8+Tee4RNCL9ulrY= 8 | github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= 9 | github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4 h1:Hs82Z41s6SdL1CELW+XaDYmOH4hkBN4/N9og/AsOv7E= 10 | github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= 11 | github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc= 12 | github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE= 13 | github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= 14 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 15 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 16 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 17 | github.com/haya14busa/goverage v0.0.0-20180129164344-eec3514a20b5 h1:FdBGmSkD2QpQzRWup//SGObvWf2nq89zj9+ta9OvI3A= 18 | github.com/haya14busa/goverage v0.0.0-20180129164344-eec3514a20b5/go.mod h1:0YZ2wQSuwviXXXGUiK6zXzskyBLAbLXhamxzcFHSLoM= 19 | github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= 20 | github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= 21 | github.com/lunixbochs/struc v0.0.0-20180408203800-02e4c2afbb2a h1:axFx97V2Lyke5LbeygrJlzc07mwVhHt2ZHeI/Nv8Aq4= 22 | github.com/lunixbochs/struc v0.0.0-20180408203800-02e4c2afbb2a/go.mod h1:iOJu9pApjjmEmNq7PqlA5R9mDu/HMF5EM3llWKX/TyA= 23 | github.com/lunixbochs/struc v0.0.0-20190326164542-a9e4041416c2 h1:xvBq0/ARZLqmB57m6jds017I+KtXPcsKBHv6dUUac4A= 24 | github.com/lunixbochs/struc v0.0.0-20190326164542-a9e4041416c2/go.mod h1:iOJu9pApjjmEmNq7PqlA5R9mDu/HMF5EM3llWKX/TyA= 25 | github.com/mattn/goveralls v0.0.2 h1:7eJB6EqsPhRVxvwEXGnqdO2sJI0PTsrWoTMXEk9/OQc= 26 | github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= 27 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 28 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 29 | github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww= 30 | github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= 31 | github.com/sirupsen/logrus v1.0.6 h1:hcP1GmhGigz/O7h1WVUM5KklBp1JoNS9FggWKdj/j3s= 32 | github.com/sirupsen/logrus v1.0.6/go.mod h1:pMByvHTf9Beacp5x1UXfOR9xyW/9antXMhjMPG0dEzc= 33 | github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= 34 | github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= 35 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 36 | github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= 37 | github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 38 | github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= 39 | github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= 40 | github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= 41 | github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 42 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 43 | github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= 44 | github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 45 | github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5 h1:Xim2mBRFdXzXmKRO8DJg/FJtn/8Fj9NOEpO6+WuMPmk= 46 | github.com/vektra/mockery v0.0.0-20181123154057-e78b021dcbb5/go.mod h1:ppEjwdhyy7Y31EnHRDm1JkChoC7LXIJ7Ex0VYLWtZtQ= 47 | golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= 48 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= 49 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 50 | golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= 51 | golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 52 | golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 53 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 54 | golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 55 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 56 | golang.org/x/sys v0.0.0-20180727230415-bd9dbc187b6e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 57 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= 58 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 59 | golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 60 | golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 61 | golang.org/x/sys v0.0.0-20190825160603-fb81701db80f h1:LCxigP8q3fPRGNVYndYsyHnF0zRrvcoVwZMfb8iQZe4= 62 | golang.org/x/sys v0.0.0-20190825160603-fb81701db80f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 63 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 64 | golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= 65 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 66 | golang.org/x/tools v0.0.0-20181112210238-4b1f3b6b1646/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 67 | golang.org/x/tools v0.0.0-20190825031127-d72b05d2b1b6 h1:l//7Uxa3g2EhYJdlFcJzXf/oPre+P1nB/0cxBRSZGWU= 68 | golang.org/x/tools v0.0.0-20190825031127-d72b05d2b1b6/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 69 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 70 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 71 | gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 72 | gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= 73 | gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= 74 | -------------------------------------------------------------------------------- /protocol/mocks/Protocol.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | common "github.com/nickw444/miio-go/subscription/common" 7 | mock "github.com/stretchr/testify/mock" 8 | 9 | time "time" 10 | ) 11 | 12 | // Protocol is an autogenerated mock type for the Protocol type 13 | type Protocol struct { 14 | mock.Mock 15 | } 16 | 17 | // CloseAllSubscriptions provides a mock function with given fields: 18 | func (_m *Protocol) CloseAllSubscriptions() error { 19 | ret := _m.Called() 20 | 21 | var r0 error 22 | if rf, ok := ret.Get(0).(func() error); ok { 23 | r0 = rf() 24 | } else { 25 | r0 = ret.Error(0) 26 | } 27 | 28 | return r0 29 | } 30 | 31 | // Discover provides a mock function with given fields: 32 | func (_m *Protocol) Discover() error { 33 | ret := _m.Called() 34 | 35 | var r0 error 36 | if rf, ok := ret.Get(0).(func() error); ok { 37 | r0 = rf() 38 | } else { 39 | r0 = ret.Error(0) 40 | } 41 | 42 | return r0 43 | } 44 | 45 | // HasSubscribers provides a mock function with given fields: 46 | func (_m *Protocol) HasSubscribers() bool { 47 | ret := _m.Called() 48 | 49 | var r0 bool 50 | if rf, ok := ret.Get(0).(func() bool); ok { 51 | r0 = rf() 52 | } else { 53 | r0 = ret.Get(0).(bool) 54 | } 55 | 56 | return r0 57 | } 58 | 59 | // NewSubscription provides a mock function with given fields: 60 | func (_m *Protocol) NewSubscription() (common.Subscription, error) { 61 | ret := _m.Called() 62 | 63 | var r0 common.Subscription 64 | if rf, ok := ret.Get(0).(func() common.Subscription); ok { 65 | r0 = rf() 66 | } else { 67 | if ret.Get(0) != nil { 68 | r0 = ret.Get(0).(common.Subscription) 69 | } 70 | } 71 | 72 | var r1 error 73 | if rf, ok := ret.Get(1).(func() error); ok { 74 | r1 = rf() 75 | } else { 76 | r1 = ret.Error(1) 77 | } 78 | 79 | return r0, r1 80 | } 81 | 82 | // Publish provides a mock function with given fields: event 83 | func (_m *Protocol) Publish(event interface{}) error { 84 | ret := _m.Called(event) 85 | 86 | var r0 error 87 | if rf, ok := ret.Get(0).(func(interface{}) error); ok { 88 | r0 = rf(event) 89 | } else { 90 | r0 = ret.Error(0) 91 | } 92 | 93 | return r0 94 | } 95 | 96 | // RemoveSubscription provides a mock function with given fields: s 97 | func (_m *Protocol) RemoveSubscription(s common.Subscription) error { 98 | ret := _m.Called(s) 99 | 100 | var r0 error 101 | if rf, ok := ret.Get(0).(func(common.Subscription) error); ok { 102 | r0 = rf(s) 103 | } else { 104 | r0 = ret.Error(0) 105 | } 106 | 107 | return r0 108 | } 109 | 110 | // SetExpiryTime provides a mock function with given fields: duration 111 | func (_m *Protocol) SetExpiryTime(duration time.Duration) { 112 | _m.Called(duration) 113 | } 114 | -------------------------------------------------------------------------------- /protocol/packet/crypto.go: -------------------------------------------------------------------------------- 1 | package packet 2 | 3 | import ( 4 | "bytes" 5 | "crypto/aes" 6 | "crypto/cipher" 7 | "crypto/md5" 8 | "fmt" 9 | "time" 10 | 11 | "github.com/benbjohnson/clock" 12 | ) 13 | 14 | type Crypto interface { 15 | VerifyPacket(pkt *Packet) error 16 | Decrypt(data []byte) ([]byte, error) 17 | Encrypt(data []byte) ([]byte, error) 18 | NewPacket(data []byte) (*Packet, error) 19 | } 20 | 21 | type crypto struct { 22 | iv []byte 23 | key []byte 24 | deviceId uint32 25 | deviceToken []byte 26 | initialStamp uint32 27 | stampTime time.Time 28 | clock clock.Clock 29 | } 30 | 31 | func NewCrypto(deviceID uint32, deviceToken []byte, initialStamp uint32, stampTime time.Time, clock clock.Clock) ( 32 | Crypto, error) { 33 | 34 | hash := md5.New() 35 | _, err := hash.Write(deviceToken) 36 | if err != nil { 37 | return nil, err 38 | } 39 | key := hash.Sum(nil) 40 | 41 | hash = md5.New() 42 | _, err = hash.Write(key) 43 | if err != nil { 44 | return nil, err 45 | } 46 | _, err = hash.Write(deviceToken) 47 | if err != nil { 48 | return nil, err 49 | } 50 | iv := hash.Sum(nil) 51 | 52 | return &crypto{ 53 | deviceId: deviceID, 54 | deviceToken: deviceToken, 55 | initialStamp: initialStamp, 56 | stampTime: stampTime, 57 | clock: clock, 58 | 59 | iv: iv, 60 | key: key, 61 | }, nil 62 | } 63 | 64 | func (c *crypto) VerifyPacket(pkt *Packet) error { 65 | // Verify the checksum. 66 | return pkt.Verify(c.deviceToken) 67 | } 68 | 69 | func (c *crypto) Decrypt(data []byte) ([]byte, error) { 70 | block, err := aes.NewCipher(c.key) 71 | if err != nil { 72 | return nil, err 73 | } 74 | stream := cipher.NewCBCDecrypter(block, c.iv) 75 | decrypted := make([]byte, len(data)) 76 | stream.CryptBlocks(decrypted, data) 77 | 78 | return pkcs5Unpad(decrypted, block.BlockSize()) 79 | } 80 | 81 | func (c *crypto) Encrypt(data []byte) ([]byte, error) { 82 | block, err := aes.NewCipher(c.key) 83 | if err != nil { 84 | return nil, err 85 | } 86 | 87 | data = pkcs5Pad(data, block.BlockSize()) 88 | stream := cipher.NewCBCEncrypter(block, c.iv) 89 | 90 | encrypted := make([]byte, len(data)) 91 | stream.CryptBlocks(encrypted, []byte(data)) 92 | return encrypted, nil 93 | } 94 | 95 | func (c *crypto) getStamp() uint32 { 96 | return uint32(c.clock.Now().Sub(c.stampTime).Seconds()) + c.initialStamp 97 | } 98 | 99 | func (c *crypto) NewPacket(data []byte) (*Packet, error) { 100 | encrypted, err := c.Encrypt(data) 101 | if err != nil { 102 | return nil, err 103 | } 104 | 105 | stamp := c.getStamp() 106 | 107 | p := New(c.deviceId, c.deviceToken, stamp, encrypted) 108 | err = p.WriteChecksum() 109 | 110 | if err != nil { 111 | return nil, err 112 | } 113 | 114 | return p, nil 115 | } 116 | 117 | // Pad using PKCS5 padding scheme. 118 | func pkcs5Pad(data []byte, blockSize int) []byte { 119 | length := len(data) 120 | padLength := (blockSize - (length % blockSize)) 121 | pad := bytes.Repeat([]byte{byte(padLength)}, padLength) 122 | return append(data, pad...) 123 | } 124 | 125 | // Unpad using PKCS5 padding scheme. 126 | func pkcs5Unpad(data []byte, blockSize int) ([]byte, error) { 127 | srcLen := len(data) 128 | paddingLen := int(data[srcLen-1]) 129 | if paddingLen >= srcLen || paddingLen > blockSize { 130 | return nil, fmt.Errorf("Padding size error whilst decrypting payload. Src Length: %d, Padding Length: %d, Block Size: %d", srcLen, paddingLen, blockSize) 131 | } 132 | return data[:srcLen-paddingLen], nil 133 | } 134 | -------------------------------------------------------------------------------- /protocol/packet/crypto_test.go: -------------------------------------------------------------------------------- 1 | package packet 2 | 3 | import ( 4 | "bytes" 5 | "testing" 6 | 7 | "github.com/stretchr/testify/assert" 8 | ) 9 | 10 | // TODO (NW): Write more tests 11 | 12 | func TestPkcs5Pad1(t *testing.T) { 13 | data := bytes.Repeat([]byte{0xff}, 16) 14 | result := pkcs5Pad(data, 16) 15 | assert.Len(t, result, 32) 16 | assert.Equal(t, append(data, bytes.Repeat([]byte{16}, 16)...), result) 17 | } 18 | 19 | func TestPkcs5Pad2(t *testing.T) { 20 | data := bytes.Repeat([]byte{0xff}, 15) 21 | result := pkcs5Pad(data, 16) 22 | assert.Len(t, result, 16) 23 | assert.Equal(t, append(data, 0x01), result) 24 | } 25 | 26 | func TestPkcs5Unpad1(t *testing.T) { 27 | data := bytes.Repeat([]byte{0xff}, 16) 28 | padded := append(data, bytes.Repeat([]byte{16}, 16)...) 29 | 30 | result, err := pkcs5Unpad(padded, 16) 31 | assert.NoError(t, err) 32 | assert.Len(t, result, 16) 33 | assert.Equal(t, data, result) 34 | } 35 | 36 | func TestPkcs5Unpad2(t *testing.T) { 37 | data := bytes.Repeat([]byte{0xff}, 15) 38 | padded := append(data, 0x01) 39 | 40 | result, err := pkcs5Unpad(padded, 16) 41 | assert.NoError(t, err) 42 | assert.Len(t, result, 15) 43 | assert.Equal(t, data, result) 44 | } 45 | -------------------------------------------------------------------------------- /protocol/packet/mocks/Crypto.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | packet "github.com/nickw444/miio-go/protocol/packet" 7 | mock "github.com/stretchr/testify/mock" 8 | ) 9 | 10 | // Crypto is an autogenerated mock type for the Crypto type 11 | type Crypto struct { 12 | mock.Mock 13 | } 14 | 15 | // Decrypt provides a mock function with given fields: data 16 | func (_m *Crypto) Decrypt(data []byte) ([]byte, error) { 17 | ret := _m.Called(data) 18 | 19 | var r0 []byte 20 | if rf, ok := ret.Get(0).(func([]byte) []byte); ok { 21 | r0 = rf(data) 22 | } else { 23 | if ret.Get(0) != nil { 24 | r0 = ret.Get(0).([]byte) 25 | } 26 | } 27 | 28 | var r1 error 29 | if rf, ok := ret.Get(1).(func([]byte) error); ok { 30 | r1 = rf(data) 31 | } else { 32 | r1 = ret.Error(1) 33 | } 34 | 35 | return r0, r1 36 | } 37 | 38 | // Encrypt provides a mock function with given fields: data 39 | func (_m *Crypto) Encrypt(data []byte) ([]byte, error) { 40 | ret := _m.Called(data) 41 | 42 | var r0 []byte 43 | if rf, ok := ret.Get(0).(func([]byte) []byte); ok { 44 | r0 = rf(data) 45 | } else { 46 | if ret.Get(0) != nil { 47 | r0 = ret.Get(0).([]byte) 48 | } 49 | } 50 | 51 | var r1 error 52 | if rf, ok := ret.Get(1).(func([]byte) error); ok { 53 | r1 = rf(data) 54 | } else { 55 | r1 = ret.Error(1) 56 | } 57 | 58 | return r0, r1 59 | } 60 | 61 | // NewPacket provides a mock function with given fields: data 62 | func (_m *Crypto) NewPacket(data []byte) (*packet.Packet, error) { 63 | ret := _m.Called(data) 64 | 65 | var r0 *packet.Packet 66 | if rf, ok := ret.Get(0).(func([]byte) *packet.Packet); ok { 67 | r0 = rf(data) 68 | } else { 69 | if ret.Get(0) != nil { 70 | r0 = ret.Get(0).(*packet.Packet) 71 | } 72 | } 73 | 74 | var r1 error 75 | if rf, ok := ret.Get(1).(func([]byte) error); ok { 76 | r1 = rf(data) 77 | } else { 78 | r1 = ret.Error(1) 79 | } 80 | 81 | return r0, r1 82 | } 83 | 84 | // VerifyPacket provides a mock function with given fields: pkt 85 | func (_m *Crypto) VerifyPacket(pkt *packet.Packet) error { 86 | ret := _m.Called(pkt) 87 | 88 | var r0 error 89 | if rf, ok := ret.Get(0).(func(*packet.Packet) error); ok { 90 | r0 = rf(pkt) 91 | } else { 92 | r0 = ret.Error(0) 93 | } 94 | 95 | return r0 96 | } 97 | -------------------------------------------------------------------------------- /protocol/packet/packet.go: -------------------------------------------------------------------------------- 1 | package packet 2 | 3 | import ( 4 | "bytes" 5 | "crypto/md5" 6 | "encoding/hex" 7 | "fmt" 8 | "time" 9 | 10 | "net" 11 | 12 | "github.com/lunixbochs/struc" 13 | ) 14 | 15 | const checksumLengthBytes = 16 16 | 17 | // See https://github.com/OpenMiHome/mihome-binary-protocol/blob/master/doc/PROTOCOL.md for 18 | // documentation 19 | type Header struct { 20 | Magic uint16 21 | Length uint16 22 | F1 uint32 // Unknown field. 23 | DeviceID uint32 24 | Stamp uint32 25 | Checksum []byte `struc:"[16]byte"` 26 | } 27 | 28 | // Meta provides (optional) additional context about incoming packets. 29 | type Meta struct { 30 | DecodeTime time.Time 31 | Addr *net.UDPAddr 32 | } 33 | 34 | type Packet struct { 35 | Meta Meta 36 | Header Header 37 | Data []byte 38 | } 39 | 40 | func (p *Packet) Serialize() []byte { 41 | var buf bytes.Buffer 42 | err := struc.Pack(&buf, &p.Header) 43 | if err != nil { 44 | panic(err) 45 | } 46 | 47 | buf.Write(p.Data) 48 | return buf.Bytes() 49 | } 50 | 51 | func (p *Packet) CalcChecksum() ([]byte, error) { 52 | h := md5.New() 53 | _, err := h.Write(p.Serialize()) 54 | if err != nil { 55 | return nil, err 56 | } 57 | 58 | return h.Sum(nil), nil 59 | } 60 | 61 | func (p *Packet) WriteChecksum() error { 62 | checksum, err := p.CalcChecksum() 63 | if err != nil { 64 | return err 65 | } 66 | 67 | p.Header.Checksum = checksum 68 | return nil 69 | } 70 | 71 | func (p *Packet) DataLength() int { 72 | return len(p.Data) 73 | } 74 | 75 | func (p *Packet) HasZeroChecksum() bool { 76 | for _, b := range p.Header.Checksum { 77 | if b != 0 { 78 | return false 79 | } 80 | } 81 | return true 82 | } 83 | 84 | func (p *Packet) Verify(deviceToken []byte) error { 85 | var tmpPacket Packet 86 | tmpPacket = *p 87 | tmpPacket.Header.Checksum = deviceToken 88 | 89 | calculated, err := tmpPacket.CalcChecksum() 90 | if err != nil { 91 | return err 92 | } 93 | 94 | if !bytes.Equal(calculated, p.Header.Checksum) { 95 | return fmt.Errorf("Checksum could not be verified. Expected %s, got %s.", 96 | hex.EncodeToString(calculated), hex.EncodeToString(p.Header.Checksum)) 97 | } 98 | return nil 99 | } 100 | 101 | func Decode(data []byte, addr *net.UDPAddr) (*Packet, error) { 102 | meta := Meta{DecodeTime: time.Now(), Addr: addr} 103 | header := Header{} 104 | struc.Unpack(bytes.NewBuffer(data[:32]), &header) 105 | 106 | p := &Packet{ 107 | Meta: meta, 108 | Header: header, 109 | Data: data[32:], 110 | } 111 | return p, nil 112 | } 113 | 114 | // New creates a new packet 115 | func New(deviceId uint32, deviceToken []byte, stamp uint32, data []byte) *Packet { 116 | header := Header{ 117 | Magic: 0x2131, 118 | Length: uint16(32 + len(data)), 119 | F1: 0x0, 120 | DeviceID: deviceId, 121 | Stamp: stamp, 122 | Checksum: deviceToken, 123 | } 124 | 125 | p := &Packet{ 126 | Header: header, 127 | Data: data, 128 | } 129 | return p 130 | } 131 | 132 | func NewHello() *Packet { 133 | checksum := bytes.Repeat([]byte{0xff}, 16) 134 | return &Packet{ 135 | Header: Header{ 136 | Magic: 0x2131, 137 | Length: 0x0020, 138 | F1: 0xffffffff, 139 | DeviceID: 0xffffffff, 140 | Stamp: 0xffffffff, 141 | Checksum: checksum, 142 | }, 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /protocol/packet/packet_test.go: -------------------------------------------------------------------------------- 1 | package packet 2 | 3 | import ( 4 | "bytes" 5 | "encoding/binary" 6 | "net" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | var payload []byte 13 | var deviceToken []byte 14 | var packet *Packet 15 | 16 | func setUp(t *testing.T) func(t *testing.T) { 17 | payload = []byte("Hello World") 18 | deviceToken = bytes.Repeat([]byte{0xFF, 0x00}, checksumLengthBytes/2) 19 | packet = New(0xAAAABBBB, deviceToken, 0xCCCCDDDD, payload) 20 | return func(t *testing.T) { 21 | payload = nil 22 | deviceToken = nil 23 | packet = nil 24 | } 25 | } 26 | 27 | // Ensure a known packet decodes and re-serializes to the same value 28 | func TestDecode(t *testing.T) { 29 | tearDown := setUp(t) 30 | defer tearDown(t) 31 | 32 | data := packet.Serialize() 33 | newPkt, err := Decode(data, &net.UDPAddr{}) 34 | assert.NoError(t, err) 35 | 36 | // Ensure the new packet serializes the same 37 | newData := newPkt.Serialize() 38 | assert.Equal(t, data, newData) 39 | } 40 | 41 | // Ensure that serialize orders fields correctly 42 | func TestPacket_Serialize(t *testing.T) { 43 | tearDown := setUp(t) 44 | defer tearDown(t) 45 | 46 | data := packet.Serialize() 47 | assert.Equal(t, []byte{0x21, 0x31}, data[0:2]) 48 | assert.Equal(t, uint16(len(payload)+32), binary.BigEndian.Uint16(data[2:4])) 49 | assert.Equal(t, uint32(0), binary.BigEndian.Uint32(data[4:8])) 50 | assert.Equal(t, uint32(0xAAAABBBB), binary.BigEndian.Uint32(data[8:12])) 51 | assert.Equal(t, uint32(0xCCCCDDDD), binary.BigEndian.Uint32(data[12:16])) 52 | assert.Equal(t, deviceToken, data[16:32]) 53 | } 54 | 55 | // Ensure that CalcChecksum outputs a known good value. 56 | func TestPacket_CalcChecksum(t *testing.T) { 57 | tearDown := setUp(t) 58 | defer tearDown(t) 59 | 60 | chk, err := packet.CalcChecksum() 61 | assert.NoError(t, err) 62 | 63 | assert.Equal(t, []byte{0x41, 0x7a, 0x35, 0xb4, 0x21, 0x5c, 0x64, 0xad, 0xd5, 0xe0, 0xcd, 0x3f, 0x51, 0x47, 0xf5, 0xc2}, chk) 64 | } 65 | 66 | // Ensure that the written checksum is the same value that is calculated. 67 | func TestPacket_WriteChecksum(t *testing.T) { 68 | tearDown := setUp(t) 69 | defer tearDown(t) 70 | 71 | packet.Header.Checksum = bytes.Repeat([]byte{0xFF}, checksumLengthBytes) 72 | 73 | chk, err := packet.CalcChecksum() 74 | assert.NoError(t, err) 75 | 76 | err = packet.WriteChecksum() 77 | assert.NoError(t, err) 78 | assert.Equal(t, chk, packet.Header.Checksum) 79 | 80 | } 81 | 82 | // Ensure that the packet data length returned is expected. 83 | func TestPacket_DataLength(t *testing.T) { 84 | tearDown := setUp(t) 85 | defer tearDown(t) 86 | 87 | assert.Equal(t, len(payload), packet.DataLength()) 88 | } 89 | 90 | // Test verification with a malformed packet checksum. 91 | func TestPacket_VerifyFail1(t *testing.T) { 92 | tearDown := setUp(t) 93 | defer tearDown(t) 94 | 95 | err := packet.WriteChecksum() 96 | assert.NoError(t, err) 97 | 98 | // Mutate the checksum 99 | packet.Header.Checksum[0]++ 100 | 101 | err = packet.Verify(deviceToken) 102 | assert.NotNil(t, err) 103 | } 104 | 105 | // Test verification with a malformed packet checksum. 106 | func TestPacket_VerifyFail2(t *testing.T) { 107 | tearDown := setUp(t) 108 | defer tearDown(t) 109 | 110 | err := packet.WriteChecksum() 111 | assert.NoError(t, err) 112 | 113 | // Mutate the data 114 | packet.Data[0]++ 115 | 116 | err = packet.Verify(deviceToken) 117 | assert.NotNil(t, err) 118 | } 119 | 120 | // Test verification with a known good packet. 121 | func TestPacket_VerifySuccess(t *testing.T) { 122 | tearDown := setUp(t) 123 | defer tearDown(t) 124 | 125 | err := packet.WriteChecksum() 126 | assert.NoError(t, err) 127 | 128 | err = packet.Verify(deviceToken) 129 | assert.NoError(t, err) 130 | } 131 | 132 | // Ensure that calls to Verify does not mutate the packet header. 133 | func TestPacket_VerifyNoMutation(t *testing.T) { 134 | tearDown := setUp(t) 135 | defer tearDown(t) 136 | 137 | packet.WriteChecksum() 138 | 139 | before := packet.Serialize() 140 | packet.Verify(deviceToken) 141 | 142 | after := packet.Serialize() 143 | assert.Equal(t, before, after) 144 | } 145 | -------------------------------------------------------------------------------- /protocol/protocol.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | "time" 7 | 8 | "github.com/benbjohnson/clock" 9 | "github.com/nickw444/miio-go/common" 10 | "github.com/nickw444/miio-go/device" 11 | "github.com/nickw444/miio-go/protocol/packet" 12 | "github.com/nickw444/miio-go/protocol/tokens" 13 | "github.com/nickw444/miio-go/protocol/transport" 14 | "github.com/nickw444/miio-go/subscription" 15 | ) 16 | 17 | type Protocol interface { 18 | subscription.SubscriptionTarget 19 | 20 | Discover() error 21 | SetExpiryTime(duration time.Duration) 22 | } 23 | 24 | type protocol struct { 25 | subscription.SubscriptionTarget 26 | port int 27 | expireAfter time.Duration 28 | clock clock.Clock 29 | lastDiscovery time.Time 30 | tokenStore tokens.TokenStore 31 | 32 | broadcastDev device.Device 33 | quitChan chan struct{} 34 | devicesMutex sync.RWMutex 35 | devices map[uint32]device.Device 36 | ignoredDevices map[uint32]bool 37 | 38 | transport transport.Transport 39 | deviceFactory DeviceFactory 40 | cryptoFactory CryptoFactory 41 | } 42 | 43 | type DeviceFactory func(deviceId uint32, outbound transport.Outbound, seen time.Time, token []byte) device.Device 44 | type CryptoFactory func(deviceID uint32, deviceToken []byte, initialStamp uint32, stampTime time.Time) (packet.Crypto, error) 45 | 46 | type ProtocolConfig struct { 47 | // Required config 48 | BroadcastIP net.IP 49 | TokenStore tokens.TokenStore 50 | 51 | // Optional config 52 | ListenPort int // Defaults to a random system-assigned port if not provided. 53 | } 54 | 55 | func NewProtocol(c ProtocolConfig) (Protocol, error) { 56 | clk := clock.New() 57 | var listenAddr *net.UDPAddr 58 | if c.ListenPort != 0 { 59 | listenAddr = &net.UDPAddr{Port: c.ListenPort} 60 | } 61 | 62 | s, err := net.ListenUDP("udp", listenAddr) 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | t := transport.NewTransport(s) 68 | deviceFactory := func(deviceId uint32, outbound transport.Outbound, seen time.Time, token []byte) device.Device { 69 | return device.New(deviceId, outbound, seen, token) 70 | } 71 | cryptoFactory := func(deviceID uint32, deviceToken []byte, initialStamp uint32, stampTime time.Time) (packet.Crypto, error) { 72 | return packet.NewCrypto(deviceID, deviceToken, initialStamp, stampTime, clk) 73 | } 74 | 75 | addr := &net.UDPAddr{ 76 | IP: c.BroadcastIP, 77 | Port: 54321, 78 | } 79 | broadcastDev := deviceFactory(0, t.NewOutbound(nil, addr), time.Time{}, nil) 80 | 81 | p := newProtocol(clk, t, deviceFactory, cryptoFactory, subscription.NewTarget(), broadcastDev, c.TokenStore) 82 | p.start() 83 | return p, nil 84 | } 85 | 86 | func newProtocol(c clock.Clock, transport transport.Transport, deviceFactory DeviceFactory, 87 | crptoFactory CryptoFactory, target subscription.SubscriptionTarget, broadcastDev device.Device, 88 | tokenStore tokens.TokenStore) *protocol { 89 | 90 | p := &protocol{ 91 | SubscriptionTarget: target, 92 | transport: transport, 93 | deviceFactory: deviceFactory, 94 | cryptoFactory: crptoFactory, 95 | clock: c, 96 | quitChan: make(chan struct{}), 97 | devices: make(map[uint32]device.Device), 98 | broadcastDev: broadcastDev, 99 | tokenStore: tokenStore, 100 | ignoredDevices: make(map[uint32]bool), 101 | } 102 | return p 103 | } 104 | 105 | func (p *protocol) start() { 106 | go p.dispatcher() 107 | } 108 | 109 | func (p *protocol) SetExpiryTime(duration time.Duration) { 110 | p.expireAfter = duration 111 | } 112 | 113 | func (p *protocol) dispatcher() { 114 | pkts := p.transport.Inbound().Packets() 115 | for { 116 | select { 117 | case <-p.quitChan: 118 | return 119 | default: 120 | } 121 | 122 | select { 123 | case <-p.quitChan: 124 | return 125 | case pkt := <-pkts: 126 | go p.process(pkt) 127 | } 128 | } 129 | } 130 | 131 | func (p *protocol) Discover() error { 132 | common.Log.Debugf("Running discovery...") 133 | 134 | if p.lastDiscovery.After(time.Time{}) { 135 | // If the device has not been seen recently, it should be expired. 136 | cutoff := time.Now().Add(p.expireAfter * -1) 137 | var expiredDevices []device.Device 138 | p.devicesMutex.RLock() 139 | for _, dev := range p.devices { 140 | if dev.Seen().Before(cutoff) { 141 | common.Log.Debugf("Device %d is stale. Last Seen at %s", dev.ID(), dev.Seen()) 142 | expiredDevices = append(expiredDevices, dev) 143 | } 144 | } 145 | p.devicesMutex.RUnlock() 146 | 147 | for _, dev := range expiredDevices { 148 | common.Log.Debugf("Removing expired device with id %d.", dev.ID()) 149 | p.removeDevice(dev.ID()) 150 | dev.Close() 151 | err := p.Publish(common.EventExpiredDevice{dev}) 152 | if err != nil { 153 | common.Log.Warn(err) 154 | } 155 | } 156 | } 157 | if err := p.broadcastDev.Discover(); err != nil { 158 | return err 159 | } 160 | 161 | p.lastDiscovery = time.Now() 162 | return nil 163 | } 164 | func (p *protocol) process(pkt *packet.Packet) { 165 | common.Log.Debugf("Processing incoming packet from %s", pkt.Meta.Addr) 166 | if ok, _ := p.ignoredDevices[pkt.Header.DeviceID]; ok { 167 | return 168 | } 169 | 170 | dev := p.getDevice(pkt.Header.DeviceID) 171 | if dev == nil && pkt.DataLength() == 0 { 172 | // Device response to a Hello packet. 173 | common.Log.Debugf("Device with id %d responded to Hello packet.", pkt.Header.DeviceID) 174 | 175 | deviceToken := pkt.Header.Checksum 176 | if pkt.HasZeroChecksum() { 177 | token, err := p.tokenStore.GetToken(pkt.Header.DeviceID) 178 | if err != nil { 179 | common.Log.Warnf("Device with id %d is not revealing its token. You must manually collect this token and add it to the store.", pkt.Header.DeviceID) 180 | p.ignoredDevices[pkt.Header.DeviceID] = true 181 | p.Publish(common.EventNewMaskedDevice{DeviceID: pkt.Header.DeviceID}) 182 | return 183 | } else { 184 | common.Log.Debugf("Loaded token for device %d from store", pkt.Header.DeviceID) 185 | deviceToken = token 186 | } 187 | } 188 | 189 | crypto, err := p.cryptoFactory(pkt.Header.DeviceID, deviceToken, pkt.Header.Stamp, 190 | pkt.Meta.DecodeTime) 191 | if err != nil { 192 | panic(err) 193 | } 194 | 195 | t := p.transport.NewOutbound(crypto, pkt.Meta.Addr) 196 | baseDev := p.deviceFactory(pkt.Header.DeviceID, t, pkt.Meta.DecodeTime, deviceToken) 197 | 198 | // Store the provisional device for now to ensure it can handle subsequent 199 | // packets that may occur during classification. 200 | p.addDevice(baseDev) 201 | 202 | common.Log.Infof("Classifying device...") 203 | dev, err := device.Classify(baseDev) 204 | if err != nil { 205 | panic(err) 206 | } 207 | 208 | // Store the specific device and publish a new device event. 209 | p.addDevice(dev) 210 | p.Publish(common.EventNewDevice{Device: dev}) 211 | } else if dev != nil { 212 | // Known device. Handle the incoming packet. 213 | err := dev.Handle(pkt) 214 | if err != nil { 215 | common.Log.Errorf("Unable to process packet %v for device %d. Error %s", pkt, dev.ID(), err) 216 | } 217 | } else { 218 | common.Log.Errorf("Unable to process packet %v. Device unknown.", pkt) 219 | } 220 | } 221 | 222 | func (p *protocol) removeDevice(id uint32) { 223 | p.devicesMutex.Lock() 224 | delete(p.devices, id) 225 | p.devicesMutex.Unlock() 226 | } 227 | 228 | func (p *protocol) addDevice(dev device.Device) { 229 | p.devicesMutex.Lock() 230 | p.devices[dev.ID()] = dev 231 | p.devicesMutex.Unlock() 232 | } 233 | 234 | func (p *protocol) getDevice(id uint32) device.Device { 235 | p.devicesMutex.RLock() 236 | dev, ok := p.devices[id] 237 | p.devicesMutex.RUnlock() 238 | if !ok { 239 | return nil 240 | } 241 | return dev 242 | } 243 | -------------------------------------------------------------------------------- /protocol/protocol_test.go: -------------------------------------------------------------------------------- 1 | package protocol 2 | 3 | import ( 4 | "net" 5 | "sync" 6 | "testing" 7 | "time" 8 | 9 | "github.com/benbjohnson/clock" 10 | "github.com/nickw444/miio-go/device" 11 | deviceMocks "github.com/nickw444/miio-go/device/mocks" 12 | "github.com/nickw444/miio-go/protocol/packet" 13 | "github.com/nickw444/miio-go/protocol/tokens" 14 | "github.com/nickw444/miio-go/protocol/transport" 15 | transportMocks "github.com/nickw444/miio-go/protocol/transport/mocks" 16 | subscriptionMocks "github.com/nickw444/miio-go/subscription/common/mocks" 17 | "github.com/stretchr/testify/assert" 18 | "github.com/stretchr/testify/mock" 19 | ) 20 | 21 | func Protocol_SetUp() (tt struct { 22 | clk *clock.Mock 23 | transport *mockTransport 24 | deviceFactory DeviceFactory 25 | cryptoFactory CryptoFactory 26 | subscriptionTarget *subscriptionMocks.SubscriptionTarget 27 | protocol *protocol 28 | devices []*deviceMocks.Device 29 | broadcastDevice *deviceMocks.Device 30 | }) { 31 | tt.clk = clock.NewMock() 32 | tt.transport = &mockTransport{new(transportMocks.Inbound)} 33 | tt.subscriptionTarget = new(subscriptionMocks.SubscriptionTarget) 34 | tt.deviceFactory = func(deviceId uint32, outbound transport.Outbound, seen time.Time, token []byte) device.Device { 35 | d := &deviceMocks.Device{} 36 | tt.devices = append(tt.devices, d) 37 | return d 38 | } 39 | tt.cryptoFactory = func(deviceID uint32, deviceToken []byte, initialStamp uint32, stampTime time.Time) (packet.Crypto, error) { 40 | return nil, nil 41 | } 42 | tt.broadcastDevice = &deviceMocks.Device{} 43 | tt.broadcastDevice.On("Discover").Return(nil) 44 | tt.protocol = newProtocol(tt.clk, tt.transport, tt.deviceFactory, tt.cryptoFactory, tt.subscriptionTarget, 45 | tt.broadcastDevice, tokens.New()) 46 | return 47 | } 48 | 49 | // Ensure that the broadcast device has Discover called on it. 50 | func TestProtocol_Discover(t *testing.T) { 51 | tt := Protocol_SetUp() 52 | 53 | err := tt.protocol.Discover() 54 | assert.NoError(t, err) 55 | tt.broadcastDevice.AssertCalled(t, "Discover") 56 | } 57 | 58 | // Ensure that inbound's Packets method is called. 59 | func TestProtocol_dispatcher(t *testing.T) { 60 | tt := Protocol_SetUp() 61 | wg := sync.WaitGroup{} 62 | wg.Add(1) 63 | 64 | ch := make(chan *packet.Packet) 65 | // Hack to convert the channel to a read-only channel (what the mock expects) 66 | ro := func(c chan *packet.Packet) <-chan *packet.Packet { 67 | return c 68 | } 69 | 70 | tt.transport.inbound.On("Packets").Return(ro(ch)).Run(func(args mock.Arguments) { 71 | wg.Done() 72 | }) 73 | tt.protocol.start() 74 | wg.Wait() 75 | tt.transport.inbound.AssertExpectations(t) 76 | } 77 | 78 | type mockTransport struct { 79 | inbound *transportMocks.Inbound 80 | } 81 | 82 | func (m *mockTransport) Inbound() transport.Inbound { 83 | return m.inbound 84 | } 85 | 86 | func (*mockTransport) NewOutbound(crypto packet.Crypto, dest net.Addr) transport.Outbound { 87 | return &transportMocks.Outbound{} 88 | } 89 | 90 | func (*mockTransport) Close() error { 91 | return nil 92 | } 93 | -------------------------------------------------------------------------------- /protocol/tokens/token_store.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | import ( 4 | "bufio" 5 | "encoding/hex" 6 | "fmt" 7 | "os" 8 | "strconv" 9 | "strings" 10 | ) 11 | 12 | type TokenStore interface { 13 | LoadFile(inputPath string) error 14 | WriteFile(outputPath string) error 15 | GetToken(deviceId uint32) ([]byte, error) 16 | AddDevice(deviceId uint32, token []byte) error 17 | RemoveDevice(deviceId uint32) 18 | } 19 | 20 | type tokenStore struct { 21 | tokens map[uint32][]byte 22 | } 23 | 24 | func New() TokenStore { 25 | return &tokenStore{ 26 | tokens: make(map[uint32][]byte), 27 | } 28 | } 29 | 30 | func FromFile(filePath string) (TokenStore, error) { 31 | store := New() 32 | err := store.LoadFile(filePath) 33 | return store, err 34 | } 35 | 36 | func (t *tokenStore) LoadFile(inputPath string) error { 37 | if _, err := os.Stat(inputPath); os.IsNotExist(err) { 38 | // File doesn't exist, so don't load anything. 39 | return nil 40 | } 41 | 42 | f, err := os.Open(inputPath) 43 | defer f.Close() 44 | 45 | if err != nil { 46 | return err 47 | } 48 | 49 | scanner := bufio.NewScanner(f) 50 | scanner.Split(bufio.ScanLines) 51 | 52 | for scanner.Scan() { 53 | line := strings.TrimSpace(scanner.Text()) 54 | 55 | if len(line) == 0 { 56 | continue 57 | } 58 | 59 | if strings.HasPrefix(line, "#") { 60 | continue 61 | } 62 | 63 | splitLine := strings.Split(line, "=") 64 | if len(splitLine) != 2 { 65 | return fmt.Errorf("Malformed line: %s", line) 66 | } 67 | 68 | deviceId, err := strconv.ParseUint(splitLine[0], 10, 32) 69 | if err != nil { 70 | return fmt.Errorf("Malformed line: %s", line) 71 | } 72 | 73 | token, err := hex.DecodeString(splitLine[1]) 74 | if err != nil { 75 | return fmt.Errorf("Malformed line: %s", line) 76 | } 77 | 78 | t.tokens[uint32(deviceId)] = token 79 | } 80 | 81 | return nil 82 | } 83 | 84 | func (t *tokenStore) WriteFile(outputPath string) error { 85 | f, err := os.Create(outputPath) 86 | if err != nil { 87 | return err 88 | } 89 | defer f.Close() 90 | 91 | for deviceId, tokenBytes := range t.tokens { 92 | tokenStr := hex.EncodeToString(tokenBytes) 93 | _, err := f.WriteString(fmt.Sprintf("%d=%s\n", deviceId, tokenStr)) 94 | if err != nil { 95 | return err 96 | } 97 | } 98 | return nil 99 | } 100 | 101 | func (t *tokenStore) GetToken(deviceId uint32) ([]byte, error) { 102 | if val, ok := t.tokens[deviceId]; ok { 103 | return val, nil 104 | } 105 | return nil, fmt.Errorf("Device ID %d does not exist in token store", deviceId) 106 | } 107 | 108 | func (t *tokenStore) AddDevice(deviceId uint32, token []byte) error { 109 | t.tokens[deviceId] = token 110 | return nil 111 | } 112 | 113 | func (t *tokenStore) RemoveDevice(deviceId uint32) { 114 | delete(t.tokens, deviceId) 115 | } 116 | -------------------------------------------------------------------------------- /protocol/tokens/token_store_test.go: -------------------------------------------------------------------------------- 1 | package tokens 2 | 3 | import ( 4 | "testing" 5 | 6 | "encoding/hex" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestFromFile(t *testing.T) { 12 | store, err := FromFile("tokens.example.txt") 13 | assert.NoError(t, err) 14 | token, err := store.GetToken(123456) 15 | assert.NoError(t, err) 16 | assert.Equal(t, "ffffffffffffffffffffffffffffffff", hex.EncodeToString(token)) 17 | token, err = store.GetToken(111222) 18 | assert.NoError(t, err) 19 | assert.Equal(t, "badcafefffffffffffffffffffffffff", hex.EncodeToString(token)) 20 | } 21 | -------------------------------------------------------------------------------- /protocol/tokens/tokens.example.txt: -------------------------------------------------------------------------------- 1 | # MiiO Go TokenStore data file. 2 | # Format: 3 | # deviceID=deviceToken (hex) 4 | 5 | 123456=ffffffffffffffffffffffffffffffff 6 | 111222=badcafefffffffffffffffffffffffff 7 | -------------------------------------------------------------------------------- /protocol/transport/inbound.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/nickw444/miio-go/protocol/packet" 7 | ) 8 | 9 | // An inbound transport is a channel'ed abstraction around a net.UDPConn. 10 | // Provides an abstraction around inbound packets on the network to allow 11 | // using the existing protocol implementation without a miIO network handy. 12 | // Consumers of this interface should never close the underlying UDP 13 | // connection without first calling Stop(). 14 | type Inbound interface { 15 | Packets() <-chan *packet.Packet 16 | Stop() error 17 | } 18 | 19 | type inbound struct { 20 | socket InboundConn 21 | packets chan *packet.Packet 22 | quitChan chan struct{} 23 | stopped bool 24 | } 25 | 26 | // InboundConn is an abstraction around net.UDPConn to allow 27 | // mocking during tests. 28 | type InboundConn interface { 29 | ReadFromUDP(b []byte) (int, *net.UDPAddr, error) 30 | } 31 | 32 | func NewInbound(socket InboundConn) Inbound { 33 | return newInbound(socket) 34 | } 35 | 36 | func newInbound(socket InboundConn) *inbound { 37 | i := &inbound{ 38 | socket: socket, 39 | packets: make(chan *packet.Packet), 40 | quitChan: make(chan struct{}), 41 | stopped: false, 42 | } 43 | go i.reader() 44 | return i 45 | } 46 | 47 | // A goroutine that continuously pulls data from the given UDP socket 48 | // and decodes inbound packets. 49 | func (i *inbound) reader() { 50 | for { 51 | select { 52 | case <-i.quitChan: 53 | return 54 | default: 55 | buf := make([]byte, 1024) 56 | n, addr, err := i.socket.ReadFromUDP(buf) 57 | 58 | if i.stopped { 59 | // No need to process this packet as we have been stopped. 60 | return 61 | } 62 | 63 | if err != nil { 64 | // TODO NW remove panic 65 | panic(err) 66 | continue 67 | } 68 | 69 | pkt, err := packet.Decode(buf[:n], addr) 70 | if err != nil { 71 | // TODO NW remove panic 72 | panic(err) 73 | continue 74 | } 75 | 76 | i.packets <- pkt 77 | } 78 | } 79 | } 80 | 81 | func (i *inbound) Packets() <-chan *packet.Packet { 82 | return i.packets 83 | } 84 | 85 | func (i *inbound) Stop() error { 86 | close(i.quitChan) 87 | i.stopped = true 88 | return nil 89 | } 90 | -------------------------------------------------------------------------------- /protocol/transport/mocks/Conn.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | net "net" 7 | 8 | mock "github.com/stretchr/testify/mock" 9 | ) 10 | 11 | // Conn is an autogenerated mock type for the Conn type 12 | type Conn struct { 13 | mock.Mock 14 | } 15 | 16 | // Close provides a mock function with given fields: 17 | func (_m *Conn) Close() error { 18 | ret := _m.Called() 19 | 20 | var r0 error 21 | if rf, ok := ret.Get(0).(func() error); ok { 22 | r0 = rf() 23 | } else { 24 | r0 = ret.Error(0) 25 | } 26 | 27 | return r0 28 | } 29 | 30 | // ReadFromUDP provides a mock function with given fields: b 31 | func (_m *Conn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { 32 | ret := _m.Called(b) 33 | 34 | var r0 int 35 | if rf, ok := ret.Get(0).(func([]byte) int); ok { 36 | r0 = rf(b) 37 | } else { 38 | r0 = ret.Get(0).(int) 39 | } 40 | 41 | var r1 *net.UDPAddr 42 | if rf, ok := ret.Get(1).(func([]byte) *net.UDPAddr); ok { 43 | r1 = rf(b) 44 | } else { 45 | if ret.Get(1) != nil { 46 | r1 = ret.Get(1).(*net.UDPAddr) 47 | } 48 | } 49 | 50 | var r2 error 51 | if rf, ok := ret.Get(2).(func([]byte) error); ok { 52 | r2 = rf(b) 53 | } else { 54 | r2 = ret.Error(2) 55 | } 56 | 57 | return r0, r1, r2 58 | } 59 | 60 | // WriteTo provides a mock function with given fields: _a0, _a1 61 | func (_m *Conn) WriteTo(_a0 []byte, _a1 net.Addr) (int, error) { 62 | ret := _m.Called(_a0, _a1) 63 | 64 | var r0 int 65 | if rf, ok := ret.Get(0).(func([]byte, net.Addr) int); ok { 66 | r0 = rf(_a0, _a1) 67 | } else { 68 | r0 = ret.Get(0).(int) 69 | } 70 | 71 | var r1 error 72 | if rf, ok := ret.Get(1).(func([]byte, net.Addr) error); ok { 73 | r1 = rf(_a0, _a1) 74 | } else { 75 | r1 = ret.Error(1) 76 | } 77 | 78 | return r0, r1 79 | } 80 | -------------------------------------------------------------------------------- /protocol/transport/mocks/Inbound.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | packet "github.com/nickw444/miio-go/protocol/packet" 7 | mock "github.com/stretchr/testify/mock" 8 | ) 9 | 10 | // Inbound is an autogenerated mock type for the Inbound type 11 | type Inbound struct { 12 | mock.Mock 13 | } 14 | 15 | // Packets provides a mock function with given fields: 16 | func (_m *Inbound) Packets() <-chan *packet.Packet { 17 | ret := _m.Called() 18 | 19 | var r0 <-chan *packet.Packet 20 | if rf, ok := ret.Get(0).(func() <-chan *packet.Packet); ok { 21 | r0 = rf() 22 | } else { 23 | if ret.Get(0) != nil { 24 | r0 = ret.Get(0).(<-chan *packet.Packet) 25 | } 26 | } 27 | 28 | return r0 29 | } 30 | 31 | // Stop provides a mock function with given fields: 32 | func (_m *Inbound) Stop() error { 33 | ret := _m.Called() 34 | 35 | var r0 error 36 | if rf, ok := ret.Get(0).(func() error); ok { 37 | r0 = rf() 38 | } else { 39 | r0 = ret.Error(0) 40 | } 41 | 42 | return r0 43 | } 44 | -------------------------------------------------------------------------------- /protocol/transport/mocks/InboundConn.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | net "net" 7 | 8 | mock "github.com/stretchr/testify/mock" 9 | ) 10 | 11 | // InboundConn is an autogenerated mock type for the InboundConn type 12 | type InboundConn struct { 13 | mock.Mock 14 | } 15 | 16 | // ReadFromUDP provides a mock function with given fields: b 17 | func (_m *InboundConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { 18 | ret := _m.Called(b) 19 | 20 | var r0 int 21 | if rf, ok := ret.Get(0).(func([]byte) int); ok { 22 | r0 = rf(b) 23 | } else { 24 | r0 = ret.Get(0).(int) 25 | } 26 | 27 | var r1 *net.UDPAddr 28 | if rf, ok := ret.Get(1).(func([]byte) *net.UDPAddr); ok { 29 | r1 = rf(b) 30 | } else { 31 | if ret.Get(1) != nil { 32 | r1 = ret.Get(1).(*net.UDPAddr) 33 | } 34 | } 35 | 36 | var r2 error 37 | if rf, ok := ret.Get(2).(func([]byte) error); ok { 38 | r2 = rf(b) 39 | } else { 40 | r2 = ret.Error(2) 41 | } 42 | 43 | return r0, r1, r2 44 | } 45 | -------------------------------------------------------------------------------- /protocol/transport/mocks/Outbound.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | packet "github.com/nickw444/miio-go/protocol/packet" 7 | mock "github.com/stretchr/testify/mock" 8 | ) 9 | 10 | // Outbound is an autogenerated mock type for the Outbound type 11 | type Outbound struct { 12 | mock.Mock 13 | } 14 | 15 | // Call provides a mock function with given fields: method, params 16 | func (_m *Outbound) Call(method string, params interface{}) ([]byte, error) { 17 | ret := _m.Called(method, params) 18 | 19 | var r0 []byte 20 | if rf, ok := ret.Get(0).(func(string, interface{}) []byte); ok { 21 | r0 = rf(method, params) 22 | } else { 23 | if ret.Get(0) != nil { 24 | r0 = ret.Get(0).([]byte) 25 | } 26 | } 27 | 28 | var r1 error 29 | if rf, ok := ret.Get(1).(func(string, interface{}) error); ok { 30 | r1 = rf(method, params) 31 | } else { 32 | r1 = ret.Error(1) 33 | } 34 | 35 | return r0, r1 36 | } 37 | 38 | // CallAndDeserialize provides a mock function with given fields: method, params, resp 39 | func (_m *Outbound) CallAndDeserialize(method string, params interface{}, resp interface{}) error { 40 | ret := _m.Called(method, params, resp) 41 | 42 | var r0 error 43 | if rf, ok := ret.Get(0).(func(string, interface{}, interface{}) error); ok { 44 | r0 = rf(method, params, resp) 45 | } else { 46 | r0 = ret.Error(0) 47 | } 48 | 49 | return r0 50 | } 51 | 52 | // Handle provides a mock function with given fields: pkt 53 | func (_m *Outbound) Handle(pkt *packet.Packet) error { 54 | ret := _m.Called(pkt) 55 | 56 | var r0 error 57 | if rf, ok := ret.Get(0).(func(*packet.Packet) error); ok { 58 | r0 = rf(pkt) 59 | } else { 60 | r0 = ret.Error(0) 61 | } 62 | 63 | return r0 64 | } 65 | 66 | // Send provides a mock function with given fields: _a0 67 | func (_m *Outbound) Send(_a0 *packet.Packet) error { 68 | ret := _m.Called(_a0) 69 | 70 | var r0 error 71 | if rf, ok := ret.Get(0).(func(*packet.Packet) error); ok { 72 | r0 = rf(_a0) 73 | } else { 74 | r0 = ret.Error(0) 75 | } 76 | 77 | return r0 78 | } 79 | -------------------------------------------------------------------------------- /protocol/transport/mocks/OutboundConn.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | net "net" 7 | 8 | mock "github.com/stretchr/testify/mock" 9 | ) 10 | 11 | // OutboundConn is an autogenerated mock type for the OutboundConn type 12 | type OutboundConn struct { 13 | mock.Mock 14 | } 15 | 16 | // WriteTo provides a mock function with given fields: _a0, _a1 17 | func (_m *OutboundConn) WriteTo(_a0 []byte, _a1 net.Addr) (int, error) { 18 | ret := _m.Called(_a0, _a1) 19 | 20 | var r0 int 21 | if rf, ok := ret.Get(0).(func([]byte, net.Addr) int); ok { 22 | r0 = rf(_a0, _a1) 23 | } else { 24 | r0 = ret.Get(0).(int) 25 | } 26 | 27 | var r1 error 28 | if rf, ok := ret.Get(1).(func([]byte, net.Addr) error); ok { 29 | r1 = rf(_a0, _a1) 30 | } else { 31 | r1 = ret.Error(1) 32 | } 33 | 34 | return r0, r1 35 | } 36 | -------------------------------------------------------------------------------- /protocol/transport/mocks/Transport.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | net "net" 7 | 8 | mock "github.com/stretchr/testify/mock" 9 | 10 | packet "github.com/nickw444/miio-go/protocol/packet" 11 | 12 | transport "github.com/nickw444/miio-go/protocol/transport" 13 | ) 14 | 15 | // Transport is an autogenerated mock type for the Transport type 16 | type Transport struct { 17 | mock.Mock 18 | } 19 | 20 | // Close provides a mock function with given fields: 21 | func (_m *Transport) Close() error { 22 | ret := _m.Called() 23 | 24 | var r0 error 25 | if rf, ok := ret.Get(0).(func() error); ok { 26 | r0 = rf() 27 | } else { 28 | r0 = ret.Error(0) 29 | } 30 | 31 | return r0 32 | } 33 | 34 | // Inbound provides a mock function with given fields: 35 | func (_m *Transport) Inbound() transport.Inbound { 36 | ret := _m.Called() 37 | 38 | var r0 transport.Inbound 39 | if rf, ok := ret.Get(0).(func() transport.Inbound); ok { 40 | r0 = rf() 41 | } else { 42 | if ret.Get(0) != nil { 43 | r0 = ret.Get(0).(transport.Inbound) 44 | } 45 | } 46 | 47 | return r0 48 | } 49 | 50 | // NewOutbound provides a mock function with given fields: crypto, dest 51 | func (_m *Transport) NewOutbound(crypto packet.Crypto, dest net.Addr) transport.Outbound { 52 | ret := _m.Called(crypto, dest) 53 | 54 | var r0 transport.Outbound 55 | if rf, ok := ret.Get(0).(func(packet.Crypto, net.Addr) transport.Outbound); ok { 56 | r0 = rf(crypto, dest) 57 | } else { 58 | if ret.Get(0) != nil { 59 | r0 = ret.Get(0).(transport.Outbound) 60 | } 61 | } 62 | 63 | return r0 64 | } 65 | -------------------------------------------------------------------------------- /protocol/transport/outbound.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "encoding/hex" 5 | "encoding/json" 6 | "fmt" 7 | "net" 8 | "time" 9 | 10 | "sync" 11 | 12 | "github.com/benbjohnson/clock" 13 | "github.com/nickw444/miio-go/common" 14 | "github.com/nickw444/miio-go/protocol/packet" 15 | ) 16 | 17 | type OutboundConn interface { 18 | WriteTo([]byte, net.Addr) (int, error) 19 | } 20 | 21 | // Outbound transport is an abstraction around a net.UDPConn for outbound interaction with 22 | // a networked miIO device. Consumers should never close the underlying socket and continue 23 | // to use the service. Outbound also provides retry and timeout logic. 24 | type Outbound interface { 25 | // Handle handles incoming packets and triggers waiting continuations. 26 | Handle(pkt *packet.Packet) error 27 | // Call makes a call, waits for a Response and returns the raw bytes returned. 28 | Call(method string, params interface{}) ([]byte, error) 29 | // CallAndDeserialize makes a call, waits for a Response and deserialises the JSON 30 | // payload into `ret`. 31 | CallAndDeserialize(method string, params interface{}, resp interface{}) error 32 | // Send will send a raw packet without waiting for a Response. 33 | Send(packet *packet.Packet) error 34 | } 35 | 36 | type outbound struct { 37 | maxRetries int 38 | timeout time.Duration 39 | 40 | clock clock.Clock 41 | crypto packet.Crypto 42 | 43 | dest net.Addr 44 | socket OutboundConn 45 | 46 | nextReqID uint32 47 | continuationsMutex sync.RWMutex 48 | continuations map[uint32]chan []byte 49 | } 50 | 51 | func NewOutbound(crypto packet.Crypto, dest net.Addr, socket OutboundConn) Outbound { 52 | return newOutbound(10, time.Millisecond*200, clock.New(), crypto, dest, socket) 53 | } 54 | 55 | func newOutbound(maxRetries int, timeout time.Duration, clock clock.Clock, crypto packet.Crypto, 56 | dest net.Addr, socket OutboundConn) *outbound { 57 | return &outbound{ 58 | maxRetries: maxRetries, 59 | timeout: timeout, 60 | clock: clock, 61 | crypto: crypto, 62 | dest: dest, 63 | socket: socket, 64 | 65 | nextReqID: 1, 66 | continuations: make(map[uint32]chan []byte), 67 | } 68 | } 69 | 70 | func (o *outbound) Handle(pkt *packet.Packet) error { 71 | if pkt.Header.Length <= 32 { 72 | return nil 73 | } 74 | 75 | err := o.crypto.VerifyPacket(pkt) 76 | if err != nil { 77 | panic(err) 78 | } 79 | 80 | data, err := o.crypto.Decrypt(pkt.Data) 81 | if err != nil { 82 | panic(err) 83 | } 84 | 85 | resp := Response{} 86 | err = json.Unmarshal(data, &resp) 87 | if err != nil { 88 | return err 89 | } 90 | 91 | // Lookup the Response ID and pass data to the appropriate continuation goroutine. 92 | o.continuationsMutex.RLock() 93 | if ch, ok := o.continuations[resp.ID]; ok { 94 | common.Log.Debugf("Callback with ID %d was reconciled", resp.ID) 95 | ch <- data 96 | } else { 97 | common.Log.Debugf("Unable to reconcile callback for resp id %d", resp.ID) 98 | } 99 | o.continuationsMutex.RUnlock() 100 | 101 | return nil 102 | } 103 | 104 | func (o *outbound) Call(method string, params interface{}) ([]byte, error) { 105 | // Setup a continuation channel 106 | o.continuationsMutex.Lock() 107 | requestId := o.nextReqID 108 | o.nextReqID++ 109 | ch := make(chan []byte) 110 | o.continuations[requestId] = ch 111 | o.continuationsMutex.Unlock() 112 | 113 | // Ensure we cleanup. 114 | defer func() { 115 | o.continuationsMutex.Lock() 116 | delete(o.continuations, requestId) 117 | close(ch) 118 | o.continuationsMutex.Unlock() 119 | }() 120 | 121 | for i := 0; i < o.maxRetries+1; i++ { 122 | // Perform the call 123 | err := o.call(requestId, method, params) 124 | if err != nil { 125 | return nil, err 126 | } 127 | 128 | select { 129 | case data := <-ch: 130 | return data, nil 131 | case <-o.clock.After(o.timeout): 132 | common.Log.Debugf("Timed out whilst waiting for Response.") 133 | continue 134 | } 135 | } 136 | 137 | err := fmt.Errorf("Max retries exceeded whilst sending Request to device %s", o.dest) 138 | common.Log.Error(err) 139 | return nil, err 140 | } 141 | 142 | func (o *outbound) CallAndDeserialize(method string, params interface{}, ret interface{}) error { 143 | resp, err := o.Call(method, params) 144 | if err != nil { 145 | return err 146 | } 147 | 148 | return json.Unmarshal(resp, ret) 149 | } 150 | 151 | func (o *outbound) Send(packet *packet.Packet) error { 152 | common.Log.Debugf("Sending packet with checksum: %s", hex.EncodeToString(packet.Header.Checksum)) 153 | _, err := o.socket.WriteTo(packet.Serialize(), o.dest) 154 | return err 155 | } 156 | 157 | // Call out to the device, but don't wait for a Response. 158 | func (o *outbound) call(requestId uint32, method string, params interface{}) (err error) { 159 | data, err := json.Marshal(Request{ 160 | ID: requestId, 161 | Method: method, 162 | Params: params, 163 | }) 164 | if err != nil { 165 | return 166 | } 167 | 168 | p, err := o.crypto.NewPacket(data) 169 | if err != nil { 170 | return 171 | } 172 | 173 | err = o.Send(p) 174 | return 175 | } 176 | 177 | type Response struct { 178 | ID uint32 `json:"id"` 179 | Result interface{} `json:"result"` 180 | } 181 | 182 | type Request struct { 183 | ID uint32 `json:"id"` 184 | Method string `json:"method"` 185 | Params interface{} `json:"params"` 186 | } 187 | -------------------------------------------------------------------------------- /protocol/transport/transport.go: -------------------------------------------------------------------------------- 1 | package transport 2 | 3 | import ( 4 | "net" 5 | 6 | "github.com/nickw444/miio-go/protocol/packet" 7 | ) 8 | 9 | type Conn interface { 10 | InboundConn 11 | OutboundConn 12 | Close() error 13 | } 14 | 15 | type Transport interface { 16 | Inbound() Inbound 17 | NewOutbound(crypto packet.Crypto, dest net.Addr) Outbound 18 | Close() error 19 | } 20 | 21 | type transport struct { 22 | inbound Inbound 23 | outbounds []Outbound 24 | socket Conn 25 | } 26 | 27 | func NewTransport(socket Conn) Transport { 28 | return &transport{ 29 | socket: socket, 30 | } 31 | } 32 | 33 | func (t *transport) Inbound() Inbound { 34 | if t.inbound == nil { 35 | t.inbound = NewInbound(t.socket) 36 | } 37 | return t.inbound 38 | } 39 | 40 | func (t *transport) NewOutbound(crypto packet.Crypto, dest net.Addr) Outbound { 41 | o := NewOutbound(crypto, dest, t.socket) 42 | t.outbounds = append(t.outbounds, o) 43 | return o 44 | } 45 | 46 | func (t *transport) Close() error { 47 | err := t.inbound.Stop() 48 | if err != nil { 49 | return err 50 | } 51 | err = t.socket.Close() 52 | return err 53 | } 54 | -------------------------------------------------------------------------------- /simulator/.gitignore: -------------------------------------------------------------------------------- 1 | simulator 2 | -------------------------------------------------------------------------------- /simulator/README.md: -------------------------------------------------------------------------------- 1 | # MiiO Simulator 2 | 3 | This simulator has been created to allow easier development when away from 4 | hardware devices and to test the entire integration of the stack, rather 5 | than the unit testing currently available in this library. 6 | 7 | ``` 8 | usage: simulator [] [] 9 | 10 | Flags: 11 | --help Show context-sensitive help (also try --help-long and --help-man). 12 | --device-id=12341234 Device ID for the simulated device 13 | --device-token=00ff00ff00ff00ff00ff00ff00ff00ff 14 | The device token to use for encrypted payloads 15 | --(no-)reveal-token Whether or not to reveal the device token 16 | 17 | Args: 18 | [] Device to simulate 19 | 20 | ``` 21 | 22 | ### Available Devices 23 | 24 | - `powerplug` 25 | - `yeelight` 26 | 27 | Devices are a work in progress. All devices are a built from a collection of capabilities. 28 | Available capabilities include: 29 | 30 | - power 31 | - info 32 | 33 | ## Building 34 | 35 | ``` 36 | go build 37 | ``` 38 | -------------------------------------------------------------------------------- /simulator/capability/capability.go: -------------------------------------------------------------------------------- 1 | package capability 2 | 3 | type Capability interface { 4 | MaybeGetProp(propName string) (handled bool, value interface{}, err error) 5 | MaybeHandle(method string, params interface{}) (handled bool, data interface{}, err error) 6 | } 7 | -------------------------------------------------------------------------------- /simulator/capability/info.go: -------------------------------------------------------------------------------- 1 | package capability 2 | 3 | import "github.com/nickw444/miio-go/common" 4 | 5 | type Info struct { 6 | Model string 7 | } 8 | 9 | func (i *Info) MaybeGetProp(propName string) (handled bool, value interface{}, err error) { 10 | return false, nil, nil 11 | } 12 | 13 | func (i *Info) MaybeHandle(method string, params interface{}) (handled bool, data interface{}, err error) { 14 | if method == "miIO.info" { 15 | info := common.DeviceInfo{ 16 | Model: i.Model, 17 | FirmwareVersion: "SIM_0", 18 | MacAddress: "00:00:00:00:00:00", 19 | HardwareVersion: "SIM_0", 20 | } 21 | return true, info, nil 22 | } 23 | 24 | return false, nil, nil 25 | } 26 | -------------------------------------------------------------------------------- /simulator/capability/light.go: -------------------------------------------------------------------------------- 1 | package capability 2 | 3 | type Light struct { 4 | brightness int 5 | color int 6 | } 7 | 8 | func (l *Light) MaybeGetProp(propName string) (handled bool, value interface{}, err error) { 9 | switch propName { 10 | case "bright": 11 | return true, l.brightness, nil 12 | case "hsv": 13 | return true, 0, nil 14 | case "rgb": 15 | return true, 0, nil 16 | default: 17 | return false, nil, nil 18 | } 19 | } 20 | 21 | func (l *Light) MaybeHandle(method string, params interface{}) (handled bool, data interface{}, err error) { 22 | switch method { 23 | case "set_bright": 24 | return true, nil, nil 25 | case "set_rgb": 26 | return true, nil, nil 27 | case "set_hsv": 28 | return true, nil, nil 29 | default: 30 | return false, nil, nil 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /simulator/capability/power.go: -------------------------------------------------------------------------------- 1 | package capability 2 | 3 | type Power struct { 4 | power bool 5 | } 6 | 7 | func (p *Power) MaybeGetProp(propName string) (handled bool, value interface{}, err error) { 8 | if propName == "power" { 9 | var power string 10 | if p.power { 11 | power = "on" 12 | } else { 13 | power = "off" 14 | } 15 | return true, power, nil 16 | } 17 | 18 | return false, nil, nil 19 | } 20 | 21 | func (p *Power) MaybeHandle(method string, params interface{}) (handled bool, data interface{}, err error) { 22 | if method == "set_power" { 23 | value := params.([]interface{})[0].(string) 24 | if value == "on" { 25 | p.power = true 26 | } else if value == "off" { 27 | p.power = false 28 | } 29 | return true, nil, nil 30 | } 31 | return false, nil, nil 32 | } 33 | -------------------------------------------------------------------------------- /simulator/device/device.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "time" 8 | 9 | "github.com/benbjohnson/clock" 10 | "github.com/nickw444/miio-go/protocol/packet" 11 | "github.com/nickw444/miio-go/protocol/transport" 12 | "github.com/nickw444/miio-go/simulator/capability" 13 | "github.com/sirupsen/logrus" 14 | ) 15 | 16 | var log = logrus.New() 17 | 18 | type SimulatedDevice interface { 19 | HandlePacket(pkt *packet.Packet) (*packet.Packet, error) 20 | HandleDiscover(pkt *packet.Packet) (*packet.Packet, error) 21 | } 22 | 23 | type BaseDevice struct { 24 | capabilities []capability.Capability 25 | crypto packet.Crypto 26 | deviceToken []byte 27 | deviceID uint32 28 | revealToken bool 29 | } 30 | 31 | func NewBaseDevice(deviceID uint32, deviceToken []byte, revealToken bool) (*BaseDevice, error) { 32 | crypto, err := packet.NewCrypto(deviceID, deviceToken, 1, time.Now(), clock.New()) 33 | if err != nil { 34 | return nil, err 35 | } 36 | return &BaseDevice{ 37 | capabilities: []capability.Capability{}, 38 | deviceID: deviceID, 39 | deviceToken: deviceToken, 40 | crypto: crypto, 41 | revealToken: revealToken, 42 | }, nil 43 | } 44 | 45 | func (b *BaseDevice) DecodeRequest(pkt *packet.Packet) (*transport.Request, error) { 46 | err := b.crypto.VerifyPacket(pkt) 47 | if err != nil { 48 | panic(err) 49 | } 50 | 51 | data, err := b.crypto.Decrypt(pkt.Data) 52 | if err != nil { 53 | panic(err) 54 | } 55 | 56 | request := transport.Request{} 57 | err = json.Unmarshal(data, &request) 58 | return &request, err 59 | } 60 | 61 | func (b *BaseDevice) PackResponse(response interface{}) (*packet.Packet, error) { 62 | data, err := json.Marshal(&response) 63 | if err != nil { 64 | return nil, err 65 | } 66 | 67 | log.Infof("Response Data: %s", string(data)) 68 | return b.crypto.NewPacket(data) 69 | } 70 | 71 | func (b *BaseDevice) HandleDiscover(pkt *packet.Packet) (*packet.Packet, error) { 72 | var checksumValue []byte 73 | if b.revealToken { 74 | checksumValue = b.deviceToken 75 | } else { 76 | checksumValue = bytes.Repeat([]byte{0x00}, 16) 77 | } 78 | return packet.New(b.deviceID, checksumValue, 1, []byte{}), nil 79 | } 80 | 81 | func (b *BaseDevice) getPropFromCapabilities(propName string) (interface{}, error) { 82 | for _, c := range b.capabilities { 83 | handled, result, err := c.MaybeGetProp(propName) 84 | if err != nil { 85 | return nil, err 86 | } 87 | 88 | if handled { 89 | return result, nil 90 | } 91 | } 92 | 93 | return nil, fmt.Errorf("No capabilities available to return data for get_prop '%s'", propName) 94 | } 95 | 96 | func (b *BaseDevice) HandlePacket(pkt *packet.Packet) (*packet.Packet, error) { 97 | req, err := b.DecodeRequest(pkt) 98 | if err != nil { 99 | return nil, err 100 | } 101 | 102 | log.Infof("Request received. ID=%d method=%s, params=%s", req.ID, req.Method, req.Params) 103 | 104 | switch req.Method { 105 | case "get_prop": 106 | props := req.Params.([]interface{}) 107 | retProps := []interface{}{} 108 | 109 | for _, prop := range props { 110 | propName := prop.(string) 111 | value, err := b.getPropFromCapabilities(propName) 112 | if err != nil { 113 | return nil, err 114 | } 115 | 116 | retProps = append(retProps, value) 117 | } 118 | 119 | return b.PackResponse(transport.Response{ 120 | ID: req.ID, 121 | Result: retProps, 122 | }) 123 | 124 | default: 125 | for _, c := range b.capabilities { 126 | handled, result, err := c.MaybeHandle(req.Method, req.Params) 127 | if err != nil { 128 | return nil, err 129 | } 130 | if handled { 131 | return b.PackResponse(transport.Response{ 132 | ID: req.ID, 133 | Result: result, 134 | }) 135 | } 136 | } 137 | 138 | log.Warnf("No capabilities able to handle method %s", req.Method) 139 | } 140 | 141 | return nil, nil 142 | } 143 | 144 | func (b *BaseDevice) AddCapability(c capability.Capability) { 145 | b.capabilities = append(b.capabilities, c) 146 | } 147 | -------------------------------------------------------------------------------- /simulator/device/powerplug.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "github.com/nickw444/miio-go/simulator/capability" 5 | ) 6 | 7 | type SimulatedPowerPlug struct { 8 | *BaseDevice 9 | } 10 | 11 | func NewSimulatedPowerPlug(baseDevice *BaseDevice) *SimulatedPowerPlug { 12 | baseDevice.AddCapability(&capability.Info{ 13 | Model: "chuangmi.plug.m1", 14 | }) 15 | baseDevice.AddCapability(&capability.Power{}) 16 | return &SimulatedPowerPlug{ 17 | BaseDevice: baseDevice, 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /simulator/device/yeelight.go: -------------------------------------------------------------------------------- 1 | package device 2 | 3 | import ( 4 | "github.com/nickw444/miio-go/simulator/capability" 5 | ) 6 | 7 | type SimulatedYeelight struct { 8 | *BaseDevice 9 | } 10 | 11 | func NewSimulatedYeelight(baseDevice *BaseDevice) *SimulatedYeelight { 12 | baseDevice.AddCapability(&capability.Info{ 13 | Model: "yeelink.light.color1", 14 | }) 15 | baseDevice.AddCapability(&capability.Power{}) 16 | baseDevice.AddCapability(&capability.Light{}) 17 | 18 | return &SimulatedYeelight{ 19 | BaseDevice: baseDevice, 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /simulator/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "bytes" 5 | "encoding/hex" 6 | "net" 7 | 8 | "github.com/alecthomas/kingpin" 9 | "github.com/nickw444/miio-go/protocol/packet" 10 | "github.com/nickw444/miio-go/protocol/transport" 11 | "github.com/nickw444/miio-go/simulator/device" 12 | "github.com/sirupsen/logrus" 13 | ) 14 | 15 | var log = logrus.New() 16 | 17 | func main() { 18 | defaultToken := bytes.Repeat([]byte{0x00, 0xff}, 8) 19 | var ( 20 | deviceType = kingpin.Arg("device", "Device to simulate").Default("yeelight").Enum("yeelight", "powerplug") 21 | deviceId = kingpin.Flag("device-id", "Device ID for the simulated device").Default("12341234").Uint32() 22 | deviceToken = kingpin.Flag("device-token", "The device token to use for encrypted payloads").Default(hex.EncodeToString(defaultToken)).HexBytes() 23 | revealToken = kingpin.Flag("reveal-token", "Whether or not to reveal the device token").Default("true").Bool() 24 | ) 25 | 26 | kingpin.Parse() 27 | 28 | listenAddr := &net.UDPAddr{Port: 54321} 29 | s, err := net.ListenUDP("udp", listenAddr) 30 | if err != nil { 31 | panic(err) 32 | } 33 | 34 | inbound := transport.NewInbound(s) 35 | log.Infof("Creating device with id=%d token=%s revealToken=%t", 36 | *deviceId, hex.EncodeToString(*deviceToken), *revealToken) 37 | baseDev, err := device.NewBaseDevice(*deviceId, *deviceToken, *revealToken) 38 | if err != nil { 39 | panic(err) 40 | } 41 | 42 | var dev device.SimulatedDevice 43 | if *deviceType == "yeelight" { 44 | dev = device.NewSimulatedYeelight(baseDev) 45 | } else if *deviceType == "powerplug" { 46 | dev = device.NewSimulatedPowerPlug(baseDev) 47 | } else { 48 | panic("Unknown device type.") 49 | } 50 | 51 | for pkt := range inbound.Packets() { 52 | var resp *packet.Packet 53 | var err error 54 | if pkt.Header.DeviceID == 0xffffffff { 55 | log.Info("Discovery packet received") 56 | resp, err = dev.HandleDiscover(pkt) 57 | } else { 58 | resp, err = dev.HandlePacket(pkt) 59 | } 60 | 61 | if err != nil { 62 | panic(err) 63 | } 64 | if resp != nil { 65 | s.WriteToUDP(resp.Serialize(), pkt.Meta.Addr) 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /subscription/common/common.go: -------------------------------------------------------------------------------- 1 | package common 2 | 3 | type Subscription interface { 4 | ID() string 5 | Events() <-chan interface{} 6 | Write(event interface{}) error 7 | Close() error 8 | } 9 | 10 | type SubscriptionTarget interface { 11 | HasSubscribers() bool 12 | Publish(event interface{}) error 13 | NewSubscription() (Subscription, error) 14 | RemoveSubscription(s Subscription) error 15 | CloseAllSubscriptions() error 16 | } 17 | -------------------------------------------------------------------------------- /subscription/common/mocks/Subscription.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import mock "github.com/stretchr/testify/mock" 6 | 7 | // Subscription is an autogenerated mock type for the Subscription type 8 | type Subscription struct { 9 | mock.Mock 10 | } 11 | 12 | // Close provides a mock function with given fields: 13 | func (_m *Subscription) Close() error { 14 | ret := _m.Called() 15 | 16 | var r0 error 17 | if rf, ok := ret.Get(0).(func() error); ok { 18 | r0 = rf() 19 | } else { 20 | r0 = ret.Error(0) 21 | } 22 | 23 | return r0 24 | } 25 | 26 | // Events provides a mock function with given fields: 27 | func (_m *Subscription) Events() <-chan interface{} { 28 | ret := _m.Called() 29 | 30 | var r0 <-chan interface{} 31 | if rf, ok := ret.Get(0).(func() <-chan interface{}); ok { 32 | r0 = rf() 33 | } else { 34 | if ret.Get(0) != nil { 35 | r0 = ret.Get(0).(<-chan interface{}) 36 | } 37 | } 38 | 39 | return r0 40 | } 41 | 42 | // ID provides a mock function with given fields: 43 | func (_m *Subscription) ID() string { 44 | ret := _m.Called() 45 | 46 | var r0 string 47 | if rf, ok := ret.Get(0).(func() string); ok { 48 | r0 = rf() 49 | } else { 50 | r0 = ret.Get(0).(string) 51 | } 52 | 53 | return r0 54 | } 55 | 56 | // Write provides a mock function with given fields: event 57 | func (_m *Subscription) Write(event interface{}) error { 58 | ret := _m.Called(event) 59 | 60 | var r0 error 61 | if rf, ok := ret.Get(0).(func(interface{}) error); ok { 62 | r0 = rf(event) 63 | } else { 64 | r0 = ret.Error(0) 65 | } 66 | 67 | return r0 68 | } 69 | -------------------------------------------------------------------------------- /subscription/common/mocks/SubscriptionTarget.go: -------------------------------------------------------------------------------- 1 | // Code generated by mockery v1.0.0. DO NOT EDIT. 2 | 3 | package mocks 4 | 5 | import ( 6 | common "github.com/nickw444/miio-go/subscription/common" 7 | mock "github.com/stretchr/testify/mock" 8 | ) 9 | 10 | // SubscriptionTarget is an autogenerated mock type for the SubscriptionTarget type 11 | type SubscriptionTarget struct { 12 | mock.Mock 13 | } 14 | 15 | // CloseAllSubscriptions provides a mock function with given fields: 16 | func (_m *SubscriptionTarget) CloseAllSubscriptions() error { 17 | ret := _m.Called() 18 | 19 | var r0 error 20 | if rf, ok := ret.Get(0).(func() error); ok { 21 | r0 = rf() 22 | } else { 23 | r0 = ret.Error(0) 24 | } 25 | 26 | return r0 27 | } 28 | 29 | // HasSubscribers provides a mock function with given fields: 30 | func (_m *SubscriptionTarget) HasSubscribers() bool { 31 | ret := _m.Called() 32 | 33 | var r0 bool 34 | if rf, ok := ret.Get(0).(func() bool); ok { 35 | r0 = rf() 36 | } else { 37 | r0 = ret.Get(0).(bool) 38 | } 39 | 40 | return r0 41 | } 42 | 43 | // NewSubscription provides a mock function with given fields: 44 | func (_m *SubscriptionTarget) NewSubscription() (common.Subscription, error) { 45 | ret := _m.Called() 46 | 47 | var r0 common.Subscription 48 | if rf, ok := ret.Get(0).(func() common.Subscription); ok { 49 | r0 = rf() 50 | } else { 51 | if ret.Get(0) != nil { 52 | r0 = ret.Get(0).(common.Subscription) 53 | } 54 | } 55 | 56 | var r1 error 57 | if rf, ok := ret.Get(1).(func() error); ok { 58 | r1 = rf() 59 | } else { 60 | r1 = ret.Error(1) 61 | } 62 | 63 | return r0, r1 64 | } 65 | 66 | // Publish provides a mock function with given fields: event 67 | func (_m *SubscriptionTarget) Publish(event interface{}) error { 68 | ret := _m.Called(event) 69 | 70 | var r0 error 71 | if rf, ok := ret.Get(0).(func(interface{}) error); ok { 72 | r0 = rf(event) 73 | } else { 74 | r0 = ret.Error(0) 75 | } 76 | 77 | return r0 78 | } 79 | 80 | // RemoveSubscription provides a mock function with given fields: s 81 | func (_m *SubscriptionTarget) RemoveSubscription(s common.Subscription) error { 82 | ret := _m.Called(s) 83 | 84 | var r0 error 85 | if rf, ok := ret.Get(0).(func(common.Subscription) error); ok { 86 | r0 = rf(s) 87 | } else { 88 | r0 = ret.Error(0) 89 | } 90 | 91 | return r0 92 | } 93 | -------------------------------------------------------------------------------- /subscription/subscription.go: -------------------------------------------------------------------------------- 1 | package subscription 2 | 3 | import ( 4 | "github.com/nickw444/miio-go/subscription/common" 5 | "github.com/nickw444/miio-go/subscription/target" 6 | ) 7 | 8 | func NewTarget() common.SubscriptionTarget { 9 | return target.NewTarget() 10 | } 11 | 12 | type SubscriptionTarget = common.SubscriptionTarget 13 | type Subscription = common.Subscription 14 | -------------------------------------------------------------------------------- /subscription/subscription/subscription.go: -------------------------------------------------------------------------------- 1 | package subscription 2 | 3 | import ( 4 | "errors" 5 | "sync" 6 | "time" 7 | 8 | "github.com/nickw444/miio-go/subscription/common" 9 | "github.com/satori/go.uuid" 10 | ) 11 | 12 | const ( 13 | defaultTimeout = 2 * time.Second 14 | chanSize = 16 15 | ) 16 | 17 | var ( 18 | ErrClosed = errors.New("Subscription is already closed.") 19 | ErrTimeout = errors.New("Timed out.") 20 | ) 21 | 22 | type subscription struct { 23 | id uuid.UUID 24 | wg sync.WaitGroup 25 | quitChan chan struct{} 26 | events chan interface{} 27 | target common.SubscriptionTarget 28 | } 29 | 30 | func NewSubscription(target common.SubscriptionTarget) common.Subscription { 31 | return &subscription{ 32 | id: uuid.NewV4(), 33 | events: make(chan interface{}, chanSize), 34 | quitChan: make(chan struct{}), 35 | target: target, 36 | } 37 | } 38 | 39 | func (s *subscription) ID() string { 40 | return s.id.String() 41 | } 42 | 43 | func (s *subscription) Events() <-chan interface{} { 44 | return s.events 45 | } 46 | 47 | func (s *subscription) Write(event interface{}) error { 48 | s.wg.Add(1) 49 | defer s.wg.Done() 50 | timeout := time.After(defaultTimeout) 51 | select { 52 | case <-s.quitChan: 53 | return ErrClosed 54 | default: 55 | } 56 | select { 57 | case <-s.quitChan: 58 | return ErrClosed 59 | case s.events <- event: 60 | return nil 61 | case <-timeout: 62 | // TODO NW Warnings 63 | //Log.Debugf("Timeout on subscription %s", s.ID) 64 | return ErrTimeout 65 | } 66 | } 67 | 68 | func (s *subscription) Close() error { 69 | select { 70 | case <-s.quitChan: 71 | // TODO NW Warnings 72 | //Log.Warnf("Subscription %s already closed", s.ID) 73 | return ErrClosed 74 | default: 75 | close(s.quitChan) 76 | s.wg.Wait() 77 | close(s.events) 78 | } 79 | return s.target.RemoveSubscription(s) 80 | } 81 | -------------------------------------------------------------------------------- /subscription/subscription/subscription_test.go: -------------------------------------------------------------------------------- 1 | package subscription 2 | -------------------------------------------------------------------------------- /subscription/target/subscription_target.go: -------------------------------------------------------------------------------- 1 | package target 2 | 3 | import ( 4 | "github.com/nickw444/miio-go/subscription/common" 5 | "github.com/nickw444/miio-go/subscription/subscription" 6 | ) 7 | 8 | type subscriptionTarget struct { 9 | subscriptions map[string]common.Subscription 10 | } 11 | 12 | func NewTarget() common.SubscriptionTarget { 13 | return &subscriptionTarget{ 14 | subscriptions: make(map[string]common.Subscription), 15 | } 16 | } 17 | 18 | func (t *subscriptionTarget) HasSubscribers() bool { 19 | return len(t.subscriptions) > 0 20 | } 21 | 22 | func (t *subscriptionTarget) Publish(event interface{}) error { 23 | for _, sub := range t.subscriptions { 24 | err := sub.Write(event) 25 | if err != nil { 26 | return err 27 | } 28 | } 29 | return nil 30 | } 31 | 32 | func (t *subscriptionTarget) NewSubscription() (common.Subscription, error) { 33 | sub := subscription.NewSubscription(t) 34 | t.subscriptions[sub.ID()] = sub 35 | return sub, nil 36 | } 37 | 38 | func (t *subscriptionTarget) RemoveSubscription(s common.Subscription) error { 39 | delete(t.subscriptions, s.ID()) 40 | return nil 41 | } 42 | 43 | func (t *subscriptionTarget) CloseAllSubscriptions() error { 44 | for _, sub := range t.subscriptions { 45 | err := t.RemoveSubscription(sub) 46 | if err != nil { 47 | return err 48 | } 49 | err = sub.Close() 50 | if err != nil { 51 | return err 52 | } 53 | } 54 | 55 | return nil 56 | } 57 | -------------------------------------------------------------------------------- /subscription/target/subscription_target_test.go: -------------------------------------------------------------------------------- 1 | package target 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/nickw444/miio-go/subscription/common" 7 | "github.com/nickw444/miio-go/subscription/common/mocks" 8 | "github.com/stretchr/testify/assert" 9 | "github.com/stretchr/testify/mock" 10 | ) 11 | 12 | func SubscriptionTarget_Setup() (tt struct { 13 | target *subscriptionTarget 14 | }) { 15 | tt.target = &subscriptionTarget{ 16 | subscriptions: make(map[string]common.Subscription), 17 | } 18 | return 19 | } 20 | 21 | func TestSubscriptionTarget_HasSubscribers(t *testing.T) { 22 | tt := SubscriptionTarget_Setup() 23 | assert.False(t, tt.target.HasSubscribers()) 24 | 25 | s, err := tt.target.NewSubscription() 26 | assert.NoError(t, err) 27 | assert.True(t, tt.target.HasSubscribers()) 28 | 29 | s.Close() 30 | assert.NoError(t, err) 31 | assert.False(t, tt.target.HasSubscribers()) 32 | } 33 | 34 | // Ensure that an event is published to all subscriptions 35 | func TestSubscriptionTarget_Publish(t *testing.T) { 36 | tt := SubscriptionTarget_Setup() 37 | 38 | sub1 := &mocks.Subscription{} 39 | sub2 := &mocks.Subscription{} 40 | 41 | sub1.On("Write", mock.Anything).Return(nil).Once() 42 | sub2.On("Write", mock.Anything).Return(nil).Once() 43 | 44 | tt.target.subscriptions = map[string]common.Subscription{ 45 | "01": sub1, 46 | "02": sub2, 47 | } 48 | 49 | err := tt.target.Publish(struct{}{}) 50 | assert.NoError(t, err) 51 | sub1.AssertExpectations(t) 52 | sub2.AssertExpectations(t) 53 | } 54 | 55 | // Ensure new subscriptions are tracked 56 | func TestSubscriptionTarget_NewSubscription(t *testing.T) { 57 | tt := SubscriptionTarget_Setup() 58 | 59 | assert.Len(t, tt.target.subscriptions, 0) 60 | s, err := tt.target.NewSubscription() 61 | assert.NoError(t, err) 62 | 63 | assert.Len(t, tt.target.subscriptions, 1) 64 | assert.Equal(t, s, tt.target.subscriptions[s.ID()]) 65 | } 66 | 67 | // Ensure subscriptions are correctly removed. 68 | func TestSubscriptionTarget_RemoveSubscription(t *testing.T) { 69 | tt := SubscriptionTarget_Setup() 70 | 71 | assert.Len(t, tt.target.subscriptions, 0) 72 | s1, err := tt.target.NewSubscription() 73 | assert.NoError(t, err) 74 | assert.Len(t, tt.target.subscriptions, 1) 75 | s2, err := tt.target.NewSubscription() 76 | assert.NoError(t, err) 77 | assert.Len(t, tt.target.subscriptions, 2) 78 | 79 | tt.target.RemoveSubscription(s1) 80 | assert.Len(t, tt.target.subscriptions, 1) 81 | tt.target.RemoveSubscription(s2) 82 | assert.Len(t, tt.target.subscriptions, 0) 83 | } 84 | 85 | // Ensure all subscriptions are closed 86 | func TestSubscriptionTarget_CloseAllSubscriptions(t *testing.T) { 87 | tt := SubscriptionTarget_Setup() 88 | 89 | sub1 := &mocks.Subscription{} 90 | sub2 := &mocks.Subscription{} 91 | sub1.On("Write", mock.Anything).Return(nil).Once() 92 | sub2.On("Write", mock.Anything).Return(nil).Once() 93 | sub1.On("ID").Return("01").Once() 94 | sub2.On("ID").Return("02").Once() 95 | sub1.On("Close").Return(nil).Once() 96 | sub2.On("Close").Return(nil).Once() 97 | tt.target.subscriptions = map[string]common.Subscription{ 98 | "01": sub1, 99 | "02": sub2, 100 | } 101 | 102 | err := tt.target.CloseAllSubscriptions() 103 | assert.NoError(t, err) 104 | 105 | assert.Len(t, tt.target.subscriptions, 0) 106 | } 107 | -------------------------------------------------------------------------------- /tokencapture/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "net" 5 | 6 | "time" 7 | 8 | "flag" 9 | 10 | "encoding/hex" 11 | 12 | "github.com/nickw444/miio-go/common" 13 | "github.com/nickw444/miio-go/device" 14 | "github.com/nickw444/miio-go/protocol/packet" 15 | "github.com/nickw444/miio-go/protocol/tokens" 16 | "github.com/nickw444/miio-go/protocol/transport" 17 | "github.com/sirupsen/logrus" 18 | ) 19 | 20 | var ( 21 | log = logrus.New() 22 | ignoredDevices = make(map[uint32]bool) 23 | t transport.Transport 24 | quitChan chan struct{} 25 | broadcastDev device.Device 26 | tokenStore tokens.TokenStore 27 | 28 | tokenStoreFile = flag.String("file", "tokens.txt", "Path to the token store to update") 29 | ) 30 | 31 | func main() { 32 | miioDebug := flag.Bool("miio-debug", false, "Enable miio debug") 33 | flag.Parse() 34 | 35 | if *miioDebug { 36 | miioLogger := logrus.New() 37 | miioLogger.SetLevel(logrus.DebugLevel) 38 | common.SetLogger(miioLogger) 39 | } 40 | 41 | var err error 42 | tokenStore, err = tokens.FromFile(*tokenStoreFile) 43 | if err != nil { 44 | log.Panic(err) 45 | } 46 | 47 | var listenAddr *net.UDPAddr 48 | s, err := net.ListenUDP("udp", listenAddr) 49 | if err != nil { 50 | panic(err) 51 | } 52 | t = transport.NewTransport(s) 53 | 54 | addr := &net.UDPAddr{ 55 | IP: net.IPv4(255, 255, 255, 255), 56 | Port: 54321, 57 | } 58 | devTsp := t.NewOutbound(nil, addr) 59 | broadcastDev = device.New(0, devTsp, time.Time{}, nil) 60 | 61 | go dispatcher() 62 | 63 | tick := time.Tick(5 * time.Second) 64 | broadcastDev.Discover() 65 | for { 66 | select { 67 | case <-quitChan: 68 | return 69 | default: 70 | } 71 | select { 72 | case <-quitChan: 73 | return 74 | case <-tick: 75 | broadcastDev.Discover() 76 | } 77 | } 78 | } 79 | 80 | func dispatcher() { 81 | pkts := t.Inbound().Packets() 82 | for { 83 | select { 84 | case <-quitChan: 85 | return 86 | default: 87 | } 88 | 89 | select { 90 | case <-quitChan: 91 | return 92 | case pkt := <-pkts: 93 | go process(pkt) 94 | } 95 | } 96 | } 97 | 98 | func process(pkt *packet.Packet) { 99 | if _, ok := ignoredDevices[pkt.Header.DeviceID]; ok { 100 | return 101 | } 102 | 103 | if pkt.DataLength() == 0 { 104 | if pkt.HasZeroChecksum() { 105 | log.Warnf("Device with Id %d is not revealing its token. Reset this device and connect to its network to retrieve the token. Ignoring it.", pkt.Header.DeviceID) 106 | ignoredDevices[pkt.Header.DeviceID] = true 107 | return 108 | } else { 109 | tokenStore.AddDevice(pkt.Header.DeviceID, pkt.Header.Checksum) 110 | err := tokenStore.WriteFile(*tokenStoreFile) 111 | if err != nil { 112 | log.Panic(err) 113 | } 114 | log.Infof("Got token for device with id %d: %s", pkt.Header.DeviceID, hex.EncodeToString(pkt.Header.Checksum)) 115 | } 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /tools/wireshark/miio.lua: -------------------------------------------------------------------------------- 1 | -- A Wireshark dissector written in for https://github.com/diacritic/wssdl. 2 | -- Install by adding to ~/.config/wireshark/plugins along with wssdl. 3 | 4 | local wssdl = require 'wssdl' 5 | 6 | miio = wssdl.packet 7 | { 8 | magic : u16() 9 | : hex(); 10 | length : u16(); 11 | unknown : u32() 12 | : hex(); 13 | deviceId : u32() 14 | : hex(); 15 | stamp : u32(); 16 | checksum : bytes(16); 17 | data : payload(magic); 18 | } 19 | 20 | wssdl.dissect { 21 | udp.port:set { 22 | [54321] = miio:proto('miio', 'Xiaomi Mi Home Binary Protocol') 23 | } 24 | } 25 | --------------------------------------------------------------------------------