├── .editorconfig ├── .github ├── dependabot.yml └── workflows │ └── ci.yml ├── .gitignore ├── CHANGELOG.md ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── RELEASING.md ├── examples ├── autobahn_client.rs ├── autobahn_server.rs └── hyper_server.rs ├── rustfmt.toml └── src ├── base.rs ├── connection.rs ├── data.rs ├── extension.rs ├── extension └── deflate.rs ├── handshake.rs ├── handshake ├── client.rs ├── http.rs └── server.rs └── lib.rs /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset=utf-8 5 | end_of_line=lf 6 | indent_size=4 7 | indent_style=space 8 | max_line_length=100 9 | 10 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "cargo" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | 8 | - package-ecosystem: "github-actions" 9 | directory: "/" 10 | schedule: 11 | interval: "weekly" 12 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | # Run jobs when commits are pushed to 6 | # master or release-like branches: 7 | branches: 8 | - master 9 | - release* 10 | pull_request: 11 | # Run jobs for any external PR that wants 12 | # to merge to master, too: 13 | branches: 14 | - master 15 | 16 | env: 17 | CARGO_TERM_COLOR: always 18 | 19 | jobs: 20 | build: 21 | name: Check Code 22 | runs-on: ubuntu-latest 23 | steps: 24 | - name: Checkout sources 25 | uses: actions/checkout@v4.2.2 26 | 27 | - name: Install Rust stable toolchain 28 | uses: actions-rs/toolchain@v1.0.7 29 | with: 30 | profile: minimal 31 | toolchain: stable 32 | override: true 33 | 34 | - name: Rust Cache 35 | uses: Swatinem/rust-cache@v2.7.8 36 | 37 | - name: Build 38 | uses: actions-rs/cargo@v1.0.3 39 | with: 40 | command: check 41 | args: --all-targets --all-features 42 | 43 | fmt: 44 | name: Run rustfmt 45 | runs-on: ubuntu-latest 46 | steps: 47 | - name: Checkout sources 48 | uses: actions/checkout@v4.2.2 49 | 50 | - name: Install Rust stable toolchain 51 | uses: actions-rs/toolchain@v1.0.7 52 | with: 53 | profile: minimal 54 | toolchain: stable 55 | override: true 56 | components: clippy, rustfmt 57 | 58 | - name: Rust Cache 59 | uses: Swatinem/rust-cache@v2.7.8 60 | 61 | - name: Cargo fmt 62 | uses: actions-rs/cargo@v1.0.3 63 | with: 64 | command: fmt 65 | args: --all -- --check 66 | 67 | docs: 68 | name: Check Documentation 69 | runs-on: ubuntu-latest 70 | steps: 71 | - name: Checkout sources 72 | uses: actions/checkout@v4.2.2 73 | 74 | - name: Install Rust stable toolchain 75 | uses: actions-rs/toolchain@v1.0.7 76 | with: 77 | profile: minimal 78 | toolchain: stable 79 | override: true 80 | 81 | - name: Rust Cache 82 | uses: Swatinem/rust-cache@v2.7.8 83 | 84 | - name: Check internal documentation links 85 | run: RUSTDOCFLAGS="--deny broken_intra_doc_links" cargo doc --verbose --workspace --no-deps --document-private-items 86 | 87 | tests: 88 | name: Run tests 89 | runs-on: ubuntu-latest 90 | steps: 91 | - name: Checkout sources 92 | uses: actions/checkout@v4.2.2 93 | 94 | - name: Install Rust stable toolchain 95 | uses: actions-rs/toolchain@v1.0.7 96 | with: 97 | profile: minimal 98 | toolchain: stable 99 | override: true 100 | 101 | - name: Rust Cache 102 | uses: Swatinem/rust-cache@v2.7.8 103 | 104 | - name: Cargo build 105 | uses: actions-rs/cargo@v1.0.3 106 | with: 107 | command: build 108 | args: --workspace 109 | 110 | - name: Cargo test 111 | uses: actions-rs/cargo@v1.0.3 112 | with: 113 | command: test 114 | 115 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | Cargo.lock 3 | *.dat 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | The format is based on [Keep a Changelog]. 4 | 5 | [Keep a Changelog]: http://keepachangelog.com/en/1.0.0/ 6 | 7 | ## 0.8.1 8 | 9 | - [fixed] ignore I/O error after successful close handshake [#115](https://github.com/paritytech/soketto/pull/115) 10 | 11 | ## 0.8.0 12 | 13 | - [changed] move to rust 2021 [#56](https://github.com/paritytech/soketto/pull/56) 14 | - [changed] Replace sha-1 v0.9 with sha1 v0.10 [#62](https://github.com/paritytech/soketto/pull/62) 15 | - [changed] Update hyper requirement from v0.14 to v1.0 [#99](https://github.com/paritytech/soketto/pull/99) 16 | - [changed] Update base64 requirement from 0.13 to 0.22 [#97](https://github.com/paritytech/soketto/pull/97) 17 | - [changed] Bump MSRV to 1.71.1. 18 | - [fixed] doc typo on Client resource field [#79](https://github.com/paritytech/soketto/pull/97) 19 | 20 | ## 0.7.1 21 | 22 | - [fixed] Advance reader when a too big message is received [#54](https://github.com/paritytech/soketto/pull/54) 23 | 24 | ## 0.7.0 25 | 26 | - [added] Added the `handshake::http` module and example usage at `examples/hyper_server.rs` to make using Soketto in conjunction with libraries that use the `http` types (like Hyper) simpler [#45](https://github.com/paritytech/soketto/pull/45) [#48](https://github.com/paritytech/soketto/pull/48) 27 | - [added] Allow setting custom headers on the client to be sent to WebSocket servers when the opening handshake is performed [#47](https://github.com/paritytech/soketto/pull/47) 28 | 29 | ## 0.6.0 30 | 31 | - [changed] Expose the `Origin` headers from the client handshake on `ClientRequest` [#35](https://github.com/paritytech/soketto/pull/35) 32 | - [changed] Update handshake error to expose a couple of new variants (`IncompleteHttpRequest` and `SecWebSocketKeyInvalidLength`) [#35](https://github.com/paritytech/soketto/pull/35) 33 | - [added] Add `send_text_owned` method to `Sender` as an optimisation when you can pass an owned `String` in [#36](https://github.com/paritytech/soketto/pull/36) 34 | - [updated] Run rustfmt over the repository, and minor tidy up [#41](https://github.com/paritytech/soketto/pull/41) 35 | 36 | ## 0.5.0 37 | 38 | - Update examples to Tokio 1 [#27](https://github.com/paritytech/soketto/pull/27) 39 | - Update deps and remove unnecessary transients [#30](https://github.com/paritytech/soketto/pull/30) 40 | - Add CLOSE reason handling [#31](https://github.com/paritytech/soketto/pull/31) 41 | - Fix handshake with case-sensible servers [#32](https://github.com/paritytech/soketto/pull/32) 42 | 43 | ## 0.4.2 44 | 45 | - Added connection ID to log output (#21). 46 | - Added `ClientRequest::path` to access the path requested by the client 47 | (See #23 by @mward for details). 48 | - Updated `sha-1` dependency to 0.9 (#24). 49 | 50 | ## 0.4.1 51 | 52 | - Update some `dev-dependencies`. 53 | 54 | ## 0.4.0 55 | 56 | - Remove all `unsafe` code blocks. 57 | - Remove internal use of `futures::io::BufWriter`. 58 | - `Extension::decode` now takes a `&mut Vec` instead of a `BytesMut`. 59 | - `Incoming::Pong` contains the PONG payload data slice inline. 60 | - `Data` not longer contains application data, but reports only the number 61 | of bytes. The actual data is written directly into the `&mut Vec` 62 | parameter of `Receiver::receive` or `Receiver::receive_data`. 63 | - `Receiver::into_stream` has been removed. 64 | 65 | ## 0.3.2 66 | 67 | - Bugfix release. `Codec::encode_header` contained a hidden assumption that 68 | a `usize` would be 8 bytes long, which is obviously only true on 64-bit 69 | architectures. See #18 for details. 70 | 71 | ## 0.3.1 72 | 73 | - A method `into_inner` to get back the socket has been added to 74 | `handshake::{Client, Server}`. 75 | 76 | ## 0.3.0 77 | 78 | Update to use and work with async/await: 79 | 80 | - `Connection` has been split into a `Sender` and `Receiver` pair with 81 | async methods to send and receive data or control frames such as Pings 82 | or Pongs. 83 | - `connection::into_stream` has been added to get a `futures::stream::Stream` 84 | from a `Receiver`. 85 | - A `connection::Builder` has been added to setup connection parameters. 86 | `handshake::Client` and `handshake::Server` no longer have an 87 | `into_connection` method, but an `into_builder` one which returns the 88 | `Builder` and allows further configuration. 89 | - `base::Data` has been moved to `data`. In addition `data::Incoming` 90 | supports control frame data. 91 | - `base::Codec` no longer implements `Encoder`/`Decoder` traits but has 92 | inherent methods for encoding and decoding websocket frame headers. 93 | - `base::Frame` has been removed. The `base::Codec` only deals with 94 | headers. 95 | - The `handshake` module contains separate sub-modules for `client` and 96 | `server` handshakes. Some handshake related types have been refactored 97 | slightly. 98 | - `Extension`s `decode` methods work on `&mut BytesMut` parameters 99 | instead of `Data`. For `encode` a new type `Storage` has been added 100 | which unifies different types of data, i.e. shared, unique and owned data. 101 | 102 | ## 0.2.3 103 | 104 | - Maintenance release. 105 | 106 | ## 0.2.2 107 | 108 | - Improved handshake header matching which is now more robust and can cope with 109 | repeated header names and comma separated values. 110 | 111 | ## 0.2.1 112 | 113 | - The DEFLATE extension now allows custom maximum window bits for client and server. 114 | - Fix handling of reserved bits in base codec. 115 | 116 | ## 0.2.0 117 | 118 | - Change `Extension` trait and add an optional DEFLATE extension (RFC 7692). 119 | For now the possibility to use reserved opcodes in extensions is not enabled. 120 | The DEFLATE extension does not support setting of window bits other than 15 121 | currently. 122 | - Limit the max. buffer size in `Connection` (see `Connection::set_max_buffer_size`). 123 | 124 | ## 0.1.0 125 | 126 | Initial release. 127 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "soketto" 3 | version = "0.8.1" 4 | authors = ["Parity Technologies ", "Jason Ozias "] 5 | description = "A websocket protocol implementation." 6 | keywords = ["websocket", "codec", "async", "futures"] 7 | categories = ["network-programming", "asynchronous", "web-programming::websocket"] 8 | license = "Apache-2.0 OR MIT" 9 | readme = "README.md" 10 | repository = "https://github.com/paritytech/soketto" 11 | edition = "2021" 12 | rust-version = "1.71.1" 13 | 14 | [package.metadata.docs.rs] 15 | all-features = true 16 | 17 | [features] 18 | default = [] 19 | deflate = ["flate2"] 20 | 21 | [dependencies] 22 | base64 = { default-features = false, features = ["alloc"], version = "0.22" } 23 | bytes = { default-features = false, version = "1.0" } 24 | flate2 = { default-features = false, features = ["zlib"], optional = true, version = "1.0.13" } 25 | futures = { default-features = false, features = ["bilock", "std", "unstable"], version = "0.3.1" } 26 | httparse = { default-features = false, features = ["std"], version = "1.3.4" } 27 | log = { default-features = false, version = "0.4.8" } 28 | rand = { default-features = false, features = ["std", "std_rng"], version = "0.8" } 29 | sha1 = { default-features = false, version = "0.10" } 30 | http = { version = "1", optional = true } 31 | 32 | [dev-dependencies] 33 | quickcheck = "1" 34 | tokio = { version = "1", features = ["full"] } 35 | tokio-util = { version = "0.7", features = ["compat"] } 36 | tokio-stream = { version = "0.1", features = ["net"] } 37 | http-body-util = "0.1" 38 | hyper = { version = "1.2", features = ["full"] } 39 | hyper-util = { version = "0.1", features = ["tokio"] } 40 | env_logger = "0.11.1" 41 | 42 | [[example]] 43 | name = "hyper_server" 44 | required-features = ["http"] 45 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | Copyright (c) 2016 twist developers 3 | 4 | Permission is hereby granted, free of charge, to any 5 | person obtaining a copy of this software and associated 6 | documentation files (the "Software"), to deal in the 7 | Software without restriction, including without 8 | limitation the rights to use, copy, modify, merge, 9 | publish, distribute, sublicense, and/or sell copies of 10 | the Software, and to permit persons to whom the Software 11 | is furnished to do so, subject to the following 12 | conditions: 13 | 14 | The above copyright notice and this permission notice 15 | shall be included in all copies or substantial portions 16 | of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 19 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 20 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 21 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 22 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 23 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 24 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 25 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 26 | DEALINGS IN THE SOFTWARE. 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Soketto 2 | 3 | An implementation of the [RFC 6455][1] websocket protocol. 4 | This crate is a heavily modified fork of the [twist][2] crate. 5 | 6 | [1]: https://tools.ietf.org/html/rfc6455 7 | [2]: https://crates.io/crates/twist 8 | 9 | -------------------------------------------------------------------------------- /RELEASING.md: -------------------------------------------------------------------------------- 1 | # Release Checklist 2 | 3 | These steps assume that you've checked out the Soketto repository and are in the root directory of it. 4 | 5 | We also assume that ongoing work done is being merged directly to the `master` branch. 6 | 7 | 1. Ensure that everything you'd like to see released is on the `master` branch. 8 | 9 | 2. Create a release branch off `master`, for example `release-v0.6.0`. The branch name should start with `release` 10 | so that we can target commits with CI. Decide how far the version needs to be bumped based on the changes to date. 11 | If unsure what to bump the version to (e.g. is it a major, minor or patch release), check with the Parity Tools team. 12 | 13 | 3. Check that you're happy with the current documentation. 14 | 15 | ``` 16 | cargo doc --open --all-features 17 | ``` 18 | 19 | CI checks for broken internal links at the moment. Optionally you can also confirm that any external links 20 | are still valid like so: 21 | 22 | ``` 23 | cargo install cargo-deadlinks 24 | cargo deadlinks --check-http -- --all-features 25 | ``` 26 | 27 | If there are minor issues with the documentation, they can be fixed in the release branch. 28 | 29 | 4. Bump the crate version in `Cargo.toml` to whatever was decided in step 2. 30 | 31 | 5. Update `CHANGELOG.md` to reflect the difference between this release and last. If you're unsure of 32 | what to add, check with the Tools team. 33 | 34 | One way to gain some inspiration on what to write is to look at the [closed PRs](https://github.com/paritytech/soketto/pulls?q=is%3Apr+is%3Aclosed). 35 | 36 | You can also look through the commit history to find the code changes since the last release (eg `git log --pretty LAST_VERSION_TAG..HEAD`). 37 | 38 | 6. Commit any of the above changes to the release branch and open a PR in GitHub with a base of `master`. 39 | 40 | 7. Once the branch has been reviewed and passes CI, merge it. 41 | 42 | 8. Now, we're ready to publish the release to crates.io. 43 | 44 | Checkout `master`, ensuring we're looking at that latest merge (`git pull`). 45 | 46 | Next, do a dry run to make sure that things seem sane: 47 | ``` 48 | cargo publish --dry-run 49 | ``` 50 | 51 | If we're happy with everything, proceed with the release: 52 | ``` 53 | cargo publish 54 | ``` 55 | 56 | 9. If the release was successful, then tag the commit that we released in the `master` branch with the 57 | version that we just released, for example: 58 | 59 | ``` 60 | git tag v0.6.0 # use the version number you've just published to crates.io, not this one 61 | git push --tags 62 | ``` 63 | 64 | Once this is pushed, go along to [the releases page on GitHub](https://github.com/paritytech/soketto/releases) 65 | and draft a new release which points to the tag you just pushed to `master` above. Copy the changelog comments 66 | for the current release into the release description. 67 | 68 | -------------------------------------------------------------------------------- /examples/autobahn_client.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // 3 | // Licensed under the Apache License, Version 2.0 4 | // or the MIT 5 | // license , at your 6 | // option. All files in the project carrying such notice may not be copied, 7 | // modified, or distributed except according to those terms. 8 | 9 | // Example to be used with the autobahn test suite, a fully automated test 10 | // suite to verify client and server implementations of websocket 11 | // implementation. 12 | // 13 | // Once started, the tests can be executed with: wstest -m fuzzingserver 14 | // 15 | // See https://github.com/crossbario/autobahn-testsuite for details. 16 | 17 | use futures::io::{BufReader, BufWriter}; 18 | use soketto::{connection, handshake, BoxedError}; 19 | use std::str::FromStr; 20 | use tokio::net::TcpStream; 21 | use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; 22 | 23 | const SOKETTO_VERSION: &str = env!("CARGO_PKG_VERSION"); 24 | 25 | #[tokio::main] 26 | async fn main() -> Result<(), BoxedError> { 27 | let n = num_of_cases().await?; 28 | for i in 1..=n { 29 | if let Err(e) = run_case(i).await { 30 | log::error!("case {}: {:?}", i, e) 31 | } 32 | } 33 | update_report().await?; 34 | Ok(()) 35 | } 36 | 37 | async fn num_of_cases() -> Result { 38 | let socket = TcpStream::connect("127.0.0.1:9001").await?; 39 | let mut client = new_client(socket, "/getCaseCount"); 40 | assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted { .. })); 41 | let (_, mut receiver) = client.into_builder().finish(); 42 | let mut data = Vec::new(); 43 | let kind = receiver.receive_data(&mut data).await?; 44 | assert!(kind.is_text()); 45 | let num = usize::from_str(std::str::from_utf8(&data)?)?; 46 | log::info!("{} cases to run", num); 47 | Ok(num) 48 | } 49 | 50 | async fn run_case(n: usize) -> Result<(), BoxedError> { 51 | log::info!("running case {}", n); 52 | let resource = format!("/runCase?case={}&agent=soketto-{}", n, SOKETTO_VERSION); 53 | let socket = TcpStream::connect("127.0.0.1:9001").await?; 54 | let mut client = new_client(socket, &resource); 55 | assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted { .. })); 56 | let (mut sender, mut receiver) = client.into_builder().finish(); 57 | let mut message = Vec::new(); 58 | loop { 59 | message.clear(); 60 | match receiver.receive_data(&mut message).await { 61 | Ok(soketto::Data::Binary(n)) => { 62 | assert_eq!(n, message.len()); 63 | sender.send_binary_mut(&mut message).await?; 64 | sender.flush().await? 65 | } 66 | Ok(soketto::Data::Text(n)) => { 67 | assert_eq!(n, message.len()); 68 | sender.send_text(std::str::from_utf8(&message)?).await?; 69 | sender.flush().await? 70 | } 71 | Err(connection::Error::Closed) => return Ok(()), 72 | Err(e) => return Err(e.into()), 73 | } 74 | } 75 | } 76 | 77 | async fn update_report() -> Result<(), BoxedError> { 78 | log::info!("requesting report generation"); 79 | let resource = format!("/updateReports?agent=soketto-{}", SOKETTO_VERSION); 80 | let socket = TcpStream::connect("127.0.0.1:9001").await?; 81 | let mut client = new_client(socket, &resource); 82 | assert!(matches!(client.handshake().await?, handshake::ServerResponse::Accepted { .. })); 83 | client.into_builder().finish().0.close().await?; 84 | Ok(()) 85 | } 86 | 87 | #[cfg(not(feature = "deflate"))] 88 | fn new_client(socket: TcpStream, path: &str) -> handshake::Client<'_, BufReader>>> { 89 | handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "127.0.0.1:9001", path) 90 | } 91 | 92 | #[cfg(feature = "deflate")] 93 | fn new_client(socket: TcpStream, path: &str) -> handshake::Client<'_, BufReader>>> { 94 | let socket = BufReader::with_capacity(8 * 1024, BufWriter::with_capacity(64 * 1024, socket.compat())); 95 | let mut client = handshake::Client::new(socket, "127.0.0.1:9001", path); 96 | let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Client); 97 | client.add_extension(Box::new(deflate)); 98 | client 99 | } 100 | -------------------------------------------------------------------------------- /examples/autobahn_server.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // 3 | // Licensed under the Apache License, Version 2.0 4 | // or the MIT 5 | // license , at your 6 | // option. All files in the project carrying such notice may not be copied, 7 | // modified, or distributed except according to those terms. 8 | 9 | // Example to be used with the autobahn test suite, a fully automated test 10 | // suite to verify client and server implementations of websocket 11 | // implementation. 12 | // 13 | // Once started, the tests can be executed with: wstest -m fuzzingclient 14 | // 15 | // See https://github.com/crossbario/autobahn-testsuite for details. 16 | 17 | use futures::io::{BufReader, BufWriter}; 18 | use soketto::{connection, handshake, BoxedError}; 19 | use tokio::net::{TcpListener, TcpStream}; 20 | use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; 21 | use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; 22 | #[tokio::main] 23 | async fn main() -> Result<(), BoxedError> { 24 | let listener = TcpListener::bind("127.0.0.1:9001").await?; 25 | let mut incoming = TcpListenerStream::new(listener); 26 | while let Some(socket) = incoming.next().await { 27 | let mut server = new_server(socket?); 28 | let key = { 29 | let req = server.receive_request().await?; 30 | req.key() 31 | }; 32 | let accept = handshake::server::Response::Accept { key, protocol: None }; 33 | server.send_response(&accept).await?; 34 | let (mut sender, mut receiver) = server.into_builder().finish(); 35 | let mut message = Vec::new(); 36 | loop { 37 | message.clear(); 38 | match receiver.receive_data(&mut message).await { 39 | Ok(soketto::Data::Binary(n)) => { 40 | assert_eq!(n, message.len()); 41 | sender.send_binary_mut(&mut message).await?; 42 | sender.flush().await? 43 | } 44 | Ok(soketto::Data::Text(n)) => { 45 | assert_eq!(n, message.len()); 46 | if let Ok(txt) = std::str::from_utf8(&message) { 47 | sender.send_text(txt).await?; 48 | sender.flush().await? 49 | } else { 50 | break; 51 | } 52 | } 53 | Err(connection::Error::Closed) => break, 54 | Err(e) => { 55 | log::error!("connection error: {}", e); 56 | break; 57 | } 58 | } 59 | } 60 | } 61 | Ok(()) 62 | } 63 | 64 | #[cfg(not(feature = "deflate"))] 65 | fn new_server<'a>(socket: TcpStream) -> handshake::Server<'a, BufReader>>> { 66 | handshake::Server::new(BufReader::new(BufWriter::new(socket.compat()))) 67 | } 68 | 69 | #[cfg(feature = "deflate")] 70 | fn new_server<'a>(socket: TcpStream) -> handshake::Server<'a, BufReader>>> { 71 | let socket = BufReader::with_capacity(8 * 1024, BufWriter::with_capacity(16 * 1024, socket.compat())); 72 | let mut server = handshake::Server::new(socket); 73 | let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server); 74 | server.add_extension(Box::new(deflate)); 75 | server 76 | } 77 | -------------------------------------------------------------------------------- /examples/hyper_server.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021 Parity Technologies (UK) Ltd. 2 | // 3 | // Licensed under the Apache License, Version 2.0 4 | // or the MIT 5 | // license , at your 6 | // option. All files in the project carrying such notice may not be copied, 7 | // modified, or distributed except according to those terms. 8 | 9 | // An example of how to use of Soketto alongside Hyper, so that we can handle 10 | // standard HTTP traffic with Hyper, and WebSocket connections with Soketto, on 11 | // the same port. 12 | // 13 | // To try this, start up the example (`cargo run --example hyper_server`) and then 14 | // navigate to localhost:3000 and, in the browser JS console, run: 15 | // 16 | // ``` 17 | // var socket = new WebSocket("ws://localhost:3000"); 18 | // socket.onmessage = function(msg) { console.log(msg) }; 19 | // socket.send("Hello!"); 20 | // ``` 21 | // 22 | // You'll see any messages you send echoed back. 23 | 24 | use std::net::SocketAddr; 25 | 26 | use futures::io::{BufReader, BufWriter}; 27 | use hyper::server::conn::http1; 28 | use hyper::{body::Bytes, service::service_fn, Request, Response}; 29 | use hyper_util::rt::TokioIo; 30 | use soketto::{ 31 | handshake::http::{is_upgrade_request, Server}, 32 | BoxedError, 33 | }; 34 | use tokio_util::compat::TokioAsyncReadCompatExt; 35 | 36 | type FullBody = http_body_util::Full; 37 | 38 | /// Start up a hyper server. 39 | #[tokio::main] 40 | async fn main() -> Result<(), BoxedError> { 41 | env_logger::init(); 42 | 43 | let addr: SocketAddr = ([127, 0, 0, 1], 3000).into(); 44 | let listener = tokio::net::TcpListener::bind(addr).await?; 45 | 46 | log::info!( 47 | "Listening on http://{:?} — connect and I'll echo back anything you send!", 48 | listener.local_addr().unwrap() 49 | ); 50 | 51 | loop { 52 | let stream = match listener.accept().await { 53 | Ok((stream, addr)) => { 54 | log::info!("Accepting new connection: {addr}"); 55 | stream 56 | } 57 | Err(e) => { 58 | log::error!("Accepting new connection failed: {e}"); 59 | continue; 60 | } 61 | }; 62 | 63 | tokio::spawn(async { 64 | let io = TokioIo::new(stream); 65 | let conn = http1::Builder::new().serve_connection(io, service_fn(handler)); 66 | 67 | // Enable upgrades on the connection for the websocket upgrades to work. 68 | let conn = conn.with_upgrades(); 69 | 70 | // Log any errors that might have occurred during the connection. 71 | if let Err(err) = conn.await { 72 | log::error!("HTTP connection failed {err}"); 73 | } 74 | }); 75 | } 76 | } 77 | 78 | /// Handle incoming HTTP Requests. 79 | async fn handler(req: Request) -> Result, BoxedError> { 80 | if is_upgrade_request(&req) { 81 | // Create a new handshake server. 82 | let mut server = Server::new(); 83 | 84 | // Add any extensions that we want to use. 85 | #[cfg(feature = "deflate")] 86 | { 87 | let deflate = soketto::extension::deflate::Deflate::new(soketto::Mode::Server); 88 | server.add_extension(Box::new(deflate)); 89 | } 90 | 91 | // Attempt the handshake. 92 | match server.receive_request(&req) { 93 | // The handshake has been successful so far; return the response we're given back 94 | // and spawn a task to handle the long-running WebSocket server: 95 | Ok(response) => { 96 | tokio::spawn(async move { 97 | if let Err(e) = websocket_echo_messages(server, req).await { 98 | log::error!("Error upgrading to websocket connection: {}", e); 99 | } 100 | }); 101 | Ok(response.map(|()| FullBody::default())) 102 | } 103 | // We tried to upgrade and failed early on; tell the client about the failure however we like: 104 | Err(e) => { 105 | log::error!("Could not upgrade connection: {}", e); 106 | Ok(Response::new(FullBody::from("Something went wrong upgrading!"))) 107 | } 108 | } 109 | } else { 110 | // The request wasn't an upgrade request; let's treat it as a standard HTTP request: 111 | Ok(Response::new(FullBody::from("Hello HTTP!"))) 112 | } 113 | } 114 | 115 | /// Echo any messages we get from the client back to them 116 | async fn websocket_echo_messages(server: Server, req: Request) -> Result<(), BoxedError> { 117 | // The negotiation to upgrade to a WebSocket connection has been successful so far. Next, we get back the underlying 118 | // stream using `hyper::upgrade::on`, and hand this to a Soketto server to use to handle the WebSocket communication 119 | // on this socket. 120 | // 121 | // Note: awaiting this won't succeed until the handshake response has been returned to the client, so this must be 122 | // spawned on a separate task so as not to block that response being handed back. 123 | let stream = hyper::upgrade::on(req).await?; 124 | let io = TokioIo::new(stream); 125 | let stream = BufReader::new(BufWriter::new(io.compat())); 126 | 127 | // Get back a reader and writer that we can use to send and receive websocket messages. 128 | let (mut sender, mut receiver) = server.into_builder(stream).finish(); 129 | 130 | // Echo any received messages back to the client: 131 | let mut message = Vec::new(); 132 | loop { 133 | message.clear(); 134 | match receiver.receive_data(&mut message).await { 135 | Ok(soketto::Data::Binary(n)) => { 136 | assert_eq!(n, message.len()); 137 | sender.send_binary_mut(&mut message).await?; 138 | sender.flush().await? 139 | } 140 | Ok(soketto::Data::Text(n)) => { 141 | assert_eq!(n, message.len()); 142 | if let Ok(txt) = std::str::from_utf8(&message) { 143 | sender.send_text(txt).await?; 144 | sender.flush().await? 145 | } else { 146 | break; 147 | } 148 | } 149 | Err(soketto::connection::Error::Closed) => break, 150 | Err(e) => { 151 | eprintln!("Websocket connection error: {}", e); 152 | break; 153 | } 154 | } 155 | } 156 | 157 | Ok(()) 158 | } 159 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | hard_tabs = true 2 | max_width = 120 3 | use_small_heuristics = "Max" 4 | edition = "2018" 5 | -------------------------------------------------------------------------------- /src/base.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // Copyright (c) 2016 twist developers 3 | // 4 | // Licensed under the Apache License, Version 2.0 5 | // or the MIT 6 | // license , at your 7 | // option. All files in the project carrying such notice may not be copied, 8 | // modified, or distributed except according to those terms. 9 | 10 | // This file is largely based on the original twist implementation. 11 | // See [frame/base.rs] and [codec/base.rs]. 12 | // 13 | // [frame/base.rs]: https://github.com/rustyhorde/twist/blob/449d8b75c2/src/frame/base.rs 14 | // [codec/base.rs]: https://github.com/rustyhorde/twist/blob/449d8b75c2/src/codec/base.rs 15 | 16 | //! A websocket [base frame][base] codec. 17 | //! 18 | //! [base]: https://tools.ietf.org/html/rfc6455#section-5.2 19 | 20 | use crate::{as_u64, Parsing}; 21 | use std::{fmt, io}; 22 | 23 | /// Max. size of a frame header. 24 | pub(crate) const MAX_HEADER_SIZE: usize = 14; 25 | 26 | /// Max. size of a control frame payload. 27 | pub(crate) const MAX_CTRL_BODY_SIZE: u64 = 125; 28 | 29 | // OpCode ///////////////////////////////////////////////////////////////////////////////////////// 30 | 31 | /// Operation codes defined in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.2). 32 | #[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Clone, Copy)] 33 | pub enum OpCode { 34 | /// A continuation frame of a fragmented message. 35 | Continue, 36 | /// A text data frame. 37 | Text, 38 | /// A binary data frame. 39 | Binary, 40 | /// A close control frame. 41 | Close, 42 | /// A ping control frame. 43 | Ping, 44 | /// A pong control frame. 45 | Pong, 46 | /// A reserved op code. 47 | Reserved3, 48 | /// A reserved op code. 49 | Reserved4, 50 | /// A reserved op code. 51 | Reserved5, 52 | /// A reserved op code. 53 | Reserved6, 54 | /// A reserved op code. 55 | Reserved7, 56 | /// A reserved op code. 57 | Reserved11, 58 | /// A reserved op code. 59 | Reserved12, 60 | /// A reserved op code. 61 | Reserved13, 62 | /// A reserved op code. 63 | Reserved14, 64 | /// A reserved op code. 65 | Reserved15, 66 | } 67 | 68 | impl OpCode { 69 | /// Is this a control opcode? 70 | pub fn is_control(self) -> bool { 71 | if let OpCode::Close | OpCode::Ping | OpCode::Pong = self { 72 | true 73 | } else { 74 | false 75 | } 76 | } 77 | 78 | /// Is this opcode reserved? 79 | pub fn is_reserved(self) -> bool { 80 | match self { 81 | OpCode::Reserved3 82 | | OpCode::Reserved4 83 | | OpCode::Reserved5 84 | | OpCode::Reserved6 85 | | OpCode::Reserved7 86 | | OpCode::Reserved11 87 | | OpCode::Reserved12 88 | | OpCode::Reserved13 89 | | OpCode::Reserved14 90 | | OpCode::Reserved15 => true, 91 | _ => false, 92 | } 93 | } 94 | } 95 | 96 | impl fmt::Display for OpCode { 97 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 98 | match self { 99 | OpCode::Continue => f.write_str("Continue"), 100 | OpCode::Text => f.write_str("Text"), 101 | OpCode::Binary => f.write_str("Binary"), 102 | OpCode::Close => f.write_str("Close"), 103 | OpCode::Ping => f.write_str("Ping"), 104 | OpCode::Pong => f.write_str("Pong"), 105 | OpCode::Reserved3 => f.write_str("Reserved:3"), 106 | OpCode::Reserved4 => f.write_str("Reserved:4"), 107 | OpCode::Reserved5 => f.write_str("Reserved:5"), 108 | OpCode::Reserved6 => f.write_str("Reserved:6"), 109 | OpCode::Reserved7 => f.write_str("Reserved:7"), 110 | OpCode::Reserved11 => f.write_str("Reserved:11"), 111 | OpCode::Reserved12 => f.write_str("Reserved:12"), 112 | OpCode::Reserved13 => f.write_str("Reserved:13"), 113 | OpCode::Reserved14 => f.write_str("Reserved:14"), 114 | OpCode::Reserved15 => f.write_str("Reserved:15"), 115 | } 116 | } 117 | } 118 | 119 | /// Error returned by `OpCode::try_from` if an unknown opcode 120 | /// number is encountered. 121 | #[derive(Clone, Debug)] 122 | pub struct UnknownOpCode(()); 123 | 124 | impl fmt::Display for UnknownOpCode { 125 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 126 | f.write_str("unknown opcode") 127 | } 128 | } 129 | 130 | impl std::error::Error for UnknownOpCode {} 131 | 132 | impl TryFrom for OpCode { 133 | type Error = UnknownOpCode; 134 | 135 | fn try_from(val: u8) -> Result { 136 | match val { 137 | 0 => Ok(OpCode::Continue), 138 | 1 => Ok(OpCode::Text), 139 | 2 => Ok(OpCode::Binary), 140 | 3 => Ok(OpCode::Reserved3), 141 | 4 => Ok(OpCode::Reserved4), 142 | 5 => Ok(OpCode::Reserved5), 143 | 6 => Ok(OpCode::Reserved6), 144 | 7 => Ok(OpCode::Reserved7), 145 | 8 => Ok(OpCode::Close), 146 | 9 => Ok(OpCode::Ping), 147 | 10 => Ok(OpCode::Pong), 148 | 11 => Ok(OpCode::Reserved11), 149 | 12 => Ok(OpCode::Reserved12), 150 | 13 => Ok(OpCode::Reserved13), 151 | 14 => Ok(OpCode::Reserved14), 152 | 15 => Ok(OpCode::Reserved15), 153 | _ => Err(UnknownOpCode(())), 154 | } 155 | } 156 | } 157 | 158 | impl From for u8 { 159 | fn from(opcode: OpCode) -> u8 { 160 | match opcode { 161 | OpCode::Continue => 0, 162 | OpCode::Text => 1, 163 | OpCode::Binary => 2, 164 | OpCode::Close => 8, 165 | OpCode::Ping => 9, 166 | OpCode::Pong => 10, 167 | OpCode::Reserved3 => 3, 168 | OpCode::Reserved4 => 4, 169 | OpCode::Reserved5 => 5, 170 | OpCode::Reserved6 => 6, 171 | OpCode::Reserved7 => 7, 172 | OpCode::Reserved11 => 11, 173 | OpCode::Reserved12 => 12, 174 | OpCode::Reserved13 => 13, 175 | OpCode::Reserved14 => 14, 176 | OpCode::Reserved15 => 15, 177 | } 178 | } 179 | } 180 | 181 | // Frame header /////////////////////////////////////////////////////////////////////////////////// 182 | 183 | /// A websocket base frame header, i.e. everything but the payload. 184 | #[derive(Debug, Clone)] 185 | pub struct Header { 186 | fin: bool, 187 | rsv1: bool, 188 | rsv2: bool, 189 | rsv3: bool, 190 | masked: bool, 191 | opcode: OpCode, 192 | mask: u32, 193 | payload_len: usize, 194 | } 195 | 196 | impl fmt::Display for Header { 197 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 198 | write!( 199 | f, 200 | "({} (fin {}) (rsv {}{}{}) (mask ({} {:x})) (len {}))", 201 | self.opcode, 202 | self.fin as u8, 203 | self.rsv1 as u8, 204 | self.rsv2 as u8, 205 | self.rsv3 as u8, 206 | self.masked as u8, 207 | self.mask, 208 | self.payload_len 209 | ) 210 | } 211 | } 212 | 213 | impl Header { 214 | /// Create a new frame header with a given [`OpCode`]. 215 | pub fn new(oc: OpCode) -> Self { 216 | Header { fin: true, rsv1: false, rsv2: false, rsv3: false, masked: false, opcode: oc, mask: 0, payload_len: 0 } 217 | } 218 | 219 | /// Is the `fin` flag set? 220 | pub fn is_fin(&self) -> bool { 221 | self.fin 222 | } 223 | 224 | /// Set the `fin` flag. 225 | pub fn set_fin(&mut self, fin: bool) -> &mut Self { 226 | self.fin = fin; 227 | self 228 | } 229 | 230 | /// Is the `rsv1` flag set? 231 | pub fn is_rsv1(&self) -> bool { 232 | self.rsv1 233 | } 234 | 235 | /// Set the `rsv1` flag. 236 | pub fn set_rsv1(&mut self, rsv1: bool) -> &mut Self { 237 | self.rsv1 = rsv1; 238 | self 239 | } 240 | 241 | /// Is the `rsv2` flag set? 242 | pub fn is_rsv2(&self) -> bool { 243 | self.rsv2 244 | } 245 | 246 | /// Set the `rsv2` flag. 247 | pub fn set_rsv2(&mut self, rsv2: bool) -> &mut Self { 248 | self.rsv2 = rsv2; 249 | self 250 | } 251 | 252 | /// Is the `rsv3` flag set? 253 | pub fn is_rsv3(&self) -> bool { 254 | self.rsv3 255 | } 256 | 257 | /// Set the `rsv3` flag. 258 | pub fn set_rsv3(&mut self, rsv3: bool) -> &mut Self { 259 | self.rsv3 = rsv3; 260 | self 261 | } 262 | 263 | /// Is the `masked` flag set? 264 | pub fn is_masked(&self) -> bool { 265 | self.masked 266 | } 267 | 268 | /// Set the `masked` flag. 269 | pub fn set_masked(&mut self, masked: bool) -> &mut Self { 270 | self.masked = masked; 271 | self 272 | } 273 | 274 | /// Get the `opcode`. 275 | pub fn opcode(&self) -> OpCode { 276 | self.opcode 277 | } 278 | 279 | /// Set the `opcode` 280 | pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Self { 281 | self.opcode = opcode; 282 | self 283 | } 284 | 285 | /// Get the `mask`. 286 | pub fn mask(&self) -> u32 { 287 | self.mask 288 | } 289 | 290 | /// Set the `mask` 291 | pub fn set_mask(&mut self, mask: u32) -> &mut Self { 292 | self.mask = mask; 293 | self 294 | } 295 | 296 | /// Get the payload length. 297 | pub fn payload_len(&self) -> usize { 298 | self.payload_len 299 | } 300 | 301 | /// Set the payload length. 302 | pub fn set_payload_len(&mut self, len: usize) -> &mut Self { 303 | self.payload_len = len; 304 | self 305 | } 306 | } 307 | 308 | // Base codec ////////////////////////////////////////////////////////////////////////////////////. 309 | 310 | /// If the payload length byte is 126, the following two bytes represent the 311 | /// actual payload length. 312 | const TWO_EXT: u8 = 126; 313 | 314 | /// If the payload length byte is 127, the following eight bytes represent 315 | /// the actual payload length. 316 | const EIGHT_EXT: u8 = 127; 317 | 318 | /// Codec for encoding/decoding websocket [base] frames. 319 | /// 320 | /// [base]: https://tools.ietf.org/html/rfc6455#section-5.2 321 | #[derive(Debug, Clone)] 322 | pub struct Codec { 323 | /// Maximum size of payload data per frame. 324 | max_data_size: usize, 325 | /// Bits reserved by an extension. 326 | reserved_bits: u8, 327 | /// Scratch buffer used during header encoding. 328 | header_buffer: [u8; MAX_HEADER_SIZE], 329 | } 330 | 331 | impl Default for Codec { 332 | fn default() -> Self { 333 | Codec { max_data_size: 256 * 1024 * 1024, reserved_bits: 0, header_buffer: [0; MAX_HEADER_SIZE] } 334 | } 335 | } 336 | 337 | impl Codec { 338 | /// Create a new base frame codec. 339 | /// 340 | /// The codec will support decoding payload lengths up to 256 MiB 341 | /// (use `set_max_data_size` to change this value). 342 | pub fn new() -> Self { 343 | Codec::default() 344 | } 345 | 346 | /// Get the configured maximum payload length. 347 | pub fn max_data_size(&self) -> usize { 348 | self.max_data_size 349 | } 350 | 351 | /// Limit the maximum size of payload data to `size` bytes. 352 | pub fn set_max_data_size(&mut self, size: usize) -> &mut Self { 353 | self.max_data_size = size; 354 | self 355 | } 356 | 357 | /// The reserved bits currently configured. 358 | pub fn reserved_bits(&self) -> (bool, bool, bool) { 359 | let r = self.reserved_bits; 360 | (r & 4 == 4, r & 2 == 2, r & 1 == 1) 361 | } 362 | 363 | /// Add to the reserved bits in use. 364 | pub fn add_reserved_bits(&mut self, bits: (bool, bool, bool)) -> &mut Self { 365 | let (r1, r2, r3) = bits; 366 | self.reserved_bits |= (r1 as u8) << 2 | (r2 as u8) << 1 | r3 as u8; 367 | self 368 | } 369 | 370 | /// Reset the reserved bits. 371 | pub fn clear_reserved_bits(&mut self) { 372 | self.reserved_bits = 0 373 | } 374 | 375 | /// Decode a websocket frame header. 376 | pub fn decode_header(&self, bytes: &[u8]) -> Result, Error> { 377 | if bytes.len() < 2 { 378 | return Ok(Parsing::NeedMore(2 - bytes.len())); 379 | } 380 | 381 | let first = bytes[0]; 382 | let second = bytes[1]; 383 | let mut offset = 2; 384 | 385 | let fin = first & 0x80 != 0; 386 | let opcode = OpCode::try_from(first & 0xF)?; 387 | 388 | if opcode.is_reserved() { 389 | return Err(Error::ReservedOpCode); 390 | } 391 | 392 | if opcode.is_control() && !fin { 393 | return Err(Error::FragmentedControl); 394 | } 395 | 396 | let mut header = Header::new(opcode); 397 | header.set_fin(fin); 398 | 399 | let rsv1 = first & 0x40 != 0; 400 | if rsv1 && (self.reserved_bits & 4 == 0) { 401 | return Err(Error::InvalidReservedBit(1)); 402 | } 403 | header.set_rsv1(rsv1); 404 | 405 | let rsv2 = first & 0x20 != 0; 406 | if rsv2 && (self.reserved_bits & 2 == 0) { 407 | return Err(Error::InvalidReservedBit(2)); 408 | } 409 | header.set_rsv2(rsv2); 410 | 411 | let rsv3 = first & 0x10 != 0; 412 | if rsv3 && (self.reserved_bits & 1 == 0) { 413 | return Err(Error::InvalidReservedBit(3)); 414 | } 415 | header.set_rsv3(rsv3); 416 | header.set_masked(second & 0x80 != 0); 417 | 418 | let len: u64 = match second & 0x7F { 419 | TWO_EXT => { 420 | if bytes.len() < offset + 2 { 421 | return Ok(Parsing::NeedMore(offset + 2 - bytes.len())); 422 | } 423 | let len = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]); 424 | offset += 2; 425 | u64::from(len) 426 | } 427 | EIGHT_EXT => { 428 | if bytes.len() < offset + 8 { 429 | return Ok(Parsing::NeedMore(offset + 8 - bytes.len())); 430 | } 431 | let mut b = [0; 8]; 432 | b.copy_from_slice(&bytes[offset..offset + 8]); 433 | offset += 8; 434 | u64::from_be_bytes(b) 435 | } 436 | n => u64::from(n), 437 | }; 438 | 439 | if len > MAX_CTRL_BODY_SIZE && header.opcode().is_control() { 440 | return Err(Error::InvalidControlFrameLen); 441 | } 442 | 443 | let len: usize = if len > as_u64(self.max_data_size) { 444 | return Err(Error::PayloadTooLarge { actual: len, maximum: as_u64(self.max_data_size) }); 445 | } else { 446 | len as usize 447 | }; 448 | 449 | header.set_payload_len(len); 450 | 451 | if header.is_masked() { 452 | if bytes.len() < offset + 4 { 453 | return Ok(Parsing::NeedMore(offset + 4 - bytes.len())); 454 | } 455 | let mut b = [0; 4]; 456 | b.copy_from_slice(&bytes[offset..offset + 4]); 457 | offset += 4; 458 | header.set_mask(u32::from_be_bytes(b)); 459 | } 460 | 461 | Ok(Parsing::Done { value: header, offset }) 462 | } 463 | 464 | /// Encode a websocket frame header. 465 | pub fn encode_header(&mut self, header: &Header) -> &[u8] { 466 | let mut offset = 0; 467 | 468 | let mut first_byte = 0_u8; 469 | if header.is_fin() { 470 | first_byte |= 0x80 471 | } 472 | if header.is_rsv1() { 473 | first_byte |= 0x40 474 | } 475 | if header.is_rsv2() { 476 | first_byte |= 0x20 477 | } 478 | if header.is_rsv3() { 479 | first_byte |= 0x10 480 | } 481 | 482 | let opcode: u8 = header.opcode().into(); 483 | first_byte |= opcode; 484 | 485 | self.header_buffer[offset] = first_byte; 486 | offset += 1; 487 | 488 | let mut second_byte = 0_u8; 489 | if header.is_masked() { 490 | second_byte |= 0x80 491 | } 492 | 493 | let len = header.payload_len(); 494 | 495 | if len < usize::from(TWO_EXT) { 496 | second_byte |= len as u8; 497 | self.header_buffer[offset] = second_byte; 498 | offset += 1; 499 | } else if len <= usize::from(u16::max_value()) { 500 | second_byte |= TWO_EXT; 501 | self.header_buffer[offset] = second_byte; 502 | offset += 1; 503 | self.header_buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_be_bytes()); 504 | offset += 2; 505 | } else { 506 | second_byte |= EIGHT_EXT; 507 | self.header_buffer[offset] = second_byte; 508 | offset += 1; 509 | self.header_buffer[offset..offset + 8].copy_from_slice(&as_u64(len).to_be_bytes()); 510 | offset += 8; 511 | } 512 | 513 | if header.is_masked() { 514 | self.header_buffer[offset..offset + 4].copy_from_slice(&header.mask().to_be_bytes()); 515 | offset += 4; 516 | } 517 | 518 | &self.header_buffer[..offset] 519 | } 520 | 521 | /// Use the given header's mask and apply it to the data. 522 | pub fn apply_mask(header: &Header, data: &mut [u8]) { 523 | if header.is_masked() { 524 | let mask = header.mask().to_be_bytes(); 525 | for (byte, &key) in data.iter_mut().zip(mask.iter().cycle()) { 526 | *byte ^= key; 527 | } 528 | } 529 | } 530 | } 531 | 532 | /// Error cases the base frame decoder may encounter. 533 | #[non_exhaustive] 534 | #[derive(Debug)] 535 | pub enum Error { 536 | /// An I/O error has been encountered. 537 | Io(io::Error), 538 | /// Some unknown opcode number has been decoded. 539 | UnknownOpCode, 540 | /// The opcode decoded is reserved. 541 | ReservedOpCode, 542 | /// A fragmented control frame (fin bit not set) has been decoded. 543 | FragmentedControl, 544 | /// A control frame with an invalid length code has been decoded. 545 | InvalidControlFrameLen, 546 | /// The reserved bit is invalid. 547 | InvalidReservedBit(u8), 548 | /// The payload length of a frame exceeded the configured maximum. 549 | PayloadTooLarge { actual: u64, maximum: u64 }, 550 | } 551 | 552 | impl fmt::Display for Error { 553 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 554 | match self { 555 | Error::Io(e) => write!(f, "i/o error: {}", e), 556 | Error::UnknownOpCode => f.write_str("unknown opcode"), 557 | Error::ReservedOpCode => f.write_str("reserved opcode"), 558 | Error::FragmentedControl => f.write_str("fragmented control frame"), 559 | Error::InvalidControlFrameLen => f.write_str("invalid control frame length"), 560 | Error::InvalidReservedBit(n) => write!(f, "invalid reserved bit: {}", n), 561 | Error::PayloadTooLarge { actual, maximum } => { 562 | write!(f, "payload too large: len = {}, maximum = {}", actual, maximum) 563 | } 564 | } 565 | } 566 | } 567 | 568 | impl std::error::Error for Error { 569 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 570 | match self { 571 | Error::Io(e) => Some(e), 572 | Error::UnknownOpCode 573 | | Error::ReservedOpCode 574 | | Error::FragmentedControl 575 | | Error::InvalidControlFrameLen 576 | | Error::InvalidReservedBit(_) 577 | | Error::PayloadTooLarge { .. } => None, 578 | } 579 | } 580 | } 581 | 582 | impl From for Error { 583 | fn from(e: io::Error) -> Self { 584 | Error::Io(e) 585 | } 586 | } 587 | 588 | impl From for Error { 589 | fn from(_: UnknownOpCode) -> Self { 590 | Error::UnknownOpCode 591 | } 592 | } 593 | 594 | // Tests ////////////////////////////////////////////////////////////////////////////////////////// 595 | 596 | #[cfg(test)] 597 | mod test { 598 | use super::{Codec, Error, OpCode}; 599 | use crate::Parsing; 600 | use quickcheck::QuickCheck; 601 | 602 | #[test] 603 | fn decode_partial_header() { 604 | let partial_header: &[u8] = &[0x89]; 605 | assert!(matches! { 606 | Codec::new().decode_header(partial_header), 607 | Ok(Parsing::NeedMore(1)) 608 | }) 609 | } 610 | 611 | #[test] 612 | fn decode_partial_len() { 613 | let partial_length_1: &[u8] = &[0x89, 0xFE, 0x01]; 614 | assert!(matches! { 615 | Codec::new().decode_header(partial_length_1), 616 | Ok(Parsing::NeedMore(1)) 617 | }); 618 | let partial_length_2: &[u8] = &[0x89, 0xFF, 0x01, 0x02, 0x03, 0x04]; 619 | assert!(matches! { 620 | Codec::new().decode_header(partial_length_2), 621 | Ok(Parsing::NeedMore(4)) 622 | }) 623 | } 624 | 625 | #[test] 626 | fn decode_partial_mask() { 627 | let partial_mask: &[u8] = &[0x82, 0xFE, 0x01, 0x02, 0x00, 0x00]; 628 | assert!(matches! { 629 | Codec::new().decode_header(partial_mask), 630 | Ok(Parsing::NeedMore(2)) 631 | }) 632 | } 633 | 634 | #[test] 635 | fn decode_partial_payload() { 636 | let partial_payload: &mut [u8] = &mut [0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00]; 637 | if let Ok(Parsing::Done { value, offset }) = Codec::new().decode_header(partial_payload) { 638 | assert_eq!(3, value.payload_len() - (partial_payload.len() - offset)) 639 | } else { 640 | assert!(false) 641 | } 642 | } 643 | 644 | #[test] 645 | fn decode_invalid_control_payload_len() { 646 | // Payload on control frame must be 125 bytes or less. 2nd byte must be 0xFD or less. 647 | let ctrl_payload_len: &[u8] = &[0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; 648 | assert!(matches! { 649 | Codec::new().decode_header(ctrl_payload_len), 650 | Err(Error::InvalidControlFrameLen) 651 | }) 652 | } 653 | 654 | /// Checking that rsv1, rsv2, and rsv3 bit set returns error. 655 | #[test] 656 | fn decode_reserved() { 657 | // rsv1, rsv2, and rsv3. 658 | let reserved = [0x90, 0xa0, 0xc0]; 659 | for res in &reserved { 660 | let mut buf = [0; 2]; 661 | buf[0] |= *res; 662 | assert!(matches! { 663 | Codec::new().decode_header(&buf), 664 | Err(Error::InvalidReservedBit(_)) 665 | }) 666 | } 667 | } 668 | 669 | /// Checking that a control frame, where fin bit is 0, returns an error. 670 | #[test] 671 | fn decode_fragmented_control() { 672 | let second_bytes = [8, 9, 10]; 673 | for sb in &second_bytes { 674 | let mut buf = [0; 2]; 675 | buf[0] |= *sb; 676 | assert!(matches! { 677 | Codec::new().decode_header(&buf), 678 | Err(Error::FragmentedControl) 679 | }) 680 | } 681 | } 682 | 683 | /// Checking that reserved opcodes return an error. 684 | #[test] 685 | fn decode_reserved_opcodes() { 686 | let reserved = [3, 4, 5, 6, 7, 11, 12, 13, 14, 15]; 687 | for res in &reserved { 688 | let mut buf = [0; 2]; 689 | buf[0] |= 0x80 | *res; 690 | assert!(matches! { 691 | Codec::new().decode_header(&buf), 692 | Err(Error::ReservedOpCode) 693 | }) 694 | } 695 | } 696 | 697 | #[test] 698 | fn decode_ping_no_data() { 699 | let ping_no_data: &mut [u8] = &mut [0x89, 0x80, 0x00, 0x00, 0x00, 0x01]; 700 | let c = Codec::new(); 701 | if let Ok(Parsing::Done { value: header, .. }) = c.decode_header(ping_no_data) { 702 | assert!(header.is_fin()); 703 | assert!(!header.is_rsv1()); 704 | assert!(!header.is_rsv2()); 705 | assert!(!header.is_rsv3()); 706 | assert!(header.opcode() == OpCode::Ping); 707 | assert!(header.payload_len() == 0) 708 | } else { 709 | assert!(false) 710 | } 711 | } 712 | 713 | #[test] 714 | fn reserved_bits() { 715 | fn property(bits: (bool, bool, bool)) -> bool { 716 | let mut c = Codec::new(); 717 | assert_eq!((false, false, false), c.reserved_bits()); 718 | c.add_reserved_bits(bits); 719 | bits == c.reserved_bits() 720 | } 721 | QuickCheck::new().quickcheck(property as fn((bool, bool, bool)) -> bool) 722 | } 723 | } 724 | -------------------------------------------------------------------------------- /src/connection.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // 3 | // Licensed under the Apache License, Version 2.0 4 | // or the MIT 5 | // license , at your 6 | // option. All files in the project carrying such notice may not be copied, 7 | // modified, or distributed except according to those terms. 8 | 9 | //! A persistent websocket connection after the handshake phase, represented 10 | //! as a [`Sender`] and [`Receiver`] pair. 11 | 12 | use crate::data::{ByteSlice125, Data, Incoming}; 13 | use crate::{ 14 | base::{self, Header, OpCode, MAX_HEADER_SIZE}, 15 | extension::Extension, 16 | Parsing, Storage, 17 | }; 18 | use bytes::{Buf, BytesMut}; 19 | use futures::{ 20 | io::{ReadHalf, WriteHalf}, 21 | lock::BiLock, 22 | prelude::*, 23 | }; 24 | use std::{fmt, io, str}; 25 | 26 | /// Accumulated max. size of a complete message. 27 | const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024; 28 | 29 | /// Max. size of a single message frame. 30 | const MAX_FRAME_SIZE: usize = MAX_MESSAGE_SIZE; 31 | 32 | /// Is the connection used by a client or server? 33 | #[derive(Copy, Clone, Debug, PartialEq, Eq)] 34 | pub enum Mode { 35 | /// Client-side of a connection (implies masking of payload data). 36 | Client, 37 | /// Server-side of a connection. 38 | Server, 39 | } 40 | 41 | impl Mode { 42 | pub fn is_client(self) -> bool { 43 | if let Mode::Client = self { 44 | true 45 | } else { 46 | false 47 | } 48 | } 49 | 50 | pub fn is_server(self) -> bool { 51 | !self.is_client() 52 | } 53 | } 54 | 55 | /// Connection ID. 56 | #[derive(Clone, Copy, Debug)] 57 | struct Id(u32); 58 | 59 | impl fmt::Display for Id { 60 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 61 | write!(f, "{:08x}", self.0) 62 | } 63 | } 64 | 65 | /// The sending half of a connection. 66 | #[derive(Debug)] 67 | pub struct Sender { 68 | id: Id, 69 | mode: Mode, 70 | codec: base::Codec, 71 | writer: BiLock>, 72 | mask_buffer: Vec, 73 | extensions: BiLock>>, 74 | has_extensions: bool, 75 | } 76 | 77 | /// The receiving half of a connection. 78 | #[derive(Debug)] 79 | pub struct Receiver { 80 | id: Id, 81 | mode: Mode, 82 | codec: base::Codec, 83 | reader: ReadHalf, 84 | writer: BiLock>, 85 | extensions: BiLock>>, 86 | has_extensions: bool, 87 | buffer: BytesMut, 88 | ctrl_buffer: BytesMut, 89 | max_message_size: usize, 90 | is_closed: bool, 91 | } 92 | 93 | /// A connection builder. 94 | /// 95 | /// Allows configuring certain parameters and extensions before 96 | /// creating the [`Sender`]/[`Receiver`] pair that represents the 97 | /// connection. 98 | #[derive(Debug)] 99 | pub struct Builder { 100 | id: Id, 101 | mode: Mode, 102 | socket: T, 103 | codec: base::Codec, 104 | extensions: Vec>, 105 | buffer: BytesMut, 106 | max_message_size: usize, 107 | } 108 | 109 | impl Builder { 110 | /// Create a new `Builder` from the given async I/O resource and mode. 111 | /// 112 | /// **Note**: Use this type only after a successful [handshake][0]. 113 | /// You can either use this crate's [handshake functionality][1] 114 | /// or perform the handshake by some other means. 115 | /// 116 | /// [0]: https://tools.ietf.org/html/rfc6455#section-4 117 | /// [1]: crate::handshake 118 | pub fn new(socket: T, mode: Mode) -> Self { 119 | let mut codec = base::Codec::default(); 120 | codec.set_max_data_size(MAX_FRAME_SIZE); 121 | Builder { 122 | id: Id(rand::random()), 123 | mode, 124 | socket, 125 | codec, 126 | extensions: Vec::new(), 127 | buffer: BytesMut::new(), 128 | max_message_size: MAX_MESSAGE_SIZE, 129 | } 130 | } 131 | 132 | /// Set a custom buffer to use. 133 | pub fn set_buffer(&mut self, b: BytesMut) { 134 | self.buffer = b 135 | } 136 | 137 | /// Add extensions to use with this connection. 138 | /// 139 | /// Only enabled extensions will be considered. 140 | pub fn add_extensions(&mut self, extensions: I) 141 | where 142 | I: IntoIterator>, 143 | { 144 | for e in extensions.into_iter().filter(|e| e.is_enabled()) { 145 | log::debug!("{}: using extension: {}", self.id, e.name()); 146 | self.codec.add_reserved_bits(e.reserved_bits()); 147 | self.extensions.push(e) 148 | } 149 | } 150 | 151 | /// Set the maximum size of a complete message. 152 | /// 153 | /// Message fragments will be buffered and concatenated up to this value, 154 | /// i.e. the sum of all message frames payload lengths will not be greater 155 | /// than this maximum. However, extensions may increase the total message 156 | /// size further, e.g. by decompressing the payload data. 157 | pub fn set_max_message_size(&mut self, max: usize) { 158 | self.max_message_size = max 159 | } 160 | 161 | /// Set the maximum size of a single websocket frame payload. 162 | pub fn set_max_frame_size(&mut self, max: usize) { 163 | self.codec.set_max_data_size(max); 164 | } 165 | 166 | /// Create a configured [`Sender`]/[`Receiver`] pair. 167 | pub fn finish(self) -> (Sender, Receiver) { 168 | let (rhlf, whlf) = self.socket.split(); 169 | let (wrt1, wrt2) = BiLock::new(whlf); 170 | let has_extensions = !self.extensions.is_empty(); 171 | let (ext1, ext2) = BiLock::new(self.extensions); 172 | 173 | let recv = Receiver { 174 | id: self.id, 175 | mode: self.mode, 176 | reader: rhlf, 177 | writer: wrt1, 178 | codec: self.codec.clone(), 179 | extensions: ext1, 180 | has_extensions, 181 | buffer: self.buffer, 182 | ctrl_buffer: BytesMut::new(), 183 | max_message_size: self.max_message_size, 184 | is_closed: false, 185 | }; 186 | 187 | let send = Sender { 188 | id: self.id, 189 | mode: self.mode, 190 | writer: wrt2, 191 | mask_buffer: Vec::new(), 192 | codec: self.codec, 193 | extensions: ext2, 194 | has_extensions, 195 | }; 196 | 197 | (send, recv) 198 | } 199 | } 200 | 201 | impl Receiver { 202 | /// Receive the next websocket message. 203 | /// 204 | /// The received frames forming the complete message will be appended to 205 | /// the given `message` argument. The returned [`Incoming`] value describes 206 | /// the type of data that was received, e.g. binary or textual data. 207 | /// 208 | /// Interleaved PONG frames are returned immediately as `Data::Pong` 209 | /// values. If PONGs are not expected or uninteresting, 210 | /// [`Receiver::receive_data`] may be used instead which skips over PONGs 211 | /// and considers only application payload data. 212 | pub async fn receive(&mut self, message: &mut Vec) -> Result, Error> { 213 | let mut first_fragment_opcode = None; 214 | let mut length: usize = 0; 215 | let message_len = message.len(); 216 | loop { 217 | if self.is_closed { 218 | log::debug!("{}: cannot receive, connection is closed", self.id); 219 | return Err(Error::Closed); 220 | } 221 | 222 | self.ctrl_buffer.clear(); 223 | let mut header = self.receive_header().await?; 224 | log::trace!("{}: recv: {}", self.id, header); 225 | 226 | // Handle control frames: PING, PONG and CLOSE. 227 | if header.opcode().is_control() { 228 | self.read_buffer(&header).await?; 229 | self.ctrl_buffer = self.buffer.split_to(header.payload_len()); 230 | base::Codec::apply_mask(&header, &mut self.ctrl_buffer); 231 | if header.opcode() == OpCode::Pong { 232 | return Ok(Incoming::Pong(&self.ctrl_buffer[..])); 233 | } 234 | if let Some(close_reason) = self.on_control(&header).await? { 235 | log::trace!("{}: recv, incoming CLOSE: {:?}", self.id, close_reason); 236 | return Ok(Incoming::Closed(close_reason)); 237 | } 238 | continue; 239 | } 240 | 241 | length = length.saturating_add(header.payload_len()); 242 | 243 | // Check if total message does not exceed maximum. 244 | if length > self.max_message_size { 245 | log::warn!("{}: accumulated message length exceeds maximum", self.id); 246 | 247 | // Discard bytes that were too large to fit in the buffer. 248 | discard_bytes(length as u64, &mut self.reader).await?; 249 | return Err(Error::MessageTooLarge { current: length, maximum: self.max_message_size }); 250 | } 251 | 252 | // Get the frame's payload data bytes from buffer or socket. 253 | { 254 | let old_msg_len = message.len(); 255 | 256 | let bytes_to_read = { 257 | let required = header.payload_len(); 258 | let buffered = self.buffer.len(); 259 | 260 | if buffered == 0 { 261 | required 262 | } else if required > buffered { 263 | message.extend_from_slice(&self.buffer); 264 | self.buffer.clear(); 265 | required - buffered 266 | } else { 267 | message.extend_from_slice(&self.buffer.split_to(required)); 268 | 0 269 | } 270 | }; 271 | 272 | if bytes_to_read > 0 { 273 | let n = message.len(); 274 | message.resize(n + bytes_to_read, 0u8); 275 | self.reader.read_exact(&mut message[n..]).await? 276 | } 277 | 278 | debug_assert_eq!(header.payload_len(), message.len() - old_msg_len); 279 | 280 | base::Codec::apply_mask(&header, &mut message[old_msg_len..]); 281 | } 282 | 283 | match (header.is_fin(), header.opcode()) { 284 | (false, OpCode::Continue) => { 285 | // Intermediate message fragment. 286 | if first_fragment_opcode.is_none() { 287 | log::debug!("{}: continue frame while not processing message fragments", self.id); 288 | return Err(Error::UnexpectedOpCode(OpCode::Continue)); 289 | } 290 | continue; 291 | } 292 | (false, oc) => { 293 | // Initial message fragment. 294 | if first_fragment_opcode.is_some() { 295 | log::debug!("{}: initial fragment while processing a fragmented message", self.id); 296 | return Err(Error::UnexpectedOpCode(oc)); 297 | } 298 | first_fragment_opcode = Some(oc); 299 | self.decode_with_extensions(&mut header, message).await?; 300 | continue; 301 | } 302 | (true, OpCode::Continue) => { 303 | // Last message fragment. 304 | if let Some(oc) = first_fragment_opcode.take() { 305 | header.set_payload_len(message.len()); 306 | log::trace!("{}: last fragment: total length = {} bytes", self.id, message.len()); 307 | self.decode_with_extensions(&mut header, message).await?; 308 | header.set_opcode(oc); 309 | } else { 310 | log::debug!("{}: last continue frame while not processing message fragments", self.id); 311 | return Err(Error::UnexpectedOpCode(OpCode::Continue)); 312 | } 313 | } 314 | (true, oc) => { 315 | // Regular non-fragmented message. 316 | if first_fragment_opcode.is_some() { 317 | log::debug!("{}: regular message while processing fragmented message", self.id); 318 | return Err(Error::UnexpectedOpCode(oc)); 319 | } 320 | self.decode_with_extensions(&mut header, message).await? 321 | } 322 | } 323 | 324 | let num_bytes = message.len() - message_len; 325 | 326 | if header.opcode() == OpCode::Text { 327 | return Ok(Incoming::Data(Data::Text(num_bytes))); 328 | } else { 329 | return Ok(Incoming::Data(Data::Binary(num_bytes))); 330 | } 331 | } 332 | } 333 | 334 | /// Receive the next websocket message, skipping over control frames. 335 | pub async fn receive_data(&mut self, message: &mut Vec) -> Result { 336 | loop { 337 | if let Incoming::Data(d) = self.receive(message).await? { 338 | return Ok(d); 339 | } 340 | } 341 | } 342 | 343 | /// Read the next frame header. 344 | async fn receive_header(&mut self) -> Result { 345 | loop { 346 | match self.codec.decode_header(&self.buffer)? { 347 | Parsing::Done { value: header, offset } => { 348 | debug_assert!(offset <= MAX_HEADER_SIZE); 349 | self.buffer.advance(offset); 350 | return Ok(header); 351 | } 352 | Parsing::NeedMore(n) => crate::read(&mut self.reader, &mut self.buffer, n).await?, 353 | } 354 | } 355 | } 356 | 357 | /// Read the complete payload data into the read buffer. 358 | async fn read_buffer(&mut self, header: &Header) -> Result<(), Error> { 359 | if header.payload_len() <= self.buffer.len() { 360 | return Ok(()); 361 | } 362 | let i = self.buffer.len(); 363 | let d = header.payload_len() - i; 364 | self.buffer.resize(i + d, 0u8); 365 | self.reader.read_exact(&mut self.buffer[i..]).await?; 366 | Ok(()) 367 | } 368 | 369 | /// Answer incoming control frames. 370 | /// `PING`: replied to immediately with a `PONG` 371 | /// `PONG`: no action 372 | /// `CLOSE`: replied to immediately with a `CLOSE`; returns the [`CloseReason`] 373 | /// All other [`OpCode`]s return [`Error::UnexpectedOpCode`] 374 | async fn on_control(&mut self, header: &Header) -> Result, Error> { 375 | match header.opcode() { 376 | OpCode::Ping => { 377 | let mut answer = Header::new(OpCode::Pong); 378 | let mut unused = Vec::new(); 379 | let mut data = Storage::Unique(&mut self.ctrl_buffer); 380 | write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut answer, &mut data, &mut unused) 381 | .await?; 382 | self.flush().await?; 383 | Ok(None) 384 | } 385 | OpCode::Pong => Ok(None), 386 | OpCode::Close => { 387 | log::trace!("{}: Acknowledging CLOSE to sender", self.id); 388 | let (mut header, reason) = close_answer(&self.ctrl_buffer)?; 389 | // Write back a Close frame 390 | let mut unused = Vec::new(); 391 | if let Some(CloseReason { code, .. }) = reason { 392 | let mut data = code.to_be_bytes(); 393 | let mut data = Storage::Unique(&mut data); 394 | let _ = write( 395 | self.id, 396 | self.mode, 397 | &mut self.codec, 398 | &mut self.writer, 399 | &mut header, 400 | &mut data, 401 | &mut unused, 402 | ) 403 | .await; 404 | } else { 405 | let mut data = Storage::Unique(&mut []); 406 | let _ = write( 407 | self.id, 408 | self.mode, 409 | &mut self.codec, 410 | &mut self.writer, 411 | &mut header, 412 | &mut data, 413 | &mut unused, 414 | ) 415 | .await; 416 | } 417 | self.flush().await?; 418 | // Close down the connection but the I/O stream could already be closed and 419 | // we don't want propagate such error to the user if the I/O was already closed. 420 | _ = self.writer.lock().await.close().await; 421 | self.is_closed = true; 422 | Ok(reason) 423 | } 424 | OpCode::Binary 425 | | OpCode::Text 426 | | OpCode::Continue 427 | | OpCode::Reserved3 428 | | OpCode::Reserved4 429 | | OpCode::Reserved5 430 | | OpCode::Reserved6 431 | | OpCode::Reserved7 432 | | OpCode::Reserved11 433 | | OpCode::Reserved12 434 | | OpCode::Reserved13 435 | | OpCode::Reserved14 436 | | OpCode::Reserved15 => Err(Error::UnexpectedOpCode(header.opcode())), 437 | } 438 | } 439 | 440 | /// Apply all extensions to the given header and the internal message buffer. 441 | async fn decode_with_extensions(&mut self, header: &mut Header, message: &mut Vec) -> Result<(), Error> { 442 | if !self.has_extensions { 443 | return Ok(()); 444 | } 445 | for e in self.extensions.lock().await.iter_mut() { 446 | log::trace!("{}: decoding with extension: {}", self.id, e.name()); 447 | e.decode(header, message).map_err(Error::Extension)? 448 | } 449 | Ok(()) 450 | } 451 | 452 | /// Flush the socket buffer. 453 | async fn flush(&mut self) -> Result<(), Error> { 454 | log::trace!("{}: Receiver flushing connection", self.id); 455 | if self.is_closed { 456 | return Ok(()); 457 | } 458 | self.writer.lock().await.flush().await.or(Err(Error::Closed)) 459 | } 460 | } 461 | 462 | impl Sender { 463 | /// Send a text value over the websocket connection. 464 | pub async fn send_text(&mut self, data: impl AsRef) -> Result<(), Error> { 465 | let mut header = Header::new(OpCode::Text); 466 | self.send_frame(&mut header, &mut Storage::Shared(data.as_ref().as_bytes())).await 467 | } 468 | 469 | /// Send a text value over the websocket connection. 470 | /// 471 | /// This method performs one copy fewer than [`Sender::send_text`]. 472 | pub async fn send_text_owned(&mut self, data: String) -> Result<(), Error> { 473 | let mut header = Header::new(OpCode::Text); 474 | self.send_frame(&mut header, &mut Storage::Owned(data.into_bytes())).await 475 | } 476 | 477 | /// Send some binary data over the websocket connection. 478 | pub async fn send_binary(&mut self, data: impl AsRef<[u8]>) -> Result<(), Error> { 479 | let mut header = Header::new(OpCode::Binary); 480 | self.send_frame(&mut header, &mut Storage::Shared(data.as_ref())).await 481 | } 482 | 483 | /// Send some binary data over the websocket connection. 484 | /// 485 | /// This method performs one copy fewer than [`Sender::send_binary`]. 486 | /// The `data` buffer may be modified by this method, e.g. if masking is necessary. 487 | pub async fn send_binary_mut(&mut self, mut data: impl AsMut<[u8]>) -> Result<(), Error> { 488 | let mut header = Header::new(OpCode::Binary); 489 | self.send_frame(&mut header, &mut Storage::Unique(data.as_mut())).await 490 | } 491 | 492 | /// Ping the remote end. 493 | pub async fn send_ping(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> { 494 | let mut header = Header::new(OpCode::Ping); 495 | self.write(&mut header, &mut Storage::Shared(data.as_ref())).await 496 | } 497 | 498 | /// Send an unsolicited Pong to the remote. 499 | pub async fn send_pong(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> { 500 | let mut header = Header::new(OpCode::Pong); 501 | self.write(&mut header, &mut Storage::Shared(data.as_ref())).await 502 | } 503 | 504 | /// Flush the socket buffer. 505 | pub async fn flush(&mut self) -> Result<(), Error> { 506 | log::trace!("{}: Sender flushing connection", self.id); 507 | self.writer.lock().await.flush().await.or(Err(Error::Closed)) 508 | } 509 | 510 | /// Send a close message and close the connection. 511 | pub async fn close(&mut self) -> Result<(), Error> { 512 | log::trace!("{}: closing connection", self.id); 513 | let mut header = Header::new(OpCode::Close); 514 | let code = 1000_u16.to_be_bytes(); // 1000 = normal closure 515 | self.write(&mut header, &mut Storage::Shared(&code[..])).await?; 516 | self.flush().await?; 517 | self.writer.lock().await.close().await.or(Err(Error::Closed)) 518 | } 519 | 520 | /// Send arbitrary websocket frames. 521 | /// 522 | /// Before sending, extensions will be applied to header and payload data. 523 | async fn send_frame(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> { 524 | if !self.has_extensions { 525 | return self.write(header, data).await; 526 | } 527 | 528 | for e in self.extensions.lock().await.iter_mut() { 529 | log::trace!("{}: encoding with extension: {}", self.id, e.name()); 530 | e.encode(header, data).map_err(Error::Extension)? 531 | } 532 | 533 | self.write(header, data).await 534 | } 535 | 536 | /// Write final header and payload data to socket. 537 | /// 538 | /// The data will be masked if necessary. 539 | /// No extensions will be applied to header and payload data. 540 | async fn write(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> { 541 | write(self.id, self.mode, &mut self.codec, &mut self.writer, header, data, &mut self.mask_buffer).await 542 | } 543 | } 544 | 545 | /// Write header and payload data to socket. 546 | async fn write( 547 | id: Id, 548 | mode: Mode, 549 | codec: &mut base::Codec, 550 | writer: &mut BiLock>, 551 | header: &mut Header, 552 | data: &mut Storage<'_>, 553 | mask_buffer: &mut Vec, 554 | ) -> Result<(), Error> { 555 | if mode.is_client() { 556 | header.set_masked(true); 557 | header.set_mask(rand::random()); 558 | } 559 | header.set_payload_len(data.as_ref().len()); 560 | 561 | log::trace!("{}: send: {}", id, header); 562 | 563 | let header_bytes = codec.encode_header(&header); 564 | let mut w = writer.lock().await; 565 | w.write_all(&header_bytes).await.or(Err(Error::Closed))?; 566 | 567 | if !header.is_masked() { 568 | return w.write_all(data.as_ref()).await.or(Err(Error::Closed)); 569 | } 570 | 571 | match data { 572 | Storage::Shared(slice) => { 573 | mask_buffer.clear(); 574 | mask_buffer.extend_from_slice(slice); 575 | base::Codec::apply_mask(header, mask_buffer); 576 | w.write_all(mask_buffer).await.or(Err(Error::Closed)) 577 | } 578 | Storage::Unique(slice) => { 579 | base::Codec::apply_mask(header, slice); 580 | w.write_all(slice).await.or(Err(Error::Closed)) 581 | } 582 | Storage::Owned(ref mut bytes) => { 583 | base::Codec::apply_mask(header, bytes); 584 | w.write_all(bytes).await.or(Err(Error::Closed)) 585 | } 586 | } 587 | } 588 | 589 | /// Create a close frame based on the given data. The close frame is echoed back 590 | /// to the sender. 591 | fn close_answer(data: &[u8]) -> Result<(Header, Option), Error> { 592 | let answer = Header::new(OpCode::Close); 593 | if data.len() < 2 { 594 | return Ok((answer, None)); 595 | } 596 | // Check that the reason string is properly encoded 597 | let descr = std::str::from_utf8(&data[2..])?.into(); 598 | let code = u16::from_be_bytes([data[0], data[1]]); 599 | let reason = CloseReason { code, descr: Some(descr) }; 600 | 601 | // Status codes are defined in 602 | // https://tools.ietf.org/html/rfc6455#section-7.4.1 and 603 | // https://mailarchive.ietf.org/arch/msg/hybi/P_1vbD9uyHl63nbIIbFxKMfSwcM/ 604 | match code { 605 | | 1000 ..= 1003 606 | | 1007 ..= 1011 607 | | 1012 // Service Restart 608 | | 1013 // Try Again Later 609 | | 1015 610 | | 3000 ..= 4999 => Ok((answer, Some(reason))), // acceptable codes 611 | _ => { 612 | // invalid code => protocol error (1002) 613 | Ok((answer, Some(CloseReason { code: 1002, descr: None}))) 614 | } 615 | } 616 | } 617 | 618 | /// Errors which may occur when sending or receiving messages. 619 | #[non_exhaustive] 620 | #[derive(Debug)] 621 | pub enum Error { 622 | /// An I/O error was encountered. 623 | Io(io::Error), 624 | /// The base codec errored. 625 | Codec(base::Error), 626 | /// An extension produced an error while encoding or decoding. 627 | Extension(crate::BoxedError), 628 | /// An unexpected opcode was encountered. 629 | UnexpectedOpCode(OpCode), 630 | /// A close reason was not correctly UTF-8 encoded. 631 | Utf8(str::Utf8Error), 632 | /// The total message payload data size exceeds the configured maximum. 633 | MessageTooLarge { current: usize, maximum: usize }, 634 | /// The connection is closed. 635 | Closed, 636 | } 637 | 638 | /// Reason for closing the connection. 639 | #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] 640 | pub struct CloseReason { 641 | pub code: u16, 642 | pub descr: Option, 643 | } 644 | 645 | impl fmt::Display for Error { 646 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 647 | match self { 648 | Error::Io(e) => write!(f, "i/o error: {}", e), 649 | Error::Codec(e) => write!(f, "codec error: {}", e), 650 | Error::Extension(e) => write!(f, "extension error: {}", e), 651 | Error::UnexpectedOpCode(c) => write!(f, "unexpected opcode: {}", c), 652 | Error::Utf8(e) => write!(f, "utf-8 error: {}", e), 653 | Error::MessageTooLarge { current, maximum } => { 654 | write!(f, "message too large: len >= {}, maximum = {}", current, maximum) 655 | } 656 | Error::Closed => f.write_str("connection closed"), 657 | } 658 | } 659 | } 660 | 661 | impl std::error::Error for Error { 662 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 663 | match self { 664 | Error::Io(e) => Some(e), 665 | Error::Codec(e) => Some(e), 666 | Error::Extension(e) => Some(&**e), 667 | Error::Utf8(e) => Some(e), 668 | Error::UnexpectedOpCode(_) | Error::MessageTooLarge { .. } | Error::Closed => None, 669 | } 670 | } 671 | } 672 | 673 | impl From for Error { 674 | fn from(e: io::Error) -> Self { 675 | if e.kind() == io::ErrorKind::UnexpectedEof { 676 | Error::Closed 677 | } else { 678 | Error::Io(e) 679 | } 680 | } 681 | } 682 | 683 | impl From for Error { 684 | fn from(e: str::Utf8Error) -> Self { 685 | Error::Utf8(e) 686 | } 687 | } 688 | 689 | impl From for Error { 690 | fn from(e: base::Error) -> Self { 691 | Error::Codec(e) 692 | } 693 | } 694 | 695 | /// Discard `n` bytes from the underlying reader. 696 | async fn discard_bytes(n: u64, reader: R) -> Result { 697 | futures::io::copy(&mut reader.take(n), &mut futures::io::sink()).await 698 | } 699 | 700 | #[cfg(test)] 701 | mod tests { 702 | use super::discard_bytes; 703 | use futures::{io::Cursor, AsyncReadExt}; 704 | 705 | #[tokio::test] 706 | async fn discard_bytes_works() { 707 | let bytes: Vec = (0..5).collect(); 708 | let mut cursor = Cursor::new(bytes); 709 | discard_bytes(1_u64, &mut cursor).await.unwrap(); 710 | let mut read = vec![0; 4]; 711 | cursor.read_exact(&mut read).await.unwrap(); 712 | assert_eq!(read, vec![1, 2, 3, 4]); 713 | } 714 | } 715 | -------------------------------------------------------------------------------- /src/data.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // 3 | // Licensed under the Apache License, Version 2.0 4 | // or the MIT 5 | // license , at your 6 | // option. All files in the project carrying such notice may not be copied, 7 | // modified, or distributed except according to those terms. 8 | 9 | //! Types describing various forms of payload data. 10 | 11 | use std::fmt; 12 | 13 | use crate::connection::CloseReason; 14 | 15 | /// Data received from the remote end. 16 | #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] 17 | pub enum Incoming<'a> { 18 | /// Text or binary data. 19 | Data(Data), 20 | /// Data sent with a PONG control frame. 21 | Pong(&'a [u8]), 22 | /// The other end closed the connection. 23 | Closed(CloseReason), 24 | } 25 | 26 | impl Incoming<'_> { 27 | /// Is this text or binary data? 28 | pub fn is_data(&self) -> bool { 29 | if let Incoming::Data(_) = self { 30 | true 31 | } else { 32 | false 33 | } 34 | } 35 | 36 | /// Is this a PONG? 37 | pub fn is_pong(&self) -> bool { 38 | if let Incoming::Pong(_) = self { 39 | true 40 | } else { 41 | false 42 | } 43 | } 44 | 45 | /// Is this text data? 46 | pub fn is_text(&self) -> bool { 47 | if let Incoming::Data(d) = self { 48 | d.is_text() 49 | } else { 50 | false 51 | } 52 | } 53 | 54 | /// Is this binary data? 55 | pub fn is_binary(&self) -> bool { 56 | if let Incoming::Data(d) = self { 57 | d.is_binary() 58 | } else { 59 | false 60 | } 61 | } 62 | } 63 | 64 | #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] 65 | pub enum Data { 66 | /// Textual data (number of bytes). 67 | Text(usize), 68 | /// Binary data (number of bytes). 69 | Binary(usize), 70 | } 71 | 72 | impl Data { 73 | /// Is this text data? 74 | pub fn is_text(&self) -> bool { 75 | if let Data::Text(_) = self { 76 | true 77 | } else { 78 | false 79 | } 80 | } 81 | 82 | /// Is this binary data? 83 | pub fn is_binary(&self) -> bool { 84 | if let Data::Binary(_) = self { 85 | true 86 | } else { 87 | false 88 | } 89 | } 90 | 91 | /// The length of data (number of bytes). 92 | pub fn len(&self) -> usize { 93 | match self { 94 | Data::Text(n) => *n, 95 | Data::Binary(n) => *n, 96 | } 97 | } 98 | } 99 | 100 | /// Wrapper type which restricts the length of its byte slice to 125 bytes. 101 | #[derive(Copy, Clone, Debug)] 102 | pub struct ByteSlice125<'a>(&'a [u8]); 103 | 104 | /// Error, if converting to [`ByteSlice125`] fails. 105 | #[derive(Copy, Clone, Debug)] 106 | pub struct SliceTooLarge(()); 107 | 108 | impl fmt::Display for SliceTooLarge { 109 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 110 | f.write_str("Slice larger than 125 bytes") 111 | } 112 | } 113 | 114 | impl std::error::Error for SliceTooLarge {} 115 | 116 | impl<'a> TryFrom<&'a [u8]> for ByteSlice125<'a> { 117 | type Error = SliceTooLarge; 118 | 119 | fn try_from(value: &'a [u8]) -> Result { 120 | if value.len() > 125 { 121 | Err(SliceTooLarge(())) 122 | } else { 123 | Ok(ByteSlice125(value)) 124 | } 125 | } 126 | } 127 | 128 | impl AsRef<[u8]> for ByteSlice125<'_> { 129 | fn as_ref(&self) -> &[u8] { 130 | self.0 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /src/extension.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // Copyright (c) 2016 twist developers 3 | // 4 | // Licensed under the Apache License, Version 2.0 5 | // or the MIT 6 | // license , at your 7 | // option. All files in the project carrying such notice may not be copied, 8 | // modified, or distributed except according to those terms. 9 | 10 | //! Websocket extensions as per [RFC 6455][rfc6455]. 11 | //! 12 | //! [rfc6455]: https://tools.ietf.org/html/rfc6455#section-9 13 | 14 | #[cfg(feature = "deflate")] 15 | pub mod deflate; 16 | 17 | use crate::{base::Header, BoxedError, Storage}; 18 | use std::{borrow::Cow, fmt}; 19 | 20 | /// A websocket extension as per RFC 6455, section 9. 21 | /// 22 | /// Extensions are invoked during handshake and subsequently during base 23 | /// frame encoding and decoding. The invocation during handshake differs 24 | /// on client and server side. 25 | /// 26 | /// # Server 27 | /// 28 | /// 1. All extensions should consider themselves as disabled but available. 29 | /// 2. When receiving a handshake request from a client, for each extension 30 | /// with a matching name, [`Extension::configure`] will be applied to the 31 | /// request parameters. The extension may internally enable itself. 32 | /// 3. When sending back the response, for each extension whose 33 | /// [`Extension::is_enabled`] returns true, the extension name and its 34 | /// parameters (as returned by [`Extension::params`]) will be included in the 35 | /// response. 36 | /// 37 | /// # Client 38 | /// 39 | /// 1. All extensions should consider themselves as disabled but available. 40 | /// 2. When creating the handshake request, all extensions and its parameters 41 | /// (as returned by [`Extension::params`]) will be included in the request. 42 | /// 3. When receiving the response from the server, for every extension with 43 | /// a matching name in the response, [`Extension::configure`] will be applied 44 | /// to the response parameters. The extension may internally enable itself. 45 | /// 46 | /// After this handshake phase, extensions have been configured and are 47 | /// potentially enabled. Enabled extensions can then be used for further base 48 | /// frame processing. 49 | pub trait Extension: std::fmt::Debug { 50 | /// Is this extension enabled? 51 | fn is_enabled(&self) -> bool; 52 | 53 | /// The name of this extension. 54 | fn name(&self) -> &str; 55 | 56 | /// The parameters this extension wants to send for negotiation. 57 | fn params(&self) -> &[Param]; 58 | 59 | /// Configure this extension with the parameters received from negotiation. 60 | fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError>; 61 | 62 | /// Encode a frame, given as frame header and payload data. 63 | fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError>; 64 | 65 | /// Decode a frame. 66 | /// 67 | /// The frame header is given, as well as the accumulated payload data, i.e. 68 | /// the concatenated payload data of all message fragments. 69 | fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError>; 70 | 71 | /// The reserved bits this extension uses. 72 | fn reserved_bits(&self) -> (bool, bool, bool) { 73 | (false, false, false) 74 | } 75 | } 76 | 77 | impl Extension for Box { 78 | fn is_enabled(&self) -> bool { 79 | (**self).is_enabled() 80 | } 81 | 82 | fn name(&self) -> &str { 83 | (**self).name() 84 | } 85 | 86 | fn params(&self) -> &[Param] { 87 | (**self).params() 88 | } 89 | 90 | fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError> { 91 | (**self).configure(params) 92 | } 93 | 94 | fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError> { 95 | (**self).encode(header, data) 96 | } 97 | 98 | fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError> { 99 | (**self).decode(header, data) 100 | } 101 | 102 | fn reserved_bits(&self) -> (bool, bool, bool) { 103 | (**self).reserved_bits() 104 | } 105 | } 106 | 107 | /// Extension parameter (used for negotiation). 108 | #[derive(Debug, Clone, PartialEq, Eq)] 109 | pub struct Param<'a> { 110 | name: Cow<'a, str>, 111 | value: Option>, 112 | } 113 | 114 | impl<'a> fmt::Display for Param<'a> { 115 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 116 | if let Some(v) = &self.value { 117 | write!(f, "{} = {}", self.name, v) 118 | } else { 119 | write!(f, "{}", self.name) 120 | } 121 | } 122 | } 123 | 124 | impl<'a> Param<'a> { 125 | /// Create a new parameter with the given name. 126 | pub fn new(name: impl Into>) -> Self { 127 | Param { name: name.into(), value: None } 128 | } 129 | 130 | /// Access the parameter name. 131 | pub fn name(&self) -> &str { 132 | &self.name 133 | } 134 | 135 | /// Access the optional parameter value. 136 | pub fn value(&self) -> Option<&str> { 137 | self.value.as_ref().map(|v| v.as_ref()) 138 | } 139 | 140 | /// Set the parameter to the given value. 141 | pub fn set_value(&mut self, value: Option>>) -> &mut Self { 142 | self.value = value.map(Into::into); 143 | self 144 | } 145 | 146 | /// Turn this parameter into one that owns its values. 147 | pub fn acquire(self) -> Param<'static> { 148 | Param { name: Cow::Owned(self.name.into_owned()), value: self.value.map(|v| Cow::Owned(v.into_owned())) } 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /src/extension/deflate.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // 3 | // Licensed under the Apache License, Version 2.0 4 | // or the MIT 5 | // license , at your 6 | // option. All files in the project carrying such notice may not be copied, 7 | // modified, or distributed except according to those terms. 8 | 9 | //! Deflate compression extension mostly conformant with [RFC 7692][rfc7692]. 10 | //! 11 | //! [rfc7692]: https://tools.ietf.org/html/rfc7692 12 | 13 | use crate::{ 14 | as_u64, 15 | base::{Header, OpCode}, 16 | connection::Mode, 17 | extension::{Extension, Param}, 18 | BoxedError, Storage, 19 | }; 20 | use flate2::{write::DeflateDecoder, Compress, Compression, FlushCompress, Status}; 21 | use std::{ 22 | convert::TryInto, 23 | io::{self, Write}, 24 | mem, 25 | }; 26 | 27 | const SERVER_NO_CONTEXT_TAKEOVER: &str = "server_no_context_takeover"; 28 | const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits"; 29 | 30 | const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover"; 31 | const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits"; 32 | 33 | /// The deflate extension type. 34 | /// 35 | /// The extension does currently not support max. window bits other than the 36 | /// default, which is 15 and will ask for no context takeover during handshake. 37 | #[derive(Debug)] 38 | pub struct Deflate { 39 | mode: Mode, 40 | enabled: bool, 41 | buffer: Vec, 42 | params: Vec>, 43 | our_max_window_bits: u8, 44 | their_max_window_bits: u8, 45 | await_last_fragment: bool, 46 | } 47 | 48 | impl Deflate { 49 | /// Create a new deflate extension either on client or server side. 50 | pub fn new(mode: Mode) -> Self { 51 | let params = match mode { 52 | Mode::Server => Vec::new(), 53 | Mode::Client => { 54 | let mut params = Vec::new(); 55 | params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)); 56 | params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)); 57 | params.push(Param::new(CLIENT_MAX_WINDOW_BITS)); 58 | params 59 | } 60 | }; 61 | Deflate { 62 | mode, 63 | enabled: false, 64 | buffer: Vec::new(), 65 | params, 66 | our_max_window_bits: 15, 67 | their_max_window_bits: 15, 68 | await_last_fragment: false, 69 | } 70 | } 71 | 72 | /// Set the server's max. window bits. 73 | /// 74 | /// The value must be within 9 ..= 15. 75 | /// The extension must be in client mode. 76 | /// 77 | /// By including this parameter, a client limits the LZ77 sliding window 78 | /// size that the server will use to compress messages. A server accepts 79 | /// by including the "server_max_window_bits" extension parameter in the 80 | /// response with the same or smaller value as the offer. 81 | pub fn set_max_server_window_bits(&mut self, max: u8) { 82 | assert!(self.mode == Mode::Client, "setting max. server window bits requires client mode"); 83 | assert!(max > 8 && max <= 15, "max. server window bits have to be within 9 ..= 15"); 84 | self.their_max_window_bits = max; // upper bound of the server's window 85 | let mut p = Param::new(SERVER_MAX_WINDOW_BITS); 86 | p.set_value(Some(max.to_string())); 87 | self.params.push(p) 88 | } 89 | 90 | /// Set the client's max. window bits. 91 | /// 92 | /// The value must be within 9 ..= 15. 93 | /// The extension must be in client mode. 94 | /// 95 | /// The parameter informs the server that even if it doesn't include the 96 | /// "client_max_window_bits" extension parameter in the response with a 97 | /// value greater than the one in the negotiation offer or if it doesn't 98 | /// include the extension parameter at all, the client is not going to 99 | /// use an LZ77 sliding window size greater than one given here. 100 | /// The server may also respond with a smaller value which allows the client 101 | /// to reduce its sliding window even more. 102 | pub fn set_max_client_window_bits(&mut self, max: u8) { 103 | assert!(self.mode == Mode::Client, "setting max. client window bits requires client mode"); 104 | assert!(max > 8 && max <= 15, "max. client window bits have to be within 9 ..= 15"); 105 | self.our_max_window_bits = max; // upper bound of the client's window 106 | if let Some(p) = self.params.iter_mut().find(|p| p.name() == CLIENT_MAX_WINDOW_BITS) { 107 | p.set_value(Some(max.to_string())); 108 | } else { 109 | let mut p = Param::new(CLIENT_MAX_WINDOW_BITS); 110 | p.set_value(Some(max.to_string())); 111 | self.params.push(p) 112 | } 113 | } 114 | 115 | fn set_their_max_window_bits(&mut self, p: &Param, expected: Option) -> Result<(), ()> { 116 | if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { 117 | if v < 8 || v > 15 { 118 | log::debug!("invalid {}: {} (expected range: 8 ..= 15)", p.name(), v); 119 | return Err(()); 120 | } 121 | if let Some(x) = expected { 122 | if v > x { 123 | log::debug!("invalid {}: {} (expected: {} <= {})", p.name(), v, v, x); 124 | return Err(()); 125 | } 126 | } 127 | self.their_max_window_bits = std::cmp::max(9, v); 128 | } 129 | Ok(()) 130 | } 131 | } 132 | 133 | impl Extension for Deflate { 134 | fn name(&self) -> &str { 135 | "permessage-deflate" 136 | } 137 | 138 | fn is_enabled(&self) -> bool { 139 | self.enabled 140 | } 141 | 142 | fn params(&self) -> &[Param] { 143 | &self.params 144 | } 145 | 146 | fn configure(&mut self, params: &[Param]) -> Result<(), BoxedError> { 147 | match self.mode { 148 | Mode::Server => { 149 | self.params.clear(); 150 | for p in params { 151 | log::trace!("configure server with: {}", p); 152 | match p.name() { 153 | CLIENT_MAX_WINDOW_BITS => { 154 | if self.set_their_max_window_bits(&p, None).is_err() { 155 | // we just accept the client's offer as is => no need to reply 156 | return Ok(()); 157 | } 158 | } 159 | SERVER_MAX_WINDOW_BITS => { 160 | if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { 161 | // The RFC allows 8 to 15 bits, but due to zlib limitations we 162 | // only support 9 to 15. 163 | if v < 9 || v > 15 { 164 | log::debug!("unacceptable server_max_window_bits: {}", v); 165 | return Ok(()); 166 | } 167 | let mut x = Param::new(SERVER_MAX_WINDOW_BITS); 168 | x.set_value(Some(v.to_string())); 169 | self.params.push(x); 170 | self.our_max_window_bits = v; 171 | } else { 172 | log::debug!("invalid server_max_window_bits: {:?}", p.value()); 173 | return Ok(()); 174 | } 175 | } 176 | CLIENT_NO_CONTEXT_TAKEOVER => self.params.push(Param::new(CLIENT_NO_CONTEXT_TAKEOVER)), 177 | SERVER_NO_CONTEXT_TAKEOVER => self.params.push(Param::new(SERVER_NO_CONTEXT_TAKEOVER)), 178 | _ => { 179 | log::debug!("{}: unknown parameter: {}", self.name(), p.name()); 180 | return Ok(()); 181 | } 182 | } 183 | } 184 | } 185 | Mode::Client => { 186 | let mut server_no_context_takeover = false; 187 | for p in params { 188 | log::trace!("configure client with: {}", p); 189 | match p.name() { 190 | SERVER_NO_CONTEXT_TAKEOVER => server_no_context_takeover = true, 191 | CLIENT_NO_CONTEXT_TAKEOVER => {} // must be supported 192 | SERVER_MAX_WINDOW_BITS => { 193 | let expected = Some(self.their_max_window_bits); 194 | if self.set_their_max_window_bits(&p, expected).is_err() { 195 | return Ok(()); 196 | } 197 | } 198 | CLIENT_MAX_WINDOW_BITS => { 199 | if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { 200 | if v < 8 || v > 15 { 201 | log::debug!("unacceptable client_max_window_bits: {}", v); 202 | return Ok(()); 203 | } 204 | use std::cmp::{max, min}; 205 | // Due to zlib limitations we have to use 9 as a lower bound 206 | // here, even if the server allowed us to go down to 8 bits. 207 | self.our_max_window_bits = min(self.our_max_window_bits, max(9, v)); 208 | } 209 | } 210 | _ => { 211 | log::debug!("{}: unknown parameter: {}", self.name(), p.name()); 212 | return Ok(()); 213 | } 214 | } 215 | } 216 | if !server_no_context_takeover { 217 | log::debug!("{}: server did not confirm no context takeover", self.name()); 218 | return Ok(()); 219 | } 220 | } 221 | } 222 | self.enabled = true; 223 | Ok(()) 224 | } 225 | 226 | fn reserved_bits(&self) -> (bool, bool, bool) { 227 | (true, false, false) 228 | } 229 | 230 | fn decode(&mut self, header: &mut Header, data: &mut Vec) -> Result<(), BoxedError> { 231 | if data.is_empty() { 232 | return Ok(()); 233 | } 234 | 235 | match header.opcode() { 236 | OpCode::Binary | OpCode::Text if header.is_rsv1() => { 237 | if !header.is_fin() { 238 | self.await_last_fragment = true; 239 | log::trace!("deflate: not decoding {}; awaiting last fragment", header); 240 | return Ok(()); 241 | } 242 | log::trace!("deflate: decoding {}", header) 243 | } 244 | OpCode::Continue if header.is_fin() && self.await_last_fragment => { 245 | self.await_last_fragment = false; 246 | log::trace!("deflate: decoding {}", header) 247 | } 248 | _ => { 249 | log::trace!("deflate: not decoding {}", header); 250 | return Ok(()); 251 | } 252 | } 253 | 254 | // Restore LEN and NLEN: 255 | data.extend_from_slice(&[0, 0, 0xFF, 0xFF]); // cf. RFC 7692, 7.2.2 256 | 257 | self.buffer.clear(); 258 | let mut decoder = DeflateDecoder::new(&mut self.buffer); 259 | decoder.write_all(&data)?; 260 | decoder.finish()?; 261 | mem::swap(data, &mut self.buffer); 262 | 263 | header.set_rsv1(false); 264 | header.set_payload_len(data.len()); 265 | 266 | Ok(()) 267 | } 268 | 269 | fn encode(&mut self, header: &mut Header, data: &mut Storage) -> Result<(), BoxedError> { 270 | if data.as_ref().is_empty() { 271 | return Ok(()); 272 | } 273 | 274 | if let OpCode::Binary | OpCode::Text = header.opcode() { 275 | log::trace!("deflate: encoding {}", header) 276 | } else { 277 | log::trace!("deflate: not encoding {}", header); 278 | return Ok(()); 279 | } 280 | 281 | self.buffer.clear(); 282 | self.buffer.reserve(data.as_ref().len()); 283 | 284 | let mut encoder = Compress::new_with_window_bits(Compression::fast(), false, self.our_max_window_bits); 285 | 286 | // Compress all input bytes. 287 | while encoder.total_in() < as_u64(data.as_ref().len()) { 288 | let i: usize = encoder.total_in().try_into()?; 289 | match encoder.compress_vec(&data.as_ref()[i..], &mut self.buffer, FlushCompress::None)? { 290 | Status::BufError => self.buffer.reserve(4096), 291 | Status::Ok => continue, 292 | Status::StreamEnd => break, 293 | } 294 | } 295 | 296 | // We need to append an empty deflate block if not there yet (RFC 7692, 7.2.1). 297 | while !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) { 298 | self.buffer.reserve(5); // Make sure there is room for the trailing end bytes. 299 | match encoder.compress_vec(&[], &mut self.buffer, FlushCompress::Sync)? { 300 | Status::Ok => continue, 301 | Status::BufError => continue, // more capacity is reserved above 302 | Status::StreamEnd => break, 303 | } 304 | } 305 | 306 | // If we still have not seen the empty deflate block appended, something is wrong. 307 | if !self.buffer.ends_with(&[0, 0, 0xFF, 0xFF]) { 308 | log::error!("missing 00 00 FF FF"); 309 | return Err(io::Error::new(io::ErrorKind::Other, "missing 00 00 FF FF").into()); 310 | } 311 | 312 | self.buffer.truncate(self.buffer.len() - 4); // Remove 00 00 FF FF; cf. RFC 7692, 7.2.1 313 | 314 | if let Storage::Owned(d) = data { 315 | mem::swap(d, &mut self.buffer) 316 | } else { 317 | *data = Storage::Owned(mem::take(&mut self.buffer)) 318 | } 319 | header.set_rsv1(true); 320 | header.set_payload_len(data.as_ref().len()); 321 | Ok(()) 322 | } 323 | } 324 | -------------------------------------------------------------------------------- /src/handshake.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // 3 | // Licensed under the Apache License, Version 2.0 4 | // or the MIT 5 | // license , at your 6 | // option. All files in the project carrying such notice may not be copied, 7 | // modified, or distributed except according to those terms. 8 | 9 | //! Websocket [handshake]s. 10 | //! 11 | //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 12 | 13 | pub mod client; 14 | #[cfg(feature = "http")] 15 | pub mod http; 16 | pub mod server; 17 | 18 | use crate::extension::{Extension, Param}; 19 | use base64::Engine; 20 | use bytes::BytesMut; 21 | use sha1::{Digest, Sha1}; 22 | use std::{fmt, io, str}; 23 | 24 | pub use client::{Client, ServerResponse}; 25 | pub use server::{ClientRequest, Server}; 26 | 27 | // Defined in RFC 6455 and used to generate the `Sec-WebSocket-Accept` header 28 | // in the server handshake response. 29 | const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 30 | 31 | // How many HTTP headers do we support during parsing? 32 | const MAX_NUM_HEADERS: usize = 32; 33 | 34 | // Some HTTP headers we need to check during parsing. 35 | const SEC_WEBSOCKET_EXTENSIONS: &str = "Sec-WebSocket-Extensions"; 36 | const SEC_WEBSOCKET_PROTOCOL: &str = "Sec-WebSocket-Protocol"; 37 | 38 | /// Check a set of headers contains a specific one. 39 | fn expect_ascii_header(headers: &[httparse::Header], name: &str, ours: &str) -> Result<(), Error> { 40 | enum State { 41 | Init, // Start state 42 | Name, // Header name found 43 | Match, // Header value matches 44 | } 45 | 46 | headers 47 | .iter() 48 | .filter(|h| h.name.eq_ignore_ascii_case(name)) 49 | .fold(Ok(State::Init), |result, header| { 50 | if let Ok(State::Match) = result { 51 | return result; 52 | } 53 | if str::from_utf8(header.value)?.split(',').any(|v| v.trim().eq_ignore_ascii_case(ours)) { 54 | return Ok(State::Match); 55 | } 56 | Ok(State::Name) 57 | }) 58 | .and_then(|state| match state { 59 | State::Init => Err(Error::HeaderNotFound(name.into())), 60 | State::Name => Err(Error::UnexpectedHeader(name.into())), 61 | State::Match => Ok(()), 62 | }) 63 | } 64 | 65 | /// Pick the first header with the given name and apply the given closure to it. 66 | fn with_first_header<'a, F, R>(headers: &[httparse::Header<'a>], name: &str, f: F) -> Result 67 | where 68 | F: Fn(&'a [u8]) -> Result, 69 | { 70 | if let Some(h) = headers.iter().find(|h| h.name.eq_ignore_ascii_case(name)) { 71 | f(h.value) 72 | } else { 73 | Err(Error::HeaderNotFound(name.into())) 74 | } 75 | } 76 | 77 | // Configure all extensions with parsed parameters. 78 | fn configure_extensions(extensions: &mut [Box], line: &str) -> Result<(), Error> { 79 | for e in line.split(',') { 80 | let mut ext_parts = e.split(';'); 81 | if let Some(name) = ext_parts.next() { 82 | let name = name.trim(); 83 | if let Some(ext) = extensions.iter_mut().find(|x| x.name().eq_ignore_ascii_case(name)) { 84 | let mut params = Vec::new(); 85 | for p in ext_parts { 86 | let mut key_value = p.split('='); 87 | if let Some(key) = key_value.next().map(str::trim) { 88 | let val = key_value.next().map(|v| v.trim().trim_matches('"')); 89 | let mut p = Param::new(key); 90 | p.set_value(val); 91 | params.push(p) 92 | } 93 | } 94 | ext.configure(¶ms).map_err(Error::Extension)? 95 | } 96 | } 97 | } 98 | Ok(()) 99 | } 100 | 101 | // Write all extensions to the given buffer. 102 | fn append_extensions<'a, I>(extensions: I, bytes: &mut BytesMut) 103 | where 104 | I: IntoIterator>, 105 | { 106 | let mut iter = extensions.into_iter().peekable(); 107 | 108 | if iter.peek().is_some() { 109 | bytes.extend_from_slice(b"\r\nSec-WebSocket-Extensions: ") 110 | } 111 | 112 | append_extension_header_value(iter, bytes) 113 | } 114 | 115 | // Write the extension header value to the given buffer. 116 | fn append_extension_header_value<'a, I>(mut extensions_iter: std::iter::Peekable, bytes: &mut BytesMut) 117 | where 118 | I: Iterator>, 119 | { 120 | while let Some(e) = extensions_iter.next() { 121 | bytes.extend_from_slice(e.name().as_bytes()); 122 | for p in e.params() { 123 | bytes.extend_from_slice(b"; "); 124 | bytes.extend_from_slice(p.name().as_bytes()); 125 | if let Some(v) = p.value() { 126 | bytes.extend_from_slice(b"="); 127 | bytes.extend_from_slice(v.as_bytes()) 128 | } 129 | } 130 | if extensions_iter.peek().is_some() { 131 | bytes.extend_from_slice(b", ") 132 | } 133 | } 134 | } 135 | 136 | // This function takes a 16 byte key (base64 encoded, and so 24 bytes of input) that is expected via 137 | // the `Sec-WebSocket-Key` header during a websocket handshake, and writes the response that's expected 138 | // to be handed back in the response header `Sec-WebSocket-Accept`. 139 | // 140 | // The response is a base64 encoding of a 160bit hash. base64 encoding uses 1 ascii character per 6 bits, 141 | // and 160 / 6 = 26.66 characters. The output is padded with '=' to the nearest 4 characters, so we need 28 142 | // bytes in total for all of the characters. 143 | // 144 | // See https://datatracker.ietf.org/doc/html/rfc6455#section-1.3 for more information on this. 145 | fn generate_accept_key<'k>(key_base64: &WebSocketKey) -> [u8; 28] { 146 | let mut digest = Sha1::new(); 147 | digest.update(key_base64); 148 | digest.update(KEY); 149 | let d = digest.finalize(); 150 | 151 | let mut output_buf = [0; 28]; 152 | let n = base64::engine::general_purpose::STANDARD 153 | .encode_slice(d, &mut output_buf) 154 | .expect("encoding to base64 is exactly 28 bytes; qed"); 155 | debug_assert_eq!(n, 28, "encoding to base64 should be exactly 28 bytes"); 156 | output_buf 157 | } 158 | 159 | /// Enumeration of possible handshake errors. 160 | #[non_exhaustive] 161 | #[derive(Debug)] 162 | pub enum Error { 163 | /// An I/O error has been encountered. 164 | Io(io::Error), 165 | /// An HTTP version =/= 1.1 was encountered. 166 | UnsupportedHttpVersion, 167 | /// An incomplete HTTP request. 168 | IncompleteHttpRequest, 169 | /// The value of the `Sec-WebSocket-Key` header is of unexpected length. 170 | SecWebSocketKeyInvalidLength(usize), 171 | /// The handshake request was not a GET request. 172 | InvalidRequestMethod, 173 | /// An HTTP header has not been present. 174 | HeaderNotFound(String), 175 | /// An HTTP header value was not expected. 176 | UnexpectedHeader(String), 177 | /// The Sec-WebSocket-Accept header value did not match. 178 | InvalidSecWebSocketAccept, 179 | /// The server returned an extension we did not ask for. 180 | UnsolicitedExtension, 181 | /// The server returned a protocol we did not ask for. 182 | UnsolicitedProtocol, 183 | /// An extension produced an error while encoding or decoding. 184 | Extension(crate::BoxedError), 185 | /// The HTTP entity could not be parsed successfully. 186 | Http(crate::BoxedError), 187 | /// UTF-8 decoding failed. 188 | Utf8(str::Utf8Error), 189 | } 190 | 191 | impl fmt::Display for Error { 192 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 193 | match self { 194 | Error::Io(e) => write!(f, "i/o error: {}", e), 195 | Error::UnsupportedHttpVersion => f.write_str("http version was not 1.1"), 196 | Error::IncompleteHttpRequest => f.write_str("http request was incomplete"), 197 | Error::SecWebSocketKeyInvalidLength(len) => { 198 | write!(f, "Sec-WebSocket-Key header was {} bytes long, expected 24", len) 199 | } 200 | Error::InvalidRequestMethod => f.write_str("handshake was not a GET request"), 201 | Error::HeaderNotFound(name) => write!(f, "header {} not found", name), 202 | Error::UnexpectedHeader(name) => write!(f, "header {} had an unexpected value", name), 203 | Error::InvalidSecWebSocketAccept => f.write_str("websocket key mismatch"), 204 | Error::UnsolicitedExtension => f.write_str("unsolicited extension returned"), 205 | Error::UnsolicitedProtocol => f.write_str("unsolicited protocol returned"), 206 | Error::Extension(e) => write!(f, "extension error: {}", e), 207 | Error::Http(e) => write!(f, "http parser error: {}", e), 208 | Error::Utf8(e) => write!(f, "utf-8 decoding error: {}", e), 209 | } 210 | } 211 | } 212 | 213 | impl std::error::Error for Error { 214 | fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 215 | match self { 216 | Error::Io(e) => Some(e), 217 | Error::Extension(e) => Some(&**e), 218 | Error::Http(e) => Some(&**e), 219 | Error::Utf8(e) => Some(e), 220 | Error::UnsupportedHttpVersion 221 | | Error::IncompleteHttpRequest 222 | | Error::SecWebSocketKeyInvalidLength(_) 223 | | Error::InvalidRequestMethod 224 | | Error::HeaderNotFound(_) 225 | | Error::UnexpectedHeader(_) 226 | | Error::InvalidSecWebSocketAccept 227 | | Error::UnsolicitedExtension 228 | | Error::UnsolicitedProtocol => None, 229 | } 230 | } 231 | } 232 | 233 | impl From for Error { 234 | fn from(e: io::Error) -> Self { 235 | Error::Io(e) 236 | } 237 | } 238 | 239 | impl From for Error { 240 | fn from(e: str::Utf8Error) -> Self { 241 | Error::Utf8(e) 242 | } 243 | } 244 | 245 | /// Owned value of the `Sec-WebSocket-Key` header. 246 | /// 247 | /// Per [RFC 6455](https://datatracker.ietf.org/doc/html/rfc6455#section-4.1): 248 | /// 249 | /// ```text 250 | /// (...) The value of this header field MUST be a 251 | /// nonce consisting of a randomly selected 16-byte value that has 252 | /// been base64-encoded (see Section 4 of [RFC4648]). (...) 253 | /// ``` 254 | /// 255 | /// Base64 encoding of the nonce produces 24 ASCII bytes, padding included. 256 | pub type WebSocketKey = [u8; 24]; 257 | 258 | #[cfg(test)] 259 | mod tests { 260 | use super::expect_ascii_header; 261 | 262 | #[test] 263 | fn header_match() { 264 | let headers = &[ 265 | httparse::Header { name: "foo", value: b"a,b,c,d" }, 266 | httparse::Header { name: "foo", value: b"x" }, 267 | httparse::Header { name: "foo", value: b"y, z, a" }, 268 | httparse::Header { name: "bar", value: b"xxx" }, 269 | httparse::Header { name: "bar", value: b"sdfsdf 423 42 424" }, 270 | httparse::Header { name: "baz", value: b"123" }, 271 | ]; 272 | 273 | assert!(expect_ascii_header(headers, "foo", "a").is_ok()); 274 | assert!(expect_ascii_header(headers, "foo", "b").is_ok()); 275 | assert!(expect_ascii_header(headers, "foo", "c").is_ok()); 276 | assert!(expect_ascii_header(headers, "foo", "d").is_ok()); 277 | assert!(expect_ascii_header(headers, "foo", "x").is_ok()); 278 | assert!(expect_ascii_header(headers, "foo", "y").is_ok()); 279 | assert!(expect_ascii_header(headers, "foo", "z").is_ok()); 280 | assert!(expect_ascii_header(headers, "foo", "a").is_ok()); 281 | assert!(expect_ascii_header(headers, "bar", "xxx").is_ok()); 282 | assert!(expect_ascii_header(headers, "bar", "sdfsdf 423 42 424").is_ok()); 283 | assert!(expect_ascii_header(headers, "baz", "123").is_ok()); 284 | assert!(expect_ascii_header(headers, "baz", "???").is_err()); 285 | assert!(expect_ascii_header(headers, "???", "x").is_err()); 286 | } 287 | } 288 | -------------------------------------------------------------------------------- /src/handshake/client.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // 3 | // Licensed under the Apache License, Version 2.0 4 | // or the MIT 5 | // license , at your 6 | // option. All files in the project carrying such notice may not be copied, 7 | // modified, or distributed except according to those terms. 8 | 9 | //! Websocket client [handshake]. 10 | //! 11 | //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 12 | 13 | use super::{ 14 | append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, KEY, 15 | MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL, 16 | }; 17 | use crate::connection::{self, Mode}; 18 | use crate::{extension::Extension, Parsing}; 19 | use base64::Engine; 20 | use bytes::{Buf, BytesMut}; 21 | use futures::prelude::*; 22 | use sha1::{Digest, Sha1}; 23 | use std::{mem, str}; 24 | 25 | pub use httparse::Header; 26 | 27 | const BLOCK_SIZE: usize = 8 * 1024; 28 | 29 | /// Websocket client handshake. 30 | #[derive(Debug)] 31 | pub struct Client<'a, T> { 32 | /// The underlying async I/O resource. 33 | socket: T, 34 | /// The HTTP host to send the handshake to. 35 | host: &'a str, 36 | /// The HTTP host resource. 37 | resource: &'a str, 38 | /// The HTTP headers. 39 | headers: &'a [Header<'a>], 40 | /// A buffer holding the base-64 encoded request nonce. 41 | nonce: WebSocketKey, 42 | /// The protocols to include in the handshake. 43 | protocols: Vec<&'a str>, 44 | /// The extensions the client wishes to include in the request. 45 | extensions: Vec>, 46 | /// Encoding/decoding buffer. 47 | buffer: BytesMut, 48 | } 49 | 50 | impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> { 51 | /// Create a new client handshake for some host and resource. 52 | pub fn new(socket: T, host: &'a str, resource: &'a str) -> Self { 53 | Client { 54 | socket, 55 | host, 56 | resource, 57 | headers: &[], 58 | nonce: [0; 24], 59 | protocols: Vec::new(), 60 | extensions: Vec::new(), 61 | buffer: BytesMut::new(), 62 | } 63 | } 64 | 65 | /// Override the buffer to use for request/response handling. 66 | pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { 67 | self.buffer = b; 68 | self 69 | } 70 | 71 | /// Extract the buffer. 72 | pub fn take_buffer(&mut self) -> BytesMut { 73 | mem::take(&mut self.buffer) 74 | } 75 | 76 | /// Set connection headers to a slice. These headers are not checked for validity, 77 | /// the caller of this method is responsible for verification as well as avoiding 78 | /// conflicts with internally set headers. 79 | pub fn set_headers(&mut self, h: &'a [Header]) -> &mut Self { 80 | self.headers = h; 81 | self 82 | } 83 | 84 | /// Add a protocol to be included in the handshake. 85 | pub fn add_protocol(&mut self, p: &'a str) -> &mut Self { 86 | self.protocols.push(p); 87 | self 88 | } 89 | 90 | /// Add an extension to be included in the handshake. 91 | pub fn add_extension(&mut self, e: Box) -> &mut Self { 92 | self.extensions.push(e); 93 | self 94 | } 95 | 96 | /// Get back all extensions. 97 | pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { 98 | self.extensions.drain(..) 99 | } 100 | 101 | /// Initiate client handshake request to server and get back the response. 102 | pub async fn handshake(&mut self) -> Result { 103 | self.buffer.clear(); 104 | self.encode_request(); 105 | self.socket.write_all(&self.buffer).await?; 106 | self.socket.flush().await?; 107 | self.buffer.clear(); 108 | 109 | loop { 110 | crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?; 111 | if let Parsing::Done { value, offset } = self.decode_response()? { 112 | self.buffer.advance(offset); 113 | return Ok(value); 114 | } 115 | } 116 | } 117 | 118 | /// Turn this handshake into a [`connection::Builder`]. 119 | pub fn into_builder(mut self) -> connection::Builder { 120 | let mut builder = connection::Builder::new(self.socket, Mode::Client); 121 | builder.set_buffer(self.buffer); 122 | builder.add_extensions(self.extensions.drain(..)); 123 | builder 124 | } 125 | 126 | /// Get out the inner socket of the client. 127 | pub fn into_inner(self) -> T { 128 | self.socket 129 | } 130 | 131 | /// Encode the client handshake as a request, ready to be sent to the server. 132 | fn encode_request(&mut self) { 133 | let nonce: [u8; 16] = rand::random(); 134 | base64::engine::general_purpose::STANDARD 135 | .encode_slice(nonce, &mut self.nonce) 136 | .expect("encoding to base64 is exactly 16 bytes; qed"); 137 | self.buffer.extend_from_slice(b"GET "); 138 | self.buffer.extend_from_slice(self.resource.as_bytes()); 139 | self.buffer.extend_from_slice(b" HTTP/1.1"); 140 | self.buffer.extend_from_slice(b"\r\nHost: "); 141 | self.buffer.extend_from_slice(self.host.as_bytes()); 142 | self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade"); 143 | self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: "); 144 | self.buffer.extend_from_slice(&self.nonce); 145 | self.headers.iter().for_each(|h| { 146 | self.buffer.extend_from_slice(b"\r\n"); 147 | self.buffer.extend_from_slice(h.name.as_bytes()); 148 | self.buffer.extend_from_slice(b": "); 149 | self.buffer.extend_from_slice(h.value); 150 | }); 151 | if let Some((last, prefix)) = self.protocols.split_last() { 152 | self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); 153 | for p in prefix { 154 | self.buffer.extend_from_slice(p.as_bytes()); 155 | self.buffer.extend_from_slice(b",") 156 | } 157 | self.buffer.extend_from_slice(last.as_bytes()) 158 | } 159 | append_extensions(&self.extensions, &mut self.buffer); 160 | self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n") 161 | } 162 | 163 | /// Decode the server response to this client request. 164 | fn decode_response(&mut self) -> Result, Error> { 165 | let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS]; 166 | let mut response = httparse::Response::new(&mut header_buf); 167 | 168 | let offset = match response.parse(self.buffer.as_ref()) { 169 | Ok(httparse::Status::Complete(off)) => off, 170 | Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())), 171 | Err(e) => return Err(Error::Http(Box::new(e))), 172 | }; 173 | 174 | if response.version != Some(1) { 175 | return Err(Error::UnsupportedHttpVersion); 176 | } 177 | 178 | match response.code { 179 | Some(101) => (), 180 | Some(code @ (301..=303)) | Some(code @ 307) | Some(code @ 308) => { 181 | // redirect response 182 | let location = 183 | with_first_header(response.headers, "Location", |loc| Ok(String::from(std::str::from_utf8(loc)?)))?; 184 | let response = ServerResponse::Redirect { status_code: code, location }; 185 | return Ok(Parsing::Done { value: response, offset }); 186 | } 187 | other => { 188 | let response = ServerResponse::Rejected { status_code: other.unwrap_or(0) }; 189 | return Ok(Parsing::Done { value: response, offset }); 190 | } 191 | } 192 | 193 | expect_ascii_header(response.headers, "Upgrade", "websocket")?; 194 | expect_ascii_header(response.headers, "Connection", "upgrade")?; 195 | 196 | with_first_header(&response.headers, "Sec-WebSocket-Accept", |theirs| { 197 | let mut digest = Sha1::new(); 198 | digest.update(&self.nonce); 199 | digest.update(KEY); 200 | let ours = base64::engine::general_purpose::STANDARD.encode(digest.finalize()); 201 | if ours.as_bytes() != theirs { 202 | return Err(Error::InvalidSecWebSocketAccept); 203 | } 204 | Ok(()) 205 | })?; 206 | 207 | // Parse `Sec-WebSocket-Extensions` headers. 208 | 209 | for h in response.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) { 210 | configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)? 211 | } 212 | 213 | // Match `Sec-WebSocket-Protocol` header. 214 | 215 | let mut selected_proto = None; 216 | if let Some(tp) = response.headers.iter().find(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) { 217 | if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == tp.value) { 218 | selected_proto = Some(String::from(p)) 219 | } else { 220 | return Err(Error::UnsolicitedProtocol); 221 | } 222 | } 223 | 224 | let response = ServerResponse::Accepted { protocol: selected_proto }; 225 | Ok(Parsing::Done { value: response, offset }) 226 | } 227 | } 228 | 229 | /// Handshake response received from the server. 230 | #[derive(Debug)] 231 | pub enum ServerResponse { 232 | /// The server has accepted our request. 233 | Accepted { 234 | /// The protocol (if any) the server has selected. 235 | protocol: Option, 236 | }, 237 | /// The server is redirecting us to some other location. 238 | Redirect { 239 | /// The HTTP response status code. 240 | status_code: u16, 241 | /// The location URL we should go to. 242 | location: String, 243 | }, 244 | /// The server rejected our request. 245 | Rejected { 246 | /// HTTP response status code. 247 | status_code: u16, 248 | }, 249 | } 250 | -------------------------------------------------------------------------------- /src/handshake/http.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021 Parity Technologies (UK) Ltd. 2 | // 3 | // Licensed under the Apache License, Version 2.0 4 | // or the MIT 5 | // license , at your 6 | // option. All files in the project carrying such notice may not be copied, 7 | // modified, or distributed except according to those terms. 8 | 9 | /*! 10 | This module somewhat mirrors [`crate::handshake::server`], except it's focus is on working 11 | with [`http::Request`] and [`http::Response`] types, making it easier to integrate with 12 | external web servers such as Hyper. 13 | 14 | See `examples/hyper_server.rs` from this crate's repository for example usage. 15 | */ 16 | 17 | use super::{WebSocketKey, SEC_WEBSOCKET_EXTENSIONS}; 18 | use crate::connection::{self, Mode}; 19 | use crate::extension::Extension; 20 | use crate::handshake; 21 | use bytes::BytesMut; 22 | use futures::prelude::*; 23 | use http::{header, HeaderMap, Response}; 24 | use std::mem; 25 | 26 | /// A re-export of [`handshake::Error`]. 27 | pub type Error = handshake::Error; 28 | 29 | /// Websocket handshake server. This is similar to [`handshake::Server`], but it is 30 | /// focused on performing the WebSocket handshake using a provided [`http::Request`], as opposed 31 | /// to decoding the request internally. 32 | pub struct Server { 33 | // Extensions the server supports. 34 | extensions: Vec>, 35 | // Encoding/decoding buffer. 36 | buffer: BytesMut, 37 | } 38 | 39 | impl Server { 40 | /// Create a new server handshake. 41 | pub fn new() -> Self { 42 | Server { extensions: Vec::new(), buffer: BytesMut::new() } 43 | } 44 | 45 | /// Override the buffer to use for request/response handling. 46 | pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { 47 | self.buffer = b; 48 | self 49 | } 50 | 51 | /// Extract the buffer. 52 | pub fn take_buffer(&mut self) -> BytesMut { 53 | mem::take(&mut self.buffer) 54 | } 55 | 56 | /// Add an extension the server supports. 57 | pub fn add_extension(&mut self, e: Box) -> &mut Self { 58 | self.extensions.push(e); 59 | self 60 | } 61 | 62 | /// Get back all extensions. 63 | pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { 64 | self.extensions.drain(..) 65 | } 66 | 67 | /// Attempt to interpret the provided [`http::Request`] as a WebSocket Upgrade request. If successful, this 68 | /// returns an [`http::Response`] that should be returned to the client to complete the handshake. 69 | pub fn receive_request(&mut self, req: &http::Request) -> Result, Error> { 70 | if !is_upgrade_request(&req) { 71 | return Err(Error::InvalidSecWebSocketAccept); 72 | } 73 | 74 | let key = match req.headers().get("Sec-WebSocket-Key") { 75 | Some(key) => key, 76 | None => { 77 | return Err(Error::HeaderNotFound("Sec-WebSocket-Key".into()).into()); 78 | } 79 | }; 80 | 81 | if req.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") { 82 | return Err(Error::HeaderNotFound("Sec-WebSocket-Version".into()).into()); 83 | } 84 | 85 | // Pull out the Sec-WebSocket-Key and generate the appropriate response to it. 86 | let key: &WebSocketKey = match key.as_bytes().try_into() { 87 | Ok(key) => key, 88 | Err(_) => return Err(Error::InvalidSecWebSocketAccept), 89 | }; 90 | let accept_key = handshake::generate_accept_key(key); 91 | 92 | // Get extension information out of the request as we'll need this as well. 93 | let extension_config = req 94 | .headers() 95 | .iter() 96 | .filter(|&(name, _)| name.as_str().eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) 97 | .map(|(_, value)| Ok(std::str::from_utf8(value.as_bytes())?.to_string())) 98 | .collect::, Error>>()?; 99 | 100 | // Attempt to set the extension configuration params that the client requested. 101 | for config_str in &extension_config { 102 | handshake::configure_extensions(&mut self.extensions, &config_str)?; 103 | } 104 | 105 | // Build a response that should be sent back to the client to acknowledge the upgrade. 106 | let mut response = Response::builder() 107 | .status(http::StatusCode::SWITCHING_PROTOCOLS) 108 | .header(http::header::CONNECTION, "upgrade") 109 | .header(http::header::UPGRADE, "websocket") 110 | .header("Sec-WebSocket-Accept", &accept_key[..]); 111 | 112 | // Tell the client about the agreed-upon extension configuration. We reuse code to build up the 113 | // extension header value, but that does make this a little more clunky. 114 | if !self.extensions.is_empty() { 115 | let mut buf = bytes::BytesMut::new(); 116 | let enabled_extensions = self.extensions.iter().filter(|e| e.is_enabled()).peekable(); 117 | handshake::append_extension_header_value(enabled_extensions, &mut buf); 118 | response = response.header("Sec-WebSocket-Extensions", buf.as_ref()); 119 | } 120 | 121 | let response = response.body(()).expect("bug: failed to build response"); 122 | Ok(response) 123 | } 124 | 125 | /// Turn this handshake into a [`connection::Builder`]. 126 | pub fn into_builder(mut self, socket: T) -> connection::Builder { 127 | let mut builder = connection::Builder::new(socket, Mode::Server); 128 | builder.set_buffer(self.buffer); 129 | builder.add_extensions(self.extensions.drain(..)); 130 | builder 131 | } 132 | } 133 | 134 | /// Check if an [`http::Request`] looks like a valid websocket upgrade request. 135 | pub fn is_upgrade_request(request: &http::Request) -> bool { 136 | header_contains_value(request.headers(), header::CONNECTION, b"upgrade") 137 | && header_contains_value(request.headers(), header::UPGRADE, b"websocket") 138 | } 139 | 140 | // Check if there is a header of the given name containing the wanted value. 141 | fn header_contains_value(headers: &HeaderMap, header: header::HeaderName, value: &[u8]) -> bool { 142 | pub fn trim(x: &[u8]) -> &[u8] { 143 | let from = match x.iter().position(|x| !x.is_ascii_whitespace()) { 144 | Some(i) => i, 145 | None => return &[], 146 | }; 147 | let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap(); 148 | &x[from..=to] 149 | } 150 | 151 | for header in headers.get_all(header) { 152 | if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) { 153 | return true; 154 | } 155 | } 156 | false 157 | } 158 | -------------------------------------------------------------------------------- /src/handshake/server.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // 3 | // Licensed under the Apache License, Version 2.0 4 | // or the MIT 5 | // license , at your 6 | // option. All files in the project carrying such notice may not be copied, 7 | // modified, or distributed except according to those terms. 8 | 9 | //! Websocket server [handshake]. 10 | //! 11 | //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 12 | 13 | use super::{ 14 | append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, 15 | MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL, 16 | }; 17 | use crate::connection::{self, Mode}; 18 | use crate::extension::Extension; 19 | use bytes::BytesMut; 20 | use futures::prelude::*; 21 | use std::{mem, str}; 22 | 23 | // Most HTTP servers default to 8KB limit on headers 24 | const MAX_HEADERS_SIZE: usize = 8 * 1024; 25 | const BLOCK_SIZE: usize = 8 * 1024; 26 | 27 | /// Websocket handshake server. 28 | #[derive(Debug)] 29 | pub struct Server<'a, T> { 30 | socket: T, 31 | /// Protocols the server supports. 32 | protocols: Vec<&'a str>, 33 | /// Extensions the server supports. 34 | extensions: Vec>, 35 | /// Encoding/decoding buffer. 36 | buffer: BytesMut, 37 | } 38 | 39 | impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> { 40 | /// Create a new server handshake. 41 | pub fn new(socket: T) -> Self { 42 | Server { socket, protocols: Vec::new(), extensions: Vec::new(), buffer: BytesMut::new() } 43 | } 44 | 45 | /// Override the buffer to use for request/response handling. 46 | pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self { 47 | self.buffer = b; 48 | self 49 | } 50 | 51 | /// Extract the buffer. 52 | pub fn take_buffer(&mut self) -> BytesMut { 53 | mem::take(&mut self.buffer) 54 | } 55 | 56 | /// Add a protocol the server supports. 57 | pub fn add_protocol(&mut self, p: &'a str) -> &mut Self { 58 | self.protocols.push(p); 59 | self 60 | } 61 | 62 | /// Add an extension the server supports. 63 | pub fn add_extension(&mut self, e: Box) -> &mut Self { 64 | self.extensions.push(e); 65 | self 66 | } 67 | 68 | /// Get back all extensions. 69 | pub fn drain_extensions(&mut self) -> impl Iterator> + '_ { 70 | self.extensions.drain(..) 71 | } 72 | 73 | /// Await an incoming client handshake request. 74 | pub async fn receive_request(&mut self) -> Result, Error> { 75 | self.buffer.clear(); 76 | 77 | let mut skip = 0; 78 | 79 | loop { 80 | crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?; 81 | 82 | let limit = std::cmp::min(self.buffer.len(), MAX_HEADERS_SIZE); 83 | 84 | // We don't expect body, so can search for the CRLF headers tail from 85 | // the end of the buffer. 86 | if self.buffer[skip..limit].windows(4).rev().any(|w| w == b"\r\n\r\n") { 87 | break; 88 | } 89 | 90 | // Give up if we've reached the limit. We could emit a specific error here, 91 | // but httparse will produce meaningful error for us regardless. 92 | if limit == MAX_HEADERS_SIZE { 93 | break; 94 | } 95 | 96 | // Skip bytes that did not contain CRLF in the next iteration. 97 | // If we only read a partial CRLF sequence, we would miss it if we skipped the full buffer 98 | // length, hence backing off the full 4 bytes. 99 | skip = self.buffer.len().saturating_sub(4); 100 | } 101 | 102 | self.decode_request() 103 | } 104 | 105 | /// Respond to the client. 106 | pub async fn send_response(&mut self, r: &Response<'_>) -> Result<(), Error> { 107 | self.buffer.clear(); 108 | self.encode_response(r); 109 | self.socket.write_all(&self.buffer).await?; 110 | self.socket.flush().await?; 111 | self.buffer.clear(); 112 | Ok(()) 113 | } 114 | 115 | /// Turn this handshake into a [`connection::Builder`]. 116 | pub fn into_builder(mut self) -> connection::Builder { 117 | let mut builder = connection::Builder::new(self.socket, Mode::Server); 118 | builder.set_buffer(self.buffer); 119 | builder.add_extensions(self.extensions.drain(..)); 120 | builder 121 | } 122 | 123 | /// Get out the inner socket of the server. 124 | pub fn into_inner(self) -> T { 125 | self.socket 126 | } 127 | 128 | // Decode client handshake request. 129 | fn decode_request(&mut self) -> Result { 130 | let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS]; 131 | let mut request = httparse::Request::new(&mut header_buf); 132 | 133 | match request.parse(self.buffer.as_ref()) { 134 | Ok(httparse::Status::Complete(_)) => (), 135 | Ok(httparse::Status::Partial) => return Err(Error::IncompleteHttpRequest), 136 | Err(e) => return Err(Error::Http(Box::new(e))), 137 | }; 138 | if request.method != Some("GET") { 139 | return Err(Error::InvalidRequestMethod); 140 | } 141 | if request.version != Some(1) { 142 | return Err(Error::UnsupportedHttpVersion); 143 | } 144 | 145 | let host = with_first_header(&request.headers, "Host", Ok)?; 146 | 147 | expect_ascii_header(request.headers, "Upgrade", "websocket")?; 148 | expect_ascii_header(request.headers, "Connection", "upgrade")?; 149 | expect_ascii_header(request.headers, "Sec-WebSocket-Version", "13")?; 150 | 151 | let origin = 152 | request.headers.iter().find_map( 153 | |h| { 154 | if h.name.eq_ignore_ascii_case("Origin") { 155 | Some(h.value) 156 | } else { 157 | None 158 | } 159 | }, 160 | ); 161 | let headers = RequestHeaders { host, origin }; 162 | 163 | let ws_key = with_first_header(&request.headers, "Sec-WebSocket-Key", |k| { 164 | WebSocketKey::try_from(k).map_err(|_| Error::SecWebSocketKeyInvalidLength(k.len())) 165 | })?; 166 | 167 | for h in request.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) { 168 | configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)? 169 | } 170 | 171 | let mut protocols = Vec::new(); 172 | for p in request.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) { 173 | if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == p.value) { 174 | protocols.push(p) 175 | } 176 | } 177 | 178 | let path = request.path.unwrap_or("/"); 179 | 180 | Ok(ClientRequest { ws_key, protocols, path, headers }) 181 | } 182 | 183 | // Encode server handshake response. 184 | fn encode_response(&mut self, response: &Response<'_>) { 185 | match response { 186 | Response::Accept { key, protocol } => { 187 | let accept_value = super::generate_accept_key(&key); 188 | self.buffer.extend_from_slice( 189 | concat![ 190 | "HTTP/1.1 101 Switching Protocols", 191 | "\r\nServer: soketto-", 192 | env!("CARGO_PKG_VERSION"), 193 | "\r\nUpgrade: websocket", 194 | "\r\nConnection: upgrade", 195 | "\r\nSec-WebSocket-Accept: ", 196 | ] 197 | .as_bytes(), 198 | ); 199 | self.buffer.extend_from_slice(&accept_value); 200 | if let Some(p) = protocol { 201 | self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: "); 202 | self.buffer.extend_from_slice(p.as_bytes()) 203 | } 204 | append_extensions(self.extensions.iter().filter(|e| e.is_enabled()), &mut self.buffer); 205 | self.buffer.extend_from_slice(b"\r\n\r\n") 206 | } 207 | Response::Reject { status_code } => { 208 | self.buffer.extend_from_slice(b"HTTP/1.1 "); 209 | let (_, reason) = if let Ok(i) = STATUSCODES.binary_search_by_key(status_code, |(n, _)| *n) { 210 | STATUSCODES[i] 211 | } else { 212 | (500, "500 Internal Server Error") 213 | }; 214 | self.buffer.extend_from_slice(reason.as_bytes()); 215 | self.buffer.extend_from_slice(b"\r\n\r\n") 216 | } 217 | } 218 | } 219 | } 220 | 221 | /// Handshake request received from the client. 222 | #[derive(Debug)] 223 | pub struct ClientRequest<'a> { 224 | ws_key: WebSocketKey, 225 | protocols: Vec<&'a str>, 226 | path: &'a str, 227 | headers: RequestHeaders<'a>, 228 | } 229 | 230 | /// Select HTTP headers sent by the client. 231 | #[derive(Debug, Copy, Clone)] 232 | pub struct RequestHeaders<'a> { 233 | /// The [`Host`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) header. 234 | pub host: &'a [u8], 235 | /// The [`Origin`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin) header, if provided. 236 | pub origin: Option<&'a [u8]>, 237 | } 238 | 239 | impl<'a> ClientRequest<'a> { 240 | /// The `Sec-WebSocket-Key` header nonce value. 241 | pub fn key(&self) -> WebSocketKey { 242 | self.ws_key 243 | } 244 | 245 | /// The protocols the client is proposing. 246 | pub fn protocols(&self) -> impl Iterator { 247 | self.protocols.iter().cloned() 248 | } 249 | 250 | /// The path the client is requesting. 251 | pub fn path(&self) -> &str { 252 | self.path 253 | } 254 | 255 | /// Select HTTP headers sent by the client. 256 | pub fn headers(&self) -> RequestHeaders { 257 | self.headers 258 | } 259 | } 260 | 261 | /// Handshake response the server sends back to the client. 262 | #[derive(Debug)] 263 | pub enum Response<'a> { 264 | /// The server accepts the handshake request. 265 | Accept { key: WebSocketKey, protocol: Option<&'a str> }, 266 | /// The server rejects the handshake request. 267 | Reject { status_code: u16 }, 268 | } 269 | 270 | /// Known status codes and their reason phrases. 271 | const STATUSCODES: &[(u16, &str)] = &[ 272 | (100, "100 Continue"), 273 | (101, "101 Switching Protocols"), 274 | (102, "102 Processing"), 275 | (200, "200 OK"), 276 | (201, "201 Created"), 277 | (202, "202 Accepted"), 278 | (203, "203 Non Authoritative Information"), 279 | (204, "204 No Content"), 280 | (205, "205 Reset Content"), 281 | (206, "206 Partial Content"), 282 | (207, "207 Multi-Status"), 283 | (208, "208 Already Reported"), 284 | (226, "226 IM Used"), 285 | (300, "300 Multiple Choices"), 286 | (301, "301 Moved Permanently"), 287 | (302, "302 Found"), 288 | (303, "303 See Other"), 289 | (304, "304 Not Modified"), 290 | (305, "305 Use Proxy"), 291 | (307, "307 Temporary Redirect"), 292 | (308, "308 Permanent Redirect"), 293 | (400, "400 Bad Request"), 294 | (401, "401 Unauthorized"), 295 | (402, "402 Payment Required"), 296 | (403, "403 Forbidden"), 297 | (404, "404 Not Found"), 298 | (405, "405 Method Not Allowed"), 299 | (406, "406 Not Acceptable"), 300 | (407, "407 Proxy Authentication Required"), 301 | (408, "408 Request Timeout"), 302 | (409, "409 Conflict"), 303 | (410, "410 Gone"), 304 | (411, "411 Length Required"), 305 | (412, "412 Precondition Failed"), 306 | (413, "413 Payload Too Large"), 307 | (414, "414 URI Too Long"), 308 | (415, "415 Unsupported Media Type"), 309 | (416, "416 Range Not Satisfiable"), 310 | (417, "417 Expectation Failed"), 311 | (418, "418 I'm a teapot"), 312 | (421, "421 Misdirected Request"), 313 | (422, "422 Unprocessable Entity"), 314 | (423, "423 Locked"), 315 | (424, "424 Failed Dependency"), 316 | (426, "426 Upgrade Required"), 317 | (428, "428 Precondition Required"), 318 | (429, "429 Too Many Requests"), 319 | (431, "431 Request Header Fields Too Large"), 320 | (451, "451 Unavailable For Legal Reasons"), 321 | (500, "500 Internal Server Error"), 322 | (501, "501 Not Implemented"), 323 | (502, "502 Bad Gateway"), 324 | (503, "503 Service Unavailable"), 325 | (504, "504 Gateway Timeout"), 326 | (505, "505 HTTP Version Not Supported"), 327 | (506, "506 Variant Also Negotiates"), 328 | (507, "507 Insufficient Storage"), 329 | (508, "508 Loop Detected"), 330 | (510, "510 Not Extended"), 331 | (511, "511 Network Authentication Required"), 332 | ]; 333 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019 Parity Technologies (UK) Ltd. 2 | // Copyright (c) 2016 twist developers 3 | // 4 | // Licensed under the Apache License, Version 2.0 5 | // or the MIT 6 | // license , at your 7 | // option. All files in the project carrying such notice may not be copied, 8 | // modified, or distributed except according to those terms. 9 | 10 | //! An implementation of the [RFC 6455][rfc6455] websocket protocol. 11 | //! 12 | //! To begin a websocket connection one first needs to perform a [handshake], 13 | //! either as [client] or [server], in order to upgrade from HTTP. 14 | //! Once successful, the client or server can transition to a connection, 15 | //! i.e. a [Sender]/[Receiver] pair and send and receive textual or 16 | //! binary data. 17 | //! 18 | //! **Note**: While it is possible to only receive websocket messages it is 19 | //! not possible to only send websocket messages. Receiving data is required 20 | //! in order to react to control frames such as PING or CLOSE. While those will be 21 | //! answered transparently they have to be received in the first place, so 22 | //! calling [`connection::Receiver::receive`] is imperative. 23 | //! 24 | //! **Note**: None of the `async` methods are safe to cancel so their `Future`s 25 | //! must not be dropped unless they return `Poll::Ready`. 26 | //! 27 | //! # Client example 28 | //! 29 | //! ```no_run 30 | //! # use tokio_util::compat::TokioAsyncReadCompatExt; 31 | //! # async fn doc() -> Result<(), soketto::BoxedError> { 32 | //! use soketto::handshake::{Client, ServerResponse}; 33 | //! 34 | //! // First, we need to establish a TCP connection. 35 | //! let socket = tokio::net::TcpStream::connect("...").await?; 36 | //! 37 | //! // Then we configure the client handshake. 38 | //! let mut client = Client::new(socket.compat(), "...", "/"); 39 | //! 40 | //! // And finally we perform the handshake and handle the result. 41 | //! let (mut sender, mut receiver) = match client.handshake().await? { 42 | //! ServerResponse::Accepted { .. } => client.into_builder().finish(), 43 | //! ServerResponse::Redirect { status_code, location } => unimplemented!("follow location URL"), 44 | //! ServerResponse::Rejected { status_code } => unimplemented!("handle failure") 45 | //! }; 46 | //! 47 | //! // Over the established websocket connection we can send 48 | //! sender.send_text("some text").await?; 49 | //! sender.send_text("some more text").await?; 50 | //! sender.flush().await?; 51 | //! 52 | //! // ... and receive data. 53 | //! let mut data = Vec::new(); 54 | //! receiver.receive_data(&mut data).await?; 55 | //! 56 | //! # Ok(()) 57 | //! # } 58 | //! 59 | //! ``` 60 | //! 61 | //! # Server example 62 | //! 63 | //! ```no_run 64 | //! # use tokio_util::compat::TokioAsyncReadCompatExt; 65 | //! # use tokio_stream::{wrappers::TcpListenerStream, StreamExt}; 66 | //! # async fn doc() -> Result<(), soketto::BoxedError> { 67 | //! use soketto::{handshake::{Server, ClientRequest, server::Response}}; 68 | //! 69 | //! // First, we listen for incoming connections. 70 | //! let listener = tokio::net::TcpListener::bind("...").await?; 71 | //! let mut incoming = TcpListenerStream::new(listener); 72 | //! 73 | //! while let Some(socket) = incoming.next().await { 74 | //! // For each incoming connection we perform a handshake. 75 | //! let mut server = Server::new(socket?.compat()); 76 | //! 77 | //! let websocket_key = { 78 | //! let req = server.receive_request().await?; 79 | //! req.key() 80 | //! }; 81 | //! 82 | //! // Here we accept the client unconditionally. 83 | //! let accept = Response::Accept { key: websocket_key, protocol: None }; 84 | //! server.send_response(&accept).await?; 85 | //! 86 | //! // And we can finally transition to a websocket connection. 87 | //! let (mut sender, mut receiver) = server.into_builder().finish(); 88 | //! 89 | //! let mut data = Vec::new(); 90 | //! let data_type = receiver.receive_data(&mut data).await?; 91 | //! 92 | //! if data_type.is_text() { 93 | //! sender.send_text(std::str::from_utf8(&data)?).await? 94 | //! } else { 95 | //! sender.send_binary(&data).await? 96 | //! } 97 | //! 98 | //! sender.close().await? 99 | //! } 100 | //! 101 | //! # Ok(()) 102 | //! # } 103 | //! 104 | //! ``` 105 | //! 106 | //! See `examples/hyper_server.rs` from this crate's repository for an example of 107 | //! starting up a WebSocket server alongside an Hyper HTTP server. 108 | //! 109 | //! [client]: handshake::Client 110 | //! [server]: handshake::Server 111 | //! [Sender]: connection::Sender 112 | //! [Receiver]: connection::Receiver 113 | //! [rfc6455]: https://tools.ietf.org/html/rfc6455 114 | //! [handshake]: https://tools.ietf.org/html/rfc6455#section-4 115 | 116 | #![forbid(unsafe_code)] 117 | 118 | pub mod base; 119 | pub mod connection; 120 | pub mod data; 121 | pub mod extension; 122 | pub mod handshake; 123 | 124 | use bytes::BytesMut; 125 | use futures::io::{AsyncRead, AsyncReadExt}; 126 | use std::io; 127 | 128 | pub use connection::{Mode, Receiver, Sender}; 129 | pub use data::{Data, Incoming}; 130 | 131 | pub type BoxedError = Box; 132 | 133 | /// A parsing result. 134 | #[derive(Debug, Clone)] 135 | pub enum Parsing { 136 | /// Parsing completed. 137 | Done { 138 | /// The parsed value. 139 | value: T, 140 | /// The offset into the byte slice that has been consumed. 141 | offset: usize, 142 | }, 143 | /// Parsing is incomplete and needs more data. 144 | NeedMore(N), 145 | } 146 | 147 | /// A buffer type used for implementing `Extension`s. 148 | #[derive(Debug)] 149 | pub enum Storage<'a> { 150 | /// A read-only shared byte slice. 151 | Shared(&'a [u8]), 152 | /// A mutable byte slice. 153 | Unique(&'a mut [u8]), 154 | /// An owned byte buffer. 155 | Owned(Vec), 156 | } 157 | 158 | impl AsRef<[u8]> for Storage<'_> { 159 | fn as_ref(&self) -> &[u8] { 160 | match self { 161 | Storage::Shared(d) => d, 162 | Storage::Unique(d) => d, 163 | Storage::Owned(b) => b.as_ref(), 164 | } 165 | } 166 | } 167 | 168 | /// Helper function to allow casts from `usize` to `u64` only on platforms 169 | /// where the sizes are guaranteed to fit. 170 | #[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))] 171 | const fn as_u64(a: usize) -> u64 { 172 | a as u64 173 | } 174 | 175 | /// Fill the buffer from the given `AsyncRead` impl with up to `max` bytes. 176 | async fn read(reader: &mut R, dest: &mut BytesMut, max: usize) -> io::Result<()> 177 | where 178 | R: AsyncRead + Unpin, 179 | { 180 | let i = dest.len(); 181 | dest.resize(i + max, 0u8); 182 | let n = reader.read(&mut dest[i..]).await?; 183 | dest.truncate(i + n); 184 | if n == 0 { 185 | return Err(io::ErrorKind::UnexpectedEof.into()); 186 | } 187 | log::trace!("read {} bytes", n); 188 | Ok(()) 189 | } 190 | --------------------------------------------------------------------------------