├── .github └── workflows │ └── rust.yml ├── .gitignore ├── .justfile ├── Cargo.toml ├── LICENSE ├── README.md ├── benches └── latency │ ├── endpoint.rs │ ├── main.rs │ └── server.rs ├── docs └── thumbnail.png ├── examples ├── common │ └── mod.rs ├── endpoint_with_timer.rs ├── http_single_connection.rs ├── io_service_dispatch.rs ├── io_service_with_auto_disconnect.rs ├── io_service_with_context.rs ├── io_service_with_direct_selector.rs ├── io_service_without_context.rs ├── polymorphic_endpoints.rs ├── recorded_stream.rs ├── replay_stream.rs └── ws_client.rs ├── rustfmt.toml └── src ├── buffer.rs ├── http └── mod.rs ├── inet.rs ├── lib.rs ├── service ├── endpoint.rs ├── mod.rs ├── node.rs ├── select │ ├── direct.rs │ ├── mio.rs │ └── mod.rs └── time.rs ├── stream ├── buffer.rs ├── file.rs ├── mio.rs ├── mod.rs ├── record.rs ├── replay.rs ├── tcp.rs └── tls.rs ├── util.rs └── ws ├── decoder.rs ├── ds.rs ├── encoder.rs ├── error.rs ├── handshake.rs ├── mod.rs ├── protocol.rs └── util.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: 6 | - '*' 7 | pull_request: 8 | types: [opened, synchronize, reopened] 9 | branches: 10 | - main 11 | 12 | env: 13 | CARGO_TERM_COLOR: always 14 | RUSTFLAGS: "-Dwarnings" 15 | 16 | jobs: 17 | 18 | msrv: 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | features: 23 | - "ext,http,ws,mio,openssl" 24 | - "ext,http,ws,mio,rustls-webpki" 25 | - "ext,http,ws,mio,rustls-native" 26 | steps: 27 | - uses: actions/checkout@v4 28 | - uses: taiki-e/install-action@cargo-hack 29 | - run: cargo hack check --rust-version --workspace --all-targets --features=${{ matrix.features }} --ignore-private 30 | 31 | test: 32 | runs-on: ubuntu-latest 33 | strategy: 34 | matrix: 35 | features: 36 | - "ext,http,ws,mio,openssl" 37 | - "ext,http,ws,mio,rustls-webpki" 38 | - "ext,http,ws,mio,rustls-native" 39 | steps: 40 | - uses: actions/checkout@v2 41 | - uses: actions/cache@v3 42 | id: cache-dependencies 43 | with: 44 | path: | 45 | ~/.cargo/registry 46 | ~/.cargo/git 47 | target 48 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 49 | - uses: actions-rs/toolchain@v1 50 | with: 51 | profile: minimal 52 | toolchain: stable 53 | override: true 54 | - uses: actions-rs/cargo@v1 55 | with: 56 | command: test 57 | args: --features=${{ matrix.features }} 58 | 59 | fmt: 60 | runs-on: ubuntu-latest 61 | steps: 62 | - uses: actions/checkout@v2 63 | - uses: actions-rs/toolchain@v1 64 | with: 65 | toolchain: stable 66 | override: true 67 | components: rustfmt 68 | - uses: actions-rs/cargo@v1 69 | with: 70 | command: fmt 71 | args: --all -- --check 72 | 73 | clippy: 74 | runs-on: ubuntu-latest 75 | strategy: 76 | matrix: 77 | features: 78 | - "ext,http,ws,mio,openssl" 79 | - "ext,http,ws,mio,rustls-webpki" 80 | - "ext,http,ws,mio,rustls-native" 81 | steps: 82 | - uses: actions/checkout@v3 83 | - uses: actions-rs/toolchain@v1 84 | with: 85 | toolchain: stable 86 | override: true 87 | components: clippy 88 | - uses: actions-rs/cargo@v1 89 | with: 90 | command: clippy 91 | args: --no-deps --all-targets --features=${{ matrix.features }} -- -D warnings 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /target 3 | Cargo.lock 4 | /*.rec 5 | scratch 6 | massif.out.* 7 | *.log -------------------------------------------------------------------------------- /.justfile: -------------------------------------------------------------------------------- 1 | # print options 2 | default: 3 | @just --list --unsorted 4 | 5 | # install cargo tools 6 | init: 7 | cargo upgrade --incompatible 8 | cargo update 9 | 10 | # check code 11 | check: 12 | cargo check 13 | cargo fmt --all -- --check 14 | cargo clippy --all-targets --features "full" 15 | 16 | # fix code 17 | fix: 18 | cargo fmt --all 19 | cargo clippy --allow-dirty --fix --features "full" 20 | 21 | # build project 22 | build: 23 | cargo build --features "full" 24 | 25 | # execute tests 26 | test: 27 | cargo test --features "full" 28 | 29 | # execute benchmarks 30 | bench: 31 | cargo bench --features "full" 32 | 33 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "boomnet" 3 | version = "0.0.53" 4 | edition = "2024" 5 | license = "MIT" 6 | description = "Framework for building low latency clients on top of TCP." 7 | readme = "./README.md" 8 | documentation = "https://docs.rs/boomnet" 9 | repository = "https://github.com/HaveFunTrading/boomnet" 10 | keywords = ["http", "async", "client", "websocket", "mio"] 11 | categories = ["network-programming", "web-programming::websocket", "web-programming::http-client"] 12 | rust-version = "1.85.1" 13 | 14 | [package.metadata.docs.rs] 15 | all-features = true 16 | 17 | [features] 18 | default = [] 19 | mio = ["dep:mio"] 20 | rustls-native = ["rustls", "rustls-native-certs"] 21 | rustls-webpki = ["rustls", "webpki-roots"] 22 | openssl = ["dep:openssl"] 23 | http = ["dep:http", "httparse", "memchr", "itoa", "smallvec"] 24 | ws = ["rand", "base64", "dep:http", "httparse"] 25 | ext = [] 26 | 27 | [dependencies] 28 | url = "2.5.0" 29 | thiserror = "1.0.50" 30 | log = "0.4.20" 31 | socket2 = { version = "0.5.5", features = ["all"] } 32 | pnet = "0.34.0" 33 | mio = { version = "1", features = ["net", "os-poll"], optional = true } 34 | rustls = { version = "0.22.4", optional = true } 35 | rand = { version = "0.9.1", optional = true } 36 | base64 = { version = "0.21.5", optional = true } 37 | httparse = { version = "1.8.0", optional = true } 38 | http = { version = "1.0.0", optional = true } 39 | openssl = { version = "0.10.70", features = ["vendored"], optional = true } 40 | memchr = { version = "2.7.4", optional = true } 41 | itoa = { version = "1.0.15", optional = true } 42 | smallvec = { version = "1.15.0", optional = true} 43 | 44 | [dependencies.webpki-roots] 45 | version = "0.26.0" 46 | optional = true 47 | 48 | [dependencies.rustls-native-certs] 49 | version = "0.7.0" 50 | optional = true 51 | 52 | [dev-dependencies] 53 | anyhow = "1" 54 | env_logger = "0.10.1" 55 | ansi_term = "0.12.1" 56 | tungstenite = "0.26.1" 57 | criterion = "0.5.1" 58 | idle = "0.2.0" 59 | core_affinity = "0.8.1" 60 | 61 | [lints.clippy] 62 | uninit_assumed_init = "allow" 63 | mem_replace_with_uninit = "allow" 64 | 65 | [profile.release] 66 | debug = true 67 | lto = true 68 | codegen-units = 1 69 | 70 | [[bench]] 71 | name = "latency" 72 | path = "benches/latency/main.rs" 73 | harness = false 74 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Tom Brzozowski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | [![Build Status]][actions] [![Latest Version]][crates.io] [![Docs Badge]][docs] [![License Badge]][license] 4 | 5 | [Build Status]: https://img.shields.io/endpoint.svg?url=https%3A%2F%2Factions-badge.atrox.dev%2Fhavefuntrading%2Fboomnet%2Fbadge%3Fref%3Dmain&style=flat&label=build&logo=none 6 | [actions]: https://actions-badge.atrox.dev/havefuntrading/boomnet/goto?ref=main 7 | [Latest Version]: https://img.shields.io/crates/v/boomnet.svg 8 | [crates.io]: https://crates.io/crates/boomnet 9 | [Docs Badge]: https://docs.rs/boomnet/badge.svg 10 | [docs]: https://docs.rs/boomnet 11 | [License Badge]: https://img.shields.io/badge/License-MIT-blue.svg 12 | [license]: LICENSE 13 | 14 | ## Overview 15 | BoomNet is a high-performance framework targeting development of low-latency network applications, 16 | particularly focusing on TCP stream-oriented clients that utilise various protocols. 17 | 18 | ## Installation 19 | Simply declare dependency on `boomnet` in your `Cargo.toml` and select desired [features](#features). 20 | ```toml 21 | [dependencies] 22 | boomnet = { version = "0.0.52", features = ["rustls-webpki", "ws", "ext"]} 23 | ``` 24 | 25 | ## Design Principles 26 | 27 | The framework is structured into multiple layers, with each subsequent layer building upon its predecessor, 28 | enhancing functionality and abstraction. 29 | 30 | ### Stream 31 | The first layer defines `stream` as abstraction over TCP connection, adhering to the following characteristics. 32 | 33 | * Must implement `Read` and `Write` traits for I/O operations. 34 | * Operates in a non-blocking manner. 35 | * Integrates with TLS using `rustls` or `openssl`. 36 | * Supports recording and replay of network byte streams. 37 | * Allows binding to specific network interface. 38 | * Facilitates implementation of TCP oriented client protocols such as WebSocket, HTTP, and FIX. 39 | 40 | Streams are designed to be fully generic, avoiding dynamic dispatch, and can be composed in flexible way. 41 | 42 | ```rust 43 | let stream: RecordedStream> = TcpStream::try_from((host, port))? 44 | .into_tls_stream() 45 | .into_default_recorded_stream(); 46 | ``` 47 | 48 | Different protocols can then be applied on top of a stream in order to create a client. 49 | ```rust 50 | let ws: Websocket>> = stream.into_websocket("/ws"); 51 | ``` 52 | 53 | ### Selector 54 | `Selector` provides abstraction over OS specific mechanisms (like `epoll`) for efficiently monitoring socket readiness events. 55 | Though primarily utilised internally, selectors are crucial for the `IOService` functionality, currently offering both 56 | `mio` and `direct` (no-op) implementations. 57 | 58 | ```rust 59 | let mut io_service = MioSelector::new()?.into_io_service(); 60 | ``` 61 | 62 | ### Service 63 | The last layer manages lifecycle of endpoints and provides auxiliary services (such as asynchronous DNS resolution and 64 | auto disconnect) through the `IOService`. 65 | 66 | `Endpoint` serves as low level construct for application logic. `IOService` oversees the connection lifecycle within endpoints. 67 | 68 | ## Protocols 69 | The aim is to support a variety of protocols, including WebSocket, HTTP, and FIX. 70 | 71 | ### Websocket 72 | The websocket client protocol complies with the [RFC 6455](https://datatracker.ietf.org/doc/html/rfc6455) specification, 73 | offering the following features. 74 | 75 | * Compatibility with any stream. 76 | * TCP batch-aware frame processing. 77 | * Not blocking on partial frame(s). 78 | * No memory allocations (except to initialise buffers) 79 | * Designed for zero-copy read and write. 80 | * Optional masking of outbound frames. 81 | * Standalone usage or in conjunction with `IOService`. 82 | 83 | ### Http 84 | Provides http 1.1 client that is compatible with any non-blocking stream and does perform memory allocations. 85 | 86 | ## Example Usage 87 | 88 | The repository contains comprehensive list of [examples](https://github.com/HaveFunTrading/boomnet/tree/main/examples). 89 | 90 | The following example illustrates how to use multiple websocket connections with `IOService` in order to consume messages from the Binance cryptocurrency 91 | exchange. First, we need to define and implement our `Endpoint`. The framework provides `TlsWebsocketEndpoint` trait 92 | that we can use. 93 | 94 | ```rust 95 | 96 | struct TradeEndpoint { 97 | id: u32, 98 | connection_info: ConnectionInfo, 99 | ws_endpoint: String, 100 | instrument: &'static str, 101 | } 102 | 103 | impl TradeEndpoint { 104 | pub fn new(id: u32, url: &'static str, instrument: &'static str) -> TradeEndpoint { 105 | let (connection_info, ws_endpoint, _) = boomnet::ws::util::parse_url(url).unwrap(); 106 | Self { id, connection_info, ws_endpoint, instrument, } 107 | } 108 | } 109 | 110 | impl ConnectionInfoProvider for TradeEndpoint { 111 | fn connection_info(&self) -> &ConnectionInfo { 112 | &self.connection_info 113 | } 114 | } 115 | 116 | impl TlsWebsocketEndpoint for TradeEndpoint { 117 | 118 | type Stream = MioStream; 119 | 120 | // called by the IO service whenever a connection has to be established for this endpoint 121 | fn create_websocket(&mut self, addr: SocketAddr) -> io::Result>> { 122 | 123 | let mut ws = TcpStream::try_from((&self.connection_info, addr))? 124 | .into_mio_stream() 125 | .into_tls_websocket(&self.ws_endpoint); 126 | 127 | // send subscription message 128 | ws.send_text( 129 | true, 130 | Some(format!(r#"{{"method":"SUBSCRIBE","params":["{}@trade"],"id":1}}"#, self.instrument).as_bytes()), 131 | )?; 132 | 133 | Ok(Some(ws)) 134 | } 135 | 136 | #[inline] 137 | fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { 138 | // iterate over available frames in the current batch 139 | for frame in ws.read_batch()? { 140 | if let WebsocketFrame::Text(fin, data) = frame? { 141 | println!("[{}] ({fin}) {}", self.id, String::from_utf8_lossy(data)); 142 | } 143 | } 144 | Ok(()) 145 | } 146 | } 147 | ``` 148 | 149 | After defining the endpoint, it is registered with the `IOService` and polled within an event loop. The service handles 150 | `Endpoint` connection management and reconnection in case of disconnection. 151 | 152 | ```rust 153 | 154 | fn main() -> anyhow::Result<()> { 155 | let mut io_service = MioSelector::new()?.into_io_service(); 156 | 157 | let endpoint_btc = TradeEndpoint::new(0, "wss://stream1.binance.com:443/ws", "btcusdt"); 158 | let endpoint_eth = TradeEndpoint::new(1, "wss://stream2.binance.com:443/ws", "ethusdt"); 159 | let endpoint_xrp = TradeEndpoint::new(2, "wss://stream3.binance.com:443/ws", "xrpusdt"); 160 | 161 | io_service.register(endpoint_btc); 162 | io_service.register(endpoint_eth); 163 | io_service.register(endpoint_xrp); 164 | 165 | loop { 166 | // will never block 167 | io_service.poll()?; 168 | } 169 | } 170 | ``` 171 | 172 | It is often required to expose shared state to the `Endpoint`. This can be achieved with user defined `Context`. 173 | 174 | ```rust 175 | struct FeedContext; 176 | 177 | // use the marker trait 178 | impl Context for FeedContext {} 179 | ``` 180 | 181 | When implementing our `TradeEndpoint` we can use `TlsWebsocketEndpointWithContext` trait instead. 182 | ```rust 183 | impl TlsWebsocketEndpointWithContext for TradeEndpoint { 184 | type Stream = MioStream; 185 | 186 | fn create_websocket(&mut self, addr: SocketAddr, ctx: &mut FeedContext) -> io::Result>> { 187 | // we now have access to context 188 | // ... 189 | } 190 | 191 | #[inline] 192 | fn poll(&mut self, ws: &mut TlsWebsocket, ctx: &mut FeedContext) -> io::Result<()> { 193 | // we now have access to context 194 | // ... 195 | Ok(()) 196 | } 197 | } 198 | ``` 199 | 200 | We will also need to create `IOService` that is `Context` aware. 201 | 202 | ```rust 203 | let mut context = FeedContext::new(); 204 | let mut io_service = MioSelector::new()?.into_io_service_with_context(&mut context); 205 | ``` 206 | 207 | The `Context` must now be passed to the service `poll` method. 208 | ```rust 209 | loop { 210 | io_service.poll(&mut context)?; 211 | } 212 | ``` 213 | 214 | ## Features 215 | The framework feature set is modular, allowing for tailored functionality based on project needs. 216 | 217 | * [mio](#mio) 218 | * [rustls-native](#rustls-native) 219 | * [rustls-webpki](#rustls-webpki) 220 | * [openssl](#openssl) 221 | * [ext](#ext) 222 | * [ws](#ws) 223 | * [http](#http) 224 | 225 | ### `mio` 226 | Adds dependency on `mio` crate and enables `MioSelector` and `MioStream`. 227 | 228 | ### `rustls-native` 229 | Adds dependency on `rustls` crate with `rustls-native-certs` and enables `TlsStream` as well as more flexible `TlsReadyStream`. 230 | 231 | ### `rustls-webpki` 232 | Adds dependency on `rustls` crate with `webpki-roots` and enables `TlsStream` as well as more flexible `TlsReadyStream`. 233 | 234 | ### `openssl` 235 | Adds dependency on `openssl` crate and enables `TlsStream` as well as more flexible `TlsReadyStream`. 236 | 237 | ### `ext` 238 | Adds various extensions that provide blanket trait implementations such as `TlsWebsocketEndpoint`. 239 | 240 | ### `ws` 241 | Adds support for `Websocket` protocol. 242 | 243 | ### `http` 244 | Adds support for `Http1.1` protocol. 245 | -------------------------------------------------------------------------------- /benches/latency/endpoint.rs: -------------------------------------------------------------------------------- 1 | use boomnet::service::endpoint::{Context, EndpointWithContext}; 2 | use boomnet::stream::buffer::{BufferedStream, IntoBufferedStream}; 3 | use boomnet::stream::tcp::TcpStream; 4 | use boomnet::stream::{ConnectionInfo, ConnectionInfoProvider}; 5 | use boomnet::ws::{IntoWebsocket, Websocket}; 6 | use std::hint::black_box; 7 | use std::net::SocketAddr; 8 | 9 | pub struct TestContext { 10 | pub wants_write: bool, 11 | pub processed: usize, 12 | } 13 | 14 | impl Context for TestContext {} 15 | 16 | impl TestContext { 17 | pub fn new() -> TestContext { 18 | Self { 19 | wants_write: true, 20 | processed: 0, 21 | } 22 | } 23 | } 24 | 25 | pub struct TestEndpoint { 26 | connection_info: ConnectionInfo, 27 | payload: &'static str, 28 | } 29 | 30 | impl ConnectionInfoProvider for TestEndpoint { 31 | fn connection_info(&self) -> &ConnectionInfo { 32 | &self.connection_info 33 | } 34 | } 35 | 36 | impl EndpointWithContext for TestEndpoint { 37 | type Target = Websocket>; 38 | 39 | fn create_target(&mut self, addr: SocketAddr, _ctx: &mut TestContext) -> std::io::Result> { 40 | let ws = self 41 | .connection_info 42 | .clone() 43 | .into_tcp_stream_with_addr(addr)? 44 | .into_default_buffered_stream() 45 | .into_websocket("/"); 46 | Ok(Some(ws)) 47 | } 48 | 49 | fn poll(&mut self, ws: &mut Self::Target, ctx: &mut TestContext) -> std::io::Result<()> { 50 | if ctx.wants_write { 51 | ws.send_text(true, Some(self.payload.as_bytes()))?; 52 | ctx.wants_write = false; 53 | } else { 54 | for frame in ws.read_batch()? { 55 | black_box(frame?); 56 | ctx.processed += 1; 57 | } 58 | } 59 | Ok(()) 60 | } 61 | } 62 | 63 | impl TestEndpoint { 64 | pub fn new(port: u16, payload: &'static str) -> Self { 65 | Self { 66 | connection_info: ConnectionInfo::new("127.0.0.1", port), 67 | payload, 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /benches/latency/main.rs: -------------------------------------------------------------------------------- 1 | use ::tungstenite::{Message, connect}; 2 | use criterion::{Criterion, Throughput, black_box, criterion_group, criterion_main}; 3 | use tungstenite::Utf8Bytes; 4 | 5 | use crate::endpoint::{TestContext, TestEndpoint}; 6 | use ::boomnet::stream::buffer::IntoBufferedStream; 7 | use ::boomnet::ws::IntoWebsocket; 8 | use boomnet::service::IntoIOServiceWithContext; 9 | use boomnet::service::select::direct::DirectSelector; 10 | use boomnet::stream::ConnectionInfo; 11 | 12 | mod endpoint; 13 | mod server; 14 | 15 | const MSG: &str = unsafe { std::str::from_utf8_unchecked(&[90u8; 256]) }; 16 | 17 | fn boomnet_rtt_benchmark(c: &mut Criterion) { 18 | let mut group = c.benchmark_group("boomnet"); 19 | group.throughput(Throughput::Bytes(MSG.len() as u64)); 20 | 21 | // run server in the background 22 | server::start_on_thread(9002); 23 | 24 | // affinity 25 | core_affinity::set_for_current(core_affinity::CoreId { id: 8 }); 26 | 27 | // setup client 28 | let mut ws = ConnectionInfo::new("127.0.0.1", 9002) 29 | .into_tcp_stream() 30 | .unwrap() 31 | .into_default_buffered_stream() 32 | .into_websocket("/"); 33 | 34 | group.bench_function("boomnet_rtt", |b| { 35 | b.iter(|| { 36 | ws.send_text(true, Some(MSG.as_bytes())).unwrap(); 37 | let mut received = 0; 38 | loop { 39 | for frame in ws.read_batch().unwrap() { 40 | black_box(frame.unwrap()); 41 | received += 1; 42 | } 43 | if received == 100 { 44 | break; 45 | } 46 | } 47 | }) 48 | }); 49 | 50 | group.finish(); 51 | } 52 | 53 | fn boomnet_rtt_benchmark_io_service(c: &mut Criterion) { 54 | let mut group = c.benchmark_group("boomnet"); 55 | group.throughput(Throughput::Bytes(MSG.len() as u64)); 56 | 57 | // run server in the background 58 | server::start_on_thread(9003); 59 | 60 | // affinity 61 | core_affinity::set_for_current(core_affinity::CoreId { id: 12 }); 62 | 63 | // setup io service 64 | let mut ctx = TestContext::new(); 65 | let mut io_service = DirectSelector::new().unwrap().into_io_service_with_context(&mut ctx); 66 | io_service.register(TestEndpoint::new(9003, MSG)); 67 | 68 | group.bench_function("boomnet_rtt_io_service", |b| { 69 | b.iter(|| { 70 | loop { 71 | io_service.poll(&mut ctx).unwrap(); 72 | if ctx.processed == 100 { 73 | ctx.wants_write = true; 74 | ctx.processed = 0; 75 | break; 76 | } 77 | } 78 | }) 79 | }); 80 | 81 | group.finish(); 82 | } 83 | 84 | fn tungstenite_rtt_benchmark(c: &mut Criterion) { 85 | let mut group = c.benchmark_group("tungstenite"); 86 | group.throughput(Throughput::Bytes(MSG.len() as u64)); 87 | 88 | // run server in the background 89 | server::start_on_thread(9001); 90 | 91 | // affinity 92 | core_affinity::set_for_current(core_affinity::CoreId { id: 10 }); 93 | 94 | // setup client 95 | let (mut ws, _) = connect("ws://127.0.0.1:9001").unwrap(); 96 | 97 | group.bench_function("tungstenite_rtt", |b| { 98 | b.iter(|| { 99 | ws.write(Message::Text(Utf8Bytes::from_static(MSG))).unwrap(); 100 | ws.flush().unwrap(); 101 | 102 | let mut received = 0; 103 | loop { 104 | if let Message::Text(data) = ws.read().unwrap() { 105 | black_box(data); 106 | received += 1; 107 | } 108 | if received == 100 { 109 | break; 110 | } 111 | } 112 | }) 113 | }); 114 | 115 | group.finish(); 116 | } 117 | 118 | criterion_group!(benches, boomnet_rtt_benchmark, boomnet_rtt_benchmark_io_service, tungstenite_rtt_benchmark); 119 | criterion_main!(benches); 120 | -------------------------------------------------------------------------------- /benches/latency/server.rs: -------------------------------------------------------------------------------- 1 | use std::net::TcpListener; 2 | use std::time::Duration; 3 | 4 | use tungstenite::accept; 5 | 6 | pub fn start_on_thread(port: u16) { 7 | let server = TcpListener::bind(format!("127.0.0.1:{port}")).unwrap(); 8 | std::thread::spawn(move || { 9 | if let Some(stream) = server.incoming().next() { 10 | let mut client = accept(stream.unwrap()).unwrap(); 11 | loop { 12 | let msg = client.read().unwrap(); 13 | for _ in 0..100 { 14 | client.write(msg.clone()).unwrap(); 15 | } 16 | client.flush().unwrap(); 17 | } 18 | } 19 | }); 20 | std::thread::sleep(Duration::from_secs(1)); 21 | } 22 | -------------------------------------------------------------------------------- /docs/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaveFunTrading/boomnet/69c12132982294a6a27c5f88b7f5399047ac6d35/docs/thumbnail.png -------------------------------------------------------------------------------- /examples/common/mod.rs: -------------------------------------------------------------------------------- 1 | use ansi_term::Color::{Green, Purple, Red, Yellow}; 2 | use boomnet::service::endpoint::Context; 3 | use boomnet::service::endpoint::ws::{TlsWebsocket, TlsWebsocketEndpoint, TlsWebsocketEndpointWithContext}; 4 | use boomnet::stream::mio::{IntoMioStream, MioStream}; 5 | use boomnet::stream::tcp::TcpStream; 6 | use boomnet::stream::tls::{IntoTlsStream, TlsConfigExt}; 7 | use boomnet::stream::{ConnectionInfo, ConnectionInfoProvider}; 8 | use boomnet::ws::{IntoTlsWebsocket, IntoWebsocket, WebsocketFrame}; 9 | use log::info; 10 | use std::io; 11 | use std::net::SocketAddr; 12 | 13 | pub struct FeedContext; 14 | impl Context for FeedContext {} 15 | 16 | impl FeedContext { 17 | #[allow(dead_code)] 18 | pub fn new() -> Self { 19 | Self 20 | } 21 | } 22 | 23 | pub struct TradeEndpoint { 24 | id: u32, 25 | connection_info: ConnectionInfo, 26 | instrument: &'static str, 27 | ws_endpoint: String, 28 | subscribe: bool, 29 | } 30 | 31 | impl TradeEndpoint { 32 | #[allow(dead_code)] 33 | pub fn new(id: u32, url: &'static str, net_iface: Option<&'static str>, instrument: &'static str) -> TradeEndpoint { 34 | Self::new_with_subscribe(id, url, net_iface, instrument, true) 35 | } 36 | 37 | pub fn new_with_subscribe( 38 | id: u32, 39 | url: &'static str, 40 | net_iface: Option<&'static str>, 41 | instrument: &'static str, 42 | subscribe: bool, 43 | ) -> TradeEndpoint { 44 | let (mut connection_info, ws_endpoint, _) = boomnet::ws::util::parse_url(url).unwrap(); 45 | if let Some(net_iface) = net_iface { 46 | connection_info = connection_info.with_net_iface_from_name(net_iface); 47 | } 48 | Self { 49 | id, 50 | connection_info, 51 | instrument, 52 | ws_endpoint, 53 | subscribe, 54 | } 55 | } 56 | 57 | pub fn subscribe(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { 58 | ws.send_text( 59 | true, 60 | Some(format!(r#"{{"method":"SUBSCRIBE","params":["{}@trade"],"id":1}}"#, self.instrument).as_bytes()), 61 | )?; 62 | Ok(()) 63 | } 64 | } 65 | 66 | impl ConnectionInfoProvider for TradeEndpoint { 67 | fn connection_info(&self) -> &ConnectionInfo { 68 | &self.connection_info 69 | } 70 | } 71 | 72 | impl TlsWebsocketEndpoint for TradeEndpoint { 73 | type Stream = MioStream; 74 | 75 | fn create_websocket(&mut self, addr: SocketAddr) -> io::Result>> { 76 | let mut ws = TcpStream::try_from((&self.connection_info, addr))? 77 | .into_mio_stream() 78 | .into_tls_stream_with_config(|cfg| cfg.with_no_cert_verification())? 79 | .into_websocket(&self.ws_endpoint); 80 | 81 | if self.subscribe { 82 | self.subscribe(&mut ws)?; 83 | } 84 | 85 | Ok(Some(ws)) 86 | } 87 | 88 | #[inline] 89 | fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { 90 | for frame in ws.read_batch()? { 91 | if let WebsocketFrame::Text(fin, data) = frame? { 92 | match self.id % 4 { 93 | 0 => info!("({fin}) {}", Red.paint(String::from_utf8_lossy(data))), 94 | 1 => info!("({fin}) {}", Green.paint(String::from_utf8_lossy(data))), 95 | 2 => info!("({fin}) {}", Purple.paint(String::from_utf8_lossy(data))), 96 | 3 => info!("({fin}) {}", Yellow.paint(String::from_utf8_lossy(data))), 97 | _ => {} 98 | } 99 | } 100 | } 101 | Ok(()) 102 | } 103 | } 104 | 105 | impl TlsWebsocketEndpointWithContext for TradeEndpoint { 106 | type Stream = MioStream; 107 | 108 | fn create_websocket( 109 | &mut self, 110 | addr: SocketAddr, 111 | _ctx: &mut FeedContext, 112 | ) -> io::Result>> { 113 | let mut ws = TcpStream::try_from((&self.connection_info, addr))? 114 | .into_mio_stream() 115 | .into_tls_websocket(&self.ws_endpoint)?; 116 | 117 | if self.subscribe { 118 | self.subscribe(&mut ws)?; 119 | } 120 | 121 | Ok(Some(ws)) 122 | } 123 | 124 | #[inline] 125 | fn poll(&mut self, ws: &mut TlsWebsocket, _ctx: &mut FeedContext) -> io::Result<()> { 126 | for frame in ws.read_batch()? { 127 | if let WebsocketFrame::Text(fin, data) = frame? { 128 | match self.id % 4 { 129 | 0 => info!("({fin}) {}", Red.paint(String::from_utf8_lossy(data))), 130 | 1 => info!("({fin}) {}", Green.paint(String::from_utf8_lossy(data))), 131 | 2 => info!("({fin}) {}", Purple.paint(String::from_utf8_lossy(data))), 132 | 3 => info!("({fin}) {}", Yellow.paint(String::from_utf8_lossy(data))), 133 | _ => {} 134 | } 135 | } 136 | } 137 | Ok(()) 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /examples/endpoint_with_timer.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::net::SocketAddr; 3 | use std::time::{Duration, SystemTime, UNIX_EPOCH}; 4 | 5 | use boomnet::service::IntoIOServiceWithContext; 6 | use boomnet::service::endpoint::Context; 7 | use boomnet::service::endpoint::ws::{TlsWebsocket, TlsWebsocketEndpointWithContext}; 8 | use boomnet::service::select::mio::MioSelector; 9 | use boomnet::stream::mio::{IntoMioStream, MioStream}; 10 | use boomnet::stream::{ConnectionInfo, ConnectionInfoProvider}; 11 | use boomnet::ws::{IntoTlsWebsocket, WebsocketFrame}; 12 | use log::info; 13 | use url::Url; 14 | 15 | /// This example demonstrates how to implement explicit timer inside the endpoint. Since endpoint 16 | /// poll method is called on every cycle by the io service we can implement timer functionality 17 | /// directly inside the endpoint. In this case, the endpoint will keep disconnecting every 10s. 18 | struct TradeEndpoint { 19 | connection_info: ConnectionInfo, 20 | instrument: &'static str, 21 | next_disconnect_time_ns: u64, 22 | } 23 | 24 | impl TradeEndpoint { 25 | pub fn new(url: &'static str, instrument: &'static str, ctx: &FeedContext) -> TradeEndpoint { 26 | let connection_info = Url::parse(url).try_into().unwrap(); 27 | Self { 28 | connection_info, 29 | instrument, 30 | next_disconnect_time_ns: ctx.current_time_ns() + Duration::from_secs(10).as_nanos() as u64, 31 | } 32 | } 33 | } 34 | 35 | #[derive(Debug)] 36 | struct FeedContext; 37 | 38 | impl Context for FeedContext {} 39 | 40 | impl FeedContext { 41 | pub fn new() -> Self { 42 | Self 43 | } 44 | 45 | pub fn current_time_ns(&self) -> u64 { 46 | SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_nanos() as u64 47 | } 48 | } 49 | 50 | impl ConnectionInfoProvider for TradeEndpoint { 51 | fn connection_info(&self) -> &ConnectionInfo { 52 | &self.connection_info 53 | } 54 | } 55 | 56 | impl TlsWebsocketEndpointWithContext for TradeEndpoint { 57 | type Stream = MioStream; 58 | 59 | fn create_websocket( 60 | &mut self, 61 | addr: SocketAddr, 62 | _ctx: &mut FeedContext, 63 | ) -> io::Result>> { 64 | let mut ws = self 65 | .connection_info 66 | .clone() 67 | .into_tcp_stream_with_addr(addr)? 68 | .into_mio_stream() 69 | .into_tls_websocket("/ws")?; 70 | 71 | ws.send_text( 72 | true, 73 | Some(format!(r#"{{"method":"SUBSCRIBE","params":["{}@trade"],"id":1}}"#, self.instrument).as_bytes()), 74 | )?; 75 | 76 | Ok(Some(ws)) 77 | } 78 | 79 | #[inline] 80 | fn poll(&mut self, ws: &mut TlsWebsocket, ctx: &mut FeedContext) -> io::Result<()> { 81 | while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { 82 | info!("({fin}) {}", String::from_utf8_lossy(data)); 83 | } 84 | let now_ns = ctx.current_time_ns(); 85 | if now_ns > self.next_disconnect_time_ns { 86 | self.next_disconnect_time_ns = now_ns + Duration::from_secs(10).as_nanos() as u64; 87 | return Err(io::Error::other("disconnected due to timer")); 88 | } 89 | Ok(()) 90 | } 91 | } 92 | 93 | fn main() -> anyhow::Result<()> { 94 | env_logger::init(); 95 | 96 | let mut ctx = FeedContext::new(); 97 | 98 | let mut io_service = MioSelector::new()?.into_io_service_with_context(&mut ctx); 99 | 100 | let endpoint_btc = TradeEndpoint::new("wss://stream1.binance.com:443/ws", "btcusdt", &ctx); 101 | 102 | io_service.register(endpoint_btc); 103 | 104 | loop { 105 | io_service.poll(&mut ctx)?; 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /examples/http_single_connection.rs: -------------------------------------------------------------------------------- 1 | use boomnet::http::{ConnectionPool, SingleTlsConnectionPool}; 2 | use http::Method; 3 | 4 | fn main() -> anyhow::Result<()> { 5 | let mut client = SingleTlsConnectionPool::new(("fapi.binance.com", 443)).into_http_client(); 6 | 7 | let request = 8 | client.new_request_with_headers(Method::GET, "/fapi/v1/depth?symbol=BTCUSDT", None, move |headers| { 9 | headers["FOO"] = "bar"; 10 | })?; 11 | 12 | // execute in blocking mode (will consume request) 13 | let (status_code, headers, body) = request.block()?; 14 | println!("{}", status_code); 15 | println!("{}", headers); 16 | println!("{}", body); 17 | 18 | // execute in async mode (we must provide own buffer) 19 | let mut request = client.new_request(Method::GET, "/fapi/v1/time", None)?; 20 | loop { 21 | if let Some((status_code, headers, body)) = request.poll()? { 22 | println!("{}", status_code); 23 | println!("{}", headers); 24 | println!("{}", body); 25 | break; 26 | } 27 | } 28 | 29 | // once the request is done, polling it again will just return the same data 30 | let (status_code, headers, body) = request.poll()?.unwrap(); 31 | println!("{}", status_code); 32 | println!("{}", headers); 33 | println!("{}", body); 34 | 35 | Ok(()) 36 | } 37 | -------------------------------------------------------------------------------- /examples/io_service_dispatch.rs: -------------------------------------------------------------------------------- 1 | use crate::common::TradeEndpoint; 2 | use boomnet::service::IntoIOService; 3 | use boomnet::service::select::mio::MioSelector; 4 | 5 | #[path = "common/mod.rs"] 6 | mod common; 7 | 8 | fn main() -> anyhow::Result<()> { 9 | env_logger::init(); 10 | 11 | let mut io_service = MioSelector::new()?.into_io_service(); 12 | 13 | let endpoint_xrp = TradeEndpoint::new_with_subscribe(2, "wss://stream3.binance.com:443/ws", None, "xrpusdt", false); 14 | 15 | let handle = io_service.register(endpoint_xrp); 16 | 17 | // we delay the subscription until the endpoint is ready 18 | loop { 19 | let success = io_service.dispatch(handle, |ws, endpoint| { 20 | endpoint.subscribe(ws)?; 21 | Ok(()) 22 | })?; 23 | if success { 24 | break; 25 | } else { 26 | io_service.poll()?; 27 | } 28 | } 29 | 30 | loop { 31 | io_service.poll()?; 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /examples/io_service_with_auto_disconnect.rs: -------------------------------------------------------------------------------- 1 | use crate::common::TradeEndpoint; 2 | use boomnet::service::IntoIOService; 3 | use boomnet::service::select::mio::MioSelector; 4 | use std::time::Duration; 5 | 6 | #[path = "common/mod.rs"] 7 | mod common; 8 | 9 | fn main() -> anyhow::Result<()> { 10 | env_logger::init(); 11 | 12 | let mut io_service = MioSelector::new()? 13 | .into_io_service() 14 | .with_auto_disconnect(Duration::from_secs(10)); 15 | 16 | let endpoint_btc = TradeEndpoint::new(0, "wss://stream1.binance.com:443/ws", None, "btcusdt"); 17 | 18 | io_service.register(endpoint_btc); 19 | 20 | loop { 21 | io_service.poll()?; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /examples/io_service_with_context.rs: -------------------------------------------------------------------------------- 1 | use crate::common::{FeedContext, TradeEndpoint}; 2 | use boomnet::service::IntoIOServiceWithContext; 3 | use boomnet::service::select::mio::MioSelector; 4 | 5 | #[path = "common/mod.rs"] 6 | mod common; 7 | 8 | fn main() -> anyhow::Result<()> { 9 | env_logger::init(); 10 | 11 | let mut context = FeedContext::new(); 12 | 13 | let mut io_service = MioSelector::new()?.into_io_service_with_context(&mut context); 14 | 15 | let endpoint_btc = TradeEndpoint::new(0, "wss://stream1.binance.com:443/ws", None, "btcusdt"); 16 | let endpoint_eth = TradeEndpoint::new(1, "wss://stream2.binance.com:443/ws", None, "ethusdt"); 17 | let endpoint_xrp = TradeEndpoint::new(2, "wss://stream3.binance.com:443/ws", None, "xrpusdt"); 18 | 19 | io_service.register(endpoint_btc); 20 | io_service.register(endpoint_eth); 21 | io_service.register(endpoint_xrp); 22 | 23 | loop { 24 | io_service.poll(&mut context)?; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /examples/io_service_with_direct_selector.rs: -------------------------------------------------------------------------------- 1 | use boomnet::inet::{IntoNetworkInterface, ToSocketAddr}; 2 | use boomnet::service::IntoIOService; 3 | use boomnet::service::endpoint::ws::{TlsWebsocket, TlsWebsocketEndpoint}; 4 | use boomnet::service::select::direct::DirectSelector; 5 | use boomnet::stream::{ConnectionInfo, ConnectionInfoProvider, tcp}; 6 | use boomnet::ws::{IntoTlsWebsocket, WebsocketFrame}; 7 | use std::io; 8 | use std::net::SocketAddr; 9 | use url::Url; 10 | 11 | struct TradeEndpoint { 12 | id: u32, 13 | connection_info: ConnectionInfo, 14 | instrument: &'static str, 15 | ws_endpoint: String, 16 | } 17 | 18 | impl TradeEndpoint { 19 | pub fn new(id: u32, url: &'static str, net_iface: Option<&'static str>, instrument: &'static str) -> TradeEndpoint { 20 | let url = Url::parse(url).unwrap(); 21 | let mut connection_info = ConnectionInfo::try_from(url.clone()).unwrap(); 22 | let ws_endpoint = url.path().to_owned(); 23 | let net_iface = net_iface 24 | .and_then(|name| name.into_network_interface()) 25 | .and_then(|iface| iface.to_socket_addr()); 26 | if let Some(net_iface) = net_iface { 27 | connection_info = connection_info.with_net_iface(net_iface); 28 | } 29 | Self { 30 | id, 31 | connection_info, 32 | instrument, 33 | ws_endpoint, 34 | } 35 | } 36 | } 37 | 38 | impl ConnectionInfoProvider for TradeEndpoint { 39 | fn connection_info(&self) -> &ConnectionInfo { 40 | &self.connection_info 41 | } 42 | } 43 | 44 | impl TlsWebsocketEndpoint for TradeEndpoint { 45 | type Stream = tcp::TcpStream; 46 | 47 | fn create_websocket(&mut self, addr: SocketAddr) -> io::Result>> { 48 | let mut ws = self 49 | .connection_info 50 | .clone() 51 | .into_tcp_stream_with_addr(addr)? 52 | .into_tls_websocket(&self.ws_endpoint)?; 53 | ws.send_text( 54 | true, 55 | Some(format!(r#"{{"method":"SUBSCRIBE","params":["{}@trade"],"id":1}}"#, self.instrument).as_bytes()), 56 | )?; 57 | 58 | Ok(Some(ws)) 59 | } 60 | 61 | #[inline] 62 | fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { 63 | for frame in ws.read_batch()? { 64 | if let WebsocketFrame::Text(fin, data) = frame? { 65 | println!("[{}] ({fin}) {}", self.id, String::from_utf8_lossy(data)) 66 | } 67 | } 68 | Ok(()) 69 | } 70 | } 71 | 72 | fn main() -> anyhow::Result<()> { 73 | env_logger::init(); 74 | 75 | let mut io_service = DirectSelector::new()?.into_io_service(); 76 | 77 | let endpoint_btc = TradeEndpoint::new(0, "wss://stream1.binance.com:443/ws", None, "btcusdt"); 78 | let endpoint_eth = TradeEndpoint::new(1, "wss://stream2.binance.com:443/ws", None, "ethusdt"); 79 | let endpoint_xrp = TradeEndpoint::new(2, "wss://stream3.binance.com:443/ws", None, "xrpusdt"); 80 | 81 | io_service.register(endpoint_btc); 82 | io_service.register(endpoint_eth); 83 | io_service.register(endpoint_xrp); 84 | 85 | loop { 86 | io_service.poll()?; 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /examples/io_service_without_context.rs: -------------------------------------------------------------------------------- 1 | use crate::common::TradeEndpoint; 2 | use boomnet::service::IntoIOService; 3 | use boomnet::service::select::mio::MioSelector; 4 | 5 | #[path = "common/mod.rs"] 6 | mod common; 7 | 8 | fn main() -> anyhow::Result<()> { 9 | env_logger::init(); 10 | 11 | let mut io_service = MioSelector::new()?.into_io_service(); 12 | 13 | let endpoint_btc = TradeEndpoint::new(0, "wss://stream1.binance.com:443/ws", None, "btcusdt"); 14 | let endpoint_eth = TradeEndpoint::new(1, "wss://stream2.binance.com:443/ws", None, "ethusdt"); 15 | let endpoint_xrp = TradeEndpoint::new(2, "wss://stream3.binance.com:443/ws", None, "xrpusdt"); 16 | 17 | io_service.register(endpoint_btc); 18 | io_service.register(endpoint_eth); 19 | io_service.register(endpoint_xrp); 20 | 21 | loop { 22 | io_service.poll()?; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /examples/polymorphic_endpoints.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused)] 2 | 3 | use std::io; 4 | use std::net::{SocketAddr, TcpStream}; 5 | use std::time::Duration; 6 | 7 | use boomnet::inet::{IntoNetworkInterface, ToSocketAddr}; 8 | use boomnet::service::endpoint::Context; 9 | use boomnet::service::endpoint::ws::{TlsWebsocket, TlsWebsocketEndpoint, TlsWebsocketEndpointWithContext}; 10 | use boomnet::service::select::mio::MioSelector; 11 | use boomnet::service::{IntoIOService, IntoIOServiceWithContext}; 12 | use boomnet::stream::mio::{IntoMioStream, MioStream}; 13 | use boomnet::stream::tls::TlsStream; 14 | use boomnet::stream::{BindAndConnect, ConnectionInfo, ConnectionInfoProvider}; 15 | use boomnet::ws::{IntoTlsWebsocket, Websocket, WebsocketFrame}; 16 | use idle::IdleStrategy; 17 | use log::info; 18 | use url::Url; 19 | 20 | enum MarketDataEndpoint { 21 | Trade(TradeEndpoint), 22 | Ticker(TickerEndpoint), 23 | } 24 | 25 | impl ConnectionInfoProvider for MarketDataEndpoint { 26 | fn connection_info(&self) -> &ConnectionInfo { 27 | match self { 28 | MarketDataEndpoint::Ticker(ticker) => ticker.connection_info(), 29 | MarketDataEndpoint::Trade(trade) => trade.connection_info(), 30 | } 31 | } 32 | } 33 | 34 | impl TlsWebsocketEndpoint for MarketDataEndpoint { 35 | type Stream = MioStream; 36 | 37 | fn create_websocket(&mut self, addr: SocketAddr) -> io::Result>>> { 38 | match self { 39 | MarketDataEndpoint::Ticker(ticker) => ticker.create_websocket(addr), 40 | MarketDataEndpoint::Trade(trade) => trade.create_websocket(addr), 41 | } 42 | } 43 | 44 | fn poll(&mut self, ws: &mut Websocket>) -> io::Result<()> { 45 | match self { 46 | MarketDataEndpoint::Ticker(ticker) => TlsWebsocketEndpoint::poll(ticker, ws), 47 | MarketDataEndpoint::Trade(trade) => TlsWebsocketEndpoint::poll(trade, ws), 48 | } 49 | } 50 | } 51 | 52 | struct TradeEndpoint { 53 | id: u32, 54 | connection_info: ConnectionInfo, 55 | instrument: &'static str, 56 | } 57 | 58 | impl TradeEndpoint { 59 | pub fn new(id: u32, url: &'static str, instrument: &'static str) -> TradeEndpoint { 60 | let connection_info = Url::parse(url).try_into().unwrap(); 61 | Self { 62 | id, 63 | connection_info, 64 | instrument, 65 | } 66 | } 67 | } 68 | 69 | impl ConnectionInfoProvider for TradeEndpoint { 70 | fn connection_info(&self) -> &ConnectionInfo { 71 | &self.connection_info 72 | } 73 | } 74 | 75 | impl TlsWebsocketEndpoint for TradeEndpoint { 76 | type Stream = MioStream; 77 | 78 | fn create_websocket(&mut self, addr: SocketAddr) -> io::Result>> { 79 | let mut ws = self 80 | .connection_info 81 | .clone() 82 | .into_tcp_stream_with_addr(addr)? 83 | .into_mio_stream() 84 | .into_tls_websocket("/ws")?; 85 | 86 | ws.send_text( 87 | true, 88 | Some(format!(r#"{{"method":"SUBSCRIBE","params":["{}@trade"],"id":1}}"#, self.instrument).as_bytes()), 89 | )?; 90 | 91 | Ok(Some(ws)) 92 | } 93 | 94 | #[inline] 95 | fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { 96 | while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { 97 | info!("({fin}) {}", String::from_utf8_lossy(data)); 98 | } 99 | Ok(()) 100 | } 101 | } 102 | 103 | struct TickerEndpoint { 104 | id: u32, 105 | connection_info: ConnectionInfo, 106 | instrument: &'static str, 107 | } 108 | 109 | impl TickerEndpoint { 110 | pub fn new(id: u32, url: &'static str, instrument: &'static str) -> TickerEndpoint { 111 | let connection_info = Url::parse(url).try_into().unwrap(); 112 | Self { 113 | id, 114 | connection_info, 115 | instrument, 116 | } 117 | } 118 | } 119 | 120 | impl ConnectionInfoProvider for TickerEndpoint { 121 | fn connection_info(&self) -> &ConnectionInfo { 122 | &self.connection_info 123 | } 124 | } 125 | 126 | impl TlsWebsocketEndpoint for TickerEndpoint { 127 | type Stream = MioStream; 128 | 129 | fn create_websocket(&mut self, addr: SocketAddr) -> io::Result>> { 130 | let mut ws = self 131 | .connection_info 132 | .clone() 133 | .into_tcp_stream_with_addr(addr)? 134 | .into_mio_stream() 135 | .into_tls_websocket("/ws")?; 136 | 137 | ws.send_text( 138 | true, 139 | Some(format!(r#"{{"method":"SUBSCRIBE","params":["{}@ticker"],"id":1}}"#, self.instrument).as_bytes()), 140 | )?; 141 | 142 | Ok(Some(ws)) 143 | } 144 | 145 | #[inline] 146 | fn poll(&mut self, ws: &mut TlsWebsocket) -> io::Result<()> { 147 | while let Some(Ok(WebsocketFrame::Text(fin, data))) = ws.receive_next() { 148 | info!("({fin}) {}", String::from_utf8_lossy(data)); 149 | } 150 | Ok(()) 151 | } 152 | } 153 | 154 | fn main() -> anyhow::Result<()> { 155 | env_logger::init(); 156 | 157 | let mut io_service = MioSelector::new()?.into_io_service(); 158 | 159 | let ticker = MarketDataEndpoint::Ticker(TickerEndpoint::new(0, "wss://stream.binance.com:443/ws", "btcusdt")); 160 | let trade = MarketDataEndpoint::Trade(TradeEndpoint::new(1, "wss://stream.binance.com:443/ws", "ethusdt")); 161 | 162 | io_service.register(ticker); 163 | io_service.register(trade); 164 | 165 | loop { 166 | io_service.poll()?; 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /examples/recorded_stream.rs: -------------------------------------------------------------------------------- 1 | use boomnet::stream::ConnectionInfo; 2 | use boomnet::stream::record::IntoRecordedStream; 3 | use boomnet::stream::tls::IntoTlsStream; 4 | use boomnet::ws::{IntoWebsocket, WebsocketFrame}; 5 | use idle::IdleStrategy; 6 | use std::time::Duration; 7 | 8 | fn main() -> anyhow::Result<()> { 9 | let mut ws = ConnectionInfo::new("stream.binance.com", 9443) 10 | .into_tcp_stream()? 11 | .into_tls_stream()? 12 | .into_default_recorded_stream() 13 | .into_websocket("/ws"); 14 | 15 | ws.send_text(true, Some(r#"{"method":"SUBSCRIBE","params":["btcusdt@trade"],"id":1}"#.to_string().as_bytes()))?; 16 | 17 | let idle = IdleStrategy::Sleep(Duration::from_millis(1)); 18 | 19 | loop { 20 | for frame in ws.read_batch()? { 21 | if let WebsocketFrame::Text(fin, body) = frame? { 22 | println!("({fin}) {}", String::from_utf8_lossy(body)); 23 | } 24 | } 25 | idle.idle(0); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /examples/replay_stream.rs: -------------------------------------------------------------------------------- 1 | use std::io::ErrorKind::UnexpectedEof; 2 | 3 | use boomnet::stream::replay::ReplayStream; 4 | use boomnet::ws::{Error, IntoWebsocket, WebsocketFrame}; 5 | 6 | fn main() -> anyhow::Result<()> { 7 | let mut ws = ReplayStream::from_file("plain_inbound")?.into_websocket("/ws"); 8 | 9 | fn run Result<(), Error>>(f: F) -> anyhow::Result<()> { 10 | match f() { 11 | Err(Error::IO(io_error)) if io_error.kind() == UnexpectedEof => Ok(()), 12 | Err(err) => Err(err)?, 13 | _ => Ok(()), 14 | } 15 | } 16 | 17 | run(|| { 18 | loop { 19 | for frame in ws.read_batch()? { 20 | if let WebsocketFrame::Text(fin, body) = frame? { 21 | println!("({fin}) {}", String::from_utf8_lossy(body)); 22 | } 23 | } 24 | } 25 | })?; 26 | 27 | Ok(()) 28 | } 29 | -------------------------------------------------------------------------------- /examples/ws_client.rs: -------------------------------------------------------------------------------- 1 | use boomnet::stream::tcp::TcpStream; 2 | use boomnet::stream::tls::IntoTlsStream; 3 | use boomnet::ws::{IntoWebsocket, WebsocketFrame}; 4 | use idle::IdleStrategy; 5 | use std::time::Duration; 6 | 7 | fn main() -> anyhow::Result<()> { 8 | let mut ws = TcpStream::try_from(("stream.binance.com", 9443))? 9 | .into_tls_stream()? 10 | .into_websocket("/ws?timeUnit=microsecond"); 11 | 12 | ws.send_text(true, Some(b"{\"method\":\"SUBSCRIBE\",\"params\":[\"btcusdt@trade\"],\"id\":1}"))?; 13 | 14 | let idle = IdleStrategy::Sleep(Duration::from_millis(1)); 15 | 16 | loop { 17 | for frame in ws.read_batch()? { 18 | if let WebsocketFrame::Text(fin, body) = frame? { 19 | println!("({fin}) {}", String::from_utf8_lossy(body)); 20 | } 21 | } 22 | idle.idle(0); 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | edition = "2021" 2 | max_width = 120 3 | fn_call_width = 120 4 | reorder_imports = true 5 | -------------------------------------------------------------------------------- /src/buffer.rs: -------------------------------------------------------------------------------- 1 | //! Fixed length buffer for reading data from the network. 2 | //! 3 | //! The buffer should be used when implementing protocols on top of streams. It offers 4 | //! a number of methods to retrieve the bytes with zero-copy semantics. 5 | 6 | use std::io::Read; 7 | use std::{io, ptr}; 8 | 9 | use crate::util::NoBlock; 10 | 11 | const DEFAULT_INITIAL_CAPACITY: usize = 32768; 12 | 13 | #[derive(Debug)] 14 | pub struct ReadBuffer { 15 | inner: Vec, 16 | head: usize, 17 | tail: usize, 18 | } 19 | 20 | /// Reading mode that controls [ReadBuffer::read_from] data limit. 21 | enum ReadMode { 22 | /// Try to read up to one chunk of data. 23 | Chunk, 24 | /// Try to read all available data up to the buffer capacity. 25 | Available, 26 | } 27 | 28 | impl Default for ReadBuffer { 29 | fn default() -> Self { 30 | Self::new() 31 | } 32 | } 33 | 34 | impl ReadBuffer { 35 | pub fn new() -> ReadBuffer { 36 | assert!( 37 | CHUNK_SIZE <= INITIAL_CAPACITY, 38 | "CHUNK_SIZE ({CHUNK_SIZE}) must be less or equal than {INITIAL_CAPACITY}" 39 | ); 40 | Self { 41 | inner: vec![0u8; INITIAL_CAPACITY], 42 | head: 0, 43 | tail: 0, 44 | } 45 | } 46 | 47 | #[inline] 48 | pub const fn available(&self) -> usize { 49 | self.tail - self.head 50 | } 51 | 52 | /// Reads up to `CHUNK_SIZE` into buffer from the provided `stream`. If there is no more space 53 | /// available to accommodate the next read of up to chunk size, the buffer will grow by a factor of 2. 54 | #[inline] 55 | pub fn read_from(&mut self, stream: &mut S) -> io::Result<()> { 56 | self.read_from_with_mode(stream, ReadMode::Chunk) 57 | } 58 | 59 | /// Reads all available bytes into buffer from the provided `stream`. If there is no more space 60 | /// available to accommodate the next read of up to `CHUNK_SIZE`, the buffer will grow by a factor of 2. 61 | /// This method is usually preferred to [`ReadBuffer::read_from`] as it takes advantage of all available 62 | /// space in the buffer therefore reducing the number of operating system calls and increasing the throughput. 63 | #[inline] 64 | pub fn read_all_from(&mut self, stream: &mut S) -> io::Result<()> { 65 | self.read_from_with_mode(stream, ReadMode::Available) 66 | } 67 | 68 | #[inline] 69 | fn read_from_with_mode(&mut self, stream: &mut S, read_mode: ReadMode) -> io::Result<()> { 70 | #[cold] 71 | fn grow(buf: &mut Vec) { 72 | buf.resize(buf.len() * 2, 0u8); 73 | } 74 | 75 | #[cold] 76 | fn compact( 77 | buf: &mut ReadBuffer, 78 | ) { 79 | unsafe { ptr::copy(buf.inner.as_ptr().add(buf.head), buf.inner.as_mut_ptr(), buf.available()) } 80 | buf.tail -= buf.head; 81 | buf.head = 0; 82 | } 83 | 84 | // compact 85 | if self.head > 0 && self.available() > 0 { 86 | compact(self); 87 | } 88 | 89 | // clear 90 | if self.head > 0 && self.available() == 0 { 91 | self.head = 0; 92 | self.tail = 0; 93 | } 94 | 95 | // ensure capacity for at least one chunk 96 | if self.tail + CHUNK_SIZE > self.inner.capacity() { 97 | grow(&mut self.inner); 98 | } 99 | 100 | let read = match read_mode { 101 | ReadMode::Chunk => stream.read(&mut self.inner[self.tail..self.tail + CHUNK_SIZE]), 102 | ReadMode::Available => stream.read(&mut self.inner[self.tail..]), 103 | }; 104 | 105 | self.tail += read.no_block()?; 106 | Ok(()) 107 | } 108 | 109 | #[inline] 110 | pub fn consume_next(&mut self, len: usize) -> Option<&'static [u8]> { 111 | match self.available() >= len { 112 | true => Some(unsafe { self.consume_next_unchecked(len) }), 113 | false => None, 114 | } 115 | } 116 | 117 | /// # Safety 118 | /// This function should only be called after `available` bytes are known. 119 | /// ```no_run 120 | /// use boomnet::buffer::ReadBuffer; 121 | /// 122 | /// let mut buffer = ReadBuffer::<4096>::new(); 123 | /// if buffer.available() > 10 { 124 | /// unsafe { 125 | /// let view = buffer.consume_next_unchecked(10); 126 | /// } 127 | /// } 128 | #[inline] 129 | pub unsafe fn consume_next_unchecked(&mut self, len: usize) -> &'static [u8] { 130 | unsafe { 131 | let consumed_view = &*ptr::slice_from_raw_parts(self.inner.as_ptr().add(self.head), len); 132 | self.head += len; 133 | consumed_view 134 | } 135 | } 136 | 137 | #[inline] 138 | pub fn consume_next_byte(&mut self) -> Option { 139 | match self.available() >= 1 { 140 | true => Some(unsafe { self.consume_next_byte_unchecked() }), 141 | false => None, 142 | } 143 | } 144 | 145 | /// # Safety 146 | /// This function should only be called after `available` bytes are known. 147 | /// ```no_run 148 | /// use boomnet::buffer::ReadBuffer; 149 | /// 150 | /// let mut buffer = ReadBuffer::<4096>::new(); 151 | /// if buffer.available() > 0 { 152 | /// unsafe { 153 | /// let byte = buffer.consume_next_byte_unchecked(); 154 | /// } 155 | /// } 156 | #[inline] 157 | pub unsafe fn consume_next_byte_unchecked(&mut self) -> u8 { 158 | unsafe { 159 | let byte = *self.inner.as_ptr().add(self.head); 160 | self.head += 1; 161 | byte 162 | } 163 | } 164 | 165 | #[inline] 166 | pub fn view(&self) -> &[u8] { 167 | &self.inner[self.head..self.tail] 168 | } 169 | 170 | #[inline] 171 | pub fn view_last(&self, len: usize) -> &[u8] { 172 | &self.inner[self.tail - len..self.tail] 173 | } 174 | } 175 | 176 | #[cfg(test)] 177 | mod tests { 178 | use std::io::Cursor; 179 | use std::io::ErrorKind::{UnexpectedEof, WouldBlock}; 180 | 181 | use super::*; 182 | 183 | #[test] 184 | fn should_read_from_stream() { 185 | let mut buf = ReadBuffer::<16>::new(); 186 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 187 | assert_eq!(0, buf.head); 188 | assert_eq!(0, buf.tail); 189 | 190 | let mut stream = Cursor::new(b"hello world!"); 191 | buf.read_from(&mut stream).expect("unable to read from the stream"); 192 | 193 | assert_eq!(12, buf.available()); 194 | assert_eq!(b"hello world!", buf.view()); 195 | 196 | assert_eq!(b"hello ", buf.consume_next(6).unwrap()); 197 | assert_eq!(6, buf.available()); 198 | assert_eq!(b"world!", buf.view()); 199 | 200 | assert_eq!(b"world!", buf.consume_next(6).unwrap()); 201 | assert_eq!(0, buf.available()); 202 | assert_eq!(b"", buf.view()); 203 | 204 | assert_eq!(12, buf.head, "head"); 205 | assert_eq!(12, buf.tail, "tail"); 206 | assert_eq!(0, buf.available()); 207 | 208 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 209 | } 210 | 211 | #[test] 212 | fn should_read_all_from_stream() { 213 | let mut buf = ReadBuffer::<8>::new(); 214 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 215 | assert_eq!(0, buf.head); 216 | assert_eq!(0, buf.tail); 217 | 218 | let mut stream = Cursor::new(b"hello world!"); 219 | buf.read_all_from(&mut stream).expect("unable to read from the stream"); 220 | 221 | assert_eq!(12, buf.available()); 222 | assert_eq!(b"hello world!", buf.view()); 223 | } 224 | 225 | #[test] 226 | fn should_append_on_multiple_read() { 227 | let mut buf = ReadBuffer::<6>::new(); 228 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 229 | 230 | let mut stream = Cursor::new(b"hello world!"); 231 | 232 | buf.read_from(&mut stream).expect("unable to read from the stream"); 233 | assert_eq!(b"hello ", buf.view()); 234 | 235 | buf.read_from(&mut stream).expect("unable to read from the stream"); 236 | assert_eq!(b"hello world!", buf.view()); 237 | 238 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 239 | } 240 | 241 | #[test] 242 | fn should_clear_on_multiple_read() { 243 | let mut buf = ReadBuffer::<6>::new(); 244 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 245 | 246 | let mut stream = Cursor::new(b"hello world you are amazing!"); 247 | 248 | buf.read_from(&mut stream).expect("unable to read from the stream"); 249 | assert_eq!(b"hello ", buf.view()); 250 | 251 | assert_eq!(b"hello ", buf.consume_next(6).unwrap()); 252 | assert_eq!(0, buf.available()); 253 | assert_eq!(b"", buf.view()); 254 | 255 | buf.read_from(&mut stream).expect("unable to read from the stream"); 256 | assert_eq!(b"world ", buf.view()); 257 | assert_eq!(0, buf.head); 258 | assert_eq!(6, buf.tail); 259 | 260 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 261 | } 262 | 263 | #[test] 264 | fn should_compact_if_any_leftover_before_next_read() { 265 | let mut buf = ReadBuffer::<6>::new(); 266 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 267 | 268 | let mut stream = Cursor::new(b"hello world you are amazing!"); 269 | 270 | buf.read_from(&mut stream).expect("unable to read from the stream"); 271 | assert_eq!(b"hello ", buf.view()); 272 | 273 | assert_eq!(b"he", buf.consume_next(2).unwrap()); 274 | assert_eq!(4, buf.available()); 275 | assert_eq!(b"llo ", buf.view()); 276 | 277 | buf.read_from(&mut stream).expect("unable to read from the stream"); 278 | assert_eq!(10, buf.available()); 279 | assert_eq!(b"llo world ", buf.view()); 280 | assert_eq!(0, buf.head); 281 | assert_eq!(10, buf.tail); 282 | 283 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 284 | } 285 | 286 | #[test] 287 | fn should_return_none_if_too_many_bytes_requested_to_view() { 288 | let mut buf = ReadBuffer::<6>::new(); 289 | let mut stream = Cursor::new(b"hello world!"); 290 | buf.read_from(&mut stream).expect("unable to read from the stream"); 291 | 292 | assert_eq!(b"hello ", buf.view()); 293 | assert_eq!(None, buf.consume_next(7)); 294 | } 295 | 296 | #[test] 297 | fn should_return_empty_buffer_if_no_data() { 298 | let buf = ReadBuffer::<6>::new(); 299 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 300 | assert_eq!(b"", buf.view()); 301 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 302 | } 303 | 304 | #[test] 305 | fn should_grow_when_appending() { 306 | let mut buf = ReadBuffer::<1, 8>::new(); 307 | assert_eq!(8, buf.inner.len()); 308 | let mut stream = Cursor::new(b"hello world!"); 309 | while stream.position() < 12 { 310 | buf.read_from(&mut stream).expect("unable to read from the stream"); 311 | } 312 | assert_eq!(b"hello world!", buf.view()); 313 | assert_eq!(16, buf.inner.len()); 314 | } 315 | 316 | #[test] 317 | fn should_handle_reader_with_no_data() { 318 | struct StreamWithNoData; 319 | 320 | impl Read for StreamWithNoData { 321 | fn read(&mut self, _buf: &mut [u8]) -> io::Result { 322 | Err(io::Error::new(WouldBlock, "would block")) 323 | } 324 | } 325 | 326 | let mut stream = StreamWithNoData {}; 327 | let mut buf = ReadBuffer::<8>::new(); 328 | 329 | buf.read_from(&mut stream).expect("unable to read from the stream"); 330 | assert_eq!(b"", buf.view()); 331 | assert_eq!(DEFAULT_INITIAL_CAPACITY, buf.inner.len()); 332 | } 333 | 334 | #[test] 335 | fn should_propagate_errors() { 336 | struct FaultyStream; 337 | 338 | impl Read for FaultyStream { 339 | fn read(&mut self, _buf: &mut [u8]) -> io::Result { 340 | Err(io::Error::new(UnexpectedEof, "eof")) 341 | } 342 | } 343 | 344 | let mut stream = FaultyStream {}; 345 | let mut buf = ReadBuffer::<8>::new(); 346 | 347 | buf.read_from(&mut stream).expect_err("expected eof error"); 348 | } 349 | 350 | #[test] 351 | fn should_consume_next() { 352 | let mut buf = ReadBuffer::<64>::new(); 353 | let mut stream = Cursor::new(b"hello world!"); 354 | buf.read_from(&mut stream).expect("unable to read from the stream"); 355 | 356 | assert_eq!(b"hello world!", buf.view()); 357 | assert_eq!(b"hello", buf.consume_next(5).unwrap()); 358 | assert_eq!(b" ", buf.consume_next(1).unwrap()); 359 | assert_eq!(b"world!", buf.consume_next(6).unwrap()); 360 | assert_eq!(0, buf.available()) 361 | } 362 | 363 | #[test] 364 | fn should_consume_next_byte() { 365 | let mut buf = ReadBuffer::<64>::new(); 366 | let mut stream = Cursor::new(b"hello world!"); 367 | buf.read_from(&mut stream).expect("unable to read from the stream"); 368 | 369 | assert_eq!(b"hello world!", buf.view()); 370 | assert_eq!(b'h', buf.consume_next_byte().unwrap()); 371 | assert_eq!(b'e', buf.consume_next_byte().unwrap()); 372 | assert_eq!(b'l', buf.consume_next_byte().unwrap()); 373 | assert_eq!(b'l', buf.consume_next_byte().unwrap()); 374 | assert_eq!(b'o', buf.consume_next_byte().unwrap()); 375 | assert_eq!(b' ', buf.consume_next_byte().unwrap()); 376 | assert_eq!(b"world!", buf.consume_next(6).unwrap()); 377 | assert_eq!(0, buf.available()) 378 | } 379 | 380 | #[test] 381 | fn should_view_last() { 382 | let mut buf = ReadBuffer::<64>::new(); 383 | let mut stream = Cursor::new(b"hello world!"); 384 | buf.read_from(&mut stream).expect("unable to read from the stream"); 385 | 386 | assert_eq!(b"hello world!", buf.view()); 387 | assert_eq!(b"world!", buf.view_last(6)); 388 | assert_eq!(12, buf.available()) 389 | } 390 | } 391 | -------------------------------------------------------------------------------- /src/http/mod.rs: -------------------------------------------------------------------------------- 1 | //! This module provides a reusable HTTP1.1 client built on top of a generic `ConnectionPool` trait. 2 | //! 3 | //! # Examples 4 | //! 5 | //! ```no_run 6 | //! // Create a TLS connection pool 7 | //! use http::Method; 8 | //! use boomnet::http::{ConnectionPool, HttpClient, SingleTlsConnectionPool}; 9 | //! use boomnet::stream::ConnectionInfo; 10 | //! 11 | //! let mut client = SingleTlsConnectionPool::new(ConnectionInfo::new("example.com", 443)).into_http_client(); 12 | //! 13 | //! // Send a GET request and block until complete 14 | //! let (status, headers, body) = client 15 | //! .new_request(Method::GET, "/", None) 16 | //! .unwrap() 17 | //! .block() 18 | //! .unwrap(); 19 | //! 20 | //! println!("Status: {}", status); 21 | //! println!("Headers: {}", headers); 22 | //! println!("Body: {}", body); 23 | //! ``` 24 | 25 | use crate::stream::ConnectionInfo; 26 | use crate::stream::buffer::{BufferedStream, IntoBufferedStream}; 27 | use crate::stream::tcp::TcpStream; 28 | use crate::stream::tls::{IntoTlsStream, TlsConfigExt, TlsStream}; 29 | use crate::util::NoBlock; 30 | 31 | use httparse::{EMPTY_HEADER, Response}; 32 | use memchr::arch::all::rabinkarp::Finder; 33 | use std::cell::RefCell; 34 | use std::io; 35 | use std::io::{ErrorKind, Read, Write}; 36 | use std::ops::{Index, IndexMut}; 37 | use std::rc::Rc; 38 | 39 | // re-export 40 | pub use http::Method; 41 | use smallvec::SmallVec; 42 | 43 | type HttpTlsConnection = Connection>>; 44 | 45 | /// Re-usable container to store headers 46 | #[derive(Default)] 47 | pub struct Headers<'a> { 48 | inner: SmallVec<[(&'a str, &'a str); 32]>, 49 | } 50 | 51 | impl<'a> Index<&'a str> for Headers<'a> { 52 | type Output = &'a str; 53 | 54 | // Look up the first* matching header 55 | // panics if not found 56 | fn index(&self, key: &'a str) -> &Self::Output { 57 | for pair in &self.inner { 58 | if pair.0 == key { 59 | return &pair.1; 60 | } 61 | } 62 | panic!("no header named `{}`", key); 63 | } 64 | } 65 | 66 | impl<'a> IndexMut<&'a str> for Headers<'a> { 67 | fn index_mut(&mut self, key: &'a str) -> &mut Self::Output { 68 | // we push (key, "") and then hand back a &mut to the `&'a str` slot 69 | self.inner.push((key, "")); 70 | &mut self.inner.last_mut().unwrap().1 71 | } 72 | } 73 | 74 | impl<'a> Headers<'a> { 75 | /// Append key-value header to the outgoing request. 76 | #[inline] 77 | pub fn insert(&mut self, key: &'a str, value: &'a str) { 78 | self.inner.push((key, value)); 79 | } 80 | 81 | #[inline] 82 | fn is_empty(&self) -> bool { 83 | self.inner.is_empty() 84 | } 85 | 86 | #[inline] 87 | fn iter(&self) -> impl Iterator { 88 | self.inner.iter() 89 | } 90 | 91 | #[inline] 92 | fn clear(&mut self) -> &mut Self { 93 | self.inner.clear(); 94 | self 95 | } 96 | } 97 | 98 | /// A generic HTTP client that uses a pooled connection strategy. 99 | pub struct HttpClient { 100 | connection_pool: Rc>, 101 | headers: Headers<'static>, 102 | } 103 | 104 | impl HttpClient { 105 | /// Create a new HTTP client from the provided pool. 106 | pub fn new(connection_pool: C) -> HttpClient { 107 | Self { 108 | connection_pool: Rc::new(RefCell::new(connection_pool)), 109 | headers: Headers { 110 | inner: SmallVec::with_capacity(32), 111 | }, 112 | } 113 | } 114 | 115 | /// Prepare a request with custom headers and optional body. 116 | /// 117 | /// # Examples 118 | /// 119 | /// ```no_run 120 | /// use http::Method; 121 | /// use boomnet::http::{ConnectionPool, HttpClient, SingleTlsConnectionPool}; 122 | /// use boomnet::stream::ConnectionInfo; 123 | /// 124 | /// let mut client = SingleTlsConnectionPool::new(ConnectionInfo::new("example.com", 443)).into_http_client(); 125 | /// 126 | /// let req = client.new_request_with_headers( 127 | /// Method::POST, 128 | /// "/submit", 129 | /// Some(b"data"), 130 | /// |hdrs| { 131 | /// hdrs["X-Custom"] = "Value"; 132 | /// } 133 | /// ).unwrap(); 134 | /// ``` 135 | pub fn new_request_with_headers( 136 | &mut self, 137 | method: Method, 138 | path: impl AsRef, 139 | body: Option<&[u8]>, 140 | builder: F, 141 | ) -> io::Result> 142 | where 143 | F: FnOnce(&mut Headers), 144 | { 145 | builder(self.headers.clear()); 146 | let conn = self 147 | .connection_pool 148 | .borrow_mut() 149 | .acquire()? 150 | .ok_or_else(|| io::Error::other("no available connection"))?; 151 | let request = HttpRequest::new(method, path, body, &self.headers, conn, self.connection_pool.clone())?; 152 | Ok(request) 153 | } 154 | 155 | /// Prepare a request with no additional headers and optional body. 156 | /// 157 | /// # Examples 158 | /// 159 | /// ```no_run 160 | /// use http::Method; 161 | /// use boomnet::http::{ConnectionPool , SingleTlsConnectionPool}; 162 | /// use boomnet::stream::ConnectionInfo; 163 | /// 164 | /// let mut client = SingleTlsConnectionPool::new(ConnectionInfo::new("example.com", 443)).into_http_client(); 165 | /// let req = client.new_request( 166 | /// Method::POST, 167 | /// "/submit", 168 | /// Some(b"data"), 169 | /// ).unwrap(); 170 | /// ``` 171 | pub fn new_request( 172 | &mut self, 173 | method: Method, 174 | path: impl AsRef, 175 | body: Option<&[u8]>, 176 | ) -> io::Result> { 177 | self.new_request_with_headers(method, path, body, |_| {}) 178 | } 179 | } 180 | 181 | /// Trait defining a pool of reusable connections. 182 | pub trait ConnectionPool: Sized { 183 | /// Underlying stream type. 184 | type Stream: Read + Write; 185 | 186 | /// Turn this connection pool into http client. 187 | fn into_http_client(self) -> HttpClient { 188 | HttpClient::new(self) 189 | } 190 | 191 | /// Hostname for requests. 192 | fn host(&self) -> &str; 193 | 194 | /// Acquire next free connection, if available. 195 | fn acquire(&mut self) -> io::Result>>; 196 | 197 | /// Release a connection back into the pool. 198 | fn release(&mut self, stream: Option>); 199 | } 200 | 201 | /// A single-connection pool over TLS, reconnecting on demand. 202 | pub struct SingleTlsConnectionPool { 203 | connection_info: ConnectionInfo, 204 | conn: Option, 205 | has_active_connection: bool, 206 | } 207 | 208 | impl SingleTlsConnectionPool { 209 | /// Build a new TLS pool for the given connection info. 210 | pub fn new(connection_info: impl Into) -> SingleTlsConnectionPool { 211 | Self { 212 | connection_info: connection_info.into(), 213 | conn: None, 214 | has_active_connection: false, 215 | } 216 | } 217 | } 218 | 219 | impl ConnectionPool for SingleTlsConnectionPool { 220 | type Stream = BufferedStream>; 221 | 222 | fn host(&self) -> &str { 223 | self.connection_info.host() 224 | } 225 | 226 | fn acquire(&mut self) -> io::Result>> { 227 | match (self.conn.take(), self.has_active_connection) { 228 | (Some(_), true) => { 229 | // we can at most have one active connection 230 | unreachable!() 231 | } 232 | (Some(stream), false) => { 233 | self.has_active_connection = true; 234 | Ok(Some(stream)) 235 | } 236 | (None, true) => Ok(None), 237 | (None, false) => { 238 | let stream = self 239 | .connection_info 240 | .clone() 241 | .into_tcp_stream()? 242 | .into_tls_stream_with_config(|tls_cfg| tls_cfg.with_no_cert_verification())? 243 | .into_default_buffered_stream(); 244 | self.has_active_connection = true; 245 | Ok(Some(Connection::new(stream))) 246 | } 247 | } 248 | } 249 | 250 | fn release(&mut self, conn: Option>) { 251 | self.has_active_connection = false; 252 | if let Some(conn) = conn { 253 | if !conn.disconnected { 254 | let _ = self.conn.insert(conn); 255 | } 256 | } 257 | } 258 | } 259 | 260 | /// Represents an in-flight HTTP exchange. 261 | pub struct HttpRequest { 262 | conn: Option>, 263 | pool: Rc>, 264 | state: State, 265 | } 266 | 267 | #[derive(Debug, Eq, PartialEq)] 268 | enum State { 269 | ReadingHeaders, 270 | ReadingBody { 271 | header_len: usize, 272 | content_len: usize, 273 | status_code: u16, 274 | }, 275 | Done { 276 | header_len: usize, 277 | status_code: u16, 278 | }, 279 | } 280 | 281 | impl HttpRequest { 282 | fn new( 283 | method: Method, 284 | path: impl AsRef, 285 | body: Option<&[u8]>, 286 | headers: &Headers, 287 | mut conn: Connection, 288 | pool: Rc>, 289 | ) -> io::Result> { 290 | conn.write_all(method.as_str().as_bytes())?; 291 | conn.write_all(b" ")?; 292 | conn.write_all(path.as_ref().as_bytes())?; 293 | conn.write_all(b" HTTP/1.1\r\nHost: ")?; 294 | conn.write_all(pool.borrow().host().as_bytes())?; 295 | if !headers.is_empty() { 296 | conn.write_all(b"\r\n")?; 297 | for header in headers.iter() { 298 | conn.write_all(header.0.as_bytes())?; 299 | conn.write_all(b": ")?; 300 | conn.write_all(header.1.as_bytes())?; 301 | conn.write_all(b"\r\n")?; 302 | } 303 | if let Some(body) = body { 304 | conn.write_all(b"Content-Length: ")?; 305 | let mut buf = itoa::Buffer::new(); 306 | conn.write_all(buf.format(body.len()).as_bytes())?; 307 | conn.write_all(b"\r\n")?; 308 | } 309 | conn.write_all(b"\r\n")?; 310 | } else if let Some(body) = body { 311 | conn.write_all(b"\r\n")?; 312 | conn.write_all(b"Content-Length: ")?; 313 | let mut buf = itoa::Buffer::new(); 314 | conn.write_all(buf.format(body.as_ref().len()).as_bytes())?; 315 | conn.write_all(b"\r\n\r\n")?; 316 | } else { 317 | conn.write_all(b"\r\n\r\n")?; 318 | } 319 | if let Some(body) = body { 320 | conn.write_all(body)?; 321 | } 322 | conn.flush()?; 323 | Ok(Self { 324 | conn: Some(conn), 325 | pool, 326 | state: State::ReadingHeaders, 327 | }) 328 | } 329 | 330 | /// Block until the full response is available. 331 | #[inline] 332 | pub fn block(mut self) -> io::Result<(u16, String, String)> { 333 | loop { 334 | if let Some((status_code, headers, body)) = self.poll()? { 335 | return Ok((status_code, headers.to_owned(), body.to_owned())); 336 | } 337 | } 338 | } 339 | 340 | /// Read from the stream and return when complete. Must provide buffer that will hold the response. 341 | /// It's ok to re-use the buffer as long as it's been cleared before using it with a new request. 342 | /// 343 | /// # Example 344 | /// ```no_run 345 | /// use http::Method; 346 | /// use boomnet::http::{ConnectionPool , SingleTlsConnectionPool}; 347 | /// use boomnet::stream::ConnectionInfo; 348 | /// 349 | /// let mut client = SingleTlsConnectionPool::new(ConnectionInfo::new("example.com", 443)).into_http_client(); 350 | /// 351 | /// let mut request = client.new_request_with_headers( 352 | /// Method::POST, 353 | /// "/submit", 354 | /// Some(b"data"), 355 | /// |hdrs| { 356 | /// hdrs["X-Custom"] = "Value"; 357 | /// } 358 | /// ).unwrap(); 359 | /// 360 | /// loop { 361 | /// if let Some((status_code, headers, body)) = request.poll().unwrap() { 362 | /// println!("{}", status_code); 363 | /// println!("{}", headers); 364 | /// println!("{}", body); 365 | /// break; 366 | /// } 367 | /// } 368 | /// 369 | /// ``` 370 | pub fn poll(&mut self) -> io::Result> { 371 | if let Some(conn) = self.conn.as_mut() { 372 | match self.state { 373 | State::ReadingHeaders | State::ReadingBody { .. } => conn.poll()?, 374 | State::Done { .. } => {} 375 | } 376 | match self.state { 377 | State::ReadingHeaders => { 378 | if conn.buffer.len() >= 4 { 379 | if let Some(headers_end) = conn.header_finder.find(&conn.buffer, b"\r\n\r\n") { 380 | let header_len = headers_end + 4; 381 | let header_slice = &conn.buffer[..header_len]; 382 | // now parse headers 383 | let mut headers = [EMPTY_HEADER; 32]; 384 | let mut resp = Response::new(&mut headers); 385 | match resp.parse(header_slice) { 386 | Ok(httparse::Status::Complete(_)) => { 387 | let status_code = resp 388 | .code 389 | .ok_or_else(|| io::Error::new(ErrorKind::InvalidData, "missing status code"))?; 390 | let mut content_len = 0; 391 | for header in resp.headers { 392 | if header.name.eq_ignore_ascii_case("Content-Length") { 393 | content_len = std::str::from_utf8(header.value) 394 | .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))? 395 | .parse() 396 | .map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; 397 | break; 398 | } 399 | } 400 | self.state = State::ReadingBody { 401 | header_len, 402 | content_len, 403 | status_code, 404 | }; 405 | } 406 | Ok(httparse::Status::Partial) => { 407 | return Err(io::Error::new(ErrorKind::InvalidData, "unable to parse headers")); 408 | } 409 | Err(err) => return Err(io::Error::new(ErrorKind::InvalidData, err)), 410 | } 411 | } 412 | } 413 | } 414 | State::ReadingBody { 415 | header_len, 416 | content_len, 417 | status_code, 418 | } => { 419 | let total_len = header_len + content_len; 420 | if conn.buffer.len() >= total_len { 421 | self.state = State::Done { 422 | header_len, 423 | status_code, 424 | }; 425 | } 426 | } 427 | State::Done { 428 | header_len, 429 | status_code, 430 | } => { 431 | let (headers, body) = conn.buffer.split_at(header_len); 432 | let headers = 433 | std::str::from_utf8(headers).map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; 434 | let body = std::str::from_utf8(body).map_err(|e| io::Error::new(ErrorKind::InvalidData, e))?; 435 | return Ok(Some((status_code, headers, body))); 436 | } 437 | } 438 | } 439 | Ok(None) 440 | } 441 | } 442 | 443 | impl Drop for HttpRequest { 444 | fn drop(&mut self) { 445 | if let Some(conn) = self.conn.as_mut() { 446 | conn.buffer.clear(); 447 | } 448 | self.pool.borrow_mut().release(self.conn.take()); 449 | } 450 | } 451 | 452 | /// Connection managed by the `ConnectionPool`. Binds underlying stream together with buffer used 453 | /// for reading data. The reading is performed in chunks with default size of 1024 bytes. 454 | pub struct Connection { 455 | stream: S, 456 | buffer: Vec, 457 | disconnected: bool, 458 | header_finder: Finder, 459 | } 460 | 461 | impl Connection { 462 | #[inline] 463 | fn poll(&mut self) -> io::Result<()> { 464 | if self.disconnected { 465 | return Err(io::Error::new(ErrorKind::NotConnected, "connection closed")); 466 | } 467 | let mut chunk = [0u8; CHUNK_SIZE]; 468 | match self.stream.read(&mut chunk).no_block() { 469 | Ok(read) => { 470 | if read > 0 { 471 | self.buffer.extend_from_slice(&chunk[..read]); 472 | } 473 | Ok(()) 474 | } 475 | Err(err) => { 476 | self.disconnected = true; 477 | Err(err) 478 | } 479 | } 480 | } 481 | } 482 | 483 | impl Write for Connection { 484 | #[inline] 485 | fn write(&mut self, buf: &[u8]) -> io::Result { 486 | self.stream.write(buf) 487 | } 488 | 489 | #[inline] 490 | fn flush(&mut self) -> io::Result<()> { 491 | self.stream.flush() 492 | } 493 | } 494 | 495 | impl Connection { 496 | #[inline] 497 | fn new(stream: S) -> Self { 498 | Self { 499 | stream, 500 | buffer: Vec::with_capacity(CHUNK_SIZE), 501 | disconnected: false, 502 | header_finder: Finder::new(b"\r\n\r\n"), 503 | } 504 | } 505 | } 506 | 507 | #[cfg(test)] 508 | mod tests { 509 | use super::*; 510 | 511 | #[test] 512 | fn should_insert_headers() { 513 | let mut headers = Headers::default(); 514 | 515 | headers["hello"] = "world"; 516 | headers["foo"] = "bar"; 517 | 518 | let mut iter = headers.iter(); 519 | 520 | let (key, value) = iter.next().unwrap(); 521 | assert_eq!((&"hello", &"world"), (key, value)); 522 | assert_eq!("world", headers["hello"]); 523 | 524 | let (key, value) = iter.next().unwrap(); 525 | assert_eq!((&"foo", &"bar"), (key, value)); 526 | assert_eq!("bar", headers["foo"]); 527 | 528 | assert!(iter.next().is_none()); 529 | } 530 | } 531 | -------------------------------------------------------------------------------- /src/inet.rs: -------------------------------------------------------------------------------- 1 | //! Utilities related to working with network interfaces. 2 | 3 | use std::net::SocketAddr; 4 | 5 | use pnet::datalink; 6 | use pnet::datalink::NetworkInterface; 7 | 8 | pub trait FromNetworkInterfaceName { 9 | fn from_net_iface_name(iface_name: &str) -> Option; 10 | } 11 | 12 | impl FromNetworkInterfaceName for NetworkInterface { 13 | fn from_net_iface_name(iface_name: &str) -> Option { 14 | datalink::interfaces() 15 | .into_iter() 16 | .find(|iface| iface.name == iface_name) 17 | } 18 | } 19 | 20 | pub trait IntoNetworkInterface { 21 | fn into_network_interface(self) -> Option; 22 | } 23 | 24 | impl IntoNetworkInterface for T 25 | where 26 | T: AsRef, 27 | { 28 | fn into_network_interface(self) -> Option { 29 | NetworkInterface::from_net_iface_name(self.as_ref()) 30 | } 31 | } 32 | 33 | pub trait ToSocketAddr { 34 | fn to_socket_addr(self) -> Option; 35 | } 36 | 37 | impl ToSocketAddr for NetworkInterface { 38 | fn to_socket_addr(self) -> Option { 39 | let ip_addr = self.ips.iter().find(|ip| ip.is_ipv4())?.ip(); 40 | Some(SocketAddr::new(ip_addr, 0)) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod buffer; 2 | #[cfg(feature = "http")] 3 | pub mod http; 4 | pub mod inet; 5 | pub mod service; 6 | pub mod stream; 7 | mod util; 8 | #[cfg(feature = "ws")] 9 | pub mod ws; 10 | -------------------------------------------------------------------------------- /src/service/endpoint.rs: -------------------------------------------------------------------------------- 1 | //! Entry point for the application logic. 2 | 3 | use crate::stream::ConnectionInfoProvider; 4 | use std::io; 5 | use std::net::SocketAddr; 6 | 7 | /// Entry point for the application logic. Endpoints are registered and Managed by 'IOService'. 8 | pub trait Endpoint: ConnectionInfoProvider { 9 | /// Defines protocol and stream this endpoint operates on. 10 | type Target; 11 | 12 | /// Used by the `IOService` to create connection upon disconnect by passing resolved `addr`. 13 | /// If the endpoint does not want to connect at this stage it should return `Ok(None)` and 14 | /// await the next connection attempt with (possibly) different `addr`. 15 | fn create_target(&mut self, addr: SocketAddr) -> io::Result>; 16 | 17 | /// Called by the `IOService` on each duty cycle. 18 | fn poll(&mut self, target: &mut Self::Target) -> io::Result<()>; 19 | 20 | /// Upon disconnection `IOService` will query the endpoint if the connection can be 21 | /// recreated. If not, it will cause program to panic. 22 | fn can_recreate(&mut self) -> bool { 23 | true 24 | } 25 | 26 | /// When `auto_disconnect` is used the service will check with the endpoint before 27 | /// disconnecting. If `false` is returned the service will update the endpoint next 28 | /// disconnect time as per the `auto_disconnect` configuration. 29 | fn can_auto_disconnect(&mut self) -> bool { 30 | true 31 | } 32 | } 33 | 34 | /// Marker trait to be applied on user defined `struct` that is registered with 'IOService' 35 | /// as context. 36 | pub trait Context {} 37 | 38 | /// Entry point for the application logic that exposes user provided [Context]. 39 | /// Endpoints are registered and Managed by `IOService`. 40 | pub trait EndpointWithContext: ConnectionInfoProvider { 41 | /// Defines protocol and stream this endpoint operates on. 42 | type Target; 43 | 44 | /// Used by the `IOService` to create connection upon disconnect passing resolved `addr` and 45 | /// user provided `Context`. If the endpoint does not want to connect at this stage it should 46 | /// return `Ok(None)` and await the next connection attempt with (possibly) different `addr`. 47 | fn create_target(&mut self, addr: SocketAddr, context: &mut C) -> io::Result>; 48 | 49 | /// Called by the `IOService` on each duty cycle passing user provided `Context`. 50 | fn poll(&mut self, target: &mut Self::Target, context: &mut C) -> io::Result<()>; 51 | 52 | /// Upon disconnection `IOService` will query the endpoint if the connection can be 53 | /// recreated. If not, it will cause program to panic. 54 | fn can_recreate(&mut self, _context: &mut C) -> bool { 55 | true 56 | } 57 | 58 | /// When `auto_disconnect` is used the service will check with the endpoint before 59 | /// disconnecting. If `false` is returned the service will update the endpoint next 60 | /// disconnect time as per the `auto_disconnect` configuration. 61 | fn can_auto_disconnect(&mut self, _context: &mut C) -> bool { 62 | true 63 | } 64 | } 65 | 66 | #[cfg(all(feature = "ext", feature = "ws", any(feature = "rustls", feature = "openssl")))] 67 | pub mod ws { 68 | use std::io; 69 | use std::io::{Read, Write}; 70 | use std::net::SocketAddr; 71 | 72 | use crate::service::endpoint::{Endpoint, EndpointWithContext}; 73 | use crate::stream::ConnectionInfoProvider; 74 | use crate::stream::tls::TlsStream; 75 | use crate::ws::Websocket; 76 | 77 | pub type TlsWebsocket = Websocket>; 78 | 79 | pub trait TlsWebsocketEndpoint: ConnectionInfoProvider { 80 | type Stream: Read + Write; 81 | 82 | fn create_websocket(&mut self, addr: SocketAddr) -> io::Result>>>; 83 | 84 | fn poll(&mut self, ws: &mut Websocket>) -> io::Result<()>; 85 | 86 | fn can_recreate(&mut self) -> bool { 87 | true 88 | } 89 | 90 | fn can_auto_disconnect(&mut self) -> bool { 91 | true 92 | } 93 | } 94 | 95 | impl Endpoint for T 96 | where 97 | T: TlsWebsocketEndpoint, 98 | { 99 | type Target = Websocket>; 100 | 101 | #[inline] 102 | fn create_target(&mut self, addr: SocketAddr) -> io::Result> { 103 | self.create_websocket(addr) 104 | } 105 | 106 | #[inline] 107 | fn poll(&mut self, target: &mut Self::Target) -> io::Result<()> { 108 | self.poll(target) 109 | } 110 | 111 | #[inline] 112 | fn can_recreate(&mut self) -> bool { 113 | self.can_recreate() 114 | } 115 | 116 | #[inline] 117 | fn can_auto_disconnect(&mut self) -> bool { 118 | self.can_auto_disconnect() 119 | } 120 | } 121 | 122 | pub trait TlsWebsocketEndpointWithContext: ConnectionInfoProvider { 123 | type Stream: Read + Write; 124 | 125 | fn create_websocket( 126 | &mut self, 127 | addr: SocketAddr, 128 | ctx: &mut C, 129 | ) -> io::Result>>>; 130 | 131 | fn poll(&mut self, ws: &mut Websocket>, ctx: &mut C) -> io::Result<()>; 132 | 133 | fn can_recreate(&mut self, _ctx: &mut C) -> bool { 134 | true 135 | } 136 | 137 | fn can_auto_disconnect(&mut self, _ctx: &mut C) -> bool { 138 | true 139 | } 140 | } 141 | 142 | impl EndpointWithContext for T 143 | where 144 | T: TlsWebsocketEndpointWithContext, 145 | { 146 | type Target = Websocket>; 147 | 148 | #[inline] 149 | fn create_target(&mut self, addr: SocketAddr, context: &mut C) -> io::Result> { 150 | self.create_websocket(addr, context) 151 | } 152 | 153 | #[inline] 154 | fn poll(&mut self, target: &mut Self::Target, context: &mut C) -> io::Result<()> { 155 | self.poll(target, context) 156 | } 157 | 158 | #[inline] 159 | fn can_recreate(&mut self, context: &mut C) -> bool { 160 | self.can_recreate(context) 161 | } 162 | 163 | #[inline] 164 | fn can_auto_disconnect(&mut self, context: &mut C) -> bool { 165 | self.can_auto_disconnect(context) 166 | } 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /src/service/mod.rs: -------------------------------------------------------------------------------- 1 | //! Service to manage multiple endpoint lifecycle. 2 | 3 | use std::collections::{HashMap, VecDeque}; 4 | use std::io; 5 | use std::marker::PhantomData; 6 | use std::net::{SocketAddr, ToSocketAddrs}; 7 | use std::time::Duration; 8 | 9 | use crate::service::endpoint::{Context, Endpoint, EndpointWithContext}; 10 | use crate::service::node::IONode; 11 | use crate::service::select::{Selector, SelectorToken}; 12 | use crate::service::time::{SystemTimeClockSource, TimeSource}; 13 | use crate::stream::ConnectionInfo; 14 | use log::{error, warn}; 15 | 16 | pub mod endpoint; 17 | mod node; 18 | pub mod select; 19 | pub mod time; 20 | 21 | const ENDPOINT_CREATION_THROTTLE_NS: u64 = Duration::from_secs(1).as_nanos() as u64; 22 | 23 | /// Endpoint handle. 24 | #[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Default)] 25 | #[repr(transparent)] 26 | pub struct Handle(SelectorToken); 27 | 28 | /// Handles the lifecycle of endpoints (see [`Endpoint`]), which are typically network connections. 29 | /// It uses `SelectService` pattern for managing asynchronous I/O operations. 30 | pub struct IOService { 31 | selector: S, 32 | pending_endpoints: VecDeque<(Handle, E)>, 33 | io_nodes: HashMap>, 34 | next_endpoint_create_time_ns: u64, 35 | context: PhantomData, 36 | auto_disconnect: Option, 37 | time_source: TS, 38 | } 39 | 40 | /// Defines how an instance that implements `SelectService` can be transformed 41 | /// into an [`IOService`], facilitating the management of asynchronous I/O operations. 42 | pub trait IntoIOService { 43 | fn into_io_service(self) -> IOService 44 | where 45 | Self: Selector, 46 | Self: Sized; 47 | } 48 | 49 | /// Defines how an instance that implements [`Selector`] can be transformed 50 | /// into an [`IOService`] with [`Context`], facilitating the management of asynchronous I/O operations. 51 | pub trait IntoIOServiceWithContext { 52 | fn into_io_service_with_context(self, context: &mut C) -> IOService 53 | where 54 | Self: Selector, 55 | Self: Sized; 56 | } 57 | 58 | impl IOService { 59 | /// Creates new instance of [`IOService`]. 60 | pub fn new(selector: S, time_source: TS) -> IOService { 61 | Self { 62 | selector, 63 | pending_endpoints: VecDeque::new(), 64 | io_nodes: HashMap::new(), 65 | next_endpoint_create_time_ns: 0, 66 | context: PhantomData, 67 | auto_disconnect: None, 68 | time_source, 69 | } 70 | } 71 | 72 | /// Specify TTL for each [`Endpoint`] connection. 73 | pub fn with_auto_disconnect(self, auto_disconnect: Duration) -> IOService { 74 | Self { 75 | auto_disconnect: Some(auto_disconnect), 76 | ..self 77 | } 78 | } 79 | 80 | /// Specify custom [`TimeSource`] instead of the default system time source. 81 | pub fn with_time_source(self, time_source: T) -> IOService { 82 | IOService { 83 | time_source, 84 | pending_endpoints: self.pending_endpoints, 85 | context: self.context, 86 | auto_disconnect: self.auto_disconnect, 87 | io_nodes: self.io_nodes, 88 | next_endpoint_create_time_ns: self.next_endpoint_create_time_ns, 89 | selector: self.selector, 90 | } 91 | } 92 | 93 | /// Registers a new [`Endpoint`] with the service and return a handle to it. 94 | pub fn register(&mut self, endpoint: E) -> Handle { 95 | let handle = Handle(self.selector.next_token()); 96 | self.pending_endpoints.push_back((handle, endpoint)); 97 | handle 98 | } 99 | 100 | /// Deregister [`Endpoint`] with the service based on a handle. 101 | pub fn deregister(&mut self, handle: Handle) -> Option { 102 | match self.io_nodes.remove(&handle.0) { 103 | Some(io_node) => Some(io_node.into_endpoint().1), 104 | None => { 105 | let mut index_to_remove = None; 106 | for (index, endpoint) in self.pending_endpoints.iter().enumerate() { 107 | if endpoint.0 == handle { 108 | index_to_remove = Some(index); 109 | break; 110 | } 111 | } 112 | if let Some(index_to_remove) = index_to_remove { 113 | self.pending_endpoints 114 | .remove(index_to_remove) 115 | .map(|(_, endpoint)| endpoint) 116 | } else { 117 | None 118 | } 119 | } 120 | } 121 | } 122 | 123 | /// Return iterator over active endpoints, additionally exposing handle and the stream. 124 | #[inline] 125 | pub fn iter(&self) -> impl Iterator { 126 | self.io_nodes.values().map(|io_node| { 127 | let (stream, (handle, endpoint)) = io_node.as_parts(); 128 | (*handle, stream, endpoint) 129 | }) 130 | } 131 | 132 | /// Return mutable iterator over active endpoints, additionally exposing handle and the stream. 133 | #[inline] 134 | pub fn iter_mut(&mut self) -> impl Iterator { 135 | self.io_nodes.values_mut().map(|io_node| { 136 | let (stream, (handle, endpoint)) = io_node.as_parts_mut(); 137 | (*handle, stream, endpoint) 138 | }) 139 | } 140 | 141 | /// Return iterator over pending endpoints. 142 | #[inline] 143 | pub fn pending(&self) -> impl Iterator { 144 | self.pending_endpoints.iter() 145 | } 146 | 147 | #[inline] 148 | fn resolve_dns(connection_info: &ConnectionInfo) -> io::Result { 149 | connection_info 150 | .to_socket_addrs()? 151 | .next() 152 | .ok_or_else(|| io::Error::other("unable to resolve dns address")) 153 | } 154 | } 155 | 156 | impl IOService 157 | where 158 | S: Selector, 159 | E: Endpoint, 160 | TS: TimeSource, 161 | { 162 | /// This method polls all registered endpoints for readiness and performs I/O operations based 163 | /// on the ['Selector'] poll results. It then iterates through all endpoints, either 164 | /// updating existing streams or creating and registering new ones. It uses [`Endpoint::can_recreate`] 165 | /// to determine if the error that occurred during polling is recoverable (typically due to remote peer disconnect). 166 | pub fn poll(&mut self) -> io::Result<()> { 167 | // check for pending endpoints (one at a time & throttled) 168 | if !self.pending_endpoints.is_empty() { 169 | let current_time_ns = self.time_source.current_time_nanos(); 170 | if current_time_ns > self.next_endpoint_create_time_ns { 171 | if let Some((handle, mut endpoint)) = self.pending_endpoints.pop_front() { 172 | let addr = Self::resolve_dns(endpoint.connection_info())?; 173 | match endpoint.create_target(addr)? { 174 | Some(stream) => { 175 | let mut io_node = 176 | IONode::new(stream, handle, endpoint, self.auto_disconnect, &self.time_source); 177 | self.selector.register(handle.0, &mut io_node)?; 178 | self.io_nodes.insert(handle.0, io_node); 179 | } 180 | None => self.pending_endpoints.push_back((handle, endpoint)), 181 | } 182 | } 183 | self.next_endpoint_create_time_ns = current_time_ns + ENDPOINT_CREATION_THROTTLE_NS; 184 | } 185 | } 186 | 187 | // check for readiness events 188 | self.selector.poll(&mut self.io_nodes)?; 189 | 190 | // check for auto disconnect if enabled 191 | if let Some(auto_disconnect) = self.auto_disconnect { 192 | let current_time_ns = self.time_source.current_time_nanos(); 193 | self.io_nodes.retain(|_token, io_node| { 194 | let force_disconnect = current_time_ns > io_node.disconnect_time_ns; 195 | if force_disconnect { 196 | // check if we really have to disconnect 197 | return if io_node.as_endpoint_mut().1.can_auto_disconnect() { 198 | warn!("endpoint auto disconnected after {:?}", auto_disconnect); 199 | self.selector.unregister(io_node).unwrap(); 200 | let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); 201 | if endpoint.can_recreate() { 202 | self.pending_endpoints.push_back((handle, endpoint)); 203 | } else { 204 | panic!("unrecoverable error when polling endpoint"); 205 | } 206 | false 207 | } else { 208 | // extend the endpoint TTL 209 | io_node.disconnect_time_ns += auto_disconnect.as_nanos() as u64; 210 | true 211 | }; 212 | } 213 | true 214 | }); 215 | } 216 | 217 | // poll endpoints 218 | self.io_nodes.retain(|_token, io_node| { 219 | let (stream, (_, endpoint)) = io_node.as_parts_mut(); 220 | if let Err(err) = endpoint.poll(stream) { 221 | error!("error when polling endpoint [{}]: {}", endpoint.connection_info().host(), err); 222 | self.selector.unregister(io_node).unwrap(); 223 | let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); 224 | if endpoint.can_recreate() { 225 | self.pending_endpoints.push_back((handle, endpoint)); 226 | } else { 227 | panic!("unrecoverable error when polling endpoint"); 228 | } 229 | return false; 230 | } 231 | true 232 | }); 233 | 234 | Ok(()) 235 | } 236 | 237 | /// Dispatch command to an active endpoint using `handle` and provided `action`. If the 238 | /// endpoint is currently active `true` will be returned and the provided `action` invoked, 239 | /// otherwise this method will return `false` and no `action` will be invoked. 240 | pub fn dispatch(&mut self, handle: Handle, mut action: F) -> io::Result 241 | where 242 | F: FnMut(&mut E::Target, &mut E) -> std::io::Result<()>, 243 | { 244 | match self.io_nodes.get_mut(&handle.0) { 245 | Some(io_node) => { 246 | let (stream, (_, endpoint)) = io_node.as_parts_mut(); 247 | action(stream, endpoint)?; 248 | Ok(true) 249 | } 250 | None => Ok(false), 251 | } 252 | } 253 | } 254 | 255 | impl IOService 256 | where 257 | S: Selector, 258 | C: Context, 259 | E: EndpointWithContext, 260 | TS: TimeSource, 261 | { 262 | /// This method polls all registered endpoints for readiness passing the [`Context`] and performs I/O operations based 263 | /// on the `SelectService` poll results. It then iterates through all endpoints, either 264 | /// updating existing streams or creating and registering new ones. It uses [`Endpoint::can_recreate`] 265 | /// to determine if the error that occurred during polling is recoverable (typically due to remote peer disconnect). 266 | pub fn poll(&mut self, context: &mut C) -> io::Result<()> { 267 | // check for pending endpoints (one at a time & throttled) 268 | if !self.pending_endpoints.is_empty() { 269 | let current_time_ns = self.time_source.current_time_nanos(); 270 | if current_time_ns > self.next_endpoint_create_time_ns { 271 | if let Some((handle, mut endpoint)) = self.pending_endpoints.pop_front() { 272 | let addr = Self::resolve_dns(endpoint.connection_info())?; 273 | match endpoint.create_target(addr, context)? { 274 | Some(stream) => { 275 | let mut io_node = 276 | IONode::new(stream, handle, endpoint, self.auto_disconnect, &self.time_source); 277 | self.selector.register(handle.0, &mut io_node)?; 278 | self.io_nodes.insert(handle.0, io_node); 279 | } 280 | None => self.pending_endpoints.push_back((handle, endpoint)), 281 | } 282 | } 283 | self.next_endpoint_create_time_ns = current_time_ns + ENDPOINT_CREATION_THROTTLE_NS; 284 | } 285 | } 286 | 287 | // check for readiness events 288 | self.selector.poll(&mut self.io_nodes)?; 289 | 290 | // check for auto disconnect if enabled 291 | if let Some(auto_disconnect) = self.auto_disconnect { 292 | let current_time_ns = self.time_source.current_time_nanos(); 293 | self.io_nodes.retain(|_token, io_node| { 294 | let force_disconnect = current_time_ns > io_node.disconnect_time_ns; 295 | if force_disconnect { 296 | // check if we really have to disconnect 297 | return if io_node.as_endpoint_mut().1.can_auto_disconnect(context) { 298 | warn!("endpoint auto disconnected after {:?}", auto_disconnect); 299 | self.selector.unregister(io_node).unwrap(); 300 | let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); 301 | if endpoint.can_recreate(context) { 302 | self.pending_endpoints.push_back((handle, endpoint)); 303 | } else { 304 | panic!("unrecoverable error when polling endpoint"); 305 | } 306 | false 307 | } else { 308 | // extend the endpoint TTL 309 | io_node.disconnect_time_ns += auto_disconnect.as_nanos() as u64; 310 | true 311 | }; 312 | } 313 | true 314 | }); 315 | } 316 | 317 | // poll endpoints 318 | self.io_nodes.retain(|_token, io_node| { 319 | let (stream, (_, endpoint)) = io_node.as_parts_mut(); 320 | if let Err(err) = endpoint.poll(stream, context) { 321 | error!("error when polling endpoint [{}]: {}", endpoint.connection_info().host(), err); 322 | self.selector.unregister(io_node).unwrap(); 323 | let (handle, mut endpoint) = io_node.endpoint.take().unwrap(); 324 | if endpoint.can_recreate(context) { 325 | self.pending_endpoints.push_back((handle, endpoint)); 326 | } else { 327 | panic!("unrecoverable error when polling endpoint"); 328 | } 329 | return false; 330 | } 331 | true 332 | }); 333 | 334 | Ok(()) 335 | } 336 | 337 | /// Dispatch command to an active endpoint using `handle` and provided `action`. If the 338 | /// endpoint is currently active `true` will be returned and the provided `action` invoked, 339 | /// otherwise this method will return `false` and no `action` will be invoked. This method 340 | /// requires `Context` to be passed and exposes it to the provided `action`. 341 | pub fn dispatch(&mut self, handle: Handle, ctx: &mut C, mut action: F) -> io::Result 342 | where 343 | F: FnMut(&mut E::Target, &mut E, &mut C) -> std::io::Result<()>, 344 | { 345 | match self.io_nodes.get_mut(&handle.0) { 346 | Some(io_node) => { 347 | let (stream, (_, endpoint)) = io_node.as_parts_mut(); 348 | action(stream, endpoint, ctx)?; 349 | Ok(true) 350 | } 351 | None => Ok(false), 352 | } 353 | } 354 | } 355 | -------------------------------------------------------------------------------- /src/service/node.rs: -------------------------------------------------------------------------------- 1 | use crate::service::Handle; 2 | use crate::service::time::TimeSource; 3 | use std::time::Duration; 4 | 5 | pub struct IONode { 6 | pub stream: S, 7 | pub endpoint: Option<(Handle, E)>, 8 | pub disconnect_time_ns: u64, 9 | } 10 | 11 | impl IONode { 12 | pub fn new(stream: S, handle: Handle, endpoint: E, ttl: Option, ts: &TS) -> IONode 13 | where 14 | TS: TimeSource, 15 | { 16 | let disconnect_time_ns = match ttl { 17 | Some(ttl) => ts.current_time_nanos() + ttl.as_nanos() as u64, 18 | None => u64::MAX, 19 | }; 20 | Self { 21 | stream, 22 | endpoint: Some((handle, endpoint)), 23 | disconnect_time_ns, 24 | } 25 | } 26 | 27 | pub fn as_parts(&self) -> (&S, &(Handle, E)) { 28 | // SAFETY: safe to call as endpoint will never be None 29 | unsafe { (&self.stream, self.endpoint.as_ref().unwrap_unchecked()) } 30 | } 31 | 32 | pub fn as_parts_mut(&mut self) -> (&mut S, &mut (Handle, E)) { 33 | // SAFETY: safe to call as endpoint will never be None 34 | unsafe { (&mut self.stream, self.endpoint.as_mut().unwrap_unchecked()) } 35 | } 36 | 37 | pub const fn as_stream(&self) -> &S { 38 | &self.stream 39 | } 40 | 41 | pub fn as_stream_mut(&mut self) -> &mut S { 42 | &mut self.stream 43 | } 44 | 45 | pub fn as_endpoint(&self) -> &(Handle, E) { 46 | // SAFETY: safe to call as endpoint will never be None 47 | unsafe { self.endpoint.as_ref().unwrap_unchecked() } 48 | } 49 | 50 | pub fn as_endpoint_mut(&mut self) -> &mut (Handle, E) { 51 | // SAFETY: safe to call as endpoint will never be None 52 | unsafe { self.endpoint.as_mut().unwrap_unchecked() } 53 | } 54 | 55 | pub fn into_endpoint(mut self) -> (Handle, E) { 56 | // SAFETY: safe to call as endpoint will never be None 57 | unsafe { self.endpoint.take().unwrap_unchecked() } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/service/select/direct.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::io; 3 | use std::marker::PhantomData; 4 | 5 | use crate::service::endpoint::{Context, Endpoint, EndpointWithContext}; 6 | use crate::service::node::IONode; 7 | use crate::service::select::{Selectable, Selector, SelectorToken}; 8 | use crate::service::time::SystemTimeClockSource; 9 | use crate::service::{IOService, IntoIOService, IntoIOServiceWithContext}; 10 | 11 | pub struct DirectSelector { 12 | next_token: u32, 13 | phantom: PhantomData, 14 | } 15 | 16 | impl DirectSelector { 17 | pub fn new() -> io::Result> { 18 | Ok(Self { 19 | next_token: 0, 20 | phantom: PhantomData, 21 | }) 22 | } 23 | } 24 | 25 | impl Selector for DirectSelector { 26 | type Target = S; 27 | 28 | fn register( 29 | &mut self, 30 | _selector_token: SelectorToken, 31 | _io_node: &mut IONode, 32 | ) -> io::Result<()> { 33 | Ok(()) 34 | } 35 | 36 | fn unregister(&mut self, _io_node: &mut IONode) -> io::Result<()> { 37 | Ok(()) 38 | } 39 | 40 | fn poll(&mut self, _io_nodes: &mut HashMap>) -> io::Result<()> { 41 | Ok(()) 42 | } 43 | 44 | fn next_token(&mut self) -> SelectorToken { 45 | let token = self.next_token; 46 | self.next_token += 1; 47 | token 48 | } 49 | } 50 | 51 | impl IntoIOService for DirectSelector { 52 | fn into_io_service(self) -> IOService 53 | where 54 | Self: Selector, 55 | Self: Sized, 56 | { 57 | IOService::new(self, SystemTimeClockSource) 58 | } 59 | } 60 | 61 | impl> IntoIOServiceWithContext for DirectSelector { 62 | fn into_io_service_with_context(self, _ctx: &mut C) -> IOService 63 | where 64 | Self: Selector, 65 | Self: Sized, 66 | { 67 | IOService::new(self, SystemTimeClockSource) 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/service/select/mio.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::io; 3 | use std::marker::PhantomData; 4 | use std::time::Duration; 5 | 6 | use mio::event::Source; 7 | use mio::{Events, Interest, Poll, Token}; 8 | 9 | use crate::service::endpoint::{Context, Endpoint, EndpointWithContext}; 10 | use crate::service::node::IONode; 11 | use crate::service::select::{Selectable, Selector, SelectorToken}; 12 | use crate::service::time::SystemTimeClockSource; 13 | use crate::service::{IOService, IntoIOService, IntoIOServiceWithContext}; 14 | 15 | const NO_WAIT: Option = Some(Duration::from_millis(0)); 16 | 17 | pub struct MioSelector { 18 | poll: Poll, 19 | events: Events, 20 | next_token: u32, 21 | phantom: PhantomData, 22 | } 23 | 24 | impl MioSelector { 25 | pub fn new() -> io::Result> { 26 | Ok(Self { 27 | poll: Poll::new()?, 28 | events: Events::with_capacity(1024), 29 | next_token: 0, 30 | phantom: PhantomData, 31 | }) 32 | } 33 | } 34 | 35 | impl Selector for MioSelector { 36 | type Target = S; 37 | 38 | fn register(&mut self, selector_token: SelectorToken, io_node: &mut IONode) -> io::Result<()> { 39 | let token = Token(selector_token as usize); 40 | self.poll 41 | .registry() 42 | .register(io_node.as_stream_mut(), token, Interest::WRITABLE)?; 43 | Ok(()) 44 | } 45 | 46 | fn unregister(&mut self, io_node: &mut IONode) -> io::Result<()> { 47 | self.poll.registry().deregister(io_node.as_stream_mut()) 48 | } 49 | 50 | fn poll(&mut self, io_nodes: &mut HashMap>) -> io::Result<()> { 51 | self.poll.poll(&mut self.events, NO_WAIT)?; 52 | for ev in self.events.iter() { 53 | let token = ev.token(); 54 | let stream = io_nodes 55 | .get_mut(&(token.0 as SelectorToken)) 56 | .ok_or_else(|| io::Error::other("io node not found"))? 57 | .as_stream_mut(); 58 | if ev.is_writable() && stream.connected()? { 59 | stream.make_writable()?; 60 | self.poll.registry().reregister(stream, token, Interest::READABLE)?; 61 | } 62 | if ev.is_readable() { 63 | stream.make_readable()?; 64 | } 65 | } 66 | Ok(()) 67 | } 68 | 69 | #[inline] 70 | fn next_token(&mut self) -> SelectorToken { 71 | let token = self.next_token; 72 | self.next_token += 1; 73 | token 74 | } 75 | } 76 | 77 | impl IntoIOService for MioSelector { 78 | fn into_io_service(self) -> IOService 79 | where 80 | Self: Selector, 81 | Self: Sized, 82 | { 83 | IOService::new(self, SystemTimeClockSource) 84 | } 85 | } 86 | 87 | impl> IntoIOServiceWithContext for MioSelector { 88 | fn into_io_service_with_context(self, _ctx: &mut C) -> IOService 89 | where 90 | Self: Selector, 91 | Self: Sized, 92 | { 93 | IOService::new(self, SystemTimeClockSource) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/service/select/mod.rs: -------------------------------------------------------------------------------- 1 | //! OS specific socket event notification mechanisms like `epoll`. 2 | 3 | use crate::service::node::IONode; 4 | use std::collections::HashMap; 5 | use std::io; 6 | 7 | pub mod direct; 8 | #[cfg(feature = "mio")] 9 | pub mod mio; 10 | 11 | /// Used to uniquely identify a socket (connection) by the `Selector`. 12 | pub type SelectorToken = u32; 13 | 14 | pub trait Selectable { 15 | fn connected(&mut self) -> io::Result; 16 | 17 | fn make_writable(&mut self) -> io::Result<()>; 18 | 19 | fn make_readable(&mut self) -> io::Result<()>; 20 | } 21 | 22 | pub trait Selector { 23 | type Target: Selectable; 24 | 25 | fn register(&mut self, selector_token: SelectorToken, io_node: &mut IONode) -> io::Result<()>; 26 | 27 | fn unregister(&mut self, io_node: &mut IONode) -> io::Result<()>; 28 | 29 | fn poll(&mut self, io_nodes: &mut HashMap>) -> io::Result<()>; 30 | 31 | fn next_token(&mut self) -> SelectorToken; 32 | } 33 | -------------------------------------------------------------------------------- /src/service/time.rs: -------------------------------------------------------------------------------- 1 | //! Contains time related utilities. 2 | 3 | use std::time::SystemTime; 4 | 5 | /// Trait that provides current time since UNIX epoch. 6 | pub trait TimeSource { 7 | /// Provides current time since UNIX epoch as nanos. 8 | fn current_time_nanos(&self) -> u64; 9 | } 10 | 11 | /// Uses `SystemTime` as [`TimeSource`]. 12 | pub struct SystemTimeClockSource; 13 | 14 | impl TimeSource for SystemTimeClockSource { 15 | #[inline] 16 | fn current_time_nanos(&self) -> u64 { 17 | SystemTime::now() 18 | .duration_since(SystemTime::UNIX_EPOCH) 19 | .unwrap() 20 | .as_nanos() as u64 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/stream/buffer.rs: -------------------------------------------------------------------------------- 1 | //! Stream that is buffering data written to it. 2 | 3 | use crate::service::select::Selectable; 4 | use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; 5 | #[cfg(feature = "mio")] 6 | use mio::{Interest, Registry, Token, event::Source}; 7 | use std::io; 8 | use std::io::{ErrorKind, Read, Write}; 9 | use std::mem::MaybeUninit; 10 | 11 | /// Default buffer size in bytes. 12 | pub const DEFAULT_BUFFER_SIZE: usize = 1024; 13 | 14 | /// Buffers data written to it until explicitly flushed. Useful if you 15 | /// want to reduce the number of operating system calls when writing. If there 16 | /// is no more space in the buffer to accommodate the current write it 17 | /// will return [ErrorKind::WriteZero]. 18 | /// 19 | /// ## Examples 20 | /// 21 | /// Wrap with default BufferedStream`. 22 | /// 23 | /// ``` no_run 24 | /// use boomnet::stream::buffer::IntoBufferedStream; 25 | /// use boomnet::stream::ConnectionInfo; 26 | /// use boomnet::stream::tls::IntoTlsStream; 27 | /// use boomnet::ws::IntoWebsocket; 28 | /// 29 | /// let mut ws = ConnectionInfo::new("stream.binance.com", 9443) 30 | /// .into_tcp_stream().unwrap() 31 | /// .into_tls_stream().unwrap() 32 | /// .into_default_buffered_stream() 33 | /// .into_websocket("/ws"); 34 | /// ``` 35 | /// 36 | /// Specify buffer size when wrapping. 37 | /// 38 | /// ``` no_run 39 | /// use boomnet::stream::buffer::IntoBufferedStream; 40 | /// use boomnet::stream::ConnectionInfo; 41 | /// use boomnet::stream::tls::IntoTlsStream; 42 | /// use boomnet::ws::IntoWebsocket; 43 | /// 44 | /// let mut ws = ConnectionInfo::new("stream.binance.com", 9443) 45 | /// .into_tcp_stream().unwrap() 46 | /// .into_tls_stream().unwrap() 47 | /// .into_buffered_stream::<512>() 48 | /// .into_websocket("/ws"); 49 | /// ``` 50 | pub struct BufferedStream { 51 | inner: S, 52 | buffer: [u8; N], 53 | cursor: usize, 54 | } 55 | 56 | impl Read for BufferedStream { 57 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 58 | self.inner.read(buf) 59 | } 60 | } 61 | 62 | impl Write for BufferedStream { 63 | fn write(&mut self, buf: &[u8]) -> io::Result { 64 | #[cold] 65 | fn handle_overflow() -> io::Result<()> { 66 | Err(io::Error::new(ErrorKind::WriteZero, "unable to write the whole buffer")) 67 | } 68 | 69 | let len = buf.len(); 70 | let remaining = N - self.cursor; 71 | if len > remaining { 72 | handle_overflow()? 73 | } 74 | self.buffer[self.cursor..self.cursor + len].copy_from_slice(buf); 75 | self.cursor += len; 76 | Ok(len) 77 | } 78 | 79 | fn flush(&mut self) -> io::Result<()> { 80 | self.inner.write_all(&self.buffer[..self.cursor])?; 81 | self.cursor = 0; 82 | self.inner.flush() 83 | } 84 | } 85 | 86 | impl ConnectionInfoProvider for BufferedStream { 87 | fn connection_info(&self) -> &ConnectionInfo { 88 | self.inner.connection_info() 89 | } 90 | } 91 | 92 | /// Trait to convert any stream into `BufferedStream`. 93 | pub trait IntoBufferedStream { 94 | /// Convert into `BufferedStream` and specify buffer length. 95 | fn into_buffered_stream(self) -> BufferedStream; 96 | 97 | /// Convert into `BufferedStream` with default buffer length. 98 | fn into_default_buffered_stream(self) -> BufferedStream 99 | where 100 | Self: Sized, 101 | { 102 | Self::into_buffered_stream(self) 103 | } 104 | } 105 | 106 | impl IntoBufferedStream for T 107 | where 108 | T: Read + Write + ConnectionInfoProvider, 109 | { 110 | fn into_buffered_stream(self) -> BufferedStream { 111 | unsafe { 112 | BufferedStream { 113 | inner: self, 114 | buffer: MaybeUninit::uninit().assume_init(), 115 | cursor: 0, 116 | } 117 | } 118 | } 119 | } 120 | 121 | impl Selectable for BufferedStream { 122 | fn connected(&mut self) -> io::Result { 123 | self.inner.connected() 124 | } 125 | 126 | fn make_writable(&mut self) -> io::Result<()> { 127 | self.inner.make_writable() 128 | } 129 | 130 | fn make_readable(&mut self) -> io::Result<()> { 131 | self.inner.make_readable() 132 | } 133 | } 134 | 135 | #[cfg(feature = "mio")] 136 | impl Source for BufferedStream { 137 | fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 138 | registry.register(&mut self.inner, token, interests) 139 | } 140 | 141 | fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 142 | registry.reregister(&mut self.inner, token, interests) 143 | } 144 | 145 | fn deregister(&mut self, registry: &Registry) -> io::Result<()> { 146 | registry.deregister(&mut self.inner) 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /src/stream/file.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::min; 2 | use std::fs::File; 3 | use std::io; 4 | use std::io::ErrorKind::UnexpectedEof; 5 | use std::io::{BufReader, Read, Write}; 6 | 7 | pub struct FileStream(BufReader); 8 | 9 | impl Read for FileStream { 10 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 11 | let up_to = min(buf.len(), CHUNK_SIZE); 12 | 13 | match self.0.read(&mut buf[..up_to]) { 14 | Ok(0) => Err(io::Error::new(UnexpectedEof, "eof")), 15 | Ok(n) => Ok(n), 16 | Err(err) => Err(err), 17 | } 18 | } 19 | } 20 | 21 | impl Write for FileStream { 22 | fn write(&mut self, buf: &[u8]) -> io::Result { 23 | Ok(buf.len()) 24 | } 25 | 26 | fn flush(&mut self) -> io::Result<()> { 27 | Ok(()) 28 | } 29 | } 30 | 31 | impl TryFrom<&str> for FileStream { 32 | type Error = io::Error; 33 | 34 | fn try_from(path: &str) -> Result { 35 | let file = File::open(path)?; 36 | let stream = FileStream(BufReader::new(file)); 37 | Ok(stream) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/stream/mio.rs: -------------------------------------------------------------------------------- 1 | //! Stream that can be used together with `MioSelector`. 2 | 3 | use std::io::ErrorKind::{Interrupted, NotConnected, WouldBlock}; 4 | use std::io::{Read, Write}; 5 | use std::{io, net}; 6 | 7 | use crate::service::select::Selectable; 8 | use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; 9 | use mio::event::Source; 10 | use mio::net::TcpStream; 11 | use mio::{Interest, Registry, Token}; 12 | 13 | #[derive(Debug)] 14 | pub struct MioStream { 15 | inner: TcpStream, 16 | connection_info: ConnectionInfo, 17 | connected: bool, 18 | can_read: bool, 19 | can_write: bool, 20 | buffer: Vec, 21 | } 22 | 23 | impl MioStream { 24 | fn new(inner: TcpStream, connection_info: ConnectionInfo) -> MioStream { 25 | Self { 26 | inner, 27 | connection_info, 28 | connected: false, 29 | can_read: false, 30 | can_write: false, 31 | buffer: Vec::with_capacity(4096), 32 | } 33 | } 34 | } 35 | 36 | impl Selectable for MioStream { 37 | fn connected(&mut self) -> io::Result { 38 | if self.connected { 39 | return Ok(true); 40 | } 41 | match self.inner.peer_addr() { 42 | Ok(_) => { 43 | self.connected = true; 44 | // bypassing `can_write` as we can get to this state 45 | // only if the socket is writable 46 | self.inner.write_all(&self.buffer)?; 47 | self.buffer.clear(); 48 | Ok(true) 49 | } 50 | Err(err) if err.kind() == NotConnected => Ok(false), 51 | Err(err) if err.kind() == Interrupted => Ok(false), 52 | Err(err) => Err(err), 53 | } 54 | } 55 | 56 | fn make_writable(&mut self) -> io::Result<()> { 57 | self.can_write = true; 58 | Ok(()) 59 | } 60 | 61 | fn make_readable(&mut self) -> io::Result<()> { 62 | self.can_read = true; 63 | Ok(()) 64 | } 65 | } 66 | 67 | impl Source for MioStream { 68 | fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 69 | registry.register(&mut self.inner, token, interests) 70 | } 71 | 72 | fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 73 | registry.reregister(&mut self.inner, token, interests) 74 | } 75 | 76 | fn deregister(&mut self, registry: &Registry) -> io::Result<()> { 77 | registry.deregister(&mut self.inner) 78 | } 79 | } 80 | 81 | impl Read for MioStream { 82 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 83 | if self.can_read { 84 | let read = self.inner.read(buf)?; 85 | if read < buf.len() { 86 | self.can_read = false; 87 | } 88 | return Ok(read); 89 | } 90 | Err(io::Error::from(WouldBlock)) 91 | } 92 | } 93 | 94 | impl Write for MioStream { 95 | fn write(&mut self, buf: &[u8]) -> io::Result { 96 | if !self.can_write { 97 | self.buffer.extend_from_slice(buf); 98 | return Ok(buf.len()); 99 | } 100 | self.inner.write(buf) 101 | } 102 | 103 | fn flush(&mut self) -> io::Result<()> { 104 | self.inner.flush() 105 | } 106 | } 107 | 108 | impl ConnectionInfoProvider for MioStream { 109 | fn connection_info(&self) -> &ConnectionInfo { 110 | &self.connection_info 111 | } 112 | } 113 | 114 | pub trait IntoMioStream { 115 | fn into_mio_stream(self) -> MioStream; 116 | } 117 | 118 | impl IntoMioStream for T 119 | where 120 | T: Into, 121 | T: ConnectionInfoProvider, 122 | { 123 | fn into_mio_stream(self) -> MioStream { 124 | let connection_info = self.connection_info().clone(); 125 | MioStream::new(TcpStream::from_std(self.into()), connection_info) 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /src/stream/mod.rs: -------------------------------------------------------------------------------- 1 | //! Various stream implementations on top of which protocol can be applied. 2 | 3 | use crate::inet::{IntoNetworkInterface, ToSocketAddr}; 4 | use crate::service::select::Selectable; 5 | use socket2::{Domain, Protocol, Socket, Type}; 6 | use std::fmt::{Display, Formatter}; 7 | use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; 8 | use std::{io, vec}; 9 | use url::{ParseError, Url}; 10 | 11 | pub mod buffer; 12 | pub mod file; 13 | #[cfg(feature = "mio")] 14 | pub mod mio; 15 | pub mod record; 16 | pub mod replay; 17 | pub mod tcp; 18 | #[cfg(any(feature = "rustls", feature = "openssl"))] 19 | pub mod tls; 20 | 21 | #[cfg(target_os = "linux")] 22 | const EINPROGRESS: i32 = 115; 23 | #[cfg(target_os = "macos")] 24 | const EINPROGRESS: i32 = 36; 25 | 26 | /// Trait to create `TcpStream` and optionally bind it to a specific network interface and/or cpu 27 | /// before connecting. 28 | /// 29 | /// # Examples 30 | /// 31 | /// Bind to a specific network interface. 32 | /// 33 | /// ```no_run 34 | /// use std::net::TcpStream; 35 | /// use boomnet::inet::{IntoNetworkInterface, ToSocketAddr}; 36 | /// use boomnet::stream::BindAndConnect; 37 | /// 38 | /// let inet = "eth1".into_network_interface().and_then(|inet| inet.to_socket_addr()); 39 | /// let stream = TcpStream::bind_and_connect("stream.binance.com", inet, None).unwrap(); 40 | /// ``` 41 | /// 42 | /// Set `SO_INCOMING_CPU` affinity. 43 | /// 44 | /// ```no_run 45 | /// use std::net::TcpStream; 46 | /// use boomnet::stream::BindAndConnect; 47 | /// 48 | /// let stream = TcpStream::bind_and_connect("stream.binance.com", None, Some(2)).unwrap(); 49 | /// ``` 50 | pub trait BindAndConnect { 51 | /// Creates `TcpStream` and optionally binds it to network interface and/or CPU before 52 | /// connecting. 53 | /// 54 | /// # Examples 55 | /// 56 | /// Bind to a specific network interface. 57 | /// 58 | /// ```no_run 59 | /// use std::net::TcpStream; 60 | /// use boomnet::inet::{IntoNetworkInterface, ToSocketAddr}; 61 | /// use boomnet::stream::BindAndConnect; 62 | /// 63 | /// let inet = "eth1".into_network_interface().and_then(|inet| inet.to_socket_addr()); 64 | /// let stream = TcpStream::bind_and_connect("stream.binance.com", inet, None).unwrap(); 65 | /// ``` 66 | /// 67 | /// Set `SO_INCOMING_CPU` affinity. 68 | /// 69 | /// ```no_run 70 | /// use std::net::TcpStream; 71 | /// use boomnet::stream::BindAndConnect; 72 | /// 73 | /// let stream = TcpStream::bind_and_connect("stream.binance.com", None, Some(2)).unwrap(); 74 | /// ``` 75 | fn bind_and_connect(addr: A, net_iface: Option, cpu: Option) -> io::Result 76 | where 77 | A: ToSocketAddrs, 78 | { 79 | Self::bind_and_connect_with_socket_config(addr, net_iface, cpu, |_| Ok(())) 80 | } 81 | 82 | /// Creates `TcpStream` and optionally binds it to network interface and/or CPU before 83 | /// connecting. This also accepts user defined `socket_config` closure that will be applied 84 | /// to the socket. 85 | /// 86 | /// # Examples 87 | /// 88 | /// Bind to a specific network interface. 89 | /// 90 | /// ```no_run 91 | /// use std::net::TcpStream; 92 | /// use boomnet::inet::{IntoNetworkInterface, ToSocketAddr}; 93 | /// use boomnet::stream::BindAndConnect; 94 | /// 95 | /// let inet = "eth1".into_network_interface().and_then(|inet| inet.to_socket_addr()); 96 | /// let stream = TcpStream::bind_and_connect("stream.binance.com", inet, None).unwrap(); 97 | /// ``` 98 | /// 99 | /// Set `SO_INCOMING_CPU` affinity. 100 | /// 101 | /// ```no_run 102 | /// use std::net::TcpStream; 103 | /// use boomnet::stream::BindAndConnect; 104 | /// 105 | /// let stream = TcpStream::bind_and_connect("stream.binance.com", None, Some(2)).unwrap(); 106 | /// ``` 107 | /// 108 | /// Use `socket_config` to enable additional socket options. 109 | /// 110 | /// ```no_run 111 | /// use std::net::TcpStream; 112 | /// use boomnet::stream::BindAndConnect; 113 | /// 114 | /// let stream = TcpStream::bind_and_connect_with_socket_config("stream.binance.com", None, Some(2), |socket| { 115 | /// socket.set_reuse_address(true)?; 116 | /// Ok(()) 117 | /// }).unwrap(); 118 | /// ``` 119 | /// 120 | fn bind_and_connect_with_socket_config( 121 | addr: A, 122 | net_iface: Option, 123 | cpu: Option, 124 | socket_config: F, 125 | ) -> io::Result 126 | where 127 | A: ToSocketAddrs, 128 | F: FnOnce(&Socket) -> io::Result<()>; 129 | } 130 | 131 | impl BindAndConnect for TcpStream { 132 | #[allow(unused_variables)] 133 | fn bind_and_connect_with_socket_config( 134 | addr: A, 135 | net_iface: Option, 136 | cpu: Option, 137 | socket_config: F, 138 | ) -> io::Result 139 | where 140 | A: ToSocketAddrs, 141 | F: FnOnce(&Socket) -> io::Result<()>, 142 | { 143 | // create a socket but do not connect yet 144 | let socket = Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP))?; 145 | socket.set_nonblocking(true)?; 146 | socket.set_nodelay(true)?; 147 | socket.set_keepalive(true)?; 148 | 149 | // apply custom options 150 | socket_config(&socket)?; 151 | 152 | // optionally bind to a specific network interface 153 | if let Some(addr) = net_iface { 154 | socket.bind(&addr.into())?; 155 | } 156 | 157 | // optionally set rx cpu affinity (only on linux) 158 | #[cfg(target_os = "linux")] 159 | if let Some(cpu_affinity) = cpu { 160 | socket.set_cpu_affinity(cpu_affinity)?; 161 | } 162 | 163 | // connect to the remote endpoint 164 | // we can ignore EINPROGRESS error due to non-blocking socket 165 | match socket.connect( 166 | &addr 167 | .to_socket_addrs()? 168 | .next() 169 | .ok_or_else(|| io::Error::other("unable to resolve socket address"))? 170 | .into(), 171 | ) { 172 | Ok(()) => Ok(socket.into()), 173 | Err(err) if err.raw_os_error() == Some(EINPROGRESS) => Ok(socket.into()), 174 | Err(err) => Err(err), 175 | } 176 | } 177 | } 178 | 179 | impl Selectable for TcpStream { 180 | fn connected(&mut self) -> io::Result { 181 | Ok(true) 182 | } 183 | 184 | fn make_writable(&mut self) -> io::Result<()> { 185 | Ok(()) 186 | } 187 | 188 | fn make_readable(&mut self) -> io::Result<()> { 189 | Ok(()) 190 | } 191 | } 192 | 193 | pub trait ConnectionInfoProvider { 194 | fn connection_info(&self) -> &ConnectionInfo; 195 | } 196 | 197 | /// TCP stream connection info. 198 | #[derive(Debug, Clone, Default)] 199 | pub struct ConnectionInfo { 200 | host: String, 201 | port: u16, 202 | net_iface: Option, 203 | cpu: Option, 204 | socket_config: Option io::Result<()>>, 205 | } 206 | 207 | impl ToSocketAddrs for ConnectionInfo { 208 | type Iter = vec::IntoIter; 209 | 210 | fn to_socket_addrs(&self) -> io::Result { 211 | format!("{}:{}", self.host, self.port).to_socket_addrs() 212 | } 213 | } 214 | 215 | impl Display for ConnectionInfo { 216 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 217 | write!(f, "{}:{}", self.host, self.port) 218 | } 219 | } 220 | 221 | impl TryFrom for ConnectionInfo { 222 | type Error = io::Error; 223 | 224 | fn try_from(url: Url) -> Result { 225 | Ok(ConnectionInfo { 226 | host: url 227 | .host_str() 228 | .ok_or_else(|| io::Error::other("host not present"))? 229 | .to_owned(), 230 | port: url 231 | .port_or_known_default() 232 | .ok_or_else(|| io::Error::other("port not present"))?, 233 | net_iface: None, 234 | cpu: None, 235 | socket_config: None, 236 | }) 237 | } 238 | } 239 | 240 | impl TryFrom> for ConnectionInfo { 241 | type Error = io::Error; 242 | 243 | fn try_from(result: Result) -> Result { 244 | match result { 245 | Ok(url) => Ok(url.try_into()?), 246 | Err(err) => Err(io::Error::other(err)), 247 | } 248 | } 249 | } 250 | 251 | impl From<(&str, u16)> for ConnectionInfo { 252 | fn from(host_and_port: (&str, u16)) -> Self { 253 | let (host, port) = host_and_port; 254 | Self::new(host, port) 255 | } 256 | } 257 | 258 | impl ConnectionInfo { 259 | pub fn new(host: impl AsRef, port: u16) -> Self { 260 | Self { 261 | host: host.as_ref().to_string(), 262 | port, 263 | net_iface: None, 264 | cpu: None, 265 | socket_config: None, 266 | } 267 | } 268 | 269 | pub fn with_net_iface(self, net_iface: SocketAddr) -> Self { 270 | Self { 271 | net_iface: Some(net_iface), 272 | ..self 273 | } 274 | } 275 | 276 | pub fn with_net_iface_from_name(self, net_iface_name: &str) -> Self { 277 | let net_iface = net_iface_name 278 | .into_network_interface() 279 | .and_then(|iface| iface.to_socket_addr()) 280 | .unwrap_or_else(|| panic!("invalid network interface: {}", net_iface_name)); 281 | Self { 282 | net_iface: Some(net_iface), 283 | ..self 284 | } 285 | } 286 | 287 | pub fn with_cpu(self, cpu: usize) -> Self { 288 | Self { cpu: Some(cpu), ..self } 289 | } 290 | 291 | pub fn with_socket_config(self, socket_config: fn(&Socket) -> io::Result<()>) -> Self { 292 | Self { 293 | socket_config: Some(socket_config), 294 | ..self 295 | } 296 | } 297 | 298 | pub fn host(&self) -> &str { 299 | &self.host 300 | } 301 | 302 | pub fn port(&self) -> u16 { 303 | self.port 304 | } 305 | 306 | pub fn into_tcp_stream(self) -> io::Result { 307 | let stream = 308 | TcpStream::bind_and_connect_with_socket_config(&self, self.net_iface, self.cpu, |socket| { 309 | match self.socket_config { 310 | Some(f) => f(socket), 311 | None => Ok(()), 312 | } 313 | })?; 314 | Ok(tcp::TcpStream::new(stream, self)) 315 | } 316 | 317 | pub fn into_tcp_stream_with_addr(self, addr: SocketAddr) -> io::Result { 318 | let stream = 319 | TcpStream::bind_and_connect_with_socket_config(addr, self.net_iface, self.cpu, |socket| { 320 | match self.socket_config { 321 | Some(f) => f(socket), 322 | None => Ok(()), 323 | } 324 | })?; 325 | Ok(tcp::TcpStream::new(stream, self)) 326 | } 327 | } 328 | -------------------------------------------------------------------------------- /src/stream/record.rs: -------------------------------------------------------------------------------- 1 | //! Stream that will also record incoming and outgoing data to a file. 2 | //! 3 | 4 | use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; 5 | use std::fmt::{Debug, Formatter}; 6 | use std::fs::File; 7 | use std::io; 8 | use std::io::{BufWriter, Read, Write}; 9 | 10 | const DEFAULT_RECORDING_NAME: &str = "plain"; 11 | 12 | pub struct Recorder { 13 | inbound: Box, 14 | inbound_seq: Box, 15 | outbound: Box, 16 | } 17 | 18 | impl Recorder { 19 | pub fn new(recording_name: impl AsRef) -> io::Result { 20 | let file_in = format!("{}_inbound.rec", recording_name.as_ref()); 21 | let file_out = format!("{}_outbound.rec", recording_name.as_ref()); 22 | let inbound = Box::new(BufWriter::new(File::create(file_in)?)); 23 | let outbound = Box::new(BufWriter::new(File::create(file_out)?)); 24 | 25 | let file_seq_in = format!("{}_inbound_seq.rec", recording_name.as_ref()); 26 | let inbound_seq = Box::new(BufWriter::new(File::create(file_seq_in)?)); 27 | 28 | Ok(Self { 29 | inbound, 30 | inbound_seq, 31 | outbound, 32 | }) 33 | } 34 | fn record_inbound(&mut self, buf: &[u8], seq: usize) -> io::Result<()> { 35 | self.inbound.write_all(buf)?; 36 | self.inbound.flush()?; 37 | self.inbound_seq.write_all(&seq.to_le_bytes())?; 38 | self.inbound_seq.write_all(&buf.len().to_le_bytes())?; 39 | self.inbound_seq.flush()?; 40 | Ok(()) 41 | } 42 | fn record_outbound(&mut self, buf: &[u8]) -> io::Result<()> { 43 | self.outbound.write_all(buf)?; 44 | self.outbound.flush() 45 | } 46 | } 47 | 48 | pub struct RecordedStream { 49 | inner: S, 50 | recorder: Recorder, 51 | inbound_seq: usize, 52 | } 53 | 54 | impl Debug for RecordedStream { 55 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 56 | f.debug_struct("RecordedStream") 57 | .field("seq", &self.inbound_seq) 58 | .finish() 59 | } 60 | } 61 | 62 | impl RecordedStream { 63 | pub fn new(stream: S, recorder: Recorder) -> RecordedStream { 64 | Self { 65 | inner: stream, 66 | recorder, 67 | inbound_seq: 0, 68 | } 69 | } 70 | } 71 | 72 | impl Read for RecordedStream { 73 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 74 | let seq = self.inbound_seq; 75 | self.inbound_seq += 1; 76 | let read = self.inner.read(buf)?; 77 | self.recorder.record_inbound(&buf[..read], seq)?; 78 | Ok(read) 79 | } 80 | } 81 | 82 | impl Write for RecordedStream { 83 | fn write(&mut self, buf: &[u8]) -> io::Result { 84 | let wrote = self.inner.write(buf)?; 85 | self.recorder.record_outbound(&buf[..wrote])?; 86 | Ok(wrote) 87 | } 88 | 89 | fn flush(&mut self) -> io::Result<()> { 90 | self.inner.flush() 91 | } 92 | } 93 | 94 | impl ConnectionInfoProvider for RecordedStream { 95 | fn connection_info(&self) -> &ConnectionInfo { 96 | self.inner.connection_info() 97 | } 98 | } 99 | 100 | pub trait IntoRecordedStream { 101 | fn into_recorded_stream(self, recording_name: impl AsRef) -> RecordedStream 102 | where 103 | Self: Sized; 104 | 105 | fn into_default_recorded_stream(self) -> RecordedStream 106 | where 107 | Self: Sized, 108 | { 109 | self.into_recorded_stream(DEFAULT_RECORDING_NAME) 110 | } 111 | } 112 | 113 | impl IntoRecordedStream for T 114 | where 115 | T: Read + Write, 116 | { 117 | fn into_recorded_stream(self, recording_name: impl AsRef) -> RecordedStream 118 | where 119 | Self: Sized, 120 | { 121 | RecordedStream::new(self, Recorder::new(recording_name).unwrap()) 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/stream/replay.rs: -------------------------------------------------------------------------------- 1 | //! Stream that uses file replay. 2 | 3 | use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; 4 | use std::collections::HashMap; 5 | use std::fmt::{Debug, Formatter}; 6 | use std::fs::File; 7 | use std::io; 8 | use std::io::{BufReader, Read, Write}; 9 | use std::path::Path; 10 | 11 | type Sequence = u64; 12 | 13 | pub struct ReplayStream { 14 | inner: S, 15 | seq: Sequence, 16 | last_seq: Sequence, 17 | bytes_read: HashMap, 18 | } 19 | 20 | impl Debug for ReplayStream { 21 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 22 | f.debug_struct("ReplayStream") 23 | .field("seq", &self.seq) 24 | .field("last_seq", &self.last_seq) 25 | .finish() 26 | } 27 | } 28 | 29 | impl ReplayStream> { 30 | pub fn from_file(recording_name: impl AsRef) -> io::Result>> { 31 | let recording_file = format!("{}.rec", recording_name.as_ref()); 32 | let seq_file = format!("{}_seq.rec", recording_name.as_ref()); 33 | 34 | let bytes_read = load_sequence_file(seq_file)?; 35 | let last_seq = *bytes_read 36 | .keys() 37 | .max() 38 | .ok_or_else(|| io::Error::other("sequence file is empty"))?; 39 | 40 | Ok(Self { 41 | inner: BufReader::new(File::open(recording_file)?), 42 | seq: 0, 43 | bytes_read, 44 | last_seq, 45 | }) 46 | } 47 | } 48 | 49 | impl Read for ReplayStream { 50 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 51 | let seq = self.seq; 52 | if seq > self.last_seq { 53 | return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "no more data to replay")); 54 | } 55 | self.seq += 1; 56 | let read = *self.bytes_read.get(&seq).unwrap_or(&0); 57 | if read == 0 { 58 | return Err(io::Error::new(io::ErrorKind::WouldBlock, "")); 59 | } 60 | 61 | // keep reading until we have required number of bytes at that sequence 62 | let mut actual_read = 0; 63 | while actual_read != read { 64 | actual_read += self.inner.read(buf[actual_read..read].as_mut())?; 65 | } 66 | 67 | Ok(actual_read) 68 | } 69 | } 70 | 71 | impl Write for ReplayStream { 72 | fn write(&mut self, buf: &[u8]) -> io::Result { 73 | Ok(buf.len()) 74 | } 75 | 76 | fn flush(&mut self) -> io::Result<()> { 77 | Ok(()) 78 | } 79 | } 80 | 81 | impl ConnectionInfoProvider for ReplayStream { 82 | fn connection_info(&self) -> &ConnectionInfo { 83 | Box::leak(Box::new(ConnectionInfo::default())) 84 | } 85 | } 86 | 87 | fn load_sequence_file(file: impl AsRef) -> io::Result> { 88 | let mut map = HashMap::new(); 89 | let mut reader = BufReader::with_capacity(16, File::open(file)?); 90 | let mut bytes = [0u8; 16]; 91 | loop { 92 | match reader.read(&mut bytes)? { 93 | 0 => break, 94 | 1..15 => return Err(io::Error::other("incomplete sequence file")), 95 | _ => {} 96 | } 97 | let (seq, read) = bytes.split_at(8); 98 | let seq = u64::from_le_bytes(seq.try_into().map_err(io::Error::other)?); 99 | let read = usize::from_le_bytes(read.try_into().map_err(io::Error::other)?); 100 | map.insert(seq, read); 101 | } 102 | Ok(map) 103 | } 104 | -------------------------------------------------------------------------------- /src/stream/tcp.rs: -------------------------------------------------------------------------------- 1 | //! Wrapper over `std::net::TcpStream`. 2 | 3 | use crate::service::select::Selectable; 4 | use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; 5 | use std::io; 6 | use std::io::{Read, Write}; 7 | use std::net::SocketAddr; 8 | 9 | /// Wraps `std::net::TcpStream` and provides `ConnectionInfo`. 10 | #[derive(Debug)] 11 | pub struct TcpStream { 12 | inner: std::net::TcpStream, 13 | connection_info: ConnectionInfo, 14 | } 15 | 16 | impl From for std::net::TcpStream { 17 | fn from(stream: TcpStream) -> Self { 18 | stream.inner 19 | } 20 | } 21 | 22 | impl TryFrom<(&str, u16)> for TcpStream { 23 | type Error = io::Error; 24 | 25 | fn try_from(host_and_port: (&str, u16)) -> Result { 26 | ConnectionInfo::from(host_and_port).try_into() 27 | } 28 | } 29 | 30 | impl TryFrom for TcpStream { 31 | type Error = io::Error; 32 | 33 | fn try_from(connection_info: ConnectionInfo) -> Result { 34 | connection_info.into_tcp_stream() 35 | } 36 | } 37 | 38 | impl TryFrom<&ConnectionInfo> for TcpStream { 39 | type Error = io::Error; 40 | 41 | fn try_from(connection_info: &ConnectionInfo) -> Result { 42 | connection_info.clone().into_tcp_stream() 43 | } 44 | } 45 | 46 | impl TryFrom<(&ConnectionInfo, SocketAddr)> for TcpStream { 47 | type Error = io::Error; 48 | 49 | fn try_from(conn_and_addr: (&ConnectionInfo, SocketAddr)) -> Result { 50 | let (conn, addr) = conn_and_addr; 51 | conn.clone().into_tcp_stream_with_addr(addr) 52 | } 53 | } 54 | 55 | impl TryFrom<(ConnectionInfo, SocketAddr)> for TcpStream { 56 | type Error = io::Error; 57 | 58 | fn try_from(conn_and_addr: (ConnectionInfo, SocketAddr)) -> Result { 59 | let (conn, addr) = conn_and_addr; 60 | conn.into_tcp_stream_with_addr(addr) 61 | } 62 | } 63 | 64 | impl TcpStream { 65 | pub fn new(stream: std::net::TcpStream, connection_info: ConnectionInfo) -> Self { 66 | Self { 67 | inner: stream, 68 | connection_info, 69 | } 70 | } 71 | } 72 | 73 | impl Read for TcpStream { 74 | fn read(&mut self, buf: &mut [u8]) -> std::io::Result { 75 | self.inner.read(buf) 76 | } 77 | } 78 | 79 | impl Write for TcpStream { 80 | fn write(&mut self, buf: &[u8]) -> std::io::Result { 81 | self.inner.write(buf) 82 | } 83 | 84 | fn flush(&mut self) -> std::io::Result<()> { 85 | self.inner.flush() 86 | } 87 | } 88 | 89 | impl Selectable for TcpStream { 90 | fn connected(&mut self) -> io::Result { 91 | Ok(true) 92 | } 93 | 94 | fn make_writable(&mut self) -> io::Result<()> { 95 | Ok(()) 96 | } 97 | 98 | fn make_readable(&mut self) -> io::Result<()> { 99 | Ok(()) 100 | } 101 | } 102 | 103 | impl ConnectionInfoProvider for TcpStream { 104 | fn connection_info(&self) -> &ConnectionInfo { 105 | &self.connection_info 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /src/stream/tls.rs: -------------------------------------------------------------------------------- 1 | //! Provides TLS stream implementation for different backends. 2 | 3 | use crate::service::select::Selectable; 4 | use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; 5 | #[cfg(feature = "openssl")] 6 | pub use __openssl::TlsStream; 7 | #[cfg(all(feature = "rustls", not(feature = "openssl")))] 8 | pub use __rustls::TlsStream; 9 | #[cfg(feature = "mio")] 10 | use mio::{Interest, Registry, Token, event::Source}; 11 | #[cfg(feature = "openssl")] 12 | use openssl::ssl::{SslConnectorBuilder, SslVerifyMode}; 13 | #[cfg(all(feature = "rustls", not(feature = "openssl")))] 14 | use rustls::ClientConfig; 15 | use std::fmt::Debug; 16 | use std::io; 17 | use std::io::{Read, Write}; 18 | 19 | /// Used to configure TLS backend. 20 | pub struct TlsConfig { 21 | #[cfg(all(feature = "rustls", not(feature = "openssl")))] 22 | rustls_config: ClientConfig, 23 | #[cfg(feature = "openssl")] 24 | openssl_config: SslConnectorBuilder, 25 | } 26 | 27 | /// Extension methods for `TlsConfig`. 28 | pub trait TlsConfigExt { 29 | /// Disable certificate verification. 30 | fn with_no_cert_verification(&mut self); 31 | } 32 | 33 | impl TlsConfig { 34 | /// Get reference to the `rustls` configuration object. 35 | #[cfg(all(feature = "rustls", not(feature = "openssl")))] 36 | pub const fn as_rustls(&self) -> &ClientConfig { 37 | &self.rustls_config 38 | } 39 | 40 | /// Get mutable reference to the `rustls` configuration object. 41 | #[cfg(all(feature = "rustls", not(feature = "openssl")))] 42 | pub const fn as_rustls_mut(&mut self) -> &mut ClientConfig { 43 | &mut self.rustls_config 44 | } 45 | 46 | /// Get reference to the `openssl` configuration object. 47 | #[cfg(feature = "openssl")] 48 | pub const fn as_openssl(&self) -> &SslConnectorBuilder { 49 | &self.openssl_config 50 | } 51 | 52 | /// Get mutable reference to the `openssl` configuration object. 53 | #[cfg(feature = "openssl")] 54 | pub const fn as_openssl_mut(&mut self) -> &mut SslConnectorBuilder { 55 | &mut self.openssl_config 56 | } 57 | } 58 | 59 | impl TlsConfigExt for TlsConfig { 60 | fn with_no_cert_verification(&mut self) { 61 | #[cfg(all(feature = "rustls", not(feature = "openssl")))] 62 | self.rustls_config 63 | .dangerous() 64 | .set_certificate_verifier(std::sync::Arc::new(crate::stream::tls::__rustls::NoCertVerification)); 65 | #[cfg(feature = "openssl")] 66 | self.openssl_config.set_verify(SslVerifyMode::NONE); 67 | } 68 | } 69 | 70 | #[cfg(all(feature = "rustls", not(feature = "openssl")))] 71 | mod __rustls { 72 | use crate::service::select::Selectable; 73 | use crate::stream::tls::TlsConfig; 74 | use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; 75 | use crate::util::NoBlock; 76 | #[cfg(feature = "mio")] 77 | use mio::{Interest, Registry, Token, event::Source}; 78 | use rustls::SignatureScheme::{ 79 | ECDSA_NISTP256_SHA256, ECDSA_NISTP384_SHA384, ECDSA_NISTP521_SHA512, ECDSA_SHA1_Legacy, ED448, ED25519, 80 | RSA_PKCS1_SHA1, RSA_PKCS1_SHA256, RSA_PKCS1_SHA384, RSA_PKCS1_SHA512, RSA_PSS_SHA256, RSA_PSS_SHA384, 81 | RSA_PSS_SHA512, 82 | }; 83 | use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; 84 | use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; 85 | use rustls::{ClientConfig, ClientConnection, DigitallySignedStruct, Error, RootCertStore, SignatureScheme}; 86 | use std::fmt::Debug; 87 | use std::io; 88 | use std::io::{Read, Write}; 89 | 90 | pub struct TlsStream { 91 | inner: S, 92 | tls: ClientConnection, 93 | } 94 | 95 | #[cfg(feature = "mio")] 96 | impl Source for TlsStream { 97 | fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 98 | registry.register(&mut self.inner, token, interests) 99 | } 100 | 101 | fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 102 | registry.reregister(&mut self.inner, token, interests) 103 | } 104 | 105 | fn deregister(&mut self, registry: &Registry) -> io::Result<()> { 106 | registry.deregister(&mut self.inner) 107 | } 108 | } 109 | 110 | impl Selectable for TlsStream { 111 | fn connected(&mut self) -> io::Result { 112 | self.inner.connected() 113 | } 114 | 115 | fn make_writable(&mut self) -> io::Result<()> { 116 | self.inner.make_writable() 117 | } 118 | 119 | fn make_readable(&mut self) -> io::Result<()> { 120 | self.inner.make_readable() 121 | } 122 | } 123 | 124 | impl Read for TlsStream { 125 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 126 | let (_, _) = self.complete_io()?; 127 | self.tls.reader().read(buf) 128 | } 129 | } 130 | 131 | impl Write for TlsStream { 132 | fn write(&mut self, buf: &[u8]) -> io::Result { 133 | self.tls.writer().write(buf) 134 | } 135 | 136 | fn flush(&mut self) -> io::Result<()> { 137 | self.tls.writer().flush() 138 | } 139 | } 140 | 141 | impl TlsStream { 142 | pub fn wrap_with_config(stream: S, server_name: &str, builder: F) -> io::Result> 143 | where 144 | F: FnOnce(&mut TlsConfig), 145 | { 146 | #[cfg(not(all(feature = "rustls-native-certs", feature = "webpki-roots")))] 147 | let mut root_store = RootCertStore::empty(); 148 | 149 | #[cfg(all(feature = "rustls-native-certs", feature = "webpki-roots"))] 150 | let root_store = RootCertStore::empty(); 151 | 152 | #[cfg(all(feature = "webpki-roots", not(feature = "rustls-native-certs")))] 153 | root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); 154 | 155 | #[cfg(all(feature = "rustls-native-certs", not(feature = "webpki-roots")))] 156 | { 157 | for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") { 158 | root_store.add(cert).unwrap(); 159 | } 160 | } 161 | 162 | let config = ClientConfig::builder() 163 | .with_root_certificates(root_store) 164 | .with_no_client_auth(); 165 | 166 | let mut config = TlsConfig { rustls_config: config }; 167 | builder(&mut config); 168 | 169 | let config = std::sync::Arc::new(config.rustls_config); 170 | let server_name = server_name.to_owned().try_into().map_err(io::Error::other)?; 171 | let tls = ClientConnection::new(config, server_name).map_err(io::Error::other)?; 172 | 173 | Ok(Self { inner: stream, tls }) 174 | } 175 | 176 | pub fn wrap(stream: S, server_name: &str) -> io::Result> { 177 | Self::wrap_with_config(stream, server_name, |_| {}) 178 | } 179 | 180 | fn complete_io(&mut self) -> io::Result<(usize, usize)> { 181 | let wrote = if self.tls.wants_write() { 182 | self.tls.write_tls(&mut self.inner)? 183 | } else { 184 | 0 185 | }; 186 | 187 | let read = if self.tls.wants_read() { 188 | let read = self.tls.read_tls(&mut self.inner).no_block()?; 189 | if read > 0 { 190 | self.tls.process_new_packets().map_err(io::Error::other)?; 191 | } 192 | read 193 | } else { 194 | 0 195 | }; 196 | 197 | Ok((read, wrote)) 198 | } 199 | } 200 | 201 | impl ConnectionInfoProvider for TlsStream { 202 | fn connection_info(&self) -> &ConnectionInfo { 203 | self.inner.connection_info() 204 | } 205 | } 206 | 207 | #[derive(Debug)] 208 | pub(crate) struct NoCertVerification; 209 | 210 | impl ServerCertVerifier for NoCertVerification { 211 | fn verify_server_cert( 212 | &self, 213 | _end_entity: &CertificateDer<'_>, 214 | _intermediates: &[CertificateDer<'_>], 215 | _server_name: &ServerName<'_>, 216 | _ocsp_response: &[u8], 217 | _now: UnixTime, 218 | ) -> Result { 219 | Ok(ServerCertVerified::assertion()) 220 | } 221 | 222 | fn verify_tls12_signature( 223 | &self, 224 | _message: &[u8], 225 | _cert: &CertificateDer<'_>, 226 | _dss: &DigitallySignedStruct, 227 | ) -> Result { 228 | Ok(HandshakeSignatureValid::assertion()) 229 | } 230 | 231 | fn verify_tls13_signature( 232 | &self, 233 | _message: &[u8], 234 | _cert: &CertificateDer<'_>, 235 | _dss: &DigitallySignedStruct, 236 | ) -> Result { 237 | Ok(HandshakeSignatureValid::assertion()) 238 | } 239 | 240 | fn supported_verify_schemes(&self) -> Vec { 241 | vec![ 242 | RSA_PKCS1_SHA1, 243 | ECDSA_SHA1_Legacy, 244 | RSA_PKCS1_SHA256, 245 | ECDSA_NISTP256_SHA256, 246 | RSA_PKCS1_SHA384, 247 | ECDSA_NISTP384_SHA384, 248 | RSA_PKCS1_SHA512, 249 | ECDSA_NISTP521_SHA512, 250 | RSA_PSS_SHA256, 251 | RSA_PSS_SHA384, 252 | RSA_PSS_SHA512, 253 | ED25519, 254 | ED448, 255 | ] 256 | } 257 | } 258 | } 259 | 260 | #[cfg(feature = "openssl")] 261 | mod __openssl { 262 | use crate::service::select::Selectable; 263 | use crate::stream::tls::TlsConfig; 264 | use crate::stream::{ConnectionInfo, ConnectionInfoProvider}; 265 | #[cfg(feature = "mio")] 266 | use mio::{Interest, Registry, Token, event::Source}; 267 | use openssl::ssl::{HandshakeError, MidHandshakeSslStream, SslConnector, SslMethod, SslRef, SslStream}; 268 | use openssl::x509::X509VerifyResult; 269 | use std::fmt::Debug; 270 | use std::fs::OpenOptions; 271 | use std::io; 272 | use std::io::ErrorKind::WouldBlock; 273 | use std::io::{Read, Write}; 274 | 275 | #[derive(Debug)] 276 | pub struct TlsStream { 277 | state: State, 278 | } 279 | 280 | #[derive(Debug)] 281 | enum State { 282 | Handshake(Option<(MidHandshakeSslStream, Vec)>), 283 | Stream(SslStream), 284 | } 285 | 286 | impl State { 287 | fn get_stream_mut(&mut self) -> io::Result<&mut S> { 288 | match self { 289 | State::Handshake(stream_and_buf) => match stream_and_buf.as_mut() { 290 | Some((stream, _)) => Ok(stream.get_mut()), 291 | None => Err(io::Error::other("unable to perform TLS handshake")), 292 | }, 293 | State::Stream(stream) => Ok(stream.get_mut()), 294 | } 295 | } 296 | } 297 | 298 | impl ConnectionInfoProvider for State { 299 | fn connection_info(&self) -> &ConnectionInfo { 300 | match self { 301 | State::Handshake(stream_and_buf) => stream_and_buf.as_ref().unwrap().0.get_ref().connection_info(), 302 | State::Stream(stream) => stream.get_ref().connection_info(), 303 | } 304 | } 305 | } 306 | 307 | #[cfg(feature = "mio")] 308 | impl Source for TlsStream { 309 | fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 310 | registry.register(self.state.get_stream_mut()?, token, interests) 311 | } 312 | 313 | fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 314 | registry.reregister(self.state.get_stream_mut()?, token, interests) 315 | } 316 | 317 | fn deregister(&mut self, registry: &Registry) -> io::Result<()> { 318 | registry.deregister(self.state.get_stream_mut()?) 319 | } 320 | } 321 | 322 | impl Selectable for TlsStream { 323 | fn connected(&mut self) -> io::Result { 324 | self.state.get_stream_mut()?.connected() 325 | } 326 | 327 | fn make_writable(&mut self) -> io::Result<()> { 328 | self.state.get_stream_mut()?.make_writable() 329 | } 330 | 331 | fn make_readable(&mut self) -> io::Result<()> { 332 | self.state.get_stream_mut()?.make_readable() 333 | } 334 | } 335 | 336 | impl Read for TlsStream { 337 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 338 | match &mut self.state { 339 | State::Handshake(stream_and_buf) => { 340 | if let Some((mid_handshake, buffer)) = stream_and_buf.take() { 341 | return match mid_handshake.handshake() { 342 | Ok(mut ssl_stream) => { 343 | // drain the pending message buffer 344 | ssl_stream.write_all(&buffer)?; 345 | self.state = State::Stream(ssl_stream); 346 | Err(io::Error::from(WouldBlock)) 347 | } 348 | Err(HandshakeError::WouldBlock(mid)) => { 349 | self.state = State::Handshake(Some((mid, buffer))); 350 | Err(io::Error::from(WouldBlock)) 351 | } 352 | Err(err) => match err { 353 | HandshakeError::Failure(stream) => { 354 | let verify = stream.ssl().verify_result(); 355 | if verify != X509VerifyResult::OK { 356 | Err(io::Error::other(format!("{} {}", stream.error(), verify))) 357 | } else { 358 | Err(io::Error::other(stream.error().to_string())) 359 | } 360 | } 361 | _ => Err(io::Error::other("TLS handshake failed")), 362 | }, 363 | }; 364 | } 365 | Err(io::Error::from(WouldBlock)) 366 | } 367 | State::Stream(stream) => stream.read(buf), 368 | } 369 | } 370 | } 371 | 372 | impl Write for TlsStream { 373 | fn write(&mut self, buf: &[u8]) -> io::Result { 374 | match &mut self.state { 375 | State::Handshake(stream_and_buf) => { 376 | let (_, buffer) = stream_and_buf.as_mut().unwrap(); 377 | buffer.extend_from_slice(buf); 378 | Ok(buf.len()) 379 | } 380 | State::Stream(stream) => stream.write(buf), 381 | } 382 | } 383 | 384 | fn flush(&mut self) -> io::Result<()> { 385 | match &mut self.state { 386 | State::Handshake(_) => Ok(()), 387 | State::Stream(stream) => stream.flush(), 388 | } 389 | } 390 | } 391 | 392 | impl TlsStream { 393 | pub fn wrap_with_config(stream: S, server_name: &str, configure: F) -> io::Result> 394 | where 395 | F: FnOnce(&mut TlsConfig), 396 | { 397 | let mut builder = SslConnector::builder(SslMethod::tls_client()).map_err(io::Error::other)?; 398 | 399 | if std::env::var("SSLKEYLOGFILE").is_ok() { 400 | builder.set_keylog_callback(default_key_log_callback) 401 | } 402 | 403 | let mut tls_config = TlsConfig { 404 | openssl_config: builder, 405 | }; 406 | configure(&mut tls_config); 407 | 408 | let connector = tls_config.openssl_config.build(); 409 | match connector.connect(server_name, stream) { 410 | Ok(stream) => Ok(Self { 411 | state: State::Stream(stream), 412 | }), 413 | Err(HandshakeError::WouldBlock(mid_handshake)) => Ok(Self { 414 | state: State::Handshake(Some((mid_handshake, Vec::with_capacity(4096)))), 415 | }), 416 | Err(e) => Err(io::Error::other(e.to_string())), 417 | } 418 | } 419 | 420 | pub fn wrap(stream: S, server_name: &str) -> io::Result> { 421 | Self::wrap_with_config(stream, server_name, |_| {}) 422 | } 423 | } 424 | 425 | impl ConnectionInfoProvider for TlsStream { 426 | fn connection_info(&self) -> &ConnectionInfo { 427 | self.state.connection_info() 428 | } 429 | } 430 | 431 | fn default_key_log_callback(_ssl: &SslRef, line: &str) { 432 | let path = std::env::var("SSLKEYLOGFILE").expect("SSLKEYLOGFILE not set"); 433 | let mut file = OpenOptions::new() 434 | .create(true) 435 | .append(true) 436 | .open(path) 437 | .expect("Failed to open SSL key log file"); 438 | 439 | writeln!(file, "{}", line).expect("Failed to write to SSL key log file"); 440 | } 441 | } 442 | 443 | /// Trait to convert underlying stream into [TlsStream]. 444 | pub trait IntoTlsStream { 445 | /// Convert underlying stream into [TlsStream] with default tls config. 446 | /// 447 | /// ## Examples 448 | /// ```no_run 449 | /// use boomnet::stream::tcp::TcpStream; 450 | /// use boomnet::stream::tls::IntoTlsStream; 451 | /// 452 | /// let tls = TcpStream::try_from(("127.0.0.1", 4222)).unwrap().into_tls_stream(); 453 | /// ``` 454 | fn into_tls_stream(self) -> io::Result> 455 | where 456 | Self: Sized, 457 | { 458 | self.into_tls_stream_with_config(|_| {}) 459 | } 460 | 461 | /// Convert underlying stream into [TlsStream] and modify tls config. The type of`TlsConfig` used 462 | /// will depend on whether `openssl` or `rustls` has been enabled. 463 | /// 464 | /// ## Examples 465 | /// 466 | /// Using `openssl` configure the TLS stream to disable server side certificate verification. 467 | /// ```no_run 468 | /// #[cfg(feature = "openssl")] 469 | /// { 470 | /// use openssl::ssl::SslVerifyMode; 471 | /// { 472 | /// use boomnet::stream::tcp::TcpStream; 473 | /// use boomnet::stream::tls::IntoTlsStream; 474 | /// 475 | /// let tls = TcpStream::try_from(("127.0.0.1", 4222)).unwrap().into_tls_stream_with_config(|config| { 476 | /// config.as_openssl_mut().set_verify(SslVerifyMode::NONE); 477 | /// }); 478 | /// } 479 | /// } 480 | /// ``` 481 | fn into_tls_stream_with_config(self, builder: F) -> io::Result> 482 | where 483 | Self: Sized, 484 | F: FnOnce(&mut TlsConfig); 485 | } 486 | 487 | impl IntoTlsStream for T 488 | where 489 | T: Read + Write + Debug + ConnectionInfoProvider, 490 | { 491 | fn into_tls_stream_with_config(self, builder: F) -> io::Result> 492 | where 493 | Self: Sized, 494 | F: FnOnce(&mut TlsConfig), 495 | { 496 | let server_name = self.connection_info().clone().host; 497 | TlsStream::wrap_with_config(self, &server_name, builder) 498 | } 499 | } 500 | 501 | #[allow(clippy::large_enum_variant)] 502 | pub enum TlsReadyStream { 503 | Plain(S), 504 | Tls(TlsStream), 505 | } 506 | 507 | impl Read for TlsReadyStream { 508 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 509 | match self { 510 | TlsReadyStream::Plain(stream) => stream.read(buf), 511 | TlsReadyStream::Tls(stream) => stream.read(buf), 512 | } 513 | } 514 | } 515 | 516 | impl Write for TlsReadyStream { 517 | fn write(&mut self, buf: &[u8]) -> io::Result { 518 | match self { 519 | TlsReadyStream::Plain(stream) => stream.write(buf), 520 | TlsReadyStream::Tls(stream) => stream.write(buf), 521 | } 522 | } 523 | 524 | fn flush(&mut self) -> io::Result<()> { 525 | match self { 526 | TlsReadyStream::Plain(stream) => stream.flush(), 527 | TlsReadyStream::Tls(stream) => stream.flush(), 528 | } 529 | } 530 | } 531 | 532 | impl ConnectionInfoProvider for TlsReadyStream { 533 | fn connection_info(&self) -> &ConnectionInfo { 534 | match self { 535 | TlsReadyStream::Plain(stream) => stream.connection_info(), 536 | TlsReadyStream::Tls(stream) => stream.connection_info(), 537 | } 538 | } 539 | } 540 | 541 | #[cfg(feature = "mio")] 542 | impl Source for TlsReadyStream { 543 | fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 544 | match self { 545 | TlsReadyStream::Plain(stream) => registry.register(stream, token, interests), 546 | TlsReadyStream::Tls(stream) => registry.register(stream, token, interests), 547 | } 548 | } 549 | 550 | fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 551 | match self { 552 | TlsReadyStream::Plain(stream) => registry.reregister(stream, token, interests), 553 | TlsReadyStream::Tls(stream) => registry.reregister(stream, token, interests), 554 | } 555 | } 556 | 557 | fn deregister(&mut self, registry: &Registry) -> io::Result<()> { 558 | match self { 559 | TlsReadyStream::Plain(stream) => registry.deregister(stream), 560 | TlsReadyStream::Tls(stream) => registry.deregister(stream), 561 | } 562 | } 563 | } 564 | 565 | impl Selectable for TlsReadyStream { 566 | fn connected(&mut self) -> io::Result { 567 | match self { 568 | TlsReadyStream::Plain(stream) => stream.connected(), 569 | TlsReadyStream::Tls(stream) => stream.connected(), 570 | } 571 | } 572 | 573 | fn make_writable(&mut self) -> io::Result<()> { 574 | match self { 575 | TlsReadyStream::Plain(stream) => stream.make_writable(), 576 | TlsReadyStream::Tls(stream) => stream.make_writable(), 577 | } 578 | } 579 | 580 | fn make_readable(&mut self) -> io::Result<()> { 581 | match self { 582 | TlsReadyStream::Plain(stream) => stream.make_readable(), 583 | TlsReadyStream::Tls(stream) => stream.make_readable(), 584 | } 585 | } 586 | } 587 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::io::ErrorKind::{UnexpectedEof, WouldBlock}; 3 | use std::mem::MaybeUninit; 4 | use std::ptr::copy_nonoverlapping; 5 | 6 | pub trait NoBlock { 7 | type Value; 8 | 9 | fn no_block(self) -> io::Result; 10 | } 11 | 12 | impl NoBlock for io::Result { 13 | type Value = usize; 14 | 15 | fn no_block(self) -> io::Result { 16 | match self { 17 | Ok(0) => Err(io::Error::from(UnexpectedEof)), 18 | Ok(n) => Ok(n), 19 | Err(err) if err.kind() == WouldBlock => Ok(0), 20 | Err(err) => Err(err), 21 | } 22 | } 23 | } 24 | 25 | impl NoBlock for io::Result<()> { 26 | type Value = (); 27 | 28 | fn no_block(self) -> io::Result { 29 | match self { 30 | Ok(()) => Ok(()), 31 | Err(err) if err.kind() == WouldBlock => Ok(()), 32 | Err(err) => Err(err), 33 | } 34 | } 35 | } 36 | 37 | #[inline] 38 | pub const unsafe fn into_array(slice: &[u8]) -> [u8; N] { 39 | unsafe { 40 | let array = MaybeUninit::<[u8; N]>::uninit(); 41 | copy_nonoverlapping(slice.as_ptr(), array.as_ptr() as *mut u8, slice.len()); 42 | array.assume_init() 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/ws/decoder.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::io::Read; 3 | 4 | use crate::util::into_array; 5 | use crate::ws::{Error, ReadBuffer, WebsocketFrame, protocol}; 6 | 7 | #[derive(Debug)] 8 | pub struct Decoder { 9 | buffer: ReadBuffer, 10 | decode_state: DecodeState, 11 | fin: bool, 12 | payload_length: usize, 13 | op_code: u8, 14 | needs_more_data: bool, 15 | } 16 | 17 | #[derive(Debug)] 18 | enum DecodeState { 19 | ReadingHeader, 20 | ReadingPayloadLength, 21 | ReadingExtendedPayloadLength2, 22 | ReadingExtendedPayloadLength8, 23 | ReadingPayload, 24 | } 25 | 26 | impl Decoder { 27 | pub fn new() -> Self { 28 | Self { 29 | buffer: ReadBuffer::new(), 30 | decode_state: DecodeState::ReadingHeader, 31 | fin: false, 32 | op_code: 0, 33 | payload_length: 0, 34 | needs_more_data: true, 35 | } 36 | } 37 | 38 | #[inline] 39 | pub fn read(&mut self, stream: &mut S) -> io::Result<()> { 40 | if self.needs_more_data { 41 | self.buffer.read_all_from(stream)?; 42 | self.needs_more_data = false; 43 | } 44 | Ok(()) 45 | } 46 | 47 | #[inline] 48 | pub fn decode_next(&mut self) -> Result, Error> { 49 | loop { 50 | let available = self.buffer.available(); 51 | match self.decode_state { 52 | DecodeState::ReadingHeader => { 53 | if available > 0 { 54 | // SAFETY: available > 0 55 | let b = unsafe { self.buffer.consume_next_byte_unchecked() }; 56 | let fin = ((b & protocol::FIN_MASK) >> 7) == 1; 57 | let rsv1 = (b & protocol::RSV1_MASK) >> 6; 58 | let rsv2 = (b & protocol::RSV2_MASK) >> 5; 59 | let rsv3 = (b & protocol::RSV3_MASK) >> 4; 60 | if rsv1 + rsv2 + rsv3 != 0 { 61 | return Err(Error::Protocol("non zero RSV value received")); 62 | } 63 | self.fin = fin; 64 | let op_code = b & protocol::OP_CODE_MASK; 65 | self.op_code = op_code; 66 | self.decode_state = DecodeState::ReadingPayloadLength 67 | } else { 68 | break; 69 | } 70 | } 71 | DecodeState::ReadingPayloadLength => { 72 | if available > 0 { 73 | // SAFETY: available > 0 74 | let b = unsafe { self.buffer.consume_next_byte_unchecked() }; 75 | let mask = (b & protocol::MASK_MASK) >> 7; 76 | if mask == 1 { 77 | return Err(Error::Protocol("masking bit set on the server frame")); 78 | } 79 | let payload_length = b & protocol::PAYLOAD_LENGTH_MASK; 80 | self.payload_length = payload_length as usize; 81 | match payload_length { 82 | 0..=125 => self.decode_state = DecodeState::ReadingPayload, 83 | 126 => self.decode_state = DecodeState::ReadingExtendedPayloadLength2, 84 | 127 => self.decode_state = DecodeState::ReadingExtendedPayloadLength8, 85 | // we only use 7 bits 86 | _ => unsafe { std::hint::unreachable_unchecked() }, 87 | } 88 | } else { 89 | break; 90 | } 91 | } 92 | DecodeState::ReadingExtendedPayloadLength2 => { 93 | if available >= 2 { 94 | // SAFETY: available >= 2 95 | let bytes = unsafe { self.buffer.consume_next_unchecked(2) }; 96 | // SAFETY: we know bytes length is 2 97 | let payload_length = u16::from_be_bytes(unsafe { into_array(bytes) }); 98 | self.payload_length = payload_length as usize; 99 | self.decode_state = DecodeState::ReadingPayload; 100 | } else { 101 | break; 102 | } 103 | } 104 | DecodeState::ReadingExtendedPayloadLength8 => { 105 | if available >= 8 { 106 | // SAFETY: available >= 8 107 | let bytes = unsafe { self.buffer.consume_next_unchecked(8) }; 108 | // SAFETY: we know bytes length is 8 109 | let payload_length = u64::from_be_bytes(unsafe { into_array(bytes) }); 110 | self.payload_length = payload_length as usize; 111 | self.decode_state = DecodeState::ReadingPayload; 112 | } else { 113 | break; 114 | } 115 | } 116 | DecodeState::ReadingPayload => { 117 | let payload_length = self.payload_length; 118 | if available >= payload_length { 119 | // SAFETY: available >= payload_length 120 | let payload = unsafe { self.buffer.consume_next_unchecked(payload_length) }; 121 | let frame = match self.op_code { 122 | protocol::op::TEXT_FRAME => WebsocketFrame::Text(self.fin, payload), 123 | protocol::op::BINARY_FRAME => WebsocketFrame::Binary(self.fin, payload), 124 | protocol::op::CONTINUATION_FRAME => WebsocketFrame::Continuation(self.fin, payload), 125 | protocol::op::PING => WebsocketFrame::Ping(payload), 126 | protocol::op::CONNECTION_CLOSE => WebsocketFrame::Close(payload), 127 | _ => return Err(Error::Protocol("unknown op_code")), 128 | }; 129 | self.decode_state = DecodeState::ReadingHeader; 130 | return Ok(Some(frame)); 131 | } else { 132 | break; 133 | } 134 | } 135 | } 136 | } 137 | 138 | // await for more data 139 | self.needs_more_data = true; 140 | Ok(None) 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /src/ws/ds.rs: -------------------------------------------------------------------------------- 1 | use crate::ws::{Error, State, Websocket, WebsocketFrame}; 2 | use std::io; 3 | 4 | pub trait DataSource { 5 | fn next(&self) -> Result, Error>; 6 | 7 | fn into_stream(self) -> DataSourceStream 8 | where 9 | Self: Sized, 10 | { 11 | DataSourceStream { data_source: self } 12 | } 13 | } 14 | 15 | pub struct DataSourceStream { 16 | data_source: D, 17 | } 18 | 19 | impl Websocket> { 20 | pub fn receive_next(&mut self) -> Result, Error> { 21 | self.stream.data_source.next() 22 | } 23 | } 24 | 25 | impl Websocket { 26 | pub fn from_data_source(data_source: D) -> io::Result>> { 27 | Ok(Websocket { 28 | stream: data_source.into_stream(), 29 | closed: false, 30 | state: State::connection(), 31 | }) 32 | } 33 | } 34 | 35 | #[cfg(test)] 36 | mod tests { 37 | use super::*; 38 | 39 | #[test] 40 | fn should_use_custom_data_source() { 41 | struct CustomDataSource; 42 | 43 | impl DataSource for CustomDataSource { 44 | fn next(&self) -> Result, Error> { 45 | Ok(Some(WebsocketFrame::Text(true, b"foo"))) 46 | } 47 | } 48 | 49 | let mut ws = Websocket::from_data_source(CustomDataSource).unwrap(); 50 | 51 | if let Some(WebsocketFrame::Text(_fin, data)) = ws.receive_next().unwrap() { 52 | assert_eq!(b"foo", data) 53 | } else { 54 | panic!("test failed") 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/ws/encoder.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::io::Write; 3 | 4 | use crate::ws::protocol; 5 | 6 | #[inline] 7 | pub fn send(stream: &mut S, fin: bool, op_code: u8, body: Option<&[u8]>) -> io::Result<()> { 8 | let mut header = 0u8; 9 | if fin { 10 | header |= protocol::FIN_MASK; 11 | } 12 | header |= op_code; 13 | stream.write_all(&header.to_be_bytes())?; 14 | let mut payload_length = 0u8; 15 | payload_length |= protocol::MASK_MASK; 16 | if let Some(body) = body { 17 | let len = body.len(); 18 | if len <= 125 { 19 | payload_length |= len as u8; 20 | stream.write_all(&payload_length.to_be_bytes())?; 21 | } else if len <= u16::MAX as usize { 22 | payload_length |= 126; 23 | let extended_payload_length = len as u16; 24 | stream.write_all(&payload_length.to_be_bytes())?; 25 | stream.write_all(&extended_payload_length.to_be_bytes())?; 26 | } else if len <= u64::MAX as usize { 27 | payload_length |= 127; 28 | let extended_payload_length = len as u64; 29 | stream.write_all(&payload_length.to_be_bytes())?; 30 | stream.write_all(&extended_payload_length.to_be_bytes())?; 31 | } 32 | } else { 33 | stream.write_all(&payload_length.to_be_bytes())?; 34 | } 35 | let masking_key = 0u32; 36 | stream.write_all(&masking_key.to_be_bytes())?; 37 | if let Some(body) = body { 38 | // we can send plain text as masking key is set to zero on purpose 39 | // this is done for performance reason as it will make XOR no-op 40 | stream.write_all(body)?; 41 | } 42 | stream.flush()?; 43 | Ok(()) 44 | } 45 | -------------------------------------------------------------------------------- /src/ws/error.rs: -------------------------------------------------------------------------------- 1 | use std::array::TryFromSliceError; 2 | use std::io; 3 | use thiserror::Error; 4 | use url::ParseError; 5 | 6 | #[derive(Error, Debug)] 7 | pub enum Error { 8 | #[error("the peer has sent the close frame: status code {0}, body: {1}")] 9 | ReceivedCloseFrame(u16, String), 10 | #[error("websocket protocol error: {0}")] 11 | Protocol(&'static str), 12 | #[error("the websocket is closed and can be dropped")] 13 | Closed, 14 | #[error("IO error: {0}")] 15 | IO(#[from] io::Error), 16 | #[error("url parse error: {0}")] 17 | InvalidUrl(#[from] ParseError), 18 | #[error("slice error: {0}")] 19 | SliceError(#[from] TryFromSliceError), 20 | } 21 | 22 | impl From for io::Error { 23 | fn from(value: Error) -> Self { 24 | io::Error::other(value) 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/ws/handshake.rs: -------------------------------------------------------------------------------- 1 | use std::collections::VecDeque; 2 | use std::io; 3 | use std::io::ErrorKind::WouldBlock; 4 | use std::io::{Read, Write}; 5 | 6 | use base64::Engine; 7 | use base64::engine::general_purpose; 8 | use http::StatusCode; 9 | use httparse::Response; 10 | use rand::{Rng, rng}; 11 | 12 | use crate::buffer::ReadBuffer; 13 | use crate::ws::Error; 14 | use crate::ws::handshake::HandshakeState::{Completed, NotStarted, Pending}; 15 | 16 | #[derive(Debug)] 17 | pub struct Handshaker { 18 | buffer: ReadBuffer<1>, 19 | state: HandshakeState, 20 | server_name: String, 21 | endpoint: String, 22 | pending_msg_buffer: VecDeque<(u8, bool, Option>)>, 23 | } 24 | 25 | #[derive(Debug, Copy, Clone, Eq, PartialEq)] 26 | pub enum HandshakeState { 27 | NotStarted, 28 | Pending, 29 | Completed, 30 | } 31 | 32 | impl Handshaker { 33 | pub fn new(server_name: &str, endpoint: &str) -> Self { 34 | Self { 35 | buffer: ReadBuffer::new(), 36 | state: NotStarted, 37 | server_name: server_name.to_string(), 38 | endpoint: endpoint.to_string(), 39 | pending_msg_buffer: VecDeque::with_capacity(256), 40 | } 41 | } 42 | 43 | #[cold] 44 | pub fn read(&mut self, stream: &mut S) -> io::Result<()> { 45 | if self.state == Pending { 46 | self.buffer.read_from(stream)?; 47 | } 48 | Ok(()) 49 | } 50 | 51 | #[cold] 52 | pub fn perform_handshake(&mut self, stream: &mut S) -> io::Result<()> { 53 | match self.state { 54 | NotStarted => { 55 | self.send_handshake_request(stream)?; 56 | Err(io::Error::from(WouldBlock)) 57 | } 58 | Pending => { 59 | let available = self.buffer.available(); 60 | if available >= 4 && self.buffer.view_last(4) == b"\r\n\r\n" { 61 | // decode http response 62 | let mut headers = [httparse::EMPTY_HEADER; 64]; 63 | let mut response = Response::new(&mut headers); 64 | response.parse(self.buffer.view()).map_err(io::Error::other)?; 65 | if response.code.unwrap() != StatusCode::SWITCHING_PROTOCOLS.as_u16() { 66 | return Err(io::Error::other("unable to switch protocols")); 67 | } 68 | self.state = Completed; 69 | } 70 | Err(io::Error::from(WouldBlock)) 71 | } 72 | Completed => Ok(()), 73 | } 74 | } 75 | 76 | #[cold] 77 | pub fn buffer_message(&mut self, fin: bool, op: u8, body: Option<&[u8]>) { 78 | let body = body.map(|body| body.to_vec()); 79 | self.pending_msg_buffer.push_back((op, fin, body)) 80 | } 81 | 82 | #[cold] 83 | pub fn drain_pending_message_buffer(&mut self, stream: &mut S, mut send: F) -> Result<(), Error> 84 | where 85 | S: Write, 86 | F: FnMut(&mut S, bool, u8, Option<&[u8]>) -> io::Result<()>, 87 | { 88 | while let Some((op, fin, body)) = self.pending_msg_buffer.pop_front() { 89 | send(stream, fin, op, body.as_deref())?; 90 | } 91 | Ok(()) 92 | } 93 | 94 | fn send_handshake_request(&mut self, stream: &mut S) -> io::Result<()> { 95 | stream.write_all(format!("GET {} HTTP/1.1\r\n", self.endpoint).as_bytes())?; 96 | stream.write_all(format!("Host: {}\r\n", self.server_name).as_bytes())?; 97 | stream.write_all(b"Upgrade: websocket\r\n")?; 98 | stream.write_all(b"Connection: upgrade\r\n")?; 99 | stream.write_all(format!("Sec-WebSocket-Key: {}\r\n", generate_nonce()).as_bytes())?; 100 | stream.write_all(b"Sec-WebSocket-Version: 13\r\n")?; 101 | stream.write_all(b"\r\n")?; 102 | stream.flush()?; 103 | self.state = Pending; 104 | Ok(()) 105 | } 106 | } 107 | 108 | fn generate_nonce() -> String { 109 | let mut rng = rng(); 110 | let nonce_bytes: [u8; 16] = rng.random(); 111 | general_purpose::STANDARD.encode(nonce_bytes) 112 | } 113 | -------------------------------------------------------------------------------- /src/ws/mod.rs: -------------------------------------------------------------------------------- 1 | //! Websocket client protocol implementation. 2 | //! 3 | //! ## Examples 4 | //! 5 | //! Create a TLS websocket from a stream. 6 | //!```no_run 7 | //! use std::net::TcpStream; 8 | //! use boomnet::stream::{BindAndConnect, ConnectionInfo}; 9 | //! use boomnet::stream::buffer::IntoBufferedStream; 10 | //! use boomnet::stream::tls::IntoTlsStream; 11 | //! use boomnet::ws::IntoWebsocket; 12 | //! 13 | //! let mut ws = ConnectionInfo::new("stream.binance.com", 9443) 14 | //! .into_tcp_stream().unwrap() 15 | //! .into_tls_stream().unwrap() 16 | //! .into_default_buffered_stream() 17 | //! .into_websocket("/ws"); 18 | //! ``` 19 | //! 20 | //! Quickly create websocket from a valid url (for debugging purposes only). 21 | //! ```no_run 22 | //! use boomnet::ws::TryIntoTlsReadyWebsocket; 23 | //! 24 | //! let mut ws = "wss://stream.binance.com/ws".try_into_tls_ready_websocket().unwrap(); 25 | //! ``` 26 | //! 27 | //! Receive messages in a batch for optimal performance. 28 | //!```no_run 29 | //! use std::io::{Read, Write}; 30 | //! use boomnet::ws::{Websocket, WebsocketFrame}; 31 | //! 32 | //! fn consume_batch(ws: &mut Websocket) -> std::io::Result<()> { 33 | //! for frame in ws.read_batch()? { 34 | //! if let WebsocketFrame::Text(fin, body) = frame? { 35 | //! println!("({fin}) {}", String::from_utf8_lossy(body)); 36 | //! } 37 | //! } 38 | //! Ok(()) 39 | //! } 40 | //! ``` 41 | //! 42 | //! Receive messages at most one at a tine. If possible, use batch mode instead. 43 | //!```no_run 44 | //! use std::io::{Read, Write}; 45 | //! use boomnet::ws::{Websocket, WebsocketFrame}; 46 | //! 47 | //! fn consume_individually(ws: &mut Websocket) -> std::io::Result<()> { 48 | //! if let Some(frame) = ws.receive_next() { 49 | //! if let WebsocketFrame::Text(fin, body) = frame? { 50 | //! println!("({fin}) {}", String::from_utf8_lossy(body)); 51 | //! } 52 | //! } 53 | //! Ok(()) 54 | //! } 55 | //! ``` 56 | 57 | use crate::buffer; 58 | use crate::service::select::Selectable; 59 | #[cfg(any(feature = "rustls", feature = "openssl"))] 60 | use crate::stream::tls::{IntoTlsStream, TlsReadyStream, TlsStream}; 61 | use crate::stream::{BindAndConnect, ConnectionInfoProvider}; 62 | use crate::util::NoBlock; 63 | use crate::ws::Error::{Closed, ReceivedCloseFrame}; 64 | use crate::ws::decoder::Decoder; 65 | use crate::ws::handshake::Handshaker; 66 | #[cfg(feature = "mio")] 67 | use mio::{Interest, Registry, Token, event::Source}; 68 | use std::fmt::Debug; 69 | use std::io; 70 | use std::io::ErrorKind::WouldBlock; 71 | use std::io::{Read, Write}; 72 | use thiserror::Error; 73 | use url::Url; 74 | 75 | // re-export 76 | pub use crate::ws::error::Error; 77 | 78 | mod decoder; 79 | pub mod ds; 80 | mod encoder; 81 | mod error; 82 | mod handshake; 83 | mod protocol; 84 | pub mod util; 85 | 86 | type ReadBuffer = buffer::ReadBuffer<4096>; 87 | 88 | /// Supported web socket frame variants. 89 | pub enum WebsocketFrame { 90 | /// Server has sent ping frame that will generate automatic pong response. This frame is not 91 | /// exposed to the user. 92 | Ping(&'static [u8]), 93 | Pong(&'static [u8]), 94 | Text(bool, &'static [u8]), 95 | Binary(bool, &'static [u8]), 96 | Continuation(bool, &'static [u8]), 97 | /// Server has sent close frame. The websocket will be closed as a result. This frame is not 98 | /// exposed to the user. 99 | Close(&'static [u8]), 100 | } 101 | 102 | /// Websocket client that owns underlying stream. 103 | #[derive(Debug)] 104 | pub struct Websocket { 105 | stream: S, 106 | closed: bool, 107 | state: State, 108 | } 109 | 110 | impl Websocket { 111 | pub fn new(stream: S, server_name: &str, endpoint: &str) -> Websocket { 112 | Self { 113 | stream, 114 | closed: false, 115 | state: State::handshake(server_name, endpoint), 116 | } 117 | } 118 | 119 | /// Checks if the websocket is closed. This can be result of an IO error or the other side 120 | /// sending `WebsocketFrame::Closed`. 121 | pub const fn closed(&self) -> bool { 122 | self.closed 123 | } 124 | 125 | /// Checks if the handshake has completed successfully. If attempt is made to send a message 126 | /// while the handshake is pending the message will be buffered and dispatched once handshake 127 | /// has finished. 128 | #[inline] 129 | pub const fn handshake_complete(&self) -> bool { 130 | match self.state { 131 | State::Handshake(_) => false, 132 | State::Connection(_) => true, 133 | } 134 | } 135 | } 136 | 137 | impl Websocket { 138 | /// Allows to decode and iterate over incoming messages in a batch efficient way. It will perform 139 | /// single network read operation if there is no more data available for processing. It is possible 140 | /// to receive more than one message from a single network read and when no messages are available 141 | /// in the current batch, the iterator will yield `None`. 142 | /// 143 | /// ## Examples 144 | /// 145 | /// Process incoming frames in a batch using iterator, 146 | /// ```no_run 147 | /// use std::io::{Read, Write}; 148 | /// use boomnet::ws::{Websocket, WebsocketFrame}; 149 | /// 150 | /// fn process(ws: &mut Websocket) -> std::io::Result<()> { 151 | /// for frame in ws.read_batch()? { 152 | /// if let (WebsocketFrame::Text(fin, data)) = frame? { 153 | /// println!("({fin}) {}", String::from_utf8_lossy(data)); 154 | /// } 155 | /// } 156 | /// Ok(()) 157 | /// } 158 | /// ``` 159 | /// 160 | /// Read frames one by one without iterator, 161 | /// ```no_run 162 | /// use std::io::{Read, Write}; 163 | /// use boomnet::ws::{Websocket, WebsocketFrame}; 164 | /// 165 | /// fn process(ws: &mut Websocket) -> std::io::Result<()> { 166 | /// let mut batch = ws.read_batch()?; 167 | /// while let Some(frame) = batch.receive_next() { 168 | /// if let (WebsocketFrame::Text(fin, data)) = frame? { 169 | /// println!("({fin}) {}", String::from_utf8_lossy(data)); 170 | /// } 171 | /// } 172 | /// Ok(()) 173 | /// } 174 | /// ``` 175 | #[inline] 176 | pub fn read_batch(&mut self) -> Result, Error> { 177 | match self.state.read(&mut self.stream).no_block() { 178 | Ok(()) => Ok(Batch { websocket: self }), 179 | Err(err) => { 180 | self.closed = true; 181 | Err(err)? 182 | } 183 | } 184 | } 185 | 186 | #[inline] 187 | pub fn receive_next(&mut self) -> Option> { 188 | match self.read_batch() { 189 | Ok(mut batch) => batch.receive_next(), 190 | Err(err) => Some(Err(err)), 191 | } 192 | } 193 | 194 | #[inline] 195 | pub fn send_text(&mut self, fin: bool, body: Option<&[u8]>) -> Result<(), Error> { 196 | self.send(fin, protocol::op::TEXT_FRAME, body) 197 | } 198 | 199 | #[inline] 200 | pub fn send_binary(&mut self, fin: bool, body: Option<&[u8]>) -> Result<(), Error> { 201 | self.send(fin, protocol::op::BINARY_FRAME, body) 202 | } 203 | 204 | #[inline] 205 | pub fn send_pong(&mut self, body: Option<&[u8]>) -> Result<(), Error> { 206 | self.send(true, protocol::op::PONG, body) 207 | } 208 | 209 | #[inline] 210 | pub fn send_ping(&mut self, body: Option<&[u8]>) -> Result<(), Error> { 211 | self.send(true, protocol::op::PING, body) 212 | } 213 | 214 | #[inline] 215 | fn next(&mut self) -> Result, Error> { 216 | self.ensure_not_closed()?; 217 | match self.state.next(&mut self.stream) { 218 | Ok(frame) => Ok(frame), 219 | Err(err) => { 220 | self.closed = true; 221 | Err(err)? 222 | } 223 | } 224 | } 225 | 226 | #[inline] 227 | fn send(&mut self, fin: bool, op_code: u8, body: Option<&[u8]>) -> Result<(), Error> { 228 | self.ensure_not_closed()?; 229 | match self.state.send(&mut self.stream, fin, op_code, body) { 230 | Ok(()) => Ok(()), 231 | Err(err) => { 232 | self.closed = true; 233 | Err(err)? 234 | } 235 | } 236 | } 237 | 238 | #[inline] 239 | const fn ensure_not_closed(&self) -> Result<(), Error> { 240 | if self.closed { 241 | return Err(Closed); 242 | } 243 | Ok(()) 244 | } 245 | } 246 | 247 | #[cfg(feature = "mio")] 248 | impl Source for Websocket { 249 | fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 250 | registry.register(&mut self.stream, token, interests) 251 | } 252 | 253 | fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> { 254 | registry.reregister(&mut self.stream, token, interests) 255 | } 256 | 257 | fn deregister(&mut self, registry: &Registry) -> io::Result<()> { 258 | registry.deregister(&mut self.stream) 259 | } 260 | } 261 | 262 | impl Selectable for Websocket { 263 | fn connected(&mut self) -> io::Result { 264 | self.stream.connected() 265 | } 266 | 267 | fn make_writable(&mut self) -> io::Result<()> { 268 | self.stream.make_writable() 269 | } 270 | 271 | fn make_readable(&mut self) -> io::Result<()> { 272 | self.stream.make_readable() 273 | } 274 | } 275 | 276 | #[derive(Debug)] 277 | enum State { 278 | Handshake(Handshaker), 279 | Connection(Decoder), 280 | } 281 | 282 | impl State { 283 | pub fn handshake(server_name: &str, endpoint: &str) -> Self { 284 | Self::Handshake(Handshaker::new(server_name, endpoint)) 285 | } 286 | 287 | pub fn connection() -> Self { 288 | Self::Connection(Decoder::new()) 289 | } 290 | } 291 | 292 | impl State { 293 | #[inline] 294 | fn read(&mut self, stream: &mut S) -> io::Result<()> { 295 | match self { 296 | State::Handshake(handshake) => handshake.read(stream), 297 | State::Connection(decoder) => decoder.read(stream), 298 | } 299 | } 300 | 301 | #[inline] 302 | fn next(&mut self, stream: &mut S) -> Result, Error> { 303 | match self { 304 | State::Handshake(handshake) => match handshake.perform_handshake(stream) { 305 | Ok(()) => { 306 | handshake.drain_pending_message_buffer(stream, encoder::send)?; 307 | *self = State::connection(); 308 | Ok(None) 309 | } 310 | Err(err) if err.kind() == WouldBlock => Ok(None), 311 | Err(err) => Err(err)?, 312 | }, 313 | State::Connection(decoder) => match decoder.decode_next() { 314 | Ok(Some(WebsocketFrame::Ping(payload))) => { 315 | self.send(stream, true, protocol::op::PONG, Some(payload))?; 316 | Ok(None) 317 | } 318 | Ok(Some(WebsocketFrame::Close(payload))) => { 319 | let _ = self.send(stream, true, protocol::op::CONNECTION_CLOSE, Some(payload)); 320 | let (status_code, body) = payload.split_at(std::mem::size_of::()); 321 | let status_code = u16::from_be_bytes(status_code.try_into()?); 322 | let body = String::from_utf8_lossy(body).to_string(); 323 | Err(ReceivedCloseFrame(status_code, body)) 324 | } 325 | Ok(frame) => Ok(frame), 326 | Err(err) => Err(err)?, 327 | }, 328 | } 329 | } 330 | 331 | #[inline] 332 | fn send(&mut self, stream: &mut S, fin: bool, op_code: u8, body: Option<&[u8]>) -> Result<(), Error> { 333 | match self { 334 | State::Handshake(handshake) => { 335 | handshake.buffer_message(fin, op_code, body); 336 | Ok(()) 337 | } 338 | State::Connection(_) => { 339 | encoder::send(stream, fin, op_code, body)?; 340 | Ok(()) 341 | } 342 | } 343 | } 344 | } 345 | 346 | /// Represents a batch of 0 to N websocket frames since the last network read that are ready to be decoded. 347 | pub struct Batch<'a, S> { 348 | websocket: &'a mut Websocket, 349 | } 350 | 351 | impl<'a, S: Read + Write> IntoIterator for Batch<'a, S> { 352 | type Item = Result; 353 | type IntoIter = BatchIter<'a, S>; 354 | 355 | fn into_iter(self) -> Self::IntoIter { 356 | BatchIter { batch: self } 357 | } 358 | } 359 | 360 | impl Batch<'_, S> { 361 | /// Try to decode next frame from the underlying `Batch`. If no more frames are available it 362 | /// will return `None`. 363 | pub fn receive_next(&mut self) -> Option> { 364 | self.websocket.next().transpose() 365 | } 366 | } 367 | 368 | /// Iterator that owns the current `Batch`. When no more frames are available to be decoded in the buffer 369 | /// it will yield `None`. 370 | pub struct BatchIter<'a, S> { 371 | batch: Batch<'a, S>, 372 | } 373 | 374 | impl Iterator for BatchIter<'_, S> { 375 | type Item = Result; 376 | 377 | fn next(&mut self) -> Option { 378 | self.batch.receive_next() 379 | } 380 | } 381 | 382 | pub trait IntoWebsocket { 383 | fn into_websocket(self, endpoint: &str) -> Websocket 384 | where 385 | Self: Sized; 386 | } 387 | 388 | impl IntoWebsocket for T 389 | where 390 | T: Read + Write + ConnectionInfoProvider, 391 | { 392 | fn into_websocket(self, endpoint: &str) -> Websocket 393 | where 394 | Self: Sized, 395 | { 396 | let host = self.connection_info().host().to_owned(); 397 | Websocket::new(self, &host, endpoint) 398 | } 399 | } 400 | 401 | #[cfg(any(feature = "rustls", feature = "openssl"))] 402 | pub trait IntoTlsWebsocket { 403 | fn into_tls_websocket(self, endpoint: &str) -> io::Result>> 404 | where 405 | Self: Sized; 406 | } 407 | 408 | #[cfg(any(feature = "rustls", feature = "openssl"))] 409 | impl IntoTlsWebsocket for T 410 | where 411 | T: Read + Write + Debug + ConnectionInfoProvider, 412 | { 413 | fn into_tls_websocket(self, endpoint: &str) -> io::Result>> 414 | where 415 | Self: Sized, 416 | { 417 | Ok(self.into_tls_stream()?.into_websocket(endpoint)) 418 | } 419 | } 420 | 421 | #[cfg(any(feature = "rustls", feature = "openssl"))] 422 | pub trait TryIntoTlsReadyWebsocket { 423 | fn try_into_tls_ready_websocket(self) -> io::Result>> 424 | where 425 | Self: Sized; 426 | } 427 | 428 | #[cfg(any(feature = "rustls", feature = "openssl"))] 429 | impl TryIntoTlsReadyWebsocket for T 430 | where 431 | T: AsRef, 432 | { 433 | fn try_into_tls_ready_websocket(self) -> io::Result>> 434 | where 435 | Self: Sized, 436 | { 437 | let url = Url::parse(self.as_ref()).map_err(io::Error::other)?; 438 | 439 | let addr = url.socket_addrs(|| match url.scheme() { 440 | "ws" => Some(80), 441 | "wss" => Some(443), 442 | _ => None, 443 | })?; 444 | 445 | let endpoint = match url.query() { 446 | Some(query) => format!("{}?{}", url.path(), query), 447 | None => url.path().to_string(), 448 | }; 449 | 450 | let stream = std::net::TcpStream::bind_and_connect(addr[0], None, None)?; 451 | 452 | let tls_ready_stream = match url.scheme() { 453 | "ws" => Ok(TlsReadyStream::Plain(stream)), 454 | "wss" => Ok(TlsReadyStream::Tls(TlsStream::wrap(stream, url.host_str().unwrap()).unwrap())), 455 | scheme => Err(io::Error::other(format!("unrecognised url scheme: {}", scheme))), 456 | }?; 457 | 458 | Ok(Websocket::new(tls_ready_stream, url.host_str().unwrap(), &endpoint)) 459 | } 460 | } 461 | -------------------------------------------------------------------------------- /src/ws/protocol.rs: -------------------------------------------------------------------------------- 1 | pub const FIN_MASK: u8 = 0b1000_0000; 2 | pub const RSV1_MASK: u8 = 0b0100_0000; 3 | pub const RSV2_MASK: u8 = 0b0010_0000; 4 | pub const RSV3_MASK: u8 = 0b0001_0000; 5 | pub const OP_CODE_MASK: u8 = 0b0000_1111; 6 | pub const MASK_MASK: u8 = 0b1000_0000; 7 | pub const PAYLOAD_LENGTH_MASK: u8 = 0b0111_1111; 8 | 9 | pub mod op { 10 | pub const CONTINUATION_FRAME: u8 = 0x0; 11 | pub const TEXT_FRAME: u8 = 0x1; 12 | pub const BINARY_FRAME: u8 = 0x2; 13 | pub const CONNECTION_CLOSE: u8 = 0x8; 14 | pub const PING: u8 = 0x9; 15 | pub const PONG: u8 = 0xA; 16 | } 17 | -------------------------------------------------------------------------------- /src/ws/util.rs: -------------------------------------------------------------------------------- 1 | use crate::stream::ConnectionInfo; 2 | use crate::ws::Error; 3 | use std::io; 4 | use url::Url; 5 | 6 | pub fn parse_url(url: &str) -> Result<(ConnectionInfo, String, bool), Error> { 7 | let url = Url::parse(url)?; 8 | let connection_info = ConnectionInfo::try_from(url.clone())?; 9 | let endpoint = match url.query() { 10 | Some(query) => format!("{}?{}", url.path(), query), 11 | None => url.path().to_string(), 12 | }; 13 | let secure = match url.scheme() { 14 | "ws" => false, 15 | "wss" => true, 16 | scheme => Err(io::Error::other(format!("unrecognised url scheme: {}", scheme)))?, 17 | }; 18 | Ok((connection_info, endpoint, secure)) 19 | } 20 | --------------------------------------------------------------------------------