├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── benches ├── bench_simd.rs.dep └── mod.rs.dep ├── readme.md └── src ├── The Speed Of Lite —— Lite-Tls Specification.md ├── args.rs ├── client ├── inbound │ ├── http.rs │ ├── listener.rs │ ├── mod.rs │ └── socks5.rs ├── mod.rs ├── outbound │ ├── connect.rs │ ├── mod.rs │ ├── quic.rs │ ├── request_cmd.rs │ ├── tcp_tls.rs │ └── trojan_auth.rs ├── run.rs └── utils │ ├── client_server_connection.rs │ ├── client_tcp_stream.rs │ ├── client_udp_stream.rs │ ├── connection_mode.rs │ ├── data_transfer.rs │ ├── mod.rs │ └── rustls_utils.rs ├── main.rs ├── protocol.rs ├── proxy.rs ├── server ├── inbound │ ├── acceptor.rs │ ├── handler.rs │ ├── listener.rs │ ├── mod.rs │ ├── quic.rs │ └── tcp_tls.rs ├── mod.rs ├── outbound │ ├── connector.rs │ ├── fallback.rs │ └── mod.rs ├── run.rs └── utils │ ├── lite_tls.rs │ ├── mod.rs │ ├── rustls_utils.rs │ └── server_udp_stream.rs ├── simd ├── mod.rs └── simd_parse.rs └── utils ├── adapter.rs ├── buffered_recv.rs ├── buffers.rs ├── copy_tcp.rs ├── dns_utils ├── dns_resolver.rs └── mod.rs ├── either_io.rs ├── forked_copy ├── copy_bidirectional.rs ├── copy_buf.rs └── mod.rs ├── glommio_utils ├── copy_bidirectional.rs ├── copy_buf.rs ├── mod.rs └── start_tcp_relay_thread.rs ├── lite_tls ├── The Journey to The Speed of Lite.md ├── error.rs ├── leave_tls.rs ├── lite_tls_stream.rs ├── mod.rs └── tls_relay_buffer.rs ├── macros.rs ├── mix_addr.rs ├── mod.rs ├── timedout_duplex_io.rs ├── udp ├── copy_udp.rs ├── copy_udp_bidirectional.rs ├── mod.rs ├── trojan_udp_stream.rs ├── udp_relay_buffer.rs ├── udp_shutdown.rs └── udp_traits.rs └── wr_tuple.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | .VSCodeCounter 3 | .devcontainer -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "trojan-oxide" 3 | version = "0.1.0" 4 | authors = ["3andne <3andne@github.com>"] 5 | edition = "2021" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | tokio = { version = "1.15.0", features = ["full"] } 11 | clap = "2.34.0" 12 | anyhow = "1.0.51" 13 | tracing = "0.1.26" 14 | tracing-subscriber = "0.3.3" 15 | futures = "0.3" 16 | err-derive = "0.3.0" 17 | directories = "4.0.1" 18 | structopt = "0.3.21" 19 | sha2 = "0.10.0" 20 | bytes = "1.0.1" 21 | pin-project-lite = "0.2.7" 22 | lazy_static = "1.4.0" 23 | rustls-native-certs = { version = "0.6.1" } 24 | mimalloc = { version = "*", default-features = false } 25 | fxhash = "0.2.1" 26 | rustls-pemfile = "*" 27 | 28 | [dependencies.quinn] 29 | version = "0.8.0" 30 | optional = true 31 | 32 | [dependencies.quinn-proto] 33 | version = "0.8.0" 34 | optional = true 35 | 36 | [dependencies.tokio-rustls] 37 | version = "0.23.1" 38 | optional = true 39 | 40 | [dependencies.rcgen] 41 | version = "0.8.11" 42 | optional = true 43 | 44 | [target.'cfg(target_os = "linux")'.dependencies] 45 | glommio = "0.6.0" 46 | num_cpus = "*" 47 | 48 | [profile.release] 49 | # codegen-units = 1 50 | incremental = false 51 | # lto = true 52 | opt-level = 3 53 | 54 | [features] 55 | server = [] 56 | client = [] 57 | tcp_tls = ["tokio-rustls"] 58 | lite_tls = ["tokio-rustls"] 59 | zio = [] 60 | quic = ["quinn", "quinn-proto", "rcgen"] 61 | udp = [] 62 | udp_info = [] 63 | debug_info = ["udp_info"] 64 | full = ["server", "client", "tcp_tls", "quic", "udp", "lite_tls"] 65 | dev = ["server", "client", "tcp_tls", "quic", "udp", "lite_tls", "debug_info"] 66 | server_full = ["server", "tcp_tls", "quic", "udp", "lite_tls"] 67 | client_full = ["client", "tcp_tls", "quic", "udp", "lite_tls"] 68 | default = ["full"] 69 | -------------------------------------------------------------------------------- /benches/bench_simd.rs.dep: -------------------------------------------------------------------------------- 1 | // #![feature(aarch64_target_feature)] 2 | // #![feature(stdsimd)] 3 | 4 | // use criterion::{black_box, criterion_group, criterion_main, Criterion}; 5 | 6 | // #[path = "../src/simd/simd_parse.rs"] 7 | // mod simd_parse; 8 | // use simd_parse::*; 9 | 10 | // const TESTSET1: [u8; 16] = [ 11 | // b'_', b'_', b'_', b'_', 12 | // b'_', b'_', b'_', b'_', 13 | // b'_', b'_', b'\r', b'\n', 14 | // b'_', b'_', b'_', b'_', 15 | // ]; 16 | 17 | // const TESTSET2: [u8; 32] = [ 18 | // b'_', b'_', b'_', b'_', 19 | // b'_', b'_', b'_', b'_', 20 | // b'_', b'_', b'_', b'_', 21 | // b'_', b'_', b'_', b'_', 22 | // b'_', b'_', b'_', b'_', 23 | // b'_', b'_', b'_', b'_', 24 | // b'_', b'_', b'\r', b'\n', 25 | // b'_', b'_', b'_', b'_', 26 | // ]; 27 | 28 | // const TESTSET3: [u8; 64] = [ 29 | // b'_', b'_', b'_', b'_', 30 | // b'_', b'_', b'_', b'_', 31 | // b'_', b'_', b'_', b'_', 32 | // b'_', b'_', b'_', b'_', 33 | // b'_', b'_', b'_', b'_', 34 | // b'_', b'_', b'_', b'_', 35 | // b'_', b'_', b'_', b'_', 36 | // b'_', b'_', b'_', b'_', 37 | // b'_', b'_', b'_', b'_', 38 | // b'_', b'_', b'_', b'_', 39 | // b'_', b'_', b'_', b'_', 40 | // b'_', b'_', b'_', b'_', 41 | // b'_', b'_', b'_', b'_', 42 | // b'_', b'_', b'_', b'_', 43 | // b'_', b'_', b'\r', b'\n', 44 | // b'_', b'_', b'_', b'_', 45 | // ]; 46 | 47 | // const TESTSET4: [u8; 128] = [ 48 | // b'_', b'_', b'_', b'_', 49 | // b'_', b'_', b'_', b'_', 50 | // b'_', b'_', b'_', b'_', 51 | // b'_', b'_', b'_', b'_', 52 | // b'_', b'_', b'_', b'_', 53 | // b'_', b'_', b'_', b'_', 54 | // b'_', b'_', b'_', b'_', 55 | // b'_', b'_', b'_', b'_', 56 | // b'_', b'_', b'_', b'_', 57 | // b'_', b'_', b'_', b'_', 58 | // b'_', b'_', b'_', b'_', 59 | // b'_', b'_', b'_', b'_', 60 | // b'_', b'_', b'_', b'_', 61 | // b'_', b'_', b'_', b'_', 62 | // b'_', b'_', b'_', b'_', 63 | // b'_', b'_', b'_', b'_', 64 | // b'_', b'_', b'_', b'_', 65 | // b'_', b'_', b'_', b'_', 66 | // b'_', b'_', b'_', b'_', 67 | // b'_', b'_', b'_', b'_', 68 | // b'_', b'_', b'_', b'_', 69 | // b'_', b'_', b'_', b'_', 70 | // b'_', b'_', b'_', b'_', 71 | // b'_', b'_', b'_', b'_', 72 | // b'_', b'_', b'_', b'_', 73 | // b'_', b'_', b'_', b'_', 74 | // b'_', b'_', b'_', b'_', 75 | // b'_', b'_', b'_', b'_', 76 | // b'_', b'_', b'_', b'_', 77 | // b'_', b'_', b'_', b'_', 78 | // b'_', b'_', b'\r', b'\n', 79 | // b'_', b'_', b'_', b'_', 80 | // ]; 81 | 82 | 83 | // fn criterion_benchmark(c: &mut Criterion) { 84 | // c.bench_function("simd-neon-16-16", |b| b.iter(|| simd16_wrap(black_box(&TESTSET1)))); 85 | // c.bench_function("simd-neon-8-16", |b| b.iter(|| simd8_wrap(black_box(&TESTSET1)))); 86 | // c.bench_function("scalar-16", |b| b.iter(|| parse_scalar(black_box(&TESTSET1)))); 87 | // c.bench_function("simd-neon-16-32", |b| b.iter(|| simd16_wrap(black_box(&TESTSET2)))); 88 | // c.bench_function("simd-neon-8-32", |b| b.iter(|| simd8_wrap(black_box(&TESTSET2)))); 89 | // c.bench_function("scalar-32", |b| b.iter(|| parse_scalar(black_box(&TESTSET2)))); 90 | // c.bench_function("simd-neon-16-64", |b| b.iter(|| simd16_wrap(black_box(&TESTSET3)))); 91 | // c.bench_function("simd-neon-8-64", |b| b.iter(|| simd8_wrap(black_box(&TESTSET3)))); 92 | // c.bench_function("scalar-64", |b| b.iter(|| parse_scalar(black_box(&TESTSET3)))); 93 | // c.bench_function("simd-neon-16-128", |b| b.iter(|| simd16_wrap(black_box(&TESTSET4)))); 94 | // c.bench_function("simd-neon-8-128", |b| b.iter(|| simd8_wrap(black_box(&TESTSET4)))); 95 | // c.bench_function("scalar-128", |b| b.iter(|| parse_scalar(black_box(&TESTSET4)))); 96 | // } 97 | 98 | // criterion_group!(benches, criterion_benchmark); 99 | // criterion_main!(benches); -------------------------------------------------------------------------------- /benches/mod.rs.dep: -------------------------------------------------------------------------------- 1 | // #![feature(aarch64_target_feature)] 2 | // #![feature(stdsimd)] 3 | // pub mod bench_simd; 4 | // // use simd_parse::*; -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Trojan-Oxide 2 | 3 | A Rust implementation of Trojan with QUIC tunnel, Lite-TLS and more. 4 | 5 | ## Overview 6 | 7 | 1. Full support for the original [Trojan](https://github.com/trojan-gfw/trojan ) Protocol, including TCP and UDP traffic. 8 | 2. Pure Rust implementation with no comprimise on security and speed. 9 | * Highly efficient [Tokio](https://github.com/tokio-rs/tokio)-based async network io. 10 | * Minimized memory consumption. 11 | * Predictable performance with no runtime garbage collector. 12 | * Poll based UDP Traffic. 13 | 3. [QUIC](https://en.wikipedia.org/wiki/QUIC) tunnel. The stealth Trojan implementation is still undetectable in the HTTP/3 era. 14 | 4. [Lite-TLS](https://github.com/3andne/trojan-oxide/wiki/The-Speed-of-Lite). Avoid redundant encrpytion with TLS traffics without modifying the underlying TLS library. We do things in the Trojan way, i.e., we imitate rather than create. 15 | 5. [Zero-Copy](https://kernel.dk/io_uring.pdf?source=techstories.org) (Linux Kernel >= 5.8 is required). With Lite-TLS enabled, we can achieve maximum efficiency on both the client and server sides. Up to [60% improvement](https://github.com/frevib/io_uring-echo-server/blob/master/benchmarks/benchmarks.md) is observed in a TCP echo server based on io-uring. 16 | 17 | ## License 18 | 19 | [GPL-3](https://github.com/3andne/trojan-oxide/blob/main/LICENSE) 20 | 21 | ## Examples 22 | 23 | ### Install Rust 24 | 25 | Please follow the [instructions](https://www.rust-lang.org/tools/install). 26 | 27 | ### Build Trojan-Oxide From Source 28 | 29 | ``` 30 | git clone https://github.com/3andne/trojan-oxide.git && cd ./trojan-oxide 31 | cargo build --release 32 | ``` 33 | 34 | The executable binary file is `./target/release/trojan-oxide`. 35 | 36 | #### Build Selected Features 37 | 38 | You can select features according to your needs; the default configuration builds both the server and client. 39 | 40 | ##### Server Only 41 | 42 | ``` 43 | cargo build --release --features server_full 44 | ``` 45 | 46 | ##### Client Only 47 | 48 | ``` 49 | cargo build --release --features client_full 50 | ``` 51 | 52 | ##### Zero Copy Feature 53 | 54 | This feature is disabled by default since it only works on Linux with a kernel >= 5.8. The following command will build this feature. 55 | 56 | ``` 57 | cargo build --release --features client_full,zio 58 | cargo build --release --features server_full,zio 59 | ``` 60 | 61 | ### Run Server 62 | 63 | Suppose you have a server `your.website.com`. 64 | 65 | * Your TLS certificate is in "/path/to/cert/fullchain.cert". 66 | 67 | * Your TLS private key is in "/path/to/key/private.key". 68 | 69 | * You want the server to listen on port `443`, and re-direct unauthenticated traffics to `80`. 70 | * You set password to `your_password`. **If your password contains '$', please write it as '\\$'.** 71 | 72 | Then you should start the server by: 73 | 74 | ``` 75 | ./target/release/trojan-oxide -s -w "your_password" -k "/path/to/key/private.key" -c "/path/to/cert/fullchain.cert" -u "your.website.com" -x 443 -f 80 76 | ``` 77 | 78 | Note that rustls (the underlying tls library) **doesn't support ECC keys** as of this moment. Please Follow the [instructions](https://github.com/rustls/rustls/issues/767) if you have a pair of ECC keys. 79 | 80 | ### Run Client 81 | 82 | If you have a `tcp-tls` trojan service on `your.website.com:443` with the password `your_password`. You can start your client by: 83 | 84 | ``` 85 | ./target/release/trojan-oxide -w "your_password" -u you.website.com -x 443 -m t 86 | ``` 87 | 88 | * the default tunnel is TCP-TLS 89 | 90 | * use `-m q` if you want to use the QUIC tunnel 91 | * use `-m l` if you want to use the Lite-TLS tunnel 92 | 93 | * you can also specify your server ip by: 94 | 95 | ``` 96 | ./target/release/trojan-oxide -w "your_password" -u you.website.com -d 114.51.4.191 -x 443 -m t 97 | ``` 98 | 99 | * **The default http and socks5 port is `8888 ` and `8889` respectively. Please specify them by `-h` and `-5`.** 100 | 101 | ### Run Zero Copy Endpoints 102 | 103 | Note that this feature only works when Linux kernel >= 5.8. Please build the client/server with `zio` feature first. 104 | 105 | Then start the client in Lite-TLS mode. 106 | 107 | ``` 108 | ./target/release/trojan-oxide -w "your_password" -u you.website.com -d 114.51.4.191 -x 443 -m l 109 | ``` 110 | 111 | You don't need to configure the server. 112 | 113 | ### Manual 114 | 115 | ``` 116 | ./target/release/trojan-oxide --help 117 | ``` 118 | 119 | ``` 120 | USAGE: 121 | trojan-oxide [FLAGS] [OPTIONS] --password [remote-socket-addr] 122 | 123 | FLAGS: 124 | --help 125 | Prints help information 126 | 127 | -s, --server 128 | whether to start as server 129 | 130 | -V, --version 131 | Prints version information 132 | 133 | 134 | OPTIONS: 135 | --ca 136 | 137 | 138 | -c, --cert 139 | TLS certificate in PEM format 140 | 141 | -m, --connection-mode 142 | Connetion Mode: 143 | 144 | - t (for tcp-tls) 145 | 146 | - q (for quic) 147 | 148 | - l (for lite-tls) [default: t] 149 | -f, --fallback-port 150 | port to re-direct unauthenticated connections [default: 0] 151 | 152 | -k, --key 153 | TLS private key in PEM format 154 | 155 | -h, --http_port 156 | client http proxy port [default: 8888] 157 | 158 | -5, --socks5_port 159 | client socks5 proxy port [default: 8889] 160 | 161 | -l, --log-level 162 | Log level (from least to most verbose): 163 | 164 | error < warn < info < debug < trace [default: info] 165 | -w, --password 166 | the password to authenticate connections 167 | 168 | -u, --server-hostname 169 | Server Name Indication (sni), or Hostname [default: localhost] 170 | 171 | -d, --server-ip 172 | server ip address [default: ] 173 | 174 | -x, --server-port 175 | server proxy port [default: 443] 176 | ``` 177 | -------------------------------------------------------------------------------- /src/The Speed Of Lite —— Lite-Tls Specification.md: -------------------------------------------------------------------------------- 1 | # The Speed Of Lite —— Lite-Tls Specification 2 | 3 | ## 术语定义 4 | 5 | ### TLS Packet Specification 6 | 7 | 我们需要简单了解一下tls(1.2/1.3)的包定义: 8 | ``` 9 | +-------------+-------------+--------+----------+ 10 | | Record Type | version | Length | Payload | 11 | +-------------+------+------+-------------------+ 12 | | 1 | 0x03 | 0x03 | 2 | Variable | 13 | +-------------+------+------+--------+----------+ 14 | ``` 15 | 16 | 根据标准,包头的Record Type有 17 | * 0x14: Change Cipher Spec 18 | * 0x16: Handshake 19 | * 0x17: Application Data 20 | * ... 21 | 22 | 其中,0x14和0x16会在握手过程中被使用,而0x17则是数据传输使用的类型,也就是可以被直接转发的包类型。 23 | 24 | ### 终端定义 25 | 26 | * `user` - 用户 27 | * `client` - 代理客户端 28 | * `server` - 代理服务端 29 | * `target` - 目标网站 30 | 31 | ### 其他术语 32 | 33 | `一手包`和`二手包`:从`user/target`那里直接获得的包是`一手包`,从`server/client`那里获得的包是`二手包`。例如,对于`client`而言,从`user`发来的包是`一手包`,从`server`发来的包是`二手包`。 34 | 35 | ## 握手流程 36 | 37 | ``` 38 | --->: tcp traffic 39 | ===>: tls over tcp traffic 40 | #################################################################### 41 | ---- 0x17 ---> [client] [server] 42 | 43 | [client] ==== 0x17 ====> [server] 44 | ^ the first 0x17 in this stream 45 | 46 | <== ...some traffics... ==> 47 | 48 | [client] [server] <-{..., 0x17}-- 49 | ^ active side *1 50 | 51 | [client] <={..., 0xff}== [server]{0x17} < cached *3 52 | passive side *2 ^ ^ a 0xff is appended 53 | 54 | [client]{...} [server]{0x17} 55 | ^ cached *4 56 | 57 | [client]{...} == 0xff => [server]{0x17} 58 | ^ a 0xff is returned *5 59 | 60 | [client]{...} [server]{0x17} 61 | ^ quit tls ^ quit tls *6 62 | 63 | [client] <- Plain Tcp -> [server] 64 | 65 | [client]{...} <- 0x17 -- [server] 66 | 67 | <-{..., 0x17}-- [client] [server] 68 | ...... 69 | ``` 70 | ### 注释: 71 | 72 | 1. active side: 第二个收到`一手0x17`的endpoint进入active mode 73 | 2. passive side: 收到`0xff`的endpoint进入passive mode 74 | 3. active side会把收到的0x17先缓存起来,在0x17之前、往往有与尚未发送的0x16/0x14包,我们把`0xff`包附在这些pending包的尾部,将这些包一起发往passive side,表示之后随时可以退出tls隧道 75 | 4. passive side收到`3`发过来的包之后,会验证`0xff`(之后丢弃),并把它前面的包缓存起来,等0x17到达后,一同发给`user`,否则会导致`user`(浏览器)因为收到的包不完整而出现错误。 76 | 5. passive side验证完`0xff`后,会返回一个`0xff`,表示自己已经不会再通过tls隧道接收数据,之后便退出`tls`隧道 77 | 6. 当active side收到返回的`0xff`后,便也退出tls隧道 78 | 7. 之后active side和passive side之间便通过tcp直接通信,active side把之前缓存的`0x17`发给passive side,passive side收到`0x17`后,连同之前缓存的包一起一次性发给`user`。之后整个过程结束。 79 | 80 | -------------------------------------------------------------------------------- /src/args.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "client")] 2 | use crate::client::ConnectionMode; 3 | use crate::protocol::HASH_LEN; 4 | use sha2::{Digest, Sha224}; 5 | use std::fmt::Write; 6 | use std::net::SocketAddr; 7 | use std::path::PathBuf; 8 | use std::sync::Arc; 9 | use structopt::StructOpt; 10 | use tokio::sync::broadcast; 11 | 12 | fn parse_log_level(l: &str) -> tracing::Level { 13 | match &l.to_lowercase()[..] { 14 | "info" => tracing::Level::INFO, 15 | "debug" => tracing::Level::DEBUG, 16 | "warn" => tracing::Level::WARN, 17 | "error" => tracing::Level::ERROR, 18 | "trace" => tracing::Level::TRACE, 19 | _ => tracing::Level::INFO, 20 | } 21 | } 22 | 23 | #[cfg(feature = "client")] 24 | fn parse_connection_mode(l: &str) -> ConnectionMode { 25 | use ConnectionMode::*; 26 | #[allow(unreachable_patterns)] 27 | match &l.to_lowercase()[..] { 28 | #[cfg(feature = "tcp_tls")] 29 | "tcp-tls" => TcpTLS, 30 | #[cfg(feature = "tcp_tls")] 31 | "t" => TcpTLS, 32 | #[cfg(feature = "tcp_tls")] 33 | "tcp" => TcpTLS, 34 | #[cfg(feature = "tcp_tls")] 35 | "tcp_tls" => TcpTLS, 36 | #[cfg(feature = "quic")] 37 | "quic" => Quic, 38 | #[cfg(feature = "quic")] 39 | "q" => Quic, 40 | #[cfg(feature = "lite_tls")] 41 | "l" => LiteTLS, 42 | #[cfg(feature = "tcp_tls")] 43 | _ => TcpTLS, 44 | #[cfg(feature = "lite_tls")] 45 | #[allow(unreachable_patterns)] 46 | _ => LiteTLS, 47 | #[cfg(feature = "quic")] 48 | #[allow(unreachable_patterns)] 49 | _ => Quic, 50 | } 51 | } 52 | 53 | #[cfg(feature = "client")] 54 | fn into_local_addr(l: &str) -> SocketAddr { 55 | ("127.0.0.1:".to_owned() + l).parse::().unwrap() 56 | } 57 | 58 | fn into_u16(l: &str) -> u16 { 59 | let mut res = 0; 60 | for i in l.bytes() { 61 | if i <= b'9' && i >= b'0' { 62 | res = res * 10 + (i - b'0') as u16; 63 | } else { 64 | panic!("invalid port value") 65 | } 66 | } 67 | res 68 | } 69 | 70 | fn password_to_hash(s: &str) -> String { 71 | let mut hasher = Sha224::new(); 72 | hasher.update(s); 73 | let h = hasher.finalize(); 74 | let mut s = String::with_capacity(HASH_LEN); 75 | for i in h { 76 | write!(&mut s, "{:02x}", i).unwrap(); 77 | } 78 | s 79 | } 80 | 81 | #[derive(StructOpt, Clone)] 82 | #[cfg_attr(feature = "debug_info", derive(Debug))] 83 | #[structopt(name = "basic")] 84 | pub struct Opt { 85 | /// client http proxy port 86 | #[cfg(feature = "client")] 87 | #[structopt(short = "h", long = "http_port", default_value = "8888", parse(from_str = into_local_addr))] 88 | pub local_http_addr: SocketAddr, 89 | 90 | /// client socks5 proxy port 91 | #[cfg(feature = "client")] 92 | #[structopt(short = "5", long = "socks5_port", default_value = "8889", parse(from_str = into_local_addr))] 93 | pub local_socks5_addr: SocketAddr, 94 | 95 | /// Log level (from least to most verbose): 96 | /// 97 | /// error < warn < info < debug < trace 98 | #[structopt(short = "l", long, default_value = "info", parse(from_str = parse_log_level))] 99 | pub log_level: tracing::Level, 100 | 101 | #[structopt(parse(from_os_str), long = "ca")] 102 | pub ca: Option, 103 | 104 | /// Server Name Indication (sni), or Hostname. 105 | #[structopt(short = "u", long, default_value = "localhost")] 106 | pub server_hostname: String, 107 | 108 | /// server proxy port 109 | #[structopt(short = "x", long, default_value = "443", parse(from_str = into_u16))] 110 | pub server_port: u16, 111 | 112 | /// server ip address 113 | #[structopt(short = "d", long, default_value = "")] 114 | pub server_ip: String, 115 | 116 | /// whether to start as server 117 | #[structopt(short, long)] 118 | pub server: bool, 119 | 120 | /// TLS private key in PEM format 121 | #[cfg(feature = "server")] 122 | #[structopt(parse(from_os_str), short = "k", long = "key", requires = "cert")] 123 | pub key: Option, 124 | 125 | /// TLS certificate in PEM format 126 | #[cfg(feature = "server")] 127 | #[structopt(parse(from_os_str), short = "c", long = "cert", requires = "key")] 128 | pub cert: Option, 129 | 130 | /// the password to authenticate connections 131 | #[structopt(short = "w", long, parse(from_str = password_to_hash))] 132 | pub password: String, 133 | 134 | /// port to re-direct unauthenticated connections 135 | #[cfg(feature = "server")] 136 | #[structopt(short = "f", long, default_value = "0", parse(from_str = into_u16))] 137 | pub fallback_port: u16, 138 | 139 | /// Connetion Mode: 140 | /// 141 | /// - t (for tcp-tls) 142 | /// 143 | /// - q (for quic) 144 | /// 145 | /// - l (for lite-tls) 146 | #[cfg(feature = "client")] 147 | #[structopt(short = "m", long, default_value = "t", parse(from_str = parse_connection_mode))] 148 | pub connection_mode: ConnectionMode, 149 | 150 | pub remote_socket_addr: Option, 151 | } 152 | 153 | #[cfg_attr(feature = "debug_info", derive(Debug))] 154 | pub struct TrojanContext { 155 | pub options: Arc, 156 | pub shutdown: broadcast::Receiver<()>, 157 | } 158 | 159 | impl TrojanContext { 160 | pub fn clone_with_signal(&self, shutdown: broadcast::Receiver<()>) -> Self { 161 | Self { 162 | options: self.options.clone(), 163 | shutdown, 164 | } 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/client/inbound/http.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | client::utils::new_client_tcp_stream, 3 | utils::ConnectionRequest, 4 | utils::{MixAddrType, ParserError}, 5 | }; 6 | 7 | use anyhow::{Error, Result}; 8 | use futures::Future; 9 | // use futures::future; 10 | // use std::io::IoSlice; 11 | // use std::pin::Pin; 12 | use tokio::io::*; 13 | use tokio::net::TcpStream; 14 | use tracing::*; 15 | 16 | use crate::client::utils::ClientConnectionRequest; 17 | 18 | use super::{listener::RequestFromClient, ClientRequestAcceptResult}; 19 | 20 | pub struct HttpRequest { 21 | is_https: bool, 22 | addr: MixAddrType, 23 | cursor: usize, 24 | inbound: Option, 25 | } 26 | 27 | const HEADER0: &'static [u8] = b"GET / HTTP/1.1\r\nHost: "; 28 | const HEADER1: &'static [u8] = b"\r\nConnection: keep-alive\r\n\r\n"; 29 | 30 | impl HttpRequest { 31 | fn set_stream_type(&mut self, buf: &Vec) -> Result<(), ParserError> { 32 | if buf.len() < 4 { 33 | return Err(ParserError::Incomplete( 34 | "HttpRequest::set_stream_type".into(), 35 | )); 36 | } 37 | 38 | if &buf[..4] == b"GET " { 39 | self.is_https = false; 40 | self.cursor = 4; 41 | return Ok(()); 42 | } 43 | 44 | if buf.len() < 8 { 45 | return Err(ParserError::Incomplete( 46 | "HttpRequest::set_stream_type".into(), 47 | )); 48 | } 49 | 50 | if &buf[..8] == b"CONNECT " { 51 | self.is_https = true; 52 | self.cursor = 8; 53 | return Ok(()); 54 | } 55 | 56 | return Err(ParserError::Invalid("HttpRequest::set_stream_type".into())); 57 | } 58 | 59 | fn set_host(&mut self, buf: &Vec) -> Result<(), ParserError> { 60 | #[cfg(feature = "debug_info")] 61 | debug!("set_host entered"); 62 | while self.cursor < buf.len() && buf[self.cursor] == b' ' { 63 | self.cursor += 1; 64 | } 65 | if !self.is_https { 66 | if self.cursor + 7 < buf.len() { 67 | if &buf[self.cursor..self.cursor + 7].to_ascii_lowercase()[..] == b"http://" { 68 | self.cursor += 7; 69 | } 70 | } else { 71 | return Err(ParserError::Incomplete("HttpRequest::set_host".into())); 72 | } 73 | } 74 | 75 | let start = self.cursor; 76 | let mut end = start; 77 | while end < buf.len() && buf[end] != b' ' && buf[end] != b'/' { 78 | end += 1; 79 | } 80 | 81 | if end == buf.len() { 82 | return Err(ParserError::Incomplete("HttpRequest::set_host".into())); 83 | } 84 | 85 | self.addr = MixAddrType::from_http_header(self.is_https, &buf[start..end])?; 86 | return Ok(()); 87 | } 88 | 89 | fn parse(&mut self, buf: &mut Vec) -> Result<(), ParserError> { 90 | #[cfg(feature = "debug_info")] 91 | debug!("parsing: {:?}", String::from_utf8(buf.clone())); 92 | if self.cursor == 0 { 93 | self.set_stream_type(buf)?; 94 | } 95 | 96 | #[cfg(feature = "debug_info")] 97 | debug!("stream is https: {}", self.is_https); 98 | 99 | if self.addr.is_none() { 100 | match self.set_host(buf) { 101 | Ok(_) => { 102 | #[cfg(feature = "debug_info")] 103 | debug!("stream target host: {:?}", self.addr); 104 | } 105 | err @ Err(_) => { 106 | #[cfg(feature = "debug_info")] 107 | debug!("stream target host err: {:?}", err); 108 | return err; 109 | } 110 | } 111 | } 112 | 113 | // `integrity` check 114 | if &buf[buf.len() - 4..] == b"\r\n\r\n" { 115 | #[cfg(feature = "debug_info")] 116 | debug!("integrity test passed"); 117 | return Ok(()); 118 | } 119 | 120 | for i in 0..4 { 121 | buf[i] = buf[buf.len() - 4 + i]; 122 | } 123 | 124 | unsafe { 125 | buf.set_len(4); 126 | } 127 | Err(ParserError::Incomplete("HttpRequest::parse".into())) 128 | } 129 | 130 | async fn impl_accept(&mut self) -> Result { 131 | let mut buffer = Vec::with_capacity(200); 132 | let mut inbound = self.inbound.take().unwrap(); 133 | loop { 134 | let read = inbound.read_buf(&mut buffer).await?; 135 | if read != 0 { 136 | match self.parse(&mut buffer) { 137 | Ok(_) => { 138 | #[cfg(feature = "debug_info")] 139 | debug!("http request parsed"); 140 | break; 141 | } 142 | Err(e @ ParserError::Invalid(_)) => { 143 | return Err(Error::new(e)); 144 | } 145 | _ => (), 146 | } 147 | } else { 148 | return Err(Error::new(ParserError::Invalid( 149 | "HttpRequest::accept unable to accept before EOF".into(), 150 | ))); 151 | } 152 | } 153 | 154 | let http_p0 = if self.is_https { 155 | inbound 156 | .write_all(b"HTTP/1.1 200 Connection established\r\n\r\n") 157 | .await?; 158 | inbound.flush().await?; 159 | debug!("https packet 0 sent"); 160 | None 161 | } else { 162 | let (host, port) = self.addr.as_host(); 163 | Some( 164 | [ 165 | HEADER0, 166 | host.as_bytes(), 167 | &[':' as u8], 168 | port.to_string().as_bytes(), 169 | HEADER1, 170 | ] 171 | .concat(), 172 | ) 173 | // let bufs = [ 174 | // IoSlice::new(HEADER0), 175 | // IoSlice::new(self.host_raw.as_bytes()), 176 | // IoSlice::new(HEADER1), 177 | // ]; 178 | 179 | // future::poll_fn(|cx| writer.as_mut().poll_write_vectored(cx, &bufs[..])) 180 | // .await 181 | // .map_err(|e| Box::new(e))?; 182 | 183 | // debug!("http packet 0 sent"); 184 | }; 185 | 186 | Ok(ConnectionRequest::TCP(new_client_tcp_stream( 187 | inbound, http_p0, 188 | ))) 189 | } 190 | } 191 | 192 | impl RequestFromClient for HttpRequest { 193 | type Accepting<'a> = impl Future + Send; 194 | 195 | fn accept<'a>(mut self) -> Self::Accepting<'a> { 196 | async move { Ok::<_, Error>((self.impl_accept().await?, self.addr)) } 197 | } 198 | 199 | fn new(inbound: TcpStream) -> Self { 200 | Self { 201 | is_https: false, 202 | addr: MixAddrType::None, 203 | cursor: 0, 204 | inbound: Some(inbound), 205 | } 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /src/client/inbound/listener.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | args::TrojanContext, 3 | client::{ 4 | outbound::forward, 5 | utils::{ClientConnectionRequest, ClientServerConnection}, 6 | ConnectionMode, 7 | }, 8 | or_continue, try_recv, 9 | utils::MixAddrType, 10 | }; 11 | 12 | #[cfg(feature = "quic")] 13 | use crate::client::outbound::quic::*; 14 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 15 | use crate::client::outbound::tcp_tls::*; 16 | use anyhow::Result; 17 | use futures::TryFutureExt; 18 | use std::{future::Future, net::SocketAddr, sync::Arc}; 19 | use tokio::{ 20 | net::{TcpListener, TcpStream}, 21 | sync::{broadcast, oneshot}, 22 | }; 23 | use tracing::*; 24 | 25 | pub type ClientRequestAcceptResult = Result<(ClientConnectionRequest, MixAddrType)>; 26 | 27 | pub trait RequestFromClient { 28 | type Accepting<'a>: Future + Send; 29 | 30 | fn new(inbound: TcpStream) -> Self; 31 | fn accept<'a>(self) -> Self::Accepting<'a>; 32 | } 33 | 34 | pub async fn user_endpoint_listener( 35 | service_addr: SocketAddr, 36 | mut context: TrojanContext, 37 | ) -> Result<()> 38 | where 39 | Acceptor: RequestFromClient + Send + 'static, 40 | { 41 | let (shutdown_tx, shutdown) = broadcast::channel::<()>(1); 42 | let service_listener = TcpListener::bind(&service_addr).await?; 43 | 44 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 45 | let tls_config = Arc::new(tls_client_config().await); 46 | 47 | #[cfg(feature = "quic")] 48 | let (task_tx, task_rx) = tokio::sync::mpsc::channel(20); 49 | #[cfg(feature = "quic")] 50 | tokio::spawn(quic_connection_daemon( 51 | context.clone_with_signal(shutdown), 52 | task_rx, 53 | )); 54 | 55 | loop { 56 | try_recv!(broadcast, context.shutdown); 57 | let (stream, _) = service_listener.accept().await?; 58 | debug!("accepted http: {:?}", stream); 59 | let incoming: _ = Acceptor::new(stream).accept(); 60 | let new_context = context.clone_with_signal(shutdown_tx.subscribe()); 61 | match &context.options.connection_mode { 62 | #[cfg(feature = "tcp_tls")] 63 | ConnectionMode::TcpTLS => { 64 | let connecting: _ = TrojanTcpTlsConnector::new(tls_config.clone(), false) 65 | .connect(context.options.clone()); 66 | tokio::spawn( 67 | forward(new_context, incoming, connecting) 68 | .map_err(|e| error!("[tcp-tls]forward failed: {:?}", e)), 69 | ); 70 | } 71 | #[cfg(feature = "lite_tls")] 72 | ConnectionMode::LiteTLS => { 73 | let connecting: _ = TrojanTcpTlsConnector::new(tls_config.clone(), true) 74 | .connect(context.options.clone()); 75 | tokio::spawn( 76 | forward(new_context, incoming, connecting) 77 | .map_err(|e| error!("[lite]forward failed: {:?}", e)), 78 | ); 79 | } 80 | #[cfg(feature = "quic")] 81 | ConnectionMode::Quic => { 82 | let (conn_ret_tx, conn_ret_rx) = oneshot::channel(); 83 | or_continue!(task_tx.send(conn_ret_tx).await); 84 | tokio::spawn(forward(new_context, incoming, async move { 85 | Ok(ClientServerConnection::Quic(conn_ret_rx.await??)) 86 | })); 87 | } 88 | } 89 | } 90 | Ok(()) 91 | } 92 | -------------------------------------------------------------------------------- /src/client/inbound/mod.rs: -------------------------------------------------------------------------------- 1 | mod http; 2 | mod listener; 3 | mod socks5; 4 | 5 | pub use http::HttpRequest; 6 | pub use listener::{user_endpoint_listener, ClientRequestAcceptResult, RequestFromClient}; 7 | pub use socks5::Socks5Request; 8 | -------------------------------------------------------------------------------- /src/client/inbound/socks5.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | client::utils::{new_client_tcp_stream, ClientConnectionRequest}, 3 | expect_buf_len, 4 | utils::{ConnectionRequest, MixAddrType, ParserError}, 5 | }; 6 | 7 | #[cfg(feature = "udp")] 8 | use crate::client::utils::Socks5UdpStream; 9 | use futures::Future; 10 | #[cfg(feature = "udp")] 11 | use std::net::SocketAddr; 12 | #[cfg(feature = "udp")] 13 | use tokio::net::UdpSocket; 14 | #[cfg(feature = "udp")] 15 | use tokio::sync::oneshot; 16 | 17 | use anyhow::{Error, Result}; 18 | use tokio::io::*; 19 | use tokio::net::TcpStream; 20 | #[cfg(feature = "debug_info")] 21 | use tracing::*; 22 | 23 | use super::{listener::RequestFromClient, ClientRequestAcceptResult}; 24 | 25 | const SOCKS_VERSION_INDEX: usize = 0; 26 | const NUM_SUPPORTED_AUTH_METHOD_INDEX: usize = 1; 27 | const CONNECTION_TYPE_INDEX: usize = 1; 28 | const ADDR_TYPE_INDEX: usize = 3; 29 | // const LEN_OF_ADDR_INDEX: usize = 4; 30 | const PHASE1_SERVER_REPLY: [u8; 2] = [0x05, 0x00]; 31 | const PHASE2_SERVER_REPLY: [u8; 3] = [0x05, 0x00, 0x00]; 32 | 33 | pub struct Socks5Request { 34 | phase: Sock5ParsePhase, 35 | is_udp: bool, 36 | addr: MixAddrType, 37 | inbound: Option, 38 | } 39 | 40 | enum Sock5ParsePhase { 41 | P1ClientHello, 42 | P2ClientRequest, 43 | } 44 | 45 | impl Socks5Request { 46 | async fn impl_accept(&mut self) -> Result { 47 | let mut buffer = Vec::with_capacity(200); 48 | let mut inbound = self.inbound.take().unwrap(); 49 | loop { 50 | let read = inbound.read_buf(&mut buffer).await?; 51 | if read != 0 { 52 | match self.parse(&mut buffer) { 53 | Ok(_) => { 54 | use Sock5ParsePhase::*; 55 | match self.phase { 56 | P1ClientHello => { 57 | inbound.write_all(&PHASE1_SERVER_REPLY).await?; 58 | #[cfg(feature = "debug_info")] 59 | debug!("socks5 Phase 1 parsed"); 60 | self.phase = P2ClientRequest; 61 | unsafe { 62 | // reset buffer 63 | buffer.set_len(0); 64 | } 65 | } 66 | P2ClientRequest => { 67 | #[cfg(feature = "debug_info")] 68 | debug!("socks5 Phase 2 parsed"); 69 | break; 70 | } 71 | } 72 | } 73 | Err(e @ ParserError::Invalid(_)) => { 74 | return Err(Error::new(e)); 75 | } 76 | _ => (), 77 | } 78 | } else { 79 | return Err(Error::new(ParserError::Invalid( 80 | "Socks5Request::accept unable to accept before EOF".into(), 81 | ))); 82 | } 83 | } 84 | 85 | let mut buf = Vec::with_capacity(3 + 1 + 16 + 2); 86 | buf.extend_from_slice(&PHASE2_SERVER_REPLY); 87 | 88 | match self.is_udp { 89 | false => { 90 | MixAddrType::from(&inbound.local_addr()?).write_buf(&mut buf); 91 | inbound.write_all(&buf).await?; 92 | Ok(ConnectionRequest::TCP(new_client_tcp_stream(inbound, None))) 93 | } 94 | #[cfg(feature = "udp")] 95 | true => { 96 | let local_ip = inbound.local_addr()?.ip(); 97 | let server_udp_socket = UdpSocket::bind(SocketAddr::new(local_ip, 0)).await?; 98 | MixAddrType::from(&server_udp_socket.local_addr()?).write_buf(&mut buf); 99 | inbound.write_all(&buf).await?; 100 | let (stream_reset_signal_tx, stream_reset_signal_rx) = oneshot::channel(); 101 | 102 | tokio::spawn(async move { 103 | let mut dummy = [0u8; 3]; 104 | let _ = inbound.read(&mut dummy).await; 105 | let _ = stream_reset_signal_tx.send(()); 106 | }); 107 | 108 | let udp_stream = Socks5UdpStream::new(server_udp_socket, stream_reset_signal_rx); 109 | Ok(ConnectionRequest::UDP(udp_stream)) 110 | } 111 | #[cfg(not(feature = "udp"))] 112 | _ => { 113 | panic!("Udp not included, re-compile to include") 114 | } 115 | } 116 | } 117 | 118 | fn parse(&mut self, buf: &Vec) -> Result<(), ParserError> { 119 | use Sock5ParsePhase::*; 120 | match self.phase { 121 | P1ClientHello => { 122 | expect_buf_len!(buf, 2, "Sock5ParsePhase::parse phase 1 incomplete[1]"); 123 | if buf[SOCKS_VERSION_INDEX] != 5 { 124 | return Err(ParserError::Invalid( 125 | "Socks5Request::parse only support socks v5".into(), 126 | )); 127 | } 128 | let num = buf[NUM_SUPPORTED_AUTH_METHOD_INDEX]; 129 | 130 | let expected_len = 2 + num as usize; 131 | expect_buf_len!( 132 | buf, 133 | expected_len, 134 | "Sock5ParsePhase::parse phase 1 incomplete[2]" 135 | ); 136 | 137 | for &method in buf[2..expected_len].iter() { 138 | if method == 0 { 139 | return Ok(()); 140 | } 141 | } 142 | return Err(ParserError::Invalid( 143 | "Socks5Request::parse method invalid".into(), 144 | )); 145 | } 146 | P2ClientRequest => { 147 | expect_buf_len!(buf, 5, "Sock5ParsePhase::parse phase 2 incomplete[1]"); 148 | if buf[SOCKS_VERSION_INDEX] != 5 { 149 | return Err(ParserError::Invalid( 150 | "Socks5Request::parse only support socks v5".into(), 151 | )); 152 | } 153 | 154 | match buf[CONNECTION_TYPE_INDEX] { 155 | 0x01 => { 156 | self.is_udp = false; 157 | } 158 | 0x03 => { 159 | self.is_udp = true; 160 | } 161 | _ => { 162 | return Err(ParserError::Invalid( 163 | "Socks5Request::parse invalid connection type".into(), 164 | )); 165 | } 166 | } 167 | 168 | self.addr = MixAddrType::from_encoded_bytes(&buf[ADDR_TYPE_INDEX..])?.0; 169 | 170 | return Ok(()); 171 | } 172 | } 173 | } 174 | } 175 | 176 | impl RequestFromClient for Socks5Request { 177 | type Accepting<'a> = impl Future + Send; 178 | 179 | fn new(inbound: TcpStream) -> Self { 180 | Self { 181 | phase: Sock5ParsePhase::P1ClientHello, 182 | is_udp: false, 183 | addr: MixAddrType::None, 184 | inbound: Some(inbound), 185 | } 186 | } 187 | 188 | fn accept<'a>(mut self) -> Self::Accepting<'a> { 189 | async { Ok::<_, Error>((self.impl_accept().await?, self.addr)) } 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /src/client/mod.rs: -------------------------------------------------------------------------------- 1 | mod inbound; 2 | mod outbound; 3 | mod run; 4 | mod utils; 5 | 6 | pub use run::run_client; 7 | pub use utils::ConnectionMode; -------------------------------------------------------------------------------- /src/client/outbound/connect.rs: -------------------------------------------------------------------------------- 1 | use super::trojan_auth::trojan_auth; 2 | 3 | #[cfg(feature = "quic")] 4 | use crate::client::outbound::quic::send_echo; 5 | #[cfg(feature = "udp")] 6 | use crate::client::utils::relay_udp; 7 | 8 | use crate::client::inbound::ClientRequestAcceptResult; 9 | use crate::{ 10 | args::TrojanContext, 11 | client::{ 12 | outbound::request_cmd::ClientRequestCMD, 13 | utils::{relay_tcp, ClientServerConnection}, 14 | }, 15 | utils::ConnectionRequest, 16 | }; 17 | use anyhow::Result; 18 | use std::{ 19 | future::Future, 20 | sync::atomic::{AtomicUsize, Ordering}, 21 | }; 22 | use tracing::*; 23 | 24 | static TCP_CONNECTION_COUNTER: AtomicUsize = AtomicUsize::new(0); 25 | static UDP_CONNECTION_COUNTER: AtomicUsize = AtomicUsize::new(0); 26 | 27 | pub async fn forward( 28 | context: TrojanContext, 29 | incomming: Incomming, 30 | connecting: Connecting, 31 | ) -> Result<()> 32 | where 33 | Incomming: Future + Send, 34 | Connecting: Future> + Send, 35 | { 36 | let (conn_req, addr) = incomming.await.map_err(|e| { 37 | error!("forward error: {:#}", e); 38 | e 39 | })?; 40 | 41 | let mut outbound = connecting.await.map_err(|e| { 42 | error!("forward error: {:#}", e); 43 | e 44 | })?; 45 | 46 | let opt = &*context.options; 47 | let connection_cmd = ClientRequestCMD(&conn_req, &outbound).get_cmd(); 48 | trojan_auth(connection_cmd, &addr, &mut outbound, &opt.password).await?; 49 | 50 | use ConnectionRequest::*; 51 | match conn_req { 52 | TCP(inbound) => { 53 | let conn_id = TCP_CONNECTION_COUNTER.fetch_add(1, Ordering::Relaxed); 54 | relay_tcp(inbound, outbound, context.shutdown, conn_id, &addr).await?; 55 | } 56 | #[cfg(feature = "udp")] 57 | UDP(inbound) => { 58 | let conn_id = UDP_CONNECTION_COUNTER.fetch_add(1, Ordering::Relaxed); 59 | info!("[udp][{}] => {:?}", conn_id, &addr); 60 | relay_udp(inbound, outbound, context.shutdown, conn_id).await?; 61 | info!("[end][udp][{}]", conn_id); 62 | } 63 | #[cfg(feature = "quic")] 64 | ECHO(echo_rx) => match outbound { 65 | ClientServerConnection::Quic(outbound) => { 66 | send_echo(outbound, echo_rx).await; 67 | } 68 | _ => unreachable!(), 69 | }, 70 | _PHANTOM(_) => unreachable!(), 71 | #[allow(unreachable_patterns)] 72 | _ => panic!("functionality not included"), 73 | } 74 | 75 | Ok(()) 76 | } 77 | -------------------------------------------------------------------------------- /src/client/outbound/mod.rs: -------------------------------------------------------------------------------- 1 | mod connect; 2 | #[cfg(feature = "quic")] 3 | pub mod quic; 4 | mod request_cmd; 5 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 6 | pub mod tcp_tls; 7 | pub mod trojan_auth; 8 | 9 | pub(super) use connect::forward; 10 | -------------------------------------------------------------------------------- /src/client/outbound/quic.rs: -------------------------------------------------------------------------------- 1 | use super::forward; 2 | use crate::{ 3 | args::{Opt, TrojanContext}, 4 | client::utils::{get_rustls_config, ClientConnectionRequest, ClientServerConnection}, 5 | protocol::*, 6 | utils::MixAddrType, 7 | }; 8 | use anyhow::{anyhow, Error, Result}; 9 | use quinn::*; 10 | use std::{ 11 | sync::{ 12 | atomic::{AtomicBool, Ordering::SeqCst}, 13 | Arc, 14 | }, 15 | time::Duration, 16 | }; 17 | use tokio::{ 18 | fs, io, select, 19 | sync::{broadcast, mpsc, oneshot}, 20 | time::{sleep, timeout}, 21 | }; 22 | use tokio_rustls::{rustls, rustls::RootCertStore}; 23 | use tracing::*; 24 | 25 | pub static IS_CONNECTION_OPENED: AtomicBool = AtomicBool::new(false); 26 | 27 | #[derive(Default)] 28 | struct QuicConnectionWrapper { 29 | connection: Option, 30 | concurrent_streams_counter: usize, 31 | } 32 | 33 | impl QuicConnectionWrapper { 34 | pub fn refresh(&mut self, conn: Connection) { 35 | self.connection = Some(conn); 36 | self.concurrent_streams_counter = 0; 37 | } 38 | 39 | pub fn open_bi(&mut self) -> Option { 40 | if self.has_remaining() { 41 | self.concurrent_streams_counter += 1; 42 | Some(self.connection.as_ref().unwrap().open_bi()) 43 | } else { 44 | None 45 | } 46 | } 47 | 48 | pub fn has_remaining(&self) -> bool { 49 | self.concurrent_streams_counter < MAX_CONCURRENT_BIDI_STREAMS 50 | } 51 | } 52 | 53 | pub async fn send_echo(echo_stream: (SendStream, RecvStream), mut echo_rx: mpsc::Receiver<()>) { 54 | let (mut write, mut read) = echo_stream; 55 | 56 | let mut buf = [0u8; ECHO_PHRASE.len()]; 57 | loop { 58 | let _ = echo_rx.recv().await; 59 | match timeout( 60 | Duration::from_secs(2), 61 | write.write_all(ECHO_PHRASE.as_bytes()), 62 | ) 63 | .await 64 | { 65 | Ok(Ok(_)) => { 66 | debug!("echo written"); 67 | } 68 | other => { 69 | info!( 70 | "[echo][send] connection reset detected: {:?}, buf {:?}", 71 | other, buf 72 | ); 73 | IS_CONNECTION_OPENED.store(false, SeqCst); 74 | echo_rx.close(); 75 | return; 76 | } 77 | } 78 | 79 | match timeout(Duration::from_secs(2), read.read_exact(&mut buf)).await { 80 | Ok(Ok(_)) => { 81 | debug!("echo received"); 82 | } 83 | other => { 84 | info!( 85 | "[echo][recv] connection reset detected: {:?}, buf {:?}", 86 | other, buf 87 | ); 88 | IS_CONNECTION_OPENED.store(false, SeqCst); 89 | echo_rx.close(); 90 | return; 91 | } 92 | } 93 | 94 | sleep(Duration::from_secs(5)).await; 95 | } 96 | } 97 | 98 | pub struct EndpointManager { 99 | outbound: Endpoint, 100 | connection: QuicConnectionWrapper, 101 | options: Arc, 102 | echo_tx: Option>, 103 | shudown_tx: broadcast::Sender<()>, 104 | } 105 | 106 | impl EndpointManager { 107 | pub async fn new(options: Arc) -> Result { 108 | let mut outbound = quinn::Endpoint::client("[::]:0".parse().unwrap())?; 109 | outbound.set_default_client_config(new_builder(&options).await?); 110 | 111 | let (shudown_tx, _) = broadcast::channel(1); 112 | let mut _self = Self { 113 | outbound, 114 | connection: QuicConnectionWrapper::default(), 115 | options, 116 | echo_tx: None, 117 | shudown_tx, 118 | }; 119 | 120 | _self.new_connection().await?; 121 | 122 | Ok(_self) 123 | } 124 | 125 | fn echo_task_status(&self) -> bool { 126 | match self.echo_tx { 127 | Some(ref tx) => !tx.is_closed(), 128 | None => false, 129 | } 130 | } 131 | 132 | async fn echo(&mut self) -> Result { 133 | if !self.echo_task_status() { 134 | let open_bi = match self.connection.open_bi() { 135 | None => return Err(anyhow!("failed to open bi conn")), 136 | Some(open_bi) => open_bi, 137 | }; 138 | let connecting: _ = 139 | async move { Ok::<_, Error>(ClientServerConnection::Quic(open_bi.await?)) }; 140 | 141 | let (echo_tx, echo_rx) = mpsc::channel::<()>(1); 142 | let incomming: _ = async { 143 | Ok(( 144 | ClientConnectionRequest::ECHO(echo_rx), 145 | MixAddrType::new_null(), 146 | )) 147 | }; 148 | 149 | self.echo_tx = Some(echo_tx); 150 | let context = TrojanContext { 151 | options: self.options.clone(), 152 | shutdown: self.shudown_tx.subscribe(), 153 | }; 154 | 155 | tokio::spawn(forward(context, incomming, connecting)); 156 | } 157 | 158 | let _ = self.echo_tx.as_mut().unwrap().try_send(()); 159 | 160 | Ok(IS_CONNECTION_OPENED.load(SeqCst)) 161 | } 162 | 163 | pub async fn connect(&mut self) -> Result<(SendStream, RecvStream)> { 164 | if !self.connection.has_remaining() || !self.echo().await? { 165 | debug!("[connect] re-connecting"); 166 | self.new_connection().await?; 167 | } 168 | 169 | debug!("[connect] connection request"); 170 | let new_tunnel = self.connection.open_bi().unwrap().await?; 171 | Ok(new_tunnel) 172 | } 173 | 174 | async fn new_connection(&mut self) -> Result<()> { 175 | let new_conn = timeout( 176 | Duration::from_secs(2), 177 | self.outbound.connect( 178 | self.options.remote_socket_addr.unwrap(), 179 | &self.options.server_hostname, 180 | )?, 181 | ) 182 | .await 183 | .map_err(|e| Error::new(e))? 184 | .map_err(|e| Error::new(e))?; 185 | 186 | let quinn::NewConnection { 187 | connection: conn, .. 188 | } = new_conn; 189 | self.connection.refresh(conn); 190 | IS_CONNECTION_OPENED.store(true, SeqCst); 191 | Ok(()) 192 | } 193 | } 194 | 195 | async fn new_builder(options: &Opt) -> Result { 196 | let mut crypto_config = get_rustls_config(load_cert(options, RootCertStore::empty()).await?); 197 | crypto_config.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); 198 | let mut config = quinn::ClientConfig::new(Arc::new(crypto_config)); 199 | 200 | let transport_cfg = Arc::get_mut(&mut config.transport).unwrap(); 201 | transport_cfg.max_idle_timeout(Some(QUIC_MAX_IDLE_TIMEOUT.try_into()?)); 202 | transport_cfg.persistent_congestion_threshold(6); 203 | transport_cfg.max_concurrent_bidi_streams(MAX_CONCURRENT_BIDI_STREAMS.try_into()?); 204 | transport_cfg.packet_threshold(4); 205 | Ok(config) 206 | } 207 | 208 | async fn load_cert(options: &Opt, mut roots: RootCertStore) -> Result { 209 | if let Some(ca_path) = &options.ca { 210 | roots.add(&rustls::Certificate(fs::read(&ca_path).await?))?; 211 | } else { 212 | let dirs = directories::ProjectDirs::from("org", "quinn", "quinn-examples").unwrap(); 213 | match fs::read(dirs.data_local_dir().join("cert.der")).await { 214 | Ok(cert) => { 215 | roots.add(&rustls::Certificate(cert))?; 216 | } 217 | Err(e) => { 218 | if e.kind() == io::ErrorKind::NotFound { 219 | info!("local server certificate not found"); 220 | } else { 221 | error!("failed to open local server certificate: {:#}", e); 222 | } 223 | return Err(anyhow::Error::new(e)); 224 | } 225 | } 226 | } 227 | Ok(roots) 228 | } 229 | 230 | pub async fn quic_connection_daemon( 231 | context: TrojanContext, 232 | mut task_rx: mpsc::Receiver>>, 233 | ) -> Result<()> { 234 | debug!("quic_connection_daemon enter"); 235 | let TrojanContext { 236 | mut shutdown, 237 | options, 238 | } = context; 239 | let mut endpoint = EndpointManager::new(options) 240 | .await 241 | .expect("EndpointManager::new"); 242 | 243 | loop { 244 | select! { 245 | maybe_ret_tx = task_rx.recv() => { 246 | match maybe_ret_tx { 247 | None => break, 248 | Some(ret_tx) => { 249 | let _ =ret_tx.send(endpoint.connect().await); 250 | }, 251 | } 252 | } 253 | _ = shutdown.recv() => { 254 | break; 255 | } 256 | } 257 | } 258 | debug!("quic_connection_daemon leave"); 259 | Ok(()) 260 | } 261 | -------------------------------------------------------------------------------- /src/client/outbound/request_cmd.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | client::utils::{ClientConnectionRequest, ClientServerConnection}, 3 | protocol::{ECHO_REQUEST_CMD, LITE_TLS_REQUEST_CMD, TCP_REQUEST_CMD, UDP_REQUEST_CMD}, 4 | utils::ConnectionRequest, 5 | }; 6 | pub struct ClientRequestCMD<'a>( 7 | pub &'a ClientConnectionRequest, 8 | pub &'a ClientServerConnection, 9 | ); 10 | 11 | impl<'a> ClientRequestCMD<'a> { 12 | pub fn get_cmd(&self) -> u8 { 13 | use ClientServerConnection::*; 14 | use ConnectionRequest::*; 15 | match (self.0, self.1) { 16 | #[cfg(feature = "udp")] 17 | (UDP(_), _) => UDP_REQUEST_CMD, 18 | #[cfg(feature = "quic")] 19 | (ECHO(_), _) => ECHO_REQUEST_CMD, 20 | (TCP(_), LiteTLS(_)) => LITE_TLS_REQUEST_CMD, 21 | (TCP(_), _) => TCP_REQUEST_CMD, 22 | _ => unreachable!(), 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/client/outbound/tcp_tls.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | args::Opt, 3 | client::utils::{get_rustls_config, ClientServerConnection}, 4 | }; 5 | use anyhow::Result; 6 | use std::sync::Arc; 7 | use tokio::net::TcpStream; 8 | use tokio_rustls::{ 9 | rustls::{ClientConfig, RootCertStore, ServerName}, 10 | TlsConnector, 11 | }; 12 | 13 | pub async fn tls_client_config() -> ClientConfig { 14 | get_rustls_config(RootCertStore::empty()) 15 | } 16 | 17 | pub struct TrojanTcpTlsConnector { 18 | tls_config: Arc, 19 | is_lite: bool, 20 | } 21 | 22 | impl TrojanTcpTlsConnector { 23 | pub fn new(tls_config: Arc, is_lite: bool) -> Self { 24 | Self { 25 | tls_config, 26 | is_lite, 27 | } 28 | } 29 | 30 | pub async fn connect(self, opt: Arc) -> Result { 31 | let Self { 32 | tls_config, 33 | is_lite, 34 | } = self; 35 | let opt = &*opt; 36 | let connector = TlsConnector::from(tls_config); 37 | let stream = TcpStream::connect(&opt.remote_socket_addr.unwrap()).await?; 38 | stream.set_nodelay(true)?; 39 | let stream = connector 40 | .connect( 41 | ServerName::try_from(opt.server_hostname.as_str()).expect("invalid DNS name"), 42 | stream, 43 | ) 44 | .await?; 45 | use ClientServerConnection::*; 46 | return Ok(match is_lite { 47 | #[cfg(feature = "lite_tls")] 48 | true => LiteTLS(stream), 49 | #[cfg(feature = "tcp_tls")] 50 | false => TcpTLS(stream), 51 | #[allow(unreachable_patterns)] 52 | _ => unreachable!(), 53 | }); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/client/outbound/trojan_auth.rs: -------------------------------------------------------------------------------- 1 | use crate::{client::utils::ClientServerConnection, protocol::HASH_LEN, utils::MixAddrType}; 2 | use anyhow::Result; 3 | use tokio::io::{AsyncWrite, AsyncWriteExt}; 4 | use tracing::{debug, trace}; 5 | 6 | pub async fn trojan_auth( 7 | mode: u8, 8 | addr: &MixAddrType, 9 | outbound: &mut ClientServerConnection, 10 | password: &String, 11 | ) -> Result<()> { 12 | match outbound { 13 | #[cfg(feature = "quic")] 14 | ClientServerConnection::Quic((out_write, _)) => { 15 | send_trojan_auth(mode, addr, out_write, password).await 16 | } 17 | #[cfg(feature = "tcp_tls")] 18 | ClientServerConnection::TcpTLS(out_write) => { 19 | send_trojan_auth(mode, addr, out_write, password).await 20 | } 21 | #[cfg(feature = "lite_tls")] 22 | ClientServerConnection::LiteTLS(out_write) => { 23 | send_trojan_auth(mode, addr, out_write, password).await 24 | } 25 | } 26 | } 27 | 28 | async fn send_trojan_auth( 29 | mode: u8, 30 | addr: &MixAddrType, 31 | outbound: &mut A, 32 | password: &String, 33 | ) -> Result<()> 34 | where 35 | A: AsyncWrite + Unpin + ?Sized, 36 | { 37 | let mut buf = Vec::with_capacity(HASH_LEN + 2 + 1 + addr.encoded_len() + 2); 38 | buf.extend_from_slice(password.as_bytes()); 39 | buf.extend_from_slice(&[b'\r', b'\n', mode]); 40 | addr.write_buf(&mut buf); 41 | buf.extend_from_slice(&[b'\r', b'\n']); 42 | trace!("trojan_connect: writing {:?}", buf); 43 | outbound.write_all(&buf).await?; 44 | // not using the following code because of quinn's bug. 45 | // let packet0 = [ 46 | // IoSlice::new(password_hash.as_bytes()), 47 | // IoSlice::new(&command0[..command0_len]), 48 | // IoSlice::new(self.host.as_bytes()), 49 | // IoSlice::new(&port_arr), 50 | // IoSlice::new(&[b'\r', b'\n']), 51 | // ]; 52 | // let mut writer = Pin::new(outbound); 53 | // future::poll_fn(|cx| writer.as_mut().poll_write_vectored(cx, &packet0[..])) 54 | // .await 55 | // .map_err(|e| Box::new(e))?; 56 | 57 | // writer.flush().await.map_err(|e| Box::new(e))?; 58 | // outbound.flush().await?; 59 | debug!("trojan packet 0 sent"); 60 | 61 | Ok(()) 62 | } 63 | -------------------------------------------------------------------------------- /src/client/run.rs: -------------------------------------------------------------------------------- 1 | use super::inbound::{user_endpoint_listener, HttpRequest, Socks5Request}; 2 | use crate::args::TrojanContext; 3 | use anyhow::Result; 4 | use tokio::sync::broadcast; 5 | 6 | pub async fn run_client(mut context: TrojanContext) -> Result<()> { 7 | // blocking by async in traits 8 | // trait ClientReqAcc { async fn accept(stream: TcpStream) -> ... } 9 | // impl ClientReqAcc for HttpAcc; 10 | // impl ClientReqAcc for Socks5Acc; 11 | // struct Listener { fn new(); async fn listen(); }; 12 | // listen(): user_endpoint_listener() 13 | // let http_listener = Listener::new::::(); 14 | // let socks5_listener = Listener::new::::(); 15 | // http_listener.listen().await; socks5_listener.listen().await; 16 | 17 | // ClientServerConnection would be elimiated 18 | let http_addr = context.options.local_http_addr; 19 | let socks5_addr = context.options.local_socks5_addr; 20 | let (shutdown_tx, shutdown) = broadcast::channel(1); 21 | tokio::spawn(user_endpoint_listener::( 22 | http_addr, 23 | context.clone_with_signal(shutdown), 24 | )); 25 | 26 | tokio::spawn(user_endpoint_listener::( 27 | socks5_addr, 28 | context.clone_with_signal(shutdown_tx.subscribe()), 29 | )); 30 | let _ = context.shutdown.recv().await; 31 | Ok(()) 32 | } 33 | -------------------------------------------------------------------------------- /src/client/utils/client_server_connection.rs: -------------------------------------------------------------------------------- 1 | use tokio::net::TcpStream; 2 | 3 | #[cfg(feature = "quic")] 4 | use quinn::*; 5 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 6 | use tokio_rustls::client::TlsStream; 7 | 8 | pub enum ClientServerConnection { 9 | #[cfg(feature = "quic")] 10 | Quic((SendStream, RecvStream)), 11 | #[cfg(feature = "tcp_tls")] 12 | TcpTLS(TlsStream), 13 | #[cfg(feature = "lite_tls")] 14 | LiteTLS(TlsStream), 15 | } 16 | -------------------------------------------------------------------------------- /src/client/utils/client_tcp_stream.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::BufferedRecv; 2 | use tokio::net::TcpStream; 3 | 4 | pub fn new_client_tcp_stream( 5 | inner: TcpStream, 6 | http_request_extension: Option>, 7 | ) -> BufferedRecv { 8 | BufferedRecv::new(inner, http_request_extension.map(|v| (0, v))) 9 | } 10 | -------------------------------------------------------------------------------- /src/client/utils/client_udp_stream.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | protocol::UDP_BUFFER_SIZE, 3 | utils::{CursoredBuffer, ExtendableFromSlice, MixAddrType, UdpRead, UdpRelayBuffer, UdpWrite}, 4 | }; 5 | use futures::ready; 6 | use std::pin::Pin; 7 | use std::task::{Context, Poll}; 8 | use std::{ 9 | net::SocketAddr, 10 | ops::{Deref, DerefMut}, 11 | }; 12 | use tokio::net::UdpSocket; 13 | use tokio::sync::oneshot; 14 | use tracing::{debug, warn}; 15 | 16 | #[cfg_attr(feature = "debug_info", derive(Debug))] 17 | struct Socks5UdpSpecifiedBuffer { 18 | inner: Vec, 19 | } 20 | 21 | impl Socks5UdpSpecifiedBuffer { 22 | fn new(capacity: usize) -> Self { 23 | let mut inner = Vec::with_capacity(capacity); 24 | // The fields in the UDP request header are: 25 | // o RSV Reserved X'0000' 26 | // o FRAG Current fragment number 27 | inner.extend_from_slice(&[0, 0, 0]); 28 | Self { inner } 29 | } 30 | 31 | fn reset(&mut self) { 32 | unsafe { 33 | self.inner.set_len(3); 34 | } 35 | } 36 | 37 | fn is_empty(&self) -> bool { 38 | assert!( 39 | self.inner.len() >= 3, 40 | "Socks5UdpSpecifiedBuffer unexpected len: {}", 41 | self.inner.len() 42 | ); 43 | self.inner.len() == 3 44 | } 45 | } 46 | 47 | impl ExtendableFromSlice for Socks5UdpSpecifiedBuffer { 48 | fn extend_from_slice(&mut self, src: &[u8]) { 49 | self.inner.extend_from_slice(src); 50 | } 51 | } 52 | 53 | impl Deref for Socks5UdpSpecifiedBuffer { 54 | type Target = Vec; 55 | fn deref(&self) -> &Self::Target { 56 | &self.inner 57 | } 58 | } 59 | 60 | impl DerefMut for Socks5UdpSpecifiedBuffer { 61 | fn deref_mut(&mut self) -> &mut Self::Target { 62 | &mut self.inner 63 | } 64 | } 65 | 66 | pub struct Socks5UdpStream { 67 | server_udp_socket: UdpSocket, 68 | client_udp_addr: Option, 69 | signal_reset: oneshot::Receiver<()>, 70 | buffer: Socks5UdpSpecifiedBuffer, 71 | } 72 | 73 | impl Socks5UdpStream { 74 | pub fn new( 75 | server_udp_socket: UdpSocket, 76 | stream_reset_signal_rx: oneshot::Receiver<()>, 77 | ) -> Self { 78 | Self { 79 | server_udp_socket, 80 | client_udp_addr: None, 81 | signal_reset: stream_reset_signal_rx, 82 | buffer: Socks5UdpSpecifiedBuffer::new(UDP_BUFFER_SIZE), 83 | } 84 | } 85 | } 86 | 87 | impl Socks5UdpStream { 88 | fn poll_read( 89 | mut self: std::pin::Pin<&mut Self>, 90 | cx: &mut std::task::Context<'_>, 91 | buf: &mut tokio::io::ReadBuf<'_>, 92 | ) -> std::task::Poll> { 93 | let addr = match ready!(self.server_udp_socket.poll_recv_from(cx, buf)) { 94 | Ok(addr) => addr, 95 | Err(e) => { 96 | return Poll::Ready(Err(e)); 97 | } 98 | }; 99 | 100 | if self.client_udp_addr.is_none() { 101 | self.client_udp_addr = Some(addr.clone()); 102 | } else { 103 | if self.client_udp_addr.unwrap() != addr { 104 | return Poll::Ready(Err(std::io::ErrorKind::Interrupted.into())); 105 | } 106 | } 107 | Poll::Ready(Ok(())) 108 | } 109 | 110 | fn poll_write( 111 | self: &mut std::pin::Pin<&mut Self>, 112 | cx: &mut std::task::Context<'_>, 113 | ) -> Poll> { 114 | if self.client_udp_addr.is_none() { 115 | return Poll::Ready(Err(std::io::ErrorKind::Other.into())); 116 | } 117 | 118 | self.server_udp_socket 119 | .poll_send_to(cx, &self.buffer, self.client_udp_addr.unwrap()) 120 | } 121 | } 122 | 123 | impl UdpRead for Socks5UdpStream { 124 | /// ```not_rust 125 | /// +----+------+------+----------+----------+----------+ 126 | /// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | 127 | /// +----+------+------+----------+----------+----------+ 128 | /// | 2 | 1 | 1 | Variable | 2 | Variable | 129 | /// +----+------+------+----------+----------+----------+ 130 | /// The fields in the UDP request header are: 131 | /// o RSV Reserved X'0000' 132 | /// o FRAG Current fragment number 133 | /// o ATYP address type of following addresses: 134 | /// o IP V4 address: X'01' 135 | /// o DOMAINNAME: X'03' 136 | /// o IP V6 address: X'04' 137 | /// o DST.ADDR desired destination address 138 | /// o DST.PORT desired destination port 139 | /// o DATA user data 140 | /// ``` 141 | fn poll_proxy_stream_read( 142 | mut self: Pin<&mut Self>, 143 | cx: &mut Context<'_>, 144 | buf: &mut UdpRelayBuffer, 145 | ) -> Poll> { 146 | debug!("Socks5UdpRecvStream::poll_proxy_stream_read()"); 147 | let mut buf_inner = buf.as_read_buf(); 148 | let ptr = buf_inner.filled().as_ptr(); 149 | 150 | crate::try_recv!( 151 | oneshot, 152 | self.signal_reset, 153 | return Poll::Ready(Ok(MixAddrType::None)) 154 | ); 155 | 156 | match ready!(self.poll_read(cx, &mut buf_inner)) { 157 | Ok(_) => { 158 | // Ensure the pointer does not change from under us 159 | assert_eq!(ptr, buf_inner.filled().as_ptr()); 160 | let n = buf_inner.filled().len(); 161 | 162 | if n < 3 { 163 | return Poll::Ready(Ok(MixAddrType::None)); 164 | } 165 | 166 | // Safety: This is guaranteed to be the number of initialized (and read) 167 | // bytes due to the invariants provided by `ReadBuf::filled`. 168 | unsafe { 169 | buf.advance_mut(n); 170 | } 171 | buf.advance(3); 172 | #[cfg(feature = "debug_info")] 173 | debug!( 174 | "Socks5UdpRecvStream::poll_proxy_stream_read() buf {:?}", 175 | buf 176 | ); 177 | Poll::Ready( 178 | MixAddrType::from_encoded(buf).map_err(|_| std::io::ErrorKind::Other.into()), 179 | ) 180 | } 181 | Err(e) => Poll::Ready(Err(e)), 182 | } 183 | } 184 | } 185 | 186 | impl UdpWrite for Socks5UdpStream { 187 | fn poll_proxy_stream_write( 188 | mut self: Pin<&mut Self>, 189 | cx: &mut Context<'_>, 190 | buf: &[u8], 191 | addr: &MixAddrType, 192 | ) -> Poll> { 193 | let just_filled_buf = if self.buffer.is_empty() { 194 | addr.write_buf(&mut self.buffer); 195 | self.buffer.extend_from_slice(buf); 196 | true 197 | } else { 198 | false 199 | }; 200 | 201 | // only if we write the whole buf in one write we reset the buffer 202 | // to accept new data. 203 | match self.poll_write(cx)? { 204 | Poll::Ready(real_written_amt) => { 205 | if real_written_amt == self.buffer.len() { 206 | self.buffer.reset(); 207 | } else { 208 | warn!("Socks5UdpSendStream didn't send the entire buffer"); 209 | } 210 | } 211 | _ => (), 212 | } 213 | 214 | if just_filled_buf { 215 | Poll::Ready(Ok(buf.len())) 216 | } else { 217 | Poll::Pending 218 | } 219 | } 220 | 221 | fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { 222 | Poll::Ready(Ok(())) 223 | } 224 | 225 | fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { 226 | Poll::Ready(Ok(())) 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /src/client/utils/connection_mode.rs: -------------------------------------------------------------------------------- 1 | #[derive(Clone)] 2 | #[cfg_attr(feature = "debug_info", derive(Debug))] 3 | pub enum ConnectionMode { 4 | #[cfg(feature = "tcp_tls")] 5 | TcpTLS, 6 | #[cfg(feature = "quic")] 7 | Quic, 8 | #[cfg(feature = "lite_tls")] 9 | LiteTLS, 10 | } 11 | -------------------------------------------------------------------------------- /src/client/utils/data_transfer.rs: -------------------------------------------------------------------------------- 1 | use super::ClientServerConnection; 2 | #[cfg(feature = "udp")] 3 | use {super::Socks5UdpStream, crate::utils::TrojanUdpStream}; 4 | 5 | use crate::{ 6 | adapt, 7 | utils::{Adapter, BufferedRecv, MixAddrType, ParserError, WRTuple}, 8 | }; 9 | use anyhow::{anyhow, Context, Result}; 10 | use tokio::{net::TcpStream, sync::broadcast}; 11 | use tracing::info; 12 | 13 | #[cfg(feature = "lite_tls")] 14 | use crate::utils::lite_tls::LiteTlsStream; 15 | 16 | pub async fn relay_tcp( 17 | mut inbound: BufferedRecv, 18 | outbound: ClientServerConnection, 19 | shutdown: broadcast::Receiver<()>, 20 | conn_id: usize, 21 | target_host: &MixAddrType, 22 | ) -> Result<()> { 23 | match outbound { 24 | #[cfg(feature = "quic")] 25 | ClientServerConnection::Quic(outbound) => { 26 | let outbound = WRTuple::from_wr_tuple(outbound); 27 | adapt!([tcp][conn_id] 28 | inbound <=> outbound <=> target_host 29 | Until shutdown 30 | ); 31 | } 32 | #[cfg(feature = "tcp_tls")] 33 | ClientServerConnection::TcpTLS(outbound) => { 34 | adapt!([tcp][conn_id] 35 | inbound <=> outbound <=> target_host 36 | Until shutdown 37 | ); 38 | } 39 | #[cfg(feature = "lite_tls")] 40 | ClientServerConnection::LiteTLS(mut outbound) => { 41 | let mut lite_tls_endpoint = LiteTlsStream::new_client_endpoint(); 42 | 43 | // there is a potential bug here, if timeout is too short for a 44 | // valid handshake, it closes unexpectedly and immediately try for 45 | // another time. However for the second time, it is not recognised 46 | // as a tls stream and therefore fails again. 47 | // I set a reasonably large timeout here to avoid such problem, 48 | // but the reason for the failed second round is currently unknown. 49 | match lite_tls_endpoint 50 | .handshake_timeout(&mut outbound, &mut inbound) 51 | .await 52 | { 53 | Ok(_) => { 54 | let ver = lite_tls_endpoint.version; 55 | if ver.is_none() { 56 | return Ok(()); 57 | } 58 | info!("[{}]lite tls handshake succeed", ver.unwrap()); 59 | let (mut outbound, _) = outbound.into_inner(); 60 | let (mut inbound, _) = inbound.into_inner(); 61 | 62 | lite_tls_endpoint 63 | .flush_tls(&mut inbound, &mut outbound) 64 | .await?; 65 | 66 | adapt!([lite][conn_id] 67 | inbound <=> outbound <=> target_host 68 | Until shutdown 69 | ); 70 | } 71 | Err(e) => { 72 | if let Some(e @ ParserError::Invalid(_)) = e.downcast_ref::() { 73 | info!("not tls stream: {:#}", e); 74 | lite_tls_endpoint 75 | .flush_non_tls(&mut outbound, &mut inbound) 76 | .await?; 77 | adapt!([tcp][conn_id] 78 | inbound <=> outbound <=> target_host 79 | Until shutdown 80 | ); 81 | } else { 82 | return Err(e); 83 | } 84 | } 85 | } 86 | } 87 | } 88 | Ok(()) 89 | } 90 | 91 | #[cfg(feature = "udp")] 92 | pub async fn relay_udp( 93 | inbound: Socks5UdpStream, 94 | outbound: ClientServerConnection, 95 | upper_shutdown: broadcast::Receiver<()>, 96 | conn_id: usize, 97 | ) -> Result<()> { 98 | match outbound { 99 | #[cfg(feature = "quic")] 100 | ClientServerConnection::Quic(out_quic) => { 101 | let outbound: _ = TrojanUdpStream::new(WRTuple::from_wr_tuple(out_quic), None); 102 | adapt!([udp][conn_id]inbound <=> outbound 103 | Until upper_shutdown 104 | ); 105 | } 106 | #[cfg(feature = "tcp_tls")] 107 | ClientServerConnection::TcpTLS(out_tls) => { 108 | let outbound = TrojanUdpStream::new(out_tls, None); 109 | adapt!([udp][conn_id]inbound <=> outbound 110 | Until upper_shutdown 111 | ); 112 | } 113 | #[cfg(feature = "lite_tls")] 114 | ClientServerConnection::LiteTLS(out_tls) => { 115 | let outbound = TrojanUdpStream::new(out_tls, None); 116 | adapt!([udp][conn_id]inbound <=> outbound 117 | Until upper_shutdown 118 | ); 119 | } 120 | } 121 | Ok(()) 122 | } 123 | -------------------------------------------------------------------------------- /src/client/utils/mod.rs: -------------------------------------------------------------------------------- 1 | mod client_tcp_stream; 2 | pub use client_tcp_stream::*; 3 | #[cfg(feature = "udp")] 4 | mod client_udp_stream; 5 | #[cfg(feature = "udp")] 6 | pub use client_udp_stream::*; 7 | mod data_transfer; 8 | pub use data_transfer::*; 9 | mod client_server_connection; 10 | pub use client_server_connection::*; 11 | 12 | mod connection_mode; 13 | pub use connection_mode::ConnectionMode; 14 | 15 | mod rustls_utils; 16 | pub use rustls_utils::get_rustls_config; 17 | 18 | use tokio::net::TcpStream; 19 | #[cfg(feature = "udp")] 20 | use tokio::sync::mpsc; 21 | 22 | #[cfg(not(feature = "udp"))] 23 | use crate::utils::DummyRequest; 24 | 25 | use crate::utils::{BufferedRecv, ConnectionRequest}; 26 | 27 | #[cfg(feature = "udp")] 28 | pub type ClientConnectionRequest = 29 | ConnectionRequest, Socks5UdpStream, mpsc::Receiver<()>>; 30 | 31 | #[cfg(not(feature = "udp"))] 32 | pub type ClientConnectionRequest = ConnectionRequest; 33 | -------------------------------------------------------------------------------- /src/client/utils/rustls_utils.rs: -------------------------------------------------------------------------------- 1 | use tokio_rustls::rustls::{Certificate, ClientConfig, RootCertStore}; 2 | 3 | pub fn get_rustls_config(mut roots: RootCertStore) -> ClientConfig { 4 | for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") { 5 | roots.add(&Certificate(cert.0)).unwrap(); 6 | } 7 | 8 | ClientConfig::builder() 9 | .with_safe_defaults() 10 | .with_root_certificates(roots) 11 | .with_no_client_auth() 12 | } 13 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | // #![feature(aarch64_target_feature)] 2 | // #![feature(stdsimd)] 3 | #![feature(generic_associated_types)] 4 | #![feature(type_alias_impl_trait)] 5 | #![feature(associated_type_defaults)] 6 | 7 | use mimalloc::MiMalloc; 8 | #[global_allocator] 9 | static GLOBAL: MiMalloc = MiMalloc; 10 | 11 | #[cfg(feature = "client")] 12 | mod client; 13 | mod proxy; 14 | 15 | mod protocol; 16 | #[cfg(feature = "server")] 17 | mod server; 18 | // pub mod simd; 19 | 20 | #[cfg(not(any(feature = "client", feature = "server")))] 21 | mod must_choose_between_client_and_server; 22 | #[cfg(not(any(feature = "quic", feature = "tcp_tls", feature = "lite_tls")))] 23 | mod must_choose_between_quic_and_tcp_tls; 24 | 25 | mod args; 26 | use args::Opt; 27 | mod utils; 28 | 29 | use anyhow::{anyhow, Result}; 30 | use std::{net::ToSocketAddrs, sync::Arc}; 31 | use structopt::StructOpt; 32 | 33 | #[cfg(all(target_os = "linux", feature = "zio"))] 34 | use { 35 | tokio::sync::OnceCell, 36 | utils::{start_tcp_relay_threads, TcpTx}, 37 | }; 38 | 39 | #[cfg(all(target_os = "linux", feature = "zio"))] 40 | pub static VEC_TCP_TX: OnceCell> = OnceCell::const_new(); 41 | 42 | #[tokio::main] 43 | async fn main() -> Result<()> { 44 | let mut options = Opt::from_args(); 45 | let collector = tracing_subscriber::fmt() 46 | .with_max_level(options.log_level) 47 | .with_target(if cfg!(feature = "debug_info") { 48 | true 49 | } else { 50 | false 51 | }) 52 | .finish(); 53 | let _ = tracing::subscriber::set_global_default(collector); 54 | 55 | #[cfg(all(target_os = "linux", feature = "zio"))] 56 | { 57 | use tracing::info; 58 | let tcp_submit = start_tcp_relay_threads(); 59 | let _ = VEC_TCP_TX.set(tcp_submit); 60 | info!("glommio runtime started"); 61 | } 62 | 63 | options.remote_socket_addr = Some( 64 | ( 65 | if options.server_ip.len() > 0 { 66 | options.server_ip.to_owned() 67 | } else { 68 | options.server_hostname.to_owned() 69 | }, 70 | options.server_port, 71 | ) 72 | .to_socket_addrs()? 73 | .next() 74 | .ok_or(anyhow!("invalid remote address"))?, 75 | ); 76 | 77 | utils::start_dns_resolver_thread(); 78 | 79 | let _ = proxy::build_tunnel(tokio::signal::ctrl_c(), Arc::new(options)).await; 80 | Ok(()) 81 | } 82 | -------------------------------------------------------------------------------- /src/protocol.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | use std::time::Duration; 3 | 4 | pub const HASH_LEN: usize = 56; 5 | pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; 6 | pub const ECHO_PHRASE: &str = "echo"; 7 | pub const QUIC_MAX_IDLE_TIMEOUT: Duration = Duration::from_secs(600); 8 | pub const MAX_CONCURRENT_BIDI_STREAMS: usize = 30; 9 | 10 | pub const TCP_REQUEST_CMD: u8 = 0x01; 11 | pub const UDP_REQUEST_CMD: u8 = 0x03; 12 | pub const ECHO_REQUEST_CMD: u8 = 0xff; 13 | pub const LITE_TLS_REQUEST_CMD: u8 = 0x11; 14 | 15 | pub const TCP_MAX_IDLE_TIMEOUT: u16 = 10 * 60; 16 | pub const SERVER_OUTBOUND_CONNECT_TIMEOUT: u64 = 10; 17 | pub const LITE_TLS_HANDSHAKE_TIMEOUT: u64 = 60; 18 | 19 | pub const LEAVE_TLS_COMMAND: [u8; 6] = [0xff, 0x03, 0x03, 0, 0x01, 0x01]; 20 | 21 | pub const RELAY_BUFFER_SIZE: usize = 8192 * 2; 22 | pub const UDP_BUFFER_SIZE: usize = 8192; 23 | 24 | pub const DNS_UPDATE_PERIOD_SEC: u64 = 60 * 4; 25 | 26 | pub const BLACK_HOLE_LOCAL_ADDR: [u8; 4] = [192, 0, 2, 0]; 27 | -------------------------------------------------------------------------------- /src/proxy.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use crate::args::{Opt, TrojanContext}; 4 | use anyhow::Result; 5 | use tokio::{select, sync::broadcast}; 6 | use tracing::*; 7 | 8 | pub async fn build_tunnel(ctrl_c: impl std::future::Future, options: Arc) -> Result<()> { 9 | let (shutdown_tx, shutdown) = broadcast::channel(1); 10 | 11 | let context = TrojanContext { options, shutdown }; 12 | 13 | match context.options.server { 14 | #[cfg(feature = "server")] 15 | true => { 16 | use crate::server::run_server; 17 | info!("server-start"); 18 | select! { 19 | _ = ctrl_c => { 20 | info!("ctrl-c"); 21 | } 22 | res = run_server(context) => { 23 | match res { 24 | Err(err) => { 25 | error!("server quit due to {:#}", err); 26 | } 27 | ok => { 28 | info!("server end: {:?}", ok); 29 | } 30 | } 31 | } 32 | } 33 | } 34 | #[cfg(feature = "client")] 35 | false => { 36 | use crate::client::run_client; 37 | info!("client-start"); 38 | select! { 39 | _ = ctrl_c => { 40 | info!("ctrl-c"); 41 | } 42 | res = run_client(context) => { 43 | match res { 44 | Err(err) => { 45 | error!("client quit due to {:#}", err); 46 | } 47 | ok => { 48 | info!("client end: {:?}", ok); 49 | 50 | } 51 | } 52 | } 53 | } 54 | } 55 | #[allow(unreachable_patterns)] 56 | _ => panic!("function not complied"), 57 | } 58 | 59 | drop(shutdown_tx); 60 | Ok(()) 61 | } 62 | -------------------------------------------------------------------------------- /src/server/inbound/acceptor.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(feature = "udp"))] 2 | use crate::utils::DummyRequest; 3 | #[cfg(feature = "udp")] 4 | use crate::utils::TrojanUdpStream; 5 | use crate::{ 6 | expect_buf_len, 7 | protocol::{ 8 | ECHO_REQUEST_CMD, HASH_LEN, LITE_TLS_REQUEST_CMD, TCP_REQUEST_CMD, UDP_REQUEST_CMD, 9 | }, 10 | server::{outbound::fallback, utils::TcpOption}, 11 | utils::{BufferedRecv, ConnectionRequest, MixAddrType, ParserError}, 12 | }; 13 | use anyhow::Result; 14 | use futures::TryFutureExt; 15 | use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; 16 | use tracing::*; 17 | #[cfg(feature = "udp")] 18 | type ServerConnectionRequest = 19 | ConnectionRequest>, TrojanUdpStream, BufferedRecv>; 20 | #[cfg(not(feature = "udp"))] 21 | type ServerConnectionRequest = 22 | ConnectionRequest>, DummyRequest, BufferedRecv>; 23 | #[derive(Default)] 24 | #[cfg_attr(feature = "debug_info", derive(Debug))] 25 | pub struct TrojanAcceptor<'a> { 26 | pub host: MixAddrType, 27 | cursor: usize, 28 | password_hash: &'a [u8], 29 | buf: Vec, 30 | cmd_code: u8, 31 | fallback_port: u16, 32 | } 33 | 34 | impl<'a> TrojanAcceptor<'a> { 35 | pub fn new(password_hash: &[u8], fallback_port: u16) -> TrojanAcceptor { 36 | TrojanAcceptor { 37 | password_hash, 38 | fallback_port, 39 | buf: Vec::with_capacity(1024), 40 | ..Default::default() 41 | } 42 | } 43 | 44 | fn verify(&mut self) -> Result<(), ParserError> { 45 | if self.buf.len() < HASH_LEN { 46 | return Err(ParserError::Incomplete( 47 | "Target::verify self.buf.len() < HASH_LEN".into(), 48 | )); 49 | } 50 | 51 | if &self.buf[..HASH_LEN] == self.password_hash { 52 | self.cursor = HASH_LEN + 2; 53 | Ok(()) 54 | } else { 55 | Err(ParserError::Invalid("Target::verify hash invalid".into())) 56 | } 57 | } 58 | 59 | fn set_host_and_port(&mut self) -> Result<(), ParserError> { 60 | expect_buf_len!( 61 | self.buf, 62 | HASH_LEN + 5, 63 | "TrojanAcceptor::set_host_and_port cmd" 64 | ); // HASH + \r\n + cmd(2 bytes) + host_len(1 byte, only valid when address is hostname) 65 | 66 | // unsafe: This is so buggy 67 | self.cursor = HASH_LEN + 3; 68 | 69 | self.cmd_code = self.buf[HASH_LEN + 2]; 70 | match self.cmd_code { 71 | TCP_REQUEST_CMD | UDP_REQUEST_CMD | LITE_TLS_REQUEST_CMD => { 72 | self.host = MixAddrType::from_encoded(&mut (&mut self.cursor, &self.buf))?; 73 | } 74 | ECHO_REQUEST_CMD => (), 75 | _ => { 76 | return Err(ParserError::Invalid( 77 | "Target::verify invalid connection type".into(), 78 | )) 79 | } 80 | }; 81 | Ok(()) 82 | } 83 | 84 | /// ```not_rust 85 | /// +-----------------------+---------+----------------+---------+----------+ 86 | /// | hex(SHA224(password)) | CRLF | Trojan Request | CRLF | Payload | 87 | /// +-----------------------+---------+----------------+---------+----------+ 88 | /// | 56 | X'0D0A' | Variable | X'0D0A' | Variable | 89 | /// +-----------------------+---------+----------------+---------+----------+ 90 | /// 91 | /// where Trojan Request is a SOCKS5-like request: 92 | /// 93 | /// +-----+------+----------+----------+ 94 | /// | CMD | ATYP | DST.ADDR | DST.PORT | 95 | /// +-----+------+----------+----------+ 96 | /// | 1 | 1 | Variable | 2 | 97 | /// +-----+------+----------+----------+ 98 | /// 99 | /// where: 100 | /// 101 | /// o CMD 102 | /// o CONNECT X'01' 103 | /// o UDP ASSOCIATE X'03' 104 | /// o PROBING X'FF' 105 | /// o ATYP address type of following address 106 | /// o IP V4 address: X'01' 107 | /// o DOMAINNAME: X'03' 108 | /// o IP V6 address: X'04' 109 | /// o DST.ADDR desired destination address 110 | /// o DST.PORT desired destination port in network octet order 111 | /// ``` 112 | pub async fn accept( 113 | &mut self, 114 | mut inbound: I, 115 | ) -> Result, ParserError> 116 | where 117 | I: AsyncRead + AsyncWrite + Unpin + Send + 'static, 118 | { 119 | // let (mut read_half, write_half) = inbound.split(); 120 | loop { 121 | let read = inbound 122 | .read_buf(&mut self.buf) 123 | .await 124 | .map_err(|_| ParserError::Invalid("Target::accept failed to read".into()))?; 125 | if read != 0 { 126 | match self.parse() { 127 | Err(err @ ParserError::Invalid(_)) => { 128 | error!("Target::accept failed: {:#}", err); 129 | let mut buf = Vec::new(); 130 | std::mem::swap(&mut buf, &mut self.buf); 131 | tokio::spawn( 132 | fallback(buf, self.fallback_port, inbound).unwrap_or_else(|e| { 133 | error!("connection to fallback failed {:#}", e) 134 | }), 135 | ); 136 | return Err(err); 137 | } 138 | Err(err @ ParserError::Incomplete(_)) => { 139 | debug!("Target::accept failed: {:?}", err); 140 | continue; 141 | } 142 | Ok(()) => { 143 | break; 144 | } 145 | } 146 | } else { 147 | return Err(ParserError::Incomplete("Target::accept EOF".into())); 148 | } 149 | } 150 | use ConnectionRequest::*; 151 | let buffered_request = if self.buf.len() == self.cursor { 152 | None 153 | } else { 154 | Some((self.cursor, std::mem::take(&mut self.buf))) 155 | }; 156 | 157 | use TcpOption::*; 158 | match self.cmd_code { 159 | #[cfg(feature = "udp")] 160 | UDP_REQUEST_CMD => Ok(UDP(TrojanUdpStream::new(inbound, buffered_request))), 161 | #[cfg(not(feature = "udp"))] 162 | UDP_REQUEST_CMD => Err(ParserError::Invalid( 163 | "udp functionality not included".into(), 164 | )), 165 | TCP_REQUEST_CMD => Ok(TCP(TLS(BufferedRecv::new(inbound, buffered_request)))), 166 | LITE_TLS_REQUEST_CMD => Ok(TCP(LiteTLS(BufferedRecv::new(inbound, buffered_request)))), 167 | #[cfg(feature = "quic")] 168 | ECHO_REQUEST_CMD => Ok(ECHO(BufferedRecv::new(inbound, buffered_request))), 169 | _ => unreachable!(), 170 | } 171 | } 172 | 173 | pub fn parse(&mut self) -> Result<(), ParserError> { 174 | #[cfg(feature = "debug_info")] 175 | debug!( 176 | "parse begin, cursor {}, buffer({}): {:?}", 177 | self.cursor, 178 | self.buf.len(), 179 | &self.buf[self.cursor..] 180 | ); 181 | if self.cursor == 0 { 182 | self.verify()?; 183 | #[cfg(feature = "debug_info")] 184 | debug!("verified"); 185 | } 186 | 187 | if self.host.is_none() { 188 | self.set_host_and_port()?; 189 | } 190 | 191 | #[cfg(feature = "debug_info")] 192 | debug!("target: {:?}", self); 193 | 194 | expect_buf_len!(self.buf, self.cursor + 2, "TrojanAcceptor::parse CRLF"); 195 | 196 | if &self.buf[self.cursor..self.cursor + 2] == b"\r\n" { 197 | self.cursor += 2; 198 | Ok(()) 199 | } else { 200 | Err(ParserError::Invalid("Target::accept expecting CRLF".into())) 201 | } 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /src/server/inbound/handler.rs: -------------------------------------------------------------------------------- 1 | use crate::{args::TrojanContext, server::outbound::handle_outbound}; 2 | use anyhow::{anyhow, Context, Result}; 3 | use tokio::{ 4 | sync::broadcast, 5 | time::{timeout, Duration}, 6 | }; 7 | use tokio_rustls::Accept; 8 | #[cfg(feature = "quic")] 9 | use { 10 | futures::{StreamExt, TryFutureExt}, 11 | quinn::*, 12 | }; 13 | 14 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 15 | use tokio::net::TcpStream; 16 | 17 | #[cfg(feature = "quic")] 18 | pub async fn handle_quic_connection( 19 | mut context: TrojanContext, 20 | mut streams: IncomingBiStreams, 21 | ) -> Result<()> { 22 | use crate::utils::WRTuple; 23 | use tokio::select; 24 | use tracing::{error, info}; 25 | let (shutdown_tx, _) = broadcast::channel(1); 26 | 27 | loop { 28 | let stream = select! { 29 | s = streams.next() => { 30 | match s { 31 | Some(stream) => stream, 32 | None => {break;} 33 | } 34 | }, 35 | _ = context.shutdown.recv() => { 36 | // info 37 | break; 38 | } 39 | }; 40 | 41 | let stream = match stream { 42 | Err(quinn::ConnectionError::ApplicationClosed { .. }) => { 43 | info!("connection closed"); 44 | return Ok(()); 45 | } 46 | Err(e) => { 47 | return Err(anyhow::Error::new(e)); 48 | } 49 | Ok(s) => s, 50 | }; 51 | tokio::spawn( 52 | handle_outbound( 53 | context.clone_with_signal(shutdown_tx.subscribe()), 54 | WRTuple::from_wr_tuple(stream), 55 | ) 56 | .map_err(|e| { 57 | error!("handle_quic_outbound quit due to {:#}", e); 58 | e 59 | }), 60 | ); 61 | } 62 | Ok(()) 63 | } 64 | 65 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 66 | pub async fn handle_tcp_tls_connection( 67 | context: TrojanContext, 68 | incoming: Accept, 69 | ) -> Result<()> { 70 | let stream = timeout(Duration::from_secs(5), incoming) 71 | .await 72 | .with_context(|| anyhow!("failed to accept TlsStream"))??; 73 | handle_outbound(context, stream).await?; 74 | Ok(()) 75 | } 76 | -------------------------------------------------------------------------------- /src/server/inbound/listener.rs: -------------------------------------------------------------------------------- 1 | use crate::{args::TrojanContext, try_recv}; 2 | 3 | #[cfg(feature = "quic")] 4 | use crate::{ 5 | server::inbound::handler::handle_quic_connection, server::inbound::quic::quic_tunnel_rx, 6 | }; 7 | 8 | use anyhow::{anyhow, Context, Result}; 9 | use futures::TryFutureExt; 10 | use tokio::sync::broadcast; 11 | use tracing::*; 12 | 13 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 14 | use { 15 | crate::server::inbound::{ 16 | get_server_local_addr, 17 | {handler::handle_tcp_tls_connection, tcp_tls::*}, 18 | }, 19 | std::sync::Arc, 20 | tokio::net::TcpListener, 21 | tokio_rustls::TlsAcceptor, 22 | }; 23 | 24 | #[cfg(feature = "quic")] 25 | pub async fn quic_listener(mut context: TrojanContext) -> Result<()> { 26 | use futures::StreamExt; 27 | let (shutdown_tx, _) = broadcast::channel(1); 28 | let (endpoint, mut incoming) = quic_tunnel_rx(&context.options).await?; 29 | info!("listening on [udp]{}", endpoint.local_addr()?); 30 | while let Some(conn) = incoming.next().await { 31 | try_recv!(broadcast, context.shutdown); 32 | debug!("[quic]connection incoming"); 33 | 34 | let quinn::NewConnection { bi_streams, .. } = match conn.await { 35 | Ok(new_conn) => new_conn, 36 | Err(e) => { 37 | error!("[quic]error while awaiting connection {:#}", e); 38 | continue; 39 | } 40 | }; 41 | tokio::spawn( 42 | handle_quic_connection( 43 | context.clone_with_signal(shutdown_tx.subscribe()), 44 | bi_streams, 45 | ) 46 | .map_err(move |e| error!("[quic]connection failed: {:#}", e)), 47 | ); 48 | } 49 | Ok(()) 50 | } 51 | 52 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 53 | pub async fn tcp_tls_listener(mut context: TrojanContext) -> Result<()> { 54 | let (shutdown_tx, _) = broadcast::channel(1); 55 | let config = tls_server_config(&context.options) 56 | .await 57 | .with_context(|| anyhow!("failed to get config"))?; 58 | let acceptor = TlsAcceptor::from(Arc::new(config)); 59 | let addr = get_server_local_addr(context.options.server_port); 60 | let mut listener = TcpListener::bind(&addr) 61 | .await 62 | .with_context(|| anyhow!("failed to bind tcp port"))?; 63 | info!("listening on [tcp]{}", listener.local_addr()?); 64 | loop { 65 | try_recv!(broadcast, context.shutdown); 66 | let (stream, _peer_addr) = match listener.accept().await { 67 | Ok(res) => res, 68 | Err(err) => { 69 | error!("failed to listen to tcp: {:?}", err); 70 | drop(listener); 71 | listener = TcpListener::bind(&addr) 72 | .await 73 | .with_context(|| anyhow!("[tcp]failed to bind tcp port"))?; 74 | continue; 75 | } 76 | }; 77 | stream.set_nodelay(true)?; 78 | tokio::spawn( 79 | handle_tcp_tls_connection( 80 | context.clone_with_signal(shutdown_tx.subscribe()), 81 | acceptor.accept(stream), 82 | ) 83 | .unwrap_or_else(move |e| error!("[tcp]failed to handle connection: {:#}", e)), 84 | ); 85 | } 86 | Ok(()) 87 | } 88 | -------------------------------------------------------------------------------- /src/server/inbound/mod.rs: -------------------------------------------------------------------------------- 1 | mod acceptor; 2 | mod handler; 3 | mod listener; 4 | #[cfg(feature = "quic")] 5 | mod quic; 6 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 7 | mod tcp_tls; 8 | 9 | pub use acceptor::TrojanAcceptor; 10 | #[cfg(feature = "quic")] 11 | pub use listener::quic_listener; 12 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 13 | pub use listener::tcp_tls_listener; 14 | 15 | use std::net::{IpAddr, SocketAddr}; 16 | pub fn get_server_local_addr(proxy_port: u16) -> SocketAddr { 17 | SocketAddr::from((IpAddr::from([0, 0, 0, 0]), proxy_port)) 18 | } 19 | -------------------------------------------------------------------------------- /src/server/inbound/quic.rs: -------------------------------------------------------------------------------- 1 | use super::get_server_local_addr; 2 | use crate::{ 3 | args::Opt, 4 | protocol::{ALPN_QUIC_HTTP, MAX_CONCURRENT_BIDI_STREAMS, QUIC_MAX_IDLE_TIMEOUT}, 5 | server::utils::get_server_certs_and_key, 6 | }; 7 | use anyhow::{bail, Context, Result}; 8 | use quinn::Endpoint; 9 | #[cfg(feature = "quic")] 10 | use quinn::*; 11 | use std::sync::Arc; 12 | use tokio::{fs, io}; 13 | use tokio_rustls::rustls::{self, ServerConfig}; 14 | use tracing::*; 15 | 16 | pub async fn quic_tunnel_rx(options: &Opt) -> Result<(Endpoint, Incoming)> { 17 | let (certs, key) = if let (Some(key_path), Some(cert_path)) = (&options.key, &options.cert) { 18 | get_server_certs_and_key(key_path, cert_path).await? 19 | } else { 20 | let dirs = directories::ProjectDirs::from("org", "quinn", "quinn-examples").unwrap(); 21 | let path = dirs.data_local_dir(); 22 | let cert_path = path.join("cert.der"); 23 | let key_path = path.join("key.der"); 24 | let cert = fs::read(&cert_path).await; 25 | let key = fs::read(&key_path).await; 26 | let (cert, key) = match cert.and_then(|x| Ok((x, key?))) { 27 | Ok(x) => x, 28 | Err(ref e) if e.kind() == io::ErrorKind::NotFound => { 29 | info!("generating self-signed certificate"); 30 | let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); 31 | let key = cert.serialize_private_key_der(); 32 | let cert = cert.serialize_der().unwrap(); 33 | fs::create_dir_all(&path) 34 | .await 35 | .context("failed to create certificate directory")?; 36 | fs::write(&cert_path, &cert) 37 | .await 38 | .context("failed to write certificate")?; 39 | fs::write(&key_path, &key) 40 | .await 41 | .context("failed to write private key")?; 42 | (cert, key) 43 | } 44 | Err(e) => { 45 | bail!("failed to read certificate: {}", e); 46 | } 47 | }; 48 | 49 | let key = rustls::PrivateKey(key); 50 | let cert = rustls::Certificate(cert); 51 | (vec![cert], key) 52 | }; 53 | 54 | let mut crypto_config = ServerConfig::builder() 55 | .with_safe_defaults() 56 | .with_no_client_auth() 57 | .with_single_cert(certs, key)?; 58 | crypto_config.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); 59 | 60 | let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(crypto_config)); 61 | 62 | let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); 63 | transport_config.max_idle_timeout(Some(QUIC_MAX_IDLE_TIMEOUT.try_into()?)); 64 | transport_config.persistent_congestion_threshold(6); 65 | transport_config.packet_threshold(4); 66 | transport_config.max_concurrent_bidi_streams(MAX_CONCURRENT_BIDI_STREAMS.try_into()?); 67 | 68 | let server_addr = get_server_local_addr(options.server_port); 69 | Ok(Endpoint::server(server_config, server_addr)?) 70 | } 71 | -------------------------------------------------------------------------------- /src/server/inbound/tcp_tls.rs: -------------------------------------------------------------------------------- 1 | use crate::args::Opt; 2 | use crate::server::utils::get_server_certs_and_key; 3 | use anyhow::Result; 4 | use std::io; 5 | 6 | use tokio_rustls::rustls::{ServerConfig, Ticketer}; 7 | 8 | pub async fn tls_server_config(options: &Opt) -> Result { 9 | let (cert, key) = get_server_certs_and_key( 10 | options.key.as_ref().unwrap(), 11 | options.cert.as_ref().unwrap(), 12 | ) 13 | .await?; 14 | let mut config = ServerConfig::builder() 15 | .with_safe_defaults() 16 | .with_no_client_auth() 17 | .with_single_cert(cert, key) 18 | .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; 19 | config.ticketer = Ticketer::new()?; 20 | Ok(config) 21 | } 22 | -------------------------------------------------------------------------------- /src/server/mod.rs: -------------------------------------------------------------------------------- 1 | mod inbound; 2 | mod outbound; 3 | mod run; 4 | mod utils; 5 | 6 | pub use run::run_server; 7 | -------------------------------------------------------------------------------- /src/server/outbound/connector.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | adapt, 3 | args::TrojanContext, 4 | protocol::{SERVER_OUTBOUND_CONNECT_TIMEOUT, TCP_MAX_IDLE_TIMEOUT}, 5 | server::inbound::TrojanAcceptor, 6 | utils::{lite_tls::LeaveTls, Adapter, ConnectionRequest, MixAddrType}, 7 | }; 8 | use anyhow::{anyhow, Context, Error, Result}; 9 | use std::sync::atomic::{AtomicUsize, Ordering}; 10 | use tokio::{ 11 | io::{AsyncRead, AsyncWrite}, 12 | net::TcpStream, 13 | select, 14 | time::{timeout, Duration}, 15 | }; 16 | use tracing::{debug, info}; 17 | #[cfg(feature = "udp")] 18 | use {crate::server::utils::ServerUdpStream, tokio::net::UdpSocket}; 19 | 20 | pub(crate) static TCP_CONNECTION_COUNTER: AtomicUsize = AtomicUsize::new(0); 21 | pub(crate) static UDP_CONNECTION_COUNTER: AtomicUsize = AtomicUsize::new(0); 22 | 23 | async fn outbound_connect(target_host: &MixAddrType) -> Result { 24 | let target_socket_addr = 25 | target_host.clone().resolve().await.with_context(|| { 26 | anyhow!("failed to resolve ip when connecting to {:?}", target_host) 27 | })?; 28 | 29 | let outbound = TcpStream::connect(target_socket_addr) 30 | .await 31 | .map_err(|e| Error::new(e)) 32 | .with_context(|| anyhow!("failed to connect to {:?}", target_host))?; 33 | 34 | outbound 35 | .set_nodelay(true) 36 | .map_err(|e| Error::new(e)) 37 | .with_context(|| { 38 | anyhow!( 39 | "failed to set tcp_nodelay for outbound stream {:?}", 40 | target_host 41 | ) 42 | })?; 43 | Ok(outbound) 44 | } 45 | 46 | pub async fn handle_outbound(mut context: TrojanContext, stream: I) -> Result<()> 47 | where 48 | I: AsyncRead + AsyncWrite + LeaveTls + Unpin + Send + 'static, 49 | { 50 | let opt = &*context.options; 51 | let mut target = TrojanAcceptor::new(opt.password.as_bytes(), opt.fallback_port); 52 | use ConnectionRequest::*; 53 | match timeout( 54 | Duration::from_secs(SERVER_OUTBOUND_CONNECT_TIMEOUT), 55 | target.accept(stream), 56 | ) 57 | .await? 58 | { 59 | Ok(TCP(inbound)) => { 60 | let outbound = 61 | timeout(Duration::from_secs(2), outbound_connect(&target.host)).await??; 62 | let conn_id = TCP_CONNECTION_COUNTER.fetch_add(1, Ordering::Relaxed); 63 | inbound 64 | .forward(outbound, &target.host, context.shutdown, conn_id) 65 | .await?; 66 | } 67 | #[cfg(feature = "udp")] 68 | Ok(UDP(inbound)) => { 69 | let outbound = UdpSocket::bind("[::]:0") 70 | .await 71 | .map_err(|e| Error::new(e)) 72 | .with_context(|| anyhow!("failed to bind UdpSocket {:?}", target.host))?; 73 | 74 | let outbound = ServerUdpStream::new(outbound); 75 | let conn_id = UDP_CONNECTION_COUNTER.fetch_add(1, Ordering::Relaxed); 76 | let shutdown = context.shutdown; 77 | adapt!([udp][conn_id] 78 | inbound <=> outbound 79 | Until shutdown Or Sec TCP_MAX_IDLE_TIMEOUT 80 | ); 81 | } 82 | #[cfg(feature = "quic")] 83 | Ok(ECHO(mut inbound)) => { 84 | use tokio::io::AsyncReadExt; 85 | use tokio::io::AsyncWriteExt; 86 | let echo = async move { 87 | let mut buf = [0; 256]; 88 | loop { 89 | let num = inbound.read(&mut buf).await; 90 | let num = if num.is_err() { return } else { num.unwrap() }; 91 | if inbound.write(&buf[..num]).await.is_err() { 92 | return; 93 | } 94 | } 95 | }; 96 | debug!("[echo]start relaying"); 97 | select! { 98 | _ = echo => { 99 | debug!("echo end"); 100 | }, 101 | _ = context.shutdown.recv() => { 102 | debug!("server shutdown signal received"); 103 | }, 104 | } 105 | debug!("[echo]end relaying"); 106 | } 107 | Ok(_PHANTOM(_)) => { 108 | unreachable!("") 109 | } 110 | Err(e) => { 111 | return Err(Error::new(e)) 112 | .with_context(|| anyhow!("failed to parse connection to {:?}", target.host)); 113 | } 114 | } 115 | 116 | Ok(()) 117 | } 118 | -------------------------------------------------------------------------------- /src/server/outbound/fallback.rs: -------------------------------------------------------------------------------- 1 | use std::{io::Cursor, net::IpAddr}; 2 | 3 | use crate::utils::copy_forked; 4 | use anyhow::{anyhow, Context, Error, Result}; 5 | use tokio::{ 6 | io::{split, AsyncRead, AsyncWrite, AsyncWriteExt}, 7 | net::TcpStream, 8 | select, 9 | }; 10 | use tracing::*; 11 | 12 | pub async fn fallback( 13 | buf: Vec, 14 | fallback_port: u16, 15 | inbound: I, 16 | ) -> Result<()> { 17 | let mut outbound = TcpStream::connect((IpAddr::from([127, 0, 0, 1]), fallback_port)) 18 | .await 19 | .map_err(|e| Error::new(e)) 20 | .with_context(|| anyhow!("failed to connect to fallback service"))?; 21 | 22 | outbound 23 | .write_all_buf(&mut Cursor::new(&buf)) 24 | .await 25 | .with_context(|| anyhow!("failed to write to fallback service"))?; 26 | 27 | let (mut out_read, mut out_write) = outbound.split(); 28 | let (mut in_read, mut in_write) = split(inbound); 29 | select! { 30 | res = copy_forked(&mut out_read, &mut in_write) => { 31 | debug!("[fallback]relaying download end, {:?}", res); 32 | }, 33 | res = copy_forked(&mut in_read, &mut out_write) => { 34 | debug!("[fallback]relaying upload end, {:?}", res); 35 | }, 36 | } 37 | Ok(()) 38 | } 39 | -------------------------------------------------------------------------------- /src/server/outbound/mod.rs: -------------------------------------------------------------------------------- 1 | mod connector; 2 | mod fallback; 3 | 4 | pub use connector::handle_outbound; 5 | pub use fallback::fallback; -------------------------------------------------------------------------------- /src/server/run.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "quic")] 2 | use crate::server::inbound::quic_listener; 3 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 4 | use crate::server::inbound::tcp_tls_listener; 5 | 6 | use crate::args::TrojanContext; 7 | use anyhow::Result; 8 | use futures::TryFutureExt; 9 | use tracing::error; 10 | 11 | // todo: refactor into Server class 12 | #[cfg(feature = "server")] 13 | pub async fn run_server(mut context: TrojanContext) -> Result<()> { 14 | use tokio::sync::broadcast; 15 | 16 | let (shutdown_tx, shutdown) = broadcast::channel(1); 17 | #[cfg(feature = "quic")] 18 | tokio::spawn( 19 | quic_listener(context.clone_with_signal(shutdown)) 20 | .unwrap_or_else(move |e| error!("quic server shutdown due to {:#}", e)), 21 | ); 22 | 23 | #[cfg(any(feature = "tcp_tls", feature = "lite_tls"))] 24 | tokio::spawn( 25 | tcp_tls_listener(context.clone_with_signal(shutdown_tx.subscribe())) 26 | .unwrap_or_else(move |e| error!("tcp_tls server shutdown due to {:#}", e)), 27 | ); 28 | let _ = context.shutdown.recv().await; 29 | Ok(()) 30 | } 31 | -------------------------------------------------------------------------------- /src/server/utils/lite_tls.rs: -------------------------------------------------------------------------------- 1 | use tokio::{ 2 | io::{AsyncRead, AsyncWrite}, 3 | net::TcpStream, 4 | sync::broadcast, 5 | }; 6 | use tracing::{debug, info}; 7 | 8 | use crate::{ 9 | adapt, 10 | protocol::TCP_MAX_IDLE_TIMEOUT, 11 | utils::{ 12 | lite_tls::{LeaveTls, LiteTlsStream}, 13 | Adapter, BufferedRecv, MixAddrType, ParserError, 14 | }, 15 | }; 16 | use anyhow::{anyhow, Context, Result}; 17 | 18 | pub enum TcpOption { 19 | TLS(I), 20 | LiteTLS(I), 21 | } 22 | 23 | impl TcpOption> 24 | where 25 | I: AsyncRead + AsyncWrite + LeaveTls + Unpin, 26 | { 27 | pub async fn forward( 28 | self, 29 | mut outbound: TcpStream, 30 | target_host: &MixAddrType, 31 | shutdown: broadcast::Receiver<()>, 32 | conn_id: usize, 33 | ) -> Result<()> { 34 | use TcpOption::*; 35 | match self { 36 | TLS(inbound) => { 37 | adapt!([tcp][conn_id] 38 | inbound <=> outbound <=> target_host 39 | Until shutdown Or Sec TCP_MAX_IDLE_TIMEOUT 40 | ); 41 | } 42 | LiteTLS(mut inbound) => { 43 | let mut lite_tls_endpoint = LiteTlsStream::new_server_endpoint(); 44 | match lite_tls_endpoint 45 | .handshake_timeout(&mut outbound, &mut inbound) 46 | .await 47 | { 48 | Ok(_) => { 49 | let ver = lite_tls_endpoint.version; 50 | if ver.is_none() { 51 | // EOF 52 | return Ok(()); 53 | } 54 | info!("[{}]lite tls handshake succeed", ver.unwrap()); 55 | let mut inbound = inbound.into_inner().0.leave(); 56 | lite_tls_endpoint 57 | .flush_tls(&mut outbound, &mut inbound) 58 | .await?; 59 | debug!("lite tls start relaying"); 60 | adapt!([lite][conn_id] 61 | inbound <=> outbound <=> target_host 62 | Until shutdown Or Sec TCP_MAX_IDLE_TIMEOUT 63 | ); 64 | } 65 | Err(e) => { 66 | if let Some(ParserError::Invalid(x)) = e.downcast_ref::() { 67 | debug!("not tls stream: {}", x); 68 | lite_tls_endpoint 69 | .flush_non_tls(&mut outbound, &mut inbound) 70 | .await?; 71 | adapt!([tcp][conn_id] 72 | inbound <=> outbound <=> target_host 73 | Until shutdown Or Sec TCP_MAX_IDLE_TIMEOUT 74 | ); 75 | } 76 | } 77 | } 78 | } 79 | } 80 | Ok(()) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/server/utils/mod.rs: -------------------------------------------------------------------------------- 1 | mod lite_tls; 2 | pub use lite_tls::TcpOption; 3 | 4 | #[cfg(feature = "udp")] 5 | mod server_udp_stream; 6 | #[cfg(feature = "udp")] 7 | pub use server_udp_stream::ServerUdpStream; 8 | 9 | mod rustls_utils; 10 | pub use rustls_utils::get_server_certs_and_key; -------------------------------------------------------------------------------- /src/server/utils/rustls_utils.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{Context, Result}; 2 | use std::path::PathBuf; 3 | use tokio::fs; 4 | use tokio_rustls::rustls::{self, Certificate, PrivateKey}; 5 | 6 | pub async fn get_server_certs_and_key( 7 | key_path: &PathBuf, 8 | cert_path: &PathBuf, 9 | ) -> Result<(Vec, PrivateKey)> { 10 | let key = fs::read(key_path) 11 | .await 12 | .context("failed to read private key")?; 13 | let key = if key_path.extension().map_or(false, |x| x == "der") { 14 | rustls::PrivateKey(key) 15 | } else { 16 | let pkcs8 = rustls_pemfile::pkcs8_private_keys(&mut &*key) 17 | .context("malformed PKCS #8 private key")?; 18 | match pkcs8.into_iter().next() { 19 | Some(x) => rustls::PrivateKey(x), 20 | None => { 21 | let rsa = rustls_pemfile::rsa_private_keys(&mut &*key) 22 | .context("malformed PKCS #1 private key")?; 23 | match rsa.into_iter().next() { 24 | Some(x) => rustls::PrivateKey(x), 25 | None => { 26 | anyhow::bail!("no private keys found"); 27 | } 28 | } 29 | } 30 | } 31 | }; 32 | let cert_chain = fs::read(cert_path) 33 | .await 34 | .context("failed to read certificate chain")?; 35 | let cert_chain = if cert_path.extension().map_or(false, |x| x == "der") { 36 | vec![rustls::Certificate(cert_chain)] 37 | } else { 38 | rustls_pemfile::certs(&mut &*cert_chain) 39 | .context("invalid PEM-encoded certificate")? 40 | .into_iter() 41 | .map(rustls::Certificate) 42 | .collect() 43 | }; 44 | 45 | Ok((cert_chain, key)) 46 | } 47 | -------------------------------------------------------------------------------- /src/server/utils/server_udp_stream.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::{MixAddrType, UdpRead, UdpRelayBuffer, UdpWrite, DNS_TX}; 2 | use futures::{ready, Future}; 3 | use std::io::Result; 4 | use std::net::{IpAddr, SocketAddr}; 5 | use std::sync::Arc; 6 | use std::{ 7 | pin::Pin, 8 | task::{Context, Poll}, 9 | }; 10 | use tokio::net::UdpSocket; 11 | use tokio::sync::oneshot; 12 | #[cfg(feature = "udp_info")] 13 | use tracing::*; 14 | 15 | #[cfg_attr(feature = "debug_info", derive(Debug))] 16 | pub struct ServerUdpStream { 17 | inner: Arc, 18 | addr_task: ResolveAddr, 19 | } 20 | 21 | impl ServerUdpStream { 22 | pub fn new(inner: UdpSocket) -> Self { 23 | Self { 24 | inner: Arc::new(inner), 25 | addr_task: ResolveAddr::None, 26 | } 27 | } 28 | } 29 | 30 | impl UdpWrite for ServerUdpStream { 31 | fn poll_proxy_stream_write( 32 | mut self: Pin<&mut Self>, 33 | cx: &mut Context<'_>, 34 | buf: &[u8], 35 | addr: &MixAddrType, 36 | ) -> Poll> { 37 | mc::debug_info!(send self, "enter", ""); 38 | loop { 39 | match self.addr_task { 40 | ResolveAddr::Pending((ref mut task, ref mut missed_wakeup)) => { 41 | let ip = match Pin::new(task).poll(cx) { 42 | Poll::Pending => { 43 | *missed_wakeup = true; 44 | mc::debug_info!(send self, "resolving pending & pending", ""); 45 | return Poll::Pending; 46 | } 47 | Poll::Ready(Ok(ip)) => { 48 | if *missed_wakeup { 49 | mc::debug_info!(send self, "adding back wakeups", ""); 50 | cx.waker().wake_by_ref(); 51 | } 52 | ip 53 | } 54 | Poll::Ready(Err(_)) => return Poll::Ready(Ok(0)), 55 | }; 56 | self.addr_task = ResolveAddr::Ready((ip, addr.port()).into()); 57 | } 58 | ResolveAddr::Ready(s_addr) => { 59 | let res = self.inner.poll_send_to(cx, buf, s_addr); 60 | 61 | if let Poll::Ready(Ok(val)) = res { 62 | if val == buf.len() { 63 | self.addr_task = ResolveAddr::None; 64 | } 65 | } 66 | mc::debug_info!(send self, 67 | "ResolveAddr::Ready({})", format!("addr {:?} buf len {:?}, res {:?}", s_addr, buf.len(), res) 68 | ); 69 | return res; 70 | } 71 | ResolveAddr::None => { 72 | mc::debug_info!(send self, "ResolveAddr::None", ""); 73 | 74 | use MixAddrType::*; 75 | self.addr_task = match addr { 76 | x @ V4(_) | x @ V6(_) => ResolveAddr::Ready(x.clone().to_socket_addrs()), 77 | Hostname((name, _)) => { 78 | let name = name.to_owned(); 79 | let (task_tx, task_rx) = oneshot::channel(); 80 | tokio::spawn(async move { 81 | DNS_TX.get().unwrap().send((name, task_tx)).await 82 | }); 83 | ResolveAddr::Pending((task_rx, false)) 84 | } 85 | _ => panic!("unprecedented MixAddrType variant"), 86 | }; 87 | } 88 | } 89 | } 90 | } 91 | 92 | fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { 93 | Poll::Ready(Ok(())) 94 | } 95 | 96 | fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { 97 | Poll::Ready(Ok(())) 98 | } 99 | } 100 | 101 | #[cfg_attr(feature = "debug_info", derive(Debug))] 102 | enum ResolveAddr { 103 | Pending((oneshot::Receiver, bool)), 104 | Ready(SocketAddr), 105 | None, 106 | } 107 | 108 | impl UdpRead for ServerUdpStream { 109 | fn poll_proxy_stream_read( 110 | self: Pin<&mut Self>, 111 | cx: &mut Context<'_>, 112 | buf: &mut UdpRelayBuffer, 113 | ) -> Poll> { 114 | mc::debug_info!(recv self, "enter", format!("buf len {}", buf.len())); 115 | 116 | let mut read_buf = buf.as_read_buf(); 117 | 118 | let addr = ready!(self.inner.poll_recv_from(cx, &mut read_buf))?; 119 | 120 | let n = read_buf.filled().len(); 121 | if n == 0 { 122 | // EOF is seen 123 | return Poll::Ready(Ok(MixAddrType::None)); 124 | } 125 | 126 | // Safety: This is guaranteed to be the number of initialized (and read) 127 | // bytes due to the invariants provided by `ReadBuf::filled`. 128 | unsafe { 129 | buf.advance_mut(n); 130 | } 131 | 132 | mc::debug_info!(recv self, "read ok", format!("buf len {}, n {}", buf.len(), n)); 133 | 134 | Poll::Ready(Ok((&addr).into())) 135 | } 136 | } 137 | 138 | mod mc { 139 | macro_rules! debug_info { 140 | (recv $me:expr, $msg:expr, $addition:expr) => { 141 | #[cfg(feature = "udp_info")] 142 | debug!("ServerUdpRecv {} | {:?}", $msg, $addition); 143 | }; 144 | 145 | (send $me:expr, $msg:expr, $addition:expr) => { 146 | #[cfg(feature = "udp_info")] 147 | debug!("ServerUdpSend {} | {:?}", $msg, $addition,); 148 | }; 149 | } 150 | pub(crate) use debug_info; 151 | } 152 | -------------------------------------------------------------------------------- /src/simd/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod simd_parse; -------------------------------------------------------------------------------- /src/simd/simd_parse.rs: -------------------------------------------------------------------------------- 1 | // pub fn simd16_wrap(buf: &[u8]) -> usize { 2 | // unsafe { parse_simd_16(buf) } 3 | // } 4 | 5 | // pub fn simd8_wrap(buf: &[u8]) -> usize { 6 | // unsafe { parse_simd_8(buf) } 7 | // } 8 | 9 | #[cfg(target_arch = "aarch64")] 10 | #[inline] 11 | #[target_feature(enable = "neon")] 12 | #[allow(non_snake_case, overflowing_literals)] 13 | unsafe fn parse_simd_16(mut buf: &[u8]) -> usize { 14 | use core::arch::aarch64::*; 15 | 16 | let dash_r_mask = vdupq_n_u8(0x0d); 17 | let dash_n_mask = vdupq_n_u8(0x0a); 18 | 19 | const BYTE_MASK_DATA_HIGH: [u8; 16] = [0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 4, 8, 16, 32, 64, 128]; 20 | const BYTE_MASK_DATA_LOW: [u8; 16] = [1, 2, 4, 8, 16, 32, 64, 128, 0, 0, 0, 0, 0, 0, 0, 0]; 21 | 22 | let byte_mask_high = vld1q_u8(BYTE_MASK_DATA_HIGH.as_ptr()); 23 | let byte_mask_low = vld1q_u8(BYTE_MASK_DATA_LOW.as_ptr()); 24 | let mut res = 0; 25 | while buf.len() >= 16 { 26 | let ptr = buf.as_ptr(); 27 | let data = vld1q_u8(ptr); 28 | 29 | let bits1 = 30 | crate::neon_move_mask!(16 vceqq_u8(dash_r_mask, data), byte_mask_high, byte_mask_low); 31 | let bits2 = 32 | crate::neon_move_mask!(16 vceqq_u8(dash_n_mask, data), byte_mask_high, byte_mask_low); 33 | let ret = _clz_u64(_rbit_u64(((bits2 >> 1) & bits1) as u64) | 1 << 47) as usize; 34 | res += ret; 35 | if ret != 16 { 36 | break; 37 | } 38 | buf = &buf[16..]; 39 | } 40 | res 41 | } 42 | 43 | #[cfg(target_arch = "aarch64")] 44 | #[inline] 45 | #[target_feature(enable = "neon")] 46 | #[allow(non_snake_case, overflowing_literals)] 47 | unsafe fn parse_simd_8(mut buf: &[u8]) -> usize { 48 | use core::arch::aarch64::*; 49 | 50 | let dash_r_mask = vld1_dup_u8(&0x0d); 51 | let dash_n_mask = vld1_dup_u8(&0x0a); 52 | 53 | const BYTE_MASK_DATA: [u8; 8] = [1, 2, 4, 8, 16, 32, 64, 128]; 54 | let byte_mask = vld1_u8(BYTE_MASK_DATA.as_ptr()); 55 | let mut res = 0; 56 | while buf.len() >= 8 { 57 | let ptr = buf.as_ptr(); 58 | let data = vld1_u8(ptr); 59 | let bits1 = crate::neon_move_mask!(8 vceq_u8(dash_r_mask, data), byte_mask); 60 | let bits2 = crate::neon_move_mask!(8 vceq_u8(dash_n_mask, data), byte_mask); 61 | 62 | let ret = _clz_u64(_rbit_u64(((bits2 >> 1) & bits1) as u64) | 1 << 55) as usize; 63 | res += ret; 64 | if ret != 8 { 65 | break; 66 | } 67 | buf = &buf[8..]; 68 | } 69 | res 70 | } 71 | 72 | #[macro_export] 73 | macro_rules! neon_move_mask { 74 | (16 $mask:expr, $filter_high:expr, $filter_low:expr) => {{ 75 | let masked1 = vandq_u8($mask, $filter_high); 76 | let masked2 = vandq_u8($mask, $filter_low); 77 | 78 | ((vaddvq_u8(masked1) as u16) << 8) + vaddvq_u8(masked2) as u16 79 | }}; 80 | (8 $mask:expr, $filter:expr) => {{ 81 | let masked = vand_u8($mask, $filter); 82 | 83 | vaddv_u8(masked) 84 | }}; 85 | } 86 | 87 | pub fn parse_scalar(buf: &[u8]) -> usize { 88 | for i in 0..buf.len() - 1 { 89 | if buf[i] == b'\r' && buf[i + 1] == b'\n' { 90 | return i; 91 | } 92 | } 93 | buf.len() 94 | } 95 | -------------------------------------------------------------------------------- /src/utils/adapter.rs: -------------------------------------------------------------------------------- 1 | use std::{fmt, time::Duration}; 2 | 3 | use crate::utils::{ 4 | copy_bidirectional_forked, either_io::EitherIO, udp_copy_bidirectional, TimeoutMonitor, 5 | }; 6 | 7 | use anyhow::Result; 8 | use futures::future::{pending, Either}; 9 | use tokio::{ 10 | io::{AsyncRead, AsyncWrite}, 11 | select, 12 | sync::broadcast, 13 | }; 14 | use tracing::debug; 15 | 16 | #[cfg(all(target_os = "linux", feature = "zio"))] 17 | use {crate::VEC_TCP_TX, anyhow::Error, tokio::net::TcpStream}; 18 | 19 | use super::{UdpRead, UdpWrite}; 20 | 21 | #[macro_export] 22 | macro_rules! adapt { 23 | (tcp) => {"tcp"}; 24 | (lite) => {"lite"}; 25 | ([udp][$conn_id:ident]$inbound:ident <=> $outbound:ident Until $shutdown:ident$( Or Sec $timeout:expr)?) => { 26 | #[allow(unused_mut)] 27 | let mut adapter = Adapter::new(); 28 | $(adapter.set_timeout($timeout);)? 29 | info!("[udp][{}]", $conn_id); 30 | let reason = adapter.relay_udp($inbound, $outbound, $shutdown, $conn_id).await.with_context(|| anyhow!("[udp][{}] failed", $conn_id))?; 31 | info!("[udp][{}] end by {}", $conn_id, reason); 32 | }; 33 | ([lite][$conn_id:ident]$inbound:ident <=> $outbound:ident <=> $target_host:ident Until $shutdown:ident$( Or Sec $timeout:expr)?) => { 34 | #[cfg(all(target_os = "linux", feature = "zio"))] 35 | { 36 | // timeout is not used here. In glommio, we set tcp socket's 37 | // timeout instead. 38 | info!("[lite+][{}] => {:?}", $conn_id, $target_host); 39 | let _ = Adapter::relay_tcp_zio($inbound, $outbound, $conn_id).await?; 40 | // the ending message is printed by glommio, since 41 | // it's not yet possible to get the reason back from glommio. 42 | // // info!("[lite+][{}] end by {}", $conn_id, reason); 43 | } 44 | 45 | #[cfg(not(all(target_os = "linux", feature = "zio")))] 46 | { 47 | #[allow(unused_mut)] 48 | let mut adapter = Adapter::new(); 49 | $(adapter.set_timeout($timeout);)? 50 | info!("[lite][{}] => {:?}", $conn_id, $target_host); 51 | let reason = adapter.relay_tcp($inbound, $outbound, $shutdown).await.with_context(|| anyhow!("[lite][{}] failed", $conn_id))?; 52 | info!("[lite][{}] end by {}", $conn_id, reason); 53 | } 54 | }; 55 | ([tcp][$conn_id:ident]$inbound:ident <=> $outbound:ident <=> $target_host:ident Until $shutdown:ident$( Or Sec $timeout:expr)?) => { 56 | #[allow(unused_mut)] 57 | let mut adapter = Adapter::new(); 58 | $(adapter.set_timeout($timeout);)? 59 | info!("[tcp][{}] => {:?}", $conn_id, $target_host); 60 | let reason = adapter.relay_tcp($inbound, $outbound, $shutdown).await.with_context(|| anyhow!("[tcp][{}] failed", $conn_id))?; 61 | info!("[tcp][{}] end by {}", $conn_id, reason); 62 | }; 63 | } 64 | 65 | #[derive(Clone)] 66 | #[cfg_attr(feature = "debug_info", derive(Debug))] 67 | pub enum StreamStopReasons { 68 | Upload, 69 | Download, 70 | Timeout, 71 | Shutdown, 72 | } 73 | 74 | impl fmt::Display for StreamStopReasons { 75 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 76 | use StreamStopReasons::*; 77 | match self { 78 | Upload => write!(f, "upload"), 79 | Download => write!(f, "download"), 80 | Timeout => write!(f, "timeout"), 81 | Shutdown => write!(f, "shutdown"), 82 | } 83 | } 84 | } 85 | 86 | pub struct Adapter { 87 | timeout: Option, 88 | } 89 | 90 | impl Adapter { 91 | pub fn new() -> Self { 92 | Self { timeout: None } 93 | } 94 | 95 | #[allow(dead_code)] 96 | pub fn set_timeout(&mut self, timeout: u16) { 97 | self.timeout = Some(timeout); 98 | } 99 | 100 | pub async fn relay_tcp( 101 | &self, 102 | mut inbound: I, 103 | outbound: O, 104 | mut shutdown: broadcast::Receiver<()>, 105 | ) -> Result 106 | where 107 | I: AsyncRead + AsyncWrite + Unpin, 108 | O: AsyncRead + AsyncWrite + Unpin, 109 | { 110 | let (mut outbound, timeout): _ = match self.timeout { 111 | Some(t) => { 112 | let deadline = Duration::from_secs(t as u64); 113 | let timeout_monitor = TimeoutMonitor::new(deadline); 114 | let outbound = EitherIO::Left(timeout_monitor.watch(outbound)); 115 | (outbound, Either::Left(timeout_monitor)) 116 | } 117 | None => (EitherIO::Right(outbound), Either::Right(pending::<()>())), 118 | }; 119 | 120 | let duplex_stream: _ = copy_bidirectional_forked(&mut inbound, &mut outbound); 121 | 122 | use StreamStopReasons::*; 123 | let reason = select! { 124 | res = duplex_stream => { 125 | match res { 126 | Err((reason, e)) => { 127 | debug!("forward tcp failed: {:#}", e); 128 | reason 129 | } 130 | Ok(res) => res, 131 | } 132 | }, 133 | _ = timeout => { 134 | Timeout 135 | } 136 | _ = shutdown.recv() => { 137 | Shutdown 138 | }, 139 | }; 140 | Ok(reason) 141 | } 142 | 143 | pub async fn relay_udp( 144 | &self, 145 | mut inbound: I, 146 | outbound: O, 147 | mut shutdown: broadcast::Receiver<()>, 148 | conn_id: usize, 149 | ) -> Result 150 | where 151 | I: UdpRead + UdpWrite + Unpin, 152 | O: UdpRead + UdpWrite + Unpin, 153 | { 154 | let (mut outbound, timeout): _ = match self.timeout { 155 | Some(t) => { 156 | let deadline = Duration::from_secs(t as u64); 157 | let timeout_monitor = TimeoutMonitor::new(deadline); 158 | let outbound: _ = EitherIO::Left(timeout_monitor.watch(outbound)); 159 | (outbound, Either::Left(timeout_monitor)) 160 | } 161 | None => (EitherIO::Right(outbound), Either::Right(pending::<()>())), 162 | }; 163 | 164 | use StreamStopReasons::*; 165 | let reason = select! { 166 | res = udp_copy_bidirectional(&mut inbound, &mut outbound, conn_id) => { 167 | let (_, _, reason) = res?; 168 | reason 169 | } 170 | _ = timeout => { 171 | Timeout 172 | } 173 | _ = shutdown.recv() => { 174 | Shutdown 175 | }, 176 | }; 177 | Ok(reason) 178 | } 179 | 180 | #[cfg(all(target_os = "linux", feature = "zio"))] 181 | pub async fn relay_tcp_zio( 182 | inbound: TcpStream, 183 | outbound: TcpStream, 184 | conn_id: usize, 185 | ) -> Result<()> { 186 | let vec_tcp_tx = VEC_TCP_TX.get().unwrap(); 187 | let tcp_tx = vec_tcp_tx[conn_id % vec_tcp_tx.len()].clone(); 188 | 189 | // we transfer the ownership of the socket to glommio by sending 190 | // its std representation. tokio is no longer responsible for 191 | // releasing the socket. 192 | let inbound_std = inbound.into_std()?; 193 | let outbound_std = outbound.into_std()?; 194 | tcp_tx 195 | .send((inbound_std, outbound_std, conn_id)) 196 | .await 197 | .map_err(|e| Error::new(e).context("failed on sending"))?; 198 | Ok(()) 199 | } 200 | } 201 | -------------------------------------------------------------------------------- /src/utils/buffered_recv.rs: -------------------------------------------------------------------------------- 1 | use std::pin::Pin; 2 | use std::task::Poll; 3 | use tokio::io::{AsyncRead, AsyncWrite}; 4 | 5 | #[cfg_attr(feature = "debug_info", derive(Debug))] 6 | pub struct BufferedRecv { 7 | buffered_request: Option<(usize, Vec)>, 8 | inner: T, 9 | } 10 | 11 | impl BufferedRecv { 12 | pub fn new(inner: T, buffered_request: Option<(usize, Vec)>) -> Self { 13 | Self { 14 | inner, 15 | buffered_request, 16 | } 17 | } 18 | 19 | #[allow(dead_code)] 20 | pub fn into_inner(self) -> (T, Option<(usize, Vec)>) { 21 | (self.inner, self.buffered_request) 22 | } 23 | } 24 | 25 | impl AsyncRead for BufferedRecv 26 | where 27 | T: AsyncRead + Unpin, 28 | { 29 | fn poll_read( 30 | mut self: std::pin::Pin<&mut Self>, 31 | cx: &mut std::task::Context<'_>, 32 | buf: &mut tokio::io::ReadBuf<'_>, 33 | ) -> Poll> { 34 | if self.buffered_request.is_some() { 35 | let (index, buffered_request) = self.buffered_request.as_ref().unwrap(); 36 | buf.put_slice(&buffered_request[*index..]); 37 | self.buffered_request = None; 38 | return Poll::Ready(Ok(())); 39 | } 40 | 41 | let reader = Pin::new(&mut self.inner); 42 | reader.poll_read(cx, buf) 43 | } 44 | } 45 | 46 | impl AsyncWrite for BufferedRecv 47 | where 48 | T: AsyncWrite + Unpin, 49 | { 50 | fn poll_write( 51 | mut self: Pin<&mut Self>, 52 | cx: &mut std::task::Context<'_>, 53 | buf: &[u8], 54 | ) -> Poll> { 55 | Pin::new(&mut self.inner).poll_write(cx, buf) 56 | } 57 | 58 | fn poll_flush( 59 | mut self: Pin<&mut Self>, 60 | cx: &mut std::task::Context<'_>, 61 | ) -> Poll> { 62 | Pin::new(&mut self.inner).poll_flush(cx) 63 | } 64 | 65 | fn poll_shutdown( 66 | mut self: Pin<&mut Self>, 67 | cx: &mut std::task::Context<'_>, 68 | ) -> Poll> { 69 | Pin::new(&mut self.inner).poll_shutdown(cx) 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/utils/buffers.rs: -------------------------------------------------------------------------------- 1 | use bytes::BufMut; 2 | use tokio::io::ReadBuf; 3 | 4 | pub trait CursoredBuffer { 5 | fn chunk(&self) -> &[u8]; 6 | fn advance(&mut self, len: usize); 7 | fn remaining(&self) -> usize { 8 | self.chunk().len() 9 | } 10 | } 11 | 12 | impl<'a> CursoredBuffer for (&'a mut usize, &Vec) { 13 | fn chunk(&self) -> &[u8] { 14 | &self.1[*self.0..] 15 | } 16 | 17 | fn advance(&mut self, len: usize) { 18 | assert!( 19 | self.1.len() >= *self.0 + len, 20 | "(&'a mut usize, &Vec) was about to set a larger position than it's length" 21 | ); 22 | *self.0 += len; 23 | } 24 | } 25 | 26 | pub trait VecAsReadBufExt<'a> { 27 | fn as_read_buf(&'a mut self) -> ReadBuf<'a>; 28 | } 29 | 30 | impl<'a> VecAsReadBufExt<'a> for Vec { 31 | fn as_read_buf(&'a mut self) -> ReadBuf<'a> { 32 | let dst = self.chunk_mut(); 33 | let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit]) }; 34 | ReadBuf::uninit(dst) 35 | } 36 | } 37 | 38 | pub trait ExtendableFromSlice { 39 | fn extend_from_slice(&mut self, src: &[u8]); 40 | } 41 | 42 | impl ExtendableFromSlice for Vec { 43 | fn extend_from_slice(&mut self, src: &[u8]) { 44 | self.extend_from_slice(src); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/utils/copy_tcp.rs: -------------------------------------------------------------------------------- 1 | // use std::mem::MaybeUninit; 2 | // use crate::protocol::RELAY_BUFFER_SIZE; 3 | // use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; 4 | 5 | // pub async fn copy_to_tls( 6 | // r: &mut R, 7 | // w: &mut W, 8 | // ) -> std::io::Result { 9 | // // safety: We don't realy care what's previouly in the buffer 10 | // let mut buf = unsafe { 11 | // let buf: [MaybeUninit; RELAY_BUFFER_SIZE] = MaybeUninit::uninit().assume_init(); 12 | // std::mem::transmute::<_, [u8; RELAY_BUFFER_SIZE]>(buf) 13 | // }; 14 | 15 | // loop { 16 | // let len = r.read(&mut buf).await?; 17 | // if len == 0 { 18 | // return Ok(0); 19 | // } 20 | // // let mut writen = 0; 21 | // w.write(&buf[..len]).await?; 22 | // // loop { 23 | // // writen += w.write(&buf[writen..len]).await?; 24 | // // if writen == len { 25 | // // break; 26 | // // } 27 | // // } 28 | // if len != buf.len() { 29 | // w.flush().await?; 30 | // } 31 | // } 32 | // } 33 | -------------------------------------------------------------------------------- /src/utils/dns_utils/dns_resolver.rs: -------------------------------------------------------------------------------- 1 | use std::{net::IpAddr, time::Duration}; 2 | 3 | use fxhash::FxHashMap; 4 | use tokio::{ 5 | net::lookup_host, 6 | select, 7 | sync::{mpsc, oneshot, OnceCell}, 8 | time::{sleep_until, Instant, Sleep}, 9 | }; 10 | use tracing::{error, info}; 11 | 12 | use crate::protocol::{BLACK_HOLE_LOCAL_ADDR, DNS_UPDATE_PERIOD_SEC}; 13 | 14 | type DNSTask = (Box, oneshot::Sender); 15 | type DNSTx = mpsc::Sender; 16 | type DNSRx = mpsc::Receiver; 17 | 18 | pub static DNS_TX: OnceCell = OnceCell::const_new(); 19 | 20 | pub fn start_dns_resolver_thread() { 21 | info!("starting dns resolver"); 22 | let (dns_tx, dns_rx) = mpsc::channel(100); 23 | tokio::spawn(dns_resolver(dns_rx)); 24 | let _ = DNS_TX 25 | .set(dns_tx) 26 | .map_err(|e| error!("failed to set DNS_TX: {:#}", e)); 27 | } 28 | 29 | async fn dns_resolver(mut incoming_tasks: DNSRx) { 30 | let mut cache = FxHashMap::, (IpAddr, usize)>::default(); 31 | let (update_cache_tx, mut update_cache_rx) = mpsc::channel(100); 32 | let mut timer = Timer::new(); 33 | loop { 34 | select! { 35 | _ = timer.sleep() => { 36 | timer.update(); 37 | timer.slower(); 38 | } 39 | maybe_cache_update = update_cache_rx.recv() => { 40 | match maybe_cache_update { 41 | Some((query, ip)) => { 42 | cache.insert(query, (ip, timer.counter)); 43 | }, 44 | None => error!("unexpected None received when updating cache"), 45 | } 46 | } 47 | maybe_task = incoming_tasks.recv() => { 48 | timer.faster(); 49 | match maybe_task { 50 | None => { 51 | info!("shutting down dns resolver"); 52 | return; 53 | } 54 | Some((addr, ret_tx)) => { 55 | if let Some((ip, timestamp)) = cache.get(&addr) { 56 | if timer.counter - timestamp < 2 { 57 | let _ = ret_tx.send(*ip); 58 | continue; 59 | } 60 | } 61 | tokio::spawn(single_resolve_task(addr, ret_tx, update_cache_tx.clone())); 62 | } 63 | } 64 | } 65 | } 66 | } 67 | } 68 | 69 | struct Timer { 70 | instant: Instant, 71 | counter: usize, 72 | slow: usize, 73 | } 74 | 75 | impl Timer { 76 | fn new() -> Self { 77 | Self { 78 | instant: Instant::now(), 79 | counter: 0, 80 | slow: 0, 81 | } 82 | } 83 | fn sleep(&self) -> Sleep { 84 | return sleep_until( 85 | self.instant 86 | + Duration::from_secs(if self.slow == 0 { 87 | DNS_UPDATE_PERIOD_SEC * 20 88 | } else { 89 | DNS_UPDATE_PERIOD_SEC 90 | }), 91 | ); 92 | } 93 | 94 | fn update(&mut self) { 95 | self.counter += 1; 96 | self.instant = Instant::now(); 97 | } 98 | 99 | fn slower(&mut self) { 100 | if self.slow > 0 { 101 | self.slow -= 1; 102 | } 103 | } 104 | 105 | fn faster(&mut self) { 106 | self.slow = 2; 107 | } 108 | } 109 | 110 | async fn single_resolve_task( 111 | query: Box, 112 | ret_tx: oneshot::Sender, 113 | update_cache_tx: mpsc::Sender<(Box, IpAddr)>, 114 | ) { 115 | let res = match lookup_host((&*query, 0)).await { 116 | Ok(mut iter) => match iter.next() { 117 | Some(res) => res.ip(), 118 | None => { 119 | error!("failed to lookup host: the result is empty"); 120 | BLACK_HOLE_LOCAL_ADDR.into() 121 | } 122 | }, 123 | Err(e) => { 124 | error!("failed to lookup host: {:#}", e); 125 | BLACK_HOLE_LOCAL_ADDR.into() 126 | } 127 | }; 128 | let _ = ret_tx.send(res); 129 | let _ = update_cache_tx.send((query, res)).await; 130 | } 131 | -------------------------------------------------------------------------------- /src/utils/dns_utils/mod.rs: -------------------------------------------------------------------------------- 1 | mod dns_resolver; 2 | pub use dns_resolver::{start_dns_resolver_thread, DNS_TX}; 3 | -------------------------------------------------------------------------------- /src/utils/either_io.rs: -------------------------------------------------------------------------------- 1 | use std::pin::Pin; 2 | 3 | use tokio::io::{AsyncRead, AsyncWrite}; 4 | 5 | use super::{UdpRead, UdpWrite}; 6 | 7 | pub enum EitherIO { 8 | Left(IO1), 9 | Right(IO2), 10 | } 11 | 12 | impl AsyncWrite for EitherIO 13 | where 14 | IO1: AsyncWrite + Unpin, 15 | IO2: AsyncWrite + Unpin, 16 | { 17 | fn poll_write( 18 | mut self: Pin<&mut Self>, 19 | cx: &mut std::task::Context<'_>, 20 | buf: &[u8], 21 | ) -> std::task::Poll> { 22 | match *self { 23 | EitherIO::Left(ref mut io) => Pin::new(io).poll_write(cx, buf), 24 | EitherIO::Right(ref mut io) => Pin::new(io).poll_write(cx, buf), 25 | } 26 | } 27 | 28 | fn poll_flush( 29 | mut self: Pin<&mut Self>, 30 | cx: &mut std::task::Context<'_>, 31 | ) -> std::task::Poll> { 32 | match *self { 33 | EitherIO::Left(ref mut io) => Pin::new(io).poll_flush(cx), 34 | EitherIO::Right(ref mut io) => Pin::new(io).poll_flush(cx), 35 | } 36 | } 37 | 38 | fn poll_shutdown( 39 | mut self: Pin<&mut Self>, 40 | cx: &mut std::task::Context<'_>, 41 | ) -> std::task::Poll> { 42 | match *self { 43 | EitherIO::Left(ref mut io) => Pin::new(io).poll_shutdown(cx), 44 | EitherIO::Right(ref mut io) => Pin::new(io).poll_shutdown(cx), 45 | } 46 | } 47 | } 48 | 49 | impl AsyncRead for EitherIO 50 | where 51 | IO1: AsyncRead + Unpin, 52 | IO2: AsyncRead + Unpin, 53 | { 54 | fn poll_read( 55 | mut self: Pin<&mut Self>, 56 | cx: &mut std::task::Context<'_>, 57 | buf: &mut tokio::io::ReadBuf<'_>, 58 | ) -> std::task::Poll> { 59 | match *self { 60 | EitherIO::Left(ref mut io) => Pin::new(io).poll_read(cx, buf), 61 | EitherIO::Right(ref mut io) => Pin::new(io).poll_read(cx, buf), 62 | } 63 | } 64 | } 65 | 66 | impl UdpRead for EitherIO 67 | where 68 | IO1: UdpRead + Unpin, 69 | IO2: UdpRead + Unpin, 70 | { 71 | fn poll_proxy_stream_read( 72 | mut self: Pin<&mut Self>, 73 | cx: &mut std::task::Context<'_>, 74 | buf: &mut super::UdpRelayBuffer, 75 | ) -> std::task::Poll> { 76 | use EitherIO::*; 77 | match *self { 78 | Left(ref mut io) => Pin::new(io).poll_proxy_stream_read(cx, buf), 79 | Right(ref mut io) => Pin::new(io).poll_proxy_stream_read(cx, buf), 80 | } 81 | } 82 | } 83 | 84 | impl UdpWrite for EitherIO 85 | where 86 | IO1: UdpWrite + Unpin, 87 | IO2: UdpWrite + Unpin, 88 | { 89 | fn poll_proxy_stream_write( 90 | mut self: Pin<&mut Self>, 91 | cx: &mut std::task::Context<'_>, 92 | buf: &[u8], 93 | addr: &super::MixAddrType, 94 | ) -> std::task::Poll> { 95 | use EitherIO::*; 96 | match *self { 97 | Left(ref mut io) => Pin::new(io).poll_proxy_stream_write(cx, buf, addr), 98 | Right(ref mut io) => Pin::new(io).poll_proxy_stream_write(cx, buf, addr), 99 | } 100 | } 101 | 102 | fn poll_flush( 103 | mut self: Pin<&mut Self>, 104 | cx: &mut std::task::Context<'_>, 105 | ) -> std::task::Poll> { 106 | use EitherIO::*; 107 | match *self { 108 | Left(ref mut io) => Pin::new(io).poll_flush(cx), 109 | Right(ref mut io) => Pin::new(io).poll_flush(cx), 110 | } 111 | } 112 | 113 | fn poll_shutdown( 114 | mut self: Pin<&mut Self>, 115 | cx: &mut std::task::Context<'_>, 116 | ) -> std::task::Poll> { 117 | use EitherIO::*; 118 | match *self { 119 | Left(ref mut io) => Pin::new(io).poll_shutdown(cx), 120 | Right(ref mut io) => Pin::new(io).poll_shutdown(cx), 121 | } 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/utils/forked_copy/copy_bidirectional.rs: -------------------------------------------------------------------------------- 1 | use super::CopyBuffer; 2 | 3 | use tokio::io::{AsyncRead, AsyncWrite}; 4 | 5 | use crate::utils::StreamStopReasons; 6 | use futures::ready; 7 | use std::future::Future; 8 | use std::io; 9 | use std::pin::Pin; 10 | use std::task::{Context, Poll}; 11 | 12 | enum TransferState { 13 | Running(CopyBuffer), 14 | ShuttingDown(u64), 15 | Done(u64), 16 | } 17 | 18 | struct CopyBidirectional<'a, I, O> { 19 | i: &'a mut I, 20 | o: &'a mut O, 21 | upload: TransferState, 22 | download: TransferState, 23 | stop_reason: StreamStopReasons, 24 | } 25 | 26 | fn transfer_one_direction( 27 | cx: &mut Context<'_>, 28 | state: &mut TransferState, 29 | r: &mut A, 30 | w: &mut B, 31 | ) -> Poll> 32 | where 33 | A: AsyncRead + AsyncWrite + Unpin, 34 | B: AsyncRead + AsyncWrite + Unpin, 35 | { 36 | let mut r = Pin::new(r); 37 | let mut w = Pin::new(w); 38 | 39 | loop { 40 | match state { 41 | TransferState::Running(buf) => { 42 | let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; 43 | *state = TransferState::ShuttingDown(count); 44 | } 45 | TransferState::ShuttingDown(count) => { 46 | ready!(w.as_mut().poll_shutdown(cx))?; 47 | 48 | *state = TransferState::Done(*count); 49 | } 50 | TransferState::Done(count) => return Poll::Ready(Ok(*count)), 51 | } 52 | } 53 | } 54 | 55 | impl<'a, I, O> Future for CopyBidirectional<'a, I, O> 56 | where 57 | I: AsyncRead + AsyncWrite + Unpin, 58 | O: AsyncRead + AsyncWrite + Unpin, 59 | { 60 | type Output = Result; 61 | 62 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 63 | // Unpack self into mut refs to each field to avoid borrow check issues. 64 | use StreamStopReasons::*; 65 | let CopyBidirectional { 66 | i, 67 | o, 68 | upload, 69 | download, 70 | stop_reason, 71 | } = &mut *self; 72 | 73 | let upload = 74 | transfer_one_direction(cx, upload, &mut *i, &mut *o).map_err(|e| (Upload, e))?; 75 | let download = 76 | transfer_one_direction(cx, download, &mut *o, &mut *i).map_err(|e| (Download, e))?; 77 | 78 | // It is not a problem if ready! returns early because transfer_one_direction for the 79 | // other direction will keep returning TransferState::Done(count) in future calls to poll 80 | use Poll::*; 81 | match (upload, download) { 82 | (Pending, Pending) => Pending, 83 | (Ready(_), Pending) => { 84 | *stop_reason = Upload; 85 | Pending 86 | } 87 | (Pending, Ready(_)) => { 88 | *stop_reason = Download; 89 | Pending 90 | } 91 | (Ready(_), Ready(_)) => Ready(Ok(stop_reason.clone())), 92 | } 93 | } 94 | } 95 | 96 | pub async fn copy_bidirectional_forked( 97 | inbound: &mut I, 98 | outbound: &mut O, 99 | ) -> Result 100 | where 101 | I: AsyncRead + AsyncWrite + Unpin, 102 | O: AsyncRead + AsyncWrite + Unpin, 103 | { 104 | CopyBidirectional { 105 | i: inbound, 106 | o: outbound, 107 | upload: TransferState::Running(CopyBuffer::new()), 108 | download: TransferState::Running(CopyBuffer::new()), 109 | stop_reason: StreamStopReasons::Download, 110 | } 111 | .await 112 | } 113 | -------------------------------------------------------------------------------- /src/utils/forked_copy/copy_buf.rs: -------------------------------------------------------------------------------- 1 | // use tokio::io::{copy_bidirectional}; 2 | 3 | use crate::protocol::RELAY_BUFFER_SIZE; 4 | use futures::{ready, Future}; 5 | use std::io; 6 | use std::pin::Pin; 7 | use std::task::{Context, Poll}; 8 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 9 | 10 | #[cfg_attr(feature = "debug_info", derive(Debug))] 11 | pub(super) struct CopyBuffer { 12 | read_done: bool, 13 | need_flush: bool, 14 | pos: usize, 15 | cap: usize, 16 | amt: u64, 17 | buf: Box<[u8]>, 18 | } 19 | 20 | impl CopyBuffer { 21 | pub(super) fn new() -> Self { 22 | Self { 23 | read_done: false, 24 | need_flush: false, 25 | pos: 0, 26 | cap: 0, 27 | amt: 0, 28 | buf: vec![0; RELAY_BUFFER_SIZE].into_boxed_slice(), 29 | } 30 | } 31 | 32 | pub(super) fn poll_copy( 33 | &mut self, 34 | cx: &mut Context<'_>, 35 | mut reader: Pin<&mut R>, 36 | mut writer: Pin<&mut W>, 37 | ) -> Poll> 38 | where 39 | R: AsyncRead, 40 | W: AsyncWrite, 41 | { 42 | loop { 43 | // If our buffer is empty, then we need to read some data to 44 | // continue. 45 | if self.pos == self.cap && !self.read_done { 46 | let me = &mut *self; 47 | let mut buf = ReadBuf::new(&mut me.buf); 48 | 49 | match reader.as_mut().poll_read(cx, &mut buf) { 50 | Poll::Ready(Ok(_)) => (), 51 | Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), 52 | Poll::Pending => { 53 | // Try flushing when the reader has no progress to avoid deadlock 54 | // when the reader depends on buffered writer. 55 | if self.need_flush { 56 | ready!(writer.as_mut().poll_flush(cx))?; 57 | self.need_flush = false; 58 | } 59 | 60 | return Poll::Pending; 61 | } 62 | } 63 | 64 | let n = buf.filled().len(); 65 | if n == 0 { 66 | self.read_done = true; 67 | } else { 68 | self.pos = 0; 69 | self.cap = n; 70 | } 71 | } 72 | 73 | // If our buffer has some data, let's write it out! 74 | while self.pos < self.cap { 75 | let me = &mut *self; 76 | let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?; 77 | if i == 0 { 78 | return Poll::Ready(Err(io::Error::new( 79 | io::ErrorKind::WriteZero, 80 | "write zero byte into writer", 81 | ))); 82 | } else { 83 | self.pos += i; 84 | self.amt += i as u64; 85 | self.need_flush = true; 86 | } 87 | } 88 | 89 | // If we've written all the data and we've seen EOF, flush out the 90 | // data and finish the transfer. 91 | if self.pos == self.cap && self.read_done { 92 | ready!(writer.as_mut().poll_flush(cx))?; 93 | return Poll::Ready(Ok(self.amt)); 94 | } 95 | } 96 | } 97 | } 98 | 99 | /// A future that asynchronously copies the entire contents of a reader into a 100 | /// writer. 101 | #[cfg_attr(feature = "debug_info", derive(Debug))] 102 | #[must_use = "futures do nothing unless you `.await` or poll them"] 103 | struct Copy<'a, R, W> { 104 | reader: &'a mut R, 105 | writer: &'a mut W, 106 | buf: CopyBuffer, 107 | } 108 | 109 | /// Asynchronously copies the entire contents of a reader into a writer. 110 | /// 111 | /// This function returns a future that will continuously read data from 112 | /// `reader` and then write it into `writer` in a streaming fashion until 113 | /// `reader` returns EOF. 114 | /// 115 | /// On success, the total number of bytes that were copied from `reader` to 116 | /// `writer` is returned. 117 | /// 118 | /// This is an asynchronous version of [`std::io::copy`][std]. 119 | /// 120 | /// [std]: std::io::copy 121 | /// 122 | /// # Errors 123 | /// 124 | /// The returned future will return an error immediately if any call to 125 | /// `poll_read` or `poll_write` returns an error. 126 | /// 127 | /// # Examples 128 | /// 129 | /// ``` 130 | /// use tokio::io; 131 | /// 132 | /// # async fn dox() -> std::io::Result<()> { 133 | /// let mut reader: &[u8] = b"hello"; 134 | /// let mut writer: Vec = vec![]; 135 | /// 136 | /// io::copy(&mut reader, &mut writer).await?; 137 | /// 138 | /// assert_eq!(&b"hello"[..], &writer[..]); 139 | /// # Ok(()) 140 | /// # } 141 | /// ``` 142 | pub async fn copy_forked<'a, R, W>(reader: &'a mut R, writer: &'a mut W) -> io::Result 143 | where 144 | R: AsyncRead + Unpin, 145 | W: AsyncWrite + Unpin, 146 | { 147 | Copy { 148 | reader, 149 | writer, 150 | buf: CopyBuffer::new(), 151 | } 152 | .await 153 | } 154 | 155 | impl Future for Copy<'_, R, W> 156 | where 157 | R: AsyncRead + Unpin, 158 | W: AsyncWrite + Unpin, 159 | { 160 | type Output = io::Result; 161 | 162 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 163 | let me = &mut *self; 164 | 165 | me.buf 166 | .poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer)) 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /src/utils/forked_copy/mod.rs: -------------------------------------------------------------------------------- 1 | /// this is a forked version of tokio::io::copy 2 | mod copy_buf; 3 | pub use copy_buf::copy_forked; 4 | use copy_buf::CopyBuffer; 5 | mod copy_bidirectional; 6 | pub use copy_bidirectional::copy_bidirectional_forked; -------------------------------------------------------------------------------- /src/utils/glommio_utils/copy_bidirectional.rs: -------------------------------------------------------------------------------- 1 | /// this is a replica of tokio::io::copy_bidirectional 2 | /// but based on `futures` traits 3 | use super::copy_buf::CopyBuffer; 4 | 5 | use futures::{ready, AsyncRead, AsyncWrite}; 6 | 7 | use std::future::Future; 8 | use std::io; 9 | use std::pin::Pin; 10 | use std::task::{Context, Poll}; 11 | 12 | use crate::utils::StreamStopReasons; 13 | 14 | use tracing::debug; 15 | 16 | enum TransferState { 17 | Running(CopyBuffer), 18 | ShuttingDown(u64), 19 | Done(u64), 20 | } 21 | 22 | struct CopyBidirectional<'a, A, B> { 23 | a: &'a mut A, 24 | b: &'a mut B, 25 | a_to_b: TransferState, 26 | b_to_a: TransferState, 27 | stop_reason: StreamStopReasons, 28 | } 29 | 30 | fn transfer_one_direction( 31 | cx: &mut Context<'_>, 32 | state: &mut TransferState, 33 | r: &mut A, 34 | w: &mut B, 35 | ) -> Poll> 36 | where 37 | A: AsyncRead + AsyncWrite + Unpin, 38 | B: AsyncRead + AsyncWrite + Unpin, 39 | { 40 | let mut r = Pin::new(r); 41 | let mut w = Pin::new(w); 42 | 43 | loop { 44 | match state { 45 | TransferState::Running(buf) => { 46 | debug!("transfer_one_direction: running"); 47 | let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; 48 | *state = TransferState::ShuttingDown(count); 49 | } 50 | TransferState::ShuttingDown(count) => { 51 | debug!("transfer_one_direction: ShuttingDown"); 52 | ready!(w.as_mut().poll_close(cx))?; 53 | 54 | *state = TransferState::Done(*count); 55 | } 56 | TransferState::Done(count) => return Poll::Ready(Ok(*count)), 57 | } 58 | } 59 | } 60 | 61 | impl<'a, A, B> Future for CopyBidirectional<'a, A, B> 62 | where 63 | A: AsyncRead + AsyncWrite + Unpin, 64 | B: AsyncRead + AsyncWrite + Unpin, 65 | { 66 | type Output = io::Result<(u64, u64, StreamStopReasons)>; 67 | 68 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 69 | // Unpack self into mut refs to each field to avoid borrow check issues. 70 | let CopyBidirectional { 71 | a, 72 | b, 73 | a_to_b, 74 | b_to_a, 75 | stop_reason, 76 | } = &mut *self; 77 | 78 | let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?; 79 | let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?; 80 | 81 | // It is not a problem if ready! returns early because transfer_one_direction for the 82 | // other direction will keep returning TransferState::Done(count) in future calls to poll 83 | use Poll::*; 84 | use StreamStopReasons::*; 85 | match (a_to_b, b_to_a) { 86 | (Pending, Pending) => Pending, 87 | (Ready(_a_to_b), Pending) => { 88 | *stop_reason = Upload; 89 | Pending 90 | } 91 | (Pending, Ready(_b_to_a)) => { 92 | *stop_reason = Download; 93 | Pending 94 | } 95 | (Ready(a_to_b), Ready(b_to_a)) => { 96 | Ready(Ok((a_to_b, b_to_a, stop_reason.clone()))) 97 | } 98 | } 99 | } 100 | } 101 | 102 | pub async fn glommio_copy_bidirectional( 103 | a: &mut A, 104 | b: &mut B, 105 | ) -> Result<(u64, u64, StreamStopReasons), std::io::Error> 106 | where 107 | A: AsyncRead + AsyncWrite + Unpin, 108 | B: AsyncRead + AsyncWrite + Unpin, 109 | { 110 | CopyBidirectional { 111 | a, 112 | b, 113 | a_to_b: TransferState::Running(CopyBuffer::new()), 114 | b_to_a: TransferState::Running(CopyBuffer::new()), 115 | stop_reason: StreamStopReasons::Download, 116 | } 117 | .await 118 | } 119 | -------------------------------------------------------------------------------- /src/utils/glommio_utils/copy_buf.rs: -------------------------------------------------------------------------------- 1 | use futures::{ready, AsyncRead, AsyncWrite}; 2 | use std::io; 3 | use std::pin::Pin; 4 | use std::task::{Context, Poll}; 5 | use tracing::debug; 6 | use crate::protocol::RELAY_BUFFER_SIZE; 7 | 8 | pub(super) struct CopyBuffer { 9 | read_done: bool, 10 | pos: usize, 11 | cap: usize, 12 | amt: u64, 13 | buf: Box<[u8]>, 14 | } 15 | 16 | impl CopyBuffer { 17 | pub(super) fn new() -> Self { 18 | Self { 19 | read_done: false, 20 | pos: 0, 21 | cap: 0, 22 | amt: 0, 23 | buf: vec![0; RELAY_BUFFER_SIZE].into_boxed_slice(), 24 | } 25 | } 26 | 27 | pub(super) fn poll_copy( 28 | &mut self, 29 | cx: &mut Context<'_>, 30 | mut reader: Pin<&mut R>, 31 | mut writer: Pin<&mut W>, 32 | ) -> Poll> 33 | where 34 | R: AsyncRead + ?Sized, 35 | W: AsyncWrite + ?Sized, 36 | { 37 | loop { 38 | // If our buffer is empty, then we need to read some data to 39 | // continue. 40 | if self.pos == self.cap && !self.read_done { 41 | debug!("poll_copy: poll_read"); 42 | let n = ready!(reader.as_mut().poll_read(cx, &mut *self.buf))?; 43 | debug!("poll_copy: poll_read {}", n); 44 | if n == 0 { 45 | self.read_done = true; 46 | } else { 47 | self.pos = 0; 48 | self.cap = n; 49 | } 50 | } 51 | 52 | // If our buffer has some data, let's write it out! 53 | while self.pos < self.cap { 54 | debug!("poll_copy: poll_write"); 55 | let i = ready!(writer 56 | .as_mut() 57 | .poll_write(cx, &self.buf[self.pos..self.cap]))?; 58 | debug!("poll_copy: poll_write {}", i); 59 | if i == 0 { 60 | return Poll::Ready(Err(io::Error::new( 61 | io::ErrorKind::WriteZero, 62 | "write zero byte into writer", 63 | ))); 64 | } else { 65 | self.pos += i; 66 | self.amt += i as u64; 67 | } 68 | } 69 | 70 | // If we've written all the data and we've seen EOF, flush out the 71 | // data and finish the transfer. 72 | if self.pos == self.cap && self.read_done { 73 | ready!(writer.as_mut().poll_flush(cx))?; 74 | return Poll::Ready(Ok(self.amt)); 75 | } 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/utils/glommio_utils/mod.rs: -------------------------------------------------------------------------------- 1 | mod copy_bidirectional; 2 | mod copy_buf; 3 | mod start_tcp_relay_thread; 4 | 5 | use std::net::TcpStream; 6 | 7 | pub use start_tcp_relay_thread::start_tcp_relay_threads; 8 | pub type TcpTx = tokio::sync::mpsc::Sender<(TcpStream, TcpStream, usize)>; 9 | pub type TcpRx = tokio::sync::mpsc::Receiver<(TcpStream, TcpStream, usize)>; 10 | -------------------------------------------------------------------------------- /src/utils/glommio_utils/start_tcp_relay_thread.rs: -------------------------------------------------------------------------------- 1 | use std::os::unix::io::FromRawFd; 2 | use std::os::unix::prelude::IntoRawFd; 3 | use std::time::Duration; 4 | 5 | use glommio::net::TcpStream; 6 | use glommio::{Local, LocalExecutorBuilder}; 7 | use tokio::sync::mpsc; 8 | 9 | use super::copy_bidirectional::glommio_copy_bidirectional; 10 | use super::{TcpRx, TcpTx}; 11 | use crate::protocol::TCP_MAX_IDLE_TIMEOUT; 12 | use glommio::Result; 13 | use tracing::*; 14 | 15 | async fn tcp_relay_task(mut inbound: TcpStream, mut outbound: TcpStream, conn_id: usize) { 16 | #[cfg(feature = "debug_info")] 17 | debug!("tcp_relay_task: entered"); 18 | match glommio_copy_bidirectional(&mut inbound, &mut outbound).await { 19 | Ok((_, _, reason)) => { 20 | info!("[lite+][{}] end by {}", conn_id, reason); 21 | } 22 | Err(e) => { 23 | info!("[lite+][{}] end by {:?}", conn_id, e); 24 | } 25 | } 26 | } 27 | 28 | fn init_tcp_stream(std_tcp_stream: std::net::TcpStream) -> Result { 29 | let stream: TcpStream = unsafe { 30 | // safety: both steps are infallible, therefore the socket 31 | // will always be under control. 32 | let raw_fd = std_tcp_stream.into_raw_fd(); 33 | TcpStream::from_raw_fd(raw_fd) 34 | }; 35 | stream.set_nodelay(true)?; 36 | stream.set_read_timeout(Some(Duration::from_secs(TCP_MAX_IDLE_TIMEOUT as u64)))?; 37 | stream.set_write_timeout(Some(Duration::from_secs(TCP_MAX_IDLE_TIMEOUT as u64)))?; 38 | Ok(stream) 39 | } 40 | 41 | async fn worker(mut tcp_rx: TcpRx) { 42 | while let Some((inbound_fd, outbound_fd, ret)) = tcp_rx.recv().await { 43 | let inbound = match init_tcp_stream(inbound_fd) { 44 | Ok(s) => s, 45 | Err(e) => { 46 | error!("glommio initing tcp inbound failed, {:?}", e); 47 | return; 48 | } 49 | }; 50 | let outbound = match init_tcp_stream(outbound_fd) { 51 | Ok(s) => s, 52 | Err(e) => { 53 | error!("glommio initing tcp outbound failed, {:?}", e); 54 | return; 55 | } 56 | }; 57 | Local::local(tcp_relay_task(inbound, outbound, ret)).detach(); 58 | } 59 | } 60 | 61 | pub fn start_tcp_relay_threads() -> Vec { 62 | let numc = num_cpus::get(); 63 | let mut tcp_submit = Vec::with_capacity(numc); 64 | for i in 0..numc { 65 | info!("starting glommio runtime: {}", i); 66 | let (tcp_tx, tcp_rx) = mpsc::channel(100); 67 | tcp_submit.push(tcp_tx); 68 | std::thread::spawn(move || { 69 | let ex = LocalExecutorBuilder::new().pin_to_cpu(i).make().unwrap(); 70 | ex.run(worker(tcp_rx)); 71 | }); 72 | } 73 | tcp_submit 74 | } 75 | 76 | #[tokio::test] 77 | async fn test_glommio() { 78 | use crate::utils::Adapter; 79 | use crate::VEC_TCP_TX; 80 | use tokio::io::AsyncReadExt; 81 | use tokio::io::AsyncWriteExt; 82 | use tokio::time::sleep; 83 | 84 | let collector = tracing_subscriber::fmt() 85 | .with_max_level(tracing::Level::INFO) 86 | .finish(); 87 | let _ = tracing::subscriber::set_global_default(collector); 88 | 89 | let tcp_submit = start_tcp_relay_threads(); 90 | let _ = VEC_TCP_TX.set(tcp_submit); 91 | 92 | // spawn a server in tokio 93 | tokio::spawn(async move { 94 | let server_listener = tokio::net::TcpListener::bind("0.0.0.0:5555").await.unwrap(); 95 | info!("server started"); 96 | loop { 97 | let mut stream = match server_listener.accept().await { 98 | Ok((s, _)) => s, 99 | Err(e) => { 100 | error!("server[1] {:?}", e); 101 | continue; 102 | } 103 | }; 104 | // info!("server incoming"); 105 | 106 | tokio::spawn(async move { 107 | let mut buf = [0u8; 2048]; 108 | loop { 109 | let n = match stream.read(&mut buf).await { 110 | Ok(n) => n, 111 | Err(e) => { 112 | error!("server[2] {:?}", e); 113 | return; 114 | } 115 | }; 116 | 117 | let _ = stream.write(&buf[..n]).await; 118 | } 119 | }); 120 | } 121 | }); 122 | 123 | // start the proxy 124 | tokio::spawn(async move { 125 | let proxy_listener = tokio::net::TcpListener::bind("0.0.0.0:6666").await.unwrap(); 126 | info!("proxy started"); 127 | 128 | let mut conn_id = 0; 129 | loop { 130 | let mut inbound = match proxy_listener.accept().await { 131 | Ok((s, _)) => s, 132 | Err(e) => { 133 | error!("proxy[1] {:?}", e); 134 | continue; 135 | } 136 | }; 137 | // info!("proxy incoming"); 138 | 139 | conn_id += 1; 140 | 141 | tokio::spawn(async move { 142 | let mut outbound = match tokio::net::TcpStream::connect("127.0.0.1:5555").await { 143 | Ok(s) => s, 144 | Err(e) => { 145 | error!("proxy[2] {:?}", e); 146 | return; 147 | } 148 | }; 149 | let mut buf = [0; 4096]; 150 | let a = match inbound.read(&mut buf).await { 151 | Ok(x) => x, 152 | Err(x) => { 153 | error!("proxy[4] {:?}", x); 154 | return; 155 | } 156 | }; 157 | let _ = match outbound.write(&buf[..a]).await { 158 | Ok(x) => x, 159 | Err(x) => { 160 | error!("proxy[5] {:?}", x); 161 | return; 162 | } 163 | }; 164 | 165 | let a = match outbound.read(&mut buf).await { 166 | Ok(x) => x, 167 | Err(x) => { 168 | error!("proxy[6] {:?}", x); 169 | return; 170 | } 171 | }; 172 | let _ = match inbound.write(&buf[..a]).await { 173 | Ok(x) => x, 174 | Err(x) => { 175 | error!("proxy[7] {:?}", x); 176 | return; 177 | } 178 | }; 179 | 180 | info!("relay start"); 181 | match Adapter::relay_tcp_zio(inbound, outbound, conn_id).await { 182 | Ok(_) => (), 183 | Err(x) => error!("proxy[3] {:?}", x), 184 | } 185 | info!("relay end"); 186 | }); 187 | } 188 | }); 189 | 190 | // spawn 100 clients in tokio 191 | let mut client_handles = Vec::new(); 192 | for _ in 0..50 { 193 | client_handles.push(tokio::spawn(async move { 194 | let mut client_sender = match tokio::net::TcpStream::connect("127.0.0.1:6666").await { 195 | Ok(s) => s, 196 | Err(e) => { 197 | error!("client[1] {:?}", e); 198 | return; 199 | } 200 | }; 201 | let data = [1u8; 2048]; 202 | let mut buf = [0u8; 2048]; 203 | for _ in 0..10 { 204 | let _ = client_sender.write(&data).await; 205 | let _ = client_sender.read(&mut buf).await; 206 | sleep(Duration::from_secs(1)).await; 207 | } 208 | })); 209 | } 210 | 211 | let mut i = 0; 212 | for handles in client_handles { 213 | info!("client {}", i); 214 | i += 1; 215 | let _ = handles.await; 216 | } 217 | sleep(Duration::from_secs(20)).await; 218 | } 219 | -------------------------------------------------------------------------------- /src/utils/lite_tls/The Journey to The Speed of Lite.md: -------------------------------------------------------------------------------- 1 | # The Journey to The Speed of Lite 2 | 3 | ## Background 4 | 5 | 在当下主流的代理工具中,trojan系协议在复杂度、速度、稳定性和隐蔽性方面脱颖而出,该协议的基本流程为 6 | 7 | ``` 8 | 1. user向代理客户端client发送请求 9 | [user] ---> [client] 10 | ^ 包含target address 11 | 12 | 2. client通过一个tls隧道向代理服务器发送password + target address 13 | [user] **** [client] ---> [server] 14 | ^ tls secured 15 | 16 | 3. 代理server认证client,然后建立与目标服务器的链接 17 | [user] **** [client] **** [server] ---> [target] 18 | 19 | 4. 最终,user通过这条隧道与target交换数据 20 | [user] <--> [client] <--> [server] <--> [target] 21 | ^ tls secured 22 | ``` 23 | 24 | ## Double encryption 25 | 26 | 仔细观察上述流程,client和server之间已经建立了一个tls隧道,而当下,user和target之间大部分的通讯也是基于tls加密的,这就使得在传输数据的时候,client和server之间的加密变得没有意义 27 | 28 | 对此,我们探索了Lite-Tls机制,其核心思路非常简单,当user和target使用tls通讯并开始传输数据的时候,client和server一同退出tls隧道 —— 反正在监听者看来,这都是tls传输,没有任何区别。 29 | 30 | ## The pain starts 31 | 32 | ### What can we transfer without encryption? 33 | 34 | 我们首先需要解决的问题是,哪些东西是可以直接被转发无需加密的?直觉告诉我们,tls握手阶段应该是不能直接转发的,那么具体该如何区分呢? 35 | 36 | 我们需要简单了解一下tls(1.2/1.3)的包定义: 37 | ``` 38 | +-------------+-------------+--------+----------+ 39 | | Record Type | version | Length | Payload | 40 | +-------------+------+------+-------------------+ 41 | | 1 | 0x03 | 0x03 | 2 | Variable | 42 | +-------------+------+------+--------+----------+ 43 | ``` 44 | 45 | 根据标准,包头的Record Type有 46 | * 0x14: Change Cipher Spec 47 | * 0x16: Handshake 48 | * 0x17: Application Data 49 | 50 | 其中,0x14和0x16会在握手过程中被使用,而0x17则是数据传输使用的类型,也就是可以被直接转发的包类型。 51 | 52 | ### A First Try 53 | 54 | TLS标准规定,0x14 - Change Cipher Spec的含义是,该包之后的包全部使用协商好的加密方法进行加密, 55 | 56 | * 终端定义 57 | * `user` - 用户 58 | * `client` - 代理客户端 59 | * `server` - 代理服务端 60 | * `target` - 目标网站 61 | * `一手包`和`二手包`:从`user/target`那里直接获得的包是`一手包`,从`server/client`那里获得的包是`二手包`。例如,对于`client`而言,从`user`发来的包是`一手包`,从`server`发来的包是`二手包`。 62 | 63 | ``` 64 | ---- 0x17 ---> [client] [server] 65 | 66 | [client] ==== 0x17 ====> [server] 67 | ^ the first 0x17 in this stream 68 | 69 | <== ...some traffics... ==> 70 | 71 | [client] [server] <-{..., 0x17}-- 72 | ^ active side *1 73 | 74 | [client] <={..., 0xff}== [server]{0x17} < cached *3 75 | passive side *2 ^ ^ a 0xff is appended 76 | 77 | [client]{...} [server]{0x17} 78 | ^ cached *4 79 | 80 | [client]{...} == 0xff => [server]{0x17} 81 | ^ a 0xff is returned *5 82 | 83 | [client]{...} [server]{0x17} 84 | ^ quit tls ^ quit tls *6 85 | 86 | [client] <- Plain Tcp -> [server] 87 | 88 | [client]{...} <- 0x17 -- [server] 89 | 90 | <-{..., 0x17}-- [client] [server] 91 | ...... 92 | ``` 93 | 注: 94 | 1. active side: 第二个收到`一手0x17`的endpoint进入active mode 95 | 2. passive side: 收到`0xff`的endpoint进入passive mode 96 | 3. active side会把收到的0x17先缓存起来,在0x17之前、往往有与尚未发送的0x16/0x14包,我们把`0xff`包附在这些pending包的尾部,将这些包一起发往passive side,表示之后随时可以退出tls隧道 97 | 4. passive side收到`3`发过来的包之后,会验证`0xff`(之后丢弃),并把它前面的包缓存起来,等0x17到达后,一同发给`user`,否则会导致`user`(浏览器)因为收到的包不完整而出现错误。 98 | 5. passive side验证完`0xff`后,会返回一个`0xff`,表示自己已经不会再通过tls隧道接收数据,之后便退出`tls`隧道 99 | 6. 当active side收到返回的`0xff`后,便也退出tls隧道 100 | 7. 之后active side和passive side之间便通过tcp直接通信,active side把之前缓存的`0x17`发给passive side,passive side收到`0x17`后,连同之前缓存的包一起一次性发给`user`。之后整个过程结束。 -------------------------------------------------------------------------------- /src/utils/lite_tls/error.rs: -------------------------------------------------------------------------------- 1 | use std::io::ErrorKind; 2 | use anyhow::Error; 3 | 4 | pub fn eof_err(msg: &str) -> Error { 5 | Error::new(std::io::Error::new(ErrorKind::UnexpectedEof, msg)) 6 | } -------------------------------------------------------------------------------- /src/utils/lite_tls/leave_tls.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "quic")] 2 | use quinn::*; 3 | use tokio::net::TcpStream; 4 | use tokio_rustls::{client, server}; 5 | 6 | #[cfg(feature = "quic")] 7 | use crate::utils::{TrojanUdpStream, WRTuple}; 8 | 9 | pub trait LeaveTls { 10 | fn leave(self) -> TcpStream; 11 | } 12 | 13 | impl LeaveTls for server::TlsStream { 14 | fn leave(self) -> TcpStream { 15 | self.into_inner().0 16 | } 17 | } 18 | 19 | impl LeaveTls for client::TlsStream { 20 | fn leave(self) -> TcpStream { 21 | self.into_inner().0 22 | } 23 | } 24 | 25 | #[cfg(feature = "quic")] 26 | impl LeaveTls for TrojanUdpStream { 27 | fn leave(self) -> TcpStream { 28 | unimplemented!() 29 | } 30 | } 31 | 32 | #[cfg(feature = "quic")] 33 | impl LeaveTls for WRTuple { 34 | fn leave(self) -> TcpStream { 35 | unimplemented!() 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/utils/lite_tls/mod.rs: -------------------------------------------------------------------------------- 1 | mod error; 2 | mod leave_tls; 3 | mod lite_tls_stream; 4 | mod tls_relay_buffer; 5 | 6 | pub use leave_tls::LeaveTls; 7 | pub use lite_tls_stream::LiteTlsStream; 8 | pub use tls_relay_buffer::LeaveTlsMode; 9 | -------------------------------------------------------------------------------- /src/utils/lite_tls/tls_relay_buffer.rs: -------------------------------------------------------------------------------- 1 | use super::{error::eof_err, lite_tls_stream::Direction}; 2 | use crate::{expect_buf_len, utils::ParserError}; 3 | use anyhow::{Error, Result}; 4 | use std::{ 5 | cmp::min, 6 | fmt::Display, 7 | ops::{Deref, DerefMut}, 8 | }; 9 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 10 | 11 | #[cfg(feature = "debug_info")] 12 | use tracing::debug; 13 | 14 | #[derive(Debug, Clone, Copy)] 15 | pub enum LeaveTlsMode { 16 | Active, 17 | Passive, 18 | } 19 | 20 | impl Display for LeaveTlsMode { 21 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 22 | match self { 23 | &LeaveTlsMode::Active => write!(f, "a"), 24 | &LeaveTlsMode::Passive => write!(f, "p"), 25 | } 26 | } 27 | } 28 | 29 | #[cfg_attr(feature = "debug_info", derive(Debug))] 30 | pub struct TlsRelayBuffer { 31 | inner: Vec, 32 | /// read cursor 33 | cursor: usize, 34 | } 35 | 36 | impl Deref for TlsRelayBuffer { 37 | type Target = Vec; 38 | fn deref(&self) -> &Self::Target { 39 | &self.inner 40 | } 41 | } 42 | 43 | impl DerefMut for TlsRelayBuffer { 44 | fn deref_mut(&mut self) -> &mut Self::Target { 45 | &mut self.inner 46 | } 47 | } 48 | 49 | fn extract_len(buf: &[u8]) -> usize { 50 | buf[0] as usize * 256 + buf[1] as usize 51 | } 52 | 53 | // pub(super) enum Expecting { 54 | // Num(usize), 55 | // Packet(u8), 56 | // } 57 | 58 | // pub(super) enum LeaveTlsMode { 59 | // Passive, 60 | // Active, 61 | // } 62 | 63 | #[derive(Debug, Clone, Copy)] 64 | pub(super) enum Seen0x17 { 65 | None, 66 | FromInbound, 67 | FromOutbound, 68 | BothDirections, 69 | } 70 | 71 | impl Seen0x17 { 72 | fn witness(&mut self, dir: Direction) { 73 | use Direction::*; 74 | use Seen0x17::*; 75 | *self = match (*self, dir) { 76 | (None, Inbound) => FromInbound, 77 | (None, Outbound) => FromOutbound, 78 | (FromInbound, Outbound) | (FromOutbound, Inbound) => BothDirections, 79 | (BothDirections, _) => unreachable!(), 80 | _ => return, 81 | }; 82 | } 83 | 84 | fn is_complete(&self) -> bool { 85 | match self { 86 | &Seen0x17::BothDirections => true, 87 | _ => false, 88 | } 89 | } 90 | } 91 | 92 | impl TlsRelayBuffer { 93 | pub fn new() -> Self { 94 | Self { 95 | inner: Vec::with_capacity(2048), 96 | cursor: 0, 97 | } 98 | } 99 | pub fn len(&self) -> usize { 100 | self.inner.len() 101 | } 102 | 103 | pub fn reset(&mut self) { 104 | unsafe { 105 | self.inner.set_len(0); 106 | } 107 | self.cursor = 0; 108 | } 109 | 110 | pub fn checked_packets(&self) -> &[u8] { 111 | if self.cursor < self.inner.len() { 112 | &self.inner[..self.cursor] 113 | } else { 114 | &self.inner 115 | } 116 | } 117 | 118 | pub fn pop_checked_packets(&mut self) { 119 | let new_len = self.inner.len() - min(self.cursor, self.inner.len()); 120 | for i in 0..new_len { 121 | self.inner[i] = self.inner[self.cursor + i]; 122 | } 123 | 124 | self.cursor -= self.inner.len() - new_len; 125 | 126 | unsafe { 127 | self.inner.set_len(new_len); 128 | } 129 | } 130 | 131 | pub fn check_client_hello(&mut self) -> Result<(), ParserError> { 132 | expect_buf_len!(self.inner, 5, "client hello incomplete[1]"); 133 | if &self.inner[..3] != &[0x16, 0x03, 0x01] { 134 | // Not tls 1.2/1.3 135 | return Err(ParserError::Invalid("not tls 1.2/1.3[1]".into())); 136 | } 137 | self.cursor = 5 + extract_len(&self.inner[3..]); 138 | if self.cursor != self.inner.len() { 139 | // Not tls 1.2/1.3 140 | return Err(ParserError::Invalid("not tls 1.2/1.3[2]".into())); 141 | } 142 | Ok(()) 143 | } 144 | 145 | pub(super) fn check_tls_packet(&mut self) -> Result { 146 | expect_buf_len!( 147 | self.inner, 148 | self.cursor + 5, 149 | "packet 0x16 (or sth) incomplete[1]" 150 | ); 151 | let packet_type = self.inner[self.cursor]; 152 | self.cursor += 5 + extract_len(&self.inner[self.cursor + 3..]); 153 | expect_buf_len!( 154 | self.inner, 155 | self.cursor, 156 | "packet 0x16 (or sth) incomplete[2]" 157 | ); 158 | Ok(packet_type) 159 | } 160 | 161 | pub(super) fn drop_0xff(&mut self) -> Result<()> { 162 | if self.cursor >= self.inner.len() { 163 | return Err(Error::new(ParserError::Invalid( 164 | "wants to drop 0xff but found nothing to drop".into(), 165 | ))); 166 | } else if self.cursor + 6 < self.inner.len() { 167 | return Err(Error::new(ParserError::Invalid( 168 | "wants to drop 0xff but it seems to be incomplete".into(), 169 | ))); 170 | } else if self.cursor + 6 > self.inner.len() { 171 | return Err(Error::new(ParserError::Invalid( 172 | "wants to drop 0xff but there seems to be packets after 0xff".into(), 173 | ))); 174 | } 175 | unsafe { 176 | self.inner.set_len(self.cursor); 177 | } 178 | Ok(()) 179 | } 180 | 181 | pub(super) fn find_key_packets( 182 | &mut self, 183 | seen_0x17: &mut Seen0x17, 184 | dir: Direction, 185 | ) -> Result { 186 | loop { 187 | expect_buf_len!(self.inner, self.cursor + 1, "find 0x17 incomplete"); 188 | match self.inner[self.cursor] { 189 | 0x17 => { 190 | #[cfg(feature = "debug_info")] 191 | debug!("found 0x17, already seen: {:?}", seen_0x17); 192 | seen_0x17.witness(dir); 193 | #[cfg(feature = "debug_info")] 194 | debug!("now seen 0x17: {:?}", seen_0x17); 195 | if seen_0x17.is_complete() { 196 | #[cfg(feature = "debug_info")] 197 | debug!("lite-tls active handshake"); 198 | return Ok(LeaveTlsMode::Active); 199 | } else { 200 | #[cfg(feature = "debug_info")] 201 | debug!("lite-tls 0x17 in first direction"); 202 | self.check_tls_packet()?; 203 | } 204 | } 205 | 0xff => { 206 | #[cfg(feature = "debug_info")] 207 | debug!("lite-tls passive handshake"); 208 | return Ok(LeaveTlsMode::Passive); 209 | } 210 | 0x15 | 0x16 | 0x14 => { 211 | self.check_tls_packet()?; 212 | } 213 | _ => { 214 | return Err(ParserError::Invalid("unexpected tls packet type".into())); 215 | } 216 | } 217 | } 218 | } 219 | 220 | pub(super) async fn flush_checked(&mut self, writer: &mut W) -> Result<()> 221 | where 222 | W: AsyncWriteExt + Unpin, 223 | { 224 | if self.checked_packets().len() > 0 { 225 | if writer.write(self.checked_packets()).await? == 0 { 226 | return Err(eof_err("EOF on Parsing[]")); 227 | } 228 | writer.flush().await?; 229 | self.pop_checked_packets(); 230 | } 231 | Ok(()) 232 | } 233 | 234 | pub(super) async fn tls12_relay_until_0xff( 235 | &mut self, 236 | reader: &mut R, 237 | writer: &mut W, 238 | ) -> Result<()> 239 | where 240 | R: AsyncReadExt + Unpin, 241 | W: AsyncWriteExt + Unpin, 242 | { 243 | loop { 244 | while self.inner.len() < self.cursor + 5 { 245 | if self.checked_packets().len() > 0 { 246 | self.flush_checked(writer).await?; 247 | } 248 | if reader.read_buf(self.deref_mut()).await? == 0 { 249 | return Err(eof_err("EOF on Parsing[]")); 250 | } 251 | } 252 | 253 | match self.inner[self.cursor] { 254 | 0xff => { 255 | // relay pending 0x17 256 | if self.checked_packets().len() > 0 { 257 | self.flush_checked(writer).await?; 258 | } 259 | let next = self.cursor + 5 + extract_len(&self.inner[self.cursor + 3..]); 260 | while self.inner.len() < next { 261 | if reader.read_buf(self.deref_mut()).await? == 0 { 262 | return Err(eof_err("EOF on Parsing[]")); 263 | } 264 | } 265 | self.reset(); 266 | return Ok(()); 267 | } 268 | _ => (), 269 | } 270 | self.cursor += 5 + extract_len(&self.inner[self.cursor + 3..]); 271 | } 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /src/utils/macros.rs: -------------------------------------------------------------------------------- 1 | 2 | #[macro_export] 3 | macro_rules! try_recv { 4 | ($T:tt, $instance:expr) => { 5 | try_recv!($T, $instance, break) 6 | }; 7 | ($T:tt, $instance:expr, $then_expr:expr) => { 8 | match $instance.try_recv() { 9 | Err($T::error::TryRecvError::Empty) => (), 10 | _ => { 11 | tracing::info!("{} received", stringify!($instance)); 12 | $then_expr; 13 | } 14 | } 15 | }; 16 | } 17 | 18 | #[macro_export] 19 | macro_rules! or_continue { 20 | ($res:expr) => { 21 | match $res { 22 | Ok(res) => res, 23 | Err(e) => { 24 | info!("{} failed due to {:?}", stringify!($res), e); 25 | continue; 26 | } 27 | } 28 | }; 29 | } 30 | 31 | #[macro_export] 32 | macro_rules! expect_buf_len { 33 | ($buf:expr, $len:expr) => { 34 | if $buf.len() < $len { 35 | return Err(ParserError::Incomplete(stringify!($len))); 36 | } 37 | }; 38 | ($buf:expr, $len:expr, $mark:expr) => { 39 | if $buf.len() < $len { 40 | // debug!("expect_buf_len {}", $mark); 41 | return Err(ParserError::Incomplete($mark.into())); 42 | } 43 | }; 44 | } -------------------------------------------------------------------------------- /src/utils/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "udp")] 2 | mod udp; 3 | #[cfg(feature = "udp")] 4 | pub use udp::*; 5 | 6 | #[cfg(feature = "lite_tls")] 7 | pub mod lite_tls; 8 | 9 | // mod copy_tcp; 10 | // pub use copy_tcp::copy_to_tls; 11 | 12 | mod macros; 13 | mod mix_addr; 14 | pub use mix_addr::*; 15 | 16 | mod adapter; 17 | mod either_io; 18 | pub use adapter::*; 19 | 20 | mod timedout_duplex_io; 21 | pub use timedout_duplex_io::*; 22 | 23 | mod buffers; 24 | pub use buffers::*; 25 | 26 | mod forked_copy; 27 | pub use forked_copy::*; 28 | 29 | mod buffered_recv; 30 | pub use buffered_recv::BufferedRecv; 31 | 32 | mod wr_tuple; 33 | pub use wr_tuple::WRTuple; 34 | 35 | mod dns_utils; 36 | pub use dns_utils::*; 37 | 38 | #[cfg(all(target_os = "linux", feature = "zio"))] 39 | mod glommio_utils; 40 | #[cfg(all(target_os = "linux", feature = "zio"))] 41 | pub use glommio_utils::*; 42 | 43 | #[derive(Debug, err_derive::Error)] 44 | pub enum ParserError { 45 | #[error(display = "ParserError Incomplete: {:?}", _0)] 46 | Incomplete(String), 47 | #[error(display = "ParserError Invalid: {:?}", _0)] 48 | Invalid(String), 49 | } 50 | 51 | pub fn transmute_u16s_to_u8s(a: &[u16], b: &mut [u8]) { 52 | if b.len() < a.len() * 2 { 53 | return; 54 | } 55 | for (i, val) in a.iter().enumerate() { 56 | let x = val.to_be_bytes(); 57 | b[2 * i] = x[0]; 58 | b[2 * i + 1] = x[1]; 59 | } 60 | } 61 | 62 | pub enum ConnectionRequest { 63 | TCP(TcpRequest), 64 | #[cfg(feature = "udp")] 65 | UDP(UdpRequest), 66 | #[cfg(feature = "quic")] 67 | ECHO(EchoRequest), 68 | _PHANTOM((TcpRequest, UdpRequest, EchoRequest)), 69 | } 70 | 71 | #[cfg(not(feature = "udp"))] 72 | pub struct DummyRequest {} 73 | -------------------------------------------------------------------------------- /src/utils/timedout_duplex_io.rs: -------------------------------------------------------------------------------- 1 | use futures::Future; 2 | use pin_project_lite::pin_project; 3 | use std::{ 4 | pin::Pin, 5 | sync::atomic::{AtomicU32, Ordering}, 6 | sync::Arc, 7 | task::Poll, 8 | }; 9 | 10 | use tokio::{ 11 | io::{AsyncRead, AsyncWrite}, 12 | time::{sleep_until, Duration, Instant, Sleep}, 13 | }; 14 | 15 | #[cfg(feature = "udp")] 16 | use crate::utils::MixAddrType; 17 | 18 | #[cfg(feature = "udp")] 19 | use super::{UdpRead, UdpWrite}; 20 | 21 | pin_project! { 22 | pub struct TimeoutMonitor { 23 | created: Instant, 24 | last_active: Arc, 25 | deadline: u32, 26 | #[pin] 27 | sleep: Sleep, 28 | } 29 | } 30 | 31 | pub struct TimedoutIO { 32 | inner: R, 33 | created: Instant, 34 | last_active: Arc, 35 | } 36 | 37 | impl TimeoutMonitor { 38 | pub fn new(deadline: Duration) -> Self { 39 | let sleep = Instant::now() + deadline; 40 | Self { 41 | created: Instant::now(), 42 | last_active: Arc::new(AtomicU32::new(0)), 43 | deadline: deadline.as_secs() as u32, 44 | sleep: sleep_until(sleep), 45 | } 46 | } 47 | 48 | pub fn watch(&self, inner: R) -> TimedoutIO { 49 | TimedoutIO { 50 | inner, 51 | created: self.created, 52 | last_active: self.last_active.clone(), 53 | } 54 | } 55 | } 56 | 57 | impl Future for TimeoutMonitor { 58 | type Output = (); 59 | 60 | fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { 61 | let mut me = self.project(); 62 | match me.sleep.as_mut().poll(cx) { 63 | Poll::Ready(()) => { 64 | let instant = Instant::now(); 65 | let time_elapsed = (instant - *me.created).as_secs() as u32; 66 | let last_active = me.last_active.load(Ordering::Relaxed); 67 | let inactive_time = time_elapsed - last_active; 68 | if inactive_time > *me.deadline { 69 | Poll::Ready(()) 70 | } else { 71 | me.sleep 72 | .as_mut() 73 | .reset(instant + Duration::from_secs(*me.deadline as u64)); 74 | Poll::Pending 75 | } 76 | } 77 | Poll::Pending => Poll::Pending, 78 | } 79 | } 80 | } 81 | 82 | macro_rules! poll_timeout { 83 | {With $me:expr, $poll:expr} => { 84 | match $poll { 85 | res @ Poll::Ready(_) => { 86 | let last_active = (Instant::now() - $me.created).as_secs() as u32; 87 | $me.last_active.store(last_active, Ordering::Relaxed); 88 | return res; 89 | }, 90 | Poll::Pending => Poll::Pending, 91 | } 92 | }; 93 | } 94 | 95 | #[cfg(feature = "udp")] 96 | impl UdpRead for TimedoutIO { 97 | fn poll_proxy_stream_read( 98 | mut self: std::pin::Pin<&mut Self>, 99 | cx: &mut std::task::Context<'_>, 100 | buf: &mut super::UdpRelayBuffer, 101 | ) -> Poll> { 102 | poll_timeout! { 103 | With self, 104 | Pin::new(&mut self.inner).poll_proxy_stream_read(cx, buf) 105 | } 106 | } 107 | } 108 | 109 | #[cfg(feature = "udp")] 110 | impl UdpWrite for TimedoutIO { 111 | fn poll_proxy_stream_write( 112 | mut self: Pin<&mut Self>, 113 | cx: &mut std::task::Context<'_>, 114 | buf: &[u8], 115 | addr: &MixAddrType, 116 | ) -> Poll> { 117 | poll_timeout! { 118 | With self, 119 | Pin::new(&mut self.inner).poll_proxy_stream_write(cx, buf, addr) 120 | } 121 | } 122 | 123 | fn poll_flush( 124 | mut self: Pin<&mut Self>, 125 | cx: &mut std::task::Context<'_>, 126 | ) -> Poll> { 127 | Pin::new(&mut self.inner).poll_flush(cx) 128 | } 129 | 130 | fn poll_shutdown( 131 | mut self: Pin<&mut Self>, 132 | cx: &mut std::task::Context<'_>, 133 | ) -> Poll> { 134 | Pin::new(&mut self.inner).poll_shutdown(cx) 135 | } 136 | } 137 | 138 | impl AsyncRead for TimedoutIO { 139 | fn poll_read( 140 | mut self: Pin<&mut Self>, 141 | cx: &mut std::task::Context<'_>, 142 | buf: &mut tokio::io::ReadBuf<'_>, 143 | ) -> Poll> { 144 | poll_timeout! { 145 | With self, 146 | Pin::new(&mut self.inner).poll_read(cx, buf) 147 | } 148 | } 149 | } 150 | 151 | impl AsyncWrite for TimedoutIO { 152 | fn poll_write( 153 | mut self: Pin<&mut Self>, 154 | cx: &mut std::task::Context<'_>, 155 | buf: &[u8], 156 | ) -> Poll> { 157 | poll_timeout! { 158 | With self, 159 | Pin::new(&mut self.inner).poll_write(cx, buf) 160 | } 161 | } 162 | 163 | fn poll_flush( 164 | mut self: Pin<&mut Self>, 165 | cx: &mut std::task::Context<'_>, 166 | ) -> Poll> { 167 | Pin::new(&mut self.inner).poll_flush(cx) 168 | } 169 | 170 | fn poll_shutdown( 171 | mut self: Pin<&mut Self>, 172 | cx: &mut std::task::Context<'_>, 173 | ) -> Poll> { 174 | Pin::new(&mut self.inner).poll_shutdown(cx) 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /src/utils/udp/copy_udp.rs: -------------------------------------------------------------------------------- 1 | use futures::ready; 2 | 3 | use crate::utils::{CursoredBuffer, MixAddrType, UdpRead, UdpRelayBuffer, UdpWrite}; 4 | use std::pin::Pin; 5 | use std::task::Poll; 6 | use std::{future::Future, u64}; 7 | #[cfg(feature = "udp_info")] 8 | use tracing::debug; 9 | use tracing::info; 10 | 11 | #[allow(unused)] 12 | pub async fn copy_udp<'a, R: UdpRead + Unpin, W: UdpWrite + Unpin>( 13 | reader: &'a mut R, 14 | writer: &'a mut W, 15 | conn_id: Option, 16 | ) -> std::io::Result { 17 | CopyUdp { 18 | reader, 19 | writer, 20 | udp_buf: UdpCopyBuf::new(conn_id), 21 | } 22 | .await 23 | } 24 | 25 | struct CopyUdp<'a, R, W> { 26 | reader: &'a mut R, 27 | writer: &'a mut W, 28 | udp_buf: UdpCopyBuf, 29 | } 30 | 31 | impl Future for CopyUdp<'_, R, W> 32 | where 33 | R: UdpRead + Unpin, 34 | W: UdpWrite + Unpin, 35 | { 36 | type Output = std::io::Result; 37 | 38 | fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { 39 | let me = &mut *self; 40 | me.udp_buf 41 | .poll_copy(cx, Pin::new(me.reader), Pin::new(me.writer)) 42 | } 43 | } 44 | 45 | pub(crate) struct UdpCopyBuf { 46 | buf: UdpRelayBuffer, 47 | addr: Option, 48 | amt: u64, 49 | conn_id: Option, 50 | need_flush: bool, 51 | read_done: bool, 52 | } 53 | 54 | impl UdpCopyBuf { 55 | pub(crate) fn new(conn_id: Option) -> Self { 56 | Self { 57 | buf: UdpRelayBuffer::new(), 58 | addr: None, 59 | amt: 0, 60 | conn_id, 61 | need_flush: false, 62 | read_done: false, 63 | } 64 | } 65 | 66 | pub(crate) fn poll_copy( 67 | self: &mut Self, 68 | cx: &mut std::task::Context<'_>, 69 | mut reader: Pin<&mut R>, 70 | mut writer: Pin<&mut W>, 71 | ) -> Poll> 72 | where 73 | R: UdpRead + Unpin, 74 | W: UdpWrite + Unpin, 75 | { 76 | loop { 77 | if !self.buf.has_remaining() && !self.read_done { 78 | #[cfg(feature = "udp_info")] 79 | debug!("[{:?}]CopyUdp::poll reset buffer", self.conn_id); 80 | unsafe { 81 | self.buf.reset(); 82 | } 83 | let new_addr = match reader.as_mut().poll_proxy_stream_read(cx, &mut self.buf)? { 84 | Poll::Ready(addr) => addr, 85 | Poll::Pending => { 86 | // Try flushing when the reader has no progress to avoid deadlock 87 | // when the reader depends on buffered writer. 88 | if self.need_flush { 89 | ready!(writer.as_mut().poll_flush(cx))?; 90 | self.need_flush = false; 91 | } 92 | 93 | return Poll::Pending; 94 | } 95 | }; 96 | if new_addr.is_none() { 97 | #[cfg(feature = "udp_info")] 98 | debug!("[{:?}]CopyUdp::poll new_addr.is_none()", self.conn_id); 99 | self.read_done = true; 100 | } else if self.addr.as_ref().map_or(true, |prev| prev != &new_addr) { 101 | if self.conn_id.is_some() { 102 | info!("[udp][{}] => {:?}", self.conn_id.unwrap(), &new_addr); 103 | } 104 | self.addr = Some(new_addr); 105 | } 106 | } 107 | 108 | #[cfg(feature = "udp_info")] 109 | debug!( 110 | "[{:?}]CopyUdp::poll self.addr {:?}, self.buff len: {:?}", 111 | self.conn_id, 112 | self.addr, 113 | &self.buf.chunk().len() 114 | ); 115 | 116 | while self.buf.has_remaining() { 117 | let x = ready!(writer.as_mut().poll_proxy_stream_write( 118 | cx, 119 | &self.buf.chunk(), 120 | self.addr.as_ref().unwrap() 121 | ))?; 122 | 123 | if x == 0 { 124 | return Poll::Ready(Err(std::io::Error::new( 125 | std::io::ErrorKind::WriteZero, 126 | "write zero byte into writer", 127 | ))); 128 | } 129 | 130 | #[cfg(feature = "udp_info")] 131 | debug!("[{:?}]CopyUdp::poll self.buf.advance({})", self.conn_id, x); 132 | self.buf.advance(x); 133 | self.amt += x as u64; 134 | self.need_flush = true; 135 | } 136 | 137 | if !self.buf.has_remaining() && self.read_done { 138 | ready!(writer.as_mut().poll_flush(cx))?; 139 | return Poll::Ready(Ok(self.amt)); 140 | } 141 | } 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/utils/udp/copy_udp_bidirectional.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::StreamStopReasons; 2 | use crate::utils::{UdpCopyBuf, UdpRead, UdpWrite}; 3 | use futures::{ready, Future}; 4 | use std::pin::Pin; 5 | use std::task::{Context, Poll}; 6 | 7 | enum TransferState { 8 | Running(UdpCopyBuf), 9 | ShuttingDown(u64), 10 | Done(u64), 11 | } 12 | 13 | struct CopyBidirectional<'a, I, O> { 14 | i: &'a mut I, 15 | o: &'a mut O, 16 | upload: TransferState, 17 | download: TransferState, 18 | stop_reason: StreamStopReasons, 19 | } 20 | 21 | fn transfer_one_direction( 22 | cx: &mut Context<'_>, 23 | state: &mut TransferState, 24 | r: &mut A, 25 | w: &mut B, 26 | ) -> Poll> 27 | where 28 | A: UdpRead + Unpin, 29 | B: UdpWrite + Unpin, 30 | { 31 | let mut r = Pin::new(r); 32 | let mut w = Pin::new(w); 33 | 34 | loop { 35 | match state { 36 | TransferState::Running(buf) => { 37 | let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?; 38 | *state = TransferState::ShuttingDown(count); 39 | } 40 | TransferState::ShuttingDown(count) => { 41 | ready!(w.as_mut().poll_shutdown(cx))?; 42 | 43 | *state = TransferState::Done(*count); 44 | } 45 | TransferState::Done(count) => return Poll::Ready(Ok(*count)), 46 | } 47 | } 48 | } 49 | 50 | impl<'a, A, B> Future for CopyBidirectional<'a, A, B> 51 | where 52 | A: UdpRead + UdpWrite + Unpin, 53 | B: UdpRead + UdpWrite + Unpin, 54 | { 55 | type Output = std::io::Result<(u64, u64, StreamStopReasons)>; 56 | 57 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 58 | // Unpack self into mut refs to each field to avoid borrow check issues. 59 | use StreamStopReasons::*; 60 | let CopyBidirectional { 61 | i, 62 | o, 63 | upload, 64 | download, 65 | stop_reason, 66 | } = &mut *self; 67 | 68 | let upload = transfer_one_direction(cx, upload, *i, *o)?; 69 | let download = transfer_one_direction(cx, download, *o, *i)?; 70 | 71 | // It is not a problem if ready! returns early because transfer_one_direction for the 72 | // other direction will keep returning TransferState::Done(count) in future calls to poll 73 | use Poll::*; 74 | match (upload, download) { 75 | (Pending, Pending) => Pending, 76 | (Ready(_upload), Pending) => { 77 | *stop_reason = Upload; 78 | Pending 79 | } 80 | (Pending, Ready(_download)) => { 81 | *stop_reason = Download; 82 | Pending 83 | } 84 | (Ready(upload), Ready(download)) => Ready(Ok((upload, download, stop_reason.clone()))), 85 | } 86 | } 87 | } 88 | 89 | pub async fn udp_copy_bidirectional( 90 | inbound: &mut I, 91 | outbound: &mut O, 92 | conn_id: usize, 93 | ) -> Result<(u64, u64, StreamStopReasons), std::io::Error> 94 | where 95 | I: UdpRead + UdpWrite + Unpin, 96 | O: UdpRead + UdpWrite + Unpin, 97 | { 98 | CopyBidirectional { 99 | i: inbound, 100 | o: outbound, 101 | upload: TransferState::Running(UdpCopyBuf::new(Some(conn_id))), 102 | download: TransferState::Running(UdpCopyBuf::new(None)), 103 | stop_reason: StreamStopReasons::Download, 104 | } 105 | .await 106 | } 107 | -------------------------------------------------------------------------------- /src/utils/udp/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod copy_udp; 2 | pub mod copy_udp_bidirectional; 3 | pub mod trojan_udp_stream; 4 | pub mod udp_relay_buffer; 5 | pub mod udp_shutdown; 6 | pub mod udp_traits; 7 | 8 | pub use { 9 | copy_udp::copy_udp, 10 | copy_udp_bidirectional::udp_copy_bidirectional, 11 | trojan_udp_stream::TrojanUdpStream, 12 | udp_relay_buffer::UdpRelayBuffer, 13 | udp_traits::{UdpRead, UdpWrite, UdpWriteExt}, 14 | }; 15 | 16 | pub(crate) use copy_udp::UdpCopyBuf; 17 | -------------------------------------------------------------------------------- /src/utils/udp/trojan_udp_stream.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::{ 2 | BufferedRecv, CursoredBuffer, ExtendableFromSlice, MixAddrType, ParserError, UdpRead, 3 | UdpRelayBuffer, UdpWrite, 4 | }; 5 | use pin_project_lite::pin_project; 6 | use std::{ 7 | pin::Pin, 8 | task::{Context, Poll}, 9 | }; 10 | use tokio::io::{AsyncRead, AsyncWrite}; 11 | use tracing::*; 12 | 13 | pin_project! { 14 | #[cfg_attr(feature = "debug_info", derive(Debug))] 15 | pub struct TrojanUdpStream { 16 | #[pin] 17 | inner: BufferedRecv, 18 | // recv half 19 | recv_buffer: UdpRelayBuffer, 20 | expecting: Option, 21 | addr_buf: MixAddrType, 22 | want_to_extract: bool, 23 | // send half 24 | send_buffer: UdpRelayBuffer, 25 | data_len: usize, 26 | } 27 | } 28 | 29 | mod mc { 30 | macro_rules! debug_info { 31 | (recv $me:expr, $msg:expr, $addition:expr) => { 32 | #[cfg(feature = "udp_info")] 33 | debug!( 34 | "TrojanUdpRecv {} buf len {} expecting {:?} addr {:?} | {:?}", 35 | $msg, 36 | $me.recv_buffer.chunk().len(), 37 | $me.expecting, 38 | $me.addr_buf, 39 | $addition 40 | ); 41 | }; 42 | 43 | (send $me:expr, $msg:expr, $buf:expr, $addr:expr, $addition:expr) => { 44 | #[cfg(feature = "udp_info")] 45 | debug!( 46 | "TrojanUdpSend {} inner_buf len {} buf len {} addr {:?} | {:?}", 47 | $msg, 48 | $me.send_buffer.chunk().len(), 49 | $buf.len(), 50 | $addr, 51 | $addition, 52 | ); 53 | }; 54 | } 55 | pub(crate) use debug_info; 56 | } 57 | 58 | impl TrojanUdpStream { 59 | pub fn new(inner: IO, buffered_request: Option<(usize, Vec)>) -> Self { 60 | Self { 61 | inner: BufferedRecv::new(inner, buffered_request), 62 | recv_buffer: UdpRelayBuffer::new(), 63 | expecting: None, 64 | want_to_extract: false, 65 | addr_buf: MixAddrType::None, 66 | send_buffer: UdpRelayBuffer::new(), 67 | data_len: 0, 68 | } 69 | } 70 | } 71 | 72 | impl TrojanUdpStream { 73 | fn copy_into_inner(mut self: Pin<&mut Self>, buf: &[u8], addr: &MixAddrType) { 74 | self.data_len = buf.len(); 75 | self.send_buffer 76 | .reserve_by_cursor(addr.encoded_len() + 4 + buf.len()); 77 | addr.write_buf(&mut self.send_buffer); 78 | // unsafe: as u16 79 | self.send_buffer 80 | .extend_from_slice(&(buf.len() as u16).to_be_bytes()); 81 | self.send_buffer.extend_from_slice(&[b'\r', b'\n']); 82 | self.send_buffer.extend_from_slice(buf); 83 | mc::debug_info!(send self, "empty and refill", buf, addr, ""); 84 | } 85 | } 86 | 87 | impl UdpWrite for TrojanUdpStream { 88 | /// ```not_rust 89 | /// +------+----------+----------+--------+---------+----------+ 90 | /// | ATYP | DST.ADDR | DST.PORT | Length | CRLF | Payload | 91 | /// +------+----------+----------+--------+---------+----------+ 92 | /// | 1 | Variable | 2 | 2 | X'0D0A' | Variable | 93 | /// +------+----------+----------+--------+---------+----------+ 94 | /// ``` 95 | fn poll_proxy_stream_write( 96 | mut self: Pin<&mut Self>, 97 | cx: &mut Context<'_>, 98 | buf: &[u8], 99 | addr: &MixAddrType, 100 | ) -> Poll> { 101 | mc::debug_info!(send self, "enter", buf, addr, ""); 102 | if self.send_buffer.is_empty() { 103 | self.as_mut().copy_into_inner(buf, addr); 104 | } 105 | let mut me = self.project(); 106 | 107 | mc::debug_info!(send me, "before sending", buf, addr, ""); 108 | 109 | loop { 110 | match me.inner.as_mut().poll_write(cx, &me.send_buffer)? { 111 | Poll::Ready(0) => { 112 | return Poll::Ready(Ok(0)); 113 | } 114 | Poll::Ready(x) => { 115 | if x < me.send_buffer.remaining() { 116 | mc::debug_info!(send me, "send and remain", buf, addr, x); 117 | me.send_buffer.advance(x); 118 | } else { 119 | mc::debug_info!(send me, "send all", buf, addr, x); 120 | unsafe { 121 | me.send_buffer.reset(); 122 | } 123 | return Poll::Ready(Ok(*me.data_len)); 124 | } 125 | } 126 | Poll::Pending => { 127 | mc::debug_info!(send me, "pending", buf, addr, ""); 128 | return Poll::Pending; 129 | } 130 | } 131 | } 132 | } 133 | 134 | fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 135 | Pin::new(&mut self.inner).poll_flush(cx) 136 | } 137 | 138 | fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 139 | Pin::new(&mut self.inner).poll_shutdown(cx) 140 | } 141 | } 142 | 143 | impl TrojanUdpStream { 144 | fn try_update_addr_buf(self: Pin<&mut Self>) -> Result<(), ParserError> { 145 | let me = self.project(); 146 | if me.addr_buf.is_none() { 147 | MixAddrType::from_encoded(me.recv_buffer).map(|addr| *me.addr_buf = addr) 148 | } else { 149 | Ok(()) 150 | } 151 | } 152 | 153 | fn try_update_expecting(mut self: Pin<&mut Self>) -> Result<(), ParserError> { 154 | if self.expecting.is_none() { 155 | if self.recv_buffer.remaining() < 2 { 156 | return Err(ParserError::Incomplete(String::new())); 157 | } 158 | let bytes = self.recv_buffer.chunk()[0..2].try_into().unwrap(); 159 | let expecting = u16::from_be_bytes(bytes) as usize; 160 | self.expecting = Some(expecting); 161 | self.recv_buffer.advance(2 + 2); // `len` + `\r\n` 162 | self.recv_buffer.reserve_by_cursor(expecting); 163 | } 164 | Ok(()) 165 | } 166 | 167 | fn copy_into_outer(mut self: Pin<&mut Self>, outer_buf: &mut UdpRelayBuffer) { 168 | let expecting = self.expecting.unwrap(); 169 | let _out_len = outer_buf.len(); 170 | outer_buf.extend_from_slice(&self.recv_buffer.chunk()[..expecting]); 171 | self.recv_buffer.advance(expecting); 172 | self.recv_buffer.compact(); 173 | self.expecting = None; 174 | mc::debug_info!(recv self, "can extract", format!("outer len: {} -> {}", _out_len, outer_buf.len())); 175 | } 176 | 177 | fn try_extract_packet( 178 | mut self: Pin<&mut Self>, 179 | outer_buf: &mut UdpRelayBuffer, 180 | ) -> Option> { 181 | mc::debug_info!(recv self, "try_extract_packet", ""); 182 | match self 183 | .as_mut() 184 | .try_update_addr_buf() 185 | .and_then(|_| self.as_mut().try_update_expecting()) 186 | { 187 | Ok(_) => { 188 | mc::debug_info!(recv self, "try to extract", ""); 189 | 190 | // udp shouldn't be fragmented 191 | // we read in the packet as a whole 192 | // or we return pending 193 | if self.expecting.unwrap() <= self.recv_buffer.remaining() { 194 | self.as_mut().copy_into_outer(outer_buf); 195 | Some(Ok(std::mem::replace(&mut self.addr_buf, MixAddrType::None))) 196 | } else { 197 | None 198 | } 199 | } 200 | Err(ParserError::Incomplete(_msg)) => { 201 | #[cfg(feature = "udp_info")] 202 | debug!("TrojanUdpRecvStream Incomplete({})", _msg); 203 | None 204 | } 205 | Err(ParserError::Invalid(msg)) => { 206 | error!("TrojanUdpRecvStream Invalid({})", msg); 207 | Some(Err(std::io::ErrorKind::Other.into())) 208 | } 209 | } 210 | } 211 | } 212 | 213 | impl UdpRead for TrojanUdpStream { 214 | /// ```not_rust 215 | /// +------+----------+----------+--------+---------+----------+ 216 | /// | ATYP | DST.ADDR | DST.PORT | Length | CRLF | Payload | 217 | /// +------+----------+----------+--------+---------+----------+ 218 | /// | 1 | Variable | 2 | 2 | X'0D0A' | Variable | 219 | /// +------+----------+----------+--------+---------+----------+ 220 | /// ``` 221 | fn poll_proxy_stream_read( 222 | mut self: Pin<&mut Self>, 223 | cx: &mut Context<'_>, 224 | outer_buf: &mut UdpRelayBuffer, // bug once occured: accidentally used outer_buf as inner_buf 225 | ) -> Poll> { 226 | mc::debug_info!(recv self, "enter", format!("{:p}", cx.waker())); 227 | 228 | loop { 229 | if self.want_to_extract { 230 | if let Some(res) = self.as_mut().try_extract_packet(outer_buf) { 231 | mc::debug_info!(recv self, "early return", res); 232 | return Poll::Ready(res); 233 | } else { 234 | self.want_to_extract = false; 235 | } 236 | } 237 | 238 | let mut me = self.as_mut().project(); 239 | let mut buf_inner = me.recv_buffer.as_read_buf(); 240 | match me.inner.as_mut().poll_read(cx, &mut buf_inner)? { 241 | Poll::Ready(_) => { 242 | match buf_inner.filled().len() { 243 | 0 => { 244 | mc::debug_info!(recv me, "n == 0", ""); 245 | // EOF is seen 246 | return Poll::Ready(Ok(MixAddrType::None)); 247 | } 248 | n => { 249 | // Safety: This is guaranteed to be the number of initialized (and read) bytes due to the invariants provided by `ReadBuf::filled`. 250 | unsafe { 251 | me.recv_buffer.advance_mut(n); 252 | } 253 | 254 | *me.want_to_extract = true; 255 | 256 | mc::debug_info!(recv me, "read ready", n); 257 | } 258 | } 259 | } 260 | Poll::Pending => { 261 | mc::debug_info!(recv me, "pending", ""); 262 | return Poll::Pending; 263 | } 264 | } 265 | } 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /src/utils/udp/udp_relay_buffer.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | protocol::UDP_BUFFER_SIZE, 3 | utils::{CursoredBuffer, ExtendableFromSlice, VecAsReadBufExt}, 4 | }; 5 | use bytes::BufMut; 6 | use std::ops::Deref; 7 | use tokio::io::ReadBuf; 8 | 9 | #[cfg_attr(feature = "udp_info", derive(Debug))] 10 | pub struct UdpRelayBuffer { 11 | cursor: usize, 12 | inner: Vec, 13 | } 14 | 15 | impl<'a> UdpRelayBuffer { 16 | pub fn new() -> Self { 17 | let buf = Vec::with_capacity(UDP_BUFFER_SIZE); 18 | Self { 19 | cursor: 0, 20 | inner: buf, 21 | } 22 | } 23 | 24 | pub fn as_read_buf(&'a mut self) -> ReadBuf<'a> { 25 | self.inner.as_read_buf() 26 | } 27 | 28 | pub unsafe fn advance_mut(&mut self, cnt: usize) { 29 | self.inner.advance_mut(cnt); 30 | } 31 | 32 | pub unsafe fn reset(&mut self) { 33 | self.inner.set_len(0); 34 | self.cursor = 0; 35 | } 36 | 37 | pub fn has_remaining(&self) -> bool { 38 | self.cursor < self.inner.len() 39 | } 40 | 41 | pub fn is_empty(&self) -> bool { 42 | self.inner.is_empty() 43 | } 44 | 45 | pub fn compact(&mut self) { 46 | if self.cursor == 0 { 47 | return; 48 | } 49 | let data_len = self.remaining(); 50 | for i in 0..data_len { 51 | self.inner[i] = self.inner[i + self.cursor]; 52 | } 53 | unsafe { 54 | self.inner.set_len(data_len); 55 | } 56 | self.cursor = 0; 57 | } 58 | 59 | pub fn reserve_by_cursor(&mut self, len: usize) { 60 | if len + self.cursor <= self.inner.capacity() { 61 | return; 62 | } 63 | let mut new_inner = Vec::with_capacity(len); 64 | new_inner.extend_from_slice(self.chunk()); 65 | self.inner = new_inner; 66 | self.cursor = 0; 67 | } 68 | } 69 | 70 | impl<'a> CursoredBuffer for UdpRelayBuffer { 71 | fn chunk(&self) -> &[u8] { 72 | &self.inner[self.cursor..] 73 | } 74 | 75 | fn advance(&mut self, len: usize) { 76 | assert!( 77 | self.inner.len() >= self.cursor + len, 78 | "UdpRelayBuffer was about to set a larger position({}+{}) than it's length({})", 79 | self.cursor, 80 | len, 81 | self.inner.len() 82 | ); 83 | self.cursor += len; 84 | } 85 | } 86 | 87 | impl Deref for UdpRelayBuffer { 88 | type Target = [u8]; 89 | 90 | fn deref(&self) -> &Self::Target { 91 | self.chunk() 92 | } 93 | } 94 | 95 | impl ExtendableFromSlice for UdpRelayBuffer { 96 | fn extend_from_slice(&mut self, src: &[u8]) { 97 | self.inner.extend_from_slice(src); 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /src/utils/udp/udp_shutdown.rs: -------------------------------------------------------------------------------- 1 | use super::UdpWrite; 2 | use pin_project_lite::pin_project; 3 | use std::future::Future; 4 | use std::io; 5 | use std::marker::PhantomPinned; 6 | use std::pin::Pin; 7 | use std::task::{Context, Poll}; 8 | 9 | pin_project! { 10 | /// A future used to shutdown an I/O object. 11 | /// 12 | /// Created by the [`AsyncWriteExt::shutdown`][shutdown] function. 13 | /// [shutdown]: crate::io::AsyncWriteExt::shutdown 14 | #[must_use = "futures do nothing unless you `.await` or poll them"] 15 | #[cfg_attr(feature = "debug_info", derive(Debug))] 16 | pub struct Shutdown<'a, A: ?Sized> { 17 | a: &'a mut A, 18 | // Make this future `!Unpin` for compatibility with async trait methods. 19 | #[pin] 20 | _pin: PhantomPinned, 21 | } 22 | } 23 | 24 | /// Creates a future which will shutdown an I/O object. 25 | #[allow(unused)] 26 | pub(super) fn shutdown(a: &mut A) -> Shutdown<'_, A> 27 | where 28 | A: UdpWrite + Unpin + ?Sized, 29 | { 30 | Shutdown { 31 | a, 32 | _pin: PhantomPinned, 33 | } 34 | } 35 | 36 | impl Future for Shutdown<'_, A> 37 | where 38 | A: UdpWrite + Unpin + ?Sized, 39 | { 40 | type Output = io::Result<()>; 41 | 42 | fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { 43 | let me = self.project(); 44 | Pin::new(&mut **me.a).poll_shutdown(cx) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/utils/udp/udp_traits.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::{MixAddrType, UdpRelayBuffer}; 2 | use std::pin::Pin; 3 | use std::task::{Context, Poll}; 4 | 5 | use super::udp_shutdown::{shutdown, Shutdown}; 6 | 7 | pub trait UdpRead { 8 | /// Should return Poll::Ready(Ok(MixAddrType::None)) when 9 | /// EOF is seen. 10 | /// Udp packets should be read as a whole. 11 | /// If it's not complete, return Pending. 12 | fn poll_proxy_stream_read( 13 | self: Pin<&mut Self>, 14 | cx: &mut Context<'_>, 15 | buf: &mut UdpRelayBuffer, 16 | ) -> Poll>; 17 | } 18 | 19 | pub trait UdpWrite { 20 | /// Should return Ok(0) when the underlying object is no 21 | /// longer writable 22 | fn poll_proxy_stream_write( 23 | self: Pin<&mut Self>, 24 | cx: &mut Context<'_>, 25 | buf: &[u8], 26 | addr: &MixAddrType, 27 | ) -> Poll>; 28 | 29 | /// Should implement this if the underlying object e.g. 30 | /// TlsStream requires you to manually flush after write 31 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; 32 | 33 | fn poll_shutdown( 34 | self: Pin<&mut Self>, 35 | cx: &mut Context<'_>, 36 | ) -> Poll>; 37 | } 38 | 39 | pub trait UdpWriteExt: UdpWrite { 40 | fn shutdown(&mut self) -> Shutdown<'_, Self> 41 | where 42 | Self: Unpin, 43 | { 44 | shutdown(self) 45 | } 46 | } 47 | 48 | impl UdpWriteExt for W {} 49 | -------------------------------------------------------------------------------- /src/utils/wr_tuple.rs: -------------------------------------------------------------------------------- 1 | use std::pin::Pin; 2 | use std::task::Poll; 3 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; 4 | 5 | #[cfg_attr(feature = "debug_info", derive(Debug))] 6 | pub struct WRTuple(pub W, pub R); 7 | 8 | impl WRTuple { 9 | pub fn from_wr_tuple((w, r): (W, R)) -> Self { 10 | Self(w, r) 11 | } 12 | } 13 | 14 | impl AsyncRead for WRTuple 15 | where 16 | R: AsyncRead + Unpin, 17 | W: AsyncWrite + Unpin, 18 | { 19 | fn poll_read( 20 | mut self: Pin<&mut Self>, 21 | cx: &mut std::task::Context<'_>, 22 | buf: &mut ReadBuf<'_>, 23 | ) -> Poll> { 24 | Pin::new(&mut self.1).poll_read(cx, buf) 25 | } 26 | } 27 | 28 | impl AsyncWrite for WRTuple 29 | where 30 | R: AsyncRead + Unpin, 31 | W: AsyncWrite + Unpin, 32 | { 33 | fn poll_write( 34 | mut self: Pin<&mut Self>, 35 | cx: &mut std::task::Context<'_>, 36 | buf: &[u8], 37 | ) -> Poll> { 38 | Pin::new(&mut self.0).poll_write(cx, buf) 39 | } 40 | 41 | fn poll_flush( 42 | mut self: Pin<&mut Self>, 43 | cx: &mut std::task::Context<'_>, 44 | ) -> Poll> { 45 | Pin::new(&mut self.0).poll_flush(cx) 46 | } 47 | 48 | fn poll_shutdown( 49 | mut self: Pin<&mut Self>, 50 | cx: &mut std::task::Context<'_>, 51 | ) -> Poll> { 52 | Pin::new(&mut self.0).poll_shutdown(cx) 53 | } 54 | } 55 | --------------------------------------------------------------------------------