├── .github └── workflows │ └── CI.yaml ├── .gitignore ├── CHANGELOG.md ├── Cargo.toml ├── LICENSE ├── README.md └── src └── lib.rs /.github/workflows/CI.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | env: 4 | CARGO_TERM_COLOR: always 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | pull_request: {} 11 | 12 | jobs: 13 | check: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@master 17 | - uses: actions-rs/toolchain@v1 18 | with: 19 | toolchain: beta 20 | override: true 21 | profile: minimal 22 | components: clippy, rustfmt 23 | - uses: Swatinem/rust-cache@v1 24 | - name: Check 25 | run: | 26 | cargo clippy --all --all-targets --all-features 27 | - name: rustfmt 28 | run: | 29 | cargo fmt --all -- --check 30 | 31 | check-docs: 32 | runs-on: ubuntu-latest 33 | steps: 34 | - uses: actions/checkout@master 35 | - uses: actions-rs/toolchain@v1 36 | with: 37 | toolchain: stable 38 | override: true 39 | profile: minimal 40 | - uses: Swatinem/rust-cache@v1 41 | - name: cargo doc 42 | env: 43 | RUSTDOCFLAGS: "-D broken-intra-doc-links" 44 | run: cargo doc --all-features --no-deps 45 | 46 | test-versions: 47 | needs: check 48 | runs-on: ubuntu-latest 49 | strategy: 50 | matrix: 51 | rust: [stable, beta] 52 | steps: 53 | - uses: actions/checkout@master 54 | - uses: actions-rs/toolchain@v1 55 | with: 56 | toolchain: ${{ matrix.rust }} 57 | override: true 58 | profile: minimal 59 | - uses: Swatinem/rust-cache@v1 60 | - name: Run tests 61 | uses: actions-rs/cargo@v1 62 | with: 63 | command: test 64 | args: --all --all-features --all-targets 65 | 66 | test-docs: 67 | needs: check 68 | runs-on: ubuntu-latest 69 | steps: 70 | - uses: actions/checkout@master 71 | - uses: actions-rs/toolchain@v1 72 | with: 73 | toolchain: stable 74 | override: true 75 | profile: minimal 76 | - uses: Swatinem/rust-cache@v1 77 | - name: Run doc tests 78 | uses: actions-rs/cargo@v1 79 | with: 80 | command: test 81 | args: --all-features --doc 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | # Unreleased 9 | 10 | - None. 11 | 12 | # 0.3.0 (02. August, 2022) 13 | 14 | - **changed:** Update to tokio-tungstenite 0.20 ([#9]) 15 | 16 | [#9]: https://github.com/davidpdrsn/axum-tungstenite/pull/9 17 | 18 | # 0.2.0 (10. December, 2022) 19 | 20 | - **changed:** Update to axum-core 0.3, which requires axum 0.6 ([#6]) 21 | - **changed:** Update to tokio-tungstenite 0.18 ([#6]) 22 | - **added:** Allow configuration of client frame masking ([#3]) 23 | - **added:** Add `on_failed_upgrade` callback to `WebSocketUpgrade` ([#7]) 24 | 25 | [#3]: https://github.com/davidpdrsn/axum-tungstenite/pull/3 26 | [#6]: https://github.com/davidpdrsn/axum-tungstenite/pull/6 27 | [#7]: https://github.com/davidpdrsn/axum-tungstenite/pull/7 28 | 29 | # 0.1.0 (15. May, 2022) 30 | 31 | - Initial release. 32 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "axum-tungstenite" 3 | version = "0.3.0" 4 | categories = ["asynchronous", "network-programming", "web-programming"] 5 | description = "WebSocket connections for axum directly using tungstenite" 6 | edition = "2021" 7 | homepage = "https://github.com/davidpdrsn/axum-tungstenite" 8 | keywords = ["http", "web", "framework"] 9 | license = "MIT" 10 | readme = "README.md" 11 | repository = "https://github.com/davidpdrsn/axum-tungstenite" 12 | 13 | [dependencies] 14 | async-trait = "0.1.59" 15 | axum-core = "0.3.0" 16 | base64 = "0.21.0" 17 | bytes = "1.3.0" 18 | futures-util = { version = "0.3.25", default-features = false, features = ["alloc"] } 19 | http = "0.2.8" 20 | http-body = "0.4.5" 21 | hyper = "0.14.23" 22 | sha-1 = "0.10.1" 23 | tokio = { version = "1.23.0", features = ["rt"] } 24 | tokio-tungstenite = "0.20.0" 25 | 26 | [dev-dependencies] 27 | axum = "0.6.1" 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 David Pedersen 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # axum-tungstenite 2 | 3 | WebSocket connections for [axum] directly using [tungstenite]. 4 | 5 | [![CI](https://github.com/davidpdrsn/axum-tungstenite/actions/workflows/CI.yaml/badge.svg)](https://github.com/davidpdrsn/axum-tungstenite/actions/workflows/CI.yaml) 6 | [![Crates.io](https://img.shields.io/crates/v/axum-tungstenite)](https://crates.io/crates/axum-tungstenite) 7 | [![Documentation](https://docs.rs/axum-tungstenite/badge.svg)](https://docs.rs/axum-tungstenite) 8 | 9 | More information about this crate can be found in the [crate documentation][docs]. 10 | 11 | # Differences from `axum::extract::ws` 12 | 13 | axum already supports WebSockets through [`axum::extract::ws`]. However the fact that axum uses 14 | tungstenite under the hood is a private implementation detail. Thus axum doesn't directly 15 | expose types from tungstenite, such as [`tungstenite::Error`] and [`tungstenite::Message`]. 16 | This allows axum to update to a new major version of tungstenite in a new minor version of 17 | axum, which leads to greater API stability. 18 | 19 | This library works differently as it directly uses the types from tungstenite in its public 20 | API. That makes some things simpler but also means axum-tungstenite will receive a new major 21 | version when tungstenite does. 22 | 23 | # Which should you choose? 24 | 25 | By default you should use `axum::extract::ws` unless you specifically need something from 26 | tungstenite and don't mind keeping up with additional breaking changes. 27 | 28 | ## Safety 29 | 30 | This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 31 | 100% safe Rust. 32 | 33 | ## License 34 | 35 | This project is licensed under the [MIT license][license]. 36 | 37 | [docs]: https://docs.rs/axum-tungstenite 38 | [license]: https://github.com/davidpdrsn/axum-tungstenite/blob/main/LICENSE 39 | [axum]: https://crates.io/crates/axum 40 | [tungstenite]: https://crates.io/crates/tungstenite 41 | [`axum::extract::ws`]: https://docs.rs/axum/latest/axum/extract/ws/index.html 42 | [`tungstenite::Error`]: https://docs.rs/tungstenite/latest/tungstenite/error/enum.Error.html 43 | [`tungstenite::Message`]: https://docs.rs/tungstenite/latest/tungstenite/enum.Message.html 44 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! WebSocket connections for [axum] directly using [tungstenite]. 2 | //! 3 | //! # Differences from `axum::extract::ws` 4 | //! 5 | //! axum already supports WebSockets through [`axum::extract::ws`]. However the fact that axum uses 6 | //! tungstenite under the hood is a private implementation detail. Thus axum doesn't directly 7 | //! expose types from tungstenite, such as [`tungstenite::Error`] and [`tungstenite::Message`]. 8 | //! This allows axum to update to a new major version of tungstenite in a new minor version of 9 | //! axum, which leads to greater API stability. 10 | //! 11 | //! This library works differently as it directly uses the types from tungstenite in its public 12 | //! API. That makes some things simpler but also means axum-tungstenite will receive a new major 13 | //! version when tungstenite does. 14 | //! 15 | //! # Which should you choose? 16 | //! 17 | //! By default you should use `axum::extract::ws` unless you specifically need something from 18 | //! tungstenite and don't mind keeping up with additional breaking changes. 19 | //! 20 | //! # Example 21 | //! 22 | //! ``` 23 | //! use axum::{ 24 | //! routing::get, 25 | //! response::IntoResponse, 26 | //! Router, 27 | //! }; 28 | //! use axum_tungstenite::{WebSocketUpgrade, WebSocket}; 29 | //! 30 | //! let app = Router::new().route("/ws", get(handler)); 31 | //! 32 | //! async fn handler(ws: WebSocketUpgrade) -> impl IntoResponse { 33 | //! ws.on_upgrade(handle_socket) 34 | //! } 35 | //! 36 | //! async fn handle_socket(mut socket: WebSocket) { 37 | //! while let Some(msg) = socket.recv().await { 38 | //! let msg = if let Ok(msg) = msg { 39 | //! msg 40 | //! } else { 41 | //! // client disconnected 42 | //! return; 43 | //! }; 44 | //! 45 | //! if socket.send(msg).await.is_err() { 46 | //! // client disconnected 47 | //! return; 48 | //! } 49 | //! } 50 | //! } 51 | //! # async { 52 | //! # axum::Server::bind(&"".parse().unwrap()).serve(app.into_make_service()).await.unwrap(); 53 | //! # }; 54 | //! ``` 55 | //! 56 | //! [axum]: https://crates.io/crates/axum 57 | //! [tungstenite]: https://crates.io/crates/tungstenite 58 | //! [`axum::extract::ws`]: https://docs.rs/axum/latest/axum/extract/ws/index.html 59 | //! [`tungstenite::Error`]: https://docs.rs/tungstenite/latest/tungstenite/error/enum.Error.html 60 | //! [`tungstenite::Message`]: https://docs.rs/tungstenite/latest/tungstenite/enum.Message.html 61 | 62 | #![warn( 63 | clippy::all, 64 | clippy::dbg_macro, 65 | clippy::todo, 66 | clippy::empty_enum, 67 | clippy::enum_glob_use, 68 | clippy::mem_forget, 69 | clippy::unused_self, 70 | clippy::filter_map_next, 71 | clippy::needless_continue, 72 | clippy::needless_borrow, 73 | clippy::match_wildcard_for_single_variants, 74 | clippy::if_let_mutex, 75 | clippy::mismatched_target_os, 76 | clippy::await_holding_lock, 77 | clippy::match_on_vec_items, 78 | clippy::imprecise_flops, 79 | clippy::suboptimal_flops, 80 | clippy::lossy_float_literal, 81 | clippy::rest_pat_in_fully_bound_structs, 82 | clippy::fn_params_excessive_bools, 83 | clippy::exit, 84 | clippy::inefficient_to_string, 85 | clippy::linkedlist, 86 | clippy::macro_use_imports, 87 | clippy::option_option, 88 | clippy::verbose_file_reads, 89 | clippy::unnested_or_patterns, 90 | clippy::str_to_string, 91 | rust_2018_idioms, 92 | future_incompatible, 93 | nonstandard_style, 94 | missing_debug_implementations, 95 | missing_docs 96 | )] 97 | #![deny(unreachable_pub, private_in_public)] 98 | #![allow(elided_lifetimes_in_paths, clippy::type_complexity)] 99 | #![forbid(unsafe_code)] 100 | #![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))] 101 | #![cfg_attr(test, allow(clippy::float_cmp))] 102 | 103 | use self::rejection::*; 104 | use async_trait::async_trait; 105 | use axum_core::{ 106 | extract::FromRequestParts, 107 | response::{IntoResponse, Response}, 108 | }; 109 | use bytes::Bytes; 110 | use futures_util::{ 111 | sink::{Sink, SinkExt}, 112 | stream::{Stream, StreamExt}, 113 | }; 114 | use http::{ 115 | header::{self, HeaderMap, HeaderName, HeaderValue}, 116 | request::Parts, 117 | Method, StatusCode, 118 | }; 119 | use hyper::upgrade::{OnUpgrade, Upgraded}; 120 | use sha1::{Digest, Sha1}; 121 | use std::{ 122 | borrow::Cow, 123 | future::Future, 124 | pin::Pin, 125 | task::{Context, Poll}, 126 | }; 127 | use tokio_tungstenite::{ 128 | tungstenite::protocol::{self, WebSocketConfig}, 129 | WebSocketStream, 130 | }; 131 | 132 | #[doc(no_inline)] 133 | pub use tokio_tungstenite::tungstenite::error::{ 134 | CapacityError, Error, ProtocolError, TlsError, UrlError, 135 | }; 136 | #[doc(no_inline)] 137 | pub use tokio_tungstenite::tungstenite::Message; 138 | 139 | /// Extractor for establishing WebSocket connections. 140 | /// 141 | /// See the [module docs](self) for an example. 142 | #[derive(Debug)] 143 | pub struct WebSocketUpgrade { 144 | config: WebSocketConfig, 145 | /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. 146 | protocol: Option, 147 | sec_websocket_key: HeaderValue, 148 | on_upgrade: OnUpgrade, 149 | on_failed_upgrade: F, 150 | sec_websocket_protocol: Option, 151 | } 152 | 153 | impl WebSocketUpgrade { 154 | /// The target minimum size of the write buffer to reach before writing the data 155 | /// to the underlying stream. 156 | /// 157 | /// The default value is 128 KiB. 158 | /// 159 | /// If set to `0` each message will be eagerly written to the underlying stream. 160 | /// It is often more optimal to allow them to buffer a little, hence the default value. 161 | /// 162 | /// Note: [`flush`](SinkExt::flush) will always fully write the buffer regardless. 163 | pub fn write_buffer_size(mut self, size: usize) -> Self { 164 | self.config.write_buffer_size = size; 165 | self 166 | } 167 | 168 | /// The max size of the write buffer in bytes. Setting this can provide backpressure 169 | /// in the case the write buffer is filling up due to write errors. 170 | /// 171 | /// The default value is unlimited. 172 | /// 173 | /// Note: The write buffer only builds up past [`write_buffer_size`](Self::write_buffer_size) 174 | /// when writes to the underlying stream are failing. So the **write buffer can not 175 | /// fill up if you are not observing write errors even if not flushing**. 176 | /// 177 | /// Note: Should always be at least [`write_buffer_size + 1 message`](Self::write_buffer_size) 178 | /// and probably a little more depending on error handling strategy. 179 | pub fn max_write_buffer_size(mut self, max: usize) -> Self { 180 | self.config.max_write_buffer_size = max; 181 | self 182 | } 183 | 184 | /// Set the maximum message size (defaults to 64 megabytes) 185 | pub fn max_message_size(mut self, max: usize) -> Self { 186 | self.config.max_message_size = Some(max); 187 | self 188 | } 189 | 190 | /// Set the maximum frame size (defaults to 16 megabytes) 191 | pub fn max_frame_size(mut self, max: usize) -> Self { 192 | self.config.max_frame_size = Some(max); 193 | self 194 | } 195 | 196 | /// Allow server to accept unmasked frames (defaults to false) 197 | pub fn accept_unmasked_frames(mut self, accept: bool) -> Self { 198 | self.config.accept_unmasked_frames = accept; 199 | self 200 | } 201 | 202 | /// Set the known protocols. 203 | /// 204 | /// If the protocol name specified by `Sec-WebSocket-Protocol` header 205 | /// to match any of them, the upgrade response will include `Sec-WebSocket-Protocol` header and 206 | /// return the protocol name. 207 | /// 208 | /// The protocols should be listed in decreasing order of preference: if the client offers 209 | /// multiple protocols that the server could support, the server will pick the first one in 210 | /// this list. 211 | pub fn protocols(mut self, protocols: I) -> Self 212 | where 213 | I: IntoIterator, 214 | I::Item: Into>, 215 | { 216 | if let Some(req_protocols) = self 217 | .sec_websocket_protocol 218 | .as_ref() 219 | .and_then(|p| p.to_str().ok()) 220 | { 221 | self.protocol = protocols 222 | .into_iter() 223 | .map(Into::into) 224 | .find(|protocol| { 225 | req_protocols 226 | .split(',') 227 | .any(|req_protocol| req_protocol.trim() == protocol) 228 | }) 229 | .map(|protocol| match protocol { 230 | Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(), 231 | Cow::Borrowed(s) => HeaderValue::from_static(s), 232 | }); 233 | } 234 | 235 | self 236 | } 237 | 238 | /// Finalize upgrading the connection and call the provided callback with 239 | /// the stream. 240 | /// 241 | /// When using `WebSocketUpgrade`, the response produced by this method 242 | /// should be returned from the handler. See the [module docs](self) for an 243 | /// example. 244 | pub fn on_upgrade(self, callback: F) -> Response 245 | where 246 | F: FnOnce(WebSocket) -> Fut + Send + 'static, 247 | Fut: Future + Send + 'static, 248 | C: OnFailedUpdgrade, 249 | { 250 | let on_upgrade = self.on_upgrade; 251 | let config = self.config; 252 | let on_failed_upgrade = self.on_failed_upgrade; 253 | 254 | let protocol = self.protocol.clone(); 255 | 256 | tokio::spawn(async move { 257 | let upgraded = match on_upgrade.await { 258 | Ok(upgraded) => upgraded, 259 | Err(err) => { 260 | on_failed_upgrade.call(err); 261 | return; 262 | } 263 | }; 264 | 265 | let socket = 266 | WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config)) 267 | .await; 268 | let socket = WebSocket { 269 | inner: socket, 270 | protocol, 271 | }; 272 | callback(socket).await; 273 | }); 274 | 275 | #[allow(clippy::declare_interior_mutable_const)] 276 | const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); 277 | #[allow(clippy::declare_interior_mutable_const)] 278 | const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); 279 | 280 | let mut headers = HeaderMap::new(); 281 | headers.insert(header::CONNECTION, UPGRADE); 282 | headers.insert(header::UPGRADE, WEBSOCKET); 283 | headers.insert( 284 | header::SEC_WEBSOCKET_ACCEPT, 285 | sign(self.sec_websocket_key.as_bytes()), 286 | ); 287 | 288 | if let Some(protocol) = self.protocol { 289 | headers.insert(header::SEC_WEBSOCKET_PROTOCOL, protocol); 290 | } 291 | 292 | (StatusCode::SWITCHING_PROTOCOLS, headers).into_response() 293 | } 294 | 295 | /// Provide a callback to call if upgrading the connection fails. 296 | /// 297 | /// The connection upgrade is performed in a background task. If that fails this callback 298 | /// will be called. 299 | /// 300 | /// By default any errors will be silently ignored. 301 | /// 302 | /// # Example 303 | /// 304 | /// ``` 305 | /// use axum::response::Response; 306 | /// use axum_tungstenite::WebSocketUpgrade; 307 | /// 308 | /// async fn handler(ws: WebSocketUpgrade) -> Response { 309 | /// ws.on_failed_upgrade(|error| { 310 | /// report_error(error); 311 | /// }) 312 | /// .on_upgrade(|socket| async { /* ... */ }) 313 | /// } 314 | /// # 315 | /// # fn report_error(_: hyper::Error) {} 316 | /// ``` 317 | pub fn on_failed_upgrade(self, callback: C2) -> WebSocketUpgrade 318 | where 319 | C2: OnFailedUpdgrade, 320 | { 321 | WebSocketUpgrade { 322 | config: self.config, 323 | protocol: self.protocol, 324 | sec_websocket_key: self.sec_websocket_key, 325 | on_upgrade: self.on_upgrade, 326 | on_failed_upgrade: callback, 327 | sec_websocket_protocol: self.sec_websocket_protocol, 328 | } 329 | } 330 | } 331 | 332 | #[async_trait] 333 | impl FromRequestParts for WebSocketUpgrade 334 | where 335 | S: Sync, 336 | { 337 | type Rejection = WebSocketUpgradeRejection; 338 | 339 | async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { 340 | if parts.method != Method::GET { 341 | return Err(MethodNotGet.into()); 342 | } 343 | 344 | if !header_contains(parts, header::CONNECTION, "upgrade") { 345 | return Err(InvalidConnectionHeader.into()); 346 | } 347 | 348 | if !header_eq(parts, header::UPGRADE, "websocket") { 349 | return Err(InvalidUpgradeHeader.into()); 350 | } 351 | 352 | if !header_eq(parts, header::SEC_WEBSOCKET_VERSION, "13") { 353 | return Err(InvalidWebSocketVersionHeader.into()); 354 | } 355 | 356 | let sec_websocket_key = if let Some(key) = parts.headers.remove(header::SEC_WEBSOCKET_KEY) { 357 | key 358 | } else { 359 | return Err(WebSocketKeyHeaderMissing.into()); 360 | }; 361 | 362 | let on_upgrade = parts.extensions.remove::().unwrap(); 363 | 364 | let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned(); 365 | 366 | Ok(Self { 367 | config: Default::default(), 368 | protocol: None, 369 | sec_websocket_key, 370 | on_upgrade, 371 | on_failed_upgrade: DefaultOnFailedUpdgrade, 372 | sec_websocket_protocol, 373 | }) 374 | } 375 | } 376 | 377 | fn header_eq(req: &Parts, key: HeaderName, value: &'static str) -> bool { 378 | if let Some(header) = req.headers.get(&key) { 379 | header.as_bytes().eq_ignore_ascii_case(value.as_bytes()) 380 | } else { 381 | false 382 | } 383 | } 384 | 385 | fn header_contains(req: &Parts, key: HeaderName, value: &'static str) -> bool { 386 | let header = if let Some(header) = req.headers.get(&key) { 387 | header 388 | } else { 389 | return false; 390 | }; 391 | 392 | if let Ok(header) = std::str::from_utf8(header.as_bytes()) { 393 | header.to_ascii_lowercase().contains(value) 394 | } else { 395 | false 396 | } 397 | } 398 | 399 | /// A stream of WebSocket messages. 400 | #[derive(Debug)] 401 | pub struct WebSocket { 402 | inner: WebSocketStream, 403 | protocol: Option, 404 | } 405 | 406 | impl WebSocket { 407 | /// Consume `self` and get the inner [`tokio_tungstenite::WebSocketStream`]. 408 | pub fn into_inner(self) -> WebSocketStream { 409 | self.inner 410 | } 411 | 412 | /// Receive another message. 413 | /// 414 | /// Returns `None` if the stream has closed. 415 | pub async fn recv(&mut self) -> Option> { 416 | self.next().await 417 | } 418 | 419 | /// Send a message. 420 | pub async fn send(&mut self, msg: Message) -> Result<(), Error> { 421 | self.inner.send(msg).await 422 | } 423 | 424 | /// Gracefully close this WebSocket. 425 | pub async fn close(mut self) -> Result<(), Error> { 426 | self.inner.close(None).await 427 | } 428 | 429 | /// Return the selected WebSocket subprotocol, if one has been chosen. 430 | pub fn protocol(&self) -> Option<&HeaderValue> { 431 | self.protocol.as_ref() 432 | } 433 | } 434 | 435 | impl Stream for WebSocket { 436 | type Item = Result; 437 | 438 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 439 | self.inner.poll_next_unpin(cx) 440 | } 441 | } 442 | 443 | impl Sink for WebSocket { 444 | type Error = Error; 445 | 446 | fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 447 | Pin::new(&mut self.inner).poll_ready(cx) 448 | } 449 | 450 | fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { 451 | Pin::new(&mut self.inner).start_send(item) 452 | } 453 | 454 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 455 | Pin::new(&mut self.inner).poll_flush(cx) 456 | } 457 | 458 | fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 459 | Pin::new(&mut self.inner).poll_close(cx) 460 | } 461 | } 462 | 463 | fn sign(key: &[u8]) -> HeaderValue { 464 | use base64::engine::Engine as _; 465 | 466 | let mut sha1 = Sha1::default(); 467 | sha1.update(key); 468 | sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]); 469 | let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize())); 470 | HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value") 471 | } 472 | 473 | /// What to do when a connection upgrade fails. 474 | /// 475 | /// See [`WebSocketUpgrade::on_failed_upgrade`] for more details. 476 | pub trait OnFailedUpdgrade: Send + 'static { 477 | /// Call the callback. 478 | fn call(self, error: hyper::Error); 479 | } 480 | 481 | impl OnFailedUpdgrade for F 482 | where 483 | F: FnOnce(hyper::Error) + Send + 'static, 484 | { 485 | fn call(self, error: hyper::Error) { 486 | self(error) 487 | } 488 | } 489 | 490 | /// The default `OnFailedUpdgrade` used by `WebSocketUpgrade`. 491 | /// 492 | /// It simply ignores the error. 493 | #[non_exhaustive] 494 | #[derive(Debug)] 495 | pub struct DefaultOnFailedUpdgrade; 496 | 497 | impl OnFailedUpdgrade for DefaultOnFailedUpdgrade { 498 | #[inline] 499 | fn call(self, _error: hyper::Error) {} 500 | } 501 | 502 | pub mod rejection { 503 | //! WebSocket specific rejections. 504 | 505 | use super::*; 506 | 507 | macro_rules! define_rejection { 508 | ( 509 | #[status = $status:ident] 510 | #[body = $body:expr] 511 | $(#[$m:meta])* 512 | pub struct $name:ident; 513 | ) => { 514 | $(#[$m])* 515 | #[derive(Debug)] 516 | #[non_exhaustive] 517 | pub struct $name; 518 | 519 | impl IntoResponse for $name { 520 | fn into_response(self) -> Response { 521 | (http::StatusCode::$status, $body).into_response() 522 | } 523 | } 524 | 525 | impl std::fmt::Display for $name { 526 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 527 | write!(f, "{}", $body) 528 | } 529 | } 530 | 531 | impl std::error::Error for $name {} 532 | }; 533 | } 534 | 535 | define_rejection! { 536 | #[status = METHOD_NOT_ALLOWED] 537 | #[body = "Request method must be `GET`"] 538 | /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). 539 | pub struct MethodNotGet; 540 | } 541 | 542 | define_rejection! { 543 | #[status = BAD_REQUEST] 544 | #[body = "Connection header did not include 'upgrade'"] 545 | /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). 546 | pub struct InvalidConnectionHeader; 547 | } 548 | 549 | define_rejection! { 550 | #[status = BAD_REQUEST] 551 | #[body = "`Upgrade` header did not include 'websocket'"] 552 | /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). 553 | pub struct InvalidUpgradeHeader; 554 | } 555 | 556 | define_rejection! { 557 | #[status = BAD_REQUEST] 558 | #[body = "`Sec-WebSocket-Version` header did not include '13'"] 559 | /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). 560 | pub struct InvalidWebSocketVersionHeader; 561 | } 562 | 563 | define_rejection! { 564 | #[status = BAD_REQUEST] 565 | #[body = "`Sec-WebSocket-Key` header missing"] 566 | /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). 567 | pub struct WebSocketKeyHeaderMissing; 568 | } 569 | 570 | macro_rules! composite_rejection { 571 | ( 572 | $(#[$m:meta])* 573 | pub enum $name:ident { 574 | $($variant:ident),+ 575 | $(,)? 576 | } 577 | ) => { 578 | $(#[$m])* 579 | #[derive(Debug)] 580 | #[non_exhaustive] 581 | pub enum $name { 582 | $( 583 | #[allow(missing_docs)] 584 | $variant($variant) 585 | ),+ 586 | } 587 | 588 | impl IntoResponse for $name { 589 | fn into_response(self) -> Response { 590 | match self { 591 | $( 592 | Self::$variant(inner) => inner.into_response(), 593 | )+ 594 | } 595 | } 596 | } 597 | 598 | $( 599 | impl From<$variant> for $name { 600 | fn from(inner: $variant) -> Self { 601 | Self::$variant(inner) 602 | } 603 | } 604 | )+ 605 | 606 | impl std::fmt::Display for $name { 607 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 608 | match self { 609 | $( 610 | Self::$variant(inner) => write!(f, "{}", inner), 611 | )+ 612 | } 613 | } 614 | } 615 | 616 | impl std::error::Error for $name { 617 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 618 | match self { 619 | $( 620 | Self::$variant(inner) => Some(inner), 621 | )+ 622 | } 623 | } 624 | } 625 | }; 626 | } 627 | 628 | composite_rejection! { 629 | /// Rejection used for [`WebSocketUpgrade`](super::WebSocketUpgrade). 630 | /// 631 | /// Contains one variant for each way the [`WebSocketUpgrade`](super::WebSocketUpgrade) 632 | /// extractor can fail. 633 | pub enum WebSocketUpgradeRejection { 634 | MethodNotGet, 635 | InvalidConnectionHeader, 636 | InvalidUpgradeHeader, 637 | InvalidWebSocketVersionHeader, 638 | WebSocketKeyHeaderMissing, 639 | } 640 | } 641 | } 642 | --------------------------------------------------------------------------------