├── .cargo └── config.toml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── easydns.toml ├── rust-toolchain.toml ├── src ├── cache │ ├── cache_record │ │ ├── ip_record.rs │ │ ├── mod.rs │ │ └── soa_record.rs │ ├── expired_strategy.rs │ ├── limit_map.rs │ ├── mod.rs │ └── timeout_strategy.rs ├── client.rs ├── config.rs ├── cursor │ ├── array_buf.rs │ └── mod.rs ├── filter.rs ├── handler │ ├── cache_handler.rs │ ├── domain_filter.rs │ ├── ip_maker.rs │ ├── legal_checker.rs │ ├── mod.rs │ ├── query_sender.rs │ └── server_group │ │ ├── combine_server_sender.rs │ │ ├── fast_server_sender.rs │ │ ├── mod.rs │ │ ├── prefer_server_sender.rs │ │ └── query_executor.rs ├── main.rs ├── protocol │ ├── answer │ │ ├── failure.rs │ │ ├── ipv4.rs │ │ ├── mod.rs │ │ ├── no_such_name.rs │ │ ├── resource │ │ │ ├── basic.rs │ │ │ ├── cname.rs │ │ │ ├── ipv4.rs │ │ │ ├── mod.rs │ │ │ └── soa.rs │ │ └── soa.rs │ ├── basic.rs │ ├── header.rs │ ├── mod.rs │ ├── query │ │ └── mod.rs │ └── question.rs └── system.rs └── tests ├── resources ├── covercast_filter.txt └── test_filter.txt └── test.rs /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [target.mipsel-unknown-linux-uclibc] 2 | runner = "cargo exec runner" 3 | linker = "/home/dmj/mipsel-linux-uclibc/bin/mipsel-linux-uclibc-cc" 4 | 5 | [profile.release] 6 | lto = true 7 | opt-level = 'z' 8 | codegen-units = 1 9 | 10 | 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | target/* 3 | Cargo.lock 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "easydns" 3 | version = "0.1.0" 4 | edition = "2018" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | dashmap = " 4.0.2" 10 | tokio = { version = "1.9.0", features = ["full"] } 11 | futures-util = "0.3.16" 12 | tokio-icmp = "0.4.0" 13 | async-trait = "0.1.51" 14 | regex = "1.5" 15 | simple_logger = "1.13" 16 | log = "0.4" 17 | toml = "0.5" 18 | 19 | [patch.crates-io] 20 | socket2 = { git = "https://github.com/dunmengjun/socket2.git" } 21 | tokio-icmp = { git = "https://github.com/dunmengjun/tokio-icmp.git", branch = "main" } 22 | 23 | [package.metadata.scripts] 24 | runner = """qemu-mipsel -L /home/dmj/mipsel-linux-uclibc/mipsel-linux-uclibc/sysroot {0}""" 25 | build = "cargo build --target mipsel-unknown-linux-uclibc -Z build-std" 26 | run = "cargo run --target mipsel-unknown-linux-uclibc -Z build-std" 27 | test = "cargo test --target mipsel-unknown-linux-uclibc -Z build-std" 28 | release = "cargo build --target mipsel-unknown-linux-uclibc --release -Z build-std" 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 dunmengjun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Easydns 2 | 3 | 1. 此项目的目的时做一个足够简单高效的dns透传优选和屏蔽广告服务,类似smartdns, 但不会有 smartdns那么多的功能,只会提供核心关键的功能以保持简单高效。 4 | 2. 此项目编译到 [Padavan](https://github.com/hanwckf/rt-n56u) 算是rust依赖mipsel toolchain交叉编译到Padavan的模板 5 | 3. 百分之百用Rust语言开发 6 | 7 | ### 功能完成度 8 | 9 | 假定的场景是:dns请求中的问题只有一个, 多个的情况属于容错,事实上现在dns请求带多个问题的场景基本已经没有了,家用是碰不到的。 10 | 11 | - [x] A(ipv4)记录的透传并过整条链路(缓存和优选) 12 | - [ ] AAAA(ipv6)记录的透传并过整条链路(缓存和优选) 13 | - [x] 所有其他记录容错(直接给到上游服务器去返回,不过缓存和优选) 14 | - [x] 缓存(根据ttl时间, 最大条数限制) 15 | - [x] 多线程(tokio实现) 16 | - [x] 缓存持久化(存本地文件,下次启动时load) 17 | - [x] 域名过滤(过滤广告,返回soa) 18 | - [x] 返回soa 19 | - [x] 从文件中读 20 | - [x] 从网址中读 21 | - [x] dns优选 22 | - [x] 上游dns服务器优选 23 | - [x] 返回的IP地址优选 24 | - [x] ping协议 (需要root权限或者给程序设置cap_net_raw) 25 | - [ ] tcp协议 26 | - [ ] 80端口网页中的域名预加载到缓存 27 | - [x] 参数配置化(统一的配置文件) 28 | - [x] 标准日志 29 | - [ ] github action自动编译 30 | - [ ] 测试(单元测试,稳定性测试,性能测试) 31 | - [ ] 其他平台的编译(linux, windows, macos) 32 | 33 | ### 必须的依赖: 34 | 1. rust环境 35 | 2. mipsel 36 | toolchain [下载地址](https://github.com/hanwckf/padavan-toolchain/releases/download/v1.1/mipsel-linux-uclibc.tar.xz) 37 | 下完解压即可 38 | 3. cargo exec [仓库地址](https://github.com/dunmengjun/cargo-exec) clone完成后 39 | ```shell 40 | cargo install --path . 41 | ``` 42 | 4. curl, 系统上必须安装 curl 43 | ### 可选依赖: 44 | 45 | 4. 在本机上测试和运行依赖于 qemu-mipsel 5.2.0, 单纯编译不需要。 46 | 47 | ## 运行项目 48 | 49 | ### 自动rust版本同步 50 | 51 | ```shell 52 | rustup show 53 | ``` 54 | 55 | ### 编译 56 | 57 | ```shell 58 | cargo exec build 59 | ``` 60 | 61 | ### 运行 62 | 63 | ```shell 64 | cargo exec run 65 | ``` 66 | 67 | ### 单元测试 68 | 69 | ```shell 70 | cargo exec test 71 | ``` 72 | 73 | ### 发布 74 | 75 | ```shell 76 | cargo exec release 77 | ``` 78 | 79 | ### 调试 80 | 81 | 调试依赖 82 | 83 | 1. qemu-mipsel 5.2.0 84 | 2. mipsel toolchain 85 | 3. mipsel gdb ([安装教程](https://blog.csdn.net/zqj6893/article/details/84662579)) 86 | 87 | 有了这三个工具调试就很简单了 88 | 89 | #### gdb server 启动 90 | 91 | ```shell 92 | qemu-mipsel -L /path to mipsel toolchain/mipsel-linux-uclibc/sysroot -g 1234 ./target/mipsel-unknown-linux-uclibc/debug/easydns 93 | ``` 94 | 95 | #### gdb client连接 96 | 97 | ```shell 98 | $ ./mipsel-linux-gdb 99 | GNU gdb (GDB) 9.2 100 | Copyright (C) 2020 Free Software Foundation, Inc. 101 | License GPLv3+: GNU GPL version 3 or later 102 | This is free software: you are free to change and redistribute it. 103 | There is NO WARRANTY, to the extent permitted by law. 104 | Type "show copying" and "show warranty" for details. 105 | This GDB was configured as "--host=x86_64-pc-linux-gnu --target=mipsel-linux". 106 | Type "show configuration" for configuration details. 107 | For bug reporting instructions, please see: 108 | . 109 | Find the GDB manual and other documentation resources online at: 110 | . 111 | 112 | For help, type "help". 113 | Type "apropos word" to search for commands related to "word". 114 | (gdb) target remote localhost:1234 115 | Remote debugging using localhost:1234 116 | ``` 117 | 118 | 在clion这种IDE上可以以Remote GDB Server来配置,核心就是上面两个命令 119 | 120 | ### 常用命令 121 | 122 | 设置cap_net_raw 123 | 124 | ```shell 125 | sudo setcap cap_net_raw=eip {程序名称} 126 | ``` -------------------------------------------------------------------------------- /easydns.toml: -------------------------------------------------------------------------------- 1 | # 接受客户端请求的端口 2 | # 等于0是随机port 3 | port = 2053 4 | 5 | #上游dns服务器,目前没有分组的功能,主要是认为用处不大 6 | servers = [ 7 | "114.114.114.114:53", 8 | "8.8.8.8:53", 9 | "1.1.1.1:53" 10 | ] 11 | # 默认是0 定时优选,选取最快的server, (最快的server返回就返回,只会发一个请求,但由于是定时,所以一段时间内不会更新最快的server) 12 | # 1是每次新请求都优选,从最快的server获取结果,(最快的server返回就返回,但可能会实际发送n个请求) 13 | # 2是不优选,从所有server获取结果(会等待所有的server返回, 实际发送n个请求,等待耗时最长的那个返回就返回) 14 | server-choose-strategy = 0 15 | 16 | # server-choose-strategy=0 时此项生效, 代表定时优选的时间间隔 17 | # 单位是小时 18 | server-choose-duration-h = 12 19 | 20 | # 缓存设置为false,并且ip优选策略是1 会严重影响性能,因为会走两个串行的请求,一个是要从server获取返回的ip,二是要ping返回的ip 21 | # 这两个请求是不能并行的,所以推荐把缓存开着 22 | # 缓存是根据ttl时间设置的,ttl过期了会自动删除 23 | cache = true 24 | cache-num = 1000 25 | cache-file = "cache" 26 | 27 | # 缓存获取策略,默认是0, 就是严格遵循ttl值来,过期了就去同步的取上游dns server的返回值放入缓存 28 | # 1 是在ttl过期之后,请求进来还是先返回过期的记录,之后服务器再去异步的请求上游dns服务器的返回值放入缓存,保证下次用户取的是最新值 29 | cache-get-strategy = 0 30 | 31 | # 缓存策略是1的情况下, 此项生效, 会在dns记录的ttl时间过期之后返回过期的记录直到新的记录从上游服务器获取到并插入缓存 32 | # 或者超过了设置的时间就同步去上游服务器取新的记录 33 | cache-ttl-timeout-ms = 60000 34 | 35 | # 浏览器场景下,就算dns服务返回多个ip,浏览器也是默认取第一个,所以这里的默认策略是返回第一个ip 36 | # 其他大多数场景下也应该和浏览器是一致的 37 | # 默认是0 选择第一个ip, 1是利用ping协议选择最小延迟的一个ip 38 | ip-choose-strategy = 0 39 | 40 | # 在filter文件中的值会被拦截并返回soa记录,可以用于dns方式去广告 41 | # 值可以是文件路径或者是url路径, 会自动去重,里面的条目会从下往上覆盖 42 | # 格式是这种 1. address /00-gov.cn/# 加入到filter, 2. address /00-gov.cn/d, 从已经存在的filter集合中删除 43 | # 如果条目格式错误,会被忽略,不会报错,日志中只有debug模式会有这个日志 44 | filters = [ 45 | # "https://raw.githubusercontent.com/dunmengjun/SmartDNS-GFWList/master/smartdns_anti_ad.conf", 46 | "./tests/resources/covercast_filter.txt", 47 | ] 48 | 49 | # 六个值 trace debug info warn error off 从前往后日志越少。大小写都可 50 | # off是不输出日志 51 | # 日志分割,请直接搜索一下,linux下有现成的命令,很简单就可以配置,这里就不原生提供这个功能了 52 | log-level = "debug" -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "nightly-2020-06-03" 3 | components = ["rust-src"] 4 | targets = ["mipsel-unknown-linux-uclibc"] 5 | profile = "default" -------------------------------------------------------------------------------- /src/cache/cache_record/ip_record.rs: -------------------------------------------------------------------------------- 1 | use crate::system::{get_now}; 2 | use crate::cache::cache_record::{CacheItem, IP_RECORD}; 3 | use crate::cursor::Cursor; 4 | use crate::protocol::{DnsAnswer, Ipv4Answer}; 5 | use std::net::Ipv4Addr; 6 | 7 | #[derive(Clone, PartialOrd, PartialEq, Debug)] 8 | pub struct IpCacheRecord { 9 | pub domain: String, 10 | pub address: Ipv4Addr, 11 | pub create_time: u128, 12 | pub ttl_ms: u128, 13 | } 14 | 15 | impl CacheItem for IpCacheRecord { 16 | fn get_create_time(&self) -> u128 { 17 | self.create_time 18 | } 19 | 20 | fn get_ttl_ms(&self) -> u128 { 21 | self.ttl_ms 22 | } 23 | 24 | fn get_key(&self) -> &String { 25 | &self.domain 26 | } 27 | 28 | fn to_bytes(&self) -> Vec { 29 | self.into() 30 | } 31 | 32 | fn to_answer(&self) -> DnsAnswer { 33 | Ipv4Answer::from(self).into() 34 | } 35 | } 36 | 37 | impl IpCacheRecord { 38 | pub fn get_address(&self) -> &Ipv4Addr { 39 | &self.address 40 | } 41 | } 42 | 43 | impl From<&IpCacheRecord> for Vec { 44 | fn from(record: &IpCacheRecord) -> Self { 45 | let mut vec = Vec::::new(); 46 | vec.push(IP_RECORD);//插入魔数 47 | vec.push(record.domain.len() as u8); 48 | vec.extend(record.domain.as_bytes()); 49 | vec.push(4); 50 | vec.extend(&(record.get_remain_time(get_now()) as u32).to_be_bytes()); 51 | vec.push(16); 52 | vec.extend(&record.create_time.to_be_bytes()); 53 | vec.push(4); 54 | vec.extend(&record.address.octets()); 55 | vec 56 | } 57 | } 58 | 59 | impl From<&[u8]> for IpCacheRecord { 60 | fn from(bytes: &[u8]) -> Self { 61 | let cursor = Cursor::form(Vec::from(bytes).into()); 62 | cursor.take(); //删掉魔数 63 | let len = cursor.take() as usize; 64 | let domain = String::from_utf8(Vec::from(cursor.take_slice(len))).unwrap(); 65 | cursor.take(); 66 | let ttl_ms = u32::from_be_bytes(cursor.take_bytes()) as u128; 67 | cursor.take(); 68 | let create_time = u128::from_be_bytes(cursor.take_bytes()); 69 | cursor.take(); 70 | let address = Ipv4Addr::from(cursor.take_bytes()); 71 | IpCacheRecord { 72 | domain, 73 | address, 74 | create_time, 75 | ttl_ms, 76 | } 77 | } 78 | } 79 | 80 | impl From<&Ipv4Answer> for IpCacheRecord { 81 | fn from(answer: &Ipv4Answer) -> Self { 82 | IpCacheRecord { 83 | domain: answer.get_name().clone(), 84 | address: answer.get_address().clone(), 85 | create_time: get_now(), 86 | ttl_ms: answer.get_ttl() as u128 * 1000, 87 | } 88 | } 89 | } 90 | 91 | #[cfg(test)] 92 | pub mod tests { 93 | use crate::cache::{IpCacheRecord, CacheItem, CacheRecord}; 94 | use crate::system::{TIME, get_now}; 95 | use crate::cache::limit_map::GetOrdKey; 96 | use crate::protocol::tests::get_ip_answer; 97 | use std::net::Ipv4Addr; 98 | 99 | #[test] 100 | fn should_return_valid_record_when_create_from_bytes_given_valid_bytes() { 101 | let vec = get_test_bytes(); 102 | let valid_bytes = vec.as_slice(); 103 | 104 | let result = IpCacheRecord::from(valid_bytes); 105 | 106 | let expected = get_ip_record(); 107 | assert_eq!(expected, result) 108 | } 109 | 110 | #[test] 111 | fn should_return_bytes_when_to_from_bytes_given_valid_bytes() { 112 | let record = get_ip_record(); 113 | 114 | let result = record.to_bytes(); 115 | 116 | let expected = get_test_bytes(); 117 | assert_eq!(expected, result) 118 | } 119 | 120 | #[test] 121 | fn should_return_valid_record_when_from_answer_given_valid_answer() { 122 | let answer = get_ip_answer(); 123 | TIME.with(|t| { 124 | t.borrow_mut().set_timestamp(0); 125 | }); 126 | 127 | let result = answer.to_cache().unwrap(); 128 | 129 | let expected: CacheRecord = get_ip_record().into(); 130 | assert!(expected.eq(&result)) 131 | } 132 | 133 | fn get_test_bytes() -> Vec { 134 | let bytes: [u8; 44] = [42, 15, 3, 119, 119, 119, 5, 98, 97, 105, 100, 117, 3, 99, 111, 109, 0, 4, 0, 0, 3, 232, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 1, 1, 1, 1]; 135 | let mut vec = Vec::with_capacity(44); 136 | for c in bytes.iter() { 137 | vec.push(c.clone()) 138 | } 139 | vec 140 | } 141 | 142 | pub fn get_ip_record() -> IpCacheRecord { 143 | IpCacheRecord { 144 | domain: "www.baidu.com".to_string(), 145 | address: Ipv4Addr::from([1, 1, 1, 1]), 146 | create_time: 0, 147 | ttl_ms: 1000, 148 | } 149 | } 150 | } -------------------------------------------------------------------------------- /src/cache/cache_record/mod.rs: -------------------------------------------------------------------------------- 1 | mod ip_record; 2 | mod soa_record; 3 | 4 | use crate::system::get_now; 5 | use crate::cache::limit_map::GetOrdKey; 6 | 7 | pub use ip_record::IpCacheRecord; 8 | pub use soa_record::SoaCacheRecord; 9 | use std::fmt::{Debug, Formatter}; 10 | use crate::protocol::DnsAnswer; 11 | 12 | pub type CacheRecord = Box; 13 | 14 | pub const IP_RECORD: u8 = '*' as u8; 15 | pub const SOA_RECORD: u8 = '#' as u8; 16 | 17 | pub trait Expired { 18 | fn is_expired(&self, timestamp: u128) -> bool; 19 | } 20 | 21 | impl Expired for CacheRecord { 22 | fn is_expired(&self, timestamp: u128) -> bool { 23 | let duration = timestamp - self.get_create_time(); 24 | self.get_ttl_ms() < duration 25 | } 26 | } 27 | 28 | pub trait CacheItem: Sync + Send + BoxedClone { 29 | fn get_remain_time(&self, timestamp: u128) -> u128 { 30 | let duration = timestamp - self.get_create_time(); 31 | if self.get_ttl_ms() > duration { 32 | self.get_ttl_ms() - duration 33 | } else { 34 | 0 35 | } 36 | } 37 | fn get_create_time(&self) -> u128; 38 | fn get_ttl_ms(&self) -> u128; 39 | fn get_key(&self) -> &String; 40 | fn to_bytes(&self) -> Vec; 41 | fn to_answer(&self) -> DnsAnswer; 42 | } 43 | 44 | pub trait BoxedClone { 45 | fn boxed_clone(&self) -> CacheRecord; 46 | } 47 | 48 | impl BoxedClone for T where T: 'static + Clone + CacheItem { 49 | fn boxed_clone(&self) -> CacheRecord { 50 | Box::new(self.clone()) 51 | } 52 | } 53 | 54 | impl Clone for CacheRecord { 55 | fn clone(&self) -> Self { 56 | self.boxed_clone() 57 | } 58 | } 59 | 60 | impl GetOrdKey for CacheRecord { 61 | type Output = u128; 62 | fn get_order_key(&self) -> Self::Output { 63 | self.get_remain_time(get_now()) 64 | } 65 | } 66 | 67 | impl Debug for CacheRecord { 68 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 69 | f.debug_tuple("") 70 | .field(self.get_key()) 71 | .field(&self.get_create_time()) 72 | .field(&self.get_remain_time(get_now())) 73 | .field(&self.get_ttl_ms()) 74 | .finish() 75 | } 76 | } 77 | 78 | impl From for CacheRecord { 79 | fn from(record: IpCacheRecord) -> Self { 80 | Box::new(record) 81 | } 82 | } 83 | 84 | impl From for CacheRecord { 85 | fn from(record: SoaCacheRecord) -> Self { 86 | Box::new(record) 87 | } 88 | } 89 | 90 | #[cfg(test)] 91 | impl PartialEq for CacheRecord { 92 | fn eq(&self, other: &Self) -> bool { 93 | self.to_bytes().eq(&other.to_bytes()) 94 | } 95 | } 96 | 97 | #[cfg(test)] 98 | pub mod tests { 99 | pub use crate::cache::cache_record::ip_record::tests; 100 | use crate::cache::{CacheItem, CacheRecord}; 101 | use crate::cache::cache_record::BoxedClone; 102 | use crate::system::TIME; 103 | use crate::cache::limit_map::GetOrdKey; 104 | use crate::protocol::{DnsAnswer, FailureAnswer}; 105 | 106 | #[test] 107 | fn should_return_true_when_check_expired_given_expired() { 108 | let record = get_test_record(); 109 | 110 | let result = record.is_expired(1001); 111 | 112 | assert!(result) 113 | } 114 | 115 | #[test] 116 | fn should_return_false_when_check_expired_given_not_expired() { 117 | let record = get_test_record(); 118 | 119 | let result = record.is_expired(999); 120 | 121 | assert!(!result) 122 | } 123 | 124 | #[test] 125 | fn should_return_remain_time_when_get_remain_time_given_not_expired() { 126 | let record = get_test_record(); 127 | 128 | let result = record.get_remain_time(999); 129 | 130 | assert_eq!(1, result) 131 | } 132 | 133 | #[test] 134 | fn should_return_0_when_get_remain_time_given_expired() { 135 | let record = get_test_record(); 136 | 137 | let result = record.get_remain_time(1001); 138 | 139 | assert_eq!(0, result) 140 | } 141 | 142 | #[test] 143 | fn should_return_remain_time_when_get_order_key_given_test_record() { 144 | let record: CacheRecord = Box::new(get_test_record()); 145 | TIME.with(|t| { 146 | t.borrow_mut().set_timestamp(999); 147 | }); 148 | 149 | let result: u128 = record.get_order_key(); 150 | 151 | assert_eq!(1, result) 152 | } 153 | 154 | fn get_test_record() -> TestRecord { 155 | TestRecord { 156 | key: vec![], 157 | ttl: 1000, 158 | create_time: 0, 159 | } 160 | } 161 | 162 | #[derive(Clone)] 163 | struct TestRecord { 164 | key: Vec, 165 | ttl: u128, 166 | create_time: u128, 167 | } 168 | 169 | impl CacheItem for TestRecord { 170 | fn get_create_time(&self) -> u128 { 171 | self.create_time 172 | } 173 | 174 | fn get_ttl_ms(&self) -> u128 { 175 | self.ttl 176 | } 177 | 178 | fn get_key(&self) -> &Vec { 179 | &self.key 180 | } 181 | 182 | fn to_bytes(&self) -> Vec { 183 | vec![] 184 | } 185 | 186 | fn to_answer(&self) -> DnsAnswer { 187 | DnsAnswer::from(FailureAnswer::new(0, "".to_string())) 188 | } 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /src/cache/cache_record/soa_record.rs: -------------------------------------------------------------------------------- 1 | use crate::cache::cache_record::{CacheItem, SOA_RECORD}; 2 | use crate::system::get_now; 3 | use crate::cursor::Cursor; 4 | use crate::protocol::{DnsAnswer, SoaAnswer}; 5 | 6 | #[derive(Clone, PartialOrd, PartialEq, Debug)] 7 | pub struct SoaCacheRecord { 8 | pub domain: String, 9 | pub create_time: u128, 10 | pub ttl_ms: u128, 11 | } 12 | 13 | impl CacheItem for SoaCacheRecord { 14 | fn get_create_time(&self) -> u128 { 15 | self.create_time 16 | } 17 | 18 | fn get_ttl_ms(&self) -> u128 { 19 | self.ttl_ms 20 | } 21 | 22 | fn get_key(&self) -> &String { 23 | &self.domain 24 | } 25 | 26 | fn to_bytes(&self) -> Vec { 27 | self.into() 28 | } 29 | 30 | fn to_answer(&self) -> DnsAnswer { 31 | SoaAnswer::from(self).into() 32 | } 33 | } 34 | 35 | impl From<&SoaCacheRecord> for Vec { 36 | fn from(record: &SoaCacheRecord) -> Self { 37 | let mut vec = Vec::::new(); 38 | vec.push(SOA_RECORD);//插入魔数 39 | vec.push(record.domain.len() as u8); 40 | vec.extend(record.domain.as_bytes()); 41 | vec.push(4); 42 | vec.extend(&(record.get_remain_time(get_now()) as u32).to_be_bytes()); 43 | vec.push(16); 44 | vec.extend(&record.create_time.to_be_bytes()); 45 | vec 46 | } 47 | } 48 | 49 | impl From<&[u8]> for SoaCacheRecord { 50 | fn from(bytes: &[u8]) -> Self { 51 | let cursor = Cursor::form(Vec::from(bytes).into()); 52 | cursor.take();//删掉魔数 53 | let len = cursor.take() as usize; 54 | let domain = String::from_utf8(Vec::from(cursor.take_slice(len))).unwrap(); 55 | cursor.take(); 56 | let ttl_ms = u32::from_be_bytes(cursor.take_bytes()) as u128; 57 | cursor.take(); 58 | let create_time = u128::from_be_bytes(cursor.take_bytes()); 59 | SoaCacheRecord { 60 | domain, 61 | create_time, 62 | ttl_ms, 63 | } 64 | } 65 | } 66 | 67 | impl From<&SoaAnswer> for SoaCacheRecord { 68 | fn from(answer: &SoaAnswer) -> Self { 69 | SoaCacheRecord { 70 | domain: answer.get_name().clone(), 71 | create_time: get_now(), 72 | ttl_ms: answer.get_ttl() as u128 * 1000, 73 | } 74 | } 75 | } 76 | 77 | #[cfg(test)] 78 | mod tests { 79 | use crate::cache::{SoaCacheRecord, CacheRecord, CacheItem}; 80 | use crate::system::TIME; 81 | use crate::protocol::tests::{get_ip_answer, get_soa_answer}; 82 | 83 | #[test] 84 | fn should_return_valid_record_when_create_from_bytes_given_valid_bytes() { 85 | let vec = get_test_bytes(); 86 | let valid_bytes = vec.as_slice(); 87 | 88 | let result = SoaCacheRecord::from(valid_bytes); 89 | 90 | let expected = get_soa_record(); 91 | assert_eq!(expected, result) 92 | } 93 | 94 | #[test] 95 | fn should_return_bytes_when_to_from_bytes_given_valid_bytes() { 96 | let record = get_soa_record(); 97 | 98 | let result = record.to_bytes(); 99 | 100 | let expected = get_test_bytes(); 101 | assert_eq!(expected, result) 102 | } 103 | 104 | #[test] 105 | fn should_return_valid_record_when_from_answer_given_valid_answer() { 106 | let answer = get_soa_answer(); 107 | TIME.with(|t| { 108 | t.borrow_mut().set_timestamp(0); 109 | }); 110 | 111 | let result = answer.to_cache().unwrap(); 112 | 113 | let expected: CacheRecord = get_soa_record().into(); 114 | assert!(expected.eq(&result)) 115 | } 116 | 117 | fn get_test_bytes() -> Vec { 118 | let bytes: [u8; 44] = [35, 15, 3, 119, 119, 119, 5, 98, 97, 105, 100, 117, 3, 99, 111, 109, 0, 4, 0, 0, 3, 232, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 1, 1, 1, 1]; 119 | let mut vec = Vec::with_capacity(44); 120 | for c in bytes.iter() { 121 | vec.push(c.clone()) 122 | } 123 | vec 124 | } 125 | 126 | pub fn get_soa_record() -> SoaCacheRecord { 127 | SoaCacheRecord { 128 | domain: "www.baidu.com".to_string(), 129 | create_time: 0, 130 | ttl_ms: 1000, 131 | } 132 | } 133 | 134 | pub fn build_soa_record(f: fn(&mut SoaCacheRecord)) -> SoaCacheRecord { 135 | let mut record = get_soa_record(); 136 | f(&mut record); 137 | record 138 | } 139 | } -------------------------------------------------------------------------------- /src/cache/expired_strategy.rs: -------------------------------------------------------------------------------- 1 | use crate::cache::{CacheStrategy, CacheMap, AnswerFuture}; 2 | use std::sync::Arc; 3 | use crate::system::get_now; 4 | use crate::system::Result; 5 | use crate::cache::cache_record::{CacheRecord, Expired}; 6 | use async_trait::async_trait; 7 | use crate::protocol::DnsAnswer; 8 | 9 | pub struct ExpiredCacheStrategy { 10 | map: Arc, 11 | } 12 | 13 | #[async_trait] 14 | impl CacheStrategy for ExpiredCacheStrategy { 15 | async fn handle(&self, record: CacheRecord, future: AnswerFuture) -> Result { 16 | if record.is_expired(get_now()) { 17 | let answer = future.await?; 18 | if let Some(r) = answer.to_cache() { 19 | self.map.insert(record.get_key().clone(), r); 20 | } 21 | Ok(answer) 22 | } else { 23 | Ok(record.to_answer()) 24 | } 25 | } 26 | } 27 | 28 | impl ExpiredCacheStrategy { 29 | pub fn from(map: Arc) -> Self { 30 | ExpiredCacheStrategy { 31 | map 32 | } 33 | } 34 | } 35 | 36 | // #[cfg(test)] 37 | // pub mod tests { 38 | // use crate::cache::expired_strategy::ExpiredCacheStrategy; 39 | // use crate::cache::limit_map::LimitedMap; 40 | // use crate::cache::{DNSCacheRecord, CacheStrategy}; 41 | // use std::sync::Arc; 42 | // use crate::cache::record::tests::{get_valid_record, build_valid_record}; 43 | // use crate::protocol::tests::get_ip_answer; 44 | // use crate::protocol::DNSAnswer; 45 | // use crate::system::{Result, TIME}; 46 | // use std::sync::atomic::{AtomicBool, Ordering}; 47 | // 48 | // #[test] 49 | // fn should_return_answer_when_call_handle_given_no_expired_record() { 50 | // let strategy = ExpiredCacheStrategy { 51 | // map: Arc::new(LimitedMap::, DNSCacheRecord>::from(0)) 52 | // }; 53 | // let (is_called, func) = get_test_func(); 54 | // let record = get_valid_record(); 55 | // 56 | // let result = strategy.handle(record.domain.clone(), record, func); 57 | // 58 | // assert!(!is_called.load(Ordering::Relaxed)); 59 | // assert!(result.is_ok()); 60 | // assert_eq!(get_ip_answer(), result.unwrap()); 61 | // assert!(strategy.map.is_empty()); 62 | // } 63 | // 64 | // #[test] 65 | // fn should_return_answer_and_call_get_data_func_and_insert_map_when_call_handle_given_expired_record() { 66 | // let strategy = ExpiredCacheStrategy { 67 | // map: Arc::new(LimitedMap::, DNSCacheRecord>::from(10)) 68 | // }; 69 | // let (is_called, func) = get_test_func(); 70 | // let record = get_expired_record(); 71 | // let key = record.domain.clone(); 72 | // 73 | // let result = strategy.handle(key.clone(), record, func); 74 | // 75 | // assert!(is_called.load(Ordering::Relaxed)); 76 | // assert!(result.is_ok()); 77 | // assert_eq!(get_ip_answer(), result.unwrap()); 78 | // let expected = build_valid_record(|r| { r.start_time = 1001; }); 79 | // assert_eq!(Some(expected), strategy.map.get(&key)) 80 | // } 81 | // 82 | // pub fn get_test_func() -> (Arc, Box Result + Send + 'static>) { 83 | // let is_called = Arc::new(AtomicBool::new(false)); 84 | // let rc = is_called.clone(); 85 | // let func = Box::new(move || -> Result{ 86 | // rc.fetch_or(true, Ordering::Relaxed); 87 | // Ok(get_ip_answer()) 88 | // }); 89 | // (is_called, func) 90 | // } 91 | // 92 | // fn get_expired_record() -> DNSCacheRecord { 93 | // let record = get_valid_record(); 94 | // TIME.with(|r| { 95 | // r.borrow_mut().set_timestamp(1001); 96 | // }); 97 | // record 98 | // } 99 | // } -------------------------------------------------------------------------------- /src/cache/limit_map.rs: -------------------------------------------------------------------------------- 1 | use dashmap::DashMap; 2 | use std::hash::Hash; 3 | use std::collections::hash_map::RandomState; 4 | use std::sync::Mutex; 5 | 6 | pub trait GetOrdKey { 7 | type Output: Ord + Clone; 8 | fn get_order_key(&self) -> Self::Output; 9 | } 10 | 11 | pub struct LimitedMap { 12 | records: DashMap, 13 | limit: usize, 14 | lock_key: Mutex, 15 | } 16 | 17 | impl LimitedMap 18 | where K: Eq + Hash + Clone, V: Clone + GetOrdKey { 19 | pub fn from(limit: usize) -> Self { 20 | LimitedMap { 21 | records: DashMap::with_capacity(limit), 22 | limit, 23 | lock_key: Mutex::new(0), 24 | } 25 | } 26 | pub fn get(&self, key: &K) -> Option { 27 | self.records.get(key).map(|e| e.value().clone()) 28 | } 29 | 30 | pub fn insert(&self, key: K, value: V) { 31 | //如果超过了限制的大小,则删除掉十分之一最小的记录 32 | let guard = self.lock_key.lock().unwrap(); 33 | if self.records.len() >= self.limit { 34 | let vec = &mut Vec::new(); 35 | self.records.iter().for_each(|e| { 36 | vec.push((e.key().clone(), e.value().get_order_key())) 37 | }); 38 | vec.sort_unstable_by_key(|(_, sort_key)| sort_key.clone()); 39 | let keys: Vec<&K> = vec[0..self.limit / 10].iter() 40 | .map(|(k, _)| k).collect(); 41 | self.records.retain(|r, _| { 42 | !keys.contains(&r) 43 | }) 44 | } 45 | drop(guard); 46 | self.records.insert(key, value); 47 | } 48 | 49 | pub fn iter(&self) -> dashmap::iter::Iter> { 50 | self.records.iter() 51 | } 52 | 53 | pub fn is_empty(&self) -> bool { 54 | self.records.is_empty() 55 | } 56 | } 57 | 58 | #[cfg(test)] 59 | mod tests { 60 | use crate::cache::limit_map::{LimitedMap, GetOrdKey}; 61 | 62 | impl GetOrdKey for i32 { 63 | type Output = i32; 64 | 65 | fn get_order_key(&self) -> Self::Output { 66 | self.clone() 67 | } 68 | } 69 | 70 | #[test] 71 | fn should_insert_into_map_when_call_insert_given_empty_map() { 72 | let map = LimitedMap::from(1); 73 | 74 | map.insert(1, 1); 75 | 76 | let result = map.records.get(&1) 77 | .map(|r| r.value().clone()); 78 | assert_eq!(Some(1), result) 79 | } 80 | 81 | #[test] 82 | fn should_insert_into_map_and_remove_10_persist_when_call_insert_given_full_map() { 83 | let map = LimitedMap::from(100); 84 | (0..100).for_each(|r| { 85 | map.records.insert(r, r); 86 | }); 87 | 88 | map.insert(1000, 1000); 89 | 90 | let result = map.records.get(&1000) 91 | .map(|r| r.value().clone()); 92 | let over_result = map.records.get(&9).map(|r| r.value().clone()); 93 | assert_eq!(Some(1000), result); 94 | assert_eq!(91, map.records.len()); 95 | assert_eq!(None, over_result) 96 | } 97 | 98 | #[test] 99 | fn should_return_true_when_call_is_empty_given_empty_map() { 100 | let map = LimitedMap::::from(1); 101 | 102 | let result = map.is_empty(); 103 | 104 | assert!(result) 105 | } 106 | 107 | #[test] 108 | fn should_return_false_when_call_is_empty_given_has_value_in_map() { 109 | let map = LimitedMap::from(1); 110 | map.records.insert(1, 1); 111 | 112 | let result = map.is_empty(); 113 | 114 | assert!(!result) 115 | } 116 | 117 | #[test] 118 | fn should_return_value_when_call_get_given_has_value() { 119 | let map = LimitedMap::from(1); 120 | map.records.insert(1, 1); 121 | 122 | let result = map.get(&1).unwrap().value().clone(); 123 | 124 | 125 | assert_eq!(1, result) 126 | } 127 | 128 | #[test] 129 | fn should_return_none_when_call_get_given_no_value() { 130 | let map = LimitedMap::from(1); 131 | map.records.insert(2, 1); 132 | 133 | let result = map.get(&1); 134 | 135 | assert!(result.is_none()); 136 | } 137 | } -------------------------------------------------------------------------------- /src/cache/mod.rs: -------------------------------------------------------------------------------- 1 | mod limit_map; 2 | mod expired_strategy; 3 | mod timeout_strategy; 4 | mod cache_record; 5 | 6 | use crate::config::Config; 7 | use crate::system::{Result, get_now, block_on}; 8 | use std::sync::Arc; 9 | use limit_map::{LimitedMap}; 10 | use tokio::fs::File; 11 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 12 | 13 | pub use cache_record::CacheRecord; 14 | pub use cache_record::IpCacheRecord; 15 | pub use cache_record::SoaCacheRecord; 16 | pub use cache_record::CacheItem; 17 | use crate::cache::expired_strategy::ExpiredCacheStrategy; 18 | use crate::cache::timeout_strategy::TimeoutCacheStrategy; 19 | use crate::cache::cache_record::{IP_RECORD, Expired}; 20 | use crate::cache::cache_record::SOA_RECORD; 21 | use crate::cursor::Cursor; 22 | use async_trait::async_trait; 23 | use futures_util::future::BoxFuture; 24 | use crate::protocol::DnsAnswer; 25 | 26 | pub type CacheMap = LimitedMap; 27 | type ExpiredStrategy = Box; 28 | type AnswerFuture = BoxFuture<'static, Result>; 29 | 30 | #[async_trait] 31 | pub trait CacheStrategy: Send + Sync { 32 | async fn handle(&self, record: CacheRecord, future: AnswerFuture) -> Result; 33 | } 34 | 35 | pub struct CachePool { 36 | strategy: ExpiredStrategy, 37 | file_name: String, 38 | map: Arc, 39 | } 40 | 41 | impl Drop for CachePool { 42 | fn drop(&mut self) { 43 | block_on(async move { 44 | match self.write_to_file().await { 45 | Ok(_) => {} 46 | Err(e) => { 47 | error!("把缓存写入文件出错: {:?}", e) 48 | } 49 | } 50 | }); 51 | } 52 | } 53 | 54 | impl CachePool { 55 | pub async fn from(config: &Config) -> Result { 56 | let limit_map: Arc = Arc::new(create_map_by_config(config).await?); 57 | let strategy: ExpiredStrategy = if config.cache_get_strategy == 0 { 58 | Box::new(ExpiredCacheStrategy::from(limit_map.clone())) 59 | } else { 60 | Box::new(TimeoutCacheStrategy::from(limit_map.clone(), 61 | config.cache_ttl_timeout_ms as u128)) 62 | }; 63 | Ok(CachePool { 64 | strategy, 65 | file_name: config.cache_file.clone(), 66 | map: limit_map, 67 | }) 68 | } 69 | pub async fn get(&self, key: String, future: AnswerFuture) -> Result { 70 | //从缓存map中取 71 | match self.map.get(&key) { 72 | //缓存中有 73 | Some(r) => { 74 | Ok(self.strategy.handle(r, future).await?) 75 | } 76 | //缓存中没有 77 | None => { 78 | let answer = future.await?; 79 | if let Some(r) = answer.to_cache() { 80 | self.map.insert(key, r); 81 | } 82 | Ok(answer) 83 | } 84 | } 85 | } 86 | 87 | fn to_file_bytes(&self) -> Vec { 88 | let mut vec = Vec::new(); 89 | self.map.iter().for_each(|e| { 90 | let bytes = e.value().to_bytes(); 91 | vec.push(bytes.len() as u8); 92 | vec.extend(bytes); 93 | }); 94 | vec.push(0); 95 | vec 96 | } 97 | 98 | pub async fn write_to_file(&self) -> Result<()> { 99 | if self.map.is_empty() { 100 | info!("没有缓存需要写入文件"); 101 | return Ok(()); 102 | } 103 | let mut file = File::create(&self.file_name).await?; 104 | file.write_all(self.to_file_bytes().as_slice()).await?; 105 | info!("缓存全部写入了文件! 文件名称是{}", self.file_name); 106 | Ok(()) 107 | } 108 | } 109 | 110 | async fn create_map_by_config(config: &Config) -> Result { 111 | Ok(match File::open(&config.cache_file).await { 112 | Ok(mut file) => { 113 | let mut file_vec = Vec::new(); 114 | file.read_to_end(&mut file_vec).await?; 115 | if file_vec.is_empty() { 116 | LimitedMap::from(config.cache_num) 117 | } else { 118 | create_map_by_vec_u8(config, file_vec) 119 | } 120 | } 121 | Err(_e) => { 122 | LimitedMap::from(config.cache_num) 123 | } 124 | }) 125 | } 126 | 127 | fn create_map_by_vec_u8(config: &Config, file_vec: Vec) -> CacheMap { 128 | let map = LimitedMap::from(config.cache_num); 129 | let cursor = Cursor::form(file_vec.into()); 130 | let mut len = cursor.take() as usize; 131 | while len > 0 { 132 | let flag = cursor.peek(); 133 | let record = match flag { 134 | IP_RECORD => { 135 | CacheRecord::from(IpCacheRecord::from(cursor.take_slice(len))) 136 | } 137 | SOA_RECORD => { 138 | CacheRecord::from(SoaCacheRecord::from(cursor.take_slice(len))) 139 | } 140 | _ => { 141 | panic!("Unsupported cache record!"); 142 | } 143 | }; 144 | if !record.is_expired(get_now()) { 145 | map.insert(record.get_key().clone(), record); 146 | } 147 | len = cursor.take() as usize; 148 | } 149 | map 150 | } 151 | 152 | #[cfg(test)] 153 | mod tests {} -------------------------------------------------------------------------------- /src/cache/timeout_strategy.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | use crate::cache::{CacheStrategy, CacheMap, AnswerFuture}; 3 | use crate::system::{get_sub_now, get_now}; 4 | use std::time::Duration; 5 | use crate::system::Result; 6 | use crate::cache::cache_record::{CacheRecord, Expired}; 7 | use async_trait::async_trait; 8 | use crate::protocol::DnsAnswer; 9 | 10 | pub struct TimeoutCacheStrategy { 11 | map: Arc, 12 | timeout: u128, 13 | } 14 | 15 | #[async_trait] 16 | impl CacheStrategy for TimeoutCacheStrategy { 17 | async fn handle(&self, record: CacheRecord, future: AnswerFuture) -> Result { 18 | let now = get_sub_now(Duration::from_millis(self.timeout as u64)); 19 | if record.is_expired(now) { 20 | let answer = future.await?; 21 | if let Some(r) = answer.to_cache() { 22 | self.map.insert(record.get_key().clone(), r); 23 | } 24 | Ok(answer) 25 | } else { 26 | if record.is_expired(get_now()) { 27 | let cloned_map = self.map.clone(); 28 | let key = record.get_key().clone(); 29 | let _joiner = tokio::spawn(async move { 30 | match future.await { 31 | Ok(answer) => { 32 | if let Some(r) = answer.to_cache() { 33 | cloned_map.insert(key, r); 34 | } 35 | } 36 | Err(e) => { 37 | error!("{}", e); 38 | } 39 | } 40 | }); 41 | if cfg!(test) { 42 | crate::system::block_on(async move { 43 | _joiner.await.unwrap(); 44 | }) 45 | } 46 | Ok(record.to_answer()) 47 | } else { 48 | Ok(record.to_answer()) 49 | } 50 | } 51 | } 52 | } 53 | 54 | impl TimeoutCacheStrategy { 55 | pub fn from(map: Arc, timeout: u128) -> Self { 56 | TimeoutCacheStrategy { 57 | map, 58 | timeout, 59 | } 60 | } 61 | } 62 | 63 | // #[cfg(test)] 64 | // mod tests { 65 | // use crate::cache::timeout_strategy::TimeoutCacheStrategy; 66 | // use std::sync::Arc; 67 | // use crate::cache::limit_map::LimitedMap; 68 | // use crate::cache::{CacheStrategy, CacheRecord, CacheItem}; 69 | // use std::sync::atomic::Ordering; 70 | // use crate::protocol::tests::{get_ip_answer, get_ip_answer_with_ttl}; 71 | // use crate::cache::expired_strategy::tests::get_test_func; 72 | // use crate::system::{set_time_base}; 73 | // use crate::cache::cache_record::tests::tests::{get_ip_record, build_ip_record}; 74 | // 75 | // #[test] 76 | // fn should_return_answer_when_call_handle_given_no_expired_record() { 77 | // let strategy = TimeoutCacheStrategy { 78 | // map: Arc::new(LimitedMap::, DNSCacheRecord>::from(0)), 79 | // timeout: 800, 80 | // }; 81 | // let (is_called, func) = get_test_func(); 82 | // let record = Box::new(get_ip_record()); 83 | // set_time_base(999); 84 | // 85 | // let result = strategy.handle(record.domain.clone(), record, func); 86 | // 87 | // assert!(!is_called.load(Ordering::Relaxed)); 88 | // assert!(result.is_ok()); 89 | // assert_eq!(get_ip_answer_with_ttl(0), result.unwrap()); 90 | // assert!(strategy.map.is_empty()); 91 | // } 92 | // 93 | // #[test] 94 | // fn should_return_answer_and_insert_to_map_when_call_handle_given_expired_record() { 95 | // let strategy = TimeoutCacheStrategy { 96 | // map: Arc::new(LimitedMap::, DNSCacheRecord>::from(0)), 97 | // timeout: 1000, 98 | // }; 99 | // let (is_called, func) = get_test_func(); 100 | // let record = Box::new(get_ip_record()); 101 | // set_time_base(2001); 102 | // let key = record.get_key().clone(); 103 | // 104 | // let result = strategy.handle(key.clone(), record, func); 105 | // 106 | // assert!(is_called.load(Ordering::Relaxed)); 107 | // assert!(result.is_ok()); 108 | // assert_eq!(get_ip_answer(), result.unwrap()); 109 | // let expected = build_ip_record(|r| { r.create_time = 2001; }); 110 | // assert_eq!(Some(Box::new(expected)), strategy.map.get(&key)); 111 | // } 112 | // 113 | // #[tokio::test(flavor = "multi_thread", worker_threads = 1)] 114 | // async fn should_return_answer_and_insert_to_map_when_call_handle_given_no_expired_but_timeout_record() { 115 | // //没找到好的办法测试内部的async调用,所以只能这样了 116 | // let strategy = TimeoutCacheStrategy { 117 | // map: Arc::new(LimitedMap::, DNSCacheRecord>::from(0)), 118 | // timeout: 1000, 119 | // }; 120 | // let (is_called, func) = get_test_func(); 121 | // let record = Box::new(get_ip_record()); 122 | // set_time_base(1999); 123 | // let key = record.get_key().clone(); 124 | // 125 | // let result = strategy.handle(key.clone(), record, func); 126 | // 127 | // assert!(is_called.load(Ordering::Relaxed)); 128 | // assert!(result.is_ok()); 129 | // assert_eq!(get_ip_answer_with_ttl(0), result.unwrap()); 130 | // let expected: CacheRecord = Box::new(get_ip_record()); 131 | // assert!(strategy.map.get(&key).unwrap().value().eq(&expected)) 132 | // // assert_eq!(Some(expected), ); 133 | // } 134 | // } -------------------------------------------------------------------------------- /src/client.rs: -------------------------------------------------------------------------------- 1 | use tokio::net::UdpSocket; 2 | use std::net::SocketAddr; 3 | use crate::system::{Result, QueryBuf, default_value}; 4 | use crate::protocol::DnsAnswer; 5 | 6 | pub struct ClientSocket { 7 | socket: UdpSocket, 8 | } 9 | 10 | impl ClientSocket { 11 | pub async fn new(port: u16) -> Result { 12 | let socket = UdpSocket::bind(("0.0.0.0", port)).await?; 13 | Ok(ClientSocket { 14 | socket 15 | }) 16 | } 17 | pub async fn recv(&self) -> Result<(QueryBuf, SocketAddr)> { 18 | let mut buf: QueryBuf = default_value(); 19 | let (_, src) = self.socket 20 | .recv_from(&mut buf) 21 | .await?; 22 | Ok((buf, src)) 23 | } 24 | 25 | pub async fn back_to(&self, client: SocketAddr, answer: DnsAnswer) -> Result<()> { 26 | self.socket.send_to(answer.to_bytes().as_slice(), client).await?; 27 | Ok(()) 28 | } 29 | } -------------------------------------------------------------------------------- /src/config.rs: -------------------------------------------------------------------------------- 1 | use crate::system::Result; 2 | use tokio::fs::File; 3 | use tokio::io::AsyncReadExt; 4 | use toml::Value; 5 | 6 | pub struct Config { 7 | pub cache_on: bool, 8 | pub cache_file: String, 9 | pub cache_num: usize, 10 | pub port: u16, 11 | pub servers: Vec, 12 | pub filters: Vec, 13 | pub log_level: String, 14 | pub ip_choose_strategy: usize, 15 | pub cache_get_strategy: usize, 16 | pub cache_ttl_timeout_ms: usize, 17 | pub server_choose_strategy: usize, 18 | pub server_choose_duration_h: usize, 19 | } 20 | 21 | impl Config { 22 | fn from(value: Value) -> Self { 23 | let cache_file = value["cache-file"].as_str().map(|e| String::from(e)) 24 | .unwrap_or("cache".into()); 25 | let cache_on = value["cache"].as_bool().unwrap_or(true); 26 | let cache_num = value["cache-num"].as_integer().unwrap_or(1000) as usize; 27 | let port = value["port"].as_integer().unwrap_or(2053) as u16; 28 | let servers = value["servers"].as_array().map(|e| { 29 | e.iter().map(|e| String::from(e.as_str().unwrap())).collect() 30 | }).unwrap_or(vec![]); 31 | let filters = value["filters"].as_array().map(|e| { 32 | e.iter().map(|e| String::from(e.as_str().unwrap())).collect() 33 | }).unwrap_or(vec![]); 34 | let log_level = value["log-level"].as_str().map(|e| String::from(e)) 35 | .unwrap_or("error".into()); 36 | let ip_choose_strategy = value["ip-choose-strategy"].as_integer() 37 | .unwrap_or(0) as usize; 38 | let cache_get_strategy = value["cache-get-strategy"].as_integer() 39 | .unwrap_or(0) as usize; 40 | let cache_ttl_timeout_ms = value["cache-ttl-timeout-ms"].as_integer() 41 | .unwrap_or(0) as usize; 42 | let server_choose_strategy = value["server-choose-strategy"].as_integer() 43 | .unwrap_or(0) as usize; 44 | let server_choose_duration_h = value["server-choose-duration-h"].as_integer() 45 | .unwrap_or(12) as usize; 46 | Config { 47 | cache_on, 48 | cache_file, 49 | cache_num, 50 | port, 51 | servers, 52 | filters, 53 | log_level, 54 | ip_choose_strategy, 55 | cache_get_strategy, 56 | cache_ttl_timeout_ms, 57 | server_choose_strategy, 58 | server_choose_duration_h, 59 | } 60 | } 61 | } 62 | 63 | pub async fn init_from_toml() -> Result { 64 | let mut file = File::open("easydns.toml").await?; 65 | let buf = &mut String::new(); 66 | file.read_to_string(buf).await?; 67 | Ok(Config::from(buf.parse::()?)) 68 | } 69 | -------------------------------------------------------------------------------- /src/cursor/array_buf.rs: -------------------------------------------------------------------------------- 1 | use crate::cursor::{Array, ArrayBuf}; 2 | use crate::system::{QueryBuf, AnswerBuf}; 3 | 4 | impl Array for Vec { 5 | fn get(&self, index: usize) -> u8 { 6 | self[index] 7 | } 8 | 9 | fn get_slice(&self, start: usize, end: usize) -> &[u8] { 10 | &self[start..end] 11 | } 12 | } 13 | 14 | impl From> for ArrayBuf { 15 | fn from(buf: Vec) -> Self { 16 | Box::new(buf) 17 | } 18 | } 19 | 20 | impl Array for AnswerBuf { 21 | fn get(&self, index: usize) -> u8 { 22 | self[index] 23 | } 24 | 25 | fn get_slice(&self, start: usize, end: usize) -> &[u8] { 26 | &self[start..end] 27 | } 28 | } 29 | 30 | impl From for ArrayBuf { 31 | fn from(buf: AnswerBuf) -> Self { 32 | Box::new(buf) 33 | } 34 | } 35 | 36 | impl Array for QueryBuf { 37 | fn get(&self, index: usize) -> u8 { 38 | self[index] 39 | } 40 | 41 | fn get_slice(&self, start: usize, end: usize) -> &[u8] { 42 | &self[start..end] 43 | } 44 | } 45 | 46 | impl From for ArrayBuf { 47 | fn from(buf: QueryBuf) -> Self { 48 | Box::new(buf) 49 | } 50 | } -------------------------------------------------------------------------------- /src/cursor/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::system::default_value; 2 | use std::cell::RefCell; 3 | 4 | mod array_buf; 5 | 6 | pub type ArrayBuf = Box>; 7 | 8 | pub struct Cursor { 9 | array: ArrayBuf, 10 | current: RefCell, 11 | } 12 | 13 | pub trait Array: Send + Sync { 14 | fn get(&self, index: usize) -> T; 15 | fn get_slice(&self, start: usize, end: usize) -> &[T]; 16 | } 17 | 18 | impl Cursor { 19 | pub fn form(array: ArrayBuf) -> Self { 20 | Cursor { 21 | array, 22 | current: RefCell::new(0), 23 | } 24 | } 25 | 26 | pub fn at(&self, index: usize) { 27 | *self.current.borrow_mut() = index; 28 | } 29 | 30 | #[inline] 31 | pub fn tmp_at(&self, index: usize, mut func: F) { 32 | let current_index_saved = self.get_current_index(); 33 | self.at(index); 34 | func(self); 35 | self.at(current_index_saved); 36 | } 37 | 38 | pub fn take(&self) -> R { 39 | let current = *self.current.borrow(); 40 | let result = self.array.get(current); 41 | *self.current.borrow_mut() = current + 1; 42 | result 43 | } 44 | 45 | pub fn move_to(&self, step: usize) { 46 | let current = *self.current.borrow(); 47 | *self.current.borrow_mut() = current + step; 48 | } 49 | 50 | pub fn peek(&self) -> R { 51 | self.array.get(*self.current.borrow()) 52 | } 53 | 54 | pub fn get_current_index(&self) -> usize { 55 | *self.current.borrow() 56 | } 57 | 58 | pub fn take_slice(&self, len: usize) -> &[R] { 59 | let current = *self.current.borrow(); 60 | let result = self.array.get_slice(current, current + len); 61 | *self.current.borrow_mut() = current + len; 62 | result 63 | } 64 | 65 | pub fn take_bytes(&self) -> [R; N] where R: Default + Copy { 66 | let mut k = default_value(); 67 | (0..N).into_iter().for_each(|index| { 68 | k[index] = self.take(); 69 | }); 70 | k 71 | } 72 | } -------------------------------------------------------------------------------- /src/filter.rs: -------------------------------------------------------------------------------- 1 | use crate::config::Config; 2 | use crate::system::{Result, FileNotFoundError}; 3 | use regex::Regex; 4 | use std::collections::HashSet; 5 | use std::process::Stdio; 6 | use tokio::fs::File; 7 | use tokio::io::{AsyncBufRead, AsyncBufReadExt, BufReader}; 8 | use tokio::process::Command; 9 | 10 | const GET_DOMAIN_REGEX: &str = 11 | "address /([a-zA-Z0-9][-a-zA-Z0-9]{0,62}(?:\\.[a-zA-Z0-9][-a-zA-Z0-9]{0,62})+)/([#|d])"; 12 | 13 | #[derive(PartialEq, Eq, Hash, Debug)] 14 | struct FilterItem { 15 | domain: String, 16 | group: String, 17 | } 18 | 19 | impl From for FilterItem { 20 | fn from(domain: String) -> Self { 21 | FilterItem { 22 | domain, 23 | group: "#".into(), 24 | } 25 | } 26 | } 27 | 28 | pub struct Filter { 29 | set: HashSet, 30 | } 31 | 32 | impl Filter { 33 | pub async fn from(config: &Config) -> Self { 34 | let set = read_resources_to_filter(&config.filters).await; 35 | debug!("filter init done, set len = {}", set.len()); 36 | Filter { set } 37 | } 38 | 39 | pub fn contain(&self, domain: &String) -> bool { 40 | //拆分多级域名 41 | let split = domain.split("."); 42 | let vec: Vec<&str> = split.collect(); 43 | for i in (0..vec.len()).rev() { 44 | let mut string = String::new(); 45 | for j in i..vec.len() { 46 | string.push_str(vec[j]); 47 | string.push_str("."); 48 | } 49 | string.remove(string.len() - 1); 50 | if self.set.contains(&string.into()) { 51 | return true; 52 | } 53 | } 54 | false 55 | } 56 | } 57 | 58 | async fn read_resources_to_filter(paths: &Vec) -> HashSet { 59 | let mut set = HashSet::new(); 60 | for path in paths { 61 | let result = read_resource_to_filter(&path).await; 62 | match result { 63 | Ok(temp) => { 64 | for f in temp { 65 | if f.group == "#" { 66 | set.insert(f); 67 | } else { 68 | set.remove(&f.domain.into()); 69 | } 70 | } 71 | } 72 | Err(e) => { 73 | error!("{:?}", e); 74 | } 75 | }; 76 | } 77 | set 78 | } 79 | 80 | async fn read_resource_to_filter(path: &str) -> Result> { 81 | if path.starts_with("http") { 82 | read_url_to_filter(path).await 83 | } else { 84 | read_file_to_filter(path).await 85 | } 86 | } 87 | 88 | async fn read_url_to_filter(url: &str) -> Result> { 89 | let mut child = Command::new("curl") 90 | .arg("-k") 91 | .arg("-s") 92 | .arg(url) 93 | .stdout(Stdio::piped()) 94 | .spawn()?; 95 | let reader = BufReader::new(child.stdout.take().unwrap()); 96 | tokio::spawn(async move { 97 | let status = child 98 | .wait() 99 | .await 100 | .expect("filter curl process encountered an error"); 101 | debug!("filter curl status was: {}", status); 102 | }); 103 | read_to_filter(reader).await 104 | } 105 | 106 | async fn read_file_to_filter(file_path: &str) -> Result> { 107 | let file = File::open(file_path).await.map_err(|e| { 108 | FileNotFoundError { 109 | path: String::from(file_path), 110 | supper: Box::new(e), 111 | } 112 | })?; 113 | let reader = BufReader::new(file); 114 | read_to_filter(reader).await 115 | } 116 | 117 | async fn read_to_filter( 118 | mut reader: impl AsyncBufRead + std::marker::Unpin, 119 | ) -> Result> { 120 | let mut buffer = String::new(); 121 | let line_regex = Regex::new(GET_DOMAIN_REGEX).unwrap(); 122 | let mut set = HashSet::new(); 123 | while reader.read_line(&mut buffer).await? > 0 { 124 | match handle_one_line(&line_regex, &buffer) { 125 | None => {} 126 | Some(item) => { 127 | set.insert(item); 128 | } 129 | } 130 | buffer.clear(); 131 | } 132 | Ok(set) 133 | } 134 | 135 | fn handle_one_line(regex: &Regex, line: &String) -> Option { 136 | if line.starts_with("#") { 137 | return None; 138 | } 139 | regex 140 | .captures(line) 141 | .and_then(|cap| match cap.get(1).map(|l| String::from(l.as_str())) { 142 | Some(domain) => { 143 | if let Some(group) = cap.get(2).map(|f| String::from(f.as_str())) { 144 | Some(FilterItem { domain, group }) 145 | } else { 146 | None 147 | } 148 | } 149 | None => None, 150 | }) 151 | } 152 | 153 | #[cfg(test)] 154 | mod tests { 155 | use crate::filter::{ 156 | handle_one_line, read_resource_to_filter, read_resources_to_filter, FilterItem, 157 | GET_DOMAIN_REGEX, 158 | }; 159 | use crate::system::Result; 160 | use regex::Regex; 161 | use std::collections::HashSet; 162 | 163 | #[test] 164 | fn test_handle_one_line() { 165 | let line_regex = Regex::new(GET_DOMAIN_REGEX).unwrap(); 166 | let x = String::from("address /kwcscdn.000dn.com/#"); 167 | 168 | let result = handle_one_line(&line_regex, &x); 169 | 170 | assert_eq!(result, Some(String::from("kwcscdn.000dn.com").into())); 171 | } 172 | 173 | #[tokio::test] 174 | async fn test_read_file_to_filter() -> Result<()> { 175 | let filter = read_resource_to_filter( 176 | "./tests/resources/test_filter.txt").await?; 177 | 178 | let mut expected: HashSet = HashSet::new(); 179 | expected.insert(String::from("00-gov.cn").into()); 180 | expected.insert(String::from("kwcdn.000dn.com").into()); 181 | assert_eq!(expected, filter); 182 | Ok(()) 183 | } 184 | 185 | // #[tokio::test] 186 | // async fn test_read_url_to_filter() -> Result<()> { 187 | // let filter = read_resource_to_filter( 188 | // "https://raw.githubusercontent.com/dunmengjun\ 189 | // /SmartDNS-GFWList/master/test_url_filter.txt", 190 | // ).await?; 191 | // 192 | // let mut expected: HashSet = HashSet::new(); 193 | // expected.insert(String::from("00-gov.cn").into()); 194 | // expected.insert(String::from("kwcdn.000dn.com").into()); 195 | // assert_eq!(expected, filter); 196 | // Ok(()) 197 | // } 198 | 199 | #[tokio::test] 200 | async fn test_filter_item_overcast() -> Result<()> { 201 | let filters: Vec = vec!["./tests/resources/test_filter.txt".into(), 202 | "./tests/resources/covercast_filter.txt".into()]; 203 | 204 | let result = read_resources_to_filter(&filters).await; 205 | 206 | let mut expected: HashSet = HashSet::new(); 207 | expected.insert(String::from("00-gov.cn").into()); 208 | assert_eq!(expected, result); 209 | Ok(()) 210 | } 211 | 212 | #[tokio::test] 213 | async fn test_filter_path_empty() -> Result<()> { 214 | let filters: Vec = vec![]; 215 | 216 | let result = read_resources_to_filter(&filters).await; 217 | 218 | assert!(result.is_empty()); 219 | Ok(()) 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /src/handler/cache_handler.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use crate::cache::CachePool; 3 | use std::sync::Arc; 4 | use crate::handler::{Clain, Handler}; 5 | use crate::system::{Result}; 6 | use futures_util::FutureExt; 7 | use crate::protocol::{DnsAnswer, DnsQuery}; 8 | 9 | #[derive(Clone)] 10 | pub struct CacheHandler { 11 | cache_pool: Arc, 12 | } 13 | 14 | impl CacheHandler { 15 | pub fn new(cache_pool: Arc) -> Self { 16 | CacheHandler { 17 | cache_pool 18 | } 19 | } 20 | } 21 | 22 | 23 | #[async_trait] 24 | impl Handler for CacheHandler { 25 | async fn handle(&self, clain: Clain, query: DnsQuery) -> Result { 26 | let id = query.get_id().clone(); 27 | self.cache_pool 28 | .get(query.get_name().clone(), clain.next(query).boxed()).await 29 | .map(|mut r| { 30 | r.set_id(id); 31 | r 32 | }) 33 | } 34 | } -------------------------------------------------------------------------------- /src/handler/domain_filter.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use crate::filter::Filter; 3 | use std::sync::Arc; 4 | use crate::handler::{Clain, Handler}; 5 | use crate::system::Result; 6 | use crate::protocol::{DnsAnswer, SoaAnswer, DnsQuery}; 7 | 8 | #[derive(Clone)] 9 | pub struct DomainFilter { 10 | filter: Arc, 11 | } 12 | 13 | impl DomainFilter { 14 | pub fn new(filter: Arc) -> Self { 15 | DomainFilter { 16 | filter 17 | } 18 | } 19 | } 20 | 21 | #[async_trait] 22 | impl Handler for DomainFilter { 23 | async fn handle(&self, clain: Clain, query: DnsQuery) -> Result { 24 | let domain = query.get_name().clone(); 25 | if self.filter.contain(&domain) { 26 | //返回soa 27 | return Ok(DnsAnswer::from(SoaAnswer::default_soa( 28 | query.get_id().clone(), domain))); 29 | } 30 | clain.next(query).await 31 | } 32 | } -------------------------------------------------------------------------------- /src/handler/ip_maker.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use tokio_icmp::Pinger; 3 | use std::sync::Arc; 4 | use crate::handler::{Clain, Handler}; 5 | use crate::system::Result; 6 | use futures_util::future::select_all; 7 | use crate::protocol::{DnsAnswer, Ipv4Answer, DnsQuery}; 8 | use std::net::IpAddr; 9 | 10 | #[derive(Clone)] 11 | pub struct IpChoiceMaker { 12 | pinger: Arc, 13 | } 14 | 15 | impl IpChoiceMaker { 16 | pub fn new(pinger: Arc) -> Self { 17 | IpChoiceMaker { 18 | pinger 19 | } 20 | } 21 | } 22 | 23 | #[async_trait] 24 | impl Handler for IpChoiceMaker { 25 | async fn handle(&self, clain: Clain, query: DnsQuery) -> Result { 26 | let mut answer = clain.next(query).await?; 27 | if let Some(ipv4_answer) = answer.as_mut_any().downcast_mut::() { 28 | let ip_vec = ipv4_answer.get_all_ips(); 29 | if ip_vec.len() == 1 { 30 | return Ok(answer); 31 | } 32 | let mut ping_future_vec = Vec::new(); 33 | ip_vec.iter().for_each(|r| { 34 | let ip = *r.clone(); 35 | let future = self.pinger.chain(IpAddr::V4(ip)).send(); 36 | ping_future_vec.push(future); 37 | }); 38 | let index = select_all(ping_future_vec).await.1; 39 | let ip = ip_vec[index].clone(); 40 | ipv4_answer.retain_ip(&ip); 41 | } 42 | Ok(answer) 43 | } 44 | } 45 | 46 | #[derive(Clone)] 47 | pub struct IpFirstMaker; 48 | 49 | #[async_trait] 50 | impl Handler for IpFirstMaker { 51 | async fn handle(&self, clain: Clain, query: DnsQuery) -> Result { 52 | let mut answer = clain.next(query).await?; 53 | if let Some(ipv4_answer) = answer.as_mut_any().downcast_mut::() { 54 | let vec = ipv4_answer.get_all_ips(); 55 | let addr = vec[0].clone(); 56 | ipv4_answer.retain_ip(&addr); 57 | } 58 | Ok(answer) 59 | } 60 | } -------------------------------------------------------------------------------- /src/handler/legal_checker.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use std::sync::Arc; 3 | use crate::handler::{Handler, Clain}; 4 | use crate::system::Result; 5 | use crate::handler::server_group::ServerGroup; 6 | use crate::protocol::{DnsAnswer, DnsQuery}; 7 | 8 | #[derive(Clone)] 9 | pub struct LegalChecker { 10 | server_group: Arc, 11 | } 12 | 13 | impl LegalChecker { 14 | pub fn new(server_group: Arc) -> Self { 15 | LegalChecker { 16 | server_group 17 | } 18 | } 19 | } 20 | 21 | #[async_trait] 22 | impl Handler for LegalChecker { 23 | async fn handle(&self, clain: Clain, query: DnsQuery) -> Result { 24 | if !query.is_supported() { 25 | debug!("The dns query is not supported , will not mit the cache!"); 26 | let answer = self.server_group.send_query(query).await?; 27 | debug!("dns answer: {}", answer); 28 | return Ok(answer); 29 | } else { 30 | clain.next(query).await 31 | } 32 | } 33 | } -------------------------------------------------------------------------------- /src/handler/mod.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | use async_trait::async_trait; 4 | use tokio_icmp::Pinger; 5 | 6 | use crate::cache::CachePool; 7 | use crate::config::Config; 8 | use crate::filter::Filter; 9 | use crate::handler::cache_handler::CacheHandler; 10 | use crate::handler::domain_filter::DomainFilter; 11 | use crate::handler::ip_maker::{IpChoiceMaker, IpFirstMaker}; 12 | use crate::handler::legal_checker::LegalChecker; 13 | use crate::handler::query_sender::QuerySender; 14 | use crate::system::{Result, QueryBuf}; 15 | use crate::handler::server_group::ServerGroup; 16 | use std::option::Option::Some; 17 | use crate::protocol::{DnsAnswer, DnsQuery}; 18 | 19 | mod legal_checker; 20 | mod cache_handler; 21 | mod query_sender; 22 | mod ip_maker; 23 | mod domain_filter; 24 | mod server_group; 25 | 26 | pub struct HandlerContext { 27 | server_group: Arc, 28 | pinger: Option>, 29 | cache_pool: Option>, 30 | filter: Arc, 31 | } 32 | 33 | impl HandlerContext { 34 | pub async fn from(config: Config) -> Result { 35 | let pinger = if config.ip_choose_strategy == 0 { 36 | None 37 | } else { 38 | Some(Arc::new(Pinger::new().await?)) 39 | }; 40 | let server_group = Arc::new(ServerGroup::from( 41 | config.servers.clone(), 42 | config.server_choose_strategy.clone(), 43 | (config.server_choose_duration_h * 60 * 60) as u64, 44 | ).await?); 45 | let cache_pool = if config.cache_on { 46 | Some(Arc::new(CachePool::from(&config).await?)) 47 | } else { 48 | None 49 | }; 50 | let filter = Arc::new(Filter::from(&config).await); 51 | Ok(HandlerContext { 52 | server_group, 53 | pinger, 54 | cache_pool, 55 | filter, 56 | }) 57 | } 58 | 59 | pub async fn handle_query(&self, buf: QueryBuf) -> Result { 60 | let mut query_clain = Clain::new(); 61 | query_clain.add(DomainFilter::new(self.filter.clone())); 62 | query_clain.add(LegalChecker::new(self.server_group.clone())); 63 | if let Some(pool) = self.cache_pool.clone() { 64 | query_clain.add(CacheHandler::new(pool)); 65 | } 66 | if let Some(pinger) = self.pinger.clone() { 67 | query_clain.add(IpChoiceMaker::new(pinger)); 68 | } else { 69 | query_clain.add(IpFirstMaker); 70 | } 71 | query_clain.add(QuerySender::new(self.server_group.clone())); 72 | query_clain.next(DnsQuery::from(buf)).await 73 | } 74 | } 75 | 76 | struct Clain { 77 | pub funcs: Vec>, 78 | } 79 | 80 | impl Clain { 81 | fn new() -> Self { 82 | Clain { funcs: Vec::new() } 83 | } 84 | 85 | fn add(&mut self, handler: impl Handler + Send + Sync + 'static) { 86 | self.funcs.push(Box::new(handler)); 87 | } 88 | 89 | async fn next(mut self, query: DnsQuery) -> Result { 90 | self.funcs.remove(0).handle(self, query).await 91 | } 92 | } 93 | 94 | #[async_trait] 95 | trait Handler: Send + Sync { 96 | async fn handle(&self, clain: Clain, query: DnsQuery) -> Result; 97 | } 98 | -------------------------------------------------------------------------------- /src/handler/query_sender.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use std::sync::Arc; 3 | use crate::handler::{Clain, Handler}; 4 | use crate::system::Result; 5 | use crate::handler::server_group::ServerGroup; 6 | use crate::protocol::{DnsAnswer, DnsQuery}; 7 | 8 | #[derive(Clone)] 9 | pub struct QuerySender { 10 | server_group: Arc, 11 | } 12 | 13 | impl QuerySender { 14 | pub fn new(server_group: Arc) -> Self { 15 | QuerySender { 16 | server_group 17 | } 18 | } 19 | } 20 | 21 | #[async_trait] 22 | impl Handler for QuerySender { 23 | async fn handle(&self, _: Clain, query: DnsQuery) -> Result { 24 | let answer = self.server_group.send_query(query).await?; 25 | // info!("answer: {:?}", answer); 26 | // if answer.is_empty() { 27 | // return Err("answer is empty".into()); 28 | // } 29 | Ok(answer) 30 | } 31 | } -------------------------------------------------------------------------------- /src/handler/server_group/combine_server_sender.rs: -------------------------------------------------------------------------------- 1 | use crate::system::Result; 2 | use async_trait::async_trait; 3 | use crate::handler::server_group::query_executor::QueryExecutor; 4 | use crate::handler::server_group::ServerSender; 5 | use crate::protocol::{DnsAnswer, Ipv4Answer, FailureAnswer, DnsQuery}; 6 | 7 | pub struct CombineServerSender { 8 | executor: QueryExecutor, 9 | servers: Vec, 10 | } 11 | 12 | #[async_trait] 13 | impl ServerSender for CombineServerSender { 14 | async fn send(&self, query: DnsQuery) -> Result { 15 | let servers = &self.servers; 16 | let mut future_vec = Vec::with_capacity(servers.len()); 17 | for address in servers.iter() { 18 | future_vec.push(self.executor.exec(address.as_str(), query.clone())); 19 | } 20 | let mut ipv4_answer = Ipv4Answer::empty_answer( 21 | query.get_id().clone(), query.get_name().clone()); 22 | for future in future_vec { 23 | match future.await { 24 | Ok(r) => { 25 | ipv4_answer.combine(r); 26 | } 27 | Err(e) => { 28 | error!("{:?}", e); 29 | } 30 | } 31 | } 32 | if ipv4_answer.is_empty() { 33 | return Ok(FailureAnswer::new( 34 | query.get_id().clone(), query.get_name().clone()).into()); 35 | } else { 36 | Ok(ipv4_answer.into()) 37 | } 38 | } 39 | } 40 | 41 | impl CombineServerSender { 42 | pub fn from(executor: QueryExecutor, servers: Vec) -> Self { 43 | CombineServerSender { 44 | executor, 45 | servers, 46 | } 47 | } 48 | } -------------------------------------------------------------------------------- /src/handler/server_group/fast_server_sender.rs: -------------------------------------------------------------------------------- 1 | use std::sync::{Arc, Mutex}; 2 | use crate::system::Result; 3 | use tokio::time::Duration; 4 | use futures_util::FutureExt; 5 | use async_trait::async_trait; 6 | use futures_util::future::select_all; 7 | use tokio::time::interval; 8 | use crate::handler::server_group::query_executor::QueryExecutor; 9 | use crate::handler::server_group::ServerSender; 10 | use crate::protocol::{DnsAnswer, DnsQuery}; 11 | 12 | pub struct FastServerSender { 13 | executor: Arc, 14 | servers: Arc>, 15 | fast_server: Arc>, 16 | } 17 | 18 | #[async_trait] 19 | impl ServerSender for FastServerSender { 20 | async fn send(&self, query: DnsQuery) -> Result { 21 | let address = self.fast_server.lock().unwrap().clone(); 22 | self.executor.exec(address.as_str(), query).await 23 | } 24 | } 25 | 26 | impl FastServerSender { 27 | pub fn from( 28 | query_executor: QueryExecutor, 29 | servers: Vec, 30 | duration_secs: u64, 31 | ) -> Self { 32 | let executor = Arc::new(query_executor); 33 | let cloned_executor = executor.clone(); 34 | let arc_servers = Arc::new(servers); 35 | let cloned_servers = arc_servers.clone(); 36 | let fast_server = Arc::new(Mutex::new(String::new())); 37 | let cloned_fast_server = fast_server.clone(); 38 | let sender = FastServerSender { 39 | executor: cloned_executor, 40 | servers: cloned_servers, 41 | fast_server: cloned_fast_server, 42 | }; 43 | 44 | tokio::spawn(async move { 45 | let mut interval = interval(Duration::from_secs(duration_secs)); 46 | loop { 47 | interval.tick().await; 48 | let test_query: DnsQuery = "www.baidu.com".into(); 49 | if let Err(e) = sender.preferred_dns_server(test_query).await { 50 | error!("interval task upstream servers choose has error: {:?}", e) 51 | } 52 | } 53 | }); 54 | 55 | FastServerSender { 56 | executor, 57 | servers: arc_servers, 58 | fast_server, 59 | } 60 | } 61 | 62 | async fn preferred_dns_server(&self, query: DnsQuery) -> Result<()> { 63 | let (_, index) = self.get_answer_from_fast_server(query).await?; 64 | *self.fast_server.lock().unwrap() = self.servers[index].clone(); 65 | Ok(()) 66 | } 67 | 68 | async fn get_answer_from_fast_server(&self, query: DnsQuery) -> Result<(DnsAnswer, usize)> { 69 | let servers = &self.servers; 70 | let mut future_vec = Vec::with_capacity(servers.len()); 71 | for address in servers.iter() { 72 | future_vec.push(self.executor.exec(address.as_str(), query.clone()).boxed()); 73 | } 74 | let (result, index, _) = select_all(future_vec).await; 75 | let answer = result?; 76 | Ok((answer, index)) 77 | } 78 | } -------------------------------------------------------------------------------- /src/handler/server_group/mod.rs: -------------------------------------------------------------------------------- 1 | mod fast_server_sender; 2 | mod prefer_server_sender; 3 | mod combine_server_sender; 4 | mod query_executor; 5 | 6 | use crate::system::Result; 7 | use async_trait::async_trait; 8 | use crate::handler::server_group::fast_server_sender::FastServerSender; 9 | use crate::handler::server_group::prefer_server_sender::PreferServerSender; 10 | use crate::handler::server_group::combine_server_sender::CombineServerSender; 11 | use crate::handler::server_group::query_executor::QueryExecutor; 12 | use crate::protocol::{DnsAnswer, DnsQuery}; 13 | 14 | #[async_trait] 15 | pub trait ServerSender: Sync + Send { 16 | async fn send(&self, query: DnsQuery) -> Result; 17 | } 18 | 19 | pub struct ServerGroup { 20 | server_sender: Box, 21 | } 22 | 23 | impl ServerGroup { 24 | pub async fn from(servers: Vec, strategy: usize, duration_secs: u64) -> Result { 25 | let query_executor = QueryExecutor::create().await?; 26 | let server_sender: Box = match strategy { 27 | 0 => Box::new(FastServerSender::from(query_executor, servers, duration_secs)), 28 | 1 => Box::new(PreferServerSender::from(query_executor, servers)), 29 | 2 => Box::new(CombineServerSender::from(query_executor, servers)), 30 | _ => panic!("不支持的server strategy类型!"), 31 | }; 32 | Ok(ServerGroup { 33 | server_sender, 34 | }) 35 | } 36 | 37 | pub async fn send_query(&self, query: DnsQuery) -> Result { 38 | self.server_sender.send(query).await 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/handler/server_group/prefer_server_sender.rs: -------------------------------------------------------------------------------- 1 | use async_trait::async_trait; 2 | use crate::system::Result; 3 | use futures_util::future::select_all; 4 | use futures_util::FutureExt; 5 | use crate::handler::server_group::query_executor::QueryExecutor; 6 | use crate::handler::server_group::ServerSender; 7 | use crate::protocol::{DnsAnswer, DnsQuery}; 8 | 9 | pub struct PreferServerSender { 10 | executor: QueryExecutor, 11 | servers: Vec, 12 | } 13 | 14 | #[async_trait] 15 | impl ServerSender for PreferServerSender { 16 | async fn send(&self, query: DnsQuery) -> Result { 17 | let servers = &self.servers; 18 | let mut future_vec = Vec::with_capacity(servers.len()); 19 | for address in servers.iter() { 20 | future_vec.push(self.executor.exec(address.as_str(), query.clone()).boxed()); 21 | } 22 | let (result, _, _) = select_all(future_vec).await; 23 | let answer = result?; 24 | Ok(answer) 25 | } 26 | } 27 | 28 | impl PreferServerSender { 29 | pub fn from(executor: QueryExecutor, servers: Vec) -> Self { 30 | PreferServerSender { 31 | executor, 32 | servers, 33 | } 34 | } 35 | } -------------------------------------------------------------------------------- /src/handler/server_group/query_executor.rs: -------------------------------------------------------------------------------- 1 | use tokio::net::UdpSocket; 2 | use dashmap::DashMap; 3 | use tokio::sync::oneshot::Sender; 4 | use crate::system::{Result, next_id, AnswerBuf, default_value}; 5 | use std::sync::Arc; 6 | use tokio::sync::oneshot; 7 | use tokio::time::timeout; 8 | use std::time::Duration; 9 | use crate::protocol::{DnsAnswer, FailureAnswer, DnsQuery}; 10 | 11 | pub struct QueryExecutor { 12 | socket: Arc, 13 | reg_table: Arc>>, 14 | } 15 | 16 | impl QueryExecutor { 17 | pub async fn create() -> Result { 18 | let socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); 19 | let cloned_socket = socket.clone(); 20 | let reg_table = Arc::new(DashMap::new()); 21 | let cloned_reg_table = reg_table.clone(); 22 | 23 | let executor = QueryExecutor { 24 | socket: cloned_socket, 25 | reg_table: cloned_reg_table, 26 | }; 27 | 28 | tokio::spawn(async move { 29 | loop { 30 | match executor.recv().await { 31 | Ok(()) => {} 32 | Err(e) => error!("error occur here accept {:?}", e), 33 | } 34 | } 35 | }); 36 | 37 | Ok(QueryExecutor { 38 | socket, 39 | reg_table, 40 | }) 41 | } 42 | 43 | pub async fn exec(&self, address: &str, mut query: DnsQuery) -> Result { 44 | let (sender, receiver) = oneshot::channel(); 45 | let client_query_id = query.get_id(); 46 | let next_id = next_id(); 47 | self.reg_table.insert(next_id, sender); 48 | query.set_id(next_id); 49 | let query_vec: Vec = (&query).into(); 50 | self.socket 51 | .send_to(query_vec.as_slice(), address) 52 | .await?; 53 | let mut answer = match timeout(Duration::from_secs(3), receiver).await { 54 | Ok(result) => { 55 | let buf = result?; 56 | DnsAnswer::from(buf) 57 | } 58 | Err(_) => { 59 | FailureAnswer::new(client_query_id, query.get_name().clone()).into() 60 | } 61 | }; 62 | self.reg_table.remove(&next_id); 63 | answer.set_id(client_query_id); 64 | Ok(answer) 65 | } 66 | 67 | async fn recv(&self) -> Result<()> { 68 | let mut buf: AnswerBuf = default_value(); 69 | self.socket.recv_from(&mut buf).await?; 70 | let id = u16::from_be_bytes([buf[0], buf[1]]); 71 | match self.reg_table.remove(&id) { 72 | Some((_, sender)) => { 73 | if let Err(_e) = sender.send(buf) { 74 | self.reg_table.remove(&id); 75 | } 76 | } 77 | None => {} 78 | } 79 | Ok(()) 80 | } 81 | } -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | #![feature(panic_info_message)] 2 | #![feature(const_generics)] 3 | 4 | use std::sync::Arc; 5 | 6 | use simple_logger::SimpleLogger; 7 | 8 | use crate::handler::*; 9 | use crate::system::{Result}; 10 | use crate::client::ClientSocket; 11 | 12 | mod config; 13 | mod filter; 14 | mod system; 15 | mod handler; 16 | mod cache; 17 | mod client; 18 | mod cursor; 19 | mod protocol; 20 | 21 | #[macro_use] 22 | extern crate log; 23 | 24 | //dig @127.0.0.1 -p 2053 www.baidu.com 25 | //dig @127.0.0.1 -p 2053 0-100.com 26 | #[tokio::main] 27 | async fn main() -> Result<()> { 28 | SimpleLogger::new().init()?; 29 | system::setup_panic_hook(); 30 | 31 | let config = config::init_from_toml().await?; 32 | system::setup_log_level(&config)?; 33 | let client = Arc::new(ClientSocket::new(config.port).await?); 34 | let handler = Arc::new(HandlerContext::from(config).await?); 35 | //主循环 36 | loop { 37 | tokio::select! { 38 | result = client.recv() => { 39 | let (query_buf, src) = result?; 40 | let arc_client = client.clone(); 41 | let arc_handler = handler.clone(); 42 | tokio::spawn(async move { 43 | let answer = match arc_handler.handle_query(query_buf).await { 44 | Ok(answer) => answer, 45 | Err(e) => { 46 | error!("Handle query task error: {:?}", e); 47 | return; 48 | }, 49 | }; 50 | info!("answer: {}", answer); 51 | if let Err(e) = arc_client.back_to(src, answer).await { 52 | error!("Send answer back to client error: {:?}", e) 53 | } 54 | }); 55 | }, 56 | //监听ctrl_c事件 57 | _ = tokio::signal::ctrl_c() => { 58 | break; 59 | } 60 | } 61 | } 62 | Ok(()) 63 | } 64 | -------------------------------------------------------------------------------- /src/protocol/answer/failure.rs: -------------------------------------------------------------------------------- 1 | use crate::protocol::answer::Answer; 2 | use crate::cache::CacheRecord; 3 | use std::fmt::{Display, Formatter}; 4 | use std::any::Any; 5 | use crate::protocol::basic::{BasicData, Builder}; 6 | 7 | pub struct FailureAnswer { 8 | data: BasicData, 9 | } 10 | 11 | impl Display for FailureAnswer { 12 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 13 | write!(f, "(FAILURE, {})", self.data.get_name()) 14 | } 15 | } 16 | 17 | impl Answer for FailureAnswer { 18 | fn to_cache(&self) -> Option { 19 | None 20 | } 21 | 22 | fn to_bytes(&self) -> Vec { 23 | let data = &self.data; 24 | data.into() 25 | } 26 | 27 | fn as_any(&self) -> &(dyn Any + Send + Sync) { 28 | self 29 | } 30 | 31 | fn as_mut_any(&mut self) -> &mut (dyn Any + Send + Sync) { 32 | self 33 | } 34 | 35 | fn set_id(&mut self, id: u16) { 36 | self.data.set_id(id); 37 | } 38 | 39 | fn get_id(&self) -> u16 { 40 | self.data.get_id() 41 | } 42 | } 43 | 44 | impl FailureAnswer { 45 | pub fn from(mut data: BasicData) -> Self { 46 | data.set_authority_count(0); 47 | data.set_answer_count(0); 48 | FailureAnswer { 49 | data 50 | } 51 | } 52 | 53 | pub fn new(id: u16, name: String) -> Self { 54 | let data = Builder::new() 55 | .id(id) 56 | .name(name) 57 | .flags(0x8182) 58 | .build(); 59 | FailureAnswer { 60 | data 61 | } 62 | } 63 | } -------------------------------------------------------------------------------- /src/protocol/answer/ipv4.rs: -------------------------------------------------------------------------------- 1 | use crate::protocol::answer::Answer; 2 | use crate::cache::{CacheRecord, IpCacheRecord, CacheItem}; 3 | use crate::protocol::answer::resource::{Ipv4Resource, Resource}; 4 | use std::fmt::{Display, Formatter}; 5 | use std::any::Any; 6 | use crate::protocol::{DnsAnswer}; 7 | use std::net::Ipv4Addr; 8 | use crate::protocol::basic::{BasicData, Builder}; 9 | 10 | pub struct Ipv4Answer { 11 | data: BasicData, 12 | resources: Vec, 13 | } 14 | 15 | impl Display for Ipv4Answer { 16 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 17 | write!(f, "(IP, {}, {}, {})", self.data.get_name(), 18 | self.resources[0].get_ttl(), self.resources[0].get_data()) 19 | } 20 | } 21 | 22 | impl Answer for Ipv4Answer { 23 | fn to_cache(&self) -> Option { 24 | Some(IpCacheRecord::from(self).into()) 25 | } 26 | 27 | fn to_bytes(&self) -> Vec { 28 | let data = &self.data; 29 | let mut vec: Vec = data.into(); 30 | self.resources.iter().for_each(|r| { 31 | let resource: Vec = r.into(); 32 | vec.extend(resource) 33 | }); 34 | vec 35 | } 36 | 37 | fn as_any(&self) -> &(dyn Any + Send + Sync) { 38 | self 39 | } 40 | 41 | fn as_mut_any(&mut self) -> &mut (dyn Any + Send + Sync) { 42 | self 43 | } 44 | 45 | fn set_id(&mut self, id: u16) { 46 | self.data.set_id(id) 47 | } 48 | 49 | fn get_id(&self) -> u16 { 50 | self.data.get_id() 51 | } 52 | } 53 | 54 | impl Ipv4Answer { 55 | pub fn create(mut data: BasicData, mut resources: Vec) -> Self { 56 | data.set_authority_count(0); 57 | resources.iter_mut().for_each(|e| { 58 | e.set_name(data.get_name().clone()); 59 | }); 60 | Ipv4Answer { 61 | data, 62 | resources, 63 | } 64 | } 65 | 66 | pub fn combine(&mut self, mut other: DnsAnswer) { 67 | if let Some(answer) = other.as_mut_any().downcast_mut::() { 68 | if self.get_name() != answer.get_name() { 69 | return; 70 | } 71 | while let Some(r) = answer.resources.pop() { 72 | let flag = self.resources.iter().find(|e| { 73 | e.data != r.data 74 | }).is_none(); 75 | if flag { 76 | self.resources.push(r); 77 | } 78 | } 79 | self.data.set_answer_count(self.resources.len() as u16); 80 | } 81 | } 82 | pub fn empty_answer(id: u16, name: String) -> Self { 83 | let data = Builder::new() 84 | .id(id) 85 | .name(name) 86 | .flags(0x8180) 87 | .build(); 88 | Ipv4Answer { 89 | data, 90 | resources: vec![], 91 | } 92 | } 93 | 94 | pub fn is_empty(&self) -> bool { 95 | self.resources.is_empty() 96 | } 97 | 98 | pub fn get_all_ips(&self) -> Vec<&Ipv4Addr> { 99 | self.resources.iter().map(|r| { 100 | &r.data 101 | }).collect() 102 | } 103 | 104 | pub fn retain_ip(&mut self, ip: &Ipv4Addr) { 105 | self.resources.retain(|r| { 106 | r.data.eq(ip) 107 | }); 108 | self.data.set_answer_count(1); 109 | } 110 | 111 | pub fn get_name(&self) -> &String { 112 | self.data.get_name() 113 | } 114 | 115 | pub fn get_ttl(&self) -> u32 { 116 | self.resources[0].get_ttl() 117 | } 118 | 119 | pub fn get_address(&self) -> &Ipv4Addr { 120 | self.resources[0].get_data() 121 | } 122 | } 123 | 124 | impl From<&IpCacheRecord> for Ipv4Answer { 125 | fn from(record: &IpCacheRecord) -> Self { 126 | let data = Builder::new() 127 | .flags(0x8180) 128 | .name(record.get_key().clone()) 129 | .answer(1) 130 | .build(); 131 | let resource = Ipv4Resource::from(record); 132 | Ipv4Answer { 133 | data, 134 | resources: vec![resource], 135 | } 136 | } 137 | } -------------------------------------------------------------------------------- /src/protocol/answer/mod.rs: -------------------------------------------------------------------------------- 1 | mod failure; 2 | mod resource; 3 | mod no_such_name; 4 | mod soa; 5 | mod ipv4; 6 | 7 | use crate::cache::CacheRecord; 8 | use crate::system::AnswerBuf; 9 | use crate::cursor::Cursor; 10 | use crate::protocol::answer::resource::{CnameResource, Ipv4Resource, SoaResource}; 11 | use crate::protocol::answer::no_such_name::NoSuchNameAnswer; 12 | use std::fmt::{Display}; 13 | use std::any::Any; 14 | 15 | pub type DnsAnswer = Box; 16 | 17 | pub use ipv4::Ipv4Answer; 18 | pub use failure::FailureAnswer; 19 | pub use soa::SoaAnswer; 20 | use crate::protocol::basic::BasicData; 21 | 22 | pub trait Answer: Display + Send + Sync { 23 | fn to_cache(&self) -> Option; 24 | fn to_bytes(&self) -> Vec; 25 | fn as_any(&self) -> &(dyn Any + Send + Sync); 26 | fn as_mut_any(&mut self) -> &mut (dyn Any + Send + Sync); 27 | fn set_id(&mut self, id: u16); 28 | fn get_id(&self) -> u16; 29 | } 30 | 31 | impl From for DnsAnswer { 32 | fn from(buf: AnswerBuf) -> Self { 33 | // info!("buf: {:?}", &buf[0..buf.len()]); 34 | let cursor = Cursor::form(buf.into()); 35 | let data = BasicData::from(&cursor); 36 | if data.get_flags() == 0x8182 { 37 | return FailureAnswer::from(data).into(); 38 | } 39 | if data.get_answer_count() == 0 && data.get_authority_count() == 0 { 40 | return FailureAnswer::from(data).into(); 41 | } 42 | if data.get_flags() == 0x8183 { 43 | return NoSuchNameAnswer::from(data).into(); 44 | } 45 | let mut ipv4_records = Vec::new(); 46 | (0..data.get_answer_count() as usize).into_iter().for_each(|_| { 47 | let r_data = resource::BasicData::from(&cursor); 48 | if r_data.get_type() == 5 { 49 | // cname记录 目前的处理是移除 50 | let _resource = CnameResource::create(r_data, &cursor); 51 | } else if r_data.get_type() == 1 { 52 | // a记录 53 | ipv4_records.push(Ipv4Resource::create(r_data, &cursor)); 54 | } else { 55 | panic!("不支持的应答资源记录类型: name = {}, type = {}", 56 | r_data.get_name(), r_data.get_type()) 57 | }; 58 | }); 59 | let mut soa_records = Vec::new(); 60 | (0..data.get_authority_count() as usize).into_iter().for_each(|_| { 61 | let r_data = resource::BasicData::from(&cursor); 62 | if r_data.get_type() == 6 { 63 | soa_records.push(SoaResource::create(r_data, &cursor)); 64 | } else { 65 | panic!("不支持的认证资源记录类型: name = {}, type = {}", 66 | r_data.get_name(), r_data.get_type()) 67 | } 68 | }); 69 | if !ipv4_records.is_empty() { 70 | return Ipv4Answer::create(data, ipv4_records).into(); 71 | } 72 | if !soa_records.is_empty() { 73 | return SoaAnswer::create(data, soa_records.remove(0)).into(); 74 | } 75 | unreachable!() 76 | } 77 | } 78 | 79 | impl From for DnsAnswer { 80 | fn from(f: FailureAnswer) -> Self { 81 | Box::new(f) 82 | } 83 | } 84 | 85 | impl From for DnsAnswer { 86 | fn from(f: NoSuchNameAnswer) -> Self { 87 | Box::new(f) 88 | } 89 | } 90 | 91 | impl From for DnsAnswer { 92 | fn from(f: SoaAnswer) -> Self { 93 | Box::new(f) 94 | } 95 | } 96 | 97 | impl From for DnsAnswer { 98 | fn from(f: Ipv4Answer) -> Self { 99 | Box::new(f) 100 | } 101 | } -------------------------------------------------------------------------------- /src/protocol/answer/no_such_name.rs: -------------------------------------------------------------------------------- 1 | use crate::protocol::answer::Answer; 2 | use crate::cache::CacheRecord; 3 | use std::fmt::{Display, Formatter}; 4 | use std::any::Any; 5 | use crate::protocol::basic::BasicData; 6 | 7 | pub struct NoSuchNameAnswer { 8 | data: BasicData, 9 | } 10 | 11 | impl Display for NoSuchNameAnswer { 12 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 13 | write!(f, "(NO_SUCH_NAME, {})", self.data.get_name()) 14 | } 15 | } 16 | 17 | impl Answer for NoSuchNameAnswer { 18 | fn to_cache(&self) -> Option { 19 | None 20 | } 21 | 22 | fn to_bytes(&self) -> Vec { 23 | let data = &self.data; 24 | data.into() 25 | } 26 | 27 | fn as_any(&self) -> &(dyn Any + Send + Sync) { 28 | self 29 | } 30 | 31 | fn as_mut_any(&mut self) -> &mut (dyn Any + Send + Sync) { 32 | self 33 | } 34 | 35 | fn set_id(&mut self, id: u16) { 36 | self.data.set_id(id); 37 | } 38 | 39 | fn get_id(&self) -> u16 { 40 | self.data.get_id() 41 | } 42 | } 43 | 44 | impl NoSuchNameAnswer { 45 | pub fn from(mut data: BasicData) -> Self { 46 | data.set_answer_count(0); 47 | data.set_authority_count(0); 48 | NoSuchNameAnswer { 49 | data 50 | } 51 | } 52 | } -------------------------------------------------------------------------------- /src/protocol/answer/resource/basic.rs: -------------------------------------------------------------------------------- 1 | use crate::protocol::question::Question; 2 | use crate::cursor::Cursor; 3 | 4 | #[derive(Debug, Eq, PartialEq, Clone)] 5 | pub struct BasicData { 6 | question: Question, 7 | ttl: u32, 8 | data_len: u16, 9 | } 10 | 11 | impl From<&Cursor> for BasicData { 12 | fn from(cursor: &Cursor) -> Self { 13 | let question = Question::from(cursor); 14 | let ttl = u32::from_be_bytes(cursor.take_bytes()); 15 | let data_len = u16::from_be_bytes(cursor.take_bytes()); 16 | BasicData { 17 | question, 18 | ttl, 19 | data_len, 20 | } 21 | } 22 | } 23 | 24 | impl From<&BasicData> for Vec { 25 | fn from(data: &BasicData) -> Self { 26 | let question = &data.question; 27 | let mut vec: Vec = question.into(); 28 | vec.extend(&data.ttl.to_be_bytes()); 29 | vec.extend(&data.data_len.to_be_bytes()); 30 | vec 31 | } 32 | } 33 | 34 | impl BasicData { 35 | pub fn get_name(&self) -> &String { 36 | &self.question.name 37 | } 38 | 39 | pub fn set_name(&mut self, name: String) { 40 | self.question.name = name; 41 | } 42 | 43 | pub fn get_ttl(&self) -> u32 { 44 | self.ttl 45 | } 46 | 47 | pub fn get_type(&self) -> u16 { 48 | self.question._type 49 | } 50 | 51 | pub fn set_data_len(&mut self, len: u16) { 52 | self.data_len = len; 53 | } 54 | 55 | fn new() -> Self { 56 | let mut question = Question::new(); 57 | question._type = 0; 58 | question.class = 1; 59 | BasicData { 60 | question, 61 | ttl: 0, 62 | data_len: 0, 63 | } 64 | } 65 | } 66 | 67 | pub struct Builder { 68 | data: Option, 69 | } 70 | 71 | impl Builder { 72 | pub fn new() -> Self { 73 | Builder { 74 | data: Some(BasicData::new()) 75 | } 76 | } 77 | 78 | pub fn name(mut self, name: String) -> Self { 79 | self.data.as_mut().map(|e| { 80 | e.question.name = name; 81 | e 82 | }); 83 | self 84 | } 85 | 86 | pub fn ttl(mut self, ttl: u32) -> Self { 87 | self.data.as_mut().map(|e| { 88 | e.ttl = ttl; 89 | e 90 | }); 91 | self 92 | } 93 | 94 | pub fn _type(mut self, _type: u16) -> Self { 95 | self.data.as_mut().map(|e| { 96 | e.question._type = _type; 97 | e 98 | }); 99 | self 100 | } 101 | 102 | pub fn data_len(mut self, data_len: u16) -> Self { 103 | self.data.as_mut().map(|e| { 104 | e.data_len = data_len; 105 | e 106 | }); 107 | self 108 | } 109 | 110 | pub fn build(mut self) -> BasicData { 111 | self.data.take().unwrap() 112 | } 113 | } -------------------------------------------------------------------------------- /src/protocol/answer/resource/cname.rs: -------------------------------------------------------------------------------- 1 | use crate::cursor::Cursor; 2 | use crate::protocol::answer::resource::{Resource, BasicData}; 3 | use crate::protocol::unzip_domain; 4 | 5 | #[derive(Debug, Eq, PartialEq, Clone)] 6 | pub struct CnameResource { 7 | basic: BasicData, 8 | data: String, 9 | } 10 | 11 | impl Resource for CnameResource { 12 | fn get_name(&self) -> &String { 13 | self.basic.get_name() 14 | } 15 | 16 | fn get_ttl(&self) -> u32 { 17 | self.basic.get_ttl() 18 | } 19 | 20 | fn get_data(&self) -> &String { 21 | &self.data 22 | } 23 | } 24 | 25 | impl CnameResource { 26 | pub fn create(basic: BasicData, cursor: &Cursor) -> Self { 27 | let data = unzip_domain(cursor); 28 | CnameResource { 29 | basic, 30 | data, 31 | } 32 | } 33 | } -------------------------------------------------------------------------------- /src/protocol/answer/resource/ipv4.rs: -------------------------------------------------------------------------------- 1 | use crate::protocol::answer::resource::{Resource, BasicData}; 2 | use crate::cursor::Cursor; 3 | use std::net::Ipv4Addr; 4 | use crate::cache::{IpCacheRecord, CacheItem}; 5 | use crate::protocol::answer::resource::basic::Builder; 6 | use crate::system::get_now; 7 | 8 | #[derive(Debug, Eq, PartialEq, Clone)] 9 | pub struct Ipv4Resource { 10 | basic: BasicData, 11 | pub data: Ipv4Addr, 12 | } 13 | 14 | impl Resource for Ipv4Resource { 15 | fn get_name(&self) -> &String { 16 | self.basic.get_name() 17 | } 18 | 19 | fn get_ttl(&self) -> u32 { 20 | self.basic.get_ttl() 21 | } 22 | 23 | fn get_data(&self) -> &Ipv4Addr { 24 | &self.data 25 | } 26 | } 27 | 28 | impl From<&Ipv4Resource> for Vec { 29 | fn from(r: &Ipv4Resource) -> Self { 30 | let data = &r.basic; 31 | let mut vec: Vec = data.into(); 32 | vec.extend(&r.data.octets()); 33 | vec 34 | } 35 | } 36 | 37 | impl From<&IpCacheRecord> for Ipv4Resource { 38 | fn from(record: &IpCacheRecord) -> Self { 39 | let basic = Builder::new() 40 | .name(record.get_key().clone()) 41 | .ttl((record.get_remain_time(get_now()) / 1000) as u32) 42 | ._type(1) 43 | .data_len(4) 44 | .build(); 45 | Ipv4Resource { 46 | basic, 47 | data: record.get_address().clone(), 48 | } 49 | } 50 | } 51 | 52 | impl Ipv4Resource { 53 | pub fn create(basic: BasicData, cursor: &Cursor) -> Self { 54 | let data = Ipv4Addr::from(cursor.take_bytes()); 55 | Ipv4Resource { 56 | basic, 57 | data, 58 | } 59 | } 60 | 61 | pub fn set_name(&mut self, name: String) { 62 | self.basic.set_name(name); 63 | } 64 | } -------------------------------------------------------------------------------- /src/protocol/answer/resource/mod.rs: -------------------------------------------------------------------------------- 1 | mod cname; 2 | mod soa; 3 | mod ipv4; 4 | mod basic; 5 | 6 | pub use cname::CnameResource; 7 | pub use soa::SoaResource; 8 | pub use ipv4::Ipv4Resource; 9 | pub use basic::BasicData; 10 | 11 | pub trait Resource { 12 | fn get_name(&self) -> &String; 13 | fn get_ttl(&self) -> u32; 14 | fn get_data(&self) -> &T; 15 | } -------------------------------------------------------------------------------- /src/protocol/answer/resource/soa.rs: -------------------------------------------------------------------------------- 1 | use crate::protocol::answer::resource::{Resource, BasicData}; 2 | use crate::cursor::Cursor; 3 | use crate::protocol::{unzip_domain, wrap_name}; 4 | use crate::protocol::answer::resource::basic; 5 | 6 | #[derive(Debug, Eq, PartialEq, Clone)] 7 | pub struct SoaResource { 8 | basic: BasicData, 9 | data: Soa, 10 | } 11 | 12 | impl Resource for SoaResource { 13 | fn get_name(&self) -> &String { 14 | self.basic.get_name() 15 | } 16 | 17 | fn get_ttl(&self) -> u32 { 18 | self.basic.get_ttl() 19 | } 20 | 21 | fn get_data(&self) -> &Soa { 22 | &self.data 23 | } 24 | } 25 | 26 | impl From<&SoaResource> for Vec { 27 | fn from(r: &SoaResource) -> Self { 28 | let data = &r.basic; 29 | let mut vec: Vec = data.into(); 30 | let soa = &r.data; 31 | let data_vec: Vec = soa.into(); 32 | vec.extend(data_vec); 33 | vec 34 | } 35 | } 36 | 37 | impl SoaResource { 38 | pub fn create(mut basic: BasicData, cursor: &Cursor) -> Self { 39 | let data = Soa::from(cursor); 40 | basic.set_data_len(data.len as u16); 41 | SoaResource { 42 | basic, 43 | data, 44 | } 45 | } 46 | 47 | pub fn new_with_default_soa(name: String, ttl: u32) -> Self { 48 | let soa = Soa::default(); 49 | let basic = basic::Builder::new() 50 | ._type(6) 51 | .ttl(ttl) 52 | .name(name) 53 | .data_len(soa.len as u16) 54 | .build(); 55 | SoaResource { 56 | basic, 57 | data: soa, 58 | } 59 | } 60 | 61 | pub fn set_name(&mut self, name: String) { 62 | self.basic.set_name(name) 63 | } 64 | } 65 | 66 | #[derive(Clone, Debug, Eq, PartialEq)] 67 | struct NameServer { 68 | domain: String, 69 | len: usize, 70 | } 71 | 72 | impl From<&Cursor> for NameServer { 73 | fn from(cursor: &Cursor) -> Self { 74 | if cursor.peek() == 0 { 75 | cursor.take(); 76 | NameServer { 77 | domain: ".".to_string(), 78 | len: 1, 79 | } 80 | } else { 81 | let domain = unzip_domain(cursor); 82 | let len = domain.len() + 2; 83 | NameServer { 84 | domain, 85 | len, 86 | } 87 | } 88 | } 89 | } 90 | 91 | impl From<&str> for NameServer { 92 | fn from(str: &str) -> Self { 93 | NameServer { 94 | domain: str.to_string(), 95 | len: 2 + str.len(), 96 | } 97 | } 98 | } 99 | 100 | impl From<&NameServer> for Vec { 101 | fn from(name_server: &NameServer) -> Self { 102 | let mut vec = Vec::new(); 103 | if name_server.domain.eq(".") { 104 | vec.push(0u8); 105 | } else { 106 | vec.extend(wrap_name(&name_server.domain)); 107 | } 108 | vec 109 | } 110 | } 111 | 112 | impl NameServer { 113 | fn len(&self) -> usize { 114 | self.len 115 | } 116 | } 117 | 118 | #[derive(Debug, Eq, PartialEq, Clone)] 119 | pub struct Soa { 120 | name_server: NameServer, 121 | mailbox: String, 122 | serial_number: u32, 123 | interval_refresh: u32, 124 | interval_retry: u32, 125 | expire_limit: u32, 126 | minimum_ttl: u32, 127 | len: usize, 128 | } 129 | 130 | impl From<&Soa> for Vec { 131 | fn from(s: &Soa) -> Self { 132 | let server = &s.name_server; 133 | let mut vec: Vec = server.into(); 134 | vec.extend(wrap_name(&s.mailbox)); 135 | vec.extend(&s.serial_number.to_be_bytes()); 136 | vec.extend(&s.interval_refresh.to_be_bytes()); 137 | vec.extend(&s.interval_retry.to_be_bytes()); 138 | vec.extend(&s.expire_limit.to_be_bytes()); 139 | vec.extend(&s.minimum_ttl.to_be_bytes()); 140 | vec 141 | } 142 | } 143 | 144 | impl Soa { 145 | fn from(cursor: &Cursor) -> Self { 146 | let name_server = NameServer::from(cursor); 147 | let mailbox = unzip_domain(cursor); 148 | let serial_number = u32::from_be_bytes(cursor.take_bytes()); 149 | let interval_refresh = u32::from_be_bytes(cursor.take_bytes()); 150 | let interval_retry = u32::from_be_bytes(cursor.take_bytes()); 151 | let expire_limit = u32::from_be_bytes(cursor.take_bytes()); 152 | let minimum_ttl = u32::from_be_bytes(cursor.take_bytes()); 153 | let len = name_server.len() + mailbox.len() + 2 + 20; 154 | Soa { 155 | name_server, 156 | mailbox, 157 | serial_number, 158 | interval_refresh, 159 | interval_retry, 160 | expire_limit, 161 | minimum_ttl, 162 | len, 163 | } 164 | } 165 | 166 | fn default() -> Self { 167 | let name_server = NameServer::from("dns17.hichina.com"); 168 | let len = name_server.len(); 169 | Soa { 170 | name_server, 171 | mailbox: "hostmaster.hichina.com".to_string(), 172 | serial_number: 1, 173 | interval_refresh: 3600, 174 | interval_retry: 1200, 175 | expire_limit: 3600, 176 | minimum_ttl: 600, 177 | len: len + 24 + 20, 178 | } 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /src/protocol/answer/soa.rs: -------------------------------------------------------------------------------- 1 | use crate::protocol::answer::Answer; 2 | use crate::cache::{CacheRecord, SoaCacheRecord, CacheItem}; 3 | use crate::protocol::answer::resource::{SoaResource, Resource}; 4 | use std::fmt::{Display, Formatter}; 5 | use std::any::Any; 6 | use crate::protocol::basic::{BasicData, Builder}; 7 | use crate::system::get_now; 8 | 9 | pub struct SoaAnswer { 10 | data: BasicData, 11 | resource: SoaResource, 12 | } 13 | 14 | impl Display for SoaAnswer { 15 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 16 | write!(f, "(SOA, {}, {})", self.data.get_name(), self.resource.get_ttl()) 17 | } 18 | } 19 | 20 | impl From<&SoaCacheRecord> for SoaAnswer { 21 | fn from(record: &SoaCacheRecord) -> Self { 22 | let data = Builder::new() 23 | .name(record.get_key().clone()) 24 | .flags(0x8180) 25 | .authority(1) 26 | .build(); 27 | let resource = SoaResource::new_with_default_soa( 28 | record.get_key().clone(), record.get_remain_time(get_now()) as u32 / 1000); 29 | SoaAnswer { 30 | data, 31 | resource, 32 | } 33 | } 34 | } 35 | 36 | impl Answer for SoaAnswer { 37 | fn to_cache(&self) -> Option { 38 | Some(SoaCacheRecord::from(self).into()) 39 | } 40 | 41 | fn to_bytes(&self) -> Vec { 42 | let data = &self.data; 43 | let resource1 = &self.resource; 44 | let mut vec: Vec = data.into(); 45 | let resource: Vec = resource1.into(); 46 | vec.extend(resource); 47 | vec 48 | } 49 | 50 | fn as_any(&self) -> &(dyn Any + Send + Sync) { 51 | self 52 | } 53 | 54 | fn as_mut_any(&mut self) -> &mut (dyn Any + Send + Sync) { 55 | self 56 | } 57 | 58 | fn set_id(&mut self, id: u16) { 59 | self.data.set_id(id); 60 | } 61 | 62 | fn get_id(&self) -> u16 { 63 | self.data.get_id() 64 | } 65 | } 66 | 67 | impl SoaAnswer { 68 | pub fn create(mut data: BasicData, mut resource: SoaResource) -> Self { 69 | data.set_answer_count(0); 70 | resource.set_name(data.get_name().clone()); 71 | SoaAnswer { 72 | data, 73 | resource, 74 | } 75 | } 76 | 77 | pub fn default_soa(id: u16, name: String) -> Self { 78 | let data = Builder::new() 79 | .id(id) 80 | .name(name.clone()) 81 | .flags(0x8180) 82 | .authority(1) 83 | .build(); 84 | SoaAnswer { 85 | data, 86 | resource: SoaResource::new_with_default_soa(name, 600), 87 | } 88 | } 89 | 90 | pub fn get_name(&self) -> &String { 91 | self.data.get_name() 92 | } 93 | 94 | pub fn get_ttl(&self) -> u32 { 95 | self.resource.get_ttl() 96 | } 97 | } -------------------------------------------------------------------------------- /src/protocol/basic.rs: -------------------------------------------------------------------------------- 1 | use crate::protocol::header::Header; 2 | use crate::protocol::question::Question; 3 | use crate::cursor::Cursor; 4 | 5 | #[derive(Clone, Eq, PartialEq, Debug)] 6 | pub struct BasicData { 7 | header: Header, 8 | question: Question, 9 | } 10 | 11 | impl From<&Cursor> for BasicData { 12 | fn from(cursor: &Cursor) -> Self { 13 | let header = Header::from(cursor); 14 | if header.question_count > 1 { 15 | panic!("不支持多个域名查询") 16 | } 17 | let question = Question::from(cursor); 18 | BasicData { 19 | header, 20 | question, 21 | } 22 | } 23 | } 24 | 25 | impl BasicData { 26 | pub fn set_id(&mut self, id: u16) { 27 | self.header.id = id; 28 | } 29 | pub fn set_answer_count(&mut self, count: u16) { 30 | self.header.answer_count = count 31 | } 32 | pub fn set_authority_count(&mut self, count: u16) { 33 | self.header.authority_count = count 34 | } 35 | pub fn get_id(&self) -> u16 { 36 | self.header.id 37 | } 38 | pub fn get_flags(&self) -> u16 { 39 | self.header.flags 40 | } 41 | pub fn get_name(&self) -> &String { 42 | &self.question.name 43 | } 44 | 45 | fn new() -> Self { 46 | let mut header = Header::new(); 47 | header.question_count = 1; 48 | let mut question = Question::new(); 49 | question._type = 1; 50 | question.class = 1; 51 | BasicData { 52 | header, 53 | question, 54 | } 55 | } 56 | 57 | pub fn get_answer_count(&self) -> u16 { 58 | self.header.answer_count 59 | } 60 | 61 | pub fn get_authority_count(&self) -> u16 { 62 | self.header.authority_count 63 | } 64 | } 65 | 66 | impl From<&BasicData> for Vec { 67 | fn from(data: &BasicData) -> Self { 68 | let header = &data.header; 69 | let mut header_vec: Vec = header.into(); 70 | let question = &data.question; 71 | let question_vec: Vec = question.into(); 72 | header_vec.extend(question_vec); 73 | header_vec 74 | } 75 | } 76 | 77 | pub struct Builder { 78 | data: Option, 79 | } 80 | 81 | impl Builder { 82 | pub fn new() -> Self { 83 | Builder { 84 | data: Some(BasicData::new()) 85 | } 86 | } 87 | 88 | pub fn id(mut self, id: u16) -> Self { 89 | self.data.as_mut().map(|e| { 90 | e.header.id = id; 91 | e 92 | }); 93 | self 94 | } 95 | 96 | pub fn flags(mut self, flags: u16) -> Self { 97 | self.data.as_mut().map(|e| { 98 | e.header.flags = flags; 99 | e 100 | }); 101 | self 102 | } 103 | 104 | pub fn name(mut self, name: String) -> Self { 105 | self.data.as_mut().map(|e| { 106 | e.question.name = name; 107 | e 108 | }); 109 | self 110 | } 111 | 112 | pub fn authority(mut self, count: u16) -> Self { 113 | self.data.as_mut().map(|e| { 114 | e.header.authority_count = count; 115 | e 116 | }); 117 | self 118 | } 119 | 120 | pub fn answer(mut self, count: u16) -> Self { 121 | self.data.as_mut().map(|e| { 122 | e.header.answer_count = count; 123 | e 124 | }); 125 | self 126 | } 127 | 128 | pub fn build(mut self) -> BasicData { 129 | self.data.take().unwrap() 130 | } 131 | } -------------------------------------------------------------------------------- /src/protocol/header.rs: -------------------------------------------------------------------------------- 1 | use crate::cursor::Cursor; 2 | 3 | #[derive(Debug, Clone, Eq, PartialEq)] 4 | pub struct Header { 5 | pub id: u16, 6 | pub flags: u16, 7 | pub question_count: u16, 8 | pub answer_count: u16, 9 | pub authority_count: u16, 10 | pub additional_count: u16, 11 | } 12 | 13 | impl From<&Header> for Vec { 14 | fn from(header: &Header) -> Self { 15 | let mut result = Vec::with_capacity(12); 16 | result.extend(&header.id.to_be_bytes()); 17 | result.extend(&header.flags.to_be_bytes()); 18 | result.extend(&header.question_count.to_be_bytes()); 19 | result.extend(&header.answer_count.to_be_bytes()); 20 | result.extend(&header.authority_count.to_be_bytes()); 21 | result.extend(&header.additional_count.to_be_bytes()); 22 | result 23 | } 24 | } 25 | 26 | impl From<&Cursor> for Header { 27 | fn from(cursor: &Cursor) -> Self { 28 | let header = Header { 29 | id: u16::from_be_bytes([cursor.take(), cursor.take()]), 30 | flags: u16::from_be_bytes([cursor.take(), cursor.take()]), 31 | question_count: u16::from_be_bytes([cursor.take(), cursor.take()]), 32 | answer_count: u16::from_be_bytes([cursor.take(), cursor.take()]), 33 | authority_count: u16::from_be_bytes([cursor.take(), cursor.take()]), 34 | additional_count: 0, 35 | }; 36 | cursor.move_to(2); 37 | header 38 | } 39 | } 40 | 41 | impl Header { 42 | pub fn new() -> Self { 43 | Header { 44 | id: 0, 45 | flags: 0, 46 | question_count: 0, 47 | answer_count: 0, 48 | authority_count: 0, 49 | additional_count: 0, 50 | } 51 | } 52 | } -------------------------------------------------------------------------------- /src/protocol/mod.rs: -------------------------------------------------------------------------------- 1 | mod header; 2 | mod question; 3 | mod answer; 4 | mod basic; 5 | mod query; 6 | 7 | use crate::cursor::Cursor; 8 | 9 | const C_FACTOR: u8 = 192u8; 10 | const DC_FACTOR: u16 = 16383u16; 11 | 12 | pub use answer::{DnsAnswer, Ipv4Answer, FailureAnswer, SoaAnswer}; 13 | pub use query::DnsQuery; 14 | 15 | fn parse_name(cursor: &Cursor, name_vec: &mut Vec) { 16 | if cursor.peek() & C_FACTOR == C_FACTOR { 17 | let c_index = u16::from_be_bytes([cursor.take(), cursor.take()]); 18 | cursor.tmp_at((c_index & DC_FACTOR) as usize, |buf| { 19 | parse_name(buf, name_vec); 20 | }) 21 | } else { 22 | let seg_len = cursor.take(); 23 | if seg_len > 0 { 24 | let segment = cursor.take_slice(seg_len as usize); 25 | name_vec.push('.' as u8); 26 | name_vec.extend(segment); 27 | parse_name(cursor, name_vec) 28 | } 29 | }; 30 | } 31 | 32 | fn unzip_domain(cursor: &Cursor) -> String { 33 | let mut domain_vec = Vec::new(); 34 | parse_name(cursor, &mut domain_vec); 35 | domain_vec.remove(0); 36 | String::from_utf8(domain_vec).unwrap() 37 | } 38 | 39 | fn wrap_name(name: &String) -> Vec { 40 | let split = name.split('.'); 41 | let mut vec = Vec::new(); 42 | for s in split { 43 | vec.push(s.len() as u8); 44 | vec.extend(s.bytes()) 45 | } 46 | vec.push(0); 47 | vec 48 | } 49 | 50 | -------------------------------------------------------------------------------- /src/protocol/query/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::protocol::basic::BasicData; 2 | use crate::protocol::{basic}; 3 | use crate::system::{QueryBuf, next_id}; 4 | use crate::cursor::Cursor; 5 | 6 | const QUERY_ONLY_RECURSIVELY: u16 = 0x0100; 7 | const QUERY_RECURSIVELY_AD: u16 = 0x0120; 8 | 9 | #[derive(Clone, Eq, PartialEq, Debug)] 10 | pub struct DnsQuery { 11 | basic: BasicData, 12 | } 13 | 14 | impl DnsQuery { 15 | pub fn get_id(&self) -> u16 { 16 | self.basic.get_id() 17 | } 18 | pub fn set_id(&mut self, id: u16) { 19 | self.basic.set_id(id) 20 | } 21 | pub fn get_name(&self) -> &String { 22 | self.basic.get_name() 23 | } 24 | 25 | pub fn is_supported(&self) -> bool { 26 | let flags = self.basic.get_flags(); 27 | flags == QUERY_ONLY_RECURSIVELY || flags == QUERY_RECURSIVELY_AD 28 | } 29 | } 30 | 31 | impl From for DnsQuery { 32 | fn from(buf: QueryBuf) -> Self { 33 | let cursor = Cursor::form(buf.into()); 34 | DnsQuery { 35 | basic: BasicData::from(&cursor) 36 | } 37 | } 38 | } 39 | 40 | impl From<&str> for DnsQuery { 41 | fn from(domain: &str) -> Self { 42 | let basic = basic::Builder::new() 43 | .id(next_id()) 44 | .name(domain.to_string()) 45 | .flags(QUERY_ONLY_RECURSIVELY) 46 | .build(); 47 | DnsQuery { 48 | basic 49 | } 50 | } 51 | } 52 | 53 | impl From<&DnsQuery> for Vec { 54 | fn from(query: &DnsQuery) -> Self { 55 | let data = &query.basic; 56 | data.into() 57 | } 58 | } -------------------------------------------------------------------------------- /src/protocol/question.rs: -------------------------------------------------------------------------------- 1 | use crate::cursor::Cursor; 2 | use crate::protocol::{unzip_domain, wrap_name}; 3 | 4 | #[derive(Debug, Clone, Eq, PartialEq)] 5 | pub struct Question { 6 | pub name: String, 7 | pub _type: u16, 8 | pub class: u16, 9 | } 10 | 11 | impl From<&Question> for Vec { 12 | fn from(question: &Question) -> Self { 13 | let mut result = Vec::new(); 14 | result.extend(wrap_name(&question.name)); 15 | result.extend(&question._type.to_be_bytes()); 16 | result.extend(&question.class.to_be_bytes()); 17 | result 18 | } 19 | } 20 | 21 | impl From<&Cursor> for Question { 22 | fn from(cursor: &Cursor) -> Self { 23 | let name = unzip_domain(cursor); 24 | let _type = u16::from_be_bytes([cursor.take(), cursor.take()]); 25 | let class = u16::from_be_bytes([cursor.take(), cursor.take()]); 26 | Question { 27 | name, 28 | _type, 29 | class, 30 | } 31 | } 32 | } 33 | 34 | impl Question { 35 | fn is_legal(&self) -> bool { 36 | true 37 | } 38 | 39 | pub fn is_supported(&self) -> bool { 40 | self.is_legal() 41 | && self._type == 1 42 | && self.class == 1 43 | } 44 | 45 | pub fn new() -> Self { 46 | Question { 47 | name: String::new(), 48 | _type: 0, 49 | class: 0, 50 | } 51 | } 52 | } -------------------------------------------------------------------------------- /src/system.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::time::{Duration}; 3 | use std::sync::atomic::{AtomicU16, Ordering}; 4 | use crate::config::Config; 5 | use log::LevelFilter; 6 | use std::str::FromStr; 7 | use std::fmt::{Debug, Formatter, Display}; 8 | use std::cell::RefCell; 9 | 10 | pub type Result = core::result::Result>; 11 | pub type QueryBuf = [u8; 256]; 12 | pub type AnswerBuf = [u8; 512]; 13 | 14 | pub fn default_value() -> [T; N] where T: Default + Copy { 15 | [T::default(); N] 16 | } 17 | 18 | #[derive(Debug)] 19 | pub struct FileNotFoundError { 20 | pub path: String, 21 | pub supper: Box, 22 | } 23 | 24 | impl Display for FileNotFoundError { 25 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 26 | write!(f, "SuperError is here!") 27 | } 28 | } 29 | 30 | impl Error for FileNotFoundError {} 31 | 32 | pub struct TimeNow { 33 | #[cfg(test)] timestamp: u128, 34 | add_duration: Option, 35 | sub_duration: Option, 36 | } 37 | 38 | impl TimeNow { 39 | #[cfg(test)] 40 | pub fn get(&mut self) -> u128 { 41 | let add = self.add_duration.take().map(|d| d.as_millis()).unwrap_or(0); 42 | let sub = self.sub_duration.take().map(|d| d.as_millis()).unwrap_or(0); 43 | self.timestamp + add - sub 44 | } 45 | #[cfg(test)] 46 | pub fn set_timestamp(&mut self, timestamp: u128) -> &mut Self { 47 | self.timestamp = timestamp; 48 | self 49 | } 50 | 51 | #[cfg(test)] 52 | pub fn new() -> Self { 53 | TimeNow { 54 | timestamp: 0, 55 | add_duration: None, 56 | sub_duration: None, 57 | } 58 | } 59 | 60 | #[cfg(not(test))] 61 | pub fn get(&mut self) -> u128 { 62 | let current_time = get_timestamp(); 63 | let add = self.add_duration.take().map(|d| d.as_millis()).unwrap_or(0); 64 | let sub = self.sub_duration.take().map(|d| d.as_millis()).unwrap_or(0); 65 | current_time + add - sub 66 | } 67 | 68 | #[cfg(not(test))] 69 | pub fn new() -> Self { 70 | TimeNow { 71 | add_duration: None, 72 | sub_duration: None, 73 | } 74 | } 75 | 76 | pub fn sub(&mut self, d: Duration) -> &mut Self { 77 | self.sub_duration = Some(d); 78 | self 79 | } 80 | } 81 | 82 | thread_local! { 83 | pub static TIME: RefCell = RefCell::new(TimeNow::new()); 84 | } 85 | 86 | pub fn get_now() -> u128 { 87 | TIME.with(|r| { 88 | r.borrow_mut().get() 89 | }) 90 | } 91 | 92 | pub fn get_sub_now(d: Duration) -> u128 { 93 | TIME.with(|r| { 94 | r.borrow_mut().sub(d).get() 95 | }) 96 | } 97 | 98 | #[cfg(not(test))] 99 | use std::time::{SystemTime, UNIX_EPOCH}; 100 | use tokio::runtime::Handle; 101 | use std::future::Future; 102 | 103 | #[cfg(not(test))] 104 | fn get_timestamp() -> u128 { 105 | SystemTime::now() 106 | .duration_since(UNIX_EPOCH) 107 | .expect("Time went backwards").as_millis() 108 | } 109 | 110 | #[cfg(test)] 111 | pub fn set_time_base(base: u128) { 112 | TIME.with(|r| { 113 | r.borrow_mut().set_timestamp(base); 114 | }) 115 | } 116 | 117 | static ID: AtomicU16 = AtomicU16::new(0); 118 | 119 | pub fn next_id() -> u16 { 120 | match ID.fetch_update(Ordering::SeqCst, Ordering::Relaxed, |x| { 121 | if x > u16::MAX - 10000 { 122 | Some(0); 123 | } 124 | Some(x + 1) 125 | }) { 126 | Ok(id) => id, 127 | Err(e) => e 128 | } 129 | } 130 | 131 | pub fn setup_panic_hook() { 132 | //设置panic hook 133 | std::panic::set_hook(Box::new(|panic_info| { 134 | error!("panic message: {:?}, location in {:?}", panic_info.message(), panic_info.location()); 135 | })); 136 | } 137 | 138 | pub fn setup_log_level(config: &Config) -> Result<()> { 139 | let level = LevelFilter::from_str(&config.log_level)?; 140 | log::set_max_level(level); 141 | Ok(()) 142 | } 143 | 144 | pub fn block_on(future: F) -> F::Output { 145 | tokio::task::block_in_place(move || { 146 | Handle::current().block_on(async move { 147 | future.await 148 | }) 149 | }) 150 | } -------------------------------------------------------------------------------- /tests/resources/covercast_filter.txt: -------------------------------------------------------------------------------- 1 | #TITLE=anti-AD for SmartDNS 2 | #VER=20210806113134 3 | #URL=https://github.com/privacy-protection-tools/anti-AD 4 | #TOTAL_LINES=44642 5 | address /kwcdn.000dn.com/d -------------------------------------------------------------------------------- /tests/resources/test_filter.txt: -------------------------------------------------------------------------------- 1 | #TITLE=anti-AD for SmartDNS 2 | #VER=20210806113134 3 | #URL=https://github.com/privacy-protection-tools/anti-AD 4 | #TOTAL_LINES=44642 5 | address /00-gov.cn/# 6 | address /kwcdn.000dn.com/# -------------------------------------------------------------------------------- /tests/test.rs: -------------------------------------------------------------------------------- 1 | #[test] 2 | fn test() { 3 | assert_eq!(1 + 1, 2) 4 | } --------------------------------------------------------------------------------