├── .cargo └── config.toml ├── src ├── utils │ ├── error.rs │ ├── mod.rs │ ├── clean.rs │ └── metrics.rs ├── mux │ ├── mod.rs │ ├── event.rs │ ├── stream.rs │ └── connection.rs ├── tunnel │ ├── mod.rs │ ├── local.rs │ ├── http_local.rs │ ├── quic_remote.rs │ ├── tls_remote.rs │ ├── socks5_local.rs │ ├── stream.rs │ ├── tls_local.rs │ └── client.rs └── main.rs ├── .gitignore ├── README.md ├── ci ├── build_linux.sh └── build_other.sh ├── Cargo.toml └── .github └── workflows └── build-release.yml /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [target.arm-unknown-linux-musleabi] 2 | linker = "arm-linux-musleabi-ld" 3 | -------------------------------------------------------------------------------- /src/utils/error.rs: -------------------------------------------------------------------------------- 1 | pub fn make_io_error(desc: &str) -> std::io::Error { 2 | std::io::Error::new(std::io::ErrorKind::Other, desc) 3 | } 4 | -------------------------------------------------------------------------------- /src/mux/mod.rs: -------------------------------------------------------------------------------- 1 | mod connection; 2 | pub mod event; 3 | mod stream; 4 | 5 | pub use connection::Connection; 6 | pub use connection::Mode; 7 | pub use stream::MuxStream; 8 | -------------------------------------------------------------------------------- /src/utils/mod.rs: -------------------------------------------------------------------------------- 1 | mod clean; 2 | mod error; 3 | mod metrics; 4 | pub use clean::clean_rotate_logs; 5 | pub use error::make_io_error; 6 | pub use metrics::MetricsLogRecorder; 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /target 3 | **/*.rs.bk 4 | Cargo.lock 5 | .vscode 6 | *.der 7 | 8 | 9 | # Added by cargo 10 | # 11 | # already existing elements were commented out 12 | 13 | #/target 14 | -------------------------------------------------------------------------------- /src/tunnel/mod.rs: -------------------------------------------------------------------------------- 1 | mod client; 2 | mod http_local; 3 | mod local; 4 | mod quic_remote; 5 | mod socks5_local; 6 | mod stream; 7 | mod tls_local; 8 | mod tls_remote; 9 | 10 | // pub const DEFAULT_TLS_HOST: &str = "google.com"; 11 | pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; 12 | pub const DefaultTimeoutSecs: u64 = 30; 13 | pub const CheckTimeoutSecs: u64 = 5; 14 | 15 | pub use self::client::Message; 16 | pub use self::client::MuxClient; 17 | pub use self::client::QuicInnerConnection; 18 | pub use self::client::TlsInnerConnection; 19 | pub use self::local::start_local_tunnel_server; 20 | pub use self::quic_remote::start_quic_remote_server; 21 | pub use self::tls_remote::start_tls_remote_server; 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Rust practice project 2 | 3 | ## Features 4 | 5 | 6 | # Getting Started 7 | 8 | **Examples** 9 | 10 | **Build** 11 | ```sh 12 | $ cargo build --release 13 | ``` 14 | 15 | **Generate cert/key for TLS/QUIC** 16 | ```sh 17 | $ ./target/release/rsnova --rcgen 18 | ``` 19 | 20 | **Launch Server At Remote Server** 21 | ```sh 22 | $ ./rsnova --role server --protocol tls --key ./key.der --cert ./cert.der --listen 0.0.0.0:48100 23 | ``` 24 | 25 | **Launch Server At Local Client** 26 | ```sh 27 | $ ./rsnova --role client --cert ./cert.der --listen 127.0.0.1:48100 --remote tls:// 28 | ``` 29 | 30 | **Use Proxy** 31 | Now you can configure `socks5://127.0.0.1:48100` as the proxy for browser/tools. 32 | 33 | -------------------------------------------------------------------------------- /ci/build_linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [ "$#" -ne 1 ]; then 3 | echo "Illegal number of parameters" 4 | exit 1 5 | fi 6 | 7 | TARGET=$1 8 | CUR_DIR=$( cd $( dirname $0 ) && pwd ) 9 | PKG_DIR="${CUR_DIR}/release" 10 | mkdir -p "${PKG_DIR}" 11 | VERSION=$(grep -E '^version' ${CUR_DIR}/../Cargo.toml | awk '{print $3}' | sed 's/"//g') 12 | 13 | cross build --target=$TARGET --release 14 | EXITCODE=$? 15 | if [ $EXITCODE -ne 0 ]; then 16 | echo "cross build failed" 17 | exit $EXITCODE 18 | fi 19 | 20 | # Package up the release binary 21 | tar -C target/$TARGET/release -cf rsnova-$VERSION-$TARGET.tar rsnova 22 | tar uf rsnova-$VERSION-$TARGET.tar 23 | gzip rsnova-$VERSION-$TARGET.tar 24 | mv rsnova-$VERSION-$TARGET.tar.gz $PKG_DIR -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rsnova" 3 | version = "0.3.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | [[bin]] 8 | name = "rsnova" 9 | path = "src/main.rs" 10 | 11 | 12 | [dependencies] 13 | clap = { version = "4.1.4", features = ["derive", "env"] } 14 | tracing = "0.1" 15 | tracing-subscriber = "0.3" 16 | tracing-appender = "0.2" 17 | veil = "0.1" 18 | tokio = { version = "1.0", features = ["full"] } 19 | quinn = "0.10" 20 | rcgen = "0.11" 21 | rustls = "0.21" 22 | rustls-pemfile = "1.0.2" 23 | rustls-native-certs = "0.6.1" 24 | anyhow = "1.0" 25 | url = "2.5.0" 26 | bincode = "2.0.0-rc.3" 27 | # yamux = "0.13.1" 28 | tokio-rustls = "0.23.4" 29 | tokio-util = { version = "0.7", features = ["compat"] } 30 | pki-types = { package = "rustls-pki-types", version = "1" } 31 | httparse = "1.8.0" 32 | bytes = "1" 33 | futures = "0.3" 34 | metrics = "0.21" 35 | metrics-util = "0.15" 36 | time = "0.3" 37 | -------------------------------------------------------------------------------- /ci/build_other.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "Illegal number of parameters" 5 | exit 1 6 | fi 7 | 8 | TARGET=$1 9 | CUR_DIR=$( cd $( dirname $0 ) && pwd ) 10 | PKG_DIR="${CUR_DIR}/release" 11 | mkdir -p "${PKG_DIR}" 12 | VERSION=$(grep -E '^version' ${CUR_DIR}/../Cargo.toml | awk '{print $3}' | sed 's/"//g') 13 | 14 | # if [ "$TARGET" = "arm-unknown-linux-musleabi" ] 15 | # then 16 | # wget https://musl.cc/arm-linux-musleabi-cross.tgz 17 | # tar zxf ./arm-linux-musleabi-cross.tgz -C /tmp 18 | # export PATH=/tmp/arm-linux-musleabi-cross/bin:$PATH 19 | # fi 20 | 21 | # Compile the binary for the current target 22 | 23 | cargo build --target=$TARGET --release 24 | EXITCODE=$? 25 | if [ $EXITCODE -ne 0 ]; then 26 | echo "cargo build failed" 27 | exit $EXITCODE 28 | fi 29 | 30 | # Package up the release binary 31 | if [ "$TARGET" = "x86_64-pc-windows-msvc" ] 32 | then 33 | echo "Build for windows" 34 | #echo "`ls -l target/$TARGET/release`" 35 | tar -C target/$TARGET/release -cf rsnova-$VERSION-$TARGET.tar rsnova.exe 36 | else 37 | tar -C target/$TARGET/release -cf rsnova-$VERSION-$TARGET.tar rsnova 38 | fi 39 | tar uf rsnova-$VERSION-$TARGET.tar 40 | gzip rsnova-$VERSION-$TARGET.tar 41 | mv rsnova-$VERSION-$TARGET.tar.gz $PKG_DIR -------------------------------------------------------------------------------- /src/utils/clean.rs: -------------------------------------------------------------------------------- 1 | pub async fn clean_rotate_logs(path: String) { 2 | let format = 3 | time::format_description::parse("[year]-[month]-[day]").expect("Unable to create format"); 4 | // let format = time::format_description::parse("[year]-[month]-[day]-[hour]-[minute]") 5 | // .expect("Unable to create format"); 6 | 7 | loop { 8 | tokio::time::sleep(tokio::time::Duration::from_secs(60)).await; 9 | let now = time::OffsetDateTime::now_utc(); 10 | for i in 7..=30 { 11 | let check_date = now.clone(); 12 | match check_date.checked_sub(time::Duration::days(i)) { 13 | Some(date) => { 14 | let suffix = date 15 | .format(&format) 16 | .expect("Unable to format OffsetDateTime"); 17 | let file_path = format!("{}.{}", path, suffix); 18 | let path = std::path::Path::new(file_path.as_str()); 19 | if path.exists() { 20 | tracing::info!("remove log file:{}", file_path); 21 | let _ = std::fs::remove_file(file_path); 22 | } else { 23 | //tracing::info!("remove log file:{} not exist", file_path); 24 | break; 25 | } 26 | } 27 | None => { 28 | break; 29 | } 30 | } 31 | } 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/utils/metrics.rs: -------------------------------------------------------------------------------- 1 | use metrics::{Counter, Gauge, Histogram, Key, KeyName, Recorder, SharedString, Unit}; 2 | use metrics_util::registry::{AtomicStorage, Registry}; 3 | use std::sync::Arc; 4 | use tokio::time::{sleep, Duration}; 5 | 6 | pub struct MetricsLogRecorder { 7 | registry: Arc>, 8 | } 9 | 10 | impl Recorder for MetricsLogRecorder { 11 | fn describe_counter(&self, key: KeyName, unit: Option, description: SharedString) {} 12 | fn describe_gauge(&self, key: KeyName, unit: Option, description: SharedString) {} 13 | fn describe_histogram(&self, key: KeyName, unit: Option, description: SharedString) {} 14 | fn register_counter(&self, key: &Key) -> Counter { 15 | self.registry 16 | .get_or_create_counter(key, |c| Counter::from_arc(c.clone())) 17 | } 18 | fn register_gauge(&self, key: &Key) -> Gauge { 19 | self.registry 20 | .get_or_create_gauge(key, |g| Gauge::from_arc(g.clone())) 21 | } 22 | fn register_histogram(&self, key: &Key) -> Histogram { 23 | self.registry 24 | .get_or_create_histogram(key, |h| Histogram::from_arc(h.clone())) 25 | } 26 | } 27 | 28 | async fn period_print_metrics(registry: Arc>, duration: Duration) { 29 | loop { 30 | sleep(duration).await; 31 | let mut metrics_info = String::new(); 32 | metrics_info.push_str("\n=================Metrics=====================\n"); 33 | metrics_info.push_str("Guages:\n"); 34 | registry.visit_gauges(|name, guage| { 35 | //guage.load(order) 36 | let n = guage.load(std::sync::atomic::Ordering::Relaxed); 37 | metrics_info.push_str(format!("{}:{}\n", name, f64::from_bits(n) as u64).as_str()); 38 | }); 39 | metrics_info.push_str("Counters:\n"); 40 | registry.visit_counters(|name, counter| { 41 | metrics_info.push_str( 42 | format!( 43 | "{}:{}\n", 44 | name, 45 | counter.load(std::sync::atomic::Ordering::SeqCst) 46 | ) 47 | .as_str(), 48 | ); 49 | }); 50 | tracing::info!("{}", metrics_info); 51 | } 52 | } 53 | 54 | impl MetricsLogRecorder { 55 | pub fn new(duration: Duration) -> MetricsLogRecorder { 56 | let recorder = MetricsLogRecorder { 57 | registry: Arc::new(Registry::atomic()), 58 | }; 59 | tokio::spawn(period_print_metrics(recorder.registry.clone(), duration)); 60 | recorder 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/tunnel/local.rs: -------------------------------------------------------------------------------- 1 | use std::net::SocketAddr; 2 | 3 | use crate::tunnel::http_local::{handle_http, handle_https}; 4 | use crate::tunnel::socks5_local::handle_socks5; 5 | use crate::tunnel::tls_local::{handle_tls, valid_tls_version}; 6 | use crate::tunnel::Message; 7 | use anyhow::{anyhow, Result}; 8 | use tokio::net::{TcpListener, TcpStream}; 9 | use tokio::sync::mpsc; 10 | 11 | async fn handle_local_tunnel( 12 | inbound: TcpStream, 13 | tunnel_id: u32, 14 | sender: mpsc::UnboundedSender, 15 | ) -> Result<()> { 16 | //stream.peek(buf) 17 | let mut peek_buf = [0u8; 3]; 18 | inbound.peek(&mut peek_buf).await?; 19 | match peek_buf[0] { 20 | 5 => { 21 | //socks5 22 | tracing::info!("[{}]Accept client as SOCKS5 proxy.", tunnel_id); 23 | handle_socks5(tunnel_id, inbound, sender).await?; 24 | return Ok(()); 25 | } 26 | 4 => { 27 | //socks4 28 | tracing::error!("socks4 not supported!"); 29 | return Err(anyhow!("socks4 unimplemented")); 30 | } 31 | _ => { 32 | //info!("Not socks protocol:{}", _data[0]); 33 | } 34 | } 35 | if valid_tls_version(&peek_buf[..]) { 36 | tracing::info!("[{}]Accept client as TLS proxy.", tunnel_id); 37 | handle_tls(tunnel_id, inbound, sender).await?; 38 | return Ok(()); 39 | } 40 | if let Ok(prefix_str) = std::str::from_utf8(&peek_buf) { 41 | let prefix_str = prefix_str.to_uppercase(); 42 | match prefix_str.as_str() { 43 | "GET" | "PUT" | "POS" | "DEL" | "OPT" | "TRA" | "PAT" | "HEA" | "CON" | "UPG" => { 44 | tracing::info!( 45 | "[{}]Accept client as HTTP proxy with method:{}", 46 | tunnel_id, 47 | prefix_str 48 | ); 49 | //http proxy 50 | if prefix_str.as_str() == "CON" { 51 | handle_https(tunnel_id, inbound, sender).await?; 52 | } else { 53 | handle_http(tunnel_id, inbound, sender).await?; 54 | } 55 | return Ok(()); 56 | } 57 | _ => { 58 | //nothing 59 | } 60 | }; 61 | } 62 | Ok(()) 63 | } 64 | 65 | pub async fn start_local_tunnel_server( 66 | addr: &SocketAddr, 67 | sender: mpsc::UnboundedSender, 68 | ) -> Result<(), std::io::Error> { 69 | let listener = TcpListener::bind(addr).await?; 70 | tracing::info!("Start local TCP listen at {}", addr); 71 | let mut tunnel_id_seed: u32 = 0; 72 | while let Ok((inbound, _)) = listener.accept().await { 73 | let tunnel_id = tunnel_id_seed; 74 | tunnel_id_seed += 1; 75 | let tunnel_sender = sender.clone(); 76 | tokio::spawn(async move { 77 | if let Err(e) = handle_local_tunnel(inbound, tunnel_id, tunnel_sender).await { 78 | tracing::error!("handle local tunnel error:{}", e); 79 | } 80 | }); 81 | } 82 | Ok(()) 83 | } 84 | -------------------------------------------------------------------------------- /src/tunnel/http_local.rs: -------------------------------------------------------------------------------- 1 | use crate::tunnel::Message; 2 | use anyhow::{anyhow, Result}; 3 | 4 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 5 | use tokio::net::TcpStream; 6 | use tokio::sync::mpsc; 7 | 8 | use crate::tunnel::tls_local; 9 | 10 | async fn read_http_headers(inbound: &mut TcpStream) -> Result> { 11 | let mut buf: Vec = Vec::new(); 12 | let crlf2 = "\r\n\r\n".as_bytes(); 13 | loop { 14 | let mut tmp_buf = [0; 4096]; 15 | let n = inbound.read(&mut tmp_buf).await?; 16 | buf.extend_from_slice(&tmp_buf[0..n]); 17 | if let Some(pos) = buf.windows(crlf2.len()).position(|window| window == crlf2) { 18 | return Ok(buf); 19 | } 20 | } 21 | } 22 | 23 | fn extract_target(headers_buf: &Vec, default_port: &str) -> Result { 24 | let mut headers = [httparse::EMPTY_HEADER; 64]; 25 | let mut req = httparse::Request::new(&mut headers); 26 | req.parse(headers_buf.as_slice())?; 27 | let mut target_addr: String = String::new(); 28 | if let Some(path) = req.path { 29 | if path.starts_with("http://") { 30 | let url = url::Url::parse(path)?; 31 | if url.has_host() { 32 | target_addr.push_str(url.host_str().unwrap()); 33 | } 34 | match url.port() { 35 | Some(p) => { 36 | target_addr.push(':'); 37 | target_addr.push_str(p.to_string().as_str()); 38 | } 39 | None => {} 40 | } 41 | } 42 | } 43 | if target_addr.is_empty() { 44 | for h in req.headers { 45 | if h.name.eq_ignore_ascii_case("Host") { 46 | target_addr = String::from(std::str::from_utf8(h.value)?); 47 | break; 48 | } 49 | } 50 | } 51 | if target_addr.is_empty() { 52 | return Err(anyhow!("Can not get target addr.")); 53 | } 54 | if target_addr.find(':').is_none() { 55 | target_addr.push_str(default_port); 56 | } 57 | Ok(target_addr) 58 | } 59 | 60 | pub async fn handle_http( 61 | tunnel_id: u32, 62 | mut inbound: TcpStream, 63 | sender: mpsc::UnboundedSender, 64 | ) -> Result<()> { 65 | let headers_buf = read_http_headers(&mut inbound).await?; 66 | let target_addr = extract_target(&headers_buf, ":80")?; 67 | 68 | tracing::info!("{}", target_addr); 69 | let msg = Message::open_stream(inbound, target_addr, Some(headers_buf)); 70 | sender.send(msg)?; 71 | Ok(()) 72 | } 73 | 74 | pub async fn handle_https( 75 | tunnel_id: u32, 76 | mut inbound: TcpStream, 77 | sender: mpsc::UnboundedSender, 78 | ) -> Result<()> { 79 | let headers_buf = read_http_headers(&mut inbound).await?; 80 | let conn_res = "HTTP/1.0 200 Connection established\r\n\r\n"; 81 | inbound.write_all(conn_res.as_bytes()).await?; 82 | let target_addr = match tls_local::peek_sni(&mut inbound).await { 83 | Ok(mut sni) => { 84 | sni.push_str(":443"); 85 | sni 86 | } 87 | Err(_) => extract_target(&headers_buf, ":443")?, 88 | }; 89 | tracing::info!("[{}]Handle HTTPS proxy to {} ", tunnel_id, target_addr); 90 | let msg = Message::open_stream(inbound, target_addr, None); 91 | sender.send(msg)?; 92 | Ok(()) 93 | } 94 | -------------------------------------------------------------------------------- /.github/workflows/build-release.yml: -------------------------------------------------------------------------------- 1 | name: Build Releases 2 | on: 3 | push: 4 | tags: 5 | - v* 6 | workflow_dispatch: 7 | inputs: 8 | tag: 9 | description: 'Release Tag' 10 | required: true 11 | type: string 12 | 13 | env: 14 | CARGO_TERM_COLOR: always 15 | 16 | jobs: 17 | build-cross: 18 | runs-on: ubuntu-latest 19 | env: 20 | RUST_BACKTRACE: full 21 | strategy: 22 | fail-fast: false 23 | matrix: 24 | target: 25 | - x86_64-unknown-linux-musl 26 | - arm-unknown-linux-musleabi 27 | - arm-unknown-linux-musleabihf 28 | 29 | steps: 30 | - uses: actions/checkout@v3 31 | 32 | - name: Install Rust 33 | run: | 34 | rustup set profile minimal 35 | rustup toolchain install nightly 36 | rustup default nightly 37 | rustup override set nightly 38 | rustup target add --toolchain nightly ${{ matrix.target }} 39 | 40 | - name: Install cross 41 | run: cargo install cross 42 | 43 | - name: Build ${{ matrix.target }} 44 | timeout-minutes: 120 45 | run: | 46 | compile_target=${{ matrix.target }} 47 | chmod +x ./ci/build_linux.sh 48 | ./ci/build_linux.sh ${{ matrix.target }} 49 | 50 | - name: Upload Github Assets 51 | uses: softprops/action-gh-release@v1 52 | env: 53 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 54 | with: 55 | files: ci/release/* 56 | prerelease: ${{ contains(github.ref_name, '-') }} 57 | tag_name: ${{ inputs.tag || github.ref_name }} 58 | 59 | build-unix: 60 | runs-on: ${{ matrix.os }} 61 | env: 62 | RUST_BACKTRACE: full 63 | strategy: 64 | fail-fast: false 65 | matrix: 66 | # os: [ubuntu-latest, macos-latest] 67 | os: [macos-latest] 68 | target: 69 | - x86_64-apple-darwin 70 | - aarch64-apple-darwin 71 | steps: 72 | - uses: actions/checkout@v3 73 | 74 | - name: Install GNU tar 75 | if: runner.os == 'macOS' 76 | run: | 77 | brew install gnu-tar 78 | # echo "::add-path::/usr/local/opt/gnu-tar/libexec/gnubin" 79 | echo "/usr/local/opt/gnu-tar/libexec/gnubin" >> $GITHUB_PATH 80 | 81 | - name: Install Rust 82 | run: | 83 | rustup set profile minimal 84 | rustup toolchain install nightly 85 | rustup default nightly 86 | rustup override set nightly 87 | rustup target add --toolchain nightly ${{ matrix.target }} 88 | 89 | - name: Build release 90 | shell: bash 91 | run: | 92 | chmod +x ./ci/build_other.sh 93 | ./ci/build_other.sh ${{ matrix.target }} 94 | 95 | - name: Upload Github Assets 96 | uses: softprops/action-gh-release@v1 97 | env: 98 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 99 | with: 100 | files: ci/release/* 101 | prerelease: ${{ contains(github.ref_name, '-') }} 102 | tag_name: ${{ inputs.tag || github.ref_name }} 103 | 104 | build-windows: 105 | runs-on: windows-latest 106 | env: 107 | RUSTFLAGS: "-C target-feature=+crt-static" 108 | RUST_BACKTRACE: full 109 | steps: 110 | - uses: actions/checkout@v3 111 | 112 | - name: Install Rust 113 | run: | 114 | rustup set profile minimal 115 | rustup toolchain install nightly 116 | rustup default nightly 117 | rustup override set nightly 118 | 119 | - name: Build release 120 | shell: bash 121 | run: | 122 | chmod +x ./ci/build_other.sh 123 | ./ci/build_other.sh x86_64-pc-windows-msvc 124 | 125 | - name: Upload Github Assets 126 | uses: softprops/action-gh-release@v1 127 | env: 128 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 129 | with: 130 | files: ci/release/* 131 | prerelease: ${{ contains(github.ref_name, '-') }} 132 | tag_name: ${{ inputs.tag || github.ref_name }} -------------------------------------------------------------------------------- /src/tunnel/quic_remote.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{Context, Result}; 2 | use quinn::ConnectionError; 3 | use rustls_pemfile::Item; 4 | use std::{fs, net::SocketAddr, path::PathBuf, sync::Arc}; 5 | 6 | use crate::tunnel::stream::handle_server_stream; 7 | 8 | use crate::tunnel::ALPN_QUIC_HTTP; 9 | 10 | fn print_type_of(_: &T) { 11 | println!("{}", std::any::type_name::()) 12 | } 13 | 14 | pub async fn start_quic_remote_server( 15 | listen: &SocketAddr, 16 | cert_path: &PathBuf, 17 | key_path: &PathBuf, 18 | ) -> Result<()> { 19 | let key = fs::read(key_path.clone()).context("failed to read private key")?; 20 | let key = if key_path.extension().map_or(false, |x| x == "der") { 21 | tracing::debug!("private key with DER format"); 22 | rustls::PrivateKey(key) 23 | } else { 24 | match rustls_pemfile::read_one(&mut &*key) { 25 | Ok(x) => match x.unwrap() { 26 | Item::RSAKey(key) => { 27 | tracing::debug!("private key with PKCS #1 format"); 28 | rustls::PrivateKey(key) 29 | } 30 | Item::PKCS8Key(key) => { 31 | tracing::debug!("private key with PKCS #8 format"); 32 | rustls::PrivateKey(key) 33 | } 34 | Item::ECKey(key) => { 35 | tracing::debug!("private key with SEC1 format"); 36 | rustls::PrivateKey(key) 37 | } 38 | Item::X509Certificate(_) => { 39 | anyhow::bail!("you should provide a key file instead of cert"); 40 | } 41 | _ => { 42 | anyhow::bail!("no private keys found"); 43 | } 44 | }, 45 | Err(_) => { 46 | anyhow::bail!("malformed private key"); 47 | } 48 | } 49 | }; 50 | 51 | let certs = fs::read(cert_path.clone()).context("failed to read certificate chain")?; 52 | let certs = if cert_path.extension().map_or(false, |x| x == "der") { 53 | vec![rustls::Certificate(certs)] 54 | } else { 55 | rustls_pemfile::certs(&mut &*certs) 56 | .context("invalid PEM-encoded certificate")? 57 | .into_iter() 58 | .map(rustls::Certificate) 59 | .collect() 60 | }; 61 | 62 | let mut server_crypto = rustls::ServerConfig::builder() 63 | .with_safe_defaults() 64 | .with_no_client_auth() 65 | .with_single_cert(certs, key)?; 66 | server_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); 67 | 68 | let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); 69 | let transport_config = Arc::get_mut(&mut server_config.transport).unwrap(); 70 | transport_config.max_concurrent_uni_streams(0_u8.into()); 71 | 72 | let endpoint = quinn::Endpoint::server(server_config, listen.clone())?; 73 | tracing::info!("QUIC server listening on {}", endpoint.local_addr()?); 74 | 75 | while let Some(conn) = endpoint.accept().await { 76 | tracing::info!("QUIC connection incoming"); 77 | 78 | let fut = handle_quic_connection(conn); 79 | tokio::spawn(async move { 80 | match fut.await { 81 | Err(e) => { 82 | tracing::error!("connection failed: {reason}", reason = e.to_string()) 83 | } 84 | _ => {} 85 | } 86 | }); 87 | } 88 | 89 | Ok(()) 90 | } 91 | 92 | async fn handle_quic_connection(conn: quinn::Connecting) -> Result<()> { 93 | let connection = conn.await?; 94 | 95 | async { 96 | // tracing::info!("QUIC connection established"); 97 | 98 | // Each stream initiated by the client constitutes a new request. 99 | loop { 100 | let stream = connection.accept_bi().await; 101 | metrics::increment_gauge!("quic_server_proxy_streams", 1.0); 102 | let (mut send_stream, mut recv_stream) = match stream { 103 | Err(quinn::ConnectionError::ApplicationClosed { .. }) => { 104 | tracing::info!("connection closed"); 105 | return Ok(()); 106 | } 107 | Err(e) => { 108 | return Err(e); 109 | } 110 | Ok(s) => s, 111 | }; 112 | tokio::spawn(async move { 113 | //tracing::info!("handle quic stream"); 114 | if let Err(e) = handle_server_stream(&mut recv_stream, &mut send_stream).await { 115 | //print_type_of(&e); 116 | tracing::error!("failed: {reason}", reason = e.to_string()); 117 | } 118 | metrics::decrement_gauge!("quic_server_proxy_streams", 1.0); 119 | }); 120 | } 121 | } 122 | .await?; 123 | Ok(()) 124 | } 125 | -------------------------------------------------------------------------------- /src/tunnel/tls_remote.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{Context, Result}; 2 | // use pki_types::{CertificateDer, PrivateKeyDer}; 3 | use crate::mux; 4 | use crate::tunnel::stream::handle_server_stream; 5 | use rustls_pemfile::Item; 6 | use std::{collections::VecDeque, fs, io, net::SocketAddr, path::PathBuf, sync::Arc, sync::Mutex}; 7 | 8 | use tokio::net::TcpListener; 9 | use tokio_rustls::TlsAcceptor; 10 | 11 | pub const ALPN_QUIC_HTTP: &[&[u8]] = &[b"hq-29"]; 12 | pub async fn start_tls_remote_server( 13 | listen: &SocketAddr, 14 | cert_path: &PathBuf, 15 | key_path: &PathBuf, 16 | ) -> Result<()> { 17 | let key = fs::read(key_path.clone()).context("failed to read private key")?; 18 | let key = if key_path.extension().map_or(false, |x| x == "der") { 19 | tracing::debug!("private key with DER format"); 20 | tokio_rustls::rustls::PrivateKey(key) 21 | } else { 22 | match rustls_pemfile::read_one(&mut &*key) { 23 | Ok(x) => match x.unwrap() { 24 | Item::RSAKey(key) => { 25 | tracing::debug!("private key with PKCS #1 format"); 26 | tokio_rustls::rustls::PrivateKey(key) 27 | } 28 | Item::PKCS8Key(key) => { 29 | tracing::debug!("private key with PKCS #8 format"); 30 | tokio_rustls::rustls::PrivateKey(key) 31 | } 32 | Item::ECKey(key) => { 33 | tracing::debug!("private key with SEC1 format"); 34 | tokio_rustls::rustls::PrivateKey(key) 35 | } 36 | Item::X509Certificate(_) => { 37 | anyhow::bail!("you should provide a key file instead of cert"); 38 | } 39 | _ => { 40 | anyhow::bail!("no private keys found"); 41 | } 42 | }, 43 | Err(_) => { 44 | anyhow::bail!("malformed private key"); 45 | } 46 | } 47 | }; 48 | 49 | let certs = fs::read(cert_path.clone()).context("failed to read certificate chain")?; 50 | let certs = if cert_path.extension().map_or(false, |x| x == "der") { 51 | vec![tokio_rustls::rustls::Certificate(certs)] 52 | } else { 53 | rustls_pemfile::certs(&mut &*certs) 54 | .context("invalid PEM-encoded certificate")? 55 | .into_iter() 56 | .map(tokio_rustls::rustls::Certificate) 57 | .collect() 58 | }; 59 | 60 | let mut server_crypto = tokio_rustls::rustls::ServerConfig::builder() 61 | .with_safe_defaults() 62 | .with_no_client_auth() 63 | .with_single_cert(certs, key)?; 64 | server_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); 65 | 66 | let acceptor = TlsAcceptor::from(Arc::new(server_crypto)); 67 | let listener = TcpListener::bind(listen).await?; 68 | tracing::info!("TLS server listening on {:?}", listen); 69 | 70 | let mut id: u32 = 0; 71 | let free_ids = Arc::new(Mutex::new(VecDeque::new())); 72 | loop { 73 | let (stream, _) = listener.accept().await?; 74 | let conn_id = if free_ids.lock().unwrap().is_empty() { 75 | id += 1; 76 | id - 1 77 | } else { 78 | free_ids.lock().unwrap().pop_front().unwrap() 79 | }; 80 | let acceptor = acceptor.clone(); 81 | let fut_free_ids = free_ids.clone(); 82 | let fut = async move { 83 | let stream = acceptor.accept(stream).await?; 84 | tracing::info!("TLS connection incoming"); 85 | handle_tls_connection(stream, conn_id).await?; 86 | Ok(()) as Result<()> 87 | }; 88 | 89 | tokio::spawn(async move { 90 | if let Err(e) = fut.await { 91 | tracing::error!("connection failed: {reason}", reason = e.to_string()) 92 | } 93 | fut_free_ids.lock().unwrap().push_back(conn_id); 94 | }); 95 | } 96 | } 97 | 98 | async fn handle_tls_connection( 99 | conn: tokio_rustls::server::TlsStream, 100 | id: u32, 101 | ) -> Result<()> { 102 | let (r, w) = tokio::io::split(conn); 103 | let mux_conn = mux::Connection::new(r, w, mux::Mode::Server, id); 104 | 105 | loop { 106 | let stream = mux_conn.accept_stream().await?; 107 | metrics::increment_gauge!("tls_server_proxy_streams", 1.0); 108 | tokio::spawn(async move { 109 | let stream_id = stream.id(); 110 | let (mut stream_reader, mut stream_writer) = tokio::io::split(stream); 111 | if let Err(e) = handle_server_stream(&mut stream_reader, &mut stream_writer).await { 112 | tracing::error!( 113 | "[{}/{}]failed: {reason}", 114 | id, 115 | stream_id, 116 | reason = e.to_string() 117 | ); 118 | } 119 | metrics::decrement_gauge!("tls_server_proxy_streams", 1.0); 120 | }); 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /src/tunnel/socks5_local.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use std::net::{Ipv4Addr, Ipv6Addr}; 3 | use tokio::io::AsyncReadExt; 4 | use tokio::io::AsyncWriteExt; 5 | use tokio::net::TcpStream; 6 | use tokio::sync::mpsc; 7 | 8 | use crate::tunnel::Message; 9 | 10 | mod v5 { 11 | pub const VERSION: u8 = 5; 12 | 13 | pub const METH_NO_AUTH: u8 = 0; 14 | pub const METH_GSSAPI: u8 = 1; 15 | pub const METH_USER_PASS: u8 = 2; 16 | 17 | pub const CMD_CONNECT: u8 = 1; 18 | pub const CMD_BIND: u8 = 2; 19 | pub const CMD_UDP_ASSOCIATE: u8 = 3; 20 | 21 | pub const ATYP_IPV4: u8 = 1; 22 | pub const ATYP_IPV6: u8 = 4; 23 | pub const ATYP_DOMAIN: u8 = 3; 24 | 25 | pub const SOCKS_RESP_SUUCESS: u8 = 0; 26 | } 27 | 28 | // Extracts the name and port from addr_buf and returns them, converting 29 | // the name to the form that the trust-dns client can use. If the original 30 | // name can be parsed as an IP address, makes a SocketAddr from that 31 | // address and the port and returns it; we skip DNS resolution in that 32 | // case. 33 | fn name_port(addr_buf: &[u8]) -> Option { 34 | // The last two bytes of the buffer are the port, and the other parts of it 35 | // are the hostname. 36 | let hostname = &addr_buf[..addr_buf.len() - 2]; 37 | let hostname = match std::str::from_utf8(hostname) { 38 | Ok(s) => s, 39 | Err(_e) => { 40 | return None; 41 | } 42 | }; 43 | let pos = addr_buf.len() - 2; 44 | let port = ((addr_buf[pos] as u16) << 8) | (addr_buf[pos + 1] as u16); 45 | Some(format!("{}:{}", hostname, port)) 46 | } 47 | 48 | pub async fn handle_socks5( 49 | tunnel_id: u32, 50 | mut inbound: TcpStream, 51 | sender: mpsc::UnboundedSender, 52 | ) -> Result<()> { 53 | //let mut peek_buf = Vec::new(); 54 | let mut num_methods_buf = [0u8; 2]; 55 | inbound.read_exact(&mut num_methods_buf).await?; 56 | let mut vdata = vec![0; num_methods_buf[1] as usize]; 57 | inbound.read_exact(&mut vdata).await?; 58 | if !vdata.contains(&v5::METH_NO_AUTH) { 59 | return Err(anyhow!("no supported method given")); 60 | } 61 | inbound.write_all(&[v5::VERSION, v5::METH_NO_AUTH]).await?; 62 | let mut head = [0u8; 4]; 63 | inbound.read_exact(&mut head).await?; 64 | if head[0] != v5::VERSION { 65 | return Err(anyhow!("didn't confirm with v5 version")); 66 | } 67 | if head[1] != v5::CMD_CONNECT { 68 | return Err(anyhow!("unsupported command")); 69 | } 70 | let target_addr = match head[3] { 71 | v5::ATYP_IPV4 => { 72 | let mut addr_buf = [0u8; 6]; 73 | inbound.read_exact(&mut addr_buf).await?; 74 | let addr = Ipv4Addr::new(addr_buf[0], addr_buf[1], addr_buf[2], addr_buf[3]); 75 | let port = ((addr_buf[4] as u16) << 8) | (addr_buf[5] as u16); 76 | format!("{}:{}", addr.to_string(), port) 77 | } 78 | v5::ATYP_IPV6 => { 79 | let mut addr_buf = [0u8; 18]; 80 | inbound.read_exact(&mut addr_buf).await?; 81 | let a = ((addr_buf[0] as u16) << 8) | (addr_buf[1] as u16); 82 | let b = ((addr_buf[2] as u16) << 8) | (addr_buf[3] as u16); 83 | let c = ((addr_buf[4] as u16) << 8) | (addr_buf[5] as u16); 84 | let d = ((addr_buf[6] as u16) << 8) | (addr_buf[7] as u16); 85 | let e = ((addr_buf[8] as u16) << 8) | (addr_buf[9] as u16); 86 | let f = ((addr_buf[10] as u16) << 8) | (addr_buf[11] as u16); 87 | let g = ((addr_buf[12] as u16) << 8) | (addr_buf[13] as u16); 88 | let h = ((addr_buf[14] as u16) << 8) | (addr_buf[15] as u16); 89 | let addr = Ipv6Addr::new(a, b, c, d, e, f, g, h); 90 | let port = ((addr_buf[16] as u16) << 8) | (addr_buf[17] as u16); 91 | format!("{}:{}", addr.to_string(), port) 92 | } 93 | v5::ATYP_DOMAIN => { 94 | // 95 | let mut len_buf = [0u8; 1]; 96 | inbound.read_exact(&mut len_buf).await?; 97 | let mut addr_buf = vec![0u8; len_buf[0] as usize + 2]; 98 | inbound.read_exact(&mut addr_buf).await?; 99 | match name_port(&addr_buf) { 100 | Some(addr) => addr, 101 | None => { 102 | return Err(anyhow!("can not get addr with domian")); 103 | } 104 | } 105 | } 106 | n => { 107 | return Err(anyhow!("unknown ATYP received: {}", n)); 108 | } 109 | }; 110 | let mut resp = [0u8; 10]; 111 | // VER - protocol version 112 | resp[0] = 5; 113 | // REP - "reply field" -- what happened with the actual connect. 114 | // 115 | // In theory this should reply back with a bunch more kinds of 116 | // errors if possible, but for now we just recognize a few concrete 117 | // errors. 118 | resp[1] = v5::SOCKS_RESP_SUUCESS; 119 | 120 | // RSV - reserved 121 | resp[2] = 0; 122 | resp[3] = 1; // socksAtypeV4 = 0x01 123 | inbound.write_all(&resp).await?; 124 | 125 | tracing::info!( 126 | "[{}]Handle SOCKS5 proxy to {} with local:{} remote:{}", 127 | tunnel_id, 128 | target_addr, 129 | inbound.local_addr().unwrap(), 130 | inbound.peer_addr().unwrap() 131 | ); 132 | 133 | let msg = Message::open_stream(inbound, target_addr, None); 134 | sender.send(msg)?; 135 | Ok(()) 136 | } 137 | -------------------------------------------------------------------------------- /src/mux/event.rs: -------------------------------------------------------------------------------- 1 | //use tokio::codec::{Decoder, Encoder}; 2 | use anyhow::{anyhow, Result}; 3 | use bincode::{config, Decode, Encode}; 4 | use bytes::{Buf, BufMut, BytesMut}; 5 | use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; 6 | use tokio::sync::oneshot; 7 | 8 | pub const FLAG_OPEN: u8 = 1; 9 | pub const FLAG_FIN: u8 = 2; 10 | pub const FLAG_SYN: u8 = 4; 11 | pub const FLAG_DATA: u8 = 3; 12 | // pub const FLAG_WIN_UPDATE: u8 = 4; 13 | pub const FLAG_PING: u8 = 5; 14 | // pub const FLAG_SHUTDOWN: u8 = 7; 15 | // pub const FLAG_PONG: u8 = 8; 16 | // pub const FLAG_ROUTINE: u8 = 9; 17 | 18 | pub const EVENT_HEADER_LEN: usize = 8; 19 | 20 | pub fn get_event_type_str(flags: u8) -> &'static str { 21 | match flags { 22 | FLAG_OPEN => "FLAG_SYN", 23 | // FLAG_FIN => "FLAG_FIN", 24 | // FLAG_DATA => "FLAG_DATA", 25 | // FLAG_WIN_UPDATE => "FLAG_WIN_UPDATE", 26 | // FLAG_PING => "FLAG_PING", 27 | // FLAG_SHUTDOWN => "FLAG_SHUTDOWN", 28 | // FLAG_PONG => "FLAG_PONG", 29 | _ => "INVALID", 30 | } 31 | } 32 | 33 | #[derive(Debug, Clone, Copy)] 34 | pub struct Header { 35 | pub flag_len: u32, 36 | pub stream_id: u32, 37 | //pub reserved: [u8; 2], 38 | } 39 | 40 | fn get_flag_len(len: u32, flag: u8) -> u32 { 41 | (len << 8) | u32::from(flag) 42 | } 43 | 44 | impl Header { 45 | fn set_flag_len(&mut self, len: u32, flag: u8) { 46 | self.flag_len = (len << 8) | u32::from(flag); 47 | } 48 | pub fn flags(&self) -> u8 { 49 | (self.flag_len & 0xFF) as u8 50 | } 51 | pub fn len(&self) -> u32 { 52 | (self.flag_len >> 8) 53 | } 54 | #[allow(dead_code)] 55 | pub fn set_len(&mut self, v: u32) { 56 | let f = self.flags(); 57 | self.set_flag_len(v, f); 58 | } 59 | pub fn set_flag(&mut self, v: u8) { 60 | let l = self.len(); 61 | self.set_flag_len(l, v); 62 | } 63 | } 64 | 65 | #[derive(Encode, Decode, PartialEq, Debug)] 66 | pub struct OpenStreamEvent { 67 | pub proto: String, 68 | pub addr: String, 69 | } 70 | 71 | #[derive(Debug, Clone)] 72 | pub struct Event { 73 | pub header: Header, 74 | pub body: Vec, 75 | } 76 | 77 | impl Event { 78 | #[allow(dead_code)] 79 | pub fn is_empty(&self) -> bool { 80 | self.header.flags() == 0 as u8 81 | } 82 | } 83 | 84 | #[allow(dead_code)] 85 | pub fn new_empty_event(inbound: bool) -> Event { 86 | Event { 87 | header: Header { 88 | flag_len: get_flag_len(0, 0), 89 | stream_id: 0, 90 | }, 91 | body: Vec::new(), 92 | } 93 | } 94 | fn new_event(sid: u32, buf: &[u8]) -> Event { 95 | Event { 96 | header: Header { 97 | flag_len: get_flag_len(buf.len() as u32, 0), 98 | stream_id: sid, 99 | }, 100 | body: Vec::from(buf), 101 | } 102 | } 103 | pub fn new_data_event(sid: u32, buf: Vec) -> Event { 104 | Event { 105 | header: Header { 106 | flag_len: get_flag_len(buf.len() as u32, FLAG_DATA), 107 | stream_id: sid, 108 | }, 109 | body: buf, 110 | } 111 | } 112 | 113 | pub fn new_fin_event(sid: u32) -> Event { 114 | Event { 115 | header: Header { 116 | flag_len: get_flag_len(0, FLAG_FIN), 117 | stream_id: sid, 118 | }, 119 | body: Vec::new(), 120 | } 121 | } 122 | pub fn new_syn_event(sid: u32) -> Event { 123 | Event { 124 | header: Header { 125 | flag_len: get_flag_len(0, FLAG_SYN), 126 | stream_id: sid, 127 | }, 128 | body: Vec::new(), 129 | } 130 | } 131 | pub fn new_ping_event() -> Event { 132 | Event { 133 | header: Header { 134 | flag_len: get_flag_len(0, FLAG_PING), 135 | stream_id: 0, 136 | }, 137 | body: Vec::new(), 138 | } 139 | } 140 | 141 | pub fn new_open_stream_event(sid: u32, msg: &OpenStreamEvent) -> Event { 142 | let config = config::standard(); 143 | let data: Vec = bincode::encode_to_vec(msg, config).unwrap(); 144 | let mut ev = new_event(sid, &data[..]); 145 | ev.header.set_flag(FLAG_OPEN); 146 | ev 147 | } 148 | 149 | pub async fn write_event<'a, T>(writer: &'a mut T, ev: Event) -> anyhow::Result<()> 150 | where 151 | T: AsyncWriteExt + Unpin, 152 | { 153 | let mut out = BytesMut::new(); 154 | out.reserve(EVENT_HEADER_LEN + ev.body.len()); 155 | out.put_u32_le(ev.header.flag_len); 156 | out.put_u32_le(ev.header.stream_id); 157 | if !ev.body.is_empty() { 158 | out.put_slice(&ev.body[..]); 159 | } 160 | writer.write_all(&out).await?; 161 | Ok(()) 162 | } 163 | 164 | pub async fn read_event<'a, T>(reader: &'a mut T) -> Result 165 | where 166 | T: AsyncReadExt + Unpin + ?Sized, 167 | { 168 | let mut hbuf = vec![0; EVENT_HEADER_LEN]; 169 | let _ = reader.read_exact(&mut hbuf).await?; 170 | 171 | let mut xbuf: [u8; 4] = Default::default(); 172 | xbuf.copy_from_slice(&hbuf[0..4]); 173 | let e1 = u32::from_le_bytes(xbuf); 174 | xbuf.copy_from_slice(&hbuf[4..8]); 175 | let e2 = u32::from_le_bytes(xbuf); 176 | let header = Header { 177 | flag_len: e1, 178 | stream_id: e2, 179 | }; 180 | let body_data_len = header.len(); 181 | let mut dbuf = vec![0; body_data_len as usize]; 182 | if body_data_len > 0 { 183 | let _ = reader.read_exact(&mut dbuf).await?; 184 | } 185 | let ev = Event { header, body: dbuf }; 186 | Ok(ev) 187 | } 188 | -------------------------------------------------------------------------------- /src/tunnel/stream.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use futures::future::try_join; 3 | use std::sync::atomic::AtomicBool; 4 | use std::sync::atomic::AtomicU64; 5 | use std::sync::atomic::Ordering::SeqCst; 6 | use std::sync::Arc; 7 | use std::time::Duration; 8 | use std::time::SystemTime; 9 | use std::time::UNIX_EPOCH; 10 | use tokio::io::{self, AsyncReadExt, AsyncWrite, AsyncWriteExt}; 11 | use tokio::time::timeout; 12 | 13 | use crate::mux::event::{self, OpenStreamEvent}; 14 | use crate::tunnel::CheckTimeoutSecs; 15 | use crate::tunnel::DefaultTimeoutSecs; 16 | 17 | struct TransferState { 18 | abort: AtomicBool, 19 | io_active_timestamp_secs: AtomicU64, 20 | } 21 | 22 | impl TransferState { 23 | fn new() -> Self { 24 | Self { 25 | abort: AtomicBool::new(false), 26 | io_active_timestamp_secs: AtomicU64::new(0), 27 | } 28 | } 29 | } 30 | 31 | pub struct Stream<'a, LR, LW, RR, RW> { 32 | local_reader: &'a mut LR, 33 | local_writer: &'a mut LW, 34 | remote_reader: &'a mut RR, 35 | remote_writer: &'a mut RW, 36 | } 37 | 38 | async fn timeout_copy_impl( 39 | r: &mut R, 40 | w: &mut W, 41 | timeout_sec: u64, 42 | state: Arc, 43 | ) -> Result<()> { 44 | let mut buf = [0u8; 8192]; 45 | let check_timeout_secs = Duration::from_secs(CheckTimeoutSecs); 46 | loop { 47 | if state.abort.load(SeqCst) { 48 | return Err(anyhow!("abort")); 49 | } 50 | match timeout(check_timeout_secs, r.read(&mut buf)).await { 51 | Err(_) => { 52 | let now_secs = SystemTime::now() 53 | .duration_since(UNIX_EPOCH) 54 | .unwrap() 55 | .as_secs(); 56 | if now_secs > (state.io_active_timestamp_secs.load(SeqCst) + timeout_sec) { 57 | return Err(anyhow!(format!("timeout after inactive {}secs", now_secs- state.io_active_timestamp_secs.load(SeqCst)))); 58 | }else{ 59 | continue; 60 | } 61 | } 62 | Ok(Ok(n)) => { 63 | if n == 0 { 64 | break; 65 | }; 66 | state.io_active_timestamp_secs.store( 67 | SystemTime::now() 68 | .duration_since(UNIX_EPOCH) 69 | .unwrap() 70 | .as_secs(), 71 | SeqCst, 72 | ); 73 | if let Err(ex) = w.write_all(&buf[0..n]).await { 74 | state.abort.store(true, SeqCst); 75 | return Err(ex.into()); 76 | } 77 | } 78 | Ok(Err(e)) => { 79 | state.abort.store(true, SeqCst); 80 | return Err(e.into()); 81 | } 82 | } 83 | } 84 | Ok(()) 85 | } 86 | async fn timeout_copy( 87 | r: &mut R, 88 | w: &mut W, 89 | timeout_sec: u64, 90 | state: Arc, 91 | ) -> Result<()> { 92 | let result = timeout_copy_impl(r, w, timeout_sec, state).await; 93 | w.shutdown().await?; 94 | result 95 | } 96 | 97 | impl<'a, LR, LW, RR, RW> Stream<'a, LR, LW, RR, RW> 98 | where 99 | LR: AsyncReadExt + Unpin, 100 | LW: AsyncWriteExt + Unpin, 101 | RR: AsyncReadExt + Unpin, 102 | RW: AsyncWriteExt + Unpin, 103 | { 104 | pub fn new(lr: &'a mut LR, lw: &'a mut LW, rr: &'a mut RR, rw: &'a mut RW) -> Self { 105 | Self { 106 | local_reader: lr, 107 | local_writer: lw, 108 | remote_reader: rr, 109 | remote_writer: rw, 110 | } 111 | } 112 | 113 | pub async fn transfer(&mut self) -> Result<()> { 114 | let state = Arc::new(TransferState::new()); 115 | let client_to_server = timeout_copy( 116 | &mut self.local_reader, 117 | &mut self.remote_writer, 118 | DefaultTimeoutSecs, 119 | state.clone(), 120 | ); 121 | let server_to_client = timeout_copy( 122 | &mut self.remote_reader, 123 | &mut self.local_writer, 124 | DefaultTimeoutSecs, 125 | state.clone(), 126 | ); 127 | try_join(client_to_server, server_to_client).await?; 128 | Ok(()) 129 | } 130 | } 131 | 132 | pub async fn handle_server_stream<'a, LR: AsyncReadExt + Unpin, LW: AsyncWriteExt + Unpin>( 133 | mut lr: &'a mut LR, 134 | mut lw: &'a mut LW, 135 | ) -> Result<()> { 136 | let timeout_secs = Duration::from_secs(DefaultTimeoutSecs); 137 | match timeout(timeout_secs, event::read_event(&mut lr)).await? { 138 | Err(e) => match e.kind() { 139 | std::io::ErrorKind::UnexpectedEof => Ok(()), 140 | _ => Err(anyhow::Error::new(e)), 141 | }, 142 | Ok(ev) => { 143 | if ev.header.flags() != event::FLAG_OPEN { 144 | return Err(anyhow!("unexpected flag:{}", ev.header.flags())); 145 | } 146 | let config = bincode::config::standard(); 147 | let (open_event, len): (OpenStreamEvent, usize) = 148 | bincode::decode_from_slice(&ev.body[..], config)?; 149 | tracing::info!("[{}]recv open event:{:?}", ev.header.stream_id, open_event); 150 | let mut remote_stream = timeout( 151 | timeout_secs, 152 | tokio::net::TcpStream::connect(&open_event.addr), 153 | ) 154 | .await??; 155 | let (mut remote_receiver, mut remote_sender) = remote_stream.split(); 156 | let mut stream: Stream< 157 | LR, 158 | LW, 159 | tokio::net::tcp::ReadHalf<'_>, 160 | tokio::net::tcp::WriteHalf<'_>, 161 | > = Stream::new(&mut lr, &mut lw, &mut remote_receiver, &mut remote_sender); 162 | stream.transfer().await 163 | } 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /src/mux/stream.rs: -------------------------------------------------------------------------------- 1 | use crate::utils; 2 | use anyhow::{anyhow, Result}; 3 | use bytes::Buf; 4 | use bytes::BytesMut; 5 | use futures::ready; 6 | use std::pin::Pin; 7 | use std::task::{Context, Poll}; 8 | use tokio::io::AsyncRead; 9 | use tokio::io::AsyncWrite; 10 | use tokio::io::ReadBuf; 11 | use tokio::sync::mpsc; 12 | use tokio::sync::oneshot; 13 | use tokio_util::sync::PollSender; 14 | 15 | pub struct MuxStream { 16 | id: u32, 17 | ev_writer: PollSender, 18 | inbound_reader: mpsc::Receiver>>, 19 | recv_buf: BytesMut, 20 | initial_close: bool, 21 | close_by_remote: bool, 22 | } 23 | 24 | pub enum Control { 25 | AcceptStream(oneshot::Sender>), 26 | NewStream( 27 | ( 28 | u32, 29 | mpsc::Sender>>, 30 | Option>>>, 31 | ), 32 | ), 33 | StreamData(u32, Vec, bool), 34 | StreamShutdown(u32, bool), 35 | StreamClose(u32), 36 | Ping, 37 | Close, 38 | } 39 | 40 | fn fill_read_buf(src: &mut BytesMut, dst: &mut ReadBuf<'_>) -> usize { 41 | if src.is_empty() { 42 | return 0; 43 | } 44 | let mut n = src.len(); 45 | if n > dst.remaining() { 46 | n = dst.remaining(); 47 | } 48 | 49 | dst.put_slice(&src[0..n]); 50 | src.advance(n); 51 | if src.is_empty() { 52 | src.clear(); 53 | } 54 | n 55 | } 56 | 57 | impl MuxStream { 58 | pub fn new( 59 | id: u32, 60 | ev_writer: mpsc::Sender, 61 | inbound_reader: mpsc::Receiver>>, 62 | ) -> Self { 63 | Self { 64 | id, 65 | ev_writer: PollSender::new(ev_writer), 66 | inbound_reader, 67 | recv_buf: BytesMut::new(), 68 | initial_close: false, 69 | close_by_remote: false, 70 | } 71 | } 72 | 73 | pub fn id(&self) -> u32 { 74 | self.id 75 | } 76 | } 77 | 78 | impl AsyncRead for MuxStream { 79 | fn poll_read( 80 | mut self: Pin<&mut Self>, 81 | cx: &mut Context<'_>, 82 | buf: &mut ReadBuf<'_>, 83 | ) -> Poll> { 84 | if !self.recv_buf.is_empty() { 85 | fill_read_buf(&mut self.recv_buf, buf); 86 | if buf.remaining() == 0 { 87 | return Poll::Ready(Ok(())); 88 | } 89 | }; 90 | match self.inbound_reader.poll_recv(cx) { 91 | Poll::Ready(Some(data)) => match data { 92 | Some(b) => { 93 | let mut copy_n: usize = b.len(); 94 | if 0 == copy_n { 95 | return Poll::Ready(Ok(())); 96 | } 97 | if copy_n > buf.remaining() { 98 | copy_n = buf.remaining(); 99 | } 100 | buf.put_slice(&b[0..copy_n]); 101 | if copy_n < b.len() { 102 | self.recv_buf.extend_from_slice(&b[copy_n..]); 103 | } 104 | Poll::Ready(Ok(())) 105 | } 106 | None => { 107 | self.close_by_remote = true; 108 | Poll::Ready(Err(std::io::Error::new( 109 | std::io::ErrorKind::ConnectionReset, 110 | "close by remote", 111 | ))) 112 | } 113 | }, 114 | Poll::Ready(None) => { 115 | // self.eof_close = true; 116 | //error!("[{}]####3 Close", state.stream_id); 117 | Poll::Ready(Ok(())) 118 | } 119 | Poll::Pending => Poll::Pending, 120 | } 121 | } 122 | } 123 | 124 | impl AsyncWrite for MuxStream { 125 | fn poll_write( 126 | mut self: Pin<&mut Self>, 127 | cx: &mut Context<'_>, 128 | buf: &[u8], 129 | ) -> Poll> { 130 | let ctrl = Control::StreamData(self.id, Vec::from(buf), false); 131 | match ready!(self.ev_writer.poll_reserve(cx)) { 132 | Err(e) => Poll::Ready(Err(utils::make_io_error(&e.to_string()))), 133 | Ok(v) => match self.ev_writer.send_item(ctrl) { 134 | Ok(()) => Poll::Ready(Ok(buf.len())), 135 | Err(ex) => Poll::Ready(Err(utils::make_io_error(&ex.to_string()))), 136 | }, 137 | } 138 | } 139 | fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { 140 | Poll::Ready(Ok(())) 141 | } 142 | fn poll_shutdown( 143 | mut self: Pin<&mut Self>, 144 | cx: &mut Context<'_>, 145 | ) -> Poll> { 146 | if !self.initial_close { 147 | self.initial_close = true; 148 | let ctrl = Control::StreamShutdown(self.id, false); 149 | match ready!(self.ev_writer.poll_reserve(cx)) { 150 | Err(e) => Poll::Ready(Err(utils::make_io_error(&e.to_string()))), 151 | Ok(v) => match self.ev_writer.send_item(ctrl) { 152 | Ok(()) => Poll::Ready(Ok(())), 153 | Err(ex) => Poll::Ready(Err(utils::make_io_error(&ex.to_string()))), 154 | }, 155 | } 156 | } else { 157 | Poll::Ready(Ok(())) 158 | } 159 | } 160 | } 161 | 162 | impl Drop for MuxStream { 163 | fn drop(&mut self) { 164 | match self.ev_writer.get_ref() { 165 | Some(sender) => { 166 | if !self.initial_close { 167 | let ctrl_sender = sender.clone(); 168 | let stream_close = Control::StreamShutdown(self.id, false); 169 | tokio::spawn(async move { 170 | let _ = ctrl_sender.send(stream_close).await; 171 | }); 172 | } 173 | if !self.close_by_remote { 174 | let ctrl_sender = sender.clone(); 175 | let stream_drop = Control::StreamClose(self.id); 176 | tokio::spawn(async move { 177 | let _ = ctrl_sender.send(stream_drop).await; 178 | }); 179 | } 180 | } 181 | None => {} 182 | } 183 | } 184 | } 185 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | #![feature(map_try_insert)] 2 | 3 | use anyhow::{anyhow, Result}; 4 | use clap::{Parser, ValueEnum}; 5 | 6 | use metrics::gauge; 7 | use std::fs; 8 | use std::net::SocketAddr; 9 | use std::path::PathBuf; 10 | use std::time::Duration; 11 | use tokio::sync::mpsc::UnboundedSender; 12 | use tokio::{task, time}; 13 | use tracing_subscriber; 14 | use url::{ParseError, Url}; 15 | use veil::Redact; // 1.3.0 16 | 17 | mod mux; 18 | mod tunnel; 19 | mod utils; 20 | 21 | #[derive(ValueEnum, Clone, Debug)] 22 | enum Protocol { 23 | TLS, 24 | QUIC, 25 | } 26 | 27 | #[derive(ValueEnum, Clone, Debug)] 28 | enum Role { 29 | CLIENT, 30 | SERVER, 31 | } 32 | 33 | #[derive(Parser, Redact)] 34 | #[clap(author, version, about, long_about = None)] 35 | struct Args { 36 | // #[clap(default_value = "", long, env)] 37 | // #[redact(partial)] 38 | // model_id: String, 39 | #[structopt(long = "listen", default_value = "127.0.0.1:48100")] 40 | listen: SocketAddr, 41 | 42 | #[clap(long, value_enum, default_value_t=Protocol::TLS)] 43 | protocol: Protocol, 44 | 45 | #[structopt(long = "remote")] 46 | remote: Option, 47 | 48 | #[clap(default_value = "127.0.0.1:48101", long, env)] 49 | admin: String, 50 | 51 | #[clap(long = "key", requires = "cert", default_value = "key.der")] 52 | #[redact(partial)] 53 | key: Option, 54 | /// TLS certificate in PEM format 55 | #[clap(long = "cert", default_value = "cert.der")] 56 | cert: Option, 57 | 58 | #[clap(long, value_enum, default_value_t=Role::CLIENT)] 59 | role: Role, 60 | 61 | #[clap(default_value = "5", long)] 62 | concurrent: usize, 63 | 64 | #[clap(default_value = "mydomain.io", long)] 65 | tls_host: String, 66 | 67 | #[clap(default_value = "false", long)] 68 | rcgen: bool, 69 | 70 | #[clap(default_value = "", long)] 71 | log: String, 72 | } 73 | 74 | // async fn handler() -> Html<&'static str> { 75 | // Html("

Hello, World!

") 76 | // } 77 | 78 | fn rcgen(tls_host: &String) { 79 | let cert_path = std::path::PathBuf::from(r"./cert.der"); 80 | let key_path = std::path::PathBuf::from(r"./key.der"); 81 | // let cert_der_path = std::path::PathBuf::from(r"./cert.der"); 82 | 83 | println!( 84 | "generating self-signed certificate at {:?} & {:?} with host:{}", 85 | cert_path, key_path, tls_host, 86 | ); 87 | let cert = rcgen::generate_simple_self_signed(vec![tls_host.into()]).unwrap(); 88 | let key = cert.serialize_private_key_der(); 89 | let cert = cert.serialize_der().unwrap(); 90 | // let cert = cert.serialize_pem().unwrap(); 91 | 92 | if let Err(e) = fs::write(&cert_path, &cert) { 93 | println!("failed to write certificate:{}", e); 94 | return; 95 | } 96 | if let Err(e) = fs::write(&key_path, &key) { 97 | println!("failed to write certificate:{}", e); 98 | return; 99 | } 100 | } 101 | 102 | #[tokio::main] 103 | async fn main() -> anyhow::Result<()> { 104 | let args: Args = Args::parse(); 105 | if args.log.is_empty() { 106 | tracing_subscriber::fmt::init(); 107 | } else { 108 | let file_appender = tracing_appender::rolling::daily("./", args.log.as_str()); 109 | //let (non_blocking_appender, _guard) = tracing_appender::non_blocking(file_appender); 110 | tracing_subscriber::fmt().with_writer(file_appender).init(); 111 | tokio::spawn(utils::clean_rotate_logs(format!("./{}", args.log.as_str()))); 112 | } 113 | tracing::info!("{args:?}"); 114 | 115 | if args.rcgen { 116 | rcgen(&args.tls_host); 117 | return Ok(()); 118 | } 119 | 120 | let recorder = utils::MetricsLogRecorder::new(Duration::from_secs(10)); 121 | metrics::set_boxed_recorder(Box::new(recorder)).unwrap(); 122 | 123 | match args.role { 124 | Role::CLIENT => { 125 | let tunnel_sender: UnboundedSender; 126 | match args.remote.as_ref().unwrap().scheme() { 127 | "quic" => { 128 | tunnel_sender = tunnel::MuxClient::::from( 129 | &args.remote.as_ref().unwrap(), 130 | &args.cert.as_ref().unwrap(), 131 | &args.tls_host, 132 | args.concurrent, 133 | ) 134 | .await?; 135 | } 136 | "tls" => { 137 | tunnel_sender = tunnel::MuxClient::::from( 138 | &args.remote.as_ref().unwrap(), 139 | &args.cert.as_ref().unwrap(), 140 | &args.tls_host, 141 | args.concurrent, 142 | ) 143 | .await?; 144 | } 145 | _ => { 146 | tracing::error!("unsupported"); 147 | return Err(anyhow!("unsupported")); 148 | } 149 | }; 150 | let health_checker = tunnel_sender.clone(); 151 | tokio::spawn(async move { 152 | let mut interval = time::interval(Duration::from_secs(1)); 153 | loop { 154 | interval.tick().await; 155 | if let Err(e) = health_checker.send(tunnel::Message::HealthCheck) { 156 | tracing::error!("health check error:{}", e); 157 | } 158 | } 159 | }); 160 | 161 | if let Err(e) = tunnel::start_local_tunnel_server(&args.listen, tunnel_sender) 162 | .await 163 | .map_err(anyhow::Error::from) 164 | { 165 | tracing::error!("{e:?}"); 166 | return Err(e); 167 | } 168 | } 169 | Role::SERVER => match args.protocol { 170 | Protocol::QUIC => { 171 | if let Err(e) = tunnel::start_quic_remote_server( 172 | &args.listen, 173 | args.cert.as_ref().unwrap(), 174 | args.key.as_ref().unwrap(), 175 | ) 176 | .await 177 | { 178 | tracing::error!("{e:?}"); 179 | } 180 | } 181 | Protocol::TLS => { 182 | if let Err(e) = tunnel::start_tls_remote_server( 183 | &args.listen, 184 | args.cert.as_ref().unwrap(), 185 | args.key.as_ref().unwrap(), 186 | ) 187 | .await 188 | { 189 | tracing::error!("{e:?}"); 190 | } 191 | } 192 | }, 193 | } 194 | Ok(()) 195 | } 196 | -------------------------------------------------------------------------------- /src/tunnel/tls_local.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | 3 | use tokio::net::TcpStream; 4 | use tokio::sync::mpsc; 5 | 6 | use crate::tunnel::Message; 7 | 8 | pub fn valid_tls_version(buf: &[u8]) -> bool { 9 | if buf.len() < 3 { 10 | return false; 11 | } 12 | //recordTypeHandshake 13 | if buf[0] != 0x16 { 14 | //info!("###1 here {}", buf[0]); 15 | return false; 16 | } 17 | let tls_major_ver = buf[1]; 18 | //let tlsMinorVer = buf[2]; 19 | 20 | if tls_major_ver < 3 { 21 | //no SNI before sslv3 22 | //info!("###2 here {}", tls_major_ver); 23 | return false; 24 | } 25 | true 26 | } 27 | 28 | // pub async fn peek_sni(inbound: &mut TcpStream) -> Result<(String)> { 29 | // let mut peek_buf = Vec::new(); 30 | // let mut ver_len_buf = [0u8; 5]; 31 | // inbound.read_exact(&mut ver_len_buf).await?; 32 | // peek_buf.extend_from_slice(&ver_len_buf); 33 | // let mut n = ver_len_buf[3] as u16; 34 | // n = (n << 8) + ver_len_buf[4] as u16; 35 | // if n < 42 { 36 | // return Err(make_error("no sufficient space for sni")); 37 | // } 38 | // let mut vdata = vec![0; n as usize]; 39 | // inbound.read_exact(&mut vdata).await?; 40 | // peek_buf.extend_from_slice(&vdata[..]); 41 | // if vdata[0] != 0x01 { 42 | // return Err(make_error("not clienthello handshake")); 43 | // } 44 | // let rest_buf = &vdata[38..]; 45 | // let sid_len = rest_buf[0] as usize; 46 | // let rest_buf = &rest_buf[(1 + sid_len)..]; 47 | // if rest_buf.len() < 2 { 48 | // return Err(make_error("no sufficient space for sni0")); 49 | // } 50 | // let mut cipher_len = rest_buf[0] as usize; 51 | // cipher_len = (cipher_len << 8) + rest_buf[1] as usize; 52 | // if cipher_len % 2 == 1 || rest_buf.len() < 3 + cipher_len { 53 | // return Err(make_error("invalid cipher_len")); 54 | // } 55 | // let rest_buf = &rest_buf[(2 + cipher_len)..]; 56 | // let compress_method_len = rest_buf[0] as usize; 57 | // if rest_buf.len() < 1 + compress_method_len { 58 | // return Err(make_error("invalid compress_method_len")); 59 | // } 60 | // let rest_buf = &rest_buf[(1 + compress_method_len)..]; 61 | // if rest_buf.len() < 2 { 62 | // return Err(make_error("invalid after compress_method")); 63 | // } 64 | // let mut ext_len = rest_buf[0] as usize; 65 | // ext_len = (ext_len << 8) + rest_buf[1] as usize; 66 | // let rest_buf = &rest_buf[2..]; 67 | // if rest_buf.len() < ext_len { 68 | // return Err(make_error("invalid ext_len")); 69 | // } 70 | // if ext_len == 0 { 71 | // return Err(make_error("no extension in client_hello")); 72 | // } 73 | // let mut ext_buf = rest_buf; 74 | // loop { 75 | // if ext_buf.len() < 4 { 76 | // return Err(make_error("invalid ext buf len")); 77 | // } 78 | // let mut extension = ext_buf[0] as usize; 79 | // extension = (extension << 8) + ext_buf[1] as usize; 80 | // let mut length = ext_buf[2] as usize; 81 | // length = (length << 8) + ext_buf[3] as usize; 82 | // ext_buf = &ext_buf[4..]; 83 | // if ext_buf.len() < length { 84 | // return Err(make_error("invalid ext buf content")); 85 | // } 86 | // if extension == 0 { 87 | // if length < 2 { 88 | // return Err(make_error("invalid ext buf length")); 89 | // } 90 | // let mut num_names = ext_buf[0] as usize; 91 | // num_names = (num_names << 8) + ext_buf[1] as usize; 92 | // let mut data = &ext_buf[2..]; 93 | // for _ in 0..num_names { 94 | // if data.len() < 3 { 95 | // return Err(make_error("invalid ext data length")); 96 | // } 97 | // let name_type = data[0]; 98 | // let mut name_len = data[1] as usize; 99 | // name_len = (name_len << 8) + data[2] as usize; 100 | // data = &data[3..]; 101 | // if data.len() < name_len { 102 | // return Err(make_error("invalid ext name data")); 103 | // } 104 | // if name_type == 0 { 105 | // let server_name = String::from_utf8_lossy(&data[0..name_len]); 106 | // debug!("####Peek SNI:{}", server_name); 107 | // return Ok((String::from(server_name), peek_buf)); 108 | // } 109 | // data = &data[name_len..]; 110 | // } 111 | // } 112 | // ext_buf = &ext_buf[length..]; 113 | // } 114 | // } 115 | 116 | pub async fn peek_sni(inbound: &mut TcpStream) -> Result { 117 | let mut peek_buf: Vec = Vec::new(); 118 | peek_buf.resize(4096, 0); 119 | let ver_len: usize = 5; 120 | let mut peek_cursor: usize = 0; 121 | let peek_n = inbound.peek(peek_buf.as_mut_slice()).await?; 122 | if peek_n < ver_len { 123 | return Err(anyhow!("no sufficient peek space for sni")); 124 | } 125 | let mut n = peek_buf[3] as u16; 126 | n = (n << 8) + peek_buf[4] as u16; 127 | if n < 42 { 128 | return Err(anyhow!("no sufficient space for sni")); 129 | } 130 | peek_cursor += ver_len; 131 | if peek_n < (peek_cursor + n as usize) { 132 | return Err(anyhow!("no sufficient peek buffer space for sni")); 133 | } 134 | if peek_buf[peek_cursor] != 0x01 { 135 | return Err(anyhow!("not clienthello handshake")); 136 | } 137 | let rest_buf = &peek_buf.as_slice()[peek_cursor + 38..]; 138 | let sid_len = rest_buf[0] as usize; 139 | let rest_buf = &rest_buf[(1 + sid_len)..]; 140 | if rest_buf.len() < 2 { 141 | return Err(anyhow!("no sufficient space for sni")); 142 | } 143 | let mut cipher_len = rest_buf[0] as usize; 144 | cipher_len = (cipher_len << 8) + rest_buf[1] as usize; 145 | if cipher_len % 2 == 1 || rest_buf.len() < 3 + cipher_len { 146 | return Err(anyhow!("invalid cipher_len")); 147 | } 148 | let rest_buf = &rest_buf[(2 + cipher_len)..]; 149 | let compress_method_len = rest_buf[0] as usize; 150 | if rest_buf.len() < 1 + compress_method_len { 151 | return Err(anyhow!("invalid compress_method_len")); 152 | } 153 | let rest_buf = &rest_buf[(1 + compress_method_len)..]; 154 | if rest_buf.len() < 2 { 155 | return Err(anyhow!("invalid after compress_method")); 156 | } 157 | let mut ext_len = rest_buf[0] as usize; 158 | ext_len = (ext_len << 8) + rest_buf[1] as usize; 159 | let rest_buf = &rest_buf[2..]; 160 | if rest_buf.len() < ext_len { 161 | return Err(anyhow!("invalid ext_len")); 162 | } 163 | if ext_len == 0 { 164 | return Err(anyhow!("no extension in client_hello")); 165 | } 166 | let mut ext_buf = rest_buf; 167 | loop { 168 | if ext_buf.len() < 4 { 169 | return Err(anyhow!("invalid ext buf len")); 170 | } 171 | let mut extension = ext_buf[0] as usize; 172 | extension = (extension << 8) + ext_buf[1] as usize; 173 | let mut length = ext_buf[2] as usize; 174 | length = (length << 8) + ext_buf[3] as usize; 175 | ext_buf = &ext_buf[4..]; 176 | if ext_buf.len() < length { 177 | return Err(anyhow!("invalid ext buf content")); 178 | } 179 | if extension == 0 { 180 | if length < 2 { 181 | return Err(anyhow!("invalid ext buf length")); 182 | } 183 | let mut num_names = ext_buf[0] as usize; 184 | num_names = (num_names << 8) + ext_buf[1] as usize; 185 | let mut data = &ext_buf[2..]; 186 | for _ in 0..num_names { 187 | if data.len() < 3 { 188 | return Err(anyhow!("invalid ext data length")); 189 | } 190 | let name_type = data[0]; 191 | let mut name_len = data[1] as usize; 192 | name_len = (name_len << 8) + data[2] as usize; 193 | data = &data[3..]; 194 | if data.len() < name_len { 195 | return Err(anyhow!("invalid ext name data")); 196 | } 197 | if name_type == 0 { 198 | let server_name = String::from_utf8_lossy(&data[0..name_len]); 199 | tracing::info!("Peek SNI:{}", server_name); 200 | return Ok(String::from(server_name)); 201 | } 202 | data = &data[name_len..]; 203 | } 204 | } 205 | ext_buf = &ext_buf[length..]; 206 | } 207 | } 208 | 209 | pub async fn handle_tls( 210 | tunnel_id: u32, 211 | mut inbound: TcpStream, 212 | sender: mpsc::UnboundedSender, 213 | ) -> Result<()> { 214 | let target_addr = match peek_sni(&mut inbound).await { 215 | Ok(mut sni) => { 216 | sni.push_str(":443"); 217 | sni 218 | } 219 | Err(_) => String::from(""), 220 | }; 221 | if target_addr.is_empty() { 222 | return Err(anyhow!("no sni found")); 223 | } 224 | tracing::info!("[{}]Handle TLS proxy to {} ", tunnel_id, target_addr); 225 | let msg = Message::open_stream(inbound, target_addr, None); 226 | sender.send(msg)?; 227 | Ok(()) 228 | } 229 | -------------------------------------------------------------------------------- /src/mux/connection.rs: -------------------------------------------------------------------------------- 1 | use crate::mux::stream::MuxStream; 2 | use anyhow::{anyhow, Result}; 3 | use std::collections::{HashMap, VecDeque}; 4 | use std::sync::atomic::{AtomicU32, Ordering}; 5 | use tokio::io::{AsyncRead, AsyncWrite}; 6 | use tokio::sync::mpsc; 7 | use tokio::sync::oneshot; 8 | use tracing_subscriber::fmt::format; 9 | 10 | use super::event; 11 | 12 | use super::stream::Control; 13 | 14 | const DefaultStreamChannelSize: usize = 16; 15 | 16 | pub struct Connection { 17 | ev_writer: mpsc::Sender, 18 | stream_id_seed: AtomicU32, 19 | } 20 | 21 | pub enum Mode { 22 | Server, 23 | Client, 24 | } 25 | 26 | impl Connection { 27 | pub fn new( 28 | r: R, 29 | w: W, 30 | mode: Mode, 31 | id: u32, 32 | ) -> Self { 33 | let (sender_orig, receiver) = mpsc::channel::(4096); 34 | let sender = sender_orig.clone(); 35 | tokio::spawn(async move { 36 | handle_mux_connection(id, r, w, receiver, sender).await; 37 | }); 38 | match mode { 39 | Mode::Client => Self { 40 | ev_writer: sender_orig, 41 | stream_id_seed: AtomicU32::new(0), 42 | }, 43 | Mode::Server => Self { 44 | ev_writer: sender_orig, 45 | stream_id_seed: AtomicU32::new(1), 46 | }, 47 | } 48 | } 49 | pub async fn ping(&self) -> Result<()> { 50 | if let Err(e) = self.ev_writer.send(Control::Ping).await { 51 | return Err(anyhow::Error::new(e)); 52 | } 53 | Ok(()) 54 | } 55 | pub async fn open_stream(&self) -> Result { 56 | let (sender, receiver) = mpsc::channel::>>(DefaultStreamChannelSize); 57 | let id = self.stream_id_seed.fetch_add(2, Ordering::SeqCst); 58 | let stream = MuxStream::new(id, self.ev_writer.clone(), receiver); 59 | if let Err(e) = self 60 | .ev_writer 61 | .send(Control::NewStream((id, sender, None))) 62 | .await 63 | { 64 | return Err(anyhow::Error::new(e)); 65 | } 66 | Ok(stream) 67 | } 68 | pub async fn accept_stream(&self) -> Result { 69 | let (sender, receiver) = oneshot::channel::>(); 70 | if let Err(e) = self.ev_writer.send(Control::AcceptStream(sender)).await { 71 | return Err(anyhow::Error::new(e)); 72 | } 73 | match receiver.await { 74 | Ok(v) => v, 75 | Err(e) => Err(anyhow::Error::new(e)), 76 | } 77 | } 78 | } 79 | 80 | async fn handle_mux_connection( 81 | conn_id: u32, 82 | r: R, 83 | mut w: W, 84 | mut ev_reader: mpsc::Receiver, 85 | ev_writer_orig: mpsc::Sender, 86 | ) { 87 | let ev_writer = ev_writer_orig.clone(); 88 | let read_connection_fut = async move { 89 | let mut buf_reader = tokio::io::BufReader::new(r); 90 | while let Ok(ev) = event::read_event(&mut buf_reader).await { 91 | match ev.header.flags() { 92 | event::FLAG_SYN => { 93 | //tracing::info!("recv syn:{}", ev.header.stream_id); 94 | let (sender, receiver) = mpsc::channel::>>(DefaultStreamChannelSize); 95 | let _ = ev_writer 96 | .send(Control::NewStream(( 97 | ev.header.stream_id, 98 | sender, 99 | Some(receiver), 100 | ))) 101 | .await; 102 | } 103 | event::FLAG_FIN => { 104 | let _ = ev_writer 105 | .send(Control::StreamShutdown(ev.header.stream_id, true)) 106 | .await; 107 | } 108 | event::FLAG_DATA => { 109 | // if ev.header.flags() == event::FLAG_FIN { 110 | // tracing::info!("recv fin:{}", ev.header.stream_id); 111 | // } 112 | let _ = ev_writer 113 | .send(Control::StreamData(ev.header.stream_id, ev.body, true)) 114 | .await; 115 | } 116 | event::FLAG_PING => { 117 | // 118 | } 119 | _ => { 120 | tracing::error!( 121 | "Unexpected event:{}/{}", 122 | ev.header.flags(), 123 | ev.header.stream_id 124 | ); 125 | } 126 | } 127 | } 128 | let _ = ev_writer.send(Control::Close).await; 129 | }; 130 | 131 | let ev_writer = ev_writer_orig.clone(); 132 | let read_ctrl_fut = async move { 133 | let labels: [(&str, String); 1] = [("idx", format!("{}", conn_id))]; 134 | let mut incoming_streams: VecDeque = VecDeque::new(); 135 | let mut accept_callback: Option>> = None; 136 | let mut stream_senders: HashMap>>> = HashMap::new(); 137 | 138 | while let Some(ctrl) = ev_reader.recv().await { 139 | metrics::gauge!("mux.streams", stream_senders.len() as f64, &labels); 140 | match ctrl { 141 | Control::AcceptStream(callback) => { 142 | if accept_callback.is_some() { 143 | let _ = callback.send(Err(anyhow!("duplocate accept"))); 144 | continue; 145 | } 146 | accept_callback = Some(callback); 147 | } 148 | Control::NewStream((sid, sender, receiver)) => { 149 | match stream_senders.try_insert(sid, sender) { 150 | Err(e) => { 151 | tracing::error!("Duplicate stream id:{}", sid); 152 | let _ = e.value.send(Some(Vec::new())).await; 153 | } 154 | _ => { 155 | if receiver.is_some() { 156 | let stream = 157 | MuxStream::new(sid, ev_writer.clone(), receiver.unwrap()); 158 | incoming_streams.push_back(stream); 159 | } else { 160 | let ev = event::new_syn_event(sid); 161 | if let Err(e) = event::write_event(&mut w, ev).await { 162 | tracing::error!("write syn failed:{}", e); 163 | break; 164 | } 165 | } 166 | } 167 | } 168 | } 169 | Control::StreamData(sid, data, incoming) => { 170 | match stream_senders.get(&sid) { 171 | Some(stream_sender) => { 172 | if incoming { 173 | let data_len = data.len(); 174 | if let Err(e) = stream_sender.send(Some(data)).await { 175 | // handle error 176 | tracing::error!( 177 | "Stream:{}/{} send error:{} with data len:{}", 178 | conn_id, 179 | sid, 180 | e, 181 | data_len 182 | ); 183 | } 184 | } else { 185 | let ev = event::new_data_event(sid, data); 186 | if let Err(e) = event::write_event(&mut w, ev).await { 187 | tracing::error!("write stream data failed:{}", e); 188 | break; 189 | } 190 | } 191 | } 192 | None => { 193 | if data.len() > 0 { 194 | tracing::error!( 195 | "No stream:{}/{} found for for data incoming:{} with data len:{}", 196 | conn_id,sid, 197 | incoming, 198 | data.len() 199 | ); 200 | } 201 | } 202 | } 203 | } 204 | Control::StreamShutdown(sid, remote) => match stream_senders.get(&sid) { 205 | Some(sender) => { 206 | if !remote { 207 | tracing::info!("[{}/{}]Stream shutdown write.", conn_id, sid); 208 | //stream_senders.remove(&sid); 209 | let ev = event::new_fin_event(sid); 210 | if let Err(e) = event::write_event(&mut w, ev).await { 211 | tracing::error!("write fin failed:{}", e); 212 | break; 213 | } 214 | } else { 215 | tracing::info!("[{}/{}]Stream shutdown read.", conn_id, sid); 216 | let _ = sender.send(Some(Vec::new())).await; 217 | //stream_senders.remove(&sid); 218 | } 219 | } 220 | None => {} 221 | }, 222 | Control::StreamClose(sid) => { 223 | tracing::info!("[{}/{}]Stream close.", conn_id, sid); 224 | match stream_senders.remove_entry(&sid) { 225 | Some((_, sender)) => { 226 | let _ = sender.send(None).await; 227 | } 228 | None => { 229 | // 230 | } 231 | } 232 | } 233 | Control::Ping => { 234 | let ev = event::new_ping_event(); 235 | if let Err(e) = event::write_event(&mut w, ev).await { 236 | tracing::error!("write ping failed:{}", e); 237 | break; 238 | } 239 | } 240 | Control::Close => { 241 | break; 242 | } 243 | } 244 | 245 | if accept_callback.is_some() && !incoming_streams.is_empty() { 246 | let stream = incoming_streams.pop_front().unwrap(); 247 | let _ = accept_callback.unwrap().send(Ok(stream)); 248 | accept_callback = None; 249 | } 250 | } 251 | 252 | //close streams 253 | for (_, sender) in stream_senders.drain().take(1) { 254 | let _ = sender.send(Some(Vec::new())).await; 255 | } 256 | if accept_callback.is_some(){ 257 | let _ = accept_callback.unwrap().send(Err(anyhow!("connection closed"))); 258 | accept_callback = None; 259 | } 260 | }; 261 | tokio::join!(read_connection_fut, read_ctrl_fut); 262 | } 263 | -------------------------------------------------------------------------------- /src/tunnel/client.rs: -------------------------------------------------------------------------------- 1 | use metrics::{decrement_gauge, increment_gauge}; 2 | use std::io; 3 | use std::net::ToSocketAddrs; 4 | use std::path::PathBuf; 5 | use std::sync::Arc; 6 | use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; 7 | use tokio::net::TcpStream; 8 | use tokio::sync::mpsc; 9 | use tokio_rustls::TlsConnector; 10 | use url::Url; 11 | 12 | use anyhow::anyhow; 13 | 14 | use crate::mux::event::OpenStreamEvent; 15 | use crate::mux::MuxStream; 16 | use crate::mux::{self, event}; 17 | use crate::tunnel::stream::Stream; 18 | use crate::tunnel::ALPN_QUIC_HTTP; 19 | 20 | struct OpenStreamRequest { 21 | tcp_stream: tokio::net::TcpStream, 22 | event: OpenStreamEvent, 23 | payload: Option>, 24 | } 25 | 26 | impl OpenStreamRequest { 27 | pub fn from(stream: tokio::net::TcpStream, target: String, payload: Option>) -> Self { 28 | Self { 29 | tcp_stream: stream, 30 | event: OpenStreamEvent { 31 | proto: String::from("tcp"), 32 | addr: target, 33 | }, 34 | payload, 35 | } 36 | } 37 | } 38 | 39 | pub enum Message { 40 | OpenStream(OpenStreamRequest), 41 | HealthCheck, 42 | } 43 | impl Message { 44 | pub fn open_stream( 45 | stream: tokio::net::TcpStream, 46 | target: String, 47 | payload: Option>, 48 | ) -> Message { 49 | let req = OpenStreamRequest::from(stream, target, payload); 50 | Message::OpenStream(req) 51 | } 52 | } 53 | 54 | trait MuxConnection { 55 | type SendStream: AsyncWrite + Unpin + Send; 56 | type RecvStream: AsyncRead + Unpin + Send; 57 | async fn ping(&mut self) -> anyhow::Result<()>; 58 | async fn connect(&mut self, url: &Url, key_path: &PathBuf, host: &String) 59 | -> anyhow::Result<()>; 60 | async fn open_stream(&mut self) -> anyhow::Result<(Self::SendStream, Self::RecvStream)>; 61 | fn is_valid(&self) -> bool; 62 | // async fn accept_stream(&self) -> anyhow::Result<(Self::SendStream, Self::RecvStream)>; 63 | } 64 | 65 | pub struct QuicInnerConnection { 66 | inner: Option, 67 | endpoint: Arc, 68 | } 69 | 70 | impl MuxConnection for QuicInnerConnection { 71 | type SendStream = quinn::SendStream; 72 | type RecvStream = quinn::RecvStream; 73 | fn is_valid(&self) -> bool { 74 | !self.inner.is_none() 75 | } 76 | async fn ping(&mut self) -> anyhow::Result<()> { 77 | match &mut self.inner { 78 | None => Err(anyhow!("null connection")), 79 | Some(c) => { 80 | let _ = self.open_stream().await?; 81 | Ok(()) 82 | } 83 | } 84 | } 85 | async fn connect( 86 | &mut self, 87 | url: &Url, 88 | key_path: &PathBuf, 89 | host: &String, 90 | ) -> anyhow::Result<()> { 91 | match &mut self.inner { 92 | None => match new_quic_connection(&self.endpoint, url, host).await { 93 | Ok(c) => { 94 | self.inner = Some(c); 95 | 96 | Ok(()) 97 | } 98 | Err(e) => Err(e), 99 | }, 100 | Some(_) => Err(anyhow!("non null connection")), 101 | } 102 | } 103 | async fn open_stream(&mut self) -> anyhow::Result<(Self::SendStream, Self::RecvStream)> { 104 | match &mut self.inner { 105 | None => Err(anyhow!("null connection")), 106 | Some(c) => match c.open_bi().await { 107 | Err(e) => { 108 | let _ = c.close( 109 | quinn::VarInt::from_u32(48100), 110 | "open stream failed".as_bytes(), 111 | ); 112 | self.inner = None; 113 | Err(e.into()) 114 | } 115 | Ok((send, recv)) => Ok((send, recv)), 116 | }, 117 | } 118 | } 119 | } 120 | 121 | pub struct TlsInnerConnection { 122 | inner: Option, 123 | id: u32, 124 | } 125 | 126 | impl MuxConnection for TlsInnerConnection { 127 | type SendStream = tokio::io::WriteHalf; 128 | type RecvStream = tokio::io::ReadHalf; 129 | fn is_valid(&self) -> bool { 130 | !self.inner.is_none() 131 | } 132 | 133 | async fn ping(&mut self) -> anyhow::Result<()> { 134 | match &mut self.inner { 135 | None => Err(anyhow!("null connection")), 136 | Some(c) => c.ping().await, 137 | } 138 | } 139 | async fn connect( 140 | &mut self, 141 | url: &Url, 142 | key_path: &PathBuf, 143 | host: &String, 144 | ) -> anyhow::Result<()> { 145 | match new_tls_connection(url, key_path, host).await { 146 | Ok(c) => { 147 | let (r, w) = tokio::io::split(c); 148 | let mux_conn = mux::Connection::new(r, w, mux::Mode::Client, self.id); 149 | self.inner = Some(mux_conn); 150 | Ok(()) 151 | } 152 | Err(e) => Err(e), 153 | } 154 | } 155 | async fn open_stream(&mut self) -> anyhow::Result<(Self::SendStream, Self::RecvStream)> { 156 | match &mut self.inner { 157 | None => Err(anyhow!("null connection")), 158 | Some(c) => match c.open_stream().await { 159 | Ok(stream) => { 160 | let (r, w) = tokio::io::split(stream); 161 | Ok((w, r)) 162 | } 163 | Err(e) => { 164 | self.inner = None; 165 | tracing::error!("failed to open stream: {}", e); 166 | Err(e) 167 | } 168 | }, 169 | } 170 | } 171 | } 172 | 173 | trait MuxClientTrait { 174 | type SendStream: AsyncWrite + Unpin + Send; 175 | type RecvStream: AsyncRead + Unpin + Send; 176 | async fn open_stream(&mut self) -> anyhow::Result<(Self::SendStream, Self::RecvStream)>; 177 | async fn health_check(&mut self) -> anyhow::Result<()>; 178 | } 179 | 180 | pub struct MuxClient { 181 | url: url::Url, 182 | conns: Vec, 183 | host: String, 184 | cursor: usize, 185 | cert: Option, 186 | } 187 | 188 | impl MuxClientTrait for MuxClient { 189 | type SendStream = T::SendStream; 190 | type RecvStream = T::RecvStream; 191 | async fn open_stream(&mut self) -> anyhow::Result<(Self::SendStream, Self::RecvStream)> { 192 | for _i in 0..self.conns.len() { 193 | let idx = self.cursor % self.conns.len(); 194 | self.cursor += 1; 195 | if let Ok((send, recv)) = self.conns[idx].open_stream().await { 196 | return Ok((send, recv)); 197 | } 198 | } 199 | Err(anyhow!("no available stream")) 200 | } 201 | async fn health_check(&mut self) -> anyhow::Result<()> { 202 | for c in &mut self.conns { 203 | if !c.is_valid() { 204 | if let Err(e) = c 205 | .connect(&self.url, self.cert.as_ref().unwrap(), &self.host) 206 | .await 207 | { 208 | tracing::error!("reconnect error:{}", e); 209 | } 210 | } else { 211 | if let Err(e) = c.ping().await { 212 | tracing::error!("open stream failed:{}", e); 213 | } 214 | } 215 | } 216 | Ok(()) 217 | } 218 | } 219 | 220 | async fn mux_client_loop( 221 | mut client: T, 222 | mut receiver: mpsc::UnboundedReceiver, 223 | ) where 224 | ::SendStream: 'static, 225 | ::RecvStream: 'static, 226 | { 227 | while let Some(msg) = receiver.recv().await { 228 | match msg { 229 | Message::OpenStream(mut event) => { 230 | tracing::info!("Proxy request to {}", event.event.addr); 231 | if let Ok((mut send, mut recv)) = client.open_stream().await { 232 | increment_gauge!("client_proxy_streams", 1.0); 233 | tokio::spawn(async move { 234 | // tracing::info!("create remote proxy stream success"); 235 | let (mut local_reader, mut local_writer) = event.tcp_stream.split(); 236 | let ev = event::new_open_stream_event(0, &event.event); 237 | if let Err(e) = event::write_event(&mut send, ev).await { 238 | tracing::error!("write open stream event failed:{}", e); 239 | } else { 240 | if let Some(payload) = event.payload { 241 | if let Err(e) = send.write_all(&payload).await { 242 | tracing::error!("write payload failed:{}", e); 243 | return; 244 | } 245 | } 246 | let mut stream = Stream::new( 247 | &mut local_reader, 248 | &mut local_writer, 249 | &mut recv, 250 | &mut send, 251 | ); 252 | if let Err(e) = stream.transfer().await { 253 | tracing::error!("transfer finish:{}", e); 254 | } 255 | } 256 | decrement_gauge!("client_proxy_streams", 1.0); 257 | }); 258 | } else { 259 | tracing::error!("create remote proxy stream failed"); 260 | } 261 | } 262 | Message::HealthCheck => { 263 | let _ = client.health_check().await; 264 | } 265 | } 266 | } 267 | } 268 | 269 | impl MuxClient { 270 | pub async fn from( 271 | url: &Url, 272 | cert_path: &PathBuf, 273 | host: &String, 274 | count: usize, 275 | ) -> anyhow::Result> { 276 | match url.scheme() { 277 | "quic" => { 278 | let (sender, receiver) = mpsc::unbounded_channel::(); 279 | let mut client: MuxClient = MuxClient { 280 | url: url.clone(), 281 | conns: Vec::new(), 282 | host: String::from(host), 283 | cursor: 0, 284 | cert: Some(cert_path.clone()), 285 | }; 286 | let endpoint = new_quic_endpoint(url, cert_path)?; 287 | let endpoint = Arc::new(endpoint); 288 | for i in 0..count { 289 | let mut quic_conn = QuicInnerConnection { 290 | endpoint: endpoint.clone(), 291 | inner: None, 292 | }; 293 | match quic_conn.connect(url, cert_path, &host).await { 294 | Err(e) => { 295 | if i == 0 { 296 | return Err(e); 297 | } 298 | } 299 | _ => { 300 | tracing::info!("QUIC connection:{} established!", i); 301 | } 302 | } 303 | client.conns.push(quic_conn); 304 | } 305 | tokio::spawn(mux_client_loop(client, receiver)); 306 | return Ok(sender); 307 | } 308 | _ => { 309 | return Err(anyhow!("unsupported schema:{:?}", url.scheme())); 310 | } 311 | } 312 | } 313 | } 314 | 315 | impl MuxClient { 316 | pub async fn from( 317 | url: &Url, 318 | cert_path: &PathBuf, 319 | host: &String, 320 | count: usize, 321 | ) -> anyhow::Result> { 322 | match url.scheme() { 323 | "tls" => { 324 | let (sender, receiver) = mpsc::unbounded_channel::(); 325 | let mut client: MuxClient = MuxClient { 326 | url: url.clone(), 327 | conns: Vec::new(), 328 | host: String::from(host), 329 | cursor: 0, 330 | cert: Some(cert_path.clone()), 331 | }; 332 | for i in 0..count { 333 | let mut tls_conn: TlsInnerConnection = TlsInnerConnection { 334 | inner: None, 335 | id: i as u32, 336 | }; 337 | match tls_conn.connect(url, cert_path, &host).await { 338 | Err(e) => { 339 | if i == 0 { 340 | return Err(e); 341 | } 342 | } 343 | _ => { 344 | tracing::info!("TLS connection:{} established!", i); 345 | } 346 | } 347 | client.conns.push(tls_conn); 348 | } 349 | tokio::spawn(mux_client_loop(client, receiver)); 350 | return Ok(sender); 351 | } 352 | _ => { 353 | return Err(anyhow!("unsupported schema:{:?}", url.scheme())); 354 | } 355 | } 356 | } 357 | } 358 | 359 | fn new_quic_endpoint(url: &Url, cert_path: &PathBuf) -> anyhow::Result { 360 | let mut roots = rustls::RootCertStore::empty(); 361 | roots.add(&rustls::Certificate(std::fs::read(cert_path)?))?; 362 | let mut client_crypto = rustls::ClientConfig::builder() 363 | .with_safe_defaults() 364 | .with_root_certificates(roots) 365 | .with_no_client_auth(); 366 | 367 | client_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); 368 | 369 | let client_config = quinn::ClientConfig::new(Arc::new(client_crypto)); 370 | let mut endpoint = quinn::Endpoint::client("[::]:0".parse().unwrap())?; 371 | endpoint.set_default_client_config(client_config); 372 | Ok(endpoint) 373 | } 374 | async fn new_quic_connection( 375 | endpoint: &quinn::Endpoint, 376 | url: &Url, 377 | host: &String, 378 | ) -> anyhow::Result { 379 | let remote = (url.host_str().unwrap(), url.port().unwrap_or(4433)) 380 | .to_socket_addrs()? 381 | .next() 382 | .ok_or_else(|| anyhow!("couldn't resolve to an address"))?; 383 | 384 | let conn: quinn::Connection = endpoint 385 | .connect(remote, host)? 386 | .await 387 | .map_err(|e: quinn::ConnectionError| anyhow!("failed to connect: {}", e))?; 388 | Ok(conn) 389 | } 390 | 391 | async fn new_tls_connection( 392 | url: &Url, 393 | cert_path: &PathBuf, 394 | domain: &String, 395 | ) -> anyhow::Result> { 396 | let remote = (url.host_str().unwrap(), url.port().unwrap_or(4433)) 397 | .to_socket_addrs()? 398 | .next() 399 | .ok_or_else(|| anyhow!("couldn't resolve to an address"))?; 400 | let mut roots = tokio_rustls::rustls::RootCertStore::empty(); 401 | roots.add(&tokio_rustls::rustls::Certificate(std::fs::read( 402 | cert_path, 403 | )?))?; 404 | let mut client_crypto = tokio_rustls::rustls::ClientConfig::builder() 405 | .with_safe_defaults() 406 | .with_root_certificates(roots) 407 | .with_no_client_auth(); 408 | 409 | client_crypto.alpn_protocols = ALPN_QUIC_HTTP.iter().map(|&x| x.into()).collect(); 410 | 411 | let connector = TlsConnector::from(Arc::new(client_crypto)); 412 | let stream = TcpStream::connect(&remote).await?; 413 | 414 | let domain = tokio_rustls::rustls::ServerName::try_from(domain.as_str()) 415 | .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))? 416 | .to_owned(); 417 | 418 | let stream: tokio_rustls::client::TlsStream = 419 | connector.connect(domain, stream).await?; 420 | Ok(stream) 421 | } 422 | --------------------------------------------------------------------------------