├── .circleci
└── config.yml
├── .gitignore
├── Cargo.toml
├── LICENSE
├── Makefile
├── README.md
├── src
├── cid.rs
├── congestion.rs
├── conn.rs
├── event.rs
├── lib.rs
├── packet.rs
├── peer.rs
├── recv.rs
├── send.rs
├── sent.rs
├── seq.rs
├── socket.rs
├── stream.rs
├── testutils.rs
├── time.rs
└── udp.rs
└── tests
├── socket.rs
└── stream.rs
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 | orbs:
3 | rust: circleci/rust@1.6.0
4 | workflows:
5 | prod:
6 | jobs:
7 | - rust/lint-test-build:
8 | clippy_arguments: '--all-targets --all-features -- --deny warnings'
9 | release: true
10 | version: 1.81.0
11 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 | /Cargo.lock
3 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "utp-rs"
3 | version = "0.1.0-alpha.17"
4 | edition = "2021"
5 | authors = ["Jacob Kaufmann", "Jason Carver"]
6 | description = "uTorrent transport protocol"
7 | readme = "README.md"
8 | repository = "https://github.com/ethereum/utp/"
9 | license = "MIT"
10 | keywords = ["utp"]
11 | categories = ["network-programming"]
12 |
13 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
14 |
15 | [dependencies]
16 | async-trait = "0.1.64"
17 | delay_map = "0.3.0"
18 | futures = "0.3.26"
19 | rand = "0.8.5"
20 | tokio = { version = "1.25.0", features = ["io-util", "rt-multi-thread", "macros", "net", "sync", "time"] }
21 | tracing = { version = "0.1.37", features = ["std", "attributes", "log"] }
22 |
23 | [dev-dependencies]
24 | quickcheck = "1.0.3"
25 | tokio = { version = "1.25.0", features = ["test-util"] }
26 | tracing-subscriber = "0.3.16"
27 |
28 | [profile.test]
29 | opt-level = 3
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Jacob Kaufmann
4 | Copyright (c) 2023-2025 Trin Contributors
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: lint
2 | lint: # Run clippy and rustfmt
3 | cargo fmt --all
4 | cargo clippy --all --all-targets --all-features --no-deps -- --deny warnings
5 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # utp
2 | A Rust library for the [uTorrent transport protocol (uTP)](https://www.bittorrent.org/beps/bep_0029.html).
3 |
4 | ## 🚧 WARNING: UNDER CONSTRUCTION 🚧
5 | This library is currently unstable, with known issues. Use at your own discretion.
6 |
7 | # Usage
8 |
9 | ```rust
10 | use std::net::SocketAddr;
11 |
12 | use utp_rs::conn::ConnectionConfig;
13 | use utp_rs::socket::UtpSocket;
14 | use utp_rs::udp::AsyncUdpSocket;
15 |
16 | #[tokio::main]
17 | fn main() {
18 | // bind a standard UDP socket. (transport is over a `tokio::net::UdpSocket`.)
19 | let socket_addr = SocketAddr::from(([127, 0, 0, 1], 3400));
20 | let udp_based_socket = UtpSocket::bind(socket_addr).await.unwrap();
21 |
22 | // bind a custom UDP socket. here we assume `CustomSocket` implements `AsyncUdpSocket`.
23 | let async_udp_socket = CustomSocket::new(..);
24 | let custom_socket = UtpSocket::with_socket(async_udp_socket).await.unwrap();
25 |
26 | // connect to a remote peer over uTP.
27 | let remote = SocketAddr::from(..);
28 | let config = ConnectionConfig::default();
29 | let mut stream = udp_socket::connect(remote, config).await.unwrap();
30 |
31 | // write data to the remote peer over the stream.
32 | let data = vec![0xef; 2048];
33 | let n = stream.write(data.as_slice()).await.unwrap();
34 |
35 | // accept a connection from a remote peer.
36 | let config = ConnectionConfig::default();
37 | let stream = udp_socket.accept(config).await;
38 |
39 | // read data from the remote peer until the peer indicates there is no data left to write.
40 | let mut data = vec![];
41 | let n = stream.read_to_eof(&mut data).await.unwrap();
42 | }
43 | ```
44 |
45 |
--------------------------------------------------------------------------------
/src/cid.rs:
--------------------------------------------------------------------------------
1 | #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq)]
2 | pub struct ConnectionId
{
3 | pub send: u16,
4 | pub recv: u16,
5 | pub peer_id: P,
6 | }
7 |
--------------------------------------------------------------------------------
/src/congestion.rs:
--------------------------------------------------------------------------------
1 | use std::cmp;
2 | use std::collections::{BinaryHeap, HashMap};
3 | use std::time::{Duration, Instant};
4 |
5 | pub(crate) const DEFAULT_TARGET_MICROS: u32 = 100_000;
6 | pub(crate) const DEFAULT_INITIAL_TIMEOUT: Duration = Duration::from_secs(1);
7 | pub(crate) const DEFAULT_MIN_TIMEOUT: Duration = Duration::from_millis(500);
8 | pub(crate) const DEFAULT_MAX_TIMEOUT: Duration = Duration::from_secs(60);
9 | pub(crate) const DEFAULT_MAX_PACKET_SIZE_BYTES: u32 = 1024;
10 | const DEFAULT_GAIN: f32 = 1.0;
11 | const DEFAULT_DELAY_WINDOW: Duration = Duration::from_secs(120);
12 |
13 | #[derive(Clone, Debug)]
14 | struct Packet {
15 | size_bytes: u32,
16 | num_transmissions: u32,
17 | acked: bool,
18 | }
19 |
20 | #[derive(Clone, Debug)]
21 | pub enum Transmit {
22 | Initial { bytes: u32 },
23 | Retransmission,
24 | }
25 |
26 | #[derive(Clone, Debug)]
27 | pub struct Ack {
28 | pub delay: Duration,
29 | pub rtt: Duration,
30 | pub received_at: Instant,
31 | }
32 |
33 | #[derive(Clone, Debug, Eq, PartialEq)]
34 | pub enum Error {
35 | InsufficientWindowSize,
36 | UnknownSeqNum,
37 | DuplicateTransmission,
38 | }
39 |
40 | #[derive(Clone, Debug)]
41 | pub struct Config {
42 | pub target_delay_micros: u32,
43 | pub initial_timeout: Duration,
44 | pub min_timeout: Duration,
45 | pub max_timeout: Duration,
46 | pub max_packet_size_bytes: u32,
47 | pub max_window_size_inc_bytes: u32,
48 | pub gain: f32,
49 | pub delay_window: Duration,
50 | }
51 |
52 | impl Default for Config {
53 | fn default() -> Self {
54 | Self {
55 | target_delay_micros: DEFAULT_TARGET_MICROS,
56 | initial_timeout: DEFAULT_INITIAL_TIMEOUT,
57 | min_timeout: DEFAULT_MIN_TIMEOUT,
58 | max_timeout: DEFAULT_MAX_TIMEOUT,
59 | max_packet_size_bytes: DEFAULT_MAX_PACKET_SIZE_BYTES,
60 | max_window_size_inc_bytes: DEFAULT_MAX_PACKET_SIZE_BYTES,
61 | gain: DEFAULT_GAIN,
62 | delay_window: DEFAULT_DELAY_WINDOW,
63 | }
64 | }
65 | }
66 |
67 | #[derive(Clone, Debug)]
68 | pub struct Controller {
69 | target_delay_micros: u32,
70 | timeout: Duration,
71 | min_timeout: Duration,
72 | max_timeout: Duration,
73 | window_size_bytes: u32,
74 | max_window_size_bytes: u32,
75 | min_window_size_bytes: u32,
76 | max_window_size_inc_bytes: u32,
77 | gain: f32,
78 | rtt: Duration,
79 | rtt_variance_micros: u64,
80 | transmissions: HashMap,
81 | delay_acc: DelayAccumulator,
82 | }
83 |
84 | impl Controller {
85 | /// Creates a `Controller` configured by `config`.
86 | pub fn new(config: Config) -> Self {
87 | Self {
88 | target_delay_micros: config.target_delay_micros,
89 | timeout: config.initial_timeout,
90 | min_timeout: config.min_timeout,
91 | max_timeout: config.max_timeout,
92 | window_size_bytes: 0,
93 | max_window_size_bytes: 2 * config.max_packet_size_bytes,
94 | min_window_size_bytes: 2 * config.max_packet_size_bytes,
95 | max_window_size_inc_bytes: config.max_window_size_inc_bytes,
96 | gain: config.gain,
97 | rtt: Duration::ZERO,
98 | rtt_variance_micros: 0,
99 | transmissions: HashMap::new(),
100 | delay_acc: DelayAccumulator::new(config.delay_window),
101 | }
102 | }
103 |
104 | /// Returns the congestion timeout.
105 | pub fn timeout(&self) -> Duration {
106 | self.timeout
107 | }
108 |
109 | /// Returns the number of bytes available in the congestion window.
110 | pub fn bytes_available_in_window(&self) -> u32 {
111 | // Use saturating arithmetic because the max window (capacity) may drop below the current
112 | // window (data in flight).
113 | self.max_window_size_bytes
114 | .saturating_sub(self.window_size_bytes)
115 | }
116 |
117 | /// Registers the transmission of a packet with the controller.
118 | pub fn on_transmit(&mut self, seq_num: u16, transmission: Transmit) -> Result<(), Error> {
119 | // If the transmission is an initial transmission, then record the transmission. If the
120 | // transmission is a retransmission, then increment the number of transmissions for that
121 | // record.
122 | match transmission {
123 | Transmit::Initial { bytes } => {
124 | if self.transmissions.contains_key(&seq_num) {
125 | return Err(Error::DuplicateTransmission);
126 | }
127 | self.transmissions.insert(
128 | seq_num,
129 | Packet {
130 | size_bytes: bytes,
131 | num_transmissions: 1,
132 | acked: false,
133 | },
134 | );
135 | }
136 | Transmit::Retransmission => {
137 | let packet = self
138 | .transmissions
139 | .get_mut(&seq_num)
140 | .ok_or(Error::UnknownSeqNum)?;
141 | packet.num_transmissions += 1;
142 | }
143 | };
144 |
145 | // The key-value pair should exist by logic above.
146 | let packet = self.transmissions.get(&seq_num).unwrap();
147 |
148 | // If this is the initial transmission of this packet, then increase the size of the
149 | // window. Return an error if there is not sufficient space.
150 | if packet.num_transmissions == 1 {
151 | if self.window_size_bytes + packet.size_bytes > self.max_window_size_bytes {
152 | return Err(Error::InsufficientWindowSize);
153 | }
154 | self.window_size_bytes += packet.size_bytes;
155 | }
156 |
157 | Ok(())
158 | }
159 |
160 | /// Registers a packet `Ack` with the controller.
161 | pub fn on_ack(&mut self, seq_num: u16, ack: Ack) -> Result<(), Error> {
162 | let packet = self
163 | .transmissions
164 | .get_mut(&seq_num)
165 | .ok_or(Error::UnknownSeqNum)?;
166 |
167 | // Mark the packet acknowledged. If the packet was already acknowledged, then short-circuit
168 | // return. There are no newly acknowledged bytes.
169 | if packet.acked {
170 | return Ok(());
171 | }
172 | packet.acked = true;
173 |
174 | let packet = packet.clone();
175 |
176 | // Add the delay to the accumulator.
177 | self.delay_acc.push(ack.delay, ack.received_at);
178 |
179 | // Adjust the maximum congestion window based on the delay of the acknowledged packet.
180 | // Bound the base delay (in microseconds) by `u32::MAX`. Bound the delay (in microseconds)
181 | // of the acknowledged packet by `u32::MAX`.
182 | let base_delay_micros = self
183 | .delay_acc
184 | .base_delay()
185 | .unwrap_or(Duration::ZERO)
186 | .as_micros();
187 | let base_delay_micros = u32::try_from(base_delay_micros).unwrap_or(u32::MAX);
188 | let packet_delay_micros = u32::try_from(ack.delay.as_micros()).unwrap_or(u32::MAX);
189 | let max_window_size_adjustment = compute_max_window_size_adjustment(
190 | self.target_delay_micros,
191 | base_delay_micros,
192 | packet_delay_micros,
193 | self.window_size_bytes,
194 | packet.size_bytes,
195 | self.max_window_size_inc_bytes,
196 | self.gain,
197 | );
198 | self.apply_max_window_size_adjustment(max_window_size_adjustment);
199 |
200 | // Adjust the current window size to account for the acknowledged packet.
201 | //
202 | // An overflow panic occurs if the window size is less than the packet size. This would
203 | // correspond to an invalid operation as the window size should account for the size of the
204 | // packet being acknowledged.
205 | self.window_size_bytes -= packet.size_bytes;
206 |
207 | // Only adjust the round trip time (RTT) estimation if the acknowledgement corresponds to
208 | // the first transmission. The congestion timeout is also adjusted each time the RTT
209 | // estimation is adjusted.
210 | if packet.num_transmissions == 1 {
211 | // Adjust round trip time variance.
212 | let rtt_var_adjustment = compute_rtt_variance_adjustment(
213 | self.rtt.as_micros(),
214 | self.rtt_variance_micros,
215 | ack.rtt.as_micros(),
216 | );
217 | if rtt_var_adjustment.is_negative() {
218 | self.rtt_variance_micros = self
219 | .rtt_variance_micros
220 | .saturating_sub(rtt_var_adjustment.unsigned_abs());
221 | } else {
222 | self.rtt_variance_micros = self
223 | .rtt_variance_micros
224 | .saturating_add(rtt_var_adjustment as u64);
225 | }
226 |
227 | // Adjust round trip time.
228 | let rtt_adjustment = compute_rtt_adjustment(self.rtt.as_micros(), ack.rtt.as_micros());
229 | if rtt_adjustment.is_negative() {
230 | self.rtt = self
231 | .rtt
232 | .saturating_sub(Duration::from_micros(rtt_adjustment.unsigned_abs()));
233 | } else {
234 | self.rtt = self
235 | .rtt
236 | .saturating_add(Duration::from_micros(rtt_adjustment as u64));
237 | }
238 |
239 | // Adjust congestion timeout.
240 | self.apply_timeout_adjustment();
241 | }
242 |
243 | Ok(())
244 | }
245 |
246 | /// Registers a lost packet with the controller.
247 | pub fn on_lost_packet(&mut self, seq_num: u16, retransmitting: bool) -> Result<(), Error> {
248 | let packet = self
249 | .transmissions
250 | .get(&seq_num)
251 | .ok_or(Error::UnknownSeqNum)?;
252 |
253 | self.max_window_size_bytes =
254 | cmp::max(self.max_window_size_bytes / 2, self.min_window_size_bytes);
255 |
256 | // If the packet is not to be retransmitted, then account for those lost bytes in the
257 | // congestion window.
258 | if !retransmitting {
259 | self.window_size_bytes -= packet.size_bytes;
260 | }
261 |
262 | Ok(())
263 | }
264 |
265 | /// Registers a timeout with the controller.
266 | pub fn on_timeout(&mut self) {
267 | self.max_window_size_bytes = self.min_window_size_bytes;
268 | self.timeout = cmp::min(self.timeout * 2, self.max_timeout);
269 | }
270 |
271 | /// Adjusts the maximum window (i.e. congestion window) by `adjustment`, keeping the size of
272 | /// the window within the configured interval, and not allowing it to grow by more than the
273 | /// configured maximum increment.
274 | fn apply_max_window_size_adjustment(&mut self, adjustment: i64) {
275 | // Apply the adjustment.
276 | let adj_max_window_size_bytes = i64::from(self.max_window_size_bytes) + adjustment;
277 |
278 | // The maximum congestion window must not fall below the minimum.
279 | let adj_max_window_size_bytes =
280 | cmp::max(adj_max_window_size_bytes as u32, self.min_window_size_bytes);
281 |
282 | // The maximum congestion window cannot increase by more than the configured maximum
283 | // increment.
284 | self.max_window_size_bytes = cmp::min(
285 | adj_max_window_size_bytes,
286 | self.max_window_size_bytes
287 | .saturating_add(self.max_window_size_inc_bytes),
288 | );
289 | }
290 |
291 | /// Adjusts the congestion timeout based on the current round trip time (RTT) estimate and the
292 | /// current RTT variance.
293 | ///
294 | /// The congestion timeout cannot fall below the configured minimum.
295 | fn apply_timeout_adjustment(&mut self) {
296 | // Do not let timeout go below minimum.
297 | self.timeout = cmp::max(
298 | self.rtt + Duration::from_micros(self.rtt_variance_micros * 4),
299 | self.min_timeout,
300 | );
301 |
302 | // Do not let timeout go above maximum.
303 | self.timeout = cmp::min(self.timeout, self.max_timeout)
304 | }
305 | }
306 |
307 | /// Returns the adjustment in bytes to the maximum window (i.e. congestion window) size based on
308 | /// the delta between the packet delay and the target delay and on the portion of the total
309 | /// in-flight bytes that the packet corresponds to.
310 | fn compute_max_window_size_adjustment(
311 | target_delay_micros: u32,
312 | base_delay_micros: u32,
313 | packet_delay_micros: u32,
314 | window_size_bytes: u32,
315 | packet_size_bytes: u32,
316 | max_window_size_inc_bytes: u32,
317 | gain: f32,
318 | ) -> i64 {
319 | // Adjust the delay based on the base delay.
320 | //
321 | // The `i64::try_from` should not fail because `packet_delay_micros` is a `u32`.
322 | //
323 | // The base delay adjustment should not panic because the base delay is non-negative and
324 | // should not be larger than the current delay.
325 | let delay_micros = i64::from(packet_delay_micros - base_delay_micros);
326 |
327 | let off_target_micros = i64::from(target_delay_micros) - delay_micros;
328 | let delay_factor = (off_target_micros as f64) / f64::from(target_delay_micros);
329 | let window_factor = f64::from(packet_size_bytes) / f64::from(window_size_bytes);
330 |
331 | let scaled_gain =
332 | f64::from(gain) * f64::from(max_window_size_inc_bytes) * delay_factor * window_factor;
333 |
334 | scaled_gain as i64
335 | }
336 |
337 | /// Returns the adjustment to the round trip time (RTT) estimate in microseconds based on the
338 | /// packet RTT and the current RTT estimate.
339 | fn compute_rtt_adjustment(rtt_micros: u128, packet_rtt_micros: u128) -> i64 {
340 | ((packet_rtt_micros as f64 - rtt_micros as f64) / 8.0) as i64
341 | }
342 |
343 | /// Returns the adjustment to round trip time (RTT) variance in microseconds based on the packet
344 | /// RTT, current RTT estimate, and current RTT variance.
345 | fn compute_rtt_variance_adjustment(
346 | rtt_micros: u128,
347 | rtt_variance_micros: u64,
348 | packet_rtt_micros: u128,
349 | ) -> i64 {
350 | let abs_delta_micros = rtt_micros.abs_diff(packet_rtt_micros);
351 |
352 | (((abs_delta_micros as f64) - (rtt_variance_micros as f64)) / 4.0) as i64
353 | }
354 |
355 | #[derive(Clone, Debug, Eq)]
356 | struct Delay {
357 | value: Duration,
358 | deadline: Instant,
359 | }
360 |
361 | impl PartialEq for Delay {
362 | fn eq(&self, other: &Self) -> bool {
363 | self.value == other.value
364 | }
365 | }
366 |
367 | impl PartialOrd for Delay {
368 | fn partial_cmp(&self, other: &Self) -> Option {
369 | Some(self.cmp(other))
370 | }
371 | }
372 |
373 | impl Ord for Delay {
374 | fn cmp(&self, other: &Self) -> cmp::Ordering {
375 | self.value.cmp(&other.value)
376 | }
377 | }
378 |
379 | #[derive(Clone, Debug)]
380 | struct DelayAccumulator {
381 | delays: BinaryHeap>,
382 | window: Duration,
383 | }
384 |
385 | impl DelayAccumulator {
386 | /// Creates a `DelayAccumulator` with a sliding window of length `window`.
387 | pub fn new(window: Duration) -> Self {
388 | Self {
389 | delays: BinaryHeap::new(),
390 | window,
391 | }
392 | }
393 |
394 | // TODO: Handle `received_at` from the future (i.e. beyond `Instant::now`).
395 | /// Pushes `delay` onto the accumulator set to remain in the sliding window based on
396 | /// `received_at`.
397 | pub fn push(&mut self, delay: Duration, received_at: Instant) {
398 | let delay = Delay {
399 | value: delay,
400 | deadline: received_at + self.window,
401 | };
402 | self.delays.push(cmp::Reverse(delay));
403 | }
404 |
405 | // TODO: Here, delay measurements that fall outside of the window are deleted lazily. Evaluate
406 | // a non-lazy alternative. The number of elements in the accumulator should not be too large.
407 | // The lazy solution also means that `base_delay` requires `&mut self`, which is not
408 | // preferable.
409 | /// Returns a baseline delay measure given the delay measurements present in the accumulator.
410 | /// Returns `None` if there are no delay measurements within the current sliding window.
411 | pub fn base_delay(&mut self) -> Option {
412 | while let Some(min) = self.delays.peek() {
413 | // If the deadline of the delay has been reached, then remove the delay and continue to
414 | // the next iteration.
415 | if Instant::now() >= min.0.deadline {
416 | self.delays.pop();
417 | continue;
418 | }
419 |
420 | return Some(min.0.value);
421 | }
422 |
423 | // No delay exists within the window.
424 | None
425 | }
426 | }
427 |
428 | #[cfg(test)]
429 | mod tests {
430 | use super::*;
431 |
432 | mod controller {
433 | use super::*;
434 |
435 | #[test]
436 | fn on_transmit() {
437 | let mut ctrl = Controller::new(Config::default());
438 |
439 | let initial_timeout = ctrl.timeout();
440 |
441 | // Register the initial transmission of a packet with sequence number 1.
442 | let mut seq_num = 1;
443 | let packet_one_size_bytes = 32;
444 | let transmission = Transmit::Initial {
445 | bytes: packet_one_size_bytes,
446 | };
447 | ctrl.on_transmit(seq_num, transmission)
448 | .expect("transmission registration failed");
449 |
450 | let transmission_record = ctrl
451 | .transmissions
452 | .get(&seq_num)
453 | .expect("transmission not recorded");
454 | assert_eq!(transmission_record.size_bytes, packet_one_size_bytes);
455 | assert_eq!(transmission_record.num_transmissions, 1);
456 |
457 | assert_eq!(ctrl.window_size_bytes, packet_one_size_bytes);
458 |
459 | // Register the initial transmission of a packet with sequence number 2.
460 | seq_num = 2;
461 | let packet_two_size_bytes = 128;
462 | let transmission = Transmit::Initial {
463 | bytes: packet_two_size_bytes,
464 | };
465 | ctrl.on_transmit(seq_num, transmission)
466 | .expect("transmission registration failed");
467 |
468 | let transmission_record = ctrl
469 | .transmissions
470 | .get(&seq_num)
471 | .expect("transmission not recorded");
472 | assert_eq!(transmission_record.size_bytes, packet_two_size_bytes);
473 | assert_eq!(transmission_record.num_transmissions, 1);
474 | assert_eq!(
475 | ctrl.window_size_bytes,
476 | packet_one_size_bytes + packet_two_size_bytes,
477 | );
478 |
479 | // Register the retransmission of the packet with sequence number 2.
480 | ctrl.on_transmit(seq_num, Transmit::Retransmission)
481 | .expect("transmission registration failed");
482 |
483 | let transmission_record = ctrl
484 | .transmissions
485 | .get(&seq_num)
486 | .expect("transmission not recorded");
487 | assert_eq!(transmission_record.size_bytes, packet_two_size_bytes);
488 | assert_eq!(transmission_record.num_transmissions, 2);
489 | assert_eq!(
490 | ctrl.window_size_bytes,
491 | packet_one_size_bytes + packet_two_size_bytes,
492 | );
493 |
494 | assert_eq!(ctrl.timeout(), initial_timeout);
495 | }
496 |
497 | #[test]
498 | fn on_transmit_duplicate_transmission() {
499 | let mut ctrl = Controller::new(Config::default());
500 |
501 | // Register the initial transmission of a packet with sequence number 1.
502 | let seq_num = 1;
503 | let bytes = 32;
504 | let transmission = Transmit::Initial { bytes };
505 | ctrl.on_transmit(seq_num, transmission)
506 | .expect("transmission registration failed");
507 |
508 | assert_eq!(ctrl.window_size_bytes, bytes);
509 |
510 | // Register the initial transmission of the SAME packet.
511 | let transmission = Transmit::Initial { bytes };
512 | let result = ctrl.on_transmit(seq_num, transmission);
513 | assert_eq!(result, Err(Error::DuplicateTransmission));
514 |
515 | assert_eq!(ctrl.window_size_bytes, bytes);
516 | }
517 |
518 | #[test]
519 | fn on_transmit_unknown_seq_num() {
520 | let mut ctrl = Controller::new(Config::default());
521 |
522 | // Register the retransmission of the packet with sequence number 1.
523 | let seq_num = 1;
524 | let result = ctrl.on_transmit(seq_num, Transmit::Retransmission);
525 | assert_eq!(result, Err(Error::UnknownSeqNum));
526 |
527 | assert_eq!(ctrl.window_size_bytes, 0);
528 | }
529 |
530 | #[test]
531 | fn on_transmit_insufficient_window_size() {
532 | let mut ctrl = Controller::new(Config::default());
533 |
534 | // Register the transmission of a packet with sequence number 1 whose size EXCEEDS the
535 | // maximum window size.
536 | let seq_num = 1;
537 | let bytes = ctrl.max_window_size_bytes + 1;
538 | let result = ctrl.on_transmit(seq_num, Transmit::Initial { bytes });
539 | assert_eq!(result, Err(Error::InsufficientWindowSize));
540 |
541 | assert_eq!(ctrl.window_size_bytes, 0);
542 | }
543 |
544 | #[test]
545 | fn on_ack() {
546 | let mut ctrl = Controller::new(Config::default());
547 |
548 | // Register the initial transmission of a packet with sequence number 1.
549 | let seq_num = 1;
550 | let bytes = 32;
551 | let transmission = Transmit::Initial { bytes };
552 | ctrl.on_transmit(seq_num, transmission)
553 | .expect("transmission registration failed");
554 |
555 | // Register the acknowledgement for the packet with sequence number 1.
556 | let ack_delay = Duration::from_millis(150);
557 | let ack_rtt = Duration::from_millis(300);
558 | let ack_received_at = Instant::now();
559 | let ack = Ack {
560 | delay: ack_delay,
561 | rtt: ack_rtt,
562 | received_at: ack_received_at,
563 | };
564 | ctrl.on_ack(seq_num, ack).expect("ack registration failed");
565 |
566 | assert_eq!(
567 | ctrl.delay_acc
568 | .base_delay()
569 | .expect("delay not pushed into accumulator"),
570 | ack_delay,
571 | );
572 |
573 | // TODO: max window
574 |
575 | assert_eq!(ctrl.window_size_bytes, 0);
576 |
577 | // TODO: RTT variance
578 |
579 | // TODO: RTT
580 |
581 | assert!(ctrl.timeout() >= ctrl.min_timeout);
582 | }
583 |
584 | #[test]
585 | fn on_ack_unknown_seq_num() {
586 | let mut ctrl = Controller::new(Config::default());
587 |
588 | // Register the acknowledgement for the packet with sequence number 1.
589 | let seq_num = 1;
590 | let ack_delay = Duration::from_millis(150);
591 | let ack_rtt = Duration::from_millis(300);
592 | let ack_received_at = Instant::now();
593 | let ack = Ack {
594 | delay: ack_delay,
595 | rtt: ack_rtt,
596 | received_at: ack_received_at,
597 | };
598 | let result = ctrl.on_ack(seq_num, ack);
599 | assert_eq!(result, Err(Error::UnknownSeqNum));
600 | }
601 |
602 | #[test]
603 | fn on_lost_packet_retransmitting() {
604 | let mut ctrl = Controller::new(Config::default());
605 |
606 | let initial_max_window_size_bytes = ctrl.min_window_size_bytes * 10;
607 | ctrl.max_window_size_bytes = initial_max_window_size_bytes;
608 |
609 | // Register the initial transmission of a packet with sequence number 1.
610 | let seq_num = 1;
611 | let bytes = 32;
612 | let transmission = Transmit::Initial { bytes };
613 | ctrl.on_transmit(seq_num, transmission)
614 | .expect("transmission registration failed");
615 |
616 | assert_eq!(ctrl.window_size_bytes, bytes);
617 |
618 | // Register the loss of the packet with sequence number 1. Specify that we WILL attempt
619 | // to retransmit.
620 | ctrl.on_lost_packet(seq_num, true)
621 | .expect("lost packet registration failed");
622 | assert_eq!(ctrl.window_size_bytes, bytes);
623 | assert!(ctrl.max_window_size_bytes >= ctrl.min_window_size_bytes);
624 | assert_eq!(
625 | ctrl.max_window_size_bytes,
626 | initial_max_window_size_bytes / 2,
627 | );
628 | }
629 |
630 | #[test]
631 | fn on_lost_packet_not_retransmitting() {
632 | let mut ctrl = Controller::new(Config::default());
633 |
634 | let initial_max_window_size_bytes = ctrl.min_window_size_bytes * 10;
635 | ctrl.max_window_size_bytes = initial_max_window_size_bytes;
636 |
637 | // Register the initial transmission of a packet with sequence number 1.
638 | let seq_num = 1;
639 | let bytes = 32;
640 | let transmission = Transmit::Initial { bytes };
641 | ctrl.on_transmit(seq_num, transmission)
642 | .expect("transmission registration failed");
643 |
644 | assert_eq!(ctrl.window_size_bytes, bytes);
645 |
646 | // Register the loss of the packet with sequence number 1. Specify that we WILL NOT
647 | // attempt to retransmit.
648 | ctrl.on_lost_packet(seq_num, false)
649 | .expect("lost packet registration failed");
650 | assert_eq!(ctrl.window_size_bytes, 0);
651 | assert!(ctrl.max_window_size_bytes >= ctrl.min_window_size_bytes);
652 | assert_eq!(
653 | ctrl.max_window_size_bytes,
654 | initial_max_window_size_bytes / 2,
655 | );
656 | }
657 |
658 | #[test]
659 | fn on_lost_packet_unknown_seq_num() {
660 | let mut ctrl = Controller::new(Config::default());
661 |
662 | let initial_max_window_size_bytes = ctrl.min_window_size_bytes * 10;
663 | ctrl.max_window_size_bytes = initial_max_window_size_bytes;
664 |
665 | // Register the loss of the packet with sequence number 1.
666 | let seq_num = 1;
667 | let result = ctrl.on_lost_packet(seq_num, false);
668 | assert_eq!(result, Err(Error::UnknownSeqNum));
669 | assert_eq!(ctrl.window_size_bytes, 0);
670 | assert_eq!(ctrl.max_window_size_bytes, initial_max_window_size_bytes);
671 | }
672 |
673 | #[test]
674 | fn on_timeout() {
675 | let mut ctrl = Controller::new(Config::default());
676 |
677 | let initial_max_window_size_bytes = ctrl.min_window_size_bytes * 10;
678 | ctrl.max_window_size_bytes = initial_max_window_size_bytes;
679 |
680 | let initial_timeout = ctrl.timeout();
681 |
682 | // Register a timeout.
683 | ctrl.on_timeout();
684 | assert_eq!(ctrl.max_window_size_bytes, ctrl.min_window_size_bytes);
685 | assert_eq!(ctrl.timeout, initial_timeout * 2);
686 | }
687 |
688 | #[test]
689 | fn on_timeout_not_exceed_max() {
690 | const MAX_TIMEOUT: Duration = Duration::from_secs(3);
691 | let config = Config {
692 | initial_timeout: Duration::from_secs(2),
693 | max_timeout: MAX_TIMEOUT,
694 | ..Default::default()
695 | };
696 |
697 | let mut ctrl = Controller::new(config);
698 |
699 | // Register a timeout.
700 | ctrl.on_timeout();
701 | assert_eq!(ctrl.timeout, MAX_TIMEOUT);
702 | }
703 | }
704 |
705 | mod delay_accumulator {
706 | use super::*;
707 |
708 | #[test]
709 | fn push() {
710 | let window = Duration::from_millis(100);
711 | let mut acc = DelayAccumulator::new(window);
712 |
713 | let delay = Duration::from_millis(50);
714 | let delay_received_at = Instant::now();
715 | acc.push(delay, delay_received_at);
716 |
717 | let item = acc
718 | .delays
719 | .peek()
720 | .expect("delay not pushed onto accumulator");
721 | assert_eq!(item.0.value, delay);
722 | assert_eq!(item.0.deadline, delay_received_at + window);
723 | }
724 |
725 | #[test]
726 | fn base_delay() {
727 | let window = Duration::from_millis(100);
728 | let mut acc = DelayAccumulator::new(window);
729 |
730 | let delay_small = Duration::from_millis(50);
731 | let delay_small_received_at = Instant::now();
732 | acc.push(delay_small, delay_small_received_at);
733 |
734 | let delay_smaller = Duration::from_millis(25);
735 | let delay_smaller_received_at = Instant::now();
736 | acc.push(delay_smaller, delay_smaller_received_at);
737 |
738 | let delay_smallest = Duration::from_millis(5);
739 | let delay_smallest_received_at = Instant::now();
740 | acc.push(delay_smallest, delay_smallest_received_at);
741 |
742 | let delay_expired = Duration::from_millis(1);
743 | let delay_expired_received_at = Instant::now() - window;
744 | acc.push(delay_expired, delay_expired_received_at);
745 |
746 | // Check that all delays are present within the accumulator.
747 | assert_eq!(acc.delays.len(), 4);
748 |
749 | let base_delay = acc
750 | .base_delay()
751 | .expect("base delay not present in accumulator");
752 | assert_eq!(base_delay, delay_smallest);
753 |
754 | // Check that the expired delay was popped.
755 | assert_eq!(acc.delays.len(), 3);
756 | }
757 |
758 | #[test]
759 | fn base_delay_empty() {
760 | let window = Duration::from_millis(100);
761 | let mut acc = DelayAccumulator::new(window);
762 |
763 | let base_delay = acc.base_delay();
764 | assert!(base_delay.is_none());
765 | }
766 | }
767 | }
768 |
--------------------------------------------------------------------------------
/src/event.rs:
--------------------------------------------------------------------------------
1 | use crate::cid::ConnectionId;
2 | use crate::packet::Packet;
3 | use crate::peer::{ConnectionPeer, Peer};
4 |
5 | #[derive(Clone, Debug)]
6 | pub enum StreamEvent {
7 | Incoming(Packet),
8 | Shutdown,
9 | }
10 |
11 | #[derive(Clone, Debug)]
12 | pub enum SocketEvent {
13 | Outgoing((Packet, Peer)),
14 | Shutdown(ConnectionId),
15 | }
16 |
--------------------------------------------------------------------------------
/src/lib.rs:
--------------------------------------------------------------------------------
1 | pub mod cid;
2 | pub mod congestion;
3 | pub mod conn;
4 | pub mod event;
5 | pub mod packet;
6 | pub mod peer;
7 | pub mod recv;
8 | pub mod send;
9 | pub mod sent;
10 | pub mod seq;
11 | pub mod socket;
12 | pub mod stream;
13 | pub mod testutils;
14 | pub mod time;
15 | pub mod udp;
16 |
--------------------------------------------------------------------------------
/src/packet.rs:
--------------------------------------------------------------------------------
1 | use std::fmt;
2 | use std::fmt::{Debug, Formatter};
3 |
4 | /// Size of an encoded uTP header in bytes.
5 | const PACKET_HEADER_LEN: usize = 20;
6 |
7 | /// Size of a Selective ACK segment in bits.
8 | const SELECTIVE_ACK_BITS: usize = 32;
9 |
10 | /// Size of an extension identifier in bytes.
11 | const EXTENSION_TYPE_LEN: usize = 1;
12 |
13 | /// Size of an extension length specifier in bytes.
14 | const EXTENSION_LEN_LEN: usize = 1;
15 |
16 | #[derive(Copy, Clone, Debug)]
17 | pub struct InvalidPacketType;
18 |
19 | impl fmt::Display for InvalidPacketType {
20 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21 | write!(f, "invalid uTP packet type")
22 | }
23 | }
24 |
25 | #[derive(Copy, Clone, Debug)]
26 | pub struct InvalidVersion;
27 |
28 | impl fmt::Display for InvalidVersion {
29 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 | write!(f, "invalid uTP version")
31 | }
32 | }
33 |
34 | #[derive(Copy, Clone, Debug)]
35 | pub enum SelectiveAckError {
36 | InsufficientLen,
37 | InvalidLen,
38 | }
39 |
40 | impl From for PacketError {
41 | fn from(value: SelectiveAckError) -> Self {
42 | Self::InvalidExtension(ExtensionError::from(value))
43 | }
44 | }
45 |
46 | impl fmt::Display for SelectiveAckError {
47 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 | let s = match self {
49 | Self::InsufficientLen => "selective ACK len must be at least 32 bits",
50 | Self::InvalidLen => "selective ACK len must be a multiple of 32 bits",
51 | };
52 |
53 | write!(f, "{}", s)
54 | }
55 | }
56 |
57 | #[derive(Clone, Debug)]
58 | pub enum ExtensionError {
59 | InsufficientLen,
60 | InvalidSelectiveAck(SelectiveAckError),
61 | }
62 |
63 | impl From for ExtensionError {
64 | fn from(value: SelectiveAckError) -> Self {
65 | Self::InvalidSelectiveAck(value)
66 | }
67 | }
68 |
69 | impl fmt::Display for ExtensionError {
70 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 | let s: String = match self {
72 | Self::InsufficientLen => String::from("insufficient extension len"),
73 | Self::InvalidSelectiveAck(err) => err.to_string(),
74 | };
75 |
76 | write!(f, "{}", s)
77 | }
78 | }
79 |
80 | #[derive(Clone, Debug)]
81 | pub enum PacketHeaderError {
82 | InvalidPacketType(InvalidPacketType),
83 | InvalidVersion(InvalidVersion),
84 | InvalidExtension(ExtensionError),
85 | InvalidLen,
86 | }
87 |
88 | impl From for PacketHeaderError {
89 | fn from(value: InvalidPacketType) -> Self {
90 | Self::InvalidPacketType(value)
91 | }
92 | }
93 |
94 | impl From for PacketHeaderError {
95 | fn from(value: InvalidVersion) -> Self {
96 | Self::InvalidVersion(value)
97 | }
98 | }
99 |
100 | impl From for PacketHeaderError {
101 | fn from(value: ExtensionError) -> Self {
102 | Self::InvalidExtension(value)
103 | }
104 | }
105 |
106 | #[derive(Clone, Debug)]
107 | pub enum PacketError {
108 | InvalidHeader(PacketHeaderError),
109 | InvalidExtension(ExtensionError),
110 | InvalidLen,
111 | EmptyDataPayload,
112 | }
113 |
114 | impl From for PacketError {
115 | fn from(value: PacketHeaderError) -> Self {
116 | Self::InvalidHeader(value)
117 | }
118 | }
119 |
120 | impl From for PacketError {
121 | fn from(value: ExtensionError) -> Self {
122 | Self::InvalidExtension(value)
123 | }
124 | }
125 |
126 | #[derive(Copy, Clone, Debug, Eq, PartialEq)]
127 | pub enum PacketType {
128 | Data,
129 | Fin,
130 | State,
131 | Reset,
132 | Syn,
133 | }
134 |
135 | impl fmt::Display for PacketType {
136 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 | let s = match self {
138 | Self::Data => "ST_DATA".to_string(),
139 | Self::Fin => "ST_FIN".to_string(),
140 | Self::State => "ST_STATE".to_string(),
141 | Self::Reset => "ST_RESET".to_string(),
142 | Self::Syn => "ST_SYN".to_string(),
143 | };
144 |
145 | write!(f, "{s}")
146 | }
147 | }
148 |
149 | impl TryFrom for PacketType {
150 | type Error = InvalidPacketType;
151 |
152 | fn try_from(value: u8) -> Result {
153 | match value {
154 | 0 => Ok(Self::Data),
155 | 1 => Ok(Self::Fin),
156 | 2 => Ok(Self::State),
157 | 3 => Ok(Self::Reset),
158 | 4 => Ok(Self::Syn),
159 | _ => Err(InvalidPacketType),
160 | }
161 | }
162 | }
163 |
164 | impl From for u8 {
165 | fn from(value: PacketType) -> u8 {
166 | match value {
167 | PacketType::Data => 0,
168 | PacketType::Fin => 1,
169 | PacketType::State => 2,
170 | PacketType::Reset => 3,
171 | PacketType::Syn => 4,
172 | }
173 | }
174 | }
175 |
176 | #[derive(Copy, Clone, Debug, Eq, PartialEq)]
177 | enum Version {
178 | One,
179 | }
180 |
181 | impl TryFrom for Version {
182 | type Error = InvalidVersion;
183 |
184 | fn try_from(value: u8) -> Result {
185 | match value {
186 | 1 => Ok(Self::One),
187 | _ => Err(InvalidVersion),
188 | }
189 | }
190 | }
191 |
192 | impl From for u8 {
193 | fn from(value: Version) -> u8 {
194 | match value {
195 | Version::One => 1,
196 | }
197 | }
198 | }
199 |
200 | #[derive(Copy, Clone, Debug, Eq, PartialEq)]
201 | enum Extension {
202 | None,
203 | SelectiveAck,
204 | Unknown(u8),
205 | }
206 |
207 | impl From for Extension {
208 | fn from(value: u8) -> Self {
209 | match value {
210 | 0 => Self::None,
211 | 1 => Self::SelectiveAck,
212 | unknown => Self::Unknown(unknown),
213 | }
214 | }
215 | }
216 |
217 | impl From for u8 {
218 | fn from(value: Extension) -> u8 {
219 | match value {
220 | Extension::None => 0,
221 | Extension::SelectiveAck => 1,
222 | Extension::Unknown(ext) => ext,
223 | }
224 | }
225 | }
226 |
227 | #[derive(Clone, Debug, Eq, PartialEq)]
228 | struct PacketHeader {
229 | packet_type: PacketType,
230 | version: Version,
231 | extension: Extension,
232 | conn_id: u16,
233 | ts_micros: u32,
234 | ts_diff_micros: u32,
235 | window_size: u32,
236 | seq_num: u16,
237 | ack_num: u16,
238 | }
239 |
240 | impl PacketHeader {
241 | pub fn encode(&self) -> Vec {
242 | let mut bytes = vec![];
243 |
244 | let packet_type = Into::::into(self.packet_type).to_be_bytes()[0];
245 | let version = Into::::into(self.version).to_be_bytes()[0];
246 | let type_version = (packet_type << 4) | version;
247 | bytes.push(type_version);
248 |
249 | let extension = u8::from(self.extension);
250 | bytes.push(extension);
251 |
252 | bytes.extend(self.conn_id.to_be_bytes());
253 | bytes.extend(self.ts_micros.to_be_bytes());
254 | bytes.extend(self.ts_diff_micros.to_be_bytes());
255 | bytes.extend(self.window_size.to_be_bytes());
256 | bytes.extend(self.seq_num.to_be_bytes());
257 | bytes.extend(self.ack_num.to_be_bytes());
258 |
259 | bytes
260 | }
261 |
262 | pub fn decode(value: &[u8]) -> Result {
263 | if value.len() != PACKET_HEADER_LEN {
264 | return Err(PacketHeaderError::InvalidLen);
265 | }
266 |
267 | let packet_type = value[0] >> 4;
268 | let packet_type = PacketType::try_from(u8::from_be(packet_type))?;
269 |
270 | let version = value[0] & 0x0F;
271 | let version = Version::try_from(u8::from_be(version))?;
272 |
273 | let extension = u8::from_be(value[1]);
274 | let extension = Extension::from(extension);
275 |
276 | let conn_id = [value[2], value[3]];
277 | let conn_id = u16::from_be_bytes(conn_id);
278 |
279 | let ts_micros = [value[4], value[5], value[6], value[7]];
280 | let ts_micros = u32::from_be_bytes(ts_micros);
281 |
282 | let ts_diff_micros = [value[8], value[9], value[10], value[11]];
283 | let ts_diff_micros = u32::from_be_bytes(ts_diff_micros);
284 |
285 | let window_size = [value[12], value[13], value[14], value[15]];
286 | let window_size = u32::from_be_bytes(window_size);
287 |
288 | let seq_num = [value[16], value[17]];
289 | let seq_num = u16::from_be_bytes(seq_num);
290 |
291 | let ack_num = [value[18], value[19]];
292 | let ack_num = u16::from_be_bytes(ack_num);
293 |
294 | Ok(Self {
295 | packet_type,
296 | version,
297 | extension,
298 | conn_id,
299 | ts_micros,
300 | ts_diff_micros,
301 | window_size,
302 | seq_num,
303 | ack_num,
304 | })
305 | }
306 | }
307 |
308 | #[derive(Clone, Debug, Eq, PartialEq)]
309 | pub struct SelectiveAck {
310 | acked: Vec<[bool; SELECTIVE_ACK_BITS]>,
311 | }
312 |
313 | impl fmt::Display for SelectiveAck {
314 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
315 | let mut s = String::new();
316 | for chunk in &self.acked {
317 | for bit in chunk {
318 | if *bit {
319 | s.push('1');
320 | } else {
321 | s.push('0');
322 | }
323 | }
324 | }
325 | write!(f, "{}", s)
326 | }
327 | }
328 |
329 | impl SelectiveAck {
330 | pub fn new(acked: Vec) -> Self {
331 | let chunks = acked.as_slice().chunks_exact(SELECTIVE_ACK_BITS);
332 | let remainder = chunks.remainder();
333 |
334 | let mut acked = Vec::new();
335 | for chunk in chunks {
336 | let mut fragment: [bool; SELECTIVE_ACK_BITS] = [false; SELECTIVE_ACK_BITS];
337 | fragment.copy_from_slice(chunk);
338 | acked.push(fragment);
339 | }
340 |
341 | if !remainder.is_empty() {
342 | let mut fragment: [bool; SELECTIVE_ACK_BITS] = [false; SELECTIVE_ACK_BITS];
343 | fragment[..remainder.len()].copy_from_slice(remainder);
344 | acked.push(fragment);
345 | }
346 |
347 | Self { acked }
348 | }
349 |
350 | /// Returns the length in bytes of the encoded Selective ACK.
351 | pub fn encoded_len(&self) -> usize {
352 | (self.acked.len() * SELECTIVE_ACK_BITS) / 8
353 | }
354 |
355 | pub fn acked(&self) -> Vec {
356 | self.acked
357 | .clone()
358 | .into_iter()
359 | .flatten()
360 | .collect::>()
361 | }
362 |
363 | pub fn encode(&self) -> Vec {
364 | let mut bitmask = vec![];
365 |
366 | for word in &self.acked {
367 | let chunks = word.as_slice().chunks_exact(8);
368 |
369 | for chunk in chunks {
370 | let mut byte = 0;
371 |
372 | byte |= u8::from(chunk[7]) << 7;
373 | byte |= u8::from(chunk[6]) << 6;
374 | byte |= u8::from(chunk[5]) << 5;
375 | byte |= u8::from(chunk[4]) << 4;
376 | byte |= u8::from(chunk[3]) << 3;
377 | byte |= u8::from(chunk[2]) << 2;
378 | byte |= u8::from(chunk[1]) << 1;
379 | byte |= u8::from(chunk[0]);
380 |
381 | bitmask.push(byte);
382 | }
383 | }
384 |
385 | bitmask
386 | }
387 |
388 | pub fn decode(value: &[u8]) -> Result {
389 | if value.len() < 4 {
390 | return Err(SelectiveAckError::InsufficientLen);
391 | }
392 | if value.len() % 4 != 0 {
393 | return Err(SelectiveAckError::InvalidLen);
394 | }
395 |
396 | let mut acked: Vec<[bool; 32]> = vec![];
397 | let mut tmp = [false; 32];
398 | for (index, byte) in value.iter().enumerate() {
399 | tmp[(index * 8) % 32] = (*byte & 0b0000_0001) != 0;
400 | tmp[(index * 8 + 1) % 32] = (*byte & 0b0000_0010) != 0;
401 | tmp[(index * 8 + 2) % 32] = (*byte & 0b0000_0100) != 0;
402 | tmp[(index * 8 + 3) % 32] = (*byte & 0b0000_1000) != 0;
403 | tmp[(index * 8 + 4) % 32] = (*byte & 0b0001_0000) != 0;
404 | tmp[(index * 8 + 5) % 32] = (*byte & 0b0010_0000) != 0;
405 | tmp[(index * 8 + 6) % 32] = (*byte & 0b0100_0000) != 0;
406 | tmp[(index * 8 + 7) % 32] = (*byte & 0b1000_0000) != 0;
407 |
408 | if (index + 1) % 4 == 0 {
409 | acked.push(tmp);
410 | tmp = [false; 32];
411 | }
412 | }
413 |
414 | if value.len() % 4 != 0 {
415 | acked.push(tmp);
416 | }
417 |
418 | Ok(Self { acked })
419 | }
420 | }
421 |
422 | #[derive(Clone, Eq, PartialEq)]
423 | pub struct Packet {
424 | header: PacketHeader,
425 | selective_ack: Option,
426 | payload: Vec,
427 | }
428 |
429 | impl Packet {
430 | pub fn packet_type(&self) -> PacketType {
431 | self.header.packet_type
432 | }
433 |
434 | pub fn conn_id(&self) -> u16 {
435 | self.header.conn_id
436 | }
437 |
438 | pub fn ts_micros(&self) -> u32 {
439 | self.header.ts_micros
440 | }
441 |
442 | pub fn ts_diff_micros(&self) -> u32 {
443 | self.header.ts_diff_micros
444 | }
445 |
446 | pub fn window_size(&self) -> u32 {
447 | self.header.window_size
448 | }
449 |
450 | pub fn seq_num(&self) -> u16 {
451 | self.header.seq_num
452 | }
453 |
454 | pub fn ack_num(&self) -> u16 {
455 | self.header.ack_num
456 | }
457 |
458 | pub fn selective_ack(&self) -> Option<&SelectiveAck> {
459 | self.selective_ack.as_ref()
460 | }
461 |
462 | pub fn payload(&self) -> &Vec {
463 | &self.payload
464 | }
465 |
466 | /// Returns the length in bytes of the encoded packet.
467 | pub fn encoded_len(&self) -> usize {
468 | let mut len = PACKET_HEADER_LEN;
469 | if let Some(ref sack) = self.selective_ack {
470 | len += sack.encoded_len() + EXTENSION_TYPE_LEN + EXTENSION_LEN_LEN;
471 | }
472 | len += self.payload.len();
473 |
474 | len
475 | }
476 |
477 | pub fn encode(&self) -> Vec {
478 | let mut bytes = vec![];
479 |
480 | bytes.extend(self.header.encode());
481 | if let Some(ack) = &self.selective_ack {
482 | let ack = ack.encode();
483 | bytes.push(Extension::None.into());
484 | bytes.push((ack.len() as u8).to_be_bytes()[0]);
485 | bytes.extend(ack);
486 | }
487 | bytes.extend_from_slice(self.payload.as_slice());
488 |
489 | bytes
490 | }
491 |
492 | pub fn decode(value: &[u8]) -> Result {
493 | if value.len() < PACKET_HEADER_LEN {
494 | return Err(PacketError::InvalidHeader(PacketHeaderError::InvalidLen));
495 | }
496 |
497 | let mut header: [u8; PACKET_HEADER_LEN] = [0; PACKET_HEADER_LEN];
498 | header.copy_from_slice(&value[..PACKET_HEADER_LEN]);
499 | let header = PacketHeader::decode(&header)?;
500 |
501 | let (extensions, extensions_len) =
502 | Self::decode_raw_extensions(header.extension, &value[PACKET_HEADER_LEN..])?;
503 |
504 | // Look for the first (if any) Selective ACK extension, and attempt to decode it.
505 | // TODO: Evaluate whether duplicate extensions should constitute an error.
506 | let selective_ack = extensions
507 | .iter()
508 | .find(|(ext, _)| *ext == Extension::SelectiveAck);
509 | let selective_ack = match selective_ack {
510 | Some((_, data)) => Some(SelectiveAck::decode(data)?),
511 | None => None,
512 | };
513 |
514 | // TODO: Save all raw extensions and make them accessible. People should be able to use
515 | // custom extensions.
516 |
517 | // The packet payload is the remainder of the packet.
518 | let payload_start_index = PACKET_HEADER_LEN + extensions_len;
519 | let payload = if payload_start_index == value.len() {
520 | vec![]
521 | } else {
522 | value[payload_start_index..].to_vec()
523 | };
524 |
525 | if header.packet_type == PacketType::Data && payload.is_empty() {
526 | return Err(PacketError::EmptyDataPayload);
527 | }
528 |
529 | Ok(Self {
530 | header,
531 | selective_ack,
532 | payload,
533 | })
534 | }
535 |
536 | // TODO: Resolve disabled clippy lint.
537 | #[allow(clippy::type_complexity)]
538 | fn decode_raw_extensions(
539 | first_ext: Extension,
540 | data: &[u8],
541 | ) -> Result<(Vec<(Extension, Vec)>, usize), ExtensionError> {
542 | let mut ext = first_ext;
543 | let mut index = 0;
544 |
545 | let mut extensions: Vec<(Extension, Vec)> = Vec::new();
546 | while ext != Extension::None {
547 | if data[index..].len() < 2 {
548 | return Err(ExtensionError::InsufficientLen);
549 | }
550 |
551 | let next_ext = u8::from_be_bytes([data[index]]);
552 |
553 | let ext_len = u8::from_be_bytes([data[index + 1]]);
554 | let ext_len = usize::from(ext_len);
555 |
556 | let ext_start = index + 2;
557 | if data[ext_start..].len() < ext_len {
558 | return Err(ExtensionError::InsufficientLen);
559 | }
560 |
561 | let ext_data = data[ext_start..ext_start + ext_len].to_vec();
562 | extensions.push((ext, ext_data));
563 |
564 | ext = Extension::from(next_ext);
565 | index = ext_start + ext_len;
566 | }
567 |
568 | Ok((extensions, index))
569 | }
570 | }
571 |
572 | impl Debug for Packet {
573 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
574 | write!(f, "packet cid={} packetType={} seqNr={} ackNr={} timestamp={} timestampDiff={} remoteWindow={}",
575 | self.conn_id(),
576 | self.packet_type(),
577 | self.seq_num(),
578 | self.ack_num(),
579 | self.ts_micros(),
580 | self.ts_diff_micros(),
581 | self.window_size(),
582 | )
583 | }
584 | }
585 |
586 | #[derive(Clone, Debug)]
587 | pub struct PacketBuilder {
588 | packet_type: PacketType,
589 | conn_id: u16,
590 | ts_micros: u32,
591 | ts_diff_micros: u32,
592 | window_size: u32,
593 | seq_num: u16,
594 | ack_num: u16,
595 | selective_ack: Option,
596 | payload: Option>,
597 | }
598 |
599 | impl PacketBuilder {
600 | pub fn new(
601 | packet_type: PacketType,
602 | conn_id: u16,
603 | ts_micros: u32,
604 | window_size: u32,
605 | seq_num: u16,
606 | ) -> Self {
607 | Self {
608 | packet_type,
609 | conn_id,
610 | ts_micros,
611 | ts_diff_micros: 0,
612 | window_size,
613 | seq_num,
614 | ack_num: 0,
615 | selective_ack: None,
616 | payload: None,
617 | }
618 | }
619 |
620 | pub fn ts_micros(mut self, ts_micros: u32) -> Self {
621 | self.ts_micros = ts_micros;
622 | self
623 | }
624 |
625 | pub fn ts_diff_micros(mut self, ts_diff_micros: u32) -> Self {
626 | self.ts_diff_micros = ts_diff_micros;
627 | self
628 | }
629 |
630 | pub fn window_size(mut self, window_size: u32) -> Self {
631 | self.window_size = window_size;
632 | self
633 | }
634 |
635 | pub fn ack_num(mut self, ack_num: u16) -> Self {
636 | self.ack_num = ack_num;
637 | self
638 | }
639 |
640 | pub fn selective_ack(mut self, selective_ack: Option) -> Self {
641 | self.selective_ack = selective_ack;
642 | self
643 | }
644 |
645 | pub fn payload(mut self, payload: Vec) -> Self {
646 | self.payload = Some(payload);
647 | self
648 | }
649 |
650 | pub fn build(self) -> Packet {
651 | let extension = match self.selective_ack {
652 | Some(..) => Extension::SelectiveAck,
653 | None => Extension::None,
654 | };
655 |
656 | Packet {
657 | header: PacketHeader {
658 | packet_type: self.packet_type,
659 | version: Version::One,
660 | extension,
661 | conn_id: self.conn_id,
662 | ts_micros: self.ts_micros,
663 | ts_diff_micros: self.ts_diff_micros,
664 | window_size: self.window_size,
665 | seq_num: self.seq_num,
666 | ack_num: self.ack_num,
667 | },
668 | selective_ack: self.selective_ack,
669 | payload: self.payload.unwrap_or_default(),
670 | }
671 | }
672 | }
673 |
674 | impl From for PacketBuilder {
675 | fn from(packet: Packet) -> Self {
676 | let payload = if packet.payload.is_empty() {
677 | None
678 | } else {
679 | Some(packet.payload)
680 | };
681 |
682 | Self {
683 | packet_type: packet.header.packet_type,
684 | conn_id: packet.header.conn_id,
685 | ts_micros: packet.header.ts_micros,
686 | ts_diff_micros: packet.header.ts_diff_micros,
687 | window_size: packet.header.window_size,
688 | seq_num: packet.header.seq_num,
689 | ack_num: packet.header.ack_num,
690 | selective_ack: packet.selective_ack,
691 | payload,
692 | }
693 | }
694 | }
695 |
696 | #[cfg(test)]
697 | mod tests {
698 | use super::*;
699 |
700 | use quickcheck::{quickcheck, Arbitrary, Gen, TestResult};
701 |
702 | impl Arbitrary for PacketHeader {
703 | fn arbitrary(g: &mut Gen) -> Self {
704 | let packet_type = u8::arbitrary(g);
705 | let packet_type = if packet_type % 5 == 0 {
706 | PacketType::Data
707 | } else if packet_type % 5 == 1 {
708 | PacketType::Fin
709 | } else if packet_type % 5 == 2 {
710 | PacketType::State
711 | } else if packet_type % 5 == 3 {
712 | PacketType::Reset
713 | } else {
714 | PacketType::Syn
715 | };
716 |
717 | let extension = u8::arbitrary(g);
718 | let extension = Extension::from(extension);
719 |
720 | Self {
721 | packet_type,
722 | version: Version::One,
723 | extension,
724 | conn_id: u16::arbitrary(g),
725 | ts_micros: u32::arbitrary(g),
726 | ts_diff_micros: u32::arbitrary(g),
727 | window_size: u32::arbitrary(g),
728 | ack_num: u16::arbitrary(g),
729 | seq_num: u16::arbitrary(g),
730 | }
731 | }
732 | }
733 |
734 | impl Arbitrary for SelectiveAck {
735 | fn arbitrary(g: &mut Gen) -> Self {
736 | let bits: Vec = Vec::arbitrary(g);
737 |
738 | let mut acked: Vec<[bool; 32]> = vec![];
739 |
740 | let mut tmp = [false; 32];
741 | for (index, bit) in bits.iter().enumerate() {
742 | tmp[index % 32] = *bit;
743 |
744 | if (index + 1) % 32 == 0 {
745 | acked.push(tmp);
746 | tmp = [false; 32];
747 | }
748 | }
749 |
750 | if bits.len() % 32 != 0 || acked.is_empty() {
751 | acked.push(tmp);
752 | }
753 |
754 | Self { acked }
755 | }
756 | }
757 |
758 | // TODO: Add more tests. For example, packet encoding and decoding should test for arbitrary
759 | // extensions.
760 |
761 | #[test]
762 | fn header_encode_decode() {
763 | fn prop(header: PacketHeader) -> TestResult {
764 | let encoded = header.encode();
765 | let encoded: [u8; 20] = encoded
766 | .try_into()
767 | .expect("invalid length for encoded uTP packet header");
768 | let decoded =
769 | PacketHeader::decode(&encoded).expect("failed to decode uTP packet header");
770 |
771 | TestResult::from_bool(decoded == header)
772 | }
773 | quickcheck(prop as fn(PacketHeader) -> TestResult);
774 | }
775 |
776 | #[test]
777 | fn selective_ack_encode_decode() {
778 | fn prop(selective_ack: SelectiveAck) -> TestResult {
779 | let encoded_len = selective_ack.encoded_len();
780 |
781 | let encoded = selective_ack.encode();
782 |
783 | assert!(encoded.len() % (SELECTIVE_ACK_BITS / 8) == 0);
784 | assert_eq!(encoded.len(), encoded_len);
785 |
786 | let decoded = SelectiveAck::decode(&encoded).expect("failed to decode Selective ACK");
787 |
788 | TestResult::from_bool(decoded == selective_ack)
789 | }
790 | quickcheck(prop as fn(SelectiveAck) -> TestResult);
791 | }
792 |
793 | #[test]
794 | fn packet_encode_decode() {
795 | fn prop(
796 | mut header: PacketHeader,
797 | selective_ack: SelectiveAck,
798 | payload: Vec,
799 | ) -> TestResult {
800 | if payload.is_empty() {
801 | return TestResult::discard();
802 | }
803 |
804 | let selective_ack = if selective_ack.acked.is_empty() {
805 | None
806 | } else {
807 | Some(selective_ack)
808 | };
809 | match selective_ack {
810 | Some(..) => {
811 | header.extension = Extension::SelectiveAck;
812 | }
813 | None => {
814 | header.extension = Extension::None;
815 | }
816 | }
817 |
818 | let packet = Packet {
819 | header,
820 | selective_ack,
821 | payload,
822 | };
823 |
824 | let encoded_len = packet.encoded_len();
825 |
826 | let encoded = packet.encode();
827 |
828 | assert_eq!(encoded.len(), encoded_len);
829 |
830 | let decoded = Packet::decode(&encoded).expect("failed to decode uTP packet");
831 |
832 | TestResult::from_bool(decoded == packet)
833 | }
834 | quickcheck(prop as fn(PacketHeader, SelectiveAck, Vec) -> TestResult);
835 | }
836 | }
837 |
--------------------------------------------------------------------------------
/src/peer.rs:
--------------------------------------------------------------------------------
1 | use std::fmt::Debug;
2 | use std::hash::Hash;
3 | use std::net::SocketAddr;
4 |
5 | /// A trait that describes remote peer
6 | pub trait ConnectionPeer: Debug + Clone + Send + Sync {
7 | type Id: Debug + Clone + PartialEq + Eq + Hash + Send + Sync;
8 |
9 | /// Returns peer's id
10 | fn id(&self) -> Self::Id;
11 |
12 | /// Consolidates two peers into one.
13 | ///
14 | /// It's possible that we have two instances that represent the same peer (equal `peer_id`),
15 | /// and we need to consolidate them into one. This can happen when [Peer]-s passed with
16 | /// [UtpSocket::accept_with_cid](crate::socket::UtpSocket::accept_with_cid) or
17 | /// [UtpSocket::connect_with_cid](crate::socket::UtpSocket::connect_with_cid), and returned by
18 | /// [AsyncUdpSocket::recv_from](crate::udp::AsyncUdpSocket::recv_from) contain peers (not just
19 | /// `peer_id`).
20 | ///
21 | /// The structure implementing this trait can decide on the exact behavior. Some examples:
22 | /// - If structure is simple (i.e. two peers are the same iff all fields are the same), return
23 | /// either (see implementation for `SocketAddr`)
24 | /// - If we can determine which peer is newer (e.g. using timestamp or version field), return
25 | /// newer peer
26 | /// - If structure behaves more like a key-value map whose values don't change over time,
27 | /// merge key-value pairs from both instances into one
28 | ///
29 | /// Should panic if ids are not matching.
30 | fn consolidate(a: Self, b: Self) -> Self;
31 | }
32 |
33 | impl ConnectionPeer for SocketAddr {
34 | type Id = Self;
35 |
36 | fn id(&self) -> Self::Id {
37 | *self
38 | }
39 |
40 | fn consolidate(a: Self, b: Self) -> Self {
41 | assert!(a == b, "Consolidating non-equal peers");
42 | a
43 | }
44 | }
45 |
46 | /// Structure that stores peer's id, and maybe peer as well.
47 | #[derive(Debug, Clone)]
48 | pub struct Peer {
49 | id: P::Id,
50 | peer: Option,
51 | }
52 |
53 | impl Peer {
54 | /// Creates new instance that stores peer
55 | pub fn new(peer: P) -> Self {
56 | Self {
57 | id: peer.id(),
58 | peer: Some(peer),
59 | }
60 | }
61 |
62 | /// Creates new instance that only stores peer's id
63 | pub fn new_id(peer_id: P::Id) -> Self {
64 | Self {
65 | id: peer_id,
66 | peer: None,
67 | }
68 | }
69 |
70 | /// Returns peer's id
71 | pub fn id(&self) -> &P::Id {
72 | &self.id
73 | }
74 |
75 | /// Returns optional reference to peer
76 | pub fn peer(&self) -> Option<&P> {
77 | self.peer.as_ref()
78 | }
79 |
80 | /// Consolidates given peer into `Self` whilst consuming it.
81 | ///
82 | /// See [ConnectionPeer::consolidate] for details.
83 | ///
84 | /// Panics if ids are not matching.
85 | pub fn consolidate(&mut self, other: Self) {
86 | assert!(self.id == other.id, "Consolidating with non-equal peer");
87 | let Some(other_peer) = other.peer else {
88 | return;
89 | };
90 |
91 | self.peer = match self.peer.take() {
92 | Some(peer) => Some(P::consolidate(peer, other_peer)),
93 | None => Some(other_peer),
94 | };
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/src/recv.rs:
--------------------------------------------------------------------------------
1 | use std::collections::{BTreeMap, HashSet};
2 |
3 | use crate::packet::SelectiveAck;
4 | use crate::seq::CircularRangeInclusive;
5 |
6 | type Bytes = Vec;
7 |
8 | // https://github.com/ethereum/utp/issues/139
9 | // Maximum number of selective ACKs that can fit within the length of 8 bits
10 | const MAX_SELECTIVE_ACK_COUNT: usize = 32 * 63;
11 |
12 | #[derive(Clone, Debug)]
13 | pub struct ReceiveBuffer {
14 | buf: Box<[u8; N]>,
15 | offset: usize,
16 | pending: BTreeMap,
17 | init_seq_num: u16,
18 | consumed: u16,
19 | }
20 |
21 | impl ReceiveBuffer {
22 | /// Returns a new buffer.
23 | pub fn new(init_seq_num: u16) -> Self {
24 | Self {
25 | buf: Box::new([0; N]),
26 | offset: 0,
27 | pending: BTreeMap::new(),
28 | init_seq_num,
29 | consumed: 0,
30 | }
31 | }
32 |
33 | /// Returns the number of available bytes remaining in the buffer.
34 | pub fn available(&self) -> usize {
35 | N - self.offset - self.pending.values().fold(0, |acc, data| acc + data.len())
36 | }
37 |
38 | /// Returns `true` if the buffer is empty.
39 | pub fn is_empty(&self) -> bool {
40 | self.offset == 0 && self.pending.is_empty()
41 | }
42 |
43 | /// Returns the initial sequence number of the buffer.
44 | pub fn init_seq_num(&self) -> u16 {
45 | self.init_seq_num
46 | }
47 |
48 | /// Returns `true` if data was already written into the buffer for `seq_num`.
49 | pub fn was_written(&self, seq_num: u16) -> bool {
50 | let range = CircularRangeInclusive::new(
51 | self.init_seq_num,
52 | self.init_seq_num.wrapping_add(self.consumed),
53 | );
54 | range.contains(seq_num) || self.pending.contains_key(&seq_num)
55 | }
56 |
57 | /// Reads data from the buffer into `buf`, returning the number of bytes read.
58 | pub fn read(&mut self, buf: &mut [u8]) -> std::io::Result {
59 | if buf.is_empty() {
60 | return Ok(0);
61 | }
62 |
63 | let n = std::cmp::min(buf.len(), self.offset);
64 | buf[..n].copy_from_slice(&self.buf.as_slice()[..n]);
65 |
66 | let remaining = self.offset - n;
67 | self.buf.as_mut_slice().copy_within(n..n + remaining, 0);
68 | self.offset = remaining;
69 |
70 | Ok(n)
71 | }
72 |
73 | /// Writes `data` into the buffer at `seq_num`.
74 | ///
75 | /// # Panics
76 | ///
77 | /// Panics if `data.len()` is greater than the amount of available bytes in the buffer and the
78 | /// data for `seq_num` has not already been written.
79 | pub fn write(&mut self, data: &[u8], seq_num: u16) {
80 | if self.was_written(seq_num) {
81 | return;
82 | }
83 |
84 | // TODO: Return error instead of panicking.
85 | if data.len() > self.available() {
86 | panic!("insufficient space in buffer");
87 | }
88 |
89 | // Read all sequential data from pending packets.
90 | self.pending.insert(seq_num, data.to_vec());
91 | let start = self.init_seq_num.wrapping_add(1);
92 | let mut next = start.wrapping_add(self.consumed);
93 | while let Some(data) = self.pending.remove(&next) {
94 | let end = self.offset + data.len();
95 | self.buf.as_mut_slice()[self.offset..end].copy_from_slice(&data[..]);
96 |
97 | self.offset = end;
98 | self.consumed += 1;
99 | next = next.wrapping_add(1);
100 | }
101 | }
102 |
103 | /// Returns the last sequence number in a contiguous sequence from the initial sequence number.
104 | pub fn ack_num(&self) -> u16 {
105 | self.init_seq_num.wrapping_add(self.consumed)
106 | }
107 |
108 | /// Returns a selective ACK based on the sequence of data written into the buffer.
109 | pub fn selective_ack(&self) -> Option {
110 | if self.pending.is_empty() {
111 | return None;
112 | }
113 |
114 | // If there are pending packets, then the data for `ack_num + 1` must be missing.
115 | let mut last_ack = self.ack_num().wrapping_add(2);
116 | let mut pending_packets = self.pending.keys().copied().collect::>();
117 |
118 | let mut acked = vec![];
119 | while !pending_packets.is_empty() && acked.len() < MAX_SELECTIVE_ACK_COUNT {
120 | if pending_packets.remove(&last_ack) {
121 | acked.push(true);
122 | } else {
123 | acked.push(false);
124 | }
125 | last_ack = last_ack.wrapping_add(1);
126 | }
127 |
128 | Some(SelectiveAck::new(acked))
129 | }
130 | }
131 |
132 | #[cfg(test)]
133 | mod test {
134 | use super::*;
135 |
136 | const SIZE: usize = 1024;
137 |
138 | #[test]
139 | fn available() {
140 | let init_seq_num = u16::MAX;
141 | let mut buf = ReceiveBuffer::::new(init_seq_num);
142 |
143 | assert_eq!(buf.available(), SIZE);
144 |
145 | const DATA_LEN: usize = 256;
146 |
147 | // Write out-of-order packet.
148 | let data = vec![0xef; DATA_LEN];
149 | let seq_num = init_seq_num.wrapping_add(2);
150 | buf.write(&data, seq_num);
151 | assert_eq!(buf.available(), SIZE - DATA_LEN);
152 |
153 | // Write in-order packet.
154 | let data = vec![0xef; DATA_LEN];
155 | let seq_num = init_seq_num.wrapping_add(1);
156 | buf.write(&data, seq_num);
157 | assert_eq!(buf.available(), SIZE - (DATA_LEN * 2));
158 |
159 | // Read all data.
160 | let mut data = [0; DATA_LEN * 2];
161 | buf.read(&mut data).unwrap();
162 | assert_eq!(buf.available(), SIZE);
163 | }
164 |
165 | #[test]
166 | fn was_written() {
167 | let init_seq_num = u16::MAX;
168 | let mut buf = ReceiveBuffer::::new(init_seq_num);
169 |
170 | let seq_num = init_seq_num.wrapping_add(2);
171 | assert!(!buf.was_written(seq_num));
172 |
173 | const DATA_LEN: usize = 64;
174 | let data = vec![0xef; DATA_LEN];
175 | buf.write(&data, seq_num);
176 | assert!(buf.was_written(seq_num));
177 | }
178 |
179 | #[test]
180 | fn write() {
181 | let init_seq_num = u16::MAX;
182 | let mut buf = ReceiveBuffer::::new(init_seq_num);
183 |
184 | const DATA_LEN: usize = 256;
185 |
186 | // Write out-of-order packet.
187 | let data_second = vec![0xef; DATA_LEN];
188 | let seq_num = init_seq_num.wrapping_add(2);
189 | buf.write(&data_second, seq_num);
190 | assert_eq!(buf.offset, 0);
191 | assert_eq!(buf.consumed, 0);
192 | assert!(buf.pending.contains_key(&seq_num));
193 | assert_eq!(*buf.pending.get(&seq_num).unwrap(), data_second);
194 |
195 | // Write in-order packet.
196 | let data_first = vec![0xfe; DATA_LEN];
197 | let seq_num = init_seq_num.wrapping_add(1);
198 | buf.write(&data_first, seq_num);
199 | assert_eq!(buf.offset, DATA_LEN * 2);
200 | assert_eq!(buf.consumed, 2);
201 | assert!(buf.pending.is_empty());
202 | assert_eq!(buf.buf[..DATA_LEN], data_first[..]);
203 | assert_eq!(buf.buf[DATA_LEN..DATA_LEN * 2], data_second[..]);
204 | }
205 |
206 | #[test]
207 | #[should_panic]
208 | fn write_exceeds_available() {
209 | let init_seq_num = u16::MAX;
210 | let mut buf = ReceiveBuffer::::new(init_seq_num);
211 |
212 | let seq_num = init_seq_num.wrapping_add(1);
213 | let data = vec![0xef; SIZE + 1];
214 | buf.write(&data, seq_num);
215 | }
216 |
217 | #[test]
218 | fn read() {
219 | let init_seq_num = u16::MAX;
220 | let mut buf = ReceiveBuffer::::new(init_seq_num);
221 |
222 | const DATA_LEN: usize = 256;
223 |
224 | // Write out-of-order packet.
225 | let data_second = vec![0xef; DATA_LEN];
226 | let seq_num = init_seq_num.wrapping_add(2);
227 | buf.write(&data_second, seq_num);
228 |
229 | let mut read_buf = vec![0; SIZE];
230 | let read = buf.read(&mut read_buf).unwrap();
231 | assert_eq!(read, 0);
232 |
233 | // Write in-order packet.
234 | let data_first = vec![0xfe; DATA_LEN];
235 | let seq_num = init_seq_num.wrapping_add(1);
236 | buf.write(&data_first, seq_num);
237 |
238 | let read = buf.read(&mut read_buf).unwrap();
239 | assert_eq!(read, DATA_LEN * 2);
240 | assert_eq!(buf.offset, 0);
241 | assert_eq!(read_buf[..DATA_LEN], data_first[..]);
242 | assert_eq!(read_buf[DATA_LEN..DATA_LEN * 2], data_second[..]);
243 |
244 | let read = buf.read(&mut read_buf).unwrap();
245 | assert_eq!(read, 0);
246 | }
247 |
248 | #[test]
249 | fn ack_num() {
250 | let init_seq_num = u16::MAX;
251 | let mut buf = ReceiveBuffer::::new(init_seq_num);
252 |
253 | assert_eq!(buf.ack_num(), init_seq_num);
254 |
255 | const DATA_LEN: usize = 64;
256 | let data = vec![0xef; DATA_LEN];
257 |
258 | // Write out-of-order packet.
259 | let second_seq_num = init_seq_num.wrapping_add(2);
260 | buf.write(&data, second_seq_num);
261 |
262 | assert_eq!(buf.ack_num(), init_seq_num);
263 |
264 | // Write in-order packet.
265 | let first_seq_num = init_seq_num.wrapping_add(1);
266 | buf.write(&data, first_seq_num);
267 |
268 | assert_eq!(buf.ack_num(), second_seq_num);
269 | }
270 |
271 | #[test]
272 | fn selective_ack() {
273 | let init_seq_num = u16::MAX;
274 | let mut buf = ReceiveBuffer::::new(init_seq_num);
275 |
276 | let selective_ack = buf.selective_ack();
277 | assert!(selective_ack.is_none());
278 |
279 | const DATA_LEN: usize = 64;
280 | let data = vec![0xef; DATA_LEN];
281 |
282 | // Write out-of-order packet.
283 | let seq_num = init_seq_num.wrapping_add(2);
284 | buf.write(&data, seq_num);
285 |
286 | let selective_ack = buf.selective_ack().unwrap();
287 | let acked = selective_ack.acked();
288 | assert!(acked[0]);
289 | for ack in acked[1..].iter() {
290 | assert!(!ack);
291 | }
292 |
293 | // Write in-order packet.
294 | let seq_num = init_seq_num.wrapping_add(1);
295 | buf.write(&data, seq_num);
296 |
297 | let selective_ack = buf.selective_ack();
298 | assert!(selective_ack.is_none());
299 | }
300 |
301 | #[test]
302 | fn selective_ack_overflow() {
303 | let init_seq_num = u16::MAX - 2;
304 | let mut buf = ReceiveBuffer::::new(init_seq_num);
305 |
306 | let selective_ack = buf.selective_ack();
307 | assert!(selective_ack.is_none());
308 |
309 | const DATA_LEN: usize = 64;
310 | let data = vec![0xef; DATA_LEN];
311 |
312 | // Write out-of-order packet.
313 | let seq_num = init_seq_num.wrapping_add(2);
314 | buf.write(&data, seq_num);
315 | // Write overflow packet, which is at seq_num 0.
316 | let seq_num = init_seq_num.wrapping_add(3);
317 | buf.write(&data, seq_num);
318 | // Write another out of order but received packet.
319 | let seq_num = init_seq_num.wrapping_add(5);
320 | buf.write(&data, seq_num);
321 |
322 | // Selective ACK should mark received packets as set.
323 | // Selective ACK begins with ack_num + 2, onwards.
324 | // Hence since we received packets 65535, 0, and 2, we should have 3 packets set, in the respective positions.
325 | let selective_ack = buf.selective_ack().unwrap();
326 | let mut acked = vec![false; 32];
327 | acked[0] = true;
328 | acked[1] = true;
329 | acked[3] = true;
330 | assert_eq!(selective_ack.acked(), acked);
331 | }
332 | }
333 |
--------------------------------------------------------------------------------
/src/send.rs:
--------------------------------------------------------------------------------
1 | use std::collections::VecDeque;
2 | use std::io;
3 |
4 | type Bytes = Vec;
5 |
6 | #[derive(Clone, Debug)]
7 | pub struct SendBuffer {
8 | pending: VecDeque,
9 | offset: usize,
10 | }
11 |
12 | impl Default for SendBuffer {
13 | fn default() -> Self {
14 | Self {
15 | pending: VecDeque::new(),
16 | offset: 0,
17 | }
18 | }
19 | }
20 |
21 | impl SendBuffer {
22 | /// Creates a new buffer.
23 | pub fn new() -> Self {
24 | Self {
25 | pending: VecDeque::new(),
26 | offset: 0,
27 | }
28 | }
29 |
30 | /// Returns the number of bytes available in the buffer.
31 | pub fn available(&self) -> usize {
32 | N + self.offset - self.pending.iter().fold(0, |acc, x| acc + x.len())
33 | }
34 |
35 | /// Returns `true` if the buffer is empty.
36 | pub fn is_empty(&self) -> bool {
37 | self.pending.is_empty()
38 | }
39 |
40 | /// Writes `data` into the buffer, returning the number of bytes written.
41 | pub fn write(&mut self, data: &[u8]) -> io::Result {
42 | let available = self.available();
43 | if data.len() <= available {
44 | self.pending.push_back(data.to_vec());
45 | Ok(data.len())
46 | } else {
47 | self.pending.push_back(data[..available].to_vec());
48 | Ok(available)
49 | }
50 | }
51 |
52 | /// Reads data from the buffer into `buf`, returning the number of bytes read.
53 | ///
54 | /// Data from at most one previous write can be read into `buf`. Data from different writes
55 | /// will not go into a single read.
56 | pub fn read(&mut self, buf: &mut [u8]) -> io::Result {
57 | if buf.is_empty() {
58 | return Ok(0);
59 | }
60 |
61 | let mut written = 0;
62 | if let Some(data) = self.pending.front() {
63 | let n = std::cmp::min(data.len() - self.offset, buf.len());
64 | buf[..n].copy_from_slice(&data[self.offset..self.offset + n]);
65 |
66 | written += n;
67 | if self.offset + n == data.len() {
68 | self.offset = 0;
69 | self.pending.pop_front();
70 | } else {
71 | self.offset += n;
72 | }
73 | }
74 |
75 | Ok(written)
76 | }
77 | }
78 |
79 | #[cfg(test)]
80 | mod test {
81 | use super::*;
82 |
83 | const SIZE: usize = 8192;
84 |
85 | #[test]
86 | fn available() {
87 | let mut buf = SendBuffer::::new();
88 | assert_eq!(buf.available(), SIZE);
89 |
90 | const WRITE_LEN: usize = 512;
91 | const NUM_WRITES: usize = 3;
92 |
93 | const READ_LEN: usize = 64;
94 |
95 | for _ in 0..NUM_WRITES {
96 | let data = vec![0; WRITE_LEN];
97 | buf.write(&data).unwrap();
98 | }
99 | assert_eq!(buf.available(), SIZE - (WRITE_LEN * NUM_WRITES));
100 |
101 | let mut data = vec![0; READ_LEN];
102 | buf.read(&mut data).unwrap();
103 | assert_eq!(buf.available(), SIZE - (WRITE_LEN * NUM_WRITES) + READ_LEN);
104 |
105 | for _ in 0..NUM_WRITES {
106 | let mut data = vec![0; WRITE_LEN];
107 | buf.read(&mut data).unwrap();
108 | }
109 | assert_eq!(buf.available(), SIZE);
110 | }
111 |
112 | #[test]
113 | #[allow(clippy::read_zero_byte_vec)]
114 | fn read() {
115 | let mut buf = SendBuffer::::new();
116 |
117 | // Read of empty buffer returns zero.
118 | let mut read_buf = vec![0; SIZE];
119 | let read = buf.read(&mut read_buf).unwrap();
120 | assert_eq!(read, 0);
121 |
122 | const WRITE_LEN: usize = 1024;
123 |
124 | const READ_LEN: usize = 784;
125 |
126 | let mut read_buf = vec![0; READ_LEN];
127 |
128 | let write_one = vec![0xef; WRITE_LEN];
129 | let write_two = vec![0xfe; WRITE_LEN];
130 | buf.write(&write_one).unwrap();
131 | buf.write(&write_two).unwrap();
132 |
133 | // Read first chunk of first write.
134 | let read = buf.read(&mut read_buf).unwrap();
135 | assert_eq!(read, READ_LEN);
136 | assert_eq!(read_buf[..READ_LEN], write_one[..READ_LEN]);
137 |
138 | // Read remaining chunk of first write.
139 | let mut read_buf = vec![0; READ_LEN];
140 | let read = buf.read(&mut read_buf).unwrap();
141 | assert_eq!(read, WRITE_LEN - READ_LEN);
142 | assert_eq!(read_buf[..WRITE_LEN - READ_LEN], write_one[READ_LEN..]);
143 |
144 | // Read first chunk of second write.
145 | let read = buf.read(&mut read_buf).unwrap();
146 | assert_eq!(read, READ_LEN);
147 | assert_eq!(read_buf[..READ_LEN], write_two[..READ_LEN]);
148 |
149 | // Read with empty buffer returns zero.
150 | let mut empty = vec![];
151 | let read = buf.read(&mut empty).unwrap();
152 | assert_eq!(read, 0);
153 | }
154 |
155 | #[test]
156 | fn write() {
157 | let mut buf = SendBuffer::::new();
158 |
159 | const WRITE_LEN: usize = 1024;
160 |
161 | let data = vec![0xef; WRITE_LEN];
162 | let written = buf.write(data.as_slice()).unwrap();
163 | assert_eq!(written, WRITE_LEN);
164 | assert_eq!(&buf.pending.pop_front().unwrap(), &data);
165 | }
166 | }
167 |
--------------------------------------------------------------------------------
/src/sent.rs:
--------------------------------------------------------------------------------
1 | use std::collections::BTreeSet;
2 | use std::fmt::{self, Formatter};
3 | use std::time::{Duration, Instant};
4 |
5 | use crate::congestion;
6 | use crate::packet::{PacketType, SelectiveAck};
7 | use crate::seq::CircularRangeInclusive;
8 |
9 | const LOSS_THRESHOLD: usize = 3;
10 |
11 | type Bytes = Vec;
12 |
13 | #[derive(Clone, Debug)]
14 | struct SentPacket {
15 | pub seq_num: u16,
16 | pub packet_type: PacketType,
17 | pub data: Option,
18 | pub transmission: Instant,
19 | pub retransmissions: Vec,
20 | pub acks: Vec,
21 | }
22 |
23 | impl SentPacket {
24 | fn rtt(&self, now: Instant) -> Duration {
25 | let last_transmission = self.retransmissions.first().unwrap_or(&self.transmission);
26 | now.duration_since(*last_transmission)
27 | }
28 | }
29 |
30 | #[derive(Clone, Debug)]
31 | pub struct SentPackets {
32 | packets: Vec,
33 | init_seq_num: u16,
34 | lost_packets: BTreeSet,
35 | congestion_ctrl: congestion::Controller,
36 | }
37 |
38 | #[derive(Clone, Copy, Debug, PartialEq, Eq)]
39 | pub enum SentPacketsError {
40 | InvalidAckNum,
41 | }
42 |
43 | impl fmt::Display for SentPacketsError {
44 | fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
45 | match self {
46 | Self::InvalidAckNum => write!(f, "invalid ack number"),
47 | }
48 | }
49 | }
50 |
51 | impl std::error::Error for SentPacketsError {}
52 |
53 | impl SentPackets {
54 | /// Note: `init_seq_num` corresponds to the sequence number just before the sequence number of
55 | /// the first packet to track.
56 | pub fn new(init_seq_num: u16, congestion_ctrl: congestion::Controller) -> Self {
57 | Self {
58 | packets: Vec::new(),
59 | init_seq_num,
60 | lost_packets: BTreeSet::new(),
61 | congestion_ctrl,
62 | }
63 | }
64 |
65 | pub fn next_seq_num(&self) -> u16 {
66 | // Assume cast is okay, meaning that `packets` never contains more than `u16::MAX`
67 | // elements.
68 | self.init_seq_num
69 | .wrapping_add(self.packets.len() as u16)
70 | .wrapping_add(1)
71 | }
72 |
73 | pub fn ack_num(&self) -> u16 {
74 | self.last_ack_num().unwrap_or(0)
75 | }
76 |
77 | pub fn seq_num_range(&self) -> CircularRangeInclusive {
78 | let end = self.next_seq_num().wrapping_sub(1);
79 | CircularRangeInclusive::new(self.init_seq_num, end)
80 | }
81 |
82 | pub fn timeout(&self) -> Duration {
83 | self.congestion_ctrl.timeout()
84 | }
85 |
86 | pub fn on_timeout(&mut self) {
87 | self.congestion_ctrl.on_timeout()
88 | }
89 |
90 | pub fn window(&self) -> u32 {
91 | self.congestion_ctrl.bytes_available_in_window()
92 | }
93 |
94 | pub fn has_unacked_packets(&self) -> bool {
95 | self.first_unacked_seq_num().is_some()
96 | }
97 |
98 | pub fn has_lost_packets(&self) -> bool {
99 | !self.lost_packets.is_empty()
100 | }
101 |
102 | pub fn lost_packets(&self) -> Vec<(u16, PacketType, Option)> {
103 | self.lost_packets
104 | .iter()
105 | .map(|seq| {
106 | let index = self.seq_num_index(*seq);
107 |
108 | // The unwrap is safe because only sent packets may be lost.
109 | let packet = self.packets.get(index).unwrap();
110 |
111 | (packet.seq_num, packet.packet_type, packet.data.clone())
112 | })
113 | .collect()
114 | }
115 |
116 | /// # Panics
117 | ///
118 | /// Panics if `seq_num` does not correspond to the next expected packet or a previously sent
119 | /// packet.
120 | ///
121 | /// Panics if the transmit is not a retransmission and `len` is greater than the amount of
122 | /// available space in the window.
123 | pub fn on_transmit(
124 | &mut self,
125 | seq_num: u16,
126 | packet_type: PacketType,
127 | data: Option,
128 | len: u32,
129 | now: Instant,
130 | ) {
131 | let index = self.seq_num_index(seq_num);
132 | let is_retransmission = index < self.packets.len();
133 |
134 | // If the packet sequence number is beyond the next sequence number, then panic.
135 | if index > self.packets.len() {
136 | panic!("out of order transmit");
137 | }
138 |
139 | // If this is not a retransmission and the length of the packet is greater than the amount
140 | // of available space in the window, then panic.
141 | if !is_retransmission && len > self.window() {
142 | panic!("transmit exceeds available send window");
143 | }
144 |
145 | match self.packets.get_mut(index) {
146 | Some(sent) => {
147 | sent.retransmissions.push(now);
148 | }
149 | None => {
150 | let sent = SentPacket {
151 | seq_num,
152 | packet_type,
153 | data,
154 | transmission: now,
155 | retransmissions: Vec::new(),
156 | acks: Vec::new(),
157 | };
158 | self.packets.push(sent);
159 | }
160 | }
161 |
162 | let transmit = if is_retransmission {
163 | congestion::Transmit::Retransmission
164 | } else {
165 | congestion::Transmit::Initial { bytes: len }
166 | };
167 |
168 | // The unwrap is safe given the check above on the available window.
169 | self.congestion_ctrl.on_transmit(seq_num, transmit).unwrap();
170 | }
171 |
172 | /// Handle an ACK for a packet with sequence number `ack_num`, and an optional selective_ack.
173 | ///
174 | /// Returns Error, if `ack_num` is not within the sequence number range of the sent packets.
175 | pub fn on_ack(
176 | &mut self,
177 | ack_num: u16,
178 | selective_ack: Option<&SelectiveAck>,
179 | delay: Duration,
180 | now: Instant,
181 | ) -> Result<(CircularRangeInclusive, Vec), SentPacketsError> {
182 | let range = self.seq_num_range();
183 | if !range.contains(ack_num) {
184 | return Err(SentPacketsError::InvalidAckNum);
185 | }
186 |
187 | // Do not ACK if ACK num corresponds to initial packet.
188 | if ack_num != range.start() {
189 | self.on_ack_num(ack_num, selective_ack, delay, now);
190 | }
191 |
192 | // Mark all packets up to `ack_num` acknowledged by retaining all packets in
193 | // the range beyond `ack_num`.
194 | let full_acked = CircularRangeInclusive::new(range.start(), ack_num);
195 |
196 | let selected_acks = if let Some(selective_ack) = selective_ack {
197 | selective_ack
198 | .acked()
199 | .iter()
200 | .enumerate()
201 | .filter(|(_, &acked)| acked)
202 | .map(|(i, _)| ack_num.wrapping_add(2).wrapping_add(i as u16))
203 | .collect()
204 | } else {
205 | Vec::new()
206 | };
207 |
208 | Ok((full_acked, selected_acks))
209 | }
210 |
211 | /// # Panics
212 | ///
213 | /// Panics if `ack_num` does not correspond to a previously sent packet.
214 | fn on_ack_num(
215 | &mut self,
216 | ack_num: u16,
217 | selective_ack: Option<&SelectiveAck>,
218 | delay: Duration,
219 | now: Instant,
220 | ) {
221 | if let Some(sack) = selective_ack {
222 | self.on_selective_ack(ack_num, sack, delay, now);
223 | } else {
224 | self.ack(ack_num, delay, now);
225 | }
226 |
227 | // An ACK for `ack_num` implicitly ACKs all sequence numbers that precede `ack_num`.
228 | // Account for any preceding unacked packets.
229 | self.ack_prior_unacked(ack_num, delay, now);
230 |
231 | // Account for (newly) lost packets.
232 | let lost = self.detect_lost_packets();
233 | for packet in lost {
234 | if self.lost_packets.insert(packet) {
235 | self.on_lost(packet, true);
236 | }
237 | }
238 | }
239 |
240 | /// # Panics
241 | ///
242 | /// Panics if `ack_num` does not correspond to a previously sent packet.
243 | fn on_selective_ack(
244 | &mut self,
245 | ack_num: u16,
246 | selective_ack: &SelectiveAck,
247 | delay: Duration,
248 | now: Instant,
249 | ) {
250 | self.ack(ack_num, delay, now);
251 |
252 | let range = self.seq_num_range();
253 |
254 | // The first bit of the selective ACK corresponds to `ack_num.wrapping_add(2)`, where
255 | // `ack_num.wrapping_add(1)` is assumed to have been dropped.
256 | let mut sack_num = ack_num.wrapping_add(2);
257 | for ack in selective_ack.acked() {
258 | // Break once we exhaust all sent sequence numbers. The selective ACK length is a
259 | // multiple of 32, so it may be padded beyond the actual range of sequence numbers.
260 | if !range.contains(sack_num) {
261 | break;
262 | }
263 |
264 | if ack {
265 | self.ack(sack_num, delay, now);
266 | }
267 |
268 | sack_num = sack_num.wrapping_add(1);
269 | }
270 | }
271 |
272 | /// Returns a set containing the sequence numbers of lost packets.
273 | ///
274 | /// A packet is lost if it has not been acknowledged and some threshold number of packets sent
275 | /// after it have been acknowledged.
276 | fn detect_lost_packets(&self) -> BTreeSet {
277 | let mut acked = 0;
278 | let mut lost = BTreeSet::new();
279 |
280 | let start = match self.first_unacked_seq_num() {
281 | Some(first_unacked) => self.seq_num_index(first_unacked),
282 | None => return lost,
283 | };
284 |
285 | for packet in self.packets[start..].iter().rev() {
286 | if packet.acks.is_empty() && acked >= LOSS_THRESHOLD {
287 | lost.insert(packet.seq_num);
288 | }
289 |
290 | if !packet.acks.is_empty() {
291 | acked += 1;
292 | }
293 | }
294 |
295 | lost
296 | }
297 |
298 | /// # Panics
299 | ///
300 | /// Panics if `seq_num` does not correspond to a previously sent packet.
301 | fn ack(&mut self, seq_num: u16, delay: Duration, now: Instant) {
302 | let index = self.seq_num_index(seq_num);
303 | let packet = self.packets.get_mut(index).unwrap();
304 |
305 | let ack = congestion::Ack {
306 | delay,
307 | rtt: packet.rtt(now),
308 | received_at: now,
309 | };
310 | self.congestion_ctrl.on_ack(packet.seq_num, ack).unwrap();
311 |
312 | packet.acks.push(now);
313 |
314 | self.lost_packets.remove(&packet.seq_num);
315 | }
316 |
317 | /// Acknowledges any unacknowledged packets that precede `seq_num`.
318 | fn ack_prior_unacked(&mut self, seq_num: u16, delay: Duration, now: Instant) {
319 | if let Some(first_unacked) = self.first_unacked_seq_num() {
320 | let start = self.seq_num_index(first_unacked);
321 | let end = self.seq_num_index(seq_num);
322 | if start >= end {
323 | return;
324 | }
325 |
326 | let to_ack: Vec = self.packets[start..end].iter().map(|p| p.seq_num).collect();
327 | for seq_num in to_ack {
328 | self.ack(seq_num, delay, now);
329 | }
330 | }
331 | }
332 |
333 | /// # Panics
334 | ///
335 | /// Panics if `seq_num` does not correspond to a previously sent packet.
336 | fn on_lost(&mut self, seq_num: u16, retransmitting: bool) {
337 | if !self.seq_num_range().contains(seq_num) {
338 | panic!("cannot mark unsent packet lost");
339 | }
340 |
341 | // The unwrap is safe assuming that we do not panic above.
342 | self.congestion_ctrl
343 | .on_lost_packet(seq_num, retransmitting)
344 | .expect("lost packet was previously sent");
345 | }
346 |
347 | /// Returns the "normalized" index for `seq_num` based on the initial sequence number.
348 | fn seq_num_index(&self, seq_num: u16) -> usize {
349 | // The first sequence number is equal to `self.init_seq_num.wrapping_add(1)`.
350 | if seq_num > self.init_seq_num {
351 | usize::from(seq_num - self.init_seq_num - 1)
352 | } else {
353 | usize::from((u16::MAX - self.init_seq_num).wrapping_add(seq_num))
354 | }
355 | }
356 |
357 | /// Returns the sequence number of the last (i.e. latest) packet in a contiguous sequence of
358 | /// acknowledged packets.
359 | ///
360 | /// Returns `None` if none of the (possibly zero) packets have been acknowledged.
361 | // TODO: Cache this value, (possibly) updating on each ACK.
362 | pub fn last_ack_num(&self) -> Option {
363 | if self.packets.is_empty() {
364 | return None;
365 | }
366 |
367 | let mut num = None;
368 | for packet in &self.packets {
369 | if !packet.acks.is_empty() {
370 | num = Some(packet.seq_num);
371 | } else {
372 | break;
373 | }
374 | }
375 |
376 | num
377 | }
378 |
379 | /// Returns the sequence number of the first (i.e. earliest) packet that has not been
380 | /// acknowledged.
381 | ///
382 | /// Returns `None` if all (possibly zero) sent packets have been acknowledged.
383 | fn first_unacked_seq_num(&self) -> Option {
384 | if self.packets.is_empty() {
385 | return None;
386 | }
387 |
388 | let seq_num = match self.last_ack_num() {
389 | Some(last_ack_num) => {
390 | // If the last ACK num corresponds to the last packet, then return `None`.
391 | if self.packets.last().unwrap().seq_num == last_ack_num {
392 | return None;
393 | }
394 | last_ack_num.wrapping_add(1)
395 | }
396 | None => self.init_seq_num.wrapping_add(1),
397 | };
398 |
399 | Some(seq_num)
400 | }
401 | }
402 |
403 | #[cfg(test)]
404 | mod test {
405 | use super::*;
406 |
407 | use quickcheck::{quickcheck, TestResult};
408 |
409 | const DELAY: Duration = Duration::from_millis(100);
410 |
411 | // TODO: Bolster tests.
412 |
413 | #[test]
414 | fn next_seq_num() {
415 | fn prop(init_seq_num: u16, len: u8) -> TestResult {
416 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default());
417 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl);
418 | if len == 0 {
419 | return TestResult::from_bool(
420 | sent_packets.next_seq_num() == init_seq_num.wrapping_add(1),
421 | );
422 | }
423 |
424 | let final_seq_num = init_seq_num.wrapping_add(u16::from(len));
425 | let range = CircularRangeInclusive::new(init_seq_num.wrapping_add(1), final_seq_num);
426 | let transmission = Instant::now();
427 | for seq_num in range {
428 | sent_packets.packets.push(SentPacket {
429 | seq_num,
430 | packet_type: PacketType::Data,
431 | data: None,
432 | transmission,
433 | acks: Default::default(),
434 | retransmissions: Default::default(),
435 | });
436 | }
437 |
438 | TestResult::from_bool(sent_packets.next_seq_num() == final_seq_num.wrapping_add(1))
439 | }
440 | quickcheck(prop as fn(u16, u8) -> TestResult)
441 | }
442 |
443 | #[test]
444 | fn on_transmit_initial() {
445 | let init_seq_num = u16::MAX;
446 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default());
447 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl);
448 |
449 | let seq_num = sent_packets.next_seq_num();
450 | let data = vec![0];
451 | let len = data.len() as u32;
452 | let now = Instant::now();
453 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data), len, now);
454 |
455 | assert_eq!(sent_packets.packets.len(), 1);
456 |
457 | let packet = &sent_packets.packets[0];
458 | assert_eq!(packet.seq_num, seq_num);
459 | assert_eq!(packet.transmission, now);
460 | assert!(packet.acks.is_empty());
461 | assert!(packet.retransmissions.is_empty());
462 | }
463 |
464 | #[test]
465 | fn on_transmit_retransmit() {
466 | let init_seq_num = u16::MAX;
467 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default());
468 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl);
469 |
470 | let seq_num = sent_packets.next_seq_num();
471 | let data = vec![0];
472 | let len = data.len() as u32;
473 | let first = Instant::now();
474 | let second = Instant::now();
475 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data.clone()), len, first);
476 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data), len, second);
477 |
478 | assert_eq!(sent_packets.packets.len(), 1);
479 |
480 | let packet = &sent_packets.packets[0];
481 | assert_eq!(packet.seq_num, seq_num);
482 | assert_eq!(packet.transmission, first);
483 | assert!(packet.acks.is_empty());
484 | assert_eq!(packet.retransmissions.len(), 1);
485 | assert_eq!(packet.retransmissions[0], second);
486 | }
487 |
488 | #[test]
489 | #[should_panic]
490 | fn on_transmit_out_of_order() {
491 | let init_seq_num = u16::MAX;
492 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default());
493 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl);
494 |
495 | let out_of_order_seq_num = init_seq_num.wrapping_add(2);
496 | let data = vec![0];
497 | let len = data.len() as u32;
498 | let now = Instant::now();
499 |
500 | sent_packets.on_transmit(out_of_order_seq_num, PacketType::Data, Some(data), len, now);
501 | }
502 |
503 | #[test]
504 | fn on_selective_ack() {
505 | let init_seq_num = u16::MAX;
506 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default());
507 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl);
508 |
509 | let data = vec![0];
510 | let len = data.len() as u32;
511 |
512 | const COUNT: usize = 10;
513 | for _ in 0..COUNT {
514 | let now = Instant::now();
515 | let seq_num = sent_packets.next_seq_num();
516 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data.clone()), len, now);
517 | }
518 |
519 | const SACK_LEN: usize = COUNT - 2;
520 | let mut acked = vec![false; SACK_LEN];
521 | for (i, ack) in acked.iter_mut().enumerate() {
522 | if i % 2 == 0 {
523 | *ack = true;
524 | }
525 | }
526 | let selective_ack = SelectiveAck::new(acked);
527 |
528 | let now = Instant::now();
529 | sent_packets
530 | .on_ack(
531 | init_seq_num.wrapping_add(1),
532 | Some(&selective_ack),
533 | DELAY,
534 | now,
535 | )
536 | .unwrap();
537 | assert_eq!(sent_packets.packets[0].acks.len(), 1);
538 | assert!(sent_packets.packets[1].acks.is_empty());
539 | for i in 2..COUNT {
540 | let is_empty = i % 2 != 0;
541 | assert_eq!(sent_packets.packets[i].acks.is_empty(), is_empty);
542 | }
543 | }
544 |
545 | #[test]
546 | fn detect_lost_packets() {
547 | let init_seq_num = u16::MAX;
548 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default());
549 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl);
550 |
551 | let data = vec![0];
552 | let len = data.len() as u32;
553 |
554 | const COUNT: usize = 10;
555 | const START: usize = COUNT - LOSS_THRESHOLD;
556 | for i in 0..COUNT {
557 | let now = Instant::now();
558 | let seq_num = sent_packets.next_seq_num();
559 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data.clone()), len, now);
560 |
561 | if i >= START {
562 | sent_packets.ack(seq_num, DELAY, now);
563 | }
564 | }
565 |
566 | let lost = sent_packets.detect_lost_packets();
567 | for i in 0..START {
568 | let packet = &sent_packets.packets[i];
569 | assert!(lost.contains(&packet.seq_num));
570 | }
571 | }
572 |
573 | #[test]
574 | fn ack() {
575 | let init_seq_num = u16::MAX;
576 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default());
577 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl);
578 |
579 | let seq_num = sent_packets.next_seq_num();
580 | let data = vec![0];
581 | let len = data.len() as u32;
582 | let now = Instant::now();
583 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data), len, now);
584 |
585 | // Artificially insert packet into lost packets.
586 | sent_packets.lost_packets.insert(seq_num);
587 | assert!(sent_packets.lost_packets.contains(&seq_num));
588 |
589 | let now = Instant::now();
590 | sent_packets.ack(seq_num, DELAY, now);
591 |
592 | let index = sent_packets.seq_num_index(seq_num);
593 | let packet = sent_packets.packets.get(index).unwrap();
594 |
595 | assert_eq!(packet.acks.len(), 1);
596 | assert_eq!(packet.acks[0], now);
597 | assert!(!sent_packets.lost_packets.contains(&seq_num));
598 | }
599 |
600 | #[test]
601 | fn ack_prior_unacked() {
602 | let init_seq_num = u16::MAX;
603 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default());
604 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl);
605 |
606 | let data = vec![0];
607 | let len = data.len() as u32;
608 |
609 | const COUNT: usize = 10;
610 | for _ in 0..COUNT {
611 | let now = Instant::now();
612 | let seq_num = sent_packets.next_seq_num();
613 | sent_packets.on_transmit(seq_num, PacketType::Data, Some(data.clone()), len, now);
614 | }
615 |
616 | const ACK_NUM: u16 = 3;
617 | assert!(usize::from(ACK_NUM) < COUNT);
618 | assert!(COUNT - usize::from(ACK_NUM) > 2);
619 |
620 | let now = Instant::now();
621 | sent_packets.ack_prior_unacked(ACK_NUM, DELAY, now);
622 | for i in 0..usize::from(ACK_NUM) {
623 | assert_eq!(sent_packets.packets[i].acks.len(), 1);
624 | }
625 | }
626 |
627 | #[test]
628 | #[should_panic]
629 | fn ack_unsent() {
630 | let init_seq_num = u16::MAX;
631 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default());
632 | let mut sent_packets = SentPackets::new(init_seq_num, congestion_ctrl);
633 |
634 | let unsent_ack_num = init_seq_num.wrapping_add(2);
635 | let now = Instant::now();
636 | sent_packets.ack(unsent_ack_num, DELAY, now);
637 | }
638 |
639 | #[test]
640 | fn seq_num_index() {
641 | let init_seq_num = u16::MAX;
642 | let congestion_ctrl = congestion::Controller::new(congestion::Config::default());
643 | let sent_packets = SentPackets::new(init_seq_num, congestion_ctrl);
644 |
645 | assert_eq!(
646 | sent_packets.seq_num_index(init_seq_num),
647 | usize::from(u16::MAX)
648 | );
649 |
650 | let zero = init_seq_num.wrapping_add(1);
651 | assert_eq!(sent_packets.seq_num_index(zero), 0);
652 | }
653 | }
654 |
--------------------------------------------------------------------------------
/src/seq.rs:
--------------------------------------------------------------------------------
1 | /// A range bounded inclusively below and above that supports wrapping arithmetic.
2 | ///
3 | /// If `end < start`, then the range contains all values `x` such that `start <= x <= u16::MAX` and
4 | /// `0 <= x <= end`.
5 | #[derive(Clone, Debug)]
6 | pub struct CircularRangeInclusive {
7 | start: u16,
8 | end: u16,
9 | exhausted: bool,
10 | }
11 |
12 | impl CircularRangeInclusive {
13 | /// Returns a new range.
14 | pub fn new(start: u16, end: u16) -> Self {
15 | Self {
16 | start,
17 | end,
18 | exhausted: false,
19 | }
20 | }
21 |
22 | /// Returns the start of the range (inclusive).
23 | pub fn start(&self) -> u16 {
24 | self.start
25 | }
26 |
27 | /// Returns the end of the range (inclusive).
28 | pub fn end(&self) -> u16 {
29 | self.end
30 | }
31 |
32 | /// Returns `true` if `item` is contained in the range.
33 | pub fn contains(&self, item: u16) -> bool {
34 | if self.end >= self.start {
35 | item >= self.start && item <= self.end
36 | } else if item >= self.start {
37 | true
38 | } else {
39 | item <= self.end
40 | }
41 | }
42 | }
43 |
44 | impl std::iter::Iterator for CircularRangeInclusive {
45 | type Item = u16;
46 |
47 | fn next(&mut self) -> Option {
48 | if self.exhausted {
49 | None
50 | } else if self.start == self.end {
51 | self.exhausted = true;
52 | Some(self.end)
53 | } else {
54 | let step = self.start.wrapping_add(1);
55 | Some(std::mem::replace(&mut self.start, step))
56 | }
57 | }
58 | }
59 |
60 | #[cfg(test)]
61 | mod tests {
62 | use super::*;
63 |
64 | use quickcheck::{quickcheck, TestResult};
65 |
66 | #[test]
67 | fn contains_start() {
68 | fn prop(start: u16, end: u16) -> TestResult {
69 | let range = CircularRangeInclusive::new(start, end);
70 | TestResult::from_bool(range.contains(start))
71 | }
72 | quickcheck(prop as fn(u16, u16) -> TestResult);
73 | }
74 |
75 | #[test]
76 | fn contains_end() {
77 | fn prop(start: u16, end: u16) -> TestResult {
78 | let range = CircularRangeInclusive::new(start, end);
79 | TestResult::from_bool(range.contains(end))
80 | }
81 | quickcheck(prop as fn(u16, u16) -> TestResult);
82 | }
83 |
84 | #[test]
85 | fn iterator() {
86 | fn prop(start: u16, end: u16) -> TestResult {
87 | let range = CircularRangeInclusive::new(start, end);
88 |
89 | let mut len: usize = 0;
90 | let mut expected_idx = start;
91 | for idx in range {
92 | assert_eq!(idx, expected_idx);
93 | expected_idx = expected_idx.wrapping_add(1);
94 | len += 1;
95 | }
96 |
97 | let expected_len = if start <= end {
98 | usize::from(end - start) + 1
99 | } else {
100 | usize::from(u16::MAX - start) + usize::from(end) + 2
101 | };
102 | assert_eq!(len, expected_len);
103 |
104 | TestResult::passed()
105 | }
106 | quickcheck(prop as fn(u16, u16) -> TestResult);
107 | }
108 |
109 | #[test]
110 | fn iterator_single() {
111 | fn prop(x: u16) -> TestResult {
112 | let mut range = CircularRangeInclusive::new(x, x);
113 | assert_eq!(range.next(), Some(x));
114 | assert!(range.next().is_none());
115 |
116 | TestResult::passed()
117 | }
118 | quickcheck(prop as fn(u16) -> TestResult);
119 | }
120 | }
121 |
--------------------------------------------------------------------------------
/src/socket.rs:
--------------------------------------------------------------------------------
1 | use std::collections::HashMap;
2 | use std::io;
3 | use std::net::SocketAddr;
4 | use std::sync::{Arc, RwLock};
5 | use std::time::Duration;
6 |
7 | use delay_map::HashMapDelay;
8 | use futures::StreamExt;
9 | use rand::{thread_rng, Rng};
10 | use tokio::net::UdpSocket;
11 | use tokio::sync::mpsc::UnboundedSender;
12 | use tokio::sync::{mpsc, oneshot};
13 |
14 | use crate::cid::ConnectionId;
15 | use crate::conn::ConnectionConfig;
16 | use crate::event::{SocketEvent, StreamEvent};
17 | use crate::packet::{Packet, PacketBuilder, PacketType};
18 | use crate::peer::{ConnectionPeer, Peer};
19 | use crate::stream::UtpStream;
20 | use crate::udp::AsyncUdpSocket;
21 |
22 | type ConnChannel = UnboundedSender;
23 |
24 | struct Accept {
25 | stream: oneshot::Sender>>,
26 | config: ConnectionConfig,
27 | }
28 |
29 | struct AcceptWithCidPeer {
30 | cid: ConnectionId,
31 | peer: Peer,
32 | accept: Accept
,
33 | }
34 |
35 | const MAX_UDP_PAYLOAD_SIZE: usize = u16::MAX as usize;
36 | const CID_GENERATION_TRY_WARNING_COUNT: usize = 10;
37 |
38 | /// accept_with_cid() has unique interactions compared to accept()
39 | /// accept() pulls awaiting requests off a queue, but accept_with_cid() only
40 | /// takes a connection off if CID matches. Because of this if we are awaiting a CID
41 | /// eventually we need to timeout the await, or the queue would never stop growing with stale awaits
42 | /// 20 seconds is arbitrary, after the uTP config refactor is done that can replace this constant.
43 | /// but thee uTP config refactor is currently very low priority.
44 | const AWAITING_CONNECTION_TIMEOUT: Duration = Duration::from_secs(20);
45 |
46 | pub struct UtpSocket {
47 | conns: Arc, ConnChannel>>>,
48 | accepts: UnboundedSender>,
49 | accepts_with_cid: UnboundedSender>,
50 | socket_events: UnboundedSender>,
51 | }
52 |
53 | impl UtpSocket {
54 | pub async fn bind(addr: SocketAddr) -> io::Result {
55 | let socket = UdpSocket::bind(addr).await?;
56 | let socket = Self::with_socket(socket);
57 | Ok(socket)
58 | }
59 | }
60 |
61 | impl UtpSocket
62 | where
63 | P: ConnectionPeer + Unpin + 'static,
64 | {
65 | pub fn with_socket(mut socket: S) -> Self
66 | where
67 | S: AsyncUdpSocket + 'static,
68 | {
69 | let conns = HashMap::new();
70 | let conns = Arc::new(RwLock::new(conns));
71 |
72 | let mut awaiting: HashMapDelay, AcceptWithCidPeer> =
73 | HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT);
74 |
75 | let mut incoming_conns: HashMapDelay, (Peer, Packet)> =
76 | HashMapDelay::new(AWAITING_CONNECTION_TIMEOUT);
77 |
78 | let (socket_event_tx, mut socket_event_rx) = mpsc::unbounded_channel();
79 | let (accepts_tx, mut accepts_rx) = mpsc::unbounded_channel();
80 | let (accepts_with_cid_tx, mut accepts_with_cid_rx) = mpsc::unbounded_channel();
81 |
82 | let utp = Self {
83 | conns: Arc::clone(&conns),
84 | accepts: accepts_tx,
85 | accepts_with_cid: accepts_with_cid_tx,
86 | socket_events: socket_event_tx.clone(),
87 | };
88 |
89 | tokio::spawn(async move {
90 | let mut buf = [0; MAX_UDP_PAYLOAD_SIZE];
91 | loop {
92 | tokio::select! {
93 | biased;
94 | Ok((n, mut peer)) = socket.recv_from(&mut buf) => {
95 | let peer_id = peer.id();
96 | let packet = match Packet::decode(&buf[..n]) {
97 | Ok(pkt) => pkt,
98 | Err(..) => {
99 | tracing::warn!(?peer, "unable to decode uTP packet");
100 | continue;
101 | }
102 | };
103 |
104 | let peer_init_cid = cid_from_packet::
(&packet, peer_id, IdType::SendIdPeerInitiated);
105 | let we_init_cid = cid_from_packet::
(&packet, peer_id, IdType::SendIdWeInitiated);
106 | let acc_cid = cid_from_packet::
(&packet, peer_id, IdType::RecvId);
107 | let mut conns = conns.write().unwrap();
108 | let conn = conns
109 | .get(&acc_cid)
110 | .or_else(|| conns.get(&we_init_cid))
111 | .or_else(|| conns.get(&peer_init_cid));
112 | match conn {
113 | Some(conn) => {
114 | let _ = conn.send(StreamEvent::Incoming(packet));
115 | }
116 | None => {
117 | if std::matches!(packet.packet_type(), PacketType::Syn) {
118 | let cid = acc_cid;
119 |
120 | // If there was an awaiting connection with the CID, then
121 | // create a new stream for that connection. Otherwise, add the
122 | // connection to the incoming connections.
123 | if let Some(accept_with_cid) = awaiting.remove(&cid) {
124 | peer.consolidate(accept_with_cid.peer);
125 |
126 | let (connected_tx, connected_rx) = oneshot::channel();
127 | let (events_tx, events_rx) = mpsc::unbounded_channel();
128 |
129 | conns.insert(cid.clone(), events_tx);
130 |
131 | let stream = UtpStream::new(
132 | cid,
133 | peer,
134 | accept_with_cid.accept.config,
135 | Some(packet),
136 | socket_event_tx.clone(),
137 | events_rx,
138 | connected_tx
139 | );
140 |
141 | tokio::spawn(async move {
142 | Self::await_connected(stream, accept_with_cid.accept.stream, connected_rx).await
143 | });
144 | } else {
145 | incoming_conns.insert(cid, (peer, packet));
146 | }
147 | } else {
148 | tracing::debug!(
149 | cid = %packet.conn_id(),
150 | packet = ?packet.packet_type(),
151 | seq = %packet.seq_num(),
152 | ack = %packet.ack_num(),
153 | peer_init_cid = ?peer_init_cid,
154 | we_init_cid = ?we_init_cid,
155 | acc_cid = ?acc_cid,
156 | "received uTP packet for non-existing conn"
157 | );
158 | // don't send a reset if we are receiving a reset
159 | if packet.packet_type() != PacketType::Reset {
160 | // if we get a packet from an unknown source send a reset packet.
161 | let random_seq_num = thread_rng().gen_range(0..=65535);
162 | let reset_packet =
163 | PacketBuilder::new(PacketType::Reset, packet.conn_id(), crate::time::now_micros(), 100_000, random_seq_num)
164 | .build();
165 | let event = SocketEvent::Outgoing((reset_packet, peer));
166 | if socket_event_tx.send(event).is_err() {
167 | tracing::warn!("Cannot transmit reset packet: socket closed channel");
168 | return;
169 | }
170 | }
171 | }
172 | },
173 | }
174 | }
175 | Some(accept_with_cid) = accepts_with_cid_rx.recv() => {
176 | let Some((mut peer, syn)) = incoming_conns.remove(&accept_with_cid.cid) else {
177 | awaiting.insert(accept_with_cid.cid.clone(), accept_with_cid);
178 | continue;
179 | };
180 | peer.consolidate(accept_with_cid.peer);
181 | Self::select_accept_helper(accept_with_cid.cid, peer, syn, conns.clone(), accept_with_cid.accept, socket_event_tx.clone());
182 | }
183 | Some(accept) = accepts_rx.recv(), if !incoming_conns.is_empty() => {
184 | let cid = incoming_conns.keys().next().expect("at least one incoming connection");
185 | let cid = cid.clone();
186 | let (peer, packet) = incoming_conns.remove(&cid).expect("to delete incoming connection");
187 | Self::select_accept_helper(cid, peer, packet, conns.clone(), accept, socket_event_tx.clone());
188 | }
189 | Some(event) = socket_event_rx.recv() => {
190 | match event {
191 | SocketEvent::Outgoing((packet, dst)) => {
192 | let encoded = packet.encode();
193 | if let Err(err) = socket.send_to(&encoded, &dst).await {
194 | tracing::debug!(
195 | %err,
196 | cid = %packet.conn_id(),
197 | packet = ?packet.packet_type(),
198 | seq = %packet.seq_num(),
199 | ack = %packet.ack_num(),
200 | "unable to send uTP packet over socket"
201 | );
202 | }
203 | }
204 | SocketEvent::Shutdown(cid) => {
205 | tracing::debug!(%cid.send, %cid.recv, "uTP conn shutdown");
206 | conns.write().unwrap().remove(&cid);
207 | }
208 | }
209 | }
210 | Some(Ok((cid, accept_with_cid))) = awaiting.next() => {
211 | // accept_with_cid didn't receive an inbound connection within the timeout period
212 | // log it and return a timeout error
213 | tracing::debug!(%cid.send, %cid.recv, "accept_with_cid timed out");
214 | let _ = accept_with_cid.accept
215 | .stream
216 | .send(Err(io::Error::from(io::ErrorKind::TimedOut)));
217 | }
218 | Some(Ok((cid, _packet))) = incoming_conns.next() => {
219 | // didn't handle inbound connection within the timeout period
220 | // log it and return a timeout error
221 | tracing::debug!(%cid.send, %cid.recv, "inbound connection timed out");
222 | }
223 | }
224 | }
225 | });
226 |
227 | utp
228 | }
229 |
230 | /// Internal cid generation
231 | fn generate_cid(
232 | &self,
233 | peer_id: P::Id,
234 | is_initiator: bool,
235 | event_tx: Option>,
236 | ) -> ConnectionId {
237 | let mut cid = ConnectionId {
238 | send: 0,
239 | recv: 0,
240 | peer_id,
241 | };
242 | let mut generation_attempt_count = 0;
243 | loop {
244 | if generation_attempt_count > CID_GENERATION_TRY_WARNING_COUNT {
245 | tracing::error!("cid() tried to generate a cid {generation_attempt_count} times")
246 | }
247 | let recv: u16 = rand::random();
248 | let send = if is_initiator {
249 | recv.wrapping_add(1)
250 | } else {
251 | recv.wrapping_sub(1)
252 | };
253 | cid.send = send;
254 | cid.recv = recv;
255 |
256 | if !self.conns.read().unwrap().contains_key(&cid) {
257 | if let Some(event_tx) = event_tx {
258 | self.conns.write().unwrap().insert(cid.clone(), event_tx);
259 | }
260 | return cid;
261 | }
262 | generation_attempt_count += 1;
263 | }
264 | }
265 |
266 | pub fn cid(&self, peer_id: P::Id, is_initiator: bool) -> ConnectionId {
267 | self.generate_cid(peer_id, is_initiator, None)
268 | }
269 |
270 | /// Returns the number of connections currently open, both inbound and outbound.
271 | pub fn num_connections(&self) -> usize {
272 | self.conns.read().unwrap().len()
273 | }
274 |
275 | /// WARNING: only accept() or accept_with_cid() can be used in an application.
276 | /// they aren't compatible to use interchangeably in a program
277 | pub async fn accept(&self, config: ConnectionConfig) -> io::Result> {
278 | let (stream_tx, stream_rx) = oneshot::channel();
279 | let accept = Accept {
280 | stream: stream_tx,
281 | config,
282 | };
283 | self.accepts
284 | .send(accept)
285 | .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
286 | match stream_rx.await {
287 | Ok(stream) => Ok(stream?),
288 | Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)),
289 | }
290 | }
291 |
292 | /// WARNING: only accept() or accept_with_cid() can be used in an application.
293 | /// they aren't compatible to use interchangeably in a program
294 | pub async fn accept_with_cid(
295 | &self,
296 | cid: ConnectionId,
297 | peer: Peer,
298 | config: ConnectionConfig,
299 | ) -> io::Result> {
300 | let (stream_tx, stream_rx) = oneshot::channel();
301 | let accept = AcceptWithCidPeer {
302 | cid,
303 | peer,
304 | accept: Accept {
305 | stream: stream_tx,
306 | config,
307 | },
308 | };
309 | self.accepts_with_cid
310 | .send(accept)
311 | .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
312 | match stream_rx.await {
313 | Ok(stream) => Ok(stream?),
314 | Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)),
315 | }
316 | }
317 |
318 | pub async fn connect(
319 | &self,
320 | peer: Peer,
321 | config: ConnectionConfig,
322 | ) -> io::Result> {
323 | let (connected_tx, connected_rx) = oneshot::channel();
324 | let (events_tx, events_rx) = mpsc::unbounded_channel();
325 | let cid = self.generate_cid(peer.id().clone(), true, Some(events_tx));
326 |
327 | let stream = UtpStream::new(
328 | cid,
329 | peer,
330 | config,
331 | None,
332 | self.socket_events.clone(),
333 | events_rx,
334 | connected_tx,
335 | );
336 |
337 | match connected_rx.await {
338 | Ok(Ok(..)) => Ok(stream),
339 | Ok(Err(err)) => Err(err),
340 | Err(..) => Err(io::Error::from(io::ErrorKind::TimedOut)),
341 | }
342 | }
343 |
344 | pub async fn connect_with_cid(
345 | &self,
346 | cid: ConnectionId,
347 | peer: Peer,
348 | config: ConnectionConfig,
349 | ) -> io::Result> {
350 | if self.conns.read().unwrap().contains_key(&cid) {
351 | return Err(io::Error::new(
352 | io::ErrorKind::Other,
353 | "connection ID unavailable".to_string(),
354 | ));
355 | }
356 |
357 | let (connected_tx, connected_rx) = oneshot::channel();
358 | let (events_tx, events_rx) = mpsc::unbounded_channel();
359 |
360 | {
361 | self.conns.write().unwrap().insert(cid.clone(), events_tx);
362 | }
363 |
364 | let stream = UtpStream::new(
365 | cid.clone(),
366 | peer,
367 | config,
368 | None,
369 | self.socket_events.clone(),
370 | events_rx,
371 | connected_tx,
372 | );
373 |
374 | match connected_rx.await {
375 | Ok(Ok(..)) => Ok(stream),
376 | Ok(Err(err)) => {
377 | tracing::error!(%err, "failed to open connection with {cid:?}");
378 | Err(err)
379 | }
380 | Err(err) => {
381 | tracing::error!(%err, "failed to open connection with {cid:?}");
382 | Err(io::Error::from(io::ErrorKind::TimedOut))
383 | }
384 | }
385 | }
386 |
387 | async fn await_connected(
388 | stream: UtpStream,
389 | callback: oneshot::Sender>>,
390 | connected: oneshot::Receiver>,
391 | ) {
392 | match connected.await {
393 | Ok(Ok(..)) => {
394 | let _ = callback.send(Ok(stream));
395 | }
396 | Ok(Err(err)) => {
397 | let _ = callback.send(Err(err));
398 | }
399 | Err(..) => {
400 | let _ = callback.send(Err(io::Error::from(io::ErrorKind::ConnectionAborted)));
401 | }
402 | }
403 | }
404 |
405 | fn select_accept_helper(
406 | cid: ConnectionId,
407 | peer: Peer,
408 | syn: Packet,
409 | conns: Arc, ConnChannel>>>,
410 | accept: Accept,
411 | socket_event_tx: UnboundedSender>,
412 | ) {
413 | if conns.read().unwrap().contains_key(&cid) {
414 | let _ = accept.stream.send(Err(io::Error::new(
415 | io::ErrorKind::Other,
416 | "connection ID unavailable".to_string(),
417 | )));
418 | return;
419 | }
420 |
421 | let (connected_tx, connected_rx) = oneshot::channel();
422 | let (events_tx, events_rx) = mpsc::unbounded_channel();
423 |
424 | {
425 | conns.write().unwrap().insert(cid.clone(), events_tx);
426 | }
427 |
428 | let stream = UtpStream::new(
429 | cid,
430 | peer,
431 | accept.config,
432 | Some(syn),
433 | socket_event_tx,
434 | events_rx,
435 | connected_tx,
436 | );
437 |
438 | tokio::spawn(
439 | async move { Self::await_connected(stream, accept.stream, connected_rx).await },
440 | );
441 | }
442 | }
443 |
444 | #[derive(Copy, Clone, Debug)]
445 | enum IdType {
446 | RecvId,
447 | SendIdWeInitiated,
448 | SendIdPeerInitiated,
449 | }
450 |
451 | fn cid_from_packet(
452 | packet: &Packet,
453 | peer_id: &P::Id,
454 | id_type: IdType,
455 | ) -> ConnectionId {
456 | let peer_id = peer_id.clone();
457 | match id_type {
458 | IdType::RecvId => {
459 | let (send, recv) = match packet.packet_type() {
460 | PacketType::Syn => (packet.conn_id(), packet.conn_id().wrapping_add(1)),
461 | PacketType::State | PacketType::Data | PacketType::Fin | PacketType::Reset => {
462 | (packet.conn_id().wrapping_sub(1), packet.conn_id())
463 | }
464 | };
465 | ConnectionId {
466 | send,
467 | recv,
468 | peer_id,
469 | }
470 | }
471 | IdType::SendIdWeInitiated => {
472 | let (send, recv) = (packet.conn_id().wrapping_add(1), packet.conn_id());
473 | ConnectionId {
474 | send,
475 | recv,
476 | peer_id,
477 | }
478 | }
479 | IdType::SendIdPeerInitiated => {
480 | let (send, recv) = (packet.conn_id(), packet.conn_id().wrapping_sub(1));
481 | ConnectionId {
482 | send,
483 | recv,
484 | peer_id,
485 | }
486 | }
487 | }
488 | }
489 |
490 | impl Drop for UtpSocket {
491 | fn drop(&mut self) {
492 | for conn in self.conns.read().unwrap().values() {
493 | let _ = conn.send(StreamEvent::Shutdown);
494 | }
495 | }
496 | }
497 |
--------------------------------------------------------------------------------
/src/stream.rs:
--------------------------------------------------------------------------------
1 | use std::io;
2 |
3 | use tokio::sync::{mpsc, oneshot};
4 | use tokio::task;
5 | use tracing::Instrument;
6 |
7 | use crate::cid::ConnectionId;
8 | use crate::congestion::DEFAULT_MAX_PACKET_SIZE_BYTES;
9 | use crate::conn;
10 | use crate::event::{SocketEvent, StreamEvent};
11 | use crate::packet::Packet;
12 | use crate::peer::{ConnectionPeer, Peer};
13 |
14 | /// The size of the send and receive buffers.
15 | // TODO: Make the buffer size configurable.
16 | const BUF: usize = 1024 * 1024;
17 |
18 | pub struct UtpStream {
19 | cid: ConnectionId,
20 | reads: mpsc::UnboundedReceiver,
21 | writes: mpsc::UnboundedSender,
22 | shutdown: Option>,
23 | conn_handle: Option>>,
24 | }
25 |
26 | impl UtpStream
27 | where
28 | P: ConnectionPeer + 'static,
29 | {
30 | pub(crate) fn new(
31 | cid: ConnectionId,
32 | peer: Peer,
33 | config: conn::ConnectionConfig,
34 | syn: Option,
35 | socket_events: mpsc::UnboundedSender>,
36 | stream_events: mpsc::UnboundedReceiver,
37 | connected: oneshot::Sender>,
38 | ) -> Self {
39 | let (shutdown_tx, shutdown_rx) = oneshot::channel();
40 | let (reads_tx, reads_rx) = mpsc::unbounded_channel();
41 | let (writes_tx, writes_rx) = mpsc::unbounded_channel();
42 | let mut conn = conn::Connection::::new(
43 | cid.clone(),
44 | peer,
45 | config,
46 | syn,
47 | connected,
48 | socket_events,
49 | reads_tx,
50 | );
51 | let conn_handle = tokio::spawn(async move {
52 | conn.event_loop(stream_events, writes_rx, shutdown_rx)
53 | .instrument(tracing::info_span!("uTP", send = cid.send, recv = cid.recv))
54 | .await
55 | });
56 |
57 | Self {
58 | cid,
59 | reads: reads_rx,
60 | writes: writes_tx,
61 | shutdown: Some(shutdown_tx),
62 | conn_handle: Some(conn_handle),
63 | }
64 | }
65 |
66 | pub fn cid(&self) -> &ConnectionId {
67 | &self.cid
68 | }
69 |
70 | pub async fn read_to_eof(&mut self, buf: &mut Vec) -> io::Result {
71 | // Reserve space in the buffer to avoid expensive allocation for small reads.
72 | buf.reserve(DEFAULT_MAX_PACKET_SIZE_BYTES as usize);
73 |
74 | let mut n = 0;
75 | loop {
76 | match self.reads.recv().await {
77 | Some(data) => match data {
78 | Ok(mut data) => {
79 | if data.is_empty() {
80 | return Ok(n);
81 | }
82 | n += data.len();
83 | buf.append(&mut data);
84 |
85 | // Reserve additional space in the buffer proportional to the amount of
86 | // data read.
87 | buf.reserve(data.len());
88 | }
89 | Err(err) => return Err(err),
90 | },
91 | None => tracing::debug!("read buffer was sent None"),
92 | }
93 | }
94 | }
95 |
96 | pub async fn write(&mut self, buf: &[u8]) -> io::Result {
97 | if self.shutdown.is_none() {
98 | return Err(io::Error::from(io::ErrorKind::NotConnected));
99 | }
100 |
101 | let (tx, rx) = oneshot::channel();
102 | self.writes
103 | .send((buf.to_vec(), tx))
104 | .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?;
105 |
106 | match rx.await {
107 | Ok(n) => Ok(n?),
108 | Err(err) => Err(io::Error::new(io::ErrorKind::Other, format!("{err:?}"))),
109 | }
110 | }
111 |
112 | /// Closes the stream gracefully.
113 | /// Completes when the remote peer acknowledges all sent data.
114 | pub async fn close(&mut self) -> io::Result<()> {
115 | self.shutdown()?;
116 | match self.conn_handle.take() {
117 | Some(conn_handle) => conn_handle.await?,
118 | None => Err(io::Error::from(io::ErrorKind::NotConnected)),
119 | }
120 | }
121 | }
122 |
123 | impl UtpStream {
124 | // Send signal to the connection event loop to exit, after all outgoing writes have completed.
125 | // Public callers should use close() instead.
126 | fn shutdown(&mut self) -> io::Result<()> {
127 | match self.shutdown.take() {
128 | Some(shutdown) => Ok(shutdown
129 | .send(())
130 | .map_err(|_| io::Error::from(io::ErrorKind::NotConnected))?),
131 | None => Err(io::Error::from(io::ErrorKind::NotConnected)),
132 | }
133 | }
134 | }
135 |
136 | impl Drop for UtpStream {
137 | fn drop(&mut self) {
138 | let _ = self.shutdown();
139 | }
140 | }
141 |
--------------------------------------------------------------------------------
/src/testutils.rs:
--------------------------------------------------------------------------------
1 | use std::io;
2 | use std::sync::atomic::{AtomicBool, Ordering};
3 | use std::sync::Arc;
4 |
5 | use async_trait::async_trait;
6 | use tokio::sync::mpsc;
7 |
8 | use crate::cid::ConnectionId;
9 | use crate::peer::{ConnectionPeer, Peer};
10 | use crate::udp::AsyncUdpSocket;
11 |
12 | /// Decides whether the link between peers is up or down.
13 | trait LinkDecider {
14 | /// Returns true if the link should send the packet, false otherwise.
15 | ///
16 | /// This must only be called once per packet, as it may have side-effects.
17 | fn should_send(&mut self) -> bool;
18 | }
19 |
20 | /// A mock socket that can be used to simulate a perfect link.
21 | #[derive(Debug)]
22 | pub struct MockUdpSocket {
23 | outbound: mpsc::UnboundedSender>,
24 | inbound: mpsc::UnboundedReceiver>,
25 | /// Peers identified by a letter
26 | pub only_peer: char,
27 | /// Defines whether the link is up. If not up, link will SILENTLY drop all sent packets.
28 | pub link: Link,
29 | }
30 |
31 | #[derive(Clone)]
32 | pub struct ManualLinkDecider {
33 | pub up_switch: Arc,
34 | }
35 |
36 | impl ManualLinkDecider {
37 | fn new() -> Self {
38 | Self {
39 | up_switch: Arc::new(AtomicBool::new(true)),
40 | }
41 | }
42 | }
43 |
44 | impl LinkDecider for ManualLinkDecider {
45 | fn should_send(&mut self) -> bool {
46 | self.up_switch.load(Ordering::SeqCst)
47 | }
48 | }
49 |
50 | pub struct LinkDropsFirstNSent {
51 | target_drops: usize,
52 | actual_drops: usize,
53 | }
54 |
55 | impl LinkDropsFirstNSent {
56 | fn new(n: usize) -> Self {
57 | Self {
58 | target_drops: n,
59 | actual_drops: 0,
60 | }
61 | }
62 | }
63 |
64 | impl LinkDecider for LinkDropsFirstNSent {
65 | fn should_send(&mut self) -> bool {
66 | if self.actual_drops < self.target_drops {
67 | self.actual_drops += 1;
68 | false
69 | } else {
70 | true
71 | }
72 | }
73 | }
74 |
75 | #[async_trait]
76 | impl AsyncUdpSocket
77 | for MockUdpSocket
78 | {
79 | /// # Panics
80 | ///
81 | /// Panics if `target` is not equal to `self.only_peer`. This socket is built to support
82 | /// exactly two peers communicating with each other, so it will panic if used with more.
83 | async fn send_to(&mut self, buf: &[u8], peer: &Peer) -> io::Result {
84 | if peer.id() != &self.only_peer {
85 | panic!("MockUdpSocket only supports sending to one peer");
86 | }
87 | if !self.link.should_send() {
88 | tracing::warn!("Dropping packet to {peer:?}: {buf:?}");
89 | return Ok(buf.len());
90 | }
91 | if let Err(err) = self.outbound.send(buf.to_vec()) {
92 | Err(io::Error::new(
93 | io::ErrorKind::UnexpectedEof,
94 | format!("channel closed: {err}"),
95 | ))
96 | } else {
97 | Ok(buf.len())
98 | }
99 | }
100 |
101 | /// # Panics
102 | ///
103 | /// Panics if `buf` is smaller than the packet size.
104 | async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer)> {
105 | let packet = self
106 | .inbound
107 | .recv()
108 | .await
109 | .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "channel closed"))?;
110 | if buf.len() < packet.len() {
111 | panic!("buffer too small for perfect link");
112 | }
113 | let packet_len = packet.len();
114 | buf[..packet_len].copy_from_slice(&packet[..]);
115 | Ok((packet_len, Peer::new(self.only_peer)))
116 | }
117 | }
118 |
119 | impl ConnectionPeer for char {
120 | type Id = char;
121 |
122 | fn id(&self) -> Self::Id {
123 | *self
124 | }
125 |
126 | fn consolidate(a: Self, b: Self) -> Self {
127 | assert!(a == b, "Consolidating non-equal peers");
128 | a
129 | }
130 | }
131 |
132 | fn build_link_pair(
133 | a_to_b_link: LinkAtoB,
134 | b_to_a_link: LinkBtoA,
135 | ) -> (MockUdpSocket, MockUdpSocket) {
136 | let (peer_a, peer_b): (char, char) = ('A', 'B');
137 | let (a_tx, a_rx) = mpsc::unbounded_channel();
138 | let (b_tx, b_rx) = mpsc::unbounded_channel();
139 | let a = MockUdpSocket {
140 | outbound: a_tx,
141 | inbound: b_rx,
142 | only_peer: peer_b,
143 | link: a_to_b_link,
144 | };
145 | let b = MockUdpSocket {
146 | outbound: b_tx,
147 | inbound: a_rx,
148 | only_peer: peer_a,
149 | link: b_to_a_link,
150 | };
151 | (a, b)
152 | }
153 |
154 | fn build_connection_id_pair(
155 | socket_a: &MockUdpSocket,
156 | socket_b: &MockUdpSocket,
157 | ) -> (ConnectionId, ConnectionId) {
158 | build_connection_id_pair_starting_at(socket_a, socket_b, 100)
159 | }
160 |
161 | fn build_connection_id_pair_starting_at(
162 | socket_a: &MockUdpSocket,
163 | socket_b: &MockUdpSocket,
164 | lower_id: u16,
165 | ) -> (ConnectionId, ConnectionId) {
166 | let higher_id = lower_id.wrapping_add(1);
167 | let a_cid = ConnectionId {
168 | send: higher_id,
169 | recv: lower_id,
170 | peer_id: socket_a.only_peer,
171 | };
172 | let b_cid = ConnectionId {
173 | send: lower_id,
174 | recv: higher_id,
175 | peer_id: socket_b.only_peer,
176 | };
177 | (a_cid, b_cid)
178 | }
179 |
180 | /// Build a link between sockets, which we can manually control whether it is up or down
181 | #[allow(clippy::type_complexity)]
182 | pub fn build_manually_linked_pair() -> (
183 | (MockUdpSocket, ConnectionId),
184 | (MockUdpSocket, ConnectionId),
185 | ) {
186 | let (socket_a, socket_b) = build_link_pair(ManualLinkDecider::new(), ManualLinkDecider::new());
187 | let (a_cid, b_cid) = build_connection_id_pair(&socket_a, &socket_b);
188 | ((socket_a, a_cid), (socket_b, b_cid))
189 | }
190 |
191 | /// Build a link between sockets, where the first n packets sent by the 2nd socket are dropped.
192 | ///
193 | /// The first socket is the one with the lower connection ID, which must be the one initiating the
194 | /// connection.
195 | #[allow(clippy::type_complexity)]
196 | pub fn build_link_drops_first_n_sent_pair(
197 | n: usize,
198 | ) -> (
199 | (MockUdpSocket, ConnectionId),
200 | (MockUdpSocket, ConnectionId),
201 | ) {
202 | let link_a_to_b = ManualLinkDecider::new();
203 | let link_b_to_a = LinkDropsFirstNSent::new(n);
204 | let (socket_a, socket_b) = build_link_pair(link_a_to_b, link_b_to_a);
205 | let (a_cid, b_cid) = build_connection_id_pair(&socket_a, &socket_b);
206 | ((socket_a, a_cid), (socket_b, b_cid))
207 | }
208 |
--------------------------------------------------------------------------------
/src/time.rs:
--------------------------------------------------------------------------------
1 | use std::time::{Duration, SystemTime, UNIX_EPOCH};
2 |
3 | /// Returns the UNIX timestamp truncated to a `u32`.
4 | pub fn now_micros() -> u32 {
5 | let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
6 | now.as_micros() as u32
7 | }
8 |
9 | /// Returns the amount of time elapsed between `earlier_micros` and `later_micros`.
10 | ///
11 | /// If `later_micros` is less than `earlier_micros`, then we assume that `later_micros`
12 | /// has wrapped around the `u32` boundary.
13 | pub fn duration_between(earlier_micros: u32, later_micros: u32) -> Duration {
14 | if later_micros < earlier_micros {
15 | Duration::from_micros((u32::MAX - earlier_micros + later_micros).into())
16 | } else {
17 | Duration::from_micros((later_micros - earlier_micros).into())
18 | }
19 | }
20 |
--------------------------------------------------------------------------------
/src/udp.rs:
--------------------------------------------------------------------------------
1 | use std::io;
2 | use std::net::SocketAddr;
3 |
4 | use async_trait::async_trait;
5 | use tokio::net::UdpSocket;
6 |
7 | use crate::peer::{ConnectionPeer, Peer};
8 |
9 | /// An abstract representation of an asynchronous UDP socket.
10 | #[async_trait]
11 | pub trait AsyncUdpSocket: Send + Sync {
12 | /// Attempts to send data on the socket to a given peer.
13 | /// Note that this should return nearly immediately, rather than awaiting something internally.
14 | async fn send_to(&mut self, buf: &[u8], peer: &Peer) -> io::Result;
15 | /// Attempts to receive a single datagram on the socket.
16 | async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer)>;
17 | }
18 |
19 | #[async_trait]
20 | impl AsyncUdpSocket for UdpSocket {
21 | async fn send_to(&mut self, buf: &[u8], peer: &Peer) -> io::Result {
22 | UdpSocket::send_to(self, buf, peer.id()).await
23 | }
24 |
25 | async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, Peer)> {
26 | UdpSocket::recv_from(self, buf)
27 | .await
28 | .map(|(len, peer)| (len, Peer::new(peer)))
29 | }
30 | }
31 |
--------------------------------------------------------------------------------
/tests/socket.rs:
--------------------------------------------------------------------------------
1 | use futures::stream::{FuturesUnordered, StreamExt};
2 | use std::net::SocketAddr;
3 | use std::sync::Arc;
4 | use utp_rs::peer::Peer;
5 |
6 | use tokio::task::JoinHandle;
7 | use tokio::time::Instant;
8 |
9 | use utp_rs::cid;
10 | use utp_rs::conn::ConnectionConfig;
11 | use utp_rs::socket::UtpSocket;
12 |
13 | const TEST_DATA: &[u8] = &[0xf0; 1_000_000];
14 |
15 | #[tokio::test(flavor = "multi_thread", worker_threads = 16)]
16 | async fn many_concurrent_transfers() {
17 | let _ = tracing_subscriber::fmt::try_init();
18 |
19 | tracing::info!("starting socket test");
20 |
21 | let recv_addr = SocketAddr::from(([127, 0, 0, 1], 3400));
22 | let send_addr = SocketAddr::from(([127, 0, 0, 1], 3401));
23 |
24 | let recv = UtpSocket::bind(recv_addr).await.unwrap();
25 | let recv = Arc::new(recv);
26 | let send = UtpSocket::bind(send_addr).await.unwrap();
27 | let send = Arc::new(send);
28 | let mut handles = FuturesUnordered::new();
29 |
30 | let start = Instant::now();
31 | let num_transfers = 1000;
32 | for i in 0..num_transfers {
33 | // step up cid by two to avoid collisions
34 | let handle = initiate_transfer(
35 | i * 2,
36 | recv_addr,
37 | recv.clone(),
38 | send_addr,
39 | send.clone(),
40 | TEST_DATA,
41 | )
42 | .await;
43 | handles.push(handle.0);
44 | handles.push(handle.1);
45 | }
46 |
47 | while let Some(res) = handles.next().await {
48 | res.unwrap();
49 | }
50 | let elapsed = Instant::now() - start;
51 | let megabits_sent = num_transfers as f64 * TEST_DATA.len() as f64 * 8.0 / 1_000_000.0;
52 | let transfer_rate = megabits_sent / elapsed.as_secs_f64();
53 | tracing::info!("finished high concurrency load test of {} simultaneous transfers, in {:?}, at a rate of {:.0} Mbps", num_transfers, elapsed, transfer_rate);
54 | }
55 |
56 | #[tokio::test]
57 | /// Test that a socket can send and receive a large amount of data
58 | async fn one_huge_data_transfer() {
59 | // TODO: test 100MiB or more. Currently, it fails (perhaps due to a rollover at 2^16 packets)
60 |
61 | // At the time of writing, 1024 * 1024 + 1 will hang, because it's bigger than the send buffer,
62 | // and the sending logic pauses until the buffer is larger than the pending data.
63 | const HUGE_DATA: &[u8] = &[0xf0; 1024 * 1024 * 50];
64 |
65 | let _ = tracing_subscriber::fmt::try_init();
66 |
67 | tracing::info!("starting single transfer of huge data test");
68 |
69 | let recv_addr = SocketAddr::from(([127, 0, 0, 1], 3500));
70 | let send_addr = SocketAddr::from(([127, 0, 0, 1], 3501));
71 |
72 | let recv = UtpSocket::bind(recv_addr).await.unwrap();
73 | let recv = Arc::new(recv);
74 | let send = UtpSocket::bind(send_addr).await.unwrap();
75 | let send = Arc::new(send);
76 |
77 | let start = Instant::now();
78 | let handle = initiate_transfer(
79 | 0,
80 | recv_addr,
81 | recv.clone(),
82 | send_addr,
83 | send.clone(),
84 | HUGE_DATA,
85 | )
86 | .await;
87 |
88 | // Wait for the sending side of the transfer to complete
89 | handle.0.await.unwrap();
90 | // Wait for the receiving side of the transfer to complete
91 | handle.1.await.unwrap();
92 |
93 | let elapsed = Instant::now() - start;
94 | let megabytes_sent = HUGE_DATA.len() as f64 / 1_000_000.0;
95 | let megabits_sent = megabytes_sent * 8.0;
96 | let transfer_rate = megabits_sent / elapsed.as_secs_f64();
97 | tracing::info!(
98 | "finished single large transfer test with {:.0} MB, in {:?}, at a rate of {:.1} Mbps",
99 | megabytes_sent,
100 | elapsed,
101 | transfer_rate
102 | );
103 | }
104 |
105 | async fn initiate_transfer(
106 | i: u16,
107 | recv_addr: SocketAddr,
108 | recv: Arc>,
109 | send_addr: SocketAddr,
110 | send: Arc>,
111 | data: &'static [u8],
112 | ) -> (JoinHandle<()>, JoinHandle<()>) {
113 | let conn_config = ConnectionConfig::default();
114 | let initiator_cid = 100 + i;
115 | let responder_cid = 100 + i + 1;
116 | let recv_cid = cid::ConnectionId {
117 | send: initiator_cid,
118 | recv: responder_cid,
119 | peer_id: send_addr,
120 | };
121 | let send_cid = cid::ConnectionId {
122 | send: responder_cid,
123 | recv: initiator_cid,
124 | peer_id: recv_addr,
125 | };
126 |
127 | let recv_handle = tokio::spawn(async move {
128 | let mut stream = recv
129 | .accept_with_cid(recv_cid, Peer::new(send_addr), conn_config)
130 | .await
131 | .unwrap();
132 | let mut buf = vec![];
133 | let n = match stream.read_to_eof(&mut buf).await {
134 | Ok(num_bytes) => num_bytes,
135 | Err(err) => {
136 | let cid = stream.cid();
137 | tracing::error!(?cid, "read to eof error: {:?}", err);
138 | panic!("fail to read data");
139 | }
140 | };
141 | tracing::info!(cid.send = %recv_cid.send, cid.recv = %recv_cid.recv, "read {n} bytes from uTP stream");
142 |
143 | assert_eq!(n, data.len());
144 | assert_eq!(buf, data);
145 | });
146 |
147 | let send_handle = tokio::spawn(async move {
148 | let mut stream = send
149 | .connect_with_cid(send_cid, Peer::new(recv_addr), conn_config)
150 | .await
151 | .unwrap();
152 | let n = stream.write(data).await.unwrap();
153 | assert_eq!(n, data.len());
154 |
155 | stream.close().await.unwrap();
156 | });
157 | (send_handle, recv_handle)
158 | }
159 |
160 | // Test that a new socket has zero connections
161 | #[tokio::test]
162 | async fn test_empty_socket_conn_count() {
163 | let socket_addr = SocketAddr::from(([127, 0, 0, 1], 3402));
164 | let socket = UtpSocket::bind(socket_addr).await.unwrap();
165 | assert_eq!(socket.num_connections(), 0);
166 | }
167 |
168 | // Test that a socket returns 2 from num_connections after connecting twice
169 | #[tokio::test]
170 | async fn test_socket_reports_two_connections() {
171 | let conn_config = ConnectionConfig::default();
172 |
173 | let recv_addr = SocketAddr::from(([127, 0, 0, 1], 3404));
174 | let recv = UtpSocket::bind(recv_addr).await.unwrap();
175 | let recv = Arc::new(recv);
176 |
177 | let send_addr = SocketAddr::from(([127, 0, 0, 1], 3405));
178 | let send = UtpSocket::bind(send_addr).await.unwrap();
179 | let send = Arc::new(send);
180 |
181 | let recv_one_cid = cid::ConnectionId {
182 | send: 100,
183 | recv: 101,
184 | peer_id: send_addr,
185 | };
186 | let send_one_cid = cid::ConnectionId {
187 | send: 101,
188 | recv: 100,
189 | peer_id: recv_addr,
190 | };
191 |
192 | let recv_one = Arc::clone(&recv);
193 | let recv_one_handle = tokio::spawn(async move {
194 | recv_one
195 | .accept_with_cid(recv_one_cid, Peer::new(send_addr), conn_config)
196 | .await
197 | .unwrap()
198 | });
199 |
200 | let send_one = Arc::clone(&send);
201 | let send_one_handle = tokio::spawn(async move {
202 | send_one
203 | .connect_with_cid(send_one_cid, Peer::new(recv_addr), conn_config)
204 | .await
205 | .unwrap()
206 | });
207 |
208 | let recv_two_cid = cid::ConnectionId {
209 | send: 200,
210 | recv: 201,
211 | peer_id: send_addr,
212 | };
213 | let send_two_cid = cid::ConnectionId {
214 | send: 201,
215 | recv: 200,
216 | peer_id: recv_addr,
217 | };
218 |
219 | let recv_two = Arc::clone(&recv);
220 | let recv_two_handle = tokio::spawn(async move {
221 | recv_two
222 | .accept_with_cid(recv_two_cid, Peer::new(send_addr), conn_config)
223 | .await
224 | .unwrap()
225 | });
226 |
227 | let send_two = Arc::clone(&send);
228 | let send_two_handle = tokio::spawn(async move {
229 | send_two
230 | .connect_with_cid(send_two_cid, Peer::new(recv_addr), conn_config)
231 | .await
232 | .unwrap()
233 | });
234 |
235 | let (tx_one, rx_one, tx_two, rx_two) = tokio::join!(
236 | send_one_handle,
237 | recv_one_handle,
238 | send_two_handle,
239 | recv_two_handle
240 | );
241 | tx_one.unwrap();
242 | rx_one.unwrap();
243 | tx_two.unwrap();
244 | rx_two.unwrap();
245 |
246 | assert_eq!(recv.num_connections(), 2);
247 | assert_eq!(send.num_connections(), 2);
248 | }
249 |
--------------------------------------------------------------------------------
/tests/stream.rs:
--------------------------------------------------------------------------------
1 | use std::io::ErrorKind;
2 | use std::sync::atomic::Ordering;
3 | use std::sync::Arc;
4 | use std::time::Duration;
5 |
6 | use tokio::time::timeout;
7 |
8 | use utp_rs::conn::{ConnectionConfig, DEFAULT_MAX_IDLE_TIMEOUT};
9 | use utp_rs::peer::Peer;
10 | use utp_rs::socket::UtpSocket;
11 |
12 | use utp_rs::testutils;
13 |
14 | // How long should tests expect the connection to wait before timing out due to inactivity?
15 | const EXPECTED_IDLE_TIMEOUT: Duration = DEFAULT_MAX_IDLE_TIMEOUT;
16 |
17 | // Test that close() returns successful, after transfer is complete
18 | #[tokio::test]
19 | async fn close_is_successful_when_write_completes() {
20 | let conn_config = ConnectionConfig::default();
21 |
22 | let ((send_socket, send_cid), (recv_socket, recv_cid)) =
23 | testutils::build_manually_linked_pair();
24 |
25 | let recv = UtpSocket::with_socket(recv_socket);
26 | let recv = Arc::new(recv);
27 |
28 | let send = UtpSocket::with_socket(send_socket);
29 | let send = Arc::new(send);
30 |
31 | let recv_one = Arc::clone(&recv);
32 | let recv_one_handle = tokio::spawn(async move {
33 | recv_one
34 | .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config)
35 | .await
36 | .unwrap()
37 | });
38 |
39 | // Keep a clone of the socket so that it doesn't drop when moved into the task.
40 | // Dropping it causes all connections to exit.
41 | let send_one = Arc::clone(&send);
42 | let send_one_handle = tokio::spawn(async move {
43 | send_one
44 | .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config)
45 | .await
46 | .unwrap()
47 | });
48 |
49 | let (tx_one, rx_one) = tokio::join!(send_one_handle, recv_one_handle,);
50 | let mut send_stream = tx_one.unwrap();
51 | let mut recv_stream = rx_one.unwrap();
52 |
53 | // data to send
54 | const DATA_LEN: usize = 100;
55 | let data = [0xa5; DATA_LEN];
56 |
57 | // send data
58 | let send_stream_handle = tokio::spawn(async move {
59 | match send_stream.write(&data).await {
60 | Ok(written_len) => assert_eq!(written_len, DATA_LEN),
61 | Err(e) => panic!("Error sending data: {:?}", e),
62 | };
63 | send_stream
64 | });
65 |
66 | // recv data
67 | let recv_stream_handle = tokio::spawn(async move {
68 | let mut read_buf = vec![];
69 | let _ = recv_stream.read_to_eof(&mut read_buf).await.unwrap();
70 | assert_eq!(read_buf, data.to_vec());
71 | });
72 |
73 | // wait for send to start
74 | let mut send_stream = send_stream_handle.await.unwrap();
75 |
76 | // close stream, which will wait for write to complete, and exit without a problem
77 | // This should happen extremely quickly.
78 | match timeout(Duration::from_millis(20), send_stream.close()).await {
79 | Ok(Ok(_)) => {}
80 | Ok(Err(e)) => panic!("Error closing stream: {:?}", e),
81 | Err(_) => panic!("Timeout closing stream"),
82 | };
83 |
84 | // confirm that data is received as expected
85 | recv_stream_handle.await.unwrap();
86 | }
87 |
88 | // Test that close() returns a timeout, if recipient is not ACKing (after a successful connection)
89 | #[tokio::test(start_paused = true)]
90 | async fn close_errors_if_all_packets_dropped() {
91 | let conn_config = ConnectionConfig::default();
92 |
93 | let ((send_socket, send_cid), (recv_socket, recv_cid)) =
94 | testutils::build_manually_linked_pair();
95 | let tx_link_switch = send_socket.link.up_switch.clone();
96 |
97 | let recv = UtpSocket::with_socket(recv_socket);
98 | let recv = Arc::new(recv);
99 |
100 | let send = UtpSocket::with_socket(send_socket);
101 | let send = Arc::new(send);
102 |
103 | let recv_one = Arc::clone(&recv);
104 | let recv_one_handle = tokio::spawn(async move {
105 | recv_one
106 | .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config)
107 | .await
108 | .unwrap()
109 | });
110 |
111 | // Keep a clone of the socket so that it doesn't drop when moved into the task.
112 | // Dropping it causes all connections to exit.
113 | let send_one = Arc::clone(&send);
114 | let send_one_handle = tokio::spawn(async move {
115 | send_one
116 | .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config)
117 | .await
118 | .unwrap()
119 | });
120 |
121 | let (tx_one, rx_one) = tokio::join!(send_one_handle, recv_one_handle,);
122 | let mut send_stream = tx_one.unwrap();
123 | let mut recv_stream = rx_one.unwrap();
124 |
125 | // ******* DISABLE NETWORK LINK ********
126 | tx_link_switch.store(false, Ordering::SeqCst);
127 |
128 | // data to send
129 | const DATA_LEN: usize = 100;
130 | let data = [0xa5; DATA_LEN];
131 |
132 | // send data
133 | let send_stream_handle = tokio::spawn(async move {
134 | match send_stream.write(&data).await {
135 | Ok(written_len) => assert_eq!(written_len, DATA_LEN),
136 | Err(e) => panic!("Error sending data: {:?}", e),
137 | };
138 | send_stream
139 | });
140 |
141 | // recv data
142 | let recv_stream_handle = tokio::spawn(async move {
143 | let mut read_buf = vec![];
144 | let read_err = recv_stream.read_to_eof(&mut read_buf).await.unwrap_err();
145 | assert_eq!(read_err.kind(), ErrorKind::TimedOut);
146 | });
147 |
148 | // Wait for send to start
149 | let mut send_stream = send_stream_handle.await.unwrap();
150 |
151 | // Close stream, which will fail because network is disabled.
152 | match timeout(EXPECTED_IDLE_TIMEOUT * 2, send_stream.close()).await {
153 | Ok(Ok(_)) => panic!("Stream closed successfully, but should have timed out"),
154 | Ok(Err(e)) => {
155 | // The stream must time out when waiting to close, if the network is disabled.
156 | assert_eq!(e.kind(), ErrorKind::TimedOut);
157 | }
158 | Err(e) => {
159 | panic!("The stream did not timeout on close() fast enough, giving up after: {e:?}")
160 | }
161 | };
162 |
163 | // Wait to confirm that the read will time out, also.
164 | recv_stream_handle.await.unwrap();
165 | }
166 |
167 | // Test that close() succeeds, if the connection is only missing the FIN-ACK
168 | #[tokio::test(start_paused = true)]
169 | async fn close_succeeds_if_only_fin_ack_dropped() {
170 | let conn_config = ConnectionConfig::default();
171 |
172 | let ((send_socket, send_cid), (recv_socket, recv_cid)) =
173 | testutils::build_manually_linked_pair();
174 | let rx_link_switch = recv_socket.link.up_switch.clone();
175 |
176 | let recv = UtpSocket::with_socket(recv_socket);
177 | let recv = Arc::new(recv);
178 |
179 | let send = UtpSocket::with_socket(send_socket);
180 | let send = Arc::new(send);
181 |
182 | let recv_one = Arc::clone(&recv);
183 | let recv_one_handle = tokio::spawn(async move {
184 | recv_one
185 | .accept_with_cid(recv_cid, Peer::new_id(recv_cid.peer_id), conn_config)
186 | .await
187 | .unwrap()
188 | });
189 |
190 | // Keep a clone of the socket so that it doesn't drop when moved into the task.
191 | // Dropping it causes all connections to exit.
192 | let send_one = Arc::clone(&send);
193 | let send_one_handle = tokio::spawn(async move {
194 | send_one
195 | .connect_with_cid(send_cid, Peer::new_id(send_cid.peer_id), conn_config)
196 | .await
197 | .unwrap()
198 | });
199 |
200 | let (tx_one, rx_one) = tokio::join!(send_one_handle, recv_one_handle,);
201 | let mut send_stream = tx_one.unwrap();
202 | let mut recv_stream = rx_one.unwrap();
203 |
204 | // data to send
205 | const DATA_LEN: usize = 100;
206 | let data = [0xa5; DATA_LEN];
207 |
208 | // send data
209 | let send_stream_handle = tokio::spawn(async move {
210 | match send_stream.write(&data).await {
211 | Ok(written_len) => assert_eq!(written_len, DATA_LEN),
212 | Err(e) => panic!("Error sending data: {:?}", e),
213 | };
214 | send_stream
215 | });
216 |
217 | // recv data
218 | let recv_stream_handle = tokio::spawn(async move {
219 | let mut read_buf = vec![];
220 | let _ = recv_stream.read_to_eof(&mut read_buf).await.unwrap();
221 | assert_eq!(read_buf, data.to_vec());
222 | recv_stream
223 | });
224 |
225 | // Wait for send to start
226 | let mut send_stream = send_stream_handle.await.unwrap();
227 |
228 | // Wait for the full data to be sent before dropping the link
229 | // This is a timeless sleep, because tokio time is paused
230 | tokio::time::sleep(EXPECTED_IDLE_TIMEOUT / 2).await;
231 |
232 | // ******* DISABLE NETWORK LINK ********
233 | // This only drops the connection from the recipient to the sender, leading to the following
234 | // scenario:
235 | // - Sender sends FIN
236 | // - Recipient receives FIN, sends FIN-ACK and its own FIN
237 | // - Sender receives nothing, because link is down
238 | // - Recipient is only missing its inbound FIN-ACK and closes with success
239 | // - Sender is missing the recipient's FIN and times out with failure
240 | rx_link_switch.store(false, Ordering::SeqCst);
241 |
242 | match timeout(EXPECTED_IDLE_TIMEOUT * 2, send_stream.close()).await {
243 | Ok(Ok(_)) => panic!("Send stream closed successfully, but should have timed out"),
244 | Ok(Err(e)) => {
245 | // The stream must time out when waiting to close, because recipient's FIN is missing
246 | assert_eq!(e.kind(), ErrorKind::TimedOut);
247 | }
248 | Err(e) => {
249 | panic!("The send stream did not timeout on close() fast enough, giving up after: {e:?}")
250 | }
251 | };
252 |
253 | let mut recv_stream = recv_stream_handle.await.unwrap();
254 |
255 | // Since switching to one-way FIN-ACK, closing after reading is not allowed. We only explicitly
256 | // close after write() now, and close after reading should error.
257 | match timeout(EXPECTED_IDLE_TIMEOUT * 2, recv_stream.close()).await {
258 | Ok(Ok(_)) => panic!("Closing after reading should have errored, but succeeded"),
259 | Ok(Err(e)) => {
260 | // The stream will already be disconnected by the read_to_eof() call, so we expect a
261 | // NotConnected error here.
262 | assert_eq!(e.kind(), ErrorKind::NotConnected);
263 | }
264 | Err(e) => {
265 | panic!("The recv stream did not timeout on close() fast enough, giving up after: {e:?}")
266 | }
267 | };
268 | }
269 |
270 | // Test that data is delivered successfully, even if the original SYN-STATE is dropped
271 | //
272 | // At the time of writing, a bug in this scenario causes the first bytes (2048 of them) to be
273 | // silently lost. The recipient thinks it received everything, but is missing the first bytes of
274 | // the transfer. When the recipient, who started the connection, times out the original SYN, it
275 | // resends. The sender has already sent some data. When the bug is active, the resent STATE
276 | // packet in response to the SYN uses the sequence number after incrementing from the previously
277 | // sent state data. This causes the recipient to ignore all data sent previously.
278 | #[tokio::test(start_paused = true)]
279 | async fn test_data_valid_when_resending_syn_state_response() {
280 | let _ = tracing_subscriber::fmt::try_init();
281 |
282 | let conn_config = ConnectionConfig::default();
283 |
284 | let ((connector_socket, connector_cid), (acceptor_socket, acceptor_cid)) =
285 | testutils::build_link_drops_first_n_sent_pair(1);
286 |
287 | let acceptor = UtpSocket::with_socket(acceptor_socket);
288 | let acceptor = Arc::new(acceptor);
289 |
290 | let connector = UtpSocket::with_socket(connector_socket);
291 | let connector = Arc::new(connector);
292 |
293 | // It's important for this scenario that the data recipient is the one creating the connection.
294 | let acceptor_one = Arc::clone(&acceptor);
295 | let acceptor_one_handle = tokio::spawn(async move {
296 | acceptor_one
297 | .accept_with_cid(
298 | acceptor_cid,
299 | Peer::new_id(acceptor_cid.peer_id),
300 | conn_config,
301 | )
302 | .await
303 | .unwrap()
304 | });
305 |
306 | // Keep a clone of the socket so that it doesn't drop when moved into the task.
307 | // Dropping it causes all connections to exit.
308 | let connector_one = Arc::clone(&connector);
309 | let connector_one_handle = tokio::spawn(async move {
310 | connector_one
311 | .connect_with_cid(
312 | connector_cid,
313 | Peer::new_id(connector_cid.peer_id),
314 | conn_config,
315 | )
316 | .await
317 | .unwrap()
318 | });
319 |
320 | let mut acceptor_stream = acceptor_one_handle.await.unwrap();
321 | tracing::debug!("Acceptor stream established");
322 | // Must not wait for connection to complete before writing data here, so that we can trigger
323 | // the bug.
324 |
325 | // data to send, must be longer than 2048 bytes, which are lost in the bug scenario
326 | const DATA_LEN: usize = 9000;
327 | let data = [0xa5; DATA_LEN];
328 |
329 | // send data
330 | let acceptor_stream_handle = tokio::spawn(async move {
331 | match acceptor_stream.write(&data).await {
332 | Ok(written_len) => assert_eq!(written_len, DATA_LEN),
333 | Err(err) => panic!("Error sending data: {err:?}"),
334 | };
335 | acceptor_stream
336 | });
337 |
338 | // Finally, we can wait for the connection to complete. If we await this any earlier in the
339 | // test, then the data won't be sent, and we don't trigger the bug scenario.
340 | let mut connector_stream = connector_one_handle.await.unwrap();
341 |
342 | // Test that complete data is received
343 | let connector_stream_handle = tokio::spawn(async move {
344 | let mut read_buf = vec![];
345 | let _ = connector_stream.read_to_eof(&mut read_buf).await.unwrap();
346 | assert_eq!(read_buf.len(), data.len());
347 | assert_eq!(read_buf, data.to_vec());
348 | connector_stream
349 | });
350 |
351 | // Complete the streams
352 | let mut acceptor_stream = acceptor_stream_handle.await.unwrap();
353 | acceptor_stream.close().await.unwrap();
354 | connector_stream_handle.await.unwrap();
355 | }
356 |
--------------------------------------------------------------------------------