├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── README.md ├── architecture.svg ├── netns.sh └── src ├── configuration ├── config.rs ├── error.rs ├── mod.rs └── uapi │ ├── get.rs │ ├── mod.rs │ └── set.rs ├── main.rs ├── platform ├── dummy │ ├── endpoint.rs │ ├── mod.rs │ ├── tun │ │ ├── dummy.rs │ │ ├── mod.rs │ │ └── void.rs │ └── udp.rs ├── endpoint.rs ├── linux │ ├── mod.rs │ ├── tun.rs │ ├── uapi.rs │ └── udp.rs ├── mod.rs ├── tun.rs ├── uapi.rs └── udp.rs ├── util.rs └── wireguard ├── constants.rs ├── handshake ├── device.rs ├── macs.rs ├── messages.rs ├── mod.rs ├── noise.rs ├── peer.rs ├── ratelimiter.rs ├── tests.rs ├── timestamp.rs └── types.rs ├── mod.rs ├── peer.rs ├── queue.rs ├── router ├── anti_replay.rs ├── constants.rs ├── device.rs ├── ip.rs ├── messages.rs ├── mod.rs ├── peer.rs ├── queue.rs ├── receive.rs ├── route.rs ├── send.rs ├── tests │ ├── bench.rs │ ├── mod.rs │ └── tests.rs ├── types.rs └── worker.rs ├── tests.rs ├── timers.rs ├── types.rs ├── wireguard.rs └── workers.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | proptest-regressions/ 4 | Cargo.lock 5 | .idea/ -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | authors = ["Mathias Hall-Andersen "] 3 | edition = "2018" 4 | license = "MIT" 5 | name = "wireguard-rs" 6 | version = "0.1.4" 7 | 8 | [dependencies] 9 | aead = "^0.3" 10 | arraydeque = "0.4.5" 11 | blake2 = "^0.9" 12 | byteorder = "1.3" 13 | chacha20poly1305 = "^0.7" 14 | clear_on_drop = "0.2.3" 15 | cpuprofiler = {version = "*", optional = true} 16 | crossbeam-channel = "^0.5" 17 | dashmap = "^4.0" 18 | digest = "^0.9" 19 | env_logger = "^0.8" 20 | generic-array = "^0.14" 21 | hex = "0.4" 22 | hjul = "0.2.2" 23 | hmac = "^0.10" 24 | log = {version = "0.4", features = ["max_level_trace", "release_max_level_info"]} 25 | num_cpus = "^1.10" 26 | parking_lot = "^0.11" 27 | rand = "^0.7" 28 | rand_core = "^0.5" 29 | ring = "0.16" 30 | spin = "0.7" 31 | zerocopy = "0.3" 32 | 33 | [dependencies.treebitmap] 34 | package = "ip_network_table-deps-treebitmap" 35 | version = "0.5.0" 36 | 37 | [target.'cfg(unix)'.dependencies] 38 | libc = "^0.2" 39 | 40 | [dependencies.x25519-dalek] 41 | version = "^1.1" 42 | 43 | [dependencies.subtle] 44 | version = "^2.4" 45 | #features = ["nightly"] 46 | 47 | [features] 48 | profiler = ["cpuprofiler"] 49 | start_up = [] 50 | 51 | [dev-dependencies] 52 | pnet = "^0.27" 53 | proptest = "^0.10" 54 | rand_chacha = "^0.2" 55 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Mathias Hall-Andersen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rust Implementation of WireGuard 2 | 3 | ## Usage 4 | 5 | Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. 6 | With wireguard-rs, instead simply run: 7 | 8 | $ wireguard-rs wg0 9 | 10 | This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, 11 | or if your system does not support removing interfaces directly, you may instead remove the control socket via 12 | `rm -f /var/run/wireguard/wg0.sock`, which will result in wireguard-rs shutting down. 13 | 14 | When an interface is running, you may use `wg(8)` to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands. 15 | 16 | ## Platforms 17 | 18 | ### Linux 19 | 20 | This will run on Linux; 21 | however YOU SHOULD NOT RUN THIS ON LINUX. Instead use the kernel module; see the installation page for instructions. 22 | 23 | ### Windows 24 | 25 | Coming soon. 26 | 27 | ### FreeBSD 28 | 29 | Coming soon. 30 | 31 | ### OpenBSD 32 | 33 | Coming soon. 34 | 35 | ## Building 36 | 37 | The wireguard-rs project is targeting the current nightly (although it should also build with stable Rust). 38 | 39 | To build wireguard-rs (on supported platforms): 40 | 41 | 1. Obtain nightly `cargo` and `rustc` through [rustup](https://rustup.rs/) 42 | 2. Clone the repository: `git clone https://git.zx2c4.com/wireguard-rs`. 43 | 3. Run `cargo build --release` from inside the `wireguard-rs` directory. 44 | 45 | ## Architecture 46 | 47 | This section is intended for those wishing to read/contribute to the code. 48 | 49 | WireGuard Rust has a similar separation of concerns as many other implementations of various cryptographic transports: 50 | separating the handshake code from the packet protector. 51 | The handshake module implements an authenticated key-exchange (NoiseIK), 52 | which provides key-material, which is then consumed by the router module (packet protector) 53 | responsible for the actual encapsulation of transport messages (IP packets). 54 | This is illustrated below: 55 | 56 | ![Structure](architecture.svg) 57 | -------------------------------------------------------------------------------- /src/configuration/config.rs: -------------------------------------------------------------------------------- 1 | use std::mem; 2 | use std::net::{IpAddr, SocketAddr}; 3 | use std::sync::atomic::Ordering; 4 | use std::sync::{Arc, Mutex, MutexGuard}; 5 | use std::time::{Duration, SystemTime}; 6 | 7 | use x25519_dalek::{PublicKey, StaticSecret}; 8 | 9 | use super::udp::Owner; 10 | use super::*; 11 | 12 | /// The goal of the configuration interface is, among others, 13 | /// to hide the IO implementations (over which the WG device is generic), 14 | /// from the configuration and UAPI code. 15 | /// 16 | /// Furthermore it forms the simpler interface for embedding WireGuard in other applications, 17 | /// and hides the complex types of the implementation from the host application. 18 | 19 | /// Describes a snapshot of the state of a peer 20 | pub struct PeerState { 21 | pub rx_bytes: u64, 22 | pub tx_bytes: u64, 23 | pub last_handshake_time: Option<(u64, u64)>, 24 | pub public_key: PublicKey, 25 | pub allowed_ips: Vec<(IpAddr, u32)>, 26 | pub endpoint: Option, 27 | pub persistent_keepalive_interval: u64, 28 | pub preshared_key: [u8; 32], // 0^32 is the "default value" (though treated like any other psk) 29 | } 30 | 31 | pub struct WireGuardConfig(Arc>>); 32 | 33 | struct Inner { 34 | wireguard: WireGuard, 35 | port: u16, 36 | bind: Option, 37 | fwmark: Option, 38 | } 39 | 40 | impl WireGuardConfig { 41 | fn lock(&self) -> MutexGuard> { 42 | self.0.lock().unwrap() 43 | } 44 | } 45 | 46 | impl WireGuardConfig { 47 | pub fn new(wg: WireGuard) -> WireGuardConfig { 48 | WireGuardConfig(Arc::new(Mutex::new(Inner { 49 | wireguard: wg, 50 | port: 0, 51 | bind: None, 52 | fwmark: None, 53 | }))) 54 | } 55 | } 56 | 57 | impl Clone for WireGuardConfig { 58 | fn clone(&self) -> Self { 59 | WireGuardConfig(self.0.clone()) 60 | } 61 | } 62 | 63 | /// Exposed configuration interface 64 | pub trait Configuration { 65 | fn up(&self, mtu: usize) -> Result<(), ConfigError>; 66 | 67 | fn down(&self); 68 | 69 | /// Updates the private key of the device 70 | /// 71 | /// # Arguments 72 | /// 73 | /// - `sk`: The new private key (or None, if the private key should be cleared) 74 | fn set_private_key(&self, sk: Option); 75 | 76 | /// Returns the private key of the device 77 | /// 78 | /// # Returns 79 | /// 80 | /// The private if set, otherwise None. 81 | fn get_private_key(&self) -> Option; 82 | 83 | /// Returns the protocol version of the device 84 | /// 85 | /// # Returns 86 | /// 87 | /// An integer indicating the protocol version 88 | fn get_protocol_version(&self) -> usize; 89 | 90 | fn set_listen_port(&self, port: u16) -> Result<(), ConfigError>; 91 | 92 | /// Set the firewall mark (or similar, depending on platform) 93 | /// 94 | /// # Arguments 95 | /// 96 | /// - `mark`: The fwmark value 97 | /// 98 | /// # Returns 99 | /// 100 | /// An error if this operation is not supported by the underlying 101 | /// "bind" implementation. 102 | fn set_fwmark(&self, mark: Option) -> Result<(), ConfigError>; 103 | 104 | /// Removes all peers from the device 105 | fn replace_peers(&self); 106 | 107 | /// Remove the peer from the 108 | /// 109 | /// # Arguments 110 | /// 111 | /// - `peer`: The public key of the peer to remove 112 | /// 113 | /// # Returns 114 | /// 115 | /// If the peer does not exists this operation is a noop 116 | fn remove_peer(&self, peer: &PublicKey); 117 | 118 | /// Adds a new peer to the device 119 | /// 120 | /// # Arguments 121 | /// 122 | /// - `peer`: The public key of the peer to add 123 | /// 124 | /// # Returns 125 | /// 126 | /// A bool indicating if the peer was added. 127 | /// 128 | /// If the peer already exists this operation is a noop 129 | fn add_peer(&self, peer: &PublicKey) -> bool; 130 | 131 | /// Update the psk of a peer 132 | /// 133 | /// # Arguments 134 | /// 135 | /// - `peer`: The public key of the peer 136 | /// - `psk`: The new psk or None if the psk should be unset 137 | /// 138 | /// # Returns 139 | /// 140 | /// An error if no such peer exists 141 | fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]); 142 | 143 | /// Update the endpoint of the 144 | /// 145 | /// # Arguments 146 | /// 147 | /// - `peer': The public key of the peer 148 | /// - `psk` 149 | fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr); 150 | 151 | /// Update the endpoint of the 152 | /// 153 | /// # Arguments 154 | /// 155 | /// - `peer': The public key of the peer 156 | /// - `psk` 157 | fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64); 158 | 159 | /// Remove all allowed IPs from the peer 160 | /// 161 | /// # Arguments 162 | /// 163 | /// - `peer': The public key of the peer 164 | /// 165 | /// # Returns 166 | /// 167 | /// An error if no such peer exists 168 | fn replace_allowed_ips(&self, peer: &PublicKey); 169 | 170 | /// Add a new allowed subnet to the peer 171 | /// 172 | /// # Arguments 173 | /// 174 | /// - `peer`: The public key of the peer 175 | /// - `ip`: Subnet mask 176 | /// - `masklen`: 177 | /// 178 | /// # Returns 179 | /// 180 | /// An error if the peer does not exist 181 | fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32); 182 | 183 | fn get_listen_port(&self) -> Option; 184 | 185 | /// Returns the state of all peers 186 | /// 187 | /// # Returns 188 | /// 189 | /// A list of structures describing the state of each peer 190 | fn get_peers(&self) -> Vec; 191 | 192 | fn get_fwmark(&self) -> Option; 193 | } 194 | 195 | fn start_listener( 196 | mut cfg: MutexGuard>, 197 | ) -> Result<(), ConfigError> { 198 | cfg.bind = None; 199 | 200 | // create new listener 201 | let (mut readers, writer, mut owner) = match B::bind(cfg.port) { 202 | Ok(r) => r, 203 | Err(_) => { 204 | return Err(ConfigError::FailedToBind); 205 | } 206 | }; 207 | 208 | // set fwmark 209 | let _ = owner.set_fwmark(cfg.fwmark); // TODO: handle 210 | 211 | // set writer on WireGuard 212 | cfg.wireguard.set_writer(writer); 213 | 214 | // add readers 215 | while let Some(reader) = readers.pop() { 216 | cfg.wireguard.add_udp_reader(reader); 217 | } 218 | 219 | // create new UDP state 220 | cfg.bind = Some(owner); 221 | Ok(()) 222 | } 223 | 224 | impl Configuration for WireGuardConfig { 225 | fn up(&self, mtu: usize) -> Result<(), ConfigError> { 226 | log::info!("configuration, set device up"); 227 | let cfg = self.lock(); 228 | cfg.wireguard.up(mtu); 229 | start_listener(cfg) 230 | } 231 | 232 | fn down(&self) { 233 | log::info!("configuration, set device down"); 234 | let mut cfg = self.lock(); 235 | cfg.wireguard.down(); 236 | cfg.bind = None; 237 | } 238 | 239 | fn get_fwmark(&self) -> Option { 240 | self.lock().fwmark 241 | } 242 | 243 | fn set_private_key(&self, sk: Option) { 244 | log::info!("configuration, set private key"); 245 | self.lock().wireguard.set_key(sk) 246 | } 247 | 248 | fn get_private_key(&self) -> Option { 249 | self.lock().wireguard.get_sk() 250 | } 251 | 252 | fn get_protocol_version(&self) -> usize { 253 | 1 254 | } 255 | 256 | fn get_listen_port(&self) -> Option { 257 | let st = self.lock(); 258 | log::trace!("Config, Get listen port, bound: {}", st.bind.is_some()); 259 | st.bind.as_ref().map(|bind| bind.get_port()) 260 | } 261 | 262 | fn set_listen_port(&self, port: u16) -> Result<(), ConfigError> { 263 | log::trace!("Config, Set listen port: {:?}", port); 264 | 265 | // update port and take old bind 266 | let mut cfg = self.lock(); 267 | let bound: bool = { 268 | let old = mem::replace(&mut cfg.bind, None); 269 | cfg.port = port; 270 | old.is_some() 271 | }; 272 | 273 | // restart listener if bound 274 | if bound { 275 | start_listener(cfg) 276 | } else { 277 | Ok(()) 278 | } 279 | } 280 | 281 | fn set_fwmark(&self, mark: Option) -> Result<(), ConfigError> { 282 | log::trace!("Config, Set fwmark: {:?}", mark); 283 | match self.lock().bind.as_mut() { 284 | Some(bind) => { 285 | if bind.set_fwmark(mark).is_err() { 286 | Err(ConfigError::IOError) 287 | } else { 288 | Ok(()) 289 | } 290 | } 291 | None => Ok(()), 292 | } 293 | } 294 | 295 | fn replace_peers(&self) { 296 | self.lock().wireguard.clear_peers(); 297 | } 298 | 299 | fn remove_peer(&self, peer: &PublicKey) { 300 | self.lock().wireguard.remove_peer(peer); 301 | } 302 | 303 | fn add_peer(&self, peer: &PublicKey) -> bool { 304 | self.lock().wireguard.add_peer(*peer) 305 | } 306 | 307 | fn set_preshared_key(&self, peer: &PublicKey, psk: [u8; 32]) { 308 | self.lock().wireguard.set_psk(*peer, psk); 309 | } 310 | 311 | fn set_endpoint(&self, peer: &PublicKey, addr: SocketAddr) { 312 | if let Some(peer) = self.lock().wireguard.peers.read().get(peer) { 313 | peer.set_endpoint(B::Endpoint::from_address(addr)); 314 | } 315 | } 316 | 317 | fn set_persistent_keepalive_interval(&self, peer: &PublicKey, secs: u64) { 318 | if let Some(peer) = self.lock().wireguard.peers.read().get(peer) { 319 | peer.opaque().set_persistent_keepalive_interval(secs); 320 | } 321 | } 322 | 323 | fn replace_allowed_ips(&self, peer: &PublicKey) { 324 | if let Some(peer) = self.lock().wireguard.peers.read().get(peer) { 325 | peer.remove_allowed_ips(); 326 | } 327 | } 328 | 329 | fn add_allowed_ip(&self, peer: &PublicKey, ip: IpAddr, masklen: u32) { 330 | if let Some(peer) = self.lock().wireguard.peers.read().get(peer) { 331 | peer.add_allowed_ip(ip, masklen); 332 | } 333 | } 334 | 335 | /* 336 | 337 | 338 | pub fn list_peers( 339 | &self, 340 | ) -> Vec<( 341 | PublicKey, 342 | router::PeerHandle, T::Writer, B::Writer>, 343 | )> { 344 | let peers = self.peers.read(); 345 | let mut list = Vec::with_capacity(peers.len()); 346 | for (k, v) in peers.iter() { 347 | debug_assert!(k.as_bytes() == v.opaque().pk.as_bytes()); 348 | list.push((k.clone(), v.clone())); 349 | } 350 | list 351 | } 352 | */ 353 | 354 | fn get_peers(&self) -> Vec { 355 | let cfg = self.lock(); 356 | let peers = cfg.wireguard.peers.read(); 357 | let mut state = Vec::with_capacity(peers.len()); 358 | 359 | for (pk, p) in peers.iter() { 360 | // convert the system time to (secs, nano) since epoch 361 | let last_handshake_time = (*p.walltime_last_handshake.lock()).map(|t| { 362 | let duration = t 363 | .duration_since(SystemTime::UNIX_EPOCH) 364 | .unwrap_or_else(|_| Duration::from_secs(0)); 365 | (duration.as_secs(), duration.subsec_nanos() as u64) 366 | }); 367 | 368 | if let Some(psk) = cfg.wireguard.get_psk(&pk) { 369 | // extract state into PeerState 370 | state.push(PeerState { 371 | preshared_key: psk, 372 | endpoint: p.get_endpoint(), 373 | rx_bytes: p.rx_bytes.load(Ordering::Relaxed), 374 | tx_bytes: p.tx_bytes.load(Ordering::Relaxed), 375 | persistent_keepalive_interval: p.get_keepalive_interval(), 376 | allowed_ips: p.list_allowed_ips(), 377 | last_handshake_time, 378 | public_key: pk, 379 | }) 380 | } 381 | } 382 | state 383 | } 384 | } 385 | -------------------------------------------------------------------------------- /src/configuration/error.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt; 3 | 4 | #[cfg(unix)] 5 | use libc::*; 6 | 7 | #[derive(Debug)] 8 | pub enum ConfigError { 9 | FailedToBind, 10 | InvalidHexValue, 11 | InvalidPortNumber, 12 | InvalidFwmark, 13 | InvalidKey, 14 | InvalidSocketAddr, 15 | InvalidKeepaliveInterval, 16 | InvalidAllowedIp, 17 | InvalidOperation, 18 | LineTooLong, 19 | IOError, 20 | UnsupportedValue, 21 | UnsupportedProtocolVersion, 22 | } 23 | 24 | impl fmt::Display for ConfigError { 25 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 26 | write!(f, "ConfigError(errno = {})", self.errno()) 27 | } 28 | } 29 | 30 | impl Error for ConfigError { 31 | fn description(&self) -> &str { 32 | "" 33 | } 34 | 35 | fn source(&self) -> Option<&(dyn Error + 'static)> { 36 | None 37 | } 38 | } 39 | 40 | #[cfg(unix)] 41 | impl ConfigError { 42 | pub fn errno(&self) -> i32 { 43 | // TODO: obtain the correct errorno values 44 | match self { 45 | // insufficient perms 46 | ConfigError::FailedToBind => EPERM, 47 | 48 | // parsing of value failed 49 | ConfigError::InvalidHexValue => EINVAL, 50 | ConfigError::InvalidPortNumber => EINVAL, 51 | ConfigError::InvalidFwmark => EINVAL, 52 | ConfigError::InvalidSocketAddr => EINVAL, 53 | ConfigError::InvalidKeepaliveInterval => EINVAL, 54 | ConfigError::InvalidAllowedIp => EINVAL, 55 | ConfigError::InvalidOperation => EINVAL, 56 | ConfigError::UnsupportedValue => EINVAL, 57 | 58 | // other protocol errors 59 | ConfigError::LineTooLong => EPROTO, 60 | ConfigError::InvalidKey => EPROTO, 61 | ConfigError::UnsupportedProtocolVersion => EPROTO, 62 | 63 | // IO 64 | ConfigError::IOError => EIO, 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/configuration/mod.rs: -------------------------------------------------------------------------------- 1 | mod config; 2 | mod error; 3 | pub mod uapi; 4 | 5 | use super::platform::Endpoint; 6 | use super::platform::{tun, udp}; 7 | use super::wireguard::WireGuard; 8 | 9 | pub use error::ConfigError; 10 | 11 | pub use config::Configuration; 12 | pub use config::WireGuardConfig; 13 | -------------------------------------------------------------------------------- /src/configuration/uapi/get.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | 3 | use super::Configuration; 4 | 5 | pub fn serialize(writer: &mut W, config: &C) -> io::Result<()> { 6 | let mut write = |key: &'static str, value: String| { 7 | debug_assert!(value.is_ascii()); 8 | debug_assert!(key.is_ascii()); 9 | log::trace!("UAPI: return : {}={}", key, value); 10 | writer.write_all(key.as_ref())?; 11 | writer.write_all(b"=")?; 12 | writer.write_all(value.as_ref())?; 13 | writer.write_all(b"\n") 14 | }; 15 | 16 | // serialize interface 17 | config 18 | .get_private_key() 19 | .map(|sk| write("private_key", hex::encode(sk.to_bytes()))); 20 | 21 | config 22 | .get_listen_port() 23 | .map(|port| write("listen_port", port.to_string())); 24 | 25 | config 26 | .get_fwmark() 27 | .map(|fwmark| write("fwmark", fwmark.to_string())); 28 | 29 | // serialize all peers 30 | let mut peers = config.get_peers(); 31 | while let Some(p) = peers.pop() { 32 | write("public_key", hex::encode(p.public_key.as_bytes()))?; 33 | write("preshared_key", hex::encode(p.preshared_key))?; 34 | write("rx_bytes", p.rx_bytes.to_string())?; 35 | write("tx_bytes", p.tx_bytes.to_string())?; 36 | write( 37 | "persistent_keepalive_interval", 38 | p.persistent_keepalive_interval.to_string(), 39 | )?; 40 | 41 | if let Some((secs, nsecs)) = p.last_handshake_time { 42 | write("last_handshake_time_sec", secs.to_string())?; 43 | write("last_handshake_time_nsec", nsecs.to_string())?; 44 | } 45 | 46 | if let Some(endpoint) = p.endpoint { 47 | write("endpoint", endpoint.to_string())?; 48 | } 49 | 50 | for (ip, cidr) in p.allowed_ips { 51 | write("allowed_ip", ip.to_string() + "/" + &cidr.to_string())?; 52 | } 53 | } 54 | 55 | Ok(()) 56 | } 57 | -------------------------------------------------------------------------------- /src/configuration/uapi/mod.rs: -------------------------------------------------------------------------------- 1 | mod get; 2 | mod set; 3 | 4 | use std::io::{Read, Write}; 5 | 6 | use super::{ConfigError, Configuration}; 7 | 8 | use get::serialize; 9 | use set::LineParser; 10 | 11 | const MAX_LINE_LENGTH: usize = 256; 12 | 13 | pub fn handle(stream: &mut S, config: &C) { 14 | fn operation( 15 | stream: &mut S, 16 | config: &C, 17 | ) -> Result<(), ConfigError> { 18 | // read string up to maximum length (why is this not in std?) 19 | fn readline(reader: &mut R) -> Result { 20 | let mut m: [u8; 1] = [0u8]; 21 | let mut l: String = String::with_capacity(MAX_LINE_LENGTH); 22 | while reader.read_exact(&mut m).is_ok() { 23 | let c = m[0] as char; 24 | if c == '\n' { 25 | log::trace!("UAPI, line: {}", l); 26 | return Ok(l); 27 | }; 28 | l.push(c); 29 | if l.len() > MAX_LINE_LENGTH { 30 | return Err(ConfigError::LineTooLong); 31 | } 32 | } 33 | Err(ConfigError::IOError) 34 | } 35 | 36 | // split into (key, value) pair 37 | fn keypair(ln: &str) -> Result<(&str, &str), ConfigError> { 38 | let mut split = ln.splitn(2, '='); 39 | match (split.next(), split.next()) { 40 | (Some(key), Some(value)) => Ok((key, value)), 41 | _ => Err(ConfigError::LineTooLong), 42 | } 43 | }; 44 | 45 | // read operation line 46 | match readline(stream)?.as_str() { 47 | "get=1" => { 48 | log::debug!("UAPI, Get operation"); 49 | serialize(stream, config).map_err(|_| ConfigError::IOError) 50 | } 51 | "set=1" => { 52 | log::debug!("UAPI, Set operation"); 53 | let mut parser = LineParser::new(config); 54 | loop { 55 | let ln = readline(stream)?; 56 | if ln == "" { 57 | break; 58 | } 59 | let (k, v) = keypair(ln.as_str())?; 60 | parser.parse_line(k, v)?; 61 | } 62 | parser.parse_line("", "") 63 | } 64 | _ => Err(ConfigError::InvalidOperation), 65 | } 66 | } 67 | 68 | // process operation 69 | let res = operation(stream, config); 70 | log::debug!("UAPI, Result of operation: {:?}", res); 71 | 72 | // return errno 73 | let _ = stream.write("errno=".as_ref()); 74 | let _ = stream.write( 75 | match res { 76 | Err(e) => e.errno().to_string(), 77 | Ok(()) => "0".to_owned(), 78 | } 79 | .as_ref(), 80 | ); 81 | let _ = stream.write("\n\n".as_ref()); 82 | } 83 | -------------------------------------------------------------------------------- /src/configuration/uapi/set.rs: -------------------------------------------------------------------------------- 1 | use hex::FromHex; 2 | use std::net::{IpAddr, SocketAddr}; 3 | use subtle::ConstantTimeEq; 4 | use x25519_dalek::{PublicKey, StaticSecret}; 5 | 6 | use super::{ConfigError, Configuration}; 7 | 8 | enum ParserState { 9 | Peer(ParsedPeer), 10 | Interface, 11 | } 12 | 13 | struct ParsedPeer { 14 | public_key: PublicKey, 15 | update_only: bool, 16 | allowed_ips: Vec<(IpAddr, u32)>, 17 | remove: bool, 18 | preshared_key: Option<[u8; 32]>, 19 | replace_allowed_ips: bool, 20 | persistent_keepalive_interval: Option, 21 | protocol_version: Option, 22 | endpoint: Option, 23 | } 24 | 25 | pub struct LineParser<'a, C: Configuration> { 26 | config: &'a C, 27 | state: ParserState, 28 | } 29 | 30 | impl<'a, C: Configuration> LineParser<'a, C> { 31 | pub fn new(config: &'a C) -> LineParser<'a, C> { 32 | LineParser { 33 | config, 34 | state: ParserState::Interface, 35 | } 36 | } 37 | 38 | fn new_peer(value: &str) -> Result { 39 | match <[u8; 32]>::from_hex(value) { 40 | Ok(pk) => Ok(ParserState::Peer(ParsedPeer { 41 | public_key: PublicKey::from(pk), 42 | remove: false, 43 | update_only: false, 44 | allowed_ips: vec![], 45 | preshared_key: None, 46 | replace_allowed_ips: false, 47 | persistent_keepalive_interval: None, 48 | protocol_version: None, 49 | endpoint: None, 50 | })), 51 | Err(_) => Err(ConfigError::InvalidHexValue), 52 | } 53 | } 54 | 55 | pub fn parse_line(&mut self, key: &str, value: &str) -> Result<(), ConfigError> { 56 | #[cfg(debug)] 57 | { 58 | if key.len() > 0 { 59 | log::debug!("UAPI: {}={}", key, value); 60 | } 61 | } 62 | 63 | // flush peer updates to configuration 64 | fn flush_peer(config: &C, peer: &ParsedPeer) -> Option { 65 | if peer.remove { 66 | log::trace!("flush peer, remove peer"); 67 | config.remove_peer(&peer.public_key); 68 | return None; 69 | } 70 | 71 | if !peer.update_only { 72 | log::trace!("flush peer, add peer"); 73 | config.add_peer(&peer.public_key); 74 | } 75 | 76 | for (ip, cidr) in &peer.allowed_ips { 77 | log::trace!("flush peer, add allowed_ips : {}/{}", ip.to_string(), cidr); 78 | config.add_allowed_ip(&peer.public_key, *ip, *cidr); 79 | } 80 | 81 | if let Some(psk) = peer.preshared_key { 82 | log::trace!("flush peer, set preshared_key {}", hex::encode(psk)); 83 | config.set_preshared_key(&peer.public_key, psk); 84 | } 85 | 86 | if let Some(secs) = peer.persistent_keepalive_interval { 87 | log::trace!("flush peer, set persistent_keepalive_interval {}", secs); 88 | config.set_persistent_keepalive_interval(&peer.public_key, secs); 89 | } 90 | 91 | if let Some(version) = peer.protocol_version { 92 | log::trace!("flush peer, set protocol_version {}", version); 93 | if version == 0 || version > config.get_protocol_version() { 94 | return Some(ConfigError::UnsupportedProtocolVersion); 95 | } 96 | } 97 | 98 | if let Some(endpoint) = peer.endpoint { 99 | log::trace!("flush peer, set endpoint {}", endpoint.to_string()); 100 | config.set_endpoint(&peer.public_key, endpoint); 101 | }; 102 | 103 | None 104 | }; 105 | 106 | // parse line and update parser state 107 | match self.state { 108 | // configure the interface 109 | ParserState::Interface => match key { 110 | // opt: set private key 111 | "private_key" => match <[u8; 32]>::from_hex(value) { 112 | Ok(sk) => { 113 | self.config.set_private_key(if sk.ct_eq(&[0u8; 32]).into() { 114 | None 115 | } else { 116 | Some(StaticSecret::from(sk)) 117 | }); 118 | Ok(()) 119 | } 120 | Err(_) => Err(ConfigError::InvalidHexValue), 121 | }, 122 | 123 | // opt: set listen port 124 | "listen_port" => match value.parse() { 125 | Ok(port) => { 126 | self.config.set_listen_port(port)?; 127 | Ok(()) 128 | } 129 | Err(_) => Err(ConfigError::InvalidPortNumber), 130 | }, 131 | 132 | // opt: set fwmark 133 | "fwmark" => match value.parse() { 134 | Ok(fwmark) => { 135 | self.config 136 | .set_fwmark(if fwmark == 0 { None } else { Some(fwmark) })?; 137 | Ok(()) 138 | } 139 | Err(_) => Err(ConfigError::InvalidFwmark), 140 | }, 141 | 142 | // opt: remove all peers 143 | "replace_peers" => match value { 144 | "true" => { 145 | for p in self.config.get_peers() { 146 | self.config.remove_peer(&p.public_key) 147 | } 148 | Ok(()) 149 | } 150 | _ => Err(ConfigError::UnsupportedValue), 151 | }, 152 | 153 | // opt: transition to peer configuration 154 | "public_key" => { 155 | self.state = Self::new_peer(value)?; 156 | Ok(()) 157 | } 158 | 159 | // ignore (end of transcript) 160 | "" => Ok(()), 161 | 162 | // unknown key 163 | _ => Err(ConfigError::InvalidKey), 164 | }, 165 | 166 | // configure peers 167 | ParserState::Peer(ref mut peer) => match key { 168 | // opt: new peer 169 | "public_key" => { 170 | flush_peer(self.config, &peer); 171 | self.state = Self::new_peer(value)?; 172 | Ok(()) 173 | } 174 | 175 | // opt: remove peer 176 | "remove" => { 177 | peer.remove = true; 178 | Ok(()) 179 | } 180 | 181 | // opt: update only 182 | "update_only" => { 183 | peer.update_only = true; 184 | Ok(()) 185 | } 186 | 187 | // opt: set preshared key 188 | "preshared_key" => match <[u8; 32]>::from_hex(value) { 189 | Ok(psk) => { 190 | peer.preshared_key = Some(psk); 191 | Ok(()) 192 | } 193 | Err(_) => Err(ConfigError::InvalidHexValue), 194 | }, 195 | 196 | // opt: set endpoint 197 | "endpoint" => match value.parse() { 198 | Ok(endpoint) => { 199 | peer.endpoint = Some(endpoint); 200 | Ok(()) 201 | } 202 | Err(_) => Err(ConfigError::InvalidSocketAddr), 203 | }, 204 | 205 | // opt: set persistent keepalive interval 206 | "persistent_keepalive_interval" => match value.parse() { 207 | Ok(secs) => { 208 | peer.persistent_keepalive_interval = Some(secs); 209 | Ok(()) 210 | } 211 | Err(_) => Err(ConfigError::InvalidKeepaliveInterval), 212 | }, 213 | 214 | // opt replace allowed ips 215 | "replace_allowed_ips" => { 216 | peer.replace_allowed_ips = true; 217 | peer.allowed_ips.clear(); 218 | Ok(()) 219 | } 220 | 221 | // opt add allowed ips 222 | "allowed_ip" => { 223 | let mut split = value.splitn(2, '/'); 224 | let addr = split.next().and_then(|x| x.parse().ok()); 225 | let cidr = split.next().and_then(|x| x.parse().ok()); 226 | match (addr, cidr) { 227 | (Some(addr), Some(cidr)) => { 228 | peer.allowed_ips.push((addr, cidr)); 229 | Ok(()) 230 | } 231 | _ => Err(ConfigError::InvalidAllowedIp), 232 | } 233 | } 234 | 235 | // set protocol version of peer 236 | "protocol_version" => { 237 | let parse_res: Result = value.parse(); 238 | match parse_res { 239 | Ok(version) => { 240 | peer.protocol_version = Some(version); 241 | Ok(()) 242 | } 243 | Err(_) => Err(ConfigError::UnsupportedProtocolVersion), 244 | } 245 | } 246 | 247 | // flush (used at end of transcipt) 248 | "" => { 249 | log::trace!("UAPI, Set, processes end of transaction"); 250 | flush_peer(self.config, &peer); 251 | Ok(()) 252 | } 253 | 254 | // unknown key 255 | _ => Err(ConfigError::InvalidKey), 256 | }, 257 | } 258 | } 259 | } 260 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(feature = "unstable", feature(test))] 2 | 3 | extern crate alloc; 4 | 5 | #[cfg(feature = "profiler")] 6 | extern crate cpuprofiler; 7 | 8 | #[cfg(feature = "profiler")] 9 | use cpuprofiler::PROFILER; 10 | 11 | mod configuration; 12 | mod platform; 13 | mod wireguard; 14 | 15 | mod util; 16 | 17 | use std::env; 18 | use std::process::exit; 19 | use std::thread; 20 | 21 | use configuration::Configuration; 22 | 23 | use platform::tun::{PlatformTun, Status}; 24 | use platform::uapi::{BindUAPI, PlatformUAPI}; 25 | use platform::*; 26 | 27 | use wireguard::WireGuard; 28 | 29 | #[cfg(feature = "profiler")] 30 | fn profiler_stop() { 31 | println!("Stopping profiler"); 32 | PROFILER.lock().unwrap().stop().unwrap(); 33 | } 34 | 35 | #[cfg(not(feature = "profiler"))] 36 | fn profiler_stop() {} 37 | 38 | #[cfg(feature = "profiler")] 39 | fn profiler_start(name: &str) { 40 | use std::path::Path; 41 | 42 | // find first available path to save profiler output 43 | let mut n = 0; 44 | loop { 45 | let path = format!("./{}-{}.profile", name, n); 46 | if !Path::new(path.as_str()).exists() { 47 | println!("Starting profiler: {}", path); 48 | PROFILER.lock().unwrap().start(path).unwrap(); 49 | break; 50 | }; 51 | n += 1; 52 | } 53 | } 54 | 55 | fn main() { 56 | // parse command line arguments 57 | let mut name = None; 58 | let mut drop_privileges = true; 59 | let mut foreground = false; 60 | let mut args = env::args(); 61 | 62 | // skip path (argv[0]) 63 | args.next(); 64 | for arg in args { 65 | match arg.as_str() { 66 | "--foreground" | "-f" => { 67 | foreground = true; 68 | } 69 | "--disable-drop-privileges" => { 70 | drop_privileges = false; 71 | } 72 | dev => name = Some(dev.to_owned()), 73 | } 74 | } 75 | 76 | // unwrap device name 77 | let name = match name { 78 | None => { 79 | eprintln!("No device name supplied"); 80 | exit(-1); 81 | } 82 | Some(name) => name, 83 | }; 84 | 85 | // create UAPI socket 86 | let uapi = plt::UAPI::bind(name.as_str()).unwrap_or_else(|e| { 87 | eprintln!("Failed to create UAPI listener: {}", e); 88 | exit(-2); 89 | }); 90 | 91 | // create TUN device 92 | let (mut readers, writer, status) = plt::Tun::create(name.as_str()).unwrap_or_else(|e| { 93 | eprintln!("Failed to create TUN device: {}", e); 94 | exit(-3); 95 | }); 96 | 97 | // drop privileges 98 | if drop_privileges { 99 | match util::drop_privileges() { 100 | Ok(_) => (), 101 | Err(e) => { 102 | eprintln!("Failed to drop privileges: {}", e); 103 | exit(-4); 104 | } 105 | } 106 | } 107 | 108 | // daemonize to background 109 | if !foreground { 110 | match util::daemonize() { 111 | Ok(_) => (), 112 | Err(e) => { 113 | eprintln!("Failed to daemonize: {}", e); 114 | exit(-5); 115 | } 116 | } 117 | } 118 | 119 | // start logging 120 | env_logger::builder() 121 | .try_init() 122 | .expect("Failed to initialize event logger"); 123 | 124 | log::info!("Starting {} WireGuard device.", name); 125 | 126 | // start profiler (if enabled) 127 | #[cfg(feature = "profiler")] 128 | profiler_start(name.as_str()); 129 | 130 | // create WireGuard device 131 | let wg: WireGuard = WireGuard::new(writer); 132 | 133 | // add all Tun readers 134 | while let Some(reader) = readers.pop() { 135 | wg.add_tun_reader(reader); 136 | } 137 | 138 | // wrap in configuration interface 139 | let cfg = configuration::WireGuardConfig::new(wg.clone()); 140 | 141 | // start Tun event thread 142 | { 143 | let cfg = cfg.clone(); 144 | let mut status = status; 145 | thread::spawn(move || loop { 146 | match status.event() { 147 | Err(e) => { 148 | log::info!("Tun device error {}", e); 149 | profiler_stop(); 150 | exit(0); 151 | } 152 | Ok(tun::TunEvent::Up(mtu)) => { 153 | log::info!("Tun up (mtu = {})", mtu); 154 | let _ = cfg.up(mtu); // TODO: handle 155 | } 156 | Ok(tun::TunEvent::Down) => { 157 | log::info!("Tun down"); 158 | cfg.down(); 159 | } 160 | } 161 | }); 162 | } 163 | 164 | // start UAPI server 165 | thread::spawn(move || loop { 166 | // accept and handle UAPI config connections 167 | match uapi.connect() { 168 | Ok(mut stream) => { 169 | let cfg = cfg.clone(); 170 | thread::spawn(move || { 171 | configuration::uapi::handle(&mut stream, &cfg); 172 | }); 173 | } 174 | Err(err) => { 175 | log::info!("UAPI connection error: {}", err); 176 | profiler_stop(); 177 | exit(-1); 178 | } 179 | } 180 | }); 181 | 182 | // block until all tun readers closed 183 | wg.wait(); 184 | profiler_stop(); 185 | } 186 | -------------------------------------------------------------------------------- /src/platform/dummy/endpoint.rs: -------------------------------------------------------------------------------- 1 | use std::net::SocketAddr; 2 | 3 | use super::super::Endpoint; 4 | 5 | #[derive(Clone, Copy)] 6 | pub struct UnitEndpoint {} 7 | 8 | impl Endpoint for UnitEndpoint { 9 | fn from_address(_: SocketAddr) -> UnitEndpoint { 10 | UnitEndpoint {} 11 | } 12 | 13 | fn into_address(&self) -> SocketAddr { 14 | "127.0.0.1:8080".parse().unwrap() 15 | } 16 | 17 | fn clear_src(&mut self) {} 18 | } 19 | 20 | impl UnitEndpoint { 21 | pub fn new() -> UnitEndpoint { 22 | UnitEndpoint {} 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/platform/dummy/mod.rs: -------------------------------------------------------------------------------- 1 | mod endpoint; 2 | mod tun; 3 | mod udp; 4 | 5 | /* A pure dummy platform available during "test-time" 6 | * 7 | * The use of the dummy platform is to enable unit testing of full WireGuard, 8 | * the configuration interface and the UAPI parser. 9 | */ 10 | 11 | pub use endpoint::*; 12 | pub use tun::*; 13 | pub use udp::*; 14 | -------------------------------------------------------------------------------- /src/platform/dummy/tun/dummy.rs: -------------------------------------------------------------------------------- 1 | // This provides a mock tunnel interface. 2 | // Which enables unit tests where WireGuard interfaces 3 | // are configured to match each other and a full test of: 4 | // 5 | // - Handshake 6 | // - Transport encryption/decryption 7 | // 8 | // Can be executed. 9 | 10 | use super::*; 11 | 12 | use std::cmp::min; 13 | use std::error::Error; 14 | use std::fmt; 15 | use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; 16 | use std::sync::Mutex; 17 | use std::thread; 18 | use std::time::Duration; 19 | 20 | use hex; 21 | use log::debug; 22 | use rand::rngs::OsRng; 23 | use rand::Rng; 24 | 25 | pub struct TunTest {} 26 | 27 | // Represents the "other end" (kernel/OS end) of the TUN connection: 28 | // 29 | // Used to send/receive packets to the mock WireGuard interface. 30 | pub struct TunFakeIO { 31 | id: u32, 32 | store: bool, 33 | tx: SyncSender>, 34 | rx: Receiver>, 35 | } 36 | 37 | pub struct TunReader { 38 | id: u32, 39 | rx: Receiver>, 40 | } 41 | 42 | pub struct TunWriter { 43 | id: u32, 44 | store: bool, 45 | tx: Mutex>>, 46 | } 47 | 48 | impl fmt::Display for TunFakeIO { 49 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 50 | write!(f, "FakeIO({})", self.id) 51 | } 52 | } 53 | 54 | impl fmt::Display for TunReader { 55 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 56 | write!(f, "TunReader({})", self.id) 57 | } 58 | } 59 | 60 | impl fmt::Display for TunWriter { 61 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 62 | write!(f, "TunWriter({})", self.id) 63 | } 64 | } 65 | 66 | pub struct TunStatus { 67 | first: bool, 68 | } 69 | 70 | impl Error for TunError { 71 | fn description(&self) -> &str { 72 | "Generic Tun Error" 73 | } 74 | 75 | fn source(&self) -> Option<&(dyn Error + 'static)> { 76 | None 77 | } 78 | } 79 | 80 | impl fmt::Display for TunError { 81 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 82 | write!(f, "Not Possible") 83 | } 84 | } 85 | 86 | impl Reader for TunReader { 87 | type Error = TunError; 88 | 89 | fn read(&self, buf: &mut [u8], offset: usize) -> Result { 90 | match self.rx.recv() { 91 | Ok(msg) => { 92 | let n = min(buf.len() - offset, msg.len()); 93 | buf[offset..offset + n].copy_from_slice(&msg[..n]); 94 | debug!( 95 | "dummy::TUN({}) : read ({}, {})", 96 | self.id, 97 | n, 98 | hex::encode(&buf[offset..offset + n]) 99 | ); 100 | Ok(n) 101 | } 102 | Err(_) => Err(TunError::Disconnected), 103 | } 104 | } 105 | } 106 | 107 | impl Writer for TunWriter { 108 | type Error = TunError; 109 | 110 | fn write(&self, src: &[u8]) -> Result<(), Self::Error> { 111 | debug!( 112 | "dummy::TUN({}) : write ({}, {})", 113 | self.id, 114 | src.len(), 115 | hex::encode(src) 116 | ); 117 | if self.store { 118 | let m = src.to_owned(); 119 | match self.tx.lock().unwrap().send(m) { 120 | Ok(_) => Ok(()), 121 | Err(_) => Err(TunError::Disconnected), 122 | } 123 | } else { 124 | Ok(()) 125 | } 126 | } 127 | } 128 | 129 | impl Status for TunStatus { 130 | type Error = TunError; 131 | 132 | fn event(&mut self) -> Result { 133 | if self.first { 134 | self.first = false; 135 | return Ok(TunEvent::Up(1420)); 136 | } 137 | 138 | loop { 139 | thread::sleep(Duration::from_secs(60 * 60)); 140 | } 141 | } 142 | } 143 | 144 | impl Tun for TunTest { 145 | type Writer = TunWriter; 146 | type Reader = TunReader; 147 | type Error = TunError; 148 | } 149 | 150 | impl TunFakeIO { 151 | pub fn write(&self, msg: Vec) { 152 | if self.store { 153 | self.tx.send(msg).unwrap(); 154 | } 155 | } 156 | 157 | pub fn read(&self) -> Vec { 158 | self.rx.recv().unwrap() 159 | } 160 | } 161 | 162 | impl TunTest { 163 | pub fn create(store: bool) -> (TunFakeIO, TunReader, TunWriter, TunStatus) { 164 | let (tx1, rx1) = if store { 165 | sync_channel(32) 166 | } else { 167 | sync_channel(1) 168 | }; 169 | let (tx2, rx2) = if store { 170 | sync_channel(32) 171 | } else { 172 | sync_channel(1) 173 | }; 174 | 175 | let id: u32 = OsRng.gen(); 176 | 177 | let fake = TunFakeIO { 178 | id, 179 | tx: tx1, 180 | rx: rx2, 181 | store, 182 | }; 183 | let reader = TunReader { id, rx: rx1 }; 184 | let writer = TunWriter { 185 | id, 186 | tx: Mutex::new(tx2), 187 | store, 188 | }; 189 | let status = TunStatus { first: true }; 190 | (fake, reader, writer, status) 191 | } 192 | } 193 | 194 | impl PlatformTun for TunTest { 195 | type Status = TunStatus; 196 | 197 | fn create(_name: &str) -> Result<(Vec, Self::Writer, Self::Status), Self::Error> { 198 | Err(TunError::Disconnected) 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /src/platform/dummy/tun/mod.rs: -------------------------------------------------------------------------------- 1 | use super::super::tun::*; 2 | 3 | mod dummy; 4 | mod void; 5 | 6 | #[derive(Debug)] 7 | pub enum TunError { 8 | Disconnected, 9 | } 10 | 11 | pub use dummy::*; 12 | pub use void::*; 13 | -------------------------------------------------------------------------------- /src/platform/dummy/tun/void.rs: -------------------------------------------------------------------------------- 1 | /* 2 | // This code provides a "void" implementation of the tunnel interface: 3 | // The implementation never reads and immediately discards any write without error 4 | // 5 | // This is used during benchmarking and profiling of the inbound path. 6 | 7 | use super::*; 8 | 9 | pub struct VoidTun {} 10 | 11 | pub struct VoidReader {} 12 | 13 | pub struct VoidWriter {} 14 | 15 | impl Tun for VoidTun { 16 | type Writer = VoidWriter; 17 | type Reader = VoidReader; 18 | type Error = TunError; 19 | } 20 | 21 | 22 | impl Reader for VodReader { 23 | type Error = TunError; 24 | 25 | fn write(&self, src: &[u8]) -> Result<(), Self::Error> { 26 | debug!( 27 | "dummy::TUN({}) : write ({}, {})", 28 | self.id, 29 | src.len(), 30 | hex::encode(src) 31 | ); 32 | if self.store { 33 | let m = src.to_owned(); 34 | match self.tx.lock().unwrap().send(m) { 35 | Ok(_) => Ok(()), 36 | Err(_) => Err(TunError::Disconnected), 37 | } 38 | } else { 39 | Ok(()) 40 | } 41 | } 42 | } 43 | */ 44 | -------------------------------------------------------------------------------- /src/platform/dummy/udp.rs: -------------------------------------------------------------------------------- 1 | use hex; 2 | use std::error::Error; 3 | use std::fmt; 4 | use std::marker; 5 | 6 | use log::debug; 7 | use rand::rngs::OsRng; 8 | use rand::Rng; 9 | 10 | use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; 11 | use std::sync::Arc; 12 | use std::sync::Mutex; 13 | 14 | use super::super::udp::*; 15 | 16 | use super::UnitEndpoint; 17 | 18 | pub struct VoidOwner {} 19 | 20 | #[derive(Debug)] 21 | pub enum BindError { 22 | Disconnected, 23 | } 24 | 25 | impl Error for BindError { 26 | fn description(&self) -> &str { 27 | "Generic Bind Error" 28 | } 29 | 30 | fn source(&self) -> Option<&(dyn Error + 'static)> { 31 | None 32 | } 33 | } 34 | 35 | impl fmt::Display for BindError { 36 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 37 | match self { 38 | BindError::Disconnected => write!(f, "PairBind disconnected"), 39 | } 40 | } 41 | } 42 | 43 | #[derive(Clone, Copy)] 44 | pub struct VoidBind {} 45 | 46 | impl Reader for VoidBind { 47 | type Error = BindError; 48 | 49 | fn read(&self, _buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { 50 | Ok((0, UnitEndpoint {})) 51 | } 52 | } 53 | 54 | impl Writer for VoidBind { 55 | type Error = BindError; 56 | 57 | fn write(&self, _buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> { 58 | Ok(()) 59 | } 60 | } 61 | 62 | impl UDP for VoidBind { 63 | type Error = BindError; 64 | type Endpoint = UnitEndpoint; 65 | 66 | type Reader = VoidBind; 67 | type Writer = VoidBind; 68 | } 69 | 70 | impl VoidBind { 71 | pub fn new() -> VoidBind { 72 | VoidBind {} 73 | } 74 | } 75 | 76 | /* Pair Bind */ 77 | 78 | #[derive(Clone)] 79 | pub struct PairReader { 80 | id: u32, 81 | recv: Arc>>>, 82 | _marker: marker::PhantomData, 83 | } 84 | 85 | impl Reader for PairReader { 86 | type Error = BindError; 87 | fn read(&self, buf: &mut [u8]) -> Result<(usize, UnitEndpoint), Self::Error> { 88 | let vec = self 89 | .recv 90 | .lock() 91 | .unwrap() 92 | .recv() 93 | .map_err(|_| BindError::Disconnected)?; 94 | let len = vec.len(); 95 | buf[..len].copy_from_slice(&vec[..]); 96 | debug!( 97 | "dummy({}): read ({}, {})", 98 | self.id, 99 | len, 100 | hex::encode(&buf[..len]) 101 | ); 102 | Ok((len, UnitEndpoint {})) 103 | } 104 | } 105 | 106 | impl Writer for PairWriter { 107 | type Error = BindError; 108 | fn write(&self, buf: &[u8], _dst: &mut UnitEndpoint) -> Result<(), Self::Error> { 109 | debug!( 110 | "dummy({}): write ({}, {})", 111 | self.id, 112 | buf.len(), 113 | hex::encode(buf) 114 | ); 115 | let owned = buf.to_owned(); 116 | match self.send.lock().unwrap().send(owned) { 117 | Err(_) => Err(BindError::Disconnected), 118 | Ok(_) => Ok(()), 119 | } 120 | } 121 | } 122 | 123 | #[derive(Clone)] 124 | pub struct PairWriter { 125 | id: u32, 126 | send: Arc>>>, 127 | _marker: marker::PhantomData, 128 | } 129 | 130 | #[derive(Clone)] 131 | pub struct PairBind {} 132 | 133 | impl PairBind { 134 | pub fn pair() -> ( 135 | (PairReader, PairWriter), 136 | (PairReader, PairWriter), 137 | ) { 138 | let id1: u32 = OsRng.gen(); 139 | let id2: u32 = OsRng.gen(); 140 | 141 | let (tx1, rx1) = sync_channel(128); 142 | let (tx2, rx2) = sync_channel(128); 143 | ( 144 | ( 145 | PairReader { 146 | id: id1, 147 | recv: Arc::new(Mutex::new(rx1)), 148 | _marker: marker::PhantomData, 149 | }, 150 | PairWriter { 151 | id: id1, 152 | send: Arc::new(Mutex::new(tx2)), 153 | _marker: marker::PhantomData, 154 | }, 155 | ), 156 | ( 157 | PairReader { 158 | id: id2, 159 | recv: Arc::new(Mutex::new(rx2)), 160 | _marker: marker::PhantomData, 161 | }, 162 | PairWriter { 163 | id: id2, 164 | send: Arc::new(Mutex::new(tx1)), 165 | _marker: marker::PhantomData, 166 | }, 167 | ), 168 | ) 169 | } 170 | } 171 | 172 | impl UDP for PairBind { 173 | type Error = BindError; 174 | type Endpoint = UnitEndpoint; 175 | type Reader = PairReader; 176 | type Writer = PairWriter; 177 | } 178 | 179 | impl Owner for VoidOwner { 180 | type Error = BindError; 181 | 182 | fn set_fwmark(&mut self, _value: Option) -> Result<(), Self::Error> { 183 | Ok(()) 184 | } 185 | 186 | fn get_port(&self) -> u16 { 187 | 0 188 | } 189 | } 190 | 191 | impl PlatformUDP for PairBind { 192 | type Owner = VoidOwner; 193 | fn bind(_port: u16) -> Result<(Vec, Self::Writer, Self::Owner), Self::Error> { 194 | Err(BindError::Disconnected) 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /src/platform/endpoint.rs: -------------------------------------------------------------------------------- 1 | use std::net::SocketAddr; 2 | 3 | pub trait Endpoint: Send + 'static { 4 | fn from_address(addr: SocketAddr) -> Self; 5 | fn into_address(&self) -> SocketAddr; 6 | fn clear_src(&mut self); 7 | } 8 | -------------------------------------------------------------------------------- /src/platform/linux/mod.rs: -------------------------------------------------------------------------------- 1 | mod tun; 2 | mod uapi; 3 | mod udp; 4 | 5 | pub use tun::LinuxTun as Tun; 6 | pub use uapi::LinuxUAPI as UAPI; 7 | pub use udp::LinuxUDP as UDP; 8 | -------------------------------------------------------------------------------- /src/platform/linux/tun.rs: -------------------------------------------------------------------------------- 1 | use super::super::tun::*; 2 | 3 | use std::error::Error; 4 | use std::fmt; 5 | use std::mem; 6 | use std::os::raw::c_short; 7 | use std::os::unix::io::RawFd; 8 | 9 | const TUNSETIFF: u64 = 0x4004_54ca; 10 | const CLONE_DEVICE_PATH: &[u8] = b"/dev/net/tun\0"; 11 | 12 | #[repr(C)] 13 | struct Ifreq { 14 | name: [u8; libc::IFNAMSIZ], 15 | flags: c_short, 16 | _pad: [u8; 64], 17 | } 18 | 19 | // man 7 rtnetlink 20 | // Layout from: https://elixir.bootlin.com/linux/latest/source/include/uapi/linux/rtnetlink.h#L516 21 | #[repr(C)] 22 | struct IfInfomsg { 23 | ifi_family: libc::c_uchar, 24 | __ifi_pad: libc::c_uchar, 25 | ifi_type: libc::c_ushort, 26 | ifi_index: libc::c_int, 27 | ifi_flags: libc::c_uint, 28 | ifi_change: libc::c_uint, 29 | } 30 | 31 | pub struct LinuxTun {} 32 | 33 | pub struct LinuxTunReader { 34 | fd: RawFd, 35 | } 36 | 37 | pub struct LinuxTunWriter { 38 | fd: RawFd, 39 | } 40 | 41 | pub struct LinuxTunStatus { 42 | events: Vec, 43 | index: i32, 44 | name: [u8; libc::IFNAMSIZ], 45 | fd: RawFd, 46 | } 47 | 48 | #[derive(Debug)] 49 | pub enum LinuxTunError { 50 | InvalidTunDeviceName, 51 | FailedToOpenCloneDevice, 52 | SetIFFIoctlFailed, 53 | GetMTUIoctlFailed, 54 | NetlinkFailure, 55 | Closed, // TODO 56 | } 57 | 58 | impl fmt::Display for LinuxTunError { 59 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 60 | match self { 61 | LinuxTunError::InvalidTunDeviceName => write!(f, "Invalid name (too long)"), 62 | LinuxTunError::FailedToOpenCloneDevice => { 63 | write!(f, "Failed to obtain fd for clone device") 64 | } 65 | LinuxTunError::SetIFFIoctlFailed => { 66 | write!(f, "set_iff ioctl failed (insufficient permissions?)") 67 | } 68 | LinuxTunError::Closed => write!(f, "The tunnel has been closed"), 69 | LinuxTunError::GetMTUIoctlFailed => write!(f, "ifmtu ioctl failed"), 70 | LinuxTunError::NetlinkFailure => write!(f, "Netlink listener error"), 71 | } 72 | } 73 | } 74 | 75 | impl Error for LinuxTunError { 76 | fn source(&self) -> Option<&(dyn Error + 'static)> { 77 | unimplemented!() 78 | } 79 | 80 | fn description(&self) -> &str { 81 | unimplemented!() 82 | } 83 | } 84 | 85 | impl Reader for LinuxTunReader { 86 | type Error = LinuxTunError; 87 | 88 | fn read(&self, buf: &mut [u8], offset: usize) -> Result { 89 | /* 90 | debug_assert!( 91 | offset < buf.len(), 92 | "There is no space for the body of the read" 93 | ); 94 | */ 95 | let n: isize = 96 | unsafe { libc::read(self.fd, buf[offset..].as_mut_ptr() as _, buf.len() - offset) }; 97 | if n < 0 { 98 | Err(LinuxTunError::Closed) 99 | } else { 100 | // conversion is safe 101 | Ok(n as usize) 102 | } 103 | } 104 | } 105 | 106 | impl Writer for LinuxTunWriter { 107 | type Error = LinuxTunError; 108 | 109 | fn write(&self, src: &[u8]) -> Result<(), Self::Error> { 110 | match unsafe { libc::write(self.fd, src.as_ptr() as _, src.len() as _) } { 111 | -1 => Err(LinuxTunError::Closed), 112 | _ => Ok(()), 113 | } 114 | } 115 | } 116 | 117 | fn get_ifindex(name: &[u8; libc::IFNAMSIZ]) -> i32 { 118 | debug_assert_eq!( 119 | name[libc::IFNAMSIZ - 1], 120 | 0, 121 | "name buffer not null-terminated" 122 | ); 123 | 124 | let name = *name; 125 | let idx = unsafe { 126 | let ptr: *const libc::c_char = mem::transmute(&name); 127 | libc::if_nametoindex(ptr) 128 | }; 129 | idx as i32 130 | } 131 | 132 | fn get_mtu(name: &[u8; libc::IFNAMSIZ]) -> Result { 133 | #[repr(C)] 134 | struct arg { 135 | name: [u8; libc::IFNAMSIZ], 136 | mtu: u32, 137 | } 138 | 139 | debug_assert_eq!( 140 | name[libc::IFNAMSIZ - 1], 141 | 0, 142 | "name buffer not null-terminated" 143 | ); 144 | 145 | // create socket 146 | let fd = unsafe { libc::socket(libc::AF_INET, libc::SOCK_DGRAM, 0) }; 147 | if fd < 0 { 148 | return Err(LinuxTunError::GetMTUIoctlFailed); 149 | } 150 | 151 | // do SIOCGIFMTU ioctl 152 | let buf = arg { 153 | name: *name, 154 | mtu: 0, 155 | }; 156 | let err = unsafe { 157 | let ptr: &libc::c_void = &*(&buf as *const _ as *const libc::c_void); 158 | libc::ioctl(fd, libc::SIOCGIFMTU, ptr) 159 | }; 160 | 161 | // close socket 162 | unsafe { libc::close(fd) }; 163 | 164 | // handle error from ioctl 165 | if err != 0 { 166 | return Err(LinuxTunError::GetMTUIoctlFailed); 167 | } 168 | 169 | // upcast to usize 170 | Ok(buf.mtu as usize) 171 | } 172 | 173 | impl Status for LinuxTunStatus { 174 | type Error = LinuxTunError; 175 | 176 | fn event(&mut self) -> Result { 177 | const DONE: u16 = libc::NLMSG_DONE as u16; 178 | const ERROR: u16 = libc::NLMSG_ERROR as u16; 179 | const INFO_SIZE: usize = mem::size_of::(); 180 | const HDR_SIZE: usize = mem::size_of::(); 181 | 182 | let mut buf = [0u8; 1 << 12]; 183 | log::debug!("netlink, fetch event (fd = {})", self.fd); 184 | loop { 185 | // attempt to return a buffered event 186 | if let Some(event) = self.events.pop() { 187 | return Ok(event); 188 | } 189 | 190 | // read message 191 | let size: libc::ssize_t = 192 | unsafe { libc::recv(self.fd, mem::transmute(&mut buf), buf.len(), 0) }; 193 | if size < 0 { 194 | break Err(LinuxTunError::NetlinkFailure); 195 | } 196 | 197 | // cut buffer to size 198 | let size: usize = size as usize; 199 | let mut remain = &buf[..size]; 200 | log::debug!("netlink, received message ({} bytes)", size); 201 | 202 | // handle messages 203 | while remain.len() >= HDR_SIZE { 204 | // extract the header 205 | assert!(remain.len() > HDR_SIZE); 206 | let hdr: libc::nlmsghdr = unsafe { 207 | let mut hdr = [0u8; HDR_SIZE]; 208 | hdr.copy_from_slice(&remain[..HDR_SIZE]); 209 | mem::transmute(hdr) 210 | }; 211 | 212 | // upcast length 213 | let body: &[u8] = &remain[HDR_SIZE..]; 214 | let msg_len: usize = hdr.nlmsg_len as usize; 215 | assert!(msg_len <= remain.len(), "malformed netlink message"); 216 | 217 | // handle message body 218 | match hdr.nlmsg_type { 219 | DONE => break, 220 | ERROR => break, 221 | libc::RTM_NEWLINK => { 222 | // extract info struct 223 | if body.len() < INFO_SIZE { 224 | return Err(LinuxTunError::NetlinkFailure); 225 | } 226 | let info: IfInfomsg = unsafe { 227 | let mut info = [0u8; INFO_SIZE]; 228 | info.copy_from_slice(&body[..INFO_SIZE]); 229 | mem::transmute(info) 230 | }; 231 | 232 | // trace log 233 | log::trace!( 234 | "netlink, IfInfomsg{{ family = {}, type = {}, index = {}, flags = {}, change = {}}}", 235 | info.ifi_family, 236 | info.ifi_type, 237 | info.ifi_index, 238 | info.ifi_flags, 239 | info.ifi_change, 240 | ); 241 | debug_assert_eq!(info.__ifi_pad, 0); 242 | 243 | if info.ifi_index == self.index { 244 | // handle up / down 245 | if info.ifi_flags & (libc::IFF_UP as u32) != 0 { 246 | let mtu = get_mtu(&self.name)?; 247 | log::trace!("netlink, up event, mtu = {}", mtu); 248 | self.events.push(TunEvent::Up(mtu)); 249 | } else { 250 | log::trace!("netlink, down event"); 251 | self.events.push(TunEvent::Down); 252 | } 253 | } 254 | } 255 | _ => (), 256 | }; 257 | 258 | // go to next message 259 | remain = &remain[msg_len..]; 260 | } 261 | } 262 | } 263 | } 264 | 265 | impl LinuxTunStatus { 266 | const RTNLGRP_LINK: libc::c_uint = 1; 267 | const RTNLGRP_IPV4_IFADDR: libc::c_uint = 5; 268 | const RTNLGRP_IPV6_IFADDR: libc::c_uint = 9; 269 | 270 | fn new(name: [u8; libc::IFNAMSIZ]) -> Result { 271 | // create netlink socket 272 | let fd = unsafe { libc::socket(libc::AF_NETLINK, libc::SOCK_RAW, libc::NETLINK_ROUTE) }; 273 | if fd < 0 { 274 | return Err(LinuxTunError::Closed); 275 | } 276 | 277 | // prepare address (specify groups) 278 | let groups = (1 << (Self::RTNLGRP_LINK - 1)) 279 | | (1 << (Self::RTNLGRP_IPV4_IFADDR - 1)) 280 | | (1 << (Self::RTNLGRP_IPV6_IFADDR - 1)); 281 | 282 | let mut sockaddr: libc::sockaddr_nl = unsafe { mem::zeroed() }; 283 | sockaddr.nl_family = libc::AF_NETLINK as u16; 284 | sockaddr.nl_groups = groups; 285 | sockaddr.nl_pid = 0; 286 | 287 | // attempt to bind 288 | let res = unsafe { 289 | libc::bind( 290 | fd, 291 | mem::transmute(&mut sockaddr), 292 | mem::size_of::() as u32, 293 | ) 294 | }; 295 | 296 | if res != 0 { 297 | Err(LinuxTunError::Closed) 298 | } else { 299 | Ok(LinuxTunStatus { 300 | events: vec![ 301 | #[cfg(feature = "start_up")] 302 | TunEvent::Up(1500), 303 | ], 304 | index: get_ifindex(&name), 305 | fd, 306 | name, 307 | }) 308 | } 309 | } 310 | } 311 | 312 | impl Tun for LinuxTun { 313 | type Writer = LinuxTunWriter; 314 | type Reader = LinuxTunReader; 315 | type Error = LinuxTunError; 316 | } 317 | 318 | impl PlatformTun for LinuxTun { 319 | type Status = LinuxTunStatus; 320 | 321 | #[allow(clippy::type_complexity)] 322 | fn create(name: &str) -> Result<(Vec, Self::Writer, Self::Status), Self::Error> { 323 | // construct request struct 324 | let mut req = Ifreq { 325 | name: [0u8; libc::IFNAMSIZ], 326 | flags: (libc::IFF_TUN | libc::IFF_NO_PI) as c_short, 327 | _pad: [0u8; 64], 328 | }; 329 | 330 | // sanity check length of device name 331 | let bs = name.as_bytes(); 332 | if bs.len() > libc::IFNAMSIZ - 1 { 333 | return Err(LinuxTunError::InvalidTunDeviceName); 334 | } 335 | req.name[..bs.len()].copy_from_slice(bs); 336 | 337 | // open clone device 338 | let fd: RawFd = match unsafe { libc::open(CLONE_DEVICE_PATH.as_ptr() as _, libc::O_RDWR) } { 339 | -1 => return Err(LinuxTunError::FailedToOpenCloneDevice), 340 | fd => fd, 341 | }; 342 | assert!(fd >= 0); 343 | 344 | // create TUN device 345 | if unsafe { libc::ioctl(fd, TUNSETIFF as _, &req) } < 0 { 346 | return Err(LinuxTunError::SetIFFIoctlFailed); 347 | } 348 | 349 | // create PlatformTunMTU instance 350 | Ok(( 351 | vec![LinuxTunReader { fd }], // TODO: use multi-queue for Linux 352 | LinuxTunWriter { fd }, 353 | LinuxTunStatus::new(req.name)?, 354 | )) 355 | } 356 | } 357 | -------------------------------------------------------------------------------- /src/platform/linux/uapi.rs: -------------------------------------------------------------------------------- 1 | use super::super::uapi::*; 2 | 3 | use std::fs; 4 | use std::io; 5 | use std::os::unix::net::{UnixListener, UnixStream}; 6 | 7 | const SOCK_DIR: &str = "/var/run/wireguard/"; 8 | 9 | pub struct LinuxUAPI {} 10 | 11 | impl PlatformUAPI for LinuxUAPI { 12 | type Error = io::Error; 13 | type Bind = UnixListener; 14 | 15 | fn bind(name: &str) -> Result { 16 | let socket_path = format!("{}{}.sock", SOCK_DIR, name); 17 | let _ = fs::create_dir_all(SOCK_DIR); 18 | let _ = fs::remove_file(&socket_path); 19 | UnixListener::bind(socket_path) 20 | } 21 | } 22 | 23 | impl BindUAPI for UnixListener { 24 | type Stream = UnixStream; 25 | type Error = io::Error; 26 | 27 | fn connect(&self) -> Result { 28 | let (stream, _) = self.accept()?; 29 | Ok(stream) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/platform/mod.rs: -------------------------------------------------------------------------------- 1 | mod endpoint; 2 | 3 | pub mod tun; 4 | pub mod uapi; 5 | pub mod udp; 6 | 7 | pub use endpoint::Endpoint; 8 | 9 | #[cfg(target_os = "linux")] 10 | pub mod linux; 11 | 12 | #[cfg(test)] 13 | pub mod dummy; 14 | 15 | #[cfg(target_os = "linux")] 16 | pub use linux as plt; 17 | -------------------------------------------------------------------------------- /src/platform/tun.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | pub enum TunEvent { 4 | Up(usize), // interface is up (supply MTU) 5 | Down, // interface is down 6 | } 7 | 8 | pub trait Status: Send + 'static { 9 | type Error: Error; 10 | 11 | /// Returns status updates for the interface 12 | /// When the status is unchanged the method blocks 13 | fn event(&mut self) -> Result; 14 | } 15 | 16 | pub trait Writer: Send + Sync + 'static { 17 | type Error: Error; 18 | 19 | /// Receive a cryptkey routed IP packet 20 | /// 21 | /// # Arguments 22 | /// 23 | /// - src: Buffer containing the IP packet to be written 24 | /// 25 | /// # Returns 26 | /// 27 | /// Unit type or an error 28 | fn write(&self, src: &[u8]) -> Result<(), Self::Error>; 29 | } 30 | 31 | pub trait Reader: Send + 'static { 32 | type Error: Error; 33 | 34 | /// Reads an IP packet into dst[offset:] from the tunnel device 35 | /// 36 | /// The reason for providing space for a prefix 37 | /// is to efficiently accommodate platforms on which the packet is prefaced by a header. 38 | /// This space is later used to construct the transport message inplace. 39 | /// 40 | /// # Arguments 41 | /// 42 | /// - buf: Destination buffer (enough space for MTU bytes + header) 43 | /// - offset: Offset for the beginning of the IP packet 44 | /// 45 | /// # Returns 46 | /// 47 | /// The size of the IP packet (ignoring the header) or an std::error::Error instance: 48 | fn read(&self, buf: &mut [u8], offset: usize) -> Result; 49 | } 50 | 51 | pub trait Tun: Send + Sync + 'static { 52 | type Writer: Writer; 53 | type Reader: Reader; 54 | type Error: Error; 55 | } 56 | 57 | /// On some platforms the application can create the TUN device itself. 58 | pub trait PlatformTun: Tun { 59 | type Status: Status; 60 | 61 | #[allow(clippy::type_complexity)] 62 | fn create(name: &str) -> Result<(Vec, Self::Writer, Self::Status), Self::Error>; 63 | } 64 | -------------------------------------------------------------------------------- /src/platform/uapi.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::io::{Read, Write}; 3 | 4 | pub trait BindUAPI { 5 | type Stream: Read + Write; 6 | type Error: Error; 7 | 8 | fn connect(&self) -> Result; 9 | } 10 | 11 | pub trait PlatformUAPI { 12 | type Error: Error; 13 | type Bind: BindUAPI; 14 | 15 | fn bind(name: &str) -> Result; 16 | } 17 | -------------------------------------------------------------------------------- /src/platform/udp.rs: -------------------------------------------------------------------------------- 1 | use super::Endpoint; 2 | use std::error::Error; 3 | 4 | pub trait Reader: Send + Sync { 5 | type Error: Error; 6 | 7 | fn read(&self, buf: &mut [u8]) -> Result<(usize, E), Self::Error>; 8 | } 9 | 10 | pub trait Writer: Send + Sync + 'static { 11 | type Error: Error; 12 | 13 | fn write(&self, buf: &[u8], dst: &mut E) -> Result<(), Self::Error>; 14 | } 15 | 16 | pub trait UDP: Send + Sync + 'static { 17 | type Error: Error; 18 | type Endpoint: Endpoint; 19 | 20 | /* Until Rust gets type equality constraints these have to be generic */ 21 | type Writer: Writer; 22 | type Reader: Reader; 23 | } 24 | 25 | /// On platforms where fwmark can be set and the 26 | /// implementation can bind to a new port during later configuration (UAPI support), 27 | /// this type provides the ability to set the fwmark and close the socket (by dropping the instance) 28 | pub trait Owner: Send { 29 | type Error: Error; 30 | 31 | fn get_port(&self) -> u16; 32 | 33 | fn set_fwmark(&mut self, value: Option) -> Result<(), Self::Error>; 34 | } 35 | 36 | /// On some platforms the application can itself bind to a socket. 37 | /// This enables configuration using the UAPI interface. 38 | pub trait PlatformUDP: UDP { 39 | type Owner: Owner; 40 | 41 | /// Bind to a new port, returning the reader/writer and 42 | /// an associated instance of the owner type, which closes the UDP socket upon "drop" 43 | /// and enables configuration of the fwmark value. 44 | #[allow(clippy::type_complexity)] 45 | fn bind(port: u16) -> Result<(Vec, Self::Writer, Self::Owner), Self::Error>; 46 | } 47 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | use std::fmt; 3 | use std::process::exit; 4 | 5 | use libc::{c_char, chdir, chroot, fork, getpwnam, getuid, setgid, setsid, setuid, umask}; 6 | 7 | #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] 8 | pub enum DaemonizeError { 9 | Fork, 10 | SetSession, 11 | SetGroup, 12 | SetUser, 13 | Chroot, 14 | Chdir, 15 | } 16 | 17 | impl fmt::Display for DaemonizeError { 18 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 19 | match *self { 20 | DaemonizeError::Fork => "unable to fork", 21 | DaemonizeError::SetSession => "unable to create new process session", 22 | DaemonizeError::SetGroup => "unable to set group (drop privileges)", 23 | DaemonizeError::SetUser => "unable to set user (drop privileges)", 24 | DaemonizeError::Chroot => "unable to enter chroot jail", 25 | DaemonizeError::Chdir => "failed to change directory", 26 | } 27 | .fmt(f) 28 | } 29 | } 30 | 31 | fn fork_and_exit() -> Result<(), DaemonizeError> { 32 | let pid = unsafe { fork() }; 33 | match pid.cmp(&0) { 34 | Ordering::Less => Err(DaemonizeError::Fork), 35 | Ordering::Equal => Ok(()), 36 | Ordering::Greater => exit(0), 37 | } 38 | } 39 | 40 | pub fn daemonize() -> Result<(), DaemonizeError> { 41 | // fork from the original parent 42 | fork_and_exit()?; 43 | 44 | // avoid killing the child when this parent dies 45 | if unsafe { setsid() } < 0 { 46 | return Err(DaemonizeError::SetSession); 47 | } 48 | 49 | // fork again to create orphan 50 | fork_and_exit() 51 | } 52 | 53 | pub fn drop_privileges() -> Result<(), DaemonizeError> { 54 | // retrieve nobody's uid & gid 55 | let usr = unsafe { getpwnam("nobody\x00".as_ptr() as *const c_char) }; 56 | if usr.is_null() { 57 | return Err(DaemonizeError::SetGroup); 58 | } 59 | 60 | // change root directory 61 | let uid = unsafe { getuid() }; 62 | if uid == 0 && unsafe { chroot("/tmp\x00".as_ptr() as *const c_char) } != 0 { 63 | return Err(DaemonizeError::Chroot); 64 | } 65 | 66 | // set umask for files 67 | unsafe { umask(0) }; 68 | 69 | // change directory 70 | if unsafe { chdir("/\x00".as_ptr() as *const c_char) } != 0 { 71 | return Err(DaemonizeError::Chdir); 72 | } 73 | 74 | // set group id to nobody 75 | if unsafe { setgid((*usr).pw_gid) } != 0 { 76 | return Err(DaemonizeError::SetGroup); 77 | } 78 | 79 | // set user id to nobody 80 | if unsafe { setuid((*usr).pw_uid) } != 0 { 81 | Err(DaemonizeError::SetUser) 82 | } else { 83 | Ok(()) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/wireguard/constants.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | use std::u64; 3 | 4 | pub const REKEY_AFTER_MESSAGES: u64 = 1 << 60; 5 | pub const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 4); 6 | 7 | pub const REKEY_AFTER_TIME: Duration = Duration::from_secs(120); 8 | pub const REJECT_AFTER_TIME: Duration = Duration::from_secs(180); 9 | pub const REKEY_ATTEMPT_TIME: Duration = Duration::from_secs(90); 10 | pub const REKEY_TIMEOUT: Duration = Duration::from_secs(5); 11 | pub const KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(10); 12 | 13 | pub const MAX_TIMER_HANDSHAKES: usize = 14 | (REKEY_ATTEMPT_TIME.as_secs() / REKEY_TIMEOUT.as_secs()) as usize; 15 | 16 | // Semantics: 17 | // Maximum number of buffered handshake requests 18 | // (either from outside message or handshake requests triggered locally) 19 | pub const MAX_QUEUED_INCOMING_HANDSHAKES: usize = 4096; 20 | 21 | // Semantics: 22 | // When the number of queued handshake requests exceeds this number 23 | // the device is considered under load and DoS mitigation is triggered. 24 | pub const THRESHOLD_UNDER_LOAD: usize = MAX_QUEUED_INCOMING_HANDSHAKES / 8; 25 | 26 | // Semantics: 27 | // When a device is detected to go under load, 28 | // it will remain under load for at least the following duration. 29 | pub const DURATION_UNDER_LOAD: Duration = Duration::from_secs(1); 30 | 31 | // Semantics: 32 | // The payload of transport messages are padded to this multiple 33 | pub const MESSAGE_PADDING_MULTIPLE: usize = 16; 34 | 35 | // Semantics: 36 | // Longest possible duration of any WireGuard timer 37 | pub const TIMER_MAX_DURATION: Duration = Duration::from_secs(200); 38 | 39 | // Semantics: 40 | // Resolution of the timer-wheel 41 | pub const TIMERS_TICK: Duration = Duration::from_millis(100); 42 | 43 | // Semantics: 44 | // Resulting number of slots in the wheel 45 | pub const TIMERS_SLOTS: usize = (TIMER_MAX_DURATION.as_micros() / TIMERS_TICK.as_micros()) as usize; 46 | 47 | // Performance: 48 | // Initial capacity of timer-wheel (grows to accommodate more timers). 49 | pub const TIMERS_CAPACITY: usize = 16; 50 | 51 | /* A long duration (compared to the WireGuard time constants), 52 | * used in places to avoid Option by instead using a long "expired" Instant: 53 | * (Instant::now() - TIME_HORIZON) 54 | * 55 | * Note, this duration need not fit inside the timer wheel. 56 | */ 57 | pub const TIME_HORIZON: Duration = Duration::from_secs(TIMER_MAX_DURATION.as_secs() * 2); 58 | -------------------------------------------------------------------------------- /src/wireguard/handshake/macs.rs: -------------------------------------------------------------------------------- 1 | use generic_array::GenericArray; 2 | use rand_core::{CryptoRng, RngCore}; 3 | use spin::RwLock; 4 | use std::time::{Duration, Instant}; 5 | 6 | // types to coalesce into bytes 7 | use std::net::SocketAddr; 8 | use x25519_dalek::PublicKey; 9 | 10 | // AEAD 11 | 12 | use aead::{Aead, NewAead, Payload}; 13 | use chacha20poly1305::XChaCha20Poly1305; 14 | 15 | // MAC 16 | use blake2::Blake2s; 17 | use subtle::ConstantTimeEq; 18 | 19 | use super::messages::{CookieReply, MacsFooter, TYPE_COOKIE_REPLY}; 20 | use super::types::HandshakeError; 21 | 22 | const LABEL_MAC1: &[u8] = b"mac1----"; 23 | const LABEL_COOKIE: &[u8] = b"cookie--"; 24 | 25 | const SIZE_COOKIE: usize = 16; 26 | const SIZE_SECRET: usize = 32; 27 | const SIZE_MAC: usize = 16; // blake2s-mac128 28 | const SIZE_TAG: usize = 16; // xchacha20poly1305 tag 29 | 30 | const COOKIE_UPDATE_INTERVAL: Duration = Duration::from_secs(120); 31 | 32 | macro_rules! HASH { 33 | ( $($input:expr),* ) => {{ 34 | use blake2::Digest; 35 | let mut hsh = Blake2s::new(); 36 | $( 37 | hsh.update($input); 38 | )* 39 | hsh.finalize() 40 | }}; 41 | } 42 | 43 | macro_rules! MAC { 44 | ( $key:expr, $($input:expr),* ) => {{ 45 | use blake2::VarBlake2s; 46 | use blake2::digest::{Update, VariableOutput}; 47 | let mut tag = [0u8; SIZE_MAC]; 48 | let mut mac = VarBlake2s::new_keyed($key, SIZE_MAC); 49 | $( 50 | mac.update($input); 51 | )* 52 | mac.finalize_variable(|buf| tag.copy_from_slice(buf)); 53 | tag 54 | }}; 55 | } 56 | 57 | macro_rules! XSEAL { 58 | ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{ 59 | let ct = XChaCha20Poly1305::new(GenericArray::from_slice($key)) 60 | .encrypt( 61 | GenericArray::from_slice($nonce), 62 | Payload { msg: $pt, aad: $ad }, 63 | ) 64 | .unwrap(); 65 | debug_assert_eq!(ct.len(), $pt.len() + SIZE_TAG); 66 | $ct.copy_from_slice(&ct); 67 | }}; 68 | } 69 | 70 | macro_rules! XOPEN { 71 | ($key:expr, $nonce:expr, $ad:expr, $pt:expr, $ct:expr) => {{ 72 | debug_assert_eq!($ct.len(), $pt.len() + SIZE_TAG); 73 | XChaCha20Poly1305::new(GenericArray::from_slice($key)) 74 | .decrypt( 75 | GenericArray::from_slice($nonce), 76 | Payload { msg: $ct, aad: $ad }, 77 | ) 78 | .map_err(|_| HandshakeError::DecryptionFailure) 79 | .map(|pt| $pt.copy_from_slice(&pt)) 80 | }}; 81 | } 82 | 83 | struct Cookie { 84 | value: [u8; 16], 85 | birth: Instant, 86 | } 87 | 88 | pub struct Generator { 89 | mac1_key: [u8; 32], 90 | cookie_key: [u8; 32], // xchacha20poly key for opening cookie response 91 | last_mac1: Option<[u8; 16]>, 92 | cookie: Option, 93 | } 94 | 95 | fn addr_to_mac_bytes(addr: &SocketAddr) -> Vec { 96 | match addr { 97 | SocketAddr::V4(addr) => { 98 | let mut res = Vec::with_capacity(4 + 2); 99 | res.extend(&addr.ip().octets()); 100 | res.extend(&addr.port().to_le_bytes()); 101 | res 102 | } 103 | SocketAddr::V6(addr) => { 104 | let mut res = Vec::with_capacity(16 + 2); 105 | res.extend(&addr.ip().octets()); 106 | res.extend(&addr.port().to_le_bytes()); 107 | res 108 | } 109 | } 110 | } 111 | 112 | impl Generator { 113 | /// Initalize a new mac field generator 114 | /// 115 | /// # Arguments 116 | /// 117 | /// - pk: The public key of the peer to which the generator is associated 118 | /// 119 | /// # Returns 120 | /// 121 | /// A freshly initated generator 122 | pub fn new(pk: PublicKey) -> Generator { 123 | Generator { 124 | mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(), 125 | cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(), 126 | last_mac1: None, 127 | cookie: None, 128 | } 129 | } 130 | 131 | /// Process a CookieReply message 132 | /// 133 | /// # Arguments 134 | /// 135 | /// - reply: CookieReply to process 136 | /// 137 | /// # Returns 138 | /// 139 | /// Can fail if the cookie reply fails to validate 140 | /// (either indicating that it is outdated or malformed) 141 | pub fn process(&mut self, reply: &CookieReply) -> Result<(), HandshakeError> { 142 | let mac1 = self.last_mac1.ok_or(HandshakeError::InvalidState)?; 143 | let mut tau = [0u8; SIZE_COOKIE]; 144 | #[allow(clippy::unnecessary_mut_passed)] 145 | XOPEN!( 146 | &self.cookie_key, // key 147 | &reply.f_nonce, // nonce 148 | &mac1, // ad 149 | &mut tau, // pt 150 | &reply.f_cookie // ct || tag 151 | )?; 152 | self.cookie = Some(Cookie { 153 | birth: Instant::now(), 154 | value: tau, 155 | }); 156 | Ok(()) 157 | } 158 | 159 | /// Generate both mac fields for an inner message 160 | /// 161 | /// # Arguments 162 | /// 163 | /// - inner: A byteslice representing the inner message to be covered 164 | /// - macs: The destination mac footer for the resulting macs 165 | pub fn generate(&mut self, inner: &[u8], macs: &mut MacsFooter) { 166 | macs.f_mac1 = MAC!(&self.mac1_key, inner); 167 | macs.f_mac2 = match &self.cookie { 168 | Some(cookie) => { 169 | if cookie.birth.elapsed() > COOKIE_UPDATE_INTERVAL { 170 | self.cookie = None; 171 | [0u8; SIZE_MAC] 172 | } else { 173 | MAC!(&cookie.value, inner, macs.f_mac1) 174 | } 175 | } 176 | None => [0u8; SIZE_MAC], 177 | }; 178 | self.last_mac1 = Some(macs.f_mac1); 179 | } 180 | } 181 | 182 | struct Secret { 183 | value: [u8; 32], 184 | birth: Instant, 185 | } 186 | 187 | pub struct Validator { 188 | mac1_key: [u8; 32], // mac1 key, derived from device public key 189 | cookie_key: [u8; 32], // xchacha20poly key for sealing cookie response 190 | secret: RwLock, 191 | } 192 | 193 | impl Validator { 194 | pub fn new(pk: PublicKey) -> Validator { 195 | Validator { 196 | mac1_key: HASH!(LABEL_MAC1, pk.as_bytes()).into(), 197 | cookie_key: HASH!(LABEL_COOKIE, pk.as_bytes()).into(), 198 | secret: RwLock::new(Secret { 199 | value: [0u8; SIZE_SECRET], 200 | birth: Instant::now() - Duration::new(86400, 0), 201 | }), 202 | } 203 | } 204 | 205 | fn get_tau(&self, src: &[u8]) -> Option<[u8; SIZE_COOKIE]> { 206 | let secret = self.secret.read(); 207 | if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL { 208 | Some(MAC!(&secret.value, src)) 209 | } else { 210 | None 211 | } 212 | } 213 | 214 | fn get_set_tau(&self, rng: &mut R, src: &[u8]) -> [u8; SIZE_COOKIE] { 215 | // check if current value is still valid 216 | { 217 | let secret = self.secret.read(); 218 | if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL { 219 | return MAC!(&secret.value, src); 220 | }; 221 | } 222 | 223 | // take write lock, check again 224 | { 225 | let mut secret = self.secret.write(); 226 | if secret.birth.elapsed() < COOKIE_UPDATE_INTERVAL { 227 | return MAC!(&secret.value, src); 228 | }; 229 | 230 | // set new random cookie secret 231 | rng.fill_bytes(&mut secret.value); 232 | secret.birth = Instant::now(); 233 | MAC!(&secret.value, src) 234 | } 235 | } 236 | 237 | pub fn create_cookie_reply( 238 | &self, 239 | rng: &mut R, 240 | receiver: u32, // receiver id of incoming message 241 | src: &SocketAddr, // source address of incoming message 242 | macs: &MacsFooter, // footer of incoming message 243 | msg: &mut CookieReply, // resulting cookie reply 244 | ) { 245 | let src = addr_to_mac_bytes(src); 246 | msg.f_type.set(TYPE_COOKIE_REPLY as u32); 247 | msg.f_receiver.set(receiver); 248 | rng.fill_bytes(&mut msg.f_nonce); 249 | XSEAL!( 250 | &self.cookie_key, // key 251 | &msg.f_nonce, // nonce 252 | &macs.f_mac1, // ad 253 | &self.get_set_tau(rng, &src), // pt 254 | &mut msg.f_cookie // ct || tag 255 | ); 256 | } 257 | 258 | /// Check the mac1 field against the inner message 259 | /// 260 | /// # Arguments 261 | /// 262 | /// - inner: The inner message covered by the mac1 field 263 | /// - macs: The mac footer 264 | pub fn check_mac1(&self, inner: &[u8], macs: &MacsFooter) -> Result<(), HandshakeError> { 265 | let valid_mac1: bool = MAC!(&self.mac1_key, inner).ct_eq(&macs.f_mac1).into(); 266 | if !valid_mac1 { 267 | Err(HandshakeError::InvalidMac1) 268 | } else { 269 | Ok(()) 270 | } 271 | } 272 | 273 | pub fn check_mac2(&self, inner: &[u8], src: &SocketAddr, macs: &MacsFooter) -> bool { 274 | let src = addr_to_mac_bytes(src); 275 | match self.get_tau(&src) { 276 | Some(tau) => MAC!(&tau, inner, macs.f_mac1).ct_eq(&macs.f_mac2).into(), 277 | None => false, 278 | } 279 | } 280 | } 281 | 282 | #[cfg(test)] 283 | mod tests { 284 | use super::*; 285 | use proptest::prelude::*; 286 | use rand::rngs::OsRng; 287 | use x25519_dalek::StaticSecret; 288 | 289 | fn new_validator_generator() -> (Validator, Generator) { 290 | let sk = StaticSecret::new(&mut OsRng); 291 | let pk = PublicKey::from(&sk); 292 | (Validator::new(pk), Generator::new(pk)) 293 | } 294 | 295 | proptest! { 296 | #[test] 297 | fn test_cookie_reply(inner1 : Vec, inner2 : Vec, receiver : u32) { 298 | let mut msg = CookieReply::default(); 299 | let mut macs = MacsFooter::default(); 300 | let src = "192.0.2.16:8080".parse().unwrap(); 301 | let (validator, mut generator) = new_validator_generator(); 302 | 303 | // generate mac1 for first message 304 | generator.generate(&inner1[..], &mut macs); 305 | assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set"); 306 | assert_eq!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should not be set"); 307 | 308 | // check validity of mac1 309 | validator.check_mac1(&inner1[..], &macs).expect("mac1 of inner1 did not validate"); 310 | assert_eq!(validator.check_mac2(&inner1[..], &src, &macs), false, "mac2 of inner2 did not validate"); 311 | validator.create_cookie_reply(&mut OsRng, receiver, &src, &macs, &mut msg); 312 | 313 | // consume cookie reply 314 | generator.process(&msg).expect("failed to process CookieReply"); 315 | 316 | // generate mac2 & mac2 for second message 317 | generator.generate(&inner2[..], &mut macs); 318 | assert_ne!(macs.f_mac1, [0u8; SIZE_MAC], "mac1 should be set"); 319 | assert_ne!(macs.f_mac2, [0u8; SIZE_MAC], "mac2 should be set"); 320 | 321 | // check validity of mac1 and mac2 322 | validator.check_mac1(&inner2[..], &macs).expect("mac1 of inner2 did not validate"); 323 | assert!(validator.check_mac2(&inner2[..], &src, &macs), "mac2 of inner2 did not validate"); 324 | } 325 | } 326 | } 327 | -------------------------------------------------------------------------------- /src/wireguard/handshake/messages.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | use hex; 3 | 4 | #[cfg(test)] 5 | use std::fmt; 6 | 7 | use std::mem; 8 | 9 | use byteorder::LittleEndian; 10 | use zerocopy::byteorder::U32; 11 | use zerocopy::{AsBytes, ByteSlice, FromBytes, LayoutVerified}; 12 | 13 | use super::types::*; 14 | 15 | const SIZE_MAC: usize = 16; 16 | const SIZE_TAG: usize = 16; // poly1305 tag 17 | const SIZE_XNONCE: usize = 24; // xchacha20 nonce 18 | const SIZE_COOKIE: usize = 16; // 19 | const SIZE_X25519_POINT: usize = 32; // x25519 public key 20 | const SIZE_TIMESTAMP: usize = 12; 21 | 22 | pub const TYPE_INITIATION: u32 = 1; 23 | pub const TYPE_RESPONSE: u32 = 2; 24 | pub const TYPE_COOKIE_REPLY: u32 = 3; 25 | 26 | const fn max(a: usize, b: usize) -> usize { 27 | let m: usize = (a > b) as usize; 28 | m * a + (1 - m) * b 29 | } 30 | 31 | pub const MAX_HANDSHAKE_MSG_SIZE: usize = max( 32 | max(mem::size_of::(), mem::size_of::()), 33 | mem::size_of::(), 34 | ); 35 | 36 | /* Handshake messsages */ 37 | 38 | #[repr(packed)] 39 | #[derive(Copy, Clone, FromBytes, AsBytes)] 40 | pub struct Response { 41 | pub noise: NoiseResponse, // inner message covered by macs 42 | pub macs: MacsFooter, 43 | } 44 | 45 | #[repr(packed)] 46 | #[derive(Copy, Clone, FromBytes, AsBytes)] 47 | pub struct Initiation { 48 | pub noise: NoiseInitiation, // inner message covered by macs 49 | pub macs: MacsFooter, 50 | } 51 | 52 | #[repr(packed)] 53 | #[derive(Copy, Clone, FromBytes, AsBytes)] 54 | pub struct CookieReply { 55 | pub f_type: U32, 56 | pub f_receiver: U32, 57 | pub f_nonce: [u8; SIZE_XNONCE], 58 | pub f_cookie: [u8; SIZE_COOKIE + SIZE_TAG], 59 | } 60 | 61 | /* Inner sub-messages */ 62 | 63 | #[repr(packed)] 64 | #[derive(Copy, Clone, FromBytes, AsBytes)] 65 | pub struct MacsFooter { 66 | pub f_mac1: [u8; SIZE_MAC], 67 | pub f_mac2: [u8; SIZE_MAC], 68 | } 69 | 70 | #[repr(packed)] 71 | #[derive(Copy, Clone, FromBytes, AsBytes)] 72 | pub struct NoiseInitiation { 73 | pub f_type: U32, 74 | pub f_sender: U32, 75 | pub f_ephemeral: [u8; SIZE_X25519_POINT], 76 | pub f_static: [u8; SIZE_X25519_POINT + SIZE_TAG], 77 | pub f_timestamp: [u8; SIZE_TIMESTAMP + SIZE_TAG], 78 | } 79 | 80 | #[repr(packed)] 81 | #[derive(Copy, Clone, FromBytes, AsBytes)] 82 | pub struct NoiseResponse { 83 | pub f_type: U32, 84 | pub f_sender: U32, 85 | pub f_receiver: U32, 86 | pub f_ephemeral: [u8; SIZE_X25519_POINT], 87 | pub f_empty: [u8; SIZE_TAG], 88 | } 89 | 90 | /* Zero copy parsing of handshake messages */ 91 | 92 | impl Initiation { 93 | pub fn parse(bytes: B) -> Result, HandshakeError> { 94 | let msg: LayoutVerified = 95 | LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?; 96 | 97 | if msg.noise.f_type.get() != (TYPE_INITIATION as u32) { 98 | return Err(HandshakeError::InvalidMessageFormat); 99 | } 100 | 101 | Ok(msg) 102 | } 103 | } 104 | 105 | impl Response { 106 | pub fn parse(bytes: B) -> Result, HandshakeError> { 107 | let msg: LayoutVerified = 108 | LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?; 109 | 110 | if msg.noise.f_type.get() != (TYPE_RESPONSE as u32) { 111 | return Err(HandshakeError::InvalidMessageFormat); 112 | } 113 | 114 | Ok(msg) 115 | } 116 | } 117 | 118 | impl CookieReply { 119 | pub fn parse(bytes: B) -> Result, HandshakeError> { 120 | let msg: LayoutVerified = 121 | LayoutVerified::new(bytes).ok_or(HandshakeError::InvalidMessageFormat)?; 122 | 123 | if msg.f_type.get() != (TYPE_COOKIE_REPLY as u32) { 124 | return Err(HandshakeError::InvalidMessageFormat); 125 | } 126 | 127 | Ok(msg) 128 | } 129 | } 130 | 131 | /* Default values */ 132 | 133 | impl Default for Response { 134 | fn default() -> Self { 135 | Self { 136 | noise: Default::default(), 137 | macs: Default::default(), 138 | } 139 | } 140 | } 141 | 142 | impl Default for Initiation { 143 | fn default() -> Self { 144 | Self { 145 | noise: Default::default(), 146 | macs: Default::default(), 147 | } 148 | } 149 | } 150 | 151 | impl Default for CookieReply { 152 | fn default() -> Self { 153 | Self { 154 | f_type: >::new(TYPE_COOKIE_REPLY as u32), 155 | f_receiver: >::ZERO, 156 | f_nonce: [0u8; SIZE_XNONCE], 157 | f_cookie: [0u8; SIZE_COOKIE + SIZE_TAG], 158 | } 159 | } 160 | } 161 | 162 | impl Default for MacsFooter { 163 | fn default() -> Self { 164 | Self { 165 | f_mac1: [0u8; SIZE_MAC], 166 | f_mac2: [0u8; SIZE_MAC], 167 | } 168 | } 169 | } 170 | 171 | impl Default for NoiseInitiation { 172 | fn default() -> Self { 173 | Self { 174 | f_type: >::new(TYPE_INITIATION as u32), 175 | f_sender: >::ZERO, 176 | f_ephemeral: [0u8; SIZE_X25519_POINT], 177 | f_static: [0u8; SIZE_X25519_POINT + SIZE_TAG], 178 | f_timestamp: [0u8; SIZE_TIMESTAMP + SIZE_TAG], 179 | } 180 | } 181 | } 182 | 183 | impl Default for NoiseResponse { 184 | fn default() -> Self { 185 | Self { 186 | f_type: >::new(TYPE_RESPONSE as u32), 187 | f_sender: >::ZERO, 188 | f_receiver: >::ZERO, 189 | f_ephemeral: [0u8; SIZE_X25519_POINT], 190 | f_empty: [0u8; SIZE_TAG], 191 | } 192 | } 193 | } 194 | 195 | /* Debug formatting (for testing purposes) */ 196 | 197 | #[cfg(test)] 198 | impl fmt::Debug for Initiation { 199 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 200 | write!(f, "Initiation {{ {:?} || {:?} }}", self.noise, self.macs) 201 | } 202 | } 203 | 204 | #[cfg(test)] 205 | impl fmt::Debug for Response { 206 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 207 | write!(f, "Response {{ {:?} || {:?} }}", self.noise, self.macs) 208 | } 209 | } 210 | 211 | #[cfg(test)] 212 | impl fmt::Debug for CookieReply { 213 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 214 | write!( 215 | f, 216 | "CookieReply {{ type = {}, receiver = {}, nonce = {}, cookie = {} }}", 217 | self.f_type, 218 | self.f_receiver, 219 | hex::encode(&self.f_nonce[..]), 220 | hex::encode(&self.f_cookie[..]), 221 | ) 222 | } 223 | } 224 | 225 | #[cfg(test)] 226 | impl fmt::Debug for NoiseInitiation { 227 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 228 | write!(f, 229 | "NoiseInitiation {{ type = {}, sender = {}, ephemeral = {}, static = {}, timestamp = {} }}", 230 | self.f_type.get(), 231 | self.f_sender.get(), 232 | hex::encode(&self.f_ephemeral[..]), 233 | hex::encode(&self.f_static[..]), 234 | hex::encode(&self.f_timestamp[..]), 235 | ) 236 | } 237 | } 238 | 239 | #[cfg(test)] 240 | impl fmt::Debug for NoiseResponse { 241 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 242 | write!(f, 243 | "NoiseResponse {{ type = {}, sender = {}, receiver = {}, ephemeral = {}, empty = |{} }}", 244 | self.f_type, 245 | self.f_sender, 246 | self.f_receiver, 247 | hex::encode(&self.f_ephemeral[..]), 248 | hex::encode(&self.f_empty[..]) 249 | ) 250 | } 251 | } 252 | 253 | #[cfg(test)] 254 | impl fmt::Debug for MacsFooter { 255 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 256 | write!( 257 | f, 258 | "Macs {{ mac1 = {}, mac2 = {} }}", 259 | hex::encode(&self.f_mac1[..]), 260 | hex::encode(&self.f_mac2[..]) 261 | ) 262 | } 263 | } 264 | 265 | /* Equality (for testing purposes) */ 266 | 267 | #[cfg(test)] 268 | macro_rules! eq_as_bytes { 269 | ($type:path) => { 270 | impl PartialEq for $type { 271 | fn eq(&self, other: &Self) -> bool { 272 | self.as_bytes() == other.as_bytes() 273 | } 274 | } 275 | impl Eq for $type {} 276 | }; 277 | } 278 | 279 | #[cfg(test)] 280 | eq_as_bytes!(Initiation); 281 | 282 | #[cfg(test)] 283 | eq_as_bytes!(Response); 284 | 285 | #[cfg(test)] 286 | eq_as_bytes!(CookieReply); 287 | 288 | #[cfg(test)] 289 | eq_as_bytes!(MacsFooter); 290 | 291 | #[cfg(test)] 292 | eq_as_bytes!(NoiseInitiation); 293 | 294 | #[cfg(test)] 295 | eq_as_bytes!(NoiseResponse); 296 | 297 | #[cfg(test)] 298 | mod tests { 299 | use super::*; 300 | 301 | #[test] 302 | fn message_response_identity() { 303 | let mut msg: Response = Default::default(); 304 | 305 | msg.noise.f_sender.set(146252); 306 | msg.noise.f_receiver.set(554442); 307 | msg.noise.f_ephemeral = [ 308 | 0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c, 309 | 0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44, 310 | 0x13, 0x56, 0x52, 0x1f, 311 | ]; 312 | msg.noise.f_empty = [ 313 | 0x60, 0x0e, 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05, 314 | 0x2f, 0xde, 315 | ]; 316 | msg.macs.f_mac1 = [ 317 | 0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54, 318 | 0x69, 0x29, 319 | ]; 320 | msg.macs.f_mac2 = [ 321 | 0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5, 322 | 0xdf, 0xbb, 323 | ]; 324 | 325 | let buf: Vec = msg.as_bytes().to_vec(); 326 | let msg_p = Response::parse(&buf[..]).unwrap(); 327 | assert_eq!(msg, *msg_p.into_ref()); 328 | } 329 | 330 | #[test] 331 | fn message_initiate_identity() { 332 | let mut msg: Initiation = Default::default(); 333 | 334 | msg.noise.f_sender.set(575757); 335 | msg.noise.f_ephemeral = [ 336 | 0xc1, 0x66, 0x0a, 0x0c, 0xdc, 0x0f, 0x6c, 0x51, 0x0f, 0xc2, 0xcc, 0x51, 0x52, 0x0c, 337 | 0xde, 0x1e, 0xf7, 0xf1, 0xca, 0x90, 0x86, 0x72, 0xad, 0x67, 0xea, 0x89, 0x45, 0x44, 338 | 0x13, 0x56, 0x52, 0x1f, 339 | ]; 340 | msg.noise.f_static = [ 341 | 0xdc, 0x33, 0x90, 0x15, 0x8f, 0x82, 0x3e, 0x06, 0x44, 0xa0, 0xde, 0x4c, 0x15, 0x6c, 342 | 0x5d, 0xa4, 0x65, 0x99, 0xf6, 0x6c, 0xa1, 0x14, 0x77, 0xf9, 0xeb, 0x6a, 0xec, 0xc3, 343 | 0x3c, 0xda, 0x47, 0xe1, 0x45, 0xac, 0x8d, 0x43, 0xea, 0x1b, 0x2f, 0x02, 0x45, 0x5d, 344 | 0x86, 0x37, 0xee, 0x83, 0x6b, 0x42, 345 | ]; 346 | msg.noise.f_timestamp = [ 347 | 0x4f, 0x1c, 0x60, 0xec, 0x0e, 0xf6, 0x36, 0xf0, 0x78, 0x28, 0x57, 0x42, 0x60, 0x0e, 348 | 0x1e, 0x95, 0x41, 0x6b, 0x52, 0x05, 0xa2, 0x09, 0xe1, 0xbf, 0x40, 0x05, 0x2f, 0xde, 349 | ]; 350 | msg.macs.f_mac1 = [ 351 | 0xf2, 0xad, 0x40, 0xb5, 0xf7, 0xde, 0x77, 0x35, 0x89, 0x19, 0xb7, 0x5c, 0xf9, 0x54, 352 | 0x69, 0x29, 353 | ]; 354 | msg.macs.f_mac2 = [ 355 | 0x4f, 0xd2, 0x1b, 0xfe, 0x77, 0xe6, 0x2e, 0xc9, 0x07, 0xe2, 0x87, 0x17, 0xbb, 0xe5, 356 | 0xdf, 0xbb, 357 | ]; 358 | 359 | let buf: Vec = msg.as_bytes().to_vec(); 360 | let msg_p = Initiation::parse(&buf[..]).unwrap(); 361 | assert_eq!(msg, *msg_p.into_ref()); 362 | } 363 | } 364 | -------------------------------------------------------------------------------- /src/wireguard/handshake/mod.rs: -------------------------------------------------------------------------------- 1 | /* Implementation of the: 2 | * 3 | * Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s 4 | * 5 | * Protocol pattern, see: http://www.noiseprotocol.org/noise.html. 6 | * For documentation. 7 | */ 8 | 9 | mod device; 10 | mod macs; 11 | mod messages; 12 | mod noise; 13 | mod peer; 14 | mod ratelimiter; 15 | mod timestamp; 16 | mod types; 17 | 18 | #[cfg(test)] 19 | mod tests; 20 | 21 | // publicly exposed interface 22 | 23 | pub use device::Device; 24 | pub use messages::{MAX_HANDSHAKE_MSG_SIZE, TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; 25 | -------------------------------------------------------------------------------- /src/wireguard/handshake/peer.rs: -------------------------------------------------------------------------------- 1 | use spin::Mutex; 2 | 3 | use std::mem; 4 | use std::time::{Duration, Instant}; 5 | 6 | use generic_array::typenum::U32; 7 | use generic_array::GenericArray; 8 | 9 | use x25519_dalek::PublicKey; 10 | use x25519_dalek::StaticSecret; 11 | 12 | use clear_on_drop::clear::Clear; 13 | 14 | use super::device::Device; 15 | use super::macs; 16 | use super::timestamp; 17 | use super::types::*; 18 | 19 | const TIME_BETWEEN_INITIATIONS: Duration = Duration::from_millis(20); 20 | 21 | // Represents the state of a peer. 22 | // 23 | // This type is only for internal use and not exposed. 24 | pub(super) struct Peer { 25 | // opaque type which identifies a peer 26 | pub opaque: O, 27 | 28 | // mutable state 29 | pub state: Mutex, 30 | pub timestamp: Mutex>, 31 | pub last_initiation_consumption: Mutex>, 32 | 33 | // state related to DoS mitigation fields 34 | pub macs: Mutex, 35 | 36 | // constant state 37 | pub ss: [u8; 32], // precomputed DH(static, static) 38 | pub psk: Psk, // psk of peer 39 | } 40 | 41 | pub enum State { 42 | Reset, 43 | InitiationSent { 44 | local: u32, // local id assigned 45 | eph_sk: StaticSecret, 46 | hs: GenericArray, 47 | ck: GenericArray, 48 | }, 49 | } 50 | 51 | impl Drop for State { 52 | fn drop(&mut self) { 53 | if let State::InitiationSent { hs, ck, .. } = self { 54 | // eph_sk already cleared by dalek-x25519 55 | hs.clear(); 56 | ck.clear(); 57 | } 58 | } 59 | } 60 | 61 | impl Peer { 62 | pub fn new(pk: PublicKey, ss: [u8; 32], opaque: O) -> Self { 63 | Self { 64 | opaque, 65 | macs: Mutex::new(macs::Generator::new(pk)), 66 | state: Mutex::new(State::Reset), 67 | timestamp: Mutex::new(None), 68 | last_initiation_consumption: Mutex::new(None), 69 | ss, 70 | psk: [0u8; 32], 71 | } 72 | } 73 | 74 | pub fn reset_state(&self) -> Option { 75 | match mem::replace(&mut *self.state.lock(), State::Reset) { 76 | State::InitiationSent { local, .. } => Some(local), 77 | _ => None, 78 | } 79 | } 80 | 81 | /// Set the mutable state of the peer conditioned on the timestamp being newer 82 | /// 83 | /// # Arguments 84 | /// 85 | /// * st_new - The updated state of the peer 86 | /// * ts_new - The associated timestamp 87 | pub fn check_replay_flood( 88 | &self, 89 | device: &Device, 90 | timestamp_new: ×tamp::TAI64N, 91 | ) -> Result<(), HandshakeError> { 92 | let mut state = self.state.lock(); 93 | let mut timestamp = self.timestamp.lock(); 94 | let mut last_initiation_consumption = self.last_initiation_consumption.lock(); 95 | 96 | // check replay attack 97 | if let Some(timestamp_old) = *timestamp { 98 | if !timestamp::compare(×tamp_old, ×tamp_new) { 99 | return Err(HandshakeError::OldTimestamp); 100 | } 101 | }; 102 | 103 | // check flood attack 104 | if let Some(last) = *last_initiation_consumption { 105 | if last.elapsed() < TIME_BETWEEN_INITIATIONS { 106 | return Err(HandshakeError::InitiationFlood); 107 | } 108 | } 109 | 110 | // reset state 111 | if let State::InitiationSent { local, .. } = *state { 112 | device.release(local) 113 | } 114 | 115 | // update replay & flood protection 116 | *state = State::Reset; 117 | *timestamp = Some(*timestamp_new); 118 | *last_initiation_consumption = Some(Instant::now()); 119 | Ok(()) 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/wireguard/handshake/ratelimiter.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::net::IpAddr; 3 | use std::sync::atomic::{AtomicBool, Ordering}; 4 | use std::sync::{Arc, Condvar, Mutex}; 5 | use std::thread; 6 | use std::time::{Duration, Instant}; 7 | 8 | const PACKETS_PER_SECOND: u64 = 20; 9 | const PACKETS_BURSTABLE: u64 = 5; 10 | const PACKET_COST: u64 = 1_000_000_000 / PACKETS_PER_SECOND; 11 | const MAX_TOKENS: u64 = PACKET_COST * PACKETS_BURSTABLE; 12 | 13 | const GC_INTERVAL: Duration = Duration::from_secs(1); 14 | 15 | struct Entry { 16 | pub last_time: Instant, 17 | pub tokens: u64, 18 | } 19 | 20 | pub struct RateLimiter(Arc); 21 | 22 | struct RateLimiterInner { 23 | gc_running: AtomicBool, 24 | gc_dropped: (Mutex, Condvar), 25 | table: spin::RwLock>>, 26 | } 27 | 28 | impl Drop for RateLimiter { 29 | fn drop(&mut self) { 30 | // wake up & terminate any lingering GC thread 31 | let &(ref lock, ref cvar) = &self.0.gc_dropped; 32 | let mut dropped = lock.lock().unwrap(); 33 | *dropped = true; 34 | cvar.notify_all(); 35 | } 36 | } 37 | 38 | impl RateLimiter { 39 | pub fn new() -> Self { 40 | #[allow(clippy::mutex_atomic)] 41 | RateLimiter(Arc::new(RateLimiterInner { 42 | gc_dropped: (Mutex::new(false), Condvar::new()), 43 | gc_running: AtomicBool::from(false), 44 | table: spin::RwLock::new(HashMap::new()), 45 | })) 46 | } 47 | 48 | pub fn allow(&self, addr: &IpAddr) -> bool { 49 | // check if allowed 50 | let allowed = { 51 | // check for existing entry (only requires read lock) 52 | if let Some(entry) = self.0.table.read().get(addr) { 53 | // update existing entry 54 | let mut entry = entry.lock(); 55 | 56 | // add tokens earned since last time 57 | entry.tokens = MAX_TOKENS 58 | .min(entry.tokens + u64::from(entry.last_time.elapsed().subsec_nanos())); 59 | entry.last_time = Instant::now(); 60 | 61 | // subtract cost of packet 62 | if entry.tokens > PACKET_COST { 63 | entry.tokens -= PACKET_COST; 64 | return true; 65 | } else { 66 | return false; 67 | } 68 | } 69 | 70 | // add new entry (write lock) 71 | self.0.table.write().insert( 72 | *addr, 73 | spin::Mutex::new(Entry { 74 | last_time: Instant::now(), 75 | tokens: MAX_TOKENS - PACKET_COST, 76 | }), 77 | ); 78 | true 79 | }; 80 | 81 | // check that GC thread is scheduled 82 | if !self.0.gc_running.swap(true, Ordering::Relaxed) { 83 | let limiter = self.0.clone(); 84 | thread::spawn(move || { 85 | let &(ref lock, ref cvar) = &limiter.gc_dropped; 86 | let mut dropped = lock.lock().unwrap(); 87 | while !*dropped { 88 | // garbage collect 89 | { 90 | let mut tw = limiter.table.write(); 91 | tw.retain(|_, ref mut entry| { 92 | entry.lock().last_time.elapsed() <= GC_INTERVAL 93 | }); 94 | if tw.len() == 0 { 95 | limiter.gc_running.store(false, Ordering::Relaxed); 96 | return; 97 | } 98 | } 99 | 100 | // wait until stopped or new GC (~1 every sec) 101 | let res = cvar.wait_timeout(dropped, GC_INTERVAL).unwrap(); 102 | dropped = res.0; 103 | } 104 | }); 105 | } 106 | 107 | allowed 108 | } 109 | } 110 | 111 | #[cfg(test)] 112 | mod tests { 113 | use super::*; 114 | use std; 115 | 116 | struct Result { 117 | allowed: bool, 118 | text: &'static str, 119 | wait: Duration, 120 | } 121 | 122 | #[test] 123 | fn test_ratelimiter() { 124 | let ratelimiter = RateLimiter::new(); 125 | let mut expected = vec![]; 126 | let ips = vec![ 127 | "127.0.0.1".parse().unwrap(), 128 | "192.168.1.1".parse().unwrap(), 129 | "172.167.2.3".parse().unwrap(), 130 | "97.231.252.215".parse().unwrap(), 131 | "248.97.91.167".parse().unwrap(), 132 | "188.208.233.47".parse().unwrap(), 133 | "104.2.183.179".parse().unwrap(), 134 | "72.129.46.120".parse().unwrap(), 135 | "2001:0db8:0a0b:12f0:0000:0000:0000:0001".parse().unwrap(), 136 | "f5c2:818f:c052:655a:9860:b136:6894:25f0".parse().unwrap(), 137 | "b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc".parse().unwrap(), 138 | "a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918".parse().unwrap(), 139 | "ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445".parse().unwrap(), 140 | "3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4".parse().unwrap(), 141 | ]; 142 | 143 | for _ in 0..PACKETS_BURSTABLE { 144 | expected.push(Result { 145 | allowed: true, 146 | wait: Duration::new(0, 0), 147 | text: "initial burst", 148 | }); 149 | } 150 | 151 | expected.push(Result { 152 | allowed: false, 153 | wait: Duration::new(0, 0), 154 | text: "after burst", 155 | }); 156 | 157 | expected.push(Result { 158 | allowed: true, 159 | wait: Duration::new(0, PACKET_COST as u32), 160 | text: "filling tokens for single packet", 161 | }); 162 | 163 | expected.push(Result { 164 | allowed: false, 165 | wait: Duration::new(0, 0), 166 | text: "not having refilled enough", 167 | }); 168 | 169 | expected.push(Result { 170 | allowed: true, 171 | wait: Duration::new(0, 2 * PACKET_COST as u32), 172 | text: "filling tokens for 2 * packet burst", 173 | }); 174 | 175 | expected.push(Result { 176 | allowed: true, 177 | wait: Duration::new(0, 0), 178 | text: "second packet in 2 packet burst", 179 | }); 180 | 181 | expected.push(Result { 182 | allowed: false, 183 | wait: Duration::new(0, 0), 184 | text: "packet following 2 packet burst", 185 | }); 186 | 187 | for item in expected { 188 | std::thread::sleep(item.wait); 189 | for ip in ips.iter() { 190 | if ratelimiter.allow(&ip) != item.allowed { 191 | panic!( 192 | "test failed for {} on {}. expected: {}, got: {}", 193 | ip, item.text, item.allowed, !item.allowed 194 | ) 195 | } 196 | } 197 | } 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /src/wireguard/handshake/tests.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | 3 | use std::net::SocketAddr; 4 | use std::thread; 5 | use std::time::Duration; 6 | 7 | use hex; 8 | 9 | use rand::rngs::OsRng; 10 | use rand_core::{CryptoRng, RngCore}; 11 | 12 | use x25519_dalek::PublicKey; 13 | use x25519_dalek::StaticSecret; 14 | 15 | use super::messages::{Initiation, Response}; 16 | 17 | fn setup_devices( 18 | rng1: &mut R, 19 | rng2: &mut R, 20 | rng3: &mut R, 21 | ) -> (PublicKey, Device, PublicKey, Device) { 22 | // generate new key pairs 23 | 24 | let sk1 = StaticSecret::new(rng1); 25 | let pk1 = PublicKey::from(&sk1); 26 | 27 | let sk2 = StaticSecret::new(rng2); 28 | let pk2 = PublicKey::from(&sk2); 29 | 30 | // pick random psk 31 | 32 | let mut psk = [0u8; 32]; 33 | rng3.fill_bytes(&mut psk[..]); 34 | 35 | // initialize devices on both ends 36 | 37 | let mut dev1 = Device::new(); 38 | let mut dev2 = Device::new(); 39 | 40 | dev1.set_sk(Some(sk1)); 41 | dev2.set_sk(Some(sk2)); 42 | 43 | dev1.add(pk2, O::default()).unwrap(); 44 | dev2.add(pk1, O::default()).unwrap(); 45 | 46 | dev1.set_psk(pk2, psk).unwrap(); 47 | dev2.set_psk(pk1, psk).unwrap(); 48 | 49 | (pk1, dev1, pk2, dev2) 50 | } 51 | 52 | fn wait() { 53 | thread::sleep(Duration::from_millis(20)); 54 | } 55 | 56 | /* Test longest possible handshake interaction (7 messages): 57 | * 58 | * 1. I -> R (initiation) 59 | * 2. I <- R (cookie reply) 60 | * 3. I -> R (initiation) 61 | * 4. I <- R (response) 62 | * 5. I -> R (cookie reply) 63 | * 6. I -> R (initiation) 64 | * 7. I <- R (response) 65 | */ 66 | #[test] 67 | fn handshake_under_load() { 68 | let (_pk1, dev1, pk2, dev2): (_, Device, _, _) = 69 | setup_devices(&mut OsRng, &mut OsRng, &mut OsRng); 70 | 71 | let src1: SocketAddr = "172.16.0.1:8080".parse().unwrap(); 72 | let src2: SocketAddr = "172.16.0.2:7070".parse().unwrap(); 73 | 74 | // 1. device-1 : create first initiation 75 | let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); 76 | 77 | // 2. device-2 : responds with CookieReply 78 | let msg_cookie = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { 79 | (None, Some(msg), None) => msg, 80 | _ => panic!("unexpected response"), 81 | }; 82 | 83 | // device-1 : processes CookieReply (no response) 84 | match dev1.process(&mut OsRng, &msg_cookie, Some(src2)).unwrap() { 85 | (None, None, None) => (), 86 | _ => panic!("unexpected response"), 87 | } 88 | 89 | // avoid initiation flood detection 90 | wait(); 91 | 92 | // 3. device-1 : create second initiation 93 | let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); 94 | 95 | // 4. device-2 : responds with noise response 96 | let msg_response = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { 97 | (Some(_), Some(msg), Some(kp)) => { 98 | assert_eq!(kp.initiator, false); 99 | msg 100 | } 101 | _ => panic!("unexpected response"), 102 | }; 103 | 104 | // 5. device-1 : responds with CookieReply 105 | let msg_cookie = match dev1.process(&mut OsRng, &msg_response, Some(src2)).unwrap() { 106 | (None, Some(msg), None) => msg, 107 | _ => panic!("unexpected response"), 108 | }; 109 | 110 | // device-2 : processes CookieReply (no response) 111 | match dev2.process(&mut OsRng, &msg_cookie, Some(src1)).unwrap() { 112 | (None, None, None) => (), 113 | _ => panic!("unexpected response"), 114 | } 115 | 116 | // avoid initiation flood detection 117 | wait(); 118 | 119 | // 6. device-1 : create third initiation 120 | let msg_init = dev1.begin(&mut OsRng, &pk2).unwrap(); 121 | 122 | // 7. device-2 : responds with noise response 123 | let (msg_response, kp1) = match dev2.process(&mut OsRng, &msg_init, Some(src1)).unwrap() { 124 | (Some(_), Some(msg), Some(kp)) => { 125 | assert_eq!(kp.initiator, false); 126 | (msg, kp) 127 | } 128 | _ => panic!("unexpected response"), 129 | }; 130 | 131 | // device-1 : process noise response 132 | let kp2 = match dev1.process(&mut OsRng, &msg_response, Some(src2)).unwrap() { 133 | (Some(_), None, Some(kp)) => { 134 | assert_eq!(kp.initiator, true); 135 | kp 136 | } 137 | _ => panic!("unexpected response"), 138 | }; 139 | 140 | assert_eq!(kp1.send, kp2.recv); 141 | assert_eq!(kp1.recv, kp2.send); 142 | } 143 | 144 | #[test] 145 | fn handshake_no_load() { 146 | let (pk1, mut dev1, pk2, mut dev2): (_, Device, _, _) = 147 | setup_devices(&mut OsRng, &mut OsRng, &mut OsRng); 148 | 149 | // do a few handshakes (every handshake should succeed) 150 | 151 | for i in 0..10 { 152 | println!("handshake : {}", i); 153 | 154 | // create initiation 155 | 156 | let msg1 = dev1.begin(&mut OsRng, &pk2).unwrap(); 157 | 158 | println!("msg1 = {} : {} bytes", hex::encode(&msg1[..]), msg1.len()); 159 | println!( 160 | "msg1 = {:?}", 161 | Initiation::parse(&msg1[..]).expect("failed to parse initiation") 162 | ); 163 | 164 | // process initiation and create response 165 | 166 | let (_, msg2, ks_r) = dev2 167 | .process(&mut OsRng, &msg1, None) 168 | .expect("failed to process initiation"); 169 | 170 | let ks_r = ks_r.unwrap(); 171 | let msg2 = msg2.unwrap(); 172 | 173 | println!("msg2 = {} : {} bytes", hex::encode(&msg2[..]), msg2.len()); 174 | println!( 175 | "msg2 = {:?}", 176 | Response::parse(&msg2[..]).expect("failed to parse response") 177 | ); 178 | 179 | assert!(!ks_r.initiator, "Responders key-pair is confirmed"); 180 | 181 | // process response and obtain confirmed key-pair 182 | 183 | let (_, msg3, ks_i) = dev1 184 | .process(&mut OsRng, &msg2, None) 185 | .expect("failed to process response"); 186 | let ks_i = ks_i.unwrap(); 187 | 188 | assert!(msg3.is_none(), "Returned message after response"); 189 | assert!(ks_i.initiator, "Initiators key-pair is not confirmed"); 190 | 191 | assert_eq!(ks_i.send, ks_r.recv, "KeyI.send != KeyR.recv"); 192 | assert_eq!(ks_i.recv, ks_r.send, "KeyI.recv != KeyR.send"); 193 | 194 | dev1.release(ks_i.local_id()); 195 | dev2.release(ks_r.local_id()); 196 | 197 | // avoid initiation flood detection 198 | wait(); 199 | } 200 | 201 | dev1.remove(&pk2).unwrap(); 202 | dev2.remove(&pk1).unwrap(); 203 | } 204 | -------------------------------------------------------------------------------- /src/wireguard/handshake/timestamp.rs: -------------------------------------------------------------------------------- 1 | use std::time::{SystemTime, UNIX_EPOCH}; 2 | 3 | pub type TAI64N = [u8; 12]; 4 | 5 | const TAI64_EPOCH: u64 = 0x400000000000000a; 6 | 7 | pub const ZERO: TAI64N = [0u8; 12]; 8 | 9 | pub fn now() -> TAI64N { 10 | // get system time as duration 11 | let sysnow = SystemTime::now(); 12 | let delta = sysnow.duration_since(UNIX_EPOCH).unwrap(); 13 | 14 | // convert to tai64n 15 | let tai64_secs = delta.as_secs() + TAI64_EPOCH; 16 | let tai64_nano = delta.subsec_nanos(); 17 | 18 | // serialize 19 | let mut res = [0u8; 12]; 20 | res[..8].copy_from_slice(&tai64_secs.to_be_bytes()[..]); 21 | res[8..].copy_from_slice(&tai64_nano.to_be_bytes()[..]); 22 | res 23 | } 24 | 25 | pub fn compare(old: &TAI64N, new: &TAI64N) -> bool { 26 | for i in 0..12 { 27 | if new[i] > old[i] { 28 | return true; 29 | } 30 | } 31 | false 32 | } 33 | -------------------------------------------------------------------------------- /src/wireguard/handshake/types.rs: -------------------------------------------------------------------------------- 1 | use super::super::types::KeyPair; 2 | 3 | use std::error::Error; 4 | use std::fmt; 5 | 6 | /* Internal types for the noise IKpsk2 implementation */ 7 | 8 | // config error 9 | 10 | #[derive(Debug)] 11 | pub struct ConfigError(String); 12 | 13 | impl ConfigError { 14 | pub fn new(s: &str) -> Self { 15 | ConfigError(s.to_string()) 16 | } 17 | } 18 | 19 | impl fmt::Display for ConfigError { 20 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 21 | write!(f, "ConfigError({})", self.0) 22 | } 23 | } 24 | 25 | impl Error for ConfigError { 26 | fn description(&self) -> &str { 27 | &self.0 28 | } 29 | 30 | fn source(&self) -> Option<&(dyn Error + 'static)> { 31 | None 32 | } 33 | } 34 | 35 | // handshake error 36 | 37 | #[derive(Debug)] 38 | pub enum HandshakeError { 39 | DecryptionFailure, 40 | UnknownPublicKey, 41 | UnknownReceiverId, 42 | InvalidMessageFormat, 43 | InvalidSharedSecret, 44 | OldTimestamp, 45 | InvalidState, 46 | InvalidMac1, 47 | RateLimited, 48 | InitiationFlood, 49 | } 50 | 51 | impl fmt::Display for HandshakeError { 52 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 53 | match self { 54 | HandshakeError::InvalidSharedSecret => write!(f, "Zero shared secret"), 55 | HandshakeError::DecryptionFailure => write!(f, "Failed to AEAD:OPEN"), 56 | HandshakeError::UnknownPublicKey => write!(f, "Unknown public key"), 57 | HandshakeError::UnknownReceiverId => { 58 | write!(f, "Receiver id not allocated to any handshake") 59 | } 60 | HandshakeError::InvalidMessageFormat => write!(f, "Invalid handshake message format"), 61 | HandshakeError::OldTimestamp => write!(f, "Timestamp is less/equal to the newest"), 62 | HandshakeError::InvalidState => write!(f, "Message does not apply to handshake state"), 63 | HandshakeError::InvalidMac1 => write!(f, "Message has invalid mac1 field"), 64 | HandshakeError::RateLimited => write!(f, "Message was dropped by rate limiter"), 65 | HandshakeError::InitiationFlood => { 66 | write!(f, "Message was dropped because of initiation flood") 67 | } 68 | } 69 | } 70 | } 71 | 72 | impl Error for HandshakeError { 73 | fn description(&self) -> &str { 74 | "Generic Handshake Error" 75 | } 76 | 77 | fn source(&self) -> Option<&(dyn Error + 'static)> { 78 | None 79 | } 80 | } 81 | 82 | pub type Output<'a, O> = ( 83 | Option<&'a O>, // external identifier associated with peer 84 | Option>, // message to send 85 | Option, // resulting key-pair of successful handshake 86 | ); 87 | 88 | // preshared key 89 | 90 | pub type Psk = [u8; 32]; 91 | -------------------------------------------------------------------------------- /src/wireguard/mod.rs: -------------------------------------------------------------------------------- 1 | /// The wireguard sub-module represents a full, pure, WireGuard implementation: 2 | /// 3 | /// The WireGuard device described here does not depend on particular IO implementations 4 | /// or UAPI, and can be instantiated in unit-tests with the dummy IO implementation. 5 | /// 6 | /// The code at this level serves to "glue" the handshake state-machine 7 | /// and the crypto-key router code together, 8 | /// e.g. every WireGuard peer consists of one handshake peer and one router peer. 9 | mod constants; 10 | mod handshake; 11 | mod peer; 12 | mod queue; 13 | mod router; 14 | mod timers; 15 | mod types; 16 | mod workers; 17 | 18 | #[cfg(test)] 19 | mod tests; 20 | 21 | #[allow(clippy::module_inception)] 22 | mod wireguard; 23 | 24 | // represents a WireGuard interface 25 | pub use wireguard::WireGuard; 26 | 27 | #[cfg(test)] 28 | use super::platform::dummy; 29 | 30 | use super::platform::{tun, udp, Endpoint}; 31 | use types::KeyPair; 32 | -------------------------------------------------------------------------------- /src/wireguard/peer.rs: -------------------------------------------------------------------------------- 1 | use super::timers::Timers; 2 | 3 | use super::tun::Tun; 4 | use super::udp::UDP; 5 | 6 | use super::constants::REKEY_TIMEOUT; 7 | use super::wireguard::WireGuard; 8 | use super::workers::HandshakeJob; 9 | 10 | use std::fmt; 11 | use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; 12 | use std::time::{Instant, SystemTime}; 13 | 14 | use spin::{Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard}; 15 | 16 | use x25519_dalek::PublicKey; 17 | 18 | pub struct PeerInner { 19 | // internal id (for logging) 20 | pub id: u64, 21 | 22 | // wireguard device state 23 | pub wg: WireGuard, 24 | 25 | // TODO: eliminate 26 | pub pk: PublicKey, 27 | 28 | // handshake state 29 | pub walltime_last_handshake: Mutex>, /* walltime for last handshake (for UAPI status) */ 30 | pub last_handshake_sent: Mutex, // instant for last handshake 31 | pub handshake_queued: AtomicBool, // is a handshake job currently queued? 32 | 33 | // stats and configuration 34 | pub rx_bytes: AtomicU64, // received bytes 35 | pub tx_bytes: AtomicU64, // transmitted bytes 36 | 37 | // timer model 38 | pub timers: RwLock, 39 | } 40 | 41 | impl PeerInner { 42 | /* Queue a handshake request for the parallel workers 43 | * (if one does not already exist) 44 | * 45 | * The function is ratelimited. 46 | */ 47 | pub fn packet_send_handshake_initiation(&self) { 48 | log::trace!("{} : packet_send_handshake_initiation", self); 49 | 50 | // the function is rate limited 51 | { 52 | let mut lhs = self.last_handshake_sent.lock(); 53 | if lhs.elapsed() < REKEY_TIMEOUT { 54 | log::trace!("{} : packet_send_handshake_initiation, rate-limited!", self); 55 | return; 56 | } 57 | *lhs = Instant::now(); 58 | } 59 | 60 | // create a new handshake job for the peer 61 | if !self.handshake_queued.swap(true, Ordering::SeqCst) { 62 | self.wg.pending.fetch_add(1, Ordering::SeqCst); 63 | self.wg.queue.send(HandshakeJob::New(self.pk)); 64 | log::trace!( 65 | "{} : packet_send_handshake_initiation, handshake queued", 66 | self 67 | ); 68 | } else { 69 | log::trace!( 70 | "{} : packet_send_handshake_initiation, handshake already queued", 71 | self 72 | ); 73 | } 74 | } 75 | 76 | #[inline(always)] 77 | pub fn timers(&self) -> RwLockReadGuard { 78 | self.timers.read() 79 | } 80 | 81 | #[inline(always)] 82 | pub fn timers_mut(&self) -> RwLockWriteGuard { 83 | self.timers.write() 84 | } 85 | } 86 | 87 | impl fmt::Display for PeerInner { 88 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 89 | write!(f, "peer(id = {})", self.id) 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/wireguard/queue.rs: -------------------------------------------------------------------------------- 1 | use crossbeam_channel::{bounded, Receiver, Sender}; 2 | use std::sync::Mutex; 3 | 4 | pub struct ParallelQueue { 5 | queue: Mutex>>, 6 | } 7 | 8 | impl ParallelQueue { 9 | /// Create a new ParallelQueue instance 10 | /// 11 | /// # Arguments 12 | /// 13 | /// - `queues`: number of readers 14 | /// - `capacity`: capacity of each internal queue 15 | pub fn new(queues: usize, capacity: usize) -> (Self, Vec>) { 16 | let mut receivers = Vec::with_capacity(queues); 17 | let (tx, rx) = bounded(capacity); 18 | for _ in 0..queues { 19 | receivers.push(rx.clone()); 20 | } 21 | ( 22 | ParallelQueue { 23 | queue: Mutex::new(Some(tx)), 24 | }, 25 | receivers, 26 | ) 27 | } 28 | 29 | pub fn send(&self, v: T) { 30 | if let Some(s) = self.queue.lock().unwrap().as_ref() { 31 | let _ = s.send(v); 32 | } 33 | } 34 | 35 | pub fn close(&self) { 36 | *self.queue.lock().unwrap() = None; 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/wireguard/router/anti_replay.rs: -------------------------------------------------------------------------------- 1 | use core::mem; 2 | 3 | // Implementation of RFC 6479. 4 | // https://tools.ietf.org/html/rfc6479 5 | 6 | #[cfg(target_pointer_width = "64")] 7 | type Word = u64; 8 | 9 | #[cfg(target_pointer_width = "64")] 10 | const REDUNDANT_BIT_SHIFTS: usize = 6; 11 | 12 | #[cfg(target_pointer_width = "32")] 13 | type Word = u32; 14 | 15 | #[cfg(target_pointer_width = "32")] 16 | const REDUNDANT_BIT_SHIFTS: usize = 5; 17 | 18 | const SIZE_OF_WORD: usize = mem::size_of::() * 8; 19 | 20 | const BITMAP_BITLEN: usize = 2048; 21 | const BITMAP_LEN: usize = BITMAP_BITLEN / SIZE_OF_WORD; 22 | const BITMAP_INDEX_MASK: u64 = BITMAP_LEN as u64 - 1; 23 | const BITMAP_LOC_MASK: u64 = (SIZE_OF_WORD - 1) as u64; 24 | const WINDOW_SIZE: u64 = (BITMAP_BITLEN - SIZE_OF_WORD) as u64; 25 | 26 | pub struct AntiReplay { 27 | bitmap: [Word; BITMAP_LEN], 28 | last: u64, 29 | } 30 | 31 | impl Default for AntiReplay { 32 | fn default() -> Self { 33 | AntiReplay::new() 34 | } 35 | } 36 | 37 | impl AntiReplay { 38 | pub fn new() -> Self { 39 | debug_assert_eq!(1 << REDUNDANT_BIT_SHIFTS, SIZE_OF_WORD); 40 | debug_assert_eq!(BITMAP_BITLEN % SIZE_OF_WORD, 0); 41 | AntiReplay { 42 | last: 0, 43 | bitmap: [0; BITMAP_LEN], 44 | } 45 | } 46 | 47 | // Returns true if check is passed, i.e., not a replay or too old. 48 | // 49 | // Unlike RFC 6479, zero is allowed. 50 | fn check(&self, seq: u64) -> bool { 51 | // Larger is always good. 52 | if seq > self.last { 53 | return true; 54 | } 55 | 56 | if self.last - seq > WINDOW_SIZE { 57 | return false; 58 | } 59 | 60 | let bit_location = seq & BITMAP_LOC_MASK; 61 | let index = (seq >> REDUNDANT_BIT_SHIFTS) & BITMAP_INDEX_MASK; 62 | 63 | self.bitmap[index as usize] & (1 << bit_location) == 0 64 | } 65 | 66 | // Should only be called if check returns true. 67 | fn update_store(&mut self, seq: u64) { 68 | debug_assert!(self.check(seq)); 69 | 70 | let index = seq >> REDUNDANT_BIT_SHIFTS; 71 | 72 | if seq > self.last { 73 | let index_cur = self.last >> REDUNDANT_BIT_SHIFTS; 74 | let diff = index - index_cur; 75 | 76 | if diff >= BITMAP_LEN as u64 { 77 | self.bitmap = [0; BITMAP_LEN]; 78 | } else { 79 | for i in 0..diff { 80 | let real_index = (index_cur + i + 1) & BITMAP_INDEX_MASK; 81 | self.bitmap[real_index as usize] = 0; 82 | } 83 | } 84 | 85 | self.last = seq; 86 | } 87 | 88 | let index = index & BITMAP_INDEX_MASK; 89 | let bit_location = seq & BITMAP_LOC_MASK; 90 | self.bitmap[index as usize] |= 1 << bit_location; 91 | } 92 | 93 | /// Checks and marks a sequence number in the replay filter 94 | /// 95 | /// # Arguments 96 | /// 97 | /// - seq: Sequence number check for replay and add to filter 98 | /// 99 | /// # Returns 100 | /// 101 | /// Ok(()) if sequence number is valid (not marked and not behind the moving window). 102 | /// Err if the sequence number is invalid (already marked or "too old"). 103 | pub fn update(&mut self, seq: u64) -> bool { 104 | if self.check(seq) { 105 | self.update_store(seq); 106 | true 107 | } else { 108 | false 109 | } 110 | } 111 | } 112 | 113 | #[cfg(test)] 114 | mod tests { 115 | use super::*; 116 | 117 | #[test] 118 | fn anti_replay() { 119 | let mut ar = AntiReplay::new(); 120 | 121 | for i in 0..20000 { 122 | assert!(ar.update(i)); 123 | } 124 | 125 | for i in (0..20000).rev() { 126 | assert!(!ar.check(i)); 127 | } 128 | 129 | assert!(ar.update(65536)); 130 | for i in (65536 - WINDOW_SIZE)..65535 { 131 | assert!(ar.update(i)); 132 | } 133 | 134 | for i in (65536 - 10 * WINDOW_SIZE)..65535 { 135 | assert!(!ar.check(i)); 136 | } 137 | 138 | assert!(ar.update(66000)); 139 | for i in 65537..66000 { 140 | assert!(ar.update(i)); 141 | } 142 | for i in 65537..66000 { 143 | assert_eq!(ar.update(i), false); 144 | } 145 | 146 | // Test max u64. 147 | let next = u64::max_value(); 148 | assert!(ar.update(next)); 149 | assert!(!ar.check(next)); 150 | for i in (next - WINDOW_SIZE)..next { 151 | assert!(ar.update(i)); 152 | } 153 | for i in (next - 20 * WINDOW_SIZE)..next { 154 | assert!(!ar.check(i)); 155 | } 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /src/wireguard/router/constants.rs: -------------------------------------------------------------------------------- 1 | // WireGuard semantics constants 2 | 3 | pub const MAX_QUEUED_PACKETS: usize = 1024; 4 | 5 | // performance constants 6 | 7 | pub const PARALLEL_QUEUE_SIZE: usize = 4 * MAX_QUEUED_PACKETS; 8 | 9 | pub const INORDER_QUEUE_SIZE: usize = MAX_QUEUED_PACKETS; 10 | -------------------------------------------------------------------------------- /src/wireguard/router/device.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::ops::Deref; 3 | use std::sync::atomic::AtomicBool; 4 | use std::sync::Arc; 5 | use std::thread; 6 | 7 | use spin::{Mutex, RwLock}; 8 | use zerocopy::LayoutVerified; 9 | 10 | use super::anti_replay::AntiReplay; 11 | 12 | use super::constants::PARALLEL_QUEUE_SIZE; 13 | use super::messages::{TransportHeader, TYPE_TRANSPORT}; 14 | use super::peer::{new_peer, Peer, PeerHandle}; 15 | use super::types::{Callbacks, RouterError}; 16 | use super::SIZE_MESSAGE_PREFIX; 17 | 18 | use super::receive::ReceiveJob; 19 | use super::route::RoutingTable; 20 | use super::worker::{worker, JobUnion}; 21 | 22 | use super::super::{tun, udp, Endpoint, KeyPair}; 23 | use super::ParallelQueue; 24 | 25 | pub struct DeviceInner> { 26 | // inbound writer (TUN) 27 | pub(super) inbound: T, 28 | 29 | // outbound writer (Bind) 30 | pub(super) outbound: RwLock<(bool, Option)>, 31 | 32 | // routing 33 | #[allow(clippy::type_complexity)] 34 | pub(super) recv: RwLock>>>, /* receiver id -> decryption state */ 35 | pub(super) table: RoutingTable>, 36 | 37 | // work queue 38 | pub(super) work: ParallelQueue>, 39 | } 40 | 41 | pub struct EncryptionState { 42 | pub(super) keypair: Arc, // keypair 43 | pub(super) nonce: u64, // next available nonce 44 | } 45 | 46 | pub struct DecryptionState> { 47 | pub(super) keypair: Arc, 48 | pub(super) confirmed: AtomicBool, 49 | pub(super) protector: Mutex, 50 | pub(super) peer: Peer, 51 | } 52 | 53 | pub struct Device> { 54 | inner: Arc>, 55 | } 56 | 57 | impl> Clone for Device { 58 | fn clone(&self) -> Self { 59 | Device { 60 | inner: self.inner.clone(), 61 | } 62 | } 63 | } 64 | 65 | impl> PartialEq 66 | for Device 67 | { 68 | fn eq(&self, other: &Self) -> bool { 69 | Arc::ptr_eq(&self.inner, &other.inner) 70 | } 71 | } 72 | 73 | impl> Eq for Device {} 74 | 75 | impl> Deref for Device { 76 | type Target = DeviceInner; 77 | fn deref(&self) -> &Self::Target { 78 | &self.inner 79 | } 80 | } 81 | 82 | pub struct DeviceHandle> { 83 | state: Device, // reference to device state 84 | handles: Vec>, // join handles for workers 85 | } 86 | 87 | impl> Drop 88 | for DeviceHandle 89 | { 90 | fn drop(&mut self) { 91 | log::debug!("router: dropping device"); 92 | 93 | // close worker queue 94 | self.state.work.close(); 95 | 96 | // join all worker threads 97 | while let Some(handle) = self.handles.pop() { 98 | handle.thread().unpark(); 99 | handle.join().unwrap(); 100 | } 101 | log::debug!("router: joined with all workers from pool"); 102 | } 103 | } 104 | 105 | impl> DeviceHandle { 106 | pub fn new(num_workers: usize, tun: T) -> DeviceHandle { 107 | let (work, mut consumers) = ParallelQueue::new(num_workers, PARALLEL_QUEUE_SIZE); 108 | let device = Device { 109 | inner: Arc::new(DeviceInner { 110 | work, 111 | inbound: tun, 112 | outbound: RwLock::new((true, None)), 113 | recv: RwLock::new(HashMap::new()), 114 | table: RoutingTable::new(), 115 | }), 116 | }; 117 | 118 | // start worker threads 119 | let mut threads = Vec::with_capacity(num_workers); 120 | while let Some(rx) = consumers.pop() { 121 | threads.push(thread::spawn(move || worker(rx))); 122 | } 123 | debug_assert!(num_workers > 0, "zero worker threads"); 124 | debug_assert_eq!( 125 | threads.len(), 126 | num_workers, 127 | "workers does not match consumers" 128 | ); 129 | 130 | // return exported device handle 131 | DeviceHandle { 132 | state: device, 133 | handles: threads, 134 | } 135 | } 136 | 137 | pub fn send_raw(&self, msg: &[u8], dst: &mut E) -> Result<(), B::Error> { 138 | let bind = self.state.outbound.read(); 139 | if bind.0 { 140 | if let Some(bind) = bind.1.as_ref() { 141 | return bind.write(msg, dst); 142 | } 143 | } 144 | Ok(()) 145 | } 146 | 147 | /// Brings the router down. 148 | /// When the router is brought down it: 149 | /// - Prevents transmission of outbound messages. 150 | pub fn down(&self) { 151 | self.state.outbound.write().0 = false; 152 | } 153 | 154 | /// Brints the router up 155 | /// When the router is brought up it enables the transmission of outbound messages. 156 | pub fn up(&self) { 157 | self.state.outbound.write().0 = true; 158 | } 159 | 160 | /// A new secret key has been set for the device. 161 | /// According to WireGuard semantics, this should cause all "sending" keys to be discarded. 162 | pub fn clear_sending_keys(&self) { 163 | log::debug!("Clear sending keys"); 164 | // TODO: Implement. Consider: The device does not have an explicit list of peers 165 | } 166 | 167 | /// Adds a new peer to the device 168 | /// 169 | /// # Returns 170 | /// 171 | /// A atomic ref. counted peer (with liftime matching the device) 172 | pub fn new_peer(&self, opaque: C::Opaque) -> PeerHandle { 173 | new_peer(self.state.clone(), opaque) 174 | } 175 | 176 | /// Cryptkey routes and sends a plaintext message (IP packet) 177 | /// 178 | /// # Arguments 179 | /// 180 | /// - msg: IP packet to crypt-key route 181 | pub fn send(&self, msg: Vec) -> Result<(), RouterError> { 182 | debug_assert!(msg.len() > SIZE_MESSAGE_PREFIX); 183 | log::trace!( 184 | "send, packet = {}", 185 | hex::encode(&msg[SIZE_MESSAGE_PREFIX..]) 186 | ); 187 | 188 | // ignore header prefix (for in-place transport message construction) 189 | let packet = &msg[SIZE_MESSAGE_PREFIX..]; 190 | 191 | // lookup peer based on IP packet destination address 192 | let peer = self 193 | .state 194 | .table 195 | .get_route(packet) 196 | .ok_or(RouterError::NoCryptoKeyRoute)?; 197 | 198 | // schedule for encryption and transmission to peer 199 | peer.send(msg, true); 200 | Ok(()) 201 | } 202 | 203 | /// Receive an encrypted transport message 204 | /// 205 | /// # Arguments 206 | /// 207 | /// - src: Source address of the packet 208 | /// - msg: Encrypted transport message 209 | /// 210 | /// # Returns 211 | pub fn recv(&self, src: E, msg: Vec) -> Result<(), RouterError> { 212 | log::trace!("receive, src: {}", src.into_address()); 213 | 214 | // parse / cast 215 | let (header, _) = match LayoutVerified::new_from_prefix(&msg[..]) { 216 | Some(v) => v, 217 | None => { 218 | return Err(RouterError::MalformedTransportMessage); 219 | } 220 | }; 221 | 222 | let header: LayoutVerified<&[u8], TransportHeader> = header; 223 | 224 | debug_assert!( 225 | header.f_type.get() == TYPE_TRANSPORT as u32, 226 | "this should be checked by the message type multiplexer" 227 | ); 228 | 229 | log::trace!( 230 | "handle transport message: (receiver = {}, counter = {})", 231 | header.f_receiver, 232 | header.f_counter 233 | ); 234 | 235 | // lookup peer based on receiver id 236 | let dec = self.state.recv.read(); 237 | let dec = dec 238 | .get(&header.f_receiver.get()) 239 | .ok_or(RouterError::UnknownReceiverId)?; 240 | 241 | // create inbound job 242 | let job = ReceiveJob::new(msg, dec.clone(), src); 243 | 244 | // 1. add to sequential queue (drop if full) 245 | // 2. then add to parallel work queue (wait if full) 246 | if dec.peer.inbound.push(job.clone()) { 247 | self.state.work.send(JobUnion::Inbound(job)); 248 | } 249 | Ok(()) 250 | } 251 | 252 | /// Set outbound writer 253 | pub fn set_outbound_writer(&self, new: B) { 254 | self.state.outbound.write().1 = Some(new); 255 | } 256 | } 257 | -------------------------------------------------------------------------------- /src/wireguard/router/ip.rs: -------------------------------------------------------------------------------- 1 | use core::mem; 2 | 3 | use byteorder::BigEndian; 4 | use zerocopy::byteorder::U16; 5 | use zerocopy::LayoutVerified; 6 | use zerocopy::{AsBytes, FromBytes}; 7 | 8 | pub const VERSION_IP4: u8 = 4; 9 | pub const VERSION_IP6: u8 = 6; 10 | 11 | #[repr(packed)] 12 | #[derive(Copy, Clone, FromBytes, AsBytes)] 13 | pub struct IPv4Header { 14 | _f_space1: [u8; 2], 15 | pub f_total_len: U16, 16 | _f_space2: [u8; 8], 17 | pub f_source: [u8; 4], 18 | pub f_destination: [u8; 4], 19 | } 20 | 21 | #[repr(packed)] 22 | #[derive(Copy, Clone, FromBytes, AsBytes)] 23 | pub struct IPv6Header { 24 | _f_space1: [u8; 4], 25 | pub f_len: U16, 26 | _f_space2: [u8; 2], 27 | pub f_source: [u8; 16], 28 | pub f_destination: [u8; 16], 29 | } 30 | 31 | #[inline(always)] 32 | pub fn inner_length(packet: &[u8]) -> Option { 33 | match packet.get(0)? >> 4 { 34 | VERSION_IP4 => { 35 | let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = 36 | LayoutVerified::new_from_prefix(packet)?; 37 | 38 | Some(header.f_total_len.get() as usize) 39 | } 40 | VERSION_IP6 => { 41 | // check length and cast to IPv6 header 42 | let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = 43 | LayoutVerified::new_from_prefix(packet)?; 44 | 45 | Some(header.f_len.get() as usize + mem::size_of::()) 46 | } 47 | _ => None, 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/wireguard/router/messages.rs: -------------------------------------------------------------------------------- 1 | use byteorder::LittleEndian; 2 | use zerocopy::byteorder::{U32, U64}; 3 | use zerocopy::{AsBytes, FromBytes}; 4 | 5 | pub const TYPE_TRANSPORT: u32 = 4; 6 | 7 | #[repr(packed)] 8 | #[derive(Copy, Clone, FromBytes, AsBytes)] 9 | pub struct TransportHeader { 10 | pub f_type: U32, 11 | pub f_receiver: U32, 12 | pub f_counter: U64, 13 | } 14 | -------------------------------------------------------------------------------- /src/wireguard/router/mod.rs: -------------------------------------------------------------------------------- 1 | mod anti_replay; 2 | mod constants; 3 | mod device; 4 | mod ip; 5 | mod messages; 6 | mod peer; 7 | mod route; 8 | mod types; 9 | 10 | mod queue; 11 | mod receive; 12 | mod send; 13 | mod worker; 14 | 15 | #[cfg(test)] 16 | mod tests; 17 | 18 | use messages::TransportHeader; 19 | 20 | use super::constants::REJECT_AFTER_MESSAGES; 21 | use super::queue::ParallelQueue; 22 | use super::types::*; 23 | 24 | use core::mem; 25 | 26 | pub const SIZE_TAG: usize = 16; 27 | pub const SIZE_MESSAGE_PREFIX: usize = mem::size_of::(); 28 | pub const CAPACITY_MESSAGE_POSTFIX: usize = SIZE_TAG; 29 | 30 | pub const fn message_data_len(payload: usize) -> usize { 31 | payload + mem::size_of::() + SIZE_TAG 32 | } 33 | 34 | pub use device::DeviceHandle as Device; 35 | pub use messages::TYPE_TRANSPORT; 36 | pub use peer::PeerHandle; 37 | pub use types::Callbacks; 38 | -------------------------------------------------------------------------------- /src/wireguard/router/queue.rs: -------------------------------------------------------------------------------- 1 | use arraydeque::ArrayDeque; 2 | use spin::Mutex; 3 | 4 | use core::mem; 5 | use core::sync::atomic::{AtomicUsize, Ordering}; 6 | 7 | use super::constants::INORDER_QUEUE_SIZE; 8 | 9 | pub trait SequentialJob { 10 | fn is_ready(&self) -> bool; 11 | 12 | fn sequential_work(self); 13 | } 14 | 15 | pub trait ParallelJob: Sized + SequentialJob { 16 | fn queue(&self) -> &Queue; 17 | 18 | fn parallel_work(&self); 19 | } 20 | 21 | pub struct Queue { 22 | contenders: AtomicUsize, 23 | queue: Mutex>, 24 | 25 | #[cfg(debug)] 26 | _flag: Mutex<()>, 27 | } 28 | 29 | impl Queue { 30 | pub fn new() -> Queue { 31 | Queue { 32 | contenders: AtomicUsize::new(0), 33 | queue: Mutex::new(ArrayDeque::new()), 34 | 35 | #[cfg(debug)] 36 | _flag: Mutex::new(()), 37 | } 38 | } 39 | 40 | pub fn push(&self, job: J) -> bool { 41 | self.queue.lock().push_back(job).is_ok() 42 | } 43 | 44 | pub fn consume(&self) { 45 | // check if we are the first contender 46 | let pos = self.contenders.fetch_add(1, Ordering::SeqCst); 47 | if pos > 0 { 48 | assert!(usize::max_value() > pos, "contenders overflow"); 49 | return; 50 | } 51 | 52 | // enter the critical section 53 | let mut contenders = 1; // myself 54 | while contenders > 0 { 55 | // check soundness in debug builds 56 | #[cfg(debug)] 57 | let _flag = self 58 | ._flag 59 | .try_lock() 60 | .expect("contenders should ensure mutual exclusion"); 61 | 62 | // handle every ready element 63 | loop { 64 | let mut queue = self.queue.lock(); 65 | 66 | // check if front job is ready 67 | match queue.front() { 68 | None => break, 69 | Some(job) => { 70 | if !job.is_ready() { 71 | break; 72 | } 73 | } 74 | }; 75 | 76 | // take the job out of the queue 77 | let job = queue.pop_front().unwrap(); 78 | debug_assert!(job.is_ready()); 79 | mem::drop(queue); 80 | 81 | // process element 82 | job.sequential_work(); 83 | } 84 | 85 | #[cfg(debug)] 86 | mem::drop(_flag); 87 | 88 | // decrease contenders 89 | contenders = self.contenders.fetch_sub(contenders, Ordering::SeqCst) - contenders; 90 | } 91 | } 92 | } 93 | 94 | #[cfg(test)] 95 | mod tests { 96 | use super::*; 97 | 98 | use std::thread; 99 | 100 | use std::sync::Arc; 101 | use std::time::Duration; 102 | 103 | use rand::thread_rng; 104 | use rand::Rng; 105 | 106 | #[test] 107 | fn test_consume_queue() { 108 | struct TestJob { 109 | cnt: Arc, 110 | wait_sequential: Duration, 111 | } 112 | 113 | impl SequentialJob for TestJob { 114 | fn is_ready(&self) -> bool { 115 | true 116 | } 117 | 118 | fn sequential_work(self) { 119 | thread::sleep(self.wait_sequential); 120 | self.cnt.fetch_add(1, Ordering::SeqCst); 121 | } 122 | } 123 | 124 | fn hammer(queue: &Arc>, cnt: Arc) -> usize { 125 | let mut jobs = 0; 126 | let mut rng = thread_rng(); 127 | for _ in 0..10_000 { 128 | if rng.gen() { 129 | let wait_sequential: u64 = rng.gen(); 130 | let wait_sequential = wait_sequential % 1000; 131 | 132 | let wait_parallel: u64 = rng.gen(); 133 | let wait_parallel = wait_parallel % 1000; 134 | 135 | thread::sleep(Duration::from_micros(wait_parallel)); 136 | 137 | queue.push(TestJob { 138 | cnt: cnt.clone(), 139 | wait_sequential: Duration::from_micros(wait_sequential), 140 | }); 141 | jobs += 1; 142 | } else { 143 | queue.consume(); 144 | } 145 | } 146 | queue.consume(); 147 | jobs 148 | } 149 | 150 | let queue = Arc::new(Queue::new()); 151 | let counter = Arc::new(AtomicUsize::new(0)); 152 | 153 | // repeatedly apply operations randomly from concurrent threads 154 | let other = { 155 | let queue = queue.clone(); 156 | let counter = counter.clone(); 157 | thread::spawn(move || hammer(&queue, counter)) 158 | }; 159 | let mut jobs = hammer(&queue, counter.clone()); 160 | 161 | // wait, consume and check empty 162 | jobs += other.join().unwrap(); 163 | assert_eq!(queue.queue.lock().len(), 0, "elements left in queue"); 164 | assert_eq!( 165 | jobs, 166 | counter.load(Ordering::Acquire), 167 | "did not consume every job" 168 | ); 169 | } 170 | 171 | /* Fuzz the Queue */ 172 | #[test] 173 | fn test_fuzz_queue() { 174 | struct TestJob {} 175 | 176 | impl SequentialJob for TestJob { 177 | fn is_ready(&self) -> bool { 178 | true 179 | } 180 | 181 | fn sequential_work(self) {} 182 | } 183 | 184 | fn hammer(queue: &Arc>) { 185 | let mut rng = thread_rng(); 186 | for _ in 0..1_000_000 { 187 | if rng.gen() { 188 | queue.push(TestJob {}); 189 | } else { 190 | queue.consume(); 191 | } 192 | } 193 | } 194 | 195 | let queue = Arc::new(Queue::new()); 196 | 197 | // repeatedly apply operations randomly from concurrent threads 198 | let other = { 199 | let queue = queue.clone(); 200 | thread::spawn(move || hammer(&queue)) 201 | }; 202 | hammer(&queue); 203 | 204 | // wait, consume and check empty 205 | other.join().unwrap(); 206 | queue.consume(); 207 | assert_eq!(queue.queue.lock().len(), 0); 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /src/wireguard/router/receive.rs: -------------------------------------------------------------------------------- 1 | use super::device::DecryptionState; 2 | use super::ip::inner_length; 3 | use super::messages::TransportHeader; 4 | use super::queue::{ParallelJob, Queue, SequentialJob}; 5 | use super::types::Callbacks; 6 | use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; 7 | 8 | use super::super::{tun, udp, Endpoint}; 9 | 10 | use alloc::sync::Arc; 11 | use core::sync::atomic::{AtomicBool, Ordering}; 12 | use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; 13 | use spin::Mutex; 14 | use zerocopy::{AsBytes, LayoutVerified}; 15 | 16 | struct Inner> { 17 | ready: AtomicBool, // job status 18 | buffer: Mutex<(Option, Vec)>, // endpoint & ciphertext buffer 19 | state: Arc>, // decryption state (keys and replay protector) 20 | } 21 | 22 | pub struct ReceiveJob>( 23 | Arc>, 24 | ); 25 | 26 | impl> Clone 27 | for ReceiveJob 28 | { 29 | fn clone(&self) -> ReceiveJob { 30 | ReceiveJob(self.0.clone()) 31 | } 32 | } 33 | 34 | impl> ReceiveJob { 35 | pub fn new( 36 | buffer: Vec, 37 | state: Arc>, 38 | endpoint: E, 39 | ) -> ReceiveJob { 40 | ReceiveJob(Arc::new(Inner { 41 | ready: AtomicBool::new(false), 42 | buffer: Mutex::new((Some(endpoint), buffer)), 43 | state, 44 | })) 45 | } 46 | } 47 | 48 | impl> ParallelJob 49 | for ReceiveJob 50 | { 51 | fn queue(&self) -> &Queue { 52 | &self.0.state.peer.inbound 53 | } 54 | 55 | /* The parallel section of an incoming job: 56 | * 57 | * - Decryption. 58 | * - Crypto-key routing lookup. 59 | * 60 | * Note: We truncate the message buffer to 0 bytes in case of authentication failure 61 | * or crypto-key routing failure (attempted impersonation). 62 | * 63 | * Note: We cannot do replay protection in the parallel job, 64 | * since this can cause dropping of packets (leaving the window) due to scheduling. 65 | */ 66 | fn parallel_work(&self) { 67 | debug_assert_eq!( 68 | self.is_ready(), 69 | false, 70 | "doing parallel work on completed job" 71 | ); 72 | log::trace!("processing parallel receive job"); 73 | 74 | // decrypt 75 | { 76 | // closure for locking 77 | let job = &self.0; 78 | let peer = &job.state.peer; 79 | let mut msg = job.buffer.lock(); 80 | 81 | // process buffer 82 | let ok = (|| { 83 | // cast to header followed by payload 84 | let (header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = 85 | match LayoutVerified::new_from_prefix(&mut msg.1[..]) { 86 | Some(v) => v, 87 | None => return false, 88 | }; 89 | 90 | // create nonce object 91 | let mut nonce = [0u8; 12]; 92 | debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); 93 | nonce[4..].copy_from_slice(header.f_counter.as_bytes()); 94 | let nonce = Nonce::assume_unique_for_key(nonce); 95 | // do the weird ring AEAD dance 96 | let key = LessSafeKey::new( 97 | UnboundKey::new(&CHACHA20_POLY1305, &job.state.keypair.recv.key[..]).unwrap(), 98 | ); 99 | 100 | // attempt to open (and authenticate) the body 101 | match key.open_in_place(nonce, Aad::empty(), packet) { 102 | Ok(_) => (), 103 | Err(_) => return false, 104 | } 105 | 106 | // check that counter not after reject 107 | if header.f_counter.get() >= REJECT_AFTER_MESSAGES { 108 | return false; 109 | } 110 | 111 | // check crypto-key router 112 | packet.len() == SIZE_TAG || peer.device.table.check_route(&peer, &packet) 113 | })(); 114 | 115 | // remove message in case of failure: 116 | // to indicate failure and avoid later accidental use of unauthenticated data. 117 | if !ok { 118 | msg.1.truncate(0); 119 | } 120 | }; 121 | 122 | // mark ready 123 | self.0.ready.store(true, Ordering::Release); 124 | } 125 | } 126 | 127 | impl> SequentialJob 128 | for ReceiveJob 129 | { 130 | fn is_ready(&self) -> bool { 131 | self.0.ready.load(Ordering::Acquire) 132 | } 133 | 134 | fn sequential_work(self) { 135 | debug_assert_eq!( 136 | self.is_ready(), 137 | true, 138 | "doing sequential work on an incomplete job" 139 | ); 140 | log::trace!("processing sequential receive job"); 141 | 142 | let job = &self.0; 143 | let peer = &job.state.peer; 144 | let mut msg = job.buffer.lock(); 145 | let endpoint = msg.0.take(); 146 | 147 | // cast transport header 148 | let (header, packet): (LayoutVerified<&[u8], TransportHeader>, &[u8]) = 149 | match LayoutVerified::new_from_prefix(&msg.1[..]) { 150 | Some(v) => v, 151 | None => { 152 | // also covers authentication failure (will fail to parse header) 153 | return; 154 | } 155 | }; 156 | 157 | // check for replay 158 | if !job.state.protector.lock().update(header.f_counter.get()) { 159 | log::debug!("inbound worker: replay detected"); 160 | return; 161 | } 162 | 163 | // check for confirms key 164 | if !job.state.confirmed.swap(true, Ordering::SeqCst) { 165 | log::debug!("inbound worker: message confirms key"); 166 | peer.confirm_key(&job.state.keypair); 167 | } 168 | 169 | // update endpoint 170 | *peer.endpoint.lock() = endpoint; 171 | 172 | // check if should be written to TUN 173 | // (keep-alive and malformed packets will have no inner length) 174 | if let Some(inner) = inner_length(packet) { 175 | if inner + SIZE_TAG <= packet.len() { 176 | let _ = peer.device.inbound.write(&packet[..inner]).map_err(|e| { 177 | log::debug!("failed to write inbound packet to TUN: {:?}", e); 178 | }); 179 | } 180 | } 181 | 182 | // trigger callback 183 | C::recv(&peer.opaque, msg.1.len(), true, &job.state.keypair); 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /src/wireguard/router/route.rs: -------------------------------------------------------------------------------- 1 | use super::ip::*; 2 | 3 | // TODO: no_std alternatives 4 | use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; 5 | 6 | use spin::RwLock; 7 | use treebitmap::address::Address; 8 | use treebitmap::IpLookupTable; 9 | use zerocopy::LayoutVerified; 10 | 11 | /* Functions for obtaining and validating "cryptokey" routes */ 12 | 13 | pub struct RoutingTable { 14 | ipv4: RwLock>, 15 | ipv6: RwLock>, 16 | } 17 | 18 | impl RoutingTable { 19 | pub fn new() -> Self { 20 | RoutingTable { 21 | ipv4: RwLock::new(IpLookupTable::new()), 22 | ipv6: RwLock::new(IpLookupTable::new()), 23 | } 24 | } 25 | 26 | // collect keys mapping to the given value 27 | fn collect(table: &IpLookupTable, value: &T) -> Vec<(A, u32)> 28 | where 29 | A: Address, 30 | { 31 | let mut res = Vec::new(); 32 | for (ip, cidr, v) in table.iter() { 33 | if v == value { 34 | res.push((ip, cidr)) 35 | } 36 | } 37 | res 38 | } 39 | 40 | pub fn insert(&self, ip: IpAddr, cidr: u32, value: T) { 41 | match ip { 42 | IpAddr::V4(v4) => self.ipv4.write().insert(v4.mask(cidr), cidr, value), 43 | IpAddr::V6(v6) => self.ipv6.write().insert(v6.mask(cidr), cidr, value), 44 | }; 45 | } 46 | 47 | pub fn list(&self, value: &T) -> Vec<(IpAddr, u32)> { 48 | let mut res = vec![]; 49 | res.extend( 50 | Self::collect(&*self.ipv4.read(), value) 51 | .into_iter() 52 | .map(|(ip, cidr)| (IpAddr::V4(ip), cidr)), 53 | ); 54 | res.extend( 55 | Self::collect(&*self.ipv6.read(), value) 56 | .into_iter() 57 | .map(|(ip, cidr)| (IpAddr::V6(ip), cidr)), 58 | ); 59 | res 60 | } 61 | 62 | pub fn remove(&self, value: &T) { 63 | let mut v4 = self.ipv4.write(); 64 | for (ip, cidr) in Self::collect(&*v4, value) { 65 | v4.remove(ip, cidr); 66 | } 67 | 68 | let mut v6 = self.ipv6.write(); 69 | for (ip, cidr) in Self::collect(&*v6, value) { 70 | v6.remove(ip, cidr); 71 | } 72 | } 73 | 74 | #[inline(always)] 75 | pub fn get_route(&self, packet: &[u8]) -> Option { 76 | match packet.get(0)? >> 4 { 77 | VERSION_IP4 => { 78 | // check length and cast to IPv4 header 79 | let (header, _): (LayoutVerified<&[u8], IPv4Header>, _) = 80 | LayoutVerified::new_from_prefix(packet)?; 81 | 82 | log::trace!( 83 | "router, get route for IPv4 destination: {:?}", 84 | Ipv4Addr::from(header.f_destination) 85 | ); 86 | 87 | // check IPv4 source address 88 | self.ipv4 89 | .read() 90 | .longest_match(Ipv4Addr::from(header.f_destination)) 91 | .map(|(_, _, p)| p.clone()) 92 | } 93 | VERSION_IP6 => { 94 | // check length and cast to IPv6 header 95 | let (header, _): (LayoutVerified<&[u8], IPv6Header>, _) = 96 | LayoutVerified::new_from_prefix(packet)?; 97 | 98 | log::trace!( 99 | "router, get route for IPv6 destination: {:?}", 100 | Ipv6Addr::from(header.f_destination) 101 | ); 102 | 103 | // check IPv6 source address 104 | self.ipv6 105 | .read() 106 | .longest_match(Ipv6Addr::from(header.f_destination)) 107 | .map(|(_, _, p)| p.clone()) 108 | } 109 | v => { 110 | log::trace!("router, invalid IP version {}", v); 111 | None 112 | } 113 | } 114 | } 115 | 116 | #[inline(always)] 117 | pub fn check_route(&self, peer: &T, packet: &[u8]) -> bool { 118 | match packet.get(0).map(|v| v >> 4) { 119 | Some(VERSION_IP4) => LayoutVerified::new_from_prefix(packet) 120 | .and_then(|(header, _): (LayoutVerified<&[u8], IPv4Header>, _)| { 121 | self.ipv4 122 | .read() 123 | .longest_match(Ipv4Addr::from(header.f_source)) 124 | .map(|(_, _, p)| p == peer) 125 | }) 126 | .is_some(), 127 | 128 | Some(VERSION_IP6) => LayoutVerified::new_from_prefix(packet) 129 | .and_then(|(header, _): (LayoutVerified<&[u8], IPv6Header>, _)| { 130 | self.ipv6 131 | .read() 132 | .longest_match(Ipv6Addr::from(header.f_source)) 133 | .map(|(_, _, p)| p == peer) 134 | }) 135 | .is_some(), 136 | _ => false, 137 | } 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /src/wireguard/router/send.rs: -------------------------------------------------------------------------------- 1 | use super::messages::{TransportHeader, TYPE_TRANSPORT}; 2 | use super::peer::Peer; 3 | use super::queue::{ParallelJob, Queue, SequentialJob}; 4 | use super::types::Callbacks; 5 | use super::KeyPair; 6 | use super::{REJECT_AFTER_MESSAGES, SIZE_TAG}; 7 | 8 | use super::super::{tun, udp, Endpoint}; 9 | 10 | use alloc::sync::Arc; 11 | use core::sync::atomic::{AtomicBool, Ordering}; 12 | 13 | use ring::aead::{Aad, LessSafeKey, Nonce, UnboundKey, CHACHA20_POLY1305}; 14 | use spin::Mutex; 15 | use zerocopy::{AsBytes, LayoutVerified}; 16 | 17 | struct Inner> { 18 | ready: AtomicBool, 19 | buffer: Mutex>, 20 | counter: u64, 21 | keypair: Arc, 22 | peer: Peer, 23 | } 24 | 25 | pub struct SendJob>( 26 | Arc>, 27 | ); 28 | 29 | impl> Clone for SendJob { 30 | fn clone(&self) -> SendJob { 31 | SendJob(self.0.clone()) 32 | } 33 | } 34 | 35 | impl> SendJob { 36 | pub fn new( 37 | buffer: Vec, 38 | counter: u64, 39 | keypair: Arc, 40 | peer: Peer, 41 | ) -> SendJob { 42 | SendJob(Arc::new(Inner { 43 | buffer: Mutex::new(buffer), 44 | counter, 45 | keypair, 46 | peer, 47 | ready: AtomicBool::new(false), 48 | })) 49 | } 50 | } 51 | 52 | impl> ParallelJob 53 | for SendJob 54 | { 55 | fn queue(&self) -> &Queue { 56 | &self.0.peer.outbound 57 | } 58 | 59 | fn parallel_work(&self) { 60 | debug_assert_eq!( 61 | self.is_ready(), 62 | false, 63 | "doing parallel work on completed job" 64 | ); 65 | log::trace!("processing parallel send job"); 66 | 67 | // encrypt body 68 | { 69 | // make space for the tag 70 | let job = &*self.0; 71 | let mut msg = job.buffer.lock(); 72 | msg.extend([0u8; SIZE_TAG].iter()); 73 | 74 | // cast to header (should never fail) 75 | let (mut header, packet): (LayoutVerified<&mut [u8], TransportHeader>, &mut [u8]) = 76 | LayoutVerified::new_from_prefix(&mut msg[..]) 77 | .expect("earlier code should ensure that there is ample space"); 78 | 79 | // set header fields 80 | debug_assert!( 81 | job.counter < REJECT_AFTER_MESSAGES, 82 | "should be checked when assigning counters" 83 | ); 84 | header.f_type.set(TYPE_TRANSPORT); 85 | header.f_receiver.set(job.keypair.send.id); 86 | header.f_counter.set(job.counter); 87 | 88 | // create a nonce object 89 | let mut nonce = [0u8; 12]; 90 | debug_assert_eq!(nonce.len(), CHACHA20_POLY1305.nonce_len()); 91 | nonce[4..].copy_from_slice(header.f_counter.as_bytes()); 92 | let nonce = Nonce::assume_unique_for_key(nonce); 93 | 94 | // encrypt contents of transport message in-place 95 | let tag_offset = packet.len() - SIZE_TAG; 96 | let key = LessSafeKey::new( 97 | UnboundKey::new(&CHACHA20_POLY1305, &job.keypair.send.key[..]).unwrap(), 98 | ); 99 | let tag = key 100 | .seal_in_place_separate_tag(nonce, Aad::empty(), &mut packet[..tag_offset]) 101 | .unwrap(); 102 | 103 | // append tag 104 | packet[tag_offset..].copy_from_slice(tag.as_ref()); 105 | } 106 | 107 | // mark ready 108 | self.0.ready.store(true, Ordering::Release); 109 | } 110 | } 111 | 112 | impl> SequentialJob 113 | for SendJob 114 | { 115 | fn is_ready(&self) -> bool { 116 | self.0.ready.load(Ordering::Acquire) 117 | } 118 | 119 | fn sequential_work(self) { 120 | debug_assert_eq!( 121 | self.is_ready(), 122 | true, 123 | "doing sequential work 124 | on an incomplete job" 125 | ); 126 | log::trace!("processing sequential send job"); 127 | 128 | // send to peer 129 | let job = &self.0; 130 | let msg = job.buffer.lock(); 131 | let xmit = job.peer.send_raw(&msg[..]).is_ok(); 132 | 133 | // trigger callback (for timers) 134 | C::send(&job.peer.opaque, msg.len(), xmit, &job.keypair, job.counter); 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /src/wireguard/router/tests/bench.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "unstable")] 2 | extern crate test; 3 | 4 | use super::*; 5 | 6 | use std::sync::atomic::AtomicUsize; 7 | use std::sync::atomic::Ordering; 8 | use std::sync::Arc; 9 | 10 | // only used in benchmark 11 | #[cfg(feature = "unstable")] 12 | use std::net::IpAddr; 13 | 14 | // only used in benchmark 15 | #[cfg(feature = "unstable")] 16 | use num_cpus; 17 | 18 | #[cfg(feature = "unstable")] 19 | use test::Bencher; 20 | 21 | // 22 | struct TransmissionCounter { 23 | sent: AtomicUsize, 24 | recv: AtomicUsize, 25 | } 26 | 27 | impl TransmissionCounter { 28 | #[allow(dead_code)] 29 | fn new() -> TransmissionCounter { 30 | TransmissionCounter { 31 | sent: AtomicUsize::new(0), 32 | recv: AtomicUsize::new(0), 33 | } 34 | } 35 | 36 | #[allow(dead_code)] 37 | fn reset(&self) { 38 | self.sent.store(0, Ordering::SeqCst); 39 | self.recv.store(0, Ordering::SeqCst); 40 | } 41 | 42 | #[allow(dead_code)] 43 | fn sent(&self) -> usize { 44 | self.sent.load(Ordering::Acquire) 45 | } 46 | 47 | #[allow(dead_code)] 48 | fn recv(&self) -> usize { 49 | self.recv.load(Ordering::Acquire) 50 | } 51 | } 52 | 53 | struct BencherCallbacks {} 54 | 55 | impl Callbacks for BencherCallbacks { 56 | type Opaque = Arc; 57 | fn send(t: &Self::Opaque, size: usize, _sent: bool, _keypair: &Arc, _counter: u64) { 58 | t.sent.fetch_add(size, Ordering::SeqCst); 59 | } 60 | fn recv(t: &Self::Opaque, size: usize, _sent: bool, _keypair: &Arc) { 61 | t.recv.fetch_add(size, Ordering::SeqCst); 62 | } 63 | fn need_key(_t: &Self::Opaque) {} 64 | fn key_confirmed(_t: &Self::Opaque) {} 65 | } 66 | 67 | #[cfg(feature = "profiler")] 68 | use cpuprofiler::PROFILER; 69 | 70 | #[cfg(feature = "profiler")] 71 | fn profiler_stop() { 72 | println!("Stopping profiler"); 73 | PROFILER.lock().unwrap().stop().unwrap(); 74 | } 75 | 76 | #[cfg(feature = "profiler")] 77 | fn profiler_start(name: &str) { 78 | use std::path::Path; 79 | 80 | // find first available path to save profiler output 81 | let mut n = 0; 82 | loop { 83 | let path = format!("./{}-{}.profile", name, n); 84 | if !Path::new(path.as_str()).exists() { 85 | println!("Starting profiler: {}", path); 86 | PROFILER.lock().unwrap().start(path).unwrap(); 87 | break; 88 | }; 89 | n += 1; 90 | } 91 | } 92 | 93 | #[cfg(feature = "unstable")] 94 | #[bench] 95 | fn bench_router_outbound(b: &mut Bencher) { 96 | // 10 GB transmission per iteration 97 | const BYTES_PER_ITER: usize = 100 * 1024 * 1024 * 1024; 98 | 99 | // inner payload of IPv4 packet is 1440 bytes 100 | const BYTES_PER_PACKET: usize = 1440; 101 | 102 | // create device 103 | let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(false); 104 | let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> = 105 | Device::new(num_cpus::get_physical(), tun_writer); 106 | 107 | // add peer to router 108 | let opaque = Arc::new(TransmissionCounter::new()); 109 | let peer = router.new_peer(opaque.clone()); 110 | peer.add_keypair(dummy_keypair(true)); 111 | 112 | // add subnet to peer 113 | let (mask, len, dst) = ("192.168.1.0", 24, "192.168.1.20"); 114 | let mask: IpAddr = mask.parse().unwrap(); 115 | peer.add_allowed_ip(mask, len); 116 | 117 | // create "IP packet" 118 | let dst = dst.parse().unwrap(); 119 | let src = match dst { 120 | IpAddr::V4(_) => "127.0.0.1".parse().unwrap(), 121 | IpAddr::V6(_) => "::1".parse().unwrap(), 122 | }; 123 | let packet = make_packet(BYTES_PER_PACKET, src, dst, 0); 124 | 125 | // suffix with zero and reserve capacity for tag 126 | // (normally done to enable in-place transport message construction) 127 | let mut msg = pad(&packet); 128 | msg.reserve(16); 129 | 130 | // setup profiler 131 | #[cfg(feature = "profiler")] 132 | profiler_start("outbound"); 133 | 134 | // repeatedly transmit 10 GB 135 | b.iter(|| { 136 | opaque.reset(); 137 | while opaque.sent() < BYTES_PER_ITER / packet.len() { 138 | router 139 | .send(msg.to_vec()) 140 | .expect("failed to crypto-route packet"); 141 | } 142 | }); 143 | 144 | // stop profiler 145 | #[cfg(feature = "profiler")] 146 | profiler_stop(); 147 | } 148 | 149 | /* 150 | #[test] 151 | fn bench_router_bidirectional(b: &mut Bencher) { 152 | const MAX_SIZE_BODY: usize = 1500; 153 | 154 | let tests = [ 155 | ( 156 | ("192.168.1.0", 24, "192.168.1.20", true), 157 | ("172.133.133.133", 32, "172.133.133.133", true), 158 | ), 159 | ( 160 | ("192.168.1.0", 24, "192.168.1.20", true), 161 | ("172.133.133.133", 32, "172.133.133.133", true), 162 | ), 163 | ( 164 | ( 165 | "2001:db8::ff00:42:8000", 166 | 113, 167 | "2001:db8::ff00:42:ffff", 168 | true, 169 | ), 170 | ( 171 | "2001:db8::ff40:42:8000", 172 | 113, 173 | "2001:db8::ff40:42:ffff", 174 | true, 175 | ), 176 | ), 177 | ( 178 | ( 179 | "2001:db8::ff00:42:8000", 180 | 113, 181 | "2001:db8::ff00:42:ffff", 182 | true, 183 | ), 184 | ( 185 | "2001:db8::ff40:42:8000", 186 | 113, 187 | "2001:db8::ff40:42:ffff", 188 | true, 189 | ), 190 | ), 191 | ]; 192 | 193 | let p1 = ("192.168.1.0", 24, "192.168.1.20"); 194 | let p2 = ("172.133.133.133", 32, "172.133.133.133"); 195 | 196 | let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = dummy::PairBind::pair(); 197 | 198 | let mut confirm_packet_size = SIZE_KEEPALIVE; 199 | 200 | // create matching device 201 | let (_fake, _, tun_writer1, _) = dummy::TunTest::create(false); 202 | let (_fake, _, tun_writer2, _) = dummy::TunTest::create(false); 203 | 204 | let router1: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer1); 205 | router1.set_outbound_writer(bind_writer1); 206 | 207 | let router2: Device<_, TestCallbacks, _, _> = Device::new(1, tun_writer2); 208 | router2.set_outbound_writer(bind_writer2); 209 | 210 | // prepare opaque values for tracing callbacks 211 | 212 | let opaque1 = Opaque::new(); 213 | let opaque2 = Opaque::new(); 214 | 215 | // create peers with matching keypairs and assign subnets 216 | 217 | let peer1 = router1.new_peer(opaque1.clone()); 218 | let peer2 = router2.new_peer(opaque2.clone()); 219 | 220 | { 221 | let (mask, len, _ip, _okay) = p1; 222 | let mask: IpAddr = mask.parse().unwrap(); 223 | peer1.add_allowed_ip(mask, *len); 224 | peer1.add_keypair(dummy_keypair(false)); 225 | } 226 | 227 | { 228 | let (mask, len, _ip, _okay) = p2; 229 | let mask: IpAddr = mask.parse().unwrap(); 230 | peer2.add_allowed_ip(mask, *len); 231 | peer2.set_endpoint(dummy::UnitEndpoint::new()); 232 | } 233 | 234 | if confirm_with_staged_packet { 235 | // create IP packet 236 | let (_mask, _len, ip1, _okay) = p1; 237 | let (_mask, _len, ip2, _okay) = p2; 238 | 239 | let msg = make_packet( 240 | SIZE_MSG, 241 | ip1.parse().unwrap(), // src 242 | ip2.parse().unwrap(), // dst 243 | 0, 244 | ); 245 | 246 | // calculate size of encapsulated IP packet 247 | confirm_packet_size = msg.len() + SIZE_KEEPALIVE; 248 | 249 | // stage packet for sending 250 | router2 251 | .send(pad(&msg)) 252 | .expect("failed to sent staged packet"); 253 | 254 | // a new key should have been requested from the handshake machine 255 | assert_eq!( 256 | opaque2.need_key.wait(TIMEOUT), 257 | Some(()), 258 | "a new key should be requested since a packet was attempted transmitted" 259 | ); 260 | 261 | // no other events should fire 262 | no_events!(opaque1); 263 | no_events!(opaque2); 264 | } 265 | 266 | // add a keypair 267 | assert_eq!(peer1.get_endpoint(), None, "no endpoint has yet been set"); 268 | peer2.add_keypair(dummy_keypair(true)); 269 | 270 | // this should cause a key-confirmation packet (keepalive or staged packet) 271 | assert_eq!( 272 | opaque2.send.wait(TIMEOUT), 273 | Some((confirm_packet_size, true)), 274 | "expected successful transmission of a confirmation packet" 275 | ); 276 | 277 | // no other events should fire 278 | no_events!(opaque1); 279 | no_events!(opaque2); 280 | 281 | // read confirming message received by the other end ("across the internet") 282 | let mut buf = vec![0u8; SIZE_MSG * 2]; 283 | let (len, from) = bind_reader1.read(&mut buf).unwrap(); 284 | buf.truncate(len); 285 | 286 | assert_eq!( 287 | len, confirm_packet_size, 288 | "unexpected size of confirmation message" 289 | ); 290 | 291 | // pass to the router for processing 292 | router1 293 | .recv(from, buf) 294 | .expect("failed to receive confirmation message"); 295 | 296 | // check that a receive event is fired 297 | assert_eq!( 298 | opaque1.recv.wait(TIMEOUT), 299 | Some((confirm_packet_size, true)), 300 | "we expect processing to be successful" 301 | ); 302 | 303 | // the key is confirmed 304 | assert_eq!( 305 | opaque1.key_confirmed.wait(TIMEOUT), 306 | Some(()), 307 | "confirmation message should confirm the key" 308 | ); 309 | 310 | // peer1 learns the endpoint 311 | assert!( 312 | peer1.get_endpoint().is_some(), 313 | "peer1 should learn the endpoint of peer2 from the confirmation message (roaming)" 314 | ); 315 | 316 | // no other events should fire 317 | no_events!(opaque1); 318 | no_events!(opaque2); 319 | 320 | // now that peer1 has an endpoint 321 | // route packets in the other direction: peer1 -> peer2 322 | let mut sizes = vec![0, 1, 1500, MAX_SIZE_BODY]; 323 | for _ in 0..100 { 324 | let body_size: usize = rng.gen(); 325 | let body_size = body_size % MAX_SIZE_BODY; 326 | sizes.push(body_size); 327 | } 328 | for (id, body_size) in sizes.iter().enumerate() { 329 | println!("packet: id = {}, body_size = {}", id, body_size); 330 | 331 | // pass IP packet to router 332 | let (_mask, _len, ip1, _okay) = p1; 333 | let (_mask, _len, ip2, _okay) = p2; 334 | let msg = make_packet( 335 | *body_size, 336 | ip2.parse().unwrap(), // src 337 | ip1.parse().unwrap(), // dst 338 | id as u64, 339 | ); 340 | 341 | // calculate encrypted size 342 | let encrypted_size = msg.len() + SIZE_KEEPALIVE; 343 | 344 | router1 345 | .send(pad(&msg)) 346 | .expect("we expect routing to be successful"); 347 | 348 | // encryption succeeds and the correct size is logged 349 | assert_eq!( 350 | opaque1.send.wait(TIMEOUT), 351 | Some((encrypted_size, true)), 352 | "expected send event for peer1 -> peer2 payload" 353 | ); 354 | 355 | // otherwise no events 356 | no_events!(opaque1); 357 | no_events!(opaque2); 358 | 359 | // receive ("across the internet") on the other end 360 | let mut buf = vec![0u8; MAX_SIZE_BODY + 512]; 361 | let (len, from) = bind_reader2.read(&mut buf).unwrap(); 362 | buf.truncate(len); 363 | router2.recv(from, buf).unwrap(); 364 | 365 | // check that decryption succeeds 366 | assert_eq!( 367 | opaque2.recv.wait(TIMEOUT), 368 | Some((msg.len() + SIZE_KEEPALIVE, true)), 369 | "decryption and routing should succeed" 370 | ); 371 | 372 | // otherwise no events 373 | no_events!(opaque1); 374 | no_events!(opaque2); 375 | } 376 | } 377 | 378 | #[bench] 379 | fn bench_router_inbound(b: &mut Bencher) { 380 | struct BencherCallbacks {} 381 | impl Callbacks for BencherCallbacks { 382 | type Opaque = Arc; 383 | fn send( 384 | _t: &Self::Opaque, 385 | _size: usize, 386 | _sent: bool, 387 | _keypair: &Arc, 388 | _counter: u64, 389 | ) { 390 | } 391 | fn recv(t: &Self::Opaque, size: usize, _sent: bool, _keypair: &Arc) { 392 | t.fetch_add(size, Ordering::SeqCst); 393 | } 394 | fn need_key(_t: &Self::Opaque) {} 395 | fn key_confirmed(_t: &Self::Opaque) {} 396 | } 397 | 398 | // create device 399 | let (_fake, _reader, tun_writer, _mtu) = dummy::TunTest::create(false); 400 | let router: Device<_, BencherCallbacks, dummy::TunWriter, dummy::VoidBind> = 401 | Device::new(num_cpus::get_physical(), tun_writer); 402 | 403 | // add new peer 404 | let opaque = Arc::new(AtomicUsize::new(0)); 405 | let peer = router.new_peer(opaque.clone()); 406 | peer.add_keypair(dummy_keypair(true)); 407 | 408 | // add subnet to peer 409 | let (mask, len, dst) = ("192.168.1.0", 24, "192.168.1.20"); 410 | let mask: IpAddr = mask.parse().unwrap(); 411 | peer.add_allowed_ip(mask, len); 412 | 413 | // create "IP packet" 414 | let dst = dst.parse().unwrap(); 415 | let src = match dst { 416 | IpAddr::V4(_) => "127.0.0.1".parse().unwrap(), 417 | IpAddr::V6(_) => "::1".parse().unwrap(), 418 | }; 419 | let mut msg = pad(&make_packet(1024, src, dst, 0)); 420 | 421 | msg.reserve(16); 422 | 423 | #[cfg(feature = "profiler")] 424 | profiler_start("outbound"); 425 | 426 | // every iteration sends 10 GB 427 | b.iter(|| { 428 | opaque.store(0, Ordering::SeqCst); 429 | while opaque.load(Ordering::Acquire) < 10 * 1024 * 1024 { 430 | router.send(msg.to_vec()).unwrap(); 431 | } 432 | }); 433 | 434 | #[cfg(feature = "profiler")] 435 | profiler_stop(); 436 | } 437 | */ 438 | -------------------------------------------------------------------------------- /src/wireguard/router/tests/mod.rs: -------------------------------------------------------------------------------- 1 | mod bench; 2 | mod tests; 3 | 4 | use super::message_data_len; 5 | use super::SIZE_MESSAGE_PREFIX; 6 | use super::{Callbacks, Device}; 7 | use super::{Key, KeyPair}; 8 | 9 | use super::super::dummy; 10 | use super::super::tests::make_packet; 11 | 12 | use std::time::Instant; 13 | 14 | fn init() { 15 | let _ = env_logger::builder().is_test(true).try_init(); 16 | } 17 | 18 | fn pad(msg: &[u8]) -> Vec { 19 | let mut o = vec![0; msg.len() + SIZE_MESSAGE_PREFIX]; 20 | o[SIZE_MESSAGE_PREFIX..SIZE_MESSAGE_PREFIX + msg.len()].copy_from_slice(msg); 21 | o 22 | } 23 | 24 | pub fn dummy_keypair(initiator: bool) -> KeyPair { 25 | let k1 = Key { 26 | key: [0x53u8; 32], 27 | id: 0x646e6573, 28 | }; 29 | let k2 = Key { 30 | key: [0x52u8; 32], 31 | id: 0x76636572, 32 | }; 33 | if initiator { 34 | KeyPair { 35 | birth: Instant::now(), 36 | initiator: true, 37 | send: k1, 38 | recv: k2, 39 | } 40 | } else { 41 | KeyPair { 42 | birth: Instant::now(), 43 | initiator: false, 44 | send: k2, 45 | recv: k1, 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/wireguard/router/types.rs: -------------------------------------------------------------------------------- 1 | use super::KeyPair; 2 | 3 | use alloc::sync::Arc; 4 | use core::fmt; 5 | 6 | // TODO: no_std alternatives 7 | use std::error::Error; 8 | 9 | pub trait Opaque: Send + Sync + 'static {} 10 | 11 | impl Opaque for T where T: Send + Sync + 'static {} 12 | 13 | /// A send/recv callback takes 3 arguments: 14 | /// 15 | /// * `0`, a reference to the opaque value assigned to the peer 16 | /// * `1`, a bool indicating whether the message contained data (not just keepalive) 17 | /// * `2`, a bool indicating whether the message was transmitted (i.e. did the peer have an associated endpoint?) 18 | pub trait Callback: Fn(&T, usize, bool) + Sync + Send + 'static {} 19 | 20 | impl Callback for F where F: Fn(&T, usize, bool) + Sync + Send + 'static {} 21 | 22 | /// A key callback takes 1 argument 23 | /// 24 | /// * `0`, a reference to the opaque value assigned to the peer 25 | pub trait KeyCallback: Fn(&T) + Sync + Send + 'static {} 26 | 27 | impl KeyCallback for F where F: Fn(&T) + Sync + Send + 'static {} 28 | 29 | pub trait Callbacks: Send + Sync + 'static { 30 | type Opaque: Opaque; 31 | fn send(opaque: &Self::Opaque, size: usize, sent: bool, keypair: &Arc, counter: u64); 32 | fn recv(opaque: &Self::Opaque, size: usize, sent: bool, keypair: &Arc); 33 | fn need_key(opaque: &Self::Opaque); 34 | fn key_confirmed(opaque: &Self::Opaque); 35 | } 36 | 37 | #[derive(Debug)] 38 | pub enum RouterError { 39 | NoCryptoKeyRoute, 40 | MalformedTransportMessage, 41 | UnknownReceiverId, 42 | NoEndpoint, 43 | SendError, 44 | } 45 | 46 | impl fmt::Display for RouterError { 47 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 48 | match self { 49 | RouterError::NoCryptoKeyRoute => write!(f, "No cryptokey route configured for subnet"), 50 | RouterError::MalformedTransportMessage => write!(f, "Transport header is malformed"), 51 | RouterError::UnknownReceiverId => { 52 | write!(f, "No decryption state associated with receiver id") 53 | } 54 | RouterError::NoEndpoint => write!(f, "No endpoint for peer"), 55 | RouterError::SendError => write!(f, "Failed to send packet on bind"), 56 | } 57 | } 58 | } 59 | 60 | impl Error for RouterError { 61 | fn source(&self) -> Option<&(dyn Error + 'static)> { 62 | None 63 | } 64 | 65 | fn description(&self) -> &str { 66 | "Generic Handshake Error" 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/wireguard/router/worker.rs: -------------------------------------------------------------------------------- 1 | use super::queue::ParallelJob; 2 | use super::receive::ReceiveJob; 3 | use super::send::SendJob; 4 | 5 | use super::super::{tun, udp, Endpoint}; 6 | use super::types::Callbacks; 7 | 8 | use crossbeam_channel::Receiver; 9 | 10 | pub enum JobUnion> { 11 | Outbound(SendJob), 12 | Inbound(ReceiveJob), 13 | } 14 | 15 | pub fn worker>( 16 | receiver: Receiver>, 17 | ) { 18 | loop { 19 | log::trace!("pool worker awaiting job"); 20 | match receiver.recv() { 21 | Err(e) => { 22 | log::debug!("worker stopped with {}", e); 23 | break; 24 | } 25 | Ok(JobUnion::Inbound(job)) => { 26 | job.parallel_work(); 27 | job.queue().consume(); 28 | } 29 | Ok(JobUnion::Outbound(job)) => { 30 | job.parallel_work(); 31 | job.queue().consume(); 32 | } 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/wireguard/tests.rs: -------------------------------------------------------------------------------- 1 | use super::dummy; 2 | use super::wireguard::WireGuard; 3 | 4 | use std::convert::TryInto; 5 | use std::net::IpAddr; 6 | 7 | use hex; 8 | use rand_chacha::ChaCha8Rng; 9 | use rand_core::{RngCore, SeedableRng}; 10 | use x25519_dalek::{PublicKey, StaticSecret}; 11 | 12 | use pnet::packet::ipv4::MutableIpv4Packet; 13 | use pnet::packet::ipv6::MutableIpv6Packet; 14 | 15 | pub fn make_packet(size: usize, src: IpAddr, dst: IpAddr, id: u64) -> Vec { 16 | // expand pseudo random payload 17 | let mut rng: _ = ChaCha8Rng::seed_from_u64(id); 18 | let mut p: Vec = vec![0; size]; 19 | rng.fill_bytes(&mut p); 20 | 21 | // create "IP packet" 22 | let mut msg = Vec::with_capacity(size); 23 | match dst { 24 | IpAddr::V4(dst) => { 25 | let length = size + MutableIpv4Packet::minimum_packet_size(); 26 | msg.resize(length, 0); 27 | 28 | let mut packet = MutableIpv4Packet::new(&mut msg[..]).unwrap(); 29 | packet.set_destination(dst); 30 | packet.set_total_length(length.try_into().expect("length too great for IPv4 packet")); 31 | packet.set_source(if let IpAddr::V4(src) = src { 32 | src 33 | } else { 34 | panic!("src.version != dst.version") 35 | }); 36 | packet.set_payload(&p); 37 | packet.set_version(4); 38 | } 39 | IpAddr::V6(dst) => { 40 | let length = size + MutableIpv6Packet::minimum_packet_size(); 41 | msg.resize(length, 0); 42 | 43 | let mut packet = MutableIpv6Packet::new(&mut msg[..]).unwrap(); 44 | packet.set_destination(dst); 45 | packet.set_payload_length(size.try_into().expect("length too great for IPv6 packet")); 46 | packet.set_source(if let IpAddr::V6(src) = src { 47 | src 48 | } else { 49 | panic!("src.version != dst.version") 50 | }); 51 | packet.set_payload(&p); 52 | packet.set_version(6); 53 | } 54 | } 55 | msg 56 | } 57 | 58 | fn init() { 59 | let _ = env_logger::builder().is_test(true).try_init(); 60 | } 61 | 62 | /* Create and configure 63 | * two matching pure (no side-effects) instances of WireGuard. 64 | * 65 | * Test: 66 | * 67 | * - Handshaking completes successfully 68 | * - All packets up to MTU are delivered 69 | * - All packets are delivered in-order 70 | */ 71 | #[test] 72 | fn test_pure_wireguard() { 73 | init(); 74 | 75 | // create WG instances for dummy TUN devices 76 | 77 | let (fake1, tun_reader1, tun_writer1, _) = dummy::TunTest::create(true); 78 | let wg1: WireGuard = WireGuard::new(tun_writer1); 79 | wg1.add_tun_reader(tun_reader1); 80 | wg1.up(1500); 81 | 82 | let (fake2, tun_reader2, tun_writer2, _) = dummy::TunTest::create(true); 83 | let wg2: WireGuard = WireGuard::new(tun_writer2); 84 | wg2.add_tun_reader(tun_reader2); 85 | wg2.up(1500); 86 | 87 | // create pair bind to connect the interfaces "over the internet" 88 | 89 | let ((bind_reader1, bind_writer1), (bind_reader2, bind_writer2)) = dummy::PairBind::pair(); 90 | 91 | wg1.set_writer(bind_writer1); 92 | wg2.set_writer(bind_writer2); 93 | 94 | wg1.add_udp_reader(bind_reader1); 95 | wg2.add_udp_reader(bind_reader2); 96 | 97 | // configure (public, private) key pairs 98 | 99 | let sk1 = StaticSecret::from([ 100 | 0x3f, 0x69, 0x86, 0xd1, 0xc0, 0xec, 0x25, 0xa0, 0x9c, 0x8e, 0x56, 0xb5, 0x1d, 0xb7, 0x3c, 101 | 0xed, 0x56, 0x8e, 0x59, 0x9d, 0xd9, 0xc3, 0x98, 0x67, 0x74, 0x69, 0x90, 0xc3, 0x43, 0x36, 102 | 0x78, 0x89, 103 | ]); 104 | 105 | let sk2 = StaticSecret::from([ 106 | 0xfb, 0xd1, 0xd6, 0xe4, 0x65, 0x06, 0xd2, 0xe5, 0xc5, 0xdf, 0x6e, 0xab, 0x51, 0x71, 0xd8, 107 | 0x70, 0xb5, 0xb7, 0x77, 0x51, 0xb4, 0xbe, 0xfb, 0xbc, 0x88, 0x62, 0x40, 0xca, 0x2c, 0xc2, 108 | 0x66, 0xe2, 109 | ]); 110 | 111 | let pk1 = PublicKey::from(&sk1); 112 | 113 | let pk2 = PublicKey::from(&sk2); 114 | 115 | wg1.add_peer(pk2); 116 | wg2.add_peer(pk1); 117 | 118 | wg1.set_key(Some(sk1)); 119 | wg2.set_key(Some(sk2)); 120 | 121 | // configure crypto-key router 122 | 123 | { 124 | let peers1 = wg1.peers.read(); 125 | let peers2 = wg2.peers.read(); 126 | 127 | let peer2 = peers1.get(&pk2).unwrap(); 128 | let peer1 = peers2.get(&pk1).unwrap(); 129 | 130 | peer1.add_allowed_ip("192.168.1.0".parse().unwrap(), 24); 131 | 132 | peer2.add_allowed_ip("192.168.2.0".parse().unwrap(), 24); 133 | 134 | // set endpoint (the other should be learned dynamically) 135 | 136 | peer2.set_endpoint(dummy::UnitEndpoint::new()); 137 | } 138 | 139 | let num_packets = 20; 140 | 141 | // send IP packets (causing a new handshake) 142 | 143 | { 144 | let mut packets: Vec> = Vec::with_capacity(num_packets); 145 | 146 | for id in 0..num_packets { 147 | packets.push(make_packet( 148 | 50 * id as usize, // size 149 | "192.168.1.20".parse().unwrap(), // src 150 | "192.168.2.10".parse().unwrap(), // dst 151 | id as u64, // prng seed 152 | )); 153 | } 154 | 155 | let mut backup = packets.clone(); 156 | 157 | while let Some(p) = packets.pop() { 158 | println!("send"); 159 | fake1.write(p); 160 | } 161 | 162 | while let Some(p) = backup.pop() { 163 | println!("read"); 164 | assert_eq!( 165 | hex::encode(fake2.read()), 166 | hex::encode(p), 167 | "Failed to receive valid IPv4 packet unmodified and in-order" 168 | ); 169 | } 170 | } 171 | 172 | // send IP packets (other direction) 173 | 174 | { 175 | let mut packets: Vec> = Vec::with_capacity(num_packets); 176 | 177 | for id in 0..num_packets { 178 | packets.push(make_packet( 179 | 50 + 50 * id as usize, // size 180 | "192.168.2.10".parse().unwrap(), // src 181 | "192.168.1.20".parse().unwrap(), // dst 182 | (id + 100) as u64, // prng seed 183 | )); 184 | } 185 | 186 | let mut backup = packets.clone(); 187 | 188 | while let Some(p) = packets.pop() { 189 | fake2.write(p); 190 | } 191 | 192 | while let Some(p) = backup.pop() { 193 | assert_eq!( 194 | hex::encode(fake1.read()), 195 | hex::encode(p), 196 | "Failed to receive valid IPv4 packet unmodified and in-order" 197 | ); 198 | } 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /src/wireguard/types.rs: -------------------------------------------------------------------------------- 1 | use clear_on_drop::clear::Clear; 2 | use std::fmt; 3 | use std::time::Instant; 4 | 5 | #[derive(Clone)] 6 | pub struct Key { 7 | pub key: [u8; 32], 8 | pub id: u32, 9 | } 10 | 11 | // zero key on drop 12 | impl Drop for Key { 13 | fn drop(&mut self) { 14 | self.key.clear() 15 | } 16 | } 17 | 18 | #[cfg(test)] 19 | impl PartialEq for Key { 20 | fn eq(&self, other: &Self) -> bool { 21 | self.id == other.id && self.key[..] == other.key[..] 22 | } 23 | } 24 | 25 | impl fmt::Debug for Key { 26 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 27 | write!(f, "Key {{ id = {} }}", self.id) 28 | } 29 | } 30 | 31 | #[derive(Clone)] 32 | pub struct KeyPair { 33 | pub birth: Instant, // when was the key-pair created 34 | pub initiator: bool, // has the key-pair been confirmed? 35 | pub send: Key, // key for outbound messages 36 | pub recv: Key, // key for inbound messages 37 | } 38 | 39 | impl fmt::Debug for KeyPair { 40 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 41 | write!( 42 | f, 43 | "KeyPair {{ initator = {}, age = {} secs, send = {:?}, recv = {:?}}}", 44 | self.initiator, 45 | self.birth.elapsed().as_secs(), 46 | self.send, 47 | self.recv 48 | ) 49 | } 50 | } 51 | 52 | impl KeyPair { 53 | pub fn local_id(&self) -> u32 { 54 | self.recv.id 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/wireguard/wireguard.rs: -------------------------------------------------------------------------------- 1 | use super::constants::*; 2 | use super::handshake; 3 | use super::peer::PeerInner; 4 | use super::router; 5 | use super::timers::Timers; 6 | 7 | use super::queue::ParallelQueue; 8 | use super::workers::HandshakeJob; 9 | 10 | use super::tun::Tun; 11 | use super::udp::UDP; 12 | 13 | use super::workers::{handshake_worker, tun_worker, udp_worker}; 14 | 15 | use std::fmt; 16 | use std::thread; 17 | 18 | use std::ops::Deref; 19 | use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; 20 | use std::sync::Arc; 21 | use std::sync::Condvar; 22 | use std::sync::Mutex as StdMutex; 23 | use std::time::Instant; 24 | 25 | use rand::rngs::OsRng; 26 | use rand::Rng; 27 | 28 | use hjul::Runner; 29 | use spin::{Mutex, RwLock}; 30 | use x25519_dalek::{PublicKey, StaticSecret}; 31 | 32 | pub struct WireguardInner { 33 | // identifier (for logging) 34 | pub id: u32, 35 | 36 | // timer wheel 37 | pub runner: Mutex, 38 | 39 | // device enabled 40 | pub enabled: RwLock, 41 | 42 | // number of tun readers 43 | pub tun_readers: WaitCounter, 44 | 45 | // current MTU 46 | pub mtu: AtomicUsize, 47 | 48 | // peer map 49 | #[allow(clippy::type_complexity)] 50 | pub peers: RwLock< 51 | handshake::Device, T::Writer, B::Writer>>, 52 | >, 53 | 54 | // cryptokey router 55 | pub router: router::Device, T::Writer, B::Writer>, 56 | 57 | // handshake related state 58 | pub last_under_load: Mutex, 59 | pub pending: AtomicUsize, // number of pending handshake packets in queue 60 | pub queue: ParallelQueue>, 61 | } 62 | 63 | pub struct WireGuard { 64 | inner: Arc>, 65 | } 66 | 67 | pub struct WaitCounter(StdMutex, Condvar); 68 | 69 | impl fmt::Display for WireGuard { 70 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 71 | write!(f, "wireguard({:x})", self.id) 72 | } 73 | } 74 | 75 | impl Deref for WireGuard { 76 | type Target = WireguardInner; 77 | fn deref(&self) -> &Self::Target { 78 | &self.inner 79 | } 80 | } 81 | 82 | impl Clone for WireGuard { 83 | fn clone(&self) -> Self { 84 | WireGuard { 85 | inner: self.inner.clone(), 86 | } 87 | } 88 | } 89 | 90 | #[allow(clippy::mutex_atomic)] 91 | impl WaitCounter { 92 | pub fn wait(&self) { 93 | let mut nread = self.0.lock().unwrap(); 94 | while *nread > 0 { 95 | nread = self.1.wait(nread).unwrap(); 96 | } 97 | } 98 | 99 | fn new() -> Self { 100 | Self(StdMutex::new(0), Condvar::new()) 101 | } 102 | 103 | fn decrease(&self) { 104 | let mut nread = self.0.lock().unwrap(); 105 | assert!(*nread > 0); 106 | *nread -= 1; 107 | if *nread == 0 { 108 | self.1.notify_all(); 109 | } 110 | } 111 | 112 | fn increase(&self) { 113 | *self.0.lock().unwrap() += 1; 114 | } 115 | } 116 | 117 | impl WireGuard { 118 | /// Brings the WireGuard device down. 119 | /// Usually called when the associated interface is brought down. 120 | /// 121 | /// This stops any further action/timer on any peer 122 | /// and prevents transmission of further messages, 123 | /// however the device retrains its state. 124 | /// 125 | /// The instance will continue to consume and discard messages 126 | /// on both ends of the device. 127 | pub fn down(&self) { 128 | // ensure exclusive access (to avoid race with "up" call) 129 | let mut enabled = self.enabled.write(); 130 | 131 | // check if already down 132 | if !(*enabled) { 133 | return; 134 | } 135 | 136 | // set mtu 137 | self.mtu.store(0, Ordering::Relaxed); 138 | 139 | // avoid transmission from router 140 | self.router.down(); 141 | 142 | // set all peers down (stops timers) 143 | for (_, peer) in self.peers.write().iter() { 144 | peer.stop_timers(); 145 | peer.down(); 146 | } 147 | 148 | *enabled = false; 149 | } 150 | 151 | /// Brings the WireGuard device up. 152 | /// Usually called when the associated interface is brought up. 153 | pub fn up(&self, mtu: usize) { 154 | // ensure exclusive access (to avoid race with "up" call) 155 | let mut enabled = self.enabled.write(); 156 | 157 | // set mtu 158 | self.mtu.store(mtu, Ordering::Relaxed); 159 | 160 | // check if already up 161 | if *enabled { 162 | return; 163 | } 164 | 165 | // enable transmission from router 166 | self.router.up(); 167 | 168 | // set all peers up (restarts timers) 169 | for (_, peer) in self.peers.write().iter() { 170 | peer.up(); 171 | peer.start_timers(); 172 | } 173 | 174 | *enabled = true; 175 | } 176 | 177 | pub fn clear_peers(&self) { 178 | self.peers.write().clear(); 179 | } 180 | 181 | pub fn remove_peer(&self, pk: &PublicKey) { 182 | let _ = self.peers.write().remove(pk); 183 | } 184 | 185 | pub fn set_key(&self, sk: Option) { 186 | let mut peers = self.peers.write(); 187 | peers.set_sk(sk); 188 | self.router.clear_sending_keys(); 189 | } 190 | 191 | pub fn get_sk(&self) -> Option { 192 | self.peers 193 | .read() 194 | .get_sk() 195 | .map(|sk| StaticSecret::from(sk.to_bytes())) 196 | } 197 | 198 | pub fn set_psk(&self, pk: PublicKey, psk: [u8; 32]) -> bool { 199 | self.peers.write().set_psk(pk, psk).is_ok() 200 | } 201 | pub fn get_psk(&self, pk: &PublicKey) -> Option<[u8; 32]> { 202 | self.peers.read().get_psk(pk).ok() 203 | } 204 | 205 | pub fn add_peer(&self, pk: PublicKey) -> bool { 206 | let mut peers = self.peers.write(); 207 | if peers.contains_key(&pk) { 208 | return false; 209 | } 210 | 211 | // prevent up/down while inserting 212 | let enabled = self.enabled.read(); 213 | 214 | // create timers (lookup by public key) 215 | let timers = Timers::new::(self.clone(), pk, *enabled); 216 | 217 | // create new router peer 218 | let peer: router::PeerHandle, T::Writer, B::Writer> = 219 | self.router.new_peer(PeerInner { 220 | id: OsRng.gen(), 221 | pk, 222 | wg: self.clone(), 223 | walltime_last_handshake: Mutex::new(None), 224 | last_handshake_sent: Mutex::new(Instant::now() - TIME_HORIZON), 225 | handshake_queued: AtomicBool::new(false), 226 | rx_bytes: AtomicU64::new(0), 227 | tx_bytes: AtomicU64::new(0), 228 | timers: RwLock::new(timers), 229 | }); 230 | 231 | // finally, add the peer to the handshake device 232 | peers.add(pk, peer).is_ok() 233 | } 234 | 235 | /// Begin consuming messages from the reader. 236 | /// Multiple readers can be added to support multi-queue and individual Ipv6/Ipv4 sockets interfaces 237 | /// 238 | /// Any previous reader thread is stopped by closing the previous reader, 239 | /// which unblocks the thread and causes an error on reader.read 240 | pub fn add_udp_reader(&self, reader: B::Reader) { 241 | let wg = self.clone(); 242 | thread::spawn(move || { 243 | udp_worker(&wg, reader); 244 | }); 245 | } 246 | 247 | pub fn set_writer(&self, writer: B::Writer) { 248 | self.router.set_outbound_writer(writer); 249 | } 250 | 251 | pub fn add_tun_reader(&self, reader: T::Reader) { 252 | let wg = self.clone(); 253 | 254 | // increment reader count 255 | wg.tun_readers.increase(); 256 | 257 | // start worker 258 | thread::spawn(move || { 259 | tun_worker(&wg, reader); 260 | wg.tun_readers.decrease(); 261 | }); 262 | } 263 | 264 | pub fn wait(&self) { 265 | self.tun_readers.wait(); 266 | } 267 | 268 | pub fn new(writer: T::Writer) -> WireGuard { 269 | // workers equal to number of physical cores 270 | let cpus = num_cpus::get(); 271 | 272 | // create handshake queue 273 | let (tx, mut rxs) = ParallelQueue::new(cpus, 128); 274 | 275 | // create router 276 | let router: router::Device, T::Writer, B::Writer> = 277 | router::Device::new(num_cpus::get(), writer); 278 | 279 | // create arc to state 280 | let wg = WireGuard { 281 | inner: Arc::new(WireguardInner { 282 | enabled: RwLock::new(false), 283 | tun_readers: WaitCounter::new(), 284 | id: OsRng.gen(), 285 | mtu: AtomicUsize::new(0), 286 | last_under_load: Mutex::new(Instant::now() - TIME_HORIZON), 287 | router, 288 | pending: AtomicUsize::new(0), 289 | peers: RwLock::new(handshake::Device::new()), 290 | runner: Mutex::new(Runner::new(TIMERS_TICK, TIMERS_SLOTS, TIMERS_CAPACITY)), 291 | queue: tx, 292 | }), 293 | }; 294 | 295 | // start handshake workers 296 | while let Some(rx) = rxs.pop() { 297 | let wg = wg.clone(); 298 | thread::spawn(move || handshake_worker(&wg, rx)); 299 | } 300 | 301 | wg 302 | } 303 | } 304 | -------------------------------------------------------------------------------- /src/wireguard/workers.rs: -------------------------------------------------------------------------------- 1 | use std::sync::atomic::Ordering; 2 | use std::time::Instant; 3 | 4 | use byteorder::{ByteOrder, LittleEndian}; 5 | use crossbeam_channel::Receiver; 6 | use log::debug; 7 | use rand::rngs::OsRng; 8 | use x25519_dalek::PublicKey; 9 | 10 | // IO traits 11 | use super::Endpoint; 12 | 13 | use super::tun::Reader as TunReader; 14 | use super::tun::Tun; 15 | 16 | use super::udp::Reader as UDPReader; 17 | use super::udp::UDP; 18 | 19 | // constants 20 | use super::constants::{ 21 | DURATION_UNDER_LOAD, MAX_QUEUED_INCOMING_HANDSHAKES, MESSAGE_PADDING_MULTIPLE, 22 | THRESHOLD_UNDER_LOAD, 23 | }; 24 | use super::handshake::MAX_HANDSHAKE_MSG_SIZE; 25 | use super::handshake::{TYPE_COOKIE_REPLY, TYPE_INITIATION, TYPE_RESPONSE}; 26 | use super::router::{CAPACITY_MESSAGE_POSTFIX, SIZE_MESSAGE_PREFIX, TYPE_TRANSPORT}; 27 | 28 | use super::wireguard::WireGuard; 29 | 30 | pub enum HandshakeJob { 31 | Message(Vec, E), 32 | New(PublicKey), 33 | } 34 | 35 | /* Returns the padded length of a message: 36 | * 37 | * # Arguments 38 | * 39 | * - `size` : Size of unpadded message 40 | * - `mtu` : Maximum transmission unit of the device 41 | * 42 | * # Returns 43 | * 44 | * The padded length (always less than or equal to the MTU) 45 | */ 46 | #[inline(always)] 47 | const fn padding(size: usize, mtu: usize) -> usize { 48 | #[inline(always)] 49 | const fn min(a: usize, b: usize) -> usize { 50 | let m = (a < b) as usize; 51 | a * m + (1 - m) * b 52 | } 53 | let pad = MESSAGE_PADDING_MULTIPLE; 54 | min(mtu, size + (pad - size % pad) % pad) 55 | } 56 | 57 | pub fn tun_worker(wg: &WireGuard, reader: T::Reader) { 58 | loop { 59 | // create vector big enough for any transport message (based on MTU) 60 | let mtu = wg.mtu.load(Ordering::Relaxed); 61 | let size = mtu + SIZE_MESSAGE_PREFIX + 1; 62 | let mut msg: Vec = vec![0; size + CAPACITY_MESSAGE_POSTFIX]; 63 | 64 | // read a new IP packet 65 | let payload = match reader.read(&mut msg[..], SIZE_MESSAGE_PREFIX) { 66 | Ok(payload) => payload, 67 | Err(e) => { 68 | debug!("TUN worker, failed to read from tun device: {}", e); 69 | break; 70 | } 71 | }; 72 | debug!("TUN worker, IP packet of {} bytes (MTU = {})", payload, mtu); 73 | 74 | // check if device is down 75 | if mtu == 0 { 76 | continue; 77 | } 78 | 79 | // truncate padding 80 | let padded = padding(payload, mtu); 81 | log::trace!( 82 | "TUN worker, payload length = {}, padded length = {}", 83 | payload, 84 | padded 85 | ); 86 | msg.truncate(SIZE_MESSAGE_PREFIX + padded); 87 | debug_assert!(padded <= mtu); 88 | debug_assert_eq!( 89 | if padded < mtu { 90 | (msg.len() - SIZE_MESSAGE_PREFIX) % MESSAGE_PADDING_MULTIPLE 91 | } else { 92 | 0 93 | }, 94 | 0 95 | ); 96 | 97 | // crypt-key route 98 | let e = wg.router.send(msg); 99 | debug!("TUN worker, router returned {:?}", e); 100 | } 101 | } 102 | 103 | pub fn udp_worker(wg: &WireGuard, reader: B::Reader) { 104 | loop { 105 | // create vector big enough for any message given current MTU 106 | let mtu = wg.mtu.load(Ordering::Relaxed); 107 | let size = mtu + MAX_HANDSHAKE_MSG_SIZE; 108 | let mut msg: Vec = vec![0; size]; 109 | 110 | // read UDP packet into vector 111 | let (size, src) = match reader.read(&mut msg) { 112 | Err(e) => { 113 | debug!("Bind reader closed with {}", e); 114 | return; 115 | } 116 | Ok(v) => v, 117 | }; 118 | msg.truncate(size); 119 | 120 | // TODO: start device down 121 | if mtu == 0 { 122 | continue; 123 | } 124 | 125 | // message type de-multiplexer 126 | if msg.len() < std::mem::size_of::() { 127 | continue; 128 | } 129 | match LittleEndian::read_u32(&msg[..]) { 130 | TYPE_COOKIE_REPLY | TYPE_INITIATION | TYPE_RESPONSE => { 131 | debug!("{} : reader, received handshake message", wg); 132 | wg.pending.fetch_add(1, Ordering::SeqCst); 133 | wg.queue.send(HandshakeJob::Message(msg, src)); 134 | } 135 | TYPE_TRANSPORT => { 136 | debug!("{} : reader, received transport message", wg); 137 | 138 | // transport message 139 | let _ = wg.router.recv(src, msg).map_err(|e| { 140 | debug!("Failed to handle incoming transport message: {}", e); 141 | }); 142 | } 143 | _ => (), 144 | } 145 | } 146 | } 147 | 148 | pub fn handshake_worker( 149 | wg: &WireGuard, 150 | rx: Receiver>, 151 | ) { 152 | debug!("{} : handshake worker, started", wg); 153 | 154 | // process elements from the handshake queue 155 | for job in rx { 156 | // check if under load 157 | let mut under_load = false; 158 | let job: HandshakeJob = job; 159 | let pending = wg.pending.fetch_sub(1, Ordering::SeqCst); 160 | debug_assert!(pending < MAX_QUEUED_INCOMING_HANDSHAKES + (1 << 16)); 161 | 162 | // immediate go under load if too many handshakes pending 163 | if pending > THRESHOLD_UNDER_LOAD { 164 | log::trace!("{} : handshake worker, under load (above threshold)", wg); 165 | *wg.last_under_load.lock() = Instant::now(); 166 | under_load = true; 167 | } 168 | 169 | // remain under load for DURATION_UNDER_LOAD 170 | if !under_load { 171 | let elapsed = wg.last_under_load.lock().elapsed(); 172 | if DURATION_UNDER_LOAD >= elapsed { 173 | log::trace!("{} : handshake worker, under load (recent)", wg); 174 | under_load = true; 175 | } 176 | } 177 | 178 | // de-multiplex staged handshake jobs and handshake messages 179 | match job { 180 | HandshakeJob::Message(msg, mut src) => { 181 | // process message 182 | let device = wg.peers.read(); 183 | match device.process( 184 | &mut OsRng, 185 | &msg[..], 186 | if under_load { 187 | Some(src.into_address()) 188 | } else { 189 | None 190 | }, 191 | ) { 192 | Ok((peer, resp, keypair)) => { 193 | // send response (might be cookie reply or handshake response) 194 | let mut resp_len: u64 = 0; 195 | if let Some(msg) = resp { 196 | resp_len = msg.len() as u64; 197 | // TODO: consider a more elegant solution for accessing the bind 198 | let _ = wg.router.send_raw(&msg[..], &mut src).map_err(|e| { 199 | debug!( 200 | "{} : handshake worker, failed to send response, error = {}", 201 | wg, e 202 | ); 203 | }); 204 | } 205 | 206 | // update peer state 207 | if let Some(peer) = peer { 208 | // authenticated handshake packet received 209 | 210 | // add to rx_bytes and tx_bytes 211 | let req_len = msg.len() as u64; 212 | peer.opaque().rx_bytes.fetch_add(req_len, Ordering::Relaxed); 213 | peer.opaque() 214 | .tx_bytes 215 | .fetch_add(resp_len, Ordering::Relaxed); 216 | 217 | // update endpoint 218 | peer.set_endpoint(src); 219 | 220 | if resp_len > 0 { 221 | // update timers after sending handshake response 222 | debug!("{} : handshake worker, handshake response sent", wg); 223 | peer.opaque().sent_handshake_response(); 224 | } else { 225 | // update timers after receiving handshake response 226 | debug!( 227 | "{} : handshake worker, handshake response was received", 228 | wg 229 | ); 230 | peer.opaque().timers_handshake_complete(); 231 | } 232 | 233 | // add any new keypair to peer 234 | if let Some(kp) = keypair { 235 | debug!("{} : handshake worker, new keypair for {}", wg, peer); 236 | 237 | // this means that a handshake response was processed or sent 238 | peer.opaque().timers_session_derived(); 239 | 240 | // free any unused ids 241 | for id in peer.add_keypair(kp) { 242 | device.release(id); 243 | } 244 | }; 245 | } 246 | } 247 | Err(e) => debug!("{} : handshake worker, error = {:?}", wg, e), 248 | } 249 | } 250 | HandshakeJob::New(pk) => { 251 | if let Some(peer) = wg.peers.read().get(&pk) { 252 | debug!( 253 | "{} : handshake worker, new handshake requested for {}", 254 | wg, peer 255 | ); 256 | let device = wg.peers.read(); 257 | let _ = device.begin(&mut OsRng, &pk).map(|msg| { 258 | let _ = peer.send_raw(&msg[..]).map_err(|e| { 259 | debug!("{} : handshake worker, failed to send handshake initiation, error = {}", wg, e) 260 | }); 261 | peer.opaque().sent_handshake_initiation(); 262 | }); 263 | peer.opaque() 264 | .handshake_queued 265 | .store(false, Ordering::SeqCst); 266 | } 267 | } 268 | } 269 | } 270 | } 271 | --------------------------------------------------------------------------------