├── .github ├── FUNDING.yml └── workflows │ └── build-package.yml ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── MIGRATION.md ├── README.md ├── build.rs ├── conf ├── config.env ├── moproxy.service ├── policy.rules ├── proxy.ini └── simple_score.lua ├── src ├── cli.rs ├── client │ ├── connect.rs │ ├── mod.rs │ └── tls_parser.rs ├── futures_stream.rs ├── lib.rs ├── linux │ ├── mod.rs │ ├── systemd.rs │ └── tcp.rs ├── main.rs ├── monitor │ ├── alive_test.rs │ ├── graphite.rs │ ├── mod.rs │ └── traffic.rs ├── policy │ ├── capabilities.rs │ ├── mod.rs │ └── parser.rs ├── proxy │ ├── copy.rs │ ├── http.rs │ ├── mod.rs │ └── socks5.rs ├── server.rs └── web │ ├── helpers.rs │ ├── index.html │ ├── mod.rs │ ├── open_metrics.rs │ └── rich.rs └── tests └── socks5.rs /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: xierch 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /.github/workflows/build-package.yml: -------------------------------------------------------------------------------- 1 | name: build-package 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | release: 10 | name: Release - ${{ matrix.platform.release_for }} 11 | if: startsWith(github.ref, 'refs/tags/v') 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | platform: 16 | - release_for: Linux-x86_64 17 | os: ubuntu-latest 18 | target: x86_64-unknown-linux-gnu 19 | suffix: linux_amd64.bin 20 | toolchain: stable 21 | 22 | - release_for: Windows-x86_64 23 | os: windows-latest 24 | target: x86_64-pc-windows-msvc 25 | suffix: windows_amd64.exe 26 | toolchain: stable 27 | 28 | - release_for: Linux-x86_64-musl 29 | os: ubuntu-latest 30 | target: x86_64-unknown-linux-musl 31 | suffix: linux_x86_64_musl.bin 32 | toolchain: stable 33 | 34 | - release_for: Linux-armv7 35 | os: ubuntu-latest 36 | target: armv7-unknown-linux-gnueabihf 37 | suffix: linux_armv7_gnueabihf.bin 38 | toolchain: stable 39 | 40 | - release_for: Linux-aarch64 41 | os: ubuntu-latest 42 | target: aarch64-unknown-linux-gnu 43 | suffix: linux_aarch64.bin 44 | toolchain: stable 45 | 46 | # Disabled until corss release a new version 47 | # https://github.com/cross-rs/cross/issues/1222 48 | # 49 | #- release_for: Android-aarch64 50 | # os: ubuntu-latest 51 | # target: aarch64-linux-android 52 | # suffix: linux_aarch64_android.bin 53 | # toolchain: stable 54 | 55 | runs-on: ${{ matrix.platform.os }} 56 | steps: 57 | - name: Checkout 58 | uses: actions/checkout@v4 59 | - name: Setup cache 60 | uses: Swatinem/rust-cache@v2 61 | - name: Install musl-tools 62 | if: contains(matrix.platform.target, 'linux-musl') 63 | run: sudo apt install musl-tools 64 | - name: Install RPM tools 65 | if: contains(matrix.platform.target, 'linux-gnu') 66 | run: sudo apt-get install rpm 67 | - name: Install aarch64 binutils 68 | if: matrix.platform.target == 'aarch64-unknown-linux-gnu' 69 | run: | 70 | sudo apt-get install binutils-aarch64-linux-gnu build-essential crossbuild-essential-arm64 pkg-config libc6-arm64-cross libgcc1-arm64-cross libstdc++6-arm64-cross 71 | sudo dpkg --add-architecture arm64 72 | - name: Build 73 | uses: houseabsolute/actions-rust-cross@v0 74 | with: 75 | command: build 76 | target: ${{ matrix.platform.target }} 77 | toolchain: ${{ matrix.platform.toolchain }} 78 | args: --release 79 | strip: yes 80 | - name: Packaging for Debian - x84_64 81 | if: matrix.platform.target == 'x86_64-unknown-linux-gnu' 82 | run: | 83 | cargo install cargo-deb 84 | cargo deb --target=${{ matrix.platform.target }} --no-build 85 | - name: Packaging for Debian - aarch64 86 | if: matrix.platform.target == 'aarch64-unknown-linux-gnu' 87 | run: | 88 | cargo install cargo-deb 89 | cargo deb --target=${{ matrix.platform.target }} --no-build --no-strip 90 | - name: Packaging for RPM - x86_64 91 | if: matrix.platform.target == 'x86_64-unknown-linux-gnu' 92 | run: | 93 | cargo install cargo-generate-rpm 94 | cargo generate-rpm --target=${{ matrix.platform.target }} 95 | - name: Packaging for RPM - aarch64 96 | if: matrix.platform.target == 'aarch64-unknown-linux-gnu' 97 | run: | 98 | cargo install cargo-generate-rpm 99 | cargo generate-rpm --target=${{ matrix.platform.target }} 100 | - name: Packaging binary for Linux 101 | if: contains(matrix.platform.os, 'ubuntu') 102 | run: xz -kfS "_${GITHUB_REF#*/v}_${{ matrix.platform.suffix }}.xz" target/${{ matrix.platform.target }}/release/moproxy 103 | - name: Packaging binary for Windows 104 | if: contains(matrix.platform.os, 'windows') 105 | run: xz -kf target/${{ matrix.platform.target }}/release/moproxy.exe > "moproxy_${GITHUB_REF#*/v}_${{ matrix.platform.suffix }}.xz" 106 | - name: Release 107 | uses: ncipollo/release-action@v1 108 | with: 109 | artifacts: "*.xz,target/**/*.xz,target/**/*.deb,target/**/*.rpm" 110 | draft: true 111 | allowUpdates: true 112 | updateOnlyUnreleased: true 113 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | **/*.rs.bk 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "moproxy" 3 | version = "0.5.1" 4 | authors = ["sorz "] 5 | edition = "2021" 6 | description = "Transparent TCP to SOCKSv5/HTTP proxy on Linux written in Rust" 7 | readme = "README.md" 8 | repository = "https://github.com/sorz/moproxy" 9 | license = "MIT" 10 | keywords = ["proxy", "socksv5"] 11 | categories = ["command-line-utilities"] 12 | rust-version = "1.75.0" 13 | 14 | [dependencies] 15 | rand = "0.8" 16 | tokio = { version = "1", features = ["full"] } 17 | tokio-stream = "0.1" 18 | net2 = "0.2" 19 | clap = { version = "4", features = ["derive", "wrap_help"] } 20 | tracing = "0.1" 21 | tracing-subscriber = "0.3" 22 | serde = { version = "1", features = ["rc"] } 23 | serde_json = "1" 24 | serde_derive = "1" 25 | serde_with = "3" 26 | rust-ini = "0.20" 27 | hyper = { version = "1", optional = true, features = [ 28 | "http1", 29 | "server", 30 | ] } 31 | hyper-util = { version = "0.1", features = ["tokio"] } 32 | http-body-util = "0.1" 33 | parking_lot = { version = "0.12", features = ["serde", "deadlock_detection"] } 34 | http = "1" 35 | prettytable-rs = { version = "0.10", default-features = false } 36 | regex = "1" 37 | once_cell = "1" 38 | number_prefix = "0.4" 39 | futures-core = "0.3" 40 | futures-util = "0.3" 41 | httparse = "1" 42 | rlua = { version = "0.19", optional = true } 43 | bytes = "1" 44 | zip = { version = "0.6", optional = true, default-features = false, features = [ 45 | "deflate" 46 | ] } 47 | base64 = "0.21" 48 | nom = "7" 49 | flexstr = { version = "0.9", features = ["serde"] } 50 | anyhow = "1" 51 | ip_network_table-deps-treebitmap = "0.5.0" 52 | 53 | [target.'cfg(target_os = "linux")'.dependencies] 54 | libc = "0.2" 55 | nix = { version = "0.27", features = ["fs", "net", "socket"] } 56 | sd-notify = { version = "0.4", optional = true } 57 | tracing-journald = { version = "0.3", optional = true } 58 | 59 | [features] 60 | default = ["web_console", "score_script", "systemd", "rich_web"] 61 | web_console = ["hyper"] 62 | rich_web = ["web_console", "zip"] 63 | score_script = ["rlua"] 64 | systemd = ["sd-notify", "tracing-journald"] 65 | 66 | [build-dependencies] 67 | reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "blocking"] } 68 | 69 | [package.metadata.deb] 70 | section = "net" 71 | priority = "optional" 72 | assets = [ 73 | ["target/release/moproxy", "usr/bin/", "755"], 74 | ["README.md", "usr/share/doc/moproxy/README", "644"], 75 | ["conf/moproxy.service", "usr/lib/systemd/system/", "644"], 76 | ["conf/config.env", "etc/moproxy/", "644"], 77 | ["conf/proxy.ini", "etc/moproxy/", "644"], 78 | ["conf/simple_score.lua", "etc/moproxy/", "644"], 79 | ] 80 | 81 | [package.metadata.generate-rpm] 82 | assets = [ 83 | { source = "target/release/moproxy", dest = "/usr/bin/moproxy", mode = "755" }, 84 | { source = "README.md", dest = "/usr/share/doc/moproxy/README", mode = "644" }, 85 | { source = "conf/moproxy.service", dest = "/usr/lib/systemd/system/moproxy.service", mode = "644" }, 86 | { source = "conf/config.env", dest = "/etc/moproxy/config.env", mode = "644" }, 87 | { source = "conf/proxy.ini", dest = "/etc/moproxy/proxy.ini", mode = "644" }, 88 | { source = "conf/simple_score.lua", dest = "/etc/moproxy/simple_score.lua", mode = "644" }, 89 | ] 90 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017-2021 Shell Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MIGRATION.md: -------------------------------------------------------------------------------- 1 | # Migration guide 2 | 3 | 4 | ## v0.4 to v0.5 5 | 6 | - Multiple values for single CLI argument now need to be delimited by comma 7 | 8 | Before: `--port 2081 2082 2083` 9 | 10 | After: `--port 2081,2082,2083` 11 | 12 | - `listen ports =` on proxy list INI file is not longer supported. 13 | User should migrate them to proxy selection policy. 14 | 15 | Before: 16 | ```ini 17 | # Proxy list 18 | [server-1] 19 | listen ports=2081 20 | # ... 21 | [server-2] 22 | listen ports=2082 23 | # ... 24 | ``` 25 | 26 | After: 27 | ```ini 28 | # Proxy list 29 | [server-1] 30 | capabilities = cap1 31 | # ... 32 | [server-2] 33 | capabilities = cap2 34 | # ... 35 | ``` 36 | ``` 37 | # Ruleset 38 | listen port 2081 require cap1 39 | listen port 2082 require cap2 40 | ``` 41 | 42 | ## v0.3 to v0.4 43 | 44 | `-h` (listen host) has been renamed to `-b` 45 | 46 | Before: 47 | ```bash 48 | moproxy -h ::1 -p 2080 ... 49 | ``` 50 | 51 | After: 52 | ```bash 53 | moproxy -b ::1 -p 2080 ... 54 | ``` 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # moproxy 2 | 3 | A transparent TCP to SOCKSv5/HTTP proxy on *Linux* written in Rust. 4 | 5 | Features: 6 | 7 | * Transparent TCP proxy with `iptables -j REDIRECT` or `nft redirect to` 8 | * Downstream SOCKSv5 as a supplement to transparent proxy 9 | * Multiple SOCKSv5/HTTP upstream proxy servers 10 | * SOCKS/HTTP-layer alive & latency probe for upstreams 11 | * Prioritize upstreams according to connection quality (latency & error rate) 12 | * Full IPv6 support 13 | * Proxy selection policy (see [conf/policy.rules](conf/policy.rules)) 14 | * Multiple downstream listen ports (for proxy selection policy) 15 | * Remote DNS resolving for TLS with SNI (extract domain name from TLS 16 | handshaking) 17 | * Optional try-in-parallel for TLS (try multiple proxies and choose the one 18 | first response) 19 | * Optional status web page (latency, traffic, etc. w/ curl-friendly output) 20 | * Optional [Graphite](https://graphite.readthedocs.io/) and 21 | OpenMetrics ([Prometheus](https://prometheus.io/)) support 22 | (to build fancy dashboard with [Grafana](https://grafana.com/) for example) 23 | * Customizable proxy selection algorithm with Lua script (see 24 | [conf/simple_scroe.lua](conf/simple_score.lua)). 25 | 26 | ``` 27 | +-----+ TCP +-----------+ SOCKSv5 +---------+ 28 | | App |------>| firewall | +----------->| Proxy 1 |---> 29 | +-----+ +-----------+ | +---------+ 30 | redirect | | 31 | +-----+ to v | HTTP +---------+ 32 | | App | //=========\\ | +------->| Proxy 2 |---> 33 | +-----+ || ||----+ | +---------+ 34 | | || MOPROXY ||--------+ : 35 | +--------->|| ||-----------··· : 36 | SOCKSv5 \\=========// Selection | +---------+ 37 | | policy +-->| Proxy N |---> 38 | | +---------+ 39 | | 40 | +----------- Direct ------------> 41 | ``` 42 | 43 | ## Breaking changes 44 | 45 | There are CLI and/or configure changes among: 46 | 47 | - [v0.4 => v0.5](MIGRATION.md/#v04-to-v05) 48 | - [v0.3 => v0.4](MIGRATION.md/#v03-to-v04) 49 | 50 | See [MIGRATION.md](MIGRATION.md) 51 | 52 | ## Usage 53 | 54 | ### Print usage 55 | ```bash 56 | moproxy --help 57 | ``` 58 | ### Examples 59 | 60 | Assume there are three SOCKSv5 servers on `localhost:2001`, `localhost:2002`, 61 | and `localhost:2003`, and two HTTP proxy servers listen on `localhost:3128` 62 | and `192.0.2.0:3128`. 63 | Following commands forward all TCP connections that connect to 80 and 443 to 64 | these proxy servers. 65 | 66 | ```bash 67 | moproxy --port 2080 --socks5 2001 2002 2003 --http 3128 192.0.2.0:3128 68 | 69 | # redirect local-initiated connections 70 | nft add rule nat output tcp dport {80, 443} redirect to 2080 71 | # redirect connections initiated by other hosts (if you are router) 72 | nft add rule nat prerouting tcp dport {80, 443} redirect to 2080 73 | 74 | # or the legacy iptables equivalent 75 | iptables -t nat -A OUTPUT -p tcp -m multiport --dports 80,443 -j REDIRECT --to-port 2080 76 | iptables -t nat -A PREROUTING -p tcp -m multiport --dports 80,443 -j REDIRECT --to-port 2080 77 | ``` 78 | 79 | SOCKSv5 server is also launched alongs with transparent proxy on the same port: 80 | ```bash 81 | http_proxy=socks5h://localhost:2080 curl ifconfig.co 82 | ``` 83 | 84 | ### Server list file 85 | Put upstream proxies on a file to avoid messy CLI arguments and enable features 86 | like priority (score base), username/password auth, capabilities, etc. 87 | 88 | [See proxy.ini example](conf/proxy.ini) for details. 89 | 90 | Pass file path to `moproxy` via `--list` argument. 91 | 92 | Signal `SIGHUP` will trigger the program to reload the list. 93 | 94 | ### Proxy selection policy file 95 | Let specified connections use only a subset of upstream proxies. 96 | 97 | [See policy.rules example](conf/policy.rules) for details. 98 | 99 | Pass file path to `moproxy` via `--policy` argument. 100 | 101 | Signal `SIGHUP` will trigger the program to reload the list. 102 | 103 | ### Custom proxy selection 104 | Proxy servers are sorted by their *score*, which is re-calculated after each 105 | round of alive/latency probing. Server with lower score is prioritized. 106 | 107 | The current scoring algorithm is a kind of weighted moving average of latency 108 | with penalty for recent connection errors. This can be replaced with your own 109 | algorithm written in Lua. See [conf/simple_score.lua](conf/simple_score.lua) 110 | for details. 111 | 112 | Source/destination address–based proxy selection is not directly supported. 113 | One workaround is let moproxy bind multiple ports, delegates each port to 114 | different proxy servers with `listen ports` in your config, then doing 115 | address-based selection on your firewall. 116 | 117 | ### Monitoring 118 | Metrics (latency, traffic, number of connections, etc.) are useful for 119 | diagnosis and customing your own proxy selection. You can access these 120 | metrics with various methods, from a simple web page, curl, to specialized 121 | tools like Graphite or Prometheus. 122 | 123 | `--stats-bind [::1]:8080` turns on the internal stats page, via HTTP, on the 124 | given IP address and port number. It returns a HTML page for web browser, 125 | or a ASCII table for `curl`. 126 | 127 | The stats page only provides current metrics and a few aggregations. Graphite 128 | (via `--graphite`) or OpenMetrics (via `--stats-bind` then `\metrics`) should 129 | be used if you want a full history. 130 | 131 | Some examples of Prometheus query (Grafana variant): 132 | 133 | ``` 134 | Inbound bandwith: 135 | rate(moproxy_proxy_server_bytes_rx_total[$__range]) 136 | 137 | Total outbound traffic: 138 | sum(increase(moproxy_proxy_server_bytes_tx_total[$__range])) 139 | 140 | No. of connection errors per minute: 141 | sum(increase(moproxy_proxy_server_connections_error[1m])) 142 | 143 | Average delay for each proxy server: 144 | avg_over_time(moproxy_proxy_server_dns_delay_seconds[$__interval]) 145 | ``` 146 | 147 | ### Systemd integration 148 | 149 | Sample service file: [conf/moproxy.service](conf/moproxy.service) 150 | 151 | Implemented features: 152 | 153 | - Watchdog 154 | - Reloading (via SIGHUP signal) 155 | - Notify (`type=notify`, reloading, status string) 156 | 157 | Get simple status without turing on the HTTP stats page: 158 | 159 | ``` 160 | $ systemctl status moproxy 161 | > ... 162 | > Status: "serving (7/11 upstream proxies up)" 163 | > ... 164 | ``` 165 | 166 | ## Install 167 | 168 | You may download the binary executable file on 169 | [releases page](https://github.com/sorz/moproxy/releases). 170 | 171 | Arch Linux user can install it from 172 | [AUR/moproxy](https://aur.archlinux.org/packages/moproxy/). 173 | 174 | Or compile it manually: 175 | 176 | ```bash 177 | # Install Rust 178 | curl https://sh.rustup.rs -sSf | sh 179 | 180 | # Clone source code 181 | git clone https://github.com/sorz/moproxy 182 | cd moproxy 183 | 184 | # Build 185 | cargo build --release 186 | target/release/moproxy --help 187 | 188 | # If you are in Debian 189 | cargo install cargo-deb 190 | cargo deb 191 | sudo dpkg -i target/debian/*.deb 192 | moproxy --help 193 | ``` 194 | 195 | Refer to [conf/](conf/) for config & systemd service files. 196 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | env, 3 | fs::File, 4 | path::{Path, PathBuf}, 5 | }; 6 | 7 | const ZIP_URL: &str = "https://github.com/sorz/moproxy-web/releases/download/{VERSION}/build.zip"; 8 | const VERSION: &str = "v0.1.8"; 9 | 10 | fn main() { 11 | if env::var("CARGO_FEATURE_RICH_WEB").is_err() { 12 | return; 13 | } 14 | let output_dir = env::var("OUT_DIR").expect("OUT_DIR environment variable not set"); 15 | let zip_path = PathBuf::from(output_dir).join(format!("moproxy-web-{}.zip", VERSION)); 16 | if !zip_path.exists() { 17 | download_zip(&zip_path); 18 | } 19 | println!( 20 | "cargo:rustc-env=MOPROXY_WEB_BUNDLE={}", 21 | zip_path.into_os_string().into_string().unwrap() 22 | ); 23 | } 24 | 25 | fn download_zip(path: &Path) { 26 | let url = ZIP_URL.replace("{VERSION}", VERSION); 27 | let mut resp = reqwest::blocking::get(url) 28 | .expect("error on get moproxy-web bundle") 29 | .error_for_status() 30 | .expect("unexpect HTTP response"); 31 | let mut zip = File::create(path).expect("cannot create file"); 32 | resp.copy_to(&mut zip) 33 | .expect("error on download/write out moproxy-web bundle"); 34 | } 35 | -------------------------------------------------------------------------------- /conf/config.env: -------------------------------------------------------------------------------- 1 | ## Configuraion file loaded by moproxy.service 2 | 3 | ## TCP listen address 4 | HOST="::" 5 | PORT="2080" 6 | # Multiple ports: 7 | # PORT="2080,2081" 8 | 9 | ## Web status page listen on 10 | WEB_BIND="127.0.0.1:8080" 11 | # Or a UNIX domain socket starts with "/" 12 | # WEB_BIND="/run/moproxy/web.sock" 13 | 14 | ## List of backend proxy servers 15 | PROXY_LIST="/etc/moproxy/proxy.ini" 16 | 17 | ## Other arguments passed to moproxy daemon 18 | DAEMON_ARGS="--stats-bind ${WEB_BIND}" 19 | 20 | ## Enable remote DNS 21 | # DAEMON_ARGS="${DAEMON_ARGS} --remote-dns" 22 | 23 | ## Enable policy rules 24 | # DAEMON_ARGS="${DAEMON_ARGS} --policy /etc/moproxy/policy.rules" 25 | -------------------------------------------------------------------------------- /conf/moproxy.service: -------------------------------------------------------------------------------- 1 | [Unit] 2 | Description=MoProxy transparent TCP proxy daemon. 3 | After=network.target 4 | 5 | [Service] 6 | Type=notify 7 | DynamicUser=yes 8 | 9 | WatchdogSec=5 10 | Restart=on-failure 11 | 12 | EnvironmentFile=/etc/moproxy/config.env 13 | ExecStart=/usr/bin/moproxy -b $HOST -p $PORT --list $PROXY_LIST $DAEMON_ARGS 14 | ExecReload=/usr/bin/kill -HUP $MAINPID 15 | 16 | LimitNOFILE=32768 17 | PrivateTmp=true 18 | PrivateDevices=true 19 | ProtectSystem=strict 20 | ProtectHome=true 21 | ProtectKernelModules=true 22 | ProtectControlGroups=true 23 | 24 | ## Wait for WAN interface up 25 | # ExecStartPre=/usr/lib/systemd/systemd-networkd-wait-online --interface=ppp0 26 | 27 | [Install] 28 | WantedBy=multi-user.target 29 | -------------------------------------------------------------------------------- /conf/policy.rules: -------------------------------------------------------------------------------- 1 | ## Example of moproxy policy rulesets 2 | 3 | # One rule per line, comment starts with a hashtag. 4 | # Each rule are composited by two parts, FILTER and ACTION. 5 | # Keywords and domain name are case-insensitive. 6 | # 7 | # Supported filters: 8 | # - DEFUALT (matches everything / no filter) 9 | # - LISTEN PORT (moproxy's TCP listen port number) 10 | # - DST IP [/] (destination IP address, won't resolve) 11 | # - DST DOMAIN (domain name in TLS SNI or SOCKSv5 request) 12 | # 13 | # Supported actions: 14 | # - REQUIRE [or |...] (limit avaiable upstream proxies) 15 | # - DIRECT (do not use proxy, go direct, even if --allow-direct unset) 16 | # - REJECT (close connection immediately) 17 | # 18 | # Evaluation order: 19 | # For each incoming connection, rules are evaluated in the order according 20 | # to their filter type: DEFAULT -> LSITEN PORT -> DST IP -> DST DOMAIN 21 | # 22 | # Multiple matches: 23 | # One connection may be matched by multiple rules, depending on their actions: 24 | # - REQUIRE actions accumulate with themself 25 | # - DIRECT & REJECT are exclusive, they override other action and been 26 | # overridden by others 27 | # (Can be tweaked by priority) 28 | # 29 | # Action priority: 30 | # One or more exclamation marks (!) after action promote its priority (up to 5). 31 | # Actions with higher priority always override lower one. 32 | # 33 | # Example: 34 | # 35 | 36 | # Connection to TCP 8001 requires "cap1" on proxy's capabilities 37 | # TCP 8002 requires "cap1" or "cap2" 38 | # TCP 8003 requires "cap3" only. It ignore all rules without 3 or more "!". 39 | listen port 8001 require cap1 40 | listen port 8002 require cap1 or cap2 41 | listen port 8003 require!!! cap3 42 | 43 | # *.netflix.com goes to proxies with BOTH "streaming" AND "us". 44 | dst domain netflix.com require streaming 45 | dst domain netflix.com require us 46 | 47 | # *.cn will not use any proxy, expect *.edu.cn require proxies with "edu" 48 | # more specific match override less specific one 49 | dst domain cn direct 50 | dst domain edu.cn require edu 51 | 52 | # *.edu.au will match both rules, thus requires BOTH "au" AND "edu" 53 | # However, *.anu.edu.au requires just "au" due to its higher priority 54 | dst domain au require au 55 | dst domain edu.au require edu 56 | dst domain anu.edu.au require! au 57 | 58 | # `dst domain` lookup for SOCKSv5 hostname if it exists, or TLS SNI if 59 | # `--remote-dns` is enabled. Explicit SOCKSv5 hostname get the priority. 60 | # `dst domain .` will match any domain (but not for connection w/o domain). 61 | -------------------------------------------------------------------------------- /conf/proxy.ini: -------------------------------------------------------------------------------- 1 | ## Example of moproxy server list file. 2 | 3 | # Each server starts with a unique `[SERVER-TAG]`, 4 | # followed by a list of attributes. 5 | # 6 | # Use `moproxy [...] policy get [..]` to test it. 7 | # 8 | # Common attributes 9 | # - address: IP-addr:port of the server. 10 | # - protocol: HTTP or SOCKSv5. 11 | # - test dns: IP-addr:port of a DNS server with TCP support. 12 | # - score base: A fixed +/- integer added into server's score. 13 | # - capabilities: List of capabilities, used by --policy rules. 14 | # 15 | # Attributes for SOCKSv5 16 | # - socks username, socks password: 17 | # Username/password authentication (RFC 1929) for upstream proxy 18 | # 19 | # Attributes for HTTP 20 | # - http username, http password: 21 | # HTTP basic access authentication for upstream proxy 22 | # 23 | # `address` and `protocol` are mandatory, others are optional. 24 | 25 | [server-1] 26 | address=127.0.0.1:2001 ;required 27 | protocol=socks5 ;required 28 | ;all other attributes are optional 29 | 30 | [server-2] 31 | address=127.0.0.1:2002 32 | protocol=http 33 | http username = user 34 | http password = pAsSwoRd ;optional upstream HTTP Basic Auth 35 | test dns=127.0.0.53:53 ;use remote's local dns server to caculate delay 36 | capabilities = cap1 cap2 ;used by policy rules 37 | 38 | [server-3] 39 | address=127.0.0.1:2003 40 | protocol=http 41 | ; server-3 serves for port 8001 & 8002, while server-2 is only for 42 | ; port 8001. server-1 accepts connections coming from any ports specified 43 | ; by CLI argument --port. 44 | 45 | [backup] 46 | address=127.0.0.1:2002 47 | protocol=socks5 48 | socks username = user 49 | socks password = pAsSwoRd 50 | score base=5000 ;add 5k to pull away from preferred server. 51 | max wait=10 ;waiting up to 10 seconds before give up. 52 | -------------------------------------------------------------------------------- /conf/simple_score.lua: -------------------------------------------------------------------------------- 1 | -- A simple demo for using Lua script to customize proxy scoring. 2 | -- Run moproxy with `--score-script /path/to/simple_score.lua` to enable it. 3 | 4 | -- Calculate score for given proxy server and delay 5 | -- proxy: a table describes the proxy server 6 | -- delay: time in seconds in float 7 | -- Return a score in signed number or nil 8 | function calc_score(proxy, delay) 9 | -- proxy.addr, proxy.proto, proxy.tag: 10 | -- Basic information about the proxy. 11 | -- proxy.config: 12 | -- Proxy's configs, 13 | -- includes test_dns, max_wait, and score_base. 14 | -- proxy.traffic: 15 | -- tx_bytes: total amount of traffics, upload to proxy server 16 | -- rx_bytes: download from proxy server 17 | -- proxy.status: 18 | -- delay: the delay before this update, in secs in float. 19 | -- nil = initial value; -1 = timed out. 20 | -- score: the score before this update, may be nil. 21 | -- conn_alive, conn_total, conn_error: connection counters 22 | -- close_history: 23 | -- History of the 64 most recent closed connections, stored as 24 | -- bitmap in a 64-bit int. 0 for closed without any error, 1 for 25 | -- connection closed due to error. The most insignificant bit is 26 | -- the most recent closed connection. 27 | 28 | -- print out tag & delay for debugging 29 | print(proxy.tag, delay) 30 | 31 | if delay == nil then 32 | -- disable proxy if delay probing failed 33 | return nil 34 | else 35 | -- simply use delay in microseconds plus score_base as score 36 | return math.floor(delay * 1000 + proxy.config.score_base) 37 | end 38 | end 39 | -------------------------------------------------------------------------------- /src/cli.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | net::{IpAddr, Ipv6Addr, SocketAddr}, 3 | path::PathBuf, 4 | time::Duration, 5 | }; 6 | 7 | use clap::{arg, command, Parser, Subcommand}; 8 | use tracing::metadata::LevelFilter; 9 | 10 | #[derive(Parser, Debug)] 11 | #[command(author, version, about, long_about = None)] 12 | pub(crate) struct CliArgs { 13 | /// Address to bind on 14 | #[arg(short = 'b', long, value_name = "IP-ADDRESS")] 15 | #[arg(default_value_t = Ipv6Addr::UNSPECIFIED.into())] 16 | pub(crate) host: IpAddr, 17 | 18 | /// Port number to bind on. Multiple ports can be delimited by comma (,) 19 | #[arg( 20 | short = 'p', 21 | long, 22 | value_name = "PORTS", 23 | required = true, 24 | value_delimiter = ',' 25 | )] 26 | pub(crate) port: Vec, 27 | 28 | /// SOCKSv5 server list. IP address can omit for localhost. 29 | #[arg( 30 | short = 's', 31 | long = "socks5", 32 | value_name = "SOCKS5-SERVERS", 33 | value_delimiter = ',', 34 | value_parser = parse_socket_addr_default_on_localhost 35 | )] 36 | pub(crate) socks5_servers: Vec, 37 | 38 | /// HTTP proxy server list. IP address can omit for localhost. 39 | #[arg( 40 | short = 't', 41 | long = "http", 42 | value_name = "HTTP-SERVERS", 43 | value_delimiter = ',', 44 | value_parser = parse_socket_addr_default_on_localhost 45 | )] 46 | pub(crate) http_servers: Vec, 47 | 48 | /// INI file contains list of proxy servers. 49 | #[arg(short = 'l', long = "list", value_name = "SERVER-LIST")] 50 | pub(crate) server_list: Option, 51 | 52 | /// Rule file for proxy selection policy 53 | #[arg(long = "policy", value_name = "POLICY")] 54 | pub(crate) policy: Option, 55 | 56 | /// Period of time to make one probe. 57 | #[arg(short = 'i', long = "probe", value_name = "SECONDS")] 58 | #[arg(default_value_t = 30)] 59 | pub(crate) probe_secs: u64, 60 | 61 | /// Address of a DNS server with TCP support to do delay probing. 62 | #[arg(long, value_name = "IP-ADDR:PORT", default_value = "8.8.8.8:53")] 63 | pub(crate) test_dns: SocketAddr, 64 | 65 | /// Where the web server that shows statistics bind. 66 | #[cfg(feature = "web_console")] 67 | #[arg(long = "stats-bind", value_name = "IP-ADDR:PORT")] 68 | pub(crate) web_bind: Option, 69 | 70 | /// Try to obtain domain name from TLS SNI, and sent it to remote 71 | /// proxy server. Only apply for port number 443. 72 | #[arg(long)] 73 | pub(crate) remote_dns: bool, 74 | 75 | /// Connect and send application data to N proxies in parallel, use 76 | /// the first proxy that return valid data. Currently only support 77 | /// TLS as application layer. Must turn on --remote-dns otherwise it 78 | /// will be ignored. 79 | #[arg(long, value_name = "N", default_value_t = 0)] 80 | pub(crate) n_parallel: usize, 81 | 82 | /// Set TCP congestion control algorithm on local (client) side. 83 | #[cfg(target_os = "linux")] 84 | #[arg(long = "congestion-local", value_name = "ALG-NAME")] 85 | pub(crate) cong_local: Option, 86 | 87 | /// Fallback to direct connect (without proxy) if all proxies failed. 88 | #[arg(long)] 89 | pub(crate) allow_direct: bool, 90 | 91 | /// Send metrics to graphite (carbon) daemon in plaintext format with 92 | /// TCP. 93 | #[arg(long, value_name = "IP-ADDR:PORT")] 94 | pub(crate) graphite: Option, 95 | 96 | /// Level of verbosity [possible values: off, error, warn, info, debug, 97 | /// trace] 98 | #[arg(long, default_value = "info")] 99 | pub(crate) log_level: LevelFilter, 100 | 101 | /// Lua script that customize proxy score 102 | #[cfg(feature = "score_script")] 103 | #[arg(long, value_name = "LUA-SCRIPT")] 104 | pub(crate) score_script: Option, 105 | 106 | /// Max waiting time in seconds for connection establishment before 107 | /// timeout. Applied for both probe & regular proxy connections. 108 | #[arg(long, value_name = "SECONDS", default_value = "4", value_parser = parse_duration_in_seconds)] 109 | pub(crate) max_wait: Duration, 110 | 111 | #[command(subcommand)] 112 | pub(crate) command: Option, 113 | } 114 | 115 | #[derive(Debug, Subcommand)] 116 | pub(crate) enum Commands { 117 | /// Load & check configure and then exit 118 | Check { 119 | /// Do not try to bind socket 120 | #[arg(long)] 121 | no_bind: bool, 122 | }, 123 | 124 | /// Policy ruleset related commands 125 | Policy { 126 | #[command(subcommand)] 127 | command: PolicyCommands, 128 | }, 129 | } 130 | 131 | #[derive(Debug, Subcommand)] 132 | pub(crate) enum PolicyCommands { 133 | /// Given connection info, return policy with filtered upstream proxies 134 | Get { 135 | #[arg(long)] 136 | listen_port: Option, 137 | #[arg(long)] 138 | dst_ip: Option, 139 | #[arg(long)] 140 | dst_domain: Option, 141 | }, 142 | } 143 | 144 | fn parse_duration_in_seconds(s: &str) -> Result { 145 | s.parse() 146 | .map_err(|_| format!("`{}` isn't a number", s)) 147 | .map(Duration::from_secs) 148 | } 149 | 150 | fn parse_socket_addr_default_on_localhost(addr: &str) -> Result { 151 | if addr.contains(':') { 152 | addr.parse() 153 | } else { 154 | format!("127.0.0.1:{}", addr).parse() 155 | } 156 | .map_err(|_| format!("`{}` isn't a valid server address", addr)) 157 | } 158 | 159 | #[test] 160 | fn verify_cli() { 161 | use clap::CommandFactory; 162 | CliArgs::command().debug_assert() 163 | } 164 | -------------------------------------------------------------------------------- /src/client/connect.rs: -------------------------------------------------------------------------------- 1 | use bytes::Bytes; 2 | use std::{ 3 | collections::VecDeque, 4 | future::Future, 5 | io::{self, ErrorKind}, 6 | pin::Pin, 7 | sync::Arc, 8 | task::{Context, Poll}, 9 | }; 10 | use tokio::{net::TcpStream, time::timeout}; 11 | use tracing::{info, instrument}; 12 | 13 | use crate::proxy::{Destination, ProxyServer}; 14 | 15 | #[derive(Debug, Clone)] 16 | struct Request { 17 | dest: Destination, 18 | pending_data: Option, 19 | wait_response: bool, 20 | } 21 | 22 | #[instrument(skip_all, fields(proxy = %server.tag))] 23 | async fn try_connect(request: Request, server: Arc) -> io::Result { 24 | let max_wait = server.max_wait(); 25 | // waiting for proxy server connected 26 | let stream = timeout( 27 | max_wait, 28 | server.connect(&request.dest, request.pending_data), 29 | ) 30 | .await??; 31 | 32 | // waiting for response data 33 | if request.wait_response { 34 | let mut buf = [0u8; 4]; 35 | let len = timeout(max_wait, stream.peek(&mut buf)).await??; 36 | if len == 0 { 37 | return Err(io::Error::new(ErrorKind::UnexpectedEof, "no response data")); 38 | } 39 | } 40 | Ok(stream) 41 | } 42 | 43 | type PinnedConnectFuture = Pin> + Send>>; 44 | 45 | /// Try to connect one of the proxy servers. 46 | /// Pick `parallel_n` servers from `queue` to `connecting` and wait for 47 | /// connect. Once any of them connected, move that to `reading` and wait 48 | /// for read respone. Once any of handshakings done, return it and cancel 49 | /// others. 50 | pub struct TryConnectAll { 51 | request: Request, 52 | parallel_n: usize, 53 | standby: VecDeque>, 54 | connects: VecDeque<(Arc, PinnedConnectFuture)>, 55 | last_error: Option, 56 | } 57 | 58 | pub fn try_connect_all( 59 | dest: &Destination, 60 | servers: Vec>, 61 | parallel_n: usize, 62 | wait_response: bool, 63 | pending_data: Option, 64 | ) -> TryConnectAll { 65 | let parallel_n = parallel_n.clamp(1, if wait_response { servers.len() } else { 1 }); 66 | let servers = servers.into_iter().collect(); 67 | let request = Request { 68 | dest: dest.clone(), 69 | pending_data, 70 | wait_response, 71 | }; 72 | TryConnectAll { 73 | request, 74 | parallel_n, 75 | standby: servers, 76 | connects: VecDeque::with_capacity(parallel_n), 77 | last_error: None, 78 | } 79 | } 80 | 81 | impl Future for TryConnectAll { 82 | type Output = io::Result<(Arc, TcpStream)>; 83 | 84 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { 85 | loop { 86 | // if current connections less than parallel_n, 87 | // pick servers from queue to connect. 88 | while !self.standby.is_empty() && self.connects.len() < self.parallel_n { 89 | let server = self.standby.pop_front().unwrap(); 90 | let conn = try_connect(self.request.clone(), server.clone()); 91 | self.connects.push_back((server, Box::pin(conn))); 92 | } 93 | 94 | // poll all connects 95 | let mut i = 0; 96 | while i < self.connects.len() { 97 | let (server, conn) = &mut self.connects[i]; 98 | match conn.as_mut().poll(cx) { 99 | // error, stop trying, drop it. 100 | Poll::Ready(Err(err)) => { 101 | info!(proxy = %server.tag, ?err, "Failed to connect upstream proxy"); 102 | self.last_error = Some(err); 103 | drop(self.connects.remove(i)); 104 | } 105 | // not ready, keep here, poll next one. 106 | Poll::Pending => i += 1, 107 | // ready, return it. 108 | Poll::Ready(Ok(conn)) => return Poll::Ready(Ok((server.clone(), conn))), 109 | } 110 | } 111 | 112 | // if all servers failed, return error 113 | if self.connects.is_empty() && self.standby.is_empty() { 114 | let err = self.last_error.take().unwrap_or_else(|| { 115 | io::Error::new(io::ErrorKind::InvalidInput, "no upstream proxy") 116 | }); 117 | return Poll::Ready(Err(err)); 118 | } 119 | 120 | // if not need to connect standby server, wait for events. 121 | if self.connects.len() >= self.parallel_n || self.standby.is_empty() { 122 | return Poll::Pending; 123 | } 124 | } 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /src/client/mod.rs: -------------------------------------------------------------------------------- 1 | mod connect; 2 | mod tls_parser; 3 | use bytes::{Bytes, BytesMut}; 4 | use flexstr::SharedStr; 5 | use std::{ 6 | borrow::Cow, 7 | io, 8 | net::{IpAddr, SocketAddr}, 9 | sync::Arc, 10 | time::Duration, 11 | }; 12 | use tokio::{ 13 | io::{AsyncReadExt, AsyncWriteExt}, 14 | net::TcpStream, 15 | time::timeout, 16 | }; 17 | use tracing::{debug, info, instrument, warn}; 18 | 19 | #[cfg(target_os = "linux")] 20 | use crate::linux::tcp::TcpStreamExt; 21 | use crate::{ 22 | client::connect::try_connect_all, 23 | policy::RequestFeatures, 24 | proxy::{copy::pipe, Traffic}, 25 | proxy::{Address, Destination, ProxyServer}, 26 | }; 27 | 28 | #[derive(Debug, Default)] 29 | pub struct TlsData { 30 | pending_data: Option, 31 | has_full_tls_hello: bool, 32 | pub sni: Option, 33 | } 34 | 35 | #[derive(Debug)] 36 | pub struct NewClient { 37 | left: TcpStream, 38 | /// Destination IP address or domain name with port number. 39 | /// Retrived from firewall or SOCKSv5 request initially, may be override 40 | /// by TLS SNI. 41 | pub dest: Destination, 42 | /// Destination IP address. Unlike `dest`, it won't be override by SNI. 43 | dest_ip_addr: Option, 44 | /// Server's TCP port number. 45 | from_port: u16, 46 | pub tls: Option, 47 | } 48 | 49 | #[derive(Debug)] 50 | pub struct ConnectedClient { 51 | orig: NewClient, 52 | right: TcpStream, 53 | server: Arc, 54 | } 55 | 56 | #[derive(Debug)] 57 | pub enum FailedClient { 58 | Recoverable(NewClient), 59 | Unrecoverable(io::Error), 60 | } 61 | 62 | impl From for FailedClient { 63 | fn from(value: io::Error) -> Self { 64 | Self::Unrecoverable(value) 65 | } 66 | } 67 | 68 | fn error_invalid_input(msg: &'static str) -> io::Result { 69 | Err(io::Error::new(io::ErrorKind::InvalidInput, msg)) 70 | } 71 | 72 | trait SocketAddrExt { 73 | fn normalize(&self) -> Cow; 74 | } 75 | 76 | impl SocketAddrExt for SocketAddr { 77 | fn normalize(&self) -> Cow { 78 | match self { 79 | SocketAddr::V4(sock) => { 80 | let addr = sock.ip().to_ipv6_mapped(); 81 | let sock = SocketAddr::new(addr.into(), sock.port()); 82 | Cow::Owned(sock) 83 | } 84 | _ => Cow::Borrowed(self), 85 | } 86 | } 87 | } 88 | 89 | #[instrument(skip_all)] 90 | async fn accept_socks5(client: &mut TcpStream) -> io::Result { 91 | // Not a NATed connection, treated as SOCKSv5 92 | // Parse version 93 | // TODO: add timeout 94 | // TODO: use buffered reader 95 | let ver = client.read_u8().await?; 96 | if ver != 0x05 { 97 | return error_invalid_input("Neither a NATed or SOCKSv5 connection"); 98 | } 99 | // Parse auth methods 100 | let n_methods = client.read_u8().await?; 101 | let mut buf = vec![0u8; n_methods as usize]; 102 | client.read_exact(&mut buf).await?; 103 | if !buf.iter().any(|&m| m == 0) { 104 | return error_invalid_input("SOCKSv5: No auth is required"); 105 | } 106 | // Select no auth 107 | client.write_all(&[0x05, 0x00]).await?; 108 | // Parse request 109 | buf.resize(4, 0); 110 | client.read_exact(&mut buf).await?; 111 | if buf[0..2] != [0x05, 0x01] { 112 | return error_invalid_input("SOCKSv5: CONNECT is required"); 113 | } 114 | let addr: Address = match buf[3] { 115 | 0x01 => { 116 | // IPv4 117 | let mut buf = [0u8; 4]; 118 | client.read_exact(&mut buf).await?; 119 | buf.into() 120 | } 121 | 0x03 => { 122 | // Domain name 123 | let len = client.read_u8().await? as usize; 124 | buf.resize(len, 0); 125 | client.read_exact(&mut buf).await?; 126 | 127 | let domain = std::str::from_utf8(&buf).map_err(|_| { 128 | io::Error::new(io::ErrorKind::InvalidInput, "SOCKSv5: Invalid domain name") 129 | })?; 130 | Address::Domain(domain.into()) 131 | } 132 | 0x04 => { 133 | // IPv6 134 | let mut buf = [0u8; 16]; 135 | client.read_exact(&mut buf).await?; 136 | buf.into() 137 | } 138 | _ => return error_invalid_input("SOCKSv5: unknown address type"), 139 | }; 140 | let port = client.read_u16().await?; 141 | // Send response 142 | client.write_all(&[5, 0, 0, 1, 0, 0, 0, 0, 0, 0]).await?; 143 | Ok((addr, port).into()) 144 | } 145 | 146 | impl NewClient { 147 | #[instrument(name = "retrieve_dest", skip_all)] 148 | pub async fn from_socket(mut left: TcpStream) -> io::Result { 149 | let from_port = left.local_addr()?.port(); 150 | 151 | // Try to get original destination before NAT 152 | #[cfg(target_os = "linux")] 153 | let dest = match left.get_original_dest()? { 154 | // Redirecting to itself is possible. Treat it as non-redirect. 155 | Some(dest) if dest.normalize() != left.local_addr()?.normalize() => Some(dest), 156 | _ => None, 157 | }; 158 | 159 | // No NAT supported 160 | #[cfg(not(target_os = "linux"))] 161 | let dest: Option = None; 162 | 163 | let dest = if let Some(dest) = dest { 164 | debug!(?dest, "Retrived destination via NAT info"); 165 | dest.into() 166 | } else { 167 | let dest = accept_socks5(&mut left).await?; 168 | debug!(?dest, "Retrived destination via SOCKSv5"); 169 | dest 170 | }; 171 | 172 | let dest_ip_addr = match dest.host { 173 | Address::Ip(ip) => Some(ip), 174 | Address::Domain(_) => None, 175 | }; 176 | 177 | Ok(NewClient { 178 | left, 179 | dest, 180 | dest_ip_addr, 181 | from_port, 182 | tls: None, 183 | }) 184 | } 185 | 186 | fn pending_data(&self) -> Option { 187 | Some(self.tls.as_ref()?.pending_data.as_ref()?.clone()) 188 | } 189 | 190 | pub fn features(&self) -> RequestFeatures { 191 | RequestFeatures { 192 | listen_port: Some(self.from_port), 193 | dst_domain: self.dest.host.domain(), 194 | dst_ip: self.dest_ip_addr, 195 | } 196 | } 197 | 198 | pub fn override_dest_with_sni(&mut self) -> bool { 199 | match ( 200 | &mut self.dest.host, 201 | &self.tls.as_ref().and_then(|tls| tls.sni.clone()), 202 | ) { 203 | (Address::Domain(_), _) => false, 204 | (_, None) => false, 205 | (dst, Some(host)) => { 206 | *dst = Address::Domain(host.clone()); 207 | true 208 | } 209 | } 210 | } 211 | 212 | #[instrument(level = "error", skip_all, fields(dest=?self.dest))] 213 | pub async fn direct_connect( 214 | self, 215 | pseudo_server: Arc, 216 | ) -> io::Result { 217 | let mut right = match self.dest.host { 218 | Address::Ip(addr) => TcpStream::connect((addr, self.dest.port)).await?, 219 | Address::Domain(ref name) => { 220 | TcpStream::connect((name.as_ref(), self.dest.port)).await? 221 | } 222 | }; 223 | right.set_nodelay(true)?; 224 | 225 | if let Some(data) = self.pending_data() { 226 | right.write_all(&data).await?; 227 | } 228 | 229 | info!(remote = %right.peer_addr()?, "Connected w/o proxy"); 230 | Ok(ConnectedClient { 231 | orig: self, 232 | right, 233 | server: pseudo_server, 234 | }) 235 | } 236 | 237 | #[instrument(level = "error", skip_all, fields(dest=?self.dest))] 238 | pub async fn retrieve_dest_from_sni(&mut self) -> io::Result<()> { 239 | if self.tls.is_some() { 240 | return Ok(()); 241 | } 242 | let mut tls = TlsData::default(); 243 | let wait = Duration::from_millis(500); 244 | let mut buf = BytesMut::with_capacity(2048); 245 | buf.resize(buf.capacity(), 0); 246 | if let Ok(len) = timeout(wait, self.left.read(&mut buf)).await { 247 | buf.truncate(len?); 248 | // only TLS is safe to duplicate requests. 249 | match tls_parser::parse_client_hello(&buf) { 250 | Err(err) => info!("fail to parse hello: {}", err), 251 | Ok(hello) => { 252 | tls.has_full_tls_hello = true; 253 | if let Some(name) = hello.server_name { 254 | tls.sni = Some(name.into()); 255 | debug!(sni = name, "SNI found"); 256 | } 257 | if hello.early_data { 258 | debug!("TLS with early data"); 259 | } 260 | } 261 | } 262 | tls.pending_data = Some(buf.freeze()); 263 | } else { 264 | info!("no tls request received before timeout"); 265 | } 266 | self.tls = Some(tls); 267 | Ok(()) 268 | } 269 | 270 | #[instrument(level = "error", skip_all, fields(dest=?self.dest))] 271 | pub async fn connect_server( 272 | self, 273 | proxies: Vec>, 274 | n_parallel: usize, 275 | ) -> Result { 276 | if proxies.is_empty() { 277 | warn!("No avaiable proxy"); 278 | return Err(FailedClient::Recoverable(self)); 279 | } 280 | let (n_parallel, wait_response) = match self.tls { 281 | Some(ref tls) if tls.has_full_tls_hello => (n_parallel.clamp(1, proxies.len()), true), 282 | _ => (1, false), 283 | }; 284 | let proxies_len = proxies.len(); 285 | match try_connect_all( 286 | &self.dest, 287 | proxies, 288 | n_parallel, 289 | wait_response, 290 | self.pending_data(), 291 | ) 292 | .await 293 | { 294 | Ok((server, right)) => { 295 | info!(proxy = %server.tag, "Proxy connected"); 296 | Ok(ConnectedClient { 297 | orig: self, 298 | right, 299 | server, 300 | }) 301 | } 302 | Err(err) => { 303 | warn!("Tried {} proxies but failed: {}", proxies_len, err); 304 | Err(FailedClient::Recoverable(self)) 305 | } 306 | } 307 | } 308 | } 309 | 310 | impl FailedClient { 311 | pub fn recovery(self) -> io::Result { 312 | match self { 313 | Self::Recoverable(client) => Ok(client), 314 | Self::Unrecoverable(err) => Err(err), 315 | } 316 | } 317 | } 318 | 319 | impl ConnectedClient { 320 | #[instrument(level = "error", skip_all, fields(dest=?self.orig.dest, proxy=%self.server.tag))] 321 | pub async fn serve(self) -> io::Result<()> { 322 | let ConnectedClient { 323 | orig, 324 | right, 325 | server, 326 | .. 327 | } = self; 328 | // TODO: make keepalive configurable 329 | // FIXME: set_cookies 330 | /* 331 | let timeout = Some(Duration::from_secs(180)); 332 | FIXME: keepalive 333 | https://github.com/tokio-rs/tokio/issues/3109 334 | 335 | if let Err(e) = left 336 | .set_keepalive(timeout) 337 | .and(right.set_keepalive(timeout)) 338 | { 339 | warn!("fail to set keepalive: {}", e); 340 | } 341 | */ 342 | server.update_stats_conn_open(); 343 | match pipe(orig.left, right, server.clone()).await { 344 | Ok(Traffic { tx_bytes, rx_bytes }) => { 345 | server.update_stats_conn_close(false); 346 | debug!(tx_bytes, rx_bytes, "Closed"); 347 | Ok(()) 348 | } 349 | Err(err) => { 350 | server.update_stats_conn_close(true); 351 | info!(?err, "Closed"); 352 | Err(err) 353 | } 354 | } 355 | } 356 | } 357 | -------------------------------------------------------------------------------- /src/client/tls_parser.rs: -------------------------------------------------------------------------------- 1 | use std::ops::Range; 2 | use std::str::from_utf8; 3 | 4 | const EXT_SERVER_NAME: &[u8] = &[0, 0]; 5 | const EXT_EARLY_DATA: &[u8] = &[0, 42]; 6 | 7 | pub struct TlsClientHello<'a> { 8 | pub server_name: Option<&'a str>, 9 | pub early_data: bool, 10 | } 11 | 12 | struct TlsRecord<'a> { 13 | content_type: &'a u8, 14 | version_major: &'a u8, 15 | #[allow(dead_code)] 16 | version_minor: &'a u8, 17 | fragment: &'a [u8], 18 | } 19 | 20 | fn truncate(data: &[u8], len_pos: Range) -> Result<&[u8], &'static str> { 21 | let len_bits = data 22 | .get(len_pos.clone()) 23 | .ok_or("lack data to decode length")?; 24 | let mut len = 0usize; 25 | for bit in len_bits { 26 | len = len << 8 | (*bit as usize); 27 | } 28 | data.get(len_pos.end..len_pos.end + len) 29 | .ok_or("not enough data") 30 | } 31 | 32 | fn drop_before(data: &[u8], len_pos: Range) -> Result<&[u8], &'static str> { 33 | let len = truncate(data, len_pos.clone())?.len(); 34 | Ok(&data[len_pos.end + len..]) 35 | } 36 | 37 | fn parse_tls_record(data: &[u8]) -> Result { 38 | let fragment = truncate(data, 3..5)?; 39 | Ok(TlsRecord { 40 | content_type: &data[0], 41 | version_major: &data[1], 42 | version_minor: &data[2], 43 | fragment, 44 | }) 45 | } 46 | 47 | pub fn parse_client_hello(data: &[u8]) -> Result { 48 | let TlsRecord { 49 | content_type: &ctype, 50 | version_major: &version, 51 | fragment, 52 | .. 53 | } = parse_tls_record(data)?; 54 | if version != 3 { 55 | return Err("unknown tls version"); 56 | } 57 | if ctype != 22 { 58 | return Err("not handshake"); 59 | } 60 | 61 | // 0: handshake type 62 | if fragment.first() != Some(&1) { 63 | return Err("not client hello"); 64 | } 65 | let hello = truncate(fragment, 1..4)?; 66 | // 0..2: client version 67 | if hello.first() != Some(&3) { 68 | return Err("unsupported client version"); 69 | } 70 | // 2..34: 32-bytes random, dropped 71 | // 34+: session id, dropped 72 | let remaining = drop_before(hello, 34..35)?; 73 | // cipher suite, dropped 74 | let remaining = drop_before(remaining, 0..2)?; 75 | // compression methods, dropped 76 | let remaining = drop_before(remaining, 0..1)?; 77 | // 2-byte length of extensions 78 | let mut exts = truncate(remaining, 0..2)?; 79 | let mut server_name = None; 80 | let mut early_data = false; 81 | while exts.len() >= 4 { 82 | // 0..2: extension type 83 | let ext_type = &exts[0..2]; 84 | // 2..4: extension length 85 | let ext_data = truncate(exts, 2..4)?; 86 | exts = drop_before(exts, 2..4)?; 87 | if ext_type == EXT_SERVER_NAME { 88 | server_name = parse_server_name_ext(ext_data)?; 89 | } else if ext_type == EXT_EARLY_DATA { 90 | early_data = true; 91 | } 92 | } 93 | 94 | Ok(TlsClientHello { 95 | server_name, 96 | early_data, 97 | }) 98 | } 99 | 100 | /// Parse SNI data, return hostname. 101 | fn parse_server_name_ext(ext_data: &[u8]) -> Result, &'static str> { 102 | let mut data = truncate(ext_data, 0..2)?; 103 | while data.len() > 3 { 104 | let value = truncate(data, 1..3)?; 105 | let name_type = data[0]; 106 | data = drop_before(data, 1..3)?; 107 | if name_type == 0 { 108 | // hostname 109 | let name = from_utf8(value).map_err(|_| "server name not utf-8 string")?; 110 | if name.as_bytes().len() > 255 { 111 | return Err("server name too long"); 112 | } 113 | if !name 114 | .chars() 115 | .all(|c| c.is_digit(36) || c == '.' || c == '-' || c == '_') 116 | { 117 | return Err("illegal char in server name"); 118 | } 119 | return Ok(Some(name)); 120 | } 121 | } 122 | Ok(None) 123 | } 124 | 125 | #[test] 126 | fn test_parse_without_server_name() { 127 | let data = [ 128 | 0x16, 0x03, 0x01, 0x00, 0xa1, 0x01, 0x00, 0x00, 0x9d, 0x03, 0x03, 0x52, 0x36, 0x2c, 0x10, 129 | 0x12, 0xcf, 0x23, 0x62, 0x82, 0x56, 0xe7, 0x45, 0xe9, 0x03, 0xce, 0xa6, 0x96, 0xe9, 0xf6, 130 | 0x2a, 0x60, 0xba, 0x0a, 0xe8, 0x31, 0x1d, 0x70, 0xde, 0xa5, 0xe4, 0x19, 0x49, 0x00, 0x00, 131 | 0x04, 0xc0, 0x30, 0x00, 0xff, 0x02, 0x01, 0x00, 0x00, 0x6f, 0x00, 0x0b, 0x00, 0x04, 0x03, 132 | 0x00, 0x01, 0x02, 0x00, 0x0a, 0x00, 0x34, 0x00, 0x32, 0x00, 0x0e, 0x00, 0x0d, 0x00, 0x19, 133 | 0x00, 0x0b, 0x00, 0x0c, 0x00, 0x18, 0x00, 0x09, 0x00, 0x0a, 0x00, 0x16, 0x00, 0x17, 0x00, 134 | 0x08, 0x00, 0x06, 0x00, 0x07, 0x00, 0x14, 0x00, 0x15, 0x00, 0x04, 0x00, 0x05, 0x00, 0x12, 135 | 0x00, 0x13, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x00, 0x0f, 0x00, 0x10, 0x00, 0x11, 0x00, 136 | 0x23, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x22, 0x00, 0x20, 0x06, 0x01, 0x06, 0x02, 0x06, 0x03, 137 | 0x05, 0x01, 0x05, 0x02, 0x05, 0x03, 0x04, 0x01, 0x04, 0x02, 0x04, 0x03, 0x03, 0x01, 0x03, 138 | 0x02, 0x03, 0x03, 0x02, 0x01, 0x02, 0x02, 0x02, 0x03, 0x01, 0x01, 0x00, 0x0f, 0x00, 0x01, 139 | 0x01, 140 | ]; 141 | let TlsRecord { 142 | content_type: &content_type, 143 | version_major: &version_major, 144 | version_minor: &version_minor, 145 | fragment, 146 | } = parse_tls_record(&data).unwrap(); 147 | assert_eq!(22, content_type); 148 | assert_eq!(3, version_major); 149 | assert_eq!(1, version_minor); 150 | assert_eq!(161, fragment.len()); 151 | assert_eq!(1, fragment[0]); 152 | assert_eq!(Some(&1), fragment.last()); 153 | 154 | let TlsClientHello { server_name, .. } = parse_client_hello(&data).unwrap(); 155 | assert_eq!(None, server_name); 156 | } 157 | 158 | #[test] 159 | fn test_parse_with_server_name() { 160 | let data = [ 161 | 0x16, 0x03, 0x01, 0x00, 0xba, 0x01, 0x00, 0x00, 0xb6, 0x03, 0x03, 0xce, 0xf3, 0xc8, 0x77, 162 | 0x36, 0x6a, 0x81, 0x3b, 0x2f, 0x22, 0xc8, 0xd3, 0x29, 0xed, 0xf8, 0xb6, 0xec, 0xd9, 0x73, 163 | 0xfb, 0x76, 0x66, 0x6c, 0xbb, 0xa0, 0x50, 0xbd, 0x42, 0x13, 0xd5, 0xc4, 0xf1, 0x00, 0x00, 164 | 0x1e, 0xc0, 0x2b, 0xc0, 0x2f, 0xcc, 0xa9, 0xcc, 0xa8, 0xc0, 0x2c, 0xc0, 0x30, 0xc0, 0x0a, 165 | 0xc0, 0x09, 0xc0, 0x13, 0xc0, 0x14, 0x00, 0x33, 0x00, 0x39, 0x00, 0x2f, 0x00, 0x35, 0x00, 166 | 0x0a, 0x01, 0x00, 0x00, 0x6f, 0x00, 0x00, 0x00, 0x13, 0x00, 0x11, 0x00, 0x00, 0x0e, 0x77, 167 | 0x77, 0x77, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x00, 0x17, 168 | 0x00, 0x00, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x0a, 0x00, 0x08, 0x00, 0x1d, 169 | 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x00, 0x0b, 0x00, 0x02, 0x01, 0x00, 0x00, 0x23, 0x00, 170 | 0x00, 0x00, 0x10, 0x00, 0x0e, 0x00, 0x0c, 0x02, 0x68, 0x32, 0x08, 0x68, 0x74, 0x74, 0x70, 171 | 0x2f, 0x31, 0x2e, 0x31, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, 172 | 0x00, 0x18, 0x00, 0x16, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x08, 0x04, 0x08, 0x05, 0x08, 173 | 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x03, 0x02, 0x01, 174 | ]; 175 | let TlsClientHello { server_name, .. } = parse_client_hello(&data).unwrap(); 176 | assert_eq!(Some("www.google.com"), server_name); 177 | } 178 | -------------------------------------------------------------------------------- /src/futures_stream.rs: -------------------------------------------------------------------------------- 1 | use futures_core::{ready, Stream}; 2 | use std::{ 3 | io::Result, 4 | task::{Context, Poll}, 5 | }; 6 | use tokio::net::{TcpListener, TcpStream}; 7 | #[cfg(unix)] 8 | use tokio::net::{UnixListener, UnixStream}; 9 | 10 | macro_rules! impl_stream { 11 | ($name:ident : $listener:ty => $stream:ty) => { 12 | pub struct $name(pub $listener); 13 | 14 | impl Stream for $name { 15 | type Item = Result<$stream>; 16 | 17 | fn poll_next( 18 | self: std::pin::Pin<&mut Self>, 19 | cx: &mut Context<'_>, 20 | ) -> Poll> { 21 | let (stream, _) = ready!(self.0.poll_accept(cx))?; 22 | Poll::Ready(Some(Ok(stream))) 23 | } 24 | } 25 | }; 26 | } 27 | 28 | impl_stream!(TcpListenerStream: TcpListener => TcpStream); 29 | 30 | #[cfg(unix)] 31 | impl_stream!(UnixListenerStream: UnixListener => UnixStream); 32 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod client; 2 | pub mod futures_stream; 3 | #[cfg(target_os = "linux")] 4 | pub mod linux; 5 | pub mod monitor; 6 | pub mod policy; 7 | pub mod proxy; 8 | #[cfg(feature = "web_console")] 9 | pub mod web; 10 | -------------------------------------------------------------------------------- /src/linux/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(feature = "systemd")] 2 | pub mod systemd; 3 | pub mod tcp; 4 | -------------------------------------------------------------------------------- /src/linux/systemd.rs: -------------------------------------------------------------------------------- 1 | use libc::{dev_t as Dev, ino_t as Inode}; 2 | use nix::sys::stat::fstat; 3 | use sd_notify::{notify, NotifyState}; 4 | use std::{borrow::Cow, env, io, os::unix::prelude::AsRawFd, process, time::Duration}; 5 | use tokio::time::sleep; 6 | use tracing::{info, instrument, trace, warn}; 7 | 8 | fn notify_enabled() -> bool { 9 | env::var_os("NOTIFY_SOCKET").is_some() 10 | } 11 | 12 | pub fn notify_ready() { 13 | if notify_enabled() && notify(false, &[NotifyState::Ready]).is_err() { 14 | warn!("fail to notify systemd (ready)") 15 | } 16 | } 17 | 18 | pub fn notify_realoding() { 19 | if notify_enabled() && notify(false, &[NotifyState::Reloading]).is_err() { 20 | warn!("fail to notify systemd (reloading)") 21 | } 22 | } 23 | 24 | pub fn set_status(status: Cow) { 25 | if notify_enabled() && notify(false, &[NotifyState::Status(&status)]).is_err() { 26 | warn!("fail to notify systemd (set status)"); 27 | } 28 | } 29 | 30 | /// Return the watchdog timeout if it's enabled by systemd. 31 | pub fn watchdog_timeout() -> Option { 32 | if !notify_enabled() { 33 | return None; 34 | } 35 | let pid: u32 = env::var("WATCHDOG_PID").ok()?.parse().ok()?; 36 | if pid != process::id() { 37 | info!( 38 | "WATCHDOG_PID was set to {}, not ours {}", 39 | pid, 40 | process::id() 41 | ); 42 | return None; 43 | } 44 | let usec: u64 = env::var("WATCHDOG_USEC").ok()?.parse().ok()?; 45 | Some(Duration::from_micros(usec)) 46 | } 47 | 48 | #[instrument(skip_all)] 49 | pub async fn watchdog_loop(timeout: Duration) { 50 | info!("Watchdog enabled, poke for every {}ms", timeout.as_millis()); 51 | loop { 52 | trace!("poke the watchdog"); 53 | if notify(false, &[NotifyState::Watchdog]).is_err() { 54 | warn!("fail to poke watchdog"); 55 | } 56 | sleep(timeout).await; 57 | } 58 | } 59 | 60 | /// Try to read the device & inode number from environment variable `JOURNAL_STREAM`. 61 | fn get_journal_stream_dev_ino() -> Option<(Dev, Inode)> { 62 | let stream_env = env::var_os("JOURNAL_STREAM")?; 63 | let (dev, ino) = stream_env.to_str()?.split_once(':')?; 64 | Some((dev.parse().ok()?, ino.parse().ok()?)) 65 | } 66 | 67 | /// Check if STDERR is connected with systemd's journal service. 68 | pub fn is_stderr_connected_to_journal() -> bool { 69 | if let Some((dev, ino)) = get_journal_stream_dev_ino() { 70 | if let Ok(stat) = fstat(io::stderr().as_raw_fd()) { 71 | return stat.st_dev == dev && stat.st_ino == ino; 72 | } 73 | } 74 | false 75 | } 76 | -------------------------------------------------------------------------------- /src/linux/tcp.rs: -------------------------------------------------------------------------------- 1 | use nix::sys::socket::{ 2 | getsockopt, setsockopt, 3 | sockopt::{Ip6tOriginalDst, OriginalDst, TcpCongestion}, 4 | }; 5 | use std::{ 6 | ffi::OsStr, 7 | io::{self, ErrorKind}, 8 | net::{SocketAddr, SocketAddrV4, SocketAddrV6}, 9 | os::fd::AsFd, 10 | }; 11 | use tokio::net::{TcpListener, TcpStream}; 12 | 13 | pub trait TcpStreamExt { 14 | fn get_original_dest(&self) -> io::Result>; 15 | } 16 | 17 | pub trait TcpListenerExt { 18 | fn set_congestion>(&self, alg: S) -> io::Result<()>; 19 | } 20 | 21 | impl TcpStreamExt for TcpStream { 22 | fn get_original_dest(&self) -> io::Result> { 23 | match get_original_dest_v4(self) { 24 | Ok(addr) => Ok(Some(SocketAddr::V4(addr))), 25 | Err(err) if err.kind() == ErrorKind::NotFound => match get_original_dest_v6(self) { 26 | Ok(addr) => Ok(Some(SocketAddr::V6(addr))), 27 | Err(err) if err.kind() == ErrorKind::NotFound => Ok(None), 28 | Err(err) => Err(err), 29 | }, 30 | Err(err) => Err(err), 31 | } 32 | } 33 | } 34 | 35 | impl TcpListenerExt for TcpListener { 36 | fn set_congestion>(&self, alg: S) -> io::Result<()> { 37 | let val = alg.as_ref().into(); 38 | setsockopt(self, TcpCongestion, &val)?; 39 | Ok(()) 40 | } 41 | } 42 | 43 | fn get_original_dest_v4(fd: &F) -> io::Result 44 | where 45 | F: AsFd, 46 | { 47 | let addr = getsockopt(fd, OriginalDst)?; 48 | Ok(SocketAddrV4::new( 49 | u32::from_be(addr.sin_addr.s_addr).into(), 50 | u16::from_be(addr.sin_port), 51 | )) 52 | } 53 | 54 | fn get_original_dest_v6(fd: &F) -> io::Result 55 | where 56 | F: AsFd, 57 | { 58 | let sockaddr = getsockopt(fd, Ip6tOriginalDst)?; 59 | Ok(SocketAddrV6::new( 60 | sockaddr.sin6_addr.s6_addr.into(), 61 | u16::from_be(sockaddr.sin6_port), 62 | sockaddr.sin6_flowinfo, 63 | sockaddr.sin6_scope_id, 64 | )) 65 | } 66 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | mod cli; 2 | mod server; 3 | 4 | use clap::Parser; 5 | use cli::{Commands, PolicyCommands}; 6 | use moproxy::policy::{ActionType, RequestFeatures}; 7 | use server::MoProxy; 8 | use std::str::FromStr; 9 | #[cfg(unix)] 10 | use tokio::signal::unix::{signal, SignalKind}; 11 | use tracing::{debug, error, info, instrument, warn}; 12 | 13 | #[cfg(all(feature = "systemd", target_os = "linux"))] 14 | use moproxy::linux::systemd; 15 | use tracing_subscriber::prelude::*; 16 | 17 | trait FromOptionStr> { 18 | fn parse(&self) -> Result, E>; 19 | } 20 | 21 | impl FromOptionStr for Option 22 | where 23 | T: FromStr, 24 | S: AsRef, 25 | { 26 | fn parse(&self) -> Result, E> { 27 | if let Some(s) = self { 28 | let t = T::from_str(s.as_ref())?; 29 | Ok(Some(t)) 30 | } else { 31 | Ok(None) 32 | } 33 | } 34 | } 35 | 36 | #[tokio::main] 37 | async fn main() { 38 | let mut args = cli::CliArgs::parse(); 39 | let command = args.command.take(); 40 | let mut log_registry: Option<_> = tracing_subscriber::registry().with(args.log_level).into(); 41 | 42 | #[cfg(all(feature = "systemd", target_os = "linux"))] 43 | { 44 | if systemd::is_stderr_connected_to_journal() { 45 | match tracing_journald::layer() { 46 | Ok(layer) => { 47 | log_registry.take().unwrap().with(layer).init(); 48 | debug!("Use native journal protocol"); 49 | } 50 | Err(err) => eprintln!( 51 | "Failed to connect systemd-journald: {}; fallback to STDERR.", 52 | err 53 | ), 54 | } 55 | } 56 | } 57 | if let Some(registry) = log_registry { 58 | registry.with(tracing_subscriber::fmt::layer()).init(); 59 | } 60 | 61 | // Init moproxy (read config files, etc.) 62 | let moproxy = MoProxy::new(args).await.expect("failed to start moproxy"); 63 | 64 | // Setup signal listener for reloading server list 65 | #[cfg(unix)] 66 | { 67 | let moproxy = moproxy.clone(); 68 | let mut signals = signal(SignalKind::hangup()).expect("cannot catch signal"); 69 | tokio::spawn(async move { 70 | while signals.recv().await.is_some() { 71 | reload_daemon(&moproxy); 72 | } 73 | }); 74 | } 75 | 76 | match &command { 77 | Some(Commands::Check { no_bind }) if *no_bind => { 78 | info!("Configuration checked"); 79 | return; 80 | } 81 | Some(Commands::Policy { command }) => match command { 82 | PolicyCommands::Get { 83 | listen_port, 84 | dst_ip, 85 | dst_domain, 86 | } => { 87 | let policy = moproxy.policy.read(); 88 | let features = RequestFeatures { 89 | listen_port: *listen_port, 90 | dst_ip: *dst_ip, 91 | dst_domain: dst_domain.as_deref(), 92 | }; 93 | let action = policy.matches(&features); 94 | println!("Policy: {action}"); 95 | if let ActionType::Require(caps) = action.action { 96 | let mut tags: Vec<_> = moproxy 97 | .monitor 98 | .servers() 99 | .iter() 100 | .filter(|s| caps.iter().all(|c| s.capable_anyof(c))) 101 | .map(|s| s.tag.clone()) 102 | .collect(); 103 | tags.sort(); 104 | println!("Allowed: {}", tags.join(", ")); 105 | } 106 | return; 107 | } 108 | }, 109 | _ => {} 110 | } 111 | 112 | // Listen on TCP ports 113 | let listener = moproxy 114 | .listen() 115 | .await 116 | .expect("cannot listen on given TCP port"); 117 | 118 | // Watchdog 119 | #[cfg(all(feature = "systemd", target_os = "linux"))] 120 | { 121 | if let Some(timeout) = systemd::watchdog_timeout() { 122 | tokio::spawn(systemd::watchdog_loop(timeout / 2)); 123 | } 124 | } 125 | 126 | // Notify systemd 127 | #[cfg(all(feature = "systemd", target_os = "linux"))] 128 | systemd::notify_ready(); 129 | 130 | match &command { 131 | None => listener.handle_forever().await, 132 | Some(Commands::Check { .. }) => info!("Configuration checked"), 133 | _ => unreachable!(), 134 | } 135 | } 136 | 137 | #[instrument(skip_all)] 138 | fn reload_daemon(moproxy: &MoProxy) { 139 | #[cfg(all(feature = "systemd", target_os = "linux"))] 140 | systemd::notify_realoding(); 141 | 142 | // feature: check deadlocks on signal 143 | let deadlocks = parking_lot::deadlock::check_deadlock(); 144 | if !deadlocks.is_empty() { 145 | error!("{} deadlocks detected!", deadlocks.len()); 146 | for (i, threads) in deadlocks.iter().enumerate() { 147 | debug!("Deadlock #{}", i); 148 | for t in threads { 149 | debug!("Thread Id {:#?}", t.thread_id()); 150 | debug!("{:#?}", t.backtrace()); 151 | } 152 | } 153 | } 154 | 155 | // actual reload 156 | debug!("SIGHUP received, reload server list."); 157 | if let Err(err) = moproxy.reload() { 158 | error!("fail to reload servers: {}", err); 159 | } 160 | 161 | #[cfg(all(feature = "systemd", target_os = "linux"))] 162 | systemd::notify_ready(); 163 | } 164 | -------------------------------------------------------------------------------- /src/monitor/alive_test.rs: -------------------------------------------------------------------------------- 1 | use futures_util::future::join_all; 2 | use std::{self, io, net::Shutdown, time::Duration}; 3 | #[cfg(all(feature = "systemd", target_os = "linux"))] 4 | use std::{ 5 | fmt, 6 | sync::atomic::{AtomicUsize, Ordering}, 7 | }; 8 | use tokio::{ 9 | io::AsyncReadExt, 10 | time::{timeout, Instant}, 11 | }; 12 | use tracing::{debug, instrument, warn}; 13 | 14 | use super::Monitor; 15 | #[cfg(all(feature = "systemd", target_os = "linux"))] 16 | use crate::linux::systemd; 17 | use crate::proxy::ProxyServer; 18 | 19 | #[cfg(all(feature = "systemd", target_os = "linux"))] 20 | struct TestProgress { 21 | total: usize, 22 | pass: AtomicUsize, 23 | fail: AtomicUsize, 24 | } 25 | 26 | #[cfg(all(feature = "systemd", target_os = "linux"))] 27 | impl TestProgress { 28 | fn new(total: usize) -> Self { 29 | Self { 30 | total, 31 | pass: AtomicUsize::new(0), 32 | fail: AtomicUsize::new(0), 33 | } 34 | } 35 | 36 | fn increase(&self, passed: bool) { 37 | let counter = if passed { &self.pass } else { &self.fail }; 38 | counter.fetch_add(1, Ordering::Relaxed); 39 | systemd::set_status(format!("{}", self).into()); 40 | } 41 | } 42 | 43 | #[cfg(all(feature = "systemd", target_os = "linux"))] 44 | impl fmt::Display for TestProgress { 45 | fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { 46 | let pass = self.pass.load(Ordering::Relaxed); 47 | let fail = self.fail.load(Ordering::Relaxed); 48 | if (pass + fail) < self.total { 49 | write!(fmt, "probing ({}/{} done)", pass + fail, self.total) 50 | } else { 51 | write!( 52 | fmt, 53 | "serving ({}/{} upstream {} up)", 54 | pass, 55 | self.total, 56 | if pass > 1 { "proxies" } else { "proxy" } 57 | ) 58 | } 59 | } 60 | } 61 | 62 | #[instrument(skip_all)] 63 | pub(crate) async fn test_all(monitor: &Monitor) { 64 | debug!("Start testing all servers"); 65 | #[cfg(all(feature = "systemd", target_os = "linux"))] 66 | let progress = TestProgress::new(monitor.servers().len()); 67 | #[cfg(all(feature = "systemd", target_os = "linux"))] 68 | let progress_ref = &progress; 69 | let tests: Vec<_> = monitor 70 | .servers() 71 | .into_iter() 72 | .map(move |server| { 73 | Box::pin(async move { 74 | let delay = alive_test(&server).await.ok(); 75 | 76 | #[cfg(all(feature = "systemd", target_os = "linux"))] 77 | progress_ref.increase(delay.is_some()); 78 | 79 | #[cfg(feature = "score_script")] 80 | { 81 | let mut caculated = false; 82 | if let Some(lua) = &monitor.lua { 83 | match lua 84 | .lock() 85 | .context(|ctx| server.update_delay_with_lua(delay, ctx)) 86 | { 87 | Ok(()) => caculated = true, 88 | Err(err) => warn!("fail to update score w/ Lua script: {}", err), 89 | } 90 | } 91 | if !caculated { 92 | server.update_delay(delay); 93 | } 94 | } 95 | #[cfg(not(feature = "score_script"))] 96 | server.update_delay(delay); 97 | }) 98 | }) 99 | .collect(); 100 | 101 | join_all(tests).await; 102 | monitor.resort(); 103 | } 104 | 105 | #[instrument(skip_all, fields(proxy = %server.tag))] 106 | async fn alive_test(server: &ProxyServer) -> io::Result { 107 | let request = [ 108 | 0, 109 | 17, // length 110 | rand::random(), 111 | rand::random(), // transaction ID 112 | 1, 113 | 32, // standard query 114 | 0, 115 | 1, // one query 116 | 0, 117 | 0, // answer 118 | 0, 119 | 0, // authority 120 | 0, 121 | 0, // addition 122 | 0, // query: root 123 | 0, 124 | 1, // query: type A 125 | 0, 126 | 1, // query: class IN 127 | ]; 128 | let tid = |req: &[u8]| (req[2] as u16) << 8 | (req[3] as u16); 129 | let req_tid = tid(&request); 130 | let now = Instant::now(); 131 | 132 | let mut buf = [0u8; 12]; 133 | let test_dns = server.test_dns().into(); 134 | let result = timeout(server.max_wait(), async { 135 | let mut stream = server.connect(&test_dns, Some(request)).await?; 136 | stream.read_exact(&mut buf).await?; 137 | stream.into_std()?.shutdown(Shutdown::Both) 138 | }) 139 | .await; 140 | 141 | match result { 142 | Err(_) => return Err(io::Error::new(io::ErrorKind::TimedOut, "test timeout")), 143 | Ok(Err(e)) => return Err(e), 144 | Ok(Ok(_)) => (), 145 | } 146 | 147 | if req_tid == tid(&buf) { 148 | let t = now.elapsed(); 149 | debug!("{}ms", t.as_millis()); 150 | Ok(t) 151 | } else { 152 | Err(io::Error::new(io::ErrorKind::Other, "unknown response")) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /src/monitor/graphite.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | io::{self, Write}, 3 | net::SocketAddr, 4 | time::{Duration, SystemTime}, 5 | }; 6 | use tokio::{io::AsyncWriteExt, net::TcpStream, time::timeout}; 7 | use tracing::{debug, instrument, warn}; 8 | 9 | static GRAPHITE_TIMEOUT_SECS: u64 = 5; 10 | 11 | #[derive(Debug)] 12 | pub struct Graphite { 13 | server_addr: SocketAddr, 14 | stream: Option, 15 | } 16 | 17 | #[derive(Clone, Debug)] 18 | pub struct Record { 19 | path: String, 20 | value: u64, 21 | time: Option, 22 | } 23 | 24 | impl Graphite { 25 | pub fn new(server_addr: SocketAddr) -> Self { 26 | Graphite { 27 | stream: None, 28 | server_addr, 29 | } 30 | } 31 | 32 | #[instrument(skip_all)] 33 | pub async fn write_records(&mut self, records: Vec) -> io::Result<()> { 34 | let Graphite { 35 | ref server_addr, 36 | stream: ref mut stream_opt, 37 | } = self; 38 | 39 | let mut stream = if let Some(stream) = stream_opt.take() { 40 | stream 41 | } else { 42 | debug!("start new connection to graphite server"); 43 | TcpStream::connect(&server_addr).await? 44 | }; 45 | 46 | let mut buf = Vec::new(); 47 | for record in records { 48 | record.write_paintext(&mut buf).unwrap(); 49 | } 50 | 51 | let max_wait = Duration::from_secs(GRAPHITE_TIMEOUT_SECS); 52 | match timeout(max_wait, stream.write_all(&buf)).await { 53 | Err(_) => Err(io::Error::from(io::ErrorKind::TimedOut)), 54 | Ok(Err(err)) => { 55 | warn!("fail to send metrics: {}", err); 56 | self.stream = None; 57 | Err(err) 58 | } 59 | Ok(Ok(_)) => { 60 | stream_opt.replace(stream); 61 | Ok(()) 62 | } 63 | } 64 | } 65 | } 66 | 67 | impl Record { 68 | pub fn new(path: String, value: u64, time: Option) -> Self { 69 | if !path.is_ascii() || path.contains(' ') || path.contains('\n') { 70 | panic!( 71 | "Graphite path contains space, line break, \ 72 | or non-ASCII characters." 73 | ); 74 | } 75 | Record { path, value, time } 76 | } 77 | 78 | fn write_paintext(&self, buf: &mut B) -> io::Result<()> 79 | where 80 | B: Write, 81 | { 82 | let time = match self.time { 83 | None => -1.0, 84 | Some(time) => { 85 | let t = time 86 | .duration_since(SystemTime::UNIX_EPOCH) 87 | .expect("SystemTime before UNIX EPOCH"); 88 | (t.as_secs() as f64) + (t.subsec_micros() as f64) / 1_000_000.0 89 | } 90 | }; 91 | writeln!(buf, "{} {} {}", self.path, self.value, time) 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/monitor/mod.rs: -------------------------------------------------------------------------------- 1 | mod graphite; 2 | #[cfg(feature = "score_script")] 3 | use rlua::prelude::*; 4 | mod alive_test; 5 | mod traffic; 6 | use parking_lot::Mutex; 7 | use rand::{self, Rng}; 8 | use std::{ 9 | self, 10 | collections::{HashMap, HashSet}, 11 | io, 12 | net::SocketAddr, 13 | sync::Arc, 14 | time::{Duration, SystemTime}, 15 | }; 16 | #[cfg(feature = "score_script")] 17 | use std::{fs::File, io::Read, path::Path}; 18 | use tokio::time::{interval_at, Instant}; 19 | use tracing::{debug, instrument, warn}; 20 | 21 | pub use self::traffic::Throughput; 22 | use self::{ 23 | graphite::{Graphite, Record}, 24 | traffic::Meter, 25 | }; 26 | #[cfg(all(feature = "systemd", target_os = "linux"))] 27 | use crate::linux::systemd; 28 | use crate::proxy::ProxyServer; 29 | 30 | static THROUGHPUT_INTERVAL_SECS: u64 = 1; 31 | 32 | pub type ServerList = Vec>; 33 | 34 | #[derive(Clone)] 35 | pub struct Monitor { 36 | servers: Arc>, 37 | meters: Arc, Meter>>>, 38 | graphite: Option, 39 | #[cfg(feature = "score_script")] 40 | lua: Option>>, 41 | } 42 | 43 | impl Monitor { 44 | pub fn new(servers: Vec>, graphite: Option) -> Monitor { 45 | let meters = servers 46 | .iter() 47 | .map(|server| (server.clone(), Meter::new())) 48 | .collect(); 49 | #[cfg(all(feature = "systemd", target_os = "linux"))] 50 | systemd::set_status( 51 | format!( 52 | "serving ({} upstream {}, state unknown)", 53 | servers.len(), 54 | if servers.len() > 1 { 55 | "proxies" 56 | } else { 57 | "proxy" 58 | } 59 | ) 60 | .into(), 61 | ); 62 | Monitor { 63 | servers: Arc::new(Mutex::new(servers)), 64 | meters: Arc::new(Mutex::new(meters)), 65 | graphite, 66 | #[cfg(feature = "score_script")] 67 | lua: None, 68 | } 69 | } 70 | 71 | #[cfg(feature = "score_script")] 72 | pub fn load_score_script>(&mut self, path: T) -> anyhow::Result<()> { 73 | use anyhow::{bail, Context}; 74 | 75 | let mut buf = Vec::new(); 76 | File::open(path)?.take(2u64.pow(26)).read_to_end(&mut buf)?; 77 | 78 | let lua = Lua::new(); 79 | lua.context(|ctx| -> anyhow::Result<()> { 80 | let globals = ctx.globals(); 81 | ctx.load(&buf).exec().context("failed to load Lua script")?; 82 | if !globals.contains_key("calc_score")? { 83 | bail!("calc_score() not found in Lua globals"); 84 | } 85 | let _: LuaFunction = match globals.get("calc_score") { 86 | Err(LuaError::FromLuaConversionError { .. }) => { 87 | bail!("calc_score is not a function"); 88 | } 89 | other => other, 90 | }?; 91 | Ok(()) 92 | })?; 93 | 94 | self.lua.replace(Arc::new(Mutex::new(lua))); 95 | Ok(()) 96 | } 97 | 98 | /// Return an ordered list of servers. 99 | pub fn servers(&self) -> ServerList { 100 | self.servers.lock().clone() 101 | } 102 | 103 | /// Replace internal servers with provided list. 104 | pub fn update_servers(&self, new_servers: Vec>) { 105 | let oldset: HashSet<_> = self.servers().into_iter().collect(); 106 | let newset = HashSet::from_iter(new_servers); 107 | let mut new_servers = Vec::with_capacity(newset.len()); 108 | 109 | // Copy config from new server objects to old ones. 110 | // That also ensure their `status` remain unchange. 111 | for server in oldset.intersection(&newset) { 112 | let old = oldset.get(server).unwrap(); 113 | let new = newset.get(server).unwrap(); 114 | old.copy_config_from(new); 115 | new_servers.push(old.clone()); 116 | } 117 | 118 | // Add brand new server objects 119 | new_servers.extend(newset.difference(&oldset).cloned()); 120 | 121 | // Create new meters 122 | let mut meters = self.meters.lock(); 123 | meters.clear(); 124 | for server in new_servers.iter() { 125 | meters.insert(server.clone(), Meter::new()); 126 | } 127 | 128 | *self.servers.lock() = new_servers; 129 | self.resort(); 130 | } 131 | 132 | fn resort(&self) { 133 | let mut rng = rand::thread_rng(); 134 | let mut servers = self.servers.lock(); 135 | servers.sort_by_key(move |server| { 136 | server.score().unwrap_or(std::i32::MAX) - (rng.gen::() % 30) as i32 137 | }); 138 | debug!("scores:{}", info_stats(&servers)); 139 | } 140 | 141 | /// Start monitoring delays. 142 | /// Returned Future won't return unless error on timer. 143 | #[instrument(skip_all)] 144 | pub async fn monitor_delay(self, probe: u64) { 145 | let mut graphite = self.graphite.map(Graphite::new); 146 | let interval = Duration::from_secs(probe); 147 | 148 | alive_test::test_all(&self).await; 149 | 150 | let mut interval = interval_at(Instant::now() + interval, interval); 151 | loop { 152 | interval.tick().await; 153 | alive_test::test_all(&self).await; 154 | if let Some(ref mut graphite) = graphite { 155 | match send_metrics(&self, graphite).await { 156 | Ok(_) => debug!("metrics sent"), 157 | Err(e) => warn!("fail to send metrics {:?}", e), 158 | } 159 | } 160 | } 161 | } 162 | 163 | /// Start monitoring throughput. 164 | /// Returned Future won't return unless error on timer. 165 | pub async fn monitor_throughput(self) { 166 | let interval = Duration::from_secs(THROUGHPUT_INTERVAL_SECS); 167 | let mut interval = interval_at(Instant::now() + interval, interval); 168 | loop { 169 | interval.tick().await; 170 | for (server, meter) in self.meters.lock().iter_mut() { 171 | meter.add_sample(server.traffic()); 172 | } 173 | } 174 | } 175 | 176 | /// Return average throughputs of all servers in the recent monitor 177 | /// period. Should start `monitor_throughput()` task before call this. 178 | pub fn throughputs(&self) -> HashMap, Throughput> { 179 | self.meters 180 | .lock() 181 | .iter() 182 | .map(|(server, meter)| (server.clone(), meter.throughput(server.traffic()))) 183 | .collect() 184 | } 185 | } 186 | 187 | fn info_stats(infos: &[Arc]) -> String { 188 | let mut stats = String::new(); 189 | for info in infos.iter().take(5) { 190 | stats += &match info.score() { 191 | None => format!(" {}: --,", info.tag), 192 | Some(t) => format!(" {}/{},", info.tag, t), 193 | }; 194 | } 195 | stats.pop(); 196 | stats 197 | } 198 | 199 | // send graphite metrics if need 200 | #[instrument(skip_all)] 201 | async fn send_metrics(monitor: &Monitor, graphite: &mut Graphite) -> io::Result<()> { 202 | let records = monitor 203 | .servers() 204 | .iter() 205 | .flat_map(|server| { 206 | let now = Some(SystemTime::now()); 207 | let r = |path, value| Record::new(server.graphite_path(path), value, now); 208 | let status = server.status_snapshot(); 209 | let traffic = server.traffic(); 210 | vec![ 211 | status.delay.map(|t| r("delay", t.as_millis() as u64)), 212 | status.score.map(|s| r("score", s as u64)), 213 | Some(r("tx_bytes", traffic.tx_bytes as u64)), 214 | Some(r("rx_bytes", traffic.rx_bytes as u64)), 215 | Some(r("conns.total", status.conn_total as u64)), 216 | Some(r("conns.alive", status.conn_alive as u64)), 217 | Some(r("conns.error", status.conn_error as u64)), 218 | ] 219 | }) 220 | .flatten() 221 | .collect(); // FIXME: avoid allocate large memory 222 | graphite.write_records(records).await 223 | } 224 | -------------------------------------------------------------------------------- /src/monitor/traffic.rs: -------------------------------------------------------------------------------- 1 | use crate::proxy::Traffic; 2 | use serde_derive::Serialize; 3 | use std::{collections::VecDeque, ops::Add, time::Instant}; 4 | 5 | /// Monitor & caculate throughtput using traffic samples. 6 | #[derive(Debug)] 7 | pub struct Meter { 8 | samples: VecDeque, 9 | } 10 | 11 | #[derive(Debug)] 12 | pub struct TrafficSample { 13 | time: Instant, 14 | amt: Traffic, 15 | } 16 | 17 | #[derive(Clone, Copy, Debug, Default, Serialize)] 18 | pub struct Throughput { 19 | pub tx_bps: usize, 20 | pub rx_bps: usize, 21 | } 22 | 23 | impl From for TrafficSample { 24 | fn from(traffic: Traffic) -> Self { 25 | Self { 26 | time: Instant::now(), 27 | amt: traffic, 28 | } 29 | } 30 | } 31 | 32 | impl Throughput { 33 | fn from_samples(t0: &TrafficSample, t1: &TrafficSample) -> Self { 34 | let t = t1.time - t0.time; 35 | let t = t.as_secs() as f64 + t.subsec_nanos() as f64 / 1e9; 36 | let f = |x0, x1| (((x1 - x0) as f64) / t * 8.0).round() as usize; 37 | Throughput { 38 | tx_bps: f(t0.amt.tx_bytes, t1.amt.tx_bytes), 39 | rx_bps: f(t0.amt.rx_bytes, t1.amt.rx_bytes), 40 | } 41 | } 42 | } 43 | 44 | impl Add for Throughput { 45 | type Output = Throughput; 46 | 47 | fn add(self, other: Self) -> Self { 48 | Throughput { 49 | tx_bps: self.tx_bps + other.tx_bps, 50 | rx_bps: self.rx_bps + other.rx_bps, 51 | } 52 | } 53 | } 54 | 55 | impl Meter { 56 | pub fn new() -> Self { 57 | Meter { 58 | samples: VecDeque::with_capacity(2), 59 | } 60 | } 61 | 62 | pub fn add_sample(&mut self, sample: T) 63 | where 64 | T: Into, 65 | { 66 | self.samples.truncate(1); 67 | self.samples.push_front(sample.into()); 68 | } 69 | 70 | pub fn throughput(&self, sample: T) -> Throughput 71 | where 72 | T: Into, 73 | { 74 | let current = sample.into(); 75 | if let Some(oldest) = self.samples.back() { 76 | Throughput::from_samples(oldest, ¤t) 77 | } else { 78 | Default::default() 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/policy/capabilities.rs: -------------------------------------------------------------------------------- 1 | use std::{fmt::Display, mem}; 2 | 3 | use flexstr::SharedStr; 4 | use serde::Serialize; 5 | 6 | #[derive(Debug, Clone, Eq, PartialEq, Hash, Default, Serialize, PartialOrd, Ord)] 7 | pub struct CapSet(Box<[SharedStr]>); 8 | 9 | impl CapSet { 10 | pub fn new(caps: I) -> Self 11 | where 12 | I: Iterator, 13 | I::Item: Into, 14 | { 15 | let mut caps: Vec<_> = caps.map(|s| s.into()).collect(); 16 | caps.sort(); 17 | Self(caps.into()) 18 | } 19 | 20 | pub fn has_intersection(&self, other: &Self) -> bool { 21 | let mut a = &self.0[..]; 22 | let mut b = &other.0[..]; 23 | if a.len() < b.len() { 24 | mem::swap(&mut a, &mut b); 25 | } 26 | while !(a.is_empty() || b.is_empty()) { 27 | match a.binary_search(&b[0]) { 28 | Ok(_) => return true, 29 | Err(n) => { 30 | a = &a[n..]; 31 | b = &b[1..]; 32 | } 33 | } 34 | } 35 | false 36 | } 37 | 38 | pub fn is_empty(&self) -> bool { 39 | self.0.is_empty() 40 | } 41 | } 42 | 43 | impl Display for CapSet { 44 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 45 | if self.0.len() > 1 { 46 | write!(f, "(")?; 47 | } 48 | match self.0.first() { 49 | Some(cap) => write!(f, "{}", cap), 50 | None => write!(f, "(EMPTY)"), 51 | }?; 52 | for cap in self.0.iter().skip(1) { 53 | write!(f, " OR {}", cap)?; 54 | } 55 | if self.0.len() > 1 { 56 | write!(f, ")")?; 57 | } 58 | Ok(()) 59 | } 60 | } 61 | 62 | pub trait CheckAllCapsMeet { 63 | fn all_meet_by(self, caps: &CapSet) -> bool; 64 | } 65 | 66 | impl<'a, T> CheckAllCapsMeet for T 67 | where 68 | T: IntoIterator, 69 | { 70 | fn all_meet_by(self, caps: &CapSet) -> bool { 71 | self.into_iter().all(|req| req.has_intersection(caps)) 72 | } 73 | } 74 | 75 | #[test] 76 | fn test_capset_intersection() { 77 | let abc = CapSet::new(["a", "b", "c"].into_iter()); 78 | let def = CapSet::new(["d", "e", "f"].into_iter()); 79 | let bcg = CapSet::new(["b", "c", "g"].into_iter()); 80 | let aeg = CapSet::new(["a", "e", "g"].into_iter()); 81 | assert!(!abc.has_intersection(&def)); 82 | assert!(!def.has_intersection(&abc)); 83 | assert!(!def.has_intersection(&bcg)); 84 | assert!(!bcg.has_intersection(&def)); 85 | assert!(def.has_intersection(&aeg)); 86 | assert!(aeg.has_intersection(&def)); 87 | assert!(abc.has_intersection(&aeg)); 88 | } 89 | 90 | #[test] 91 | fn test_capset_display() { 92 | assert_eq!("(EMPTY)", CapSet::default().to_string()); 93 | assert_eq!("a", CapSet::new(["a"].into_iter()).to_string()); 94 | assert_eq!("(a OR b)", CapSet::new(["a", "b"].into_iter()).to_string()); 95 | } 96 | -------------------------------------------------------------------------------- /src/policy/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod capabilities; 2 | pub mod parser; 3 | 4 | use std::{ 5 | collections::{HashMap, HashSet}, 6 | fmt::Display, 7 | fs::File, 8 | hash::Hash, 9 | io::{self, BufRead, BufReader}, 10 | net::{IpAddr, Ipv4Addr, Ipv6Addr}, 11 | path::Path, 12 | }; 13 | 14 | use flexstr::{SharedStr, ToSharedStr}; 15 | 16 | use capabilities::CapSet; 17 | use ip_network_table_deps_treebitmap::{address::Address, IpLookupTable}; 18 | use tracing::info; 19 | 20 | use self::parser::{Filter, Rule}; 21 | 22 | #[derive(Debug, Clone, Default, PartialEq, Eq)] 23 | pub struct Action { 24 | priority: u8, 25 | pub action: ActionType, 26 | } 27 | 28 | #[derive(Debug, Clone, PartialEq, Eq)] 29 | pub enum ActionType { 30 | Require(HashSet), 31 | Direct, 32 | Reject, 33 | } 34 | 35 | impl Default for ActionType { 36 | fn default() -> Self { 37 | Self::Require(Default::default()) 38 | } 39 | } 40 | 41 | impl ActionType { 42 | fn wrap(self, priority: u8) -> Action { 43 | Action { 44 | priority, 45 | action: self, 46 | } 47 | } 48 | } 49 | 50 | impl From for Action { 51 | fn from(action: ActionType) -> Self { 52 | Self { 53 | priority: 0, 54 | action, 55 | } 56 | } 57 | } 58 | 59 | impl Action { 60 | fn len(&self) -> usize { 61 | match &self.action { 62 | ActionType::Direct | ActionType::Reject => 1, 63 | ActionType::Require(set) => set.len(), 64 | } 65 | } 66 | 67 | fn extend(&mut self, other: Self) { 68 | if self.priority < other.priority { 69 | *self = other; 70 | } else if self.priority == other.priority { 71 | match other.action { 72 | ActionType::Direct | ActionType::Reject => *self = other, 73 | ActionType::Require(new_caps) => { 74 | if let ActionType::Require(caps) = &mut self.action { 75 | caps.extend(new_caps.into_iter()) 76 | } else { 77 | self.action = ActionType::Require(new_caps) 78 | } 79 | } 80 | } 81 | } 82 | // Do nothing if self.priority > other.priority 83 | } 84 | } 85 | 86 | #[derive(Default)] 87 | struct RuleSet(HashMap); 88 | 89 | type ListenPortRuleSet = RuleSet; 90 | type DstDomainRuleSet = RuleSet; 91 | 92 | impl RuleSet { 93 | fn add(&mut self, key: K, action: Action) { 94 | // TODO: warning duplicated rules 95 | let value = self.0.entry(key).or_default(); 96 | value.extend(action) 97 | } 98 | 99 | fn get<'a>(&'a self, key: &'a K) -> impl Iterator { 100 | self.0.get(key).into_iter() 101 | } 102 | } 103 | 104 | impl DstDomainRuleSet { 105 | fn get_recursive<'a>(&'a self, name: &'a str) -> impl Iterator { 106 | let name = name.trim_end_matches('.'); // Add back later 107 | let mut skip = name.len() + 1; // pretend ending with dot 108 | let parts = name.rsplit('.').map(move |part| { 109 | skip -= part.len() + 1; // +1 for the dot 110 | &name[skip..] 111 | }); 112 | ["."] // add back the dot 113 | .into_iter() 114 | .chain(parts) 115 | .filter_map(|key| self.0.get(key)) 116 | } 117 | } 118 | 119 | struct IpRuleSet(IpLookupTable); 120 | 121 | type Ipv4RuleSet = IpRuleSet; 122 | type Ipv6RuleSet = IpRuleSet; 123 | 124 | impl Default for IpRuleSet { 125 | fn default() -> Self { 126 | Self(IpLookupTable::new()) 127 | } 128 | } 129 | 130 | impl IpRuleSet { 131 | fn add(&mut self, net: (A, u8), action: Action) { 132 | let (ip, len) = net; 133 | let len = len as u32; 134 | match self.0.exact_match_mut(ip, len) { 135 | Some(item) => item.extend(action), 136 | None => { 137 | self.0.insert(ip, len, action); 138 | } 139 | } 140 | } 141 | 142 | fn get<'a>(&'a self, ip: &'a A) -> impl Iterator { 143 | self.0.matches(*ip).map(|(_, _, action)| action) 144 | } 145 | 146 | fn actions(&self) -> impl Iterator { 147 | self.0.iter().map(|(_, _, action)| action) 148 | } 149 | } 150 | 151 | #[derive(Debug, Default, Clone)] 152 | pub struct RequestFeatures> { 153 | pub listen_port: Option, 154 | pub dst_ip: Option, 155 | pub dst_domain: Option, 156 | } 157 | 158 | #[derive(Default)] 159 | pub struct Policy { 160 | default_action: Action, 161 | listen_port_ruleset: ListenPortRuleSet, 162 | dst_ipv4_ruleset: Ipv4RuleSet, 163 | dst_ipv6_ruleset: Ipv6RuleSet, 164 | dst_domain_ruleset: DstDomainRuleSet, 165 | } 166 | 167 | impl Policy { 168 | pub fn load(read: R) -> io::Result { 169 | let mut router: Self = Default::default(); 170 | for line in read.lines() { 171 | match parser::line_no_ending(&line?) { 172 | Ok((_, None)) => (), 173 | Ok((_, Some(rule))) => router.add_rule(rule), 174 | Err(err) => return Err(io::Error::new(io::ErrorKind::InvalidData, err.to_owned())), 175 | } 176 | } 177 | Ok(router) 178 | } 179 | 180 | pub fn load_from_file>(path: T) -> io::Result { 181 | let file = File::open(path)?; 182 | let reader = BufReader::new(file); 183 | let this = Self::load(reader)?; 184 | info!("policy: {} rule(s) loaded", this.rule_count()); 185 | Ok(this) 186 | } 187 | 188 | fn add_rule(&mut self, rule: parser::Rule) { 189 | let Rule { filter, action } = rule; 190 | match filter { 191 | Filter::Default => self.default_action.extend(action), 192 | Filter::ListenPort(port) => { 193 | self.listen_port_ruleset.add(port, action); 194 | } 195 | Filter::DstSni(parts) => { 196 | self.dst_domain_ruleset.add(parts.to_shared_str(), action); 197 | } 198 | Filter::DstIp((IpAddr::V4(ip), len)) => { 199 | self.dst_ipv4_ruleset.add((ip, len), action); 200 | } 201 | Filter::DstIp((IpAddr::V6(ip), len)) => { 202 | self.dst_ipv6_ruleset.add((ip, len), action); 203 | } 204 | } 205 | } 206 | 207 | pub fn rule_count(&self) -> usize { 208 | self.listen_port_ruleset 209 | .0 210 | .values() 211 | .chain(self.dst_domain_ruleset.0.values()) 212 | .chain(self.dst_ipv4_ruleset.actions()) 213 | .chain(self.dst_ipv6_ruleset.actions()) 214 | .fold(0, |acc, v| acc + v.len()) 215 | } 216 | 217 | pub fn matches>(&self, features: &RequestFeatures) -> Action { 218 | let mut action: Action = self.default_action.clone(); 219 | if let Some(port) = features.listen_port { 220 | self.listen_port_ruleset 221 | .get(&port) 222 | .for_each(|a| action.extend(a.clone())) 223 | } 224 | 225 | // Canonicalize IP address 226 | // Waiting for stablizion of IpAddr::to_canonical() 227 | let dst_ip = match features.dst_ip { 228 | None => None, 229 | ip @ Some(IpAddr::V4(_)) => ip, 230 | ip @ Some(IpAddr::V6(v6)) => match v6.to_ipv4_mapped() { 231 | None => ip, 232 | Some(v4) => Some(IpAddr::V4(v4)), 233 | }, 234 | }; 235 | if let Some(IpAddr::V4(ip)) = dst_ip { 236 | self.dst_ipv4_ruleset 237 | .get(&ip) 238 | .for_each(|a| action.extend(a.clone())) 239 | } 240 | if let Some(IpAddr::V6(ip)) = dst_ip { 241 | self.dst_ipv6_ruleset 242 | .get(&ip) 243 | .for_each(|a| action.extend(a.clone())) 244 | } 245 | 246 | if let Some(name) = &features.dst_domain { 247 | self.dst_domain_ruleset 248 | .get_recursive(name.as_ref()) 249 | .for_each(|a| action.extend(a.clone())); 250 | } 251 | action 252 | } 253 | } 254 | 255 | impl Display for Action { 256 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 257 | match &self.action { 258 | ActionType::Direct => write!(f, "DIRECT"), 259 | ActionType::Reject => write!(f, "REJECT"), 260 | ActionType::Require(_) => write!(f, "REQUIRE"), 261 | }?; 262 | for _ in 0..self.priority { 263 | write!(f, "!")?; 264 | } 265 | if let ActionType::Require(ref caps) = self.action { 266 | let mut caps = Vec::from_iter(caps); 267 | caps.sort_unstable(); 268 | match caps.first() { 269 | Some(cap) => write!(f, " {}", cap)?, 270 | None => write!(f, " NOTHING")?, 271 | } 272 | for cap in caps.iter().skip(1) { 273 | write!(f, " AND {}", cap)?; 274 | } 275 | } 276 | Ok(()) 277 | } 278 | } 279 | 280 | #[test] 281 | fn test_policy_listen_port() { 282 | use capabilities::CheckAllCapsMeet; 283 | 284 | let rules = " 285 | listen port 1 require a 286 | listen port 2 require b 287 | listen port 2 require c or d 288 | "; 289 | let policy = Policy::load(rules.as_bytes()).unwrap(); 290 | assert_eq!(3, policy.rule_count()); 291 | let mut features: RequestFeatures<&'static str> = Default::default(); 292 | features.listen_port = Some(1); 293 | let p1 = match policy.matches(&features).action { 294 | ActionType::Require(a) => a, 295 | _ => panic!(), 296 | }; 297 | features.listen_port = Some(2); 298 | let p2 = match policy.matches(&features).action { 299 | ActionType::Require(a) => a, 300 | _ => panic!(), 301 | }; 302 | let abc = CapSet::new(["a", "b", "c"].into_iter()); 303 | let bc = CapSet::new(["b", "c"].into_iter()); 304 | let c = CapSet::new(["c"].into_iter()); 305 | assert!(p1.all_meet_by(&abc)); 306 | assert!(!p1.all_meet_by(&bc)); 307 | assert!(!p1.all_meet_by(&c)); 308 | assert!(p2.all_meet_by(&abc)); 309 | assert!(p2.all_meet_by(&bc)); 310 | assert!(!p2.all_meet_by(&c)); 311 | } 312 | 313 | #[test] 314 | fn test_policy_dst_ip() { 315 | use std::str::FromStr; 316 | 317 | let rules = " 318 | dst ip 0.0.0.0/0 require a4 319 | dst ip ::1 require a6 320 | dst ip 127.0.0.0/8 require b4 321 | dst ip 127.0.0.1 direct 322 | "; 323 | let policy = Policy::load(rules.as_bytes()).unwrap(); 324 | assert_eq!(4, policy.rule_count()); 325 | 326 | let mut features: RequestFeatures<&str> = Default::default(); 327 | 328 | // Match 0.0.0.0/0 329 | features.dst_ip = IpAddr::from_str("1.2.3.4").ok(); 330 | let action = policy.matches(&features).action; 331 | assert!(matches!(action, ActionType::Require(a) if a.len() == 1)); 332 | 333 | // Match 0.0.0.0/0 & 127.0.0.0/8 334 | features.dst_ip = IpAddr::from_str("127.1.1.1").ok(); 335 | let action1 = policy.matches(&features).action; 336 | features.dst_ip = IpAddr::from_str("::ffff:127.1.1.1").ok(); 337 | let action2 = policy.matches(&features).action; 338 | assert_eq!(action1, action2); 339 | assert!(matches!(action1, ActionType::Require(a) if a.len() == 2)); 340 | 341 | // Match 0.0.0.0/0 & 127.0.0.0/8, then override by 127.0.0.1/32 DIRECT 342 | features.dst_ip = IpAddr::from_str("127.0.0.1").ok(); 343 | let action = policy.matches(&features).action; 344 | assert!(matches!(action, ActionType::Direct)); 345 | 346 | // Match ::1/128 347 | features.dst_ip = IpAddr::from_str("::1").ok(); 348 | let action = policy.matches(&features).action; 349 | assert!(matches!(action, ActionType::Require(a) if a.len() == 1)); 350 | } 351 | 352 | #[test] 353 | fn test_policy_get_domain_caps_requirements() { 354 | let policy = Policy::load( 355 | " 356 | dst domain . require root 357 | dst domain com. require com 358 | dst domain example.com require example 359 | " 360 | .as_bytes(), 361 | ) 362 | .unwrap(); 363 | let set = policy.dst_domain_ruleset; 364 | assert_eq!(3, set.get_recursive("test.example.com").count()); 365 | assert_eq!(3, set.get_recursive("example.com").count()); 366 | assert_eq!(2, set.get_recursive("com").count()); 367 | assert_eq!(1, set.get_recursive("net").count()); 368 | } 369 | 370 | #[test] 371 | fn test_policy_action() { 372 | let rules = " 373 | default require def 374 | listen port 1 require a 375 | listen port 2 direct 376 | dst domain test require c 377 | dst domain d.test direct 378 | "; 379 | let policy = Policy::load(rules.as_bytes()).unwrap(); 380 | // listen-port/direct override default/require 381 | let direct1 = policy.matches(&RequestFeatures { 382 | listen_port: Some(2), 383 | dst_domain: Some("abcd"), 384 | ..Default::default() 385 | }); 386 | assert!(matches!(direct1.action, ActionType::Direct)); 387 | // d.test/direct override others 388 | let direct2 = policy.matches(&RequestFeatures { 389 | listen_port: Some(1), 390 | dst_domain: Some("a.d.test"), 391 | ..Default::default() 392 | }); 393 | assert!(matches!(direct2.action, ActionType::Direct)); 394 | // just default/require 395 | let require1 = policy.matches(&RequestFeatures { 396 | listen_port: Some(3), 397 | dst_domain: Some("abcd"), 398 | ..Default::default() 399 | }); 400 | assert!(matches!(require1.action, ActionType::Require(a) if a.len() == 1)); 401 | // default/require + dst-domain/require 402 | let require2 = policy.matches(&RequestFeatures { 403 | dst_domain: Some("test"), 404 | ..Default::default() 405 | }); 406 | assert!(matches!(require2.action, ActionType::Require(a) if a.len() == 2)); 407 | // default/require + dst-domain/require + listen-port/require 408 | let require3 = policy.matches(&RequestFeatures { 409 | listen_port: Some(1), 410 | dst_domain: Some("test"), 411 | ..Default::default() 412 | }); 413 | assert!(matches!(require3.action, ActionType::Require(a) if a.len() == 3)); 414 | } 415 | 416 | #[test] 417 | fn test_policy_action_priority() { 418 | let rules = " 419 | default require! def 420 | listen port 1 reject # always ignore 421 | dst domain a require! a 422 | dst domain a.a reject! # same-level override 423 | dst domain a.a.a require!! aaa # level override 424 | dst domain a.a.a.a require!! aaaa #same-level append 425 | "; 426 | let policy = Policy::load(rules.as_bytes()).unwrap(); 427 | let mut features: RequestFeatures<&'static str> = Default::default(); 428 | 429 | features.listen_port = Some(10); 430 | let def = policy.matches(&features); 431 | assert!(matches!(&def.action, ActionType::Require(a) if a.len() == 1)); 432 | assert_eq!(1, def.priority); 433 | 434 | features.listen_port = Some(1); 435 | let action = policy.matches(&features); 436 | assert_eq!(def, action); 437 | 438 | features.dst_domain = Some("a.a"); 439 | let action = policy.matches(&features); 440 | assert!(matches!(&action.action, ActionType::Reject)); 441 | 442 | features.dst_domain = Some("a.a.a"); 443 | let action = policy.matches(&features); 444 | assert!(matches!(&action.action, ActionType::Require(a) if a.len() == 1)); 445 | assert_eq!(2, action.priority); 446 | 447 | features.dst_domain = Some("a.a.a.a"); 448 | let action = policy.matches(&features); 449 | assert!(matches!(&action.action, ActionType::Require(a) if a.len() == 2)); 450 | assert_eq!(2, action.priority); 451 | } 452 | 453 | #[test] 454 | fn test_action_type_display() { 455 | assert_eq!( 456 | "DIRECT", 457 | Action { 458 | action: ActionType::Direct, 459 | priority: 0 460 | } 461 | .to_string() 462 | ); 463 | assert_eq!( 464 | "REJECT!!", 465 | Action { 466 | action: ActionType::Reject, 467 | priority: 2 468 | } 469 | .to_string() 470 | ); 471 | assert_eq!("REQUIRE NOTHING", Action::default().to_string()); 472 | let caps = HashSet::from_iter(vec![ 473 | CapSet::new(["a"].into_iter()), 474 | CapSet::new(["b", "c"].into_iter()), 475 | ]); 476 | let action = ActionType::Require(caps); 477 | assert_eq!( 478 | "REQUIRE! a AND (b OR c)", 479 | Action { 480 | action, 481 | priority: 1 482 | } 483 | .to_string() 484 | ); 485 | } 486 | -------------------------------------------------------------------------------- /src/policy/parser.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashSet, 3 | net::{IpAddr, Ipv4Addr, Ipv6Addr}, 4 | str::FromStr, 5 | }; 6 | 7 | use flexstr::{shared_str, SharedStr, ToCase}; 8 | use nom::{ 9 | branch::alt, 10 | bytes::complete::{tag, tag_no_case, take_till1}, 11 | character::complete::{char, hex_digit1, not_line_ending, space0, space1, u16, u8}, 12 | combinator::{eof, fail, opt, recognize, verify}, 13 | multi::{many0_count, many1, many_m_n, separated_list0, separated_list1}, 14 | sequence::tuple, 15 | IResult, Parser, 16 | }; 17 | 18 | use super::{capabilities::CapSet, Action, ActionType}; 19 | 20 | #[derive(Debug, PartialEq, Eq)] 21 | pub enum Filter { 22 | Default, 23 | ListenPort(u16), 24 | DstSni(SharedStr), 25 | DstIp((IpAddr, u8)), 26 | } 27 | 28 | #[derive(Debug, PartialEq, Eq)] 29 | pub struct Rule { 30 | pub filter: Filter, 31 | pub action: Action, 32 | } 33 | 34 | fn port_number(input: &str) -> IResult<&str, u16> { 35 | verify(u16, |&n| n != 0)(input) 36 | } 37 | 38 | fn ipv4_addr(input: &str) -> IResult<&str, Ipv4Addr> { 39 | tuple((u8, tag("."), u8, tag("."), u8, tag("."), u8)) 40 | .map(|(a, _, b, _, c, _, d)| Ipv4Addr::new(a, b, c, d)) 41 | .parse(input) 42 | } 43 | 44 | fn ipv6_addr(input: &str) -> IResult<&str, Ipv6Addr> { 45 | let (rem, str) = 46 | recognize(separated_list1(tag(":"), many_m_n(0, 4, hex_digit1))).parse(input)?; 47 | match Ipv6Addr::from_str(str) { 48 | Ok(addr) => Ok((rem, addr)), 49 | Err(_) => fail(str), 50 | } 51 | } 52 | 53 | fn ip_addr_prefix_len(input: &str) -> IResult<&str, (IpAddr, u8)> { 54 | let v4 = ipv4_addr.map(IpAddr::V4); 55 | let v6 = ipv6_addr.map(IpAddr::V6); 56 | let prefix_len = tuple((tag("/"), u8)).map(|(_, n)| n); 57 | let (rem, (ip, len)) = tuple((alt((v4, v6)), opt(prefix_len))).parse(input)?; 58 | let len = len.unwrap_or(if ip.is_ipv4() { 32 } else { 128 }); 59 | match ip { 60 | IpAddr::V4(_) if len > 32 => fail(input), 61 | IpAddr::V6(_) if len > 128 => fail(input), 62 | _ => Ok((rem, (ip, len))), 63 | } 64 | } 65 | 66 | fn id_chars(input: &str) -> IResult<&str, &str> { 67 | take_till1(|c: char| !c.is_alphanumeric() && c != '-' && c != '_')(input) 68 | } 69 | 70 | fn domain_name_part(input: &str) -> IResult<&str, SharedStr> { 71 | tuple((id_chars, opt(char('.')))) 72 | .map(|(name, _)| name.into()) 73 | .parse(input) 74 | } 75 | 76 | fn domain_name_root(input: &str) -> IResult<&str, ()> { 77 | char('.').map(|_| ()).parse(input) 78 | } 79 | 80 | fn domain_name(input: &str) -> IResult<&str, SharedStr> { 81 | alt(( 82 | recognize(many1(domain_name_part)).map(|n| { 83 | match n.strip_suffix('.') { 84 | Some(n) => SharedStr::from(n), 85 | None => n.into(), 86 | } 87 | .to_lower() 88 | }), 89 | domain_name_root.map(|_| shared_str!(".")), 90 | ))(input) 91 | } 92 | 93 | fn filter_dst_ip(input: &str) -> IResult<&str, Filter> { 94 | tuple((tag_no_case("dst ip"), space1, ip_addr_prefix_len)) 95 | .map(|(_, _, net)| Filter::DstIp(net)) 96 | .parse(input) 97 | } 98 | 99 | fn filter_dst_domain(input: &str) -> IResult<&str, Filter> { 100 | tuple((tag_no_case("dst domain"), space1, domain_name)) 101 | .map(|(_, _, parts)| Filter::DstSni(parts)) 102 | .parse(input) 103 | } 104 | 105 | fn filter_listen_port(input: &str) -> IResult<&str, Filter> { 106 | tuple((tag_no_case("listen port"), space1, port_number)) 107 | .map(|(_, _, n)| Filter::ListenPort(n)) 108 | .parse(input) 109 | } 110 | 111 | fn filter_default(input: &str) -> IResult<&str, Filter> { 112 | tag_no_case("default").map(|_| Filter::Default).parse(input) 113 | } 114 | 115 | fn rule_filter(input: &str) -> IResult<&str, Filter> { 116 | alt(( 117 | filter_dst_ip, 118 | filter_dst_domain, 119 | filter_listen_port, 120 | filter_default, 121 | ))(input) 122 | } 123 | 124 | fn cap_name(input: &str) -> IResult<&str, SharedStr> { 125 | id_chars.map(SharedStr::from).parse(input) 126 | } 127 | 128 | fn caps1(input: &str) -> IResult<&str, Vec> { 129 | separated_list1(tuple((space1, tag_no_case("or"), space1)), cap_name)(input) 130 | } 131 | 132 | fn action_priority(input: &str) -> IResult<&str, u8> { 133 | verify(many0_count(tag("!")), |n| *n <= 5) 134 | .map(|n| n as u8) 135 | .parse(input) 136 | } 137 | 138 | fn action_require(input: &str) -> IResult<&str, Action> { 139 | tuple((tag_no_case("require"), action_priority, space1, caps1)) 140 | .map(|(_, priority, _, caps)| { 141 | let mut set = HashSet::new(); 142 | set.insert(CapSet::new(caps.into_iter())); 143 | ActionType::Require(set).wrap(priority) 144 | }) 145 | .parse(input) 146 | } 147 | 148 | fn action_direct(input: &str) -> IResult<&str, Action> { 149 | tuple((tag_no_case("direct"), action_priority)) 150 | .map(|(_, priority)| ActionType::Direct.wrap(priority)) 151 | .parse(input) 152 | } 153 | 154 | fn action_reject(input: &str) -> IResult<&str, Action> { 155 | tuple((tag_no_case("reject"), action_priority)) 156 | .map(|(_, priority)| ActionType::Reject.wrap(priority)) 157 | .parse(input) 158 | } 159 | 160 | fn rule_action(input: &str) -> IResult<&str, Action> { 161 | alt((action_require, action_direct, action_reject)).parse(input) 162 | } 163 | 164 | fn rule(input: &str) -> IResult<&str, Rule> { 165 | tuple((rule_filter, space1, rule_action)) 166 | .map(|(filter, _, action)| Rule { filter, action }) 167 | .parse(input) 168 | } 169 | 170 | fn comment(input: &str) -> IResult<&str, ()> { 171 | tuple((char('#'), not_line_ending)).map(|_| ()).parse(input) 172 | } 173 | 174 | pub fn capabilities(input: &str) -> IResult<&str, CapSet> { 175 | separated_list0( 176 | alt((recognize(tuple((space0, tag(","), space0))), space1)), 177 | cap_name, 178 | ) 179 | .map(|caps| CapSet::new(caps.into_iter())) 180 | .parse(input) 181 | } 182 | 183 | pub fn line_no_ending(input: &str) -> IResult<&str, Option> { 184 | alt(( 185 | tuple((space0, opt(comment), space0, eof)).map(|_| None), 186 | tuple((space0, rule, space0, opt(comment), eof)).map(|(_, rule, _, _, _)| Some(rule)), 187 | )) 188 | .parse(input) 189 | } 190 | 191 | #[test] 192 | fn test_parse_domain_name_root() { 193 | let (empty, parts) = domain_name(".").unwrap(); 194 | assert!(empty.is_empty()); 195 | assert_eq!(shared_str!("."), parts); 196 | } 197 | 198 | #[test] 199 | fn test_parse_domain_name() { 200 | let (rem, parts) = domain_name("Test_-123.Example.Com.\n").unwrap(); 201 | assert_eq!("\n", rem); 202 | assert_eq!(shared_str!("test_-123.example.com"), parts); 203 | 204 | let (rem, parts) = domain_name("example\n").unwrap(); 205 | assert_eq!("\n", rem); 206 | assert_eq!(shared_str!("example"), parts); 207 | } 208 | 209 | #[test] 210 | fn test_listen_port_filter() { 211 | let (rem, port) = filter_listen_port("listen port 1234\n").unwrap(); 212 | assert_eq!("\n", rem); 213 | assert_eq!(Filter::ListenPort(1234), port); 214 | } 215 | 216 | #[test] 217 | fn test_dst_domain_filter() { 218 | let (rem, parts) = filter_dst_domain("dst domain test\n").unwrap(); 219 | assert_eq!("\n", rem); 220 | assert_eq!(Filter::DstSni(shared_str!("test")), parts); 221 | } 222 | 223 | #[test] 224 | fn test_dst_ip_filter() { 225 | let (rem, filter) = filter_dst_ip("dst ip ::\n").unwrap(); 226 | assert_eq!("\n", rem); 227 | assert!(matches!(filter, Filter::DstIp((_, 128)))); 228 | } 229 | 230 | #[test] 231 | fn test_dst_default_filter() { 232 | let (rem, parts) = filter_default("default\n").unwrap(); 233 | assert_eq!("\n", rem); 234 | assert_eq!(Filter::Default, parts); 235 | } 236 | 237 | #[test] 238 | fn test_ipv4_addr() { 239 | let (_, ip) = ipv4_addr("255.0.1.2").unwrap(); 240 | assert_eq!(ip, Ipv4Addr::new(255, 0, 1, 2)); 241 | assert!(ipv4_addr("256.0.0.0").is_err()); 242 | assert!(ipv4_addr("127.0.1").is_err()); 243 | assert!(ipv4_addr("0").is_err()); 244 | assert!(ipv4_addr("").is_err()); 245 | } 246 | 247 | #[test] 248 | fn test_ipv6_addr() { 249 | let addrs = ["::", "::1", "1::", "1::1", "1:0:ffff:fff:ff:0f:0000:8"]; 250 | for addr in addrs { 251 | let (_, parsed) = ipv6_addr(addr).unwrap(); 252 | assert_eq!(parsed, Ipv6Addr::from_str(addr).unwrap()); 253 | } 254 | let addrs = ["", "0", "1::2::3", "g::0", "1:2:3:4:5:6:7:8:9"]; 255 | for addr in addrs { 256 | assert!(ipv6_addr(addr).is_err()); 257 | } 258 | } 259 | 260 | #[test] 261 | fn test_ip_addr_prefix_len() { 262 | let (_, (ip, len)) = ip_addr_prefix_len("0.0.0.0/0").unwrap(); 263 | assert_eq!(ip, IpAddr::from_str("0.0.0.0").unwrap()); 264 | assert_eq!(len, 0); 265 | let (_, (ip, len)) = ip_addr_prefix_len("127.0.0.1").unwrap(); 266 | assert_eq!(ip, IpAddr::from_str("127.0.0.1").unwrap()); 267 | assert_eq!(len, 32); 268 | let (_, (ip, len)) = ip_addr_prefix_len("::/0").unwrap(); 269 | assert_eq!(ip, IpAddr::from_str("::").unwrap()); 270 | assert_eq!(len, 0); 271 | let (_, (ip, len)) = ip_addr_prefix_len("::1").unwrap(); 272 | assert_eq!(ip, IpAddr::from_str("::1").unwrap()); 273 | assert_eq!(len, 128); 274 | 275 | assert!(ip_addr_prefix_len("0.0.0.0/33").is_err()); 276 | assert!(ip_addr_prefix_len("::/129").is_err()); 277 | } 278 | 279 | #[test] 280 | fn test_action() { 281 | let (rem, action) = rule_action("require a or b\n").unwrap(); 282 | assert_eq!("\n", rem); 283 | assert!(matches!( 284 | action.action, 285 | ActionType::Require(caps) if caps.iter().next().unwrap() == &CapSet::new(["a", "b"].into_iter()), 286 | )); 287 | let (_, action) = rule_action("direct\n").unwrap(); 288 | assert_eq!(ActionType::Direct, action.action); 289 | let (_, action) = rule_action("reject\n").unwrap(); 290 | assert_eq!(ActionType::Reject, action.action); 291 | } 292 | 293 | #[test] 294 | fn test_action_priority() { 295 | let (_, action) = rule_action("require a").unwrap(); 296 | assert_eq!(0, action.priority); 297 | let (_, action) = rule_action("require! a").unwrap(); 298 | assert_eq!(1, action.priority); 299 | let (_, action) = rule_action("direct!!").unwrap(); 300 | assert_eq!(2, action.priority); 301 | let (_, action) = rule_action("reject!!!!!").unwrap(); 302 | assert_eq!(5, action.priority); 303 | assert!(rule_action("reject!!!!!!").is_err()); 304 | assert!(rule_action("require!!!1 a").is_err()); 305 | } 306 | 307 | #[test] 308 | fn test_rule() { 309 | let (_, result) = rule("listen port 1 require!! a\n").unwrap(); 310 | let mut set = HashSet::new(); 311 | set.insert(CapSet::new(["a"].into_iter())); 312 | assert_eq!( 313 | Rule { 314 | filter: Filter::ListenPort(1), 315 | action: ActionType::Require(set).wrap(2) 316 | }, 317 | result 318 | ); 319 | assert!(rule("default require!!!1 a\n").is_err()); 320 | assert!(rule("default require ! a\n").is_err()); 321 | } 322 | 323 | #[test] 324 | fn test_comment() { 325 | comment("# test\n").unwrap(); 326 | comment("#\n").unwrap(); 327 | } 328 | 329 | #[test] 330 | fn test_line_no_ending() { 331 | let (_, rules) = line_no_ending("").unwrap(); 332 | assert!(rules.is_none()); 333 | let (_, rules) = line_no_ending(" \t ").unwrap(); 334 | assert!(rules.is_none()); 335 | let (_, rules) = line_no_ending("#").unwrap(); 336 | assert!(rules.is_none()); 337 | let (_, rules) = line_no_ending(" # test ").unwrap(); 338 | assert!(rules.is_none()); 339 | let (_, rules) = line_no_ending("dst domain test require a #1").unwrap(); 340 | assert!(rules.is_some()); 341 | } 342 | 343 | #[test] 344 | fn test_line_no_ending_error() { 345 | assert!(line_no_ending("dst").is_err()); 346 | assert!(line_no_ending("dst domain test error").is_err()); 347 | assert!(line_no_ending("dst domain require a b").is_err()); 348 | } 349 | 350 | #[test] 351 | fn test_capabilities() { 352 | let (_, caps) = capabilities("a b c ").unwrap(); 353 | assert_eq!( 354 | CapSet::new(["a", "b", "c"].into_iter().map(SharedStr::from)), 355 | caps 356 | ); 357 | let (_, caps) = capabilities("a, b ,c,d,").unwrap(); 358 | assert_eq!( 359 | CapSet::new("abcd".chars().into_iter().map(SharedStr::from)), 360 | caps 361 | ); 362 | let (_, caps) = capabilities(" ").unwrap(); 363 | assert!(caps.is_empty()); 364 | } 365 | -------------------------------------------------------------------------------- /src/proxy/copy.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | cell::RefCell, 3 | cmp, fmt, 4 | future::Future, 5 | io, 6 | ops::Neg, 7 | pin::Pin, 8 | sync::Arc, 9 | task::{Context, Poll}, 10 | thread_local, 11 | time::Duration, 12 | }; 13 | use tokio::{ 14 | io::{AsyncRead, AsyncWrite, ReadBuf}, 15 | net::TcpStream, 16 | time::{sleep, Instant, Sleep}, 17 | }; 18 | use tracing::{debug, trace}; 19 | 20 | use self::Side::{Left, Right}; 21 | use crate::proxy::{ProxyServer, Traffic}; 22 | 23 | #[derive(Debug, Clone)] 24 | enum Side { 25 | Left, 26 | Right, 27 | } 28 | 29 | impl fmt::Display for Side { 30 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 31 | match *self { 32 | Left => write!(f, "local"), 33 | Right => write!(f, "remote"), 34 | } 35 | } 36 | } 37 | 38 | impl Neg for Side { 39 | type Output = Side; 40 | 41 | fn neg(self) -> Side { 42 | match self { 43 | Left => Right, 44 | Right => Left, 45 | } 46 | } 47 | } 48 | 49 | macro_rules! try_poll { 50 | ($expr:expr) => { 51 | match $expr { 52 | Poll::Pending => return Poll::Pending, 53 | Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), 54 | Poll::Ready(Ok(v)) => v, 55 | } 56 | }; 57 | } 58 | 59 | // The number of shared buffers is fixed (equals to no. of CPU threads), 60 | // so we can use a larger one for better performance. 61 | const SHARED_BUF_SIZE: usize = 1024 * 64; 62 | const PRIVATE_BUF_SIZE: usize = 1024 * 8; 63 | 64 | thread_local!( 65 | static SHARED_BUFFER: RefCell<[u8; SHARED_BUF_SIZE]> = RefCell::new([0u8; SHARED_BUF_SIZE]); 66 | ); 67 | 68 | struct StreamWithBuffer { 69 | pub stream: TcpStream, 70 | buf: Option>, 71 | pos: usize, 72 | cap: usize, 73 | pub read_eof: bool, 74 | pub all_done: bool, 75 | } 76 | 77 | impl StreamWithBuffer { 78 | pub fn new(stream: TcpStream) -> Self { 79 | StreamWithBuffer { 80 | stream, 81 | buf: None, 82 | pos: 0, 83 | cap: 0, 84 | read_eof: false, 85 | all_done: false, 86 | } 87 | } 88 | 89 | pub fn is_empty(&self) -> bool { 90 | self.pos == self.cap 91 | } 92 | 93 | pub fn poll_read_to_buffer(&mut self, cx: &mut Context) -> Poll> { 94 | let mut read = |buf: &mut [u8]| { 95 | let stream = Pin::new(&mut self.stream); 96 | let mut buf = ReadBuf::new(buf); 97 | stream.poll_read(cx, &mut buf).map_ok(|_| { 98 | let n = buf.filled().len(); 99 | if n == 0 { 100 | self.read_eof = true; 101 | } else { 102 | self.pos = 0; 103 | self.cap = n; 104 | } 105 | trace!("{} bytes read", n); 106 | n 107 | }) 108 | }; 109 | 110 | if let Some(buf) = self.buf.as_deref_mut() { 111 | read(buf) 112 | } else { 113 | SHARED_BUFFER.with(|buf| read(&mut *buf.borrow_mut())) 114 | } 115 | } 116 | 117 | pub fn poll_write_buffer_to( 118 | &mut self, 119 | cx: &mut Context, 120 | writer: &mut TcpStream, 121 | ) -> Poll> { 122 | let writer = Pin::new(writer); 123 | 124 | let result = if let Some(buf) = self.buf.as_deref() { 125 | writer.poll_write(cx, &buf[self.pos..self.cap]) 126 | } else { 127 | SHARED_BUFFER.with(|buf| { 128 | let buf = &buf.borrow()[self.pos..self.cap]; 129 | match writer.poll_write(cx, buf) { 130 | Poll::Pending => { 131 | // Move remaining data to the private buffer 132 | let n = self.cap - self.pos; 133 | trace!("allocate private buffer for {} bytes", n); 134 | let mut shared = Vec::with_capacity(cmp::max(PRIVATE_BUF_SIZE, n)); 135 | shared.extend_from_slice(buf); 136 | self.pos = 0; 137 | self.cap = n; 138 | self.buf = Some(shared.into_boxed_slice()); 139 | Poll::Pending 140 | } 141 | any => any, 142 | } 143 | }) 144 | }; 145 | match result { 146 | Poll::Ready(Ok(0)) => Poll::Ready(Err(io::Error::new( 147 | io::ErrorKind::WriteZero, 148 | "write zero byte into writer", 149 | ))), 150 | Poll::Ready(Ok(n)) => { 151 | self.pos += n; 152 | trace!("{} bytes written", n); 153 | Poll::Ready(Ok(n)) 154 | } 155 | _ => result, 156 | } 157 | } 158 | 159 | pub fn shrink_private_buffer_if_need(&mut self) { 160 | assert!(self.is_empty()); 161 | if let Some(ref mut buf) = self.buf { 162 | if buf.len() > PRIVATE_BUF_SIZE { 163 | trace!( 164 | "shrink private buffer from {} to {} bytes", 165 | buf.len(), 166 | PRIVATE_BUF_SIZE 167 | ); 168 | *buf = vec![0; PRIVATE_BUF_SIZE].into_boxed_slice() 169 | } 170 | } 171 | } 172 | } 173 | 174 | // Pipe two TcpStream in both direction, 175 | // update traffic amount to ProxyServer on the fly. 176 | pub struct BiPipe { 177 | left: StreamWithBuffer, 178 | right: StreamWithBuffer, 179 | server: Arc, 180 | traffic: Traffic, 181 | half_close_deadline: Option>>, 182 | } 183 | 184 | // Half-closed connections will be forcibly closed if there is no traffic 185 | // after the following duration. 186 | const HALF_CLOSE_TIMEOUT: Duration = Duration::from_secs(60); 187 | 188 | pub fn pipe(left: TcpStream, right: TcpStream, server: Arc) -> BiPipe { 189 | let (left, right) = (StreamWithBuffer::new(left), StreamWithBuffer::new(right)); 190 | BiPipe { 191 | left, 192 | right, 193 | server, 194 | traffic: Default::default(), 195 | half_close_deadline: Default::default(), 196 | } 197 | } 198 | 199 | impl BiPipe { 200 | fn poll_one_side(&mut self, cx: &mut Context, side: Side) -> Poll> { 201 | let Self { 202 | ref mut left, 203 | ref mut right, 204 | ref mut server, 205 | ref mut traffic, 206 | .. 207 | } = *self; 208 | let (reader, writer) = match side { 209 | Left => (left, right), 210 | Right => (right, left), 211 | }; 212 | loop { 213 | // read something if buffer is empty 214 | if reader.is_empty() && !reader.read_eof { 215 | let n = try_poll!(reader.poll_read_to_buffer(cx)); 216 | let amt = match side { 217 | Left => (n, 0), 218 | Right => (0, n), 219 | } 220 | .into(); 221 | server.add_traffic(amt); 222 | *traffic += amt; 223 | } 224 | 225 | // write out if buffer is not empty 226 | while !reader.is_empty() { 227 | try_poll!(reader.poll_write_buffer_to(cx, &mut writer.stream)); 228 | } 229 | reader.shrink_private_buffer_if_need(); 230 | 231 | // flush and does half close if seen eof 232 | if reader.read_eof { 233 | // shutdown implies flush 234 | match Pin::new(&mut writer.stream).poll_shutdown(cx) { 235 | Poll::Pending => return Poll::Pending, 236 | Poll::Ready(Ok(_)) => (), 237 | Poll::Ready(Err(err)) => debug!("fail to shutdown: {}", err), 238 | } 239 | reader.all_done = true; 240 | return Poll::Ready(Ok(())); 241 | } 242 | } 243 | } 244 | } 245 | 246 | impl Future for BiPipe { 247 | type Output = io::Result; 248 | 249 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 250 | if !self.left.all_done { 251 | trace!("(BiPipe) poll left"); 252 | if let Poll::Ready(Err(err)) = self.poll_one_side(cx, Left) { 253 | return Poll::Ready(Err(err)); 254 | } 255 | } 256 | if !self.right.all_done { 257 | trace!("(BiPipe) poll right"); 258 | if let Poll::Ready(Err(err)) = self.poll_one_side(cx, Right) { 259 | return Poll::Ready(Err(err)); 260 | } 261 | } 262 | match (self.left.all_done, self.right.all_done) { 263 | (true, true) => Poll::Ready(Ok(self.traffic)), 264 | (false, false) => Poll::Pending, 265 | _ => { 266 | // Half close 267 | match &mut self.half_close_deadline { 268 | None => { 269 | // Setup a deadline then poll it 270 | let mut deadline = Box::pin(sleep(HALF_CLOSE_TIMEOUT)); 271 | let _ = deadline.as_mut().poll(cx); // always return pending 272 | self.half_close_deadline = Some(deadline); 273 | Poll::Pending 274 | } 275 | Some(deadline) if !deadline.is_elapsed() => { 276 | trace!("(BiPipe) reset half-close timer"); 277 | deadline.as_mut().reset(Instant::now() + HALF_CLOSE_TIMEOUT); 278 | Poll::Pending 279 | } 280 | Some(_) => { 281 | debug!("(BiPipe) half-close conn timed out"); 282 | Poll::Ready(Ok(self.traffic)) 283 | } 284 | } 285 | } 286 | } 287 | } 288 | } 289 | -------------------------------------------------------------------------------- /src/proxy/http.rs: -------------------------------------------------------------------------------- 1 | use base64::prelude::{Engine, BASE64_STANDARD}; 2 | use httparse::{Response, Status, EMPTY_HEADER}; 3 | use std::io::{self, ErrorKind}; 4 | use std::net::IpAddr; 5 | use tokio::{ 6 | io::{AsyncReadExt, AsyncWriteExt}, 7 | net::TcpStream, 8 | }; 9 | use tracing::{debug, instrument, trace}; 10 | 11 | use crate::proxy::{Address, Destination}; 12 | 13 | use super::UserPassAuthCredential; 14 | 15 | macro_rules! ensure_200 { 16 | ($code:expr) => { 17 | if $code != 200 { 18 | return Err(io::Error::new( 19 | ErrorKind::Other, 20 | format!("proxy return error: {}", $code), 21 | )); 22 | } 23 | }; 24 | } 25 | 26 | const BUF_LEN: usize = 1024; 27 | 28 | #[instrument(name = "http_handshake", skip_all)] 29 | pub async fn handshake( 30 | stream: &mut TcpStream, 31 | addr: &Destination, 32 | data: Option, 33 | with_playload: bool, 34 | user_pass_auth: &Option, 35 | ) -> io::Result<()> 36 | where 37 | T: AsRef<[u8]> + 'static, 38 | { 39 | let mut buf = build_request(addr, user_pass_auth).into_bytes(); 40 | stream.write_all(&buf).await?; 41 | 42 | if with_playload { 43 | // violate the protocol but save latency 44 | if let Some(ref data) = data { 45 | stream.write_all(data.as_ref()).await?; 46 | } 47 | } 48 | 49 | // Parse HTTP response 50 | buf.clear(); 51 | let mut bytes_read = 0; 52 | let mut sink = [0u8; BUF_LEN]; 53 | loop { 54 | let mut headers = [EMPTY_HEADER; 16]; 55 | let mut response = Response::new(&mut headers); 56 | buf.resize(bytes_read + BUF_LEN, 0); 57 | let peek_len = stream.peek(&mut buf).await?; 58 | bytes_read += peek_len; 59 | trace!("bytes peek: {}", bytes_read); 60 | 61 | match response.parse(&buf[..bytes_read]) { 62 | Err(e) => return Err(io::Error::new(ErrorKind::Other, e)), 63 | Ok(Status::Partial) => { 64 | debug!("partial http reponse read; wait for more data"); 65 | if let Some(code) = response.code { 66 | ensure_200!(code); 67 | } 68 | if bytes_read > 64_000 { 69 | return Err(io::Error::new(ErrorKind::Other, "response too large")); 70 | } 71 | // Drop peeked data from socket buffer 72 | stream.read_exact(&mut sink[..peek_len]).await?; 73 | } 74 | Ok(Status::Complete(bytes_request)) => { 75 | trace!( 76 | "response {}, {} bytes", 77 | response.code.unwrap(), 78 | bytes_request 79 | ); 80 | ensure_200!(response.code.unwrap()); 81 | let len = peek_len - (bytes_read - bytes_request); 82 | stream.read_exact(&mut sink[..len]).await?; 83 | break; 84 | } 85 | } 86 | } 87 | 88 | // Write out payload if exist 89 | if !with_playload { 90 | if let Some(ref data) = data { 91 | stream.write_all(data.as_ref()).await?; 92 | } 93 | } 94 | trace!("HTTP CONNECT handshaking done"); 95 | Ok(()) 96 | } 97 | 98 | fn build_request(addr: &Destination, user_pass_auth: &Option) -> String { 99 | let port = addr.port; 100 | let host = match addr.host { 101 | Address::Ip(ip) => match ip { 102 | IpAddr::V4(ip) => format!("{}:{}", ip, port), 103 | IpAddr::V6(ip) => format!("[{}]:{}", ip, port), 104 | }, 105 | Address::Domain(ref s) => format!("{}:{}", s, port), 106 | }; 107 | 108 | if let Some(user_pass_auth) = user_pass_auth { 109 | let auth = format!( 110 | "{username}:{password}", 111 | username = user_pass_auth.username, 112 | password = user_pass_auth.password 113 | ); 114 | let basic_auth = BASE64_STANDARD.encode(auth); 115 | format!( 116 | "CONNECT {host} HTTP/1.1\r\n\ 117 | Host: {host}\r\n\ 118 | Proxy-Authorization: Basic {basic_auth}\r\n\r\n", 119 | host = host, 120 | basic_auth = basic_auth, 121 | ) 122 | } else { 123 | format!( 124 | "CONNECT {host} HTTP/1.1\r\n\ 125 | Host: {host}\r\n\r\n", 126 | host = host 127 | ) 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /src/proxy/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod copy; 2 | pub mod http; 3 | use flexstr::{shared_fmt, SharedStr}; 4 | #[cfg(feature = "score_script")] 5 | use rlua::prelude::*; 6 | pub mod socks5; 7 | use parking_lot::{Mutex, RwLock}; 8 | use serde::{Serialize, Serializer}; 9 | use serde_with::{serde_as, DisplayFromStr}; 10 | use std::{ 11 | cmp, fmt, 12 | hash::{Hash, Hasher}, 13 | io, 14 | net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, 15 | ops::{Add, AddAssign}, 16 | str::FromStr, 17 | sync::atomic::{AtomicUsize, Ordering}, 18 | time::Duration, 19 | }; 20 | use tokio::net::TcpStream; 21 | use tracing::{debug, instrument}; 22 | 23 | use crate::policy::capabilities::CapSet; 24 | 25 | const GRAPHITE_PATH_PREFIX: &str = "moproxy.proxy_servers"; 26 | 27 | #[derive(Hash, Eq, PartialEq, Clone, Debug, Serialize)] 28 | pub enum ProxyProto { 29 | #[serde(rename = "SOCKSv5")] 30 | Socks5 { 31 | /// Not actually do the SOCKSv5 handshaking but sending all bytes as 32 | /// per protocol specified in once followed by the actual request. 33 | /// This saves [TODO] round-trip delay but may cause problem on some 34 | /// servers. 35 | fake_handshaking: bool, 36 | user_pass_auth: Option, 37 | }, 38 | #[serde(rename = "HTTP")] 39 | Http { 40 | /// Allow to send app-level data as payload on CONNECT request. 41 | /// This can eliminate 1 round-trip delay but may cause problem on 42 | /// some servers. 43 | /// 44 | /// RFC 7231: 45 | /// >> A payload within a CONNECT request message has no defined 46 | /// >> semantics; sending a payload body on a CONNECT request might 47 | /// cause some existing implementations to reject the request. 48 | connect_with_payload: bool, 49 | user_pass_auth: Option, 50 | }, 51 | Direct, 52 | } 53 | 54 | #[derive(Hash, Eq, PartialEq, Clone, Debug, Serialize)] 55 | pub struct UserPassAuthCredential { 56 | username: SharedStr, 57 | #[serde(skip_serializing)] 58 | password: SharedStr, 59 | } 60 | 61 | impl UserPassAuthCredential { 62 | pub fn new>(username: T, password: T) -> Self { 63 | Self { 64 | username: username.as_ref().into(), 65 | password: password.as_ref().into(), 66 | } 67 | } 68 | } 69 | 70 | #[allow(clippy::mutable_key_type)] 71 | #[derive(Debug, Serialize)] 72 | pub struct ProxyServer { 73 | pub addr: SocketAddr, 74 | pub proto: ProxyProto, 75 | pub tag: SharedStr, 76 | config: RwLock, 77 | status: Mutex, 78 | traffic: AtomicTraffic, 79 | } 80 | 81 | #[derive(Debug, Serialize, Clone)] 82 | pub struct ProxyServerConfig { 83 | pub test_dns: SocketAddr, 84 | pub max_wait: Duration, 85 | pub capabilities: CapSet, 86 | score_base: i32, 87 | } 88 | 89 | #[cfg(feature = "score_script")] 90 | impl ToLua<'_> for ProxyServerConfig { 91 | fn to_lua(self, ctx: LuaContext<'_>) -> LuaResult> { 92 | let table = ctx.create_table()?; 93 | table.set("test_dns", self.test_dns.to_string())?; 94 | table.set("max_wait", self.max_wait.as_secs_f32())?; 95 | table.set("score_base", self.score_base)?; 96 | table.to_lua(ctx) 97 | } 98 | } 99 | 100 | #[derive(Debug, Serialize, Clone, Copy, Default)] 101 | pub enum Delay { 102 | #[default] 103 | Unknown, 104 | Some(Duration), 105 | TimedOut, 106 | } 107 | 108 | impl Delay { 109 | pub fn map(self, func: F) -> Option 110 | where 111 | F: FnOnce(Duration) -> T, 112 | { 113 | if let Delay::Some(d) = self { 114 | Some(func(d)) 115 | } else { 116 | None 117 | } 118 | } 119 | } 120 | 121 | impl From> for Delay { 122 | fn from(d: Option) -> Self { 123 | d.map(Self::Some).unwrap_or(Self::TimedOut) 124 | } 125 | } 126 | 127 | #[cfg(feature = "score_script")] 128 | impl ToLua<'_> for Delay { 129 | fn to_lua(self, ctx: LuaContext<'_>) -> LuaResult> { 130 | match self { 131 | Delay::Some(d) => Some(d.as_secs_f32()), 132 | Delay::TimedOut => Some(-1f32), 133 | Delay::Unknown => None, 134 | } 135 | .to_lua(ctx) 136 | } 137 | } 138 | 139 | #[serde_as] 140 | #[derive(Debug, Serialize, Clone, Copy, Default)] 141 | pub struct ProxyServerStatus { 142 | pub delay: Delay, 143 | pub score: Option, 144 | pub conn_alive: u32, 145 | pub conn_total: u32, 146 | pub conn_error: u32, 147 | #[serde_as(as = "DisplayFromStr")] 148 | pub close_history: u64, 149 | } 150 | 151 | #[cfg(feature = "score_script")] 152 | impl ToLua<'_> for ProxyServerStatus { 153 | fn to_lua(self, ctx: LuaContext<'_>) -> LuaResult> { 154 | let status = ctx.create_table()?; 155 | status.set("delay", self.delay)?; 156 | status.set("score", self.score)?; 157 | status.set("conn_alive", self.conn_alive)?; 158 | status.set("conn_total", self.conn_total)?; 159 | status.set("conn_error", self.conn_error)?; 160 | status.set("close_history", self.close_history)?; 161 | status.to_lua(ctx) 162 | } 163 | } 164 | 165 | impl Hash for ProxyServer { 166 | fn hash(&self, state: &mut H) { 167 | self.addr.hash(state); 168 | self.proto.hash(state); 169 | self.tag.hash(state); 170 | } 171 | } 172 | 173 | impl PartialEq for ProxyServer { 174 | fn eq(&self, other: &ProxyServer) -> bool { 175 | self.addr == other.addr && self.proto == other.proto && self.tag == other.tag 176 | } 177 | } 178 | 179 | impl Eq for ProxyServer {} 180 | 181 | #[cfg(feature = "score_script")] 182 | impl ToLua<'_> for &ProxyServer { 183 | fn to_lua(self, ctx: LuaContext<'_>) -> LuaResult> { 184 | let table = ctx.create_table()?; 185 | table.set("addr", self.addr.to_string())?; 186 | table.set("proto", self.proto.to_string())?; 187 | table.set("tag", self.tag.to_string())?; 188 | table.set("config", self.config.read().clone())?; 189 | table.set("status", *self.status.lock())?; 190 | table.set("traffic", self.traffic())?; 191 | table.to_lua(ctx) 192 | } 193 | } 194 | 195 | #[derive(Hash, Clone)] 196 | pub enum Address { 197 | Ip(IpAddr), 198 | Domain(SharedStr), 199 | } 200 | 201 | impl Address { 202 | pub fn is_domain(&self) -> bool { 203 | matches!(self, Address::Domain(_)) 204 | } 205 | 206 | pub fn domain(&self) -> Option { 207 | match self { 208 | Self::Ip(_) => None, 209 | Self::Domain(name) => Some(SharedStr::from(name.as_ref())), 210 | } 211 | } 212 | } 213 | 214 | impl fmt::Debug for Address { 215 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 216 | match self { 217 | Address::Domain(name) => write!(f, "{}", name), 218 | Address::Ip(IpAddr::V4(addr)) => write!(f, "{}", addr), 219 | Address::Ip(IpAddr::V6(addr)) => write!(f, "[{}]", addr), 220 | } 221 | } 222 | } 223 | 224 | impl From<[u8; 4]> for Address { 225 | fn from(buf: [u8; 4]) -> Self { 226 | Address::Ip(IpAddr::V4(Ipv4Addr::from(buf))) 227 | } 228 | } 229 | 230 | impl From<[u8; 16]> for Address { 231 | fn from(buf: [u8; 16]) -> Self { 232 | Address::Ip(IpAddr::V6(Ipv6Addr::from(buf))) 233 | } 234 | } 235 | 236 | #[derive(Clone)] 237 | pub struct Destination { 238 | pub host: Address, 239 | pub port: u16, 240 | } 241 | 242 | impl fmt::Debug for Destination { 243 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 244 | write!(f, "{}:{}", self.host, self.port) 245 | } 246 | } 247 | 248 | impl From for Destination { 249 | fn from(addr: SocketAddr) -> Self { 250 | Destination { 251 | host: Address::Ip(addr.ip()), 252 | port: addr.port(), 253 | } 254 | } 255 | } 256 | 257 | impl<'a> From<(&'a str, u16)> for Destination { 258 | fn from(addr: (&'a str, u16)) -> Self { 259 | Destination { 260 | host: Address::Domain(addr.0.into()), 261 | port: addr.1, 262 | } 263 | } 264 | } 265 | 266 | impl From<(Address, u16)> for Destination { 267 | fn from(addr_port: (Address, u16)) -> Self { 268 | Destination { 269 | host: addr_port.0, 270 | port: addr_port.1, 271 | } 272 | } 273 | } 274 | 275 | #[derive(Debug)] 276 | pub struct AtomicTraffic { 277 | tx_bytes: AtomicUsize, 278 | rx_bytes: AtomicUsize, 279 | } 280 | 281 | impl Default for AtomicTraffic { 282 | fn default() -> Self { 283 | Self { 284 | tx_bytes: AtomicUsize::new(0), 285 | rx_bytes: AtomicUsize::new(0), 286 | } 287 | } 288 | } 289 | 290 | impl AtomicTraffic { 291 | pub fn add(&self, amt: Traffic) { 292 | self.rx_bytes.fetch_add(amt.rx_bytes, Ordering::Relaxed); 293 | self.tx_bytes.fetch_add(amt.tx_bytes, Ordering::Relaxed); 294 | } 295 | 296 | pub fn read(&self) -> Traffic { 297 | Traffic { 298 | tx_bytes: self.tx_bytes.load(Ordering::Relaxed), 299 | rx_bytes: self.rx_bytes.load(Ordering::Relaxed), 300 | } 301 | } 302 | } 303 | 304 | impl Serialize for AtomicTraffic { 305 | fn serialize(&self, serializer: S) -> Result { 306 | self.read().serialize(serializer) 307 | } 308 | } 309 | 310 | #[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize)] 311 | pub struct Traffic { 312 | pub tx_bytes: usize, 313 | pub rx_bytes: usize, 314 | } 315 | 316 | impl From<(usize, usize)> for Traffic { 317 | fn from(tx_rx_bytes: (usize, usize)) -> Self { 318 | Self { 319 | tx_bytes: tx_rx_bytes.0, 320 | rx_bytes: tx_rx_bytes.1, 321 | } 322 | } 323 | } 324 | 325 | impl Add for Traffic { 326 | type Output = Traffic; 327 | 328 | fn add(self, other: Traffic) -> Traffic { 329 | Traffic { 330 | tx_bytes: self.tx_bytes + other.tx_bytes, 331 | rx_bytes: self.rx_bytes + other.rx_bytes, 332 | } 333 | } 334 | } 335 | 336 | impl AddAssign for Traffic { 337 | fn add_assign(&mut self, other: Traffic) { 338 | *self = *self + other; 339 | } 340 | } 341 | 342 | #[cfg(feature = "score_script")] 343 | impl ToLua<'_> for Traffic { 344 | fn to_lua(self, ctx: LuaContext<'_>) -> LuaResult> { 345 | let table = ctx.create_table()?; 346 | table.set("tx_bytes", self.tx_bytes)?; 347 | table.set("rx_bytes", self.rx_bytes)?; 348 | table.to_lua(ctx) 349 | } 350 | } 351 | 352 | impl ProxyProto { 353 | pub fn socks5(fake_handshaking: bool) -> Self { 354 | ProxyProto::Socks5 { 355 | fake_handshaking, 356 | user_pass_auth: None, 357 | } 358 | } 359 | 360 | pub fn socks5_with_auth(credential: UserPassAuthCredential) -> Self { 361 | ProxyProto::Socks5 { 362 | fake_handshaking: false, 363 | user_pass_auth: Some(credential), 364 | } 365 | } 366 | 367 | pub fn http(connect_with_payload: bool, credential: Option) -> Self { 368 | ProxyProto::Http { 369 | connect_with_payload, 370 | user_pass_auth: credential, 371 | } 372 | } 373 | } 374 | 375 | impl ProxyServerConfig { 376 | fn new( 377 | test_dns: SocketAddr, 378 | score_base: Option, 379 | capabilities: Option, 380 | max_wait: Duration, 381 | ) -> Self { 382 | Self { 383 | test_dns, 384 | max_wait, 385 | capabilities: capabilities.unwrap_or_default(), 386 | score_base: score_base.unwrap_or(0), 387 | } 388 | } 389 | } 390 | 391 | impl ProxyServer { 392 | pub fn new( 393 | addr: SocketAddr, 394 | proto: ProxyProto, 395 | test_dns: SocketAddr, 396 | max_wait: Duration, 397 | capabilities: Option, 398 | tag: Option<&str>, 399 | score_base: Option, 400 | ) -> ProxyServer { 401 | ProxyServer { 402 | addr, 403 | proto, 404 | tag: match tag { 405 | None => shared_fmt!("{}", addr.port()), 406 | Some(s) => { 407 | if !s.is_ascii() || s.contains(' ') || s.contains('\n') { 408 | panic!( 409 | "Tag \"{}\" contains white spaces, line \ 410 | breaks, or non-ASCII characters.", 411 | s 412 | ); 413 | } 414 | SharedStr::from(s) 415 | } 416 | }, 417 | config: ProxyServerConfig::new(test_dns, score_base, capabilities, max_wait).into(), 418 | status: Default::default(), 419 | traffic: Default::default(), 420 | } 421 | } 422 | 423 | pub fn direct(max_wait: Duration) -> Self { 424 | let stub_addr = "0.0.0.0:0".parse().unwrap(); 425 | Self { 426 | addr: stub_addr, 427 | proto: ProxyProto::Direct, 428 | tag: "__DIRECT__".into(), 429 | config: ProxyServerConfig::new(stub_addr, None, None, max_wait).into(), 430 | status: Default::default(), 431 | traffic: Default::default(), 432 | } 433 | } 434 | 435 | pub fn copy_config_from(&self, from: &Self) { 436 | if !std::ptr::eq(&from.config, &self.config) { 437 | *self.config.write() = from.config.read().clone(); 438 | } 439 | } 440 | 441 | #[instrument(skip_all)] 442 | pub async fn connect(&self, addr: &Destination, data: Option) -> io::Result 443 | where 444 | T: AsRef<[u8]> + 'static, 445 | { 446 | let mut stream = TcpStream::connect(&self.addr).await?; 447 | debug!(remote = %stream.peer_addr()?, "TCP established"); 448 | stream.set_nodelay(true)?; 449 | 450 | match &self.proto { 451 | ProxyProto::Direct => unimplemented!(), 452 | ProxyProto::Socks5 { 453 | fake_handshaking, 454 | user_pass_auth, 455 | } => { 456 | socks5::handshake(&mut stream, addr, data, *fake_handshaking, user_pass_auth) 457 | .await? 458 | } 459 | ProxyProto::Http { 460 | connect_with_payload, 461 | user_pass_auth, 462 | } => { 463 | http::handshake( 464 | &mut stream, 465 | addr, 466 | data, 467 | *connect_with_payload, 468 | user_pass_auth, 469 | ) 470 | .await? 471 | } 472 | } 473 | Ok(stream) 474 | } 475 | 476 | pub fn status_snapshot(&self) -> ProxyServerStatus { 477 | *self.status.lock() 478 | } 479 | 480 | pub fn score(&self) -> Option { 481 | self.status.lock().score 482 | } 483 | 484 | pub fn traffic(&self) -> Traffic { 485 | self.traffic.read() 486 | } 487 | 488 | pub fn max_wait(&self) -> Duration { 489 | self.config.read().max_wait 490 | } 491 | 492 | pub fn test_dns(&self) -> SocketAddr { 493 | self.config.read().test_dns 494 | } 495 | 496 | pub fn update_delay(&self, delay: Option) { 497 | let mut status = self.status.lock(); 498 | let config = self.config.read(); 499 | 500 | if let Some(delay) = delay { 501 | let last_score = status.score.unwrap_or_else(|| { 502 | match status.delay { 503 | Delay::Some(d) => d, 504 | Delay::Unknown => delay, 505 | Delay::TimedOut => config.max_wait, 506 | } 507 | .as_millis() as i32 508 | + config.score_base 509 | }); 510 | let err_rate = status 511 | .recent_error_rate(16) 512 | .min(status.recent_error_rate(64)); 513 | 514 | let score = delay.as_millis() as i32 + config.score_base; 515 | // give penalty for continuous errors 516 | let score = score + (score as f32 * err_rate * 10f32).round() as i32; 517 | // moving average on score 518 | // give more weight to delays exceed the mean for network jitter penalty 519 | let score = if score < last_score { 520 | (last_score * 9 + score) / 10 521 | } else { 522 | (last_score * 8 + score * 2) / 10 523 | }; 524 | status.score = Some(score); 525 | status.delay = Delay::Some(delay); 526 | 527 | // Shift error history 528 | // This give the server with high error penalty a chance to recovery. 529 | status.close_history <<= 1; 530 | } else { 531 | // Timed out 532 | status.delay = Delay::TimedOut; 533 | status.score = None; 534 | }; 535 | } 536 | 537 | #[cfg(feature = "score_script")] 538 | pub fn update_delay_with_lua(&self, delay: Option, ctx: LuaContext) -> LuaResult<()> { 539 | let func: LuaFunction = ctx.globals().get("calc_score")?; 540 | let delay_secs = delay.map(|t| t.as_secs_f32()); 541 | let score: Option = func.call((self, delay_secs))?; 542 | 543 | let mut status = self.status.lock(); 544 | status.score = score; 545 | status.delay = delay.into(); 546 | Ok(()) 547 | } 548 | 549 | pub fn add_traffic(&self, traffic: Traffic) { 550 | self.traffic.add(traffic); 551 | } 552 | 553 | pub fn update_stats_conn_open(&self) { 554 | let mut status = self.status.lock(); 555 | status.conn_alive += 1; 556 | status.conn_total += 1; 557 | } 558 | 559 | pub fn update_stats_conn_close(&self, has_error: bool) { 560 | let mut status = self.status.lock(); 561 | status.conn_alive -= 1; 562 | status.close_history <<= 1; 563 | if has_error { 564 | status.conn_error += 1; 565 | status.close_history += 1; 566 | } 567 | } 568 | 569 | pub fn graphite_path(&self, suffix: &str) -> String { 570 | format!( 571 | "{}.{}.{}", 572 | GRAPHITE_PATH_PREFIX, 573 | self.tag.replace('.', "_"), 574 | suffix 575 | ) 576 | } 577 | 578 | pub fn capable_anyof(&self, caps: &CapSet) -> bool { 579 | self.config.read().capabilities.has_intersection(caps) 580 | } 581 | } 582 | 583 | impl ProxyServerStatus { 584 | pub fn recent_error_count(&self, n: u8) -> u8 { 585 | let n = 64 - cmp::min(n, 64); 586 | (self.close_history << n).count_ones() as u8 587 | } 588 | 589 | pub fn recent_error_rate(&self, n: u8) -> f32 { 590 | self.recent_error_count(n) as f32 / (cmp::min(n, 64) as f32) 591 | } 592 | } 593 | 594 | impl fmt::Display for ProxyServer { 595 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 596 | if self.proto == ProxyProto::Direct { 597 | f.write_str("DIRECT") 598 | } else { 599 | write!(f, "{} ({} {})", self.tag, self.proto, self.addr) 600 | } 601 | } 602 | } 603 | 604 | impl fmt::Display for ProxyProto { 605 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 606 | match *self { 607 | ProxyProto::Socks5 { .. } => write!(f, "SOCKSv5"), 608 | ProxyProto::Http { .. } => write!(f, "HTTP"), 609 | ProxyProto::Direct { .. } => write!(f, "DIRECT"), 610 | } 611 | } 612 | } 613 | 614 | impl fmt::Display for Address { 615 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 616 | match *self { 617 | Address::Ip(ref ip) => write!(f, "{}", ip), 618 | Address::Domain(ref s) => write!(f, "{}", s), 619 | } 620 | } 621 | } 622 | 623 | impl fmt::Display for Destination { 624 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 625 | write!(f, "{}:{}", self.host, self.port) 626 | } 627 | } 628 | 629 | impl FromStr for ProxyProto { 630 | type Err = (); 631 | fn from_str(s: &str) -> Result { 632 | match s.to_lowercase().as_str() { 633 | // default to disable fake handshaking 634 | "socks5" | "socksv5" => Ok(ProxyProto::socks5(false)), 635 | // default to disable connect with payload 636 | "http" => Ok(ProxyProto::http(false, None)), 637 | _ => Err(()), 638 | } 639 | } 640 | } 641 | -------------------------------------------------------------------------------- /src/proxy/socks5.rs: -------------------------------------------------------------------------------- 1 | use crate::proxy::{Address, Destination}; 2 | use std::io::{self, ErrorKind}; 3 | use std::net::IpAddr; 4 | use tokio::{ 5 | io::{AsyncReadExt, AsyncWriteExt}, 6 | net::TcpStream, 7 | }; 8 | use tracing::{instrument, trace}; 9 | 10 | use super::UserPassAuthCredential; 11 | 12 | #[instrument(name = "socks5_handshake", skip_all)] 13 | pub async fn handshake( 14 | stream: &mut TcpStream, 15 | addr: &Destination, 16 | data: Option, 17 | fake_handshaking: bool, 18 | user_pass_auth: &Option, 19 | ) -> io::Result<()> 20 | where 21 | T: AsRef<[u8]>, 22 | { 23 | if fake_handshaking && user_pass_auth.is_none() { 24 | trace!("socks: do FAKE handshake w/ {:?}", addr); 25 | fake_handshake(stream, addr, data).await 26 | } else { 27 | trace!("socks: do FULL handshake w/ {:?}", addr); 28 | full_handshake(stream, addr, data, user_pass_auth).await 29 | } 30 | } 31 | 32 | pub async fn fake_handshake( 33 | stream: &mut TcpStream, 34 | addr: &Destination, 35 | data: Option, 36 | ) -> io::Result<()> 37 | where 38 | T: AsRef<[u8]>, 39 | { 40 | let mut buf = Vec::with_capacity(16); 41 | buf.extend_from_slice(&[5, 1, 0]); 42 | build_request(&mut buf, addr); 43 | stream.write_all(&buf).await?; 44 | if let Some(data) = data { 45 | stream.write_all(data.as_ref()).await?; 46 | } 47 | buf.resize(12, 0); 48 | stream.read_exact(&mut buf).await?; 49 | Ok(()) 50 | } 51 | 52 | macro_rules! err { 53 | ($msg:expr) => { 54 | return Err(io::Error::new(ErrorKind::Other, $msg)) 55 | }; 56 | } 57 | 58 | pub async fn full_handshake( 59 | stream: &mut TcpStream, 60 | addr: &Destination, 61 | data: Option, 62 | user_pass_auth: &Option, 63 | ) -> io::Result<()> 64 | where 65 | T: AsRef<[u8]>, 66 | { 67 | let mut buf = vec![]; 68 | if user_pass_auth.is_none() { 69 | // Send request w/ auth method 0x00 (no auth) 70 | buf.extend(&[0x05, 0x01, 0x00]) 71 | } else { 72 | // Or, include 0x02 (username/password auth) 73 | buf.extend(&[0x05, 0x02, 0x00, 0x02]) 74 | }; 75 | trace!("socks: write {:?}", buf); 76 | stream.write_all(&buf).await?; 77 | 78 | // Server select auth method 79 | let mut buf = vec![0; 2]; 80 | stream.read_exact(&mut buf).await?; 81 | trace!("socks: read {:?}", buf); 82 | match buf[..2] { 83 | // 0xff: no acceptable method 84 | [0x05, 0xff] => err!("auth required by socks server"), 85 | // 0x00: no auth required 86 | [0x05, 0x00] => (), 87 | // 0x02: username/password method 88 | [0x05, 0x02] => { 89 | if let Some(auth) = user_pass_auth { 90 | if auth.username.len() > 255 || auth.password.len() > 255 { 91 | panic!("SOCKSv5 username/password exceeds 255 bytes"); 92 | } 93 | buf.clear(); 94 | buf.push(0x01); // version 95 | buf.push(auth.username.len() as u8); 96 | buf.extend(auth.username.as_bytes()); 97 | buf.push(auth.password.len() as u8); 98 | buf.extend(auth.password.as_bytes()); 99 | trace!("socks: write auth {:?}", buf); 100 | stream.write_all(&buf).await?; 101 | 102 | // Parse response 103 | buf.resize(2, 0); 104 | stream.read_exact(&mut buf).await?; 105 | trace!("socks: read {:?}", buf); 106 | if buf != [0x01, 0x00] { 107 | err!("auth rejected by SOCKSv5 server") 108 | } 109 | } else { 110 | err!("missing username/password required by socks server"); 111 | } 112 | } 113 | _ => err!("unrecognized reply from socks server"), 114 | } 115 | 116 | // Write the actual request 117 | buf.clear(); 118 | build_request(&mut buf, addr); 119 | trace!("socks: write request {:?}", buf); 120 | stream.write_all(&buf).await?; 121 | 122 | // Check server's reply 123 | buf.resize(10, 0); 124 | stream.read_exact(&mut buf).await?; 125 | trace!("socks: read reply {:?}", buf); 126 | if !buf.starts_with(&[0x05, 0x00]) { 127 | err!("socks server reply error"); 128 | } 129 | if buf[3] == 4 { 130 | // Consume truncted IPv6 address 131 | buf.resize(16 - 4, 0); 132 | stream.read_exact(&mut buf).await?; 133 | } 134 | 135 | // Write out payload if exist 136 | if let Some(data) = data { 137 | trace!("socks: write payload {:?}", data.as_ref()); 138 | stream.write_all(data.as_ref()).await?; 139 | } 140 | Ok(()) 141 | } 142 | 143 | fn build_request(buffer: &mut Vec, addr: &Destination) { 144 | buffer.extend_from_slice(&[5, 1, 0]); 145 | match addr.host { 146 | Address::Ip(ip) => match ip { 147 | IpAddr::V4(ip) => { 148 | buffer.push(0x01); 149 | buffer.extend_from_slice(&ip.octets()); 150 | } 151 | IpAddr::V6(ip) => { 152 | buffer.push(0x04); 153 | buffer.extend_from_slice(&ip.octets()); 154 | } 155 | }, 156 | Address::Domain(ref host) => { 157 | buffer.push(0x03); 158 | buffer.push(host.len() as u8); 159 | buffer.extend_from_slice(host.as_bytes()); 160 | } 161 | }; 162 | buffer.push((addr.port >> 8) as u8); 163 | buffer.push(addr.port as u8); 164 | } 165 | -------------------------------------------------------------------------------- /src/server.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, bail, Context}; 2 | use futures_util::{stream, StreamExt}; 3 | use ini::Ini; 4 | use parking_lot::RwLock; 5 | use std::{ 6 | collections::HashSet, io, net::SocketAddr, net::ToSocketAddrs, path::PathBuf, sync::Arc, 7 | time::Duration, 8 | }; 9 | use tokio::net::{TcpListener, TcpStream}; 10 | use tracing::{error, info, instrument, warn}; 11 | 12 | use crate::{cli::CliArgs, FromOptionStr}; 13 | #[cfg(feature = "web_console")] 14 | use moproxy::web::WebServer; 15 | use moproxy::{ 16 | client::{FailedClient, NewClient}, 17 | futures_stream::TcpListenerStream, 18 | monitor::Monitor, 19 | policy::{parser, ActionType, Policy}, 20 | proxy::{ProxyProto, ProxyServer, UserPassAuthCredential}, 21 | web::WebServerListener, 22 | }; 23 | 24 | #[derive(Clone)] 25 | pub(crate) struct MoProxy { 26 | cli_args: Arc, 27 | server_list_config: Arc, 28 | pub(crate) monitor: Monitor, 29 | direct_server: Arc, 30 | pub(crate) policy: Arc>, 31 | #[cfg(feature = "web_console")] 32 | web_server: Option, 33 | } 34 | 35 | pub(crate) struct MoProxyListener { 36 | moproxy: MoProxy, 37 | listeners: Vec, 38 | #[cfg(feature = "web_console")] 39 | web_server: Option, 40 | } 41 | 42 | #[derive(Debug)] 43 | enum PolicyResult { 44 | Filtered(Vec>), 45 | Direct, 46 | Reject, 47 | } 48 | 49 | impl MoProxy { 50 | pub(crate) async fn new(args: CliArgs) -> anyhow::Result { 51 | // Load proxy server list 52 | let server_list_config = ServerListConfig::new(&args); 53 | let servers = server_list_config.load().context("fail to load servers")?; 54 | let direct_server = Arc::new(ProxyServer::direct(args.max_wait)); 55 | 56 | // Load policy 57 | let policy = { 58 | if let Some(ref path) = args.policy { 59 | let policy = Policy::load_from_file(path).context("cannot to load policy")?; 60 | Arc::new(RwLock::new(policy)) 61 | } else { 62 | Default::default() 63 | } 64 | }; 65 | 66 | // Setup proxy monitor 67 | let graphite = args.graphite; 68 | #[cfg(feature = "score_script")] 69 | let mut monitor = Monitor::new(servers, graphite); 70 | #[cfg(not(feature = "score_script"))] 71 | let monitor = Monitor::new(servers, graphite); 72 | #[cfg(feature = "score_script")] 73 | { 74 | if let Some(ref path) = args.score_script { 75 | monitor 76 | .load_score_script(path) 77 | .context("fail to load Lua script")?; 78 | } 79 | } 80 | 81 | // Setup web console 82 | #[cfg(feature = "web_console")] 83 | let web_server = if let Some(addr) = &args.web_bind { 84 | Some(WebServer::new(monitor.clone(), addr.into())?) 85 | } else { 86 | None 87 | }; 88 | 89 | // Launch monitor 90 | if args.probe_secs > 0 { 91 | tokio::spawn(monitor.clone().monitor_delay(args.probe_secs)); 92 | } 93 | 94 | Ok(Self { 95 | cli_args: Arc::new(args), 96 | server_list_config: Arc::new(server_list_config), 97 | direct_server, 98 | monitor, 99 | policy, 100 | #[cfg(feature = "web_console")] 101 | web_server, 102 | }) 103 | } 104 | 105 | pub(crate) fn reload(&self) -> anyhow::Result<()> { 106 | // Load proxy server list 107 | let servers = self.server_list_config.load()?; 108 | // Load policy 109 | let policy = match &self.cli_args.policy { 110 | Some(path) => Policy::load_from_file(path).context("cannot to load policy")?, 111 | _ => Default::default(), 112 | }; 113 | // TODO: reload lua script 114 | 115 | // Apply only if no error occur 116 | self.monitor.update_servers(servers); 117 | *self.policy.write() = policy; 118 | Ok(()) 119 | } 120 | 121 | pub(crate) async fn listen(&self) -> anyhow::Result { 122 | let ports: HashSet<_> = self.cli_args.port.iter().collect(); 123 | let mut listeners = Vec::with_capacity(ports.len()); 124 | for port in ports { 125 | let addr = SocketAddr::new(self.cli_args.host, *port); 126 | let listener = TcpListener::bind(&addr) 127 | .await 128 | .context("cannot bind to port")?; 129 | info!("listen on {}", addr); 130 | #[cfg(target_os = "linux")] 131 | if let Some(ref alg) = self.cli_args.cong_local { 132 | use moproxy::linux::tcp::TcpListenerExt; 133 | 134 | info!("set {} on {}", alg, addr); 135 | listener.set_congestion(alg).expect( 136 | "fail to set tcp congestion algorithm. \ 137 | check tcp_allowed_congestion_control?", 138 | ); 139 | } 140 | listeners.push(TcpListenerStream(listener)); 141 | } 142 | #[cfg(feature = "web_console")] 143 | let web_server = if let Some(web) = &self.web_server { 144 | Some(web.listen().await?) 145 | } else { 146 | None 147 | }; 148 | 149 | Ok(MoProxyListener { 150 | moproxy: self.clone(), 151 | listeners, 152 | #[cfg(feature = "web_console")] 153 | web_server, 154 | }) 155 | } 156 | 157 | fn apply_policy(&self, client: &NewClient) -> PolicyResult { 158 | let action = self.policy.read().matches(&client.features()); 159 | match action.action { 160 | ActionType::Reject => PolicyResult::Reject, 161 | ActionType::Direct => PolicyResult::Direct, 162 | ActionType::Require(caps) => { 163 | let servers = self 164 | .monitor 165 | .servers() 166 | .into_iter() 167 | .filter(|s| caps.iter().all(|c| s.capable_anyof(c))) 168 | .collect(); 169 | PolicyResult::Filtered(servers) 170 | } 171 | } 172 | } 173 | 174 | #[instrument(level = "error", skip_all, fields(on_port=sock.local_addr()?.port(), peer=?sock.peer_addr()?))] 175 | async fn handle_client(&self, sock: TcpStream) -> io::Result<()> { 176 | let mut client = NewClient::from_socket(sock).await?; 177 | let args = &self.cli_args; 178 | 179 | if (args.remote_dns || args.n_parallel > 1) && client.dest.port == 443 { 180 | // Try parse TLS client hello 181 | client.retrieve_dest_from_sni().await?; 182 | if args.remote_dns { 183 | client.override_dest_with_sni(); 184 | } 185 | } 186 | let result = match self.apply_policy(&client) { 187 | PolicyResult::Reject => { 188 | info!("rejected by policy"); 189 | return Ok(()); 190 | } 191 | PolicyResult::Direct => client 192 | .direct_connect(self.direct_server.clone()) 193 | .await 194 | .map_err(|err| err.into()), 195 | PolicyResult::Filtered(proxies) => { 196 | client.connect_server(proxies, args.n_parallel).await 197 | } 198 | }; 199 | let client = match result { 200 | Ok(client) => client, 201 | Err(FailedClient::Recoverable(client)) if args.allow_direct => { 202 | client.direct_connect(self.direct_server.clone()).await? 203 | } 204 | Err(_) => return Ok(()), 205 | }; 206 | client.serve().await 207 | } 208 | } 209 | 210 | impl MoProxyListener { 211 | pub(crate) async fn handle_forever(mut self) { 212 | #[cfg(feature = "web_console")] 213 | if let Some(web) = self.web_server { 214 | web.run_background() 215 | } 216 | 217 | let mut clients = stream::select_all(self.listeners.iter_mut()); 218 | while let Some(sock) = clients.next().await { 219 | let moproxy = self.moproxy.clone(); 220 | match sock { 221 | Ok(sock) => { 222 | tokio::spawn(async move { 223 | if let Err(e) = moproxy.handle_client(sock).await { 224 | info!("error on hanle client: {}", e); 225 | } 226 | }); 227 | } 228 | Err(err) => info!("error on accept client: {}", err), 229 | } 230 | } 231 | } 232 | } 233 | 234 | struct ServerListConfig { 235 | default_test_dns: SocketAddr, 236 | default_max_wait: Duration, 237 | cli_servers: Vec>, 238 | path: Option, 239 | allow_direct: bool, 240 | } 241 | 242 | impl ServerListConfig { 243 | fn new(args: &CliArgs) -> Self { 244 | let default_test_dns = args.test_dns; 245 | let default_max_wait = args.max_wait; 246 | 247 | let mut cli_servers = vec![]; 248 | for addr in &args.socks5_servers { 249 | cli_servers.push(Arc::new(ProxyServer::new( 250 | *addr, 251 | ProxyProto::socks5(false), 252 | default_test_dns, 253 | default_max_wait, 254 | None, 255 | None, 256 | None, 257 | ))); 258 | } 259 | 260 | for addr in &args.http_servers { 261 | cli_servers.push(Arc::new(ProxyServer::new( 262 | *addr, 263 | ProxyProto::http(false, None), 264 | default_test_dns, 265 | default_max_wait, 266 | None, 267 | None, 268 | None, 269 | ))); 270 | } 271 | 272 | let path = args.server_list.clone(); 273 | Self { 274 | default_test_dns, 275 | default_max_wait, 276 | cli_servers, 277 | path, 278 | allow_direct: args.allow_direct, 279 | } 280 | } 281 | 282 | #[instrument(skip_all)] 283 | fn load(&self) -> anyhow::Result>> { 284 | let mut servers = self.cli_servers.clone(); 285 | if let Some(path) = &self.path { 286 | let ini = Ini::load_from_file(path).context("cannot read server list file")?; 287 | for (section, props) in ini.iter() { 288 | if section.is_none() && props.is_empty() { 289 | // `rust-ini` always return empty general section on 0.19 & 0.20 290 | continue; 291 | } 292 | let server = self 293 | .load_proxy_from_ini_section(section, props) 294 | .with_context(|| { 295 | format!( 296 | "load [{}] from {}", 297 | section.unwrap_or(""), 298 | path.display() 299 | ) 300 | })?; 301 | servers.push(Arc::new(server)); 302 | } 303 | } 304 | if servers.is_empty() && !self.allow_direct { 305 | bail!("missing server list"); 306 | } 307 | info!("total {} server(s) loaded", servers.len()); 308 | Ok(servers) 309 | } 310 | 311 | fn load_proxy_from_ini_section( 312 | &self, 313 | section: Option<&str>, 314 | props: &ini::Properties, 315 | ) -> anyhow::Result { 316 | let tag = props.get("tag").or(section); 317 | let addr: SocketAddr = props 318 | .get("address") 319 | .ok_or(anyhow!("address not specified"))? 320 | .to_socket_addrs() 321 | .context("not a valid socket address")? 322 | .next() 323 | .unwrap(); 324 | let base = props 325 | .get("score base") 326 | .parse() 327 | .context("score base not a integer")?; 328 | let test_dns = props 329 | .get("test dns") 330 | .parse() 331 | .context("not a valid socket address")? 332 | .unwrap_or(self.default_test_dns); 333 | let max_wait = props 334 | .get("max wait") 335 | .parse() 336 | .context("not a valid number")? 337 | .map(Duration::from_secs) 338 | .unwrap_or(self.default_max_wait); 339 | if props.get("listen ports").is_some() { 340 | // TODO: add a link to how-to --policy 341 | error!("`listen ports` is not longer supported, use --policy instead"); 342 | } 343 | let (_, capabilities) = parser::capabilities(props.get("capabilities").unwrap_or_default()) 344 | .map_err(|e| e.to_owned()) 345 | .context("not a valid list of capabilities")?; 346 | let proto = match props 347 | .get("protocol") 348 | .context("protocol not specified")? 349 | .to_lowercase() 350 | .as_str() 351 | { 352 | "socks5" | "socksv5" => { 353 | let fake_hs = props 354 | .get("socks fake handshaking") 355 | .parse() 356 | .context("not a boolean value")? 357 | .unwrap_or(false); 358 | let username = props.get("socks username").unwrap_or(""); 359 | let password = props.get("socks password").unwrap_or(""); 360 | match (username.len(), password.len()) { 361 | (0, 0) => ProxyProto::socks5(fake_hs), 362 | (0, _) | (_, 0) => bail!("socks username/password is empty"), 363 | (u, p) if u > 255 || p > 255 => { 364 | bail!("socks username/password too long") 365 | } 366 | _ => ProxyProto::socks5_with_auth(UserPassAuthCredential::new( 367 | username, password, 368 | )), 369 | } 370 | } 371 | "http" => { 372 | let cwp = props 373 | .get("http allow connect payload") 374 | .parse() 375 | .context("not a boolean value")? 376 | .unwrap_or(false); 377 | let credential = match (props.get("http username"), props.get("http password")) { 378 | (None, None) => None, 379 | (Some(user), _) if user.contains(':') => { 380 | bail!("semicolon (:) in http username") 381 | } 382 | (user, pass) => Some(UserPassAuthCredential::new( 383 | user.unwrap_or(""), 384 | pass.unwrap_or(""), 385 | )), 386 | }; 387 | ProxyProto::http(cwp, credential) 388 | } 389 | _ => bail!("unknown proxy protocol"), 390 | }; 391 | Ok(ProxyServer::new( 392 | addr, 393 | proto, 394 | test_dns, 395 | max_wait, 396 | Some(capabilities), 397 | tag, 398 | base, 399 | )) 400 | } 401 | } 402 | -------------------------------------------------------------------------------- /src/web/helpers.rs: -------------------------------------------------------------------------------- 1 | use hyper::Request; 2 | use number_prefix::NumberPrefix::{self, Prefixed, Standalone}; 3 | use once_cell::sync::Lazy; 4 | use regex::Regex; 5 | use std::{fmt::Write, time::Duration}; 6 | 7 | pub trait RequestExt { 8 | fn accept_html(&self) -> bool; 9 | } 10 | 11 | impl RequestExt for Request { 12 | fn accept_html(&self) -> bool { 13 | if let Some(accpet) = self.headers().get("accpet").and_then(|v| v.to_str().ok()) { 14 | if accpet.starts_with("text/html") { 15 | return true; 16 | } else if accpet.starts_with("text/plain") { 17 | return false; 18 | } 19 | } 20 | 21 | static RE: Lazy = 22 | Lazy::new(|| Regex::new(r"(^(curl|Lynx)/|PowerShell/[\d\.]+$)").unwrap()); 23 | if let Some(ua) = self 24 | .headers() 25 | .get("user-agent") 26 | .and_then(|v| v.to_str().ok()) 27 | { 28 | !RE.is_match(ua) 29 | } else { 30 | true 31 | } 32 | } 33 | } 34 | 35 | pub trait DurationExt { 36 | fn format(&self) -> String; 37 | fn format_millis(&self) -> String; 38 | } 39 | 40 | impl DurationExt for Duration { 41 | fn format(&self) -> String { 42 | let secs = self.as_secs(); 43 | let d = secs / 86400; 44 | let h = (secs % 86400) / 3600; 45 | let m = (secs % 3600) / 60; 46 | let s = secs % 60; 47 | let mut buf = String::new(); 48 | vec![(d, 'd'), (h, 'h'), (m, 'm'), (s, 's')] 49 | .into_iter() 50 | .filter(|(v, _)| *v > 0) 51 | .take(2) 52 | .for_each(|(v, u)| { 53 | write!(&mut buf, "{}{}", v, u).unwrap(); 54 | }); 55 | buf 56 | } 57 | 58 | fn format_millis(&self) -> String { 59 | format!("{} ms", self.as_millis()) 60 | } 61 | } 62 | 63 | pub fn to_human_bytes(n: usize) -> String { 64 | if n == 0 { 65 | String::new() 66 | } else { 67 | match NumberPrefix::binary(n as f64) { 68 | Standalone(bytes) => format!("{} bytes", bytes), 69 | Prefixed(prefix, n) => format!("{:.1} {}B", n, prefix), 70 | } 71 | } 72 | } 73 | 74 | pub fn to_human_bps(n: usize) -> String { 75 | match NumberPrefix::decimal(n as f64) { 76 | Standalone(n) => format!("{} bps", n), 77 | Prefixed(prefix, n) => format!("{:.0} {}bps", n, prefix), 78 | } 79 | } 80 | 81 | pub fn to_human_bps_prefix_only(n: usize) -> String { 82 | match NumberPrefix::decimal(n as f64) { 83 | Standalone(n) => format!("{} ", n), 84 | Prefixed(prefix, n) => format!("{:.0}{}", n, prefix), 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/web/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | moproxy 6 | 22 | 23 | 24 |

moproxy

25 |

moproxy is running. 26 | 27 |

28 | 29 | 30 | 31 | 32 |

Proxy servers

33 |

34 | Connections: 35 | 0 36 | Throughput: 37 | 38 | - 39 | 40 | - 41 |

42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 |
ServerScoreDelayCUR / TTLUp / Down
51 | 52 | 184 | 185 | 186 | -------------------------------------------------------------------------------- /src/web/mod.rs: -------------------------------------------------------------------------------- 1 | mod helpers; 2 | mod open_metrics; 3 | #[cfg(feature = "rich_web")] 4 | mod rich; 5 | use anyhow::Context; 6 | use bytes::Bytes; 7 | use flexstr::SharedStr; 8 | use helpers::{DurationExt, RequestExt}; 9 | use http_body_util::Full; 10 | use hyper::{ 11 | body::Incoming, server::conn::http1, service::service_fn, Method, Request, Response, StatusCode, 12 | }; 13 | use hyper_util::rt::TokioIo; 14 | #[cfg(feature = "rich_web")] 15 | use once_cell::sync::Lazy; 16 | use prettytable::{cell, format::consts::FORMAT_NO_LINESEP_WITH_TITLE, row, Table}; 17 | use serde_derive::Serialize; 18 | use std::{ 19 | fmt::Write, 20 | fs, io, 21 | net::SocketAddr, 22 | path::Path, 23 | sync::Arc, 24 | time::{Duration, Instant}, 25 | }; 26 | #[cfg(unix)] 27 | use tokio::net::{UnixListener, UnixStream}; 28 | use tokio::{ 29 | self, 30 | io::{AsyncRead, AsyncWrite}, 31 | net::{TcpListener, TcpStream}, 32 | }; 33 | use tracing::{info, instrument, warn}; 34 | 35 | use crate::{ 36 | monitor::{Monitor, Throughput}, 37 | proxy::{Delay, ProxyServer}, 38 | }; 39 | 40 | #[cfg(feature = "rich_web")] 41 | static BUNDLE: Lazy = Lazy::new(rich::ResourceBundle::new); 42 | 43 | #[derive(Debug, Serialize)] 44 | struct ServerStatus { 45 | server: Arc, 46 | throughput: Option, 47 | } 48 | 49 | #[derive(Debug, Serialize)] 50 | struct Status { 51 | servers: Vec, 52 | uptime: Duration, 53 | throughput: Throughput, 54 | } 55 | 56 | impl Status { 57 | fn from(start_time: &Instant, monitor: &Monitor) -> Self { 58 | let mut thps = monitor.throughputs(); 59 | let throughput = thps.values().fold(Default::default(), |a, b| a + *b); 60 | let servers = monitor 61 | .servers() 62 | .into_iter() 63 | .map(|server| ServerStatus { 64 | throughput: thps.remove(&server), 65 | server, 66 | }) 67 | .collect(); 68 | Status { 69 | servers, 70 | throughput, 71 | uptime: start_time.elapsed(), 72 | } 73 | } 74 | } 75 | 76 | type BytesResult = Result>, http::Error>; 77 | 78 | fn home_page(req: &Request, start_time: &Instant, monitor: &Monitor) -> BytesResult { 79 | if req.accept_html() { 80 | #[cfg(feature = "rich_web")] 81 | let resp = BUNDLE.get("/index.html").map(|(mime, content)| { 82 | Response::builder() 83 | .header("Content-Type", mime) 84 | .body(content.into()) 85 | }); 86 | #[cfg(not(feature = "rich_web"))] 87 | let resp = None; 88 | resp.unwrap_or_else(|| { 89 | Response::builder() 90 | .header("Content-Type", "text/html") 91 | .body(include_str!("index.html").into()) 92 | }) 93 | } else { 94 | plaintext_status_response(start_time, monitor) 95 | } 96 | } 97 | 98 | fn plaintext_status(start_time: &Instant, monitor: &Monitor) -> String { 99 | let status = Status::from(start_time, monitor); 100 | let mut buf = String::new(); 101 | 102 | writeln!( 103 | &mut buf, 104 | "moproxy ({}) is running. {}", 105 | env!("CARGO_PKG_VERSION"), 106 | status.uptime.format() 107 | ) 108 | .unwrap(); 109 | 110 | let mut table = Table::new(); 111 | table.add_row(row![ 112 | "Server", 113 | "Score", 114 | "Delay", 115 | "CUR", 116 | "TTL", 117 | "E16:64", 118 | "Up", 119 | "Down", 120 | "↑↓ bps" 121 | ]); 122 | table.set_format(*FORMAT_NO_LINESEP_WITH_TITLE); 123 | let mut total_alive_conns = 0; 124 | for ServerStatus { server, throughput } in status.servers { 125 | let status = server.status_snapshot(); 126 | let traffic = server.traffic(); 127 | total_alive_conns += status.conn_alive; 128 | let row = table.add_empty_row(); 129 | // Server 130 | row.add_cell(cell!(l -> server.tag)); 131 | // Score 132 | if let Some(v) = status.score { 133 | row.add_cell(cell!(r -> v)); 134 | } else { 135 | row.add_cell(cell!(r -> "-")); 136 | } 137 | // Delay 138 | if let Delay::Some(v) = status.delay { 139 | row.add_cell(cell!(r -> v.format_millis())); 140 | } else { 141 | row.add_cell(cell!(r -> "-")); 142 | } 143 | // CUR TTL 144 | row.add_cell(cell!(r -> status.conn_alive)); 145 | row.add_cell(cell!(r -> status.conn_total)); 146 | // Error rate 147 | // TODO: document the two columns 148 | row.add_cell(cell!(r -> 149 | format!("{:02}:{:02}", 150 | status.recent_error_count( 16), 151 | status.recent_error_count(64), 152 | ) 153 | )); 154 | // Up Down 155 | row.add_cell(cell!(r -> helpers::to_human_bytes(traffic.tx_bytes))); 156 | row.add_cell(cell!(r -> helpers::to_human_bytes(traffic.rx_bytes))); 157 | // ↑↓ 158 | if let Some(tp) = throughput { 159 | let sum = tp.tx_bps + tp.rx_bps; 160 | if sum > 0 { 161 | row.add_cell(cell!(r -> helpers::to_human_bps_prefix_only(sum))); 162 | } 163 | } 164 | } 165 | 166 | writeln!( 167 | &mut buf, 168 | "[{}] ↑ {} ↓ {}\n{}", 169 | total_alive_conns, 170 | helpers::to_human_bps(status.throughput.tx_bps), 171 | helpers::to_human_bps(status.throughput.rx_bps), 172 | table 173 | ) 174 | .unwrap(); 175 | buf 176 | } 177 | 178 | fn plaintext_status_response(start_time: &Instant, monitor: &Monitor) -> BytesResult { 179 | Response::builder() 180 | .header("Content-Type", "text/plain; charset=utf-8") 181 | .body(plaintext_status(start_time, monitor).into()) 182 | } 183 | 184 | fn response(req: &Request, start_time: Instant, monitor: Monitor) -> BytesResult { 185 | if req.method() != Method::GET { 186 | return Response::builder() 187 | .status(StatusCode::METHOD_NOT_ALLOWED) 188 | .header("Allow", "GET") 189 | .header("Content-Type", "text/plain") 190 | .body("only GET is allowed".into()); 191 | } 192 | 193 | match req.uri().path() { 194 | "/" | "/index.html" => home_page(req, &start_time, &monitor), 195 | "/plain" => plaintext_status_response(&start_time, &monitor), 196 | "/version" => Response::builder() 197 | .header("Content-Type", "text/plain") 198 | .body(env!("CARGO_PKG_VERSION").into()), 199 | "/status" => { 200 | let json = serde_json::to_string(&Status::from(&start_time, &monitor)) 201 | .expect("fail to serialize servers to json"); 202 | Response::builder() 203 | .header("Content-Type", "application/json") 204 | .body(json.into()) 205 | } 206 | "/metrics" => open_metrics::exporter(&start_time, &monitor), 207 | path => { 208 | #[cfg(feature = "rich_web")] 209 | let resp = BUNDLE.get(path).map(|(mime, body)| { 210 | Response::builder() 211 | .header("Content-Type", mime) 212 | .body(body.into()) 213 | }); 214 | #[cfg(not(feature = "rich_web"))] 215 | let resp = None; 216 | resp.unwrap_or_else(|| { 217 | Response::builder() 218 | .status(StatusCode::NOT_FOUND) 219 | .header("Content-Type", "text/plain") 220 | .body("page not found".into()) 221 | }) 222 | } 223 | } 224 | } 225 | 226 | #[derive(Debug, Clone)] 227 | enum ListenAddr { 228 | TcpSocket(SocketAddr), 229 | #[cfg(unix)] 230 | UnixPath(SharedStr), 231 | } 232 | 233 | enum Listener { 234 | Tcp(TcpListener), 235 | #[cfg(unix)] 236 | Unix { 237 | listener: UnixListener, 238 | file: AutoRemoveFile, 239 | }, 240 | } 241 | 242 | trait Accept { 243 | async fn accept(&self) -> io::Result; 244 | } 245 | 246 | impl Accept for TcpListener { 247 | async fn accept(&self) -> io::Result { 248 | let (client, _) = self.accept().await?; 249 | Ok(client) 250 | } 251 | } 252 | 253 | #[cfg(unix)] 254 | impl Accept for UnixListener { 255 | async fn accept(&self) -> io::Result { 256 | let (client, _) = self.accept().await?; 257 | Ok(client) 258 | } 259 | } 260 | 261 | #[derive(Clone)] 262 | pub struct WebServer { 263 | monitor: Monitor, 264 | bind_addr: ListenAddr, 265 | } 266 | 267 | pub struct WebServerListener { 268 | monitor: Monitor, 269 | listener: Listener, 270 | } 271 | 272 | impl WebServer { 273 | pub fn new(monitor: Monitor, bind_addr: SharedStr) -> anyhow::Result { 274 | let bind_addr = if !bind_addr.starts_with('/') || cfg!(not(unix)) { 275 | // TCP socket 276 | let addr = str::parse(bind_addr.as_str()) 277 | .context("Not valid TCP socket address for web server")?; 278 | ListenAddr::TcpSocket(addr) 279 | } else { 280 | #[cfg(unix)] 281 | { 282 | ListenAddr::UnixPath(bind_addr) 283 | } 284 | #[cfg(not(unix))] 285 | anyhow::bail!("No UNIX domain socket support on this system") 286 | }; 287 | Ok(Self { monitor, bind_addr }) 288 | } 289 | 290 | pub async fn listen(&self) -> anyhow::Result { 291 | let listener = match &self.bind_addr { 292 | ListenAddr::TcpSocket(addr) => { 293 | info!("Web console listen on tcp:{}", addr); 294 | let listener = TcpListener::bind(&addr) 295 | .await 296 | .context("fail to bind web server")?; 297 | Listener::Tcp(listener) 298 | } 299 | #[cfg(unix)] 300 | ListenAddr::UnixPath(addr) => { 301 | info!("Web console listen on unix:{}", addr); 302 | let file = AutoRemoveFile(addr.clone()); 303 | let listener = UnixListener::bind(&file).context("fail to bind web server")?; 304 | Listener::Unix { listener, file } 305 | } 306 | }; 307 | Ok(WebServerListener { 308 | monitor: self.monitor.clone(), 309 | listener, 310 | }) 311 | } 312 | } 313 | 314 | impl WebServerListener { 315 | pub fn run_background(self) { 316 | match self.listener { 317 | Listener::Tcp(tcp) => { 318 | tokio::spawn(run_server(tcp, self.monitor)); 319 | } 320 | #[cfg(unix)] 321 | Listener::Unix { listener, file } => { 322 | tokio::spawn(async move { 323 | run_server(listener, self.monitor).await; 324 | drop(file); 325 | }); 326 | } 327 | } 328 | } 329 | } 330 | 331 | #[instrument(name = "web_server", skip_all)] 332 | async fn run_server(listener: L, monitor: Monitor) 333 | where 334 | L: Accept + Unpin, 335 | IO: AsyncRead + AsyncWrite + Unpin + Send + 'static, 336 | { 337 | tokio::spawn(monitor.clone().monitor_throughput()); 338 | let start_time = Instant::now(); 339 | 340 | loop { 341 | let stream = match listener.accept().await { 342 | Ok(stream) => stream, 343 | Err(err) => { 344 | warn!("failed to accept: {}", err); 345 | break; 346 | } 347 | }; 348 | let monitor = monitor.clone(); 349 | let service = service_fn(move |req: Request| { 350 | let monitor = monitor.clone(); 351 | async move { response(&req, start_time, monitor) } 352 | }); 353 | 354 | tokio::spawn(async move { 355 | let conn = http1::Builder::new().serve_connection(TokioIo::new(stream), service); 356 | if let Err(e) = conn.await { 357 | warn!("web server error: {}", e); 358 | } 359 | }); 360 | } 361 | 362 | warn!("web server stopped"); 363 | } 364 | 365 | /// File on this path will be removed on `drop()`. 366 | struct AutoRemoveFile(SharedStr); 367 | 368 | impl Drop for AutoRemoveFile { 369 | fn drop(&mut self) { 370 | if let Err(err) = fs::remove_file(self.0.as_str()) { 371 | warn!("fail to remove {}: {}", self.0, err); 372 | } 373 | } 374 | } 375 | 376 | impl AsRef for AutoRemoveFile { 377 | fn as_ref(&self) -> &Path { 378 | self.0.as_str().as_ref() 379 | } 380 | } 381 | -------------------------------------------------------------------------------- /src/web/open_metrics.rs: -------------------------------------------------------------------------------- 1 | use hyper::Response; 2 | use std::{ 3 | fmt::{Display, Write}, 4 | time::Instant, 5 | }; 6 | 7 | use super::{BytesResult, ServerStatus, Status}; 8 | use crate::{monitor::Monitor, proxy::Delay}; 9 | 10 | const CONTENT_TYPE: &str = "application/openmetrics-text; version=1.0.0; charset=utf-8"; 11 | 12 | fn new_metric(buf: &mut String, name: &str, metric_type: &str, help: &str) { 13 | writeln!(buf, "# HELP moproxy_{} {}", name, help).unwrap(); 14 | writeln!(buf, "# TYPE moproxy_{} {}", name, metric_type).unwrap(); 15 | } 16 | 17 | fn each_server(buf: &mut String, name: &str, servers: &[ServerStatus], metric: F) 18 | where 19 | F: Fn(&ServerStatus) -> Option, 20 | D: Display, 21 | { 22 | for s in servers { 23 | if let Some(value) = metric(s) { 24 | writeln!( 25 | buf, 26 | "moproxy_{}{{server=\"{}\"}} {}", 27 | name, s.server.tag, value 28 | ) 29 | .unwrap(); 30 | } 31 | } 32 | } 33 | 34 | pub fn exporter(start_time: &Instant, monitor: &Monitor) -> BytesResult { 35 | let status = Status::from(start_time, monitor); 36 | let mut buf = String::new(); 37 | 38 | macro_rules! server_gauge { 39 | ($name:expr, $help:expr, $func:expr) => { 40 | new_metric(&mut buf, $name, "gauge", $help); 41 | each_server(&mut buf, $name, &status.servers, $func); 42 | }; 43 | } 44 | 45 | server_gauge!( 46 | "proxy_server_bytes_tx_total", 47 | "Current total of outgoing bytes", 48 | |s| Some(s.server.traffic().tx_bytes) 49 | ); 50 | server_gauge!( 51 | "proxy_server_bytes_rx_total", 52 | "Current total of incoming bytes", 53 | |s| Some(s.server.traffic().rx_bytes) 54 | ); 55 | server_gauge!( 56 | "proxy_server_connections_alive", 57 | "Current number of alive connections", 58 | |s| Some(s.server.status_snapshot().conn_alive) 59 | ); 60 | server_gauge!( 61 | "proxy_server_connections_error", 62 | "Current number of connections closed with error", 63 | |s| Some(s.server.status_snapshot().conn_error) 64 | ); 65 | server_gauge!( 66 | "proxy_server_connections_total", 67 | "Current total number of connections", 68 | |s| Some(s.server.status_snapshot().conn_total) 69 | ); 70 | server_gauge!( 71 | "proxy_server_dns_delay_seconds", 72 | "Total seconds for the last DNS query test", 73 | |s| match s.server.status_snapshot().delay { 74 | Delay::Some(d) => Some(d.as_secs() as f32 + d.subsec_millis() as f32 / 1000.0), 75 | _ => None, 76 | } 77 | ); 78 | server_gauge!( 79 | "proxy_server_score", 80 | "Score of server based on the last DNS query test", 81 | |s| s.server.status_snapshot().score 82 | ); 83 | 84 | writeln!(buf, "# EOF").unwrap(); 85 | Response::builder() 86 | .header("Content-Type", CONTENT_TYPE) 87 | .body(buf.into()) 88 | } 89 | -------------------------------------------------------------------------------- /src/web/rich.rs: -------------------------------------------------------------------------------- 1 | use bytes::Bytes; 2 | use parking_lot::Mutex; 3 | use std::io::{Cursor, Read}; 4 | use zip::read::ZipArchive; 5 | 6 | pub struct ResourceBundle { 7 | zip: Mutex>>, 8 | } 9 | 10 | impl ResourceBundle { 11 | pub fn new() -> Self { 12 | let bytes = Bytes::from_static(include_bytes!(env!("MOPROXY_WEB_BUNDLE"))); 13 | let zip = ZipArchive::new(Cursor::new(bytes)) 14 | .expect("broken moproxy-web bundle") 15 | .into(); 16 | ResourceBundle { zip } 17 | } 18 | 19 | pub fn get(&self, path: &str) -> Option<(&'static str, Vec)> { 20 | let name = path.strip_prefix('/').unwrap_or(path); 21 | let mut zip = self.zip.lock(); 22 | let mut file = zip.by_name(name).ok()?; 23 | if !file.is_file() { 24 | return None; 25 | } 26 | let mut content = Vec::with_capacity(file.size() as usize); 27 | file.read_to_end(&mut content) 28 | .expect("error on read moproxy-web bundle"); 29 | 30 | let name_ext = name.rsplit_once('.').map(|x| x.1); 31 | let mime = match name_ext { 32 | Some("html") => "text/html", 33 | Some("js") => "application/javascript", 34 | Some("css") => "text/css", 35 | Some("txt") => "text/plain", 36 | Some("json") | Some("map") => "application/json", 37 | _ => "application/octet-stream", 38 | }; 39 | 40 | (mime, content).into() 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /tests/socks5.rs: -------------------------------------------------------------------------------- 1 | use moproxy::proxy::socks5::handshake; 2 | use std::net::SocketAddr; 3 | use tokio::{ 4 | self, 5 | io::{AsyncReadExt, AsyncWriteExt}, 6 | net::{TcpListener, TcpStream}, 7 | }; 8 | 9 | #[tokio::test] 10 | async fn test_socks5_domain() { 11 | let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); 12 | let addr = listener.local_addr().unwrap(); 13 | 14 | tokio::spawn(async move { 15 | let (mut stream, _) = listener.accept().await.unwrap(); 16 | let mut buf = [0u8; 128]; 17 | stream.read_exact(&mut buf[..3]).await.unwrap(); 18 | assert_eq!(&[5, 1, 0], &buf[..3]); // ver 5, no auth 19 | stream.write_all(&[5, 0]).await.unwrap(); // no auth 20 | 21 | stream.read(&mut buf).await.unwrap(); 22 | assert!(buf.starts_with(&[5, 1, 0, 3, 11])); // domain, len 11 23 | assert!(buf[5..].starts_with(b"example.com")); 24 | assert!(buf[16..].starts_with(&[0, 80])); 25 | stream 26 | .write_all(&[5, 0, 0, 1, 0, 0, 0, 0, 0, 80]) 27 | .await 28 | .unwrap(); 29 | 30 | stream.read(&mut buf).await.unwrap(); 31 | assert!(buf.starts_with(b"early-payload")); 32 | stream.write_all(b"response").await.unwrap(); 33 | }); 34 | 35 | let mut stream = TcpStream::connect(&addr).await.unwrap(); 36 | let dest = ("example.com", 80).into(); 37 | let payload = b"early-payload"; 38 | handshake(&mut stream, &dest, Some(payload), false, &None) 39 | .await 40 | .unwrap(); 41 | let mut buf = [0u8; 128]; 42 | let n = stream.read(&mut buf).await.unwrap(); 43 | assert_eq!(&buf[..n], b"response"); 44 | } 45 | 46 | #[tokio::test] 47 | async fn test_socks5_ipv6() { 48 | let listener = TcpListener::bind("[::1]:0").await.unwrap(); 49 | let addr = listener.local_addr().unwrap(); 50 | 51 | tokio::spawn(async move { 52 | let (mut stream, _) = listener.accept().await.unwrap(); 53 | let mut buf = [0u8; 128]; 54 | stream.read_exact(&mut buf[..3]).await.unwrap(); 55 | assert_eq!(&[5, 1, 0], &buf[..3]); // ver 5, no auth 56 | stream.write_all(&[5, 0]).await.unwrap(); // no auth 57 | 58 | stream.read(&mut buf).await.unwrap(); 59 | assert_eq!( 60 | &buf[..22], 61 | &[ 62 | 5, 1, 0, 4, 0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 63 | 1, // 2001:db8::1 64 | 0, 80, // port number 65 | ] 66 | ); 67 | stream 68 | .write_all(&[ 69 | 5, 0, 0, 4, 1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1, // v6 addr 70 | 0, 80, // port number 71 | ]) 72 | .await 73 | .unwrap(); 74 | 75 | stream.read(&mut buf).await.unwrap(); 76 | assert!(buf.starts_with(b"early-payload")); 77 | stream.write_all(b"response").await.unwrap(); 78 | }); 79 | 80 | let mut stream = TcpStream::connect(&addr).await.unwrap(); 81 | let dest = "[2001:db8::1]:80".parse::().unwrap().into(); 82 | let payload = b"early-payload"; 83 | handshake(&mut stream, &dest, Some(payload), false, &None) 84 | .await 85 | .unwrap(); 86 | let mut buf = [0u8; 128]; 87 | let n = stream.read(&mut buf).await.unwrap(); 88 | assert_eq!(&buf[..n], b"response"); 89 | } 90 | --------------------------------------------------------------------------------