├── .gitignore ├── .rustfmt.toml ├── Cargo.toml ├── README.md ├── bypass-china.acl └── src ├── acl.rs ├── lib.rs ├── main.rs └── socks5.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | *.iml 3 | /.idea 4 | Cargo.lock 5 | -------------------------------------------------------------------------------- /.rustfmt.toml: -------------------------------------------------------------------------------- 1 | use_try_shorthand = true 2 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "trust-china-dns" 3 | version = "0.1.0" 4 | authors = ["Max Lv "] 5 | repository = "https://github.com/madeye/trust-china-dns" 6 | license = "MIT" 7 | edition = "2018" 8 | 9 | [[bin]] 10 | name = "trust-china-dns" 11 | path = "src/main.rs" 12 | 13 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 14 | 15 | [dependencies] 16 | tokio = { version = "^0.2.7", features = ["full"] } 17 | iprange = "0.6" 18 | ipnet = "2.2" 19 | bytes = "0.5" 20 | regex = "1" 21 | futures = "0.3" 22 | clap = "2.33" 23 | lru = "0.4" 24 | trust-dns-proto = "0.19" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # trust-china-dns 2 | 3 | A China-DNS implementation in Rust. 4 | -------------------------------------------------------------------------------- /src/acl.rs: -------------------------------------------------------------------------------- 1 | //! Access Control List (ACL) for shadowsocks 2 | //! 3 | //! This is for advance controlling server behaviors in both local and proxy servers. 4 | 5 | use std::{ 6 | fmt, 7 | fs::File, 8 | io::{self, BufRead, BufReader, Error, ErrorKind}, 9 | net::{IpAddr, SocketAddr}, 10 | path::Path, 11 | }; 12 | 13 | use ipnet::{IpNet, Ipv4Net, Ipv6Net}; 14 | use iprange::IpRange; 15 | use regex::RegexSet; 16 | 17 | use crate::socks5::Address; 18 | 19 | /// Strategy mode that ACL is running 20 | #[derive(Debug, Copy, Clone, Eq, PartialEq)] 21 | pub enum Mode { 22 | /// BlackList mode, rejects or bypasses all requests by default 23 | BlackList, 24 | /// WhiteList mode, accepts or proxies all requests by default 25 | WhiteList, 26 | } 27 | 28 | #[derive(Clone)] 29 | struct Rules { 30 | ipv4: IpRange, 31 | ipv6: IpRange, 32 | rule: RegexSet, 33 | } 34 | 35 | impl fmt::Debug for Rules { 36 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 37 | write!( 38 | f, 39 | "Rules {{ ipv4: {:?}, ipv6: {:?}, rule: [", 40 | self.ipv4, self.ipv6 41 | )?; 42 | 43 | let max_len = 2; 44 | let has_more = self.rule.len() > max_len; 45 | 46 | for (idx, r) in self.rule.patterns().iter().take(max_len).enumerate() { 47 | if idx > 0 { 48 | f.write_str(", ")?; 49 | } 50 | f.write_str(r)?; 51 | } 52 | 53 | if has_more { 54 | f.write_str(", ...")?; 55 | } 56 | 57 | f.write_str("] }") 58 | } 59 | } 60 | 61 | impl Rules { 62 | /// Create a new rule 63 | fn new(mut ipv4: IpRange, mut ipv6: IpRange, rule: RegexSet) -> Rules { 64 | // Optimization, merging networks 65 | ipv4.simplify(); 66 | ipv6.simplify(); 67 | 68 | Rules { ipv4, ipv6, rule } 69 | } 70 | 71 | /// Check if the specified address matches these rules 72 | fn check_address_matched(&self, addr: &Address) -> bool { 73 | match *addr { 74 | Address::SocketAddress(ref saddr) => self.check_ip_matched(saddr), 75 | Address::DomainNameAddress(ref domain, ..) => self.check_host_matched(domain), 76 | } 77 | } 78 | 79 | /// Check if the specified address matches any rules 80 | fn check_ip_matched(&self, addr: &SocketAddr) -> bool { 81 | match addr.ip() { 82 | IpAddr::V4(v4) => self.ipv4.contains(&v4), 83 | IpAddr::V6(v6) => self.ipv6.contains(&v6), 84 | } 85 | } 86 | 87 | /// Check if the specified host matches any rules 88 | fn check_host_matched(&self, host: &str) -> bool { 89 | self.rule.is_match(host) 90 | } 91 | } 92 | 93 | /// ACL rules 94 | /// 95 | /// ## Sections 96 | /// 97 | /// ACL File is formatted in sections, each section has a name with surrounded by brackets `[` and `]` 98 | /// followed by Rules line by line. 99 | /// 100 | /// ```plain 101 | /// [SECTION-1] 102 | /// RULE-1 103 | /// RULE-2 104 | /// RULE-3 105 | /// 106 | /// [SECTION-2] 107 | /// RULE-1 108 | /// RULE-2 109 | /// RULE-3 110 | /// ``` 111 | /// 112 | /// Available sections are 113 | /// 114 | /// - For local servers (`sslocal`, `ssredir`, ...) 115 | /// * `[bypass_all]` - ACL runs in `BlackList` mode. 116 | /// - `[bypass_list]` - Rules for connecting directly 117 | /// * `[proxy_all]` - ACL runs in `WhiteList` mode. 118 | /// - `[proxy_list]` - Rules for connecting through proxies 119 | /// - For remote servers (`ssserver`) 120 | /// * `[reject_all]` - ACL runs in `BlackList` mode. 121 | /// * `[accept_all]` - ACL runs in `WhiteList` mode. 122 | /// * `[outbound_block_list]` - Rules for blocking outbound addresses. 123 | /// 124 | /// ## Mode 125 | /// 126 | /// - `WhiteList` (reject / bypass all, except ...) 127 | /// 128 | /// Only hosts / clients that matches rules in 129 | /// - `[proxy_list]` - will be connected through remote proxies, others will be connected directly 130 | /// - `[white_list]` - will be allowed, others will be rejected 131 | /// 132 | /// - `BlackList` (accept / proxy all, except ...) 133 | /// 134 | /// Only hosts / clients that matches rules in 135 | /// - `[bypass_list]` - will be connected directly instead of connecting through remote proxies 136 | /// - `[black_list]` - will be rejected (close connection) 137 | /// 138 | /// ## Rules 139 | /// 140 | /// Rules can be either 141 | /// 142 | /// - CIDR form network addresses, like `10.9.0.32/16` 143 | /// - IP addresses, like `127.0.0.1` or `::1` 144 | /// - Regular Expression for matching hosts, like `(^|\.)gmail\.com$` 145 | #[derive(Debug, Clone)] 146 | pub struct AccessControl { 147 | outbound_block: Rules, 148 | black_list: Rules, 149 | white_list: Rules, 150 | mode: Mode, 151 | } 152 | 153 | impl AccessControl { 154 | /// Load ACL rules from a file 155 | pub fn load_from_file>(p: P) -> io::Result { 156 | let fp = File::open(p)?; 157 | let r = BufReader::new(fp); 158 | 159 | let mut mode = Mode::BlackList; 160 | 161 | let mut outbound_block_ipv4 = IpRange::new(); 162 | let mut outbound_block_ipv6 = IpRange::new(); 163 | let mut outbound_block_rules = Vec::new(); 164 | let mut bypass_ipv4 = IpRange::new(); 165 | let mut bypass_ipv6 = IpRange::new(); 166 | let mut bypass_rules = Vec::new(); 167 | let mut proxy_ipv4 = IpRange::new(); 168 | let mut proxy_ipv6 = IpRange::new(); 169 | let mut proxy_rules = Vec::new(); 170 | 171 | let mut curr_ipv4 = &mut bypass_ipv4; 172 | let mut curr_ipv6 = &mut bypass_ipv6; 173 | let mut curr_rules = &mut proxy_rules; 174 | 175 | for line in r.lines() { 176 | let line = line?; 177 | if line.is_empty() { 178 | continue; 179 | } 180 | 181 | // Comments 182 | if line.starts_with('#') { 183 | continue; 184 | } 185 | 186 | match line.as_str() { 187 | "[reject_all]" | "[bypass_all]" => { 188 | mode = Mode::WhiteList; 189 | } 190 | "[accept_all]" | "[proxy_all]" => { 191 | mode = Mode::BlackList; 192 | } 193 | "[outbound_block_list]" => { 194 | curr_ipv4 = &mut outbound_block_ipv4; 195 | curr_ipv6 = &mut outbound_block_ipv6; 196 | curr_rules = &mut outbound_block_rules; 197 | } 198 | "[black_list]" | "[bypass_list]" => { 199 | curr_ipv4 = &mut bypass_ipv4; 200 | curr_ipv6 = &mut bypass_ipv6; 201 | curr_rules = &mut bypass_rules; 202 | } 203 | "[white_list]" | "[proxy_list]" => { 204 | curr_ipv4 = &mut proxy_ipv4; 205 | curr_ipv6 = &mut proxy_ipv6; 206 | curr_rules = &mut proxy_rules; 207 | } 208 | _ => { 209 | match line.parse::() { 210 | Ok(IpNet::V4(v4)) => { 211 | curr_ipv4.add(v4); 212 | } 213 | Ok(IpNet::V6(v6)) => { 214 | curr_ipv6.add(v6); 215 | } 216 | Err(..) => { 217 | // Maybe it is a pure IpAddr 218 | match line.parse::() { 219 | Ok(IpAddr::V4(v4)) => { 220 | curr_ipv4.add(Ipv4Net::from(v4)); 221 | } 222 | Ok(IpAddr::V6(v6)) => { 223 | curr_ipv6.add(Ipv6Net::from(v6)); 224 | } 225 | Err(..) => { 226 | // FIXME: If this line is not a valid regex, how can we know without actually compile it? 227 | curr_rules.push(line); 228 | } 229 | } 230 | } 231 | } 232 | } 233 | } 234 | } 235 | 236 | let outbound_block_regex = match RegexSet::new(outbound_block_rules) { 237 | Ok(r) => r, 238 | Err(err) => { 239 | let err = Error::new( 240 | ErrorKind::Other, 241 | format!("[outbound_block_list] regex error: {}", err), 242 | ); 243 | return Err(err); 244 | } 245 | }; 246 | 247 | let bypass_regex = match RegexSet::new(bypass_rules) { 248 | Ok(r) => r, 249 | Err(err) => { 250 | let err = Error::new( 251 | ErrorKind::Other, 252 | format!("[black_list] or [bypass_list] regex error: {}", err), 253 | ); 254 | return Err(err); 255 | } 256 | }; 257 | 258 | let proxy_regex = match RegexSet::new(proxy_rules) { 259 | Ok(r) => r, 260 | Err(err) => { 261 | let err = Error::new( 262 | ErrorKind::Other, 263 | format!("[white_list] or [proxy_list] regex error: {}", err), 264 | ); 265 | return Err(err); 266 | } 267 | }; 268 | 269 | Ok(AccessControl { 270 | outbound_block: Rules::new( 271 | outbound_block_ipv4, 272 | outbound_block_ipv6, 273 | outbound_block_regex, 274 | ), 275 | black_list: Rules::new(bypass_ipv4, bypass_ipv6, bypass_regex), 276 | white_list: Rules::new(proxy_ipv4, proxy_ipv6, proxy_regex), 277 | mode, 278 | }) 279 | } 280 | 281 | /// Check if target address should be bypassed (for client) 282 | /// 283 | /// FIXME: This function may perform a DNS resolution 284 | pub async fn check_target_bypassed(&self, addr: &Address) -> bool { 285 | // Always redirect TCP DNS query (for android) 286 | if cfg!(target_os = "android") { 287 | let port = match *addr { 288 | Address::SocketAddress(ref saddr) => saddr.port(), 289 | Address::DomainNameAddress(ref _host, port) => port, 290 | }; 291 | if port == 53 { 292 | return false; 293 | } 294 | } 295 | // Addresses in bypass_list will be bypassed 296 | if self.black_list.check_address_matched(addr) { 297 | return true; 298 | } 299 | // Addresses in proxy_list will be proxied 300 | if self.white_list.check_address_matched(addr) { 301 | return false; 302 | } 303 | // default rule 304 | match self.mode { 305 | Mode::BlackList => false, 306 | Mode::WhiteList => true, 307 | } 308 | } 309 | 310 | /// Check if client address should be blocked (for server) 311 | pub fn check_client_blocked(&self, addr: &SocketAddr) -> bool { 312 | match self.mode { 313 | Mode::BlackList => { 314 | // Only clients in black_list will be blocked 315 | self.black_list.check_ip_matched(addr) 316 | } 317 | Mode::WhiteList => { 318 | // Only clients in white_list will be proxied 319 | !self.white_list.check_ip_matched(addr) 320 | } 321 | } 322 | } 323 | 324 | /// Check if outbound address is blocked (for server) 325 | /// 326 | /// NOTE: `Address::DomainName` is only validated by regex rules, 327 | /// resolved addresses are checked in the `lookup_outbound_then!` macro 328 | pub fn check_outbound_blocked(&self, outbound: &Address) -> bool { 329 | self.outbound_block.check_address_matched(outbound) 330 | } 331 | 332 | /// Check resolved outbound address is blocked (for server) 333 | pub fn check_resolved_outbound_blocked(&self, outbound: &SocketAddr) -> bool { 334 | self.outbound_block.check_ip_matched(outbound) 335 | } 336 | } 337 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod acl; 2 | pub mod socks5; 3 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use std::net::{IpAddr, SocketAddr}; 2 | use std::{future::Future, time::Duration}; 3 | 4 | use clap::{App, Arg}; 5 | 6 | use tokio::io::Result; 7 | use tokio::net::{TcpStream, UdpSocket}; 8 | use tokio::prelude::*; 9 | use tokio::time; 10 | 11 | use lru::LruCache; 12 | 13 | use trust_china_dns::acl::AccessControl; 14 | use trust_china_dns::socks5::*; 15 | use trust_dns_proto::op::*; 16 | use trust_dns_proto::rr::*; 17 | 18 | pub async fn try_timeout(fut: F, timeout: Option) -> io::Result 19 | where 20 | F: Future>, 21 | { 22 | match timeout { 23 | Some(t) => time::timeout(t, fut).await?, 24 | None => fut.await, 25 | } 26 | .map_err(From::from) 27 | } 28 | 29 | async fn udp_lookup(qname: &Name, qtype: RecordType, server: SocketAddr) -> Result { 30 | let mut socket = UdpSocket::bind(("0.0.0.0", 0)).await?; 31 | 32 | let mut message = Message::new(); 33 | let mut query = Query::new(); 34 | 35 | query.set_query_type(qtype); 36 | query.set_name(qname.clone()); 37 | 38 | message.set_id(6666); 39 | message.set_recursion_desired(true); 40 | message.add_query(query); 41 | 42 | let req_buffer = message.to_vec()?; 43 | socket.send_to(&req_buffer, server).await?; 44 | 45 | let mut res_buffer = vec![0; 512]; 46 | socket.recv_from(&mut res_buffer).await?; 47 | 48 | Ok(Message::from_vec(&mut res_buffer)?) 49 | } 50 | 51 | async fn socks5_lookup( 52 | qname: &Name, 53 | qtype: RecordType, 54 | socks5: SocketAddr, 55 | ns: SocketAddr, 56 | ) -> Result { 57 | let mut stream = TcpStream::connect(socks5).await?; 58 | 59 | // 1. Handshake 60 | let hs = HandshakeRequest::new(vec![SOCKS5_AUTH_METHOD_NONE]); 61 | hs.write_to(&mut stream).await?; 62 | stream.flush().await?; 63 | 64 | let hsp = HandshakeResponse::read_from(&mut stream).await?; 65 | assert_eq!(hsp.chosen_method, SOCKS5_AUTH_METHOD_NONE); 66 | 67 | // 2. Send request header 68 | let addr = Address::SocketAddress(ns); 69 | let h = TcpRequestHeader::new(Command::TcpConnect, addr); 70 | h.write_to(&mut stream).await?; 71 | stream.flush().await?; 72 | 73 | let hp = TcpResponseHeader::read_from(&mut stream).await?; 74 | match hp.reply { 75 | Reply::Succeeded => (), 76 | r => { 77 | let err = io::Error::new(io::ErrorKind::Other, format!("{}", r)); 78 | return Err(err); 79 | } 80 | } 81 | 82 | let mut message = Message::new(); 83 | let mut query = Query::new(); 84 | 85 | query.set_query_type(qtype); 86 | query.set_name(qname.clone()); 87 | 88 | message.set_id(6666); 89 | message.set_recursion_desired(true); 90 | message.add_query(query); 91 | 92 | let req_buffer = message.to_vec()?; 93 | let size = req_buffer.len(); 94 | let mut size_buffer: [u8; 2] = [((size >> 8) & 0xFF) as u8, ((size >> 0) & 0xFF) as u8]; 95 | let mut send_buffer: [u8; 512 + 2] = [0; 512 + 2]; 96 | send_buffer[..2].copy_from_slice(&size_buffer[..2]); 97 | send_buffer[2..size + 2].copy_from_slice(&req_buffer[0..size]); 98 | stream.write_all(&send_buffer[0..size + 2]).await?; 99 | 100 | stream.read_exact(&mut size_buffer[0..2]).await?; 101 | 102 | let mut res_buffer = vec![0; 512]; 103 | let size = ((size_buffer[0] as usize) << 8) + (size_buffer[1] as usize); 104 | stream.read_exact(&mut res_buffer[0..size]).await?; 105 | 106 | Ok(Message::from_vec(&mut res_buffer)?) 107 | } 108 | 109 | async fn acl_lookup( 110 | acl: &AccessControl, 111 | local: SocketAddr, 112 | remote: SocketAddr, 113 | socks5: SocketAddr, 114 | qname: &Name, 115 | qtype: RecordType, 116 | ) -> Result { 117 | // Start querying name servers 118 | println!( 119 | "attempting lookup of {:?} {} with ns {} and {}", 120 | qtype, qname, local, remote 121 | ); 122 | 123 | let ten_seconds = Some(Duration::new(5, 0)); 124 | 125 | let local_response = try_timeout(udp_lookup(qname, qtype.clone(), local), ten_seconds).await?; 126 | let remote_response = try_timeout( 127 | socks5_lookup(qname, qtype.clone(), socks5, remote), 128 | ten_seconds, 129 | ) 130 | .await?; 131 | 132 | let addr = Address::DomainNameAddress(qname.to_string(), 0); 133 | let qname_bypassed = acl.check_target_bypassed(&addr).await; 134 | 135 | let mut ip_bypassed = false; 136 | for rec in local_response.answers() { 137 | let bypassed = match rec.rdata() { 138 | RData::A(ref ip) => { 139 | let addr = Address::SocketAddress(SocketAddr::new(IpAddr::from(*ip), 0)); 140 | acl.check_target_bypassed(&addr).await 141 | } 142 | RData::AAAA(ref ip) => { 143 | let addr = Address::SocketAddress(SocketAddr::new(IpAddr::from(*ip), 0)); 144 | acl.check_target_bypassed(&addr).await 145 | } 146 | _ => false, 147 | }; 148 | if bypassed { 149 | ip_bypassed = true; 150 | } 151 | } 152 | 153 | if qname_bypassed { 154 | println!("Pick local response"); 155 | Ok(local_response.clone()) 156 | } else if ip_bypassed { 157 | println!("Pick local response"); 158 | Ok(local_response.clone()) 159 | } else { 160 | println!("Pick remote response"); 161 | Ok(remote_response.clone()) 162 | } 163 | } 164 | 165 | #[tokio::main] 166 | async fn main() -> Result<()> { 167 | let mut reverse_resolver_cache = LruCache::new(8192); 168 | 169 | let matches = App::new("trust-china-dns") 170 | .version("0.1") 171 | .about("Yet another ChinaDNS in Rust") 172 | .author("Max Lv ") 173 | .arg( 174 | Arg::with_name("local") 175 | .long("local") 176 | .value_name("LOCAL_DNS") 177 | .help("Sets a custom local DNS server") 178 | .takes_value(true), 179 | ) 180 | .arg( 181 | Arg::with_name("remote") 182 | .long("remote") 183 | .value_name("REMOTE_DNS") 184 | .help("Sets a custom remote DNS server") 185 | .takes_value(true), 186 | ) 187 | .arg( 188 | Arg::with_name("socks5") 189 | .long("socks5") 190 | .value_name("SOCKS5") 191 | .help("Sets a custom SOCKS5 proxy") 192 | .takes_value(true), 193 | ) 194 | .arg( 195 | Arg::with_name("listen") 196 | .long("listen") 197 | .value_name("LISTEN") 198 | .help("Sets a custom listen address") 199 | .takes_value(true), 200 | ) 201 | .arg( 202 | Arg::with_name("acl") 203 | .long("acl") 204 | .value_name("ACL") 205 | .help("Sets a custom ACL path") 206 | .takes_value(true), 207 | ) 208 | .get_matches(); 209 | 210 | let local = matches.value_of("local").unwrap_or("114.114.114.114:53"); 211 | println!("Local DNS server: {}", local); 212 | 213 | let remote = matches.value_of("remote").unwrap_or("8.8.8.8:53"); 214 | println!("Remote DNS server: {}", remote); 215 | 216 | let socks5 = matches.value_of("socks5").unwrap_or("127.0.0.1:1080"); 217 | println!("SOCKS5 server: {}", socks5); 218 | 219 | let listen = matches.value_of("listen").unwrap_or("127.0.0.1:2053"); 220 | println!("Listen on {}", listen); 221 | 222 | let acl = matches.value_of("acl").unwrap_or("bypass-china.acl"); 223 | println!("Load ACL file: {}", acl); 224 | 225 | let local_addr: SocketAddr = local.parse().expect("Unable to parse local address"); 226 | let remote_addr: SocketAddr = remote.parse().expect("Unable to parse remote address"); 227 | let socks5_addr: SocketAddr = socks5.parse().expect("Unable to parse socks5 address"); 228 | let listen_addr: SocketAddr = listen.parse().expect("Unable to parse listen address"); 229 | 230 | let mut socket = UdpSocket::bind(listen_addr).await?; 231 | let acl = AccessControl::load_from_file(acl).expect("Failed to load ACL file"); 232 | 233 | loop { 234 | let mut req_buffer: [u8; 512] = [0; 512]; 235 | let (_, src) = match socket.recv_from(&mut req_buffer).await { 236 | Ok(x) => x, 237 | Err(e) => { 238 | println!("Failed to read from UDP socket: {:?}", e); 239 | continue; 240 | } 241 | }; 242 | 243 | let request = match Message::from_vec(&mut req_buffer) { 244 | Ok(x) => x, 245 | Err(e) => { 246 | println!("Failed to parse UDP query message: {:?}", e); 247 | continue; 248 | } 249 | }; 250 | 251 | let mut message = Message::new(); 252 | message.set_id(request.id()); 253 | message.set_recursion_desired(true); 254 | message.set_recursion_available(true); 255 | message.set_message_type(header::MessageType::Response); 256 | 257 | if request.queries().is_empty() { 258 | message.set_response_code(response_code::ResponseCode::FormErr); 259 | } else { 260 | let question = &request.queries()[0]; 261 | println!("Received query: {:?}", question); 262 | 263 | if let Ok(result) = acl_lookup( 264 | &acl, 265 | local_addr, 266 | remote_addr, 267 | socks5_addr, 268 | question.name(), 269 | question.query_type(), 270 | ) 271 | .await 272 | { 273 | message.add_query(question.clone()); 274 | message.set_response_code(result.response_code()); 275 | 276 | for rec in result.answers() { 277 | println!("Answer: {:?}", rec); 278 | match rec.rdata() { 279 | RData::A(ref ip) => reverse_resolver_cache 280 | .put(IpAddr::from(*ip), question.name().to_ascii()), 281 | RData::AAAA(ref ip) => reverse_resolver_cache 282 | .put(IpAddr::from(*ip), question.name().to_ascii()), 283 | _ => None, 284 | }; 285 | message.add_answer(rec.clone()); 286 | } 287 | for rec in result.additionals() { 288 | println!("Additionals: {:?}", rec); 289 | message.add_additional(rec.clone()); 290 | } 291 | } else { 292 | message.set_response_code(ResponseCode::ServFail); 293 | } 294 | } 295 | 296 | let res_buffer = message.to_vec()?; 297 | match socket.send_to(&res_buffer, src).await { 298 | Ok(_) => {} 299 | Err(e) => { 300 | println!("Failed to send response: {:?}", e); 301 | continue; 302 | } 303 | }; 304 | } 305 | } 306 | -------------------------------------------------------------------------------- /src/socks5.rs: -------------------------------------------------------------------------------- 1 | //! Socks5 protocol definition (RFC1928) 2 | //! 3 | //! Implements [SOCKS Protocol Version 5](https://www.ietf.org/rfc/rfc1928.txt) proxy protocol 4 | 5 | use std::{ 6 | convert::From, 7 | error, 8 | fmt::{self, Debug, Formatter}, 9 | io::{self, Cursor}, 10 | net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}, 11 | str::FromStr, 12 | u8, vec, 13 | }; 14 | 15 | use bytes::{buf::BufExt, Buf, BufMut, BytesMut}; 16 | use tokio::prelude::*; 17 | 18 | pub use self::consts::{ 19 | SOCKS5_AUTH_METHOD_GSSAPI, SOCKS5_AUTH_METHOD_NONE, SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE, 20 | SOCKS5_AUTH_METHOD_PASSWORD, 21 | }; 22 | 23 | #[rustfmt::skip] 24 | mod consts { 25 | pub const SOCKS5_VERSION: u8 = 0x05; 26 | 27 | pub const SOCKS5_AUTH_METHOD_NONE: u8 = 0x00; 28 | pub const SOCKS5_AUTH_METHOD_GSSAPI: u8 = 0x01; 29 | pub const SOCKS5_AUTH_METHOD_PASSWORD: u8 = 0x02; 30 | pub const SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE: u8 = 0xff; 31 | 32 | pub const SOCKS5_CMD_TCP_CONNECT: u8 = 0x01; 33 | pub const SOCKS5_CMD_TCP_BIND: u8 = 0x02; 34 | pub const SOCKS5_CMD_UDP_ASSOCIATE: u8 = 0x03; 35 | 36 | pub const SOCKS5_ADDR_TYPE_IPV4: u8 = 0x01; 37 | pub const SOCKS5_ADDR_TYPE_DOMAIN_NAME: u8 = 0x03; 38 | pub const SOCKS5_ADDR_TYPE_IPV6: u8 = 0x04; 39 | 40 | pub const SOCKS5_REPLY_SUCCEEDED: u8 = 0x00; 41 | pub const SOCKS5_REPLY_GENERAL_FAILURE: u8 = 0x01; 42 | pub const SOCKS5_REPLY_CONNECTION_NOT_ALLOWED: u8 = 0x02; 43 | pub const SOCKS5_REPLY_NETWORK_UNREACHABLE: u8 = 0x03; 44 | pub const SOCKS5_REPLY_HOST_UNREACHABLE: u8 = 0x04; 45 | pub const SOCKS5_REPLY_CONNECTION_REFUSED: u8 = 0x05; 46 | pub const SOCKS5_REPLY_TTL_EXPIRED: u8 = 0x06; 47 | pub const SOCKS5_REPLY_COMMAND_NOT_SUPPORTED: u8 = 0x07; 48 | pub const SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08; 49 | } 50 | 51 | /// SOCKS5 command 52 | #[derive(Clone, Debug, Copy)] 53 | pub enum Command { 54 | /// CONNECT command (TCP tunnel) 55 | TcpConnect, 56 | /// BIND command (Not supported in ShadowSocks) 57 | TcpBind, 58 | /// UDP ASSOCIATE command 59 | UdpAssociate, 60 | } 61 | 62 | impl Command { 63 | #[inline] 64 | #[rustfmt::skip] 65 | fn as_u8(self) -> u8 { 66 | match self { 67 | Command::TcpConnect => consts::SOCKS5_CMD_TCP_CONNECT, 68 | Command::TcpBind => consts::SOCKS5_CMD_TCP_BIND, 69 | Command::UdpAssociate => consts::SOCKS5_CMD_UDP_ASSOCIATE, 70 | } 71 | } 72 | 73 | #[inline] 74 | #[rustfmt::skip] 75 | fn from_u8(code: u8) -> Option { 76 | match code { 77 | consts::SOCKS5_CMD_TCP_CONNECT => Some(Command::TcpConnect), 78 | consts::SOCKS5_CMD_TCP_BIND => Some(Command::TcpBind), 79 | consts::SOCKS5_CMD_UDP_ASSOCIATE => Some(Command::UdpAssociate), 80 | _ => None, 81 | } 82 | } 83 | } 84 | 85 | /// SOCKS5 reply code 86 | #[derive(Clone, Debug, Copy)] 87 | pub enum Reply { 88 | Succeeded, 89 | GeneralFailure, 90 | ConnectionNotAllowed, 91 | NetworkUnreachable, 92 | HostUnreachable, 93 | ConnectionRefused, 94 | TtlExpired, 95 | CommandNotSupported, 96 | AddressTypeNotSupported, 97 | 98 | OtherReply(u8), 99 | } 100 | 101 | impl Reply { 102 | #[inline] 103 | #[rustfmt::skip] 104 | fn as_u8(self) -> u8 { 105 | match self { 106 | Reply::Succeeded => consts::SOCKS5_REPLY_SUCCEEDED, 107 | Reply::GeneralFailure => consts::SOCKS5_REPLY_GENERAL_FAILURE, 108 | Reply::ConnectionNotAllowed => consts::SOCKS5_REPLY_CONNECTION_NOT_ALLOWED, 109 | Reply::NetworkUnreachable => consts::SOCKS5_REPLY_NETWORK_UNREACHABLE, 110 | Reply::HostUnreachable => consts::SOCKS5_REPLY_HOST_UNREACHABLE, 111 | Reply::ConnectionRefused => consts::SOCKS5_REPLY_CONNECTION_REFUSED, 112 | Reply::TtlExpired => consts::SOCKS5_REPLY_TTL_EXPIRED, 113 | Reply::CommandNotSupported => consts::SOCKS5_REPLY_COMMAND_NOT_SUPPORTED, 114 | Reply::AddressTypeNotSupported => consts::SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED, 115 | Reply::OtherReply(c) => c, 116 | } 117 | } 118 | 119 | #[inline] 120 | #[rustfmt::skip] 121 | fn from_u8(code: u8) -> Reply { 122 | match code { 123 | consts::SOCKS5_REPLY_SUCCEEDED => Reply::Succeeded, 124 | consts::SOCKS5_REPLY_GENERAL_FAILURE => Reply::GeneralFailure, 125 | consts::SOCKS5_REPLY_CONNECTION_NOT_ALLOWED => Reply::ConnectionNotAllowed, 126 | consts::SOCKS5_REPLY_NETWORK_UNREACHABLE => Reply::NetworkUnreachable, 127 | consts::SOCKS5_REPLY_HOST_UNREACHABLE => Reply::HostUnreachable, 128 | consts::SOCKS5_REPLY_CONNECTION_REFUSED => Reply::ConnectionRefused, 129 | consts::SOCKS5_REPLY_TTL_EXPIRED => Reply::TtlExpired, 130 | consts::SOCKS5_REPLY_COMMAND_NOT_SUPPORTED => Reply::CommandNotSupported, 131 | consts::SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED => Reply::AddressTypeNotSupported, 132 | _ => Reply::OtherReply(code), 133 | } 134 | } 135 | } 136 | 137 | impl fmt::Display for Reply { 138 | #[rustfmt::skip] 139 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 140 | match *self { 141 | Reply::Succeeded => write!(f, "Succeeded"), 142 | Reply::AddressTypeNotSupported => write!(f, "Address type not supported"), 143 | Reply::CommandNotSupported => write!(f, "Command not supported"), 144 | Reply::ConnectionNotAllowed => write!(f, "Connection not allowed"), 145 | Reply::ConnectionRefused => write!(f, "Connection refused"), 146 | Reply::GeneralFailure => write!(f, "General failure"), 147 | Reply::HostUnreachable => write!(f, "Host unreachable"), 148 | Reply::NetworkUnreachable => write!(f, "Network unreachable"), 149 | Reply::OtherReply(u) => write!(f, "Other reply ({})", u), 150 | Reply::TtlExpired => write!(f, "TTL expired"), 151 | } 152 | } 153 | } 154 | 155 | /// SOCKS5 protocol error 156 | #[derive(Clone)] 157 | pub struct Error { 158 | /// Reply code 159 | pub reply: Reply, 160 | /// Error message 161 | pub message: String, 162 | } 163 | 164 | impl Error { 165 | pub fn new(reply: Reply, message: S) -> Error 166 | where 167 | S: Into, 168 | { 169 | Error { 170 | reply, 171 | message: message.into(), 172 | } 173 | } 174 | } 175 | 176 | impl Debug for Error { 177 | #[inline] 178 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 179 | write!(f, "{}", self.message) 180 | } 181 | } 182 | 183 | impl fmt::Display for Error { 184 | #[inline] 185 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 186 | write!(f, "{}", self.message) 187 | } 188 | } 189 | 190 | impl error::Error for Error {} 191 | 192 | impl From for Error { 193 | fn from(err: io::Error) -> Error { 194 | Error::new(Reply::GeneralFailure, err.to_string()) 195 | } 196 | } 197 | 198 | impl From for io::Error { 199 | fn from(err: Error) -> io::Error { 200 | io::Error::new(io::ErrorKind::Other, err.message) 201 | } 202 | } 203 | 204 | /// SOCKS5 address type 205 | #[derive(Clone, PartialEq, Eq, Hash)] 206 | pub enum Address { 207 | /// Socket address (IP Address) 208 | SocketAddress(SocketAddr), 209 | /// Domain name address 210 | DomainNameAddress(String, u16), 211 | } 212 | 213 | impl Address { 214 | pub async fn read_from(stream: &mut R) -> Result 215 | where 216 | R: AsyncRead + Unpin, 217 | { 218 | let mut addr_type_buf = [0u8; 1]; 219 | let _ = stream.read_exact(&mut addr_type_buf).await?; 220 | 221 | let addr_type = addr_type_buf[0]; 222 | match addr_type { 223 | consts::SOCKS5_ADDR_TYPE_IPV4 => { 224 | let mut buf = BytesMut::with_capacity(6); 225 | buf.resize(6, 0); 226 | let _ = stream.read_exact(&mut buf).await?; 227 | 228 | let mut cursor = buf.to_bytes(); 229 | let v4addr = Ipv4Addr::new( 230 | cursor.get_u8(), 231 | cursor.get_u8(), 232 | cursor.get_u8(), 233 | cursor.get_u8(), 234 | ); 235 | let port = cursor.get_u16(); 236 | Ok(Address::SocketAddress(SocketAddr::V4(SocketAddrV4::new( 237 | v4addr, port, 238 | )))) 239 | } 240 | consts::SOCKS5_ADDR_TYPE_IPV6 => { 241 | let mut buf = [0u8; 18]; 242 | let _ = stream.read_exact(&mut buf).await?; 243 | 244 | let mut cursor = Cursor::new(&buf); 245 | let v6addr = Ipv6Addr::new( 246 | cursor.get_u16(), 247 | cursor.get_u16(), 248 | cursor.get_u16(), 249 | cursor.get_u16(), 250 | cursor.get_u16(), 251 | cursor.get_u16(), 252 | cursor.get_u16(), 253 | cursor.get_u16(), 254 | ); 255 | let port = cursor.get_u16(); 256 | 257 | Ok(Address::SocketAddress(SocketAddr::V6(SocketAddrV6::new( 258 | v6addr, port, 0, 0, 259 | )))) 260 | } 261 | consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME => { 262 | let mut length_buf = [0u8; 1]; 263 | let _ = stream.read_exact(&mut length_buf).await?; 264 | let length = length_buf[0] as usize; 265 | 266 | // Len(Domain) + Len(Port) 267 | let buf_length = length + 2; 268 | let mut buf = BytesMut::with_capacity(buf_length); 269 | buf.resize(buf_length, 0); 270 | let _ = stream.read_exact(&mut buf).await?; 271 | 272 | let mut cursor = buf.to_bytes(); 273 | let mut raw_addr = Vec::with_capacity(length); 274 | raw_addr.put(&mut BufExt::take(&mut cursor, length)); 275 | let addr = match String::from_utf8(raw_addr) { 276 | Ok(addr) => addr, 277 | Err(..) => { 278 | return Err(Error::new( 279 | Reply::GeneralFailure, 280 | "invalid address encoding", 281 | )) 282 | } 283 | }; 284 | let port = cursor.get_u16(); 285 | 286 | Ok(Address::DomainNameAddress(addr, port)) 287 | } 288 | _ => { 289 | // Wrong Address Type . Socks5 only supports ipv4, ipv6 and domain name 290 | Err(Error::new( 291 | Reply::AddressTypeNotSupported, 292 | format!("not supported address type {:#x}", addr_type), 293 | )) 294 | } 295 | } 296 | } 297 | 298 | /// Writes to writer 299 | #[inline] 300 | pub async fn write_to(&self, writer: &mut W) -> io::Result<()> 301 | where 302 | W: AsyncWrite + Unpin, 303 | { 304 | let mut buf = BytesMut::with_capacity(self.serialized_len()); 305 | self.write_to_buf(&mut buf); 306 | writer.write_all(&buf).await 307 | } 308 | 309 | /// Writes to buffer 310 | #[inline] 311 | pub fn write_to_buf(&self, buf: &mut B) { 312 | write_address(self, buf) 313 | } 314 | 315 | #[inline] 316 | pub fn serialized_len(&self) -> usize { 317 | get_addr_len(self) 318 | } 319 | } 320 | 321 | impl Debug for Address { 322 | #[inline] 323 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 324 | match *self { 325 | Address::SocketAddress(ref addr) => write!(f, "{}", addr), 326 | Address::DomainNameAddress(ref addr, ref port) => write!(f, "{}:{}", addr, port), 327 | } 328 | } 329 | } 330 | 331 | impl fmt::Display for Address { 332 | #[inline] 333 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 334 | match *self { 335 | Address::SocketAddress(ref addr) => write!(f, "{}", addr), 336 | Address::DomainNameAddress(ref addr, ref port) => write!(f, "{}:{}", addr, port), 337 | } 338 | } 339 | } 340 | 341 | impl ToSocketAddrs for Address { 342 | type Iter = vec::IntoIter; 343 | 344 | fn to_socket_addrs(&self) -> io::Result> { 345 | match self.clone() { 346 | Address::SocketAddress(addr) => Ok(vec![addr].into_iter()), 347 | Address::DomainNameAddress(addr, port) => (&addr[..], port).to_socket_addrs(), 348 | } 349 | } 350 | } 351 | 352 | impl From for Address { 353 | fn from(s: SocketAddr) -> Address { 354 | Address::SocketAddress(s) 355 | } 356 | } 357 | 358 | impl From<(String, u16)> for Address { 359 | fn from((dn, port): (String, u16)) -> Address { 360 | Address::DomainNameAddress(dn, port) 361 | } 362 | } 363 | 364 | /// Parse `Address` error 365 | #[derive(Debug)] 366 | pub struct AddressError; 367 | 368 | impl FromStr for Address { 369 | type Err = AddressError; 370 | 371 | fn from_str(s: &str) -> Result { 372 | match s.parse::() { 373 | Ok(addr) => Ok(Address::SocketAddress(addr)), 374 | Err(..) => { 375 | let mut sp = s.split(':'); 376 | match (sp.next(), sp.next()) { 377 | (Some(dn), Some(port)) => match port.parse::() { 378 | Ok(port) => Ok(Address::DomainNameAddress(dn.to_owned(), port)), 379 | Err(..) => Err(AddressError), 380 | }, 381 | (Some(dn), None) => { 382 | // Assume it is 80 (http's default port) 383 | Ok(Address::DomainNameAddress(dn.to_owned(), 80)) 384 | } 385 | _ => Err(AddressError), 386 | } 387 | } 388 | } 389 | } 390 | } 391 | 392 | fn write_ipv4_address(addr: &SocketAddrV4, buf: &mut B) { 393 | buf.put_u8(consts::SOCKS5_ADDR_TYPE_IPV4); // Address type 394 | buf.put_slice(&addr.ip().octets()); // Ipv4 bytes 395 | buf.put_u16(addr.port()); // Port 396 | } 397 | 398 | fn write_ipv6_address(addr: &SocketAddrV6, buf: &mut B) { 399 | buf.put_u8(consts::SOCKS5_ADDR_TYPE_IPV6); // Address type 400 | for seg in &addr.ip().segments() { 401 | buf.put_u16(*seg); // Ipv6 bytes 402 | } 403 | buf.put_u16(addr.port()); // Port 404 | } 405 | 406 | fn write_domain_name_address(dnaddr: &str, port: u16, buf: &mut B) { 407 | assert!(dnaddr.len() <= u8::max_value() as usize); 408 | 409 | buf.put_u8(consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME); 410 | buf.put_u8(dnaddr.len() as u8); 411 | buf.put_slice(dnaddr[..].as_bytes()); 412 | buf.put_u16(port); 413 | } 414 | 415 | fn write_socket_address(addr: &SocketAddr, buf: &mut B) { 416 | match *addr { 417 | SocketAddr::V4(ref addr) => write_ipv4_address(addr, buf), 418 | SocketAddr::V6(ref addr) => write_ipv6_address(addr, buf), 419 | } 420 | } 421 | 422 | fn write_address(addr: &Address, buf: &mut B) { 423 | match *addr { 424 | Address::SocketAddress(ref addr) => write_socket_address(addr, buf), 425 | Address::DomainNameAddress(ref dnaddr, ref port) => { 426 | write_domain_name_address(dnaddr, *port, buf) 427 | } 428 | } 429 | } 430 | 431 | #[inline] 432 | fn get_addr_len(atyp: &Address) -> usize { 433 | match *atyp { 434 | Address::SocketAddress(SocketAddr::V4(..)) => 1 + 4 + 2, 435 | Address::SocketAddress(SocketAddr::V6(..)) => 1 + 8 * 2 + 2, 436 | Address::DomainNameAddress(ref dmname, _) => 1 + 1 + dmname.len() + 2, 437 | } 438 | } 439 | 440 | /// TCP request header after handshake 441 | /// 442 | /// ```plain 443 | /// +----+-----+-------+------+----------+----------+ 444 | /// |VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT | 445 | /// +----+-----+-------+------+----------+----------+ 446 | /// | 1 | 1 | X'00' | 1 | Variable | 2 | 447 | /// +----+-----+-------+------+----------+----------+ 448 | /// ``` 449 | #[derive(Clone, Debug)] 450 | pub struct TcpRequestHeader { 451 | /// SOCKS5 command 452 | pub command: Command, 453 | /// Remote address 454 | pub address: Address, 455 | } 456 | 457 | impl TcpRequestHeader { 458 | /// Creates a request header 459 | pub fn new(cmd: Command, addr: Address) -> TcpRequestHeader { 460 | TcpRequestHeader { 461 | command: cmd, 462 | address: addr, 463 | } 464 | } 465 | 466 | /// Read from a reader 467 | pub async fn read_from(r: &mut R) -> Result 468 | where 469 | R: AsyncRead + Unpin, 470 | { 471 | let mut buf = [0u8; 3]; 472 | let _ = r.read_exact(&mut buf).await?; 473 | 474 | let ver = buf[0]; 475 | if ver != consts::SOCKS5_VERSION { 476 | return Err(Error::new( 477 | Reply::ConnectionRefused, 478 | format!("unsupported socks version {:#x}", ver), 479 | )); 480 | } 481 | 482 | let cmd = buf[1]; 483 | let command = match Command::from_u8(cmd) { 484 | Some(c) => c, 485 | None => { 486 | return Err(Error::new( 487 | Reply::CommandNotSupported, 488 | format!("unsupported command {:#x}", cmd), 489 | )); 490 | } 491 | }; 492 | 493 | let address = Address::read_from(r).await?; 494 | Ok(TcpRequestHeader { command, address }) 495 | } 496 | 497 | /// Write data into a writer 498 | pub async fn write_to(&self, w: &mut W) -> io::Result<()> 499 | where 500 | W: AsyncWrite + Unpin, 501 | { 502 | let mut buf = BytesMut::with_capacity(self.serialized_len()); 503 | self.write_to_buf(&mut buf); 504 | w.write_all(&buf).await 505 | } 506 | 507 | /// Writes to buffer 508 | pub fn write_to_buf(&self, buf: &mut B) { 509 | let TcpRequestHeader { 510 | ref address, 511 | ref command, 512 | } = *self; 513 | 514 | buf.put_slice(&[consts::SOCKS5_VERSION, command.as_u8(), 0x00]); 515 | address.write_to_buf(buf); 516 | } 517 | 518 | /// Length in bytes 519 | #[inline] 520 | pub fn serialized_len(&self) -> usize { 521 | self.address.serialized_len() + 3 522 | } 523 | } 524 | 525 | /// TCP response header 526 | /// 527 | /// ```plain 528 | /// +----+-----+-------+------+----------+----------+ 529 | /// |VER | REP | RSV | ATYP | BND.ADDR | BND.PORT | 530 | /// +----+-----+-------+------+----------+----------+ 531 | /// | 1 | 1 | X'00' | 1 | Variable | 2 | 532 | /// +----+-----+-------+------+----------+----------+ 533 | /// ``` 534 | #[derive(Clone, Debug)] 535 | pub struct TcpResponseHeader { 536 | /// SOCKS5 reply 537 | pub reply: Reply, 538 | /// Reply address 539 | pub address: Address, 540 | } 541 | 542 | impl TcpResponseHeader { 543 | /// Creates a response header 544 | pub fn new(reply: Reply, address: Address) -> TcpResponseHeader { 545 | TcpResponseHeader { reply, address } 546 | } 547 | 548 | /// Read from a reader 549 | pub async fn read_from(r: &mut R) -> Result 550 | where 551 | R: AsyncRead + Unpin, 552 | { 553 | let mut buf = [0u8; 3]; 554 | let _ = r.read_exact(&mut buf).await?; 555 | 556 | let ver = buf[0]; 557 | let reply_code = buf[1]; 558 | 559 | if ver != consts::SOCKS5_VERSION { 560 | return Err(Error::new( 561 | Reply::ConnectionRefused, 562 | format!("unsupported socks version {:#x}", ver), 563 | )); 564 | } 565 | 566 | let address = Address::read_from(r).await?; 567 | 568 | Ok(TcpResponseHeader { 569 | reply: Reply::from_u8(reply_code), 570 | address, 571 | }) 572 | } 573 | 574 | /// Write to a writer 575 | pub async fn write_to(&self, w: &mut W) -> io::Result<()> 576 | where 577 | W: AsyncWrite + Unpin, 578 | { 579 | let mut buf = BytesMut::with_capacity(self.serialized_len()); 580 | self.write_to_buf(&mut buf); 581 | w.write_all(&buf).await 582 | } 583 | 584 | /// Writes to buffer 585 | pub fn write_to_buf(&self, buf: &mut B) { 586 | let TcpResponseHeader { 587 | ref reply, 588 | ref address, 589 | } = *self; 590 | buf.put_slice(&[consts::SOCKS5_VERSION, reply.as_u8(), 0x00]); 591 | address.write_to_buf(buf); 592 | } 593 | 594 | /// Length in bytes 595 | #[inline] 596 | pub fn serialized_len(&self) -> usize { 597 | self.address.serialized_len() + 3 598 | } 599 | } 600 | 601 | /// SOCKS5 handshake request packet 602 | /// 603 | /// ```plain 604 | /// +----+----------+----------+ 605 | /// |VER | NMETHODS | METHODS | 606 | /// +----+----------+----------+ 607 | /// | 5 | 1 | 1 to 255 | 608 | /// +----+----------+----------| 609 | /// ``` 610 | #[derive(Clone, Debug)] 611 | pub struct HandshakeRequest { 612 | pub methods: Vec, 613 | } 614 | 615 | impl HandshakeRequest { 616 | /// Creates a handshake request 617 | pub fn new(methods: Vec) -> HandshakeRequest { 618 | HandshakeRequest { methods } 619 | } 620 | 621 | /// Read from a reader 622 | pub async fn read_from(r: &mut R) -> io::Result 623 | where 624 | R: AsyncRead + Unpin, 625 | { 626 | let mut buf = [0u8; 2]; 627 | let _ = r.read_exact(&mut buf).await?; 628 | 629 | let ver = buf[0]; 630 | let nmet = buf[1]; 631 | 632 | if ver != consts::SOCKS5_VERSION { 633 | use std::io::{Error, ErrorKind}; 634 | let err = Error::new( 635 | ErrorKind::InvalidData, 636 | format!("unsupported socks version {:#x}", ver), 637 | ); 638 | return Err(err); 639 | } 640 | 641 | let mut methods = vec![0u8; nmet as usize]; 642 | let _ = r.read_exact(&mut methods).await?; 643 | 644 | Ok(HandshakeRequest { methods }) 645 | } 646 | 647 | /// Write to a writer 648 | pub async fn write_to(&self, w: &mut W) -> io::Result<()> 649 | where 650 | W: AsyncWrite + Unpin, 651 | { 652 | let mut buf = BytesMut::with_capacity(self.serialized_len()); 653 | self.write_to_buf(&mut buf); 654 | w.write_all(&buf).await 655 | } 656 | 657 | /// Write to buffer 658 | pub fn write_to_buf(&self, buf: &mut B) { 659 | let HandshakeRequest { ref methods } = *self; 660 | buf.put_slice(&[consts::SOCKS5_VERSION, methods.len() as u8]); 661 | buf.put_slice(&methods); 662 | } 663 | 664 | /// Get length of bytes 665 | pub fn serialized_len(&self) -> usize { 666 | 2 + self.methods.len() 667 | } 668 | } 669 | 670 | /// SOCKS5 handshake response packet 671 | /// 672 | /// ```plain 673 | /// +----+--------+ 674 | /// |VER | METHOD | 675 | /// +----+--------+ 676 | /// | 1 | 1 | 677 | /// +----+--------+ 678 | /// ``` 679 | #[derive(Clone, Debug, Copy)] 680 | pub struct HandshakeResponse { 681 | pub chosen_method: u8, 682 | } 683 | 684 | impl HandshakeResponse { 685 | /// Creates a handshake response 686 | pub fn new(cm: u8) -> HandshakeResponse { 687 | HandshakeResponse { chosen_method: cm } 688 | } 689 | 690 | /// Read from a reader 691 | pub async fn read_from(r: &mut R) -> io::Result 692 | where 693 | R: AsyncRead + Unpin, 694 | { 695 | let mut buf = [0u8; 2]; 696 | let _ = r.read_exact(&mut buf).await?; 697 | 698 | let ver = buf[0]; 699 | let met = buf[1]; 700 | 701 | if ver != consts::SOCKS5_VERSION { 702 | use std::io::{Error, ErrorKind}; 703 | let err = Error::new( 704 | ErrorKind::InvalidData, 705 | format!("unsupported socks version {:#x}", ver), 706 | ); 707 | Err(err) 708 | } else { 709 | Ok(HandshakeResponse { chosen_method: met }) 710 | } 711 | } 712 | 713 | /// Write to a writer 714 | pub async fn write_to(self, w: &mut W) -> io::Result<()> 715 | where 716 | W: AsyncWrite + Unpin, 717 | { 718 | let mut buf = BytesMut::with_capacity(self.serialized_len()); 719 | self.write_to_buf(&mut buf); 720 | w.write_all(&buf).await 721 | } 722 | 723 | /// Write to buffer 724 | pub fn write_to_buf(self, buf: &mut B) { 725 | buf.put_slice(&[consts::SOCKS5_VERSION, self.chosen_method]); 726 | } 727 | 728 | /// Length in bytes 729 | pub fn serialized_len(self) -> usize { 730 | 2 731 | } 732 | } 733 | 734 | /// UDP ASSOCIATE request header 735 | /// 736 | /// ```plain 737 | /// +----+------+------+----------+----------+----------+ 738 | /// |RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA | 739 | /// +----+------+------+----------+----------+----------+ 740 | /// | 2 | 1 | 1 | Variable | 2 | Variable | 741 | /// +----+------+------+----------+----------+----------+ 742 | /// ``` 743 | #[derive(Clone, Debug)] 744 | pub struct UdpAssociateHeader { 745 | /// Fragment 746 | /// 747 | /// ShadowSocks does not support fragment, so this frag must be 0x00 748 | pub frag: u8, 749 | /// Remote address 750 | pub address: Address, 751 | } 752 | 753 | impl UdpAssociateHeader { 754 | /// Creates a header 755 | pub fn new(frag: u8, address: Address) -> UdpAssociateHeader { 756 | UdpAssociateHeader { frag, address } 757 | } 758 | 759 | /// Read from a reader 760 | pub async fn read_from(r: &mut R) -> Result 761 | where 762 | R: AsyncRead + Unpin, 763 | { 764 | let mut buf = [0u8; 3]; 765 | let _ = r.read_exact(&mut buf).await?; 766 | 767 | let frag = buf[2]; 768 | let address = Address::read_from(r).await?; 769 | Ok(UdpAssociateHeader::new(frag, address)) 770 | } 771 | 772 | /// Write to a writer 773 | pub async fn write_to(&self, w: &mut W) -> io::Result<()> 774 | where 775 | W: AsyncWrite + Unpin, 776 | { 777 | let mut buf = BytesMut::with_capacity(self.serialized_len()); 778 | self.write_to_buf(&mut buf); 779 | w.write_all(&buf).await 780 | } 781 | 782 | /// Write to buffer 783 | pub fn write_to_buf(&self, buf: &mut B) { 784 | let UdpAssociateHeader { 785 | ref frag, 786 | ref address, 787 | } = *self; 788 | buf.put_slice(&[0x00, 0x00, *frag]); 789 | address.write_to_buf(buf); 790 | } 791 | 792 | /// Length in bytes 793 | #[inline] 794 | pub fn serialized_len(&self) -> usize { 795 | 3 + self.address.serialized_len() 796 | } 797 | } 798 | --------------------------------------------------------------------------------