├── .gitignore ├── Cargo.toml ├── LICENSE.md ├── README.md └── src ├── error.rs ├── lib.rs ├── secure.rs └── websocket ├── builder.rs ├── frame.rs ├── handshake.rs ├── mod.rs ├── parsed_addr.rs ├── split.rs └── stream.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["imranmaj <49664304+imranmaj@users.noreply.github.com>"] 3 | categories = [ 4 | "asynchronous", 5 | "network-programming", 6 | "web-programming", 7 | "web-programming::websocket", 8 | ] 9 | description = "A WebSocket client implementation." 10 | documentation = "http://docs.rs/websockets" 11 | edition = "2018" 12 | keywords = ["websocket", "websockets", "async", "tokio", "io"] 13 | license = "MIT" 14 | name = "websockets" 15 | readme = "README.md" 16 | repository = "https://github.com/imranmaj/websockets" 17 | version = "0.3.0" 18 | 19 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 20 | 21 | [dependencies] 22 | base64 = "0.12.3" 23 | flume = "0.10.7" 24 | futures = "0.3.5" 25 | native-tls = "0.2.6" 26 | rand = "0.7.3" 27 | rand_chacha = "0.2.2" 28 | regex = "1.3.9" 29 | sha-1 = "0.9.1" 30 | thiserror = "1.0.20" 31 | tokio = { version = "1.9", features = ["net", "io-util"] } 32 | tokio-native-tls = "0.3.0" 33 | url = "2.1.1" 34 | 35 | [dev-dependencies] 36 | tokio = { version = "1.9", features = ["rt-multi-thread", "macros"] } 37 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Imran Majeed 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WebSockets 2 | 3 | [github](https://github.com/imranmaj/websockets) [crates.io](https://crates.io/crates/websockets) [docs.rs](https://docs.rs/websockets) 4 | 5 | A WebSocket client implementation. 6 | 7 | ```rust 8 | use websockets::WebSocket; 9 | 10 | let mut ws = WebSocket::connect("wss://echo.websocket.org/").await?; 11 | ws.send_text("foo".to_string()).await?; 12 | ws.receive().await?; 13 | ws.close(None).await?; 14 | ``` 15 | 16 | ## Features 17 | 18 | * Simple API 19 | * Async/await (tokio runtime) 20 | * TLS support (automatically detected) 21 | 22 | ## Usage 23 | 24 | The `WebSocket` type manages the WebSocket connection. 25 | Use it to connect, send, and receive data. 26 | Data is sent and received through `Frame`s. 27 | 28 | ## License 29 | 30 | This project is licensed under the MIT license. 31 | 32 | ## Credits 33 | 34 | * Thank you to [@thsioutas](https://github.com/thsioutas) for adding support for custom TLS configuration 35 | * Thank you to [@secana](https://github.com/secana) for making the write half `Send` 36 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use native_tls::Error as NativeTlsError; 2 | use std::io::Error as IoError; 3 | use thiserror::Error; 4 | use url::ParseError; 5 | 6 | /// The possible error types from the WebSocket connection. 7 | #[derive(Error, Debug)] 8 | pub enum WebSocketError { 9 | // connection errors 10 | /// Error connecting using TCP 11 | #[error("could not connect using TCP")] 12 | TcpConnectionError(IoError), 13 | /// Error connecting using TLS 14 | #[error("could not connect using TLS")] 15 | TlsConnectionError(NativeTlsError), 16 | /// Error building WebSocket with given TLS configuration 17 | #[error("could not build WebSocket with given TLS configuration")] 18 | TlsBuilderError(NativeTlsError), 19 | /// Error creating a TLS configuration (such as in method calls on 20 | /// [`TlsCertificate`](crate::secure::TlsCertificate) or 21 | /// [`TlsIdentity`](crate::secure::TlsIdentity)) 22 | #[error("error with TLS configuration")] 23 | TlsConfigurationError(NativeTlsError), 24 | /// Attempted to use the WebSocket when it is already closed 25 | #[error("websocket is already closed")] 26 | WebSocketClosedError, 27 | /// Error shutting down the internal stream 28 | #[error("error shutting down stream")] 29 | ShutdownError(IoError), 30 | 31 | // handshake errors 32 | /// Invalid handshake response from the server 33 | #[error("invalid handshake response")] 34 | InvalidHandshakeError, 35 | /// The server rejected the handshake request 36 | #[error("server rejected handshake")] 37 | HandshakeFailedError { 38 | /// Status code from the server's handshake response 39 | status_code: String, 40 | /// Headers from the server's handshake response 41 | headers: Vec<(String, String)>, 42 | /// Body of the server's handshake response, if any 43 | body: Option, 44 | }, 45 | 46 | // frame errors 47 | /// Attempted to use a control frame whose payload is more than 125 bytes 48 | #[error("control frame has payload larger than 125 bytes")] 49 | ControlFrameTooLargeError, 50 | /// Attempted to use a frame whose payload is too large 51 | #[error("payload is too large")] 52 | PayloadTooLargeError, 53 | /// Received an invalid frame 54 | #[error("received frame is invalid")] 55 | InvalidFrameError, 56 | /// Received a masked frame from the server 57 | #[error("received masked frame")] 58 | ReceivedMaskedFrameError, 59 | 60 | // url errors 61 | /// URL could not be parsed 62 | #[error("url could not be parsed")] 63 | ParseError(ParseError), 64 | /// URL has invalid WebSocket scheme (use "ws" or "wss") 65 | #[error(r#"invalid websocket scheme (use "ws" or "wss")"#)] 66 | SchemeError, 67 | /// URL host is invalid or missing 68 | #[error("invalid or missing host")] 69 | HostError, 70 | /// URL port is invalid 71 | #[error("invalid or unknown port")] 72 | PortError, 73 | /// Could not parse URL into SocketAddrs 74 | #[error("could not parse into SocketAddrs")] 75 | SocketAddrError(IoError), 76 | /// Could not resolve the URL's domain 77 | #[error("could not resolve domain")] 78 | ResolutionError, 79 | 80 | // reading and writing 81 | /// Error reading from WebSocket 82 | #[error("could not read from WebSocket")] 83 | ReadError(IoError), 84 | /// Error writing to WebSocket 85 | #[error("could not write to WebSocket")] 86 | WriteError(IoError), 87 | 88 | // splitting 89 | /// Issue with mpsc channel 90 | #[error("error using channel")] 91 | ChannelError, 92 | } 93 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! A WebSocket client implementation. 2 | //! 3 | //! [github](https://github.com/imranmaj/websockets) [crates.io](https://crates.io/crates/websockets) [docs.rs](https://docs.rs/websockets) 4 | //! 5 | //! ```rust 6 | //! # use websockets::WebSocketError; 7 | //! use websockets::WebSocket; 8 | //! 9 | //! # #[tokio::main] 10 | //! # async fn main() -> Result<(), WebSocketError> { 11 | //! let mut ws = WebSocket::connect("wss://echo.websocket.org/").await?; 12 | //! ws.send_text("foo".to_string()).await?; 13 | //! ws.receive().await?; 14 | //! ws.close(None).await?; 15 | //! # Ok(()) 16 | //! # } 17 | //! ``` 18 | //! 19 | //! ## Features 20 | //! 21 | //! * Simple API 22 | //! * Async/await (tokio runtime) 23 | //! * TLS support (automatically detected) 24 | //! 25 | //! ## Usage 26 | //! 27 | //! The [`WebSocket`] type manages the WebSocket connection. 28 | //! Use it to connect, send, and receive data. 29 | //! Data is sent and received through [`Frame`]s. 30 | //! 31 | //! ## License 32 | //! 33 | //! This project is licensed under the MIT license. 34 | 35 | #![forbid( 36 | unsafe_code, 37 | missing_debug_implementations, 38 | missing_docs, 39 | missing_debug_implementations 40 | )] 41 | 42 | mod error; 43 | pub mod secure; 44 | mod websocket; 45 | 46 | pub use error::WebSocketError; 47 | pub use websocket::frame::Frame; 48 | pub use websocket::split::{WebSocketReadHalf, WebSocketWriteHalf}; 49 | pub use websocket::{builder::WebSocketBuilder, WebSocket}; 50 | 51 | #[cfg(test)] 52 | mod tests { 53 | use crate::*; 54 | 55 | #[tokio::test] 56 | async fn echo_length_0_to_125() { 57 | let mut ws = WebSocket::connect("ws://echo.websocket.org/") 58 | .await 59 | .unwrap(); 60 | let message = "a".repeat(3).to_string(); 61 | ws.send_text(message.clone()).await.unwrap(); 62 | let received_frame = ws.receive().await.unwrap(); 63 | let received_message = received_frame.as_text().unwrap().0.clone(); 64 | assert_eq!(message, received_message); 65 | } 66 | 67 | #[tokio::test] 68 | async fn echo_length_126_to_u16_max() { 69 | let mut ws = WebSocket::connect("ws://echo.websocket.org/") 70 | .await 71 | .unwrap(); 72 | let message = "a".repeat(300).to_string(); 73 | ws.send_text(message.clone()).await.unwrap(); 74 | let received_frame = ws.receive().await.unwrap(); 75 | let received_message = received_frame.as_text().unwrap().0.clone(); 76 | assert_eq!(message, received_message); 77 | } 78 | 79 | #[tokio::test] 80 | async fn echo_length_u16_max_to_u64_max() { 81 | let mut ws = WebSocket::connect("ws://echo.websocket.org/") 82 | .await 83 | .unwrap(); 84 | let message = "a".repeat(66000).to_string(); 85 | ws.send_text(message.clone()).await.unwrap(); 86 | let received_frame = ws.receive().await.unwrap(); 87 | let received_message = received_frame.as_text().unwrap().0.clone(); 88 | assert_eq!(message, received_message); 89 | } 90 | 91 | #[tokio::test] 92 | async fn echo_tls() { 93 | let mut ws = WebSocket::connect("wss://echo.websocket.org/") 94 | .await 95 | .unwrap(); 96 | let message = "a".repeat(66000).to_string(); 97 | ws.send_text(message.clone()).await.unwrap(); 98 | let received_frame = ws.receive().await.unwrap(); 99 | let received_message = received_frame.as_text().unwrap().0.clone(); 100 | assert_eq!(message, received_message); 101 | } 102 | 103 | #[tokio::test] 104 | async fn close() { 105 | let mut ws = WebSocket::connect("wss://echo.websocket.org") 106 | .await 107 | .unwrap(); 108 | ws.close(Some((1000, String::new()))).await.unwrap(); 109 | let status_code = ws.receive().await.unwrap().as_close().unwrap().0; 110 | assert_eq!(status_code, 1000); 111 | } 112 | 113 | #[tokio::test] 114 | async fn bad_scheme() { 115 | let resp = WebSocket::connect("http://echo.websocket.org").await; 116 | if let Ok(_) = resp { 117 | panic!("expected to fail with bad scheme"); 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /src/secure.rs: -------------------------------------------------------------------------------- 1 | //! Types used to customize a secure WebSocket connection; used as arguments to 2 | //! methods on the [`WebSocketBuilder`](crate::WebSocketBuilder). 3 | 4 | use std::fmt::{Debug, Error as FmtError, Formatter}; 5 | 6 | pub use native_tls::Protocol as TlsProtocol; 7 | use native_tls::{Certificate, Identity}; 8 | 9 | use crate::error::WebSocketError; 10 | 11 | // Wrapper types are necessary because the methods need to return 12 | // Result<_, WebSocketError> not Result<_, NativeTlsError>. 13 | // Documentation is copied from native_tls. 14 | 15 | /// An X509 certificate. 16 | #[derive(Clone)] 17 | pub struct TlsCertificate(pub(crate) Certificate); 18 | 19 | impl Debug for TlsCertificate { 20 | fn fmt(&self, f: &mut Formatter) -> Result<(), FmtError> { 21 | f.write_str("TlsCertificate") 22 | } 23 | } 24 | 25 | impl TlsCertificate { 26 | /// Parses a DER-formatted X509 certificate. 27 | pub fn from_der(der: &[u8]) -> Result { 28 | Ok(Self( 29 | Certificate::from_der(der).map_err(|e| WebSocketError::TlsConfigurationError(e))?, 30 | )) 31 | } 32 | 33 | /// Parses a PEM-formatted X509 certificate. 34 | pub fn from_pem(pem: &[u8]) -> Result { 35 | Ok(Self( 36 | Certificate::from_pem(pem).map_err(|e| WebSocketError::TlsConfigurationError(e))?, 37 | )) 38 | } 39 | 40 | /// Returns the DER-encoded representation of this certificate. 41 | pub fn to_der(&self) -> Result, WebSocketError> { 42 | self.0 43 | .to_der() 44 | .map_err(|e| WebSocketError::TlsConfigurationError(e)) 45 | } 46 | } 47 | 48 | /// A cryptographic identity. 49 | /// 50 | /// An identity is an X509 certificate along with its corresponding private key and chain of certificates to a trusted 51 | /// root. 52 | #[derive(Clone)] 53 | pub struct TlsIdentity(pub(crate) Identity); 54 | 55 | impl Debug for TlsIdentity { 56 | fn fmt(&self, f: &mut Formatter) -> Result<(), FmtError> { 57 | f.write_str("TlsIdentity") 58 | } 59 | } 60 | 61 | impl TlsIdentity { 62 | /// Parses a DER-formatted PKCS #12 archive, using the specified password to decrypt the key. 63 | /// 64 | /// The archive should contain a leaf certificate and its private key, as well any intermediate 65 | /// certificates that should be sent to clients to allow them to build a chain to a trusted 66 | /// root. The chain certificates should be in order from the leaf certificate towards the root. 67 | /// 68 | /// PKCS #12 archives typically have the file extension `.p12` or `.pfx`, and can be created 69 | /// with the OpenSSL `pkcs12` tool: 70 | /// 71 | /// ```bash 72 | /// openssl pkcs12 -export -out identity.pfx -inkey key.pem -in cert.pem -certfile chain_certs.pem 73 | /// ``` 74 | pub fn from_pkcs12(der: &[u8], password: &str) -> Result { 75 | Ok(Self( 76 | Identity::from_pkcs12(der, password) 77 | .map_err(|e| WebSocketError::TlsConfigurationError(e))?, 78 | )) 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/websocket/builder.rs: -------------------------------------------------------------------------------- 1 | use std::convert::TryFrom; 2 | use std::fmt::{Debug, Error as FmtError, Formatter}; 3 | 4 | use native_tls::{ 5 | TlsConnector as NativeTlsTlsConnector, TlsConnectorBuilder as NativeTlsTlsConnectorBuilder, 6 | }; 7 | use rand::SeedableRng; 8 | use rand_chacha::ChaCha20Rng; 9 | use tokio::io::{self, BufReader, BufWriter}; 10 | use tokio::net::TcpStream; 11 | 12 | use super::handshake::Handshake; 13 | use super::parsed_addr::ParsedAddr; 14 | use super::split::{WebSocketReadHalf, WebSocketWriteHalf}; 15 | use super::stream::Stream; 16 | use super::FrameType; 17 | use super::WebSocket; 18 | use crate::error::WebSocketError; 19 | use crate::secure::{TlsCertificate, TlsIdentity, TlsProtocol}; 20 | 21 | /// A builder used to customize the WebSocket handshake. 22 | /// 23 | /// Handshake headers as well as subprotocols can be added and removed. 24 | /// Methods prefixed with `tls_` allow for the customization of a secure 25 | /// WebSocket connection. 26 | /// 27 | /// ``` 28 | /// # use websockets::{WebSocket, WebSocketError}; 29 | /// # #[tokio::main] 30 | /// # async fn main() -> Result<(), WebSocketError> { 31 | /// let mut ws = WebSocket::builder() 32 | /// .add_subprotocol("wamp") 33 | /// .connect("wss://echo.websocket.org") 34 | /// .await?; 35 | /// # Ok(()) 36 | /// # } 37 | /// ``` 38 | pub struct WebSocketBuilder { 39 | additional_handshake_headers: Vec<(String, String)>, 40 | subprotocols: Vec, 41 | tls_connector_builder: NativeTlsTlsConnectorBuilder, 42 | } 43 | 44 | impl Debug for WebSocketBuilder { 45 | fn fmt(&self, f: &mut Formatter) -> Result<(), FmtError> { 46 | f.write_str("WebSocketBuilder") 47 | } 48 | } 49 | 50 | impl WebSocketBuilder { 51 | pub(super) fn new() -> Self { 52 | Self { 53 | additional_handshake_headers: Vec::new(), 54 | subprotocols: Vec::new(), 55 | tls_connector_builder: NativeTlsTlsConnector::builder(), 56 | } 57 | } 58 | 59 | /// Builds a [`WebSocket`] using this builder, then connects to a URL 60 | /// (and performs the WebSocket handshake). 61 | /// 62 | /// After calling this method, no more methods should be called on this builder. 63 | pub async fn connect(&mut self, url: &str) -> Result { 64 | let parsed_addr = ParsedAddr::try_from(url)?; 65 | 66 | let stream = Stream::Plain( 67 | TcpStream::connect(parsed_addr.addr) 68 | .await 69 | .map_err(|e| WebSocketError::TcpConnectionError(e))?, 70 | ); 71 | let stream = match &parsed_addr.scheme[..] { 72 | // https://tools.ietf.org/html/rfc6455#section-11.1.1 73 | "ws" => stream, 74 | // https://tools.ietf.org/html/rfc6455#section-11.1.2 75 | "wss" => { 76 | let tls_config = self 77 | .tls_connector_builder 78 | .build() 79 | .map_err(|e| WebSocketError::TlsBuilderError(e))?; 80 | stream.into_tls(&parsed_addr.host, tls_config).await? 81 | } 82 | _ => return Err(WebSocketError::SchemeError), 83 | }; 84 | let (read_half, write_half) = io::split(stream); 85 | let (sender, receiver) = flume::unbounded(); 86 | let mut ws = WebSocket { 87 | read_half: WebSocketReadHalf { 88 | stream: BufReader::new(read_half), 89 | last_frame_type: FrameType::default(), 90 | sender, 91 | }, 92 | write_half: WebSocketWriteHalf { 93 | shutdown: false, 94 | sent_closed: false, 95 | stream: BufWriter::new(write_half), 96 | rng: ChaCha20Rng::from_entropy(), 97 | receiver, 98 | }, 99 | accepted_subprotocol: None, 100 | handshake_response_headers: None, 101 | }; 102 | 103 | // perform opening handshake 104 | let handshake = Handshake::new( 105 | &parsed_addr, 106 | &self.additional_handshake_headers, 107 | &self.subprotocols, 108 | ); 109 | handshake.send_request(&mut ws).await?; 110 | match handshake.check_response(&mut ws).await { 111 | Ok(_) => Ok(ws), 112 | Err(e) => { 113 | ws.shutdown().await?; 114 | Err(e) 115 | } 116 | } 117 | } 118 | 119 | /// Adds a header to be sent in the WebSocket handshake. 120 | pub fn add_header(&mut self, header_name: &str, header_value: &str) -> &mut Self { 121 | // https://tools.ietf.org/html/rfc6455#section-4.2.2 122 | self.additional_handshake_headers 123 | .push((header_name.to_string(), header_value.to_string())); 124 | self 125 | } 126 | 127 | /// Removes a header which would be sent in the WebSocket handshake. 128 | pub fn remove_header(&mut self, header_name: &str) -> &mut Self { 129 | // https://tools.ietf.org/html/rfc6455#section-4.2.2 130 | self.additional_handshake_headers 131 | .retain(|header| header.0 != header_name); 132 | self 133 | } 134 | 135 | /// Adds a subprotocol to the list of subprotocols to be sent in the 136 | /// WebSocket handshake. The server may select a subprotocol from this list. 137 | /// If it does, the selected subprotocol can be found using the 138 | /// [`WebSocket::accepted_subprotocol()`] method. 139 | pub fn add_subprotocol(&mut self, subprotocol: &str) -> &mut Self { 140 | // https://tools.ietf.org/html/rfc6455#section-1.9 141 | self.subprotocols.push(subprotocol.to_string()); 142 | self 143 | } 144 | 145 | /// Removes a subprotocol from the list of subprotocols that would be sent 146 | /// in the WebSocket handshake. 147 | pub fn remove_subprotocol(&mut self, subprotocol: &str) -> &mut Self { 148 | // https://tools.ietf.org/html/rfc6455#section-1.9 149 | self.subprotocols.retain(|s| s != subprotocol); 150 | self 151 | } 152 | 153 | /// Controls the use of certificate validation. Defaults to false. 154 | pub fn tls_danger_accept_invalid_certs(&mut self, accept_invalid_certs: bool) -> &mut Self { 155 | self.tls_connector_builder 156 | .danger_accept_invalid_certs(accept_invalid_certs); 157 | self 158 | } 159 | 160 | /// Controls the use of hostname verification. Defaults to false. 161 | pub fn tls_danger_accept_invalid_hostnames( 162 | &mut self, 163 | accept_invalid_hostnames: bool, 164 | ) -> &mut Self { 165 | self.tls_connector_builder 166 | .danger_accept_invalid_hostnames(accept_invalid_hostnames); 167 | self 168 | } 169 | 170 | /// Adds a certificate to the set of roots that the connector will trust. 171 | /// The connector will use the system's trust root by default. This method can be used to add 172 | /// to that set when communicating with servers not trusted by the system. 173 | /// Defaults to an empty set. 174 | pub fn tls_add_root_certificate(&mut self, cert: TlsCertificate) -> &mut Self { 175 | self.tls_connector_builder.add_root_certificate(cert.0); 176 | self 177 | } 178 | 179 | /// Controls the use of built-in system certificates during certificate validation. 180 | /// Defaults to false -- built-in system certs will be used. 181 | pub fn tls_disable_built_in_roots(&mut self, disable: bool) -> &mut Self { 182 | self.tls_connector_builder.disable_built_in_roots(disable); 183 | self 184 | } 185 | 186 | /// Sets the identity to be used for client certificate authentication. 187 | pub fn tls_identity(&mut self, identity: TlsIdentity) -> &mut Self { 188 | self.tls_connector_builder.identity(identity.0); 189 | self 190 | } 191 | 192 | /// Sets the maximum supported TLS protocol version. 193 | /// A value of None enables support for the newest protocols supported by the implementation. 194 | /// Defaults to None. 195 | pub fn tls_max_protocol_version(&mut self, protocol: Option) -> &mut Self { 196 | self.tls_connector_builder.max_protocol_version(protocol); 197 | self 198 | } 199 | 200 | /// Sets the minimum supported TLS protocol version. 201 | /// A value of None enables support for the oldest protocols supported by the implementation. 202 | /// Defaults to Some(Protocol::Tlsv10). 203 | pub fn tls_min_protocol_version(&mut self, protocol: Option) -> &mut Self { 204 | self.tls_connector_builder.min_protocol_version(protocol); 205 | self 206 | } 207 | 208 | /// Controls the use of Server Name Indication (SNI). 209 | /// Defaults to true. 210 | pub fn tls_use_sni(&mut self, use_sni: bool) -> &mut Self { 211 | self.tls_connector_builder.use_sni(use_sni); 212 | self 213 | } 214 | } 215 | -------------------------------------------------------------------------------- /src/websocket/frame.rs: -------------------------------------------------------------------------------- 1 | use std::convert::TryInto; 2 | 3 | use rand::RngCore; 4 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 5 | 6 | use super::split::{WebSocketReadHalf, WebSocketWriteHalf}; 7 | use super::FrameType; 8 | #[allow(unused_imports)] // for intra doc links 9 | use super::WebSocket; 10 | use crate::error::WebSocketError; 11 | 12 | const U16_MAX_MINUS_ONE: usize = (u16::MAX - 1) as usize; 13 | const U16_MAX: usize = u16::MAX as usize; 14 | const U64_MAX_MINUS_ONE: usize = (u64::MAX - 1) as usize; 15 | 16 | // https://tools.ietf.org/html/rfc6455#section-5.2 17 | /// Data which is sent and received through the WebSocket connection. 18 | /// 19 | /// # Sending 20 | /// 21 | /// To send a Frame, you can construct it normally and use the [`WebSocket::send()`] method, 22 | /// or use the convenience methods for each frame type 23 | /// ([`send_text()`](WebSocket::send_text()), [`send_binary()`](WebSocket::send_binary()), 24 | /// [`close()`](WebSocket::close()), [`send_ping()`](WebSocket::send_ping()), 25 | /// and [`send_pong()`](WebSocket::send_pong())). 26 | /// 27 | /// # Receiving 28 | /// 29 | /// Frames can be received through the [`WebSocket::receive()`] method. 30 | /// To extract the underlying data from a received Frame, 31 | /// you can `match` or use the convenience methods—for example, for text frames, 32 | /// you can use the method [`as_text`](Frame::as_text()) to get an immutable reference 33 | /// to the data, [`as_text_mut`](Frame::as_text_mut()) to get a mutable reference to the data, 34 | /// or [`into_text`](Frame::into_text()) to get ownership of the data. 35 | /// 36 | /// # Fragmentation 37 | /// 38 | /// As per the WebSocket protocol, frames can actually be fragments in a larger message 39 | /// (see [https://tools.ietf.org/html/rfc6455#section-5.4](https://tools.ietf.org/html/rfc6455#section-5.4)). 40 | /// However, the maximum frame size allowed by the WebSocket protocol is larger 41 | /// than what can be stored in a `Vec`. Therefore, no strategy for splitting messages 42 | /// into Frames is provided by this library. 43 | /// 44 | /// If you would like to use fragmentation manually, this can be done by setting 45 | /// the `continuation` and `fin` flags on the `Text` and `Binary` variants. 46 | /// `continuation` signifies that the Frame is a Continuation frame in the message, 47 | /// and `fin` signifies that the Frame is the final frame in the message 48 | /// (see the above linked RFC for more details). 49 | /// 50 | /// For example, if the message contains only one Frame, the single frame 51 | /// should have `continuation` set to `false` and `fin` set to `true`. If the message 52 | /// contains more than one frame, the first frame should have `continuation` set to 53 | /// `false` and `fin` set to `false`, all other frames except the last frame should 54 | /// have `continuation` set to `true` and `fin` set to `false`, and the last frame should 55 | /// have `continuation` set to `true` and `fin` set to `true`. 56 | #[derive(Debug, Clone)] 57 | pub enum Frame { 58 | /// A Text frame 59 | Text { 60 | /// The payload for the Text frame 61 | payload: String, 62 | /// Whether the Text frame is a continuation frame in the message 63 | continuation: bool, 64 | /// Whether the Text frame is the final frame in the message 65 | fin: bool, 66 | }, 67 | /// A Binary frame 68 | Binary { 69 | /// The payload for the Binary frame 70 | payload: Vec, 71 | /// Whether the Binary frame is a continuation frame in the message 72 | continuation: bool, 73 | /// Whether the Binary frame is the final frame in the message 74 | fin: bool, 75 | }, 76 | /// A Close frame 77 | Close { 78 | /// The payload for the Close frame 79 | payload: Option<(u16, String)>, 80 | }, 81 | /// A Ping frame 82 | Ping { 83 | /// The payload for the Ping frame 84 | payload: Option>, 85 | }, 86 | /// A Pong frame 87 | Pong { 88 | /// The payload for the Pong frame 89 | payload: Option>, 90 | }, 91 | } 92 | 93 | impl Frame { 94 | /// Constructs a Text frame from the given payload. 95 | /// `continuation` will be `false` and `fin` will be `true`. 96 | /// This can be modified by chaining [`Frame::set_continuation()`] or [`Frame::set_fin()`]. 97 | pub fn text(payload: String) -> Self { 98 | Self::Text { 99 | payload, 100 | continuation: false, 101 | fin: true, 102 | } 103 | } 104 | 105 | /// Returns whether the frame is a Text frame. 106 | pub fn is_text(&self) -> bool { 107 | self.as_text().is_some() 108 | } 109 | 110 | /// Attempts to interpret the frame as a Text frame, 111 | /// returning a reference to the underlying data if it is, 112 | /// and None otherwise. 113 | pub fn as_text(&self) -> Option<(&String, &bool, &bool)> { 114 | match self { 115 | Self::Text { 116 | payload, 117 | continuation, 118 | fin, 119 | } => Some((payload, continuation, fin)), 120 | _ => None, 121 | } 122 | } 123 | /// Attempts to interpret the frame as a Text frame, 124 | /// returning a mutable reference to the underlying data if it is, 125 | /// and None otherwise. 126 | pub fn as_text_mut(&mut self) -> Option<(&mut String, &mut bool, &mut bool)> { 127 | match self { 128 | Self::Text { 129 | payload, 130 | continuation, 131 | fin, 132 | } => Some((payload, continuation, fin)), 133 | _ => None, 134 | } 135 | } 136 | 137 | /// Attempts to interpret the frame as a Text frame, 138 | /// consuming and returning the underlying data if it is, 139 | /// and returning None otherwise. 140 | pub fn into_text(self) -> Option<(String, bool, bool)> { 141 | match self { 142 | Self::Text { 143 | payload, 144 | continuation, 145 | fin, 146 | } => Some((payload, continuation, fin)), 147 | _ => None, 148 | } 149 | } 150 | 151 | /// Constructs a Binary frame from the given payload. 152 | /// `continuation` will be `false` and `fin` will be `true`. 153 | /// This can be modified by chaining [`Frame::set_continuation()`] or [`Frame::set_fin()`]. 154 | pub fn binary(payload: Vec) -> Self { 155 | Self::Binary { 156 | payload, 157 | continuation: false, 158 | fin: true, 159 | } 160 | } 161 | 162 | /// Returns whether the frame is a Binary frame. 163 | pub fn is_binary(&self) -> bool { 164 | self.as_binary().is_some() 165 | } 166 | 167 | /// Attempts to interpret the frame as a Binary frame, 168 | /// returning a reference to the underlying data if it is, 169 | /// and None otherwise. 170 | pub fn as_binary(&self) -> Option<(&Vec, &bool, &bool)> { 171 | match self { 172 | Self::Binary { 173 | payload, 174 | continuation, 175 | fin, 176 | } => Some((payload, continuation, fin)), 177 | _ => None, 178 | } 179 | } 180 | 181 | /// Attempts to interpret the frame as a Binary frame, 182 | /// returning a mutable reference to the underlying data if it is, 183 | /// and None otherwise. 184 | pub fn as_binary_mut(&mut self) -> Option<(&mut Vec, &mut bool, &mut bool)> { 185 | match self { 186 | Self::Binary { 187 | payload, 188 | continuation, 189 | fin, 190 | } => Some((payload, continuation, fin)), 191 | _ => None, 192 | } 193 | } 194 | 195 | /// Attempts to interpret the frame as a Binary frame, 196 | /// consuming and returning the underlying data if it is, 197 | /// and returning None otherwise. 198 | pub fn into_binary(self) -> Option<(Vec, bool, bool)> { 199 | match self { 200 | Self::Binary { 201 | payload, 202 | continuation, 203 | fin, 204 | } => Some((payload, continuation, fin)), 205 | _ => None, 206 | } 207 | } 208 | 209 | /// Constructs a Close frame from the given payload. 210 | pub fn close(payload: Option<(u16, String)>) -> Self { 211 | Self::Close { payload } 212 | } 213 | 214 | /// Returns whether the frame is a Close frame. 215 | pub fn is_close(&self) -> bool { 216 | self.as_close().is_some() 217 | } 218 | 219 | /// Attempts to interpret the frame as a Close frame, 220 | /// returning a reference to the underlying data if it is, 221 | /// and None otherwise. 222 | pub fn as_close(&self) -> Option<&(u16, String)> { 223 | match self { 224 | Self::Close { payload } => payload.as_ref(), 225 | _ => None, 226 | } 227 | } 228 | 229 | /// Attempts to interpret the frame as a Close frame, 230 | /// returning a mutable reference to the underlying data if it is, 231 | /// and None otherwise. 232 | pub fn as_close_mut(&mut self) -> Option<&mut (u16, String)> { 233 | match self { 234 | Self::Close { payload } => payload.as_mut(), 235 | _ => None, 236 | } 237 | } 238 | 239 | /// Attempts to interpret the frame as a Close frame, 240 | /// consuming and returning the underlying data if it is, 241 | /// and returning None otherwise. 242 | pub fn into_close(self) -> Option<(u16, String)> { 243 | match self { 244 | Self::Close { payload } => payload, 245 | _ => None, 246 | } 247 | } 248 | 249 | /// Constructs a Ping frame from the given payload. 250 | pub fn ping(payload: Option>) -> Self { 251 | Self::Ping { payload } 252 | } 253 | 254 | /// Returns whether the frame is a Ping frame. 255 | pub fn is_ping(&self) -> bool { 256 | self.as_ping().is_some() 257 | } 258 | 259 | /// Attempts to interpret the frame as a Ping frame, 260 | /// returning a reference to the underlying data if it is, 261 | /// and None otherwise. 262 | pub fn as_ping(&self) -> Option<&Vec> { 263 | match self { 264 | Self::Ping { payload } => payload.as_ref(), 265 | _ => None, 266 | } 267 | } 268 | 269 | /// Attempts to interpret the frame as a Ping frame, 270 | /// returning a mutable reference to the underlying data if it is, 271 | /// and None otherwise. 272 | pub fn as_ping_mut(&mut self) -> Option<&mut Vec> { 273 | match self { 274 | Self::Ping { payload } => payload.as_mut(), 275 | _ => None, 276 | } 277 | } 278 | 279 | /// Attempts to interpret the frame as a Ping frame, 280 | /// consuming and returning the underlying data if it is, 281 | /// and returning None otherwise. 282 | pub fn into_ping(self) -> Option> { 283 | match self { 284 | Self::Ping { payload } => payload, 285 | _ => None, 286 | } 287 | } 288 | 289 | /// Constructs a Pong frame from the given payload. 290 | pub fn pong(payload: Option>) -> Self { 291 | Self::Pong { payload } 292 | } 293 | 294 | /// Returns whether the frame is a Pong frame. 295 | pub fn is_pong(&self) -> bool { 296 | self.as_pong().is_some() 297 | } 298 | 299 | /// Attempts to interpret the frame as a Pong frame, 300 | /// returning a reference to the underlying data if it is, 301 | /// and None otherwise. 302 | pub fn as_pong(&self) -> Option<&Vec> { 303 | match self { 304 | Self::Pong { payload } => payload.as_ref(), 305 | _ => None, 306 | } 307 | } 308 | 309 | /// Attempts to interpret the frame as a Pong frame, 310 | /// returning a mutable reference to the underlying data if it is, 311 | /// and None otherwise. 312 | pub fn as_pong_mut(&mut self) -> Option<&mut Vec> { 313 | match self { 314 | Self::Pong { payload } => payload.as_mut(), 315 | _ => None, 316 | } 317 | } 318 | 319 | /// Attempts to interpret the frame as a Pong frame, 320 | /// consuming and returning the underlying data if it is, 321 | /// and returning None otherwise. 322 | pub fn into_pong(self) -> Option> { 323 | match self { 324 | Self::Pong { payload } => payload, 325 | _ => None, 326 | } 327 | } 328 | 329 | /// Modifies the frame to set `continuation` to the desired value. 330 | /// If the frame is not a Text or Binary frame, no operation is performed. 331 | pub fn set_continuation(self, continuation: bool) -> Self { 332 | match self { 333 | Self::Text { payload, fin, .. } => Self::Text { 334 | payload, 335 | continuation, 336 | fin, 337 | }, 338 | Self::Binary { payload, fin, .. } => Self::Binary { 339 | payload, 340 | continuation, 341 | fin, 342 | }, 343 | _ => self, 344 | } 345 | } 346 | 347 | /// Modifies the frame to set `fin` to the desired value. 348 | /// If the frame is not a Text or Binary frame, no operation is performed. 349 | pub fn set_fin(self, fin: bool) -> Self { 350 | match self { 351 | Self::Text { 352 | payload, 353 | continuation, 354 | .. 355 | } => Self::Text { 356 | payload, 357 | continuation, 358 | fin, 359 | }, 360 | Self::Binary { 361 | payload, 362 | continuation, 363 | .. 364 | } => Self::Binary { 365 | payload, 366 | continuation, 367 | fin, 368 | }, 369 | _ => self, 370 | } 371 | } 372 | 373 | pub(super) async fn send( 374 | self, 375 | write_half: &mut WebSocketWriteHalf, 376 | ) -> Result<(), WebSocketError> { 377 | // calculate before moving payload out of self 378 | let is_control = self.is_control(); 379 | let opcode = self.opcode(); 380 | let fin = self.fin(); 381 | 382 | let mut payload = match self { 383 | // https://tools.ietf.org/html/rfc6455#section-5.6 384 | Self::Text { payload, .. } => payload.into_bytes(), 385 | Self::Binary { payload, .. } => payload, 386 | // https://tools.ietf.org/html/rfc6455#section-5.5.1 387 | Self::Close { 388 | payload: Some((status_code, reason)), 389 | } => { 390 | let mut payload = status_code.to_be_bytes().to_vec(); 391 | payload.append(&mut reason.into_bytes()); 392 | payload 393 | } 394 | Self::Close { payload: None } => Vec::new(), 395 | // https://tools.ietf.org/html/rfc6455#section-5.5.2 396 | Self::Ping { payload } => payload.unwrap_or(Vec::new()), 397 | // https://tools.ietf.org/html/rfc6455#section-5.5.3 398 | Self::Pong { payload } => payload.unwrap_or(Vec::new()), 399 | }; 400 | // control frame cannot be longer than 125 bytes: https://tools.ietf.org/html/rfc6455#section-5.5 401 | if is_control && payload.len() > 125 { 402 | return Err(WebSocketError::ControlFrameTooLargeError); 403 | } 404 | 405 | // set payload len: https://tools.ietf.org/html/rfc6455#section-5.2 406 | let mut raw_frame = Vec::with_capacity(payload.len() + 14); 407 | raw_frame.push(opcode + fin); 408 | let mut payload_len_data = match payload.len() { 409 | 0..=125 => (payload.len() as u8).to_be_bytes().to_vec(), 410 | 126..=U16_MAX_MINUS_ONE => { 411 | let mut payload_len_data = vec![126]; 412 | payload_len_data.extend_from_slice(&(payload.len() as u16).to_be_bytes()); 413 | payload_len_data 414 | } 415 | U16_MAX..=U64_MAX_MINUS_ONE => { 416 | let mut payload_len_data = vec![127]; 417 | payload_len_data.extend_from_slice(&(payload.len() as u64).to_be_bytes()); 418 | payload_len_data 419 | } 420 | _ => return Err(WebSocketError::PayloadTooLargeError), 421 | }; 422 | payload_len_data[0] += 0b10000000; // set masking bit: https://tools.ietf.org/html/rfc6455#section-5.3 423 | raw_frame.append(&mut payload_len_data); 424 | 425 | // payload masking: https://tools.ietf.org/html/rfc6455#section-5.3 426 | let mut masking_key = vec![0; 4]; 427 | write_half.rng.fill_bytes(&mut masking_key); 428 | for (i, byte) in payload.iter_mut().enumerate() { 429 | *byte = *byte ^ (masking_key[i % 4]); 430 | } 431 | raw_frame.append(&mut masking_key); 432 | 433 | raw_frame.append(&mut payload); 434 | 435 | write_half 436 | .stream 437 | .write_all(&raw_frame) 438 | .await 439 | .map_err(|e| WebSocketError::WriteError(e))?; 440 | write_half 441 | .stream 442 | .flush() 443 | .await 444 | .map_err(|e| WebSocketError::WriteError(e))?; 445 | Ok(()) 446 | } 447 | 448 | fn is_control(&self) -> bool { 449 | // control frames: https://tools.ietf.org/html/rfc6455#section-5.5 450 | match self { 451 | Self::Text { .. } => false, 452 | Self::Binary { .. } => false, 453 | Self::Close { .. } => true, 454 | Self::Ping { .. } => true, 455 | Self::Pong { .. } => true, 456 | } 457 | } 458 | 459 | fn opcode(&self) -> u8 { 460 | // opcodes: https://tools.ietf.org/html/rfc6455#section-5.2 461 | match self { 462 | Self::Text { continuation, .. } => { 463 | if *continuation { 464 | 0x0 465 | } else { 466 | 0x1 467 | } 468 | } 469 | Self::Binary { continuation, .. } => { 470 | if *continuation { 471 | 0x0 472 | } else { 473 | 0x2 474 | } 475 | } 476 | Self::Close { .. } => 0x8, 477 | Self::Ping { .. } => 0x9, 478 | Self::Pong { .. } => 0xA, 479 | } 480 | } 481 | 482 | fn fin(&self) -> u8 { 483 | // fin bit: https://tools.ietf.org/html/rfc6455#section-5.2 484 | match self { 485 | Self::Text { fin, .. } => (*fin as u8) << 7, 486 | Self::Binary { fin, .. } => (*fin as u8) << 7, 487 | Self::Close { .. } => 0b10000000, 488 | Self::Ping { .. } => 0b10000000, 489 | Self::Pong { .. } => 0b10000000, 490 | } 491 | } 492 | 493 | pub(super) async fn read_from_websocket( 494 | read_half: &mut WebSocketReadHalf, 495 | ) -> Result { 496 | // https://tools.ietf.org/html/rfc6455#section-5.2 497 | let fin_and_opcode = read_half 498 | .stream 499 | .read_u8() 500 | .await 501 | .map_err(|e| WebSocketError::ReadError(e))?; 502 | let fin: bool = fin_and_opcode & 0b10000000_u8 != 0; 503 | let opcode = fin_and_opcode & 0b00001111_u8; 504 | 505 | let mask_and_payload_len_first_byte = read_half 506 | .stream 507 | .read_u8() 508 | .await 509 | .map_err(|e| WebSocketError::ReadError(e))?; 510 | let masked = mask_and_payload_len_first_byte & 0b10000000_u8 != 0; 511 | if masked { 512 | // server to client frames should not be masked 513 | return Err(WebSocketError::ReceivedMaskedFrameError); 514 | } 515 | let payload_len_first_byte = mask_and_payload_len_first_byte & 0b01111111_u8; 516 | let payload_len = match payload_len_first_byte { 517 | 0..=125 => payload_len_first_byte as usize, 518 | 126 => read_half 519 | .stream 520 | .read_u16() 521 | .await 522 | .map_err(|e| WebSocketError::ReadError(e))? as usize, 523 | 127 => read_half 524 | .stream 525 | .read_u64() 526 | .await 527 | .map_err(|e| WebSocketError::ReadError(e))? as usize, 528 | _ => unreachable!(), 529 | }; 530 | 531 | let mut payload = vec![0; payload_len]; 532 | read_half 533 | .stream 534 | .read_exact(&mut payload) 535 | .await 536 | .map_err(|e| WebSocketError::ReadError(e))?; 537 | 538 | match opcode { 539 | 0x0 => match read_half.last_frame_type { 540 | FrameType::Text => Ok(Self::Text { 541 | payload: String::from_utf8(payload) 542 | .map_err(|_e| WebSocketError::InvalidFrameError)?, 543 | continuation: true, 544 | fin, 545 | }), 546 | FrameType::Binary => Ok(Self::Binary { 547 | payload, 548 | continuation: true, 549 | fin, 550 | }), 551 | FrameType::Control => Err(WebSocketError::InvalidFrameError), 552 | }, 553 | 0x1 => Ok(Self::Text { 554 | payload: String::from_utf8(payload) 555 | .map_err(|_e| WebSocketError::InvalidFrameError)?, 556 | continuation: false, 557 | fin, 558 | }), 559 | 0x2 => Ok(Self::Binary { 560 | payload, 561 | continuation: false, 562 | fin, 563 | }), 564 | // reserved range 565 | 0x3..=0x7 => Err(WebSocketError::InvalidFrameError), 566 | 0x8 if payload_len == 0 => Ok(Self::Close { payload: None }), 567 | // if there is a payload it must have a u16 status code 568 | 0x8 if payload_len < 2 => Err(WebSocketError::InvalidFrameError), 569 | 0x8 => { 570 | let (status_code, reason) = payload.split_at(2); 571 | let status_code = u16::from_be_bytes( 572 | status_code 573 | .try_into() 574 | .map_err(|_e| WebSocketError::InvalidFrameError)?, 575 | ); 576 | Ok(Self::Close { 577 | payload: Some(( 578 | status_code, 579 | String::from_utf8(reason.to_vec()) 580 | .map_err(|_e| WebSocketError::InvalidFrameError)?, 581 | )), 582 | }) 583 | } 584 | 0x9 if payload_len == 0 => Ok(Self::Ping { payload: None }), 585 | 0x9 => Ok(Self::Ping { 586 | payload: Some(payload), 587 | }), 588 | 0xA if payload_len == 0 => Ok(Self::Pong { payload: None }), 589 | 0xA => Ok(Self::Pong { 590 | payload: Some(payload), 591 | }), 592 | // reserved range 593 | 0xB..=0xFF => Err(WebSocketError::InvalidFrameError), 594 | } 595 | } 596 | } 597 | 598 | impl From for Frame { 599 | fn from(s: String) -> Self { 600 | Self::text(s) 601 | } 602 | } 603 | 604 | impl From> for Frame { 605 | fn from(v: Vec) -> Self { 606 | Self::binary(v) 607 | } 608 | } 609 | -------------------------------------------------------------------------------- /src/websocket/handshake.rs: -------------------------------------------------------------------------------- 1 | use rand::{RngCore, SeedableRng}; 2 | use rand_chacha::ChaCha20Rng; 3 | use regex::Regex; 4 | use sha1::{Digest, Sha1}; 5 | use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; 6 | 7 | use super::parsed_addr::ParsedAddr; 8 | use super::WebSocket; 9 | use crate::error::WebSocketError; 10 | 11 | const GUUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 12 | 13 | #[derive(Debug)] 14 | pub(super) struct Handshake { 15 | path: String, 16 | host: String, 17 | key: String, 18 | version: usize, 19 | additional_headers: Vec<(String, String)>, 20 | subprotocols: Vec, 21 | } 22 | 23 | impl Handshake { 24 | pub(super) fn new( 25 | parsed_addr: &ParsedAddr, 26 | additional_handshake_headers: &Vec<(String, String)>, 27 | subprotocols: &Vec, 28 | ) -> Self { 29 | // https://tools.ietf.org/html/rfc6455#section-5.3 30 | let mut rand_bytes = vec![0; 16]; 31 | let mut rng = ChaCha20Rng::from_entropy(); 32 | rng.fill_bytes(&mut rand_bytes); 33 | let key = base64::encode(rand_bytes); 34 | Self { 35 | path: parsed_addr.path.clone(), 36 | host: parsed_addr.host.clone(), 37 | key, 38 | // todo: support more versions 39 | version: 13, 40 | additional_headers: additional_handshake_headers.clone(), 41 | subprotocols: subprotocols.clone(), 42 | } 43 | } 44 | 45 | pub(super) async fn send_request(&self, ws: &mut WebSocket) -> Result<(), WebSocketError> { 46 | // https://tools.ietf.org/html/rfc6455#section-1.3 47 | // https://tools.ietf.org/html/rfc6455#section-4.1 48 | let mut headers = Vec::new(); 49 | headers.push(("Host".to_string(), self.host.clone())); 50 | headers.push(("Upgrade".to_string(), "websocket".to_string())); 51 | headers.push(("Connection".to_string(), "Upgrade".to_string())); 52 | headers.push(("Sec-WebSocket-Key".to_string(), self.key.clone())); 53 | headers.push(( 54 | "Sec-Websocket-Version".to_string(), 55 | self.version.to_string(), 56 | )); 57 | if self.subprotocols.len() > 0 { 58 | headers.push(( 59 | "Sec-WebSocket-Protocol".to_string(), 60 | self.subprotocols.join(", "), 61 | )); 62 | } 63 | for header in &self.additional_headers { 64 | headers.push(header.clone()); 65 | } 66 | 67 | let mut req = format!("GET {} HTTP/1.1\r\n", self.path); 68 | for (field, value) in headers { 69 | req.push_str(&format!("{}: {}\r\n", field, value)); 70 | } 71 | req.push_str("\r\n"); // end of request 72 | ws.write_half 73 | .stream 74 | .write_all(req.as_bytes()) 75 | .await 76 | .map_err(|e| WebSocketError::WriteError(e))?; 77 | ws.write_half 78 | .stream 79 | .flush() 80 | .await 81 | .map_err(|e| WebSocketError::WriteError(e))?; 82 | Ok(()) 83 | } 84 | 85 | pub(super) async fn check_response(&self, ws: &mut WebSocket) -> Result<(), WebSocketError> { 86 | // https://tools.ietf.org/html/rfc6455#section-1.3 87 | // https://tools.ietf.org/html/rfc6455#section-4.2.2 88 | let status_line_regex = Regex::new(r"HTTP/\d+\.\d+ (?P\d{3}) .+\r\n").unwrap(); 89 | let mut status_line = String::new(); 90 | 91 | ws.read_half 92 | .stream 93 | .read_line(&mut status_line) 94 | .await 95 | .map_err(|e| WebSocketError::ReadError(e))?; 96 | let captures = status_line_regex 97 | .captures(&status_line) 98 | .ok_or(WebSocketError::InvalidHandshakeError)?; 99 | let status_code = &captures["status_code"]; 100 | 101 | let mut headers = Vec::new(); 102 | let headers_regex = Regex::new(r"(?P.+?):\s*(?P.*?)\s*\r\n").unwrap(); 103 | loop { 104 | let mut header = String::new(); 105 | ws.read_half 106 | .stream 107 | .read_line(&mut header) 108 | .await 109 | .map_err(|e| WebSocketError::ReadError(e))?; 110 | match headers_regex.captures(&header) { 111 | Some(captures) => { 112 | let field = &captures["field"]; 113 | let value = &captures["value"]; 114 | headers.push((field.to_string(), value.to_string())); 115 | } 116 | None => break, // field is empty, so the header is finished (we got double crlf) 117 | } 118 | } 119 | 120 | // check status code 121 | if status_code != "101" { 122 | let body = match headers 123 | .iter() 124 | .find(|(field, _value)| field.to_lowercase() == "content-length") 125 | { 126 | Some(header) => { 127 | let body_length = header 128 | .1 129 | .parse::() 130 | .map_err(|_e| WebSocketError::InvalidHandshakeError)?; 131 | let mut body = vec![0; body_length]; 132 | ws.read_half 133 | .stream 134 | .read_exact(&mut body) 135 | .await 136 | .map_err(|e| WebSocketError::ReadError(e))?; 137 | Some( 138 | String::from_utf8(body) 139 | .map_err(|_e| WebSocketError::InvalidHandshakeError)?, 140 | ) 141 | } 142 | None => None, 143 | }; 144 | return Err(WebSocketError::HandshakeFailedError { 145 | status_code: status_code.to_string(), 146 | headers, 147 | body, 148 | }); 149 | } 150 | 151 | // check upgrade field 152 | let upgrade = headers 153 | .iter() 154 | .find(|(field, _value)| field.to_lowercase() == "upgrade") 155 | .ok_or(WebSocketError::InvalidHandshakeError)? 156 | .1 157 | .clone(); 158 | if upgrade.to_lowercase() != "websocket" { 159 | return Err(WebSocketError::InvalidHandshakeError); 160 | } 161 | 162 | // check connection field 163 | let connection = headers 164 | .iter() 165 | .find(|(field, _value)| field.to_lowercase() == "connection") 166 | .ok_or(WebSocketError::InvalidHandshakeError)? 167 | .1 168 | .clone(); 169 | if connection.to_lowercase() != "upgrade" { 170 | return Err(WebSocketError::InvalidHandshakeError); 171 | } 172 | 173 | // check extensions 174 | if let Some(_) = headers 175 | .iter() 176 | .find(|(field, _value)| field.to_lowercase() == "sec-websocket-extensions") 177 | { 178 | // extensions not supported 179 | return Err(WebSocketError::InvalidHandshakeError); 180 | } 181 | 182 | // check subprotocols 183 | let possible_subprotocol = headers 184 | .iter() 185 | .find(|(field, _value)| field.to_lowercase() == "sec-websocket-protocol") 186 | .map(|(_field, value)| value.clone()); 187 | match (possible_subprotocol, self.subprotocols.len()) { 188 | // server accepted a subprotocol that was not specified 189 | (Some(_), 0) => return Err(WebSocketError::InvalidHandshakeError), 190 | // server accepted a subprotocol that may have been specified 191 | (Some(subprotocol), _) => { 192 | if self.subprotocols.contains(&subprotocol) { 193 | ws.accepted_subprotocol = Some(subprotocol) 194 | } else { 195 | return Err(WebSocketError::InvalidHandshakeError); 196 | } 197 | } 198 | // server did not accept a subprotocol, whether one was specified or not 199 | (None, _) => (), 200 | } 201 | 202 | // validate key 203 | let accept_key = headers 204 | .iter() 205 | .find(|(field, _value)| field.to_lowercase() == "sec-websocket-accept") 206 | .ok_or(WebSocketError::InvalidHandshakeError)? 207 | .1 208 | .clone(); 209 | let mut test_key = self.key.clone(); 210 | test_key.push_str(GUUID); 211 | let hashed: [u8; 20] = Sha1::digest(test_key.as_bytes()).into(); 212 | let calculated_accept_key = base64::encode(hashed); 213 | if accept_key != calculated_accept_key { 214 | return Err(WebSocketError::InvalidHandshakeError); 215 | } 216 | 217 | ws.handshake_response_headers = Some(headers); 218 | Ok(()) 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /src/websocket/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod builder; 2 | pub mod frame; 3 | mod handshake; 4 | mod parsed_addr; 5 | pub mod split; 6 | mod stream; 7 | 8 | use crate::error::WebSocketError; 9 | use builder::WebSocketBuilder; 10 | use frame::Frame; 11 | use split::{WebSocketReadHalf, WebSocketWriteHalf}; 12 | 13 | #[derive(Debug)] 14 | enum FrameType { 15 | Text, 16 | Binary, 17 | Control, 18 | } 19 | 20 | impl Default for FrameType { 21 | fn default() -> Self { 22 | Self::Control 23 | } 24 | } 25 | 26 | /// Manages the WebSocket connection; used to connect, send data, and receive data. 27 | /// 28 | /// Connect with [`WebSocket::connect()`]: 29 | /// 30 | /// ``` 31 | /// # use websockets::{WebSocket, WebSocketError}; 32 | /// # #[tokio::main] 33 | /// # async fn main() -> Result<(), WebSocketError> { 34 | /// let mut ws = WebSocket::connect("wss://echo.websocket.org/").await?; 35 | /// # Ok(()) 36 | /// # } 37 | /// ``` 38 | /// 39 | /// Cuustomize the handshake using a [`WebSocketBuilder`] obtained from [`WebSocket::builder()`]: 40 | /// 41 | /// ``` 42 | /// # use websockets::{WebSocket, WebSocketError}; 43 | /// # #[tokio::main] 44 | /// # async fn main() -> Result<(), WebSocketError> { 45 | /// let mut ws = WebSocket::builder() 46 | /// .add_subprotocol("wamp") 47 | /// .connect("wss://echo.websocket.org") 48 | /// .await?; 49 | /// # Ok(()) 50 | /// # } 51 | /// ``` 52 | /// 53 | /// Use the `WebSocket::send*` methods to send frames: 54 | /// 55 | /// ``` 56 | /// # use websockets::{WebSocket, WebSocketError}; 57 | /// # #[tokio::main] 58 | /// # async fn main() -> Result<(), WebSocketError> { 59 | /// # let mut ws = WebSocket::connect("wss://echo.websocket.org") 60 | /// # .await?; 61 | /// ws.send_text("foo".to_string()).await?; 62 | /// # Ok(()) 63 | /// # } 64 | /// ``` 65 | /// 66 | /// Use [`WebSocket::receive()`] to receive frames: 67 | /// 68 | /// ``` 69 | /// # use websockets::{WebSocket, WebSocketError, Frame}; 70 | /// # #[tokio::main] 71 | /// # async fn main() -> Result<(), WebSocketError> { 72 | /// # let mut ws = WebSocket::connect("wss://echo.websocket.org") 73 | /// # .await?; 74 | /// # ws.send_text("foo".to_string()).await?; 75 | /// if let Frame::Text { payload: received_msg, .. } = ws.receive().await? { 76 | /// // echo.websocket.org echoes text frames 77 | /// assert_eq!(received_msg, "foo".to_string()); 78 | /// } 79 | /// # else { panic!() } 80 | /// # Ok(()) 81 | /// # } 82 | /// ``` 83 | /// 84 | /// Close the connection with [`WebSocket::close()`]: 85 | /// 86 | /// ``` 87 | /// # use websockets::{WebSocket, WebSocketError, Frame}; 88 | /// # #[tokio::main] 89 | /// # async fn main() -> Result<(), WebSocketError> { 90 | /// # let mut ws = WebSocket::connect("wss://echo.websocket.org") 91 | /// # .await?; 92 | /// ws.close(Some((1000, String::new()))).await?; 93 | /// if let Frame::Close{ payload: Some((status_code, _reason)) } = ws.receive().await? { 94 | /// assert_eq!(status_code, 1000); 95 | /// } 96 | /// # Ok(()) 97 | /// # } 98 | /// ``` 99 | /// 100 | /// # Splitting 101 | /// 102 | /// To facilitate simulataneous reads and writes, the `WebSocket` can be split 103 | /// into a [read half](WebSocketReadHalf) and a [write half](WebSocketWriteHalf). 104 | /// The read half allows frames to be received, while the write half 105 | /// allows frames to be sent. 106 | /// 107 | /// If the read half receives a Ping or Close frame, it needs to send a 108 | /// Pong or echo the Close frame and close the WebSocket, respectively. 109 | /// The write half is notified of these events, but it cannot act on them 110 | /// unless it is flushed. Events can be explicitly [`flush`](WebSocketWriteHalf::flush())ed, 111 | /// but sending a frame will also flush events. If frames are not being 112 | /// sent frequently, consider explicitly flushing events. 113 | /// 114 | /// Flushing is done automatically if you are using the the `WebSocket` type by itself. 115 | #[derive(Debug)] 116 | pub struct WebSocket { 117 | read_half: WebSocketReadHalf, 118 | write_half: WebSocketWriteHalf, 119 | accepted_subprotocol: Option, 120 | handshake_response_headers: Option>, 121 | } 122 | 123 | impl WebSocket { 124 | /// Constructs a [`WebSocketBuilder`], which can be used to customize 125 | /// the WebSocket handshake. 126 | pub fn builder() -> WebSocketBuilder { 127 | WebSocketBuilder::new() 128 | } 129 | 130 | /// Connects to a URL (and performs the WebSocket handshake). 131 | pub async fn connect(url: &str) -> Result { 132 | WebSocketBuilder::new().connect(url).await 133 | } 134 | 135 | /// Receives a [`Frame`] over the WebSocket connection. 136 | /// 137 | /// If the received frame is a Ping frame, a Pong frame will be sent. 138 | /// If the received frame is a Close frame, an echoed Close frame 139 | /// will be sent and the WebSocket will close. 140 | pub async fn receive(&mut self) -> Result { 141 | let received_frame = self.read_half.receive().await?; 142 | self.write_half.flush().await?; 143 | Ok(received_frame) 144 | } 145 | 146 | /// Receives a [`Frame`] over the WebSocket connection **without handling incoming frames.** 147 | /// For example, receiving a Ping frame will not queue a Pong frame to be sent, 148 | /// and receiving a Close frame will not queue a Close frame to be sent nor close 149 | /// the connection. 150 | /// 151 | /// To automatically handle incoming frames, use the [`receive()`](WebSocket::receive()) 152 | /// method instead. 153 | pub async fn receive_without_handling(&mut self) -> Result { 154 | self.read_half.receive_without_handling().await 155 | } 156 | 157 | /// Sends an already constructed [`Frame`] over the WebSocket connection. 158 | pub async fn send(&mut self, frame: Frame) -> Result<(), WebSocketError> { 159 | self.write_half.send(frame).await 160 | } 161 | 162 | /// Sends a Text frame over the WebSocket connection, constructed 163 | /// from passed arguments. `continuation` will be `false` and `fin` will be `true`. 164 | /// To use a custom `continuation` or `fin`, construct a [`Frame`] and use 165 | /// [`WebSocket::send()`]. 166 | pub async fn send_text(&mut self, payload: String) -> Result<(), WebSocketError> { 167 | self.write_half.send_text(payload).await 168 | } 169 | 170 | /// Sends a Binary frame over the WebSocket connection, constructed 171 | /// from passed arguments. `continuation` will be `false` and `fin` will be `true`. 172 | /// To use a custom `continuation` or `fin`, construct a [`Frame`] and use 173 | /// [`WebSocket::send()`]. 174 | pub async fn send_binary(&mut self, payload: Vec) -> Result<(), WebSocketError> { 175 | self.write_half.send_binary(payload).await 176 | } 177 | 178 | /// Sends a Close frame over the WebSocket connection, constructed 179 | /// from passed arguments, and closes the WebSocket connection. 180 | /// This method will attempt to wait for an echoed Close frame, 181 | /// which is returned. 182 | pub async fn close(&mut self, payload: Option<(u16, String)>) -> Result<(), WebSocketError> { 183 | self.write_half.close(payload).await 184 | } 185 | 186 | /// Sends a Ping frame over the WebSocket connection, constructed 187 | /// from passed arguments. 188 | pub async fn send_ping(&mut self, payload: Option>) -> Result<(), WebSocketError> { 189 | self.write_half.send_ping(payload).await 190 | } 191 | 192 | /// Sends a Pong frame over the WebSocket connection, constructed 193 | /// from passed arguments. 194 | pub async fn send_pong(&mut self, payload: Option>) -> Result<(), WebSocketError> { 195 | self.write_half.send_pong(payload).await 196 | } 197 | 198 | /// Shuts down the WebSocket connection **without sending a Close frame**. 199 | /// It is recommended to use the [`close()`](WebSocket::close()) method instead. 200 | pub async fn shutdown(&mut self) -> Result<(), WebSocketError> { 201 | self.write_half.shutdown().await 202 | } 203 | 204 | /// Splits the WebSocket into a read half and a write half, which can be used separately. 205 | /// [Accepted subprotocol](WebSocket::accepted_subprotocol()) 206 | /// and [handshake response headers](WebSocket::handshake_response_headers()) data 207 | /// will be lost. 208 | pub fn split(self) -> (WebSocketReadHalf, WebSocketWriteHalf) { 209 | (self.read_half, self.write_half) 210 | } 211 | 212 | /// Joins together a split read half and write half to reconstruct a WebSocket. 213 | pub fn join(read_half: WebSocketReadHalf, write_half: WebSocketWriteHalf) -> Self { 214 | Self { 215 | read_half, 216 | write_half, 217 | accepted_subprotocol: None, 218 | handshake_response_headers: None, 219 | } 220 | } 221 | 222 | /// Returns the subprotocol that was accepted by the server during the handshake, 223 | /// if any. This data will be lost if the WebSocket is [`split`](WebSocket::split()). 224 | pub fn accepted_subprotocol(&self) -> &Option { 225 | // https://tools.ietf.org/html/rfc6455#section-1.9 226 | &self.accepted_subprotocol 227 | } 228 | 229 | /// Returns the headers that were returned by the server during the handshake. 230 | /// This data will be lost if the WebSocket is [`split`](WebSocket::split()). 231 | pub fn handshake_response_headers(&self) -> &Option> { 232 | // https://tools.ietf.org/html/rfc6455#section-4.2.2 233 | &self.handshake_response_headers 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /src/websocket/parsed_addr.rs: -------------------------------------------------------------------------------- 1 | use std::convert::TryFrom; 2 | use std::net::{SocketAddr, ToSocketAddrs}; 3 | 4 | use url::Url; 5 | 6 | use crate::WebSocketError; 7 | 8 | #[derive(Debug)] 9 | pub(super) struct ParsedAddr { 10 | pub scheme: String, 11 | pub host: String, 12 | pub path: String, 13 | pub addr: SocketAddr, 14 | } 15 | 16 | impl TryFrom<&str> for ParsedAddr { 17 | type Error = WebSocketError; 18 | 19 | fn try_from(url: &str) -> Result { 20 | let parsed_url = Url::parse(url).map_err(|e| WebSocketError::ParseError(e))?; 21 | let scheme = parsed_url.scheme(); 22 | let host = parsed_url.host_str().ok_or(WebSocketError::HostError)?; 23 | let path = parsed_url.path(); 24 | let port = parsed_url 25 | .port_or_known_default() 26 | .ok_or(WebSocketError::PortError)?; 27 | let addr = (host, port) 28 | .to_socket_addrs() 29 | .map_err(|e| WebSocketError::SocketAddrError(e))? 30 | .next() 31 | .ok_or(WebSocketError::ResolutionError)?; 32 | Ok(ParsedAddr { 33 | scheme: scheme.to_string(), 34 | host: host.to_string(), 35 | path: path.to_string(), 36 | addr, 37 | }) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/websocket/split.rs: -------------------------------------------------------------------------------- 1 | use flume::{Receiver, Sender}; 2 | use rand_chacha::ChaCha20Rng; 3 | use tokio::io::{AsyncWriteExt, BufReader, BufWriter, ReadHalf, WriteHalf}; 4 | 5 | use super::frame::Frame; 6 | use super::stream::Stream; 7 | use super::FrameType; 8 | #[allow(unused_imports)] // for intra doc links 9 | use super::WebSocket; 10 | use crate::error::WebSocketError; 11 | 12 | /// Events sent from the read half to the write half 13 | #[derive(Debug)] 14 | pub(super) enum Event { 15 | SendPongFrame(Frame), 16 | SendCloseFrameAndShutdown(Frame), 17 | } 18 | 19 | /// The read half of a WebSocket connection, generated from [`WebSocket::split()`]. 20 | /// This half can only receive frames. 21 | #[derive(Debug)] 22 | pub struct WebSocketReadHalf { 23 | pub(super) stream: BufReader>, 24 | pub(super) last_frame_type: FrameType, 25 | pub(super) sender: Sender, 26 | } 27 | 28 | impl WebSocketReadHalf { 29 | /// Receives a [`Frame`] over the WebSocket connection. 30 | /// 31 | /// If the received frame is a Ping frame, an event to send a Pong frame will be queued. 32 | /// If the received frame is a Close frame, an event to send a Close frame 33 | /// will be queued and the WebSocket will close. However, events are not 34 | /// acted upon unless flushed (see the documentation on the [`WebSocket`](WebSocket#splitting) 35 | /// type for more details). 36 | pub async fn receive(&mut self) -> Result { 37 | let frame = self.receive_without_handling().await?; 38 | // handle incoming frames 39 | match &frame { 40 | // echo ping frame (https://tools.ietf.org/html/rfc6455#section-5.5.2) 41 | Frame::Ping { payload } => { 42 | let pong = Frame::Pong { 43 | payload: payload.clone(), 44 | }; 45 | self.sender 46 | .send(Event::SendPongFrame(pong)) 47 | .map_err(|_e| WebSocketError::ChannelError)?; 48 | } 49 | // echo close frame and shutdown (https://tools.ietf.org/html/rfc6455#section-1.4) 50 | Frame::Close { payload } => { 51 | let close = Frame::Close { 52 | payload: payload 53 | .as_ref() 54 | .map(|(status_code, _reason)| (status_code.clone(), String::new())), 55 | }; 56 | self.sender 57 | .send(Event::SendCloseFrameAndShutdown(close)) 58 | .map_err(|_e| WebSocketError::ChannelError)?; 59 | } 60 | _ => (), 61 | } 62 | Ok(frame) 63 | } 64 | 65 | /// Receives a [`Frame`] over the WebSocket connection **without handling incoming frames.** 66 | /// For example, receiving a Ping frame will not queue a Pong frame to be sent, 67 | /// and receiving a Close frame will not queue a Close frame to be sent nor close 68 | /// the connection. 69 | /// 70 | /// To automatically handle incoming frames, use the [`receive()`](WebSocketReadHalf::receive()) 71 | /// method instead. 72 | pub async fn receive_without_handling(&mut self) -> Result { 73 | let frame = Frame::read_from_websocket(self).await?; 74 | // remember last data frame type in case we get continuation frames (https://tools.ietf.org/html/rfc6455#section-5.2) 75 | match frame { 76 | Frame::Text { .. } => self.last_frame_type = FrameType::Text, 77 | Frame::Binary { .. } => self.last_frame_type = FrameType::Binary, 78 | _ => (), 79 | }; 80 | Ok(frame) 81 | } 82 | } 83 | 84 | /// The write half of a WebSocket connection, generated from [`WebSocket::split()`]. 85 | /// This half can only send frames. 86 | #[derive(Debug)] 87 | pub struct WebSocketWriteHalf { 88 | pub(super) shutdown: bool, 89 | pub(super) sent_closed: bool, 90 | pub(super) stream: BufWriter>, 91 | pub(super) rng: ChaCha20Rng, 92 | pub(super) receiver: Receiver, 93 | } 94 | 95 | impl WebSocketWriteHalf { 96 | /// Flushes incoming events from the read half. If the read half received a Ping frame, 97 | /// a Pong frame will be sent. If the read half received a Close frame, 98 | /// an echoed Close frame will be sent and the WebSocket will close. 99 | /// See the documentation on the [`WebSocket`](WebSocket#splitting) type for more details 100 | /// about events. 101 | pub async fn flush(&mut self) -> Result<(), WebSocketError> { 102 | while let Ok(event) = self.receiver.try_recv() { 103 | if self.shutdown { 104 | break; 105 | } 106 | match event { 107 | Event::SendPongFrame(frame) => self.send_without_events_check(frame).await?, 108 | Event::SendCloseFrameAndShutdown(frame) => { 109 | // read half will always send this event if it has received a close frame, 110 | // but if we have sent one already, then we have sent and received a close 111 | // frame, so we will shutdown 112 | if self.sent_closed { 113 | self.send_without_events_check(frame).await?; 114 | self.shutdown().await?; 115 | } 116 | } 117 | }; 118 | } 119 | Ok(()) 120 | } 121 | 122 | /// Sends an already constructed [`Frame`] over the WebSocket connection. 123 | /// 124 | /// This method will flush incoming events. 125 | /// See the documentation on the [`WebSocket`](WebSocket#splitting) type for more details 126 | /// about events. 127 | pub async fn send(&mut self, frame: Frame) -> Result<(), WebSocketError> { 128 | self.flush().await?; 129 | if self.shutdown || self.sent_closed { 130 | return Err(WebSocketError::WebSocketClosedError); 131 | } 132 | self.send_without_events_check(frame).await 133 | } 134 | 135 | /// Sends an already constructed [`Frame`] over the WebSocket connection 136 | /// without flushing incoming events from the read half. 137 | /// See the documentation on the [`WebSocket`](WebSocket#splitting) type for more details 138 | /// about events. 139 | async fn send_without_events_check(&mut self, frame: Frame) -> Result<(), WebSocketError> { 140 | frame.send(self).await?; 141 | Ok(()) 142 | } 143 | 144 | /// Sends a Text frame over the WebSocket connection, constructed 145 | /// from passed arguments. `continuation` will be `false` and `fin` will be `true`. 146 | /// To use a custom `continuation` or `fin`, construct a [`Frame`] and use 147 | /// [`WebSocketWriteHalf::send()`]. 148 | /// 149 | /// This method will flush incoming events. 150 | /// See the documentation on the [`WebSocket`](WebSocket#splitting) type for more details 151 | /// about events. 152 | pub async fn send_text(&mut self, payload: String) -> Result<(), WebSocketError> { 153 | // https://tools.ietf.org/html/rfc6455#section-5.6 154 | self.send(Frame::text(payload)).await 155 | } 156 | 157 | /// Sends a Binary frame over the WebSocket connection, constructed 158 | /// from passed arguments. `continuation` will be `false` and `fin` will be `true`. 159 | /// To use a custom `continuation` or `fin`, construct a [`Frame`] and use 160 | /// [`WebSocketWriteHalf::send()`]. 161 | /// 162 | /// This method will flush incoming events. 163 | /// See the documentation on the [`WebSocket`](WebSocket#splitting) type for more details 164 | /// about events. 165 | pub async fn send_binary(&mut self, payload: Vec) -> Result<(), WebSocketError> { 166 | // https://tools.ietf.org/html/rfc6455#section-5.6 167 | self.send(Frame::binary(payload)).await 168 | } 169 | 170 | /// Shuts down the WebSocket connection **without sending a Close frame**. 171 | /// It is recommended to use the [`close()`](WebSocketWriteHalf::close()) method instead. 172 | pub async fn shutdown(&mut self) -> Result<(), WebSocketError> { 173 | self.stream 174 | .shutdown() 175 | .await 176 | .map_err(|e| WebSocketError::ShutdownError(e))?; 177 | // indicates that a closed frame has been sent, so no more frames should be sent, 178 | // but the underlying stream is not technically closed (closing the stream 179 | // would prevent a Close frame from being received by the read half) 180 | self.sent_closed = true; 181 | Ok(()) 182 | } 183 | 184 | /// Sends a Close frame over the WebSocket connection, constructed 185 | /// from passed arguments, and closes the WebSocket connection. 186 | /// 187 | /// As per the WebSocket protocol, the server should send a Close frame in response 188 | /// upon receiving a Close frame. Although the write half will be closed, 189 | /// the server's echoed Close frame can be read from the still open read half. 190 | /// 191 | /// This method will flush incoming events. 192 | /// See the documentation on the [`WebSocket`](WebSocket#splitting) type for more details 193 | /// about events. 194 | pub async fn close(&mut self, payload: Option<(u16, String)>) -> Result<(), WebSocketError> { 195 | // https://tools.ietf.org/html/rfc6455#section-5.5.1 196 | self.send(Frame::Close { payload }).await?; 197 | // self.shutdown().await?; 198 | Ok(()) 199 | } 200 | 201 | /// Sends a Ping frame over the WebSocket connection, constructed 202 | /// from passed arguments. 203 | /// 204 | /// This method will flush incoming events. 205 | /// See the documentation on the [`WebSocket`](WebSocket#splitting) type for more details 206 | /// about events. 207 | pub async fn send_ping(&mut self, payload: Option>) -> Result<(), WebSocketError> { 208 | // https://tools.ietf.org/html/rfc6455#section-5.5.2 209 | self.send(Frame::Ping { payload }).await 210 | } 211 | 212 | /// Sends a Pong frame over the WebSocket connection, constructed 213 | /// from passed arguments. 214 | /// 215 | /// This method will flush incoming events. 216 | /// See the documentation on the [`WebSocket`](WebSocket#splitting) type for more details 217 | /// about events. 218 | pub async fn send_pong(&mut self, payload: Option>) -> Result<(), WebSocketError> { 219 | // https://tools.ietf.org/html/rfc6455#section-5.5.3 220 | self.send(Frame::Pong { payload }).await 221 | } 222 | } 223 | 224 | #[cfg(test)] 225 | mod tests { 226 | use super::*; 227 | 228 | #[test] 229 | fn assert_send_sync() 230 | where 231 | WebSocketReadHalf: Send + Sync, 232 | WebSocketWriteHalf: Send + Sync, 233 | { 234 | } 235 | } 236 | -------------------------------------------------------------------------------- /src/websocket/stream.rs: -------------------------------------------------------------------------------- 1 | use native_tls::TlsConnector as NativeTlsTlsConnector; 2 | use std::io::Error as IoError; 3 | use std::pin::Pin; 4 | use std::task::{Context, Poll}; 5 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 6 | use tokio::net::TcpStream; 7 | use tokio_native_tls::{TlsConnector as TokioTlsConnector, TlsStream}; 8 | 9 | use crate::error::WebSocketError; 10 | 11 | #[derive(Debug)] 12 | pub(super) enum Stream { 13 | Plain(TcpStream), 14 | Tls(TlsStream), 15 | } 16 | 17 | impl Stream { 18 | pub(super) async fn into_tls( 19 | self, 20 | host: &str, 21 | tls_connector: NativeTlsTlsConnector, 22 | ) -> Result { 23 | match self { 24 | Self::Plain(tcp_stream) => { 25 | let connector: TokioTlsConnector = tls_connector.into(); 26 | let tls_stream = connector 27 | .connect(host, tcp_stream) 28 | .await 29 | .map_err(|e| WebSocketError::TlsConnectionError(e))?; 30 | Ok(Stream::Tls(tls_stream)) 31 | } 32 | Self::Tls(_) => Ok(self), 33 | } 34 | } 35 | 36 | // pub(super) fn get_ref(&self) -> &TcpStream { 37 | // match self { 38 | // Self::Plain(tcp_stream) => tcp_stream, 39 | // Self::Tls(tls_stream) => tls_stream.get_ref().get_ref().get_ref(), 40 | // } 41 | // } 42 | 43 | // pub(super) fn get_mut(&mut self) -> &mut TcpStream { 44 | // match self { 45 | // Self::Plain(tcp_stream) => tcp_stream, 46 | // Self::Tls(tls_stream) => tls_stream.get_mut().get_mut().get_mut(), 47 | // } 48 | // } 49 | } 50 | 51 | impl AsyncRead for Stream { 52 | fn poll_read( 53 | self: Pin<&mut Self>, 54 | cx: &mut Context<'_>, 55 | buf: &mut ReadBuf, 56 | ) -> Poll> { 57 | match self.get_mut() { 58 | Self::Plain(tcp_stream) => Pin::new(tcp_stream).poll_read(cx, buf), 59 | Self::Tls(tls_stream) => Pin::new(tls_stream).poll_read(cx, buf), 60 | } 61 | } 62 | } 63 | 64 | impl AsyncWrite for Stream { 65 | fn poll_write( 66 | self: Pin<&mut Self>, 67 | cx: &mut Context<'_>, 68 | buf: &[u8], 69 | ) -> Poll> { 70 | match self.get_mut() { 71 | Self::Plain(tcp_stream) => Pin::new(tcp_stream).poll_write(cx, buf), 72 | Self::Tls(tls_stream) => Pin::new(tls_stream).poll_write(cx, buf), 73 | } 74 | } 75 | 76 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 77 | match self.get_mut() { 78 | Self::Plain(tcp_stream) => Pin::new(tcp_stream).poll_flush(cx), 79 | Self::Tls(tls_stream) => Pin::new(tls_stream).poll_flush(cx), 80 | } 81 | } 82 | 83 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 84 | match self.get_mut() { 85 | Self::Plain(tcp_stream) => Pin::new(tcp_stream).poll_shutdown(cx), 86 | Self::Tls(tls_stream) => Pin::new(tls_stream).poll_shutdown(cx), 87 | } 88 | } 89 | } 90 | --------------------------------------------------------------------------------