├── .gitignore ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── examples └── search.rs ├── src ├── action │ ├── bootstrap.rs │ ├── lookup.rs │ ├── mod.rs │ └── refresh.rs ├── bucket.rs ├── compact.rs ├── handler.rs ├── info_hash.rs ├── lib.rs ├── mainline_dht.rs ├── message.rs ├── node.rs ├── router.rs ├── socket.rs ├── storage.rs ├── table.rs ├── test.rs ├── time.rs ├── timer.rs ├── token.rs └── transaction.rs └── tests └── tests.rs /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled files 2 | *.o 3 | *.so 4 | *.rlib 5 | *.dll 6 | 7 | # Executables 8 | *.exe 9 | 10 | # Generated by Cargo 11 | target/ 12 | Cargo.lock 13 | 14 | # Generated by Rustfmt 15 | *.bk 16 | 17 | # Test Files 18 | spike/ 19 | *.torrent 20 | 21 | # Visual Studio Code Files 22 | .vscode 23 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "btdht" 3 | version = "1.0.1" 4 | description = "Implementation of the bittorrent mainline DHT" 5 | authors = ["Andrew ", "Adam Cigánek "] 6 | license = "MIT/Apache-2.0" 7 | edition = "2021" 8 | 9 | [dependencies] 10 | async-trait = "0.1.83" 11 | crc32c = "0.6.8" 12 | futures-util = { version = "0.3.31", default-features = false, features = ["alloc"] } 13 | log = "0.4.22" 14 | rand = "0.8.5" 15 | # FIXME: serde >= 1.0.181 breaks some tests. Pinning to 1.0.180 for now. 16 | serde = { version = "=1.0.180", features = ["derive"] } 17 | serde_bencode = { package = "torrust-serde-bencode", version = "0.2.3" } 18 | serde_bytes = "0.11.15" 19 | sha-1 = "0.10.1" 20 | tokio = { version = "1.32", default-features = false, features = ["macros", "net", "rt", "rt-multi-thread", "sync", "time"] } 21 | thiserror = "1.0.64" 22 | 23 | [dev-dependencies] 24 | hex = "0.4.3" 25 | pretty_env_logger = "0.5.0" 26 | tokio = { version = "1.41.0", features = ["io-std", "io-util"] } 27 | test-log = "0.2.14" 28 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2016] [bip-rs Developers] 190 | Copyright [2021] [equalit.ie] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 bip-rs Developers 2 | Copyright (c) 2021 equalit.ie 3 | 4 | Permission is hereby granted, free of charge, to any 5 | person obtaining a copy of this software and associated 6 | documentation files (the "Software"), to deal in the 7 | Software without restriction, including without 8 | limitation the rights to use, copy, modify, merge, 9 | publish, distribute, sublicense, and/or sell copies of 10 | the Software, and to permit persons to whom the Software 11 | is furnished to do so, subject to the following 12 | conditions: 13 | 14 | The above copyright notice and this permission notice 15 | shall be included in all copies or substantial portions 16 | of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 19 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 20 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 21 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 22 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 23 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 24 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 25 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 26 | DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bittorrent Mainline DHT (btdht) 2 | 3 | Implementation of the bittorrent mainline dht. Originally foked from [bip-rs](https://github.com/GGist/bip-rs) 4 | 5 | ## References 6 | 7 | - [DHT Protocol](https://www.bittorrent.org/beps/bep_0005.html) 8 | - [DHT Extensions for IPv6](https://www.bittorrent.org/beps/bep_0032.html) 9 | 10 | ## Terminology 11 | 12 | **Lookup**: Refers to the process of iteratively querying peers in the DHT to see if they have contact information for other peers 13 | that have announced themselves for a given info hash. 14 | 15 | **Announce**: Refers to the process of querying peers in the DHT, and telling the closest few nodes that you are interested in peers 16 | looking for a particular info hash, and that the node should store your contact information for later nodes that reach any of the nodes 17 | announced to. 18 | 19 | ## Important Usage Information 20 | - **Before The Bootstrap**: It is always a good idea to start up the DHT ahead of time if you know you will need it later in your 21 | application. This is because the DHT will not immediately be usable by us until bootstrapping has completed, you can feel free to 22 | make requests, but they will be executed after the bootstrap has finished (which may take up to 30 seconds). 23 | 24 | - **Announce Expire**: Nodes in the DHT will expire the contact information for announces that have gone stale (havent heard from again 25 | for a while). This means if you are still looking for peers, you will want to announce periodically. All nodes have different expire 26 | times, the spec mentions the 24 hour expire period, however, you may want to announce more often than that as peers are constantly leaving 27 | and joining the DHT, so if the nodes you announced to all left the DHT, you would be out of luck. Luckily, for each announce, we do 28 | replicate your contact information to multiple of the closest nodes. 29 | 30 | - **Read Only Nodes**: By default, all nodes created are read only; this means that the node will not respond to requests. In theory 31 | this sounds good, however, in practice this means it will be harder (but possible) to keep a healthy routing table, especially for 32 | nodes that wish to run for long periods of time. I strongly encourage users who will be running nodes for long periods of time to 33 | set up some sort of nat traversal/port forwarding to the source address of the DHT and set read only to false (it is true by default). 34 | 35 | - **Source Port vs Connect Port**: One thing you should note is that, by either implementation error or intentionally, if the port that 36 | the DHT is bound to is different than the port that we want nodes to connect to us on (our announce/connect port) some nodes will 37 | incorrectly store the source port that we used to send the announce message instead of the port specified in the message. This is not 38 | a big deal as most nodes handle this correctly ( I have only seen a few that screw this up). If you are receiving TCP connections requests 39 | on the wrong port (the DHT source port), this is most likely why. 40 | 41 | - **DHT Spam**: Many nodes in the DHT will ban nodes that they feel are malicious. This includes sending a high number 42 | of requests, most likely for the same info hash, to the same node. As a user, you will not have control over what nodes we contact in a 43 | lookup/announce. Over time, we will get better at making sure our clients dont get banned, but to do your part, do not send an excessive 44 | amount of lookups/announces for the same info hash in a short period of time. Symptoms of getting banned include receiving less and less 45 | contacts back when doing a search for an info hash. If you feel you have gotten banned, you can always restart the DHT since all nodes 46 | (should) treat the (node id, source address) as the unique identifier for nodes and we always get a new node id on startup. 47 | 48 | - **Sloppy DHT**: The kademlia DHT is also referred to as a sloppy DHT. This means that you will be able to find most (if not all) 49 | nodes that announce for a given info hash. To make your applications more robust (and this is what torrent clients do), you should 50 | develop a mechanism for receiving the contact information for other peers from peers themselves. This means that if you had two 51 | segmented swarms of peers, only one person from one swarm has to be aware of one person from another swarm in order to join the 52 | two swarms so that everyone knows of everyone else. 53 | 54 | ## License 55 | 56 | Licensed under either of 57 | 58 | * Apache License, Version 2.0, ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0) 59 | * MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT) 60 | 61 | at your option. 62 | 63 | ## Contribution 64 | 65 | Unless you explicitly state otherwise, any contribution intentionally submitted 66 | for inclusion in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above, without any 67 | additional terms or conditions. 68 | 69 | -------------------------------------------------------------------------------- /examples/search.rs: -------------------------------------------------------------------------------- 1 | use btdht::{router, InfoHash, LengthError, MainlineDht}; 2 | use futures_util::StreamExt; 3 | use std::{ 4 | collections::HashSet, 5 | convert::TryFrom, 6 | net::{Ipv4Addr, SocketAddr}, 7 | str::FromStr, 8 | time::Instant, 9 | }; 10 | use tokio::{ 11 | io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader}, 12 | net::UdpSocket, 13 | }; 14 | 15 | #[tokio::main] 16 | async fn main() { 17 | pretty_env_logger::init(); 18 | 19 | let addr = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)); 20 | let socket = UdpSocket::bind(addr).await.unwrap(); 21 | 22 | let dht = MainlineDht::builder() 23 | .add_routers([router::BITTORRENT_DHT, router::TRANSMISSION_DHT]) 24 | .set_read_only(false) 25 | .start(socket) 26 | .unwrap(); 27 | 28 | println!("bootstrapping..."); 29 | let start = Instant::now(); 30 | let status = dht.bootstrapped().await; 31 | let elapsed = start.elapsed(); 32 | 33 | if status { 34 | println!( 35 | "bootstrap completed in {}.{:03} seconds", 36 | elapsed.as_secs(), 37 | elapsed.subsec_millis() 38 | ); 39 | } else { 40 | println!( 41 | "bootstrap failed in {}.{:03} seconds", 42 | elapsed.as_secs(), 43 | elapsed.subsec_millis() 44 | ); 45 | return; 46 | } 47 | 48 | let mut stdout = io::stdout(); 49 | let mut stdin = BufReader::new(io::stdin()); 50 | let mut line = String::new(); 51 | 52 | loop { 53 | stdout.write_all(b"> ").await.unwrap(); 54 | stdout.flush().await.unwrap(); 55 | 56 | line.clear(); 57 | 58 | if stdin.read_line(&mut line).await.unwrap() > 0 { 59 | if !handle_command(&dht, &line).await.unwrap() { 60 | break; 61 | } 62 | } else { 63 | break; 64 | } 65 | } 66 | } 67 | 68 | async fn handle_command(dht: &MainlineDht, command: &str) -> io::Result { 69 | match command.parse() { 70 | Ok(Command::Help) => { 71 | println!(" h shows this help message"); 72 | println!(" s search for the specified info hash"); 73 | println!(" a announce the specified info hash"); 74 | println!(" q quit"); 75 | println!(); 76 | println!( 77 | "Note: can be specified either as a 40-character hexadecimal string or \ 78 | an arbitrary string prefixed with '#'. In the first case it is interpreted \ 79 | directly as the info hash, in the second the info hash is obtained by computing a \ 80 | SHA-1 digest of the string excluding the leading '#' and trimming any leading or \ 81 | trailing whitespace." 82 | ); 83 | 84 | Ok(true) 85 | } 86 | Ok(Command::Search { 87 | info_hash, 88 | announce, 89 | }) => { 90 | if announce { 91 | println!("announcing {info_hash:?}...") 92 | } else { 93 | println!("searching for {info_hash:?}...") 94 | } 95 | 96 | let mut peers = HashSet::new(); 97 | let start = Instant::now(); 98 | 99 | let mut search = dht.search(info_hash, announce); 100 | 101 | while let Some(addr) = search.next().await { 102 | if peers.insert(addr) { 103 | println!("peer found: {addr}"); 104 | } 105 | } 106 | 107 | let elapsed = start.elapsed(); 108 | println!( 109 | "search completed: found {} peers in {}.{:03} seconds", 110 | peers.len(), 111 | elapsed.as_secs(), 112 | elapsed.subsec_millis() 113 | ); 114 | 115 | Ok(true) 116 | } 117 | Ok(Command::Quit) => Ok(false), 118 | Err(_) => { 119 | println!("invalid command (use 'h' for help)"); 120 | Ok(true) 121 | } 122 | } 123 | } 124 | 125 | enum Command { 126 | Help, 127 | Search { info_hash: InfoHash, announce: bool }, 128 | Quit, 129 | } 130 | 131 | impl FromStr for Command { 132 | type Err = ParseError; 133 | 134 | fn from_str(s: &str) -> Result { 135 | match &s[..1] { 136 | "h" | "?" => Ok(Self::Help), 137 | "s" => Ok(Self::Search { 138 | info_hash: parse_info_hash(s[1..].trim())?, 139 | announce: false, 140 | }), 141 | "a" => Ok(Self::Search { 142 | info_hash: parse_info_hash(s[1..].trim())?, 143 | announce: true, 144 | }), 145 | "q" => Ok(Self::Quit), 146 | _ => Err(ParseError), 147 | } 148 | } 149 | } 150 | 151 | struct ParseError; 152 | 153 | impl From for ParseError { 154 | fn from(_: hex::FromHexError) -> Self { 155 | ParseError 156 | } 157 | } 158 | 159 | impl From for ParseError { 160 | fn from(_: LengthError) -> Self { 161 | ParseError 162 | } 163 | } 164 | 165 | fn parse_info_hash(s: &str) -> Result { 166 | if &s[..1] == "#" { 167 | Ok(InfoHash::sha1(s[1..].trim().as_bytes())) 168 | } else { 169 | Ok(InfoHash::try_from(hex::decode(s)?.as_ref())?) 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /src/action/bootstrap.rs: -------------------------------------------------------------------------------- 1 | use super::{resolve, IpVersion, Responded, Socket}; 2 | use crate::bucket::Bucket; 3 | use crate::message::{FindNodeRequest, Message, MessageBody, Request}; 4 | use crate::node::{Node, NodeStatus}; 5 | use crate::table::{self, RoutingTable}; 6 | use crate::transaction::{MIDGenerator, TransactionID}; 7 | use crate::{info_hash::NodeId, node::NodeHandle}; 8 | use futures_util::{stream::FuturesUnordered, StreamExt}; 9 | use std::{ 10 | collections::HashSet, 11 | net::SocketAddr, 12 | pin::pin, 13 | sync::{Arc, Mutex}, 14 | time::Duration, 15 | }; 16 | use tokio::{ 17 | select, 18 | sync::{mpsc, watch}, 19 | task, time, 20 | time::sleep, 21 | }; 22 | 23 | const INITIAL_TIMEOUT: Duration = Duration::from_millis(2500); 24 | const NODE_TIMEOUT: Duration = Duration::from_millis(500); 25 | const NO_NETWORK_TIMEOUT: Duration = Duration::from_secs(5); 26 | const PERIODIC_CHECK_TIMEOUT: Duration = Duration::from_secs(5); 27 | 28 | // We try to rebootstrap when we have fewer nodes than this. 29 | const GOOD_NODE_THRESHOLD: usize = 10; 30 | 31 | const PINGS_PER_BUCKET: usize = 8; 32 | const MAX_INITIAL_RESPONSES: usize = 8; 33 | 34 | pub(crate) struct TableBootstrap { 35 | start_tx: watch::Sender, 36 | pub state_rx: watch::Receiver, 37 | worker_handle: task::JoinHandle<()>, 38 | } 39 | 40 | impl Drop for TableBootstrap { 41 | fn drop(&mut self) { 42 | self.worker_handle.abort(); 43 | } 44 | } 45 | 46 | struct TableBootstrapInner { 47 | this_node_id: NodeId, 48 | ip_version: IpVersion, 49 | routers: HashSet, 50 | id_generator: Mutex, 51 | starting_nodes: HashSet, 52 | table: Arc>, 53 | socket: Arc, 54 | start_rx: watch::Receiver, 55 | state_tx: watch::Sender, 56 | } 57 | 58 | #[derive(Eq, PartialEq, Copy, Clone, Debug)] 59 | pub enum State { 60 | AwaitStart, 61 | InitialContact, 62 | Bootstrapping, 63 | Bootstrapped, 64 | // The starting state or state after a bootstrap has failed and new has been cheduled after a 65 | // timeout. 66 | IdleBeforeRebootstrap, 67 | } 68 | 69 | impl TableBootstrap { 70 | pub fn new( 71 | socket: Arc, 72 | table: Arc>, 73 | id_generator: MIDGenerator, 74 | routers: HashSet, 75 | nodes: HashSet, 76 | ) -> TableBootstrap { 77 | let this_node_id = table.lock().unwrap().node_id(); 78 | 79 | let (start_tx, start_rx) = watch::channel(false); 80 | let (state_tx, state_rx) = watch::channel(State::AwaitStart); 81 | 82 | let inner = TableBootstrapInner { 83 | this_node_id, 84 | ip_version: socket.ip_version(), 85 | routers, 86 | id_generator: Mutex::new(id_generator), 87 | starting_nodes: nodes, 88 | table, 89 | socket, 90 | start_rx, 91 | state_tx, 92 | }; 93 | 94 | let worker_handle = task::spawn(inner.run(this_node_id)); 95 | 96 | TableBootstrap { 97 | start_tx, 98 | state_rx, 99 | //inner, 100 | worker_handle, 101 | } 102 | } 103 | 104 | /// Return true if the bootstrap state changed. 105 | pub fn start(&self) { 106 | // Unwrap OK because the runner exits only if the `start_tx` is destroyed. 107 | self.start_tx.send(true).unwrap(); 108 | } 109 | } 110 | 111 | impl TableBootstrapInner { 112 | // Return true we switched between being bootsrapped and not being bootstrapped. 113 | fn set_state(&self, new_state: State, from_line: u32) { 114 | let old_state = *self.state_tx.borrow(); 115 | 116 | if old_state == new_state { 117 | return; 118 | } 119 | 120 | self.state_tx.send(new_state).unwrap_or(()); 121 | 122 | log::info!( 123 | "{}: TableBootstrap state change {:?} -> {:?} (from_line: {})", 124 | self.ip_version, 125 | old_state, 126 | new_state, 127 | from_line 128 | ); 129 | } 130 | 131 | async fn run(mut self, table_id: NodeId) { 132 | loop { 133 | match self.start_rx.changed().await { 134 | Ok(()) => { 135 | if *self.start_rx.borrow() { 136 | break; 137 | } 138 | } 139 | Err(_) => return, 140 | } 141 | } 142 | 143 | let mut bootstrap_attempt = 0; 144 | 145 | loop { 146 | // If we have no bootstrap contacts it means we are the first node in the network and 147 | // other would bootstrap against us. We consider this node as already bootstrapped. 148 | if self.routers.is_empty() && self.starting_nodes.is_empty() { 149 | self.set_state(State::Bootstrapped, line!()); 150 | std::future::pending::<()>().await; 151 | unreachable!(); 152 | } 153 | 154 | let router_addresses = resolve(&self.routers, self.socket.ip_version()).await; 155 | self.table.lock().unwrap().routers = router_addresses.clone(); 156 | 157 | if router_addresses.is_empty() && self.starting_nodes.is_empty() { 158 | self.set_state(State::IdleBeforeRebootstrap, line!()); 159 | sleep(NO_NETWORK_TIMEOUT).await; 160 | continue; 161 | } 162 | 163 | self.set_state(State::InitialContact, line!()); 164 | log::debug!( 165 | "Have {} routers and {} starting nodes", 166 | router_addresses.len(), 167 | self.starting_nodes.len(), 168 | ); 169 | 170 | // In the initial round, we send the requests to contacts (nodes and routers) who are not in 171 | // our routing table. Because of that, we don't care who we receive a response from, only 172 | // that we receive sufficient number of unique ones. Thus we use the same transaction id 173 | // for all of them. 174 | // After the initial round we are sending only to nodes from the routing table, so we use 175 | // unique transaction id per node. 176 | let trans_id = self.id_generator.lock().unwrap().generate(); 177 | 178 | let find_node_msg = Self::make_find_node_request(trans_id, table_id, table_id); 179 | 180 | let mut receivers = FuturesUnordered::new(); 181 | let (new_receivers_tx, mut new_receivers_rx) = mpsc::unbounded_channel(); 182 | 183 | let contact_count = router_addresses.len() + self.starting_nodes.len(); 184 | let stop_at = std::cmp::min(contact_count, MAX_INITIAL_RESPONSES); 185 | let mut responses_received = 0; 186 | 187 | let mut send_finished = false; 188 | let mut new_receivers_closed = false; 189 | let mut send_to_initial_nodes = pin!(self.send_to_initial_nodes( 190 | find_node_msg, 191 | &router_addresses, 192 | new_receivers_tx 193 | )); 194 | 195 | loop { 196 | if send_finished && new_receivers_closed && receivers.is_empty() { 197 | break; 198 | } 199 | 200 | select! { 201 | _ = &mut send_to_initial_nodes, if !send_finished => { 202 | send_finished = true; 203 | }, 204 | new_receiver = new_receivers_rx.recv(), if !new_receivers_closed => { 205 | if let Some(new_receiver) = new_receiver { 206 | receivers.push(new_receiver); 207 | } else { 208 | new_receivers_closed = true; 209 | } 210 | }, 211 | ret = receivers.next(), if !receivers.is_empty() => { 212 | if let Some(Some((message, from))) = ret { 213 | if self.handle_message(message, from) { 214 | responses_received += 1; 215 | 216 | if responses_received >= stop_at { 217 | break; 218 | } 219 | } 220 | } 221 | } 222 | } 223 | } 224 | 225 | if responses_received == 0 { 226 | self.set_state(State::IdleBeforeRebootstrap, line!()); 227 | time::sleep(self.calculate_retry_duration(bootstrap_attempt)).await; 228 | bootstrap_attempt += 1; 229 | continue; 230 | } 231 | 232 | self.set_state(State::Bootstrapping, line!()); 233 | 234 | for bucket_number in 0..table::MAX_BUCKETS { 235 | log::debug!( 236 | "{}: TableBootstrap::bootstrap_next_bucket {}/{}", 237 | self.ip_version, 238 | bucket_number, 239 | table::MAX_BUCKETS 240 | ); 241 | 242 | let (new_receivers_tx, mut new_receivers_rx) = mpsc::unbounded_channel(); 243 | 244 | let mut send_bucket_bootstrap = 245 | pin!(self.send_bucket_bootstrap_requests(bucket_number, new_receivers_tx)); 246 | 247 | let mut send_finished = false; 248 | let mut new_receivers_closed = false; 249 | let mut receivers = FuturesUnordered::new(); 250 | 251 | loop { 252 | if send_finished && new_receivers_closed && receivers.is_empty() { 253 | break; 254 | } 255 | 256 | select! { 257 | _ = &mut send_bucket_bootstrap, if !send_finished => { 258 | send_finished = true; 259 | }, 260 | new_receiver = new_receivers_rx.recv() , if !new_receivers_closed => { 261 | if let Some(new_receiver) = new_receiver { 262 | receivers.push(new_receiver); 263 | } else { 264 | new_receivers_closed = true; 265 | } 266 | }, 267 | ret = receivers.next(), if !receivers.is_empty() => { 268 | if let Some(Some((message, from))) = ret { 269 | self.handle_message(message, from); 270 | } 271 | } 272 | } 273 | } 274 | } 275 | 276 | let (num_good_nodes, num_questionable_nodes) = { 277 | let table = self.table.lock().unwrap(); 278 | (table.num_good_nodes(), table.num_questionable_nodes()) 279 | }; 280 | 281 | log::debug!( 282 | "{}: TableBootstrap num_good_nodes:{} and num_questionable_nodes:{}", 283 | self.ip_version, 284 | num_good_nodes, 285 | num_questionable_nodes 286 | ); 287 | 288 | if num_good_nodes < GOOD_NODE_THRESHOLD { 289 | // If we don't have enought good nodes and the `router_addresses` array is empty 290 | // then we might be testing or we might be in a country that is blocked from the 291 | // outside world where no BtDHT exists yet and we're one of the first nodes 292 | // creating it. In those cases we'll claim that we've bootstrapped and repeat the 293 | // bootstrap process periodically. 294 | if !router_addresses.is_empty() { 295 | self.set_state(State::IdleBeforeRebootstrap, line!()); 296 | time::sleep(self.calculate_retry_duration(bootstrap_attempt)).await; 297 | bootstrap_attempt += 1; 298 | continue; 299 | } 300 | } 301 | 302 | self.set_state(State::Bootstrapped, line!()); 303 | 304 | // Reset the counter. 305 | bootstrap_attempt = 0; 306 | 307 | loop { 308 | time::sleep(PERIODIC_CHECK_TIMEOUT).await; 309 | 310 | if self.table.lock().unwrap().num_good_nodes() < GOOD_NODE_THRESHOLD { 311 | break; 312 | } 313 | } 314 | } 315 | } 316 | 317 | async fn send_to_initial_nodes( 318 | &self, 319 | message: Message, 320 | router_addresses: &HashSet, 321 | new_receivers_tx: mpsc::UnboundedSender, 322 | ) { 323 | let mut last_send_error = None; 324 | let mut count = 0; 325 | 326 | for addr in router_addresses.iter().chain(self.starting_nodes.iter()) { 327 | // Throttle sending if there is too many initial contacts 328 | if count > PINGS_PER_BUCKET { 329 | time::sleep(NODE_TIMEOUT.max(Self::nat_friendly_send_duration())).await; 330 | } 331 | 332 | match self 333 | .socket 334 | .send_request(&message, *addr, INITIAL_TIMEOUT) 335 | .await 336 | { 337 | Ok(receiver) => { 338 | count += 1; 339 | if new_receivers_tx.send(receiver).is_err() { 340 | return; 341 | } 342 | } 343 | Err(error) => { 344 | if Some(error.kind()) != last_send_error { 345 | log::error!( 346 | "{}: Failed to send bootstrap message to router: {}", 347 | self.ip_version, 348 | error 349 | ); 350 | last_send_error = Some(error.kind()); 351 | } 352 | } 353 | } 354 | } 355 | } 356 | 357 | // If this returns `false` it means the request wasn't sent to any node (either because there 358 | // were no nodes or because all the sends failed). We should proceed to the next bucket in that 359 | // case. 360 | async fn send_bucket_bootstrap_requests( 361 | &self, 362 | bucket_number: usize, 363 | new_receivers_tx: mpsc::UnboundedSender, 364 | ) { 365 | let target_id = self.this_node_id.flip_bit(bucket_number); 366 | let nodes = self.nodes_to_bootstrap_bucket(bucket_number, target_id); 367 | 368 | for node in nodes { 369 | // Generate a transaction id 370 | let trans_id = self.id_generator.lock().unwrap().generate(); 371 | 372 | let find_node_msg = 373 | Self::make_find_node_request(trans_id, self.this_node_id, target_id); 374 | 375 | // Send the message to the node 376 | match self 377 | .socket 378 | .send_request(&find_node_msg, node.addr, NODE_TIMEOUT) 379 | .await 380 | { 381 | Ok(receiver) => { 382 | if new_receivers_tx.send(receiver).is_err() { 383 | break; 384 | } 385 | } 386 | Err(error) => { 387 | log::error!( 388 | "{}: Could not send a bootstrap message: {}", 389 | self.ip_version, 390 | error 391 | ); 392 | continue; 393 | } 394 | } 395 | 396 | // Mark that we requested from the node 397 | if let Some(node) = self.table.lock().unwrap().find_node_mut(&node) { 398 | node.local_request(); 399 | } 400 | } 401 | } 402 | 403 | fn nat_friendly_send_duration() -> Duration { 404 | // An answer on serverfault.com[1] says the average home router may have from 2^10 405 | // to 2^14 NAT entries. To be conservative, and to account for the fact that the 406 | // user may be running one IPv4 and one IPv6 `MainlineDht`, let's assume we don't 407 | // want to exceed 256 NAT entries by much. A NAT entry stays open up to 20 secods 408 | // before it's deleted. Thus let's sleep for 20s/256 so that after 20 seconds if we 409 | // contact another node, the first nodes we contacted shall begin being removed 410 | // from the NAT. 411 | // 412 | // [1] https://serverfault.com/a/57903 413 | Duration::from_millis(20_000 / 256) 414 | } 415 | 416 | fn handle_message(&self, message: Message, from: SocketAddr) -> bool { 417 | match message.body { 418 | MessageBody::Response(rsp) => { 419 | let node = Node::as_good(rsp.id, from); 420 | 421 | let nodes = match self.socket.ip_version() { 422 | IpVersion::V4 => &rsp.nodes_v4, 423 | IpVersion::V6 => &rsp.nodes_v6, 424 | }; 425 | 426 | self.table.lock().unwrap().add_nodes(node, nodes); 427 | 428 | true 429 | } 430 | _ => false, 431 | } 432 | } 433 | 434 | fn make_find_node_request( 435 | transaction_id: TransactionID, 436 | id: NodeId, 437 | target: NodeId, 438 | ) -> Message { 439 | Message { 440 | transaction_id: transaction_id.as_ref().to_vec(), 441 | body: MessageBody::Request(Request::FindNode(FindNodeRequest { 442 | id, 443 | target, 444 | want: None, // we want only contacts of the same address family we have. 445 | })), 446 | } 447 | } 448 | 449 | fn calculate_retry_duration(&self, bootstrap_attempt: u64) -> Duration { 450 | const BASE: u64 = 2; 451 | // Max is somewhere around 8.5 mins. 452 | Duration::from_secs(BASE.pow((bootstrap_attempt + 1).min(9) as u32)) 453 | } 454 | 455 | fn nodes_to_bootstrap_bucket( 456 | &self, 457 | bucket_number: usize, 458 | target_id: NodeId, 459 | ) -> Vec { 460 | let table = self.table.lock().unwrap(); 461 | 462 | // Get the optimal iterator to bootstrap the current bucket 463 | if bucket_number == 0 || bucket_number == 1 { 464 | table 465 | .closest_nodes(target_id) 466 | .filter(|n| n.status() == NodeStatus::Questionable) 467 | .take(PINGS_PER_BUCKET) 468 | .map(|node| *node.handle()) 469 | .collect() 470 | } else { 471 | let mut buckets = table.buckets().skip(bucket_number - 2); 472 | let dummy_bucket = Bucket::new(); 473 | 474 | // Sloppy probabilities of our target node residing at the node 475 | let percent_25_bucket = if let Some(bucket) = buckets.next() { 476 | bucket.iter() 477 | } else { 478 | dummy_bucket.iter() 479 | }; 480 | let percent_50_bucket = if let Some(bucket) = buckets.next() { 481 | bucket.iter() 482 | } else { 483 | dummy_bucket.iter() 484 | }; 485 | let percent_100_bucket = if let Some(bucket) = buckets.next() { 486 | bucket.iter() 487 | } else { 488 | dummy_bucket.iter() 489 | }; 490 | 491 | // TODO: Figure out why chaining them in reverse gives us more total nodes on average, perhaps it allows us to fill up the lower 492 | // buckets faster at the cost of less nodes in the higher buckets (since lower buckets are very easy to fill)...Although it should 493 | // even out since we are stagnating buckets, so doing it in reverse may make sense since on the 3rd iteration, it allows us to ping 494 | // questionable nodes in our first buckets right off the bat. 495 | percent_25_bucket 496 | .chain(percent_50_bucket) 497 | .chain(percent_100_bucket) 498 | .filter(|n| n.status() == NodeStatus::Questionable) 499 | .take(PINGS_PER_BUCKET) 500 | .map(|node| *node.handle()) 501 | .collect() 502 | } 503 | } 504 | } 505 | -------------------------------------------------------------------------------- /src/action/lookup.rs: -------------------------------------------------------------------------------- 1 | use super::{ActionStatus, IpVersion, ScheduledTaskCheck}; 2 | use crate::info_hash::{InfoHash, INFO_HASH_LEN}; 3 | use crate::message::{ 4 | AnnouncePeerRequest, GetPeersRequest, Message, MessageBody, Request, Response, 5 | }; 6 | use crate::{ 7 | bucket, 8 | info_hash::NodeId, 9 | node::{Node, NodeHandle, NodeStatus}, 10 | socket::Socket, 11 | table::RoutingTable, 12 | timer::{Timeout, Timer}, 13 | transaction::{MIDGenerator, TransactionID}, 14 | }; 15 | use std::{ 16 | collections::{HashMap, HashSet}, 17 | net::{Ipv4Addr, SocketAddr}, 18 | sync::{Arc, Mutex}, 19 | time::Duration, 20 | }; 21 | use tokio::sync::mpsc; 22 | 23 | const LOOKUP_TIMEOUT: Duration = Duration::from_millis(1500); 24 | const ENDGAME_TIMEOUT: Duration = Duration::from_millis(1500); 25 | const ANNOUNCE_PICK_NUM: usize = 8; 26 | 27 | // Currently using the aggressive variant of the standard lookup procedure. 28 | // https://people.kth.se/~rauljc/p2p11/jimenez2011subsecond.pdf 29 | 30 | // TODO: Handle case where a request round fails, should we fail the whole lookup (clear acvite lookups?) 31 | // TODO: Clean up the code in this module. 32 | 33 | const INITIAL_PICK_NUM: usize = 4; // Alpha 34 | const ITERATIVE_PICK_NUM: usize = 3; // Beta 35 | 36 | type Distance = InfoHash; 37 | type DistanceToBeat = InfoHash; 38 | 39 | pub(crate) struct TableLookup { 40 | table: Arc>, 41 | this_node_id: NodeId, 42 | ip_version: IpVersion, 43 | target_id: InfoHash, 44 | in_endgame: bool, 45 | // If we have received any values in the lookup. 46 | recv_values: bool, 47 | id_generator: MIDGenerator, 48 | will_announce: bool, 49 | // DistanceToBeat is the distance that the responses of the current lookup needs to beat, 50 | // interestingly enough (and super important), this distance may not be eqaul to the 51 | // requested node's distance 52 | active_lookups: HashMap, 53 | announce_tokens: HashMap>, 54 | requested_nodes: HashSet, 55 | // Storing whether or not it has ever been pinged so that we 56 | // can perform the brute force lookup if the lookup failed 57 | all_sorted_nodes: Vec<(Distance, NodeHandle, bool)>, 58 | // Send the found peers through this channel. 59 | tx: mpsc::UnboundedSender, 60 | } 61 | 62 | // Gather nodes 63 | 64 | impl TableLookup { 65 | pub async fn new( 66 | target_id: InfoHash, 67 | will_announce: bool, 68 | tx: mpsc::UnboundedSender, 69 | id_generator: MIDGenerator, 70 | table: Arc>, 71 | socket: &Socket, 72 | timer: &mut Timer, 73 | ) -> TableLookup { 74 | // Pick a buckets worth of nodes and put them into the all_sorted_nodes list 75 | let mut all_sorted_nodes = Vec::with_capacity(bucket::MAX_BUCKET_SIZE); 76 | for node in table 77 | .lock() 78 | .unwrap() 79 | .closest_nodes(target_id) 80 | .filter(|n| n.status() == NodeStatus::Good) 81 | .take(bucket::MAX_BUCKET_SIZE) 82 | { 83 | insert_sorted_node(&mut all_sorted_nodes, target_id, *node.handle(), false); 84 | } 85 | 86 | // Call pick_initial_nodes with the all_sorted_nodes list as an iterator 87 | let initial_pick_nodes = pick_initial_nodes(all_sorted_nodes.iter_mut()); 88 | let initial_pick_nodes_filtered = 89 | initial_pick_nodes 90 | .iter() 91 | .filter(|(_, good)| *good) 92 | .map(|(node, _)| { 93 | let distance_to_beat = node.id ^ target_id; 94 | 95 | (node, distance_to_beat) 96 | }); 97 | 98 | let this_node_id = table.lock().unwrap().node_id(); 99 | 100 | // Construct the lookup table structure 101 | let mut table_lookup = TableLookup { 102 | table, 103 | this_node_id, 104 | ip_version: socket.ip_version(), 105 | target_id, 106 | in_endgame: false, 107 | recv_values: false, 108 | id_generator, 109 | will_announce, 110 | all_sorted_nodes, 111 | announce_tokens: HashMap::new(), 112 | requested_nodes: HashSet::new(), 113 | active_lookups: HashMap::with_capacity(INITIAL_PICK_NUM), 114 | tx, 115 | }; 116 | 117 | // Call start_request_round with the list of initial_nodes (return even if the search completed...for now :D) 118 | table_lookup 119 | .start_request_round(initial_pick_nodes_filtered, socket, timer) 120 | .await; 121 | 122 | table_lookup 123 | } 124 | 125 | pub fn completed(&self) -> bool { 126 | self.active_lookups.is_empty() 127 | } 128 | 129 | pub async fn recv_response( 130 | &mut self, 131 | node: Node, 132 | trans_id: &TransactionID, 133 | msg: Response, 134 | socket: &Socket, 135 | timer: &mut Timer, 136 | ) -> ActionStatus { 137 | // Process the message transaction id 138 | let (dist_to_beat, timeout) = if let Some(lookup) = self.active_lookups.remove(trans_id) { 139 | lookup 140 | } else { 141 | log::debug!( 142 | "{}: Received expired/unsolicited node response for an active table lookup", 143 | self.ip_version 144 | ); 145 | return self.current_lookup_status(); 146 | }; 147 | 148 | // Cancel the timeout (if this is not an endgame response) 149 | if !self.in_endgame { 150 | timer.cancel(timeout); 151 | } 152 | 153 | if let Some(token) = msg.token { 154 | // Add the announce token to our list of tokens 155 | self.announce_tokens.insert(*node.handle(), token); 156 | } 157 | 158 | let nodes = match socket.ip_version() { 159 | IpVersion::V4 => msg.nodes_v4, 160 | IpVersion::V6 => msg.nodes_v6, 161 | }; 162 | 163 | let values = msg.values; 164 | 165 | // Check if we beat the distance, get the next distance to beat 166 | let (iterate_nodes, next_dist_to_beat) = if !nodes.is_empty() { 167 | let requested_nodes = &self.requested_nodes; 168 | 169 | // Get the closest distance (or the current distance) 170 | let next_dist_to_beat = nodes 171 | .iter() 172 | .filter(|node| !requested_nodes.contains(node)) 173 | .fold(dist_to_beat, |closest, node| { 174 | let distance = self.target_id ^ node.id; 175 | 176 | if distance < closest { 177 | distance 178 | } else { 179 | closest 180 | } 181 | }); 182 | 183 | // Check if we got closer (equal to is not enough) 184 | let iterate_nodes = if next_dist_to_beat < dist_to_beat { 185 | let iterate_nodes = pick_iterate_nodes( 186 | nodes 187 | .iter() 188 | .filter(|node| !requested_nodes.contains(node)) 189 | .copied(), 190 | self.target_id, 191 | ); 192 | 193 | // Push nodes into the all nodes list 194 | for node in nodes { 195 | let will_ping = iterate_nodes.iter().any(|(n, _)| n == &node); 196 | 197 | insert_sorted_node(&mut self.all_sorted_nodes, self.target_id, node, will_ping); 198 | } 199 | 200 | Some(iterate_nodes) 201 | } else { 202 | // Push nodes into the all nodes list 203 | for node in nodes { 204 | insert_sorted_node(&mut self.all_sorted_nodes, self.target_id, node, false); 205 | } 206 | 207 | None 208 | }; 209 | 210 | (iterate_nodes, next_dist_to_beat) 211 | } else { 212 | (None, dist_to_beat) 213 | }; 214 | 215 | // Check if we need to iterate (not in the endgame already) 216 | if !self.in_endgame { 217 | // If the node gave us a closer id than its own to the target id, continue the search 218 | if let Some(nodes) = iterate_nodes { 219 | let filtered_nodes = nodes 220 | .iter() 221 | .filter(|(_, good)| *good) 222 | .map(|(n, _)| (n, next_dist_to_beat)); 223 | self.start_request_round(filtered_nodes, socket, timer) 224 | .await; 225 | } 226 | 227 | // If there are not more active lookups, start the endgame 228 | if self.active_lookups.is_empty() { 229 | self.start_endgame_round(socket, timer).await; 230 | } 231 | } 232 | 233 | for value in values { 234 | self.tx.send(value).unwrap_or(()) 235 | } 236 | 237 | self.current_lookup_status() 238 | } 239 | 240 | pub async fn recv_timeout( 241 | &mut self, 242 | trans_id: &TransactionID, 243 | socket: &Socket, 244 | timer: &mut Timer, 245 | ) -> ActionStatus { 246 | if self.active_lookups.remove(trans_id).is_none() { 247 | log::warn!( 248 | "{}: Received expired/unsolicited node timeout for an active table lookup", 249 | self.ip_version 250 | ); 251 | return self.current_lookup_status(); 252 | } 253 | 254 | if !self.in_endgame { 255 | // If there are not more active lookups, start the endgame 256 | if self.active_lookups.is_empty() { 257 | self.start_endgame_round(socket, timer).await; 258 | } 259 | } 260 | 261 | self.current_lookup_status() 262 | } 263 | 264 | pub async fn recv_finished(&mut self, port: Option, socket: &Socket) { 265 | // Announce if we were told to 266 | if self.will_announce { 267 | // Partial borrow so the filter function doesnt capture all of self 268 | let announce_tokens = &self.announce_tokens; 269 | 270 | for (_, node, _) in self 271 | .all_sorted_nodes 272 | .iter() 273 | .filter(|(_, node, _)| announce_tokens.contains_key(node)) 274 | .take(ANNOUNCE_PICK_NUM) 275 | { 276 | let trans_id = self.id_generator.generate(); 277 | let token = announce_tokens.get(node).unwrap(); 278 | 279 | let announce_peer_req = AnnouncePeerRequest { 280 | id: self.this_node_id, 281 | info_hash: self.target_id, 282 | token: token.clone(), 283 | port, 284 | }; 285 | let announce_peer_msg = Message { 286 | transaction_id: trans_id.as_ref().to_vec(), 287 | body: MessageBody::Request(Request::AnnouncePeer(announce_peer_req)), 288 | }; 289 | 290 | match socket.send(&announce_peer_msg, node.addr).await { 291 | Ok(()) => { 292 | // We requested from the node, marke it down if the node is in our routing table 293 | if let Some(n) = self.table.lock().unwrap().find_node_mut(node) { 294 | n.local_request() 295 | } 296 | } 297 | Err(error) => { 298 | log::error!( 299 | "{}: TableLookup announce request failed to send: {}", 300 | self.ip_version, 301 | error 302 | ) 303 | } 304 | } 305 | } 306 | } 307 | 308 | // This may not be cleared since we didnt set a timeout for each node, any nodes that didnt respond would still be in here. 309 | self.active_lookups.clear(); 310 | self.in_endgame = false; 311 | } 312 | 313 | fn current_lookup_status(&self) -> ActionStatus { 314 | if self.in_endgame || !self.active_lookups.is_empty() { 315 | ActionStatus::Ongoing 316 | } else { 317 | ActionStatus::Completed 318 | } 319 | } 320 | 321 | async fn start_request_round<'a, I>( 322 | &mut self, 323 | nodes: I, 324 | socket: &Socket, 325 | timer: &mut Timer, 326 | ) where 327 | I: Iterator, 328 | { 329 | // Loop through the given nodes 330 | let mut messages_sent = 0; 331 | for (node, dist_to_beat) in nodes { 332 | // Generate a transaction id for this message 333 | let trans_id = self.id_generator.generate(); 334 | 335 | // Try to start a timeout for the node 336 | let timeout = 337 | timer.schedule_in(LOOKUP_TIMEOUT, ScheduledTaskCheck::LookupTimeout(trans_id)); 338 | 339 | // Associate the transaction id with the distance the returned nodes must beat and the timeout token 340 | self.active_lookups 341 | .insert(trans_id, (dist_to_beat, timeout)); 342 | 343 | // Send the message to the node 344 | let get_peers_msg = Message { 345 | transaction_id: trans_id.as_ref().to_vec(), 346 | body: MessageBody::Request(Request::GetPeers(GetPeersRequest { 347 | id: self.this_node_id, 348 | info_hash: self.target_id, 349 | want: None, 350 | })), 351 | }; 352 | 353 | if let Err(error) = socket.send(&get_peers_msg, node.addr).await { 354 | log::error!( 355 | "{}: Could not send a lookup message: {}", 356 | self.ip_version, 357 | error 358 | ); 359 | continue; 360 | } 361 | 362 | // We requested from the node, mark it down 363 | self.requested_nodes.insert(*node); 364 | 365 | // Update the node in the routing table 366 | if let Some(n) = self.table.lock().unwrap().find_node_mut(node) { 367 | n.local_request() 368 | } 369 | 370 | messages_sent += 1; 371 | } 372 | 373 | if messages_sent == 0 { 374 | self.active_lookups.clear(); 375 | } 376 | } 377 | 378 | async fn start_endgame_round( 379 | &mut self, 380 | socket: &Socket, 381 | timer: &mut Timer, 382 | ) -> ActionStatus { 383 | // Entering the endgame phase 384 | self.in_endgame = true; 385 | 386 | // Try to start a global message timeout for the endgame 387 | let timeout = timer.schedule_in( 388 | ENDGAME_TIMEOUT, 389 | ScheduledTaskCheck::LookupEndGame(self.id_generator.generate()), 390 | ); 391 | 392 | // Request all unpinged nodes if we didnt receive any values 393 | if !self.recv_values { 394 | for node_info in self.all_sorted_nodes.iter_mut().filter(|(_, _, req)| !req) { 395 | let (node_dist, node, req) = node_info; 396 | 397 | // Generate a transaction id for this message 398 | let trans_id = self.id_generator.generate(); 399 | 400 | // Associate the transaction id with this node's distance and its timeout token 401 | // We dont actually need to keep track of this information, but we do still need to 402 | // filter out unsolicited responses by using the active_lookups map!!! 403 | self.active_lookups.insert(trans_id, (*node_dist, timeout)); 404 | 405 | // Send the message to the node 406 | let get_peers_msg = Message { 407 | transaction_id: trans_id.as_ref().to_vec(), 408 | body: MessageBody::Request(Request::GetPeers(GetPeersRequest { 409 | id: self.this_node_id, 410 | info_hash: self.target_id, 411 | want: None, 412 | })), 413 | }; 414 | 415 | if let Err(error) = socket.send(&get_peers_msg, node.addr).await { 416 | log::error!( 417 | "{}: Could not send an endgame message: {}", 418 | self.ip_version, 419 | error 420 | ); 421 | continue; 422 | } 423 | 424 | // Mark that we requested from the node in the RoutingTable 425 | if let Some(n) = self.table.lock().unwrap().find_node_mut(node) { 426 | n.local_request() 427 | } 428 | 429 | // Mark that we requested from the node 430 | *req = true; 431 | } 432 | } 433 | 434 | ActionStatus::Ongoing 435 | } 436 | } 437 | 438 | /// Picks a number of nodes from the sorted distance iterator to ping on the first round. 439 | fn pick_initial_nodes<'a, I>(sorted_nodes: I) -> [(NodeHandle, bool); INITIAL_PICK_NUM] 440 | where 441 | I: Iterator, 442 | { 443 | let dummy_id = [0u8; INFO_HASH_LEN].into(); 444 | let default = ( 445 | NodeHandle::new(dummy_id, SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))), 446 | false, 447 | ); 448 | 449 | let mut pick_nodes = [default; INITIAL_PICK_NUM]; 450 | for (src, dst) in sorted_nodes.zip(pick_nodes.iter_mut()) { 451 | dst.0 = src.1; 452 | dst.1 = true; 453 | 454 | // Mark that the node has been requested from 455 | src.2 = true; 456 | } 457 | 458 | pick_nodes 459 | } 460 | 461 | /// Picks a number of nodes from the unsorted distance iterator to ping on iterative rounds. 462 | fn pick_iterate_nodes( 463 | unsorted_nodes: I, 464 | target_id: InfoHash, 465 | ) -> [(NodeHandle, bool); ITERATIVE_PICK_NUM] 466 | where 467 | I: Iterator, 468 | { 469 | let dummy_id = [0u8; INFO_HASH_LEN].into(); 470 | let default = ( 471 | NodeHandle::new(dummy_id, SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))), 472 | false, 473 | ); 474 | 475 | let mut pick_nodes = [default; ITERATIVE_PICK_NUM]; 476 | for node in unsorted_nodes { 477 | insert_closest_nodes(&mut pick_nodes, target_id, node); 478 | } 479 | 480 | pick_nodes 481 | } 482 | 483 | /// Inserts the node into the slice if a slot in the slice is unused or a node 484 | /// in the slice is further from the target id than the node being inserted. 485 | fn insert_closest_nodes( 486 | nodes: &mut [(NodeHandle, bool)], 487 | target_id: InfoHash, 488 | new_node: NodeHandle, 489 | ) { 490 | let new_distance = target_id ^ new_node.id; 491 | 492 | for &mut (ref mut old_node, ref mut used) in nodes.iter_mut() { 493 | if !*used { 494 | // Slot was not in use, go ahead and place the node 495 | *old_node = new_node; 496 | *used = true; 497 | return; 498 | } else { 499 | // Slot is in use, see if our node is closer to the target 500 | let old_distance = target_id ^ old_node.id; 501 | 502 | if new_distance < old_distance { 503 | *old_node = new_node; 504 | return; 505 | } 506 | } 507 | } 508 | } 509 | 510 | /// Inserts the Node into the list of nodes based on its distance from the target node. 511 | /// 512 | /// Nodes at the start of the list are closer to the target node than nodes at the end. 513 | fn insert_sorted_node( 514 | nodes: &mut Vec<(Distance, NodeHandle, bool)>, 515 | target: InfoHash, 516 | node: NodeHandle, 517 | pinged: bool, 518 | ) { 519 | let node_id = node.id; 520 | let node_dist = target ^ node_id; 521 | 522 | // Perform a search by distance from the target id 523 | let search_result = nodes.binary_search_by(|(dist, _, _)| dist.cmp(&node_dist)); 524 | match search_result { 525 | Ok(dup_index) => { 526 | // TODO: Bug here, what happens when multiple nodes with the same distance are 527 | // present, but we dont get the index of the duplicate node (its in the list) from 528 | // the search, then we would have a duplicate node in the list! 529 | // Insert only if this node is different (it is ok if they have the same id) 530 | if nodes[dup_index].1 != node { 531 | nodes.insert(dup_index, (node_dist, node, pinged)); 532 | } 533 | } 534 | Err(ins_index) => nodes.insert(ins_index, (node_dist, node, pinged)), 535 | }; 536 | } 537 | -------------------------------------------------------------------------------- /src/action/mod.rs: -------------------------------------------------------------------------------- 1 | use crate::socket::{Responded, Socket}; 2 | use crate::{info_hash::InfoHash, transaction::TransactionID}; 3 | use std::{collections::HashSet, fmt, io, net::SocketAddr}; 4 | use thiserror::Error; 5 | use tokio::sync::{mpsc, oneshot}; 6 | 7 | pub(crate) mod bootstrap; 8 | pub(crate) mod lookup; 9 | pub(crate) mod refresh; 10 | 11 | #[derive(Copy, Clone, Debug)] 12 | pub struct State { 13 | pub is_running: bool, 14 | pub bootstrapped: bool, 15 | pub good_node_count: usize, 16 | pub questionable_node_count: usize, 17 | pub bucket_count: usize, 18 | } 19 | 20 | #[derive(Debug, Eq, PartialEq, Clone, Copy)] 21 | pub enum IpVersion { 22 | V4, 23 | V6, 24 | } 25 | 26 | impl fmt::Display for IpVersion { 27 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 28 | match self { 29 | Self::V4 => write!(f, "IPv4"), 30 | Self::V6 => write!(f, "IPv6"), 31 | } 32 | } 33 | } 34 | 35 | /// Task that our DHT will execute immediately. 36 | pub(crate) enum OneshotTask { 37 | /// Load a new bootstrap operation into worker storage. 38 | StartBootstrap(), 39 | /// Check bootstrap status. The given sender will be notified when the bootstrap completed. 40 | CheckBootstrap(oneshot::Sender<()>), 41 | /// Start a lookup for the given InfoHash. 42 | StartLookup(StartLookup), 43 | /// Get the local address the socket is bound to. 44 | GetLocalAddr(oneshot::Sender), 45 | /// Retrieve debug information. 46 | GetState(oneshot::Sender), 47 | /// Retrieve IP:PORT pairs of "good" and "questionable" nodes in the routing table. 48 | LoadContacts(oneshot::Sender<(HashSet, HashSet)>), 49 | } 50 | 51 | pub(crate) struct StartLookup { 52 | pub info_hash: InfoHash, 53 | pub announce: bool, 54 | pub tx: mpsc::UnboundedSender, 55 | } 56 | 57 | /// Task that our DHT will execute some time later. 58 | #[derive(Copy, Clone, Debug)] 59 | pub(crate) enum ScheduledTaskCheck { 60 | /// Check the progress of the bucket refresh. 61 | TableRefresh, 62 | /// Check the progress of the current bootstrap. 63 | //BootstrapTimeout(BootstrapTimeout), 64 | /// Check the progress of a current lookup. 65 | LookupTimeout(TransactionID), 66 | /// Check the progress of the lookup endgame. 67 | LookupEndGame(TransactionID), 68 | } 69 | 70 | #[derive(Error, Debug)] 71 | pub(crate) enum WorkerError { 72 | #[error("invalid transaction id")] 73 | InvalidTransactionId, 74 | #[error("received unsolicited response")] 75 | UnsolicitedResponse, 76 | #[error("socket error")] 77 | SocketError(#[from] io::Error), 78 | } 79 | 80 | #[derive(Debug, PartialEq, Eq)] 81 | pub(crate) enum ActionStatus { 82 | /// Action is in progress 83 | Ongoing, 84 | /// Action completed 85 | Completed, 86 | } 87 | 88 | pub(crate) async fn resolve(routers: &HashSet, ip_v: IpVersion) -> HashSet { 89 | futures_util::future::join_all(routers.iter().map(tokio::net::lookup_host)) 90 | .await 91 | .into_iter() 92 | .filter_map(|result| result.ok()) 93 | .flatten() 94 | .filter(|addr| match ip_v { 95 | IpVersion::V4 => addr.is_ipv4(), 96 | IpVersion::V6 => addr.is_ipv6(), 97 | }) 98 | .collect() 99 | } 100 | -------------------------------------------------------------------------------- /src/action/refresh.rs: -------------------------------------------------------------------------------- 1 | use super::ScheduledTaskCheck; 2 | use crate::node::NodeStatus; 3 | use crate::table::{self, RoutingTable}; 4 | use crate::transaction::{ActionID, MIDGenerator}; 5 | use crate::{ 6 | message::{FindNodeRequest, Message, MessageBody, Request}, 7 | socket::Socket, 8 | timer::Timer, 9 | }; 10 | use std::{ 11 | sync::{Arc, Mutex}, 12 | time::Duration, 13 | }; 14 | 15 | const REFRESH_INTERVAL_TIMEOUT: Duration = Duration::from_millis(6000); 16 | const REFRESH_CONCURRENCY: usize = 4; 17 | 18 | pub(crate) struct TableRefresh { 19 | table: Arc>, 20 | id_generator: MIDGenerator, 21 | curr_refresh_bucket: usize, 22 | } 23 | 24 | impl TableRefresh { 25 | pub fn new(id_generator: MIDGenerator, table: Arc>) -> TableRefresh { 26 | TableRefresh { 27 | table, 28 | id_generator, 29 | curr_refresh_bucket: 0, 30 | } 31 | } 32 | 33 | pub fn action_id(&self) -> ActionID { 34 | self.id_generator.action_id() 35 | } 36 | 37 | pub async fn continue_refresh( 38 | &mut self, 39 | socket: &Socket, 40 | timer: &mut Timer, 41 | ) { 42 | if self.curr_refresh_bucket == table::MAX_BUCKETS { 43 | self.curr_refresh_bucket = 0; 44 | } 45 | 46 | let (this_node_id, target_id, num_good_nodes, num_questionable_nodes, nodes_to_contact) = { 47 | let table = self.table.lock().unwrap(); 48 | 49 | let this_node_id = table.node_id(); 50 | let target_id = this_node_id.flip_bit(self.curr_refresh_bucket); 51 | let num_good_nodes = table.num_good_nodes(); 52 | let num_questionable_nodes = table.num_questionable_nodes(); 53 | let nodes_to_contact = table 54 | .closest_nodes(target_id) 55 | .filter(|n| n.status() == NodeStatus::Questionable) 56 | .filter(|n| !n.recently_requested_from()) 57 | .take(REFRESH_CONCURRENCY) 58 | .map(|node| *node.handle()) 59 | .collect::>(); 60 | 61 | ( 62 | this_node_id, 63 | target_id, 64 | num_good_nodes, 65 | num_questionable_nodes, 66 | nodes_to_contact, 67 | ) 68 | }; 69 | 70 | log::debug!( 71 | "Performing a refresh for bucket {} (table total: num_good_nodes={}, num_questionable_nodes={})", 72 | self.curr_refresh_bucket, 73 | num_good_nodes, 74 | num_questionable_nodes, 75 | ); 76 | 77 | // Ping the closest questionable nodes 78 | for node in nodes_to_contact { 79 | // Generate a transaction id for the request 80 | let trans_id = self.id_generator.generate(); 81 | 82 | // Construct the message 83 | let find_node_req = FindNodeRequest { 84 | id: this_node_id, 85 | target: target_id, 86 | want: None, 87 | }; 88 | let find_node_msg = Message { 89 | transaction_id: trans_id.as_ref().to_vec(), 90 | body: MessageBody::Request(Request::FindNode(find_node_req)), 91 | }; 92 | 93 | // Send the message 94 | if let Err(error) = socket.send(&find_node_msg, node.addr).await { 95 | log::error!("TableRefresh failed to send a refresh message: {}", error); 96 | } 97 | 98 | // Mark that we requested from the node 99 | if let Some(node) = self.table.lock().unwrap().find_node_mut(&node) { 100 | node.local_request(); 101 | } 102 | } 103 | 104 | // Start a timer for the next refresh 105 | timer.schedule_in(REFRESH_INTERVAL_TIMEOUT, ScheduledTaskCheck::TableRefresh); 106 | 107 | self.curr_refresh_bucket += 1; 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/bucket.rs: -------------------------------------------------------------------------------- 1 | use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; 2 | use std::slice::Iter; 3 | 4 | use crate::info_hash::{NodeId, NODE_ID_LEN}; 5 | use crate::node::{Node, NodeStatus}; 6 | 7 | /// Maximum number of nodes that should reside in any bucket. 8 | pub const MAX_BUCKET_SIZE: usize = 8; 9 | 10 | /// Bucket containing Nodes with identical bit prefixes. 11 | pub struct Bucket { 12 | nodes: [Node; MAX_BUCKET_SIZE], 13 | } 14 | 15 | impl Bucket { 16 | /// Create a new Bucket with all Nodes default initialized. 17 | pub fn new() -> Bucket { 18 | let id = NodeId::from([0u8; NODE_ID_LEN]); 19 | 20 | let ip = Ipv4Addr::new(127, 0, 0, 1); 21 | let addr = SocketAddr::V4(SocketAddrV4::new(ip, 0)); 22 | 23 | Bucket { 24 | nodes: [ 25 | Node::as_bad(id, addr), 26 | Node::as_bad(id, addr), 27 | Node::as_bad(id, addr), 28 | Node::as_bad(id, addr), 29 | Node::as_bad(id, addr), 30 | Node::as_bad(id, addr), 31 | Node::as_bad(id, addr), 32 | Node::as_bad(id, addr), 33 | ], 34 | } 35 | } 36 | 37 | /// Iterator over all good nodes and questionable nodes in the bucket. 38 | pub fn pingable_nodes(&self) -> impl Iterator { 39 | self.nodes.iter().filter(|node| node.is_pingable()) 40 | } 41 | 42 | /// Iterator over all good nodes and questionable nodes in the bucket that allos modifying the 43 | /// nodes. 44 | pub fn pingable_nodes_mut(&mut self) -> impl Iterator { 45 | self.nodes.iter_mut().filter(|node| node.is_pingable()) 46 | } 47 | 48 | /// Iterator over each node within the bucket. 49 | /// 50 | /// For buckets newly created, the initial bad nodes are included. 51 | pub fn iter(&self) -> Iter { 52 | self.nodes.iter() 53 | } 54 | 55 | /// Indicates if the bucket needs to be refreshed. 56 | #[allow(unused)] 57 | pub fn needs_refresh(&self) -> bool { 58 | self.nodes 59 | .iter() 60 | .all(|node| node.status() != NodeStatus::Good) 61 | } 62 | 63 | /// Attempt to add the given Node to the bucket if it is not in a bad state. 64 | /// 65 | /// Returns false if the Node could not be placed in the bucket because it is full. 66 | pub fn add_node(&mut self, new_node: Node) -> bool { 67 | let new_node_status = new_node.status(); 68 | if new_node_status == NodeStatus::Bad { 69 | return true; 70 | } 71 | 72 | // See if this node is already in the table, in that case replace it if it 73 | // has a higher or equal status to the current node. 74 | if let Some(index) = self.nodes.iter().position(|node| *node == new_node) { 75 | // Note, we can't just compare the status and if it's better or equal then replace the 76 | // old node with the new one. Doing so would erase information already stored locally. 77 | self.nodes[index].update(new_node); 78 | 79 | return true; 80 | } 81 | 82 | // See if any lower priority nodes are present in the table, we cant do 83 | // nodes that have equal status because we have to prefer longer lasting 84 | // nodes in the case of a good status which helps with stability. 85 | let replace_index = self 86 | .nodes 87 | .iter() 88 | .position(|node| node.status() < new_node_status); 89 | if let Some(index) = replace_index { 90 | self.nodes[index] = new_node; 91 | 92 | true 93 | } else { 94 | false 95 | } 96 | } 97 | 98 | /// Iterator over all good nodes in the bucket. 99 | #[cfg(test)] 100 | fn good_nodes(&self) -> impl Iterator { 101 | self.nodes 102 | .iter() 103 | .filter(|node| node.status() == NodeStatus::Good) 104 | } 105 | } 106 | 107 | // ----------------------------------------------------------------------------// 108 | 109 | #[cfg(test)] 110 | mod tests { 111 | 112 | use crate::bucket::Bucket; 113 | use crate::node::{Node, NodeStatus}; 114 | use crate::test; 115 | 116 | #[test] 117 | fn positive_initial_no_nodes() { 118 | let bucket = Bucket::new(); 119 | 120 | assert_eq!(bucket.good_nodes().count(), 0); 121 | assert_eq!(bucket.pingable_nodes().count(), 0); 122 | } 123 | 124 | #[test] 125 | fn positive_all_questionable_nodes() { 126 | let mut bucket = Bucket::new(); 127 | 128 | let dummy_addr = test::dummy_socket_addr_v4(); 129 | let dummy_ids = test::dummy_block_node_ids(super::MAX_BUCKET_SIZE as u8); 130 | for id in dummy_ids { 131 | let node = Node::as_questionable(id, dummy_addr); 132 | bucket.add_node(node); 133 | } 134 | 135 | assert_eq!(bucket.good_nodes().count(), 0); 136 | assert_eq!(bucket.pingable_nodes().count(), super::MAX_BUCKET_SIZE); 137 | } 138 | 139 | #[test] 140 | fn positive_all_good_nodes() { 141 | let mut bucket = Bucket::new(); 142 | 143 | let dummy_addr = test::dummy_socket_addr_v4(); 144 | let dummy_ids = test::dummy_block_node_ids(super::MAX_BUCKET_SIZE as u8); 145 | for id in dummy_ids { 146 | let node = Node::as_good(id, dummy_addr); 147 | bucket.add_node(node); 148 | } 149 | 150 | assert_eq!(bucket.good_nodes().count(), super::MAX_BUCKET_SIZE); 151 | assert_eq!(bucket.pingable_nodes().count(), super::MAX_BUCKET_SIZE); 152 | } 153 | 154 | #[test] 155 | fn positive_replace_questionable_node() { 156 | let mut bucket = Bucket::new(); 157 | 158 | let dummy_addr = test::dummy_socket_addr_v4(); 159 | let dummy_ids = test::dummy_block_node_ids(super::MAX_BUCKET_SIZE as u8); 160 | for id in &dummy_ids { 161 | let node = Node::as_questionable(*id, dummy_addr); 162 | bucket.add_node(node); 163 | } 164 | 165 | assert_eq!(bucket.good_nodes().count(), 0); 166 | assert_eq!(bucket.pingable_nodes().count(), super::MAX_BUCKET_SIZE); 167 | 168 | let good_node = Node::as_good(dummy_ids[0], dummy_addr); 169 | bucket.add_node(good_node.clone()); 170 | 171 | assert_eq!(bucket.good_nodes().next().unwrap(), &good_node); 172 | assert_eq!(bucket.good_nodes().count(), 1); 173 | assert_eq!(bucket.pingable_nodes().count(), super::MAX_BUCKET_SIZE); 174 | } 175 | 176 | #[test] 177 | fn positive_resist_good_node_churn() { 178 | let mut bucket = Bucket::new(); 179 | 180 | let dummy_addr = test::dummy_socket_addr_v4(); 181 | let dummy_ids = test::dummy_block_node_ids((super::MAX_BUCKET_SIZE as u8) + 1); 182 | for id in &dummy_ids { 183 | let node = Node::as_good(*id, dummy_addr); 184 | bucket.add_node(node); 185 | } 186 | 187 | // All the nodes should be good 188 | assert_eq!(bucket.good_nodes().count(), super::MAX_BUCKET_SIZE); 189 | 190 | // Create a new good node 191 | let unused_id = dummy_ids[dummy_ids.len() - 1]; 192 | let new_good_node = Node::as_good(unused_id, dummy_addr); 193 | 194 | // Make sure the node is NOT in the bucket 195 | assert!(!bucket.good_nodes().any(|node| &new_good_node == node)); 196 | 197 | // Try to add it 198 | bucket.add_node(new_good_node.clone()); 199 | 200 | // Make sure the node is NOT in the bucket 201 | assert!(!bucket.good_nodes().any(|node| &new_good_node == node)); 202 | } 203 | 204 | #[test] 205 | fn positive_resist_questionable_node_churn() { 206 | let mut bucket = Bucket::new(); 207 | 208 | let dummy_addr = test::dummy_socket_addr_v4(); 209 | let dummy_ids = test::dummy_block_node_ids((super::MAX_BUCKET_SIZE as u8) + 1); 210 | for id in &dummy_ids { 211 | let node = Node::as_questionable(*id, dummy_addr); 212 | bucket.add_node(node); 213 | } 214 | 215 | // All the nodes should be questionable 216 | assert_eq!( 217 | bucket 218 | .pingable_nodes() 219 | .filter(|node| node.status() == NodeStatus::Questionable) 220 | .count(), 221 | super::MAX_BUCKET_SIZE 222 | ); 223 | 224 | // Create a new questionable node 225 | let unused_id = dummy_ids[dummy_ids.len() - 1]; 226 | let new_questionable_node = Node::as_questionable(unused_id, dummy_addr); 227 | 228 | // Make sure the node is NOT in the bucket 229 | assert!(!bucket 230 | .pingable_nodes() 231 | .any(|node| &new_questionable_node == node)); 232 | 233 | // Try to add it 234 | bucket.add_node(new_questionable_node); 235 | 236 | // Make sure the node is NOT in the bucket 237 | assert_eq!( 238 | bucket 239 | .pingable_nodes() 240 | .filter(|node| node.status() == NodeStatus::Questionable) 241 | .count(), 242 | super::MAX_BUCKET_SIZE 243 | ); 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /src/compact.rs: -------------------------------------------------------------------------------- 1 | //! Compact representation 2 | 3 | use crate::info_hash::NODE_ID_LEN; 4 | use std::{ 5 | convert::TryInto, 6 | net::{Ipv4Addr, Ipv6Addr, SocketAddr}, 7 | }; 8 | 9 | const SOCKET_ADDR_V4_LEN: usize = 6; 10 | const SOCKET_ADDR_V6_LEN: usize = 18; 11 | 12 | /// Serialize/deserialize `Vec` of `SocketAddr` in compact format. 13 | pub(crate) mod values { 14 | use serde::{ 15 | de::{Deserializer, Error as _, SeqAccess, Visitor}, 16 | ser::{SerializeSeq, Serializer}, 17 | }; 18 | use serde_bytes::{ByteBuf, Bytes}; 19 | use std::{fmt, net::SocketAddr}; 20 | 21 | pub(crate) fn serialize(addrs: &[SocketAddr], s: S) -> Result 22 | where 23 | S: Serializer, 24 | { 25 | let mut seq = s.serialize_seq(Some(addrs.len()))?; 26 | for addr in addrs { 27 | seq.serialize_element(Bytes::new(&super::encode_socket_addr(addr)))? 28 | } 29 | seq.end() 30 | } 31 | 32 | pub(crate) fn deserialize<'de, D>(d: D) -> Result, D::Error> 33 | where 34 | D: Deserializer<'de>, 35 | { 36 | struct SocketAddrsVisitor; 37 | 38 | impl<'de> Visitor<'de> for SocketAddrsVisitor { 39 | type Value = Vec; 40 | 41 | fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { 42 | write!(f, "list of byte strings") 43 | } 44 | 45 | fn visit_seq(self, mut seq: A) -> Result 46 | where 47 | A: SeqAccess<'de>, 48 | { 49 | let mut output = Vec::with_capacity(seq.size_hint().unwrap_or(0)); 50 | 51 | while let Some(bytes) = seq.next_element::()? { 52 | let item = super::decode_socket_addr(&bytes) 53 | .ok_or_else(|| A::Error::invalid_length(bytes.len(), &self))?; 54 | output.push(item); 55 | } 56 | 57 | Ok(output) 58 | } 59 | } 60 | 61 | d.deserialize_seq(SocketAddrsVisitor) 62 | } 63 | } 64 | 65 | /// Serialize/deserialize `Vec` of `NodeHandle` in compact format. Specialized for ipv4 addresses. 66 | pub(crate) mod nodes_v4 { 67 | use crate::node::NodeHandle; 68 | use serde::{de::Deserializer, ser::Serializer}; 69 | 70 | pub(crate) fn serialize(nodes: &[NodeHandle], s: S) -> Result 71 | where 72 | S: Serializer, 73 | { 74 | super::nodes::serialize::(nodes, s) 75 | } 76 | 77 | pub(crate) fn deserialize<'de, D>(d: D) -> Result, D::Error> 78 | where 79 | D: Deserializer<'de>, 80 | { 81 | super::nodes::deserialize::(d) 82 | } 83 | } 84 | 85 | /// Serialize/deserialize `Vec` of `NodeHandle` in compact format. Specialized for ipv6 addresses. 86 | pub(crate) mod nodes_v6 { 87 | use crate::node::NodeHandle; 88 | use serde::{de::Deserializer, ser::Serializer}; 89 | 90 | pub(crate) fn serialize(nodes: &[NodeHandle], s: S) -> Result 91 | where 92 | S: Serializer, 93 | { 94 | super::nodes::serialize::(nodes, s) 95 | } 96 | 97 | pub(crate) fn deserialize<'de, D>(d: D) -> Result, D::Error> 98 | where 99 | D: Deserializer<'de>, 100 | { 101 | super::nodes::deserialize::(d) 102 | } 103 | } 104 | 105 | /// Serialize/deserialize `Vec` of `NodeHandle` in compact format. Generic over address family. 106 | mod nodes { 107 | use crate::{info_hash::NodeId, node::NodeHandle}; 108 | use serde::{ 109 | de::{Deserialize, Deserializer, Error as _}, 110 | ser::{Error as _, Serializer}, 111 | }; 112 | use serde_bytes::ByteBuf; 113 | use std::convert::TryFrom; 114 | 115 | pub(crate) fn serialize( 116 | nodes: &[NodeHandle], 117 | s: S, 118 | ) -> Result 119 | where 120 | S: Serializer, 121 | { 122 | let mut buffer = Vec::with_capacity(nodes.len() * (super::NODE_ID_LEN + ADDR_LEN)); 123 | 124 | for node in nodes { 125 | let encoded_addr = super::encode_socket_addr(&node.addr); 126 | 127 | if encoded_addr.len() != ADDR_LEN { 128 | return Err(S::Error::custom("unexpected address family")); 129 | } 130 | 131 | buffer.extend(node.id.as_ref()); 132 | buffer.extend(encoded_addr); 133 | } 134 | 135 | s.serialize_bytes(&buffer) 136 | } 137 | 138 | pub(crate) fn deserialize<'de, D, const ADDR_LEN: usize>( 139 | d: D, 140 | ) -> Result, D::Error> 141 | where 142 | D: Deserializer<'de>, 143 | { 144 | let buffer = ByteBuf::deserialize(d)?; 145 | let chunks = buffer.chunks_exact(super::NODE_ID_LEN + ADDR_LEN); 146 | 147 | if !chunks.remainder().is_empty() { 148 | let msg = format!("multiple of {}", (super::NODE_ID_LEN + ADDR_LEN)); 149 | return Err(D::Error::invalid_length(buffer.len(), &msg.as_ref())); 150 | } 151 | 152 | let nodes = chunks 153 | .filter_map(|chunk| { 154 | let id = NodeId::try_from(&chunk[..super::NODE_ID_LEN]).ok()?; 155 | let addr = super::decode_socket_addr(&chunk[super::NODE_ID_LEN..])?; 156 | 157 | Some(NodeHandle { id, addr }) 158 | }) 159 | .collect(); 160 | 161 | Ok(nodes) 162 | } 163 | } 164 | 165 | fn decode_socket_addr(src: &[u8]) -> Option { 166 | if src.len() == SOCKET_ADDR_V4_LEN { 167 | let addr: [u8; 4] = src.get(..4)?.try_into().ok()?; 168 | let addr = Ipv4Addr::from(addr); 169 | let port = u16::from_be_bytes(src.get(4..)?.try_into().ok()?); 170 | Some((addr, port).into()) 171 | } else if src.len() == SOCKET_ADDR_V6_LEN { 172 | let addr: [u8; 16] = src.get(..16)?.try_into().ok()?; 173 | let addr = Ipv6Addr::from(addr); 174 | let port = u16::from_be_bytes(src.get(16..)?.try_into().ok()?); 175 | Some((addr, port).into()) 176 | } else { 177 | None 178 | } 179 | } 180 | 181 | // TODO: consider returning `ArrayVec` to avoid lot of small allocations. 182 | fn encode_socket_addr(addr: &SocketAddr) -> Vec { 183 | let mut buffer = match addr { 184 | SocketAddr::V4(addr) => { 185 | let mut buffer = Vec::with_capacity(6); 186 | buffer.extend(addr.ip().octets().as_ref()); 187 | buffer 188 | } 189 | SocketAddr::V6(addr) => { 190 | let mut buffer = Vec::with_capacity(18); 191 | buffer.extend(addr.ip().octets().as_ref()); 192 | buffer 193 | } 194 | }; 195 | 196 | buffer.extend(addr.port().to_be_bytes().as_ref()); 197 | buffer 198 | } 199 | 200 | #[cfg(test)] 201 | mod tests { 202 | use crate::{info_hash::NodeId, node::NodeHandle}; 203 | use serde::{Deserialize, Serialize}; 204 | use std::{ 205 | fmt::Debug, 206 | net::{Ipv4Addr, Ipv6Addr, SocketAddr}, 207 | }; 208 | 209 | #[test] 210 | fn encode_decode_values() { 211 | #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] 212 | #[serde(transparent)] 213 | struct Wrapper { 214 | #[serde(with = "super::values")] 215 | values: Vec, 216 | } 217 | 218 | // empty 219 | encode_decode(&Wrapper { values: Vec::new() }, b"le"); 220 | // one v4 221 | encode_decode( 222 | &Wrapper { 223 | values: vec![(Ipv4Addr::new(127, 0, 0, 1), 6789).into()], 224 | }, 225 | &[b'l', b'6', b':', 127, 0, 0, 1, 26, 133, b'e'], 226 | ); 227 | // two v4 228 | encode_decode( 229 | &Wrapper { 230 | values: vec![ 231 | (Ipv4Addr::new(127, 0, 0, 1), 6789).into(), 232 | (Ipv4Addr::new(127, 0, 0, 2), 1234).into(), 233 | ], 234 | }, 235 | &[ 236 | b'l', b'6', b':', 127, 0, 0, 1, 26, 133, b'6', b':', 127, 0, 0, 2, 4, 210, b'e', 237 | ], 238 | ); 239 | // one v6 240 | encode_decode( 241 | &Wrapper { 242 | values: vec![( 243 | Ipv6Addr::new( 244 | 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 245 | ), 246 | 6789, 247 | ) 248 | .into()], 249 | }, 250 | &[ 251 | b'l', b'1', b'8', b':', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 26, 133, 252 | b'e', 253 | ], 254 | ); 255 | // two v6 256 | encode_decode( 257 | &Wrapper { 258 | values: vec![ 259 | ( 260 | Ipv6Addr::new( 261 | 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 262 | ), 263 | 6789, 264 | ) 265 | .into(), 266 | ( 267 | Ipv6Addr::new( 268 | 0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334, 269 | ), 270 | 1234, 271 | ) 272 | .into(), 273 | ], 274 | }, 275 | &[ 276 | b'l', b'1', b'8', b':', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 26, 133, 277 | b'1', b'8', b':', 0x20, 0x01, 0x0d, 0xb8, 0x85, 0xa3, 0x00, 0x00, 0x00, 0x00, 0x8a, 278 | 0x2e, 0x03, 0x70, 0x73, 0x34, 4, 210, b'e', 279 | ], 280 | ); 281 | // hybrid (v4 + v6) 282 | encode_decode( 283 | &Wrapper { 284 | values: vec![ 285 | (Ipv4Addr::new(127, 0, 0, 1), 6789).into(), 286 | ( 287 | Ipv6Addr::new( 288 | 0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334, 289 | ), 290 | 1234, 291 | ) 292 | .into(), 293 | ], 294 | }, 295 | &[ 296 | b'l', b'6', b':', 127, 0, 0, 1, 26, 133, b'1', b'8', b':', 0x20, 0x01, 0x0d, 0xb8, 297 | 0x85, 0xa3, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x2e, 0x03, 0x70, 0x73, 0x34, 4, 210, 298 | b'e', 299 | ], 300 | ); 301 | } 302 | 303 | #[test] 304 | fn encode_decode_nodes_v4() { 305 | #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] 306 | #[serde(transparent)] 307 | struct Wrapper { 308 | #[serde(with = "super::nodes_v4")] 309 | nodes: Vec, 310 | } 311 | 312 | encode_decode(&Wrapper { nodes: Vec::new() }, b"0:"); 313 | encode_decode( 314 | &Wrapper { 315 | nodes: vec![NodeHandle { 316 | id: NodeId::from(*b"0123456789abcdefghij"), 317 | addr: (Ipv4Addr::new(127, 0, 0, 1), 6789).into(), 318 | }], 319 | }, 320 | &[ 321 | b'2', b'6', b':', b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'a', 322 | b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', 127, 0, 0, 1, 26, 133, 323 | ], 324 | ); 325 | encode_decode( 326 | &Wrapper { 327 | nodes: vec![ 328 | NodeHandle { 329 | id: NodeId::from(*b"0123456789abcdefghij"), 330 | addr: (Ipv4Addr::new(127, 0, 0, 1), 6789).into(), 331 | }, 332 | NodeHandle { 333 | id: NodeId::from(*b"klmnopqrstuvwxyz0123"), 334 | addr: (Ipv4Addr::new(127, 0, 0, 2), 1234).into(), 335 | }, 336 | ], 337 | }, 338 | &[ 339 | b'5', b'2', b':', b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'a', 340 | b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', 127, 0, 0, 1, 26, 133, b'k', 341 | b'l', b'm', b'n', b'o', b'p', b'q', b'r', b's', b't', b'u', b'v', b'w', b'x', b'y', 342 | b'z', b'0', b'1', b'2', b'3', 127, 0, 0, 2, 4, 210, 343 | ], 344 | ); 345 | } 346 | 347 | #[test] 348 | fn encode_decode_nodes_v6() { 349 | #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] 350 | #[serde(transparent)] 351 | struct Wrapper { 352 | #[serde(with = "super::nodes_v6")] 353 | nodes: Vec, 354 | } 355 | 356 | encode_decode(&Wrapper { nodes: Vec::new() }, b"0:"); 357 | encode_decode( 358 | &Wrapper { 359 | nodes: vec![NodeHandle { 360 | id: NodeId::from(*b"0123456789abcdefghij"), 361 | addr: ( 362 | Ipv6Addr::new( 363 | 0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334, 364 | ), 365 | 6789, 366 | ) 367 | .into(), 368 | }], 369 | }, 370 | &[ 371 | b'3', b'8', b':', b'0', b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8', b'9', b'a', 372 | b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i', b'j', 0x20, 0x01, 0x0d, 0xb8, 0x85, 373 | 0xa3, 0x00, 0x00, 0x00, 0x00, 0x8a, 0x2e, 0x03, 0x70, 0x73, 0x34, 26, 133, 374 | ], 375 | ); 376 | } 377 | 378 | #[test] 379 | fn attempt_to_encode_v4_nodes_as_v6() { 380 | #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] 381 | #[serde(transparent)] 382 | struct Wrapper { 383 | #[serde(with = "super::nodes_v6")] 384 | nodes: Vec, 385 | } 386 | 387 | let value = Wrapper { 388 | nodes: vec![NodeHandle { 389 | id: NodeId::from(*b"0123456789abcdefghij"), 390 | addr: (Ipv4Addr::new(127, 0, 0, 1), 1234).into(), 391 | }], 392 | }; 393 | 394 | assert!(serde_bencode::to_bytes(&value).is_err()); 395 | } 396 | 397 | #[test] 398 | fn attempt_to_encode_v6_nodes_as_v4() { 399 | #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] 400 | #[serde(transparent)] 401 | struct Wrapper { 402 | #[serde(with = "super::nodes_v4")] 403 | nodes: Vec, 404 | } 405 | 406 | let value = Wrapper { 407 | nodes: vec![NodeHandle { 408 | id: NodeId::from(*b"0123456789abcdefghij"), 409 | addr: ( 410 | Ipv6Addr::new( 411 | 0x2001, 0x0db8, 0x85a3, 0x0000, 0x0000, 0x8a2e, 0x0370, 0x7334, 412 | ), 413 | 1234, 414 | ) 415 | .into(), 416 | }], 417 | }; 418 | 419 | assert!(serde_bencode::to_bytes(&value).is_err()); 420 | } 421 | 422 | fn encode_decode<'de, T>(value: &T, expected_encoded: &'de [u8]) 423 | where 424 | T: Serialize + Deserialize<'de> + Eq + Debug, 425 | { 426 | let actual_encoded = serde_bencode::to_bytes(value).unwrap(); 427 | assert_eq!(actual_encoded, expected_encoded); 428 | 429 | let actual_decoded: T = serde_bencode::from_bytes(expected_encoded).unwrap(); 430 | assert_eq!(actual_decoded, *value); 431 | } 432 | } 433 | -------------------------------------------------------------------------------- /src/handler.rs: -------------------------------------------------------------------------------- 1 | use crate::action::{ 2 | bootstrap::{self, TableBootstrap}, 3 | lookup::TableLookup, 4 | refresh::TableRefresh, 5 | ActionStatus, IpVersion, OneshotTask, ScheduledTaskCheck, StartLookup, State, WorkerError, 6 | }; 7 | use crate::{ 8 | info_hash::{InfoHash, NodeId}, 9 | message::{error_code, Error, Message, MessageBody, Request, Response, Want}, 10 | node::{Node, NodeHandle}, 11 | socket::Socket, 12 | storage::AnnounceStorage, 13 | table::RoutingTable, 14 | timer::Timer, 15 | token::{Token, TokenStore}, 16 | transaction::{AIDGenerator, ActionID, TransactionID}, 17 | }; 18 | use futures_util::StreamExt; 19 | use std::{ 20 | collections::{HashMap, HashSet}, 21 | convert::AsRef, 22 | net::SocketAddr, 23 | sync::{Arc, Mutex}, 24 | }; 25 | use tokio::{ 26 | select, 27 | sync::{mpsc, oneshot}, 28 | }; 29 | 30 | /// Storage for our EventLoop to invoke actions upon. 31 | pub(crate) struct DhtHandler { 32 | this_node_id: NodeId, 33 | running: bool, 34 | command_rx: mpsc::UnboundedReceiver, 35 | timer: Timer, 36 | read_only: bool, 37 | announce_port: Option, 38 | socket: Arc, 39 | token_store: TokenStore, 40 | aid_generator: AIDGenerator, 41 | routing_table: Arc>, 42 | active_stores: AnnounceStorage, 43 | bootstrap: TableBootstrap, 44 | 45 | next_bootstrap_txs_id: u64, 46 | bootstrap_txs: HashMap>, 47 | 48 | // TableRefresh action. 49 | refresh: TableRefresh, 50 | // Ongoing TableLookups. 51 | lookups: HashMap, 52 | } 53 | 54 | impl DhtHandler { 55 | pub fn new( 56 | this_node_id: NodeId, 57 | socket: Socket, 58 | read_only: bool, 59 | routers: HashSet, 60 | nodes: HashSet, 61 | announce_port: Option, 62 | command_rx: mpsc::UnboundedReceiver, 63 | ) -> Self { 64 | let socket = Arc::new(socket); 65 | let table = Arc::new(Mutex::new(RoutingTable::new(this_node_id))); 66 | 67 | let mut aid_generator = AIDGenerator::new(); 68 | 69 | // The refresh task to execute after the bootstrap 70 | let mid_generator = aid_generator.generate(); 71 | let table_refresh = TableRefresh::new(mid_generator, table.clone()); 72 | 73 | let mid_generator = aid_generator.generate(); 74 | let bootstrap = 75 | TableBootstrap::new(socket.clone(), table.clone(), mid_generator, routers, nodes); 76 | 77 | let timer = Timer::new(); 78 | 79 | Self { 80 | this_node_id, 81 | running: true, 82 | command_rx, 83 | timer, 84 | read_only, 85 | announce_port, 86 | socket, 87 | token_store: TokenStore::new(), 88 | aid_generator, 89 | routing_table: table, 90 | active_stores: AnnounceStorage::new(), 91 | bootstrap, 92 | next_bootstrap_txs_id: 0, 93 | bootstrap_txs: HashMap::new(), 94 | refresh: table_refresh, 95 | lookups: HashMap::new(), 96 | } 97 | } 98 | 99 | fn ip_version(&self) -> IpVersion { 100 | self.socket.ip_version() 101 | } 102 | 103 | pub async fn run(mut self) { 104 | while self.running { 105 | self.run_once().await 106 | } 107 | } 108 | 109 | async fn run_once(&mut self) { 110 | select! { 111 | token = self.timer.next(), if !self.timer.is_empty() => { 112 | // `unwrap` is OK because we checked the timer is non-empty, so it should never 113 | // return `None`. 114 | let token = token.unwrap(); 115 | self.handle_timeout(token).await 116 | } 117 | command = self.command_rx.recv() => { 118 | if let Some(command) = command { 119 | self.handle_command(command).await 120 | } else { 121 | self.shutdown() 122 | } 123 | } 124 | result = self.bootstrap.state_rx.changed() => { 125 | assert!(result.is_ok()); 126 | if self.is_bootstrapped() { 127 | self.handle_bootstrap_success().await; 128 | } 129 | } 130 | message = self.socket.recv() => { 131 | match message { 132 | Ok((message, addr)) => if let Err(error) = self.handle_incoming(message, addr).await { 133 | log::debug!("{}: Failed to handle incoming message: {} from:{addr:?}", self.ip_version(), error); 134 | } 135 | Err(error) => log::warn!("{}: Failed to receive incoming message: {}", self.ip_version(), error), 136 | } 137 | } 138 | } 139 | } 140 | 141 | fn is_bootstrapped(&self) -> bool { 142 | *self.bootstrap.state_rx.borrow() == bootstrap::State::Bootstrapped 143 | } 144 | 145 | async fn handle_command(&mut self, task: OneshotTask) { 146 | match task { 147 | OneshotTask::StartBootstrap() => { 148 | self.handle_start_bootstrap(); 149 | } 150 | OneshotTask::CheckBootstrap(tx) => { 151 | self.handle_check_bootstrap(tx); 152 | } 153 | OneshotTask::StartLookup(lookup) => { 154 | self.handle_start_lookup(lookup).await; 155 | } 156 | OneshotTask::GetLocalAddr(tx) => self.handle_get_local_addr(tx), 157 | OneshotTask::GetState(tx) => self.handle_get_state(tx), 158 | OneshotTask::LoadContacts(tx) => self.handle_load_contacts(tx), 159 | } 160 | } 161 | 162 | async fn handle_timeout(&mut self, token: ScheduledTaskCheck) { 163 | match token { 164 | ScheduledTaskCheck::TableRefresh => { 165 | self.handle_check_table_refresh().await; 166 | } 167 | ScheduledTaskCheck::LookupTimeout(trans_id) => { 168 | self.handle_check_lookup_timeout(trans_id).await; 169 | } 170 | ScheduledTaskCheck::LookupEndGame(trans_id) => { 171 | self.handle_check_lookup_endgame(trans_id).await; 172 | } 173 | } 174 | } 175 | 176 | async fn handle_incoming( 177 | &mut self, 178 | message: Message, 179 | addr: SocketAddr, 180 | ) -> Result<(), WorkerError> { 181 | // Do not process requests if we are read only 182 | // TODO: Add read only flags to messages we send it we are read only! 183 | // Also, check for read only flags on responses we get before adding nodes 184 | // to our RoutingTable. 185 | if self.read_only && matches!(message.body, MessageBody::Request(_)) { 186 | return Ok(()); 187 | } 188 | 189 | log::trace!("{}: Received {:?}", self.ip_version(), message); 190 | 191 | // Process the given message 192 | match message.body { 193 | MessageBody::Request(Request::Ping(p)) => { 194 | let node = NodeHandle::new(p.id, addr); 195 | 196 | // Node requested from us, mark it in the Routingtable 197 | if let Some(n) = self.routing_table.lock().unwrap().find_node_mut(&node) { 198 | n.remote_request() 199 | } 200 | 201 | let ping_rsp = Response { 202 | id: self.this_node_id, 203 | values: vec![], 204 | nodes_v4: vec![], 205 | nodes_v6: vec![], 206 | token: None, 207 | }; 208 | let ping_msg = Message { 209 | transaction_id: message.transaction_id, 210 | body: MessageBody::Response(ping_rsp), 211 | }; 212 | 213 | self.socket.send(&ping_msg, addr).await? 214 | } 215 | MessageBody::Request(Request::FindNode(f)) => { 216 | let node = NodeHandle::new(f.id, addr); 217 | 218 | // Node requested from us, mark it in the Routingtable 219 | if let Some(n) = self.routing_table.lock().unwrap().find_node_mut(&node) { 220 | n.remote_request() 221 | } 222 | 223 | let (nodes_v4, nodes_v6) = self.find_closest_nodes(f.target, f.want)?; 224 | 225 | let find_node_rsp = Response { 226 | id: self.this_node_id, 227 | values: vec![], 228 | nodes_v4, 229 | nodes_v6, 230 | token: None, 231 | }; 232 | let find_node_msg = Message { 233 | transaction_id: message.transaction_id, 234 | body: MessageBody::Response(find_node_rsp), 235 | }; 236 | 237 | self.socket.send(&find_node_msg, addr).await? 238 | } 239 | MessageBody::Request(Request::GetPeers(g)) => { 240 | let node = NodeHandle::new(g.id, addr); 241 | 242 | // Node requested from us, mark it in the Routingtable 243 | if let Some(n) = self.routing_table.lock().unwrap().find_node_mut(&node) { 244 | n.remote_request() 245 | } 246 | 247 | // TODO: Check what the maximum number of values we can give without overflowing a udp packet 248 | // Also, if we arent going to give all of the contacts, we may want to shuffle which ones we give 249 | let values: Vec<_> = self 250 | .active_stores 251 | .find_items(&g.info_hash) 252 | .filter(|value_addr| { 253 | // According to the spec (BEP32), `values` should contain only addresses of the 254 | // same family as the address the request came from. The `want` field affects only 255 | // the `nodes` and `nodes6` fields, not the `values` field. 256 | match (addr, value_addr) { 257 | (SocketAddr::V4(_), SocketAddr::V4(_)) => true, 258 | (SocketAddr::V6(_), SocketAddr::V6(_)) => true, 259 | (SocketAddr::V4(_), SocketAddr::V6(_)) => false, 260 | (SocketAddr::V6(_), SocketAddr::V4(_)) => false, 261 | } 262 | }) 263 | .collect(); 264 | 265 | // Grab the closest nodes 266 | let (nodes_v4, nodes_v6) = self.find_closest_nodes(g.info_hash, g.want)?; 267 | let token = self.token_store.checkout(addr.ip()); 268 | 269 | let get_peers_rsp = Response { 270 | id: self.this_node_id, 271 | values, 272 | nodes_v4, 273 | nodes_v6, 274 | token: Some(token.as_ref().to_vec()), 275 | }; 276 | let get_peers_msg = Message { 277 | transaction_id: message.transaction_id, 278 | body: MessageBody::Response(get_peers_rsp), 279 | }; 280 | 281 | self.socket.send(&get_peers_msg, addr).await? 282 | } 283 | MessageBody::Request(Request::AnnouncePeer(a)) => { 284 | let node = NodeHandle::new(a.id, addr); 285 | 286 | // Node requested from us, mark it in the Routingtable 287 | if let Some(n) = self.routing_table.lock().unwrap().find_node_mut(&node) { 288 | n.remote_request() 289 | } 290 | 291 | // Validate the token 292 | let is_valid = match Token::new(&a.token) { 293 | Ok(t) => self.token_store.checkin(addr.ip(), t), 294 | Err(_) => false, 295 | }; 296 | 297 | // Create a socket address based on the implied/explicit port number 298 | let connect_addr = match a.port { 299 | None => addr, 300 | Some(port) => { 301 | let mut addr = addr; 302 | addr.set_port(port); 303 | addr 304 | } 305 | }; 306 | 307 | // Resolve type of response we are going to send 308 | let response_msg = if !is_valid { 309 | // Node gave us an invalid token 310 | log::debug!( 311 | "{}: Remote node sent us an invalid token for an AnnounceRequest", 312 | self.ip_version() 313 | ); 314 | Message { 315 | transaction_id: message.transaction_id, 316 | body: MessageBody::Error(Error { 317 | code: error_code::PROTOCOL_ERROR, 318 | message: "received an invalid token".to_owned(), 319 | }), 320 | } 321 | } else if self.active_stores.add_item(a.info_hash, connect_addr) { 322 | // Node successfully stored the value with us, send an announce response 323 | Message { 324 | transaction_id: message.transaction_id, 325 | body: MessageBody::Response(Response { 326 | id: self.this_node_id, 327 | values: vec![], 328 | nodes_v4: vec![], 329 | nodes_v6: vec![], 330 | token: None, 331 | }), 332 | } 333 | } else { 334 | // Node unsuccessfully stored the value with us, send them an error message 335 | // TODO: Spec doesnt actually say what error message to send, or even if we should send one... 336 | log::warn!( 337 | "{}: AnnounceStorage failed to store contact information because it is full", self.ip_version() 338 | ); 339 | 340 | Message { 341 | transaction_id: message.transaction_id, 342 | body: MessageBody::Error(Error { 343 | code: error_code::SERVER_ERROR, 344 | message: "announce storage is full".to_owned(), 345 | }), 346 | } 347 | }; 348 | 349 | self.socket.send(&response_msg, addr).await? 350 | } 351 | MessageBody::Response(rsp) => { 352 | let trans_id = TransactionID::from_bytes(&message.transaction_id) 353 | .ok_or(WorkerError::InvalidTransactionId)?; 354 | self.handle_incoming_response(trans_id, addr, rsp).await?; 355 | } 356 | MessageBody::Error(_) => (), 357 | } 358 | 359 | Ok(()) 360 | } 361 | 362 | async fn handle_incoming_response( 363 | &mut self, 364 | trans_id: TransactionID, 365 | addr: SocketAddr, 366 | rsp: Response, 367 | ) -> Result<(), WorkerError> { 368 | let node = Node::as_good(rsp.id, addr); 369 | 370 | let nodes = match self.socket.ip_version() { 371 | IpVersion::V4 => &rsp.nodes_v4, 372 | IpVersion::V6 => &rsp.nodes_v6, 373 | }; 374 | 375 | if let Some(lookup) = self.lookups.get_mut(&trans_id.action_id()) { 376 | self.routing_table 377 | .lock() 378 | .unwrap() 379 | .add_nodes(node.clone(), nodes); 380 | 381 | match lookup 382 | .recv_response(node, &trans_id, rsp, &self.socket, &mut self.timer) 383 | .await 384 | { 385 | ActionStatus::Ongoing => (), 386 | ActionStatus::Completed => self.handle_lookup_completed(trans_id).await, 387 | } 388 | } else if self.refresh.action_id() == trans_id.action_id() { 389 | self.routing_table.lock().unwrap().add_nodes(node, nodes); 390 | } else { 391 | return Err(WorkerError::UnsolicitedResponse); 392 | } 393 | 394 | Ok(()) 395 | } 396 | 397 | fn handle_start_bootstrap(&mut self) { 398 | self.bootstrap.start(); 399 | } 400 | 401 | fn handle_check_bootstrap(&mut self, tx: oneshot::Sender<()>) { 402 | if self.is_bootstrapped() { 403 | tx.send(()).unwrap_or(()) 404 | } else { 405 | let id = self.next_bootstrap_txs_id; 406 | self.next_bootstrap_txs_id += 1; 407 | self.bootstrap_txs.insert(id, tx); 408 | } 409 | } 410 | 411 | async fn handle_bootstrap_success(&mut self) { 412 | // Send notification that the bootstrap has completed. 413 | for (_, tx) in self.bootstrap_txs.drain() { 414 | tx.send(()).unwrap_or(()) 415 | } 416 | 417 | // Start the refresh action. 418 | self.handle_check_table_refresh().await; 419 | } 420 | 421 | async fn handle_start_lookup(&mut self, lookup: StartLookup) { 422 | // Start the lookup right now if not bootstrapping 423 | let mid_generator = self.aid_generator.generate(); 424 | let action_id = mid_generator.action_id(); 425 | 426 | let mut lookup = TableLookup::new( 427 | lookup.info_hash, 428 | lookup.announce, 429 | lookup.tx, 430 | mid_generator, 431 | self.routing_table.clone(), 432 | &self.socket, 433 | &mut self.timer, 434 | ) 435 | .await; 436 | 437 | if lookup.completed() { 438 | lookup.recv_finished(self.announce_port, &self.socket).await; 439 | } else { 440 | self.lookups.insert(action_id, lookup); 441 | } 442 | } 443 | 444 | fn handle_get_state(&self, tx: oneshot::Sender) { 445 | let table = self.routing_table.lock().unwrap(); 446 | tx.send(State { 447 | is_running: self.running, 448 | bootstrapped: self.is_bootstrapped(), 449 | good_node_count: table.num_good_nodes(), 450 | questionable_node_count: table.num_questionable_nodes(), 451 | bucket_count: table.buckets().count(), 452 | }) 453 | .unwrap_or(()) 454 | } 455 | 456 | fn handle_get_local_addr(&self, tx: oneshot::Sender) { 457 | tx.send(self.socket.local_addr()).unwrap_or(()) 458 | } 459 | 460 | async fn handle_check_lookup_timeout(&mut self, trans_id: TransactionID) { 461 | let lookup = if let Some(lookup) = self.lookups.get_mut(&trans_id.action_id()) { 462 | lookup 463 | } else { 464 | log::error!( 465 | "{}: Resolved a TransactionID to a check table lookup but no action found", 466 | self.ip_version() 467 | ); 468 | return; 469 | }; 470 | 471 | let lookup_status = lookup 472 | .recv_timeout(&trans_id, &self.socket, &mut self.timer) 473 | .await; 474 | 475 | match lookup_status { 476 | ActionStatus::Ongoing => (), 477 | ActionStatus::Completed => self.handle_lookup_completed(trans_id).await, 478 | } 479 | } 480 | 481 | async fn handle_check_lookup_endgame(&mut self, trans_id: TransactionID) { 482 | self.handle_lookup_completed(trans_id).await 483 | } 484 | 485 | async fn handle_lookup_completed(&mut self, trans_id: TransactionID) { 486 | let mut lookup = if let Some(lookup) = self.lookups.remove(&trans_id.action_id()) { 487 | lookup 488 | } else { 489 | log::error!("{}: Lookup not found", self.ip_version()); 490 | return; 491 | }; 492 | 493 | lookup.recv_finished(self.announce_port, &self.socket).await 494 | } 495 | 496 | async fn handle_check_table_refresh(&mut self) { 497 | self.refresh 498 | .continue_refresh(&self.socket, &mut self.timer) 499 | .await 500 | } 501 | 502 | fn shutdown(&mut self) { 503 | self.running = false; 504 | } 505 | 506 | fn find_closest_nodes( 507 | &self, 508 | target: InfoHash, 509 | want: Option, 510 | ) -> Result<(Vec, Vec), WorkerError> { 511 | let want = match want { 512 | Some(want) => want, 513 | None => match self.socket.ip_version() { 514 | IpVersion::V4 => Want::V4, 515 | IpVersion::V6 => Want::V6, 516 | }, 517 | }; 518 | 519 | let table = self.routing_table.lock().unwrap(); 520 | 521 | let nodes_v4 = if matches!(want, Want::V4 | Want::Both) { 522 | table 523 | .closest_nodes(target) 524 | .filter(|node| node.addr().is_ipv4()) 525 | .take(8) 526 | .map(|node| *node.handle()) 527 | .collect() 528 | } else { 529 | vec![] 530 | }; 531 | 532 | let nodes_v6 = if matches!(want, Want::V6 | Want::Both) { 533 | table 534 | .closest_nodes(target) 535 | .filter(|node| node.addr().is_ipv6()) 536 | .take(8) 537 | .map(|node| *node.handle()) 538 | .collect() 539 | } else { 540 | vec![] 541 | }; 542 | 543 | Ok((nodes_v4, nodes_v6)) 544 | } 545 | 546 | fn handle_load_contacts( 547 | &self, 548 | tx: oneshot::Sender<(HashSet, HashSet)>, 549 | ) { 550 | tx.send(self.routing_table.lock().unwrap().load_contacts()) 551 | .unwrap_or(()); 552 | } 553 | } 554 | 555 | // ----------------------------------------------------------------------------// 556 | -------------------------------------------------------------------------------- /src/info_hash.rs: -------------------------------------------------------------------------------- 1 | use rand::{ 2 | distributions::{Distribution, Standard}, 3 | Rng, 4 | }; 5 | use serde::{Deserialize, Serialize}; 6 | use sha1::{Digest, Sha1}; 7 | use std::{ 8 | convert::{TryFrom, TryInto}, 9 | fmt, 10 | net::IpAddr, 11 | ops::BitXor, 12 | }; 13 | use thiserror::Error; 14 | 15 | /// Length of `InfoHash` in bytes. 16 | pub const INFO_HASH_LEN: usize = 20; 17 | 18 | /// 20-byte long identifier of nodes and objects on the DHT 19 | #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] 20 | #[repr(transparent)] 21 | pub struct InfoHash(#[serde(with = "byte_array")] [u8; INFO_HASH_LEN]); 22 | 23 | impl InfoHash { 24 | /// Generate InfoHash from the IP address as described by BEP42 25 | /// https://www.bittorrent.org/beps/bep_0042.html 26 | pub fn from_ip(ip: IpAddr) -> Self { 27 | let v4_mask: [u8; 8] = [0x03, 0x0f, 0x3f, 0xff, 0, 0, 0, 0]; 28 | let v6_mask: [u8; 8] = [0x01, 0x03, 0x07, 0x0f, 0x1f, 0x3f, 0x7f, 0xff]; 29 | 30 | let (mut ip, num_octets, mask) = { 31 | let mut array = [0; 8]; 32 | let (mask, num_octets) = match ip { 33 | IpAddr::V4(ip) => { 34 | let num = 4; 35 | let octets = ip.octets(); 36 | array[..num].copy_from_slice(&octets[..num]); 37 | (v4_mask, num) 38 | } 39 | IpAddr::V6(ip) => { 40 | let num = 8; 41 | let octets = ip.octets(); 42 | array[..num].copy_from_slice(&octets[..num]); 43 | (v6_mask, num) 44 | } 45 | }; 46 | (array, num_octets, mask) 47 | }; 48 | 49 | for i in 0..num_octets { 50 | ip[i] &= mask[i]; 51 | } 52 | 53 | let rand = rand::random::(); 54 | ip[0] |= (rand & 0x7) << 5; 55 | 56 | let crc = crc32c::crc32c_append(0, &ip[0..num_octets]); 57 | 58 | let mut node_id: [u8; INFO_HASH_LEN] = [0; INFO_HASH_LEN]; 59 | 60 | node_id[0] = (crc >> 24).to_le_bytes()[0]; 61 | node_id[1] = (crc >> 16).to_le_bytes()[0]; 62 | node_id[2] = (crc >> 8).to_le_bytes()[0] & 0xf8 | (rand::random::() & 0x7); 63 | 64 | for byte in &mut node_id[3..19] { 65 | *byte = rand::random(); 66 | } 67 | 68 | node_id[19] = rand; 69 | 70 | Self(node_id) 71 | } 72 | 73 | /// Create a DhtId by hashing the given bytes using SHA-1. 74 | pub fn sha1(bytes: &[u8]) -> Self { 75 | let hash = Sha1::digest(bytes); 76 | Self(hash.into()) 77 | } 78 | 79 | /// Flip the bit at the given index. 80 | /// 81 | /// # Panics 82 | /// 83 | /// Panics if index is out of bounds (>= 160) 84 | pub(crate) fn flip_bit(self, index: usize) -> Self { 85 | let mut bytes = self.0; 86 | let (byte_index, bit_index) = (index / 8, index % 8); 87 | 88 | let actual_bit_index = 7 - bit_index; 89 | bytes[byte_index] ^= 1 << actual_bit_index; 90 | 91 | bytes.into() 92 | } 93 | 94 | /// Number of leading zero bits. 95 | pub(crate) fn leading_zeros(&self) -> u32 { 96 | let mut bits = 0; 97 | 98 | for byte in self.0 { 99 | bits += byte.leading_zeros(); 100 | 101 | if byte != 0 { 102 | break; 103 | } 104 | } 105 | 106 | bits 107 | } 108 | } 109 | 110 | impl AsRef<[u8]> for InfoHash { 111 | fn as_ref(&self) -> &[u8] { 112 | &self.0 113 | } 114 | } 115 | 116 | impl From for [u8; INFO_HASH_LEN] { 117 | fn from(hash: InfoHash) -> [u8; INFO_HASH_LEN] { 118 | hash.0 119 | } 120 | } 121 | 122 | impl From<[u8; INFO_HASH_LEN]> for InfoHash { 123 | fn from(hash: [u8; INFO_HASH_LEN]) -> InfoHash { 124 | Self(hash) 125 | } 126 | } 127 | 128 | #[derive(Debug, Error)] 129 | #[error("invalid id length")] 130 | pub struct LengthError; 131 | 132 | impl<'a> TryFrom<&'a [u8]> for InfoHash { 133 | type Error = LengthError; 134 | 135 | fn try_from(slice: &'a [u8]) -> Result { 136 | Ok(Self(slice.try_into().map_err(|_| LengthError)?)) 137 | } 138 | } 139 | 140 | impl BitXor for InfoHash { 141 | type Output = Self; 142 | 143 | fn bitxor(mut self, rhs: Self) -> Self { 144 | for (src, dst) in rhs.0.iter().zip(self.0.iter_mut()) { 145 | *dst ^= *src; 146 | } 147 | 148 | self 149 | } 150 | } 151 | 152 | impl Distribution for Standard { 153 | fn sample(&self, rng: &mut R) -> InfoHash { 154 | InfoHash(rng.gen()) 155 | } 156 | } 157 | 158 | impl fmt::LowerHex for InfoHash { 159 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 160 | for b in &self.0 { 161 | write!(f, "{b:02x}")?; 162 | } 163 | 164 | Ok(()) 165 | } 166 | } 167 | 168 | impl fmt::Debug for InfoHash { 169 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 170 | write!(f, "{self:x}") 171 | } 172 | } 173 | 174 | mod byte_array { 175 | use super::INFO_HASH_LEN; 176 | use serde::{ 177 | de::{Deserialize, Deserializer, Error}, 178 | ser::{Serialize, Serializer}, 179 | }; 180 | use serde_bytes::{ByteBuf, Bytes}; 181 | use std::convert::TryInto; 182 | 183 | pub(super) fn serialize( 184 | bytes: &[u8; INFO_HASH_LEN], 185 | s: S, 186 | ) -> Result { 187 | Bytes::new(bytes.as_ref()).serialize(s) 188 | } 189 | 190 | pub(super) fn deserialize<'de, D: Deserializer<'de>>( 191 | d: D, 192 | ) -> Result<[u8; INFO_HASH_LEN], D::Error> { 193 | let buf = ByteBuf::deserialize(d)?; 194 | let buf = buf.into_vec(); 195 | let len = buf.len(); 196 | 197 | buf.try_into().map_err(|_| { 198 | let expected = format!("{INFO_HASH_LEN}"); 199 | D::Error::invalid_length(len, &expected.as_ref()) 200 | }) 201 | } 202 | } 203 | 204 | // ----------------------------------------------------------------------------// 205 | 206 | /// Bittorrent `NodeId`. 207 | pub type NodeId = InfoHash; 208 | 209 | /// Length of a `NodeId`. 210 | pub const NODE_ID_LEN: usize = INFO_HASH_LEN; 211 | 212 | // ----------------------------------------------------------------------------// 213 | 214 | #[cfg(test)] 215 | mod tests { 216 | use super::*; 217 | 218 | #[test] 219 | fn positive_no_leading_zeroes() { 220 | let zero_bits = InfoHash::from([0u8; INFO_HASH_LEN]); 221 | let one_bits = InfoHash::from([255u8; INFO_HASH_LEN]); 222 | 223 | let xor_hash = zero_bits ^ one_bits; 224 | 225 | assert_eq!(xor_hash.leading_zeros(), 0) 226 | } 227 | 228 | #[test] 229 | fn positive_all_leading_zeroes() { 230 | let first_one_bits = InfoHash::from([255u8; INFO_HASH_LEN]); 231 | let second_one_bits = InfoHash::from([255u8; INFO_HASH_LEN]); 232 | 233 | let xor_hash = first_one_bits ^ second_one_bits; 234 | 235 | assert_eq!(xor_hash.leading_zeros() as usize, INFO_HASH_LEN * 8); 236 | } 237 | 238 | #[test] 239 | fn positive_one_leading_zero() { 240 | let zero_bits = InfoHash::from([0u8; INFO_HASH_LEN]); 241 | 242 | let mut bytes = [255u8; INFO_HASH_LEN]; 243 | bytes[0] = 127; 244 | let mostly_one_bits = InfoHash::from(bytes); 245 | 246 | let xor_hash = zero_bits ^ mostly_one_bits; 247 | 248 | assert_eq!(xor_hash.leading_zeros(), 1); 249 | } 250 | 251 | #[test] 252 | fn positive_one_trailing_zero() { 253 | let zero_bits = InfoHash::from([0u8; INFO_HASH_LEN]); 254 | 255 | let mut bytes = [255u8; INFO_HASH_LEN]; 256 | bytes[super::INFO_HASH_LEN - 1] = 254; 257 | let mostly_zero_bits = InfoHash::from(bytes); 258 | 259 | let xor_hash = zero_bits ^ mostly_zero_bits; 260 | 261 | assert_eq!(xor_hash.leading_zeros(), 0); 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Implementation of the Bittorrent Mainline Distributed Hash Table. 2 | 3 | // Mainline DHT extensions supported on behalf of libtorrent: 4 | // - Always send 'nodes' on a get_peers response even if 'values' is present 5 | // - Unrecognized requests which contain either an 'info_hash' or 'target' arguments are interpreted as 'find_node' TODO 6 | // - Client identification will be present in all outgoing messages in the form of the 'v' key TODO 7 | // const CLIENT_IDENTIFICATION: &'static [u8] = &[b'B', b'I', b'P', 0, 1]; 8 | 9 | // TODO: The Vuze dht operates over a protocol that is different than the mainline dht. 10 | // It would be possible to create a dht client that can work over both dhts simultaneously, 11 | // this would require essentially a completely separate routing table of course and so it 12 | // might make sense to make this distinction available to the user and allow them to startup 13 | // two dhts using the different protocols on their own. 14 | // const VUZE_DHT: (&'static str, u16) = ("dht.aelitis.com", 6881); 15 | 16 | pub mod router; 17 | 18 | mod action; 19 | mod bucket; 20 | mod compact; 21 | mod handler; 22 | mod info_hash; 23 | mod mainline_dht; 24 | pub mod message; 25 | mod node; 26 | mod socket; 27 | mod storage; 28 | mod table; 29 | #[cfg(test)] 30 | mod test; 31 | mod time; 32 | mod timer; 33 | mod token; 34 | mod transaction; 35 | 36 | pub use crate::action::State; 37 | pub use crate::info_hash::{InfoHash, LengthError, NodeId, INFO_HASH_LEN}; 38 | pub use crate::mainline_dht::{DhtBuilder, MainlineDht}; 39 | 40 | pub type IpVersion = crate::action::IpVersion; 41 | 42 | use async_trait::async_trait; 43 | use std::{io, net::SocketAddr}; 44 | 45 | #[async_trait] 46 | pub trait SocketTrait { 47 | async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<()>; 48 | async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)>; 49 | fn local_addr(&self) -> io::Result; 50 | } 51 | -------------------------------------------------------------------------------- /src/mainline_dht.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | action::{OneshotTask, StartLookup, State}, 3 | handler::DhtHandler, 4 | info_hash::{InfoHash, NodeId}, 5 | socket::Socket, 6 | SocketTrait, 7 | }; 8 | use futures_util::Stream; 9 | use std::{ 10 | collections::HashSet, 11 | io, 12 | net::SocketAddr, 13 | pin::Pin, 14 | task::{Context, Poll}, 15 | }; 16 | use tokio::{ 17 | sync::{mpsc, oneshot}, 18 | task, 19 | }; 20 | 21 | /// Maintains a Distributed Hash (Routing) Table. 22 | /// 23 | /// This type is cheaply cloneable where each clone refers to the same underlying DHT instance. This 24 | /// is useful to be able to issue DHT operations from multiple tasks/threads. 25 | /// 26 | /// # IPv6 27 | /// 28 | /// This implementation supports IPv6 as per [BEP32](https://www.bittorrent.org/beps/bep_0032.html). 29 | /// To enable dual-stack DHT (use both IPv4 and IPv6), one needs to create two separate 30 | /// `MainlineDht` instances, one bound to an IPv4 and the other to an IPv6 address. It is 31 | /// recommended that both instances use the same node id ([`DhtBuilder::set_node_id`]). Any lookup 32 | /// should then be performed on both instances and their results aggregated. 33 | #[derive(Clone)] 34 | pub struct MainlineDht { 35 | send: mpsc::UnboundedSender, 36 | } 37 | 38 | impl MainlineDht { 39 | /// Create a new DhtBuilder. 40 | pub fn builder() -> DhtBuilder { 41 | DhtBuilder { 42 | nodes: HashSet::new(), 43 | routers: HashSet::new(), 44 | read_only: true, 45 | announce_port: None, 46 | node_id: None, 47 | } 48 | } 49 | 50 | /// Start the MainlineDht with the given DhtBuilder. 51 | fn with_builder(builder: DhtBuilder, socket: Socket) -> Self { 52 | let (command_tx, command_rx) = mpsc::unbounded_channel(); 53 | 54 | // TODO: Utilize the security extension. 55 | let node_id = builder.node_id.unwrap_or_else(rand::random); 56 | 57 | let handler = DhtHandler::new( 58 | node_id, 59 | socket, 60 | builder.read_only, 61 | builder.routers, 62 | builder.nodes, 63 | builder.announce_port, 64 | command_rx, 65 | ); 66 | 67 | if command_tx.send(OneshotTask::StartBootstrap()).is_err() { 68 | // `unreachable` is OK here because the corresponding receiver definitely exists at 69 | // this point inside `handler`. 70 | unreachable!() 71 | } 72 | 73 | task::spawn(handler.run()); 74 | 75 | Self { send: command_tx } 76 | } 77 | 78 | /// Get the state of the DHT state machine, can be used for debugging. 79 | pub async fn get_state(&self) -> Option { 80 | let (tx, rx) = oneshot::channel(); 81 | 82 | if self.send.send(OneshotTask::GetState(tx)).is_err() { 83 | None 84 | } else { 85 | rx.await.ok() 86 | } 87 | } 88 | 89 | /// Waits until the DHT bootstrap completes, or returns immediately if it already completed. 90 | /// Returns whether the bootstrap was successful. 91 | pub async fn bootstrapped(&self) -> bool { 92 | let (tx, rx) = oneshot::channel(); 93 | 94 | if self.send.send(OneshotTask::CheckBootstrap(tx)).is_err() { 95 | // handler has shut down, consider this as bootstrap failure. 96 | false 97 | } else { 98 | rx.await.is_ok() 99 | } 100 | } 101 | 102 | /// Perform a search for the given InfoHash with an optional announce on the closest nodes. 103 | /// 104 | /// 105 | /// Announcing will place your contact information in the DHT so others performing lookups 106 | /// for the InfoHash will be able to find your contact information and initiate a handshake. 107 | /// 108 | /// If the initial bootstrap has not finished, the search will be queued and executed once 109 | /// the bootstrap has completed. 110 | pub fn search(&self, info_hash: InfoHash, announce: bool) -> SearchStream { 111 | let (tx, rx) = mpsc::unbounded_channel(); 112 | 113 | if self 114 | .send 115 | .send(OneshotTask::StartLookup(StartLookup { 116 | info_hash, 117 | announce, 118 | tx, 119 | })) 120 | .is_err() 121 | { 122 | log::error!("failed to start search - DhtHandler has shut down"); 123 | } 124 | 125 | SearchStream(rx) 126 | } 127 | 128 | /// Get the local address this DHT instance is bound to 129 | pub async fn local_addr(&self) -> io::Result { 130 | let (tx, rx) = oneshot::channel(); 131 | 132 | fn error() -> io::Error { 133 | io::Error::new(io::ErrorKind::Other, "DhtHandler has shut down") 134 | } 135 | 136 | self.send 137 | .send(OneshotTask::GetLocalAddr(tx)) 138 | .map_err(|_| error())?; 139 | 140 | rx.await.map_err(|_| error()) 141 | } 142 | 143 | /// Return IP:PORT pairs of "good" and "questionable" nodes from the routing table. 144 | pub async fn load_contacts(&self) -> io::Result<(HashSet, HashSet)> { 145 | let (tx, rx) = oneshot::channel(); 146 | 147 | fn error() -> io::Error { 148 | io::Error::new(io::ErrorKind::Other, "DhtHandler has shut down") 149 | } 150 | 151 | self.send 152 | .send(OneshotTask::LoadContacts(tx)) 153 | .map_err(|_| error())?; 154 | 155 | rx.await.map_err(|_| error()) 156 | } 157 | } 158 | 159 | /// Stream returned from [`MainlineDht::search()`] 160 | #[must_use = "streams do nothing unless polled"] 161 | pub struct SearchStream(mpsc::UnboundedReceiver); 162 | 163 | impl Stream for SearchStream { 164 | type Item = SocketAddr; 165 | 166 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { 167 | Pin::new(&mut self.0).poll_recv(cx) 168 | } 169 | } 170 | 171 | // ----------------------------------------------------------------------------// 172 | 173 | /// Stores information for initializing a DHT. 174 | #[derive(Debug)] 175 | pub struct DhtBuilder { 176 | nodes: HashSet, 177 | routers: HashSet, 178 | read_only: bool, 179 | announce_port: Option, 180 | node_id: Option, 181 | } 182 | 183 | impl DhtBuilder { 184 | /// Add nodes which will be distributed within our routing table. 185 | pub fn add_node(mut self, node_addr: SocketAddr) -> DhtBuilder { 186 | self.nodes.insert(node_addr); 187 | self 188 | } 189 | 190 | /// Add a router which will let us gather nodes if our routing table is ever empty. 191 | /// 192 | /// The difference between routers and nodes is that routers are not added to the routing table. 193 | pub fn add_router(mut self, router: String) -> DhtBuilder { 194 | self.routers.insert(router); 195 | self 196 | } 197 | 198 | /// Add routers. Same as calling `add_router` multiple times but more convenient in some cases. 199 | pub fn add_routers(mut self, routers: I) -> DhtBuilder 200 | where 201 | I: IntoIterator, 202 | T: Into, 203 | { 204 | self.routers.extend(routers.into_iter().map(|r| r.into())); 205 | self 206 | } 207 | 208 | /// Set the read only flag when communicating with other nodes. Indicates 209 | /// that remote nodes should not add us to their routing table. 210 | /// 211 | /// Used when we are behind a restrictive NAT and/or we want to decrease 212 | /// incoming network traffic. Defaults value is true. 213 | pub fn set_read_only(mut self, read_only: bool) -> DhtBuilder { 214 | self.read_only = read_only; 215 | 216 | self 217 | } 218 | 219 | /// Provide a port to include in the `announce_peer` requests we send. 220 | /// 221 | /// If this is not supplied, will use implied port. 222 | pub fn set_announce_port(mut self, port: u16) -> Self { 223 | self.announce_port = Some(port); 224 | self 225 | } 226 | 227 | /// Set the id of this node. If not provided, a random node id is generated. 228 | /// 229 | /// NOTE: when creating a double-stack DHT (ipv4 + ipv6), it's recommended that both DHTs use 230 | /// the same node id. 231 | pub fn set_node_id(mut self, id: NodeId) -> Self { 232 | self.node_id = Some(id); 233 | self 234 | } 235 | 236 | /// Start a mainline DHT with the current configuration and bind it to the provided socket. 237 | /// Fails only if `socket.local_addr()` fails. 238 | pub fn start( 239 | self, 240 | socket: S, 241 | ) -> io::Result { 242 | let socket = Socket::new(socket)?; 243 | Ok(MainlineDht::with_builder(self, socket)) 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /src/node.rs: -------------------------------------------------------------------------------- 1 | use std::fmt::{self, Debug, Formatter}; 2 | use std::hash::{Hash, Hasher}; 3 | use std::net::SocketAddr; 4 | use std::time::Duration; 5 | 6 | use crate::info_hash::NodeId; 7 | use crate::time::Instant; 8 | 9 | // TODO: Should remove as_* functions and replace them with from_requested, from_responded, etc to hide the logic 10 | // of the nodes initial status. 11 | 12 | // TODO: Should address the subsecond lookup paper where questionable nodes should not automatically be replaced with 13 | // good nodes, instead, questionable nodes should be pinged twice and then become available to be replaced. This reduces 14 | // GOOD node churn since after 15 minutes, a long lasting node could potentially be replaced by a short lived good node. 15 | // This strategy is actually what is vaguely specified in the standard? 16 | 17 | /// Maximum wait period before a node becomes questionable. 18 | const MAX_LAST_SEEN_MINS: u64 = 15; 19 | 20 | /// Maximum number of requests before a Questionable node becomes Bad. 21 | const MAX_REFRESH_REQUESTS: usize = 2; 22 | 23 | /// Status of the node. 24 | /// Ordering of the enumerations is important, variants higher 25 | /// up are considered to be less than those further down. 26 | #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, Ord, PartialOrd)] 27 | pub enum NodeStatus { 28 | Bad, 29 | Questionable, 30 | Good, 31 | } 32 | 33 | /// Node participating in the dht. 34 | #[derive(Clone)] 35 | pub struct Node { 36 | handle: NodeHandle, 37 | last_request: Option, 38 | last_response: Option, 39 | last_local_request: Option, 40 | refresh_requests: usize, 41 | } 42 | 43 | impl Node { 44 | /// Create a new node that has recently responded to us but never requested from us. 45 | pub fn as_good(id: NodeId, addr: SocketAddr) -> Node { 46 | Node { 47 | handle: NodeHandle { id, addr }, 48 | last_response: Some(Instant::now()), 49 | last_request: None, 50 | last_local_request: None, 51 | refresh_requests: 0, 52 | } 53 | } 54 | 55 | /// Create a questionable node that has responded to us before but never requested from us. 56 | pub fn as_questionable(id: NodeId, addr: SocketAddr) -> Node { 57 | let last_response_offset = Duration::from_secs(MAX_LAST_SEEN_MINS * 60); 58 | let last_response = Instant::now().checked_sub(last_response_offset).unwrap(); 59 | 60 | Node { 61 | handle: NodeHandle { id, addr }, 62 | last_response: Some(last_response), 63 | last_request: None, 64 | last_local_request: None, 65 | refresh_requests: 0, 66 | } 67 | } 68 | 69 | /// Create a new node that has never responded to us or requested from us. 70 | pub fn as_bad(id: NodeId, addr: SocketAddr) -> Node { 71 | Node { 72 | handle: NodeHandle { id, addr }, 73 | last_response: None, 74 | last_request: None, 75 | last_local_request: None, 76 | refresh_requests: 0, 77 | } 78 | } 79 | 80 | pub fn update(&mut self, other: Node) { 81 | assert_eq!(self.handle, other.handle); 82 | 83 | let self_status = self.status(); 84 | let other_status = other.status(); 85 | 86 | match (self_status, other_status) { 87 | (NodeStatus::Good, NodeStatus::Good) => { 88 | *self = Self { 89 | handle: self.handle, 90 | last_response: other.last_response, 91 | last_request: self.last_request, 92 | last_local_request: self.last_local_request, 93 | refresh_requests: 0, 94 | }; 95 | } 96 | (NodeStatus::Good, NodeStatus::Questionable) => {} 97 | (NodeStatus::Good, NodeStatus::Bad) => {} 98 | (NodeStatus::Questionable, NodeStatus::Good) => { 99 | *self = other; 100 | } 101 | (NodeStatus::Questionable, NodeStatus::Questionable) => {} 102 | (NodeStatus::Questionable, NodeStatus::Bad) => {} 103 | (NodeStatus::Bad, NodeStatus::Good) => { 104 | *self = other; 105 | } 106 | (NodeStatus::Bad, NodeStatus::Questionable) => { 107 | *self = other; 108 | } 109 | (NodeStatus::Bad, NodeStatus::Bad) => {} 110 | } 111 | } 112 | 113 | /// Record that we sent the node a request. 114 | pub fn local_request(&mut self) { 115 | self.last_local_request = Some(Instant::now()); 116 | 117 | if self.status() != NodeStatus::Good { 118 | self.refresh_requests = self.refresh_requests.saturating_add(1); 119 | } 120 | } 121 | 122 | /// Record that the node sent us a request. 123 | pub fn remote_request(&mut self) { 124 | self.last_request = Some(Instant::now()); 125 | } 126 | 127 | /// Return true if we have sent this node a request recently. 128 | pub fn recently_requested_from(&self) -> bool { 129 | if let Some(time) = self.last_local_request { 130 | // TODO: I made the 30 seconds up, seems reasonable. 131 | Instant::now() < time + Duration::from_secs(30) 132 | } else { 133 | false 134 | } 135 | } 136 | 137 | pub fn id(&self) -> NodeId { 138 | self.handle.id 139 | } 140 | 141 | pub fn addr(&self) -> SocketAddr { 142 | self.handle.addr 143 | } 144 | 145 | /// Current status of the node. 146 | /// 147 | /// The specification says: 148 | /// 149 | /// https://www.bittorrent.org/beps/bep_0005.html 150 | /// 151 | /// A good node is a node has responded to one of our queries within the last 15 minutes. A node is also good 152 | /// if it has ever responded to one of our queries and has sent us a query within the last 15 minutes. 153 | /// After 15 minutes of inactivity, a node becomes questionable. Nodes become bad when they fail to respond to 154 | /// multiple queries in a row. 155 | pub fn status(&self) -> NodeStatus { 156 | let curr_time = Instant::now(); 157 | 158 | // Check if node has ever responded to us 159 | let since_response = match self.last_response { 160 | Some(response_time) => curr_time - response_time, 161 | None => return NodeStatus::Bad, 162 | }; 163 | 164 | // Check if node has recently responded to us 165 | if since_response < Duration::from_secs(MAX_LAST_SEEN_MINS * 60) { 166 | return NodeStatus::Good; 167 | } 168 | 169 | // Check if we have request from node multiple times already without response 170 | if self.refresh_requests >= MAX_REFRESH_REQUESTS { 171 | return NodeStatus::Bad; 172 | } 173 | 174 | // Check if the node has recently requested from us 175 | if let Some(request_time) = self.last_request { 176 | let since_request = curr_time - request_time; 177 | 178 | if since_request < Duration::from_secs(MAX_LAST_SEEN_MINS * 60) { 179 | return NodeStatus::Good; 180 | } 181 | } 182 | 183 | NodeStatus::Questionable 184 | } 185 | 186 | /// Is node good or questionable? 187 | pub fn is_pingable(&self) -> bool { 188 | // Function is moderately expensive 189 | let status = self.status(); 190 | status == NodeStatus::Good || status == NodeStatus::Questionable 191 | } 192 | 193 | pub(crate) fn handle(&self) -> &NodeHandle { 194 | &self.handle 195 | } 196 | } 197 | 198 | impl Eq for Node {} 199 | 200 | impl PartialEq for Node { 201 | fn eq(&self, other: &Node) -> bool { 202 | self.handle == other.handle 203 | } 204 | } 205 | 206 | impl Hash for Node { 207 | fn hash(&self, state: &mut H) 208 | where 209 | H: Hasher, 210 | { 211 | self.handle.hash(state); 212 | } 213 | } 214 | 215 | impl Debug for Node { 216 | fn fmt(&self, f: &mut Formatter) -> Result<(), fmt::Error> { 217 | f.debug_struct("Node") 218 | .field("id", &self.handle.id) 219 | .field("addr", &self.handle.addr) 220 | .field("last_request", &self.last_request) 221 | .field("last_response", &self.last_response) 222 | .field("refresh_requests", &self.refresh_requests) 223 | .finish() 224 | } 225 | } 226 | 227 | /// Node id + its socket address. 228 | #[derive(Copy, Clone, PartialEq, Eq, Hash)] 229 | pub struct NodeHandle { 230 | pub id: NodeId, 231 | pub addr: SocketAddr, 232 | } 233 | 234 | impl NodeHandle { 235 | pub fn new(id: NodeId, addr: SocketAddr) -> Self { 236 | Self { id, addr } 237 | } 238 | } 239 | 240 | impl Debug for NodeHandle { 241 | fn fmt(&self, f: &mut Formatter) -> fmt::Result { 242 | write!(f, "{:?}@{:?}", self.id, self.addr) 243 | } 244 | } 245 | 246 | #[cfg(test)] 247 | mod tests { 248 | use crate::time::Instant; 249 | use std::time::Duration; 250 | 251 | use crate::node::{Node, NodeStatus}; 252 | use crate::test; 253 | 254 | #[test] 255 | fn positive_as_bad() { 256 | let node = Node::as_bad(test::dummy_node_id(), test::dummy_socket_addr_v4()); 257 | 258 | assert_eq!(node.status(), NodeStatus::Bad); 259 | } 260 | 261 | #[test] 262 | fn positive_as_questionable() { 263 | let node = Node::as_questionable(test::dummy_node_id(), test::dummy_socket_addr_v4()); 264 | 265 | assert_eq!(node.status(), NodeStatus::Questionable); 266 | } 267 | 268 | #[test] 269 | fn positive_as_good() { 270 | let node = Node::as_good(test::dummy_node_id(), test::dummy_socket_addr_v4()); 271 | 272 | assert_eq!(node.status(), NodeStatus::Good); 273 | } 274 | 275 | #[test] 276 | fn positive_request_renewal() { 277 | let mut node = Node::as_questionable(test::dummy_node_id(), test::dummy_socket_addr_v4()); 278 | 279 | node.remote_request(); 280 | 281 | assert_eq!(node.status(), NodeStatus::Good); 282 | } 283 | 284 | #[test] 285 | fn positive_node_idle() { 286 | let mut node = Node::as_good(test::dummy_node_id(), test::dummy_socket_addr_v4()); 287 | 288 | let time_offset = Duration::from_secs(super::MAX_LAST_SEEN_MINS * 60); 289 | let idle_time = Instant::now().checked_sub(time_offset).unwrap(); 290 | 291 | node.last_response = Some(idle_time); 292 | 293 | assert_eq!(node.status(), NodeStatus::Questionable); 294 | } 295 | 296 | #[test] 297 | fn positive_node_idle_reqeusts() { 298 | let mut node = Node::as_questionable(test::dummy_node_id(), test::dummy_socket_addr_v4()); 299 | 300 | for _ in 0..super::MAX_REFRESH_REQUESTS { 301 | node.local_request(); 302 | } 303 | 304 | assert_eq!(node.status(), NodeStatus::Bad); 305 | } 306 | 307 | #[test] 308 | fn positive_good_status_ordering() { 309 | assert!(NodeStatus::Good > NodeStatus::Questionable); 310 | assert!(NodeStatus::Good > NodeStatus::Bad); 311 | } 312 | 313 | #[test] 314 | fn positive_questionable_status_ordering() { 315 | assert!(NodeStatus::Questionable > NodeStatus::Bad); 316 | assert!(NodeStatus::Questionable < NodeStatus::Good); 317 | } 318 | 319 | #[test] 320 | fn positive_bad_status_ordering() { 321 | assert!(NodeStatus::Bad < NodeStatus::Good); 322 | assert!(NodeStatus::Bad < NodeStatus::Questionable); 323 | } 324 | } 325 | -------------------------------------------------------------------------------- /src/router.rs: -------------------------------------------------------------------------------- 1 | //! Some known public DHT routers. 2 | 3 | // FIXME: this doesn't seem to work (bootstrap timeout) 4 | pub const UTORRENT_DHT: &str = "router.utorrent.com:6881"; 5 | pub const BITTORRENT_DHT: &str = "router.bittorrent.com:6881"; 6 | // FIXME: this doesn't seem to work (fails the DNS request) 7 | pub const BITCOMET_DHT: &str = "router.bitcomet.com:6881"; 8 | pub const TRANSMISSION_DHT: &str = "dht.transmissionbt.com:6881"; 9 | -------------------------------------------------------------------------------- /src/socket.rs: -------------------------------------------------------------------------------- 1 | //! Helpers to simplify work with UdpSocket. 2 | 3 | use super::IpVersion; 4 | use crate::{ 5 | message::{Message, TransactionId}, 6 | SocketTrait, 7 | }; 8 | use async_trait::async_trait; 9 | use std::{ 10 | collections::HashMap, 11 | future::Future, 12 | io, 13 | net::SocketAddr, 14 | pin::Pin, 15 | sync::{Arc, Mutex}, 16 | task::{Context, Poll, Waker}, 17 | time::Duration, 18 | }; 19 | use tokio::{net::UdpSocket, time::Sleep}; 20 | 21 | type Transactions = HashMap<(SocketAddr, TransactionId), Arc>>; 22 | 23 | pub struct Socket { 24 | inner_socket: Box, 25 | local_addr: SocketAddr, 26 | transactions: Arc>, 27 | } 28 | 29 | impl Socket { 30 | pub fn new(inner: S) -> io::Result { 31 | let local_addr = inner.local_addr()?; 32 | let inner_socket = Box::new(inner); 33 | Ok(Self { 34 | inner_socket, 35 | local_addr, 36 | transactions: Arc::new(Mutex::new(Default::default())), 37 | }) 38 | } 39 | 40 | pub(crate) async fn send(&self, message: &Message, addr: SocketAddr) -> io::Result<()> { 41 | log::trace!("Sending to {addr:?} {message:?}"); 42 | // Note: if the socket fails to send the entire buffer, then there is no point in trying to 43 | // send the rest (no node will attempt to reassemble two or more datagrams into a 44 | // meaningful message). 45 | self.inner_socket.send_to(&message.encode(), &addr).await?; 46 | Ok(()) 47 | } 48 | 49 | /// Send the message and return a future on which we can await the response. 50 | pub(crate) async fn send_request( 51 | &self, 52 | message: &Message, 53 | addr: SocketAddr, 54 | timeout: Duration, 55 | ) -> io::Result { 56 | let responded = self.responded(addr, message.transaction_id.clone(), timeout); 57 | self.send(message, addr).await?; 58 | Ok(responded) 59 | } 60 | 61 | /// This function is cancel safe: https://docs.rs/tokio/1.12.0/tokio/net/struct.UdpSocket.html#cancel-safety-6 62 | /// 63 | /// NOTE: This function is in a limbo state right now. Originally, it just received a message 64 | /// and returned the (Message, SocketAddr) pair, then the caller would decide what handler 65 | /// should handle it. Using it that way created a state machine that worked, but was hard to 66 | /// modify when we wanted more features (e.g. rebootstra, IP reuse,...). Later the 67 | /// `send_request` function was added which should allow us to receive responses directly in 68 | /// the code where requests are being sent. To avoid a complete and sudden rewrite, both 69 | /// approaches are now supported but would be good if we gradually switch to the latter. 70 | pub(crate) async fn recv(&self) -> io::Result<(Message, SocketAddr)> { 71 | let mut buffer = vec![0u8; 1500]; 72 | loop { 73 | let r = self.inner_socket.recv_from(&mut buffer).await; 74 | let (size, addr) = r?; 75 | match Message::decode(&buffer[0..size]) { 76 | Ok(message) => { 77 | if let Some(responded) = self 78 | .transactions 79 | .lock() 80 | .unwrap() 81 | .remove(&(addr, message.transaction_id.clone())) 82 | { 83 | responded.lock().unwrap().make_ready(message); 84 | } else { 85 | return Ok((message, addr)); 86 | } 87 | } 88 | Err(_) => { 89 | log::warn!( 90 | "{}: Failed decode incoming message from {addr:?}", 91 | self.ip_version() 92 | ); 93 | } 94 | } 95 | } 96 | } 97 | 98 | fn responded( 99 | &self, 100 | from: SocketAddr, 101 | transaction_id: TransactionId, 102 | timeout: Duration, 103 | ) -> Responded { 104 | let inner = Arc::new(Mutex::new(RespondedInner::new(timeout))); 105 | assert!(self 106 | .transactions 107 | .lock() 108 | .unwrap() 109 | .insert((from, transaction_id.clone()), inner.clone()) 110 | .is_none()); 111 | Responded { 112 | from, 113 | transaction_id, 114 | inner, 115 | transactions: self.transactions.clone(), 116 | } 117 | } 118 | 119 | pub fn local_addr(&self) -> SocketAddr { 120 | self.local_addr 121 | } 122 | 123 | pub fn ip_version(&self) -> IpVersion { 124 | match self.local_addr { 125 | SocketAddr::V4(_) => IpVersion::V4, 126 | SocketAddr::V6(_) => IpVersion::V6, 127 | } 128 | } 129 | } 130 | 131 | #[async_trait] 132 | impl SocketTrait for UdpSocket { 133 | async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<()> { 134 | UdpSocket::send_to(self, buf, target).await.map(|_| ()) 135 | } 136 | 137 | async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { 138 | UdpSocket::recv_from(self, buf).await 139 | } 140 | 141 | fn local_addr(&self) -> io::Result { 142 | UdpSocket::local_addr(self) 143 | } 144 | } 145 | 146 | /// Future awaiting a response 147 | pub(crate) struct Responded { 148 | from: SocketAddr, 149 | transaction_id: TransactionId, 150 | inner: Arc>, 151 | transactions: Arc>, 152 | } 153 | 154 | impl Drop for Responded { 155 | fn drop(&mut self) { 156 | self.transactions 157 | .lock() 158 | .unwrap() 159 | .remove(&(self.from, self.transaction_id.clone())); 160 | } 161 | } 162 | 163 | struct RespondedInner { 164 | sleep: Pin>, 165 | message: Option, 166 | waker: Option, 167 | } 168 | 169 | impl RespondedInner { 170 | fn new(timeout: Duration) -> Self { 171 | Self { 172 | sleep: Box::pin(tokio::time::sleep(timeout)), 173 | message: None, 174 | waker: None, 175 | } 176 | } 177 | 178 | fn make_ready(&mut self, message: Message) { 179 | // Should not happen because we remove `self` from `transactions` when first response 180 | // arrives. 181 | assert!(self.message.is_none()); 182 | 183 | self.message = Some(message); 184 | 185 | if let Some(waker) = self.waker.take() { 186 | waker.wake(); 187 | } 188 | } 189 | } 190 | 191 | impl Future for Responded { 192 | type Output = Option<(Message, SocketAddr)>; 193 | 194 | fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { 195 | let mut this = self.inner.lock().unwrap(); 196 | 197 | if let Some(message) = this.message.take() { 198 | return Poll::Ready(Some((message, self.from))); 199 | } 200 | 201 | match this.sleep.as_mut().poll(cx) { 202 | Poll::Ready(()) => Poll::Ready(None), 203 | Poll::Pending => { 204 | if this.waker.is_none() { 205 | this.waker = Some(cx.waker().clone()); 206 | } 207 | 208 | Poll::Pending 209 | } 210 | } 211 | } 212 | } 213 | -------------------------------------------------------------------------------- /src/storage.rs: -------------------------------------------------------------------------------- 1 | use std::collections::hash_map::Entry; 2 | use std::collections::HashMap; 3 | use std::net::SocketAddr; 4 | use std::time::Duration; 5 | 6 | use crate::info_hash::InfoHash; 7 | use crate::time::Instant; 8 | 9 | const MAX_ITEMS_STORED: usize = 500; 10 | 11 | /// Manages storage and expiration of contact information for a number of InfoHashs. 12 | pub struct AnnounceStorage { 13 | storage: HashMap>, 14 | expires: Vec, 15 | } 16 | 17 | impl AnnounceStorage { 18 | /// Create a new AnnounceStorage object. 19 | pub fn new() -> AnnounceStorage { 20 | AnnounceStorage { 21 | storage: HashMap::new(), 22 | expires: Vec::new(), 23 | } 24 | } 25 | 26 | /// Returns true if the item was added/it's existing expiration updated, false otherwise. 27 | pub fn add_item(&mut self, info_hash: InfoHash, address: SocketAddr) -> bool { 28 | self.add(info_hash, address, Instant::now()) 29 | } 30 | 31 | fn add(&mut self, info_hash: InfoHash, address: SocketAddr, curr_time: Instant) -> bool { 32 | // Clear out any old contacts that we have stored 33 | self.remove_expired_items(curr_time); 34 | let item = AnnounceItem::new(info_hash, address); 35 | let item_expiration = item.expiration(); 36 | 37 | // Check if we already have the item and want to update it's expiration 38 | match self.insert_contact(item) { 39 | Some(true) => { 40 | self.expires.retain(|i| i != &item_expiration); 41 | self.expires.push(item_expiration); 42 | 43 | true 44 | } 45 | Some(false) => { 46 | self.expires.push(item_expiration); 47 | 48 | true 49 | } 50 | None => false, 51 | } 52 | } 53 | 54 | /// Returns an iterator over all contacts for the given info hash. 55 | pub fn find_items<'a>( 56 | &'a mut self, 57 | info_hash: &'_ InfoHash, 58 | ) -> impl Iterator + 'a { 59 | self.find(info_hash, Instant::now()) 60 | } 61 | 62 | fn find<'a>( 63 | &'a mut self, 64 | info_hash: &'_ InfoHash, 65 | curr_time: Instant, 66 | ) -> impl Iterator + 'a { 67 | // Clear out any old contacts that we have stored 68 | self.remove_expired_items(curr_time); 69 | 70 | self.storage 71 | .get(info_hash) 72 | .into_iter() 73 | .flatten() 74 | .map(|item| item.address()) 75 | } 76 | 77 | /// Returns None if the contact could not be inserted, else, returns Some(true) if the contact was already 78 | /// in the table (and was replaced by the new entry) or Some(false) if the contact was not already in the 79 | /// table but was inserted. 80 | fn insert_contact(&mut self, item: AnnounceItem) -> Option { 81 | let item_info_hash = item.info_hash(); 82 | 83 | // Check if the contact is already in our list 84 | let already_in_list = if let Some(items) = self.storage.get_mut(&item_info_hash) { 85 | items.iter().any(|a| a == &item) 86 | } else { 87 | false 88 | }; 89 | 90 | // Check if we need to insert it into the list and if we have room 91 | match (already_in_list, self.expires.len() < MAX_ITEMS_STORED) { 92 | (false, true) => { 93 | // Place it into the appropriate list 94 | match self.storage.entry(item_info_hash) { 95 | Entry::Occupied(mut occ) => occ.get_mut().push(item), 96 | Entry::Vacant(vac) => { 97 | vac.insert(vec![item]); 98 | } 99 | }; 100 | 101 | Some(false) 102 | } 103 | (false, false) => None, 104 | (true, false) => Some(true), 105 | (true, true) => Some(true), 106 | } 107 | } 108 | 109 | /// Prunes all expired items from the internal list. 110 | fn remove_expired_items(&mut self, curr_time: Instant) { 111 | let num_expired_items = self 112 | .expires 113 | .iter() 114 | .take_while(|i| i.is_expired(curr_time)) 115 | .count(); 116 | 117 | // Remove the numbers of expired elements from the head of the list 118 | for item_expiration in self.expires.drain(0..num_expired_items) { 119 | let info_hash = item_expiration.info_hash(); 120 | 121 | // Get a mutable reference to the list of contacts and remove all contacts that 122 | // are associated with the expiration (should only be one such contact). 123 | let remove_info_hash = if let Some(items) = self.storage.get_mut(&info_hash) { 124 | items.retain(|a| a.expiration() != item_expiration); 125 | 126 | items.is_empty() 127 | } else { 128 | false 129 | }; 130 | 131 | // If we drained the list of contacts completely, remove the info hash entry 132 | if remove_info_hash { 133 | self.storage.remove(&info_hash); 134 | } 135 | } 136 | } 137 | } 138 | 139 | // ----------------------------------------------------------------------------// 140 | 141 | #[derive(Debug, Clone, PartialEq, Eq)] 142 | struct AnnounceItem { 143 | expiration: ItemExpiration, 144 | } 145 | 146 | impl AnnounceItem { 147 | pub fn new(info_hash: InfoHash, address: SocketAddr) -> AnnounceItem { 148 | AnnounceItem { 149 | expiration: ItemExpiration::new(info_hash, address), 150 | } 151 | } 152 | 153 | pub fn expiration(&self) -> ItemExpiration { 154 | self.expiration.clone() 155 | } 156 | 157 | pub fn address(&self) -> SocketAddr { 158 | self.expiration.address() 159 | } 160 | 161 | pub fn info_hash(&self) -> InfoHash { 162 | self.expiration.info_hash() 163 | } 164 | } 165 | 166 | // ----------------------------------------------------------------------------// 167 | 168 | const EXPIRATION_TIME: Duration = Duration::from_secs(24 * 60 * 60); 169 | 170 | #[derive(Debug, Clone)] 171 | struct ItemExpiration { 172 | address: SocketAddr, 173 | inserted: Instant, 174 | info_hash: InfoHash, 175 | } 176 | 177 | impl ItemExpiration { 178 | pub fn new(info_hash: InfoHash, address: SocketAddr) -> ItemExpiration { 179 | ItemExpiration { 180 | address, 181 | inserted: Instant::now(), 182 | info_hash, 183 | } 184 | } 185 | 186 | pub fn is_expired(&self, now: Instant) -> bool { 187 | now - self.inserted >= EXPIRATION_TIME 188 | } 189 | 190 | pub fn info_hash(&self) -> InfoHash { 191 | self.info_hash 192 | } 193 | 194 | pub fn address(&self) -> SocketAddr { 195 | self.address 196 | } 197 | } 198 | 199 | impl PartialEq for ItemExpiration { 200 | fn eq(&self, other: &ItemExpiration) -> bool { 201 | self.address() == other.address() && self.info_hash() == other.info_hash() 202 | } 203 | } 204 | 205 | impl Eq for ItemExpiration {} 206 | 207 | #[cfg(test)] 208 | mod tests { 209 | use crate::time::Instant; 210 | 211 | use crate::info_hash::INFO_HASH_LEN; 212 | use crate::storage::{self, AnnounceStorage}; 213 | use crate::test; 214 | 215 | #[test] 216 | fn positive_add_and_retrieve_contact() { 217 | let mut announce_store = AnnounceStorage::new(); 218 | let info_hash = [0u8; INFO_HASH_LEN].into(); 219 | let sock_addr = test::dummy_socket_addr_v4(); 220 | 221 | assert!(announce_store.add_item(info_hash, sock_addr)); 222 | 223 | let items: Vec<_> = announce_store.find_items(&info_hash).collect(); 224 | assert_eq!(items.len(), 1); 225 | 226 | assert_eq!(items[0], sock_addr); 227 | } 228 | 229 | #[test] 230 | fn positive_add_and_retrieve_contacts() { 231 | let mut announce_store = AnnounceStorage::new(); 232 | let info_hash = [0u8; INFO_HASH_LEN].into(); 233 | let sock_addrs = test::dummy_block_socket_addrs(storage::MAX_ITEMS_STORED as u16); 234 | 235 | for sock_addr in sock_addrs.iter() { 236 | assert!(announce_store.add_item(info_hash, *sock_addr)); 237 | } 238 | 239 | let items: Vec<_> = announce_store.find_items(&info_hash).collect(); 240 | assert_eq!(items.len(), storage::MAX_ITEMS_STORED); 241 | 242 | for item in items.iter() { 243 | assert!(sock_addrs.iter().any(|s| s == item)); 244 | } 245 | } 246 | 247 | #[test] 248 | fn positive_renew_contacts() { 249 | let mut announce_store = AnnounceStorage::new(); 250 | let info_hash = [0u8; INFO_HASH_LEN].into(); 251 | let sock_addrs = test::dummy_block_socket_addrs((storage::MAX_ITEMS_STORED + 1) as u16); 252 | 253 | for sock_addr in sock_addrs.iter().take(storage::MAX_ITEMS_STORED) { 254 | assert!(announce_store.add_item(info_hash, *sock_addr)); 255 | } 256 | 257 | // Try to add a new item 258 | let other_info_hash = [1u8; INFO_HASH_LEN].into(); 259 | 260 | // Returns false because it wasnt added 261 | assert!(!announce_store.add_item(other_info_hash, sock_addrs[sock_addrs.len() - 1])); 262 | // Iterator is empty because it wasnt added 263 | let count = announce_store.find_items(&other_info_hash).count(); 264 | assert_eq!(count, 0); 265 | 266 | // Try to add all of the initial nodes again (renew) 267 | for sock_addr in sock_addrs.iter().take(storage::MAX_ITEMS_STORED) { 268 | assert!(announce_store.add_item(info_hash, *sock_addr)); 269 | } 270 | } 271 | 272 | #[test] 273 | fn positive_full_storage_expire_one_infohash() { 274 | let mut announce_store = AnnounceStorage::new(); 275 | let info_hash = [0u8; INFO_HASH_LEN].into(); 276 | let sock_addrs = test::dummy_block_socket_addrs((storage::MAX_ITEMS_STORED + 1) as u16); 277 | 278 | // Fill up the announce storage completely 279 | for sock_addr in sock_addrs.iter().take(storage::MAX_ITEMS_STORED) { 280 | assert!(announce_store.add_item(info_hash, *sock_addr)); 281 | } 282 | 283 | // Try to add a new item into the storage (under a different info hash) 284 | let other_info_hash = [1u8; INFO_HASH_LEN].into(); 285 | 286 | // Returned false because it wasnt added 287 | assert!(!announce_store.add_item(other_info_hash, sock_addrs[sock_addrs.len() - 1])); 288 | // Iterator is empty because it wasnt added 289 | let count = announce_store.find_items(&other_info_hash).count(); 290 | assert_eq!(count, 0); 291 | 292 | // Try to add a new item into the storage mocking the current time 293 | let mock_current_time = Instant::now() + storage::EXPIRATION_TIME; 294 | assert!(announce_store.add( 295 | other_info_hash, 296 | sock_addrs[sock_addrs.len() - 1], 297 | mock_current_time 298 | )); 299 | // Iterator is not empty because it was added 300 | let count = announce_store.find_items(&other_info_hash).count(); 301 | assert_eq!(count, 1); 302 | } 303 | 304 | #[test] 305 | fn positive_full_storage_expire_two_infohash() { 306 | let mut announce_store = AnnounceStorage::new(); 307 | let info_hash_one = [0u8; INFO_HASH_LEN].into(); 308 | let info_hash_two = [1u8; INFO_HASH_LEN].into(); 309 | let sock_addrs = test::dummy_block_socket_addrs((storage::MAX_ITEMS_STORED + 1) as u16); 310 | 311 | // Fill up first info hash 312 | let num_contacts_first = storage::MAX_ITEMS_STORED / 2; 313 | for sock_addr in sock_addrs.iter().take(num_contacts_first) { 314 | assert!(announce_store.add_item(info_hash_one, *sock_addr)); 315 | } 316 | 317 | // Fill up second info hash 318 | let num_contacts_second = storage::MAX_ITEMS_STORED - num_contacts_first; 319 | for sock_addr in sock_addrs 320 | .iter() 321 | .skip(num_contacts_first) 322 | .take(num_contacts_second) 323 | { 324 | assert!(announce_store.add_item(info_hash_two, *sock_addr)); 325 | } 326 | 327 | // Try to add a third info hash with a contact 328 | let info_hash_three = [2u8; INFO_HASH_LEN].into(); 329 | assert!(!announce_store.add_item(info_hash_three, sock_addrs[sock_addrs.len() - 1])); 330 | // Iterator is empty because it was not added 331 | let count = announce_store.find_items(&info_hash_three).count(); 332 | assert_eq!(count, 0); 333 | 334 | // Try to add a new item into the storage mocking the current time 335 | let mock_current_time = Instant::now() + storage::EXPIRATION_TIME; 336 | assert!(announce_store.add( 337 | info_hash_three, 338 | sock_addrs[sock_addrs.len() - 1], 339 | mock_current_time 340 | )); 341 | // Iterator is not empty because it was added 342 | let count = announce_store.find_items(&info_hash_three).count(); 343 | assert_eq!(count, 1); 344 | } 345 | } 346 | -------------------------------------------------------------------------------- /src/table.rs: -------------------------------------------------------------------------------- 1 | use super::{ 2 | bucket::{self, Bucket}, 3 | node::{Node, NodeHandle, NodeStatus}, 4 | }; 5 | use crate::info_hash::{NodeId, INFO_HASH_LEN}; 6 | use std::{cmp::Ordering, collections::HashSet, iter::Filter, net::SocketAddr, slice::Iter}; 7 | 8 | pub const MAX_BUCKETS: usize = INFO_HASH_LEN * 8; 9 | 10 | /// Routing table containing a table of routing nodes as well 11 | /// as the id of the local node participating in the dht. 12 | pub struct RoutingTable { 13 | // Important: Our node id will always fall within the range 14 | // of the last bucket in the buckets array. 15 | buckets: Vec, 16 | node_id: NodeId, 17 | pub routers: HashSet, 18 | } 19 | 20 | impl RoutingTable { 21 | /// Create a new RoutingTable with the given node id as our id. 22 | pub fn new(node_id: NodeId) -> RoutingTable { 23 | let buckets = vec![Bucket::new()]; 24 | 25 | RoutingTable { 26 | buckets, 27 | node_id, 28 | routers: Default::default(), 29 | } 30 | } 31 | 32 | /// Return the node id of the RoutingTable. 33 | pub fn node_id(&self) -> NodeId { 34 | self.node_id 35 | } 36 | 37 | /// Iterator over the closest good nodes to the given node id. 38 | /// 39 | /// The closeness of nodes has a maximum granularity of a bucket. For most use 40 | /// cases this is fine since we will usually be performing lookups and aggregating 41 | /// a number of results equal to the size of a bucket. 42 | pub fn closest_nodes(&self, node_id: NodeId) -> ClosestNodes { 43 | ClosestNodes::new(&self.buckets, self.node_id, node_id) 44 | } 45 | 46 | /// Number of good nodes in the RoutingTable. 47 | pub fn num_good_nodes(&self) -> usize { 48 | self.closest_nodes(self.node_id()) 49 | .filter(|n| n.status() == NodeStatus::Good) 50 | .count() 51 | } 52 | 53 | /// Number of questionable nodes in the RoutingTable. 54 | pub fn num_questionable_nodes(&self) -> usize { 55 | self.closest_nodes(self.node_id()) 56 | .filter(|n| n.status() == NodeStatus::Questionable) 57 | .count() 58 | } 59 | 60 | /// Iterator over all buckets in the routing table. 61 | pub fn buckets(&self) -> impl ExactSizeIterator { 62 | self.buckets.iter() 63 | } 64 | 65 | /// Find an instance of the target node in the RoutingTable, if it exists. 66 | #[allow(unused)] 67 | pub fn find_node(&self, node: &NodeHandle) -> Option<&Node> { 68 | let bucket_index = self.bucket_index_for_node(node.id); 69 | let bucket = self.buckets.get(bucket_index)?; 70 | bucket.pingable_nodes().find(|n| n.handle() == node) 71 | } 72 | 73 | /// Find a mutable reference to an instance of the target node in the RoutingTable, if it 74 | /// exists. 75 | pub fn find_node_mut<'a>(&'a mut self, node: &'_ NodeHandle) -> Option<&'a mut Node> { 76 | let bucket_index = self.bucket_index_for_node(node.id); 77 | let bucket = self.buckets.get_mut(bucket_index)?; 78 | bucket.pingable_nodes_mut().find(|n| n.handle() == node) 79 | } 80 | 81 | fn bucket_index_for_node(&self, node_id: NodeId) -> usize { 82 | let bucket_index = leading_bit_count(self.node_id, node_id); 83 | 84 | // Check the sorted bucket 85 | if bucket_index < self.buckets.len() { 86 | // Got the sorted bucket 87 | bucket_index 88 | } else { 89 | // Grab the assorted bucket 90 | self.buckets 91 | .len() 92 | .checked_sub(1) 93 | .expect("no buckets present in RoutingTable - implementation error") 94 | } 95 | } 96 | 97 | /// Add the node to the RoutingTable if there is space for it. 98 | pub fn add_node(&mut self, node: Node) { 99 | if self.routers.contains(&node.addr()) { 100 | return; 101 | } 102 | 103 | // Doing some checks and calculations here, outside of the recursion 104 | if node.status() == NodeStatus::Bad { 105 | return; 106 | } 107 | let num_same_bits = leading_bit_count(self.node_id, node.id()); 108 | 109 | // Should not add a node that has the same id as us 110 | if num_same_bits != MAX_BUCKETS { 111 | self.bucket_node(node, num_same_bits); 112 | } 113 | } 114 | 115 | /// Convenience function to add what's usually the response payload. 116 | pub fn add_nodes(&mut self, node: Node, questionable_nodes: &[NodeHandle]) { 117 | self.add_node(node); 118 | 119 | // Add the payload nodes as questionable 120 | for questionable_node in questionable_nodes { 121 | self.add_node(Node::as_questionable( 122 | questionable_node.id, 123 | questionable_node.addr, 124 | )); 125 | } 126 | } 127 | 128 | /// Recursively tries to place the node into some bucket. 129 | fn bucket_node(&mut self, node: Node, num_same_bits: usize) { 130 | let bucket_index = bucket_placement(num_same_bits, self.buckets.len()); 131 | 132 | // Try to place in correct bucket 133 | if !self.buckets[bucket_index].add_node(node.clone()) { 134 | // Bucket was full, try to split it 135 | if self.split_bucket(bucket_index) { 136 | // Bucket split successfully, try to add again 137 | self.bucket_node(node, num_same_bits); 138 | } 139 | } 140 | } 141 | 142 | /// Tries to split the bucket at the specified index. 143 | /// 144 | /// Returns false if the split cannot be performed. 145 | fn split_bucket(&mut self, bucket_index: usize) -> bool { 146 | if !can_split_bucket(self.buckets.len(), bucket_index) { 147 | return false; 148 | } 149 | 150 | // Implementation is easier if we just remove the whole bucket, pretty 151 | // cheap to copy and we can manipulate the new buckets while they are 152 | // in the RoutingTable already. 153 | let split_bucket = match self.buckets.pop() { 154 | Some(bucket) => bucket, 155 | None => panic!("no buckets present in RoutingTable - implementation error"), 156 | }; 157 | 158 | // Push two more buckets to distribute nodes between 159 | self.buckets.push(Bucket::new()); 160 | self.buckets.push(Bucket::new()); 161 | 162 | for node in split_bucket.iter() { 163 | self.add_node(node.clone()); 164 | } 165 | 166 | true 167 | } 168 | 169 | pub fn load_contacts(&self) -> (HashSet, HashSet) { 170 | let mut good = HashSet::new(); 171 | let mut questionable = HashSet::new(); 172 | 173 | for bucket in &self.buckets { 174 | for node in bucket.iter() { 175 | if node.status() == NodeStatus::Good { 176 | good.insert(node.handle().addr); 177 | } else if node.status() == NodeStatus::Questionable { 178 | questionable.insert(node.handle().addr); 179 | } 180 | } 181 | } 182 | 183 | (good, questionable) 184 | } 185 | } 186 | 187 | /// Returns true if the bucket can be split. 188 | fn can_split_bucket(num_buckets: usize, bucket_index: usize) -> bool { 189 | bucket_index == num_buckets - 1 && bucket_index != MAX_BUCKETS - 1 190 | } 191 | 192 | /// Number of leading bits that are identical between the local and remote node ids. 193 | pub fn leading_bit_count(local_node: NodeId, remote_node: NodeId) -> usize { 194 | (local_node ^ remote_node).leading_zeros() as usize 195 | } 196 | 197 | /// Take the number of leading bits that are the same between our node and the remote 198 | /// node and calculate a bucket index for that node id. 199 | fn bucket_placement(num_same_bits: usize, num_buckets: usize) -> usize { 200 | // The index that the node should be placed in *eventually*, meaning 201 | // when we create enough buckets for that bucket to appear. 202 | let ideal_index = num_same_bits; 203 | 204 | if ideal_index >= num_buckets { 205 | num_buckets - 1 206 | } else { 207 | ideal_index 208 | } 209 | } 210 | 211 | // ----------------------------------------------------------------------------// 212 | 213 | // Iterator filter for only good nodes. 214 | type GoodNodes<'a> = Filter, fn(&&Node) -> bool>; 215 | 216 | // So what we are going to do here is iterate over every bucket in a hypothetically filled 217 | // routing table (buckets slice). If the bucket we are interested in has not been created 218 | // yet (not in the slice), go through the last bucket (assorted nodes) and check if any nodes 219 | // would have been placed in that bucket. If we find one, return it and mark it in our assorted 220 | // nodes array. 221 | pub struct ClosestNodes<'a> { 222 | buckets: &'a [Bucket], 223 | current_iter: Option>, 224 | current_index: usize, 225 | start_index: usize, 226 | // Since we could have assorted nodes that are interleaved between our sorted 227 | // nodes as far as closest nodes are concerned, we need some way to hand the 228 | // assorted nodes out and keep track of which ones we have handed out. 229 | // (Bucket Index, Node Reference, Returned Before) 230 | assorted_nodes: Option<[(usize, &'a Node, bool); bucket::MAX_BUCKET_SIZE]>, 231 | } 232 | 233 | impl<'a> ClosestNodes<'a> { 234 | fn new(buckets: &'a [Bucket], self_node_id: NodeId, other_node_id: NodeId) -> ClosestNodes<'a> { 235 | let start_index = leading_bit_count(self_node_id, other_node_id); 236 | 237 | let current_iter = bucket_iterator(buckets, start_index); 238 | let assorted_nodes = precompute_assorted_nodes(buckets, self_node_id); 239 | 240 | ClosestNodes { 241 | buckets, 242 | current_iter, 243 | current_index: start_index, 244 | start_index, 245 | assorted_nodes, 246 | } 247 | } 248 | } 249 | 250 | impl<'a> Iterator for ClosestNodes<'a> { 251 | type Item = &'a Node; 252 | 253 | fn next(&mut self) -> Option<&'a Node> { 254 | let current_index = self.current_index; 255 | 256 | // Check if we have any nodes left in the current iterator 257 | if let Some(ref mut iter) = self.current_iter { 258 | if let Some(node) = iter.next() { 259 | return Some(node); 260 | } 261 | } 262 | 263 | // Check if we have any nodes to give in the assorted bucket 264 | if let Some(ref mut nodes) = self.assorted_nodes { 265 | let mut nodes_iter = nodes.iter_mut().filter(|tup| is_good_node(&tup.1)); 266 | 267 | if let Some(node) = nodes_iter.find(|tup| tup.0 == current_index && !tup.2) { 268 | node.2 = true; 269 | return Some(node.1); 270 | }; 271 | } 272 | 273 | // Check if we can move to a new bucket 274 | match next_bucket_index(MAX_BUCKETS, self.start_index, self.current_index) { 275 | Some(new_index) => { 276 | self.current_index = new_index; 277 | self.current_iter = bucket_iterator(self.buckets, self.current_index); 278 | 279 | // Recurse back into this function to check the previous code paths again 280 | self.next() 281 | } 282 | None => None, 283 | } 284 | } 285 | } 286 | 287 | /// Optionally returns the precomputed bucket positions for all assorted nodes. 288 | fn precompute_assorted_nodes( 289 | buckets: &[Bucket], 290 | self_node_id: NodeId, 291 | ) -> Option<[(usize, &Node, bool); bucket::MAX_BUCKET_SIZE]> { 292 | if buckets.len() == MAX_BUCKETS { 293 | return None; 294 | } 295 | let assorted_bucket = &buckets[buckets.len() - 1]; 296 | let mut assorted_iter = assorted_bucket.iter().peekable(); 297 | 298 | // So the bucket is not empty and now we have a reference to initialize our stack allocated array. 299 | if let Some(&init_reference) = assorted_iter.peek() { 300 | // Set all tuples to true in case our bucket is not full. 301 | let mut assorted_nodes = [(0, init_reference, true); bucket::MAX_BUCKET_SIZE]; 302 | 303 | for (index, node) in assorted_iter.enumerate() { 304 | let bucket_index = leading_bit_count(self_node_id, node.id()); 305 | 306 | assorted_nodes[index] = (bucket_index, node, false); 307 | } 308 | 309 | Some(assorted_nodes) 310 | } else { 311 | None 312 | } 313 | } 314 | 315 | /// Optionally returns the filter iterator for the bucket at the specified index. 316 | fn bucket_iterator(buckets: &[Bucket], index: usize) -> Option { 317 | if buckets.len() == MAX_BUCKETS { 318 | buckets 319 | } else { 320 | &buckets[..(buckets.len() - 1)] 321 | } 322 | .get(index) 323 | .map(|bucket| good_node_filter(bucket.iter())) 324 | } 325 | 326 | /// Converts the given iterator into a filter iterator to return only good nodes. 327 | fn good_node_filter(iter: Iter) -> GoodNodes { 328 | iter.filter(is_good_node) 329 | } 330 | 331 | /// Shakes fist at iterator making me take a double reference (could avoid it by mapping, but oh well) 332 | fn is_good_node(node: &&Node) -> bool { 333 | let status = node.status(); 334 | 335 | status == NodeStatus::Good || status == NodeStatus::Questionable 336 | } 337 | 338 | /// Computes the next bucket index that should be visited given the number of buckets, the starting index 339 | /// and the current index. 340 | /// 341 | /// Returns None if all of the buckets have been visited. 342 | fn next_bucket_index(num_buckets: usize, start_index: usize, curr_index: usize) -> Option { 343 | // Since we prefer going right first, that means if we are on the right side then we want to go 344 | // to the same offset on the left, however, if we are on the left we want to go 1 past the offset 345 | // to the right. All assuming we can actually do this without going out of bounds. 346 | match curr_index.cmp(&start_index) { 347 | Ordering::Equal => { 348 | let right_index = start_index.checked_add(1); 349 | let left_index = start_index.checked_sub(1); 350 | 351 | if index_is_in_bounds(num_buckets, right_index) { 352 | Some(right_index.unwrap()) 353 | } else if index_is_in_bounds(num_buckets, left_index) { 354 | Some(left_index.unwrap()) 355 | } else { 356 | None 357 | } 358 | } 359 | Ordering::Greater => { 360 | let offset = curr_index - start_index; 361 | 362 | let left_index = start_index.checked_sub(offset); 363 | let right_index = curr_index.checked_add(1); 364 | 365 | if index_is_in_bounds(num_buckets, left_index) { 366 | Some(left_index.unwrap()) 367 | } else if index_is_in_bounds(num_buckets, right_index) { 368 | Some(right_index.unwrap()) 369 | } else { 370 | None 371 | } 372 | } 373 | Ordering::Less => { 374 | let offset = (start_index - curr_index) + 1; 375 | 376 | let right_index = start_index.checked_add(offset); 377 | let left_index = curr_index.checked_sub(1); 378 | 379 | if index_is_in_bounds(num_buckets, right_index) { 380 | Some(right_index.unwrap()) 381 | } else if index_is_in_bounds(num_buckets, left_index) { 382 | Some(left_index.unwrap()) 383 | } else { 384 | None 385 | } 386 | } 387 | } 388 | } 389 | 390 | /// Returns true if the overflow checked index is in bounds of the given length. 391 | fn index_is_in_bounds(length: usize, checked_index: Option) -> bool { 392 | match checked_index { 393 | Some(index) => index < length, 394 | None => false, 395 | } 396 | } 397 | 398 | // ----------------------------------------------------------------------------// 399 | 400 | #[cfg(test)] 401 | mod tests { 402 | use crate::bucket; 403 | use crate::info_hash::{NodeId, NODE_ID_LEN}; 404 | use crate::node::Node; 405 | use crate::table::{self, RoutingTable}; 406 | use crate::test; 407 | 408 | #[test] 409 | fn positive_add_node_max_recursion() { 410 | let table_id = [1u8; NODE_ID_LEN]; 411 | let mut table = RoutingTable::new(table_id.into()); 412 | 413 | let mut node_id = table_id; 414 | // Modify the id so it is placed in the last bucket 415 | node_id[NODE_ID_LEN - 1] = 0; 416 | 417 | // Trigger a bucket overflow and since the ids are placed in the last bucket, all of 418 | // the buckets will be recursively created and inserted into the list of all buckets. 419 | let block_addrs = test::dummy_block_socket_addrs((bucket::MAX_BUCKET_SIZE + 1) as u16); 420 | for block_addr in block_addrs { 421 | let node = Node::as_good(node_id.into(), block_addr); 422 | 423 | table.add_node(node); 424 | } 425 | } 426 | 427 | #[test] 428 | fn positive_initial_empty_buckets() { 429 | let table_id = [1u8; NODE_ID_LEN]; 430 | let table = RoutingTable::new(table_id.into()); 431 | 432 | assert_eq!(table.buckets().count(), 1); 433 | for bucket in table.buckets() { 434 | assert_eq!(bucket.pingable_nodes().count(), 0) 435 | } 436 | } 437 | 438 | #[test] 439 | fn positive_first_bucket_sorted() { 440 | let table_id = [1u8; NODE_ID_LEN]; 441 | let mut table = RoutingTable::new(table_id.into()); 442 | 443 | let mut node_id = table_id; 444 | // Flip first bit so we are placed in the first bucket 445 | node_id[0] |= 128; 446 | 447 | let block_addrs = test::dummy_block_socket_addrs((bucket::MAX_BUCKET_SIZE + 1) as u16); 448 | for block_addr in block_addrs { 449 | let node = Node::as_good(node_id.into(), block_addr); 450 | 451 | table.add_node(node); 452 | } 453 | 454 | // First bucket should be sorted 455 | assert_eq!(table.buckets().take(1).count(), 1); 456 | for bucket in table.buckets().take(1) { 457 | assert_eq!(bucket.pingable_nodes().count(), bucket::MAX_BUCKET_SIZE) 458 | } 459 | 460 | // Assorted bucket should show up 461 | assert_eq!(table.buckets().skip(1).count(), 1); 462 | for bucket in table.buckets().skip(1) { 463 | assert_eq!(bucket.pingable_nodes().count(), 0) 464 | } 465 | 466 | // There should be only two buckets 467 | assert_eq!(table.buckets().skip(2).count(), 0); 468 | } 469 | 470 | #[test] 471 | fn positive_last_bucket_sorted() { 472 | let table_id = [1u8; NODE_ID_LEN]; 473 | let mut table = RoutingTable::new(table_id.into()); 474 | 475 | let mut node_id = table_id; 476 | // Flip last bit so we are placed in the last bucket 477 | node_id[NODE_ID_LEN - 1] = 0; 478 | 479 | let block_addrs = test::dummy_block_socket_addrs((bucket::MAX_BUCKET_SIZE + 1) as u16); 480 | for block_addr in block_addrs { 481 | let node = Node::as_good(node_id.into(), block_addr); 482 | 483 | table.add_node(node); 484 | } 485 | 486 | // First buckets should be sorted (although they are all empty) 487 | assert_eq!( 488 | table.buckets().take(table::MAX_BUCKETS - 1).count(), 489 | table::MAX_BUCKETS - 1 490 | ); 491 | for bucket in table.buckets().take(table::MAX_BUCKETS - 1) { 492 | assert_eq!(bucket.pingable_nodes().count(), 0) 493 | } 494 | 495 | // Last bucket should be sorted 496 | assert_eq!( 497 | table.buckets().skip(table::MAX_BUCKETS - 1).take(1).count(), 498 | 1 499 | ); 500 | for bucket in table.buckets().skip(table::MAX_BUCKETS - 1).take(1) { 501 | assert_eq!(bucket.pingable_nodes().count(), bucket::MAX_BUCKET_SIZE) 502 | } 503 | } 504 | 505 | #[test] 506 | fn positive_all_sorted_buckets() { 507 | let table_id = NodeId::from([1u8; NODE_ID_LEN]); 508 | let mut table = RoutingTable::new(table_id); 509 | 510 | let block_addrs = test::dummy_block_socket_addrs(bucket::MAX_BUCKET_SIZE as u16); 511 | for bit_flip_index in 0..table::MAX_BUCKETS { 512 | for block_addr in &block_addrs { 513 | let bucket_node_id = table_id.flip_bit(bit_flip_index); 514 | 515 | table.add_node(Node::as_good(bucket_node_id, *block_addr)); 516 | } 517 | } 518 | 519 | assert_eq!(table.buckets().count(), table::MAX_BUCKETS); 520 | for bucket in table.buckets() { 521 | assert_eq!(bucket.pingable_nodes().count(), bucket::MAX_BUCKET_SIZE) 522 | } 523 | } 524 | 525 | #[test] 526 | fn negative_node_id_equal_table_id() { 527 | let table_id = [1u8; NODE_ID_LEN]; 528 | let mut table = RoutingTable::new(table_id.into()); 529 | 530 | assert_eq!(table.closest_nodes(table_id.into()).count(), 0); 531 | 532 | let node = Node::as_good(table_id.into(), test::dummy_socket_addr_v4()); 533 | table.add_node(node); 534 | 535 | assert_eq!(table.closest_nodes(table_id.into()).count(), 0); 536 | } 537 | } 538 | -------------------------------------------------------------------------------- /src/test.rs: -------------------------------------------------------------------------------- 1 | use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4}; 2 | 3 | use crate::info_hash::{NodeId, NODE_ID_LEN}; 4 | 5 | /// Generates a dummy Ipv4 address as an `IpAddr`. 6 | pub fn dummy_ipv4_addr() -> IpAddr { 7 | let v4_addr = Ipv4Addr::new(127, 0, 0, 1); 8 | 9 | IpAddr::V4(v4_addr) 10 | } 11 | 12 | /// Generates a dummy ipv6 address as an `IpAddr`. 13 | pub fn dummy_ipv6_addr() -> IpAddr { 14 | let v6_addr = Ipv6Addr::new(127, 0, 0, 1, 0, 0, 0, 0); 15 | 16 | IpAddr::V6(v6_addr) 17 | } 18 | 19 | /// Generates a dummy socket address v4 as a `SocketAddr`. 20 | pub fn dummy_socket_addr_v4() -> SocketAddr { 21 | let v4_addr = Ipv4Addr::new(127, 0, 0, 1); 22 | let v4_socket = SocketAddrV4::new(v4_addr, 0); 23 | 24 | SocketAddr::V4(v4_socket) 25 | } 26 | 27 | /// Generates a block of unique ipv4 addresses as Vec<`SocketAddr`> 28 | pub fn dummy_block_socket_addrs(num_addrs: u16) -> Vec { 29 | let mut addr_block = Vec::with_capacity(num_addrs as usize); 30 | 31 | for port in 0..num_addrs { 32 | let ip = Ipv4Addr::new(127, 0, 0, 1); 33 | let sock_addr = SocketAddrV4::new(ip, port); 34 | 35 | addr_block.push(SocketAddr::V4(sock_addr)); 36 | } 37 | 38 | addr_block 39 | } 40 | 41 | /// Generates a dummy node id as a `NodeId`. 42 | pub fn dummy_node_id() -> NodeId { 43 | NodeId::from([0u8; NODE_ID_LEN]) 44 | } 45 | 46 | /// Generates a block of unique dummy node ids as Vec<`NodeId`> 47 | pub fn dummy_block_node_ids(num_ids: u8) -> Vec { 48 | let mut id_block = Vec::with_capacity(num_ids as usize); 49 | 50 | for repeat in 0..num_ids { 51 | let mut id = [0u8; NODE_ID_LEN]; 52 | 53 | for byte in id.iter_mut() { 54 | *byte = repeat; 55 | } 56 | 57 | id_block.push(id.into()) 58 | } 59 | 60 | id_block 61 | } 62 | -------------------------------------------------------------------------------- /src/time.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fmt, 3 | ops::{Add, Sub}, 4 | time::{Duration, Instant as StdInstant}, 5 | }; 6 | 7 | const WEEK_IN_SECONDS: u64 = 7 * 24 * 60 * 60; 8 | const OFFSET: Duration = Duration::from_secs(WEEK_IN_SECONDS); 9 | 10 | /// This `Instant` structure is basically the same thing as `std::time::Instant` but internally 11 | /// shifted into the future by `OFFSET`. It is because on Windows we saw panics in this code: 12 | /// 13 | /// Instant::now().checked_sub(15 minutes).unwrap() 14 | /// 15 | /// Internally (at least on Windows) `Instant` is represented as `Duration` since some starting 16 | /// point. This starting point is not known by us (it may be since the device booted up, since the 17 | /// app started, since first use,...) and thus subtracting 15 minutes from the internal `Duration` 18 | /// could result in a negative value, which is not allowed and thus the `checked_sub` would return 19 | /// `None`. 20 | /// 21 | /// We considered other options to circumvent this issue but they all have their own cons: 22 | /// 23 | /// 1. `chrono::DateTime` and `std::time::SystemTime` are not monotonic. 24 | /// 2. We could just use `Some(Instant)` and `None` in calculations, but this would require careful 25 | /// refactor. 26 | #[derive(Clone, Copy, PartialOrd, PartialEq, Ord, Eq)] 27 | pub(crate) struct Instant { 28 | std_instant: StdInstant, 29 | } 30 | 31 | impl Instant { 32 | pub fn now() -> Self { 33 | Self { 34 | std_instant: StdInstant::now().checked_add(OFFSET).unwrap(), 35 | } 36 | } 37 | 38 | pub fn checked_sub(&self, rhs: Duration) -> Option { 39 | self.std_instant 40 | .checked_sub(rhs) 41 | .map(|std_instant| Self { std_instant }) 42 | } 43 | } 44 | 45 | impl Add for Instant { 46 | type Output = Self; 47 | 48 | fn add(self, rhs: Duration) -> Self { 49 | Self { 50 | std_instant: self.std_instant + rhs, 51 | } 52 | } 53 | } 54 | 55 | impl Sub for Instant { 56 | type Output = Self; 57 | 58 | fn sub(self, rhs: Duration) -> Self { 59 | Self { 60 | std_instant: self.std_instant - rhs, 61 | } 62 | } 63 | } 64 | 65 | impl Sub for Instant { 66 | type Output = Duration; 67 | 68 | fn sub(self, rhs: Instant) -> Duration { 69 | self.std_instant - rhs.std_instant 70 | } 71 | } 72 | 73 | impl fmt::Debug for Instant { 74 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { 75 | if let Some(instant) = self.std_instant.checked_sub(OFFSET) { 76 | instant.fmt(f) 77 | } else { 78 | f.write_fmt(format_args!("({:?} - one week)", self.std_instant)) 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/timer.rs: -------------------------------------------------------------------------------- 1 | use futures_util::Stream; 2 | use std::{ 3 | collections::BTreeMap, 4 | future::Future, 5 | pin::Pin, 6 | task::{Context, Poll}, 7 | time::{Duration, Instant}, 8 | }; 9 | use tokio::time::{self, Sleep}; 10 | 11 | #[derive(Clone, Copy, Ord, PartialOrd, Eq, PartialEq)] 12 | pub(crate) struct Timeout { 13 | deadline: Instant, 14 | id: u64, 15 | } 16 | 17 | pub(crate) struct Timer { 18 | next_id: u64, 19 | current: Option>, 20 | queue: BTreeMap, 21 | } 22 | 23 | impl Timer { 24 | pub fn new() -> Self { 25 | Self { 26 | next_id: 0, 27 | current: None, 28 | queue: BTreeMap::new(), 29 | } 30 | } 31 | 32 | /// Has the timer no scheduled timeouts? 33 | pub fn is_empty(&self) -> bool { 34 | self.current.is_none() && self.queue.is_empty() 35 | } 36 | 37 | pub fn schedule_in(&mut self, deadline: Duration, value: T) -> Timeout { 38 | self.schedule_at(Instant::now() + deadline, value) 39 | } 40 | 41 | pub fn schedule_at(&mut self, deadline: Instant, value: T) -> Timeout { 42 | // If the current timeout is later than the new one, push it back into the queue. 43 | if let Some(current) = &self.current { 44 | let key = current.key(); 45 | 46 | if deadline < key.deadline { 47 | let CurrentTimerEntry { value, .. } = self.current.take().unwrap(); 48 | self.queue.insert(key, value); 49 | } 50 | } 51 | 52 | let id = self.next_id(); 53 | let key = Timeout { deadline, id }; 54 | self.queue.insert(key, value); 55 | 56 | key 57 | } 58 | 59 | pub fn cancel(&mut self, timeout: Timeout) -> bool { 60 | if let Some(current) = &self.current { 61 | if current.key() == timeout { 62 | self.current = None; 63 | return true; 64 | } 65 | } 66 | 67 | self.queue.remove(&timeout).is_some() 68 | } 69 | 70 | fn next_id(&mut self) -> u64 { 71 | let id = self.next_id; 72 | self.next_id = self.next_id.wrapping_add(1); 73 | id 74 | } 75 | } 76 | 77 | impl Stream for Timer { 78 | type Item = T; 79 | 80 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { 81 | loop { 82 | if let Some(current) = &mut self.current { 83 | match current.sleep.as_mut().poll(cx) { 84 | Poll::Ready(()) => { 85 | let CurrentTimerEntry { value, .. } = self.current.take().unwrap(); 86 | return Poll::Ready(Some(value)); 87 | } 88 | Poll::Pending => return Poll::Pending, 89 | } 90 | } 91 | 92 | // TODO: use BTreeMap::pop_first when it becomes stable. 93 | let (key, value) = if let Some(key) = self.queue.keys().next().copied() { 94 | self.queue.remove_entry(&key).unwrap() 95 | } else { 96 | return Poll::Ready(None); 97 | }; 98 | 99 | self.current = Some(CurrentTimerEntry { 100 | sleep: Box::pin(time::sleep_until(key.deadline.into())), 101 | value, 102 | id: key.id, 103 | }); 104 | } 105 | } 106 | } 107 | 108 | struct CurrentTimerEntry { 109 | sleep: Pin>, 110 | value: T, 111 | id: u64, 112 | } 113 | 114 | impl CurrentTimerEntry { 115 | fn key(&self) -> Timeout { 116 | Timeout { 117 | deadline: self.sleep.deadline().into_std(), 118 | id: self.id, 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/token.rs: -------------------------------------------------------------------------------- 1 | use crate::time::Instant; 2 | use std::convert::TryInto; 3 | use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; 4 | use std::time::Duration; 5 | 6 | use crate::info_hash::{InfoHash, LengthError, INFO_HASH_LEN}; 7 | 8 | /// We will partially follow the bittorrent implementation for issuing tokens to nodes, the 9 | /// secret will change every 10 minutes and tokens up to 10 minutes old will be accepted. This 10 | /// is in contrast with the bittorrent implementation where the secret changes every 5 minutes 11 | /// and tokens up to 10 minutes old are accepted. Updating of the token will take place lazily. 12 | /// However, with our implementation we are not going to store tokens that we have issued, instead, 13 | /// store the secret and check if the token they gave us is valid for the current or last secret. 14 | /// This is technically not what we want, but it will have essentially the same result when we 15 | /// assume that nobody other than us knows the secret. 16 | 17 | /// With this scheme we can guarantee that the minimum amount of time a token can be valid for 18 | /// is the maximum amount of time a token is valid for in bittorrent in order to provide interop. 19 | /// Since we arent storing the tokens we generate (which is awesome) we CANT track how long each 20 | /// individual token has been checked out from the store and so each token is valid for some time 21 | /// between 10 and 20 minutes in contrast with 5 and 10 minutes. 22 | 23 | const REFRESH_INTERVAL: Duration = Duration::from_secs(10 * 60); 24 | 25 | const IPV4_SECRET_BUFFER_LEN: usize = 4 + 4; 26 | const IPV6_SECRET_BUFFER_LEN: usize = 16 + 4; 27 | 28 | #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] 29 | pub struct Token { 30 | token: [u8; INFO_HASH_LEN], 31 | } 32 | 33 | impl Token { 34 | pub fn new(bytes: &[u8]) -> Result { 35 | Ok(Self { 36 | token: bytes.try_into().map_err(|_| LengthError)?, 37 | }) 38 | } 39 | } 40 | 41 | impl From for [u8; INFO_HASH_LEN] { 42 | fn from(token: Token) -> [u8; INFO_HASH_LEN] { 43 | token.token 44 | } 45 | } 46 | 47 | impl From<[u8; INFO_HASH_LEN]> for Token { 48 | fn from(token: [u8; INFO_HASH_LEN]) -> Token { 49 | Token { token } 50 | } 51 | } 52 | 53 | impl AsRef<[u8]> for Token { 54 | fn as_ref(&self) -> &[u8] { 55 | &self.token 56 | } 57 | } 58 | 59 | // ----------------------------------------------------------------------------// 60 | 61 | #[derive(Copy, Clone)] 62 | pub struct TokenStore { 63 | curr_secret: u32, 64 | last_secret: u32, 65 | last_refresh: Instant, 66 | } 67 | 68 | impl TokenStore { 69 | pub fn new() -> TokenStore { 70 | // We cant just use a placeholder for the last secret as that would allow external 71 | // nodes to exploit recently started dhts. Instead, just generate another placeholder 72 | // secret for the last secret with the assumption that we wont get a valid announce 73 | // under that secret. We could go the option route but that isnt as clean. 74 | let curr_secret = rand::random::(); 75 | let last_secret = rand::random::(); 76 | let last_refresh = Instant::now(); 77 | 78 | TokenStore { 79 | curr_secret, 80 | last_secret, 81 | last_refresh, 82 | } 83 | } 84 | 85 | pub fn checkout(&mut self, addr: IpAddr) -> Token { 86 | self.refresh_check(); 87 | 88 | generate_token_from_addr(addr, self.curr_secret) 89 | } 90 | 91 | pub fn checkin(&mut self, addr: IpAddr, token: Token) -> bool { 92 | self.refresh_check(); 93 | 94 | validate_token_from_addr(addr, token, self.curr_secret, self.last_secret) 95 | } 96 | 97 | fn refresh_check(&mut self) { 98 | match intervals_passed(self.last_refresh) { 99 | 0 => (), 100 | 1 => { 101 | self.last_secret = self.curr_secret; 102 | self.curr_secret = rand::random::(); 103 | self.last_refresh = Instant::now(); 104 | } 105 | _ => { 106 | self.last_secret = rand::random::(); 107 | self.curr_secret = rand::random::(); 108 | self.last_refresh = Instant::now(); 109 | } 110 | }; 111 | } 112 | } 113 | 114 | /// Since we are lazily generating tokens, more than one interval could have passed since 115 | /// we last generated a token in which case our last secret AND current secret could be 116 | /// invalid. 117 | /// 118 | /// Returns the number of intervals that have passed since the last refresh time. 119 | fn intervals_passed(last_refresh: Instant) -> u64 { 120 | let curr_time = Instant::now(); 121 | let diff_time = curr_time - last_refresh; 122 | 123 | diff_time.as_secs() / REFRESH_INTERVAL.as_secs() 124 | } 125 | 126 | /// Generate a token from an ip address and a secret. 127 | fn generate_token_from_addr(addr: IpAddr, secret: u32) -> Token { 128 | match addr { 129 | IpAddr::V4(v4) => generate_token_from_addr_v4(v4, secret), 130 | IpAddr::V6(v6) => generate_token_from_addr_v6(v6, secret), 131 | } 132 | } 133 | 134 | /// Generate a token from an ipv4 address and a secret. 135 | fn generate_token_from_addr_v4(v4_addr: Ipv4Addr, secret: u32) -> Token { 136 | let mut buffer = [0u8; IPV4_SECRET_BUFFER_LEN]; 137 | let v4_bytes = v4_addr.octets(); 138 | let secret_bytes = secret.to_be_bytes(); 139 | 140 | let source_iter = v4_bytes.iter().chain(secret_bytes.iter()); 141 | for (dst, src) in buffer.iter_mut().zip(source_iter) { 142 | *dst = *src; 143 | } 144 | 145 | let hash_buffer = InfoHash::sha1(&buffer); 146 | Into::<[u8; INFO_HASH_LEN]>::into(hash_buffer).into() 147 | } 148 | 149 | /// Generate a token from an ipv6 address and a secret. 150 | fn generate_token_from_addr_v6(v6_addr: Ipv6Addr, secret: u32) -> Token { 151 | let mut buffer = [0u8; IPV6_SECRET_BUFFER_LEN]; 152 | let v6_bytes = v6_addr.octets(); 153 | let secret_bytes = secret.to_be_bytes(); 154 | 155 | let source_iter = v6_bytes.iter().chain(secret_bytes.iter()); 156 | for (dst, src) in buffer.iter_mut().zip(source_iter) { 157 | *dst = *src; 158 | } 159 | 160 | let hash_buffer = InfoHash::sha1(&buffer); 161 | Into::<[u8; INFO_HASH_LEN]>::into(hash_buffer).into() 162 | } 163 | 164 | /// Validate a token given an ip address and the two current secrets. 165 | fn validate_token_from_addr(addr: IpAddr, token: Token, secret_one: u32, secret_two: u32) -> bool { 166 | match addr { 167 | IpAddr::V4(v4) => { 168 | validate_token_from_addr_v4(v4, token, secret_one) 169 | || validate_token_from_addr_v4(v4, token, secret_two) 170 | } 171 | IpAddr::V6(v6) => { 172 | validate_token_from_addr_v6(v6, token, secret_one) 173 | || validate_token_from_addr_v6(v6, token, secret_two) 174 | } 175 | } 176 | } 177 | 178 | /// Validate a token given an ipv4 address and one secret. 179 | fn validate_token_from_addr_v4(v4_addr: Ipv4Addr, token: Token, secret: u32) -> bool { 180 | generate_token_from_addr_v4(v4_addr, secret) == token 181 | } 182 | 183 | /// Validate a token given an ipv6 address and one secret. 184 | fn validate_token_from_addr_v6(v6_addr: Ipv6Addr, token: Token, secret: u32) -> bool { 185 | generate_token_from_addr_v6(v6_addr, secret) == token 186 | } 187 | 188 | #[cfg(test)] 189 | mod tests { 190 | use crate::time::Instant; 191 | use std::time::Duration; 192 | 193 | use crate::test; 194 | use crate::token::TokenStore; 195 | 196 | #[test] 197 | fn positive_accept_valid_v4_token() { 198 | let mut store = TokenStore::new(); 199 | let v4_addr = test::dummy_ipv4_addr(); 200 | 201 | let valid_token = store.checkout(v4_addr); 202 | 203 | assert!(store.checkin(v4_addr, valid_token)); 204 | } 205 | 206 | #[test] 207 | fn positive_accept_valid_v6_token() { 208 | let mut store = TokenStore::new(); 209 | let v6_addr = test::dummy_ipv6_addr(); 210 | 211 | let valid_token = store.checkout(v6_addr); 212 | 213 | assert!(store.checkin(v6_addr, valid_token)); 214 | } 215 | 216 | #[test] 217 | fn positive_accept_v4_token_from_second_secret() { 218 | let mut store = TokenStore::new(); 219 | let v4_addr = test::dummy_ipv4_addr(); 220 | 221 | let valid_token = store.checkout(v4_addr); 222 | 223 | let past_offset = super::REFRESH_INTERVAL * 2 - Duration::from_secs(60); 224 | let past_time = Instant::now().checked_sub(past_offset).unwrap(); 225 | store.last_refresh = past_time; 226 | 227 | assert!(store.checkin(v4_addr, valid_token)); 228 | } 229 | 230 | #[test] 231 | fn positive_accept_v6_token_from_second_secret() { 232 | let mut store = TokenStore::new(); 233 | let v6_addr = test::dummy_ipv6_addr(); 234 | 235 | let valid_token = store.checkout(v6_addr); 236 | 237 | let past_offset = super::REFRESH_INTERVAL * 2 - Duration::from_secs(60); 238 | let past_time = Instant::now().checked_sub(past_offset).unwrap(); 239 | store.last_refresh = past_time; 240 | 241 | assert!(store.checkin(v6_addr, valid_token)); 242 | } 243 | 244 | #[test] 245 | #[should_panic] 246 | fn negative_reject_expired_v4_token() { 247 | let mut store = TokenStore::new(); 248 | let v4_addr = test::dummy_ipv4_addr(); 249 | 250 | let valid_token = store.checkout(v4_addr); 251 | 252 | let past_offset = super::REFRESH_INTERVAL * 2; 253 | let past_time = Instant::now().checked_sub(past_offset).unwrap(); 254 | store.last_refresh = past_time; 255 | 256 | assert!(store.checkin(v4_addr, valid_token)); 257 | } 258 | 259 | #[test] 260 | #[should_panic] 261 | fn negative_reject_expired_v6_token() { 262 | let mut store = TokenStore::new(); 263 | let v6_addr = test::dummy_ipv6_addr(); 264 | 265 | let valid_token = store.checkout(v6_addr); 266 | 267 | let past_offset = super::REFRESH_INTERVAL * 2; 268 | let past_time = Instant::now().checked_sub(past_offset).unwrap(); 269 | store.last_refresh = past_time; 270 | 271 | assert!(store.checkin(v6_addr, valid_token)); 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /src/transaction.rs: -------------------------------------------------------------------------------- 1 | use rand::seq::SliceRandom; 2 | use std::convert::TryInto; 3 | 4 | // Transaction IDs are going to be vital for both scalability and performance concerns. 5 | // They allow us to both protect against unsolicited responses as well as dropping those 6 | // messages as soon as possible. We are taking an absurdly large, lazily generate, ringbuffer 7 | // approach to generating transaction ids. 8 | 9 | // We are going for a simple, stateless (for the most part) implementation for generating 10 | // the transaction ids. We chose to go this route because 1, we dont want to reuse transaction 11 | // ids used in recent requests that had subsequent responses as well because this would make us 12 | // vulnerable to nodes that we gave that transaction id to, they would know we would be reusing 13 | // it soon. And 2, that makes for an unscalable approach unless we also have a timeout for ids 14 | // that we never received responses for which would lend itself to messy code. 15 | 16 | // Instead, we are going to pre-allocate a chunk of ids, shuffle them, and use them until they 17 | // run out, then pre-allocate some more, shuffle them, and use them. When we run out, (which wont 18 | // happen for a VERY long time) we will simply wrap around. Also, we are going to break down the 19 | // transaction id, so our transaction id will be made up of the first 5 bytes which will be the 20 | // action id, this would be something like an individual lookup, a bucket refresh, or a bootstrap. 21 | // Now, each of those actions have a number of messages associated with them, this is where the 22 | // last 3 bytes come in which will be the message id. This allows us to route messages appropriately 23 | // and associate them with some action we are performing right down to a message that the action is 24 | // expecting. The pre-allocation strategy is used both on the action id level as well as the message 25 | // id level. 26 | 27 | // To protect against timing attacks, where recently pinged nodes got our transaction id and wish 28 | // to guess other transaction ids in the block that we may have in flight, we will make the pre-allocation 29 | // space fairly large so that our shuffle provides a strong protection from these attacks. In the future, 30 | // we may want to dynamically ban nodes that we feel are guessing our transaction ids. 31 | 32 | // IMPORTANT: Allocation markers (not the actual allocated ids) are not shifted so that we can deal with 33 | // overflow by manually checking since I dont want to rely on langauge level overflows and whether they 34 | // cause a panic or not (debug and release should have similar semantics)! 35 | 36 | // Together these make up 8 bytes, or, a u64 37 | const TRANSACTION_ID_BYTES: usize = ACTION_ID_BYTES + MESSAGE_ID_BYTES; 38 | const ACTION_ID_BYTES: usize = 5; 39 | const MESSAGE_ID_BYTES: usize = 3; 40 | 41 | // Maximum exclusive value for an action id 42 | const ACTION_ID_SHIFT: usize = ACTION_ID_BYTES * 8; 43 | const MAX_ACTION_ID: u64 = 1 << ACTION_ID_SHIFT; 44 | 45 | // Maximum exclusive value for a message id 46 | const MESSAGE_ID_SHIFT: usize = MESSAGE_ID_BYTES * 8; 47 | const MAX_MESSAGE_ID: u64 = 1 << MESSAGE_ID_SHIFT; 48 | 49 | // Multiple of two so we can wrap around nicely 50 | #[cfg(not(test))] 51 | const ACTION_ID_PREALLOC_LEN: usize = 2048; 52 | #[cfg(not(test))] 53 | const MESSAGE_ID_PREALLOC_LEN: usize = 2048; 54 | 55 | // Reduce the pre allocation length in tests to speed them up significantly 56 | #[cfg(test)] 57 | const ACTION_ID_PREALLOC_LEN: usize = 16; 58 | #[cfg(test)] 59 | const MESSAGE_ID_PREALLOC_LEN: usize = 16; 60 | 61 | pub struct AIDGenerator { 62 | // NOT SHIFTED, so that we can wrap around manually! 63 | next_alloc: u64, 64 | curr_index: usize, 65 | action_ids: [u64; ACTION_ID_PREALLOC_LEN], 66 | } 67 | 68 | impl AIDGenerator { 69 | pub fn new() -> AIDGenerator { 70 | let (next_alloc, mut action_ids) = generate_aids(0); 71 | 72 | // Randomize the order of ids 73 | action_ids.shuffle(&mut rand::thread_rng()); 74 | 75 | AIDGenerator { 76 | next_alloc, 77 | curr_index: 0, 78 | action_ids, 79 | } 80 | } 81 | 82 | pub fn generate(&mut self) -> MIDGenerator { 83 | let opt_action_id = self.action_ids.get(self.curr_index).copied(); 84 | 85 | if let Some(action_id) = opt_action_id { 86 | self.curr_index += 1; 87 | 88 | // Shift the action id to make room for the message id 89 | MIDGenerator::new(action_id << MESSAGE_ID_SHIFT) 90 | } else { 91 | // Get a new block of action ids 92 | let (next_alloc, mut action_ids) = generate_aids(self.next_alloc); 93 | 94 | // Randomize the order of ids 95 | action_ids.shuffle(&mut rand::thread_rng()); 96 | 97 | self.next_alloc = next_alloc; 98 | self.action_ids = action_ids; 99 | self.curr_index = 0; 100 | 101 | self.generate() 102 | } 103 | } 104 | } 105 | 106 | // (next_alloc, aids) 107 | fn generate_aids(next_alloc: u64) -> (u64, [u64; ACTION_ID_PREALLOC_LEN]) { 108 | // Check if we need to wrap 109 | let (next_alloc_start, next_alloc_end) = if next_alloc == MAX_ACTION_ID { 110 | (0, ACTION_ID_PREALLOC_LEN as u64) 111 | } else { 112 | (next_alloc, next_alloc + ACTION_ID_PREALLOC_LEN as u64) 113 | }; 114 | let mut action_ids = [0u64; ACTION_ID_PREALLOC_LEN]; 115 | 116 | for (index, action_id) in (next_alloc_start..next_alloc_end).enumerate() { 117 | action_ids[index] = action_id; 118 | } 119 | 120 | (next_alloc_end, action_ids) 121 | } 122 | 123 | // ----------------------------------------------------------------------------// 124 | 125 | pub struct MIDGenerator { 126 | // ALREADY SHIFTED, for your convenience :) 127 | action_id: u64, 128 | // NOT SHIFTED, so that we can wrap around manually! 129 | next_alloc: u64, 130 | curr_index: usize, 131 | message_ids: [u64; MESSAGE_ID_PREALLOC_LEN], 132 | } 133 | 134 | impl MIDGenerator { 135 | // Accepts an action id that has ALREADY BEEN SHIFTED! 136 | fn new(action_id: u64) -> MIDGenerator { 137 | // In order to speed up tests, we will generate the first block lazily. 138 | MIDGenerator { 139 | action_id, 140 | next_alloc: 0, 141 | curr_index: MESSAGE_ID_PREALLOC_LEN, 142 | message_ids: [0u64; MESSAGE_ID_PREALLOC_LEN], 143 | } 144 | } 145 | 146 | pub fn action_id(&self) -> ActionID { 147 | ActionID::from_transaction_id(self.action_id) 148 | } 149 | 150 | pub fn generate(&mut self) -> TransactionID { 151 | let opt_message_id = self.message_ids.get(self.curr_index).copied(); 152 | 153 | if let Some(message_id) = opt_message_id { 154 | self.curr_index += 1; 155 | 156 | TransactionID::new(self.action_id | message_id) 157 | } else { 158 | // Get a new block of message ids 159 | let (next_alloc, mut message_ids) = generate_mids(self.next_alloc); 160 | 161 | // Randomize the order of ids 162 | message_ids.shuffle(&mut rand::thread_rng()); 163 | 164 | self.next_alloc = next_alloc; 165 | self.message_ids = message_ids; 166 | self.curr_index = 0; 167 | 168 | self.generate() 169 | } 170 | } 171 | } 172 | 173 | // (next_alloc, mids) 174 | fn generate_mids(next_alloc: u64) -> (u64, [u64; MESSAGE_ID_PREALLOC_LEN]) { 175 | // Check if we need to wrap 176 | let (next_alloc_start, next_alloc_end) = if next_alloc == MAX_MESSAGE_ID { 177 | (0, MESSAGE_ID_PREALLOC_LEN as u64) 178 | } else { 179 | (next_alloc, next_alloc + MESSAGE_ID_PREALLOC_LEN as u64) 180 | }; 181 | let mut message_ids = [0u64; MESSAGE_ID_PREALLOC_LEN]; 182 | 183 | for (index, message_id) in (next_alloc_start..next_alloc_end).enumerate() { 184 | message_ids[index] = message_id; 185 | } 186 | 187 | (next_alloc_end, message_ids) 188 | } 189 | 190 | // ----------------------------------------------------------------------------// 191 | 192 | #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] 193 | pub struct TransactionID { 194 | bytes: [u8; TRANSACTION_ID_BYTES], 195 | } 196 | 197 | impl TransactionID { 198 | fn new(trans_id: u64) -> TransactionID { 199 | let bytes = trans_id.to_be_bytes(); 200 | 201 | TransactionID { bytes } 202 | } 203 | 204 | /// Construct a transaction id from a series of bytes. 205 | pub fn from_bytes(bytes: &[u8]) -> Option { 206 | let bytes = bytes.try_into().ok()?; 207 | Some(Self { bytes }) 208 | } 209 | 210 | pub fn action_id(&self) -> ActionID { 211 | ActionID::from_transaction_id(u64::from_be_bytes(self.bytes)) 212 | } 213 | 214 | #[allow(unused)] 215 | pub fn message_id(&self) -> MessageID { 216 | MessageID::from_transaction_id(u64::from_be_bytes(self.bytes)) 217 | } 218 | } 219 | 220 | impl AsRef<[u8]> for TransactionID { 221 | fn as_ref(&self) -> &[u8] { 222 | &self.bytes 223 | } 224 | } 225 | 226 | // ----------------------------------------------------------------------------// 227 | 228 | #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] 229 | pub struct ActionID { 230 | action_id: u64, 231 | } 232 | 233 | impl ActionID { 234 | fn from_transaction_id(trans_id: u64) -> ActionID { 235 | // The ACTUAL action id 236 | let shifted_action_id = trans_id >> MESSAGE_ID_SHIFT; 237 | 238 | ActionID { 239 | action_id: shifted_action_id, 240 | } 241 | } 242 | } 243 | 244 | // ----------------------------------------------------------------------------// 245 | 246 | #[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] 247 | pub struct MessageID { 248 | message_id: u64, 249 | } 250 | 251 | impl MessageID { 252 | fn from_transaction_id(trans_id: u64) -> MessageID { 253 | let clear_action_id = MAX_MESSAGE_ID - 1; 254 | // The ACTUAL message id 255 | let shifted_message_id = trans_id & clear_action_id; 256 | 257 | MessageID { 258 | message_id: shifted_message_id, 259 | } 260 | } 261 | } 262 | 263 | // ----------------------------------------------------------------------------// 264 | 265 | #[cfg(test)] 266 | mod tests { 267 | use std::collections::HashSet; 268 | 269 | use super::{AIDGenerator, TransactionID}; 270 | 271 | #[test] 272 | fn positive_tid_from_bytes() { 273 | let mut aid_generator = AIDGenerator::new(); 274 | let mut mid_generator = aid_generator.generate(); 275 | 276 | let tid = mid_generator.generate(); 277 | let tid_from_bytes = TransactionID::from_bytes(tid.as_ref()).unwrap(); 278 | 279 | assert_eq!(tid, tid_from_bytes); 280 | } 281 | 282 | #[test] 283 | fn positive_unique_aid_blocks() { 284 | // Go through ten blocks worth of action ids, make sure they are unique 285 | let mut action_ids = HashSet::new(); 286 | let mut aid_generator = AIDGenerator::new(); 287 | 288 | for _ in 0..(super::ACTION_ID_PREALLOC_LEN * 10) { 289 | let action_id = aid_generator.generate().action_id(); 290 | 291 | assert!(!action_ids.contains(&action_id)); 292 | 293 | action_ids.insert(action_id); 294 | } 295 | } 296 | 297 | #[test] 298 | fn positive_unique_mid_blocks() { 299 | // Go through ten blocks worth of message ids, make sure they are unique 300 | let mut message_ids = HashSet::new(); 301 | let mut aid_generator = AIDGenerator::new(); 302 | let mut mid_generator = aid_generator.generate(); 303 | 304 | for _ in 0..(super::MESSAGE_ID_PREALLOC_LEN * 10) { 305 | let message_id = mid_generator.generate().message_id(); 306 | 307 | assert!(!message_ids.contains(&message_id)); 308 | 309 | message_ids.insert(message_id); 310 | } 311 | } 312 | 313 | #[test] 314 | fn positive_unique_tid_blocks() { 315 | // Go through two blocks of compound ids (transaction ids), make sure they are unique 316 | let mut transaction_ids = HashSet::new(); 317 | let mut aid_generator = AIDGenerator::new(); 318 | 319 | for _ in 0..(super::ACTION_ID_PREALLOC_LEN) { 320 | let mut mid_generator = aid_generator.generate(); 321 | 322 | for _ in 0..(super::MESSAGE_ID_PREALLOC_LEN) { 323 | let transaction_id = mid_generator.generate(); 324 | 325 | assert!(!transaction_ids.contains(&transaction_id)); 326 | 327 | transaction_ids.insert(transaction_id); 328 | } 329 | } 330 | } 331 | 332 | #[test] 333 | fn positive_overflow_aid_generate() { 334 | let mut action_ids = HashSet::new(); 335 | let mut aid_generator = AIDGenerator::new(); 336 | 337 | // Track all action ids in the first block 338 | for _ in 0..(super::ACTION_ID_PREALLOC_LEN) { 339 | let action_id = aid_generator.generate().action_id(); 340 | 341 | assert!(!action_ids.contains(&action_id)); 342 | 343 | action_ids.insert(action_id); 344 | } 345 | 346 | // Modify private variables to overflow back to first block 347 | aid_generator.next_alloc = super::MAX_ACTION_ID; 348 | aid_generator.curr_index = super::ACTION_ID_PREALLOC_LEN; 349 | 350 | // Check all action ids in the block (should be first block) 351 | for _ in 0..(super::ACTION_ID_PREALLOC_LEN) { 352 | let action_id = aid_generator.generate().action_id(); 353 | 354 | assert!(action_ids.remove(&action_id)); 355 | } 356 | 357 | assert!(action_ids.is_empty()); 358 | } 359 | 360 | #[test] 361 | fn positive_overflow_mid_generate() { 362 | let mut message_ids = HashSet::new(); 363 | let mut aid_generator = AIDGenerator::new(); 364 | let mut mid_generator = aid_generator.generate(); 365 | 366 | // Track all message ids in the first block 367 | for _ in 0..(super::MESSAGE_ID_PREALLOC_LEN) { 368 | let message_id = mid_generator.generate().message_id(); 369 | 370 | assert!(!message_ids.contains(&message_id)); 371 | 372 | message_ids.insert(message_id); 373 | } 374 | 375 | // Modify private variables to overflow back to first block 376 | mid_generator.next_alloc = super::MAX_MESSAGE_ID; 377 | mid_generator.curr_index = super::MESSAGE_ID_PREALLOC_LEN; 378 | 379 | // Check all message ids in the block (should be first block) 380 | for _ in 0..(super::MESSAGE_ID_PREALLOC_LEN) { 381 | let message_id = mid_generator.generate().message_id(); 382 | 383 | assert!(message_ids.remove(&message_id)); 384 | } 385 | 386 | assert!(message_ids.is_empty()); 387 | } 388 | 389 | #[test] 390 | fn positive_overflow_tid_generate() { 391 | let mut transaction_ids = HashSet::new(); 392 | let mut aid_generator = AIDGenerator::new(); 393 | 394 | // Track all transaction ids in the first block 395 | for _ in 0..(super::ACTION_ID_PREALLOC_LEN) { 396 | let mut mid_generator = aid_generator.generate(); 397 | 398 | for _ in 0..(super::MESSAGE_ID_PREALLOC_LEN) { 399 | let transaction_id = mid_generator.generate(); 400 | 401 | assert!(!transaction_ids.contains(&transaction_id)); 402 | 403 | transaction_ids.insert(transaction_id); 404 | } 405 | } 406 | 407 | // Modify private variables to overflow back to first block 408 | aid_generator.next_alloc = super::MAX_ACTION_ID; 409 | aid_generator.curr_index = super::ACTION_ID_PREALLOC_LEN; 410 | 411 | // Check all transaction ids in the block (should be first block) 412 | for _ in 0..(super::ACTION_ID_PREALLOC_LEN) { 413 | let mut mid_generator = aid_generator.generate(); 414 | 415 | for _ in 0..(super::MESSAGE_ID_PREALLOC_LEN) { 416 | let transaction_id = mid_generator.generate(); 417 | 418 | assert!(transaction_ids.remove(&transaction_id)); 419 | } 420 | } 421 | 422 | assert!(transaction_ids.is_empty()); 423 | } 424 | } 425 | -------------------------------------------------------------------------------- /tests/tests.rs: -------------------------------------------------------------------------------- 1 | use btdht::{InfoHash, MainlineDht}; 2 | use futures_util::StreamExt; 3 | use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; 4 | use tokio::net::UdpSocket; 5 | 6 | #[test_log::test(tokio::test(flavor = "multi_thread"))] 7 | async fn announce_and_lookup_v4() { 8 | announce_and_lookup(AddrFamily::V4).await; 9 | } 10 | 11 | #[test_log::test(tokio::test(flavor = "multi_thread"))] 12 | async fn announce_and_lookup_v6() { 13 | announce_and_lookup(AddrFamily::V6).await; 14 | } 15 | 16 | async fn announce_and_lookup(addr_family: AddrFamily) { 17 | // Start the router node for the other nodes to bootstrap against. 18 | let bootstrap_node_socket = UdpSocket::bind(localhost(addr_family)).await.unwrap(); 19 | let bootstrap_node_addr = bootstrap_node_socket.local_addr().unwrap(); 20 | let bootstrap_node = MainlineDht::builder() 21 | .set_read_only(false) 22 | .start(bootstrap_node_socket) 23 | .unwrap(); 24 | 25 | assert!(bootstrap_node.bootstrapped().await); 26 | 27 | // Start node A 28 | let a_socket = UdpSocket::bind(localhost(addr_family)).await.unwrap(); 29 | let a_addr = a_socket.local_addr().unwrap(); 30 | let a_node = MainlineDht::builder() 31 | .add_node(bootstrap_node_addr) 32 | .set_read_only(false) 33 | .start(a_socket) 34 | .unwrap(); 35 | 36 | // Start node B 37 | let b_socket = UdpSocket::bind(localhost(addr_family)).await.unwrap(); 38 | let b_node = MainlineDht::builder() 39 | .add_node(bootstrap_node_addr) 40 | .set_read_only(false) 41 | .start(b_socket) 42 | .unwrap(); 43 | 44 | // Wait for both nodes to bootstrap 45 | assert!(a_node.bootstrapped().await); 46 | assert!(b_node.bootstrapped().await); 47 | 48 | let the_info_hash = InfoHash::sha1(b"foo"); 49 | 50 | // Perform a lookup with announce by A. It should not return any peers initially but it should 51 | // make the network aware that A has the infohash. 52 | let mut search = a_node.search(the_info_hash, true); 53 | assert_eq!(search.next().await, None); 54 | 55 | // Now perform the lookup by B. It should find A. 56 | let mut search = b_node.search(the_info_hash, false); 57 | assert_eq!(search.next().await, Some(a_addr)); 58 | } 59 | 60 | #[derive(Copy, Clone)] 61 | enum AddrFamily { 62 | V4, 63 | V6, 64 | } 65 | 66 | fn localhost(family: AddrFamily) -> SocketAddr { 67 | match family { 68 | AddrFamily::V4 => (Ipv4Addr::LOCALHOST, 0).into(), 69 | AddrFamily::V6 => (Ipv6Addr::LOCALHOST, 0).into(), 70 | } 71 | } 72 | --------------------------------------------------------------------------------