├── .gitignore ├── .DS_Store ├── src ├── .DS_Store ├── cmd │ ├── unknown.rs │ ├── get.rs │ ├── publish.rs │ ├── mod.rs │ ├── set.rs │ └── subscribe.rs ├── bin │ ├── server.rs │ └── cli.rs ├── shutdown.rs ├── lib.rs ├── buffer.rs ├── parse.rs ├── blocking_client.rs ├── frame.rs ├── connection.rs ├── client.rs ├── server.rs └── db.rs ├── examples ├── chat.rs ├── pub.rs ├── hello_world.rs └── sub.rs ├── .github └── workflows │ └── rust.yml ├── README.md ├── Cargo.toml └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yumcoder-dev/mini-telegram/HEAD/.DS_Store -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yumcoder-dev/mini-telegram/HEAD/src/.DS_Store -------------------------------------------------------------------------------- /examples/chat.rs: -------------------------------------------------------------------------------- 1 | #[tokio::main] 2 | async fn main() { 3 | unimplemented!(); 4 | } 5 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Build 20 | run: cargo build --verbose 21 | - name: Run tests 22 | run: cargo test --verbose 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mini-telegram 2 | 3 | `mini-telegram` is an unofficial, monolithic, idiomatic implementation of [MTProto](https://core.telegram.org/mtproto) (telegram) server built with [Rust](https://www.rust-lang.org) that compatible with all telegram clients (web, android, iOS, desktop). 4 | 5 | **Disclaimer** Please don't use mini-telegram in high scale production. The intent of this project is to provide a MVP (minimum viable product) of MTProto server. 6 | 7 | ## Enterprise version 8 | 9 | - TODO 10 | 11 | ## Run `mini-telegram` server 12 | 13 | - TODO 14 | 15 | ## Connect `android client` 16 | 17 | - TODO 18 | 19 | ## Connect `Ios client` 20 | 21 | - TODO 22 | 23 | ## Connect `Web client` 24 | 25 | - TODO 26 | 27 | ## Connect `desktop client` 28 | 29 | - TODO -------------------------------------------------------------------------------- /examples/pub.rs: -------------------------------------------------------------------------------- 1 | //! Publish to a telegram channel example. 2 | //! 3 | //! A simple client that connects to a mini-telegram server, and 4 | //! publishes a message on `foo` channel 5 | //! 6 | //! You can test this out by running: 7 | //! 8 | //! cargo run --bin mini-telegram-server 9 | //! 10 | //! Then in another terminal run: 11 | //! 12 | //! cargo run --example sub 13 | //! 14 | //! And then in another terminal run: 15 | //! 16 | //! cargo run --example pub 17 | 18 | #![warn(rust_2018_idioms)] 19 | 20 | use mini_telegram::{client, Result}; 21 | 22 | #[tokio::main] 23 | async fn main() -> Result<()> { 24 | // Open a connection to the mini-telgram address. 25 | let mut client = client::connect("127.0.0.1:6379").await?; 26 | 27 | // publish message `bar` on channel foo 28 | client.publish("foo", "bar".into()).await?; 29 | 30 | Ok(()) 31 | } 32 | -------------------------------------------------------------------------------- /examples/hello_world.rs: -------------------------------------------------------------------------------- 1 | //! Hello world server. 2 | //! 3 | //! A simple client that connects to a mini-telegram server, sets key "hello" with value "world", 4 | //! and gets it from the server after 5 | //! 6 | //! You can test this out by running: 7 | //! 8 | //! cargo run --bin mini-telegram-server 9 | //! 10 | //! And then in another terminal run: 11 | //! 12 | //! cargo run --example hello_world 13 | 14 | #![warn(rust_2018_idioms)] 15 | 16 | use mini_telegram::{client, Result}; 17 | 18 | #[tokio::main] 19 | pub async fn main() -> Result<()> { 20 | // Open a connection to the mini-telegram address. 21 | let mut client = client::connect("127.0.0.1:6379").await?; 22 | 23 | // Set the key "hello" with value "world" 24 | client.set("hello", "world".into()).await?; 25 | 26 | // Get key "hello" 27 | let result = client.get("hello").await?; 28 | 29 | println!("got value from the server; success={:?}", result.is_some()); 30 | 31 | Ok(()) 32 | } 33 | -------------------------------------------------------------------------------- /examples/sub.rs: -------------------------------------------------------------------------------- 1 | //! Subscribe to a telegram channel example. 2 | //! 3 | //! A simple client that connects to a mini-telegram server, subscribes to "foo" and "bar" channels 4 | //! and awaits messages published on those channels 5 | //! 6 | //! You can test this out by running: 7 | //! 8 | //! cargo run --bin mini-telegram-server 9 | //! 10 | //! Then in another terminal run: 11 | //! 12 | //! cargo run --example sub 13 | //! 14 | //! And then in another terminal run: 15 | //! 16 | //! cargo run --example pub 17 | 18 | #![warn(rust_2018_idioms)] 19 | 20 | use mini_telegram::{client, Result}; 21 | 22 | #[tokio::main] 23 | pub async fn main() -> Result<()> { 24 | // Open a connection to the mini-telegram address. 25 | let client = client::connect("127.0.0.1:6379").await?; 26 | 27 | // subscribe to channel foo 28 | let mut subscriber = client.subscribe(vec!["foo".into()]).await?; 29 | 30 | // await messages on channel foo 31 | if let Some(msg) = subscriber.next_message().await? { 32 | println!( 33 | "got message from the channel: {}; message = {:?}", 34 | msg.channel, msg.content 35 | ); 36 | } 37 | 38 | Ok(()) 39 | } 40 | -------------------------------------------------------------------------------- /src/cmd/unknown.rs: -------------------------------------------------------------------------------- 1 | use crate::{Connection, Frame}; 2 | 3 | use tracing::{debug, instrument}; 4 | 5 | /// Represents an "unknown" command. This is not a real MTProto command. 6 | #[derive(Debug)] 7 | pub struct Unknown { 8 | command_name: String, 9 | } 10 | 11 | impl Unknown { 12 | /// Create a new `Unknown` command which responds to unknown commands 13 | /// issued by clients 14 | pub(crate) fn new(key: impl ToString) -> Unknown { 15 | Unknown { 16 | command_name: key.to_string(), 17 | } 18 | } 19 | 20 | /// Returns the command name 21 | pub(crate) fn get_name(&self) -> &str { 22 | &self.command_name 23 | } 24 | 25 | /// Responds to the client, indicating the command is not recognized. 26 | /// 27 | /// This usually means the command is not yet implemented by `mini-telegram`. 28 | #[instrument(skip(self, dst))] 29 | pub(crate) async fn apply(self, dst: &mut Connection) -> crate::Result<()> { 30 | let response = Frame::Error(format!("ERR unknown command '{}'", self.command_name)); 31 | 32 | debug!(?response); 33 | 34 | dst.write_frame(&response).await?; 35 | Ok(()) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/bin/server.rs: -------------------------------------------------------------------------------- 1 | //! mini-telegram server. 2 | //! 3 | //! This file is the entry point for the server implemented in the library. It 4 | //! performs command line parsing and passes the arguments on to 5 | //! `mini_telegram::server`. 6 | //! 7 | //! The `clap` crate is used for parsing arguments. 8 | 9 | use mini_telegram::{server, DEFAULT_PORT}; 10 | 11 | use structopt::StructOpt; 12 | use tokio::net::TcpListener; 13 | use tokio::signal; 14 | 15 | #[tokio::main] 16 | pub async fn main() -> mini_telegram::Result<()> { 17 | // enable logging 18 | // see https://docs.rs/tracing for more info 19 | tracing_subscriber::fmt::try_init()?; 20 | 21 | let cli = Cli::from_args(); 22 | let port = cli.port.as_deref().unwrap_or(DEFAULT_PORT); 23 | 24 | // Bind a TCP listener 25 | let listener = TcpListener::bind(&format!("127.0.0.1:{}", port)).await?; 26 | 27 | server::run(listener, signal::ctrl_c()).await; 28 | 29 | Ok(()) 30 | } 31 | 32 | #[derive(StructOpt, Debug)] 33 | #[structopt(name = "mini-telegram-server", version = env!("CARGO_PKG_VERSION"), author = env!("CARGO_PKG_AUTHORS"), about = "A MTProto server")] 34 | struct Cli { 35 | #[structopt(name = "port", long = "--port")] 36 | port: Option, 37 | } 38 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "mini-telegram" 3 | version = "0.1.17" 4 | edition = "2021" 5 | license = "Apache-2.0" 6 | readme = "README.md" 7 | authors = ["Yumcoder "] 8 | keywords = ["telegram", "mtproto", "mtproto-server", "tokio-rs", "tokio"] 9 | repository = "https://github.com/YumcoderCom/mini-telegram" 10 | description = "mini-telegram is an unofficial, monolithic, idiomatic implementation of MTProto (telegram) server built with Rust that compatible with all telegram clients (web, android, iOS, desktop)." 11 | 12 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 13 | [dependencies] 14 | async-stream = "0.3.0" 15 | atoi = "2.0.0" 16 | bytes = "1" 17 | structopt = "0.3.14" 18 | tokio = { version = "1", features = ["full"] } 19 | tokio-stream = "0.1" 20 | tracing = "0.1.13" 21 | tracing-futures = { version = "0.2.3" } 22 | tracing-subscriber = "0.3.16" 23 | futures = "0.3.21" 24 | 25 | [dev-dependencies] 26 | # Enable test-utilities in dev mode only. This is mostly for tests. 27 | tokio = { version = "1", features = ["test-util"] } 28 | 29 | [[bin]] 30 | name = "mini-telegram-cli" 31 | path = "src/bin/cli.rs" 32 | 33 | [[bin]] 34 | name = "mini-telegram-server" 35 | path = "src/bin/server.rs" 36 | -------------------------------------------------------------------------------- /src/shutdown.rs: -------------------------------------------------------------------------------- 1 | use tokio::sync::broadcast; 2 | 3 | /// Listens for the server shutdown signal. 4 | /// 5 | /// Shutdown is signalled using a `broadcast::Receiver`. Only a single value is 6 | /// ever sent. Once a value has been sent via the broadcast channel, the server 7 | /// should shutdown. 8 | /// 9 | /// The `Shutdown` struct listens for the signal and tracks that the signal has 10 | /// been received. Callers may query for whether the shutdown signal has been 11 | /// received or not. 12 | #[derive(Debug)] 13 | pub(crate) struct Shutdown { 14 | /// `true` if the shutdown signal has been received 15 | shutdown: bool, 16 | 17 | /// The receive half of the channel used to listen for shutdown. 18 | notify: broadcast::Receiver<()>, 19 | } 20 | 21 | impl Shutdown { 22 | /// Create a new `Shutdown` backed by the given `broadcast::Receiver`. 23 | pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown { 24 | Shutdown { 25 | shutdown: false, 26 | notify, 27 | } 28 | } 29 | 30 | /// Returns `true` if the shutdown signal has been received. 31 | pub(crate) fn is_shutdown(&self) -> bool { 32 | self.shutdown 33 | } 34 | 35 | /// Receive the shutdown notice, waiting if necessary. 36 | pub(crate) async fn recv(&mut self) { 37 | // If the shutdown signal has already been received, then return 38 | // immediately. 39 | if self.shutdown { 40 | return; 41 | } 42 | 43 | // Cannot receive a "lag error" as only one value is ever sent. 44 | let _ = self.notify.recv().await; 45 | 46 | // Remember that the signal has been received. 47 | self.shutdown = true; 48 | } 49 | } 50 | 51 | #[cfg(test)] 52 | mod tests { 53 | use super::*; 54 | 55 | #[tokio::test] 56 | async fn test_shutdown() { 57 | let (notify_shutdown, _) = broadcast::channel(1); 58 | let mut shutdown = Shutdown::new(notify_shutdown.subscribe()); 59 | assert!(!shutdown.is_shutdown()); 60 | 61 | tokio::spawn(async move { notify_shutdown.send(()) }); 62 | 63 | shutdown.recv().await; 64 | assert!(shutdown.is_shutdown()); 65 | 66 | shutdown.recv().await; 67 | assert!(shutdown.is_shutdown()); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A minimal (i.e. very incomplete) implementation of a MTProto server and 2 | //! client. 3 | //! 4 | //! The purpose of this project is to provide a larger example of an 5 | //! asynchronous Rust project built with Tokio. Do not attempt to run this in 6 | //! production... seriously. 7 | //! 8 | //! # Layout 9 | //! 10 | //! The library is structured such that it can be used with guides. There are 11 | //! modules that are public that probably would not be public in a "real" MTProto 12 | //! client library. 13 | //! 14 | //! The major components are: 15 | //! 16 | //! * `server`: MTProto server implementation. Includes a single `run` function 17 | //! that takes a `TcpListener` and starts accepting MTProto client connections. 18 | //! 19 | //! * `client`: an asynchronous MTProto client implementation. 20 | //! 21 | //! * `cmd`: implementations of the supported MTProto commands(APIs). 22 | //! 23 | //! * `frame`: represents a single MTProto protocol frame. A frame is used as an 24 | //! intermediate representation between a "command" and the byte 25 | //! representation. 26 | 27 | #![feature(test)] 28 | #![feature(cursor_remaining)] 29 | #![feature(assert_matches)] 30 | 31 | pub mod blocking_client; 32 | pub mod client; 33 | 34 | pub mod cmd; 35 | pub use cmd::Command; 36 | 37 | mod connection; 38 | pub use connection::Connection; 39 | 40 | pub mod frame; 41 | pub use frame::Frame; 42 | 43 | mod db; 44 | use db::Db; 45 | use db::DbDropGuard; 46 | 47 | mod parse; 48 | use parse::{Parse, ParseError}; 49 | 50 | pub mod server; 51 | 52 | mod buffer; 53 | pub use buffer::{buffer, Buffer}; 54 | 55 | mod shutdown; 56 | use shutdown::Shutdown; 57 | 58 | /// Default port that a MTProto server listens on. 59 | /// 60 | /// Used if no port is specified. 61 | pub const DEFAULT_PORT: &str = "6379"; 62 | 63 | /// Error returned by most functions. 64 | /// 65 | /// When writing a real application, one might want to consider a specialized 66 | /// error handling crate or defining an error type as an `enum` of causes. 67 | /// However, for our example, using a boxed `std::error::Error` is sufficient. 68 | /// 69 | /// For performance reasons, boxing is avoided in any hot path. For example, in 70 | /// `parse`, a custom error `enum` is defined. This is because the error is hit 71 | /// and handled during normal execution when a partial frame is received on a 72 | /// socket. `std::error::Error` is implemented for `parse::Error` which allows 73 | /// it to be converted to `Box`. 74 | pub type Error = Box; 75 | 76 | /// A specialized `Result` type for mini-telegram operations. 77 | /// 78 | /// This is defined as a convenience. 79 | pub type Result = std::result::Result; 80 | -------------------------------------------------------------------------------- /src/cmd/get.rs: -------------------------------------------------------------------------------- 1 | use crate::{Connection, Db, Frame, Parse}; 2 | 3 | use bytes::Bytes; 4 | use tracing::{debug, instrument}; 5 | 6 | /// Get the value of key. 7 | /// 8 | /// If the key does not exist the special value nil is returned. An error is 9 | /// returned if the value stored at key is not a string, because GET only 10 | /// handles string values. 11 | #[derive(Debug)] 12 | pub struct Get { 13 | /// Name of the key to get 14 | key: String, 15 | } 16 | 17 | impl Get { 18 | /// Create a new `Get` command which fetches `key`. 19 | pub fn new(key: impl ToString) -> Get { 20 | Get { 21 | key: key.to_string(), 22 | } 23 | } 24 | 25 | /// Get the key 26 | pub fn key(&self) -> &str { 27 | &self.key 28 | } 29 | 30 | /// Parse a `Get` instance from a received frame. 31 | /// 32 | /// The `Parse` argument provides a cursor-like API to read fields from the 33 | /// `Frame`. At this point, the entire frame has already been received from 34 | /// the socket. 35 | /// 36 | /// The `GET` string has already been consumed. 37 | /// 38 | /// # Returns 39 | /// 40 | /// Returns the `Get` value on success. If the frame is malformed, `Err` is 41 | /// returned. 42 | /// 43 | /// # Format 44 | /// 45 | /// Expects an array frame containing two entries. 46 | /// 47 | /// ```text 48 | /// GET key 49 | /// ``` 50 | pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { 51 | // The `GET` string has already been consumed. The next value is the 52 | // name of the key to get. If the next value is not a string or the 53 | // input is fully consumed, then an error is returned. 54 | let key = parse.next_string()?; 55 | 56 | Ok(Get { key }) 57 | } 58 | 59 | /// Apply the `Get` command to the specified `Db` instance. 60 | /// 61 | /// The response is written to `dst`. This is called by the server in order 62 | /// to execute a received command. 63 | #[instrument(skip(self, db, dst))] 64 | pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> { 65 | // Get the value from the shared database state 66 | let response = if let Some(value) = db.get(&self.key) { 67 | // If a value is present, it is written to the client in "bulk" 68 | // format. 69 | Frame::Bulk(value) 70 | } else { 71 | // If there is no value, `Null` is written. 72 | Frame::Null 73 | }; 74 | 75 | debug!(?response); 76 | 77 | // Write the response back to the client 78 | dst.write_frame(&response).await?; 79 | 80 | Ok(()) 81 | } 82 | 83 | /// Converts the command into an equivalent `Frame`. 84 | /// 85 | /// This is called by the client when encoding a `Get` command to send to 86 | /// the server. 87 | pub(crate) fn into_frame(self) -> Frame { 88 | let mut frame = Frame::array(); 89 | frame.push_bulk(Bytes::from("get".as_bytes())); 90 | frame.push_bulk(Bytes::from(self.key.into_bytes())); 91 | frame 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/bin/cli.rs: -------------------------------------------------------------------------------- 1 | use mini_telegram::{client, DEFAULT_PORT}; 2 | 3 | use bytes::Bytes; 4 | use std::num::ParseIntError; 5 | use std::str; 6 | use std::time::Duration; 7 | use structopt::StructOpt; 8 | 9 | #[derive(StructOpt, Debug)] 10 | #[structopt(name = "mini-telegram-cli", version = env!("CARGO_PKG_VERSION"), author = env!("CARGO_PKG_AUTHORS"), about = "Issue MTProto commands")] 11 | struct Cli { 12 | #[structopt(subcommand)] 13 | command: Command, 14 | 15 | #[structopt(name = "hostname", long = "--host", default_value = "127.0.0.1")] 16 | host: String, 17 | 18 | #[structopt(name = "port", long = "--port", default_value = DEFAULT_PORT)] 19 | port: String, 20 | } 21 | 22 | #[derive(StructOpt, Debug)] 23 | enum Command { 24 | /// Get the value of key. 25 | Get { 26 | /// Name of key to get 27 | key: String, 28 | }, 29 | /// Set key to hold the string value. 30 | Set { 31 | /// Name of key to set 32 | key: String, 33 | 34 | /// Value to set. 35 | #[structopt(parse(from_str = bytes_from_str))] 36 | value: Bytes, 37 | 38 | /// Expire the value after specified amount of time 39 | #[structopt(parse(try_from_str = duration_from_ms_str))] 40 | expires: Option, 41 | }, 42 | } 43 | 44 | /// Entry point for CLI tool. 45 | /// 46 | /// The `[tokio::main]` annotation signals that the Tokio runtime should be 47 | /// started when the function is called. The body of the function is executed 48 | /// within the newly spawned runtime. 49 | /// 50 | /// `flavor = "current_thread"` is used here to avoid spawning background 51 | /// threads. The CLI tool use case benefits more by being lighter instead of 52 | /// multi-threaded. 53 | #[tokio::main(flavor = "current_thread")] 54 | async fn main() -> mini_telegram::Result<()> { 55 | // Enable logging 56 | tracing_subscriber::fmt::try_init()?; 57 | 58 | // Parse command line arguments 59 | let cli = Cli::from_args(); 60 | 61 | // Get the remote address to connect to 62 | let addr = format!("{}:{}", cli.host, cli.port); 63 | 64 | // Establish a connection 65 | let mut client = client::connect(&addr).await?; 66 | 67 | // Process the requested command 68 | match cli.command { 69 | Command::Get { key } => { 70 | if let Some(value) = client.get(&key).await? { 71 | if let Ok(string) = str::from_utf8(&value) { 72 | println!("\"{}\"", string); 73 | } else { 74 | println!("{:?}", value); 75 | } 76 | } else { 77 | println!("(nil)"); 78 | } 79 | } 80 | Command::Set { 81 | key, 82 | value, 83 | expires: None, 84 | } => { 85 | client.set(&key, value).await?; 86 | println!("OK"); 87 | } 88 | Command::Set { 89 | key, 90 | value, 91 | expires: Some(expires), 92 | } => { 93 | client.set_expires(&key, value, expires).await?; 94 | println!("OK"); 95 | } 96 | } 97 | 98 | Ok(()) 99 | } 100 | 101 | fn duration_from_ms_str(src: &str) -> Result { 102 | let ms = src.parse::()?; 103 | Ok(Duration::from_millis(ms)) 104 | } 105 | 106 | fn bytes_from_str(src: &str) -> Bytes { 107 | Bytes::from(src.to_string()) 108 | } 109 | -------------------------------------------------------------------------------- /src/cmd/publish.rs: -------------------------------------------------------------------------------- 1 | use crate::{Connection, Db, Frame, Parse}; 2 | 3 | use bytes::Bytes; 4 | 5 | /// Posts a message to the given channel. 6 | /// 7 | /// Send a message into a channel without any knowledge of individual consumers. 8 | /// Consumers may subscribe to channels in order to receive the messages. 9 | /// 10 | /// Channel names have no relation to the key-value namespace. Publishing on a 11 | /// channel named "foo" has no relation to setting the "foo" key. 12 | #[derive(Debug)] 13 | pub struct Publish { 14 | /// Name of the channel on which the message should be published. 15 | channel: String, 16 | 17 | /// The message to publish. 18 | message: Bytes, 19 | } 20 | 21 | impl Publish { 22 | /// Create a new `Publish` command which sends `message` on `channel`. 23 | pub(crate) fn new(channel: impl ToString, message: Bytes) -> Publish { 24 | Publish { 25 | channel: channel.to_string(), 26 | message, 27 | } 28 | } 29 | 30 | /// Parse a `Publish` instance from a received frame. 31 | /// 32 | /// The `Parse` argument provides a cursor-like API to read fields from the 33 | /// `Frame`. At this point, the entire frame has already been received from 34 | /// the socket. 35 | /// 36 | /// The `PUBLISH` string has already been consumed. 37 | /// 38 | /// # Returns 39 | /// 40 | /// On success, the `Publish` value is returned. If the frame is malformed, 41 | /// `Err` is returned. 42 | /// 43 | /// # Format 44 | /// 45 | /// Expects an array frame containing three entries. 46 | /// 47 | /// ```text 48 | /// PUBLISH channel message 49 | /// ``` 50 | pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { 51 | // The `PUBLISH` string has already been consumed. Extract the `channel` 52 | // and `message` values from the frame. 53 | // 54 | // The `channel` must be a valid string. 55 | let channel = parse.next_string()?; 56 | 57 | // The `message` is arbitrary bytes. 58 | let message = parse.next_bytes()?; 59 | 60 | Ok(Publish { channel, message }) 61 | } 62 | 63 | /// Apply the `Publish` command to the specified `Db` instance. 64 | /// 65 | /// The response is written to `dst`. This is called by the server in order 66 | /// to execute a received command. 67 | pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> { 68 | // The shared state contains the `tokio::sync::broadcast::Sender` for 69 | // all active channels. Calling `db.publish` dispatches the message into 70 | // the appropriate channel. 71 | // 72 | // The number of subscribers currently listening on the channel is 73 | // returned. This does not mean that `num_subscriber` channels will 74 | // receive the message. Subscribers may drop before receiving the 75 | // message. Given this, `num_subscribers` should only be used as a 76 | // "hint". 77 | let num_subscribers = db.publish(&self.channel, self.message); 78 | 79 | // The number of subscribers is returned as the response to the publish 80 | // request. 81 | let response = Frame::Integer(num_subscribers as u64); 82 | 83 | // Write the frame to the client. 84 | dst.write_frame(&response).await?; 85 | 86 | Ok(()) 87 | } 88 | 89 | /// Converts the command into an equivalent `Frame`. 90 | /// 91 | /// This is called by the client when encoding a `Publish` command to send 92 | /// to the server. 93 | pub(crate) fn into_frame(self) -> Frame { 94 | let mut frame = Frame::array(); 95 | frame.push_bulk(Bytes::from("publish".as_bytes())); 96 | frame.push_bulk(Bytes::from(self.channel.into_bytes())); 97 | frame.push_bulk(self.message); 98 | 99 | frame 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/cmd/mod.rs: -------------------------------------------------------------------------------- 1 | mod get; 2 | pub use get::Get; 3 | 4 | mod publish; 5 | pub use publish::Publish; 6 | 7 | mod set; 8 | pub use set::Set; 9 | 10 | mod subscribe; 11 | pub use subscribe::{Subscribe, Unsubscribe}; 12 | 13 | mod unknown; 14 | pub use unknown::Unknown; 15 | 16 | use crate::{Connection, Db, Frame, Parse, ParseError, Shutdown}; 17 | 18 | /// Enumeration of supported MTProto commands. 19 | /// 20 | /// Methods called on `Command` are delegated to the command implementation. 21 | #[derive(Debug)] 22 | pub enum Command { 23 | Get(Get), 24 | Publish(Publish), 25 | Set(Set), 26 | Subscribe(Subscribe), 27 | Unsubscribe(Unsubscribe), 28 | Unknown(Unknown), 29 | } 30 | 31 | impl Command { 32 | /// Parse a command from a received frame. 33 | /// 34 | /// The `Frame` must represent a command supported by `mini-telegram` and 35 | /// be the array variant. 36 | /// 37 | /// # Returns 38 | /// 39 | /// On success, the command value is returned, otherwise, `Err` is returned. 40 | pub fn from_frame(frame: Frame) -> crate::Result { 41 | // The frame value is decorated with `Parse`. `Parse` provides a 42 | // "cursor" like API which makes parsing the command easier. 43 | // 44 | // The frame value must be an array variant. Any other frame variants 45 | // result in an error being returned. 46 | let mut parse = Parse::new(frame)?; 47 | 48 | // All commands begin with the command name as a string. The name 49 | // is read and converted to lower cases in order to do case sensitive 50 | // matching. 51 | let command_name = parse.next_string()?.to_lowercase(); 52 | 53 | // Match the command name, delegating the rest of the parsing to the 54 | // specific command. 55 | let command = match &command_name[..] { 56 | "get" => Command::Get(Get::parse_frames(&mut parse)?), 57 | "publish" => Command::Publish(Publish::parse_frames(&mut parse)?), 58 | "set" => Command::Set(Set::parse_frames(&mut parse)?), 59 | "subscribe" => Command::Subscribe(Subscribe::parse_frames(&mut parse)?), 60 | "unsubscribe" => Command::Unsubscribe(Unsubscribe::parse_frames(&mut parse)?), 61 | _ => { 62 | // The command is not recognized and an Unknown command is 63 | // returned. 64 | // 65 | // `return` is called here to skip the `finish()` call below. As 66 | // the command is not recognized, there is most likely 67 | // unconsumed fields remaining in the `Parse` instance. 68 | return Ok(Command::Unknown(Unknown::new(command_name))); 69 | } 70 | }; 71 | 72 | // Check if there is any remaining unconsumed fields in the `Parse` 73 | // value. If fields remain, this indicates an unexpected frame format 74 | // and an error is returned. 75 | parse.finish()?; 76 | 77 | // The command has been successfully parsed 78 | Ok(command) 79 | } 80 | 81 | /// Apply the command to the specified `Db` instance. 82 | /// 83 | /// The response is written to `dst`. This is called by the server in order 84 | /// to execute a received command. 85 | pub(crate) async fn apply( 86 | self, 87 | db: &Db, 88 | dst: &mut Connection, 89 | shutdown: &mut Shutdown, 90 | ) -> crate::Result<()> { 91 | use Command::*; 92 | 93 | match self { 94 | Get(cmd) => cmd.apply(db, dst).await, 95 | Publish(cmd) => cmd.apply(db, dst).await, 96 | Set(cmd) => cmd.apply(db, dst).await, 97 | Subscribe(cmd) => cmd.apply(db, dst, shutdown).await, 98 | Unknown(cmd) => cmd.apply(dst).await, 99 | // `Unsubscribe` cannot be applied. It may only be received from the 100 | // context of a `Subscribe` command. 101 | Unsubscribe(_) => Err("`Unsubscribe` is unsupported in this context".into()), 102 | } 103 | } 104 | 105 | /// Returns the command name 106 | pub(crate) fn get_name(&self) -> &str { 107 | match self { 108 | Command::Get(_) => "get", 109 | Command::Publish(_) => "pub", 110 | Command::Set(_) => "set", 111 | Command::Subscribe(_) => "subscribe", 112 | Command::Unsubscribe(_) => "unsubscribe", 113 | Command::Unknown(cmd) => cmd.get_name(), 114 | } 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/buffer.rs: -------------------------------------------------------------------------------- 1 | use crate::client::Client; 2 | use crate::Result; 3 | 4 | use bytes::Bytes; 5 | use tokio::sync::mpsc::{channel, Receiver, Sender}; 6 | use tokio::sync::oneshot; 7 | 8 | /// Create a new client request buffer 9 | /// 10 | /// The `Client` performs MTProto commands directly on the TCP connection. Only a 11 | /// single request may be in-flight at a given time and operations require 12 | /// mutable access to the `Client` handle. This prevents using a single MTProto 13 | /// connection from multiple Tokio tasks. 14 | /// 15 | /// The strategy for dealing with this class of problem is to spawn a dedicated 16 | /// Tokio task to manage the MTProto connection and using "message passing" to 17 | /// operate on the connection. Commands are pushed into a channel. The 18 | /// connection task pops commands off of the channel and applies them to the 19 | /// MTProto connection. When the response is received, it is forwarded to the 20 | /// original requester. 21 | /// 22 | /// The returned `Buffer` handle may be cloned before passing the new handle to 23 | /// separate tasks. 24 | pub fn buffer(client: Client) -> Buffer { 25 | // Setting the message limit to a hard coded value of 32. in a real-app, the 26 | // buffer size should be configurable, but we don't need to do that here. 27 | let (tx, rx) = channel(32); 28 | 29 | // Spawn a task to process requests for the connection. 30 | tokio::spawn(async move { run(client, rx).await }); 31 | 32 | // Return the `Buffer` handle. 33 | Buffer { tx } 34 | } 35 | 36 | // Enum used to message pass the requested command from the `Buffer` handle 37 | #[derive(Debug)] 38 | enum Command { 39 | Get(String), 40 | Set(String, Bytes), 41 | } 42 | 43 | // Message type sent over the channel to the connection task. 44 | // 45 | // `Command` is the command to forward to the connection. 46 | // 47 | // `oneshot::Sender` is a channel type that sends a **single** value. It is used 48 | // here to send the response received from the connection back to the original 49 | // requester. 50 | type Message = (Command, oneshot::Sender>>); 51 | 52 | /// Receive commands sent through the channel and forward them to client. The 53 | /// response is returned back to the caller via a `oneshot`. 54 | async fn run(mut client: Client, mut rx: Receiver) { 55 | // Repeatedly pop messages from the channel. A return value of `None` 56 | // indicates that all `Buffer` handles have dropped and there will never be 57 | // another message sent on the channel. 58 | while let Some((cmd, tx)) = rx.recv().await { 59 | // The command is forwarded to the connection 60 | let response = match cmd { 61 | Command::Get(key) => client.get(&key).await, 62 | Command::Set(key, value) => client.set(&key, value).await.map(|_| None), 63 | }; 64 | 65 | // Send the response back to the caller. 66 | // 67 | // Failing to send the message indicates the `rx` half dropped 68 | // before receiving the message. This is a normal runtime event. 69 | let _ = tx.send(response); 70 | } 71 | } 72 | 73 | #[derive(Clone)] 74 | pub struct Buffer { 75 | tx: Sender, 76 | } 77 | 78 | impl Buffer { 79 | /// Get the value of a key. 80 | /// 81 | /// Same as `Client::get` but requests are **buffered** until the associated 82 | /// connection has the ability to send the request. 83 | pub async fn get(&mut self, key: &str) -> Result> { 84 | // Initialize a new `Get` command to send via the channel. 85 | let get = Command::Get(key.into()); 86 | 87 | // Initialize a new oneshot to be used to receive the response back from the connection. 88 | let (tx, rx) = oneshot::channel(); 89 | 90 | // Send the request 91 | self.tx.send((get, tx)).await?; 92 | 93 | // Await the response 94 | match rx.await { 95 | Ok(res) => res, 96 | Err(err) => Err(err.into()), 97 | } 98 | } 99 | 100 | /// Set `key` to hold the given `value`. 101 | /// 102 | /// Same as `Client::set` but requests are **buffered** until the associated 103 | /// connection has the ability to send the request 104 | pub async fn set(&mut self, key: &str, value: Bytes) -> Result<()> { 105 | // Initialize a new `Set` command to send via the channel. 106 | let set = Command::Set(key.into(), value); 107 | 108 | // Initialize a new oneshot to be used to receive the response back from the connection. 109 | let (tx, rx) = oneshot::channel(); 110 | 111 | // Send the request 112 | self.tx.send((set, tx)).await?; 113 | 114 | // Await the response 115 | match rx.await { 116 | Ok(res) => res.map(|_| ()), 117 | Err(err) => Err(err.into()), 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /src/cmd/set.rs: -------------------------------------------------------------------------------- 1 | use crate::cmd::{Parse, ParseError}; 2 | use crate::{Connection, Db, Frame}; 3 | 4 | use bytes::Bytes; 5 | use std::time::Duration; 6 | use tracing::{debug, instrument}; 7 | 8 | /// Set `key` to hold the string `value`. 9 | /// 10 | /// If `key` already holds a value, it is overwritten, regardless of its type. 11 | /// Any previous time to live associated with the key is discarded on successful 12 | /// SET operation. 13 | /// 14 | /// # Options 15 | /// 16 | /// Currently, the following options are supported: 17 | /// 18 | /// * EX `seconds` -- Set the specified expire time, in seconds. 19 | /// * PX `milliseconds` -- Set the specified expire time, in milliseconds. 20 | #[derive(Debug)] 21 | pub struct Set { 22 | /// the lookup key 23 | key: String, 24 | 25 | /// the value to be stored 26 | value: Bytes, 27 | 28 | /// When to expire the key 29 | expire: Option, 30 | } 31 | 32 | impl Set { 33 | /// Create a new `Set` command which sets `key` to `value`. 34 | /// 35 | /// If `expire` is `Some`, the value should expire after the specified 36 | /// duration. 37 | pub fn new(key: impl ToString, value: Bytes, expire: Option) -> Set { 38 | Set { 39 | key: key.to_string(), 40 | value, 41 | expire, 42 | } 43 | } 44 | 45 | /// Get the key 46 | pub fn key(&self) -> &str { 47 | &self.key 48 | } 49 | 50 | /// Get the value 51 | pub fn value(&self) -> &Bytes { 52 | &self.value 53 | } 54 | 55 | /// Get the expire 56 | pub fn expire(&self) -> Option { 57 | self.expire 58 | } 59 | 60 | /// Parse a `Set` instance from a received frame. 61 | /// 62 | /// The `Parse` argument provides a cursor-like API to read fields from the 63 | /// `Frame`. At this point, the entire frame has already been received from 64 | /// the socket. 65 | /// 66 | /// The `SET` string has already been consumed. 67 | /// 68 | /// # Returns 69 | /// 70 | /// Returns the `Set` value on success. If the frame is malformed, `Err` is 71 | /// returned. 72 | /// 73 | /// # Format 74 | /// 75 | /// Expects an array frame containing at least 3 entries. 76 | /// 77 | /// ```text 78 | /// SET key value [EX seconds|PX milliseconds] 79 | /// ``` 80 | pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { 81 | use ParseError::EndOfStream; 82 | 83 | // Read the key to set. This is a required field 84 | let key = parse.next_string()?; 85 | 86 | // Read the value to set. This is a required field. 87 | let value = parse.next_bytes()?; 88 | 89 | // The expiration is optional. If nothing else follows, then it is 90 | // `None`. 91 | let mut expire = None; 92 | 93 | // Attempt to parse another string. 94 | match parse.next_string() { 95 | Ok(s) if s.to_uppercase() == "EX" => { 96 | // An expiration is specified in seconds. The next value is an 97 | // integer. 98 | let secs = parse.next_int()?; 99 | expire = Some(Duration::from_secs(secs)); 100 | } 101 | Ok(s) if s.to_uppercase() == "PX" => { 102 | // An expiration is specified in milliseconds. The next value is 103 | // an integer. 104 | let ms = parse.next_int()?; 105 | expire = Some(Duration::from_millis(ms)); 106 | } 107 | // Currently, mini-telegram does not support any of the other SET 108 | // options. An error here results in the connection being 109 | // terminated. Other connections will continue to operate normally. 110 | Ok(_) => return Err("currently `SET` only supports the expiration option".into()), 111 | // The `EndOfStream` error indicates there is no further data to 112 | // parse. In this case, it is a normal run time situation and 113 | // indicates there are no specified `SET` options. 114 | Err(EndOfStream) => {} 115 | // All other errors are bubbled up, resulting in the connection 116 | // being terminated. 117 | Err(err) => return Err(err.into()), 118 | } 119 | 120 | Ok(Set { key, value, expire }) 121 | } 122 | 123 | /// Apply the `Set` command to the specified `Db` instance. 124 | /// 125 | /// The response is written to `dst`. This is called by the server in order 126 | /// to execute a received command. 127 | #[instrument(skip(self, db, dst))] 128 | pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> { 129 | // Set the value in the shared database state. 130 | db.set(self.key, self.value, self.expire); 131 | 132 | // Create a success response and write it to `dst`. 133 | let response = Frame::Simple("OK".to_string()); 134 | debug!(?response); 135 | dst.write_frame(&response).await?; 136 | 137 | Ok(()) 138 | } 139 | 140 | /// Converts the command into an equivalent `Frame`. 141 | /// 142 | /// This is called by the client when encoding a `Set` command to send to 143 | /// the server. 144 | pub(crate) fn into_frame(self) -> Frame { 145 | let mut frame = Frame::array(); 146 | frame.push_bulk(Bytes::from("set".as_bytes())); 147 | frame.push_bulk(Bytes::from(self.key.into_bytes())); 148 | frame.push_bulk(self.value); 149 | if let Some(ms) = self.expire { 150 | // Expirations in procotol can be specified in two ways 151 | // 1. SET key value EX seconds 152 | // 2. SET key value PX milliseconds 153 | // We the second option because it allows greater precision and 154 | // src/bin/cli.rs parses the expiration argument as milliseconds 155 | // in duration_from_ms_str() 156 | frame.push_bulk(Bytes::from("px".as_bytes())); 157 | frame.push_int(ms.as_millis() as u64); 158 | } 159 | frame 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /src/parse.rs: -------------------------------------------------------------------------------- 1 | use crate::Frame; 2 | 3 | use bytes::Bytes; 4 | use std::{fmt, str, vec}; 5 | 6 | /// Utility for parsing a command 7 | /// 8 | /// Commands are represented as array frames. Each entry in the frame is a 9 | /// "token". A `Parse` is initialized with the array frame and provides a 10 | /// cursor-like API. Each command struct includes a `parse_frame` method that 11 | /// uses a `Parse` to extract its fields. 12 | #[derive(Debug)] 13 | pub(crate) struct Parse { 14 | /// Array frame iterator. 15 | parts: vec::IntoIter, 16 | } 17 | 18 | /// Error encountered while parsing a frame. 19 | /// 20 | /// Only `EndOfStream` errors are handled at runtime. All other errors result in 21 | /// the connection being terminated. 22 | #[derive(Debug)] 23 | pub(crate) enum ParseError { 24 | /// Attempting to extract a value failed due to the frame being fully 25 | /// consumed. 26 | EndOfStream, 27 | 28 | /// All other errors 29 | Other(crate::Error), 30 | } 31 | 32 | impl Parse { 33 | /// Create a new `Parse` to parse the contents of `frame`. 34 | /// 35 | /// Returns `Err` if `frame` is not an array frame. 36 | pub(crate) fn new(frame: Frame) -> Result { 37 | let array = match frame { 38 | Frame::Array(array) => array, 39 | frame => return Err(format!("protocol error; expected array, got {:?}", frame).into()), 40 | }; 41 | 42 | Ok(Parse { 43 | parts: array.into_iter(), 44 | }) 45 | } 46 | 47 | /// Return the next entry. Array frames are arrays of frames, so the next 48 | /// entry is a frame. 49 | fn next(&mut self) -> Result { 50 | self.parts.next().ok_or(ParseError::EndOfStream) 51 | } 52 | 53 | /// Return the next entry as a string. 54 | /// 55 | /// If the next entry cannot be represented as a String, then an error is returned. 56 | pub(crate) fn next_string(&mut self) -> Result { 57 | match self.next()? { 58 | // Both `Simple` and `Bulk` representation may be strings. Strings 59 | // are parsed to UTF-8. 60 | // 61 | // While errors are stored as strings, they are considered separate 62 | // types. 63 | Frame::Simple(s) => Ok(s), 64 | Frame::Bulk(data) => str::from_utf8(&data[..]) 65 | .map(|s| s.to_string()) 66 | .map_err(|_| "protocol error; invalid string".into()), 67 | frame => Err(format!( 68 | "protocol error; expected simple frame or bulk frame, got {:?}", 69 | frame 70 | ) 71 | .into()), 72 | } 73 | } 74 | 75 | /// Return the next entry as raw bytes. 76 | /// 77 | /// If the next entry cannot be represented as raw bytes, an error is 78 | /// returned. 79 | pub(crate) fn next_bytes(&mut self) -> Result { 80 | match self.next()? { 81 | // Both `Simple` and `Bulk` representation may be raw bytes. 82 | // 83 | // Although errors are stored as strings and could be represented as 84 | // raw bytes, they are considered separate types. 85 | Frame::Simple(s) => Ok(Bytes::from(s.into_bytes())), 86 | Frame::Bulk(data) => Ok(data), 87 | frame => Err(format!( 88 | "protocol error; expected simple frame or bulk frame, got {:?}", 89 | frame 90 | ) 91 | .into()), 92 | } 93 | } 94 | 95 | /// Return the next entry as an integer. 96 | /// 97 | /// This includes `Simple`, `Bulk`, and `Integer` frame types. `Simple` and 98 | /// `Bulk` frame types are parsed. 99 | /// 100 | /// If the next entry cannot be represented as an integer, then an error is 101 | /// returned. 102 | pub(crate) fn next_int(&mut self) -> Result { 103 | use atoi::atoi; 104 | 105 | const MSG: &str = "protocol error; invalid number"; 106 | 107 | match self.next()? { 108 | // An integer frame type is already stored as an integer. 109 | Frame::Integer(v) => Ok(v), 110 | // Simple and bulk frames must be parsed as integers. If the parsing 111 | // fails, an error is returned. 112 | Frame::Simple(data) => atoi::(data.as_bytes()).ok_or_else(|| MSG.into()), 113 | Frame::Bulk(data) => atoi::(&data).ok_or_else(|| MSG.into()), 114 | frame => Err(format!("protocol error; expected int frame but got {:?}", frame).into()), 115 | } 116 | } 117 | 118 | /// Ensure there are no more entries in the array 119 | pub(crate) fn finish(&mut self) -> Result<(), ParseError> { 120 | if self.parts.next().is_none() { 121 | Ok(()) 122 | } else { 123 | Err("protocol error; expected end of frame, but there was more".into()) 124 | } 125 | } 126 | } 127 | 128 | impl From for ParseError { 129 | fn from(src: String) -> ParseError { 130 | ParseError::Other(src.into()) 131 | } 132 | } 133 | 134 | impl From<&str> for ParseError { 135 | fn from(src: &str) -> ParseError { 136 | src.to_string().into() 137 | } 138 | } 139 | 140 | impl fmt::Display for ParseError { 141 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 142 | match self { 143 | ParseError::EndOfStream => "protocol error; unexpected end of stream".fmt(f), 144 | ParseError::Other(err) => err.fmt(f), 145 | } 146 | } 147 | } 148 | 149 | impl std::error::Error for ParseError {} 150 | 151 | #[cfg(test)] 152 | mod tests { 153 | use super::*; 154 | 155 | #[tokio::test] 156 | async fn test_parse() { 157 | let mut frame = Frame::array(); 158 | frame.push_bulk(Bytes::from("set")); 159 | frame.push_bulk(Bytes::from("key")); 160 | frame.push_int(100); 161 | let mut parser = Parse::new(frame).unwrap(); 162 | 163 | let cmd = parser.next_string().unwrap(); 164 | assert_eq!(cmd, "set"); 165 | 166 | let key = parser.next_string().unwrap(); 167 | assert_eq!(key, "key"); 168 | 169 | let value = parser.next_int().unwrap(); 170 | assert_eq!(value, 100); 171 | 172 | assert_eq!(parser.finish().unwrap(), ()); 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /src/blocking_client.rs: -------------------------------------------------------------------------------- 1 | //! Minimal blocking MTProto client implementation 2 | //! 3 | //! Provides a blocking connect and methods for issuing the supported commands. 4 | 5 | use bytes::Bytes; 6 | use std::time::Duration; 7 | use tokio::net::ToSocketAddrs; 8 | use tokio::runtime::Runtime; 9 | 10 | pub use crate::client::Message; 11 | 12 | /// Established connection with a MTProto server. 13 | /// 14 | /// Backed by a single `TcpStream`, `BlockingClient` provides basic network 15 | /// client functionality (no pooling, retrying, ...). Connections are 16 | /// established using the [`connect`](fn@connect) function. 17 | /// 18 | /// Requests are issued using the various methods of `Client`. 19 | pub struct BlockingClient { 20 | /// The asynchronous `Client`. 21 | inner: crate::client::Client, 22 | 23 | /// A `current_thread` runtime for executing operations on the asynchronous 24 | /// client in a blocking manner. 25 | rt: Runtime, 26 | } 27 | 28 | /// A client that has entered pub/sub mode. 29 | /// 30 | /// Once clients subscribe to a channel, they may only perform pub/sub related 31 | /// commands. The `BlockingClient` type is transitioned to a 32 | /// `BlockingSubscriber` type in order to prevent non-pub/sub methods from being 33 | /// called. 34 | pub struct BlockingSubscriber { 35 | /// The asynchronous `Subscriber`. 36 | inner: crate::client::Subscriber, 37 | 38 | /// A `current_thread` runtime for executing operations on the asynchronous 39 | /// `Subscriber` in a blocking manner. 40 | rt: Runtime, 41 | } 42 | 43 | /// The iterator returned by `Subscriber::into_iter`. 44 | struct SubscriberIterator { 45 | /// The asynchronous `Subscriber`. 46 | inner: crate::client::Subscriber, 47 | 48 | /// A `current_thread` runtime for executing operations on the asynchronous 49 | /// `Subscriber` in a blocking manner. 50 | rt: Runtime, 51 | } 52 | 53 | /// Establish a connection with the MTProt server located at `addr`. 54 | /// 55 | /// `addr` may be any type that can be asynchronously converted to a 56 | /// `SocketAddr`. This includes `SocketAddr` and strings. The `ToSocketAddrs` 57 | /// trait is the Tokio version and not the `std` version. 58 | /// 59 | /// # Examples 60 | /// 61 | /// ```no_run 62 | /// use mini_telegram::blocking_client; 63 | /// 64 | /// fn main() { 65 | /// let client = match blocking_client::connect("localhost:6379") { 66 | /// Ok(client) => client, 67 | /// Err(_) => panic!("failed to establish connection"), 68 | /// }; 69 | /// # drop(client); 70 | /// } 71 | /// ``` 72 | pub fn connect(addr: T) -> crate::Result { 73 | let rt = tokio::runtime::Builder::new_current_thread() 74 | .enable_all() 75 | .build()?; 76 | 77 | let inner = rt.block_on(crate::client::connect(addr))?; 78 | 79 | Ok(BlockingClient { inner, rt }) 80 | } 81 | 82 | impl BlockingClient { 83 | /// Get the value of key. 84 | /// 85 | /// If the key does not exist the special value `None` is returned. 86 | /// 87 | /// # Examples 88 | /// 89 | /// Demonstrates basic usage. 90 | /// 91 | /// ```no_run 92 | /// use mini_telegram::blocking_client; 93 | /// 94 | /// fn main() { 95 | /// let mut client = blocking_client::connect("localhost:6379").unwrap(); 96 | /// 97 | /// let val = client.get("foo").unwrap(); 98 | /// println!("Got = {:?}", val); 99 | /// } 100 | /// ``` 101 | pub fn get(&mut self, key: &str) -> crate::Result> { 102 | self.rt.block_on(self.inner.get(key)) 103 | } 104 | 105 | /// Set `key` to hold the given `value`. 106 | /// 107 | /// The `value` is associated with `key` until it is overwritten by the next 108 | /// call to `set` or it is removed. 109 | /// 110 | /// If key already holds a value, it is overwritten. Any previous time to 111 | /// live associated with the key is discarded on successful SET operation. 112 | /// 113 | /// # Examples 114 | /// 115 | /// Demonstrates basic usage. 116 | /// 117 | /// ```no_run 118 | /// use mini_telegram::blocking_client; 119 | /// 120 | /// fn main() { 121 | /// let mut client = blocking_client::connect("localhost:6379").unwrap(); 122 | /// 123 | /// client.set("foo", "bar".into()).unwrap(); 124 | /// 125 | /// // Getting the value immediately works 126 | /// let val = client.get("foo").unwrap().unwrap(); 127 | /// assert_eq!(val, "bar"); 128 | /// } 129 | /// ``` 130 | pub fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> { 131 | self.rt.block_on(self.inner.set(key, value)) 132 | } 133 | 134 | /// Set `key` to hold the given `value`. The value expires after `expiration` 135 | /// 136 | /// The `value` is associated with `key` until one of the following: 137 | /// - it expires. 138 | /// - it is overwritten by the next call to `set`. 139 | /// - it is removed. 140 | /// 141 | /// If key already holds a value, it is overwritten. Any previous time to 142 | /// live associated with the key is discarded on a successful SET operation. 143 | /// 144 | /// # Examples 145 | /// 146 | /// Demonstrates basic usage. This example is not **guaranteed** to always 147 | /// work as it relies on time based logic and assumes the client and server 148 | /// stay relatively synchronized in time. The real world tends to not be so 149 | /// favorable. 150 | /// 151 | /// ```no_run 152 | /// use mini_telegram::blocking_client; 153 | /// use std::thread; 154 | /// use std::time::Duration; 155 | /// 156 | /// fn main() { 157 | /// let ttl = Duration::from_millis(500); 158 | /// let mut client = blocking_client::connect("localhost:6379").unwrap(); 159 | /// 160 | /// client.set_expires("foo", "bar".into(), ttl).unwrap(); 161 | /// 162 | /// // Getting the value immediately works 163 | /// let val = client.get("foo").unwrap().unwrap(); 164 | /// assert_eq!(val, "bar"); 165 | /// 166 | /// // Wait for the TTL to expire 167 | /// thread::sleep(ttl); 168 | /// 169 | /// let val = client.get("foo").unwrap(); 170 | /// assert!(val.is_some()); 171 | /// } 172 | /// ``` 173 | pub fn set_expires( 174 | &mut self, 175 | key: &str, 176 | value: Bytes, 177 | expiration: Duration, 178 | ) -> crate::Result<()> { 179 | self.rt 180 | .block_on(self.inner.set_expires(key, value, expiration)) 181 | } 182 | 183 | /// Posts `message` to the given `channel`. 184 | /// 185 | /// Returns the number of subscribers currently listening on the channel. 186 | /// There is no guarantee that these subscribers receive the message as they 187 | /// may disconnect at any time. 188 | /// 189 | /// # Examples 190 | /// 191 | /// Demonstrates basic usage. 192 | /// 193 | /// ```no_run 194 | /// use mini_telegram::blocking_client; 195 | /// 196 | /// fn main() { 197 | /// let mut client = blocking_client::connect("localhost:6379").unwrap(); 198 | /// 199 | /// let val = client.publish("foo", "bar".into()).unwrap(); 200 | /// println!("Got = {:?}", val); 201 | /// } 202 | /// ``` 203 | pub fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result { 204 | self.rt.block_on(self.inner.publish(channel, message)) 205 | } 206 | 207 | /// Subscribes the client to the specified channels. 208 | /// 209 | /// Once a client issues a subscribe command, it may no longer issue any 210 | /// non-pub/sub commands. The function consumes `self` and returns a 211 | /// `BlockingSubscriber`. 212 | /// 213 | /// The `BlockingSubscriber` value is used to receive messages as well as 214 | /// manage the list of channels the client is subscribed to. 215 | pub fn subscribe(self, channels: Vec) -> crate::Result { 216 | let subscriber = self.rt.block_on(self.inner.subscribe(channels))?; 217 | Ok(BlockingSubscriber { 218 | inner: subscriber, 219 | rt: self.rt, 220 | }) 221 | } 222 | } 223 | 224 | impl BlockingSubscriber { 225 | /// Returns the set of channels currently subscribed to. 226 | pub fn get_subscribed(&self) -> &[String] { 227 | self.inner.get_subscribed() 228 | } 229 | 230 | /// Receive the next message published on a subscribed channel, waiting if 231 | /// necessary. 232 | /// 233 | /// `None` indicates the subscription has been terminated. 234 | pub fn next_message(&mut self) -> crate::Result> { 235 | self.rt.block_on(self.inner.next_message()) 236 | } 237 | 238 | /// Convert the subscriber into an `Iterator` yielding new messages published 239 | /// on subscribed channels. 240 | pub fn into_iter(self) -> impl Iterator> { 241 | SubscriberIterator { 242 | inner: self.inner, 243 | rt: self.rt, 244 | } 245 | } 246 | 247 | /// Subscribe to a list of new channels 248 | pub fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> { 249 | self.rt.block_on(self.inner.subscribe(channels)) 250 | } 251 | 252 | /// Unsubscribe to a list of new channels 253 | pub fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> { 254 | self.rt.block_on(self.inner.unsubscribe(channels)) 255 | } 256 | } 257 | 258 | impl Iterator for SubscriberIterator { 259 | type Item = crate::Result; 260 | 261 | fn next(&mut self) -> Option> { 262 | self.rt.block_on(self.inner.next_message()).transpose() 263 | } 264 | } 265 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /src/cmd/subscribe.rs: -------------------------------------------------------------------------------- 1 | use crate::cmd::{Parse, ParseError, Unknown}; 2 | use crate::{Command, Connection, Db, Frame, Shutdown}; 3 | 4 | use bytes::Bytes; 5 | use std::pin::Pin; 6 | use tokio::select; 7 | use tokio::sync::broadcast; 8 | use tokio_stream::{Stream, StreamExt, StreamMap}; 9 | 10 | /// Subscribes the client to one or more channels. 11 | /// 12 | /// Once the client enters the subscribed state, it is not supposed to issue any 13 | /// other commands, except for additional SUBSCRIBE, PSUBSCRIBE, UNSUBSCRIBE, 14 | /// PUNSUBSCRIBE, PING and QUIT commands. 15 | #[derive(Debug)] 16 | pub struct Subscribe { 17 | channels: Vec, 18 | } 19 | 20 | /// Unsubscribes the client from one or more channels. 21 | /// 22 | /// When no channels are specified, the client is unsubscribed from all the 23 | /// previously subscribed channels. 24 | #[derive(Clone, Debug)] 25 | pub struct Unsubscribe { 26 | channels: Vec, 27 | } 28 | 29 | /// Stream of messages. The stream receives messages from the 30 | /// `broadcast::Receiver`. We use `stream!` to create a `Stream` that consumes 31 | /// messages. Because `stream!` values cannot be named, we box the stream using 32 | /// a trait object. 33 | type Messages = Pin + Send>>; 34 | 35 | impl Subscribe { 36 | /// Creates a new `Subscribe` command to listen on the specified channels. 37 | pub(crate) fn new(channels: &[String]) -> Subscribe { 38 | Subscribe { 39 | channels: channels.to_vec(), 40 | } 41 | } 42 | 43 | /// Parse a `Subscribe` instance from a received frame. 44 | /// 45 | /// The `Parse` argument provides a cursor-like API to read fields from the 46 | /// `Frame`. At this point, the entire frame has already been received from 47 | /// the socket. 48 | /// 49 | /// The `SUBSCRIBE` string has already been consumed. 50 | /// 51 | /// # Returns 52 | /// 53 | /// On success, the `Subscribe` value is returned. If the frame is 54 | /// malformed, `Err` is returned. 55 | /// 56 | /// # Format 57 | /// 58 | /// Expects an array frame containing two or more entries. 59 | /// 60 | /// ```text 61 | /// SUBSCRIBE channel [channel ...] 62 | /// ``` 63 | pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { 64 | use ParseError::EndOfStream; 65 | 66 | // The `SUBSCRIBE` string has already been consumed. At this point, 67 | // there is one or more strings remaining in `parse`. These represent 68 | // the channels to subscribe to. 69 | // 70 | // Extract the first string. If there is none, the the frame is 71 | // malformed and the error is bubbled up. 72 | let mut channels = vec![parse.next_string()?]; 73 | 74 | // Now, the remainder of the frame is consumed. Each value must be a 75 | // string or the frame is malformed. Once all values in the frame have 76 | // been consumed, the command is fully parsed. 77 | loop { 78 | match parse.next_string() { 79 | // A string has been consumed from the `parse`, push it into the 80 | // list of channels to subscribe to. 81 | Ok(s) => channels.push(s), 82 | // The `EndOfStream` error indicates there is no further data to 83 | // parse. 84 | Err(EndOfStream) => break, 85 | // All other errors are bubbled up, resulting in the connection 86 | // being terminated. 87 | Err(err) => return Err(err.into()), 88 | } 89 | } 90 | 91 | Ok(Subscribe { channels }) 92 | } 93 | 94 | /// Apply the `Subscribe` command to the specified `Db` instance. 95 | /// 96 | /// This function is the entry point and includes the initial list of 97 | /// channels to subscribe to. Additional `subscribe` and `unsubscribe` 98 | /// commands may be received from the client and the list of subscriptions 99 | /// are updated accordingly. 100 | pub(crate) async fn apply( 101 | mut self, 102 | db: &Db, 103 | dst: &mut Connection, 104 | shutdown: &mut Shutdown, 105 | ) -> crate::Result<()> { 106 | // Each individual channel subscription is handled using a 107 | // `sync::broadcast` channel. Messages are then fanned out to all 108 | // clients currently subscribed to the channels. 109 | // 110 | // An individual client may subscribe to multiple channels and may 111 | // dynamically add and remove channels from its subscription set. To 112 | // handle this, a `StreamMap` is used to track active subscriptions. The 113 | // `StreamMap` merges messages from individual broadcast channels as 114 | // they are received. 115 | let mut subscriptions = StreamMap::new(); 116 | 117 | loop { 118 | // `self.channels` is used to track additional channels to subscribe 119 | // to. When new `SUBSCRIBE` commands are received during the 120 | // execution of `apply`, the new channels are pushed onto this vec. 121 | for channel_name in self.channels.drain(..) { 122 | subscribe_to_channel(channel_name, &mut subscriptions, db, dst).await?; 123 | } 124 | 125 | // Wait for one of the following to happen: 126 | // 127 | // - Receive a message from one of the subscribed channels. 128 | // - Receive a subscribe or unsubscribe command from the client. 129 | // - A server shutdown signal. 130 | select! { 131 | // Receive messages from subscribed channels 132 | Some((channel_name, msg)) = subscriptions.next() => { 133 | dst.write_frame(&make_message_frame(channel_name, msg)).await?; 134 | } 135 | res = dst.read_frame() => { 136 | let frame = match res? { 137 | Some(frame) => frame, 138 | // This happens if the remote client has disconnected. 139 | None => return Ok(()) 140 | }; 141 | 142 | handle_command( 143 | frame, 144 | &mut self.channels, 145 | &mut subscriptions, 146 | dst, 147 | ).await?; 148 | } 149 | _ = shutdown.recv() => { 150 | return Ok(()); 151 | } 152 | }; 153 | } 154 | } 155 | 156 | /// Converts the command into an equivalent `Frame`. 157 | /// 158 | /// This is called by the client when encoding a `Subscribe` command to send 159 | /// to the server. 160 | pub(crate) fn into_frame(self) -> Frame { 161 | let mut frame = Frame::array(); 162 | frame.push_bulk(Bytes::from("subscribe".as_bytes())); 163 | for channel in self.channels { 164 | frame.push_bulk(Bytes::from(channel.into_bytes())); 165 | } 166 | frame 167 | } 168 | } 169 | 170 | async fn subscribe_to_channel( 171 | channel_name: String, 172 | subscriptions: &mut StreamMap, 173 | db: &Db, 174 | dst: &mut Connection, 175 | ) -> crate::Result<()> { 176 | let mut rx = db.subscribe(channel_name.clone()); 177 | 178 | // Subscribe to the channel. 179 | let rx = Box::pin(async_stream::stream! { 180 | loop { 181 | match rx.recv().await { 182 | Ok(msg) => yield msg, 183 | // If we lagged in consuming messages, just resume. 184 | Err(broadcast::error::RecvError::Lagged(_)) => {} 185 | Err(_) => break, 186 | } 187 | } 188 | }); 189 | 190 | // Track subscription in this client's subscription set. 191 | subscriptions.insert(channel_name.clone(), rx); 192 | 193 | // Respond with the successful subscription 194 | let response = make_subscribe_frame(channel_name, subscriptions.len()); 195 | dst.write_frame(&response).await?; 196 | 197 | Ok(()) 198 | } 199 | 200 | /// Handle a command received while inside `Subscribe::apply`. Only subscribe 201 | /// and unsubscribe commands are permitted in this context. 202 | /// 203 | /// Any new subscriptions are appended to `subscribe_to` instead of modifying 204 | /// `subscriptions`. 205 | async fn handle_command( 206 | frame: Frame, 207 | subscribe_to: &mut Vec, 208 | subscriptions: &mut StreamMap, 209 | dst: &mut Connection, 210 | ) -> crate::Result<()> { 211 | // A command has been received from the client. 212 | // 213 | // Only `SUBSCRIBE` and `UNSUBSCRIBE` commands are permitted 214 | // in this context. 215 | match Command::from_frame(frame)? { 216 | Command::Subscribe(subscribe) => { 217 | // The `apply` method will subscribe to the channels we add to this 218 | // vector. 219 | subscribe_to.extend(subscribe.channels.into_iter()); 220 | } 221 | Command::Unsubscribe(mut unsubscribe) => { 222 | // If no channels are specified, this requests unsubscribing from 223 | // **all** channels. To implement this, the `unsubscribe.channels` 224 | // vec is populated with the list of channels currently subscribed 225 | // to. 226 | if unsubscribe.channels.is_empty() { 227 | unsubscribe.channels = subscriptions 228 | .keys() 229 | .map(|channel_name| channel_name.to_string()) 230 | .collect(); 231 | } 232 | 233 | for channel_name in unsubscribe.channels { 234 | subscriptions.remove(&channel_name); 235 | 236 | let response = make_unsubscribe_frame(channel_name, subscriptions.len()); 237 | dst.write_frame(&response).await?; 238 | } 239 | } 240 | command => { 241 | let cmd = Unknown::new(command.get_name()); 242 | cmd.apply(dst).await?; 243 | } 244 | } 245 | Ok(()) 246 | } 247 | 248 | /// Creates the response to a subcribe request. 249 | /// 250 | /// All of these functions take the `channel_name` as a `String` instead of 251 | /// a `&str` since `Bytes::from` can reuse the allocation in the `String`, and 252 | /// taking a `&str` would require copying the data. This allows the caller to 253 | /// decide whether to clone the channel name or not. 254 | fn make_subscribe_frame(channel_name: String, num_subs: usize) -> Frame { 255 | let mut response = Frame::array(); 256 | response.push_bulk(Bytes::from_static(b"subscribe")); 257 | response.push_bulk(Bytes::from(channel_name)); 258 | response.push_int(num_subs as u64); 259 | response 260 | } 261 | 262 | /// Creates the response to an unsubcribe request. 263 | fn make_unsubscribe_frame(channel_name: String, num_subs: usize) -> Frame { 264 | let mut response = Frame::array(); 265 | response.push_bulk(Bytes::from_static(b"unsubscribe")); 266 | response.push_bulk(Bytes::from(channel_name)); 267 | response.push_int(num_subs as u64); 268 | response 269 | } 270 | 271 | /// Creates a message informing the client about a new message on a channel that 272 | /// the client subscribes to. 273 | fn make_message_frame(channel_name: String, msg: Bytes) -> Frame { 274 | let mut response = Frame::array(); 275 | response.push_bulk(Bytes::from_static(b"message")); 276 | response.push_bulk(Bytes::from(channel_name)); 277 | response.push_bulk(msg); 278 | response 279 | } 280 | 281 | impl Unsubscribe { 282 | /// Create a new `Unsubscribe` command with the given `channels`. 283 | pub(crate) fn new(channels: &[String]) -> Unsubscribe { 284 | Unsubscribe { 285 | channels: channels.to_vec(), 286 | } 287 | } 288 | 289 | /// Parse a `Unsubscribe` instance from a received frame. 290 | /// 291 | /// The `Parse` argument provides a cursor-like API to read fields from the 292 | /// `Frame`. At this point, the entire frame has already been received from 293 | /// the socket. 294 | /// 295 | /// The `UNSUBSCRIBE` string has already been consumed. 296 | /// 297 | /// # Returns 298 | /// 299 | /// On success, the `Unsubscribe` value is returned. If the frame is 300 | /// malformed, `Err` is returned. 301 | /// 302 | /// # Format 303 | /// 304 | /// Expects an array frame containing at least one entry. 305 | /// 306 | /// ```text 307 | /// UNSUBSCRIBE [channel [channel ...]] 308 | /// ``` 309 | pub(crate) fn parse_frames(parse: &mut Parse) -> Result { 310 | use ParseError::EndOfStream; 311 | 312 | // There may be no channels listed, so start with an empty vec. 313 | let mut channels = vec![]; 314 | 315 | // Each entry in the frame must be a string or the frame is malformed. 316 | // Once all values in the frame have been consumed, the command is fully 317 | // parsed. 318 | loop { 319 | match parse.next_string() { 320 | // A string has been consumed from the `parse`, push it into the 321 | // list of channels to unsubscribe from. 322 | Ok(s) => channels.push(s), 323 | // The `EndOfStream` error indicates there is no further data to 324 | // parse. 325 | Err(EndOfStream) => break, 326 | // All other errors are bubbled up, resulting in the connection 327 | // being terminated. 328 | Err(err) => return Err(err), 329 | } 330 | } 331 | 332 | Ok(Unsubscribe { channels }) 333 | } 334 | 335 | /// Converts the command into an equivalent `Frame`. 336 | /// 337 | /// This is called by the client when encoding an `Unsubscribe` command to 338 | /// send to the server. 339 | pub(crate) fn into_frame(self) -> Frame { 340 | let mut frame = Frame::array(); 341 | frame.push_bulk(Bytes::from("unsubscribe".as_bytes())); 342 | 343 | for channel in self.channels { 344 | frame.push_bulk(Bytes::from(channel.into_bytes())); 345 | } 346 | 347 | frame 348 | } 349 | } 350 | -------------------------------------------------------------------------------- /src/frame.rs: -------------------------------------------------------------------------------- 1 | //! Provides a type representing a protocol frame as well as utilities for 2 | //! parsing frames from a byte array. 3 | 4 | use bytes::{Buf, Bytes}; 5 | use std::convert::TryInto; 6 | use std::fmt; 7 | use std::io::Cursor; 8 | use std::num::TryFromIntError; 9 | use std::string::FromUtf8Error; 10 | 11 | /// A frame in the MTProto protocol. 12 | #[derive(Clone, Debug)] 13 | pub enum Frame { 14 | Simple(String), 15 | Error(String), 16 | Integer(u64), 17 | Bulk(Bytes), 18 | Null, 19 | // Clients send commands to the server using Array 20 | Array(Vec), 21 | } 22 | 23 | #[derive(Debug)] 24 | pub enum Error { 25 | /// Not enough data is available to parse a message 26 | Incomplete, 27 | 28 | /// Invalid message encoding 29 | Other(crate::Error), 30 | } 31 | 32 | impl Frame { 33 | /// Returns an empty array 34 | pub(crate) fn array() -> Frame { 35 | Frame::Array(vec![]) 36 | } 37 | 38 | /// Push a "bulk" frame into the array. `self` must be an Array frame. 39 | /// 40 | /// # Panics 41 | /// 42 | /// panics if `self` is not an array 43 | pub(crate) fn push_bulk(&mut self, bytes: Bytes) { 44 | match self { 45 | Frame::Array(vec) => { 46 | vec.push(Frame::Bulk(bytes)); 47 | } 48 | _ => panic!("not an array frame"), 49 | } 50 | } 51 | 52 | /// Push an "integer" frame into the array. `self` must be an Array frame. 53 | /// 54 | /// # Panics 55 | /// 56 | /// panics if `self` is not an array 57 | pub(crate) fn push_int(&mut self, value: u64) { 58 | match self { 59 | Frame::Array(vec) => { 60 | vec.push(Frame::Integer(value)); 61 | } 62 | _ => panic!("not an array frame"), 63 | } 64 | } 65 | 66 | /// Checks if an entire message can be decoded from `src` 67 | pub fn check(src: &mut Cursor<&[u8]>) -> Result<(), Error> { 68 | match get_u8(src)? { 69 | b'+' => { 70 | get_line(src)?; 71 | Ok(()) 72 | } 73 | b'-' => { 74 | get_line(src)?; // "-Error message\r\n" 75 | Ok(()) 76 | } 77 | b':' => { 78 | let _ = get_decimal(src)?; // ":1000\r\n" 79 | Ok(()) 80 | } 81 | b'$' => { 82 | // Bulk Strings are encoded in the following way: 83 | // - A "$" byte followed by the number of bytes composing the 84 | // string (a prefixed length), terminated by CRLF. 85 | // - The actual string data. 86 | // - A final CRLF. 87 | if b'-' == peek_u8(src)? { 88 | // Bulk Strings can also be used in order to signal non-existence 89 | // of a value using a special format that is used to represent a Null value. 90 | // In this special format the length is -1, and there is no data, 91 | // so a Null is represented as: "$-1\r\n" 92 | 93 | // Skip '-1\r\n' 94 | skip(src, 4) 95 | } else { 96 | // Read the bulk string 97 | let len: usize = get_decimal(src)?.try_into()?; // need to impl From 98 | 99 | // skip that number of bytes + 2 (\r\n). 100 | skip(src, len + 2) 101 | } 102 | } 103 | b'*' => { 104 | // Arrays are sent using the following format: 105 | // - A "*" character as the first byte, followed by the number of 106 | // elements in the array as a decimal number, followed by CRLF. 107 | 108 | let len = get_decimal(src)?; // get array length 109 | // An array with 5 elements: 110 | // *5\r\n 111 | // :1\r\n 112 | // :2\r\n 113 | // :3\r\n 114 | // :4\r\n 115 | // $6\r\n 116 | // foobar\r\n 117 | for _ in 0..len { 118 | Frame::check(src)?; 119 | } 120 | 121 | Ok(()) 122 | } 123 | actual => Err(format!("protocol error; invalid frame type byte `{}`", actual).into()), // impl From<&str> 124 | } 125 | } 126 | 127 | /// The message has already been validated with `check`. 128 | pub fn parse(src: &mut Cursor<&[u8]>) -> Result { 129 | match get_u8(src)? { 130 | b'+' => { 131 | // Read the line and convert it to `Vec` 132 | let line = get_line(src)?.to_vec(); 133 | 134 | // Convert the line to a String 135 | let string = String::from_utf8(line)?; // impl From 136 | 137 | Ok(Frame::Simple(string)) 138 | } 139 | b'-' => { 140 | // Read the line and convert it to `Vec` 141 | let line = get_line(src)?.to_vec(); 142 | 143 | // Convert the line to a String 144 | let string = String::from_utf8(line)?; 145 | 146 | Ok(Frame::Error(string)) 147 | } 148 | b':' => { 149 | let len = get_decimal(src)?; 150 | Ok(Frame::Integer(len)) 151 | } 152 | b'$' => { 153 | if b'-' == peek_u8(src)? { 154 | let line = get_line(src)?; 155 | 156 | if line != b"-1" { 157 | return Err("protocol error; invalid frame format".into()); 158 | } 159 | 160 | Ok(Frame::Null) 161 | } else { 162 | // Read the bulk string 163 | let len = get_decimal(src)?.try_into()?; 164 | let n = len + 2; 165 | 166 | if src.remaining() < n { 167 | return Err(Error::Incomplete); 168 | } 169 | 170 | let data = Bytes::copy_from_slice(&src.chunk()[..len]); 171 | 172 | // skip that number of bytes + 2 (\r\n). 173 | skip(src, n)?; 174 | 175 | Ok(Frame::Bulk(data)) 176 | } 177 | } 178 | b'*' => { 179 | let len = get_decimal(src)?.try_into()?; 180 | let mut out = Vec::with_capacity(len); 181 | 182 | for _ in 0..len { 183 | out.push(Frame::parse(src)?); 184 | } 185 | 186 | Ok(Frame::Array(out)) 187 | } 188 | _ => unimplemented!(), 189 | } 190 | } 191 | 192 | /// Converts the frame to an "unexpected frame" error 193 | pub(crate) fn to_error(&self) -> crate::Error { 194 | format!("unexpected frame: {}", self).into() // impl fmt::Display 195 | } 196 | } 197 | 198 | // Frame == &String 199 | impl PartialEq<&str> for Frame { 200 | fn eq(&self, other: &&str) -> bool { 201 | match self { 202 | Frame::Simple(s) => s.eq(other), 203 | Frame::Bulk(s) => s.eq(other), 204 | _ => false, 205 | } 206 | } 207 | } 208 | 209 | impl fmt::Display for Frame { 210 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { 211 | use std::str; 212 | 213 | match self { 214 | Frame::Simple(response) => response.fmt(fmt), 215 | Frame::Error(msg) => write!(fmt, "error: {}", msg), 216 | Frame::Integer(num) => num.fmt(fmt), 217 | Frame::Bulk(msg) => match str::from_utf8(msg) { 218 | Ok(string) => string.fmt(fmt), 219 | Err(_) => write!(fmt, "{:?}", msg), 220 | }, 221 | Frame::Null => "(nil)".fmt(fmt), 222 | Frame::Array(parts) => { 223 | for (i, part) in parts.iter().enumerate() { 224 | if i > 0 { 225 | write!(fmt, " ")?; 226 | } 227 | part.fmt(fmt)?; 228 | } 229 | 230 | Ok(()) 231 | } 232 | } 233 | } 234 | } 235 | 236 | fn peek_u8(src: &mut Cursor<&[u8]>) -> Result { 237 | if !src.has_remaining() { 238 | return Err(Error::Incomplete); 239 | } 240 | 241 | Ok(src.chunk()[0]) 242 | } 243 | 244 | fn get_u8(src: &mut Cursor<&[u8]>) -> Result { 245 | if !src.has_remaining() { 246 | return Err(Error::Incomplete); 247 | } 248 | 249 | Ok(src.get_u8()) 250 | } 251 | 252 | fn skip(src: &mut Cursor<&[u8]>, n: usize) -> Result<(), Error> { 253 | if src.remaining() < n { 254 | return Err(Error::Incomplete); 255 | } 256 | 257 | src.advance(n); 258 | Ok(()) 259 | } 260 | 261 | /// Read a new-line terminated decimal 262 | fn get_decimal(src: &mut Cursor<&[u8]>) -> Result { 263 | use atoi::atoi; 264 | 265 | let line = get_line(src)?; 266 | 267 | atoi::(line).ok_or_else(|| "protocol error; invalid frame format".into()) 268 | } 269 | 270 | /// Find a line 271 | fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], Error> { 272 | // Scan the bytes directly 273 | let start = src.position() as usize; 274 | // Scan to the second to last byte 275 | let end = src.get_ref().len() - 1; 276 | 277 | for i in start..end { 278 | if src.get_ref()[i] == b'\r' && src.get_ref()[i + 1] == b'\n' { 279 | // We found a line, update the position to be *after* the \n 280 | src.set_position((i + 2) as u64); 281 | 282 | // Return the line 283 | return Ok(&src.get_ref()[start..i]); 284 | } 285 | } 286 | 287 | Err(Error::Incomplete) 288 | } 289 | 290 | impl From for Error { 291 | fn from(src: String) -> Error { 292 | Error::Other(src.into()) 293 | } 294 | } 295 | 296 | impl From<&str> for Error { 297 | fn from(src: &str) -> Error { 298 | src.to_string().into() 299 | } 300 | } 301 | 302 | impl From for Error { 303 | fn from(_src: TryFromIntError) -> Error { 304 | "protocol error; invalid frame format".into() 305 | } 306 | } 307 | 308 | impl From for Error { 309 | fn from(_src: FromUtf8Error) -> Error { 310 | "protocol error; invalid frame format".into() 311 | } 312 | } 313 | 314 | impl std::error::Error for Error {} 315 | 316 | impl fmt::Display for Error { 317 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { 318 | match self { 319 | Error::Incomplete => "stream ended early".fmt(fmt), 320 | Error::Other(err) => err.fmt(fmt), 321 | } 322 | } 323 | } 324 | 325 | #[cfg(test)] 326 | mod tests { 327 | use super::*; 328 | 329 | #[test] 330 | fn test_cursor() { 331 | use std::io::prelude::*; 332 | use std::io::SeekFrom; 333 | 334 | let mut buff = Cursor::new(vec![1, 2, 3, 4, 5]); 335 | assert_eq!(buff.position(), 0); 336 | buff.seek(SeekFrom::Current(2)).unwrap(); 337 | assert_eq!(buff.position(), 2); 338 | 339 | buff.seek(SeekFrom::Current(-1)).unwrap(); 340 | assert_eq!(buff.position(), 1); 341 | 342 | buff.set_position(2); 343 | assert_eq!(buff.position(), 2); 344 | assert_eq!(buff.remaining_slice(), &[3, 4, 5]); 345 | 346 | buff.set_position(10); 347 | assert!(buff.is_empty()); 348 | 349 | buff.set_position(2); 350 | assert!(buff.has_remaining()); 351 | 352 | buff.advance(1); 353 | assert_eq!(buff.position(), 3); 354 | assert!(buff.has_remaining()); 355 | 356 | assert_eq!(buff.get_ref().len(), 5); 357 | } 358 | 359 | #[test] 360 | fn test_cursor_get_line() { 361 | let stream = &b"+10\r\n+20\r\n"[..]; 362 | let mut buff = Cursor::new(stream); 363 | 364 | let res10 = get_line(&mut buff).unwrap(); 365 | assert_eq!(String::from_utf8(res10.to_vec()).unwrap(), "+10"); 366 | 367 | let res20 = get_line(&mut buff).unwrap(); 368 | assert_eq!(String::from_utf8(res20.to_vec()).unwrap(), "+20"); 369 | } 370 | 371 | #[test] 372 | fn test_make_frame() { 373 | let mut get_frame = Frame::array(); 374 | get_frame.push_bulk(Bytes::from("get")); 375 | get_frame.push_bulk(Bytes::from("key")); 376 | 377 | assert_eq!(get_frame.to_string(), "get key"); 378 | // println!("{:?}", get_frame); // Array([Bulk(b"get"), Bulk(b"key")]) 379 | 380 | let mut set_frame = Frame::array(); 381 | set_frame.push_bulk(Bytes::from("set")); 382 | set_frame.push_bulk(Bytes::from("key")); 383 | set_frame.push_int(100); 384 | assert_eq!(set_frame.to_string(), "set key 100"); 385 | } 386 | 387 | #[test] 388 | fn test_frame_check() { 389 | let tests = vec![ 390 | (&b"+OK\r\n"[..], Ok(())), // (stream, expected result) 391 | ( 392 | &b"#\r\n"[..], 393 | Err(Error::Other( 394 | "protocol error; invalid frame type byte `35`".into(), 395 | )), 396 | ), 397 | (&b""[..], Err(Error::Incomplete)), 398 | (&b"-Error message\r\n"[..], Ok(())), // error message 399 | (&b":1000\r\n"[..], Ok(())), // an integer 400 | (&b"$6\r\nfoobar\r\n"[..], Ok(())), // string "foobar" 401 | (&b"$0\r\n\r\n"[..], Ok(())), // empty string 402 | (&b"$-1\r\n"[..], Ok(())), // null string 403 | (&b"*0\r\n"[..], Ok(())), // an empty Array 404 | ( 405 | // an array of two Bulk Strings "foo" and "bar" 406 | &b"*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n"[..], 407 | Ok(()), 408 | ), 409 | ( 410 | // an Array of three integers 411 | &b"*3\r\n:1\r\n:2\r\n:3\r\n"[..], 412 | Ok(()), 413 | ), 414 | ( 415 | // a list of four integers and a bulk string 416 | &b"*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n"[..], 417 | Ok(()), 418 | ), 419 | ( 420 | // an Array containing a Null element ["foo",nil,"bar"] 421 | &b"*3\r\n$3\r\nfoo\r\n$-1\r\n$3\r\nbar\r\n"[..], 422 | Ok(()), 423 | ), 424 | ]; 425 | 426 | for t in tests.iter() { 427 | let stream = t.0; 428 | let wanted = &t.1; 429 | let mut buff = Cursor::new(stream); 430 | let res = Frame::check(&mut buff); 431 | 432 | match res { 433 | Ok(_) => assert!(wanted.is_ok()), 434 | Err(e) => assert_eq!(wanted.as_ref().unwrap_err().to_string(), e.to_string()), 435 | } 436 | } 437 | } 438 | 439 | #[test] 440 | fn test_frame_parse() { 441 | // for easer comperation uses string as expected result 442 | let tests = vec![ 443 | (&b"+OK\r\n"[..], Ok("OK")), // (stream, expected result) 444 | (&b""[..], Err(Error::Incomplete)), 445 | ( 446 | // error message 447 | &b"-Error message\r\n"[..], 448 | Ok("error: Error message"), 449 | ), 450 | (&b":1000\r\n"[..], Ok("1000")), // an integer 451 | ( 452 | // string "foobar" 453 | &b"$6\r\nfoobar\r\n"[..], 454 | Ok("foobar"), 455 | ), 456 | ( 457 | // empty string 458 | &b"$0\r\n\r\n"[..], 459 | Ok(""), 460 | ), 461 | (&b"$-1\r\n"[..], Ok("(nil)")), // null string 462 | ( 463 | // an empty Array 464 | &b"*0\r\n"[..], 465 | Ok(""), 466 | ), 467 | ( 468 | // an array of two Bulk Strings "foo" and "bar" 469 | &b"*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n"[..], 470 | Ok("foo bar"), 471 | ), 472 | ( 473 | // an Array of three integers 474 | &b"*3\r\n:1\r\n:2\r\n:3\r\n"[..], 475 | Ok("1 2 3"), 476 | ), 477 | ( 478 | // a list of four integers and a bulk string 479 | &b"*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n"[..], 480 | Ok("1 2 3 4 foobar"), 481 | ), 482 | ( 483 | // an Array containing a Null element ["foo",nil,"bar"] 484 | &b"*3\r\n$3\r\nfoo\r\n$-1\r\n$3\r\nbar\r\n"[..], 485 | Ok("foo (nil) bar"), 486 | ), 487 | ]; 488 | 489 | for t in tests.iter() { 490 | let stream = t.0; 491 | let wanted = &t.1; 492 | let mut buff = Cursor::new(stream); 493 | let res = Frame::parse(&mut buff); 494 | match res { 495 | Ok(x) => { 496 | assert_eq!(x.to_string(), wanted.as_ref().unwrap().to_string()); 497 | } 498 | Err(e) => assert_eq!(wanted.as_ref().unwrap_err().to_string(), e.to_string()), 499 | } 500 | } 501 | } 502 | } 503 | -------------------------------------------------------------------------------- /src/connection.rs: -------------------------------------------------------------------------------- 1 | use crate::frame::{self, Frame}; 2 | 3 | use bytes::{Buf, BytesMut}; 4 | use std::io::{self, Cursor}; 5 | use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; 6 | use tokio::net::TcpStream; 7 | 8 | /// Send and receive `Frame` values from a remote peer. 9 | /// 10 | /// When implementing networking protocols, a message on that protocol is 11 | /// often composed of several smaller messages known as frames. The purpose of 12 | /// `Connection` is to read and write frames on the underlying `TcpStream`. 13 | /// 14 | /// To read frames, the `Connection` uses an internal buffer, which is filled 15 | /// up until there are enough bytes to create a full frame. Once this happens, 16 | /// the `Connection` creates the frame and returns it to the caller. 17 | /// 18 | /// When sending frames, the frame is first encoded into the write buffer. 19 | /// The contents of the write buffer are then written to the socket. 20 | #[derive(Debug)] 21 | pub struct Connection { 22 | // The `TcpStream`. It is decorated with a `BufWriter`, which provides write 23 | // level buffering. The `BufWriter` implementation provided by Tokio is 24 | // sufficient for our needs. 25 | stream: BufWriter, 26 | 27 | // The buffer for reading frames. 28 | buffer: BytesMut, 29 | } 30 | 31 | impl Connection { 32 | /// Create a new `Connection`, backed by `socket`. Read and write buffers 33 | /// are initialized. 34 | pub fn new(socket: TcpStream) -> Connection { 35 | Connection { 36 | stream: BufWriter::new(socket), 37 | // Default to a 4KB read buffer. For the use case of mini telegram, 38 | // this is fine. However, real applications will want to tune this 39 | // value to their specific use case. There is a high likelihood that 40 | // a larger read buffer will work better. 41 | buffer: BytesMut::with_capacity(4 * 1024), 42 | } 43 | } 44 | 45 | /// Read a single `Frame` value from the underlying stream. 46 | /// 47 | /// The function waits until it has retrieved enough data to parse a frame. 48 | /// Any data remaining in the read buffer after the frame has been parsed is 49 | /// kept there for the next call to `read_frame`. 50 | /// 51 | /// # Returns 52 | /// 53 | /// On success, the received frame is returned. If the `TcpStream` 54 | /// is closed in a way that doesn't break a frame in half, it returns 55 | /// `None`. Otherwise, an error is returned. 56 | pub async fn read_frame(&mut self) -> crate::Result> { 57 | loop { 58 | // Attempt to parse a frame from the buffered data. If enough data 59 | // has been buffered, the frame is returned. 60 | if let Some(frame) = self.parse_frame()? { 61 | return Ok(Some(frame)); 62 | } 63 | 64 | // There is not enough buffered data to read a frame. Attempt to 65 | // read more data from the socket. 66 | // 67 | // On success, the number of bytes is returned. `0` indicates "end 68 | // of stream". 69 | if 0 == self.stream.read_buf(&mut self.buffer).await? { 70 | // The remote closed the connection. For this to be a clean 71 | // shutdown, there should be no data in the read buffer. If 72 | // there is, this means that the peer closed the socket while 73 | // sending a frame. 74 | if self.buffer.is_empty() { 75 | return Ok(None); 76 | } else { 77 | return Err("connection reset by peer".into()); 78 | } 79 | } 80 | } 81 | } 82 | 83 | /// Tries to parse a frame from the buffer. If the buffer contains enough 84 | /// data, the frame is returned and the data removed from the buffer. If not 85 | /// enough data has been buffered yet, `Ok(None)` is returned. If the 86 | /// buffered data does not represent a valid frame, `Err` is returned. 87 | fn parse_frame(&mut self) -> crate::Result> { 88 | use frame::Error::Incomplete; 89 | 90 | // Cursor is used to track the "current" location in the 91 | // buffer. Cursor also implements `Buf` from the `bytes` crate 92 | // which provides a number of helpful utilities for working 93 | // with bytes. 94 | let mut buf = Cursor::new(&self.buffer[..]); 95 | 96 | // The first step is to check if enough data has been buffered to parse 97 | // a single frame. This step is usually much faster than doing a full 98 | // parse of the frame, and allows us to skip allocating data structures 99 | // to hold the frame data unless we know the full frame has been 100 | // received. 101 | match Frame::check(&mut buf) { 102 | Ok(_) => { 103 | // The `check` function will have advanced the cursor until the 104 | // end of the frame. Since the cursor had position set to zero 105 | // before `Frame::check` was called, we obtain the length of the 106 | // frame by checking the cursor position. 107 | let len = buf.position() as usize; 108 | 109 | // Reset the position to zero before passing the cursor to 110 | // `Frame::parse`. 111 | buf.set_position(0); 112 | 113 | // Parse the frame from the buffer. This allocates the necessary 114 | // structures to represent the frame and returns the frame 115 | // value. 116 | // 117 | // If the encoded frame representation is invalid, an error is 118 | // returned. This should terminate the **current** connection 119 | // but should not impact any other connected client. 120 | let frame = Frame::parse(&mut buf)?; 121 | 122 | // Discard the parsed data from the read buffer. 123 | // 124 | // When `advance` is called on the read buffer, all of the data 125 | // up to `len` is discarded. The details of how this works is 126 | // left to `BytesMut`. This is often done by moving an internal 127 | // cursor, but it may be done by reallocating and copying data. 128 | self.buffer.advance(len); 129 | 130 | // Return the parsed frame to the caller. 131 | Ok(Some(frame)) 132 | } 133 | // There is not enough data present in the read buffer to parse a 134 | // single frame. We must wait for more data to be received from the 135 | // socket. Reading from the socket will be done in the statement 136 | // after this `match`. 137 | // 138 | // We do not want to return `Err` from here as this "error" is an 139 | // expected runtime condition. 140 | Err(Incomplete) => Ok(None), 141 | // An error was encountered while parsing the frame. The connection 142 | // is now in an invalid state. Returning `Err` from here will result 143 | // in the connection being closed. 144 | Err(e) => Err(e.into()), 145 | } 146 | } 147 | 148 | /// Write a single `Frame` value to the underlying stream. 149 | /// 150 | /// The `Frame` value is written to the socket using the various `write_*` 151 | /// functions provided by `AsyncWrite`. Calling these functions directly on 152 | /// a `TcpStream` is **not** advised, as this will result in a large number of 153 | /// syscalls. However, it is fine to call these functions on a *buffered* 154 | /// write stream. The data will be written to the buffer. Once the buffer is 155 | /// full, it is flushed to the underlying socket. 156 | pub async fn write_frame(&mut self, frame: &Frame) -> io::Result<()> { 157 | // Arrays are encoded by encoding each entry. All other frame types are 158 | // considered literals. For now, mini-telegram is not able to encode 159 | // recursive frame structures. See below for more details. 160 | match frame { 161 | Frame::Array(val) => { 162 | // Encode the frame type prefix. For an array, it is `*`. 163 | self.stream.write_u8(b'*').await?; 164 | 165 | // Encode the length of the array. 166 | self.write_decimal(val.len() as u64).await?; 167 | 168 | // Iterate and encode each entry in the array. 169 | for entry in &**val { 170 | self.write_value(entry).await?; 171 | } 172 | } 173 | // The frame type is a literal. Encode the value directly. 174 | _ => self.write_value(frame).await?, 175 | } 176 | 177 | // Ensure the encoded frame is written to the socket. The calls above 178 | // are to the buffered stream and writes. Calling `flush` writes the 179 | // remaining contents of the buffer to the socket. 180 | self.stream.flush().await 181 | } 182 | 183 | /// Write a frame literal to the stream 184 | async fn write_value(&mut self, frame: &Frame) -> io::Result<()> { 185 | match frame { 186 | Frame::Simple(val) => { 187 | self.stream.write_u8(b'+').await?; 188 | self.stream.write_all(val.as_bytes()).await?; 189 | self.stream.write_all(b"\r\n").await?; 190 | } 191 | Frame::Error(val) => { 192 | self.stream.write_u8(b'-').await?; 193 | self.stream.write_all(val.as_bytes()).await?; 194 | self.stream.write_all(b"\r\n").await?; 195 | } 196 | Frame::Integer(val) => { 197 | self.stream.write_u8(b':').await?; 198 | self.write_decimal(*val).await?; 199 | } 200 | Frame::Null => { 201 | self.stream.write_all(b"$-1\r\n").await?; 202 | } 203 | Frame::Bulk(val) => { 204 | let len = val.len(); 205 | 206 | self.stream.write_u8(b'$').await?; 207 | self.write_decimal(len as u64).await?; 208 | self.stream.write_all(val).await?; 209 | self.stream.write_all(b"\r\n").await?; 210 | } 211 | // Encoding an `Array` from within a value cannot be done using a 212 | // recursive strategy. In general, async fns do not support 213 | // recursion. Mini-telegram has not needed to encode nested arrays yet, 214 | // so for now it is skipped. 215 | Frame::Array(_val) => unreachable!(), 216 | } 217 | 218 | Ok(()) 219 | } 220 | 221 | /// Write a decimal frame to the stream 222 | async fn write_decimal(&mut self, val: u64) -> io::Result<()> { 223 | use std::io::Write; 224 | 225 | // Convert the value to a string 226 | let mut buf = [0u8; 20]; 227 | let mut buf = Cursor::new(&mut buf[..]); 228 | write!(&mut buf, "{}", val)?; 229 | 230 | let pos = buf.position() as usize; 231 | self.stream.write_all(&buf.get_ref()[..pos]).await?; 232 | self.stream.write_all(b"\r\n").await?; 233 | 234 | Ok(()) 235 | } 236 | } 237 | 238 | #[cfg(test)] 239 | mod tests { 240 | use super::*; 241 | use bytes::{BufMut, BytesMut}; 242 | use futures::future::join_all; 243 | use std::str; 244 | use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; 245 | use tokio::net::{TcpListener, TcpStream}; 246 | use tokio::time::Instant; 247 | use tokio::try_join; 248 | 249 | #[tokio::test] 250 | async fn test_tcp_stream() { 251 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 252 | let addr = listener.local_addr().unwrap(); 253 | 254 | let server = tokio::spawn(async move { 255 | let mut stream = listener.accept().await.unwrap().0; // (stream, addr) 256 | let mut buf = [0]; 257 | let _ = stream.read(&mut buf).await.unwrap(); 258 | assert_eq!(buf[0], 144); 259 | // println!("server terminated!"); 260 | }); 261 | 262 | let client = tokio::spawn(async move { 263 | let mut stream = TcpStream::connect(addr).await.unwrap(); 264 | let _ = stream.write_all(&[144]).await.unwrap(); 265 | }); 266 | 267 | try_join!(server, client).unwrap(); 268 | } 269 | 270 | #[tokio::test] 271 | async fn test_tcp_stream_buf_writer() { 272 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 273 | let addr = listener.local_addr().unwrap(); 274 | const N: usize = 10240; 275 | 276 | let server = tokio::spawn(async move { 277 | let mut handles: Vec> = Vec::new(); 278 | for _ in 0..2 { 279 | let mut stream = listener.accept().await.unwrap().0; // (stream, addr) 280 | handles.push(tokio::spawn(async move { 281 | let mut buf = [0; 10]; 282 | for _ in 0..N { 283 | let _ = stream.read(&mut buf).await.unwrap(); 284 | assert_eq!(str::from_utf8(&buf).unwrap(), "some bytes"); 285 | } 286 | // println!("handler thread terminated: {}", t); 287 | })); 288 | } 289 | let _ = join_all(handles).await; 290 | println!("server terminated!"); 291 | }); 292 | 293 | let client_tcp_stream = tokio::spawn(async move { 294 | let mut stream = TcpStream::connect(addr).await.unwrap(); 295 | let now = Instant::now(); 296 | for _ in 0..N { 297 | let _ = stream.write_all(b"some bytes").await.unwrap(); 298 | } 299 | let tcp_stream_time_consumption = now.elapsed(); 300 | // println!("tcp_stream:{:?}", tcp_stream_time_consumption); 301 | tcp_stream_time_consumption 302 | }); 303 | 304 | let client_buf_writer = tokio::spawn(async move { 305 | let stream = TcpStream::connect(addr).await.unwrap(); 306 | // `BufWriter` can improve the speed of programs that make *small* and 307 | // *repeated* write calls to the same file or network socket. It does not 308 | // help when writing very large amounts at once, or writing just one or a few 309 | // times. It also provides no advantage when writing to a destination that is 310 | // in memory, like a `Vec`. 311 | let mut stream = BufWriter::new(stream); 312 | let now = Instant::now(); 313 | for _ in 0..N { 314 | let _ = stream.write_all(b"some bytes").await.unwrap(); 315 | } 316 | let buf_writer_time_consumption = now.elapsed(); 317 | // println!("buf_writer:{:?}", buf_writer_time_consumption); 318 | buf_writer_time_consumption 319 | }); 320 | 321 | let (_, tcp_stream_time_consumption, buf_writer_time_consumption) = 322 | try_join!(server, client_tcp_stream, client_buf_writer).unwrap(); 323 | 324 | assert!(buf_writer_time_consumption < tcp_stream_time_consumption); 325 | } 326 | 327 | #[tokio::test] 328 | async fn test_bytes_mut_growth() { 329 | // BytesMut’s BufMut implementation will implicitly grow its buffer 330 | // as necessary. However, explicitly reserving the required space 331 | // up-front before a series of inserts will be more efficient. 332 | let mut buf = BytesMut::with_capacity(10); 333 | let addr_a = format!("{:p}", buf.as_ptr()); 334 | buf.put(&b"yumcoder"[..]); 335 | let addr_b = format!("{:p}", buf.as_ptr()); 336 | assert_eq!(addr_a, addr_b); 337 | buf.put(&b"more content to expand the current buffer!"[..]); 338 | let addr_c = format!("{:p}", buf.as_ptr()); 339 | assert_ne!(addr_c, addr_b); 340 | } 341 | 342 | #[tokio::test] 343 | async fn test_read_frame() { 344 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 345 | let addr = listener.local_addr().unwrap(); 346 | 347 | let server = tokio::spawn(async move { 348 | let stream = listener.accept().await.unwrap().0; // (stream, addr) 349 | let mut connection = Connection::new(stream); 350 | let cmd = connection.read_frame().await.unwrap(); 351 | if let Some(x) = cmd { 352 | assert_eq!(x.to_string(), "OK"); 353 | } 354 | 355 | let cmd = connection.read_frame().await.unwrap_err(); 356 | let err = frame::Error::from("protocol error; invalid frame type byte `33`"); 357 | assert_eq!(cmd.to_string(), err.to_string()); 358 | // println!("server terminated!"); 359 | }); 360 | 361 | let client = tokio::spawn(async move { 362 | let mut stream = TcpStream::connect(addr).await.unwrap(); 363 | let _ = stream.write_all(&b"+OK\r\n"[..]).await.unwrap(); 364 | let _ = stream.write_all(&b"!"[..]).await.unwrap(); 365 | }); 366 | 367 | try_join!(server, client).unwrap(); 368 | } 369 | 370 | #[tokio::test] 371 | async fn test_write_frame() { 372 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 373 | let addr = listener.local_addr().unwrap(); 374 | 375 | let server = tokio::spawn(async move { 376 | let stream = listener.accept().await.unwrap().0; // (stream, addr) 377 | let mut connection = Connection::new(stream); 378 | let cmd = connection.read_frame().await.unwrap(); 379 | if let Some(x) = cmd { 380 | assert_eq!(x.to_string(), "OK"); 381 | connection.write_frame(&x).await.unwrap(); 382 | } 383 | // println!("server terminated!"); 384 | }); 385 | 386 | let client = tokio::spawn(async move { 387 | let mut stream = TcpStream::connect(addr).await.unwrap(); 388 | // for simplicity using Connection only on server side 389 | let _ = stream.write_all(&b"+OK\r\n"[..]).await.unwrap(); 390 | let mut buf = [0; 5]; 391 | let _ = stream.read(&mut buf).await.unwrap(); 392 | assert_eq!(str::from_utf8(&buf).unwrap(), "+OK\r\n"); 393 | }); 394 | 395 | try_join!(server, client).unwrap(); 396 | } 397 | } 398 | -------------------------------------------------------------------------------- /src/client.rs: -------------------------------------------------------------------------------- 1 | //! Minimal MTProto client implementation 2 | //! 3 | //! Provides an async connect and methods for issuing the supported commands. 4 | 5 | use crate::cmd::{Get, Publish, Set, Subscribe, Unsubscribe}; 6 | use crate::{Connection, Frame}; 7 | 8 | use async_stream::try_stream; 9 | use bytes::Bytes; 10 | use std::io::{Error, ErrorKind}; 11 | use std::time::Duration; 12 | use tokio::net::{TcpStream, ToSocketAddrs}; 13 | use tokio_stream::Stream; 14 | use tracing::{debug, instrument}; 15 | 16 | /// Established connection with a MTProto server. 17 | /// 18 | /// Backed by a single `TcpStream`, `Client` provides basic network client 19 | /// functionality (no pooling, retrying, ...). Connections are established using 20 | /// the [`connect`](fn@connect) function. 21 | /// 22 | /// Requests are issued using the various methods of `Client`. 23 | pub struct Client { 24 | /// The TCP connection decorated with the MTProto protocol encoder / decoder 25 | /// implemented using a buffered `TcpStream`. 26 | /// 27 | /// When `Listener` receives an inbound connection, the `TcpStream` is 28 | /// passed to `Connection::new`, which initializes the associated buffers. 29 | /// `Connection` allows the handler to operate at the "frame" level and keep 30 | /// the byte level protocol parsing details encapsulated in `Connection`. 31 | connection: Connection, 32 | } 33 | 34 | /// A client that has entered pub/sub mode. 35 | /// 36 | /// Once clients subscribe to a channel, they may only perform pub/sub related 37 | /// commands. The `Client` type is transitioned to a `Subscriber` type in order 38 | /// to prevent non-pub/sub methods from being called. 39 | pub struct Subscriber { 40 | /// The subscribed client. 41 | client: Client, 42 | 43 | /// The set of channels to which the `Subscriber` is currently subscribed. 44 | subscribed_channels: Vec, 45 | } 46 | 47 | /// A message received on a subscribed channel. 48 | #[derive(Debug, Clone)] 49 | pub struct Message { 50 | pub channel: String, 51 | pub content: Bytes, 52 | } 53 | 54 | /// Establish a connection with the MTProto server located at `addr`. 55 | /// 56 | /// `addr` may be any type that can be asynchronously converted to a 57 | /// `SocketAddr`. This includes `SocketAddr` and strings. The `ToSocketAddrs` 58 | /// trait is the Tokio version and not the `std` version. 59 | /// 60 | /// # Examples 61 | /// 62 | /// ```no_run 63 | /// use mini_telgram::client; 64 | /// 65 | /// #[tokio::main] 66 | /// async fn main() { 67 | /// let client = match client::connect("localhost:6379").await { 68 | /// Ok(client) => client, 69 | /// Err(_) => panic!("failed to establish connection"), 70 | /// }; 71 | /// # drop(client); 72 | /// } 73 | /// ``` 74 | /// 75 | pub async fn connect(addr: T) -> crate::Result { 76 | // The `addr` argument is passed directly to `TcpStream::connect`. This 77 | // performs any asynchronous DNS lookup and attempts to establish the TCP 78 | // connection. An error at either step returns an error, which is then 79 | // bubbled up to the caller of `mini_telegram` connect. 80 | let socket = TcpStream::connect(addr).await?; 81 | 82 | // Initialize the connection state. This allocates read/write buffers to 83 | // perform telegram protocol frame parsing. 84 | let connection = Connection::new(socket); 85 | 86 | Ok(Client { connection }) 87 | } 88 | 89 | impl Client { 90 | /// Get the value of key. 91 | /// 92 | /// If the key does not exist the special value `None` is returned. 93 | /// 94 | /// # Examples 95 | /// 96 | /// Demonstrates basic usage. 97 | /// 98 | /// ```no_run 99 | /// use mini_telegram::client; 100 | /// 101 | /// #[tokio::main] 102 | /// async fn main() { 103 | /// let mut client = client::connect("localhost:6379").await.unwrap(); 104 | /// 105 | /// let val = client.get("foo").await.unwrap(); 106 | /// println!("Got = {:?}", val); 107 | /// } 108 | /// ``` 109 | #[instrument(skip(self))] 110 | pub async fn get(&mut self, key: &str) -> crate::Result> { 111 | // Create a `Get` command for the `key` and convert it to a frame. 112 | let frame = Get::new(key).into_frame(); 113 | 114 | debug!(request = ?frame); 115 | 116 | // Write the frame to the socket. This writes the full frame to the 117 | // socket, waiting if necessary. 118 | self.connection.write_frame(&frame).await?; 119 | 120 | // Wait for the response from the server 121 | // 122 | // Both `Simple` and `Bulk` frames are accepted. `Null` represents the 123 | // key not being present and `None` is returned. 124 | match self.read_response().await? { 125 | Frame::Simple(value) => Ok(Some(value.into())), 126 | Frame::Bulk(value) => Ok(Some(value)), 127 | Frame::Null => Ok(None), 128 | frame => Err(frame.to_error()), 129 | } 130 | } 131 | 132 | /// Set `key` to hold the given `value`. 133 | /// 134 | /// The `value` is associated with `key` until it is overwritten by the next 135 | /// call to `set` or it is removed. 136 | /// 137 | /// If key already holds a value, it is overwritten. Any previous time to 138 | /// live associated with the key is discarded on successful SET operation. 139 | /// 140 | /// # Examples 141 | /// 142 | /// Demonstrates basic usage. 143 | /// 144 | /// ```no_run 145 | /// use mini_telegram::client; 146 | /// 147 | /// #[tokio::main] 148 | /// async fn main() { 149 | /// let mut client = client::connect("localhost:6379").await.unwrap(); 150 | /// 151 | /// client.set("foo", "bar".into()).await.unwrap(); 152 | /// 153 | /// // Getting the value immediately works 154 | /// let val = client.get("foo").await.unwrap().unwrap(); 155 | /// assert_eq!(val, "bar"); 156 | /// } 157 | /// ``` 158 | #[instrument(skip(self))] 159 | pub async fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> { 160 | // Create a `Set` command and pass it to `set_cmd`. A separate method is 161 | // used to set a value with an expiration. The common parts of both 162 | // functions are implemented by `set_cmd`. 163 | self.set_cmd(Set::new(key, value, None)).await 164 | } 165 | 166 | /// Set `key` to hold the given `value`. The value expires after `expiration` 167 | /// 168 | /// The `value` is associated with `key` until one of the following: 169 | /// - it expires. 170 | /// - it is overwritten by the next call to `set`. 171 | /// - it is removed. 172 | /// 173 | /// If key already holds a value, it is overwritten. Any previous time to 174 | /// live associated with the key is discarded on a successful SET operation. 175 | /// 176 | /// # Examples 177 | /// 178 | /// Demonstrates basic usage. This example is not **guaranteed** to always 179 | /// work as it relies on time based logic and assumes the client and server 180 | /// stay relatively synchronized in time. The real world tends to not be so 181 | /// favorable. 182 | /// 183 | /// ```no_run 184 | /// use mini_telegram::client; 185 | /// use tokio::time; 186 | /// use std::time::Duration; 187 | /// 188 | /// #[tokio::main] 189 | /// async fn main() { 190 | /// let ttl = Duration::from_millis(500); 191 | /// let mut client = client::connect("localhost:6379").await.unwrap(); 192 | /// 193 | /// client.set_expires("foo", "bar".into(), ttl).await.unwrap(); 194 | /// 195 | /// // Getting the value immediately works 196 | /// let val = client.get("foo").await.unwrap().unwrap(); 197 | /// assert_eq!(val, "bar"); 198 | /// 199 | /// // Wait for the TTL to expire 200 | /// time::sleep(ttl).await; 201 | /// 202 | /// let val = client.get("foo").await.unwrap(); 203 | /// assert!(val.is_some()); 204 | /// } 205 | /// ``` 206 | #[instrument(skip(self))] 207 | pub async fn set_expires( 208 | &mut self, 209 | key: &str, 210 | value: Bytes, 211 | expiration: Duration, 212 | ) -> crate::Result<()> { 213 | // Create a `Set` command and pass it to `set_cmd`. A separate method is 214 | // used to set a value with an expiration. The common parts of both 215 | // functions are implemented by `set_cmd`. 216 | self.set_cmd(Set::new(key, value, Some(expiration))).await 217 | } 218 | 219 | /// The core `SET` logic, used by both `set` and `set_expires. 220 | async fn set_cmd(&mut self, cmd: Set) -> crate::Result<()> { 221 | // Convert the `Set` command into a frame 222 | let frame = cmd.into_frame(); 223 | 224 | debug!(request = ?frame); 225 | 226 | // Write the frame to the socket. This writes the full frame to the 227 | // socket, waiting if necessary. 228 | self.connection.write_frame(&frame).await?; 229 | 230 | // Wait for the response from the server. On success, the server 231 | // responds simply with `OK`. Any other response indicates an error. 232 | match self.read_response().await? { 233 | Frame::Simple(response) if response == "OK" => Ok(()), 234 | frame => Err(frame.to_error()), 235 | } 236 | } 237 | 238 | /// Posts `message` to the given `channel`. 239 | /// 240 | /// Returns the number of subscribers currently listening on the channel. 241 | /// There is no guarantee that these subscribers receive the message as they 242 | /// may disconnect at any time. 243 | /// 244 | /// # Examples 245 | /// 246 | /// Demonstrates basic usage. 247 | /// 248 | /// ```no_run 249 | /// use mini_telegram::client; 250 | /// 251 | /// #[tokio::main] 252 | /// async fn main() { 253 | /// let mut client = client::connect("localhost:6379").await.unwrap(); 254 | /// 255 | /// let val = client.publish("foo", "bar".into()).await.unwrap(); 256 | /// println!("Got = {:?}", val); 257 | /// } 258 | /// ``` 259 | #[instrument(skip(self))] 260 | pub async fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result { 261 | // Convert the `Publish` command into a frame 262 | let frame = Publish::new(channel, message).into_frame(); 263 | 264 | debug!(request = ?frame); 265 | 266 | // Write the frame to the socket 267 | self.connection.write_frame(&frame).await?; 268 | 269 | // Read the response 270 | match self.read_response().await? { 271 | Frame::Integer(response) => Ok(response), 272 | frame => Err(frame.to_error()), 273 | } 274 | } 275 | 276 | /// Subscribes the client to the specified channels. 277 | /// 278 | /// Once a client issues a subscribe command, it may no longer issue any 279 | /// non-pub/sub commands. The function consumes `self` and returns a `Subscriber`. 280 | /// 281 | /// The `Subscriber` value is used to receive messages as well as manage the 282 | /// list of channels the client is subscribed to. 283 | #[instrument(skip(self))] 284 | pub async fn subscribe(mut self, channels: Vec) -> crate::Result { 285 | // Issue the subscribe command to the server and wait for confirmation. 286 | // The client will then have been transitioned into the "subscriber" 287 | // state and may only issue pub/sub commands from that point on. 288 | self.subscribe_cmd(&channels).await?; 289 | 290 | // Return the `Subscriber` type 291 | Ok(Subscriber { 292 | client: self, 293 | subscribed_channels: channels, 294 | }) 295 | } 296 | 297 | /// The core `SUBSCRIBE` logic, used by misc subscribe fns 298 | async fn subscribe_cmd(&mut self, channels: &[String]) -> crate::Result<()> { 299 | // Convert the `Subscribe` command into a frame 300 | let frame = Subscribe::new(&channels).into_frame(); 301 | 302 | debug!(request = ?frame); 303 | 304 | // Write the frame to the socket 305 | self.connection.write_frame(&frame).await?; 306 | 307 | // For each channel being subscribed to, the server responds with a 308 | // message confirming subscription to that channel. 309 | for channel in channels { 310 | // Read the response 311 | let response = self.read_response().await?; 312 | 313 | // Verify it is confirmation of subscription. 314 | match response { 315 | Frame::Array(ref frame) => match frame.as_slice() { 316 | // The server responds with an array frame in the form of: 317 | // 318 | // ``` 319 | // [ "subscribe", channel, num-subscribed ] 320 | // ``` 321 | // 322 | // where channel is the name of the channel and 323 | // num-subscribed is the number of channels that the client 324 | // is currently subscribed to. 325 | [subscribe, schannel, ..] 326 | if *subscribe == "subscribe" && *schannel == channel => {} 327 | _ => return Err(response.to_error()), 328 | }, 329 | frame => return Err(frame.to_error()), 330 | }; 331 | } 332 | 333 | Ok(()) 334 | } 335 | 336 | /// Reads a response frame from the socket. 337 | /// 338 | /// If an `Error` frame is received, it is converted to `Err`. 339 | async fn read_response(&mut self) -> crate::Result { 340 | let response = self.connection.read_frame().await?; 341 | 342 | debug!(?response); 343 | 344 | match response { 345 | // Error frames are converted to `Err` 346 | Some(Frame::Error(msg)) => Err(msg.into()), 347 | Some(frame) => Ok(frame), 348 | None => { 349 | // Receiving `None` here indicates the server has closed the 350 | // connection without sending a frame. This is unexpected and is 351 | // represented as a "connection reset by peer" error. 352 | let err = Error::new(ErrorKind::ConnectionReset, "connection reset by server"); 353 | 354 | Err(err.into()) 355 | } 356 | } 357 | } 358 | } 359 | 360 | impl Subscriber { 361 | /// Returns the set of channels currently subscribed to. 362 | pub fn get_subscribed(&self) -> &[String] { 363 | &self.subscribed_channels 364 | } 365 | 366 | /// Receive the next message published on a subscribed channel, waiting if 367 | /// necessary. 368 | /// 369 | /// `None` indicates the subscription has been terminated. 370 | pub async fn next_message(&mut self) -> crate::Result> { 371 | match self.client.connection.read_frame().await? { 372 | Some(mframe) => { 373 | debug!(?mframe); 374 | 375 | match mframe { 376 | Frame::Array(ref frame) => match frame.as_slice() { 377 | [message, channel, content] if *message == "message" => Ok(Some(Message { 378 | channel: channel.to_string(), 379 | content: Bytes::from(content.to_string()), 380 | })), 381 | _ => Err(mframe.to_error()), 382 | }, 383 | frame => Err(frame.to_error()), 384 | } 385 | } 386 | None => Ok(None), 387 | } 388 | } 389 | 390 | /// Convert the subscriber into a `Stream` yielding new messages published 391 | /// on subscribed channels. 392 | /// 393 | /// `Subscriber` does not implement stream itself as doing so with safe code 394 | /// is non trivial. The usage of async/await would require a manual Stream 395 | /// implementation to use `unsafe` code. Instead, a conversion function is 396 | /// provided and the returned stream is implemented with the help of the 397 | /// `async-stream` crate. 398 | pub fn into_stream(mut self) -> impl Stream> { 399 | // Uses the `try_stream` macro from the `async-stream` crate. Generators 400 | // are not stable in Rust. The crate uses a macro to simulate generators 401 | // on top of async/await. There are limitations, so read the 402 | // documentation there. 403 | try_stream! { 404 | while let Some(message) = self.next_message().await? { 405 | yield message; 406 | } 407 | } 408 | } 409 | 410 | /// Subscribe to a list of new channels 411 | #[instrument(skip(self))] 412 | pub async fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> { 413 | // Issue the subscribe command 414 | self.client.subscribe_cmd(channels).await?; 415 | 416 | // Update the set of subscribed channels. 417 | self.subscribed_channels 418 | .extend(channels.iter().map(Clone::clone)); 419 | 420 | Ok(()) 421 | } 422 | 423 | /// Unsubscribe to a list of new channels 424 | #[instrument(skip(self))] 425 | pub async fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> { 426 | let frame = Unsubscribe::new(&channels).into_frame(); 427 | 428 | debug!(request = ?frame); 429 | 430 | // Write the frame to the socket 431 | self.client.connection.write_frame(&frame).await?; 432 | 433 | // if the input channel list is empty, server acknowledges as unsubscribing 434 | // from all subscribed channels, so we assert that the unsubscribe list received 435 | // matches the client subscribed one 436 | let num = if channels.is_empty() { 437 | self.subscribed_channels.len() 438 | } else { 439 | channels.len() 440 | }; 441 | 442 | // Read the response 443 | for _ in 0..num { 444 | let response = self.client.read_response().await?; 445 | 446 | match response { 447 | Frame::Array(ref frame) => match frame.as_slice() { 448 | [unsubscribe, channel, ..] if *unsubscribe == "unsubscribe" => { 449 | let len = self.subscribed_channels.len(); 450 | 451 | if len == 0 { 452 | // There must be at least one channel 453 | return Err(response.to_error()); 454 | } 455 | 456 | // unsubscribed channel should exist in the subscribed list at this point 457 | self.subscribed_channels.retain(|c| *channel != &c[..]); 458 | 459 | // Only a single channel should be removed from the 460 | // list of subscribed channels. 461 | if self.subscribed_channels.len() != len - 1 { 462 | return Err(response.to_error()); 463 | } 464 | } 465 | _ => return Err(response.to_error()), 466 | }, 467 | frame => return Err(frame.to_error()), 468 | }; 469 | } 470 | 471 | Ok(()) 472 | } 473 | } 474 | -------------------------------------------------------------------------------- /src/server.rs: -------------------------------------------------------------------------------- 1 | //! Minimal MTProto server implementation 2 | //! 3 | //! Provides an async `run` function that listens for inbound connections, 4 | //! spawning a task per connection. 5 | 6 | use crate::{Command, Connection, Db, DbDropGuard, Shutdown}; 7 | 8 | use std::future::Future; 9 | use std::sync::Arc; 10 | use tokio::net::{TcpListener, TcpStream}; 11 | use tokio::sync::{broadcast, mpsc, Semaphore}; 12 | use tokio::time::{self, Duration}; 13 | use tracing::{debug, error, info, instrument}; 14 | 15 | /// Server listener state. Created in the `run` call. It includes a `run` method 16 | /// which performs the TCP listening and initialization of per-connection state. 17 | #[derive(Debug)] 18 | struct Listener { 19 | /// Shared database handle. 20 | /// 21 | /// Contains the key / value store as well as the broadcast channels for 22 | /// pub/sub. 23 | /// 24 | /// This holds a wrapper around an `Arc`. The internal `Db` can be 25 | /// retrieved and passed into the per connection state (`Handler`). 26 | db_holder: DbDropGuard, 27 | 28 | /// TCP listener supplied by the `run` caller. 29 | listener: TcpListener, 30 | 31 | /// Limit the max number of connections. 32 | /// 33 | /// A `Semaphore` is used to limit the max number of connections. Before 34 | /// attempting to accept a new connection, a permit is acquired from the 35 | /// semaphore. If none are available, the listener waits for one. 36 | /// 37 | /// When handlers complete processing a connection, the permit is returned 38 | /// to the semaphore. 39 | limit_connections: Arc, 40 | 41 | /// Broadcasts a shutdown signal to all active connections. 42 | /// 43 | /// The initial `shutdown` trigger is provided by the `run` caller. The 44 | /// server is responsible for gracefully shutting down active connections. 45 | /// When a connection task is spawned, it is passed a broadcast receiver 46 | /// handle. When a graceful shutdown is initiated, a `()` value is sent via 47 | /// the broadcast::Sender. Each active connection receives it, reaches a 48 | /// safe terminal state, and completes the task. 49 | notify_shutdown: broadcast::Sender<()>, 50 | 51 | /// Used as part of the graceful shutdown process to wait for client 52 | /// connections to complete processing. 53 | /// 54 | /// Tokio channels are closed once all `Sender` handles go out of scope. 55 | /// When a channel is closed, the receiver receives `None`. This is 56 | /// leveraged to detect all connection handlers completing. When a 57 | /// connection handler is initialized, it is assigned a clone of 58 | /// `shutdown_complete_tx`. When the listener shuts down, it drops the 59 | /// sender held by this `shutdown_complete_tx` field. Once all handler tasks 60 | /// complete, all clones of the `Sender` are also dropped. This results in 61 | /// `shutdown_complete_rx.recv()` completing with `None`. At this point, it 62 | /// is safe to exit the server process. 63 | shutdown_complete_rx: mpsc::Receiver<()>, 64 | shutdown_complete_tx: mpsc::Sender<()>, 65 | } 66 | 67 | /// Per-connection handler. Reads requests from `connection` and applies the 68 | /// commands to `db`. 69 | #[derive(Debug)] 70 | struct Handler { 71 | /// Shared database handle. 72 | /// 73 | /// When a command is received from `connection`, it is applied with `db`. 74 | /// The implementation of the command is in the `cmd` module. Each command 75 | /// will need to interact with `db` in order to complete the work. 76 | db: Db, 77 | 78 | /// The TCP connection decorated with the MTProto protocol encoder / decoder 79 | /// implemented using a buffered `TcpStream`. 80 | /// 81 | /// When `Listener` receives an inbound connection, the `TcpStream` is 82 | /// passed to `Connection::new`, which initializes the associated buffers. 83 | /// `Connection` allows the handler to operate at the "frame" level and keep 84 | /// the byte level protocol parsing details encapsulated in `Connection`. 85 | connection: Connection, 86 | 87 | /// Max connection semaphore. 88 | /// 89 | /// When the handler is dropped, a permit is returned to this semaphore. If 90 | /// the listener is waiting for connections to close, it will be notified of 91 | /// the newly available permit and resume accepting connections. 92 | limit_connections: Arc, 93 | 94 | /// Listen for shutdown notifications. 95 | /// 96 | /// A wrapper around the `broadcast::Receiver` paired with the sender in 97 | /// `Listener`. The connection handler processes requests from the 98 | /// connection until the peer disconnects **or** a shutdown notification is 99 | /// received from `shutdown`. In the latter case, any in-flight work being 100 | /// processed for the peer is continued until it reaches a safe state, at 101 | /// which point the connection is terminated. 102 | shutdown: Shutdown, 103 | 104 | /// Not used directly. Instead, when `Handler` is dropped...? 105 | _shutdown_complete: mpsc::Sender<()>, 106 | } 107 | 108 | /// Maximum number of concurrent connections the mini-telegram server will accept. 109 | /// 110 | /// When this limit is reached, the server will stop accepting connections until 111 | /// an active connection terminates. 112 | /// 113 | /// A real application will want to make this value configurable, but for this 114 | /// example, it is hard coded. 115 | /// 116 | /// This is also set to a pretty low value to discourage using this in 117 | /// production (you'd think that all the disclaimers would make it obvious that 118 | /// this is not a serious project... but I thought that about mini-http as 119 | /// well). 120 | const MAX_CONNECTIONS: usize = 250; 121 | 122 | /// Run the mini-telegram server. 123 | /// 124 | /// Accepts connections from the supplied listener. For each inbound connection, 125 | /// a task is spawned to handle that connection. The server runs until the 126 | /// `shutdown` future completes, at which point the server shuts down 127 | /// gracefully. 128 | /// 129 | /// `tokio::signal::ctrl_c()` can be used as the `shutdown` argument. This will 130 | /// listen for a SIGINT signal. 131 | pub async fn run(listener: TcpListener, shutdown: impl Future) { 132 | // When the provided `shutdown` future completes, we must send a shutdown 133 | // message to all active connections. We use a broadcast channel for this 134 | // purpose. The call below ignores the receiver of the broadcast pair, and when 135 | // a receiver is needed, the subscribe() method on the sender is used to create 136 | // one. 137 | let (notify_shutdown, _) = broadcast::channel(1); 138 | let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1); 139 | 140 | // Initialize the listener state 141 | let mut server = Listener { 142 | listener, 143 | db_holder: DbDropGuard::new(), 144 | limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)), 145 | notify_shutdown, 146 | shutdown_complete_tx, 147 | shutdown_complete_rx, 148 | }; 149 | 150 | // Concurrently run the server and listen for the `shutdown` signal. The 151 | // server task runs until an error is encountered, so under normal 152 | // circumstances, this `select!` statement runs until the `shutdown` signal 153 | // is received. 154 | // 155 | // `select!` statements are written in the form of: 156 | // 157 | // ``` 158 | // = => 159 | // ``` 160 | // 161 | // All `` statements are executed concurrently. Once the **first** 162 | // op completes, its associated `` is 163 | // performed. 164 | // 165 | // The `select! macro is a foundational building block for writing 166 | // asynchronous Rust. See the API docs for more details: 167 | // 168 | // https://docs.rs/tokio/*/tokio/macro.select.html 169 | tokio::select! { 170 | res = server.run() => { 171 | // If an error is received here, accepting connections from the TCP 172 | // listener failed multiple times and the server is giving up and 173 | // shutting down. 174 | // 175 | // Errors encountered when handling individual connections do not 176 | // bubble up to this point. 177 | if let Err(err) = res { 178 | error!(cause = %err, "failed to accept"); 179 | } 180 | } 181 | _ = shutdown => { 182 | // The shutdown signal has been received. 183 | info!("shutting down"); 184 | } 185 | } 186 | 187 | // Extract the `shutdown_complete` receiver and transmitter 188 | // explicitly drop `shutdown_transmitter`. This is important, as the 189 | // `.await` below would otherwise never complete. 190 | let Listener { 191 | mut shutdown_complete_rx, 192 | shutdown_complete_tx, 193 | notify_shutdown, 194 | .. 195 | } = server; 196 | 197 | // When `notify_shutdown` is dropped, all tasks which have `subscribe`d will 198 | // receive the shutdown signal and can exit 199 | drop(notify_shutdown); 200 | // Drop final `Sender` so the `Receiver` below can complete 201 | drop(shutdown_complete_tx); 202 | 203 | // Wait for all active connections to finish processing. As the `Sender` 204 | // handle held by the listener has been dropped above, the only remaining 205 | // `Sender` instances are held by connection handler tasks. When those drop, 206 | // the `mpsc` channel will close and `recv()` will return `None`. 207 | let _ = shutdown_complete_rx.recv().await; 208 | } 209 | 210 | impl Listener { 211 | /// Run the server 212 | /// 213 | /// Listen for inbound connections. For each inbound connection, spawn a 214 | /// task to process that connection. 215 | /// 216 | /// # Errors 217 | /// 218 | /// Returns `Err` if accepting returns an error. This can happen for a 219 | /// number reasons that resolve over time. For example, if the underlying 220 | /// operating system has reached an internal limit for max number of 221 | /// sockets, accept will fail. 222 | /// 223 | /// The process is not able to detect when a transient error resolves 224 | /// itself. One strategy for handling this is to implement a back off 225 | /// strategy, which is what we do here. 226 | async fn run(&mut self) -> crate::Result<()> { 227 | info!("accepting inbound connections"); 228 | 229 | loop { 230 | // Wait for a permit to become available 231 | // 232 | // `acquire` returns a permit that is bound via a lifetime to the 233 | // semaphore. When the permit value is dropped, it is automatically 234 | // returned to the semaphore. This is convenient in many cases. 235 | // However, in this case, the permit must be returned in a different 236 | // task than it is acquired in (the handler task). To do this, we 237 | // "forget" the permit, which drops the permit value **without** 238 | // incrementing the semaphore's permits. Then, in the handler task 239 | // we manually add a new permit when processing completes. 240 | // 241 | // `acquire()` returns `Err` when the semaphore has been closed. We 242 | // don't ever close the sempahore, so `unwrap()` is safe. 243 | self.limit_connections.acquire().await.unwrap().forget(); 244 | 245 | // Accept a new socket. This will attempt to perform error handling. 246 | // The `accept` method internally attempts to recover errors, so an 247 | // error here is non-recoverable. 248 | let socket = self.accept().await?; 249 | 250 | // Create the necessary per-connection handler state. 251 | let mut handler = Handler { 252 | // Get a handle to the shared database. 253 | db: self.db_holder.db(), 254 | 255 | // Initialize the connection state. This allocates read/write 256 | // buffers to perform MTProto protocol frame parsing. 257 | connection: Connection::new(socket), 258 | 259 | // The connection state needs a handle to the max connections 260 | // semaphore. When the handler is done processing the 261 | // connection, a permit is added back to the semaphore. 262 | limit_connections: self.limit_connections.clone(), 263 | 264 | // Receive shutdown notifications. 265 | shutdown: Shutdown::new(self.notify_shutdown.subscribe()), 266 | 267 | // Notifies the receiver half once all clones are 268 | // dropped. 269 | _shutdown_complete: self.shutdown_complete_tx.clone(), 270 | }; 271 | 272 | // Spawn a new task to process the connections. Tokio tasks are like 273 | // asynchronous green threads and are executed concurrently. 274 | tokio::spawn(async move { 275 | // Process the connection. If an error is encountered, log it. 276 | if let Err(err) = handler.run().await { 277 | error!(cause = ?err, "connection error"); 278 | } 279 | // call drop for handler 280 | }); 281 | } 282 | } 283 | 284 | /// Accept an inbound connection. 285 | /// 286 | /// Errors are handled by backing off and retrying. An exponential backoff 287 | /// strategy is used. After the first failure, the task waits for 1 second. 288 | /// After the second failure, the task waits for 2 seconds. Each subsequent 289 | /// failure doubles the wait time. If accepting fails on the 6th try after 290 | /// waiting for 64 seconds, then this function returns with an error. 291 | async fn accept(&mut self) -> crate::Result { 292 | let mut backoff = 1; 293 | 294 | // Try to accept a few times 295 | loop { 296 | // Perform the accept operation. If a socket is successfully 297 | // accepted, return it. Otherwise, save the error. 298 | match self.listener.accept().await { 299 | Ok((socket, _)) => return Ok(socket), 300 | Err(err) => { 301 | if backoff > 64 { 302 | // Accept has failed too many times. Return the error. 303 | return Err(err.into()); 304 | } 305 | } 306 | } 307 | 308 | // Pause execution until the back off period elapses. 309 | time::sleep(Duration::from_secs(backoff)).await; 310 | 311 | // Double the back off 312 | backoff *= 2; 313 | } 314 | } 315 | } 316 | 317 | impl Handler { 318 | /// Process a single connection. 319 | /// 320 | /// Request frames are read from the socket and processed. Responses are 321 | /// written back to the socket. 322 | /// 323 | /// 324 | /// When the shutdown signal is received, the connection is processed until 325 | /// it reaches a safe state, at which point it is terminated. 326 | #[instrument(skip(self))] 327 | async fn run(&mut self) -> crate::Result<()> { 328 | // As long as the shutdown signal has not been received, try to read a 329 | // new request frame. 330 | while !self.shutdown.is_shutdown() { 331 | // While reading a request frame, also listen for the shutdown 332 | // signal. 333 | let maybe_frame = tokio::select! { 334 | res = self.connection.read_frame() => res?, 335 | _ = self.shutdown.recv() => { 336 | // If a shutdown signal is received, return from `run`. 337 | // This will result in the task terminating. 338 | return Ok(()); 339 | } 340 | }; 341 | 342 | // If `None` is returned from `read_frame()` then the peer closed 343 | // the socket. There is no further work to do and the task can be 344 | // terminated. 345 | let frame = match maybe_frame { 346 | Some(frame) => frame, 347 | None => return Ok(()), 348 | }; 349 | 350 | // Convert the frame into a command struct. This returns an 351 | // error if the frame is not a valid command or it is an 352 | // unsupported command. 353 | let cmd = Command::from_frame(frame)?; 354 | 355 | // Logs the `cmd` object. The syntax here is a shorthand provided by 356 | // the `tracing` crate. It can be thought of as similar to: 357 | // 358 | // ``` 359 | // debug!(cmd = format!("{:?}", cmd)); 360 | // ``` 361 | // 362 | // `tracing` provides structured logging, so information is "logged" 363 | // as key-value pairs. 364 | debug!(?cmd); 365 | 366 | // Perform the work needed to apply the command. This may mutate the 367 | // database state as a result. 368 | // 369 | // The connection is passed into the apply function which allows the 370 | // command to write response frames directly to the connection. In 371 | // the case of pub/sub, multiple frames may be send back to the 372 | // peer. 373 | cmd.apply(&self.db, &mut self.connection, &mut self.shutdown) 374 | .await?; 375 | } 376 | 377 | Ok(()) 378 | } 379 | } 380 | 381 | impl Drop for Handler { 382 | fn drop(&mut self) { 383 | // Add a permit back to the semaphore. 384 | // 385 | // Doing so unblocks the listener if the max number of 386 | // connections has been reached. 387 | // 388 | // This is done in a `Drop` implementation in order to guarantee that 389 | // the permit is added even if the task handling the connection panics. 390 | // If `add_permit` was called at the end of the `run` function and some 391 | // bug causes a panic. The permit would never be returned to the 392 | // semaphore. 393 | self.limit_connections.add_permits(1); 394 | } 395 | } 396 | 397 | #[cfg(test)] 398 | mod tests { 399 | use super::*; 400 | 401 | #[tokio::test] 402 | async fn test_note_about_move_in_run() { 403 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 404 | 405 | let (notify_shutdown, _) = broadcast::channel(1); 406 | let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1); 407 | 408 | // TODO: why needs to pass shutdown_complete_rx to listener 409 | let server = Listener { 410 | listener, 411 | db_holder: DbDropGuard::new(), 412 | limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)), 413 | notify_shutdown, 414 | shutdown_complete_tx, 415 | shutdown_complete_rx, 416 | }; 417 | 418 | // uncomment the following lines to get error 419 | // drop(notify_shutdown); // value used here after move 420 | // drop(shutdown_complete_tx); // value used here after move 421 | // let _ = shutdown_complete_rx.recv().await; // value used here after move 422 | 423 | let Listener { 424 | mut shutdown_complete_rx, 425 | shutdown_complete_tx, 426 | notify_shutdown, 427 | .. 428 | } = server; 429 | 430 | drop(notify_shutdown); // value used here after move 431 | drop(shutdown_complete_tx); // value used here after move 432 | 433 | let _ = shutdown_complete_rx.recv().await; // value used here after move 434 | } 435 | 436 | #[tokio::test] 437 | async fn test_why_we_need_db_drop_guard() { 438 | // 1) when Listener drop (after run function terminated) 439 | // 2) then the db field (DbDropGuard) also drop 440 | // 3) back ground purge task is stopped 441 | } 442 | } 443 | -------------------------------------------------------------------------------- /src/db.rs: -------------------------------------------------------------------------------- 1 | use bytes::Bytes; 2 | use std::collections::{BTreeMap, HashMap}; 3 | use std::sync::{Arc, Mutex}; 4 | use std::time::Duration; 5 | use tokio::sync::{broadcast, Notify}; 6 | use tokio::time::{self, Instant}; 7 | use tracing::debug; 8 | 9 | /// A wrapper around a `Db` instance. This exists to allow orderly cleanup 10 | /// of the `Db` by signalling the background purge task to shut down when 11 | /// this struct is dropped. 12 | #[derive(Debug)] 13 | pub(crate) struct DbDropGuard { 14 | /// The `Db` instance that will be shut down when this `DbHolder` struct 15 | /// is dropped. 16 | db: Db, 17 | } 18 | 19 | /// Server state shared across all connections. 20 | /// 21 | /// `Db` contains a `HashMap` storing the key/value data and all 22 | /// `broadcast::Sender` values for active pub/sub channels. 23 | /// 24 | /// A `Db` instance is a handle to shared state. Cloning `Db` is shallow and 25 | /// only incurs an atomic ref count increment. 26 | /// 27 | /// When a `Db` value is created, a background task is spawned. This task is 28 | /// used to expire values after the requested duration has elapsed. The task 29 | /// runs until all instances of `Db` are dropped, at which point the task 30 | /// terminates. 31 | #[derive(Debug, Clone)] 32 | pub(crate) struct Db { 33 | /// Handle to shared state. The background task will also have an 34 | /// `Arc`. 35 | shared: Arc, 36 | } 37 | 38 | #[derive(Debug)] 39 | struct Shared { 40 | /// The shared state is guarded by a mutex. This is a `std::sync::Mutex` and 41 | /// not a Tokio mutex. This is because there are no asynchronous operations 42 | /// being performed while holding the mutex. Additionally, the critical 43 | /// sections are very small. 44 | /// 45 | /// A Tokio mutex is mostly intended to be used when locks need to be held 46 | /// across `.await` yield points. All other cases are **usually** best 47 | /// served by a std mutex. If the critical section does not include any 48 | /// async operations but is long (CPU intensive or performing blocking 49 | /// operations), then the entire operation, including waiting for the mutex, 50 | /// is considered a "blocking" operation and `tokio::task::spawn_blocking` 51 | /// should be used. 52 | state: Mutex, 53 | 54 | /// Notifies the background task handling entry expiration. The background 55 | /// task waits on this to be notified, then checks for expired values or the 56 | /// shutdown signal. 57 | background_task: Notify, 58 | } 59 | 60 | /// Entry in the key-value store 61 | #[derive(Debug, PartialEq)] 62 | struct Entry { 63 | /// Uniquely identifies this entry. 64 | id: u64, 65 | 66 | /// Stored data 67 | data: Bytes, // A cheaply cloneable and sliceable chunk of contiguous memory. 68 | 69 | /// Instant at which the entry expires and should be removed from the 70 | /// database. 71 | expires_at: Option, 72 | } 73 | 74 | #[derive(Debug, Default)] 75 | struct State { 76 | /// The key-value data. We are not trying to do anything fancy so a 77 | /// `std::collections::HashMap` works fine. 78 | entries: HashMap, 79 | 80 | /// The pub/sub key-space. telegram uses a **separate** key space for key-value 81 | /// and pub/sub. `mini-telegram` handles this by using a separate `HashMap`. 82 | pub_sub: HashMap>, 83 | 84 | /// Tracks key TTLs. 85 | /// 86 | /// A `BTreeMap` is used to maintain expirations sorted by when they expire. 87 | /// This allows the background task to iterate this map to find the value 88 | /// expiring next. 89 | /// 90 | /// While highly unlikely, it is possible for more than one expiration to be 91 | /// created for the same instant. Because of this, the `Instant` is 92 | /// insufficient for the key. A unique expiration identifier (`u64`) is used 93 | /// to break these ties. 94 | expirations: BTreeMap<(Instant, u64), String>, 95 | 96 | /// Identifier to use for the next expiration. Each expiration is associated 97 | /// with a unique identifier. See above for why. 98 | next_id: u64, 99 | 100 | /// True when the Db instance is shutting down. This happens when all `Db` 101 | /// values drop. Setting this to `true` signals to the background task to 102 | /// exit. 103 | shutdown: bool, 104 | } 105 | 106 | impl DbDropGuard { 107 | /// Create a new `DbHolder`, wrapping a `Db` instance. When this is dropped 108 | /// the `Db`'s purge task will be shut down. 109 | pub(crate) fn new() -> DbDropGuard { 110 | DbDropGuard { db: Db::new() } 111 | } 112 | 113 | /// Get the shared database. Internally, this is an 114 | /// `Arc`, so a clone only increments the ref count. 115 | pub(crate) fn db(&self) -> Db { 116 | self.db.clone() 117 | } 118 | } 119 | 120 | impl Drop for DbDropGuard { 121 | fn drop(&mut self) { 122 | // Signal the 'Db' instance to shut down the task that purges expired keys 123 | self.db.shutdown_purge_task(); 124 | } 125 | } 126 | 127 | impl Db { 128 | /// Create a new, empty, `Db` instance. Allocates shared state and spawns a 129 | /// background task to manage key expiration. 130 | pub(crate) fn new() -> Db { 131 | let shared = Arc::new(Shared { 132 | state: Default::default(), 133 | background_task: Notify::new(), 134 | }); 135 | 136 | // Start the background task. 137 | tokio::spawn(purge_expired_tasks(shared.clone())); 138 | 139 | Db { shared } 140 | } 141 | 142 | /// Get the value associated with a key. 143 | /// 144 | /// Returns `None` if there is no value associated with the key. This may be 145 | /// due to never having assigned a value to the key or a previously assigned 146 | /// value expired. 147 | pub(crate) fn get(&self, key: &str) -> Option { 148 | // Acquire the lock, get the entry and clone the value. 149 | // 150 | // Because data is stored using `Bytes`, a clone here is a shallow 151 | // clone. Data is not copied. 152 | let state = self.shared.state.lock().unwrap(); 153 | // map uses to convert Option<&Bytes> to Option 154 | state.entries.get(key).map(|entry| entry.data.clone()) 155 | } 156 | 157 | /// Set the value associated with a key along with an optional expiration 158 | /// Duration. 159 | /// 160 | /// If a value is already associated with the key, it is removed. 161 | pub(crate) fn set(&self, key: String, value: Bytes, expire: Option) { 162 | let mut state = self.shared.state.lock().unwrap(); 163 | 164 | // Get and increment the next insertion ID. Guarded by the lock, this 165 | // ensures a unique identifier is associated with each `set` operation. 166 | let id = state.next_id; 167 | state.next_id += 1; 168 | 169 | // If this `set` becomes the key that expires **next**, the background 170 | // task needs to be notified so it can update its state. 171 | // 172 | // Whether or not the task needs to be notified is computed during the 173 | // `set` routine. 174 | let mut notify = false; 175 | // note that expire.map() function invokes if expire has a value (not equal to None). 176 | let expires_at = expire.map(|duration| { 177 | // `Instant` at which the key expires. 178 | let when = Instant::now() + duration; 179 | 180 | // Only notify the worker task if the newly inserted expiration is the 181 | // **next** key to evict. In this case, the worker needs to be woken up 182 | // to update its state. 183 | notify = state 184 | .next_expiration() 185 | .map(|expiration| expiration > when) 186 | // why default is true? 187 | // if next_expiration() returns None (background worker in wait state) and 188 | // expire is not None, so reschedule the background worker 189 | .unwrap_or(true); 190 | 191 | // Track the expiration. 192 | state.expirations.insert((when, id), key.clone()); 193 | when 194 | }); 195 | 196 | // Insert the entry into the `HashMap`. 197 | // let mut map = HashMap::new(); 198 | // assert_eq!(map.insert(37, "a"), None); 199 | // assert_eq!(map.insert(37, "b"), Some("a")); 200 | let prev = state.entries.insert( 201 | key, 202 | Entry { 203 | id, 204 | data: value, 205 | expires_at, 206 | }, 207 | ); 208 | 209 | // If there was a value previously associated with the key **and** it 210 | // had an expiration time. The associated entry in the `expirations` map 211 | // must also be removed. This avoids leaking data. 212 | if let Some(prev) = prev { 213 | if let Some(when) = prev.expires_at { 214 | // clear expiration 215 | state.expirations.remove(&(when, prev.id)); 216 | } 217 | } 218 | 219 | // Release the mutex before notifying the background task. This helps 220 | // reduce contention by avoiding the background task waking up only to 221 | // be unable to acquire the mutex due to this function still holding it. 222 | drop(state); 223 | 224 | if notify { 225 | // Finally, only notify the background task if it needs to update 226 | // its state to reflect a new expiration. 227 | self.shared.background_task.notify_one(); 228 | } 229 | } 230 | 231 | /// Returns a `Receiver` for the requested channel. 232 | /// 233 | /// The returned `Receiver` is used to receive values broadcast by `PUBLISH` 234 | /// commands. 235 | pub(crate) fn subscribe(&self, key: String) -> broadcast::Receiver { 236 | use std::collections::hash_map::Entry; 237 | 238 | // Acquire the mutex 239 | let mut state = self.shared.state.lock().unwrap(); 240 | 241 | // If there is no entry for the requested channel, then create a new 242 | // broadcast channel and associate it with the key. If one already 243 | // exists, return an associated receiver. 244 | match state.pub_sub.entry(key) { 245 | Entry::Occupied(e) => e.get().subscribe(), 246 | Entry::Vacant(e) => { 247 | // No broadcast channel exists yet, so create one. 248 | // 249 | // The channel is created with a capacity of `1024` messages. A 250 | // message is stored in the channel until **all** subscribers 251 | // have seen it. This means that a slow subscriber could result 252 | // in messages being held indefinitely. 253 | // 254 | // When the channel's capacity fills up, publishing will result 255 | // in old messages being dropped. This prevents slow consumers 256 | // from blocking the entire system. 257 | let (tx, rx) = broadcast::channel(1024); 258 | e.insert(tx); 259 | rx 260 | } 261 | } 262 | } 263 | 264 | /// Publish a message to the channel. Returns the number of subscribers 265 | /// listening on the channel. 266 | pub(crate) fn publish(&self, key: &str, value: Bytes) -> usize { 267 | let state = self.shared.state.lock().unwrap(); 268 | 269 | state 270 | .pub_sub 271 | .get(key) 272 | // On a successful message send on the broadcast channel, the number 273 | // of subscribers is returned. An error indicates there are no 274 | // receivers, in which case, `0` should be returned. 275 | .map(|tx| tx.send(value).unwrap_or(0)) 276 | // If there is no entry for the channel key, then there are no 277 | // subscribers. In this case, return `0`. 278 | .unwrap_or(0) 279 | } 280 | 281 | /// Signals the purge background task to shut down. This is called by the 282 | /// `DbShutdown`s `Drop` implementation. 283 | fn shutdown_purge_task(&self) { 284 | // The background task must be signaled to shut down. This is done by 285 | // setting `State::shutdown` to `true` and signalling the task. 286 | let mut state = self.shared.state.lock().unwrap(); 287 | state.shutdown = true; 288 | 289 | // Drop the lock before signalling the background task. This helps 290 | // reduce lock contention by ensuring the background task doesn't 291 | // wake up only to be unable to acquire the mutex. 292 | drop(state); 293 | self.shared.background_task.notify_one(); 294 | } 295 | } 296 | 297 | impl Shared { 298 | /// Purge all expired keys and return the `Instant` at which the **next** 299 | /// key will expire. The background task will sleep until this instant. 300 | fn purge_expired_keys(&self) -> Option { 301 | let mut state = self.state.lock().unwrap(); 302 | if state.shutdown { 303 | // The database is shutting down. 304 | // All handles to the shared state have dropped. 305 | // So, the background task should exit immediately. 306 | return None; 307 | } 308 | 309 | // This is needed to make the borrow checker happy. In short, `lock()` 310 | // returns a `MutexGuard` and not a `&mut State`. The borrow checker is 311 | // not able to see "through" the mutex guard and determine that it is 312 | // safe to access both `state.expirations` and `state.entries` mutably, 313 | // so we get a "real" mutable reference to `State` outside of the loop. 314 | // 315 | // - while ... state.expirations.iter().next() <-- next(): state.entries state.expirations 316 | // - state.entries.remove(key); <-- remove(): state.entries to state.entries 317 | let state = &mut *state; 318 | 319 | // Find all keys scheduled to expire **before** now. 320 | let now = Instant::now(); 321 | while let Some((&(when, id), key)) = state.expirations.iter().next() { 322 | // * 323 | if when > now { 324 | // Done purging, `when` is the instant at which the next key 325 | // expires. The worker task will wait until this instant. 326 | return Some(when); 327 | } 328 | 329 | // The key expired, remove it 330 | state.entries.remove(key); 331 | state.expirations.remove(&(when, id)); 332 | } 333 | 334 | None 335 | } 336 | 337 | /// Returns `true` if the database is shutting down 338 | /// 339 | /// The `shutdown` flag is set when all `Db` values have dropped, indicating 340 | /// that the shared state can no longer be accessed. 341 | fn is_shutdown(&self) -> bool { 342 | self.state.lock().unwrap().shutdown 343 | } 344 | } 345 | 346 | impl State { 347 | /// Returns `Option` of the next expiration key 348 | fn next_expiration(&self) -> Option { 349 | self.expirations 350 | .keys() 351 | .next() 352 | .map(|expiration| expiration.0) 353 | } 354 | } 355 | 356 | /// Routine executed by the background task 357 | /// 358 | /// Wait to be notified. On notification, purge any expired keys form the shared_cloned 359 | /// state handle. If 'shutdown' is set, terminate the task. 360 | async fn purge_expired_tasks(shared: Arc) { 361 | // if the shoutdown flag is set, then the task should exist. 362 | while !shared.is_shutdown() { 363 | // Purge all keys that are expired. 364 | // The function returns the instat at which the **next** key will expire. 365 | // The worker should wait until the instant has passed thenpurge again. 366 | if let Some(when) = shared.purge_expired_keys() { 367 | // Wait until the next key expires **or** until the background task 368 | // is notified. If the task is notified, then it must reload its 369 | // state as new keys have been set to expire early. This is done by 370 | // looping. 371 | tokio::select! { 372 | _ = time::sleep_until(when) => {} 373 | _ = shared.background_task.notified() => {} 374 | } 375 | } else { 376 | // There are no keys expiring in the future. 377 | // wait until the task is notified 378 | shared.background_task.notified().await; 379 | } 380 | } 381 | 382 | debug!("Purge background task shut down") 383 | } 384 | 385 | #[cfg(test)] 386 | mod tests { 387 | use super::*; 388 | use std::time::Duration; 389 | extern crate test; 390 | // use test::Bencher; 391 | 392 | #[test] 393 | fn test_hash_map() { 394 | use std::collections::HashMap; 395 | 396 | let mut map = HashMap::new(); 397 | map.insert("a", 1); 398 | map.insert("b", 2); 399 | map.insert("c", 3); 400 | 401 | assert_eq!(map.get(&"a"), Some(&1)); 402 | assert_eq!(map.get(&"b"), Some(&2)); 403 | assert_eq!(map.get(&"c"), Some(&3)); 404 | } 405 | 406 | #[tokio::test] 407 | async fn test_broadcast() { 408 | // ┌──────────────┐ 409 | // ┌────► thread x │ 410 | // │ └──────────────┘ 411 | // ┌────────────┐ │ 412 | // │ thread z ├──┤ 413 | // └────────────┘ │ ┌──────────────┐ 414 | // [10,20] └────► thread y │ 415 | // └──────────────┘ 416 | // 417 | let (tx, mut rx1) = broadcast::channel(16); 418 | let mut rx2 = tx.subscribe(); 419 | tokio::spawn(async move { 420 | // thread x 421 | assert_eq!(rx1.recv().await.unwrap(), 10); 422 | assert_eq!(rx1.recv().await.unwrap(), 20); 423 | }); 424 | tokio::spawn(async move { 425 | // thread y 426 | assert_eq!(rx2.recv().await.unwrap(), 10); 427 | assert_eq!(rx2.recv().await.unwrap(), 20); 428 | }); 429 | 430 | // thread z 431 | tx.send(10).unwrap(); 432 | tx.send(20).unwrap(); 433 | } 434 | 435 | #[tokio::test] 436 | async fn test_broadcast_capacity_fills_up() { 437 | use tokio::time::{sleep, Duration}; 438 | 439 | let (sender, mut receiver) = broadcast::channel(2); 440 | 441 | let handler = tokio::spawn(async move { 442 | sleep(Duration::from_millis(500)).await; 443 | for _ in 0..3 { 444 | let x = receiver.recv().await; 445 | match x { 446 | Ok(i) => { 447 | // only received the last 2(chanel capacity) data 448 | assert!(i == 8 || i == 9) 449 | } 450 | Err(/* e */ _) => { 451 | // println!("received err:{}", e); 452 | break; 453 | } 454 | } 455 | } 456 | }); 457 | 458 | for i in 0..10 { 459 | let res = sender.send(i); 460 | match res { 461 | Ok(_) => { /*println!("sent")*/ } 462 | Err(e) => println!("send err:{}", e), 463 | } 464 | } 465 | 466 | handler.await.unwrap(); 467 | } 468 | 469 | #[test] 470 | fn test_arc_mutex_lock_1() { 471 | use std::sync::{Arc, Mutex}; 472 | use std::thread; 473 | 474 | let mutex = Arc::new(Mutex::new(0)); 475 | let c_mutex = Arc::clone(&mutex); 476 | 477 | thread::spawn(move || { 478 | *c_mutex.lock().unwrap() = 10; 479 | }) 480 | .join() 481 | .expect("thread::spawn failed"); 482 | 483 | assert_eq!(*mutex.lock().unwrap(), 10); 484 | } 485 | 486 | #[test] 487 | fn test_arc_mutex_lock_2() { 488 | use std::sync::mpsc::channel; 489 | use std::sync::{Arc, Mutex}; 490 | use std::thread; 491 | 492 | const N: usize = 10; 493 | 494 | // Spawn a few threads to increment a shared variable (non-atomically), and 495 | // let the main thread know once all increments are done. 496 | // 497 | // Here we're using an Arc to share memory among threads, and the data inside 498 | // the Arc is protected with a mutex. 499 | let data = Arc::new(Mutex::new(0)); 500 | 501 | let (tx, rx) = channel(); 502 | for _ in 0..N { 503 | let (data, tx) = (Arc::clone(&data), tx.clone()); 504 | thread::spawn(move || { 505 | // The shared state can only be accessed once the lock is held. 506 | // Our non-atomic increment is safe because we're the only thread 507 | // which can access the shared state when the lock is held. 508 | // 509 | // We unwrap() the return value to assert that we are not expecting 510 | // threads to ever fail while holding the lock. 511 | let mut data = data.lock().unwrap(); 512 | *data += 1; 513 | if *data == N { 514 | tx.send(()).unwrap(); 515 | } 516 | // the lock is unlocked here when `data` goes out of scope. 517 | }); 518 | } 519 | 520 | assert_eq!(rx.recv().unwrap(), ()); 521 | } 522 | 523 | #[tokio::test] 524 | async fn test_instant() { 525 | // TODO: Add example for different between std::Instant and tokio::Instant 526 | use tokio::time::{sleep, Duration, Instant}; 527 | 528 | let now = Instant::now(); 529 | sleep(Duration::new(1, 0)).await; 530 | let new_now = Instant::now(); 531 | // assert_eq!(new_now.checked_duration_since(now).unwrap().as_secs(), 1); 532 | assert!(new_now.duration_since(now) >= Duration::from_secs(1)); 533 | } 534 | 535 | #[test] 536 | fn test_tuple() { 537 | let tuple = (1, "hello", 4.5, true); 538 | let (a, b, c, d) = tuple; 539 | // println!("{:?}, {:?}, {:?}, {:?}", a, b, c, d); 540 | 541 | assert_eq!(a, 1); 542 | assert_eq!(b, "hello"); 543 | assert_eq!(c, 4.5); 544 | assert!(d); 545 | } 546 | 547 | #[tokio::test] 548 | async fn test_notify() { 549 | // use tokio::time::sleep; 550 | use std::sync::atomic::AtomicUsize; 551 | use std::sync::atomic::Ordering; 552 | 553 | // `Notify` can be thought of as a [`Semaphore`] starting with 0 permits. 554 | // [`notified().await`] waits for a permit to become available, and [`notify_one()`] 555 | // sets a permit **if there currently are no available permits**. 556 | let notify = Arc::new(Notify::new()); 557 | let notify2 = notify.clone(); 558 | 559 | // https://doc.rust-lang.org/rust-by-example/std/rc.html 560 | // let shared = Arc::new(Mutex::new(1)); 561 | let shared = Arc::new(AtomicUsize::new(1)); 562 | 563 | let shared_cloned = Arc::clone(&shared); 564 | let handle = tokio::spawn(async move { 565 | // println!("start worker thread..."); 566 | notify2.notified().await; 567 | // println!("received notification"); 568 | // let mut data = c_shared.lock().await; 569 | // *data = 10; 570 | shared_cloned.store(10, Ordering::SeqCst); 571 | }); 572 | // sleep(Duration::from_millis(100)).await; 573 | // println!("sending notification"); 574 | notify.notify_one(); 575 | 576 | handle.await.unwrap(); // wait until the thread end 577 | // assert_eq!(*shared.lock().await, 10); 578 | assert_eq!(10, shared.load(Ordering::SeqCst)); 579 | } 580 | 581 | #[tokio::test] 582 | async fn test_semaphore() { 583 | use tokio::sync::{Semaphore, TryAcquireError}; 584 | let semaphore = Semaphore::new(3); // forget --> 2 585 | //{ 586 | // 3 587 | // let a_permit = semaphore.acquire().await.unwrap().forget(); 588 | let a_permit = semaphore.acquire().await.unwrap(); 589 | // 2 590 | //} 591 | // 2 592 | let two_permits = semaphore.acquire_many(2).await.unwrap(); 593 | assert_eq!(semaphore.available_permits(), 0); 594 | 595 | let permit_attempt = semaphore.try_acquire(); 596 | assert_eq!(permit_attempt.err(), Some(TryAcquireError::NoPermits)); 597 | } 598 | 599 | #[test] 600 | fn test_option_map() { 601 | let test_cases = vec![Some(Duration::from_secs(1)), None]; 602 | let now = std::time::Instant::now(); 603 | for t in test_cases.iter() { 604 | let expires_at = t.map(|duration| { 605 | let when = now + duration; 606 | when 607 | }); 608 | match expires_at { 609 | Some(d) => { 610 | assert_eq!(d, now + t.unwrap()); 611 | } 612 | None => { 613 | assert!(expires_at.is_none()); 614 | } 615 | } 616 | } 617 | } 618 | 619 | #[tokio::test] 620 | async fn test_wait_group() { 621 | use std::sync::atomic::{AtomicUsize, Ordering}; 622 | use tokio::sync::mpsc::channel; 623 | 624 | const N: usize = 10; 625 | static GLOBAL_THREAD_COUNT: AtomicUsize = AtomicUsize::new(0); 626 | 627 | // Create a new wait group. 628 | let (send, mut recv) = channel::(1); 629 | 630 | for _ in 0..N { 631 | let _sender = send.clone(); 632 | tokio::spawn(async move { 633 | GLOBAL_THREAD_COUNT.fetch_add(1, Ordering::SeqCst); 634 | // println!("---worker--"); 635 | drop(_sender) // release sender for current task 636 | }); 637 | } 638 | 639 | // Wait for the tasks to finish. 640 | // 641 | // We drop our sender first because the recv() call otherwise 642 | // sleeps forever. 643 | drop(send); 644 | 645 | // When every sender has gone out of scope, the recv call 646 | // will return with an error. We ignore the error. 647 | let _ = recv.recv().await; 648 | 649 | // println!("---main/test thread---") 650 | assert_eq!(GLOBAL_THREAD_COUNT.load(Ordering::SeqCst), N) 651 | } 652 | 653 | #[test] 654 | fn test_using_state_structure() { 655 | let mut state: State = Default::default(); 656 | assert!(state.entries.is_empty()); 657 | assert_eq!(state.entries.len(), 0); 658 | 659 | assert!(state.pub_sub.is_empty()); 660 | assert_eq!(state.pub_sub.len(), 0); 661 | 662 | assert!(state.expirations.is_empty()); 663 | assert_eq!(state.expirations.len(), 0); 664 | 665 | assert_eq!(state.next_id, 0); 666 | assert_eq!(state.shutdown, false); 667 | 668 | let key = String::from("1"); // "1".into() 669 | let entry = Entry { 670 | id: 1, 671 | data: Bytes::from_static(b"hello"), 672 | expires_at: None, 673 | }; 674 | // insert 675 | state.entries.insert(key, entry); 676 | assert_eq!(state.entries.len(), 1); 677 | // get 678 | let entry = state.entries.get("1").unwrap(); 679 | assert_eq!(entry.id, 1); 680 | assert_eq!(entry.data, Bytes::from_static(b"hello")); 681 | assert_eq!(entry.expires_at, None); 682 | 683 | state.next_id = 10; 684 | assert_eq!(state.next_id, 10); 685 | 686 | state.shutdown = true; 687 | assert!(state.shutdown); 688 | } 689 | 690 | #[tokio::test] 691 | async fn test_db() { 692 | use tokio::sync::mpsc::channel; 693 | 694 | let db = Db::new(); 695 | const N: usize = 10; 696 | 697 | // Create a new wait group. 698 | let (send, mut recv) = channel::(1); 699 | 700 | for _ in 0..N { 701 | let shared = db.shared.clone(); 702 | let _sender = send.clone(); 703 | tokio::spawn(async move { 704 | let mut state = shared.state.lock().unwrap(); 705 | state.next_id += 1; 706 | drop(_sender) 707 | }); 708 | } 709 | 710 | // Wait for the tasks to finish. 711 | // 712 | // We drop our sender first because the recv() call otherwise 713 | // sleeps forever. 714 | drop(send); 715 | 716 | // When every sender has gone out of scope, the recv call 717 | // will return with an error. We ignore the error. 718 | let _ = recv.recv().await; 719 | assert_eq!(db.shared.clone().state.lock().unwrap().next_id as usize, N) 720 | } 721 | 722 | #[tokio::test] 723 | async fn test_map_on_none() { 724 | let db = Db::new(); 725 | let state = db.shared.state.lock().unwrap(); 726 | let expire: Option = None; 727 | 728 | // map on None value return None (without run the closure) 729 | // see map implementation in option.rs 730 | let notify = expire.map(|duration| { 731 | // println!("===== does not call at all ====="); 732 | let when = Instant::now() + duration; 733 | state 734 | .next_expiration() 735 | .map(|expiration| expiration > when) 736 | .unwrap_or(true) 737 | }); 738 | assert_eq!(notify, None) 739 | } 740 | 741 | #[test] 742 | fn test_bytes_shallow_clone() { 743 | use bytes::Buf; 744 | // Arc ptrs +---------+ 745 | // ________________________ / | Bytes 2 | 746 | // / +---------+ 747 | // / +-----------+ | | 748 | // |_________/ | Bytes 1 | | | 749 | // | +-----------+ | | 750 | // | | | ___/ data | tail 751 | // | data | tail |/ | 752 | // v v v v 753 | // +-----+---------------------------------+-----+ 754 | // | Arc | | | | | 755 | // +-----+---------------------------------+-----+ 756 | let b = Bytes::from_static(b"hello yumcoder!"); 757 | let b_clone = b.clone(); 758 | let address_p = format!("{:p}", &b); 759 | let address_p_clone = format!("{:p}", &b_clone); 760 | assert_ne!(address_p, address_p_clone); 761 | let first_elem_b = format!("{:p}", &b.chunk()[0]); 762 | let first_elem_b_clone = format!("{:p}", &b_clone.chunk()[0]); 763 | assert_eq!(first_elem_b, first_elem_b_clone); 764 | // assert_eq!( 765 | // std::ptr::addr_of!(first_elem_b), 766 | // std::ptr::addr_of!(first_elem_b_clone) 767 | // ); 768 | } 769 | 770 | #[test] 771 | fn test_string_clone() { 772 | let s = "Hello World!".to_string(); 773 | let s_clone = s.clone(); 774 | // println!("{:p}", &*s); 775 | // s.get_mut(0..5).map(|s| { 776 | // s.make_ascii_uppercase(); 777 | // // &*s 778 | // }); 779 | // println!("{:p}", s.as_ptr()); // &*s 780 | // println!("{:p}", s_clone.as_ptr()); 781 | assert_ne!(s.as_ptr(), s_clone.as_ptr()); 782 | } 783 | } 784 | --------------------------------------------------------------------------------