├── .github ├── scripts │ └── da_monitor.sh └── workflows │ ├── audit.yml │ └── ci.yml ├── .gitignore ├── .vscode └── settings.json ├── Cargo.toml ├── README.md ├── mqttrust ├── Cargo.toml └── src │ ├── encoding │ ├── mod.rs │ └── v4 │ │ ├── connect.rs │ │ ├── decoder.rs │ │ ├── encoder.rs │ │ ├── mod.rs │ │ ├── packet.rs │ │ ├── publish.rs │ │ ├── subscribe.rs │ │ └── utils.rs │ ├── fmt.rs │ └── lib.rs └── mqttrust_core ├── Cargo.toml ├── examples ├── aws_device_advisor.rs ├── common │ ├── clock.rs │ ├── credentials.rs │ ├── mod.rs │ └── network.rs ├── echo.rs └── secrets │ ├── .gitignore │ ├── identity.pfx │ └── root-ca.pem └── src ├── client.rs ├── eventloop.rs ├── fmt.rs ├── lib.rs ├── max_payload.rs ├── options.rs ├── packet.rs └── state.rs /.github/scripts/da_monitor.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # 3 | # This script is written in bash in conjunction with the AWS CLI so 4 | # you can be aware of the APIs in use in a simplified way and optimize 5 | # the run in the programming language of your choice. 6 | # 7 | 8 | STATUS_PASS=PASS 9 | STATUS_FAIL=FAIL 10 | STATUS_RUNNING=RUNNING 11 | STATUS_PENDING=PENDING 12 | STATUS_STOPPING=STOPPING 13 | STATUS_STOPPED=STOPPED 14 | STATUS_PASS_WITH_WARNINGS=PASS_WITH_WARNINGS 15 | STATUS_ERROR=ERROR 16 | 17 | # The suite definition ID should be set in the AWS CodeBuild environment. 18 | suite_definition_id=$1 19 | 20 | # The suite should be run prior to this script being invoked. 21 | suite_run_id=$2 22 | 23 | # The PID of the binary being tested. 24 | pid=$3 25 | 26 | STATUS_FILE=/tmp/myapp-run-$$.status 27 | IN_PROGRESS=1 28 | MONITOR_STATUS=0 29 | function report_status { 30 | 31 | number_groups=$(jq -r ".testResult.groups | length" ${STATUS_FILE}) 32 | 33 | echo NUMBER TEST GROUPS: ${number_groups} 34 | 35 | for gn in $(seq 0 $((number_groups-1))); do 36 | number_tests=$(jq -r ".testResult.groups[$gn].tests | length" ${STATUS_FILE}) 37 | echo GROUP $((gn+1)) NUMBER OF TESTS: ${number_tests} 38 | 39 | for tcn in $(seq 0 $((number_tests-1))); do 40 | tcname=$(jq -r ".testResult.groups[$gn].tests[$tcn].testCaseDefinitionName" ${STATUS_FILE}) 41 | tcstatus=$(jq -r ".testResult.groups[$gn].tests[$tcn].status" ${STATUS_FILE}) 42 | echo ${tcname} ${tcstatus} 43 | done 44 | done 45 | } 46 | 47 | while test ${IN_PROGRESS} == 1; do 48 | # Fetch the current status and stash in /tmp so we can use it throughout the status fetch process. 49 | 50 | aws iotdeviceadvisor get-suite-run \ 51 | --suite-definition-id ${suite_definition_id} \ 52 | --suite-run-id ${suite_run_id} --output json > ${STATUS_FILE} 53 | 54 | # Identify the overall test status. If FAIL or PASS, emit the status 55 | # and exit here with the related error code (PASS=0, FAIL=1). 56 | # Otherwise continue and provide overall test group and test case 57 | # status. 58 | 59 | overall_status=$(jq -r ".status" ${STATUS_FILE}) 60 | 61 | echo OVERALL STATUS: ${overall_status} 62 | 63 | report_status 64 | 65 | if test x"${overall_status}" == x${STATUS_FAIL}; then 66 | MONITOR_STATUS=1 67 | IN_PROGRESS=0 68 | elif test x"${overall_status}" == x${STATUS_PASS}; then 69 | MONITOR_STATUS=0 70 | IN_PROGRESS=0 71 | elif test x"${overall_status}" == x${STATUS_STOPPING}; then 72 | MONITOR_STATUS=1 73 | IN_PROGRESS=0 74 | elif test x"${overall_status}" == x${STATUS_STOPPED}; then 75 | MONITOR_STATUS=1 76 | IN_PROGRESS=0 77 | elif { ps -p $pid > /dev/null; }; [ "$?" = 1 ]; then 78 | echo Binary is not running any more? 79 | 80 | MONITOR_STATUS=1 81 | IN_PROGRESS=0 82 | else 83 | echo Sleeping 10 seconds for the next status. 84 | sleep 10 85 | fi 86 | done 87 | rm ${STATUS_FILE} 88 | exit ${MONITOR_STATUS} -------------------------------------------------------------------------------- /.github/workflows/audit.yml: -------------------------------------------------------------------------------- 1 | name: Security audit 2 | on: 3 | push: 4 | paths: 5 | - '**/Cargo.toml' 6 | - '**/Cargo.lock' 7 | jobs: 8 | security_audit: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v1 12 | - uses: actions-rs/audit-check@v1 13 | with: 14 | token: ${{ secrets.GITHUB_TOKEN }} 15 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | jobs: 10 | cancel_previous_runs: 11 | name: Cancel previous runs 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: styfle/cancel-workflow-action@0.4.1 15 | with: 16 | access_token: ${{ secrets.GITHUB_TOKEN }} 17 | 18 | test: 19 | name: Test 20 | runs-on: ubuntu-latest 21 | steps: 22 | - name: Checkout source code 23 | uses: actions/checkout@v2 24 | 25 | - name: Install Rust 26 | uses: actions-rs/toolchain@v1 27 | with: 28 | profile: minimal 29 | toolchain: stable 30 | target: thumbv7m-none-eabi 31 | override: true 32 | 33 | - name: Build 34 | uses: actions-rs/cargo@v1 35 | with: 36 | command: build 37 | args: --all --target thumbv7m-none-eabi 38 | 39 | - name: Test 40 | uses: actions-rs/cargo@v1 41 | with: 42 | command: test 43 | args: --lib --features "log" 44 | integration: 45 | name: Integration Test 46 | runs-on: ubuntu-latest 47 | needs: test 48 | steps: 49 | - name: Checkout source code 50 | uses: actions/checkout@v2 51 | 52 | - name: Install Rust 53 | uses: actions-rs/toolchain@v1 54 | with: 55 | profile: minimal 56 | toolchain: stable 57 | target: thumbv7m-none-eabi 58 | override: true 59 | 60 | - name: Run integration test 61 | uses: actions-rs/cargo@v1 62 | with: 63 | command: run 64 | args: --features=log --example echo 65 | 66 | # device_advisor: 67 | # name: AWS IoT Device Advisor 68 | # runs-on: ubuntu-latest 69 | # needs: test 70 | # env: 71 | # AWS_EC2_METADATA_DISABLED: true 72 | # AWS_DEFAULT_REGION: ${{ secrets.MGMT_AWS_DEFAULT_REGION }} 73 | # AWS_ACCESS_KEY_ID: ${{ secrets.MGMT_AWS_ACCESS_KEY_ID }} 74 | # AWS_SECRET_ACCESS_KEY: ${{ secrets.MGMT_AWS_SECRET_ACCESS_KEY }} 75 | # SUITE_ID: greb3uy2wtq3 76 | # THING_ARN: arn:aws:iot:eu-west-1:274906834921:thing/mqttrust 77 | # CERTIFICATE_ARN: arn:aws:iot:eu-west-1:274906834921:cert/e7280d8d316b58da3058037a2c1730d9eb15de50e96f4d47e54ea655266b76db 78 | # steps: 79 | # - name: Checkout 80 | # uses: actions/checkout@v1 81 | 82 | # - name: Install Rust 83 | # uses: actions-rs/toolchain@v1 84 | # with: 85 | # profile: minimal 86 | # toolchain: stable 87 | # override: true 88 | 89 | # - name: Get AWS_HOSTNAME 90 | # id: hostname 91 | # run: | 92 | # hostname=$(aws iotdeviceadvisor get-endpoint --output text --query endpoint) 93 | # ret=$? 94 | # echo "::set-output name=AWS_HOSTNAME::$hostname" 95 | # exit $ret 96 | 97 | # - name: Build test binary 98 | # uses: actions-rs/cargo@v1 99 | # env: 100 | # AWS_HOSTNAME: ${{ steps.hostname.outputs.AWS_HOSTNAME }} 101 | # with: 102 | # command: build 103 | # args: --features=log --example aws_device_advisor --release 104 | 105 | # - name: Start test suite 106 | # id: test_suite 107 | # run: | 108 | # suite_id=$(aws iotdeviceadvisor start-suite-run --suite-definition-id ${{ env.SUITE_ID }} --suite-run-configuration "primaryDevice={thingArn=${{ env.THING_ARN }},certificateArn=${{ env.CERTIFICATE_ARN }}}" --output text --query suiteRunId) 109 | # ret=$? 110 | # echo "::set-output name=SUITE_RUN_ID::$suite_id" 111 | # exit $ret 112 | 113 | # - name: Execute test binary 114 | # id: binary 115 | # env: 116 | # DEVICE_ADVISOR_PASSWORD: ${{ secrets.DEVICE_ADVISOR_PASSWORD }} 117 | # RUST_LOG: trace 118 | # run: | 119 | # nohup ./target/release/examples/aws_device_advisor > device_advisor_integration.log & 120 | # echo "::set-output name=PID::$!" 121 | 122 | # - name: Monitor test run 123 | # run: | 124 | # chmod +x ./.github/scripts/da_monitor.sh 125 | # echo ${{ env.SUITE_ID }} ${{ steps.test_suite.outputs.SUITE_RUN_ID }} ${{ steps.binary.outputs.PID }} 126 | # ./.github/scripts/da_monitor.sh ${{ env.SUITE_ID }} ${{ steps.test_suite.outputs.SUITE_RUN_ID }} ${{ steps.binary.outputs.PID }} 127 | 128 | # - name: Kill test binary process 129 | # if: ${{ always() }} 130 | # run: kill ${{ steps.binary.outputs.PID }} || true 131 | 132 | # - name: Log binary output 133 | # if: ${{ always() }} 134 | # run: cat device_advisor_integration.log 135 | 136 | # - name: Stop test suite 137 | # if: ${{ failure() }} 138 | # run: aws iotdeviceadvisor stop-suite-run --suite-definition-id ${{ env.SUITE_ID }} --suite-run-id ${{ steps.test_suite.outputs.SUITE_RUN_ID }} 139 | 140 | rustfmt: 141 | name: rustfmt 142 | runs-on: ubuntu-latest 143 | steps: 144 | - name: Checkout source code 145 | uses: actions/checkout@v2 146 | 147 | - name: Install Rust 148 | uses: actions-rs/toolchain@v1 149 | with: 150 | profile: minimal 151 | toolchain: nightly 152 | override: true 153 | components: rustfmt 154 | 155 | - name: Run rustfmt 156 | uses: actions-rs/cargo@v1 157 | with: 158 | command: fmt 159 | args: --all -- --check --verbose 160 | 161 | # tomlfmt: 162 | # name: tomlfmt 163 | # runs-on: ubuntu-latest 164 | # steps: 165 | # - name: Checkout source code 166 | # uses: actions/checkout@v2 167 | 168 | # - name: Install Rust 169 | # uses: actions-rs/toolchain@v1 170 | # with: 171 | # profile: minimal 172 | # toolchain: nightly 173 | # override: true 174 | 175 | # - name: Install tomlfmt 176 | # uses: actions-rs/install@v0.1 177 | # with: 178 | # crate: cargo-tomlfmt 179 | # version: latest 180 | # use-tool-cache: true 181 | 182 | # - name: Run Tomlfmt 183 | # uses: actions-rs/cargo@v1 184 | # with: 185 | # command: tomlfmt 186 | # args: --dryrun 187 | 188 | clippy: 189 | name: clippy 190 | runs-on: ubuntu-latest 191 | env: 192 | CLIPPY_PARAMS: -W clippy::all -W clippy::pedantic -W clippy::nursery -W clippy::cargo 193 | steps: 194 | - name: Checkout source code 195 | uses: actions/checkout@v2 196 | 197 | - name: Install Rust 198 | uses: actions-rs/toolchain@v1 199 | with: 200 | profile: minimal 201 | toolchain: stable 202 | override: true 203 | components: clippy 204 | 205 | - name: Run clippy 206 | uses: actions-rs/clippy-check@v1 207 | with: 208 | token: ${{ secrets.GITHUB_TOKEN }} 209 | args: --features "log" -- ${{ env.CLIPPY_PARAMS }} 210 | # grcov: 211 | # name: Coverage 212 | # runs-on: ubuntu-latest 213 | # steps: 214 | # - name: Checkout source code 215 | # uses: actions/checkout@v2 216 | 217 | # - name: Install Rust 218 | # uses: actions-rs/toolchain@v1 219 | # with: 220 | # profile: minimal 221 | # toolchain: nightly 222 | # target: thumbv7m-none-eabi 223 | # override: true 224 | 225 | # - name: Install grcov 226 | # uses: actions-rs/cargo@v1 227 | # # uses: actions-rs/install@v0.1 228 | # with: 229 | # # crate: grcov 230 | # # version: latest 231 | # # use-tool-cache: true 232 | # command: install 233 | # args: grcov --git https://github.com/mozilla/grcov 234 | 235 | # - name: Test 236 | # uses: actions-rs/cargo@v1 237 | # with: 238 | # command: test 239 | # args: --lib --no-fail-fast --features "log" 240 | # env: 241 | # CARGO_INCREMENTAL: "0" 242 | # RUSTFLAGS: "-Zprofile -Ccodegen-units=1 -Copt-level=0 -Coverflow-checks=off -Cpanic=unwind -Zpanic_abort_tests" 243 | # RUSTDOCFLAGS: "-Zprofile -Ccodegen-units=1 -Cinline-threshold=0 -Coverflow-checks=off -Cpanic=unwind -Zpanic_abort_tests" 244 | 245 | # - name: Generate coverage data 246 | # id: grcov 247 | # # uses: actions-rs/grcov@v0.1 248 | # run: | 249 | # grcov target/debug/ \ 250 | # --branch \ 251 | # --llvm \ 252 | # --source-dir . \ 253 | # --output-file lcov.info \ 254 | # --ignore='/**' \ 255 | # --ignore='C:/**' \ 256 | # --ignore='../**' \ 257 | # --ignore-not-existing \ 258 | # --excl-line "#\\[derive\\(" \ 259 | # --excl-br-line "(#\\[derive\\()|(debug_assert)" \ 260 | # --excl-start "#\\[cfg\\(test\\)\\]" \ 261 | # --excl-br-start "#\\[cfg\\(test\\)\\]" \ 262 | # --commit-sha ${{ github.sha }} \ 263 | # --service-job-id ${{ github.job }} \ 264 | # --service-name "GitHub Actions" \ 265 | # --service-number ${{ github.run_id }} 266 | # - name: Upload coverage as artifact 267 | # uses: actions/upload-artifact@v2 268 | # with: 269 | # name: lcov.info 270 | # # path: ${{ steps.grcov.outputs.report }} 271 | # path: lcov.info 272 | 273 | # - name: Upload coverage to codecov.io 274 | # uses: codecov/codecov-action@v1 275 | # with: 276 | # # file: ${{ steps.grcov.outputs.report }} 277 | # file: lcov.info 278 | # fail_ci_if_error: true 279 | docs: 280 | name: Documentation 281 | runs-on: ubuntu-latest 282 | steps: 283 | - name: Checkout source code 284 | uses: actions/checkout@v2 285 | with: 286 | persist-credentials: false 287 | 288 | - name: Install Rust 289 | uses: actions-rs/toolchain@v1 290 | with: 291 | profile: minimal 292 | toolchain: nightly 293 | override: true 294 | 295 | - name: Build documentation 296 | uses: actions-rs/cargo@v1 297 | with: 298 | command: doc 299 | args: --verbose --no-deps 300 | 301 | # - name: Finalize documentation 302 | # run: | 303 | # CRATE_NAME=$(echo '${{ github.repository }}' | tr '[:upper:]' '[:lower:]' | cut -f2 -d"/") 304 | # echo "" > target/doc/index.html 305 | # touch target/doc/.nojekyll 306 | # - name: Upload as artifact 307 | # uses: actions/upload-artifact@v2 308 | # with: 309 | # name: Documentation 310 | # path: target/doc 311 | 312 | # - name: Deploy 313 | # uses: JamesIves/github-pages-deploy-action@releases/v3 314 | # with: 315 | # ACCESS_TOKEN: ${{ secrets.GH_PAT }} 316 | # BRANCH: gh-pages 317 | # FOLDER: target/doc 318 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.rs.bk 2 | .#* 3 | .gdb_history 4 | Cargo.lock 5 | target/ 6 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | // override the default setting (`cargo check --all-targets`) which produces the following error 3 | // "can't find crate for `test`" when the default compilation target is a no_std target 4 | // with these changes RA will call `cargo check --bins` on save 5 | "rust-analyzer.checkOnSave.allTargets": false, 6 | "rust-analyzer.cargo.target": "x86_64-unknown-linux-gnu", 7 | "rust-analyzer.diagnostics.disabled": [ 8 | "unresolved-import" 9 | ], 10 | "rust-analyzer.cargo.features": ["log"] 11 | } -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "mqttrust", 4 | "mqttrust_core", 5 | ] 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MQTT Client for Embedded devices 2 | 3 | > no_std, no_alloc crate implementing secure MQTT Client capabilities. 4 | 5 | ![Test][test] 6 | [![Code coverage][codecov-badge]][codecov] 7 | ![No Std][no-std-badge] 8 | [![Crates.io Version][crates-io-badge]][crates-io] 9 | [![Crates.io Downloads][crates-io-download-badge]][crates-io-download] 10 | 11 | This crate is highly inspired by the great work in [rumqttc](https://github.com/bytebeamio/rumqtt/tree/master/rumqttc). 12 | 13 | ## Tests 14 | 15 | > The crate is covered by tests. These tests can be run by `cargo test --tests --all-features`, and are run by the CI on every push to master. 16 | 17 | ## License 18 | 19 | Licensed under either of 20 | 21 | - Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or 22 | http://www.apache.org/licenses/LICENSE-2.0) 23 | - MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) 24 | 25 | at your option. 26 | 27 | ### Contribution 28 | 29 | Unless you explicitly state otherwise, any contribution intentionally submitted 30 | for inclusion in the work by you, as defined in the Apache-2.0 license, shall be 31 | dual licensed as above, without any additional terms or conditions. 32 | 33 | 34 | 35 | [test]: https://github.com/BlackbirdHQ/mqttrust/workflows/Test/badge.svg 36 | [no-std-badge]: https://img.shields.io/badge/no__std-yes-blue 37 | [codecov-badge]: https://codecov.io/gh/BlackbirdHQ/mqttrust/branch/master/graph/badge.svg 38 | [codecov]: https://codecov.io/gh/BlackbirdHQ/mqttrust 39 | [crates-io]: https://crates.io/crates/mqttrust 40 | [crates-io-badge]: https://img.shields.io/crates/v/mqttrust.svg?maxAge=3600 41 | [crates-io-download]: https://crates.io/crates/mqttrust 42 | [crates-io-download-badge]: https://img.shields.io/crates/d/mqttrust.svg?maxAge=3600 43 | 44 | -------------------------------------------------------------------------------- /mqttrust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mqttrust" 3 | version = "0.6.0" 4 | authors = ["Mathias Koch "] 5 | description = "MQTT Client " 6 | readme = "../README.md" 7 | keywords = ["mqtt", "no-std"] 8 | categories = ["embedded", "no-std"] 9 | license = "MIT OR Apache-2.0" 10 | repository = "https://github.com/BlackbirdHQ/mqttrust" 11 | edition = "2018" 12 | documentation = "https://docs.rs/mqttrust" 13 | 14 | [lib] 15 | name = "mqttrust" 16 | 17 | [badges] 18 | maintenance = { status = "actively-developed" } 19 | 20 | [dependencies] 21 | heapless = { version = "^0.7" } 22 | serde = { version = "1.0", features = ["derive"], optional = true } 23 | 24 | log = { version = "^0.4", default-features = false, optional = true } 25 | defmt = { version = "^0.3", optional = true } 26 | 27 | [features] 28 | default = [] 29 | 30 | defmt-impl = ["defmt", "heapless/defmt-impl"] 31 | 32 | derive = ["serde", "heapless/serde"] 33 | -------------------------------------------------------------------------------- /mqttrust/src/encoding/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod v4; 2 | -------------------------------------------------------------------------------- /mqttrust/src/encoding/v4/connect.rs: -------------------------------------------------------------------------------- 1 | use super::{decoder::*, encoder::*, *}; 2 | 3 | /// Protocol version. 4 | /// 5 | /// Sent in [`Connect`] packet. 6 | /// 7 | /// [`Connect`]: struct.Connect.html 8 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 9 | pub enum Protocol { 10 | /// [MQTT 3.1.1] is the most commonly implemented version. 11 | /// 12 | /// [MQTT 3.1.1]: https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html 13 | /// [MQTT 5]: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html 14 | MQTT311, 15 | /// MQIsdp, aka SCADA are pre-standardisation names of MQTT. It should mostly conform to MQTT 16 | /// 3.1.1, but you should watch out for implementation discrepancies. 17 | MQIsdp, 18 | } 19 | impl Protocol { 20 | pub(crate) fn new(name: &str, level: u8) -> Result { 21 | match (name, level) { 22 | ("MQIsdp", 3) => Ok(Protocol::MQIsdp), 23 | ("MQTT", 4) => Ok(Protocol::MQTT311), 24 | _ => Err(Error::InvalidProtocol(name.into(), level)), 25 | } 26 | } 27 | pub(crate) fn from_buffer(buf: &[u8], offset: &mut usize) -> Result { 28 | let protocol_name = read_str(buf, offset)?; 29 | let protocol_level = buf[*offset]; 30 | *offset += 1; 31 | 32 | Protocol::new(protocol_name, protocol_level) 33 | } 34 | pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { 35 | match self { 36 | Protocol::MQTT311 => { 37 | let slice = &[0u8, 4, b'M', b'Q', b'T', b'T', 4]; 38 | for &byte in slice { 39 | write_u8(buf, offset, byte)?; 40 | } 41 | Ok(slice.len()) 42 | } 43 | Protocol::MQIsdp => { 44 | let slice = &[0u8, 4, b'M', b'Q', b'i', b's', b'd', b'p', 4]; 45 | for &byte in slice { 46 | write_u8(buf, offset, byte)?; 47 | } 48 | Ok(slice.len()) 49 | } 50 | } 51 | } 52 | } 53 | 54 | /// Message that the server should publish when the client disconnects. 55 | /// 56 | /// Sent by the client in the [Connect] packet. [MQTT 3.1.3.3]. 57 | /// 58 | /// [Connect]: struct.Connect.html 59 | /// [MQTT 3.1.3.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718031 60 | #[derive(Debug, Clone, PartialEq)] 61 | pub struct LastWill<'a> { 62 | pub topic: &'a str, 63 | pub message: &'a [u8], 64 | pub qos: QoS, 65 | pub retain: bool, 66 | } 67 | 68 | /// Sucess value of a [Connack] packet. 69 | /// 70 | /// See [MQTT 3.2.2.3] for interpretations. 71 | /// 72 | /// [Connack]: struct.Connack.html 73 | /// [MQTT 3.2.2.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718035 74 | #[derive(Debug, Clone, Copy, PartialEq)] 75 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 76 | pub enum ConnectReturnCode { 77 | Accepted, 78 | RefusedProtocolVersion, 79 | RefusedIdentifierRejected, 80 | ServerUnavailable, 81 | BadUsernamePassword, 82 | NotAuthorized, 83 | } 84 | impl ConnectReturnCode { 85 | fn as_u8(&self) -> u8 { 86 | match *self { 87 | ConnectReturnCode::Accepted => 0, 88 | ConnectReturnCode::RefusedProtocolVersion => 1, 89 | ConnectReturnCode::RefusedIdentifierRejected => 2, 90 | ConnectReturnCode::ServerUnavailable => 3, 91 | ConnectReturnCode::BadUsernamePassword => 4, 92 | ConnectReturnCode::NotAuthorized => 5, 93 | } 94 | } 95 | pub(crate) fn from_u8(byte: u8) -> Result { 96 | match byte { 97 | 0 => Ok(ConnectReturnCode::Accepted), 98 | 1 => Ok(ConnectReturnCode::RefusedProtocolVersion), 99 | 2 => Ok(ConnectReturnCode::RefusedIdentifierRejected), 100 | 3 => Ok(ConnectReturnCode::ServerUnavailable), 101 | 4 => Ok(ConnectReturnCode::BadUsernamePassword), 102 | 5 => Ok(ConnectReturnCode::NotAuthorized), 103 | n => Err(Error::InvalidConnectReturnCode(n)), 104 | } 105 | } 106 | } 107 | 108 | /// Connect packet ([MQTT 3.1]). 109 | /// 110 | /// [MQTT 3.1]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718028 111 | #[derive(Debug, Clone, PartialEq)] 112 | pub struct Connect<'a> { 113 | pub protocol: Protocol, 114 | pub keep_alive: u16, 115 | pub client_id: &'a str, 116 | pub clean_session: bool, 117 | pub last_will: Option>, 118 | pub username: Option<&'a str>, 119 | pub password: Option<&'a [u8]>, 120 | } 121 | 122 | /// Connack packet ([MQTT 3.2]). 123 | /// 124 | /// [MQTT 3.2]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033 125 | #[derive(Debug, Clone, Copy, PartialEq)] 126 | pub struct Connack { 127 | pub session_present: bool, 128 | pub code: ConnectReturnCode, 129 | } 130 | 131 | impl<'a> Connect<'a> { 132 | pub(crate) fn from_buffer(buf: &'a [u8], offset: &mut usize) -> Result { 133 | let protocol = Protocol::from_buffer(buf, offset)?; 134 | 135 | let connect_flags = buf[*offset]; 136 | let keep_alive = ((buf[*offset + 1] as u16) << 8) | buf[*offset + 2] as u16; 137 | *offset += 3; 138 | 139 | let client_id = read_str(buf, offset)?; 140 | 141 | let last_will = if connect_flags & 0b100 != 0 { 142 | let will_topic = read_str(buf, offset)?; 143 | let will_message = read_bytes(buf, offset)?; 144 | let will_qod = QoS::from_u8((connect_flags & 0b11000) >> 3)?; 145 | Some(LastWill { 146 | topic: will_topic, 147 | message: will_message, 148 | qos: will_qod, 149 | retain: (connect_flags & 0b00100000) != 0, 150 | }) 151 | } else { 152 | None 153 | }; 154 | 155 | let username = if connect_flags & 0b10000000 != 0 { 156 | Some(read_str(buf, offset)?) 157 | } else { 158 | None 159 | }; 160 | 161 | let password = if connect_flags & 0b01000000 != 0 { 162 | Some(read_bytes(buf, offset)?) 163 | } else { 164 | None 165 | }; 166 | 167 | let clean_session = (connect_flags & 0b10) != 0; 168 | 169 | Ok(Connect { 170 | protocol, 171 | keep_alive, 172 | client_id, 173 | clean_session, 174 | last_will, 175 | username, 176 | password, 177 | }) 178 | } 179 | 180 | pub(crate) fn len(&self) -> usize { 181 | let mut length: usize = 6 + 1 + 1; // NOTE: protocol_name(6) + protocol_level(1) + flags(1); 182 | length += 2 + self.client_id.len(); 183 | length += 2; // keep alive 184 | if let Some(username) = self.username { 185 | length += username.len(); 186 | length += 2; 187 | }; 188 | if let Some(password) = self.password { 189 | length += password.len(); 190 | length += 2; 191 | }; 192 | if let Some(last_will) = &self.last_will { 193 | length += last_will.message.len(); 194 | length += last_will.topic.len(); 195 | length += 4; 196 | }; 197 | length 198 | } 199 | 200 | pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { 201 | let header: u8 = 0b00010000; 202 | let mut connect_flags: u8 = 0b00000000; 203 | if self.clean_session { 204 | connect_flags |= 0b10; 205 | }; 206 | if self.username.is_some() { 207 | connect_flags |= 0b10000000; 208 | }; 209 | if self.password.is_some() { 210 | connect_flags |= 0b01000000; 211 | }; 212 | if let Some(last_will) = &self.last_will { 213 | connect_flags |= 0b00000100; 214 | connect_flags |= last_will.qos.as_u8() << 3; 215 | if last_will.retain { 216 | connect_flags |= 0b00100000; 217 | }; 218 | }; 219 | let length = self.len(); 220 | check_remaining(buf, offset, length + 1)?; 221 | 222 | // NOTE: putting data into buffer. 223 | write_u8(buf, offset, header)?; 224 | 225 | let write_len = write_length(buf, offset, length)? + 1; 226 | self.protocol.to_buffer(buf, offset)?; 227 | 228 | write_u8(buf, offset, connect_flags)?; 229 | write_u16(buf, offset, self.keep_alive)?; 230 | 231 | write_string(buf, offset, self.client_id)?; 232 | 233 | if let Some(last_will) = &self.last_will { 234 | write_string(buf, offset, last_will.topic)?; 235 | write_bytes(buf, offset, &last_will.message)?; 236 | }; 237 | 238 | if let Some(username) = self.username { 239 | write_string(buf, offset, username)?; 240 | }; 241 | if let Some(password) = self.password { 242 | write_bytes(buf, offset, password)?; 243 | }; 244 | // NOTE: END 245 | Ok(write_len) 246 | } 247 | } 248 | 249 | impl Connack { 250 | pub(crate) fn from_buffer(buf: &[u8], offset: &mut usize) -> Result { 251 | let flags = buf[*offset]; 252 | let return_code = buf[*offset + 1]; 253 | *offset += 2; 254 | Ok(Connack { 255 | session_present: (flags & 0b1 == 1), 256 | code: ConnectReturnCode::from_u8(return_code)?, 257 | }) 258 | } 259 | pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { 260 | check_remaining(buf, offset, 4)?; 261 | let header: u8 = 0b00100000; 262 | let length: u8 = 2; 263 | let mut flags: u8 = 0b00000000; 264 | if self.session_present { 265 | flags |= 0b1; 266 | }; 267 | let rc = self.code.as_u8(); 268 | write_u8(buf, offset, header)?; 269 | write_u8(buf, offset, length)?; 270 | write_u8(buf, offset, flags)?; 271 | write_u8(buf, offset, rc)?; 272 | Ok(4) 273 | } 274 | } 275 | -------------------------------------------------------------------------------- /mqttrust/src/encoding/v4/decoder.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | 3 | pub fn decode_slice(buf: &[u8]) -> Result>, Error> { 4 | let mut offset = 0; 5 | if let Some((header, remaining_len)) = read_header(buf, &mut offset)? { 6 | let r = read_packet(header, remaining_len, buf, &mut offset)?; 7 | Ok(Some(r)) 8 | } else { 9 | // Don't have a full packet 10 | Ok(None) 11 | } 12 | } 13 | 14 | fn read_packet<'a>( 15 | header: Header, 16 | remaining_len: usize, 17 | buf: &'a [u8], 18 | offset: &mut usize, 19 | ) -> Result, Error> { 20 | Ok(match header.typ { 21 | PacketType::Pingreq => Packet::Pingreq, 22 | PacketType::Pingresp => Packet::Pingresp, 23 | PacketType::Disconnect => Packet::Disconnect, 24 | PacketType::Connect => Connect::from_buffer(buf, offset)?.into(), 25 | PacketType::Connack => Connack::from_buffer(buf, offset)?.into(), 26 | PacketType::Publish => Publish::from_buffer(&header, remaining_len, buf, offset)?.into(), 27 | PacketType::Puback => Packet::Puback(Pid::from_buffer(buf, offset)?), 28 | PacketType::Pubrec => Packet::Pubrec(Pid::from_buffer(buf, offset)?), 29 | PacketType::Pubrel => Packet::Pubrel(Pid::from_buffer(buf, offset)?), 30 | PacketType::Pubcomp => Packet::Pubcomp(Pid::from_buffer(buf, offset)?), 31 | PacketType::Subscribe => Subscribe::from_buffer(remaining_len, buf, offset)?.into(), 32 | PacketType::Suback => Suback::from_buffer(remaining_len, buf, offset)?.into(), 33 | PacketType::Unsubscribe => Unsubscribe::from_buffer(remaining_len, buf, offset)?.into(), 34 | PacketType::Unsuback => Packet::Unsuback(Pid::from_buffer(buf, offset)?), 35 | }) 36 | } 37 | 38 | /// Read the parsed header and remaining_len from the buffer. Only return Some() and advance the 39 | /// buffer position if there is enough data in the buffer to read the full packet. 40 | pub fn read_header(buf: &[u8], offset: &mut usize) -> Result, Error> { 41 | let mut len: usize = 0; 42 | for pos in 0..=3 { 43 | if buf.len() > *offset + pos + 1 { 44 | let byte = buf[*offset + pos + 1]; 45 | len += (byte as usize & 0x7F) << (pos * 7); 46 | if (byte & 0x80) == 0 { 47 | // Continuation bit == 0, length is parsed 48 | if buf.len() < *offset + 2 + pos + len { 49 | // Won't be able to read full packet 50 | return Ok(None); 51 | } 52 | // Parse header byte, skip past the header, and return 53 | let header = Header::new(buf[*offset])?; 54 | *offset += pos + 2; 55 | return Ok(Some((header, len))); 56 | } 57 | } else { 58 | // Couldn't read full length 59 | return Ok(None); 60 | } 61 | } 62 | // Continuation byte == 1 four times, that's illegal. 63 | Err(Error::InvalidHeader) 64 | } 65 | 66 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 67 | pub struct Header { 68 | pub typ: PacketType, 69 | pub dup: bool, 70 | pub qos: QoS, 71 | pub retain: bool, 72 | } 73 | impl Header { 74 | pub fn new(hd: u8) -> Result { 75 | let (typ, flags_ok) = match hd >> 4 { 76 | 1 => (PacketType::Connect, hd & 0b1111 == 0), 77 | 2 => (PacketType::Connack, hd & 0b1111 == 0), 78 | 3 => (PacketType::Publish, true), 79 | 4 => (PacketType::Puback, hd & 0b1111 == 0), 80 | 5 => (PacketType::Pubrec, hd & 0b1111 == 0), 81 | 6 => (PacketType::Pubrel, hd & 0b1111 == 0b0010), 82 | 7 => (PacketType::Pubcomp, hd & 0b1111 == 0), 83 | 8 => (PacketType::Subscribe, hd & 0b1111 == 0b0010), 84 | 9 => (PacketType::Suback, hd & 0b1111 == 0), 85 | 10 => (PacketType::Unsubscribe, hd & 0b1111 == 0b0010), 86 | 11 => (PacketType::Unsuback, hd & 0b1111 == 0), 87 | 12 => (PacketType::Pingreq, hd & 0b1111 == 0), 88 | 13 => (PacketType::Pingresp, hd & 0b1111 == 0), 89 | 14 => (PacketType::Disconnect, hd & 0b1111 == 0), 90 | _ => (PacketType::Connect, false), 91 | }; 92 | if !flags_ok { 93 | return Err(Error::InvalidHeader); 94 | } 95 | Ok(Header { 96 | typ, 97 | dup: hd & 0b1000 != 0, 98 | qos: QoS::from_u8((hd & 0b110) >> 1)?, 99 | retain: hd & 1 == 1, 100 | }) 101 | } 102 | } 103 | 104 | pub(crate) fn read_str<'a>(buf: &'a [u8], offset: &mut usize) -> Result<&'a str, Error> { 105 | core::str::from_utf8(read_bytes(buf, offset)?).map_err(|_| Error::InvalidString) 106 | } 107 | 108 | pub(crate) fn read_bytes<'a>(buf: &'a [u8], offset: &mut usize) -> Result<&'a [u8], Error> { 109 | if buf[*offset..].len() < 2 { 110 | return Err(Error::InvalidLength); 111 | } 112 | let len = ((buf[*offset] as usize) << 8) | buf[*offset + 1] as usize; 113 | 114 | *offset += 2; 115 | if len > buf[*offset..].len() { 116 | Err(Error::InvalidLength) 117 | } else { 118 | let bytes = &buf[*offset..*offset + len]; 119 | *offset += len; 120 | Ok(bytes) 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /mqttrust/src/encoding/v4/encoder.rs: -------------------------------------------------------------------------------- 1 | use super::{Error, Packet}; 2 | 3 | /// Encode a [Packet] enum into a u8 slice. 4 | /// 5 | /// ``` 6 | /// # use mqttrust::encoding::v4::*; 7 | /// // Instantiate a `Packet` to encode. 8 | /// let packet = Publish { 9 | /// dup: false, 10 | /// qos: QoS::AtMostOnce, 11 | /// retain: false, 12 | /// topic_name: "test", 13 | /// payload: b"hello", 14 | /// pid: None, 15 | /// }.into(); 16 | /// 17 | /// // Allocate buffer (should be appropriately-sized or able to grow as needed). 18 | /// let mut buf = [0u8; 1024]; 19 | /// 20 | /// // Write bytes corresponding to `&Packet` into the `BytesMut`. 21 | /// let len = encode_slice(&packet, &mut buf).expect("failed encoding"); 22 | /// assert_eq!(&buf[..len], &[0b00110000, 11, 23 | /// 0, 4, b't', b'e', b's', b't', 24 | /// b'h', b'e', b'l', b'l', b'o']); 25 | /// ``` 26 | /// 27 | /// [Packet]: ../enum.Packet.html 28 | pub fn encode_slice(packet: &Packet, buf: &mut [u8]) -> Result { 29 | let mut offset = 0; 30 | 31 | match packet { 32 | Packet::Connect(connect) => connect.to_buffer(buf, &mut offset), 33 | Packet::Connack(connack) => connack.to_buffer(buf, &mut offset), 34 | Packet::Publish(publish) => publish.to_buffer(buf, &mut offset), 35 | Packet::Puback(pid) => { 36 | check_remaining(buf, &mut offset, 4)?; 37 | let header: u8 = 0b01000000; 38 | let length: u8 = 2; 39 | write_u8(buf, &mut offset, header)?; 40 | write_u8(buf, &mut offset, length)?; 41 | pid.to_buffer(buf, &mut offset)?; 42 | Ok(4) 43 | } 44 | Packet::Pubrec(pid) => { 45 | check_remaining(buf, &mut offset, 4)?; 46 | let header: u8 = 0b01010000; 47 | let length: u8 = 2; 48 | write_u8(buf, &mut offset, header)?; 49 | write_u8(buf, &mut offset, length)?; 50 | pid.to_buffer(buf, &mut offset)?; 51 | Ok(4) 52 | } 53 | Packet::Pubrel(pid) => { 54 | check_remaining(buf, &mut offset, 4)?; 55 | let header: u8 = 0b01100010; 56 | let length: u8 = 2; 57 | write_u8(buf, &mut offset, header)?; 58 | write_u8(buf, &mut offset, length)?; 59 | pid.to_buffer(buf, &mut offset)?; 60 | Ok(4) 61 | } 62 | Packet::Pubcomp(pid) => { 63 | check_remaining(buf, &mut offset, 4)?; 64 | let header: u8 = 0b01110000; 65 | let length: u8 = 2; 66 | write_u8(buf, &mut offset, header)?; 67 | write_u8(buf, &mut offset, length)?; 68 | pid.to_buffer(buf, &mut offset)?; 69 | Ok(4) 70 | } 71 | Packet::Subscribe(subscribe) => subscribe.to_buffer(buf, &mut offset), 72 | Packet::Suback(suback) => suback.to_buffer(buf, &mut offset), 73 | Packet::Unsubscribe(unsub) => unsub.to_buffer(buf, &mut offset), 74 | Packet::Unsuback(pid) => { 75 | check_remaining(buf, &mut offset, 4)?; 76 | let header: u8 = 0b10110000; 77 | let length: u8 = 2; 78 | write_u8(buf, &mut offset, header)?; 79 | write_u8(buf, &mut offset, length)?; 80 | pid.to_buffer(buf, &mut offset)?; 81 | Ok(4) 82 | } 83 | Packet::Pingreq => { 84 | check_remaining(buf, &mut offset, 2)?; 85 | let header: u8 = 0b11000000; 86 | let length: u8 = 0; 87 | write_u8(buf, &mut offset, header)?; 88 | write_u8(buf, &mut offset, length)?; 89 | Ok(2) 90 | } 91 | Packet::Pingresp => { 92 | check_remaining(buf, &mut offset, 2)?; 93 | let header: u8 = 0b11010000; 94 | let length: u8 = 0; 95 | write_u8(buf, &mut offset, header)?; 96 | write_u8(buf, &mut offset, length)?; 97 | Ok(2) 98 | } 99 | Packet::Disconnect => { 100 | check_remaining(buf, &mut offset, 2)?; 101 | let header: u8 = 0b11100000; 102 | let length: u8 = 0; 103 | write_u8(buf, &mut offset, header)?; 104 | write_u8(buf, &mut offset, length)?; 105 | Ok(2) 106 | } 107 | } 108 | } 109 | 110 | /// Check wether buffer has `len` bytes of write capacity left. Use this to return a clean 111 | /// Result::Err instead of panicking. 112 | pub(crate) fn check_remaining(buf: &mut [u8], offset: &mut usize, len: usize) -> Result<(), Error> { 113 | if buf[*offset..].len() < len { 114 | Err(Error::WriteZero) 115 | } else { 116 | Ok(()) 117 | } 118 | } 119 | 120 | /// http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718023 121 | pub(crate) fn write_length(buf: &mut [u8], offset: &mut usize, len: usize) -> Result { 122 | let write_len = match len { 123 | 0..=127 => { 124 | check_remaining(buf, offset, len + 1)?; 125 | len + 1 126 | } 127 | 128..=16383 => { 128 | check_remaining(buf, offset, len + 2)?; 129 | len + 2 130 | } 131 | 16384..=2097151 => { 132 | check_remaining(buf, offset, len + 3)?; 133 | len + 3 134 | } 135 | 2097152..=268435455 => { 136 | check_remaining(buf, offset, len + 4)?; 137 | len + 4 138 | } 139 | _ => return Err(Error::InvalidLength), 140 | }; 141 | let mut done = false; 142 | let mut x = len; 143 | while !done { 144 | let mut byte = (x % 128) as u8; 145 | x /= 128; 146 | if x > 0 { 147 | byte |= 128; 148 | } 149 | write_u8(buf, offset, byte)?; 150 | done = x == 0; 151 | } 152 | Ok(write_len) 153 | } 154 | 155 | pub(crate) fn write_u8(buf: &mut [u8], offset: &mut usize, val: u8) -> Result<(), Error> { 156 | buf[*offset] = val; 157 | *offset += 1; 158 | Ok(()) 159 | } 160 | 161 | pub(crate) fn write_u16(buf: &mut [u8], offset: &mut usize, val: u16) -> Result<(), Error> { 162 | write_u8(buf, offset, (val >> 8) as u8)?; 163 | write_u8(buf, offset, (val & 0xFF) as u8) 164 | } 165 | 166 | pub(crate) fn write_bytes(buf: &mut [u8], offset: &mut usize, bytes: &[u8]) -> Result<(), Error> { 167 | write_u16(buf, offset, bytes.len() as u16)?; 168 | 169 | for &byte in bytes { 170 | write_u8(buf, offset, byte)?; 171 | } 172 | Ok(()) 173 | } 174 | 175 | pub(crate) fn write_string(buf: &mut [u8], offset: &mut usize, string: &str) -> Result<(), Error> { 176 | write_bytes(buf, offset, string.as_bytes()) 177 | } 178 | -------------------------------------------------------------------------------- /mqttrust/src/encoding/v4/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod connect; 2 | pub mod decoder; 3 | pub mod encoder; 4 | pub mod packet; 5 | pub mod publish; 6 | pub mod subscribe; 7 | pub mod utils; 8 | 9 | pub use { 10 | connect::{Connack, Connect, ConnectReturnCode, LastWill, Protocol}, 11 | decoder::decode_slice, 12 | encoder::encode_slice, 13 | packet::{Packet, PacketType}, 14 | publish::Publish, 15 | subscribe::{Suback, Subscribe, SubscribeReturnCodes, SubscribeTopic, Unsubscribe}, 16 | utils::{Error, Pid, QoS, QosPid}, 17 | }; 18 | -------------------------------------------------------------------------------- /mqttrust/src/encoding/v4/packet.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | 3 | /// https://docs.solace.com/MQTT-311-Prtl-Conformance-Spec/MQTT%20Control%20Packets.htm#_Toc430864901 4 | const FIXED_HEADER_LEN: usize = 5; 5 | const PID_LEN: usize = 2; 6 | 7 | /// Base enum for all MQTT packet types. 8 | /// 9 | /// This is the main type you'll be interacting with, as an output of [`decode_slice()`] and an input of 10 | /// [`encode()`]. Most variants can be constructed directly without using methods. 11 | /// 12 | /// ``` 13 | /// # use mqttrust::encoding::v4::*; 14 | /// # use core::convert::TryFrom; 15 | /// // Simplest form 16 | /// let pkt = Packet::Connack(Connack { session_present: false, 17 | /// code: ConnectReturnCode::Accepted }); 18 | /// // Using `Into` trait 19 | /// let publish = Publish { dup: false, 20 | /// qos: QoS::AtMostOnce, 21 | /// retain: false, 22 | /// pid: None, 23 | /// topic_name: "to/pic", 24 | /// payload: b"payload" }; 25 | /// let pkt: Packet = publish.into(); 26 | /// // Identifyer-only packets 27 | /// let pkt = Packet::Puback(Pid::try_from(42).unwrap()); 28 | /// ``` 29 | /// 30 | /// [`encode()`]: fn.encode.html 31 | /// [`decode_slice()`]: fn.decode_slice.html 32 | #[derive(Debug, Clone, PartialEq)] 33 | pub enum Packet<'a> { 34 | /// [MQTT 3.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718028) 35 | Connect(Connect<'a>), 36 | /// [MQTT 3.2](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718033) 37 | Connack(Connack), 38 | /// [MQTT 3.3](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718037) 39 | Publish(Publish<'a>), 40 | /// [MQTT 3.4](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718043) 41 | Puback(Pid), 42 | /// [MQTT 3.5](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718048) 43 | Pubrec(Pid), 44 | /// [MQTT 3.6](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718053) 45 | Pubrel(Pid), 46 | /// [MQTT 3.7](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718058) 47 | Pubcomp(Pid), 48 | /// [MQTT 3.8](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718063) 49 | Subscribe(Subscribe<'a>), 50 | /// [MQTT 3.9](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718068) 51 | Suback(Suback<'a>), 52 | /// [MQTT 3.10](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718072) 53 | Unsubscribe(Unsubscribe<'a>), 54 | /// [MQTT 3.11](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718077) 55 | Unsuback(Pid), 56 | /// [MQTT 3.12](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718081) 57 | Pingreq, 58 | /// [MQTT 3.13](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718086) 59 | Pingresp, 60 | /// [MQTT 3.14](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718090) 61 | Disconnect, 62 | } 63 | 64 | impl<'a> Packet<'a> { 65 | /// Return the packet type variant. 66 | /// 67 | /// This can be used for matching, categorising, debuging, etc. Most users will match directly 68 | /// on `Packet` instead. 69 | pub fn get_type(&self) -> PacketType { 70 | match self { 71 | Packet::Connect(_) => PacketType::Connect, 72 | Packet::Connack(_) => PacketType::Connack, 73 | Packet::Publish(_) => PacketType::Publish, 74 | Packet::Puback(_) => PacketType::Puback, 75 | Packet::Pubrec(_) => PacketType::Pubrec, 76 | Packet::Pubrel(_) => PacketType::Pubrel, 77 | Packet::Pubcomp(_) => PacketType::Pubcomp, 78 | Packet::Subscribe(_) => PacketType::Subscribe, 79 | Packet::Suback(_) => PacketType::Suback, 80 | Packet::Unsubscribe(_) => PacketType::Unsubscribe, 81 | Packet::Unsuback(_) => PacketType::Unsuback, 82 | Packet::Pingreq => PacketType::Pingreq, 83 | Packet::Pingresp => PacketType::Pingresp, 84 | Packet::Disconnect => PacketType::Disconnect, 85 | } 86 | } 87 | 88 | pub fn len(&self) -> usize { 89 | let variable_len = match self { 90 | Packet::Connect(c) => c.len(), 91 | Packet::Connack(_) => 2, 92 | Packet::Publish(p) => p.len(), 93 | Packet::Puback(_) 94 | | Packet::Pubrec(_) 95 | | Packet::Pubrel(_) 96 | | Packet::Pubcomp(_) 97 | | Packet::Unsuback(_) => PID_LEN, 98 | Packet::Suback(_) => PID_LEN + 1, 99 | Packet::Subscribe(s) => s.len(), 100 | Packet::Unsubscribe(u) => u.len(), 101 | Packet::Pingreq | Packet::Pingresp | Packet::Disconnect => 0, 102 | }; 103 | 104 | FIXED_HEADER_LEN + variable_len 105 | } 106 | } 107 | 108 | macro_rules! packet_from_borrowed { 109 | ($($t:ident),+) => { 110 | $( 111 | impl<'a> From<$t<'a>> for Packet<'a> { 112 | fn from(p: $t<'a>) -> Self { 113 | Packet::$t(p) 114 | } 115 | } 116 | )+ 117 | } 118 | } 119 | macro_rules! packet_from { 120 | ($($t:ident),+) => { 121 | $( 122 | impl<'a> From<$t> for Packet<'a> { 123 | fn from(p: $t) -> Self { 124 | Packet::$t(p) 125 | } 126 | } 127 | )+ 128 | } 129 | } 130 | 131 | packet_from_borrowed!(Suback, Connect, Publish, Subscribe, Unsubscribe); 132 | packet_from!(Connack); 133 | 134 | /// Packet type variant, without the associated data. 135 | #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] 136 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 137 | pub enum PacketType { 138 | Connect, 139 | Connack, 140 | Publish, 141 | Puback, 142 | Pubrec, 143 | Pubrel, 144 | Pubcomp, 145 | Subscribe, 146 | Suback, 147 | Unsubscribe, 148 | Unsuback, 149 | Pingreq, 150 | Pingresp, 151 | Disconnect, 152 | } 153 | -------------------------------------------------------------------------------- /mqttrust/src/encoding/v4/publish.rs: -------------------------------------------------------------------------------- 1 | use super::{decoder::*, encoder::*, *}; 2 | 3 | /// Publish packet ([MQTT 3.3]). 4 | /// 5 | /// [MQTT 3.3]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718037 6 | #[derive(Debug, Clone, PartialEq)] 7 | pub struct Publish<'a> { 8 | pub dup: bool, 9 | pub qos: QoS, 10 | pub pid: Option, 11 | pub retain: bool, 12 | pub topic_name: &'a str, 13 | pub payload: &'a [u8], 14 | } 15 | 16 | impl<'a> Publish<'a> { 17 | pub(crate) fn from_buffer( 18 | header: &Header, 19 | remaining_len: usize, 20 | buf: &'a [u8], 21 | offset: &mut usize, 22 | ) -> Result { 23 | let payload_end = *offset + remaining_len; 24 | let topic_name = read_str(buf, offset)?; 25 | 26 | let (qos, pid) = match header.qos { 27 | QoS::AtMostOnce => (QoS::AtMostOnce, None), 28 | QoS::AtLeastOnce => (QoS::AtLeastOnce, Some(Pid::from_buffer(buf, offset)?)), 29 | QoS::ExactlyOnce => (QoS::ExactlyOnce, Some(Pid::from_buffer(buf, offset)?)), 30 | }; 31 | 32 | Ok(Publish { 33 | dup: header.dup, 34 | qos, 35 | pid, 36 | retain: header.retain, 37 | topic_name, 38 | payload: &buf[*offset..payload_end], 39 | }) 40 | } 41 | 42 | pub(crate) fn len(&self) -> usize { 43 | // Length: topic (2+len) + pid (0/2) + payload (len) 44 | 2 + self.topic_name.len() 45 | + match self.qos { 46 | QoS::AtMostOnce => 0, 47 | _ => 2, 48 | } 49 | + self.payload.len() 50 | } 51 | 52 | pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { 53 | // Header 54 | let mut header: u8 = match self.qos { 55 | QoS::AtMostOnce => 0b00110000, 56 | QoS::AtLeastOnce => 0b00110010, 57 | QoS::ExactlyOnce => 0b00110100, 58 | }; 59 | if self.dup { 60 | header |= 0b00001000_u8; 61 | }; 62 | if self.retain { 63 | header |= 0b00000001_u8; 64 | }; 65 | check_remaining(buf, offset, 1)?; 66 | write_u8(buf, offset, header)?; 67 | 68 | let length = self.len(); 69 | let write_len = write_length(buf, offset, length)? + 1; 70 | 71 | // Topic 72 | write_string(buf, offset, self.topic_name)?; 73 | 74 | // Pid to be overwritten later on 75 | match self.qos { 76 | QoS::AtMostOnce => (), 77 | QoS::AtLeastOnce => { 78 | write_u16(buf, offset, 0)?; 79 | } 80 | QoS::ExactlyOnce => { 81 | write_u16(buf, offset, 0)?; 82 | } 83 | } 84 | 85 | // Payload 86 | for &byte in self.payload { 87 | write_u8(buf, offset, byte)?; 88 | } 89 | 90 | Ok(write_len) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /mqttrust/src/encoding/v4/subscribe.rs: -------------------------------------------------------------------------------- 1 | use core::marker::PhantomData; 2 | 3 | use super::{decoder::*, encoder::*, *}; 4 | 5 | /// Subscribe topic. 6 | /// 7 | /// [Subscribe] packets contain a `Vec` of those. 8 | /// 9 | /// [Subscribe]: struct.Subscribe.html 10 | #[derive(Debug, Clone, PartialEq)] 11 | pub struct SubscribeTopic<'a> { 12 | pub topic_path: &'a str, 13 | pub qos: QoS, 14 | } 15 | 16 | impl<'a> FromBuffer<'a> for SubscribeTopic<'a> { 17 | type Item = Self; 18 | 19 | fn from_buffer(buf: &'a [u8], offset: &mut usize) -> Result { 20 | let topic_path = read_str(buf, offset)?; 21 | let qos = QoS::from_u8(buf[*offset])?; 22 | *offset += 1; 23 | Ok(SubscribeTopic { topic_path, qos }) 24 | } 25 | } 26 | 27 | impl<'a> FromBuffer<'a> for &'a str { 28 | type Item = Self; 29 | 30 | fn from_buffer(buf: &'a [u8], offset: &mut usize) -> Result { 31 | read_str(buf, offset) 32 | } 33 | } 34 | 35 | pub trait FromBuffer<'a> { 36 | type Item; 37 | fn from_buffer(buf: &'a [u8], offset: &mut usize) -> Result; 38 | } 39 | 40 | /// Subscribe return value. 41 | /// 42 | /// [Suback] packets contain a `Vec` of those. 43 | /// 44 | /// [Suback]: struct.Subscribe.html 45 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 46 | pub enum SubscribeReturnCodes { 47 | Success(QoS), 48 | Failure, 49 | } 50 | 51 | impl<'a> FromBuffer<'a> for SubscribeReturnCodes { 52 | type Item = Self; 53 | 54 | fn from_buffer(buf: &'a [u8], offset: &mut usize) -> Result { 55 | let code = buf[*offset]; 56 | *offset += 1; 57 | 58 | if code == 0x80 { 59 | Ok(SubscribeReturnCodes::Failure) 60 | } else { 61 | Ok(SubscribeReturnCodes::Success(QoS::from_u8(code)?)) 62 | } 63 | } 64 | } 65 | 66 | impl SubscribeReturnCodes { 67 | pub(crate) fn as_u8(&self) -> u8 { 68 | match *self { 69 | SubscribeReturnCodes::Failure => 0x80, 70 | SubscribeReturnCodes::Success(qos) => qos.as_u8(), 71 | } 72 | } 73 | } 74 | 75 | #[derive(Debug, Clone, PartialEq)] 76 | pub enum List<'a, T> { 77 | Owned(&'a [T]), 78 | Lazy(LazyList<'a, T>), 79 | } 80 | 81 | impl<'a, T> List<'a, T> 82 | where 83 | T: FromBuffer<'a, Item = T>, 84 | { 85 | pub fn len(&self) -> usize { 86 | match self { 87 | List::Owned(data) => data.len(), 88 | List::Lazy(data) => { 89 | let mut len = 0; 90 | let mut offset = 0; 91 | while T::from_buffer(data.0, &mut offset).is_ok() { 92 | len += 1; 93 | } 94 | len 95 | } 96 | } 97 | } 98 | } 99 | 100 | impl<'a, T> IntoIterator for &'a List<'a, T> 101 | where 102 | T: FromBuffer<'a, Item = T> + Clone, 103 | { 104 | type Item = T; 105 | 106 | type IntoIter = ListIter<'a, T>; 107 | 108 | fn into_iter(self) -> Self::IntoIter { 109 | ListIter { 110 | list: self, 111 | index: 0, 112 | } 113 | } 114 | } 115 | 116 | #[derive(Debug, Clone, PartialEq)] 117 | pub struct LazyList<'a, T>(&'a [u8], PhantomData); 118 | 119 | pub struct ListIter<'a, T> { 120 | list: &'a List<'a, T>, 121 | index: usize, 122 | } 123 | 124 | impl<'a, T> Iterator for ListIter<'a, T> 125 | where 126 | T: FromBuffer<'a, Item = T> + Clone, 127 | { 128 | type Item = T; 129 | 130 | fn next(&mut self) -> Option { 131 | match self.list { 132 | List::Owned(data) => { 133 | // FIXME: Can we get rid of this clone? 134 | let item = data.get(self.index).cloned(); 135 | self.index += 1; 136 | item 137 | } 138 | List::Lazy(data) => T::from_buffer(data.0, &mut self.index).ok(), 139 | } 140 | } 141 | } 142 | 143 | /// Subscribe packet ([MQTT 3.8]). 144 | /// 145 | /// [MQTT 3.8]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718063 146 | #[derive(Debug, Clone, PartialEq)] 147 | pub struct Subscribe<'a> { 148 | pid: Option, 149 | topics: List<'a, SubscribeTopic<'a>>, 150 | } 151 | 152 | /// Subsack packet ([MQTT 3.9]). 153 | /// 154 | /// [MQTT 3.9]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718068 155 | #[derive(Debug, Clone, PartialEq)] 156 | pub struct Suback<'a> { 157 | pub pid: Pid, 158 | pub return_codes: &'a [SubscribeReturnCodes], 159 | } 160 | 161 | /// Unsubscribe packet ([MQTT 3.10]). 162 | /// 163 | /// [MQTT 3.10]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718072 164 | #[derive(Debug, Clone, PartialEq)] 165 | pub struct Unsubscribe<'a> { 166 | pub pid: Option, 167 | pub topics: List<'a, &'a str>, 168 | } 169 | 170 | impl<'a> Subscribe<'a> { 171 | pub fn new(topics: &'a [SubscribeTopic<'a>]) -> Self { 172 | Self { 173 | pid: None, 174 | topics: List::Owned(topics), 175 | } 176 | } 177 | 178 | pub fn topics(&self) -> impl Iterator> { 179 | self.topics.into_iter() 180 | } 181 | 182 | pub fn pid(&self) -> Option { 183 | self.pid 184 | } 185 | 186 | pub(crate) fn from_buffer( 187 | remaining_len: usize, 188 | buf: &'a [u8], 189 | offset: &mut usize, 190 | ) -> Result { 191 | let payload_end = *offset + remaining_len; 192 | let pid = Pid::from_buffer(buf, offset)?; 193 | 194 | Ok(Subscribe { 195 | pid: Some(pid), 196 | topics: List::Lazy(LazyList(&buf[*offset..payload_end], PhantomData)), 197 | }) 198 | } 199 | 200 | /// Length: pid(2) + topic.for_each(2+len + qos(1)) 201 | pub(crate) fn len(&self) -> usize { 202 | let mut length = 2; 203 | for topic in self.topics() { 204 | length += topic.topic_path.len() + 2 + 1; 205 | } 206 | length 207 | } 208 | 209 | pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { 210 | let header: u8 = 0b10000010; 211 | check_remaining(buf, offset, 1)?; 212 | write_u8(buf, offset, header)?; 213 | 214 | let write_len = write_length(buf, offset, self.len())? + 1; 215 | 216 | // Pid 217 | self.pid.unwrap_or_default().to_buffer(buf, offset)?; 218 | 219 | // Topics 220 | for topic in self.topics() { 221 | write_string(buf, offset, topic.topic_path)?; 222 | write_u8(buf, offset, topic.qos.as_u8())?; 223 | } 224 | 225 | Ok(write_len) 226 | } 227 | } 228 | 229 | impl<'a> Unsubscribe<'a> { 230 | pub fn new(topics: &'a [&'a str]) -> Self { 231 | Self { 232 | pid: None, 233 | topics: List::Owned(topics), 234 | } 235 | } 236 | 237 | pub fn topics(&self) -> impl Iterator { 238 | self.topics.into_iter() 239 | } 240 | 241 | pub fn pid(&self) -> Option { 242 | self.pid 243 | } 244 | 245 | pub(crate) fn from_buffer( 246 | remaining_len: usize, 247 | buf: &'a [u8], 248 | offset: &mut usize, 249 | ) -> Result { 250 | let payload_end = *offset + remaining_len; 251 | let pid = Pid::from_buffer(buf, offset)?; 252 | 253 | Ok(Unsubscribe { 254 | pid: Some(pid), 255 | topics: List::Lazy(LazyList(&buf[*offset..payload_end], PhantomData)), 256 | }) 257 | } 258 | 259 | /// Length: pid(2) + topic.for_each(2+len) 260 | pub(crate) fn len(&self) -> usize { 261 | let mut length = 2; 262 | for topic in self.topics() { 263 | length += 2 + topic.len(); 264 | } 265 | length 266 | } 267 | 268 | pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { 269 | let header: u8 = 0b10100010; 270 | 271 | check_remaining(buf, offset, 1)?; 272 | write_u8(buf, offset, header)?; 273 | 274 | let write_len = write_length(buf, offset, self.len())? + 1; 275 | 276 | // Pid 277 | self.pid.unwrap_or_default().to_buffer(buf, offset)?; 278 | 279 | for topic in self.topics() { 280 | write_string(buf, offset, topic)?; 281 | } 282 | Ok(write_len) 283 | } 284 | } 285 | 286 | impl<'a> Suback<'a> { 287 | pub(crate) fn from_buffer( 288 | _remaining_len: usize, 289 | buf: &'a [u8], 290 | offset: &mut usize, 291 | ) -> Result { 292 | // FIXME: 293 | // let payload_end = *offset + remaining_len; 294 | let pid = Pid::from_buffer(buf, offset)?; 295 | 296 | // let mut return_codes = LimitedVec::new(); 297 | // while *offset < payload_end { 298 | // let _res = return_codes.push(SubscribeReturnCodes::from_buffer(buf, offset)?); 299 | // } 300 | 301 | Ok(Suback { 302 | pid, 303 | return_codes: &[], 304 | }) 305 | } 306 | 307 | pub(crate) fn to_buffer(&self, buf: &mut [u8], offset: &mut usize) -> Result { 308 | let header: u8 = 0b10010000; 309 | let length = 2 + self.return_codes.len(); 310 | check_remaining(buf, offset, 1)?; 311 | write_u8(buf, offset, header)?; 312 | 313 | let write_len = write_length(buf, offset, length)? + 1; 314 | self.pid.to_buffer(buf, offset)?; 315 | for rc in self.return_codes { 316 | write_u8(buf, offset, rc.as_u8())?; 317 | } 318 | Ok(write_len) 319 | } 320 | } 321 | -------------------------------------------------------------------------------- /mqttrust/src/encoding/v4/utils.rs: -------------------------------------------------------------------------------- 1 | use super::encoder::write_u16; 2 | use core::{convert::TryFrom, fmt, num::NonZeroU16}; 3 | 4 | #[cfg(feature = "derive")] 5 | use serde::{Deserialize, Serialize}; 6 | 7 | /// Errors returned by [`encode()`] and [`decode()`]. 8 | /// 9 | /// [`encode()`]: fn.encode.html 10 | /// [`decode()`]: fn.decode.html 11 | #[derive(Debug, Clone, PartialEq, Eq)] 12 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 13 | pub enum Error { 14 | /// Not enough space in the write buffer. 15 | /// 16 | /// It is the caller's responsiblity to pass a big enough buffer to `encode()`. 17 | WriteZero, 18 | /// Tried to encode or decode a ProcessIdentifier==0. 19 | InvalidPid(u16), 20 | /// Tried to decode a QoS > 2. 21 | InvalidQos(u8), 22 | /// Tried to decode a ConnectReturnCode > 5. 23 | InvalidConnectReturnCode(u8), 24 | /// Tried to decode an unknown protocol. 25 | InvalidProtocol(heapless::String<10>, u8), 26 | /// Tried to decode an invalid fixed header (packet type, flags, or remaining_length). 27 | InvalidHeader, 28 | /// Trying to encode/decode an invalid length. 29 | /// 30 | /// The difference with `WriteZero`/`UnexpectedEof` is that it refers to an invalid/corrupt 31 | /// length rather than a buffer size issue. 32 | InvalidLength, 33 | /// Trying to decode a non-utf8 string. 34 | InvalidString, 35 | } 36 | 37 | impl fmt::Display for Error { 38 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 39 | write!(f, "{:?}", self) 40 | } 41 | } 42 | 43 | /// Packet Identifier. 44 | /// 45 | /// For packets with [`QoS::AtLeastOne` or `QoS::ExactlyOnce`] delivery. 46 | /// 47 | /// ```rust 48 | /// # use mqttrust::encoding::v4::{Packet, Pid, QosPid}; 49 | /// # use std::convert::TryFrom; 50 | /// #[derive(Default)] 51 | /// struct Session { 52 | /// pid: Pid, 53 | /// } 54 | /// impl Session { 55 | /// pub fn next_pid(&mut self) -> Pid { 56 | /// self.pid = self.pid + 1; 57 | /// self.pid 58 | /// } 59 | /// } 60 | /// 61 | /// let mut sess = Session::default(); 62 | /// assert_eq!(2, sess.next_pid().get()); 63 | /// assert_eq!(Pid::try_from(3).unwrap(), sess.next_pid()); 64 | /// ``` 65 | /// 66 | /// The spec ([MQTT-2.3.1-1], [MQTT-2.2.1-3]) disallows a pid of 0. 67 | /// 68 | /// [`QoS::AtLeastOne` or `QoS::ExactlyOnce`]: enum.QoS.html 69 | /// [MQTT-2.3.1-1]: https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718025 70 | /// [MQTT-2.2.1-3]: https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901026 71 | #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] 72 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 73 | #[cfg_attr(feature = "derive", derive(Serialize, Deserialize))] 74 | pub struct Pid(NonZeroU16); 75 | impl Pid { 76 | /// Returns a new `Pid` with value `1`. 77 | pub fn new() -> Self { 78 | Pid(NonZeroU16::new(1).unwrap()) 79 | } 80 | 81 | /// Get the `Pid` as a raw `u16`. 82 | pub fn get(self) -> u16 { 83 | self.0.get() 84 | } 85 | 86 | pub(crate) fn from_buffer(buf: &[u8], offset: &mut usize) -> Result { 87 | let pid = ((buf[*offset] as u16) << 8) | buf[*offset + 1] as u16; 88 | *offset += 2; 89 | Self::try_from(pid) 90 | } 91 | 92 | pub fn to_buffer(self, buf: &mut [u8], offset: &mut usize) -> Result<(), Error> { 93 | write_u16(buf, offset, self.get()) 94 | } 95 | } 96 | 97 | impl Default for Pid { 98 | fn default() -> Pid { 99 | Pid::new() 100 | } 101 | } 102 | 103 | impl core::ops::Add for Pid { 104 | type Output = Pid; 105 | 106 | /// Adding a `u16` to a `Pid` will wrap around and avoid 0. 107 | fn add(self, u: u16) -> Pid { 108 | let n = match self.get().overflowing_add(u) { 109 | (n, false) => n, 110 | (n, true) => n + 1, 111 | }; 112 | Pid(NonZeroU16::new(n).unwrap()) 113 | } 114 | } 115 | 116 | impl core::ops::Sub for Pid { 117 | type Output = Pid; 118 | 119 | /// Adding a `u16` to a `Pid` will wrap around and avoid 0. 120 | fn sub(self, u: u16) -> Pid { 121 | let n = match self.get().overflowing_sub(u) { 122 | (0, _) => core::u16::MAX, 123 | (n, false) => n, 124 | (n, true) => n - 1, 125 | }; 126 | Pid(NonZeroU16::new(n).unwrap()) 127 | } 128 | } 129 | 130 | impl From for u16 { 131 | /// Convert `Pid` to `u16`. 132 | fn from(p: Pid) -> Self { 133 | p.0.get() 134 | } 135 | } 136 | 137 | impl TryFrom for Pid { 138 | type Error = Error; 139 | 140 | /// Convert `u16` to `Pid`. Will fail for value 0. 141 | fn try_from(u: u16) -> Result { 142 | match NonZeroU16::new(u) { 143 | Some(nz) => Ok(Pid(nz)), 144 | None => Err(Error::InvalidPid(u)), 145 | } 146 | } 147 | } 148 | 149 | /// Packet delivery [Quality of Service] level. 150 | /// 151 | /// [Quality of Service]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718099 152 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 153 | #[cfg_attr(feature = "derive", derive(Serialize, Deserialize))] 154 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 155 | pub enum QoS { 156 | /// `QoS 0`. No ack needed. 157 | AtMostOnce, 158 | /// `QoS 1`. One ack needed. 159 | AtLeastOnce, 160 | /// `QoS 2`. Two acks needed. 161 | ExactlyOnce, 162 | } 163 | 164 | impl QoS { 165 | pub(crate) fn as_u8(&self) -> u8 { 166 | match *self { 167 | QoS::AtMostOnce => 0, 168 | QoS::AtLeastOnce => 1, 169 | QoS::ExactlyOnce => 2, 170 | } 171 | } 172 | 173 | pub(crate) fn from_u8(byte: u8) -> Result { 174 | match byte { 175 | 0 => Ok(QoS::AtMostOnce), 176 | 1 => Ok(QoS::AtLeastOnce), 177 | 2 => Ok(QoS::ExactlyOnce), 178 | n => Err(Error::InvalidQos(n)), 179 | } 180 | } 181 | } 182 | 183 | /// Combined [`QoS`]/[`Pid`]. 184 | /// 185 | /// Used only in [`Publish`] packets. 186 | /// 187 | /// [`Publish`]: struct.Publish.html 188 | /// [`QoS`]: enum.QoS.html 189 | /// [`Pid`]: struct.Pid.html 190 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 191 | #[cfg_attr(feature = "derive", derive(Serialize, Deserialize))] 192 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 193 | pub enum QosPid { 194 | AtMostOnce, 195 | AtLeastOnce(Pid), 196 | ExactlyOnce(Pid), 197 | } 198 | 199 | impl QosPid { 200 | /// Extract the [`Pid`] from a `QosPid`, if any. 201 | /// 202 | /// [`Pid`]: struct.Pid.html 203 | pub fn pid(self) -> Option { 204 | match self { 205 | QosPid::AtMostOnce => None, 206 | QosPid::AtLeastOnce(p) => Some(p), 207 | QosPid::ExactlyOnce(p) => Some(p), 208 | } 209 | } 210 | 211 | /// Extract the [`QoS`] from a `QosPid`. 212 | /// 213 | /// [`QoS`]: enum.QoS.html 214 | pub fn qos(self) -> QoS { 215 | match self { 216 | QosPid::AtMostOnce => QoS::AtMostOnce, 217 | QosPid::AtLeastOnce(_) => QoS::AtLeastOnce, 218 | QosPid::ExactlyOnce(_) => QoS::ExactlyOnce, 219 | } 220 | } 221 | } 222 | 223 | #[cfg(test)] 224 | mod test { 225 | use core::convert::TryFrom; 226 | use std::vec; 227 | 228 | use crate::encoding::v4::Pid; 229 | 230 | #[test] 231 | fn pid_add_sub() { 232 | let t: Vec<(u16, u16, u16, u16)> = vec![ 233 | (2, 1, 1, 3), 234 | (100, 1, 99, 101), 235 | (1, 1, core::u16::MAX, 2), 236 | (1, 2, core::u16::MAX - 1, 3), 237 | (1, 3, core::u16::MAX - 2, 4), 238 | (core::u16::MAX, 1, core::u16::MAX - 1, 1), 239 | (core::u16::MAX, 2, core::u16::MAX - 2, 2), 240 | (10, core::u16::MAX, 10, 10), 241 | (10, 0, 10, 10), 242 | (1, 0, 1, 1), 243 | (core::u16::MAX, 0, core::u16::MAX, core::u16::MAX), 244 | ]; 245 | for (cur, d, prev, next) in t { 246 | let sub = Pid::try_from(cur).unwrap() - d; 247 | let add = Pid::try_from(cur).unwrap() + d; 248 | assert_eq!(prev, sub.get(), "{} - {} should be {}", cur, d, prev); 249 | assert_eq!(next, add.get(), "{} + {} should be {}", cur, d, next); 250 | } 251 | } 252 | } 253 | -------------------------------------------------------------------------------- /mqttrust/src/fmt.rs: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2020 Dario Nieuwenhuis 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | #![macro_use] 24 | #![allow(unused_macros)] 25 | 26 | #[cfg(all(feature = "defmt", feature = "log"))] 27 | compile_error!("You may not enable both `defmt` and `log` features."); 28 | 29 | macro_rules! assert { 30 | ($($x:tt)*) => { 31 | { 32 | #[cfg(not(feature = "defmt"))] 33 | ::core::assert!($($x)*); 34 | #[cfg(feature = "defmt")] 35 | ::defmt::assert!($($x)*); 36 | } 37 | }; 38 | } 39 | 40 | macro_rules! assert_eq { 41 | ($($x:tt)*) => { 42 | { 43 | #[cfg(not(feature = "defmt"))] 44 | ::core::assert_eq!($($x)*); 45 | #[cfg(feature = "defmt")] 46 | ::defmt::assert_eq!($($x)*); 47 | } 48 | }; 49 | } 50 | 51 | macro_rules! assert_ne { 52 | ($($x:tt)*) => { 53 | { 54 | #[cfg(not(feature = "defmt"))] 55 | ::core::assert_ne!($($x)*); 56 | #[cfg(feature = "defmt")] 57 | ::defmt::assert_ne!($($x)*); 58 | } 59 | }; 60 | } 61 | 62 | macro_rules! debug_assert { 63 | ($($x:tt)*) => { 64 | { 65 | #[cfg(not(feature = "defmt"))] 66 | ::core::debug_assert!($($x)*); 67 | #[cfg(feature = "defmt")] 68 | ::defmt::debug_assert!($($x)*); 69 | } 70 | }; 71 | } 72 | 73 | macro_rules! debug_assert_eq { 74 | ($($x:tt)*) => { 75 | { 76 | #[cfg(not(feature = "defmt"))] 77 | ::core::debug_assert_eq!($($x)*); 78 | #[cfg(feature = "defmt")] 79 | ::defmt::debug_assert_eq!($($x)*); 80 | } 81 | }; 82 | } 83 | 84 | macro_rules! debug_assert_ne { 85 | ($($x:tt)*) => { 86 | { 87 | #[cfg(not(feature = "defmt"))] 88 | ::core::debug_assert_ne!($($x)*); 89 | #[cfg(feature = "defmt")] 90 | ::defmt::debug_assert_ne!($($x)*); 91 | } 92 | }; 93 | } 94 | 95 | macro_rules! todo { 96 | ($($x:tt)*) => { 97 | { 98 | #[cfg(not(feature = "defmt"))] 99 | ::core::todo!($($x)*); 100 | #[cfg(feature = "defmt")] 101 | ::defmt::todo!($($x)*); 102 | } 103 | }; 104 | } 105 | 106 | macro_rules! unreachable { 107 | ($($x:tt)*) => { 108 | { 109 | #[cfg(not(feature = "defmt"))] 110 | ::core::unreachable!($($x)*); 111 | #[cfg(feature = "defmt")] 112 | ::defmt::unreachable!($($x)*); 113 | } 114 | }; 115 | } 116 | 117 | macro_rules! panic { 118 | ($($x:tt)*) => { 119 | { 120 | #[cfg(not(feature = "defmt"))] 121 | ::core::panic!($($x)*); 122 | #[cfg(feature = "defmt")] 123 | ::defmt::panic!($($x)*); 124 | } 125 | }; 126 | } 127 | 128 | macro_rules! trace { 129 | ($s:literal $(, $x:expr)* $(,)?) => { 130 | { 131 | #[cfg(feature = "log")] 132 | ::log::trace!($s $(, $x)*); 133 | #[cfg(feature = "defmt")] 134 | ::defmt::trace!($s $(, $x)*); 135 | #[cfg(not(any(feature = "log", feature="defmt")))] 136 | let _ = ($( & $x ),*); 137 | } 138 | }; 139 | } 140 | 141 | macro_rules! debug { 142 | ($s:literal $(, $x:expr)* $(,)?) => { 143 | { 144 | #[cfg(feature = "log")] 145 | ::log::debug!($s $(, $x)*); 146 | #[cfg(feature = "defmt")] 147 | ::defmt::debug!($s $(, $x)*); 148 | #[cfg(not(any(feature = "log", feature="defmt")))] 149 | let _ = ($( & $x ),*); 150 | } 151 | }; 152 | } 153 | 154 | macro_rules! info { 155 | ($s:literal $(, $x:expr)* $(,)?) => { 156 | { 157 | #[cfg(feature = "log")] 158 | ::log::info!($s $(, $x)*); 159 | #[cfg(feature = "defmt")] 160 | ::defmt::info!($s $(, $x)*); 161 | #[cfg(not(any(feature = "log", feature="defmt")))] 162 | let _ = ($( & $x ),*); 163 | } 164 | }; 165 | } 166 | 167 | macro_rules! warn { 168 | ($s:literal $(, $x:expr)* $(,)?) => { 169 | { 170 | #[cfg(feature = "log")] 171 | ::log::warn!($s $(, $x)*); 172 | #[cfg(feature = "defmt")] 173 | ::defmt::warn!($s $(, $x)*); 174 | #[cfg(not(any(feature = "log", feature="defmt")))] 175 | let _ = ($( & $x ),*); 176 | } 177 | }; 178 | } 179 | 180 | macro_rules! error { 181 | ($s:literal $(, $x:expr)* $(,)?) => { 182 | { 183 | #[cfg(feature = "log")] 184 | ::log::error!($s $(, $x)*); 185 | #[cfg(feature = "defmt")] 186 | ::defmt::error!($s $(, $x)*); 187 | #[cfg(not(any(feature = "log", feature="defmt")))] 188 | let _ = ($( & $x ),*); 189 | } 190 | }; 191 | } 192 | 193 | #[cfg(feature = "defmt")] 194 | macro_rules! unwrap { 195 | ($($x:tt)*) => { 196 | ::defmt::unwrap!($($x)*) 197 | }; 198 | } 199 | 200 | #[cfg(not(feature = "defmt"))] 201 | macro_rules! unwrap { 202 | ($arg:expr) => { 203 | match $crate::fmt::Try::into_result($arg) { 204 | ::core::result::Result::Ok(t) => t, 205 | ::core::result::Result::Err(e) => { 206 | ::core::panic!("unwrap of `{}` failed: {:?}", ::core::stringify!($arg), e); 207 | } 208 | } 209 | }; 210 | ($arg:expr, $($msg:expr),+ $(,)? ) => { 211 | match $crate::fmt::Try::into_result($arg) { 212 | ::core::result::Result::Ok(t) => t, 213 | ::core::result::Result::Err(e) => { 214 | ::core::panic!("unwrap of `{}` failed: {}: {:?}", ::core::stringify!($arg), ::core::format_args!($($msg,)*), e); 215 | } 216 | } 217 | } 218 | } 219 | 220 | #[derive(Debug, Copy, Clone, Eq, PartialEq)] 221 | pub struct NoneError; 222 | 223 | pub trait Try { 224 | type Ok; 225 | type Error; 226 | fn into_result(self) -> Result; 227 | } 228 | 229 | impl Try for Option { 230 | type Ok = T; 231 | type Error = NoneError; 232 | 233 | #[inline] 234 | fn into_result(self) -> Result { 235 | self.ok_or(NoneError) 236 | } 237 | } 238 | 239 | impl Try for Result { 240 | type Ok = T; 241 | type Error = E; 242 | 243 | #[inline] 244 | fn into_result(self) -> Self { 245 | self 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /mqttrust/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(not(test), no_std)] 2 | 3 | // This mod MUST go first, so that the others see its macros. 4 | pub(crate) mod fmt; 5 | 6 | pub mod encoding; 7 | 8 | pub use encoding::v4::{ 9 | subscribe::SubscribeTopic, utils::QoS, Packet, Publish, Subscribe, Unsubscribe, 10 | }; 11 | 12 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 13 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 14 | pub enum MqttError { 15 | /// Queue full, cannot send/receive more packets 16 | Full, 17 | /// RefCell borrow fault 18 | Borrow, 19 | /// Needed resource is unavailable 20 | Unavailable, 21 | } 22 | 23 | pub trait Mqtt { 24 | fn send(&self, packet: Packet<'_>) -> Result<(), MqttError>; 25 | 26 | fn client_id(&self) -> &str; 27 | 28 | fn publish(&self, topic_name: &str, payload: &[u8], qos: QoS) -> Result<(), MqttError> { 29 | let packet = Packet::Publish(Publish { 30 | dup: false, 31 | qos, 32 | pid: None, 33 | retain: false, 34 | topic_name, 35 | payload, 36 | }); 37 | 38 | self.send(packet) 39 | } 40 | 41 | fn subscribe(&self, topics: &[SubscribeTopic<'_>]) -> Result<(), MqttError> { 42 | let packet = Packet::Subscribe(Subscribe::new(topics)); 43 | self.send(packet) 44 | } 45 | 46 | fn unsubscribe(&self, topics: &[&str]) -> Result<(), MqttError> { 47 | let packet = Packet::Unsubscribe(Unsubscribe::new(topics)); 48 | self.send(packet) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /mqttrust_core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mqttrust_core" 3 | version = "0.6.0" 4 | authors = ["Mathias Koch "] 5 | description = "MQTT Client " 6 | readme = "../README.md" 7 | keywords = ["mqtt", "no-std"] 8 | categories = ["embedded", "no-std"] 9 | license = "MIT OR Apache-2.0" 10 | repository = "https://github.com/BlackbirdHQ/mqttrust" 11 | edition = "2018" 12 | documentation = "https://docs.rs/mqttrust_core" 13 | 14 | [lib] 15 | name = "mqttrust_core" 16 | 17 | [[example]] 18 | name = "echo" 19 | required-features = ["log"] 20 | 21 | [[example]] 22 | name = "aws_device_advisor" 23 | required-features = ["log"] 24 | 25 | [badges] 26 | maintenance = { status = "actively-developed" } 27 | 28 | [dependencies] 29 | embedded-nal = "0.6.0" 30 | nb = "^1" 31 | heapless = { version = "^0.7", features = ["serde", "x86-sync-pool"] } 32 | mqttrust = { version = "^0.6.0", path = "../mqttrust" } 33 | bbqueue = "0.5" 34 | fugit = { version = "0.3" } 35 | fugit-timer = "0.1.2" 36 | 37 | log = { version = "^0.4", default-features = false, optional = true } 38 | defmt = { version = "^0.3", optional = true } 39 | 40 | [dev-dependencies] 41 | native-tls = { version = "^0.2" } 42 | dns-lookup = "1.0.3" 43 | env_logger = "0.9.0" 44 | 45 | [features] 46 | default = ["max_payload_size_4096"] 47 | max_payload_size_2048 = [] 48 | max_payload_size_4096 = [] 49 | max_payload_size_8192 = [] 50 | 51 | 52 | std = [] 53 | 54 | defmt-impl = [ 55 | "defmt", 56 | "mqttrust/defmt-impl", 57 | "heapless/defmt-impl", 58 | "fugit/defmt", 59 | ] 60 | -------------------------------------------------------------------------------- /mqttrust_core/examples/aws_device_advisor.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | 3 | use mqttrust::{Mqtt, QoS, SubscribeTopic}; 4 | use mqttrust_core::bbqueue::BBBuffer; 5 | use mqttrust_core::{EventLoop, MqttOptions, Notification}; 6 | 7 | use common::clock::SysClock; 8 | use common::network::Network; 9 | use native_tls::TlsConnector; 10 | use std::sync::Arc; 11 | use std::thread; 12 | 13 | use crate::common::credentials; 14 | 15 | static mut Q: BBBuffer<{ 1024 * 6 }> = BBBuffer::new(); 16 | 17 | fn main() { 18 | env_logger::init(); 19 | 20 | let (p, c) = unsafe { Q.try_split_framed().unwrap() }; 21 | 22 | let hostname = credentials::HOSTNAME.unwrap(); 23 | 24 | let connector = TlsConnector::builder() 25 | .identity(credentials::identity()) 26 | .add_root_certificate(credentials::root_ca()) 27 | .build() 28 | .unwrap(); 29 | 30 | let mut network = Network::new_tls(connector, String::from(hostname)); 31 | 32 | let thing_name = "mqttrust"; 33 | 34 | let mut mqtt_eventloop = EventLoop::new( 35 | c, 36 | SysClock::new(), 37 | MqttOptions::new(thing_name, hostname.into(), 8883), 38 | ); 39 | 40 | let mqtt_client = mqttrust_core::Client::new(p, thing_name); 41 | 42 | thread::Builder::new() 43 | .name("eventloop".to_string()) 44 | .spawn(move || loop { 45 | match nb::block!(mqtt_eventloop.connect(&mut network)) { 46 | Err(_) => continue, 47 | Ok(true) => { 48 | log::info!("Successfully connected to broker"); 49 | } 50 | Ok(false) => {} 51 | } 52 | 53 | match mqtt_eventloop.yield_event(&mut network) { 54 | Ok(Notification::Publish(_)) => {} 55 | Ok(n) => { 56 | log::trace!("{:?}", n); 57 | } 58 | _ => {} 59 | } 60 | }) 61 | .unwrap(); 62 | 63 | loop { 64 | thread::sleep(std::time::Duration::from_millis(5000)); 65 | mqtt_client 66 | .subscribe(&[SubscribeTopic { 67 | topic_path: format!("{}/device/advisor", thing_name).as_str(), 68 | qos: QoS::AtLeastOnce, 69 | }]) 70 | .unwrap(); 71 | 72 | mqtt_client 73 | .publish( 74 | format!("{}/device/advisor/hello", thing_name).as_str(), 75 | format!("Hello from {}", thing_name).as_bytes(), 76 | QoS::AtLeastOnce, 77 | ) 78 | .unwrap(); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /mqttrust_core/examples/common/clock.rs: -------------------------------------------------------------------------------- 1 | use fugit::TimerInstantU32; 2 | use std::{ 3 | convert::Infallible, 4 | time::{SystemTime, UNIX_EPOCH}, 5 | }; 6 | pub struct SysClock { 7 | start_time: u32, 8 | countdown_end: Option, 9 | } 10 | 11 | impl SysClock { 12 | pub fn new() -> Self { 13 | Self { 14 | start_time: Self::epoch(), 15 | countdown_end: None, 16 | } 17 | } 18 | 19 | pub fn epoch() -> u32 { 20 | SystemTime::now() 21 | .duration_since(UNIX_EPOCH) 22 | .expect("Time went backwards") 23 | .as_millis() as u32 24 | } 25 | 26 | pub fn now(&self) -> u32 { 27 | Self::epoch() - self.start_time 28 | } 29 | } 30 | 31 | impl fugit_timer::Timer<1000> for SysClock { 32 | type Error = Infallible; 33 | 34 | fn now(&mut self) -> fugit::TimerInstantU32<1000> { 35 | TimerInstantU32::from_ticks(SysClock::now(self)) 36 | } 37 | 38 | fn start(&mut self, duration: fugit::TimerDurationU32<1000>) -> Result<(), Self::Error> { 39 | todo!() 40 | } 41 | 42 | fn cancel(&mut self) -> Result<(), Self::Error> { 43 | todo!() 44 | } 45 | 46 | fn wait(&mut self) -> nb::Result<(), Self::Error> { 47 | todo!() 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /mqttrust_core/examples/common/credentials.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | 3 | use native_tls::{Certificate, Identity}; 4 | 5 | pub fn identity() -> Identity { 6 | let pw = env::var("DEVICE_ADVISOR_PASSWORD").unwrap(); 7 | Identity::from_pkcs12(include_bytes!("../secrets/identity.pfx"), pw.as_str()).unwrap() 8 | } 9 | 10 | pub fn root_ca() -> Certificate { 11 | Certificate::from_pem(include_bytes!("../secrets/root-ca.pem")).unwrap() 12 | } 13 | 14 | pub const HOSTNAME: Option<&'static str> = option_env!("AWS_HOSTNAME"); 15 | -------------------------------------------------------------------------------- /mqttrust_core/examples/common/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod clock; 2 | pub mod credentials; 3 | pub mod network; 4 | -------------------------------------------------------------------------------- /mqttrust_core/examples/common/network.rs: -------------------------------------------------------------------------------- 1 | use embedded_nal::{AddrType, Dns, IpAddr, SocketAddr, TcpClientStack}; 2 | use native_tls::{TlsConnector, TlsStream}; 3 | use std::io::{Read, Write}; 4 | use std::marker::PhantomData; 5 | use std::net::TcpStream; 6 | 7 | use dns_lookup::{lookup_addr, lookup_host}; 8 | 9 | /// An std::io::Error compatible error type returned when an operation is requested in the wrong 10 | /// sequence (where the "right" is create a socket, connect, any receive/send, and possibly close). 11 | #[derive(Debug)] 12 | struct OutOfOrder; 13 | 14 | impl std::fmt::Display for OutOfOrder { 15 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 16 | write!(f, "Out of order operations requested") 17 | } 18 | } 19 | 20 | impl std::error::Error for OutOfOrder {} 21 | 22 | impl Into> for OutOfOrder { 23 | fn into(self) -> std::io::Result { 24 | Err(std::io::Error::new( 25 | std::io::ErrorKind::NotConnected, 26 | OutOfOrder, 27 | )) 28 | } 29 | } 30 | 31 | pub struct Network { 32 | tls_connector: Option<(TlsConnector, String)>, 33 | _sec: PhantomData, 34 | } 35 | 36 | impl Network> { 37 | pub fn new_tls(tls_connector: TlsConnector, hostname: String) -> Self { 38 | Self { 39 | tls_connector: Some((tls_connector, hostname)), 40 | _sec: PhantomData, 41 | } 42 | } 43 | } 44 | 45 | impl Network { 46 | pub fn new() -> Self { 47 | Self { 48 | tls_connector: None, 49 | _sec: PhantomData, 50 | } 51 | } 52 | } 53 | 54 | pub(crate) fn to_nb(e: std::io::Error) -> nb::Error { 55 | use std::io::ErrorKind::{TimedOut, WouldBlock}; 56 | match e.kind() { 57 | WouldBlock | TimedOut => nb::Error::WouldBlock, 58 | _ => e.into(), 59 | } 60 | } 61 | 62 | pub struct TcpSocket { 63 | pub stream: Option, 64 | } 65 | 66 | impl TcpSocket { 67 | pub fn new() -> Self { 68 | TcpSocket { stream: None } 69 | } 70 | 71 | pub fn get_running(&mut self) -> std::io::Result<&mut T> { 72 | match self.stream { 73 | Some(ref mut s) => Ok(s), 74 | _ => OutOfOrder.into(), 75 | } 76 | } 77 | } 78 | 79 | impl Dns for Network { 80 | type Error = (); 81 | 82 | fn get_host_by_address( 83 | &mut self, 84 | ip_addr: IpAddr, 85 | ) -> nb::Result, Self::Error> { 86 | let ip: std::net::IpAddr = format!("{}", ip_addr).parse().unwrap(); 87 | let host = lookup_addr(&ip).unwrap(); 88 | Ok(heapless::String::from(host.as_str())) 89 | } 90 | fn get_host_by_name( 91 | &mut self, 92 | hostname: &str, 93 | _addr_type: AddrType, 94 | ) -> nb::Result { 95 | let ips: Vec = lookup_host(hostname).unwrap(); 96 | let ip = ips 97 | .iter() 98 | .find(|s| matches!(s, std::net::IpAddr::V4(_))) 99 | .unwrap(); 100 | format!("{}", ip).parse().map_err(|_| nb::Error::Other(())) 101 | } 102 | } 103 | 104 | impl TcpClientStack for Network> { 105 | type Error = std::io::Error; 106 | type TcpSocket = TcpSocket>; 107 | 108 | fn socket(&mut self) -> Result { 109 | Ok(TcpSocket::new()) 110 | } 111 | 112 | fn receive( 113 | &mut self, 114 | network: &mut Self::TcpSocket, 115 | buf: &mut [u8], 116 | ) -> nb::Result { 117 | let socket = network.get_running()?; 118 | socket.read(buf).map_err(to_nb) 119 | } 120 | 121 | fn send( 122 | &mut self, 123 | network: &mut Self::TcpSocket, 124 | buf: &[u8], 125 | ) -> nb::Result { 126 | let socket = network.get_running()?; 127 | socket.write(buf).map_err(to_nb) 128 | } 129 | 130 | fn is_connected(&mut self, network: &Self::TcpSocket) -> Result { 131 | Ok(network.stream.is_some()) 132 | } 133 | 134 | fn connect( 135 | &mut self, 136 | network: &mut Self::TcpSocket, 137 | remote: SocketAddr, 138 | ) -> nb::Result<(), Self::Error> { 139 | let soc = TcpStream::connect(format!("{}", remote))?; 140 | 141 | let (connector, hostname) = self.tls_connector.as_ref().unwrap(); 142 | 143 | let mut tls_stream = connector.connect(hostname, soc).map_err(|e| match e { 144 | native_tls::HandshakeError::Failure(_) => nb::Error::Other(std::io::Error::new( 145 | std::io::ErrorKind::Other, 146 | "Failed TLS handshake", 147 | )), 148 | native_tls::HandshakeError::WouldBlock(_) => nb::Error::WouldBlock, 149 | })?; 150 | 151 | tls_stream.get_mut().set_nonblocking(true)?; 152 | network.stream.replace(tls_stream); 153 | 154 | Ok(()) 155 | } 156 | 157 | fn close(&mut self, _network: Self::TcpSocket) -> Result<(), Self::Error> { 158 | // No-op: Socket gets closed when it is freed 159 | // 160 | // Could wrap it in an Option, but really that'll only make things messier; users will 161 | // probably drop the socket anyway after closing, and can't expect it to be usable with 162 | // this API. 163 | Ok(()) 164 | } 165 | } 166 | 167 | impl TcpClientStack for Network { 168 | type Error = std::io::Error; 169 | type TcpSocket = TcpSocket; 170 | 171 | fn socket(&mut self) -> Result { 172 | Ok(TcpSocket::new()) 173 | } 174 | 175 | fn receive( 176 | &mut self, 177 | network: &mut Self::TcpSocket, 178 | buf: &mut [u8], 179 | ) -> nb::Result { 180 | let socket = network.get_running()?; 181 | socket.read(buf).map_err(to_nb) 182 | } 183 | 184 | fn send( 185 | &mut self, 186 | network: &mut Self::TcpSocket, 187 | buf: &[u8], 188 | ) -> nb::Result { 189 | let socket = network.get_running()?; 190 | socket.write(buf).map_err(to_nb) 191 | } 192 | 193 | fn is_connected(&mut self, network: &Self::TcpSocket) -> Result { 194 | Ok(network.stream.is_some()) 195 | } 196 | 197 | fn connect( 198 | &mut self, 199 | network: &mut Self::TcpSocket, 200 | remote: SocketAddr, 201 | ) -> nb::Result<(), Self::Error> { 202 | let mut soc = TcpStream::connect(format!("{}", remote))?; 203 | soc.set_nonblocking(true)?; 204 | network.stream.replace(soc); 205 | 206 | Ok(()) 207 | } 208 | 209 | fn close(&mut self, _network: Self::TcpSocket) -> Result<(), Self::Error> { 210 | // No-op: Socket gets closed when it is freed 211 | // 212 | // Could wrap it in an Option, but really that'll only make things messier; users will 213 | // probably drop the socket anyway after closing, and can't expect it to be usable with 214 | // this API. 215 | Ok(()) 216 | } 217 | } 218 | -------------------------------------------------------------------------------- /mqttrust_core/examples/echo.rs: -------------------------------------------------------------------------------- 1 | mod common; 2 | 3 | use mqttrust::{QoS, SubscribeTopic}; 4 | use mqttrust_core::{bbqueue::BBBuffer, EventLoop, Mqtt, MqttOptions, Notification}; 5 | 6 | use common::clock::SysClock; 7 | use common::network::Network; 8 | use std::thread; 9 | 10 | static mut Q: BBBuffer<{ 1024 * 6 }> = BBBuffer::new(); 11 | const MSG_CNT: u32 = 5; 12 | 13 | fn main() { 14 | env_logger::init(); 15 | 16 | let (p, c) = unsafe { Q.try_split_framed().unwrap() }; 17 | 18 | let mut network = Network::new(); 19 | 20 | let client_id = "mqtt_test_client_id"; 21 | 22 | // Connect to broker.hivemq.com:1883 23 | let mut mqtt_eventloop = EventLoop::new( 24 | c, 25 | SysClock::new(), 26 | MqttOptions::new(client_id, "broker.hivemq.com".into(), 1883), 27 | ); 28 | 29 | let mqtt_client = mqttrust_core::Client::new(p, client_id); 30 | 31 | nb::block!(mqtt_eventloop.connect(&mut network)).expect("Failed to connect to MQTT"); 32 | 33 | let handle = thread::Builder::new() 34 | .name("eventloop".to_string()) 35 | .spawn(move || { 36 | let mut receive_cnt = 0; 37 | while receive_cnt < MSG_CNT { 38 | match mqtt_eventloop.yield_event(&mut network) { 39 | Ok(Notification::Publish(publish)) => { 40 | log::debug!("Received {:?}", publish); 41 | receive_cnt += 1; 42 | } 43 | Ok(n) => { 44 | log::debug!("{:?}", n); 45 | } 46 | _ => {} 47 | } 48 | } 49 | receive_cnt 50 | }) 51 | .unwrap(); 52 | 53 | mqtt_client 54 | .subscribe(&[ 55 | SubscribeTopic { 56 | topic_path: "mqttrust/tester/subscriber", 57 | qos: QoS::AtLeastOnce, 58 | }, 59 | SubscribeTopic { 60 | topic_path: "mqttrust/tester/subscriber2", 61 | qos: QoS::AtLeastOnce, 62 | }, 63 | ]) 64 | .expect("Failed to subscribe to topics!"); 65 | 66 | let mut send_cnt = 0; 67 | 68 | while send_cnt < MSG_CNT { 69 | log::debug!("Sending {}", send_cnt); 70 | mqtt_client 71 | .publish( 72 | "mqttrust/tester/subscriber", 73 | format!("{{\"count\": {} }}", send_cnt).as_bytes(), 74 | QoS::AtLeastOnce, 75 | ) 76 | .expect("Failed to publish"); 77 | 78 | send_cnt += 1; 79 | thread::sleep(std::time::Duration::from_millis(5000)); 80 | } 81 | 82 | let receive_cnt = handle.join().expect("Receiving thread failed!"); 83 | 84 | assert_eq!(receive_cnt, send_cnt); 85 | 86 | println!("Success!"); 87 | } 88 | -------------------------------------------------------------------------------- /mqttrust_core/examples/secrets/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything 2 | * 3 | 4 | # But not these files... 5 | !.gitignore 6 | !identity.pfx 7 | !root-ca.pem -------------------------------------------------------------------------------- /mqttrust_core/examples/secrets/identity.pfx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FactbirdHQ/mqttrust/aed58e41c36b31e45cf11eef42c74ab9a31a4cf4/mqttrust_core/examples/secrets/identity.pfx -------------------------------------------------------------------------------- /mqttrust_core/examples/secrets/root-ca.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIDQTCCAimgAwIBAgITBmyfz5m/jAo54vB4ikPmljZbyjANBgkqhkiG9w0BAQsF 3 | ADA5MQswCQYDVQQGEwJVUzEPMA0GA1UEChMGQW1hem9uMRkwFwYDVQQDExBBbWF6 4 | b24gUm9vdCBDQSAxMB4XDTE1MDUyNjAwMDAwMFoXDTM4MDExNzAwMDAwMFowOTEL 5 | MAkGA1UEBhMCVVMxDzANBgNVBAoTBkFtYXpvbjEZMBcGA1UEAxMQQW1hem9uIFJv 6 | b3QgQ0EgMTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALJ4gHHKeNXj 7 | ca9HgFB0fW7Y14h29Jlo91ghYPl0hAEvrAIthtOgQ3pOsqTQNroBvo3bSMgHFzZM 8 | 9O6II8c+6zf1tRn4SWiw3te5djgdYZ6k/oI2peVKVuRF4fn9tBb6dNqcmzU5L/qw 9 | IFAGbHrQgLKm+a/sRxmPUDgH3KKHOVj4utWp+UhnMJbulHheb4mjUcAwhmahRWa6 10 | VOujw5H5SNz/0egwLX0tdHA114gk957EWW67c4cX8jJGKLhD+rcdqsq08p8kDi1L 11 | 93FcXmn/6pUCyziKrlA4b9v7LWIbxcceVOF34GfID5yHI9Y/QCB/IIDEgEw+OyQm 12 | jgSubJrIqg0CAwEAAaNCMEAwDwYDVR0TAQH/BAUwAwEB/zAOBgNVHQ8BAf8EBAMC 13 | AYYwHQYDVR0OBBYEFIQYzIU07LwMlJQuCFmcx7IQTgoIMA0GCSqGSIb3DQEBCwUA 14 | A4IBAQCY8jdaQZChGsV2USggNiMOruYou6r4lK5IpDB/G/wkjUu0yKGX9rbxenDI 15 | U5PMCCjjmCXPI6T53iHTfIUJrU6adTrCC2qJeHZERxhlbI1Bjjt/msv0tadQ1wUs 16 | N+gDS63pYaACbvXy8MWy7Vu33PqUXHeeE6V/Uq2V8viTO96LXFvKWlJbYK8U90vv 17 | o/ufQJVtMVT8QtPHRh8jrdkPSHCa2XV4cdFyQzR1bldZwgJcJmApzyMZFo6IQ6XU 18 | 5MsI+yMRQ+hDKXJioaldXgjUkK642M4UwtBV8ob2xJNDd2ZhwLnoQdeXeGADbkpy 19 | rqXRfboQnoZsG4q5WTP468SQvvG5 20 | -----END CERTIFICATE----- 21 | -------------------------------------------------------------------------------- /mqttrust_core/src/client.rs: -------------------------------------------------------------------------------- 1 | use bbqueue::framed::FrameProducer; 2 | use core::cell::RefCell; 3 | use core::ops::DerefMut; 4 | use mqttrust::{ 5 | encoding::v4::{encoder::encode_slice, Packet}, 6 | Mqtt, MqttError, 7 | }; 8 | /// MQTT Client 9 | /// 10 | /// This client is merely a convenience wrapper around a 11 | /// `heapless::spsc::Producer`, making it easier to send certain MQTT packet 12 | /// types, and maintaining a common reference to a client id. Also it implements 13 | /// the [`Mqtt`] trait. 14 | /// 15 | /// **Lifetimes**: 16 | /// - `'a`: Lifetime of the queue for exchanging packets between the client and 17 | /// [Eventloop](crate::eventloop::EventLoop). This must have the same lifetime as the corresponding 18 | /// Consumer. Usually `'static`. 19 | /// - `'b`: Lifetime of `client_id` str. 20 | /// 21 | /// **Generics**: 22 | /// - `L`: Length of the queue for exchanging packets between the client and 23 | /// [Eventloop](crate::eventloop::EventLoop). 24 | /// The length is in bytes and it must be chosen long enough to contain serialized MQTT packets and 25 | /// [FrameProducer](bbqueue::framed) header bytes. 26 | /// For example a MQTT packet with 30 bytes payload and 20 bytes topic name takes 59 bytes to serialize 27 | /// into MQTT frame plus ~2 bytes (depending on grant length) for [FrameProducer](bbqueue::framed) header. 28 | /// For rough calculation `payload_len + topic_name + 15` can be used to determine 29 | /// how many bytes one packet consumes. 30 | /// Packets are read out from queue only when [Eventloop::yield_event](crate::eventloop::EventLoop::yield_event) is called. 31 | /// Therefore make sure that queue length is long enough to contain multiple packets if you want to call 32 | /// [send](Client::send) multiple times in the row. 33 | pub struct Client<'a, 'b, const L: usize> { 34 | client_id: &'b str, 35 | producer: Option>>, 36 | } 37 | 38 | impl<'a, 'b, const L: usize> Client<'a, 'b, L> { 39 | pub fn new(producer: FrameProducer<'a, L>, client_id: &'b str) -> Self { 40 | Self { 41 | client_id, 42 | producer: Some(RefCell::new(producer)), 43 | } 44 | } 45 | 46 | /// Release `FrameProducer` 47 | /// 48 | /// This can be called before dropping `Client` to get back original `FrameProducer`. 49 | pub fn release_queue(&mut self) -> Option> { 50 | match self.producer.take() { 51 | Some(prod) => Some(prod.into_inner()), 52 | None => None, 53 | } 54 | } 55 | } 56 | 57 | impl<'a, 'b, const L: usize> Mqtt for Client<'a, 'b, L> { 58 | fn client_id(&self) -> &str { 59 | &self.client_id 60 | } 61 | 62 | fn send(&self, packet: Packet<'_>) -> Result<(), MqttError> { 63 | match &self.producer { 64 | Some(producer) => { 65 | let mut prod = producer.try_borrow_mut().map_err(|_| MqttError::Borrow)?; 66 | let max_size = packet.len(); 67 | let mut grant = prod.grant(max_size).map_err(|_| MqttError::Full)?; 68 | let len = encode_slice(&packet, grant.deref_mut()).map_err(|_| MqttError::Full)?; 69 | grant.commit(len); 70 | Ok(()) 71 | } 72 | None => Err(MqttError::Unavailable), 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /mqttrust_core/src/eventloop.rs: -------------------------------------------------------------------------------- 1 | use crate::max_payload::MAX_PAYLOAD_SIZE; 2 | use crate::options::Broker; 3 | use crate::packet::SerializedPacket; 4 | use crate::state::{MqttConnectionStatus, MqttState}; 5 | use crate::{EventError, MqttOptions, NetworkError, Notification}; 6 | use bbqueue::framed::FrameConsumer; 7 | use core::convert::Infallible; 8 | use core::ops::DerefMut; 9 | use core::ops::RangeTo; 10 | use embedded_nal::{AddrType, Dns, SocketAddr, TcpClientStack}; 11 | use fugit::ExtU32; 12 | use heapless::{String, Vec}; 13 | use mqttrust::encoding::v4::{decode_slice, encode_slice, Connect, Packet, Protocol}; 14 | 15 | pub struct EventLoop<'a, 'b, S, O, const TIMER_HZ: u32, const L: usize> 16 | where 17 | O: fugit_timer::Timer, 18 | { 19 | /// Current state of the connection 20 | pub(crate) state: MqttState, 21 | /// Last outgoing packet time 22 | pub(crate) last_outgoing_timer: O, 23 | /// Options of the current mqtt connection 24 | pub options: MqttOptions<'b>, 25 | /// Request stream 26 | pub(crate) requests: Option>, 27 | network_handle: NetworkHandle, 28 | } 29 | 30 | impl<'a, 'b, S, O, const TIMER_HZ: u32, const L: usize> EventLoop<'a, 'b, S, O, TIMER_HZ, L> 31 | where 32 | O: fugit_timer::Timer, 33 | { 34 | pub fn new( 35 | requests: FrameConsumer<'a, L>, 36 | outgoing_timer: O, 37 | options: MqttOptions<'b>, 38 | ) -> Self { 39 | Self { 40 | state: MqttState::new(), 41 | last_outgoing_timer: outgoing_timer, 42 | options, 43 | requests: Some(requests), 44 | network_handle: NetworkHandle::new(), 45 | } 46 | } 47 | 48 | /// Release `FrameConsumer` 49 | /// 50 | /// This can be called before dropping `EventLoop` to get back original `FrameConsumer`. 51 | pub fn release_queue(&mut self) -> Option> { 52 | self.requests.take() 53 | } 54 | 55 | pub fn connect + ?Sized>( 56 | &mut self, 57 | network: &mut N, 58 | ) -> nb::Result { 59 | // connect to the broker 60 | match self.network_handle.is_connected(network) { 61 | Ok(false) => { 62 | // Socket is present, but not connected. Usually this implies 63 | // that the socket is closed for writes. Disconnect to close & 64 | // recycle the socket. 65 | 66 | // Fallthrough to allow reading mqtt client error codes, unless 67 | // MQTT is actually connected 68 | if matches!( 69 | self.state.connection_status, 70 | MqttConnectionStatus::Connected 71 | ) { 72 | warn!("Socket cleanup!"); 73 | self.disconnect(network); 74 | return Err(EventError::Network(NetworkError::SocketClosed).into()); 75 | } 76 | } 77 | Err(_) => { 78 | // We have no socket present at all 79 | self.network_handle 80 | .connect(network, self.options.broker()) 81 | .map_err(EventError::Network)?; 82 | debug!("Network connected!"); 83 | 84 | self.state.connection_status = MqttConnectionStatus::Disconnected; 85 | } 86 | Ok(true) => { 87 | // Socket is there, and is connected. Proceed to make sure MQTT is connected 88 | } 89 | } 90 | 91 | self.mqtt_connect(network).map_err(|e| { 92 | e.map(|e| { 93 | if matches!( 94 | e, 95 | EventError::Network(_) | EventError::MqttState(_) | EventError::Timeout 96 | ) { 97 | debug!("Disconnecting!"); 98 | self.disconnect(network); 99 | } 100 | e 101 | }) 102 | }) 103 | } 104 | 105 | fn should_handle_request(&mut self) -> bool { 106 | let qos_space = self.state.outgoing_pub.len() < self.state.outgoing_pub.capacity(); 107 | 108 | // TODO: 109 | // let qos_0 = if let Some(_) = self.requests.read() { 110 | // p.qos == QoS::AtMostOnce 111 | // } else { 112 | // false 113 | // }; 114 | 115 | // qos_0 || (self.requests.ready() && qos_space) 116 | qos_space 117 | } 118 | 119 | /// Selects an event from the client's requests, incoming packets from the 120 | /// broker and keepalive ping cycle. 121 | fn select_event + ?Sized>( 122 | &mut self, 123 | network: &mut N, 124 | ) -> nb::Result { 125 | let now = self.last_outgoing_timer.now(); 126 | 127 | // Handle a request 128 | if self.should_handle_request() { 129 | match &mut self.requests { 130 | Some(requests) => { 131 | if let Some(mut grant) = requests.read() { 132 | let mut packet = SerializedPacket(grant.deref_mut()); 133 | match self.state.handle_outgoing_request(&mut packet, &now) { 134 | Ok(()) => { 135 | self.network_handle.send(network, packet.to_inner())?; 136 | grant.release(); 137 | return Err(nb::Error::WouldBlock); 138 | } 139 | Err(crate::state::StateError::MaxMessagesInflight) => {} 140 | Err(e) => return Err(nb::Error::Other(e.into())), 141 | } 142 | } 143 | } 144 | None => return Err(nb::Error::Other(EventError::RequestsNotAvailable)), 145 | } 146 | } 147 | 148 | if self 149 | .state 150 | .last_ping_entry() 151 | .or_insert(now) 152 | .has_elapsed(&now, self.options.keep_alive_ms().millis()) 153 | { 154 | // Handle keepalive ping 155 | let packet = self 156 | .state 157 | .handle_outgoing_packet(Packet::Pingreq) 158 | .map_err(EventError::from)?; 159 | self.network_handle.send_packet(network, &packet)?; 160 | self.state.last_ping_entry().insert(now); 161 | return Err(nb::Error::WouldBlock); 162 | } 163 | 164 | // Handle an incoming packet 165 | let (notification, packet) = self 166 | .network_handle 167 | .receive(network) 168 | .map_err(|e| e.map(EventError::Network))? 169 | .decode(&mut self.state)?; 170 | 171 | // Handle `ack` of newly received incoming packet, if relevant 172 | if let Some(packet) = packet { 173 | self.network_handle.send_packet(network, &packet)?; 174 | } 175 | 176 | // By comparing the current time, select pending non-zero QoS publish 177 | // requests staying longer than the retry interval, and handle their 178 | // retrial. 179 | for (pid, inflight) in self.state.retries(now, 10.secs()) { 180 | warn!("Retrying PID {:?}", pid); 181 | // Update inflight's timestamp for later retrials 182 | inflight.last_touch_entry().insert(now); 183 | let packet = inflight.packet(*pid).map_err(EventError::from)?; 184 | self.network_handle.send(network, &packet)?; 185 | } 186 | 187 | notification.ok_or(nb::Error::WouldBlock) 188 | } 189 | 190 | /// Yields notification from events. All the error raised while processing 191 | /// event is reported as an `Ok` value of `Notification::Abort`. 192 | #[must_use = "Eventloop should be iterated over a loop to make progress"] 193 | pub fn yield_event + ?Sized>( 194 | &mut self, 195 | network: &mut N, 196 | ) -> nb::Result { 197 | if self.network_handle.socket.is_none() { 198 | return Ok(Notification::Abort(EventError::Network( 199 | NetworkError::NoSocket, 200 | ))); 201 | } 202 | 203 | self.select_event(network).or_else(|e| match e { 204 | nb::Error::WouldBlock => Err(nb::Error::WouldBlock), 205 | nb::Error::Other(e) => { 206 | debug!("Disconnecting from an event error"); 207 | self.disconnect(network); 208 | Ok(Notification::Abort(e)) 209 | } 210 | }) 211 | } 212 | 213 | pub fn disconnect + ?Sized>(&mut self, network: &mut N) { 214 | self.state.connection_status = MqttConnectionStatus::Disconnected; 215 | if let Some(socket) = self.network_handle.socket.take() { 216 | network.close(socket).ok(); 217 | } 218 | } 219 | 220 | fn mqtt_connect + ?Sized>( 221 | &mut self, 222 | network: &mut N, 223 | ) -> nb::Result { 224 | match self.state.connection_status { 225 | MqttConnectionStatus::Connected => Ok(false), 226 | MqttConnectionStatus::Disconnected => { 227 | info!("MQTT connecting.."); 228 | let now = self.last_outgoing_timer.now(); 229 | self.state.last_ping_entry().insert(now); 230 | 231 | self.state.await_pingresp = false; 232 | self.network_handle.rx_buf.init(); 233 | 234 | let (username, password) = self.options.credentials(); 235 | 236 | let connect = Packet::Connect(Connect { 237 | protocol: Protocol::MQTT311, 238 | keep_alive: (self.options.keep_alive_ms() / 1000) as u16, 239 | client_id: self.options.client_id(), 240 | clean_session: self.options.clean_session(), 241 | last_will: self.options.last_will(), 242 | username, 243 | password, 244 | }); 245 | 246 | // mqtt connection with timeout 247 | self.network_handle.send_packet(network, &connect)?; 248 | self.state.handle_outgoing_connect(); 249 | Err(nb::Error::WouldBlock) 250 | } 251 | MqttConnectionStatus::Handshake => { 252 | let now = self.last_outgoing_timer.now(); 253 | 254 | if self 255 | .state 256 | .last_ping_entry() 257 | .or_insert(now) 258 | .has_elapsed(&now, 50.secs()) 259 | { 260 | return Err(nb::Error::Other(EventError::Timeout)); 261 | } 262 | 263 | self.network_handle 264 | .receive(network) 265 | .map_err(|e| e.map(EventError::Network))? 266 | .decode(&mut self.state) 267 | .and_then(|(n, p)| { 268 | if n.is_none() && p.is_none() { 269 | return Err(nb::Error::WouldBlock); 270 | } 271 | Ok(n.map(|n| n == Notification::ConnAck).unwrap_or(false)) 272 | }) 273 | } 274 | } 275 | } 276 | } 277 | 278 | struct NetworkHandle { 279 | /// Network socket 280 | socket: Option, 281 | tx_buf: heapless::Vec, 282 | rx_buf: PacketBuffer, 283 | } 284 | 285 | impl NetworkHandle { 286 | fn lookup_host + ?Sized>( 287 | network: &mut N, 288 | broker: Broker, 289 | port: u16, 290 | ) -> Result<(String<256>, SocketAddr), NetworkError> { 291 | match broker { 292 | Broker::Hostname(h) => { 293 | let socket_addr = SocketAddr::new( 294 | network.get_host_by_name(h, AddrType::IPv4).map_err(|_e| { 295 | info!("Failed to resolve IP!"); 296 | NetworkError::DnsLookupFailed 297 | })?, 298 | port, 299 | ); 300 | Ok((String::from(h), socket_addr)) 301 | } 302 | Broker::IpAddr(ip) => { 303 | let socket_addr = SocketAddr::new(ip, port); 304 | let domain = network.get_host_by_address(ip).map_err(|_e| { 305 | info!("Failed to resolve hostname!"); 306 | NetworkError::DnsLookupFailed 307 | })?; 308 | 309 | Ok((domain, socket_addr)) 310 | } 311 | } 312 | } 313 | 314 | fn new() -> Self { 315 | Self { 316 | socket: None, 317 | tx_buf: heapless::Vec::new(), 318 | rx_buf: PacketBuffer::new(), 319 | } 320 | } 321 | 322 | /// Checks if this socket is present and connected. Raises `NetworkError` when 323 | /// the socket is present and in its error state. 324 | fn is_connected + ?Sized>( 325 | &self, 326 | network: &mut N, 327 | ) -> Result { 328 | match self.socket { 329 | Some(ref socket) => network 330 | .is_connected(socket) 331 | .map_err(|_e| NetworkError::SocketClosed), 332 | None => Err(NetworkError::SocketClosed), 333 | } 334 | } 335 | 336 | fn connect + ?Sized>( 337 | &mut self, 338 | network: &mut N, 339 | broker: (Broker, u16), 340 | ) -> Result<(), NetworkError> { 341 | let socket = match self.socket.as_mut() { 342 | None => { 343 | let socket = network.socket().map_err(|_e| NetworkError::SocketOpen)?; 344 | self.socket.get_or_insert(socket) 345 | } 346 | Some(socket) => socket, 347 | }; 348 | 349 | let (broker, port) = broker; 350 | let (_hostname, socket_addr) = NetworkHandle::::lookup_host(network, broker, port)?; 351 | 352 | nb::block!(network.connect(socket, socket_addr)).map_err(|_| { 353 | if let Some(socket) = self.socket.take() { 354 | network.close(socket).ok(); 355 | } 356 | NetworkError::SocketConnect 357 | }) 358 | } 359 | 360 | pub fn send_packet<'d, N: TcpClientStack + ?Sized>( 361 | &mut self, 362 | network: &mut N, 363 | pkt: &Packet, 364 | ) -> Result { 365 | self.tx_buf.clear(); 366 | self.tx_buf 367 | .resize_default(self.tx_buf.capacity()) 368 | .unwrap_or_else(|()| unreachable!("Input length equals to the current capacity.")); 369 | 370 | let size = encode_slice(&pkt, self.tx_buf.as_mut()).map_err(EventError::Encoding)?; 371 | 372 | let socket = self 373 | .socket 374 | .as_mut() 375 | .ok_or(EventError::Network(NetworkError::NoSocket))?; 376 | 377 | let length = nb::block!(network.send(socket, &self.tx_buf[..size])).map_err(|_| { 378 | error!("[send] NetworkError::Write"); 379 | EventError::Network(NetworkError::Write) 380 | })?; 381 | 382 | Ok(length) 383 | } 384 | 385 | pub fn send<'d, N: TcpClientStack + ?Sized>( 386 | &mut self, 387 | network: &mut N, 388 | pkt: &[u8], 389 | ) -> Result { 390 | let socket = self 391 | .socket 392 | .as_mut() 393 | .ok_or(EventError::Network(NetworkError::NoSocket))?; 394 | 395 | let length = nb::block!(network.send(socket, &pkt)).map_err(|_| { 396 | error!("[send] NetworkError::Write"); 397 | EventError::Network(NetworkError::Write) 398 | })?; 399 | 400 | Ok(length) 401 | } 402 | 403 | fn receive + ?Sized>( 404 | &mut self, 405 | network: &mut N, 406 | ) -> nb::Result, NetworkError> { 407 | let socket = self.socket.as_mut().ok_or(NetworkError::NoSocket)?; 408 | 409 | self.rx_buf.receive(socket, network)?; 410 | 411 | Ok(PacketDecoder::new(&mut self.rx_buf)) 412 | } 413 | } 414 | 415 | /// A placeholder that keeps a buffer and constructs a packet incrementally. 416 | /// Given that underlying `TcpClientStack` throws `WouldBlock` in a non-blocking 417 | /// manner, its packet construction won't block either. 418 | #[derive(Debug)] 419 | struct PacketBuffer { 420 | range: RangeTo, 421 | buffer: Vec, 422 | } 423 | 424 | impl PacketBuffer { 425 | fn new() -> Self { 426 | let range = ..0; 427 | let buffer = Vec::new(); 428 | let mut buf = Self { range, buffer }; 429 | buf.init(); 430 | buf 431 | } 432 | 433 | /// Fills the buffer with all 0s 434 | fn init(&mut self) { 435 | self.range.end = 0; 436 | self.buffer.clear(); 437 | self.buffer 438 | .resize(self.buffer.capacity(), 0x00u8) 439 | .unwrap_or_else(|()| unreachable!("Length equals to the current capacity.")); 440 | } 441 | 442 | /// Returns a remaining fresh part of the buffer. 443 | fn buffer(&mut self) -> &mut [u8] { 444 | let range = self.range.end..; 445 | self.buffer[range].as_mut() 446 | } 447 | 448 | /// After decoding a packet, overwrite the used bytes by shifting the buffer 449 | /// by its length. Assumes the length fits within the buffer's capacity. 450 | fn rotate(&mut self, length: usize) { 451 | self.buffer.copy_within(length.., 0); 452 | self.range.end -= length; 453 | self.buffer.truncate(self.buffer.capacity() - length); 454 | self.buffer 455 | .resize(self.buffer.capacity(), 0) 456 | .unwrap_or_else(|()| unreachable!("Length equals to the current capacity.")); 457 | } 458 | 459 | /// Receives bytes from a network socket in non-blocking mode. If incoming 460 | /// bytes found, the range gets extended covering them. 461 | fn receive(&mut self, socket: &mut S, network: &mut N) -> nb::Result<(), NetworkError> 462 | where 463 | N: TcpClientStack + ?Sized, 464 | { 465 | let buffer = self.buffer(); 466 | let len = network.receive(socket, buffer).map_err(|e| { 467 | if matches!(e, nb::Error::WouldBlock) { 468 | nb::Error::WouldBlock 469 | } else { 470 | error!("[receive] NetworkError::Read"); 471 | nb::Error::Other(NetworkError::Read) 472 | } 473 | })?; 474 | self.range.end += len; 475 | Ok(()) 476 | } 477 | } 478 | 479 | /// Provides contextual information for decoding packets. If an incoming packet 480 | /// is well-formed and has a packet type the underlying state expects, returns a 481 | /// notification. On an error, cleans up its buffer state. 482 | struct PacketDecoder<'a> { 483 | packet_buffer: &'a mut PacketBuffer, 484 | is_err: Option, 485 | } 486 | 487 | impl<'a> PacketDecoder<'a> { 488 | fn new(packet_buffer: &'a mut PacketBuffer) -> Self { 489 | Self { 490 | packet_buffer, 491 | is_err: None, 492 | } 493 | } 494 | 495 | // https://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718023 496 | fn packet_length(&self) -> Option { 497 | // The result of earlier decode_slice failed with an error or incomplete 498 | // packet. 499 | if self.is_err.unwrap_or(true) { 500 | return None; 501 | } 502 | 503 | // The buffer contains a valid packet. 504 | self.packet_buffer 505 | .buffer 506 | .iter() 507 | .skip(1) 508 | .take(4) 509 | .scan(true, |continuation, byte| { 510 | let has_successor = byte & 0x80 != 0x00; 511 | let length = (byte & 0x7f) as usize; 512 | if *continuation { 513 | *continuation = has_successor; 514 | length.into() 515 | } else { 516 | // Short-circuit 517 | None 518 | } 519 | }) 520 | .enumerate() 521 | .fold(1, |acc, (i, length)| { 522 | acc + 1 + length * 0x80_usize.pow(i as u32) 523 | }) 524 | .into() 525 | } 526 | 527 | fn decode( 528 | mut self, 529 | state: &mut MqttState, 530 | ) -> nb::Result<(Option, Option>), EventError> { 531 | let buffer = self.packet_buffer.buffer[self.packet_buffer.range].as_ref(); 532 | match decode_slice(buffer) { 533 | Err(e) => { 534 | self.is_err.replace(true); 535 | error!("Packet decode error!"); 536 | 537 | Err(EventError::Encoding(e).into()) 538 | } 539 | Ok(Some(packet)) => { 540 | self.is_err.replace(false); 541 | state 542 | .handle_incoming_packet(packet) 543 | .map_err(EventError::from) 544 | .map_err(nb::Error::from) 545 | } 546 | Ok(None) => Err(nb::Error::WouldBlock), 547 | } 548 | } 549 | } 550 | 551 | impl<'a> Drop for PacketDecoder<'a> { 552 | fn drop(&mut self) { 553 | if let Some(is_err) = self.is_err { 554 | if is_err { 555 | self.packet_buffer.init(); 556 | } else { 557 | let length = self 558 | .packet_length() 559 | .unwrap_or_else(|| unreachable!("A valid packet has a non-zero length.")); 560 | self.packet_buffer.rotate(length); 561 | } 562 | } 563 | } 564 | } 565 | 566 | #[cfg(test)] 567 | mod tests { 568 | use super::*; 569 | use crate::state::{BoxedPublish, Inflight, StartTime}; 570 | use bbqueue::BBBuffer; 571 | use fugit::TimerInstantU32; 572 | use heapless::pool::singleton::Pool; 573 | use mqttrust::encoding::v4::{Connack, ConnectReturnCode, Error as EncodingError, Pid}; 574 | use mqttrust::{Publish, QoS}; 575 | 576 | #[derive(Debug)] 577 | struct ClockMock { 578 | ticks: u32, 579 | } 580 | 581 | impl fugit_timer::Timer<1000> for ClockMock { 582 | type Error = (); 583 | 584 | fn now(&mut self) -> fugit::TimerInstantU32<1000> { 585 | fugit::TimerInstantU32::from_ticks(self.ticks) 586 | } 587 | 588 | fn start(&mut self, _duration: fugit::TimerDurationU32<1000>) -> Result<(), Self::Error> { 589 | todo!() 590 | } 591 | 592 | fn cancel(&mut self) -> Result<(), Self::Error> { 593 | todo!() 594 | } 595 | 596 | fn wait(&mut self) -> nb::Result<(), Self::Error> { 597 | todo!() 598 | } 599 | } 600 | 601 | struct MockNetwork { 602 | pub should_fail_read: bool, 603 | pub should_fail_write: bool, 604 | } 605 | 606 | impl Dns for MockNetwork { 607 | type Error = (); 608 | 609 | fn get_host_by_name( 610 | &mut self, 611 | _hostname: &str, 612 | _addr_type: embedded_nal::AddrType, 613 | ) -> nb::Result { 614 | unimplemented!() 615 | } 616 | fn get_host_by_address( 617 | &mut self, 618 | _addr: embedded_nal::IpAddr, 619 | ) -> nb::Result, Self::Error> { 620 | unimplemented!() 621 | } 622 | } 623 | 624 | impl TcpClientStack for MockNetwork { 625 | type TcpSocket = (); 626 | type Error = (); 627 | 628 | fn socket(&mut self) -> Result { 629 | Ok(()) 630 | } 631 | 632 | fn connect( 633 | &mut self, 634 | _socket: &mut Self::TcpSocket, 635 | _remote: embedded_nal::SocketAddr, 636 | ) -> nb::Result<(), Self::Error> { 637 | Ok(()) 638 | } 639 | 640 | fn is_connected(&mut self, _socket: &Self::TcpSocket) -> Result { 641 | Ok(true) 642 | } 643 | 644 | fn send( 645 | &mut self, 646 | _socket: &mut Self::TcpSocket, 647 | buffer: &[u8], 648 | ) -> nb::Result { 649 | if self.should_fail_write { 650 | Err(nb::Error::Other(())) 651 | } else { 652 | Ok(buffer.len()) 653 | } 654 | } 655 | 656 | fn receive( 657 | &mut self, 658 | _socket: &mut Self::TcpSocket, 659 | buffer: &mut [u8], 660 | ) -> nb::Result { 661 | if self.should_fail_read { 662 | Err(nb::Error::Other(())) 663 | } else { 664 | let connack = Packet::Connack(Connack { 665 | session_present: false, 666 | code: ConnectReturnCode::Accepted, 667 | }); 668 | let size = encode_slice(&connack, buffer).unwrap(); 669 | Ok(size) 670 | } 671 | } 672 | 673 | fn close(&mut self, _socket: Self::TcpSocket) -> Result<(), Self::Error> { 674 | Ok(()) 675 | } 676 | } 677 | 678 | #[test] 679 | fn success_receive_multiple_packets() { 680 | let mut state = MqttState::<1000>::new(); 681 | const LEN: usize = 1024 * 10; 682 | static mut PUBLISH_MEM: [u8; LEN] = [0u8; LEN]; 683 | BoxedPublish::grow(unsafe { &mut PUBLISH_MEM }); 684 | 685 | let mut rx_buf = PacketBuffer::new(); 686 | let connack = Connack { 687 | session_present: false, 688 | code: ConnectReturnCode::Accepted, 689 | }; 690 | let publish = Publish { 691 | dup: false, 692 | qos: QoS::AtLeastOnce, 693 | pid: Some(Pid::new()), 694 | retain: false, 695 | topic_name: "test/topic", 696 | payload: &[0xff; { 1003 + 3 * 1024 }], 697 | }; 698 | 699 | let connack_len = encode_slice(&Packet::from(connack), rx_buf.buffer()).unwrap(); 700 | rx_buf.range.end += connack_len; 701 | let publish_len = encode_slice(&Packet::from(publish.clone()), rx_buf.buffer()).unwrap(); 702 | rx_buf.range.end += publish_len; 703 | assert_eq!(rx_buf.range.end, rx_buf.buffer.capacity()); 704 | 705 | // Decode the first Connack packet on the Handshake state. 706 | state.connection_status = MqttConnectionStatus::Handshake; 707 | let (n, p) = PacketDecoder::new(&mut rx_buf).decode(&mut state).unwrap(); 708 | assert_eq!(n, Some(Notification::ConnAck)); 709 | assert_eq!(p, None); 710 | 711 | let mut pkg = SerializedPacket(&mut rx_buf.buffer[rx_buf.range]); 712 | pkg.set_pid(Pid::new()).unwrap(); 713 | 714 | // Decode the second Publish packet on the Connected state. 715 | assert_eq!(state.connection_status, MqttConnectionStatus::Connected); 716 | let (n, p) = PacketDecoder::new(&mut rx_buf).decode(&mut state).unwrap(); 717 | let publish_notification = match n { 718 | Some(Notification::Publish(p)) => p, 719 | _ => panic!(), 720 | }; 721 | assert_eq!(&publish_notification.payload, publish.payload); 722 | assert_eq!(p, Some(Packet::Puback(Pid::default()))); 723 | assert_eq!(rx_buf.range.end, 0); 724 | assert!((0..4096).all(|i| rx_buf.buffer[i] == 0)); 725 | } 726 | 727 | #[test] 728 | fn failure_receive_multiple_packets() { 729 | let mut state = MqttState::<1000>::new(); 730 | const LEN: usize = 1024 * 10; 731 | static mut PUBLISH_MEM: [u8; LEN] = [0u8; LEN]; 732 | BoxedPublish::grow(unsafe { &mut PUBLISH_MEM }); 733 | 734 | let mut rx_buf = PacketBuffer::new(); 735 | let connack_malformed = Connack { 736 | session_present: false, 737 | code: ConnectReturnCode::Accepted, 738 | }; 739 | let publish = Publish { 740 | dup: false, 741 | qos: QoS::AtLeastOnce, 742 | pid: Some(Pid::new()), 743 | retain: false, 744 | topic_name: "test/topic", 745 | payload: &[0xff; { 1003 + 3 * 1024 }], 746 | }; 747 | 748 | let connack_malformed_len = 749 | encode_slice(&Packet::from(connack_malformed), rx_buf.buffer()).unwrap(); 750 | rx_buf.buffer()[3] = 6; // An invalid connect return code. 751 | rx_buf.range.end += connack_malformed_len; 752 | let publish_len = encode_slice(&Packet::from(publish.clone()), rx_buf.buffer()).unwrap(); 753 | rx_buf.range.end += publish_len; 754 | assert_eq!(rx_buf.range.end, rx_buf.buffer.capacity()); 755 | 756 | // When a packet is malformed, we cannot tell its length. The decoder 757 | // discards the entire buffer. 758 | state.connection_status = MqttConnectionStatus::Handshake; 759 | match PacketDecoder::new(&mut rx_buf).decode(&mut state) { 760 | Ok((_, _)) | Err(nb::Error::WouldBlock) => panic!(), 761 | Err(nb::Error::Other(e)) => { 762 | assert_eq!( 763 | e, 764 | EventError::Encoding(EncodingError::InvalidConnectReturnCode(6)) 765 | ) 766 | } 767 | } 768 | assert_eq!(state.connection_status, MqttConnectionStatus::Handshake); 769 | assert_eq!(rx_buf.range.end, 0); 770 | assert!((0..4096).all(|i| rx_buf.buffer[i] == 0)); 771 | } 772 | 773 | #[test] 774 | fn retry_behaviour() { 775 | static mut Q: BBBuffer<{ 1024 * 10 }> = BBBuffer::new(); 776 | 777 | let mut network = MockNetwork { 778 | should_fail_read: false, 779 | should_fail_write: false, 780 | }; 781 | 782 | let (_p, c) = unsafe { Q.try_split_framed().unwrap() }; 783 | let mut event = EventLoop::new( 784 | c, 785 | ClockMock { ticks: 0 }, 786 | MqttOptions::new("client", Broker::Hostname(""), 8883), 787 | ); 788 | 789 | let now = StartTime::new(TimerInstantU32::from_ticks(0)); 790 | 791 | let topic = "hello/world"; 792 | let payload = &[1, 2, 3]; 793 | 794 | let publish = Publish { 795 | qos: QoS::AtLeastOnce, 796 | pid: Some(Pid::new()), 797 | payload, 798 | dup: false, 799 | retain: false, 800 | topic_name: topic, 801 | }; 802 | 803 | let mut rx_buf = PacketBuffer::new(); 804 | let publish_len = encode_slice(&Packet::from(publish.clone()), rx_buf.buffer()).unwrap(); 805 | rx_buf.range.end += publish_len; 806 | 807 | event 808 | .state 809 | .outgoing_pub 810 | .insert(2, Inflight::new(now, &rx_buf.buffer[..rx_buf.range.end])) 811 | .unwrap(); 812 | 813 | event.state.connection_status = MqttConnectionStatus::Handshake; 814 | event.network_handle.socket = Some(()); 815 | 816 | event.connect(&mut network).unwrap(); 817 | } 818 | } 819 | -------------------------------------------------------------------------------- /mqttrust_core/src/fmt.rs: -------------------------------------------------------------------------------- 1 | // MIT License 2 | 3 | // Copyright (c) 2020 Dario Nieuwenhuis 4 | 5 | // Permission is hereby granted, free of charge, to any person obtaining a copy 6 | // of this software and associated documentation files (the "Software"), to deal 7 | // in the Software without restriction, including without limitation the rights 8 | // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | // copies of the Software, and to permit persons to whom the Software is 10 | // furnished to do so, subject to the following conditions: 11 | 12 | // The above copyright notice and this permission notice shall be included in all 13 | // copies or substantial portions of the Software. 14 | 15 | // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | // SOFTWARE. 22 | 23 | #![macro_use] 24 | #![allow(unused_macros)] 25 | 26 | #[cfg(all(feature = "defmt", feature = "log"))] 27 | compile_error!("You may not enable both `defmt` and `log` features."); 28 | 29 | macro_rules! assert { 30 | ($($x:tt)*) => { 31 | { 32 | #[cfg(not(feature = "defmt"))] 33 | ::core::assert!($($x)*); 34 | #[cfg(feature = "defmt")] 35 | ::defmt::assert!($($x)*); 36 | } 37 | }; 38 | } 39 | 40 | macro_rules! assert_eq { 41 | ($($x:tt)*) => { 42 | { 43 | #[cfg(not(feature = "defmt"))] 44 | ::core::assert_eq!($($x)*); 45 | #[cfg(feature = "defmt")] 46 | ::defmt::assert_eq!($($x)*); 47 | } 48 | }; 49 | } 50 | 51 | macro_rules! assert_ne { 52 | ($($x:tt)*) => { 53 | { 54 | #[cfg(not(feature = "defmt"))] 55 | ::core::assert_ne!($($x)*); 56 | #[cfg(feature = "defmt")] 57 | ::defmt::assert_ne!($($x)*); 58 | } 59 | }; 60 | } 61 | 62 | macro_rules! debug_assert { 63 | ($($x:tt)*) => { 64 | { 65 | #[cfg(not(feature = "defmt"))] 66 | ::core::debug_assert!($($x)*); 67 | #[cfg(feature = "defmt")] 68 | ::defmt::debug_assert!($($x)*); 69 | } 70 | }; 71 | } 72 | 73 | macro_rules! debug_assert_eq { 74 | ($($x:tt)*) => { 75 | { 76 | #[cfg(not(feature = "defmt"))] 77 | ::core::debug_assert_eq!($($x)*); 78 | #[cfg(feature = "defmt")] 79 | ::defmt::debug_assert_eq!($($x)*); 80 | } 81 | }; 82 | } 83 | 84 | macro_rules! debug_assert_ne { 85 | ($($x:tt)*) => { 86 | { 87 | #[cfg(not(feature = "defmt"))] 88 | ::core::debug_assert_ne!($($x)*); 89 | #[cfg(feature = "defmt")] 90 | ::defmt::debug_assert_ne!($($x)*); 91 | } 92 | }; 93 | } 94 | 95 | macro_rules! todo { 96 | ($($x:tt)*) => { 97 | { 98 | #[cfg(not(feature = "defmt"))] 99 | ::core::todo!($($x)*); 100 | #[cfg(feature = "defmt")] 101 | ::defmt::todo!($($x)*); 102 | } 103 | }; 104 | } 105 | 106 | macro_rules! unreachable { 107 | ($($x:tt)*) => { 108 | { 109 | #[cfg(not(feature = "defmt"))] 110 | ::core::unreachable!($($x)*); 111 | #[cfg(feature = "defmt")] 112 | ::defmt::unreachable!($($x)*); 113 | } 114 | }; 115 | } 116 | 117 | macro_rules! panic { 118 | ($($x:tt)*) => { 119 | { 120 | #[cfg(not(feature = "defmt"))] 121 | ::core::panic!($($x)*); 122 | #[cfg(feature = "defmt")] 123 | ::defmt::panic!($($x)*); 124 | } 125 | }; 126 | } 127 | 128 | macro_rules! trace { 129 | ($s:literal $(, $x:expr)* $(,)?) => { 130 | { 131 | #[cfg(feature = "log")] 132 | ::log::trace!($s $(, $x)*); 133 | #[cfg(feature = "defmt")] 134 | ::defmt::trace!($s $(, $x)*); 135 | #[cfg(not(any(feature = "log", feature="defmt")))] 136 | let _ = ($( & $x ),*); 137 | } 138 | }; 139 | } 140 | 141 | macro_rules! debug { 142 | ($s:literal $(, $x:expr)* $(,)?) => { 143 | { 144 | #[cfg(feature = "log")] 145 | ::log::debug!($s $(, $x)*); 146 | #[cfg(feature = "defmt")] 147 | ::defmt::debug!($s $(, $x)*); 148 | #[cfg(not(any(feature = "log", feature="defmt")))] 149 | let _ = ($( & $x ),*); 150 | } 151 | }; 152 | } 153 | 154 | macro_rules! info { 155 | ($s:literal $(, $x:expr)* $(,)?) => { 156 | { 157 | #[cfg(feature = "log")] 158 | ::log::info!($s $(, $x)*); 159 | #[cfg(feature = "defmt")] 160 | ::defmt::info!($s $(, $x)*); 161 | #[cfg(not(any(feature = "log", feature="defmt")))] 162 | let _ = ($( & $x ),*); 163 | } 164 | }; 165 | } 166 | 167 | macro_rules! warn { 168 | ($s:literal $(, $x:expr)* $(,)?) => { 169 | { 170 | #[cfg(feature = "log")] 171 | ::log::warn!($s $(, $x)*); 172 | #[cfg(feature = "defmt")] 173 | ::defmt::warn!($s $(, $x)*); 174 | #[cfg(not(any(feature = "log", feature="defmt")))] 175 | let _ = ($( & $x ),*); 176 | } 177 | }; 178 | } 179 | 180 | macro_rules! error { 181 | ($s:literal $(, $x:expr)* $(,)?) => { 182 | { 183 | #[cfg(feature = "log")] 184 | ::log::error!($s $(, $x)*); 185 | #[cfg(feature = "defmt")] 186 | ::defmt::error!($s $(, $x)*); 187 | #[cfg(not(any(feature = "log", feature="defmt")))] 188 | let _ = ($( & $x ),*); 189 | } 190 | }; 191 | } 192 | 193 | #[cfg(feature = "defmt")] 194 | macro_rules! unwrap { 195 | ($($x:tt)*) => { 196 | ::defmt::unwrap!($($x)*) 197 | }; 198 | } 199 | 200 | #[cfg(not(feature = "defmt"))] 201 | macro_rules! unwrap { 202 | ($arg:expr) => { 203 | match $crate::fmt::Try::into_result($arg) { 204 | ::core::result::Result::Ok(t) => t, 205 | ::core::result::Result::Err(e) => { 206 | ::core::panic!("unwrap of `{}` failed: {:?}", ::core::stringify!($arg), e); 207 | } 208 | } 209 | }; 210 | ($arg:expr, $($msg:expr),+ $(,)? ) => { 211 | match $crate::fmt::Try::into_result($arg) { 212 | ::core::result::Result::Ok(t) => t, 213 | ::core::result::Result::Err(e) => { 214 | ::core::panic!("unwrap of `{}` failed: {}: {:?}", ::core::stringify!($arg), ::core::format_args!($($msg,)*), e); 215 | } 216 | } 217 | } 218 | } 219 | 220 | #[derive(Debug, Copy, Clone, Eq, PartialEq)] 221 | pub struct NoneError; 222 | 223 | pub trait Try { 224 | type Ok; 225 | type Error; 226 | fn into_result(self) -> Result; 227 | } 228 | 229 | impl Try for Option { 230 | type Ok = T; 231 | type Error = NoneError; 232 | 233 | #[inline] 234 | fn into_result(self) -> Result { 235 | self.ok_or(NoneError) 236 | } 237 | } 238 | 239 | impl Try for Result { 240 | type Ok = T; 241 | type Error = E; 242 | 243 | #[inline] 244 | fn into_result(self) -> Self { 245 | self 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /mqttrust_core/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(not(test), no_std)] 2 | 3 | #[cfg(feature = "std")] 4 | extern crate std; 5 | 6 | // This mod MUST go first, so that the others see its macros. 7 | pub(crate) mod fmt; 8 | 9 | mod client; 10 | mod eventloop; 11 | mod max_payload; 12 | mod options; 13 | mod packet; 14 | mod state; 15 | 16 | pub use bbqueue; 17 | 18 | pub use client::Client; 19 | use core::convert::TryFrom; 20 | pub use eventloop::EventLoop; 21 | use heapless::{String, Vec}; 22 | use max_payload::MAX_PAYLOAD_SIZE; 23 | pub use mqttrust::encoding::v4::{Pid, Publish, QoS, QosPid, Suback}; 24 | pub use mqttrust::*; 25 | pub use options::{Broker, MqttOptions}; 26 | use state::StateError; 27 | 28 | #[derive(Debug, PartialEq)] 29 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 30 | pub struct PublishNotification { 31 | pub dup: bool, 32 | pub qospid: QoS, 33 | pub retain: bool, 34 | pub topic_name: String<256>, 35 | pub payload: Vec, 36 | } 37 | 38 | /// Includes incoming packets from the network and other interesting events 39 | /// happening in the eventloop 40 | #[derive(Debug, PartialEq)] 41 | // #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 42 | pub enum Notification { 43 | /// Incoming connection acknowledge 44 | ConnAck, 45 | /// Incoming publish from the broker 46 | #[cfg(not(feature = "std"))] 47 | Publish(heapless::pool::singleton::Box), 48 | #[cfg(feature = "std")] 49 | Publish(std::boxed::Box), 50 | /// Incoming puback from the broker 51 | Puback(Pid), 52 | /// Incoming pubrec from the broker 53 | Pubrec(Pid), 54 | /// Incoming pubcomp from the broker 55 | Pubcomp(Pid), 56 | // TODO: 57 | // Suback(Suback), 58 | /// Incoming suback from the broker 59 | Suback(Pid), 60 | /// Incoming unsuback from the broker 61 | Unsuback(Pid), 62 | // Eventloop error 63 | Abort(EventError), 64 | } 65 | 66 | impl<'a> TryFrom> for PublishNotification { 67 | type Error = StateError; 68 | 69 | fn try_from(p: Publish<'a>) -> Result { 70 | Ok(PublishNotification { 71 | dup: p.dup, 72 | qospid: p.qos, 73 | retain: p.retain, 74 | topic_name: String::from(p.topic_name), 75 | payload: Vec::from_slice(p.payload).map_err(|_| { 76 | error!("Failed to convert payload to notification!"); 77 | StateError::PayloadEncoding 78 | })?, 79 | }) 80 | } 81 | } 82 | 83 | /// Critical errors during eventloop polling 84 | #[derive(Debug, PartialEq)] 85 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 86 | pub enum EventError { 87 | MqttState(StateError), 88 | Timeout, 89 | Encoding(mqttrust::encoding::v4::Error), 90 | Network(NetworkError), 91 | BufferSize, 92 | Clock, 93 | RequestsNotAvailable, 94 | } 95 | 96 | #[derive(Debug, PartialEq)] 97 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 98 | pub enum NetworkError { 99 | Read, 100 | Write, 101 | NoSocket, 102 | SocketOpen, 103 | SocketConnect, 104 | SocketClosed, 105 | DnsLookupFailed, 106 | } 107 | 108 | impl From for EventError { 109 | fn from(e: mqttrust::encoding::v4::Error) -> Self { 110 | EventError::Encoding(e) 111 | } 112 | } 113 | 114 | impl From for EventError { 115 | fn from(e: StateError) -> Self { 116 | EventError::MqttState(e) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /mqttrust_core/src/max_payload.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(any( 2 | feature = "max_payload_size_2048", 3 | feature = "max_payload_size_4096", 4 | feature = "max_payload_size_8192" 5 | )))] 6 | pub const MAX_PAYLOAD_SIZE: usize = 4096; 7 | 8 | #[cfg(feature = "max_payload_size_2048")] 9 | pub const MAX_PAYLOAD_SIZE: usize = 2048; 10 | 11 | #[cfg(feature = "max_payload_size_4096")] 12 | pub const MAX_PAYLOAD_SIZE: usize = 4096; 13 | 14 | #[cfg(feature = "max_payload_size_8192")] 15 | pub const MAX_PAYLOAD_SIZE: usize = 8192; 16 | -------------------------------------------------------------------------------- /mqttrust_core/src/options.rs: -------------------------------------------------------------------------------- 1 | use embedded_nal::{IpAddr, Ipv4Addr}; 2 | use mqttrust::encoding::v4::LastWill; 3 | 4 | #[derive(Clone, Debug, PartialEq)] 5 | pub enum Broker<'a> { 6 | Hostname(&'a str), 7 | IpAddr(IpAddr), 8 | } 9 | 10 | impl<'a> From<&'a str> for Broker<'a> { 11 | fn from(s: &'a str) -> Self { 12 | Broker::Hostname(s) 13 | } 14 | } 15 | 16 | impl<'a> From for Broker<'a> { 17 | fn from(ip: IpAddr) -> Self { 18 | Broker::IpAddr(ip) 19 | } 20 | } 21 | 22 | impl<'a> From for Broker<'a> { 23 | fn from(ip: Ipv4Addr) -> Self { 24 | Broker::IpAddr(ip.into()) 25 | } 26 | } 27 | 28 | /// Options to configure the behaviour of mqtt connection 29 | /// 30 | /// **Lifetimes**: 31 | /// - 'a: The lifetime of option fields, not referenced in any MQTT packets at any point 32 | /// - 'b: The lifetime of the packet fields, backed by a slice buffer 33 | #[derive(Clone, Debug)] 34 | pub struct MqttOptions<'a> { 35 | /// broker address that you want to connect to 36 | broker_addr: Broker<'a>, 37 | /// broker port 38 | port: u16, 39 | /// keep alive time to send pingreq to broker when the connection is idle 40 | keep_alive_ms: u32, 41 | /// clean (or) persistent session 42 | clean_session: bool, 43 | /// client identifier 44 | client_id: &'a str, 45 | // alpn settings 46 | // alpn: Option>>, 47 | /// username and password 48 | credentials: Option<(&'a str, &'a [u8])>, 49 | // Minimum delay time between consecutive outgoing packets 50 | // throttle: Duration, 51 | /// Last will that will be issued on unexpected disconnect 52 | last_will: Option>, 53 | } 54 | 55 | impl<'a> MqttOptions<'a> { 56 | /// New mqtt options 57 | pub fn new(id: &'a str, broker: Broker<'a>, port: u16) -> MqttOptions<'a> { 58 | if id.starts_with(' ') || id.is_empty() { 59 | panic!("Invalid client id") 60 | } 61 | 62 | MqttOptions { 63 | broker_addr: broker, 64 | port, 65 | keep_alive_ms: 60_000, 66 | clean_session: true, 67 | client_id: id, 68 | // alpn: None, 69 | credentials: None, 70 | // throttle: Duration::from_micros(0), 71 | last_will: None, 72 | } 73 | } 74 | 75 | /// Broker address 76 | pub fn broker(&self) -> (Broker, u16) { 77 | (self.broker_addr.clone(), self.port) 78 | } 79 | 80 | /// Broker address 81 | pub fn set_broker(self, broker: Broker<'a>) -> Self { 82 | Self { 83 | broker_addr: broker, 84 | ..self 85 | } 86 | } 87 | 88 | pub fn set_port(self, port: u16) -> Self { 89 | Self { port, ..self } 90 | } 91 | 92 | pub fn set_last_will(self, will: LastWill<'a>) -> Self { 93 | Self { 94 | last_will: Some(will), 95 | ..self 96 | } 97 | } 98 | 99 | pub fn last_will(&self) -> Option> { 100 | self.last_will.clone() 101 | } 102 | 103 | // pub fn set_alpn(self, alpn: Vec>) -> Self { 104 | // Self { 105 | // alpn: Some(alpn), 106 | // ..self 107 | // } 108 | // } 109 | 110 | // pub fn alpn(&self) -> Option>> { 111 | // self.alpn.clone() 112 | // } 113 | 114 | /// Set number of seconds after which client should ping the broker 115 | /// if there is no other data exchange 116 | pub fn set_keep_alive(self, secs: u16) -> Self { 117 | if secs < 5 { 118 | panic!("Keep alives should be >= 5 secs"); 119 | } 120 | 121 | Self { 122 | keep_alive_ms: secs as u32 * 1000, 123 | ..self 124 | } 125 | } 126 | 127 | /// Keep alive time 128 | pub fn keep_alive_ms(&self) -> u32 { 129 | self.keep_alive_ms 130 | } 131 | 132 | /// Client identifier 133 | pub fn client_id(&self) -> &'a str { 134 | self.client_id 135 | } 136 | 137 | /// `clean_session = true` removes all the state from queues & instructs the broker 138 | /// to clean all the client state when client disconnects. 139 | /// 140 | /// When set `false`, broker will hold the client state and performs pending 141 | /// operations on the client when reconnection with same `client_id` 142 | /// happens. Local queue state is also held to retransmit packets after reconnection. 143 | pub fn set_clean_session(self, clean_session: bool) -> Self { 144 | Self { 145 | clean_session, 146 | ..self 147 | } 148 | } 149 | 150 | /// Clean session 151 | pub fn clean_session(&self) -> bool { 152 | self.clean_session 153 | } 154 | 155 | /// Username and password 156 | pub fn set_credentials(self, username: &'a str, password: &'a [u8]) -> Self { 157 | Self { 158 | credentials: Some((username, password)), 159 | ..self 160 | } 161 | } 162 | 163 | /// Security options 164 | pub fn credentials(&self) -> (Option<&'a str>, Option<&'a [u8]>) { 165 | if let Some((username, password)) = self.credentials { 166 | (Some(username), Some(password)) 167 | } else { 168 | (None, None) 169 | } 170 | } 171 | 172 | // /// Enables throttling and sets outoing message rate to the specified 'rate' 173 | // pub fn set_throttle(self, duration: Duration) -> Self { 174 | // self.throttle = duration; 175 | // self 176 | // } 177 | 178 | // /// Outgoing message rate 179 | // pub fn throttle(&self) -> Duration { 180 | // self.throttle 181 | // } 182 | } 183 | 184 | #[cfg(test)] 185 | mod test { 186 | use super::{Ipv4Addr, MqttOptions}; 187 | use embedded_nal::{IpAddr, Ipv6Addr}; 188 | use mqttrust::{encoding::v4::LastWill, QoS}; 189 | 190 | #[test] 191 | #[should_panic] 192 | fn client_id_starts_with_space() { 193 | let _mqtt_opts = MqttOptions::new(" client_a", Ipv4Addr::new(127, 0, 0, 1).into(), 1883) 194 | .set_clean_session(true); 195 | } 196 | 197 | #[test] 198 | #[should_panic] 199 | fn no_client_id() { 200 | let _mqtt_opts = 201 | MqttOptions::new("", Ipv4Addr::localhost().into(), 1883).set_clean_session(true); 202 | } 203 | 204 | #[test] 205 | fn broker() { 206 | let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); 207 | assert_eq!(opts.broker_addr, Ipv4Addr::localhost().into()); 208 | assert_eq!(opts.port, 1883); 209 | assert_eq!(opts.broker(), (Ipv4Addr::localhost().into(), 1883)); 210 | assert_eq!( 211 | MqttOptions::new("client_a", "localhost".into(), 1883).broker_addr, 212 | "localhost".into() 213 | ); 214 | assert_eq!( 215 | MqttOptions::new("client_a", IpAddr::V4(Ipv4Addr::localhost()).into(), 1883) 216 | .broker_addr, 217 | IpAddr::V4(Ipv4Addr::localhost()).into() 218 | ); 219 | assert_eq!( 220 | MqttOptions::new("client_a", IpAddr::V6(Ipv6Addr::localhost()).into(), 1883) 221 | .broker_addr, 222 | IpAddr::V6(Ipv6Addr::localhost()).into() 223 | ); 224 | } 225 | 226 | #[test] 227 | fn client_id() { 228 | let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); 229 | assert_eq!(opts.client_id(), "client_a"); 230 | } 231 | 232 | #[test] 233 | fn keep_alive_ms() { 234 | let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); 235 | assert_eq!(opts.keep_alive_ms, 60_000); 236 | assert_eq!(opts.set_keep_alive(120).keep_alive_ms(), 120_000); 237 | } 238 | 239 | #[test] 240 | #[should_panic] 241 | fn keep_alive_panic() { 242 | let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); 243 | assert_eq!(opts.keep_alive_ms, 60_000); 244 | assert_eq!(opts.set_keep_alive(4).keep_alive_ms(), 120_000); 245 | } 246 | 247 | #[test] 248 | fn last_will() { 249 | let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); 250 | assert_eq!(opts.last_will, None); 251 | let will = LastWill { 252 | topic: "topic", 253 | message: b"Will message", 254 | qos: QoS::AtLeastOnce, 255 | retain: false, 256 | }; 257 | assert_eq!(opts.set_last_will(will.clone()).last_will(), Some(will)); 258 | } 259 | 260 | #[test] 261 | fn clean_session() { 262 | let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); 263 | assert_eq!(opts.clean_session, true); 264 | assert_eq!(opts.set_clean_session(false).clean_session(), false); 265 | } 266 | 267 | #[test] 268 | fn credentials() { 269 | let opts = MqttOptions::new("client_a", Ipv4Addr::localhost().into(), 1883); 270 | assert_eq!(opts.credentials, None); 271 | assert_eq!(opts.credentials(), (None, None)); 272 | assert_eq!( 273 | opts.set_credentials("some_user", &[]).credentials(), 274 | (Some("some_user"), Some(&b""[..])) 275 | ); 276 | } 277 | } 278 | -------------------------------------------------------------------------------- /mqttrust_core/src/packet.rs: -------------------------------------------------------------------------------- 1 | use mqttrust::encoding::v4::{ 2 | decoder::{read_header, Header}, 3 | packet::PacketType, 4 | utils::{Pid, QoS}, 5 | }; 6 | 7 | use crate::state::StateError; 8 | 9 | pub struct SerializedPacket<'a>(pub &'a mut [u8]); 10 | 11 | impl<'a> SerializedPacket<'a> { 12 | pub fn header(&self) -> Result { 13 | Header::new(self.0[0]).map_err(|_| StateError::InvalidHeader) 14 | } 15 | 16 | pub fn set_pid(&mut self, pid: Pid) -> Result<(), StateError> { 17 | let mut offset = 0; 18 | let (header, _) = read_header(self.0, &mut offset) 19 | .map_err(|_| StateError::InvalidHeader)? 20 | .ok_or(StateError::InvalidHeader)?; 21 | 22 | match (header.typ, header.qos) { 23 | (PacketType::Publish, QoS::AtLeastOnce | QoS::ExactlyOnce) => { 24 | if self.0[offset..].len() < 2 { 25 | return Err(StateError::InvalidHeader); 26 | } 27 | let len = ((self.0[offset] as usize) << 8) | self.0[offset + 1] as usize; 28 | 29 | offset += 2; 30 | if len > self.0[offset..].len() { 31 | return Err(StateError::InvalidHeader); 32 | } else { 33 | offset += len; 34 | } 35 | } 36 | ( 37 | PacketType::Subscribe 38 | | PacketType::Unsubscribe 39 | | PacketType::Suback 40 | | PacketType::Puback 41 | | PacketType::Pubrec 42 | | PacketType::Pubrel 43 | | PacketType::Pubcomp 44 | | PacketType::Unsuback, 45 | _, 46 | ) => {} 47 | _ => return Ok(()), 48 | } 49 | 50 | pid.to_buffer(&mut self.0, &mut offset) 51 | .map_err(|_| StateError::PidMissing) 52 | } 53 | 54 | pub fn to_inner(self) -> &'a mut [u8] { 55 | self.0 56 | } 57 | } 58 | 59 | #[cfg(test)] 60 | mod tests { 61 | use core::convert::TryFrom; 62 | 63 | use mqttrust::{ 64 | encoding::v4::{decode_slice, encode_slice}, 65 | Packet, Publish, Subscribe, SubscribeTopic, 66 | }; 67 | 68 | use super::*; 69 | 70 | #[test] 71 | fn set_publish_pid() { 72 | let publish = Packet::Publish(Publish { 73 | dup: false, 74 | qos: QoS::ExactlyOnce, 75 | pid: None, 76 | retain: false, 77 | topic_name: "test", 78 | payload: b"Whatup", 79 | }); 80 | 81 | let buf = &mut [0u8; 2048]; 82 | let len = encode_slice(&publish, buf).unwrap(); 83 | 84 | let mut ser_packet = SerializedPacket(&mut buf[..len]); 85 | 86 | let header = ser_packet.header().unwrap(); 87 | 88 | assert_eq!(header.typ, PacketType::Publish); 89 | assert_eq!(header.qos, QoS::ExactlyOnce); 90 | 91 | ser_packet.set_pid(Pid::try_from(54).unwrap()).unwrap(); 92 | 93 | let p = decode_slice(ser_packet.to_inner()).unwrap(); 94 | 95 | assert_eq!( 96 | p, 97 | Some(Packet::Publish(Publish { 98 | dup: false, 99 | qos: QoS::ExactlyOnce, 100 | pid: Some(Pid::try_from(54).unwrap()), 101 | retain: false, 102 | topic_name: "test", 103 | payload: b"Whatup", 104 | })) 105 | ) 106 | } 107 | 108 | #[test] 109 | fn set_subscribe_pid() { 110 | let subscribe = Packet::Subscribe(Subscribe::new(&[SubscribeTopic { 111 | topic_path: "AWESOME", 112 | qos: QoS::AtLeastOnce, 113 | }])); 114 | 115 | let buf = &mut [0u8; 2048]; 116 | let len = encode_slice(&subscribe, buf).unwrap(); 117 | 118 | let mut ser_packet = SerializedPacket(&mut buf[..len]); 119 | 120 | let header = ser_packet.header().unwrap(); 121 | 122 | assert_eq!(header.typ, PacketType::Subscribe); 123 | assert_eq!(header.qos, QoS::AtLeastOnce); 124 | 125 | ser_packet.set_pid(Pid::try_from(65).unwrap()).unwrap(); 126 | 127 | let p = decode_slice(ser_packet.to_inner()).unwrap(); 128 | 129 | match p { 130 | Some(Packet::Subscribe(p)) => { 131 | assert_eq!(p.pid(), Some(Pid::try_from(65).unwrap())); 132 | assert_eq!( 133 | p.topics().next(), 134 | Some(SubscribeTopic { 135 | topic_path: "AWESOME", 136 | qos: QoS::AtLeastOnce, 137 | }) 138 | ); 139 | } 140 | _ => panic!(), 141 | } 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /mqttrust_core/src/state.rs: -------------------------------------------------------------------------------- 1 | use crate::packet::SerializedPacket; 2 | use crate::Notification; 3 | #[cfg(not(feature = "std"))] 4 | use crate::PublishNotification; 5 | use core::convert::TryInto; 6 | use fugit::TimerDurationU32; 7 | use fugit::TimerInstantU32; 8 | #[cfg(not(feature = "std"))] 9 | use heapless::{pool, pool::singleton::Pool}; 10 | use heapless::{FnvIndexMap, FnvIndexSet, IndexMap, IndexSet}; 11 | use mqttrust::encoding::v4::*; 12 | 13 | #[allow(unused)] 14 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 15 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 16 | pub enum MqttConnectionStatus { 17 | Handshake, 18 | Connected, 19 | Disconnected, 20 | } 21 | 22 | #[derive(Debug, PartialEq)] 23 | #[cfg_attr(feature = "defmt-impl", derive(defmt::Format))] 24 | pub enum StateError { 25 | /// Broker's error reply to client's connect packet 26 | Connect(ConnectReturnCode), 27 | /// Invalid state for a given operation 28 | InvalidState, 29 | /// Received a packet (ack) which isn't asked for 30 | Unsolicited, 31 | /// Last pingreq isn't acked 32 | AwaitPingResp, 33 | /// Received a wrong packet while waiting for another packet 34 | WrongPacket, 35 | PayloadEncoding, 36 | InvalidUtf8, 37 | /// The maximum number of messages allowed to be simultaneously in-flight has been reached. 38 | MaxMessagesInflight, 39 | /// Non-zero QoS publications require PID 40 | PidMissing, 41 | InvalidHeader, 42 | } 43 | 44 | #[cfg(not(feature = "std"))] 45 | pool!( 46 | #[allow(non_upper_case_globals)] 47 | BoxedPublish: PublishNotification 48 | ); 49 | 50 | /// State of the mqtt connection. 51 | /// 52 | /// Methods will just modify the state of the object without doing any network 53 | /// operations This abstracts the functionality better so that it's easy to 54 | /// switch between synchronous code, tokio (or) async/await 55 | pub struct MqttState { 56 | /// Connection status 57 | pub connection_status: MqttConnectionStatus, 58 | /// Status of last ping 59 | pub await_pingresp: bool, 60 | /// Packet id of the last outgoing packet 61 | pub last_pid: Pid, 62 | /// Outgoing QoS 1, 2 publishes which aren't acked yet 63 | pub(crate) outgoing_pub: FnvIndexMap, 2>, 64 | /// Packet ids of released QoS 2 publishes 65 | pub outgoing_rel: FnvIndexSet, 66 | /// Packet ids on incoming QoS 2 publishes 67 | pub incoming_pub: FnvIndexSet, 68 | last_ping: StartTime, 69 | } 70 | 71 | impl MqttState { 72 | /// Creates new mqtt state. Same state should be used during a 73 | /// connection for persistent sessions while new state should 74 | /// instantiated for clean sessions 75 | pub fn new() -> Self { 76 | #[cfg(not(feature = "std"))] 77 | { 78 | const LEN: usize = core::mem::size_of::>() 79 | + core::mem::align_of::>() 80 | - (core::mem::size_of::>() 81 | % core::mem::align_of::>()); 82 | 83 | static mut PUBLISH_MEM: [u8; LEN] = [0u8; LEN]; 84 | BoxedPublish::grow(unsafe { &mut PUBLISH_MEM }); 85 | } 86 | 87 | MqttState { 88 | connection_status: MqttConnectionStatus::Disconnected, 89 | await_pingresp: false, 90 | last_pid: Pid::new(), 91 | 92 | outgoing_pub: IndexMap::new(), 93 | outgoing_rel: IndexSet::new(), 94 | incoming_pub: IndexSet::new(), 95 | last_ping: StartTime::default(), 96 | } 97 | } 98 | 99 | /// Consolidates handling of all outgoing mqtt packet logic. Returns a 100 | /// packet which should be put on to the network by the eventloop 101 | pub fn handle_outgoing_packet<'b>( 102 | &mut self, 103 | packet: Packet<'b>, 104 | ) -> Result, StateError> { 105 | match packet { 106 | Packet::Pingreq => self.handle_outgoing_ping(), 107 | _ => unreachable!(), 108 | } 109 | } 110 | 111 | /// Consolidates handling of all outgoing mqtt packet logic. Returns a 112 | /// packet which should be put on to the network by the eventloop 113 | pub fn handle_outgoing_request( 114 | &mut self, 115 | request: &mut SerializedPacket<'_>, 116 | now: &TimerInstantU32, 117 | ) -> Result<(), StateError> { 118 | match request.header()?.typ { 119 | PacketType::Publish => self.handle_outgoing_publish(request, now)?, 120 | PacketType::Subscribe => { 121 | let pid = self.next_pid(); 122 | trace!("Sending Subscribe({:?})", pid); 123 | request.set_pid(pid)? 124 | } 125 | PacketType::Unsubscribe => { 126 | let pid = self.next_pid(); 127 | trace!("Sending Unsubscribe({:?})", pid); 128 | request.set_pid(pid)? 129 | } 130 | _ => unreachable!(), 131 | } 132 | 133 | Ok(()) 134 | } 135 | 136 | /// Consolidates handling of all incoming mqtt packets. Returns a 137 | /// `Notification` which for the user to consume and `Packet` which for the 138 | /// eventloop to put on the network E.g For incoming QoS1 publish packet, 139 | /// this method returns (Publish, Puback). Publish packet will be forwarded 140 | /// to user and Pubck packet will be written to network 141 | pub fn handle_incoming_packet<'b>( 142 | &mut self, 143 | packet: Packet<'b>, 144 | ) -> Result<(Option, Option>), StateError> { 145 | match packet { 146 | Packet::Connack(connack) => self 147 | .handle_incoming_connack(connack) 148 | .map(|()| (Notification::ConnAck.into(), None)), 149 | Packet::Pingresp => self.handle_incoming_pingresp(), 150 | Packet::Publish(publish) => self.handle_incoming_publish(publish), 151 | Packet::Suback(suback) => self.handle_incoming_suback(suback), 152 | Packet::Unsuback(pid) => self.handle_incoming_unsuback(pid), 153 | Packet::Puback(pid) => self.handle_incoming_puback(pid), 154 | Packet::Pubrec(pid) => self.handle_incoming_pubrec(pid), 155 | Packet::Pubrel(pid) => self.handle_incoming_pubrel(pid), 156 | Packet::Pubcomp(pid) => self.handle_incoming_pubcomp(pid), 157 | _ => { 158 | error!("Invalid incoming packet!"); 159 | Ok((None, None)) 160 | } 161 | } 162 | } 163 | 164 | /// Adds next packet identifier to QoS 1 and 2 publish packets and returns 165 | /// it by wrapping publish in packet 166 | fn handle_outgoing_publish( 167 | &mut self, 168 | request: &mut SerializedPacket<'_>, 169 | now: &TimerInstantU32, 170 | ) -> Result<(), StateError> { 171 | match request.header()?.qos { 172 | QoS::AtMostOnce => { 173 | trace!("Sending Publish({:?})", QoS::AtMostOnce); 174 | } 175 | QoS::AtLeastOnce => { 176 | let pid = self.next_pid(); 177 | trace!("Sending Publish({:?}, {:?})", pid, QoS::AtLeastOnce); 178 | self.outgoing_pub 179 | .insert(pid.get(), Inflight::new(StartTime::new(*now), &request.0)) 180 | .map_err(|_| StateError::MaxMessagesInflight)?; 181 | request.set_pid(pid)?; 182 | } 183 | QoS::ExactlyOnce => { 184 | let pid = self.next_pid(); 185 | trace!("Sending Publish({:?}, {:?})", pid, QoS::ExactlyOnce); 186 | self.outgoing_pub 187 | .insert(pid.get(), Inflight::new(StartTime::new(*now), &request.0)) 188 | .map_err(|_| StateError::MaxMessagesInflight)?; 189 | request.set_pid(pid)?; 190 | } 191 | } 192 | Ok(()) 193 | } 194 | 195 | /// Iterates through the list of stored publishes and removes the publish 196 | /// with the matching packet identifier. Removal is now a O(n) operation. 197 | /// This should be usually ok in case of acks due to ack ordering in normal 198 | /// conditions. But in cases where the broker doesn't guarantee the order of 199 | /// acks, the performance won't be optimal 200 | fn handle_incoming_puback( 201 | &mut self, 202 | pid: Pid, 203 | ) -> Result<(Option, Option>), StateError> { 204 | if self.outgoing_pub.contains_key(&pid.get()) { 205 | let _publish = self.outgoing_pub.remove(&pid.get()); 206 | 207 | let request = None; 208 | let notification = Some(Notification::Puback(pid)); 209 | trace!("Received Puback({:?})", pid); 210 | Ok((notification, request)) 211 | } else { 212 | error!("Unsolicited puback packet: {:?}", pid.get()); 213 | // Err(StateError::Unsolicited) 214 | Ok((None, None)) 215 | } 216 | } 217 | 218 | fn handle_incoming_suback<'a>( 219 | &mut self, 220 | suback: Suback<'a>, 221 | ) -> Result<(Option, Option>), StateError> { 222 | let request = None; 223 | trace!("Received Suback({:?})", suback.pid); 224 | // TODO: Add suback packet info here 225 | let notification = Some(Notification::Suback(suback.pid)); 226 | Ok((notification, request)) 227 | } 228 | 229 | fn handle_incoming_unsuback( 230 | &mut self, 231 | pid: Pid, 232 | ) -> Result<(Option, Option>), StateError> { 233 | let request = None; 234 | trace!("Received Unsuback({:?})", pid); 235 | let notification = Some(Notification::Unsuback(pid)); 236 | Ok((notification, request)) 237 | } 238 | 239 | /// Iterates through the list of stored publishes and removes the publish with the 240 | /// matching packet identifier. Removal is now a O(n) operation. This should be 241 | /// usually ok in case of acks due to ack ordering in normal conditions. But in cases 242 | /// where the broker doesn't guarantee the order of acks, the performance won't be optimal 243 | fn handle_incoming_pubrec( 244 | &mut self, 245 | pid: Pid, 246 | ) -> Result<(Option, Option>), StateError> { 247 | if self.outgoing_pub.contains_key(&pid.get()) { 248 | self.outgoing_pub.remove(&pid.get()); 249 | self.outgoing_rel 250 | .insert(pid.get()) 251 | .map_err(|_| StateError::InvalidState)?; 252 | 253 | let reply = Some(Packet::Pubrel(pid)); 254 | let notification = Some(Notification::Pubrec(pid)); 255 | Ok((notification, reply)) 256 | } else { 257 | error!("Unsolicited pubrec packet: {:?}", pid.get()); 258 | // Err(StateError::Unsolicited) 259 | Ok((None, None)) 260 | } 261 | } 262 | 263 | /// Results in a publish notification in all the QoS cases. Replys with an ack 264 | /// in case of QoS1 and Replys rec in case of QoS while also storing the message 265 | fn handle_incoming_publish<'b>( 266 | &mut self, 267 | publish: Publish<'b>, 268 | ) -> Result<(Option, Option>), StateError> { 269 | let qospid = (publish.qos, publish.pid); 270 | 271 | #[cfg(not(feature = "std"))] 272 | let boxed_publish = BoxedPublish::alloc().unwrap(); 273 | #[cfg(not(feature = "std"))] 274 | let notification = Notification::Publish(boxed_publish.init(publish.try_into().unwrap())); 275 | 276 | #[cfg(feature = "std")] 277 | let notification = Notification::Publish(std::boxed::Box::new(publish.try_into().unwrap())); 278 | 279 | let request = match qospid { 280 | (QoS::AtMostOnce, _) => None, 281 | (QoS::AtLeastOnce, Some(pid)) => Some(Packet::Puback(pid)), 282 | (QoS::ExactlyOnce, Some(pid)) => { 283 | self.incoming_pub.insert(pid.get()).map_err(|_| { 284 | error!("Failed to insert incoming pub!"); 285 | StateError::InvalidState 286 | })?; 287 | 288 | Some(Packet::Pubrec(pid)) 289 | } 290 | _ => return Err(StateError::InvalidHeader), 291 | }; 292 | Ok((Some(notification), request)) 293 | } 294 | 295 | fn handle_incoming_pubrel( 296 | &mut self, 297 | pid: Pid, 298 | ) -> Result<(Option, Option>), StateError> { 299 | if self.incoming_pub.contains(&pid.get()) { 300 | self.incoming_pub.remove(&pid.get()); 301 | let reply = Packet::Pubcomp(pid); 302 | Ok((None, Some(reply))) 303 | } else { 304 | error!("Unsolicited pubrel packet: {:?}", pid.get()); 305 | // Err(StateError::Unsolicited) 306 | Ok((None, None)) 307 | } 308 | } 309 | 310 | fn handle_incoming_pubcomp( 311 | &mut self, 312 | pid: Pid, 313 | ) -> Result<(Option, Option>), StateError> { 314 | if self.outgoing_rel.contains(&pid.get()) { 315 | self.outgoing_rel.remove(&pid.get()); 316 | let notification = Some(Notification::Pubcomp(pid)); 317 | let reply = None; 318 | Ok((notification, reply)) 319 | } else { 320 | error!("Unsolicited pubcomp packet: {:?}", pid.get()); 321 | // Err(StateError::Unsolicited) 322 | Ok((None, None)) 323 | } 324 | } 325 | 326 | /// check when the last control packet/pingreq packet is received and return 327 | /// the status which tells if keep alive time has exceeded 328 | /// NOTE: status will be checked for zero keepalive times also 329 | fn handle_outgoing_ping<'b>(&mut self) -> Result, StateError> { 330 | // raise error if last ping didn't receive ack 331 | if self.await_pingresp { 332 | error!("Error awaiting for last ping response"); 333 | return Err(StateError::AwaitPingResp); 334 | } 335 | 336 | self.await_pingresp = true; 337 | 338 | trace!("Sending Pingreq"); 339 | 340 | Ok(Packet::Pingreq) 341 | } 342 | 343 | fn handle_incoming_pingresp( 344 | &mut self, 345 | ) -> Result<(Option, Option>), StateError> { 346 | self.await_pingresp = false; 347 | trace!("Received Pingresp"); 348 | Ok((None, None)) 349 | } 350 | 351 | pub(crate) fn handle_outgoing_connect(&mut self) { 352 | self.connection_status = MqttConnectionStatus::Handshake; 353 | } 354 | 355 | pub fn handle_incoming_connack(&mut self, connack: Connack) -> Result<(), StateError> { 356 | match connack.code { 357 | ConnectReturnCode::Accepted 358 | if self.connection_status == MqttConnectionStatus::Handshake => 359 | { 360 | debug!("MQTT connected!"); 361 | self.connection_status = MqttConnectionStatus::Connected; 362 | Ok(()) 363 | } 364 | ConnectReturnCode::Accepted 365 | if self.connection_status != MqttConnectionStatus::Handshake => 366 | { 367 | error!( 368 | "Invalid state. Expected = {:?}, Current = {:?}", 369 | MqttConnectionStatus::Handshake, 370 | self.connection_status 371 | ); 372 | self.connection_status = MqttConnectionStatus::Disconnected; 373 | Err(StateError::InvalidState) 374 | } 375 | code => { 376 | error!("Connection failed. Connection error = {:?}", code as u8); 377 | self.connection_status = MqttConnectionStatus::Disconnected; 378 | Err(StateError::Connect(code)) 379 | } 380 | } 381 | } 382 | 383 | fn next_pid(&mut self) -> Pid { 384 | self.last_pid = self.last_pid + 1; 385 | self.last_pid 386 | } 387 | 388 | pub(crate) fn last_ping_entry(&mut self) -> &mut StartTime { 389 | &mut self.last_ping 390 | } 391 | 392 | pub(crate) fn retries( 393 | &mut self, 394 | now: TimerInstantU32, 395 | interval: TimerDurationU32, 396 | ) -> impl Iterator)> + '_ { 397 | self.outgoing_pub 398 | .iter_mut() 399 | .filter(move |(_, inflight)| inflight.last_touch.has_elapsed(&now, interval)) 400 | } 401 | } 402 | 403 | #[derive(Clone, Copy, Debug, PartialEq, Eq)] 404 | pub struct StartTime(Option>); 405 | 406 | impl Default for StartTime { 407 | fn default() -> Self { 408 | Self(None) 409 | } 410 | } 411 | 412 | impl StartTime { 413 | pub fn new(start_time: TimerInstantU32) -> Self { 414 | Self(start_time.into()) 415 | } 416 | 417 | pub fn or_insert(&mut self, now: TimerInstantU32) -> &mut Self { 418 | self.0.get_or_insert(now); 419 | self 420 | } 421 | 422 | pub fn insert(&mut self, now: TimerInstantU32) { 423 | self.0.replace(now); 424 | } 425 | } 426 | 427 | impl StartTime { 428 | /// Check whether an interval has elapsed since this start time. 429 | pub fn has_elapsed( 430 | &self, 431 | now: &TimerInstantU32, 432 | interval: TimerDurationU32, 433 | ) -> bool { 434 | if let Some(start_time) = self.0 { 435 | let elapse_time = start_time + interval; 436 | elapse_time <= *now 437 | } else { 438 | false 439 | } 440 | } 441 | } 442 | 443 | /// Client publication message data. 444 | #[derive(Debug)] 445 | pub(crate) struct Inflight { 446 | /// A publish of non-zero QoS. 447 | publish: heapless::Vec, 448 | /// A timestmap used for retry and expiry. 449 | last_touch: StartTime, 450 | } 451 | 452 | impl Inflight { 453 | pub(crate) fn new(last_touch: StartTime, publish: &[u8]) -> Self { 454 | assert!( 455 | !matches!( 456 | decoder::Header::new(publish[0]).unwrap().qos, 457 | QoS::AtMostOnce 458 | ), 459 | "Only non-zero QoSs are allowed." 460 | ); 461 | Self { 462 | publish: heapless::Vec::from_slice(publish).unwrap(), 463 | last_touch, 464 | } 465 | } 466 | 467 | pub(crate) fn last_touch_entry(&mut self) -> &mut StartTime { 468 | &mut self.last_touch 469 | } 470 | } 471 | 472 | impl Inflight { 473 | pub(crate) fn packet<'b>(&'b mut self, pid: u16) -> Result<&'b [u8], StateError> { 474 | let pid = pid.try_into().map_err(|_| StateError::PayloadEncoding)?; 475 | let mut packet = SerializedPacket(self.publish.as_mut()); 476 | packet.set_pid(pid)?; 477 | Ok(packet.to_inner()) 478 | } 479 | } 480 | 481 | #[cfg(test)] 482 | mod test { 483 | use super::{BoxedPublish, MqttConnectionStatus, MqttState, Packet, StateError}; 484 | use crate::{packet::SerializedPacket, Notification}; 485 | use core::convert::TryFrom; 486 | use fugit::TimerInstantU32; 487 | use heapless::pool::singleton::Pool; 488 | use mqttrust::{ 489 | encoding::v4::{decode_slice, encode_slice, Pid}, 490 | Publish, QoS, 491 | }; 492 | 493 | fn build_publish<'a>(qos: QoS, pid: Option) -> Publish<'a> { 494 | let topic = "hello/world"; 495 | let payload = &[1, 2, 3]; 496 | 497 | let pid = match qos { 498 | QoS::AtMostOnce => None, 499 | QoS::AtLeastOnce => pid.and_then(|p| Pid::try_from(p).ok()), 500 | QoS::ExactlyOnce => pid.and_then(|p| Pid::try_from(p).ok()), 501 | }; 502 | 503 | Publish { 504 | qos, 505 | pid, 506 | payload, 507 | dup: false, 508 | retain: false, 509 | topic_name: topic, 510 | } 511 | } 512 | 513 | fn build_mqttstate() -> MqttState<1000> { 514 | let state = MqttState::new(); 515 | const LEN: usize = 1024 * 10; 516 | static mut PUBLISH_MEM: [u8; LEN] = [0u8; LEN]; 517 | BoxedPublish::grow(unsafe { &mut PUBLISH_MEM }); 518 | state 519 | } 520 | 521 | #[test] 522 | fn handle_outgoing_requests() { 523 | let buf = &mut [0u8; 256]; 524 | let now = TimerInstantU32::from_ticks(0); 525 | let mut mqtt = build_mqttstate(); 526 | 527 | // Publish 528 | let publish = Packet::Publish(build_publish(QoS::AtMostOnce, None)); 529 | 530 | let len = encode_slice(&publish, buf).unwrap(); 531 | 532 | // Packet id shouldn't be set and publish shouldn't be saved in queue 533 | mqtt.handle_outgoing_request(&mut SerializedPacket(&mut buf[..len]), &now) 534 | .unwrap(); 535 | // assert_eq!(publish_out.qos, QoS::AtMostOnce); 536 | // assert_eq!(mqtt.outgoing_pub.len(), 0); 537 | 538 | // // Subscribe 539 | // let subscribe = SubscribeRequest { 540 | // topics: Vec::from_slice(&[ 541 | // SubscribeTopic { 542 | // topic_path: String::from("some/topic"), 543 | // qos: QoS::AtLeastOnce, 544 | // }, 545 | // SubscribeTopic { 546 | // topic_path: String::from("some/other/topic"), 547 | // qos: QoS::ExactlyOnce, 548 | // }, 549 | // ]) 550 | // .unwrap(), 551 | // }; 552 | 553 | // // Packet id should be set and subscribe shouldn't be saved in publish queue 554 | // mqtt.handle_outgoing_request(subscribe.try_into().unwrap(), buf, &now) 555 | // .unwrap(); 556 | // let mut topics_iter = subscribe_out.topics.iter(); 557 | 558 | // assert_eq!(subscribe_out.pid, Pid::try_from(2).unwrap()); 559 | // assert_eq!( 560 | // topics_iter.next(), 561 | // Some(&SubscribeTopic { 562 | // qos: QoS::AtLeastOnce, 563 | // topic_path: String::from("some/topic") 564 | // }) 565 | // ); 566 | // assert_eq!( 567 | // topics_iter.next(), 568 | // Some(&SubscribeTopic { 569 | // qos: QoS::ExactlyOnce, 570 | // topic_path: String::from("some/other/topic") 571 | // }) 572 | // ); 573 | // assert_eq!(topics_iter.next(), None); 574 | // assert_eq!(mqtt.outgoing_pub.len(), 0); 575 | 576 | // // Unsubscribe 577 | // let unsubscribe = UnsubscribeRequest { 578 | // topics: Vec::from_slice(&[ 579 | // String::from("some/topic"), 580 | // String::from("some/other/topic"), 581 | // ]) 582 | // .unwrap(), 583 | // }; 584 | 585 | // // Packet id should be set and subscribe shouldn't be saved in publish queue 586 | // let unsubscribe_out = 587 | // match mqtt.handle_outgoing_request(unsubscribe.try_into().unwrap(), buf, &now) { 588 | // Ok(Packet::Unsubscribe(p)) => p, 589 | // _ => panic!("Invalid packet. Should've been a unsubscribe packet"), 590 | // }; 591 | // let mut topics_iter = unsubscribe_out.topics.iter(); 592 | 593 | // assert_eq!(unsubscribe_out.pid, Pid::try_from(3).unwrap()); 594 | // assert_eq!(topics_iter.next(), Some(&String::from("some/topic"))); 595 | // assert_eq!(topics_iter.next(), Some(&String::from("some/other/topic"))); 596 | // assert_eq!(topics_iter.next(), None); 597 | // assert_eq!(mqtt.outgoing_pub.len(), 0); 598 | } 599 | 600 | #[test] 601 | fn outgoing_publish_handle_should_set_pid_correctly_and_add_publish_to_queue_correctly() { 602 | let buf = &mut [0u8; 256]; 603 | let now = TimerInstantU32::from_ticks(0); 604 | 605 | let mut mqtt = build_mqttstate(); 606 | 607 | // QoS0 Publish 608 | let publish = Packet::Publish(build_publish(QoS::AtMostOnce, None)); 609 | let len = encode_slice(&publish, buf).unwrap(); 610 | let mut pkg = SerializedPacket(&mut buf[..len]); 611 | 612 | // Packet id shouldn't be set and publish shouldn't be saved in queue 613 | mqtt.handle_outgoing_publish(&mut pkg, &now).unwrap(); 614 | 615 | let publish_out = match decode_slice(pkg.to_inner()).unwrap() { 616 | Some(Packet::Publish(p)) => p, 617 | _ => panic!(), 618 | }; 619 | assert_eq!(publish_out.qos, QoS::AtMostOnce); 620 | assert_eq!(mqtt.outgoing_pub.len(), 0); 621 | 622 | // QoS1 Publish 623 | let publish = Packet::Publish(build_publish(QoS::AtLeastOnce, None)); 624 | let len = encode_slice(&publish, buf).unwrap(); 625 | let mut pkg = SerializedPacket(&mut buf[..len]); 626 | 627 | // Packet id should be set and publish should be saved in queue 628 | mqtt.handle_outgoing_publish(&mut pkg, &now).unwrap(); 629 | let publish_out = match decode_slice(pkg.to_inner()).unwrap() { 630 | Some(Packet::Publish(p)) => p, 631 | _ => panic!(), 632 | }; 633 | assert_eq!(publish_out.qos, QoS::AtLeastOnce); 634 | assert_eq!(publish_out.pid, Some(Pid::try_from(2).unwrap())); 635 | assert_eq!(mqtt.outgoing_pub.len(), 1); 636 | } 637 | 638 | #[test] 639 | fn incoming_publish_should_be_added_to_queue_correctly() { 640 | let mut mqtt = build_mqttstate(); 641 | 642 | // QoS0, 1, 2 Publishes 643 | let publish1 = build_publish(QoS::AtMostOnce, Some(1)); 644 | let publish2 = build_publish(QoS::AtLeastOnce, Some(2)); 645 | let publish3 = build_publish(QoS::ExactlyOnce, Some(3)); 646 | 647 | mqtt.handle_incoming_publish(publish1).unwrap(); 648 | mqtt.handle_incoming_publish(publish2).unwrap(); 649 | mqtt.handle_incoming_publish(publish3).unwrap(); 650 | 651 | // only qos2 publish should be add to queue 652 | assert_eq!(mqtt.incoming_pub.len(), 1); 653 | assert!(mqtt.incoming_pub.contains(&3)); 654 | } 655 | 656 | #[test] 657 | fn incoming_qos2_publish_should_send_rec_to_network_and_publish_to_user() { 658 | let mut mqtt = build_mqttstate(); 659 | let publish = build_publish(QoS::ExactlyOnce, Some(1)); 660 | 661 | let (notification, request) = mqtt.handle_incoming_publish(publish).unwrap(); 662 | 663 | match notification { 664 | Some(Notification::Publish(publish)) => assert_eq!(publish.qospid, QoS::ExactlyOnce), 665 | _ => panic!("Invalid notification: {:?}", notification), 666 | } 667 | 668 | match request { 669 | Some(Packet::Pubrec(pid)) => assert_eq!(pid.get(), 1), 670 | _ => panic!("Invalid network request: {:?}", request), 671 | } 672 | } 673 | 674 | #[test] 675 | fn incoming_puback_should_remove_correct_publish_from_queue() { 676 | let mut mqtt = build_mqttstate(); 677 | let buf = &mut [0u8; 256]; 678 | let now = TimerInstantU32::from_ticks(0); 679 | 680 | let publish1 = Packet::Publish(build_publish(QoS::AtLeastOnce, None)); 681 | let len = encode_slice(&publish1, buf).unwrap(); 682 | let mut pkg1 = SerializedPacket(&mut buf[..len]); 683 | mqtt.handle_outgoing_publish(&mut pkg1, &now).unwrap(); 684 | 685 | assert_eq!(mqtt.outgoing_pub.len(), 1); 686 | 687 | let backup = mqtt.outgoing_pub.get_mut(&2).unwrap().packet(1).unwrap(); 688 | let publish_out = match decode_slice(backup).unwrap() { 689 | Some(Packet::Publish(p)) => p, 690 | _ => panic!(), 691 | }; 692 | assert_eq!(publish_out.qos, QoS::AtLeastOnce); 693 | 694 | mqtt.handle_incoming_puback(Pid::try_from(2).unwrap()) 695 | .unwrap(); 696 | assert_eq!(mqtt.outgoing_pub.len(), 0); 697 | } 698 | 699 | #[test] 700 | fn incoming_pubrec_should_release_correct_publish_from_queue_and_add_releaseid_to_rel_queue() { 701 | let mut mqtt = build_mqttstate(); 702 | let buf = &mut [0u8; 256]; 703 | let now = TimerInstantU32::from_ticks(0); 704 | 705 | let publish = Packet::Publish(build_publish(QoS::ExactlyOnce, None)); 706 | let len = encode_slice(&publish, buf).unwrap(); 707 | let mut pkg = SerializedPacket(&mut buf[..len]); 708 | mqtt.handle_outgoing_publish(&mut pkg, &now).unwrap(); 709 | 710 | mqtt.handle_incoming_pubrec(Pid::try_from(2).unwrap()) 711 | .unwrap(); 712 | assert_eq!(mqtt.outgoing_pub.len(), 0); 713 | assert_eq!(mqtt.outgoing_rel.len(), 1); 714 | 715 | // check if the element's pid is 2 716 | assert!(mqtt.outgoing_rel.contains(&2)); 717 | } 718 | 719 | #[test] 720 | fn incoming_pubrec_should_send_release_to_network_and_nothing_to_user() { 721 | let mut mqtt = build_mqttstate(); 722 | let buf = &mut [0u8; 256]; 723 | let now = TimerInstantU32::from_ticks(0); 724 | let pid = Pid::try_from(2).unwrap(); 725 | assert_eq!(pid.get(), 2); 726 | 727 | let publish = Packet::Publish(build_publish(QoS::ExactlyOnce, None)); 728 | let len = encode_slice(&publish, buf).unwrap(); 729 | let mut pkg = SerializedPacket(&mut buf[..len]); 730 | mqtt.handle_outgoing_publish(&mut pkg, &now).unwrap(); 731 | 732 | let (notification, request) = mqtt.handle_incoming_pubrec(pid).unwrap(); 733 | 734 | assert_eq!(notification, Some(Notification::Pubrec(pid))); 735 | assert_eq!(request, Some(Packet::Pubrel(pid))); 736 | } 737 | 738 | #[test] 739 | fn incoming_pubrel_should_send_comp_to_network_and_nothing_to_user() { 740 | let mut mqtt = build_mqttstate(); 741 | let publish = build_publish(QoS::ExactlyOnce, Some(1)); 742 | 743 | let pid = Pid::try_from(1).unwrap(); 744 | assert_eq!(pid.get(), 1); 745 | 746 | mqtt.handle_incoming_publish(publish).unwrap(); 747 | 748 | let (notification, request) = mqtt.handle_incoming_pubrel(pid).unwrap(); 749 | assert_eq!(notification, None); 750 | assert_eq!(request, Some(Packet::Pubcomp(pid))); 751 | } 752 | 753 | #[test] 754 | fn incoming_pubcomp_should_release_correct_pid_from_release_queue() { 755 | let mut mqtt = build_mqttstate(); 756 | let buf = &mut [0u8; 256]; 757 | let now = TimerInstantU32::from_ticks(0); 758 | let publish = Packet::Publish(build_publish(QoS::ExactlyOnce, None)); 759 | let len = encode_slice(&publish, buf).unwrap(); 760 | let mut pkg = SerializedPacket(&mut buf[..len]); 761 | 762 | let pid = Pid::try_from(2).unwrap(); 763 | 764 | mqtt.handle_outgoing_publish(&mut pkg, &now).unwrap(); 765 | mqtt.handle_incoming_pubrec(pid).unwrap(); 766 | 767 | mqtt.handle_incoming_pubcomp(pid).unwrap(); 768 | assert_eq!(mqtt.outgoing_pub.len(), 0); 769 | } 770 | 771 | #[test] 772 | fn outgoing_ping_handle_should_throw_errors_for_no_pingresp() { 773 | let mut mqtt = build_mqttstate(); 774 | let buf = &mut [0u8; 256]; 775 | let now = TimerInstantU32::from_ticks(0); 776 | mqtt.connection_status = MqttConnectionStatus::Connected; 777 | assert_eq!(mqtt.handle_outgoing_ping(), Ok(Packet::Pingreq)); 778 | assert!(mqtt.await_pingresp); 779 | 780 | // network activity other than pingresp 781 | let publish = Packet::Publish(build_publish(QoS::AtLeastOnce, None)); 782 | let len = encode_slice(&publish, buf).unwrap(); 783 | let mut pkg = SerializedPacket(&mut buf[..len]); 784 | 785 | mqtt.handle_outgoing_publish(&mut pkg, &now).unwrap(); 786 | mqtt.handle_incoming_packet(Packet::Puback(Pid::try_from(2).unwrap())) 787 | .unwrap(); 788 | 789 | // should throw error because we didn't get pingresp for previous ping 790 | assert_eq!(mqtt.handle_outgoing_ping(), Err(StateError::AwaitPingResp)); 791 | } 792 | 793 | #[test] 794 | fn outgoing_ping_handle_should_succeed_if_pingresp_is_received() { 795 | let mut mqtt = build_mqttstate(); 796 | 797 | mqtt.connection_status = MqttConnectionStatus::Connected; 798 | 799 | // should ping 800 | assert_eq!(mqtt.handle_outgoing_ping(), Ok(Packet::Pingreq)); 801 | assert!(mqtt.await_pingresp); 802 | assert_eq!( 803 | mqtt.handle_incoming_packet(Packet::Pingresp), 804 | Ok((None, None)) 805 | ); 806 | assert!(!mqtt.await_pingresp); 807 | 808 | // should ping 809 | assert_eq!(mqtt.handle_outgoing_ping(), Ok(Packet::Pingreq)); 810 | assert!(mqtt.await_pingresp); 811 | } 812 | } 813 | --------------------------------------------------------------------------------