├── .circleci └── config.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── Makefile ├── README.md ├── src ├── cid.rs ├── congestion.rs ├── conn.rs ├── event.rs ├── lib.rs ├── packet.rs ├── peer.rs ├── recv.rs ├── send.rs ├── sent.rs ├── seq.rs ├── socket.rs ├── stream.rs ├── testutils.rs ├── time.rs └── udp.rs └── tests ├── socket.rs └── stream.rs /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | orbs: 3 | rust: circleci/rust@1.6.0 4 | workflows: 5 | prod: 6 | jobs: 7 | - rust/lint-test-build: 8 | clippy_arguments: '--all-targets --all-features -- --deny warnings' 9 | release: true 10 | version: 1.81.0 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "utp-rs" 3 | version = "0.1.0-alpha.17" 4 | edition = "2021" 5 | authors = ["Jacob Kaufmann", "Jason Carver"] 6 | description = "uTorrent transport protocol" 7 | readme = "README.md" 8 | repository = "https://github.com/ethereum/utp/" 9 | license = "MIT" 10 | keywords = ["utp"] 11 | categories = ["network-programming"] 12 | 13 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 14 | 15 | [dependencies] 16 | async-trait = "0.1.64" 17 | delay_map = "0.3.0" 18 | futures = "0.3.26" 19 | rand = "0.8.5" 20 | tokio = { version = "1.25.0", features = ["io-util", "rt-multi-thread", "macros", "net", "sync", "time"] } 21 | tracing = { version = "0.1.37", features = ["std", "attributes", "log"] } 22 | 23 | [dev-dependencies] 24 | quickcheck = "1.0.3" 25 | tokio = { version = "1.25.0", features = ["test-util"] } 26 | tracing-subscriber = "0.3.16" 27 | 28 | [profile.test] 29 | opt-level = 3 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jacob Kaufmann 4 | Copyright (c) 2023-2025 Trin Contributors 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: lint 2 | lint: # Run clippy and rustfmt 3 | cargo fmt --all 4 | cargo clippy --all --all-targets --all-features --no-deps -- --deny warnings 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # utp 2 | A Rust library for the [uTorrent transport protocol (uTP)](https://www.bittorrent.org/beps/bep_0029.html). 3 | 4 | ## 🚧 WARNING: UNDER CONSTRUCTION 🚧 5 | This library is currently unstable, with known issues. Use at your own discretion. 6 | 7 | # Usage 8 | 9 | ```rust 10 | use std::net::SocketAddr; 11 | 12 | use utp_rs::conn::ConnectionConfig; 13 | use utp_rs::socket::UtpSocket; 14 | use utp_rs::udp::AsyncUdpSocket; 15 | 16 | #[tokio::main] 17 | fn main() { 18 | // bind a standard UDP socket. (transport is over a `tokio::net::UdpSocket`.) 19 | let socket_addr = SocketAddr::from(([127, 0, 0, 1], 3400)); 20 | let udp_based_socket = UtpSocket::bind(socket_addr).await.unwrap(); 21 | 22 | // bind a custom UDP socket. here we assume `CustomSocket` implements `AsyncUdpSocket`. 23 | let async_udp_socket = CustomSocket::new(..); 24 | let custom_socket = UtpSocket::with_socket(async_udp_socket).await.unwrap(); 25 | 26 | // connect to a remote peer over uTP. 27 | let remote = SocketAddr::from(..); 28 | let config = ConnectionConfig::default(); 29 | let mut stream = udp_socket::connect(remote, config).await.unwrap(); 30 | 31 | // write data to the remote peer over the stream. 32 | let data = vec![0xef; 2048]; 33 | let n = stream.write(data.as_slice()).await.unwrap(); 34 | 35 | // accept a connection from a remote peer. 36 | let config = ConnectionConfig::default(); 37 | let stream = udp_socket.accept(config).await; 38 | 39 | // read data from the remote peer until the peer indicates there is no data left to write. 40 | let mut data = vec![]; 41 | let n = stream.read_to_eof(&mut data).await.unwrap(); 42 | } 43 | ``` 44 | 45 | -------------------------------------------------------------------------------- /src/cid.rs: -------------------------------------------------------------------------------- 1 | #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)] 2 | pub struct ConnectionId

{ 3 | pub send: u16, 4 | pub recv: u16, 5 | pub peer_id: P, 6 | } 7 | -------------------------------------------------------------------------------- /src/congestion.rs: -------------------------------------------------------------------------------- 1 | use std::cmp; 2 | use std::collections::{BinaryHeap, HashMap}; 3 | use std::time::{Duration, Instant}; 4 | 5 | pub(crate) const DEFAULT_TARGET_MICROS: u32 = 100_000; 6 | pub(crate) const DEFAULT_INITIAL_TIMEOUT: Duration = Duration::from_secs(1); 7 | pub(crate) const DEFAULT_MIN_TIMEOUT: Duration = Duration::from_millis(500); 8 | pub(crate) const DEFAULT_MAX_TIMEOUT: Duration = Duration::from_secs(60); 9 | pub(crate) const DEFAULT_MAX_PACKET_SIZE_BYTES: u32 = 1024; 10 | const DEFAULT_GAIN: f32 = 1.0; 11 | const DEFAULT_DELAY_WINDOW: Duration = Duration::from_secs(120); 12 | 13 | #[derive(Clone, Debug)] 14 | struct Packet { 15 | size_bytes: u32, 16 | num_transmissions: u32, 17 | acked: bool, 18 | } 19 | 20 | #[derive(Clone, Debug)] 21 | pub enum Transmit { 22 | Initial { bytes: u32 }, 23 | Retransmission, 24 | } 25 | 26 | #[derive(Clone, Debug)] 27 | pub struct Ack { 28 | pub delay: Duration, 29 | pub rtt: Duration, 30 | pub received_at: Instant, 31 | } 32 | 33 | #[derive(Clone, Debug, Eq, PartialEq)] 34 | pub enum Error { 35 | InsufficientWindowSize, 36 | UnknownSeqNum, 37 | DuplicateTransmission, 38 | } 39 | 40 | #[derive(Clone, Debug)] 41 | pub struct Config { 42 | pub target_delay_micros: u32, 43 | pub initial_timeout: Duration, 44 | pub min_timeout: Duration, 45 | pub max_timeout: Duration, 46 | pub max_packet_size_bytes: u32, 47 | pub max_window_size_inc_bytes: u32, 48 | pub gain: f32, 49 | pub delay_window: Duration, 50 | } 51 | 52 | impl Default for Config { 53 | fn default() -> Self { 54 | Self { 55 | target_delay_micros: DEFAULT_TARGET_MICROS, 56 | initial_timeout: DEFAULT_INITIAL_TIMEOUT, 57 | min_timeout: DEFAULT_MIN_TIMEOUT, 58 | max_timeout: DEFAULT_MAX_TIMEOUT, 59 | max_packet_size_bytes: DEFAULT_MAX_PACKET_SIZE_BYTES, 60 | max_window_size_inc_bytes: DEFAULT_MAX_PACKET_SIZE_BYTES, 61 | gain: DEFAULT_GAIN, 62 | delay_window: DEFAULT_DELAY_WINDOW, 63 | } 64 | } 65 | } 66 | 67 | #[derive(Clone, Debug)] 68 | pub struct Controller { 69 | target_delay_micros: u32, 70 | timeout: Duration, 71 | min_timeout: Duration, 72 | max_timeout: Duration, 73 | window_size_bytes: u32, 74 | max_window_size_bytes: u32, 75 | min_window_size_bytes: u32, 76 | max_window_size_inc_bytes: u32, 77 | gain: f32, 78 | rtt: Duration, 79 | rtt_variance_micros: u64, 80 | transmissions: HashMap, 81 | delay_acc: DelayAccumulator, 82 | } 83 | 84 | impl Controller { 85 | /// Creates a `Controller` configured by `config`. 86 | pub fn new(config: Config) -> Self { 87 | Self { 88 | target_delay_micros: config.target_delay_micros, 89 | timeout: config.initial_timeout, 90 | min_timeout: config.min_timeout, 91 | max_timeout: config.max_timeout, 92 | window_size_bytes: 0, 93 | max_window_size_bytes: 2 * config.max_packet_size_bytes, 94 | min_window_size_bytes: 2 * config.max_packet_size_bytes, 95 | max_window_size_inc_bytes: config.max_window_size_inc_bytes, 96 | gain: config.gain, 97 | rtt: Duration::ZERO, 98 | rtt_variance_micros: 0, 99 | transmissions: HashMap::new(), 100 | delay_acc: DelayAccumulator::new(config.delay_window), 101 | } 102 | } 103 | 104 | /// Returns the congestion timeout. 105 | pub fn timeout(&self) -> Duration { 106 | self.timeout 107 | } 108 | 109 | /// Returns the number of bytes available in the congestion window. 110 | pub fn bytes_available_in_window(&self) -> u32 { 111 | // Use saturating arithmetic because the max window (capacity) may drop below the current 112 | // window (data in flight). 113 | self.max_window_size_bytes 114 | .saturating_sub(self.window_size_bytes) 115 | } 116 | 117 | /// Registers the transmission of a packet with the controller. 118 | pub fn on_transmit(&mut self, seq_num: u16, transmission: Transmit) -> Result<(), Error> { 119 | // If the transmission is an initial transmission, then record the transmission. If the 120 | // transmission is a retransmission, then increment the number of transmissions for that 121 | // record. 122 | match transmission { 123 | Transmit::Initial { bytes } => { 124 | if self.transmissions.contains_key(&seq_num) { 125 | return Err(Error::DuplicateTransmission); 126 | } 127 | self.transmissions.insert( 128 | seq_num, 129 | Packet { 130 | size_bytes: bytes, 131 | num_transmissions: 1, 132 | acked: false, 133 | }, 134 | ); 135 | } 136 | Transmit::Retransmission => { 137 | let packet = self 138 | .transmissions 139 | .get_mut(&seq_num) 140 | .ok_or(Error::UnknownSeqNum)?; 141 | packet.num_transmissions += 1; 142 | } 143 | }; 144 | 145 | // The key-value pair should exist by logic above. 146 | let packet = self.transmissions.get(&seq_num).unwrap(); 147 | 148 | // If this is the initial transmission of this packet, then increase the size of the 149 | // window. Return an error if there is not sufficient space. 150 | if packet.num_transmissions == 1 { 151 | if self.window_size_bytes + packet.size_bytes > self.max_window_size_bytes { 152 | return Err(Error::InsufficientWindowSize); 153 | } 154 | self.window_size_bytes += packet.size_bytes; 155 | } 156 | 157 | Ok(()) 158 | } 159 | 160 | /// Registers a packet `Ack` with the controller. 161 | pub fn on_ack(&mut self, seq_num: u16, ack: Ack) -> Result<(), Error> { 162 | let packet = self 163 | .transmissions 164 | .get_mut(&seq_num) 165 | .ok_or(Error::UnknownSeqNum)?; 166 | 167 | // Mark the packet acknowledged. If the packet was already acknowledged, then short-circuit 168 | // return. There are no newly acknowledged bytes. 169 | if packet.acked { 170 | return Ok(()); 171 | } 172 | packet.acked = true; 173 | 174 | let packet = packet.clone(); 175 | 176 | // Add the delay to the accumulator. 177 | self.delay_acc.push(ack.delay, ack.received_at); 178 | 179 | // Adjust the maximum congestion window based on the delay of the acknowledged packet. 180 | // Bound the base delay (in microseconds) by `u32::MAX`. Bound the delay (in microseconds) 181 | // of the acknowledged packet by `u32::MAX`. 182 | let base_delay_micros = self 183 | .delay_acc 184 | .base_delay() 185 | .unwrap_or(Duration::ZERO) 186 | .as_micros(); 187 | let base_delay_micros = u32::try_from(base_delay_micros).unwrap_or(u32::MAX); 188 | let packet_delay_micros = u32::try_from(ack.delay.as_micros()).unwrap_or(u32::MAX); 189 | let max_window_size_adjustment = compute_max_window_size_adjustment( 190 | self.target_delay_micros, 191 | base_delay_micros, 192 | packet_delay_micros, 193 | self.window_size_bytes, 194 | packet.size_bytes, 195 | self.max_window_size_inc_bytes, 196 | self.gain, 197 | ); 198 | self.apply_max_window_size_adjustment(max_window_size_adjustment); 199 | 200 | // Adjust the current window size to account for the acknowledged packet. 201 | // 202 | // An overflow panic occurs if the window size is less than the packet size. This would 203 | // correspond to an invalid operation as the window size should account for the size of the 204 | // packet being acknowledged. 205 | self.window_size_bytes -= packet.size_bytes; 206 | 207 | // Only adjust the round trip time (RTT) estimation if the acknowledgement corresponds to 208 | // the first transmission. The congestion timeout is also adjusted each time the RTT 209 | // estimation is adjusted. 210 | if packet.num_transmissions == 1 { 211 | // Adjust round trip time variance. 212 | let rtt_var_adjustment = compute_rtt_variance_adjustment( 213 | self.rtt.as_micros(), 214 | self.rtt_variance_micros, 215 | ack.rtt.as_micros(), 216 | ); 217 | if rtt_var_adjustment.is_negative() { 218 | self.rtt_variance_micros = self 219 | .rtt_variance_micros 220 | .saturating_sub(rtt_var_adjustment.unsigned_abs()); 221 | } else { 222 | self.rtt_variance_micros = self 223 | .rtt_variance_micros 224 | .saturating_add(rtt_var_adjustment as u64); 225 | } 226 | 227 | // Adjust round trip time. 228 | let rtt_adjustment = compute_rtt_adjustment(self.rtt.as_micros(), ack.rtt.as_micros()); 229 | if rtt_adjustment.is_negative() { 230 | self.rtt = self 231 | .rtt 232 | .saturating_sub(Duration::from_micros(rtt_adjustment.unsigned_abs())); 233 | } else { 234 | self.rtt = self 235 | .rtt 236 | .saturating_add(Duration::from_micros(rtt_adjustment as u64)); 237 | } 238 | 239 | // Adjust congestion timeout. 240 | self.apply_timeout_adjustment(); 241 | } 242 | 243 | Ok(()) 244 | } 245 | 246 | /// Registers a lost packet with the controller. 247 | pub fn on_lost_packet(&mut self, seq_num: u16, retransmitting: bool) -> Result<(), Error> { 248 | let packet = self 249 | .transmissions 250 | .get(&seq_num) 251 | .ok_or(Error::UnknownSeqNum)?; 252 | 253 | self.max_window_size_bytes = 254 | cmp::max(self.max_window_size_bytes / 2, self.min_window_size_bytes); 255 | 256 | // If the packet is not to be retransmitted, then account for those lost bytes in the 257 | // congestion window. 258 | if !retransmitting { 259 | self.window_size_bytes -= packet.size_bytes; 260 | } 261 | 262 | Ok(()) 263 | } 264 | 265 | /// Registers a timeout with the controller. 266 | pub fn on_timeout(&mut self) { 267 | self.max_window_size_bytes = self.min_window_size_bytes; 268 | self.timeout = cmp::min(self.timeout * 2, self.max_timeout); 269 | } 270 | 271 | /// Adjusts the maximum window (i.e. congestion window) by `adjustment`, keeping the size of 272 | /// the window within the configured interval, and not allowing it to grow by more than the 273 | /// configured maximum increment. 274 | fn apply_max_window_size_adjustment(&mut self, adjustment: i64) { 275 | // Apply the adjustment. 276 | let adj_max_window_size_bytes = i64::from(self.max_window_size_bytes) + adjustment; 277 | 278 | // The maximum congestion window must not fall below the minimum. 279 | let adj_max_window_size_bytes = 280 | cmp::max(adj_max_window_size_bytes as u32, self.min_window_size_bytes); 281 | 282 | // The maximum congestion window cannot increase by more than the configured maximum 283 | // increment. 284 | self.max_window_size_bytes = cmp::min( 285 | adj_max_window_size_bytes, 286 | self.max_window_size_bytes 287 | .saturating_add(self.max_window_size_inc_bytes), 288 | ); 289 | } 290 | 291 | /// Adjusts the congestion timeout based on the current round trip time (RTT) estimate and the 292 | /// current RTT variance. 293 | /// 294 | /// The congestion timeout cannot fall below the configured minimum. 295 | fn apply_timeout_adjustment(&mut self) { 296 | // Do not let timeout go below minimum. 297 | self.timeout = cmp::max( 298 | self.rtt + Duration::from_micros(self.rtt_variance_micros * 4), 299 | self.min_timeout, 300 | ); 301 | 302 | // Do not let timeout go above maximum. 303 | self.timeout = cmp::min(self.timeout, self.max_timeout) 304 | } 305 | } 306 | 307 | /// Returns the adjustment in bytes to the maximum window (i.e. congestion window) size based on 308 | /// the delta between the packet delay and the target delay and on the portion of the total 309 | /// in-flight bytes that the packet corresponds to. 310 | fn compute_max_window_size_adjustment( 311 | target_delay_micros: u32, 312 | base_delay_micros: u32, 313 | packet_delay_micros: u32, 314 | window_size_bytes: u32, 315 | packet_size_bytes: u32, 316 | max_window_size_inc_bytes: u32, 317 | gain: f32, 318 | ) -> i64 { 319 | // Adjust the delay based on the base delay. 320 | // 321 | // The `i64::try_from` should not fail because `packet_delay_micros` is a `u32`. 322 | // 323 | // The base delay adjustment should not panic because the base delay is non-negative and 324 | // should not be larger than the current delay. 325 | let delay_micros = i64::from(packet_delay_micros - base_delay_micros); 326 | 327 | let off_target_micros = i64::from(target_delay_micros) - delay_micros; 328 | let delay_factor = (off_target_micros as f64) / f64::from(target_delay_micros); 329 | let window_factor = f64::from(packet_size_bytes) / f64::from(window_size_bytes); 330 | 331 | let scaled_gain = 332 | f64::from(gain) * f64::from(max_window_size_inc_bytes) * delay_factor * window_factor; 333 | 334 | scaled_gain as i64 335 | } 336 | 337 | /// Returns the adjustment to the round trip time (RTT) estimate in microseconds based on the 338 | /// packet RTT and the current RTT estimate. 339 | fn compute_rtt_adjustment(rtt_micros: u128, packet_rtt_micros: u128) -> i64 { 340 | ((packet_rtt_micros as f64 - rtt_micros as f64) / 8.0) as i64 341 | } 342 | 343 | /// Returns the adjustment to round trip time (RTT) variance in microseconds based on the packet 344 | /// RTT, current RTT estimate, and current RTT variance. 345 | fn compute_rtt_variance_adjustment( 346 | rtt_micros: u128, 347 | rtt_variance_micros: u64, 348 | packet_rtt_micros: u128, 349 | ) -> i64 { 350 | let abs_delta_micros = rtt_micros.abs_diff(packet_rtt_micros); 351 | 352 | (((abs_delta_micros as f64) - (rtt_variance_micros as f64)) / 4.0) as i64 353 | } 354 | 355 | #[derive(Clone, Debug, Eq)] 356 | struct Delay { 357 | value: Duration, 358 | deadline: Instant, 359 | } 360 | 361 | impl PartialEq for Delay { 362 | fn eq(&self, other: &Self) -> bool { 363 | self.value == other.value 364 | } 365 | } 366 | 367 | impl PartialOrd for Delay { 368 | fn partial_cmp(&self, other: &Self) -> Option { 369 | Some(self.cmp(other)) 370 | } 371 | } 372 | 373 | impl Ord for Delay { 374 | fn cmp(&self, other: &Self) -> cmp::Ordering { 375 | self.value.cmp(&other.value) 376 | } 377 | } 378 | 379 | #[derive(Clone, Debug)] 380 | struct DelayAccumulator { 381 | delays: BinaryHeap>, 382 | window: Duration, 383 | } 384 | 385 | impl DelayAccumulator { 386 | /// Creates a `DelayAccumulator` with a sliding window of length `window`. 387 | pub fn new(window: Duration) -> Self { 388 | Self { 389 | delays: BinaryHeap::new(), 390 | window, 391 | } 392 | } 393 | 394 | // TODO: Handle `received_at` from the future (i.e. beyond `Instant::now`). 395 | /// Pushes `delay` onto the accumulator set to remain in the sliding window based on 396 | /// `received_at`. 397 | pub fn push(&mut self, delay: Duration, received_at: Instant) { 398 | let delay = Delay { 399 | value: delay, 400 | deadline: received_at + self.window, 401 | }; 402 | self.delays.push(cmp::Reverse(delay)); 403 | } 404 | 405 | // TODO: Here, delay measurements that fall outside of the window are deleted lazily. Evaluate 406 | // a non-lazy alternative. The number of elements in the accumulator should not be too large. 407 | // The lazy solution also means that `base_delay` requires `&mut self`, which is not 408 | // preferable. 409 | /// Returns a baseline delay measure given the delay measurements present in the accumulator. 410 | /// Returns `None` if there are no delay measurements within the current sliding window. 411 | pub fn base_delay(&mut self) -> Option { 412 | while let Some(min) = self.delays.peek() { 413 | // If the deadline of the delay has been reached, then remove the delay and continue to 414 | // the next iteration. 415 | if Instant::now() >= min.0.deadline { 416 | self.delays.pop(); 417 | continue; 418 | } 419 | 420 | return Some(min.0.value); 421 | } 422 | 423 | // No delay exists within the window. 424 | None 425 | } 426 | } 427 | 428 | #[cfg(test)] 429 | mod tests { 430 | use super::*; 431 | 432 | mod controller { 433 | use super::*; 434 | 435 | #[test] 436 | fn on_transmit() { 437 | let mut ctrl = Controller::new(Config::default()); 438 | 439 | let initial_timeout = ctrl.timeout(); 440 | 441 | // Register the initial transmission of a packet with sequence number 1. 442 | let mut seq_num = 1; 443 | let packet_one_size_bytes = 32; 444 | let transmission = Transmit::Initial { 445 | bytes: packet_one_size_bytes, 446 | }; 447 | ctrl.on_transmit(seq_num, transmission) 448 | .expect("transmission registration failed"); 449 | 450 | let transmission_record = ctrl 451 | .transmissions 452 | .get(&seq_num) 453 | .expect("transmission not recorded"); 454 | assert_eq!(transmission_record.size_bytes, packet_one_size_bytes); 455 | assert_eq!(transmission_record.num_transmissions, 1); 456 | 457 | assert_eq!(ctrl.window_size_bytes, packet_one_size_bytes); 458 | 459 | // Register the initial transmission of a packet with sequence number 2. 460 | seq_num = 2; 461 | let packet_two_size_bytes = 128; 462 | let transmission = Transmit::Initial { 463 | bytes: packet_two_size_bytes, 464 | }; 465 | ctrl.on_transmit(seq_num, transmission) 466 | .expect("transmission registration failed"); 467 | 468 | let transmission_record = ctrl 469 | .transmissions 470 | .get(&seq_num) 471 | .expect("transmission not recorded"); 472 | assert_eq!(transmission_record.size_bytes, packet_two_size_bytes); 473 | assert_eq!(transmission_record.num_transmissions, 1); 474 | assert_eq!( 475 | ctrl.window_size_bytes, 476 | packet_one_size_bytes + packet_two_size_bytes, 477 | ); 478 | 479 | // Register the retransmission of the packet with sequence number 2. 480 | ctrl.on_transmit(seq_num, Transmit::Retransmission) 481 | .expect("transmission registration failed"); 482 | 483 | let transmission_record = ctrl 484 | .transmissions 485 | .get(&seq_num) 486 | .expect("transmission not recorded"); 487 | assert_eq!(transmission_record.size_bytes, packet_two_size_bytes); 488 | assert_eq!(transmission_record.num_transmissions, 2); 489 | assert_eq!( 490 | ctrl.window_size_bytes, 491 | packet_one_size_bytes + packet_two_size_bytes, 492 | ); 493 | 494 | assert_eq!(ctrl.timeout(), initial_timeout); 495 | } 496 | 497 | #[test] 498 | fn on_transmit_duplicate_transmission() { 499 | let mut ctrl = Controller::new(Config::default()); 500 | 501 | // Register the initial transmission of a packet with sequence number 1. 502 | let seq_num = 1; 503 | let bytes = 32; 504 | let transmission = Transmit::Initial { bytes }; 505 | ctrl.on_transmit(seq_num, transmission) 506 | .expect("transmission registration failed"); 507 | 508 | assert_eq!(ctrl.window_size_bytes, bytes); 509 | 510 | // Register the initial transmission of the SAME packet. 511 | let transmission = Transmit::Initial { bytes }; 512 | let result = ctrl.on_transmit(seq_num, transmission); 513 | assert_eq!(result, Err(Error::DuplicateTransmission)); 514 | 515 | assert_eq!(ctrl.window_size_bytes, bytes); 516 | } 517 | 518 | #[test] 519 | fn on_transmit_unknown_seq_num() { 520 | let mut ctrl = Controller::new(Config::default()); 521 | 522 | // Register the retransmission of the packet with sequence number 1. 523 | let seq_num = 1; 524 | let result = ctrl.on_transmit(seq_num, Transmit::Retransmission); 525 | assert_eq!(result, Err(Error::UnknownSeqNum)); 526 | 527 | assert_eq!(ctrl.window_size_bytes, 0); 528 | } 529 | 530 | #[test] 531 | fn on_transmit_insufficient_window_size() { 532 | let mut ctrl = Controller::new(Config::default()); 533 | 534 | // Register the transmission of a packet with sequence number 1 whose size EXCEEDS the 535 | // maximum window size. 536 | let seq_num = 1; 537 | let bytes = ctrl.max_window_size_bytes + 1; 538 | let result = ctrl.on_transmit(seq_num, Transmit::Initial { bytes }); 539 | assert_eq!(result, Err(Error::InsufficientWindowSize)); 540 | 541 | assert_eq!(ctrl.window_size_bytes, 0); 542 | } 543 | 544 | #[test] 545 | fn on_ack() { 546 | let mut ctrl = Controller::new(Config::default()); 547 | 548 | // Register the initial transmission of a packet with sequence number 1. 549 | let seq_num = 1; 550 | let bytes = 32; 551 | let transmission = Transmit::Initial { bytes }; 552 | ctrl.on_transmit(seq_num, transmission) 553 | .expect("transmission registration failed"); 554 | 555 | // Register the acknowledgement for the packet with sequence number 1. 556 | let ack_delay = Duration::from_millis(150); 557 | let ack_rtt = Duration::from_millis(300); 558 | let ack_received_at = Instant::now(); 559 | let ack = Ack { 560 | delay: ack_delay, 561 | rtt: ack_rtt, 562 | received_at: ack_received_at, 563 | }; 564 | ctrl.on_ack(seq_num, ack).expect("ack registration failed"); 565 | 566 | assert_eq!( 567 | ctrl.delay_acc 568 | .base_delay() 569 | .expect("delay not pushed into accumulator"), 570 | ack_delay, 571 | ); 572 | 573 | // TODO: max window 574 | 575 | assert_eq!(ctrl.window_size_bytes, 0); 576 | 577 | // TODO: RTT variance 578 | 579 | // TODO: RTT 580 | 581 | assert!(ctrl.timeout() >= ctrl.min_timeout); 582 | } 583 | 584 | #[test] 585 | fn on_ack_unknown_seq_num() { 586 | let mut ctrl = Controller::new(Config::default()); 587 | 588 | // Register the acknowledgement for the packet with sequence number 1. 589 | let seq_num = 1; 590 | let ack_delay = Duration::from_millis(150); 591 | let ack_rtt = Duration::from_millis(300); 592 | let ack_received_at = Instant::now(); 593 | let ack = Ack { 594 | delay: ack_delay, 595 | rtt: ack_rtt, 596 | received_at: ack_received_at, 597 | }; 598 | let result = ctrl.on_ack(seq_num, ack); 599 | assert_eq!(result, Err(Error::UnknownSeqNum)); 600 | } 601 | 602 | #[test] 603 | fn on_lost_packet_retransmitting() { 604 | let mut ctrl = Controller::new(Config::default()); 605 | 606 | let initial_max_window_size_bytes = ctrl.min_window_size_bytes * 10; 607 | ctrl.max_window_size_bytes = initial_max_window_size_bytes; 608 | 609 | // Register the initial transmission of a packet with sequence number 1. 610 | let seq_num = 1; 611 | let bytes = 32; 612 | let transmission = Transmit::Initial { bytes }; 613 | ctrl.on_transmit(seq_num, transmission) 614 | .expect("transmission registration failed"); 615 | 616 | assert_eq!(ctrl.window_size_bytes, bytes); 617 | 618 | // Register the loss of the packet with sequence number 1. Specify that we WILL attempt 619 | // to retransmit. 620 | ctrl.on_lost_packet(seq_num, true) 621 | .expect("lost packet registration failed"); 622 | assert_eq!(ctrl.window_size_bytes, bytes); 623 | assert!(ctrl.max_window_size_bytes >= ctrl.min_window_size_bytes); 624 | assert_eq!( 625 | ctrl.max_window_size_bytes, 626 | initial_max_window_size_bytes / 2, 627 | ); 628 | } 629 | 630 | #[test] 631 | fn on_lost_packet_not_retransmitting() { 632 | let mut ctrl = Controller::new(Config::default()); 633 | 634 | let initial_max_window_size_bytes = ctrl.min_window_size_bytes * 10; 635 | ctrl.max_window_size_bytes = initial_max_window_size_bytes; 636 | 637 | // Register the initial transmission of a packet with sequence number 1. 638 | let seq_num = 1; 639 | let bytes = 32; 640 | let transmission = Transmit::Initial { bytes }; 641 | ctrl.on_transmit(seq_num, transmission) 642 | .expect("transmission registration failed"); 643 | 644 | assert_eq!(ctrl.window_size_bytes, bytes); 645 | 646 | // Register the loss of the packet with sequence number 1. Specify that we WILL NOT 647 | // attempt to retransmit. 648 | ctrl.on_lost_packet(seq_num, false) 649 | .expect("lost packet registration failed"); 650 | assert_eq!(ctrl.window_size_bytes, 0); 651 | assert!(ctrl.max_window_size_bytes >= ctrl.min_window_size_bytes); 652 | assert_eq!( 653 | ctrl.max_window_size_bytes, 654 | initial_max_window_size_bytes / 2, 655 | ); 656 | } 657 | 658 | #[test] 659 | fn on_lost_packet_unknown_seq_num() { 660 | let mut ctrl = Controller::new(Config::default()); 661 | 662 | let initial_max_window_size_bytes = ctrl.min_window_size_bytes * 10; 663 | ctrl.max_window_size_bytes = initial_max_window_size_bytes; 664 | 665 | // Register the loss of the packet with sequence number 1. 666 | let seq_num = 1; 667 | let result = ctrl.on_lost_packet(seq_num, false); 668 | assert_eq!(result, Err(Error::UnknownSeqNum)); 669 | assert_eq!(ctrl.window_size_bytes, 0); 670 | assert_eq!(ctrl.max_window_size_bytes, initial_max_window_size_bytes); 671 | } 672 | 673 | #[test] 674 | fn on_timeout() { 675 | let mut ctrl = Controller::new(Config::default()); 676 | 677 | let initial_max_window_size_bytes = ctrl.min_window_size_bytes * 10; 678 | ctrl.max_window_size_bytes = initial_max_window_size_bytes; 679 | 680 | let initial_timeout = ctrl.timeout(); 681 | 682 | // Register a timeout. 683 | ctrl.on_timeout(); 684 | assert_eq!(ctrl.max_window_size_bytes, ctrl.min_window_size_bytes); 685 | assert_eq!(ctrl.timeout, initial_timeout * 2); 686 | } 687 | 688 | #[test] 689 | fn on_timeout_not_exceed_max() { 690 | const MAX_TIMEOUT: Duration = Duration::from_secs(3); 691 | let config = Config { 692 | initial_timeout: Duration::from_secs(2), 693 | max_timeout: MAX_TIMEOUT, 694 | ..Default::default() 695 | }; 696 | 697 | let mut ctrl = Controller::new(config); 698 | 699 | // Register a timeout. 700 | ctrl.on_timeout(); 701 | assert_eq!(ctrl.timeout, MAX_TIMEOUT); 702 | } 703 | } 704 | 705 | mod delay_accumulator { 706 | use super::*; 707 | 708 | #[test] 709 | fn push() { 710 | let window = Duration::from_millis(100); 711 | let mut acc = DelayAccumulator::new(window); 712 | 713 | let delay = Duration::from_millis(50); 714 | let delay_received_at = Instant::now(); 715 | acc.push(delay, delay_received_at); 716 | 717 | let item = acc 718 | .delays 719 | .peek() 720 | .expect("delay not pushed onto accumulator"); 721 | assert_eq!(item.0.value, delay); 722 | assert_eq!(item.0.deadline, delay_received_at + window); 723 | } 724 | 725 | #[test] 726 | fn base_delay() { 727 | let window = Duration::from_millis(100); 728 | let mut acc = DelayAccumulator::new(window); 729 | 730 | let delay_small = Duration::from_millis(50); 731 | let delay_small_received_at = Instant::now(); 732 | acc.push(delay_small, delay_small_received_at); 733 | 734 | let delay_smaller = Duration::from_millis(25); 735 | let delay_smaller_received_at = Instant::now(); 736 | acc.push(delay_smaller, delay_smaller_received_at); 737 | 738 | let delay_smallest = Duration::from_millis(5); 739 | let delay_smallest_received_at = Instant::now(); 740 | acc.push(delay_smallest, delay_smallest_received_at); 741 | 742 | let delay_expired = Duration::from_millis(1); 743 | let delay_expired_received_at = Instant::now() - window; 744 | acc.push(delay_expired, delay_expired_received_at); 745 | 746 | // Check that all delays are present within the accumulator. 747 | assert_eq!(acc.delays.len(), 4); 748 | 749 | let base_delay = acc 750 | .base_delay() 751 | .expect("base delay not present in accumulator"); 752 | assert_eq!(base_delay, delay_smallest); 753 | 754 | // Check that the expired delay was popped. 755 | assert_eq!(acc.delays.len(), 3); 756 | } 757 | 758 | #[test] 759 | fn base_delay_empty() { 760 | let window = Duration::from_millis(100); 761 | let mut acc = DelayAccumulator::new(window); 762 | 763 | let base_delay = acc.base_delay(); 764 | assert!(base_delay.is_none()); 765 | } 766 | } 767 | } 768 | -------------------------------------------------------------------------------- /src/event.rs: -------------------------------------------------------------------------------- 1 | use crate::cid::ConnectionId; 2 | use crate::packet::Packet; 3 | use crate::peer::{ConnectionPeer, Peer}; 4 | 5 | #[derive(Clone, Debug)] 6 | pub enum StreamEvent { 7 | Incoming(Packet), 8 | Shutdown, 9 | } 10 | 11 | #[derive(Clone, Debug)] 12 | pub enum SocketEvent { 13 | Outgoing((Packet, Peer

)), 14 | Shutdown(ConnectionId), 15 | } 16 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod cid; 2 | pub mod congestion; 3 | pub mod conn; 4 | pub mod event; 5 | pub mod packet; 6 | pub mod peer; 7 | pub mod recv; 8 | pub mod send; 9 | pub mod sent; 10 | pub mod seq; 11 | pub mod socket; 12 | pub mod stream; 13 | pub mod testutils; 14 | pub mod time; 15 | pub mod udp; 16 | -------------------------------------------------------------------------------- /src/packet.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | use std::fmt::{Debug, Formatter}; 3 | 4 | /// Size of an encoded uTP header in bytes. 5 | const PACKET_HEADER_LEN: usize = 20; 6 | 7 | /// Size of a Selective ACK segment in bits. 8 | const SELECTIVE_ACK_BITS: usize = 32; 9 | 10 | /// Size of an extension identifier in bytes. 11 | const EXTENSION_TYPE_LEN: usize = 1; 12 | 13 | /// Size of an extension length specifier in bytes. 14 | const EXTENSION_LEN_LEN: usize = 1; 15 | 16 | #[derive(Copy, Clone, Debug)] 17 | pub struct InvalidPacketType; 18 | 19 | impl fmt::Display for InvalidPacketType { 20 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 21 | write!(f, "invalid uTP packet type") 22 | } 23 | } 24 | 25 | #[derive(Copy, Clone, Debug)] 26 | pub struct InvalidVersion; 27 | 28 | impl fmt::Display for InvalidVersion { 29 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 30 | write!(f, "invalid uTP version") 31 | } 32 | } 33 | 34 | #[derive(Copy, Clone, Debug)] 35 | pub enum SelectiveAckError { 36 | InsufficientLen, 37 | InvalidLen, 38 | } 39 | 40 | impl From for PacketError { 41 | fn from(value: SelectiveAckError) -> Self { 42 | Self::InvalidExtension(ExtensionError::from(value)) 43 | } 44 | } 45 | 46 | impl fmt::Display for SelectiveAckError { 47 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 48 | let s = match self { 49 | Self::InsufficientLen => "selective ACK len must be at least 32 bits", 50 | Self::InvalidLen => "selective ACK len must be a multiple of 32 bits", 51 | }; 52 | 53 | write!(f, "{}", s) 54 | } 55 | } 56 | 57 | #[derive(Clone, Debug)] 58 | pub enum ExtensionError { 59 | InsufficientLen, 60 | InvalidSelectiveAck(SelectiveAckError), 61 | } 62 | 63 | impl From for ExtensionError { 64 | fn from(value: SelectiveAckError) -> Self { 65 | Self::InvalidSelectiveAck(value) 66 | } 67 | } 68 | 69 | impl fmt::Display for ExtensionError { 70 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 71 | let s: String = match self { 72 | Self::InsufficientLen => String::from("insufficient extension len"), 73 | Self::InvalidSelectiveAck(err) => err.to_string(), 74 | }; 75 | 76 | write!(f, "{}", s) 77 | } 78 | } 79 | 80 | #[derive(Clone, Debug)] 81 | pub enum PacketHeaderError { 82 | InvalidPacketType(InvalidPacketType), 83 | InvalidVersion(InvalidVersion), 84 | InvalidExtension(ExtensionError), 85 | InvalidLen, 86 | } 87 | 88 | impl From for PacketHeaderError { 89 | fn from(value: InvalidPacketType) -> Self { 90 | Self::InvalidPacketType(value) 91 | } 92 | } 93 | 94 | impl From for PacketHeaderError { 95 | fn from(value: InvalidVersion) -> Self { 96 | Self::InvalidVersion(value) 97 | } 98 | } 99 | 100 | impl From for PacketHeaderError { 101 | fn from(value: ExtensionError) -> Self { 102 | Self::InvalidExtension(value) 103 | } 104 | } 105 | 106 | #[derive(Clone, Debug)] 107 | pub enum PacketError { 108 | InvalidHeader(PacketHeaderError), 109 | InvalidExtension(ExtensionError), 110 | InvalidLen, 111 | EmptyDataPayload, 112 | } 113 | 114 | impl From for PacketError { 115 | fn from(value: PacketHeaderError) -> Self { 116 | Self::InvalidHeader(value) 117 | } 118 | } 119 | 120 | impl From for PacketError { 121 | fn from(value: ExtensionError) -> Self { 122 | Self::InvalidExtension(value) 123 | } 124 | } 125 | 126 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 127 | pub enum PacketType { 128 | Data, 129 | Fin, 130 | State, 131 | Reset, 132 | Syn, 133 | } 134 | 135 | impl fmt::Display for PacketType { 136 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 137 | let s = match self { 138 | Self::Data => "ST_DATA".to_string(), 139 | Self::Fin => "ST_FIN".to_string(), 140 | Self::State => "ST_STATE".to_string(), 141 | Self::Reset => "ST_RESET".to_string(), 142 | Self::Syn => "ST_SYN".to_string(), 143 | }; 144 | 145 | write!(f, "{s}") 146 | } 147 | } 148 | 149 | impl TryFrom for PacketType { 150 | type Error = InvalidPacketType; 151 | 152 | fn try_from(value: u8) -> Result { 153 | match value { 154 | 0 => Ok(Self::Data), 155 | 1 => Ok(Self::Fin), 156 | 2 => Ok(Self::State), 157 | 3 => Ok(Self::Reset), 158 | 4 => Ok(Self::Syn), 159 | _ => Err(InvalidPacketType), 160 | } 161 | } 162 | } 163 | 164 | impl From for u8 { 165 | fn from(value: PacketType) -> u8 { 166 | match value { 167 | PacketType::Data => 0, 168 | PacketType::Fin => 1, 169 | PacketType::State => 2, 170 | PacketType::Reset => 3, 171 | PacketType::Syn => 4, 172 | } 173 | } 174 | } 175 | 176 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 177 | enum Version { 178 | One, 179 | } 180 | 181 | impl TryFrom for Version { 182 | type Error = InvalidVersion; 183 | 184 | fn try_from(value: u8) -> Result { 185 | match value { 186 | 1 => Ok(Self::One), 187 | _ => Err(InvalidVersion), 188 | } 189 | } 190 | } 191 | 192 | impl From for u8 { 193 | fn from(value: Version) -> u8 { 194 | match value { 195 | Version::One => 1, 196 | } 197 | } 198 | } 199 | 200 | #[derive(Copy, Clone, Debug, Eq, PartialEq)] 201 | enum Extension { 202 | None, 203 | SelectiveAck, 204 | Unknown(u8), 205 | } 206 | 207 | impl From for Extension { 208 | fn from(value: u8) -> Self { 209 | match value { 210 | 0 => Self::None, 211 | 1 => Self::SelectiveAck, 212 | unknown => Self::Unknown(unknown), 213 | } 214 | } 215 | } 216 | 217 | impl From for u8 { 218 | fn from(value: Extension) -> u8 { 219 | match value { 220 | Extension::None => 0, 221 | Extension::SelectiveAck => 1, 222 | Extension::Unknown(ext) => ext, 223 | } 224 | } 225 | } 226 | 227 | #[derive(Clone, Debug, Eq, PartialEq)] 228 | struct PacketHeader { 229 | packet_type: PacketType, 230 | version: Version, 231 | extension: Extension, 232 | conn_id: u16, 233 | ts_micros: u32, 234 | ts_diff_micros: u32, 235 | window_size: u32, 236 | seq_num: u16, 237 | ack_num: u16, 238 | } 239 | 240 | impl PacketHeader { 241 | pub fn encode(&self) -> Vec { 242 | let mut bytes = vec![]; 243 | 244 | let packet_type = Into::::into(self.packet_type).to_be_bytes()[0]; 245 | let version = Into::::into(self.version).to_be_bytes()[0]; 246 | let type_version = (packet_type << 4) | version; 247 | bytes.push(type_version); 248 | 249 | let extension = u8::from(self.extension); 250 | bytes.push(extension); 251 | 252 | bytes.extend(self.conn_id.to_be_bytes()); 253 | bytes.extend(self.ts_micros.to_be_bytes()); 254 | bytes.extend(self.ts_diff_micros.to_be_bytes()); 255 | bytes.extend(self.window_size.to_be_bytes()); 256 | bytes.extend(self.seq_num.to_be_bytes()); 257 | bytes.extend(self.ack_num.to_be_bytes()); 258 | 259 | bytes 260 | } 261 | 262 | pub fn decode(value: &[u8]) -> Result { 263 | if value.len() != PACKET_HEADER_LEN { 264 | return Err(PacketHeaderError::InvalidLen); 265 | } 266 | 267 | let packet_type = value[0] >> 4; 268 | let packet_type = PacketType::try_from(u8::from_be(packet_type))?; 269 | 270 | let version = value[0] & 0x0F; 271 | let version = Version::try_from(u8::from_be(version))?; 272 | 273 | let extension = u8::from_be(value[1]); 274 | let extension = Extension::from(extension); 275 | 276 | let conn_id = [value[2], value[3]]; 277 | let conn_id = u16::from_be_bytes(conn_id); 278 | 279 | let ts_micros = [value[4], value[5], value[6], value[7]]; 280 | let ts_micros = u32::from_be_bytes(ts_micros); 281 | 282 | let ts_diff_micros = [value[8], value[9], value[10], value[11]]; 283 | let ts_diff_micros = u32::from_be_bytes(ts_diff_micros); 284 | 285 | let window_size = [value[12], value[13], value[14], value[15]]; 286 | let window_size = u32::from_be_bytes(window_size); 287 | 288 | let seq_num = [value[16], value[17]]; 289 | let seq_num = u16::from_be_bytes(seq_num); 290 | 291 | let ack_num = [value[18], value[19]]; 292 | let ack_num = u16::from_be_bytes(ack_num); 293 | 294 | Ok(Self { 295 | packet_type, 296 | version, 297 | extension, 298 | conn_id, 299 | ts_micros, 300 | ts_diff_micros, 301 | window_size, 302 | seq_num, 303 | ack_num, 304 | }) 305 | } 306 | } 307 | 308 | #[derive(Clone, Debug, Eq, PartialEq)] 309 | pub struct SelectiveAck { 310 | acked: Vec<[bool; SELECTIVE_ACK_BITS]>, 311 | } 312 | 313 | impl fmt::Display for SelectiveAck { 314 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 315 | let mut s = String::new(); 316 | for chunk in &self.acked { 317 | for bit in chunk { 318 | if *bit { 319 | s.push('1'); 320 | } else { 321 | s.push('0'); 322 | } 323 | } 324 | } 325 | write!(f, "{}", s) 326 | } 327 | } 328 | 329 | impl SelectiveAck { 330 | pub fn new(acked: Vec) -> Self { 331 | let chunks = acked.as_slice().chunks_exact(SELECTIVE_ACK_BITS); 332 | let remainder = chunks.remainder(); 333 | 334 | let mut acked = Vec::new(); 335 | for chunk in chunks { 336 | let mut fragment: [bool; SELECTIVE_ACK_BITS] = [false; SELECTIVE_ACK_BITS]; 337 | fragment.copy_from_slice(chunk); 338 | acked.push(fragment); 339 | } 340 | 341 | if !remainder.is_empty() { 342 | let mut fragment: [bool; SELECTIVE_ACK_BITS] = [false; SELECTIVE_ACK_BITS]; 343 | fragment[..remainder.len()].copy_from_slice(remainder); 344 | acked.push(fragment); 345 | } 346 | 347 | Self { acked } 348 | } 349 | 350 | /// Returns the length in bytes of the encoded Selective ACK. 351 | pub fn encoded_len(&self) -> usize { 352 | (self.acked.len() * SELECTIVE_ACK_BITS) / 8 353 | } 354 | 355 | pub fn acked(&self) -> Vec { 356 | self.acked 357 | .clone() 358 | .into_iter() 359 | .flatten() 360 | .collect::>() 361 | } 362 | 363 | pub fn encode(&self) -> Vec { 364 | let mut bitmask = vec![]; 365 | 366 | for word in &self.acked { 367 | let chunks = word.as_slice().chunks_exact(8); 368 | 369 | for chunk in chunks { 370 | let mut byte = 0; 371 | 372 | byte |= u8::from(chunk[7]) << 7; 373 | byte |= u8::from(chunk[6]) << 6; 374 | byte |= u8::from(chunk[5]) << 5; 375 | byte |= u8::from(chunk[4]) << 4; 376 | byte |= u8::from(chunk[3]) << 3; 377 | byte |= u8::from(chunk[2]) << 2; 378 | byte |= u8::from(chunk[1]) << 1; 379 | byte |= u8::from(chunk[0]); 380 | 381 | bitmask.push(byte); 382 | } 383 | } 384 | 385 | bitmask 386 | } 387 | 388 | pub fn decode(value: &[u8]) -> Result { 389 | if value.len() < 4 { 390 | return Err(SelectiveAckError::InsufficientLen); 391 | } 392 | if value.len() % 4 != 0 { 393 | return Err(SelectiveAckError::InvalidLen); 394 | } 395 | 396 | let mut acked: Vec<[bool; 32]> = vec![]; 397 | let mut tmp = [false; 32]; 398 | for (index, byte) in value.iter().enumerate() { 399 | tmp[(index * 8) % 32] = (*byte & 0b0000_0001) != 0; 400 | tmp[(index * 8 + 1) % 32] = (*byte & 0b0000_0010) != 0; 401 | tmp[(index * 8 + 2) % 32] = (*byte & 0b0000_0100) != 0; 402 | tmp[(index * 8 + 3) % 32] = (*byte & 0b0000_1000) != 0; 403 | tmp[(index * 8 + 4) % 32] = (*byte & 0b0001_0000) != 0; 404 | tmp[(index * 8 + 5) % 32] = (*byte & 0b0010_0000) != 0; 405 | tmp[(index * 8 + 6) % 32] = (*byte & 0b0100_0000) != 0; 406 | tmp[(index * 8 + 7) % 32] = (*byte & 0b1000_0000) != 0; 407 | 408 | if (index + 1) % 4 == 0 { 409 | acked.push(tmp); 410 | tmp = [false; 32]; 411 | } 412 | } 413 | 414 | if value.len() % 4 != 0 { 415 | acked.push(tmp); 416 | } 417 | 418 | Ok(Self { acked }) 419 | } 420 | } 421 | 422 | #[derive(Clone, Eq, PartialEq)] 423 | pub struct Packet { 424 | header: PacketHeader, 425 | selective_ack: Option, 426 | payload: Vec, 427 | } 428 | 429 | impl Packet { 430 | pub fn packet_type(&self) -> PacketType { 431 | self.header.packet_type 432 | } 433 | 434 | pub fn conn_id(&self) -> u16 { 435 | self.header.conn_id 436 | } 437 | 438 | pub fn ts_micros(&self) -> u32 { 439 | self.header.ts_micros 440 | } 441 | 442 | pub fn ts_diff_micros(&self) -> u32 { 443 | self.header.ts_diff_micros 444 | } 445 | 446 | pub fn window_size(&self) -> u32 { 447 | self.header.window_size 448 | } 449 | 450 | pub fn seq_num(&self) -> u16 { 451 | self.header.seq_num 452 | } 453 | 454 | pub fn ack_num(&self) -> u16 { 455 | self.header.ack_num 456 | } 457 | 458 | pub fn selective_ack(&self) -> Option<&SelectiveAck> { 459 | self.selective_ack.as_ref() 460 | } 461 | 462 | pub fn payload(&self) -> &Vec { 463 | &self.payload 464 | } 465 | 466 | /// Returns the length in bytes of the encoded packet. 467 | pub fn encoded_len(&self) -> usize { 468 | let mut len = PACKET_HEADER_LEN; 469 | if let Some(ref sack) = self.selective_ack { 470 | len += sack.encoded_len() + EXTENSION_TYPE_LEN + EXTENSION_LEN_LEN; 471 | } 472 | len += self.payload.len(); 473 | 474 | len 475 | } 476 | 477 | pub fn encode(&self) -> Vec { 478 | let mut bytes = vec![]; 479 | 480 | bytes.extend(self.header.encode()); 481 | if let Some(ack) = &self.selective_ack { 482 | let ack = ack.encode(); 483 | bytes.push(Extension::None.into()); 484 | bytes.push((ack.len() as u8).to_be_bytes()[0]); 485 | bytes.extend(ack); 486 | } 487 | bytes.extend_from_slice(self.payload.as_slice()); 488 | 489 | bytes 490 | } 491 | 492 | pub fn decode(value: &[u8]) -> Result { 493 | if value.len() < PACKET_HEADER_LEN { 494 | return Err(PacketError::InvalidHeader(PacketHeaderError::InvalidLen)); 495 | } 496 | 497 | let mut header: [u8; PACKET_HEADER_LEN] = [0; PACKET_HEADER_LEN]; 498 | header.copy_from_slice(&value[..PACKET_HEADER_LEN]); 499 | let header = PacketHeader::decode(&header)?; 500 | 501 | let (extensions, extensions_len) = 502 | Self::decode_raw_extensions(header.extension, &value[PACKET_HEADER_LEN..])?; 503 | 504 | // Look for the first (if any) Selective ACK extension, and attempt to decode it. 505 | // TODO: Evaluate whether duplicate extensions should constitute an error. 506 | let selective_ack = extensions 507 | .iter() 508 | .find(|(ext, _)| *ext == Extension::SelectiveAck); 509 | let selective_ack = match selective_ack { 510 | Some((_, data)) => Some(SelectiveAck::decode(data)?), 511 | None => None, 512 | }; 513 | 514 | // TODO: Save all raw extensions and make them accessible. People should be able to use 515 | // custom extensions. 516 | 517 | // The packet payload is the remainder of the packet. 518 | let payload_start_index = PACKET_HEADER_LEN + extensions_len; 519 | let payload = if payload_start_index == value.len() { 520 | vec![] 521 | } else { 522 | value[payload_start_index..].to_vec() 523 | }; 524 | 525 | if header.packet_type == PacketType::Data && payload.is_empty() { 526 | return Err(PacketError::EmptyDataPayload); 527 | } 528 | 529 | Ok(Self { 530 | header, 531 | selective_ack, 532 | payload, 533 | }) 534 | } 535 | 536 | // TODO: Resolve disabled clippy lint. 537 | #[allow(clippy::type_complexity)] 538 | fn decode_raw_extensions( 539 | first_ext: Extension, 540 | data: &[u8], 541 | ) -> Result<(Vec<(Extension, Vec)>, usize), ExtensionError> { 542 | let mut ext = first_ext; 543 | let mut index = 0; 544 | 545 | let mut extensions: Vec<(Extension, Vec)> = Vec::new(); 546 | while ext != Extension::None { 547 | if data[index..].len() < 2 { 548 | return Err(ExtensionError::InsufficientLen); 549 | } 550 | 551 | let next_ext = u8::from_be_bytes([data[index]]); 552 | 553 | let ext_len = u8::from_be_bytes([data[index + 1]]); 554 | let ext_len = usize::from(ext_len); 555 | 556 | let ext_start = index + 2; 557 | if data[ext_start..].len() < ext_len { 558 | return Err(ExtensionError::InsufficientLen); 559 | } 560 | 561 | let ext_data = data[ext_start..ext_start + ext_len].to_vec(); 562 | extensions.push((ext, ext_data)); 563 | 564 | ext = Extension::from(next_ext); 565 | index = ext_start + ext_len; 566 | } 567 | 568 | Ok((extensions, index)) 569 | } 570 | } 571 | 572 | impl Debug for Packet { 573 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 574 | write!(f, "packet cid={} packetType={} seqNr={} ackNr={} timestamp={} timestampDiff={} remoteWindow={}", 575 | self.conn_id(), 576 | self.packet_type(), 577 | self.seq_num(), 578 | self.ack_num(), 579 | self.ts_micros(), 580 | self.ts_diff_micros(), 581 | self.window_size(), 582 | ) 583 | } 584 | } 585 | 586 | #[derive(Clone, Debug)] 587 | pub struct PacketBuilder { 588 | packet_type: PacketType, 589 | conn_id: u16, 590 | ts_micros: u32, 591 | ts_diff_micros: u32, 592 | window_size: u32, 593 | seq_num: u16, 594 | ack_num: u16, 595 | selective_ack: Option, 596 | payload: Option>, 597 | } 598 | 599 | impl PacketBuilder { 600 | pub fn new( 601 | packet_type: PacketType, 602 | conn_id: u16, 603 | ts_micros: u32, 604 | window_size: u32, 605 | seq_num: u16, 606 | ) -> Self { 607 | Self { 608 | packet_type, 609 | conn_id, 610 | ts_micros, 611 | ts_diff_micros: 0, 612 | window_size, 613 | seq_num, 614 | ack_num: 0, 615 | selective_ack: None, 616 | payload: None, 617 | } 618 | } 619 | 620 | pub fn ts_micros(mut self, ts_micros: u32) -> Self { 621 | self.ts_micros = ts_micros; 622 | self 623 | } 624 | 625 | pub fn ts_diff_micros(mut self, ts_diff_micros: u32) -> Self { 626 | self.ts_diff_micros = ts_diff_micros; 627 | self 628 | } 629 | 630 | pub fn window_size(mut self, window_size: u32) -> Self { 631 | self.window_size = window_size; 632 | self 633 | } 634 | 635 | pub fn ack_num(mut self, ack_num: u16) -> Self { 636 | self.ack_num = ack_num; 637 | self 638 | } 639 | 640 | pub fn selective_ack(mut self, selective_ack: Option) -> Self { 641 | self.selective_ack = selective_ack; 642 | self 643 | } 644 | 645 | pub fn payload(mut self, payload: Vec) -> Self { 646 | self.payload = Some(payload); 647 | self 648 | } 649 | 650 | pub fn build(self) -> Packet { 651 | let extension = match self.selective_ack { 652 | Some(..) => Extension::SelectiveAck, 653 | None => Extension::None, 654 | }; 655 | 656 | Packet { 657 | header: PacketHeader { 658 | packet_type: self.packet_type, 659 | version: Version::One, 660 | extension, 661 | conn_id: self.conn_id, 662 | ts_micros: self.ts_micros, 663 | ts_diff_micros: self.ts_diff_micros, 664 | window_size: self.window_size, 665 | seq_num: self.seq_num, 666 | ack_num: self.ack_num, 667 | }, 668 | selective_ack: self.selective_ack, 669 | payload: self.payload.unwrap_or_default(), 670 | } 671 | } 672 | } 673 | 674 | impl From for PacketBuilder { 675 | fn from(packet: Packet) -> Self { 676 | let payload = if packet.payload.is_empty() { 677 | None 678 | } else { 679 | Some(packet.payload) 680 | }; 681 | 682 | Self { 683 | packet_type: packet.header.packet_type, 684 | conn_id: packet.header.conn_id, 685 | ts_micros: packet.header.ts_micros, 686 | ts_diff_micros: packet.header.ts_diff_micros, 687 | window_size: packet.header.window_size, 688 | seq_num: packet.header.seq_num, 689 | ack_num: packet.header.ack_num, 690 | selective_ack: packet.selective_ack, 691 | payload, 692 | } 693 | } 694 | } 695 | 696 | #[cfg(test)] 697 | mod tests { 698 | use super::*; 699 | 700 | use quickcheck::{quickcheck, Arbitrary, Gen, TestResult}; 701 | 702 | impl Arbitrary for PacketHeader { 703 | fn arbitrary(g: &mut Gen) -> Self { 704 | let packet_type = u8::arbitrary(g); 705 | let packet_type = if packet_type % 5 == 0 { 706 | PacketType::Data 707 | } else if packet_type % 5 == 1 { 708 | PacketType::Fin 709 | } else if packet_type % 5 == 2 { 710 | PacketType::State 711 | } else if packet_type % 5 == 3 { 712 | PacketType::Reset 713 | } else { 714 | PacketType::Syn 715 | }; 716 | 717 | let extension = u8::arbitrary(g); 718 | let extension = Extension::from(extension); 719 | 720 | Self { 721 | packet_type, 722 | version: Version::One, 723 | extension, 724 | conn_id: u16::arbitrary(g), 725 | ts_micros: u32::arbitrary(g), 726 | ts_diff_micros: u32::arbitrary(g), 727 | window_size: u32::arbitrary(g), 728 | ack_num: u16::arbitrary(g), 729 | seq_num: u16::arbitrary(g), 730 | } 731 | } 732 | } 733 | 734 | impl Arbitrary for SelectiveAck { 735 | fn arbitrary(g: &mut Gen) -> Self { 736 | let bits: Vec = Vec::arbitrary(g); 737 | 738 | let mut acked: Vec<[bool; 32]> = vec![]; 739 | 740 | let mut tmp = [false; 32]; 741 | for (index, bit) in bits.iter().enumerate() { 742 | tmp[index % 32] = *bit; 743 | 744 | if (index + 1) % 32 == 0 { 745 | acked.push(tmp); 746 | tmp = [false; 32]; 747 | } 748 | } 749 | 750 | if bits.len() % 32 != 0 || acked.is_empty() { 751 | acked.push(tmp); 752 | } 753 | 754 | Self { acked } 755 | } 756 | } 757 | 758 | // TODO: Add more tests. For example, packet encoding and decoding should test for arbitrary 759 | // extensions. 760 | 761 | #[test] 762 | fn header_encode_decode() { 763 | fn prop(header: PacketHeader) -> TestResult { 764 | let encoded = header.encode(); 765 | let encoded: [u8; 20] = encoded 766 | .try_into() 767 | .expect("invalid length for encoded uTP packet header"); 768 | let decoded = 769 | PacketHeader::decode(&encoded).expect("failed to decode uTP packet header"); 770 | 771 | TestResult::from_bool(decoded == header) 772 | } 773 | quickcheck(prop as fn(PacketHeader) -> TestResult); 774 | } 775 | 776 | #[test] 777 | fn selective_ack_encode_decode() { 778 | fn prop(selective_ack: SelectiveAck) -> TestResult { 779 | let encoded_len = selective_ack.encoded_len(); 780 | 781 | let encoded = selective_ack.encode(); 782 | 783 | assert!(encoded.len() % (SELECTIVE_ACK_BITS / 8) == 0); 784 | assert_eq!(encoded.len(), encoded_len); 785 | 786 | let decoded = SelectiveAck::decode(&encoded).expect("failed to decode Selective ACK"); 787 | 788 | TestResult::from_bool(decoded == selective_ack) 789 | } 790 | quickcheck(prop as fn(SelectiveAck) -> TestResult); 791 | } 792 | 793 | #[test] 794 | fn packet_encode_decode() { 795 | fn prop( 796 | mut header: PacketHeader, 797 | selective_ack: SelectiveAck, 798 | payload: Vec, 799 | ) -> TestResult { 800 | if payload.is_empty() { 801 | return TestResult::discard(); 802 | } 803 | 804 | let selective_ack = if selective_ack.acked.is_empty() { 805 | None 806 | } else { 807 | Some(selective_ack) 808 | }; 809 | match selective_ack { 810 | Some(..) => { 811 | header.extension = Extension::SelectiveAck; 812 | } 813 | None => { 814 | header.extension = Extension::None; 815 | } 816 | } 817 | 818 | let packet = Packet { 819 | header, 820 | selective_ack, 821 | payload, 822 | }; 823 | 824 | let encoded_len = packet.encoded_len(); 825 | 826 | let encoded = packet.encode(); 827 | 828 | assert_eq!(encoded.len(), encoded_len); 829 | 830 | let decoded = Packet::decode(&encoded).expect("failed to decode uTP packet"); 831 | 832 | TestResult::from_bool(decoded == packet) 833 | } 834 | quickcheck(prop as fn(PacketHeader, SelectiveAck, Vec) -> TestResult); 835 | } 836 | } 837 | -------------------------------------------------------------------------------- /src/peer.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::Debug; 2 | use std::hash::Hash; 3 | use std::net::SocketAddr; 4 | 5 | /// A trait that describes remote peer 6 | pub trait ConnectionPeer: Debug + Clone + Send + Sync { 7 | type Id: Debug + Clone + PartialEq + Eq + Hash + Send + Sync; 8 | 9 | /// Returns peer's id 10 | fn id(&self) -> Self::Id; 11 | 12 | /// Consolidates two peers into one. 13 | /// 14 | /// It's possible that we have two instances that represent the same peer (equal `peer_id`), 15 | /// and we need to consolidate them into one. This can happen when [Peer]-s passed with 16 | /// [UtpSocket::accept_with_cid](crate::socket::UtpSocket::accept_with_cid) or 17 | /// [UtpSocket::connect_with_cid](crate::socket::UtpSocket::connect_with_cid), and returned by 18 | /// [AsyncUdpSocket::recv_from](crate::udp::AsyncUdpSocket::recv_from) contain peers (not just 19 | /// `peer_id`). 20 | /// 21 | /// The structure implementing this trait can decide on the exact behavior. Some examples: 22 | /// - If structure is simple (i.e. two peers are the same iff all fields are the same), return 23 | /// either (see implementation for `SocketAddr`) 24 | /// - If we can determine which peer is newer (e.g. using timestamp or version field), return 25 | /// newer peer 26 | /// - If structure behaves more like a key-value map whose values don't change over time, 27 | /// merge key-value pairs from both instances into one 28 | /// 29 | /// Should panic if ids are not matching. 30 | fn consolidate(a: Self, b: Self) -> Self; 31 | } 32 | 33 | impl ConnectionPeer for SocketAddr { 34 | type Id = Self; 35 | 36 | fn id(&self) -> Self::Id { 37 | *self 38 | } 39 | 40 | fn consolidate(a: Self, b: Self) -> Self { 41 | assert!(a == b, "Consolidating non-equal peers"); 42 | a 43 | } 44 | } 45 | 46 | /// Structure that stores peer's id, and maybe peer as well. 47 | #[derive(Debug, Clone)] 48 | pub struct Peer { 49 | id: P::Id, 50 | peer: Option

, 51 | } 52 | 53 | impl Peer

{ 54 | /// Creates new instance that stores peer 55 | pub fn new(peer: P) -> Self { 56 | Self { 57 | id: peer.id(), 58 | peer: Some(peer), 59 | } 60 | } 61 | 62 | /// Creates new instance that only stores peer's id 63 | pub fn new_id(peer_id: P::Id) -> Self { 64 | Self { 65 | id: peer_id, 66 | peer: None, 67 | } 68 | } 69 | 70 | /// Returns peer's id 71 | pub fn id(&self) -> &P::Id { 72 | &self.id 73 | } 74 | 75 | /// Returns optional reference to peer 76 | pub fn peer(&self) -> Option<&P> { 77 | self.peer.as_ref() 78 | } 79 | 80 | /// Consolidates given peer into `Self` whilst consuming it. 81 | /// 82 | /// See [ConnectionPeer::consolidate] for details. 83 | /// 84 | /// Panics if ids are not matching. 85 | pub fn consolidate(&mut self, other: Self) { 86 | assert!(self.id == other.id, "Consolidating with non-equal peer"); 87 | let Some(other_peer) = other.peer else { 88 | return; 89 | }; 90 | 91 | self.peer = match self.peer.take() { 92 | Some(peer) => Some(P::consolidate(peer, other_peer)), 93 | None => Some(other_peer), 94 | }; 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/recv.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{BTreeMap, HashSet}; 2 | 3 | use crate::packet::SelectiveAck; 4 | use crate::seq::CircularRangeInclusive; 5 | 6 | type Bytes = Vec; 7 | 8 | // https://github.com/ethereum/utp/issues/139 9 | // Maximum number of selective ACKs that can fit within the length of 8 bits 10 | const MAX_SELECTIVE_ACK_COUNT: usize = 32 * 63; 11 | 12 | #[derive(Clone, Debug)] 13 | pub struct ReceiveBuffer { 14 | buf: Box<[u8; N]>, 15 | offset: usize, 16 | pending: BTreeMap, 17 | init_seq_num: u16, 18 | consumed: u16, 19 | } 20 | 21 | impl ReceiveBuffer { 22 | /// Returns a new buffer. 23 | pub fn new(init_seq_num: u16) -> Self { 24 | Self { 25 | buf: Box::new([0; N]), 26 | offset: 0, 27 | pending: BTreeMap::new(), 28 | init_seq_num, 29 | consumed: 0, 30 | } 31 | } 32 | 33 | /// Returns the number of available bytes remaining in the buffer. 34 | pub fn available(&self) -> usize { 35 | N - self.offset - self.pending.values().fold(0, |acc, data| acc + data.len()) 36 | } 37 | 38 | /// Returns `true` if the buffer is empty. 39 | pub fn is_empty(&self) -> bool { 40 | self.offset == 0 && self.pending.is_empty() 41 | } 42 | 43 | /// Returns the initial sequence number of the buffer. 44 | pub fn init_seq_num(&self) -> u16 { 45 | self.init_seq_num 46 | } 47 | 48 | /// Returns `true` if data was already written into the buffer for `seq_num`. 49 | pub fn was_written(&self, seq_num: u16) -> bool { 50 | let range = CircularRangeInclusive::new( 51 | self.init_seq_num, 52 | self.init_seq_num.wrapping_add(self.consumed), 53 | ); 54 | range.contains(seq_num) || self.pending.contains_key(&seq_num) 55 | } 56 | 57 | /// Reads data from the buffer into `buf`, returning the number of bytes read. 58 | pub fn read(&mut self, buf: &mut [u8]) -> std::io::Result { 59 | if buf.is_empty() { 60 | return Ok(0); 61 | } 62 | 63 | let n = std::cmp::min(buf.len(), self.offset); 64 | buf[..n].copy_from_slice(&self.buf.as_slice()[..n]); 65 | 66 | let remaining = self.offset - n; 67 | self.buf.as_mut_slice().copy_within(n..n + remaining, 0); 68 | self.offset = remaining; 69 | 70 | Ok(n) 71 | } 72 | 73 | /// Writes `data` into the buffer at `seq_num`. 74 | /// 75 | /// # Panics 76 | /// 77 | /// Panics if `data.len()` is greater than the amount of available bytes in the buffer and the 78 | /// data for `seq_num` has not already been written. 79 | pub fn write(&mut self, data: &[u8], seq_num: u16) { 80 | if self.was_written(seq_num) { 81 | return; 82 | } 83 | 84 | // TODO: Return error instead of panicking. 85 | if data.len() > self.available() { 86 | panic!("insufficient space in buffer"); 87 | } 88 | 89 | // Read all sequential data from pending packets. 90 | self.pending.insert(seq_num, data.to_vec()); 91 | let start = self.init_seq_num.wrapping_add(1); 92 | let mut next = start.wrapping_add(self.consumed); 93 | while let Some(data) = self.pending.remove(&next) { 94 | let end = self.offset + data.len(); 95 | self.buf.as_mut_slice()[self.offset..end].copy_from_slice(&data[..]); 96 | 97 | self.offset = end; 98 | self.consumed += 1; 99 | next = next.wrapping_add(1); 100 | } 101 | } 102 | 103 | /// Returns the last sequence number in a contiguous sequence from the initial sequence number. 104 | pub fn ack_num(&self) -> u16 { 105 | self.init_seq_num.wrapping_add(self.consumed) 106 | } 107 | 108 | /// Returns a selective ACK based on the sequence of data written into the buffer. 109 | pub fn selective_ack(&self) -> Option { 110 | if self.pending.is_empty() { 111 | return None; 112 | } 113 | 114 | // If there are pending packets, then the data for `ack_num + 1` must be missing. 115 | let mut last_ack = self.ack_num().wrapping_add(2); 116 | let mut pending_packets = self.pending.keys().copied().collect::>(); 117 | 118 | let mut acked = vec![]; 119 | while !pending_packets.is_empty() && acked.len() < MAX_SELECTIVE_ACK_COUNT { 120 | if pending_packets.remove(&last_ack) { 121 | acked.push(true); 122 | } else { 123 | acked.push(false); 124 | } 125 | last_ack = last_ack.wrapping_add(1); 126 | } 127 | 128 | Some(SelectiveAck::new(acked)) 129 | } 130 | } 131 | 132 | #[cfg(test)] 133 | mod test { 134 | use super::*; 135 | 136 | const SIZE: usize = 1024; 137 | 138 | #[test] 139 | fn available() { 140 | let init_seq_num = u16::MAX; 141 | let mut buf = ReceiveBuffer::::new(init_seq_num); 142 | 143 | assert_eq!(buf.available(), SIZE); 144 | 145 | const DATA_LEN: usize = 256; 146 | 147 | // Write out-of-order packet. 148 | let data = vec![0xef; DATA_LEN]; 149 | let seq_num = init_seq_num.wrapping_add(2); 150 | buf.write(&data, seq_num); 151 | assert_eq!(buf.available(), SIZE - DATA_LEN); 152 | 153 | // Write in-order packet. 154 | let data = vec![0xef; DATA_LEN]; 155 | let seq_num = init_seq_num.wrapping_add(1); 156 | buf.write(&data, seq_num); 157 | assert_eq!(buf.available(), SIZE - (DATA_LEN * 2)); 158 | 159 | // Read all data. 160 | let mut data = [0; DATA_LEN * 2]; 161 | buf.read(&mut data).unwrap(); 162 | assert_eq!(buf.available(), SIZE); 163 | } 164 | 165 | #[test] 166 | fn was_written() { 167 | let init_seq_num = u16::MAX; 168 | let mut buf = ReceiveBuffer::::new(init_seq_num); 169 | 170 | let seq_num = init_seq_num.wrapping_add(2); 171 | assert!(!buf.was_written(seq_num)); 172 | 173 | const DATA_LEN: usize = 64; 174 | let data = vec![0xef; DATA_LEN]; 175 | buf.write(&data, seq_num); 176 | assert!(buf.was_written(seq_num)); 177 | } 178 | 179 | #[test] 180 | fn write() { 181 | let init_seq_num = u16::MAX; 182 | let mut buf = ReceiveBuffer::::new(init_seq_num); 183 | 184 | const DATA_LEN: usize = 256; 185 | 186 | // Write out-of-order packet. 187 | let data_second = vec![0xef; DATA_LEN]; 188 | let seq_num = init_seq_num.wrapping_add(2); 189 | buf.write(&data_second, seq_num); 190 | assert_eq!(buf.offset, 0); 191 | assert_eq!(buf.consumed, 0); 192 | assert!(buf.pending.contains_key(&seq_num)); 193 | assert_eq!(*buf.pending.get(&seq_num).unwrap(), data_second); 194 | 195 | // Write in-order packet. 196 | let data_first = vec![0xfe; DATA_LEN]; 197 | let seq_num = init_seq_num.wrapping_add(1); 198 | buf.write(&data_first, seq_num); 199 | assert_eq!(buf.offset, DATA_LEN * 2); 200 | assert_eq!(buf.consumed, 2); 201 | assert!(buf.pending.is_empty()); 202 | assert_eq!(buf.buf[..DATA_LEN], data_first[..]); 203 | assert_eq!(buf.buf[DATA_LEN..DATA_LEN * 2], data_second[..]); 204 | } 205 | 206 | #[test] 207 | #[should_panic] 208 | fn write_exceeds_available() { 209 | let init_seq_num = u16::MAX; 210 | let mut buf = ReceiveBuffer::::new(init_seq_num); 211 | 212 | let seq_num = init_seq_num.wrapping_add(1); 213 | let data = vec![0xef; SIZE + 1]; 214 | buf.write(&data, seq_num); 215 | } 216 | 217 | #[test] 218 | fn read() { 219 | let init_seq_num = u16::MAX; 220 | let mut buf = ReceiveBuffer::::new(init_seq_num); 221 | 222 | const DATA_LEN: usize = 256; 223 | 224 | // Write out-of-order packet. 225 | let data_second = vec![0xef; DATA_LEN]; 226 | let seq_num = init_seq_num.wrapping_add(2); 227 | buf.write(&data_second, seq_num); 228 | 229 | let mut read_buf = vec![0; SIZE]; 230 | let read = buf.read(&mut read_buf).unwrap(); 231 | assert_eq!(read, 0); 232 | 233 | // Write in-order packet. 234 | let data_first = vec![0xfe; DATA_LEN]; 235 | let seq_num = init_seq_num.wrapping_add(1); 236 | buf.write(&data_first, seq_num); 237 | 238 | let read = buf.read(&mut read_buf).unwrap(); 239 | assert_eq!(read, DATA_LEN * 2); 240 | assert_eq!(buf.offset, 0); 241 | assert_eq!(read_buf[..DATA_LEN], data_first[..]); 242 | assert_eq!(read_buf[DATA_LEN..DATA_LEN * 2], data_second[..]); 243 | 244 | let read = buf.read(&mut read_buf).unwrap(); 245 | assert_eq!(read, 0); 246 | } 247 | 248 | #[test] 249 | fn ack_num() { 250 | let init_seq_num = u16::MAX; 251 | let mut buf = ReceiveBuffer::::new(init_seq_num); 252 | 253 | assert_eq!(buf.ack_num(), init_seq_num); 254 | 255 | const DATA_LEN: usize = 64; 256 | let data = vec![0xef; DATA_LEN]; 257 | 258 | // Write out-of-order packet. 259 | let second_seq_num = init_seq_num.wrapping_add(2); 260 | buf.write(&data, second_seq_num); 261 | 262 | assert_eq!(buf.ack_num(), init_seq_num); 263 | 264 | // Write in-order packet. 265 | let first_seq_num = init_seq_num.wrapping_add(1); 266 | buf.write(&data, first_seq_num); 267 | 268 | assert_eq!(buf.ack_num(), second_seq_num); 269 | } 270 | 271 | #[test] 272 | fn selective_ack() { 273 | let init_seq_num = u16::MAX; 274 | let mut buf = ReceiveBuffer::::new(init_seq_num); 275 | 276 | let selective_ack = buf.selective_ack(); 277 | assert!(selective_ack.is_none()); 278 | 279 | const DATA_LEN: usize = 64; 280 | let data = vec![0xef; DATA_LEN]; 281 | 282 | // Write out-of-order packet. 283 | let seq_num = init_seq_num.wrapping_add(2); 284 | buf.write(&data, seq_num); 285 | 286 | let selective_ack = buf.selective_ack().unwrap(); 287 | let acked = selective_ack.acked(); 288 | assert!(acked[0]); 289 | for ack in acked[1..].iter() { 290 | assert!(!ack); 291 | } 292 | 293 | // Write in-order packet. 294 | let seq_num = init_seq_num.wrapping_add(1); 295 | buf.write(&data, seq_num); 296 | 297 | let selective_ack = buf.selective_ack(); 298 | assert!(selective_ack.is_none()); 299 | } 300 | 301 | #[test] 302 | fn selective_ack_overflow() { 303 | let init_seq_num = u16::MAX - 2; 304 | let mut buf = ReceiveBuffer::::new(init_seq_num); 305 | 306 | let selective_ack = buf.selective_ack(); 307 | assert!(selective_ack.is_none()); 308 | 309 | const DATA_LEN: usize = 64; 310 | let data = vec![0xef; DATA_LEN]; 311 | 312 | // Write out-of-order packet. 313 | let seq_num = init_seq_num.wrapping_add(2); 314 | buf.write(&data, seq_num); 315 | // Write overflow packet, which is at seq_num 0. 316 | let seq_num = init_seq_num.wrapping_add(3); 317 | buf.write(&data, seq_num); 318 | // Write another out of order but received packet. 319 | let seq_num = init_seq_num.wrapping_add(5); 320 | buf.write(&data, seq_num); 321 | 322 | // Selective ACK should mark received packets as set. 323 | // Selective ACK begins with ack_num + 2, onwards. 324 | // Hence since we received packets 65535, 0, and 2, we should have 3 packets set, in the respective positions. 325 | let selective_ack = buf.selective_ack().unwrap(); 326 | let mut acked = vec![false; 32]; 327 | acked[0] = true; 328 | acked[1] = true; 329 | acked[3] = true; 330 | assert_eq!(selective_ack.acked(), acked); 331 | } 332 | } 333 | -------------------------------------------------------------------------------- /src/send.rs: -------------------------------------------------------------------------------- 1 | use std::collections::VecDeque; 2 | use std::io; 3 | 4 | type Bytes = Vec; 5 | 6 | #[derive(Clone, Debug)] 7 | pub struct SendBuffer { 8 | pending: VecDeque, 9 | offset: usize, 10 | } 11 | 12 | impl Default for SendBuffer { 13 | fn default() -> Self { 14 | Self { 15 | pending: VecDeque::new(), 16 | offset: 0, 17 | } 18 | } 19 | } 20 | 21 | impl SendBuffer { 22 | /// Creates a new buffer. 23 | pub fn new() -> Self { 24 | Self { 25 | pending: VecDeque::new(), 26 | offset: 0, 27 | } 28 | } 29 | 30 | /// Returns the number of bytes available in the buffer. 31 | pub fn available(&self) -> usize { 32 | N + self.offset - self.pending.iter().fold(0, |acc, x| acc + x.len()) 33 | } 34 | 35 | /// Returns `true` if the buffer is empty. 36 | pub fn is_empty(&self) -> bool { 37 | self.pending.is_empty() 38 | } 39 | 40 | /// Writes `data` into the buffer, returning the number of bytes written. 41 | pub fn write(&mut self, data: &[u8]) -> io::Result { 42 | let available = self.available(); 43 | if data.len() <= available { 44 | self.pending.push_back(data.to_vec()); 45 | Ok(data.len()) 46 | } else { 47 | self.pending.push_back(data[..available].to_vec()); 48 | Ok(available) 49 | } 50 | } 51 | 52 | /// Reads data from the buffer into `buf`, returning the number of bytes read. 53 | /// 54 | /// Data from at most one previous write can be read into `buf`. Data from different writes 55 | /// will not go into a single read. 56 | pub fn read(&mut self, buf: &mut [u8]) -> io::Result { 57 | if buf.is_empty() { 58 | return Ok(0); 59 | } 60 | 61 | let mut written = 0; 62 | if let Some(data) = self.pending.front() { 63 | let n = std::cmp::min(data.len() - self.offset, buf.len()); 64 | buf[..n].copy_from_slice(&data[self.offset..self.offset + n]); 65 | 66 | written += n; 67 | if self.offset + n == data.len() { 68 | self.offset = 0; 69 | self.pending.pop_front(); 70 | } else { 71 | self.offset += n; 72 | } 73 | } 74 | 75 | Ok(written) 76 | } 77 | } 78 | 79 | #[cfg(test)] 80 | mod test { 81 | use super::*; 82 | 83 | const SIZE: usize = 8192; 84 | 85 | #[test] 86 | fn available() { 87 | let mut buf = SendBuffer::::new(); 88 | assert_eq!(buf.available(), SIZE); 89 | 90 | const WRITE_LEN: usize = 512; 91 | const NUM_WRITES: usize = 3; 92 | 93 | const READ_LEN: usize = 64; 94 | 95 | for _ in 0..NUM_WRITES { 96 | let data = vec![0; WRITE_LEN]; 97 | buf.write(&data).unwrap(); 98 | } 99 | assert_eq!(buf.available(), SIZE - (WRITE_LEN * NUM_WRITES)); 100 | 101 | let mut data = vec![0; READ_LEN]; 102 | buf.read(&mut data).unwrap(); 103 | assert_eq!(buf.available(), SIZE - (WRITE_LEN * NUM_WRITES) + READ_LEN); 104 | 105 | for _ in 0..NUM_WRITES { 106 | let mut data = vec![0; WRITE_LEN]; 107 | buf.read(&mut data).unwrap(); 108 | } 109 | assert_eq!(buf.available(), SIZE); 110 | } 111 | 112 | #[test] 113 | #[allow(clippy::read_zero_byte_vec)] 114 | fn read() { 115 | let mut buf = SendBuffer::::new(); 116 | 117 | // Read of empty buffer returns zero. 118 | let mut read_buf = vec![0; SIZE]; 119 | let read = buf.read(&mut read_buf).unwrap(); 120 | assert_eq!(read, 0); 121 | 122 | const WRITE_LEN: usize = 1024; 123 | 124 | const READ_LEN: usize = 784; 125 | 126 | let mut read_buf = vec![0; READ_LEN]; 127 | 128 | let write_one = vec![0xef; WRITE_LEN]; 129 | let write_two = vec![0xfe; WRITE_LEN]; 130 | buf.write(&write_one).unwrap(); 131 | buf.write(&write_two).unwrap(); 132 | 133 | // Read first chunk of first write. 134 | let read = buf.read(&mut read_buf).unwrap(); 135 | assert_eq!(read, READ_LEN); 136 | assert_eq!(read_buf[..READ_LEN], write_one[..READ_LEN]); 137 | 138 | // Read remaining chunk of first write. 139 | let mut read_buf = vec![0; READ_LEN]; 140 | let read = buf.read(&mut read_buf).unwrap(); 141 | assert_eq!(read, WRITE_LEN - READ_LEN); 142 | assert_eq!(read_buf[..WRITE_LEN - READ_LEN], write_one[READ_LEN..]); 143 | 144 | // Read first chunk of second write. 145 | let read = buf.read(&mut read_buf).unwrap(); 146 | assert_eq!(read, READ_LEN); 147 | assert_eq!(read_buf[..READ_LEN], write_two[..READ_LEN]); 148 | 149 | // Read with empty buffer returns zero. 150 | let mut empty = vec![]; 151 | let read = buf.read(&mut empty).unwrap(); 152 | assert_eq!(read, 0); 153 | } 154 | 155 | #[test] 156 | fn write() { 157 | let mut buf = SendBuffer::::new(); 158 | 159 | const WRITE_LEN: usize = 1024; 160 | 161 | let data = vec![0xef; WRITE_LEN]; 162 | let written = buf.write(data.as_slice()).unwrap(); 163 | assert_eq!(written, WRITE_LEN); 164 | assert_eq!(&buf.pending.pop_front().unwrap(), &data); 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/sent.rs: -------------------------------------------------------------------------------- 1 | use std::collections::BTreeSet; 2 | use std::fmt::{self, Formatter}; 3 | use std::time::{Duration, Instant}; 4 | 5 | use crate::congestion; 6 | use crate::packet::{PacketType, SelectiveAck}; 7 | use crate::seq::CircularRangeInclusive; 8 | 9 | const LOSS_THRESHOLD: usize = 3; 10 | 11 | type Bytes = Vec; 12 | 13 | #[derive(Clone, Debug)] 14 | struct SentPacket { 15 | pub seq_num: u16, 16 | pub packet_type: PacketType, 17 | pub data: Option, 18 | pub transmission: Instant, 19 | pub retransmissions: Vec, 20 | pub acks: Vec, 21 | } 22 | 23 | impl SentPacket { 24 | fn rtt(&self, now: Instant) -> Duration { 25 | let last_transmission = self.retransmissions.first().unwrap_or(&self.transmission); 26 | now.duration_since(*last_transmission) 27 | } 28 | } 29 | 30 | #[derive(Clone, Debug)] 31 | pub struct SentPackets { 32 | packets: Vec, 33 | init_seq_num: u16, 34 | lost_packets: BTreeSet, 35 | congestion_ctrl: congestion::Controller, 36 | } 37 | 38 | #[derive(Clone, Copy, Debug, PartialEq, Eq)] 39 | pub enum SentPacketsError { 40 | InvalidAckNum, 41 | } 42 | 43 | impl fmt::Display for SentPacketsError { 44 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { 45 | match self { 46 | Self::InvalidAckNum => write!(f, "invalid ack number"), 47 | } 48 | } 49 | } 50 | 51 | impl std::error::Error for SentPacketsError {} 52 | 53 | impl SentPackets { 54 | /// Note: `init_seq_num` corresponds to the sequence number just before the sequence number of 55 | /// the first packet to track. 56 | pub fn new(init_seq_num: u16, congestion_ctrl: congestion::Controller) -> Self { 57 | Self { 58 | packets: Vec::new(), 59 | init_seq_num, 60 | lost_packets: BTreeSet::new(), 61 | congestion_ctrl, 62 | } 63 | } 64 | 65 | pub fn next_seq_num(&self) -> u16 { 66 | // Assume cast is okay, meaning that `packets` never contains more than `u16::MAX` 67 | // elements. 68 | self.init_seq_num 69 | .wrapping_add(self.packets.len() as u16) 70 | .wrapping_add(1) 71 | } 72 | 73 | pub fn ack_num(&self) -> u16 { 74 | self.last_ack_num().unwrap_or(0) 75 | } 76 | 77 | pub fn seq_num_range(&self) -> CircularRangeInclusive { 78 | let end = self.next_seq_num().wrapping_sub(1); 79 | CircularRangeInclusive::new(self.init_seq_num, end) 80 | } 81 | 82 | pub fn timeout(&self) -> Duration { 83 | self.congestion_ctrl.timeout() 84 | } 85 | 86 | pub fn on_timeout(&mut self) { 87 | self.congestion_ctrl.on_timeout() 88 | } 89 | 90 | pub fn window(&self) -> u32 { 91 | self.congestion_ctrl.bytes_available_in_window() 92 | } 93 | 94 | pub fn has_unacked_packets(&self) -> bool { 95 | self.first_unacked_seq_num().is_some() 96 | } 97 | 98 | pub fn has_lost_packets(&self) -> bool { 99 | !self.lost_packets.is_empty() 100 | } 101 | 102 | pub fn lost_packets(&self) -> Vec<(u16, PacketType, Option)> { 103 | self.lost_packets 104 | .iter() 105 | .map(|seq| { 106 | let index = self.seq_num_index(*seq); 107 | 108 | // The unwrap is safe because only sent packets may be lost. 109 | let packet = self.packets.get(index).unwrap(); 110 | 111 | (packet.seq_num, packet.packet_type, packet.data.clone()) 112 | }) 113 | .collect() 114 | } 115 | 116 | /// # Panics 117 | /// 118 | /// Panics if `seq_num` does not correspond to the next expected packet or a previously sent 119 | /// packet. 120 | /// 121 | /// Panics if the transmit is not a retransmission and `len` is greater than the amount of 122 | /// available space in the window. 123 | pub fn on_transmit( 124 | &mut self, 125 | seq_num: u16, 126 | packet_type: PacketType, 127 | data: Option, 128 | len: u32, 129 | now: Instant, 130 | ) { 131 | let index = self.seq_num_index(seq_num); 132 | let is_retransmission = index < self.packets.len(); 133 | 134 | // If the packet sequence number is beyond the next sequence number, then panic. 135 | if index > self.packets.len() { 136 | panic!("out of order transmit"); 137 | } 138 | 139 | // If this is not a retransmission and the length of the packet is greater than the amount 140 | // of available space in the window, then panic. 141 | if !is_retransmission && len > self.window() { 142 | panic!("transmit exceeds available send window"); 143 | } 144 | 145 | match self.packets.get_mut(index) { 146 | Some(sent) => { 147 | sent.retransmissions.push(now); 148 | } 149 | None => { 150 | let sent = SentPacket { 151 | seq_num, 152 | packet_type, 153 | data, 154 | transmission: now, 155 | retransmissions: Vec::new(), 156 | acks: Vec::new(), 157 | }; 158 | self.packets.push(sent); 159 | } 160 | } 161 | 162 | let transmit = if is_retransmission { 163 | congestion::Transmit::Retransmission 164 | } else { 165 | congestion::Transmit::Initial { bytes: len } 166 | }; 167 | 168 | // The unwrap is safe given the check above on the available window. 169 | self.congestion_ctrl.on_transmit(seq_num, transmit).unwrap(); 170 | } 171 | 172 | /// Handle an ACK for a packet with sequence number `ack_num`, and an optional selective_ack. 173 | /// 174 | /// Returns Error, if `ack_num` is not within the sequence number range of the sent packets. 175 | pub fn on_ack( 176 | &mut self, 177 | ack_num: u16, 178 | selective_ack: Option<&SelectiveAck>, 179 | delay: Duration, 180 | now: Instant, 181 | ) -> Result<(CircularRangeInclusive, Vec), SentPacketsError> { 182 | let range = self.seq_num_range(); 183 | if !range.contains(ack_num) { 184 | return Err(SentPacketsError::InvalidAckNum); 185 | } 186 | 187 | // Do not ACK if ACK num corresponds to initial packet. 188 | if ack_num != range.start() { 189 | self.on_ack_num(ack_num, selective_ack, delay, now); 190 | } 191 | 192 | // Mark all packets up to `ack_num` acknowledged by retaining all packets in 193 | // the range beyond `ack_num`. 194 | let full_acked = CircularRangeInclusive::new(range.start(), ack_num); 195 | 196 | let selected_acks = if let Some(selective_ack) = selective_ack { 197 | selective_ack 198 | .acked() 199 | .iter() 200 | .enumerate() 201 | .filter(|(_, &acked)| acked) 202 | .map(|(i, _)| ack_num.wrapping_add(2).wrapping_add(i as u16)) 203 | .collect() 204 | } else { 205 | Vec::new() 206 | }; 207 | 208 | Ok((full_acked, selected_acks)) 209 | } 210 | 211 | /// # Panics 212 | /// 213 | /// Panics if `ack_num` does not correspond to a previously sent packet. 214 | fn on_ack_num( 215 | &mut self, 216 | ack_num: u16, 217 | selective_ack: Option<&SelectiveAck>, 218 | delay: Duration, 219 | now: Instant, 220 | ) { 221 | if let Some(sack) = selective_ack { 222 | self.on_selective_ack(ack_num, sack, delay, now); 223 | } else { 224 | self.ack(ack_num, delay, now); 225 | } 226 | 227 | // An ACK for `ack_num` implicitly ACKs all sequence numbers that precede `ack_num`. 228 | // Account for any preceding unacked packets. 229 | self.ack_prior_unacked(ack_num, delay, now); 230 | 231 | // Account for (newly) lost packets. 232 | let lost = self.detect_lost_packets(); 233 | for packet in lost { 234 | if self.lost_packets.insert(packet) { 235 | self.on_lost(packet, true); 236 | } 237 | } 238 | } 239 | 240 | /// # Panics 241 | /// 242 | /// Panics if `ack_num` does not correspond to a previously sent packet. 243 | fn on_selective_ack( 244 | &mut self, 245 | ack_num: u16, 246 | selective_ack: &SelectiveAck, 247 | delay: Duration, 248 | now: Instant, 249 | ) { 250 | self.ack(ack_num, delay, now); 251 | 252 | let range = self.seq_num_range(); 253 | 254 | // The first bit of the selective ACK corresponds to `ack_num.wrapping_add(2)`, where 255 | // `ack_num.wrapping_add(1)` is assumed to have been dropped. 256 | let mut sack_num = ack_num.wrapping_add(2); 257 | for ack in selective_ack.acked() { 258 | // Break once we exhaust all sent sequence numbers. The selective ACK length is a 259 | // multiple of 32, so it may be padded beyond the actual range of sequence numbers. 260 | if !range.contains(sack_num) { 261 | break; 262 | } 263 | 264 | if ack { 265 | self.ack(sack_num, delay, now); 266 | } 267 | 268 | sack_num = sack_num.wrapping_add(1); 269 | } 270 | } 271 | 272 | /// Returns a set containing the sequence numbers of lost packets. 273 | /// 274 | /// A packet is lost if it has not been acknowledged and some threshold number of packets sent 275 | /// after it have been acknowledged. 276 | fn detect_lost_packets(&self) -> BTreeSet { 277 | let mut acked = 0; 278 | let mut lost = BTreeSet::new(); 279 | 280 | let start = match self.first_unacked_seq_num() { 281 | Some(first_unacked) => self.seq_num_index(first_unacked), 282 | None => return lost, 283 | }; 284 | 285 | for packet in self.packets[start..].iter().rev() { 286 | if packet.acks.is_empty() && acked >= LOSS_THRESHOLD { 287 | lost.insert(packet.seq_num); 288 | } 289 | 290 | if !packet.acks.is_empty() { 291 | acked += 1; 292 | } 293 | } 294 | 295 | lost 296 | } 297 | 298 | /// # Panics 299 | /// 300 | /// Panics if `seq_num` does not correspond to a previously sent packet. 301 | fn ack(&mut self, seq_num: u16, delay: Duration, now: Instant) { 302 | let index = self.seq_num_index(seq_num); 303 | let packet = self.packets.get_mut(index).unwrap(); 304 | 305 | let ack = congestion::Ack { 306 | delay, 307 | rtt: packet.rtt(now), 308 | received_at: now, 309 | }; 310 | self.congestion_ctrl.on_ack(packet.seq_num, ack).unwrap(); 311 | 312 | packet.acks.push(now); 313 | 314 | self.lost_packets.remove(&packet.seq_num); 315 | } 316 | 317 | /// Acknowledges any unacknowledged packets that precede `seq_num`. 318 | fn ack_prior_unacked(&mut self, seq_num: u16, delay: Duration, now: Instant) { 319 | if let Some(first_unacked) = self.first_unacked_seq_num() { 320 | let start = self.seq_num_index(first_unacked); 321 | let end = self.seq_num_index(seq_num); 322 | if start >= end { 323 | return; 324 | } 325 | 326 | let to_ack: Vec = self.packets[start..end].iter().map(|p| p.seq_num).collect(); 327 | for seq_num in to_ack { 328 | self.ack(seq_num, delay, now); 329 | } 330 | } 331 | } 332 | 333 | /// # Panics 334 | /// 335 | /// Panics if `seq_num` does not correspond to a previously sent packet. 336 | fn on_lost(&mut self, seq_num: u16, retransmitting: bool) { 337 | if !self.seq_num_range().contains(seq_num) { 338 | panic!("cannot mark unsent packet lost"); 339 | } 340 | 341 | // The unwrap is safe assuming that we do not panic above. 342 | self.congestion_ctrl 343 | .on_lost_packet(seq_num, retransmitting) 344 | .expect("lost packet was previously sent"); 345 | } 346 | 347 | /// Returns the "normalized" index for `seq_num` based on the initial sequence number. 348 | fn seq_num_index(&self, seq_num: u16) -> usize { 349 | // The first sequence number is equal to `self.init_seq_num.wrapping_add(1)`. 350 | if seq_num > self.init_seq_num { 351 | usize::from(seq_num - self.init_seq_num - 1) 352 | } else { 353 | usize::from((u16::MAX - self.init_seq_num).wrapping_add(seq_num)) 354 | } 355 | } 356 | 357 | /// Returns the sequence number of the last (i.e. latest) packet in a contiguous sequence of 358 | /// acknowledged packets. 359 | /// 360 | /// Returns `None` if none of the (possibly zero) packets have been acknowledged. 361 | // TODO: Cache this value, (possibly) updating on each ACK. 362 | pub fn last_ack_num(&self) -> Option { 363 | if self.packets.is_empty() { 364 | return None; 365 | } 366 | 367 | let mut num = None; 368 | for packet in &self.packets { 369 | if !packet.acks.is_empty() { 370 | num = Some(packet.seq_num); 371 | } else { 372 | break; 373 | } 374 | } 375 | 376 | num 377 | } 378 | 379 | /// Returns the sequence number of the first (i.e. earliest) packet that has not been 380 | /// acknowledged. 381 | /// 382 | /// Returns `None` if all (possibly zero) sent packets have been acknowledged. 383 | fn first_unacked_seq_num(&self) -> Option { 384 | if self.packets.is_empty() { 385 | return None; 386 | } 387 | 388 | let seq_num = match self.last_ack_num() { 389 | Some(last_ack_num) => { 390 | // If the last ACK num corresponds to the last packet, then return `None`. 391 | if self.packets.last().unwrap().seq_num == last_ack_num { 392 | return None; 393 | } 394 | last_ack_num.wrapping_add(1) 395 | } 396 | None => self.init_seq_num.wrapping_add(1), 397 | }; 398 | 399 | Some(seq_num) 400 | } 401 | } 402 | 403 | #[cfg(test)] 404 | mod test { 405 | use super::*; 406 | 407 | use quickcheck::{quickcheck, TestResult}; 408 | 409 | const DELAY: Duration = Duration::from_millis(100); 410 | 411 | // TODO: Bolster tests. 412 | 413 | #[test] 414 | fn next_seq_num() { 415 | fn prop(init_seq_num: u16, len: u8) -> TestResult { 416 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default()); 417 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl); 418 | if len == 0 { 419 | return TestResult::from_bool( 420 | sent_packets.next_seq_num() == init_seq_num.wrapping_add(1), 421 | ); 422 | } 423 | 424 | let final_seq_num = init_seq_num.wrapping_add(u16::from(len)); 425 | let range = CircularRangeInclusive::new(init_seq_num.wrapping_add(1), final_seq_num); 426 | let transmission = Instant::now(); 427 | for seq_num in range { 428 | sent_packets.packets.push(SentPacket { 429 | seq_num, 430 | packet_type: PacketType::Data, 431 | data: None, 432 | transmission, 433 | acks: Default::default(), 434 | retransmissions: Default::default(), 435 | }); 436 | } 437 | 438 | TestResult::from_bool(sent_packets.next_seq_num() == final_seq_num.wrapping_add(1)) 439 | } 440 | quickcheck(prop as fn(u16, u8) -> TestResult) 441 | } 442 | 443 | #[test] 444 | fn on_transmit_initial() { 445 | let init_seq_num = u16::MAX; 446 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default()); 447 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl); 448 | 449 | let seq_num = sent_packets.next_seq_num(); 450 | let data = vec![0]; 451 | let len = data.len() as u32; 452 | let now = Instant::now(); 453 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data), len, now); 454 | 455 | assert_eq!(sent_packets.packets.len(), 1); 456 | 457 | let packet = &sent_packets.packets[0]; 458 | assert_eq!(packet.seq_num, seq_num); 459 | assert_eq!(packet.transmission, now); 460 | assert!(packet.acks.is_empty()); 461 | assert!(packet.retransmissions.is_empty()); 462 | } 463 | 464 | #[test] 465 | fn on_transmit_retransmit() { 466 | let init_seq_num = u16::MAX; 467 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default()); 468 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl); 469 | 470 | let seq_num = sent_packets.next_seq_num(); 471 | let data = vec![0]; 472 | let len = data.len() as u32; 473 | let first = Instant::now(); 474 | let second = Instant::now(); 475 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data.clone()), len, first); 476 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data), len, second); 477 | 478 | assert_eq!(sent_packets.packets.len(), 1); 479 | 480 | let packet = &sent_packets.packets[0]; 481 | assert_eq!(packet.seq_num, seq_num); 482 | assert_eq!(packet.transmission, first); 483 | assert!(packet.acks.is_empty()); 484 | assert_eq!(packet.retransmissions.len(), 1); 485 | assert_eq!(packet.retransmissions[0], second); 486 | } 487 | 488 | #[test] 489 | #[should_panic] 490 | fn on_transmit_out_of_order() { 491 | let init_seq_num = u16::MAX; 492 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default()); 493 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl); 494 | 495 | let out_of_order_seq_num = init_seq_num.wrapping_add(2); 496 | let data = vec![0]; 497 | let len = data.len() as u32; 498 | let now = Instant::now(); 499 | 500 | sent_packets.on_transmit(out_of_order_seq_num, PacketType::Data, Some(data), len, now); 501 | } 502 | 503 | #[test] 504 | fn on_selective_ack() { 505 | let init_seq_num = u16::MAX; 506 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default()); 507 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl); 508 | 509 | let data = vec![0]; 510 | let len = data.len() as u32; 511 | 512 | const COUNT: usize = 10; 513 | for _ in 0..COUNT { 514 | let now = Instant::now(); 515 | let seq_num = sent_packets.next_seq_num(); 516 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data.clone()), len, now); 517 | } 518 | 519 | const SACK_LEN: usize = COUNT - 2; 520 | let mut acked = vec![false; SACK_LEN]; 521 | for (i, ack) in acked.iter_mut().enumerate() { 522 | if i % 2 == 0 { 523 | *ack = true; 524 | } 525 | } 526 | let selective_ack = SelectiveAck::new(acked); 527 | 528 | let now = Instant::now(); 529 | sent_packets 530 | .on_ack( 531 | init_seq_num.wrapping_add(1), 532 | Some(&selective_ack), 533 | DELAY, 534 | now, 535 | ) 536 | .unwrap(); 537 | assert_eq!(sent_packets.packets[0].acks.len(), 1); 538 | assert!(sent_packets.packets[1].acks.is_empty()); 539 | for i in 2..COUNT { 540 | let is_empty = i % 2 != 0; 541 | assert_eq!(sent_packets.packets[i].acks.is_empty(), is_empty); 542 | } 543 | } 544 | 545 | #[test] 546 | fn detect_lost_packets() { 547 | let init_seq_num = u16::MAX; 548 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default()); 549 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl); 550 | 551 | let data = vec![0]; 552 | let len = data.len() as u32; 553 | 554 | const COUNT: usize = 10; 555 | const START: usize = COUNT - LOSS_THRESHOLD; 556 | for i in 0..COUNT { 557 | let now = Instant::now(); 558 | let seq_num = sent_packets.next_seq_num(); 559 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data.clone()), len, now); 560 | 561 | if i >= START { 562 | sent_packets.ack(seq_num, DELAY, now); 563 | } 564 | } 565 | 566 | let lost = sent_packets.detect_lost_packets(); 567 | for i in 0..START { 568 | let packet = &sent_packets.packets[i]; 569 | assert!(lost.contains(&packet.seq_num)); 570 | } 571 | } 572 | 573 | #[test] 574 | fn ack() { 575 | let init_seq_num = u16::MAX; 576 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default()); 577 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl); 578 | 579 | let seq_num = sent_packets.next_seq_num(); 580 | let data = vec![0]; 581 | let len = data.len() as u32; 582 | let now = Instant::now(); 583 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data), len, now); 584 | 585 | // Artificially insert packet into lost packets. 586 | sent_packets.lost_packets.insert(seq_num); 587 | assert!(sent_packets.lost_packets.contains(&seq_num)); 588 | 589 | let now = Instant::now(); 590 | sent_packets.ack(seq_num, DELAY, now); 591 | 592 | let index = sent_packets.seq_num_index(seq_num); 593 | let packet = sent_packets.packets.get(index).unwrap(); 594 | 595 | assert_eq!(packet.acks.len(), 1); 596 | assert_eq!(packet.acks[0], now); 597 | assert!(!sent_packets.lost_packets.contains(&seq_num)); 598 | } 599 | 600 | #[test] 601 | fn ack_prior_unacked() { 602 | let init_seq_num = u16::MAX; 603 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default()); 604 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl); 605 | 606 | let data = vec![0]; 607 | let len = data.len() as u32; 608 | 609 | const COUNT: usize = 10; 610 | for _ in 0..COUNT { 611 | let now = Instant::now(); 612 | let seq_num = sent_packets.next_seq_num(); 613 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data.clone()), len, now); 614 | } 615 | 616 | const ACK_NUM: u16 = 3; 617 | assert!(usize::from(ACK_NUM) < COUNT); 618 | assert!(COUNT - usize::from(ACK_NUM) > 2); 619 | 620 | let now = Instant::now(); 621 | sent_packets.ack_prior_unacked(ACK_NUM, DELAY, now); 622 | for i in 0..usize::from(ACK_NUM) { 623 | assert_eq!(sent_packets.packets[i].acks.len(), 1); 624 | } 625 | } 626 | 627 | #[test] 628 | #[should_panic] 629 | fn ack_unsent() { 630 | let init_seq_num = u16::MAX; 631 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default()); 632 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl); 633 | 634 | let unsent_ack_num = init_seq_num.wrapping_add(2); 635 | let now = Instant::now(); 636 | sent_packets.ack(unsent_ack_num, DELAY, now); 637 | } 638 | 639 | #[test] 640 | fn seq_num_index() { 641 | let init_seq_num = u16::MAX; 642 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default()); 643 | let sent_packets = SentPackets::new(init_seq_num, congestion_ctrl); 644 | 645 | assert_eq!( 646 | sent_packets.seq_num_index(init_seq_num), 647 | usize::from(u16::MAX) 648 | ); 649 | 650 | let zero = init_seq_num.wrapping_add(1); 651 | assert_eq!(sent_packets.seq_num_index(zero), 0); 652 | } 653 | } 654 | -------------------------------------------------------------------------------- /src/seq.rs: -------------------------------------------------------------------------------- 1 | /// A range bounded inclusively below and above that supports wrapping arithmetic. 2 | /// 3 | /// If `end < start`, then the range contains all values `x` such that `start <= x <= u16::MAX` and 4 | /// `0 <= x <= end`. 5 | #[derive(Clone, Debug)] 6 | pub struct CircularRangeInclusive { 7 | start: u16, 8 | end: u16, 9 | exhausted: bool, 10 | } 11 | 12 | impl CircularRangeInclusive { 13 | /// Returns a new range. 14 | pub fn new(start: u16, end: u16) -> Self { 15 | Self { 16 | start, 17 | end, 18 | exhausted: false, 19 | } 20 | } 21 | 22 | /// Returns the start of the range (inclusive). 23 | pub fn start(&self) -> u16 { 24 | self.start 25 | } 26 | 27 | /// Returns the end of the range (inclusive). 28 | pub fn end(&self) -> u16 { 29 | self.end 30 | } 31 | 32 | /// Returns `true` if `item` is contained in the range. 33 | pub fn contains(&self, item: u16) -> bool { 34 | if self.end >= self.start { 35 | item >= self.start && item <= self.end 36 | } else if item >= self.start { 37 | true 38 | } else { 39 | item <= self.end 40 | } 41 | } 42 | } 43 | 44 | impl std::iter::Iterator for CircularRangeInclusive { 45 | type Item = u16; 46 | 47 | fn next(&mut self) -> Option { 48 | if self.exhausted { 49 | None 50 | } else if self.start == self.end { 51 | self.exhausted = true; 52 | Some(self.end) 53 | } else { 54 | let step = self.start.wrapping_add(1); 55 | Some(std::mem::replace(&mut self.start, step)) 56 | } 57 | } 58 | } 59 | 60 | #[cfg(test)] 61 | mod tests { 62 | use super::*; 63 | 64 | use quickcheck::{quickcheck, TestResult}; 65 | 66 | #[test] 67 | fn contains_start() { 68 | fn prop(start: u16, end: u16) -> TestResult { 69 | let range = CircularRangeInclusive::new(start, end); 70 | TestResult::from_bool(range.contains(start)) 71 | } 72 | quickcheck(prop as fn(u16, u16) -> TestResult); 73 | } 74 | 75 | #[test] 76 | fn contains_end() { 77 | fn prop(start: u16, end: u16) -> TestResult { 78 | let range = CircularRangeInclusive::new(start, end); 79 | TestResult::from_bool(range.contains(end)) 80 | } 81 | quickcheck(prop as fn(u16, u16) -> TestResult); 82 | } 83 | 84 | #[test] 85 | fn iterator() { 86 | fn prop(start: u16, end: u16) -> TestResult { 87 | let range = CircularRangeInclusive::new(start, end); 88 | 89 | let mut len: usize = 0; 90 | let mut expected_idx = start; 91 | for idx in range { 92 | assert_eq!(idx, expected_idx); 93 | expected_idx = expected_idx.wrapping_add(1); 94 | len += 1; 95 | } 96 | 97 | let expected_len = if start <= end { 98 | usize::from(end - start) + 1 99 | } else { 100 | usize::from(u16::MAX - start) + usize::from(end) + 2 101 | }; 102 | assert_eq!(len, expected_len); 103 | 104 | TestResult::passed() 105 | } 106 | quickcheck(prop as fn(u16, u16) -> TestResult); 107 | } 108 | 109 | #[test] 110 | fn iterator_single() { 111 | fn prop(x: u16) -> TestResult { 112 | let mut range = CircularRangeInclusive::new(x, x); 113 | assert_eq!(range.next(), Some(x)); 114 | assert!(range.next().is_none()); 115 | 116 | TestResult::passed() 117 | } 118 | quickcheck(prop as fn(u16) -> TestResult); 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /src/socket.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::io; 3 | use std::net::SocketAddr; 4 | use std::sync::{Arc, RwLock}; 5 | use std::time::Duration; 6 | 7 | use delay_map::HashMapDelay; 8 | use futures::StreamExt; 9 | use rand::{thread_rng, Rng}; 10 | use tokio::net::UdpSocket; 11 | use tokio::sync::mpsc::UnboundedSender; 12 | use tokio::sync::{mpsc, oneshot}; 13 | 14 | use crate::cid::ConnectionId; 15 | use crate::conn::ConnectionConfig; 16 | use crate::event::{SocketEvent, StreamEvent}; 17 | use crate::packet::{Packet, PacketBuilder, PacketType}; 18 | use crate::peer::{ConnectionPeer, Peer}; 19 | use crate::stream::UtpStream; 20 | use crate::udp::AsyncUdpSocket; 21 | 22 | type ConnChannel = UnboundedSender; 23 | 24 | struct Accept { 25 | stream: oneshot::Sender>>, 26 | config: ConnectionConfig, 27 | } 28 | 29 | struct AcceptWithCidPeer { 30 | cid: ConnectionId, 31 | peer: Peer

, 32 | accept: Accept

, 33 | } 34 | 35 | const MAX_UDP_PAYLOAD_SIZE: usize = u16::MAX as usize; 36 | const CID_GENERATION_TRY_WARNING_COUNT: usize = 10; 37 | 38 | /// accept_with_cid() has unique interactions compared to accept() 39 | /// accept() pulls awaiting requests off a queue, but accept_with_cid() only 40 | /// takes a connection off if CID matches. Because of this if we are awaiting a CID 41 | /// eventually we need to timeout the await, or the queue would never stop growing with stale awaits 42 | /// 20 seconds is arbitrary, after the uTP config refactor is done that can replace this constant. 43 | /// but thee uTP config refactor is currently very low priority. 44 | const AWAITING_CONNECTION_TIMEOUT: Duration = Duration::from_secs(20); 45 | 46 | pub struct UtpSocket { 47 | conns: Arc, ConnChannel>>>, 48 | accepts: UnboundedSender>, 49 | accepts_with_cid: UnboundedSender>, 50 | socket_events: UnboundedSender>, 51 | } 52 | 53 | impl UtpSocket { 54 | pub async fn bind(addr: SocketAddr) -> io::Result { 55 | let socket = UdpSocket::bind(addr).await?; 56 | let socket = Self::with_socket(socket); 57 | Ok(socket) 58 | } 59 | } 60 | 61 | impl

UtpSocket

62 | where 63 | P: ConnectionPeer + Unpin + 'static, 64 | { 65 | pub fn with_socket(mut socket: S) -> Self 66 | where 67 | S: AsyncUdpSocket

+ 'static, 68 | { 69 | let conns = HashMap::new(); 70 | let conns = Arc::new(RwLock::new(conns)); 71 | 72 | let mut awaiting: HashMapDelay, AcceptWithCidPeer

> = 73 | HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT); 74 | 75 | let mut incoming_conns: HashMapDelay, (Peer

, Packet)> = 76 | HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT); 77 | 78 | let (socket_event_tx, mut socket_event_rx) = mpsc::unbounded_channel(); 79 | let (accepts_tx, mut accepts_rx) = mpsc::unbounded_channel(); 80 | let (accepts_with_cid_tx, mut accepts_with_cid_rx) = mpsc::unbounded_channel(); 81 | 82 | let utp = Self { 83 | conns: Arc::clone(&conns), 84 | accepts: accepts_tx, 85 | accepts_with_cid: accepts_with_cid_tx, 86 | socket_events: socket_event_tx.clone(), 87 | }; 88 | 89 | tokio::spawn(async move { 90 | let mut buf = [0; MAX_UDP_PAYLOAD_SIZE]; 91 | loop { 92 | tokio::select! { 93 | biased; 94 | Ok((n, mut peer)) = socket.recv_from(&mut buf) => { 95 | let peer_id = peer.id(); 96 | let packet = match Packet::decode(&buf[..n]) { 97 | Ok(pkt) => pkt, 98 | Err(..) => { 99 | tracing::warn!(?peer, "unable to decode uTP packet"); 100 | continue; 101 | } 102 | }; 103 | 104 | let peer_init_cid = cid_from_packet::

(&packet, peer_id, IdType::SendIdPeerInitiated); 105 | let we_init_cid = cid_from_packet::

(&packet, peer_id, IdType::SendIdWeInitiated); 106 | let acc_cid = cid_from_packet::

(&packet, peer_id, IdType::RecvId); 107 | let mut conns = conns.write().unwrap(); 108 | let conn = conns 109 | .get(&acc_cid) 110 | .or_else(|| conns.get(&we_init_cid)) 111 | .or_else(|| conns.get(&peer_init_cid)); 112 | match conn { 113 | Some(conn) => { 114 | let _ = conn.send(StreamEvent::Incoming(packet)); 115 | } 116 | None => { 117 | if std::matches!(packet.packet_type(), PacketType::Syn) { 118 | let cid = acc_cid; 119 | 120 | // If there was an awaiting connection with the CID, then 121 | // create a new stream for that connection. Otherwise, add the 122 | // connection to the incoming connections. 123 | if let Some(accept_with_cid) = awaiting.remove(&cid) { 124 | peer.consolidate(accept_with_cid.peer); 125 | 126 | let (connected_tx, connected_rx) = oneshot::channel(); 127 | let (events_tx, events_rx) = mpsc::unbounded_channel(); 128 | 129 | conns.insert(cid.clone(), events_tx); 130 | 131 | let stream = UtpStream::new( 132 | cid, 133 | peer, 134 | accept_with_cid.accept.config, 135 | Some(packet), 136 | socket_event_tx.clone(), 137 | events_rx, 138 | connected_tx 139 | ); 140 | 141 | tokio::spawn(async move { 142 | Self::await_connected(stream, accept_with_cid.accept.stream, connected_rx).await 143 | }); 144 | } else { 145 | incoming_conns.insert(cid, (peer, packet)); 146 | } 147 | } else { 148 | tracing::debug!( 149 | cid = %packet.conn_id(), 150 | packet = ?packet.packet_type(), 151 | seq = %packet.seq_num(), 152 | ack = %packet.ack_num(), 153 | peer_init_cid = ?peer_init_cid, 154 | we_init_cid = ?we_init_cid, 155 | acc_cid = ?acc_cid, 156 | "received uTP packet for non-existing conn" 157 | ); 158 | // don't send a reset if we are receiving a reset 159 | if packet.packet_type() != PacketType::Reset { 160 | // if we get a packet from an unknown source send a reset packet. 161 | let random_seq_num = thread_rng().gen_range(0..=65535); 162 | let reset_packet = 163 | PacketBuilder::new(PacketType::Reset, packet.conn_id(), crate::time::now_micros(), 100_000, random_seq_num) 164 | .build(); 165 | let event = SocketEvent::Outgoing((reset_packet, peer)); 166 | if socket_event_tx.send(event).is_err() { 167 | tracing::warn!("Cannot transmit reset packet: socket closed channel"); 168 | return; 169 | } 170 | } 171 | } 172 | }, 173 | } 174 | } 175 | Some(accept_with_cid) = accepts_with_cid_rx.recv() => { 176 | let Some((mut peer, syn)) = incoming_conns.remove(&accept_with_cid.cid) else { 177 | awaiting.insert(accept_with_cid.cid.clone(), accept_with_cid); 178 | continue; 179 | }; 180 | peer.consolidate(accept_with_cid.peer); 181 | Self::select_accept_helper(accept_with_cid.cid, peer, syn, conns.clone(), accept_with_cid.accept, socket_event_tx.clone()); 182 | } 183 | Some(accept) = accepts_rx.recv(), if !incoming_conns.is_empty() => { 184 | let cid = incoming_conns.keys().next().expect("at least one incoming connection"); 185 | let cid = cid.clone(); 186 | let (peer, packet) = incoming_conns.remove(&cid).expect("to delete incoming connection"); 187 | Self::select_accept_helper(cid, peer, packet, conns.clone(), accept, socket_event_tx.clone()); 188 | } 189 | Some(event) = socket_event_rx.recv() => { 190 | match event { 191 | SocketEvent::Outgoing((packet, dst)) => { 192 | let encoded = packet.encode(); 193 | if let Err(err) = socket.send_to(&encoded, &dst).await { 194 | tracing::debug!( 195 | %err, 196 | cid = %packet.conn_id(), 197 | packet = ?packet.packet_type(), 198 | seq = %packet.seq_num(), 199 | ack = %packet.ack_num(), 200 | "unable to send uTP packet over socket" 201 | ); 202 | } 203 | } 204 | SocketEvent::Shutdown(cid) => { 205 | tracing::debug!(%cid.send, %cid.recv, "uTP conn shutdown"); 206 | conns.write().unwrap().remove(&cid); 207 | } 208 | } 209 | } 210 | Some(Ok((cid, accept_with_cid))) = awaiting.next() => { 211 | // accept_with_cid didn't receive an inbound connection within the timeout period 212 | // log it and return a timeout error 213 | tracing::debug!(%cid.send, %cid.recv, "accept_with_cid timed out"); 214 | let _ = accept_with_cid.accept 215 | .stream 216 | .send(Err(io::Error::from(io::ErrorKind::TimedOut))); 217 | } 218 | Some(Ok((cid, _packet))) = incoming_conns.next() => { 219 | // didn't handle inbound connection within the timeout period 220 | // log it and return a timeout error 221 | tracing::debug!(%cid.send, %cid.recv, "inbound connection timed out"); 222 | } 223 | } 224 | } 225 | }); 226 | 227 | utp 228 | } 229 | 230 | /// Internal cid generation 231 | fn generate_cid( 232 | &self, 233 | peer_id: P::Id, 234 | is_initiator: bool, 235 | event_tx: Option>, 236 | ) -> ConnectionId { 237 | let mut cid = ConnectionId { 238 | send: 0, 239 | recv: 0, 240 | peer_id, 241 | }; 242 | let mut generation_attempt_count = 0; 243 | loop { 244 | if generation_attempt_count > CID_GENERATION_TRY_WARNING_COUNT { 245 | tracing::error!("cid() tried to generate a cid {generation_attempt_count} times") 246 | } 247 | let recv: u16 = rand::random(); 248 | let send = if is_initiator { 249 | recv.wrapping_add(1) 250 | } else { 251 | recv.wrapping_sub(1) 252 | }; 253 | cid.send = send; 254 | cid.recv = recv; 255 | 256 | if !self.conns.read().unwrap().contains_key(&cid) { 257 | if let Some(event_tx) = event_tx { 258 | self.conns.write().unwrap().insert(cid.clone(), event_tx); 259 | } 260 | return cid; 261 | } 262 | generation_attempt_count += 1; 263 | } 264 | } 265 | 266 | pub fn cid(&self, peer_id: P::Id, is_initiator: bool) -> ConnectionId { 267 | self.generate_cid(peer_id, is_initiator, None) 268 | } 269 | 270 | /// Returns the number of connections currently open, both inbound and outbound. 271 | pub fn num_connections(&self) -> usize { 272 | self.conns.read().unwrap().len() 273 | } 274 | 275 | /// WARNING: only accept() or accept_with_cid() can be used in an application. 276 | /// they aren't compatible to use interchangeably in a program 277 | pub async fn accept(&self, config: ConnectionConfig) -> io::Result> { 278 | let (stream_tx, stream_rx) = oneshot::channel(); 279 | let accept = Accept { 280 | stream: stream_tx, 281 | config, 282 | }; 283 | self.accepts 284 | .send(accept) 285 | .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?; 286 | match stream_rx.await { 287 | Ok(stream) => Ok(stream?), 288 | Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)), 289 | } 290 | } 291 | 292 | /// WARNING: only accept() or accept_with_cid() can be used in an application. 293 | /// they aren't compatible to use interchangeably in a program 294 | pub async fn accept_with_cid( 295 | &self, 296 | cid: ConnectionId, 297 | peer: Peer

, 298 | config: ConnectionConfig, 299 | ) -> io::Result> { 300 | let (stream_tx, stream_rx) = oneshot::channel(); 301 | let accept = AcceptWithCidPeer { 302 | cid, 303 | peer, 304 | accept: Accept { 305 | stream: stream_tx, 306 | config, 307 | }, 308 | }; 309 | self.accepts_with_cid 310 | .send(accept) 311 | .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?; 312 | match stream_rx.await { 313 | Ok(stream) => Ok(stream?), 314 | Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)), 315 | } 316 | } 317 | 318 | pub async fn connect( 319 | &self, 320 | peer: Peer

, 321 | config: ConnectionConfig, 322 | ) -> io::Result> { 323 | let (connected_tx, connected_rx) = oneshot::channel(); 324 | let (events_tx, events_rx) = mpsc::unbounded_channel(); 325 | let cid = self.generate_cid(peer.id().clone(), true, Some(events_tx)); 326 | 327 | let stream = UtpStream::new( 328 | cid, 329 | peer, 330 | config, 331 | None, 332 | self.socket_events.clone(), 333 | events_rx, 334 | connected_tx, 335 | ); 336 | 337 | match connected_rx.await { 338 | Ok(Ok(..)) => Ok(stream), 339 | Ok(Err(err)) => Err(err), 340 | Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)), 341 | } 342 | } 343 | 344 | pub async fn connect_with_cid( 345 | &self, 346 | cid: ConnectionId, 347 | peer: Peer

, 348 | config: ConnectionConfig, 349 | ) -> io::Result> { 350 | if self.conns.read().unwrap().contains_key(&cid) { 351 | return Err(io::Error::new( 352 | io::ErrorKind::Other, 353 | "connection ID unavailable".to_string(), 354 | )); 355 | } 356 | 357 | let (connected_tx, connected_rx) = oneshot::channel(); 358 | let (events_tx, events_rx) = mpsc::unbounded_channel(); 359 | 360 | { 361 | self.conns.write().unwrap().insert(cid.clone(), events_tx); 362 | } 363 | 364 | let stream = UtpStream::new( 365 | cid.clone(), 366 | peer, 367 | config, 368 | None, 369 | self.socket_events.clone(), 370 | events_rx, 371 | connected_tx, 372 | ); 373 | 374 | match connected_rx.await { 375 | Ok(Ok(..)) => Ok(stream), 376 | Ok(Err(err)) => { 377 | tracing::error!(%err, "failed to open connection with {cid:?}"); 378 | Err(err) 379 | } 380 | Err(err) => { 381 | tracing::error!(%err, "failed to open connection with {cid:?}"); 382 | Err(io::Error::from(io::ErrorKind::TimedOut)) 383 | } 384 | } 385 | } 386 | 387 | async fn await_connected( 388 | stream: UtpStream

, 389 | callback: oneshot::Sender>>, 390 | connected: oneshot::Receiver>, 391 | ) { 392 | match connected.await { 393 | Ok(Ok(..)) => { 394 | let _ = callback.send(Ok(stream)); 395 | } 396 | Ok(Err(err)) => { 397 | let _ = callback.send(Err(err)); 398 | } 399 | Err(..) => { 400 | let _ = callback.send(Err(io::Error::from(io::ErrorKind::ConnectionAborted))); 401 | } 402 | } 403 | } 404 | 405 | fn select_accept_helper( 406 | cid: ConnectionId, 407 | peer: Peer

, 408 | syn: Packet, 409 | conns: Arc, ConnChannel>>>, 410 | accept: Accept

, 411 | socket_event_tx: UnboundedSender>, 412 | ) { 413 | if conns.read().unwrap().contains_key(&cid) { 414 | let _ = accept.stream.send(Err(io::Error::new( 415 | io::ErrorKind::Other, 416 | "connection ID unavailable".to_string(), 417 | ))); 418 | return; 419 | } 420 | 421 | let (connected_tx, connected_rx) = oneshot::channel(); 422 | let (events_tx, events_rx) = mpsc::unbounded_channel(); 423 | 424 | { 425 | conns.write().unwrap().insert(cid.clone(), events_tx); 426 | } 427 | 428 | let stream = UtpStream::new( 429 | cid, 430 | peer, 431 | accept.config, 432 | Some(syn), 433 | socket_event_tx, 434 | events_rx, 435 | connected_tx, 436 | ); 437 | 438 | tokio::spawn( 439 | async move { Self::await_connected(stream, accept.stream, connected_rx).await }, 440 | ); 441 | } 442 | } 443 | 444 | #[derive(Copy, Clone, Debug)] 445 | enum IdType { 446 | RecvId, 447 | SendIdWeInitiated, 448 | SendIdPeerInitiated, 449 | } 450 | 451 | fn cid_from_packet( 452 | packet: &Packet, 453 | peer_id: &P::Id, 454 | id_type: IdType, 455 | ) -> ConnectionId { 456 | let peer_id = peer_id.clone(); 457 | match id_type { 458 | IdType::RecvId => { 459 | let (send, recv) = match packet.packet_type() { 460 | PacketType::Syn => (packet.conn_id(), packet.conn_id().wrapping_add(1)), 461 | PacketType::State | PacketType::Data | PacketType::Fin | PacketType::Reset => { 462 | (packet.conn_id().wrapping_sub(1), packet.conn_id()) 463 | } 464 | }; 465 | ConnectionId { 466 | send, 467 | recv, 468 | peer_id, 469 | } 470 | } 471 | IdType::SendIdWeInitiated => { 472 | let (send, recv) = (packet.conn_id().wrapping_add(1), packet.conn_id()); 473 | ConnectionId { 474 | send, 475 | recv, 476 | peer_id, 477 | } 478 | } 479 | IdType::SendIdPeerInitiated => { 480 | let (send, recv) = (packet.conn_id(), packet.conn_id().wrapping_sub(1)); 481 | ConnectionId { 482 | send, 483 | recv, 484 | peer_id, 485 | } 486 | } 487 | } 488 | } 489 | 490 | impl Drop for UtpSocket

{ 491 | fn drop(&mut self) { 492 | for conn in self.conns.read().unwrap().values() { 493 | let _ = conn.send(StreamEvent::Shutdown); 494 | } 495 | } 496 | } 497 | -------------------------------------------------------------------------------- /src/stream.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | 3 | use tokio::sync::{mpsc, oneshot}; 4 | use tokio::task; 5 | use tracing::Instrument; 6 | 7 | use crate::cid::ConnectionId; 8 | use crate::congestion::DEFAULT_MAX_PACKET_SIZE_BYTES; 9 | use crate::conn; 10 | use crate::event::{SocketEvent, StreamEvent}; 11 | use crate::packet::Packet; 12 | use crate::peer::{ConnectionPeer, Peer}; 13 | 14 | /// The size of the send and receive buffers. 15 | // TODO: Make the buffer size configurable. 16 | const BUF: usize = 1024 * 1024; 17 | 18 | pub struct UtpStream { 19 | cid: ConnectionId, 20 | reads: mpsc::UnboundedReceiver, 21 | writes: mpsc::UnboundedSender, 22 | shutdown: Option>, 23 | conn_handle: Option>>, 24 | } 25 | 26 | impl

UtpStream

27 | where 28 | P: ConnectionPeer + 'static, 29 | { 30 | pub(crate) fn new( 31 | cid: ConnectionId, 32 | peer: Peer

, 33 | config: conn::ConnectionConfig, 34 | syn: Option, 35 | socket_events: mpsc::UnboundedSender>, 36 | stream_events: mpsc::UnboundedReceiver, 37 | connected: oneshot::Sender>, 38 | ) -> Self { 39 | let (shutdown_tx, shutdown_rx) = oneshot::channel(); 40 | let (reads_tx, reads_rx) = mpsc::unbounded_channel(); 41 | let (writes_tx, writes_rx) = mpsc::unbounded_channel(); 42 | let mut conn = conn::Connection::::new( 43 | cid.clone(), 44 | peer, 45 | config, 46 | syn, 47 | connected, 48 | socket_events, 49 | reads_tx, 50 | ); 51 | let conn_handle = tokio::spawn(async move { 52 | conn.event_loop(stream_events, writes_rx, shutdown_rx) 53 | .instrument(tracing::info_span!("uTP", send = cid.send, recv = cid.recv)) 54 | .await 55 | }); 56 | 57 | Self { 58 | cid, 59 | reads: reads_rx, 60 | writes: writes_tx, 61 | shutdown: Some(shutdown_tx), 62 | conn_handle: Some(conn_handle), 63 | } 64 | } 65 | 66 | pub fn cid(&self) -> &ConnectionId { 67 | &self.cid 68 | } 69 | 70 | pub async fn read_to_eof(&mut self, buf: &mut Vec) -> io::Result { 71 | // Reserve space in the buffer to avoid expensive allocation for small reads. 72 | buf.reserve(DEFAULT_MAX_PACKET_SIZE_BYTES as usize); 73 | 74 | let mut n = 0; 75 | loop { 76 | match self.reads.recv().await { 77 | Some(data) => match data { 78 | Ok(mut data) => { 79 | if data.is_empty() { 80 | return Ok(n); 81 | } 82 | n += data.len(); 83 | buf.append(&mut data); 84 | 85 | // Reserve additional space in the buffer proportional to the amount of 86 | // data read. 87 | buf.reserve(data.len()); 88 | } 89 | Err(err) => return Err(err), 90 | }, 91 | None => tracing::debug!("read buffer was sent None"), 92 | } 93 | } 94 | } 95 | 96 | pub async fn write(&mut self, buf: &[u8]) -> io::Result { 97 | if self.shutdown.is_none() { 98 | return Err(io::Error::from(io::ErrorKind::NotConnected)); 99 | } 100 | 101 | let (tx, rx) = oneshot::channel(); 102 | self.writes 103 | .send((buf.to_vec(), tx)) 104 | .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?; 105 | 106 | match rx.await { 107 | Ok(n) => Ok(n?), 108 | Err(err) => Err(io::Error::new(io::ErrorKind::Other, format!("{err:?}"))), 109 | } 110 | } 111 | 112 | /// Closes the stream gracefully. 113 | /// Completes when the remote peer acknowledges all sent data. 114 | pub async fn close(&mut self) -> io::Result<()> { 115 | self.shutdown()?; 116 | match self.conn_handle.take() { 117 | Some(conn_handle) => conn_handle.await?, 118 | None => Err(io::Error::from(io::ErrorKind::NotConnected)), 119 | } 120 | } 121 | } 122 | 123 | impl UtpStream

{ 124 | // Send signal to the connection event loop to exit, after all outgoing writes have completed. 125 | // Public callers should use close() instead. 126 | fn shutdown(&mut self) -> io::Result<()> { 127 | match self.shutdown.take() { 128 | Some(shutdown) => Ok(shutdown 129 | .send(()) 130 | .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?), 131 | None => Err(io::Error::from(io::ErrorKind::NotConnected)), 132 | } 133 | } 134 | } 135 | 136 | impl Drop for UtpStream

{ 137 | fn drop(&mut self) { 138 | let _ = self.shutdown(); 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /src/testutils.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::sync::atomic::{AtomicBool, Ordering}; 3 | use std::sync::Arc; 4 | 5 | use async_trait::async_trait; 6 | use tokio::sync::mpsc; 7 | 8 | use crate::cid::ConnectionId; 9 | use crate::peer::{ConnectionPeer, Peer}; 10 | use crate::udp::AsyncUdpSocket; 11 | 12 | /// Decides whether the link between peers is up or down. 13 | trait LinkDecider { 14 | /// Returns true if the link should send the packet, false otherwise. 15 | /// 16 | /// This must only be called once per packet, as it may have side-effects. 17 | fn should_send(&mut self) -> bool; 18 | } 19 | 20 | /// A mock socket that can be used to simulate a perfect link. 21 | #[derive(Debug)] 22 | pub struct MockUdpSocket { 23 | outbound: mpsc::UnboundedSender>, 24 | inbound: mpsc::UnboundedReceiver>, 25 | /// Peers identified by a letter 26 | pub only_peer: char, 27 | /// Defines whether the link is up. If not up, link will SILENTLY drop all sent packets. 28 | pub link: Link, 29 | } 30 | 31 | #[derive(Clone)] 32 | pub struct ManualLinkDecider { 33 | pub up_switch: Arc, 34 | } 35 | 36 | impl ManualLinkDecider { 37 | fn new() -> Self { 38 | Self { 39 | up_switch: Arc::new(AtomicBool::new(true)), 40 | } 41 | } 42 | } 43 | 44 | impl LinkDecider for ManualLinkDecider { 45 | fn should_send(&mut self) -> bool { 46 | self.up_switch.load(Ordering::SeqCst) 47 | } 48 | } 49 | 50 | pub struct LinkDropsFirstNSent { 51 | target_drops: usize, 52 | actual_drops: usize, 53 | } 54 | 55 | impl LinkDropsFirstNSent { 56 | fn new(n: usize) -> Self { 57 | Self { 58 | target_drops: n, 59 | actual_drops: 0, 60 | } 61 | } 62 | } 63 | 64 | impl LinkDecider for LinkDropsFirstNSent { 65 | fn should_send(&mut self) -> bool { 66 | if self.actual_drops < self.target_drops { 67 | self.actual_drops += 1; 68 | false 69 | } else { 70 | true 71 | } 72 | } 73 | } 74 | 75 | #[async_trait] 76 | impl AsyncUdpSocket 77 | for MockUdpSocket 78 | { 79 | /// # Panics 80 | /// 81 | /// Panics if `target` is not equal to `self.only_peer`. This socket is built to support 82 | /// exactly two peers communicating with each other, so it will panic if used with more. 83 | async fn send_to(&mut self, buf: &[u8], peer: &Peer) -> io::Result { 84 | if peer.id() != &self.only_peer { 85 | panic!("MockUdpSocket only supports sending to one peer"); 86 | } 87 | if !self.link.should_send() { 88 | tracing::warn!("Dropping packet to {peer:?}: {buf:?}"); 89 | return Ok(buf.len()); 90 | } 91 | if let Err(err) = self.outbound.send(buf.to_vec()) { 92 | Err(io::Error::new( 93 | io::ErrorKind::UnexpectedEof, 94 | format!("channel closed: {err}"), 95 | )) 96 | } else { 97 | Ok(buf.len()) 98 | } 99 | } 100 | 101 | /// # Panics 102 | /// 103 | /// Panics if `buf` is smaller than the packet size. 104 | async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer)> { 105 | let packet = self 106 | .inbound 107 | .recv() 108 | .await 109 | .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "channel closed"))?; 110 | if buf.len() < packet.len() { 111 | panic!("buffer too small for perfect link"); 112 | } 113 | let packet_len = packet.len(); 114 | buf[..packet_len].copy_from_slice(&packet[..]); 115 | Ok((packet_len, Peer::new(self.only_peer))) 116 | } 117 | } 118 | 119 | impl ConnectionPeer for char { 120 | type Id = char; 121 | 122 | fn id(&self) -> Self::Id { 123 | *self 124 | } 125 | 126 | fn consolidate(a: Self, b: Self) -> Self { 127 | assert!(a == b, "Consolidating non-equal peers"); 128 | a 129 | } 130 | } 131 | 132 | fn build_link_pair( 133 | a_to_b_link: LinkAtoB, 134 | b_to_a_link: LinkBtoA, 135 | ) -> (MockUdpSocket, MockUdpSocket) { 136 | let (peer_a, peer_b): (char, char) = ('A', 'B'); 137 | let (a_tx, a_rx) = mpsc::unbounded_channel(); 138 | let (b_tx, b_rx) = mpsc::unbounded_channel(); 139 | let a = MockUdpSocket { 140 | outbound: a_tx, 141 | inbound: b_rx, 142 | only_peer: peer_b, 143 | link: a_to_b_link, 144 | }; 145 | let b = MockUdpSocket { 146 | outbound: b_tx, 147 | inbound: a_rx, 148 | only_peer: peer_a, 149 | link: b_to_a_link, 150 | }; 151 | (a, b) 152 | } 153 | 154 | fn build_connection_id_pair( 155 | socket_a: &MockUdpSocket, 156 | socket_b: &MockUdpSocket, 157 | ) -> (ConnectionId, ConnectionId) { 158 | build_connection_id_pair_starting_at(socket_a, socket_b, 100) 159 | } 160 | 161 | fn build_connection_id_pair_starting_at( 162 | socket_a: &MockUdpSocket, 163 | socket_b: &MockUdpSocket, 164 | lower_id: u16, 165 | ) -> (ConnectionId, ConnectionId) { 166 | let higher_id = lower_id.wrapping_add(1); 167 | let a_cid = ConnectionId { 168 | send: higher_id, 169 | recv: lower_id, 170 | peer_id: socket_a.only_peer, 171 | }; 172 | let b_cid = ConnectionId { 173 | send: lower_id, 174 | recv: higher_id, 175 | peer_id: socket_b.only_peer, 176 | }; 177 | (a_cid, b_cid) 178 | } 179 | 180 | /// Build a link between sockets, which we can manually control whether it is up or down 181 | #[allow(clippy::type_complexity)] 182 | pub fn build_manually_linked_pair() -> ( 183 | (MockUdpSocket, ConnectionId), 184 | (MockUdpSocket, ConnectionId), 185 | ) { 186 | let (socket_a, socket_b) = build_link_pair(ManualLinkDecider::new(), ManualLinkDecider::new()); 187 | let (a_cid, b_cid) = build_connection_id_pair(&socket_a, &socket_b); 188 | ((socket_a, a_cid), (socket_b, b_cid)) 189 | } 190 | 191 | /// Build a link between sockets, where the first n packets sent by the 2nd socket are dropped. 192 | /// 193 | /// The first socket is the one with the lower connection ID, which must be the one initiating the 194 | /// connection. 195 | #[allow(clippy::type_complexity)] 196 | pub fn build_link_drops_first_n_sent_pair( 197 | n: usize, 198 | ) -> ( 199 | (MockUdpSocket, ConnectionId), 200 | (MockUdpSocket, ConnectionId), 201 | ) { 202 | let link_a_to_b = ManualLinkDecider::new(); 203 | let link_b_to_a = LinkDropsFirstNSent::new(n); 204 | let (socket_a, socket_b) = build_link_pair(link_a_to_b, link_b_to_a); 205 | let (a_cid, b_cid) = build_connection_id_pair(&socket_a, &socket_b); 206 | ((socket_a, a_cid), (socket_b, b_cid)) 207 | } 208 | -------------------------------------------------------------------------------- /src/time.rs: -------------------------------------------------------------------------------- 1 | use std::time::{Duration, SystemTime, UNIX_EPOCH}; 2 | 3 | /// Returns the UNIX timestamp truncated to a `u32`. 4 | pub fn now_micros() -> u32 { 5 | let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); 6 | now.as_micros() as u32 7 | } 8 | 9 | /// Returns the amount of time elapsed between `earlier_micros` and `later_micros`. 10 | /// 11 | /// If `later_micros` is less than `earlier_micros`, then we assume that `later_micros` 12 | /// has wrapped around the `u32` boundary. 13 | pub fn duration_between(earlier_micros: u32, later_micros: u32) -> Duration { 14 | if later_micros < earlier_micros { 15 | Duration::from_micros((u32::MAX - earlier_micros + later_micros).into()) 16 | } else { 17 | Duration::from_micros((later_micros - earlier_micros).into()) 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/udp.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::net::SocketAddr; 3 | 4 | use async_trait::async_trait; 5 | use tokio::net::UdpSocket; 6 | 7 | use crate::peer::{ConnectionPeer, Peer}; 8 | 9 | /// An abstract representation of an asynchronous UDP socket. 10 | #[async_trait] 11 | pub trait AsyncUdpSocket: Send + Sync { 12 | /// Attempts to send data on the socket to a given peer. 13 | /// Note that this should return nearly immediately, rather than awaiting something internally. 14 | async fn send_to(&mut self, buf: &[u8], peer: &Peer

) -> io::Result; 15 | /// Attempts to receive a single datagram on the socket. 16 | async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer

)>; 17 | } 18 | 19 | #[async_trait] 20 | impl AsyncUdpSocket for UdpSocket { 21 | async fn send_to(&mut self, buf: &[u8], peer: &Peer) -> io::Result { 22 | UdpSocket::send_to(self, buf, peer.id()).await 23 | } 24 | 25 | async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer)> { 26 | UdpSocket::recv_from(self, buf) 27 | .await 28 | .map(|(len, peer)| (len, Peer::new(peer))) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /tests/socket.rs: -------------------------------------------------------------------------------- 1 | use futures::stream::{FuturesUnordered, StreamExt}; 2 | use std::net::SocketAddr; 3 | use std::sync::Arc; 4 | use utp_rs::peer::Peer; 5 | 6 | use tokio::task::JoinHandle; 7 | use tokio::time::Instant; 8 | 9 | use utp_rs::cid; 10 | use utp_rs::conn::ConnectionConfig; 11 | use utp_rs::socket::UtpSocket; 12 | 13 | const TEST_DATA: &[u8] = &[0xf0; 1_000_000]; 14 | 15 | #[tokio::test(flavor = "multi_thread", worker_threads = 16)] 16 | async fn many_concurrent_transfers() { 17 | let _ = tracing_subscriber::fmt::try_init(); 18 | 19 | tracing::info!("starting socket test"); 20 | 21 | let recv_addr = SocketAddr::from(([127, 0, 0, 1], 3400)); 22 | let send_addr = SocketAddr::from(([127, 0, 0, 1], 3401)); 23 | 24 | let recv = UtpSocket::bind(recv_addr).await.unwrap(); 25 | let recv = Arc::new(recv); 26 | let send = UtpSocket::bind(send_addr).await.unwrap(); 27 | let send = Arc::new(send); 28 | let mut handles = FuturesUnordered::new(); 29 | 30 | let start = Instant::now(); 31 | let num_transfers = 1000; 32 | for i in 0..num_transfers { 33 | // step up cid by two to avoid collisions 34 | let handle = initiate_transfer( 35 | i * 2, 36 | recv_addr, 37 | recv.clone(), 38 | send_addr, 39 | send.clone(), 40 | TEST_DATA, 41 | ) 42 | .await; 43 | handles.push(handle.0); 44 | handles.push(handle.1); 45 | } 46 | 47 | while let Some(res) = handles.next().await { 48 | res.unwrap(); 49 | } 50 | let elapsed = Instant::now() - start; 51 | let megabits_sent = num_transfers as f64 * TEST_DATA.len() as f64 * 8.0 / 1_000_000.0; 52 | let transfer_rate = megabits_sent / elapsed.as_secs_f64(); 53 | tracing::info!("finished high concurrency load test of {} simultaneous transfers, in {:?}, at a rate of {:.0} Mbps", num_transfers, elapsed, transfer_rate); 54 | } 55 | 56 | #[tokio::test] 57 | /// Test that a socket can send and receive a large amount of data 58 | async fn one_huge_data_transfer() { 59 | // TODO: test 100MiB or more. Currently, it fails (perhaps due to a rollover at 2^16 packets) 60 | 61 | // At the time of writing, 1024 * 1024 + 1 will hang, because it's bigger than the send buffer, 62 | // and the sending logic pauses until the buffer is larger than the pending data. 63 | const HUGE_DATA: &[u8] = &[0xf0; 1024 * 1024 * 50]; 64 | 65 | let _ = tracing_subscriber::fmt::try_init(); 66 | 67 | tracing::info!("starting single transfer of huge data test"); 68 | 69 | let recv_addr = SocketAddr::from(([127, 0, 0, 1], 3500)); 70 | let send_addr = SocketAddr::from(([127, 0, 0, 1], 3501)); 71 | 72 | let recv = UtpSocket::bind(recv_addr).await.unwrap(); 73 | let recv = Arc::new(recv); 74 | let send = UtpSocket::bind(send_addr).await.unwrap(); 75 | let send = Arc::new(send); 76 | 77 | let start = Instant::now(); 78 | let handle = initiate_transfer( 79 | 0, 80 | recv_addr, 81 | recv.clone(), 82 | send_addr, 83 | send.clone(), 84 | HUGE_DATA, 85 | ) 86 | .await; 87 | 88 | // Wait for the sending side of the transfer to complete 89 | handle.0.await.unwrap(); 90 | // Wait for the receiving side of the transfer to complete 91 | handle.1.await.unwrap(); 92 | 93 | let elapsed = Instant::now() - start; 94 | let megabytes_sent = HUGE_DATA.len() as f64 / 1_000_000.0; 95 | let megabits_sent = megabytes_sent * 8.0; 96 | let transfer_rate = megabits_sent / elapsed.as_secs_f64(); 97 | tracing::info!( 98 | "finished single large transfer test with {:.0} MB, in {:?}, at a rate of {:.1} Mbps", 99 | megabytes_sent, 100 | elapsed, 101 | transfer_rate 102 | ); 103 | } 104 | 105 | async fn initiate_transfer( 106 | i: u16, 107 | recv_addr: SocketAddr, 108 | recv: Arc>, 109 | send_addr: SocketAddr, 110 | send: Arc>, 111 | data: &'static [u8], 112 | ) -> (JoinHandle<()>, JoinHandle<()>) { 113 | let conn_config = ConnectionConfig::default(); 114 | let initiator_cid = 100 + i; 115 | let responder_cid = 100 + i + 1; 116 | let recv_cid = cid::ConnectionId { 117 | send: initiator_cid, 118 | recv: responder_cid, 119 | peer_id: send_addr, 120 | }; 121 | let send_cid = cid::ConnectionId { 122 | send: responder_cid, 123 | recv: initiator_cid, 124 | peer_id: recv_addr, 125 | }; 126 | 127 | let recv_handle = tokio::spawn(async move { 128 | let mut stream = recv 129 | .accept_with_cid(recv_cid, Peer::new(send_addr), conn_config) 130 | .await 131 | .unwrap(); 132 | let mut buf = vec![]; 133 | let n = match stream.read_to_eof(&mut buf).await { 134 | Ok(num_bytes) => num_bytes, 135 | Err(err) => { 136 | let cid = stream.cid(); 137 | tracing::error!(?cid, "read to eof error: {:?}", err); 138 | panic!("fail to read data"); 139 | } 140 | }; 141 | tracing::info!(cid.send = %recv_cid.send, cid.recv = %recv_cid.recv, "read {n} bytes from uTP stream"); 142 | 143 | assert_eq!(n, data.len()); 144 | assert_eq!(buf, data); 145 | }); 146 | 147 | let send_handle = tokio::spawn(async move { 148 | let mut stream = send 149 | .connect_with_cid(send_cid, Peer::new(recv_addr), conn_config) 150 | .await 151 | .unwrap(); 152 | let n = stream.write(data).await.unwrap(); 153 | assert_eq!(n, data.len()); 154 | 155 | stream.close().await.unwrap(); 156 | }); 157 | (send_handle, recv_handle) 158 | } 159 | 160 | // Test that a new socket has zero connections 161 | #[tokio::test] 162 | async fn test_empty_socket_conn_count() { 163 | let socket_addr = SocketAddr::from(([127, 0, 0, 1], 3402)); 164 | let socket = UtpSocket::bind(socket_addr).await.unwrap(); 165 | assert_eq!(socket.num_connections(), 0); 166 | } 167 | 168 | // Test that a socket returns 2 from num_connections after connecting twice 169 | #[tokio::test] 170 | async fn test_socket_reports_two_connections() { 171 | let conn_config = ConnectionConfig::default(); 172 | 173 | let recv_addr = SocketAddr::from(([127, 0, 0, 1], 3404)); 174 | let recv = UtpSocket::bind(recv_addr).await.unwrap(); 175 | let recv = Arc::new(recv); 176 | 177 | let send_addr = SocketAddr::from(([127, 0, 0, 1], 3405)); 178 | let send = UtpSocket::bind(send_addr).await.unwrap(); 179 | let send = Arc::new(send); 180 | 181 | let recv_one_cid = cid::ConnectionId { 182 | send: 100, 183 | recv: 101, 184 | peer_id: send_addr, 185 | }; 186 | let send_one_cid = cid::ConnectionId { 187 | send: 101, 188 | recv: 100, 189 | peer_id: recv_addr, 190 | }; 191 | 192 | let recv_one = Arc::clone(&recv); 193 | let recv_one_handle = tokio::spawn(async move { 194 | recv_one 195 | .accept_with_cid(recv_one_cid, Peer::new(send_addr), conn_config) 196 | .await 197 | .unwrap() 198 | }); 199 | 200 | let send_one = Arc::clone(&send); 201 | let send_one_handle = tokio::spawn(async move { 202 | send_one 203 | .connect_with_cid(send_one_cid, Peer::new(recv_addr), conn_config) 204 | .await 205 | .unwrap() 206 | }); 207 | 208 | let recv_two_cid = cid::ConnectionId { 209 | send: 200, 210 | recv: 201, 211 | peer_id: send_addr, 212 | }; 213 | let send_two_cid = cid::ConnectionId { 214 | send: 201, 215 | recv: 200, 216 | peer_id: recv_addr, 217 | }; 218 | 219 | let recv_two = Arc::clone(&recv); 220 | let recv_two_handle = tokio::spawn(async move { 221 | recv_two 222 | .accept_with_cid(recv_two_cid, Peer::new(send_addr), conn_config) 223 | .await 224 | .unwrap() 225 | }); 226 | 227 | let send_two = Arc::clone(&send); 228 | let send_two_handle = tokio::spawn(async move { 229 | send_two 230 | .connect_with_cid(send_two_cid, Peer::new(recv_addr), conn_config) 231 | .await 232 | .unwrap() 233 | }); 234 | 235 | let (tx_one, rx_one, tx_two, rx_two) = tokio::join!( 236 | send_one_handle, 237 | recv_one_handle, 238 | send_two_handle, 239 | recv_two_handle 240 | ); 241 | tx_one.unwrap(); 242 | rx_one.unwrap(); 243 | tx_two.unwrap(); 244 | rx_two.unwrap(); 245 | 246 | assert_eq!(recv.num_connections(), 2); 247 | assert_eq!(send.num_connections(), 2); 248 | } 249 | -------------------------------------------------------------------------------- /tests/stream.rs: -------------------------------------------------------------------------------- 1 | use std::io::ErrorKind; 2 | use std::sync::atomic::Ordering; 3 | use std::sync::Arc; 4 | use std::time::Duration; 5 | 6 | use tokio::time::timeout; 7 | 8 | use utp_rs::conn::{ConnectionConfig, DEFAULT_MAX_IDLE_TIMEOUT}; 9 | use utp_rs::peer::Peer; 10 | use utp_rs::socket::UtpSocket; 11 | 12 | use utp_rs::testutils; 13 | 14 | // How long should tests expect the connection to wait before timing out due to inactivity? 15 | const EXPECTED_IDLE_TIMEOUT: Duration = DEFAULT_MAX_IDLE_TIMEOUT; 16 | 17 | // Test that close() returns successful, after transfer is complete 18 | #[tokio::test] 19 | async fn close_is_successful_when_write_completes() { 20 | let conn_config = ConnectionConfig::default(); 21 | 22 | let ((send_socket, send_cid), (recv_socket, recv_cid)) = 23 | testutils::build_manually_linked_pair(); 24 | 25 | let recv = UtpSocket::with_socket(recv_socket); 26 | let recv = Arc::new(recv); 27 | 28 | let send = UtpSocket::with_socket(send_socket); 29 | let send = Arc::new(send); 30 | 31 | let recv_one = Arc::clone(&recv); 32 | let recv_one_handle = tokio::spawn(async move { 33 | recv_one 34 | .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config) 35 | .await 36 | .unwrap() 37 | }); 38 | 39 | // Keep a clone of the socket so that it doesn't drop when moved into the task. 40 | // Dropping it causes all connections to exit. 41 | let send_one = Arc::clone(&send); 42 | let send_one_handle = tokio::spawn(async move { 43 | send_one 44 | .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config) 45 | .await 46 | .unwrap() 47 | }); 48 | 49 | let (tx_one, rx_one) = tokio::join!(send_one_handle, recv_one_handle,); 50 | let mut send_stream = tx_one.unwrap(); 51 | let mut recv_stream = rx_one.unwrap(); 52 | 53 | // data to send 54 | const DATA_LEN: usize = 100; 55 | let data = [0xa5; DATA_LEN]; 56 | 57 | // send data 58 | let send_stream_handle = tokio::spawn(async move { 59 | match send_stream.write(&data).await { 60 | Ok(written_len) => assert_eq!(written_len, DATA_LEN), 61 | Err(e) => panic!("Error sending data: {:?}", e), 62 | }; 63 | send_stream 64 | }); 65 | 66 | // recv data 67 | let recv_stream_handle = tokio::spawn(async move { 68 | let mut read_buf = vec![]; 69 | let _ = recv_stream.read_to_eof(&mut read_buf).await.unwrap(); 70 | assert_eq!(read_buf, data.to_vec()); 71 | }); 72 | 73 | // wait for send to start 74 | let mut send_stream = send_stream_handle.await.unwrap(); 75 | 76 | // close stream, which will wait for write to complete, and exit without a problem 77 | // This should happen extremely quickly. 78 | match timeout(Duration::from_millis(20), send_stream.close()).await { 79 | Ok(Ok(_)) => {} 80 | Ok(Err(e)) => panic!("Error closing stream: {:?}", e), 81 | Err(_) => panic!("Timeout closing stream"), 82 | }; 83 | 84 | // confirm that data is received as expected 85 | recv_stream_handle.await.unwrap(); 86 | } 87 | 88 | // Test that close() returns a timeout, if recipient is not ACKing (after a successful connection) 89 | #[tokio::test(start_paused = true)] 90 | async fn close_errors_if_all_packets_dropped() { 91 | let conn_config = ConnectionConfig::default(); 92 | 93 | let ((send_socket, send_cid), (recv_socket, recv_cid)) = 94 | testutils::build_manually_linked_pair(); 95 | let tx_link_switch = send_socket.link.up_switch.clone(); 96 | 97 | let recv = UtpSocket::with_socket(recv_socket); 98 | let recv = Arc::new(recv); 99 | 100 | let send = UtpSocket::with_socket(send_socket); 101 | let send = Arc::new(send); 102 | 103 | let recv_one = Arc::clone(&recv); 104 | let recv_one_handle = tokio::spawn(async move { 105 | recv_one 106 | .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config) 107 | .await 108 | .unwrap() 109 | }); 110 | 111 | // Keep a clone of the socket so that it doesn't drop when moved into the task. 112 | // Dropping it causes all connections to exit. 113 | let send_one = Arc::clone(&send); 114 | let send_one_handle = tokio::spawn(async move { 115 | send_one 116 | .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config) 117 | .await 118 | .unwrap() 119 | }); 120 | 121 | let (tx_one, rx_one) = tokio::join!(send_one_handle, recv_one_handle,); 122 | let mut send_stream = tx_one.unwrap(); 123 | let mut recv_stream = rx_one.unwrap(); 124 | 125 | // ******* DISABLE NETWORK LINK ******** 126 | tx_link_switch.store(false, Ordering::SeqCst); 127 | 128 | // data to send 129 | const DATA_LEN: usize = 100; 130 | let data = [0xa5; DATA_LEN]; 131 | 132 | // send data 133 | let send_stream_handle = tokio::spawn(async move { 134 | match send_stream.write(&data).await { 135 | Ok(written_len) => assert_eq!(written_len, DATA_LEN), 136 | Err(e) => panic!("Error sending data: {:?}", e), 137 | }; 138 | send_stream 139 | }); 140 | 141 | // recv data 142 | let recv_stream_handle = tokio::spawn(async move { 143 | let mut read_buf = vec![]; 144 | let read_err = recv_stream.read_to_eof(&mut read_buf).await.unwrap_err(); 145 | assert_eq!(read_err.kind(), ErrorKind::TimedOut); 146 | }); 147 | 148 | // Wait for send to start 149 | let mut send_stream = send_stream_handle.await.unwrap(); 150 | 151 | // Close stream, which will fail because network is disabled. 152 | match timeout(EXPECTED_IDLE_TIMEOUT * 2, send_stream.close()).await { 153 | Ok(Ok(_)) => panic!("Stream closed successfully, but should have timed out"), 154 | Ok(Err(e)) => { 155 | // The stream must time out when waiting to close, if the network is disabled. 156 | assert_eq!(e.kind(), ErrorKind::TimedOut); 157 | } 158 | Err(e) => { 159 | panic!("The stream did not timeout on close() fast enough, giving up after: {e:?}") 160 | } 161 | }; 162 | 163 | // Wait to confirm that the read will time out, also. 164 | recv_stream_handle.await.unwrap(); 165 | } 166 | 167 | // Test that close() succeeds, if the connection is only missing the FIN-ACK 168 | #[tokio::test(start_paused = true)] 169 | async fn close_succeeds_if_only_fin_ack_dropped() { 170 | let conn_config = ConnectionConfig::default(); 171 | 172 | let ((send_socket, send_cid), (recv_socket, recv_cid)) = 173 | testutils::build_manually_linked_pair(); 174 | let rx_link_switch = recv_socket.link.up_switch.clone(); 175 | 176 | let recv = UtpSocket::with_socket(recv_socket); 177 | let recv = Arc::new(recv); 178 | 179 | let send = UtpSocket::with_socket(send_socket); 180 | let send = Arc::new(send); 181 | 182 | let recv_one = Arc::clone(&recv); 183 | let recv_one_handle = tokio::spawn(async move { 184 | recv_one 185 | .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config) 186 | .await 187 | .unwrap() 188 | }); 189 | 190 | // Keep a clone of the socket so that it doesn't drop when moved into the task. 191 | // Dropping it causes all connections to exit. 192 | let send_one = Arc::clone(&send); 193 | let send_one_handle = tokio::spawn(async move { 194 | send_one 195 | .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config) 196 | .await 197 | .unwrap() 198 | }); 199 | 200 | let (tx_one, rx_one) = tokio::join!(send_one_handle, recv_one_handle,); 201 | let mut send_stream = tx_one.unwrap(); 202 | let mut recv_stream = rx_one.unwrap(); 203 | 204 | // data to send 205 | const DATA_LEN: usize = 100; 206 | let data = [0xa5; DATA_LEN]; 207 | 208 | // send data 209 | let send_stream_handle = tokio::spawn(async move { 210 | match send_stream.write(&data).await { 211 | Ok(written_len) => assert_eq!(written_len, DATA_LEN), 212 | Err(e) => panic!("Error sending data: {:?}", e), 213 | }; 214 | send_stream 215 | }); 216 | 217 | // recv data 218 | let recv_stream_handle = tokio::spawn(async move { 219 | let mut read_buf = vec![]; 220 | let _ = recv_stream.read_to_eof(&mut read_buf).await.unwrap(); 221 | assert_eq!(read_buf, data.to_vec()); 222 | recv_stream 223 | }); 224 | 225 | // Wait for send to start 226 | let mut send_stream = send_stream_handle.await.unwrap(); 227 | 228 | // Wait for the full data to be sent before dropping the link 229 | // This is a timeless sleep, because tokio time is paused 230 | tokio::time::sleep(EXPECTED_IDLE_TIMEOUT / 2).await; 231 | 232 | // ******* DISABLE NETWORK LINK ******** 233 | // This only drops the connection from the recipient to the sender, leading to the following 234 | // scenario: 235 | // - Sender sends FIN 236 | // - Recipient receives FIN, sends FIN-ACK and its own FIN 237 | // - Sender receives nothing, because link is down 238 | // - Recipient is only missing its inbound FIN-ACK and closes with success 239 | // - Sender is missing the recipient's FIN and times out with failure 240 | rx_link_switch.store(false, Ordering::SeqCst); 241 | 242 | match timeout(EXPECTED_IDLE_TIMEOUT * 2, send_stream.close()).await { 243 | Ok(Ok(_)) => panic!("Send stream closed successfully, but should have timed out"), 244 | Ok(Err(e)) => { 245 | // The stream must time out when waiting to close, because recipient's FIN is missing 246 | assert_eq!(e.kind(), ErrorKind::TimedOut); 247 | } 248 | Err(e) => { 249 | panic!("The send stream did not timeout on close() fast enough, giving up after: {e:?}") 250 | } 251 | }; 252 | 253 | let mut recv_stream = recv_stream_handle.await.unwrap(); 254 | 255 | // Since switching to one-way FIN-ACK, closing after reading is not allowed. We only explicitly 256 | // close after write() now, and close after reading should error. 257 | match timeout(EXPECTED_IDLE_TIMEOUT * 2, recv_stream.close()).await { 258 | Ok(Ok(_)) => panic!("Closing after reading should have errored, but succeeded"), 259 | Ok(Err(e)) => { 260 | // The stream will already be disconnected by the read_to_eof() call, so we expect a 261 | // NotConnected error here. 262 | assert_eq!(e.kind(), ErrorKind::NotConnected); 263 | } 264 | Err(e) => { 265 | panic!("The recv stream did not timeout on close() fast enough, giving up after: {e:?}") 266 | } 267 | }; 268 | } 269 | 270 | // Test that data is delivered successfully, even if the original SYN-STATE is dropped 271 | // 272 | // At the time of writing, a bug in this scenario causes the first bytes (2048 of them) to be 273 | // silently lost. The recipient thinks it received everything, but is missing the first bytes of 274 | // the transfer. When the recipient, who started the connection, times out the original SYN, it 275 | // resends. The sender has already sent some data. When the bug is active, the resent STATE 276 | // packet in response to the SYN uses the sequence number after incrementing from the previously 277 | // sent state data. This causes the recipient to ignore all data sent previously. 278 | #[tokio::test(start_paused = true)] 279 | async fn test_data_valid_when_resending_syn_state_response() { 280 | let _ = tracing_subscriber::fmt::try_init(); 281 | 282 | let conn_config = ConnectionConfig::default(); 283 | 284 | let ((connector_socket, connector_cid), (acceptor_socket, acceptor_cid)) = 285 | testutils::build_link_drops_first_n_sent_pair(1); 286 | 287 | let acceptor = UtpSocket::with_socket(acceptor_socket); 288 | let acceptor = Arc::new(acceptor); 289 | 290 | let connector = UtpSocket::with_socket(connector_socket); 291 | let connector = Arc::new(connector); 292 | 293 | // It's important for this scenario that the data recipient is the one creating the connection. 294 | let acceptor_one = Arc::clone(&acceptor); 295 | let acceptor_one_handle = tokio::spawn(async move { 296 | acceptor_one 297 | .accept_with_cid( 298 | acceptor_cid, 299 | Peer::new_id(acceptor_cid.peer_id), 300 | conn_config, 301 | ) 302 | .await 303 | .unwrap() 304 | }); 305 | 306 | // Keep a clone of the socket so that it doesn't drop when moved into the task. 307 | // Dropping it causes all connections to exit. 308 | let connector_one = Arc::clone(&connector); 309 | let connector_one_handle = tokio::spawn(async move { 310 | connector_one 311 | .connect_with_cid( 312 | connector_cid, 313 | Peer::new_id(connector_cid.peer_id), 314 | conn_config, 315 | ) 316 | .await 317 | .unwrap() 318 | }); 319 | 320 | let mut acceptor_stream = acceptor_one_handle.await.unwrap(); 321 | tracing::debug!("Acceptor stream established"); 322 | // Must not wait for connection to complete before writing data here, so that we can trigger 323 | // the bug. 324 | 325 | // data to send, must be longer than 2048 bytes, which are lost in the bug scenario 326 | const DATA_LEN: usize = 9000; 327 | let data = [0xa5; DATA_LEN]; 328 | 329 | // send data 330 | let acceptor_stream_handle = tokio::spawn(async move { 331 | match acceptor_stream.write(&data).await { 332 | Ok(written_len) => assert_eq!(written_len, DATA_LEN), 333 | Err(err) => panic!("Error sending data: {err:?}"), 334 | }; 335 | acceptor_stream 336 | }); 337 | 338 | // Finally, we can wait for the connection to complete. If we await this any earlier in the 339 | // test, then the data won't be sent, and we don't trigger the bug scenario. 340 | let mut connector_stream = connector_one_handle.await.unwrap(); 341 | 342 | // Test that complete data is received 343 | let connector_stream_handle = tokio::spawn(async move { 344 | let mut read_buf = vec![]; 345 | let _ = connector_stream.read_to_eof(&mut read_buf).await.unwrap(); 346 | assert_eq!(read_buf.len(), data.len()); 347 | assert_eq!(read_buf, data.to_vec()); 348 | connector_stream 349 | }); 350 | 351 | // Complete the streams 352 | let mut acceptor_stream = acceptor_stream_handle.await.unwrap(); 353 | acceptor_stream.close().await.unwrap(); 354 | connector_stream_handle.await.unwrap(); 355 | } 356 | --------------------------------------------------------------------------------