├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── benches └── bench.rs ├── examples └── echo.rs └── src ├── builder.rs ├── config.rs ├── error.rs ├── frame.rs ├── lib.rs └── mux.rs /.gitignore: -------------------------------------------------------------------------------- 1 | private 2 | target 3 | Cargo.lock -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "async_smux" 3 | version = "0.3.4" 4 | authors = ["black-binary "] 5 | description = "Asynchronous smux multiplexing library" 6 | license = "MIT" 7 | edition = "2021" 8 | 9 | [dependencies] 10 | futures = "0.3" 11 | tokio = { version = "1", features = ["io-util", "sync", "time"] } 12 | tokio-util = { version = "0.7", features = ["codec"] } 13 | bytes = "1.9" 14 | rand = "0.8" 15 | log = "0.4" 16 | parking_lot = "0.12" 17 | pin-project = "1.1" 18 | thiserror = "2.0" 19 | futures-sink = "0.3" 20 | 21 | [dev-dependencies] 22 | tokio = { version = "1", features = ["full"] } 23 | env_logger = "0.11" 24 | criterion = "0.5" 25 | pprof = { version = "0.14", features = ["flamegraph"] } 26 | lazy_static = "1.4" 27 | 28 | [profile.release] 29 | debug = true 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Black Binary 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # async-smux 2 | 3 | [crates.io](https://crates.io/crates/async_smux) 4 | 5 | A lightweight asynchronous [smux](https://github.com/xtaci/smux) (Simple MUltipleXing) library for smol/async-std and any async runtime compatible to `futures`. 6 | 7 | ![img](https://raw.githubusercontent.com/xtaci/smux/master/mux.jpg) 8 | 9 | `async-smux` consumes a struct implementing `AsyncRead + AsyncWrite + Unpin + Send`, like `TcpStream` and `TlsStream`, to create a `Mux` struct. And then you may spawn multiple `MuxStream`s (up to 4294967295) over `Mux`, which also implements `AsyncRead + AsyncWrite + Unpin + Send`. 10 | 11 | ## Benchmark 12 | 13 | Here is a simple benchmarking result on my local machine, comparing to the original version smux (written in go). 14 | 15 | | Implementation | Throughput (TCP) | Handshake | 16 | | ----------------- | ---------------- | ---------- | 17 | | smux (go) | 0.4854 GiB/s | 17.070 K/s | 18 | | async-smux (rust) | 1.0550 GiB/s | 81.774 K/s | 19 | 20 | Run `cargo bench` to test it by yourself. Check out `/benches` directory for more details. 21 | 22 | ## Laziness 23 | 24 | No thread or task will be spawned by this library. It just spawns a few `future`s. So it's totally runtime-independent. 25 | 26 | `Mux` and `MuxStream` are completely lazy and will DO NOTHING if you don't `poll()` them. 27 | 28 | Any polling operation, including `.read()` ,`.write()`, `accept()` and `connect()`, will push `Mux` and `MuxStream` working. 29 | 30 | ## Specification 31 | 32 | ```text 33 | VERSION(1B) | CMD(1B) | LENGTH(2B) | STREAMID(4B) | DATA(LENGTH) 34 | 35 | VERSION: 1 36 | 37 | CMD: 38 | SYN(0) 39 | FIN(1) 40 | PSH(2) 41 | NOP(3) 42 | 43 | STREAMID: Randomly chosen number 44 | ``` 45 | 46 | ## Example 47 | 48 | ```rust 49 | use async_smux::{Mux, MuxConfig}; 50 | use async_std::net::{TcpListener, TcpStream}; 51 | use async_std::prelude::*; 52 | 53 | async fn echo_server() { 54 | let listener = TcpListener::bind("0.0.0.0:12345").await.unwrap(); 55 | let (stream, _) = listener.accept().await.unwrap(); 56 | let mux = Mux::new(stream, MuxConfig::default()); 57 | loop { 58 | let mut mux_stream = mux.accept().await.unwrap(); 59 | let mut buf = [0u8; 1024]; 60 | let size = mux_stream.read(&mut buf).await.unwrap(); 61 | mux_stream.write(&buf[..size]).await.unwrap(); 62 | } 63 | } 64 | 65 | fn main() { 66 | async_std::task::spawn(echo_server()); 67 | async_std::task::block_on(async { 68 | smol::Timer::after(std::time::Duration::from_secs(1)).await; 69 | let stream = TcpStream::connect("127.0.0.1:12345").await.unwrap(); 70 | let mux = Mux::new(stream, MuxConfig::default()); 71 | for i in 0..100 { 72 | let mut mux_stream = mux.connect().await.unwrap(); 73 | let mut buf = [0u8; 1024]; 74 | mux_stream.write(b"hello").await.unwrap(); 75 | let size = mux_stream.read(&mut buf).await.unwrap(); 76 | let reply = String::from_utf8(buf[..size].to_vec()).unwrap(); 77 | println!("{}: {}", i, reply); 78 | } 79 | }); 80 | } 81 | ``` 82 | -------------------------------------------------------------------------------- /benches/bench.rs: -------------------------------------------------------------------------------- 1 | #![feature(test)] 2 | 3 | use tokio::net::{TcpListener, TcpStream}; 4 | 5 | extern crate test; 6 | 7 | pub fn add_two(a: i32) -> i32 { 8 | a + 2 9 | } 10 | 11 | async fn get_tcp_pair() -> (TcpStream, TcpStream) { 12 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 13 | let addr = listener.local_addr().unwrap(); 14 | let h = tokio::spawn(async move { 15 | let (a, _) = listener.accept().await.unwrap(); 16 | a 17 | }); 18 | 19 | let b = TcpStream::connect(addr).await.unwrap(); 20 | let a = h.await.unwrap(); 21 | a.set_nodelay(true).unwrap(); 22 | b.set_nodelay(true).unwrap(); 23 | (a, b) 24 | } 25 | 26 | #[cfg(test)] 27 | mod tests { 28 | use super::*; 29 | use async_smux::{MuxBuilder, MuxStream}; 30 | use test::Bencher; 31 | use tokio::{ 32 | io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, 33 | runtime::Runtime, 34 | }; 35 | 36 | lazy_static::lazy_static! { 37 | static ref RT: Runtime = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); 38 | } 39 | 40 | fn get_mux_pair() -> (MuxStream, MuxStream) { 41 | RT.block_on(async { 42 | let (a, b) = get_tcp_pair().await; 43 | let (connector, _, worker) = MuxBuilder::client().with_connection(a).build(); 44 | RT.spawn(worker); 45 | let (_, mut acceptor, worker) = MuxBuilder::server().with_connection(b).build(); 46 | RT.spawn(worker); 47 | let a = connector.connect().unwrap(); 48 | let b = acceptor.accept().await.unwrap(); 49 | (a, b) 50 | }) 51 | } 52 | 53 | #[inline] 54 | async fn send(data: &[u8], a: &mut T) { 55 | a.write_all(data).await.unwrap(); 56 | a.flush().await.unwrap(); 57 | } 58 | 59 | #[inline] 60 | async fn recv(buf: &mut [u8], a: &mut T) -> std::io::Result<()> { 61 | a.read_exact(buf).await?; 62 | Ok(()) 63 | } 64 | 65 | const DATA_SIZE: usize = 0x20000; 66 | 67 | fn bench_send( 68 | bencher: &mut Bencher, 69 | mut a: T, 70 | mut b: T, 71 | ) { 72 | let data = vec![0; DATA_SIZE]; 73 | let mut buf = vec![0; DATA_SIZE]; 74 | RT.spawn(async move { 75 | loop { 76 | if recv(&mut buf, &mut b).await.is_err() { 77 | break; 78 | } 79 | } 80 | }); 81 | bencher.bytes = DATA_SIZE as u64; 82 | 83 | // Warm up 84 | for _ in 0..10 { 85 | RT.block_on(async { 86 | send(&data, &mut a).await; 87 | }); 88 | } 89 | 90 | for _ in 0..10 { 91 | bencher.iter(|| { 92 | RT.block_on(async { 93 | send(&data, &mut a).await; 94 | }); 95 | }); 96 | } 97 | } 98 | 99 | #[bench] 100 | fn bench_tcp_send(bencher: &mut Bencher) { 101 | let (a, b) = RT.block_on(async { get_tcp_pair().await }); 102 | bench_send(bencher, a, b); 103 | } 104 | 105 | #[bench] 106 | fn bench_mux_send(bencher: &mut Bencher) { 107 | let (a, b) = get_mux_pair(); 108 | bench_send(bencher, a, b); 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /examples/echo.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use async_smux::MuxBuilder; 4 | use tokio::{ 5 | io::{AsyncReadExt, AsyncWriteExt}, 6 | net::{TcpListener, TcpStream}, 7 | }; 8 | 9 | async fn echo_server() { 10 | let listener = TcpListener::bind("127.0.0.1:12345").await.unwrap(); 11 | let (stream, _) = listener.accept().await.unwrap(); 12 | 13 | let (_, mut acceptor, worker) = MuxBuilder::server().with_connection(stream).build(); 14 | tokio::spawn(worker); 15 | 16 | println!("server launched"); 17 | while let Some(mut mux_stream) = acceptor.accept().await { 18 | println!("accepted mux stream {}", mux_stream.get_stream_id()); 19 | 20 | let mut buf = [0u8; 100]; 21 | let size = mux_stream.read(&mut buf).await.unwrap(); 22 | mux_stream.write_all(&buf[..size]).await.unwrap(); 23 | mux_stream.flush().await.unwrap(); 24 | mux_stream.shutdown().await.unwrap(); 25 | } 26 | } 27 | 28 | #[tokio::main] 29 | async fn main() { 30 | tokio::spawn(echo_server()); 31 | tokio::time::sleep(Duration::from_secs(3)).await; 32 | 33 | let stream = TcpStream::connect("127.0.0.1:12345").await.unwrap(); 34 | let (connector, _, worker) = MuxBuilder::client().with_connection(stream).build(); 35 | tokio::spawn(worker); 36 | 37 | for i in 0..10 { 38 | let mut mux_stream = connector.connect().unwrap(); 39 | let mut buf = [0u8; 5]; 40 | mux_stream.write_all(b"hello").await.unwrap(); 41 | mux_stream.read_exact(&mut buf).await.unwrap(); 42 | let reply = String::from_utf8(buf[..].to_vec()).unwrap(); 43 | println!("{}: reply = {}", i, reply); 44 | mux_stream.shutdown().await.unwrap(); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/builder.rs: -------------------------------------------------------------------------------- 1 | use std::num::{NonZeroU64, NonZeroUsize}; 2 | 3 | use crate::{ 4 | config::{MuxConfig, StreamIdType}, 5 | mux::{MuxWorker, TokioConn}, 6 | mux_connection, MuxAcceptor, MuxConnector, 7 | }; 8 | 9 | pub struct WithConnection { 10 | config: MuxConfig, 11 | connection: T, 12 | } 13 | 14 | pub struct WithConfig { 15 | config: MuxConfig, 16 | } 17 | 18 | pub struct Begin {} 19 | 20 | pub struct MuxBuilder { 21 | state: State, 22 | } 23 | 24 | impl MuxBuilder { 25 | pub fn client() -> MuxBuilder { 26 | MuxBuilder { 27 | state: WithConfig { 28 | config: MuxConfig { 29 | stream_id_type: StreamIdType::Odd, 30 | keep_alive_interval: None, 31 | idle_timeout: None, 32 | max_tx_queue: NonZeroUsize::new(1024).unwrap(), 33 | max_rx_queue: NonZeroUsize::new(1024).unwrap(), 34 | }, 35 | }, 36 | } 37 | } 38 | 39 | pub fn server() -> MuxBuilder { 40 | MuxBuilder { 41 | state: WithConfig { 42 | config: MuxConfig { 43 | stream_id_type: StreamIdType::Even, 44 | keep_alive_interval: None, 45 | idle_timeout: None, 46 | max_tx_queue: NonZeroUsize::new(1024).unwrap(), 47 | max_rx_queue: NonZeroUsize::new(1024).unwrap(), 48 | }, 49 | }, 50 | } 51 | } 52 | } 53 | 54 | impl MuxBuilder { 55 | pub fn with_keep_alive_interval(&mut self, interval_secs: NonZeroU64) -> &mut Self { 56 | self.state.config.keep_alive_interval = Some(interval_secs); 57 | self 58 | } 59 | 60 | pub fn with_idle_timeout(&mut self, timeout_secs: NonZeroU64) -> &mut Self { 61 | self.state.config.idle_timeout = Some(timeout_secs); 62 | self 63 | } 64 | 65 | pub fn with_max_tx_queue(&mut self, size: NonZeroUsize) -> &mut Self { 66 | self.state.config.max_tx_queue = size; 67 | self 68 | } 69 | 70 | pub fn with_max_rx_queue(&mut self, size: NonZeroUsize) -> &mut Self { 71 | self.state.config.max_rx_queue = size; 72 | self 73 | } 74 | 75 | pub fn with_connection( 76 | &mut self, 77 | connection: T, 78 | ) -> MuxBuilder> { 79 | MuxBuilder { 80 | state: WithConnection { 81 | config: self.state.config, 82 | connection, 83 | }, 84 | } 85 | } 86 | } 87 | 88 | impl MuxBuilder> { 89 | pub fn build(self) -> (MuxConnector, MuxAcceptor, MuxWorker) { 90 | mux_connection(self.state.connection, self.state.config) 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/config.rs: -------------------------------------------------------------------------------- 1 | use std::num::{NonZeroU64, NonZeroUsize}; 2 | 3 | #[derive(Clone, Copy, Debug)] 4 | pub enum StreamIdType { 5 | Even = 0, 6 | Odd = 1, 7 | } 8 | 9 | #[derive(Clone, Copy, Debug)] 10 | pub struct MuxConfig { 11 | pub stream_id_type: StreamIdType, 12 | pub keep_alive_interval: Option, 13 | pub idle_timeout: Option, 14 | pub max_tx_queue: NonZeroUsize, 15 | pub max_rx_queue: NonZeroUsize, 16 | } 17 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | use crate::config::StreamIdType; 4 | 5 | #[derive(Debug, Error)] 6 | pub enum MuxError { 7 | #[error("Invalid command {0}")] 8 | InvalidCommand(u8), 9 | #[error("Invalid version {0}")] 10 | InvalidVersion(u8), 11 | #[error("Payload too large {0}")] 12 | PayloadTooLarge(usize), 13 | #[error("Duplicated stream id {0}")] 14 | DuplicatedStreamId(u32), 15 | 16 | #[error("Invalid stream ID from peer: {0}, local stream ID type: {0:?}")] 17 | InvalidPeerStreamIdType(u32, StreamIdType), 18 | 19 | #[error("Too many streams")] 20 | TooManyStreams, 21 | 22 | #[error("Inner connection closed")] 23 | ConnectionClosed, 24 | #[error("Mux stream closed: {0:x}")] 25 | StreamClosed(u32), 26 | 27 | #[error("IO error: {0}")] 28 | IoError(#[from] std::io::Error), 29 | } 30 | 31 | pub type MuxResult = Result; 32 | -------------------------------------------------------------------------------- /src/frame.rs: -------------------------------------------------------------------------------- 1 | use bytes::{Buf, BufMut, Bytes, BytesMut}; 2 | use tokio_util::codec::{Decoder, Encoder}; 3 | 4 | use std::io::Cursor; 5 | 6 | use crate::error::{MuxError, MuxResult}; 7 | 8 | pub const SMUX_VERSION: u8 = 1; 9 | pub const HEADER_SIZE: usize = 8; 10 | pub const MAX_PAYLOAD_SIZE: usize = 0xffff; 11 | 12 | #[derive(Eq, PartialEq, Debug, Clone, Copy)] 13 | pub(crate) enum MuxCommand { 14 | Sync = 0, 15 | Finish = 1, 16 | Push = 2, 17 | Nop = 3, 18 | } 19 | 20 | impl TryFrom for MuxCommand { 21 | type Error = MuxError; 22 | 23 | fn try_from(value: u8) -> Result { 24 | match value { 25 | 0 => Ok(MuxCommand::Sync), 26 | 1 => Ok(MuxCommand::Finish), 27 | 2 => Ok(MuxCommand::Push), 28 | 3 => Ok(MuxCommand::Nop), 29 | _ => Err(MuxError::InvalidCommand(value)), 30 | } 31 | } 32 | } 33 | 34 | #[derive(Copy, Clone, Eq, PartialEq, Debug)] 35 | pub(crate) struct MuxFrameHeader { 36 | pub version: u8, 37 | pub command: MuxCommand, 38 | pub length: u16, 39 | pub stream_id: u32, 40 | } 41 | 42 | impl MuxFrameHeader { 43 | #[inline] 44 | fn encode(&self, buf: &mut BytesMut) { 45 | buf.put_u8(self.version); 46 | buf.put_u8(self.command as u8); 47 | buf.put_u16_le(self.length); 48 | buf.put_u32_le(self.stream_id); 49 | } 50 | 51 | #[inline] 52 | fn decode(buf: &[u8]) -> MuxResult { 53 | let mut cursor = Cursor::new(buf); 54 | let version = cursor.get_u8(); 55 | if version != SMUX_VERSION { 56 | return Err(MuxError::InvalidVersion(version)); 57 | } 58 | let command = MuxCommand::try_from(cursor.get_u8())?; 59 | let length = cursor.get_u16_le(); 60 | let stream_id = cursor.get_u32_le(); 61 | Ok(Self { 62 | version, 63 | command, 64 | length, 65 | stream_id, 66 | }) 67 | } 68 | } 69 | 70 | #[derive(Clone)] 71 | pub(crate) struct MuxFrame { 72 | pub header: MuxFrameHeader, 73 | pub payload: Bytes, 74 | } 75 | 76 | impl MuxFrame { 77 | pub fn new(command: MuxCommand, stream_id: u32, payload: Bytes) -> Self { 78 | assert!(payload.len() <= MAX_PAYLOAD_SIZE); 79 | Self { 80 | header: MuxFrameHeader { 81 | version: SMUX_VERSION, 82 | command, 83 | length: payload.len() as u16, 84 | stream_id, 85 | }, 86 | payload, 87 | } 88 | } 89 | } 90 | 91 | pub(crate) struct MuxCodec {} 92 | 93 | impl Decoder for MuxCodec { 94 | type Item = MuxFrame; 95 | type Error = MuxError; 96 | 97 | fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { 98 | src.reserve(HEADER_SIZE + MAX_PAYLOAD_SIZE + HEADER_SIZE); 99 | 100 | if src.len() < HEADER_SIZE { 101 | return Ok(None); 102 | } 103 | let header = MuxFrameHeader::decode(src)?; 104 | let len = header.length as usize; 105 | if src.len() < HEADER_SIZE + len { 106 | return Ok(None); 107 | } 108 | src.advance(HEADER_SIZE); 109 | let payload = src.split_to(len).freeze(); 110 | 111 | debug_assert!(payload.len() == len); 112 | let frame = MuxFrame { header, payload }; 113 | 114 | Ok(Some(frame)) 115 | } 116 | } 117 | 118 | impl Encoder for MuxCodec { 119 | type Error = MuxError; 120 | 121 | fn encode(&mut self, item: MuxFrame, dst: &mut BytesMut) -> Result<(), Self::Error> { 122 | if item.header.version != SMUX_VERSION { 123 | return Err(MuxError::InvalidVersion(item.header.version)); 124 | } 125 | 126 | if item.payload.len() > MAX_PAYLOAD_SIZE { 127 | return Err(MuxError::PayloadTooLarge(item.payload.len())); 128 | } 129 | 130 | item.header.encode(dst); 131 | dst.put_slice(&item.payload); 132 | 133 | Ok(()) 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A lightweight and fast asynchronous [smux](https://github.com/xtaci/smux) (Simple MUltipleXing) library for Tokio async runtime. 2 | //! # Quickstart 3 | //! 4 | //! ```ignore 5 | //! ## Server 6 | //! // Initialize a stream with `AsyncRead + AsyncWrite`, e.g. TcpStream 7 | //! let tcp_connection = ... 8 | //! // Spawn a smux server to multiplexing the tcp stream using `MuxBuilder` 9 | //! let connector, acceptor, worker = MuxBuilder::server().with_connection(tcp_connection).build(); 10 | //! // Spawn the smux worker (or a worker `future`, more precisely) 11 | //! // The worker keeps running and dispatch smux frames until you drop (or close) all streams, acceptors and connectors 12 | //! tokio::spawn(worker); 13 | //! 14 | //! // Now we are ready to go! 15 | //! // Both client and server can spawn and accept bi-directional streams 16 | //! let outgoing_stream = connector.connect().unwrap(); 17 | //! let incoming_stream = acceptor.accept().await.unwrap(); 18 | //! 19 | //! // Just use these smux streams like normal tcp streams :) 20 | //! incoming_stream.read(...).await.unwrap(); 21 | //! incoming_stream.write_all(...).await.unwrap(); 22 | //! ``` 23 | //! ## Client 24 | //! ```ignore 25 | //! let tcp_connection = ... 26 | //! // Just like what we do at the server side, except that we are calling the `client()` function this time 27 | //! let (connector, acceptor, worker) = MuxBuilder::client().with_connection(tcp_connection).build(); 28 | //! tokio::spawn(worker); 29 | //! 30 | //! let outgoing_stream1 = connector.connect().unwrap(); 31 | //! ... 32 | //! ``` 33 | 34 | pub mod builder; 35 | pub mod config; 36 | pub mod error; 37 | pub(crate) mod frame; 38 | pub(crate) mod mux; 39 | 40 | pub use builder::MuxBuilder; 41 | pub use config::{MuxConfig, StreamIdType}; 42 | pub use mux::{mux_connection, MuxAcceptor, MuxConnector, MuxStream}; 43 | 44 | #[cfg(test)] 45 | mod tests { 46 | use std::{future::poll_fn, num::NonZeroU64, pin::Pin, task::Poll, time::Duration}; 47 | 48 | use rand::RngCore; 49 | use tokio::{ 50 | io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf}, 51 | net::{TcpListener, TcpStream}, 52 | }; 53 | 54 | use crate::{builder::MuxBuilder, frame::MAX_PAYLOAD_SIZE, mux::TokioConn, MuxStream}; 55 | 56 | async fn get_tcp_pair() -> (TcpStream, TcpStream) { 57 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 58 | let addr = listener.local_addr().unwrap(); 59 | let h = tokio::spawn(async move { 60 | let (a, _) = listener.accept().await.unwrap(); 61 | a 62 | }); 63 | 64 | let b = TcpStream::connect(addr).await.unwrap(); 65 | let a = h.await.unwrap(); 66 | (a, b) 67 | } 68 | 69 | async fn test_stream(mut a: MuxStream, mut b: MuxStream) { 70 | const LEN: usize = MAX_PAYLOAD_SIZE + 0x200; 71 | let mut data1 = vec![0; LEN]; 72 | let mut data2 = vec![0; LEN]; 73 | rand::thread_rng().fill_bytes(&mut data1); 74 | rand::thread_rng().fill_bytes(&mut data2); 75 | 76 | let mut buf = vec![0; LEN]; 77 | 78 | a.write_all(&data1).await.unwrap(); 79 | a.flush().await.unwrap(); 80 | b.write_all(&data2).await.unwrap(); 81 | b.flush().await.unwrap(); 82 | 83 | a.read_exact(&mut buf).await.unwrap(); 84 | assert_eq!(buf, data2); 85 | b.read_exact(&mut buf).await.unwrap(); 86 | assert_eq!(buf, data1); 87 | 88 | a.write_all(&data1).await.unwrap(); 89 | a.flush().await.unwrap(); 90 | b.read_exact(&mut buf[..LEN / 2]).await.unwrap(); 91 | b.read_exact(&mut buf[LEN / 2..]).await.unwrap(); 92 | assert_eq!(buf, data1); 93 | 94 | a.write_all(&data1[..LEN / 2]).await.unwrap(); 95 | a.flush().await.unwrap(); 96 | b.read_exact(&mut buf[..LEN / 2]).await.unwrap(); 97 | assert_eq!(buf[..LEN / 2], data1[..LEN / 2]); 98 | 99 | a.shutdown().await.unwrap(); 100 | b.shutdown().await.unwrap(); 101 | } 102 | 103 | #[tokio::test(flavor = "multi_thread")] 104 | async fn test_tcp() { 105 | let (a, b) = get_tcp_pair().await; 106 | let (connector_a, mut acceptor_a, worker_a) = 107 | MuxBuilder::client().with_connection(a).build(); 108 | let (connector_b, mut acceptor_b, worker_b) = 109 | MuxBuilder::server().with_connection(b).build(); 110 | tokio::spawn(worker_a); 111 | tokio::spawn(worker_b); 112 | 113 | let stream1 = connector_a.clone().connect().unwrap(); 114 | let stream2 = acceptor_b.accept().await.unwrap(); 115 | test_stream(stream1, stream2).await; 116 | 117 | let stream1 = connector_b.connect().unwrap(); 118 | let stream2 = acceptor_a.accept().await.unwrap(); 119 | test_stream(stream1, stream2).await; 120 | 121 | assert_eq!(connector_a.get_num_streams(), 0); 122 | assert_eq!(connector_b.get_num_streams(), 0); 123 | 124 | let mut streams1 = vec![]; 125 | let mut streams2 = vec![]; 126 | const STREAM_NUM: usize = 0x1000; 127 | for _ in 0..STREAM_NUM { 128 | let stream = connector_a.connect().unwrap(); 129 | streams1.push(stream); 130 | } 131 | for _ in 0..STREAM_NUM { 132 | let stream = acceptor_b.accept().await.unwrap(); 133 | streams2.push(stream); 134 | } 135 | 136 | let handles = streams1 137 | .into_iter() 138 | .zip(streams2.into_iter()) 139 | .map(|(a, b)| { 140 | tokio::spawn(async move { 141 | test_stream(a, b).await; 142 | }) 143 | }) 144 | .collect::>(); 145 | 146 | for h in handles { 147 | h.await.unwrap(); 148 | } 149 | 150 | assert_eq!(connector_a.get_num_streams(), 0); 151 | assert_eq!(connector_b.get_num_streams(), 0); 152 | } 153 | 154 | #[tokio::test(flavor = "multi_thread")] 155 | async fn test_worker_drop() { 156 | let (a, b) = get_tcp_pair().await; 157 | let (connector_a, mut acceptor_a, worker_a) = 158 | MuxBuilder::client().with_connection(a).build(); 159 | let (connector_b, mut acceptor_b, worker_b) = 160 | MuxBuilder::server().with_connection(b).build(); 161 | let mut stream1 = connector_a.connect().unwrap(); 162 | let h1 = tokio::spawn(async move { 163 | let mut buf = vec![0; 0x100]; 164 | stream1.read_exact(&mut buf).await.unwrap_err(); 165 | }); 166 | 167 | drop(worker_a); 168 | drop(worker_b); 169 | 170 | assert!(connector_a.connect().is_err()); 171 | assert!(connector_b.connect().is_err()); 172 | assert!(acceptor_a.accept().await.is_none()); 173 | assert!(acceptor_b.accept().await.is_none()); 174 | h1.await.unwrap(); 175 | } 176 | 177 | #[tokio::test] 178 | async fn test_shutdown() { 179 | let (a, b) = get_tcp_pair().await; 180 | let (connector_a, acceptor_a, worker_a) = MuxBuilder::client().with_connection(a).build(); 181 | let (connector_b, mut acceptor_b, worker_b) = 182 | MuxBuilder::server().with_connection(b).build(); 183 | tokio::spawn(worker_a); 184 | tokio::spawn(worker_b); 185 | 186 | let mut stream1 = connector_a.connect().unwrap(); 187 | let mut stream2 = acceptor_b.accept().await.unwrap(); 188 | 189 | let data = [1, 2, 3, 4]; 190 | stream2.write_all(&data).await.unwrap(); 191 | stream2.shutdown().await.unwrap(); 192 | 193 | tokio::time::sleep(Duration::from_secs(1)).await; 194 | 195 | stream1.write_all(&[0, 1, 2, 3]).await.unwrap_err(); 196 | stream1.flush().await.unwrap_err(); 197 | let mut buf = vec![0; 4]; 198 | stream1.read_exact(&mut buf).await.unwrap(); 199 | assert_eq!(buf, data); 200 | assert_eq!(stream1.read(&mut buf).await.unwrap(), 0); 201 | 202 | drop(acceptor_a); 203 | let mut stream = connector_b.connect().unwrap(); 204 | assert_eq!(stream.read(&mut buf).await.unwrap(), 0); 205 | stream.flush().await.unwrap_err(); 206 | stream.shutdown().await.unwrap(); 207 | 208 | let mut stream1 = connector_a.connect().unwrap(); 209 | let mut stream2 = acceptor_b.accept().await.unwrap(); 210 | stream1.write_all(&data).await.unwrap(); 211 | stream1.flush().await.unwrap(); 212 | drop(stream1); 213 | tokio::time::sleep(Duration::from_secs(1)).await; 214 | 215 | let mut buf = vec![0; 4]; 216 | stream2.read_exact(&mut buf).await.unwrap(); 217 | assert!(buf == data); 218 | stream2.read_exact(&mut buf).await.unwrap_err(); 219 | stream2.write_all(&data).await.unwrap_err(); 220 | } 221 | 222 | #[tokio::test] 223 | async fn test_timeout() { 224 | let (a, b) = get_tcp_pair().await; 225 | let (connector_a, _, worker_a) = MuxBuilder::client() 226 | .with_idle_timeout(NonZeroU64::new(3).unwrap()) 227 | .with_connection(a) 228 | .build(); 229 | let (_, mut acceptor_b, worker_b) = MuxBuilder::server().with_connection(b).build(); 230 | tokio::spawn(async move { 231 | worker_a.await.unwrap(); 232 | }); 233 | tokio::spawn(async move { 234 | worker_b.await.unwrap(); 235 | }); 236 | 237 | let mut stream1 = connector_a.connect().unwrap(); 238 | let mut stream2 = acceptor_b.accept().await.unwrap(); 239 | tokio::time::sleep(Duration::from_secs(1)).await; 240 | assert!(!stream1.is_closed()); 241 | assert!(!stream2.is_closed()); 242 | 243 | tokio::time::sleep(Duration::from_secs(5)).await; 244 | 245 | assert!(stream1.is_closed()); 246 | assert!(stream2.is_closed()); 247 | } 248 | 249 | #[tokio::test] 250 | async fn test_recv_block() { 251 | let (a, b) = get_tcp_pair().await; 252 | let (connector_a, _, worker_a) = MuxBuilder::client().with_connection(a).build(); 253 | let (_, mut acceptor_b, worker_b) = MuxBuilder::server() 254 | .with_max_rx_queue(12.try_into().unwrap()) 255 | .with_connection(b) 256 | .build(); 257 | tokio::spawn(async move { 258 | worker_a.await.unwrap(); 259 | }); 260 | tokio::spawn(async move { 261 | worker_b.await.unwrap(); 262 | }); 263 | 264 | let mut stream_x1 = connector_a.connect().unwrap(); 265 | let mut stream_x2 = acceptor_b.accept().await.unwrap(); 266 | 267 | let mut stream_y1 = connector_a.connect().unwrap(); 268 | let mut stream_y2 = acceptor_b.accept().await.unwrap(); 269 | 270 | let data = &[1, 2, 3, 4]; 271 | for _ in 0..3 { 272 | stream_x1.write_all(data).await.unwrap(); 273 | } 274 | // stream_x is full now 275 | stream_y1.write_all(data).await.unwrap(); 276 | 277 | // stream_y should be blocked unless x incoming bytes is handled 278 | poll_fn(|cx| { 279 | let mut buf = [0; 128]; 280 | let mut buf = ReadBuf::new(&mut buf); 281 | let res = Pin::new(&mut stream_y2).poll_read(cx, &mut buf); 282 | assert!(res.is_pending()); 283 | Poll::Ready(()) 284 | }) 285 | .await; 286 | 287 | let mut buf = [0; 4]; 288 | for _ in 0..3 { 289 | stream_x2.read_exact(&mut buf).await.unwrap(); 290 | assert_eq!(&buf, data); 291 | } 292 | 293 | // stream_y is avaliable now 294 | poll_fn(|cx| { 295 | let mut buf_arr = [0; 128]; 296 | let mut buf = ReadBuf::new(&mut buf_arr); 297 | let res = Pin::new(&mut stream_y2).poll_read(cx, &mut buf); 298 | assert!(res.is_ready()); 299 | Poll::Ready(()) 300 | }) 301 | .await; 302 | } 303 | 304 | #[tokio::test] 305 | async fn test_connection_drop() { 306 | let (a, b) = get_tcp_pair().await; 307 | let (connector_a, _, worker_a) = MuxBuilder::client().with_connection(a).build(); 308 | let (_, mut acceptor_b, worker_b) = MuxBuilder::server().with_connection(b).build(); 309 | tokio::spawn(worker_a); 310 | tokio::spawn(worker_b); 311 | 312 | let mut _stream1 = connector_a.connect().unwrap(); 313 | let mut stream2 = acceptor_b.accept().await.unwrap(); 314 | 315 | drop(_stream1); 316 | tokio::time::sleep(Duration::from_secs(1)).await; 317 | 318 | assert!(stream2.write_all(b"1234").await.is_err()); 319 | } 320 | 321 | #[tokio::test] 322 | async fn test_inner_shutdown() { 323 | let (a, b) = get_tcp_pair().await; 324 | 325 | let (connector_a, mut acceptor_a, worker_a) = 326 | MuxBuilder::client().with_connection(a).build(); 327 | let (connector_b, mut acceptor_b, worker_b) = 328 | MuxBuilder::server().with_connection(b).build(); 329 | 330 | let a_res = tokio::spawn(worker_a); 331 | drop(worker_b); 332 | tokio::time::sleep(Duration::from_secs(2)).await; 333 | 334 | assert!(connector_b.connect().is_err()); 335 | assert!(acceptor_b.accept().await.is_none()); 336 | 337 | drop(connector_b); 338 | drop(acceptor_b); 339 | 340 | tokio::time::sleep(Duration::from_secs(2)).await; 341 | assert!(connector_a.connect().is_err()); 342 | assert!(acceptor_a.accept().await.is_none()); 343 | a_res.await.unwrap().unwrap_err(); 344 | } 345 | } 346 | -------------------------------------------------------------------------------- /src/mux.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::{HashMap, VecDeque}, 3 | io::ErrorKind, 4 | num::Wrapping, 5 | pin::Pin, 6 | sync::Arc, 7 | task::{Context, Poll, Waker}, 8 | time::Duration, 9 | }; 10 | 11 | use bytes::{Buf, Bytes}; 12 | use futures::{future::poll_fn, ready, Future, FutureExt, SinkExt, Stream, StreamExt}; 13 | use futures_sink::Sink; 14 | use log::debug; 15 | use parking_lot::Mutex; 16 | use std::io as StdIo; 17 | use tokio::{ 18 | io::{self, AsyncRead, AsyncWrite}, 19 | time::{interval, Interval}, 20 | }; 21 | use tokio_util::codec::Framed; 22 | 23 | use crate::{ 24 | config::{MuxConfig, StreamIdType}, 25 | error::{MuxError, MuxResult}, 26 | frame::{MuxCodec, MuxCommand, MuxFrame, MAX_PAYLOAD_SIZE}, 27 | }; 28 | 29 | pub trait TokioConn: AsyncRead + AsyncWrite + Unpin {} 30 | 31 | impl TokioConn for T where T: AsyncRead + AsyncWrite + Unpin {} 32 | 33 | pub fn mux_connection( 34 | connection: T, 35 | config: MuxConfig, 36 | ) -> (MuxConnector, MuxAcceptor, MuxWorker) { 37 | let inner = Framed::new(connection, MuxCodec {}); 38 | let state = Arc::new(Mutex::new(MuxState { 39 | inner, 40 | handles: HashMap::new(), 41 | accept_queue: VecDeque::new(), 42 | accept_waker: None, 43 | tx_queue: VecDeque::with_capacity(config.max_tx_queue.get()), 44 | should_tx_waker: None, 45 | rx_consumed_waker: None, 46 | closed: false, 47 | accept_closed: false, 48 | stream_id_hint: Wrapping(config.stream_id_type as u32), 49 | stream_id_type: config.stream_id_type, 50 | idle_timeout: config.idle_timeout.map_or(1, |num| num.get()), 51 | max_tx_queue: config.max_tx_queue.get(), 52 | max_rx_queue: config.max_rx_queue.get(), 53 | })); 54 | ( 55 | MuxConnector { 56 | state: state.clone(), 57 | }, 58 | MuxAcceptor { 59 | state: state.clone(), 60 | }, 61 | MuxWorker { 62 | dispatcher: MuxDispatcher { 63 | state: state.clone(), 64 | }, 65 | sender: MuxSender { 66 | state: state.clone(), 67 | }, 68 | timer: MuxTimer { 69 | state, 70 | interval: interval(Duration::from_millis(500)), 71 | keep_alive_interval: config.keep_alive_interval.map(|a| interval(Duration::from_secs(a.get()))), 72 | idle_timeout_enabled: config.idle_timeout.is_some(), 73 | }, 74 | }, 75 | ) 76 | } 77 | 78 | pub struct MuxConnector { 79 | state: Arc>>, 80 | } 81 | 82 | impl MuxConnector { 83 | pub fn connect(&self) -> MuxResult> { 84 | let mut state = self.state.lock(); 85 | state.check_closed()?; 86 | 87 | let stream_id = state.alloc_stream_id()?; 88 | state.process_sync(stream_id, Direction::Tx)?; 89 | let frame = MuxFrame::new(MuxCommand::Sync, stream_id, Bytes::new()); 90 | state.enqueue_frame_global(frame); 91 | state.notify_should_tx(); 92 | 93 | let stream = MuxStream { 94 | stream_id, 95 | state: self.state.clone(), 96 | read_buffer: None, 97 | }; 98 | Ok(stream) 99 | } 100 | 101 | pub async fn close(&mut self) -> MuxResult<()> { 102 | poll_fn(|cx| { 103 | let mut state = self.state.lock(); 104 | state.close(); 105 | state.inner.poll_close_unpin(cx) 106 | }) 107 | .await?; 108 | Ok(()) 109 | } 110 | 111 | pub fn get_num_streams(&self) -> usize { 112 | self.state.lock().handles.len() 113 | } 114 | } 115 | 116 | impl Clone for MuxConnector { 117 | fn clone(&self) -> Self { 118 | Self { 119 | state: self.state.clone(), 120 | } 121 | } 122 | } 123 | 124 | pub struct MuxAcceptor { 125 | state: Arc>>, 126 | } 127 | 128 | impl Drop for MuxAcceptor { 129 | fn drop(&mut self) { 130 | self.state.lock().accept_closed = true; 131 | } 132 | } 133 | 134 | impl MuxAcceptor { 135 | pub async fn accept(&mut self) -> Option> { 136 | self.next().await 137 | } 138 | } 139 | 140 | impl Stream for MuxAcceptor { 141 | type Item = MuxStream; 142 | 143 | fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 144 | let mut state = self.state.lock(); 145 | if state.check_closed().is_err() { 146 | return Poll::Ready(None); 147 | } 148 | 149 | if let Some(stream) = state.accept_queue.pop_front() { 150 | Poll::Ready(Some(stream)) 151 | } else { 152 | state.register_accept_waker(cx); 153 | Poll::Pending 154 | } 155 | } 156 | } 157 | 158 | struct MuxTimer { 159 | state: Arc>>, 160 | interval: Interval, 161 | 162 | keep_alive_interval: Option, 163 | 164 | idle_timeout_enabled: bool, 165 | } 166 | 167 | impl Future for MuxTimer { 168 | type Output = MuxResult<()>; 169 | 170 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 171 | loop { 172 | ready!(self.interval.poll_tick(cx)); 173 | self.interval.reset(); 174 | 175 | // Ping check 176 | let mut is_ping_send_needs = false; 177 | if let Some(keep_alive_interval) = self.keep_alive_interval.as_mut() { 178 | if keep_alive_interval.poll_tick(cx).is_ready() { 179 | keep_alive_interval.reset(); 180 | 181 | is_ping_send_needs = true; 182 | } 183 | } 184 | 185 | let mut state = self.state.lock(); 186 | 187 | // Ping send 188 | if is_ping_send_needs { 189 | state.enqueue_frame_global(MuxFrame::new(MuxCommand::Nop, 0, Bytes::new())); 190 | state.notify_should_tx(); 191 | } 192 | 193 | // Clean timeout streams 194 | if self.idle_timeout_enabled { 195 | let dead_ids = state 196 | .handles 197 | .iter_mut() 198 | .filter_map(|(id, h)| { 199 | if h.idle_interval.poll_tick(cx).is_ready() { 200 | Some(*id) 201 | } else { 202 | None 203 | } 204 | }) 205 | .collect::>(); 206 | 207 | for stream_id in dead_ids { 208 | state.try_mark_finish(stream_id); 209 | state.send_finish(stream_id); 210 | state.notify_rx_consumed(); 211 | } 212 | } 213 | } 214 | } 215 | } 216 | 217 | struct MuxSender { 218 | state: Arc>>, 219 | } 220 | 221 | impl Future for MuxSender { 222 | type Output = MuxResult<()>; 223 | 224 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 225 | loop { 226 | let mut state = self.state.lock(); 227 | state.check_closed()?; 228 | ready!(state.poll_flush_frames(cx)).inspect_err(|_| state.close())?; 229 | ready!(state.poll_flush_inner(cx)).inspect_err(|_| state.close())?; 230 | ready!(state.poll_should_tx(cx)); 231 | } 232 | } 233 | } 234 | 235 | impl Drop for MuxSender { 236 | fn drop(&mut self) { 237 | self.state.lock().close(); 238 | } 239 | } 240 | 241 | struct MuxDispatcher { 242 | state: Arc>>, 243 | } 244 | 245 | impl Drop for MuxDispatcher { 246 | fn drop(&mut self) { 247 | self.state.lock().close(); 248 | } 249 | } 250 | 251 | impl Future for MuxDispatcher { 252 | type Output = MuxResult<()>; 253 | 254 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 255 | loop { 256 | let mut state = self.state.lock(); 257 | state.check_closed()?; 258 | 259 | ready!(state.poll_ready_rx_consumed(cx)); // Can stuck here forever, be careful 260 | 261 | let frame = ready!(state.poll_next_frame(cx)).inspect_err(|_| state.close())?; 262 | match frame.header.command { 263 | MuxCommand::Sync => { 264 | if state.accept_closed { 265 | state.send_finish(frame.header.stream_id); 266 | continue; 267 | } 268 | 269 | state.process_sync(frame.header.stream_id, Direction::Rx)?; 270 | 271 | let stream = MuxStream { 272 | stream_id: frame.header.stream_id, 273 | state: self.state.clone(), 274 | read_buffer: None, 275 | }; 276 | state.accept_queue.push_back(stream); 277 | state.notify_accept_stream(); 278 | } 279 | MuxCommand::Finish => { 280 | state.try_mark_finish(frame.header.stream_id); 281 | } 282 | MuxCommand::Push => { 283 | let stream_id = frame.header.stream_id; 284 | if !state.recv_push(frame) { 285 | state.send_finish(stream_id); 286 | } 287 | } 288 | MuxCommand::Nop => { 289 | // Do nothing 290 | } 291 | } 292 | } 293 | } 294 | } 295 | 296 | pub struct MuxWorker { 297 | dispatcher: MuxDispatcher, 298 | sender: MuxSender, 299 | timer: MuxTimer, 300 | } 301 | 302 | impl Future for MuxWorker { 303 | type Output = MuxResult<()>; 304 | 305 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 306 | let _ = self.timer.poll_unpin(cx); 307 | 308 | if self.dispatcher.poll_unpin(cx)?.is_ready() { 309 | return Poll::Ready(Ok(())); 310 | } 311 | 312 | if self.sender.poll_unpin(cx)?.is_ready() { 313 | return Poll::Ready(Ok(())); 314 | } 315 | 316 | Poll::Pending 317 | } 318 | } 319 | 320 | pub struct MuxStream { 321 | stream_id: u32, 322 | state: Arc>>, 323 | read_buffer: Option, 324 | } 325 | 326 | impl Drop for MuxStream { 327 | fn drop(&mut self) { 328 | let mut state = self.state.lock(); 329 | if !state.is_closed(self.stream_id) { 330 | // The user did not call `shutdown()` 331 | state.enqueue_frame_global(MuxFrame::new( 332 | MuxCommand::Finish, 333 | self.stream_id, 334 | Bytes::new(), 335 | )); 336 | state.notify_should_tx(); 337 | } 338 | state.remove_stream(self.stream_id); 339 | } 340 | } 341 | 342 | impl AsyncRead for MuxStream { 343 | fn poll_read( 344 | mut self: Pin<&mut Self>, 345 | cx: &mut Context<'_>, 346 | buf: &mut io::ReadBuf<'_>, 347 | ) -> Poll> { 348 | loop { 349 | if let Some(read_buffer) = &mut self.read_buffer { 350 | if read_buffer.len() <= buf.remaining() { 351 | buf.put_slice(read_buffer); 352 | self.read_buffer = None; 353 | } else { 354 | let len = buf.remaining(); 355 | buf.put_slice(&read_buffer[..len]); 356 | read_buffer.advance(len); 357 | } 358 | return Poll::Ready(Ok(())); 359 | } 360 | 361 | let frame = ready!(self.state.lock().poll_read_stream_data(cx, self.stream_id)) 362 | .map_err(mux_to_io_err)?; 363 | 364 | if let Some(frame) = frame { 365 | debug_assert_eq!(frame.header.command, MuxCommand::Push); 366 | self.read_buffer = Some(frame.payload); 367 | } else { 368 | // EOF 369 | return Poll::Ready(Ok(())); 370 | } 371 | } 372 | } 373 | } 374 | 375 | #[inline] 376 | fn mux_to_io_err(e: MuxError) -> StdIo::Error { 377 | StdIo::Error::new(ErrorKind::Other, e) 378 | } 379 | 380 | #[inline] 381 | fn new_io_err(kind: ErrorKind, reason: &str) -> StdIo::Error { 382 | StdIo::Error::new(kind, reason) 383 | } 384 | 385 | impl AsyncWrite for MuxStream { 386 | fn poll_write( 387 | self: Pin<&mut Self>, 388 | cx: &mut Context<'_>, 389 | buf: &[u8], 390 | ) -> Poll> { 391 | let mut state = self.state.lock(); 392 | if state.is_closed(self.stream_id) { 393 | return Poll::Ready(Err(new_io_err( 394 | StdIo::ErrorKind::ConnectionReset, 395 | "stream tx is already closed", 396 | ))); 397 | } 398 | 399 | ready!(state.poll_stream_write_ready(cx, self.stream_id)).map_err(mux_to_io_err)?; 400 | 401 | let mut write_buffer = Bytes::copy_from_slice(buf); 402 | while !write_buffer.is_empty() { 403 | let len = write_buffer.len().min(MAX_PAYLOAD_SIZE); 404 | let payload = write_buffer.split_to(len); 405 | let frame = MuxFrame::new(MuxCommand::Push, self.stream_id, payload); 406 | state.enqueue_frame_stream(self.stream_id, frame); 407 | } 408 | state.notify_should_tx(); 409 | Poll::Ready(Ok(buf.len())) 410 | } 411 | 412 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 413 | let mut state = self.state.lock(); 414 | if state.is_closed(self.stream_id) { 415 | return Poll::Ready(Err(new_io_err( 416 | StdIo::ErrorKind::ConnectionReset, 417 | "stream tx is already closed", 418 | ))); 419 | } 420 | 421 | state 422 | .poll_flush_stream_frames(cx, self.stream_id) 423 | .map_err(mux_to_io_err) 424 | } 425 | 426 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 427 | loop { 428 | let mut state = self.state.lock(); 429 | ready!(state 430 | .poll_flush_stream_frames(cx, self.stream_id) 431 | .map_err(mux_to_io_err))?; 432 | 433 | if state.is_closed(self.stream_id) { 434 | return Poll::Ready(Ok(())); 435 | } 436 | 437 | state.try_mark_finish(self.stream_id); 438 | state.send_finish(self.stream_id); 439 | } 440 | } 441 | } 442 | 443 | impl MuxStream { 444 | pub fn is_closed(&mut self) -> bool { 445 | self.state.lock().is_closed(self.stream_id) 446 | } 447 | 448 | pub fn get_stream_id(&self) -> u32 { 449 | self.stream_id 450 | } 451 | } 452 | 453 | struct StreamHandle { 454 | closed: bool, 455 | 456 | tx_queue: VecDeque, 457 | tx_done_waker: Option, 458 | 459 | rx_queue: VecDeque, 460 | rx_ready_waker: Option, 461 | 462 | idle_interval: Interval, 463 | } 464 | 465 | impl StreamHandle { 466 | fn new(idle_interval: Interval) -> Self { 467 | Self { 468 | closed: false, 469 | tx_queue: VecDeque::with_capacity(128), 470 | tx_done_waker: None, 471 | rx_queue: VecDeque::with_capacity(128), 472 | rx_ready_waker: None, 473 | idle_interval, 474 | } 475 | } 476 | 477 | #[inline] 478 | fn register_tx_done_waker(&mut self, cx: &Context<'_>) { 479 | self.tx_done_waker = Some(cx.waker().clone()); 480 | } 481 | 482 | #[inline] 483 | fn register_rx_ready_waker(&mut self, cx: &Context<'_>) { 484 | self.rx_ready_waker = Some(cx.waker().clone()); 485 | } 486 | 487 | #[inline] 488 | fn notify_rx_ready(&mut self) { 489 | if let Some(waker) = self.rx_ready_waker.take() { 490 | waker.wake(); 491 | } 492 | } 493 | 494 | #[inline] 495 | fn notify_tx_done(&mut self) { 496 | if let Some(waker) = self.tx_done_waker.take() { 497 | waker.wake(); 498 | } 499 | } 500 | } 501 | 502 | #[derive(Debug, Clone, Copy)] 503 | enum Direction { 504 | Tx, 505 | Rx, 506 | } 507 | 508 | struct MuxState { 509 | inner: Framed, 510 | handles: HashMap, 511 | 512 | accept_queue: VecDeque>, 513 | accept_waker: Option, 514 | 515 | tx_queue: VecDeque, 516 | should_tx_waker: Option, 517 | rx_consumed_waker: Option, 518 | 519 | closed: bool, 520 | accept_closed: bool, 521 | 522 | stream_id_hint: Wrapping, 523 | stream_id_type: StreamIdType, 524 | 525 | idle_timeout: u64, 526 | 527 | max_tx_queue: usize, 528 | max_rx_queue: usize, 529 | } 530 | 531 | impl Drop for MuxState { 532 | fn drop(&mut self) { 533 | debug!("mux state dropped"); 534 | } 535 | } 536 | 537 | impl MuxState { 538 | fn alloc_stream_id(&mut self) -> MuxResult { 539 | if self.handles.len() >= (u32::MAX / 2) as usize { 540 | return Err(MuxError::TooManyStreams); 541 | } 542 | 543 | loop { 544 | self.stream_id_hint += 2; 545 | 546 | if !self.handles.contains_key(&self.stream_id_hint.0) { 547 | break; 548 | } 549 | } 550 | 551 | Ok(self.stream_id_hint.0) 552 | } 553 | 554 | #[inline] 555 | fn remove_stream(&mut self, stream_id: u32) { 556 | self.handles.remove(&stream_id).unwrap(); 557 | // Rx queue may change 558 | self.notify_rx_consumed(); 559 | } 560 | 561 | fn send_finish(&mut self, stream_id: u32) { 562 | self.enqueue_frame_global(MuxFrame::new(MuxCommand::Finish, stream_id, Bytes::new())); 563 | self.notify_should_tx(); 564 | } 565 | 566 | fn process_sync(&mut self, stream_id: u32, dir: Direction) -> MuxResult<()> { 567 | if self.handles.contains_key(&stream_id) { 568 | return Err(MuxError::DuplicatedStreamId(stream_id)); 569 | } 570 | 571 | let from_peer = matches!(dir, Direction::Rx); 572 | if (stream_id % 2 != self.stream_id_type as u32) ^ from_peer { 573 | return Err(MuxError::InvalidPeerStreamIdType( 574 | stream_id, 575 | self.stream_id_type, 576 | )); 577 | } 578 | 579 | let handle = StreamHandle::new(self.get_idle_interval()); 580 | self.handles.insert(stream_id, handle); 581 | Ok(()) 582 | } 583 | 584 | #[inline] 585 | fn try_mark_finish(&mut self, stream_id: u32) { 586 | if let Some(h) = self.handles.get_mut(&stream_id) { 587 | h.closed = true; 588 | h.notify_rx_ready(); 589 | h.notify_tx_done(); 590 | } 591 | } 592 | 593 | fn recv_push(&mut self, frame: MuxFrame) -> bool { 594 | if let Some(handle) = self.handles.get_mut(&frame.header.stream_id) { 595 | handle.rx_queue.push_back(frame); 596 | handle.notify_rx_ready(); 597 | handle.idle_interval.reset(); 598 | true 599 | } else { 600 | false 601 | } 602 | } 603 | 604 | #[inline] 605 | fn get_rx_pending(&mut self) -> usize { 606 | self.handles 607 | .values() 608 | .filter(|h| !h.closed) 609 | .map(|h| h.rx_queue.len()) 610 | .sum() 611 | } 612 | 613 | fn poll_ready_rx_consumed(&mut self, cx: &Context<'_>) -> Poll<()> { 614 | let pending = self.get_rx_pending(); 615 | if pending > self.max_rx_queue { 616 | self.register_rx_consumed_waker(cx); 617 | Poll::Pending 618 | } else { 619 | Poll::Ready(()) 620 | } 621 | } 622 | 623 | fn is_closed(&self, stream_id: u32) -> bool { 624 | self.handles.get(&stream_id).unwrap().closed 625 | } 626 | 627 | fn poll_next_frame(&mut self, cx: &mut Context<'_>) -> Poll> { 628 | if let Some(r) = ready!(self.inner.poll_next_unpin(cx)) { 629 | let frame = r?; 630 | Poll::Ready(Ok(frame)) 631 | } else { 632 | Poll::Ready(Err(MuxError::ConnectionClosed)) 633 | } 634 | } 635 | 636 | #[inline] 637 | fn pin_inner(&mut self) -> Pin<&mut Framed> { 638 | Pin::new(&mut self.inner) 639 | } 640 | 641 | fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { 642 | ready!(self.pin_inner().poll_ready(cx))?; 643 | Poll::Ready(Ok(())) 644 | } 645 | 646 | fn write_frame(&mut self, frame: MuxFrame) -> MuxResult<()> { 647 | self.pin_inner().start_send(frame)?; 648 | Ok(()) 649 | } 650 | 651 | fn poll_read_stream_data( 652 | &mut self, 653 | cx: &mut Context<'_>, 654 | stream_id: u32, 655 | ) -> Poll>> { 656 | let handle = self.handles.get_mut(&stream_id).unwrap(); 657 | if let Some(f) = handle.rx_queue.pop_front() { 658 | self.notify_rx_consumed(); // Rx queue packet consumed 659 | Poll::Ready(Ok(Some(f))) 660 | } else if self.closed { 661 | Poll::Ready(Err(MuxError::ConnectionClosed)) 662 | } else if handle.closed { 663 | // EOF 664 | Poll::Ready(Ok(None)) 665 | } else { 666 | // No further packets, just wait 667 | handle.register_rx_ready_waker(cx); 668 | Poll::Pending 669 | } 670 | } 671 | 672 | #[inline] 673 | fn get_idle_interval(&self) -> Interval { 674 | interval(Duration::from_secs(self.idle_timeout)) 675 | } 676 | 677 | fn poll_stream_write_ready(&mut self, cx: &Context<'_>, stream_id: u32) -> Poll> { 678 | self.check_closed()?; 679 | let handle = self.handles.get_mut(&stream_id).unwrap(); 680 | if handle.tx_queue.len() > self.max_tx_queue { 681 | // A stream's tx queue is full 682 | handle.register_tx_done_waker(cx); 683 | // Notify the worker to transfer data now 684 | self.notify_should_tx(); 685 | Poll::Pending 686 | } else { 687 | Poll::Ready(Ok(())) 688 | } 689 | } 690 | 691 | fn enqueue_frame_stream(&mut self, stream_id: u32, frame: MuxFrame) { 692 | let interval = self.get_idle_interval(); 693 | let handle = self.handles.get_mut(&stream_id).unwrap(); 694 | handle.tx_queue.push_back(frame); 695 | handle.idle_interval = interval; 696 | } 697 | 698 | #[inline] 699 | fn enqueue_frame_global(&mut self, frame: MuxFrame) { 700 | self.tx_queue.push_back(frame); 701 | } 702 | 703 | #[inline] 704 | fn register_should_tx_waker(&mut self, cx: &Context<'_>) { 705 | self.should_tx_waker = Some(cx.waker().clone()); 706 | } 707 | 708 | #[inline] 709 | fn register_rx_consumed_waker(&mut self, cx: &Context<'_>) { 710 | self.rx_consumed_waker = Some(cx.waker().clone()); 711 | } 712 | 713 | #[inline] 714 | fn notify_should_tx(&mut self) { 715 | if let Some(waker) = self.should_tx_waker.take() { 716 | waker.wake(); 717 | } 718 | } 719 | 720 | #[inline] 721 | fn notify_rx_consumed(&mut self) { 722 | if let Some(waker) = self.rx_consumed_waker.take() { 723 | waker.wake(); 724 | } 725 | } 726 | 727 | fn poll_flush_stream_frames( 728 | &mut self, 729 | cx: &mut Context<'_>, 730 | stream_id: u32, 731 | ) -> Poll> { 732 | self.check_closed()?; 733 | let handle = self.handles.get_mut(&stream_id).unwrap(); 734 | if handle.tx_queue.is_empty() { 735 | Poll::Ready(Ok(())) 736 | } else { 737 | handle.register_tx_done_waker(cx); 738 | self.notify_should_tx(); 739 | Poll::Pending 740 | } 741 | } 742 | 743 | #[inline] 744 | fn register_accept_waker(&mut self, cx: &Context<'_>) { 745 | self.accept_waker = Some(cx.waker().clone()); 746 | } 747 | 748 | #[inline] 749 | fn notify_accept_stream(&mut self) { 750 | if let Some(waker) = self.accept_waker.take() { 751 | waker.wake(); 752 | } 753 | } 754 | 755 | fn close(&mut self) { 756 | self.closed = true; 757 | // Wake up everyone 758 | self.notify_accept_stream(); 759 | self.notify_rx_consumed(); 760 | self.notify_should_tx(); 761 | for (_, h) in self.handles.iter_mut() { 762 | h.closed = true; 763 | h.notify_rx_ready(); 764 | h.notify_tx_done(); 765 | } 766 | } 767 | 768 | fn check_closed(&self) -> MuxResult<()> { 769 | if self.closed { 770 | Err(MuxError::ConnectionClosed) 771 | } else { 772 | Ok(()) 773 | } 774 | } 775 | 776 | fn poll_flush_frames(&mut self, cx: &mut Context<'_>) -> Poll> { 777 | // Global queue 778 | // Flush control frames first 779 | while !self.tx_queue.is_empty() { 780 | ready!(self.poll_write_ready(cx))?; 781 | let frame = self.tx_queue.pop_front().unwrap(); 782 | self.write_frame(frame)?; 783 | } 784 | 785 | // Stream queues 786 | // Flush pending stream packets 787 | for (_, h) in self 788 | .handles 789 | .iter_mut() 790 | .filter(|(_, h)| !h.tx_queue.is_empty()) 791 | { 792 | while !h.tx_queue.is_empty() { 793 | ready!(Pin::new(&mut self.inner).poll_ready(cx))?; 794 | Pin::new(&mut self.inner).start_send(h.tx_queue.pop_front().unwrap())?; 795 | h.notify_tx_done(); 796 | } 797 | } 798 | 799 | Poll::Ready(Ok(())) 800 | } 801 | 802 | fn poll_flush_inner(&mut self, cx: &mut Context<'_>) -> Poll> { 803 | self.inner.poll_flush_unpin(cx) 804 | } 805 | 806 | fn poll_should_tx(&mut self, cx: &mut Context<'_>) -> Poll<()> { 807 | if self.tx_queue.is_empty() && self.handles.iter().all(|(_, h)| h.tx_queue.is_empty()) { 808 | self.register_should_tx_waker(cx); 809 | Poll::Pending 810 | } else { 811 | Poll::Ready(()) 812 | } 813 | } 814 | } 815 | --------------------------------------------------------------------------------