├── .github ├── stale.yml └── workflows │ └── ci.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── README.md ├── stub │ ├── .gitignore │ ├── Cargo.toml │ └── src │ │ └── main.rs └── wg-cli │ ├── .gitignore │ ├── Cargo.toml │ └── src │ └── main.rs ├── integration-tests ├── .gitignore ├── justfile ├── suites │ ├── native-tun │ │ ├── Cargo.toml │ │ └── src │ │ │ └── main.rs │ ├── wireguard-to-wiretun │ │ ├── run-test.sh │ │ └── tester │ │ │ ├── .gitignore │ │ │ ├── Cargo.toml │ │ │ └── src │ │ │ └── main.rs │ └── wiretun-to-wiretun │ │ ├── README.md │ │ ├── run-test.sh │ │ └── tester │ │ ├── .gitignore │ │ ├── Cargo.toml │ │ └── src │ │ └── main.rs └── support │ └── wiretun-cli │ ├── .gitignore │ ├── Cargo.toml │ └── src │ ├── main.rs │ └── packet.rs ├── justfile ├── src ├── device │ ├── config.rs │ ├── error.rs │ ├── handle.rs │ ├── inbound.rs │ ├── metrics.rs │ ├── mod.rs │ ├── peer │ │ ├── cidr.rs │ │ ├── handle.rs │ │ ├── handshake.rs │ │ ├── index.rs │ │ ├── mod.rs │ │ ├── monitor.rs │ │ └── session.rs │ ├── rate_limiter.rs │ └── time.rs ├── lib.rs ├── noise │ ├── crypto.rs │ ├── error.rs │ ├── handshake │ │ ├── cookie.rs │ │ ├── initiation.rs │ │ ├── mod.rs │ │ └── response.rs │ ├── mod.rs │ ├── protocol.rs │ └── timestamp.rs ├── tun │ ├── error.rs │ ├── linux │ │ ├── mod.rs │ │ ├── sys.rs │ │ └── tun.rs │ ├── macos │ │ ├── mod.rs │ │ ├── sys.rs │ │ └── tun.rs │ └── mod.rs └── uapi │ ├── connection.rs │ ├── error.rs │ ├── mod.rs │ └── protocol.rs └── tests ├── handshake.rs └── support.rs /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ "master" ] 6 | pull_request: 7 | branches: [ "master" ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | lint: 14 | name: Coding style check 15 | runs-on: ubuntu-latest 16 | timeout-minutes: 10 17 | steps: 18 | - name: Checkout Code 19 | uses: actions/checkout@v4 20 | - name: Install rust toolchain 21 | uses: dtolnay/rust-toolchain@master 22 | with: 23 | toolchain: stable 24 | components: clippy 25 | - name: Install just 26 | uses: extractions/setup-just@v2 27 | - name: Run lint 28 | run: | 29 | just lint 30 | 31 | unit-test: 32 | name: Unit test 33 | runs-on: ubuntu-latest 34 | timeout-minutes: 10 35 | steps: 36 | - name: Checkout Code 37 | uses: actions/checkout@v4 38 | - name: Install rust toolchain 39 | uses: dtolnay/rust-toolchain@master 40 | with: 41 | toolchain: stable 42 | components: clippy 43 | - name: Install just 44 | uses: extractions/setup-just@v2 45 | - name: Install cargo-nextest 46 | uses: taiki-e/install-action@nextest 47 | - name: Run tests 48 | run: | 49 | just unit-test 50 | 51 | integration-test: 52 | name: Integration test 53 | strategy: 54 | matrix: 55 | os: [ubuntu-latest, macos-latest] 56 | scenario: [native-tun, wiretun-to-wiretun, wireguard-to-wiretun] 57 | runs-on: ${{ matrix.os }} 58 | timeout-minutes: 30 59 | steps: 60 | - name: Checkout Code 61 | uses: actions/checkout@v4 62 | - name: Install rust toolchain 63 | uses: dtolnay/rust-toolchain@master 64 | with: 65 | toolchain: stable 66 | components: clippy 67 | - name: Install just 68 | uses: extractions/setup-just@v2 69 | - name: Install WireGuard tools (Linux) 70 | if: runner.os == 'Linux' 71 | run: | 72 | sudo apt update -y 73 | sudo apt install -y wireguard 74 | - name: Set up Homebrew (MacOS) 75 | if: runner.os == 'macOS' 76 | uses: Homebrew/actions/setup-homebrew@master 77 | - name: Install WireGuard tools (MacOS) 78 | if: runner.os == 'macOS' 79 | run: | 80 | brew install wireguard-tools 81 | - name: Run tests 82 | run: | 83 | just integration-test-${{ matrix.scenario }} 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/target 2 | /Cargo.lock 3 | .idea 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wiretun" 3 | description = "WireGuard Library" 4 | version = "0.5.0" 5 | authors = ["zarvd "] 6 | keywords = ["wireguard", "networking"] 7 | repository = "https://github.com/zarvd/wiretun" 8 | homepage = "https://github.com/zarvd/wiretun" 9 | documentation = "https://docs.rs/wiretun" 10 | license = "Apache-2.0" 11 | edition = "2021" 12 | 13 | [features] 14 | default = ["native"] 15 | native = [] 16 | uapi = [] 17 | 18 | [dependencies] 19 | libc = "0.2" 20 | nix = { version = "0.29", features = ["fs", "ioctl", "socket"] } 21 | socket2 = "0.5" 22 | bytes = "1.6" 23 | regex = "1.10" 24 | rand_core = "0.6" 25 | anyhow = "1.0" 26 | thiserror = "1.0" 27 | tracing = "0.1" 28 | futures = "0.3" 29 | async-trait = "0.1" 30 | tokio = { version = "1.37", features = ["full"] } 31 | tokio-util = "0.7" 32 | chacha20poly1305 = "0.10" 33 | x25519-dalek = { version = "2.0", features = ["reusable_secrets", "static_secrets"] } 34 | blake2 = "0.10" 35 | hmac = "0.12" 36 | ip_network = "0.4" 37 | ip_network_table = "0.2" 38 | 39 | [dev-dependencies] 40 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 [lodrem ] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WireTun 2 | 3 | [github](https://github.com/zarvd/wiretun) 4 | [crates.io](https://crates.io/crates/wiretun) 5 | [docs.rs](https://docs.rs/wiretun) 6 | [build status](https://github.com/zarvd/wiretun/actions?query%3Amaster) 7 | [dependency status](https://deps.rs/repo/github/zarvd/wiretun) 8 | 9 | This library provides a cross-platform, asynchronous (with [Tokio](https://tokio.rs/)) [WireGuard](https://www.wireguard.com/) implementation. 10 | 11 | **WARNING**: This library is still in early development and is not ready for production use. 12 | 13 | ```toml 14 | [dependencies] 15 | wiretun = { version = "*", features = ["uapi"] } 16 | ``` 17 | 18 | ## Example 19 | 20 | ```rust 21 | use wiretun::{Cidr, Device, DeviceConfig, PeerConfig, uapi}; 22 | 23 | #[tokio::main] 24 | async fn main() -> Result<(), Box> { 25 | let cfg = DeviceConfig::default() 26 | .listen_port(40001); 27 | let device = Device::native("utun88", cfg).await?; 28 | uapi::bind_and_handle(device.control()).await?; 29 | Ok(()) 30 | } 31 | ``` 32 | 33 | More examples can be found in the [examples](examples) directory. 34 | 35 | ## Minimum supported Rust version (MSRV) 36 | 37 | 1.66.1 38 | 39 | ## License 40 | 41 | This project is licensed under the [Apache 2.0 license](LICENSE). 42 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | - [wg-cli](./wg-cli): A CLI that works like `wireguard-go`. 4 | - [stub](./stub): A CLI that will echo back the packets it receives. -------------------------------------------------------------------------------- /examples/stub/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /examples/stub/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "stub" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | wiretun = { path = "../../", features = ["uapi"] } 9 | async-trait = "0.1" 10 | tokio = { version = "1.27", features = ["full"] } 11 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 12 | tracing = "0.1" 13 | base64 = "0.21" 14 | pnet = "0.33" -------------------------------------------------------------------------------- /examples/stub/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::sync::Arc; 3 | 4 | use async_trait::async_trait; 5 | use base64::engine::general_purpose::STANDARD as base64Encoding; 6 | use base64::Engine; 7 | use tokio::sync::mpsc; 8 | use tokio::sync::Mutex; 9 | use tracing::{debug, error, info}; 10 | use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; 11 | 12 | use wiretun::{uapi, Cidr, Device, DeviceConfig, PeerConfig, Tun, TunError}; 13 | 14 | fn decode_base64(s: &str) -> Vec { 15 | base64Encoding.decode(s).unwrap() 16 | } 17 | 18 | fn local_private_key() -> [u8; 32] { 19 | decode_base64("eMdnuqCl3u2WjK2Wzfw16y1ddgdkKkzmlXukKL3WnVU=") 20 | .try_into() 21 | .unwrap() 22 | } 23 | 24 | fn peer_public_key() -> [u8; 32] { 25 | decode_base64("Wv/8YAQITWMHhZ0a4qgNNy689546TXVgD+XJefKxzDw=") 26 | .try_into() 27 | .unwrap() 28 | } 29 | #[tokio::main] 30 | async fn main() -> Result<(), Box> { 31 | tracing_subscriber::registry() 32 | .with(tracing_subscriber::EnvFilter::new( 33 | std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into()), 34 | )) 35 | .with(tracing_subscriber::fmt::layer()) 36 | .init(); 37 | 38 | let cfg = DeviceConfig::default() 39 | .listen_port(51871) 40 | .private_key(local_private_key()) 41 | .peer( 42 | PeerConfig::default() 43 | .public_key(peer_public_key()) 44 | .allowed_ip("10.0.0.1".parse::()?), 45 | ); 46 | let tun = StubTun::new(); 47 | let device = Device::with_udp(tun, cfg).await?; 48 | 49 | let ctrl = device.control(); 50 | tokio::spawn(uapi::bind_and_handle(ctrl)); 51 | 52 | tokio::signal::ctrl_c().await?; 53 | device.terminate().await; // stop gracefully 54 | 55 | Ok(()) 56 | } 57 | 58 | #[derive(Clone)] 59 | struct StubTun { 60 | tx: mpsc::Sender>, 61 | rx: Arc>>>, 62 | } 63 | 64 | impl StubTun { 65 | pub fn new() -> Self { 66 | let (tx, rx) = mpsc::channel(128); 67 | let rx = Arc::new(Mutex::new(rx)); 68 | Self { tx, rx } 69 | } 70 | 71 | fn handle(&self, mut buf: Vec) -> Vec { 72 | use pnet::packet::ip::IpNextHeaderProtocols; 73 | use pnet::packet::ipv4::{checksum, MutableIpv4Packet}; 74 | use pnet::packet::udp::{ipv4_checksum, MutableUdpPacket}; 75 | use pnet::packet::Packet; 76 | let mut ipv4 = MutableIpv4Packet::new(&mut buf).unwrap(); 77 | let src_ip = ipv4.get_source(); 78 | let dst_ip = ipv4.get_destination(); 79 | ipv4.set_source(dst_ip); 80 | ipv4.set_destination(src_ip); 81 | 82 | match ipv4.get_next_level_protocol() { 83 | IpNextHeaderProtocols::Udp => { 84 | let mut udp = MutableUdpPacket::owned(ipv4.payload().to_vec()).unwrap(); 85 | let src_port = udp.get_source(); 86 | let dst_port = udp.get_destination(); 87 | udp.set_source(dst_port); 88 | udp.set_destination(src_port); 89 | udp.set_checksum(ipv4_checksum(&udp.to_immutable(), &dst_ip, &src_ip)); 90 | ipv4.set_payload(udp.packet()); 91 | } 92 | _ => { 93 | debug!("Unknown packet type!"); 94 | } 95 | } 96 | 97 | ipv4.set_checksum(checksum(&ipv4.to_immutable())); 98 | 99 | ipv4.packet().to_vec() 100 | } 101 | } 102 | 103 | #[async_trait] 104 | impl Tun for StubTun { 105 | fn name(&self) -> &str { 106 | "stub" 107 | } 108 | 109 | fn mtu(&self) -> Result { 110 | Ok(1500) 111 | } 112 | 113 | fn set_mtu(&self, _mtu: u16) -> Result<(), TunError> { 114 | Ok(()) 115 | } 116 | 117 | async fn recv(&self) -> Result, TunError> { 118 | let mut rx = self.rx.lock().await; 119 | let rv = rx.recv().await.ok_or(TunError::Closed); 120 | 121 | match &rv { 122 | Ok(buf) => { 123 | info!("recv data[{}] from tun", buf.len()); 124 | } 125 | Err(e) => { 126 | error!("failed to recv data from tun: {:?}", e); 127 | } 128 | } 129 | 130 | rv 131 | } 132 | 133 | async fn send(&self, buf: &[u8]) -> Result<(), TunError> { 134 | info!("recv data[{}] from outbound", buf.len()); 135 | self.tx 136 | .send(self.handle(buf.to_vec())) 137 | .await 138 | .map_err(|_| TunError::Closed) 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /examples/wg-cli/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /examples/wg-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wg-cli" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | wiretun = { path = "../../", features = ["uapi"] } 9 | tokio = { version = "1.27", features = ["full"] } 10 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 11 | tracing = "0.1" 12 | base64 = "0.21" 13 | -------------------------------------------------------------------------------- /examples/wg-cli/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::time::Duration; 3 | 4 | use base64::engine::general_purpose::STANDARD as base64Encoding; 5 | use base64::Engine; 6 | use tracing::info; 7 | use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; 8 | 9 | use wiretun::{uapi, Cidr, Device, DeviceConfig, PeerConfig}; 10 | 11 | fn decode_base64(s: &str) -> Vec { 12 | base64Encoding.decode(s).unwrap() 13 | } 14 | 15 | fn local_private_key() -> [u8; 32] { 16 | decode_base64("GDE0rT7tfVGairGhTASn5+ck1mUSqLNyajyMSBFYpVQ=") 17 | .try_into() 18 | .unwrap() 19 | } 20 | 21 | fn peer_public_key() -> [u8; 32] { 22 | decode_base64("ArhPnhqqlroFdP4wca7Yu9PuUR1p+TfMhy9kBewLNjM=") 23 | .try_into() 24 | .unwrap() 25 | } 26 | 27 | #[tokio::main] 28 | async fn main() -> Result<(), Box> { 29 | tracing_subscriber::registry() 30 | .with(tracing_subscriber::EnvFilter::new( 31 | std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into()), 32 | )) 33 | .with(tracing_subscriber::fmt::layer()) 34 | .init(); 35 | 36 | info!("Starting"); 37 | 38 | let cfg = DeviceConfig::default() 39 | .listen_port(9999) 40 | .private_key(local_private_key()) 41 | .peer( 42 | PeerConfig::default() 43 | .public_key(peer_public_key()) 44 | .endpoint("0.0.0.0:51871".parse()?) 45 | .allowed_ip("10.0.0.2".parse::()?) 46 | .persistent_keepalive(Duration::from_secs(5)), 47 | ); 48 | 49 | let device = Device::native("utun88", cfg).await?; 50 | 51 | let ctrl = device.control(); 52 | tokio::spawn(async move { 53 | uapi::bind_and_handle(ctrl).await.unwrap(); 54 | }); 55 | 56 | let ctrl = device.control(); 57 | tokio::spawn(async move { 58 | tokio::time::sleep(Duration::from_secs(10)).await; 59 | info!("Updating listen port"); 60 | let _ = ctrl.update_listen_port(9991).await; 61 | }); 62 | 63 | tokio::signal::ctrl_c().await?; 64 | device.terminate().await; // stop gracefully 65 | 66 | Ok(()) 67 | } 68 | -------------------------------------------------------------------------------- /integration-tests/.gitignore: -------------------------------------------------------------------------------- 1 | **/bin 2 | **/log 3 | **/Cargo.lock -------------------------------------------------------------------------------- /integration-tests/justfile: -------------------------------------------------------------------------------- 1 | default: 2 | just --list 3 | 4 | # Build test support 5 | build: 6 | #!/usr/bin/env bash 7 | set -e 8 | pushd support/wiretun-cli 9 | cargo build 10 | popd 11 | 12 | rm -rf bin 13 | mkdir bin 14 | cp support/wiretun-cli/target/debug/wiretun-cli bin/wiretun-cli 15 | 16 | rm -rf suites/wireguard-to-wiretun/bin 17 | mkdir suites/wireguard-to-wiretun/bin 18 | ln -s {{ absolute_path("./bin/wiretun-cli") }} suites/wireguard-to-wiretun/bin/wiretun-cli 19 | 20 | rm -rf suites/wiretun-to-wiretun/bin 21 | mkdir suites/wiretun-to-wiretun/bin 22 | ln -s {{ absolute_path("./bin/wiretun-cli") }} suites/wiretun-to-wiretun/bin/wiretun-cli 23 | 24 | run-tests: test-native-tun test-wireguard-to-wiretun test-wiretun-to-wiretun 25 | 26 | test-native-tun: build 27 | #!/usr/bin/env bash 28 | set -e 29 | pushd suites/native-tun 30 | cargo build 31 | sudo target/debug/wiretun-native-tun 32 | popd 33 | 34 | test-wireguard-to-wiretun: build 35 | #!/usr/bin/env bash 36 | set -e 37 | pushd suites/wireguard-to-wiretun 38 | 39 | # build tester 40 | pushd tester 41 | cargo build 42 | popd 43 | 44 | cp tester/target/debug/wireguard-to-wiretun-tester bin/tester 45 | 46 | # Run test 47 | sudo ./run-test.sh 48 | 49 | popd 50 | 51 | test-wiretun-to-wiretun: build 52 | #!/usr/bin/env bash 53 | set -e 54 | pushd suites/wiretun-to-wiretun 55 | 56 | # build tester 57 | pushd tester 58 | cargo build 59 | popd 60 | 61 | cp tester/target/debug/wiretun-to-wiretun-tester bin/tester 62 | 63 | # Run test 64 | sudo ./run-test.sh 65 | 66 | popd -------------------------------------------------------------------------------- /integration-tests/suites/native-tun/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wiretun-native-tun" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | wiretun = { path = "../../../", features = ["uapi"] } 9 | tokio = { version = "1.37", features = ["full"] } 10 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 11 | tracing = "0.1" 12 | -------------------------------------------------------------------------------- /integration-tests/suites/native-tun/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use tracing::info; 4 | use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; 5 | 6 | use wiretun::{NativeTun, Tun}; 7 | 8 | #[tokio::main] 9 | async fn main() -> Result<(), Box> { 10 | tracing_subscriber::registry() 11 | .with(tracing_subscriber::EnvFilter::new( 12 | std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into()), 13 | )) 14 | .with(tracing_subscriber::fmt::layer()) 15 | .init(); 16 | 17 | test_set_mtu().await?; 18 | 19 | Ok(()) 20 | } 21 | 22 | async fn test_set_mtu() -> Result<(), Box> { 23 | info!("test_set_mtu"); 24 | 25 | let name = "utun"; 26 | let tun = NativeTun::new(name)?; 27 | tun.set_mtu(1400)?; 28 | assert_eq!(tun.mtu()?, 1400); 29 | tun.set_mtu(1500)?; 30 | assert_eq!(tun.mtu()?, 1500); 31 | 32 | Ok(()) 33 | } 34 | -------------------------------------------------------------------------------- /integration-tests/suites/wireguard-to-wiretun/run-test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | set -e 5 | 6 | # Require wireguard-tools is installed 7 | if ! command -v wg-quick &> /dev/null 8 | then 9 | echo "wg-quick not installed, please install it first" 10 | exit 1 11 | fi 12 | 13 | PIDS=() 14 | cleanup() { 15 | echo "Cleaning up" 16 | for pid in ${PIDS[@]}; do 17 | echo "Killing ${pid}" 18 | sudo kill -9 ${pid} 19 | done 20 | 21 | # stop peer1 22 | wg-quick down ./utun.conf || true 23 | rm ./utun.conf 24 | } 25 | trap cleanup EXIT 26 | 27 | 28 | # ============================================ 29 | # Start three wireguard instances: 30 | # 1. Peer1: use `wg-quick` to setup with a native tun 31 | # 2. Peer2: use `wiretun` to setup with stub tun 32 | # 3. Peer3: use `wiretun` to setup with stub tun (with preshared_key) 33 | # ============================================ 34 | 35 | PEER1_LISTEN_PORT=50081 36 | PEER1_KEY=oLCiGZ7J6eMjpWgBIClVGPccrnopmqIOcia8HnDN/lY= 37 | PEER1_PUB=jNMMQlzMwX0WeeWed9v6lINsBS3PhmF+/4fKbdfNZTA= 38 | 39 | PEER2_LISTEN_PORT=50082 40 | PEER2_NAME=peer2-stub 41 | PEER2_KEY=UGyzBpReHMheRGbwr5vFJ1Yu8Xkkbn5ub3F8w22y3HA= 42 | PEER2_PUB=KlVx32ZygXCBRK2X7ko9qF5FCVfNACzKoAglNnbt1m4= 43 | 44 | PEER3_LISTEN_PORT=50083 45 | PEER3_NAME=peer3-stub 46 | PEER3_KEY=cHpUPuuP4kMccJFQ5KoGJih1UuSzIF6TI5rfiuRCF3U= 47 | PEER3_PUB=h0h2J2HjfBPzLZ31UpkqvtNXYtCjWKT20xccF/B6Wgw= 48 | PEER3_PSK=MSb1Drx0brNic2B2hAtkgKUgd4ypNbDMJZKyB4EFzlg= 49 | 50 | rm -rf log 51 | mkdir log 52 | PEER2_LOG=log/peer2.log 53 | PEER3_LOG=log/peer3.log 54 | 55 | start_peer1() { 56 | cat > utun.conf <<-EOF 57 | [Interface] 58 | Address = 10.11.100.1/32 59 | ListenPort = ${PEER1_LISTEN_PORT} 60 | PrivateKey = ${PEER1_KEY} 61 | 62 | [Peer] 63 | PublicKey = ${PEER2_PUB} 64 | AllowedIPs = 10.11.100.2/32 65 | 66 | [Peer] 67 | PublicKey = ${PEER3_PUB} 68 | AllowedIPs = 10.11.100.3/32 69 | PresharedKey = ${PEER3_PSK} 70 | EOF 71 | 72 | wg-quick up ./utun.conf 73 | } 74 | 75 | start_peer2() { 76 | ./bin/wiretun-cli \ 77 | --mode stub \ 78 | --name ${PEER2_NAME} \ 79 | --private-key ${PEER2_KEY} \ 80 | --listen-port ${PEER2_LISTEN_PORT} &> ${PEER2_LOG} & 81 | PID=$! 82 | PIDS+=(${PID}) 83 | echo "Peer2 PID: ${PID}" 84 | 85 | sleep 5 86 | wg set ${PEER2_NAME} \ 87 | peer ${PEER1_PUB} \ 88 | endpoint 0.0.0.0:${PEER1_LISTEN_PORT} \ 89 | allowed-ips 10.11.100.1/32 90 | } 91 | 92 | start_peer3() { 93 | ./bin/wiretun-cli \ 94 | --mode stub \ 95 | --name ${PEER3_NAME} \ 96 | --private-key ${PEER3_KEY} \ 97 | --listen-port ${PEER3_LISTEN_PORT} &> ${PEER3_LOG} & 98 | PID=$! 99 | PIDS+=(${PID}) 100 | echo "Peer3 PID: ${PID}" 101 | 102 | sleep 5 103 | PEER3_PSK_FILE=$(mktemp) 104 | echo ${PEER3_PSK} > ${PEER3_PSK_FILE} 105 | wg set ${PEER3_NAME} \ 106 | peer ${PEER1_PUB} \ 107 | endpoint 0.0.0.0:${PEER1_LISTEN_PORT} \ 108 | preshared-key ${PEER3_PSK_FILE} \ 109 | allowed-ips 10.11.100.1/32 110 | } 111 | 112 | start_peers() { 113 | start_peer1 114 | start_peer2 115 | start_peer3 116 | } 117 | 118 | run() { 119 | start_peers 120 | 121 | ./bin/tester 122 | RET=$? 123 | 124 | exit ${RET} 125 | } 126 | 127 | run -------------------------------------------------------------------------------- /integration-tests/suites/wireguard-to-wiretun/tester/.gitignore: -------------------------------------------------------------------------------- 1 | /Cargo.lock 2 | -------------------------------------------------------------------------------- /integration-tests/suites/wireguard-to-wiretun/tester/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wireguard-to-wiretun-tester" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | tokio = { version = "1.37", features = ["full"] } 9 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 10 | tracing = "0.1" 11 | rand_core = { version = "0.6", features = ["getrandom"] } 12 | -------------------------------------------------------------------------------- /integration-tests/suites/wireguard-to-wiretun/tester/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::net::{IpAddr, Ipv4Addr}; 3 | 4 | use tracing::info; 5 | use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; 6 | 7 | #[tokio::main] 8 | async fn main() -> Result<(), Box> { 9 | tracing_subscriber::registry() 10 | .with(tracing_subscriber::EnvFilter::new( 11 | std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into()), 12 | )) 13 | .with(tracing_subscriber::fmt::layer()) 14 | .init(); 15 | 16 | { 17 | let local_ip = IpAddr::V4(Ipv4Addr::new(10, 11, 100, 1)); 18 | let remote_ip = IpAddr::V4(Ipv4Addr::new(10, 11, 100, 2)); 19 | 20 | info!("=================="); 21 | info!("Running test: simple_udp_echo"); 22 | test_case::simple_udp_echo(local_ip, remote_ip, b"peer2-stub").await?; 23 | 24 | info!("=================="); 25 | info!("Running test: after_rekey_udp_echo"); 26 | test_case::after_rekey_udp_echo(local_ip, remote_ip, b"peer2-stub").await?; 27 | } 28 | 29 | { 30 | let local_ip = IpAddr::V4(Ipv4Addr::new(10, 11, 100, 1)); 31 | let remote_ip = IpAddr::V4(Ipv4Addr::new(10, 11, 100, 3)); 32 | 33 | info!("=================="); 34 | info!("Running test: simple_udp_echo"); 35 | test_case::simple_udp_echo(local_ip, remote_ip, b"peer3-stub").await?; 36 | 37 | info!("=================="); 38 | info!("Running test: after_rekey_udp_echo"); 39 | test_case::after_rekey_udp_echo(local_ip, remote_ip, b"peer3-stub").await?; 40 | } 41 | 42 | Ok(()) 43 | } 44 | 45 | mod test_case { 46 | use std::error::Error; 47 | use std::net::{IpAddr, SocketAddr}; 48 | use std::time::Duration; 49 | 50 | use rand_core::{OsRng, RngCore}; 51 | use tokio::net::UdpSocket; 52 | use tokio::time; 53 | use tracing::{info, instrument}; 54 | 55 | #[instrument(skip_all)] 56 | pub async fn simple_udp_echo( 57 | local_ip: IpAddr, 58 | remote_ip: IpAddr, 59 | prefix: &[u8], 60 | ) -> Result<(), Box> { 61 | let local_addr = SocketAddr::new(local_ip, 45999); 62 | let remote_addr = SocketAddr::new(remote_ip, 46999); 63 | 64 | let socket = UdpSocket::bind(local_addr).await?; 65 | for i in 1..=500 { 66 | info!("[{i}/500] Running test"); 67 | let mut output = [0u8; 1024]; 68 | OsRng.fill_bytes(&mut output); 69 | 70 | socket.send_to(&output, remote_addr).await?; 71 | 72 | let mut input = [0u8; 1024 + 100]; 73 | let (len, addr) = time::timeout(Duration::from_secs(2), socket.recv_from(&mut input)) 74 | .await 75 | .expect("should recv packet in 2 secs")?; 76 | 77 | assert_eq!(addr, remote_addr); 78 | let mut expected = prefix.to_vec(); 79 | expected.extend_from_slice(&output[..]); 80 | assert_eq!(&input[..len], &expected[..]); 81 | info!("[{i}/500] Test passed"); 82 | } 83 | 84 | Ok(()) 85 | } 86 | 87 | #[instrument(skip_all)] 88 | pub async fn after_rekey_udp_echo( 89 | local_ip: IpAddr, 90 | remote_ip: IpAddr, 91 | prefix: &[u8], 92 | ) -> Result<(), Box> { 93 | let local_addr = SocketAddr::new(local_ip, 45999); 94 | let remote_addr = SocketAddr::new(remote_ip, 46999); 95 | 96 | let socket = UdpSocket::bind(local_addr).await?; 97 | time::sleep(Duration::from_secs(120)).await; 98 | 99 | for i in 1..=500 { 100 | info!("[{i}/500] Running test"); 101 | let mut output = [0u8; 1024]; 102 | OsRng.fill_bytes(&mut output); 103 | 104 | socket.send_to(&output, remote_addr).await?; 105 | 106 | let mut input = [0u8; 1024 + 100]; 107 | let (len, addr) = time::timeout(Duration::from_secs(2), socket.recv_from(&mut input)) 108 | .await 109 | .expect("should recv packet in 2 secs")?; 110 | 111 | assert_eq!(addr, remote_addr); 112 | let mut expected = prefix.to_vec(); 113 | expected.extend_from_slice(&output[..]); 114 | assert_eq!(&input[..len], &expected[..]); 115 | info!("[{i}/500] Test passed"); 116 | } 117 | Ok(()) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /integration-tests/suites/wiretun-to-wiretun/README.md: -------------------------------------------------------------------------------- 1 | # Test Case: WireTun to WireTun 2 | 3 | ## Setup 4 | 5 | - [peer1](../support/peer1) is a WireTun instance running on the host machine with NativeTun. 6 | - [peer2](./peer2) is a WireTun instance running on the host machine with MemoryTun which will not create any tun device. 7 | - [tester](tester) will run test cases to verify the connectivity between peer1 and peer2. 8 | -------------------------------------------------------------------------------- /integration-tests/suites/wiretun-to-wiretun/run-test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | set -e 5 | 6 | PIDS=() 7 | cleanup() { 8 | echo "Cleaning up" 9 | for pid in ${PIDS[@]}; do 10 | echo "Killing ${pid}" 11 | sudo kill -9 ${pid} 12 | done 13 | } 14 | trap cleanup EXIT 15 | 16 | PEER1_LISTEN_PORT=50081 17 | PEER1_NAME=utun44 18 | PEER1_KEY=oLCiGZ7J6eMjpWgBIClVGPccrnopmqIOcia8HnDN/lY= 19 | PEER1_PUB=jNMMQlzMwX0WeeWed9v6lINsBS3PhmF+/4fKbdfNZTA= 20 | PEER1_ADDR=10.11.101.1/32 21 | 22 | PEER2_LISTEN_PORT=50082 23 | PEER2_NAME=peer2-stub 24 | PEER2_KEY=UGyzBpReHMheRGbwr5vFJ1Yu8Xkkbn5ub3F8w22y3HA= 25 | PEER2_PUB=KlVx32ZygXCBRK2X7ko9qF5FCVfNACzKoAglNnbt1m4= 26 | PEER2_ADDR=10.11.101.2/32 27 | 28 | rm -rf log 29 | mkdir log 30 | PEER1_LOG=log/peer1.log 31 | PEER2_LOG=log/peer2.log 32 | 33 | start_peer1() { 34 | ./bin/wiretun-cli \ 35 | --mode native \ 36 | --name ${PEER1_NAME} \ 37 | --private-key ${PEER1_KEY} \ 38 | --listen-port ${PEER1_LISTEN_PORT} &> ${PEER1_LOG} & 39 | PID=$! 40 | PIDS+=(${PID}) 41 | echo "Peer1 PID: ${PID}" 42 | 43 | sleep 5 44 | cat ${PEER1_LOG} 45 | wg set ${PEER1_NAME} \ 46 | peer ${PEER2_PUB} \ 47 | endpoint 0.0.0.0:${PEER2_LISTEN_PORT} \ 48 | allowed-ips ${PEER2_ADDR} 49 | 50 | if [[ "$OSTYPE" == "darwin"* ]]; then 51 | ifconfig ${PEER1_NAME} inet ${PEER1_ADDR} 10.11.101.1 alias 52 | route -q -n add -inet ${PEER2_ADDR} -interface ${PEER1_NAME} 53 | elif [[ "$OSTYPE" == "linux-gnu"* ]]; then 54 | ip -4 address add ${PEER1_ADDR} dev ${PEER1_NAME} 55 | ip link set mtu 1420 up dev ${PEER1_NAME} 56 | ip -4 route add ${PEER2_ADDR} dev ${PEER1_NAME} 57 | else 58 | echo "Unsupported OS: ${OSTYPE}" 59 | exit 1 60 | fi 61 | } 62 | 63 | start_peer2() { 64 | RUST_BACKTRACE=1 ./bin/wiretun-cli \ 65 | --mode stub \ 66 | --name ${PEER2_NAME} \ 67 | --private-key ${PEER2_KEY} \ 68 | --listen-port ${PEER2_LISTEN_PORT} &> ${PEER2_LOG} & 69 | PID=$! 70 | PIDS+=(${PID}) 71 | echo "Peer2 PID: ${PID}" 72 | 73 | sleep 5 74 | cat ${PEER2_LOG} 75 | wg set ${PEER2_NAME} \ 76 | peer ${PEER1_PUB} \ 77 | endpoint 0.0.0.0:${PEER1_LISTEN_PORT} \ 78 | allowed-ips ${PEER1_ADDR} 79 | } 80 | 81 | start_peers() { 82 | start_peer1 83 | start_peer2 84 | } 85 | 86 | run() { 87 | start_peers 88 | sleep 10 89 | ./bin/tester 90 | RET=$? 91 | exit ${RET} 92 | } 93 | 94 | run -------------------------------------------------------------------------------- /integration-tests/suites/wiretun-to-wiretun/tester/.gitignore: -------------------------------------------------------------------------------- 1 | /Cargo.lock 2 | -------------------------------------------------------------------------------- /integration-tests/suites/wiretun-to-wiretun/tester/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wiretun-to-wiretun-tester" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | tokio = { version = "1.37", features = ["full"] } 9 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 10 | tracing = "0.1" 11 | rand_core = { version = "0.6", features = ["getrandom"] } 12 | -------------------------------------------------------------------------------- /integration-tests/suites/wiretun-to-wiretun/tester/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::net::{IpAddr, Ipv4Addr}; 3 | 4 | use tracing::info; 5 | use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; 6 | 7 | #[tokio::main] 8 | async fn main() -> Result<(), Box> { 9 | tracing_subscriber::registry() 10 | .with(tracing_subscriber::EnvFilter::new( 11 | std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into()), 12 | )) 13 | .with(tracing_subscriber::fmt::layer()) 14 | .init(); 15 | 16 | let local_ip = IpAddr::V4(Ipv4Addr::new(10, 11, 101, 1)); 17 | let remote_ip = IpAddr::V4(Ipv4Addr::new(10, 11, 101, 2)); 18 | 19 | info!("=================="); 20 | info!("Running test: simple_udp_echo"); 21 | test_case::simple_udp_echo(local_ip, remote_ip, b"peer2-stub").await?; 22 | 23 | info!("=================="); 24 | info!("Running test: after_rekey_udp_echo"); 25 | test_case::after_rekey_udp_echo(local_ip, remote_ip, b"peer2-stub").await?; 26 | 27 | Ok(()) 28 | } 29 | 30 | mod test_case { 31 | use std::error::Error; 32 | use std::net::{IpAddr, SocketAddr}; 33 | use std::time::Duration; 34 | 35 | use rand_core::{OsRng, RngCore}; 36 | use tokio::net::UdpSocket; 37 | use tokio::time; 38 | use tracing::{info, instrument}; 39 | 40 | #[instrument(skip_all)] 41 | pub async fn simple_udp_echo( 42 | local_ip: IpAddr, 43 | remote_ip: IpAddr, 44 | prefix: &[u8], 45 | ) -> Result<(), Box> { 46 | let local_addr = SocketAddr::new(local_ip, 45999); 47 | let remote_addr = SocketAddr::new(remote_ip, 46999); 48 | 49 | let socket = UdpSocket::bind(local_addr).await?; 50 | for i in 1..=500 { 51 | info!("[{i}/500] Running test"); 52 | let mut output = [0u8; 1024]; 53 | OsRng.fill_bytes(&mut output); 54 | 55 | socket.send_to(&output, remote_addr).await?; 56 | 57 | let mut input = [0u8; 1024 + 100]; 58 | let (len, addr) = time::timeout(Duration::from_secs(2), socket.recv_from(&mut input)) 59 | .await 60 | .expect("should recv packet in 2 secs")?; 61 | 62 | assert_eq!(addr, remote_addr); 63 | let mut expected = prefix.to_vec(); 64 | expected.extend_from_slice(&output[..]); 65 | assert_eq!(&input[..len], &expected[..]); 66 | info!("[{i}/500] Test passed"); 67 | } 68 | 69 | Ok(()) 70 | } 71 | 72 | #[instrument(skip_all)] 73 | pub async fn after_rekey_udp_echo( 74 | local_ip: IpAddr, 75 | remote_ip: IpAddr, 76 | prefix: &[u8], 77 | ) -> Result<(), Box> { 78 | let local_addr = SocketAddr::new(local_ip, 45999); 79 | let remote_addr = SocketAddr::new(remote_ip, 46999); 80 | 81 | let socket = UdpSocket::bind(local_addr).await?; 82 | time::sleep(Duration::from_secs(120)).await; 83 | 84 | for i in 1..=500 { 85 | info!("[{i}/500] Running test"); 86 | let mut output = [0u8; 1024]; 87 | OsRng.fill_bytes(&mut output); 88 | 89 | socket.send_to(&output, remote_addr).await?; 90 | 91 | let mut input = [0u8; 1024 + 100]; 92 | let (len, addr) = time::timeout(Duration::from_secs(2), socket.recv_from(&mut input)) 93 | .await 94 | .expect("should recv packet in 2 secs")?; 95 | 96 | assert_eq!(addr, remote_addr); 97 | let mut expected = prefix.to_vec(); 98 | expected.extend_from_slice(&output[..]); 99 | assert_eq!(&input[..len], &expected[..]); 100 | info!("[{i}/500] Test passed"); 101 | } 102 | Ok(()) 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /integration-tests/support/wiretun-cli/.gitignore: -------------------------------------------------------------------------------- 1 | /Cargo.lock 2 | -------------------------------------------------------------------------------- /integration-tests/support/wiretun-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wiretun-cli" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [dependencies] 8 | wiretun = { path = "../../../", features = ["uapi"] } 9 | clap = { version = "4.5", features = ["derive"] } 10 | async-trait = "0.1" 11 | tokio = { version = "1.37", features = ["full"] } 12 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 13 | tracing = "0.1" 14 | base64 = "0.22" 15 | pnet = "0.34" 16 | -------------------------------------------------------------------------------- /integration-tests/support/wiretun-cli/src/main.rs: -------------------------------------------------------------------------------- 1 | mod packet; 2 | 3 | use std::error::Error; 4 | use std::sync::Arc; 5 | 6 | use async_trait::async_trait; 7 | use clap::{Parser, ValueEnum}; 8 | use tokio::sync::{mpsc, Mutex}; 9 | use tracing::{error, info}; 10 | use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; 11 | use wiretun::{uapi, Device, DeviceConfig, Tun, TunError}; 12 | 13 | use packet::echo_udp_packet; 14 | 15 | fn decode_base64(s: &str) -> Result<[u8; 32], String> { 16 | use base64::engine::general_purpose::STANDARD; 17 | use base64::Engine; 18 | STANDARD 19 | .decode(s) 20 | .map_err(|e| format!("bad base64 format: {}", e))? 21 | .try_into() 22 | .map_err(|e| format!("invalid secret key: {:?}", e)) 23 | } 24 | 25 | #[derive(Debug, Clone, ValueEnum)] 26 | enum Mode { 27 | Native, 28 | Stub, 29 | } 30 | 31 | #[derive(Parser, Debug)] 32 | #[command(author, version, about, long_about = None)] 33 | struct App { 34 | #[arg(long)] 35 | mode: Mode, 36 | 37 | #[arg(long)] 38 | name: String, 39 | 40 | #[arg(long, value_parser = decode_base64)] 41 | private_key: [u8; 32], 42 | 43 | #[arg(long)] 44 | listen_port: u16, 45 | } 46 | 47 | #[tokio::main] 48 | async fn main() -> Result<(), Box> { 49 | tracing_subscriber::registry() 50 | .with(tracing_subscriber::EnvFilter::new( 51 | std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into()), 52 | )) 53 | .with(tracing_subscriber::fmt::layer()) 54 | .init(); 55 | 56 | let app = App::parse(); 57 | 58 | let cfg = DeviceConfig::default() 59 | .private_key(app.private_key) 60 | .listen_port(app.listen_port); 61 | 62 | match app.mode { 63 | Mode::Native => { 64 | info!("Starting Wiretun device with native tun {}", app.name); 65 | let device = Device::native(&app.name, cfg).await?; 66 | uapi::bind_and_handle(device.control()).await?; 67 | device.terminate().await; 68 | } 69 | Mode::Stub => { 70 | info!("Starting Wiretun device with stub tun {}", app.name); 71 | let device = Device::with_udp(StubTun::new(&app.name), cfg).await?; 72 | uapi::bind_and_handle(device.control()).await?; 73 | device.terminate().await; 74 | } 75 | }; 76 | 77 | Ok(()) 78 | } 79 | 80 | #[derive(Clone)] 81 | struct StubTun { 82 | name: String, 83 | tx: mpsc::Sender>, 84 | rx: Arc>>>, 85 | } 86 | 87 | impl StubTun { 88 | pub fn new(name: &str) -> Self { 89 | let (tx, rx) = mpsc::channel(128); 90 | let rx = Arc::new(Mutex::new(rx)); 91 | let name = name.to_owned(); 92 | Self { name, tx, rx } 93 | } 94 | } 95 | 96 | #[async_trait] 97 | impl Tun for StubTun { 98 | fn name(&self) -> &str { 99 | &self.name 100 | } 101 | 102 | fn mtu(&self) -> Result { 103 | Ok(1500) 104 | } 105 | 106 | fn set_mtu(&self, _mtu: u16) -> Result<(), TunError> { 107 | Ok(()) 108 | } 109 | 110 | async fn recv(&self) -> Result, TunError> { 111 | let mut rx = self.rx.lock().await; 112 | let rv = rx.recv().await.ok_or(TunError::Closed); 113 | 114 | match &rv { 115 | Ok(buf) => { 116 | info!("recv data[{}] from tun", buf.len()); 117 | } 118 | Err(e) => { 119 | error!("failed to recv data from tun: {:?}", e); 120 | } 121 | } 122 | 123 | rv 124 | } 125 | 126 | async fn send(&self, buf: &[u8]) -> Result<(), TunError> { 127 | info!("recv data[{}] from outbound", buf.len()); 128 | self.tx 129 | .send(echo_udp_packet(buf.to_vec(), self.name.as_bytes())) 130 | .await 131 | .map_err(|_| TunError::Closed) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /integration-tests/support/wiretun-cli/src/packet.rs: -------------------------------------------------------------------------------- 1 | use pnet::packet::ip::IpNextHeaderProtocols; 2 | use pnet::packet::ipv4::{checksum, Ipv4Packet, MutableIpv4Packet}; 3 | use pnet::packet::udp::{ipv4_checksum, MutableUdpPacket, UdpPacket}; 4 | use pnet::packet::Packet; 5 | use tracing::debug; 6 | 7 | /// Returns a new packet with the source and destination IP addresses swapped. 8 | /// # Arguments 9 | /// * `buf` - IP packet to be echoed, and the payload should be UDP packet. 10 | /// 11 | /// # Panics 12 | /// This function will panic if the packet is not an IP packet. 13 | pub fn echo_udp_packet(buf: Vec, prefix: &[u8]) -> Vec { 14 | let mut output_ipv4 = MutableIpv4Packet::owned(vec![0; buf.len() + prefix.len()]).unwrap(); 15 | let input_ipv4 = Ipv4Packet::owned(buf).unwrap(); 16 | 17 | output_ipv4.set_source(input_ipv4.get_destination()); 18 | output_ipv4.set_destination(input_ipv4.get_source()); 19 | output_ipv4.set_version(input_ipv4.get_version()); 20 | output_ipv4.set_dscp(input_ipv4.get_dscp()); 21 | output_ipv4.set_flags(input_ipv4.get_flags()); 22 | output_ipv4.set_ecn(input_ipv4.get_ecn()); 23 | output_ipv4.set_header_length(input_ipv4.get_header_length()); 24 | output_ipv4.set_total_length(input_ipv4.get_total_length() + prefix.len() as u16); 25 | output_ipv4.set_identification(input_ipv4.get_identification()); 26 | output_ipv4.set_fragment_offset(input_ipv4.get_fragment_offset()); 27 | output_ipv4.set_next_level_protocol(input_ipv4.get_next_level_protocol()); 28 | output_ipv4.set_options(&input_ipv4.get_options()); 29 | 30 | match output_ipv4.get_next_level_protocol() { 31 | IpNextHeaderProtocols::Udp => { 32 | let input_udp = UdpPacket::owned(input_ipv4.payload().to_vec()).unwrap(); 33 | let mut output_udp = 34 | MutableUdpPacket::owned(vec![0; input_ipv4.payload().len() + prefix.len()]) 35 | .unwrap(); 36 | 37 | output_udp.set_source(input_udp.get_destination()); 38 | output_udp.set_destination(input_udp.get_source()); 39 | output_udp.set_payload(&{ 40 | let mut p = prefix.to_vec(); 41 | p.extend_from_slice(input_udp.payload()); 42 | p 43 | }); 44 | output_udp.set_length(input_udp.get_length() + prefix.len() as u16); 45 | output_udp.set_checksum(ipv4_checksum( 46 | &output_udp.to_immutable(), 47 | &output_ipv4.get_source(), 48 | &output_ipv4.get_destination(), 49 | )); 50 | 51 | output_ipv4.set_payload(output_udp.packet()); 52 | } 53 | _ => { 54 | debug!("Unknown packet type!"); 55 | } 56 | } 57 | 58 | output_ipv4.set_checksum(checksum(&output_ipv4.to_immutable())); 59 | 60 | output_ipv4.packet().to_vec() 61 | } 62 | -------------------------------------------------------------------------------- /justfile: -------------------------------------------------------------------------------- 1 | default: 2 | just --list 3 | 4 | # Build the project 5 | build: 6 | cargo build 7 | 8 | # Format code with rust 9 | fmt: 10 | cargo fmt 11 | 12 | # Lint code with clippy 13 | lint: 14 | cargo fmt --all -- --check 15 | cargo clippy --all-targets --all-features 16 | 17 | # Run unit tests against the current platform 18 | unit-test: 19 | cargo nextest run 20 | cargo test --doc 21 | 22 | # Run integration tests against the current platform (Require sudo) 23 | integration-test: integration-test-native-tun integration-test-wiretun-to-wiretun integration-test-wireguard-to-wiretun 24 | 25 | # Run integration test scenario: native TUN 26 | integration-test-native-tun: 27 | #!/usr/bin/env bash 28 | set -e 29 | pushd integration-tests 30 | just test-native-tun 31 | popd 32 | 33 | # Run integration test scenario: WireTun to WireTun 34 | integration-test-wiretun-to-wiretun: 35 | #!/usr/bin/env bash 36 | set -e 37 | pushd integration-tests 38 | just test-wiretun-to-wiretun 39 | popd 40 | 41 | # Run integration test scenario: WireGuard to WireTun 42 | integration-test-wireguard-to-wiretun: 43 | #!/usr/bin/env bash 44 | set -e 45 | pushd integration-tests 46 | just test-wireguard-to-wiretun 47 | popd 48 | 49 | -------------------------------------------------------------------------------- /src/device/config.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{HashMap, HashSet}; 2 | use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; 3 | use std::time::Duration; 4 | 5 | use super::Cidr; 6 | use crate::noise::crypto::LocalStaticSecret; 7 | 8 | /// Configuration for a device. 9 | /// 10 | /// # Examples 11 | /// 12 | /// ``` 13 | /// use wiretun::{DeviceConfig, PeerConfig, Cidr}; 14 | /// 15 | /// let cfg = DeviceConfig::default() 16 | /// .listen_port(40001) 17 | /// .private_key([0; 32]) 18 | /// .peer(PeerConfig::default().public_key([0; 32]).allowed_ip("10.0.0.0/24".parse::().unwrap())); 19 | /// ``` 20 | #[derive(Clone)] 21 | pub struct DeviceConfig { 22 | pub private_key: [u8; 32], 23 | pub listen_addrs: (Ipv4Addr, Ipv6Addr), 24 | pub listen_port: u16, 25 | pub fwmark: u32, 26 | pub peers: HashMap<[u8; 32], PeerConfig>, 27 | } 28 | 29 | /// Configuration for a peer. 30 | #[derive(Default, Clone)] 31 | pub struct PeerConfig { 32 | pub public_key: [u8; 32], 33 | pub allowed_ips: HashSet, 34 | pub endpoint: Option, 35 | pub preshared_key: Option<[u8; 32]>, 36 | pub persistent_keepalive: Option, 37 | } 38 | 39 | impl DeviceConfig { 40 | #[inline(always)] 41 | pub fn private_key(mut self, key: [u8; 32]) -> Self { 42 | self.private_key = key; 43 | self 44 | } 45 | 46 | #[inline(always)] 47 | pub fn listen_addr_v4(mut self, addr: Ipv4Addr) -> Self { 48 | self.listen_addrs.0 = addr; 49 | self 50 | } 51 | 52 | #[inline(always)] 53 | pub fn listen_addr_v6(mut self, addr: Ipv6Addr) -> Self { 54 | self.listen_addrs.1 = addr; 55 | self 56 | } 57 | 58 | #[inline(always)] 59 | pub fn listen_port(mut self, port: u16) -> Self { 60 | self.listen_port = port; 61 | self 62 | } 63 | 64 | #[inline(always)] 65 | pub fn peer(mut self, peer: PeerConfig) -> Self { 66 | self.peers.insert(peer.public_key, peer); 67 | self 68 | } 69 | 70 | #[inline(always)] 71 | pub fn local_secret(&self) -> LocalStaticSecret { 72 | LocalStaticSecret::new(self.private_key) 73 | } 74 | } 75 | 76 | impl Default for DeviceConfig { 77 | fn default() -> Self { 78 | Self { 79 | private_key: [0; 32], 80 | listen_addrs: (Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED), 81 | listen_port: 0, 82 | fwmark: 0, 83 | peers: HashMap::new(), 84 | } 85 | } 86 | } 87 | 88 | impl PeerConfig { 89 | #[inline(always)] 90 | pub fn public_key(mut self, key: [u8; 32]) -> Self { 91 | self.public_key = key; 92 | self 93 | } 94 | 95 | #[inline(always)] 96 | pub fn allowed_ips>(mut self, ips: impl IntoIterator) -> Self { 97 | self.allowed_ips = ips.into_iter().map(|i| i.into()).collect(); 98 | self 99 | } 100 | 101 | #[inline(always)] 102 | pub fn allowed_ip>(mut self, ip: I) -> Self { 103 | self.allowed_ips.insert(ip.into()); 104 | self 105 | } 106 | 107 | #[inline(always)] 108 | pub fn endpoint(mut self, endpoint: SocketAddr) -> Self { 109 | self.endpoint = Some(endpoint); 110 | self 111 | } 112 | 113 | #[inline(always)] 114 | pub fn preshared_key(mut self, key: [u8; 32]) -> Self { 115 | self.preshared_key = Some(key); 116 | self 117 | } 118 | 119 | #[inline(always)] 120 | pub fn persistent_keepalive(mut self, interval: Duration) -> Self { 121 | self.persistent_keepalive = Some(interval); 122 | self 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/device/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(thiserror::Error, Debug)] 2 | pub enum Error { 3 | #[error("IO Error: {0}")] 4 | IO(#[from] std::io::Error), 5 | #[error("Noise protocol error: {0}")] 6 | Noise(#[from] crate::noise::Error), 7 | #[error("Tun error: {0}")] 8 | Tun(#[from] crate::tun::Error), 9 | } 10 | -------------------------------------------------------------------------------- /src/device/handle.rs: -------------------------------------------------------------------------------- 1 | use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; 2 | use std::sync::Arc; 3 | use std::time::Duration; 4 | 5 | use futures::future::join_all; 6 | use tokio::task::JoinHandle; 7 | use tokio_util::sync::CancellationToken; 8 | use tracing::{debug, error, warn}; 9 | 10 | use super::inbound::{Endpoint, Transport}; 11 | use super::peer::InboundEvent; 12 | use super::DeviceInner; 13 | use crate::noise::crypto::LocalStaticSecret; 14 | use crate::noise::handshake::{Cookie, IncomingInitiation}; 15 | use crate::noise::protocol; 16 | use crate::noise::protocol::Message; 17 | use crate::Tun; 18 | 19 | pub(super) struct DeviceHandle { 20 | token: CancellationToken, 21 | inbound_handles: (CancellationToken, Vec>), 22 | outbound_handles: (CancellationToken, Vec>), 23 | } 24 | 25 | impl DeviceHandle { 26 | pub async fn spawn(token: CancellationToken, inner: Arc>) -> Self 27 | where 28 | T: Tun + 'static, 29 | I: Transport, 30 | { 31 | let mut me = Self { 32 | token: token.clone(), 33 | inbound_handles: (token.child_token(), vec![]), 34 | outbound_handles: (token.child_token(), vec![]), 35 | }; 36 | me.restart_inbound(Arc::clone(&inner)).await; 37 | me.restart_outbound(Arc::clone(&inner)).await; 38 | me 39 | } 40 | 41 | pub async fn restart_inbound(&mut self, inner: Arc>) 42 | where 43 | T: Tun + 'static, 44 | I: Transport, 45 | { 46 | // Stop the previous tasks 47 | { 48 | let handles: Vec<_> = self.inbound_handles.1.drain(..).collect(); 49 | let abort_handles: Vec<_> = handles.iter().map(|h| h.abort_handle()).collect(); 50 | self.inbound_handles.0.cancel(); 51 | if let Err(e) = tokio::time::timeout(Duration::from_secs(5), join_all(handles)).await { 52 | warn!("stopping device inbound loop timeout: {e}"); 53 | for handle in abort_handles { 54 | handle.abort(); 55 | } 56 | } 57 | } 58 | 59 | let token = self.token.child_token(); 60 | let handles = vec![tokio::spawn(loop_inbound( 61 | Arc::clone(&inner), 62 | token.child_token(), 63 | ))]; 64 | self.inbound_handles = (token, handles); 65 | } 66 | 67 | pub async fn restart_outbound(&mut self, inner: Arc>) 68 | where 69 | T: Tun + 'static, 70 | I: Transport, 71 | { 72 | let handles: Vec<_> = self.outbound_handles.1.drain(..).collect(); 73 | join_all(handles).await; 74 | 75 | let token = self.token.child_token(); 76 | let handles = vec![tokio::spawn(loop_outbound( 77 | Arc::clone(&inner), 78 | token.child_token(), 79 | ))]; 80 | self.outbound_handles = (token, handles); 81 | } 82 | 83 | pub fn abort(&self) { 84 | self.inbound_handles.0.cancel(); 85 | self.outbound_handles.0.cancel(); 86 | } 87 | 88 | pub async fn stop(&mut self) { 89 | self.abort(); 90 | 91 | // Wait until all background tasks are done. 92 | let mut handles = vec![]; 93 | handles.extend(&mut self.inbound_handles.1.drain(..)); 94 | handles.extend(&mut self.outbound_handles.1.drain(..)); 95 | 96 | join_all(handles).await; 97 | } 98 | } 99 | 100 | impl Drop for DeviceHandle { 101 | fn drop(&mut self) { 102 | self.token.cancel(); 103 | } 104 | } 105 | 106 | async fn loop_inbound(inner: Arc>, token: CancellationToken) 107 | where 108 | T: Tun + 'static, 109 | I: Transport, 110 | { 111 | let mut transport = inner.settings.lock().unwrap().inbound.transport(); 112 | debug!("Device Inbound loop for {transport} is UP"); 113 | let (secret, cookie) = { 114 | let settings = inner.settings.lock().unwrap(); 115 | (settings.secret.clone(), Arc::clone(&settings.cookie)) 116 | }; 117 | 118 | loop { 119 | tokio::select! { 120 | _ = token.cancelled() => { 121 | debug!("Device Inbound loop for {transport} is DOWN"); 122 | return; 123 | } 124 | data = transport.recv_from() => { 125 | if let Ok((endpoint, payload)) = data { 126 | tick_inbound(Arc::clone(&inner), &secret, Arc::clone(&cookie), endpoint, payload).await; 127 | } 128 | } 129 | } 130 | } 131 | } 132 | 133 | async fn tick_inbound( 134 | inner: Arc>, 135 | secret: &LocalStaticSecret, 136 | cookie: Arc, 137 | endpoint: Endpoint, 138 | payload: Vec, 139 | ) where 140 | T: Tun + 'static, 141 | I: Transport, 142 | { 143 | if Message::is_handshake(&payload) { 144 | if !cookie.validate_mac1(&payload) { 145 | debug!("invalid mac1"); 146 | return; 147 | } 148 | 149 | if !inner.rate_limiter.fetch_token() { 150 | debug!("rate limited"); 151 | if !cookie.validate_mac2(&payload) { 152 | debug!("invalid mac2"); 153 | return; 154 | } 155 | debug!("try to send cookie reply"); 156 | let reply = cookie.generate_cookie_reply(&payload, endpoint.dst()); 157 | endpoint.send(&reply).await.unwrap(); 158 | return; 159 | } 160 | } 161 | 162 | match Message::parse(&payload) { 163 | Ok(Message::HandshakeInitiation(p)) => { 164 | let initiation = IncomingInitiation::parse(secret, &p).unwrap_or_else(|_| todo!()); 165 | if let Some(peer) = inner.get_peer_by_key(initiation.static_public_key.as_bytes()) { 166 | peer.handle_inbound(InboundEvent::HandshakeInitiation { 167 | endpoint, 168 | initiation, 169 | }) 170 | .await; 171 | } 172 | } 173 | Ok(msg) => { 174 | let receiver_index = match &msg { 175 | Message::HandshakeResponse(p) => p.receiver_index, 176 | Message::CookieReply(p) => p.receiver_index, 177 | Message::TransportData(p) => p.receiver_index, 178 | _ => unreachable!(), 179 | }; 180 | if let Some((session, peer)) = inner.get_session_by_index(receiver_index) { 181 | match msg { 182 | Message::HandshakeResponse(packet) => { 183 | peer.handle_inbound(InboundEvent::HandshakeResponse { 184 | endpoint, 185 | packet, 186 | session, 187 | }) 188 | .await; 189 | } 190 | Message::CookieReply(packet) => { 191 | peer.handle_inbound(InboundEvent::CookieReply { 192 | endpoint, 193 | packet, 194 | session, 195 | }) 196 | .await; 197 | } 198 | Message::TransportData(packet) => { 199 | if packet.counter > protocol::REJECT_AFTER_MESSAGES { 200 | warn!("received too many messages from peer [index={receiver_index}]"); 201 | return; 202 | } 203 | 204 | peer.handle_inbound(InboundEvent::TransportData { 205 | endpoint, 206 | packet, 207 | session, 208 | }) 209 | .await; 210 | } 211 | _ => unreachable!(), 212 | } 213 | } else { 214 | warn!("received message from unknown peer [index={receiver_index}]"); 215 | } 216 | } 217 | Err(e) => { 218 | warn!("failed to parse message type: {:?}", e); 219 | } 220 | } 221 | } 222 | 223 | async fn loop_outbound(inner: Arc>, token: CancellationToken) 224 | where 225 | T: Tun + 'static, 226 | I: Transport, 227 | { 228 | debug!("Device outbound loop is UP"); 229 | loop { 230 | tokio::select! { 231 | _ = token.cancelled() => { 232 | debug!("Device outbound loop is DOWN"); 233 | return; 234 | } 235 | _ = tick_outbound(Arc::clone(&inner)) => {} 236 | } 237 | } 238 | } 239 | 240 | async fn tick_outbound(inner: Arc>) 241 | where 242 | T: Tun + 'static, 243 | I: Transport, 244 | { 245 | const IPV4_HEADER_LEN: usize = 20; 246 | const IPV6_HEADER_LEN: usize = 40; 247 | 248 | match inner.tun.recv().await { 249 | Ok(buf) => { 250 | let dst = { 251 | match buf[0] & 0xF0 { 252 | 0x40 if buf.len() < IPV4_HEADER_LEN => return, 253 | 0x40 => { 254 | let addr: [u8; 4] = buf[16..20].try_into().unwrap(); 255 | IpAddr::from(Ipv4Addr::from(addr)) 256 | } 257 | 0x60 if buf.len() < IPV6_HEADER_LEN => return, 258 | 0x60 => { 259 | let addr: [u8; 16] = buf[24..40].try_into().unwrap(); 260 | IpAddr::from(Ipv6Addr::from(addr)) 261 | } 262 | n => { 263 | debug!("unknown IP version: {}", n); 264 | return; 265 | } 266 | } 267 | }; 268 | 269 | debug!("trying to send packet to {}", dst); 270 | 271 | let peer = inner.peers.lock().unwrap().get_by_ip(dst); 272 | 273 | if let Some(peer) = peer { 274 | debug!("sending packet[{}] to {dst}", buf.len()); 275 | peer.stage_outbound(buf).await 276 | } else { 277 | warn!("no peer found for {dst}"); 278 | } 279 | } 280 | Err(e) => { 281 | error!("TUN read error: {}", e) 282 | } 283 | } 284 | } 285 | -------------------------------------------------------------------------------- /src/device/inbound.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Debug, Display, Formatter}; 2 | use std::io; 3 | use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; 4 | use std::sync::Arc; 5 | 6 | use async_trait::async_trait; 7 | use socket2::{Domain, Protocol, Type}; 8 | use tokio::net::UdpSocket; 9 | use tracing::{error, info}; 10 | 11 | /// Transport is a trait that represents a network transport. 12 | #[async_trait] 13 | pub trait Transport: Clone + Sync + Send + Unpin + Display + 'static { 14 | /// Binds to the given port and returns a new transport. 15 | /// When the port is 0, the implementation should choose a random port. 16 | async fn bind(ipv4: Ipv4Addr, ipv6: Ipv6Addr, port: u16) -> Result; 17 | 18 | fn ipv4(&self) -> Ipv4Addr; 19 | 20 | fn ipv6(&self) -> Ipv6Addr; 21 | 22 | /// Returns the port that the transport is bound to. 23 | fn port(&self) -> u16; 24 | 25 | /// Sends data to the given endpoint. 26 | async fn send_to(&self, data: &[u8], endpoint: &Endpoint) -> Result<(), io::Error>; 27 | 28 | /// Receives data from the transport. 29 | async fn recv_from(&mut self) -> Result<(Endpoint, Vec), io::Error>; 30 | } 31 | 32 | pub(super) struct Inbound 33 | where 34 | I: Transport, 35 | { 36 | transport: I, 37 | } 38 | 39 | impl Inbound 40 | where 41 | I: Transport, 42 | { 43 | #[inline(always)] 44 | pub fn new(transport: I) -> Self { 45 | Self { transport } 46 | } 47 | 48 | #[inline(always)] 49 | pub fn ipv4(&self) -> Ipv4Addr { 50 | self.transport.ipv4() 51 | } 52 | 53 | #[inline(always)] 54 | pub fn ipv6(&self) -> Ipv6Addr { 55 | self.transport.ipv6() 56 | } 57 | 58 | #[inline(always)] 59 | pub fn port(&self) -> u16 { 60 | self.transport.port() 61 | } 62 | 63 | #[inline(always)] 64 | pub fn transport(&self) -> I { 65 | self.transport.clone() 66 | } 67 | 68 | #[inline(always)] 69 | pub fn endpoint_for(&self, dst: SocketAddr) -> Endpoint { 70 | Endpoint::new(self.transport(), dst) 71 | } 72 | } 73 | 74 | #[derive(Clone)] 75 | pub struct Endpoint { 76 | transport: I, 77 | dst: SocketAddr, 78 | } 79 | 80 | impl Endpoint 81 | where 82 | I: Transport, 83 | { 84 | /// Creates a new endpoint with the given transport and destination. 85 | pub fn new(transport: I, dst: SocketAddr) -> Self { 86 | Self { transport, dst } 87 | } 88 | 89 | /// Sends data to the endpoint. 90 | #[inline] 91 | pub async fn send(&self, buf: &[u8]) -> Result<(), io::Error> { 92 | self.transport.send_to(buf, self).await 93 | } 94 | 95 | /// Returns the destination of the endpoint. 96 | #[inline(always)] 97 | pub fn dst(&self) -> SocketAddr { 98 | self.dst 99 | } 100 | } 101 | 102 | impl Debug for Endpoint { 103 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 104 | f.debug_struct("Endpoint") 105 | .field("dst", &self.dst.to_string()) 106 | .finish() 107 | } 108 | } 109 | 110 | /// UdpTransport is a UDP transport that implements the [`Transport`] trait. 111 | #[derive(Clone)] 112 | pub struct UdpTransport { 113 | port: u16, 114 | ipv4: Arc, 115 | ipv6: Arc, 116 | ipv4_buf: Vec, 117 | ipv6_buf: Vec, 118 | } 119 | 120 | impl UdpTransport { 121 | async fn bind_socket( 122 | ipv4: Ipv4Addr, 123 | ipv6: Ipv6Addr, 124 | port: u16, 125 | ) -> Result<(Arc, Arc, u16), io::Error> { 126 | let max_retry = if port == 0 { 10 } else { 1 }; 127 | let mut err = None; 128 | for _ in 0..max_retry { 129 | let ipv4 = match Self::bind_socket_v4(SocketAddrV4::new(ipv4, port)).await { 130 | Ok(s) => s, 131 | Err(e) => { 132 | err = Some(e); 133 | continue; 134 | } 135 | }; 136 | let port = ipv4.local_addr()?.port(); 137 | let ipv6 = match Self::bind_socket_v6(SocketAddrV6::new(ipv6, port, 0, 0)).await { 138 | Ok(s) => s, 139 | Err(e) => { 140 | err = Some(e); 141 | continue; 142 | } 143 | }; 144 | 145 | return Ok((Arc::new(ipv4), Arc::new(ipv6), port)); 146 | } 147 | let e = err.unwrap(); 148 | error!("Inbound is not able to bind port {port}: {e}"); 149 | Err(e) 150 | } 151 | 152 | async fn bind_socket_v4(addr: SocketAddrV4) -> Result { 153 | let socket = socket2::Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; 154 | socket.set_nonblocking(true)?; 155 | socket.set_reuse_address(true)?; 156 | socket.bind(&addr.into())?; 157 | UdpSocket::from_std(std::net::UdpSocket::from(socket)) 158 | } 159 | 160 | async fn bind_socket_v6(addr: SocketAddrV6) -> Result { 161 | let socket = socket2::Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; 162 | socket.set_only_v6(true)?; 163 | socket.set_nonblocking(true)?; 164 | socket.set_reuse_address(true)?; 165 | socket.bind(&addr.into())?; 166 | UdpSocket::from_std(std::net::UdpSocket::from(socket)) 167 | } 168 | } 169 | 170 | impl Display for UdpTransport { 171 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 172 | write!( 173 | f, 174 | "UdpTransport[{}/{}]", 175 | self.ipv4.local_addr().unwrap(), 176 | self.ipv6.local_addr().unwrap() 177 | ) 178 | } 179 | } 180 | 181 | #[async_trait] 182 | impl Transport for UdpTransport { 183 | fn port(&self) -> u16 { 184 | self.port 185 | } 186 | 187 | async fn bind(ipv4: Ipv4Addr, ipv6: Ipv6Addr, port: u16) -> Result { 188 | let (ipv4, ipv6, port) = Self::bind_socket(ipv4, ipv6, port).await?; 189 | info!( 190 | "Listening on {} / {}", 191 | ipv4.local_addr()?, 192 | ipv6.local_addr()? 193 | ); 194 | Ok(Self { 195 | port, 196 | ipv4, 197 | ipv6, 198 | ipv4_buf: vec![], 199 | ipv6_buf: vec![], 200 | }) 201 | } 202 | 203 | async fn send_to(&self, data: &[u8], endpoint: &Endpoint) -> Result<(), io::Error> { 204 | match endpoint.dst { 205 | SocketAddr::V4(_) => self.ipv4.send_to(data, endpoint.dst).await?, 206 | SocketAddr::V6(_) => self.ipv6.send_to(data, endpoint.dst).await?, 207 | }; 208 | Ok(()) 209 | } 210 | 211 | async fn recv_from(&mut self) -> Result<(Endpoint, Vec), io::Error> { 212 | if self.ipv4_buf.is_empty() { 213 | self.ipv4_buf = vec![0u8; 2048]; 214 | } 215 | if self.ipv6_buf.is_empty() { 216 | self.ipv4_buf = vec![0u8; 2048]; 217 | } 218 | 219 | let (data, addr) = tokio::select! { 220 | ret = self.ipv4.recv_from(&mut self.ipv4_buf) => { 221 | let (n, addr) = ret?; 222 | (self.ipv4_buf[..n].to_vec(), addr) 223 | }, 224 | ret = self.ipv6.recv_from(&mut self.ipv6_buf) => { 225 | let (n, addr) = ret?; 226 | (self.ipv6_buf[..n].to_vec(), addr) 227 | }, 228 | }; 229 | 230 | Ok((Endpoint::new(self.clone(), addr), data)) 231 | } 232 | 233 | fn ipv4(&self) -> Ipv4Addr { 234 | if let SocketAddr::V4(addr) = self.ipv4.local_addr().unwrap() { 235 | *addr.ip() 236 | } else { 237 | unreachable!() 238 | } 239 | } 240 | 241 | fn ipv6(&self) -> Ipv6Addr { 242 | if let SocketAddr::V6(addr) = self.ipv6.local_addr().unwrap() { 243 | *addr.ip() 244 | } else { 245 | unreachable!() 246 | } 247 | } 248 | } 249 | -------------------------------------------------------------------------------- /src/device/metrics.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use super::peer::PeerMetrics; 4 | 5 | pub struct DeviceMetrics { 6 | pub peers: HashMap<[u8; 32], PeerMetrics>, // index by public key 7 | } 8 | -------------------------------------------------------------------------------- /src/device/peer/cidr.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Display, Formatter}; 2 | use std::hash::{Hash, Hasher}; 3 | use std::net::IpAddr; 4 | use std::str::FromStr; 5 | 6 | use ip_network::IpNetwork; 7 | use ip_network_table::IpNetworkTable; 8 | 9 | const fn max_mask_for_ip(ip: &IpAddr) -> u8 { 10 | match ip { 11 | IpAddr::V4(_) => 32, 12 | IpAddr::V6(_) => 128, 13 | } 14 | } 15 | 16 | /// Cidr represents a CIDR block. 17 | /// 18 | /// # Examples 19 | /// 20 | /// ``` 21 | /// use wiretun::Cidr; 22 | /// 23 | /// let cidr = "10.10.0.0/24".parse::().unwrap(); 24 | /// assert_eq!(cidr.to_string(), "10.10.0.0/24"); 25 | /// 26 | /// let cidr = "2001:db8::/32".parse::().unwrap(); 27 | /// assert_eq!(cidr.to_string(), "2001:db8::/32"); 28 | /// 29 | /// let cidr = "10.10.10.0/16".parse::().unwrap(); 30 | /// assert_eq!(cidr.to_string(), "10.10.0.0/16"); // truncated 31 | /// ``` 32 | #[derive(Clone, Copy, Debug)] 33 | pub struct Cidr(IpNetwork); 34 | 35 | impl Cidr { 36 | /// # Panics 37 | /// Panics if the mask is invalid for the given IP address. 38 | pub fn new(ip: IpAddr, mask: u8) -> Self { 39 | Self(IpNetwork::new_truncate(ip, mask).unwrap()) 40 | } 41 | } 42 | 43 | impl Display for Cidr { 44 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 45 | write!(f, "{}/{}", self.0.network_address(), self.0.netmask()) 46 | } 47 | } 48 | 49 | impl From for Cidr { 50 | fn from(value: IpAddr) -> Self { 51 | let mask = max_mask_for_ip(&value); 52 | Self::new(value, mask) 53 | } 54 | } 55 | 56 | impl PartialEq for Cidr { 57 | fn eq(&self, other: &Self) -> bool { 58 | self.0 == other.0 59 | } 60 | } 61 | 62 | impl Hash for Cidr { 63 | fn hash(&self, state: &mut H) { 64 | self.0.hash(state); 65 | } 66 | } 67 | 68 | impl Eq for Cidr {} 69 | 70 | impl FromStr for Cidr { 71 | type Err = ParseCidrError; 72 | 73 | fn from_str(s: &str) -> Result { 74 | if let Some((ip, mask)) = s.split_once('/') { 75 | let ip = IpAddr::from_str(ip).map_err(|_| ParseCidrError::InvalidIp)?; 76 | let mask = u8::from_str(mask).map_err(|_| ParseCidrError::InvalidMask)?; 77 | if mask > max_mask_for_ip(&ip) { 78 | return Err(ParseCidrError::InvalidMask); 79 | } 80 | 81 | Ok(Self::new(ip, mask)) 82 | } else { 83 | let ip = IpAddr::from_str(s).map_err(|_| ParseCidrError::InvalidIp)?; 84 | Ok(Self::from(ip)) 85 | } 86 | } 87 | } 88 | 89 | #[derive(thiserror::Error, Debug, PartialEq)] 90 | pub enum ParseCidrError { 91 | #[error("invalid ip address")] 92 | InvalidIp, 93 | #[error("invalid mask")] 94 | InvalidMask, 95 | } 96 | 97 | pub(super) struct CidrTable { 98 | table: IpNetworkTable, 99 | } 100 | 101 | impl CidrTable { 102 | pub fn new() -> Self { 103 | Self { 104 | table: IpNetworkTable::new(), 105 | } 106 | } 107 | 108 | pub fn insert(&mut self, cidr: Cidr, value: T) { 109 | self.table.insert(cidr.0, value); 110 | } 111 | 112 | pub fn get_by_ip(&self, ip: IpAddr) -> Option<&T> { 113 | self.table.longest_match(ip).map(|(_, v)| v) 114 | } 115 | 116 | pub fn remove(&mut self, cidr: &Cidr) { 117 | self.table.remove(cidr.0); 118 | } 119 | 120 | pub fn clear(&mut self) { 121 | self.table = IpNetworkTable::new(); 122 | } 123 | } 124 | 125 | #[cfg(test)] 126 | mod tests { 127 | use super::*; 128 | 129 | #[test] 130 | fn test_parse_str_for_cidr() { 131 | let valid_cases = [ 132 | ("10.2.3.4", "10.2.3.4/32"), 133 | ("10.2.3.4/32", "10.2.3.4/32"), 134 | ("10.2.3.4/16", "10.2.0.0/16"), 135 | ("10.2.3.4/24", "10.2.3.0/24"), 136 | ]; 137 | 138 | for (input, expected) in valid_cases { 139 | let cidr = Cidr::from_str(input); 140 | assert!(cidr.is_ok()); 141 | let cidr = cidr.unwrap(); 142 | assert_eq!(cidr.to_string(), expected); 143 | } 144 | 145 | let invalid_cases = [ 146 | ("10.2.3.4.", ParseCidrError::InvalidIp), 147 | ("10.2.3.256", ParseCidrError::InvalidIp), 148 | ("10.0.0.1/33", ParseCidrError::InvalidMask), 149 | ("10.0.0.1/32/", ParseCidrError::InvalidMask), 150 | ]; 151 | 152 | for (input, expected) in invalid_cases { 153 | let cidr = Cidr::from_str(input); 154 | assert!(cidr.is_err()); 155 | assert_eq!(cidr.unwrap_err(), expected); 156 | } 157 | } 158 | 159 | #[test] 160 | fn test_cidr_table_get_by_id() { 161 | let mut table = CidrTable::new(); 162 | table.insert("10.2.3.4/16".parse().unwrap(), 1); 163 | assert_eq!(table.get_by_ip("10.2.0.0".parse().unwrap()), Some(&1)); 164 | assert_eq!(table.get_by_ip("10.2.1.0".parse().unwrap()), Some(&1)); 165 | assert_eq!(table.get_by_ip("10.2.255.0".parse().unwrap()), Some(&1)); 166 | 167 | assert_eq!(table.get_by_ip("10.3.0.0".parse().unwrap()), None); 168 | assert_eq!(table.get_by_ip("10.1.0.0".parse().unwrap()), None); 169 | table.insert("10.3.0.0/16".parse().unwrap(), 2); 170 | assert_eq!(table.get_by_ip("10.3.0.0".parse().unwrap()), Some(&2)); 171 | assert_eq!(table.get_by_ip("10.1.0.0".parse().unwrap()), None); 172 | 173 | assert_eq!(table.get_by_ip("10.2.0.0".parse().unwrap()), Some(&1)); 174 | assert_eq!(table.get_by_ip("10.2.1.0".parse().unwrap()), Some(&1)); 175 | assert_eq!(table.get_by_ip("10.2.255.0".parse().unwrap()), Some(&1)); 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /src/device/peer/handle.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | use std::time::Duration; 3 | 4 | use futures::future::join_all; 5 | use tokio::task::JoinHandle; 6 | use tokio::time; 7 | use tokio_util::sync::CancellationToken; 8 | use tracing::{debug, info, warn}; 9 | 10 | use super::{InboundEvent, InboundRx, OutboundEvent, OutboundRx, Peer, Session}; 11 | use crate::device::{Endpoint, Transport}; 12 | use crate::noise::handshake::IncomingInitiation; 13 | use crate::noise::protocol::{ 14 | self, CookieReply, HandshakeResponse, TransportData, COOKIE_REPLY_PACKET_SIZE, 15 | HANDSHAKE_RESPONSE_PACKET_SIZE, 16 | }; 17 | use crate::Tun; 18 | 19 | pub(crate) struct PeerHandle { 20 | token: CancellationToken, 21 | handles: Vec>, 22 | } 23 | 24 | impl PeerHandle { 25 | pub fn spawn( 26 | token: CancellationToken, 27 | peer: Arc>, 28 | inbound: InboundRx, 29 | outbound: OutboundRx, 30 | ) -> Self 31 | where 32 | T: Tun + 'static, 33 | I: Transport, 34 | { 35 | let handshake_loop = tokio::spawn(loop_handshake(token.child_token(), Arc::clone(&peer))); 36 | let inbound_loop = tokio::spawn(loop_inbound( 37 | token.child_token(), 38 | Arc::clone(&peer), 39 | inbound, 40 | )); 41 | let outbound_loop = tokio::spawn(loop_outbound( 42 | token.child_token(), 43 | Arc::clone(&peer), 44 | outbound, 45 | )); 46 | 47 | Self { 48 | token, 49 | handles: vec![handshake_loop, inbound_loop, outbound_loop], 50 | } 51 | } 52 | 53 | /// Cancel the background tasks and wait until they are terminated. 54 | /// If the timeout is reached, the tasks are terminated immediately. 55 | pub async fn cancel(mut self, timeout: Duration) { 56 | self.token.cancel(); 57 | let handles = self.handles.drain(..).collect::>(); 58 | let abort_handles = handles.iter().map(|h| h.abort_handle()).collect::>(); 59 | if let Err(e) = tokio::time::timeout(timeout, join_all(handles)).await { 60 | warn!( 61 | "failed to cancel peer tasks in {}ms: {}", 62 | timeout.as_millis(), 63 | e 64 | ); 65 | for handle in abort_handles { 66 | handle.abort(); 67 | } 68 | } 69 | } 70 | } 71 | 72 | impl Drop for PeerHandle { 73 | fn drop(&mut self) { 74 | self.token.cancel(); 75 | } 76 | } 77 | 78 | async fn loop_handshake(token: CancellationToken, peer: Arc>) 79 | where 80 | T: Tun + 'static, 81 | I: Transport, 82 | { 83 | debug!("Handshake loop for {peer} is UP"); 84 | while !token.is_cancelled() { 85 | if peer.monitor.can_handshake() { 86 | info!("initiating handshake"); 87 | let packet = { 88 | let (next, packet) = peer.handshake.write().unwrap().initiate(); 89 | let mut sessions = peer.sessions.write().unwrap(); 90 | sessions.prepare_uninit(next); 91 | packet 92 | }; 93 | 94 | peer.send_outbound(&packet).await; // send directly 95 | peer.monitor.handshake().initiated(); 96 | } 97 | time::sleep_until(peer.monitor.handshake().will_initiate_in().into()).await; 98 | } 99 | debug!("Handshake loop for {peer} is DOWN"); 100 | } 101 | 102 | // Send to endpoint if connected, otherwise queue for later 103 | async fn loop_outbound(token: CancellationToken, peer: Arc>, mut rx: OutboundRx) 104 | where 105 | T: Tun + 'static, 106 | I: Transport, 107 | { 108 | debug!("Outbound loop for {peer} is UP"); 109 | 110 | loop { 111 | tokio::select! { 112 | _ = token.cancelled() => break, 113 | _ = time::sleep_until(peer.monitor.keepalive().next_attempt_in(peer.monitor.traffic()).into()) => { 114 | peer.keepalive().await; 115 | } 116 | event = rx.recv() => { 117 | match event { 118 | Some(OutboundEvent::Data(data)) => { 119 | tick_outbound(Arc::clone(&peer), data).await; 120 | } 121 | None => break, 122 | } 123 | } 124 | } 125 | } 126 | 127 | debug!("Outbound loop for {peer} is DOWN"); 128 | } 129 | 130 | async fn tick_outbound(peer: Arc>, data: Vec) 131 | where 132 | T: Tun + 'static, 133 | I: Transport, 134 | { 135 | let session = { peer.sessions.read().unwrap().current().clone() }; 136 | let session = if let Some(s) = session { s } else { return }; 137 | 138 | match session.encrypt_data(&data) { 139 | Ok(packet) => { 140 | let buf = packet.to_bytes(); 141 | peer.send_outbound(&buf).await; 142 | peer.monitor.traffic().outbound(buf.len()); 143 | } 144 | Err(e) => { 145 | warn!("failed to encrypt packet: {}", e); 146 | } 147 | } 148 | } 149 | 150 | // Send to tun if we have a valid session 151 | async fn loop_inbound(token: CancellationToken, peer: Arc>, mut rx: InboundRx) 152 | where 153 | T: Tun + 'static, 154 | I: Transport, 155 | { 156 | debug!("Inbound loop for {peer} is UP"); 157 | 158 | loop { 159 | tokio::select! { 160 | () = token.cancelled() => break, 161 | event = rx.recv() => { 162 | match event { 163 | Some(event) => tick_inbound(Arc::clone(&peer), event).await, 164 | None => break, 165 | } 166 | } 167 | } 168 | } 169 | 170 | debug!("Inbound loop for {peer} is DOWN"); 171 | } 172 | 173 | async fn tick_inbound(peer: Arc>, event: InboundEvent) 174 | where 175 | T: Tun + 'static, 176 | I: Transport, 177 | { 178 | match event { 179 | InboundEvent::HandshakeInitiation { 180 | endpoint, 181 | initiation, 182 | } => inbound::handle_handshake_initiation(Arc::clone(&peer), endpoint, initiation).await, 183 | InboundEvent::HandshakeResponse { 184 | endpoint, 185 | packet, 186 | session, 187 | } => inbound::handle_handshake_response(Arc::clone(&peer), endpoint, packet, session).await, 188 | InboundEvent::CookieReply { 189 | endpoint, 190 | packet, 191 | session, 192 | } => inbound::handle_cookie_reply(Arc::clone(&peer), endpoint, packet, session).await, 193 | InboundEvent::TransportData { 194 | endpoint, 195 | packet, 196 | session, 197 | } => inbound::handle_transport_data(Arc::clone(&peer), endpoint, packet, session).await, 198 | } 199 | } 200 | 201 | mod inbound { 202 | use super::*; 203 | use tracing::error; 204 | 205 | pub(super) async fn handle_handshake_initiation( 206 | peer: Arc>, 207 | endpoint: Endpoint, 208 | initiation: IncomingInitiation, 209 | ) where 210 | T: Tun + 'static, 211 | I: Transport, 212 | { 213 | peer.monitor 214 | .traffic() 215 | .inbound(protocol::HANDSHAKE_INITIATION_PACKET_SIZE); 216 | let ret = { 217 | let mut handshake = peer.handshake.write().unwrap(); 218 | handshake.respond(&initiation) 219 | }; 220 | match ret { 221 | Ok((session, packet)) => { 222 | { 223 | let mut sessions = peer.sessions.write().unwrap(); 224 | sessions.prepare_next(session); 225 | } 226 | peer.update_endpoint(endpoint.clone()); 227 | endpoint.send(&packet).await.unwrap(); 228 | peer.monitor.handshake().initiated(); 229 | } 230 | Err(e) => debug!("failed to respond to handshake initiation: {e}"), 231 | } 232 | } 233 | 234 | pub(super) async fn handle_handshake_response( 235 | peer: Arc>, 236 | endpoint: Endpoint, 237 | packet: HandshakeResponse, 238 | _session: Session, 239 | ) where 240 | T: Tun + 'static, 241 | I: Transport, 242 | { 243 | peer.monitor 244 | .traffic() 245 | .inbound(HANDSHAKE_RESPONSE_PACKET_SIZE); 246 | let ret = { 247 | let mut handshake = peer.handshake.write().unwrap(); 248 | handshake.finalize(&packet) 249 | }; 250 | match ret { 251 | Ok(session) => { 252 | let ret = { 253 | let mut sessions = peer.sessions.write().unwrap(); 254 | sessions.complete_uninit(session) 255 | }; 256 | if !ret { 257 | debug!("failed to complete handshake, session not found"); 258 | return; 259 | } 260 | 261 | peer.monitor.handshake().completed(); 262 | info!("handshake completed"); 263 | peer.update_endpoint(endpoint); 264 | peer.stage_outbound(vec![]).await; // let the peer know the session is valid 265 | } 266 | Err(e) => debug!("failed to finalize handshake: {e}"), 267 | } 268 | } 269 | 270 | pub(super) async fn handle_cookie_reply( 271 | peer: Arc>, 272 | _endpoint: Endpoint, 273 | _packet: CookieReply, 274 | _session: Session, 275 | ) where 276 | T: Tun + 'static, 277 | I: Transport, 278 | { 279 | peer.monitor.traffic().inbound(COOKIE_REPLY_PACKET_SIZE); 280 | } 281 | 282 | pub(super) async fn handle_transport_data( 283 | peer: Arc>, 284 | endpoint: Endpoint, 285 | packet: TransportData, 286 | session: Session, 287 | ) where 288 | T: Tun + 'static, 289 | I: Transport, 290 | { 291 | peer.monitor.traffic().inbound(packet.packet_len()); 292 | { 293 | let mut sessions = peer.sessions.write().unwrap(); 294 | if sessions.complete_next(session.clone()) { 295 | info!("handshake completed"); 296 | peer.monitor.handshake().completed(); 297 | } 298 | } 299 | if !session.can_accept(packet.counter) { 300 | debug!("dropping packet due to replay"); 301 | return; 302 | } 303 | 304 | peer.update_endpoint(endpoint); 305 | match session.decrypt_data(&packet) { 306 | Ok(data) => { 307 | if data.is_empty() { 308 | // keepalive 309 | return; 310 | } 311 | 312 | debug!("recv data from peer and try to send it to TUN"); 313 | if let Err(e) = peer.tun.send(&data).await { 314 | error!("{peer} failed to send data to tun: {e}"); 315 | } 316 | session.aceept(packet.counter); 317 | } 318 | Err(e) => debug!("failed to decrypt packet: {e}"), 319 | } 320 | } 321 | } 322 | -------------------------------------------------------------------------------- /src/device/peer/handshake.rs: -------------------------------------------------------------------------------- 1 | use super::session::{Session, SessionIndex}; 2 | use crate::noise::protocol::HandshakeResponse; 3 | use crate::noise::{ 4 | crypto::{kdf2, PeerStaticSecret}, 5 | handshake::{ 6 | IncomingInitiation, IncomingResponse, MacGenerator, OutgoingInitiation, OutgoingResponse, 7 | }, 8 | Error, 9 | }; 10 | 11 | enum State { 12 | Uninit, 13 | Initiation(OutgoingInitiation), 14 | } 15 | 16 | pub(super) struct Handshake { 17 | state: State, 18 | secret: PeerStaticSecret, 19 | macs: MacGenerator, 20 | session_index: SessionIndex, 21 | } 22 | 23 | impl Handshake { 24 | pub fn new(secret: PeerStaticSecret, session_index: SessionIndex) -> Self { 25 | let cookie = MacGenerator::new(&secret); 26 | Self { 27 | secret, 28 | session_index, 29 | macs: cookie, 30 | state: State::Uninit, 31 | } 32 | } 33 | 34 | // Prepare HandshakeInitiation packet. 35 | pub fn initiate(&mut self) -> (Session, Vec) { 36 | let sender_index = self.session_index.next_index(); 37 | let (state, payload) = OutgoingInitiation::new(sender_index, &self.secret, &mut self.macs); 38 | let pre = Session::new(self.secret.clone(), sender_index, [0u8; 32], 0, [0u8; 32]); 39 | self.state = State::Initiation(state); 40 | 41 | (pre, payload) 42 | } 43 | 44 | // Receive HandshakeInitiation packet from peer. 45 | pub fn respond( 46 | &mut self, 47 | initiation: &IncomingInitiation, 48 | ) -> Result<(Session, Vec), Error> { 49 | let local_index = self.session_index.next_index(); 50 | let (state, payload) = 51 | OutgoingResponse::new(initiation, local_index, &self.secret, &mut self.macs); 52 | let (sender_index, receiver_index) = (local_index, initiation.index); 53 | let (receiver_key, sender_key) = kdf2(&state.chaining_key, &[]); 54 | let sess = Session::new( 55 | self.secret.clone(), 56 | sender_index, 57 | sender_key, 58 | receiver_index, 59 | receiver_key, 60 | ); 61 | 62 | Ok((sess, payload)) 63 | } 64 | 65 | pub fn finalize(&mut self, packet: &HandshakeResponse) -> Result { 66 | match &self.state { 67 | State::Initiation(initiation) => { 68 | let state = IncomingResponse::parse(initiation, &self.secret, packet)?; 69 | let (sender_index, receiver_index) = (initiation.index, state.index); 70 | let (sender_key, receiver_key) = kdf2(&state.chaining_key, &[]); 71 | let sess = Session::new( 72 | self.secret.clone(), 73 | sender_index, 74 | sender_key, 75 | receiver_index, 76 | receiver_key, 77 | ); 78 | 79 | Ok(sess) 80 | } 81 | _ => Err(Error::InvalidHandshakeState), 82 | } 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/device/peer/index.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{HashMap, HashSet}; 2 | use std::net::IpAddr; 3 | use std::sync::Arc; 4 | use std::time::Duration; 5 | 6 | use tokio::sync::mpsc; 7 | use tokio_util::sync::CancellationToken; 8 | 9 | use super::cidr::{Cidr, CidrTable}; 10 | use super::session::{Session, SessionIndex}; 11 | use super::{Peer, PeerHandle, PeerMetrics}; 12 | use crate::device::inbound::{Endpoint, Transport}; 13 | use crate::noise::crypto::PeerStaticSecret; 14 | use crate::{PeerConfig, Tun}; 15 | 16 | struct PeerEntry 17 | where 18 | T: Tun + 'static, 19 | I: Transport, 20 | { 21 | peer: Arc>, 22 | allowed_ips: HashSet, 23 | #[allow(unused)] 24 | handle: PeerHandle, 25 | } 26 | 27 | pub(crate) struct PeerIndex 28 | where 29 | T: Tun + 'static, 30 | I: Transport, 31 | { 32 | token: CancellationToken, 33 | tun: T, 34 | sessions: SessionIndex, 35 | peers: HashMap<[u8; 32], PeerEntry>, 36 | ips: CidrTable>>, 37 | } 38 | 39 | impl PeerIndex 40 | where 41 | T: Tun + 'static, 42 | I: Transport, 43 | { 44 | pub fn new(token: CancellationToken, tun: T) -> Self { 45 | Self { 46 | token, 47 | tun, 48 | peers: HashMap::new(), 49 | sessions: SessionIndex::new(), 50 | ips: CidrTable::new(), 51 | } 52 | } 53 | 54 | pub fn metrics(&self) -> HashMap<[u8; 32], PeerMetrics> { 55 | self.peers 56 | .iter() 57 | .map(|(pub_key, entry)| (*pub_key, entry.peer.metrics())) 58 | .collect() 59 | } 60 | 61 | /// Returns the peer that matches the given public key. 62 | pub fn get_by_key(&self, public_key: &[u8; 32]) -> Option>> { 63 | self.peers.get(public_key).map(|e| Arc::clone(&e.peer)) 64 | } 65 | 66 | /// Returns the peer that matches the given IP address. 67 | pub fn get_by_ip(&self, ip: IpAddr) -> Option>> { 68 | self.ips.get_by_ip(ip).cloned() 69 | } 70 | 71 | /// Returns the peer that matches the index of the session. 72 | pub fn get_session_by_index(&self, i: u32) -> Option<(Session, Arc>)> { 73 | match self.sessions.get_by_index(i) { 74 | Some(session) => self 75 | .get_by_key(session.secret().public_key().as_bytes()) 76 | .map(|peer| (session, peer)), 77 | None => None, 78 | } 79 | } 80 | 81 | #[inline] 82 | pub fn all(&self) -> Vec>> { 83 | self.peers 84 | .values() 85 | .map(|entry| Arc::clone(&entry.peer)) 86 | .collect() 87 | } 88 | 89 | pub fn insert( 90 | &mut self, 91 | secret: PeerStaticSecret, 92 | allowed_ips: HashSet, 93 | endpoint: Option>, 94 | persistent_keepalive_interval: Option, 95 | ) -> Arc> { 96 | let entry = self 97 | .peers 98 | .entry(secret.public_key().to_bytes()) 99 | .or_insert_with(|| { 100 | let (inbound_tx, inbound_rx) = mpsc::channel(256); 101 | let (outbound_tx, outbound_rx) = mpsc::channel(256); 102 | let peer = Arc::new(Peer::new( 103 | self.tun.clone(), 104 | secret, 105 | self.sessions.clone(), 106 | endpoint, 107 | inbound_tx, 108 | outbound_tx, 109 | persistent_keepalive_interval, 110 | )); 111 | let handle = PeerHandle::spawn( 112 | self.token.child_token(), 113 | Arc::clone(&peer), 114 | inbound_rx, 115 | outbound_rx, 116 | ); 117 | PeerEntry { 118 | peer, 119 | allowed_ips, 120 | handle, 121 | } 122 | }); 123 | 124 | for &cidr in &entry.allowed_ips { 125 | self.ips.insert(cidr, Arc::clone(&entry.peer)); 126 | } 127 | 128 | Arc::clone(&entry.peer) 129 | } 130 | 131 | pub fn update_allowed_ips_by_key( 132 | &mut self, 133 | public_key: &[u8; 32], 134 | allowed_ips: HashSet, 135 | ) -> bool { 136 | if let Some(entry) = self.peers.get_mut(public_key) { 137 | if entry.allowed_ips == allowed_ips { 138 | return false; 139 | } 140 | for cidr in &entry.allowed_ips { 141 | self.ips.remove(cidr); 142 | } 143 | for cidr in allowed_ips.clone() { 144 | self.ips.insert(cidr, Arc::clone(&entry.peer)); 145 | } 146 | entry.allowed_ips = allowed_ips; 147 | true 148 | } else { 149 | false 150 | } 151 | } 152 | 153 | pub fn remove_by_key(&mut self, public_key: &[u8; 32]) { 154 | if let Some(entry) = self.peers.remove(public_key) { 155 | tokio::spawn(entry.handle.cancel(Duration::from_secs(5))); 156 | for cidr in entry.allowed_ips { 157 | self.ips.remove(&cidr); 158 | } 159 | self.sessions.remove_by_key(public_key); 160 | } 161 | } 162 | 163 | pub fn clear(&mut self) { 164 | self.peers.drain().for_each(|(_, entry)| { 165 | tokio::spawn(entry.handle.cancel(Duration::from_secs(5))); 166 | }); 167 | self.ips.clear(); 168 | self.sessions.clear(); 169 | } 170 | 171 | pub fn to_config(&self) -> Vec { 172 | self.peers 173 | .values() 174 | .map(|entry| PeerConfig { 175 | public_key: entry.peer.secret().public_key().to_bytes(), 176 | allowed_ips: entry.allowed_ips.clone(), 177 | endpoint: entry.peer.endpoint().map(|endpoint| endpoint.dst()), 178 | preshared_key: Some(*entry.peer.secret().psk()), 179 | persistent_keepalive: None, 180 | }) 181 | .collect() 182 | } 183 | } 184 | 185 | impl Drop for PeerIndex 186 | where 187 | T: Tun + 'static, 188 | P: Transport, 189 | { 190 | fn drop(&mut self) { 191 | self.token.cancel(); 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /src/device/peer/mod.rs: -------------------------------------------------------------------------------- 1 | mod cidr; 2 | mod handle; 3 | pub mod handshake; 4 | mod index; 5 | mod monitor; 6 | mod session; 7 | 8 | pub use cidr::{Cidr, ParseCidrError}; 9 | pub use monitor::PeerMetrics; 10 | 11 | pub(crate) use handle::PeerHandle; 12 | pub(crate) use index::PeerIndex; 13 | pub(crate) use session::Session; 14 | 15 | use std::fmt::{Debug, Display, Formatter}; 16 | use std::sync::RwLock; 17 | use std::time::Duration; 18 | 19 | use tokio::sync::mpsc; 20 | use tracing::{debug, warn}; 21 | 22 | use crate::device::inbound::{Endpoint, Transport}; 23 | use crate::noise::crypto; 24 | use crate::noise::crypto::PeerStaticSecret; 25 | use crate::noise::handshake::IncomingInitiation; 26 | use crate::noise::protocol; 27 | use crate::Tun; 28 | use handshake::Handshake; 29 | use monitor::PeerMonitor; 30 | use session::{ActiveSession, SessionIndex}; 31 | 32 | #[derive(Debug)] 33 | pub(crate) enum OutboundEvent { 34 | Data(Vec), 35 | } 36 | 37 | #[derive(Debug)] 38 | pub(crate) enum InboundEvent 39 | where 40 | I: Transport, 41 | { 42 | HandshakeInitiation { 43 | endpoint: Endpoint, 44 | initiation: IncomingInitiation, 45 | }, 46 | HandshakeResponse { 47 | endpoint: Endpoint, 48 | packet: protocol::HandshakeResponse, 49 | session: Session, 50 | }, 51 | CookieReply { 52 | endpoint: Endpoint, 53 | packet: protocol::CookieReply, 54 | session: Session, 55 | }, 56 | TransportData { 57 | endpoint: Endpoint, 58 | packet: protocol::TransportData, 59 | session: Session, 60 | }, 61 | } 62 | 63 | pub(crate) type InboundTx = mpsc::Sender>; 64 | pub(crate) type InboundRx = mpsc::Receiver>; 65 | pub(crate) type OutboundTx = mpsc::Sender; 66 | pub(crate) type OutboundRx = mpsc::Receiver; 67 | 68 | pub(crate) struct Peer 69 | where 70 | T: Tun, 71 | I: Transport, 72 | { 73 | tun: T, 74 | secret: PeerStaticSecret, 75 | sessions: RwLock, 76 | handshake: RwLock, 77 | endpoint: RwLock>>, 78 | inbound: InboundTx, 79 | outbound: OutboundTx, 80 | monitor: PeerMonitor, 81 | } 82 | 83 | impl Peer 84 | where 85 | T: Tun + 'static, 86 | I: Transport, 87 | { 88 | pub(super) fn new( 89 | tun: T, 90 | secret: PeerStaticSecret, 91 | session_index: SessionIndex, 92 | endpoint: Option>, 93 | inbound: InboundTx, 94 | outbound: OutboundTx, 95 | persitent_keepalive_interval: Option, 96 | ) -> Self { 97 | let handshake = RwLock::new(Handshake::new(secret.clone(), session_index.clone())); 98 | let sessions = RwLock::new(ActiveSession::new(session_index)); 99 | let monitor = PeerMonitor::new(persitent_keepalive_interval); 100 | let endpoint = RwLock::new(endpoint); 101 | Self { 102 | tun, 103 | secret, 104 | handshake, 105 | sessions, 106 | inbound, 107 | outbound, 108 | endpoint, 109 | monitor, 110 | } 111 | } 112 | 113 | /// Stage inbound data from tun. 114 | #[inline] 115 | pub async fn handle_inbound(&self, e: InboundEvent) { 116 | if let Err(e) = self.inbound.send(e).await { 117 | warn!("{} not able to handle inbound: {}", self, e); 118 | } 119 | } 120 | 121 | /// Stage outbound data to be sent to the peer 122 | #[inline] 123 | pub async fn stage_outbound(&self, buf: Vec) { 124 | if let Err(e) = self.outbound.send(OutboundEvent::Data(buf)).await { 125 | warn!("{} not able to stage outbound: {}", self, e); 126 | } 127 | } 128 | 129 | /// Send keepalive packet to the peer if the traffic is idle. 130 | #[inline] 131 | pub async fn keepalive(&self) { 132 | if !self.monitor.keepalive().can(self.monitor.traffic()) { 133 | debug!("{self} not able to send keepalive"); 134 | return; 135 | } 136 | self.monitor.keepalive().attempt(); 137 | debug!("{self} sending keepalive"); 138 | self.stage_outbound(vec![]).await; 139 | } 140 | 141 | /// Update the endpoint of the peer. 142 | /// Could be called by IPC or the inbound loop. 143 | #[inline] 144 | pub fn update_endpoint(&self, endpoint: Endpoint) { 145 | let mut guard = self.endpoint.write().unwrap(); 146 | let _ = guard.insert(endpoint); 147 | } 148 | 149 | #[inline] 150 | pub fn endpoint(&self) -> Option> { 151 | let endpoint = self.endpoint.read().unwrap(); 152 | endpoint.clone() 153 | } 154 | 155 | #[inline] 156 | pub fn metrics(&self) -> PeerMetrics { 157 | self.monitor.metrics() 158 | } 159 | 160 | #[inline] 161 | pub fn secret(&self) -> PeerStaticSecret { 162 | self.secret.clone() 163 | } 164 | 165 | // send outbound data 166 | #[inline] 167 | async fn send_outbound(&self, buf: &[u8]) { 168 | let endpoint = { self.endpoint.read().unwrap().clone() }; 169 | if let Some(endpoint) = endpoint { 170 | endpoint.send(buf).await.unwrap(); 171 | } else { 172 | debug!("no endpoint to send outbound packet to peer {self}"); 173 | } 174 | } 175 | } 176 | 177 | impl Display for Peer 178 | where 179 | T: Tun + 'static, 180 | I: Transport, 181 | { 182 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 183 | write!( 184 | f, 185 | "Peer({})", 186 | crypto::encode_to_hex(self.secret.public_key().as_bytes()) 187 | ) 188 | } 189 | } 190 | 191 | impl Debug for Peer 192 | where 193 | T: Tun + 'static, 194 | I: Transport, 195 | { 196 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 197 | f.debug_struct("Peer") 198 | .field( 199 | "public_key", 200 | &crypto::encode_to_hex(self.secret.public_key().as_bytes()), 201 | ) 202 | .finish() 203 | } 204 | } 205 | -------------------------------------------------------------------------------- /src/device/peer/monitor.rs: -------------------------------------------------------------------------------- 1 | use std::sync::atomic::{AtomicU64, Ordering}; 2 | use std::time::{Duration, Instant, SystemTime}; 3 | 4 | use crate::device::time::{AtomicInstant, AtomicTimestamp}; 5 | 6 | const REKEY_AFTER_MESSAGES: u64 = 1 << 60; 7 | const REKEY_AFTER_TIME: Duration = Duration::from_secs(120); 8 | const REJECT_AFTER_TIME: Duration = Duration::from_secs(180); 9 | const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90); 10 | const REKEY_TIMEOUT: Duration = Duration::from_secs(5); 11 | const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); 12 | 13 | pub(super) struct HandshakeMonitor { 14 | last_attempt_at: AtomicInstant, 15 | last_complete_at: AtomicInstant, 16 | last_complete_ts: AtomicTimestamp, 17 | attempt_before: AtomicInstant, 18 | } 19 | 20 | impl HandshakeMonitor { 21 | #[inline] 22 | pub fn new() -> Self { 23 | Self { 24 | last_attempt_at: AtomicInstant::now(), 25 | last_complete_at: AtomicInstant::from_std(Instant::now() - REJECT_AFTER_TIME), 26 | attempt_before: AtomicInstant::now() + REKEY_ATTEMPT_TIME, 27 | last_complete_ts: AtomicTimestamp::zeroed(), 28 | } 29 | } 30 | 31 | #[inline] 32 | pub fn initiated(&self) { 33 | self.last_attempt_at.set_now(); 34 | } 35 | 36 | #[inline] 37 | pub fn will_initiate_in(&self) -> Instant { 38 | if self.is_max_attempt() || self.last_complete_at.elapsed() < REKEY_AFTER_TIME { 39 | return Instant::now() + REKEY_TIMEOUT; 40 | } 41 | 42 | self.last_attempt_at.to_std() + REKEY_TIMEOUT 43 | } 44 | 45 | #[inline] 46 | pub fn completed(&self) { 47 | self.last_complete_at.set_now(); 48 | self.last_complete_ts.set_now(); 49 | self.reset_attempt(); 50 | } 51 | 52 | #[inline] 53 | pub fn reset_attempt(&self) { 54 | self.attempt_before.set_now(); 55 | self.attempt_before.add_duration(REKEY_ATTEMPT_TIME); 56 | } 57 | 58 | #[inline] 59 | fn is_max_attempt(&self) -> bool { 60 | self.attempt_before.to_std() < Instant::now() 61 | } 62 | } 63 | 64 | pub(super) struct TrafficMonitor { 65 | last_sent_at: AtomicInstant, 66 | last_recv_at: AtomicInstant, 67 | tx_messages: AtomicU64, 68 | rx_messages: AtomicU64, 69 | tx_bytes: AtomicU64, 70 | rx_bytes: AtomicU64, 71 | } 72 | 73 | impl TrafficMonitor { 74 | pub fn new() -> Self { 75 | Self { 76 | last_sent_at: AtomicInstant::from_std(Instant::now()), 77 | last_recv_at: AtomicInstant::from_std(Instant::now() - REKEY_TIMEOUT), 78 | tx_messages: AtomicU64::new(0), 79 | rx_messages: AtomicU64::new(0), 80 | tx_bytes: AtomicU64::new(0), 81 | rx_bytes: AtomicU64::new(0), 82 | } 83 | } 84 | 85 | #[inline] 86 | pub fn outbound(&self, bytes: usize) { 87 | let n = bytes as _; 88 | self.last_sent_at.set_now(); 89 | self.tx_messages.fetch_add(1, Ordering::Relaxed); 90 | self.tx_bytes.fetch_add(n, Ordering::Relaxed); 91 | } 92 | 93 | #[inline] 94 | pub fn inbound(&self, bytes: usize) { 95 | let n = bytes as _; 96 | self.rx_messages.fetch_add(1, Ordering::Relaxed); 97 | self.rx_bytes.fetch_add(n, Ordering::Relaxed); 98 | } 99 | } 100 | 101 | pub(super) struct KeepAliveMonitor { 102 | last_attempt_at: AtomicInstant, 103 | perisistent_keepalive_interval: Option, 104 | } 105 | 106 | impl KeepAliveMonitor { 107 | pub fn new(persistent_keepalive_interval: Option) -> Self { 108 | Self { 109 | last_attempt_at: AtomicInstant::now(), 110 | perisistent_keepalive_interval: persistent_keepalive_interval, 111 | } 112 | } 113 | 114 | #[inline] 115 | pub fn next_attempt_in(&self, traffic: &TrafficMonitor) -> Instant { 116 | if self.last_attempt_at.elapsed() >= KEEPALIVE_TIMEOUT 117 | && traffic.last_recv_at.to_std() > traffic.last_sent_at.to_std() 118 | { 119 | if traffic.last_recv_at.elapsed() > KEEPALIVE_TIMEOUT { 120 | return Instant::now(); 121 | } else { 122 | return Instant::now() + KEEPALIVE_TIMEOUT - traffic.last_recv_at.elapsed(); 123 | } 124 | } 125 | 126 | self.perisistent_keepalive_interval 127 | .map(|v| self.last_attempt_at.to_std() + v) 128 | .unwrap_or_else(|| Instant::now() + REKEY_AFTER_TIME) 129 | } 130 | 131 | #[inline] 132 | pub fn can(&self, traffic: &TrafficMonitor) -> bool { 133 | self.next_attempt_in(traffic) <= Instant::now() 134 | } 135 | 136 | #[inline] 137 | pub fn attempt(&self) { 138 | self.last_attempt_at.set_now(); 139 | } 140 | } 141 | 142 | pub(super) struct PeerMonitor { 143 | handshake: HandshakeMonitor, 144 | traffic: TrafficMonitor, 145 | keepalive: KeepAliveMonitor, 146 | } 147 | 148 | impl PeerMonitor { 149 | pub fn new(persistent_keepalive_interval: Option) -> Self { 150 | Self { 151 | handshake: HandshakeMonitor::new(), 152 | traffic: TrafficMonitor::new(), 153 | keepalive: KeepAliveMonitor::new(persistent_keepalive_interval), 154 | } 155 | } 156 | 157 | #[inline] 158 | pub fn can_handshake(&self) -> bool { 159 | if self.traffic.tx_messages.load(Ordering::Relaxed) >= REKEY_AFTER_MESSAGES { 160 | return true; 161 | } 162 | 163 | if self.handshake.last_complete_at.elapsed() < REKEY_AFTER_TIME { 164 | // An active session exists 165 | return false; 166 | } 167 | 168 | if self.handshake.attempt_before.to_std() 169 | < self.handshake.last_complete_at.to_std() + REKEY_AFTER_TIME 170 | { 171 | self.handshake.reset_attempt(); 172 | } 173 | 174 | self.handshake.last_attempt_at.elapsed() >= REKEY_TIMEOUT 175 | } 176 | 177 | #[inline] 178 | pub fn traffic(&self) -> &TrafficMonitor { 179 | &self.traffic 180 | } 181 | 182 | #[inline] 183 | pub fn handshake(&self) -> &HandshakeMonitor { 184 | &self.handshake 185 | } 186 | 187 | #[inline] 188 | pub fn keepalive(&self) -> &KeepAliveMonitor { 189 | &self.keepalive 190 | } 191 | 192 | #[inline] 193 | pub fn metrics(&self) -> PeerMetrics { 194 | PeerMetrics { 195 | tx_messages: self.traffic.tx_messages.load(Ordering::Relaxed), 196 | rx_messages: self.traffic.rx_messages.load(Ordering::Relaxed), 197 | tx_bytes: self.traffic.tx_bytes.load(Ordering::Relaxed), 198 | rx_bytes: self.traffic.rx_bytes.load(Ordering::Relaxed), 199 | last_handshake_at: self.handshake.last_complete_ts.to_std(), 200 | } 201 | } 202 | } 203 | 204 | pub struct PeerMetrics { 205 | pub tx_messages: u64, 206 | pub rx_messages: u64, 207 | pub tx_bytes: u64, 208 | pub rx_bytes: u64, 209 | pub last_handshake_at: SystemTime, 210 | } 211 | -------------------------------------------------------------------------------- /src/device/rate_limiter.rs: -------------------------------------------------------------------------------- 1 | use std::sync::atomic::{AtomicI16, Ordering}; 2 | use std::time::Duration; 3 | 4 | use super::time::AtomicInstant; 5 | 6 | pub(crate) struct RateLimiter { 7 | tokens: u16, 8 | bucket: AtomicI16, 9 | last_at: AtomicInstant, 10 | } 11 | 12 | impl RateLimiter { 13 | pub fn new(tokens: u16) -> Self { 14 | Self { 15 | tokens, 16 | bucket: AtomicI16::new(tokens as _), 17 | last_at: AtomicInstant::now(), 18 | } 19 | } 20 | 21 | pub fn fetch_token(&self) -> bool { 22 | if self.last_at.elapsed() > Duration::from_secs(1) { 23 | self.bucket.store(self.tokens as i16 - 1, Ordering::Relaxed); 24 | self.last_at.set_now(); 25 | true 26 | } else if self.bucket.load(Ordering::Relaxed) > 0 { 27 | self.bucket.fetch_sub(1, Ordering::Relaxed) > 0 28 | } else { 29 | false 30 | } 31 | } 32 | } 33 | 34 | #[cfg(test)] 35 | mod tests { 36 | use super::*; 37 | 38 | #[test] 39 | fn test_ratelimiter_fetch_token() { 40 | let rl = RateLimiter::new(5); 41 | assert!(rl.fetch_token()); 42 | assert!(rl.fetch_token()); 43 | assert!(rl.fetch_token()); 44 | assert!(rl.fetch_token()); 45 | assert!(rl.fetch_token()); 46 | assert!(!rl.fetch_token()); 47 | assert!(!rl.fetch_token()); 48 | assert!(!rl.fetch_token()); 49 | std::thread::sleep(Duration::from_secs(1)); 50 | assert!(rl.fetch_token()); 51 | assert!(rl.fetch_token()); 52 | assert!(rl.fetch_token()); 53 | assert!(rl.fetch_token()); 54 | assert!(rl.fetch_token()); 55 | assert!(!rl.fetch_token()); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/device/time.rs: -------------------------------------------------------------------------------- 1 | use std::ops::Add; 2 | use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; 3 | use std::time::{Duration, Instant, SystemTime}; 4 | 5 | pub(crate) struct AtomicTimestamp { 6 | secs: AtomicU64, 7 | nanos: AtomicU32, 8 | } 9 | 10 | impl AtomicTimestamp { 11 | #[inline(always)] 12 | pub fn zeroed() -> Self { 13 | Self { 14 | secs: AtomicU64::new(0), 15 | nanos: AtomicU32::new(0), 16 | } 17 | } 18 | 19 | #[inline(always)] 20 | pub fn from_std(t: SystemTime) -> Self { 21 | let d = t.duration_since(SystemTime::UNIX_EPOCH).unwrap(); 22 | Self { 23 | secs: AtomicU64::new(d.as_secs()), 24 | nanos: AtomicU32::new(d.subsec_nanos()), 25 | } 26 | } 27 | 28 | #[inline(always)] 29 | pub fn set_now(&self) { 30 | let now = SystemTime::UNIX_EPOCH.elapsed().expect("fetch system time"); 31 | self.secs.store(now.as_secs(), Ordering::Relaxed); 32 | self.nanos.store(now.subsec_nanos(), Ordering::Relaxed); 33 | } 34 | 35 | #[inline(always)] 36 | pub fn timestamp(&self) -> (u64, u32) { 37 | ( 38 | self.secs.load(Ordering::Relaxed), 39 | self.nanos.load(Ordering::Relaxed), 40 | ) 41 | } 42 | 43 | #[inline(always)] 44 | pub fn to_std(&self) -> SystemTime { 45 | let (secs, nanos) = self.timestamp(); 46 | SystemTime::UNIX_EPOCH + Duration::from_secs(secs) + Duration::from_nanos(nanos as _) 47 | } 48 | } 49 | 50 | impl From for SystemTime { 51 | fn from(value: AtomicTimestamp) -> Self { 52 | value.to_std() 53 | } 54 | } 55 | 56 | impl From for AtomicTimestamp { 57 | fn from(value: SystemTime) -> Self { 58 | Self::from_std(value) 59 | } 60 | } 61 | 62 | pub(crate) struct AtomicInstant { 63 | epoch: Instant, 64 | d: AtomicU64, 65 | } 66 | 67 | impl AtomicInstant { 68 | pub fn from_std(epoch: Instant) -> Self { 69 | Self { 70 | epoch, 71 | d: AtomicU64::new(0), 72 | } 73 | } 74 | 75 | #[inline(always)] 76 | pub fn now() -> Self { 77 | Self::from_std(Instant::now()) 78 | } 79 | 80 | #[inline(always)] 81 | pub fn set_now(&self) { 82 | let elpased = self.epoch.elapsed(); 83 | self.d.store(elpased.as_millis() as _, Ordering::Relaxed); 84 | } 85 | 86 | #[inline(always)] 87 | pub fn add_duration(&self, d: Duration) { 88 | self.d.fetch_add(d.as_millis() as _, Ordering::Relaxed); 89 | } 90 | 91 | #[inline(always)] 92 | pub fn elapsed(&self) -> Duration { 93 | self.to_std().elapsed() 94 | } 95 | 96 | #[inline(always)] 97 | pub fn to_std(&self) -> Instant { 98 | self.epoch + Duration::from_millis(self.d.load(Ordering::Relaxed)) 99 | } 100 | } 101 | 102 | impl From for Instant { 103 | fn from(value: AtomicInstant) -> Self { 104 | value.to_std() 105 | } 106 | } 107 | 108 | impl From for AtomicInstant { 109 | fn from(value: Instant) -> Self { 110 | Self::from_std(value) 111 | } 112 | } 113 | 114 | impl Add for AtomicInstant { 115 | type Output = Self; 116 | 117 | fn add(self, rhs: Duration) -> Self::Output { 118 | self.add_duration(rhs); 119 | self 120 | } 121 | } 122 | 123 | impl Eq for AtomicInstant {} 124 | 125 | impl PartialEq for AtomicInstant { 126 | fn eq(&self, other: &Self) -> bool { 127 | self.to_std().eq(&other.to_std()) 128 | } 129 | } 130 | 131 | impl PartialOrd for AtomicInstant { 132 | fn partial_cmp(&self, other: &Self) -> Option { 133 | Some(self.cmp(other)) 134 | } 135 | } 136 | 137 | impl Ord for AtomicInstant { 138 | fn cmp(&self, other: &Self) -> std::cmp::Ordering { 139 | self.to_std().cmp(&other.to_std()) 140 | } 141 | } 142 | 143 | #[cfg(test)] 144 | mod tests { 145 | use super::*; 146 | 147 | #[test] 148 | fn test_atomic_timestamp() { 149 | let now = SystemTime::now(); 150 | let ts = AtomicTimestamp::from_std(now); 151 | assert_eq!(ts.to_std(), now); 152 | } 153 | 154 | #[test] 155 | fn test_atomic_instant() { 156 | let now = Instant::now(); 157 | let instant = AtomicInstant::from_std(now); 158 | assert_eq!(instant.to_std(), now); 159 | 160 | let now = now + Duration::from_secs(1); 161 | instant.add_duration(Duration::from_secs(1)); 162 | assert_eq!(instant.to_std(), now); 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::map_clone)] // FIXME: Not working properly, fixed in v1.78 2 | #![deny( 3 | warnings, 4 | rust_2018_idioms, 5 | clippy::clone_on_ref_ptr, 6 | clippy::dbg_macro, 7 | clippy::enum_glob_use, 8 | clippy::get_unwrap, 9 | clippy::macro_use_imports, 10 | clippy::str_to_string, 11 | clippy::inefficient_to_string, 12 | clippy::too_many_lines, 13 | clippy::or_fun_call 14 | )] 15 | 16 | //! # WireTun 17 | //! 18 | //! WireTun is a user-space WireGuard implementation in Rust. 19 | //! 20 | //! ## What is WireGuard? 21 | //! 22 | //! WireGuard is a modern, high-performance VPN protocol that is designed to be simple to use and easy to configure. 23 | //! It is often used to create secure private networks and build reliable, low-latency connections. 24 | //! 25 | //! ## Features 26 | //! 27 | //! - Implementation of the [WireGuard](https://www.wireguard.com/) protocol in Rust. 28 | //! - Asynchronous I/O using [Tokio](https://tokio.rs/). 29 | //! 30 | //! # Examples 31 | //! 32 | //! ```no_run 33 | //! use wiretun::{Cidr, Device, DeviceConfig, PeerConfig}; 34 | //! 35 | //! #[tokio::main] 36 | //! async fn main() -> Result<(), Box> { 37 | //! let cfg = DeviceConfig::default() 38 | //! .listen_port(40001); 39 | //! let device = Device::native("utun88", cfg).await?; 40 | //! Ok(()) 41 | //! } 42 | 43 | mod device; 44 | pub mod noise; 45 | mod tun; 46 | 47 | pub use device::{ 48 | Cidr, Device, DeviceConfig, DeviceControl, Endpoint, ParseCidrError, PeerConfig, Transport, 49 | UdpTransport, 50 | }; 51 | pub use noise::crypto::{LocalStaticSecret, PeerStaticSecret}; 52 | pub use tun::{Error as TunError, Tun}; 53 | 54 | #[cfg(feature = "native")] 55 | /// Native tun implementation. 56 | pub use tun::NativeTun; 57 | 58 | #[cfg(feature = "uapi")] 59 | pub mod uapi; 60 | -------------------------------------------------------------------------------- /src/noise/crypto.rs: -------------------------------------------------------------------------------- 1 | use blake2::{ 2 | digest::{FixedOutput, Mac, Update}, 3 | Blake2s256, Blake2sMac, Digest, 4 | }; 5 | use rand_core::OsRng; 6 | 7 | use super::Error; 8 | 9 | pub type PrivateKey = x25519_dalek::StaticSecret; 10 | pub type PublicKey = x25519_dalek::PublicKey; 11 | pub type EphemeralPrivateKey = x25519_dalek::ReusableSecret; 12 | 13 | pub fn encode_to_hex(key: &[u8]) -> String { 14 | use std::fmt::Write; 15 | let mut s = String::with_capacity(key.len() * 2); 16 | for &b in key { 17 | write!(&mut s, "{:02x}", b).unwrap(); 18 | } 19 | s 20 | } 21 | 22 | pub fn decode_from_hex(s: &str) -> Vec { 23 | (0..s.len()) 24 | .step_by(2) 25 | .map(|i| u8::from_str_radix(&s[i..i + 2], 16).unwrap()) 26 | .collect() 27 | } 28 | 29 | #[derive(Clone)] 30 | pub struct LocalStaticSecret { 31 | private: PrivateKey, 32 | public: PublicKey, 33 | } 34 | 35 | impl LocalStaticSecret { 36 | #[inline(always)] 37 | pub fn random() -> Self { 38 | Self::new(PrivateKey::random_from_rng(OsRng).to_bytes()) 39 | } 40 | 41 | #[inline(always)] 42 | pub fn new(private_key: [u8; 32]) -> Self { 43 | let private = PrivateKey::from(private_key); 44 | let public = PublicKey::from(&private); 45 | 46 | Self { private, public } 47 | } 48 | 49 | #[inline(always)] 50 | pub fn with_peer(self, peer_public_key: [u8; 32]) -> PeerStaticSecret { 51 | PeerStaticSecret::new(self, peer_public_key) 52 | } 53 | 54 | #[inline(always)] 55 | pub fn private_key(&self) -> &PrivateKey { 56 | &self.private 57 | } 58 | 59 | #[inline(always)] 60 | pub fn public_key(&self) -> &PublicKey { 61 | &self.public 62 | } 63 | } 64 | 65 | #[derive(Clone)] 66 | pub struct PeerStaticSecret { 67 | local: LocalStaticSecret, 68 | public: PublicKey, 69 | psk: [u8; 32], // pre-shared key 70 | } 71 | 72 | impl PeerStaticSecret { 73 | #[inline(always)] 74 | pub fn new(local: LocalStaticSecret, public_key: [u8; 32]) -> Self { 75 | let public = PublicKey::from(public_key); 76 | let psk = [0u8; 32]; 77 | 78 | Self { local, public, psk } 79 | } 80 | 81 | #[inline(always)] 82 | pub fn random_psk() -> [u8; 32] { 83 | x25519_dalek::StaticSecret::random_from_rng(OsRng).to_bytes() 84 | } 85 | 86 | #[inline(always)] 87 | pub fn set_psk(&mut self, psk: [u8; 32]) { 88 | self.psk = psk; 89 | } 90 | 91 | #[inline(always)] 92 | pub fn psk(&self) -> &[u8; 32] { 93 | &self.psk 94 | } 95 | 96 | #[inline(always)] 97 | pub fn local(&self) -> &LocalStaticSecret { 98 | &self.local 99 | } 100 | 101 | #[inline(always)] 102 | pub fn public_key(&self) -> &PublicKey { 103 | &self.public 104 | } 105 | } 106 | 107 | #[inline] 108 | pub fn gen_ephemeral_key() -> (EphemeralPrivateKey, PublicKey) { 109 | let secret = EphemeralPrivateKey::random_from_rng(OsRng); 110 | let public = PublicKey::from(&secret); 111 | (secret, public) 112 | } 113 | 114 | #[inline] 115 | pub fn hash(in1: &[u8], in2: &[u8]) -> [u8; 32] { 116 | Blake2s256::new().chain(in1).chain(in2).finalize().into() 117 | } 118 | 119 | #[inline] 120 | pub fn mac(key: &[u8], in0: &[u8]) -> [u8; 16] { 121 | Blake2sMac::new_from_slice(key) 122 | .unwrap() 123 | .chain(in0) 124 | .finalize_fixed() 125 | .into() 126 | } 127 | 128 | #[inline] 129 | pub fn hmac1(key: &[u8], in0: &[u8]) -> [u8; 32] { 130 | type HmacBlake2s = hmac::SimpleHmac; 131 | HmacBlake2s::new_from_slice(key) 132 | .unwrap() 133 | .chain(in0) 134 | .finalize_fixed() 135 | .into() 136 | } 137 | 138 | #[inline] 139 | pub fn hmac2(key: &[u8], in0: &[u8], in1: &[u8]) -> [u8; 32] { 140 | type HmacBlake2s = hmac::SimpleHmac; 141 | HmacBlake2s::new_from_slice(key) 142 | .unwrap() 143 | .chain(in0) 144 | .chain(in1) 145 | .finalize_fixed() 146 | .into() 147 | } 148 | 149 | #[inline] 150 | pub fn kdf1(key: &[u8], in0: &[u8]) -> [u8; 32] { 151 | hmac1(&hmac1(key, in0), &[0x1]) 152 | } 153 | 154 | #[inline] 155 | pub fn kdf2(key: &[u8], in0: &[u8]) -> ([u8; 32], [u8; 32]) { 156 | let prk = hmac1(key, in0); 157 | let t0 = hmac1(&prk, &[0x1]); 158 | let t1 = hmac2(&prk, &t0, &[0x2]); 159 | (t0, t1) 160 | } 161 | 162 | #[inline] 163 | pub fn kdf3(key: &[u8], in0: &[u8]) -> ([u8; 32], [u8; 32], [u8; 32]) { 164 | let prk = hmac1(key, in0); 165 | let t0 = hmac1(&prk, &[0x1]); 166 | let t1 = hmac2(&prk, &t0, &[0x2]); 167 | let t2 = hmac2(&prk, &t1, &[0x3]); 168 | (t0, t1, t2) 169 | } 170 | 171 | #[inline] 172 | pub fn aead_encrypt(key: &[u8], counter: u64, msg: &[u8], aad: &[u8]) -> Result, Error> { 173 | use chacha20poly1305::aead::{Aead, Payload}; 174 | use chacha20poly1305::{KeyInit, Nonce}; 175 | let nonce = { 176 | let mut nonce = [0u8; 12]; 177 | nonce[4..].copy_from_slice(&counter.to_le_bytes()); 178 | nonce 179 | }; 180 | 181 | chacha20poly1305::ChaCha20Poly1305::new_from_slice(key) 182 | .map_err(|_| Error::InvalidKeyLength)? 183 | .encrypt(Nonce::from_slice(&nonce), Payload { msg, aad }) 184 | .map_err(Error::Encryption) 185 | } 186 | 187 | #[inline] 188 | pub fn aead_decrypt(key: &[u8], counter: u64, msg: &[u8], aad: &[u8]) -> Result, Error> { 189 | use chacha20poly1305::aead::{Aead, Payload}; 190 | use chacha20poly1305::{KeyInit, Nonce}; 191 | let nonce = { 192 | let mut nonce = [0u8; 12]; 193 | nonce[4..].copy_from_slice(&counter.to_le_bytes()); 194 | nonce 195 | }; 196 | chacha20poly1305::ChaCha20Poly1305::new_from_slice(key) 197 | .map_err(|_| Error::InvalidKeyLength)? 198 | .decrypt(Nonce::from_slice(&nonce), Payload { msg, aad }) 199 | .map_err(|_| Error::Decryption) 200 | } 201 | 202 | #[inline] 203 | pub fn xaead_encrypt(key: &[u8], nonce: &[u8], msg: &[u8], aad: &[u8]) -> Result, Error> { 204 | use chacha20poly1305::aead::{Aead, Payload}; 205 | use chacha20poly1305::{KeyInit, XNonce}; 206 | chacha20poly1305::XChaCha20Poly1305::new_from_slice(key) 207 | .map_err(|_| Error::InvalidKeyLength)? 208 | .encrypt(XNonce::from_slice(nonce), Payload { msg, aad }) 209 | .map_err(Error::Encryption) 210 | } 211 | 212 | #[inline] 213 | pub fn xaead_decrypt(key: &[u8], nonce: &[u8], msg: &[u8], aad: &[u8]) -> Result, Error> { 214 | use chacha20poly1305::aead::{Aead, Payload}; 215 | use chacha20poly1305::{KeyInit, XNonce}; 216 | chacha20poly1305::XChaCha20Poly1305::new_from_slice(key) 217 | .map_err(|_| Error::InvalidKeyLength)? 218 | .decrypt(XNonce::from_slice(nonce), Payload { msg, aad }) 219 | .map_err(|_| Error::Decryption) 220 | } 221 | 222 | #[cfg(test)] 223 | mod tests { 224 | use super::*; 225 | 226 | #[test] 227 | fn test_hash() { 228 | assert_eq!( 229 | hash(b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s", b""), 230 | [ 231 | 96, 226, 109, 174, 243, 39, 239, 192, 46, 195, 53, 226, 160, 37, 210, 208, 22, 235, 232 | 66, 6, 248, 114, 119, 245, 45, 56, 209, 152, 139, 120, 205, 54, 233 | ] 234 | ) 235 | } 236 | 237 | #[test] 238 | fn test_kdf() { 239 | let cases = [ 240 | ( 241 | "746573742d6b6579", 242 | "746573742d696e707574", 243 | ( 244 | "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", 245 | "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", 246 | "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", 247 | ), 248 | ), 249 | ( 250 | "776972656775617264", 251 | "776972656775617264", 252 | ( 253 | "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", 254 | "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", 255 | "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", 256 | ), 257 | ), 258 | ( 259 | "", 260 | "", 261 | ( 262 | "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", 263 | "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", 264 | "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", 265 | ), 266 | ), 267 | ]; 268 | // test kdf1 269 | for (key, input, (t0, _, _)) in cases { 270 | let key = decode_from_hex(key); 271 | let input = decode_from_hex(input); 272 | let out = kdf1(&key, &input); 273 | assert_eq!(encode_to_hex(&out), t0); 274 | } 275 | 276 | // test kdf2 277 | for (key, input, (t0, t1, _)) in cases { 278 | let key = decode_from_hex(key); 279 | let input = decode_from_hex(input); 280 | let out = kdf2(&key, &input); 281 | assert_eq!(encode_to_hex(&out.0), t0); 282 | assert_eq!(encode_to_hex(&out.1), t1); 283 | } 284 | 285 | // test kdf3 286 | for (key, input, (t0, t1, t2)) in cases { 287 | let key = decode_from_hex(key); 288 | let input = decode_from_hex(input); 289 | let out = kdf3(&key, &input); 290 | assert_eq!(encode_to_hex(&out.0), t0); 291 | assert_eq!(encode_to_hex(&out.1), t1); 292 | assert_eq!(encode_to_hex(&out.2), t2); 293 | } 294 | } 295 | 296 | #[test] 297 | fn test_aead() { 298 | let key = b"0123456789abcdef0123456789abcdef"; 299 | let aad = b"fedcba9876543210"; 300 | let data = b"foobar"; 301 | let counter = 42; 302 | let encrypted = aead_encrypt(key, counter, data, aad).unwrap(); 303 | assert_eq!( 304 | "3b97d40eb9a5a78385054b7be7027c9661a2031f4f91", 305 | encode_to_hex(&encrypted), 306 | ); 307 | let decrypted = aead_decrypt(key, counter, &encrypted, aad).unwrap(); 308 | assert_eq!(data, &decrypted[..]); 309 | } 310 | 311 | #[test] 312 | fn test_xaead() { 313 | let key = b"0123456789abcdef0123456789abcdef"; 314 | let aad = b"fedcba9876543210"; 315 | let data = b"foobar"; 316 | let nonce = b"0123456789abcdef01234567"; 317 | let encrypted = xaead_encrypt(key, nonce, data, aad).unwrap(); 318 | assert_eq!( 319 | "2f8312b423a80a32585bcf059fbcfeee8063d258f030", 320 | encode_to_hex(&encrypted), 321 | ); 322 | let decrypted = xaead_decrypt(key, nonce, &encrypted, aad).unwrap(); 323 | assert_eq!(data, &decrypted[..]); 324 | } 325 | } 326 | -------------------------------------------------------------------------------- /src/noise/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(thiserror::Error, Debug)] 2 | pub enum Error { 3 | #[error("invalid key length")] 4 | InvalidKeyLength, 5 | #[error("unable to encrypt")] 6 | Encryption(chacha20poly1305::aead::Error), 7 | #[error("unable to decrypt")] 8 | Decryption, 9 | #[error("invalid packet")] 10 | InvalidPacket, 11 | #[error("invalid handshake state")] 12 | InvalidHandshakeState, 13 | #[error("receiver index not match")] 14 | ReceiverIndexNotMatch, 15 | } 16 | -------------------------------------------------------------------------------- /src/noise/handshake/cookie.rs: -------------------------------------------------------------------------------- 1 | use std::net::SocketAddr; 2 | use std::sync::Mutex; 3 | use std::time::{Duration, Instant}; 4 | 5 | use bytes::{BufMut, BytesMut}; 6 | use rand_core::{OsRng, RngCore}; 7 | 8 | use super::{LABEL_COOKIE, LABEL_MAC1}; 9 | use crate::noise::crypto::{hash, mac, xaead_encrypt, LocalStaticSecret, PeerStaticSecret}; 10 | 11 | const MESSAGE_TYPE_COOKIE_REPLY: u8 = 3u8; 12 | const PACKET_SIZE: usize = 64; 13 | const COOKIE_LIFETIME: Duration = Duration::from_secs(120); 14 | 15 | pub struct MacGenerator { 16 | peer_mac1_hash: [u8; 32], // pre-compute hash for generating mac1 17 | peer_cookie_hash: [u8; 32], // pre-compute hash for generating mac2 18 | last_cookie: Option<([u8; 16], Instant)>, 19 | } 20 | 21 | impl MacGenerator { 22 | #[inline] 23 | pub fn new(secret: &PeerStaticSecret) -> Self { 24 | let peer_pub = secret.public_key().as_bytes(); 25 | Self { 26 | peer_mac1_hash: hash(&LABEL_MAC1, peer_pub), 27 | peer_cookie_hash: hash(&LABEL_COOKIE, peer_pub), 28 | last_cookie: None, 29 | } 30 | } 31 | 32 | /// Generate mac1 for handshake initiation and response. 33 | #[inline] 34 | pub fn generate_mac1(&self, payload: &[u8]) -> [u8; 16] { 35 | mac(&self.peer_mac1_hash, payload) 36 | } 37 | 38 | /// Generate mac2 for handshake initiation and response. 39 | #[inline] 40 | pub fn generate_mac2(&self, payload: &[u8]) -> [u8; 16] { 41 | if self.last_cookie.is_none() || self.last_cookie.unwrap().1.elapsed() >= COOKIE_LIFETIME { 42 | [0u8; 16] 43 | } else { 44 | mac(&self.peer_cookie_hash, payload) 45 | } 46 | } 47 | } 48 | 49 | pub struct Cookie { 50 | secret: Mutex>, 51 | cookie_hash: [u8; 32], 52 | mac1_hash: [u8; 32], 53 | } 54 | 55 | impl Cookie { 56 | pub fn new(secret: &LocalStaticSecret) -> Self { 57 | let cookie_hash = hash(&LABEL_COOKIE, secret.public_key().as_bytes()); 58 | let mac1_hash = hash(&LABEL_MAC1, secret.public_key().as_bytes()); 59 | 60 | Self { 61 | secret: Mutex::new(None), 62 | cookie_hash, 63 | mac1_hash, 64 | } 65 | } 66 | 67 | /// Validate mac1 of the payload. 68 | pub fn validate_mac1(&self, payload: &[u8]) -> bool { 69 | let (msg, macs) = payload.split_at(payload.len() - 32); 70 | let (mac1, _mac2) = macs.split_at(16); 71 | 72 | mac1 == mac(&self.mac1_hash, msg) 73 | } 74 | 75 | /// Validate mac2 of the payload. 76 | pub fn validate_mac2(&self, payload: &[u8]) -> bool { 77 | let (msg, macs) = payload.split_at(payload.len() - 32); 78 | let (_mac1, mac2) = macs.split_at(16); 79 | 80 | mac2 == mac(&self.cookie_hash, msg) 81 | } 82 | 83 | pub fn generate_cookie_reply(&self, payload: &[u8], dst: SocketAddr) -> Vec { 84 | let mut buf = BytesMut::with_capacity(PACKET_SIZE); 85 | 86 | buf.put_u32_le(MESSAGE_TYPE_COOKIE_REPLY as _); 87 | buf.put_slice(&payload[4..8]); // receiver index 88 | 89 | let nonce = Self::gen_nonce(); 90 | buf.put_slice(&nonce); 91 | 92 | let mac1 = &payload[payload.len() - 32..payload.len() - 16]; 93 | let msg = { 94 | let secret = self.refresh_secret(); 95 | let dst = Self::encode_dst_addr(dst); 96 | mac(&secret, &dst) 97 | }; 98 | 99 | let cookie = xaead_encrypt(&self.cookie_hash, &nonce, &msg, mac1).unwrap(); 100 | buf.put_slice(&cookie); 101 | buf.freeze().to_vec() 102 | } 103 | 104 | // Refresh the secret if it's expired. 105 | fn refresh_secret(&self) -> [u8; 32] { 106 | let mut secret = self.secret.lock().unwrap(); 107 | if let Some(v) = secret.as_ref() { 108 | if v.1.elapsed() < COOKIE_LIFETIME { 109 | return v.0; 110 | } 111 | } 112 | 113 | let mut rv = [0u8; 32]; 114 | OsRng.fill_bytes(&mut rv); 115 | secret.replace((rv, Instant::now())); 116 | rv 117 | } 118 | 119 | #[inline] 120 | fn gen_nonce() -> [u8; 24] { 121 | let mut b = [0u8; 24]; 122 | OsRng.fill_bytes(&mut b); 123 | b 124 | } 125 | 126 | #[inline] 127 | fn encode_dst_addr(addr: SocketAddr) -> Vec { 128 | let mut bytes = vec![]; 129 | match addr { 130 | SocketAddr::V4(addr) => { 131 | bytes.extend_from_slice(&addr.ip().octets()); 132 | bytes.extend_from_slice(&addr.port().to_le_bytes()); 133 | } 134 | SocketAddr::V6(addr) => { 135 | bytes.extend_from_slice(&addr.ip().octets()); 136 | bytes.extend_from_slice(&addr.port().to_le_bytes()); 137 | } 138 | }; 139 | bytes 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /src/noise/handshake/initiation.rs: -------------------------------------------------------------------------------- 1 | use bytes::{BufMut, BytesMut}; 2 | 3 | use super::{MacGenerator, CONSTRUCTION, IDENTIFIER}; 4 | use crate::noise::crypto::{EphemeralPrivateKey, LocalStaticSecret, PeerStaticSecret, PublicKey}; 5 | use crate::noise::protocol::HandshakeInitiation; 6 | use crate::noise::{ 7 | crypto::{aead_decrypt, aead_encrypt, gen_ephemeral_key, hash, kdf1, kdf2}, 8 | timestamp::Timestamp, 9 | Error, 10 | }; 11 | 12 | const MESSAGE_TYPE_HANDSHAKE_INITIATION: u8 = 1u8; 13 | const PACKET_SIZE: usize = 148; 14 | 15 | pub struct OutgoingInitiation { 16 | pub index: u32, 17 | pub hash: [u8; 32], 18 | pub chaining_key: [u8; 32], 19 | pub ephemeral_private_key: EphemeralPrivateKey, 20 | } 21 | 22 | impl OutgoingInitiation { 23 | pub fn new( 24 | sender_index: u32, 25 | secret: &PeerStaticSecret, 26 | macs: &mut MacGenerator, 27 | ) -> (Self, Vec) { 28 | let mut buf = BytesMut::with_capacity(PACKET_SIZE); 29 | 30 | buf.put_u32_le(MESSAGE_TYPE_HANDSHAKE_INITIATION as _); 31 | buf.put_u32_le(sender_index); 32 | 33 | let c = hash(&CONSTRUCTION, b""); 34 | let h = hash(&hash(&c, &IDENTIFIER), secret.public_key().as_bytes()); 35 | let (ephemeral_pri, ephemeral_pub) = gen_ephemeral_key(); 36 | let c = kdf1(&c, ephemeral_pub.as_bytes()); 37 | buf.put_slice(ephemeral_pub.as_bytes()); // 32 bytes 38 | let h = hash(&h, ephemeral_pub.as_bytes()); 39 | let (c, k) = kdf2( 40 | &c, 41 | ephemeral_pri.diffie_hellman(secret.public_key()).as_bytes(), 42 | ); 43 | let static_key = aead_encrypt(&k, 0, secret.local().public_key().as_bytes(), &h).unwrap(); 44 | buf.put_slice(&static_key); // 32 + 16 bytes 45 | let h = hash(&h, &static_key); 46 | let (c, k) = kdf2( 47 | &c, 48 | secret 49 | .local() 50 | .private_key() 51 | .diffie_hellman(secret.public_key()) 52 | .as_bytes(), 53 | ); 54 | let timestamp = aead_encrypt(&k, 0, Timestamp::now().as_bytes(), &h).unwrap(); 55 | buf.put_slice(×tamp); // 12 + 16 bytes 56 | let h = hash(&h, ×tamp); 57 | 58 | // mac1 and mac2 59 | buf.put_slice(&macs.generate_mac1(&buf)); // 16 bytes 60 | buf.put_slice(&macs.generate_mac2(&buf)); // 16 bytes 61 | 62 | let payload = buf.freeze().to_vec(); 63 | ( 64 | Self { 65 | index: sender_index, 66 | hash: h, 67 | chaining_key: c, 68 | ephemeral_private_key: ephemeral_pri, 69 | }, 70 | payload, 71 | ) 72 | } 73 | } 74 | 75 | #[derive(Debug)] 76 | pub struct IncomingInitiation { 77 | pub index: u32, 78 | pub hash: [u8; 32], 79 | pub chaining_key: [u8; 32], 80 | pub timestamp: Timestamp, 81 | pub ephemeral_public_key: PublicKey, 82 | pub static_public_key: PublicKey, 83 | } 84 | 85 | impl IncomingInitiation { 86 | pub fn parse(secret: &LocalStaticSecret, packet: &HandshakeInitiation) -> Result { 87 | let c = hash(&CONSTRUCTION, b""); 88 | let h = hash(&hash(&c, &IDENTIFIER), secret.public_key().as_bytes()); 89 | let peer_ephemeral_pub = PublicKey::from(packet.ephemeral_public_key); 90 | let c = kdf1(&c, &packet.ephemeral_public_key); 91 | let h = hash(&h, &packet.ephemeral_public_key); 92 | let (c, k) = kdf2( 93 | &c, 94 | secret 95 | .private_key() 96 | .diffie_hellman(&peer_ephemeral_pub) 97 | .as_bytes(), 98 | ); 99 | let static_key: [u8; 32] = aead_decrypt(&k, 0, &packet.static_public_key, &h)? 100 | .try_into() 101 | .unwrap(); 102 | let peer_static_pub = PublicKey::from(static_key); 103 | 104 | let h = hash(&h, &packet.static_public_key); 105 | let (c, k) = kdf2( 106 | &c, 107 | secret 108 | .private_key() 109 | .diffie_hellman(&peer_static_pub) 110 | .as_bytes(), 111 | ); 112 | let timestamp: [u8; 12] = aead_decrypt(&k, 0, &packet.timestamp, &h)? 113 | .try_into() 114 | .unwrap(); 115 | let timestamp = Timestamp::from(timestamp); 116 | let h = hash(&h, &packet.timestamp); 117 | Ok(Self { 118 | index: packet.sender_index, 119 | hash: h, 120 | chaining_key: c, 121 | timestamp, 122 | ephemeral_public_key: peer_ephemeral_pub, 123 | static_public_key: peer_static_pub, 124 | }) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /src/noise/handshake/mod.rs: -------------------------------------------------------------------------------- 1 | mod cookie; 2 | mod initiation; 3 | mod response; 4 | 5 | pub const CONSTRUCTION: [u8; 37] = *b"Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"; 6 | pub const IDENTIFIER: [u8; 34] = *b"WireGuard v1 zx2c4 Jason@zx2c4.com"; 7 | pub const LABEL_MAC1: [u8; 8] = *b"mac1----"; 8 | pub const LABEL_COOKIE: [u8; 8] = *b"cookie--"; 9 | 10 | pub use cookie::{Cookie, MacGenerator}; 11 | pub use initiation::{IncomingInitiation, OutgoingInitiation}; 12 | pub use response::{IncomingResponse, OutgoingResponse}; 13 | 14 | #[cfg(test)] 15 | mod tests { 16 | use super::*; 17 | use crate::noise::crypto::{LocalStaticSecret, PeerStaticSecret}; 18 | use crate::noise::protocol::{HandshakeInitiation, HandshakeResponse}; 19 | 20 | #[inline] 21 | fn gen_2_static_key() -> (PeerStaticSecret, PeerStaticSecret) { 22 | let p1_local = LocalStaticSecret::random(); 23 | let p2_local = LocalStaticSecret::random(); 24 | let mut p1_secret = p1_local.clone().with_peer(p2_local.public_key().to_bytes()); 25 | let mut p2_secret = p2_local.with_peer(p1_local.public_key().to_bytes()); 26 | let psk = PeerStaticSecret::random_psk(); 27 | p1_secret.set_psk(psk); 28 | p2_secret.set_psk(psk); 29 | 30 | (p1_secret, p2_secret) 31 | } 32 | 33 | #[test] 34 | fn handshake_initiation() { 35 | let (p1_key, p2_key) = gen_2_static_key(); 36 | let (p1_i, _p2_i) = (42, 88); 37 | let mut p1_cookie = MacGenerator::new(&p2_key); 38 | 39 | let (init_out, payload) = OutgoingInitiation::new(p1_i, &p1_key, &mut p1_cookie); 40 | let packet = HandshakeInitiation::try_from(payload.as_slice()).unwrap(); 41 | let init_in = IncomingInitiation::parse(p2_key.local(), &packet).unwrap(); 42 | 43 | assert_eq!(init_in.index, p1_i); 44 | assert_eq!(init_out.hash, init_in.hash); 45 | assert_eq!(init_out.chaining_key, init_in.chaining_key); 46 | } 47 | 48 | #[test] 49 | fn handshake_response() { 50 | let (p1_key, p2_key) = gen_2_static_key(); 51 | let (p1_i, p2_i) = (42, 88); 52 | let mut p1_cookie = MacGenerator::new(&p2_key); 53 | let mut p2_cookie = MacGenerator::new(&p1_key); 54 | 55 | let (init_out, payload) = OutgoingInitiation::new(p1_i, &p1_key, &mut p1_cookie); 56 | let packet = HandshakeInitiation::try_from(payload.as_slice()).unwrap(); 57 | let init_in = IncomingInitiation::parse(p2_key.local(), &packet).unwrap(); 58 | 59 | assert_eq!(init_out.hash, init_in.hash); 60 | assert_eq!(init_out.chaining_key, init_in.chaining_key); 61 | 62 | let (resp_out, payload) = OutgoingResponse::new(&init_in, p2_i, &p2_key, &mut p2_cookie); 63 | let packet = HandshakeResponse::try_from(payload.as_slice()).unwrap(); 64 | let resp_in = IncomingResponse::parse(&init_out, &p1_key, &packet).unwrap(); 65 | 66 | assert_eq!(resp_in.index, p2_i); 67 | assert_eq!(resp_out.chaining_key, resp_in.chaining_key); 68 | assert_eq!(resp_out.hash, resp_in.hash); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/noise/handshake/response.rs: -------------------------------------------------------------------------------- 1 | use bytes::{BufMut, BytesMut}; 2 | 3 | use super::{IncomingInitiation, MacGenerator, OutgoingInitiation}; 4 | use crate::noise::protocol::HandshakeResponse; 5 | use crate::noise::{ 6 | crypto::{ 7 | aead_decrypt, aead_encrypt, gen_ephemeral_key, hash, kdf1, kdf3, EphemeralPrivateKey, 8 | PeerStaticSecret, PublicKey, 9 | }, 10 | Error, 11 | }; 12 | 13 | const MESSAGE_TYPE_HANDSHAKE_RESPONSE: u8 = 2u8; 14 | const PACKET_SIZE: usize = 92; 15 | 16 | pub struct OutgoingResponse { 17 | pub hash: [u8; 32], 18 | pub chaining_key: [u8; 32], 19 | pub ephemeral_private_key: EphemeralPrivateKey, 20 | } 21 | 22 | impl OutgoingResponse { 23 | pub fn new( 24 | initiation: &IncomingInitiation, 25 | local_index: u32, 26 | secret: &PeerStaticSecret, 27 | macs: &mut MacGenerator, 28 | ) -> (Self, Vec) { 29 | let mut buf = BytesMut::with_capacity(PACKET_SIZE); 30 | 31 | buf.put_u32_le(MESSAGE_TYPE_HANDSHAKE_RESPONSE as _); 32 | buf.put_u32_le(local_index); 33 | buf.put_u32_le(initiation.index); 34 | let (ephemeral_pri, ephemeral_pub) = gen_ephemeral_key(); 35 | buf.put_slice(ephemeral_pub.as_bytes()); // 32 bytes 36 | let c = kdf1(&initiation.chaining_key, ephemeral_pub.as_bytes()); 37 | let h = hash(&initiation.hash, ephemeral_pub.as_bytes()); 38 | let c = kdf1( 39 | &c, 40 | ephemeral_pri 41 | .diffie_hellman(&initiation.ephemeral_public_key) 42 | .as_bytes(), 43 | ); 44 | let c = kdf1( 45 | &c, 46 | ephemeral_pri.diffie_hellman(secret.public_key()).as_bytes(), 47 | ); 48 | let (c, t, k) = kdf3(&c, secret.psk()); 49 | let h = hash(&h, &t); 50 | let empty = aead_encrypt(&k, 0, &[], &h).unwrap(); 51 | buf.put_slice(&empty); // 16 bytes 52 | let h = hash(&h, &empty); 53 | 54 | // mac1 and mac2 55 | buf.put_slice(&macs.generate_mac1(&buf)); // 16 bytes 56 | buf.put_slice(&macs.generate_mac2(&buf)); // 16 bytes 57 | 58 | let payload = buf.freeze().to_vec(); 59 | ( 60 | Self { 61 | hash: h, 62 | chaining_key: c, 63 | ephemeral_private_key: ephemeral_pri, 64 | }, 65 | payload, 66 | ) 67 | } 68 | } 69 | 70 | pub struct IncomingResponse { 71 | pub index: u32, 72 | pub ephemeral_public_key: PublicKey, 73 | pub hash: [u8; 32], 74 | pub chaining_key: [u8; 32], 75 | } 76 | 77 | impl IncomingResponse { 78 | pub fn parse( 79 | initiation: &OutgoingInitiation, 80 | secret: &PeerStaticSecret, 81 | packet: &HandshakeResponse, 82 | ) -> Result { 83 | let peer_ephemeral_pub = PublicKey::from(packet.ephemeral_public_key); 84 | let c = kdf1(&initiation.chaining_key, peer_ephemeral_pub.as_bytes()); 85 | let h = hash(&initiation.hash, peer_ephemeral_pub.as_bytes()); 86 | let c = kdf1( 87 | &c, 88 | initiation 89 | .ephemeral_private_key 90 | .diffie_hellman(&peer_ephemeral_pub) 91 | .as_bytes(), 92 | ); 93 | let c = kdf1( 94 | &c, 95 | secret 96 | .local() 97 | .private_key() 98 | .diffie_hellman(&peer_ephemeral_pub) 99 | .as_bytes(), 100 | ); 101 | let (c, t, k) = kdf3(&c, secret.psk()); 102 | let h = hash(&h, &t); 103 | let empty = aead_decrypt(&k, 0, &packet.empty, &h)?; 104 | if !empty.is_empty() { 105 | return Err(Error::Decryption); 106 | } 107 | let h = hash(&h, &packet.empty); 108 | 109 | Ok(Self { 110 | index: packet.sender_index, 111 | ephemeral_public_key: peer_ephemeral_pub, 112 | hash: h, 113 | chaining_key: c, 114 | }) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/noise/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod crypto; 2 | mod error; 3 | pub mod handshake; 4 | pub mod protocol; 5 | mod timestamp; 6 | 7 | pub use error::Error; 8 | pub use protocol::Message; 9 | -------------------------------------------------------------------------------- /src/noise/protocol.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{Debug, Formatter}; 2 | 3 | const MESSAGE_TYPE_HANDSHAKE_INITIATION: u8 = 1u8; 4 | const MESSAGE_TYPE_HANDSHAKE_RESPONSE: u8 = 2u8; 5 | const MESSAGE_TYPE_COOKIE_REPLY: u8 = 3u8; 6 | const MESSAGE_TYPE_TRANSPORT_DATA: u8 = 4u8; 7 | pub const HANDSHAKE_INITIATION_PACKET_SIZE: usize = 148; 8 | pub const HANDSHAKE_RESPONSE_PACKET_SIZE: usize = 92; 9 | pub const COOKIE_REPLY_PACKET_SIZE: usize = 64; 10 | 11 | pub const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 13); 12 | 13 | const MIN_PACKET_SIZE: usize = 4; // TODO 14 | 15 | use super::Error; 16 | 17 | pub struct HandshakeInitiation { 18 | pub sender_index: u32, 19 | pub ephemeral_public_key: [u8; 32], 20 | pub static_public_key: [u8; 32 + 16], 21 | pub timestamp: [u8; 12 + 16], 22 | pub mac1: [u8; 16], 23 | pub mac2: [u8; 16], 24 | } 25 | 26 | impl TryFrom<&[u8]> for HandshakeInitiation { 27 | type Error = Error; 28 | 29 | fn try_from(value: &[u8]) -> Result { 30 | if value.len() != HANDSHAKE_INITIATION_PACKET_SIZE 31 | || value[0..4] != [MESSAGE_TYPE_HANDSHAKE_INITIATION, 0, 0, 0] 32 | { 33 | return Err(Error::InvalidPacket); 34 | } 35 | Ok(Self { 36 | sender_index: u32::from_le_bytes(value[4..8].try_into().unwrap()), 37 | ephemeral_public_key: value[8..40].try_into().unwrap(), 38 | static_public_key: value[40..88].try_into().unwrap(), 39 | timestamp: value[88..116].try_into().unwrap(), 40 | mac1: value[116..132].try_into().unwrap(), 41 | mac2: value[132..148].try_into().unwrap(), 42 | }) 43 | } 44 | } 45 | 46 | impl Debug for HandshakeInitiation { 47 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 48 | f.debug_struct("HandshakeInitiation") 49 | .field("sender_index", &self.sender_index) 50 | .finish() 51 | } 52 | } 53 | 54 | pub struct HandshakeResponse { 55 | pub sender_index: u32, 56 | pub receiver_index: u32, 57 | pub ephemeral_public_key: [u8; 32], 58 | pub empty: [u8; 16], 59 | pub mac1: [u8; 16], 60 | pub mac2: [u8; 16], 61 | } 62 | 63 | impl TryFrom<&[u8]> for HandshakeResponse { 64 | type Error = Error; 65 | 66 | fn try_from(value: &[u8]) -> Result { 67 | if value.len() != HANDSHAKE_RESPONSE_PACKET_SIZE 68 | || value[0..4] != [MESSAGE_TYPE_HANDSHAKE_RESPONSE, 0, 0, 0] 69 | { 70 | return Err(Error::InvalidPacket); 71 | } 72 | Ok(Self { 73 | sender_index: u32::from_le_bytes(value[4..8].try_into().unwrap()), 74 | receiver_index: u32::from_le_bytes(value[8..12].try_into().unwrap()), 75 | ephemeral_public_key: value[12..44].try_into().unwrap(), 76 | empty: value[44..60].try_into().unwrap(), 77 | mac1: value[60..76].try_into().unwrap(), 78 | mac2: value[76..92].try_into().unwrap(), 79 | }) 80 | } 81 | } 82 | 83 | impl Debug for HandshakeResponse { 84 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 85 | f.debug_struct("HandshakeResponse") 86 | .field("sender_index", &self.sender_index) 87 | .field("receiver_index", &self.receiver_index) 88 | .finish() 89 | } 90 | } 91 | 92 | pub struct CookieReply { 93 | pub receiver_index: u32, 94 | pub nonce: [u8; 24], 95 | pub cookie: [u8; 16 + 16], 96 | } 97 | 98 | impl TryFrom<&[u8]> for CookieReply { 99 | type Error = Error; 100 | 101 | fn try_from(value: &[u8]) -> Result { 102 | if value.len() != COOKIE_REPLY_PACKET_SIZE 103 | || value[0..4] != [MESSAGE_TYPE_COOKIE_REPLY, 0, 0, 0] 104 | { 105 | return Err(Error::InvalidPacket); 106 | } 107 | Ok(Self { 108 | receiver_index: u32::from_le_bytes(value[4..8].try_into().unwrap()), 109 | nonce: value[8..32].try_into().unwrap(), 110 | cookie: value[32..64].try_into().unwrap(), 111 | }) 112 | } 113 | } 114 | 115 | impl Debug for CookieReply { 116 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 117 | f.debug_struct("CookieReply") 118 | .field("index", &self.receiver_index) 119 | .field("nonce", &self.nonce) 120 | .finish() 121 | } 122 | } 123 | 124 | pub struct TransportData { 125 | pub receiver_index: u32, 126 | pub counter: u64, 127 | pub payload: Vec, 128 | } 129 | 130 | impl TransportData { 131 | #[inline] 132 | pub fn packet_len(&self) -> usize { 133 | self.payload.len() + 16 134 | } 135 | } 136 | 137 | impl TransportData { 138 | pub fn to_bytes(&self) -> Vec { 139 | let mut bytes = Vec::with_capacity(self.payload.len() + 16); 140 | bytes.extend_from_slice(&[MESSAGE_TYPE_TRANSPORT_DATA, 0, 0, 0]); 141 | bytes.extend_from_slice(&self.receiver_index.to_le_bytes()); 142 | bytes.extend_from_slice(&self.counter.to_le_bytes()); 143 | bytes.extend_from_slice(&self.payload); 144 | bytes 145 | } 146 | } 147 | 148 | impl Debug for TransportData { 149 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 150 | f.debug_struct("TransportData") 151 | .field("receiver", &self.receiver_index) 152 | .field("counter", &self.counter) 153 | .field("len(payload)", &self.payload.len()) 154 | .finish() 155 | } 156 | } 157 | 158 | impl TryFrom<&[u8]> for TransportData { 159 | type Error = Error; 160 | 161 | fn try_from(value: &[u8]) -> Result { 162 | if value.len() < MIN_PACKET_SIZE || value[0..4] != [MESSAGE_TYPE_TRANSPORT_DATA, 0, 0, 0] { 163 | return Err(Error::InvalidPacket); 164 | } 165 | Ok(Self { 166 | receiver_index: u32::from_le_bytes(value[4..8].try_into().unwrap()), 167 | counter: u64::from_le_bytes(value[8..16].try_into().unwrap()), 168 | payload: value[16..].to_vec(), 169 | }) 170 | } 171 | } 172 | 173 | #[derive(Debug)] 174 | pub enum Message { 175 | HandshakeInitiation(HandshakeInitiation), 176 | HandshakeResponse(HandshakeResponse), 177 | CookieReply(CookieReply), 178 | TransportData(TransportData), 179 | } 180 | 181 | impl Message { 182 | pub fn parse(payload: &[u8]) -> Result { 183 | if payload.len() < MIN_PACKET_SIZE { 184 | return Err(Error::InvalidPacket); 185 | } 186 | let message = match payload[0] { 187 | MESSAGE_TYPE_HANDSHAKE_INITIATION => { 188 | Message::HandshakeInitiation(HandshakeInitiation::try_from(payload)?) 189 | } 190 | MESSAGE_TYPE_HANDSHAKE_RESPONSE => { 191 | Message::HandshakeResponse(HandshakeResponse::try_from(payload)?) 192 | } 193 | MESSAGE_TYPE_COOKIE_REPLY => Message::CookieReply(CookieReply::try_from(payload)?), 194 | MESSAGE_TYPE_TRANSPORT_DATA => { 195 | Message::TransportData(TransportData::try_from(payload)?) 196 | } 197 | _ => return Err(Error::InvalidPacket), 198 | }; 199 | 200 | Ok(message) 201 | } 202 | 203 | pub fn is_handshake(payload: &[u8]) -> bool { 204 | match payload[0] { 205 | MESSAGE_TYPE_HANDSHAKE_INITIATION 206 | if payload.len() == HANDSHAKE_INITIATION_PACKET_SIZE => 207 | { 208 | true 209 | } 210 | MESSAGE_TYPE_HANDSHAKE_RESPONSE if payload.len() == HANDSHAKE_RESPONSE_PACKET_SIZE => { 211 | true 212 | } 213 | _ => false, 214 | } 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /src/noise/timestamp.rs: -------------------------------------------------------------------------------- 1 | use std::time::SystemTime; 2 | 3 | const BASE: u64 = 0x400000000000000a; 4 | const WHITENER_MASK: u32 = 0x1000000 - 1; 5 | 6 | #[derive(Debug)] 7 | pub struct Timestamp([u8; 12]); 8 | 9 | impl Timestamp { 10 | fn stamp(t: SystemTime) -> Self { 11 | let d = t.duration_since(SystemTime::UNIX_EPOCH).unwrap(); 12 | 13 | let secs = BASE + d.as_secs(); 14 | let nanos = d.subsec_nanos() & !WHITENER_MASK; 15 | let b = { 16 | let mut dst = [0u8; 12]; 17 | dst[..8].copy_from_slice(&secs.to_be_bytes()); 18 | dst[8..].copy_from_slice(&nanos.to_be_bytes()); 19 | dst 20 | }; 21 | 22 | Self(b) 23 | } 24 | 25 | #[inline(always)] 26 | pub fn now() -> Self { 27 | Self::stamp(SystemTime::now()) 28 | } 29 | 30 | #[inline(always)] 31 | pub fn as_bytes(&self) -> &[u8] { 32 | &self.0 33 | } 34 | } 35 | 36 | impl From<[u8; 12]> for Timestamp { 37 | fn from(b: [u8; 12]) -> Self { 38 | Self(b) 39 | } 40 | } 41 | 42 | impl PartialEq for Timestamp { 43 | fn eq(&self, other: &Self) -> bool { 44 | self.0 == other.0 45 | } 46 | } 47 | 48 | impl Eq for Timestamp {} 49 | 50 | impl PartialOrd for Timestamp { 51 | fn partial_cmp(&self, other: &Self) -> Option { 52 | Some(self.cmp(other)) 53 | } 54 | } 55 | 56 | impl Ord for Timestamp { 57 | fn cmp(&self, other: &Self) -> std::cmp::Ordering { 58 | self.0.cmp(&other.0) 59 | } 60 | } 61 | 62 | #[cfg(test)] 63 | mod tests { 64 | use std::time::{Duration, SystemTime}; 65 | 66 | use super::*; 67 | use crate::noise::crypto; 68 | 69 | #[test] 70 | fn test_timestamp() { 71 | let t0 = SystemTime::UNIX_EPOCH 72 | .checked_add(Duration::new(0, 123456789)) 73 | .unwrap(); 74 | 75 | let ts0 = Timestamp::stamp(t0); 76 | assert_eq!(crypto::encode_to_hex(&ts0.0), "400000000000000a07000000"); 77 | 78 | let ts1 = Timestamp::stamp(t0.checked_add(Duration::from_nanos(10)).unwrap()); 79 | assert!(ts0 >= ts1); 80 | 81 | let ts2 = Timestamp::stamp(t0.checked_add(Duration::from_micros(10)).unwrap()); 82 | assert!(ts0 >= ts2); 83 | 84 | let ts3 = Timestamp::stamp(t0.checked_add(Duration::from_millis(1)).unwrap()); 85 | assert!(ts0 >= ts3); 86 | 87 | let ts4 = Timestamp::stamp(t0.checked_add(Duration::from_millis(10)).unwrap()); 88 | assert!(ts0 >= ts4); 89 | 90 | let ts5 = Timestamp::stamp(t0.checked_add(Duration::from_millis(20)).unwrap()); 91 | assert!(ts0 < ts5); 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/tun/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, thiserror::Error)] 2 | pub enum Error { 3 | #[error("interface must be named utun[0-9]*")] 4 | InvalidName, 5 | #[error("system call failed: {0}")] 6 | IO(#[from] std::io::Error), 7 | #[error("system call errno: {0}")] 8 | Sys(#[from] nix::Error), 9 | #[error("invalid IP packet")] 10 | InvalidIpPacket, 11 | #[error("tun closed")] 12 | Closed, 13 | } 14 | -------------------------------------------------------------------------------- /src/tun/linux/mod.rs: -------------------------------------------------------------------------------- 1 | mod sys; 2 | mod tun; 3 | 4 | pub use tun::NativeTun; 5 | -------------------------------------------------------------------------------- /src/tun/linux/sys.rs: -------------------------------------------------------------------------------- 1 | use std::mem; 2 | use std::os::fd::{AsRawFd, RawFd}; 3 | 4 | use libc::{__c_anonymous_ifr_ifru, c_char, ifreq}; 5 | use nix::fcntl::{fcntl, FcntlArg, OFlag}; 6 | use nix::sys::socket::{socket, AddressFamily, SockFlag, SockType}; 7 | use nix::{ioctl_read_bad, ioctl_write_ptr_bad}; 8 | 9 | use crate::tun::Error; 10 | 11 | ioctl_write_ptr_bad!(ioctl_tun_set_iff, 0x400454ca, ifreq); 12 | ioctl_read_bad!(ioctl_tun_get_iff, 0x800454d2, ifreq); 13 | ioctl_write_ptr_bad!(ioctl_set_mtu, 0x8922, ifreq); 14 | ioctl_read_bad!(ioctl_get_mtu, 0x8921, ifreq); 15 | 16 | pub fn new_ifreq(name: &str) -> ifreq { 17 | let mut ifr: ifreq = unsafe { mem::zeroed() }; 18 | let ifr_name: Vec = name.as_bytes().iter().map(|c| *c as _).collect(); 19 | ifr.ifr_name[..name.len()].copy_from_slice(&ifr_name); 20 | ifr 21 | } 22 | 23 | pub fn set_nonblocking(fd: RawFd) -> Result<(), Error> { 24 | let flag = fcntl(fd, FcntlArg::F_GETFL) 25 | .map(OFlag::from_bits_retain) 26 | .map_err(Error::Sys)?; 27 | let flag = OFlag::O_NONBLOCK | flag; 28 | fcntl(fd, FcntlArg::F_SETFL(flag)).map_err(Error::Sys)?; 29 | Ok(()) 30 | } 31 | 32 | pub fn set_mtu(name: &str, mtu: u16) -> Result<(), Error> { 33 | let fd = socket( 34 | AddressFamily::Inet, 35 | SockType::Datagram, 36 | SockFlag::empty(), 37 | None, 38 | ) 39 | .map_err(Error::Sys)?; 40 | let mut ifr = new_ifreq(name); 41 | ifr.ifr_ifru = __c_anonymous_ifr_ifru { ifru_mtu: mtu as _ }; 42 | unsafe { ioctl_set_mtu(fd.as_raw_fd(), &ifr) }.map_err(Error::Sys)?; 43 | Ok(()) 44 | } 45 | 46 | pub fn get_mtu(name: &str) -> Result { 47 | let fd = socket( 48 | AddressFamily::Inet, 49 | SockType::Datagram, 50 | SockFlag::empty(), 51 | None, 52 | ) 53 | .map_err(Error::Sys)?; 54 | let mut ifr = new_ifreq(name); 55 | unsafe { ioctl_get_mtu(fd.as_raw_fd(), &mut ifr) }.map_err(Error::Sys)?; 56 | Ok(unsafe { ifr.ifr_ifru.ifru_mtu as _ }) 57 | } 58 | -------------------------------------------------------------------------------- /src/tun/linux/tun.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::os::fd::{AsRawFd, FromRawFd, OwnedFd}; 3 | use std::sync::Arc; 4 | 5 | use async_trait::async_trait; 6 | use bytes::BytesMut; 7 | use libc::{__c_anonymous_ifr_ifru, IFF_NO_PI, IFF_TUN, IFF_VNET_HDR}; 8 | use nix::fcntl::{self, OFlag}; 9 | use nix::sys::stat::Mode; 10 | use tokio::io::unix::AsyncFd; 11 | use tracing::debug; 12 | 13 | use crate::tun::linux::sys::{self, get_mtu, ioctl_tun_set_iff, set_mtu, set_nonblocking}; 14 | use crate::tun::Error; 15 | use crate::Tun; 16 | 17 | const DEVICE_PATH: &str = "/dev/net/tun"; 18 | 19 | #[derive(Clone)] 20 | pub struct NativeTun { 21 | fd: Arc>, 22 | name: String, 23 | } 24 | 25 | impl NativeTun { 26 | pub fn new(name: &str) -> Result { 27 | if name.len() > 16 { 28 | return Err(Error::InvalidName); 29 | } 30 | let fd = fcntl::open(DEVICE_PATH, OFlag::O_RDWR | OFlag::O_CLOEXEC, Mode::empty()) 31 | .map(|fd| unsafe { OwnedFd::from_raw_fd(fd) }) 32 | .map_err(Error::Sys)?; 33 | 34 | let mut ifr = sys::new_ifreq(name); 35 | ifr.ifr_ifru = __c_anonymous_ifr_ifru { 36 | ifru_flags: (IFF_TUN | IFF_NO_PI) as _, 37 | }; 38 | let _ = IFF_VNET_HDR; // TODO: enable 39 | 40 | unsafe { ioctl_tun_set_iff(fd.as_raw_fd(), &ifr) }?; 41 | set_nonblocking(fd.as_raw_fd())?; 42 | 43 | Ok(Self { 44 | fd: Arc::new(AsyncFd::new(fd)?), 45 | name: name.to_owned(), 46 | }) 47 | } 48 | } 49 | 50 | #[async_trait] 51 | impl Tun for NativeTun { 52 | fn name(&self) -> &str { 53 | &self.name 54 | } 55 | 56 | fn mtu(&self) -> Result { 57 | get_mtu(&self.name) 58 | } 59 | 60 | fn set_mtu(&self, mtu: u16) -> Result<(), Error> { 61 | set_mtu(&self.name, mtu) 62 | } 63 | 64 | async fn recv(&self) -> Result, Error> { 65 | let mut buf = BytesMut::zeroed(1500); 66 | 67 | loop { 68 | let ret = { 69 | let mut guard = self.fd.readable().await?; 70 | guard.try_io(|inner| unsafe { 71 | let ret = libc::read(inner.as_raw_fd(), buf.as_mut_ptr() as _, buf.len()); 72 | if ret < 0 { 73 | Err::(io::Error::last_os_error()) 74 | } else { 75 | Ok(ret as usize) 76 | } 77 | }) 78 | }; 79 | 80 | match ret { 81 | Ok(Ok(n)) => { 82 | debug!("TUN read {} bytes", n); 83 | buf.truncate(n); 84 | return Ok(buf.freeze().to_vec()); 85 | } 86 | Ok(Err(e)) => return Err(e.into()), 87 | _ => continue, 88 | } 89 | } 90 | } 91 | 92 | async fn send(&self, buf: &[u8]) -> Result<(), Error> { 93 | let mut guard = self.fd.writable().await?; 94 | let ret = guard.try_io(|inner| unsafe { 95 | let ret = libc::write(inner.as_raw_fd(), buf.as_ptr() as _, buf.len()); 96 | if ret < 0 { 97 | Err::(io::Error::last_os_error()) 98 | } else { 99 | Ok(ret as usize) 100 | } 101 | }); 102 | 103 | match ret { 104 | Ok(Ok(_)) => return Ok(()), 105 | Ok(Err(e)) => return Err(e.into()), 106 | _ => {} 107 | } 108 | 109 | Ok(()) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/tun/macos/mod.rs: -------------------------------------------------------------------------------- 1 | mod sys; 2 | mod tun; 3 | 4 | pub use tun::NativeTun; 5 | -------------------------------------------------------------------------------- /src/tun/macos/sys.rs: -------------------------------------------------------------------------------- 1 | use std::ffi::CStr; 2 | use std::os::fd::RawFd; 3 | use std::{io, mem, ptr}; 4 | 5 | use libc::*; 6 | use nix::fcntl::{fcntl, FcntlArg, OFlag}; 7 | use nix::{ioctl_read_bad, ioctl_write_ptr_bad}; 8 | 9 | use crate::tun::Error; 10 | 11 | pub const SIOCSIFMTU: u64 = 0x80206934; 12 | pub const SIOCGIFMTU: u64 = 0xc0206933; 13 | ioctl_read_bad!(ioctl_get_mtu, SIOCGIFMTU, ifreq); 14 | ioctl_write_ptr_bad!(ioctl_set_mtu, SIOCSIFMTU, ifreq); 15 | 16 | pub const CTRL_NAME: [c_char; MAX_KCTL_NAME] = [ 17 | b'c' as _, b'o' as _, b'm' as _, b'.' as _, b'a' as _, b'p' as _, b'p' as _, b'l' as _, 18 | b'e' as _, b'.' as _, b'n' as _, b'e' as _, b't' as _, b'.' as _, b'u' as _, b't' as _, 19 | b'u' as _, b'n' as _, b'_' as _, b'c' as _, b'o' as _, b'n' as _, b't' as _, b'r' as _, 20 | b'o' as _, b'l' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, 21 | b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, 22 | b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, 23 | b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, 24 | b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, 25 | b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, 26 | b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, 27 | b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, 28 | b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, b'\0' as _, 29 | ]; 30 | 31 | #[repr(C)] 32 | #[derive(Copy, Clone)] 33 | pub union ifrn { 34 | pub name: [c_char; IFNAMSIZ], 35 | } 36 | 37 | #[repr(C)] 38 | #[derive(Copy, Clone)] 39 | pub struct ifdevmtu { 40 | pub current: c_int, 41 | pub min: c_int, 42 | pub max: c_int, 43 | } 44 | 45 | #[repr(C)] 46 | #[derive(Copy, Clone)] 47 | pub union ifru { 48 | pub addr: sockaddr, 49 | pub dstaddr: sockaddr, 50 | pub broadaddr: sockaddr, 51 | 52 | pub flags: c_short, 53 | pub metric: c_int, 54 | pub mtu: c_int, 55 | pub phys: c_int, 56 | pub media: c_int, 57 | pub intval: c_int, 58 | pub data: *mut c_void, 59 | pub devmtu: ifdevmtu, 60 | pub wake_flags: c_uint, 61 | pub route_refcnt: c_uint, 62 | pub cap: [c_int; 2], 63 | pub functional_type: c_uint, 64 | } 65 | 66 | #[repr(C)] 67 | #[derive(Copy, Clone)] 68 | pub struct ifreq { 69 | pub ifrn: ifrn, 70 | pub ifru: ifru, 71 | } 72 | 73 | impl ifreq { 74 | pub fn new(name: &str) -> Self { 75 | let mut me: Self = unsafe { mem::zeroed() }; 76 | unsafe { 77 | ptr::copy_nonoverlapping( 78 | name.as_ptr() as *const libc::c_char, 79 | me.ifrn.name.as_mut_ptr(), 80 | name.len(), 81 | ) 82 | } 83 | me 84 | } 85 | } 86 | 87 | pub fn set_nonblocking(fd: RawFd) -> Result<(), Error> { 88 | let flag = fcntl(fd, FcntlArg::F_GETFL) 89 | .map(OFlag::from_bits_retain) 90 | .map_err(Error::Sys)?; 91 | let flag = OFlag::O_NONBLOCK | flag; 92 | fcntl(fd, FcntlArg::F_SETFL(flag)).map_err(Error::Sys)?; 93 | Ok(()) 94 | } 95 | 96 | pub unsafe fn get_iface_name(fd: RawFd) -> Result { 97 | const MAX_LEN: usize = 256; 98 | let mut name = [0u8; MAX_LEN]; 99 | let mut name_len: libc::socklen_t = name.len() as _; 100 | if libc::getsockopt( 101 | fd, 102 | libc::SYSPROTO_CONTROL, 103 | libc::UTUN_OPT_IFNAME, 104 | name.as_mut_ptr() as _, 105 | &mut name_len, 106 | ) < 0 107 | { 108 | return Err(io::Error::last_os_error()); 109 | } 110 | Ok(CStr::from_ptr(name.as_ptr() as *const libc::c_char) 111 | .to_string_lossy() 112 | .into()) 113 | } 114 | 115 | #[cfg(test)] 116 | mod tests { 117 | use super::*; 118 | 119 | #[test] 120 | fn test_ctrl_name() { 121 | let expected = { 122 | const CTRL_NAME_IN_BYTES: &[u8] = b"com.apple.net.utun_control"; 123 | let mut name: [c_char; libc::MAX_KCTL_NAME] = [0_i8; libc::MAX_KCTL_NAME]; 124 | name[..CTRL_NAME_IN_BYTES.len()].copy_from_slice( 125 | CTRL_NAME_IN_BYTES 126 | .iter() 127 | .map(|&x| x as _) 128 | .collect::>() 129 | .as_slice(), 130 | ); 131 | name 132 | }; 133 | 134 | assert_eq!(CTRL_NAME, expected); 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /src/tun/macos/tun.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::mem::{size_of, size_of_val}; 3 | use std::os::fd::{AsRawFd, FromRawFd, OwnedFd}; 4 | use std::sync::Arc; 5 | 6 | use async_trait::async_trait; 7 | use bytes::{Buf, BytesMut}; 8 | use regex::Regex; 9 | use tokio::io::unix::AsyncFd; 10 | use tracing::debug; 11 | 12 | use super::sys; 13 | use crate::tun::{Error, Tun}; 14 | 15 | #[inline] 16 | fn parse_name(name: &str) -> Result { 17 | if name == "utun" { 18 | return Ok(0); 19 | } 20 | let re = Regex::new(r"^utun([1-9]\d*|0)?$").unwrap(); 21 | if !re.is_match(name) { 22 | return Err(Error::InvalidName); 23 | } 24 | name[4..] 25 | .parse() 26 | .map(|i: u32| i + 1) 27 | .map_err(|_| Error::InvalidName) 28 | } 29 | 30 | #[derive(Debug, Clone)] 31 | pub struct NativeTun { 32 | fd: Arc>, 33 | name: String, 34 | } 35 | 36 | impl NativeTun { 37 | pub fn new(name: &str) -> Result { 38 | let idx = parse_name(name)?; 39 | 40 | let fd = match unsafe { 41 | libc::socket(libc::PF_SYSTEM, libc::SOCK_DGRAM, libc::SYSPROTO_CONTROL) 42 | } { 43 | -1 => return Err(io::Error::last_os_error().into()), 44 | fd => unsafe { OwnedFd::from_raw_fd(fd) }, 45 | }; 46 | 47 | let info = libc::ctl_info { 48 | ctl_id: 0, 49 | ctl_name: sys::CTRL_NAME, 50 | }; 51 | if unsafe { libc::ioctl(fd.as_raw_fd(), libc::CTLIOCGINFO, &info) } < 0 { 52 | return Err(io::Error::last_os_error().into()); 53 | } 54 | 55 | let addr = libc::sockaddr_ctl { 56 | sc_len: size_of::() as _, 57 | sc_family: libc::AF_SYSTEM as _, 58 | ss_sysaddr: libc::AF_SYS_CONTROL as _, 59 | sc_id: info.ctl_id, 60 | sc_unit: idx, 61 | sc_reserved: Default::default(), 62 | }; 63 | if unsafe { 64 | libc::connect( 65 | fd.as_raw_fd(), 66 | &addr as *const libc::sockaddr_ctl as _, 67 | size_of_val(&addr) as _, 68 | ) 69 | } < 0 70 | { 71 | return Err(io::Error::last_os_error().into()); 72 | } 73 | 74 | sys::set_nonblocking(fd.as_raw_fd())?; 75 | 76 | let name = unsafe { sys::get_iface_name(fd.as_raw_fd()) }?; 77 | let fd = Arc::new(AsyncFd::new(fd)?); 78 | 79 | Ok(Self { fd, name }) 80 | } 81 | } 82 | 83 | #[async_trait] 84 | impl Tun for NativeTun { 85 | fn name(&self) -> &str { 86 | &self.name 87 | } 88 | 89 | fn set_mtu(&self, mtu: u16) -> Result<(), Error> { 90 | let mut req = sys::ifreq::new(&self.name); 91 | req.ifru.mtu = mtu as _; 92 | unsafe { sys::ioctl_set_mtu(self.fd.as_raw_fd(), &req) }?; 93 | 94 | Ok(()) 95 | } 96 | 97 | fn mtu(&self) -> Result { 98 | let mut req = sys::ifreq::new(&self.name); 99 | 100 | unsafe { sys::ioctl_get_mtu(self.fd.as_raw_fd(), &mut req) }?; 101 | 102 | Ok(unsafe { req.ifru.mtu as _ }) 103 | } 104 | 105 | async fn recv(&self) -> Result, Error> { 106 | let mut buf = BytesMut::zeroed(1500); 107 | 108 | loop { 109 | let ret = { 110 | let mut guard = self.fd.readable().await?; 111 | guard.try_io(|inner| unsafe { 112 | let ret = libc::read(inner.as_raw_fd(), buf.as_mut_ptr() as _, buf.len()); 113 | if ret < 0 { 114 | Err::(io::Error::last_os_error()) 115 | } else { 116 | Ok(ret as usize) 117 | } 118 | }) 119 | }; 120 | 121 | match ret { 122 | Ok(Ok(n)) if n >= 4 => { 123 | debug!("TUN read {} bytes", n); 124 | buf.advance(4); 125 | buf.truncate(n - 4); 126 | return Ok(buf.freeze().to_vec()); 127 | } 128 | Ok(Err(e)) => return Err(e.into()), 129 | _ => continue, 130 | } 131 | } 132 | } 133 | 134 | async fn send(&self, buf: &[u8]) -> Result<(), Error> { 135 | let buf = { 136 | let mut m = vec![0u8; 4 + buf.len()]; 137 | m[3] = match buf[0] >> 4 { 138 | 4 => 0x2, 139 | 6 => 0x1e, 140 | _ => return Err(Error::InvalidIpPacket), 141 | }; 142 | m[4..].copy_from_slice(buf); 143 | m 144 | }; 145 | 146 | let mut guard = self.fd.writable().await?; 147 | let ret = guard.try_io(|inner| unsafe { 148 | let ret = libc::write(inner.as_raw_fd(), buf.as_ptr() as _, buf.len()); 149 | if ret < 0 { 150 | Err::(io::Error::last_os_error()) 151 | } else { 152 | Ok(ret as usize) 153 | } 154 | }); 155 | 156 | match ret { 157 | Ok(Ok(_)) => return Ok(()), 158 | Ok(Err(e)) => return Err(e.into()), 159 | _ => {} 160 | } 161 | 162 | Ok(()) 163 | } 164 | } 165 | 166 | #[cfg(test)] 167 | mod tests { 168 | use super::*; 169 | 170 | #[test] 171 | fn test_parse_name() { 172 | let success_cases = [("utun", 0), ("utun0", 1), ("utun42", 43)]; 173 | 174 | for (input, expected) in success_cases { 175 | let rv = parse_name(input); 176 | assert!(rv.is_ok()); 177 | assert_eq!(rv.unwrap(), expected); 178 | } 179 | 180 | let failure_cases = ["utun04", "utun007", "utun42foo", "utunfoo", "futun"]; 181 | 182 | for input in failure_cases { 183 | assert!(parse_name(input).is_err()) 184 | } 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /src/tun/mod.rs: -------------------------------------------------------------------------------- 1 | mod error; 2 | pub use error::Error; 3 | 4 | #[cfg(target_os = "macos")] 5 | mod macos; 6 | #[cfg(target_os = "macos")] 7 | pub use macos::NativeTun; 8 | 9 | #[cfg(target_os = "linux")] 10 | mod linux; 11 | #[cfg(target_os = "linux")] 12 | pub use linux::NativeTun; 13 | 14 | use async_trait::async_trait; 15 | 16 | #[async_trait] 17 | pub trait Tun: Send + Sync + Clone { 18 | fn name(&self) -> &str; 19 | fn mtu(&self) -> Result; 20 | fn set_mtu(&self, mtu: u16) -> Result<(), Error>; 21 | async fn recv(&self) -> Result, Error>; 22 | async fn send(&self, buf: &[u8]) -> Result<(), Error>; 23 | } 24 | -------------------------------------------------------------------------------- /src/uapi/connection.rs: -------------------------------------------------------------------------------- 1 | use bytes::Bytes; 2 | use std::collections::HashSet; 3 | use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; 4 | use tokio::net::unix::{OwnedReadHalf, OwnedWriteHalf}; 5 | use tokio::net::UnixStream; 6 | use tracing::debug; 7 | 8 | use super::{Error, Request, Response, SetDevice, SetPeer}; 9 | use crate::noise::crypto; 10 | 11 | pub struct Connection { 12 | reader: BufReader, 13 | writer: OwnedWriteHalf, 14 | } 15 | 16 | impl Connection { 17 | pub(super) fn new(socket: UnixStream) -> Self { 18 | let (rh, wh) = socket.into_split(); 19 | Self { 20 | reader: BufReader::new(rh), 21 | writer: wh, 22 | } 23 | } 24 | 25 | /// ## Cancel Safety 26 | /// The method is not cancellation safe. 27 | pub async fn next(&mut self) -> Result { 28 | let mut op = vec![]; 29 | self.reader.read_until(b'\n', &mut op).await?; 30 | 31 | match op.as_slice() { 32 | b"get=1\n" => { 33 | if self.reader.read_u8().await? != b'\n' { 34 | return Err(Error::InvalidProtocol); 35 | } 36 | Ok(Request::Get) 37 | } 38 | b"set=1\n" => { 39 | let mut buf = vec![]; 40 | while self.reader.read_until(b'\n', &mut buf).await? > 1 {} 41 | let s = unsafe { String::from_utf8_unchecked(buf).trim_end().to_owned() }; 42 | 43 | Ok(Request::Set(parse_set_request(&s)?)) 44 | } 45 | _ => Err(Error::InvalidProtocol), 46 | } 47 | } 48 | 49 | /// ## Cancel Safety 50 | /// The method is not cancellation safe. 51 | pub async fn write(&mut self, resp: Response) { 52 | match resp { 53 | Response::Ok => { 54 | debug!("UAPI: writing ok response"); 55 | self.writer.write_all(b"errno=0\n\n").await.unwrap(); 56 | } 57 | Response::Get(info) => { 58 | let buf: Bytes = info.into(); 59 | self.writer.write_all(buf.as_ref()).await.unwrap(); 60 | } 61 | _ => {} 62 | } 63 | } 64 | } 65 | 66 | #[allow(clippy::too_many_lines)] 67 | fn parse_set_request(s: &str) -> Result { 68 | debug!("UAPI: parsing set request: {:?}", s); 69 | 70 | let mut set_device = SetDevice { 71 | private_key: None, 72 | listen_port: None, 73 | fwmark: None, 74 | replace_peers: false, 75 | peers: vec![], 76 | }; 77 | for line in s.split('\n') { 78 | let (k, v) = line.split_once('=').ok_or(Error::InvalidProtocol)?; 79 | 80 | match k { 81 | "private_key" => { 82 | let mut private_key = [0u8; 32]; 83 | private_key.copy_from_slice(crypto::decode_from_hex(v).as_slice()); 84 | set_device.private_key = Some(private_key); 85 | } 86 | "listen_port" => { 87 | set_device.listen_port = Some(v.parse().map_err(|_| Error::InvalidProtocol)?); 88 | } 89 | "fwmark" => { 90 | set_device.fwmark = Some(v.parse().map_err(|_| Error::InvalidProtocol)?); 91 | } 92 | "replace_peers" => { 93 | if v != "true" { 94 | return Err(Error::InvalidProtocol); 95 | } 96 | set_device.replace_peers = true; 97 | } 98 | "public_key" => { 99 | set_device.peers.push(SetPeer { 100 | public_key: [0u8; 32], 101 | remove: false, 102 | update_only: false, 103 | psk: None, 104 | endpoint: None, 105 | persistent_keepalive_interval: None, 106 | replace_allowed_ips: false, 107 | allowed_ips: HashSet::new(), 108 | }); 109 | 110 | set_device 111 | .peers 112 | .last_mut() 113 | .ok_or(Error::InvalidProtocol)? 114 | .public_key = crypto::decode_from_hex(v) 115 | .as_slice() 116 | .try_into() 117 | .map_err(|_| Error::InvalidProtocol)?; 118 | } 119 | "remove" => { 120 | if v != "true" { 121 | return Err(Error::InvalidProtocol); 122 | } 123 | 124 | set_device 125 | .peers 126 | .last_mut() 127 | .ok_or(Error::InvalidProtocol)? 128 | .remove = true; 129 | } 130 | "update_only" => { 131 | if v != "true" { 132 | return Err(Error::InvalidProtocol); 133 | } 134 | 135 | set_device 136 | .peers 137 | .last_mut() 138 | .ok_or(Error::InvalidProtocol)? 139 | .update_only = true; 140 | } 141 | "preshared_key" => { 142 | set_device 143 | .peers 144 | .last_mut() 145 | .ok_or(Error::InvalidProtocol)? 146 | .psk = Some( 147 | crypto::decode_from_hex(v) 148 | .as_slice() 149 | .try_into() 150 | .map_err(|_| Error::InvalidProtocol)?, 151 | ); 152 | } 153 | "endpoint" => { 154 | set_device 155 | .peers 156 | .last_mut() 157 | .ok_or(Error::InvalidProtocol)? 158 | .endpoint = Some(v.parse().map_err(|_| Error::InvalidProtocol)?); 159 | } 160 | "persistent_keepalive_interval" => { 161 | set_device 162 | .peers 163 | .last_mut() 164 | .ok_or(Error::InvalidProtocol)? 165 | .persistent_keepalive_interval = 166 | Some(v.parse().map_err(|_| Error::InvalidProtocol)?); 167 | } 168 | "replace_allowed_ips" => { 169 | if v != "true" { 170 | return Err(Error::InvalidProtocol); 171 | } 172 | set_device 173 | .peers 174 | .last_mut() 175 | .ok_or(Error::InvalidProtocol)? 176 | .replace_allowed_ips = true; 177 | } 178 | "allowed_ip" => { 179 | set_device 180 | .peers 181 | .last_mut() 182 | .ok_or(Error::InvalidProtocol)? 183 | .allowed_ips 184 | .insert(v.parse().map_err(|_| Error::InvalidProtocol)?); 185 | } 186 | _ => return Err(Error::InvalidProtocol), 187 | } 188 | } 189 | 190 | Ok(set_device) 191 | } 192 | 193 | #[cfg(test)] 194 | mod tests { 195 | use super::*; 196 | 197 | #[test] 198 | #[allow(clippy::too_many_lines)] 199 | fn test_parse_set_request() { 200 | let rv = parse_set_request( 201 | "private_key=e84b5a6d2717c1003a13b431570353dbaca9146cf150c5f8575680feba52027a 202 | fwmark=0 203 | listen_port=12912 204 | replace_peers=true 205 | public_key=b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33 206 | preshared_key=188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52 207 | replace_allowed_ips=true 208 | allowed_ip=192.168.4.4/32 209 | endpoint=[abcd:23::33%2]:51820 210 | public_key=58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376 211 | replace_allowed_ips=true 212 | allowed_ip=192.168.4.6/32 213 | persistent_keepalive_interval=111 214 | endpoint=182.122.22.19:3233 215 | public_key=662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58 216 | endpoint=5.152.198.39:51820 217 | replace_allowed_ips=true 218 | allowed_ip=192.168.4.10/32 219 | allowed_ip=192.168.4.11/32 220 | public_key=e818b58db5274087fcc1be5dc728cf53d3b5726b4cef6b9bab8f8f8c2452c25c 221 | remove=true", 222 | ); 223 | 224 | assert!(rv.is_ok()); 225 | let rv = rv.unwrap(); 226 | assert_eq!( 227 | rv, 228 | SetDevice { 229 | private_key: Some( 230 | crypto::decode_from_hex( 231 | "e84b5a6d2717c1003a13b431570353dbaca9146cf150c5f8575680feba52027a" 232 | ) 233 | .try_into() 234 | .unwrap() 235 | ), 236 | listen_port: Some(12912), 237 | fwmark: Some(0), 238 | replace_peers: true, 239 | peers: vec![ 240 | SetPeer { 241 | public_key: crypto::decode_from_hex( 242 | "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33" 243 | ) 244 | .try_into() 245 | .unwrap(), 246 | remove: false, 247 | update_only: false, 248 | psk: Some( 249 | crypto::decode_from_hex( 250 | "188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52" 251 | ) 252 | .try_into() 253 | .unwrap() 254 | ), 255 | endpoint: Some("[abcd:23::33%2]:51820".parse().unwrap()), 256 | persistent_keepalive_interval: None, 257 | replace_allowed_ips: true, 258 | allowed_ips: ["192.168.4.4/32".parse().unwrap()].into_iter().collect(), 259 | }, 260 | SetPeer { 261 | public_key: crypto::decode_from_hex( 262 | "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376" 263 | ) 264 | .try_into() 265 | .unwrap(), 266 | remove: false, 267 | update_only: false, 268 | psk: None, 269 | endpoint: Some("182.122.22.19:3233".parse().unwrap()), 270 | persistent_keepalive_interval: Some(111), 271 | replace_allowed_ips: true, 272 | allowed_ips: ["192.168.4.6/32".parse().unwrap()].into_iter().collect(), 273 | }, 274 | SetPeer { 275 | public_key: crypto::decode_from_hex( 276 | "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58" 277 | ) 278 | .try_into() 279 | .unwrap(), 280 | remove: false, 281 | update_only: false, 282 | psk: None, 283 | endpoint: Some("5.152.198.39:51820".parse().unwrap()), 284 | persistent_keepalive_interval: None, 285 | replace_allowed_ips: true, 286 | allowed_ips: [ 287 | "192.168.4.10/32".parse().unwrap(), 288 | "192.168.4.11/32".parse().unwrap(), 289 | ] 290 | .into_iter() 291 | .collect(), 292 | }, 293 | SetPeer { 294 | public_key: crypto::decode_from_hex( 295 | "e818b58db5274087fcc1be5dc728cf53d3b5726b4cef6b9bab8f8f8c2452c25c" 296 | ) 297 | .try_into() 298 | .unwrap(), 299 | remove: true, 300 | update_only: false, 301 | psk: None, 302 | endpoint: None, 303 | persistent_keepalive_interval: None, 304 | replace_allowed_ips: false, 305 | allowed_ips: [].into_iter().collect(), 306 | } 307 | ], 308 | } 309 | ) 310 | } 311 | } 312 | -------------------------------------------------------------------------------- /src/uapi/error.rs: -------------------------------------------------------------------------------- 1 | #[derive(thiserror::Error, Debug)] 2 | pub enum Error { 3 | #[error("invalid protocol")] 4 | InvalidProtocol, 5 | #[error("invalid configuration: {0}")] 6 | InvalidConfiguration(String), 7 | #[error("IO error: {0}")] 8 | IO(#[from] std::io::Error), 9 | } 10 | -------------------------------------------------------------------------------- /src/uapi/mod.rs: -------------------------------------------------------------------------------- 1 | mod connection; 2 | mod error; 3 | mod protocol; 4 | 5 | pub use error::Error; 6 | 7 | use connection::Connection; 8 | use protocol::{GetDevice, GetPeer, Request, Response, SetDevice, SetPeer}; 9 | 10 | use std::path::{Path, PathBuf}; 11 | use std::time::Duration; 12 | 13 | use tokio::net::UnixListener; 14 | use tracing::{debug, error}; 15 | 16 | use crate::device::Transport; 17 | use crate::{DeviceControl, PeerConfig, Tun}; 18 | 19 | const SOCKET_DIR: &str = "/var/run/wireguard"; 20 | 21 | fn socket_path(iface: &str) -> PathBuf { 22 | Path::new(SOCKET_DIR).join(format!("{}.sock", iface)) 23 | } 24 | 25 | pub async fn bind_and_handle(device: DeviceControl) -> Result<(), Error> 26 | where 27 | T: Tun + 'static, 28 | I: Transport, 29 | { 30 | let listener = { 31 | let path = socket_path(device.tun_name()); 32 | debug!("binding uapi unix socket to {:?}", path); 33 | let _ = std::fs::remove_file(&path); // Remove existing socket 34 | let _ = std::fs::create_dir_all(path.parent().unwrap()); // Create socket dir 35 | UnixListener::bind(&path)? 36 | }; 37 | 38 | loop { 39 | let (socket, _) = listener.accept().await?; 40 | let conn = Connection::new(socket); 41 | let device = device.clone(); 42 | tokio::spawn(handle_connection(conn, device)); 43 | } 44 | } 45 | 46 | async fn handle_connection(mut conn: Connection, device: DeviceControl) 47 | where 48 | T: Tun + 'static, 49 | I: Transport, 50 | { 51 | debug!("UAPI: accepting new connection"); 52 | 53 | loop { 54 | match conn.next().await { 55 | Ok(Request::Get) => match handle_get(device.clone()).await { 56 | Ok(resp) => conn.write(resp).await, 57 | Err(e) => { 58 | error!("Failed to handle get operation: {}", e); 59 | conn.write(Response::Err).await; 60 | } 61 | }, 62 | Ok(Request::Set(req)) => match handle_set(device.clone(), req).await { 63 | Ok(()) => { 64 | conn.write(Response::Ok).await; 65 | } 66 | Err(e) => { 67 | error!("Failed to handle set operation: {}", e); 68 | conn.write(Response::Err).await; 69 | } 70 | }, 71 | Err(e) => { 72 | debug!("UAPI connection error: {}", e); 73 | conn.write(Response::Err).await; 74 | break; 75 | } 76 | } 77 | } 78 | } 79 | 80 | async fn handle_get(device: DeviceControl) -> Result 81 | where 82 | T: Tun + 'static, 83 | I: Transport, 84 | { 85 | debug!("UAPI: received GET request"); 86 | let cfg = device.config(); 87 | let mut metrics = device.metrics(); 88 | let peers = cfg 89 | .peers 90 | .into_values() 91 | .map(|p| { 92 | let m = metrics.peers.remove(&p.public_key).unwrap(); 93 | GetPeer { 94 | public_key: p.public_key, 95 | psk: p.preshared_key.unwrap_or_default(), 96 | allowed_ips: p.allowed_ips, 97 | endpoint: p.endpoint, 98 | last_handshake_at: m.last_handshake_at, 99 | tx_bytes: m.tx_bytes, 100 | rx_bytes: m.rx_bytes, 101 | persistent_keepalive_interval: p 102 | .persistent_keepalive 103 | .map(|v| v.as_secs() as u32) 104 | .unwrap_or(0), 105 | } 106 | }) 107 | .collect(); 108 | 109 | Ok(Response::Get(GetDevice { 110 | private_key: cfg.private_key, 111 | listen_port: cfg.listen_port, 112 | fwmark: 0, 113 | peers, 114 | })) 115 | } 116 | 117 | async fn handle_set(device: DeviceControl, req: SetDevice) -> Result<(), Error> 118 | where 119 | T: Tun + 'static, 120 | I: Transport, 121 | { 122 | debug!("UAPI: received SET request"); 123 | if req.replace_peers { 124 | device.clear_peers(); 125 | } 126 | if let Some(private_key) = req.private_key { 127 | device.update_private_key(private_key); 128 | } 129 | if let Some(port) = req.listen_port { 130 | device.update_listen_port(port).await.map_err(|e| { 131 | error!("Failed to update listen_port: {}", e); 132 | Error::InvalidConfiguration(e.to_string()) 133 | })?; 134 | } 135 | if let Some(_fwmark) = req.fwmark { 136 | // unsupoorted 137 | } 138 | 139 | let cfg = device.config(); 140 | for peer in req.peers { 141 | if peer.remove { 142 | device.remove_peer(&peer.public_key); 143 | break; 144 | } 145 | match cfg.peers.get(&peer.public_key).cloned() { 146 | Some(mut cfg) => { 147 | // to update 148 | if let Some(endpoint) = peer.endpoint { 149 | cfg.endpoint = Some(endpoint); 150 | } 151 | if peer.replace_allowed_ips { 152 | cfg.allowed_ips.clear(); 153 | } 154 | for ip in peer.allowed_ips { 155 | cfg.allowed_ips.insert(ip); 156 | } 157 | if let Some(psk) = peer.psk { 158 | cfg.preshared_key = Some(psk); 159 | } 160 | if let Some(interval) = peer.persistent_keepalive_interval { 161 | cfg.persistent_keepalive = Some(Duration::from_secs(interval as u64)); 162 | } 163 | 164 | device.remove_peer(&peer.public_key); 165 | device.insert_peer(cfg); 166 | } 167 | None if !peer.update_only => { 168 | device.insert_peer(PeerConfig { 169 | public_key: peer.public_key, 170 | allowed_ips: peer.allowed_ips, 171 | endpoint: peer.endpoint, 172 | preshared_key: peer.psk, 173 | persistent_keepalive: peer 174 | .persistent_keepalive_interval 175 | .map(|v| Duration::from_secs(v as u64)), 176 | }); 177 | } 178 | _ => {} 179 | } 180 | } 181 | 182 | Ok(()) 183 | } 184 | 185 | #[cfg(test)] 186 | mod tests { 187 | use super::*; 188 | 189 | #[test] 190 | fn test_socket_path() { 191 | assert_eq!( 192 | socket_path("wg0").to_string_lossy().as_ref(), 193 | "/var/run/wireguard/wg0.sock", 194 | ) 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /src/uapi/protocol.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashSet; 2 | use std::net::SocketAddr; 3 | use std::time::SystemTime; 4 | 5 | use bytes::{BufMut, Bytes, BytesMut}; 6 | 7 | use crate::noise::crypto; 8 | use crate::Cidr; 9 | 10 | pub enum Request { 11 | Get, 12 | Set(SetDevice), 13 | } 14 | 15 | pub enum Response { 16 | Ok, 17 | Get(GetDevice), 18 | Err, 19 | } 20 | 21 | #[derive(Debug, Eq, PartialEq)] 22 | pub struct SetDevice { 23 | pub private_key: Option<[u8; 32]>, 24 | pub listen_port: Option, 25 | pub fwmark: Option, 26 | pub replace_peers: bool, 27 | pub peers: Vec, 28 | } 29 | 30 | #[derive(Debug, Eq, PartialEq)] 31 | pub struct SetPeer { 32 | pub public_key: [u8; 32], 33 | pub remove: bool, 34 | pub update_only: bool, 35 | pub psk: Option<[u8; 32]>, 36 | pub endpoint: Option, 37 | pub persistent_keepalive_interval: Option, 38 | pub replace_allowed_ips: bool, 39 | pub allowed_ips: HashSet, 40 | } 41 | 42 | pub struct GetDevice { 43 | pub private_key: [u8; 32], 44 | pub listen_port: u16, 45 | pub fwmark: u32, 46 | pub peers: Vec, 47 | } 48 | 49 | pub struct GetPeer { 50 | pub public_key: [u8; 32], 51 | pub psk: [u8; 32], 52 | pub allowed_ips: HashSet, 53 | pub endpoint: Option, 54 | pub last_handshake_at: SystemTime, 55 | pub tx_bytes: u64, 56 | pub rx_bytes: u64, 57 | pub persistent_keepalive_interval: u32, 58 | } 59 | 60 | impl From for Bytes { 61 | fn from(value: GetDevice) -> Self { 62 | let mut buf = KVBuffer::new(); 63 | if value.private_key != [0u8; 32] { 64 | buf.encode_and_put("private_key", &value.private_key); 65 | } 66 | buf.put_u16("listen_port", value.listen_port); 67 | 68 | if value.fwmark != 0 { 69 | buf.put_u32("fwmark", value.fwmark); 70 | } 71 | 72 | for peer in value.peers { 73 | buf.encode_and_put("public_key", &peer.public_key); 74 | buf.encode_and_put("preshared_key", &peer.psk); 75 | for ip in peer.allowed_ips { 76 | buf.put("allowed_ip", &ip.to_string()); 77 | } 78 | if let Some(endpoint) = peer.endpoint { 79 | buf.put("endpoint", &endpoint.to_string()); 80 | } 81 | let d = peer 82 | .last_handshake_at 83 | .duration_since(SystemTime::UNIX_EPOCH) 84 | .unwrap(); 85 | buf.put_u64("last_handshake_time_sec", d.as_secs()); 86 | buf.put_u32("last_handshake_time_nsec", d.subsec_nanos()); 87 | buf.put_u64("tx_bytes", peer.tx_bytes); 88 | buf.put_u64("rx_bytes", peer.rx_bytes); 89 | buf.put_u32( 90 | "persistent_keepalive_interval", 91 | peer.persistent_keepalive_interval, 92 | ); 93 | } 94 | buf.put_u32("protocol_version", 0); 95 | buf.put_u32("errno", 0); 96 | buf.freeze() 97 | } 98 | } 99 | 100 | struct KVBuffer(BytesMut); 101 | 102 | impl KVBuffer { 103 | pub fn new() -> Self { 104 | KVBuffer(BytesMut::new()) 105 | } 106 | 107 | #[inline] 108 | pub fn put(&mut self, key: &str, value: &str) { 109 | self.0.put(format!("{}={}\n", key, value).as_bytes()); 110 | } 111 | 112 | #[inline] 113 | pub fn put_u16(&mut self, key: &str, value: u16) { 114 | self.0.put(format!("{}={}\n", key, value).as_bytes()); 115 | } 116 | 117 | #[inline] 118 | pub fn put_u32(&mut self, key: &str, value: u32) { 119 | self.0.put(format!("{}={}\n", key, value).as_bytes()); 120 | } 121 | 122 | #[inline] 123 | pub fn put_u64(&mut self, key: &str, value: u64) { 124 | self.0.put(format!("{}={}\n", key, value).as_bytes()); 125 | } 126 | 127 | #[inline] 128 | pub fn encode_and_put(&mut self, key: &str, value: &[u8]) { 129 | self.put(key, &crypto::encode_to_hex(value)); 130 | } 131 | 132 | #[inline] 133 | pub fn freeze(mut self) -> Bytes { 134 | self.0.put_slice(b"\n"); 135 | self.0.freeze() 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /tests/handshake.rs: -------------------------------------------------------------------------------- 1 | mod support; 2 | 3 | use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; 4 | use std::time::Duration; 5 | 6 | use tokio::time; 7 | 8 | use support::*; 9 | use wiretun::noise::protocol::{HandshakeInitiation, TransportData}; 10 | use wiretun::*; 11 | 12 | #[tokio::test] 13 | async fn test_noop_when_no_endpoint() { 14 | let secret = TestKit::gen_local_secret(); 15 | let tun = StubTun::new(); 16 | let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0) 17 | .await 18 | .unwrap(); 19 | let cfg = DeviceConfig::default() 20 | .private_key(secret.private_key().to_bytes()) 21 | .peer( 22 | PeerConfig::default().public_key(TestKit::gen_local_secret().public_key().to_bytes()), 23 | ); 24 | let device = Device::with_transport(tun.clone(), transport.clone(), cfg) 25 | .await 26 | .unwrap(); 27 | 28 | let _ctrl = device.control(); 29 | 30 | time::sleep(Duration::from_secs(30)).await; 31 | 32 | assert_eq!(transport.inbound_sent(), 0); 33 | assert_eq!(transport.outbound_sent(), 0); 34 | 35 | assert_eq!(tun.inbound_sent(), 0); 36 | assert_eq!(tun.outbound_sent(), 0); 37 | } 38 | 39 | #[tokio::test] 40 | async fn test_keep_initiation_when_no_response() { 41 | let secret = TestKit::gen_local_secret(); 42 | let tun = StubTun::new(); 43 | let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0) 44 | .await 45 | .unwrap(); 46 | let peer_pub = TestKit::gen_local_secret().public_key().to_bytes(); 47 | let peer_endpoint = "10.0.0.1:80".parse().unwrap(); 48 | let cfg = DeviceConfig::default() 49 | .private_key(secret.private_key().to_bytes()) 50 | .peer( 51 | PeerConfig::default() 52 | .public_key(peer_pub) 53 | .endpoint(peer_endpoint), 54 | ); 55 | let device = Device::with_transport(tun.clone(), transport.clone(), cfg) 56 | .await 57 | .unwrap(); 58 | 59 | let _ctrl = device.control(); 60 | 61 | time::sleep(Duration::from_secs(30)).await; 62 | 63 | assert_eq!(transport.inbound_sent(), 0); 64 | assert!(transport.outbound_sent() > 0); 65 | 66 | assert_eq!(tun.inbound_sent(), 0); 67 | assert_eq!(tun.outbound_sent(), 0); 68 | 69 | for _ in 0..transport.outbound_sent() { 70 | let (endpoint, data) = transport.fetch_outbound().await; 71 | assert_eq!(endpoint.dst(), peer_endpoint); 72 | let ret = HandshakeInitiation::try_from(data.as_slice()); 73 | assert!(ret.is_ok()); 74 | } 75 | } 76 | 77 | #[tokio::test] 78 | async fn test_complete_handshake() { 79 | use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; 80 | tracing_subscriber::registry() 81 | .with(tracing_subscriber::EnvFilter::new( 82 | std::env::var("RUST_LOG").unwrap_or_else(|_| "debug".into()), 83 | )) 84 | .with(tracing_subscriber::fmt::layer()) 85 | .init(); 86 | 87 | let secret1 = TestKit::gen_local_secret(); 88 | let endpoint1 = "10.10.0.1:6789".parse::().unwrap(); 89 | let endpoint2 = "10.10.0.2:1245".parse::().unwrap(); 90 | let secret2 = TestKit::gen_local_secret(); 91 | let (_device1, tun1, transport1) = { 92 | let tun = StubTun::new(); 93 | let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0) 94 | .await 95 | .unwrap(); 96 | let cfg = DeviceConfig::default() 97 | .private_key(secret1.private_key().to_bytes()) 98 | .peer( 99 | PeerConfig::default() 100 | .public_key(secret2.public_key().to_bytes()) 101 | .allowed_ip(endpoint2.ip()) 102 | .endpoint(endpoint2), 103 | ); 104 | let device = Device::with_transport(tun.clone(), transport.clone(), cfg) 105 | .await 106 | .unwrap(); 107 | (device, tun, transport) 108 | }; 109 | let (_device2, tun2, transport2) = { 110 | let tun = StubTun::new(); 111 | let transport = StubTransport::bind(Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0) 112 | .await 113 | .unwrap(); 114 | let cfg = DeviceConfig::default() 115 | .private_key(secret2.private_key().to_bytes()) 116 | .peer( 117 | PeerConfig::default() 118 | .public_key(secret1.public_key().to_bytes()) 119 | .allowed_ip(endpoint1.ip()), 120 | ); 121 | let device = Device::with_transport(tun.clone(), transport.clone(), cfg) 122 | .await 123 | .unwrap(); 124 | (device, tun, transport) 125 | }; 126 | 127 | { 128 | let (t1, t2) = (transport1.clone(), transport2.clone()); 129 | tokio::spawn(async move { 130 | loop { 131 | let (endpoint, data) = t1.fetch_outbound().await; 132 | assert_eq!(endpoint.dst(), endpoint2); 133 | let endpoint = Endpoint::new(t2.clone(), endpoint1); 134 | t2.send_inbound(&data, &endpoint).await; 135 | } 136 | }); 137 | let (t1, t2) = (transport1.clone(), transport2.clone()); 138 | tokio::spawn(async move { 139 | loop { 140 | let (endpoint, data) = t2.fetch_outbound().await; 141 | assert_eq!(endpoint.dst(), endpoint1); 142 | let endpoint = Endpoint::new(t1.clone(), endpoint2); 143 | t1.send_inbound(&data, &endpoint).await; 144 | } 145 | }); 146 | } 147 | 148 | time::sleep(Duration::from_secs(30)).await; 149 | assert_eq!(tun1.inbound_sent(), 0); 150 | assert_eq!(tun1.outbound_sent(), 0); 151 | assert_eq!(tun2.inbound_sent(), 0); 152 | assert_eq!(tun2.outbound_sent(), 0); 153 | 154 | assert!(transport1.inbound_sent() > 0); 155 | assert!(transport1.outbound_sent() > 0); 156 | assert!(transport2.inbound_sent() > 0); 157 | assert!(transport2.outbound_sent() > 0); 158 | 159 | let (mut d1_completed, mut d2_completed) = (false, false); 160 | 161 | for (_, data) in transport1.outbound_recording() { 162 | if TransportData::try_from(data.as_slice()).is_ok() { 163 | d1_completed = true; 164 | break; 165 | } 166 | } 167 | 168 | for (_, data) in transport2.outbound_recording() { 169 | if TransportData::try_from(data.as_slice()).is_ok() { 170 | d2_completed = true; 171 | break; 172 | } 173 | } 174 | 175 | assert!(d1_completed); 176 | assert!(!d2_completed); 177 | } 178 | -------------------------------------------------------------------------------- /tests/support.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused)] 2 | 3 | use std::collections::HashMap; 4 | use std::fmt::{Display, Formatter}; 5 | use std::io; 6 | use std::net::{Ipv4Addr, Ipv6Addr}; 7 | use std::sync::atomic::{AtomicU64, Ordering}; 8 | use std::sync::{Arc, Mutex as StdMutex}; 9 | 10 | use async_trait::async_trait; 11 | use rand_core::OsRng; 12 | use tokio::sync::{mpsc, Mutex}; 13 | 14 | use wiretun::noise::crypto::LocalStaticSecret; 15 | use wiretun::*; 16 | 17 | pub struct TestKit {} 18 | 19 | impl TestKit { 20 | #[inline(always)] 21 | pub fn gen_local_secret() -> LocalStaticSecret { 22 | let pri = x25519_dalek::StaticSecret::random_from_rng(OsRng).to_bytes(); 23 | LocalStaticSecret::new(pri) 24 | } 25 | } 26 | 27 | #[derive(Clone)] 28 | pub struct StubTun { 29 | inbound_sent: Arc, 30 | inbound_recording: Arc>>>, 31 | outbound_sent: Arc, 32 | outbound_recording: Arc>>>, 33 | 34 | outbound_tx: mpsc::Sender>, 35 | outbound_rx: Arc>>>, 36 | inbound_tx: mpsc::Sender>, 37 | inbound_rx: Arc>>>, 38 | } 39 | 40 | impl StubTun { 41 | pub fn new() -> Self { 42 | let (inbound_tx, inbound_rx) = mpsc::channel(256); 43 | let (outbound_tx, outbound_rx) = mpsc::channel(256); 44 | Self { 45 | inbound_sent: Arc::new(AtomicU64::new(0)), 46 | inbound_recording: Arc::new(StdMutex::new(vec![])), 47 | outbound_sent: Arc::new(AtomicU64::new(0)), 48 | outbound_recording: Arc::new(StdMutex::new(vec![])), 49 | 50 | outbound_tx, 51 | outbound_rx: Arc::new(Mutex::new(outbound_rx)), 52 | inbound_tx, 53 | inbound_rx: Arc::new(Mutex::new(inbound_rx)), 54 | } 55 | } 56 | 57 | #[inline(always)] 58 | pub fn inbound_sent(&self) -> u64 { 59 | self.outbound_sent.load(Ordering::Relaxed) 60 | } 61 | 62 | #[inline(always)] 63 | pub fn outbound_sent(&self) -> u64 { 64 | self.outbound_sent.load(Ordering::Relaxed) 65 | } 66 | 67 | #[inline(always)] 68 | pub fn inbound_recording(&self) -> Vec> { 69 | self.inbound_recording.lock().unwrap().clone() 70 | } 71 | 72 | #[inline(always)] 73 | pub fn outbound_recording(&self) -> Vec> { 74 | self.outbound_recording.lock().unwrap().clone() 75 | } 76 | 77 | pub async fn send_outbound(&self, data: &[u8]) { 78 | self.outbound_sent.fetch_add(1, Ordering::Relaxed); 79 | self.outbound_recording.lock().unwrap().push(data.to_vec()); 80 | self.outbound_tx.send(data.to_vec()).await.unwrap(); 81 | } 82 | 83 | pub async fn fetch_inbound(&self) -> Vec { 84 | let mut rx = self.inbound_rx.lock().await; 85 | rx.recv().await.unwrap() 86 | } 87 | } 88 | 89 | impl Default for StubTun { 90 | fn default() -> Self { 91 | Self::new() 92 | } 93 | } 94 | 95 | #[async_trait] 96 | impl Tun for StubTun { 97 | fn name(&self) -> &str { 98 | "stub" 99 | } 100 | 101 | fn mtu(&self) -> Result { 102 | Ok(1500) 103 | } 104 | 105 | fn set_mtu(&self, _mtu: u16) -> Result<(), TunError> { 106 | Ok(()) 107 | } 108 | 109 | async fn recv(&self) -> Result, TunError> { 110 | let mut rx = self.outbound_rx.lock().await; 111 | let data = rx.recv().await.unwrap(); 112 | Ok(data) 113 | } 114 | 115 | async fn send(&self, buf: &[u8]) -> Result<(), TunError> { 116 | self.inbound_sent.fetch_add(1, Ordering::Relaxed); 117 | self.inbound_recording.lock().unwrap().push(buf.to_vec()); 118 | self.inbound_tx.send(buf.to_vec()).await.unwrap(); 119 | Ok(()) 120 | } 121 | } 122 | 123 | type TransportPacket = (Endpoint, Vec); 124 | 125 | #[derive(Clone)] 126 | pub struct StubTransport { 127 | ipv4: Ipv4Addr, 128 | ipv6: Ipv6Addr, 129 | port: u16, 130 | 131 | inbound_sent: Arc, 132 | inbound_recording: Arc>>, 133 | outbound_sent: Arc, 134 | outbound_recording: Arc>>, 135 | 136 | outbound_tx: mpsc::Sender, 137 | outbound_rx: Arc>>, 138 | inbound_tx: mpsc::Sender, 139 | inbound_rx: Arc>>, 140 | } 141 | 142 | impl StubTransport { 143 | #[inline(always)] 144 | pub fn inbound_sent(&self) -> u64 { 145 | self.inbound_sent.load(Ordering::Relaxed) 146 | } 147 | 148 | #[inline(always)] 149 | pub fn outbound_sent(&self) -> u64 { 150 | self.outbound_sent.load(Ordering::Relaxed) 151 | } 152 | 153 | #[inline(always)] 154 | pub fn inbound_recording(&self) -> Vec<(Endpoint, Vec)> { 155 | self.inbound_recording.lock().unwrap().clone() 156 | } 157 | 158 | #[inline(always)] 159 | pub fn outbound_recording(&self) -> Vec<(Endpoint, Vec)> { 160 | self.outbound_recording.lock().unwrap().clone() 161 | } 162 | 163 | pub async fn send_inbound(&self, data: &[u8], endpoint: &Endpoint) { 164 | self.inbound_sent.fetch_add(1, Ordering::Relaxed); 165 | self.inbound_recording 166 | .lock() 167 | .unwrap() 168 | .push((endpoint.clone(), data.to_vec())); 169 | self.inbound_tx 170 | .send((endpoint.clone(), data.to_vec())) 171 | .await 172 | .unwrap(); 173 | } 174 | 175 | pub async fn fetch_outbound(&self) -> (Endpoint, Vec) { 176 | let mut rx = self.outbound_rx.lock().await; 177 | rx.recv().await.unwrap() 178 | } 179 | } 180 | 181 | impl Display for StubTransport { 182 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 183 | write!(f, "StubTransport") 184 | } 185 | } 186 | 187 | #[async_trait] 188 | impl Transport for StubTransport { 189 | async fn bind(ipv4: Ipv4Addr, ipv6: Ipv6Addr, port: u16) -> Result { 190 | let (inbound_tx, inbound_rx) = mpsc::channel(256); 191 | let (outbound_tx, outbound_rx) = mpsc::channel(256); 192 | Ok(Self { 193 | ipv4, 194 | ipv6, 195 | port, 196 | 197 | inbound_sent: Arc::new(AtomicU64::new(0)), 198 | inbound_recording: Arc::new(StdMutex::new(vec![])), 199 | outbound_sent: Arc::new(AtomicU64::new(0)), 200 | outbound_recording: Arc::new(StdMutex::new(vec![])), 201 | 202 | outbound_tx, 203 | outbound_rx: Arc::new(Mutex::new(outbound_rx)), 204 | inbound_tx, 205 | inbound_rx: Arc::new(Mutex::new(inbound_rx)), 206 | }) 207 | } 208 | 209 | fn ipv4(&self) -> Ipv4Addr { 210 | self.ipv4 211 | } 212 | 213 | fn ipv6(&self) -> Ipv6Addr { 214 | self.ipv6 215 | } 216 | 217 | fn port(&self) -> u16 { 218 | self.port 219 | } 220 | 221 | async fn send_to(&self, data: &[u8], endpoint: &Endpoint) -> Result<(), io::Error> { 222 | self.outbound_sent.fetch_add(1, Ordering::Relaxed); 223 | self.outbound_recording 224 | .lock() 225 | .unwrap() 226 | .push((endpoint.clone(), data.to_vec())); 227 | self.outbound_tx 228 | .send((endpoint.clone(), data.to_vec())) 229 | .await 230 | .unwrap(); 231 | Ok(()) 232 | } 233 | 234 | async fn recv_from(&mut self) -> Result<(Endpoint, Vec), io::Error> { 235 | let rv = self.inbound_rx.lock().await.recv().await.unwrap(); 236 | Ok(rv) 237 | } 238 | } 239 | --------------------------------------------------------------------------------