├── .gitignore ├── src ├── state │ ├── mod.rs │ └── socket_state.rs ├── queue │ ├── mod.rs │ ├── rcv_buffer.rs │ ├── snd_buffer.rs │ ├── snd_queue.rs │ └── rcv_queue.rs ├── common.rs ├── ack_window.rs ├── bin │ ├── udt_sender.rs │ └── udt_receiver.rs ├── packet.rs ├── lib.rs ├── configuration.rs ├── seq_number.rs ├── listener.rs ├── data_packet.rs ├── flow.rs ├── connection.rs ├── multiplexer.rs ├── rate_control.rs ├── loss_list.rs ├── udt.rs ├── control_packet.rs └── socket.rs ├── .github └── workflows │ └── ci.yml ├── Cargo.toml ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /src/state/mod.rs: -------------------------------------------------------------------------------- 1 | mod socket_state; 2 | 3 | pub(crate) use socket_state::SocketState; 4 | -------------------------------------------------------------------------------- /src/queue/mod.rs: -------------------------------------------------------------------------------- 1 | mod rcv_buffer; 2 | mod rcv_queue; 3 | mod snd_buffer; 4 | mod snd_queue; 5 | 6 | pub(crate) use rcv_buffer::RcvBuffer; 7 | pub(crate) use rcv_queue::UdtRcvQueue; 8 | pub(crate) use snd_buffer::SndBuffer; 9 | pub(crate) use snd_queue::UdtSndQueue; 10 | -------------------------------------------------------------------------------- /src/common.rs: -------------------------------------------------------------------------------- 1 | use std::net::IpAddr; 2 | 3 | pub fn ip_to_bytes(ip: IpAddr) -> [u8; 16] { 4 | match ip { 5 | IpAddr::V4(addr) => { 6 | let mut bytes = [0; 16]; 7 | bytes[0..4].copy_from_slice(&addr.octets()); 8 | bytes 9 | } 10 | IpAddr::V6(addr) => addr.octets(), 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | pull_request: 4 | 5 | jobs: 6 | test: 7 | name: Run tests 8 | strategy: 9 | matrix: 10 | os: [ubuntu-20.04, macOS-latest] 11 | runs-on: ${{ matrix.os }} 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: actions-rs/toolchain@v1 15 | with: 16 | toolchain: stable 17 | profile: minimal 18 | - name: cargo test 19 | run: cargo test 20 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tokio-udt" 3 | version = "0.1.0-alpha.8" 4 | edition = "2021" 5 | license = "AGPL-3.0" 6 | description = """ 7 | An implementation of UDP-based Data Transfer Protocol (UDT) based on Tokio primitives 8 | """ 9 | repository = "https://github.com/amatissart/tokio-udt" 10 | keywords = ["udt", "udt4", "networking", "transport", "protocol"] 11 | 12 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 13 | 14 | [dependencies] 15 | rand = "0.8" 16 | tokio = { version = "1.*", features = [ "macros", "net", "io-util", "sync", "time", "rt-multi-thread" ] } 17 | sha2 = "0.10.2" 18 | once_cell = "1.12" 19 | socket2 = "0.4.4" 20 | nix = "0.24.2" 21 | bytes = "1.1" 22 | 23 | [target.'cfg(target_os="linux")'.dependencies] 24 | tokio-timerfd = "0.2" 25 | 26 | [dev-dependencies] 27 | doc-comment = "0.3.3" 28 | -------------------------------------------------------------------------------- /src/ack_window.rs: -------------------------------------------------------------------------------- 1 | use crate::seq_number::{AckSeqNumber, SeqNumber}; 2 | use std::collections::{BTreeMap, VecDeque}; 3 | use tokio::time::{Duration, Instant}; 4 | 5 | #[derive(Debug)] 6 | pub(crate) struct AckWindow { 7 | size: usize, 8 | acks: BTreeMap, 9 | keys: VecDeque, 10 | } 11 | 12 | impl AckWindow { 13 | pub fn new(size: usize) -> Self { 14 | Self { 15 | size, 16 | acks: BTreeMap::new(), 17 | keys: VecDeque::with_capacity(size), 18 | } 19 | } 20 | 21 | pub fn store(&mut self, seq: SeqNumber, ack: AckSeqNumber) { 22 | if self.keys.len() >= self.size { 23 | let oldest = self.keys.pop_front().unwrap(); 24 | self.acks.remove(&oldest); 25 | } 26 | self.keys.push_back(ack); 27 | self.acks.insert(ack, (seq, Instant::now())); 28 | } 29 | 30 | pub fn get(&mut self, ack: AckSeqNumber) -> Option<(SeqNumber, Duration)> { 31 | self.acks.get(&ack).map(|(seq, ts)| (*seq, ts.elapsed())) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/bin/udt_sender.rs: -------------------------------------------------------------------------------- 1 | use std::time::{Duration, Instant}; 2 | use tokio::io::AsyncWriteExt; 3 | use tokio_udt::UdtConnection; 4 | 5 | #[tokio::main] 6 | async fn main() { 7 | let mut connection = UdtConnection::connect("127.0.0.1:9000", None) 8 | .await 9 | .unwrap(); 10 | 11 | println!("Connected!"); 12 | 13 | let buffer: Vec = std::iter::repeat(b"Hello World!") 14 | .take(100000) 15 | .flat_map(|b| *b) 16 | .collect(); 17 | println!("Message length: {}", buffer.len()); 18 | 19 | let mut last = Instant::now(); 20 | let mut count = 0; 21 | 22 | loop { 23 | connection 24 | .write_all(&buffer) 25 | .await 26 | .map(|_| { 27 | count += 1; 28 | }) 29 | .unwrap(); 30 | 31 | if last.elapsed() > Duration::new(1, 0) { 32 | last = Instant::now(); 33 | println!("Sent {} messages", count); 34 | println!( 35 | "Period {:?}", 36 | connection.rate_control().get_pkt_send_period() 37 | ); 38 | println!( 39 | "Window {:?}", 40 | connection.rate_control().get_congestion_window_size() 41 | ); 42 | } 43 | } 44 | 45 | // connection.close().await 46 | } 47 | -------------------------------------------------------------------------------- /src/bin/udt_receiver.rs: -------------------------------------------------------------------------------- 1 | use std::time::{Duration, Instant}; 2 | use tokio_udt::UdtListener; 3 | 4 | use tokio::io::AsyncReadExt; 5 | 6 | #[tokio::main] 7 | async fn main() { 8 | let listener = UdtListener::bind("0.0.0.0:9000".parse().unwrap(), None) 9 | .await 10 | .unwrap(); 11 | 12 | println!("Waiting for connections..."); 13 | 14 | loop { 15 | let (addr, mut connection) = listener.accept().await.unwrap(); 16 | 17 | println!("Accepted connection from {}", addr); 18 | 19 | let mut buffer = Vec::with_capacity(20_000_000); 20 | 21 | tokio::task::spawn({ 22 | let mut bytes = 0; 23 | let mut last = Instant::now(); 24 | async move { 25 | loop { 26 | match connection.read_buf(&mut buffer).await { 27 | Ok(size) => { 28 | bytes += size; 29 | } 30 | Err(_err) => { 31 | eprintln!("Connnection with {} closed", addr); 32 | println!("Received {} MB", bytes as f64 / 1e6); 33 | break; 34 | } 35 | } 36 | 37 | if last.elapsed() > Duration::new(1, 0) { 38 | last = Instant::now(); 39 | println!("Received {} MB", bytes as f64 / 1e6); 40 | } 41 | 42 | if buffer.len() >= 10_000_000 { 43 | buffer.clear(); 44 | } 45 | } 46 | } 47 | }); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/packet.rs: -------------------------------------------------------------------------------- 1 | use super::control_packet::{ControlPacketType, HandShakeInfo, UdtControlPacket}; 2 | use super::data_packet::UdtDataPacket; 3 | use tokio::io::{Error, ErrorKind, Result}; 4 | 5 | #[derive(Debug)] 6 | pub(crate) enum UdtPacket { 7 | Control(UdtControlPacket), 8 | Data(UdtDataPacket), 9 | } 10 | 11 | impl UdtPacket { 12 | pub fn get_dest_socket_id(&self) -> u32 { 13 | match self { 14 | Self::Control(p) => p.dest_socket_id, 15 | Self::Data(p) => p.header.dest_socket_id, 16 | } 17 | } 18 | 19 | pub fn serialize(&self) -> Vec { 20 | match self { 21 | Self::Control(p) => p.serialize(), 22 | Self::Data(p) => p.serialize(), 23 | } 24 | } 25 | 26 | pub fn deserialize(raw: &[u8]) -> Result { 27 | if raw.is_empty() { 28 | return Err(Error::new( 29 | ErrorKind::InvalidData, 30 | "cannot deserialize empty packet", 31 | )); 32 | } 33 | let first_bit = (raw[0] >> 7) != 0; 34 | let packet = match first_bit { 35 | false => Self::Data(UdtDataPacket::deserialize(raw)?), 36 | true => Self::Control(UdtControlPacket::deserialize(raw)?), 37 | }; 38 | Ok(packet) 39 | } 40 | 41 | pub fn handshake(&self) -> Option<&HandShakeInfo> { 42 | match self { 43 | Self::Control(ctrl) => match &ctrl.packet_type { 44 | ControlPacketType::Handshake(info) => Some(info), 45 | _ => None, 46 | }, 47 | _ => None, 48 | } 49 | } 50 | } 51 | 52 | impl From for UdtPacket { 53 | fn from(ctrl: UdtControlPacket) -> Self { 54 | Self::Control(ctrl) 55 | } 56 | } 57 | 58 | impl From for UdtPacket { 59 | fn from(data_packet: UdtDataPacket) -> Self { 60 | Self::Data(data_packet) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/state/socket_state.rs: -------------------------------------------------------------------------------- 1 | use crate::ack_window::AckWindow; 2 | use crate::configuration::UdtConfiguration; 3 | use crate::loss_list::LossList; 4 | use crate::seq_number::{AckSeqNumber, SeqNumber}; 5 | use crate::socket::SYN_INTERVAL; 6 | use tokio::time::{Duration, Instant}; 7 | 8 | #[derive(Debug)] 9 | pub(crate) struct SocketState { 10 | pub last_rsp_time: Instant, 11 | 12 | // Receiving related, 13 | pub last_sent_ack: SeqNumber, 14 | pub last_sent_ack_time: Instant, 15 | pub curr_rcv_seq_number: SeqNumber, 16 | pub last_ack_seq_number: AckSeqNumber, 17 | pub rcv_loss_list: LossList, 18 | pub last_ack2_received: SeqNumber, 19 | 20 | // Sending related 21 | pub last_ack_received: SeqNumber, 22 | pub last_data_ack_processed: SeqNumber, 23 | pub last_ack2_sent_back: AckSeqNumber, 24 | pub curr_snd_seq_number: SeqNumber, 25 | pub last_ack2_time: Instant, 26 | pub snd_loss_list: LossList, 27 | 28 | pub next_ack_time: Instant, 29 | pub interpacket_interval: Duration, 30 | pub interpacket_time_diff: Duration, 31 | pub pkt_count: usize, 32 | pub light_ack_counter: usize, 33 | pub exp_count: u32, 34 | 35 | pub next_data_target_time: Instant, 36 | 37 | pub ack_window: AckWindow, 38 | } 39 | 40 | impl SocketState { 41 | pub fn new(isn: SeqNumber, _configuration: &UdtConfiguration) -> Self { 42 | let now = Instant::now(); 43 | 44 | Self { 45 | last_rsp_time: now, 46 | last_ack_seq_number: AckSeqNumber::zero(), 47 | rcv_loss_list: LossList::new(), 48 | curr_rcv_seq_number: isn - 1, 49 | 50 | next_ack_time: now + SYN_INTERVAL, 51 | interpacket_interval: Duration::from_micros(1), 52 | interpacket_time_diff: Duration::ZERO, 53 | pkt_count: 0, 54 | light_ack_counter: 0, 55 | 56 | exp_count: 1, 57 | last_ack_received: isn, 58 | last_sent_ack: isn - 1, 59 | last_sent_ack_time: now, 60 | last_ack2_received: isn.number().into(), 61 | 62 | curr_snd_seq_number: isn - 1, 63 | last_ack2_sent_back: isn.number().into(), 64 | last_ack2_time: now, 65 | last_data_ack_processed: isn, 66 | snd_loss_list: LossList::new(), 67 | 68 | next_data_target_time: now, 69 | 70 | ack_window: AckWindow::new(1024), 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tokio-udt 2 | 3 | An implementation of UDP-based Data Transfer Protocol (UDT) based on Tokio primitives. 4 | 5 | [![Crates.io][crates-badge]][crates-url] 6 | [![Docs][docs-badge]][docs-url] 7 | 8 | [crates-badge]: https://img.shields.io/crates/v/tokio-udt.svg 9 | [crates-url]: https://crates.io/crates/tokio-udt 10 | [docs-badge]: https://img.shields.io/docsrs/tokio-udt.svg 11 | [docs-url]: https://docs.rs/tokio-udt/ 12 | 13 | ## What is UDT? 14 | 15 | UDT is a high performance data transport protocol. It was designed specifically for data intensive 16 | applications over high speed wide area networks, to overcome the efficiency and fairness 17 | problems of TCP. As its names indicates, UDT is built on top of UDP and it provides both 18 | reliable data streaming and messaging services. 19 | 20 | To learn more about UDT, see https://udt.sourceforge.io/ 21 | 22 | You can also find the reference C++ implementation on https://github.com/eminence/udt 23 | 24 | 25 | ## Examples 26 | 27 | ### UDT listener 28 | 29 | ```rust,no_run 30 | use std::net::Ipv4Addr; 31 | use tokio::io::{AsyncReadExt, Result}; 32 | use tokio_udt::UdtListener; 33 | 34 | #[tokio::main] 35 | async fn main() -> Result<()> { 36 | let port = 9000; 37 | let listener = UdtListener::bind((Ipv4Addr::UNSPECIFIED, port).into(), None).await?; 38 | 39 | println!("Waiting for connections..."); 40 | 41 | loop { 42 | let (addr, mut connection) = listener.accept().await?; 43 | println!("Accepted connection from {}", addr); 44 | let mut buffer = Vec::with_capacity(1_000_000); 45 | tokio::task::spawn({ 46 | async move { 47 | loop { 48 | match connection.read_buf(&mut buffer).await { 49 | Ok(_size) => {} 50 | Err(e) => { 51 | eprintln!("Connnection with {} failed: {}", addr, e); 52 | break; 53 | } 54 | } 55 | } 56 | } 57 | }); 58 | } 59 | } 60 | ``` 61 | 62 | ### UDT client 63 | 64 | ```rust,no_run 65 | use std::net::Ipv4Addr; 66 | use tokio::io::{AsyncWriteExt, Result}; 67 | use tokio_udt::UdtConnection; 68 | 69 | #[tokio::main] 70 | async fn main() -> Result<()> { 71 | let port = 9000; 72 | let mut connection = UdtConnection::connect((Ipv4Addr::LOCALHOST, port), None).await?; 73 | loop { 74 | connection.write_all(b"Hello World!").await?; 75 | } 76 | } 77 | ``` 78 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | An implementation of UDP-based Data Transfer Protocol (UDT) based on Tokio primitives. 3 | 4 | UDT is a high performance data transport protocol. It was designed for data intensive 5 | applications over high speed wide area networks, to overcome the efficiency and fairness 6 | problems of TCP. As its names indicates, UDT is built on top of UDP and it provides both 7 | reliable data streaming and messaging services. 8 | 9 | 10 | # Usage 11 | 12 | ## UDT server example 13 | 14 | Bind a port with [`UdtListener`]: 15 | 16 | ```no_run 17 | use std::net::Ipv4Addr; 18 | use tokio::io::{AsyncReadExt, Result}; 19 | use tokio_udt::UdtListener; 20 | 21 | #[tokio::main] 22 | async fn main() -> Result<()> { 23 | let port = 9000; 24 | let listener = UdtListener::bind((Ipv4Addr::UNSPECIFIED, port).into(), None).await?; 25 | 26 | println!("Waiting for connections..."); 27 | 28 | loop { 29 | let (addr, mut connection) = listener.accept().await?; 30 | println!("Accepted connection from {}", addr); 31 | let mut buffer = Vec::with_capacity(1_000_000); 32 | tokio::task::spawn({ 33 | async move { 34 | loop { 35 | match connection.read_buf(&mut buffer).await { 36 | Ok(_size) => {} 37 | Err(e) => { 38 | eprintln!("Connnection with {} failed: {}", addr, e); 39 | break; 40 | } 41 | } 42 | } 43 | } 44 | }); 45 | } 46 | } 47 | ``` 48 | 49 | ## UDT client example 50 | 51 | Open a connection with [`UdtConnection`] 52 | 53 | ```no_run 54 | use std::net::Ipv4Addr; 55 | use tokio::io::{AsyncWriteExt, Result}; 56 | use tokio_udt::UdtConnection; 57 | 58 | #[tokio::main] 59 | async fn main() -> Result<()> { 60 | let port = 9000; 61 | let mut connection = UdtConnection::connect((Ipv4Addr::LOCALHOST, port), None).await?; 62 | loop { 63 | connection.write_all(b"Hello World!").await?; 64 | } 65 | } 66 | ``` 67 | */ 68 | mod ack_window; 69 | mod common; 70 | mod configuration; 71 | mod connection; 72 | mod control_packet; 73 | mod data_packet; 74 | mod flow; 75 | mod listener; 76 | mod loss_list; 77 | mod multiplexer; 78 | mod packet; 79 | mod queue; 80 | mod rate_control; 81 | mod seq_number; 82 | mod socket; 83 | mod state; 84 | mod udt; 85 | 86 | pub use configuration::UdtConfiguration; 87 | pub use connection::UdtConnection; 88 | pub use listener::UdtListener; 89 | pub use rate_control::RateControl; 90 | pub use seq_number::SeqNumber; 91 | 92 | #[cfg(doctest)] 93 | doc_comment::doctest!("../README.md"); 94 | -------------------------------------------------------------------------------- /src/configuration.rs: -------------------------------------------------------------------------------- 1 | use std::time::Duration; 2 | 3 | const DEFAULT_MSS: u32 = 1500; 4 | const DEFAULT_UDT_BUF_SIZE: u32 = 81920; 5 | const DEFAULT_UDP_BUF_SIZE: usize = 8_000_000; 6 | const UDT_VERSION: u32 = 4; 7 | 8 | /// Options for UDT protocol 9 | #[derive(Debug, Clone)] 10 | pub struct UdtConfiguration { 11 | /// Packet size: the optimal size is the network MTU size. The default value is 1500 bytes. 12 | /// A UDT connection will choose the smaller value of the MSS between the two peer sides. 13 | pub mss: u32, 14 | /// Maximum window size (nb of packets). 15 | /// Internal parameter: you should set it to not less than `rcv_buf_size`. 16 | /// Default: 256000 17 | pub flight_flag_size: u32, 18 | /// Size of temporary storage for packets to send (nb of packets) 19 | pub snd_buf_size: u32, 20 | /// Size of temporary storage for packets to receive (nb of packets) 21 | pub rcv_buf_size: u32, 22 | /// UDT uses UDP as the data channel, so the UDP buffer size may affect the performance. 23 | /// The sending buffer size is applied on the UDP socket. The actual value used 24 | /// by the kernel is bounded by "net.core.wmem_max". 25 | pub udp_snd_buf_size: usize, 26 | /// UDT uses UDP as the data channel, so the UDP buffer size may affect the performance. 27 | /// The receiving buffer size is applied on the UDP socket. The actual value used 28 | /// by the kernel is bounded by "net.core.rmem_max". 29 | pub udp_rcv_buf_size: usize, 30 | /// Whether SO_REUSEPORT option should be set on the UDP socket. 31 | /// On Linux, this option can be useful to load-balance packets 32 | /// from multiple clients to distinct threads and distinct UDT multiplexers. 33 | /// Default: false. 34 | pub udp_reuse_port: bool, 35 | /// Whether a potential existing UDT multiplexer (and associated UDP socket) 36 | /// should be reused when binding the same port. The preexisting listener 37 | /// must have been created with this option set to true. 38 | /// For optimal throughput from multiple clients, using 39 | /// `udp_reuse_port` may be preferable. 40 | /// Default: true 41 | pub reuse_mux: bool, 42 | /// UDT rendez-vous mode. (NOT IMPLEMENTED) 43 | pub rendezvous: bool, 44 | /// Maximum number of pending UDT connections to accept. Default: 1000 45 | pub accept_queue_size: usize, 46 | /// Linger time on close(). Default: 10 seconds 47 | pub linger_timeout: Option, 48 | } 49 | 50 | impl UdtConfiguration { 51 | pub fn udt_version(&self) -> u32 { 52 | UDT_VERSION 53 | } 54 | } 55 | 56 | impl Default for UdtConfiguration { 57 | fn default() -> Self { 58 | Self { 59 | mss: DEFAULT_MSS, 60 | flight_flag_size: 256000, 61 | snd_buf_size: DEFAULT_UDT_BUF_SIZE, 62 | rcv_buf_size: DEFAULT_UDT_BUF_SIZE * 2, 63 | udp_snd_buf_size: DEFAULT_UDP_BUF_SIZE, 64 | udp_rcv_buf_size: DEFAULT_UDP_BUF_SIZE, 65 | udp_reuse_port: false, 66 | linger_timeout: Some(Duration::from_secs(10)), 67 | reuse_mux: true, 68 | rendezvous: false, 69 | accept_queue_size: 1000, 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /src/queue/rcv_buffer.rs: -------------------------------------------------------------------------------- 1 | use crate::data_packet::UdtDataPacket; 2 | use crate::seq_number::{MsgNumber, SeqNumber}; 3 | use std::collections::BTreeMap; 4 | use tokio::io::ReadBuf; 5 | 6 | #[derive(Debug)] 7 | pub(crate) struct RcvBuffer { 8 | packets: BTreeMap, 9 | max_size: u32, 10 | next_to_read: SeqNumber, 11 | next_to_ack: SeqNumber, 12 | } 13 | 14 | impl RcvBuffer { 15 | pub fn new(max_size: u32, initial_seq_number: SeqNumber) -> Self { 16 | Self { 17 | max_size, 18 | packets: BTreeMap::new(), 19 | next_to_read: initial_seq_number, 20 | next_to_ack: initial_seq_number, 21 | } 22 | } 23 | 24 | pub fn get_available_buf_size(&self) -> u32 { 25 | self.max_size - self.packets.len() as u32 26 | } 27 | 28 | pub fn insert(&mut self, packet: UdtDataPacket) { 29 | let seq_number = packet.header.seq_number; 30 | self.packets.entry(seq_number).or_insert(packet); 31 | } 32 | 33 | pub fn drop_msg(&mut self, msg: MsgNumber) { 34 | self.packets 35 | .retain(|_k, packet| packet.header.msg_number != msg); 36 | } 37 | 38 | pub fn ack_data(&mut self, to: SeqNumber) { 39 | if (to - self.next_to_ack) > 0 { 40 | self.next_to_ack = to; 41 | } 42 | } 43 | 44 | pub fn has_data_to_read(&self) -> bool { 45 | let first = self.next_to_read; 46 | let last = self.next_to_ack; 47 | if first <= last { 48 | return self.packets.range(first..last).next().is_some(); 49 | } else { 50 | return self 51 | .packets 52 | .range(first..=SeqNumber::max()) 53 | .next() 54 | .is_some() 55 | || self.packets.range(SeqNumber::zero()..last).next().is_some(); 56 | } 57 | } 58 | 59 | pub fn read_buffer(&mut self, buf: &mut ReadBuf<'_>) -> usize { 60 | if self.next_to_read == self.next_to_ack { 61 | return 0; 62 | } 63 | 64 | let packets = { 65 | if self.next_to_read <= self.next_to_ack { 66 | self.packets 67 | .range(self.next_to_read..self.next_to_ack) 68 | .chain( 69 | self.packets.range(SeqNumber::zero()..SeqNumber::zero()), //empty 70 | ) 71 | } else { 72 | self.packets 73 | .range(self.next_to_read..=SeqNumber::max()) 74 | .chain(self.packets.range(SeqNumber::zero()..self.next_to_ack)) 75 | } 76 | }; 77 | 78 | let mut written = 0; 79 | let mut to_remove = vec![]; 80 | for (key, packet) in packets { 81 | let packet_len = packet.data.len(); 82 | if buf.remaining() < packet_len { 83 | break; 84 | } 85 | buf.put_slice(&packet.data); 86 | written += packet_len; 87 | to_remove.push(*key); 88 | self.next_to_read = *key + 1; 89 | } 90 | 91 | to_remove.iter().for_each(|k| { 92 | self.packets.remove(k); 93 | }); 94 | 95 | written 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/seq_number.rs: -------------------------------------------------------------------------------- 1 | use rand::Rng; 2 | use std::marker::PhantomData; 3 | 4 | pub trait SeqConstants: Clone { 5 | const MAX_NUMBER: u32; 6 | 7 | fn threshold() -> u32 { 8 | Self::MAX_NUMBER / 2 9 | } 10 | } 11 | 12 | /// A sequence number, with a valid values in `[0, T::MAX_NUMBER]` 13 | /// that implements cyclic arithmetic 14 | #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] 15 | pub struct GenericSeqNumber 16 | where 17 | T: SeqConstants, 18 | { 19 | number: u32, 20 | phantom: PhantomData, 21 | } 22 | 23 | impl From for GenericSeqNumber { 24 | fn from(number: u32) -> Self { 25 | Self { 26 | number, 27 | phantom: PhantomData, 28 | } 29 | } 30 | } 31 | 32 | impl GenericSeqNumber { 33 | pub const MAX_NUMBER: u32 = T::MAX_NUMBER; 34 | 35 | pub fn number(self) -> u32 { 36 | self.number 37 | } 38 | 39 | pub fn random() -> Self { 40 | rand::thread_rng().gen_range(0..=T::MAX_NUMBER).into() 41 | } 42 | 43 | pub fn zero() -> Self { 44 | 0.into() 45 | } 46 | 47 | pub fn max() -> Self { 48 | T::MAX_NUMBER.into() 49 | } 50 | } 51 | 52 | impl std::ops::Sub for GenericSeqNumber { 53 | type Output = i32; 54 | 55 | #[allow(clippy::neg_multiply)] 56 | fn sub(self, other: Self) -> Self::Output { 57 | if self.number.abs_diff(other.number) <= T::threshold() { 58 | self.number as i32 - other.number as i32 59 | } else if self.number < T::threshold() { 60 | (self.number + T::MAX_NUMBER + 1 - other.number) as i32 * -1 61 | } else { 62 | (other.number + T::MAX_NUMBER + 1 - self.number) as i32 63 | } 64 | } 65 | } 66 | 67 | impl std::ops::Add for GenericSeqNumber { 68 | type Output = GenericSeqNumber; 69 | 70 | fn add(self, rhs: i32) -> Self { 71 | let resp = ((self.number as i64 + rhs as i64).rem_euclid(T::MAX_NUMBER as i64 + 1)) as u32; 72 | resp.into() 73 | } 74 | } 75 | 76 | impl std::ops::Sub for GenericSeqNumber { 77 | type Output = GenericSeqNumber; 78 | 79 | #[allow(clippy::neg_multiply)] 80 | fn sub(self, rhs: i32) -> Self { 81 | self + (rhs * -1) 82 | } 83 | } 84 | 85 | #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Copy)] 86 | pub struct SeqNumberConstants; 87 | impl SeqConstants for SeqNumberConstants { 88 | const MAX_NUMBER: u32 = 0x7fff_ffff; 89 | } 90 | 91 | /// Packet Sequence Number (on 31 bits) 92 | pub type SeqNumber = GenericSeqNumber; 93 | 94 | #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Copy)] 95 | pub struct AckSeqNumberConstants; 96 | impl SeqConstants for AckSeqNumberConstants { 97 | const MAX_NUMBER: u32 = 0x7fff_ffff; 98 | } 99 | 100 | /// ACK sequence number (on 31 bits) 101 | pub type AckSeqNumber = GenericSeqNumber; 102 | 103 | #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Copy)] 104 | pub struct MsgNumberConstants; 105 | impl SeqConstants for MsgNumberConstants { 106 | const MAX_NUMBER: u32 = 0x1fff_ffff; 107 | } 108 | 109 | /// Message Number (on 29 bits) 110 | pub type MsgNumber = GenericSeqNumber; 111 | -------------------------------------------------------------------------------- /src/listener.rs: -------------------------------------------------------------------------------- 1 | use crate::configuration::UdtConfiguration; 2 | use crate::connection::UdtConnection; 3 | use crate::socket::{SocketType, UdtStatus}; 4 | use crate::udt::{SocketRef, Udt}; 5 | use std::net::SocketAddr; 6 | use tokio::io::{Error, ErrorKind, Result}; 7 | 8 | /// An object representing a UDT socket listening for incoming connections 9 | pub struct UdtListener { 10 | socket: SocketRef, 11 | } 12 | 13 | impl UdtListener { 14 | pub async fn bind(bind_addr: SocketAddr, config: Option) -> Result { 15 | let socket = { 16 | let mut udt = Udt::get().write().await; 17 | udt.new_socket(SocketType::Stream, config)?.clone() 18 | }; 19 | 20 | if socket.configuration.read().unwrap().rendezvous { 21 | return Err(Error::new( 22 | ErrorKind::Unsupported, 23 | "listen is not supported in rendezvous connection setup", 24 | )); 25 | } 26 | 27 | let socket_id = socket.socket_id; 28 | 29 | { 30 | let mut udt = Udt::get().write().await; 31 | udt.bind(socket_id, bind_addr).await?; 32 | } 33 | 34 | { 35 | let socket_ref = socket.clone(); 36 | let mux = socket 37 | .multiplexer() 38 | .expect("multiplexer is not initialized"); 39 | *mux.listener.write().await = Some(socket_ref); 40 | *socket.status.lock().unwrap() = UdtStatus::Listening; 41 | 42 | println!("Now listening on {:?}", bind_addr); 43 | } 44 | 45 | Ok(Self { socket }) 46 | } 47 | 48 | pub async fn accept(&self) -> Result<(SocketAddr, UdtConnection)> { 49 | { 50 | if self.socket.configuration.read().unwrap().rendezvous { 51 | return Err(Error::new( 52 | ErrorKind::Unsupported, 53 | "no 'accept' in rendezvous connection setup", 54 | )); 55 | } 56 | } 57 | 58 | let accepted_socket_id = loop { 59 | let notified = { 60 | if self.socket.status() != UdtStatus::Listening { 61 | return Err(Error::new( 62 | ErrorKind::Other, 63 | "socket is not in listening state", 64 | )); 65 | } 66 | 67 | let mut queue = self.socket.queued_sockets.write().await; 68 | if let Some(socket_id) = queue.iter().next() { 69 | let socket_id = *socket_id; 70 | queue.remove(&socket_id); 71 | break socket_id; 72 | } 73 | self.socket.accept_notify.notified() 74 | }; 75 | notified.await 76 | }; 77 | 78 | let udt = Udt::get().read().await; 79 | let accepted_socket = udt.get_socket(accepted_socket_id).ok_or_else(|| { 80 | Error::new( 81 | ErrorKind::Other, 82 | "invalid socket id when accepting connection", 83 | ) 84 | })?; 85 | 86 | let peer_addr = accepted_socket.peer_addr().ok_or_else(|| { 87 | Error::new( 88 | ErrorKind::Other, 89 | "unknown peer address for accepted connection", 90 | ) 91 | })?; 92 | 93 | Ok((peer_addr, UdtConnection::new(accepted_socket))) 94 | } 95 | 96 | pub fn local_addr(&self) -> Result { 97 | self.socket.multiplexer().unwrap().channel.local_addr() 98 | } 99 | 100 | pub fn socket_id(&self) -> u32 { 101 | self.socket.socket_id 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/data_packet.rs: -------------------------------------------------------------------------------- 1 | use crate::seq_number::{MsgNumber, SeqNumber}; 2 | use bytes::Bytes; 3 | use tokio::io::{Error, ErrorKind, Result}; 4 | 5 | pub const UDT_DATA_HEADER_SIZE: usize = 16; 6 | 7 | #[derive(Debug)] 8 | pub(crate) struct UdtDataPacket { 9 | pub header: UdtDataPacketHeader, 10 | pub data: Bytes, 11 | } 12 | 13 | impl UdtDataPacket { 14 | pub fn deserialize(raw: &[u8]) -> Result { 15 | let header = UdtDataPacketHeader::deserialize(&raw[..UDT_DATA_HEADER_SIZE])?; 16 | let data = Bytes::copy_from_slice(&raw[UDT_DATA_HEADER_SIZE..]); 17 | Ok(Self { header, data }) 18 | } 19 | 20 | pub fn payload_len(&self) -> u32 { 21 | self.data.len() as u32 22 | } 23 | 24 | pub fn serialize(&self) -> Vec { 25 | let mut buffer = Vec::with_capacity(1500); 26 | buffer.extend_from_slice(&self.header.serialize()); 27 | buffer.extend_from_slice(&self.data); 28 | buffer 29 | } 30 | } 31 | 32 | #[derive(Debug)] 33 | pub(crate) struct UdtDataPacketHeader { 34 | // bit 0 = 0 35 | pub seq_number: SeqNumber, // bits 1-31 36 | pub position: PacketPosition, // bits 32-33 37 | pub in_order: bool, // bit 34 38 | pub msg_number: MsgNumber, // bits 35-63 39 | pub timestamp: u32, // bits 64-95 40 | pub dest_socket_id: u32, // bits 96-127 41 | } 42 | 43 | impl UdtDataPacketHeader { 44 | pub fn deserialize(raw: &[u8]) -> Result { 45 | if raw.len() < 16 { 46 | return Err(Error::new( 47 | ErrorKind::InvalidData, 48 | "data packet header is too short", 49 | )); 50 | } 51 | let seq_number = u32::from_be_bytes(raw[0..4].try_into().unwrap()) & 0x7fffffff; 52 | let position: PacketPosition = ((raw[4] & 0b11000000) >> 6).try_into()?; 53 | let in_order = (raw[4] & 0b00100000) != 0; 54 | let msg_number = u32::from_be_bytes(raw[4..8].try_into().unwrap()) & 0x1fffffff; 55 | let timestamp = u32::from_be_bytes(raw[8..12].try_into().unwrap()); 56 | let dest_socket_id = u32::from_be_bytes(raw[12..16].try_into().unwrap()); 57 | Ok(Self { 58 | seq_number: seq_number.into(), 59 | position, 60 | in_order, 61 | msg_number: msg_number.into(), 62 | timestamp, 63 | dest_socket_id, 64 | }) 65 | } 66 | 67 | pub fn serialize(&self) -> Vec { 68 | let mut buffer: Vec = Vec::with_capacity(UDT_DATA_HEADER_SIZE); 69 | buffer.extend_from_slice(&self.seq_number.number().to_be_bytes()); 70 | 71 | let block: u32 = ((self.position as u32) << 30) 72 | + ((self.in_order as u32) << 29) 73 | + self.msg_number.number(); 74 | 75 | buffer.extend_from_slice(&block.to_be_bytes()); 76 | buffer.extend_from_slice(&self.timestamp.to_be_bytes()); 77 | buffer.extend_from_slice(&self.dest_socket_id.to_be_bytes()); 78 | buffer 79 | } 80 | } 81 | 82 | #[derive(Debug, Clone, Copy)] 83 | pub(crate) enum PacketPosition { 84 | First = 2, 85 | Last = 1, 86 | Only = 3, 87 | Middle = 0, 88 | } 89 | 90 | impl TryFrom for PacketPosition { 91 | type Error = Error; 92 | 93 | fn try_from(raw_position: u8) -> Result { 94 | match raw_position { 95 | 0b10 => Ok(PacketPosition::First), 96 | 0b01 => Ok(PacketPosition::Last), 97 | 0b11 => Ok(PacketPosition::Only), 98 | 0b00 => Ok(PacketPosition::Middle), 99 | _ => Err(Error::new( 100 | ErrorKind::InvalidData, 101 | "invalid packet position", 102 | )), 103 | } 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/flow.rs: -------------------------------------------------------------------------------- 1 | use std::collections::VecDeque; 2 | use tokio::time::{Duration, Instant}; 3 | 4 | const ARRIVAL_WINDOW_SIZE: usize = 16; 5 | const PROBE_WINDOW_SIZE: usize = 64; 6 | pub const PROBE_MODULO: u32 = 16; 7 | 8 | #[derive(Debug)] 9 | pub(crate) struct UdtFlow { 10 | pub flow_window_size: u32, 11 | 12 | arrival_window: VecDeque, 13 | probe_window: VecDeque, 14 | last_arrival_time: Instant, 15 | probe_time: Instant, 16 | pub rtt: Duration, 17 | pub rtt_var: Duration, 18 | pub peer_bandwidth: u32, 19 | pub peer_delivery_rate: u32, 20 | } 21 | 22 | impl Default for UdtFlow { 23 | fn default() -> Self { 24 | let now = Instant::now(); 25 | Self { 26 | flow_window_size: 0, 27 | last_arrival_time: now, 28 | arrival_window: VecDeque::new(), 29 | probe_time: now, 30 | probe_window: VecDeque::new(), 31 | rtt: Duration::from_millis(100), 32 | rtt_var: Duration::from_millis(50), 33 | peer_bandwidth: 1, 34 | peer_delivery_rate: 16, 35 | } 36 | } 37 | } 38 | 39 | impl UdtFlow { 40 | pub fn on_pkt_arrival(&mut self, now: Instant) { 41 | self.arrival_window.push_back(now - self.last_arrival_time); 42 | if self.arrival_window.len() > ARRIVAL_WINDOW_SIZE { 43 | self.arrival_window.pop_front(); 44 | } 45 | self.last_arrival_time = now; 46 | } 47 | 48 | pub fn on_probe1_arrival(&mut self) { 49 | self.probe_time = Instant::now(); 50 | } 51 | 52 | pub fn on_probe2_arrival(&mut self) { 53 | let now = Instant::now(); 54 | self.probe_window.push_back(now - self.probe_time); 55 | if self.probe_window.len() > PROBE_WINDOW_SIZE { 56 | self.probe_window.pop_front(); 57 | } 58 | } 59 | 60 | /// Returns a number of packets per second 61 | pub fn get_pkt_rcv_speed(&self) -> u32 { 62 | let length = self.arrival_window.len(); 63 | let mut values = self.arrival_window.clone(); 64 | let (_, median, _) = values.make_contiguous().select_nth_unstable(length / 2); 65 | let median = *median; 66 | let values: Vec<_> = values 67 | .into_iter() 68 | .filter(|x| *x > median / 8 && *x < median * 8) 69 | .collect(); 70 | if values.len() < ARRIVAL_WINDOW_SIZE / 2 { 71 | return 0; 72 | } 73 | let total_duration: Duration = values.iter().sum(); 74 | (values.len() as f64 / total_duration.as_secs_f64()).ceil() as u32 75 | } 76 | 77 | pub fn get_bandwidth(&self) -> u32 { 78 | if self.probe_window.is_empty() { 79 | return 0; 80 | } 81 | let length = self.probe_window.len(); 82 | let mut values = self.probe_window.clone(); 83 | let (_, median, _) = values.make_contiguous().select_nth_unstable(length / 2); 84 | let median = *median; 85 | let values: Vec<_> = values 86 | .into_iter() 87 | .filter(|x| *x > median / 8 && *x < median * 8) 88 | .collect(); 89 | let total_duration: Duration = values.iter().sum(); 90 | if total_duration.is_zero() { 91 | return 0; 92 | } 93 | (values.len() as f64 / total_duration.as_secs_f64()).ceil() as u32 94 | } 95 | 96 | pub fn update_rtt(&mut self, new_val: Duration) { 97 | self.rtt = (7 * self.rtt + new_val) / 8; 98 | } 99 | 100 | pub fn update_rtt_var(&mut self, new_val: Duration) { 101 | self.rtt_var = (3 * self.rtt_var + new_val) / 4; 102 | } 103 | 104 | pub fn update_bandwidth(&mut self, new_val: u32) { 105 | self.peer_bandwidth = (7 * self.peer_bandwidth + new_val) / 8; 106 | } 107 | 108 | pub fn update_peer_delivery_rate(&mut self, new_val: u32) { 109 | self.peer_delivery_rate = (7 * self.peer_delivery_rate + new_val) / 8; 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/connection.rs: -------------------------------------------------------------------------------- 1 | use crate::configuration::UdtConfiguration; 2 | use crate::socket::{SocketType, UdtStatus}; 3 | use crate::udt::{SocketRef, Udt}; 4 | use std::net::SocketAddr; 5 | use std::pin::Pin; 6 | use std::task::{Context, Poll}; 7 | use tokio::io::{AsyncRead, AsyncWrite, Error, ErrorKind, ReadBuf, Result}; 8 | use tokio::net::{lookup_host, ToSocketAddrs}; 9 | 10 | pub struct UdtConnection { 11 | socket: SocketRef, 12 | } 13 | 14 | impl UdtConnection { 15 | pub(crate) fn new(socket: SocketRef) -> Self { 16 | Self { socket } 17 | } 18 | 19 | pub async fn connect( 20 | addr: impl ToSocketAddrs, 21 | config: Option, 22 | ) -> Result { 23 | Self::_bind_and_connect(None, addr, config).await 24 | } 25 | 26 | pub async fn bind_and_connect( 27 | bind_addr: SocketAddr, 28 | connect_addr: impl ToSocketAddrs, 29 | config: Option, 30 | ) -> Result { 31 | Self::_bind_and_connect(Some(bind_addr), connect_addr, config).await 32 | } 33 | 34 | async fn _bind_and_connect( 35 | bind_addr: Option, 36 | addrs: impl ToSocketAddrs, 37 | config: Option, 38 | ) -> Result { 39 | let socket = { 40 | let mut udt = Udt::get().write().await; 41 | udt.new_socket(SocketType::Stream, config)?.clone() 42 | }; 43 | 44 | let mut last_err = None; 45 | let mut connected = false; 46 | 47 | for addr in lookup_host(addrs).await? { 48 | match socket.connect(addr, bind_addr).await { 49 | Ok(()) => { 50 | connected = true; 51 | break; 52 | } 53 | Err(e) => { 54 | last_err = Some(e); 55 | } 56 | } 57 | } 58 | 59 | if !connected { 60 | return Err(last_err.unwrap_or_else(|| { 61 | Error::new(ErrorKind::InvalidInput, "could not resolve address") 62 | })); 63 | } 64 | 65 | loop { 66 | let status = socket.wait_for_connection().await; 67 | if status != UdtStatus::Connecting { 68 | break; 69 | } 70 | } 71 | Ok(Self::new(socket)) 72 | } 73 | 74 | pub async fn send(&self, msg: &[u8]) -> Result<()> { 75 | self.socket.send(msg) 76 | } 77 | 78 | pub async fn recv(&self, buf: &mut [u8]) -> Result { 79 | let nbytes = self.socket.recv(buf).await?; 80 | Ok(nbytes) 81 | } 82 | 83 | pub fn rate_control( 84 | &self, 85 | ) -> std::sync::RwLockWriteGuard<'_, crate::rate_control::RateControl> { 86 | self.socket.rate_control.write().unwrap() 87 | } 88 | 89 | pub async fn close(&self) { 90 | self.socket.close().await 91 | } 92 | 93 | pub fn socket_id(&self) -> u32 { 94 | self.socket.socket_id 95 | } 96 | } 97 | 98 | impl AsyncRead for UdtConnection { 99 | fn poll_read( 100 | self: Pin<&mut Self>, 101 | cx: &mut Context<'_>, 102 | buf: &mut ReadBuf<'_>, 103 | ) -> Poll> { 104 | match self.socket.poll_recv(buf) { 105 | Poll::Ready(res) => Poll::Ready(res.map(|_| ())), 106 | Poll::Pending => { 107 | let waker = cx.waker().clone(); 108 | let socket = self.socket.clone(); 109 | tokio::spawn(async move { 110 | socket.wait_for_data_to_read().await; 111 | waker.wake(); 112 | }); 113 | Poll::Pending 114 | } 115 | } 116 | } 117 | } 118 | 119 | impl AsyncWrite for UdtConnection { 120 | fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { 121 | let buf_len = buf.len(); 122 | match self.socket.send(buf) { 123 | Ok(_) => Poll::Ready(Ok(buf_len)), 124 | Err(err) => match err.kind() { 125 | ErrorKind::OutOfMemory => { 126 | let waker = cx.waker().clone(); 127 | let socket = self.socket.clone(); 128 | tokio::spawn(async move { 129 | socket.wait_for_next_ack_or_empty_snd_buffer().await; 130 | waker.wake(); 131 | }); 132 | Poll::Pending 133 | } 134 | _ => Poll::Ready(Err(err)), 135 | }, 136 | } 137 | } 138 | 139 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 140 | match self.socket.snd_buffer_is_empty() { 141 | true => Poll::Ready(Ok(())), 142 | false => { 143 | let waker = cx.waker().clone(); 144 | let socket = self.socket.clone(); 145 | tokio::spawn(async move { 146 | socket.wait_for_next_ack_or_empty_snd_buffer().await; 147 | waker.wake(); 148 | }); 149 | Poll::Pending 150 | } 151 | } 152 | } 153 | 154 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 155 | if self.socket.status() == UdtStatus::Closed { 156 | return Poll::Ready(Ok(())); 157 | } 158 | let socket = self.socket.clone(); 159 | let waker = cx.waker().clone(); 160 | tokio::spawn(async move { 161 | socket.close().await; 162 | waker.wake(); 163 | }); 164 | Poll::Pending 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/queue/snd_buffer.rs: -------------------------------------------------------------------------------- 1 | use crate::data_packet::{PacketPosition, UdtDataPacket, UdtDataPacketHeader}; 2 | use crate::seq_number::MsgNumber; 3 | use crate::seq_number::SeqNumber; 4 | use crate::socket::SocketId; 5 | use bytes::Bytes; 6 | use std::collections::VecDeque; 7 | use tokio::io::{Error, ErrorKind, Result as IoResult}; 8 | use tokio::time::{Duration, Instant}; 9 | 10 | const FETCH_BATCH_SIZE: usize = 100; 11 | const DEFAULT_PAYLOAD_SIZE: usize = 1500; 12 | 13 | #[derive(Debug, Clone)] 14 | pub(crate) struct SndBufferBlock { 15 | data: Bytes, 16 | msg_number: MsgNumber, 17 | origin_time: Instant, 18 | ttl: Option, // milliseconds, 19 | in_order: bool, 20 | position: PacketPosition, 21 | } 22 | 23 | impl SndBufferBlock { 24 | fn has_expired(&self) -> bool { 25 | if let Some(ttl) = self.ttl { 26 | return self.origin_time.elapsed() > Duration::from_millis(ttl); 27 | } 28 | false 29 | } 30 | 31 | fn as_data_packet( 32 | &self, 33 | seq_number: SeqNumber, 34 | dest_socket_id: SocketId, 35 | start_time: Instant, 36 | ) -> UdtDataPacket { 37 | UdtDataPacket { 38 | data: self.data.clone(), 39 | header: UdtDataPacketHeader { 40 | msg_number: self.msg_number, 41 | dest_socket_id, 42 | seq_number, 43 | in_order: self.in_order, 44 | position: self.position, 45 | timestamp: (start_time.elapsed().as_micros() & (u32::MAX as u128)) as u32, 46 | }, 47 | } 48 | } 49 | } 50 | 51 | #[derive(Debug)] 52 | pub(crate) struct SndBuffer { 53 | max_size: u32, 54 | buffer: VecDeque, 55 | payload_size: usize, 56 | next_msg_number: MsgNumber, 57 | current_position: usize, 58 | } 59 | 60 | impl SndBuffer { 61 | pub fn new(max_size: u32) -> Self { 62 | Self { 63 | max_size, 64 | buffer: VecDeque::new(), 65 | payload_size: DEFAULT_PAYLOAD_SIZE, // overwritten after connection 66 | next_msg_number: MsgNumber::zero(), 67 | current_position: 0, 68 | } 69 | } 70 | 71 | pub fn add_message(&mut self, data: &[u8], ttl: Option, in_order: bool) -> IoResult<()> { 72 | let msg_number = self.next_msg_number; 73 | let now = Instant::now(); 74 | let chunks = data.chunks(self.payload_size); 75 | let chunks_len = chunks.len(); 76 | 77 | if self.buffer.len() + chunks_len > self.max_size as usize { 78 | return Err(Error::new(ErrorKind::OutOfMemory, "Send buffer is full")); 79 | } 80 | 81 | self.buffer 82 | .extend(chunks.enumerate().map(|(idx, chunk)| SndBufferBlock { 83 | data: Bytes::copy_from_slice(chunk), 84 | msg_number, 85 | origin_time: now, 86 | ttl, 87 | in_order, 88 | position: { 89 | if idx == 0 && chunks_len == 1 { 90 | PacketPosition::Only 91 | } else if idx == 0 { 92 | PacketPosition::First 93 | } else if idx == chunks_len - 1 { 94 | PacketPosition::Last 95 | } else { 96 | PacketPosition::Middle 97 | } 98 | }, 99 | })); 100 | self.next_msg_number = self.next_msg_number + 1; 101 | Ok(()) 102 | } 103 | 104 | pub fn ack_data(&mut self, offset: i32) { 105 | for _ in 0..offset { 106 | if self.buffer.pop_front().is_some() { 107 | self.current_position -= 1; 108 | } 109 | } 110 | } 111 | 112 | pub fn read_data( 113 | &mut self, 114 | offset: usize, 115 | seq_number: SeqNumber, 116 | dest_socket_id: SocketId, 117 | start_time: Instant, 118 | ) -> Result { 119 | if let Some(block) = self.buffer.get(offset) { 120 | if block.has_expired() { 121 | // Move current_position to next message 122 | let mut pos = offset + 1; 123 | let mut msg_len = 1; 124 | while pos < self.buffer.len() { 125 | if self.buffer[pos].msg_number == block.msg_number { 126 | msg_len += 1; 127 | } else { 128 | break; 129 | } 130 | pos += 1; 131 | } 132 | if offset <= self.current_position && self.current_position < pos { 133 | self.current_position = pos; 134 | } 135 | Err((block.msg_number, msg_len)) 136 | } else { 137 | Ok(block.as_data_packet(seq_number, dest_socket_id, start_time)) 138 | } 139 | } else { 140 | Err((MsgNumber::zero(), 0)) // No msg found 141 | } 142 | } 143 | 144 | pub fn fetch_batch( 145 | &mut self, 146 | mut seq_number: SeqNumber, 147 | dest_socket_id: SocketId, 148 | start_time: Instant, 149 | ) -> Vec { 150 | let blocks: Vec<_> = self 151 | .buffer 152 | .range(self.current_position..) 153 | .take(FETCH_BATCH_SIZE) 154 | .map(|block| { 155 | let packet = block.as_data_packet(seq_number, dest_socket_id, start_time); 156 | seq_number = seq_number + 1; 157 | packet 158 | }) 159 | .collect(); 160 | self.current_position += blocks.len(); 161 | blocks 162 | } 163 | 164 | pub fn is_empty(&self) -> bool { 165 | self.buffer.is_empty() 166 | } 167 | 168 | pub fn set_payload_size(&mut self, payload_size: usize) { 169 | self.payload_size = payload_size; 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /src/queue/snd_queue.rs: -------------------------------------------------------------------------------- 1 | use crate::socket::{SocketId, UdtSocket}; 2 | use crate::udt::{SocketRef, Udt}; 3 | use std::cmp::Ordering; 4 | use std::collections::{BTreeMap, BinaryHeap}; 5 | use std::sync::{Arc, Mutex, Weak}; 6 | use tokio::io::Result; 7 | use tokio::sync::Notify; 8 | use tokio::time::Instant; 9 | 10 | const TOKIO_CHANNEL_CAPACITY: usize = 50; 11 | 12 | #[derive(Debug, PartialEq, Eq, Clone)] 13 | struct SendQueueNode { 14 | timestamp: Instant, 15 | socket_id: SocketId, 16 | } 17 | 18 | impl Ord for SendQueueNode { 19 | // Send queue should be sorted by smaller timestamp first 20 | fn cmp(&self, other: &Self) -> Ordering { 21 | self.timestamp.cmp(&other.timestamp).reverse() 22 | } 23 | } 24 | 25 | impl PartialOrd for SendQueueNode { 26 | fn partial_cmp(&self, other: &Self) -> Option { 27 | Some(self.cmp(other)) 28 | } 29 | } 30 | 31 | #[derive(Debug)] 32 | pub(crate) struct UdtSndQueue { 33 | queue: Mutex>, 34 | notify: Notify, 35 | start_time: Instant, 36 | socket_refs: Mutex>>, 37 | } 38 | 39 | impl UdtSndQueue { 40 | pub fn new() -> Self { 41 | UdtSndQueue { 42 | queue: Mutex::new(BinaryHeap::new()), 43 | notify: Notify::new(), 44 | start_time: Instant::now(), 45 | socket_refs: Mutex::new(BTreeMap::new()), 46 | } 47 | } 48 | 49 | async fn get_socket(&self, socket_id: SocketId) -> Option { 50 | let known_socket = self.socket_refs.lock().unwrap().get(&socket_id).cloned(); 51 | if let Some(socket) = known_socket { 52 | socket.upgrade() 53 | } else if let Some(socket) = Udt::get().read().await.get_socket(socket_id) { 54 | self.socket_refs 55 | .lock() 56 | .unwrap() 57 | .insert(socket_id, Arc::downgrade(&socket)); 58 | Some(socket) 59 | } else { 60 | None 61 | } 62 | } 63 | 64 | pub async fn worker(&self) -> Result<()> { 65 | let (tx, mut rx) = tokio::sync::mpsc::channel(TOKIO_CHANNEL_CAPACITY); 66 | 67 | tokio::spawn(async move { 68 | while let Some((socket, packets)) = rx.recv().await { 69 | let socket: SocketRef = socket; 70 | socket 71 | .send_data_packets(packets) 72 | .await 73 | .expect("failed to send packets") 74 | } 75 | }); 76 | 77 | loop { 78 | let next_node = { 79 | let mut sockets = self.queue.lock().unwrap(); 80 | let first_node = sockets.peek(); 81 | match first_node { 82 | Some(node) => { 83 | if node.timestamp <= Instant::now() { 84 | Ok(sockets.pop().unwrap()) 85 | } else { 86 | Err(Some(node.timestamp)) 87 | } 88 | } 89 | None => Err(None), 90 | } 91 | }; 92 | match next_node { 93 | Ok(node) => { 94 | if let Some(socket) = self.get_socket(node.socket_id).await { 95 | if let Some((packets, ts)) = socket.next_data_packets().await? { 96 | self.insert(ts, node.socket_id); 97 | tx.send((socket, packets)).await.unwrap(); 98 | } 99 | } 100 | } 101 | Err(Some(ts)) => { 102 | tokio::select! { 103 | _ = Self::sleep_until(ts) => {} 104 | _ = self.notify.notified() => {} 105 | } 106 | } 107 | _ => { 108 | self.notify.notified().await; 109 | } 110 | } 111 | } 112 | } 113 | 114 | pub fn insert(&self, ts: Instant, socket_id: SocketId) { 115 | let mut sockets = self.queue.lock().unwrap(); 116 | sockets.push(SendQueueNode { 117 | socket_id, 118 | timestamp: ts, 119 | }); 120 | if let Some(node) = sockets.peek() { 121 | if node.socket_id == socket_id { 122 | self.notify.notify_one(); 123 | } 124 | } 125 | } 126 | 127 | pub fn update(&self, socket_id: SocketId, reschedule: bool) { 128 | if reschedule { 129 | let mut sockets = self.queue.lock().unwrap(); 130 | if let Some(mut node) = sockets.peek_mut() { 131 | if node.socket_id == socket_id { 132 | node.timestamp = self.start_time; 133 | self.notify.notify_one(); 134 | return; 135 | } 136 | }; 137 | }; 138 | if !self 139 | .queue 140 | .lock() 141 | .unwrap() 142 | .iter() 143 | .any(|n| n.socket_id == socket_id) 144 | { 145 | self.insert(self.start_time, socket_id); 146 | } else if reschedule { 147 | self.remove(socket_id); 148 | self.insert(self.start_time, socket_id); 149 | } 150 | } 151 | 152 | pub fn remove(&self, socket_id: SocketId) { 153 | let mut sockets = self.queue.lock().unwrap(); 154 | *sockets = sockets 155 | .iter() 156 | .filter(|n| n.socket_id != socket_id) 157 | .cloned() 158 | .collect(); 159 | } 160 | 161 | #[cfg(target_os = "linux")] 162 | async fn sleep_until(instant: tokio::time::Instant) { 163 | tokio_timerfd::Delay::new(instant.into_std()) 164 | .expect("failed to init delay") 165 | .await 166 | .expect("timerfd failed") 167 | } 168 | 169 | #[cfg(not(target_os = "linux"))] 170 | async fn sleep_until(instant: tokio::time::Instant) { 171 | tokio::time::sleep_until(instant).await 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /src/multiplexer.rs: -------------------------------------------------------------------------------- 1 | use super::configuration::UdtConfiguration; 2 | use super::packet::UdtPacket; 3 | use crate::queue::{UdtRcvQueue, UdtSndQueue}; 4 | use crate::udt::SocketRef; 5 | use socket2::{Domain, Socket, Type}; 6 | use std::io::Result; 7 | use std::net::{Ipv4Addr, SocketAddr}; 8 | use std::sync::Arc; 9 | use tokio::net::UdpSocket; 10 | use tokio::sync::RwLock; 11 | 12 | pub type MultiplexerId = u32; 13 | 14 | #[derive(Debug)] 15 | pub struct UdtMultiplexer { 16 | pub id: MultiplexerId, 17 | pub port: u16, 18 | pub channel: Arc, 19 | pub reusable: bool, 20 | pub mss: u32, 21 | 22 | pub(crate) snd_queue: UdtSndQueue, 23 | pub(crate) rcv_queue: UdtRcvQueue, 24 | pub listener: RwLock>, 25 | } 26 | 27 | impl UdtMultiplexer { 28 | async fn new_udp_socket( 29 | config: &UdtConfiguration, 30 | bind_addr: Option, 31 | ) -> Result { 32 | let bind_addr = bind_addr.unwrap_or_else(|| (Ipv4Addr::UNSPECIFIED, 0).into()); 33 | let domain = if bind_addr.is_ipv6() { 34 | Domain::IPV6 35 | } else { 36 | Domain::IPV4 37 | }; 38 | tokio::task::spawn_blocking({ 39 | let config = config.clone(); 40 | move || { 41 | let socket = Socket::new(domain, Type::DGRAM, None)?; 42 | socket.set_recv_buffer_size(config.udp_rcv_buf_size)?; 43 | socket.set_send_buffer_size(config.udp_snd_buf_size)?; 44 | socket.set_reuse_port(config.udp_reuse_port)?; 45 | socket.set_nonblocking(true)?; 46 | socket.bind(&bind_addr.into())?; 47 | UdpSocket::from_std(socket.into()) 48 | } 49 | }) 50 | .await? 51 | } 52 | 53 | pub(crate) async fn new( 54 | id: MultiplexerId, 55 | config: &UdtConfiguration, 56 | ) -> Result<(MultiplexerId, Arc)> { 57 | let udp_socket = Self::new_udp_socket(config, None).await?; 58 | let channel = Arc::new(udp_socket); 59 | let port = channel.local_addr()?.port(); 60 | 61 | let mux = Self { 62 | id, 63 | port, 64 | reusable: config.reuse_mux, 65 | mss: config.mss, 66 | channel: channel.clone(), 67 | snd_queue: UdtSndQueue::new(), 68 | rcv_queue: UdtRcvQueue::new(channel, config.mss), 69 | listener: RwLock::new(None), 70 | }; 71 | 72 | let mux = Arc::new(mux); 73 | mux.rcv_queue.set_multiplexer(&mux); 74 | Ok((id, mux)) 75 | } 76 | 77 | pub(crate) async fn bind( 78 | id: MultiplexerId, 79 | bind_addr: SocketAddr, 80 | config: &UdtConfiguration, 81 | ) -> Result<(MultiplexerId, Arc)> { 82 | let udp_socket = Self::new_udp_socket(config, Some(bind_addr)).await?; 83 | let port = udp_socket.local_addr()?.port(); 84 | 85 | let channel = Arc::new(udp_socket); 86 | let mux = Self { 87 | id, 88 | port, 89 | reusable: config.reuse_mux, 90 | mss: config.mss, 91 | channel: channel.clone(), 92 | snd_queue: UdtSndQueue::new(), 93 | rcv_queue: UdtRcvQueue::new(channel, config.mss), 94 | listener: RwLock::new(None), 95 | }; 96 | 97 | let mux = Arc::new(mux); 98 | mux.rcv_queue.set_multiplexer(&mux); 99 | Ok((id, mux)) 100 | } 101 | 102 | pub(crate) async fn send_to(&self, addr: &SocketAddr, packet: UdtPacket) -> Result { 103 | self.channel.send_to(&packet.serialize(), addr).await 104 | } 105 | 106 | #[cfg(target_os = "linux")] 107 | pub(crate) async fn send_mmsg_to( 108 | &self, 109 | addr: &SocketAddr, 110 | packets: impl Iterator, 111 | ) -> Result { 112 | use nix::sys::socket::{sendmmsg, MsgFlags, SendMmsgData, SockaddrStorage}; 113 | use std::io::IoSlice; 114 | use std::os::unix::io::AsRawFd; 115 | use tokio::io::{Error, ErrorKind, Interest}; 116 | let data: Vec<_> = packets.map(|p| p.serialize()).collect(); 117 | let dest: SockaddrStorage = (*addr).into(); 118 | let buffers: Vec> = data 119 | .iter() 120 | .map(|packet| SendMmsgData { 121 | iov: [IoSlice::new(packet)], 122 | cmsgs: &[], 123 | addr: Some(dest), 124 | _lt: Default::default(), 125 | }) 126 | .collect(); 127 | self.channel.writable().await?; 128 | let sent = self 129 | .channel 130 | .try_io(Interest::WRITABLE, || { 131 | let sock_fd = self.channel.as_raw_fd(); 132 | let sent: usize = sendmmsg(sock_fd, &buffers, MsgFlags::MSG_DONTWAIT) 133 | .map_err(|err| { 134 | if err == nix::errno::Errno::EWOULDBLOCK { 135 | return Error::new(ErrorKind::WouldBlock, "sendmmsg would block"); 136 | } 137 | Error::new(ErrorKind::Other, err) 138 | })? 139 | .into_iter() 140 | .sum(); 141 | Ok(sent) 142 | }) 143 | .unwrap_or(0); 144 | Ok(sent) 145 | } 146 | 147 | #[cfg(not(target_os = "linux"))] 148 | pub(crate) async fn send_mmsg_to( 149 | &self, 150 | addr: &SocketAddr, 151 | packets: impl Iterator, 152 | ) -> Result { 153 | self.channel.writable().await?; 154 | let mut sent = 0; 155 | for data in packets.map(|p| p.serialize()) { 156 | sent += self.channel.send_to(&data, addr).await?; 157 | } 158 | Ok(sent) 159 | } 160 | 161 | // pub fn get_local_addr(&self) -> SocketAddr { 162 | // self.channel 163 | // .local_addr() 164 | // .expect("failed to retrieve udp local addr") 165 | // } 166 | 167 | pub fn run(mux: Arc) { 168 | tokio::spawn({ 169 | let mux = mux.clone(); 170 | async move { mux.rcv_queue.worker().await.unwrap() } 171 | }); 172 | tokio::spawn(async move { mux.snd_queue.worker().await.unwrap() }); 173 | } 174 | } 175 | -------------------------------------------------------------------------------- /src/rate_control.rs: -------------------------------------------------------------------------------- 1 | use crate::flow::UdtFlow; 2 | use crate::seq_number::SeqNumber; 3 | use crate::socket::SYN_INTERVAL; 4 | use rand::Rng; 5 | use tokio::time::{Duration, Instant}; 6 | 7 | #[derive(Debug)] 8 | pub struct RateControl { 9 | pkt_send_period: Duration, 10 | congestion_window_size: f64, 11 | max_window_size: f64, 12 | recv_rate: u32, 13 | bandwidth: u32, 14 | rtt: Duration, 15 | mss: f64, 16 | 17 | curr_snd_seq_number: SeqNumber, 18 | rc_interval: Duration, 19 | last_rate_increase: Instant, 20 | slow_start: bool, 21 | last_ack: SeqNumber, 22 | loss: bool, // has lost happenened since last rate increase 23 | last_dec_seq: SeqNumber, 24 | last_dec_period: Duration, 25 | nak_count: usize, 26 | dec_random: usize, 27 | avg_nak_num: usize, 28 | dec_count: usize, 29 | 30 | ack_period: Duration, 31 | ack_pkt_interval: usize, 32 | } 33 | 34 | impl RateControl { 35 | pub(crate) fn new() -> Self { 36 | Self { 37 | pkt_send_period: Duration::from_micros(1), 38 | congestion_window_size: 16.0, 39 | max_window_size: 16.0, 40 | recv_rate: 0, 41 | bandwidth: 0, 42 | rtt: Duration::default(), 43 | mss: 1500.0, 44 | 45 | curr_snd_seq_number: SeqNumber::zero(), 46 | rc_interval: SYN_INTERVAL, 47 | last_rate_increase: Instant::now(), 48 | slow_start: true, 49 | last_ack: SeqNumber::zero(), 50 | loss: false, 51 | last_dec_seq: SeqNumber::zero() - 1, 52 | last_dec_period: Duration::from_micros(1), 53 | nak_count: 0, 54 | avg_nak_num: 0, 55 | dec_random: 1, 56 | dec_count: 0, 57 | 58 | ack_period: SYN_INTERVAL, 59 | ack_pkt_interval: 0, 60 | } 61 | } 62 | 63 | pub(crate) fn init(&mut self, mss: u32, flow: &UdtFlow, seq_number: SeqNumber) { 64 | self.last_rate_increase = Instant::now(); 65 | self.mss = mss as f64; 66 | self.max_window_size = flow.flow_window_size as f64; 67 | 68 | self.slow_start = true; 69 | self.loss = false; 70 | self.curr_snd_seq_number = seq_number; 71 | self.last_ack = seq_number; 72 | self.last_dec_seq = seq_number - 1; 73 | 74 | self.recv_rate = flow.peer_delivery_rate; 75 | self.bandwidth = flow.peer_bandwidth; 76 | self.rtt = flow.rtt; 77 | } 78 | 79 | pub fn get_pkt_send_period(&self) -> Duration { 80 | self.pkt_send_period 81 | } 82 | 83 | pub fn get_congestion_window_size(&self) -> u32 { 84 | self.congestion_window_size as u32 85 | } 86 | 87 | pub fn get_ack_pkt_interval(&self) -> usize { 88 | self.ack_pkt_interval 89 | } 90 | 91 | pub fn get_ack_period(&self) -> Duration { 92 | std::cmp::min(SYN_INTERVAL, self.ack_period) 93 | } 94 | 95 | pub fn set_rtt(&mut self, rtt: Duration) { 96 | self.rtt = rtt; 97 | } 98 | 99 | pub fn set_rcv_rate(&mut self, pkt_per_sec: u32) { 100 | self.recv_rate = pkt_per_sec; 101 | } 102 | 103 | pub fn set_bandwidth(&mut self, pkt_per_sec: u32) { 104 | self.bandwidth = pkt_per_sec; 105 | } 106 | 107 | pub fn set_pkt_interval(&mut self, nb_pkts: usize) { 108 | self.ack_pkt_interval = nb_pkts; 109 | } 110 | 111 | pub fn on_ack(&mut self, ack: SeqNumber) { 112 | const MIN_INC: f64 = 0.01; 113 | 114 | let now = Instant::now(); 115 | if (now - self.last_rate_increase) < self.rc_interval { 116 | return; 117 | } 118 | self.last_rate_increase = now; 119 | 120 | if self.slow_start { 121 | self.congestion_window_size += (ack - self.last_ack) as f64; 122 | self.last_ack = ack; 123 | if self.congestion_window_size > self.max_window_size { 124 | self.slow_start = false; 125 | if self.recv_rate > 0 { 126 | self.pkt_send_period = Duration::from_secs(1) / self.recv_rate; 127 | } else { 128 | self.pkt_send_period = 129 | (self.rtt + self.rc_interval).div_f64(self.congestion_window_size); 130 | } 131 | } 132 | } else { 133 | self.congestion_window_size = 134 | self.recv_rate as f64 * (self.rtt + self.rc_interval).as_secs_f64() + 16.0 135 | } 136 | 137 | if self.slow_start { 138 | return; 139 | } 140 | 141 | if self.loss { 142 | self.loss = false; 143 | return; 144 | } 145 | 146 | let mut b = self.bandwidth as f64 - 1.0 / self.pkt_send_period.as_secs_f64(); 147 | if (self.pkt_send_period > self.last_dec_period) && (self.bandwidth as f64 / 9.0 < b) { 148 | b = self.bandwidth as f64 / 9.0; 149 | } 150 | let increase = if b <= 0.0 { 151 | MIN_INC 152 | } else { 153 | let inc = 10.0_f64.powf((b * self.mss as f64 * 8.0).log10().ceil()) * 1.5e-6 / self.mss; 154 | if inc < MIN_INC { 155 | MIN_INC 156 | } else { 157 | inc 158 | } 159 | }; 160 | self.pkt_send_period = Duration::from_secs_f64( 161 | (self.pkt_send_period.as_secs_f64() * self.rc_interval.as_secs_f64()) 162 | / (self.pkt_send_period.mul_f64(increase) + self.rc_interval).as_secs_f64(), 163 | ); 164 | } 165 | 166 | pub fn on_loss(&mut self, loss_seq: SeqNumber) { 167 | if self.slow_start { 168 | self.slow_start = false; 169 | if self.recv_rate > 0 { 170 | self.pkt_send_period = Duration::from_secs(1) / self.recv_rate; 171 | return; 172 | } 173 | self.pkt_send_period = 174 | (self.rtt + self.rc_interval).div_f64(self.congestion_window_size); 175 | } 176 | 177 | self.loss = true; 178 | if (loss_seq - self.last_dec_seq) > 0 { 179 | self.last_dec_period = self.pkt_send_period; 180 | self.pkt_send_period = self.pkt_send_period.mul_f64(1.125); 181 | self.avg_nak_num = 182 | (self.avg_nak_num as f64 * 0.875 + self.nak_count as f64 * 0.125).ceil() as usize; 183 | self.nak_count = 1; 184 | self.dec_count = 1; 185 | self.last_dec_seq = self.curr_snd_seq_number; 186 | 187 | self.dec_random = if self.avg_nak_num == 0 { 188 | 1 189 | } else { 190 | rand::thread_rng().gen_range(1..=self.avg_nak_num) 191 | }; 192 | } else { 193 | self.dec_count += 1; 194 | if self.dec_count <= 5 { 195 | self.nak_count += 1; 196 | if self.nak_count % self.dec_random == 0 { 197 | self.pkt_send_period = self.pkt_send_period.mul_f64(1.125); 198 | self.last_dec_seq = self.curr_snd_seq_number; 199 | } 200 | } 201 | } 202 | } 203 | 204 | pub fn set_curr_snd_seq_number(&mut self, seq: SeqNumber) { 205 | self.curr_snd_seq_number = seq; 206 | } 207 | 208 | pub fn on_timeout(&mut self) { 209 | if self.slow_start { 210 | self.slow_start = false; 211 | if self.recv_rate > 0 { 212 | self.pkt_send_period = Duration::from_secs(1) / self.recv_rate; 213 | } else { 214 | self.pkt_send_period = 215 | (self.rtt + self.rc_interval).div_f64(self.congestion_window_size); 216 | } 217 | } 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /src/loss_list.rs: -------------------------------------------------------------------------------- 1 | use crate::seq_number::SeqNumber; 2 | use std::collections::BTreeMap; 3 | 4 | #[derive(Debug)] 5 | pub(crate) struct LossList { 6 | sequences: BTreeMap, 7 | } 8 | 9 | impl LossList { 10 | pub fn new() -> Self { 11 | Self { 12 | sequences: BTreeMap::new(), 13 | } 14 | } 15 | 16 | pub fn insert(&mut self, n1: SeqNumber, n2: SeqNumber) { 17 | // TODO: limit size of loss list 18 | 19 | if n1.number() > n2.number() { 20 | self.insert(n1, SeqNumber::max()); 21 | self.insert(SeqNumber::zero(), n2); 22 | return; 23 | } 24 | 25 | let mut n2 = n2; 26 | if n2.number() > n1.number() { 27 | let mut keys_to_remove = vec![]; 28 | for (key, (_start, end)) in self.sequences.range((n1 + 1)..=n2) { 29 | keys_to_remove.push(*key); 30 | if *end > n2 { 31 | n2 = *end; 32 | } 33 | } 34 | for key in keys_to_remove { 35 | self.sequences.remove(&key); 36 | } 37 | } 38 | 39 | if let Some((_, (_start, end))) = self.sequences.range_mut(..=n1).next_back() { 40 | if *end >= n1 - 1 { 41 | *end = std::cmp::max(*end, n2); 42 | return; 43 | } 44 | } 45 | self.sequences.insert(n1, (n1, n2)); 46 | } 47 | 48 | pub fn remove(&mut self, num: SeqNumber) { 49 | if let Some((key, (start, end))) = self.sequences.range_mut(..=num).next_back() { 50 | if *start == num { 51 | let key = *key; 52 | let end = *end; 53 | self.sequences.remove(&key); 54 | if end > num { 55 | self.sequences.insert(num + 1, (num + 1, end)); 56 | } 57 | } else if *end >= num { 58 | let current_end = *end; 59 | *end = num - 1; 60 | if current_end > num { 61 | self.sequences.insert(num + 1, (num + 1, current_end)); 62 | } 63 | } 64 | } 65 | } 66 | 67 | pub fn remove_all(&mut self, n1: SeqNumber, n2: SeqNumber) { 68 | if n1 <= n2 { 69 | for i in (n1.number()..=n2.number()).rev() { 70 | self.remove(i.into()); 71 | } 72 | } else { 73 | self.remove_all(n1, SeqNumber::max()); 74 | self.remove_all(SeqNumber::zero(), n2); 75 | } 76 | } 77 | 78 | // fn find(&self, n1: SeqNumber, n2: SeqNumber) -> bool { 79 | // if n1 > n2 { 80 | // return self.find(n1, SeqNumber::max()) || self.find(SeqNumber::zero(), n2); 81 | // } 82 | 83 | // if let Some((_, (_start, end))) = self.sequences.range(..n1).next_back() { 84 | // if *end >= n1 { 85 | // return true; 86 | // } 87 | // } else if let Some((_, (start, _end))) = self.sequences.range(n1..).next() { 88 | // if *start <= n2 { 89 | // return true; 90 | // } 91 | // } 92 | // false 93 | // } 94 | 95 | pub fn is_empty(&self) -> bool { 96 | self.sequences.is_empty() 97 | } 98 | 99 | // pub fn get_loss_array(&self, limit: usize) -> Vec { 100 | // let mut array: Vec<_> = self 101 | // .sequences 102 | // .values() 103 | // .flat_map(|(start, end)| { 104 | // if start == end { 105 | // vec![start.number()] 106 | // } else { 107 | // vec![start.number() | 0x8000000, end.number()] 108 | // } 109 | // }) 110 | // .take(limit) 111 | // .collect(); 112 | 113 | // if let Some(v) = array.last() { 114 | // if *v >= 0x8000000 { 115 | // array.pop(); 116 | // } 117 | // } 118 | // array 119 | // } 120 | 121 | pub fn pop_after(&mut self, after: SeqNumber) -> Option { 122 | if self.sequences.is_empty() { 123 | return None; 124 | } 125 | if let Some((_, (_start, end))) = self.sequences.range(..=after).next_back() { 126 | if *end >= after { 127 | self.remove(after); 128 | return Some(after); 129 | } 130 | } 131 | if let Some((_, (start, _end))) = self.sequences.range(after..).next() { 132 | let start = *start; 133 | self.remove(start); 134 | return Some(start); 135 | } 136 | if let Some((_, (start, _end))) = self.sequences.iter().next() { 137 | let start = *start; 138 | self.remove(start); 139 | return Some(start); 140 | } 141 | None 142 | } 143 | 144 | pub fn peek_after(&self, after: SeqNumber) -> Option { 145 | if self.sequences.is_empty() { 146 | return None; 147 | } 148 | if let Some((_, (_start, end))) = self.sequences.range(..=after).next_back() { 149 | if *end >= after { 150 | return Some(after); 151 | } 152 | } 153 | if let Some((_, (start, _end))) = self.sequences.range(after..).next() { 154 | return Some(*start); 155 | } 156 | if let Some((_, (start, _end))) = self.sequences.iter().next() { 157 | return Some(*start); 158 | } 159 | None 160 | } 161 | } 162 | 163 | #[test] 164 | fn test_insert_sequences() { 165 | let mut loss_list = crate::loss_list::LossList::new(); 166 | loss_list.insert(5.into(), 10.into()); 167 | loss_list.insert(1.into(), 2.into()); 168 | assert_eq!(loss_list.sequences.len(), 2); 169 | 170 | let items: Vec<_> = loss_list.sequences.clone().into_iter().collect(); 171 | assert_eq!( 172 | items, 173 | [ 174 | (1.into(), (1.into(), 2.into())), 175 | (5.into(), (5.into(), 10.into())), 176 | ] 177 | ); 178 | 179 | assert_eq!(loss_list.peek_after(1.into()), Some(1.into())); 180 | assert_eq!(loss_list.peek_after(4.into()), Some(5.into())); 181 | assert_eq!(loss_list.peek_after(10.into()), Some(10.into())); 182 | assert_eq!(loss_list.peek_after(11.into()), Some(1.into())); 183 | } 184 | 185 | #[test] 186 | fn test_insert_overlapping_sequence() { 187 | let mut loss_list = crate::loss_list::LossList::new(); 188 | loss_list.insert(1.into(), 10.into()); 189 | loss_list.insert(5.into(), 20.into()); 190 | assert_eq!(loss_list.sequences.len(), 1); 191 | let items: Vec<_> = loss_list.sequences.into_iter().collect(); 192 | assert_eq!(items, [(1.into(), (1.into(), 20.into())),]); 193 | } 194 | 195 | #[test] 196 | fn test_insert_with_multiple_overlapping_sequences() { 197 | let mut loss_list = crate::loss_list::LossList::new(); 198 | loss_list.insert(6.into(), 10.into()); 199 | loss_list.insert(12.into(), 25.into()); 200 | loss_list.insert(1.into(), 22.into()); 201 | assert_eq!(loss_list.sequences.len(), 1); 202 | let items: Vec<_> = loss_list.sequences.into_iter().collect(); 203 | assert_eq!(items, [(1.into(), (1.into(), 25.into())),]); 204 | } 205 | 206 | #[test] 207 | fn test_insert_with_bigger_existing_sequence() { 208 | let mut loss_list = crate::loss_list::LossList::new(); 209 | loss_list.insert(10.into(), 30.into()); 210 | loss_list.insert(10.into(), 20.into()); 211 | assert_eq!(loss_list.sequences.len(), 1); 212 | let items: Vec<_> = loss_list.sequences.into_iter().collect(); 213 | assert_eq!(items, [(10.into(), (10.into(), 30.into())),]); 214 | } 215 | 216 | #[test] 217 | fn test_remove_seq_inside_sequence() { 218 | let mut loss_list = crate::loss_list::LossList::new(); 219 | loss_list.insert(1.into(), 10.into()); 220 | loss_list.remove(5.into()); 221 | 222 | assert_eq!(loss_list.sequences.len(), 2); 223 | let items: Vec<_> = loss_list.sequences.into_iter().collect(); 224 | assert_eq!( 225 | items, 226 | [ 227 | (1.into(), (1.into(), 4.into())), 228 | (6.into(), (6.into(), 10.into())), 229 | ] 230 | ); 231 | } 232 | 233 | #[test] 234 | fn test_remove_first_item() { 235 | let mut loss_list = crate::loss_list::LossList::new(); 236 | loss_list.insert(1.into(), 10.into()); 237 | loss_list.remove(1.into()); 238 | 239 | assert_eq!(loss_list.sequences.len(), 1); 240 | let items: Vec<_> = loss_list.sequences.into_iter().collect(); 241 | assert_eq!(items, [(2.into(), (2.into(), 10.into())),]); 242 | } 243 | -------------------------------------------------------------------------------- /src/queue/rcv_queue.rs: -------------------------------------------------------------------------------- 1 | use crate::multiplexer::UdtMultiplexer; 2 | use crate::packet::UdtPacket; 3 | use crate::socket::{SocketId, UdtSocket}; 4 | use crate::udt::{SocketRef, Udt, UDT_DEBUG}; 5 | use nix::sys::socket::{SockaddrIn, SockaddrIn6}; 6 | use std::collections::{BTreeMap, VecDeque}; 7 | use std::net::SocketAddr; 8 | use std::sync::{Arc, Mutex, Weak}; 9 | use tokio::io::{Error, ErrorKind, Result}; 10 | use tokio::net::UdpSocket; 11 | use tokio::time::{Duration, Instant}; 12 | 13 | #[cfg(not(target_os = "linux"))] 14 | use tokio::time::sleep; 15 | #[cfg(target_os = "linux")] 16 | use tokio_timerfd::sleep; 17 | 18 | const TIMERS_CHECK_INTERVAL: Duration = Duration::from_millis(100); 19 | const UDP_RCV_TIMEOUT: Duration = Duration::from_micros(30); 20 | 21 | #[derive(Debug)] 22 | pub(crate) struct UdtRcvQueue { 23 | sockets: Mutex>, 24 | mss: u32, 25 | channel: Arc, 26 | multiplexer: Mutex>, 27 | socket_refs: Mutex>>, 28 | } 29 | 30 | impl UdtRcvQueue { 31 | pub fn new(channel: Arc, mss: u32) -> Self { 32 | Self { 33 | sockets: Mutex::new(VecDeque::new()), 34 | mss, 35 | channel, 36 | multiplexer: Mutex::new(Weak::new()), 37 | socket_refs: Mutex::new(BTreeMap::new()), 38 | } 39 | } 40 | 41 | pub fn push_back(&self, socket_id: SocketId) { 42 | self.sockets 43 | .lock() 44 | .unwrap() 45 | .push_back((Instant::now(), socket_id)); 46 | } 47 | 48 | fn update(&self, socket_id: SocketId) { 49 | let mut queue = self.sockets.lock().unwrap(); 50 | queue.retain(|(_, id)| socket_id != *id); 51 | queue.push_back((Instant::now(), socket_id)); 52 | } 53 | 54 | pub fn set_multiplexer(&self, mux: &Arc) { 55 | *self.multiplexer.lock().unwrap() = Arc::downgrade(mux); 56 | } 57 | 58 | async fn get_socket(&self, socket_id: SocketId) -> Option { 59 | let known_socket = self.socket_refs.lock().unwrap().get(&socket_id).cloned(); 60 | if let Some(socket) = known_socket { 61 | socket.upgrade() 62 | } else if let Some(socket) = Udt::get().read().await.get_socket(socket_id) { 63 | self.socket_refs 64 | .lock() 65 | .unwrap() 66 | .insert(socket_id, Arc::downgrade(&socket)); 67 | Some(socket) 68 | } else { 69 | None 70 | } 71 | } 72 | 73 | #[cfg(target_os = "linux")] 74 | fn receive_packets(&self, buf: &mut [u8]) -> Result> { 75 | use nix::sys::socket::{ 76 | recvmmsg, AddressFamily, MsgFlags, RecvMmsgData, SockaddrLike, SockaddrStorage, 77 | }; 78 | use std::io::IoSliceMut; 79 | use std::os::unix::io::AsRawFd; 80 | use tokio::io::Interest; 81 | let bufs = buf.chunks_exact_mut(self.mss as usize); 82 | let mut recv_mesg_data: Vec> = bufs 83 | .map(|b| RecvMmsgData { 84 | iov: [IoSliceMut::new(&mut b[..])], 85 | cmsg_buffer: None, 86 | }) 87 | .collect(); 88 | 89 | self.channel.try_io(Interest::READABLE, || { 90 | let msgs = recvmmsg( 91 | self.channel.as_raw_fd(), 92 | &mut recv_mesg_data, 93 | MsgFlags::MSG_DONTWAIT, 94 | None, 95 | ) 96 | .map_err(|err| { 97 | if err == nix::errno::Errno::EWOULDBLOCK { 98 | return Error::new(ErrorKind::WouldBlock, "recvmmsg would block"); 99 | } 100 | Error::new(ErrorKind::Other, err) 101 | })? 102 | .iter() 103 | .map(|msg| { 104 | let addr: SockaddrStorage = msg.address.unwrap(); 105 | let socket_addr: SocketAddr = match addr.family() { 106 | Some(AddressFamily::Inet) => { 107 | Self::addr_v4_from_sockaddrin(*addr.as_sockaddr_in().unwrap()).into() 108 | } 109 | Some(AddressFamily::Inet6) => { 110 | Self::addr_v6_from_sockaddrin6(*addr.as_sockaddr_in6().unwrap()).into() 111 | } 112 | _ => unreachable!(), 113 | }; 114 | (msg.bytes, socket_addr) 115 | }) 116 | .collect(); 117 | Ok(msgs) 118 | }) 119 | } 120 | 121 | #[cfg(not(target_os = "linux"))] 122 | fn receive_packets(&self, buf: &mut [u8]) -> Result> { 123 | let bufs = buf.chunks_exact_mut(self.mss as usize); 124 | let mut msgs = vec![]; 125 | for mut buf in bufs { 126 | match self.channel.try_recv_from(&mut buf) { 127 | Ok(msg) => { 128 | msgs.push(msg); 129 | } 130 | Err(e) if e.kind() == ErrorKind::WouldBlock => break, 131 | Err(e) => return Err(e), 132 | } 133 | } 134 | Ok(msgs) 135 | } 136 | 137 | pub(crate) async fn worker(&self) -> Result<()> { 138 | let mut buf = vec![0_u8; self.mss as usize * 100]; 139 | loop { 140 | let packets = { 141 | let msgs = self.receive_packets(&mut buf).unwrap_or_default(); 142 | if !msgs.is_empty() { 143 | let packets: Vec<_> = msgs 144 | .into_iter() 145 | .zip(buf.chunks_exact_mut(self.mss as usize)) 146 | .filter_map(|((nbytes, addr), buf)| { 147 | let packet = UdtPacket::deserialize(&buf[..nbytes]).ok()?; 148 | 149 | Some((packet, addr)) 150 | }) 151 | .collect(); 152 | Some(packets) 153 | } else { 154 | tokio::select! { 155 | _ = sleep(UDP_RCV_TIMEOUT) => (), 156 | _ = self.channel.readable() => () 157 | }; 158 | None 159 | } 160 | }; 161 | 162 | for (packet, addr) in packets.into_iter().flatten() { 163 | let socket_id = packet.get_dest_socket_id(); 164 | if socket_id == 0 { 165 | if let Some(handshake) = packet.handshake() { 166 | let mux = { 167 | let lock = self.multiplexer.lock().unwrap(); 168 | lock.upgrade() 169 | }; 170 | if let Some(mux) = mux { 171 | let listener = mux.listener.read().await; 172 | if let Some(listener) = &*listener { 173 | listener.listen_on_handshake(addr, handshake).await?; 174 | } 175 | } 176 | } else { 177 | return Err(Error::new( 178 | ErrorKind::InvalidData, 179 | "received non-hanshake packet with socket 0", 180 | )); 181 | } 182 | } else if let Some(socket) = self.get_socket(socket_id).await { 183 | if socket.peer_addr() == Some(addr) && socket.status().is_alive() { 184 | socket.process_packet(packet).await?; 185 | socket.check_timers().await; 186 | self.update(socket_id); 187 | } else if *UDT_DEBUG { 188 | eprintln!("Ignoring packet {:?}", packet); 189 | } 190 | } else { 191 | // TODO: implement rendezvous queue for rendezvous mode 192 | 193 | if *UDT_DEBUG { 194 | eprintln!("socket not found for socket_id {}", socket_id); 195 | dbg!(packet); 196 | } 197 | } 198 | } 199 | 200 | let to_check = { 201 | let mut to_check = vec![]; 202 | let mut sockets = self.sockets.lock().unwrap(); 203 | while sockets 204 | .front() 205 | .map(|(ts, _)| ts.elapsed() > TIMERS_CHECK_INTERVAL) 206 | .unwrap_or(false) 207 | { 208 | to_check.push(sockets.pop_front().unwrap().1); 209 | } 210 | to_check 211 | }; 212 | 213 | for socket_id in to_check { 214 | if let Some(socket) = self.get_socket(socket_id).await { 215 | if socket.status().is_alive() { 216 | socket.check_timers().await; 217 | self.update(socket_id); 218 | } 219 | } 220 | } 221 | } 222 | } 223 | 224 | // TEMP: waiting for "nix" next release (> 0.24.2) to include these conversions 225 | fn addr_v4_from_sockaddrin(addr: SockaddrIn) -> std::net::SocketAddrV4 { 226 | std::net::SocketAddrV4::new(std::net::Ipv4Addr::from(addr.ip()), addr.port()) 227 | } 228 | 229 | fn addr_v6_from_sockaddrin6(addr: SockaddrIn6) -> std::net::SocketAddrV6 { 230 | std::net::SocketAddrV6::new( 231 | addr.ip(), 232 | addr.port(), 233 | u32::from_be(addr.flowinfo()), 234 | u32::from_be(addr.scope_id()), 235 | ) 236 | } 237 | } 238 | -------------------------------------------------------------------------------- /src/udt.rs: -------------------------------------------------------------------------------- 1 | use super::configuration::UdtConfiguration; 2 | use crate::control_packet::{HandShakeInfo, UdtControlPacket}; 3 | use crate::multiplexer::{MultiplexerId, UdtMultiplexer}; 4 | use crate::seq_number::SeqNumber; 5 | use crate::socket::{SocketId, SocketType, UdtSocket, UdtStatus}; 6 | use once_cell::sync::{Lazy, OnceCell}; 7 | use std::collections::btree_map::Entry; 8 | use std::collections::{BTreeMap, BTreeSet}; 9 | use std::io::{Error, ErrorKind, Result}; 10 | use std::net::SocketAddr; 11 | use std::sync::Arc; 12 | use tokio::sync::RwLock; 13 | use tokio::time::sleep; 14 | 15 | pub(crate) type SocketRef = Arc; 16 | 17 | static UDT_INSTANCE: OnceCell> = OnceCell::new(); 18 | pub(crate) static UDT_DEBUG: Lazy = 19 | Lazy::new(|| std::env::var("UDT_DEBUG").unwrap_or_default() != ""); 20 | 21 | #[derive(Default, Debug)] 22 | pub(crate) struct Udt { 23 | sockets: BTreeMap, 24 | multiplexers: BTreeMap>, 25 | next_socket_id: SocketId, 26 | peers: BTreeMap<(SocketId, SeqNumber), BTreeSet>, // peer socket id -> local socket id 27 | } 28 | 29 | impl Udt { 30 | fn new() -> Self { 31 | Self { 32 | next_socket_id: rand::random(), 33 | ..Default::default() 34 | } 35 | } 36 | 37 | pub fn get() -> &'static RwLock { 38 | UDT_INSTANCE.get_or_init(|| { 39 | Udt::cleanup_worker(); 40 | RwLock::new(Udt::new()) 41 | }) 42 | } 43 | 44 | fn get_new_socket_id(&mut self) -> SocketId { 45 | let socket_id = self.next_socket_id; 46 | self.next_socket_id = self.next_socket_id.wrapping_sub(1); 47 | socket_id 48 | } 49 | 50 | pub(crate) fn get_socket(&self, socket_id: SocketId) -> Option { 51 | if let Some(socket) = self.sockets.get(&socket_id) { 52 | if socket.status() != UdtStatus::Closed { 53 | return Some(socket.clone()); 54 | } 55 | } 56 | None 57 | } 58 | 59 | pub(crate) async fn get_peer_socket( 60 | &self, 61 | peer: SocketAddr, 62 | socket_id: SocketId, 63 | initial_seq_number: SeqNumber, 64 | ) -> Option { 65 | for socket in self 66 | .peers 67 | .get(&(socket_id, initial_seq_number))? 68 | .iter() 69 | .filter_map(|id| self.sockets.get(id)) 70 | { 71 | if socket.peer_addr() == Some(peer) { 72 | return self.sockets.get(&socket.socket_id).cloned(); 73 | } 74 | } 75 | None 76 | } 77 | 78 | pub fn new_socket( 79 | &mut self, 80 | socket_type: SocketType, 81 | config: Option, 82 | ) -> Result<&SocketRef> { 83 | let socket = UdtSocket::new(self.get_new_socket_id(), socket_type, None, config); 84 | let socket_id = socket.socket_id; 85 | if let Entry::Vacant(e) = self.sockets.entry(socket_id) { 86 | return Ok(e.insert(Arc::new(socket))); 87 | } 88 | Err(Error::new( 89 | ErrorKind::AlreadyExists, 90 | "socket_id already exists", 91 | )) 92 | } 93 | 94 | pub(crate) async fn new_connection( 95 | &mut self, 96 | listener_socket: &UdtSocket, 97 | peer: SocketAddr, 98 | hs: &HandShakeInfo, 99 | ) -> Result<()> { 100 | if let Some(existing_peer_socket) = self 101 | .get_peer_socket(peer, hs.socket_id, hs.initial_seq_number) 102 | .await 103 | { 104 | let socket = existing_peer_socket; 105 | if socket.status() == UdtStatus::Broken { 106 | // last connection from the "peer" address has been broken 107 | if *UDT_DEBUG { 108 | eprintln!("Existing connection to peer {} is broken", peer); 109 | } 110 | listener_socket 111 | .queued_sockets 112 | .write() 113 | .await 114 | .remove(&socket.socket_id); 115 | } else { 116 | // Respond with existing socket configuration. 117 | let source_socket_id = hs.socket_id; 118 | let hs = { 119 | let mut hs = hs.clone(); 120 | let configuration = socket.configuration.read().unwrap(); 121 | hs.initial_seq_number = socket.initial_seq_number; 122 | hs.max_packet_size = configuration.mss; 123 | hs.max_window_size = configuration.flight_flag_size; 124 | hs.connection_type = -1; 125 | hs.socket_id = socket.socket_id; 126 | hs 127 | }; 128 | let packet = UdtControlPacket::new_handshake(hs, source_socket_id); 129 | socket.send_to(&peer, packet.into()).await?; 130 | return Ok(()); 131 | } 132 | } 133 | 134 | let new_socket_id = self.get_new_socket_id(); 135 | 136 | let new_socket = { 137 | let multiplexer = listener_socket 138 | .multiplexer 139 | .read() 140 | .unwrap() 141 | .upgrade() 142 | .ok_or_else(|| Error::new(ErrorKind::Other, "Listener has no multiplexer"))?; 143 | 144 | let config = listener_socket.configuration.read().unwrap().clone(); 145 | if listener_socket.queued_sockets.read().await.len() >= config.accept_queue_size { 146 | return Err(Error::new(ErrorKind::Other, "Too many queued sockets")); 147 | } 148 | 149 | let new_socket = UdtSocket::new( 150 | new_socket_id, 151 | hs.socket_type, 152 | Some(hs.initial_seq_number), 153 | Some(config), 154 | ) 155 | .with_peer(peer, hs.socket_id) 156 | .with_listen_socket(listener_socket.socket_id, multiplexer); 157 | new_socket.open(); 158 | new_socket 159 | }; 160 | 161 | let ns_id = new_socket.socket_id; 162 | let ns_isn = new_socket.initial_seq_number; 163 | let ns_peer_socket_id = hs.socket_id; 164 | let new_socket_ref = new_socket.connect_on_handshake(peer, hs.clone()).await?; 165 | 166 | self.peers 167 | .entry((ns_peer_socket_id, ns_isn)) 168 | .or_default() 169 | .insert(new_socket_ref.socket_id); 170 | self.sockets.insert(ns_id, new_socket_ref); 171 | 172 | listener_socket.queued_sockets.write().await.insert(ns_id); 173 | listener_socket.accept_notify.notify_one(); 174 | Ok(()) 175 | } 176 | 177 | pub async fn bind(&mut self, socket_id: SocketId, addr: SocketAddr) -> Result<()> { 178 | let socket = self 179 | .get_socket(socket_id) 180 | .ok_or_else(|| Error::new(ErrorKind::Other, "unknown socket id"))?; 181 | 182 | if socket.status() != UdtStatus::Init { 183 | return Err(Error::new(ErrorKind::Other, "socket already binded")); 184 | } 185 | 186 | self.update_mux(&socket, Some(addr)).await?; 187 | socket.open(); 188 | Ok(()) 189 | } 190 | 191 | pub(crate) async fn update_mux( 192 | &mut self, 193 | socket: &UdtSocket, 194 | bind_addr: Option, 195 | ) -> Result<()> { 196 | if socket.configuration.read().unwrap().reuse_mux { 197 | if let Some(bind_addr) = bind_addr { 198 | let port = bind_addr.port(); 199 | if port > 0 { 200 | for mux in self.multiplexers.values() { 201 | let socket_mss = socket.configuration.read().unwrap().mss; 202 | if mux.reusable && mux.port == port && mux.mss == socket_mss { 203 | socket.set_multiplexer(mux); 204 | return Ok(()); 205 | } 206 | } 207 | } 208 | } 209 | } 210 | 211 | // A new multiplexer is needed 212 | let mux = { 213 | let configuration = socket.configuration.read().unwrap().clone(); 214 | let (mux_id, mux) = if let Some(bind_addr) = bind_addr { 215 | UdtMultiplexer::bind(socket.socket_id, bind_addr, &configuration).await? 216 | } else { 217 | UdtMultiplexer::new(socket.socket_id, &configuration).await? 218 | }; 219 | self.multiplexers.insert(mux_id, mux.clone()); 220 | mux 221 | }; 222 | socket.set_multiplexer(&mux); 223 | UdtMultiplexer::run(mux); 224 | Ok(()) 225 | } 226 | 227 | async fn garbage_collect_sockets(&mut self) { 228 | for (_, sock) in self 229 | .sockets 230 | .iter() 231 | .filter(|(_, s)| s.status() == UdtStatus::Broken) 232 | { 233 | if let Some(listen_socket_id) = sock.listen_socket { 234 | if let Some(listener) = self.sockets.get(&listen_socket_id) { 235 | listener 236 | .queued_sockets 237 | .write() 238 | .await 239 | .remove(&sock.socket_id); 240 | } 241 | } 242 | tokio::spawn({ 243 | let sock = sock.clone(); 244 | async move { sock.close().await } 245 | }); 246 | } 247 | 248 | let to_remove: Vec<_> = self 249 | .sockets 250 | .iter() 251 | .filter(|(_, s)| s.status() == UdtStatus::Closing) 252 | .map(|(socket_id, _)| *socket_id) 253 | .collect(); 254 | for socket_id in to_remove { 255 | if let Some(sock) = self.sockets.remove(&socket_id) { 256 | *sock.status.lock().unwrap() = UdtStatus::Closed; 257 | } 258 | } 259 | } 260 | 261 | fn cleanup_worker() { 262 | tokio::spawn(async { 263 | let udt = Self::get(); 264 | loop { 265 | udt.write().await.garbage_collect_sockets().await; 266 | sleep(std::time::Duration::from_secs(1)).await; 267 | } 268 | }); 269 | } 270 | } 271 | -------------------------------------------------------------------------------- /src/control_packet.rs: -------------------------------------------------------------------------------- 1 | use super::socket::{SocketId, SocketType}; 2 | use crate::common::ip_to_bytes; 3 | use crate::seq_number::{AckSeqNumber, MsgNumber, SeqNumber}; 4 | use std::net::IpAddr; 5 | use tokio::io::{Error, ErrorKind, Result}; 6 | 7 | #[derive(Debug)] 8 | pub(crate) struct UdtControlPacket { 9 | // bit 0 = 1 10 | pub packet_type: ControlPacketType, // bits 1-15 + Control Information Field (bits 128+) 11 | pub reserved: u16, // bits 16-31 12 | pub additional_info: u32, // bits 32-63 13 | pub timestamp: u32, // bits 64-95 14 | pub dest_socket_id: SocketId, // bits 96-127 15 | } 16 | 17 | impl UdtControlPacket { 18 | pub fn new_handshake(hs: HandShakeInfo, dest_socket_id: SocketId) -> Self { 19 | Self { 20 | packet_type: ControlPacketType::Handshake(hs), 21 | reserved: 0, 22 | additional_info: 0, 23 | timestamp: 0, 24 | dest_socket_id, 25 | } 26 | } 27 | 28 | pub fn new_nak(loss_list: Vec, dest_socket_id: SocketId) -> Self { 29 | Self { 30 | packet_type: ControlPacketType::Nak(NakInfo { 31 | loss_info: loss_list, 32 | }), 33 | reserved: 0, 34 | additional_info: 0, 35 | timestamp: 0, 36 | dest_socket_id, 37 | } 38 | } 39 | 40 | pub fn new_ack2(seq: AckSeqNumber, dest_socket_id: SocketId) -> Self { 41 | Self { 42 | packet_type: ControlPacketType::Ack2, 43 | additional_info: seq.number(), 44 | dest_socket_id, 45 | reserved: 0, 46 | timestamp: 0, 47 | } 48 | } 49 | 50 | pub fn new_drop( 51 | msg_id: MsgNumber, 52 | first: SeqNumber, 53 | last: SeqNumber, 54 | dest_socket_id: SocketId, 55 | ) -> Self { 56 | Self { 57 | packet_type: ControlPacketType::MsgDropRequest(DropRequestInfo { 58 | first_seq_number: first, 59 | last_seq_number: last, 60 | }), 61 | additional_info: msg_id.number(), 62 | dest_socket_id, 63 | reserved: 0, 64 | timestamp: 0, 65 | } 66 | } 67 | 68 | pub fn new_keep_alive(dest_socket_id: SocketId) -> Self { 69 | Self { 70 | packet_type: ControlPacketType::KeepAlive, 71 | dest_socket_id, 72 | additional_info: 0, 73 | reserved: 0, 74 | timestamp: 0, 75 | } 76 | } 77 | 78 | pub fn new_shutdown(dest_socket_id: SocketId) -> Self { 79 | Self { 80 | packet_type: ControlPacketType::Shutdown, 81 | dest_socket_id, 82 | additional_info: 0, 83 | reserved: 0, 84 | timestamp: 0, 85 | } 86 | } 87 | 88 | pub fn new_ack( 89 | ack_number: AckSeqNumber, 90 | next_seq_number: SeqNumber, 91 | dest_socket_id: SocketId, 92 | info: Option, 93 | ) -> Self { 94 | Self { 95 | packet_type: ControlPacketType::Ack(AckInfo { 96 | next_seq_number, 97 | info, 98 | }), 99 | dest_socket_id, 100 | additional_info: ack_number.number(), 101 | reserved: 0, 102 | timestamp: 0, 103 | } 104 | } 105 | 106 | pub fn ack_seq_number(&self) -> Option { 107 | match self.packet_type { 108 | ControlPacketType::Ack(_) => Some(self.additional_info.into()), 109 | ControlPacketType::Ack2 => Some(self.additional_info.into()), 110 | _ => None, 111 | } 112 | } 113 | 114 | pub fn msg_seq_number(&self) -> Option { 115 | match self.packet_type { 116 | ControlPacketType::MsgDropRequest(_) => { 117 | Some((self.additional_info & MsgNumber::MAX_NUMBER).into()) 118 | } 119 | _ => None, 120 | } 121 | } 122 | 123 | pub fn serialize(&self) -> Vec { 124 | let mut buffer: Vec = Vec::with_capacity(8); 125 | buffer.extend_from_slice(&(0x8000 + self.packet_type.type_as_u15()).to_be_bytes()); 126 | buffer.extend_from_slice(&self.reserved.to_be_bytes()); 127 | buffer.extend_from_slice(&self.additional_info.to_be_bytes()); 128 | buffer.extend_from_slice(&self.timestamp.to_be_bytes()); 129 | buffer.extend_from_slice(&self.dest_socket_id.to_be_bytes()); 130 | buffer.extend_from_slice(&self.packet_type.control_info_field()); 131 | buffer 132 | } 133 | 134 | pub fn deserialize(raw: &[u8]) -> Result { 135 | if raw.len() < 16 { 136 | return Err(Error::new( 137 | ErrorKind::InvalidData, 138 | "control packet header is too short", 139 | )); 140 | } 141 | let reserved = u16::from_be_bytes(raw[2..4].try_into().unwrap()); 142 | let additional_info = u32::from_be_bytes(raw[4..8].try_into().unwrap()); 143 | let timestamp = u32::from_be_bytes(raw[8..12].try_into().unwrap()); 144 | let dest_socket_id = u32::from_be_bytes(raw[12..16].try_into().unwrap()); 145 | 146 | let packet_type = ControlPacketType::deserialize(raw)?; 147 | Ok(Self { 148 | reserved, 149 | additional_info, 150 | timestamp, 151 | dest_socket_id, 152 | packet_type, 153 | }) 154 | } 155 | } 156 | 157 | #[derive(Debug)] 158 | pub(crate) enum ControlPacketType { 159 | Handshake(HandShakeInfo), 160 | KeepAlive, 161 | Ack(AckInfo), 162 | Nak(NakInfo), 163 | Shutdown, 164 | Ack2, 165 | MsgDropRequest(DropRequestInfo), 166 | UserDefined, 167 | } 168 | 169 | impl ControlPacketType { 170 | pub fn type_as_u15(&self) -> u16 { 171 | match self { 172 | Self::Handshake(_) => 0x0000, 173 | Self::KeepAlive => 0x0001, 174 | Self::Ack(_) => 0x0002, 175 | Self::Nak(_) => 0x0003, 176 | Self::Shutdown => 0x0005, 177 | Self::Ack2 => 0x0006, 178 | Self::MsgDropRequest(_) => 0x0007, 179 | Self::UserDefined => 0x7fff, 180 | } 181 | } 182 | 183 | pub fn control_info_field(&self) -> Vec { 184 | match self { 185 | Self::Handshake(hs) => hs.serialize(), 186 | Self::Ack(ack) => ack.serialize(), 187 | Self::Nak(nak) => nak.serialize(), 188 | Self::MsgDropRequest(drop) => drop.serialize(), 189 | _ => vec![], 190 | } 191 | } 192 | 193 | pub fn deserialize(raw_control_packet: &[u8]) -> Result { 194 | let type_id = u16::from_be_bytes(raw_control_packet[0..2].try_into().unwrap()) & 0x7FFF; 195 | let packet = match type_id { 196 | 0x0000 => Self::Handshake(HandShakeInfo::deserialize(&raw_control_packet[16..])?), 197 | 0x0001 => Self::KeepAlive, 198 | 0x0002 => Self::Ack(AckInfo::deserialize(&raw_control_packet[16..])?), 199 | 0x0003 => Self::Nak(NakInfo::deserialize(&raw_control_packet[16..])?), 200 | 0x0005 => Self::Shutdown, 201 | 0x0006 => Self::Ack2, 202 | 0x0007 => { 203 | Self::MsgDropRequest(DropRequestInfo::deserialize(&raw_control_packet[16..])?) 204 | } 205 | 0x7fff => Self::UserDefined, 206 | _ => { 207 | return Err(Error::new( 208 | ErrorKind::InvalidData, 209 | "unknown control packet type", 210 | )); 211 | } 212 | }; 213 | Ok(packet) 214 | } 215 | } 216 | 217 | #[derive(Debug, Clone)] 218 | pub(crate) struct HandShakeInfo { 219 | pub udt_version: u32, 220 | pub socket_type: SocketType, 221 | pub initial_seq_number: SeqNumber, 222 | pub max_packet_size: u32, 223 | pub max_window_size: u32, 224 | pub connection_type: i32, // regular or rendezvous 225 | pub socket_id: SocketId, 226 | pub syn_cookie: u32, 227 | pub ip_address: IpAddr, 228 | } 229 | 230 | impl HandShakeInfo { 231 | pub fn serialize(&self) -> Vec { 232 | [ 233 | self.udt_version, 234 | self.socket_type as u32, 235 | self.initial_seq_number.number(), 236 | self.max_packet_size, 237 | self.max_window_size, 238 | ] 239 | .iter() 240 | .flat_map(|v| v.to_be_bytes()) 241 | .chain(self.connection_type.to_be_bytes().into_iter()) 242 | .chain(self.socket_id.to_be_bytes().into_iter()) 243 | .chain(self.syn_cookie.to_be_bytes().into_iter()) 244 | .chain(ip_to_bytes(self.ip_address)) 245 | .collect() 246 | } 247 | 248 | pub fn deserialize(raw: &[u8]) -> Result { 249 | let get_u32 = 250 | |idx: usize| u32::from_be_bytes(raw[(idx * 4)..(idx + 1) * 4].try_into().unwrap()); 251 | let addr: IpAddr = { 252 | if raw[36..48].iter().all(|b| *b == 0) { 253 | // IPv4 254 | let octets: [u8; 4] = raw[32..36].try_into().unwrap(); 255 | octets.into() 256 | } else { 257 | // IPv6 258 | let octets: [u8; 16] = raw[32..48].try_into().unwrap(); 259 | octets.into() 260 | } 261 | }; 262 | 263 | Ok(Self { 264 | udt_version: get_u32(0), 265 | socket_type: get_u32(1).try_into()?, 266 | initial_seq_number: get_u32(2).into(), 267 | max_packet_size: get_u32(3), 268 | max_window_size: get_u32(4), 269 | connection_type: i32::from_be_bytes(raw[20..24].try_into().unwrap()), 270 | socket_id: get_u32(6), 271 | syn_cookie: get_u32(7), 272 | ip_address: addr, 273 | }) 274 | } 275 | } 276 | 277 | #[derive(Debug)] 278 | pub(crate) struct AckInfo { 279 | /// The packet sequence number to which all the 280 | /// previous packets have been received (excluding) 281 | pub next_seq_number: SeqNumber, 282 | pub info: Option, 283 | } 284 | 285 | impl AckInfo { 286 | pub fn deserialize(raw: &[u8]) -> Result { 287 | let get_u32 = 288 | |idx: usize| u32::from_be_bytes(raw[(idx * 4)..(idx + 1) * 4].try_into().unwrap()); 289 | 290 | let next_seq_number: SeqNumber = get_u32(0).into(); 291 | 292 | if raw.len() <= 4 { 293 | return Ok(Self { 294 | next_seq_number, 295 | info: None, 296 | }); 297 | } 298 | let info = AckOptionalInfo { 299 | rtt: get_u32(1), 300 | rtt_variance: get_u32(2), 301 | available_buf_size: get_u32(3), 302 | pack_recv_rate: get_u32(4), 303 | link_capacity: get_u32(5), 304 | }; 305 | Ok(Self { 306 | next_seq_number, 307 | info: Some(info), 308 | }) 309 | } 310 | 311 | pub fn serialize(&self) -> Vec { 312 | match &self.info { 313 | None => self.next_seq_number.number().to_be_bytes().to_vec(), 314 | Some(extra) => [ 315 | self.next_seq_number.number(), 316 | extra.rtt, 317 | extra.rtt_variance, 318 | extra.available_buf_size, 319 | extra.pack_recv_rate, 320 | extra.link_capacity, 321 | ] 322 | .iter() 323 | .flat_map(|v| v.to_be_bytes()) 324 | .collect(), 325 | } 326 | } 327 | } 328 | 329 | #[derive(Debug)] 330 | pub(crate) struct AckOptionalInfo { 331 | /// RTT in microseconds 332 | pub rtt: u32, 333 | pub rtt_variance: u32, 334 | pub available_buf_size: u32, 335 | pub pack_recv_rate: u32, 336 | pub link_capacity: u32, 337 | } 338 | 339 | #[derive(Debug)] 340 | pub(crate) struct NakInfo { 341 | pub loss_info: Vec, 342 | } 343 | 344 | impl NakInfo { 345 | pub fn deserialize(raw: &[u8]) -> Result { 346 | let losses: Vec = raw 347 | .chunks(4) 348 | .filter_map(|chunk| { 349 | if chunk.len() < 4 { 350 | return None; 351 | } 352 | Some(u32::from_be_bytes(chunk.try_into().unwrap())) 353 | }) 354 | .collect(); 355 | Ok(Self { loss_info: losses }) 356 | } 357 | 358 | pub fn serialize(&self) -> Vec { 359 | self.loss_info 360 | .iter() 361 | .flat_map(|x| x.to_be_bytes()) 362 | .collect() 363 | } 364 | } 365 | 366 | #[derive(Debug)] 367 | pub(crate) struct DropRequestInfo { 368 | pub first_seq_number: SeqNumber, 369 | pub last_seq_number: SeqNumber, 370 | } 371 | 372 | impl DropRequestInfo { 373 | pub fn deserialize(raw: &[u8]) -> Result { 374 | let get_u32 = 375 | |idx: usize| u32::from_be_bytes(raw[(idx * 4)..(idx + 1) * 4].try_into().unwrap()); 376 | 377 | Ok(Self { 378 | first_seq_number: get_u32(0).into(), 379 | last_seq_number: get_u32(1).into(), 380 | }) 381 | } 382 | 383 | pub fn serialize(&self) -> Vec { 384 | [ 385 | self.first_seq_number.number(), 386 | self.last_seq_number.number(), 387 | ] 388 | .iter() 389 | .flat_map(|x| x.to_be_bytes()) 390 | .collect() 391 | } 392 | } 393 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU AFFERO GENERAL PUBLIC LICENSE 2 | Version 3, 19 November 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU Affero General Public License is a free, copyleft license for 11 | software and other kinds of works, specifically designed to ensure 12 | cooperation with the community in the case of network server software. 13 | 14 | The licenses for most software and other practical works are designed 15 | to take away your freedom to share and change the works. By contrast, 16 | our General Public Licenses are intended to guarantee your freedom to 17 | share and change all versions of a program--to make sure it remains free 18 | software for all its users. 19 | 20 | When we speak of free software, we are referring to freedom, not 21 | price. Our General Public Licenses are designed to make sure that you 22 | have the freedom to distribute copies of free software (and charge for 23 | them if you wish), that you receive source code or can get it if you 24 | want it, that you can change the software or use pieces of it in new 25 | free programs, and that you know you can do these things. 26 | 27 | Developers that use our General Public Licenses protect your rights 28 | with two steps: (1) assert copyright on the software, and (2) offer 29 | you this License which gives you legal permission to copy, distribute 30 | and/or modify the software. 31 | 32 | A secondary benefit of defending all users' freedom is that 33 | improvements made in alternate versions of the program, if they 34 | receive widespread use, become available for other developers to 35 | incorporate. Many developers of free software are heartened and 36 | encouraged by the resulting cooperation. However, in the case of 37 | software used on network servers, this result may fail to come about. 38 | The GNU General Public License permits making a modified version and 39 | letting the public access it on a server without ever releasing its 40 | source code to the public. 41 | 42 | The GNU Affero General Public License is designed specifically to 43 | ensure that, in such cases, the modified source code becomes available 44 | to the community. It requires the operator of a network server to 45 | provide the source code of the modified version running there to the 46 | users of that server. Therefore, public use of a modified version, on 47 | a publicly accessible server, gives the public access to the source 48 | code of the modified version. 49 | 50 | An older license, called the Affero General Public License and 51 | published by Affero, was designed to accomplish similar goals. This is 52 | a different license, not a version of the Affero GPL, but Affero has 53 | released a new version of the Affero GPL which permits relicensing under 54 | this license. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | TERMS AND CONDITIONS 60 | 61 | 0. Definitions. 62 | 63 | "This License" refers to version 3 of the GNU Affero General Public License. 64 | 65 | "Copyright" also means copyright-like laws that apply to other kinds of 66 | works, such as semiconductor masks. 67 | 68 | "The Program" refers to any copyrightable work licensed under this 69 | License. Each licensee is addressed as "you". "Licensees" and 70 | "recipients" may be individuals or organizations. 71 | 72 | To "modify" a work means to copy from or adapt all or part of the work 73 | in a fashion requiring copyright permission, other than the making of an 74 | exact copy. The resulting work is called a "modified version" of the 75 | earlier work or a work "based on" the earlier work. 76 | 77 | A "covered work" means either the unmodified Program or a work based 78 | on the Program. 79 | 80 | To "propagate" a work means to do anything with it that, without 81 | permission, would make you directly or secondarily liable for 82 | infringement under applicable copyright law, except executing it on a 83 | computer or modifying a private copy. Propagation includes copying, 84 | distribution (with or without modification), making available to the 85 | public, and in some countries other activities as well. 86 | 87 | To "convey" a work means any kind of propagation that enables other 88 | parties to make or receive copies. Mere interaction with a user through 89 | a computer network, with no transfer of a copy, is not conveying. 90 | 91 | An interactive user interface displays "Appropriate Legal Notices" 92 | to the extent that it includes a convenient and prominently visible 93 | feature that (1) displays an appropriate copyright notice, and (2) 94 | tells the user that there is no warranty for the work (except to the 95 | extent that warranties are provided), that licensees may convey the 96 | work under this License, and how to view a copy of this License. If 97 | the interface presents a list of user commands or options, such as a 98 | menu, a prominent item in the list meets this criterion. 99 | 100 | 1. Source Code. 101 | 102 | The "source code" for a work means the preferred form of the work 103 | for making modifications to it. "Object code" means any non-source 104 | form of a work. 105 | 106 | A "Standard Interface" means an interface that either is an official 107 | standard defined by a recognized standards body, or, in the case of 108 | interfaces specified for a particular programming language, one that 109 | is widely used among developers working in that language. 110 | 111 | The "System Libraries" of an executable work include anything, other 112 | than the work as a whole, that (a) is included in the normal form of 113 | packaging a Major Component, but which is not part of that Major 114 | Component, and (b) serves only to enable use of the work with that 115 | Major Component, or to implement a Standard Interface for which an 116 | implementation is available to the public in source code form. A 117 | "Major Component", in this context, means a major essential component 118 | (kernel, window system, and so on) of the specific operating system 119 | (if any) on which the executable work runs, or a compiler used to 120 | produce the work, or an object code interpreter used to run it. 121 | 122 | The "Corresponding Source" for a work in object code form means all 123 | the source code needed to generate, install, and (for an executable 124 | work) run the object code and to modify the work, including scripts to 125 | control those activities. However, it does not include the work's 126 | System Libraries, or general-purpose tools or generally available free 127 | programs which are used unmodified in performing those activities but 128 | which are not part of the work. For example, Corresponding Source 129 | includes interface definition files associated with source files for 130 | the work, and the source code for shared libraries and dynamically 131 | linked subprograms that the work is specifically designed to require, 132 | such as by intimate data communication or control flow between those 133 | subprograms and other parts of the work. 134 | 135 | The Corresponding Source need not include anything that users 136 | can regenerate automatically from other parts of the Corresponding 137 | Source. 138 | 139 | The Corresponding Source for a work in source code form is that 140 | same work. 141 | 142 | 2. Basic Permissions. 143 | 144 | All rights granted under this License are granted for the term of 145 | copyright on the Program, and are irrevocable provided the stated 146 | conditions are met. This License explicitly affirms your unlimited 147 | permission to run the unmodified Program. The output from running a 148 | covered work is covered by this License only if the output, given its 149 | content, constitutes a covered work. This License acknowledges your 150 | rights of fair use or other equivalent, as provided by copyright law. 151 | 152 | You may make, run and propagate covered works that you do not 153 | convey, without conditions so long as your license otherwise remains 154 | in force. You may convey covered works to others for the sole purpose 155 | of having them make modifications exclusively for you, or provide you 156 | with facilities for running those works, provided that you comply with 157 | the terms of this License in conveying all material for which you do 158 | not control copyright. Those thus making or running the covered works 159 | for you must do so exclusively on your behalf, under your direction 160 | and control, on terms that prohibit them from making any copies of 161 | your copyrighted material outside their relationship with you. 162 | 163 | Conveying under any other circumstances is permitted solely under 164 | the conditions stated below. Sublicensing is not allowed; section 10 165 | makes it unnecessary. 166 | 167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 168 | 169 | No covered work shall be deemed part of an effective technological 170 | measure under any applicable law fulfilling obligations under article 171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 172 | similar laws prohibiting or restricting circumvention of such 173 | measures. 174 | 175 | When you convey a covered work, you waive any legal power to forbid 176 | circumvention of technological measures to the extent such circumvention 177 | is effected by exercising rights under this License with respect to 178 | the covered work, and you disclaim any intention to limit operation or 179 | modification of the work as a means of enforcing, against the work's 180 | users, your or third parties' legal rights to forbid circumvention of 181 | technological measures. 182 | 183 | 4. Conveying Verbatim Copies. 184 | 185 | You may convey verbatim copies of the Program's source code as you 186 | receive it, in any medium, provided that you conspicuously and 187 | appropriately publish on each copy an appropriate copyright notice; 188 | keep intact all notices stating that this License and any 189 | non-permissive terms added in accord with section 7 apply to the code; 190 | keep intact all notices of the absence of any warranty; and give all 191 | recipients a copy of this License along with the Program. 192 | 193 | You may charge any price or no price for each copy that you convey, 194 | and you may offer support or warranty protection for a fee. 195 | 196 | 5. Conveying Modified Source Versions. 197 | 198 | You may convey a work based on the Program, or the modifications to 199 | produce it from the Program, in the form of source code under the 200 | terms of section 4, provided that you also meet all of these conditions: 201 | 202 | a) The work must carry prominent notices stating that you modified 203 | it, and giving a relevant date. 204 | 205 | b) The work must carry prominent notices stating that it is 206 | released under this License and any conditions added under section 207 | 7. This requirement modifies the requirement in section 4 to 208 | "keep intact all notices". 209 | 210 | c) You must license the entire work, as a whole, under this 211 | License to anyone who comes into possession of a copy. This 212 | License will therefore apply, along with any applicable section 7 213 | additional terms, to the whole of the work, and all its parts, 214 | regardless of how they are packaged. This License gives no 215 | permission to license the work in any other way, but it does not 216 | invalidate such permission if you have separately received it. 217 | 218 | d) If the work has interactive user interfaces, each must display 219 | Appropriate Legal Notices; however, if the Program has interactive 220 | interfaces that do not display Appropriate Legal Notices, your 221 | work need not make them do so. 222 | 223 | A compilation of a covered work with other separate and independent 224 | works, which are not by their nature extensions of the covered work, 225 | and which are not combined with it such as to form a larger program, 226 | in or on a volume of a storage or distribution medium, is called an 227 | "aggregate" if the compilation and its resulting copyright are not 228 | used to limit the access or legal rights of the compilation's users 229 | beyond what the individual works permit. Inclusion of a covered work 230 | in an aggregate does not cause this License to apply to the other 231 | parts of the aggregate. 232 | 233 | 6. Conveying Non-Source Forms. 234 | 235 | You may convey a covered work in object code form under the terms 236 | of sections 4 and 5, provided that you also convey the 237 | machine-readable Corresponding Source under the terms of this License, 238 | in one of these ways: 239 | 240 | a) Convey the object code in, or embodied in, a physical product 241 | (including a physical distribution medium), accompanied by the 242 | Corresponding Source fixed on a durable physical medium 243 | customarily used for software interchange. 244 | 245 | b) Convey the object code in, or embodied in, a physical product 246 | (including a physical distribution medium), accompanied by a 247 | written offer, valid for at least three years and valid for as 248 | long as you offer spare parts or customer support for that product 249 | model, to give anyone who possesses the object code either (1) a 250 | copy of the Corresponding Source for all the software in the 251 | product that is covered by this License, on a durable physical 252 | medium customarily used for software interchange, for a price no 253 | more than your reasonable cost of physically performing this 254 | conveying of source, or (2) access to copy the 255 | Corresponding Source from a network server at no charge. 256 | 257 | c) Convey individual copies of the object code with a copy of the 258 | written offer to provide the Corresponding Source. This 259 | alternative is allowed only occasionally and noncommercially, and 260 | only if you received the object code with such an offer, in accord 261 | with subsection 6b. 262 | 263 | d) Convey the object code by offering access from a designated 264 | place (gratis or for a charge), and offer equivalent access to the 265 | Corresponding Source in the same way through the same place at no 266 | further charge. You need not require recipients to copy the 267 | Corresponding Source along with the object code. If the place to 268 | copy the object code is a network server, the Corresponding Source 269 | may be on a different server (operated by you or a third party) 270 | that supports equivalent copying facilities, provided you maintain 271 | clear directions next to the object code saying where to find the 272 | Corresponding Source. Regardless of what server hosts the 273 | Corresponding Source, you remain obligated to ensure that it is 274 | available for as long as needed to satisfy these requirements. 275 | 276 | e) Convey the object code using peer-to-peer transmission, provided 277 | you inform other peers where the object code and Corresponding 278 | Source of the work are being offered to the general public at no 279 | charge under subsection 6d. 280 | 281 | A separable portion of the object code, whose source code is excluded 282 | from the Corresponding Source as a System Library, need not be 283 | included in conveying the object code work. 284 | 285 | A "User Product" is either (1) a "consumer product", which means any 286 | tangible personal property which is normally used for personal, family, 287 | or household purposes, or (2) anything designed or sold for incorporation 288 | into a dwelling. In determining whether a product is a consumer product, 289 | doubtful cases shall be resolved in favor of coverage. For a particular 290 | product received by a particular user, "normally used" refers to a 291 | typical or common use of that class of product, regardless of the status 292 | of the particular user or of the way in which the particular user 293 | actually uses, or expects or is expected to use, the product. A product 294 | is a consumer product regardless of whether the product has substantial 295 | commercial, industrial or non-consumer uses, unless such uses represent 296 | the only significant mode of use of the product. 297 | 298 | "Installation Information" for a User Product means any methods, 299 | procedures, authorization keys, or other information required to install 300 | and execute modified versions of a covered work in that User Product from 301 | a modified version of its Corresponding Source. The information must 302 | suffice to ensure that the continued functioning of the modified object 303 | code is in no case prevented or interfered with solely because 304 | modification has been made. 305 | 306 | If you convey an object code work under this section in, or with, or 307 | specifically for use in, a User Product, and the conveying occurs as 308 | part of a transaction in which the right of possession and use of the 309 | User Product is transferred to the recipient in perpetuity or for a 310 | fixed term (regardless of how the transaction is characterized), the 311 | Corresponding Source conveyed under this section must be accompanied 312 | by the Installation Information. But this requirement does not apply 313 | if neither you nor any third party retains the ability to install 314 | modified object code on the User Product (for example, the work has 315 | been installed in ROM). 316 | 317 | The requirement to provide Installation Information does not include a 318 | requirement to continue to provide support service, warranty, or updates 319 | for a work that has been modified or installed by the recipient, or for 320 | the User Product in which it has been modified or installed. Access to a 321 | network may be denied when the modification itself materially and 322 | adversely affects the operation of the network or violates the rules and 323 | protocols for communication across the network. 324 | 325 | Corresponding Source conveyed, and Installation Information provided, 326 | in accord with this section must be in a format that is publicly 327 | documented (and with an implementation available to the public in 328 | source code form), and must require no special password or key for 329 | unpacking, reading or copying. 330 | 331 | 7. Additional Terms. 332 | 333 | "Additional permissions" are terms that supplement the terms of this 334 | License by making exceptions from one or more of its conditions. 335 | Additional permissions that are applicable to the entire Program shall 336 | be treated as though they were included in this License, to the extent 337 | that they are valid under applicable law. If additional permissions 338 | apply only to part of the Program, that part may be used separately 339 | under those permissions, but the entire Program remains governed by 340 | this License without regard to the additional permissions. 341 | 342 | When you convey a copy of a covered work, you may at your option 343 | remove any additional permissions from that copy, or from any part of 344 | it. (Additional permissions may be written to require their own 345 | removal in certain cases when you modify the work.) You may place 346 | additional permissions on material, added by you to a covered work, 347 | for which you have or can give appropriate copyright permission. 348 | 349 | Notwithstanding any other provision of this License, for material you 350 | add to a covered work, you may (if authorized by the copyright holders of 351 | that material) supplement the terms of this License with terms: 352 | 353 | a) Disclaiming warranty or limiting liability differently from the 354 | terms of sections 15 and 16 of this License; or 355 | 356 | b) Requiring preservation of specified reasonable legal notices or 357 | author attributions in that material or in the Appropriate Legal 358 | Notices displayed by works containing it; or 359 | 360 | c) Prohibiting misrepresentation of the origin of that material, or 361 | requiring that modified versions of such material be marked in 362 | reasonable ways as different from the original version; or 363 | 364 | d) Limiting the use for publicity purposes of names of licensors or 365 | authors of the material; or 366 | 367 | e) Declining to grant rights under trademark law for use of some 368 | trade names, trademarks, or service marks; or 369 | 370 | f) Requiring indemnification of licensors and authors of that 371 | material by anyone who conveys the material (or modified versions of 372 | it) with contractual assumptions of liability to the recipient, for 373 | any liability that these contractual assumptions directly impose on 374 | those licensors and authors. 375 | 376 | All other non-permissive additional terms are considered "further 377 | restrictions" within the meaning of section 10. If the Program as you 378 | received it, or any part of it, contains a notice stating that it is 379 | governed by this License along with a term that is a further 380 | restriction, you may remove that term. If a license document contains 381 | a further restriction but permits relicensing or conveying under this 382 | License, you may add to a covered work material governed by the terms 383 | of that license document, provided that the further restriction does 384 | not survive such relicensing or conveying. 385 | 386 | If you add terms to a covered work in accord with this section, you 387 | must place, in the relevant source files, a statement of the 388 | additional terms that apply to those files, or a notice indicating 389 | where to find the applicable terms. 390 | 391 | Additional terms, permissive or non-permissive, may be stated in the 392 | form of a separately written license, or stated as exceptions; 393 | the above requirements apply either way. 394 | 395 | 8. Termination. 396 | 397 | You may not propagate or modify a covered work except as expressly 398 | provided under this License. Any attempt otherwise to propagate or 399 | modify it is void, and will automatically terminate your rights under 400 | this License (including any patent licenses granted under the third 401 | paragraph of section 11). 402 | 403 | However, if you cease all violation of this License, then your 404 | license from a particular copyright holder is reinstated (a) 405 | provisionally, unless and until the copyright holder explicitly and 406 | finally terminates your license, and (b) permanently, if the copyright 407 | holder fails to notify you of the violation by some reasonable means 408 | prior to 60 days after the cessation. 409 | 410 | Moreover, your license from a particular copyright holder is 411 | reinstated permanently if the copyright holder notifies you of the 412 | violation by some reasonable means, this is the first time you have 413 | received notice of violation of this License (for any work) from that 414 | copyright holder, and you cure the violation prior to 30 days after 415 | your receipt of the notice. 416 | 417 | Termination of your rights under this section does not terminate the 418 | licenses of parties who have received copies or rights from you under 419 | this License. If your rights have been terminated and not permanently 420 | reinstated, you do not qualify to receive new licenses for the same 421 | material under section 10. 422 | 423 | 9. Acceptance Not Required for Having Copies. 424 | 425 | You are not required to accept this License in order to receive or 426 | run a copy of the Program. Ancillary propagation of a covered work 427 | occurring solely as a consequence of using peer-to-peer transmission 428 | to receive a copy likewise does not require acceptance. However, 429 | nothing other than this License grants you permission to propagate or 430 | modify any covered work. These actions infringe copyright if you do 431 | not accept this License. Therefore, by modifying or propagating a 432 | covered work, you indicate your acceptance of this License to do so. 433 | 434 | 10. Automatic Licensing of Downstream Recipients. 435 | 436 | Each time you convey a covered work, the recipient automatically 437 | receives a license from the original licensors, to run, modify and 438 | propagate that work, subject to this License. You are not responsible 439 | for enforcing compliance by third parties with this License. 440 | 441 | An "entity transaction" is a transaction transferring control of an 442 | organization, or substantially all assets of one, or subdividing an 443 | organization, or merging organizations. If propagation of a covered 444 | work results from an entity transaction, each party to that 445 | transaction who receives a copy of the work also receives whatever 446 | licenses to the work the party's predecessor in interest had or could 447 | give under the previous paragraph, plus a right to possession of the 448 | Corresponding Source of the work from the predecessor in interest, if 449 | the predecessor has it or can get it with reasonable efforts. 450 | 451 | You may not impose any further restrictions on the exercise of the 452 | rights granted or affirmed under this License. For example, you may 453 | not impose a license fee, royalty, or other charge for exercise of 454 | rights granted under this License, and you may not initiate litigation 455 | (including a cross-claim or counterclaim in a lawsuit) alleging that 456 | any patent claim is infringed by making, using, selling, offering for 457 | sale, or importing the Program or any portion of it. 458 | 459 | 11. Patents. 460 | 461 | A "contributor" is a copyright holder who authorizes use under this 462 | License of the Program or a work on which the Program is based. The 463 | work thus licensed is called the contributor's "contributor version". 464 | 465 | A contributor's "essential patent claims" are all patent claims 466 | owned or controlled by the contributor, whether already acquired or 467 | hereafter acquired, that would be infringed by some manner, permitted 468 | by this License, of making, using, or selling its contributor version, 469 | but do not include claims that would be infringed only as a 470 | consequence of further modification of the contributor version. For 471 | purposes of this definition, "control" includes the right to grant 472 | patent sublicenses in a manner consistent with the requirements of 473 | this License. 474 | 475 | Each contributor grants you a non-exclusive, worldwide, royalty-free 476 | patent license under the contributor's essential patent claims, to 477 | make, use, sell, offer for sale, import and otherwise run, modify and 478 | propagate the contents of its contributor version. 479 | 480 | In the following three paragraphs, a "patent license" is any express 481 | agreement or commitment, however denominated, not to enforce a patent 482 | (such as an express permission to practice a patent or covenant not to 483 | sue for patent infringement). To "grant" such a patent license to a 484 | party means to make such an agreement or commitment not to enforce a 485 | patent against the party. 486 | 487 | If you convey a covered work, knowingly relying on a patent license, 488 | and the Corresponding Source of the work is not available for anyone 489 | to copy, free of charge and under the terms of this License, through a 490 | publicly available network server or other readily accessible means, 491 | then you must either (1) cause the Corresponding Source to be so 492 | available, or (2) arrange to deprive yourself of the benefit of the 493 | patent license for this particular work, or (3) arrange, in a manner 494 | consistent with the requirements of this License, to extend the patent 495 | license to downstream recipients. "Knowingly relying" means you have 496 | actual knowledge that, but for the patent license, your conveying the 497 | covered work in a country, or your recipient's use of the covered work 498 | in a country, would infringe one or more identifiable patents in that 499 | country that you have reason to believe are valid. 500 | 501 | If, pursuant to or in connection with a single transaction or 502 | arrangement, you convey, or propagate by procuring conveyance of, a 503 | covered work, and grant a patent license to some of the parties 504 | receiving the covered work authorizing them to use, propagate, modify 505 | or convey a specific copy of the covered work, then the patent license 506 | you grant is automatically extended to all recipients of the covered 507 | work and works based on it. 508 | 509 | A patent license is "discriminatory" if it does not include within 510 | the scope of its coverage, prohibits the exercise of, or is 511 | conditioned on the non-exercise of one or more of the rights that are 512 | specifically granted under this License. You may not convey a covered 513 | work if you are a party to an arrangement with a third party that is 514 | in the business of distributing software, under which you make payment 515 | to the third party based on the extent of your activity of conveying 516 | the work, and under which the third party grants, to any of the 517 | parties who would receive the covered work from you, a discriminatory 518 | patent license (a) in connection with copies of the covered work 519 | conveyed by you (or copies made from those copies), or (b) primarily 520 | for and in connection with specific products or compilations that 521 | contain the covered work, unless you entered into that arrangement, 522 | or that patent license was granted, prior to 28 March 2007. 523 | 524 | Nothing in this License shall be construed as excluding or limiting 525 | any implied license or other defenses to infringement that may 526 | otherwise be available to you under applicable patent law. 527 | 528 | 12. No Surrender of Others' Freedom. 529 | 530 | If conditions are imposed on you (whether by court order, agreement or 531 | otherwise) that contradict the conditions of this License, they do not 532 | excuse you from the conditions of this License. If you cannot convey a 533 | covered work so as to satisfy simultaneously your obligations under this 534 | License and any other pertinent obligations, then as a consequence you may 535 | not convey it at all. For example, if you agree to terms that obligate you 536 | to collect a royalty for further conveying from those to whom you convey 537 | the Program, the only way you could satisfy both those terms and this 538 | License would be to refrain entirely from conveying the Program. 539 | 540 | 13. Remote Network Interaction; Use with the GNU General Public License. 541 | 542 | Notwithstanding any other provision of this License, if you modify the 543 | Program, your modified version must prominently offer all users 544 | interacting with it remotely through a computer network (if your version 545 | supports such interaction) an opportunity to receive the Corresponding 546 | Source of your version by providing access to the Corresponding Source 547 | from a network server at no charge, through some standard or customary 548 | means of facilitating copying of software. This Corresponding Source 549 | shall include the Corresponding Source for any work covered by version 3 550 | of the GNU General Public License that is incorporated pursuant to the 551 | following paragraph. 552 | 553 | Notwithstanding any other provision of this License, you have 554 | permission to link or combine any covered work with a work licensed 555 | under version 3 of the GNU General Public License into a single 556 | combined work, and to convey the resulting work. The terms of this 557 | License will continue to apply to the part which is the covered work, 558 | but the work with which it is combined will remain governed by version 559 | 3 of the GNU General Public License. 560 | 561 | 14. Revised Versions of this License. 562 | 563 | The Free Software Foundation may publish revised and/or new versions of 564 | the GNU Affero General Public License from time to time. Such new versions 565 | will be similar in spirit to the present version, but may differ in detail to 566 | address new problems or concerns. 567 | 568 | Each version is given a distinguishing version number. If the 569 | Program specifies that a certain numbered version of the GNU Affero General 570 | Public License "or any later version" applies to it, you have the 571 | option of following the terms and conditions either of that numbered 572 | version or of any later version published by the Free Software 573 | Foundation. If the Program does not specify a version number of the 574 | GNU Affero General Public License, you may choose any version ever published 575 | by the Free Software Foundation. 576 | 577 | If the Program specifies that a proxy can decide which future 578 | versions of the GNU Affero General Public License can be used, that proxy's 579 | public statement of acceptance of a version permanently authorizes you 580 | to choose that version for the Program. 581 | 582 | Later license versions may give you additional or different 583 | permissions. However, no additional obligations are imposed on any 584 | author or copyright holder as a result of your choosing to follow a 585 | later version. 586 | 587 | 15. Disclaimer of Warranty. 588 | 589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 597 | 598 | 16. Limitation of Liability. 599 | 600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 608 | SUCH DAMAGES. 609 | 610 | 17. Interpretation of Sections 15 and 16. 611 | 612 | If the disclaimer of warranty and limitation of liability provided 613 | above cannot be given local legal effect according to their terms, 614 | reviewing courts shall apply local law that most closely approximates 615 | an absolute waiver of all civil liability in connection with the 616 | Program, unless a warranty or assumption of liability accompanies a 617 | copy of the Program in return for a fee. 618 | 619 | END OF TERMS AND CONDITIONS 620 | 621 | How to Apply These Terms to Your New Programs 622 | 623 | If you develop a new program, and you want it to be of the greatest 624 | possible use to the public, the best way to achieve this is to make it 625 | free software which everyone can redistribute and change under these terms. 626 | 627 | To do so, attach the following notices to the program. It is safest 628 | to attach them to the start of each source file to most effectively 629 | state the exclusion of warranty; and each file should have at least 630 | the "copyright" line and a pointer to where the full notice is found. 631 | 632 | 633 | Copyright (C) 634 | 635 | This program is free software: you can redistribute it and/or modify 636 | it under the terms of the GNU Affero General Public License as published 637 | by the Free Software Foundation, either version 3 of the License, or 638 | (at your option) any later version. 639 | 640 | This program is distributed in the hope that it will be useful, 641 | but WITHOUT ANY WARRANTY; without even the implied warranty of 642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 643 | GNU Affero General Public License for more details. 644 | 645 | You should have received a copy of the GNU Affero General Public License 646 | along with this program. If not, see . 647 | 648 | Also add information on how to contact you by electronic and paper mail. 649 | 650 | If your software can interact with users remotely through a computer 651 | network, you should also make sure that it provides a way for users to 652 | get its source. For example, if your program is a web application, its 653 | interface could display a "Source" link that leads users to an archive 654 | of the code. There are many ways you could offer source, and different 655 | solutions will be better for different programs; see section 13 for the 656 | specific requirements. 657 | 658 | You should also get your employer (if you work as a programmer) or school, 659 | if any, to sign a "copyright disclaimer" for the program, if necessary. 660 | For more information on this, and how to apply and follow the GNU AGPL, see 661 | . 662 | -------------------------------------------------------------------------------- /src/socket.rs: -------------------------------------------------------------------------------- 1 | use crate::configuration::UdtConfiguration; 2 | use crate::control_packet::{AckOptionalInfo, ControlPacketType, HandShakeInfo, UdtControlPacket}; 3 | use crate::data_packet::{UdtDataPacket, UDT_DATA_HEADER_SIZE}; 4 | use crate::flow::{UdtFlow, PROBE_MODULO}; 5 | use crate::multiplexer::UdtMultiplexer; 6 | use crate::packet::UdtPacket; 7 | use crate::queue::{RcvBuffer, SndBuffer}; 8 | use crate::rate_control::RateControl; 9 | use crate::seq_number::SeqNumber; 10 | use crate::state::SocketState; 11 | use crate::udt::{SocketRef, Udt, UDT_DEBUG}; 12 | use once_cell::sync::Lazy; 13 | use rand::distributions::Alphanumeric; 14 | use rand::Rng; 15 | use sha2::{Digest, Sha256}; 16 | use std::cmp::Ordering; 17 | use std::collections::BTreeSet; 18 | use std::net::{IpAddr, SocketAddr}; 19 | use std::sync::{Arc, Mutex, RwLock, Weak}; 20 | use std::task::Poll; 21 | use tokio::io::{Error, ErrorKind, ReadBuf, Result}; 22 | use tokio::sync::{Notify, RwLock as TokioRwLock}; 23 | use tokio::time::{Duration, Instant}; 24 | 25 | pub(crate) const SYN_INTERVAL: Duration = Duration::from_millis(10); 26 | const MIN_EXP_INTERVAL: Duration = Duration::from_millis(300); 27 | const PACKETS_BETWEEN_LIGHT_ACK: usize = 64; 28 | 29 | static SALT: Lazy = Lazy::new(|| { 30 | rand::thread_rng() 31 | .sample_iter(&Alphanumeric) 32 | .take(30) 33 | .map(char::from) 34 | .collect() 35 | }); 36 | 37 | pub type SocketId = u32; 38 | 39 | #[derive(Debug, Copy, Clone, PartialEq, Eq)] 40 | pub enum SocketType { 41 | Stream = 1, 42 | Datagram = 2, 43 | } 44 | 45 | impl TryFrom for SocketType { 46 | type Error = Error; 47 | 48 | fn try_from(value: u32) -> Result { 49 | match value { 50 | 1 => Ok(SocketType::Stream), 51 | 2 => Ok(SocketType::Datagram), 52 | _ => Err(Error::new( 53 | ErrorKind::InvalidData, 54 | "unknown value for socket type", 55 | )), 56 | } 57 | } 58 | } 59 | 60 | #[derive(Debug)] 61 | pub struct UdtSocket { 62 | pub socket_id: SocketId, 63 | pub status: Mutex, 64 | pub socket_type: SocketType, 65 | pub(crate) listen_socket: Option, 66 | peer_addr: Mutex>, 67 | peer_socket_id: Mutex>, 68 | pub initial_seq_number: SeqNumber, 69 | 70 | pub(crate) queued_sockets: TokioRwLock>, 71 | pub(crate) accept_notify: Notify, 72 | pub(crate) multiplexer: RwLock>, 73 | pub configuration: RwLock, 74 | 75 | rcv_buffer: Mutex, 76 | snd_buffer: Mutex, 77 | flow: RwLock, 78 | pub rate_control: RwLock, 79 | start_time: Instant, 80 | 81 | state: Mutex, 82 | 83 | connect_notify: Notify, 84 | rcv_notify: Notify, 85 | ack_notify: Notify, 86 | } 87 | 88 | impl UdtSocket { 89 | pub(crate) fn new( 90 | socket_id: SocketId, 91 | socket_type: SocketType, 92 | isn: Option, 93 | configuration: Option, 94 | ) -> Self { 95 | let now = Instant::now(); 96 | let initial_seq_number = isn.unwrap_or_else(SeqNumber::random); 97 | let configuration = configuration.unwrap_or_default(); 98 | Self { 99 | socket_id, 100 | socket_type, 101 | status: Mutex::new(UdtStatus::Init), 102 | initial_seq_number, 103 | peer_addr: Mutex::new(None), 104 | peer_socket_id: Mutex::new(None), 105 | listen_socket: None, 106 | queued_sockets: TokioRwLock::new(BTreeSet::new()), 107 | accept_notify: Notify::new(), 108 | multiplexer: RwLock::new(Weak::new()), 109 | snd_buffer: Mutex::new(SndBuffer::new(configuration.snd_buf_size)), 110 | rcv_buffer: Mutex::new(RcvBuffer::new( 111 | configuration.rcv_buf_size, 112 | initial_seq_number, 113 | )), 114 | flow: RwLock::new(UdtFlow::default()), 115 | rate_control: RwLock::new(RateControl::new()), 116 | start_time: now, 117 | 118 | state: Mutex::new(SocketState::new(initial_seq_number, &configuration)), 119 | connect_notify: Notify::new(), 120 | rcv_notify: Notify::new(), 121 | ack_notify: Notify::new(), 122 | configuration: RwLock::new(configuration), 123 | } 124 | } 125 | 126 | pub fn with_peer(self, peer: SocketAddr, peer_socket_id: SocketId) -> Self { 127 | self.set_peer_addr(peer); 128 | *self.peer_socket_id.lock().unwrap() = Some(peer_socket_id); 129 | self 130 | } 131 | 132 | fn set_peer_addr(&self, peer: SocketAddr) { 133 | *self.peer_addr.lock().unwrap() = Some(peer); 134 | self.snd_buffer 135 | .lock() 136 | .unwrap() 137 | .set_payload_size(self.get_max_payload_size() as usize); 138 | } 139 | 140 | pub fn with_listen_socket( 141 | mut self, 142 | listen_socket_id: SocketId, 143 | mux: Arc, 144 | ) -> Self { 145 | self.listen_socket = Some(listen_socket_id); 146 | *self.multiplexer.write().unwrap() = Arc::downgrade(&mux); 147 | self 148 | } 149 | 150 | pub fn open(&self) { 151 | *self.status.lock().unwrap() = UdtStatus::Opened; 152 | } 153 | 154 | fn rcv_buffer(&self) -> std::sync::MutexGuard { 155 | self.rcv_buffer.lock().unwrap() 156 | } 157 | 158 | pub(crate) fn peer_addr(&self) -> Option { 159 | *self.peer_addr.lock().unwrap() 160 | } 161 | 162 | pub(crate) fn peer_socket_id(&self) -> Option { 163 | *self.peer_socket_id.lock().unwrap() 164 | } 165 | 166 | fn state(&self) -> std::sync::MutexGuard { 167 | self.state.lock().unwrap() 168 | } 169 | 170 | pub(crate) async fn connect_on_handshake( 171 | self, 172 | peer: SocketAddr, 173 | mut hs: HandShakeInfo, 174 | ) -> Result { 175 | { 176 | let mut configuration = self.configuration.write().unwrap(); 177 | if hs.max_packet_size > configuration.mss { 178 | hs.max_packet_size = configuration.mss; 179 | } else { 180 | configuration.mss = hs.max_packet_size; 181 | } 182 | 183 | self.flow.write().unwrap().flow_window_size = hs.max_window_size; 184 | hs.max_window_size = 185 | std::cmp::min(configuration.rcv_buf_size, configuration.flight_flag_size); 186 | } 187 | // self.set_self_ip(hs.ip_address); 188 | hs.ip_address = peer.ip(); 189 | hs.socket_id = self.socket_id; 190 | 191 | // TODO: use network information cache to set RTT, bandwidth, etc. 192 | 193 | { 194 | let mut rate_control = self.rate_control.write().unwrap(); 195 | rate_control.init( 196 | self.configuration.read().unwrap().mss, 197 | &self.flow.read().unwrap(), 198 | self.state().curr_snd_seq_number, 199 | ) 200 | } 201 | 202 | *self.status.lock().unwrap() = UdtStatus::Connected; 203 | 204 | let packet = UdtControlPacket::new_handshake( 205 | hs, 206 | self.peer_socket_id().expect("peer_socket_id not defined"), 207 | ); 208 | 209 | if let Some(mux) = self.multiplexer() { 210 | mux.rcv_queue.push_back(self.socket_id); 211 | mux.send_to(&peer, packet.into()).await?; 212 | } 213 | 214 | let socket = Arc::new(self); 215 | Ok(socket) 216 | } 217 | 218 | pub fn set_multiplexer(&self, mux: &Arc) { 219 | *self.multiplexer.write().unwrap() = Arc::downgrade(mux); 220 | } 221 | 222 | pub(crate) fn multiplexer(&self) -> Option> { 223 | self.multiplexer.read().unwrap().upgrade() 224 | } 225 | 226 | pub(crate) async fn next_data_packets(&self) -> Result, Instant)>> { 227 | if !self.status().is_alive() { 228 | if *UDT_DEBUG { 229 | eprintln!( 230 | "No data to send: socket {} has status {:?}", 231 | self.socket_id, 232 | self.status() 233 | ); 234 | }; 235 | return Ok(None); 236 | } 237 | let now = Instant::now(); 238 | let mut probe = false; 239 | 240 | let to_resend = { 241 | let mut state = self.state(); 242 | let data_delay = now - state.next_data_target_time; 243 | 244 | if !data_delay.is_zero() { 245 | state.interpacket_time_diff += data_delay; 246 | } 247 | 248 | let last_data_ack_processed = state.last_data_ack_processed; 249 | state 250 | .snd_loss_list 251 | .pop_after(last_data_ack_processed) 252 | .map(|seq| (seq, seq - last_data_ack_processed)) 253 | }; 254 | 255 | let packets = match to_resend { 256 | Some((seq, offset)) => { 257 | // Loss retransmission has priority 258 | if offset < 0 { 259 | if *UDT_DEBUG { 260 | eprintln!("unexpected offset {} in sender loss list", offset); 261 | } 262 | return Ok(None); 263 | } 264 | let to_send = self.snd_buffer.lock().unwrap().read_data( 265 | offset as usize, 266 | seq, 267 | self.peer_socket_id().unwrap(), 268 | self.start_time, 269 | ); 270 | match to_send { 271 | Err((msg_number, msg_len)) => { 272 | if msg_len == 0 { 273 | return Ok(None); 274 | } 275 | let (start, end) = (seq, seq + msg_len as i32 - 1); 276 | let drop = UdtControlPacket::new_drop( 277 | msg_number, 278 | start, 279 | end, 280 | self.peer_socket_id().unwrap(), 281 | ); 282 | self.send_packet(drop.into()).await?; 283 | 284 | let mut state = self.state(); 285 | let last_data_ack_processed = state.last_data_ack_processed; 286 | state.snd_loss_list.remove_all(last_data_ack_processed, end); 287 | if (end + 1) - state.curr_snd_seq_number > 0 { 288 | state.curr_snd_seq_number = end + 1; 289 | } 290 | return Ok(None); 291 | } 292 | Ok(packet) => vec![packet], 293 | } 294 | } 295 | None => { 296 | let congestion_window_size = self 297 | .rate_control 298 | .read() 299 | .unwrap() 300 | .get_congestion_window_size(); 301 | let window_size = std::cmp::min( 302 | self.flow.read().unwrap().flow_window_size, 303 | congestion_window_size, 304 | ); 305 | let mut state = self.state(); 306 | if (state.curr_snd_seq_number - state.last_ack_received) > window_size as i32 { 307 | state.next_data_target_time = now; 308 | state.interpacket_time_diff = Duration::ZERO; 309 | return Ok(None); 310 | } 311 | match self.snd_buffer.lock().unwrap().fetch_batch( 312 | state.curr_snd_seq_number + 1, 313 | self.peer_socket_id().unwrap(), 314 | self.start_time, 315 | ) { 316 | packets if !packets.is_empty() => { 317 | let new_snd_seq_number = state.curr_snd_seq_number + packets.len() as i32; 318 | state.curr_snd_seq_number = new_snd_seq_number; 319 | self.rate_control 320 | .write() 321 | .unwrap() 322 | .set_curr_snd_seq_number(new_snd_seq_number); 323 | if state.curr_snd_seq_number.number() % 16 == 0 { 324 | probe = true; 325 | } 326 | packets 327 | } 328 | _ => { 329 | state.next_data_target_time = now; 330 | state.interpacket_time_diff = Duration::ZERO; 331 | return Ok(None); 332 | } 333 | } 334 | } 335 | }; 336 | 337 | // update stats 338 | if probe { 339 | return Ok(Some((packets, now))); 340 | } 341 | 342 | let mut state = self.state(); 343 | let interval = state.interpacket_interval * packets.len() as u32; 344 | let target_time = if state.interpacket_time_diff >= interval { 345 | state.interpacket_time_diff -= interval; 346 | now 347 | } else { 348 | let target = now + interval - state.interpacket_time_diff; 349 | state.interpacket_time_diff = Duration::ZERO; 350 | target 351 | }; 352 | 353 | Ok(Some((packets, target_time))) 354 | } 355 | 356 | fn compute_cookie(&self, addr: &SocketAddr, offset: Option) -> u32 { 357 | let timestamp = (self.start_time.elapsed().as_secs() / 60) + offset.unwrap_or(0) as u64; // secret changes every one minute 358 | let host = addr.ip(); 359 | let port = addr.port(); 360 | let salt: &str = &(*SALT); 361 | u32::from_be_bytes( 362 | Sha256::digest(format!("{salt}:{host}:{port}:{timestamp}").as_bytes())[..4] 363 | .try_into() 364 | .unwrap(), 365 | ) 366 | } 367 | 368 | pub(crate) async fn send_to(&self, addr: &SocketAddr, packet: UdtPacket) -> Result<()> { 369 | self.multiplexer() 370 | .expect("multiplexer not initialized") 371 | .send_to(addr, packet) 372 | .await?; 373 | Ok(()) 374 | } 375 | 376 | pub(crate) async fn listen_on_handshake( 377 | &self, 378 | addr: SocketAddr, 379 | hs: &HandShakeInfo, 380 | ) -> Result<()> { 381 | if !self.status().is_alive() { 382 | return Err(Error::new(ErrorKind::ConnectionRefused, "socket closed")); 383 | } 384 | 385 | if hs.connection_type == 1 { 386 | // Regular connection, respond to handshake 387 | let mut hs_response = hs.clone(); 388 | let dest_socket_id = hs_response.socket_id; 389 | hs_response.syn_cookie = self.compute_cookie(&addr, None); 390 | let hs_packet = UdtControlPacket::new_handshake(hs_response, dest_socket_id); 391 | self.send_to(&addr, hs_packet.into()).await?; 392 | return Ok(()); 393 | } 394 | 395 | if hs.connection_type != -1 { 396 | return Err(Error::new( 397 | ErrorKind::InvalidData, 398 | format!("invalid connection_type: {}", hs.connection_type), 399 | )); 400 | } 401 | 402 | // Validate client response 403 | let syn_cookie = hs.syn_cookie; 404 | if syn_cookie != self.compute_cookie(&addr, None) 405 | && syn_cookie != self.compute_cookie(&addr, Some(-1)) 406 | { 407 | // Invalid cookie; 408 | return Err(Error::new(ErrorKind::PermissionDenied, "invalid cookie")); 409 | } 410 | 411 | let dest_socket_id = hs.socket_id; 412 | let udt_version = self.configuration.read().unwrap().udt_version(); 413 | if hs.udt_version != udt_version || hs.socket_type != self.socket_type { 414 | // Reject request 415 | let mut hs_response = hs.clone(); 416 | hs_response.connection_type = 1002; // Error codes defined in C++ implementation 417 | let hs_packet = UdtControlPacket::new_handshake(hs_response, dest_socket_id); 418 | self.send_to(&addr, hs_packet.into()).await?; 419 | return Err(Error::new( 420 | ErrorKind::ConnectionRefused, 421 | "configuration mismatch", 422 | )); 423 | } 424 | 425 | Udt::get() 426 | .write() 427 | .await 428 | .new_connection(self, addr, hs) 429 | .await?; 430 | // Send handshake packet in case of errors on connection? 431 | 432 | Ok(()) 433 | } 434 | 435 | pub(crate) async fn process_packet(&self, packet: UdtPacket) -> Result<()> { 436 | match packet { 437 | UdtPacket::Control(ctrl) => self.process_ctrl(ctrl).await, 438 | UdtPacket::Data(data) => self.process_data(data).await, 439 | } 440 | } 441 | 442 | async fn process_ctrl(&self, packet: UdtControlPacket) -> Result<()> { 443 | { 444 | let mut state = self.state(); 445 | state.exp_count = 1; 446 | state.last_rsp_time = Instant::now(); 447 | } 448 | 449 | match packet.packet_type { 450 | ControlPacketType::Handshake(hs) => { 451 | if self.status() != UdtStatus::Connecting { 452 | return Err(Error::new( 453 | ErrorKind::InvalidData, 454 | format!( 455 | "unexpected handshake for socket with status {:?}", 456 | self.status 457 | ), 458 | )); 459 | } 460 | 461 | // TODO: handle rendezvous mode 462 | if hs.connection_type > 0 { 463 | let mut hs = hs.clone(); 464 | hs.connection_type = -1; 465 | hs.socket_id = self.socket_id; 466 | let hs_packet = UdtControlPacket::new_handshake(hs, 0); 467 | self.send_packet(hs_packet.into()).await?; 468 | } else { 469 | // post connect 470 | let mut configuration = self.configuration.write().unwrap(); 471 | configuration.mss = hs.max_packet_size; 472 | configuration.flight_flag_size = hs.max_window_size; 473 | let mut state = self.state(); 474 | state.last_sent_ack = hs.initial_seq_number; 475 | state.last_ack2_received = hs.initial_seq_number; 476 | state.curr_rcv_seq_number = hs.initial_seq_number - 1; 477 | *self.peer_socket_id.lock().unwrap() = Some(hs.socket_id); 478 | 479 | { 480 | let mut rate_control = self.rate_control.write().unwrap(); 481 | rate_control.init( 482 | configuration.mss, 483 | &self.flow.read().unwrap(), 484 | state.curr_snd_seq_number, 485 | ) 486 | } 487 | 488 | *self.status.lock().unwrap() = UdtStatus::Connected; 489 | self.connect_notify.notify_waiters(); 490 | } 491 | } 492 | ControlPacketType::KeepAlive => (), 493 | ControlPacketType::Ack(ref ack) => { 494 | match &ack.info { 495 | None => { 496 | let mut state = self.state(); 497 | let seq = ack.next_seq_number; 498 | let nb_acked = seq - state.last_ack_received; 499 | if nb_acked >= 0 { 500 | state.last_ack_received = seq; 501 | self.flow.write().unwrap().flow_window_size -= (nb_acked) as u32; 502 | } 503 | } 504 | Some(extra) => { 505 | let ack_seq = packet.ack_seq_number().unwrap(); 506 | let send_ack2 = { 507 | let state = self.state(); 508 | state.last_ack2_time.elapsed() > SYN_INTERVAL 509 | || ack_seq == state.last_ack2_sent_back 510 | }; 511 | if send_ack2 { 512 | if let Some(peer) = self.peer_socket_id() { 513 | let ack2_packet = UdtControlPacket::new_ack2(ack_seq, peer); 514 | self.send_packet(ack2_packet.into()).await?; 515 | let mut state = self.state(); 516 | state.last_ack2_sent_back = ack_seq; 517 | state.last_ack2_time = Instant::now(); 518 | } 519 | } 520 | 521 | let seq = ack.next_seq_number; 522 | 523 | { 524 | let mut state = self.state(); 525 | if (seq - state.curr_snd_seq_number) > 1 { 526 | // This should not happen 527 | if *UDT_DEBUG { 528 | eprintln!( 529 | "Udt socket broken: seq number is larger than expected" 530 | ); 531 | }; 532 | *self.status.lock().unwrap() = UdtStatus::Broken; 533 | } 534 | 535 | if (seq - state.last_ack_received) >= 0 { 536 | self.flow.write().unwrap().flow_window_size = 537 | extra.available_buf_size; 538 | state.last_ack_received = seq; 539 | } 540 | 541 | let offset = seq - state.last_data_ack_processed; 542 | if offset <= 0 { 543 | // Ignore repeated acks 544 | return Ok(()); 545 | } 546 | 547 | self.snd_buffer.lock().unwrap().ack_data(offset); 548 | let last_data_ack_processed = state.last_data_ack_processed; 549 | state 550 | .snd_loss_list 551 | .remove_all(last_data_ack_processed, seq - 1); 552 | // TODO record times for monitoring purposes 553 | state.last_data_ack_processed = seq; 554 | self.update_snd_queue(false); 555 | self.ack_notify.notify_waiters(); 556 | } 557 | 558 | let mut flow = self.flow.write().unwrap(); 559 | flow.update_rtt(Duration::from_micros(extra.rtt.into())); 560 | flow.update_rtt_var(Duration::from_micros(extra.rtt_variance.into())); 561 | 562 | { 563 | let mut rate_control = self.rate_control.write().unwrap(); 564 | rate_control.set_rtt(flow.rtt); 565 | 566 | if extra.pack_recv_rate > 0 { 567 | flow.update_peer_delivery_rate(extra.pack_recv_rate); 568 | rate_control.set_rcv_rate(flow.peer_delivery_rate); 569 | } 570 | if extra.link_capacity > 0 { 571 | flow.update_bandwidth(extra.link_capacity); 572 | rate_control.set_bandwidth(flow.peer_bandwidth); 573 | } 574 | 575 | rate_control.on_ack(seq); 576 | } 577 | self.cc_update(); 578 | } 579 | } 580 | } 581 | ControlPacketType::Ack2 => { 582 | let ack_seq = packet.ack_seq_number().unwrap(); 583 | let window = self.state().ack_window.get(ack_seq); 584 | if let Some((seq, rtt)) = window { 585 | let mut flow = self.flow.write().unwrap(); 586 | let rtt_abs_diff = { 587 | if rtt > flow.rtt { 588 | rtt - flow.rtt 589 | } else { 590 | flow.rtt - rtt 591 | } 592 | }; 593 | flow.update_rtt_var(rtt_abs_diff); 594 | flow.update_rtt(rtt); 595 | drop(flow); 596 | let mut state = self.state(); 597 | if (seq - state.last_ack2_received) > 0 { 598 | state.last_ack2_received = seq; 599 | } 600 | } 601 | } 602 | ControlPacketType::Nak(ref nak) => { 603 | let mut broken = false; 604 | { 605 | let mut rate_control = self.rate_control.write().unwrap(); 606 | if nak.loss_info.is_empty() { 607 | if *UDT_DEBUG { 608 | eprintln!("Received NAK with empty list"); 609 | } 610 | return Ok(()); 611 | } 612 | rate_control.on_loss((nak.loss_info[0] & 0x7fff_ffff).into()); 613 | } 614 | self.cc_update(); 615 | 616 | let loss_iter = &mut nak.loss_info.iter(); 617 | let mut state = self.state(); 618 | while let Some(loss) = loss_iter.next() { 619 | let (seq_start, seq_end) = { 620 | if loss & 0x8000_0000 != 0 { 621 | if let Some(seq_end) = loss_iter.next() { 622 | let seq_start: SeqNumber = (loss & 0x7fff_ffff).into(); 623 | let seq_end: SeqNumber = (*seq_end).into(); 624 | (seq_start, seq_end) 625 | } else { 626 | broken = true; 627 | break; 628 | } 629 | } else { 630 | ((*loss).into(), (*loss).into()) 631 | } 632 | }; 633 | if (seq_start - seq_end > 0) || (seq_end - state.curr_snd_seq_number > 0) { 634 | broken = true; 635 | break; 636 | } 637 | if seq_start - state.last_ack_received >= 0 { 638 | state.snd_loss_list.insert(seq_start, seq_end); 639 | } else if seq_end - state.last_ack_received >= 0 { 640 | let last_ack_received = state.last_ack_received; 641 | state.snd_loss_list.insert(last_ack_received, seq_end); 642 | } 643 | } 644 | 645 | if broken { 646 | println!("NAK is broken: {:?} {:?}", nak, state); 647 | *self.status.lock().unwrap() = UdtStatus::Broken; 648 | return Ok(()); 649 | } 650 | 651 | self.update_snd_queue(true); 652 | } 653 | ControlPacketType::Shutdown => { 654 | *self.status.lock().unwrap() = UdtStatus::Closing; 655 | self.notify_all(); 656 | } 657 | ControlPacketType::MsgDropRequest(ref drop) => { 658 | let msg_number = packet.msg_seq_number().unwrap(); 659 | self.rcv_buffer.lock().unwrap().drop_msg(msg_number); 660 | let mut state = self.state(); 661 | state 662 | .rcv_loss_list 663 | .remove_all(drop.first_seq_number, drop.last_seq_number); 664 | if (drop.first_seq_number - (state.curr_rcv_seq_number + 1)) <= 0 665 | && (drop.last_seq_number - state.curr_rcv_seq_number) > 0 666 | { 667 | state.curr_rcv_seq_number = drop.last_seq_number; 668 | } 669 | } 670 | ControlPacketType::UserDefined => unimplemented!(), 671 | } 672 | Ok(()) 673 | } 674 | 675 | async fn process_data(&self, packet: UdtDataPacket) -> Result<()> { 676 | let now = Instant::now(); 677 | { 678 | let mut state = self.state(); 679 | state.last_rsp_time = now; 680 | state.pkt_count += 1; 681 | } 682 | 683 | let seq_number = packet.header.seq_number; 684 | 685 | { 686 | let mut flow = self.flow.write().unwrap(); 687 | flow.on_pkt_arrival(now); 688 | 689 | if seq_number.number() % PROBE_MODULO == 0 { 690 | flow.on_probe1_arrival(); 691 | } else if seq_number.number() % PROBE_MODULO == 1 { 692 | flow.on_probe2_arrival(); 693 | } 694 | } 695 | 696 | // trace_rcv++ 697 | // recv_total++ 698 | let offset = seq_number - self.state().last_sent_ack; 699 | if offset < 0 { 700 | // seq number is too late 701 | return Ok(()); 702 | } 703 | 704 | let payload_len = { 705 | let mut rcv_buffer = self.rcv_buffer(); 706 | let available_buf_size = rcv_buffer.get_available_buf_size(); 707 | if available_buf_size < offset as u32 { 708 | if *UDT_DEBUG { 709 | eprintln!("not enough space in rcv buffer"); 710 | } 711 | return Ok(()); 712 | } 713 | 714 | let payload_len = packet.payload_len(); 715 | rcv_buffer.insert(packet); 716 | payload_len 717 | }; 718 | 719 | if (seq_number - self.state().curr_rcv_seq_number) > 1 { 720 | // some packets have been lost in between 721 | let nak_packet = { 722 | let mut state = self.state(); 723 | let curr_rcv_seq_number = state.curr_rcv_seq_number; 724 | state 725 | .rcv_loss_list 726 | .insert(curr_rcv_seq_number + 1, seq_number - 1); 727 | 728 | // send NAK immediately 729 | let loss_list = { 730 | if state.curr_rcv_seq_number + 1 == seq_number - 1 { 731 | vec![(seq_number - 1).number()] 732 | } else { 733 | vec![ 734 | (state.curr_rcv_seq_number + 1).number() | 0x8000_0000, 735 | (seq_number - 1).number(), 736 | ] 737 | } 738 | }; 739 | UdtControlPacket::new_nak(loss_list, self.peer_socket_id().unwrap_or(0)) 740 | }; 741 | self.send_packet(nak_packet.into()).await?; 742 | } 743 | 744 | if payload_len < self.get_max_payload_size() { 745 | self.state().next_ack_time = Instant::now(); 746 | } 747 | 748 | let mut state = self.state(); 749 | 750 | if seq_number - state.curr_rcv_seq_number > 0 { 751 | state.curr_rcv_seq_number = seq_number; 752 | } else { 753 | state.rcv_loss_list.remove(seq_number); 754 | } 755 | 756 | Ok(()) 757 | } 758 | 759 | pub fn get_max_payload_size(&self) -> u32 { 760 | let configuration = self.configuration.read().unwrap(); 761 | match self.peer_addr().map(|a| a.ip()) { 762 | Some(IpAddr::V6(_)) => configuration.mss - 40 - UDT_DATA_HEADER_SIZE as u32, 763 | _ => configuration.mss - 28 - UDT_DATA_HEADER_SIZE as u32, 764 | } 765 | } 766 | 767 | pub(crate) async fn send_packet(&self, packet: UdtPacket) -> Result<()> { 768 | if let Some(addr) = self.peer_addr() { 769 | self.send_to(&addr, packet).await?; 770 | } 771 | Ok(()) 772 | } 773 | 774 | pub(crate) async fn send_data_packets(&self, packets: Vec) -> Result<()> { 775 | if let Some(addr) = self.peer_addr() { 776 | self.multiplexer() 777 | .expect("multiplexer not initialized") 778 | .send_mmsg_to(&addr, packets.into_iter().map(|p| p.into())) 779 | .await?; 780 | } 781 | Ok(()) 782 | } 783 | 784 | async fn send_ack(&self, light: bool) -> Result<()> { 785 | let seq_number = { 786 | let state = self.state(); 787 | let seq_number = match state 788 | .rcv_loss_list 789 | .peek_after(state.curr_rcv_seq_number + 1) 790 | { 791 | Some(num) => num, 792 | None => state.curr_rcv_seq_number + 1, 793 | }; 794 | if seq_number == state.last_ack2_received { 795 | return Ok(()); 796 | } 797 | seq_number 798 | }; 799 | 800 | if light { 801 | // Save time on buffer procesing and bandwith measurement 802 | let ack_packet = UdtControlPacket::new_ack( 803 | 0.into(), 804 | seq_number, 805 | self.peer_socket_id().unwrap(), 806 | None, 807 | ); 808 | self.send_packet(ack_packet.into()).await?; 809 | return Ok(()); 810 | } 811 | 812 | { 813 | let mut state = self.state(); 814 | let to_ack: i32 = seq_number - state.last_sent_ack; 815 | match to_ack.cmp(&0) { 816 | Ordering::Greater => { 817 | self.rcv_buffer().ack_data(seq_number); 818 | state.last_sent_ack = seq_number; 819 | self.rcv_notify.notify_waiters(); 820 | } 821 | Ordering::Equal => { 822 | let last_sent_ack_elapsed = state.last_sent_ack_time.elapsed(); 823 | drop(state); 824 | let flow = self.flow.read().unwrap(); 825 | if last_sent_ack_elapsed < (flow.rtt + 4 * flow.rtt_var) { 826 | return Ok(()); 827 | } 828 | } 829 | _ => { 830 | return Ok(()); 831 | } 832 | } 833 | } 834 | 835 | let ack_packet = { 836 | let mut state = self.state(); 837 | if (state.last_sent_ack - state.last_ack2_received) > 0 { 838 | state.last_ack_seq_number = state.last_ack_seq_number + 1; 839 | drop(state); 840 | let mut ack_info = { 841 | let flow = self.flow.read().unwrap(); 842 | AckOptionalInfo { 843 | rtt: flow.rtt.as_micros().try_into().unwrap_or(u32::MAX), 844 | rtt_variance: flow.rtt_var.as_micros().try_into().unwrap_or(u32::MAX), 845 | available_buf_size: std::cmp::max( 846 | self.rcv_buffer().get_available_buf_size(), 847 | 2, 848 | ), 849 | pack_recv_rate: 0, 850 | link_capacity: 0, 851 | } 852 | }; 853 | if self.state().last_sent_ack_time.elapsed() > SYN_INTERVAL { 854 | let flow = self.flow.read().unwrap(); 855 | ack_info.pack_recv_rate = flow.get_pkt_rcv_speed(); 856 | ack_info.link_capacity = flow.get_bandwidth(); 857 | self.state().last_sent_ack_time = Instant::now(); 858 | } 859 | let state = self.state(); 860 | Some(UdtControlPacket::new_ack( 861 | state.last_ack_seq_number, 862 | state.last_sent_ack, 863 | self.peer_socket_id().unwrap(), 864 | Some(ack_info), 865 | )) 866 | } else { 867 | None 868 | } 869 | }; 870 | 871 | if let Some(ack_packet) = ack_packet { 872 | self.send_packet(ack_packet.into()).await?; 873 | let mut state = self.state(); 874 | let last_sent_ack = state.last_sent_ack; 875 | let last_ack_seq_number = state.last_ack_seq_number; 876 | state.ack_window.store(last_sent_ack, last_ack_seq_number); 877 | } 878 | 879 | Ok(()) 880 | } 881 | 882 | fn cc_update(&self) { 883 | let mut state = self.state(); 884 | state.interpacket_interval = self.rate_control.read().unwrap().get_pkt_send_period(); 885 | } 886 | 887 | pub(crate) async fn check_timers(&self) { 888 | self.cc_update(); 889 | let now = Instant::now(); 890 | 891 | let ack_interval = self.rate_control.read().unwrap().get_ack_pkt_interval(); 892 | if now > self.state().next_ack_time 893 | || (ack_interval > 0 && ack_interval <= self.state().pkt_count) 894 | { 895 | self.send_ack(false).await.unwrap_or_else(|err| { 896 | if *UDT_DEBUG { 897 | eprintln!("failed to send ack: {:?}", err); 898 | } 899 | }); 900 | let ack_period = self.rate_control.read().unwrap().get_ack_period(); 901 | let mut state = self.state(); 902 | state.next_ack_time = now + ack_period; 903 | state.pkt_count = 0; 904 | state.light_ack_counter = 0; 905 | } else { 906 | let send_light_ack = { 907 | let state = self.state(); 908 | (state.light_ack_counter + 1) * PACKETS_BETWEEN_LIGHT_ACK <= state.pkt_count 909 | }; 910 | if send_light_ack { 911 | self.send_ack(true).await.unwrap_or_else(|err| { 912 | if *UDT_DEBUG { 913 | eprintln!("failed to send ack: {:?}", err); 914 | } 915 | }); 916 | self.state().light_ack_counter += 1; 917 | } 918 | } 919 | 920 | let next_exp_time = { 921 | let (rtt, rtt_var) = { 922 | let flow = self.flow.read().unwrap(); 923 | (flow.rtt, flow.rtt_var) 924 | }; 925 | let state = self.state(); 926 | let exp_int = state.exp_count * (rtt + 4 * rtt_var) + SYN_INTERVAL; 927 | let next_exp = std::cmp::max(exp_int, state.exp_count * MIN_EXP_INTERVAL); 928 | state.last_rsp_time + next_exp 929 | }; 930 | if now > next_exp_time { 931 | { 932 | let state = self.state(); 933 | if state.exp_count > 16 && state.last_rsp_time.elapsed() > Duration::from_secs(5) { 934 | // Connection is broken 935 | *self.status.lock().unwrap() = UdtStatus::Broken; 936 | self.update_snd_queue(true); 937 | return; 938 | } 939 | } 940 | 941 | if self.snd_buffer.lock().unwrap().is_empty() { 942 | if let Some(peer_socket_id) = self.peer_socket_id() { 943 | let keep_alive = UdtControlPacket::new_keep_alive(peer_socket_id); 944 | self.send_packet(keep_alive.into()) 945 | .await 946 | .unwrap_or_else(|err| { 947 | if *UDT_DEBUG { 948 | eprintln!("failed to send keep alive: {:?}", err); 949 | }; 950 | }); 951 | } 952 | } else { 953 | { 954 | let mut state = self.state(); 955 | if (state.last_ack_received != state.curr_snd_seq_number + 1) 956 | && state.snd_loss_list.is_empty() 957 | { 958 | let last_ack_received = state.last_ack_received; 959 | let curr_snd_seq_number = state.curr_snd_seq_number; 960 | state 961 | .snd_loss_list 962 | .insert(last_ack_received, curr_snd_seq_number); 963 | } 964 | } 965 | 966 | self.rate_control.write().unwrap().on_timeout(); 967 | self.cc_update(); 968 | self.update_snd_queue(true); 969 | } 970 | 971 | let mut state = self.state(); 972 | state.exp_count += 1; 973 | // Reset last response time since we just sent a heart-beat. 974 | state.last_rsp_time = now; 975 | } 976 | } 977 | 978 | fn update_snd_queue(&self, reschedule: bool) { 979 | if let Some(mux) = self.multiplexer() { 980 | mux.snd_queue.update(self.socket_id, reschedule); 981 | } 982 | } 983 | 984 | pub fn send(&self, data: &[u8]) -> Result<()> { 985 | if self.socket_type != SocketType::Stream { 986 | return Err(Error::new( 987 | ErrorKind::InvalidInput, 988 | "socket needs to be configured in stream mode to send data buffer", 989 | )); 990 | } 991 | if self.status() != UdtStatus::Connected { 992 | return Err(Error::new( 993 | ErrorKind::NotConnected, 994 | "UDT socket is not connected", 995 | )); 996 | } 997 | 998 | if data.is_empty() { 999 | return Ok(()); 1000 | } 1001 | 1002 | if self.snd_buffer.lock().unwrap().is_empty() { 1003 | // delay the EXP timer to avoid mis-fired timeout 1004 | self.state().last_rsp_time = Instant::now(); 1005 | } 1006 | 1007 | self.snd_buffer 1008 | .lock() 1009 | .unwrap() 1010 | .add_message(data, None, false)?; 1011 | self.update_snd_queue(false); 1012 | Ok(()) 1013 | } 1014 | 1015 | pub async fn recv(&self, buf: &mut [u8]) -> Result { 1016 | if self.socket_type != SocketType::Stream { 1017 | return Err(Error::new( 1018 | ErrorKind::InvalidInput, 1019 | "cannot recv on non-stream socket", 1020 | )); 1021 | } 1022 | let status = self.status(); 1023 | if !status.is_alive() { 1024 | if !self.rcv_buffer().has_data_to_read() { 1025 | return Err(Error::new( 1026 | ErrorKind::BrokenPipe, 1027 | "connection was closed or broken", 1028 | )); 1029 | } 1030 | } else if status != UdtStatus::Connected { 1031 | return Err(Error::new( 1032 | ErrorKind::NotConnected, 1033 | "UDT socket not connected", 1034 | )); 1035 | } 1036 | 1037 | if buf.is_empty() { 1038 | return Ok(0); 1039 | } 1040 | 1041 | self.wait_for_data_to_read().await; 1042 | 1043 | let status = self.status(); 1044 | if !status.is_alive() { 1045 | if !self.rcv_buffer().has_data_to_read() { 1046 | return Err(Error::new( 1047 | ErrorKind::BrokenPipe, 1048 | "connection was closed or broken", 1049 | )); 1050 | } 1051 | } else if status != UdtStatus::Connected { 1052 | return Err(Error::new( 1053 | ErrorKind::NotConnected, 1054 | "UDT socket not connected", 1055 | )); 1056 | } 1057 | 1058 | let mut buf = ReadBuf::new(buf); 1059 | let written = self.rcv_buffer().read_buffer(&mut buf); 1060 | 1061 | // TODO: implement configurable UDT timeout 1062 | Ok(written) 1063 | } 1064 | 1065 | pub(crate) fn poll_recv(&self, buf: &mut ReadBuf<'_>) -> Poll> { 1066 | if self.socket_type != SocketType::Stream { 1067 | return Poll::Ready(Err(Error::new( 1068 | ErrorKind::InvalidInput, 1069 | "cannot recv on non-stream socket", 1070 | ))); 1071 | } 1072 | let status = self.status(); 1073 | if !status.is_alive() { 1074 | if !self.rcv_buffer().has_data_to_read() { 1075 | return Poll::Ready(Err(Error::new( 1076 | ErrorKind::BrokenPipe, 1077 | "connection was closed or broken", 1078 | ))); 1079 | } 1080 | } else if status != UdtStatus::Connected { 1081 | return Poll::Ready(Err(Error::new( 1082 | ErrorKind::NotConnected, 1083 | "UDT socket not connected", 1084 | ))); 1085 | } 1086 | 1087 | if !self.rcv_buffer().has_data_to_read() { 1088 | return Poll::Pending; 1089 | } 1090 | 1091 | if buf.remaining() == 0 { 1092 | return Poll::Ready(Ok(0)); 1093 | } 1094 | let written = self.rcv_buffer().read_buffer(buf); 1095 | Poll::Ready(Ok(written)) 1096 | } 1097 | 1098 | pub(crate) async fn connect( 1099 | &self, 1100 | addr: SocketAddr, 1101 | bind_addr: Option, 1102 | ) -> Result<()> { 1103 | if self.status() != UdtStatus::Init { 1104 | return Err(Error::new( 1105 | ErrorKind::Unsupported, 1106 | format!("expected status Init, found {:?}", self.status), 1107 | )); 1108 | } 1109 | 1110 | self.open(); 1111 | { 1112 | let mut udt = Udt::get().write().await; 1113 | udt.update_mux(self, bind_addr).await?; 1114 | } 1115 | 1116 | *self.status.lock().unwrap() = UdtStatus::Connecting; 1117 | self.set_peer_addr(addr); 1118 | 1119 | // TODO: register the current socket in the rendezvous queue 1120 | // This is used to temporarily store incoming handshakes and possibly retry connections, including for non-rendezvous connections. 1121 | 1122 | let hs_packet = { 1123 | let configuration = self.configuration.read().unwrap(); 1124 | let hs = HandShakeInfo { 1125 | udt_version: configuration.udt_version(), 1126 | initial_seq_number: self.initial_seq_number, 1127 | max_packet_size: configuration.mss, 1128 | max_window_size: std::cmp::min( 1129 | self.flow.read().unwrap().flow_window_size, 1130 | self.rcv_buffer().get_available_buf_size(), 1131 | ), 1132 | connection_type: 1, 1133 | socket_type: self.socket_type, 1134 | socket_id: self.socket_id, 1135 | ip_address: addr.ip(), 1136 | syn_cookie: 0, 1137 | }; 1138 | UdtControlPacket::new_handshake(hs, 0) 1139 | }; 1140 | self.send_to(&addr, hs_packet.into()).await?; 1141 | 1142 | Ok(()) 1143 | } 1144 | 1145 | pub fn status(&self) -> UdtStatus { 1146 | *self.status.lock().unwrap() 1147 | } 1148 | 1149 | pub fn snd_buffer_is_empty(&self) -> bool { 1150 | self.snd_buffer.lock().unwrap().is_empty() 1151 | } 1152 | 1153 | pub async fn close(&self) { 1154 | let status = self.status(); 1155 | if status == UdtStatus::Closed || status == UdtStatus::Closing { 1156 | return; 1157 | } 1158 | let now = Instant::now(); 1159 | let linger_timeout = self 1160 | .configuration 1161 | .read() 1162 | .unwrap() 1163 | .linger_timeout 1164 | .unwrap_or(Duration::ZERO); 1165 | 1166 | while self.status() == UdtStatus::Connected 1167 | && !self.snd_buffer_is_empty() 1168 | && now.elapsed() < linger_timeout 1169 | { 1170 | self.wait_for_next_ack_or_empty_snd_buffer().await; 1171 | } 1172 | 1173 | if let Some(mux) = self.multiplexer() { 1174 | mux.snd_queue.remove(self.socket_id); 1175 | let listener_id = mux.listener.read().await.clone().map(|s| s.socket_id); 1176 | if listener_id == Some(self.socket_id) { 1177 | *mux.listener.write().await = None; 1178 | } 1179 | } 1180 | 1181 | // TODO: remove socket from rendezvous queue 1182 | 1183 | if self.status() == UdtStatus::Connected { 1184 | let shutdown = UdtControlPacket::new_shutdown(self.peer_socket_id().unwrap()); 1185 | self.send_packet(shutdown.into()) 1186 | .await 1187 | .unwrap_or_else(|err| { 1188 | if *UDT_DEBUG { 1189 | eprintln!("Failed to send shutdown packet: {}", err); 1190 | } 1191 | }); 1192 | } 1193 | 1194 | // TODO: keep channel stats (RTT, bandwidth, etc.) in a cache for more efficient reconnections. 1195 | *self.status.lock().unwrap() = UdtStatus::Closing; 1196 | self.notify_all(); 1197 | } 1198 | 1199 | fn notify_all(&self) { 1200 | self.accept_notify.notify_waiters(); 1201 | self.rcv_notify.notify_waiters(); 1202 | self.connect_notify.notify_waiters(); 1203 | } 1204 | 1205 | pub(crate) async fn wait_for_data_to_read(&self) { 1206 | if let Some(notified) = { 1207 | let status = self.status.lock().unwrap(); 1208 | if !status.is_alive() { 1209 | None 1210 | } else { 1211 | let rcv_buffer = self.rcv_buffer(); 1212 | if rcv_buffer.has_data_to_read() { 1213 | None 1214 | } else { 1215 | Some(self.rcv_notify.notified()) 1216 | } 1217 | } 1218 | } { 1219 | notified.await 1220 | } 1221 | } 1222 | 1223 | pub(crate) async fn wait_for_connection(&self) -> UdtStatus { 1224 | if let Some(notified) = { 1225 | let status = self.status.lock().unwrap(); 1226 | if *status != UdtStatus::Connecting { 1227 | None 1228 | } else { 1229 | Some(self.connect_notify.notified()) 1230 | } 1231 | } { 1232 | notified.await 1233 | } 1234 | self.status() 1235 | } 1236 | 1237 | pub(crate) async fn wait_for_next_ack_or_empty_snd_buffer(&self) { 1238 | if let Some(notified) = { 1239 | let snd_buffer = self.snd_buffer.lock().unwrap(); 1240 | if snd_buffer.is_empty() { 1241 | None 1242 | } else { 1243 | Some(self.ack_notify.notified()) 1244 | } 1245 | } { 1246 | notified.await 1247 | } 1248 | } 1249 | } 1250 | 1251 | impl Ord for UdtSocket { 1252 | fn cmp(&self, other: &Self) -> Ordering { 1253 | self.socket_id.cmp(&other.socket_id) 1254 | } 1255 | } 1256 | 1257 | impl PartialOrd for UdtSocket { 1258 | fn partial_cmp(&self, other: &Self) -> Option { 1259 | Some(self.cmp(other)) 1260 | } 1261 | } 1262 | 1263 | impl PartialEq for UdtSocket { 1264 | fn eq(&self, other: &Self) -> bool { 1265 | self.socket_id == other.socket_id 1266 | } 1267 | } 1268 | 1269 | impl Eq for UdtSocket {} 1270 | 1271 | #[derive(Debug, PartialEq, Clone, Copy, Eq)] 1272 | pub enum UdtStatus { 1273 | Init, 1274 | Opened, 1275 | Listening, 1276 | Connecting, 1277 | Connected, 1278 | Broken, 1279 | Closing, 1280 | Closed, 1281 | } 1282 | 1283 | impl UdtStatus { 1284 | pub(crate) fn is_alive(&self) -> bool { 1285 | *self != UdtStatus::Broken && *self != UdtStatus::Closing && *self != UdtStatus::Closed 1286 | } 1287 | } 1288 | --------------------------------------------------------------------------------