├── .github └── workflows │ └── build-and-test.yml ├── .gitignore ├── .travis.yml ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── pub-client.rs ├── simple.rs ├── sub-client-async.rs └── sub-client.rs ├── rustfmt.toml └── src ├── control ├── fixed_header.rs ├── mod.rs ├── packet_type.rs └── variable_header │ ├── connect_ack_flags.rs │ ├── connect_flags.rs │ ├── connect_ret_code.rs │ ├── keep_alive.rs │ ├── mod.rs │ ├── packet_identifier.rs │ ├── protocol_level.rs │ ├── protocol_name.rs │ └── topic_name.rs ├── encodable.rs ├── lib.rs ├── packet ├── connack.rs ├── connect.rs ├── disconnect.rs ├── mod.rs ├── pingreq.rs ├── pingresp.rs ├── puback.rs ├── pubcomp.rs ├── publish.rs ├── pubrec.rs ├── pubrel.rs ├── suback.rs ├── subscribe.rs ├── unsuback.rs └── unsubscribe.rs ├── qos.rs ├── topic_filter.rs └── topic_name.rs /.github/workflows/build-and-test.yml: -------------------------------------------------------------------------------- 1 | name: Build & Test 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build-and-test: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Build 19 | run: cargo build --verbose 20 | - name: Run tests 21 | run: cargo test --verbose 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | Cargo.lock 3 | .vscode 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: rust 2 | rust: 3 | - stable 4 | - nightly 5 | 6 | script: 7 | - cargo test -v 8 | - cargo test --features "tokio-codec" 9 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Y. T. Chung "] 3 | name = "mqtt-protocol" 4 | version = "0.12.0" 5 | license = "MIT/Apache-2.0" 6 | description = "MQTT Protocol Library" 7 | keywords = ["mqtt", "protocol"] 8 | repository = "https://github.com/zonyitoo/mqtt-rs" 9 | documentation = "https://docs.rs/mqtt-protocol" 10 | edition = "2018" 11 | 12 | [dependencies] 13 | byteorder = "1.3" 14 | log = "0.4" 15 | tokio = { version = "1", optional = true } 16 | tokio-util = { version = "0.6", features = ["codec"], optional = true } 17 | bytes = { version = "1.0", optional = true } 18 | thiserror = "1.0" 19 | 20 | [dev-dependencies] 21 | clap = "2" 22 | env_logger = "0.8" 23 | tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net", "time", "io-util"] } 24 | futures = { version = "0.3" } 25 | uuid = { version = "0.8", features = ["v4"] } 26 | 27 | [features] 28 | tokio-codec = ["tokio", "tokio-util", "bytes"] 29 | default = [] 30 | 31 | [lib] 32 | name = "mqtt" 33 | 34 | [[example]] 35 | name = "sub-client-async" 36 | required-features = ["tokio"] 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Y. T. CHUNG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MQTT-rs 2 | 3 | [![Build Status](https://img.shields.io/travis/zonyitoo/mqtt-rs.svg)](https://travis-ci.org/zonyitoo/mqtt-rs) 4 | ![Build & Test](https://github.com/zonyitoo/mqtt-rs/workflows/Build%20&%20Test/badge.svg) 5 | [![License](https://img.shields.io/github/license/zonyitoo/mqtt-rs.svg)](https://github.com/zonyitoo/mqtt-rs) 6 | [![crates.io](https://img.shields.io/crates/v/mqtt-protocol.svg)](https://crates.io/crates/mqtt-protocol) 7 | [![dependency status](https://deps.rs/repo/github/zonyitoo/mqtt-rs/status.svg)](https://deps.rs/repo/github/zonyitoo/mqtt-rs) 8 | 9 | MQTT protocol library for Rust 10 | 11 | ```toml 12 | [dependencies] 13 | mqtt-protocol = "0.12" 14 | ``` 15 | 16 | ## Usage 17 | 18 | ```rust 19 | extern crate mqtt; 20 | 21 | use std::io::Cursor; 22 | 23 | use mqtt::{Encodable, Decodable}; 24 | use mqtt::packet::{VariablePacket, PublishPacket, QoSWithPacketIdentifier}; 25 | use mqtt::TopicName; 26 | 27 | fn main() { 28 | // Create a new Publish packet 29 | let packet = PublishPacket::new(TopicName::new("mqtt/learning").unwrap(), 30 | QoSWithPacketIdentifier::Level2(10), 31 | "Hello MQTT!"); 32 | 33 | // Encode 34 | let mut buf = Vec::new(); 35 | packet.encode(&mut buf).unwrap(); 36 | println!("Encoded: {:?}", buf); 37 | 38 | // Decode it with known type 39 | let mut dec_buf = Cursor::new(&buf[..]); 40 | let decoded = PublishPacket::decode(&mut dec_buf).unwrap(); 41 | println!("Decoded: {:?}", decoded); 42 | assert_eq!(packet, decoded); 43 | 44 | // Auto decode by the fixed header 45 | let mut dec_buf = Cursor::new(&buf[..]); 46 | let auto_decode = VariablePacket::decode(&mut dec_buf).unwrap(); 47 | println!("Variable packet decode: {:?}", auto_decode); 48 | assert_eq!(VariablePacket::PublishPacket(packet), auto_decode); 49 | } 50 | ``` 51 | 52 | ## Note 53 | 54 | * Based on [MQTT 3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) 55 | -------------------------------------------------------------------------------- /examples/pub-client.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate log; 3 | 4 | use std::env; 5 | use std::io::{self, Write}; 6 | use std::net::TcpStream; 7 | use std::thread; 8 | 9 | use clap::{App, Arg}; 10 | 11 | use uuid::Uuid; 12 | 13 | use mqtt::control::variable_header::ConnectReturnCode; 14 | use mqtt::packet::*; 15 | use mqtt::{Decodable, Encodable, QualityOfService}; 16 | use mqtt::{TopicFilter, TopicName}; 17 | 18 | fn generate_client_id() -> String { 19 | format!("/MQTT/rust/{}", Uuid::new_v4()) 20 | } 21 | 22 | fn main() { 23 | // configure logging 24 | env::set_var("RUST_LOG", env::var_os("RUST_LOG").unwrap_or_else(|| "info".into())); 25 | env_logger::init(); 26 | 27 | let matches = App::new("sub-client") 28 | .author("Y. T. Chung ") 29 | .arg( 30 | Arg::with_name("SERVER") 31 | .short("S") 32 | .long("server") 33 | .takes_value(true) 34 | .required(true) 35 | .help("MQTT server address (host:port)"), 36 | ) 37 | .arg( 38 | Arg::with_name("SUBSCRIBE") 39 | .short("s") 40 | .long("subscribe") 41 | .takes_value(true) 42 | .multiple(true) 43 | .required(true) 44 | .help("Channel filter to subscribe"), 45 | ) 46 | .arg( 47 | Arg::with_name("USER_NAME") 48 | .short("u") 49 | .long("username") 50 | .takes_value(true) 51 | .help("Login user name"), 52 | ) 53 | .arg( 54 | Arg::with_name("PASSWORD") 55 | .short("p") 56 | .long("password") 57 | .takes_value(true) 58 | .help("Password"), 59 | ) 60 | .arg( 61 | Arg::with_name("CLIENT_ID") 62 | .short("i") 63 | .long("client-identifier") 64 | .takes_value(true) 65 | .help("Client identifier"), 66 | ) 67 | .get_matches(); 68 | 69 | let server_addr = matches.value_of("SERVER").unwrap(); 70 | let client_id = matches 71 | .value_of("CLIENT_ID") 72 | .map(|x| x.to_owned()) 73 | .unwrap_or_else(generate_client_id); 74 | let channel_filters: Vec<(TopicFilter, QualityOfService)> = matches 75 | .values_of("SUBSCRIBE") 76 | .unwrap() 77 | .map(|c| (TopicFilter::new(c.to_string()).unwrap(), QualityOfService::Level0)) 78 | .collect(); 79 | 80 | info!("Connecting to {:?} ... ", server_addr); 81 | let mut stream = TcpStream::connect(server_addr).unwrap(); 82 | info!("Connected!"); 83 | 84 | info!("Client identifier {:?}", client_id); 85 | let mut conn = ConnectPacket::new(client_id); 86 | conn.set_clean_session(true); 87 | let mut buf = Vec::new(); 88 | conn.encode(&mut buf).unwrap(); 89 | stream.write_all(&buf[..]).unwrap(); 90 | 91 | let connack = ConnackPacket::decode(&mut stream).unwrap(); 92 | trace!("CONNACK {:?}", connack); 93 | 94 | if connack.connect_return_code() != ConnectReturnCode::ConnectionAccepted { 95 | panic!( 96 | "Failed to connect to server, return code {:?}", 97 | connack.connect_return_code() 98 | ); 99 | } 100 | 101 | info!("Applying channel filters {:?} ...", channel_filters); 102 | let sub = SubscribePacket::new(10, channel_filters); 103 | let mut buf = Vec::new(); 104 | sub.encode(&mut buf).unwrap(); 105 | stream.write_all(&buf[..]).unwrap(); 106 | 107 | let channels: Vec = matches 108 | .values_of("SUBSCRIBE") 109 | .unwrap() 110 | .map(|c| TopicName::new(c.to_string()).unwrap()) 111 | .collect(); 112 | 113 | let user_name = matches.value_of("USER_NAME").unwrap_or(""); 114 | 115 | let mut cloned_stream = stream.try_clone().unwrap(); 116 | thread::spawn(move || { 117 | loop { 118 | let packet = match VariablePacket::decode(&mut cloned_stream) { 119 | Ok(pk) => pk, 120 | Err(err) => { 121 | error!("Error in receiving packet {:?}", err); 122 | continue; 123 | } 124 | }; 125 | trace!("PACKET {:?}", packet); 126 | 127 | match packet { 128 | VariablePacket::PingreqPacket(..) => { 129 | let pingresp = PingrespPacket::new(); 130 | info!("Sending Ping response {:?}", pingresp); 131 | pingresp.encode(&mut cloned_stream).unwrap(); 132 | } 133 | VariablePacket::DisconnectPacket(..) => { 134 | break; 135 | } 136 | _ => { 137 | // Ignore other packets in pub client 138 | } 139 | } 140 | } 141 | }); 142 | 143 | let stdin = io::stdin(); 144 | loop { 145 | print!("{}: ", user_name); 146 | io::stdout().flush().unwrap(); 147 | 148 | let mut line = String::new(); 149 | stdin.read_line(&mut line).unwrap(); 150 | 151 | if line.trim_end() == "" { 152 | continue; 153 | } 154 | 155 | let message = format!("{}: {}", user_name, line.trim_end()); 156 | 157 | for chan in &channels { 158 | // let publish_packet = PublishPacket::new(chan.clone(), QoSWithPacketIdentifier::Level0, message.clone()); 159 | let publish_packet = PublishPacketRef::new(chan, QoSWithPacketIdentifier::Level0, message.as_bytes()); 160 | let mut buf = Vec::new(); 161 | publish_packet.encode(&mut buf).unwrap(); 162 | stream.write_all(&buf[..]).unwrap(); 163 | } 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /examples/simple.rs: -------------------------------------------------------------------------------- 1 | use std::io::Cursor; 2 | 3 | use mqtt::packet::{PublishPacket, QoSWithPacketIdentifier, VariablePacket}; 4 | use mqtt::TopicName; 5 | use mqtt::{Decodable, Encodable}; 6 | 7 | fn main() { 8 | // Create a new Publish packet 9 | let packet = PublishPacket::new( 10 | TopicName::new("mqtt/learning").unwrap(), 11 | QoSWithPacketIdentifier::Level2(10), 12 | "Hello MQTT!", 13 | ); 14 | 15 | // Encode 16 | let mut buf = Vec::new(); 17 | packet.encode(&mut buf).unwrap(); 18 | println!("Encoded: {:?}", buf); 19 | 20 | // Decode it with known type 21 | let mut dec_buf = Cursor::new(&buf[..]); 22 | let decoded = PublishPacket::decode(&mut dec_buf).unwrap(); 23 | println!("Decoded: {:?}", decoded); 24 | assert_eq!(packet, decoded); 25 | 26 | // Auto decode by the fixed header 27 | let mut dec_buf = Cursor::new(&buf[..]); 28 | let auto_decode = VariablePacket::decode(&mut dec_buf).unwrap(); 29 | println!("Variable packet decode: {:?}", auto_decode); 30 | assert_eq!(VariablePacket::PublishPacket(packet), auto_decode); 31 | } 32 | -------------------------------------------------------------------------------- /examples/sub-client-async.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::io::Write; 3 | use std::net; 4 | use std::str; 5 | use std::time::Duration; 6 | 7 | use clap::{App, Arg}; 8 | use log::{error, info, trace}; 9 | 10 | use uuid::Uuid; 11 | 12 | use tokio::io::AsyncWriteExt; 13 | use tokio::net::TcpStream; 14 | 15 | use mqtt::control::variable_header::ConnectReturnCode; 16 | use mqtt::packet::*; 17 | use mqtt::TopicFilter; 18 | use mqtt::{Decodable, Encodable, QualityOfService}; 19 | 20 | fn generate_client_id() -> String { 21 | format!("/MQTT/rust/{}", Uuid::new_v4()) 22 | } 23 | 24 | #[tokio::main] 25 | async fn main() { 26 | // configure logging 27 | env::set_var("RUST_LOG", env::var_os("RUST_LOG").unwrap_or_else(|| "info".into())); 28 | env_logger::init(); 29 | 30 | let matches = App::new("sub-client") 31 | .author("Y. T. Chung ") 32 | .arg( 33 | Arg::with_name("SERVER") 34 | .short("S") 35 | .long("server") 36 | .takes_value(true) 37 | .required(true) 38 | .help("MQTT server address (host:port)"), 39 | ) 40 | .arg( 41 | Arg::with_name("SUBSCRIBE") 42 | .short("s") 43 | .long("subscribe") 44 | .takes_value(true) 45 | .multiple(true) 46 | .required(true) 47 | .help("Channel filter to subscribe"), 48 | ) 49 | .arg( 50 | Arg::with_name("USER_NAME") 51 | .short("u") 52 | .long("username") 53 | .takes_value(true) 54 | .help("Login user name"), 55 | ) 56 | .arg( 57 | Arg::with_name("PASSWORD") 58 | .short("p") 59 | .long("password") 60 | .takes_value(true) 61 | .help("Password"), 62 | ) 63 | .arg( 64 | Arg::with_name("CLIENT_ID") 65 | .short("i") 66 | .long("client-identifier") 67 | .takes_value(true) 68 | .help("Client identifier"), 69 | ) 70 | .get_matches(); 71 | 72 | let server_addr = matches.value_of("SERVER").unwrap(); 73 | let client_id = matches 74 | .value_of("CLIENT_ID") 75 | .map(|x| x.to_owned()) 76 | .unwrap_or_else(generate_client_id); 77 | let channel_filters: Vec<(TopicFilter, QualityOfService)> = matches 78 | .values_of("SUBSCRIBE") 79 | .unwrap() 80 | .map(|c| (TopicFilter::new(c.to_string()).unwrap(), QualityOfService::Level0)) 81 | .collect(); 82 | 83 | let keep_alive = 10; 84 | 85 | info!("Connecting to {:?} ... ", server_addr); 86 | let mut stream = net::TcpStream::connect(server_addr).unwrap(); 87 | info!("Connected!"); 88 | 89 | info!("Client identifier {:?}", client_id); 90 | let mut conn = ConnectPacket::new(client_id); 91 | conn.set_clean_session(true); 92 | conn.set_keep_alive(keep_alive); 93 | let mut buf = Vec::new(); 94 | conn.encode(&mut buf).unwrap(); 95 | stream.write_all(&buf[..]).unwrap(); 96 | 97 | let connack = ConnackPacket::decode(&mut stream).unwrap(); 98 | trace!("CONNACK {:?}", connack); 99 | 100 | if connack.connect_return_code() != ConnectReturnCode::ConnectionAccepted { 101 | panic!( 102 | "Failed to connect to server, return code {:?}", 103 | connack.connect_return_code() 104 | ); 105 | } 106 | 107 | // const CHANNEL_FILTER: &'static str = "typing-speed-test.aoeu.eu"; 108 | info!("Applying channel filters {:?} ...", channel_filters); 109 | let sub = SubscribePacket::new(10, channel_filters); 110 | let mut buf = Vec::new(); 111 | sub.encode(&mut buf).unwrap(); 112 | stream.write_all(&buf[..]).unwrap(); 113 | 114 | loop { 115 | let packet = match VariablePacket::decode(&mut stream) { 116 | Ok(pk) => pk, 117 | Err(err) => { 118 | error!("Error in receiving packet {:?}", err); 119 | continue; 120 | } 121 | }; 122 | trace!("PACKET {:?}", packet); 123 | 124 | if let VariablePacket::SubackPacket(ref ack) = packet { 125 | if ack.packet_identifier() != 10 { 126 | panic!("SUBACK packet identifier not match"); 127 | } 128 | 129 | info!("Subscribed!"); 130 | break; 131 | } 132 | } 133 | 134 | // connection made, start the async work 135 | stream.set_nonblocking(true).unwrap(); 136 | let mut stream = TcpStream::from_std(stream).unwrap(); 137 | let (mut mqtt_read, mut mqtt_write) = stream.split(); 138 | 139 | let ping_sender = async move { 140 | loop { 141 | info!("Sending PINGREQ to broker"); 142 | 143 | let pingreq_packet = PingreqPacket::new(); 144 | 145 | let mut buf = Vec::new(); 146 | pingreq_packet.encode(&mut buf).unwrap(); 147 | mqtt_write.write_all(&buf).await.unwrap(); 148 | 149 | tokio::time::sleep(Duration::from_secs(keep_alive as u64 / 2)).await; 150 | } 151 | }; 152 | 153 | let receiver = async move { 154 | while let Ok(packet) = VariablePacket::parse(&mut mqtt_read).await { 155 | trace!("PACKET {:?}", packet); 156 | 157 | match packet { 158 | VariablePacket::PingrespPacket(..) => { 159 | info!("Received PINGRESP from broker .."); 160 | } 161 | VariablePacket::PublishPacket(ref publ) => { 162 | let msg = match str::from_utf8(publ.payload()) { 163 | Ok(msg) => msg, 164 | Err(err) => { 165 | error!("Failed to decode publish message {:?}", err); 166 | continue; 167 | } 168 | }; 169 | info!("PUBLISH ({}): {}", publ.topic_name(), msg); 170 | } 171 | _ => {} 172 | } 173 | } 174 | }; 175 | 176 | tokio::pin!(ping_sender); 177 | tokio::pin!(receiver); 178 | 179 | tokio::join!(ping_sender, receiver); 180 | } 181 | -------------------------------------------------------------------------------- /examples/sub-client.rs: -------------------------------------------------------------------------------- 1 | extern crate mqtt; 2 | #[macro_use] 3 | extern crate log; 4 | extern crate clap; 5 | extern crate env_logger; 6 | extern crate uuid; 7 | 8 | use std::env; 9 | use std::io::Write; 10 | use std::net::TcpStream; 11 | use std::str; 12 | use std::thread; 13 | use std::time::{Duration, Instant}; 14 | 15 | use clap::{App, Arg}; 16 | 17 | use uuid::Uuid; 18 | 19 | use mqtt::control::variable_header::ConnectReturnCode; 20 | use mqtt::packet::*; 21 | use mqtt::TopicFilter; 22 | use mqtt::{Decodable, Encodable, QualityOfService}; 23 | 24 | fn generate_client_id() -> String { 25 | format!("/MQTT/rust/{}", Uuid::new_v4()) 26 | } 27 | 28 | fn main() { 29 | // configure logging 30 | env::set_var("RUST_LOG", env::var_os("RUST_LOG").unwrap_or_else(|| "info".into())); 31 | env_logger::init(); 32 | 33 | let matches = App::new("sub-client") 34 | .author("Y. T. Chung ") 35 | .arg( 36 | Arg::with_name("SERVER") 37 | .short("S") 38 | .long("server") 39 | .takes_value(true) 40 | .required(true) 41 | .help("MQTT server address (host:port)"), 42 | ) 43 | .arg( 44 | Arg::with_name("SUBSCRIBE") 45 | .short("s") 46 | .long("subscribe") 47 | .takes_value(true) 48 | .multiple(true) 49 | .required(true) 50 | .help("Channel filter to subscribe"), 51 | ) 52 | .arg( 53 | Arg::with_name("USER_NAME") 54 | .short("u") 55 | .long("username") 56 | .takes_value(true) 57 | .help("Login user name"), 58 | ) 59 | .arg( 60 | Arg::with_name("PASSWORD") 61 | .short("p") 62 | .long("password") 63 | .takes_value(true) 64 | .help("Password"), 65 | ) 66 | .arg( 67 | Arg::with_name("CLIENT_ID") 68 | .short("i") 69 | .long("client-identifier") 70 | .takes_value(true) 71 | .help("Client identifier"), 72 | ) 73 | .get_matches(); 74 | 75 | let server_addr = matches.value_of("SERVER").unwrap(); 76 | let client_id = matches 77 | .value_of("CLIENT_ID") 78 | .map(|x| x.to_owned()) 79 | .unwrap_or_else(generate_client_id); 80 | let channel_filters: Vec<(TopicFilter, QualityOfService)> = matches 81 | .values_of("SUBSCRIBE") 82 | .unwrap() 83 | .map(|c| (TopicFilter::new(c.to_string()).unwrap(), QualityOfService::Level0)) 84 | .collect(); 85 | 86 | let keep_alive = 10; 87 | 88 | info!("Connecting to {:?} ... ", server_addr); 89 | let mut stream = TcpStream::connect(server_addr).unwrap(); 90 | info!("Connected!"); 91 | 92 | info!("Client identifier {:?}", client_id); 93 | let mut conn = ConnectPacket::new(client_id); 94 | conn.set_clean_session(true); 95 | conn.set_keep_alive(keep_alive); 96 | let mut buf = Vec::new(); 97 | conn.encode(&mut buf).unwrap(); 98 | stream.write_all(&buf[..]).unwrap(); 99 | 100 | let connack = ConnackPacket::decode(&mut stream).unwrap(); 101 | trace!("CONNACK {:?}", connack); 102 | 103 | if connack.connect_return_code() != ConnectReturnCode::ConnectionAccepted { 104 | panic!( 105 | "Failed to connect to server, return code {:?}", 106 | connack.connect_return_code() 107 | ); 108 | } 109 | 110 | // const CHANNEL_FILTER: &'static str = "typing-speed-test.aoeu.eu"; 111 | info!("Applying channel filters {:?} ...", channel_filters); 112 | let sub = SubscribePacket::new(10, channel_filters); 113 | let mut buf = Vec::new(); 114 | sub.encode(&mut buf).unwrap(); 115 | stream.write_all(&buf[..]).unwrap(); 116 | 117 | loop { 118 | let packet = match VariablePacket::decode(&mut stream) { 119 | Ok(pk) => pk, 120 | Err(err) => { 121 | error!("Error in receiving packet {:?}", err); 122 | continue; 123 | } 124 | }; 125 | trace!("PACKET {:?}", packet); 126 | 127 | if let VariablePacket::SubackPacket(ref ack) = packet { 128 | if ack.packet_identifier() != 10 { 129 | panic!("SUBACK packet identifier not match"); 130 | } 131 | 132 | info!("Subscribed!"); 133 | break; 134 | } 135 | } 136 | 137 | let mut stream_clone = stream.try_clone().unwrap(); 138 | thread::spawn(move || { 139 | let mut last_ping_time = Instant::now(); 140 | let mut next_ping_time = last_ping_time + Duration::from_secs((keep_alive as f32 * 0.9) as u64); 141 | loop { 142 | let current_timestamp = Instant::now(); 143 | if keep_alive > 0 && current_timestamp >= next_ping_time { 144 | info!("Sending PINGREQ to broker"); 145 | 146 | let pingreq_packet = PingreqPacket::new(); 147 | 148 | let mut buf = Vec::new(); 149 | pingreq_packet.encode(&mut buf).unwrap(); 150 | stream_clone.write_all(&buf[..]).unwrap(); 151 | 152 | last_ping_time = current_timestamp; 153 | next_ping_time = last_ping_time + Duration::from_secs((keep_alive as f32 * 0.9) as u64); 154 | thread::sleep(Duration::new((keep_alive / 2) as u64, 0)); 155 | } 156 | } 157 | }); 158 | 159 | loop { 160 | let packet = match VariablePacket::decode(&mut stream) { 161 | Ok(pk) => pk, 162 | Err(err) => { 163 | error!("Error in receiving packet {}", err); 164 | continue; 165 | } 166 | }; 167 | trace!("PACKET {:?}", packet); 168 | 169 | match packet { 170 | VariablePacket::PingrespPacket(..) => { 171 | info!("Receiving PINGRESP from broker .."); 172 | } 173 | VariablePacket::PublishPacket(ref publ) => { 174 | let msg = match str::from_utf8(publ.payload()) { 175 | Ok(msg) => msg, 176 | Err(err) => { 177 | error!("Failed to decode publish message {:?}", err); 178 | continue; 179 | } 180 | }; 181 | info!("PUBLISH ({}): {}", publ.topic_name(), msg); 182 | } 183 | _ => {} 184 | } 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | edition = "2018" 2 | max_width = 120 3 | reorder_imports = true 4 | use_try_shorthand = true 5 | -------------------------------------------------------------------------------- /src/control/fixed_header.rs: -------------------------------------------------------------------------------- 1 | //! Fixed header in MQTT 2 | 3 | use std::io::{self, Read, Write}; 4 | 5 | use byteorder::{ReadBytesExt, WriteBytesExt}; 6 | 7 | #[cfg(feature = "tokio")] 8 | use tokio::io::{AsyncRead, AsyncReadExt}; 9 | 10 | use crate::control::packet_type::{PacketType, PacketTypeError}; 11 | use crate::{Decodable, Encodable}; 12 | 13 | /// Fixed header for each MQTT control packet 14 | /// 15 | /// Format: 16 | /// 17 | /// ```plain 18 | /// 7 3 0 19 | /// +--------------------------+--------------------------+ 20 | /// | MQTT Control Packet Type | Flags for each type | 21 | /// +--------------------------+--------------------------+ 22 | /// | Remaining Length ... | 23 | /// +-----------------------------------------------------+ 24 | /// ``` 25 | #[derive(Debug, Clone, Copy, Eq, PartialEq)] 26 | pub struct FixedHeader { 27 | /// Packet Type 28 | pub packet_type: PacketType, 29 | 30 | /// The Remaining Length is the number of bytes remaining within the current packet, 31 | /// including data in the variable header and the payload. The Remaining Length does 32 | /// not include the bytes used to encode the Remaining Length. 33 | pub remaining_length: u32, 34 | } 35 | 36 | impl FixedHeader { 37 | pub fn new(packet_type: PacketType, remaining_length: u32) -> FixedHeader { 38 | debug_assert!(remaining_length <= 0x0FFF_FFFF); 39 | FixedHeader { 40 | packet_type, 41 | remaining_length, 42 | } 43 | } 44 | 45 | #[cfg(feature = "tokio")] 46 | /// Asynchronously parse a single fixed header from an AsyncRead type, such as a network 47 | /// socket. 48 | /// 49 | /// This requires mqtt-rs to be built with `feature = "tokio"` 50 | pub async fn parse(rdr: &mut A) -> Result { 51 | let type_val = rdr.read_u8().await?; 52 | 53 | let mut remaining_len = 0; 54 | let mut i = 0; 55 | 56 | loop { 57 | let byte = rdr.read_u8().await?; 58 | 59 | remaining_len |= (u32::from(byte) & 0x7F) << (7 * i); 60 | 61 | if i >= 4 { 62 | return Err(FixedHeaderError::MalformedRemainingLength); 63 | } 64 | 65 | if byte & 0x80 == 0 { 66 | break; 67 | } else { 68 | i += 1; 69 | } 70 | } 71 | 72 | match PacketType::from_u8(type_val) { 73 | Ok(packet_type) => Ok(FixedHeader::new(packet_type, remaining_len)), 74 | Err(PacketTypeError::ReservedType(ty, _)) => Err(FixedHeaderError::ReservedType(ty, remaining_len)), 75 | Err(err) => Err(From::from(err)), 76 | } 77 | } 78 | } 79 | 80 | impl Encodable for FixedHeader { 81 | fn encode(&self, wr: &mut W) -> Result<(), io::Error> { 82 | wr.write_u8(self.packet_type.to_u8())?; 83 | 84 | let mut cur_len = self.remaining_length; 85 | loop { 86 | let mut byte = (cur_len & 0x7F) as u8; 87 | cur_len >>= 7; 88 | 89 | if cur_len > 0 { 90 | byte |= 0x80; 91 | } 92 | 93 | wr.write_u8(byte)?; 94 | 95 | if cur_len == 0 { 96 | break; 97 | } 98 | } 99 | 100 | Ok(()) 101 | } 102 | 103 | fn encoded_length(&self) -> u32 { 104 | let rem_size = if self.remaining_length >= 2_097_152 { 105 | 4 106 | } else if self.remaining_length >= 16_384 { 107 | 3 108 | } else if self.remaining_length >= 128 { 109 | 2 110 | } else { 111 | 1 112 | }; 113 | 1 + rem_size 114 | } 115 | } 116 | 117 | impl Decodable for FixedHeader { 118 | type Error = FixedHeaderError; 119 | type Cond = (); 120 | 121 | fn decode_with(rdr: &mut R, _rest: ()) -> Result { 122 | let type_val = rdr.read_u8()?; 123 | let remaining_len = { 124 | let mut cur = 0u32; 125 | for i in 0.. { 126 | let byte = rdr.read_u8()?; 127 | cur |= ((byte as u32) & 0x7F) << (7 * i); 128 | 129 | if i >= 4 { 130 | return Err(FixedHeaderError::MalformedRemainingLength); 131 | } 132 | 133 | if byte & 0x80 == 0 { 134 | break; 135 | } 136 | } 137 | 138 | cur 139 | }; 140 | 141 | match PacketType::from_u8(type_val) { 142 | Ok(packet_type) => Ok(FixedHeader::new(packet_type, remaining_len)), 143 | Err(PacketTypeError::ReservedType(ty, _)) => Err(FixedHeaderError::ReservedType(ty, remaining_len)), 144 | Err(err) => Err(From::from(err)), 145 | } 146 | } 147 | } 148 | 149 | #[derive(Debug, thiserror::Error)] 150 | pub enum FixedHeaderError { 151 | #[error("malformed remaining length")] 152 | MalformedRemainingLength, 153 | #[error("reserved header ({0}, {1})")] 154 | ReservedType(u8, u32), 155 | #[error(transparent)] 156 | PacketTypeError(#[from] PacketTypeError), 157 | #[error(transparent)] 158 | IoError(#[from] io::Error), 159 | } 160 | 161 | #[cfg(test)] 162 | mod test { 163 | use super::*; 164 | 165 | use crate::control::packet_type::{ControlType, PacketType}; 166 | use crate::{Decodable, Encodable}; 167 | use std::io::Cursor; 168 | 169 | #[test] 170 | fn test_encode_fixed_header() { 171 | let header = FixedHeader::new(PacketType::with_default(ControlType::Connect), 321); 172 | let mut buf = Vec::new(); 173 | header.encode(&mut buf).unwrap(); 174 | 175 | let expected = b"\x10\xc1\x02"; 176 | assert_eq!(&expected[..], &buf[..]); 177 | } 178 | 179 | #[test] 180 | fn test_decode_fixed_header() { 181 | let stream = b"\x10\xc1\x02"; 182 | let mut cursor = Cursor::new(&stream[..]); 183 | let header = FixedHeader::decode(&mut cursor).unwrap(); 184 | assert_eq!(header.packet_type, PacketType::with_default(ControlType::Connect)); 185 | assert_eq!(header.remaining_length, 321); 186 | } 187 | 188 | #[test] 189 | #[should_panic] 190 | fn test_decode_too_long_fixed_header() { 191 | let stream = b"\x10\x80\x80\x80\x80\x02"; 192 | let mut cursor = Cursor::new(&stream[..]); 193 | FixedHeader::decode(&mut cursor).unwrap(); 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /src/control/mod.rs: -------------------------------------------------------------------------------- 1 | //! Control packets 2 | 3 | pub use self::fixed_header::FixedHeader; 4 | pub use self::packet_type::{ControlType, PacketType}; 5 | pub use self::variable_header::*; 6 | 7 | pub mod fixed_header; 8 | pub mod packet_type; 9 | pub mod variable_header; 10 | -------------------------------------------------------------------------------- /src/control/packet_type.rs: -------------------------------------------------------------------------------- 1 | //! Packet types 2 | 3 | use crate::qos::QualityOfService; 4 | 5 | /// Packet type 6 | // INVARIANT: the high 4 bits of the byte must be a valid control type 7 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 8 | pub struct PacketType(u8); 9 | 10 | /// Defined control types 11 | #[rustfmt::skip] 12 | #[repr(u8)] 13 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 14 | pub enum ControlType { 15 | /// Client request to connect to Server 16 | Connect = value::CONNECT, 17 | 18 | /// Connect acknowledgment 19 | ConnectAcknowledgement = value::CONNACK, 20 | 21 | /// Publish message 22 | Publish = value::PUBLISH, 23 | 24 | /// Publish acknowledgment 25 | PublishAcknowledgement = value::PUBACK, 26 | 27 | /// Publish received (assured delivery part 1) 28 | PublishReceived = value::PUBREC, 29 | 30 | /// Publish release (assured delivery part 2) 31 | PublishRelease = value::PUBREL, 32 | 33 | /// Publish complete (assured delivery part 3) 34 | PublishComplete = value::PUBCOMP, 35 | 36 | /// Client subscribe request 37 | Subscribe = value::SUBSCRIBE, 38 | 39 | /// Subscribe acknowledgment 40 | SubscribeAcknowledgement = value::SUBACK, 41 | 42 | /// Unsubscribe request 43 | Unsubscribe = value::UNSUBSCRIBE, 44 | 45 | /// Unsubscribe acknowledgment 46 | UnsubscribeAcknowledgement = value::UNSUBACK, 47 | 48 | /// PING request 49 | PingRequest = value::PINGREQ, 50 | 51 | /// PING response 52 | PingResponse = value::PINGRESP, 53 | 54 | /// Client is disconnecting 55 | Disconnect = value::DISCONNECT, 56 | } 57 | 58 | impl ControlType { 59 | #[inline] 60 | fn default_flags(self) -> u8 { 61 | match self { 62 | ControlType::Connect => 0, 63 | ControlType::ConnectAcknowledgement => 0, 64 | 65 | ControlType::Publish => 0, 66 | ControlType::PublishAcknowledgement => 0, 67 | ControlType::PublishReceived => 0, 68 | ControlType::PublishRelease => 0b0010, 69 | ControlType::PublishComplete => 0, 70 | 71 | ControlType::Subscribe => 0b0010, 72 | ControlType::SubscribeAcknowledgement => 0, 73 | 74 | ControlType::Unsubscribe => 0b0010, 75 | ControlType::UnsubscribeAcknowledgement => 0, 76 | 77 | ControlType::PingRequest => 0, 78 | ControlType::PingResponse => 0, 79 | 80 | ControlType::Disconnect => 0, 81 | } 82 | } 83 | } 84 | 85 | impl PacketType { 86 | /// Creates a packet type. Returns None if `flags` is an invalid value for the given 87 | /// ControlType as defined by the [MQTT spec]. 88 | /// 89 | /// [MQTT spec]: http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Table_2.2_- 90 | pub fn new(t: ControlType, flags: u8) -> Result { 91 | let flags_ok = match t { 92 | ControlType::Publish => { 93 | let qos = (flags & 0b0110) >> 1; 94 | matches!(qos, 0 | 1 | 2) 95 | } 96 | _ => t.default_flags() == flags, 97 | }; 98 | if flags_ok { 99 | Ok(PacketType::new_unchecked(t, flags)) 100 | } else { 101 | Err(InvalidFlag(t, flags)) 102 | } 103 | } 104 | 105 | #[inline] 106 | fn new_unchecked(t: ControlType, flags: u8) -> PacketType { 107 | let byte = (t as u8) << 4 | (flags & 0x0F); 108 | #[allow(unused_unsafe)] 109 | unsafe { 110 | // SAFETY: just constructed from a valid ControlType 111 | PacketType(byte) 112 | } 113 | } 114 | 115 | /// Creates a packet type with default flags 116 | /// 117 | /// 118 | #[inline] 119 | pub fn with_default(t: ControlType) -> PacketType { 120 | let flags = t.default_flags(); 121 | PacketType::new_unchecked(t, flags) 122 | } 123 | 124 | pub(crate) fn publish(qos: QualityOfService) -> PacketType { 125 | PacketType::new_unchecked(ControlType::Publish, (qos as u8) << 1) 126 | } 127 | 128 | #[inline] 129 | pub(crate) fn update_flags(&mut self, upd: impl FnOnce(u8) -> u8) { 130 | let flags = upd(self.flags()); 131 | self.0 = (self.0 & !0x0F) | (flags & 0x0F) 132 | } 133 | 134 | /// To code 135 | #[inline] 136 | pub fn to_u8(self) -> u8 { 137 | self.0 138 | } 139 | 140 | /// From code 141 | pub fn from_u8(val: u8) -> Result { 142 | let type_val = val >> 4; 143 | let flags = val & 0x0F; 144 | 145 | let control_type = get_control_type(type_val).ok_or(PacketTypeError::ReservedType(type_val, flags))?; 146 | Ok(PacketType::new(control_type, flags)?) 147 | } 148 | 149 | #[inline] 150 | pub fn control_type(self) -> ControlType { 151 | get_control_type(self.0 >> 4).unwrap_or_else(|| { 152 | // SAFETY: this is maintained by the invariant for PacketType 153 | unsafe { std::hint::unreachable_unchecked() } 154 | }) 155 | } 156 | 157 | #[inline] 158 | pub fn flags(self) -> u8 { 159 | self.0 & 0x0F 160 | } 161 | } 162 | 163 | #[inline] 164 | fn get_control_type(val: u8) -> Option { 165 | let typ = match val { 166 | value::CONNECT => ControlType::Connect, 167 | value::CONNACK => ControlType::ConnectAcknowledgement, 168 | 169 | value::PUBLISH => ControlType::Publish, 170 | value::PUBACK => ControlType::PublishAcknowledgement, 171 | value::PUBREC => ControlType::PublishReceived, 172 | value::PUBREL => ControlType::PublishRelease, 173 | value::PUBCOMP => ControlType::PublishComplete, 174 | 175 | value::SUBSCRIBE => ControlType::Subscribe, 176 | value::SUBACK => ControlType::SubscribeAcknowledgement, 177 | 178 | value::UNSUBSCRIBE => ControlType::Unsubscribe, 179 | value::UNSUBACK => ControlType::UnsubscribeAcknowledgement, 180 | 181 | value::PINGREQ => ControlType::PingRequest, 182 | value::PINGRESP => ControlType::PingResponse, 183 | 184 | value::DISCONNECT => ControlType::Disconnect, 185 | 186 | _ => return None, 187 | }; 188 | Some(typ) 189 | } 190 | 191 | /// Parsing packet type errors 192 | #[derive(Debug, thiserror::Error)] 193 | pub enum PacketTypeError { 194 | #[error("reserved type {0:?} (flags {1:#X})")] 195 | ReservedType(u8, u8), 196 | #[error(transparent)] 197 | InvalidFlag(#[from] InvalidFlag), 198 | } 199 | 200 | #[derive(Debug, thiserror::Error)] 201 | #[error("invalid flag for {0:?} ({1:#X})")] 202 | pub struct InvalidFlag(pub ControlType, pub u8); 203 | 204 | #[rustfmt::skip] 205 | mod value { 206 | pub const CONNECT: u8 = 1; 207 | pub const CONNACK: u8 = 2; 208 | pub const PUBLISH: u8 = 3; 209 | pub const PUBACK: u8 = 4; 210 | pub const PUBREC: u8 = 5; 211 | pub const PUBREL: u8 = 6; 212 | pub const PUBCOMP: u8 = 7; 213 | pub const SUBSCRIBE: u8 = 8; 214 | pub const SUBACK: u8 = 9; 215 | pub const UNSUBSCRIBE: u8 = 10; 216 | pub const UNSUBACK: u8 = 11; 217 | pub const PINGREQ: u8 = 12; 218 | pub const PINGRESP: u8 = 13; 219 | pub const DISCONNECT: u8 = 14; 220 | } 221 | -------------------------------------------------------------------------------- /src/control/variable_header/connect_ack_flags.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Read, Write}; 2 | 3 | use byteorder::{ReadBytesExt, WriteBytesExt}; 4 | 5 | use crate::control::variable_header::VariableHeaderError; 6 | use crate::{Decodable, Encodable}; 7 | 8 | /// Flags in `CONNACK` packet 9 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 10 | pub struct ConnackFlags { 11 | pub session_present: bool, 12 | } 13 | 14 | impl ConnackFlags { 15 | pub fn empty() -> ConnackFlags { 16 | ConnackFlags { session_present: false } 17 | } 18 | } 19 | 20 | impl Encodable for ConnackFlags { 21 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 22 | let code = self.session_present as u8; 23 | writer.write_u8(code) 24 | } 25 | 26 | fn encoded_length(&self) -> u32 { 27 | 1 28 | } 29 | } 30 | 31 | impl Decodable for ConnackFlags { 32 | type Error = VariableHeaderError; 33 | type Cond = (); 34 | 35 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 36 | let code = reader.read_u8()?; 37 | if code & !1 != 0 { 38 | return Err(VariableHeaderError::InvalidReservedFlag); 39 | } 40 | 41 | Ok(ConnackFlags { 42 | session_present: code == 1, 43 | }) 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/control/variable_header/connect_flags.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Read, Write}; 2 | 3 | use byteorder::{ReadBytesExt, WriteBytesExt}; 4 | 5 | use crate::control::variable_header::VariableHeaderError; 6 | use crate::{Decodable, Encodable}; 7 | 8 | /// Flags for `CONNECT` packet 9 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 10 | pub struct ConnectFlags { 11 | pub user_name: bool, 12 | pub password: bool, 13 | pub will_retain: bool, 14 | pub will_qos: u8, 15 | pub will_flag: bool, 16 | pub clean_session: bool, 17 | // We never use this, but must decode because brokers must verify it's zero per [MQTT-3.1.2-3] 18 | pub reserved: bool, 19 | } 20 | 21 | impl ConnectFlags { 22 | pub fn empty() -> ConnectFlags { 23 | ConnectFlags { 24 | user_name: false, 25 | password: false, 26 | will_retain: false, 27 | will_qos: 0, 28 | will_flag: false, 29 | clean_session: false, 30 | reserved: false, 31 | } 32 | } 33 | } 34 | 35 | impl Encodable for ConnectFlags { 36 | #[rustfmt::skip] 37 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 38 | let code = ((self.user_name as u8) << 7) 39 | | ((self.password as u8) << 6) 40 | | ((self.will_retain as u8) << 5) 41 | | ((self.will_qos) << 3) 42 | | ((self.will_flag as u8) << 2) 43 | | ((self.clean_session as u8) << 1); 44 | 45 | writer.write_u8(code) 46 | } 47 | 48 | fn encoded_length(&self) -> u32 { 49 | 1 50 | } 51 | } 52 | 53 | impl Decodable for ConnectFlags { 54 | type Error = VariableHeaderError; 55 | type Cond = (); 56 | 57 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 58 | let code = reader.read_u8()?; 59 | if code & 1 != 0 { 60 | return Err(VariableHeaderError::InvalidReservedFlag); 61 | } 62 | 63 | Ok(ConnectFlags { 64 | user_name: (code & 0b1000_0000) != 0, 65 | password: (code & 0b0100_0000) != 0, 66 | will_retain: (code & 0b0010_0000) != 0, 67 | will_qos: (code & 0b0001_1000) >> 3, 68 | will_flag: (code & 0b0000_0100) != 0, 69 | clean_session: (code & 0b0000_0010) != 0, 70 | reserved: (code & 0b0000_0001) != 0, 71 | }) 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/control/variable_header/connect_ret_code.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Read, Write}; 2 | 3 | use byteorder::{ReadBytesExt, WriteBytesExt}; 4 | 5 | use crate::control::variable_header::VariableHeaderError; 6 | use crate::{Decodable, Encodable}; 7 | 8 | pub const CONNECTION_ACCEPTED: u8 = 0x00; 9 | pub const UNACCEPTABLE_PROTOCOL_VERSION: u8 = 0x01; 10 | pub const IDENTIFIER_REJECTED: u8 = 0x02; 11 | pub const SERVICE_UNAVAILABLE: u8 = 0x03; 12 | pub const BAD_USER_NAME_OR_PASSWORD: u8 = 0x04; 13 | pub const NOT_AUTHORIZED: u8 = 0x05; 14 | 15 | /// Return code for `CONNACK` packet 16 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 17 | pub enum ConnectReturnCode { 18 | ConnectionAccepted, 19 | UnacceptableProtocolVersion, 20 | IdentifierRejected, 21 | ServiceUnavailable, 22 | BadUserNameOrPassword, 23 | NotAuthorized, 24 | Reserved(u8), 25 | } 26 | 27 | impl ConnectReturnCode { 28 | /// Get the code 29 | pub fn to_u8(self) -> u8 { 30 | match self { 31 | ConnectReturnCode::ConnectionAccepted => CONNECTION_ACCEPTED, 32 | ConnectReturnCode::UnacceptableProtocolVersion => UNACCEPTABLE_PROTOCOL_VERSION, 33 | ConnectReturnCode::IdentifierRejected => IDENTIFIER_REJECTED, 34 | ConnectReturnCode::ServiceUnavailable => SERVICE_UNAVAILABLE, 35 | ConnectReturnCode::BadUserNameOrPassword => BAD_USER_NAME_OR_PASSWORD, 36 | ConnectReturnCode::NotAuthorized => NOT_AUTHORIZED, 37 | ConnectReturnCode::Reserved(r) => r, 38 | } 39 | } 40 | 41 | /// Create `ConnectReturnCode` from code 42 | pub fn from_u8(code: u8) -> ConnectReturnCode { 43 | match code { 44 | CONNECTION_ACCEPTED => ConnectReturnCode::ConnectionAccepted, 45 | UNACCEPTABLE_PROTOCOL_VERSION => ConnectReturnCode::UnacceptableProtocolVersion, 46 | IDENTIFIER_REJECTED => ConnectReturnCode::IdentifierRejected, 47 | SERVICE_UNAVAILABLE => ConnectReturnCode::ServiceUnavailable, 48 | BAD_USER_NAME_OR_PASSWORD => ConnectReturnCode::BadUserNameOrPassword, 49 | NOT_AUTHORIZED => ConnectReturnCode::NotAuthorized, 50 | _ => ConnectReturnCode::Reserved(code), 51 | } 52 | } 53 | } 54 | 55 | impl Encodable for ConnectReturnCode { 56 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 57 | writer.write_u8(self.to_u8()) 58 | } 59 | 60 | fn encoded_length(&self) -> u32 { 61 | 1 62 | } 63 | } 64 | 65 | impl Decodable for ConnectReturnCode { 66 | type Error = VariableHeaderError; 67 | type Cond = (); 68 | 69 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 70 | reader.read_u8().map(ConnectReturnCode::from_u8).map_err(From::from) 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/control/variable_header/keep_alive.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Read, Write}; 2 | 3 | use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; 4 | 5 | use crate::control::variable_header::VariableHeaderError; 6 | use crate::{Decodable, Encodable}; 7 | 8 | /// Keep alive time interval 9 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 10 | pub struct KeepAlive(pub u16); 11 | 12 | impl Encodable for KeepAlive { 13 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 14 | writer.write_u16::(self.0) 15 | } 16 | 17 | fn encoded_length(&self) -> u32 { 18 | 2 19 | } 20 | } 21 | 22 | impl Decodable for KeepAlive { 23 | type Error = VariableHeaderError; 24 | type Cond = (); 25 | 26 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 27 | reader.read_u16::().map(KeepAlive).map_err(From::from) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/control/variable_header/mod.rs: -------------------------------------------------------------------------------- 1 | //! Variable header in MQTT 2 | 3 | use std::io; 4 | use std::string::FromUtf8Error; 5 | 6 | use crate::topic_name::{TopicNameDecodeError, TopicNameError}; 7 | 8 | pub use self::connect_ack_flags::ConnackFlags; 9 | pub use self::connect_flags::ConnectFlags; 10 | pub use self::connect_ret_code::ConnectReturnCode; 11 | pub use self::keep_alive::KeepAlive; 12 | pub use self::packet_identifier::PacketIdentifier; 13 | pub use self::protocol_level::ProtocolLevel; 14 | pub use self::protocol_name::ProtocolName; 15 | pub use self::topic_name::TopicNameHeader; 16 | 17 | mod connect_ack_flags; 18 | mod connect_flags; 19 | mod connect_ret_code; 20 | mod keep_alive; 21 | mod packet_identifier; 22 | pub mod protocol_level; 23 | mod protocol_name; 24 | mod topic_name; 25 | 26 | /// Errors while decoding variable header 27 | #[derive(Debug, thiserror::Error)] 28 | pub enum VariableHeaderError { 29 | #[error(transparent)] 30 | IoError(#[from] io::Error), 31 | #[error("invalid reserved flags")] 32 | InvalidReservedFlag, 33 | #[error(transparent)] 34 | FromUtf8Error(#[from] FromUtf8Error), 35 | #[error(transparent)] 36 | TopicNameError(#[from] TopicNameError), 37 | #[error("invalid protocol version")] 38 | InvalidProtocolVersion, 39 | } 40 | 41 | impl From for VariableHeaderError { 42 | fn from(err: TopicNameDecodeError) -> VariableHeaderError { 43 | match err { 44 | TopicNameDecodeError::IoError(e) => Self::IoError(e), 45 | TopicNameDecodeError::InvalidTopicName(e) => Self::TopicNameError(e), 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/control/variable_header/packet_identifier.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Read, Write}; 2 | 3 | use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; 4 | 5 | use crate::control::variable_header::VariableHeaderError; 6 | use crate::{Decodable, Encodable}; 7 | 8 | /// Packet identifier 9 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 10 | pub struct PacketIdentifier(pub u16); 11 | 12 | impl Encodable for PacketIdentifier { 13 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 14 | writer.write_u16::(self.0) 15 | } 16 | 17 | fn encoded_length(&self) -> u32 { 18 | 2 19 | } 20 | } 21 | 22 | impl Decodable for PacketIdentifier { 23 | type Error = VariableHeaderError; 24 | type Cond = (); 25 | 26 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 27 | reader.read_u16::().map(PacketIdentifier).map_err(From::from) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/control/variable_header/protocol_level.rs: -------------------------------------------------------------------------------- 1 | //! Protocol level header 2 | 3 | use std::io::{self, Read, Write}; 4 | 5 | use byteorder::{ReadBytesExt, WriteBytesExt}; 6 | 7 | use crate::control::variable_header::VariableHeaderError; 8 | use crate::{Decodable, Encodable}; 9 | 10 | pub const SPEC_3_1_0: u8 = 0x03; 11 | pub const SPEC_3_1_1: u8 = 0x04; 12 | pub const SPEC_5_0: u8 = 0x05; 13 | 14 | /// Protocol level in MQTT (`0x04` in v3.1.1) 15 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 16 | #[repr(u8)] 17 | pub enum ProtocolLevel { 18 | Version310 = SPEC_3_1_0, 19 | Version311 = SPEC_3_1_1, 20 | Version50 = SPEC_5_0, 21 | } 22 | 23 | impl Encodable for ProtocolLevel { 24 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 25 | writer.write_u8(*self as u8) 26 | } 27 | 28 | fn encoded_length(&self) -> u32 { 29 | 1 30 | } 31 | } 32 | 33 | impl Decodable for ProtocolLevel { 34 | type Error = VariableHeaderError; 35 | type Cond = (); 36 | 37 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 38 | reader 39 | .read_u8() 40 | .map_err(From::from) 41 | .map(ProtocolLevel::from_u8) 42 | .and_then(|x| x.ok_or(VariableHeaderError::InvalidProtocolVersion)) 43 | } 44 | } 45 | 46 | impl ProtocolLevel { 47 | pub fn from_u8(n: u8) -> Option { 48 | match n { 49 | SPEC_3_1_0 => Some(ProtocolLevel::Version310), 50 | SPEC_3_1_1 => Some(ProtocolLevel::Version311), 51 | SPEC_5_0 => Some(ProtocolLevel::Version50), 52 | _ => None, 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/control/variable_header/protocol_name.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Read, Write}; 2 | 3 | use crate::control::variable_header::VariableHeaderError; 4 | use crate::{Decodable, Encodable}; 5 | 6 | /// Protocol name in variable header 7 | /// 8 | /// # Example 9 | /// 10 | /// ```plain 11 | /// 7 3 0 12 | /// +--------------------------+--------------------------+ 13 | /// | Length MSB (0) | 14 | /// | Length LSB (4) | 15 | /// | 0100 | 1101 | 'M' 16 | /// | 0101 | 0001 | 'Q' 17 | /// | 0101 | 0100 | 'T' 18 | /// | 0101 | 0100 | 'T' 19 | /// +--------------------------+--------------------------+ 20 | /// ``` 21 | #[derive(Debug, Eq, PartialEq, Clone)] 22 | pub struct ProtocolName(pub String); 23 | 24 | impl Encodable for ProtocolName { 25 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 26 | (&self.0[..]).encode(writer) 27 | } 28 | 29 | fn encoded_length(&self) -> u32 { 30 | (&self.0[..]).encoded_length() 31 | } 32 | } 33 | 34 | impl Decodable for ProtocolName { 35 | type Error = VariableHeaderError; 36 | type Cond = (); 37 | 38 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 39 | Ok(ProtocolName(Decodable::decode(reader)?)) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/control/variable_header/topic_name.rs: -------------------------------------------------------------------------------- 1 | use std::io::{self, Read, Write}; 2 | 3 | use crate::control::variable_header::VariableHeaderError; 4 | use crate::topic_name::TopicName; 5 | use crate::{Decodable, Encodable}; 6 | 7 | /// Topic name wrapper 8 | #[derive(Debug, Eq, PartialEq, Clone)] 9 | pub struct TopicNameHeader(TopicName); 10 | 11 | impl TopicNameHeader { 12 | pub fn new(topic_name: String) -> Result { 13 | match TopicName::new(topic_name) { 14 | Ok(h) => Ok(TopicNameHeader(h)), 15 | Err(err) => Err(VariableHeaderError::TopicNameError(err)), 16 | } 17 | } 18 | } 19 | 20 | impl From for TopicName { 21 | fn from(hdr: TopicNameHeader) -> Self { 22 | hdr.0 23 | } 24 | } 25 | 26 | impl Encodable for TopicNameHeader { 27 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 28 | (&self.0[..]).encode(writer) 29 | } 30 | 31 | fn encoded_length(&self) -> u32 { 32 | (&self.0[..]).encoded_length() 33 | } 34 | } 35 | 36 | impl Decodable for TopicNameHeader { 37 | type Error = VariableHeaderError; 38 | type Cond = (); 39 | 40 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 41 | TopicNameHeader::new(Decodable::decode(reader)?) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/encodable.rs: -------------------------------------------------------------------------------- 1 | //! Encodable traits 2 | 3 | use std::convert::Infallible; 4 | use std::error::Error; 5 | 6 | use std::io::{self, Read, Write}; 7 | use std::marker::Sized; 8 | 9 | use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; 10 | 11 | /// Methods for encoding an Object to bytes according to MQTT specification 12 | pub trait Encodable { 13 | /// Encodes to writer 14 | fn encode(&self, writer: &mut W) -> io::Result<()>; 15 | /// Length of bytes after encoded 16 | fn encoded_length(&self) -> u32; 17 | } 18 | 19 | // impl Encodable for &T { 20 | // fn encode(&self, writer: &mut W) -> io::Result<()> { 21 | // (**self).encode(writer) 22 | // } 23 | // fn encoded_length(&self) -> u32 { 24 | // (**self).encoded_length() 25 | // } 26 | // } 27 | 28 | impl Encodable for Option { 29 | fn encode(&self, writer: &mut W) -> io::Result<()> { 30 | if let Some(this) = self { 31 | this.encode(writer)? 32 | } 33 | Ok(()) 34 | } 35 | 36 | fn encoded_length(&self) -> u32 { 37 | self.as_ref().map_or(0, |x| x.encoded_length()) 38 | } 39 | } 40 | 41 | /// Methods for decoding bytes to an Object according to MQTT specification 42 | pub trait Decodable: Sized { 43 | type Error: Error; 44 | type Cond; 45 | 46 | /// Decodes object from reader 47 | fn decode(reader: &mut R) -> Result 48 | where 49 | Self::Cond: Default, 50 | { 51 | Self::decode_with(reader, Default::default()) 52 | } 53 | 54 | /// Decodes object with additional data (or hints) 55 | fn decode_with(reader: &mut R, cond: Self::Cond) -> Result; 56 | } 57 | 58 | impl<'a> Encodable for &'a str { 59 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 60 | assert!(self.as_bytes().len() <= u16::max_value() as usize); 61 | 62 | writer 63 | .write_u16::(self.as_bytes().len() as u16) 64 | .and_then(|_| writer.write_all(self.as_bytes())) 65 | } 66 | 67 | fn encoded_length(&self) -> u32 { 68 | 2 + self.as_bytes().len() as u32 69 | } 70 | } 71 | 72 | impl<'a> Encodable for &'a [u8] { 73 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 74 | writer.write_all(self) 75 | } 76 | 77 | fn encoded_length(&self) -> u32 { 78 | self.len() as u32 79 | } 80 | } 81 | 82 | impl Encodable for String { 83 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 84 | (&self[..]).encode(writer) 85 | } 86 | 87 | fn encoded_length(&self) -> u32 { 88 | (&self[..]).encoded_length() 89 | } 90 | } 91 | 92 | impl Decodable for String { 93 | type Error = io::Error; 94 | type Cond = (); 95 | 96 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 97 | let VarBytes(buf) = VarBytes::decode(reader)?; 98 | 99 | String::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) 100 | } 101 | } 102 | 103 | impl Encodable for Vec { 104 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 105 | (&self[..]).encode(writer) 106 | } 107 | 108 | fn encoded_length(&self) -> u32 { 109 | (&self[..]).encoded_length() 110 | } 111 | } 112 | 113 | impl Decodable for Vec { 114 | type Error = io::Error; 115 | type Cond = Option; 116 | 117 | fn decode_with(reader: &mut R, length: Option) -> Result, io::Error> { 118 | match length { 119 | Some(length) => { 120 | let mut buf = Vec::with_capacity(length as usize); 121 | reader.take(length.into()).read_to_end(&mut buf)?; 122 | Ok(buf) 123 | } 124 | None => { 125 | let mut buf = Vec::new(); 126 | reader.read_to_end(&mut buf)?; 127 | Ok(buf) 128 | } 129 | } 130 | } 131 | } 132 | 133 | impl Encodable for () { 134 | fn encode(&self, _: &mut W) -> Result<(), io::Error> { 135 | Ok(()) 136 | } 137 | 138 | fn encoded_length(&self) -> u32 { 139 | 0 140 | } 141 | } 142 | 143 | impl Decodable for () { 144 | type Error = Infallible; 145 | type Cond = (); 146 | 147 | fn decode_with(_: &mut R, _: ()) -> Result<(), Self::Error> { 148 | Ok(()) 149 | } 150 | } 151 | 152 | /// Bytes that encoded with length 153 | #[derive(Debug, Eq, PartialEq, Clone)] 154 | pub struct VarBytes(pub Vec); 155 | 156 | impl Encodable for VarBytes { 157 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 158 | assert!(self.0.len() <= u16::max_value() as usize); 159 | let len = self.0.len() as u16; 160 | writer.write_u16::(len)?; 161 | writer.write_all(&self.0)?; 162 | Ok(()) 163 | } 164 | 165 | fn encoded_length(&self) -> u32 { 166 | 2 + self.0.len() as u32 167 | } 168 | } 169 | 170 | impl Decodable for VarBytes { 171 | type Error = io::Error; 172 | type Cond = (); 173 | fn decode_with(reader: &mut R, _: ()) -> Result { 174 | let length = reader.read_u16::()?; 175 | let mut buf = Vec::with_capacity(length as usize); 176 | reader.take(length.into()).read_to_end(&mut buf)?; 177 | Ok(VarBytes(buf)) 178 | } 179 | } 180 | 181 | #[cfg(test)] 182 | mod test { 183 | use super::*; 184 | 185 | use std::io::Cursor; 186 | 187 | #[test] 188 | fn varbyte_encode() { 189 | let test_var = vec![0, 1, 2, 3, 4, 5]; 190 | let bytes = VarBytes(test_var); 191 | 192 | assert_eq!(bytes.encoded_length() as usize, 2 + 6); 193 | 194 | let mut buf = Vec::new(); 195 | bytes.encode(&mut buf).unwrap(); 196 | 197 | assert_eq!(&buf, &[0, 6, 0, 1, 2, 3, 4, 5]); 198 | 199 | let mut reader = Cursor::new(buf); 200 | let decoded = VarBytes::decode(&mut reader).unwrap(); 201 | 202 | assert_eq!(decoded, bytes); 203 | } 204 | } 205 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! MQTT protocol utilities library 2 | //! 3 | //! Strictly implements protocol of [MQTT v3.1.1](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html) 4 | //! 5 | //! ## Usage 6 | //! 7 | //! ```rust 8 | //! use std::io::Cursor; 9 | //! 10 | //! use mqtt::{Encodable, Decodable}; 11 | //! use mqtt::packet::{VariablePacket, PublishPacket, QoSWithPacketIdentifier}; 12 | //! use mqtt::TopicName; 13 | //! 14 | //! // Create a new Publish packet 15 | //! let packet = PublishPacket::new(TopicName::new("mqtt/learning").unwrap(), 16 | //! QoSWithPacketIdentifier::Level2(10), 17 | //! b"Hello MQTT!".to_vec()); 18 | //! 19 | //! // Encode 20 | //! let mut buf = Vec::new(); 21 | //! packet.encode(&mut buf).unwrap(); 22 | //! println!("Encoded: {:?}", buf); 23 | //! 24 | //! // Decode it with known type 25 | //! let mut dec_buf = Cursor::new(&buf[..]); 26 | //! let decoded = PublishPacket::decode(&mut dec_buf).unwrap(); 27 | //! println!("Decoded: {:?}", decoded); 28 | //! assert_eq!(packet, decoded); 29 | //! 30 | //! // Auto decode by the fixed header 31 | //! let mut dec_buf = Cursor::new(&buf[..]); 32 | //! let auto_decode = VariablePacket::decode(&mut dec_buf).unwrap(); 33 | //! println!("Variable packet decode: {:?}", auto_decode); 34 | //! assert_eq!(VariablePacket::PublishPacket(packet), auto_decode); 35 | //! ``` 36 | 37 | pub use self::encodable::{Decodable, Encodable}; 38 | pub use self::qos::QualityOfService; 39 | pub use self::topic_filter::{TopicFilter, TopicFilterRef}; 40 | pub use self::topic_name::{TopicName, TopicNameRef}; 41 | 42 | pub mod control; 43 | pub mod encodable; 44 | pub mod packet; 45 | pub mod qos; 46 | pub mod topic_filter; 47 | pub mod topic_name; 48 | -------------------------------------------------------------------------------- /src/packet/connack.rs: -------------------------------------------------------------------------------- 1 | //! CONNACK 2 | 3 | use std::io::Read; 4 | 5 | use crate::control::variable_header::{ConnackFlags, ConnectReturnCode}; 6 | use crate::control::{ControlType, FixedHeader, PacketType}; 7 | use crate::packet::{DecodablePacket, PacketError}; 8 | use crate::Decodable; 9 | 10 | /// `CONNACK` packet 11 | #[derive(Debug, Eq, PartialEq, Clone)] 12 | pub struct ConnackPacket { 13 | fixed_header: FixedHeader, 14 | flags: ConnackFlags, 15 | ret_code: ConnectReturnCode, 16 | } 17 | 18 | encodable_packet!(ConnackPacket(flags, ret_code)); 19 | 20 | impl ConnackPacket { 21 | pub fn new(session_present: bool, ret_code: ConnectReturnCode) -> ConnackPacket { 22 | ConnackPacket { 23 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::ConnectAcknowledgement), 2), 24 | flags: ConnackFlags { session_present }, 25 | ret_code, 26 | } 27 | } 28 | 29 | pub fn connack_flags(&self) -> ConnackFlags { 30 | self.flags 31 | } 32 | 33 | pub fn connect_return_code(&self) -> ConnectReturnCode { 34 | self.ret_code 35 | } 36 | } 37 | 38 | impl DecodablePacket for ConnackPacket { 39 | type DecodePacketError = std::convert::Infallible; 40 | 41 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 42 | let flags: ConnackFlags = Decodable::decode(reader)?; 43 | let code: ConnectReturnCode = Decodable::decode(reader)?; 44 | 45 | Ok(ConnackPacket { 46 | fixed_header, 47 | flags, 48 | ret_code: code, 49 | }) 50 | } 51 | } 52 | 53 | #[cfg(test)] 54 | mod test { 55 | use super::*; 56 | 57 | use std::io::Cursor; 58 | 59 | use crate::control::variable_header::ConnectReturnCode; 60 | use crate::{Decodable, Encodable}; 61 | 62 | #[test] 63 | pub fn test_connack_packet_basic() { 64 | let packet = ConnackPacket::new(false, ConnectReturnCode::IdentifierRejected); 65 | 66 | let mut buf = Vec::new(); 67 | packet.encode(&mut buf).unwrap(); 68 | 69 | let mut decode_buf = Cursor::new(buf); 70 | let decoded = ConnackPacket::decode(&mut decode_buf).unwrap(); 71 | 72 | assert_eq!(packet, decoded); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/packet/connect.rs: -------------------------------------------------------------------------------- 1 | //! CONNECT 2 | 3 | use std::io::{self, Read, Write}; 4 | 5 | use crate::control::variable_header::protocol_level::SPEC_3_1_1; 6 | use crate::control::variable_header::{ConnectFlags, KeepAlive, ProtocolLevel, ProtocolName, VariableHeaderError}; 7 | use crate::control::{ControlType, FixedHeader, PacketType}; 8 | use crate::encodable::VarBytes; 9 | use crate::packet::{DecodablePacket, PacketError}; 10 | use crate::topic_name::{TopicName, TopicNameDecodeError, TopicNameError}; 11 | use crate::{Decodable, Encodable}; 12 | 13 | /// `CONNECT` packet 14 | #[derive(Debug, Eq, PartialEq, Clone)] 15 | pub struct ConnectPacket { 16 | fixed_header: FixedHeader, 17 | protocol_name: ProtocolName, 18 | 19 | protocol_level: ProtocolLevel, 20 | flags: ConnectFlags, 21 | keep_alive: KeepAlive, 22 | 23 | payload: ConnectPacketPayload, 24 | } 25 | 26 | encodable_packet!(ConnectPacket(protocol_name, protocol_level, flags, keep_alive, payload)); 27 | 28 | impl ConnectPacket { 29 | pub fn new(client_identifier: C) -> ConnectPacket 30 | where 31 | C: Into, 32 | { 33 | ConnectPacket::with_level("MQTT", client_identifier, SPEC_3_1_1).expect("SPEC_3_1_1 should always be valid") 34 | } 35 | 36 | pub fn with_level(protoname: P, client_identifier: C, level: u8) -> Result 37 | where 38 | P: Into, 39 | C: Into, 40 | { 41 | let protocol_level = ProtocolLevel::from_u8(level).ok_or(VariableHeaderError::InvalidProtocolVersion)?; 42 | let mut pk = ConnectPacket { 43 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Connect), 0), 44 | protocol_name: ProtocolName(protoname.into()), 45 | protocol_level, 46 | flags: ConnectFlags::empty(), 47 | keep_alive: KeepAlive(0), 48 | payload: ConnectPacketPayload::new(client_identifier.into()), 49 | }; 50 | 51 | pk.fix_header_remaining_len(); 52 | 53 | Ok(pk) 54 | } 55 | 56 | pub fn set_keep_alive(&mut self, keep_alive: u16) { 57 | self.keep_alive = KeepAlive(keep_alive); 58 | } 59 | 60 | pub fn set_user_name(&mut self, name: Option) { 61 | self.flags.user_name = name.is_some(); 62 | self.payload.user_name = name; 63 | self.fix_header_remaining_len(); 64 | } 65 | 66 | pub fn set_will(&mut self, topic_message: Option<(TopicName, Vec)>) { 67 | self.flags.will_flag = topic_message.is_some(); 68 | 69 | self.payload.will = topic_message.map(|(t, m)| (t, VarBytes(m))); 70 | 71 | self.fix_header_remaining_len(); 72 | } 73 | 74 | pub fn set_password(&mut self, password: Option) { 75 | self.flags.password = password.is_some(); 76 | self.payload.password = password; 77 | self.fix_header_remaining_len(); 78 | } 79 | 80 | pub fn set_client_identifier>(&mut self, id: I) { 81 | self.payload.client_identifier = id.into(); 82 | self.fix_header_remaining_len(); 83 | } 84 | 85 | pub fn set_will_retain(&mut self, will_retain: bool) { 86 | self.flags.will_retain = will_retain; 87 | } 88 | 89 | pub fn set_will_qos(&mut self, will_qos: u8) { 90 | assert!(will_qos <= 2); 91 | self.flags.will_qos = will_qos; 92 | } 93 | 94 | pub fn set_clean_session(&mut self, clean_session: bool) { 95 | self.flags.clean_session = clean_session; 96 | } 97 | 98 | pub fn user_name(&self) -> Option<&str> { 99 | self.payload.user_name.as_ref().map(|x| &x[..]) 100 | } 101 | 102 | pub fn password(&self) -> Option<&str> { 103 | self.payload.password.as_ref().map(|x| &x[..]) 104 | } 105 | 106 | pub fn will(&self) -> Option<(&str, &[u8])> { 107 | self.payload.will.as_ref().map(|(topic, msg)| (&topic[..], &*msg.0)) 108 | } 109 | 110 | pub fn will_retain(&self) -> bool { 111 | self.flags.will_retain 112 | } 113 | 114 | pub fn will_qos(&self) -> u8 { 115 | self.flags.will_qos 116 | } 117 | 118 | pub fn client_identifier(&self) -> &str { 119 | &self.payload.client_identifier[..] 120 | } 121 | 122 | pub fn protocol_name(&self) -> &str { 123 | &self.protocol_name.0 124 | } 125 | 126 | pub fn protocol_level(&self) -> ProtocolLevel { 127 | self.protocol_level 128 | } 129 | 130 | pub fn clean_session(&self) -> bool { 131 | self.flags.clean_session 132 | } 133 | 134 | pub fn keep_alive(&self) -> u16 { 135 | self.keep_alive.0 136 | } 137 | 138 | /// Read back the "reserved" Connect flag bit 0. For compliant implementations this should 139 | /// always be false. 140 | pub fn reserved_flag(&self) -> bool { 141 | self.flags.reserved 142 | } 143 | } 144 | 145 | impl DecodablePacket for ConnectPacket { 146 | type DecodePacketError = ConnectPacketError; 147 | 148 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 149 | let protoname: ProtocolName = Decodable::decode(reader)?; 150 | let protocol_level: ProtocolLevel = Decodable::decode(reader)?; 151 | let flags: ConnectFlags = Decodable::decode(reader)?; 152 | let keep_alive: KeepAlive = Decodable::decode(reader)?; 153 | let payload: ConnectPacketPayload = 154 | Decodable::decode_with(reader, Some(flags)).map_err(PacketError::PayloadError)?; 155 | 156 | Ok(ConnectPacket { 157 | fixed_header, 158 | protocol_name: protoname, 159 | protocol_level, 160 | flags, 161 | keep_alive, 162 | payload, 163 | }) 164 | } 165 | } 166 | 167 | /// Payloads for connect packet 168 | #[derive(Debug, Eq, PartialEq, Clone)] 169 | struct ConnectPacketPayload { 170 | client_identifier: String, 171 | will: Option<(TopicName, VarBytes)>, 172 | user_name: Option, 173 | password: Option, 174 | } 175 | 176 | impl ConnectPacketPayload { 177 | pub fn new(client_identifier: String) -> ConnectPacketPayload { 178 | ConnectPacketPayload { 179 | client_identifier, 180 | will: None, 181 | user_name: None, 182 | password: None, 183 | } 184 | } 185 | } 186 | 187 | impl Encodable for ConnectPacketPayload { 188 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 189 | self.client_identifier.encode(writer)?; 190 | 191 | if let Some((will_topic, will_message)) = &self.will { 192 | will_topic.encode(writer)?; 193 | will_message.encode(writer)?; 194 | } 195 | 196 | if let Some(ref user_name) = self.user_name { 197 | user_name.encode(writer)?; 198 | } 199 | 200 | if let Some(ref password) = self.password { 201 | password.encode(writer)?; 202 | } 203 | 204 | Ok(()) 205 | } 206 | 207 | fn encoded_length(&self) -> u32 { 208 | self.client_identifier.encoded_length() 209 | + self 210 | .will 211 | .as_ref() 212 | .map(|(a, b)| a.encoded_length() + b.encoded_length()) 213 | .unwrap_or(0) 214 | + self.user_name.as_ref().map(|t| t.encoded_length()).unwrap_or(0) 215 | + self.password.as_ref().map(|t| t.encoded_length()).unwrap_or(0) 216 | } 217 | } 218 | 219 | impl Decodable for ConnectPacketPayload { 220 | type Error = ConnectPacketError; 221 | type Cond = Option; 222 | 223 | fn decode_with( 224 | reader: &mut R, 225 | rest: Option, 226 | ) -> Result { 227 | let mut need_will = false; 228 | let mut need_user_name = false; 229 | let mut need_password = false; 230 | 231 | if let Some(r) = rest { 232 | need_will = r.will_flag; 233 | need_user_name = r.user_name; 234 | need_password = r.password; 235 | } 236 | 237 | let ident = String::decode(reader)?; 238 | let will = if need_will { 239 | let topic = TopicName::decode(reader).map_err(|e| match e { 240 | TopicNameDecodeError::IoError(e) => ConnectPacketError::from(e), 241 | TopicNameDecodeError::InvalidTopicName(e) => e.into(), 242 | })?; 243 | let msg = VarBytes::decode(reader)?; 244 | Some((topic, msg)) 245 | } else { 246 | None 247 | }; 248 | let uname = if need_user_name { 249 | Some(String::decode(reader)?) 250 | } else { 251 | None 252 | }; 253 | let pwd = if need_password { 254 | Some(String::decode(reader)?) 255 | } else { 256 | None 257 | }; 258 | 259 | Ok(ConnectPacketPayload { 260 | client_identifier: ident, 261 | will, 262 | user_name: uname, 263 | password: pwd, 264 | }) 265 | } 266 | } 267 | 268 | #[derive(Debug, thiserror::Error)] 269 | #[error(transparent)] 270 | pub enum ConnectPacketError { 271 | IoError(#[from] io::Error), 272 | TopicNameError(#[from] TopicNameError), 273 | } 274 | 275 | #[cfg(test)] 276 | mod test { 277 | use super::*; 278 | 279 | use std::io::Cursor; 280 | 281 | use crate::{Decodable, Encodable}; 282 | 283 | #[test] 284 | fn test_connect_packet_encode_basic() { 285 | let packet = ConnectPacket::new("12345".to_owned()); 286 | let expected = b"\x10\x11\x00\x04MQTT\x04\x00\x00\x00\x00\x0512345"; 287 | 288 | let mut buf = Vec::new(); 289 | packet.encode(&mut buf).unwrap(); 290 | 291 | assert_eq!(&expected[..], &buf[..]); 292 | } 293 | 294 | #[test] 295 | fn test_connect_packet_decode_basic() { 296 | let encoded_data = b"\x10\x11\x00\x04MQTT\x04\x00\x00\x00\x00\x0512345"; 297 | 298 | let mut buf = Cursor::new(&encoded_data[..]); 299 | let packet = ConnectPacket::decode(&mut buf).unwrap(); 300 | 301 | let expected = ConnectPacket::new("12345".to_owned()); 302 | assert_eq!(expected, packet); 303 | } 304 | 305 | #[test] 306 | fn test_connect_packet_user_name() { 307 | let mut packet = ConnectPacket::new("12345".to_owned()); 308 | packet.set_user_name(Some("mqtt_player".to_owned())); 309 | 310 | let mut buf = Vec::new(); 311 | packet.encode(&mut buf).unwrap(); 312 | 313 | let mut decode_buf = Cursor::new(buf); 314 | let decoded_packet = ConnectPacket::decode(&mut decode_buf).unwrap(); 315 | 316 | assert_eq!(packet, decoded_packet); 317 | } 318 | } 319 | -------------------------------------------------------------------------------- /src/packet/disconnect.rs: -------------------------------------------------------------------------------- 1 | //! DISCONNECT 2 | 3 | use std::io::Read; 4 | 5 | use crate::control::{ControlType, FixedHeader, PacketType}; 6 | use crate::packet::{DecodablePacket, PacketError}; 7 | 8 | /// `DISCONNECT` packet 9 | #[derive(Debug, Eq, PartialEq, Clone)] 10 | pub struct DisconnectPacket { 11 | fixed_header: FixedHeader, 12 | } 13 | 14 | encodable_packet!(DisconnectPacket()); 15 | 16 | impl DisconnectPacket { 17 | pub fn new() -> DisconnectPacket { 18 | DisconnectPacket { 19 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Disconnect), 0), 20 | } 21 | } 22 | } 23 | 24 | impl Default for DisconnectPacket { 25 | fn default() -> DisconnectPacket { 26 | DisconnectPacket::new() 27 | } 28 | } 29 | 30 | impl DecodablePacket for DisconnectPacket { 31 | type DecodePacketError = std::convert::Infallible; 32 | 33 | fn decode_packet(_reader: &mut R, fixed_header: FixedHeader) -> Result> { 34 | Ok(DisconnectPacket { fixed_header }) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/packet/mod.rs: -------------------------------------------------------------------------------- 1 | //! Specific packets 2 | 3 | use std::error::Error; 4 | use std::fmt::{self, Debug}; 5 | use std::io::{self, Read, Write}; 6 | 7 | #[cfg(feature = "tokio")] 8 | use tokio::io::{AsyncRead, AsyncReadExt}; 9 | 10 | use crate::control::fixed_header::FixedHeaderError; 11 | use crate::control::variable_header::VariableHeaderError; 12 | use crate::control::ControlType; 13 | use crate::control::FixedHeader; 14 | use crate::topic_name::{TopicNameDecodeError, TopicNameError}; 15 | use crate::{Decodable, Encodable}; 16 | 17 | macro_rules! encodable_packet { 18 | ($typ:ident($($field:ident),* $(,)?)) => { 19 | impl $crate::packet::EncodablePacket for $typ { 20 | fn fixed_header(&self) -> &$crate::control::fixed_header::FixedHeader { 21 | &self.fixed_header 22 | } 23 | 24 | #[allow(unused)] 25 | fn encode_packet(&self, writer: &mut W) -> ::std::io::Result<()> { 26 | $($crate::encodable::Encodable::encode(&self.$field, writer)?;)* 27 | Ok(()) 28 | } 29 | 30 | fn encoded_packet_length(&self) -> u32 { 31 | $($crate::encodable::Encodable::encoded_length(&self.$field) +)* 32 | 0 33 | } 34 | } 35 | 36 | impl $typ { 37 | #[allow(unused)] 38 | #[inline(always)] 39 | fn fix_header_remaining_len(&mut self) { 40 | self.fixed_header.remaining_length = $crate::packet::EncodablePacket::encoded_packet_length(self); 41 | } 42 | } 43 | }; 44 | } 45 | 46 | pub use self::connack::ConnackPacket; 47 | pub use self::connect::ConnectPacket; 48 | pub use self::disconnect::DisconnectPacket; 49 | pub use self::pingreq::PingreqPacket; 50 | pub use self::pingresp::PingrespPacket; 51 | pub use self::puback::PubackPacket; 52 | pub use self::pubcomp::PubcompPacket; 53 | pub use self::publish::{PublishPacket, PublishPacketRef}; 54 | pub use self::pubrec::PubrecPacket; 55 | pub use self::pubrel::PubrelPacket; 56 | pub use self::suback::SubackPacket; 57 | pub use self::subscribe::SubscribePacket; 58 | pub use self::unsuback::UnsubackPacket; 59 | pub use self::unsubscribe::UnsubscribePacket; 60 | 61 | pub use self::publish::QoSWithPacketIdentifier; 62 | 63 | pub mod connack; 64 | pub mod connect; 65 | pub mod disconnect; 66 | pub mod pingreq; 67 | pub mod pingresp; 68 | pub mod puback; 69 | pub mod pubcomp; 70 | pub mod publish; 71 | pub mod pubrec; 72 | pub mod pubrel; 73 | pub mod suback; 74 | pub mod subscribe; 75 | pub mod unsuback; 76 | pub mod unsubscribe; 77 | 78 | /// A trait representing a packet that can be encoded, when passed as `FooPacket` or as 79 | /// `&FooPacket`. Different from [`Encodable`] in that it prevents you from accidentally passing 80 | /// a type intended to be encoded only as a part of a packet and doesn't have a header, e.g. 81 | /// `Vec`. 82 | pub trait EncodablePacket { 83 | /// Get a reference to `FixedHeader`. All MQTT packet must have a fixed header. 84 | fn fixed_header(&self) -> &FixedHeader; 85 | 86 | /// Encodes packet data after fixed header, including variable headers and payload 87 | fn encode_packet(&self, _writer: &mut W) -> io::Result<()> { 88 | Ok(()) 89 | } 90 | 91 | /// Length in bytes for data after fixed header, including variable headers and payload 92 | fn encoded_packet_length(&self) -> u32 { 93 | 0 94 | } 95 | } 96 | 97 | impl Encodable for T { 98 | fn encode(&self, writer: &mut W) -> io::Result<()> { 99 | self.fixed_header().encode(writer)?; 100 | self.encode_packet(writer) 101 | } 102 | 103 | fn encoded_length(&self) -> u32 { 104 | self.fixed_header().encoded_length() + self.encoded_packet_length() 105 | } 106 | } 107 | 108 | pub trait DecodablePacket: EncodablePacket + Sized { 109 | type DecodePacketError: Error + 'static; 110 | 111 | /// Decode packet given a `FixedHeader` 112 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result>; 113 | } 114 | 115 | impl Decodable for T { 116 | type Error = PacketError; 117 | type Cond = Option; 118 | 119 | fn decode_with(reader: &mut R, fixed_header: Self::Cond) -> Result { 120 | let fixed_header: FixedHeader = if let Some(hdr) = fixed_header { 121 | hdr 122 | } else { 123 | Decodable::decode(reader)? 124 | }; 125 | 126 | ::decode_packet(reader, fixed_header) 127 | } 128 | } 129 | 130 | /// Parsing errors for packet 131 | #[derive(thiserror::Error)] 132 | #[error(transparent)] 133 | pub enum PacketError

134 | where 135 | P: DecodablePacket, 136 | { 137 | FixedHeaderError(#[from] FixedHeaderError), 138 | VariableHeaderError(#[from] VariableHeaderError), 139 | PayloadError(

::DecodePacketError), 140 | IoError(#[from] io::Error), 141 | TopicNameError(#[from] TopicNameError), 142 | } 143 | 144 | impl

Debug for PacketError

145 | where 146 | P: DecodablePacket, 147 | { 148 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 149 | match *self { 150 | PacketError::FixedHeaderError(ref e) => f.debug_tuple("FixedHeaderError").field(e).finish(), 151 | PacketError::VariableHeaderError(ref e) => f.debug_tuple("VariableHeaderError").field(e).finish(), 152 | PacketError::PayloadError(ref e) => f.debug_tuple("PayloadError").field(e).finish(), 153 | PacketError::IoError(ref e) => f.debug_tuple("IoError").field(e).finish(), 154 | PacketError::TopicNameError(ref e) => f.debug_tuple("TopicNameError").field(e).finish(), 155 | } 156 | } 157 | } 158 | 159 | impl From for PacketError

{ 160 | fn from(e: TopicNameDecodeError) -> Self { 161 | match e { 162 | TopicNameDecodeError::IoError(e) => e.into(), 163 | TopicNameDecodeError::InvalidTopicName(e) => e.into(), 164 | } 165 | } 166 | } 167 | 168 | macro_rules! impl_variable_packet { 169 | ($($name:ident & $errname:ident => $hdr:ident,)+) => { 170 | /// Variable packet 171 | #[derive(Debug, Eq, PartialEq, Clone)] 172 | pub enum VariablePacket { 173 | $( 174 | $name($name), 175 | )+ 176 | } 177 | 178 | #[cfg(feature = "tokio")] 179 | impl VariablePacket { 180 | /// Asynchronously parse a packet from a `tokio::io::AsyncRead` 181 | /// 182 | /// This requires mqtt-rs to be built with `feature = "tokio"` 183 | pub async fn parse(rdr: &mut A) -> Result { 184 | use std::io::Cursor; 185 | let fixed_header = FixedHeader::parse(rdr).await?; 186 | 187 | let mut buffer = vec![0u8; fixed_header.remaining_length as usize]; 188 | rdr.read_exact(&mut buffer).await?; 189 | 190 | decode_with_header(&mut Cursor::new(buffer), fixed_header) 191 | } 192 | } 193 | 194 | #[inline] 195 | fn decode_with_header(rdr: &mut R, fixed_header: FixedHeader) -> Result { 196 | match fixed_header.packet_type.control_type() { 197 | $( 198 | ControlType::$hdr => { 199 | let pk = <$name as DecodablePacket>::decode_packet(rdr, fixed_header)?; 200 | Ok(VariablePacket::$name(pk)) 201 | } 202 | )+ 203 | } 204 | } 205 | 206 | $( 207 | impl From<$name> for VariablePacket { 208 | fn from(pk: $name) -> VariablePacket { 209 | VariablePacket::$name(pk) 210 | } 211 | } 212 | )+ 213 | 214 | // impl Encodable for VariablePacket { 215 | // fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 216 | // match *self { 217 | // $( 218 | // VariablePacket::$name(ref pk) => pk.encode(writer), 219 | // )+ 220 | // } 221 | // } 222 | 223 | // fn encoded_length(&self) -> u32 { 224 | // match *self { 225 | // $( 226 | // VariablePacket::$name(ref pk) => pk.encoded_length(), 227 | // )+ 228 | // } 229 | // } 230 | // } 231 | 232 | impl EncodablePacket for VariablePacket { 233 | fn fixed_header(&self) -> &FixedHeader { 234 | match *self { 235 | $( 236 | VariablePacket::$name(ref pk) => pk.fixed_header(), 237 | )+ 238 | } 239 | } 240 | 241 | fn encode_packet(&self, writer: &mut W) -> io::Result<()> { 242 | match *self { 243 | $( 244 | VariablePacket::$name(ref pk) => pk.encode_packet(writer), 245 | )+ 246 | } 247 | } 248 | 249 | fn encoded_packet_length(&self) -> u32 { 250 | match *self { 251 | $( 252 | VariablePacket::$name(ref pk) => pk.encoded_packet_length(), 253 | )+ 254 | } 255 | } 256 | } 257 | 258 | impl Decodable for VariablePacket { 259 | type Error = VariablePacketError; 260 | type Cond = Option; 261 | 262 | fn decode_with(reader: &mut R, fixed_header: Self::Cond) 263 | -> Result { 264 | let fixed_header = match fixed_header { 265 | Some(fh) => fh, 266 | None => { 267 | match FixedHeader::decode(reader) { 268 | Ok(header) => header, 269 | Err(FixedHeaderError::ReservedType(code, length)) => { 270 | let reader = &mut reader.take(length as u64); 271 | let mut buf = Vec::with_capacity(length as usize); 272 | reader.read_to_end(&mut buf)?; 273 | return Err(VariablePacketError::ReservedPacket(code, buf)); 274 | }, 275 | Err(err) => return Err(From::from(err)) 276 | } 277 | } 278 | }; 279 | let reader = &mut reader.take(fixed_header.remaining_length as u64); 280 | 281 | decode_with_header(reader, fixed_header) 282 | } 283 | } 284 | 285 | /// Parsing errors for variable packet 286 | #[derive(Debug, thiserror::Error)] 287 | pub enum VariablePacketError { 288 | #[error(transparent)] 289 | FixedHeaderError(#[from] FixedHeaderError), 290 | #[error("reserved packet type ({0}), [u8, ..{}]", .1.len())] 291 | ReservedPacket(u8, Vec), 292 | #[error(transparent)] 293 | IoError(#[from] io::Error), 294 | $( 295 | #[error(transparent)] 296 | $errname(#[from] PacketError<$name>), 297 | )+ 298 | } 299 | } 300 | } 301 | 302 | impl_variable_packet! { 303 | ConnectPacket & ConnectPacketError => Connect, 304 | ConnackPacket & ConnackPacketError => ConnectAcknowledgement, 305 | 306 | PublishPacket & PublishPacketError => Publish, 307 | PubackPacket & PubackPacketError => PublishAcknowledgement, 308 | PubrecPacket & PubrecPacketError => PublishReceived, 309 | PubrelPacket & PubrelPacketError => PublishRelease, 310 | PubcompPacket & PubcompPacketError => PublishComplete, 311 | 312 | PingreqPacket & PingreqPacketError => PingRequest, 313 | PingrespPacket & PingrespPacketError => PingResponse, 314 | 315 | SubscribePacket & SubscribePacketError => Subscribe, 316 | SubackPacket & SubackPacketError => SubscribeAcknowledgement, 317 | 318 | UnsubscribePacket & UnsubscribePacketError => Unsubscribe, 319 | UnsubackPacket & UnsubackPacketError => UnsubscribeAcknowledgement, 320 | 321 | DisconnectPacket & DisconnectPacketError => Disconnect, 322 | } 323 | 324 | impl VariablePacket { 325 | pub fn new(t: T) -> VariablePacket 326 | where 327 | VariablePacket: From, 328 | { 329 | From::from(t) 330 | } 331 | } 332 | 333 | #[cfg(feature = "tokio-codec")] 334 | mod tokio_codec { 335 | use super::*; 336 | use crate::control::packet_type::{PacketType, PacketTypeError}; 337 | use bytes::{Buf, BufMut, BytesMut}; 338 | use tokio_util::codec; 339 | 340 | pub struct MqttDecoder { 341 | state: DecodeState, 342 | } 343 | 344 | enum DecodeState { 345 | Start, 346 | Packet { length: u32, typ: DecodePacketType }, 347 | } 348 | 349 | #[derive(Copy, Clone)] 350 | enum DecodePacketType { 351 | Standard(PacketType), 352 | Reserved(u8), 353 | } 354 | 355 | impl MqttDecoder { 356 | pub const fn new() -> Self { 357 | MqttDecoder { 358 | state: DecodeState::Start, 359 | } 360 | } 361 | } 362 | 363 | /// Like FixedHeader::decode(), but on a buffer instead of a stream. Returns None if it reaches 364 | /// the end of the buffer before it finishes decoding the header. 365 | #[inline] 366 | fn decode_header(mut data: &[u8]) -> Option> { 367 | let mut header_size = 0; 368 | macro_rules! read_u8 { 369 | () => {{ 370 | let (&x, rest) = data.split_first()?; 371 | data = rest; 372 | header_size += 1; 373 | x 374 | }}; 375 | } 376 | 377 | let type_val = read_u8!(); 378 | let remaining_len = { 379 | let mut cur = 0u32; 380 | for i in 0.. { 381 | let byte = read_u8!(); 382 | cur |= ((byte as u32) & 0x7F) << (7 * i); 383 | 384 | if i >= 4 { 385 | return Some(Err(FixedHeaderError::MalformedRemainingLength)); 386 | } 387 | 388 | if byte & 0x80 == 0 { 389 | break; 390 | } 391 | } 392 | 393 | cur 394 | }; 395 | 396 | let packet_type = match PacketType::from_u8(type_val) { 397 | Ok(ty) => DecodePacketType::Standard(ty), 398 | Err(PacketTypeError::ReservedType(ty, _)) => DecodePacketType::Reserved(ty), 399 | Err(err) => return Some(Err(err.into())), 400 | }; 401 | Some(Ok((packet_type, remaining_len, header_size))) 402 | } 403 | 404 | impl codec::Decoder for MqttDecoder { 405 | type Item = VariablePacket; 406 | type Error = VariablePacketError; 407 | fn decode(&mut self, src: &mut BytesMut) -> Result, VariablePacketError> { 408 | loop { 409 | match &mut self.state { 410 | DecodeState::Start => match decode_header(&src[..]) { 411 | Some(Ok((typ, length, header_size))) => { 412 | src.advance(header_size); 413 | self.state = DecodeState::Packet { length, typ }; 414 | continue; 415 | } 416 | Some(Err(e)) => return Err(e.into()), 417 | None => return Ok(None), 418 | }, 419 | DecodeState::Packet { length, typ } => { 420 | let length = *length; 421 | if src.remaining() < length as usize { 422 | return Ok(None); 423 | } 424 | let typ = *typ; 425 | 426 | self.state = DecodeState::Start; 427 | 428 | match typ { 429 | DecodePacketType::Standard(typ) => { 430 | let header = FixedHeader { 431 | packet_type: typ, 432 | remaining_length: length, 433 | }; 434 | return decode_with_header(&mut src.reader(), header).map(Some); 435 | } 436 | DecodePacketType::Reserved(code) => { 437 | let data = src[..length as usize].to_vec(); 438 | src.advance(length as usize); 439 | return Err(VariablePacketError::ReservedPacket(code, data)); 440 | } 441 | } 442 | } 443 | } 444 | } 445 | } 446 | } 447 | 448 | pub struct MqttEncoder { 449 | _priv: (), 450 | } 451 | 452 | impl MqttEncoder { 453 | pub const fn new() -> Self { 454 | MqttEncoder { _priv: () } 455 | } 456 | } 457 | 458 | impl codec::Encoder for MqttEncoder { 459 | type Error = io::Error; 460 | fn encode(&mut self, packet: T, dst: &mut BytesMut) -> Result<(), io::Error> { 461 | dst.reserve(packet.encoded_length() as usize); 462 | packet.encode(&mut dst.writer()) 463 | } 464 | } 465 | 466 | pub struct MqttCodec { 467 | decode: MqttDecoder, 468 | encode: MqttEncoder, 469 | } 470 | 471 | impl MqttCodec { 472 | pub const fn new() -> Self { 473 | MqttCodec { 474 | decode: MqttDecoder::new(), 475 | encode: MqttEncoder::new(), 476 | } 477 | } 478 | } 479 | 480 | impl codec::Decoder for MqttCodec { 481 | type Item = VariablePacket; 482 | type Error = VariablePacketError; 483 | #[inline] 484 | fn decode(&mut self, src: &mut BytesMut) -> Result, VariablePacketError> { 485 | self.decode.decode(src) 486 | } 487 | } 488 | 489 | impl codec::Encoder for MqttCodec { 490 | type Error = io::Error; 491 | #[inline] 492 | fn encode(&mut self, packet: T, dst: &mut BytesMut) -> Result<(), io::Error> { 493 | self.encode.encode(packet, dst) 494 | } 495 | } 496 | } 497 | 498 | #[cfg(feature = "tokio-codec")] 499 | pub use tokio_codec::{MqttCodec, MqttDecoder, MqttEncoder}; 500 | 501 | #[cfg(test)] 502 | mod test { 503 | use super::*; 504 | 505 | use std::io::Cursor; 506 | 507 | use crate::{Decodable, Encodable}; 508 | 509 | #[test] 510 | fn test_variable_packet_basic() { 511 | let packet = ConnectPacket::new("1234".to_owned()); 512 | 513 | // Wrap it 514 | let var_packet = VariablePacket::new(packet); 515 | 516 | // Encode 517 | let mut buf = Vec::new(); 518 | var_packet.encode(&mut buf).unwrap(); 519 | 520 | // Decode 521 | let mut decode_buf = Cursor::new(buf); 522 | let decoded_packet = VariablePacket::decode(&mut decode_buf).unwrap(); 523 | 524 | assert_eq!(var_packet, decoded_packet); 525 | } 526 | 527 | #[cfg(feature = "tokio")] 528 | #[tokio::test] 529 | async fn test_variable_packet_async_parse() { 530 | let packet = ConnectPacket::new("1234".to_owned()); 531 | 532 | // Wrap it 533 | let var_packet = VariablePacket::new(packet); 534 | 535 | // Encode 536 | let mut buf = Vec::new(); 537 | var_packet.encode(&mut buf).unwrap(); 538 | 539 | // Parse 540 | let mut async_buf = buf.as_slice(); 541 | let decoded_packet = VariablePacket::parse(&mut async_buf).await.unwrap(); 542 | 543 | assert_eq!(var_packet, decoded_packet); 544 | } 545 | 546 | #[cfg(feature = "tokio-codec")] 547 | #[tokio::test] 548 | async fn test_variable_packet_framed() { 549 | use crate::{QualityOfService, TopicFilter}; 550 | use futures::{SinkExt, StreamExt}; 551 | use tokio_util::codec::{FramedRead, FramedWrite}; 552 | 553 | let conn_packet = ConnectPacket::new("1234".to_owned()); 554 | let sub_packet = SubscribePacket::new(1, vec![(TopicFilter::new("foo/#").unwrap(), QualityOfService::Level0)]); 555 | 556 | // small, to make sure buffering and stuff works 557 | let (reader, writer) = tokio::io::duplex(8); 558 | 559 | let task = tokio::spawn({ 560 | let (conn_packet, sub_packet) = (conn_packet.clone(), sub_packet.clone()); 561 | async move { 562 | let mut sink = FramedWrite::new(writer, MqttEncoder::new()); 563 | sink.send(conn_packet).await.unwrap(); 564 | sink.send(sub_packet).await.unwrap(); 565 | SinkExt::::flush(&mut sink).await.unwrap(); 566 | } 567 | }); 568 | 569 | let mut stream = FramedRead::new(reader, MqttDecoder::new()); 570 | let decoded_conn = stream.next().await.unwrap().unwrap(); 571 | let decoded_sub = stream.next().await.unwrap().unwrap(); 572 | 573 | task.await.unwrap(); 574 | 575 | assert!(stream.next().await.is_none()); 576 | 577 | assert_eq!(decoded_conn, conn_packet.into()); 578 | assert_eq!(decoded_sub, sub_packet.into()); 579 | } 580 | } 581 | -------------------------------------------------------------------------------- /src/packet/pingreq.rs: -------------------------------------------------------------------------------- 1 | //! PINGREQ 2 | 3 | use std::io::Read; 4 | 5 | use crate::control::{ControlType, FixedHeader, PacketType}; 6 | use crate::packet::{DecodablePacket, PacketError}; 7 | 8 | /// `PINGREQ` packet 9 | #[derive(Debug, Eq, PartialEq, Clone)] 10 | pub struct PingreqPacket { 11 | fixed_header: FixedHeader, 12 | } 13 | 14 | encodable_packet!(PingreqPacket()); 15 | 16 | impl PingreqPacket { 17 | pub fn new() -> PingreqPacket { 18 | PingreqPacket { 19 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PingRequest), 0), 20 | } 21 | } 22 | } 23 | 24 | impl Default for PingreqPacket { 25 | fn default() -> PingreqPacket { 26 | PingreqPacket::new() 27 | } 28 | } 29 | 30 | impl DecodablePacket for PingreqPacket { 31 | type DecodePacketError = std::convert::Infallible; 32 | 33 | fn decode_packet(_reader: &mut R, fixed_header: FixedHeader) -> Result> { 34 | Ok(PingreqPacket { fixed_header }) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/packet/pingresp.rs: -------------------------------------------------------------------------------- 1 | //! PINGRESP 2 | 3 | use std::io::Read; 4 | 5 | use crate::control::{ControlType, FixedHeader, PacketType}; 6 | use crate::packet::{DecodablePacket, PacketError}; 7 | 8 | /// `PINGRESP` packet 9 | #[derive(Debug, Eq, PartialEq, Clone)] 10 | pub struct PingrespPacket { 11 | fixed_header: FixedHeader, 12 | } 13 | 14 | encodable_packet!(PingrespPacket()); 15 | 16 | impl PingrespPacket { 17 | pub fn new() -> PingrespPacket { 18 | PingrespPacket { 19 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PingResponse), 0), 20 | } 21 | } 22 | } 23 | 24 | impl Default for PingrespPacket { 25 | fn default() -> PingrespPacket { 26 | PingrespPacket::new() 27 | } 28 | } 29 | 30 | impl DecodablePacket for PingrespPacket { 31 | type DecodePacketError = std::convert::Infallible; 32 | 33 | fn decode_packet(_reader: &mut R, fixed_header: FixedHeader) -> Result> { 34 | Ok(PingrespPacket { fixed_header }) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/packet/puback.rs: -------------------------------------------------------------------------------- 1 | //! PUBACK 2 | 3 | use std::io::Read; 4 | 5 | use crate::control::variable_header::PacketIdentifier; 6 | use crate::control::{ControlType, FixedHeader, PacketType}; 7 | use crate::packet::{DecodablePacket, PacketError}; 8 | use crate::Decodable; 9 | 10 | /// `PUBACK` packet 11 | #[derive(Debug, Eq, PartialEq, Clone)] 12 | pub struct PubackPacket { 13 | fixed_header: FixedHeader, 14 | packet_identifier: PacketIdentifier, 15 | } 16 | 17 | encodable_packet!(PubackPacket(packet_identifier)); 18 | 19 | impl PubackPacket { 20 | pub fn new(pkid: u16) -> PubackPacket { 21 | PubackPacket { 22 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PublishAcknowledgement), 2), 23 | packet_identifier: PacketIdentifier(pkid), 24 | } 25 | } 26 | 27 | pub fn packet_identifier(&self) -> u16 { 28 | self.packet_identifier.0 29 | } 30 | 31 | pub fn set_packet_identifier(&mut self, pkid: u16) { 32 | self.packet_identifier.0 = pkid; 33 | } 34 | } 35 | 36 | impl DecodablePacket for PubackPacket { 37 | type DecodePacketError = std::convert::Infallible; 38 | 39 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 40 | let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; 41 | Ok(PubackPacket { 42 | fixed_header, 43 | packet_identifier, 44 | }) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/packet/pubcomp.rs: -------------------------------------------------------------------------------- 1 | //! PUBCOMP 2 | 3 | use std::io::Read; 4 | 5 | use crate::control::variable_header::PacketIdentifier; 6 | use crate::control::{ControlType, FixedHeader, PacketType}; 7 | use crate::packet::{DecodablePacket, PacketError}; 8 | use crate::Decodable; 9 | 10 | /// `PUBCOMP` packet 11 | #[derive(Debug, Eq, PartialEq, Clone)] 12 | pub struct PubcompPacket { 13 | fixed_header: FixedHeader, 14 | packet_identifier: PacketIdentifier, 15 | } 16 | 17 | encodable_packet!(PubcompPacket(packet_identifier)); 18 | 19 | impl PubcompPacket { 20 | pub fn new(pkid: u16) -> PubcompPacket { 21 | PubcompPacket { 22 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PublishComplete), 2), 23 | packet_identifier: PacketIdentifier(pkid), 24 | } 25 | } 26 | 27 | pub fn packet_identifier(&self) -> u16 { 28 | self.packet_identifier.0 29 | } 30 | 31 | pub fn set_packet_identifier(&mut self, pkid: u16) { 32 | self.packet_identifier.0 = pkid; 33 | } 34 | } 35 | 36 | impl DecodablePacket for PubcompPacket { 37 | type DecodePacketError = std::convert::Infallible; 38 | 39 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 40 | let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; 41 | Ok(PubcompPacket { 42 | fixed_header, 43 | packet_identifier, 44 | }) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/packet/publish.rs: -------------------------------------------------------------------------------- 1 | //! PUBLISH 2 | 3 | use std::io::{self, Read, Write}; 4 | 5 | use crate::control::{FixedHeader, PacketType}; 6 | use crate::packet::{DecodablePacket, PacketError}; 7 | use crate::qos::QualityOfService; 8 | use crate::topic_name::TopicName; 9 | use crate::{control::variable_header::PacketIdentifier, TopicNameRef}; 10 | use crate::{Decodable, Encodable}; 11 | 12 | use super::EncodablePacket; 13 | 14 | /// QoS with identifier pairs 15 | #[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)] 16 | pub enum QoSWithPacketIdentifier { 17 | Level0, 18 | Level1(u16), 19 | Level2(u16), 20 | } 21 | 22 | impl QoSWithPacketIdentifier { 23 | pub fn new(qos: QualityOfService, id: u16) -> QoSWithPacketIdentifier { 24 | match (qos, id) { 25 | (QualityOfService::Level0, _) => QoSWithPacketIdentifier::Level0, 26 | (QualityOfService::Level1, id) => QoSWithPacketIdentifier::Level1(id), 27 | (QualityOfService::Level2, id) => QoSWithPacketIdentifier::Level2(id), 28 | } 29 | } 30 | 31 | pub fn split(self) -> (QualityOfService, Option) { 32 | match self { 33 | QoSWithPacketIdentifier::Level0 => (QualityOfService::Level0, None), 34 | QoSWithPacketIdentifier::Level1(pkid) => (QualityOfService::Level1, Some(pkid)), 35 | QoSWithPacketIdentifier::Level2(pkid) => (QualityOfService::Level2, Some(pkid)), 36 | } 37 | } 38 | } 39 | 40 | /// `PUBLISH` packet 41 | #[derive(Debug, Eq, PartialEq, Clone)] 42 | pub struct PublishPacket { 43 | fixed_header: FixedHeader, 44 | topic_name: TopicName, 45 | packet_identifier: Option, 46 | payload: Vec, 47 | } 48 | 49 | encodable_packet!(PublishPacket(topic_name, packet_identifier, payload)); 50 | 51 | impl PublishPacket { 52 | pub fn new>>(topic_name: TopicName, qos: QoSWithPacketIdentifier, payload: P) -> PublishPacket { 53 | let (qos, pkid) = qos.split(); 54 | let mut pk = PublishPacket { 55 | fixed_header: FixedHeader::new(PacketType::publish(qos), 0), 56 | topic_name, 57 | packet_identifier: pkid.map(PacketIdentifier), 58 | payload: payload.into(), 59 | }; 60 | pk.fix_header_remaining_len(); 61 | pk 62 | } 63 | 64 | pub fn set_dup(&mut self, dup: bool) { 65 | self.fixed_header 66 | .packet_type 67 | .update_flags(|flags| (flags & !(1 << 3)) | (dup as u8) << 3) 68 | } 69 | 70 | pub fn dup(&self) -> bool { 71 | self.fixed_header.packet_type.flags() & 0x80 != 0 72 | } 73 | 74 | pub fn set_qos(&mut self, qos: QoSWithPacketIdentifier) { 75 | let (qos, pkid) = qos.split(); 76 | self.fixed_header 77 | .packet_type 78 | .update_flags(|flags| (flags & !0b0110) | (qos as u8) << 1); 79 | self.packet_identifier = pkid.map(PacketIdentifier); 80 | self.fix_header_remaining_len(); 81 | } 82 | 83 | pub fn qos(&self) -> QoSWithPacketIdentifier { 84 | match self.packet_identifier { 85 | None => QoSWithPacketIdentifier::Level0, 86 | Some(pkid) => { 87 | let qos_val = (self.fixed_header.packet_type.flags() & 0b0110) >> 1; 88 | match qos_val { 89 | 1 => QoSWithPacketIdentifier::Level1(pkid.0), 90 | 2 => QoSWithPacketIdentifier::Level2(pkid.0), 91 | _ => unreachable!(), 92 | } 93 | } 94 | } 95 | } 96 | 97 | pub fn set_retain(&mut self, ret: bool) { 98 | self.fixed_header 99 | .packet_type 100 | .update_flags(|flags| (flags & !0b0001) | (ret as u8)) 101 | } 102 | 103 | pub fn retain(&self) -> bool { 104 | self.fixed_header.packet_type.flags() & 0b0001 != 0 105 | } 106 | 107 | pub fn set_topic_name(&mut self, topic_name: TopicName) { 108 | self.topic_name = topic_name; 109 | self.fix_header_remaining_len(); 110 | } 111 | 112 | pub fn topic_name(&self) -> &str { 113 | &self.topic_name[..] 114 | } 115 | 116 | pub fn payload(&self) -> &[u8] { 117 | &self.payload 118 | } 119 | 120 | pub fn set_payload>>(&mut self, payload: P) { 121 | self.payload = payload.into(); 122 | self.fix_header_remaining_len(); 123 | } 124 | } 125 | 126 | impl DecodablePacket for PublishPacket { 127 | type DecodePacketError = std::convert::Infallible; 128 | 129 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 130 | let topic_name = TopicName::decode(reader)?; 131 | 132 | let qos = (fixed_header.packet_type.flags() & 0b0110) >> 1; 133 | let packet_identifier = if qos > 0 { 134 | Some(PacketIdentifier::decode(reader)?) 135 | } else { 136 | None 137 | }; 138 | 139 | let vhead_len = 140 | topic_name.encoded_length() + packet_identifier.as_ref().map(|x| x.encoded_length()).unwrap_or(0); 141 | let payload_len = fixed_header.remaining_length - vhead_len; 142 | 143 | let payload = Vec::::decode_with(reader, Some(payload_len))?; 144 | 145 | Ok(PublishPacket { 146 | fixed_header, 147 | topic_name, 148 | packet_identifier, 149 | payload, 150 | }) 151 | } 152 | } 153 | 154 | /// `PUBLISH` packet by reference, for encoding only 155 | pub struct PublishPacketRef<'a> { 156 | fixed_header: FixedHeader, 157 | topic_name: &'a TopicNameRef, 158 | packet_identifier: Option, 159 | payload: &'a [u8], 160 | } 161 | 162 | impl<'a> PublishPacketRef<'a> { 163 | pub fn new(topic_name: &'a TopicNameRef, qos: QoSWithPacketIdentifier, payload: &'a [u8]) -> PublishPacketRef<'a> { 164 | let (qos, pkid) = qos.split(); 165 | 166 | let mut pk = PublishPacketRef { 167 | fixed_header: FixedHeader::new(PacketType::publish(qos), 0), 168 | topic_name, 169 | packet_identifier: pkid.map(PacketIdentifier), 170 | payload, 171 | }; 172 | pk.fix_header_remaining_len(); 173 | pk 174 | } 175 | 176 | fn fix_header_remaining_len(&mut self) { 177 | self.fixed_header.remaining_length = 178 | self.topic_name.encoded_length() + self.packet_identifier.encoded_length() + self.payload.encoded_length(); 179 | } 180 | } 181 | 182 | impl EncodablePacket for PublishPacketRef<'_> { 183 | fn fixed_header(&self) -> &FixedHeader { 184 | &self.fixed_header 185 | } 186 | 187 | fn encode_packet(&self, writer: &mut W) -> io::Result<()> { 188 | self.topic_name.encode(writer)?; 189 | self.packet_identifier.encode(writer)?; 190 | self.payload.encode(writer) 191 | } 192 | 193 | fn encoded_packet_length(&self) -> u32 { 194 | self.topic_name.encoded_length() + self.packet_identifier.encoded_length() + self.payload.encoded_length() 195 | } 196 | } 197 | 198 | #[cfg(test)] 199 | mod test { 200 | use super::*; 201 | 202 | use std::io::Cursor; 203 | 204 | use crate::topic_name::TopicName; 205 | use crate::{Decodable, Encodable}; 206 | 207 | #[test] 208 | fn test_publish_packet_basic() { 209 | let packet = PublishPacket::new( 210 | TopicName::new("a/b".to_owned()).unwrap(), 211 | QoSWithPacketIdentifier::Level2(10), 212 | b"Hello world!".to_vec(), 213 | ); 214 | 215 | let mut buf = Vec::new(); 216 | packet.encode(&mut buf).unwrap(); 217 | 218 | let mut decode_buf = Cursor::new(buf); 219 | let decoded = PublishPacket::decode(&mut decode_buf).unwrap(); 220 | 221 | assert_eq!(packet, decoded); 222 | } 223 | 224 | #[test] 225 | fn issue56() { 226 | let mut packet = PublishPacket::new( 227 | TopicName::new("topic").unwrap(), 228 | QoSWithPacketIdentifier::Level0, 229 | Vec::new(), 230 | ); 231 | assert_eq!(packet.fixed_header().remaining_length, 7); 232 | 233 | packet.set_qos(QoSWithPacketIdentifier::Level1(1)); 234 | assert_eq!(packet.fixed_header().remaining_length, 9); 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /src/packet/pubrec.rs: -------------------------------------------------------------------------------- 1 | //! PUBREC 2 | 3 | use std::io::Read; 4 | 5 | use crate::control::variable_header::PacketIdentifier; 6 | use crate::control::{ControlType, FixedHeader, PacketType}; 7 | use crate::packet::{DecodablePacket, PacketError}; 8 | use crate::Decodable; 9 | 10 | /// `PUBREC` packet 11 | #[derive(Debug, Eq, PartialEq, Clone)] 12 | pub struct PubrecPacket { 13 | fixed_header: FixedHeader, 14 | packet_identifier: PacketIdentifier, 15 | } 16 | 17 | encodable_packet!(PubrecPacket(packet_identifier)); 18 | 19 | impl PubrecPacket { 20 | pub fn new(pkid: u16) -> PubrecPacket { 21 | PubrecPacket { 22 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PublishReceived), 2), 23 | packet_identifier: PacketIdentifier(pkid), 24 | } 25 | } 26 | 27 | pub fn packet_identifier(&self) -> u16 { 28 | self.packet_identifier.0 29 | } 30 | 31 | pub fn set_packet_identifier(&mut self, pkid: u16) { 32 | self.packet_identifier.0 = pkid; 33 | } 34 | } 35 | 36 | impl DecodablePacket for PubrecPacket { 37 | type DecodePacketError = std::convert::Infallible; 38 | 39 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 40 | let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; 41 | Ok(PubrecPacket { 42 | fixed_header, 43 | packet_identifier, 44 | }) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/packet/pubrel.rs: -------------------------------------------------------------------------------- 1 | //! PUBREL 2 | 3 | use std::io::Read; 4 | 5 | use crate::control::variable_header::PacketIdentifier; 6 | use crate::control::{ControlType, FixedHeader, PacketType}; 7 | use crate::packet::{DecodablePacket, PacketError}; 8 | use crate::Decodable; 9 | 10 | /// `PUBREL` packet 11 | #[derive(Debug, Eq, PartialEq, Clone)] 12 | pub struct PubrelPacket { 13 | fixed_header: FixedHeader, 14 | packet_identifier: PacketIdentifier, 15 | } 16 | 17 | encodable_packet!(PubrelPacket(packet_identifier)); 18 | 19 | impl PubrelPacket { 20 | pub fn new(pkid: u16) -> PubrelPacket { 21 | PubrelPacket { 22 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::PublishRelease), 2), 23 | packet_identifier: PacketIdentifier(pkid), 24 | } 25 | } 26 | 27 | pub fn packet_identifier(&self) -> u16 { 28 | self.packet_identifier.0 29 | } 30 | 31 | pub fn set_packet_identifier(&mut self, pkid: u16) { 32 | self.packet_identifier.0 = pkid; 33 | } 34 | } 35 | 36 | impl DecodablePacket for PubrelPacket { 37 | type DecodePacketError = std::convert::Infallible; 38 | 39 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 40 | let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; 41 | Ok(PubrelPacket { 42 | fixed_header, 43 | packet_identifier, 44 | }) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/packet/suback.rs: -------------------------------------------------------------------------------- 1 | //! SUBACK 2 | 3 | use std::cmp::Ordering; 4 | 5 | use std::io::{self, Read, Write}; 6 | 7 | use byteorder::{ReadBytesExt, WriteBytesExt}; 8 | 9 | use crate::control::variable_header::PacketIdentifier; 10 | use crate::control::{ControlType, FixedHeader, PacketType}; 11 | use crate::packet::{DecodablePacket, PacketError}; 12 | use crate::qos::QualityOfService; 13 | use crate::{Decodable, Encodable}; 14 | 15 | /// Subscribe code 16 | #[repr(u8)] 17 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 18 | pub enum SubscribeReturnCode { 19 | MaximumQoSLevel0 = 0x00, 20 | MaximumQoSLevel1 = 0x01, 21 | MaximumQoSLevel2 = 0x02, 22 | Failure = 0x80, 23 | } 24 | 25 | impl PartialOrd for SubscribeReturnCode { 26 | fn partial_cmp(&self, other: &Self) -> Option { 27 | use self::SubscribeReturnCode::*; 28 | match (self, other) { 29 | (&Failure, _) => None, 30 | (_, &Failure) => None, 31 | (&MaximumQoSLevel0, &MaximumQoSLevel0) => Some(Ordering::Equal), 32 | (&MaximumQoSLevel1, &MaximumQoSLevel1) => Some(Ordering::Equal), 33 | (&MaximumQoSLevel2, &MaximumQoSLevel2) => Some(Ordering::Equal), 34 | (&MaximumQoSLevel0, _) => Some(Ordering::Less), 35 | (&MaximumQoSLevel1, &MaximumQoSLevel0) => Some(Ordering::Greater), 36 | (&MaximumQoSLevel1, &MaximumQoSLevel2) => Some(Ordering::Less), 37 | (&MaximumQoSLevel2, _) => Some(Ordering::Greater), 38 | } 39 | } 40 | } 41 | 42 | impl From for SubscribeReturnCode { 43 | fn from(qos: QualityOfService) -> Self { 44 | match qos { 45 | QualityOfService::Level0 => SubscribeReturnCode::MaximumQoSLevel0, 46 | QualityOfService::Level1 => SubscribeReturnCode::MaximumQoSLevel1, 47 | QualityOfService::Level2 => SubscribeReturnCode::MaximumQoSLevel2, 48 | } 49 | } 50 | } 51 | 52 | /// `SUBACK` packet 53 | #[derive(Debug, Eq, PartialEq, Clone)] 54 | pub struct SubackPacket { 55 | fixed_header: FixedHeader, 56 | packet_identifier: PacketIdentifier, 57 | payload: SubackPacketPayload, 58 | } 59 | 60 | encodable_packet!(SubackPacket(packet_identifier, payload)); 61 | 62 | impl SubackPacket { 63 | pub fn new(pkid: u16, subscribes: Vec) -> SubackPacket { 64 | let mut pk = SubackPacket { 65 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::SubscribeAcknowledgement), 0), 66 | packet_identifier: PacketIdentifier(pkid), 67 | payload: SubackPacketPayload::new(subscribes), 68 | }; 69 | pk.fix_header_remaining_len(); 70 | pk 71 | } 72 | 73 | pub fn packet_identifier(&self) -> u16 { 74 | self.packet_identifier.0 75 | } 76 | 77 | pub fn set_packet_identifier(&mut self, pkid: u16) { 78 | self.packet_identifier.0 = pkid; 79 | } 80 | 81 | pub fn subscribes(&self) -> &[SubscribeReturnCode] { 82 | &self.payload.subscribes[..] 83 | } 84 | } 85 | 86 | impl DecodablePacket for SubackPacket { 87 | type DecodePacketError = SubackPacketError; 88 | 89 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 90 | let packet_identifier = PacketIdentifier::decode(reader)?; 91 | let payload: SubackPacketPayload = SubackPacketPayload::decode_with( 92 | reader, 93 | fixed_header.remaining_length - packet_identifier.encoded_length(), 94 | ) 95 | .map_err(PacketError::PayloadError)?; 96 | Ok(SubackPacket { 97 | fixed_header, 98 | packet_identifier, 99 | payload, 100 | }) 101 | } 102 | } 103 | 104 | #[derive(Debug, Eq, PartialEq, Clone)] 105 | struct SubackPacketPayload { 106 | subscribes: Vec, 107 | } 108 | 109 | impl SubackPacketPayload { 110 | pub fn new(subs: Vec) -> SubackPacketPayload { 111 | SubackPacketPayload { subscribes: subs } 112 | } 113 | } 114 | 115 | impl Encodable for SubackPacketPayload { 116 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 117 | for code in self.subscribes.iter() { 118 | writer.write_u8(*code as u8)?; 119 | } 120 | 121 | Ok(()) 122 | } 123 | 124 | fn encoded_length(&self) -> u32 { 125 | self.subscribes.len() as u32 126 | } 127 | } 128 | 129 | impl Decodable for SubackPacketPayload { 130 | type Error = SubackPacketError; 131 | type Cond = u32; 132 | 133 | fn decode_with(reader: &mut R, payload_len: u32) -> Result { 134 | let mut subs = Vec::new(); 135 | 136 | for _ in 0..payload_len { 137 | let retcode = match reader.read_u8()? { 138 | 0x00 => SubscribeReturnCode::MaximumQoSLevel0, 139 | 0x01 => SubscribeReturnCode::MaximumQoSLevel1, 140 | 0x02 => SubscribeReturnCode::MaximumQoSLevel2, 141 | 0x80 => SubscribeReturnCode::Failure, 142 | code => return Err(SubackPacketError::InvalidSubscribeReturnCode(code)), 143 | }; 144 | 145 | subs.push(retcode); 146 | } 147 | 148 | Ok(SubackPacketPayload::new(subs)) 149 | } 150 | } 151 | 152 | #[derive(Debug, thiserror::Error)] 153 | pub enum SubackPacketError { 154 | #[error(transparent)] 155 | IoError(#[from] io::Error), 156 | #[error("invalid subscribe return code {0}")] 157 | InvalidSubscribeReturnCode(u8), 158 | } 159 | -------------------------------------------------------------------------------- /src/packet/subscribe.rs: -------------------------------------------------------------------------------- 1 | //! SUBSCRIBE 2 | 3 | use std::io::{self, Read, Write}; 4 | use std::string::FromUtf8Error; 5 | 6 | use byteorder::{ReadBytesExt, WriteBytesExt}; 7 | 8 | use crate::control::variable_header::PacketIdentifier; 9 | use crate::control::{ControlType, FixedHeader, PacketType}; 10 | use crate::packet::{DecodablePacket, PacketError}; 11 | use crate::topic_filter::{TopicFilter, TopicFilterDecodeError, TopicFilterError}; 12 | use crate::{Decodable, Encodable, QualityOfService}; 13 | 14 | /// `SUBSCRIBE` packet 15 | #[derive(Debug, Eq, PartialEq, Clone)] 16 | pub struct SubscribePacket { 17 | fixed_header: FixedHeader, 18 | packet_identifier: PacketIdentifier, 19 | payload: SubscribePacketPayload, 20 | } 21 | 22 | encodable_packet!(SubscribePacket(packet_identifier, payload)); 23 | 24 | impl SubscribePacket { 25 | pub fn new(pkid: u16, subscribes: Vec<(TopicFilter, QualityOfService)>) -> SubscribePacket { 26 | let mut pk = SubscribePacket { 27 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Subscribe), 0), 28 | packet_identifier: PacketIdentifier(pkid), 29 | payload: SubscribePacketPayload::new(subscribes), 30 | }; 31 | pk.fix_header_remaining_len(); 32 | pk 33 | } 34 | 35 | pub fn packet_identifier(&self) -> u16 { 36 | self.packet_identifier.0 37 | } 38 | 39 | pub fn set_packet_identifier(&mut self, pkid: u16) { 40 | self.packet_identifier.0 = pkid; 41 | } 42 | 43 | pub fn subscribes(&self) -> &[(TopicFilter, QualityOfService)] { 44 | &self.payload.subscribes[..] 45 | } 46 | } 47 | 48 | impl DecodablePacket for SubscribePacket { 49 | type DecodePacketError = SubscribePacketError; 50 | 51 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 52 | let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; 53 | let payload: SubscribePacketPayload = SubscribePacketPayload::decode_with( 54 | reader, 55 | fixed_header.remaining_length - packet_identifier.encoded_length(), 56 | ) 57 | .map_err(PacketError::PayloadError)?; 58 | Ok(SubscribePacket { 59 | fixed_header, 60 | packet_identifier, 61 | payload, 62 | }) 63 | } 64 | } 65 | 66 | /// Payload of subscribe packet 67 | #[derive(Debug, Eq, PartialEq, Clone)] 68 | struct SubscribePacketPayload { 69 | subscribes: Vec<(TopicFilter, QualityOfService)>, 70 | } 71 | 72 | impl SubscribePacketPayload { 73 | pub fn new(subs: Vec<(TopicFilter, QualityOfService)>) -> SubscribePacketPayload { 74 | SubscribePacketPayload { subscribes: subs } 75 | } 76 | } 77 | 78 | impl Encodable for SubscribePacketPayload { 79 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 80 | for &(ref filter, ref qos) in self.subscribes.iter() { 81 | filter.encode(writer)?; 82 | writer.write_u8(*qos as u8)?; 83 | } 84 | 85 | Ok(()) 86 | } 87 | 88 | fn encoded_length(&self) -> u32 { 89 | self.subscribes.iter().fold(0, |b, a| b + a.0.encoded_length() + 1) 90 | } 91 | } 92 | 93 | impl Decodable for SubscribePacketPayload { 94 | type Error = SubscribePacketError; 95 | type Cond = u32; 96 | 97 | fn decode_with( 98 | reader: &mut R, 99 | mut payload_len: u32, 100 | ) -> Result { 101 | let mut subs = Vec::new(); 102 | 103 | while payload_len > 0 { 104 | let filter = TopicFilter::decode(reader)?; 105 | let qos = match reader.read_u8()? { 106 | 0 => QualityOfService::Level0, 107 | 1 => QualityOfService::Level1, 108 | 2 => QualityOfService::Level2, 109 | _ => return Err(SubscribePacketError::InvalidQualityOfService), 110 | }; 111 | 112 | payload_len -= filter.encoded_length() + 1; 113 | subs.push((filter, qos)); 114 | } 115 | 116 | Ok(SubscribePacketPayload::new(subs)) 117 | } 118 | } 119 | 120 | #[derive(Debug, thiserror::Error)] 121 | pub enum SubscribePacketError { 122 | #[error(transparent)] 123 | IoError(#[from] io::Error), 124 | #[error(transparent)] 125 | FromUtf8Error(#[from] FromUtf8Error), 126 | #[error("invalid quality of service")] 127 | InvalidQualityOfService, 128 | #[error(transparent)] 129 | TopicFilterError(#[from] TopicFilterError), 130 | } 131 | 132 | impl From for SubscribePacketError { 133 | fn from(e: TopicFilterDecodeError) -> Self { 134 | match e { 135 | TopicFilterDecodeError::IoError(e) => e.into(), 136 | TopicFilterDecodeError::InvalidTopicFilter(e) => e.into(), 137 | } 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /src/packet/unsuback.rs: -------------------------------------------------------------------------------- 1 | //! UNSUBACK 2 | 3 | use std::io::Read; 4 | 5 | use crate::control::variable_header::PacketIdentifier; 6 | use crate::control::{ControlType, FixedHeader, PacketType}; 7 | use crate::packet::{DecodablePacket, PacketError}; 8 | use crate::Decodable; 9 | 10 | /// `UNSUBACK` packet 11 | #[derive(Debug, Eq, PartialEq, Clone)] 12 | pub struct UnsubackPacket { 13 | fixed_header: FixedHeader, 14 | packet_identifier: PacketIdentifier, 15 | } 16 | 17 | encodable_packet!(UnsubackPacket(packet_identifier)); 18 | 19 | impl UnsubackPacket { 20 | pub fn new(pkid: u16) -> UnsubackPacket { 21 | UnsubackPacket { 22 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::UnsubscribeAcknowledgement), 2), 23 | packet_identifier: PacketIdentifier(pkid), 24 | } 25 | } 26 | 27 | pub fn packet_identifier(&self) -> u16 { 28 | self.packet_identifier.0 29 | } 30 | 31 | pub fn set_packet_identifier(&mut self, pkid: u16) { 32 | self.packet_identifier.0 = pkid; 33 | } 34 | } 35 | 36 | impl DecodablePacket for UnsubackPacket { 37 | type DecodePacketError = std::convert::Infallible; 38 | 39 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 40 | let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; 41 | Ok(UnsubackPacket { 42 | fixed_header, 43 | packet_identifier, 44 | }) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/packet/unsubscribe.rs: -------------------------------------------------------------------------------- 1 | //! UNSUBSCRIBE 2 | 3 | use std::io::{self, Read, Write}; 4 | use std::string::FromUtf8Error; 5 | 6 | use crate::control::variable_header::PacketIdentifier; 7 | use crate::control::{ControlType, FixedHeader, PacketType}; 8 | use crate::packet::{DecodablePacket, PacketError}; 9 | use crate::topic_filter::{TopicFilter, TopicFilterDecodeError, TopicFilterError}; 10 | use crate::{Decodable, Encodable}; 11 | 12 | /// `UNSUBSCRIBE` packet 13 | #[derive(Debug, Eq, PartialEq, Clone)] 14 | pub struct UnsubscribePacket { 15 | fixed_header: FixedHeader, 16 | packet_identifier: PacketIdentifier, 17 | payload: UnsubscribePacketPayload, 18 | } 19 | 20 | encodable_packet!(UnsubscribePacket(packet_identifier, payload)); 21 | 22 | impl UnsubscribePacket { 23 | pub fn new(pkid: u16, subscribes: Vec) -> UnsubscribePacket { 24 | let mut pk = UnsubscribePacket { 25 | fixed_header: FixedHeader::new(PacketType::with_default(ControlType::Unsubscribe), 0), 26 | packet_identifier: PacketIdentifier(pkid), 27 | payload: UnsubscribePacketPayload::new(subscribes), 28 | }; 29 | pk.fix_header_remaining_len(); 30 | pk 31 | } 32 | 33 | pub fn packet_identifier(&self) -> u16 { 34 | self.packet_identifier.0 35 | } 36 | 37 | pub fn set_packet_identifier(&mut self, pkid: u16) { 38 | self.packet_identifier.0 = pkid; 39 | } 40 | 41 | pub fn subscribes(&self) -> &[TopicFilter] { 42 | &self.payload.subscribes[..] 43 | } 44 | } 45 | 46 | impl DecodablePacket for UnsubscribePacket { 47 | type DecodePacketError = UnsubscribePacketError; 48 | 49 | fn decode_packet(reader: &mut R, fixed_header: FixedHeader) -> Result> { 50 | let packet_identifier: PacketIdentifier = PacketIdentifier::decode(reader)?; 51 | let payload: UnsubscribePacketPayload = UnsubscribePacketPayload::decode_with( 52 | reader, 53 | fixed_header.remaining_length - packet_identifier.encoded_length(), 54 | ) 55 | .map_err(PacketError::PayloadError)?; 56 | Ok(UnsubscribePacket { 57 | fixed_header, 58 | packet_identifier, 59 | payload, 60 | }) 61 | } 62 | } 63 | 64 | #[derive(Debug, Eq, PartialEq, Clone)] 65 | struct UnsubscribePacketPayload { 66 | subscribes: Vec, 67 | } 68 | 69 | impl UnsubscribePacketPayload { 70 | pub fn new(subs: Vec) -> UnsubscribePacketPayload { 71 | UnsubscribePacketPayload { subscribes: subs } 72 | } 73 | } 74 | 75 | impl Encodable for UnsubscribePacketPayload { 76 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 77 | for filter in self.subscribes.iter() { 78 | filter.encode(writer)?; 79 | } 80 | 81 | Ok(()) 82 | } 83 | 84 | fn encoded_length(&self) -> u32 { 85 | self.subscribes.iter().fold(0, |b, a| b + a.encoded_length()) 86 | } 87 | } 88 | 89 | impl Decodable for UnsubscribePacketPayload { 90 | type Error = UnsubscribePacketError; 91 | type Cond = u32; 92 | 93 | fn decode_with( 94 | reader: &mut R, 95 | mut payload_len: u32, 96 | ) -> Result { 97 | let mut subs = Vec::new(); 98 | 99 | while payload_len > 0 { 100 | let filter = TopicFilter::decode(reader)?; 101 | payload_len -= filter.encoded_length(); 102 | subs.push(filter); 103 | } 104 | 105 | Ok(UnsubscribePacketPayload::new(subs)) 106 | } 107 | } 108 | 109 | #[derive(Debug, thiserror::Error)] 110 | #[error(transparent)] 111 | pub enum UnsubscribePacketError { 112 | IoError(#[from] io::Error), 113 | FromUtf8Error(#[from] FromUtf8Error), 114 | TopicFilterError(#[from] TopicFilterError), 115 | } 116 | 117 | impl From for UnsubscribePacketError { 118 | fn from(e: TopicFilterDecodeError) -> Self { 119 | match e { 120 | TopicFilterDecodeError::IoError(e) => e.into(), 121 | TopicFilterDecodeError::InvalidTopicFilter(e) => e.into(), 122 | } 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/qos.rs: -------------------------------------------------------------------------------- 1 | //! QoS (Quality of Services) 2 | 3 | use crate::packet::publish::QoSWithPacketIdentifier; 4 | 5 | #[repr(u8)] 6 | #[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Copy, Clone)] 7 | pub enum QualityOfService { 8 | Level0 = 0, 9 | Level1 = 1, 10 | Level2 = 2, 11 | } 12 | 13 | impl From for QualityOfService { 14 | fn from(qos: QoSWithPacketIdentifier) -> Self { 15 | match qos { 16 | QoSWithPacketIdentifier::Level0 => QualityOfService::Level0, 17 | QoSWithPacketIdentifier::Level1(_) => QualityOfService::Level1, 18 | QoSWithPacketIdentifier::Level2(_) => QualityOfService::Level2, 19 | } 20 | } 21 | } 22 | 23 | #[cfg(test)] 24 | mod test { 25 | use super::*; 26 | use std::cmp::min; 27 | 28 | #[test] 29 | fn min_qos() { 30 | let q1 = QoSWithPacketIdentifier::Level1(0).into(); 31 | let q2 = QualityOfService::Level2; 32 | assert_eq!(min(q1, q2), q1); 33 | 34 | let q1 = QoSWithPacketIdentifier::Level0.into(); 35 | let q2 = QualityOfService::Level2; 36 | assert_eq!(min(q1, q2), q1); 37 | 38 | let q1 = QoSWithPacketIdentifier::Level2(0).into(); 39 | let q2 = QualityOfService::Level1; 40 | assert_eq!(min(q1, q2), q2); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/topic_filter.rs: -------------------------------------------------------------------------------- 1 | //! Topic filter 2 | 3 | use std::io::{self, Read, Write}; 4 | use std::ops::Deref; 5 | 6 | use crate::topic_name::TopicNameRef; 7 | use crate::{Decodable, Encodable}; 8 | 9 | #[inline] 10 | fn is_invalid_topic_filter(topic: &str) -> bool { 11 | if topic.is_empty() || topic.as_bytes().len() > 65535 { 12 | return true; 13 | } 14 | 15 | let mut found_hash = false; 16 | for member in topic.split('/') { 17 | if found_hash { 18 | return true; 19 | } 20 | 21 | match member { 22 | "#" => found_hash = true, 23 | "+" => {} 24 | _ => { 25 | if member.contains(['#', '+']) { 26 | return true; 27 | } 28 | } 29 | } 30 | } 31 | 32 | false 33 | } 34 | 35 | /// Topic filter 36 | /// 37 | /// 38 | /// 39 | /// ```rust 40 | /// use mqtt::{TopicFilter, TopicNameRef}; 41 | /// 42 | /// let topic_filter = TopicFilter::new("sport/+/player1").unwrap(); 43 | /// let matcher = topic_filter.get_matcher(); 44 | /// assert!(matcher.is_match(TopicNameRef::new("sport/abc/player1").unwrap())); 45 | /// ``` 46 | #[derive(Debug, Eq, PartialEq, Clone, Hash, Ord, PartialOrd)] 47 | pub struct TopicFilter(String); 48 | 49 | impl TopicFilter { 50 | /// Creates a new topic filter from string 51 | /// Return error if it is not a valid topic filter 52 | pub fn new>(topic: S) -> Result { 53 | let topic = topic.into(); 54 | if is_invalid_topic_filter(&topic) { 55 | Err(TopicFilterError(topic)) 56 | } else { 57 | Ok(TopicFilter(topic)) 58 | } 59 | } 60 | 61 | /// Creates a new topic filter from string without validation 62 | /// 63 | /// # Safety 64 | /// 65 | /// Topic filters' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106). 66 | /// Creating a filter from raw string may cause errors 67 | pub unsafe fn new_unchecked>(topic: S) -> TopicFilter { 68 | TopicFilter(topic.into()) 69 | } 70 | } 71 | 72 | impl From for String { 73 | fn from(topic: TopicFilter) -> String { 74 | topic.0 75 | } 76 | } 77 | 78 | impl Encodable for TopicFilter { 79 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 80 | (&self.0[..]).encode(writer) 81 | } 82 | 83 | fn encoded_length(&self) -> u32 { 84 | (&self.0[..]).encoded_length() 85 | } 86 | } 87 | 88 | impl Decodable for TopicFilter { 89 | type Error = TopicFilterDecodeError; 90 | type Cond = (); 91 | 92 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 93 | let topic_filter = String::decode(reader)?; 94 | Ok(TopicFilter::new(topic_filter)?) 95 | } 96 | } 97 | 98 | impl Deref for TopicFilter { 99 | type Target = TopicFilterRef; 100 | 101 | fn deref(&self) -> &TopicFilterRef { 102 | unsafe { TopicFilterRef::new_unchecked(&self.0) } 103 | } 104 | } 105 | 106 | /// Reference to a `TopicFilter` 107 | #[derive(Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] 108 | #[repr(transparent)] 109 | pub struct TopicFilterRef(str); 110 | 111 | impl TopicFilterRef { 112 | /// Creates a new topic filter from string 113 | /// Return error if it is not a valid topic filter 114 | pub fn new + ?Sized>(topic: &S) -> Result<&TopicFilterRef, TopicFilterError> { 115 | let topic = topic.as_ref(); 116 | if is_invalid_topic_filter(topic) { 117 | Err(TopicFilterError(topic.to_owned())) 118 | } else { 119 | Ok(unsafe { &*(topic as *const str as *const TopicFilterRef) }) 120 | } 121 | } 122 | 123 | /// Creates a new topic filter from string without validation 124 | /// 125 | /// # Safety 126 | /// 127 | /// Topic filters' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106). 128 | /// Creating a filter from raw string may cause errors 129 | pub unsafe fn new_unchecked + ?Sized>(topic: &S) -> &TopicFilterRef { 130 | let topic = topic.as_ref(); 131 | &*(topic as *const str as *const TopicFilterRef) 132 | } 133 | 134 | /// Get a matcher 135 | pub fn get_matcher(&self) -> TopicFilterMatcher<'_> { 136 | TopicFilterMatcher::new(&self.0) 137 | } 138 | } 139 | 140 | impl Deref for TopicFilterRef { 141 | type Target = str; 142 | 143 | fn deref(&self) -> &str { 144 | &self.0 145 | } 146 | } 147 | 148 | #[derive(Debug, thiserror::Error)] 149 | #[error("invalid topic filter ({0})")] 150 | pub struct TopicFilterError(pub String); 151 | 152 | /// Errors while parsing topic filters 153 | #[derive(Debug, thiserror::Error)] 154 | #[error(transparent)] 155 | pub enum TopicFilterDecodeError { 156 | IoError(#[from] io::Error), 157 | InvalidTopicFilter(#[from] TopicFilterError), 158 | } 159 | 160 | /// Matcher for matching topic names with this filter 161 | #[derive(Debug, Copy, Clone)] 162 | pub struct TopicFilterMatcher<'a> { 163 | topic_filter: &'a str, 164 | } 165 | 166 | impl<'a> TopicFilterMatcher<'a> { 167 | fn new(filter: &'a str) -> TopicFilterMatcher<'a> { 168 | TopicFilterMatcher { topic_filter: filter } 169 | } 170 | 171 | /// Check if this filter can match the `topic_name` 172 | pub fn is_match(&self, topic_name: &TopicNameRef) -> bool { 173 | let mut tn_itr = topic_name.split('/'); 174 | let mut ft_itr = self.topic_filter.split('/'); 175 | 176 | // The Server MUST NOT match Topic Filters starting with a wildcard character (# or +) 177 | // with Topic Names beginning with a $ character [MQTT-4.7.2-1]. 178 | 179 | let first_ft = ft_itr.next().unwrap(); 180 | let first_tn = tn_itr.next().unwrap(); 181 | 182 | if first_tn.starts_with('$') { 183 | if first_tn != first_ft { 184 | return false; 185 | } 186 | } else { 187 | match first_ft { 188 | // Matches the whole topic 189 | "#" => return true, 190 | "+" => {} 191 | _ => { 192 | if first_tn != first_ft { 193 | return false; 194 | } 195 | } 196 | } 197 | } 198 | 199 | loop { 200 | match (ft_itr.next(), tn_itr.next()) { 201 | (Some(ft), Some(tn)) => match ft { 202 | "#" => break, 203 | "+" => {} 204 | _ => { 205 | if ft != tn { 206 | return false; 207 | } 208 | } 209 | }, 210 | (Some(ft), None) => { 211 | if ft != "#" { 212 | return false; 213 | } else { 214 | break; 215 | } 216 | } 217 | (None, Some(..)) => return false, 218 | (None, None) => break, 219 | } 220 | } 221 | 222 | true 223 | } 224 | } 225 | 226 | #[cfg(test)] 227 | mod test { 228 | use super::*; 229 | 230 | #[test] 231 | fn topic_filter_validate() { 232 | let topic = "#".to_owned(); 233 | TopicFilter::new(topic).unwrap(); 234 | 235 | let topic = "sport/tennis/player1".to_owned(); 236 | TopicFilter::new(topic).unwrap(); 237 | 238 | let topic = "sport/tennis/player1/ranking".to_owned(); 239 | TopicFilter::new(topic).unwrap(); 240 | 241 | let topic = "sport/tennis/player1/#".to_owned(); 242 | TopicFilter::new(topic).unwrap(); 243 | 244 | let topic = "#".to_owned(); 245 | TopicFilter::new(topic).unwrap(); 246 | 247 | let topic = "sport/tennis/#".to_owned(); 248 | TopicFilter::new(topic).unwrap(); 249 | 250 | let topic = "sport/tennis#".to_owned(); 251 | assert!(TopicFilter::new(topic).is_err()); 252 | 253 | let topic = "sport/tennis/#/ranking".to_owned(); 254 | assert!(TopicFilter::new(topic).is_err()); 255 | 256 | let topic = "+".to_owned(); 257 | TopicFilter::new(topic).unwrap(); 258 | 259 | let topic = "+/tennis/#".to_owned(); 260 | TopicFilter::new(topic).unwrap(); 261 | 262 | let topic = "sport+".to_owned(); 263 | assert!(TopicFilter::new(topic).is_err()); 264 | 265 | let topic = "sport/+/player1".to_owned(); 266 | TopicFilter::new(topic).unwrap(); 267 | 268 | let topic = "+/+".to_owned(); 269 | TopicFilter::new(topic).unwrap(); 270 | 271 | let topic = "$SYS/#".to_owned(); 272 | TopicFilter::new(topic).unwrap(); 273 | 274 | let topic = "$SYS".to_owned(); 275 | TopicFilter::new(topic).unwrap(); 276 | } 277 | 278 | #[test] 279 | fn topic_filter_matcher() { 280 | let filter = TopicFilter::new("sport/#").unwrap(); 281 | let matcher = filter.get_matcher(); 282 | assert!(matcher.is_match(TopicNameRef::new("sport").unwrap())); 283 | 284 | let filter = TopicFilter::new("#").unwrap(); 285 | let matcher = filter.get_matcher(); 286 | assert!(matcher.is_match(TopicNameRef::new("sport").unwrap())); 287 | assert!(matcher.is_match(TopicNameRef::new("/").unwrap())); 288 | assert!(matcher.is_match(TopicNameRef::new("abc/def").unwrap())); 289 | assert!(!matcher.is_match(TopicNameRef::new("$SYS").unwrap())); 290 | assert!(!matcher.is_match(TopicNameRef::new("$SYS/abc").unwrap())); 291 | 292 | let filter = TopicFilter::new("+/monitor/Clients").unwrap(); 293 | let matcher = filter.get_matcher(); 294 | assert!(!matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap())); 295 | 296 | let filter = TopicFilter::new("$SYS/#").unwrap(); 297 | let matcher = filter.get_matcher(); 298 | assert!(matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap())); 299 | assert!(matcher.is_match(TopicNameRef::new("$SYS").unwrap())); 300 | 301 | let filter = TopicFilter::new("$SYS/monitor/+").unwrap(); 302 | let matcher = filter.get_matcher(); 303 | assert!(matcher.is_match(TopicNameRef::new("$SYS/monitor/Clients").unwrap())); 304 | } 305 | } 306 | -------------------------------------------------------------------------------- /src/topic_name.rs: -------------------------------------------------------------------------------- 1 | //! Topic name 2 | 3 | use std::{ 4 | borrow::{Borrow, BorrowMut}, 5 | io::{self, Read, Write}, 6 | ops::{Deref, DerefMut}, 7 | }; 8 | 9 | use crate::{Decodable, Encodable}; 10 | 11 | #[inline] 12 | fn is_invalid_topic_name(topic_name: &str) -> bool { 13 | topic_name.is_empty() || topic_name.as_bytes().len() > 65535 || topic_name.chars().any(|ch| ch == '#' || ch == '+') 14 | } 15 | 16 | /// Topic name 17 | /// 18 | /// 19 | #[derive(Debug, Eq, PartialEq, Clone, Hash, Ord, PartialOrd)] 20 | pub struct TopicName(String); 21 | 22 | impl TopicName { 23 | /// Creates a new topic name from string 24 | /// Return error if the string is not a valid topic name 25 | pub fn new>(topic_name: S) -> Result { 26 | let topic_name = topic_name.into(); 27 | if is_invalid_topic_name(&topic_name) { 28 | Err(TopicNameError(topic_name)) 29 | } else { 30 | Ok(TopicName(topic_name)) 31 | } 32 | } 33 | 34 | /// Creates a new topic name from string without validation 35 | /// 36 | /// # Safety 37 | /// 38 | /// Topic names' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106). 39 | /// Creating a name from raw string may cause errors 40 | pub unsafe fn new_unchecked(topic_name: String) -> TopicName { 41 | TopicName(topic_name) 42 | } 43 | } 44 | 45 | impl From for String { 46 | fn from(topic_name: TopicName) -> String { 47 | topic_name.0 48 | } 49 | } 50 | 51 | impl Deref for TopicName { 52 | type Target = TopicNameRef; 53 | 54 | fn deref(&self) -> &TopicNameRef { 55 | unsafe { TopicNameRef::new_unchecked(&self.0) } 56 | } 57 | } 58 | 59 | impl DerefMut for TopicName { 60 | fn deref_mut(&mut self) -> &mut Self::Target { 61 | unsafe { TopicNameRef::new_mut_unchecked(&mut self.0) } 62 | } 63 | } 64 | 65 | impl Borrow for TopicName { 66 | fn borrow(&self) -> &TopicNameRef { 67 | Deref::deref(self) 68 | } 69 | } 70 | 71 | impl BorrowMut for TopicName { 72 | fn borrow_mut(&mut self) -> &mut TopicNameRef { 73 | DerefMut::deref_mut(self) 74 | } 75 | } 76 | 77 | impl Encodable for TopicName { 78 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 79 | (&self.0[..]).encode(writer) 80 | } 81 | 82 | fn encoded_length(&self) -> u32 { 83 | (&self.0[..]).encoded_length() 84 | } 85 | } 86 | 87 | impl Decodable for TopicName { 88 | type Error = TopicNameDecodeError; 89 | type Cond = (); 90 | 91 | fn decode_with(reader: &mut R, _rest: ()) -> Result { 92 | let topic_name = String::decode(reader)?; 93 | Ok(TopicName::new(topic_name)?) 94 | } 95 | } 96 | 97 | #[derive(Debug, thiserror::Error)] 98 | #[error("invalid topic filter ({0})")] 99 | pub struct TopicNameError(pub String); 100 | 101 | /// Errors while parsing topic names 102 | #[derive(Debug, thiserror::Error)] 103 | #[error(transparent)] 104 | pub enum TopicNameDecodeError { 105 | IoError(#[from] io::Error), 106 | InvalidTopicName(#[from] TopicNameError), 107 | } 108 | 109 | /// Reference to a topic name 110 | #[derive(Debug, Eq, PartialEq, Hash, Ord, PartialOrd)] 111 | #[repr(transparent)] 112 | pub struct TopicNameRef(str); 113 | 114 | impl TopicNameRef { 115 | /// Creates a new topic name from string 116 | /// Return error if the string is not a valid topic name 117 | pub fn new + ?Sized>(topic_name: &S) -> Result<&TopicNameRef, TopicNameError> { 118 | let topic_name = topic_name.as_ref(); 119 | if is_invalid_topic_name(topic_name) { 120 | Err(TopicNameError(topic_name.to_owned())) 121 | } else { 122 | Ok(unsafe { &*(topic_name as *const str as *const TopicNameRef) }) 123 | } 124 | } 125 | 126 | /// Creates a new topic name from string 127 | /// Return error if the string is not a valid topic name 128 | pub fn new_mut + ?Sized>(topic_name: &mut S) -> Result<&mut TopicNameRef, TopicNameError> { 129 | let topic_name = topic_name.as_mut(); 130 | if is_invalid_topic_name(topic_name) { 131 | Err(TopicNameError(topic_name.to_owned())) 132 | } else { 133 | Ok(unsafe { &mut *(topic_name as *mut str as *mut TopicNameRef) }) 134 | } 135 | } 136 | 137 | /// Creates a new topic name from string without validation 138 | /// 139 | /// # Safety 140 | /// 141 | /// Topic names' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106). 142 | /// Creating a name from raw string may cause errors 143 | pub unsafe fn new_unchecked + ?Sized>(topic_name: &S) -> &TopicNameRef { 144 | let topic_name = topic_name.as_ref(); 145 | &*(topic_name as *const str as *const TopicNameRef) 146 | } 147 | 148 | /// Creates a new topic name from string without validation 149 | /// 150 | /// # Safety 151 | /// 152 | /// Topic names' syntax is defined in [MQTT specification](http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc398718106). 153 | /// Creating a name from raw string may cause errors 154 | pub unsafe fn new_mut_unchecked + ?Sized>(topic_name: &mut S) -> &mut TopicNameRef { 155 | let topic_name = topic_name.as_mut(); 156 | &mut *(topic_name as *mut str as *mut TopicNameRef) 157 | } 158 | 159 | /// Check if this topic name is only for server. 160 | /// 161 | /// Topic names that beginning with a '$' character are reserved for servers 162 | pub fn is_server_specific(&self) -> bool { 163 | self.0.starts_with('$') 164 | } 165 | } 166 | 167 | impl Deref for TopicNameRef { 168 | type Target = str; 169 | 170 | fn deref(&self) -> &str { 171 | &self.0 172 | } 173 | } 174 | 175 | impl ToOwned for TopicNameRef { 176 | type Owned = TopicName; 177 | 178 | fn to_owned(&self) -> Self::Owned { 179 | TopicName(self.0.to_owned()) 180 | } 181 | } 182 | 183 | impl Encodable for TopicNameRef { 184 | fn encode(&self, writer: &mut W) -> Result<(), io::Error> { 185 | (&self.0[..]).encode(writer) 186 | } 187 | 188 | fn encoded_length(&self) -> u32 { 189 | (&self.0[..]).encoded_length() 190 | } 191 | } 192 | 193 | #[cfg(test)] 194 | mod test { 195 | use super::*; 196 | 197 | #[test] 198 | fn topic_name_sys() { 199 | let topic_name = "$SYS".to_owned(); 200 | TopicName::new(topic_name).unwrap(); 201 | 202 | let topic_name = "$SYS/broker/connection/test.cosm-energy/state".to_owned(); 203 | TopicName::new(topic_name).unwrap(); 204 | } 205 | 206 | #[test] 207 | fn topic_name_slash() { 208 | TopicName::new("/").unwrap(); 209 | } 210 | 211 | #[test] 212 | fn topic_name_basic() { 213 | TopicName::new("/finance").unwrap(); 214 | TopicName::new("/finance//def").unwrap(); 215 | } 216 | } 217 | --------------------------------------------------------------------------------