├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md └── src ├── codec ├── mod.rs ├── string.rs └── values.rs ├── error.rs ├── fixed_header ├── mod.rs ├── packet_flags.rs └── packet_type.rs ├── lib.rs ├── packet.rs ├── payload ├── connect.rs ├── mod.rs ├── suback.rs └── subscribe.rs ├── qos.rs ├── status.rs └── variable_header ├── connack.rs ├── connect.rs ├── mod.rs ├── packet_identifier.rs └── publish.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | .vscode/ -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "embedded-mqtt" 3 | description = "no_std encoder/decoder for MQTT 3.1.1 protocol packets for embedded devices." 4 | version = "0.1.0" 5 | authors = ["Keith Duncan "] 6 | repository = "https://github.com/keithduncan/embedded-mqtt" 7 | license = "MIT" 8 | keywords = ["mqtt"] 9 | categories = ["embedded", "network-programming", "no-std"] 10 | 11 | [features] 12 | std = ["byteorder/std"] 13 | 14 | [dependencies] 15 | byteorder = { version = "1.2", default-features = false } 16 | bitfield = "0.13.1" 17 | 18 | [dev-dependencies] 19 | rayon = "1.0" 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Rajasekharan Vengalil 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | MIT License 25 | 26 | Copyright (c) Keith Duncan 27 | 28 | Permission is hereby granted, free of charge, to any person obtaining a copy 29 | of this software and associated documentation files (the "Software"), to deal 30 | in the Software without restriction, including without limitation the rights 31 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 32 | copies of the Software, and to permit persons to whom the Software is 33 | furnished to do so, subject to the following conditions: 34 | 35 | The above copyright notice and this permission notice shall be included in all 36 | copies or substantial portions of the Software. 37 | 38 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 39 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 40 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 41 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 42 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 43 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 44 | SOFTWARE. 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # embedded-mqtt 2 | 3 | An encoder/decoder for the MQTT 3.1.1 protocol in pure Rust. 4 | 5 | `no_std` by default, a `std` feature enables extensions. 6 | 7 | Originally forked from https://github.com/avranju/mqttparse and 8 | renamed when I added encode support, the original license and 9 | copyright is preserved in [LICENSE](LICENSE). 10 | -------------------------------------------------------------------------------- /src/codec/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | error::{DecodeError, EncodeError}, 3 | status::Status, 4 | }; 5 | 6 | pub mod string; 7 | pub mod values; 8 | 9 | pub trait Decodable<'a> 10 | where 11 | Self: core::marker::Sized, 12 | { 13 | fn decode(bytes: &'a [u8]) -> Result, DecodeError>; 14 | } 15 | 16 | pub trait Encodable { 17 | fn encoded_len(&self) -> usize; 18 | fn encode(&self, bytes: &mut [u8]) -> Result; 19 | } 20 | -------------------------------------------------------------------------------- /src/codec/string.rs: -------------------------------------------------------------------------------- 1 | use core::{cmp::min, convert::TryFrom, result::Result, str}; 2 | 3 | use crate::{ 4 | error::{DecodeError, EncodeError}, 5 | status::Status, 6 | }; 7 | 8 | use super::{values, Decodable, Encodable}; 9 | 10 | impl<'buf> Decodable<'buf> for &'buf str { 11 | fn decode(bytes: &'buf [u8]) -> Result, DecodeError> { 12 | parse_string(bytes) 13 | } 14 | } 15 | 16 | impl Encodable for str { 17 | fn encoded_len(&self) -> usize { 18 | 2 + self.len() 19 | } 20 | 21 | fn encode(&self, bytes: &mut [u8]) -> Result { 22 | encode_string(self, bytes) 23 | } 24 | } 25 | 26 | pub fn parse_string(bytes: &[u8]) -> Result, DecodeError> { 27 | let offset = 0; 28 | 29 | let (offset, string_len) = read!(values::parse_u16, bytes, offset); 30 | 31 | let available = bytes.len() - offset; 32 | 33 | let needed = string_len as usize - min(available, string_len as usize); 34 | if needed > 0 { 35 | return Ok(Status::Partial(needed)); 36 | } 37 | 38 | let val = if string_len > 0 { 39 | // Rust string slices are never in the code point range 0xD800 and 40 | // 0xDFFF which takes care of requirement MQTT-1.5.3-1. str::from_utf8 41 | // will fail if those code points are found in "bytes". 42 | // 43 | // Rust utf-8 decoding also takes care of MQTT-1.5.3-3. U+FEFF does not 44 | // get ignored/stripped off. 45 | str::from_utf8(&bytes[2..(2 + string_len) as usize])? 46 | } else { 47 | "" 48 | }; 49 | 50 | // Requirement MQTT-1.5.3-2 requires that there be no U+0000 code points 51 | // in the string. 52 | if val.chars().any(|ch| ch == '\u{0000}') { 53 | return Err(DecodeError::Utf8); 54 | } 55 | 56 | Ok(Status::Complete(((2 + string_len) as usize, val))) 57 | } 58 | 59 | pub fn encode_string(string: &str, bytes: &mut [u8]) -> Result { 60 | let size = match u16::try_from(string.len()) { 61 | Err(_) => return Err(EncodeError::ValueTooBig), 62 | Ok(s) => s, 63 | }; 64 | 65 | if bytes.len() < (2 + size) as usize { 66 | return Err(EncodeError::OutOfSpace); 67 | } 68 | 69 | values::encode_u16(size, &mut bytes[0..2])?; 70 | (&mut bytes[2..2 + size as usize]).copy_from_slice(string.as_bytes()); 71 | 72 | Ok(2 + size as usize) 73 | } 74 | 75 | #[cfg(test)] 76 | mod tests { 77 | use super::*; 78 | use std::{ 79 | format, 80 | io::{Cursor, Write}, 81 | vec::Vec, 82 | }; 83 | 84 | use byteorder::{BigEndian, ByteOrder}; 85 | 86 | use byteorder::WriteBytesExt; 87 | 88 | #[test] 89 | fn small_buffer() { 90 | assert_eq!(Ok(Status::Partial(2)), parse_string(&[])); 91 | assert_eq!(Ok(Status::Partial(1)), parse_string(&[0])); 92 | 93 | let mut buf = [0u8; 2]; 94 | BigEndian::write_u16(&mut buf, 16); 95 | assert_eq!(Ok(Status::Partial(16)), parse_string(&buf)); 96 | } 97 | 98 | #[test] 99 | fn empty_str() { 100 | let mut buf = [0u8; 2]; 101 | BigEndian::write_u16(&mut buf, 0); 102 | assert_eq!(Ok(Status::Complete((2, ""))), parse_string(&buf)); 103 | } 104 | 105 | #[test] 106 | fn parse_str() { 107 | let inp = "don't panic!"; 108 | let mut buf = Cursor::new(Vec::new()); 109 | buf.write_u16::(inp.len() as u16).unwrap(); 110 | buf.write(inp.as_bytes()).unwrap(); 111 | assert_eq!( 112 | Status::Complete((14, inp)), 113 | parse_string(buf.get_ref().as_ref()).unwrap() 114 | ); 115 | } 116 | 117 | #[test] 118 | fn invalid_utf8() { 119 | let inp = [0, 159, 146, 150]; 120 | let mut buf = Cursor::new(Vec::new()); 121 | buf.write_u16::(inp.len() as u16).unwrap(); 122 | buf.write(&inp).unwrap(); 123 | assert_eq!(Err(DecodeError::Utf8), parse_string(buf.get_ref().as_ref())); 124 | } 125 | 126 | #[test] 127 | fn null_utf8() { 128 | let inp = format!("don't {} panic!", '\u{0000}'); 129 | let mut buf = Cursor::new(Vec::new()); 130 | buf.write_u16::(inp.len() as u16).unwrap(); 131 | buf.write(inp.as_bytes()).unwrap(); 132 | assert_eq!(Err(DecodeError::Utf8), parse_string(buf.get_ref().as_ref())); 133 | } 134 | 135 | #[test] 136 | fn encode() { 137 | let mut buf = [0u8; 3]; 138 | let result = encode_string("a", &mut buf[0..3]); 139 | assert_eq!(result, Ok(3)); 140 | assert_eq!(buf, [0b00000000, 0b00000001, 0x61]); 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /src/codec/values.rs: -------------------------------------------------------------------------------- 1 | use core::{cmp::min, convert::TryFrom, result::Result}; 2 | 3 | use crate::{ 4 | error::{DecodeError, EncodeError}, 5 | status::Status, 6 | }; 7 | 8 | use super::{Decodable, Encodable}; 9 | 10 | use byteorder::{BigEndian, ByteOrder}; 11 | 12 | pub fn parse_u8(bytes: &[u8]) -> Result, DecodeError> { 13 | if bytes.is_empty() { 14 | return Ok(Status::Partial(1)); 15 | } 16 | 17 | Ok(Status::Complete((1, bytes[0]))) 18 | } 19 | 20 | pub fn encode_u8(value: u8, bytes: &mut [u8]) -> Result { 21 | if bytes.is_empty() { 22 | return Err(EncodeError::OutOfSpace); 23 | } 24 | 25 | bytes[0] = value; 26 | Ok(1) 27 | } 28 | 29 | pub fn parse_u16(bytes: &[u8]) -> Result, DecodeError> { 30 | if bytes.len() < 2 { 31 | return Ok(Status::Partial(2 - bytes.len())); 32 | } 33 | 34 | Ok(Status::Complete((2, BigEndian::read_u16(&bytes[0..2])))) 35 | } 36 | 37 | pub fn encode_u16(value: u16, bytes: &mut [u8]) -> Result { 38 | if bytes.len() < 2 { 39 | return Err(EncodeError::OutOfSpace); 40 | } 41 | 42 | BigEndian::write_u16(&mut bytes[0..2], value); 43 | Ok(2) 44 | } 45 | 46 | impl<'buf> Decodable<'buf> for &'buf [u8] { 47 | fn decode(bytes: &'buf [u8]) -> Result, DecodeError> { 48 | parse_bytes(bytes) 49 | } 50 | } 51 | 52 | impl Encodable for [u8] { 53 | fn encoded_len(&self) -> usize { 54 | 2 + self.len() 55 | } 56 | 57 | fn encode(&self, bytes: &mut [u8]) -> Result { 58 | encode_bytes(self, bytes) 59 | } 60 | } 61 | 62 | pub fn parse_bytes(bytes: &[u8]) -> Result, DecodeError> { 63 | let offset = 0; 64 | let (offset, len) = read!(parse_u16, bytes, offset); 65 | 66 | let available = bytes.len() - offset; 67 | let needed = len as usize - min(available, len as usize); 68 | if needed > 0 { 69 | return Ok(Status::Partial(needed)); 70 | } 71 | let payload = &bytes[offset..offset + len as usize]; 72 | 73 | Ok(Status::Complete((offset + len as usize, payload))) 74 | } 75 | 76 | pub fn encode_bytes(value: &[u8], bytes: &mut [u8]) -> Result { 77 | let size = match u16::try_from(value.len()) { 78 | Err(_) => return Err(EncodeError::ValueTooBig), 79 | Ok(s) => s, 80 | }; 81 | 82 | let offset = encode_u16(size, bytes)?; 83 | 84 | let payload_size = value.len(); 85 | if offset + payload_size > bytes.len() { 86 | return Err(EncodeError::OutOfSpace); 87 | } 88 | 89 | (&mut bytes[offset..offset + payload_size as usize]).copy_from_slice(value); 90 | 91 | Ok(offset + payload_size) 92 | } 93 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use core::{convert::From, fmt, str::Utf8Error}; 2 | 3 | use crate::qos; 4 | 5 | #[derive(Copy, Clone, PartialEq, Eq, Debug)] 6 | pub enum DecodeError { 7 | /// Invalid packet type in header 8 | PacketType, 9 | /// Invalid packet type flag in header 10 | PacketFlag, 11 | /// Malformed remaining length in header 12 | RemainingLength, 13 | /// Invalid buffer length 14 | InvalidLength, 15 | /// Invalid UTF-8 encoding 16 | Utf8, 17 | /// Invalid QoS value 18 | InvalidQoS(qos::Error), 19 | /// Invalid protocol level 20 | InvalidProtocolLevel, 21 | /// Invalid connect flag value 22 | InvalidConnectFlag, 23 | /// Invalid Connack flag 24 | InvalidConnackFlag, 25 | /// Invalid Connack Return Code 26 | InvalidConnackReturnCode, 27 | /// Invalid Suback Return Code 28 | InvalidSubackReturnCode, 29 | } 30 | 31 | impl DecodeError { 32 | fn desc(&self) -> &'static str { 33 | match *self { 34 | DecodeError::PacketType => "invalid packet type in header", 35 | DecodeError::PacketFlag => "invalid packet type flag in header", 36 | DecodeError::RemainingLength => "malformed remaining length in header", 37 | DecodeError::InvalidLength => "invalid buffer length", 38 | DecodeError::Utf8 => "invalid utf-8 encoding", 39 | DecodeError::InvalidQoS(_) => "invalid QoS bit pattern", 40 | DecodeError::InvalidProtocolLevel => "invalid protocol level", 41 | DecodeError::InvalidConnectFlag => "invalid connect flag value", 42 | DecodeError::InvalidConnackFlag => "invalid connack flag value", 43 | DecodeError::InvalidConnackReturnCode => "invalid connack return code", 44 | DecodeError::InvalidSubackReturnCode => "invalid suback return code", 45 | } 46 | } 47 | } 48 | 49 | impl fmt::Display for DecodeError { 50 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 51 | f.write_str(self.desc()) 52 | } 53 | } 54 | 55 | #[cfg(feature = "std")] 56 | impl ::std::error::Error for DecodeError { 57 | fn description(&self) -> &str { 58 | self.desc() 59 | } 60 | } 61 | 62 | impl From for DecodeError { 63 | fn from(_: Utf8Error) -> Self { 64 | DecodeError::Utf8 65 | } 66 | } 67 | 68 | impl From for DecodeError { 69 | fn from(err: qos::Error) -> Self { 70 | DecodeError::InvalidQoS(err) 71 | } 72 | } 73 | 74 | #[derive(Copy, Clone, PartialEq, Eq, Debug)] 75 | pub enum EncodeError { 76 | /// Not enough space in buffer to encode 77 | OutOfSpace, 78 | /// Value too big for field 79 | ValueTooBig, 80 | } 81 | 82 | impl EncodeError { 83 | fn desc(&self) -> &'static str { 84 | match *self { 85 | EncodeError::OutOfSpace => "not enough space in encode buffer", 86 | EncodeError::ValueTooBig => "value too big to ever be encoded", 87 | } 88 | } 89 | } 90 | 91 | impl fmt::Display for EncodeError { 92 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 93 | f.write_str(self.desc()) 94 | } 95 | } 96 | 97 | #[cfg(feature = "std")] 98 | impl ::std::error::Error for EncodeError { 99 | fn description(&self) -> &str { 100 | self.desc() 101 | } 102 | } 103 | 104 | impl From for EncodeError { 105 | fn from(_err: core::num::TryFromIntError) -> EncodeError { 106 | EncodeError::ValueTooBig 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/fixed_header/mod.rs: -------------------------------------------------------------------------------- 1 | use core::result::Result; 2 | 3 | use crate::{ 4 | codec::{self, Decodable, Encodable}, 5 | error::{DecodeError, EncodeError}, 6 | status::Status, 7 | }; 8 | 9 | mod packet_flags; 10 | mod packet_type; 11 | 12 | pub use self::{ 13 | packet_flags::{PacketFlags, PublishFlags}, 14 | packet_type::PacketType, 15 | }; 16 | 17 | #[derive(Copy, Clone, PartialEq, Eq, Debug)] 18 | pub struct FixedHeader { 19 | r#type: PacketType, 20 | flags: PacketFlags, 21 | len: u32, 22 | } 23 | 24 | impl FixedHeader { 25 | pub fn new(r#type: PacketType, flags: PacketFlags, len: u32) -> Self { 26 | FixedHeader { r#type, flags, len } 27 | } 28 | 29 | pub fn r#type(&self) -> PacketType { 30 | self.r#type 31 | } 32 | 33 | pub fn flags(&self) -> PacketFlags { 34 | self.flags 35 | } 36 | 37 | pub fn len(&self) -> u32 { 38 | self.len 39 | } 40 | 41 | pub fn is_empty(&self) -> bool { 42 | self.len == 0 43 | } 44 | } 45 | 46 | impl<'buf> Decodable<'buf> for FixedHeader { 47 | fn decode(bytes: &'buf [u8]) -> Result, DecodeError> { 48 | // "bytes" must be at least 2 bytes long to be a valid fixed header 49 | if bytes.len() < 2 { 50 | return Ok(Status::Partial(2 - bytes.len())); 51 | } 52 | 53 | let (r#type, flags) = parse_packet_type(bytes[0])?; 54 | 55 | let offset = 1; 56 | 57 | let (offset, len) = read!(parse_remaining_length, bytes, offset); 58 | 59 | Ok(Status::Complete((offset, Self { r#type, flags, len }))) 60 | } 61 | } 62 | 63 | impl Encodable for FixedHeader { 64 | fn encoded_len(&self) -> usize { 65 | let mut buf = [0u8; 4]; 66 | let u = encode_remaining_length(self.len, &mut buf); 67 | 1 + u 68 | } 69 | 70 | fn encode(&self, bytes: &mut [u8]) -> Result { 71 | let mut offset = 0; 72 | offset += codec::values::encode_u8( 73 | encode_packet_type(self.r#type, self.flags), 74 | &mut bytes[offset..], 75 | )?; 76 | 77 | let mut remaining_length = [0u8; 4]; 78 | let o = encode_remaining_length(self.len, &mut remaining_length); 79 | (&mut bytes[offset..offset + o]).copy_from_slice(&remaining_length[..o]); 80 | offset += o; 81 | 82 | Ok(offset) 83 | } 84 | } 85 | 86 | fn parse_remaining_length(bytes: &[u8]) -> Result, DecodeError> { 87 | let mut multiplier = 1; 88 | let mut value = 0u32; 89 | let mut index = 0; 90 | 91 | loop { 92 | if multiplier > 128 * 128 * 128 { 93 | return Err(DecodeError::RemainingLength); 94 | } 95 | 96 | if index >= bytes.len() { 97 | return Ok(Status::Partial(1)); 98 | } 99 | 100 | let byte = bytes[index]; 101 | index += 1; 102 | 103 | value += (byte & 0b01111111) as u32 * multiplier; 104 | 105 | multiplier *= 128; 106 | 107 | if byte & 128 == 0 { 108 | return Ok(Status::Complete((index, value))); 109 | } 110 | } 111 | } 112 | 113 | fn encode_remaining_length(mut len: u32, buf: &mut [u8; 4]) -> usize { 114 | let mut index = 0; 115 | loop { 116 | let mut byte = len as u8 % 128; 117 | len /= 128; 118 | if len > 0 { 119 | byte |= 128; 120 | } 121 | buf[index] = byte; 122 | index += 1; 123 | 124 | if len == 0 { 125 | break index; 126 | } 127 | } 128 | } 129 | 130 | fn parse_packet_type(inp: u8) -> Result<(PacketType, PacketFlags), DecodeError> { 131 | // high 4 bits are the packet type 132 | let packet_type = match (inp & 0xF0) >> 4 { 133 | 1 => PacketType::Connect, 134 | 2 => PacketType::Connack, 135 | 3 => PacketType::Publish, 136 | 4 => PacketType::Puback, 137 | 5 => PacketType::Pubrec, 138 | 6 => PacketType::Pubrel, 139 | 7 => PacketType::Pubcomp, 140 | 8 => PacketType::Subscribe, 141 | 9 => PacketType::Suback, 142 | 10 => PacketType::Unsubscribe, 143 | 11 => PacketType::Unsuback, 144 | 12 => PacketType::Pingreq, 145 | 13 => PacketType::Pingresp, 146 | 14 => PacketType::Disconnect, 147 | _ => return Err(DecodeError::PacketType), 148 | }; 149 | 150 | // low 4 bits represent control flags 151 | let flags = PacketFlags(inp & 0xF); 152 | 153 | validate_flag(packet_type, flags) 154 | } 155 | 156 | fn encode_packet_type(r#type: PacketType, flags: PacketFlags) -> u8 { 157 | let packet_type: u8 = match r#type { 158 | PacketType::Connect => 1, 159 | PacketType::Connack => 2, 160 | PacketType::Publish => 3, 161 | PacketType::Puback => 4, 162 | PacketType::Pubrec => 5, 163 | PacketType::Pubrel => 6, 164 | PacketType::Pubcomp => 7, 165 | PacketType::Subscribe => 8, 166 | PacketType::Suback => 9, 167 | PacketType::Unsubscribe => 10, 168 | PacketType::Unsuback => 11, 169 | PacketType::Pingreq => 12, 170 | PacketType::Pingresp => 13, 171 | PacketType::Disconnect => 14, 172 | }; 173 | 174 | (packet_type << 4) | flags.0 175 | } 176 | 177 | fn validate_flag( 178 | packet_type: PacketType, 179 | flags: PacketFlags, 180 | ) -> Result<(PacketType, PacketFlags), DecodeError> { 181 | // for the following packet types, the control flag MUST be zero 182 | const ZERO_TYPES: &[PacketType] = &[ 183 | PacketType::Connect, 184 | PacketType::Connack, 185 | PacketType::Puback, 186 | PacketType::Pubrec, 187 | PacketType::Pubcomp, 188 | PacketType::Suback, 189 | PacketType::Unsuback, 190 | PacketType::Pingreq, 191 | PacketType::Pingresp, 192 | PacketType::Disconnect, 193 | ]; 194 | // for the following packet types, the control flag MUST be 0b0010 195 | const ONE_TYPES: &[PacketType] = &[ 196 | PacketType::Pubrel, 197 | PacketType::Subscribe, 198 | PacketType::Unsubscribe, 199 | ]; 200 | 201 | validate_flag_val(packet_type, flags, ZERO_TYPES, PacketFlags(0b0000)) 202 | .and_then(|_| validate_flag_val(packet_type, flags, ONE_TYPES, PacketFlags(0b0010))) 203 | } 204 | 205 | fn validate_flag_val( 206 | packet_type: PacketType, 207 | flags: PacketFlags, 208 | types: &[PacketType], 209 | expected_flags: PacketFlags, 210 | ) -> Result<(PacketType, PacketFlags), DecodeError> { 211 | if types.iter().any(|&v| v == packet_type) && flags != expected_flags { 212 | return Err(DecodeError::PacketFlag); 213 | } 214 | 215 | Ok((packet_type, flags)) 216 | } 217 | 218 | #[cfg(test)] 219 | mod tests { 220 | use super::*; 221 | use rayon::prelude::*; 222 | use std::format; 223 | 224 | #[test] 225 | fn packet_type() { 226 | let mut inputs: [([u8; 1], PacketType); 14] = [ 227 | ([01 << 4 | 0b0000], PacketType::Connect), 228 | ([02 << 4 | 0b0000], PacketType::Connack), 229 | ([03 << 4 | 0b0000], PacketType::Publish), 230 | ([04 << 4 | 0b0000], PacketType::Puback), 231 | ([05 << 4 | 0b0000], PacketType::Pubrec), 232 | ([06 << 4 | 0b0010], PacketType::Pubrel), 233 | ([07 << 4 | 0b0000], PacketType::Pubcomp), 234 | ([08 << 4 | 0b0010], PacketType::Subscribe), 235 | ([09 << 4 | 0b0000], PacketType::Suback), 236 | ([10 << 4 | 0b0010], PacketType::Unsubscribe), 237 | ([11 << 4 | 0b0000], PacketType::Unsuback), 238 | ([12 << 4 | 0b0000], PacketType::Pingreq), 239 | ([13 << 4 | 0b0000], PacketType::Pingresp), 240 | ([14 << 4 | 0b0000], PacketType::Disconnect), 241 | ]; 242 | 243 | for (buf, expected_type) in inputs.iter_mut() { 244 | let expected_flag = PacketFlags(buf[0] & 0xF); 245 | let (packet_type, flag) = parse_packet_type(buf[0]).unwrap(); 246 | assert_eq!(packet_type, *expected_type); 247 | assert_eq!(flag, expected_flag); 248 | } 249 | } 250 | 251 | #[test] 252 | fn bad_packet_type() { 253 | let result = parse_packet_type(15 << 4); 254 | assert_eq!(result, Err(DecodeError::PacketType)); 255 | } 256 | 257 | #[test] 258 | fn bad_zero_flags() { 259 | let mut inputs: [([u8; 1], PacketType); 10] = [ 260 | ([01 << 4 | 1], PacketType::Connect), 261 | ([02 << 4 | 1], PacketType::Connack), 262 | ([04 << 4 | 1], PacketType::Puback), 263 | ([05 << 4 | 1], PacketType::Pubrec), 264 | ([07 << 4 | 1], PacketType::Pubcomp), 265 | ([09 << 4 | 1], PacketType::Suback), 266 | ([11 << 4 | 1], PacketType::Unsuback), 267 | ([12 << 4 | 1], PacketType::Pingreq), 268 | ([13 << 4 | 1], PacketType::Pingresp), 269 | ([14 << 4 | 1], PacketType::Disconnect), 270 | ]; 271 | for (buf, _) in inputs.iter_mut() { 272 | let result = parse_packet_type(buf[0]); 273 | assert_eq!(result, Err(DecodeError::PacketFlag)); 274 | } 275 | } 276 | 277 | #[test] 278 | fn bad_one_flags() { 279 | let mut inputs: [([u8; 1], PacketType); 3] = [ 280 | ([06 << 4 | 0], PacketType::Pubrel), 281 | ([08 << 4 | 0], PacketType::Subscribe), 282 | ([10 << 4 | 0], PacketType::Unsubscribe), 283 | ]; 284 | for (buf, _) in inputs.iter_mut() { 285 | let result = parse_packet_type(buf[0]); 286 | assert_eq!(result, Err(DecodeError::PacketFlag)); 287 | } 288 | } 289 | 290 | #[test] 291 | fn publish_flags() { 292 | for i in 0..15 { 293 | let input = 03 << 4 | i; 294 | let (packet_type, flag) = parse_packet_type(input).unwrap(); 295 | assert_eq!(packet_type, PacketType::Publish); 296 | assert_eq!(flag, PacketFlags(i)); 297 | } 298 | } 299 | 300 | #[test] 301 | #[ignore] 302 | fn remaining_length() { 303 | // NOTE: This test can take a while to complete. 304 | let _: u32 = (0u32..(268435455 + 1)) 305 | .into_par_iter() 306 | .map(|i| { 307 | let mut buf = [0u8; 4]; 308 | let expected_offset = encode_remaining_length(i, &mut buf); 309 | let (offset, len) = parse_remaining_length(&buf) 310 | .expect(&format!("Failed for number: {}", i)) 311 | .unwrap(); 312 | assert_eq!(i, len); 313 | assert_eq!(expected_offset, offset); 314 | 0 315 | }) 316 | .sum(); 317 | } 318 | 319 | #[test] 320 | fn bad_remaining_length() { 321 | let buf = [0xFF, 0xFF, 0xFF, 0xFF]; 322 | let result = parse_remaining_length(&buf); 323 | assert_eq!(result, Err(DecodeError::RemainingLength)); 324 | } 325 | 326 | #[test] 327 | fn bad_remaining_length2() { 328 | let buf = [0xFF, 0xFF]; 329 | let result = parse_remaining_length(&buf); 330 | assert_eq!(result, Ok(Status::Partial(1))); 331 | } 332 | 333 | #[test] 334 | fn fixed_header1() { 335 | let buf = [ 336 | 01 << 4 | 0b0000, // PacketType::Connect 337 | 0, // remaining length 338 | ]; 339 | let (offset, header) = FixedHeader::decode(&buf).unwrap().unwrap(); 340 | assert_eq!(offset, 2); 341 | assert_eq!(header.r#type(), PacketType::Connect); 342 | assert_eq!(header.flags(), PacketFlags(0)); 343 | assert_eq!(header.len(), 0); 344 | } 345 | 346 | #[test] 347 | fn fixed_header2() { 348 | let buf = [ 349 | 03 << 4 | 0b0000, // PacketType::Publish 350 | 0x80, // remaining length 351 | 0x80, 352 | 0x80, 353 | 0x1, 354 | ]; 355 | let (offset, header) = FixedHeader::decode(&buf).unwrap().unwrap(); 356 | assert_eq!(offset, 5); 357 | assert_eq!(header.r#type(), PacketType::Publish); 358 | assert_eq!(header.flags(), PacketFlags(0)); 359 | assert_eq!(header.len(), 2097152); 360 | } 361 | 362 | #[test] 363 | fn bad_len() { 364 | let buf = [03 << 4 | 0]; 365 | let result = FixedHeader::decode(&buf); 366 | assert_eq!(result, Ok(Status::Partial(1))); 367 | } 368 | } 369 | -------------------------------------------------------------------------------- /src/fixed_header/packet_flags.rs: -------------------------------------------------------------------------------- 1 | use core::{ 2 | convert::{From, TryFrom, TryInto}, 3 | fmt::Debug, 4 | result::Result, 5 | }; 6 | 7 | use crate::qos; 8 | 9 | use bitfield::BitRange; 10 | 11 | #[derive(PartialEq, Eq, Debug, Clone, Copy)] 12 | pub struct PacketFlags(pub u8); 13 | 14 | #[allow(dead_code)] 15 | impl PacketFlags { 16 | pub const CONNECT: PacketFlags = PacketFlags(0b0000); 17 | pub const CONNACK: PacketFlags = PacketFlags(0b0000); 18 | // publish is special 19 | pub const PUBACK: PacketFlags = PacketFlags(0b0000); 20 | pub const PUBREC: PacketFlags = PacketFlags(0b0000); 21 | pub const PUBREL: PacketFlags = PacketFlags(0b0010); 22 | pub const PUBCOMP: PacketFlags = PacketFlags(0b0000); 23 | pub const SUBSCRIBE: PacketFlags = PacketFlags(0b0010); 24 | pub const SUBACK: PacketFlags = PacketFlags(0b0000); 25 | pub const UNSUBSCRIBE: PacketFlags = PacketFlags(0b0010); 26 | pub const UNSUBACK: PacketFlags = PacketFlags(0b0000); 27 | pub const PINGREQ: PacketFlags = PacketFlags(0b0000); 28 | pub const PINGRESP: PacketFlags = PacketFlags(0b0000); 29 | pub const DISCONNECT: PacketFlags = PacketFlags(0b0000); 30 | } 31 | 32 | impl From for PacketFlags { 33 | fn from(flags: PublishFlags) -> Self { 34 | PacketFlags(flags.0) 35 | } 36 | } 37 | 38 | #[derive(PartialEq, Eq, Clone, Copy, Default)] 39 | pub struct PublishFlags(u8); 40 | 41 | bitfield_bitrange! { 42 | struct PublishFlags(u8) 43 | } 44 | 45 | impl PublishFlags { 46 | bitfield_fields! { 47 | bool; 48 | pub dup, set_dup : 3; 49 | pub retain, set_retain : 0; 50 | } 51 | 52 | pub fn qos(&self) -> Result { 53 | let qos_bits: u8 = self.bit_range(2, 1); 54 | qos_bits.try_into() 55 | } 56 | 57 | #[allow(dead_code)] 58 | pub fn set_qos(&mut self, qos: qos::QoS) { 59 | self.set_bit_range(2, 1, u8::from(qos)) 60 | } 61 | } 62 | 63 | impl Debug for PublishFlags { 64 | bitfield_debug! { 65 | struct PublishFlags; 66 | pub dup, _ : 3; 67 | pub into qos::QoS, qos, _ : 2, 1; 68 | pub retain, _ : 0; 69 | } 70 | } 71 | 72 | impl TryFrom for PublishFlags { 73 | type Error = qos::Error; 74 | fn try_from(flags: PacketFlags) -> Result { 75 | let flags = PublishFlags(flags.0); 76 | flags.qos()?; 77 | Ok(flags) 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/fixed_header/packet_type.rs: -------------------------------------------------------------------------------- 1 | #[derive(Copy, Clone, PartialEq, Eq, Debug)] 2 | pub enum PacketType { 3 | Connect, 4 | Connack, 5 | Publish, 6 | Puback, 7 | Pubrec, 8 | Pubrel, 9 | Pubcomp, 10 | Subscribe, 11 | Suback, 12 | Unsubscribe, 13 | Unsuback, 14 | Pingreq, 15 | Pingresp, 16 | Disconnect, 17 | } 18 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![no_std] 2 | 3 | #[cfg(any(feature = "std", test))] 4 | #[macro_use] 5 | extern crate std; 6 | #[cfg(test)] 7 | extern crate rayon; 8 | 9 | extern crate byteorder; 10 | 11 | #[macro_use] 12 | extern crate bitfield; 13 | 14 | #[macro_use] 15 | pub mod status; 16 | pub mod error; 17 | 18 | pub mod codec; 19 | 20 | pub mod fixed_header; 21 | pub mod packet; 22 | pub mod payload; 23 | pub mod variable_header; 24 | 25 | pub mod qos; 26 | -------------------------------------------------------------------------------- /src/packet.rs: -------------------------------------------------------------------------------- 1 | use core::{cmp::min, convert::TryFrom, default::Default, result::Result}; 2 | 3 | use crate::{ 4 | codec::{Decodable, Encodable}, 5 | error::{DecodeError, EncodeError}, 6 | fixed_header::{self, FixedHeader}, 7 | payload::{self, Payload}, 8 | qos, 9 | status::Status, 10 | variable_header::{self, VariableHeader}, 11 | }; 12 | 13 | #[derive(Debug)] 14 | #[allow(dead_code)] 15 | pub struct Packet<'a> { 16 | fixed_header: FixedHeader, 17 | variable_header: Option>, 18 | payload: Payload<'a>, 19 | } 20 | 21 | /// A full MQTT packet with fixed header, variable header and payload. 22 | /// 23 | /// Variable header and payload are optional for some packet types. 24 | impl<'a> Packet<'a> { 25 | /// Create a CONNECT packet. 26 | pub fn connect( 27 | variable_header: variable_header::connect::Connect<'a>, 28 | payload: payload::connect::Connect<'a>, 29 | ) -> Result { 30 | Self::packet( 31 | fixed_header::PacketType::Connect, 32 | fixed_header::PacketFlags::CONNECT, 33 | Some(variable_header::VariableHeader::Connect(variable_header)), 34 | payload::Payload::Connect(payload), 35 | ) 36 | } 37 | 38 | /// Create a SUBSCRIBE packet. 39 | pub fn subscribe( 40 | variable_header: variable_header::packet_identifier::PacketIdentifier, 41 | payload: payload::subscribe::Subscribe<'a>, 42 | ) -> Result { 43 | Self::packet( 44 | fixed_header::PacketType::Subscribe, 45 | fixed_header::PacketFlags::SUBSCRIBE, 46 | Some(variable_header::VariableHeader::Subscribe(variable_header)), 47 | payload::Payload::Subscribe(payload), 48 | ) 49 | } 50 | 51 | /// Create a PUBLISH packet. 52 | pub fn publish( 53 | flags: fixed_header::PublishFlags, 54 | variable_header: variable_header::publish::Publish<'a>, 55 | payload: &'a [u8], 56 | ) -> Result { 57 | // TODO encode this using type states 58 | assert!( 59 | flags.qos().expect("valid qos") == qos::QoS::AtMostOnce 60 | || variable_header.packet_identifier().is_some() 61 | ); 62 | 63 | Self::packet( 64 | fixed_header::PacketType::Publish, 65 | flags.into(), 66 | Some(variable_header::VariableHeader::Publish(variable_header)), 67 | payload::Payload::Bytes(payload), 68 | ) 69 | } 70 | 71 | pub fn puback( 72 | variable_header: variable_header::packet_identifier::PacketIdentifier, 73 | ) -> Result { 74 | Self::packet( 75 | fixed_header::PacketType::Puback, 76 | fixed_header::PacketFlags::PUBACK, 77 | Some(variable_header::VariableHeader::Puback(variable_header)), 78 | Default::default(), 79 | ) 80 | } 81 | 82 | /// Create a PINGREQ packet. 83 | pub fn pingreq() -> Self { 84 | Self { 85 | fixed_header: FixedHeader::new( 86 | fixed_header::PacketType::Pingreq, 87 | fixed_header::PacketFlags::PINGREQ, 88 | 0, 89 | ), 90 | variable_header: None, 91 | payload: Default::default(), 92 | } 93 | } 94 | 95 | /// Create a PINGRESP packet. 96 | pub fn pingresp() -> Self { 97 | Self { 98 | fixed_header: FixedHeader::new( 99 | fixed_header::PacketType::Pingresp, 100 | fixed_header::PacketFlags::PINGRESP, 101 | 0, 102 | ), 103 | variable_header: None, 104 | payload: Default::default(), 105 | } 106 | } 107 | 108 | /// Create a packet with the given type, flags, variable header and payload. 109 | /// 110 | /// Constructs a fixed header with the appropriate `len` field for the given 111 | /// variable header and payload. 112 | fn packet( 113 | r#type: fixed_header::PacketType, 114 | flags: fixed_header::PacketFlags, 115 | variable_header: Option>, 116 | payload: Payload<'a>, 117 | ) -> Result { 118 | let len = u32::try_from( 119 | variable_header 120 | .as_ref() 121 | .map(VariableHeader::encoded_len) 122 | .unwrap_or(0) 123 | + payload.encoded_len(), 124 | )?; 125 | 126 | Ok(Self { 127 | fixed_header: FixedHeader::new(r#type, flags, len), 128 | variable_header, 129 | payload, 130 | }) 131 | } 132 | 133 | /// Return a reference to the fixed header of the packet. 134 | /// 135 | /// The len field of the returned header will be valid. 136 | pub fn fixed_header(&self) -> &FixedHeader { 137 | &self.fixed_header 138 | } 139 | 140 | /// Return a reference to the variable header of the packet. 141 | pub fn variable_header(&self) -> &Option { 142 | &self.variable_header 143 | } 144 | 145 | /// Return a reference to the payload of the packet. 146 | pub fn payload(&self) -> &Payload { 147 | &self.payload 148 | } 149 | } 150 | 151 | impl<'a> Decodable<'a> for Packet<'a> { 152 | /// Decode any MQTT packet from a pre-allocated buffer. 153 | /// 154 | /// If an unrecoverable error occurs an `Err(x)` is returned, the caller should 155 | /// disconnect and network connection and discard the contents of the connection 156 | /// receive buffer. 157 | /// 158 | /// Decoding may return an `Ok(Status::Partial(x))` in which case the caller 159 | /// should buffer at most `x` more bytes and then attempt decoding again. 160 | /// 161 | /// If decoding succeeds an `Ok(Status::Complete(x))` will be returned 162 | /// containing the number of bytes read from the buffer and the decoded packet. 163 | /// The lifetime of the decoded packet is tied to the input buffer. 164 | fn decode(bytes: &'a [u8]) -> Result, DecodeError> { 165 | let (fixed_header_offset, fixed_header) = read!(FixedHeader::decode, bytes, 0); 166 | 167 | let (variable_header_consumed, variable_header) = if let Some(result) = 168 | VariableHeader::decode( 169 | fixed_header.r#type(), 170 | fixed_header.flags(), 171 | &bytes[fixed_header_offset..], 172 | ) { 173 | let (variable_header_offset, variable_header) = complete!(result); 174 | (variable_header_offset, Some(variable_header)) 175 | } else { 176 | (0, None) 177 | }; 178 | 179 | let payload_len = fixed_header.len() as usize - variable_header_consumed; 180 | 181 | let available = bytes.len() - (fixed_header_offset + variable_header_consumed); 182 | let needed = payload_len - min(available, payload_len); 183 | if needed > 0 { 184 | return Ok(Status::Partial(needed)); 185 | } 186 | 187 | let payload_bytes = &bytes[fixed_header_offset + variable_header_consumed 188 | ..fixed_header_offset + variable_header_consumed + payload_len]; 189 | 190 | let payload = if let Some(result) = Payload::decode(fixed_header.r#type(), payload_bytes) { 191 | match result { 192 | Err(e) => return Err(e), 193 | Ok(Status::Partial(n)) => return Ok(Status::Partial(n)), 194 | Ok(Status::Complete((_, payload))) => payload, 195 | } 196 | } else { 197 | payload::Payload::Bytes(payload_bytes) 198 | }; 199 | 200 | Ok(Status::Complete(( 201 | fixed_header_offset + fixed_header.len() as usize, 202 | Self { 203 | fixed_header, 204 | variable_header, 205 | payload, 206 | }, 207 | ))) 208 | } 209 | } 210 | 211 | impl<'a> Encodable for Packet<'a> { 212 | /// Calculate the exact length of the fully encoded packet. 213 | /// 214 | /// The encode buffer will need to hold at least this number of bytes. 215 | fn encoded_len(&self) -> usize { 216 | self.fixed_header.encoded_len() + self.fixed_header.len() as usize 217 | } 218 | 219 | /// Encode a packet for sending over a network connection. 220 | /// 221 | /// If encoding fails an `Err(x)` is returned. 222 | /// 223 | /// If encoding succeeds an `Ok(written)` is returned with the number of 224 | /// bytes written to the buffer. 225 | fn encode(&self, bytes: &mut [u8]) -> Result { 226 | let mut offset = 0; 227 | 228 | offset += self.fixed_header.encode(&mut bytes[offset..])?; 229 | if let Some(ref variable_header) = self.variable_header { 230 | offset += variable_header.encode(&mut bytes[offset..])?; 231 | } 232 | offset += self.payload.encode(&mut bytes[offset..])?; 233 | 234 | Ok(offset) 235 | } 236 | } 237 | 238 | #[cfg(test)] 239 | mod tests { 240 | use super::*; 241 | 242 | #[test] 243 | fn encode_publish() { 244 | let payload = b"{}"; 245 | assert_eq!(2, payload.len()); 246 | 247 | let mut publish_flags = fixed_header::PublishFlags::default(); 248 | publish_flags.set_qos(qos::QoS::AtLeastOnce); 249 | let publish_id = 2; 250 | let publish = Packet::publish( 251 | publish_flags, 252 | variable_header::publish::Publish::new("a/b", Some(publish_id)), 253 | payload, 254 | ) 255 | .expect("valid packet"); 256 | 257 | assert_eq!(11, publish.encoded_len()); 258 | assert_eq!(2, publish.fixed_header().encoded_len()); 259 | assert_eq!(9, publish.fixed_header().len()); 260 | assert_eq!( 261 | 7, 262 | publish 263 | .variable_header() 264 | .as_ref() 265 | .expect("variable header") 266 | .encoded_len() 267 | ); 268 | assert_eq!(2, publish.payload().encoded_len()); 269 | } 270 | 271 | #[test] 272 | fn encode_subscribe() { 273 | let subscribe_id = 1; 274 | let sub = Packet::subscribe( 275 | variable_header::packet_identifier::PacketIdentifier::new(subscribe_id), 276 | payload::subscribe::Subscribe::new(&[ 277 | ("c/a", qos::QoS::AtMostOnce), 278 | ("c/b", qos::QoS::AtLeastOnce), 279 | ("c/c", qos::QoS::ExactlyOnce), 280 | ]), 281 | ) 282 | .expect("valid packet"); 283 | 284 | assert_eq!(22, sub.encoded_len()); 285 | assert_eq!(2, sub.fixed_header().encoded_len()); 286 | assert_eq!(20, sub.fixed_header().len()); 287 | assert_eq!( 288 | 2, 289 | sub.variable_header() 290 | .as_ref() 291 | .expect("variable header") 292 | .encoded_len() 293 | ); 294 | assert_eq!(18, sub.payload().encoded_len()); 295 | } 296 | } 297 | -------------------------------------------------------------------------------- /src/payload/connect.rs: -------------------------------------------------------------------------------- 1 | #![allow(warnings)] 2 | 3 | use core::result::Result; 4 | 5 | use crate::{ 6 | codec::{self, Decodable, Encodable}, 7 | error::{DecodeError, EncodeError}, 8 | status::Status, 9 | variable_header::connect::Flags, 10 | }; 11 | 12 | #[derive(Debug)] 13 | pub struct Will<'buf> { 14 | topic: &'buf str, 15 | message: &'buf [u8], 16 | } 17 | 18 | impl<'buf> Decodable<'buf> for Will<'buf> { 19 | fn decode(bytes: &'buf [u8]) -> Result)>, DecodeError> { 20 | let offset = 0; 21 | let (offset, topic) = read!(codec::string::parse_string, bytes, offset); 22 | let (offset, message) = read!(codec::values::parse_bytes, bytes, offset); 23 | 24 | Ok(Status::Complete((offset, Will { topic, message }))) 25 | } 26 | } 27 | 28 | impl<'buf> Encodable for Will<'buf> { 29 | fn encoded_len(&self) -> usize { 30 | 2 + self.topic.len() + 2 + self.message.len() 31 | } 32 | 33 | fn encode(&self, bytes: &mut [u8]) -> Result { 34 | let mut offset = 0; 35 | offset += codec::string::encode_string(self.topic, &mut bytes[offset..])?; 36 | offset += codec::values::encode_bytes(self.message, &mut bytes[offset..])?; 37 | Ok(offset) 38 | } 39 | } 40 | 41 | impl<'buf> Will<'buf> { 42 | pub fn new(topic: &'buf str, message: &'buf [u8]) -> Self { 43 | Will { topic, message } 44 | } 45 | } 46 | 47 | #[derive(Debug)] 48 | pub struct Connect<'buf> { 49 | client_id: &'buf str, 50 | will: Option>, 51 | username: Option<&'buf str>, 52 | password: Option<&'buf [u8]>, 53 | } 54 | 55 | impl<'buf> Connect<'buf> { 56 | pub fn new( 57 | client_id: &'buf str, 58 | will: Option>, 59 | username: Option<&'buf str>, 60 | password: Option<&'buf [u8]>, 61 | ) -> Self { 62 | Connect { 63 | client_id, 64 | will, 65 | username, 66 | password, 67 | } 68 | } 69 | } 70 | 71 | impl<'buf> Connect<'buf> { 72 | pub fn decode(flags: Flags, bytes: &'buf [u8]) -> Result, DecodeError> { 73 | let offset = 0; 74 | 75 | let (offset, client_id) = read!(codec::string::parse_string, bytes, offset); 76 | 77 | let (offset, will) = if flags.has_will() { 78 | let (offset, will) = read!(Will::decode, bytes, offset); 79 | (offset, Some(will)) 80 | } else { 81 | (offset, None) 82 | }; 83 | 84 | let (offset, username) = if flags.has_username() { 85 | let (offset, username) = read!(codec::string::parse_string, bytes, offset); 86 | (offset, Some(username)) 87 | } else { 88 | (offset, None) 89 | }; 90 | 91 | let (offset, password) = if flags.has_password() { 92 | let (offset, password) = read!(codec::values::parse_bytes, bytes, offset); 93 | (offset, Some(bytes)) 94 | } else { 95 | (offset, None) 96 | }; 97 | 98 | Ok(Status::Complete(( 99 | offset, 100 | Connect { 101 | client_id, 102 | will, 103 | username, 104 | password, 105 | }, 106 | ))) 107 | } 108 | } 109 | 110 | impl<'buf> Encodable for Connect<'buf> { 111 | fn encoded_len(&self) -> usize { 112 | self.client_id.encoded_len() 113 | + self.will.as_ref().map(|w| w.encoded_len()).unwrap_or(0) 114 | + self.username.as_ref().map(|u| u.encoded_len()).unwrap_or(0) 115 | + self.password.as_ref().map(|p| p.encoded_len()).unwrap_or(0) 116 | } 117 | 118 | fn encode(&self, bytes: &mut [u8]) -> Result { 119 | let mut offset = 0; 120 | 121 | offset += codec::string::encode_string(self.client_id, &mut bytes[offset..])?; 122 | 123 | if let Some(ref will) = self.will { 124 | offset += will.encode(&mut bytes[offset..])?; 125 | } 126 | 127 | if let Some(username) = self.username { 128 | offset += codec::string::encode_string(username, &mut bytes[offset..])?; 129 | } 130 | 131 | if let Some(password) = self.password { 132 | offset += codec::values::encode_bytes(password, &mut bytes[offset..])?; 133 | } 134 | 135 | Ok(offset) 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /src/payload/mod.rs: -------------------------------------------------------------------------------- 1 | use core::{default::Default, result::Result}; 2 | 3 | use crate::{ 4 | codec::{Decodable, Encodable}, 5 | error::{DecodeError, EncodeError}, 6 | fixed_header::PacketType, 7 | status::Status, 8 | }; 9 | 10 | pub mod connect; 11 | pub mod suback; 12 | pub mod subscribe; 13 | 14 | #[derive(Debug)] 15 | pub enum Payload<'a> { 16 | Bytes(&'a [u8]), 17 | Connect(connect::Connect<'a>), 18 | Subscribe(subscribe::Subscribe<'a>), 19 | Suback(suback::Suback<'a>), 20 | } 21 | 22 | impl<'a> Payload<'a> { 23 | pub fn decode( 24 | r#type: PacketType, 25 | bytes: &'a [u8], 26 | ) -> Option, DecodeError>> { 27 | Some(match r#type { 28 | // TODO need to pass the variable header / flags to the payload parser 29 | //PacketType::Connect => Payload::Connect(complete!(connect::Connect::decode(bytes))), 30 | PacketType::Suback => match suback::Suback::decode(bytes) { 31 | Err(e) => Err(e), 32 | Ok(Status::Partial(p)) => Ok(Status::Partial(p)), 33 | Ok(Status::Complete((offset, p))) => { 34 | Ok(Status::Complete((offset, Payload::Suback(p)))) 35 | } 36 | }, 37 | PacketType::Subscribe => match subscribe::Subscribe::decode(bytes) { 38 | Err(e) => Err(e), 39 | Ok(Status::Partial(p)) => Ok(Status::Partial(p)), 40 | Ok(Status::Complete((offset, p))) => { 41 | Ok(Status::Complete((offset, Payload::Subscribe(p)))) 42 | } 43 | }, 44 | _ => return None, 45 | }) 46 | } 47 | } 48 | 49 | impl<'a> Encodable for Payload<'a> { 50 | fn encoded_len(&self) -> usize { 51 | match self { 52 | Payload::Connect(ref c) => c.encoded_len(), 53 | Payload::Subscribe(ref c) => c.encoded_len(), 54 | Payload::Suback(ref c) => c.encoded_len(), 55 | Payload::Bytes(c) => c.len(), 56 | } 57 | } 58 | 59 | fn encode(&self, bytes: &mut [u8]) -> Result { 60 | match self { 61 | Payload::Connect(ref c) => c.encode(bytes), 62 | Payload::Subscribe(ref c) => c.encode(bytes), 63 | Payload::Suback(ref c) => c.encode(bytes), 64 | Payload::Bytes(c) => { 65 | if bytes.len() < c.len() { 66 | return Err(EncodeError::OutOfSpace); 67 | } 68 | 69 | (&mut bytes[0..c.len()]).copy_from_slice(c); 70 | 71 | Ok(c.len()) 72 | } 73 | } 74 | } 75 | } 76 | 77 | impl<'a> Default for Payload<'a> { 78 | fn default() -> Self { 79 | Payload::Bytes(&[]) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/payload/suback.rs: -------------------------------------------------------------------------------- 1 | use core::{ 2 | convert::{From, TryFrom, TryInto}, 3 | fmt::Debug, 4 | result::Result, 5 | }; 6 | 7 | use crate::{ 8 | codec::{Decodable, Encodable}, 9 | error::{DecodeError, EncodeError}, 10 | qos, 11 | status::Status, 12 | }; 13 | 14 | use bitfield::BitRange; 15 | 16 | #[derive(PartialEq, Eq, Clone, Copy)] 17 | pub struct ReturnCode(u8); 18 | 19 | bitfield_bitrange! { 20 | struct ReturnCode(u8) 21 | } 22 | 23 | impl ReturnCode { 24 | pub const SUCCESS_QOS_0: ReturnCode = ReturnCode(0b0000_0000); 25 | pub const SUCCESS_QOS_1: ReturnCode = ReturnCode(0b0000_0001); 26 | pub const SUCCESS_QOS_2: ReturnCode = ReturnCode(0b0000_0010); 27 | pub const FAILURE: ReturnCode = ReturnCode(0b1000_0000); 28 | 29 | bitfield_fields! { 30 | bool; 31 | pub failure, set_failure : 7; 32 | } 33 | 34 | pub fn max_qos(&self) -> Result { 35 | let qos_bits: u8 = self.bit_range(1, 0); 36 | qos_bits.try_into() 37 | } 38 | 39 | #[allow(dead_code)] 40 | pub fn set_max_qos(&mut self, qos: qos::QoS) { 41 | self.set_bit_range(1, 0, u8::from(qos)) 42 | } 43 | } 44 | 45 | impl Debug for ReturnCode { 46 | bitfield_debug! { 47 | struct ReturnCode; 48 | pub failure, _ : 7; 49 | pub into QoS, max_qos, _ : 1, 0; 50 | } 51 | } 52 | 53 | impl From for u8 { 54 | fn from(val: ReturnCode) -> u8 { 55 | val.0 56 | } 57 | } 58 | 59 | impl TryFrom for ReturnCode { 60 | type Error = (); 61 | fn try_from(val: u8) -> Result { 62 | if 0b0111_1100 & val != 0 { 63 | return Err(()); 64 | } 65 | 66 | let failure = 0b1000_0000 & val; 67 | let success = 0b0000_0011 & val; 68 | 69 | if (success != 0) && (failure != 0) { 70 | return Err(()); 71 | } 72 | 73 | Ok(ReturnCode(val)) 74 | } 75 | } 76 | 77 | #[derive(PartialEq, Eq, Debug)] 78 | pub struct Suback<'a> { 79 | return_codes: &'a [ReturnCode], 80 | } 81 | 82 | impl<'a> Suback<'a> { 83 | pub fn new(return_codes: &'a [ReturnCode]) -> Self { 84 | Self { return_codes } 85 | } 86 | } 87 | 88 | impl<'a> Decodable<'a> for Suback<'a> { 89 | fn decode(bytes: &'a [u8]) -> Result, DecodeError> { 90 | // Check all the bytes are valid return codes 91 | bytes 92 | .iter() 93 | .fold(Ok(()), |acc, byte| { 94 | acc?; 95 | ReturnCode::try_from(*byte).map(|_| ()) 96 | }) 97 | .map_err(|_| DecodeError::InvalidSubackReturnCode)?; 98 | 99 | let return_codes = unsafe { &*(bytes as *const [u8] as *const [ReturnCode]) }; 100 | 101 | Ok(Status::Complete((bytes.len(), Self { return_codes }))) 102 | } 103 | } 104 | 105 | impl<'a> Encodable for Suback<'a> { 106 | fn encoded_len(&self) -> usize { 107 | self.return_codes.len() 108 | } 109 | 110 | fn encode(&self, bytes: &mut [u8]) -> Result { 111 | if bytes.len() < self.return_codes.len() { 112 | return Err(EncodeError::OutOfSpace); 113 | } 114 | 115 | let return_code_bytes = 116 | unsafe { &*(self.return_codes as *const [ReturnCode] as *const [u8]) }; 117 | 118 | (&mut bytes[..self.return_codes.len()]).copy_from_slice(return_code_bytes); 119 | 120 | Ok(self.return_codes.len()) 121 | } 122 | } 123 | 124 | #[cfg(test)] 125 | mod tests { 126 | use super::*; 127 | 128 | #[test] 129 | fn encode() { 130 | let return_codes = [ReturnCode::SUCCESS_QOS_0]; 131 | 132 | let payload = Suback::new(&return_codes[..]); 133 | let mut buf = [0u8; 1]; 134 | let used = payload.encode(&mut buf[..]); 135 | assert_eq!(used, Ok(1)); 136 | assert_eq!(buf, [0b0000_0000]); 137 | } 138 | 139 | #[test] 140 | fn decode() { 141 | let return_code_bytes = [0b1000_0000, 0b0000_0010, 0b0000_0001, 0b0000_0000]; 142 | 143 | let return_codes = [ 144 | ReturnCode::FAILURE, 145 | ReturnCode::SUCCESS_QOS_2, 146 | ReturnCode::SUCCESS_QOS_1, 147 | ReturnCode::SUCCESS_QOS_0, 148 | ]; 149 | 150 | let payload = Suback::decode(&return_code_bytes[..]); 151 | assert_eq!( 152 | payload, 153 | Ok(Status::Complete((4, Suback::new(&return_codes[..])))) 154 | ); 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /src/payload/subscribe.rs: -------------------------------------------------------------------------------- 1 | use core::{ 2 | convert::{From, TryFrom}, 3 | fmt, 4 | iter::Iterator, 5 | result::Result, 6 | }; 7 | 8 | use crate::{ 9 | codec::{self, Decodable, Encodable}, 10 | error::{DecodeError, EncodeError}, 11 | qos, 12 | status::Status, 13 | }; 14 | 15 | pub struct Iter<'a> { 16 | offset: usize, 17 | sub: &'a Subscribe<'a>, 18 | } 19 | 20 | impl<'a> Iter<'a> { 21 | fn new(sub: &'a Subscribe<'a>) -> Self { 22 | Iter { offset: 0, sub } 23 | } 24 | } 25 | 26 | impl<'a> Iterator for Iter<'a> { 27 | type Item = (&'a str, qos::QoS); 28 | fn next(&mut self) -> Option { 29 | match self.sub { 30 | Subscribe::Encode(topics) => { 31 | // Offset is an index into the encode slice 32 | if self.offset >= topics.len() { 33 | return None; 34 | } 35 | 36 | let item = topics[self.offset]; 37 | self.offset += 1; 38 | 39 | Some(item) 40 | } 41 | Subscribe::Decode(bytes) => { 42 | // Offset is a byte offset in the byte slice 43 | if self.offset >= bytes.len() { 44 | return None; 45 | } 46 | 47 | // &bytes[offset..] points to a length, string and QoS 48 | let (o, item) = parse_subscription(&bytes[self.offset..]) 49 | .expect("already validated") 50 | .unwrap(); 51 | self.offset += o; 52 | 53 | Some(item) 54 | } 55 | } 56 | } 57 | } 58 | 59 | pub enum Subscribe<'a> { 60 | Encode(&'a [(&'a str, qos::QoS)]), 61 | Decode(&'a [u8]), 62 | } 63 | 64 | impl<'a> Subscribe<'a> { 65 | pub fn new(topics: &'a [(&'a str, qos::QoS)]) -> Self { 66 | Subscribe::Encode(topics) 67 | } 68 | 69 | pub fn topics(&self) -> Iter { 70 | Iter::new(self) 71 | } 72 | } 73 | 74 | impl<'a> fmt::Debug for Subscribe<'a> { 75 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 76 | writeln!(f, "Subscribe {{")?; 77 | self.topics().fold(Ok(()), |acc, (topic, qos)| { 78 | acc?; 79 | writeln!( 80 | f, 81 | " (\n Topic: {:#?},\n QoS: {:#?}\n )", 82 | topic, qos 83 | ) 84 | })?; 85 | write!(f, "}}")?; 86 | 87 | Ok(()) 88 | } 89 | } 90 | 91 | #[allow(clippy::type_complexity)] 92 | fn parse_subscription( 93 | bytes: &[u8], 94 | ) -> Result, DecodeError> { 95 | let offset = 0; 96 | 97 | let (offset, topic) = { 98 | let (o, topic) = complete!(codec::string::parse_string(&bytes[offset..])); 99 | (offset + o, topic) 100 | }; 101 | 102 | let (offset, qos) = { 103 | let (o, qos) = complete!(codec::values::parse_u8(&bytes[offset..])); 104 | let qos = qos::QoS::try_from(qos)?; 105 | (offset + o, qos) 106 | }; 107 | 108 | Ok(Status::Complete((offset, (topic, qos)))) 109 | } 110 | 111 | impl<'a> Decodable<'a> for Subscribe<'a> { 112 | fn decode(bytes: &'a [u8]) -> Result, DecodeError> { 113 | let mut offset = 0; 114 | while offset < bytes.len() { 115 | let o = match parse_subscription(&bytes[offset..]) { 116 | Err(e) => return Err(e), 117 | Ok(Status::Partial(..)) => return Err(DecodeError::InvalidLength), 118 | Ok(Status::Complete((o, _))) => o, 119 | }; 120 | offset += o; 121 | } 122 | 123 | Ok(Status::Complete((bytes.len(), Subscribe::Decode(bytes)))) 124 | } 125 | } 126 | 127 | impl<'a> Encodable for Subscribe<'a> { 128 | fn encoded_len(&self) -> usize { 129 | self.topics().map(|topic| topic.0.encoded_len() + 1).sum() 130 | } 131 | 132 | fn encode(&self, bytes: &mut [u8]) -> Result { 133 | self.topics().fold(Ok(0), |acc, (topic, qos)| { 134 | let mut offset = acc?; 135 | offset += codec::string::encode_string(topic, &mut bytes[offset..])?; 136 | offset += codec::values::encode_u8(u8::from(qos), &mut bytes[offset..])?; 137 | Ok(offset) 138 | }) 139 | } 140 | } 141 | 142 | #[cfg(test)] 143 | mod tests { 144 | use super::*; 145 | 146 | #[test] 147 | fn decode_literal() { 148 | let topics = [ 149 | ("a", qos::QoS::AtMostOnce), 150 | ("b", qos::QoS::AtLeastOnce), 151 | ("c", qos::QoS::ExactlyOnce), 152 | ]; 153 | 154 | let sub = Subscribe::new(&topics); 155 | 156 | let mut iter = sub.topics(); 157 | 158 | let next = iter.next(); 159 | assert_eq!(next, Some(("a", qos::QoS::AtMostOnce))); 160 | 161 | let next = iter.next(); 162 | assert_eq!(next, Some(("b", qos::QoS::AtLeastOnce))); 163 | 164 | let next = iter.next(); 165 | assert_eq!(next, Some(("c", qos::QoS::ExactlyOnce))); 166 | 167 | let next = iter.next(); 168 | assert_eq!(next, None); 169 | } 170 | 171 | #[test] 172 | fn decode_bytes() { 173 | let bytes = [ 174 | 0b0000_0000, // 1 175 | 0b0000_0001, 176 | 0x61, // 'a' 177 | 0x0000_0000, // AtMostOnce 178 | 0b0000_0000, // 1 179 | 0b0000_0001, 180 | 0x62, // 'b' 181 | 0b0000_0001, // AtLeastOnce 182 | 0b0000_0000, // 1 183 | 0b0000_0001, 184 | 0x63, // 'c' 185 | 0b0000_0010, // ExactlyOnce 186 | ]; 187 | 188 | let (_, sub) = Subscribe::decode(&bytes).expect("valid").unwrap(); 189 | 190 | let mut iter = sub.topics(); 191 | 192 | let next = iter.next(); 193 | assert_eq!(next, Some(("a", qos::QoS::AtMostOnce))); 194 | 195 | let next = iter.next(); 196 | assert_eq!(next, Some(("b", qos::QoS::AtLeastOnce))); 197 | 198 | let next = iter.next(); 199 | assert_eq!(next, Some(("c", qos::QoS::ExactlyOnce))); 200 | 201 | let next = iter.next(); 202 | assert_eq!(next, None); 203 | } 204 | 205 | #[test] 206 | fn decode_bytes_error() { 207 | let bytes = [ 208 | 0b0000_0000, // 1 209 | 0b0000_0001, 210 | 0x61, // 'a' 211 | 0x0000_0000, // AtMostOnce 212 | 0b0000_0000, // 1 213 | 0b0000_0001, 214 | 0x62, // 'b' 215 | 0b0000_0001, // AtLeastOnce 216 | 0b0000_0000, // 1 217 | 0b0000_0001, 218 | 0x63, // 'c' 219 | 220 | // Intentionally omitted 221 | //0b0000_0010, // ExactlyOnce 222 | // 223 | ]; 224 | 225 | let sub = Subscribe::decode(&bytes); 226 | assert!(sub.is_err()); 227 | assert_eq!(sub.unwrap_err(), DecodeError::InvalidLength); 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /src/qos.rs: -------------------------------------------------------------------------------- 1 | use core::{ 2 | convert::{From, TryFrom}, 3 | fmt, 4 | result::Result, 5 | }; 6 | 7 | #[derive(Copy, Clone, PartialEq, Eq, Debug)] 8 | pub enum QoS { 9 | AtMostOnce, 10 | AtLeastOnce, 11 | ExactlyOnce, 12 | } 13 | 14 | #[derive(PartialEq, Eq, Debug, Clone, Copy)] 15 | pub enum Error { 16 | BadPattern, 17 | } 18 | 19 | impl TryFrom for QoS { 20 | type Error = Error; 21 | 22 | fn try_from(byte: u8) -> Result { 23 | let qos = match byte & 0b11 { 24 | 0b00 => QoS::AtMostOnce, 25 | 0b01 => QoS::AtLeastOnce, 26 | 0b10 => QoS::ExactlyOnce, 27 | _ => return Err(Error::BadPattern), 28 | }; 29 | 30 | Ok(qos) 31 | } 32 | } 33 | 34 | impl From for u8 { 35 | fn from(qos: QoS) -> u8 { 36 | match qos { 37 | QoS::AtMostOnce => 0b00, 38 | QoS::AtLeastOnce => 0b01, 39 | QoS::ExactlyOnce => 0b10, 40 | } 41 | } 42 | } 43 | 44 | impl Error { 45 | fn desc(&self) -> &'static str { 46 | match *self { 47 | Error::BadPattern => "invalid QoS bit pattern", 48 | } 49 | } 50 | } 51 | 52 | impl fmt::Display for Error { 53 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 54 | f.write_str(self.desc()) 55 | } 56 | } 57 | 58 | #[cfg(feature = "std")] 59 | impl ::std::error::Error for Error { 60 | fn description(&self) -> &str { 61 | self.desc() 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/status.rs: -------------------------------------------------------------------------------- 1 | /// The result of a successful parse pass. Taken from the `httparse` crate. 2 | /// 3 | /// `Complete` is used when the buffer contained the complete value. 4 | /// `Partial` is used when parsing did not reach the end of the expected value, 5 | /// but no invalid data was found. 6 | #[derive(Copy, Clone, PartialEq, Eq, Debug)] 7 | pub enum Status { 8 | /// The completed result. 9 | Complete(T), 10 | /// A partial result and how much is needed to continue parsing. 11 | Partial(usize), 12 | } 13 | 14 | impl Status { 15 | /// Convenience method to check if status is complete. 16 | #[inline] 17 | pub fn is_complete(&self) -> bool { 18 | match *self { 19 | Status::Complete(..) => true, 20 | Status::Partial(..) => false, 21 | } 22 | } 23 | 24 | /// Convenience method to check if status is partial. 25 | #[inline] 26 | pub fn is_partial(&self) -> bool { 27 | match *self { 28 | Status::Complete(..) => false, 29 | Status::Partial(..) => true, 30 | } 31 | } 32 | 33 | /// Convenience method to unwrap a Complete value. Panics if the status is 34 | /// `Partial`. 35 | #[inline] 36 | pub fn unwrap(self) -> T { 37 | match self { 38 | Status::Complete(t) => t, 39 | Status::Partial(..) => panic!("Tried to unwrap Status::Partial"), 40 | } 41 | } 42 | } 43 | 44 | #[macro_export] 45 | macro_rules! complete { 46 | ($e:expr) => { 47 | match $e? { 48 | Status::Complete(v) => v, 49 | Status::Partial(x) => return Ok(Status::Partial(x)), 50 | } 51 | }; 52 | } 53 | 54 | macro_rules! read { 55 | ($fn:path, $bytes:expr, $offset:expr) => { 56 | match $fn(&$bytes[$offset..])? { 57 | Status::Complete(v) => ($offset + v.0, v.1), 58 | Status::Partial(x) => return Ok(Status::Partial(x)), 59 | } 60 | }; 61 | } 62 | -------------------------------------------------------------------------------- /src/variable_header/connack.rs: -------------------------------------------------------------------------------- 1 | use core::{ 2 | convert::{TryFrom, TryInto}, 3 | fmt::Debug, 4 | result::Result, 5 | }; 6 | 7 | use crate::{ 8 | codec::{self, Encodable}, 9 | error::{DecodeError, EncodeError}, 10 | fixed_header::PacketFlags, 11 | status::Status, 12 | }; 13 | 14 | use super::HeaderDecode; 15 | 16 | #[derive(PartialEq, Clone, Copy)] 17 | pub struct Flags(u8); 18 | 19 | bitfield_bitrange! { 20 | struct Flags(u8) 21 | } 22 | 23 | impl Flags { 24 | bitfield_fields! { 25 | bool; 26 | pub session_present, _ : 1; 27 | } 28 | } 29 | 30 | impl Debug for Flags { 31 | bitfield_debug! { 32 | struct Flags; 33 | pub session_present, _ : 1; 34 | } 35 | } 36 | 37 | impl TryFrom for Flags { 38 | type Error = (); 39 | fn try_from(from: u8) -> Result { 40 | if 0b11111110 & from != 0 { 41 | Err(()) 42 | } else { 43 | Ok(Flags(from)) 44 | } 45 | } 46 | } 47 | 48 | impl Encodable for Flags { 49 | fn encoded_len(&self) -> usize { 50 | 1 51 | } 52 | 53 | fn encode(&self, bytes: &mut [u8]) -> Result { 54 | if bytes.is_empty() { 55 | return Err(EncodeError::OutOfSpace); 56 | } 57 | 58 | bytes[0] = self.0; 59 | 60 | Ok(1) 61 | } 62 | } 63 | 64 | #[derive(PartialEq, Eq, Debug, Clone, Copy)] 65 | pub enum ReturnCode { 66 | Accepted, 67 | RefusedProtocolVersion, 68 | RefusedClientIdentifier, 69 | RefusedServerUnavailable, 70 | RefusedUsernameOrPassword, 71 | RefusedNotAuthorized, 72 | } 73 | 74 | impl TryFrom for ReturnCode { 75 | type Error = (); 76 | fn try_from(from: u8) -> Result { 77 | Ok(match from { 78 | 0 => ReturnCode::Accepted, 79 | 1 => ReturnCode::RefusedProtocolVersion, 80 | 2 => ReturnCode::RefusedClientIdentifier, 81 | 3 => ReturnCode::RefusedServerUnavailable, 82 | 4 => ReturnCode::RefusedUsernameOrPassword, 83 | 5 => ReturnCode::RefusedNotAuthorized, 84 | _ => return Err(()), 85 | }) 86 | } 87 | } 88 | 89 | impl Encodable for ReturnCode { 90 | fn encoded_len(&self) -> usize { 91 | 1 92 | } 93 | 94 | fn encode(&self, bytes: &mut [u8]) -> Result { 95 | if bytes.is_empty() { 96 | return Err(EncodeError::OutOfSpace); 97 | } 98 | 99 | let val = match self { 100 | ReturnCode::Accepted => 0, 101 | ReturnCode::RefusedProtocolVersion => 1, 102 | ReturnCode::RefusedClientIdentifier => 2, 103 | ReturnCode::RefusedServerUnavailable => 3, 104 | ReturnCode::RefusedUsernameOrPassword => 4, 105 | ReturnCode::RefusedNotAuthorized => 5, 106 | }; 107 | 108 | bytes[0] = val; 109 | 110 | Ok(1) 111 | } 112 | } 113 | 114 | // VariableHeader for Connack packet 115 | #[derive(PartialEq, Debug)] 116 | pub struct Connack { 117 | flags: Flags, 118 | return_code: ReturnCode, 119 | } 120 | 121 | impl Connack { 122 | pub fn flags(&self) -> Flags { 123 | self.flags 124 | } 125 | 126 | pub fn return_code(&self) -> ReturnCode { 127 | self.return_code 128 | } 129 | } 130 | 131 | impl<'buf> HeaderDecode<'buf> for Connack { 132 | fn decode(_flags: PacketFlags, bytes: &[u8]) -> Result, DecodeError> { 133 | if bytes.len() < 2 { 134 | return Ok(Status::Partial(2 - bytes.len())); 135 | } 136 | 137 | let offset = 0; 138 | 139 | // read connack flags 140 | let (offset, flags) = read!(codec::values::parse_u8, bytes, offset); 141 | let flags = flags 142 | .try_into() 143 | .map_err(|_| DecodeError::InvalidConnackFlag)?; 144 | 145 | // read return code 146 | let (offset, return_code) = read!(codec::values::parse_u8, bytes, offset); 147 | let return_code = return_code 148 | .try_into() 149 | .map_err(|_| DecodeError::InvalidConnackReturnCode)?; 150 | 151 | Ok(Status::Complete((offset, Connack { flags, return_code }))) 152 | } 153 | } 154 | 155 | impl Encodable for Connack { 156 | fn encoded_len(&self) -> usize { 157 | 2 158 | } 159 | 160 | fn encode(&self, bytes: &mut [u8]) -> Result { 161 | self.flags.encode(&mut bytes[0..])?; 162 | self.return_code.encode(&mut bytes[1..])?; 163 | Ok(2) 164 | } 165 | } 166 | 167 | #[cfg(test)] 168 | mod tests {} 169 | -------------------------------------------------------------------------------- /src/variable_header/connect.rs: -------------------------------------------------------------------------------- 1 | use core::{ 2 | convert::{From, TryFrom, TryInto}, 3 | fmt::Debug, 4 | result::Result, 5 | }; 6 | 7 | use crate::{ 8 | codec::{self, Encodable}, 9 | error::{DecodeError, EncodeError}, 10 | fixed_header::PacketFlags, 11 | qos, 12 | status::Status, 13 | }; 14 | 15 | use super::HeaderDecode; 16 | 17 | use bitfield::BitRange; 18 | 19 | #[derive(PartialEq, Eq, Debug, Clone, Copy)] 20 | pub enum Protocol { 21 | MQTT, 22 | } 23 | 24 | impl Protocol { 25 | fn name(self) -> &'static str { 26 | match self { 27 | Protocol::MQTT => "MQTT", 28 | } 29 | } 30 | } 31 | 32 | #[derive(PartialEq, Eq, Debug, Clone, Copy)] 33 | pub enum Level { 34 | Level3_1_1, 35 | } 36 | 37 | impl TryFrom for Level { 38 | type Error = (); 39 | fn try_from(val: u8) -> Result { 40 | if val == 4 { 41 | Ok(Level::Level3_1_1) 42 | } else { 43 | Err(()) 44 | } 45 | } 46 | } 47 | 48 | impl From for u8 { 49 | fn from(val: Level) -> u8 { 50 | match val { 51 | Level::Level3_1_1 => 4, 52 | } 53 | } 54 | } 55 | 56 | #[derive(PartialEq, Clone, Copy, Default)] 57 | pub struct Flags(u8); 58 | 59 | bitfield_bitrange! { 60 | struct Flags(u8) 61 | } 62 | 63 | impl Flags { 64 | bitfield_fields! { 65 | bool; 66 | pub has_username, set_has_username : 7; 67 | pub has_password, set_has_password : 6; 68 | pub will_retain, set_will_retain : 5; 69 | 70 | pub has_will, set_has_will_flag : 2; 71 | pub clean_session, set_clean_session : 1; 72 | } 73 | 74 | pub fn will_qos(&self) -> Result { 75 | let qos_bits: u8 = self.bit_range(4, 3); 76 | qos_bits.try_into() 77 | } 78 | 79 | #[allow(dead_code)] 80 | pub fn set_will_qos(&mut self, qos: qos::QoS) { 81 | self.set_bit_range(4, 3, u8::from(qos)) 82 | } 83 | } 84 | 85 | impl From for u8 { 86 | fn from(val: Flags) -> u8 { 87 | val.0 88 | } 89 | } 90 | 91 | impl Debug for Flags { 92 | bitfield_debug! { 93 | struct Flags; 94 | pub has_username, _ : 7; 95 | pub has_password, _ : 6; 96 | pub will_retain, _ : 5; 97 | pub into QoS, will_qos, _ : 4, 3; 98 | pub has_will, _ : 2; 99 | pub clean_session, _ : 1; 100 | } 101 | } 102 | 103 | // VariableHeader for Connect packet 104 | #[derive(PartialEq, Debug)] 105 | pub struct Connect<'buf> { 106 | name: &'buf str, 107 | level: Level, 108 | flags: Flags, 109 | keep_alive: u16, 110 | } 111 | 112 | impl<'buf> Connect<'buf> { 113 | pub fn new(protocol: Protocol, level: Level, flags: Flags, keep_alive: u16) -> Self { 114 | let name = protocol.name(); 115 | Connect { 116 | name, 117 | level, 118 | flags, 119 | keep_alive, 120 | } 121 | } 122 | 123 | pub fn name(&self) -> &str { 124 | self.name 125 | } 126 | 127 | pub fn level(&self) -> Level { 128 | self.level 129 | } 130 | 131 | pub fn flags(&self) -> Flags { 132 | self.flags 133 | } 134 | 135 | pub fn keep_alive(&self) -> u16 { 136 | self.keep_alive 137 | } 138 | } 139 | 140 | impl<'buf> HeaderDecode<'buf> for Connect<'buf> { 141 | fn decode( 142 | _flags: PacketFlags, 143 | bytes: &'buf [u8], 144 | ) -> Result)>, DecodeError> { 145 | let offset = 0; 146 | 147 | // read protocol name 148 | let (offset, name) = read!(codec::string::parse_string, bytes, offset); 149 | 150 | // read protocol revision 151 | let (offset, level) = read!(codec::values::parse_u8, bytes, offset); 152 | 153 | let level = level 154 | .try_into() 155 | .map_err(|_| DecodeError::InvalidProtocolLevel)?; 156 | if level != Level::Level3_1_1 { 157 | return Err(DecodeError::InvalidProtocolLevel); 158 | } 159 | 160 | // read protocol flags 161 | let (offset, flags) = read!(codec::values::parse_u8, bytes, offset); 162 | 163 | let flags = Flags(flags); 164 | 165 | if let Err(e) = flags.will_qos() { 166 | match e { 167 | qos::Error::BadPattern => return Err(DecodeError::InvalidConnectFlag), 168 | } 169 | } 170 | 171 | // read protocol keep alive 172 | let (offset, keep_alive) = read!(codec::values::parse_u16, bytes, offset); 173 | 174 | Ok(Status::Complete(( 175 | offset, 176 | Connect { 177 | name, 178 | level, 179 | flags, 180 | keep_alive, 181 | }, 182 | ))) 183 | } 184 | } 185 | 186 | impl<'buf> Encodable for Connect<'buf> { 187 | fn encoded_len(&self) -> usize { 188 | self.name.encoded_len() + 1 + 1 + 2 189 | } 190 | 191 | fn encode(&self, bytes: &mut [u8]) -> Result { 192 | let mut offset = 0; 193 | offset += codec::string::encode_string(self.name, &mut bytes[offset..])?; 194 | offset += codec::values::encode_u8(self.level.into(), &mut bytes[offset..])?; 195 | offset += codec::values::encode_u8(self.flags.into(), &mut bytes[offset..])?; 196 | offset += codec::values::encode_u16(self.keep_alive, &mut bytes[offset..])?; 197 | Ok(offset) 198 | } 199 | } 200 | 201 | #[cfg(test)] 202 | mod tests { 203 | use super::*; 204 | 205 | #[test] 206 | fn parse_flags() { 207 | let flags = Flags(0b11100110); 208 | assert_eq!(flags.has_username(), true); 209 | assert_eq!(flags.has_password(), true); 210 | assert_eq!(flags.will_retain(), true); 211 | assert_eq!(flags.has_will(), true); 212 | assert_eq!(flags.clean_session(), true); 213 | 214 | let flags = Flags(0b00000000); 215 | assert_eq!(flags.has_username(), false); 216 | assert_eq!(flags.has_password(), false); 217 | assert_eq!(flags.will_retain(), false); 218 | assert_eq!(flags.has_will(), false); 219 | assert_eq!(flags.clean_session(), false); 220 | } 221 | 222 | #[test] 223 | fn parse_qos() { 224 | let flags = Flags(0b00010000); 225 | assert_eq!(flags.will_qos(), Ok(qos::QoS::ExactlyOnce)); 226 | 227 | let flags = Flags(0b00001000); 228 | assert_eq!(flags.will_qos(), Ok(qos::QoS::AtLeastOnce)); 229 | 230 | let flags = Flags(0b00000000); 231 | assert_eq!(flags.will_qos(), Ok(qos::QoS::AtMostOnce)); 232 | } 233 | 234 | #[test] 235 | fn parse_connect() { 236 | let buf = [ 237 | 0b00000000, // Protocol Name Length 238 | 0b00000100, 0b01001101, // 'M' 239 | 0b01010001, // 'Q' 240 | 0b01010100, // 'T' 241 | 0b01010100, // 'T' 242 | 0b00000100, // Level 4 243 | 0b11001110, // Connect Flags - Username 1 244 | // - Password 1 245 | // - Will Retain 0 246 | // - Will QoS 01 247 | // - Will Flag 1 248 | // - Clean Session 1 249 | // - Reserved 0 250 | 0b00000000, // Keep Alive (10s) 251 | 0b00001010, // 252 | ]; 253 | 254 | let connect = Connect::decode(PacketFlags::CONNECT, &buf); 255 | 256 | assert_eq!( 257 | connect, 258 | Ok(Status::Complete(( 259 | 10, 260 | Connect { 261 | name: "MQTT", 262 | level: Level::Level3_1_1, 263 | flags: Flags(0b11001110), 264 | keep_alive: 10, 265 | } 266 | ))) 267 | ); 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /src/variable_header/mod.rs: -------------------------------------------------------------------------------- 1 | use core::result::Result; 2 | 3 | use crate::{ 4 | codec::Encodable, 5 | error::{DecodeError, EncodeError}, 6 | fixed_header::{PacketFlags, PacketType}, 7 | status::Status, 8 | }; 9 | 10 | pub mod connack; 11 | pub mod connect; 12 | pub mod packet_identifier; 13 | pub mod publish; 14 | 15 | #[derive(Debug)] 16 | pub enum VariableHeader<'a> { 17 | Connect(connect::Connect<'a>), 18 | Connack(connack::Connack), 19 | Subscribe(packet_identifier::PacketIdentifier), 20 | Suback(packet_identifier::PacketIdentifier), 21 | Publish(publish::Publish<'a>), 22 | Puback(packet_identifier::PacketIdentifier), 23 | } 24 | 25 | pub trait HeaderDecode<'a> 26 | where 27 | Self: core::marker::Sized, 28 | { 29 | fn decode(flags: PacketFlags, bytes: &'a [u8]) -> Result, DecodeError>; 30 | } 31 | 32 | pub type PacketId = u16; 33 | 34 | macro_rules! decode { 35 | ($($name:ident, $parser:path;)+) => ( 36 | pub fn decode(r#type: PacketType, flags: PacketFlags, bytes: &'a [u8]) -> Option, DecodeError>> { 37 | Some(match r#type { 38 | $( 39 | PacketType::$name => $parser(flags, bytes).map(|s| { 40 | match s { 41 | Status::Complete((offset, var_header)) => { 42 | Status::Complete((offset, VariableHeader::$name(var_header))) 43 | }, 44 | Status::Partial(n) => Status::Partial(n), 45 | } 46 | }), 47 | )+ 48 | _ => return None, 49 | }) 50 | } 51 | ) 52 | } 53 | 54 | impl<'a> VariableHeader<'a> { 55 | decode!( 56 | Connect, connect::Connect::decode; 57 | Connack, connack::Connack::decode; 58 | Subscribe, packet_identifier::PacketIdentifier::decode; 59 | Suback, packet_identifier::PacketIdentifier::decode; 60 | Publish, publish::Publish::decode; 61 | Puback, packet_identifier::PacketIdentifier::decode; 62 | ); 63 | } 64 | 65 | macro_rules! encode { 66 | ($($enum:ident;)+) => ( 67 | fn encoded_len(&self) -> usize { 68 | match self { 69 | $( &VariableHeader::$enum(ref c) => c.encoded_len(), )+ 70 | } 71 | } 72 | 73 | fn encode(&self, bytes: &mut [u8]) -> Result { 74 | match self { 75 | $( &VariableHeader::$enum(ref c) => c.encode(bytes), )+ 76 | } 77 | } 78 | ) 79 | } 80 | 81 | impl<'buf> Encodable for VariableHeader<'buf> { 82 | encode!( 83 | Connect; 84 | Connack; 85 | Subscribe; 86 | Suback; 87 | Publish; 88 | Puback; 89 | ); 90 | } 91 | -------------------------------------------------------------------------------- /src/variable_header/packet_identifier.rs: -------------------------------------------------------------------------------- 1 | use core::result::Result; 2 | 3 | use crate::{ 4 | codec::{self, Encodable}, 5 | error::{DecodeError, EncodeError}, 6 | fixed_header::PacketFlags, 7 | status::Status, 8 | }; 9 | 10 | use super::{HeaderDecode, PacketId}; 11 | 12 | // TODO make this a non-zero u16 when it is stable 13 | #[derive(PartialEq, Debug)] 14 | pub struct PacketIdentifier(PacketId); 15 | 16 | impl PacketIdentifier { 17 | pub fn new(packet_identifier: PacketId) -> Self { 18 | Self(packet_identifier) 19 | } 20 | 21 | pub fn packet_identifier(&self) -> PacketId { 22 | self.0 23 | } 24 | } 25 | 26 | impl<'buf> HeaderDecode<'buf> for PacketIdentifier { 27 | fn decode( 28 | _flags: PacketFlags, 29 | bytes: &'buf [u8], 30 | ) -> Result, DecodeError> { 31 | // read connack flags 32 | let (offset, packet_identifier) = read!(codec::values::parse_u16, bytes, 0); 33 | 34 | Ok(Status::Complete((offset, Self(packet_identifier)))) 35 | } 36 | } 37 | 38 | impl Encodable for PacketIdentifier { 39 | fn encoded_len(&self) -> usize { 40 | 2 41 | } 42 | 43 | fn encode(&self, bytes: &mut [u8]) -> Result { 44 | codec::values::encode_u16(self.0, bytes) 45 | } 46 | } 47 | 48 | #[cfg(test)] 49 | mod tests {} 50 | -------------------------------------------------------------------------------- /src/variable_header/publish.rs: -------------------------------------------------------------------------------- 1 | use core::{convert::TryFrom, result::Result}; 2 | 3 | use crate::{ 4 | codec::{self, Encodable}, 5 | error::{DecodeError, EncodeError}, 6 | fixed_header::{PacketFlags, PublishFlags}, 7 | qos, 8 | status::Status, 9 | }; 10 | 11 | use super::{HeaderDecode, PacketId}; 12 | 13 | #[derive(Debug)] 14 | pub struct Publish<'a> { 15 | topic_name: &'a str, 16 | packet_identifier: Option, 17 | } 18 | 19 | impl<'a> Publish<'a> { 20 | pub fn new(topic_name: &'a str, packet_identifier: Option) -> Self { 21 | Self { 22 | topic_name, 23 | packet_identifier, 24 | } 25 | } 26 | 27 | pub fn topic_name(&self) -> &'a str { 28 | self.topic_name 29 | } 30 | 31 | pub fn packet_identifier(&self) -> Option { 32 | self.packet_identifier 33 | } 34 | } 35 | 36 | impl<'a> HeaderDecode<'a> for Publish<'a> { 37 | fn decode(flags: PacketFlags, bytes: &'a [u8]) -> Result, DecodeError> { 38 | let flags = PublishFlags::try_from(flags)?; 39 | 40 | let offset = 0; 41 | let (offset, topic_name) = read!(codec::string::parse_string, bytes, offset); 42 | 43 | let (offset, packet_identifier) = if flags.qos()? != qos::QoS::AtMostOnce { 44 | let (offset, packet_identifier) = read!(codec::values::parse_u16, bytes, offset); 45 | (offset, Some(packet_identifier)) 46 | } else { 47 | (offset, None) 48 | }; 49 | 50 | Ok(Status::Complete(( 51 | offset, 52 | Self { 53 | topic_name, 54 | packet_identifier, 55 | }, 56 | ))) 57 | } 58 | } 59 | 60 | impl<'a> Encodable for Publish<'a> { 61 | fn encoded_len(&self) -> usize { 62 | self.topic_name.encoded_len() + self.packet_identifier.map(|_| 2).unwrap_or(0) 63 | } 64 | 65 | fn encode(&self, bytes: &mut [u8]) -> Result { 66 | let mut offset = 0; 67 | offset += self.topic_name.encode(&mut bytes[offset..])?; 68 | if let Some(packet_identifier) = self.packet_identifier { 69 | offset += codec::values::encode_u16(packet_identifier, &mut bytes[offset..])?; 70 | } 71 | Ok(offset) 72 | } 73 | } 74 | 75 | #[cfg(test)] 76 | mod tests { 77 | use super::*; 78 | 79 | #[test] 80 | fn encode() { 81 | let header = Publish { 82 | topic_name: "a/b", 83 | packet_identifier: Some(1), 84 | }; 85 | 86 | assert_eq!(7, header.encoded_len()); 87 | 88 | let mut buf = [0u8; 7]; 89 | let res = header.encode(&mut buf[..]); 90 | assert_eq!(res, Ok(7)); 91 | 92 | assert_eq!( 93 | buf, 94 | [ 95 | 0b0000_0000, 96 | 0b0000_0011, 97 | 0x61, 98 | 0x2f, 99 | 0x62, 100 | 0b0000_0000, 101 | 0b0000_0001, 102 | ] 103 | ); 104 | } 105 | } 106 | --------------------------------------------------------------------------------