├── .github └── workflows │ ├── build.yml │ └── ci.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── rustfmt.toml ├── src ├── bleed │ ├── mod.rs │ ├── store.rs │ └── writer.rs ├── endpoint │ ├── async_client.rs │ ├── async_server.rs │ ├── client.rs │ ├── detail │ │ ├── accept.rs │ │ ├── connect.rs │ │ └── mod.rs │ ├── mod.rs │ └── server.rs ├── error │ ├── ctrl.rs │ ├── frame.rs │ ├── handshake.rs │ └── mod.rs ├── frame │ ├── flag.rs │ ├── length.rs │ ├── mask.rs │ └── mod.rs ├── handshake │ ├── key.rs │ ├── mod.rs │ ├── request.rs │ └── response.rs ├── lib.rs ├── role │ ├── client.rs │ ├── mod.rs │ └── server.rs └── stream │ ├── async_read.rs │ ├── async_write.rs │ ├── ctrl.rs │ ├── detail │ ├── mod.rs │ ├── read.rs │ └── write.rs │ ├── mod.rs │ ├── read.rs │ ├── special.rs │ ├── state.rs │ └── write.rs └── tests ├── async_bidi_copy.rs ├── async_echo.rs ├── async_handshake.rs ├── async_read_write.rs ├── auto_mask.rs ├── sync_bidi_copy.rs ├── sync_echo.rs ├── sync_handshake.rs └── sync_read_write.rs /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | on: 3 | push: 4 | branches: [ master ] 5 | pull_request: 6 | branches: [ master ] 7 | 8 | env: 9 | CARGO_TERM_COLOR: always 10 | 11 | jobs: 12 | build-corss: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | target: 17 | - x86_64-unknown-linux-gnu 18 | - x86_64-unknown-linux-musl 19 | - x86_64-linux-android 20 | - x86_64-pc-windows-gnu 21 | - aarch64-unknown-linux-gnu 22 | - aarch64-unknown-linux-musl 23 | - aarch64-linux-android 24 | steps: 25 | - uses: actions/checkout@v4 26 | - name: install toolchain 27 | uses: dtolnay/rust-toolchain@master 28 | with: 29 | toolchain: nightly 30 | targets: ${{ matrix.target }} 31 | - name: compile 32 | run: cargo build --release --target=${{ matrix.target }} 33 | build-windows: 34 | runs-on: windows-latest 35 | strategy: 36 | matrix: 37 | target: 38 | - x86_64-pc-windows-msvc 39 | steps: 40 | - uses: actions/checkout@v4 41 | - name: install toolchain 42 | uses: dtolnay/rust-toolchain@master 43 | with: 44 | toolchain: nightly 45 | targets: ${{ matrix.target }} 46 | - name: compile 47 | run: cargo build --release --target=${{ matrix.target }} 48 | build-apple: 49 | runs-on: macos-latest 50 | strategy: 51 | matrix: 52 | target: 53 | - x86_64-apple-darwin 54 | - aarch64-apple-darwin 55 | - aarch64-apple-ios 56 | steps: 57 | - uses: actions/checkout@v4 58 | - name: install toolchain 59 | uses: dtolnay/rust-toolchain@master 60 | with: 61 | toolchain: nightly 62 | targets: ${{ matrix.target }} 63 | - name: compile 64 | run: cargo build --release --target=${{ matrix.target }} 65 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: [ master ] 5 | pull_request: 6 | branches: [ master ] 7 | jobs: 8 | clippy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: dtolnay/rust-toolchain@master 13 | with: 14 | toolchain: nightly 15 | components: clippy 16 | - run: cargo clippy --all-features 17 | test-debug: 18 | runs-on: ubuntu-latest 19 | steps: 20 | - uses: actions/checkout@v4 21 | - uses: dtolnay/rust-toolchain@master 22 | with: 23 | toolchain: nightly 24 | components: clippy 25 | - run: cargo test -v --no-fail-fast --all-features 26 | test-release: 27 | runs-on: ubuntu-latest 28 | steps: 29 | - uses: actions/checkout@v4 30 | - uses: dtolnay/rust-toolchain@master 31 | with: 32 | toolchain: nightly 33 | components: clippy 34 | - run: cargo test -v --no-fail-fast --release --all-features 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | .vscode/ 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "lightws" 3 | version = "0.6.12" 4 | authors = ["zephyr "] 5 | description = "Lightweight websocket implement for stream transmission." 6 | repository = "https://github.com/zephyrchien/lightws" 7 | readme = "README.md" 8 | documentation = "https://docs.rs/lightws" 9 | keywords = ["websocket", "network", "stream", "async"] 10 | edition = "2021" 11 | license = "MIT" 12 | 13 | [features] 14 | default = ["async"] 15 | async = ["tokio"] 16 | unsafe_auto_mask_write = [] 17 | 18 | [dependencies] 19 | cfg-if = "1" 20 | rand = "0.8" 21 | sha1 = "0.10" 22 | base64 = "0.21" 23 | httparse = "1" 24 | tokio = { version = "1", optional = true } 25 | 26 | 27 | [dev-dependencies] 28 | log = "0.4" 29 | env_logger = "0.10" 30 | tokio = { version = "1", features = ["full"] } 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 zephyr 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lightws 2 | 3 | ![Lightws](https://github.com/zephyrchien/lightws/workflows/ci/badge.svg) 4 | ![Lightws](https://github.com/zephyrchien/lightws/workflows/build/badge.svg) 5 | [![Released API docs](https://docs.rs/lightws/badge.svg)](https://docs.rs/lightws) 6 | [![crates.io](https://img.shields.io/crates/v/lightws.svg)](https://crates.io/crates/lightws) 7 | 8 | Lightweight websocket implement for stream transmission. 9 | 10 | ## Features 11 | 12 | - Avoid heap allocation. 13 | - Avoid buffering frame payload. 14 | - Use vectored-io if available. 15 | - Transparent Read/Write over the underlying IO source. 16 | 17 | ## High-level API 18 | 19 | [role, endpoint, stream] 20 | 21 | Std: 22 | 23 | ```rust 24 | { 25 | // handshake 26 | let stream = Endpoint::connect(tcp, buf, host, path)?; 27 | // read some data 28 | stream.read(&mut buf)?; 29 | // write some data 30 | stream.write(&buf)?; 31 | } 32 | ``` 33 | 34 | Async: 35 | 36 | ```rust 37 | { 38 | // handshake 39 | let stream = Endpoint::connect_async(tcp, buf, host, path).await?; 40 | // read some data 41 | stream.read(&mut buf).await?; 42 | // write some data 43 | stream.write(&buf).await?; 44 | } 45 | ``` 46 | 47 | ## Low-level API 48 | 49 | [frame, handshake] 50 | 51 | Frame: 52 | 53 | ```rust 54 | { 55 | // encode a frame head 56 | let head = FrameHead::new(...); 57 | let offset = unsafe { 58 | head.encode_unchecked(&mut buf); 59 | }; 60 | 61 | // decode a frame head 62 | let (head, offset) = FrameHead::decode(&buf).unwrap(); 63 | } 64 | ``` 65 | 66 | Handshake: 67 | 68 | ```rust 69 | { 70 | // make a client handshake request 71 | let request = Request::new(b"/ws", b"example.com", "sec-key.."); 72 | let offset = request.encode(&mut buf).unwrap(); 73 | 74 | // parse a server handshake response 75 | let mut custom_headers = HttpHeader::new_storage(); 76 | let mut response = Response::new_storage(&mut custom_headers); 77 | let offset = response.decode(&buf).unwrap(); 78 | } 79 | ``` 80 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width = 100 2 | reorder_imports = false 3 | reorder_modules = false 4 | #fn_args_layout = "Vertical" 5 | fn_single_line = true 6 | -------------------------------------------------------------------------------- /src/bleed/mod.rs: -------------------------------------------------------------------------------- 1 | //! Some dirty work 2 | 3 | mod store; 4 | mod writer; 5 | 6 | pub(crate) use store::Store; 7 | pub(crate) use writer::Writer; 8 | 9 | #[inline] 10 | pub(crate) const unsafe fn slice(slice: &[T], beg: usize, end: usize) -> &[T] { 11 | let ptr = slice.as_ptr().add(beg); 12 | &*std::ptr::slice_from_raw_parts(ptr, end - beg) 13 | } 14 | 15 | #[inline] 16 | pub(crate) const unsafe fn slice_mut(slice: &mut [T], beg: usize, end: usize) -> &mut [T] { 17 | let ptr = slice.as_mut_ptr().add(beg); 18 | &mut *std::ptr::slice_from_raw_parts_mut(ptr, end - beg) 19 | } 20 | 21 | #[inline] 22 | pub(crate) const unsafe fn slice_to_array(slice: &[T]) -> &[T; N] { 23 | &*(slice as *const [T] as *const [T; N]) 24 | } 25 | 26 | #[inline] 27 | #[allow(unused)] 28 | #[allow(invalid_reference_casting)] 29 | #[allow(clippy::mut_from_ref)] 30 | pub(crate) const unsafe fn const_cast(x: &T) -> &mut T { 31 | let const_ptr = x as *const T; 32 | let mut_ptr = const_ptr.cast_mut(); 33 | &mut *mut_ptr 34 | } 35 | 36 | #[cfg(test)] 37 | mod test { 38 | use super::*; 39 | 40 | #[test] 41 | fn unsafe_slice() { 42 | let buf: Vec = std::iter::repeat(rand::random::()).take(1024).collect(); 43 | let mut buf2 = buf.clone(); 44 | 45 | macro_rules! s { 46 | ($beg: expr, $end: expr) => { 47 | assert_eq!(&buf[$beg..$end], unsafe { slice(&buf, $beg, $end) }); 48 | assert_eq!(&buf[$beg..$end], unsafe { 49 | slice_mut(&mut buf2, $beg, $end) 50 | }); 51 | }; 52 | } 53 | 54 | for end in 1..1024 { 55 | for beg in 0..end { 56 | s!(beg, end); 57 | } 58 | } 59 | } 60 | 61 | #[test] 62 | fn unsafe_slice_to_array() { 63 | let buf: Vec = std::iter::repeat(rand::random::()).take(4096).collect(); 64 | 65 | macro_rules! s { 66 | ($beg: expr, $len: expr) => { 67 | let slice = &buf[$beg..$beg + $len]; 68 | let array1: [_; $len] = slice.try_into().unwrap(); 69 | let array2: [_; $len] = *unsafe { slice_to_array::<_, $len>(slice) }; 70 | 71 | assert_eq!(array1, array2); 72 | }; 73 | } 74 | 75 | for beg in 0..=2048 { 76 | s!(beg, 2048); 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /src/bleed/store.rs: -------------------------------------------------------------------------------- 1 | use super::{slice, slice_mut}; 2 | 3 | /// Buffer on stack. 4 | #[derive(Debug, Clone, Copy)] 5 | pub struct Store { 6 | rd: u8, 7 | wr: u8, 8 | buf: [u8; N], 9 | } 10 | 11 | #[allow(unused)] 12 | impl Store { 13 | #[inline] 14 | pub const fn new() -> Self { 15 | Self { 16 | rd: 0, 17 | wr: 0, 18 | buf: [0; N], 19 | } 20 | } 21 | 22 | #[inline] 23 | pub fn new_with_data(data: &[u8]) -> Self { 24 | let mut buf = [0_u8; N]; 25 | unsafe { 26 | std::ptr::copy_nonoverlapping(data.as_ptr(), buf.as_mut_ptr(), data.len()); 27 | } 28 | Self { 29 | rd: 0, 30 | wr: data.len() as u8, 31 | buf, 32 | } 33 | } 34 | 35 | #[inline] 36 | pub fn replace_with_data(&mut self, data: &[u8]) { 37 | unsafe { 38 | std::ptr::copy_nonoverlapping(data.as_ptr(), self.buf.as_mut_ptr(), data.len()); 39 | } 40 | self.rd = 0; 41 | self.wr = data.len() as u8; 42 | } 43 | 44 | #[inline] 45 | pub const fn rd_pos(&self) -> usize { self.rd as usize } 46 | 47 | #[inline] 48 | pub const fn wr_pos(&self) -> usize { self.wr as usize } 49 | 50 | #[inline] 51 | pub const fn set_rd_pos(&mut self, n: usize) { self.rd = n as u8 } 52 | 53 | #[inline] 54 | pub const fn set_wr_pos(&mut self, n: usize) { self.wr = n as u8 } 55 | 56 | #[inline] 57 | pub const fn advance_rd_pos(&mut self, n: usize) { self.rd += n as u8 } 58 | 59 | #[inline] 60 | pub const fn advance_wr_pos(&mut self, n: usize) { self.wr += n as u8 } 61 | 62 | #[inline] 63 | pub const fn rd_left(&self) -> usize { self.wr as usize - self.rd as usize } 64 | 65 | #[inline] 66 | pub const fn wr_left(&self) -> usize { N - self.wr as usize } 67 | 68 | #[inline] 69 | pub const fn is_empty(&self) -> bool { self.wr == 0 } 70 | 71 | #[inline] 72 | pub const fn read(&self) -> &[u8] { 73 | unsafe { slice(&self.buf, self.rd as usize, self.wr as usize) } 74 | } 75 | 76 | #[inline] 77 | pub const fn write(&mut self) -> &mut [u8] { 78 | unsafe { slice_mut(&mut self.buf, self.wr as usize, N) } 79 | } 80 | 81 | #[inline] 82 | pub const fn reset(&mut self) { 83 | self.rd = 0; 84 | self.wr = 0; 85 | } 86 | } 87 | 88 | /// Get the whole buffer. 89 | impl AsRef<[u8]> for Store { 90 | #[inline] 91 | fn as_ref(&self) -> &[u8] { &self.buf } 92 | } 93 | 94 | /// Get the whole buffer. 95 | impl AsMut<[u8]> for Store { 96 | #[inline] 97 | fn as_mut(&mut self) -> &mut [u8] { &mut self.buf } 98 | } 99 | 100 | #[cfg(test)] 101 | mod test { 102 | use super::*; 103 | 104 | #[test] 105 | fn unsafe_store() { 106 | let mut store = Store::<14>::new_with_data(b"Hello, "); 107 | assert_eq!(store.read(), b"Hello, "); 108 | store.write().copy_from_slice(b"World!!"); 109 | store.advance_wr_pos(7); 110 | assert_eq!(store.read(), b"Hello, World!!"); 111 | store.advance_rd_pos(7); 112 | assert_eq!(store.read(), b"World!!"); 113 | 114 | store.reset(); 115 | assert_eq!(store.read(), []); 116 | 117 | store.replace_with_data(b"hello, world!!"); 118 | assert_eq!(store.read(), b"hello, world!!"); 119 | store.advance_rd_pos(7); 120 | assert_eq!(store.read(), b"world!!"); 121 | 122 | store.reset(); 123 | assert_eq!(store.read(), []); 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /src/bleed/writer.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | use std::ptr::copy_nonoverlapping; 3 | 4 | pub struct Writer<'a, T> { 5 | ptr: *mut T, 6 | pos: usize, 7 | cap: usize, 8 | _marker: PhantomData<&'a T>, 9 | } 10 | 11 | #[allow(unused)] 12 | #[allow(non_camel_case_types)] 13 | #[allow(clippy::builtin_type_shadow)] 14 | impl<'a, u8> Writer<'a, u8> { 15 | #[inline] 16 | pub const fn new(w: &mut [u8]) -> Self { 17 | Writer { 18 | ptr: w.as_mut_ptr(), 19 | pos: 0, 20 | cap: w.len(), 21 | _marker: PhantomData, 22 | } 23 | } 24 | 25 | #[inline] 26 | pub const unsafe fn new_raw(w: *mut u8, pos: usize, cap: usize) -> Self { 27 | Writer { 28 | ptr: w, 29 | pos, 30 | cap, 31 | _marker: PhantomData, 32 | } 33 | } 34 | 35 | #[inline] 36 | pub const fn pos(&self) -> usize { self.pos } 37 | 38 | #[inline] 39 | pub const fn cap(&self) -> usize { self.cap } 40 | 41 | #[inline] 42 | pub const fn remaining(&self) -> usize { self.cap - self.pos } 43 | 44 | #[inline] 45 | pub unsafe fn write_unchecked(&mut self, src: &[u8]) -> usize { 46 | let len = src.len(); 47 | copy_nonoverlapping(src.as_ptr(), self.ptr.add(self.pos), len); 48 | self.pos += len; 49 | len 50 | } 51 | 52 | #[inline] 53 | pub unsafe fn write_byte_unchecked(&mut self, b: u8) { 54 | *self.ptr.add(self.pos) = b; 55 | self.pos += 1; 56 | } 57 | 58 | #[inline] 59 | pub fn write_or_err(&mut self, src: &[u8], f: F) -> Result 60 | where 61 | F: Fn() -> E, 62 | E: std::error::Error, 63 | { 64 | if self.remaining() < src.len() { 65 | Err(f()) 66 | } else { 67 | Ok(unsafe { self.write_unchecked(src) }) 68 | } 69 | } 70 | } 71 | 72 | #[cfg(test)] 73 | mod test { 74 | use super::*; 75 | use std::io::Write; 76 | 77 | #[test] 78 | fn unsafe_write() { 79 | let mut buf = vec![0; 4096]; 80 | let mut buf2 = buf.clone(); 81 | 82 | for i in (1..=1024).filter(|x| 4096 % x == 0) { 83 | let n = 4096 / i; 84 | let data: Vec = std::iter::repeat(rand::random::()).take(i).collect(); 85 | 86 | let mut writer = Writer::new(&mut buf); 87 | let mut write_n = 0; 88 | 89 | for _ in 0..n { 90 | unsafe { writer.write_unchecked(&data[..]) }; 91 | { 92 | let mut writer2 = &mut buf2.as_mut_slice()[write_n..]; 93 | write_n += writer2.write(&data[..]).unwrap(); 94 | } 95 | assert_eq!(write_n, writer.pos()); 96 | assert_eq!(&buf, &buf2); 97 | } 98 | } 99 | } 100 | 101 | #[test] 102 | fn unsafe_write_byte() { 103 | let mut buf = vec![0; 4096]; 104 | let mut buf2 = buf.clone(); 105 | 106 | let mut writer = Writer::new(&mut buf); 107 | let mut writer2 = Writer::new(&mut buf2); 108 | 109 | for _ in 0..4096 { 110 | let b: u8 = rand::random(); 111 | unsafe { 112 | writer.write_byte_unchecked(b); 113 | writer2.write_unchecked(&[b]); 114 | assert_eq!(&buf, &buf2); 115 | } 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /src/endpoint/async_client.rs: -------------------------------------------------------------------------------- 1 | use std::io::Result; 2 | use std::pin::Pin; 3 | use std::future::poll_fn; 4 | 5 | use tokio::io::{ReadBuf, AsyncRead, AsyncWrite}; 6 | 7 | use super::detail; 8 | use super::Endpoint; 9 | 10 | use crate::role::ClientRole; 11 | use crate::handshake::{HttpHeader, Request, Response}; 12 | use crate::handshake::{new_sec_key, derive_accept_key}; 13 | use crate::error::HandshakeError; 14 | use crate::stream::Stream; 15 | 16 | impl Endpoint { 17 | /// Async version of [`send_request`](Self::send_request). 18 | pub async fn send_request_async<'h, 'b: 'h, const N: usize>( 19 | io: &mut IO, 20 | buf: &mut [u8], 21 | request: &Request<'h, 'b, N>, 22 | ) -> Result { 23 | poll_fn(|cx| { 24 | detail::send_request(io, buf, request, |io, buf| Pin::new(io).poll_write(cx, buf)) 25 | }) 26 | .await 27 | } 28 | 29 | /// Async version of [`recv_response`](Self::recv_response). 30 | /// 31 | /// # Safety 32 | /// 33 | /// Caller must not modify the buffer while `response` is in use, 34 | /// otherwise it is undefined behavior! 35 | pub async unsafe fn recv_response_async<'h, 'b: 'h, const N: usize>( 36 | io: &mut IO, 37 | buf: &mut [u8], 38 | response: &mut Response<'h, 'b, N>, 39 | ) -> Result { 40 | poll_fn(|cx| { 41 | detail::recv_response(io, buf, response, |io, buf| { 42 | let mut buf = ReadBuf::new(buf); 43 | Pin::new(io) 44 | .poll_read(cx, &mut buf) 45 | .map_ok(|_| buf.filled().len()) 46 | }) 47 | }) 48 | .await 49 | } 50 | 51 | /// Async version of [`connect`](Self::connect). 52 | pub async fn connect_async( 53 | mut io: IO, 54 | buf: &mut [u8], 55 | host: &str, 56 | path: &str, 57 | ) -> Result> { 58 | let sec_key = new_sec_key(); 59 | let sec_accept = derive_accept_key(&sec_key); 60 | 61 | // send 62 | let request = Request::new(path.as_bytes(), host.as_bytes(), &sec_key); 63 | let _ = Self::send_request_async(&mut io, buf, &request).await?; 64 | 65 | // recv 66 | let mut other_headers = HttpHeader::new_storage(); 67 | let mut response = Response::new_storage(&mut other_headers); 68 | // this is safe since we do not modify response. 69 | let _ = unsafe { Self::recv_response_async(&mut io, buf, &mut response) }.await?; 70 | 71 | // check 72 | if response.sec_accept != sec_accept { 73 | return Err(HandshakeError::SecWebSocketAccept.into()); 74 | } 75 | 76 | Ok(Stream::new(io, Role::new())) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/endpoint/async_server.rs: -------------------------------------------------------------------------------- 1 | use std::io::Result; 2 | use std::pin::Pin; 3 | use std::future::poll_fn; 4 | 5 | use tokio::io::{ReadBuf, AsyncRead, AsyncWrite}; 6 | 7 | use super::detail; 8 | use super::Endpoint; 9 | 10 | use crate::role::ServerRole; 11 | use crate::handshake::{HttpHeader, Request, Response}; 12 | use crate::handshake::derive_accept_key; 13 | use crate::error::HandshakeError; 14 | use crate::stream::Stream; 15 | 16 | impl Endpoint { 17 | /// Async version of [`send_response`](Self::send_response). 18 | pub async fn send_response_async( 19 | io: &mut IO, 20 | buf: &mut [u8], 21 | response: &Response<'_, '_, N>, 22 | ) -> Result { 23 | poll_fn(|cx| { 24 | detail::send_response(io, buf, response, |io, buf| { 25 | Pin::new(io).poll_write(cx, buf) 26 | }) 27 | }) 28 | .await 29 | } 30 | 31 | /// Async version of [`recv_request`](Self::recv_request). 32 | /// 33 | /// # Safety 34 | /// 35 | /// Caller must not modify the buffer while `request` is in use, 36 | /// otherwise it is undefined behavior! 37 | pub async unsafe fn recv_request_async<'h, 'b: 'h, const N: usize>( 38 | io: &mut IO, 39 | buf: &mut [u8], 40 | request: &mut Request<'h, 'b, N>, 41 | ) -> Result { 42 | poll_fn(|cx| { 43 | detail::recv_request(io, buf, request, |io, buf| { 44 | let mut buf = ReadBuf::new(buf); 45 | Pin::new(io) 46 | .poll_read(cx, &mut buf) 47 | .map_ok(|_| buf.filled().len()) 48 | }) 49 | }) 50 | .await 51 | } 52 | 53 | /// Async version of [`accept`](Self::accept). 54 | pub async fn accept_async( 55 | mut io: IO, 56 | buf: &mut [u8], 57 | host: &str, 58 | path: &str, 59 | ) -> Result> { 60 | // recv 61 | let mut other_headers = HttpHeader::new_storage(); 62 | let mut request = Request::new_storage(&mut other_headers); 63 | // this is safe since we do not modify request. 64 | let _ = unsafe { Self::recv_request_async(&mut io, buf, &mut request) }.await?; 65 | 66 | // check 67 | if request.host != host.as_bytes() { 68 | return Err(HandshakeError::Manual("host mismatch").into()); 69 | } 70 | 71 | if request.path != path.as_bytes() { 72 | return Err(HandshakeError::Manual("path mismatch").into()); 73 | } 74 | 75 | // send 76 | let sec_accept = derive_accept_key(request.sec_key); 77 | let response = Response::new(&sec_accept); 78 | let _ = Self::send_response_async(&mut io, buf, &response).await?; 79 | 80 | Ok(Stream::new(io, Role::new())) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/endpoint/client.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Read, Write, Result}; 2 | use std::task::Poll; 3 | 4 | use super::detail; 5 | use super::Endpoint; 6 | 7 | use crate::role::ClientRole; 8 | use crate::handshake::{HttpHeader, Request, Response}; 9 | use crate::handshake::{new_sec_key, derive_accept_key}; 10 | use crate::error::HandshakeError; 11 | use crate::stream::Stream; 12 | 13 | impl Endpoint { 14 | /// Send websocket upgrade request to IO source, return 15 | /// the number of bytes transmitted. 16 | /// Request data are encoded to the provided buffer. 17 | /// 18 | /// This function will block until all data 19 | /// are written to IO source or an error occurs. 20 | pub fn send_request( 21 | io: &mut IO, 22 | buf: &mut [u8], 23 | request: &Request<'_, '_, N>, 24 | ) -> Result { 25 | match detail::send_request(io, buf, request, |io, buf| io.write(buf).into()) { 26 | Poll::Ready(x) => x, 27 | Poll::Pending => unreachable!(), 28 | } 29 | } 30 | 31 | /// Receive websocket upgrade response from IO source, return 32 | /// the number of bytes transmitted. 33 | /// Received data are stored in the provided buffer, and parsed 34 | /// as [`Response`]. 35 | /// 36 | /// This function will block on reading data, until there is enough 37 | /// data to parse a response or an error occurs. 38 | /// 39 | /// # Safety 40 | /// 41 | /// Caller must not modify the buffer while `response` is in use, 42 | /// otherwise it is undefined behavior! 43 | pub unsafe fn recv_response<'h, 'b: 'h, const N: usize>( 44 | io: &mut IO, 45 | buf: &mut [u8], 46 | response: &mut Response<'h, 'b, N>, 47 | ) -> Result { 48 | match detail::recv_response(io, buf, response, |io, buf| io.read(buf).into()) { 49 | Poll::Ready(x) => x, 50 | Poll::Pending => unreachable!(), 51 | } 52 | } 53 | 54 | /// Perform a simple websocket client handshake, return a new websocket stream. 55 | /// 56 | /// This function is a combination of [`send_request`](Self::send_request) 57 | /// and [`recv_response`](Self::recv_response), without accessing [`Response`]. 58 | /// It will block until the handshake completes, or an error occurs. 59 | pub fn connect(mut io: IO, buf: &mut [u8], host: &str, path: &str) -> Result> { 60 | let sec_key = new_sec_key(); 61 | let sec_accept = derive_accept_key(&sec_key); 62 | 63 | // send 64 | let request = Request::new(path.as_bytes(), host.as_bytes(), &sec_key); 65 | let _ = Self::send_request(&mut io, buf, &request)?; 66 | 67 | // recv 68 | let mut other_headers = HttpHeader::new_storage(); 69 | let mut response = Response::new_storage(&mut other_headers); 70 | // this is safe since we do not modify response. 71 | let _ = unsafe { Self::recv_response(&mut io, buf, &mut response) }?; 72 | 73 | // check 74 | if response.sec_accept != sec_accept { 75 | return Err(HandshakeError::SecWebSocketAccept.into()); 76 | } 77 | 78 | Ok(Stream::new(io, Role::new())) 79 | } 80 | } 81 | 82 | #[cfg(test)] 83 | mod test { 84 | use std::error::Error; 85 | use super::*; 86 | use super::super::test::*; 87 | use crate::error::HandshakeError; 88 | use crate::role::Client; 89 | 90 | #[test] 91 | fn send_upgrade_request() { 92 | fn run_limit(limit: usize) { 93 | let mut rw = LimitReadWriter { 94 | rbuf: Vec::new(), 95 | wbuf: Vec::new(), 96 | rlimit: 0, 97 | wlimit: limit, 98 | cursor: 0, 99 | }; 100 | 101 | let request = Request::new(b"/ws", b"www.example.com", b"dGhlIHNhbXBsZSBub25jZQ=="); 102 | 103 | let mut buf = vec![0u8; 1024]; 104 | 105 | let send_n = Endpoint::<_, Client>::send_request(&mut rw, &mut buf, &request).unwrap(); 106 | 107 | assert_eq!(send_n, REQUEST.len()); 108 | assert_eq!(&buf[..send_n], REQUEST); 109 | } 110 | 111 | for i in 1..=256 { 112 | run_limit(i); 113 | } 114 | } 115 | 116 | #[test] 117 | fn recv_upgrade_response() { 118 | fn run_limit(limit: usize) { 119 | let mut rw = LimitReadWriter { 120 | rbuf: Vec::from(RESPONSE), 121 | wbuf: Vec::new(), 122 | rlimit: limit, 123 | wlimit: 0, 124 | cursor: 0, 125 | }; 126 | 127 | let mut buf = vec![0u8; 1024]; 128 | let mut headers = HttpHeader::new_storage(); 129 | let mut response = Response::new_storage(&mut headers); 130 | 131 | let recv_n = 132 | unsafe { Endpoint::<_, Client>::recv_response(&mut rw, &mut buf, &mut response) } 133 | .unwrap(); 134 | 135 | assert_eq!(recv_n, RESPONSE.len()); 136 | assert_eq!(response.sec_accept, b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); 137 | drop(response); 138 | assert_eq!(&buf[..recv_n], RESPONSE); 139 | } 140 | 141 | for i in 1..=256 { 142 | run_limit(i); 143 | } 144 | } 145 | 146 | #[test] 147 | fn client_connect() { 148 | // use std::error::Error; 149 | let mut rw = LimitReadWriter { 150 | rbuf: Vec::from(RESPONSE), 151 | wbuf: Vec::new(), 152 | rlimit: 1, 153 | wlimit: 1, 154 | cursor: 0, 155 | }; 156 | 157 | let mut buf = vec![0u8; 1024]; 158 | 159 | // sec-websocket-accept mismatch 160 | // since connect uses a random key 161 | let stream = Endpoint::<_, Client>::connect(&mut rw, &mut buf, "example.com", "/"); 162 | if let Err(e) = stream { 163 | let e = e.source().unwrap(); 164 | let e: &HandshakeError = e.downcast_ref().unwrap(); 165 | assert_eq!(*e, HandshakeError::SecWebSocketAccept); 166 | } 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /src/endpoint/detail/accept.rs: -------------------------------------------------------------------------------- 1 | use std::io::Result; 2 | use std::task::{Poll, ready}; 3 | 4 | use crate::handshake::Request; 5 | use crate::handshake::Response; 6 | use crate::error::HandshakeError; 7 | 8 | pub fn send_response<'h, 'b: 'h, F, IO, const N: usize>( 9 | io: &mut IO, 10 | buf: &mut [u8], 11 | response: &Response<'h, 'b, N>, 12 | mut write: F, 13 | ) -> Poll> 14 | where 15 | F: FnMut(&mut IO, &[u8]) -> Poll>, 16 | { 17 | let total = match response.encode(buf) { 18 | Ok(n) => n, 19 | Err(e) => return Poll::Ready(Err(e.into())), 20 | }; 21 | 22 | let mut offset = 0; 23 | 24 | while offset < total { 25 | let n = ready!(write(io, &buf[offset..total]))?; 26 | 27 | offset += n; 28 | } 29 | 30 | Poll::Ready(Ok(total)) 31 | } 32 | 33 | pub unsafe fn recv_request<'h, 'b: 'h, F, IO, const N: usize>( 34 | io: &mut IO, 35 | buf: &mut [u8], 36 | request: &mut Request<'h, 'b, N>, 37 | mut read: F, 38 | ) -> Poll> 39 | where 40 | F: FnMut(&mut IO, &mut [u8]) -> Poll>, 41 | { 42 | let total = buf.len(); 43 | let mut offset = 0; 44 | 45 | // WARNING !! I am breaking rust's borrow rules here. 46 | // Caller must not modify the buffer while response is in use. 47 | let buf_const: &'b [u8] = &*(buf as *const [u8]); 48 | 49 | while offset < total { 50 | let n = ready!(read(io, &mut buf[offset..]))?; 51 | 52 | // EOF, no more data 53 | if n == 0 { 54 | return Poll::Ready(Err(HandshakeError::NotEnoughData.into())); 55 | } 56 | 57 | offset += n; 58 | 59 | match request.decode(&buf_const[..offset]) { 60 | Ok(_) => return Poll::Ready(Ok(offset)), 61 | Err(ref e) if *e == HandshakeError::NotEnoughData => continue, 62 | Err(e) => return Poll::Ready(Err(e.into())), 63 | } 64 | } 65 | 66 | // provided buffer is filled, however it could not accommodate the response. 67 | Poll::Ready(Err(HandshakeError::NotEnoughCapacity.into())) 68 | } 69 | -------------------------------------------------------------------------------- /src/endpoint/detail/connect.rs: -------------------------------------------------------------------------------- 1 | use std::io::Result; 2 | use std::task::{Poll, ready}; 3 | 4 | use crate::handshake::Request; 5 | use crate::handshake::Response; 6 | use crate::error::HandshakeError; 7 | 8 | pub fn send_request<'h, 'b: 'h, F, IO, const N: usize>( 9 | io: &mut IO, 10 | buf: &mut [u8], 11 | request: &Request<'h, 'b, N>, 12 | mut write: F, 13 | ) -> Poll> 14 | where 15 | F: FnMut(&mut IO, &[u8]) -> Poll>, 16 | { 17 | let total = match request.encode(buf) { 18 | Ok(n) => n, 19 | Err(e) => return Poll::Ready(Err(e.into())), 20 | }; 21 | 22 | let mut offset = 0; 23 | 24 | while offset < total { 25 | let n = ready!(write(io, &buf[offset..total]))?; 26 | 27 | offset += n; 28 | } 29 | 30 | Poll::Ready(Ok(total)) 31 | } 32 | 33 | pub unsafe fn recv_response<'h, 'b: 'h, F, IO, const N: usize>( 34 | io: &mut IO, 35 | buf: &mut [u8], 36 | response: &mut Response<'h, 'b, N>, 37 | mut read: F, 38 | ) -> Poll> 39 | where 40 | F: FnMut(&mut IO, &mut [u8]) -> Poll>, 41 | { 42 | let total = buf.len(); 43 | let mut offset = 0; 44 | 45 | // WARNING !! I am breaking rust's borrow rules here. 46 | // Caller must not modify the buffer while response is in use. 47 | let buf_const: &'b [u8] = &*(buf as *const [u8]); 48 | 49 | while offset < total { 50 | let n = ready!(read(io, &mut buf[offset..]))?; 51 | 52 | // EOF, no more data 53 | if n == 0 { 54 | return Poll::Ready(Err(HandshakeError::NotEnoughData.into())); 55 | } 56 | 57 | offset += n; 58 | 59 | match response.decode(&buf_const[..offset]) { 60 | Ok(_) => return Poll::Ready(Ok(offset)), 61 | Err(ref e) if *e == HandshakeError::NotEnoughData => continue, 62 | Err(e) => return Poll::Ready(Err(e.into())), 63 | } 64 | } 65 | 66 | // provided buffer is filled, however it could not accommodate the response. 67 | Poll::Ready(Err(HandshakeError::NotEnoughCapacity.into())) 68 | } 69 | -------------------------------------------------------------------------------- /src/endpoint/detail/mod.rs: -------------------------------------------------------------------------------- 1 | mod accept; 2 | mod connect; 3 | 4 | pub(super) use accept::{recv_request, send_response}; 5 | pub(super) use connect::{recv_response, send_request}; 6 | -------------------------------------------------------------------------------- /src/endpoint/mod.rs: -------------------------------------------------------------------------------- 1 | //! Websocket endpoint. 2 | //! 3 | //! [`Endpoint`] is used to perform a handshake. It is compatible with 4 | //! both sync and async IO. 5 | //! 6 | //! To open or accept a connection directly, use [`Endpoint::connect`], 7 | //! [`Endpoint::accept`], or their async version. 8 | //! 9 | //! To have detailed control over a handshake, use [`Endpoint::send_request`], 10 | //! [`Endpoint::recv_response`], [`Endpoint::recv_request`], [`Endpoint::send_response`], 11 | //! or their async version. 12 | 13 | mod detail; 14 | mod client; 15 | mod server; 16 | 17 | cfg_if::cfg_if! { 18 | if #[cfg(feature = "tokio")] { 19 | mod async_client; 20 | mod async_server; 21 | } 22 | } 23 | 24 | use std::marker::PhantomData; 25 | 26 | /// Handshake endpoint. 27 | pub struct Endpoint { 28 | _marker: PhantomData, 29 | __marker: PhantomData, 30 | } 31 | 32 | #[cfg(test)] 33 | mod test { 34 | use std::io::{Read, Write, Result}; 35 | 36 | pub const REQUEST: &[u8] = b"\ 37 | GET /ws HTTP/1.1\r\n\ 38 | host: www.example.com\r\n\ 39 | upgrade: websocket\r\n\ 40 | connection: upgrade\r\n\ 41 | sec-websocket-key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ 42 | sec-websocket-version: 13\r\n\r\n"; 43 | 44 | pub const RESPONSE: &[u8] = b"\ 45 | HTTP/1.1 101 Switching Protocols\r\n\ 46 | upgrade: websocket\r\n\ 47 | connection: upgrade\r\n\ 48 | sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n"; 49 | 50 | pub struct LimitReadWriter { 51 | pub rbuf: Vec, 52 | pub wbuf: Vec, 53 | pub rlimit: usize, 54 | pub wlimit: usize, 55 | pub cursor: usize, 56 | } 57 | 58 | impl Read for LimitReadWriter { 59 | fn read(&mut self, mut buf: &mut [u8]) -> Result { 60 | let to_read = std::cmp::min(buf.len(), self.rlimit); 61 | let left_data = self.rbuf.len() - self.cursor; 62 | if left_data == 0 { 63 | return Ok(0); 64 | } 65 | if left_data <= to_read { 66 | buf.write(&self.rbuf[self.cursor..]).unwrap(); 67 | self.cursor = self.rbuf.len(); 68 | return Ok(left_data); 69 | } 70 | 71 | buf.write(&self.rbuf[self.cursor..self.cursor + to_read]) 72 | .unwrap(); 73 | self.cursor += to_read; 74 | Ok(to_read) 75 | } 76 | } 77 | 78 | impl Write for LimitReadWriter { 79 | fn write(&mut self, buf: &[u8]) -> Result { 80 | let len = std::cmp::min(buf.len(), self.wlimit); 81 | self.wbuf.write(&buf[..len]) 82 | } 83 | 84 | fn flush(&mut self) -> Result<()> { Ok(()) } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/endpoint/server.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Read, Write, Result}; 2 | use std::task::Poll; 3 | 4 | use super::detail; 5 | use super::Endpoint; 6 | 7 | use crate::role::ServerRole; 8 | use crate::handshake::{HttpHeader, Request, Response}; 9 | use crate::handshake::derive_accept_key; 10 | use crate::error::HandshakeError; 11 | use crate::stream::Stream; 12 | 13 | impl Endpoint { 14 | /// Send websocket upgrade response to IO source, return 15 | /// the number of bytes transmitted. 16 | /// Response data are encoded to the provided buffer. 17 | /// 18 | /// This function will block until all data 19 | /// are written to IO source or an error occurs. 20 | pub fn send_response( 21 | io: &mut IO, 22 | buf: &mut [u8], 23 | response: &Response<'_, '_, N>, 24 | ) -> Result { 25 | match detail::send_response(io, buf, response, |io, buf| io.write(buf).into()) { 26 | Poll::Ready(x) => x, 27 | Poll::Pending => unreachable!(), 28 | } 29 | } 30 | 31 | /// Receive websocket upgrade request from IO source, return 32 | /// the number of bytes transmitted. 33 | /// Received data are stored in the provided buffer, and parsed 34 | /// as [`Request`]. 35 | /// 36 | /// This function will block on reading data, until there is enough 37 | /// data to parse a request or an error occurs. 38 | /// 39 | /// # Safety 40 | /// 41 | /// Caller must not modify the buffer while `request` is in use, 42 | /// otherwise it is undefined behavior! 43 | pub unsafe fn recv_request<'h, 'b: 'h, const N: usize>( 44 | io: &mut IO, 45 | buf: &mut [u8], 46 | request: &mut Request<'h, 'b, N>, 47 | ) -> Result { 48 | match detail::recv_request(io, buf, request, |io, buf| io.read(buf).into()) { 49 | Poll::Ready(x) => x, 50 | Poll::Pending => unreachable!(), 51 | } 52 | } 53 | 54 | /// Perform a simple websocket server handshake, return a new websocket stream. 55 | /// 56 | /// This function is a combination of [`recv_request`](Self::recv_request) 57 | /// and [`send_response`](Self::send_response), without accessing [`Request`]. 58 | /// It will block until the handshake completes, or an error occurs. 59 | pub fn accept(mut io: IO, buf: &mut [u8], host: &str, path: &str) -> Result> { 60 | // recv 61 | let mut other_headers = HttpHeader::new_storage(); 62 | let mut request = Request::new_storage(&mut other_headers); 63 | // this is safe since we do not modify request. 64 | let _ = unsafe { Self::recv_request(&mut io, buf, &mut request) }?; 65 | 66 | // check 67 | if request.host != host.as_bytes() { 68 | return Err(HandshakeError::Manual("host mismatch").into()); 69 | } 70 | 71 | if request.path != path.as_bytes() { 72 | return Err(HandshakeError::Manual("path mismatch").into()); 73 | } 74 | 75 | // send 76 | let sec_accept = derive_accept_key(request.sec_key); 77 | let response = Response::new(&sec_accept); 78 | let _ = Self::send_response(&mut io, buf, &response)?; 79 | 80 | Ok(Stream::new(io, Role::new())) 81 | } 82 | } 83 | 84 | #[cfg(test)] 85 | mod test { 86 | use super::*; 87 | use super::super::test::*; 88 | use crate::role::Server; 89 | 90 | #[test] 91 | fn send_upgrade_response() { 92 | fn run_limit(limit: usize) { 93 | let mut rw = LimitReadWriter { 94 | rbuf: Vec::new(), 95 | wbuf: Vec::new(), 96 | rlimit: 0, 97 | wlimit: limit, 98 | cursor: 0, 99 | }; 100 | 101 | let response = Response::new(b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); 102 | 103 | let mut buf = vec![0u8; 1024]; 104 | 105 | let send_n = 106 | Endpoint::<_, Server>::send_response(&mut rw, &mut buf, &response).unwrap(); 107 | 108 | assert_eq!(send_n, RESPONSE.len()); 109 | assert_eq!(&buf[..send_n], RESPONSE); 110 | } 111 | 112 | for i in 1..=256 { 113 | run_limit(i); 114 | } 115 | } 116 | 117 | #[test] 118 | fn recv_upgrade_request() { 119 | fn run_limit(limit: usize) { 120 | let mut rw = LimitReadWriter { 121 | rbuf: Vec::from(REQUEST), 122 | wbuf: Vec::new(), 123 | rlimit: limit, 124 | wlimit: 0, 125 | cursor: 0, 126 | }; 127 | 128 | let mut buf = vec![0u8; 1024]; 129 | let mut headers = HttpHeader::new_storage(); 130 | let mut request = Request::new_storage(&mut headers); 131 | 132 | let recv_n = 133 | unsafe { Endpoint::<_, Server>::recv_request(&mut rw, &mut buf, &mut request) } 134 | .unwrap(); 135 | 136 | assert_eq!(recv_n, REQUEST.len()); 137 | assert_eq!(request.host, b"www.example.com"); 138 | assert_eq!(request.path, b"/ws"); 139 | assert_eq!(request.sec_key, b"dGhlIHNhbXBsZSBub25jZQ=="); 140 | drop(request); 141 | assert_eq!(&buf[..recv_n], REQUEST); 142 | } 143 | 144 | for i in 1..=256 { 145 | run_limit(i); 146 | } 147 | } 148 | 149 | #[test] 150 | fn server_accept() { 151 | // use std::error::Error; 152 | let mut rw = LimitReadWriter { 153 | rbuf: Vec::from(REQUEST), 154 | wbuf: Vec::new(), 155 | rlimit: 1, 156 | wlimit: 1, 157 | cursor: 0, 158 | }; 159 | 160 | let mut buf = vec![0u8; 1024]; 161 | 162 | let _ = Endpoint::<_, Server>::accept(&mut rw, &mut buf, "www.example.com", "/ws"); 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /src/error/ctrl.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Display, Formatter}; 2 | 3 | #[derive(Debug, PartialEq, Eq)] 4 | pub enum CtrlError { 5 | SetMaskInWrite, 6 | } 7 | 8 | impl Display for CtrlError { 9 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 10 | use CtrlError::*; 11 | match self { 12 | SetMaskInWrite => write!(f, "Set mask during an incomplete write"), 13 | } 14 | } 15 | } 16 | 17 | // use default impl 18 | impl std::error::Error for CtrlError {} 19 | -------------------------------------------------------------------------------- /src/error/frame.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Display, Formatter}; 2 | 3 | #[derive(Debug, PartialEq, Eq)] 4 | pub enum FrameError { 5 | IllegalFin, 6 | 7 | IllegalMask, 8 | 9 | IllegalOpCode, 10 | 11 | IllegalData, 12 | 13 | NotEnoughData, 14 | 15 | NotEnoughCapacity, 16 | 17 | UnsupportedOpcode, 18 | } 19 | 20 | impl Display for FrameError { 21 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 22 | use FrameError::*; 23 | match self { 24 | IllegalFin => write!(f, "Illegal fin value"), 25 | IllegalMask => write!(f, "Illegal mask value"), 26 | IllegalOpCode => write!(f, "Illegal opcode value"), 27 | IllegalData => write!(f, "Illegal data"), 28 | NotEnoughData => write!(f, "Not enough data to parse"), 29 | NotEnoughCapacity => write!(f, "Not enough space to write to"), 30 | UnsupportedOpcode => write!( 31 | f, 32 | "Unsupported opcode, only support binary, ping, pong, close" 33 | ), 34 | } 35 | } 36 | } 37 | 38 | // use default impl 39 | impl std::error::Error for FrameError {} 40 | -------------------------------------------------------------------------------- /src/error/handshake.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Display, Formatter}; 2 | 3 | #[derive(Debug, PartialEq, Eq)] 4 | pub enum HandshakeError { 5 | // http error 6 | HttpVersion, 7 | 8 | HttpMethod, 9 | 10 | HttpSatusCode, 11 | 12 | HttpHost, 13 | 14 | // websocket error 15 | Upgrade, 16 | 17 | Connection, 18 | 19 | SecWebSocketKey, 20 | 21 | SecWebSocketAccept, 22 | 23 | SecWebSocketVersion, 24 | 25 | // other error 26 | 27 | // read 28 | NotEnoughData, 29 | 30 | // write 31 | NotEnoughCapacity, 32 | 33 | Httparse(httparse::Error), 34 | 35 | Manual(&'static str), 36 | } 37 | 38 | impl Display for HandshakeError { 39 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 40 | use HandshakeError::*; 41 | match self { 42 | // http error 43 | HttpVersion => write!(f, "Illegal http version"), 44 | 45 | HttpMethod => write!(f, "Illegal http method"), 46 | 47 | HttpSatusCode => write!(f, "Illegal http status code"), 48 | 49 | HttpHost => write!(f, "Missing http host header"), 50 | 51 | // websocket error 52 | Upgrade => write!(f, "Missing or illegal upgrade header"), 53 | 54 | Connection => write!(f, "Missing or illegal connection header"), 55 | 56 | SecWebSocketKey => { 57 | write!(f, "Missing sec-websocket-key header") 58 | } 59 | 60 | SecWebSocketAccept => { 61 | write!(f, "Missing or illegal sec-websocket-accept header") 62 | } 63 | 64 | SecWebSocketVersion => { 65 | write!(f, "Missing or illegal sec-websocket-version") 66 | } 67 | 68 | // other error 69 | NotEnoughData => write!(f, "Not enough data to parse"), 70 | 71 | NotEnoughCapacity => write!(f, "Not enough space to write to"), 72 | 73 | Httparse(e) => write!(f, "Http parse error: {}", e), 74 | 75 | Manual(s) => write!(f, "Manual error: {}", s), 76 | } 77 | } 78 | } 79 | 80 | impl From for HandshakeError { 81 | fn from(e: httparse::Error) -> Self { HandshakeError::Httparse(e) } 82 | } 83 | 84 | impl std::error::Error for HandshakeError { 85 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 86 | if let HandshakeError::Httparse(e) = self { 87 | Some(e) 88 | } else { 89 | None 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/error/mod.rs: -------------------------------------------------------------------------------- 1 | #![allow(missing_docs)] 2 | //! Errors 3 | 4 | mod ctrl; 5 | mod frame; 6 | mod handshake; 7 | 8 | pub use ctrl::CtrlError; 9 | pub use frame::FrameError; 10 | pub use handshake::HandshakeError; 11 | 12 | use std::fmt::{Display, Formatter}; 13 | 14 | #[derive(Debug)] 15 | pub enum Error { 16 | Ctrl(CtrlError), 17 | 18 | Frame(FrameError), 19 | 20 | Handshake(HandshakeError), 21 | } 22 | 23 | impl From for Error { 24 | fn from(e: FrameError) -> Self { Error::Frame(e) } 25 | } 26 | 27 | impl From for Error { 28 | fn from(e: HandshakeError) -> Self { Error::Handshake(e) } 29 | } 30 | 31 | impl Display for Error { 32 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 33 | use Error::*; 34 | match self { 35 | Ctrl(e) => write!(f, "Control error: {}", e), 36 | Frame(e) => write!(f, "Frame error: {}", e), 37 | Handshake(e) => write!(f, "Handshake error: {}", e), 38 | } 39 | } 40 | } 41 | 42 | impl std::error::Error for Error { 43 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 44 | use Error::*; 45 | 46 | match self { 47 | Ctrl(e) => Some(e), 48 | Frame(e) => Some(e), 49 | Handshake(e) => Some(e), 50 | } 51 | } 52 | } 53 | 54 | impl From for std::io::Error { 55 | fn from(e: Error) -> Self { 56 | use std::io::{Error, ErrorKind}; 57 | Error::new(ErrorKind::Other, e) 58 | } 59 | } 60 | 61 | impl From for std::io::Error { 62 | fn from(e: CtrlError) -> Self { Error::Ctrl(e).into() } 63 | } 64 | 65 | impl From for std::io::Error { 66 | fn from(e: FrameError) -> Self { Error::Frame(e).into() } 67 | } 68 | 69 | impl From for std::io::Error { 70 | fn from(e: HandshakeError) -> Self { Error::Handshake(e).into() } 71 | } 72 | -------------------------------------------------------------------------------- /src/frame/flag.rs: -------------------------------------------------------------------------------- 1 | //! Fin flag and opcode. 2 | 3 | use crate::error::FrameError; 4 | 5 | /// Fin flag. 6 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 7 | pub enum Fin { 8 | /// a byte with its leading bit set 9 | Y = 0x80, 10 | 11 | /// a byte with its leading bit clear 12 | N = 0x00, 13 | } 14 | 15 | /// Frame opcode. 16 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 17 | pub enum OpCode { 18 | /// denotes a continuation frame, 0x00 19 | Continue = 0x00, 20 | /// denotes a text frame, 0x01 21 | Text = 0x01, 22 | /// denotes a binary frame, 0x02 23 | Binary = 0x02, 24 | 25 | /// denotes a connection close, 0x08 26 | Close = 0x08, 27 | /// denotes a ping, 0x09 28 | Ping = 0x09, 29 | /// denotes a pong, 0x0a 30 | Pong = 0x0a, 31 | } 32 | 33 | impl Fin { 34 | /// Parse from byte. 35 | #[inline] 36 | pub const fn from_flag(b: u8) -> Result { 37 | let fin = match b & 0xf0 { 38 | 0x80 => Fin::Y, 39 | 0x00 => Fin::N, 40 | _ => return Err(FrameError::IllegalFin), 41 | }; 42 | Ok(fin) 43 | } 44 | } 45 | 46 | impl OpCode { 47 | /// Parse from byte. 48 | #[inline] 49 | pub const fn from_flag(b: u8) -> Result { 50 | use OpCode::*; 51 | let opcode = match b & 0x0f { 52 | 0x00 => Continue, 53 | 0x01 => Text, 54 | 0x02 => Binary, 55 | 0x08 => Close, 56 | 0x09 => Ping, 57 | 0x0a => Pong, 58 | _ => return Err(FrameError::IllegalOpCode), 59 | }; 60 | Ok(opcode) 61 | } 62 | } 63 | 64 | #[cfg(test)] 65 | mod test { 66 | use super::*; 67 | 68 | macro_rules! enc_dec { 69 | ($class: ident $(, $v: expr )+ ) => { 70 | $( 71 | let v = $class::from_flag($v).unwrap(); 72 | assert_eq!(v as u8, $v); 73 | )+ 74 | }; 75 | } 76 | 77 | #[test] 78 | fn fin() { 79 | enc_dec!(Fin, 0x00, 0x80); 80 | } 81 | 82 | #[test] 83 | fn opcode() { 84 | enc_dec!(OpCode, 0x00, 0x01, 0x02, 0x08, 0x09, 0x0a); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/frame/length.rs: -------------------------------------------------------------------------------- 1 | //! Payload length. 2 | 3 | /// Payload length. 4 | /// 5 | /// Could be 7 bits, 7+16 bits, or 7+64 bits. 6 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 7 | pub enum PayloadLen { 8 | /// 0 - 125 9 | Standard(u8), 10 | /// 126 - 65535 11 | Extended1(u16), 12 | /// over 65536 13 | Extended2(u64), 14 | } 15 | 16 | impl PayloadLen { 17 | /// Parse from number. 18 | #[inline] 19 | pub const fn from_num(n: u64) -> Self { 20 | if n < 126 { 21 | PayloadLen::Standard(n as u8) 22 | } else if n < 65536 { 23 | PayloadLen::Extended1(n as u16) 24 | } else { 25 | PayloadLen::Extended2(n) 26 | } 27 | } 28 | 29 | /// Convert to number. 30 | #[inline] 31 | pub const fn to_num(self) -> u64 { 32 | use PayloadLen::*; 33 | match self { 34 | Standard(v) => v as u64, 35 | Extended1(v) => v as u64, 36 | Extended2(v) => v, 37 | } 38 | } 39 | 40 | /// Read the flag which indicates the kind of length. 41 | /// 42 | /// If extended length is used, the caller should read the next 2 or 8 bytes 43 | /// to get the real length. 44 | #[inline] 45 | pub const fn from_flag(b: u8) -> Self { 46 | match b & 0x7f { 47 | 126 => PayloadLen::Extended1(0), 48 | 127 => PayloadLen::Extended2(0), 49 | b => PayloadLen::Standard(b), 50 | } 51 | } 52 | 53 | /// Generate the flag byte. 54 | /// If `length <= 125`, it represents the real length. 55 | #[inline] 56 | pub const fn to_flag(&self) -> u8 { 57 | use PayloadLen::*; 58 | match self { 59 | Standard(b) => *b, 60 | Extended1(_) => 126, 61 | Extended2(_) => 127, 62 | } 63 | } 64 | 65 | /// Read as 16-bit length. 66 | #[inline] 67 | pub const fn from_byte2(buf: [u8; 2]) -> Self { PayloadLen::Extended1(u16::from_be_bytes(buf)) } 68 | 69 | /// Read as 64-bit length. 70 | #[inline] 71 | pub const fn from_byte8(buf: [u8; 8]) -> Self { PayloadLen::Extended2(u64::from_be_bytes(buf)) } 72 | 73 | /// Get value, as 8-bit length. 74 | #[inline] 75 | pub const fn to_u8(&self) -> u8 { 76 | match self { 77 | PayloadLen::Standard(v) => *v, 78 | _ => unreachable!(), 79 | } 80 | } 81 | 82 | /// Get value, as 16-bit length. 83 | #[inline] 84 | pub const fn to_u16(&self) -> u16 { 85 | match self { 86 | PayloadLen::Extended1(v) => *v, 87 | _ => unreachable!(), 88 | } 89 | } 90 | 91 | /// Get value, as 64-bit length. 92 | #[inline] 93 | pub const fn to_u64(&self) -> u64 { 94 | match self { 95 | PayloadLen::Extended2(v) => *v, 96 | _ => unreachable!(), 97 | } 98 | } 99 | } 100 | 101 | #[cfg(test)] 102 | mod test { 103 | use super::*; 104 | 105 | #[test] 106 | fn standard() { 107 | for v in 0..=125_u8 { 108 | let a = PayloadLen::from_flag(v); 109 | let b = PayloadLen::from_num(v as u64); 110 | 111 | assert_eq!(a.to_flag(), v); 112 | assert_eq!(a.to_num(), b.to_num()); 113 | } 114 | } 115 | 116 | #[test] 117 | fn extend1() { 118 | for v in 126..=65535_u16 { 119 | let a = PayloadLen::from_num(v as u64); 120 | let b = PayloadLen::from_byte2(v.to_be_bytes()); 121 | 122 | assert_eq!(a.to_flag(), 126_u8); 123 | assert_eq!(a.to_num(), b.to_num()); 124 | } 125 | } 126 | 127 | #[test] 128 | fn extend2() { 129 | for v in 65536..=100000_u64 { 130 | let a = PayloadLen::from_num(v); 131 | let b = PayloadLen::from_byte8(v.to_be_bytes()); 132 | 133 | assert_eq!(a.to_flag(), 127_u8); 134 | assert_eq!(a.to_num(), b.to_num()); 135 | } 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /src/frame/mask.rs: -------------------------------------------------------------------------------- 1 | //! Mask flag and key. 2 | 3 | use crate::error::FrameError; 4 | 5 | /// Payload mask with a 32-bit key. 6 | /// 7 | /// `Mask::Skip` is used by server side to skip unmask 8 | /// if mask key equals 0. 9 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 10 | pub enum Mask { 11 | Key([u8; 4]), 12 | Skip, 13 | None, 14 | } 15 | 16 | impl Mask { 17 | /// Read the flag which indicates whether mask is used. 18 | #[inline] 19 | pub const fn from_flag(b: u8) -> Result { 20 | let mask = match b & 0x80 { 21 | 0x80 => Mask::Skip, 22 | 0x00 => Mask::None, 23 | _ => return Err(FrameError::IllegalMask), 24 | }; 25 | Ok(mask) 26 | } 27 | 28 | /// Get the flag byte. 29 | #[inline] 30 | pub const fn to_flag(&self) -> u8 { 31 | use Mask::*; 32 | match self { 33 | Key(_) | Skip => 0x80, 34 | None => 0x00, 35 | } 36 | } 37 | 38 | /// Get inner mask key. 39 | #[inline] 40 | pub const fn to_key(&self) -> [u8; 4] { 41 | use Mask::*; 42 | match self { 43 | Key(k) => *k, 44 | Skip => [0u8; 4], 45 | None => unreachable!(), 46 | } 47 | } 48 | } 49 | 50 | /// Generate a new random mask key. 51 | #[inline] 52 | pub fn new_mask_key() -> [u8; 4] { rand::random::<[u8; 4]>() } 53 | 54 | /// Mask the buffer, byte by byte. 55 | #[inline] 56 | pub fn apply_mask(key: [u8; 4], buf: &mut [u8]) { 57 | for (i, b) in buf.iter_mut().enumerate() { 58 | *b ^= key[i & 0x03]; 59 | } 60 | } 61 | 62 | /// Mask the buffer, 4 bytes at a time. 63 | #[inline] 64 | pub fn apply_mask4(key: [u8; 4], buf: &mut [u8]) { 65 | let key4 = u32::from_ne_bytes(key); 66 | 67 | let (prefix, middle, suffix) = unsafe { buf.align_to_mut::() }; 68 | 69 | apply_mask(key, prefix); 70 | 71 | let head = prefix.len() & 3; 72 | let key4 = if head > 0 { 73 | if cfg!(target_endian = "big") { 74 | key4.rotate_left(8 * head as u32) 75 | } else { 76 | key4.rotate_right(8 * head as u32) 77 | } 78 | } else { 79 | key4 80 | }; 81 | for b4 in middle.iter_mut() { 82 | *b4 ^= key4; 83 | } 84 | 85 | apply_mask(key4.to_ne_bytes(), suffix); 86 | } 87 | 88 | #[cfg(test)] 89 | mod test { 90 | use super::*; 91 | 92 | #[test] 93 | fn mask_store() { 94 | for v in [0x00, 0x80] { 95 | assert_eq!(Mask::from_flag(v).unwrap().to_flag(), v); 96 | } 97 | } 98 | 99 | #[test] 100 | fn mask_byte() { 101 | let key: [u8; 4] = rand::random(); 102 | let buf: Vec = std::iter::repeat(rand::random::()).take(1024).collect(); 103 | 104 | assert_eq!(buf.len(), 1024); 105 | 106 | let mut buf2 = buf.clone(); 107 | apply_mask(key, &mut buf2); 108 | apply_mask(key, &mut buf2); 109 | 110 | assert_eq!(buf, buf2); 111 | } 112 | 113 | #[test] 114 | fn mask_byte4() { 115 | for i in 0..4096 { 116 | let key: [u8; 4] = rand::random(); 117 | let buf: Vec = std::iter::repeat(rand::random::()).take(i).collect(); 118 | 119 | assert_eq!(buf.len(), i); 120 | 121 | let mut buf2 = buf.clone(); 122 | apply_mask4(key, &mut buf2); 123 | apply_mask4(key, &mut buf2); 124 | 125 | assert_eq!(buf, buf2); 126 | } 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /src/frame/mod.rs: -------------------------------------------------------------------------------- 1 | //! Websocket data frame. 2 | //! 3 | //! [RFC-6455 Section5](https://datatracker.ietf.org/doc/html/rfc6455#section-5) 4 | //! 5 | //! ```text 6 | //! 0 1 2 3 7 | //! 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 8 | //! +-+-+-+-+-------+-+-------------+-------------------------------+ 9 | //! |F|R|R|R| opcode|M| Payload len | Extended payload length | 10 | //! |I|S|S|S| (4) |A| (7) | (16/64) | 11 | //! |N|V|V|V| |S| | (if payload len==126/127) | 12 | //! | |1|2|3| |K| | | 13 | //! +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + 14 | //! | Extended payload length continued, if payload len == 127 | 15 | //! + - - - - - - - - - - - - - - - +-------------------------------+ 16 | //! | |Masking-key, if MASK set to 1 | 17 | //! +-------------------------------+-------------------------------+ 18 | //! | Masking-key (continued) | Payload Data | 19 | //! +-------------------------------- - - - - - - - - - - - - - - - + 20 | //! : Payload Data continued ... : 21 | //! + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + 22 | //! | Payload Data continued ... | 23 | //! +---------------------------------------------------------------+ 24 | //! ``` 25 | //! 26 | 27 | pub mod flag; 28 | pub mod length; 29 | pub mod mask; 30 | 31 | pub use flag::{Fin, OpCode}; 32 | pub use length::PayloadLen; 33 | pub use mask::{Mask, new_mask_key, apply_mask4}; 34 | 35 | /// Websocket frame head. 36 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 37 | pub struct FrameHead { 38 | pub fin: Fin, 39 | pub opcode: OpCode, 40 | pub mask: Mask, 41 | pub length: PayloadLen, 42 | } 43 | 44 | use crate::bleed::Writer; 45 | use crate::bleed::{slice, slice_to_array}; 46 | use crate::error::FrameError; 47 | 48 | impl FrameHead { 49 | /// Constructor. 50 | #[inline] 51 | pub const fn new(fin: Fin, opcode: OpCode, mask: Mask, length: PayloadLen) -> Self { 52 | Self { 53 | fin, 54 | opcode, 55 | mask, 56 | length, 57 | } 58 | } 59 | 60 | /// Encode to provided buffer, return the count of written bytes. 61 | /// 62 | /// Caller should ensure there is enough space to write, 63 | /// otherwise a [`FrameError::NotEnoughCapacity`] error will be returned. 64 | pub fn encode(&self, buf: &mut [u8]) -> Result { 65 | if buf.len() < 2 { 66 | return Err(FrameError::NotEnoughCapacity); 67 | } 68 | 69 | let mut writer = Writer::new(buf); 70 | 71 | macro_rules! writex { 72 | ($dst: expr) => { 73 | if writer.remaining() < $dst.len() { 74 | return Err(FrameError::NotEnoughCapacity); 75 | } else { 76 | unsafe { 77 | writer.write_unchecked($dst); 78 | } 79 | } 80 | }; 81 | } 82 | 83 | // fin, opcode 84 | let b1 = self.fin as u8 | self.opcode as u8; 85 | 86 | // mask, payload length 87 | let b2 = self.mask.to_flag() | self.length.to_flag(); 88 | 89 | writex!(&[b1, b2]); 90 | 91 | // extended payload length 92 | match &self.length { 93 | PayloadLen::Standard(_) => {} 94 | PayloadLen::Extended1(v) => writex!(&v.to_be_bytes()), 95 | PayloadLen::Extended2(v) => writex!(&v.to_be_bytes()), 96 | }; 97 | 98 | // mask key 99 | match &self.mask { 100 | Mask::Key(k) => writex!(k), 101 | Mask::Skip => writex!(&[0u8; 4]), 102 | Mask::None => {} 103 | }; 104 | 105 | Ok(writer.pos()) 106 | } 107 | 108 | /// Unchecked version of [`encode`](Self::encode). 109 | /// 110 | /// # Safety 111 | /// 112 | /// Caller must ensure there is enough space to write, 113 | /// otherwise it is **Undefined Behavior!** 114 | pub unsafe fn encode_unchecked(&self, buf: &mut [u8]) -> usize { 115 | let mut writer = Writer::new(buf); 116 | 117 | macro_rules! writex { 118 | ($dst: expr) => {{ 119 | writer.write_unchecked($dst); 120 | }}; 121 | } 122 | 123 | // fin, opcode 124 | let b1 = self.fin as u8 | self.opcode as u8; 125 | 126 | // mask, payload length 127 | let b2 = self.mask.to_flag() | self.length.to_flag(); 128 | 129 | writex!(&[b1, b2]); 130 | 131 | // extended payload length 132 | match &self.length { 133 | PayloadLen::Standard(_) => {} 134 | PayloadLen::Extended1(v) => writex!(&v.to_be_bytes()), 135 | PayloadLen::Extended2(v) => writex!(&v.to_be_bytes()), 136 | }; 137 | 138 | // mask key 139 | match &self.mask { 140 | Mask::Key(k) => writex!(k), 141 | Mask::Skip => writex!(&[0u8; 4]), 142 | Mask::None => {} 143 | }; 144 | 145 | writer.pos() 146 | } 147 | 148 | /// Parse from provided buffer, returns [`FrameHead`] and the count of read bytes. 149 | /// 150 | /// If there is not enough data to parse, a [`FrameError::NotEnoughData`] error 151 | /// will be returned. 152 | pub fn decode(buf: &[u8]) -> Result<(Self, usize), FrameError> { 153 | if buf.len() < 2 { 154 | return Err(FrameError::NotEnoughData); 155 | } 156 | 157 | let mut n: usize = 2; 158 | 159 | // fin, opcode 160 | let b1 = unsafe { *buf.get_unchecked(0) }; 161 | 162 | // mask, payload length 163 | let b2 = unsafe { *buf.get_unchecked(1) }; 164 | 165 | let fin = Fin::from_flag(b1)?; 166 | let opcode = OpCode::from_flag(b1)?; 167 | 168 | let mut mask = Mask::from_flag(b2)?; 169 | let mut length = PayloadLen::from_flag(b2); 170 | 171 | match length { 172 | PayloadLen::Standard(_) => {} 173 | PayloadLen::Extended1(_) => { 174 | if buf.len() - n < 2 { 175 | return Err(FrameError::NotEnoughData); 176 | } 177 | 178 | length = 179 | PayloadLen::from_byte2(unsafe { *slice_to_array::<_, 2>(slice(buf, 2, 4)) }); 180 | 181 | n += 2; 182 | } 183 | PayloadLen::Extended2(_) => { 184 | if buf.len() - n < 8 { 185 | return Err(FrameError::NotEnoughData); 186 | } 187 | 188 | length = 189 | PayloadLen::from_byte8(unsafe { *slice_to_array::<_, 8>(slice(buf, 2, 10)) }); 190 | 191 | n += 8; 192 | } 193 | }; 194 | 195 | match mask { 196 | Mask::None => {} 197 | _ => { 198 | if buf.len() - n < 4 { 199 | return Err(FrameError::NotEnoughData); 200 | } 201 | 202 | let key = *unsafe { slice_to_array::<_, 4>(slice(buf, n, n + 4)) }; 203 | 204 | if key.into_iter().all(|b| b == 0) { 205 | mask = Mask::Skip 206 | } else { 207 | mask = Mask::Key(key) 208 | } 209 | 210 | n += 4; 211 | } 212 | } 213 | 214 | Ok(( 215 | FrameHead { 216 | fin, 217 | opcode, 218 | mask, 219 | length, 220 | }, 221 | n, 222 | )) 223 | } 224 | } 225 | 226 | #[cfg(test)] 227 | mod test { 228 | use super::*; 229 | 230 | #[test] 231 | fn frame_head() { 232 | let head = FrameHead { 233 | fin: Fin::Y, 234 | opcode: OpCode::Binary, 235 | mask: Mask::Key(mask::new_mask_key()), 236 | length: PayloadLen::from_num(4096), 237 | }; 238 | 239 | let head2 = FrameHead { 240 | fin: Fin::N, 241 | opcode: OpCode::Binary, 242 | mask: Mask::Key(mask::new_mask_key()), 243 | length: PayloadLen::from_num(64), 244 | }; 245 | 246 | for head in [head, head2] { 247 | let mut buf = vec![0; 1024]; 248 | 249 | let encode_n = head.encode(&mut buf).unwrap(); 250 | 251 | assert!(encode_n + 128 <= buf.len()); 252 | 253 | let (head2, decode_n) = FrameHead::decode(&buf[0..encode_n + 128]).unwrap(); 254 | 255 | assert_eq!(encode_n, decode_n); 256 | assert_eq!(head, head2); 257 | 258 | let mut buf2 = vec![0; 1024]; 259 | let encode_n2 = unsafe { head2.encode_unchecked(&mut buf2) }; 260 | 261 | assert_eq!(encode_n2, encode_n); 262 | assert_eq!(&buf[0..encode_n], &buf2[0..encode_n2]); 263 | } 264 | } 265 | } 266 | -------------------------------------------------------------------------------- /src/handshake/key.rs: -------------------------------------------------------------------------------- 1 | //! Key exchange. 2 | 3 | use super::GUID; 4 | use base64::Engine; 5 | use base64::engine::general_purpose::STANDARD; 6 | use sha1::{Digest, Sha1}; 7 | 8 | /// Generate a new `sec-websocket-key`. 9 | #[inline] 10 | pub fn new_sec_key() -> [u8; 24] { 11 | let input: [u8; 16] = rand::random(); 12 | let mut output = [0_u8; 24]; 13 | Engine::encode_slice(&STANDARD, input, &mut output).unwrap(); 14 | output 15 | } 16 | 17 | /// Derive `sec-websocket-accept` from `sec-websocket-key`. 18 | #[inline] 19 | pub fn derive_accept_key(sec_key: &[u8]) -> [u8; 28] { 20 | let mut sha1 = Sha1::default(); 21 | sha1.update(sec_key); 22 | sha1.update(GUID); 23 | let input = sha1.finalize(); 24 | let mut output = [0_u8; 28]; 25 | Engine::encode_slice(&STANDARD, input, &mut output).unwrap(); 26 | output 27 | } 28 | 29 | #[cfg(test)] 30 | mod test { 31 | use super::*; 32 | 33 | #[test] 34 | fn generate_sec_key() { 35 | for _ in 0..=1024 { 36 | // should not panic 37 | new_sec_key(); 38 | } 39 | } 40 | 41 | #[test] 42 | fn derive_sec_key() { 43 | assert_eq!( 44 | &derive_accept_key(b"dGhlIHNhbXBsZSBub25jZQ=="), 45 | b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" 46 | ); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/handshake/mod.rs: -------------------------------------------------------------------------------- 1 | //! Websocket handshake. 2 | 3 | pub mod key; 4 | pub mod request; 5 | pub mod response; 6 | 7 | pub use request::Request; 8 | pub use response::Response; 9 | pub use key::{new_sec_key, derive_accept_key}; 10 | 11 | /// 32 12 | pub const MAX_ALLOW_HEADERS: usize = 32; 13 | 14 | /// Empty header with dummy reference 15 | pub const EMPTY_HEADER: HttpHeader = HttpHeader::new(b"", b""); 16 | 17 | /// 258EAFA5-E914-47DA-95CA-C5AB0DC85B11 18 | pub const GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 19 | 20 | /// GET 21 | pub const HTTP_METHOD: &[u8] = b"GET"; 22 | 23 | /// HTTP/1.1 24 | pub const HTTP_VERSION: &[u8] = b"HTTP/1.1"; 25 | 26 | /// CRLF 27 | pub const HTTP_LINE_BREAK: &[u8] = b"\r\n"; 28 | 29 | /// A colon + one SP is prefered 30 | pub const HTTP_HEADER_SP: &[u8] = b": "; 31 | 32 | /// HTTP/1.1 101 Switching Protocols 33 | pub const HTTP_STATUS_LINE: &[u8] = b"HTTP/1.1 101 Switching Protocols"; 34 | 35 | /// Http header, take two references 36 | #[allow(clippy::len_without_is_empty)] 37 | #[derive(Debug, Copy, Clone, Eq, PartialEq)] 38 | pub struct HttpHeader<'h> { 39 | pub name: &'h [u8], 40 | pub value: &'h [u8], 41 | } 42 | 43 | // compile time computation 44 | trait HeaderHelper { 45 | const SIZE: usize; 46 | } 47 | 48 | impl<'h> HttpHeader<'h> { 49 | /// Constructor, take provided name and value. 50 | #[inline] 51 | pub const fn new(name: &'h [u8], value: &'h [u8]) -> Self { Self { name, value } } 52 | 53 | /// Total number of bytes(name + value + sp). 54 | #[inline] 55 | pub const fn len(&self) -> usize { 56 | self.name.len() + self.value.len() + HTTP_HEADER_SP.len() + HTTP_LINE_BREAK.len() 57 | } 58 | 59 | /// Create [`MAX_ALLOW_HEADERS`] empty headers. 60 | #[inline] 61 | pub const fn new_storage() -> [HttpHeader<'static>; MAX_ALLOW_HEADERS] { 62 | [EMPTY_HEADER; MAX_ALLOW_HEADERS] 63 | } 64 | 65 | /// Create N empty headers. 66 | #[inline] 67 | pub const fn new_custom_storage() -> [HttpHeader<'static>; N] { 68 | [EMPTY_HEADER; N] 69 | } 70 | } 71 | 72 | impl Default for HttpHeader<'static> { 73 | fn default() -> Self { EMPTY_HEADER } 74 | } 75 | 76 | impl<'h> std::fmt::Display for HttpHeader<'h> { 77 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 78 | use std::str::from_utf8_unchecked; 79 | write!( 80 | f, 81 | "{}: {}", 82 | unsafe { from_utf8_unchecked(self.name) }, 83 | unsafe { from_utf8_unchecked(self.value) } 84 | ) 85 | } 86 | } 87 | 88 | macro_rules! header { 89 | ( $( 90 | $(#[$docs: meta])* 91 | ($hdr: ident => $name: expr, $value: expr); 92 | )+ 93 | ) => { 94 | $( 95 | $(#[$docs])* 96 | pub const $hdr: HttpHeader = HttpHeader::new($name, $value); 97 | )+ 98 | }; 99 | ( $( 100 | ($hdr_name: ident => $name: expr); 101 | )+ 102 | ) => { 103 | $( 104 | pub const $hdr_name: &[u8] = $name; 105 | )+ 106 | }; 107 | } 108 | 109 | macro_rules! write_header { 110 | ($w: expr, $hdr: expr) => { 111 | if $w.remaining() < $hdr.len() { 112 | return Err(HandshakeError::NotEnoughCapacity); 113 | } else { 114 | unsafe { 115 | $w.write_unchecked($hdr.name); 116 | $w.write_unchecked(HTTP_HEADER_SP); 117 | $w.write_unchecked($hdr.value); 118 | $w.write_unchecked(HTTP_LINE_BREAK); 119 | } 120 | } 121 | }; 122 | ($w: expr, $name: expr, $value: expr) => { 123 | write_header!($w, HttpHeader::new($name, $value)); 124 | }; 125 | } 126 | 127 | macro_rules! handshake_check { 128 | ($hdr: expr, $e: expr) => { 129 | if $hdr.value.is_empty() { 130 | return Err($e); 131 | } 132 | }; 133 | ($hdr: expr, $value: expr, $e: expr) => { 134 | // header value here is case insensitive 135 | // ref: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 136 | if $hdr.value.is_empty() || !$hdr.value.eq_ignore_ascii_case($value) { 137 | return Err($e); 138 | } 139 | }; 140 | } 141 | 142 | use write_header; 143 | use handshake_check; 144 | 145 | #[inline] 146 | fn filter_header<'h>( 147 | all: &[httparse::Header<'h>], 148 | required: &mut [HttpHeader<'h>], 149 | other: &mut [HttpHeader<'h>], 150 | ) { 151 | let mut other_iter = other.iter_mut(); 152 | for hdr in all.iter() { 153 | let name = hdr.name.as_bytes(); 154 | 155 | if let Some(h) = required 156 | .iter_mut() 157 | .filter(|h| h.value.is_empty()) 158 | .find(|h| h.name.eq_ignore_ascii_case(name)) 159 | { 160 | h.value = hdr.value; 161 | } else { 162 | let other_hdr = other_iter.next().unwrap(); 163 | other_hdr.name = name; 164 | other_hdr.value = hdr.value; 165 | } 166 | } 167 | } 168 | 169 | /// Static http headers 170 | #[allow(unused)] 171 | pub mod static_headers { 172 | use super::HttpHeader; 173 | // header 174 | header!( 175 | /// host: {host} 176 | (HEADER_HOST => b"host", b""); 177 | 178 | /// upgrade: websocket 179 | (HEADER_UPGRADE => b"upgrade", b""); 180 | 181 | /// connection: upgrade 182 | (HEADER_CONNECTION => b"connection", b""); 183 | 184 | /// sec-websocket-key: {key} 185 | (HEADER_SEC_WEBSOCKET_KEY => b"sec-websocket-key", b""); 186 | 187 | /// sec-websocket-accept: {accept} 188 | (HEADER_SEC_WEBSOCKET_ACCEPT => b"sec-websocket-accept", b""); 189 | 190 | /// sec-webSocket-version: 13 191 | (HEADER_SEC_WEBSOCKET_VERSION => b"sec-webSocket-version", b""); 192 | ); 193 | 194 | // header name 195 | header! { 196 | (HEADER_HOST_NAME => b"host"); 197 | 198 | (HEADER_UPGRADE_NAME => b"upgrade"); 199 | 200 | (HEADER_CONNECTION_NAME => b"connection"); 201 | 202 | (HEADER_SEC_WEBSOCKET_KEY_NAME => b"sec-websocket-key"); 203 | 204 | (HEADER_SEC_WEBSOCKET_ACCEPT_NAME => b"sec-websocket-accept"); 205 | 206 | (HEADER_SEC_WEBSOCKET_VERSION_NAME => b"sec-websocket-version"); 207 | } 208 | 209 | // header value 210 | header! { 211 | (HEADER_UPGRADE_VALUE => b"websocket"); 212 | 213 | (HEADER_CONNECTION_VALUE => b"upgrade"); 214 | 215 | (HEADER_SEC_WEBSOCKET_VERSION_VALUE => b"13"); 216 | } 217 | } 218 | 219 | #[cfg(test)] 220 | mod test { 221 | use rand::prelude::*; 222 | 223 | pub const TEMPLATE_HEADERS: &str = "\ 224 | host: www.example.com\r\n\ 225 | upgrade: websocket\r\n\ 226 | connection: upgrade\r\n\ 227 | sec-websocket-key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ 228 | sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\ 229 | sec-websocket-version: 13"; 230 | 231 | pub fn make_headers(count: usize, max_len: usize, headers: &str) -> String { 232 | fn rand_ascii() -> char { 233 | let x: u8 = thread_rng().gen_range(1..=4); 234 | let ch: u8 = match x { 235 | 1 => thread_rng().gen_range(b'0'..=b'9'), 236 | 2 => thread_rng().gen_range(b'A'..=b'Z'), 237 | 3 => thread_rng().gen_range(b'a'..=b'z'), 238 | 4 => b'-', 239 | _ => unreachable!(), 240 | }; 241 | ch as char 242 | } 243 | 244 | fn rand_str(len: usize) -> String { 245 | let mut s = String::new(); 246 | for _ in 0..len { 247 | s.push(rand_ascii()); 248 | } 249 | s 250 | } 251 | 252 | fn make_header(max_len: usize) -> String { 253 | let mut s = String::with_capacity(256); 254 | let name_len: usize = thread_rng().gen_range(1..=max_len); 255 | let value_len: usize = thread_rng().gen_range(1..=max_len); 256 | s.push_str(&format!( 257 | "{}: {}\r\n", 258 | rand_str(name_len), 259 | rand_str(value_len) 260 | )); 261 | s 262 | } 263 | 264 | let mut s = Vec::::with_capacity(256); 265 | for hdr in headers.split("\r\n") { 266 | s.push(format!("{}\r\n", hdr)); 267 | } 268 | for _ in 0..count { 269 | s.push(make_header(max_len)); 270 | } 271 | s.shuffle(&mut thread_rng()); 272 | s.concat() 273 | } 274 | } 275 | -------------------------------------------------------------------------------- /src/handshake/request.rs: -------------------------------------------------------------------------------- 1 | //! Client upgrade request. 2 | //! 3 | //! From [RFC-6455 Section 4.1](https://datatracker.ietf.org/doc/html/rfc6455#section-4.1): 4 | //! 5 | //! Once a connection to the server has been established (including a 6 | //! connection via a proxy or over a TLS-encrypted tunnel), the client 7 | //! MUST send an opening handshake to the server. The handshake consists 8 | //! of an HTTP Upgrade request, along with a list of required and 9 | //! optional header fields. 10 | //! 11 | //! Once the client's opening handshake has been sent, the client MUST 12 | //! wait for a response from the server before sending any further data. 13 | //! 14 | //! Example: 15 | //! 16 | //! ```text 17 | //! GET /path HTTP/1.1 18 | //! host: www.example.com 19 | //! upgrade: websocket 20 | //! connection: upgrade 21 | //! sec-websocket-key: dGhlIHNhbXBsZSBub25jZQ== 22 | //! sec-websocket-version: 13 23 | //! ``` 24 | //! 25 | 26 | use super::{HttpHeader, HeaderHelper}; 27 | use super::{write_header, filter_header}; 28 | use super::handshake_check; 29 | use super::MAX_ALLOW_HEADERS; 30 | use super::{HTTP_METHOD, HTTP_VERSION, HTTP_LINE_BREAK, HTTP_HEADER_SP}; 31 | use super::static_headers::*; 32 | 33 | use crate::bleed::Writer; 34 | use crate::error::HandshakeError; 35 | 36 | /// Http request presentation. 37 | pub struct Request<'h, 'b: 'h, const N: usize = MAX_ALLOW_HEADERS> { 38 | pub path: &'b [u8], 39 | pub host: &'b [u8], 40 | pub sec_key: &'b [u8], 41 | pub other_headers: &'h mut [HttpHeader<'b>], 42 | } 43 | 44 | impl<'h, 'b: 'h, const N: usize> HeaderHelper for Request<'h, 'b, N> { 45 | const SIZE: usize = N; 46 | } 47 | 48 | impl<'h, 'b: 'h> Request<'h, 'b> { 49 | /// Create a new request without extra headers. 50 | /// This is usually used to send a request. 51 | #[inline] 52 | pub const fn new(path: &'b [u8], host: &'b [u8], sec_key: &'b [u8]) -> Self { 53 | Self { 54 | path, 55 | host, 56 | sec_key, 57 | other_headers: &mut [], 58 | } 59 | } 60 | 61 | /// Create a new request with extra headers. 62 | /// This is usually used to send a request. 63 | #[inline] 64 | pub const fn new_with_headers( 65 | path: &'b [u8], 66 | host: &'b [u8], 67 | sec_key: &'b [u8], 68 | other_headers: &'h mut [HttpHeader<'b>], 69 | ) -> Self { 70 | Self { 71 | path, 72 | host, 73 | sec_key, 74 | other_headers, 75 | } 76 | } 77 | 78 | /// Create with user provided headers storage, other fields are left empty. 79 | /// This is usually used to receive a request. 80 | /// 81 | /// The max decode header size is [`MAX_ALLOW_HEADERS`]. 82 | #[inline] 83 | pub const fn new_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self { 84 | Self { 85 | path: &[], 86 | host: &[], 87 | sec_key: &[], 88 | other_headers, 89 | } 90 | } 91 | } 92 | 93 | impl<'h, 'b: 'h, const N: usize> Request<'h, 'b, N> { 94 | /// Create with user provided headers storage, other fields are left empty. 95 | /// This is usually used to receive a request. 96 | /// 97 | /// The const generic paramater represents the max decode header size. 98 | #[inline] 99 | pub const fn new_custom_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self { 100 | Self { 101 | path: &[], 102 | host: &[], 103 | sec_key: &[], 104 | other_headers, 105 | } 106 | } 107 | 108 | /// Encode to a provided buffer, return the number of written bytes. 109 | /// 110 | /// Necessary headers, including `host`, `upgrade`, `connection`, 111 | /// `sec-websocket-key` and `sec-websocket-version` are written to 112 | /// the buffer, then other headers(if any) are written in order. 113 | /// 114 | /// Caller should make sure there is enough space to write, 115 | /// otherwise a [`HandshakeError::NotEnoughCapacity`] error will be returned. 116 | pub fn encode(&self, buf: &mut [u8]) -> Result { 117 | debug_assert!(buf.len() > 80); 118 | 119 | let mut w = Writer::new(buf); 120 | 121 | // GET {path} HTTP/1.1 122 | unsafe { 123 | w.write_unchecked(HTTP_METHOD); 124 | w.write_byte_unchecked(0x20); 125 | w.write_unchecked(self.path); 126 | w.write_byte_unchecked(0x20); 127 | w.write_unchecked(HTTP_VERSION); 128 | w.write_unchecked(HTTP_LINE_BREAK); 129 | } 130 | 131 | // host: {host} 132 | write_header!(w, HEADER_HOST_NAME, self.host); 133 | 134 | // upgrade: websocket 135 | write_header!(w, HEADER_UPGRADE_NAME, HEADER_UPGRADE_VALUE); 136 | 137 | // connection: upgrade 138 | write_header!(w, HEADER_CONNECTION_NAME, HEADER_CONNECTION_VALUE); 139 | 140 | // sec-websocket-key: {sec_key} 141 | write_header!(w, HEADER_SEC_WEBSOCKET_KEY_NAME, self.sec_key); 142 | 143 | // sec-websocket-version: 13 144 | write_header!( 145 | w, 146 | HEADER_SEC_WEBSOCKET_VERSION_NAME, 147 | HEADER_SEC_WEBSOCKET_VERSION_VALUE 148 | ); 149 | 150 | // other headers 151 | for hdr in self.other_headers.iter() { 152 | write_header!(w, hdr) 153 | } 154 | 155 | // finish with CRLF 156 | w.write_or_err(HTTP_LINE_BREAK, || HandshakeError::NotEnoughCapacity)?; 157 | 158 | Ok(w.pos()) 159 | } 160 | 161 | /// Parse from a provided buffer, save the results, and 162 | /// return the number of bytes parsed. 163 | /// 164 | /// Necessary headers, including `host`, `upgrade`, `connection`, 165 | /// `sec-websocket-key` and `sec-websocket-version` are parsed and checked, 166 | /// and stored in the struct. Optional headers 167 | /// (like `sec-websocket-protocol`) are stored in `other_headers`. 168 | /// After the parse, `other_headers` will be shrunk to 169 | /// fit the number of stored headers. 170 | /// 171 | /// Caller should make sure there is enough space 172 | /// (default is [`MAX_ALLOW_HEADERS`]) to store headers, 173 | /// which could be specified by the const generic paramater. 174 | /// If the buffer does not contain a complete http request, 175 | /// a [`HandshakeError::NotEnoughData`] error will be returned. 176 | /// If the required headers(mentioned above) do not pass the check 177 | /// (case insensitive), other corresponding errors will be returned. 178 | pub fn decode(&mut self, buf: &'b [u8]) -> Result { 179 | debug_assert!(self.other_headers.len() >= ::SIZE); 180 | 181 | let mut headers = [httparse::EMPTY_HEADER; N]; 182 | let mut request = httparse::Request::new(&mut headers); 183 | 184 | // return value 185 | let decode_n = match request.parse(buf)? { 186 | httparse::Status::Complete(n) => n, 187 | httparse::Status::Partial => return Err(HandshakeError::NotEnoughData), 188 | }; 189 | 190 | // check method 191 | if request.method.unwrap().as_bytes() != HTTP_METHOD { 192 | return Err(HandshakeError::HttpMethod); 193 | } 194 | 195 | // check version, should be HTTP/1.1 196 | // ref: https://docs.rs/httparse/latest/src/httparse/lib.rs.html#581-596 197 | if request.version.unwrap() != 1_u8 { 198 | return Err(HandshakeError::HttpVersion); 199 | } 200 | 201 | // handle headers below 202 | // headers are shrunk to number of inited headers 203 | // ref: https://docs.rs/httparse/latest/src/httparse/lib.rs.html#757-765 204 | let headers = request.headers; 205 | 206 | let mut required_headers = [ 207 | HEADER_HOST, 208 | HEADER_UPGRADE, 209 | HEADER_CONNECTION, 210 | HEADER_SEC_WEBSOCKET_KEY, 211 | HEADER_SEC_WEBSOCKET_VERSION, 212 | ]; 213 | 214 | // filter required headers, save other headers 215 | filter_header(headers, &mut required_headers, self.other_headers); 216 | 217 | let [host_hdr, upgrade_hdr, connection_hdr, sec_key_hdr, sec_version_hdr] = 218 | required_headers; 219 | 220 | // check missing header 221 | if !required_headers.iter().all(|h| !h.value.is_empty()) { 222 | handshake_check!(host_hdr, HandshakeError::HttpHost); 223 | handshake_check!(upgrade_hdr, HandshakeError::Upgrade); 224 | handshake_check!(connection_hdr, HandshakeError::Connection); 225 | handshake_check!(sec_key_hdr, HandshakeError::SecWebSocketKey); 226 | handshake_check!(sec_version_hdr, HandshakeError::SecWebSocketVersion); 227 | } 228 | 229 | // check header value (case insensitive) 230 | // ref: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 231 | handshake_check!(upgrade_hdr, HEADER_UPGRADE_VALUE, HandshakeError::Upgrade); 232 | 233 | handshake_check!( 234 | connection_hdr, 235 | HEADER_CONNECTION_VALUE, 236 | HandshakeError::Connection 237 | ); 238 | 239 | handshake_check!( 240 | sec_version_hdr, 241 | HEADER_SEC_WEBSOCKET_VERSION_VALUE, 242 | HandshakeError::SecWebSocketVersion 243 | ); 244 | 245 | // save ref 246 | self.path = request.path.unwrap().as_bytes(); 247 | self.host = host_hdr.value; 248 | self.sec_key = sec_key_hdr.value; 249 | 250 | // shrink header reference 251 | let other_header_len = headers.len() - required_headers.len(); 252 | 253 | // remove lifetime here, remember that 254 | // &mut other_headers lives longer than &mut self 255 | let other_headers: &'h mut [HttpHeader<'b>] = 256 | unsafe { &mut *(self.other_headers as *mut _) }; 257 | self.other_headers = unsafe { other_headers.get_unchecked_mut(0..other_header_len) }; 258 | 259 | Ok(decode_n) 260 | } 261 | } 262 | 263 | #[cfg(test)] 264 | mod test { 265 | use super::*; 266 | use super::super::HttpHeader; 267 | use super::super::test::{make_headers, TEMPLATE_HEADERS}; 268 | use rand::prelude::*; 269 | 270 | #[test] 271 | fn client_handshake() { 272 | for i in 0..64 { 273 | let hdr_len: usize = thread_rng().gen_range(1..128); 274 | let headers = format!( 275 | "GET / HTTP/1.1\r\n{}\r\n", 276 | make_headers(i, hdr_len, TEMPLATE_HEADERS) 277 | ); 278 | 279 | let mut other_headers = HttpHeader::new_custom_storage::<1024>(); 280 | let mut request = Request::<1024>::new_custom_storage(&mut other_headers); 281 | let decode_n = request.decode(headers.as_bytes()).unwrap(); 282 | 283 | assert_eq!(decode_n, headers.len()); 284 | assert_eq!(request.path, b"/"); 285 | assert_eq!(request.host, b"www.example.com"); 286 | assert_eq!(request.sec_key, b"dGhlIHNhbXBsZSBub25jZQ=="); 287 | 288 | // other headers 289 | macro_rules! match_other { 290 | ($name: expr, $value: expr) => {{ 291 | request 292 | .other_headers 293 | .iter() 294 | .find(|hdr| hdr.name == $name && hdr.value == $value) 295 | .unwrap(); 296 | }}; 297 | } 298 | match_other!(b"sec-websocket-accept", b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); 299 | 300 | let mut buf: Vec = vec![0; 0x4000]; 301 | let encode_n = request.encode(&mut buf).unwrap(); 302 | assert_eq!(encode_n, decode_n); 303 | } 304 | } 305 | 306 | #[test] 307 | fn client_handshake2() { 308 | macro_rules! run { 309 | ($host: expr, $path: expr, $sec_key: expr) => {{ 310 | let headers = format!( 311 | "GET {1} HTTP/1.1\r\n{0}\r\n", 312 | make_headers( 313 | 16, 314 | 32, 315 | &format!( 316 | "host: {0}\r\n\ 317 | sec-websocket-key: {1}\r\n\ 318 | upgrade: websocket\r\n\ 319 | connection: upgrade\r\n\ 320 | sec-websocket-version: 13", 321 | $host, $sec_key 322 | ) 323 | ), 324 | $path 325 | ); 326 | 327 | let mut other_headers = HttpHeader::new_storage(); 328 | let mut request = Request::new_storage(&mut other_headers); 329 | let decode_n = request.decode(headers.as_bytes()).unwrap(); 330 | assert_eq!(decode_n, headers.len()); 331 | assert_eq!(request.host, $host.as_bytes()); 332 | assert_eq!(request.path, $path.as_bytes()); 333 | assert_eq!(request.sec_key, $sec_key.as_bytes()); 334 | 335 | let mut buf: Vec = vec![0; 0x4000]; 336 | let encode_n = request.encode(&mut buf).unwrap(); 337 | assert_eq!(encode_n, decode_n); 338 | }}; 339 | } 340 | 341 | run!("host", "/path", "key"); 342 | run!("www.abc.com", "/path/to", "xxxxxx"); 343 | run!("wwww.www.ww.w", "/path/to/to/path", "xxxxxxyyyy"); 344 | } 345 | 346 | // catch errors ... 347 | } 348 | -------------------------------------------------------------------------------- /src/handshake/response.rs: -------------------------------------------------------------------------------- 1 | //! Server upgrade response. 2 | //! 3 | //! From [RFC-6455 Section 4.2](https://datatracker.ietf.org/doc/html/rfc6455#section-4.2): 4 | //! 5 | //! When a client starts a WebSocket connection, it sends its part of the 6 | //! opening handshake. The server must parse at least part of this 7 | //! handshake in order to obtain the necessary information to generate 8 | //! the server part of the handshake. 9 | //! 10 | //! If the server chooses to accept the incoming connection, it MUST 11 | //! reply with a valid HTTP response. 12 | //! 13 | //! Example: 14 | //! 15 | //! ```text 16 | //! HTTP/1.1 101 Switching Protocols 17 | //! upgrade: websocket 18 | //! connection: upgrade 19 | //! sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= 20 | //! ``` 21 | //! 22 | 23 | use super::{HttpHeader, HeaderHelper}; 24 | use super::{write_header, filter_header}; 25 | use super::handshake_check; 26 | use super::MAX_ALLOW_HEADERS; 27 | use super::{HTTP_STATUS_LINE, HTTP_LINE_BREAK, HTTP_HEADER_SP}; 28 | use super::static_headers::*; 29 | 30 | use crate::bleed::Writer; 31 | use crate::error::HandshakeError; 32 | 33 | /// Http response presentation. 34 | pub struct Response<'h, 'b: 'h, const N: usize = MAX_ALLOW_HEADERS> { 35 | pub sec_accept: &'b [u8], 36 | pub other_headers: &'h mut [HttpHeader<'b>], 37 | } 38 | 39 | impl<'h, 'b: 'h, const N: usize> HeaderHelper for Response<'h, 'b, N> { 40 | const SIZE: usize = N; 41 | } 42 | 43 | impl<'h, 'b: 'h> Response<'h, 'b> { 44 | /// Create a new response without extra headers. 45 | /// This is usually used to send a response. 46 | #[inline] 47 | pub const fn new(sec_accept: &'b [u8]) -> Self { 48 | Self { 49 | sec_accept, 50 | other_headers: &mut [], 51 | } 52 | } 53 | 54 | /// Create a new response with extra headers. 55 | /// This is usually used to send a response. 56 | #[inline] 57 | pub const fn new_with_headers( 58 | sec_accept: &'b [u8], 59 | other_headers: &'h mut [HttpHeader<'b>], 60 | ) -> Self { 61 | Self { 62 | sec_accept, 63 | other_headers, 64 | } 65 | } 66 | 67 | /// Create with user provided headers storage, other fields are left empty. 68 | /// This is usually used to receive a response. 69 | /// 70 | /// The max decode header size is [`MAX_ALLOW_HEADERS`]. 71 | #[inline] 72 | pub const fn new_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self { 73 | Self { 74 | sec_accept: &[], 75 | other_headers, 76 | } 77 | } 78 | } 79 | 80 | impl<'h, 'b: 'h, const N: usize> Response<'h, 'b, N> { 81 | /// Create with user provided headers storage, other fields are left empty. 82 | /// This is usually used to receive a response. 83 | /// 84 | /// The const generic paramater represents the max decode header size. 85 | #[inline] 86 | pub const fn new_custom_storage(other_headers: &'h mut [HttpHeader<'b>]) -> Self { 87 | Self { 88 | sec_accept: &[], 89 | other_headers, 90 | } 91 | } 92 | 93 | /// Encode to a provided buffer, return the number of written bytes. 94 | /// 95 | /// Necessary headers, including `upgrade`, `connection`, and 96 | /// `sec-websocket-accept` are written to the buffer, 97 | /// then other headers(if any) are written in order. 98 | /// 99 | /// Caller should make sure there is enough space to write, 100 | /// otherwise a [`HandshakeError::NotEnoughCapacity`] error will be returned. 101 | pub fn encode(&self, buf: &mut [u8]) -> Result { 102 | debug_assert!(buf.len() > 80); 103 | 104 | let mut w = Writer::new(buf); 105 | 106 | // HTTP/1.1 101 Switching Protocols 107 | unsafe { 108 | w.write_unchecked(HTTP_STATUS_LINE); 109 | w.write_unchecked(HTTP_LINE_BREAK); 110 | } 111 | 112 | // upgrade: websocket 113 | write_header!(w, HEADER_UPGRADE_NAME, HEADER_UPGRADE_VALUE); 114 | 115 | // connection: upgrade 116 | write_header!(w, HEADER_CONNECTION_NAME, HEADER_CONNECTION_VALUE); 117 | 118 | // sec-websocket-accept: {sec_accept} 119 | write_header!(w, HEADER_SEC_WEBSOCKET_ACCEPT_NAME, self.sec_accept); 120 | 121 | // other headers 122 | for hdr in self.other_headers.iter() { 123 | write_header!(w, hdr) 124 | } 125 | 126 | // finish with CRLF 127 | w.write_or_err(HTTP_LINE_BREAK, || HandshakeError::NotEnoughCapacity)?; 128 | 129 | Ok(w.pos()) 130 | } 131 | 132 | /// Parse from a provided buffer, save the results, and 133 | /// return the number of bytes parsed. 134 | /// 135 | /// Necessary headers, including `upgrade`, `connection`, and 136 | /// `sec-websocket-version` are parsed and checked, 137 | /// and stored in the struct. Optional headers 138 | /// (like `sec-websocket-protocol`) are stored in `other headers`. 139 | /// After the parse, `other_headers` will be shrunk to 140 | /// fit the number of stored headers. 141 | /// 142 | /// Caller should make sure there is enough space 143 | /// (default is [`MAX_ALLOW_HEADERS`]) to store headers, 144 | /// which could be specified by the const generic paramater. 145 | /// If the buffer does not contain a complete http request, 146 | /// a [`HandshakeError::NotEnoughData`] error will be returned. 147 | /// If the required headers(mentioned above) do not pass the check 148 | /// (case insensitive), other corresponding errors will be returned. 149 | pub fn decode(&mut self, buf: &'b [u8]) -> Result { 150 | debug_assert!(self.other_headers.len() >= ::SIZE); 151 | 152 | let mut headers = [httparse::EMPTY_HEADER; N]; 153 | let mut response = httparse::Response::new(&mut headers); 154 | 155 | // return value 156 | let decode_n = match response.parse(buf)? { 157 | httparse::Status::Complete(n) => n, 158 | httparse::Status::Partial => return Err(HandshakeError::NotEnoughData), 159 | }; 160 | 161 | // check version, should be HTTP/1.1 162 | // ref: https://docs.rs/httparse/latest/src/httparse/lib.rs.html#581-596 163 | if response.version.unwrap() != 1_u8 { 164 | return Err(HandshakeError::HttpVersion); 165 | } 166 | 167 | // check status code, should be 101 168 | // ref: https://docs.rs/httparse/latest/src/httparse/lib.rs.html#581-596 169 | if response.code.unwrap() != 101_u16 { 170 | return Err(HandshakeError::HttpSatusCode); 171 | } 172 | 173 | // handle headers below 174 | // headers are shrunk to number of inited headers 175 | // ref: https://docs.rs/httparse/latest/src/httparse/lib.rs.html#757-765 176 | let headers = response.headers; 177 | 178 | let mut required_headers = [ 179 | HEADER_UPGRADE, 180 | HEADER_CONNECTION, 181 | HEADER_SEC_WEBSOCKET_ACCEPT, 182 | ]; 183 | 184 | // filter required headers, save other headers 185 | filter_header(headers, &mut required_headers, self.other_headers); 186 | 187 | let [upgrade_hdr, connection_hdr, sec_accept_hdr] = required_headers; 188 | 189 | // check missing header 190 | if !required_headers.iter().all(|h| !h.value.is_empty()) { 191 | handshake_check!(upgrade_hdr, HandshakeError::Upgrade); 192 | handshake_check!(connection_hdr, HandshakeError::Connection); 193 | handshake_check!(sec_accept_hdr, HandshakeError::SecWebSocketAccept); 194 | } 195 | 196 | // check header value (case insensitive) 197 | // ref: https://datatracker.ietf.org/doc/html/rfc6455#section-4.1 198 | handshake_check!(upgrade_hdr, HEADER_UPGRADE_VALUE, HandshakeError::Upgrade); 199 | 200 | handshake_check!( 201 | connection_hdr, 202 | HEADER_CONNECTION_VALUE, 203 | HandshakeError::Connection 204 | ); 205 | 206 | // save ref 207 | self.sec_accept = sec_accept_hdr.value; 208 | 209 | // shrink header reference 210 | let other_header_len = headers.len() - required_headers.len(); 211 | 212 | // remove lifetime here, remember that 213 | // &mut other_headers lives longer than &mut self 214 | let other_headers: &'h mut [HttpHeader<'b>] = 215 | unsafe { &mut *(self.other_headers as *mut _) }; 216 | self.other_headers = unsafe { other_headers.get_unchecked_mut(0..other_header_len) }; 217 | 218 | Ok(decode_n) 219 | } 220 | } 221 | 222 | #[cfg(test)] 223 | mod test { 224 | use super::*; 225 | use super::super::HttpHeader; 226 | use super::super::test::{make_headers, TEMPLATE_HEADERS}; 227 | use rand::prelude::*; 228 | 229 | #[test] 230 | fn server_handshake() { 231 | for i in 0..64 { 232 | let hdr_len: usize = thread_rng().gen_range(1..128); 233 | let headers = format!( 234 | "HTTP/1.1 101 Switching Protocols\r\n{}\r\n", 235 | make_headers(i, hdr_len, TEMPLATE_HEADERS) 236 | ); 237 | 238 | let mut other_headers = HttpHeader::new_custom_storage::<1024>(); 239 | let mut response = Response::<1024>::new_custom_storage(&mut other_headers); 240 | let decode_n = response.decode(headers.as_bytes()).unwrap(); 241 | 242 | assert_eq!(decode_n, headers.len()); 243 | assert_eq!(response.sec_accept, b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); 244 | 245 | // other headers 246 | macro_rules! match_other { 247 | ($name: expr, $value: expr) => {{ 248 | response 249 | .other_headers 250 | .iter() 251 | .find(|hdr| hdr.name == $name && hdr.value == $value) 252 | .unwrap(); 253 | }}; 254 | } 255 | 256 | match_other!(b"host", b"www.example.com"); 257 | match_other!(b"sec-websocket-version", b"13"); 258 | match_other!(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ=="); 259 | 260 | let mut buf: Vec = vec![0; 0x4000]; 261 | let encode_n = response.encode(&mut buf).unwrap(); 262 | assert_eq!(encode_n, decode_n); 263 | } 264 | } 265 | 266 | #[test] 267 | fn server_handshake2() { 268 | macro_rules! run { 269 | ($sec_accept: expr) => {{ 270 | let headers = format!( 271 | "HTTP/1.1 101 Switching Protocols\r\n{}\r\n", 272 | make_headers( 273 | 16, 274 | 32, 275 | &format!( 276 | "upgrade: websocket\r\n\ 277 | connection: upgrade\r\n\ 278 | sec-websocket-accept: {}", 279 | $sec_accept 280 | ) 281 | ) 282 | ); 283 | 284 | let mut other_headers = HttpHeader::new_storage(); 285 | let mut response = Response::new_storage(&mut other_headers); 286 | let decode_n = response.decode(headers.as_bytes()).unwrap(); 287 | assert_eq!(decode_n, headers.len()); 288 | assert_eq!(response.sec_accept, $sec_accept.as_bytes()); 289 | 290 | let mut buf: Vec = vec![0; 0x4000]; 291 | let encode_n = response.encode(&mut buf).unwrap(); 292 | assert_eq!(encode_n, decode_n); 293 | }}; 294 | } 295 | 296 | run!("aaa"); 297 | run!("bbbbbbbbbb"); 298 | run!("xxxxxxxxx=="); 299 | } 300 | 301 | // catch errors ... 302 | } 303 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(incomplete_features)] 2 | #![allow(clippy::blocks_in_conditions)] 3 | #![feature(read_buf)] 4 | #![feature(core_io_borrowed_buf)] 5 | #![feature(specialization)] 6 | 7 | //! Lightweight websocket implement for stream transmission. 8 | //! 9 | //! ## Features 10 | //! 11 | //! - Avoid heap allocation. 12 | //! - Avoid buffering frame payload. 13 | //! - Use vectored-io if available. 14 | //! - Transparent Read/Write over the underlying IO source. 15 | //! 16 | //! ## High-level API 17 | //! 18 | //! - [`role`] 19 | //! - [`endpoint`] 20 | //! - [`stream`] 21 | //! 22 | //! Std: 23 | //! 24 | //! ```no_run 25 | //! use std::io::{Read, Write}; 26 | //! use std::net::TcpStream; 27 | //! use lightws::role::Client; 28 | //! use lightws::endpoint::Endpoint; 29 | //! fn run_sync() -> std::io::Result<()> { 30 | //! let mut buf = [0u8; 256]; 31 | //! // establish tcp connection 32 | //! let mut tcp = TcpStream::connect("example.com:80")?; 33 | //! // establish ws connection 34 | //! let mut ws = Endpoint::::connect(tcp, &mut buf, "example.com", "/ws")?; 35 | //! // read some data 36 | //! let n = ws.read(&mut buf)?; 37 | //! // write some data 38 | //! let n = ws.write(&buf)?; 39 | //! Ok(()) 40 | //! } 41 | //! ``` 42 | //! 43 | //! Tokio: 44 | //! 45 | //! ```no_run 46 | //! use tokio::net::TcpStream; 47 | //! use tokio::io::{AsyncReadExt, AsyncWriteExt}; 48 | //! use lightws::role::Client; 49 | //! use lightws::endpoint::Endpoint; 50 | //! async fn run_async() -> std::io::Result<()> { 51 | //! let mut buf = [0u8; 256]; 52 | //! // establish tcp connection 53 | //! let mut tcp = TcpStream::connect("example.com:80").await?; 54 | //! // establish ws connection 55 | //! let mut ws = Endpoint::::connect_async(tcp, &mut buf, "example.com", "/ws").await?; 56 | //! // read some data 57 | //! let n = ws.read(&mut buf).await?; 58 | //! // write some data 59 | //! let n = ws.write(&buf).await?; 60 | //! Ok(()) 61 | //! } 62 | //! ``` 63 | //! 64 | //! ## Low-level API 65 | //! 66 | //! - [`frame`] 67 | //! - [`handshake`] 68 | //! 69 | //! Frame: 70 | //! 71 | //! ```no_run 72 | //! use lightws::frame::{FrameHead, Fin, OpCode, PayloadLen, Mask}; 73 | //! { 74 | //! let mut buf = [0u8; 14]; 75 | //! // crate a frame head 76 | //! let head = FrameHead::new( 77 | //! Fin::N, OpCode::Binary, 78 | //! Mask::None, PayloadLen::from_num(256) 79 | //! ); 80 | //! // encode to buffer 81 | //! let offset = unsafe { 82 | //! head.encode_unchecked(&mut buf); 83 | //! }; 84 | //! // decode from buffer 85 | //! let (head, offset) = FrameHead::decode(&buf).unwrap(); 86 | //! } 87 | //! ``` 88 | //! 89 | //! Handshake: 90 | //! 91 | //! ```no_run 92 | //! use lightws::handshake::{Request, Response, HttpHeader}; 93 | //! { 94 | //! let mut buf = [0u8; 256]; 95 | //! // make a client handshake request 96 | //! let request = Request::new(b"/ws", b"example.com", b"sec-key.."); 97 | //! let offset = request.encode(&mut buf).unwrap(); 98 | //! 99 | //! // parse a server handshake response 100 | //! let mut custom_headers = HttpHeader::new_storage(); 101 | //! let mut response = Response::new_storage(&mut custom_headers); 102 | //! let offset = response.decode(&buf).unwrap(); 103 | //! } 104 | //! ``` 105 | 106 | mod bleed; 107 | 108 | pub mod role; 109 | pub mod error; 110 | pub mod frame; 111 | pub mod stream; 112 | pub mod endpoint; 113 | pub mod handshake; 114 | -------------------------------------------------------------------------------- /src/role/client.rs: -------------------------------------------------------------------------------- 1 | use super::{RoleHelper, ClientRole, AutoMaskClientRole}; 2 | use crate::frame::Mask; 3 | 4 | macro_rules! client_consts { 5 | () => { 6 | const SHORT_FRAME_HEAD_LEN: u8 = 2; 7 | const COMMON_FRAME_HEAD_LEN: u8 = 2 + 2; 8 | const LONG_FRAME_HEAD_LEN: u8 = 2 + 8; 9 | }; 10 | } 11 | 12 | /// Simple client using an empty(fake) mask key. 13 | /// 14 | /// It simply skips masking before writing data. 15 | #[derive(Clone, Copy)] 16 | pub struct Client; 17 | 18 | impl RoleHelper for Client { 19 | client_consts!(); 20 | 21 | #[inline] 22 | fn new() -> Self { Self {} } 23 | 24 | #[inline] 25 | fn mask_key(&self) -> Mask { Mask::Skip } 26 | } 27 | 28 | impl ClientRole for Client {} 29 | 30 | /// Standard client using random mask key. 31 | /// 32 | /// With `unsafe_auto_mask_write` feature enabled, it will automatically 33 | /// update its inner mask key and mask payload data before a write. 34 | #[derive(Clone, Copy)] 35 | pub struct StandardClient([u8; 4]); 36 | 37 | impl RoleHelper for StandardClient { 38 | client_consts!(); 39 | 40 | #[inline] 41 | fn new() -> Self { Self([0u8; 4]) } 42 | 43 | #[inline] 44 | fn mask_key(&self) -> Mask { Mask::Key(self.0) } 45 | 46 | #[inline] 47 | fn set_mask_key(&mut self, mask: [u8; 4]) { self.0 = mask; } 48 | } 49 | 50 | impl ClientRole for StandardClient {} 51 | 52 | impl AutoMaskClientRole for StandardClient { 53 | const UPDATE_MASK_KEY: bool = true; 54 | } 55 | 56 | /// Client using a fixed mask key. 57 | /// 58 | /// With `unsafe_auto_mask_write` feature enabled, it will automatically 59 | /// mask payload data before a write, where its inner mask key is not updated. 60 | #[derive(Clone, Copy)] 61 | pub struct FixedMaskClient([u8; 4]); 62 | 63 | impl RoleHelper for FixedMaskClient { 64 | client_consts!(); 65 | 66 | #[inline] 67 | fn new() -> Self { Self(crate::frame::new_mask_key()) } 68 | 69 | #[inline] 70 | fn mask_key(&self) -> Mask { Mask::Key(self.0) } 71 | 72 | #[inline] 73 | fn set_mask_key(&mut self, mask: [u8; 4]) { self.0 = mask; } 74 | } 75 | 76 | impl ClientRole for FixedMaskClient {} 77 | 78 | impl AutoMaskClientRole for FixedMaskClient { 79 | const UPDATE_MASK_KEY: bool = false; 80 | } 81 | -------------------------------------------------------------------------------- /src/role/mod.rs: -------------------------------------------------------------------------------- 1 | //! Markers. 2 | //! 3 | //! Markers are used to apply different strategies to clients or servers. 4 | //! 5 | //! For example, `Endpoint::connect` is used to to open a connection, 6 | //! and returns `Stream`; `Endpoint` is used to accept 7 | //! a connection and returns `Stream`. 8 | //! 9 | //! Both client and server meet [`RoleHelper`], which indicates frame head length 10 | //! (currently unused), and how to mask payload data. Only client meets [`ClientRole`], 11 | //! and only server meets [`ServerRole`]. 12 | //! 13 | //! Any type implements these traits will be treated as a `client` or `server`. 14 | 15 | use crate::frame::Mask; 16 | 17 | /// Client or Server marker. 18 | pub trait RoleHelper: Clone + Copy { 19 | const SHORT_FRAME_HEAD_LEN: u8; 20 | const COMMON_FRAME_HEAD_LEN: u8; 21 | const LONG_FRAME_HEAD_LEN: u8; 22 | 23 | fn new() -> Self; 24 | fn mask_key(&self) -> Mask; 25 | // by default this is a no-op 26 | fn set_mask_key(&mut self, _: [u8; 4]) {} 27 | } 28 | 29 | /// Client marker. 30 | pub trait ClientRole: RoleHelper { 31 | const SHORT_FRAME_HEAD_LEN: u8 = 2; 32 | const COMMON_FRAME_HEAD_LEN: u8 = 2 + 2; 33 | const LONG_FRAME_HEAD_LEN: u8 = 2 + 8; 34 | } 35 | 36 | /// Server marker. 37 | pub trait ServerRole: RoleHelper {} 38 | 39 | /// Client marker. 40 | pub trait AutoMaskClientRole: ClientRole { 41 | const UPDATE_MASK_KEY: bool; 42 | } 43 | 44 | mod server; 45 | mod client; 46 | 47 | pub use server::Server; 48 | pub use client::{Client, StandardClient, FixedMaskClient}; 49 | -------------------------------------------------------------------------------- /src/role/server.rs: -------------------------------------------------------------------------------- 1 | use super::{RoleHelper, ServerRole}; 2 | use crate::frame::Mask; 3 | 4 | /// Standard server. 5 | #[derive(Clone, Copy)] 6 | pub struct Server; 7 | 8 | impl RoleHelper for Server { 9 | const SHORT_FRAME_HEAD_LEN: u8 = 2 + 4; 10 | const COMMON_FRAME_HEAD_LEN: u8 = 2 + 2 + 4; 11 | const LONG_FRAME_HEAD_LEN: u8 = 2 + 8 + 4; 12 | 13 | #[inline] 14 | fn new() -> Self { Self {} } 15 | 16 | /// Server should not mask the payload. 17 | #[inline] 18 | fn mask_key(&self) -> Mask { Mask::None } 19 | } 20 | 21 | impl ServerRole for Server {} 22 | -------------------------------------------------------------------------------- /src/stream/async_read.rs: -------------------------------------------------------------------------------- 1 | use std::io::Result; 2 | use std::pin::Pin; 3 | use std::task::{Poll, Context}; 4 | 5 | use tokio::io::AsyncRead; 6 | use tokio::io::ReadBuf; 7 | 8 | use super::{Stream, RoleHelper, Guarded}; 9 | use super::detail::read_some; 10 | 11 | impl AsyncRead for Stream 12 | where 13 | IO: AsyncRead + Unpin, 14 | Stream: Unpin, 15 | Role: RoleHelper, 16 | { 17 | /// Async version of `Stream::read`. 18 | #[rustfmt::skip] 19 | fn poll_read( 20 | self: Pin<&mut Self>, 21 | cx: &mut Context<'_>, 22 | buf: &mut ReadBuf<'_>, 23 | ) -> Poll> { 24 | read_some(self.get_mut(), |io, buf| { 25 | let mut buf = ReadBuf::new(buf); 26 | Pin::new(io).poll_read(cx, &mut buf) 27 | .map_ok(|_| buf.filled().len()) 28 | }, 29 | buf.initialize_unfilled(), 30 | ).map_ok(|n| buf.advance(n)) 31 | } 32 | } 33 | 34 | impl AsyncRead for Stream 35 | where 36 | IO: AsyncRead + Unpin, 37 | Stream: Unpin, 38 | Role: RoleHelper, 39 | { 40 | /// Async version of `Stream::read`. 41 | /// Continue to read if frame head is not complete. 42 | fn poll_read( 43 | self: Pin<&mut Self>, 44 | cx: &mut Context<'_>, 45 | buf: &mut ReadBuf<'_>, 46 | ) -> Poll> { 47 | let this = self.get_mut(); 48 | 49 | loop { 50 | match read_some( 51 | this, 52 | |io, buf| { 53 | let mut buf = ReadBuf::new(buf); 54 | Pin::new(io) 55 | .poll_read(cx, &mut buf) 56 | .map_ok(|_| buf.filled().len()) 57 | }, 58 | buf.initialize_unfilled(), 59 | ) { 60 | Poll::Ready(Ok(0)) if this.is_read_partial_head() || !this.is_read_end() => { 61 | continue 62 | } 63 | Poll::Ready(Ok(n)) => { 64 | buf.advance(n); 65 | return Poll::Ready(Ok(())); 66 | } 67 | Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), 68 | Poll::Pending => return Poll::Pending, 69 | } 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/stream/async_write.rs: -------------------------------------------------------------------------------- 1 | use std::io::Result; 2 | use std::pin::Pin; 3 | use std::task::{Poll, Context}; 4 | 5 | use tokio::io::AsyncWrite; 6 | 7 | use super::{Stream, RoleHelper, Guarded}; 8 | use super::detail::write_some; 9 | 10 | impl AsyncWrite for Stream 11 | where 12 | IO: AsyncWrite + Unpin, 13 | Stream: Unpin, 14 | Role: RoleHelper, 15 | { 16 | /// Async version of `Stream::write`. 17 | #[rustfmt::skip] 18 | fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { 19 | write_some(self.get_mut(), |io, buf| Pin::new(io).poll_write_vectored(cx, buf), buf) 20 | } 21 | 22 | /// This is a no-op since we do not buffer any data. 23 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 24 | Pin::new(&mut self.get_mut().io).poll_flush(cx) 25 | } 26 | 27 | /// Shutdown the underlying IO source. 28 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 29 | Pin::new(&mut self.get_mut().io).poll_shutdown(cx) 30 | } 31 | } 32 | 33 | impl AsyncWrite for Stream 34 | where 35 | IO: AsyncWrite + Unpin, 36 | Stream: Unpin, 37 | Role: RoleHelper, 38 | { 39 | /// Async version of `Stream::write`. 40 | /// Continue to write if frame head is not completely written. 41 | #[rustfmt::skip] 42 | fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { 43 | let this = self.get_mut(); 44 | loop { 45 | match write_some(this, |io, buf| Pin::new(io).poll_write_vectored(cx, buf), buf) { 46 | Poll::Ready(Ok(0)) if this.is_write_partial_head() || !this.is_write_zero()=> continue, 47 | Poll::Ready(Ok(n)) => return Poll::Ready(Ok(n)), 48 | Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), 49 | Poll::Pending => return Poll::Pending, 50 | } 51 | } 52 | } 53 | 54 | /// This is a no-op since we do not buffer any data. 55 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 56 | Pin::new(&mut self.get_mut().io).poll_flush(cx) 57 | } 58 | 59 | /// Shutdown the underlying IO source. 60 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 61 | Pin::new(&mut self.get_mut().io).poll_shutdown(cx) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/stream/ctrl.rs: -------------------------------------------------------------------------------- 1 | use super::Stream; 2 | use super::state::WriteState; 3 | 4 | use crate::frame::Mask; 5 | use crate::role::RoleHelper; 6 | use crate::error::CtrlError; 7 | 8 | impl Stream 9 | where 10 | Role: RoleHelper, 11 | { 12 | /// Get mask for upcoming writes. 13 | #[inline] 14 | pub fn mask_key(&self) -> Mask { self.role.mask_key() } 15 | 16 | /// Set mask for upcoming writes. 17 | /// An attempt to set mask during a write will fail with [`CtrlError::SetMaskInWrite`]. 18 | #[inline] 19 | pub fn set_mask_key(&mut self, key: [u8; 4]) -> Result<(), CtrlError> { 20 | // make sure this is a new fresh write 21 | if let WriteState::WriteHead(head) = self.write_state { 22 | if head.is_empty() { 23 | self.role.set_mask_key(key); 24 | return Ok(()); 25 | } 26 | } 27 | Err(CtrlError::SetMaskInWrite) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/stream/detail/mod.rs: -------------------------------------------------------------------------------- 1 | mod read; 2 | mod write; 3 | 4 | pub(super) use read::read_some; 5 | pub(super) use write::write_some; 6 | 7 | #[inline] 8 | fn min_len(buf_len: usize, length: u64) -> usize { 9 | #[cfg(target_pointer_width = "64")] 10 | { 11 | std::cmp::min(buf_len, length as usize) 12 | } 13 | 14 | #[cfg(not(target_pointer_width = "64"))] 15 | { 16 | let next = std::cmp::min(usize::MAX as u64, length) as usize; 17 | std::cmp::min(buf_len, next) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/stream/detail/read.rs: -------------------------------------------------------------------------------- 1 | use std::io::Result; 2 | use std::task::{Poll, ready}; 3 | 4 | use super::min_len; 5 | use super::super::{Stream, RoleHelper}; 6 | use super::super::state::{ReadState, HeadStore}; 7 | 8 | use crate::frame::{FrameHead, Mask, OpCode}; 9 | use crate::frame::mask::apply_mask4; 10 | use crate::error::FrameError; 11 | 12 | pub fn read_some( 13 | stream: &mut Stream, 14 | mut read: F, 15 | buf: &mut [u8], 16 | ) -> Poll> 17 | where 18 | F: FnMut(&mut IO, &mut [u8]) -> Poll>, 19 | Role: RoleHelper, 20 | { 21 | debug_assert!(buf.len() >= 14); 22 | 23 | loop { 24 | match stream.read_state { 25 | // always returns 0 26 | ReadState::Eof => return Poll::Ready(Ok(0)), 27 | ReadState::Close => return Poll::Ready(Ok(0)), 28 | // read a new incoming frame 29 | ReadState::ReadHead(head_store) => { 30 | let head_store_len = head_store.rd_left(); 31 | 32 | // write stored data to user provided buffer 33 | if !head_store.is_empty() { 34 | let (left, _) = buf.split_at_mut(head_store_len); 35 | left.copy_from_slice(head_store.read()); 36 | } 37 | 38 | let read_n = ready!(read(&mut stream.io, &mut buf[head_store_len..]))?; 39 | 40 | // EOF ? 41 | if read_n == 0 { 42 | stream.read_state = ReadState::Eof; 43 | return Poll::Ready(Ok(0)); 44 | } 45 | 46 | stream.read_state = ReadState::ProcessBuf { 47 | beg: 0, 48 | end: read_n + head_store_len, 49 | processed: 0, 50 | } 51 | } 52 | // continue to read data from the same frame 53 | ReadState::ReadData { next, mask } => { 54 | let read_n = ready!(read(&mut stream.io, buf))?; 55 | // EOF ? 56 | if read_n == 0 { 57 | stream.read_state = ReadState::Eof; 58 | return Poll::Ready(Ok(0)); 59 | } 60 | let len = min_len(read_n, next); 61 | // unmask if server receives data from client 62 | // this operation can be skipped if mask key is 0 63 | if let Mask::Key(key) = mask { 64 | apply_mask4(key, &mut buf[..len]) 65 | }; 66 | // read complete ? 67 | if next > read_n as u64 { 68 | // need to read more 69 | stream.read_state = ReadState::ReadData { 70 | next: next - read_n as u64, 71 | mask, 72 | }; 73 | return Poll::Ready(Ok(read_n)); 74 | } else { 75 | // continue to process 76 | stream.read_state = ReadState::ProcessBuf { 77 | beg: len, 78 | end: read_n, 79 | processed: len, 80 | } 81 | } 82 | } 83 | // continue to read data from a ctrl frame 84 | ReadState::ReadPing { next, mask } => { 85 | let (buf, _) = stream 86 | .heartbeat 87 | .ping_store 88 | .write() 89 | .split_at_mut(next as usize); 90 | let read_n = ready!(read(&mut stream.io, buf))?; 91 | // EOF ? 92 | if read_n == 0 { 93 | stream.read_state = ReadState::Eof; 94 | return Poll::Ready(Ok(0)); 95 | } 96 | // unmask if server receives data from client 97 | // this operation can be skipped if mask key is 0 98 | if let Mask::Key(key) = mask { 99 | apply_mask4(key, buf); 100 | }; 101 | 102 | stream.heartbeat.ping_store.advance_wr_pos(read_n); 103 | 104 | // read complete ? 105 | if next == read_n as u8 { 106 | stream.heartbeat.is_complete = true; 107 | stream.read_state = ReadState::new(); 108 | } else { 109 | stream.read_state = ReadState::ReadPing { 110 | next: next - read_n as u8, 111 | mask, 112 | }; 113 | } 114 | return Poll::Ready(Ok(0)); 115 | } 116 | // handle the read data in user provided buffer 117 | ReadState::ProcessBuf { 118 | mut beg, 119 | end, 120 | mut processed, 121 | } => { 122 | // parse head, fin is ignored 123 | let ( 124 | FrameHead { 125 | opcode, 126 | mask, 127 | length, 128 | .. 129 | }, 130 | parse_n, 131 | ) = match FrameHead::decode(&buf[beg..end]) { 132 | Ok(x) => x, 133 | Err(ref e) if *e == FrameError::NotEnoughData => { 134 | if beg == end { 135 | stream.read_state = ReadState::new(); 136 | } else { 137 | stream.read_state = 138 | ReadState::ReadHead(HeadStore::new_with_data(&buf[beg..end])); 139 | } 140 | return Poll::Ready(Ok(processed)); 141 | } 142 | Err(e) => return Poll::Ready(Err(e.into())), 143 | }; 144 | // point to payload 145 | beg += parse_n; 146 | 147 | // may read a frame without payload 148 | let frame_len = length.to_num(); 149 | let buf_len = end - beg; 150 | let data_len = min_len(buf_len, frame_len); 151 | 152 | match opcode { 153 | // text is not allowed 154 | // we never send a ping, so we ignore the pong 155 | OpCode::Text | OpCode::Pong => { 156 | return Poll::Ready(Err(FrameError::UnsupportedOpcode.into())); 157 | } 158 | // ignore fin flag 159 | OpCode::Binary | OpCode::Continue => { 160 | if data_len != 0 { 161 | // unmask payload data from client 162 | if let Mask::Key(key) = mask { 163 | apply_mask4(key, &mut buf[beg..beg + data_len]); 164 | } 165 | // move forward 166 | unsafe { 167 | std::ptr::copy( 168 | buf.as_ptr().add(beg), 169 | buf.as_mut_ptr().add(processed), 170 | data_len, 171 | ); 172 | }; 173 | } 174 | beg += data_len; 175 | processed += data_len; 176 | // need to read more payload 177 | if frame_len > buf_len as u64 { 178 | stream.read_state = ReadState::ReadData { 179 | next: frame_len - data_len as u64, 180 | mask, 181 | }; 182 | return Poll::Ready(Ok(processed)); 183 | } 184 | // continue to process 185 | stream.read_state = ReadState::ProcessBuf { 186 | beg, 187 | end, 188 | processed, 189 | }; 190 | } 191 | OpCode::Ping => { 192 | // a ping frame must not have extened data 193 | if frame_len > 125 { 194 | return Poll::Ready(Err(FrameError::IllegalData.into())); 195 | } 196 | if data_len != 0 { 197 | // unmask payload data from client 198 | if let Mask::Key(key) = mask { 199 | apply_mask4(key, &mut buf[beg..beg + data_len]); 200 | } 201 | // save ping data 202 | stream 203 | .heartbeat 204 | .ping_store 205 | .replace_with_data(&buf[beg..beg + data_len]); 206 | } else { 207 | // no payload 208 | stream.heartbeat.ping_store.reset(); 209 | } 210 | 211 | // processed does not increase; 212 | beg += data_len; 213 | 214 | // need to read more payload 215 | if frame_len > buf_len as u64 { 216 | stream.heartbeat.is_complete = false; 217 | stream.read_state = ReadState::ReadPing { 218 | next: frame_len as u8 - data_len as u8, 219 | mask, 220 | }; 221 | return Poll::Ready(Ok(processed)); 222 | } 223 | // continue to process 224 | stream.heartbeat.is_complete = true; 225 | stream.read_state = ReadState::ProcessBuf { 226 | beg, 227 | end, 228 | processed, 229 | }; 230 | } 231 | OpCode::Close => { 232 | stream.read_state = ReadState::Close; 233 | return Poll::Ready(Ok(processed)); 234 | } 235 | } 236 | } 237 | } 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /src/stream/detail/write.rs: -------------------------------------------------------------------------------- 1 | use std::io::Result; 2 | use std::io::IoSlice; 3 | use std::task::{Poll, ready}; 4 | use std::marker::PhantomData; 5 | 6 | use super::min_len; 7 | use super::super::{Stream, RoleHelper}; 8 | use super::super::state::{WriteState, HeadStore}; 9 | 10 | use crate::frame::FrameHead; 11 | use crate::frame::{Fin, OpCode, PayloadLen}; 12 | 13 | pub fn write_some( 14 | stream: &mut Stream, 15 | mut write: F, 16 | buf: &[u8], 17 | ) -> Poll> 18 | where 19 | F: FnMut(&mut IO, &[IoSlice]) -> Poll>, 20 | Role: RoleHelper, 21 | { 22 | match stream.write_state { 23 | // always returns 0 24 | WriteState::WriteZero => Poll::Ready(Ok(0)), 25 | // create a new frame 26 | WriteState::WriteHead(mut head_store) => { 27 | // data frame length depends on provided buffer length 28 | let frame_len = buf.len(); 29 | 30 | if head_store.is_empty() { 31 | // build frame head 32 | // mask payload(this is unsafe) if unsafe_auto_mask_write is activated 33 | WriteFrameHead::::write_data_frame(&mut head_store, &mut stream.role, buf); 34 | } 35 | // frame head(maybe partial) + payload 36 | let iovec = [IoSlice::new(head_store.read()), IoSlice::new(buf)]; 37 | let write_n = ready!(write(&mut stream.io, &iovec))?; 38 | let head_len = head_store.rd_left(); 39 | 40 | // write zero ? 41 | if write_n == 0 { 42 | stream.write_state = WriteState::WriteZero; 43 | return Poll::Ready(Ok(0)); 44 | } 45 | 46 | // frame head is not written completely 47 | if write_n < head_len { 48 | head_store.advance_rd_pos(write_n); 49 | stream.write_state = WriteState::WriteHead(head_store); 50 | return Poll::Ready(Ok(0)); 51 | } 52 | 53 | // frame has been written completely 54 | let write_n = write_n - head_len; 55 | 56 | // all data written ? 57 | if write_n == frame_len { 58 | stream.write_state = WriteState::new(); 59 | } else { 60 | stream.write_state = WriteState::WriteData((frame_len - write_n) as u64); 61 | } 62 | 63 | Poll::Ready(Ok(write_n)) 64 | } 65 | // continue to write to the same frame 66 | WriteState::WriteData(next) => { 67 | let len = min_len(buf.len(), next); 68 | let write_n = ready!(write(&mut stream.io, &[IoSlice::new(&buf[..len])]))?; 69 | // write zero ? 70 | if write_n == 0 { 71 | stream.write_state = WriteState::WriteZero; 72 | return Poll::Ready(Ok(0)); 73 | } 74 | // all data written ? 75 | if next == write_n as u64 { 76 | stream.write_state = WriteState::new() 77 | } else { 78 | stream.write_state = WriteState::WriteData(next - write_n as u64) 79 | } 80 | Poll::Ready(Ok(write_n)) 81 | } 82 | } 83 | } 84 | 85 | struct WriteFrameHead { 86 | _marker: PhantomData, 87 | } 88 | 89 | trait WriteFrameHeadTrait { 90 | fn write_data_frame(_: &mut HeadStore, _: &mut R, _: &[u8]) {} 91 | } 92 | 93 | // use default impl 94 | impl WriteFrameHeadTrait for WriteFrameHead { 95 | #[inline] 96 | default fn write_data_frame(store: &mut HeadStore, role: &mut Role, buf: &[u8]) { 97 | let head = FrameHead::new( 98 | Fin::Y, 99 | OpCode::Binary, 100 | role.mask_key(), 101 | PayloadLen::from_num(buf.len() as u64), 102 | ); 103 | // The buffer is large enough to accommodate any kind of frame head. 104 | let n = unsafe { head.encode_unchecked(store.as_mut()) }; 105 | store.set_wr_pos(n); 106 | } 107 | } 108 | 109 | cfg_if::cfg_if! { 110 | if #[cfg(feature = "unsafe_auto_mask_write")] { 111 | use crate::role::AutoMaskClientRole; 112 | use crate::bleed::const_cast; 113 | use crate::frame::{Mask, new_mask_key, apply_mask4}; 114 | } 115 | } 116 | 117 | // specialize 118 | #[cfg(feature = "unsafe_auto_mask_write")] 119 | impl WriteFrameHeadTrait for WriteFrameHead { 120 | #[inline] 121 | fn write_data_frame(store: &mut HeadStore, role: &mut Role, buf: &[u8]) { 122 | let key = if Role::UPDATE_MASK_KEY { 123 | let key = new_mask_key(); 124 | role.set_mask_key(key); 125 | key 126 | } else { 127 | role.mask_key().to_key() 128 | }; 129 | 130 | unsafe { 131 | // !! const_cast a immutable reference 132 | let buf = const_cast(buf); 133 | // prevent too aggresive optimizations 134 | let buf = std::hint::black_box(buf); 135 | apply_mask4(key, buf); 136 | } 137 | 138 | // below is the same of default impl 139 | let head = FrameHead::new( 140 | Fin::Y, 141 | OpCode::Binary, 142 | Mask::Key(key), 143 | PayloadLen::from_num(buf.len() as u64), 144 | ); 145 | // The buffer is large enough to accommodate any kind of frame head. 146 | let n = unsafe { head.encode_unchecked(store.as_mut()) }; 147 | store.set_wr_pos(n); 148 | } 149 | } 150 | 151 | #[cfg(all(test, feature = "unsafe_auto_mask_write"))] 152 | mod test { 153 | use super::*; 154 | use crate::bleed::Store; 155 | use crate::frame::mask::*; 156 | use crate::role::*; 157 | 158 | fn auto_mask(role: &mut R, buf: &[u8]) { 159 | let mut store = Store::new(); 160 | WriteFrameHead::::write_data_frame(&mut store, role, buf) 161 | } 162 | 163 | #[test] 164 | fn auto_mask_active() { 165 | for i in 0..4096 { 166 | let mut buf: Vec = std::iter::repeat(rand::random::()).take(i).collect(); 167 | let buf2 = buf.clone(); 168 | assert_eq!(buf.len(), i); 169 | 170 | let mut role = StandardClient::new(); 171 | 172 | for _ in 0..8 { 173 | auto_mask(&mut role, &buf2); 174 | let key = role.mask_key().to_key(); 175 | apply_mask4(key, &mut buf); 176 | assert_eq!(buf, buf2); 177 | } 178 | } 179 | } 180 | 181 | #[test] 182 | fn auto_mask_active2() { 183 | for i in 0..4096 { 184 | let mut buf: Vec = std::iter::repeat(rand::random::()).take(i).collect(); 185 | let buf2 = buf.clone(); 186 | assert_eq!(buf.len(), i); 187 | 188 | let mut role = FixedMaskClient::new(); 189 | let key = role.mask_key().to_key(); 190 | 191 | for _ in 0..8 { 192 | auto_mask(&mut role, &buf2); 193 | assert_eq!(key, role.mask_key().to_key()); 194 | 195 | apply_mask4(key, &mut buf); 196 | assert_eq!(buf, buf2); 197 | } 198 | } 199 | } 200 | 201 | #[test] 202 | fn auto_mask_inactive() { 203 | for i in 0..4096 { 204 | let buf: Vec = std::iter::repeat(rand::random::()).take(i).collect(); 205 | let buf2 = buf.clone(); 206 | assert_eq!(buf.len(), i); 207 | 208 | let mut client = Client::new(); 209 | let mut server = Server::new(); 210 | 211 | for _ in 0..8 { 212 | auto_mask(&mut client, &buf2); 213 | assert_eq!(buf, buf2); 214 | } 215 | 216 | for _ in 0..8 { 217 | auto_mask(&mut server, &buf2); 218 | assert_eq!(buf, buf2); 219 | } 220 | } 221 | } 222 | } 223 | -------------------------------------------------------------------------------- /src/stream/mod.rs: -------------------------------------------------------------------------------- 1 | //! Websocket stream. 2 | //! 3 | //! [`Stream`] is a simple wrapper of the underlying IO source, 4 | //! with small stack buffers to save states. 5 | //! 6 | //! It is transparent to call `Read` or `Write` on Stream: 7 | //! 8 | //! ```ignore 9 | //! { 10 | //! // establish connection, handshake 11 | //! let stream = ... 12 | //! // read some data 13 | //! stream.read(&mut buf)?; 14 | //! // write some data 15 | //! stream.write(&buf)?; 16 | //! } 17 | //! ``` 18 | //! 19 | //! A new established [`Stream`] is in [`Direct`] (default) mode, where 20 | //! a `Read` or `Write` leads to **at most one** syscall, and 21 | //! an `Ok(0)` will be returned if frame head is not completely read or written. 22 | //! It can be converted to [`Guarded`] mode with [`Stream::guard`], 23 | //! which wraps `Read` or `Write` in a loop, where `Ok(0)` is handled internally. 24 | //! 25 | //! Stream itself does not buffer any payload data during 26 | //! a `Read` or `Write`, so there is no extra heap allocation. 27 | //! 28 | //! # Masking payload 29 | //! 30 | //! Data read from stream are automatically unmasked. 31 | //! However, data written to stream are **NOT** automatically masked, 32 | //! since a `Write` call requires an immutable `&[u8]`. 33 | //! 34 | //! A standard client(e.g. [`StandardClient`](crate::role::StandardClient)) 35 | //! should mask the payload before sending it; 36 | //! A non-standard client (e.g. [`Client`](crate::role::Client)) which holds an empty mask key 37 | //! can simply skip this step. 38 | //! 39 | //! The mask key is prepared by [`ClientRole`](crate::role::ClientRole), 40 | //! which can be set or fetched via [`Stream::set_mask_key`] and [`Stream::mask_key`]. 41 | //! 42 | //! Example: 43 | //! 44 | //! ```no_run 45 | //! use std::io::{Read, Write}; 46 | //! use std::net::TcpStream; 47 | //! use lightws::role::StandardClient; 48 | //! use lightws::endpoint::Endpoint; 49 | //! use lightws::frame::{new_mask_key, apply_mask4}; 50 | //! fn write_data() -> std::io::Result<()> { 51 | //! let mut buf = [0u8; 256]; 52 | //! let mut tcp = TcpStream::connect("example.com:80")?; 53 | //! let mut ws = Endpoint::::connect(tcp, &mut buf, "example.com", "/ws")?; 54 | //! 55 | //! // mask data 56 | //! let key = new_mask_key(); 57 | //! apply_mask4(key, &mut buf); 58 | //! 59 | //! // set mask key for next write 60 | //! ws.set_mask_key(key)?; 61 | //! 62 | //! // write some data 63 | //! ws.write_all(&buf)?; 64 | //! Ok(()) 65 | //! } 66 | //! ``` 67 | //! 68 | //! # Automatic masking 69 | //! 70 | //! It is annoying to mask the payload each time before a write, 71 | //! and it will block us from using convenient functions like [`std::io::copy`]. 72 | //! 73 | //! With `unsafe_auto_mask_write` fearure enabled, the provided immutable `&[u8]` will be casted 74 | //! to a mutable `&mut [u8]` then payload data can be automatically masked. 75 | //! 76 | //! This feature only has effects on [`AutoMaskClientRole`](crate::role::AutoMaskClientRole), 77 | //! where its inner mask key may be updated (depends on 78 | //! [`AutoMaskClientRole::UPDATE_MASK_KEY`](crate::role::AutoMaskClientRole::UPDATE_MASK_KEY)) 79 | //! and used to mask the payload before each write. 80 | //! Other [`ClientRole`](crate::role::ClientRole) and [`ServerRole`](crate::role::ServerRole) 81 | //! are not affected. Related code lies in `src/stream/detail/write#L118`. 82 | //! 83 | 84 | mod read; 85 | mod write; 86 | 87 | mod ctrl; 88 | mod state; 89 | mod detail; 90 | mod special; 91 | 92 | cfg_if::cfg_if! { 93 | if #[cfg(feature = "async")] { 94 | mod async_read; 95 | mod async_write; 96 | } 97 | } 98 | 99 | use std::marker::PhantomData; 100 | use state::{ReadState, WriteState, HeartBeat}; 101 | use crate::role::RoleHelper; 102 | 103 | /// Direct read or write. 104 | pub struct Direct {} 105 | 106 | /// Wrapped read or write. 107 | pub struct Guarded {} 108 | 109 | /// Websocket stream. 110 | /// 111 | /// Depending on `IO`, [`Stream`] implements [`std::io::Read`] and [`std::io::Write`] 112 | /// or [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`]. 113 | /// 114 | /// `Role` decides whether to mask payload data. 115 | /// It is reserved to provide extra infomation to apply optimizations. 116 | /// 117 | /// See also: `Stream::read`, `Stream::write`. 118 | pub struct Stream { 119 | io: IO, 120 | role: Role, 121 | read_state: ReadState, 122 | write_state: WriteState, 123 | heartbeat: HeartBeat, 124 | __marker: PhantomData, 125 | } 126 | 127 | impl AsRef for Stream { 128 | #[inline] 129 | fn as_ref(&self) -> &IO { &self.io } 130 | } 131 | 132 | impl AsMut for Stream { 133 | #[inline] 134 | fn as_mut(&mut self) -> &mut IO { &mut self.io } 135 | } 136 | 137 | impl std::fmt::Debug for Stream { 138 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 139 | f.debug_struct("Stream") 140 | .field("read_state", &self.read_state) 141 | .field("write_state", &self.write_state) 142 | .field("heartbeat", &self.heartbeat) 143 | .finish() 144 | } 145 | } 146 | 147 | impl Stream { 148 | /// Create websocket stream from IO source directly, 149 | /// without a handshake. 150 | #[inline] 151 | pub const fn new(io: IO, role: Role) -> Self { 152 | Stream { 153 | io, 154 | role, 155 | read_state: ReadState::new(), 156 | write_state: WriteState::new(), 157 | heartbeat: HeartBeat::new(), 158 | __marker: PhantomData, 159 | } 160 | } 161 | 162 | /// Convert to a guarded stream. 163 | #[inline] 164 | pub fn guard(self) -> Stream { 165 | Stream { 166 | io: self.io, 167 | role: self.role, 168 | read_state: self.read_state, 169 | write_state: self.write_state, 170 | heartbeat: self.heartbeat, 171 | __marker: PhantomData, 172 | } 173 | } 174 | } 175 | 176 | #[cfg(test)] 177 | mod test { 178 | use super::*; 179 | use std::io::{Read, Write, Result}; 180 | use crate::frame::*; 181 | use crate::role::*; 182 | 183 | pub struct LimitReadWriter { 184 | pub buf: Vec, 185 | pub rlimit: usize, 186 | pub wlimit: usize, 187 | pub cursor: usize, 188 | } 189 | 190 | impl Read for LimitReadWriter { 191 | fn read(&mut self, mut buf: &mut [u8]) -> Result { 192 | let to_read = std::cmp::min(buf.len(), self.rlimit); 193 | let left_data = self.buf.len() - self.cursor; 194 | if left_data == 0 { 195 | return Ok(0); 196 | } 197 | if left_data <= to_read { 198 | buf.write(&self.buf[self.cursor..]).unwrap(); 199 | self.cursor = self.buf.len(); 200 | return Ok(left_data); 201 | } 202 | 203 | buf.write(&self.buf[self.cursor..self.cursor + to_read]) 204 | .unwrap(); 205 | self.cursor += to_read; 206 | Ok(to_read) 207 | } 208 | } 209 | 210 | impl Write for LimitReadWriter { 211 | fn write(&mut self, buf: &[u8]) -> Result { 212 | let len = std::cmp::min(buf.len(), self.wlimit); 213 | self.buf.write(&buf[..len]) 214 | } 215 | 216 | fn flush(&mut self) -> Result<()> { Ok(()) } 217 | } 218 | 219 | pub fn make_head(opcode: OpCode, mask: Mask, len: usize) -> Vec { 220 | let mut tmp = vec![0; 14]; 221 | let head = FrameHead::new(Fin::Y, opcode, mask, PayloadLen::from_num(len as u64)); 222 | 223 | let head_len = head.encode(&mut tmp).unwrap(); 224 | let mut head = Vec::new(); 225 | let write_n = head.write(&tmp[..head_len]).unwrap(); 226 | assert_eq!(write_n, head_len); 227 | head 228 | } 229 | 230 | pub fn make_data(len: usize) -> Vec { 231 | std::iter::repeat(rand::random::()).take(len).collect() 232 | } 233 | 234 | pub fn make_frame(opcode: OpCode, len: usize) -> (Vec, Vec) { 235 | make_frame_with_mask(opcode, R::new().mask_key(), len) 236 | } 237 | 238 | // data is unmasked 239 | pub fn make_frame_with_mask(opcode: OpCode, mask: Mask, len: usize) -> (Vec, Vec) { 240 | let data = make_data(len); 241 | let mut data2 = data.clone(); 242 | 243 | let mut frame = make_head(opcode, mask, len); 244 | let head_len = frame.len(); 245 | 246 | frame.append(&mut data2); 247 | assert_eq!(frame.len(), len + head_len); 248 | 249 | (frame, data) 250 | } 251 | 252 | #[test] 253 | fn read_write_stream() { 254 | fn read_write(rlimit: usize, wlimit: usize, len: usize) { 255 | let io = LimitReadWriter { 256 | buf: Vec::new(), 257 | rlimit, 258 | wlimit, 259 | cursor: 0, 260 | }; 261 | // data written to a client stream should be read as a server stream. 262 | // here we read/write on the same (client/server)stream. 263 | // this is not correct in practice, but our program can still handle it. 264 | let mut stream = Stream::<_, R>::new(io, R::new()); 265 | 266 | let data: Vec = std::iter::repeat(rand::random::()).take(len).collect(); 267 | let mut data2: Vec = Vec::new(); 268 | 269 | let mut buf = vec![0; 0x2000]; 270 | let mut to_write = data.len(); 271 | 272 | while to_write > 0 { 273 | let wbeg = data.len() - to_write; 274 | let n = loop { 275 | let x = stream.write(&data[wbeg..]).unwrap(); 276 | if x != 0 { 277 | break x; 278 | } 279 | }; 280 | 281 | let mut tmp: Vec = Vec::new(); 282 | loop { 283 | // avoid read EOF here 284 | if stream.as_ref().cursor == stream.as_ref().buf.len() { 285 | break; 286 | } 287 | let n = stream.read(&mut buf).unwrap(); 288 | 289 | // if n == 0 && stream.is_read_end() { 290 | // break; 291 | // } 292 | 293 | tmp.write(&buf[..n]).unwrap(); 294 | } 295 | 296 | assert_eq!(tmp.len(), n); 297 | assert_eq!(&data[wbeg..wbeg + n], &tmp); 298 | 299 | to_write -= n; 300 | data2.append(&mut tmp); 301 | } 302 | 303 | assert_eq!(&data, &data2); 304 | } 305 | 306 | for limit in 1..512 { 307 | for len in 1..=256 { 308 | read_write::(limit, 512 - limit, len); 309 | read_write::(limit, 512 - limit, len); 310 | } 311 | } 312 | } 313 | } 314 | -------------------------------------------------------------------------------- /src/stream/read.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Read, Result}; 2 | use std::task::Poll; 3 | 4 | use super::{Stream, RoleHelper, Guarded}; 5 | use super::detail::read_some; 6 | 7 | impl Read for Stream { 8 | /// Read some data from the underlying IO source, 9 | /// returns `Ok(0)` until a complete frame head is present. 10 | /// Caller should ensure the available buffer size is larger 11 | /// than **14** before a read. 12 | /// 13 | /// Read a control frame(like Ping) returns `Ok(0)`, 14 | /// which could be detected via [`Stream::is_pinged`]. 15 | /// 16 | /// Any read after receiving a `Close` frame or reaching `EOF` 17 | /// will return `Ok(0)`, 18 | /// which could be checked via [`Stream::is_read_end`], 19 | /// [`Stream::is_read_close`], [`Stream::is_read_eof`]. 20 | fn read(&mut self, buf: &mut [u8]) -> Result { 21 | match read_some(self, |io, buf| io.read(buf).into(), buf) { 22 | Poll::Ready(x) => x, 23 | Poll::Pending => unreachable!(), 24 | } 25 | } 26 | 27 | /// **This is NOT supported!** 28 | fn read_to_end(&mut self, _: &mut Vec) -> Result { 29 | panic!("Unsupported"); 30 | } 31 | 32 | /// **This is NOT supported!** 33 | fn read_exact(&mut self, _: &mut [u8]) -> Result<()> { 34 | panic!("Unsupported"); 35 | } 36 | 37 | /// **This is NOT supported!** 38 | fn read_to_string(&mut self, _: &mut String) -> Result { 39 | panic!("Unsupported"); 40 | } 41 | } 42 | 43 | impl Read for Stream { 44 | /// Wrap read in a loop. 45 | /// Continue to read if frame head is not complete. 46 | fn read(&mut self, buf: &mut [u8]) -> Result { 47 | loop { 48 | match read_some(self, |io, buf| io.read(buf).into(), buf) { 49 | Poll::Ready(Ok(0)) if self.is_read_partial_head() || !self.is_read_end() => { 50 | continue 51 | } 52 | Poll::Ready(x) => return x, 53 | Poll::Pending => unreachable!(), 54 | } 55 | } 56 | } 57 | 58 | /// Override default implement, extend reserved buffer size, 59 | /// so that there is enough space to accommodate frame head. 60 | fn read_to_end(&mut self, buf: &mut Vec) -> Result { 61 | use std::io::BorrowedBuf; 62 | use std::io::ErrorKind; 63 | 64 | let start_len = buf.len(); 65 | let start_cap = buf.capacity(); 66 | 67 | let mut initialized = 0; // Extra initialized bytes from previous loop iteration 68 | loop { 69 | if buf.len() < buf.capacity() + 14 { 70 | buf.reserve(32); // buf is full, need more space 71 | } 72 | 73 | let mut read_buf: BorrowedBuf<'_> = buf.spare_capacity_mut().into(); 74 | 75 | // SAFETY: These bytes were initialized but not filled in the previous loop 76 | unsafe { 77 | read_buf.set_init(initialized); 78 | } 79 | 80 | let mut cursor = read_buf.unfilled(); 81 | match self.read_buf(cursor.reborrow()) { 82 | Ok(()) => {} 83 | Err(e) if e.kind() == ErrorKind::Interrupted => continue, 84 | Err(e) => return Err(e), 85 | } 86 | 87 | if cursor.written() == 0 { 88 | return Ok(buf.len() - start_len); 89 | } 90 | 91 | // store how much was initialized but not filled 92 | initialized = cursor.init_ref().len(); 93 | 94 | // SAFETY: BorrowedBuf's invariants mean this much memory is init 95 | unsafe { 96 | let new_len = read_buf.filled().len() + buf.len(); 97 | buf.set_len(new_len); 98 | } 99 | 100 | if buf.len() == buf.capacity() && buf.capacity() == start_cap { 101 | // The buffer might be an exact fit. Let's read into a probe buffer 102 | // and see if it returns `Ok(0)`. If so, we've avoided an 103 | // unnecessary doubling of the capacity. But if not, append the 104 | // probe buffer to the primary buffer and let its capacity grow. 105 | let mut probe = [0u8; 32]; 106 | 107 | loop { 108 | match self.read(&mut probe) { 109 | Ok(0) => return Ok(buf.len() - start_len), 110 | Ok(n) => { 111 | buf.extend_from_slice(&probe[..n]); 112 | break; 113 | } 114 | Err(ref e) if e.kind() == ErrorKind::Interrupted => continue, 115 | Err(e) => return Err(e), 116 | } 117 | } 118 | } 119 | } 120 | } 121 | } 122 | 123 | #[cfg(test)] 124 | mod test { 125 | use std::io::Read; 126 | use super::*; 127 | use super::super::test::{LimitReadWriter, make_frame}; 128 | use crate::frame::*; 129 | use crate::role::*; 130 | 131 | #[test] 132 | fn read_from_stream() { 133 | fn read(n: usize) { 134 | let (frame, data) = make_frame::(OpCode::Binary, n); 135 | 136 | let mut stream = Stream::new(frame.as_slice(), R2::new()); 137 | 138 | let mut buf = vec![0; n + 14]; 139 | let read_n = stream.read(&mut buf).unwrap(); 140 | 141 | assert_eq!(read_n, n); 142 | assert_eq!(&buf[..n], &data); 143 | } 144 | 145 | for i in 0..=0x2000 { 146 | read::(i); 147 | read::(i); 148 | } 149 | 150 | for i in [65536, 65537, 100000] { 151 | read::(i); 152 | read::(i); 153 | } 154 | } 155 | 156 | #[test] 157 | fn read_from_limit_stream() { 158 | fn read(n: usize, limit: usize) { 159 | let (frame, data) = make_frame::(OpCode::Binary, n); 160 | 161 | let io = LimitReadWriter { 162 | buf: frame, 163 | rlimit: limit, 164 | wlimit: 0, 165 | cursor: 0, 166 | }; 167 | 168 | let mut buf = Vec::new(); 169 | let mut stream = Stream::new(io, R2::new()).guard(); 170 | 171 | let read_n = stream.read_to_end(&mut buf).unwrap(); 172 | 173 | assert_eq!(read_n, n); 174 | assert_eq!(&buf[..n], &data); 175 | } 176 | 177 | for i in 0..=256 { 178 | for limit in 1..=300 { 179 | read::(i, limit); 180 | read::(i, limit); 181 | } 182 | } 183 | 184 | for i in [65536, 65537, 100000] { 185 | for limit in 1..=1024 { 186 | read::(i, limit); 187 | read::(i, limit); 188 | } 189 | } 190 | } 191 | 192 | #[test] 193 | fn read_eof_from_stream() { 194 | fn read() { 195 | let io = LimitReadWriter { 196 | buf: b"EOFFFF:)".to_vec(), 197 | rlimit: 0, 198 | wlimit: 0, 199 | cursor: 0, 200 | }; 201 | let mut stream = Stream::new(io, R::new()); 202 | let mut buf = vec![0; 32]; 203 | let n = stream.read(&mut buf).unwrap(); 204 | assert_eq!(n, 0); 205 | assert!(stream.is_read_end()); 206 | assert!(stream.is_read_eof()); 207 | 208 | let mut stream = stream.guard(); 209 | 210 | let n = stream.read_to_end(&mut buf).unwrap(); 211 | assert_eq!(n, 0); 212 | assert!(stream.is_read_end()); 213 | assert!(stream.is_read_eof()); 214 | } 215 | read::(); 216 | read::(); 217 | } 218 | 219 | #[test] 220 | fn read_close_from_stream() { 221 | fn read(limit: usize) { 222 | let (frame, _) = make_frame::(OpCode::Close, 1); 223 | let io = LimitReadWriter { 224 | buf: frame, 225 | rlimit: limit, 226 | wlimit: 0, 227 | cursor: 0, 228 | }; 229 | 230 | let mut stream = Stream::new(io, R2::new()); 231 | 232 | let mut buf = vec![0; 32]; 233 | 234 | let n = stream.read(&mut buf).unwrap(); 235 | assert_eq!(n, 0); 236 | 237 | let mut stream = stream.guard(); 238 | 239 | let n = stream.read_to_end(&mut buf).unwrap(); 240 | assert_eq!(n, 0); 241 | assert!(stream.is_read_end()); 242 | assert!(stream.is_read_close()); 243 | } 244 | 245 | for i in 1..=32 { 246 | read::(i); 247 | read::(i); 248 | } 249 | } 250 | 251 | #[test] 252 | fn read_ping_from_stream() { 253 | fn read(n: usize, limit: usize) { 254 | let (frame, data) = make_frame::(OpCode::Ping, n); 255 | 256 | let io = LimitReadWriter { 257 | buf: frame, 258 | rlimit: limit, 259 | wlimit: 0, 260 | cursor: 0, 261 | }; 262 | 263 | let mut buf = Vec::new(); 264 | let mut stream = Stream::new(io, R2::new()).guard(); 265 | 266 | let read_n = stream.read_to_end(&mut buf).unwrap(); 267 | 268 | assert_eq!(read_n, 0); 269 | assert_eq!(stream.ping_data(), &data); 270 | } 271 | 272 | for i in 0..=125 { 273 | for limit in 1..=128 { 274 | read::(i, limit); 275 | read::(i, limit); 276 | } 277 | } 278 | } 279 | 280 | #[test] 281 | fn read_multi_frame_from_stream() { 282 | fn read(n: usize, step: usize, limit: usize) { 283 | let mut len = 0; 284 | let mut frame = Vec::new(); 285 | let mut data = Vec::new(); 286 | 287 | for i in 0..n { 288 | let (mut f, mut d) = make_frame::(OpCode::Binary, step + i * step); 289 | len += d.len(); 290 | frame.append(&mut f); 291 | data.append(&mut d); 292 | assert_eq!(len, (i + 1) * (i + 2) * step / 2); 293 | } 294 | 295 | let (mut close, _) = make_frame::(OpCode::Close, 1); 296 | frame.append(&mut close); 297 | 298 | let io = LimitReadWriter { 299 | buf: frame, 300 | rlimit: limit, 301 | wlimit: 0, 302 | cursor: 0, 303 | }; 304 | 305 | let mut buf = Vec::new(); 306 | let mut stream = Stream::new(io, R2::new()).guard(); 307 | 308 | let read_n = stream.read_to_end(&mut buf).unwrap(); 309 | 310 | assert!(stream.is_read_end()); 311 | assert!(stream.is_read_close()); 312 | assert_eq!(read_n, len); 313 | assert_eq!(&buf[..len], &data); 314 | } 315 | 316 | for n in 1..=20 { 317 | for step in [1, 10, 100, 1000, 10000] { 318 | for limit in [1, 10, 100, 1000, 10000, usize::MAX] { 319 | read::(n, step, limit); 320 | read::(n, step, limit); 321 | } 322 | } 323 | } 324 | } 325 | 326 | #[test] 327 | fn read_multi_ping_from_stream() { 328 | fn read(n: usize, step: usize, limit: usize) { 329 | let mut len = 0; 330 | let mut frame = Vec::new(); 331 | let mut data = Vec::new(); 332 | 333 | for i in 0..n { 334 | let (mut f, d) = make_frame::(OpCode::Ping, step + i * step); 335 | len += d.len(); 336 | frame.append(&mut f); 337 | data = d; 338 | assert_eq!(len, (i + 1) * (i + 2) * step / 2); 339 | } 340 | 341 | let io = LimitReadWriter { 342 | buf: frame, 343 | rlimit: limit, 344 | wlimit: 0, 345 | cursor: 0, 346 | }; 347 | 348 | let mut buf = Vec::new(); 349 | let mut stream = Stream::new(io, R2::new()).guard(); 350 | 351 | let read_n = stream.read_to_end(&mut buf).unwrap(); 352 | 353 | assert_eq!(read_n, 0); 354 | assert_eq!(stream.ping_data(), &data); 355 | } 356 | 357 | for n in 1..=125 { 358 | for limit in 1..=128 { 359 | read::(n, 1, limit); 360 | read::(n, 1, limit); 361 | } 362 | } 363 | } 364 | } 365 | -------------------------------------------------------------------------------- /src/stream/special.rs: -------------------------------------------------------------------------------- 1 | use super::{Stream, RoleHelper}; 2 | use std::io::Result; 3 | use std::net::TcpStream; 4 | 5 | impl Stream { 6 | /// Creates a new independently owned handle to the underlying IO source. 7 | /// 8 | /// Caution: **states are not shared among instances!** 9 | pub fn try_clone(&self) -> Result { 10 | let io = self.io.try_clone()?; 11 | Ok(Self::new(io, self.role)) 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/stream/state.rs: -------------------------------------------------------------------------------- 1 | use super::Stream; 2 | 3 | use crate::frame::Mask; 4 | use crate::bleed::Store; 5 | 6 | /// Store incomplete frame head. 7 | pub(super) type HeadStore = Store<14>; 8 | 9 | /// Store the most recent ping. 10 | pub(super) type PingStore = Store<125>; 11 | 12 | #[derive(Debug)] 13 | pub(super) struct HeartBeat { 14 | pub ping_store: PingStore, 15 | pub is_complete: bool, 16 | } 17 | 18 | impl HeartBeat { 19 | #[inline] 20 | pub const fn new() -> Self { 21 | Self { 22 | ping_store: PingStore::new(), 23 | is_complete: false, 24 | } 25 | } 26 | } 27 | 28 | /// Read state. 29 | #[derive(Debug)] 30 | pub(super) enum ReadState { 31 | ReadHead(HeadStore), 32 | ReadData { 33 | next: u64, 34 | mask: Mask, 35 | }, 36 | ReadPing { 37 | next: u8, 38 | mask: Mask, 39 | }, 40 | ProcessBuf { 41 | beg: usize, 42 | end: usize, 43 | processed: usize, 44 | }, 45 | Eof, 46 | Close, 47 | } 48 | 49 | impl ReadState { 50 | #[inline] 51 | pub const fn new() -> Self { ReadState::ReadHead(Store::new()) } 52 | } 53 | 54 | /// Write state. 55 | #[allow(clippy::enum_variant_names)] 56 | #[derive(Debug)] 57 | pub(super) enum WriteState { 58 | WriteHead(HeadStore), 59 | WriteData(u64), 60 | WriteZero, 61 | } 62 | 63 | impl WriteState { 64 | #[inline] 65 | pub const fn new() -> Self { WriteState::WriteHead(Store::new()) } 66 | } 67 | 68 | /// Check status. 69 | impl Stream { 70 | /// Check if a `Ping` frame is received. 71 | #[inline] 72 | pub const fn is_pinged(&self) -> bool { !self.heartbeat.ping_store.is_empty() } 73 | 74 | /// Check if a `Ping` frame is completely read. 75 | #[inline] 76 | pub const fn is_ping_completed(&self) -> bool { self.heartbeat.is_complete } 77 | 78 | /// Get the most recent ping. 79 | #[inline] 80 | pub const fn ping_data(&self) -> &[u8] { self.heartbeat.ping_store.read() } 81 | 82 | /// Check if `EOF` is reached. 83 | #[inline] 84 | pub const fn is_read_eof(&self) -> bool { matches!(&self.read_state, ReadState::Eof) } 85 | 86 | /// Check if a `Close` frame is received. 87 | #[inline] 88 | pub const fn is_read_close(&self) -> bool { matches!(&self.read_state, ReadState::Close) } 89 | 90 | /// Check if a `Close` frame is received or `EOF` is reached. 91 | #[inline] 92 | pub const fn is_read_end(&self) -> bool { self.is_read_eof() || self.is_read_close() } 93 | 94 | /// Check if a `WriteZero` error occurred. 95 | #[inline] 96 | pub const fn is_write_zero(&self) -> bool { matches!(&self.write_state, WriteState::WriteZero) } 97 | 98 | /// Check if a frame head is partially read. 99 | #[inline] 100 | pub const fn is_read_partial_head(&self) -> bool { 101 | matches!(&self.read_state, ReadState::ReadHead(..)) 102 | } 103 | 104 | /// Check if frame head is partially written. 105 | #[inline] 106 | pub const fn is_write_partial_head(&self) -> bool { 107 | matches!(&self.write_state, WriteState::WriteHead(..)) 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/stream/write.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Write, Result}; 2 | use std::task::Poll; 3 | 4 | use super::{Stream, RoleHelper, Guarded}; 5 | use super::detail::write_some; 6 | 7 | impl Write for Stream { 8 | /// Write some data to the underlying IO source, 9 | /// returns `Ok(0)` until the frame head is completely 10 | /// written. 11 | /// 12 | /// if `WriteZero` occurs, it will also return `Ok(0)`, 13 | /// which could be detected via [`Stream::is_write_zero`]. 14 | /// 15 | /// Frame head will be generated automatically, 16 | /// according to the length of the provided buffer. 17 | /// 18 | /// A standard client should mask payload data before sending it. 19 | fn write(&mut self, buf: &[u8]) -> Result { 20 | match write_some(self, |io, iovec| io.write_vectored(iovec).into(), buf) { 21 | Poll::Ready(x) => x, 22 | Poll::Pending => unreachable!(), 23 | } 24 | } 25 | 26 | /// The writer does not buffer any data, simply flush 27 | /// the underlying IO source. 28 | fn flush(&mut self) -> Result<()> { self.io.flush() } 29 | 30 | /// **This is NOT supported!** 31 | fn write_all(&mut self, _: &[u8]) -> Result<()> { 32 | panic!("Unsupported"); 33 | } 34 | } 35 | 36 | impl Write for Stream { 37 | /// Wrap write in a loop. 38 | /// Continue to write if frame head is not completely written. 39 | fn write(&mut self, buf: &[u8]) -> Result { 40 | loop { 41 | match write_some(self, |io, iovec| io.write_vectored(iovec).into(), buf) { 42 | Poll::Ready(Ok(0)) if self.is_write_partial_head() || !self.is_write_zero() => { 43 | continue 44 | } 45 | Poll::Ready(x) => return x, 46 | Poll::Pending => unreachable!(), 47 | } 48 | } 49 | } 50 | 51 | /// The writer does not buffer any data, simply flush 52 | /// the underlying IO source. 53 | fn flush(&mut self) -> Result<()> { self.io.flush() } 54 | } 55 | 56 | #[cfg(test)] 57 | mod test { 58 | use super::*; 59 | use super::super::test::*; 60 | use crate::frame::*; 61 | use crate::role::*; 62 | use std::io::Write; 63 | 64 | #[test] 65 | fn write_to_stream() { 66 | fn write(n: usize) { 67 | let (frame, data) = make_frame::(OpCode::Binary, n); 68 | 69 | let io: Vec = Vec::new(); 70 | let mut stream = Stream::new(io, R::new()); 71 | 72 | let write_n = stream.write(&data).unwrap(); 73 | 74 | assert_eq!(write_n, n); 75 | 76 | assert_eq!(stream.as_ref(), &frame); 77 | } 78 | 79 | for i in 1..=0x2000 { 80 | write::(i); 81 | write::(i); 82 | } 83 | 84 | for i in [65536, 65537, 100000] { 85 | write::(i); 86 | write::(i); 87 | } 88 | } 89 | 90 | #[test] 91 | fn write_to_limit_stream() { 92 | fn write(n: usize, limit: usize) { 93 | let (frame, data) = make_frame::(OpCode::Binary, n); 94 | 95 | let io = LimitReadWriter { 96 | buf: Vec::new(), 97 | rlimit: 0, 98 | wlimit: limit, 99 | cursor: 0, 100 | }; 101 | 102 | let mut stream = Stream::new(io, R::new()).guard(); 103 | 104 | stream.write_all(&data).unwrap(); 105 | 106 | assert_eq!(&stream.as_ref().buf, &frame); 107 | } 108 | 109 | for i in 1..=256 { 110 | for limit in 1..=300 { 111 | write::(i, limit); 112 | write::(i, limit); 113 | } 114 | } 115 | 116 | for i in [65536, 65537, 100000] { 117 | for limit in 1..=1024 { 118 | write::(i, limit); 119 | write::(i, limit); 120 | } 121 | } 122 | } 123 | 124 | #[test] 125 | #[cfg(feature = "unsafe_auto_mask_write")] 126 | fn write_to_stream_auto_mask_fixed() { 127 | fn write(n: usize) { 128 | let key = new_mask_key(); 129 | 130 | let (mut frame, data) = make_frame_with_mask(OpCode::Binary, Mask::Key(key), n); 131 | 132 | // manually mask frame data 133 | let offset = frame.len() - n; 134 | apply_mask4(key, &mut frame[offset..]); 135 | 136 | let io: Vec = Vec::new(); 137 | let mut stream = Stream::new(io, R::new()); 138 | stream.set_mask_key(key).unwrap(); 139 | 140 | let write_n = stream.write(&data).unwrap(); 141 | 142 | assert_eq!(write_n, n); 143 | 144 | assert_eq!(stream.as_ref(), &frame); 145 | } 146 | for i in 1..=2 { 147 | write::(i); 148 | } 149 | 150 | for i in [65536, 65537, 100000] { 151 | write::(i); 152 | } 153 | } 154 | 155 | #[test] 156 | #[cfg(feature = "unsafe_auto_mask_write")] 157 | fn write_to_limit_stream_auto_mask_fixed() { 158 | fn write(n: usize, limit: usize) { 159 | let key = new_mask_key(); 160 | let (mut frame, data) = make_frame_with_mask(OpCode::Binary, Mask::Key(key), n); 161 | 162 | // manually mask frame data 163 | let offset = frame.len() - n; 164 | apply_mask4(key, &mut frame[offset..]); 165 | 166 | let io = LimitReadWriter { 167 | buf: Vec::new(), 168 | rlimit: 0, 169 | wlimit: limit, 170 | cursor: 0, 171 | }; 172 | 173 | let mut stream = Stream::new(io, R::new()).guard(); 174 | stream.set_mask_key(key).unwrap(); 175 | 176 | stream.write_all(&data).unwrap(); 177 | 178 | assert_eq!(&stream.as_ref().buf, &frame); 179 | } 180 | 181 | for i in 1..=256 { 182 | for limit in 1..=300 { 183 | write::(i, limit); 184 | } 185 | } 186 | 187 | for i in [65536, 65537, 100000] { 188 | for limit in 1..=1024 { 189 | write::(i, limit); 190 | } 191 | } 192 | } 193 | 194 | #[test] 195 | #[cfg(feature = "unsafe_auto_mask_write")] 196 | fn write_to_stream_auto_mask_updated() { 197 | fn write(n: usize) { 198 | let data = make_data(n); 199 | let mut data2 = data.clone(); 200 | 201 | let io: Vec = Vec::new(); 202 | let mut stream = Stream::new(io, R::new()); 203 | 204 | let write_n = stream.write(&data).unwrap(); 205 | assert_eq!(write_n, n); 206 | 207 | // manually mask frame data 208 | let key = stream.mask_key().to_key(); 209 | let head = make_head(OpCode::Binary, Mask::Key(key), n); 210 | apply_mask4(key, &mut data2); 211 | 212 | assert_eq!(stream.as_ref()[..head.len()], head); 213 | assert_eq!(stream.as_ref()[head.len()..], data2); 214 | } 215 | for i in 1..=2 { 216 | write::(i); 217 | } 218 | 219 | for i in [65536, 65537, 100000] { 220 | write::(i); 221 | } 222 | } 223 | 224 | #[test] 225 | #[cfg(feature = "unsafe_auto_mask_write")] 226 | fn write_to_limit_stream_auto_mask_updated() { 227 | fn write(n: usize, limit: usize) { 228 | let data = make_data(n); 229 | let mut data2 = data.clone(); 230 | 231 | let io = LimitReadWriter { 232 | buf: Vec::new(), 233 | rlimit: 0, 234 | wlimit: limit, 235 | cursor: 0, 236 | }; 237 | 238 | let mut stream = Stream::new(io, R::new()).guard(); 239 | stream.write_all(&data).unwrap(); 240 | 241 | // manually mask frame data 242 | let key = stream.mask_key().to_key(); 243 | let head = make_head(OpCode::Binary, Mask::Key(key), n); 244 | apply_mask4(key, &mut data2); 245 | 246 | assert_eq!(stream.as_ref().buf[..head.len()], head); 247 | assert_eq!(stream.as_ref().buf[head.len()..], data2); 248 | } 249 | 250 | for i in 1..=256 { 251 | for limit in 1..=300 { 252 | write::(i, limit); 253 | } 254 | } 255 | 256 | for i in [65536, 65537, 100000] { 257 | for limit in 1..=1024 { 258 | write::(i, limit); 259 | } 260 | } 261 | } 262 | } 263 | -------------------------------------------------------------------------------- /tests/async_bidi_copy.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use tokio::net::{TcpStream, TcpListener}; 4 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 5 | 6 | use lightws::endpoint::Endpoint; 7 | use lightws::role::{Client, Server}; 8 | 9 | use log::debug; 10 | 11 | const ADDR1: &str = "127.0.0.1:10000"; 12 | const ADDR2: &str = "127.0.0.1:20000"; 13 | const HOST: &str = "www.example.com"; 14 | const PATH: &str = "/ws"; 15 | const ECHO_DATA: &[u8] = b"ECHO ECHO ECHO!"; 16 | 17 | macro_rules! gets { 18 | ($b: expr) => { 19 | std::str::from_utf8($b).unwrap() 20 | }; 21 | } 22 | 23 | // addr0(client) <=> addr1(relay) <=> addr2(server) 24 | #[tokio::test] 25 | async fn async_bidi_copy() { 26 | env_logger::init(); 27 | 28 | let lis1 = TcpListener::bind(ADDR1).await.unwrap(); 29 | let lis2 = TcpListener::bind(ADDR2).await.unwrap(); 30 | 31 | let relay = tokio::spawn(async move { 32 | let mut buf = vec![0u8; 1024]; 33 | let (tcp, _) = lis1.accept().await.unwrap(); 34 | debug!("relay: tcp accepted!"); 35 | let ws_local = Endpoint::<_, Server>::accept_async(tcp, &mut buf, HOST, PATH) 36 | .await 37 | .unwrap() 38 | .guard(); 39 | debug!("relay: websocket accepted!"); 40 | 41 | let tcp = TcpStream::connect(ADDR2).await.unwrap(); 42 | debug!("relay: tcp connected!"); 43 | let ws_remote = Endpoint::<_, Client>::connect_async(tcp, &mut buf, HOST, PATH) 44 | .await 45 | .unwrap() 46 | .guard(); 47 | debug!("relay: websocket connected!"); 48 | 49 | // or use tokio::io::bidirectional_copy ~ 50 | let (mut ws_local_read, mut ws_local_write) = tokio::io::split(ws_local); 51 | let (mut ws_remote_read, mut ws_remote_write) = tokio::io::split(ws_remote); 52 | 53 | let t1 = tokio::spawn(async move { 54 | let _ = tokio::io::copy(&mut ws_local_read, &mut ws_remote_write).await; 55 | debug!("relay: client close, shutdown"); 56 | let _ = ws_remote_write.shutdown().await; 57 | }); 58 | 59 | let t2 = tokio::spawn(async move { 60 | let _ = tokio::io::copy(&mut ws_remote_read, &mut ws_local_write).await; 61 | debug!("relay: server close"); 62 | }); 63 | 64 | let _ = tokio::join!(t1, t2); 65 | }); 66 | 67 | let server = tokio::spawn(async move { 68 | let mut buf = vec![0u8; 1024]; 69 | let (tcp, _) = lis2.accept().await.unwrap(); 70 | debug!("server: tcp accepted!"); 71 | let mut ws = Endpoint::<_, Server>::accept_async(tcp, &mut buf, HOST, PATH) 72 | .await 73 | .unwrap(); 74 | debug!("server: websocket accepted!"); 75 | 76 | loop { 77 | let n = ws.read(&mut buf).await.unwrap(); 78 | if n == 0 && ws.is_read_end() { 79 | debug!("server: close"); 80 | break; 81 | } 82 | debug!("server: echo.."); 83 | let _ = ws.write(&buf[..n]).await.unwrap(); 84 | } 85 | }); 86 | 87 | let client = tokio::spawn(async { 88 | let mut buf = vec![0u8; 1024]; 89 | debug!("client: sleep 500ms.."); 90 | tokio::time::sleep(Duration::from_millis(500)).await; 91 | let tcp = TcpStream::connect(ADDR1).await.unwrap(); 92 | debug!("client: tcp connected!"); 93 | let mut ws = Endpoint::<_, Client>::connect_async(tcp, &mut buf, HOST, PATH) 94 | .await 95 | .unwrap(); 96 | debug!("client: websocket connected!"); 97 | 98 | debug!("client: sleep 500ms.."); 99 | tokio::time::sleep(Duration::from_millis(500)).await; 100 | 101 | for i in 1..=5 { 102 | debug!("client: send[{}]..", i); 103 | let n = ws.write(ECHO_DATA).await.unwrap(); 104 | assert_eq!(n, ECHO_DATA.len()); 105 | 106 | let n = ws.read(&mut buf).await.unwrap(); 107 | debug!("client: receive message: {}", gets!(&buf[..n])); 108 | assert_eq!(n, ECHO_DATA.len()); 109 | assert_eq!(&buf[..n], ECHO_DATA); 110 | } 111 | 112 | debug!("client: close"); 113 | }); 114 | 115 | let _ = tokio::join!(relay, server, client); 116 | } 117 | -------------------------------------------------------------------------------- /tests/async_echo.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use tokio::net::{TcpStream, TcpListener}; 4 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 5 | 6 | use lightws::endpoint::Endpoint; 7 | use lightws::role::{Client, Server}; 8 | 9 | use log::debug; 10 | 11 | const ADDR: &str = "127.0.0.1:10000"; 12 | const HOST: &str = "www.example.com"; 13 | const PATH: &str = "/ws"; 14 | const ECHO_DATA: &[u8] = b"ECHO ECHO ECHO!"; 15 | 16 | macro_rules! gets { 17 | ($b: expr) => { 18 | std::str::from_utf8($b).unwrap() 19 | }; 20 | } 21 | 22 | #[tokio::test] 23 | async fn async_echo() { 24 | env_logger::init(); 25 | 26 | let lis = TcpListener::bind(ADDR).await.unwrap(); 27 | 28 | let t1 = tokio::spawn(async move { 29 | let mut buf = vec![0u8; 1024]; 30 | let (tcp, _) = lis.accept().await.unwrap(); 31 | debug!("server: tcp accepted!"); 32 | let mut ws = Endpoint::<_, Server>::accept_async(tcp, &mut buf, HOST, PATH) 33 | .await 34 | .unwrap(); 35 | debug!("server: websocket accepted!"); 36 | 37 | loop { 38 | let n = ws.read(&mut buf).await.unwrap(); 39 | if n == 0 && ws.is_read_end() { 40 | debug!("server: close"); 41 | break; 42 | } 43 | debug!("server: echo.."); 44 | let _ = ws.write(&buf[..n]).await.unwrap(); 45 | } 46 | }); 47 | 48 | let t2 = tokio::spawn(async { 49 | let mut buf = vec![0u8; 1024]; 50 | debug!("client: sleep 500ms.."); 51 | tokio::time::sleep(Duration::from_millis(500)).await; 52 | let tcp = TcpStream::connect(ADDR).await.unwrap(); 53 | debug!("client: tcp connected!"); 54 | let mut ws = Endpoint::<_, Client>::connect_async(tcp, &mut buf, HOST, PATH) 55 | .await 56 | .unwrap(); 57 | debug!("client: websocket connected!"); 58 | 59 | for i in 1..=5 { 60 | debug!("client: send[{}]..", i); 61 | let n = ws.write(ECHO_DATA).await.unwrap(); 62 | assert_eq!(n, ECHO_DATA.len()); 63 | 64 | let n = ws.read(&mut buf).await.unwrap(); 65 | debug!("client: receive message: {}", gets!(&buf[..n])); 66 | assert_eq!(n, ECHO_DATA.len()); 67 | assert_eq!(&buf[..n], ECHO_DATA); 68 | } 69 | 70 | debug!("client: close"); 71 | }); 72 | 73 | let _ = tokio::join!(t1, t2); 74 | } 75 | -------------------------------------------------------------------------------- /tests/async_handshake.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use tokio::net::{TcpStream, TcpListener}; 4 | 5 | use lightws::endpoint::Endpoint; 6 | use lightws::role::{Client, Server}; 7 | 8 | use log::debug; 9 | 10 | const ADDR: &str = "127.0.0.1:10000"; 11 | const HOST: &str = "www.example.com"; 12 | const PATH: &str = "/ws"; 13 | 14 | #[tokio::test] 15 | async fn async_handshake() { 16 | env_logger::init(); 17 | 18 | let lis = TcpListener::bind(ADDR).await.unwrap(); 19 | 20 | let t1 = tokio::spawn(async move { 21 | let mut buf = vec![0u8; 1024]; 22 | let (tcp, _) = lis.accept().await.unwrap(); 23 | debug!("server: tcp accepted!"); 24 | let _ = Endpoint::<_, Server>::accept_async(tcp, &mut buf, HOST, PATH) 25 | .await 26 | .unwrap(); 27 | debug!("server: websocket accepted!"); 28 | }); 29 | 30 | let t2 = tokio::spawn(async { 31 | let mut buf = vec![0u8; 1024]; 32 | debug!("client: sleep 500ms.."); 33 | tokio::time::sleep(Duration::from_millis(500)).await; 34 | let tcp = TcpStream::connect(ADDR).await.unwrap(); 35 | debug!("client: tcp connected!"); 36 | let _ = Endpoint::<_, Client>::connect_async(tcp, &mut buf, HOST, PATH) 37 | .await 38 | .unwrap(); 39 | debug!("client: websocket connected!"); 40 | }); 41 | 42 | let _ = tokio::join!(t1, t2); 43 | } 44 | -------------------------------------------------------------------------------- /tests/async_read_write.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | use tokio::net::{TcpStream, TcpListener}; 4 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 5 | 6 | use lightws::endpoint::Endpoint; 7 | use lightws::role::{Client, Server}; 8 | 9 | use log::debug; 10 | 11 | const ADDR: &str = "127.0.0.1:10000"; 12 | const HOST: &str = "www.example.com"; 13 | const PATH: &str = "/ws"; 14 | const PING_DATA: &[u8] = b"PING PING PING!"; 15 | const PONG_DATA: &[u8] = b"PONG PONG PONG!"; 16 | 17 | macro_rules! gets { 18 | ($b: expr) => { 19 | std::str::from_utf8($b).unwrap() 20 | }; 21 | } 22 | 23 | #[tokio::test] 24 | async fn async_read_write() { 25 | env_logger::init(); 26 | 27 | let lis = TcpListener::bind(ADDR).await.unwrap(); 28 | 29 | let t1 = tokio::spawn(async move { 30 | let mut buf = vec![0u8; 1024]; 31 | let (tcp, _) = lis.accept().await.unwrap(); 32 | debug!("server: tcp accepted!"); 33 | let mut ws = Endpoint::<_, Server>::accept_async(tcp, &mut buf, HOST, PATH) 34 | .await 35 | .unwrap(); 36 | debug!("server: websocket accepted!"); 37 | let n = ws.read(&mut buf).await.unwrap(); 38 | debug!("server: receive message: {}", gets!(&buf[..n])); 39 | assert_eq!(n, PING_DATA.len()); 40 | assert_eq!(&buf[..n], PING_DATA); 41 | debug!("server: send.."); 42 | let n = ws.write(PONG_DATA).await.unwrap(); 43 | assert_eq!(n, PONG_DATA.len()); 44 | }); 45 | 46 | let t2 = tokio::spawn(async { 47 | let mut buf = vec![0u8; 1024]; 48 | debug!("client: sleep 500ms.."); 49 | tokio::time::sleep(Duration::from_millis(500)).await; 50 | let tcp = TcpStream::connect(ADDR).await.unwrap(); 51 | debug!("client: tcp connected!"); 52 | let mut ws = Endpoint::<_, Client>::connect_async(tcp, &mut buf, HOST, PATH) 53 | .await 54 | .unwrap(); 55 | debug!("client: websocket connected!"); 56 | 57 | debug!("client: send.."); 58 | let n = ws.write(PING_DATA).await.unwrap(); 59 | assert_eq!(n, PING_DATA.len()); 60 | let n = ws.read(&mut buf).await.unwrap(); 61 | debug!("client: receive message: {}", gets!(&buf[..n])); 62 | assert_eq!(n, PONG_DATA.len()); 63 | assert_eq!(&buf[..n], PONG_DATA); 64 | }); 65 | 66 | let _ = tokio::join!(t1, t2); 67 | } 68 | -------------------------------------------------------------------------------- /tests/auto_mask.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Read, Write}; 2 | use std::net::{TcpStream, TcpListener}; 3 | use std::time::Duration; 4 | use std::thread; 5 | 6 | use lightws::endpoint::Endpoint; 7 | use lightws::role::{StandardClient, Server}; 8 | 9 | use log::debug; 10 | 11 | const ADDR1: &str = "127.0.0.1:10000"; 12 | const ADDR2: &str = "127.0.0.1:20000"; 13 | const HOST: &str = "www.example.com"; 14 | const PATH: &str = "/ws"; 15 | const ECHO_DATA: &[u8] = b"ECHO ECHO ECHO!"; 16 | 17 | macro_rules! gets { 18 | ($b: expr) => { 19 | std::str::from_utf8($b).unwrap() 20 | }; 21 | } 22 | 23 | // addr0(client) <=> addr1(relay) <=> addr2(server) 24 | #[test] 25 | fn bidi_copy_auto_mask() { 26 | env_logger::init(); 27 | 28 | let lis1 = TcpListener::bind(ADDR1).unwrap(); 29 | let lis2 = TcpListener::bind(ADDR2).unwrap(); 30 | 31 | let relay = thread::spawn(move || { 32 | let mut buf = vec![0u8; 1024]; 33 | let (tcp, _) = lis1.accept().unwrap(); 34 | debug!("relay: tcp accepted!"); 35 | let ws_local_read = Endpoint::<_, Server>::accept(tcp, &mut buf, HOST, PATH).unwrap(); 36 | debug!("relay: websocket accepted!"); 37 | 38 | let tcp = TcpStream::connect(ADDR2).unwrap(); 39 | debug!("relay: tcp connected!"); 40 | let ws_remote_read = 41 | Endpoint::<_, StandardClient>::connect(tcp, &mut buf, HOST, PATH).unwrap(); 42 | debug!("relay: websocket connected!"); 43 | 44 | let mut ws_local_write = ws_local_read.try_clone().unwrap().guard(); 45 | let mut ws_remote_write = ws_remote_read.try_clone().unwrap().guard(); 46 | 47 | let mut ws_local_read = ws_local_read.guard(); 48 | let mut ws_remote_read = ws_remote_read.guard(); 49 | 50 | let t1 = thread::spawn(move || { 51 | let _ = std::io::copy(&mut ws_local_read, &mut ws_remote_write); 52 | debug!("relay: client close, shutdown"); 53 | ws_remote_write 54 | .as_ref() 55 | .shutdown(std::net::Shutdown::Both) 56 | .unwrap(); 57 | }); 58 | 59 | let t2 = thread::spawn(move || { 60 | let _ = std::io::copy(&mut ws_remote_read, &mut ws_local_write); 61 | debug!("relay: server close"); 62 | }); 63 | 64 | t1.join().unwrap(); 65 | t2.join().unwrap(); 66 | }); 67 | 68 | let server = thread::spawn(move || { 69 | let mut buf = vec![0u8; 1024]; 70 | let (tcp, _) = lis2.accept().unwrap(); 71 | debug!("server: tcp accepted!"); 72 | let mut ws = Endpoint::<_, Server>::accept(tcp, &mut buf, HOST, PATH).unwrap(); 73 | debug!("server: websocket accepted!"); 74 | 75 | loop { 76 | let n = ws.read(&mut buf).unwrap(); 77 | if n == 0 && ws.is_read_end() { 78 | debug!("server: close"); 79 | break; 80 | } 81 | debug!("server: echo.."); 82 | let _ = ws.write(&buf[..n]).unwrap(); 83 | } 84 | }); 85 | 86 | let client = thread::spawn(|| { 87 | let mut buf = vec![0u8; 1024]; 88 | debug!("client: sleep 500ms.."); 89 | thread::sleep(Duration::from_millis(500)); 90 | let tcp = TcpStream::connect(ADDR1).unwrap(); 91 | debug!("client: tcp connected!"); 92 | let mut ws = Endpoint::<_, StandardClient>::connect(tcp, &mut buf, HOST, PATH).unwrap(); 93 | debug!("client: websocket connected!"); 94 | 95 | debug!("client: sleep 500ms.."); 96 | thread::sleep(Duration::from_millis(500)); 97 | 98 | for i in 1..=5 { 99 | let echo_data = Vec::from(ECHO_DATA); 100 | debug!("client: send[{}]..", i); 101 | let n = ws.write(&echo_data).unwrap(); 102 | assert_eq!(n, ECHO_DATA.len()); 103 | 104 | let n = ws.read(&mut buf).unwrap(); 105 | debug!("client: receive message: {}", gets!(&buf[..n])); 106 | assert_eq!(n, ECHO_DATA.len()); 107 | assert_eq!(&buf[..n], ECHO_DATA); 108 | } 109 | 110 | debug!("client: close"); 111 | }); 112 | 113 | relay.join().unwrap(); 114 | server.join().unwrap(); 115 | client.join().unwrap(); 116 | } 117 | -------------------------------------------------------------------------------- /tests/sync_bidi_copy.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Read, Write}; 2 | use std::net::{TcpStream, TcpListener}; 3 | use std::time::Duration; 4 | use std::thread; 5 | 6 | use lightws::endpoint::Endpoint; 7 | use lightws::role::{Client, Server}; 8 | 9 | use log::debug; 10 | 11 | const ADDR1: &str = "127.0.0.1:10000"; 12 | const ADDR2: &str = "127.0.0.1:20000"; 13 | const HOST: &str = "www.example.com"; 14 | const PATH: &str = "/ws"; 15 | const ECHO_DATA: &[u8] = b"ECHO ECHO ECHO!"; 16 | 17 | macro_rules! gets { 18 | ($b: expr) => { 19 | std::str::from_utf8($b).unwrap() 20 | }; 21 | } 22 | 23 | // addr0(client) <=> addr1(relay) <=> addr2(server) 24 | #[test] 25 | fn sync_bidi_copy() { 26 | env_logger::init(); 27 | 28 | let lis1 = TcpListener::bind(ADDR1).unwrap(); 29 | let lis2 = TcpListener::bind(ADDR2).unwrap(); 30 | 31 | let relay = thread::spawn(move || { 32 | let mut buf = vec![0u8; 1024]; 33 | let (tcp, _) = lis1.accept().unwrap(); 34 | debug!("relay: tcp accepted!"); 35 | let ws_local_read = Endpoint::<_, Server>::accept(tcp, &mut buf, HOST, PATH).unwrap(); 36 | debug!("relay: websocket accepted!"); 37 | 38 | let tcp = TcpStream::connect(ADDR2).unwrap(); 39 | debug!("relay: tcp connected!"); 40 | let ws_remote_read = Endpoint::<_, Client>::connect(tcp, &mut buf, HOST, PATH).unwrap(); 41 | debug!("relay: websocket connected!"); 42 | 43 | let mut ws_local_write = ws_local_read.try_clone().unwrap().guard(); 44 | let mut ws_remote_write = ws_remote_read.try_clone().unwrap().guard(); 45 | 46 | let mut ws_local_read = ws_local_read.guard(); 47 | let mut ws_remote_read = ws_remote_read.guard(); 48 | 49 | let t1 = thread::spawn(move || { 50 | let _ = std::io::copy(&mut ws_local_read, &mut ws_remote_write); 51 | debug!("relay: client close, shutdown"); 52 | ws_remote_write 53 | .as_ref() 54 | .shutdown(std::net::Shutdown::Both) 55 | .unwrap(); 56 | }); 57 | 58 | let t2 = thread::spawn(move || { 59 | let _ = std::io::copy(&mut ws_remote_read, &mut ws_local_write); 60 | debug!("relay: server close"); 61 | }); 62 | 63 | t1.join().unwrap(); 64 | t2.join().unwrap(); 65 | }); 66 | 67 | let server = thread::spawn(move || { 68 | let mut buf = vec![0u8; 1024]; 69 | let (tcp, _) = lis2.accept().unwrap(); 70 | debug!("server: tcp accepted!"); 71 | let mut ws = Endpoint::<_, Server>::accept(tcp, &mut buf, HOST, PATH).unwrap(); 72 | debug!("server: websocket accepted!"); 73 | 74 | loop { 75 | let n = ws.read(&mut buf).unwrap(); 76 | if n == 0 && ws.is_read_end() { 77 | debug!("server: close"); 78 | break; 79 | } 80 | debug!("server: echo.."); 81 | let _ = ws.write(&buf[..n]).unwrap(); 82 | } 83 | }); 84 | 85 | let client = thread::spawn(|| { 86 | let mut buf = vec![0u8; 1024]; 87 | debug!("client: sleep 500ms.."); 88 | thread::sleep(Duration::from_millis(500)); 89 | let tcp = TcpStream::connect(ADDR1).unwrap(); 90 | debug!("client: tcp connected!"); 91 | let mut ws = Endpoint::<_, Client>::connect(tcp, &mut buf, HOST, PATH).unwrap(); 92 | debug!("client: websocket connected!"); 93 | 94 | debug!("client: sleep 500ms.."); 95 | thread::sleep(Duration::from_millis(500)); 96 | 97 | for i in 1..=5 { 98 | debug!("client: send[{}]..", i); 99 | let n = ws.write(ECHO_DATA).unwrap(); 100 | assert_eq!(n, ECHO_DATA.len()); 101 | 102 | let n = ws.read(&mut buf).unwrap(); 103 | debug!("client: receive message: {}", gets!(&buf[..n])); 104 | assert_eq!(n, ECHO_DATA.len()); 105 | assert_eq!(&buf[..n], ECHO_DATA); 106 | } 107 | 108 | debug!("client: close"); 109 | }); 110 | 111 | relay.join().unwrap(); 112 | server.join().unwrap(); 113 | client.join().unwrap(); 114 | } 115 | -------------------------------------------------------------------------------- /tests/sync_echo.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Read, Write}; 2 | use std::net::{TcpStream, TcpListener}; 3 | use std::time::Duration; 4 | use std::thread; 5 | 6 | use lightws::endpoint::Endpoint; 7 | use lightws::role::{Client, Server}; 8 | 9 | use log::debug; 10 | 11 | const ADDR: &str = "127.0.0.1:10000"; 12 | const HOST: &str = "www.example.com"; 13 | const PATH: &str = "/ws"; 14 | const ECHO_DATA: &[u8] = b"ECHO ECHO ECHO!"; 15 | 16 | macro_rules! gets { 17 | ($b: expr) => { 18 | std::str::from_utf8($b).unwrap() 19 | }; 20 | } 21 | 22 | #[test] 23 | fn sync_echo() { 24 | env_logger::init(); 25 | 26 | let lis = TcpListener::bind(ADDR).unwrap(); 27 | 28 | let t1 = thread::spawn(move || { 29 | let mut buf = vec![0u8; 1024]; 30 | let (tcp, _) = lis.accept().unwrap(); 31 | debug!("server: tcp accepted!"); 32 | let mut ws = Endpoint::<_, Server>::accept(tcp, &mut buf, HOST, PATH).unwrap(); 33 | debug!("server: websocket accepted!"); 34 | 35 | loop { 36 | let n = ws.read(&mut buf).unwrap(); 37 | if n == 0 && ws.is_read_end() { 38 | debug!("server: close"); 39 | break; 40 | } 41 | debug!("server: echo.."); 42 | let _ = ws.write(&buf[..n]).unwrap(); 43 | } 44 | }); 45 | 46 | let t2 = thread::spawn(|| { 47 | let mut buf = vec![0u8; 1024]; 48 | debug!("client: sleep 500ms.."); 49 | thread::sleep(Duration::from_millis(500)); 50 | let tcp = TcpStream::connect(ADDR).unwrap(); 51 | debug!("client: tcp connected!"); 52 | let mut ws = Endpoint::<_, Client>::connect(tcp, &mut buf, HOST, PATH).unwrap(); 53 | debug!("client: websocket connected!"); 54 | 55 | for i in 1..=5 { 56 | debug!("client: send[{}]..", i); 57 | let n = ws.write(ECHO_DATA).unwrap(); 58 | assert_eq!(n, ECHO_DATA.len()); 59 | 60 | let n = ws.read(&mut buf).unwrap(); 61 | debug!("client: receive message: {}", gets!(&buf[..n])); 62 | assert_eq!(n, ECHO_DATA.len()); 63 | assert_eq!(&buf[..n], ECHO_DATA); 64 | } 65 | 66 | debug!("client: close"); 67 | }); 68 | 69 | t1.join().unwrap(); 70 | t2.join().unwrap(); 71 | } 72 | -------------------------------------------------------------------------------- /tests/sync_handshake.rs: -------------------------------------------------------------------------------- 1 | use std::net::{TcpStream, TcpListener}; 2 | use std::time::Duration; 3 | use std::thread; 4 | 5 | use lightws::endpoint::Endpoint; 6 | use lightws::role::{Client, Server}; 7 | 8 | use log::debug; 9 | 10 | const ADDR: &str = "127.0.0.1:10000"; 11 | const HOST: &str = "www.example.com"; 12 | const PATH: &str = "/ws"; 13 | 14 | #[test] 15 | fn sync_handshake() { 16 | env_logger::init(); 17 | 18 | let lis = TcpListener::bind(ADDR).unwrap(); 19 | 20 | let t1 = thread::spawn(move || { 21 | let mut buf = vec![0u8; 1024]; 22 | let (tcp, _) = lis.accept().unwrap(); 23 | debug!("server: tcp accepted!"); 24 | let _ = Endpoint::<_, Server>::accept(tcp, &mut buf, HOST, PATH).unwrap(); 25 | debug!("server: websocket accepted!"); 26 | }); 27 | 28 | let t2 = thread::spawn(|| { 29 | let mut buf = vec![0u8; 1024]; 30 | debug!("client: sleep 500ms.."); 31 | thread::sleep(Duration::from_millis(500)); 32 | let tcp = TcpStream::connect(ADDR).unwrap(); 33 | debug!("client: tcp connected!"); 34 | let _ = Endpoint::<_, Client>::connect(tcp, &mut buf, HOST, PATH).unwrap(); 35 | debug!("client: websocket connected!"); 36 | }); 37 | 38 | t1.join().unwrap(); 39 | t2.join().unwrap(); 40 | } 41 | -------------------------------------------------------------------------------- /tests/sync_read_write.rs: -------------------------------------------------------------------------------- 1 | use std::io::{Read, Write}; 2 | use std::net::{TcpStream, TcpListener}; 3 | use std::time::Duration; 4 | use std::thread; 5 | 6 | use lightws::endpoint::Endpoint; 7 | use lightws::role::{Client, Server}; 8 | 9 | use log::debug; 10 | 11 | const ADDR: &str = "127.0.0.1:10000"; 12 | const HOST: &str = "www.example.com"; 13 | const PATH: &str = "/ws"; 14 | const PING_DATA: &[u8] = b"PING PING PING!"; 15 | const PONG_DATA: &[u8] = b"PONG PONG PONG!"; 16 | 17 | macro_rules! gets { 18 | ($b: expr) => { 19 | std::str::from_utf8($b).unwrap() 20 | }; 21 | } 22 | 23 | #[test] 24 | fn sync_read_write() { 25 | env_logger::init(); 26 | 27 | let lis = TcpListener::bind(ADDR).unwrap(); 28 | 29 | let t1 = thread::spawn(move || { 30 | let mut buf = vec![0u8; 1024]; 31 | let (tcp, _) = lis.accept().unwrap(); 32 | debug!("server: tcp accepted!"); 33 | let mut ws = Endpoint::<_, Server>::accept(tcp, &mut buf, HOST, PATH).unwrap(); 34 | debug!("server: websocket accepted!"); 35 | 36 | let n = ws.read(&mut buf).unwrap(); 37 | debug!("server: receive message: {}", gets!(&buf[..n])); 38 | assert_eq!(n, PING_DATA.len()); 39 | assert_eq!(&buf[..n], PING_DATA); 40 | 41 | debug!("server: send.."); 42 | let n = ws.write(PONG_DATA).unwrap(); 43 | assert_eq!(n, PONG_DATA.len()); 44 | }); 45 | 46 | let t2 = thread::spawn(|| { 47 | let mut buf = vec![0u8; 1024]; 48 | debug!("client: sleep 500ms.."); 49 | thread::sleep(Duration::from_millis(500)); 50 | let tcp = TcpStream::connect(ADDR).unwrap(); 51 | debug!("client: tcp connected!"); 52 | let mut ws = Endpoint::<_, Client>::connect(tcp, &mut buf, HOST, PATH).unwrap(); 53 | debug!("client: websocket connected!"); 54 | 55 | debug!("client: send.."); 56 | let n = ws.write(PING_DATA).unwrap(); 57 | assert_eq!(n, PING_DATA.len()); 58 | 59 | let n = ws.read(&mut buf).unwrap(); 60 | debug!("client: receive message: {}", gets!(&buf[..n])); 61 | assert_eq!(n, PONG_DATA.len()); 62 | assert_eq!(&buf[..n], PONG_DATA); 63 | }); 64 | 65 | t1.join().unwrap(); 66 | t2.join().unwrap(); 67 | } 68 | --------------------------------------------------------------------------------