├── .github └── workflows │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── icmp.rs ├── ipv4_ipv6.rs ├── tcp.rs ├── tcp_connect.rs ├── tcp_proxy.rs └── udp.rs ├── rustfmt.toml └── src ├── address └── mod.rs ├── buffer.rs ├── icmp └── mod.rs ├── ip └── mod.rs ├── ip_stack.rs ├── lib.rs ├── tcp ├── mod.rs ├── sys.rs ├── tcb.rs └── tcp_queue.rs └── udp └── mod.rs /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Push or PR 2 | 3 | on: 4 | [push, pull_request] 5 | 6 | env: 7 | CARGO_TERM_COLOR: always 8 | 9 | jobs: 10 | build_n_test: 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [ubuntu-latest] 15 | 16 | runs-on: ${{ matrix.os }} 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: rustfmt 21 | run: cargo fmt --all -- --check 22 | - name: check 23 | run: cargo check --verbose 24 | - name: clippy 25 | run: cargo clippy --all-targets --all-features -- -D warnings 26 | - name: build 27 | run: cargo build --verbose --examples --tests --all-features 28 | - name: test 29 | run: cargo test --all-features --examples -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .VSCodeCounter/ 3 | .idea/ 4 | /target 5 | Cargo.lock 6 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "tcp_ip" 3 | version = "0.1.10" 4 | edition = "2021" 5 | license = "Apache-2.0" 6 | readme = "README.md" 7 | description = "User-space TCP/IP stack" 8 | repository = "https://github.com/rustp2p/tcp_ip" 9 | keywords = ["ip", "tcp", "udp", "icmp", "network"] 10 | 11 | [dependencies] 12 | bytes = "1.9" 13 | dashmap = "6.1" 14 | flume = { version = "0.11", features = ["async"] } 15 | log = "0.4" 16 | parking_lot = "0.12" 17 | pnet_packet = "0.35" 18 | rand = "0.9" 19 | tokio = { version = "1.42", features = ["macros", "rt", "time"] } 20 | tokio-util = "0.7" 21 | num_enum = "0.7" 22 | lazy_static = { version = "1.5.0", optional = true } 23 | 24 | [features] 25 | default = [] 26 | global-ip-stack = ["lazy_static"] 27 | 28 | [dev-dependencies] 29 | tokio = { version = "1.42", features = ["full"] } 30 | anyhow = "1" 31 | env_logger = "0.11" 32 | tun-rs = { version = "1.5.0", features = ["async"] } 33 | clap = { version = "4", features = ["derive"] } 34 | 35 | [[example]] 36 | name = "tcp" 37 | required-features = ["global-ip-stack"] 38 | [[example]] 39 | name = "udp" 40 | required-features = ["global-ip-stack"] 41 | [[example]] 42 | name = "icmp" 43 | required-features = ["global-ip-stack"] 44 | [[example]] 45 | name = "ipv4_ipv6" 46 | required-features = ["global-ip-stack"] 47 | [[example]] 48 | name = "tcp_proxy" 49 | required-features = ["global-ip-stack"] 50 | [[example]] 51 | name = "tcp_connect" 52 | required-features = ["global-ip-stack"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Crates.io](https://img.shields.io/crates/v/tcp_ip.svg)](https://crates.io/crates/tcp_ip) 2 | [![tcp_ip](https://docs.rs/tcp_ip/badge.svg)](https://docs.rs/tcp_ip/latest/tcp_ip/) 3 | 4 | # tcp_ip 5 | 6 | User-space TCP/IP stack 7 | 8 | ## Features 9 | 10 | #### IPv4 11 | 12 | - IPv4 fragmentation and reassembly is supported. 13 | - IPv4 options are not supported and are silently ignored. 14 | 15 | #### IPv6 16 | 17 | - In development, currently does not support any extended protocols. 18 | 19 | #### UDP 20 | 21 | Use UdpSocket. Supported over IPv4 and IPv6. 22 | 23 | #### ICMPv4 & ICMPv6 24 | 25 | Use IcmpSocket or IcmpV6Socket. The user needs to handle the ICMP header themselves and calculate the checksum. 26 | 27 | #### TCP 28 | 29 | Use TcpListener and TcpStream. Supported over IPv4 and IPv6. 30 | 31 | - MSS is negotiated 32 | - Window scaling is negotiated. 33 | - Reassembly of out-of-order segments is supported 34 | - The timeout waiting time is fixed and can be configured 35 | - Selective acknowledgements permitted. (Proactively ACK the need for improvement) 36 | 37 | #### Other 38 | 39 | Using IpSocket to send and receive packets of other protocols.(Handles all IP upper-layer protocols without requiring 40 | the user to consider IP fragmentation.) 41 | 42 | ## example 43 | 44 | - [tcp](https://github.com/rustp2p/tcp_ip/blob/main/examples/tcp.rs) 45 | - [udp](https://github.com/rustp2p/tcp_ip/blob/main/examples/udp.rs) 46 | - [icmp](https://github.com/rustp2p/tcp_ip/blob/main/examples/icmp.rs) 47 | - [ipv4_ipv6](https://github.com/rustp2p/tcp_ip/blob/main/examples/ipv4_ipv6.rs) 48 | - [proxy](https://github.com/rustp2p/tcp_ip/blob/main/examples/tcp_proxy.rs) 49 | - [tcp_connect](https://github.com/rustp2p/tcp_ip/blob/main/examples/tcp_connect.rs) 50 | 51 | ## iperf test 52 | 53 | ### LAN Speed Test 54 | 55 | ![image](https://github.com/user-attachments/assets/135c2ff9-9515-46c2-9439-e035f3422d54) 56 | 57 | ### Example:[Proxy](https://github.com/rustp2p/tcp_ip/blob/main/examples/tcp_proxy.rs)-Windows 58 | 59 | ![image](https://github.com/user-attachments/assets/9a56de87-2e89-4a42-9587-8f1923935739) 60 | 61 | ### Example: [Proxy](https://github.com/rustp2p/tcp_ip/blob/main/examples/tcp_proxy.rs)-Linux 62 | 63 | ![image](https://github.com/user-attachments/assets/23d7863a-475a-4602-b56a-a1444cfa155d) 64 | 65 | -------------------------------------------------------------------------------- /examples/icmp.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused, unused_variables)] 2 | use pnet_packet::icmp::IcmpTypes; 3 | use pnet_packet::icmpv6::Icmpv6Types; 4 | use pnet_packet::Packet; 5 | use std::net::IpAddr; 6 | use std::sync::Arc; 7 | use tcp_ip::icmp::{IcmpSocket, IcmpV6Socket}; 8 | use tcp_ip::ip::IpSocket; 9 | use tcp_ip::{ip_stack, IpStackConfig, IpStackRecv, IpStackSend}; 10 | use tun_rs::{AsyncDevice, Configuration}; 11 | 12 | const MTU: u16 = 1420; 13 | 14 | /// After starting the program,ping 10.0.0.0/24 (e.g. ping 10.0.0.2), 15 | /// and you can receive a response from IcmpSocket 16 | #[tokio::main] 17 | pub async fn main() -> anyhow::Result<()> { 18 | env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace")).init(); 19 | let mut config = Configuration::default(); 20 | config 21 | .mtu(MTU) 22 | .address_with_prefix_multi(&[("CDCD:910A:2222:5498:8475:1111:3900:2025", 64), ("10.0.0.29", 24)]) 23 | .up(); 24 | let dev = tun_rs::create_as_async(&config)?; 25 | let dev = Arc::new(dev); 26 | let ip_stack_config = IpStackConfig { 27 | mtu: MTU, 28 | ..Default::default() 29 | }; 30 | let (ip_stack_send, ip_stack_recv) = ip_stack(ip_stack_config)?; 31 | let icmp_socket = IcmpSocket::bind_all().await?; 32 | let icmp_v6_socket = IcmpV6Socket::bind_all().await?; 33 | 34 | let h1 = tokio::spawn(async { 35 | if let Err(e) = icmp_v4_recv(icmp_socket).await { 36 | log::error!("icmp {e:?}"); 37 | } 38 | }); 39 | let h1_1 = tokio::spawn(async { 40 | if let Err(e) = icmp_v6_recv(icmp_v6_socket).await { 41 | log::error!("icmpv6 {e:?}"); 42 | } 43 | }); 44 | let dev1 = dev.clone(); 45 | let h2 = tokio::spawn(async { 46 | if let Err(e) = tun_to_ip_stack(dev1, ip_stack_send).await { 47 | log::error!("tun_to_ip_stack {e:?}"); 48 | } 49 | }); 50 | let h3 = tokio::spawn(async { 51 | if let Err(e) = ip_stack_to_tun(ip_stack_recv, dev).await { 52 | log::error!("ip_stack_to_tun {e:?}"); 53 | } 54 | }); 55 | let _ = tokio::try_join!(h1, h1_1, h2, h3); 56 | Ok(()) 57 | } 58 | 59 | async fn icmp_v4_recv(icmp_socket: IcmpSocket) -> anyhow::Result<()> { 60 | let mut buf = [0; 65536]; 61 | loop { 62 | let (len, src, dst) = icmp_socket.recv_from_to(&mut buf).await?; 63 | log::info!("src={src},dst={dst},len={len},buf={:?}", &buf[..len]); 64 | if let Some(mut packet) = pnet_packet::icmp::MutableIcmpPacket::new(&mut buf[..len]) { 65 | if packet.get_icmp_type() == IcmpTypes::EchoRequest { 66 | log::info!("icmpv4 {packet:?}"); 67 | packet.set_icmp_type(IcmpTypes::EchoReply); 68 | let checksum = pnet_packet::icmp::checksum(&packet.to_immutable()); 69 | packet.set_checksum(checksum); 70 | 71 | icmp_socket.send_from_to(packet.packet(), dst, src).await?; 72 | } 73 | } 74 | } 75 | } 76 | async fn icmp_v6_recv(icmp_socket: IcmpV6Socket) -> anyhow::Result<()> { 77 | let mut buf = [0; 65536]; 78 | loop { 79 | let (len, src, dst) = icmp_socket.recv_from_to(&mut buf).await?; 80 | let src_ip = match src { 81 | IpAddr::V6(ip) => ip, 82 | IpAddr::V4(_) => unimplemented!(), 83 | }; 84 | let dst_ip = match dst { 85 | IpAddr::V6(ip) => ip, 86 | IpAddr::V4(_) => unimplemented!(), 87 | }; 88 | log::info!("src={src},dst={dst},len={len},buf={:?}", &buf[..len]); 89 | if let Some(mut packet) = pnet_packet::icmpv6::MutableIcmpv6Packet::new(&mut buf[..len]) { 90 | if packet.get_icmpv6_type() == Icmpv6Types::EchoRequest { 91 | log::info!("icmpv6 {packet:?}"); 92 | packet.set_icmpv6_type(Icmpv6Types::EchoReply); 93 | let checksum = pnet_packet::icmpv6::checksum(&packet.to_immutable(), &dst_ip, &src_ip); 94 | packet.set_checksum(checksum); 95 | 96 | icmp_socket.send_from_to(packet.packet(), dst, src).await?; 97 | } 98 | } 99 | } 100 | } 101 | 102 | async fn tun_to_ip_stack(dev: Arc, mut ip_stack_send: IpStackSend) -> anyhow::Result<()> { 103 | let mut buf = [0; MTU as usize]; 104 | loop { 105 | let len = dev.recv(&mut buf).await?; 106 | if let Err(e) = ip_stack_send.send_ip_packet(&buf[..len]).await { 107 | log::error!("ip_stack_send.send_ip_packet e={e:?}") 108 | } 109 | } 110 | } 111 | 112 | async fn ip_stack_to_tun(mut ip_stack_recv: IpStackRecv, dev: Arc) -> anyhow::Result<()> { 113 | let mut buf = [0; MTU as usize]; 114 | loop { 115 | let len = ip_stack_recv.recv(&mut buf).await?; 116 | dev.send(&buf[..len]).await?; 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /examples/ipv4_ipv6.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused, unused_variables)] 2 | use pnet_packet::icmp::IcmpTypes; 3 | use pnet_packet::ip::{IpNextHeaderProtocol, IpNextHeaderProtocols}; 4 | use pnet_packet::Packet; 5 | use std::sync::Arc; 6 | use tun_rs::{AsyncDevice, Configuration}; 7 | 8 | use tcp_ip::icmp::IcmpSocket; 9 | use tcp_ip::ip::IpSocket; 10 | use tcp_ip::{ip_stack, IpStackConfig, IpStackRecv, IpStackSend}; 11 | 12 | const MTU: u16 = 1420; 13 | 14 | /// Handles all IPv4 upper-layer protocols without requiring the user to consider IP fragmentation. 15 | #[tokio::main] 16 | pub async fn main() -> anyhow::Result<()> { 17 | env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace")).init(); 18 | let mut config = Configuration::default(); 19 | 20 | config 21 | .mtu(MTU) 22 | .address_with_prefix_multi(&[("CDCD:910A:2222:5498:8475:1111:3900:2025", 64), ("10.0.0.29", 24)]) 23 | .up(); 24 | let dev = tun_rs::create_as_async(&config)?; 25 | let dev = Arc::new(dev); 26 | let ip_stack_config = IpStackConfig { 27 | mtu: MTU, 28 | ..Default::default() 29 | }; 30 | let (ip_stack_send, ip_stack_recv) = ip_stack(ip_stack_config)?; 31 | // None means receiving all protocols. 32 | let ip_socket = IpSocket::bind_all(None).await?; 33 | 34 | let h1 = tokio::spawn(async { 35 | if let Err(e) = ip_recv(ip_socket).await { 36 | log::error!("ip packet {e:?}"); 37 | } 38 | }); 39 | let dev1 = dev.clone(); 40 | let h2 = tokio::spawn(async { 41 | // Reads packet from TUN and sends to stack. 42 | if let Err(e) = tun_to_ip_stack(dev1, ip_stack_send).await { 43 | log::error!("tun_to_ip_stack {e:?}"); 44 | } 45 | }); 46 | let h3 = tokio::spawn(async { 47 | // Reads packet from stack and sends to TUN. 48 | if let Err(e) = ip_stack_to_tun(ip_stack_recv, dev).await { 49 | log::error!("ip_stack_to_tun {e:?}"); 50 | } 51 | }); 52 | let _ = tokio::try_join!(h1, h2, h3); 53 | Ok(()) 54 | } 55 | 56 | async fn ip_recv(ip_socket: IpSocket) -> anyhow::Result<()> { 57 | let mut buf = [0; 65536]; 58 | loop { 59 | let (len, p, src, dst) = ip_socket.recv_protocol_from_to(&mut buf).await?; 60 | // The read and write operations of Ipv4Socket do not include the IP header. 61 | log::info!("protocol={p},src={src},dst={dst},len={len},buf={:?}", &buf[..len]); 62 | match p { 63 | IpNextHeaderProtocols::Icmp => { 64 | if let Some(mut packet) = pnet_packet::icmp::MutableIcmpPacket::new(&mut buf[..len]) { 65 | if packet.get_icmp_type() == IcmpTypes::EchoRequest { 66 | log::info!("icmp {packet:?}"); 67 | packet.set_icmp_type(IcmpTypes::EchoReply); 68 | let checksum = pnet_packet::icmp::checksum(&packet.to_immutable()); 69 | packet.set_checksum(checksum); 70 | 71 | ip_socket.send_from_to(packet.packet(), dst, src).await?; 72 | } 73 | } 74 | } 75 | IpNextHeaderProtocols::Udp => { 76 | let udp_packet = pnet_packet::udp::UdpPacket::new(&buf[..len]).unwrap(); 77 | // When using this socket to send UDP packets, you need to calculate the UDP checksum yourself. 78 | log::info!("recv udp {:?}", std::str::from_utf8(udp_packet.payload())); 79 | } 80 | _ => {} 81 | } 82 | } 83 | } 84 | 85 | async fn tun_to_ip_stack(dev: Arc, mut ip_stack_send: IpStackSend) -> anyhow::Result<()> { 86 | let mut buf = [0; MTU as usize]; 87 | loop { 88 | let len = dev.recv(&mut buf).await?; 89 | if let Err(e) = ip_stack_send.send_ip_packet(&buf[..len]).await { 90 | log::error!("ip_stack_send.send_ip_packet e={e:?}") 91 | } 92 | } 93 | } 94 | 95 | async fn ip_stack_to_tun(mut ip_stack_recv: IpStackRecv, dev: Arc) -> anyhow::Result<()> { 96 | let mut buf = [0; MTU as usize]; 97 | loop { 98 | let len = ip_stack_recv.recv(&mut buf).await?; 99 | log::debug!("ip_stack_to_tun num={len}"); 100 | dev.send(&buf[..len]).await?; 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /examples/tcp.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused, unused_variables)] 2 | use std::sync::Arc; 3 | 4 | use bytes::BytesMut; 5 | use pnet_packet::Packet; 6 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 7 | use tun_rs::{AsyncDevice, Configuration}; 8 | 9 | use tcp_ip::tcp::TcpListener; 10 | use tcp_ip::{ip_stack, IpStackConfig, IpStackRecv, IpStackSend}; 11 | 12 | const MTU: u16 = 1420; 13 | 14 | /// After starting, use a TCP connection to any port in the 10.0.0.0/24 subnet (e.g., telnet 10.0.0.2 8080). 15 | /// Sending data will receive a response. 16 | #[tokio::main] 17 | pub async fn main() -> anyhow::Result<()> { 18 | env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); 19 | let mut config = Configuration::default(); 20 | config 21 | .mtu(MTU) 22 | .address_with_prefix_multi(&[("CDCD:910A:2222:5498:8475:1111:3900:2025", 64), ("10.0.0.29", 24)]) 23 | .up(); 24 | let dev = tun_rs::create_as_async(&config)?; 25 | let dev = Arc::new(dev); 26 | let ip_stack_config = IpStackConfig { 27 | mtu: MTU, 28 | ..Default::default() 29 | }; 30 | let (ip_stack_send, ip_stack_recv) = ip_stack(ip_stack_config)?; 31 | let mut tcp_listener = TcpListener::bind_all().await?; 32 | 33 | let h1 = tokio::spawn(async move { 34 | loop { 35 | let (mut tcp_stream, addr) = match tcp_listener.accept().await { 36 | Ok(rs) => rs, 37 | Err(e) => { 38 | log::error!("tcp_listener accept {e:?}"); 39 | break; 40 | } 41 | }; 42 | log::info!("tcp_stream addr:{addr}"); 43 | tokio::spawn(async move { 44 | let mut buf = [0; 1024]; 45 | loop { 46 | match tcp_stream.read(&mut buf).await { 47 | Ok(len) => { 48 | log::info!("tcp_stream read len={len},buf={:?}", &buf[..len]); 49 | if let Err(e) = tcp_stream.write(&buf[..len]).await { 50 | log::error!("tcp_stream write {e:?}"); 51 | break; 52 | } 53 | } 54 | Err(e) => { 55 | log::error!("tcp_stream read {e:?}"); 56 | break; 57 | } 58 | } 59 | } 60 | }); 61 | } 62 | }); 63 | 64 | let dev1 = dev.clone(); 65 | let h2 = tokio::spawn(async { 66 | if let Err(e) = tun_to_ip_stack(dev1, ip_stack_send).await { 67 | log::error!("tun_to_ip_stack {e:?}"); 68 | } 69 | }); 70 | let h3 = tokio::spawn(async { 71 | if let Err(e) = ip_stack_to_tun(ip_stack_recv, dev).await { 72 | log::error!("ip_stack_to_tun {e:?}"); 73 | } 74 | }); 75 | let _ = tokio::try_join!(h1, h2, h3,); 76 | Ok(()) 77 | } 78 | 79 | async fn tun_to_ip_stack(dev: Arc, mut ip_stack_send: IpStackSend) -> anyhow::Result<()> { 80 | let mut buf = [0; MTU as usize]; 81 | loop { 82 | let len = dev.recv(&mut buf).await?; 83 | let packet = pnet_packet::ipv4::Ipv4Packet::new(&buf[..len]).unwrap(); 84 | if packet.get_next_level_protocol() == pnet_packet::ip::IpNextHeaderProtocols::Tcp { 85 | // log::debug!("tun_to_ip_stack {packet:?}"); 86 | let tcp_packet = pnet_packet::tcp::TcpPacket::new(packet.payload()).unwrap(); 87 | log::debug!("tun_to_ip_stack tcp_packet={tcp_packet:?} payload={:?}", tcp_packet.payload()); 88 | } 89 | 90 | if let Err(e) = ip_stack_send.send_ip_packet(&buf[..len]).await { 91 | log::error!("ip_stack_send.send_ip_packet e={e:?}") 92 | } 93 | } 94 | } 95 | 96 | async fn ip_stack_to_tun(mut ip_stack_recv: IpStackRecv, dev: Arc) -> anyhow::Result<()> { 97 | let mut bufs = Vec::with_capacity(128); 98 | let mut sizes = vec![0; 128]; 99 | for _ in 0..128 { 100 | bufs.push(BytesMut::zeroed(MTU as usize)) 101 | } 102 | loop { 103 | let num = ip_stack_recv.recv_ip_packet(&mut bufs, &mut sizes).await?; 104 | // log::debug!("ip_stack_to_tun num={num}"); 105 | for index in 0..num { 106 | let buf = &bufs[index]; 107 | let len = sizes[index]; 108 | let packet = pnet_packet::ipv4::Ipv4Packet::new(&buf[..len]).unwrap(); 109 | // log::debug!("ip_stack_to_tun {packet:?}"); 110 | if packet.get_next_level_protocol() == pnet_packet::ip::IpNextHeaderProtocols::Tcp { 111 | let tcp_packet = pnet_packet::tcp::TcpPacket::new(packet.payload()).unwrap(); 112 | log::debug!("ip_stack_to_tun tcp_packet={tcp_packet:?} payload={:?}", tcp_packet.payload()); 113 | } 114 | 115 | dev.send(&buf[..len]).await?; 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /examples/tcp_connect.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused, unused_variables)] 2 | use std::net::{Ipv4Addr, SocketAddrV4}; 3 | use std::sync::Arc; 4 | use std::time::Duration; 5 | 6 | use bytes::BytesMut; 7 | use pnet_packet::Packet; 8 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 9 | use tun_rs::{AsyncDevice, Configuration}; 10 | 11 | use tcp_ip::{ip_stack, IpStackConfig, IpStackRecv, IpStackSend}; 12 | 13 | const MTU: u16 = 1420; 14 | /// This example demonstrates how to use a TCP active connection to a userspace TCP/IP protocol stack, 15 | /// which can convert TCP data into IP packets. 16 | #[tokio::main] 17 | pub async fn main() -> anyhow::Result<()> { 18 | env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init(); 19 | let mut config = Configuration::default(); 20 | let local_ip = Ipv4Addr::new(10, 0, 0, 29); 21 | config.mtu(MTU).address_with_prefix(local_ip, 24).up(); 22 | let dev = tun_rs::create_as_async(&config)?; 23 | let dev = Arc::new(dev); 24 | let ip_stack_config = IpStackConfig { 25 | mtu: MTU, 26 | ..Default::default() 27 | }; 28 | let (ip_stack_send, ip_stack_recv) = ip_stack(ip_stack_config)?; 29 | let dev1 = dev.clone(); 30 | tokio::spawn(async { 31 | if let Err(e) = tun_to_ip_stack(dev1, ip_stack_send).await { 32 | log::error!("tun_to_ip_stack {e:?}"); 33 | } 34 | }); 35 | tokio::spawn(async { 36 | if let Err(e) = ip_stack_to_tun(ip_stack_recv, dev).await { 37 | log::error!("ip_stack_to_tun {e:?}"); 38 | } 39 | }); 40 | let listen_addr = SocketAddrV4::new(local_ip, 18888); 41 | // Waiting for the Tun network card to take effect. 42 | // Otherwise, it cannot be bound to the IP address 43 | tokio::time::sleep(Duration::from_secs(10)).await; 44 | let tokio_tcp_listener = tokio::net::TcpListener::bind(listen_addr).await?; 45 | tokio::spawn(async move { 46 | log::info!("tokio_tcp_listener accept {:?}", tokio_tcp_listener.local_addr()); 47 | loop { 48 | let (mut tokio_tcp_stream, addr) = match tokio_tcp_listener.accept().await { 49 | Ok(rs) => rs, 50 | Err(e) => { 51 | log::error!("tokio_tcp_listener accept {e:?}"); 52 | break; 53 | } 54 | }; 55 | log::info!("tokio_tcp_stream addr:{addr}"); 56 | tokio::spawn(async move { 57 | let mut buf = [0; 1024]; 58 | match tokio_tcp_stream.read(&mut buf).await { 59 | Ok(len) => { 60 | log::info!("tokio_tcp_stream read len={len},buf={:?}", &buf[..len]); 61 | if let Err(e) = tokio_tcp_stream.write(b"hello").await { 62 | log::error!("tokio_tcp_stream write {e:?}"); 63 | } 64 | } 65 | Err(e) => { 66 | log::error!("tokio_tcp_stream read {e:?}"); 67 | } 68 | } 69 | }); 70 | } 71 | }); 72 | let peer_addr = SocketAddrV4::new(local_ip, 18888); 73 | log::info!("tcp_ip_stream connecting. addr:{peer_addr}"); 74 | let mut tcp_ip_stream = tcp_ip::tcp::TcpStream::bind("10.0.0.2:18889")?.connect_to(peer_addr).await?; 75 | log::info!("tcp_ip_stream connection successful. addr:{peer_addr}"); 76 | tcp_ip_stream.write_all(b"hi").await?; 77 | let mut buf = [0; 1024]; 78 | let len = tcp_ip_stream.read(&mut buf).await?; 79 | log::info!("tcp_ip_stream read len={len},buf={:?}", &buf[..len]); 80 | Ok(()) 81 | } 82 | 83 | async fn tun_to_ip_stack(dev: Arc, mut ip_stack_send: IpStackSend) -> anyhow::Result<()> { 84 | let mut buf = [0; MTU as usize]; 85 | loop { 86 | let len = dev.recv(&mut buf).await?; 87 | let packet = pnet_packet::ipv4::Ipv4Packet::new(&buf[..len]).unwrap(); 88 | if packet.get_next_level_protocol() == pnet_packet::ip::IpNextHeaderProtocols::Tcp { 89 | // log::debug!("tun_to_ip_stack {packet:?}"); 90 | let tcp_packet = pnet_packet::tcp::TcpPacket::new(packet.payload()).unwrap(); 91 | log::debug!("tun_to_ip_stack tcp_packet={tcp_packet:?} payload={:?}", tcp_packet.payload()); 92 | } 93 | 94 | if let Err(e) = ip_stack_send.send_ip_packet(&buf[..len]).await { 95 | log::error!("ip_stack_send.send_ip_packet e={e:?}") 96 | } 97 | } 98 | } 99 | 100 | async fn ip_stack_to_tun(mut ip_stack_recv: IpStackRecv, dev: Arc) -> anyhow::Result<()> { 101 | let mut bufs = Vec::with_capacity(128); 102 | let mut sizes = vec![0; 128]; 103 | for _ in 0..128 { 104 | bufs.push(BytesMut::zeroed(MTU as usize)) 105 | } 106 | loop { 107 | let num = ip_stack_recv.recv_ip_packet(&mut bufs, &mut sizes).await?; 108 | // log::debug!("ip_stack_to_tun num={num}"); 109 | for index in 0..num { 110 | let buf = &bufs[index]; 111 | let len = sizes[index]; 112 | let packet = pnet_packet::ipv4::Ipv4Packet::new(&buf[..len]).unwrap(); 113 | // log::debug!("ip_stack_to_tun {packet:?}"); 114 | if packet.get_next_level_protocol() == pnet_packet::ip::IpNextHeaderProtocols::Tcp { 115 | let tcp_packet = pnet_packet::tcp::TcpPacket::new(packet.payload()).unwrap(); 116 | log::debug!("ip_stack_to_tun tcp_packet={tcp_packet:?} payload={:?}", tcp_packet.payload()); 117 | } 118 | 119 | dev.send(&buf[..len]).await?; 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /examples/tcp_proxy.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused, unused_variables)] 2 | use std::net::SocketAddr; 3 | use std::sync::Arc; 4 | 5 | use bytes::BytesMut; 6 | use clap::Parser; 7 | use pnet_packet::Packet; 8 | use tun_rs::{AsyncDevice, Configuration}; 9 | 10 | use tcp_ip::tcp::TcpListener; 11 | use tcp_ip::{ip_stack, IpStackConfig, IpStackRecv, IpStackSend}; 12 | 13 | const MTU: u16 = 1420; 14 | #[derive(Parser)] 15 | pub struct Args { 16 | #[arg(short, long)] 17 | server_addr: SocketAddr, 18 | } 19 | 20 | /// Convert the IP packet into transport layer data and forward it to the target address. 21 | /// This functionality is typically required by VPN clients. 22 | #[tokio::main] 23 | pub async fn main() -> anyhow::Result<()> { 24 | let args = Args::parse(); 25 | let server_addr = args.server_addr; 26 | env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); 27 | let mut config = Configuration::default(); 28 | config 29 | .mtu(MTU) 30 | .address_with_prefix_multi(&[("CDCD:910A:2222:5498:8475:1111:3900:2025", 64), ("10.0.0.29", 24)]) 31 | .up(); 32 | let dev = tun_rs::create_as_async(&config)?; 33 | let dev = Arc::new(dev); 34 | let ip_stack_config = IpStackConfig { 35 | mtu: MTU, 36 | ..Default::default() 37 | }; 38 | let (ip_stack_send, ip_stack_recv) = ip_stack(ip_stack_config)?; 39 | let mut tcp_listener = TcpListener::bind_all().await?; 40 | 41 | let h1 = tokio::spawn(async move { 42 | loop { 43 | let (tcp_stream, addr) = match tcp_listener.accept().await { 44 | Ok(rs) => rs, 45 | Err(e) => { 46 | log::error!("tcp_listener accept {e:?}"); 47 | break; 48 | } 49 | }; 50 | log::info!("tcp_stream addr:{addr}"); 51 | let server_stream = tokio::net::TcpStream::connect(server_addr).await.unwrap(); 52 | 53 | tokio::spawn(async move { 54 | let (mut client_write, mut client_read) = tcp_stream.split().unwrap(); 55 | let (mut server_read, mut server_write) = server_stream.into_split(); 56 | let h1 = tokio::io::copy(&mut client_read, &mut server_write); 57 | let h2 = tokio::io::copy(&mut server_read, &mut client_write); 58 | let rs = tokio::join!(h1, h2); 59 | log::info!("copy rs:{rs:?}"); 60 | }); 61 | } 62 | }); 63 | 64 | let dev1 = dev.clone(); 65 | let h2 = tokio::spawn(async { 66 | if let Err(e) = tun_to_ip_stack(dev1, ip_stack_send).await { 67 | log::error!("tun_to_ip_stack {e:?}"); 68 | } 69 | }); 70 | let h3 = tokio::spawn(async { 71 | if let Err(e) = ip_stack_to_tun(ip_stack_recv, dev).await { 72 | log::error!("ip_stack_to_tun {e:?}"); 73 | } 74 | }); 75 | let _ = tokio::try_join!(h1, h2, h3,); 76 | Ok(()) 77 | } 78 | 79 | async fn tun_to_ip_stack(dev: Arc, mut ip_stack_send: IpStackSend) -> anyhow::Result<()> { 80 | let mut buf = [0; MTU as usize]; 81 | loop { 82 | let len = dev.recv(&mut buf).await?; 83 | let packet = pnet_packet::ipv4::Ipv4Packet::new(&buf[..len]).unwrap(); 84 | if packet.get_next_level_protocol() == pnet_packet::ip::IpNextHeaderProtocols::Tcp { 85 | let tcp_packet = pnet_packet::tcp::TcpPacket::new(packet.payload()).unwrap(); 86 | log::debug!("tun_to_ip_stack tcp_packet={tcp_packet:?}"); 87 | } 88 | 89 | if let Err(e) = ip_stack_send.send_ip_packet(&buf[..len]).await { 90 | log::error!("ip_stack_send.send_ip_packet e={e:?}") 91 | } 92 | } 93 | } 94 | 95 | async fn ip_stack_to_tun(mut ip_stack_recv: IpStackRecv, dev: Arc) -> anyhow::Result<()> { 96 | let mut bufs = Vec::with_capacity(128); 97 | let mut sizes = vec![0; 128]; 98 | for _ in 0..128 { 99 | bufs.push(BytesMut::zeroed(MTU as usize)) 100 | } 101 | loop { 102 | let num = ip_stack_recv.recv_ip_packet(&mut bufs, &mut sizes).await?; 103 | // log::debug!("ip_stack_to_tun num={num}"); 104 | for index in 0..num { 105 | let buf = &bufs[index]; 106 | let len = sizes[index]; 107 | let packet = pnet_packet::ipv4::Ipv4Packet::new(&buf[..len]).unwrap(); 108 | // log::debug!("ip_stack_to_tun {packet:?}"); 109 | if packet.get_next_level_protocol() == pnet_packet::ip::IpNextHeaderProtocols::Tcp { 110 | let tcp_packet = pnet_packet::tcp::TcpPacket::new(packet.payload()).unwrap(); 111 | log::debug!("ip_stack_to_tun tcp_packet={tcp_packet:?}"); 112 | } 113 | 114 | match dev.send(&buf[..len]).await { 115 | Ok(_) => {} 116 | Err(e) => { 117 | log::error!("{e:?}, buf={:?}", &buf[..len]) 118 | } 119 | } 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /examples/udp.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused, unused_variables)] 2 | use bytes::BytesMut; 3 | use pnet_packet::Packet; 4 | use std::sync::Arc; 5 | use tcp_ip::udp::UdpSocket; 6 | use tcp_ip::{ip_stack, IpStackConfig, IpStackRecv, IpStackSend}; 7 | use tun_rs::{AsyncDevice, Configuration}; 8 | 9 | const MTU: u16 = 1420; 10 | 11 | /// After starting, use a UDP send to any port in the 10.0.0.0/24 subnet 12 | /// Sending data will receive a response. 13 | /// Specifically, if sent to port 8080, a response will be received from udp_socket_8080 14 | #[tokio::main] 15 | pub async fn main() -> anyhow::Result<()> { 16 | env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace")).init(); 17 | let mut config = Configuration::default(); 18 | config 19 | .mtu(MTU) 20 | .address_with_prefix_multi(&[("CDCD:910A:2222:5498:8475:1111:3900:2025", 64), ("10.0.0.29", 24)]) 21 | .up(); 22 | let dev = tun_rs::create_as_async(&config)?; 23 | let dev = Arc::new(dev); 24 | let ip_stack_config = IpStackConfig { 25 | mtu: MTU, 26 | ..Default::default() 27 | }; 28 | let (ip_stack_send, ip_stack_recv) = ip_stack(ip_stack_config)?; 29 | let udp_socket = UdpSocket::bind_all().await?; 30 | // Bind to a specific address 31 | let udp_socket_8080 = UdpSocket::bind("0.0.0.0:8080").await?; 32 | 33 | let h1 = tokio::spawn(async { 34 | if let Err(e) = udp_recv(udp_socket, "hello".into()).await { 35 | log::error!("udp {e:?}"); 36 | } 37 | }); 38 | let h2 = tokio::spawn(async { 39 | if let Err(e) = udp_recv(udp_socket_8080, "hello8080".into()).await { 40 | log::error!("udp {e:?}"); 41 | } 42 | }); 43 | let dev1 = dev.clone(); 44 | let h3 = tokio::spawn(async { 45 | if let Err(e) = tun_to_ip_stack(dev1, ip_stack_send).await { 46 | log::error!("tun_to_ip_stack {e:?}"); 47 | } 48 | }); 49 | let h4 = tokio::spawn(async { 50 | if let Err(e) = ip_stack_to_tun(ip_stack_recv, dev).await { 51 | log::error!("ip_stack_to_tun {e:?}"); 52 | } 53 | }); 54 | let _ = tokio::try_join!(h1, h2, h3, h4); 55 | Ok(()) 56 | } 57 | 58 | async fn udp_recv(udp_socket: UdpSocket, prefix: String) -> anyhow::Result<()> { 59 | let mut buf = [0; 65536]; 60 | loop { 61 | let (len, src, dst) = udp_socket.recv_from_to(&mut buf).await?; 62 | log::info!("src={src},dst={dst},len={len},buf={:?}", &buf[..len]); 63 | match String::from_utf8(buf[..len].to_vec()) { 64 | Ok(msg) => { 65 | let res = format!("{prefix}:{msg}"); 66 | udp_socket.send_from_to(res.as_bytes(), dst, src).await?; 67 | } 68 | Err(e) => { 69 | log::error!("from_utf8 {e}") 70 | } 71 | } 72 | } 73 | } 74 | 75 | async fn tun_to_ip_stack(dev: Arc, mut ip_stack_send: IpStackSend) -> anyhow::Result<()> { 76 | let mut buf = [0; MTU as usize]; 77 | loop { 78 | let len = dev.recv(&mut buf).await?; 79 | let packet = pnet_packet::ipv4::Ipv4Packet::new(&buf[..len]).unwrap(); 80 | if packet.get_next_level_protocol() == pnet_packet::ip::IpNextHeaderProtocols::Udp { 81 | log::debug!("tun_to_ip_stack {packet:?}"); 82 | let udp_packet = pnet_packet::udp::UdpPacket::new(packet.payload()).unwrap(); 83 | log::debug!("tun_to_ip_stack udp_packet={udp_packet:?} payload={:?}", udp_packet.payload()); 84 | } 85 | 86 | if let Err(e) = ip_stack_send.send_ip_packet(&buf[..len]).await { 87 | log::error!("ip_stack_send.send_ip_packet e={e:?}") 88 | } 89 | } 90 | } 91 | 92 | async fn ip_stack_to_tun(mut ip_stack_recv: IpStackRecv, dev: Arc) -> anyhow::Result<()> { 93 | let mut bufs = Vec::with_capacity(128); 94 | let mut sizes = vec![0; 128]; 95 | for _ in 0..128 { 96 | bufs.push(BytesMut::zeroed(MTU as usize)) 97 | } 98 | loop { 99 | let num = ip_stack_recv.recv_ip_packet(&mut bufs, &mut sizes).await?; 100 | log::debug!("ip_stack_to_tun num={num}"); 101 | for index in 0..num { 102 | let buf = &bufs[index]; 103 | let len = sizes[index]; 104 | let packet = pnet_packet::ipv4::Ipv4Packet::new(&buf[..len]).unwrap(); 105 | log::debug!("ip_stack_to_tun {packet:?}"); 106 | if packet.get_next_level_protocol() == pnet_packet::ip::IpNextHeaderProtocols::Udp { 107 | let udp_packet = pnet_packet::udp::UdpPacket::new(packet.payload()).unwrap(); 108 | log::debug!("ip_stack_to_tun udp_packet={udp_packet:?} payload={:?}", udp_packet.payload()); 109 | } 110 | 111 | dev.send(&buf[..len]).await?; 112 | } 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width = 140 2 | -------------------------------------------------------------------------------- /src/address/mod.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; 3 | 4 | pub trait ToSocketAddr { 5 | fn to_addr(&self) -> io::Result; 6 | } 7 | impl ToSocketAddr for &SocketAddr { 8 | fn to_addr(&self) -> io::Result { 9 | Ok(**self) 10 | } 11 | } 12 | impl ToSocketAddr for SocketAddr { 13 | fn to_addr(&self) -> io::Result { 14 | Ok(*self) 15 | } 16 | } 17 | impl ToSocketAddr for &SocketAddrV4 { 18 | fn to_addr(&self) -> io::Result { 19 | Ok((**self).into()) 20 | } 21 | } 22 | impl ToSocketAddr for SocketAddrV4 { 23 | fn to_addr(&self) -> io::Result { 24 | Ok((*self).into()) 25 | } 26 | } 27 | impl ToSocketAddr for &SocketAddrV6 { 28 | fn to_addr(&self) -> io::Result { 29 | Ok((**self).into()) 30 | } 31 | } 32 | impl ToSocketAddr for SocketAddrV6 { 33 | fn to_addr(&self) -> io::Result { 34 | Ok((*self).into()) 35 | } 36 | } 37 | impl ToSocketAddr for &str { 38 | fn to_addr(&self) -> io::Result { 39 | self.parse() 40 | .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, format!("{e}"))) 41 | } 42 | } 43 | impl ToSocketAddr for &String { 44 | fn to_addr(&self) -> io::Result { 45 | self.as_str().to_addr() 46 | } 47 | } 48 | impl ToSocketAddr for String { 49 | fn to_addr(&self) -> io::Result { 50 | self.as_str().to_addr() 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/buffer.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | 3 | use bytes::BytesMut; 4 | 5 | #[derive(Debug)] 6 | pub struct FixedBuffer { 7 | offset: usize, 8 | buf: BytesMut, 9 | } 10 | impl FixedBuffer { 11 | pub fn with_capacity(capacity: usize) -> Self { 12 | Self { 13 | offset: 0, 14 | buf: BytesMut::with_capacity(capacity), 15 | } 16 | } 17 | pub fn available(&self) -> usize { 18 | self.buf.capacity() - self.buf.len() 19 | } 20 | pub fn clear(&mut self) { 21 | self.offset = 0; 22 | self.buf.clear(); 23 | } 24 | pub fn offset(&self) -> usize { 25 | self.offset 26 | } 27 | pub fn advance(&mut self, n: usize) { 28 | self.offset += n; 29 | assert!(self.offset <= self.buf.capacity()); 30 | if self.offset > self.buf.len() { 31 | self.buf.resize(self.offset, 0); 32 | } 33 | } 34 | pub fn back(&mut self, n: usize) { 35 | assert!(self.offset >= n); 36 | self.offset -= n; 37 | } 38 | pub fn extend_from_slice(&mut self, buf: &[u8]) -> usize { 39 | let n = buf.len().min(self.available()); 40 | if n == 0 { 41 | return 0; 42 | } 43 | self.buf.extend_from_slice(&buf[..n]); 44 | n 45 | } 46 | pub fn len(&self) -> usize { 47 | self.buf.len() - self.offset 48 | } 49 | pub fn bytes(&self) -> &[u8] { 50 | &self.buf[self.offset..] 51 | } 52 | pub fn bytes_mut(&mut self) -> &mut [u8] { 53 | &mut self.buf[self.offset..] 54 | } 55 | } 56 | 57 | pub struct RingBuffer { 58 | buffer: Vec, 59 | head: usize, 60 | tail: usize, 61 | size: usize, 62 | } 63 | impl RingBuffer { 64 | pub fn new(capacity: usize) -> Self { 65 | assert!(capacity.is_power_of_two(), "Capacity must be a power of 2"); 66 | Self { 67 | buffer: vec![0; capacity], 68 | head: 0, 69 | tail: 0, 70 | size: 0, 71 | } 72 | } 73 | pub fn is_empty(&self) -> bool { 74 | self.size == 0 75 | } 76 | pub fn len(&self) -> usize { 77 | self.size 78 | } 79 | 80 | pub fn is_full(&self) -> bool { 81 | self.size == self.buffer.capacity() 82 | } 83 | pub fn capacity(&self) -> usize { 84 | self.buffer.capacity() 85 | } 86 | pub fn available(&self) -> usize { 87 | self.capacity() - self.size 88 | } 89 | pub fn push(&mut self, data: &[u8]) -> usize { 90 | let push_len = self.available().min(data.len()); 91 | if push_len == 0 { 92 | return 0; 93 | } 94 | let first_part = self.capacity() - self.tail; 95 | if push_len <= first_part { 96 | self.buffer[self.tail..self.tail + push_len].copy_from_slice(&data[..push_len]); 97 | } else { 98 | self.buffer[self.tail..].copy_from_slice(&data[..first_part]); 99 | self.buffer[..push_len - first_part].copy_from_slice(&data[first_part..push_len]); 100 | } 101 | self.tail = (self.tail + push_len) & (self.capacity() - 1); 102 | self.size += push_len; 103 | push_len 104 | } 105 | pub fn pop(&mut self, buf: &mut [u8]) -> usize { 106 | let len = buf.len().min(self.len()); 107 | if len == 0 { 108 | return 0; 109 | } 110 | let mask = self.capacity() - 1; 111 | let first_part = self.capacity() - (self.head & mask); 112 | if len <= first_part { 113 | buf[..len].copy_from_slice(&self.buffer[self.head & mask..(self.head & mask) + len]); 114 | } else { 115 | buf[..first_part].copy_from_slice(&self.buffer[self.head & mask..]); 116 | buf[first_part..len].copy_from_slice(&self.buffer[..len - first_part]); 117 | } 118 | 119 | self.head = (self.head + len) & mask; 120 | self.size -= len; 121 | len 122 | } 123 | } 124 | 125 | #[cfg(test)] 126 | mod tests { 127 | use super::*; 128 | #[test] 129 | fn test_fixed_buffer() { 130 | let mut buffer = FixedBuffer::with_capacity(10); 131 | 132 | // Test extend_from_slice 133 | let buf = &[1, 2, 3, 4, 5]; 134 | let extended_len = buffer.extend_from_slice(buf); 135 | assert_eq!(extended_len, 5); // 5 bytes should be added 136 | assert_eq!(buffer.len(), 5); // Buffer length should be 5 137 | assert_eq!(buffer.bytes(), &[1, 2, 3, 4, 5]); // Buffer content should match 138 | 139 | // Test advance 140 | buffer.advance(3); 141 | assert_eq!(buffer.len(), 2); // After advancing 3, length should be 2 142 | assert_eq!(buffer.bytes(), &[4, 5]); // Remaining bytes should be [4, 5] 143 | 144 | // Test back 145 | buffer.back(1); 146 | assert_eq!(buffer.len(), 3); // After going back by 1, length should be 3 147 | assert_eq!(buffer.bytes(), &[3, 4, 5]); // Remaining bytes should be [3, 4, 5] 148 | 149 | // Test available space 150 | assert_eq!(buffer.available(), 5); 151 | buffer.extend_from_slice(&[6, 7]); 152 | assert_eq!(buffer.available(), 3); 153 | // Test clear 154 | buffer.clear(); 155 | assert_eq!(buffer.available(), 10); 156 | assert_eq!(buffer.len(), 0); // After clear, buffer should be empty 157 | assert_eq!(buffer.bytes(), &[]); // No bytes left after clearing 158 | } 159 | #[test] 160 | fn test_ring_buffer() { 161 | let mut ring_buffer = RingBuffer::new(8); 162 | 163 | // Test pushing data 164 | let pushed_len = ring_buffer.push(&[1, 2, 3, 4, 5]); 165 | assert_eq!(pushed_len, 5); // 5 bytes pushed 166 | assert_eq!(ring_buffer.len(), 5); // Buffer size should be 5 167 | 168 | // Test popping data 169 | let mut buf = vec![0; 4]; 170 | let popped_len = ring_buffer.pop(&mut buf); 171 | assert_eq!(popped_len, 4); // 4 bytes popped 172 | assert_eq!(buf, [1, 2, 3, 4]); // Popped data should match 173 | 174 | // Test remaining data in buffer 175 | assert_eq!(ring_buffer.len(), 1); // Only 1 byte should be left 176 | 177 | // Test pushing more data 178 | let pushed_len = ring_buffer.push(&[6, 7, 8]); 179 | assert_eq!(pushed_len, 3); // 3 bytes pushed 180 | assert_eq!(ring_buffer.len(), 4); // Buffer size should now be 4 181 | 182 | // Test popping remaining data 183 | let mut buf = vec![0; 4]; 184 | let popped_len = ring_buffer.pop(&mut buf); 185 | assert_eq!(popped_len, 4); // 4 bytes popped 186 | assert_eq!(buf, [5, 6, 7, 8]); // Popped data should match 187 | let pushed_len = ring_buffer.push(&[1, 2, 3, 4, 5, 6, 7, 8, 9]); 188 | assert_eq!(pushed_len, 8); 189 | let mut buf = vec![0; 10]; 190 | let popped_len = ring_buffer.pop(&mut buf); 191 | assert_eq!(popped_len, 8); 192 | assert_eq!(&buf[..popped_len], &[1, 2, 3, 4, 5, 6, 7, 8]); 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /src/icmp/mod.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::net::IpAddr; 3 | use std::ops::Deref; 4 | 5 | use pnet_packet::ip::IpNextHeaderProtocols; 6 | 7 | use crate::ip::IpSocket; 8 | use crate::ip_stack::{IpStack, UNSPECIFIED_ADDR_V4, UNSPECIFIED_ADDR_V6}; 9 | 10 | pub struct IcmpSocket { 11 | raw_ip_socket: IpSocket, 12 | } 13 | #[cfg(feature = "global-ip-stack")] 14 | impl IcmpSocket { 15 | pub async fn bind_all() -> io::Result { 16 | Self::bind(UNSPECIFIED_ADDR_V4.ip()).await 17 | } 18 | pub async fn bind(local_ip: IpAddr) -> io::Result { 19 | if local_ip.is_ipv6() { 20 | return Err(io::Error::new(io::ErrorKind::Unsupported, "need to use IcmpV6Socket")); 21 | } 22 | let ip_stack = IpStack::get()?; 23 | let raw_ip_socket = IpSocket::bind0( 24 | ip_stack.config.icmp_channel_size, 25 | Some(IpNextHeaderProtocols::Icmp), 26 | ip_stack, 27 | Some(local_ip), 28 | ) 29 | .await?; 30 | Ok(Self { raw_ip_socket }) 31 | } 32 | } 33 | #[cfg(not(feature = "global-ip-stack"))] 34 | impl IcmpSocket { 35 | pub async fn bind_all(ip_stack: IpStack) -> io::Result { 36 | Self::bind(ip_stack, UNSPECIFIED_ADDR_V4.ip()).await 37 | } 38 | pub async fn bind(ip_stack: IpStack, local_ip: IpAddr) -> io::Result { 39 | if local_ip.is_ipv6() { 40 | return Err(io::Error::new(io::ErrorKind::Unsupported, "need to use IcmpV6Socket")); 41 | } 42 | let raw_ip_socket = IpSocket::bind0( 43 | ip_stack.config.icmp_channel_size, 44 | Some(IpNextHeaderProtocols::Icmp), 45 | ip_stack, 46 | Some(local_ip), 47 | ) 48 | .await?; 49 | Ok(Self { raw_ip_socket }) 50 | } 51 | } 52 | 53 | impl Deref for IcmpSocket { 54 | type Target = IpSocket; 55 | 56 | fn deref(&self) -> &Self::Target { 57 | &self.raw_ip_socket 58 | } 59 | } 60 | 61 | pub struct IcmpV6Socket { 62 | raw_ip_socket: IpSocket, 63 | } 64 | #[cfg(feature = "global-ip-stack")] 65 | impl IcmpV6Socket { 66 | pub async fn bind_all() -> io::Result { 67 | Self::bind(UNSPECIFIED_ADDR_V6.ip()).await 68 | } 69 | pub async fn bind(local_ip: IpAddr) -> io::Result { 70 | if local_ip.is_ipv4() { 71 | return Err(io::Error::new(io::ErrorKind::Unsupported, "need to use IcmpSocket")); 72 | } 73 | let ip_stack = IpStack::get()?; 74 | let raw_ip_socket = IpSocket::bind0( 75 | ip_stack.config.icmp_channel_size, 76 | Some(IpNextHeaderProtocols::Icmpv6), 77 | ip_stack, 78 | Some(local_ip), 79 | ) 80 | .await?; 81 | Ok(Self { raw_ip_socket }) 82 | } 83 | } 84 | #[cfg(not(feature = "global-ip-stack"))] 85 | impl IcmpV6Socket { 86 | pub async fn bind_all(ip_stack: IpStack) -> io::Result { 87 | Self::bind(ip_stack, UNSPECIFIED_ADDR_V6.ip()).await 88 | } 89 | pub async fn bind(ip_stack: IpStack, local_ip: IpAddr) -> io::Result { 90 | if local_ip.is_ipv4() { 91 | return Err(io::Error::new(io::ErrorKind::Unsupported, "need to use IcmpSocket")); 92 | } 93 | let raw_ip_socket = IpSocket::bind0( 94 | ip_stack.config.icmp_channel_size, 95 | Some(IpNextHeaderProtocols::Icmpv6), 96 | ip_stack, 97 | Some(local_ip), 98 | ) 99 | .await?; 100 | Ok(Self { raw_ip_socket }) 101 | } 102 | } 103 | 104 | impl Deref for IcmpV6Socket { 105 | type Target = IpSocket; 106 | 107 | fn deref(&self) -> &Self::Target { 108 | &self.raw_ip_socket 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /src/ip/mod.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::net::{IpAddr, SocketAddr}; 3 | 4 | use crate::ip_stack::{check_ip, default_ip, BindAddr, IpStack, NetworkTuple, TransportPacket}; 5 | use bytes::BytesMut; 6 | pub use pnet_packet::ip::IpNextHeaderProtocol; 7 | pub use pnet_packet::ip::IpNextHeaderProtocols; 8 | 9 | /// Internally handles the splitting and reassembly of IP fragmentation. 10 | /// The read and write operations of Ipv4Socket do not include the IP header. 11 | /// For reading and writing, only the upper-layer protocol data of IP needs to be considered. 12 | pub struct IpSocket { 13 | _bind_addr: Option, 14 | protocol: Option, 15 | ip_stack: IpStack, 16 | packet_receiver: flume::Receiver, 17 | local_addr: Option, 18 | } 19 | #[cfg(feature = "global-ip-stack")] 20 | impl IpSocket { 21 | pub async fn bind_all(protocol: Option) -> io::Result { 22 | let ip_stack = IpStack::get()?; 23 | Self::bind0(ip_stack.config.ip_channel_size, protocol, ip_stack, None).await 24 | } 25 | pub async fn bind(protocol: Option, local_ip: IpAddr) -> io::Result { 26 | let ip_stack = IpStack::get()?; 27 | Self::bind0(ip_stack.config.ip_channel_size, protocol, ip_stack, Some(local_ip)).await 28 | } 29 | } 30 | #[cfg(not(feature = "global-ip-stack"))] 31 | impl IpSocket { 32 | pub async fn bind_all(protocol: Option, ip_stack: IpStack) -> io::Result { 33 | Self::bind0(ip_stack.config.ip_channel_size, protocol, ip_stack, None).await 34 | } 35 | pub async fn bind(protocol: Option, ip_stack: IpStack, local_ip: IpAddr) -> io::Result { 36 | Self::bind0(ip_stack.config.ip_channel_size, protocol, ip_stack, Some(local_ip)).await 37 | } 38 | } 39 | impl IpSocket { 40 | pub(crate) async fn bind0( 41 | channel_size: usize, 42 | protocol: Option, 43 | ip_stack: IpStack, 44 | local_ip: Option, 45 | ) -> io::Result { 46 | let local_addr = local_ip.map(|ip| SocketAddr::new(ip, 0)); 47 | let _bind_addr = if let (Some(protocol), Some(local_addr)) = (protocol, local_addr) { 48 | Some(ip_stack.bind_ip(protocol, local_addr)?) 49 | } else { 50 | None 51 | }; 52 | 53 | let (packet_sender, packet_receiver) = flume::bounded(channel_size); 54 | ip_stack.add_ip_socket(protocol, local_addr, packet_sender)?; 55 | Ok(Self { 56 | _bind_addr, 57 | protocol, 58 | ip_stack, 59 | packet_receiver, 60 | local_addr, 61 | }) 62 | } 63 | } 64 | 65 | impl IpSocket { 66 | pub fn local_ip(&self) -> io::Result { 67 | self.local_addr 68 | .map(|v| v.ip()) 69 | .ok_or_else(|| io::Error::from(io::ErrorKind::NotFound)) 70 | } 71 | pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, IpAddr)> { 72 | let (len, src, _dst) = self.recv_from_to(buf).await?; 73 | Ok((len, src)) 74 | } 75 | pub async fn send_to(&self, buf: &[u8], addr: IpAddr) -> io::Result { 76 | let from = if let Some(from) = self.local_addr { 77 | from.ip() 78 | } else { 79 | default_ip(addr.is_ipv4()) 80 | }; 81 | self.send_from_to(buf, from, addr).await 82 | } 83 | pub async fn recv_from_to(&self, buf: &mut [u8]) -> io::Result<(usize, IpAddr, IpAddr)> { 84 | let (len, _p, src, dst) = self.recv_protocol_from_to(buf).await?; 85 | Ok((len, src, dst)) 86 | } 87 | pub async fn recv_protocol_from_to(&self, buf: &mut [u8]) -> io::Result<(usize, IpNextHeaderProtocol, IpAddr, IpAddr)> { 88 | let Ok(packet) = self.packet_receiver.recv_async().await else { 89 | return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); 90 | }; 91 | 92 | let len = packet.buf.len(); 93 | if buf.len() < len { 94 | return Err(io::Error::new( 95 | io::ErrorKind::InvalidInput, 96 | format!("buf too short: {}<{len}", buf.len()), 97 | )); 98 | } 99 | buf[..len].copy_from_slice(&packet.buf); 100 | Ok(( 101 | len, 102 | packet.network_tuple.protocol, 103 | packet.network_tuple.src.ip(), 104 | packet.network_tuple.dst.ip(), 105 | )) 106 | } 107 | pub async fn send_from_to(&self, buf: &[u8], src: IpAddr, dst: IpAddr) -> io::Result { 108 | let Some(protocol) = self.protocol else { 109 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "need to specify protocol")); 110 | }; 111 | self.send_protocol_from_to(buf, protocol, src, dst).await 112 | } 113 | pub async fn send_protocol_from_to( 114 | &self, 115 | buf: &[u8], 116 | protocol: IpNextHeaderProtocol, 117 | mut src: IpAddr, 118 | dst: IpAddr, 119 | ) -> io::Result { 120 | if let Some(p) = self.protocol { 121 | if p != protocol { 122 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "inconsistent protocol")); 123 | } 124 | } 125 | if buf.len() > u16::MAX as usize - 8 { 126 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "buf too long")); 127 | } 128 | if src.is_ipv4() != dst.is_ipv4() { 129 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "address error")); 130 | } 131 | check_ip(dst)?; 132 | if let Err(e) = check_ip(src) { 133 | if let Some(v) = self.ip_stack.routes().route(dst) { 134 | src = v; 135 | } else { 136 | Err(e)? 137 | } 138 | } 139 | let data: BytesMut = buf.into(); 140 | let src = SocketAddr::new(src, 0); 141 | let dst = SocketAddr::new(dst, 0); 142 | let network_tuple = NetworkTuple::new(src, dst, protocol); 143 | 144 | let packet = TransportPacket::new(data, network_tuple); 145 | if self.ip_stack.inner.packet_sender.send(packet).await.is_err() { 146 | return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); 147 | } 148 | Ok(buf.len()) 149 | } 150 | } 151 | impl Drop for IpSocket { 152 | fn drop(&mut self) { 153 | self.ip_stack.remove_ip_socket(self.protocol, self.local_addr); 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /src/ip_stack.rs: -------------------------------------------------------------------------------- 1 | use bytes::BytesMut; 2 | use dashmap::{DashMap, Entry}; 3 | use parking_lot::Mutex; 4 | use pnet_packet::ip::{IpNextHeaderProtocol, IpNextHeaderProtocols}; 5 | use pnet_packet::ipv4::{Ipv4Flags, Ipv4Packet}; 6 | use pnet_packet::ipv6::Ipv6Packet; 7 | use pnet_packet::Packet; 8 | use rand::Rng; 9 | use std::collections::{HashMap, HashSet}; 10 | use std::hash::Hash; 11 | use std::io; 12 | use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; 13 | use std::sync::atomic::{AtomicBool, Ordering}; 14 | use std::sync::Arc; 15 | use std::time::{Duration, Instant, UNIX_EPOCH}; 16 | use tokio::sync::mpsc::{channel, Receiver, Sender}; 17 | use tokio::sync::Notify; 18 | 19 | pub(crate) const UNSPECIFIED_ADDR_V4: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); 20 | pub(crate) const UNSPECIFIED_ADDR_V6: SocketAddr = SocketAddr::V6(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0)); 21 | pub(crate) const fn default_addr(is_v4: bool) -> SocketAddr { 22 | if is_v4 { 23 | UNSPECIFIED_ADDR_V4 24 | } else { 25 | UNSPECIFIED_ADDR_V6 26 | } 27 | } 28 | pub(crate) const fn default_ip(is_v4: bool) -> IpAddr { 29 | if is_v4 { 30 | UNSPECIFIED_ADDR_V4.ip() 31 | } else { 32 | UNSPECIFIED_ADDR_V6.ip() 33 | } 34 | } 35 | pub(crate) fn check_addr(addr: SocketAddr) -> io::Result<()> { 36 | if addr.port() == 0 { 37 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid port")); 38 | } 39 | check_ip(addr.ip()) 40 | } 41 | pub(crate) fn check_ip(ip: IpAddr) -> io::Result<()> { 42 | if match ip { 43 | IpAddr::V4(ip) => ip.is_unspecified(), 44 | IpAddr::V6(ip) => ip.is_unspecified(), 45 | } { 46 | Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid ip")) 47 | } else { 48 | Ok(()) 49 | } 50 | } 51 | /// Configure the protocol stack 52 | #[derive(Copy, Clone, Debug)] 53 | pub struct IpStackConfig { 54 | pub mtu: u16, 55 | pub ip_fragment_timeout: Duration, 56 | pub tcp_config: crate::tcp::TcpConfig, 57 | pub channel_size: usize, 58 | pub tcp_syn_channel_size: usize, 59 | pub tcp_channel_size: usize, 60 | pub udp_channel_size: usize, 61 | pub icmp_channel_size: usize, 62 | pub ip_channel_size: usize, 63 | } 64 | 65 | impl IpStackConfig { 66 | pub fn check(&self) -> io::Result<()> { 67 | if self.mtu < 576 { 68 | return Err(io::Error::new(io::ErrorKind::InvalidData, "mtu<576")); 69 | } 70 | if self.ip_fragment_timeout.is_zero() { 71 | return Err(io::Error::new(io::ErrorKind::InvalidData, "ip_fragment_timeout is zero")); 72 | } 73 | 74 | if self.channel_size == 0 { 75 | return Err(io::Error::new(io::ErrorKind::InvalidData, "channel_size is zero")); 76 | } 77 | if self.tcp_syn_channel_size == 0 { 78 | return Err(io::Error::new(io::ErrorKind::InvalidData, "tcp_syn_channel_size is zero")); 79 | } 80 | if self.tcp_channel_size == 0 { 81 | return Err(io::Error::new(io::ErrorKind::InvalidData, "tcp_channel_size is zero")); 82 | } 83 | if self.udp_channel_size == 0 { 84 | return Err(io::Error::new(io::ErrorKind::InvalidData, "udp_channel_size is zero")); 85 | } 86 | if self.icmp_channel_size == 0 { 87 | return Err(io::Error::new(io::ErrorKind::InvalidData, "icmp_channel_size is zero")); 88 | } 89 | if self.ip_channel_size == 0 { 90 | return Err(io::Error::new(io::ErrorKind::InvalidData, "ip_channel_size is zero")); 91 | } 92 | self.tcp_config.check() 93 | } 94 | } 95 | 96 | impl Default for IpStackConfig { 97 | fn default() -> Self { 98 | Self { 99 | mtu: 1500, 100 | ip_fragment_timeout: Duration::from_secs(10), 101 | tcp_config: Default::default(), 102 | channel_size: 1024, 103 | tcp_syn_channel_size: 128, 104 | tcp_channel_size: 2048, 105 | udp_channel_size: 1024, 106 | icmp_channel_size: 128, 107 | ip_channel_size: 128, 108 | } 109 | } 110 | } 111 | 112 | /// Context information of protocol stack 113 | #[derive(Clone, Debug)] 114 | pub struct IpStack { 115 | routes: SafeRoutes, 116 | pub(crate) config: Box, 117 | pub(crate) inner: Arc, 118 | } 119 | 120 | #[derive(Debug)] 121 | pub(crate) struct IpStackInner { 122 | active_state: AtomicBool, 123 | pub(crate) tcp_stream_map: DashMap>, 124 | pub(crate) tcp_listener_map: DashMap, Sender>, 125 | pub(crate) udp_socket_map: DashMap<(Option, Option), flume::Sender>, 126 | pub(crate) raw_socket_map: DashMap<(Option, Option), flume::Sender>, 127 | pub(crate) packet_sender: Sender, 128 | bind_addrs: Mutex>, 129 | } 130 | impl IpStackInner { 131 | fn remove_all(&self) { 132 | self.active_state.store(false, Ordering::SeqCst); 133 | self.tcp_listener_map.clear(); 134 | self.tcp_stream_map.clear(); 135 | self.udp_socket_map.clear(); 136 | self.raw_socket_map.clear(); 137 | } 138 | fn check_state(&self) -> io::Result<()> { 139 | if self.active_state.load(Ordering::SeqCst) { 140 | Ok(()) 141 | } else { 142 | Err(io::Error::new(io::ErrorKind::Other, "shutdown")) 143 | } 144 | } 145 | fn check_state_and_remove(&self) -> io::Result<()> { 146 | let rs = self.check_state(); 147 | if rs.is_err() { 148 | self.remove_all(); 149 | } 150 | rs 151 | } 152 | } 153 | /// Send IP packets to the protocol stack using `IpStackSend` 154 | #[derive(Clone)] 155 | pub struct IpStackSend { 156 | ip_stack: IpStack, 157 | ident_fragments_map: Arc>>, 158 | notify: Arc, 159 | } 160 | 161 | impl Drop for IpStackSend { 162 | fn drop(&mut self) { 163 | self.notify.notify_one(); 164 | self.ip_stack.inner.remove_all(); 165 | } 166 | } 167 | 168 | impl IpStackSend { 169 | pub(crate) fn new(ip_stack: IpStack) -> Self { 170 | Self { 171 | ip_stack, 172 | ident_fragments_map: Default::default(), 173 | notify: Arc::new(Notify::new()), 174 | } 175 | } 176 | } 177 | 178 | /// Receive IP packets from the protocol stack using `IpStackRecv` 179 | pub struct IpStackRecv { 180 | inner: IpStackRecvInner, 181 | index: usize, 182 | num: usize, 183 | sizes: Vec, 184 | bufs: Vec, 185 | } 186 | struct IpStackRecvInner { 187 | mtu: u16, 188 | identification: u16, 189 | packet_receiver: Receiver, 190 | } 191 | 192 | impl IpStackRecv { 193 | pub(crate) fn new(mtu: u16, packet_receiver: Receiver) -> Self { 194 | let identification = std::time::SystemTime::now() 195 | .duration_since(UNIX_EPOCH) 196 | .map(|v| (v.as_millis() & 0xFFFF) as u16) 197 | .unwrap_or(0); 198 | let inner = IpStackRecvInner { 199 | mtu, 200 | identification, 201 | packet_receiver, 202 | }; 203 | Self { 204 | inner, 205 | index: 0, 206 | num: 0, 207 | sizes: Vec::new(), 208 | bufs: Vec::new(), 209 | } 210 | } 211 | } 212 | 213 | #[derive(Eq, Hash, PartialEq, Debug, Clone, Copy)] 214 | pub(crate) struct NetworkTuple { 215 | pub src: SocketAddr, 216 | pub dst: SocketAddr, 217 | pub protocol: IpNextHeaderProtocol, 218 | } 219 | 220 | impl NetworkTuple { 221 | pub fn new(src: SocketAddr, dst: SocketAddr, protocol: IpNextHeaderProtocol) -> Self { 222 | assert_eq!(src.is_ipv4(), dst.is_ipv4()); 223 | Self { src, dst, protocol } 224 | } 225 | pub fn is_ipv4(&self) -> bool { 226 | self.src.is_ipv4() 227 | } 228 | } 229 | 230 | #[derive(Eq, Hash, PartialEq, Debug, Clone, Copy)] 231 | struct IdKey { 232 | pub src: IpAddr, 233 | pub dst: IpAddr, 234 | pub protocol: IpNextHeaderProtocol, 235 | pub identification: u16, 236 | } 237 | 238 | impl IdKey { 239 | fn new(src: IpAddr, dst: IpAddr, protocol: IpNextHeaderProtocol, identification: u16) -> Self { 240 | Self { 241 | src, 242 | dst, 243 | protocol, 244 | identification, 245 | } 246 | } 247 | } 248 | 249 | /// Create a user-space protocol stack. 250 | /// 251 | /// # Examples 252 | /// ```rust 253 | /// use tcp_ip::tcp::TcpListener; 254 | /// #[cfg(not(feature = "global-ip-stack"))] 255 | /// async fn main(){ 256 | /// let (ip_stack, ip_stack_send, ip_stack_recv) = tcp_ip::ip_stack(Default::default())?; 257 | /// // Use ip_stack_send and ip_stack_recv to interface 258 | /// // with the input and output of IP packets. 259 | /// // ... 260 | /// let mut tcp_listener = TcpListener::bind_all(ip_stack.clone()).await?; 261 | /// } 262 | /// ``` 263 | #[cfg(not(feature = "global-ip-stack"))] 264 | pub fn ip_stack(config: IpStackConfig) -> io::Result<(IpStack, IpStackSend, IpStackRecv)> { 265 | ip_stack0(config) 266 | } 267 | 268 | /// Create a user-space protocol stack. 269 | /// 270 | /// # Examples 271 | /// ```rust 272 | /// use tcp_ip::tcp::TcpListener; 273 | /// #[cfg(feature = "global-ip-stack")] 274 | /// async fn main(){ 275 | /// let (ip_stack_send, ip_stack_recv) = tcp_ip::ip_stack(Default::default())?; 276 | /// // Use ip_stack_send and ip_stack_recv to interface 277 | /// // with the input and output of IP packets. 278 | /// // ... 279 | /// let mut tcp_listener = TcpListener::bind_all().await?; 280 | /// } 281 | /// ``` 282 | #[cfg(feature = "global-ip-stack")] 283 | pub fn ip_stack(config: IpStackConfig) -> io::Result<(IpStackSend, IpStackRecv)> { 284 | let (ip_stack, ip_stack_send, ip_stack_recv) = ip_stack0(config)?; 285 | IpStack::set(ip_stack); 286 | Ok((ip_stack_send, ip_stack_recv)) 287 | } 288 | fn ip_stack0(config: IpStackConfig) -> io::Result<(IpStack, IpStackSend, IpStackRecv)> { 289 | config.check()?; 290 | let (packet_sender, packet_receiver) = channel(config.channel_size); 291 | let ip_stack = IpStack::new(config, packet_sender); 292 | let ip_stack_send = IpStackSend::new(ip_stack.clone()); 293 | let ip_stack_recv = IpStackRecv::new(ip_stack.config.mtu, packet_receiver); 294 | { 295 | let ident_fragments_map = ip_stack_send.ident_fragments_map.clone(); 296 | let notify = ip_stack_send.notify.clone(); 297 | let timeout = ip_stack.config.ip_fragment_timeout; 298 | tokio::spawn(async move { 299 | loop_check_timeouts(timeout, ident_fragments_map, notify).await; 300 | }); 301 | } 302 | Ok((ip_stack, ip_stack_send, ip_stack_recv)) 303 | } 304 | async fn loop_check_timeouts(timeout: Duration, ident_fragments_map: Arc>>, notify: Arc) { 305 | let notified = notify.notified(); 306 | tokio::pin!(notified); 307 | loop { 308 | tokio::select! { 309 | _=&mut notified=>{ 310 | break; 311 | } 312 | _=tokio::time::sleep(timeout)=>{ 313 | check_timeouts(&ident_fragments_map,timeout); 314 | } 315 | 316 | } 317 | } 318 | } 319 | 320 | fn check_timeouts(ident_fragments_map: &Mutex>, timeout: Duration) { 321 | if let Some(mut ident_fragments_map) = ident_fragments_map.try_lock() { 322 | let now = Instant::now(); 323 | // Clear timeout IP segmentation 324 | ident_fragments_map.retain(|_id_key, p| p.time + timeout > now) 325 | } 326 | } 327 | impl IpStack { 328 | pub fn routes(&self) -> &SafeRoutes { 329 | &self.routes 330 | } 331 | } 332 | impl IpStack { 333 | pub(crate) fn new(config: IpStackConfig, packet_sender: Sender) -> Self { 334 | Self { 335 | routes: Default::default(), 336 | config: Box::new(config), 337 | inner: Arc::new(IpStackInner { 338 | active_state: AtomicBool::new(true), 339 | tcp_stream_map: Default::default(), 340 | tcp_listener_map: Default::default(), 341 | udp_socket_map: Default::default(), 342 | raw_socket_map: Default::default(), 343 | packet_sender, 344 | bind_addrs: Default::default(), 345 | }), 346 | } 347 | } 348 | pub(crate) fn add_ip_socket( 349 | &self, 350 | protocol: Option, 351 | local_addr: Option, 352 | packet_sender: flume::Sender, 353 | ) -> io::Result<()> { 354 | Self::add_socket0(&self.inner, &self.inner.raw_socket_map, (protocol, local_addr), packet_sender) 355 | } 356 | pub(crate) fn add_udp_socket( 357 | &self, 358 | local_addr: Option, 359 | peer_addr: Option, 360 | packet_sender: flume::Sender, 361 | ) -> io::Result<()> { 362 | Self::add_socket0(&self.inner, &self.inner.udp_socket_map, (local_addr, peer_addr), packet_sender) 363 | } 364 | 365 | pub(crate) fn replace_udp_socket( 366 | &self, 367 | old: (Option, Option), 368 | new: (Option, Option), 369 | ) -> io::Result<()> { 370 | let packet_sender = if let Some(v) = self.inner.udp_socket_map.get(&old) { 371 | v.value().clone() 372 | } else { 373 | return Err(io::Error::from(io::ErrorKind::NotFound)); 374 | }; 375 | Self::add_socket0(&self.inner, &self.inner.udp_socket_map, new, packet_sender)?; 376 | _ = self.inner.udp_socket_map.remove(&old); 377 | Ok(()) 378 | } 379 | 380 | pub(crate) fn add_tcp_listener(&self, local_addr: Option, packet_sender: Sender) -> io::Result<()> { 381 | Self::add_socket0(&self.inner, &self.inner.tcp_listener_map, local_addr, packet_sender) 382 | } 383 | pub(crate) fn remove_tcp_listener(&self, local_addr: &Option) { 384 | self.inner.tcp_listener_map.remove(local_addr); 385 | } 386 | pub(crate) fn add_tcp_socket(&self, network_tuple: NetworkTuple, packet_sender: Sender) -> io::Result<()> { 387 | Self::add_socket0(&self.inner, &self.inner.tcp_stream_map, network_tuple, packet_sender) 388 | } 389 | pub(crate) fn remove_tcp_socket(&self, network_tuple: &NetworkTuple) { 390 | self.inner.tcp_stream_map.remove(network_tuple); 391 | } 392 | pub(crate) fn remove_udp_socket(&self, local_addr: Option, peer_addr: Option) { 393 | self.inner.udp_socket_map.remove(&(local_addr, peer_addr)); 394 | } 395 | pub(crate) fn remove_ip_socket(&self, protocol: Option, local_addr: Option) { 396 | self.inner.raw_socket_map.remove(&(protocol, local_addr)); 397 | } 398 | fn add_socket0( 399 | ip_stack_inner: &IpStackInner, 400 | map: &DashMap, 401 | local_addr: K, 402 | packet_sender: V, 403 | ) -> io::Result<()> { 404 | ip_stack_inner.check_state()?; 405 | let entry = map.entry(local_addr); 406 | let rs = match entry { 407 | Entry::Occupied(_entry) => Err(io::Error::from(io::ErrorKind::AddrInUse)), 408 | Entry::Vacant(entry) => { 409 | entry.insert(packet_sender); 410 | Ok(()) 411 | } 412 | }; 413 | ip_stack_inner.check_state_and_remove()?; 414 | rs 415 | } 416 | pub(crate) async fn send_packet(&self, transport_packet: TransportPacket) -> io::Result<()> { 417 | match self.inner.packet_sender.send(transport_packet).await { 418 | Ok(_) => Ok(()), 419 | Err(_) => Err(io::Error::new(io::ErrorKind::WriteZero, "ip stack close")), 420 | } 421 | } 422 | pub(crate) fn bind(&self, protocol: IpNextHeaderProtocol, addr: &mut SocketAddr) -> io::Result { 423 | let bind_address = self.inner.add_bind_addr(protocol, *addr, true)?; 424 | *addr = bind_address; 425 | Ok(BindAddr { 426 | protocol, 427 | addr: bind_address, 428 | inner: self.inner.clone(), 429 | }) 430 | } 431 | pub(crate) fn bind_ip(&self, protocol: IpNextHeaderProtocol, addr: SocketAddr) -> io::Result { 432 | _ = self.inner.add_bind_addr(protocol, addr, false)?; 433 | Ok(BindAddr { 434 | protocol, 435 | addr, 436 | inner: self.inner.clone(), 437 | }) 438 | } 439 | } 440 | 441 | impl IpStackSend { 442 | /// Send the IP packet to this protocol stack. 443 | pub async fn send_ip_packet(&self, buf: &[u8]) -> io::Result<()> { 444 | let p = buf[0] >> 4; 445 | match p { 446 | 4 => { 447 | let Some(packet) = Ipv4Packet::new(buf) else { 448 | return Err(io::Error::from(io::ErrorKind::InvalidInput)); 449 | }; 450 | 451 | let id_key = convert_id_key(&packet); 452 | 453 | let Some(network_tuple) = self.prepare_ipv4_fragments(&packet, id_key)? else { 454 | return Ok(()); 455 | }; 456 | let mut sender = match packet.get_next_level_protocol() { 457 | IpNextHeaderProtocols::Tcp => self.get_tcp_sender(&network_tuple), 458 | IpNextHeaderProtocols::Udp => self.get_udp_sender(&network_tuple), 459 | _ => None, 460 | }; 461 | if sender.is_none() { 462 | sender = self.get_raw_sender(packet.get_next_level_protocol(), &network_tuple); 463 | } 464 | if let Some(sender) = sender { 465 | let rs = self.transmit_ip_packet(sender, packet, id_key, network_tuple).await; 466 | if rs.is_err() { 467 | self.clear_fragment_cache(&id_key); 468 | } 469 | rs 470 | } else { 471 | self.clear_fragment_cache(&id_key); 472 | Ok(()) 473 | } 474 | } 475 | 6 => { 476 | let Some(packet) = Ipv6Packet::new(buf) else { 477 | return Err(io::Error::from(io::ErrorKind::InvalidInput)); 478 | }; 479 | // todo Need to handle fragmentation, routing, and other header information. 480 | let network_tuple = self.prepare_ipv6_fragments(&packet)?; 481 | let mut sender = match packet.get_next_header() { 482 | IpNextHeaderProtocols::Tcp => self.get_tcp_sender(&network_tuple), 483 | IpNextHeaderProtocols::Udp => self.get_udp_sender(&network_tuple), 484 | _ => None, 485 | }; 486 | if sender.is_none() { 487 | sender = self.get_raw_sender(packet.get_next_header(), &network_tuple); 488 | } 489 | if let Some(sender) = sender { 490 | _ = sender.send(TransportPacket::new(packet.payload().into(), network_tuple)).await; 491 | } 492 | Ok(()) 493 | } 494 | _ => Err(io::Error::from(io::ErrorKind::InvalidInput)), 495 | } 496 | } 497 | fn get_tcp_sender(&self, network_tuple: &NetworkTuple) -> Option> { 498 | let stack = &self.ip_stack.inner; 499 | if let Some(tcp) = stack.tcp_stream_map.get(network_tuple) { 500 | Some(SenderBox::Mpsc(tcp.value().clone())) 501 | } else if let Some(tcp) = stack.tcp_listener_map.get(&Some(network_tuple.dst)) { 502 | Some(SenderBox::Mpsc(tcp.value().clone())) 503 | } else { 504 | let dst = SocketAddr::new(default_ip(network_tuple.is_ipv4()), network_tuple.dst.port()); 505 | if let Some(tcp) = stack.tcp_listener_map.get(&Some(dst)) { 506 | Some(SenderBox::Mpsc(tcp.value().clone())) 507 | } else if let Some(tcp) = stack.tcp_listener_map.get(&Some(default_addr(network_tuple.is_ipv4()))) { 508 | Some(SenderBox::Mpsc(tcp.value().clone())) 509 | } else { 510 | stack.tcp_listener_map.get(&None).map(|tcp| SenderBox::Mpsc(tcp.value().clone())) 511 | } 512 | } 513 | } 514 | fn get_udp_sender(&self, network_tuple: &NetworkTuple) -> Option> { 515 | let stack = &self.ip_stack.inner; 516 | if let Some(udp) = stack.udp_socket_map.get(&(Some(network_tuple.dst), Some(network_tuple.src))) { 517 | return Some(SenderBox::Mpmc(udp.value().clone())); 518 | } 519 | if let Some(udp) = stack.udp_socket_map.get(&(Some(network_tuple.dst), None)) { 520 | Some(SenderBox::Mpmc(udp.value().clone())) 521 | } else { 522 | let dst = SocketAddr::new(default_ip(network_tuple.is_ipv4()), network_tuple.dst.port()); 523 | if let Some(udp) = stack.udp_socket_map.get(&(Some(dst), None)) { 524 | Some(SenderBox::Mpmc(udp.value().clone())) 525 | } else if let Some(udp) = stack.udp_socket_map.get(&(Some(default_addr(network_tuple.is_ipv4())), None)) { 526 | Some(SenderBox::Mpmc(udp.value().clone())) 527 | } else { 528 | stack 529 | .udp_socket_map 530 | .get(&(None, None)) 531 | .map(|udp| SenderBox::Mpmc(udp.value().clone())) 532 | } 533 | } 534 | } 535 | fn get_raw_sender(&self, protocol: IpNextHeaderProtocol, network_tuple: &NetworkTuple) -> Option> { 536 | if let Some(v) = self.get_raw_sender0(Some(protocol), network_tuple) { 537 | Some(v) 538 | } else { 539 | self.get_raw_sender0(None, network_tuple) 540 | } 541 | } 542 | fn get_raw_sender0(&self, protocol: Option, network_tuple: &NetworkTuple) -> Option> { 543 | let stack = &self.ip_stack.inner; 544 | if let Some(socket) = stack.raw_socket_map.get(&(protocol, Some(network_tuple.dst))) { 545 | Some(SenderBox::Mpmc(socket.value().clone())) 546 | } else { 547 | let dst = SocketAddr::new(default_ip(network_tuple.is_ipv4()), network_tuple.dst.port()); 548 | if let Some(socket) = stack.raw_socket_map.get(&(protocol, Some(dst))) { 549 | Some(SenderBox::Mpmc(socket.value().clone())) 550 | } else if let Some(socket) = stack.raw_socket_map.get(&(protocol, Some(default_addr(network_tuple.is_ipv4())))) { 551 | Some(SenderBox::Mpmc(socket.value().clone())) 552 | } else { 553 | stack 554 | .raw_socket_map 555 | .get(&(protocol, None)) 556 | .map(|icmp| SenderBox::Mpmc(icmp.value().clone())) 557 | } 558 | } 559 | } 560 | async fn transmit_ip_packet( 561 | &self, 562 | sender: SenderBox, 563 | packet: Ipv4Packet<'_>, 564 | id_key: IdKey, 565 | network_tuple: NetworkTuple, 566 | ) -> io::Result<()> { 567 | let more_fragments = packet.get_flags() & Ipv4Flags::MoreFragments == Ipv4Flags::MoreFragments; 568 | let offset = packet.get_fragment_offset(); 569 | let segmented = more_fragments || offset > 0; 570 | let buf = if segmented { 571 | // merge ip fragments 572 | if let Some(buf) = self.merge_ip_fragments(&packet, id_key, network_tuple)? { 573 | buf 574 | } else { 575 | // Need to wait for all shards to arrive 576 | return Ok(()); 577 | } 578 | } else { 579 | // confirm that the id is not occupied 580 | self.clear_fragment_cache(&id_key); 581 | packet.payload().into() 582 | }; 583 | _ = sender.send(TransportPacket::new(buf, network_tuple)).await; 584 | Ok(()) 585 | } 586 | fn prepare_ipv4_fragments(&self, ip_packet: &Ipv4Packet<'_>, id_key: IdKey) -> io::Result> { 587 | let offset = ip_packet.get_fragment_offset(); 588 | let network_tuple = if offset == 0 589 | || (ip_packet.get_next_level_protocol() != IpNextHeaderProtocols::Udp 590 | && ip_packet.get_next_level_protocol() != IpNextHeaderProtocols::Tcp) 591 | { 592 | // No segmentation or the first segmentation 593 | convert_network_tuple(ip_packet)? 594 | } else { 595 | let mut guard = self.ident_fragments_map.lock(); 596 | let p = guard.entry(id_key).or_default(); 597 | 598 | if let Some(v) = p.network_tuple { 599 | v 600 | } else { 601 | // Perhaps the first IP segment has not yet arrived, 602 | // so the network tuple cannot be obtained. 603 | let last_fragment = ip_packet.get_flags() & Ipv4Flags::MoreFragments != Ipv4Flags::MoreFragments; 604 | p.add_fragment(ip_packet.into(), last_fragment)?; 605 | return Ok(None); 606 | } 607 | }; 608 | Ok(Some(network_tuple)) 609 | } 610 | fn prepare_ipv6_fragments(&self, ip_packet: &Ipv6Packet<'_>) -> io::Result { 611 | match ip_packet.get_next_header() { 612 | IpNextHeaderProtocols::Ipv6Frag 613 | | IpNextHeaderProtocols::Ipv6Route 614 | | IpNextHeaderProtocols::Ipv6Opts 615 | | IpNextHeaderProtocols::Ipv6NoNxt => { 616 | // todo Handle IP fragmentation. 617 | return Err(io::Error::new(io::ErrorKind::Unsupported, "ipv6 option")); 618 | } 619 | _ => {} 620 | } 621 | convert_network_tuple_v6(ip_packet) 622 | } 623 | fn merge_ip_fragments(&self, ip_packet: &Ipv4Packet<'_>, id_key: IdKey, network_tuple: NetworkTuple) -> io::Result> { 624 | let mut map = self.ident_fragments_map.lock(); 625 | let ip_fragments = map 626 | .entry(id_key) 627 | .and_modify(|p| p.update_time()) 628 | .or_insert_with(|| IpFragments::new(network_tuple)); 629 | 630 | let last_fragment = ip_packet.get_flags() & Ipv4Flags::MoreFragments != Ipv4Flags::MoreFragments; 631 | let offset = ip_packet.get_fragment_offset() << 3; 632 | if last_fragment { 633 | ip_fragments.last_offset.replace(offset); 634 | } 635 | ip_fragments.add_fragment(ip_packet.into(), last_fragment)?; 636 | 637 | if ip_fragments.is_complete() { 638 | //This place cannot be None 639 | let mut fragments = map.remove(&id_key).unwrap(); 640 | fragments 641 | .bufs 642 | .sort_by(|ip_fragment1, ip_fragment2| ip_fragment1.offset.cmp(&ip_fragment2.offset)); 643 | let mut total_payload_len = 0; 644 | for ip_fragment in &fragments.bufs { 645 | if total_payload_len as u16 != ip_fragment.offset { 646 | return Err(io::Error::new( 647 | io::ErrorKind::InvalidData, 648 | format!("fragment offset error:{total_payload_len}!={}", ip_fragment.offset), 649 | )); 650 | } 651 | total_payload_len += ip_fragment.payload.len(); 652 | } 653 | let mut p = BytesMut::with_capacity(total_payload_len); 654 | for ip_fragment in &fragments.bufs { 655 | p.extend_from_slice(&ip_fragment.payload); 656 | } 657 | return Ok(Some(p)); 658 | } 659 | 660 | Ok(None) 661 | } 662 | fn clear_fragment_cache(&self, id_key: &IdKey) { 663 | let mut guard = self.ident_fragments_map.lock(); 664 | guard.remove(id_key); 665 | } 666 | } 667 | 668 | impl IpStackRecv { 669 | /// Read a single IP packet from the protocol stack. 670 | pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result { 671 | loop { 672 | if self.num > self.index { 673 | let index = self.index; 674 | let len = self.sizes[index]; 675 | if buf.len() < len { 676 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "bufs too short")); 677 | } 678 | buf[..len].copy_from_slice(&self.bufs[index][..len]); 679 | self.index += 1; 680 | return Ok(len); 681 | } 682 | self.index = 0; 683 | self.num = 0; 684 | if self.sizes.is_empty() { 685 | self.sizes.resize(128, 0); 686 | } 687 | if self.bufs.is_empty() { 688 | for _ in 0..128 { 689 | self.bufs.push(BytesMut::zeroed(self.inner.mtu as usize)); 690 | } 691 | } 692 | self.num = self.inner.recv_ip_packet(&mut self.bufs, &mut self.sizes).await?; 693 | if self.num == 0 { 694 | return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "read 0")); 695 | } 696 | } 697 | } 698 | /// Read multiple IP packets from the protocol stack at once. 699 | pub async fn recv_ip_packet>(&mut self, bufs: &mut [B], sizes: &mut [usize]) -> io::Result { 700 | self.inner.recv_ip_packet(bufs, sizes).await 701 | } 702 | } 703 | impl IpStackRecvInner { 704 | async fn recv_ip_packet>(&mut self, bufs: &mut [B], sizes: &mut [usize]) -> io::Result { 705 | if bufs.is_empty() { 706 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "bufs is empty")); 707 | } 708 | if bufs.len() != sizes.len() { 709 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "bufs.len!=sizes.len")); 710 | } 711 | if let Some(packet) = self.packet_receiver.recv().await { 712 | match (packet.network_tuple.src.is_ipv6(), packet.network_tuple.dst.is_ipv6()) { 713 | (true, true) => self.wrap_in_ipv6(bufs, sizes, packet), 714 | (false, false) => self.split_ip_packet(bufs, sizes, packet), 715 | (_, _) => Err(io::Error::new(io::ErrorKind::InvalidInput, "address error")), 716 | } 717 | } else { 718 | Err(io::Error::new(io::ErrorKind::UnexpectedEof, "close")) 719 | } 720 | } 721 | fn wrap_in_ipv6>(&mut self, bufs: &mut [B], sizes: &mut [usize], packet: TransportPacket) -> io::Result { 722 | let buf = bufs[0].as_mut(); 723 | let total_length = 40 + packet.buf.len(); 724 | if buf.len() < total_length { 725 | return Err(io::Error::new( 726 | io::ErrorKind::InvalidInput, 727 | format!("bufs[0] too short.{total_length}>{:?}", buf.len()), 728 | )); 729 | } 730 | let src_ip = match packet.network_tuple.src.ip() { 731 | IpAddr::V6(ip) => ip, 732 | IpAddr::V4(_) => unimplemented!(), 733 | }; 734 | let dst_ip = match packet.network_tuple.dst.ip() { 735 | IpAddr::V6(ip) => ip, 736 | IpAddr::V4(_) => unimplemented!(), 737 | }; 738 | // 创建一个可变的IPv6数据包 739 | let Some(mut ipv6_packet) = pnet_packet::ipv6::MutableIpv6Packet::new(&mut buf[..total_length]) else { 740 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "ipv6 data error")); 741 | }; 742 | ipv6_packet.set_version(6); 743 | ipv6_packet.set_traffic_class(0); 744 | ipv6_packet.set_flow_label(0); 745 | ipv6_packet.set_payload_length(packet.buf.len() as u16); // 设置负载长度 746 | ipv6_packet.set_next_header(packet.network_tuple.protocol); 747 | ipv6_packet.set_hop_limit(64); 748 | ipv6_packet.set_source(src_ip); 749 | ipv6_packet.set_destination(dst_ip); 750 | // 添加负载数据 751 | ipv6_packet.set_payload(&packet.buf); 752 | sizes[0] = total_length; 753 | Ok(1) 754 | } 755 | fn split_ip_packet>(&mut self, bufs: &mut [B], sizes: &mut [usize], packet: TransportPacket) -> io::Result { 756 | let mtu = self.mtu; 757 | self.identification = self.identification.wrapping_sub(1); 758 | let identification = self.identification; 759 | let mut offset = 0; 760 | let mut total_packets = 0; 761 | 762 | const IPV4_HEADER_SIZE: usize = 20; // IPv4 header fixed size 763 | let max_payload_size = mtu as usize - IPV4_HEADER_SIZE; 764 | let max_payload_size_8 = max_payload_size & !0b111; 765 | let src_ip = match packet.network_tuple.src.ip() { 766 | IpAddr::V4(ip) => ip, 767 | IpAddr::V6(_) => unimplemented!(), 768 | }; 769 | let dst_ip = match packet.network_tuple.dst.ip() { 770 | IpAddr::V4(ip) => ip, 771 | IpAddr::V6(_) => unimplemented!(), 772 | }; 773 | let protocol = packet.network_tuple.protocol.0; 774 | 775 | while offset < packet.buf.len() { 776 | let remaining = packet.buf.len() - offset; 777 | let fragment_size = if remaining > max_payload_size { 778 | max_payload_size_8 779 | } else { 780 | remaining 781 | }; 782 | let total_length = IPV4_HEADER_SIZE + fragment_size; 783 | if total_packets >= bufs.len() { 784 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "bufs too short")); 785 | } 786 | let buf = bufs[total_packets].as_mut(); 787 | if total_length > buf.len() { 788 | return Err(io::Error::new( 789 | io::ErrorKind::InvalidInput, 790 | format!("bufs[{total_packets}] too short.{total_length}>{:?}", buf.len()), 791 | )); 792 | } 793 | let more_fragments = if remaining > fragment_size { 1 } else { 0 }; 794 | assert_eq!(offset & 0b111, 0, "Offset must be a multiple of 8"); 795 | let fragment_offset = ((offset & !0b111) as u16) >> 3; 796 | let flags_fragment_offset = (more_fragments << 13) | fragment_offset; 797 | 798 | let ip_header = &mut buf[..IPV4_HEADER_SIZE]; 799 | ip_header[0] = (4 << 4) | (IPV4_HEADER_SIZE / 4) as u8; // Version (4) + IHL 800 | ip_header[1] = 0; // Type of Service 801 | ip_header[2..4].copy_from_slice(&(total_length as u16).to_be_bytes()); 802 | ip_header[4..6].copy_from_slice(&identification.to_be_bytes()); 803 | ip_header[6..8].copy_from_slice(&flags_fragment_offset.to_be_bytes()); 804 | ip_header[8] = 64; // TTL 805 | ip_header[9] = protocol; 806 | ip_header[12..16].copy_from_slice(&src_ip.octets()); 807 | ip_header[16..20].copy_from_slice(&dst_ip.octets()); 808 | 809 | let checksum = pnet_packet::util::checksum(ip_header, 5); 810 | ip_header[10..12].copy_from_slice(&checksum.to_be_bytes()); 811 | let ip_payload = &mut buf[IPV4_HEADER_SIZE..total_length]; 812 | ip_payload.copy_from_slice(&packet.buf[offset..offset + fragment_size]); 813 | offset += fragment_size; 814 | sizes[total_packets] = total_length; 815 | total_packets += 1; 816 | } 817 | Ok(total_packets) 818 | } 819 | } 820 | 821 | #[derive(Debug)] 822 | pub(crate) struct TransportPacket { 823 | pub buf: BytesMut, 824 | pub network_tuple: NetworkTuple, 825 | } 826 | 827 | impl TransportPacket { 828 | pub fn new(buf: BytesMut, network_tuple: NetworkTuple) -> Self { 829 | Self { buf, network_tuple } 830 | } 831 | } 832 | 833 | struct IpFragments { 834 | network_tuple: Option, 835 | bufs: Vec, 836 | // Read IP payload length(Excluding the last IP segment). 837 | read_len: u16, 838 | // The offset of the last segment. 839 | // If last_offset == read_len, it means all fragments have been received. 840 | last_offset: Option, 841 | time: Instant, 842 | } 843 | 844 | struct IpFragment { 845 | offset: u16, 846 | payload: BytesMut, 847 | } 848 | 849 | impl From<&Ipv4Packet<'_>> for IpFragment { 850 | fn from(value: &Ipv4Packet<'_>) -> Self { 851 | Self { 852 | offset: value.get_fragment_offset() << 3, 853 | payload: value.payload().into(), 854 | } 855 | } 856 | } 857 | 858 | impl Default for IpFragments { 859 | fn default() -> Self { 860 | Self { 861 | network_tuple: None, 862 | bufs: Vec::with_capacity(8), 863 | read_len: 0, 864 | last_offset: None, 865 | time: Instant::now(), 866 | } 867 | } 868 | } 869 | 870 | impl IpFragments { 871 | fn new(network_tuple: NetworkTuple) -> Self { 872 | Self { 873 | network_tuple: Some(network_tuple), 874 | ..Self::default() 875 | } 876 | } 877 | fn update_time(&mut self) { 878 | self.time = Instant::now(); 879 | } 880 | fn add_fragment(&mut self, ip_fragment: IpFragment, last_fragment: bool) -> io::Result<()> { 881 | if !last_fragment { 882 | let (read_len, overflow) = self.read_len.overflowing_add(ip_fragment.payload.len() as u16); 883 | if overflow { 884 | return Err(io::Error::new(io::ErrorKind::InvalidData, "IP segment length overflow")); 885 | } 886 | self.read_len = read_len; 887 | } 888 | 889 | self.bufs.push(ip_fragment); 890 | Ok(()) 891 | } 892 | fn is_complete(&self) -> bool { 893 | if let Some(last_offset) = self.last_offset { 894 | last_offset == self.read_len 895 | } else { 896 | false 897 | } 898 | } 899 | } 900 | 901 | fn convert_network_tuple(packet: &Ipv4Packet) -> io::Result { 902 | let src_ip = packet.get_source(); 903 | let dest_ip = packet.get_destination(); 904 | let (src_port, dest_port) = match packet.get_next_level_protocol() { 905 | IpNextHeaderProtocols::Tcp => { 906 | let Some(tcp_packet) = pnet_packet::tcp::TcpPacket::new(packet.payload()) else { 907 | return Err(io::Error::from(io::ErrorKind::InvalidData)); 908 | }; 909 | (tcp_packet.get_source(), tcp_packet.get_destination()) 910 | } 911 | IpNextHeaderProtocols::Udp => { 912 | let Some(udp_packet) = pnet_packet::udp::UdpPacket::new(packet.payload()) else { 913 | return Err(io::Error::from(io::ErrorKind::InvalidData)); 914 | }; 915 | (udp_packet.get_source(), udp_packet.get_destination()) 916 | } 917 | _ => (0, 0), 918 | }; 919 | 920 | let src_addr = SocketAddrV4::new(src_ip, src_port); 921 | let dest_addr = SocketAddrV4::new(dest_ip, dest_port); 922 | let protocol = packet.get_next_level_protocol(); 923 | let network_tuple = NetworkTuple::new(src_addr.into(), dest_addr.into(), protocol); 924 | Ok(network_tuple) 925 | } 926 | fn convert_network_tuple_v6(packet: &Ipv6Packet) -> io::Result { 927 | let src_ip = packet.get_source(); 928 | let dest_ip = packet.get_destination(); 929 | let protocol = packet.get_next_header(); 930 | 931 | let (src_port, dest_port) = match protocol { 932 | IpNextHeaderProtocols::Tcp => { 933 | let Some(tcp_packet) = pnet_packet::tcp::TcpPacket::new(packet.payload()) else { 934 | return Err(io::Error::from(io::ErrorKind::InvalidData)); 935 | }; 936 | (tcp_packet.get_source(), tcp_packet.get_destination()) 937 | } 938 | IpNextHeaderProtocols::Udp => { 939 | let Some(udp_packet) = pnet_packet::udp::UdpPacket::new(packet.payload()) else { 940 | return Err(io::Error::from(io::ErrorKind::InvalidData)); 941 | }; 942 | (udp_packet.get_source(), udp_packet.get_destination()) 943 | } 944 | _ => (0, 0), 945 | }; 946 | 947 | let src_addr = SocketAddrV6::new(src_ip, src_port, 0, 0); 948 | let dest_addr = SocketAddrV6::new(dest_ip, dest_port, 0, 0); 949 | let network_tuple = NetworkTuple::new(src_addr.into(), dest_addr.into(), protocol); 950 | Ok(network_tuple) 951 | } 952 | 953 | fn convert_id_key(packet: &Ipv4Packet) -> IdKey { 954 | let src_ip = packet.get_source(); 955 | let dest_ip = packet.get_destination(); 956 | let protocol = packet.get_next_level_protocol(); 957 | let identification = packet.get_identification(); 958 | IdKey::new(src_ip.into(), dest_ip.into(), protocol, identification) 959 | } 960 | 961 | enum SenderBox { 962 | Mpsc(Sender), 963 | Mpmc(flume::Sender), 964 | } 965 | 966 | impl SenderBox { 967 | async fn send(&self, t: T) -> bool { 968 | match self { 969 | SenderBox::Mpsc(sender) => sender.send(t).await.is_ok(), 970 | SenderBox::Mpmc(sender) => sender.send_async(t).await.is_ok(), 971 | } 972 | } 973 | } 974 | #[derive(Clone, Default, Debug)] 975 | pub struct SafeRoutes { 976 | routes: Arc>, 977 | } 978 | impl SafeRoutes { 979 | pub(crate) fn check_bind_ip(&self, ip: IpAddr) -> io::Result<()> { 980 | if check_ip(ip).is_ok() && !self.exists_ip(&ip) { 981 | return Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "cannot assign requested address")); 982 | } 983 | Ok(()) 984 | } 985 | pub(crate) fn exists_ip(&self, ip: &IpAddr) -> bool { 986 | match ip { 987 | IpAddr::V4(ip) => self.exists_v4(ip), 988 | IpAddr::V6(ip) => self.exists_v6(ip), 989 | } 990 | } 991 | pub(crate) fn exists_v4(&self, ip: &Ipv4Addr) -> bool { 992 | self.routes.lock().exists_v4(ip) 993 | } 994 | pub(crate) fn exists_v6(&self, ip: &Ipv6Addr) -> bool { 995 | self.routes.lock().exists_v6(ip) 996 | } 997 | pub fn ipv4_list(&self) -> Vec { 998 | self.routes.lock().v4_list.clone() 999 | } 1000 | pub fn ipv6_list(&self) -> Vec { 1001 | self.routes.lock().v6_list.clone() 1002 | } 1003 | pub fn route(&self, dst: IpAddr) -> Option { 1004 | match dst { 1005 | IpAddr::V4(ip) => self.route_v4(ip).map(|v| v.into()), 1006 | IpAddr::V6(ip) => self.route_v6(ip).map(|v| v.into()), 1007 | } 1008 | } 1009 | pub fn route_v4(&self, dst: Ipv4Addr) -> Option { 1010 | self.routes.lock().route_v4(dst) 1011 | } 1012 | pub fn route_v6(&self, dst: Ipv6Addr) -> Option { 1013 | self.routes.lock().route_v6(dst) 1014 | } 1015 | pub fn add_v4(&self, dest: Ipv4Addr, mask: Ipv4Addr, ip: Ipv4Addr) -> io::Result<()> { 1016 | self.routes.lock().add_v4(dest, mask, ip) 1017 | } 1018 | pub fn add_v6(&self, dest: Ipv6Addr, mask: Ipv6Addr, ip: Ipv6Addr) -> io::Result<()> { 1019 | self.routes.lock().add_v6(dest, mask, ip) 1020 | } 1021 | pub fn remove_v4(&self, dest: Ipv4Addr, mask: Ipv4Addr) -> io::Result<()> { 1022 | self.routes.lock().remove_v4(dest, mask) 1023 | } 1024 | pub fn remove_v6(&self, dest: Ipv6Addr, mask: Ipv6Addr) -> io::Result<()> { 1025 | self.routes.lock().remove_v6(dest, mask) 1026 | } 1027 | pub fn clear_v4(&self) { 1028 | self.routes.lock().clear_v4() 1029 | } 1030 | pub fn clear_v6(&self) { 1031 | self.routes.lock().clear_v6() 1032 | } 1033 | pub fn set_default_v4(&self, ip: Ipv4Addr) { 1034 | self.routes.lock().set_default_v4(ip) 1035 | } 1036 | pub fn set_default_v6(&self, ip: Ipv6Addr) { 1037 | self.routes.lock().set_default_v6(ip) 1038 | } 1039 | pub fn default_v4(&self) -> Option { 1040 | self.routes.lock().default_v4() 1041 | } 1042 | pub fn default_v6(&self) -> Option { 1043 | self.routes.lock().default_v6() 1044 | } 1045 | } 1046 | 1047 | #[derive(Default, Debug)] 1048 | struct Routes { 1049 | v4_list: Vec, 1050 | default_v4: Option, 1051 | v4_table: Vec<(u32, u32, Ipv4Addr)>, 1052 | v6_list: Vec, 1053 | default_v6: Option, 1054 | v6_table: Vec<(u128, u128, Ipv6Addr)>, 1055 | } 1056 | impl Routes { 1057 | fn exists_v4(&self, ip: &Ipv4Addr) -> bool { 1058 | if self.v4_list.is_empty() { 1059 | return true; 1060 | } 1061 | self.v4_list.contains(ip) 1062 | } 1063 | fn exists_v6(&self, ip: &Ipv6Addr) -> bool { 1064 | if self.v6_list.is_empty() { 1065 | return true; 1066 | } 1067 | self.v6_list.contains(ip) 1068 | } 1069 | fn route_v4(&self, dst: Ipv4Addr) -> Option { 1070 | let dst = u32::from(dst); 1071 | for (dest_cur, mask_cur, ip_cur) in self.v4_table.iter() { 1072 | if dst & *mask_cur == *dest_cur { 1073 | return Some(*ip_cur); 1074 | } 1075 | } 1076 | self.default_v4 1077 | } 1078 | fn route_v6(&self, dst: Ipv6Addr) -> Option { 1079 | let dst = u128::from(dst); 1080 | for (dest_cur, mask_cur, ip_cur) in self.v6_table.iter() { 1081 | if dst & *mask_cur == *dest_cur { 1082 | return Some(*ip_cur); 1083 | } 1084 | } 1085 | self.default_v6 1086 | } 1087 | fn add_v4(&mut self, dest: Ipv4Addr, mask: Ipv4Addr, ip: Ipv4Addr) -> io::Result<()> { 1088 | let mask = u32::from(mask); 1089 | if mask.count_ones() != mask.leading_ones() { 1090 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid mask")); 1091 | } 1092 | if !self.v4_list.contains(&ip) { 1093 | self.v4_list.push(ip); 1094 | } 1095 | let dest = u32::from(dest) & mask; 1096 | for (dest_cur, mask_cur, ip_cur) in self.v4_table.iter_mut() { 1097 | if dest == *dest_cur && mask == *mask_cur { 1098 | *ip_cur = ip; 1099 | return Ok(()); 1100 | } 1101 | } 1102 | self.v4_table.push((dest, mask, ip)); 1103 | self.v4_table.sort_by(|a, b| b.1.cmp(&a.1)); 1104 | Ok(()) 1105 | } 1106 | fn add_v6(&mut self, dest: Ipv6Addr, mask: Ipv6Addr, ip: Ipv6Addr) -> io::Result<()> { 1107 | let mask = u128::from(mask); 1108 | if mask.count_ones() != mask.leading_ones() { 1109 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid mask")); 1110 | } 1111 | if !self.v6_list.contains(&ip) { 1112 | self.v6_list.push(ip); 1113 | } 1114 | let dest = u128::from(dest) & mask; 1115 | for (dest_cur, mask_cur, ip_cur) in self.v6_table.iter_mut() { 1116 | if dest == *dest_cur && mask == *mask_cur { 1117 | *ip_cur = ip; 1118 | return Ok(()); 1119 | } 1120 | } 1121 | self.v6_table.push((dest, mask, ip)); 1122 | self.v6_table.sort_by(|a, b| b.1.cmp(&a.1)); 1123 | Ok(()) 1124 | } 1125 | fn remove_v4(&mut self, dest: Ipv4Addr, mask: Ipv4Addr) -> io::Result<()> { 1126 | let mask = u32::from(mask); 1127 | let dest = u32::from(dest) & mask; 1128 | let len = self.v4_table.len(); 1129 | 1130 | self.v4_table 1131 | .retain(|(dest_cur, mask_cur, _)| !(dest == *dest_cur && mask == *mask_cur)); 1132 | if len == self.v4_table.len() { 1133 | Err(io::Error::new(io::ErrorKind::NotFound, "not found route")) 1134 | } else { 1135 | self.v4_list = self.v4_table.iter().map(|v| v.2).collect(); 1136 | Ok(()) 1137 | } 1138 | } 1139 | fn remove_v6(&mut self, dest: Ipv6Addr, mask: Ipv6Addr) -> io::Result<()> { 1140 | let mask = u128::from(mask); 1141 | let dest = u128::from(dest) & mask; 1142 | let len = self.v6_table.len(); 1143 | self.v6_table 1144 | .retain(|(dest_cur, mask_cur, _)| !(dest == *dest_cur && mask == *mask_cur)); 1145 | if len == self.v6_table.len() { 1146 | Err(io::Error::new(io::ErrorKind::NotFound, "not found route")) 1147 | } else { 1148 | self.v6_list = self.v6_table.iter().map(|v| v.2).collect(); 1149 | Ok(()) 1150 | } 1151 | } 1152 | fn clear_v4(&mut self) { 1153 | self.v4_table.clear(); 1154 | } 1155 | fn clear_v6(&mut self) { 1156 | self.v6_table.clear(); 1157 | } 1158 | fn set_default_v4(&mut self, ip: Ipv4Addr) { 1159 | if !self.v4_list.contains(&ip) { 1160 | self.v4_list.push(ip); 1161 | } 1162 | self.default_v4 = Some(ip) 1163 | } 1164 | fn set_default_v6(&mut self, ip: Ipv6Addr) { 1165 | if !self.v6_list.contains(&ip) { 1166 | self.v6_list.push(ip); 1167 | } 1168 | self.default_v6 = Some(ip) 1169 | } 1170 | fn default_v4(&self) -> Option { 1171 | self.default_v4 1172 | } 1173 | fn default_v6(&self) -> Option { 1174 | self.default_v6 1175 | } 1176 | } 1177 | 1178 | impl IpStackInner { 1179 | fn add_bind_addr(&self, protocol: IpNextHeaderProtocol, mut addr: SocketAddr, set_port: bool) -> io::Result { 1180 | let mut guard = self.bind_addrs.lock(); 1181 | if set_port && addr.port() == 0 { 1182 | let port_start: u16 = rand::rng().random_range(1..=65535); 1183 | for i in 0..65535 { 1184 | let port = port_start.wrapping_add(i); 1185 | if port == 0 { 1186 | continue; 1187 | } 1188 | addr.set_port(port); 1189 | if !guard.contains(&(protocol, addr)) { 1190 | guard.insert((protocol, addr)); 1191 | return Ok(addr); 1192 | } 1193 | } 1194 | return Err(io::Error::new(io::ErrorKind::AddrInUse, "Address already in use")); 1195 | } 1196 | if guard.contains(&(protocol, addr)) { 1197 | return Err(io::Error::new(io::ErrorKind::AddrInUse, "Address already in use")); 1198 | } 1199 | guard.insert((protocol, addr)); 1200 | Ok(addr) 1201 | } 1202 | fn remove_bind_addr(&self, protocol: IpNextHeaderProtocol, addr: SocketAddr) { 1203 | let mut guard = self.bind_addrs.lock(); 1204 | guard.remove(&(protocol, addr)); 1205 | } 1206 | } 1207 | #[derive(Debug)] 1208 | pub(crate) struct BindAddr { 1209 | protocol: IpNextHeaderProtocol, 1210 | pub(crate) addr: SocketAddr, 1211 | inner: Arc, 1212 | } 1213 | impl Drop for BindAddr { 1214 | fn drop(&mut self) { 1215 | self.inner.remove_bind_addr(self.protocol, self.addr); 1216 | } 1217 | } 1218 | 1219 | #[cfg(feature = "global-ip-stack")] 1220 | lazy_static::lazy_static! { 1221 | static ref IP_STACK: Mutex> = Mutex::new(None); 1222 | } 1223 | #[cfg(feature = "global-ip-stack")] 1224 | impl IpStack { 1225 | pub fn get() -> io::Result { 1226 | if let Some(v) = IP_STACK.lock().clone() { 1227 | Ok(v) 1228 | } else { 1229 | Err(io::Error::new(io::ErrorKind::Other, "Not initialized IpStack")) 1230 | } 1231 | } 1232 | pub fn release() { 1233 | _ = IP_STACK.lock().take(); 1234 | } 1235 | pub(crate) fn set(ip_stack: IpStack) { 1236 | IP_STACK.lock().replace(ip_stack); 1237 | } 1238 | } 1239 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | /*! 2 | # Example 3 | ```no_run 4 | #[tokio::main] 5 | pub async fn main() -> std::io::Result<()> { 6 | use tokio::io::AsyncReadExt; 7 | let (ip_stack, _ip_stack_send, mut ip_stack_recv) = tcp_ip::ip_stack(tcp_ip::IpStackConfig::default())?; 8 | tokio::spawn(async move { 9 | loop { 10 | // ip_stack_send.send_ip_packet() 11 | todo!("Send IP packets to the protocol stack using 'ip_stack_send'") 12 | } 13 | }); 14 | tokio::spawn(async move { 15 | let mut buf = [0; 65535]; 16 | loop { 17 | match ip_stack_recv.recv(&mut buf).await { 18 | Ok(_len) => {} 19 | Err(e) => println!("{e:?}"), 20 | } 21 | todo!("Receive IP packets from the protocol stack using 'ip_stack_recv'") 22 | } 23 | }); 24 | let mut tcp_listener = tcp_ip::tcp::TcpListener::bind(ip_stack.clone(), "0.0.0.0:80".parse().unwrap()).await?; 25 | loop { 26 | let (mut tcp_stream, addr) = tcp_listener.accept().await?; 27 | tokio::spawn(async move { 28 | let mut buf = [0; 1024]; 29 | match tcp_stream.read(&mut buf).await { 30 | Ok(len) => println!("read:{:?},addr={addr}", &buf[..len]), 31 | Err(e) => println!("{e:?}"), 32 | } 33 | }); 34 | } 35 | } 36 | ``` 37 | */ 38 | 39 | mod buffer; 40 | pub mod icmp; 41 | mod ip_stack; 42 | pub use ip_stack::*; 43 | pub mod address; 44 | pub mod ip; 45 | pub mod tcp; 46 | pub mod udp; 47 | -------------------------------------------------------------------------------- /src/tcp/mod.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::io; 3 | use std::io::Error; 4 | use std::net::SocketAddr; 5 | use std::pin::Pin; 6 | use std::task::{Context, Poll}; 7 | 8 | use bytes::{Buf, BytesMut}; 9 | use pnet_packet::ip::IpNextHeaderProtocols; 10 | use pnet_packet::tcp::TcpFlags::{ACK, RST, SYN}; 11 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 12 | use tokio::sync::mpsc::{channel, Receiver}; 13 | use tokio_util::sync::PollSender; 14 | 15 | pub use tcb::TcpConfig; 16 | 17 | use crate::address::ToSocketAddr; 18 | use crate::ip_stack::{check_addr, check_ip, default_addr, BindAddr, IpStack, NetworkTuple, TransportPacket}; 19 | use crate::tcp::sys::{ReadNotify, TcpStreamTask}; 20 | use crate::tcp::tcb::Tcb; 21 | 22 | mod sys; 23 | mod tcb; 24 | mod tcp_queue; 25 | 26 | /// A TCP socket server, listening for connections. 27 | /// You can accept a new connection by using the accept method. 28 | /// # Example 29 | /// ```no_run 30 | /// use std::io; 31 | /// 32 | /// async fn process_socket(socket: T) { 33 | /// // do work with socket here 34 | /// } 35 | /// 36 | /// #[tokio::main] 37 | /// #[cfg(not(feature = "global-ip-stack"))] 38 | /// async fn main() -> io::Result<()> { 39 | /// let (ip_stack, _ip_stack_send, _ip_stack_recv) = 40 | /// tcp_ip::ip_stack(tcp_ip::IpStackConfig::default())?; 41 | /// // Read and write IP packets using _ip_stack_send and _ip_stack_recv 42 | /// let src = "10.0.0.2:8080".parse().unwrap(); 43 | /// let mut listener = tcp_ip::tcp::TcpListener::bind(ip_stack.clone(),src).await?; 44 | /// 45 | /// loop { 46 | /// let (socket, _) = listener.accept().await?; 47 | /// process_socket(socket).await; 48 | /// } 49 | /// } 50 | /// ``` 51 | pub struct TcpListener { 52 | _bind_addr: Option, 53 | ip_stack: IpStack, 54 | packet_receiver: Receiver, 55 | local_addr: Option, 56 | tcb_map: HashMap, 57 | } 58 | 59 | /// A TCP stream between a local and a remote socket. 60 | /// 61 | /// # Example 62 | /// ```no_run 63 | /// #[tokio::main] 64 | /// #[cfg(not(feature = "global-ip-stack"))] 65 | /// async fn main() -> std::io::Result<()> { 66 | /// // Connect to a peer 67 | /// use tokio::io::AsyncWriteExt; 68 | /// let (ip_stack, _ip_stack_send, _ip_stack_recv) = 69 | /// tcp_ip::ip_stack(tcp_ip::IpStackConfig::default())?; 70 | /// // Read and write IP packets using _ip_stack_send and _ip_stack_recv 71 | /// let src = "10.0.0.2:8080".parse().unwrap(); 72 | /// let dst = "10.0.0.3:8080".parse().unwrap(); 73 | /// let mut stream = tcp_ip::tcp::TcpStream::bind(ip_stack.clone(),src)? 74 | /// .connect(dst).await?; 75 | /// 76 | /// // Write some data. 77 | /// stream.write_all(b"hello world!").await?; 78 | /// 79 | /// Ok(()) 80 | /// } 81 | /// ``` 82 | pub struct TcpStream { 83 | bind_addr: Option, 84 | ip_stack: Option, 85 | local_addr: SocketAddr, 86 | peer_addr: Option, 87 | read: Option, 88 | write: Option, 89 | } 90 | 91 | pub struct TcpStreamReadHalf { 92 | read_notify: ReadNotify, 93 | last_buf: Option, 94 | payload_receiver: Receiver, 95 | } 96 | 97 | pub struct TcpStreamWriteHalf { 98 | mss: usize, 99 | payload_sender: PollSender, 100 | } 101 | #[cfg(feature = "global-ip-stack")] 102 | impl TcpListener { 103 | pub async fn bind_all() -> io::Result { 104 | Self::bind0(IpStack::get()?, None).await 105 | } 106 | pub async fn bind(local_addr: A) -> io::Result { 107 | let ip_stack = IpStack::get()?; 108 | let local_addr = local_addr.to_addr()?; 109 | ip_stack.routes().check_bind_ip(local_addr.ip())?; 110 | Self::bind0(ip_stack, Some(local_addr)).await 111 | } 112 | } 113 | #[cfg(not(feature = "global-ip-stack"))] 114 | impl TcpListener { 115 | pub async fn bind_all(ip_stack: IpStack) -> io::Result { 116 | Self::bind0(ip_stack, None).await 117 | } 118 | pub async fn bind(ip_stack: IpStack, local_addr: A) -> io::Result { 119 | let local_addr = local_addr.to_addr()?; 120 | ip_stack.routes().check_bind_ip(local_addr.ip())?; 121 | Self::bind0(ip_stack, Some(local_addr)).await 122 | } 123 | } 124 | impl TcpListener { 125 | async fn bind0(ip_stack: IpStack, mut local_addr: Option) -> io::Result { 126 | let (packet_sender, packet_receiver) = channel(ip_stack.config.tcp_syn_channel_size); 127 | let _bind_addr = if let Some(addr) = &mut local_addr { 128 | Some(ip_stack.bind(IpNextHeaderProtocols::Tcp, addr)?) 129 | } else { 130 | None 131 | }; 132 | ip_stack.add_tcp_listener(local_addr, packet_sender)?; 133 | Ok(Self { 134 | _bind_addr, 135 | ip_stack, 136 | packet_receiver, 137 | local_addr, 138 | tcb_map: Default::default(), 139 | }) 140 | } 141 | pub fn local_addr(&self) -> io::Result { 142 | self.local_addr.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound)) 143 | } 144 | pub async fn accept(&mut self) -> io::Result<(TcpStream, SocketAddr)> { 145 | loop { 146 | if let Some(packet) = self.packet_receiver.recv().await { 147 | let network_tuple = &packet.network_tuple; 148 | if let Some(v) = self.ip_stack.inner.tcp_stream_map.get(network_tuple).as_deref().cloned() { 149 | // If a TCP stream has already been generated, hand it over to the corresponding stream 150 | _ = v.send(packet).await; 151 | continue; 152 | } 153 | let Some(tcp_packet) = pnet_packet::tcp::TcpPacket::new(&packet.buf) else { 154 | return Err(Error::new(io::ErrorKind::InvalidInput, "not tcp")); 155 | }; 156 | let acknowledgement = tcp_packet.get_acknowledgement(); 157 | let sequence = tcp_packet.get_sequence(); 158 | let local_addr = network_tuple.dst; 159 | let peer_addr = network_tuple.src; 160 | if tcp_packet.get_flags() & SYN == SYN { 161 | // LISTEN -> SYN_RECEIVED 162 | let tcp_config = self.ip_stack.config.tcp_config; 163 | let mut tcb = Tcb::new_listen(local_addr, peer_addr, tcp_config); 164 | if let Some(relay_packet) = tcb.try_syn_received(&tcp_packet) { 165 | self.ip_stack.send_packet(relay_packet).await?; 166 | self.tcb_map.insert(*network_tuple, tcb); 167 | continue; 168 | } 169 | } else if let Some(tcb) = self.tcb_map.get_mut(network_tuple) { 170 | // SYN_RECEIVED -> ESTABLISHED 171 | if tcb.try_syn_received_to_established(packet.buf) { 172 | let tcb = self.tcb_map.remove(network_tuple).unwrap(); 173 | return Ok((TcpStream::new(self.ip_stack.clone(), tcb)?, peer_addr)); 174 | } 175 | if tcb.is_close() { 176 | self.tcb_map.remove(network_tuple).unwrap(); 177 | } 178 | } else if tcp_packet.get_flags() & RST == RST { 179 | continue; 180 | } 181 | let data = tcb::create_transport_packet_raw( 182 | &local_addr, 183 | &peer_addr, 184 | acknowledgement, 185 | sequence.wrapping_add(1), 186 | 0, 187 | RST | ACK, 188 | &[], 189 | ); 190 | self.ip_stack.send_packet(data).await?; 191 | } else { 192 | return Err(Error::from(io::ErrorKind::UnexpectedEof)); 193 | } 194 | } 195 | } 196 | } 197 | #[cfg(feature = "global-ip-stack")] 198 | impl TcpStream { 199 | pub fn bind(local_addr: A) -> io::Result { 200 | let ip_stack = IpStack::get()?; 201 | let mut local_addr = local_addr.to_addr()?; 202 | ip_stack.routes().check_bind_ip(local_addr.ip())?; 203 | let bind_addr = ip_stack.bind(IpNextHeaderProtocols::Tcp, &mut local_addr)?; 204 | Ok(Self::new_uncheck(Some(bind_addr), Some(ip_stack), local_addr, None, None, None)) 205 | } 206 | pub async fn connect(dest: A) -> io::Result { 207 | let dest = dest.to_addr()?; 208 | TcpStream::bind(default_addr(dest.is_ipv4()))?.connect_to(dest).await 209 | } 210 | } 211 | #[cfg(not(feature = "global-ip-stack"))] 212 | impl TcpStream { 213 | pub fn bind(ip_stack: IpStack, local_addr: A) -> io::Result { 214 | let mut local_addr = local_addr.to_addr()?; 215 | ip_stack.routes().check_bind_ip(local_addr.ip())?; 216 | let bind_addr = ip_stack.bind(IpNextHeaderProtocols::Tcp, &mut local_addr)?; 217 | Ok(Self::new_uncheck(Some(bind_addr), Some(ip_stack), local_addr, None, None, None)) 218 | } 219 | pub async fn connect(ip_stack: IpStack, dest: A) -> io::Result { 220 | let dest = dest.to_addr()?; 221 | TcpStream::bind(ip_stack, default_addr(dest.is_ipv4()))?.connect_to(dest).await 222 | } 223 | } 224 | impl TcpStream { 225 | pub async fn connect_to(self, dest: A) -> io::Result { 226 | let dest = dest.to_addr()?; 227 | check_addr(dest)?; 228 | let Some(ip_stack) = self.ip_stack else { 229 | return Err(Error::new(io::ErrorKind::AlreadyExists, "transport endpoint is already connected")); 230 | }; 231 | let mut src = self.local_addr; 232 | if src.is_ipv4() != dest.is_ipv4() { 233 | return Err(Error::new(io::ErrorKind::InvalidInput, "address error")); 234 | } 235 | if let Err(e) = check_ip(src.ip()) { 236 | if let Some(v) = ip_stack.routes().route(dest.ip()) { 237 | src.set_ip(v); 238 | } else { 239 | Err(e)? 240 | } 241 | } 242 | 243 | Self::connect0(self.bind_addr, ip_stack, src, dest).await 244 | } 245 | pub fn local_addr(&self) -> io::Result { 246 | Ok(self.local_addr) 247 | } 248 | pub fn peer_addr(&self) -> io::Result { 249 | if let Some(v) = self.peer_addr { 250 | Ok(v) 251 | } else { 252 | Err(Error::from(io::ErrorKind::NotConnected)) 253 | } 254 | } 255 | pub fn split(self) -> io::Result<(TcpStreamWriteHalf, TcpStreamReadHalf)> { 256 | match (self.write, self.read) { 257 | (Some(write), Some(read)) => Ok((write, read)), 258 | _ => Err(Error::from(io::ErrorKind::NotConnected)), 259 | } 260 | } 261 | } 262 | 263 | impl TcpStream { 264 | fn as_mut_read(&mut self) -> io::Result<&mut TcpStreamReadHalf> { 265 | if let Some(v) = self.read.as_mut() { 266 | Ok(v) 267 | } else { 268 | Err(Error::from(io::ErrorKind::NotConnected)) 269 | } 270 | } 271 | fn as_mut_write(&mut self) -> io::Result<&mut TcpStreamWriteHalf> { 272 | if let Some(v) = self.write.as_mut() { 273 | Ok(v) 274 | } else { 275 | Err(Error::from(io::ErrorKind::NotConnected)) 276 | } 277 | } 278 | pub(crate) async fn connect0( 279 | bind_addr: Option, 280 | ip_stack: IpStack, 281 | local_addr: SocketAddr, 282 | peer_addr: SocketAddr, 283 | ) -> io::Result { 284 | let (payload_sender_w, payload_receiver_w) = channel(ip_stack.config.tcp_channel_size); 285 | let (payload_sender, payload_receiver) = channel(ip_stack.config.tcp_channel_size); 286 | let (packet_sender, packet_receiver) = channel(ip_stack.config.tcp_channel_size); 287 | let network_tuple = NetworkTuple::new(peer_addr, local_addr, IpNextHeaderProtocols::Tcp); 288 | ip_stack.add_tcp_socket(network_tuple, packet_sender)?; 289 | let mut tcp_config = ip_stack.config.tcp_config; 290 | if tcp_config.mss.is_none() { 291 | tcp_config.mss.replace(ip_stack.config.mtu - tcb::IP_TCP_HEADER_LEN as u16); 292 | } 293 | let tcb = Tcb::new_listen(local_addr, peer_addr, ip_stack.config.tcp_config); 294 | let mut stream_task = TcpStreamTask::new(bind_addr, tcb, ip_stack, payload_sender, payload_receiver_w, packet_receiver); 295 | stream_task.connect().await?; 296 | let read_notify = stream_task.read_notify(); 297 | let mss = stream_task.mss() as usize; 298 | tokio::spawn(async move { 299 | if let Err(e) = stream_task.run().await { 300 | log::warn!("stream_task run {local_addr}->{peer_addr}: {e:?}") 301 | } 302 | }); 303 | let read = TcpStreamReadHalf { 304 | read_notify, 305 | last_buf: None, 306 | payload_receiver, 307 | }; 308 | let write = TcpStreamWriteHalf { 309 | mss, 310 | payload_sender: PollSender::new(payload_sender_w), 311 | }; 312 | let stream = Self::new_uncheck(None, None, local_addr, Some(peer_addr), Some(read), Some(write)); 313 | Ok(stream) 314 | } 315 | fn new_uncheck( 316 | bind_addr: Option, 317 | ip_stack: Option, 318 | local_addr: SocketAddr, 319 | peer_addr: Option, 320 | read: Option, 321 | write: Option, 322 | ) -> Self { 323 | Self { 324 | bind_addr, 325 | ip_stack, 326 | local_addr, 327 | peer_addr, 328 | read, 329 | write, 330 | } 331 | } 332 | pub(crate) fn new0(ip_stack: IpStack, tcb: Tcb) -> io::Result<(Self, TcpStreamTask)> { 333 | let peer_addr = tcb.peer_addr(); 334 | let local_addr = tcb.local_addr(); 335 | let (payload_sender_w, payload_receiver_w) = channel(ip_stack.config.tcp_channel_size); 336 | let (payload_sender, payload_receiver) = channel(ip_stack.config.tcp_channel_size); 337 | let (packet_sender, packet_receiver) = channel(ip_stack.config.tcp_channel_size); 338 | let network_tuple = NetworkTuple::new(peer_addr, local_addr, IpNextHeaderProtocols::Tcp); 339 | ip_stack.add_tcp_socket(network_tuple, packet_sender)?; 340 | let mss = tcb.mss() as usize; 341 | let stream_task = TcpStreamTask::new(None, tcb, ip_stack, payload_sender, payload_receiver_w, packet_receiver); 342 | let read_notify = stream_task.read_notify(); 343 | let read = TcpStreamReadHalf { 344 | read_notify, 345 | last_buf: None, 346 | payload_receiver, 347 | }; 348 | let write = TcpStreamWriteHalf { 349 | mss, 350 | payload_sender: PollSender::new(payload_sender_w), 351 | }; 352 | let stream = Self::new_uncheck(None, None, local_addr, Some(peer_addr), Some(read), Some(write)); 353 | Ok((stream, stream_task)) 354 | } 355 | pub(crate) fn new(ip_stack: IpStack, tcb: Tcb) -> io::Result { 356 | let peer_addr = tcb.peer_addr(); 357 | let local_addr = tcb.local_addr(); 358 | let (stream, mut stream_task) = Self::new0(ip_stack, tcb)?; 359 | tokio::spawn(async move { 360 | if let Err(e) = stream_task.run().await { 361 | log::warn!("stream_task run {local_addr}->{peer_addr}: {e:?}") 362 | } 363 | }); 364 | Ok(stream) 365 | } 366 | } 367 | 368 | impl AsyncRead for TcpStream { 369 | fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { 370 | Pin::new(self.as_mut_read()?).poll_read(cx, buf) 371 | } 372 | } 373 | 374 | impl AsyncRead for TcpStreamReadHalf { 375 | fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { 376 | if let Some(p) = self.last_buf.as_mut() { 377 | let len = buf.remaining().min(p.len()); 378 | buf.put_slice(&p[..len]); 379 | p.advance(len); 380 | if p.is_empty() { 381 | self.last_buf.take(); 382 | if self.try_read0(buf) { 383 | self.read_notify.notify(); 384 | } 385 | } 386 | return Poll::Ready(Ok(())); 387 | } 388 | let poll = self.payload_receiver.poll_recv(cx); 389 | match poll { 390 | Poll::Ready(None) => Poll::Ready(Ok(())), 391 | Poll::Ready(Some(mut p)) => { 392 | if p.is_empty() { 393 | self.payload_receiver.close(); 394 | return Poll::Ready(Ok(())); 395 | } 396 | let len = buf.remaining().min(p.len()); 397 | buf.put_slice(&p[..len]); 398 | p.advance(len); 399 | if p.is_empty() { 400 | self.try_read0(buf); 401 | } else { 402 | self.last_buf.replace(p); 403 | } 404 | self.read_notify.notify(); 405 | Poll::Ready(Ok(())) 406 | } 407 | Poll::Pending => Poll::Pending, 408 | } 409 | } 410 | } 411 | 412 | impl Drop for TcpStreamReadHalf { 413 | fn drop(&mut self) { 414 | self.payload_receiver.close(); 415 | self.read_notify.close(); 416 | } 417 | } 418 | impl TcpStreamReadHalf { 419 | fn try_read0(&mut self, buf: &mut ReadBuf<'_>) -> bool { 420 | let mut rs = false; 421 | while buf.remaining() > 0 { 422 | let Ok(mut p) = self.payload_receiver.try_recv() else { 423 | break; 424 | }; 425 | rs = true; 426 | if p.is_empty() { 427 | self.payload_receiver.close(); 428 | break; 429 | } 430 | let len = buf.remaining().min(p.len()); 431 | buf.put_slice(&p[..len]); 432 | p.advance(len); 433 | if !p.is_empty() { 434 | self.last_buf.replace(p); 435 | } 436 | } 437 | rs 438 | } 439 | } 440 | 441 | impl AsyncWrite for TcpStream { 442 | fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { 443 | Pin::new(self.as_mut_write()?).poll_write(cx, buf) 444 | } 445 | 446 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 447 | Pin::new(self.as_mut_write()?).poll_flush(cx) 448 | } 449 | 450 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 451 | Pin::new(self.as_mut_write()?).poll_shutdown(cx) 452 | } 453 | } 454 | 455 | impl AsyncWrite for TcpStreamWriteHalf { 456 | fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { 457 | if buf.is_empty() { 458 | return Poll::Ready(Err(io::Error::from(io::ErrorKind::WriteZero))); 459 | } 460 | match self.payload_sender.poll_reserve(cx) { 461 | Poll::Ready(Ok(_)) => { 462 | let len = buf.len().min(self.mss * 10); 463 | let buf = &buf[..len]; 464 | match self.payload_sender.send_item(buf.into()) { 465 | Ok(_) => {} 466 | Err(_) => return Poll::Ready(Err(io::Error::from(io::ErrorKind::WriteZero))), 467 | }; 468 | Poll::Ready(Ok(buf.len())) 469 | } 470 | Poll::Ready(Err(_)) => Poll::Ready(Err(io::Error::from(io::ErrorKind::WriteZero))), 471 | Poll::Pending => Poll::Pending, 472 | } 473 | } 474 | 475 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { 476 | Poll::Ready(Ok(())) 477 | } 478 | 479 | fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { 480 | self.payload_sender.close(); 481 | Poll::Ready(Ok(())) 482 | } 483 | } 484 | 485 | impl Drop for TcpListener { 486 | fn drop(&mut self) { 487 | self.ip_stack.remove_tcp_listener(&self.local_addr) 488 | } 489 | } 490 | -------------------------------------------------------------------------------- /src/tcp/sys.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::io::Error; 3 | use std::ops::Add; 4 | use std::sync::atomic::{AtomicBool, Ordering}; 5 | use std::sync::Arc; 6 | use std::time::Duration; 7 | 8 | use bytes::{Buf, BytesMut}; 9 | use pnet_packet::ip::IpNextHeaderProtocols; 10 | use tokio::sync::mpsc::error::TrySendError; 11 | use tokio::sync::mpsc::{Receiver, Sender}; 12 | use tokio::sync::Notify; 13 | use tokio::time::Instant; 14 | 15 | use crate::ip_stack::{BindAddr, IpStack, NetworkTuple, TransportPacket}; 16 | use crate::tcp::tcb::Tcb; 17 | 18 | #[derive(Debug)] 19 | pub struct TcpStreamTask { 20 | _bind_addr: Option, 21 | quick_end: bool, 22 | tcb: Tcb, 23 | ip_stack: IpStack, 24 | application_layer_receiver: Receiver, 25 | last_buffer: Option, 26 | packet_receiver: Receiver, 27 | application_layer_sender: Option>, 28 | write_half_closed: bool, 29 | retransmission: bool, 30 | read_notify: ReadNotify, 31 | } 32 | 33 | #[derive(Clone, Default, Debug)] 34 | pub struct ReadNotify { 35 | readable: Arc, 36 | notify: Arc, 37 | } 38 | 39 | impl ReadNotify { 40 | pub fn notify(&self) { 41 | if self.readable.load(Ordering::Acquire) { 42 | self.notify.notify_one(); 43 | } 44 | } 45 | pub fn close(&self) { 46 | self.notify.notify_one(); 47 | } 48 | async fn notified(&self) { 49 | self.notify.notified().await 50 | } 51 | fn set_state(&self, readable: bool) { 52 | self.readable.store(readable, Ordering::Release); 53 | } 54 | } 55 | 56 | impl Drop for TcpStreamTask { 57 | fn drop(&mut self) { 58 | let peer_addr = self.tcb.peer_addr(); 59 | let local_addr = self.tcb.local_addr(); 60 | let network_tuple = NetworkTuple::new(peer_addr, local_addr, IpNextHeaderProtocols::Tcp); 61 | self.ip_stack.remove_tcp_socket(&network_tuple); 62 | } 63 | } 64 | 65 | impl TcpStreamTask { 66 | pub fn new( 67 | _bind_addr: Option, 68 | tcb: Tcb, 69 | ip_stack: IpStack, 70 | application_layer_sender: Sender, 71 | application_layer_receiver: Receiver, 72 | packet_receiver: Receiver, 73 | ) -> Self { 74 | Self { 75 | _bind_addr, 76 | quick_end: ip_stack.config.tcp_config.quick_end, 77 | tcb, 78 | ip_stack, 79 | application_layer_receiver, 80 | last_buffer: None, 81 | packet_receiver, 82 | application_layer_sender: Some(application_layer_sender), 83 | write_half_closed: false, 84 | retransmission: false, 85 | read_notify: Default::default(), 86 | } 87 | } 88 | pub fn read_notify(&self) -> ReadNotify { 89 | self.read_notify.clone() 90 | } 91 | } 92 | 93 | impl TcpStreamTask { 94 | pub async fn run(&mut self) -> io::Result<()> { 95 | let result = self.run0().await; 96 | self.push_application_layer(); 97 | result 98 | } 99 | pub async fn run0(&mut self) -> io::Result<()> { 100 | loop { 101 | if self.tcb.is_close() { 102 | return Ok(()); 103 | } 104 | if self.quick_end && self.read_half_closed() && self.write_half_closed { 105 | return Ok(()); 106 | } 107 | if !self.write_half_closed && !self.retransmission { 108 | self.flush().await?; 109 | } 110 | let data = self.recv_data().await; 111 | 112 | match data { 113 | TaskRecvData::In(mut buf) => { 114 | let mut count = 0; 115 | loop { 116 | if let Some(reply_packet) = self.tcb.push_packet(buf) { 117 | self.send_packet(reply_packet).await?; 118 | } 119 | 120 | if self.tcb.is_close() { 121 | return Ok(()); 122 | } 123 | if !self.tcb.readable_state() { 124 | break; 125 | } 126 | count += 1; 127 | if count >= 10 { 128 | break; 129 | } 130 | if let Some(v) = self.try_recv_in() { 131 | buf = v 132 | } else { 133 | break; 134 | } 135 | } 136 | self.push_application_layer(); 137 | // if self.tcb.readable_state() && self.application_layer_sender.is_some() && self.tcb.readable() && self.tcb.recv_busy() { 138 | // // The window is too small and requires blocking to wait; otherwise, it will lead to severe packet loss 139 | // self.read_notify.notified().await; 140 | // self.push_application_layer(); 141 | // } 142 | } 143 | TaskRecvData::Out(buf) => { 144 | self.write(buf).await?; 145 | } 146 | TaskRecvData::InClose => return Err(Error::new(io::ErrorKind::Other, "NetworkDown")), 147 | TaskRecvData::OutClose => { 148 | assert!(self.last_buffer.is_none()); 149 | self.write_half_closed = true; 150 | let packet = self.tcb.fin_packet(); 151 | self.send_packet(packet).await?; 152 | self.tcb.sent_fin(); 153 | } 154 | TaskRecvData::Timeout => { 155 | self.tcb.timeout(); 156 | if self.tcb.is_close() { 157 | return Ok(()); 158 | } 159 | if self.tcb.cannot_write() { 160 | let packet = self.tcb.fin_packet(); 161 | self.send_packet(packet).await?; 162 | } 163 | if self.read_half_closed() && self.write_half_closed { 164 | return Ok(()); 165 | } 166 | } 167 | TaskRecvData::ReadNotify => { 168 | self.push_application_layer(); 169 | self.try_send_ack().await?; 170 | } 171 | } 172 | self.retransmission = self.try_retransmission().await?; 173 | self.try_send_ack().await?; 174 | self.tcb.perform_post_ack_action(); 175 | if !self.read_half_closed() && self.tcb.cannot_read() { 176 | self.close_read(); 177 | } 178 | } 179 | } 180 | async fn send_packet(&mut self, transport_packet: TransportPacket) -> io::Result<()> { 181 | self.ip_stack.send_packet(transport_packet).await?; 182 | self.tcb.perform_post_ack_action(); 183 | Ok(()) 184 | } 185 | fn read_half_closed(&self) -> bool { 186 | if let Some(v) = self.application_layer_sender.as_ref() { 187 | v.is_closed() 188 | } else { 189 | true 190 | } 191 | } 192 | pub fn mss(&self) -> u16 { 193 | self.tcb.mss() 194 | } 195 | fn only_recv_in(&self) -> bool { 196 | self.retransmission || self.last_buffer.is_some() || self.write_half_closed || self.tcb.limit() 197 | } 198 | fn push_application_layer(&mut self) { 199 | if let Some(sender) = self.application_layer_sender.as_ref() { 200 | let mut read_half_closed = false; 201 | while self.tcb.readable() { 202 | match sender.try_reserve() { 203 | Ok(sender) => { 204 | if let Some(buffer) = self.tcb.read() { 205 | sender.send(buffer); 206 | } 207 | } 208 | Err(e) => match e { 209 | TrySendError::Full(_) => break, 210 | TrySendError::Closed(_) => { 211 | read_half_closed = true; 212 | break; 213 | } 214 | }, 215 | } 216 | self.read_notify.set_state(self.tcb.readable()); 217 | } 218 | if self.tcb.cannot_read() || read_half_closed { 219 | self.close_read(); 220 | } 221 | } else { 222 | self.tcb.read_none(); 223 | } 224 | } 225 | fn close_read(&mut self) { 226 | if let Some(sender) = self.application_layer_sender.take() { 227 | _ = sender.try_send(BytesMut::new()); 228 | } 229 | } 230 | async fn write_slice0(tcb: &mut Tcb, ip_stack: &IpStack, mut buf: &[u8]) -> io::Result { 231 | let len = buf.len(); 232 | while !buf.is_empty() { 233 | if let Some((packet, len)) = tcb.write(buf) { 234 | if len == 0 { 235 | break; 236 | } 237 | ip_stack.send_packet(packet).await?; 238 | tcb.perform_post_ack_action(); 239 | buf = &buf[len..]; 240 | } else { 241 | break; 242 | } 243 | } 244 | Ok(len - buf.len()) 245 | } 246 | async fn write_slice(&mut self, buf: &[u8]) -> io::Result { 247 | Self::write_slice0(&mut self.tcb, &self.ip_stack, buf).await 248 | } 249 | async fn write(&mut self, mut buf: BytesMut) -> io::Result { 250 | let len = self.write_slice(&buf).await?; 251 | if len != buf.len() { 252 | // Buffer is full 253 | buf.advance(len); 254 | self.last_buffer.replace(buf); 255 | } 256 | Ok(len) 257 | } 258 | async fn flush(&mut self) -> io::Result<()> { 259 | if let Some(buf) = self.last_buffer.as_mut() { 260 | let len = Self::write_slice0(&mut self.tcb, &self.ip_stack, buf).await?; 261 | if buf.len() == len { 262 | self.last_buffer.take(); 263 | } else { 264 | buf.advance(len); 265 | } 266 | } 267 | Ok(()) 268 | } 269 | 270 | async fn try_retransmission(&mut self) -> io::Result { 271 | if self.write_half_closed { 272 | return Ok(false); 273 | } 274 | if let Some(v) = self.tcb.retransmission() { 275 | self.send_packet(v).await?; 276 | return Ok(true); 277 | } 278 | if self.tcb.no_inflight_packet() { 279 | return Ok(false); 280 | } 281 | if self.tcb.need_retransmission() { 282 | if let Some(v) = self.tcb.retransmission() { 283 | self.send_packet(v).await?; 284 | return Ok(true); 285 | } 286 | } 287 | Ok(false) 288 | } 289 | async fn try_send_ack(&mut self) -> io::Result<()> { 290 | if self.tcb.need_ack() { 291 | let packet = self.tcb.ack_packet(); 292 | self.ip_stack.send_packet(packet).await?; 293 | } 294 | Ok(()) 295 | } 296 | 297 | async fn recv_data(&mut self) -> TaskRecvData { 298 | let deadline = if let Some(v) = self.tcb.time_wait() { 299 | Some(v.into()) 300 | } else { 301 | self.tcb.write_timeout().map(|v| v.into()) 302 | }; 303 | 304 | if let Some(deadline) = deadline { 305 | if self.only_recv_in() { 306 | self.recv_in_timeout_at(deadline).await 307 | } else { 308 | self.recv_timeout_at(deadline).await 309 | } 310 | } else if self.write_half_closed { 311 | let timeout_at = Instant::now().add(self.ip_stack.config.tcp_config.time_wait_timeout); 312 | self.recv_in_timeout_at(timeout_at).await 313 | } else { 314 | self.recv().await 315 | } 316 | } 317 | async fn recv(&mut self) -> TaskRecvData { 318 | tokio::select! { 319 | rs=self.packet_receiver.recv()=>{ 320 | rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose) 321 | } 322 | rs=self.application_layer_receiver.recv()=>{ 323 | rs.map(TaskRecvData::Out).unwrap_or(TaskRecvData::OutClose) 324 | } 325 | _=self.read_notify.notified()=>{ 326 | TaskRecvData::ReadNotify 327 | } 328 | } 329 | } 330 | async fn recv_timeout_at(&mut self, deadline: Instant) -> TaskRecvData { 331 | tokio::select! { 332 | rs=self.packet_receiver.recv()=>{ 333 | rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose) 334 | } 335 | rs=self.application_layer_receiver.recv()=>{ 336 | rs.map(TaskRecvData::Out).unwrap_or(TaskRecvData::OutClose) 337 | } 338 | _=tokio::time::sleep_until(deadline)=>{ 339 | TaskRecvData::Timeout 340 | } 341 | _=self.read_notify.notified()=>{ 342 | TaskRecvData::ReadNotify 343 | } 344 | } 345 | } 346 | 347 | async fn recv_in_timeout_at(&mut self, deadline: Instant) -> TaskRecvData { 348 | tokio::select! { 349 | rs=self.packet_receiver.recv()=>{ 350 | rs.map(|v| TaskRecvData::In(v.buf)).unwrap_or(TaskRecvData::InClose) 351 | } 352 | _=tokio::time::sleep_until(deadline)=>{ 353 | TaskRecvData::Timeout 354 | } 355 | _=self.read_notify.notified()=>{ 356 | TaskRecvData::ReadNotify 357 | } 358 | } 359 | } 360 | async fn recv_in_timeout(&mut self, duration: Duration) -> TaskRecvData { 361 | self.recv_in_timeout_at(Instant::now().add(duration)).await 362 | } 363 | 364 | fn try_recv_in(&mut self) -> Option { 365 | self.packet_receiver.try_recv().map(|v| v.buf).ok() 366 | } 367 | } 368 | 369 | impl TcpStreamTask { 370 | pub async fn connect(&mut self) -> io::Result<()> { 371 | let mut count = 0; 372 | let mut time = 50; 373 | while let Some(packet) = self.tcb.try_syn_sent() { 374 | count += 1; 375 | if count > 50 { 376 | break; 377 | } 378 | self.send_packet(packet).await?; 379 | time *= 2; 380 | return match self.recv_in_timeout(Duration::from_millis(time.min(3000))).await { 381 | TaskRecvData::In(buf) => { 382 | if let Some(relay) = self.tcb.try_syn_sent_to_established(buf) { 383 | self.send_packet(relay).await?; 384 | Ok(()) 385 | } else { 386 | Err(io::Error::from(io::ErrorKind::ConnectionRefused)) 387 | } 388 | } 389 | TaskRecvData::InClose => Err(io::Error::from(io::ErrorKind::ConnectionRefused)), 390 | TaskRecvData::Timeout => continue, 391 | _ => { 392 | unreachable!() 393 | } 394 | }; 395 | } 396 | Err(io::Error::from(io::ErrorKind::ConnectionRefused)) 397 | } 398 | } 399 | 400 | enum TaskRecvData { 401 | In(BytesMut), 402 | Out(BytesMut), 403 | ReadNotify, 404 | InClose, 405 | OutClose, 406 | Timeout, 407 | } 408 | -------------------------------------------------------------------------------- /src/tcp/tcb.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::Ordering; 2 | use std::collections::VecDeque; 3 | use std::io; 4 | use std::net::{IpAddr, SocketAddr}; 5 | use std::ops::{Add, Sub}; 6 | use std::time::{Duration, Instant}; 7 | 8 | use bytes::{Buf, BufMut, BytesMut}; 9 | use pnet_packet::ip::IpNextHeaderProtocols; 10 | use pnet_packet::tcp::TcpFlags::{ACK, FIN, PSH, RST, SYN}; 11 | use pnet_packet::tcp::{TcpOptionNumber, TcpOptionNumbers, TcpPacket}; 12 | use pnet_packet::Packet; 13 | use rand::RngCore; 14 | 15 | use crate::buffer::FixedBuffer; 16 | use crate::ip_stack::{NetworkTuple, TransportPacket}; 17 | use crate::tcp::tcp_queue::{TcpOfoQueue, TcpReceiveQueue}; 18 | 19 | const IP_HEADER_LEN: usize = 20; 20 | const TCP_HEADER_LEN: usize = 20; 21 | pub const IP_TCP_HEADER_LEN: usize = IP_HEADER_LEN + TCP_HEADER_LEN; 22 | const MAX_DIFF: u32 = u32::MAX / 2; 23 | const MSS_MIN: u16 = 536; 24 | 25 | /// Enum representing the various states of a TCP connection. 26 | #[derive(Debug, Clone, Copy, PartialEq, Eq, num_enum::IntoPrimitive, num_enum::TryFromPrimitive)] 27 | #[repr(u8)] 28 | pub enum TcpState { 29 | /// The listening state, waiting for incoming connection requests. 30 | Listen, 31 | /// The state after sending a SYN message, awaiting acknowledgment. 32 | SynSent, 33 | /// The state after receiving a SYN+ACK message, awaiting final acknowledgment. 34 | SynReceived, 35 | /// The state after completing the three-way handshake; the connection is established. 36 | Established, 37 | /// The state where the connection is in the process of being closed (after sending FIN). 38 | FinWait1, 39 | /// The state where the other side has acknowledged the connection termination. 40 | FinWait2, 41 | /// The state after receiving a FIN message, waiting for acknowledgment from the other side. 42 | CloseWait, 43 | /// The state where the connection is actively closing (waiting for all data to be sent/acknowledged). 44 | Closing, 45 | /// The state where the sender has sent the final FIN message and is waiting for acknowledgment from the other side. 46 | LastAck, 47 | /// The state after both sides have sent FIN messages, indicating the connection is fully closed. 48 | TimeWait, 49 | /// The state where the connection is completely closed. 50 | Closed, 51 | } 52 | 53 | #[derive(Debug)] 54 | pub struct Tcb { 55 | state: TcpState, 56 | local_addr: SocketAddr, 57 | peer_addr: SocketAddr, 58 | // Send snd_seq to the other party 59 | snd_seq: SeqNum, 60 | // Send snd_ack to the other party 61 | snd_ack: AckNum, 62 | last_snd_ack: AckNum, 63 | // Received ordered maximum seq 64 | // rcv_seq: SeqNum, 65 | // Received ack,Its starting point is snd_seq 66 | rcv_ack: AckNum, 67 | snd_wnd: u16, 68 | rcv_wnd: u16, 69 | mss: u16, 70 | sack_permitted: bool, 71 | snd_window_shift_cnt: u8, 72 | rcv_window_shift_cnt: u8, 73 | duplicate_ack_count: usize, 74 | tcp_receive_queue: TcpReceiveQueue, 75 | tcp_out_of_order_queue: TcpOfoQueue, 76 | back_seq: Option, 77 | inflight_packets: VecDeque, 78 | time_wait: Option, 79 | time_wait_timeout: Duration, 80 | write_timeout: Option, 81 | retransmission_timeout: Duration, 82 | timeout_count: (AckNum, usize), 83 | congestion_window: CongestionWindow, 84 | last_snd_wnd: u16, 85 | requires_ack_repeat: bool, 86 | } 87 | 88 | #[derive(Eq, PartialEq, Debug, Copy, Clone)] 89 | #[repr(transparent)] 90 | struct SeqNum(u32); 91 | 92 | type AckNum = SeqNum; 93 | 94 | impl From for SeqNum { 95 | fn from(value: u32) -> Self { 96 | Self(value) 97 | } 98 | } 99 | 100 | impl From for u32 { 101 | fn from(value: SeqNum) -> Self { 102 | value.0 103 | } 104 | } 105 | 106 | impl PartialOrd for SeqNum { 107 | fn partial_cmp(&self, other: &Self) -> Option { 108 | Some(self.cmp(other)) 109 | } 110 | } 111 | 112 | impl Ord for SeqNum { 113 | fn cmp(&self, other: &Self) -> Ordering { 114 | let diff = self.0.wrapping_sub(other.0); 115 | if diff == 0 { 116 | Ordering::Equal 117 | } else if diff < MAX_DIFF { 118 | Ordering::Greater 119 | } else { 120 | Ordering::Less 121 | } 122 | } 123 | } 124 | 125 | impl Add for SeqNum { 126 | type Output = SeqNum; 127 | 128 | fn add(self, rhs: Self) -> Self::Output { 129 | SeqNum(self.0.wrapping_add(rhs.0)) 130 | } 131 | } 132 | 133 | impl Sub for SeqNum { 134 | type Output = SeqNum; 135 | 136 | fn sub(self, rhs: Self) -> Self::Output { 137 | SeqNum(self.0.wrapping_sub(rhs.0)) 138 | } 139 | } 140 | 141 | impl SeqNum { 142 | fn add_num(self, n: u32) -> Self { 143 | SeqNum(self.0.wrapping_add(n)) 144 | } 145 | fn add_update(&mut self, n: u32) { 146 | self.0 = self.0.wrapping_add(n) 147 | } 148 | } 149 | 150 | #[derive(Debug)] 151 | pub(crate) struct UnreadPacket { 152 | seq: SeqNum, 153 | flags: u8, 154 | payload: BytesMut, 155 | } 156 | 157 | impl Eq for UnreadPacket {} 158 | 159 | impl PartialEq for UnreadPacket { 160 | fn eq(&self, other: &Self) -> bool { 161 | self.seq.eq(&other.seq) 162 | } 163 | } 164 | 165 | impl PartialOrd for UnreadPacket { 166 | fn partial_cmp(&self, other: &Self) -> Option { 167 | Some(self.cmp(other)) 168 | } 169 | } 170 | 171 | impl Ord for UnreadPacket { 172 | fn cmp(&self, other: &Self) -> Ordering { 173 | self.seq.cmp(&other.seq) 174 | } 175 | } 176 | 177 | impl UnreadPacket { 178 | fn new(seq: SeqNum, flags: u8, payload: BytesMut) -> Self { 179 | Self { seq, flags, payload } 180 | } 181 | pub(crate) fn len(&self) -> usize { 182 | if self.flags & FIN == FIN { 183 | self.payload.len() + 1 184 | } else { 185 | self.payload.len() 186 | } 187 | } 188 | fn advance(&mut self, cnt: usize) { 189 | self.seq.add_update(cnt as u32); 190 | self.payload.advance(cnt) 191 | } 192 | fn start(&self) -> SeqNum { 193 | self.seq 194 | } 195 | fn end(&self) -> SeqNum { 196 | self.seq.add_num(self.payload.len() as u32) 197 | } 198 | fn into_bytes(self) -> BytesMut { 199 | self.payload 200 | } 201 | } 202 | 203 | #[derive(Debug)] 204 | struct InflightPacket { 205 | seq: SeqNum, 206 | // Need to support SACK 207 | confirmed: bool, 208 | buf: FixedBuffer, 209 | } 210 | 211 | impl InflightPacket { 212 | pub fn new(seq: SeqNum, buf: FixedBuffer) -> Self { 213 | let mut packet = Self { 214 | seq, 215 | confirmed: false, 216 | buf, 217 | }; 218 | packet.init(); 219 | packet 220 | } 221 | pub fn init(&mut self) { 222 | self.buf.clear(); 223 | } 224 | pub fn len(&self) -> usize { 225 | self.buf.len() 226 | } 227 | 228 | pub fn advance(&mut self, cnt: usize) { 229 | self.seq.add_update(cnt as u32); 230 | self.buf.advance(cnt) 231 | } 232 | pub fn start(&self) -> SeqNum { 233 | self.seq 234 | } 235 | pub fn end(&self) -> SeqNum { 236 | self.seq.add_num(self.buf.len() as u32) 237 | } 238 | pub fn write(&mut self, buf: &[u8]) -> usize { 239 | self.buf.extend_from_slice(buf) 240 | } 241 | pub fn bytes(&self) -> &[u8] { 242 | self.buf.bytes() 243 | } 244 | } 245 | 246 | #[derive(Debug, Clone, Copy)] 247 | pub struct TcpConfig { 248 | pub retransmission_timeout: Duration, 249 | pub time_wait_timeout: Duration, 250 | pub mss: Option, 251 | pub rcv_wnd: u16, 252 | pub window_shift_cnt: u8, 253 | pub quick_end: bool, 254 | } 255 | 256 | impl Default for TcpConfig { 257 | fn default() -> Self { 258 | Self { 259 | retransmission_timeout: Duration::from_millis(1000), 260 | time_wait_timeout: Duration::from_secs(10), 261 | mss: None, 262 | rcv_wnd: u16::MAX, 263 | // Window size too large can cause packet loss 264 | window_shift_cnt: 2, 265 | // If the stream is closed, exit the corresponding task immediately 266 | quick_end: true, 267 | } 268 | } 269 | } 270 | 271 | impl TcpConfig { 272 | pub fn check(&self) -> io::Result<()> { 273 | if let Some(mss) = self.mss { 274 | if mss < MSS_MIN { 275 | return Err(io::Error::new(io::ErrorKind::InvalidData, "mss cannot be less than 536")); 276 | } 277 | } 278 | 279 | if self.retransmission_timeout.is_zero() { 280 | return Err(io::Error::new(io::ErrorKind::InvalidData, "retransmission_timeout is zero")); 281 | } 282 | Ok(()) 283 | } 284 | } 285 | 286 | /// Implementation related to initialization connection 287 | impl Tcb { 288 | pub fn new_listen(local_addr: SocketAddr, peer_addr: SocketAddr, config: TcpConfig) -> Self { 289 | let snd_seq = SeqNum::from(rand::rng().next_u32()); 290 | Self { 291 | state: TcpState::Listen, 292 | local_addr, 293 | peer_addr, 294 | snd_seq, 295 | snd_ack: AckNum::from(0), 296 | last_snd_ack: AckNum::from(0), 297 | snd_wnd: 0, 298 | rcv_wnd: config.rcv_wnd, 299 | rcv_ack: snd_seq, 300 | mss: config.mss.unwrap_or(MSS_MIN), 301 | sack_permitted: false, 302 | snd_window_shift_cnt: 0, 303 | rcv_window_shift_cnt: config.window_shift_cnt, 304 | duplicate_ack_count: 0, 305 | // rcv_seq: SeqNum(0), 306 | tcp_receive_queue: Default::default(), 307 | tcp_out_of_order_queue: Default::default(), 308 | back_seq: None, 309 | inflight_packets: Default::default(), 310 | time_wait: None, 311 | time_wait_timeout: config.time_wait_timeout, 312 | write_timeout: None, 313 | retransmission_timeout: config.retransmission_timeout, 314 | timeout_count: (AckNum::from(0), 0), 315 | congestion_window: CongestionWindow::default(), 316 | last_snd_wnd: 0, 317 | requires_ack_repeat: false, 318 | } 319 | } 320 | pub fn try_syn_sent(&mut self) -> Option { 321 | if self.state == TcpState::Listen || self.state == TcpState::SynSent { 322 | self.sent_syn(); 323 | let options = self.get_options(); 324 | let packet = self.create_option_transport_packet(SYN, &[], Some(&options)); 325 | Some(packet) 326 | } else { 327 | None 328 | } 329 | } 330 | pub fn try_syn_received(&mut self, tcp_packet: &TcpPacket<'_>) -> Option { 331 | let flags = tcp_packet.get_flags(); 332 | if flags & RST == RST { 333 | self.recv_rst(); 334 | return None; 335 | } 336 | if self.state == TcpState::Listen || self.state == TcpState::SynReceived { 337 | self.option(tcp_packet); 338 | self.snd_ack = AckNum::from(tcp_packet.get_sequence()).add_num(1); 339 | self.last_snd_ack = self.snd_ack; 340 | // self.rcv_seq = self.snd_ack; 341 | self.snd_wnd = tcp_packet.get_window(); 342 | self.recv_syn(); 343 | let options = self.get_options(); 344 | let relay = self.create_option_transport_packet(SYN | ACK, &[], Some(&options)); 345 | Some(relay) 346 | } else { 347 | None 348 | } 349 | } 350 | pub fn try_syn_received_to_established(&mut self, mut buf: BytesMut) -> bool { 351 | let Some(packet) = TcpPacket::new(&buf) else { 352 | self.error(); 353 | return false; 354 | }; 355 | let flags = packet.get_flags(); 356 | if flags & RST == RST { 357 | self.recv_rst(); 358 | return false; 359 | } 360 | let header_len = packet.get_data_offset() as usize * 4; 361 | let flags = packet.get_flags(); 362 | if self.state == TcpState::SynReceived 363 | && flags & ACK == ACK 364 | && self.snd_ack.0 == packet.get_sequence() 365 | && self.snd_seq.add_num(1).0 == packet.get_acknowledgement() 366 | { 367 | self.snd_wnd = packet.get_window(); 368 | self.snd_seq = SeqNum(packet.get_acknowledgement()); 369 | self.rcv_ack = SeqNum(packet.get_acknowledgement()); 370 | self.recv_syn_ack(); 371 | self.init_congestion_window(); 372 | if !packet.payload().is_empty() { 373 | let seq = SeqNum(packet.get_sequence()); 374 | buf.advance(header_len); 375 | let unread_packet = UnreadPacket::new(seq, flags, buf); 376 | self.recv(unread_packet) 377 | } 378 | return true; 379 | } 380 | false 381 | } 382 | pub fn try_syn_sent_to_established(&mut self, buf: BytesMut) -> Option { 383 | let packet = TcpPacket::new(&buf)?; 384 | let flags = packet.get_flags(); 385 | if self.state == TcpState::SynSent && flags & ACK == ACK && flags & SYN == SYN { 386 | self.snd_seq.add_update(1); 387 | self.snd_ack = SeqNum::from(packet.get_sequence()).add_num(1); 388 | self.last_snd_ack = self.snd_ack; 389 | self.rcv_ack = SeqNum(packet.get_acknowledgement()); 390 | self.snd_wnd = packet.get_window(); 391 | self.recv_syn_ack(); 392 | self.init_congestion_window(); 393 | let relay = self.create_option_transport_packet(ACK, &[], None); 394 | return Some(relay); 395 | } 396 | None 397 | } 398 | fn init_congestion_window(&mut self) { 399 | let initial_cwnd = self.mss as usize * 4; 400 | let max_cwnd = (self.snd_wnd as usize) << self.snd_window_shift_cnt; 401 | self.congestion_window 402 | .init(initial_cwnd, (initial_cwnd + max_cwnd) / 2, max_cwnd, self.mss as usize); 403 | } 404 | } 405 | 406 | impl Tcb { 407 | pub fn local_addr(&self) -> SocketAddr { 408 | self.local_addr 409 | } 410 | pub fn peer_addr(&self) -> SocketAddr { 411 | self.peer_addr 412 | } 413 | pub fn mss(&self) -> u16 { 414 | self.mss 415 | } 416 | fn get_options(&self) -> BytesMut { 417 | let mut options = BytesMut::with_capacity(40); 418 | let mss = self.mss; 419 | options.put_u8(TcpOptionNumbers::MSS.0); 420 | options.put_u8(4); 421 | options.put_u16(mss); 422 | 423 | options.put_u8(TcpOptionNumbers::NOP.0); 424 | options.put_u8(TcpOptionNumbers::WSCALE.0); 425 | options.put_u8(3); 426 | options.put_u8(self.rcv_window_shift_cnt); 427 | options.put_u8(TcpOptionNumbers::NOP.0); 428 | options.put_u8(TcpOptionNumbers::NOP.0); 429 | options.put_u8(TcpOptionNumbers::SACK_PERMITTED.0); 430 | options.put_u8(2); 431 | options 432 | } 433 | fn option(&mut self, tcp_packet: &TcpPacket<'_>) { 434 | for tcp_option in tcp_packet.get_options_iter() { 435 | let payload = tcp_option.payload(); 436 | match tcp_option.get_number() { 437 | TcpOptionNumbers::WSCALE => { 438 | if let Some(window_shift_cnt) = payload.first() { 439 | self.snd_window_shift_cnt = (*window_shift_cnt).min(14); 440 | } 441 | } 442 | TcpOptionNumbers::MSS => { 443 | if payload.len() == 2 { 444 | self.mss = ((payload[0] as u16) << 8) | (payload[1] as u16); 445 | } 446 | } 447 | TcpOptionNumbers::SACK_PERMITTED => { 448 | // Selective acknowledgements permitted. 449 | self.sack_permitted = true; 450 | } 451 | TcpOptionNumber(_) => {} 452 | } 453 | } 454 | } 455 | fn option_sack(&mut self, tcp_packet: &TcpPacket<'_>) { 456 | if !self.sack_permitted { 457 | return; 458 | } 459 | for tcp_option in tcp_packet.get_options_iter() { 460 | if tcp_option.get_number() == TcpOptionNumbers::SACK { 461 | let payload = tcp_option.payload(); 462 | if payload.len() & 7 != 0 { 463 | continue; 464 | } 465 | let n = payload.len() >> 3; 466 | for inflight_packet in self.inflight_packets.iter_mut() { 467 | for index in 0..n { 468 | let offset = index * 8; 469 | let left: SeqNum = payload[offset..4 + offset].try_into().map(u32::from_be_bytes).unwrap().into(); 470 | let right: SeqNum = payload[4 + offset..8 + offset].try_into().map(u32::from_be_bytes).unwrap().into(); 471 | if inflight_packet.confirmed || inflight_packet.end() <= left { 472 | break; 473 | } 474 | if inflight_packet.start() >= left && inflight_packet.end() <= right { 475 | inflight_packet.confirmed = true; 476 | } 477 | } 478 | } 479 | } 480 | } 481 | } 482 | fn create_transport_packet(&self, flags: u8, payload: &[u8]) -> TransportPacket { 483 | let data = self.create_packet(flags, self.snd_seq.0, self.snd_ack.0, payload, None); 484 | TransportPacket::new(data, NetworkTuple::new(self.local_addr, self.peer_addr, IpNextHeaderProtocols::Tcp)) 485 | } 486 | fn create_option_transport_packet(&self, flags: u8, payload: &[u8], options: Option<&[u8]>) -> TransportPacket { 487 | let data = self.create_packet(flags, self.snd_seq.0, self.snd_ack.0, payload, options); 488 | TransportPacket::new(data, NetworkTuple::new(self.local_addr, self.peer_addr, IpNextHeaderProtocols::Tcp)) 489 | } 490 | fn create_transport_packet_seq(&self, flags: u8, seq: u32, payload: &[u8]) -> TransportPacket { 491 | let data = self.create_packet(flags, seq, self.snd_ack.0, payload, None); 492 | TransportPacket::new(data, NetworkTuple::new(self.local_addr, self.peer_addr, IpNextHeaderProtocols::Tcp)) 493 | } 494 | 495 | fn create_packet(&self, flags: u8, seq: u32, ack: u32, payload: &[u8], options: Option<&[u8]>) -> BytesMut { 496 | create_packet_raw( 497 | &self.local_addr, 498 | &self.peer_addr, 499 | seq, 500 | ack, 501 | self.recv_window(), 502 | flags, 503 | payload, 504 | options, 505 | ) 506 | } 507 | } 508 | 509 | /// Implementation related to reading data 510 | impl Tcb { 511 | pub fn readable_state(&self) -> bool { 512 | matches!(self.state, TcpState::Established | TcpState::FinWait1 | TcpState::FinWait2) 513 | } 514 | pub fn cannot_read(&self) -> bool { 515 | !self.readable_state() && !self.readable() 516 | } 517 | 518 | pub fn push_packet(&mut self, mut buf: BytesMut) -> Option { 519 | let Some(packet) = TcpPacket::new(&buf) else { 520 | self.error(); 521 | return None; 522 | }; 523 | let flags = packet.get_flags(); 524 | if flags & RST == RST { 525 | self.recv_rst(); 526 | return None; 527 | } 528 | if flags & SYN == SYN { 529 | let reply_packet = self.create_transport_packet(RST, &[]); 530 | return Some(reply_packet); 531 | } 532 | 533 | let header_len = packet.get_data_offset() as usize * 4; 534 | match self.state { 535 | TcpState::Established | TcpState::FinWait1 | TcpState::FinWait2 => { 536 | if flags & ACK == ACK { 537 | let acknowledgement = AckNum::from(packet.get_acknowledgement()); 538 | if acknowledgement == self.rcv_ack { 539 | if self.rcv_ack != self.snd_seq { 540 | self.duplicate_ack_count += 1; 541 | if self.duplicate_ack_count > 3 { 542 | self.back_n(); 543 | } 544 | } 545 | self.snd_wnd = packet.get_window(); 546 | } 547 | 548 | self.update_last_ack(&packet); 549 | self.option_sack(&packet); 550 | } 551 | let seq = SeqNum(packet.get_sequence()); 552 | buf.advance(header_len); 553 | let unread_packet = UnreadPacket::new(seq, flags, buf); 554 | if self.rcv_wnd == 0 { 555 | self.snd_ack = unread_packet.end(); 556 | } 557 | if self.recv_buffer_full() { 558 | // Packet loss occurs when the buffer is full 559 | return None; 560 | } 561 | if unread_packet.end() >= self.snd_ack { 562 | self.recv(unread_packet); 563 | } 564 | return None; 565 | } 566 | TcpState::CloseWait | TcpState::Closing | TcpState::LastAck | TcpState::TimeWait => { 567 | if flags & ACK == ACK { 568 | let acknowledgement = AckNum::from(packet.get_acknowledgement()); 569 | if acknowledgement > self.snd_seq { 570 | // acknowledgement == self.snd_seq + 1 571 | self.recv_fin_ack() 572 | } 573 | } 574 | if flags & FIN == FIN { 575 | self.recv_fin(); 576 | // reply ACK 577 | let reply_packet = self.create_transport_packet(ACK, &[]); 578 | return Some(reply_packet); 579 | } 580 | return None; 581 | } 582 | _ => { 583 | // RST 584 | } 585 | } 586 | self.error(); 587 | let reply_packet = self.create_transport_packet(RST, &[]); 588 | Some(reply_packet) 589 | } 590 | pub fn readable(&self) -> bool { 591 | self.tcp_receive_queue.total_bytes() != 0 592 | } 593 | pub fn read_none(&mut self) { 594 | self.rcv_wnd = 0; 595 | self.tcp_receive_queue.clear(); 596 | } 597 | pub fn read(&mut self) -> Option { 598 | self.tcp_receive_queue.pop() 599 | } 600 | 601 | fn recv(&mut self, mut unread_packet: UnreadPacket) { 602 | let start = unread_packet.start(); 603 | if self.snd_ack >= start { 604 | let flags = unread_packet.flags; 605 | let end = unread_packet.end(); 606 | if end > self.snd_ack { 607 | unread_packet.advance((self.snd_ack - start).0 as usize); 608 | self.snd_ack = end; 609 | self.tcp_receive_queue.push(unread_packet.into_bytes()) 610 | } 611 | if flags & FIN == FIN { 612 | self.recv_fin(); 613 | } 614 | } else { 615 | self.tcp_out_of_order_queue.push(unread_packet); 616 | self.advice_ack(); 617 | if !self.tcp_out_of_order_queue.is_empty() { 618 | // If out-of-order packets are present, a duplicate ACK is required to trigger the peer's fast retransmit. 619 | self.requires_ack_repeat = true; 620 | } 621 | } 622 | } 623 | fn advice_ack(&mut self) { 624 | while let Some(packet) = self.tcp_out_of_order_queue.peek() { 625 | let start = packet.start(); 626 | if self.snd_ack < start { 627 | //unordered 628 | break; 629 | } 630 | let flags = packet.flags; 631 | let end = packet.end(); 632 | let mut unread_packet = self.tcp_out_of_order_queue.pop().unwrap(); 633 | if end > self.snd_ack { 634 | let offset = (self.snd_ack - start).0; 635 | self.snd_ack = end; 636 | unread_packet.advance(offset as usize); 637 | self.tcp_receive_queue.push(unread_packet.into_bytes()); 638 | } 639 | if flags & FIN == FIN { 640 | self.recv_fin(); 641 | break; 642 | } 643 | } 644 | } 645 | pub fn need_ack(&self) -> bool { 646 | self.last_snd_wnd != self.recv_window() || self.snd_ack != self.last_snd_ack || self.requires_ack_repeat 647 | } 648 | pub fn recv_window(&self) -> u16 { 649 | let src_rcv_wnd = (self.rcv_wnd as usize) << self.rcv_window_shift_cnt; 650 | let unread_total_bytes = self.tcp_out_of_order_queue.total_bytes() + self.tcp_receive_queue.total_bytes(); 651 | let rcv_wnd = src_rcv_wnd.saturating_sub(unread_total_bytes); 652 | (rcv_wnd >> self.rcv_window_shift_cnt) as u16 653 | } 654 | fn recv_buffer_full(&self) -> bool { 655 | // To reduce packet loss, the actual receivable window size is larger than the recv_window() 656 | let src_rcv_wnd = ((self.rcv_wnd as usize) << self.rcv_window_shift_cnt) << 1; 657 | let unread_total_bytes = self.tcp_out_of_order_queue.total_bytes() + self.tcp_receive_queue.total_bytes(); 658 | src_rcv_wnd <= unread_total_bytes 659 | } 660 | // pub fn recv_busy(&self) -> bool { 661 | // if !self.readable_state() || self.rcv_wnd == 0 { 662 | // return false; 663 | // } 664 | // let src_rcv_wnd = (self.rcv_wnd as usize) << self.rcv_window_shift_cnt; 665 | // let unread_total_bytes = self.tcp_out_of_order_queue.total_bytes() + self.tcp_receive_queue.total_bytes(); 666 | // let rcv_wnd = src_rcv_wnd.saturating_sub(unread_total_bytes); 667 | // rcv_wnd <= 2 * self.mss as usize 668 | // } 669 | } 670 | 671 | /// Implementation related to writing data 672 | impl Tcb { 673 | #[inline] 674 | fn ack_distance(&self) -> u32 { 675 | (self.snd_seq - self.rcv_ack).0 676 | } 677 | fn send_window(&self) -> usize { 678 | let distance = self.ack_distance(); 679 | let snd_wnd = (self.snd_wnd as usize) << self.snd_window_shift_cnt; 680 | let wnd = self.congestion_window.current_window_size().min(snd_wnd); 681 | // log::info!("snd_wnd1 ={snd_wnd1} snd_wnd = {snd_wnd:?},distance={distance}"); 682 | wnd.saturating_sub(distance as usize) 683 | } 684 | 685 | pub fn perform_post_ack_action(&mut self) { 686 | self.last_snd_wnd = self.recv_window(); 687 | self.last_snd_ack = self.snd_ack; 688 | self.requires_ack_repeat = false; 689 | } 690 | fn update_last_ack(&mut self, tcp_packet: &TcpPacket<'_>) { 691 | let ack = AckNum::from(tcp_packet.get_acknowledgement()); 692 | if ack <= self.rcv_ack { 693 | return; 694 | } 695 | self.snd_wnd = tcp_packet.get_window(); 696 | self.congestion_window.on_ack(); 697 | self.duplicate_ack_count = 0; 698 | let mut distance = (ack - self.rcv_ack).0 as usize; 699 | self.rcv_ack = ack; 700 | while let Some(inflight_packet) = self.inflight_packets.front_mut() { 701 | if inflight_packet.len() > distance { 702 | inflight_packet.advance(distance); 703 | break; 704 | } else { 705 | distance -= inflight_packet.len(); 706 | _ = self.inflight_packets.pop_front(); 707 | } 708 | } 709 | if self.inflight_packets.is_empty() { 710 | self.write_timeout.take(); 711 | } else if let Some(write_timeout) = self.write_timeout.as_mut() { 712 | *write_timeout += self.retransmission_timeout 713 | } 714 | if !self.writeable_state() && self.rcv_ack > self.snd_seq { 715 | self.recv_fin_ack() 716 | } 717 | self.reset_write_timeout(); 718 | } 719 | fn take_send_buf(&mut self) -> Option { 720 | let bytes_mut = FixedBuffer::with_capacity(self.mss as usize); 721 | Some(InflightPacket::new(self.snd_seq, bytes_mut)) 722 | } 723 | pub fn write(&mut self, buf: &[u8]) -> Option<(TransportPacket, usize)> { 724 | let rs = self.write0(buf); 725 | self.init_write_timeout(); 726 | rs 727 | } 728 | fn write0(&mut self, mut buf: &[u8]) -> Option<(TransportPacket, usize)> { 729 | if !self.writeable_state() { 730 | return None; 731 | } 732 | let seq = self.snd_seq.0; 733 | let snd_wnd = self.send_window(); 734 | if snd_wnd < buf.len() { 735 | buf = &buf[..snd_wnd]; 736 | } 737 | if buf.is_empty() { 738 | return None; 739 | } 740 | let flags = if self.decelerate() { PSH | ACK } else { ACK }; 741 | if let Some(packet) = self.inflight_packets.back_mut() { 742 | let n = packet.write(buf); 743 | if n > 0 { 744 | let packet = self.create_transport_packet_seq(flags, seq, &buf[..n]); 745 | self.snd_seq.add_update(n as u32); 746 | return Some((packet, n)); 747 | } 748 | } 749 | 750 | if let Some(mut packet) = self.take_send_buf() { 751 | let n = packet.write(buf); 752 | assert!(n > 0); 753 | self.inflight_packets.push_back(packet); 754 | let packet = self.create_transport_packet_seq(flags, seq, &buf[..n]); 755 | self.snd_seq.add_update(n as u32); 756 | return Some((packet, n)); 757 | } 758 | None 759 | } 760 | pub fn write_timeout(&self) -> Option { 761 | self.write_timeout 762 | } 763 | fn reset_write_timeout(&mut self) { 764 | if !self.inflight_packets.is_empty() { 765 | self.write_timeout.replace(Instant::now() + self.retransmission_timeout); 766 | } 767 | } 768 | fn init_write_timeout(&mut self) { 769 | if self.write_timeout.is_none() { 770 | self.reset_write_timeout(); 771 | } 772 | } 773 | 774 | pub fn retransmission(&mut self) -> Option { 775 | let back_seq = self.back_seq?; 776 | for packet in self.inflight_packets.iter() { 777 | if packet.confirmed { 778 | continue; 779 | } 780 | if packet.end() > back_seq { 781 | self.back_seq.replace(packet.end()); 782 | return Some(self.create_transport_packet_seq(ACK, packet.start().0, packet.bytes())); 783 | } 784 | } 785 | self.back_seq.take(); 786 | None 787 | } 788 | fn back_n(&mut self) -> bool { 789 | if let Some(v) = self.inflight_packets.front() { 790 | self.back_seq.replace(v.start()); 791 | self.congestion_window.on_loss(); 792 | self.reset_write_timeout(); 793 | true 794 | } else { 795 | false 796 | } 797 | } 798 | pub fn decelerate(&self) -> bool { 799 | let snd_wnd = self.send_window(); 800 | snd_wnd <= (self.mss as usize) << 4 801 | } 802 | pub fn limit(&self) -> bool { 803 | let snd_wnd = self.send_window(); 804 | snd_wnd == 0 805 | } 806 | pub fn no_inflight_packet(&self) -> bool { 807 | self.inflight_packets.is_empty() 808 | } 809 | pub fn writeable_state(&self) -> bool { 810 | self.state == TcpState::Established || self.state == TcpState::CloseWait 811 | } 812 | pub fn cannot_write(&self) -> bool { 813 | !self.writeable_state() 814 | } 815 | pub fn is_close(&self) -> bool { 816 | self.state == TcpState::Closed 817 | } 818 | pub fn time_wait(&self) -> Option { 819 | self.time_wait 820 | } 821 | pub fn timeout(&mut self) { 822 | if self.state == TcpState::TimeWait { 823 | self.timeout_wait(); 824 | return; 825 | } 826 | if !self.back_n() { 827 | return; 828 | } 829 | if self.timeout_count.0 == self.rcv_ack { 830 | self.timeout_count.1 += 1; 831 | if self.timeout_count.1 > 10 { 832 | self.error(); 833 | } 834 | } else { 835 | self.timeout_count.0 = self.rcv_ack; 836 | self.timeout_count.1 = 0; 837 | } 838 | } 839 | pub fn need_retransmission(&self) -> bool { 840 | self.back_seq.is_some() 841 | } 842 | } 843 | 844 | /// TCP state rotation 845 | impl Tcb { 846 | fn sent_syn(&mut self) { 847 | if self.state == TcpState::Listen { 848 | self.state = TcpState::SynSent 849 | } 850 | } 851 | fn recv_syn(&mut self) { 852 | if self.state == TcpState::Listen { 853 | self.state = TcpState::SynReceived 854 | } 855 | } 856 | fn recv_syn_ack(&mut self) { 857 | match self.state { 858 | TcpState::SynReceived => self.state = TcpState::Established, 859 | TcpState::SynSent => self.state = TcpState::Established, 860 | _ => {} 861 | } 862 | } 863 | 864 | pub fn sent_fin(&mut self) { 865 | match self.state { 866 | TcpState::Established => self.state = TcpState::FinWait1, 867 | TcpState::CloseWait => self.state = TcpState::LastAck, 868 | _ => {} 869 | } 870 | } 871 | fn recv_fin(&mut self) { 872 | match self.state { 873 | TcpState::Established => { 874 | self.snd_ack.add_update(1); 875 | self.state = TcpState::CloseWait 876 | } 877 | TcpState::FinWait1 => { 878 | self.snd_ack.add_update(1); 879 | self.state = TcpState::Closing 880 | } 881 | TcpState::FinWait2 => { 882 | self.snd_ack.add_update(1); 883 | self.time_wait = Some(Instant::now() + self.time_wait_timeout); 884 | self.state = TcpState::TimeWait 885 | } 886 | _ => {} 887 | } 888 | } 889 | fn recv_fin_ack(&mut self) { 890 | match self.state { 891 | TcpState::FinWait1 => self.state = TcpState::FinWait2, 892 | TcpState::Closing => self.state = TcpState::TimeWait, 893 | TcpState::LastAck => self.state = TcpState::Closed, 894 | _ => {} 895 | } 896 | } 897 | fn recv_rst(&mut self) { 898 | self.state = TcpState::Closed 899 | } 900 | fn timeout_wait(&mut self) { 901 | assert_eq!(self.state, TcpState::TimeWait); 902 | self.state = TcpState::Closed 903 | } 904 | fn error(&mut self) { 905 | self.state = TcpState::Closed 906 | } 907 | pub fn fin_packet(&self) -> TransportPacket { 908 | let seq = self.snd_seq.0; 909 | self.create_transport_packet_seq(FIN | ACK, seq, &[]) 910 | } 911 | pub fn ack_packet(&self) -> TransportPacket { 912 | let seq = self.snd_seq.0; 913 | self.create_transport_packet_seq(ACK, seq, &[]) 914 | } 915 | } 916 | 917 | pub fn create_transport_packet_raw( 918 | local_addr: &SocketAddr, 919 | peer_addr: &SocketAddr, 920 | snd_seq: u32, 921 | rcv_ack: u32, 922 | rcv_wnd: u16, 923 | flags: u8, 924 | payload: &[u8], 925 | ) -> TransportPacket { 926 | let data = create_packet_raw(local_addr, peer_addr, snd_seq, rcv_ack, rcv_wnd, flags, payload, None); 927 | TransportPacket::new(data, NetworkTuple::new(*local_addr, *peer_addr, IpNextHeaderProtocols::Tcp)) 928 | } 929 | 930 | #[allow(clippy::too_many_arguments)] 931 | pub fn create_packet_raw( 932 | local_addr: &SocketAddr, 933 | peer_addr: &SocketAddr, 934 | snd_seq: u32, 935 | snd_ack: u32, 936 | rcv_wnd: u16, 937 | flags: u8, 938 | payload: &[u8], 939 | options: Option<&[u8]>, 940 | ) -> BytesMut { 941 | let mut bytes = BytesMut::with_capacity(TCP_HEADER_LEN + payload.len()); 942 | bytes.put_u16(local_addr.port()); 943 | bytes.put_u16(peer_addr.port()); 944 | bytes.put_u32(snd_seq); 945 | bytes.put_u32(snd_ack); 946 | let head_len = options 947 | .filter(|op| !op.is_empty()) 948 | .map(|op| { 949 | assert_eq!(op.len() & 3, 0, "Options must be aligned with four bytes"); 950 | TCP_HEADER_LEN + op.len() 951 | }) 952 | .unwrap_or(TCP_HEADER_LEN); 953 | // Data Offset 954 | bytes.put_u8((head_len as u8 / 4) << 4); 955 | bytes.put_u8(flags); 956 | bytes.put_u16(rcv_wnd); 957 | // Checksum 958 | bytes.put_u16(0); 959 | // Urgent Pointer 960 | bytes.put_u16(0); 961 | if let Some(op) = options { 962 | if !op.is_empty() { 963 | bytes.extend_from_slice(op); 964 | } 965 | } 966 | bytes.extend_from_slice(payload); 967 | let checksum = match (local_addr.ip(), peer_addr.ip()) { 968 | (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => { 969 | pnet_packet::util::ipv4_checksum(&bytes, 8, &[], &src_ip, &dst_ip, IpNextHeaderProtocols::Tcp) 970 | } 971 | (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => { 972 | pnet_packet::util::ipv6_checksum(&bytes, 8, &[], &src_ip, &dst_ip, IpNextHeaderProtocols::Tcp) 973 | } 974 | (_, _) => { 975 | unreachable!() 976 | } 977 | }; 978 | bytes[16..18].copy_from_slice(&checksum.to_be_bytes()); 979 | bytes 980 | } 981 | 982 | #[derive(Copy, Clone, Debug, Default)] 983 | struct CongestionWindow { 984 | cwnd: usize, 985 | ssthresh: usize, 986 | max_cwnd: usize, 987 | mss: usize, 988 | } 989 | 990 | impl CongestionWindow { 991 | pub fn init(&mut self, initial_cwnd: usize, initial_ssthresh: usize, max_cwnd: usize, mss: usize) { 992 | self.cwnd = initial_cwnd; 993 | self.ssthresh = initial_ssthresh; 994 | self.max_cwnd = max_cwnd; 995 | self.mss = mss; 996 | } 997 | 998 | pub fn on_ack(&mut self) { 999 | if self.cwnd < self.ssthresh { 1000 | self.cwnd *= 2; 1001 | } else { 1002 | self.cwnd += (self.cwnd as f64).sqrt() as usize; 1003 | } 1004 | 1005 | self.cwnd = self.cwnd.min(self.max_cwnd); 1006 | } 1007 | 1008 | pub fn on_loss(&mut self) { 1009 | self.ssthresh = self.cwnd / 2; 1010 | self.cwnd = self.mss; 1011 | } 1012 | 1013 | pub fn current_window_size(&self) -> usize { 1014 | self.cwnd 1015 | } 1016 | } 1017 | -------------------------------------------------------------------------------- /src/tcp/tcp_queue.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused, unused_variables)] 2 | use crate::tcp::tcb::UnreadPacket; 3 | use bytes::{Buf, BytesMut}; 4 | use std::cmp::Ordering; 5 | use std::collections::LinkedList; 6 | use std::marker::PhantomData; 7 | use std::mem; 8 | use std::ops::Deref; 9 | use std::ptr::NonNull; 10 | 11 | #[derive(Debug, Default)] 12 | pub(crate) struct TcpReceiveQueue { 13 | total_bytes: usize, 14 | queue: LinkedList, 15 | } 16 | pub(crate) struct TcpReceiveQueueItem<'a> { 17 | total_bytes: &'a mut usize, 18 | payload: &'a mut BytesMut, 19 | } 20 | impl Deref for TcpReceiveQueueItem<'_> { 21 | type Target = BytesMut; 22 | 23 | fn deref(&self) -> &Self::Target { 24 | self.payload 25 | } 26 | } 27 | impl TcpReceiveQueueItem<'_> { 28 | pub fn advance(&mut self, cnt: usize) { 29 | self.payload.advance(cnt); 30 | *self.total_bytes -= cnt; 31 | } 32 | } 33 | impl TcpReceiveQueue { 34 | pub fn push(&mut self, elt: BytesMut) { 35 | self.total_bytes += elt.len(); 36 | self.queue.push_back(elt); 37 | } 38 | pub fn pop(&mut self) -> Option { 39 | if let Some(v) = self.queue.pop_front() { 40 | self.total_bytes -= v.len(); 41 | Some(v) 42 | } else { 43 | None 44 | } 45 | } 46 | pub fn peek(&mut self) -> Option { 47 | let total_bytes = &mut self.total_bytes; 48 | self.queue.front_mut().map(|payload| TcpReceiveQueueItem { total_bytes, payload }) 49 | } 50 | pub fn clear(&mut self) { 51 | self.queue.clear(); 52 | self.total_bytes = 0; 53 | } 54 | pub fn total_bytes(&self) -> usize { 55 | self.total_bytes 56 | } 57 | pub fn is_empty(&self) -> bool { 58 | self.queue.is_empty() 59 | } 60 | } 61 | #[derive(Debug, Default)] 62 | pub(crate) struct TcpOfoQueue { 63 | total_bytes: usize, 64 | queue: OrderQueue, 65 | } 66 | fn handle_duplicate_seq(p1: &UnreadPacket, p2: &UnreadPacket) -> bool { 67 | p1.len() < p2.len() 68 | } 69 | impl TcpOfoQueue { 70 | pub fn total_bytes(&self) -> usize { 71 | self.total_bytes 72 | } 73 | pub fn push(&mut self, elt: UnreadPacket) { 74 | self.total_bytes += elt.len(); 75 | self.queue.push(elt, handle_duplicate_seq); 76 | } 77 | pub fn pop(&mut self) -> Option { 78 | if let Some(v) = self.queue.pop() { 79 | self.total_bytes -= v.len(); 80 | Some(v) 81 | } else { 82 | None 83 | } 84 | } 85 | pub fn peek(&self) -> Option<&UnreadPacket> { 86 | self.queue.peek() 87 | } 88 | pub fn clear(&mut self) { 89 | self.queue.clear(); 90 | self.total_bytes = 0; 91 | } 92 | pub fn len(&self) -> usize { 93 | self.queue.len() 94 | } 95 | pub fn is_empty(&self) -> bool { 96 | self.queue.is_empty() 97 | } 98 | } 99 | 100 | impl<'a> IntoIterator for &'a TcpOfoQueue { 101 | type Item = &'a UnreadPacket; 102 | type IntoIter = Iter<'a, UnreadPacket>; 103 | 104 | fn into_iter(self) -> Iter<'a, UnreadPacket> { 105 | self.queue.iter() 106 | } 107 | } 108 | 109 | #[derive(Debug)] 110 | pub struct OrderQueue { 111 | head: Option>>, 112 | tail: Option>>, 113 | len: usize, 114 | } 115 | 116 | struct Node { 117 | next: Option>>, 118 | prev: Option>>, 119 | element: T, 120 | } 121 | 122 | impl Node { 123 | fn new(element: T) -> Self { 124 | Node { 125 | next: None, 126 | prev: None, 127 | element, 128 | } 129 | } 130 | } 131 | impl Default for OrderQueue { 132 | fn default() -> Self { 133 | OrderQueue::new() 134 | } 135 | } 136 | impl OrderQueue { 137 | pub fn push(&mut self, elt: T, compute: F) 138 | where 139 | F: Fn(&T, &T) -> bool, 140 | { 141 | let mut prev = self.tail; 142 | while let Some(mut v) = prev { 143 | unsafe { 144 | let curr_elt = &v.as_ref().element; 145 | match curr_elt.cmp(&elt) { 146 | Ordering::Less => break, 147 | Ordering::Equal => { 148 | if compute(curr_elt, &elt) { 149 | v.as_mut().element = elt; 150 | } 151 | return; 152 | } 153 | Ordering::Greater => { 154 | prev = v.as_ref().prev; 155 | } 156 | } 157 | } 158 | } 159 | 160 | let mut node = Box::new(Node::new(elt)); 161 | node.prev = prev; 162 | let node_ptr = NonNull::from(Box::leak(node)); 163 | let node = Some(node_ptr); 164 | 165 | unsafe { 166 | match prev { 167 | None => { 168 | (*node_ptr.as_ptr()).next = self.head; 169 | self.head = node 170 | } 171 | Some(prev) => { 172 | (*node_ptr.as_ptr()).next = (*prev.as_ptr()).next; 173 | (*prev.as_ptr()).next = node 174 | } 175 | } 176 | match (*node_ptr.as_ptr()).next { 177 | None => { 178 | self.tail = node; 179 | } 180 | Some(next) => { 181 | (*next.as_ptr()).prev = node; 182 | } 183 | } 184 | } 185 | 186 | self.len += 1; 187 | } 188 | } 189 | 190 | impl OrderQueue { 191 | pub fn new() -> Self { 192 | Self { 193 | head: None, 194 | tail: None, 195 | len: 0, 196 | } 197 | } 198 | #[inline] 199 | pub fn peek(&self) -> Option<&T> { 200 | self.head.map(|v| unsafe { &(*v.as_ptr()).element }) 201 | } 202 | pub fn pop(&mut self) -> Option { 203 | self.head.map(|node| { 204 | unsafe { 205 | let node = Box::from_raw(node.as_ptr()); 206 | self.head = node.next; 207 | 208 | match self.head { 209 | None => self.tail = None, 210 | // Not creating new mutable (unique!) references overlapping `element`. 211 | Some(head) => (*head.as_ptr()).prev = None, 212 | } 213 | self.len -= 1; 214 | node.element 215 | } 216 | }) 217 | } 218 | pub fn clear(&mut self) { 219 | drop(OrderQueue { 220 | head: self.head.take(), 221 | tail: self.tail.take(), 222 | len: mem::take(&mut self.len), 223 | }); 224 | } 225 | pub fn len(&self) -> usize { 226 | self.len 227 | } 228 | pub fn is_empty(&self) -> bool { 229 | self.len == 0 230 | } 231 | pub fn iter(&self) -> Iter<'_, T> { 232 | Iter { 233 | head: self.head, 234 | tail: self.tail, 235 | len: self.len, 236 | marker: PhantomData, 237 | } 238 | } 239 | } 240 | 241 | pub struct Iter<'a, T: 'a> { 242 | head: Option>>, 243 | tail: Option>>, 244 | len: usize, 245 | marker: PhantomData<&'a Node>, 246 | } 247 | 248 | pub struct IntoIter { 249 | list: OrderQueue, 250 | } 251 | 252 | impl<'a, T> Iterator for Iter<'a, T> { 253 | type Item = &'a T; 254 | 255 | #[inline] 256 | fn next(&mut self) -> Option<&'a T> { 257 | if self.len == 0 { 258 | None 259 | } else { 260 | self.head.map(|node| unsafe { 261 | // Need an unbound lifetime to get 'a 262 | let node = &*node.as_ptr(); 263 | self.len -= 1; 264 | self.head = node.next; 265 | &node.element 266 | }) 267 | } 268 | } 269 | 270 | #[inline] 271 | fn size_hint(&self) -> (usize, Option) { 272 | (self.len, Some(self.len)) 273 | } 274 | } 275 | 276 | impl Iterator for IntoIter { 277 | type Item = T; 278 | 279 | #[inline] 280 | fn next(&mut self) -> Option { 281 | self.list.pop() 282 | } 283 | 284 | #[inline] 285 | fn size_hint(&self) -> (usize, Option) { 286 | (self.list.len, Some(self.list.len)) 287 | } 288 | } 289 | 290 | impl<'a, T> IntoIterator for &'a OrderQueue { 291 | type Item = &'a T; 292 | type IntoIter = Iter<'a, T>; 293 | 294 | fn into_iter(self) -> Iter<'a, T> { 295 | self.iter() 296 | } 297 | } 298 | 299 | impl IntoIterator for OrderQueue { 300 | type Item = T; 301 | type IntoIter = IntoIter; 302 | 303 | /// Consumes the list into an iterator yielding elements by value. 304 | #[inline] 305 | fn into_iter(self) -> IntoIter { 306 | IntoIter { list: self } 307 | } 308 | } 309 | 310 | impl Drop for OrderQueue { 311 | fn drop(&mut self) { 312 | while self.pop().is_some() {} 313 | } 314 | } 315 | unsafe impl Send for OrderQueue {} 316 | 317 | unsafe impl Sync for OrderQueue {} 318 | 319 | #[cfg(test)] 320 | mod tests { 321 | use std::sync::Arc; 322 | 323 | use super::*; 324 | 325 | #[test] 326 | fn test_push_and_peek() { 327 | let mut queue = OrderQueue::new(); 328 | queue.push(10, |_, _| false); 329 | assert_eq!(queue.peek(), Some(&10)); 330 | queue.push(20, |_, _| false); 331 | assert_eq!(queue.peek(), Some(&10)); 332 | queue.push(5, |_, _| false); 333 | assert_eq!(queue.peek(), Some(&5)); 334 | queue.push(6, |_, _| false); 335 | assert_eq!(queue.peek(), Some(&5)); 336 | queue.push(7, |_, _| false); 337 | assert_eq!(queue.peek(), Some(&5)); 338 | queue.push(1, |_, _| false); 339 | assert_eq!(queue.peek(), Some(&1)); 340 | assert_eq!(queue.len(), 6); 341 | let list: Vec = queue.iter().copied().collect(); 342 | assert_eq!(&list, &[1, 5, 6, 7, 10, 20]); 343 | } 344 | 345 | #[test] 346 | fn test_push_with_duplicate_handling() { 347 | let mut queue = OrderQueue::new(); 348 | 349 | queue.push(10, |_, _| false); 350 | queue.push(10, |_, _| false); 351 | assert_eq!(queue.len(), 1); 352 | assert_eq!(queue.peek(), Some(&10)); 353 | assert_eq!(queue.len(), 1); 354 | 355 | queue.push(10, |_, _| true); 356 | assert_eq!(queue.peek(), Some(&10)); 357 | assert_eq!(queue.len(), 1); 358 | } 359 | 360 | #[test] 361 | fn test_pop() { 362 | let mut queue = OrderQueue::new(); 363 | 364 | queue.push(10, |_, _| false); 365 | queue.push(20, |_, _| false); 366 | queue.push(5, |_, _| false); 367 | queue.push(6, |_, _| false); 368 | queue.push(7, |_, _| false); 369 | queue.push(1, |_, _| false); 370 | queue.push(0, |_, _| false); 371 | queue.push(100, |_, _| false); 372 | queue.push(99, |_, _| false); 373 | assert_eq!(queue.pop(), Some(0)); 374 | assert_eq!(queue.pop(), Some(1)); 375 | assert_eq!(queue.pop(), Some(5)); 376 | assert_eq!(queue.pop(), Some(6)); 377 | assert_eq!(queue.pop(), Some(7)); 378 | assert_eq!(queue.pop(), Some(10)); 379 | assert_eq!(queue.pop(), Some(20)); 380 | assert_eq!(queue.pop(), Some(99)); 381 | assert_eq!(queue.pop(), Some(100)); 382 | assert_eq!(queue.pop(), None); 383 | assert_eq!(queue.len(), 0); 384 | } 385 | 386 | #[test] 387 | fn test_clear() { 388 | let mut queue = OrderQueue::new(); 389 | 390 | queue.push(10, |_, _| false); 391 | queue.push(20, |_, _| false); 392 | queue.push(30, |_, _| false); 393 | 394 | queue.clear(); 395 | assert_eq!(queue.len(), 0); 396 | assert_eq!(queue.peek(), None); 397 | assert_eq!(queue.pop(), None); 398 | } 399 | 400 | #[test] 401 | fn test_len() { 402 | let mut queue = OrderQueue::new(); 403 | 404 | assert_eq!(queue.len, 0); 405 | 406 | queue.push(10, |_, _| false); 407 | assert_eq!(queue.len, 1); 408 | 409 | queue.push(20, |_, _| false); 410 | assert_eq!(queue.len, 2); 411 | 412 | queue.pop(); 413 | assert_eq!(queue.len, 1); 414 | 415 | queue.pop(); 416 | assert_eq!(queue.len, 0); 417 | } 418 | 419 | #[test] 420 | fn test_ordering() { 421 | let mut queue = OrderQueue::new(); 422 | 423 | queue.push(15, |_, _| false); 424 | queue.push(10, |_, _| false); 425 | queue.push(20, |_, _| false); 426 | queue.push(5, |_, _| false); 427 | 428 | assert_eq!(queue.pop(), Some(5)); 429 | assert_eq!(queue.pop(), Some(10)); 430 | assert_eq!(queue.pop(), Some(15)); 431 | assert_eq!(queue.pop(), Some(20)); 432 | } 433 | 434 | #[test] 435 | fn test_drop_after_pop() { 436 | let mut queue = OrderQueue::new(); 437 | 438 | // Create elements that track drop count 439 | let mut elem1 = Arc::new(10); 440 | let mut elem2 = Arc::new(100); 441 | let mut elem3 = Arc::new(5); 442 | let mut elem4 = Arc::new(6); 443 | 444 | // Push elements into the queue 445 | queue.push(elem1.clone(), |_, _| false); 446 | queue.push(elem2.clone(), |_, _| false); 447 | queue.push(elem3.clone(), |_, _| false); 448 | queue.push(elem4.clone(), |_, _| false); 449 | assert_eq!(Arc::strong_count(&elem1), 2); 450 | assert_eq!(Arc::strong_count(&elem2), 2); 451 | assert_eq!(Arc::strong_count(&elem3), 2); 452 | assert_eq!(Arc::strong_count(&elem4), 2); 453 | queue.clear(); 454 | assert_eq!(Arc::strong_count(&elem1), 1); 455 | assert_eq!(Arc::strong_count(&elem2), 1); 456 | assert_eq!(Arc::strong_count(&elem3), 1); 457 | assert_eq!(Arc::strong_count(&elem4), 1); 458 | } 459 | } 460 | -------------------------------------------------------------------------------- /src/udp/mod.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::net::{IpAddr, SocketAddr}; 3 | 4 | use crate::address::ToSocketAddr; 5 | use crate::ip_stack::{check_addr, check_ip, BindAddr, IpStack, NetworkTuple, TransportPacket}; 6 | use bytes::{BufMut, BytesMut}; 7 | use pnet_packet::ip::IpNextHeaderProtocols; 8 | use pnet_packet::Packet; 9 | 10 | /// A UDP socket. 11 | /// 12 | /// UDP is "connectionless", unlike TCP. Meaning, regardless of what address you've bound to, a `UdpSocket` 13 | /// is free to communicate with many different remotes. In tcp_ip there are basically two main ways to use `UdpSocket`: 14 | /// 15 | /// * one to many: [`bind`](`UdpSocket::bind`) and use [`send_to`](`UdpSocket::send_to`) 16 | /// and [`recv_from`](`UdpSocket::recv_from`) to communicate with many different addresses 17 | /// * many to many: [`bind_all`](`UdpSocket::bind_all`) and use [`send_from_to`](`UdpSocket::send_from_to`) 18 | /// and [`recv_from_to`](`UdpSocket::recv_from_to`) to communicate with many different addresses 19 | /// * one to one: [`connect`](`UdpSocket::connect`) and associate with a single address, using [`send`](`UdpSocket::send`) 20 | /// and [`recv`](`UdpSocket::recv`) to communicate only with that remote address 21 | /// 22 | /// This type does not provide a `split` method, because this functionality 23 | /// can be achieved by instead wrapping the socket in an [`Arc`]. Note that 24 | /// you do not need a `Mutex` to share the `UdpSocket` — an `Arc` 25 | /// is enough. This is because all of the methods take `&self` instead of 26 | /// `&mut self`. Once you have wrapped it in an `Arc`, you can call 27 | /// `.clone()` on the `Arc` to get multiple shared handles to the 28 | /// same socket. 29 | /// 30 | /// [`Arc`]: std::sync::Arc 31 | pub struct UdpSocket { 32 | _bind_addr: Option, 33 | ip_stack: IpStack, 34 | packet_receiver: flume::Receiver, 35 | local_addr: Option, 36 | peer_addr: Option, 37 | } 38 | #[cfg(feature = "global-ip-stack")] 39 | impl UdpSocket { 40 | pub async fn bind_all() -> io::Result { 41 | let ip_stack = IpStack::get()?; 42 | Self::bind0(ip_stack, None, None).await 43 | } 44 | pub async fn bind(local_addr: A) -> io::Result { 45 | let ip_stack = IpStack::get()?; 46 | let local_addr = local_addr.to_addr()?; 47 | ip_stack.routes().check_bind_ip(local_addr.ip())?; 48 | Self::bind0(ip_stack, Some(local_addr), None).await 49 | } 50 | } 51 | #[cfg(not(feature = "global-ip-stack"))] 52 | impl UdpSocket { 53 | pub async fn bind_all(ip_stack: IpStack) -> io::Result { 54 | Self::bind0(ip_stack, None, None).await 55 | } 56 | pub async fn bind(ip_stack: IpStack, local_addr: A) -> io::Result { 57 | let local_addr = local_addr.to_addr()?; 58 | ip_stack.routes().check_bind_ip(local_addr.ip())?; 59 | Self::bind0(ip_stack, Some(local_addr), None).await 60 | } 61 | } 62 | 63 | impl UdpSocket { 64 | async fn bind0(ip_stack: IpStack, mut local_addr: Option, peer_addr: Option) -> io::Result { 65 | let (packet_sender, packet_receiver) = flume::bounded(ip_stack.config.udp_channel_size); 66 | let _bind_addr = if let Some(addr) = &mut local_addr { 67 | Some(ip_stack.bind(IpNextHeaderProtocols::Udp, addr)?) 68 | } else { 69 | None 70 | }; 71 | ip_stack.add_udp_socket(local_addr, peer_addr, packet_sender)?; 72 | Ok(Self { 73 | _bind_addr, 74 | ip_stack, 75 | packet_receiver, 76 | local_addr, 77 | peer_addr, 78 | }) 79 | } 80 | } 81 | 82 | impl UdpSocket { 83 | pub fn local_addr(&self) -> io::Result { 84 | self.local_addr.ok_or_else(|| io::Error::from(io::ErrorKind::NotFound)) 85 | } 86 | pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { 87 | let (len, src, _dst) = self.recv_from_to(buf).await?; 88 | Ok((len, src)) 89 | } 90 | pub async fn send_to(&self, buf: &[u8], addr: A) -> io::Result { 91 | let Some(from) = self.local_addr else { 92 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "need to specify source address")); 93 | }; 94 | self.send_from_to(buf, from, addr).await 95 | } 96 | pub async fn recv_from_to(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr, SocketAddr)> { 97 | let Ok(packet) = self.packet_receiver.recv_async().await else { 98 | return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); 99 | }; 100 | let Some(udp_packet) = pnet_packet::udp::UdpPacket::new(&packet.buf) else { 101 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "not udp")); 102 | }; 103 | let len = udp_packet.payload().len(); 104 | if buf.len() < len { 105 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "buf too short")); 106 | } 107 | buf[..len].copy_from_slice(udp_packet.payload()); 108 | Ok((len, packet.network_tuple.src, packet.network_tuple.dst)) 109 | } 110 | pub async fn send_from_to(&self, buf: &[u8], src: A1, dst: A2) -> io::Result { 111 | self.send_from_to0(buf, src.to_addr()?, dst.to_addr()?).await 112 | } 113 | async fn send_from_to0(&self, buf: &[u8], src: SocketAddr, dst: SocketAddr) -> io::Result { 114 | let src = self.src_addr0(src, dst)?; 115 | if buf.len() > u16::MAX as usize - 8 { 116 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "buf too long")); 117 | } 118 | 119 | let mut data = BytesMut::with_capacity(8 + buf.len()); 120 | 121 | data.put_u16(src.port()); 122 | data.put_u16(dst.port()); 123 | data.put_u16(8 + buf.len() as u16); 124 | // checksum 125 | data.put_u16(0); 126 | data.extend_from_slice(buf); 127 | 128 | let checksum = match (src.ip(), dst.ip()) { 129 | (IpAddr::V4(src_ip), IpAddr::V4(dst_ip)) => { 130 | pnet_packet::util::ipv4_checksum(&data, 3, &[], &src_ip, &dst_ip, IpNextHeaderProtocols::Udp) 131 | } 132 | (IpAddr::V6(src_ip), IpAddr::V6(dst_ip)) => { 133 | pnet_packet::util::ipv6_checksum(&data, 3, &[], &src_ip, &dst_ip, IpNextHeaderProtocols::Udp) 134 | } 135 | (_, _) => { 136 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "address error")); 137 | } 138 | }; 139 | 140 | data[6..8].copy_from_slice(&checksum.to_be_bytes()); 141 | let network_tuple = NetworkTuple::new(src, dst, IpNextHeaderProtocols::Udp); 142 | 143 | let packet = TransportPacket::new(data, network_tuple); 144 | self.ip_stack.send_packet(packet).await?; 145 | Ok(buf.len()) 146 | } 147 | fn src_addr(&self, peer_addr: SocketAddr) -> io::Result { 148 | let Some(local_addr) = self.local_addr else { 149 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "need to specify source address")); 150 | }; 151 | self.src_addr0(local_addr, peer_addr) 152 | } 153 | fn src_addr0(&self, mut local_addr: SocketAddr, peer_addr: SocketAddr) -> io::Result { 154 | check_addr(peer_addr)?; 155 | if local_addr.port() == 0 { 156 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid port")); 157 | } 158 | if let Err(e) = check_ip(local_addr.ip()) { 159 | if let Some(v) = self.ip_stack.routes().route(peer_addr.ip()) { 160 | local_addr.set_ip(v); 161 | } else { 162 | Err(e)? 163 | } 164 | } 165 | Ok(local_addr) 166 | } 167 | } 168 | impl UdpSocket { 169 | pub async fn connect(&mut self, peer_addr: SocketAddr) -> io::Result<()> { 170 | let local_addr = self.src_addr(peer_addr)?; 171 | self.ip_stack 172 | .replace_udp_socket((self.local_addr, self.peer_addr), (Some(local_addr), Some(peer_addr)))?; 173 | self.local_addr = Some(local_addr); 174 | self.peer_addr = Some(peer_addr); 175 | Ok(()) 176 | } 177 | pub async fn connect_from_local(ip_stack: IpStack, local_addr: SocketAddr, peer_addr: SocketAddr) -> io::Result { 178 | Self::bind0(ip_stack, Some(local_addr), Some(peer_addr)).await 179 | } 180 | pub async fn send(&self, buf: &[u8]) -> io::Result { 181 | let Some(from) = self.local_addr else { 182 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "need to specify source address")); 183 | }; 184 | let Some(to) = self.peer_addr else { 185 | return Err(io::Error::new(io::ErrorKind::InvalidInput, "need to specify destination address")); 186 | }; 187 | self.send_from_to(buf, from, to).await 188 | } 189 | pub async fn recv(&self, buf: &mut [u8]) -> io::Result { 190 | let (len, _src, _dst) = self.recv_from_to(buf).await?; 191 | Ok(len) 192 | } 193 | } 194 | impl Drop for UdpSocket { 195 | fn drop(&mut self) { 196 | self.ip_stack.remove_udp_socket(self.local_addr, self.peer_addr); 197 | } 198 | } 199 | --------------------------------------------------------------------------------