├── .github └── workflows │ └── ci.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── examples ├── simple_run.rs ├── simulate_add_node.rs ├── simulate_node_failure.rs └── simulate_replica_repair.rs └── src ├── cluster.rs ├── error.rs ├── lib.rs ├── log.rs ├── network.rs ├── server.rs ├── state_mechine.rs └── storage.rs /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: RAFT-RS CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - name: Checkout code 11 | uses: actions/checkout@v2 12 | with: 13 | token: ${{ secrets.GITHUB_TOKEN }} 14 | 15 | - name: Set up Rust 16 | uses: actions-rs/toolchain@v1 17 | with: 18 | toolchain: stable 19 | override: true 20 | 21 | - name: Check formatting 22 | run: cargo fmt -- --check 23 | 24 | - name: Check Unused Dependencies 25 | run: | 26 | cargo install cargo-machete 27 | cargo machete 28 | 29 | - name: Check Linting 30 | run: cargo clippy -- -D warnings 31 | 32 | - name: Unit test 33 | run: cargo test 34 | 35 | - name: Build project 36 | run: cargo build 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | # server logs 17 | *.log 18 | 19 | # Added by cargo 20 | 21 | /target 22 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "raft_rs" 3 | version = "0.1.0" 4 | edition = "2021" 5 | authors = ["Vipul Vaibhaw "] 6 | description = "A Raft implementation by SpacewalkHq" 7 | license = "MIT" 8 | readme = "README.md" 9 | repository = "https://github.com/spacewalkhq/raft-rs" 10 | homepage = "https://github.com/spacewalkhq/raft-rs" 11 | keywords = ["raft", "distributed-systems", "consensus"] 12 | categories = ["algorithms", "network-programming", "distributed-systems"] 13 | 14 | [dependencies] 15 | tokio = { version = "1", features = ["full"] } 16 | futures = "0.3" 17 | async-trait = "0.1" 18 | bincode = "1.3.1" 19 | serde = { version = "1.0", features = ["derive"] } 20 | hex = "0.4" 21 | sha2 = "0.10.8" 22 | slog = "2.7.0" 23 | slog-term = "2.9.1" 24 | rand = "0.8" 25 | chrono = "0.4" 26 | thiserror = "1.0" 27 | 28 | [dev-dependencies] 29 | tempfile = "3.10.1" 30 | 31 | [[example]] 32 | name = "simple_run" 33 | path = "examples/simple_run.rs" 34 | 35 | [[example]] 36 | name = "simulate_node_failure" 37 | path = "examples/simulate_node_failure.rs" 38 | 39 | [[example]] 40 | name = "simulate_add_node" 41 | path = "examples/simulate_add_node.rs" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Spacewalk HQ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Raft-rs 2 | An understandable, fast, scalable and optimized implementation of [Raft consensus algorithm](https://en.wikipedia.org/wiki/Raft_(algorithm)). 3 | It is asynchronous(built on tokio runtime) and supports zero-copy. It does not assume storage to be non-malicious, if corrupted, it will repair the logs via peer-to-peer communication. 4 | 5 | ## Note 6 | - This project is still under development and is not yet production-ready. It is not recommended to use this in production environments. 7 | 8 | - We are actively working on this project to make it better and more reliable. If you have any suggestions or feedback, please feel free to open an issue or a pull request. This is true until we reach version 1.0.0. 9 | 10 | - Release every 2 weeks. 11 | 12 | ## Goals 13 | - [x] Understandable 14 | - [x] Fast 15 | - [x] Scalable 16 | - [x] Zero-Copy support 17 | - [x] Asynchronous 18 | - [x] Default Leader 19 | - [x] Leadership preference 20 | - [x] Log compaction 21 | - [x] Snapshot Support 22 | - [x] Tigerbeetle style replica repair 23 | - [x] Dynamic cluster membership changes support 24 | - [x] Test for dynamic cluster membership changes 25 | 26 | ## To-Do 27 | - [ ] Production-ready 28 | - [ ] Test replica repair thoroughly 29 | - [ ] io_uring support for linux 30 | - [ ] Complete batch write implementation 31 | - [ ] Improve Log compaction 32 | - [ ] RDMA support 33 | - [ ] Deterministic Simulation Testing 34 | - [ ] Benchmarking 35 | 36 | ## How to Run the Project 37 | 1. Ensure you have Rust installed. If not, follow the instructions [here](https://www.rust-lang.org/tools/install). 38 | 2. Clone the repository: 39 | ```sh 40 | git clone https://github.com/your-username/raft-rs.git 41 | cd raft-rs 42 | ``` 43 | 3. Run the project: 44 | ```sh 45 | cargo run --example simple_run 46 | ``` 47 | 4. Release the project: 48 | ```sh 49 | cargo build --release 50 | ``` 51 | 52 | ## Contributing 53 | Contributions are welcome! If you have any ideas, suggestions, or issues, please feel free to open an issue or a pull request. We aim to make this project better with your help. 54 | 55 | ## License 56 | This project is licensed under the MIT License. For more information, please refer to the [LICENSE](LICENSE) file. 57 | 58 | ## Contact 59 | For any questions or feedback, please reach out to [vaibhaw.vipul@gmail.com]. 60 | -------------------------------------------------------------------------------- /examples/simple_run.rs: -------------------------------------------------------------------------------- 1 | // Organization: SpacewalkHq 2 | // License: MIT License 3 | 4 | // make this file executable with `chmod +x examples/simple_run.rs` 5 | 6 | use raft_rs::cluster::{ClusterConfig, NodeMeta}; 7 | use raft_rs::log::get_logger; 8 | use slog::{error, info}; 9 | use std::collections::HashMap; 10 | use std::net::SocketAddr; 11 | use std::str::FromStr; 12 | use tokio::time::Duration; 13 | 14 | use raft_rs::network::{NetworkLayer, TCPManager}; 15 | use raft_rs::server::{Server, ServerConfig}; 16 | 17 | #[tokio::main] 18 | async fn main() { 19 | // Define cluster configuration 20 | let cluster_nodes = vec![1, 2, 3, 4, 5]; 21 | 22 | let peers = vec![ 23 | NodeMeta::from((1, SocketAddr::from_str("127.0.0.1:5001").unwrap())), 24 | NodeMeta::from((2, SocketAddr::from_str("127.0.0.1:5002").unwrap())), 25 | NodeMeta::from((3, SocketAddr::from_str("127.0.0.1:5003").unwrap())), 26 | NodeMeta::from((4, SocketAddr::from_str("127.0.0.1:5004").unwrap())), 27 | NodeMeta::from((5, SocketAddr::from_str("127.0.0.1:5005").unwrap())), 28 | ]; 29 | let cluster_config = ClusterConfig::new(peers.clone()); 30 | // Create server configs 31 | let configs: Vec<_> = peers 32 | .clone() 33 | .iter() 34 | .map(|n| ServerConfig { 35 | election_timeout: Duration::from_millis(1000), 36 | address: n.address, 37 | default_leader: Some(1u32), 38 | leadership_preferences: HashMap::new(), 39 | storage_location: Some("logs/".to_string()), 40 | }) 41 | .collect(); 42 | 43 | // Start servers in separate threads 44 | let mut handles = vec![]; 45 | for (i, config) in configs.into_iter().enumerate() { 46 | let id = cluster_nodes[i]; 47 | let cc = cluster_config.clone(); 48 | handles.push(tokio::spawn(async move { 49 | let mut server = Server::new(id, config, cc, None).await; 50 | server.start().await; 51 | })); 52 | } 53 | 54 | // Simulate a client request after some delay 55 | tokio::time::sleep(Duration::from_secs(20)).await; 56 | client_request(1, 42u32).await; 57 | tokio::time::sleep(Duration::from_secs(2)).await; 58 | for handle in handles { 59 | handle.await.unwrap(); 60 | } 61 | } 62 | 63 | async fn client_request(client_id: u32, data: u32) { 64 | let log = get_logger(); 65 | 66 | let server_address = SocketAddr::from_str("127.0.0.1:5001").unwrap(); // Assuming server 1 is the leader 67 | let network_manager = TCPManager::new(server_address); 68 | 69 | let request_data = vec![ 70 | client_id.to_be_bytes().to_vec(), 71 | 10u32.to_be_bytes().to_vec(), 72 | 6u32.to_be_bytes().to_vec(), 73 | data.to_be_bytes().to_vec(), 74 | ] 75 | .concat(); 76 | 77 | if let Err(e) = network_manager.send(&server_address, &request_data).await { 78 | error!(log, "Failed to send client request: {}", e); 79 | } 80 | 81 | // sleep for a while to allow the server to process the request 82 | tokio::time::sleep(Duration::from_secs(5)).await; 83 | 84 | let response = network_manager.receive().await.unwrap(); 85 | info!(log, "Received response: {:?}", response); 86 | } 87 | -------------------------------------------------------------------------------- /examples/simulate_add_node.rs: -------------------------------------------------------------------------------- 1 | // Organization: SpacewalkHq 2 | // License: MIT License 3 | 4 | use raft_rs::cluster::{ClusterConfig, NodeMeta}; 5 | use raft_rs::log::get_logger; 6 | use slog::error; 7 | use std::collections::HashMap; 8 | use std::net::SocketAddr; 9 | use std::str::FromStr; 10 | use tokio::time::Duration; 11 | 12 | use raft_rs::network::{NetworkLayer, TCPManager}; 13 | use raft_rs::server::{Server, ServerConfig}; 14 | 15 | #[tokio::main] 16 | async fn main() { 17 | // Define cluster configuration 18 | let cluster_nodes = vec![1, 2, 3, 4, 5]; 19 | let peers = vec![ 20 | NodeMeta::from((1, SocketAddr::from_str("127.0.0.1:5001").unwrap())), 21 | NodeMeta::from((2, SocketAddr::from_str("127.0.0.1:5002").unwrap())), 22 | NodeMeta::from((3, SocketAddr::from_str("127.0.0.1:5003").unwrap())), 23 | NodeMeta::from((4, SocketAddr::from_str("127.0.0.1:5004").unwrap())), 24 | NodeMeta::from((5, SocketAddr::from_str("127.0.0.1:5005").unwrap())), 25 | ]; 26 | let cluster_config = ClusterConfig::new(peers.clone()); 27 | // Create server configs 28 | let configs: Vec<_> = peers 29 | .clone() 30 | .iter() 31 | .map(|n| ServerConfig { 32 | election_timeout: Duration::from_millis(1000), 33 | address: n.address, 34 | default_leader: Some(1 as u32), 35 | leadership_preferences: HashMap::new(), 36 | storage_location: Some("logs/".to_string()), 37 | }) 38 | .collect(); 39 | 40 | // Start servers in separate threads 41 | let mut handles = vec![]; 42 | for (i, config) in configs.into_iter().enumerate() { 43 | let id = cluster_nodes[i]; 44 | let cc = cluster_config.clone(); 45 | handles.push(tokio::spawn(async move { 46 | let mut server = Server::new(id, config, cc, None).await; 47 | server.start().await; 48 | })); 49 | } 50 | 51 | // Simulate adding a new node 52 | // The following defines the basic configuration of the new node 53 | tokio::time::sleep(Duration::from_secs(10)).await; 54 | let new_node_id = 6; 55 | let new_node_address = SocketAddr::from_str(format!("127.0.0.1:{}", 5006).as_str()).unwrap(); 56 | 57 | let new_node_conf = ServerConfig { 58 | election_timeout: Duration::from_millis(1000), 59 | address: new_node_address.clone().into(), 60 | default_leader: Some(1 as u32), 61 | leadership_preferences: HashMap::new(), 62 | storage_location: Some("logs/".to_string()), 63 | }; 64 | 65 | // Launching a new node 66 | handles.push(tokio::spawn(async move { 67 | let mut server = Server::new(new_node_id, new_node_conf, cluster_config, None).await; 68 | server.start().await; 69 | })); 70 | 71 | // Simulate sending a Raft Join request after a few seconds 72 | // Because we need to wait until the new node has started 73 | tokio::time::sleep(Duration::from_secs(3)).await; 74 | add_node_request(new_node_id, new_node_address).await; 75 | 76 | for handle in handles { 77 | handle.await.unwrap(); 78 | } 79 | } 80 | 81 | async fn add_node_request(new_node_id: u32, addr: SocketAddr) { 82 | let log = get_logger(); 83 | 84 | let network_manager = TCPManager::new(addr); 85 | 86 | let request_data = vec![ 87 | new_node_id.to_be_bytes().to_vec(), 88 | 0u32.to_be_bytes().to_vec(), 89 | 10u32.to_be_bytes().to_vec(), 90 | addr.to_string().as_bytes().to_vec(), 91 | ] 92 | .concat(); 93 | 94 | // Let's assume that 5001 is the port of the leader node. 95 | if let Err(e) = network_manager 96 | .send( 97 | &SocketAddr::from_str("127.0.0.1:5001").unwrap(), 98 | &request_data, 99 | ) 100 | .await 101 | { 102 | error!(log, "Failed to send client request: {}", e); 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /examples/simulate_node_failure.rs: -------------------------------------------------------------------------------- 1 | // Organization: SpacewalkHq 2 | // License: MIT License 3 | 4 | use raft_rs::cluster::{ClusterConfig, NodeMeta}; 5 | use raft_rs::log::get_logger; 6 | use raft_rs::server::{Server, ServerConfig}; 7 | use rand::Rng; 8 | use slog::{info, warn}; 9 | use std::collections::HashMap; 10 | use std::net::SocketAddr; 11 | use std::str::FromStr; 12 | use tokio::time::{sleep, Duration}; 13 | 14 | #[tokio::main] 15 | async fn main() { 16 | let log = get_logger(); 17 | 18 | // Define cluster configuration 19 | let cluster_nodes = vec![1, 2, 3, 4, 5]; 20 | let peers = vec![ 21 | NodeMeta::from((1, SocketAddr::from_str("127.0.0.1:5001").unwrap())), 22 | NodeMeta::from((2, SocketAddr::from_str("127.0.0.1:5002").unwrap())), 23 | NodeMeta::from((3, SocketAddr::from_str("127.0.0.1:5003").unwrap())), 24 | NodeMeta::from((4, SocketAddr::from_str("127.0.0.1:5004").unwrap())), 25 | NodeMeta::from((5, SocketAddr::from_str("127.0.0.1:5005").unwrap())), 26 | ]; 27 | let cluster_config = ClusterConfig::new(peers.clone()); 28 | 29 | // Create server configs 30 | let configs: Vec<_> = peers 31 | .clone() 32 | .iter() 33 | .map(|n| ServerConfig { 34 | election_timeout: Duration::from_millis(200), 35 | address: n.address, 36 | default_leader: Some(1), 37 | leadership_preferences: HashMap::new(), 38 | storage_location: Some("logs/".to_string()), 39 | }) 40 | .collect(); 41 | 42 | // Start servers asynchronously 43 | let mut server_handles = vec![]; 44 | for (i, config) in configs.into_iter().enumerate() { 45 | let id = cluster_nodes[i]; 46 | let cc = cluster_config.clone(); 47 | let server_handle = tokio::spawn(async move { 48 | let mut server = Server::new(id, config, cc, None).await; 49 | server.start().await; 50 | }); 51 | server_handles.push(server_handle); 52 | } 53 | 54 | // Simulate stopping and restarting servers 55 | let mut rng = rand::thread_rng(); 56 | for _ in 0..10 { 57 | let sleep_time = rng.gen_range(3..=5); 58 | sleep(Duration::from_secs(sleep_time)).await; 59 | 60 | let server_to_stop = rng.gen_range(1..=5); 61 | warn!(log, "Stopping server {}", server_to_stop); 62 | 63 | // Cancel the selected server's task 64 | server_handles[server_to_stop - 1].abort(); 65 | 66 | // Simulate Raft leader election process 67 | sleep(Duration::from_secs(3)).await; 68 | 69 | warn!(log, "Restarting server {}", server_to_stop); 70 | let config = ServerConfig { 71 | election_timeout: Duration::from_millis(200), 72 | address: SocketAddr::from_str(format!("127.0.0.1:{}", 5000 + server_to_stop).as_str()) 73 | .unwrap(), 74 | default_leader: Some(1), 75 | leadership_preferences: HashMap::new(), 76 | storage_location: Some("logs/".to_string()), 77 | }; 78 | let cc = cluster_config.clone(); 79 | let server_handle = tokio::spawn(async move { 80 | let mut server = 81 | Server::new(server_to_stop.try_into().unwrap(), config, cc, None).await; 82 | server.start().await; 83 | }); 84 | server_handles[server_to_stop - 1] = server_handle; 85 | } 86 | 87 | // Wait for all server tasks to complete (if they haven't been aborted) 88 | for handle in server_handles { 89 | let _ = handle.await; 90 | } 91 | 92 | info!(log, "Test completed successfully."); 93 | } 94 | -------------------------------------------------------------------------------- /examples/simulate_replica_repair.rs: -------------------------------------------------------------------------------- 1 | // Organization: SpacewalkHq 2 | // License: MIT License 3 | 4 | // We create a cluster of 5 nodes and simulate different scenarios of storage failure and recovery. 5 | 6 | use raft_rs::cluster::{ClusterConfig, NodeMeta}; 7 | use raft_rs::log::get_logger; 8 | use raft_rs::server::{Server, ServerConfig}; 9 | use rand::Rng; 10 | use slog::{info, warn}; 11 | use std::collections::HashMap; 12 | use std::fs; 13 | use std::net::SocketAddr; 14 | use std::str::FromStr; 15 | use tokio::time::{sleep, Duration}; 16 | 17 | #[tokio::main] 18 | async fn main() { 19 | let log = get_logger(); 20 | 21 | // Define cluster configuration 22 | let cluster_nodes = vec![1, 2, 3, 4, 5]; 23 | let peers = vec![ 24 | NodeMeta::from((1, SocketAddr::from_str("127.0.0.1:5001").unwrap())), 25 | NodeMeta::from((2, SocketAddr::from_str("127.0.0.1:5002").unwrap())), 26 | NodeMeta::from((3, SocketAddr::from_str("127.0.0.1:5003").unwrap())), 27 | NodeMeta::from((4, SocketAddr::from_str("127.0.0.1:5004").unwrap())), 28 | NodeMeta::from((5, SocketAddr::from_str("127.0.0.1:5005").unwrap())), 29 | ]; 30 | let cluster_config = ClusterConfig::new(peers.clone()); 31 | 32 | // Create server configs 33 | let configs: Vec<_> = peers 34 | .clone() 35 | .iter() 36 | .map(|n| ServerConfig { 37 | election_timeout: Duration::from_millis(200), 38 | address: n.address, 39 | default_leader: Some(1), 40 | leadership_preferences: HashMap::new(), 41 | storage_location: Some("logs/".to_string()), 42 | }) 43 | .collect(); 44 | 45 | // Start servers asynchronously 46 | let mut server_handles = vec![]; 47 | for (i, config) in configs.into_iter().enumerate() { 48 | let id = cluster_nodes[i]; 49 | let cc = cluster_config.clone(); 50 | let server_handle = tokio::spawn(async move { 51 | // Simulate storage corruption when starting up 52 | let storage_location = "logs/".to_string(); 53 | let corrupted = rand::thread_rng().gen_bool(0.3); // 30% chance of corruption 54 | if corrupted { 55 | fs::create_dir_all(&storage_location).unwrap(); 56 | fs::write(format!("{}server_{}.log", storage_location, id), b"").unwrap(); // Simulate corruption 57 | warn!(get_logger(), "Storage for server {} is corrupted", id); 58 | } 59 | 60 | let mut server = Server::new(id, config, cc, None).await; 61 | server.start().await; 62 | }); 63 | server_handles.push(server_handle); 64 | } 65 | info!(log, "Cluster is up and running"); 66 | 67 | // Simulate a random storage failure and recovery while servers are running 68 | let mut rng = rand::thread_rng(); 69 | for _ in 0..10 { 70 | let sleep_time = rng.gen_range(3..=5); 71 | sleep(Duration::from_secs(sleep_time)).await; 72 | 73 | let server_to_fail = rng.gen_range(1..=5); 74 | warn!( 75 | log, 76 | "Simulating storage corruption for server {}", server_to_fail 77 | ); 78 | 79 | // Simulate storage corruption on a running server 80 | let storage_path = format!("logs/"); 81 | fs::create_dir_all(&storage_path).unwrap(); 82 | fs::write( 83 | format!("{}server_{}.log", storage_path, server_to_fail), 84 | b"", 85 | ) 86 | .unwrap(); // Simulate corruption 87 | 88 | // Restart the corrupted server to simulate recovery 89 | let cc: ClusterConfig = cluster_config.clone(); 90 | let server_handle = tokio::spawn(async move { 91 | let config = ServerConfig { 92 | election_timeout: Duration::from_millis(200), 93 | address: SocketAddr::from_str( 94 | format!("127.0.0.1:{}", 5000 + server_to_fail).as_str(), 95 | ) 96 | .unwrap(), 97 | default_leader: Some(1), 98 | leadership_preferences: HashMap::new(), 99 | storage_location: Some(storage_path.clone()), 100 | }; 101 | let mut server = 102 | Server::new(server_to_fail.try_into().unwrap(), config, cc, None).await; 103 | server.start().await; 104 | // Handle recovery of corrupted storage 105 | info!( 106 | get_logger(), 107 | "Server {} has recovered from storage corruption", server_to_fail 108 | ); 109 | }); 110 | 111 | server_handles[server_to_fail - 1] = server_handle; 112 | } 113 | 114 | // Wait for all server tasks to complete (if they haven't been aborted) 115 | for handle in server_handles { 116 | let _ = handle.await; 117 | } 118 | 119 | info!(log, "Test completed successfully."); 120 | } 121 | -------------------------------------------------------------------------------- /src/cluster.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::net::SocketAddr; 3 | 4 | #[derive(Debug, Clone)] 5 | pub struct NodeMeta { 6 | pub id: u32, 7 | pub address: SocketAddr, 8 | } 9 | 10 | impl NodeMeta { 11 | fn new(id: u32, address: SocketAddr) -> NodeMeta { 12 | Self { id, address } 13 | } 14 | } 15 | 16 | impl From<(u32, SocketAddr)> for NodeMeta { 17 | fn from((id, address): (u32, SocketAddr)) -> Self { 18 | Self::new(id, address) 19 | } 20 | } 21 | 22 | #[derive(Debug, Clone)] 23 | pub struct ClusterConfig { 24 | peers: Vec, 25 | id_node_map: HashMap, 26 | } 27 | 28 | impl ClusterConfig { 29 | pub fn new(peers: Vec) -> ClusterConfig { 30 | let id_node_map = peers 31 | .clone() 32 | .into_iter() 33 | .map(|x| (x.id, x)) 34 | .collect::>(); 35 | ClusterConfig { peers, id_node_map } 36 | } 37 | 38 | pub fn peers(&self) -> &[NodeMeta] { 39 | &self.peers 40 | } 41 | 42 | // Return meta of peers for a node 43 | pub fn peers_for(&self, id: u32) -> Vec<&NodeMeta> { 44 | self.peers.iter().filter(|x| x.id != id).collect::>() 45 | } 46 | 47 | // Return address of peers for a node 48 | pub fn peer_address_for(&self, id: u32) -> Vec { 49 | self.peers 50 | .iter() 51 | .filter(|x| x.id != id) 52 | .map(|x| x.address) 53 | .collect::>() 54 | } 55 | 56 | pub fn address(&self, id: u32) -> Option { 57 | self.id_node_map.get(&id).map(|x| x.address) 58 | } 59 | pub fn meta(&self, node_id: u32) -> Option<&NodeMeta> { 60 | self.id_node_map.get(&node_id) 61 | } 62 | 63 | pub fn contains_server(&self, node_id: u32) -> bool { 64 | self.id_node_map.contains_key(&node_id) 65 | } 66 | 67 | pub fn add_server(&mut self, n: NodeMeta) { 68 | self.peers.push(n.clone()); 69 | self.id_node_map.insert(n.id, n); 70 | } 71 | 72 | pub fn peer_count(&self, id: u32) -> usize { 73 | self.peers_for(id).len() 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | // organization : SpacewalkHq 2 | // License : MIT License 3 | 4 | use std::net::SocketAddr; 5 | use thiserror::Error; 6 | 7 | /// wrapper around std library error 8 | pub type Result = std::result::Result; 9 | 10 | #[derive(Error, Debug)] 11 | pub enum Error { 12 | // Storage layer specific error 13 | #[error("Storage error {0}")] 14 | Store(#[from] StorageError), 15 | // Network layer specific error 16 | #[error("Network error {0}")] 17 | Network(#[from] NetworkError), 18 | // To handle all std lib io error 19 | #[error("File error {0}")] 20 | Io(#[from] std::io::Error), 21 | /// Some other error occurred. 22 | #[error("unknown error {0}")] 23 | Unknown(#[from] Box), 24 | /// To handle all bincode error 25 | #[error("Bincode error {0}")] 26 | BincodeError(#[from] bincode::Error), 27 | } 28 | 29 | #[derive(Error, Debug)] 30 | pub enum NetworkError { 31 | #[error("Accepting incoming connection failed")] 32 | AcceptError, 33 | #[error("Connection is closed")] 34 | ConnectionClosedError, 35 | #[error("Connection to {0} failed")] 36 | ConnectError(SocketAddr), 37 | #[error("Failed binding to {0}")] 38 | BindError(SocketAddr), 39 | #[error("Broadcast failed, errmsg: {0}")] 40 | BroadcastError(String), 41 | } 42 | 43 | #[derive(Error, Debug)] 44 | pub enum StorageError { 45 | #[error("Path not found")] 46 | PathNotFound, 47 | #[error("File is empty")] 48 | EmptyFile, 49 | #[error("File is corrupted")] 50 | CorruptFile, 51 | #[error("Data integrity check failed!")] 52 | DataIntegrityError, 53 | #[error("Storing log failed")] 54 | StoreError, 55 | #[error("Log compaction failed")] 56 | CompactionError, 57 | #[error("Log retrieval failed")] 58 | RetrieveError, 59 | #[error("Reading file metadata failed")] 60 | MetaDataError, 61 | } 62 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // organization : SpacewalkHq 2 | // License : MIT License 3 | 4 | pub mod cluster; 5 | pub mod error; 6 | pub mod log; 7 | pub mod network; 8 | pub mod server; 9 | pub mod state_mechine; 10 | pub mod storage; 11 | -------------------------------------------------------------------------------- /src/log.rs: -------------------------------------------------------------------------------- 1 | use chrono::prelude::*; 2 | use slog::{o, Drain, Logger}; 3 | 4 | pub fn get_logger() -> Logger { 5 | let decorator = slog_term::PlainSyncDecorator::new(std::io::stdout()); 6 | let drain = slog_term::FullFormat::new(decorator) 7 | .use_custom_timestamp(|io| write!(io, "{}", Utc::now().format("%Y-%m-%d %H:%M:%S"))) 8 | .build() 9 | .fuse(); 10 | 11 | Logger::root(drain, o!()) 12 | } 13 | 14 | #[cfg(test)] 15 | mod tests { 16 | use slog::{crit, debug, error, info, trace, warn}; 17 | 18 | use crate::log::get_logger; 19 | 20 | #[tokio::test] 21 | async fn test_slog() { 22 | let log = get_logger(); 23 | 24 | trace!(log, "trace log message"); 25 | debug!(log, "debug log message"); 26 | info!(log, "info log message"); 27 | warn!(log, "warn log message"); 28 | error!(log, "error log message"); 29 | crit!(log, "crit log message"); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/network.rs: -------------------------------------------------------------------------------- 1 | // organization : SpacewalkHq 2 | // License : MIT License 3 | 4 | use std::net::SocketAddr; 5 | use std::sync::Arc; 6 | 7 | use async_trait::async_trait; 8 | use futures::future::join_all; 9 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 10 | use tokio::net::{TcpListener, TcpStream}; 11 | use tokio::sync::Mutex; 12 | 13 | use crate::error::NetworkError::ConnectionClosedError; 14 | use crate::error::Result; 15 | use crate::error::{Error, NetworkError}; 16 | 17 | #[async_trait] 18 | pub trait NetworkLayer: Send + Sync { 19 | async fn send(&self, address: &SocketAddr, data: &[u8]) -> Result<()>; 20 | async fn receive(&self) -> Result>; 21 | async fn broadcast(&self, data: &[u8], addresses: &[SocketAddr]) -> Result<()>; 22 | async fn open(&self) -> Result<()>; 23 | async fn close(self) -> Result<()>; 24 | } 25 | 26 | #[derive(Debug, Clone)] 27 | pub struct TCPManager { 28 | address: SocketAddr, 29 | listener: Arc>>, 30 | is_open: Arc>, 31 | } 32 | 33 | impl TCPManager { 34 | pub fn new(address: SocketAddr) -> Self { 35 | TCPManager { 36 | address, 37 | listener: Arc::new(Mutex::new(None)), 38 | is_open: Arc::new(Mutex::new(false)), 39 | } 40 | } 41 | 42 | async fn async_send(data: &[u8], address: &SocketAddr) -> Result<()> { 43 | let mut stream = TcpStream::connect(address).await.map_err(Error::Io)?; 44 | stream.write_all(data).await.map_err(Error::Io)?; 45 | Ok(()) 46 | } 47 | 48 | async fn handle_receive(&self) -> Result> { 49 | let mut data = Vec::new(); 50 | let listener = self.listener.lock().await; 51 | if let Some(listener) = &*listener { 52 | let (mut stream, _) = listener.accept().await.map_err(Error::Io)?; 53 | let mut buffer = Vec::new(); 54 | let mut reader = tokio::io::BufReader::new(&mut stream); 55 | reader.read_to_end(&mut buffer).await.map_err(Error::Io)?; 56 | data = buffer; 57 | } 58 | Ok(data) 59 | } 60 | } 61 | 62 | #[async_trait] 63 | impl NetworkLayer for TCPManager { 64 | async fn send(&self, address: &SocketAddr, data: &[u8]) -> Result<()> { 65 | Self::async_send(data, address).await?; 66 | Ok(()) 67 | } 68 | 69 | async fn receive(&self) -> Result> { 70 | self.handle_receive().await 71 | } 72 | 73 | async fn broadcast(&self, data: &[u8], addresses: &[SocketAddr]) -> Result<()> { 74 | let futures = addresses 75 | .iter() 76 | .map(|address| Self::async_send(data, address)); 77 | join_all(futures) 78 | .await 79 | .into_iter() 80 | .collect::>>() 81 | // FIXME: We should let client decide what to do with the errors 82 | .map_err(|e| NetworkError::BroadcastError(e.to_string()))?; 83 | Ok(()) 84 | } 85 | 86 | async fn open(&self) -> Result<()> { 87 | let mut is_open = self.is_open.lock().await; 88 | if *is_open { 89 | return Err(Error::Unknown("Listener is already open".into())); 90 | } 91 | let addr: SocketAddr = self.address; 92 | let listener = TcpListener::bind(addr) 93 | .await 94 | .map_err(|_e| NetworkError::BindError(addr))?; 95 | *self.listener.lock().await = Some(listener); 96 | *is_open = true; 97 | Ok(()) 98 | } 99 | 100 | async fn close(self) -> Result<()> { 101 | let mut is_open = self.is_open.lock().await; 102 | if !*is_open { 103 | return Err(Error::Network(ConnectionClosedError)); 104 | } 105 | *self.listener.lock().await = None; 106 | *is_open = false; 107 | Ok(()) 108 | } 109 | } 110 | 111 | #[cfg(test)] 112 | mod tests { 113 | use std::net::SocketAddr; 114 | use tokio::task::JoinSet; 115 | 116 | use crate::network::{NetworkLayer, TCPManager}; 117 | 118 | const LOCALHOST: &str = "127.0.0.1"; 119 | 120 | fn sock_addr(host: &str, port: u32) -> SocketAddr { 121 | let addr = format!("{}:{}", host, port); 122 | addr.parse::().unwrap() 123 | } 124 | 125 | #[tokio::test] 126 | async fn test_send() { 127 | let network = TCPManager::new(sock_addr(LOCALHOST, 8082)); 128 | let data = vec![1, 2, 3]; 129 | network.open().await.unwrap(); 130 | let network_clone = network.clone(); 131 | let handler = tokio::spawn(async move { 132 | let _ = network_clone.receive().await.unwrap(); 133 | }); 134 | 135 | let send_result = network.send(&sock_addr(LOCALHOST, 8082), &data).await; 136 | assert!(send_result.is_ok()); 137 | 138 | handler.await.unwrap(); 139 | } 140 | 141 | #[tokio::test] 142 | async fn test_send_closed_connection() { 143 | let network = TCPManager::new(sock_addr(LOCALHOST, 8020)); 144 | let data = vec![1, 2, 3]; 145 | network.open().await.unwrap(); 146 | let network_clone = network.clone(); 147 | tokio::spawn(async move { 148 | let _ = network_clone.receive().await.unwrap(); 149 | }); 150 | 151 | let send_result = network.send(&sock_addr(LOCALHOST, 8021), &data).await; 152 | assert!(send_result.is_err()); 153 | } 154 | 155 | #[tokio::test] 156 | async fn test_receive_happy_case() { 157 | let network = TCPManager::new(sock_addr(LOCALHOST, 8030)); 158 | let data = vec![1, 2, 3]; 159 | network.open().await.unwrap(); 160 | let network_clone = network.clone(); 161 | let handler = tokio::spawn(async move { network_clone.receive().await.unwrap() }); 162 | 163 | network 164 | .send(&sock_addr(LOCALHOST, 8030), &data) 165 | .await 166 | .unwrap(); 167 | let rx_data = handler.await.unwrap(); 168 | assert_eq!(rx_data, data) 169 | } 170 | 171 | #[tokio::test] 172 | async fn test_open() { 173 | let network = TCPManager::new(sock_addr(LOCALHOST, 8040)); 174 | let status = network.open().await; 175 | assert!(status.is_ok()); 176 | assert!(*network.is_open.lock().await); 177 | } 178 | 179 | #[tokio::test] 180 | async fn test_reopen_opened_port() { 181 | let network = TCPManager::new(sock_addr(LOCALHOST, 8042)); 182 | let status = network.open().await; 183 | assert!(status.is_ok()); 184 | let another_network = network.clone(); 185 | let status = another_network.open().await; 186 | assert!(status.is_err()); 187 | } 188 | 189 | #[tokio::test] 190 | async fn test_close() { 191 | let network = TCPManager::new(sock_addr(LOCALHOST, 8046)); 192 | let _ = network.open().await; 193 | 194 | let close_status = network.close().await; 195 | assert!(close_status.is_ok()); 196 | } 197 | 198 | #[tokio::test] 199 | async fn test_broadcast_happy_case() { 200 | let data = vec![1, 2, 3, 4]; 201 | // server which is about to broadcast data 202 | let broadcasting_node = TCPManager::new(sock_addr(LOCALHOST, 8050)); 203 | broadcasting_node.open().await.unwrap(); 204 | assert!(*broadcasting_node.is_open.lock().await); 205 | 206 | // vec to keep track of all other server which should be receiving data 207 | let mut receivers = vec![]; 208 | // vec to keep track of the address of servers 209 | let mut receiver_addresses = vec![]; 210 | 211 | for p in 8051..8060 { 212 | // create receiver server 213 | let rx = TCPManager::new(sock_addr(LOCALHOST, p)); 214 | receiver_addresses.push( 215 | format!("{}:{}", LOCALHOST, p) 216 | .parse::() 217 | .unwrap(), 218 | ); 219 | 220 | rx.open().await.unwrap(); 221 | assert!(*rx.is_open.lock().await); 222 | receivers.push(rx) 223 | } 224 | 225 | let mut s = JoinSet::new(); 226 | for rx in receivers { 227 | s.spawn(async move { 228 | let rx_data = rx.receive().await; 229 | assert!(rx_data.is_ok()); 230 | // return the received data 231 | rx_data.unwrap() 232 | }); 233 | } 234 | 235 | // broadcast the message 236 | let broadcast_result = broadcasting_node 237 | .broadcast(&data, &receiver_addresses) 238 | .await; 239 | assert!(broadcast_result.is_ok()); 240 | 241 | // assert the data received on servers 242 | while let Some(res) = s.join_next().await { 243 | let rx_data = res.unwrap(); 244 | assert_eq!(data, rx_data) 245 | } 246 | } 247 | 248 | #[tokio::test] 249 | async fn test_broadcast_some_nodes_down() { 250 | let data = vec![1, 2, 3, 4]; 251 | // server which is about to broadcast data 252 | let broadcasting_node = TCPManager::new(sock_addr(LOCALHOST, 8061)); 253 | broadcasting_node.open().await.unwrap(); 254 | assert!(*broadcasting_node.is_open.lock().await); 255 | 256 | // vec to keep track of all servers which should be receiving data 257 | let mut receivers = vec![]; 258 | // vec to keep track of the address 259 | let mut receiver_addresses = vec![]; 260 | for p in 8062..8070 { 261 | // Create a receiver node 262 | let rx = TCPManager::new(sock_addr(LOCALHOST, p)); 263 | receiver_addresses.push( 264 | format!("{}:{}", LOCALHOST, p) 265 | .parse::() 266 | .unwrap(), 267 | ); 268 | // open connection for half server 269 | // mocking rest half to be down 270 | if p & 1 == 1 { 271 | rx.open().await.unwrap(); 272 | assert!(*rx.is_open.lock().await); 273 | } 274 | receivers.push(rx) 275 | } 276 | 277 | // broadcast the data 278 | let broadcast_result = broadcasting_node 279 | .broadcast(&data, &receiver_addresses) 280 | .await; 281 | assert!(broadcast_result.is_err()); 282 | } 283 | } 284 | -------------------------------------------------------------------------------- /src/server.rs: -------------------------------------------------------------------------------- 1 | // organization : SpacewalkHq 2 | // License : MIT License 3 | 4 | use crate::cluster::{ClusterConfig, NodeMeta}; 5 | use crate::error::Error; 6 | use crate::log::get_logger; 7 | use crate::network::{NetworkLayer, TCPManager}; 8 | use crate::state_mechine::{self, StateMachine}; 9 | use crate::storage::{LocalStorage, Storage, CHECKSUM_LEN}; 10 | use serde::{Deserialize, Serialize}; 11 | use slog::{error, info, o}; 12 | use std::collections::HashMap; 13 | use std::io::Cursor; 14 | use std::net::SocketAddr; 15 | use std::path::PathBuf; 16 | use std::sync::Arc; 17 | use std::time::{Duration, Instant}; 18 | use tokio::io::AsyncReadExt; 19 | use tokio::sync::Mutex; 20 | use tokio::time::sleep; 21 | 22 | #[derive(Debug, Clone, PartialEq)] 23 | enum RaftState { 24 | Follower, 25 | Candidate, 26 | Leader, 27 | } 28 | 29 | #[derive(Debug, Clone)] 30 | enum MessageType { 31 | RequestVote, 32 | RequestVoteResponse, 33 | AppendEntries, 34 | AppendEntriesResponse, 35 | Heartbeat, 36 | HeartbeatResponse, 37 | ClientRequest, 38 | ClientResponse, 39 | RepairRequest, 40 | RepairResponse, 41 | // dynamic membership changes 42 | JoinRequest, 43 | JoinResponse, 44 | 45 | BatchAppendEntries, 46 | BatchAppendEntriesResponse, 47 | } 48 | 49 | #[derive(Debug)] 50 | /// Represents the state of a Raft server, which includes information 51 | /// about the current term, election state, and log entries. 52 | struct ServerState { 53 | /// The current term number, which increases monotonically. 54 | /// It is used to identify the latest term known to this server. 55 | current_term: u32, 56 | 57 | /// The current state of the server in the Raft protocol 58 | /// (e.g., Leader, Follower, or Candidate). 59 | state: RaftState, 60 | 61 | /// The candidate ID that this server voted for in the current term. 62 | /// It is `None` if the server hasn't voted for anyone in this term. 63 | voted_for: Option, 64 | 65 | /// A deque of log entries that are replicated to the Raft cluster. 66 | // log: VecDeque, 67 | state_machine: Arc>>, 68 | 69 | /// The index of the highest log entry known to be committed. 70 | /// This indicates the index up to which the state machine is consistent. 71 | commit_index: u32, 72 | 73 | /// The index of the previous log entry used for consistency checks. 74 | /// Typically used during the append entries process. 75 | previous_log_index: u32, 76 | 77 | /// For each follower, the next log entry to send to that follower. 78 | /// This is used by the leader to keep track of what entries have been 79 | /// sent to each follower. 80 | next_index: Vec, 81 | 82 | /// For each follower, the highest log entry index that is known to 83 | /// be replicated on that follower. 84 | match_index: Vec, 85 | 86 | /// The election timeout duration. If this time passes without receiving 87 | /// a valid heartbeat or a vote request, the server will trigger an election. 88 | election_timeout: Duration, 89 | 90 | /// The time when the last heartbeat from the current leader was received. 91 | /// Used by followers to detect if the leader has failed. 92 | last_heartbeat: Instant, 93 | 94 | /// A map of received votes in the current election term. The key is the 95 | /// peer's ID and the value is a boolean indicating whether the vote was 96 | /// granted. 97 | votes_received: HashMap, 98 | } 99 | 100 | #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] 101 | pub enum LogCommand { 102 | Noop, 103 | Set, 104 | Delete, 105 | } 106 | 107 | #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] 108 | pub struct LogEntry { 109 | pub leader_id: u32, 110 | pub server_id: u32, 111 | pub term: u32, 112 | pub command: LogCommand, 113 | pub data: u32, 114 | } 115 | 116 | #[derive(Debug)] 117 | pub struct ServerConfig { 118 | pub election_timeout: Duration, 119 | pub address: SocketAddr, 120 | // Include default leader and leadership preferences 121 | pub default_leader: Option, 122 | pub leadership_preferences: HashMap, 123 | pub storage_location: Option, 124 | } 125 | 126 | pub struct Server { 127 | pub id: u32, 128 | state: ServerState, 129 | config: ServerConfig, 130 | network_manager: TCPManager, 131 | cluster_config: ClusterConfig, 132 | // Add write buffer and debounce timer 133 | write_buffer: Vec, 134 | debounce_timer: Instant, 135 | storage: LocalStorage, 136 | log: slog::Logger, 137 | } 138 | 139 | impl Server { 140 | pub async fn new( 141 | id: u32, 142 | config: ServerConfig, 143 | cluster_config: ClusterConfig, 144 | state_machine: Option>, 145 | ) -> Server { 146 | let log = get_logger(); 147 | let log = log.new( 148 | o!("ip" => config.address.ip().to_string(), "port" => config.address.port(), "id" => id), 149 | ); 150 | 151 | // if storage location is provided, use it else set empty string to use default location 152 | let storage_location = match config.storage_location.clone() { 153 | Some(location) => location + &format!("server_{}.log", id), 154 | None => format!("server_{}.log", id), 155 | }; 156 | let storage = LocalStorage::new(storage_location.clone()).await; 157 | let parent_path = PathBuf::from(storage_location) 158 | .parent() // This returns Option<&Path> 159 | .map(|p| p.to_path_buf()) // Convert &Path to PathBuf 160 | .unwrap_or_else(|| PathBuf::from("logs")); // Provide default path 161 | 162 | // Use the provided state_machine or default to FileStateMachine if none is provided 163 | let state_machine = state_machine.unwrap_or_else(|| { 164 | // Default FileStateMachine initialization 165 | let snapshot_path = parent_path.join(format!("server_{}_snapshot.log", id)); 166 | 167 | Box::new(state_mechine::FileStateMachine::new( 168 | &snapshot_path, 169 | Duration::from_secs(60 * 60), 170 | )) 171 | }); 172 | 173 | let state_machine = Arc::new(Mutex::new(state_machine)); 174 | 175 | let peer_count = cluster_config.peer_count(id); 176 | let state = ServerState { 177 | current_term: 0, 178 | state: RaftState::Follower, 179 | voted_for: None, 180 | state_machine, 181 | commit_index: 0, 182 | previous_log_index: 0, 183 | next_index: vec![0; peer_count], 184 | match_index: vec![0; peer_count], 185 | election_timeout: config.election_timeout + Duration::from_millis(20 * id as u64), 186 | last_heartbeat: Instant::now(), 187 | votes_received: HashMap::new(), 188 | }; 189 | let network_manager = TCPManager::new(config.address); 190 | 191 | Server { 192 | id, 193 | state, 194 | config, 195 | network_manager, 196 | cluster_config, 197 | write_buffer: Vec::new(), 198 | debounce_timer: Instant::now(), 199 | storage, 200 | log, 201 | } 202 | } 203 | 204 | pub async fn start(&mut self) { 205 | if let Err(e) = self.network_manager.open().await { 206 | error!(self.log, "Failed to open network manager: {}", e); 207 | return; 208 | } 209 | 210 | // there should be at-least 3 peers to form a quorum 211 | if self.peers().len() < 2 { 212 | error!(self.log, "At least 3 peers are required to form a quorum"); 213 | return; 214 | } 215 | 216 | // if the storage path is not exist, create it 217 | if let Err(e) = self.storage.check_storage().await { 218 | error!(self.log, "Failed to check storage: {}", e); 219 | return; 220 | } 221 | 222 | loop { 223 | match self.state.state { 224 | RaftState::Follower => self.follower().await, 225 | RaftState::Candidate => self.candidate().await, 226 | RaftState::Leader => self.leader().await, 227 | } 228 | } 229 | } 230 | 231 | pub fn is_leader(&self) -> bool { 232 | self.state.state == RaftState::Leader 233 | } 234 | 235 | async fn follower(&mut self) { 236 | if self.state.state != RaftState::Follower { 237 | return; 238 | } 239 | 240 | let log_byte = match self.storage.retrieve().await { 241 | Ok(data) => data, 242 | Err(e) => { 243 | error!(self.log, "Follower failed to read from storage: {}", e); 244 | return; 245 | } 246 | }; 247 | 248 | let log_entry_size = std::mem::size_of::(); 249 | 250 | // Data integrity check failed 251 | // try repair the log from other peers 252 | if log_byte.len() % (log_entry_size + CHECKSUM_LEN) != 0 { 253 | error!(self.log, "Data integrity check failed"); 254 | 255 | // step1 delete the log file 256 | if let Err(e) = self.storage.delete().await { 257 | error!(self.log, "Failed to delete log file: {}", e); 258 | } 259 | 260 | // step2 get the log from other peers 261 | // ping all the peers to get the log 262 | let addresses: Vec = self.peers_address(); 263 | let data = [ 264 | self.id.to_be_bytes(), 265 | 0u32.to_be_bytes(), 266 | 2u32.to_be_bytes(), 267 | ] 268 | .concat(); 269 | if let Err(e) = self.network_manager.broadcast(&data, &addresses).await { 270 | error!(self.log, "Follower failed to broadcast message: {}", e) 271 | } 272 | return; 273 | } 274 | 275 | let state_machine = Arc::clone(&self.state.state_machine); 276 | // Attempting to recover LogEntry from a disk file 277 | let mut cursor = Cursor::new(&log_byte); 278 | loop { 279 | let mut bytes_data = vec![0u8; log_entry_size + CHECKSUM_LEN]; 280 | if cursor.read_exact(&mut bytes_data).await.is_err() { 281 | break; 282 | } 283 | bytes_data = bytes_data[0..log_entry_size].to_vec(); 284 | 285 | let res = self.deserialize_log_entries(&bytes_data); 286 | let Ok(log_entry) = res else { 287 | error!( 288 | self.log, 289 | "Failed to deserialize log entry: {}", 290 | res.err().unwrap() 291 | ); 292 | return; 293 | }; 294 | 295 | if log_entry.term > self.state.current_term { 296 | self.state.current_term = log_entry.term; 297 | } 298 | 299 | state_machine 300 | .lock() 301 | .await 302 | .apply_log_entry( 303 | self.state.current_term, 304 | self.state.commit_index, 305 | log_entry.clone(), 306 | ) 307 | .await; 308 | 309 | // After restoring the LogEntry, the node's state information should be updated 310 | self.state.current_term = log_entry.term; 311 | } 312 | 313 | info!( 314 | self.log, 315 | "Log after reading from disk: {:?}, current term: {}", 316 | state_machine.lock().await.get_log_entry().await, 317 | self.state.current_term 318 | ); 319 | 320 | self.state.match_index = vec![0; self.peer_count() + 1]; 321 | self.state.next_index = vec![0; self.peer_count() + 1]; 322 | 323 | info!(self.log, "Server {} is a follower", self.id); 324 | // default leader 325 | if self.state.current_term == 0 { 326 | self.state.current_term += 1; 327 | if let Some(leader_id) = self.config.default_leader { 328 | if self.id == leader_id { 329 | self.state.state = RaftState::Leader; 330 | return; 331 | } 332 | } 333 | } 334 | 335 | let state_machine = Arc::clone(&self.state.state_machine); 336 | loop { 337 | if state_machine.lock().await.need_create_snapshot().await { 338 | let state_machine_clone = Arc::clone(&self.state.state_machine); 339 | let log_clone = self.log.clone(); 340 | let node_id_clone = self.id; 341 | tokio::spawn(async move { 342 | let mut state_machine_lock = state_machine_clone.lock().await; 343 | if let Err(e) = state_machine_lock.create_snapshot().await { 344 | error!( 345 | log_clone, 346 | "Node: {}, failed to create snapshot: {:?}", node_id_clone, e 347 | ); 348 | } else { 349 | info!( 350 | log_clone, 351 | "Node: {}, snapshot created successfully.", node_id_clone 352 | ); 353 | } 354 | }); 355 | } 356 | 357 | let timeout_duration = self.state.election_timeout; 358 | 359 | let timeout_future = async { 360 | sleep(timeout_duration).await; 361 | }; 362 | 363 | let rpc_future = self.receive_rpc(); 364 | 365 | tokio::select! { 366 | _ = timeout_future => { 367 | self.state.state = RaftState::Candidate; 368 | self.state.last_heartbeat = Instant::now(); 369 | break 370 | } 371 | _ = rpc_future => { 372 | } 373 | } 374 | } 375 | } 376 | 377 | async fn candidate(&mut self) { 378 | if self.state.state != RaftState::Candidate { 379 | return; 380 | } 381 | info!(self.log, "Server {} is a candidate", self.id); 382 | self.state.last_heartbeat = Instant::now(); // reset election timeout 383 | 384 | self.state.current_term += 1; 385 | 386 | // Vote for self 387 | self.state.voted_for = Some(self.id); 388 | self.state.votes_received.insert(self.id, true); 389 | let data = self.prepare_request_vote(self.id, self.state.current_term); 390 | let addresses: Vec = self.peers_address(); 391 | info!( 392 | self.log, 393 | "Starting election, id: {}, term: {}", self.id, self.state.current_term 394 | ); 395 | let _ = self.network_manager.broadcast(&data, &addresses).await; 396 | 397 | loop { 398 | let timeout_duration = self.state.election_timeout; 399 | 400 | let timeout_future = async { 401 | sleep(timeout_duration).await; 402 | }; 403 | 404 | let rpc_future = self.receive_rpc(); 405 | tokio::select! { 406 | _ = timeout_future => { 407 | if Instant::now().duration_since(self.state.last_heartbeat) >= timeout_duration { 408 | info!(self.log, "Election timeout"); 409 | self.state.state = RaftState::Follower; 410 | self.state.votes_received.clear(); 411 | break; 412 | } 413 | } 414 | _ = rpc_future => { 415 | if self.is_quorum(self.state.votes_received.len() as u32) { 416 | info!(self.log, "Quorum reached"); 417 | info!(self.log, "I am the leader {}", self.id); 418 | self.state.state = RaftState::Leader; 419 | break; 420 | } 421 | } 422 | } 423 | } 424 | 425 | if self.state.state == RaftState::Leader { 426 | self.state.current_term += 1; 427 | } else { 428 | self.state.state = RaftState::Follower; 429 | self.state.votes_received.clear(); 430 | } 431 | } 432 | 433 | async fn leader(&mut self) { 434 | if self.state.state != RaftState::Leader { 435 | return; 436 | } 437 | info!( 438 | self.log, 439 | "Server {} is the leader, term: {}", self.id, self.state.current_term 440 | ); 441 | 442 | let mut heartbeat_interval = tokio::time::interval(Duration::from_millis(300)); 443 | 444 | let state_machine = Arc::clone(&self.state.state_machine); 445 | loop { 446 | if state_machine.lock().await.need_create_snapshot().await { 447 | let state_machine_clone = Arc::clone(&self.state.state_machine); 448 | let log_clone = self.log.clone(); 449 | let node_id_clone = self.id; 450 | tokio::spawn(async move { 451 | let mut state_machine_lock = state_machine_clone.lock().await; 452 | if let Err(e) = state_machine_lock.create_snapshot().await { 453 | error!( 454 | log_clone, 455 | "Node: {}, failed to create snapshot: {:?}", node_id_clone, e 456 | ); 457 | } else { 458 | info!( 459 | log_clone, 460 | "Node: {}, snapshot created successfully.", node_id_clone 461 | ); 462 | } 463 | }); 464 | } 465 | 466 | let rpc_future = self.receive_rpc(); 467 | tokio::select! { 468 | _ = heartbeat_interval.tick() => { 469 | if self.state.state != RaftState::Leader { 470 | break; 471 | } 472 | 473 | let now = Instant::now(); 474 | self.state.last_heartbeat = now; 475 | 476 | let heartbeat_data = self.prepare_heartbeat(); 477 | let addresses: Vec = self.peers_address(); 478 | 479 | if let Err(e) = self.network_manager.broadcast(&heartbeat_data, &addresses).await { 480 | error!(self.log, "Failed to send heartbeats: {}", e); 481 | } 482 | }, 483 | _ = rpc_future => { 484 | if self.state.state != RaftState::Leader { 485 | break; 486 | } 487 | // TODO: Write coalescing with debouncing 488 | // Move this to a separate thread to avoid blocking the main loop 489 | if !self.write_buffer.is_empty() { 490 | let append_batch = self.prepare_append_batch(self.id, self.state.current_term, self.state.previous_log_index, self.state.commit_index, self.write_buffer.clone()); 491 | 492 | for entry in self.write_buffer.clone() { 493 | match bincode::serialize(&entry) { 494 | Ok(data) => self.persist_to_disk(self.id, &data).await, 495 | Err(e) => error!(self.log, "Failed to serialize entry: {}", e), 496 | } 497 | } 498 | 499 | let addresses: Vec = self.peers_address(); 500 | if let Err(e) = self.network_manager.broadcast(&append_batch, &addresses).await { 501 | error!(self.log, "Failed to send append batch: {}", e); 502 | } 503 | 504 | self.write_buffer.clear(); 505 | self.debounce_timer = Instant::now(); 506 | } 507 | }, 508 | } 509 | } 510 | } 511 | 512 | async fn receive_rpc(&mut self) { 513 | match self.network_manager.receive().await { 514 | Ok(data) => self.handle_rpc(data).await, 515 | Err(e) => error!(self.log, "Failed to receive rpc: {}", e), 516 | }; 517 | } 518 | 519 | fn prepare_append_batch( 520 | &self, 521 | id: u32, 522 | term: u32, 523 | prev_log_index: u32, 524 | commit_index: u32, 525 | write_buffer: Vec, 526 | ) -> Vec { 527 | let mut data = [ 528 | id.to_be_bytes(), 529 | term.to_be_bytes(), 530 | 2u32.to_be_bytes(), 531 | prev_log_index.to_be_bytes(), 532 | commit_index.to_be_bytes(), 533 | ] 534 | .concat(); 535 | for entry in write_buffer { 536 | let entry_data = [entry.term.to_be_bytes(), entry.data.to_be_bytes()].concat(); 537 | data.extend_from_slice(&entry_data); 538 | } 539 | data 540 | } 541 | 542 | fn prepare_request_vote(&self, id: u32, term: u32) -> Vec { 543 | [id.to_be_bytes(), term.to_be_bytes(), 0u32.to_be_bytes()].concat() 544 | } 545 | 546 | fn prepare_heartbeat(&self) -> Vec { 547 | [ 548 | self.id.to_be_bytes(), 549 | self.state.current_term.to_be_bytes(), 550 | 4u32.to_be_bytes(), 551 | ] 552 | .concat() 553 | } 554 | 555 | async fn handle_rpc(&mut self, data: Vec) { 556 | let message_type: u32 = u32::from_be_bytes(data[8..12].try_into().unwrap()); 557 | 558 | let message_type = match message_type { 559 | 0 => MessageType::RequestVote, 560 | 1 => MessageType::RequestVoteResponse, 561 | 2 => MessageType::AppendEntries, 562 | 3 => MessageType::AppendEntriesResponse, 563 | 4 => MessageType::Heartbeat, 564 | 5 => MessageType::HeartbeatResponse, 565 | 6 => MessageType::ClientRequest, 566 | 7 => MessageType::ClientResponse, 567 | 8 => MessageType::RepairRequest, 568 | 9 => MessageType::RepairResponse, 569 | 10 => MessageType::JoinRequest, 570 | 11 => MessageType::JoinResponse, 571 | 12 => MessageType::BatchAppendEntries, 572 | 13 => MessageType::BatchAppendEntriesResponse, 573 | _ => return, 574 | }; 575 | 576 | match message_type { 577 | MessageType::RequestVote => { 578 | self.handle_request_vote(&data).await; 579 | } 580 | MessageType::RequestVoteResponse => { 581 | self.handle_request_vote_response(&data).await; 582 | } 583 | MessageType::AppendEntries => { 584 | self.handle_append_entries(data).await; 585 | } 586 | MessageType::AppendEntriesResponse => { 587 | self.handle_append_entries_response(&data).await; 588 | } 589 | MessageType::Heartbeat => { 590 | self.handle_heartbeat(&data).await; 591 | } 592 | MessageType::HeartbeatResponse => { 593 | self.handle_heartbeat_response().await; 594 | } 595 | MessageType::ClientRequest => { 596 | self.handle_client_request(data).await; 597 | } 598 | MessageType::ClientResponse => { 599 | // TODO: get implementation from user based on the application 600 | info!(self.log, "Received client response: {:?}", data); 601 | let data = u32::from_be_bytes(data[12..16].try_into().unwrap()); 602 | if data == 1 { 603 | info!(self.log, "Consensus reached!"); 604 | } else { 605 | info!(self.log, "Consensus not reached!"); 606 | } 607 | } 608 | MessageType::RepairRequest => { 609 | self.handle_repair_request(&data).await; 610 | } 611 | MessageType::RepairResponse => { 612 | self.handle_repair_response(&data).await; 613 | } 614 | MessageType::JoinRequest => { 615 | info!( 616 | self.log, 617 | "Received join request: {:?}", 618 | String::from_utf8_lossy(&data) 619 | ); 620 | self.handle_join_request(&data).await; 621 | } 622 | MessageType::JoinResponse => { 623 | self.handle_join_response(&data).await; 624 | } 625 | MessageType::BatchAppendEntries => { 626 | self.handle_batch_append_entries(&data).await; 627 | } 628 | MessageType::BatchAppendEntriesResponse => { 629 | self.handle_batch_append_entries_response(&data).await; 630 | } 631 | } 632 | } 633 | 634 | async fn handle_client_request(&mut self, data: Vec) { 635 | if self.state.state != RaftState::Leader { 636 | return; 637 | } 638 | 639 | self.state.previous_log_index += 1; 640 | self.state.commit_index += 1; 641 | self.state.current_term += 1; 642 | 643 | let command = LogCommand::Set; 644 | let data = u32::from_be_bytes(data[12..16].try_into().unwrap()); 645 | let entry = LogEntry { 646 | leader_id: self.id, 647 | server_id: self.id, 648 | term: self.state.current_term, 649 | command, 650 | data, 651 | }; 652 | info!(self.log, "Received client request: {:?}", entry); 653 | self.write_buffer.push(entry.clone()); 654 | 655 | let state_machine = Arc::clone(&self.state.state_machine); 656 | state_machine 657 | .lock() 658 | .await 659 | .apply_log_entry(self.state.current_term, self.state.commit_index, entry) 660 | .await; 661 | } 662 | 663 | async fn handle_request_vote(&mut self, data: &[u8]) { 664 | // Only Follower can vote, because Candidate voted for itself 665 | let candidate_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); 666 | let candidate_term = u32::from_be_bytes(data[4..8].try_into().unwrap()); 667 | 668 | if self.state.state != RaftState::Follower { 669 | return; 670 | } 671 | 672 | if candidate_term < self.state.current_term { 673 | return; 674 | } 675 | 676 | self.state.voted_for = Some(candidate_id); 677 | self.state.current_term = candidate_term; 678 | 679 | // get candidate address from config 680 | let candidate_address = self.cluster_config.address(candidate_id); 681 | if candidate_address.is_none() { 682 | // no dynamic membership changes 683 | info!(self.log, "Candidate address not found"); 684 | return; 685 | } 686 | 687 | let data = [ 688 | self.id.to_be_bytes(), 689 | self.state.current_term.to_be_bytes(), 690 | 1u32.to_be_bytes(), 691 | 1u32.to_be_bytes(), 692 | ] 693 | .concat(); 694 | 695 | let voter_response = self 696 | .network_manager 697 | .send(&candidate_address.unwrap(), &data) 698 | .await; 699 | if let Err(e) = voter_response { 700 | error!(self.log, "Failed to send vote response: {}", e); 701 | } 702 | } 703 | 704 | async fn handle_request_vote_response(&mut self, data: &[u8]) { 705 | if self.state.state != RaftState::Candidate { 706 | return; 707 | } 708 | 709 | let voter_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); 710 | let term = u32::from_be_bytes(data[4..8].try_into().unwrap()); 711 | let vote_granted = u32::from_be_bytes(data[8..12].try_into().unwrap()) == 1; 712 | 713 | // if follower and your term and candidate term are same, and your id is less than candidate id, vote for candidate 714 | // leader preference 715 | if term >= self.state.current_term 716 | && self.id > voter_id 717 | && self.state.state == RaftState::Candidate 718 | { 719 | self.state.state = RaftState::Follower; 720 | } 721 | 722 | self.state.votes_received.insert(voter_id, vote_granted); 723 | info!(self.log, "Votes received: {:?}", self.state.votes_received); 724 | } 725 | 726 | async fn handle_append_entries(&mut self, data: Vec) { 727 | if self.state.state != RaftState::Follower { 728 | return; 729 | } 730 | 731 | self.state.last_heartbeat = Instant::now(); 732 | 733 | let id = u32::from_be_bytes(data[0..4].try_into().unwrap()); 734 | let leader_term = u32::from_be_bytes(data[4..8].try_into().unwrap()); 735 | let message_type = u32::from_be_bytes(data[8..12].try_into().unwrap()); 736 | let prev_log_index = u32::from_be_bytes(data[12..16].try_into().unwrap()); 737 | let commit_index = u32::from_be_bytes(data[16..20].try_into().unwrap()); 738 | info!( 739 | self.log, 740 | "Node {} received append entries request from Node {}, \ 741 | (term: self={}, receive={}), \ 742 | (prev_log_index: self={}, receive={}), \ 743 | (commit_index: self={}, receive={})", 744 | self.id, 745 | id, 746 | self.state.current_term, 747 | leader_term, 748 | self.state.previous_log_index, 749 | prev_log_index, 750 | self.state.commit_index, 751 | commit_index 752 | ); 753 | 754 | if leader_term < self.state.current_term { 755 | return; 756 | } 757 | 758 | if message_type != 2 { 759 | return; 760 | } 761 | 762 | if prev_log_index > self.state.previous_log_index { 763 | self.state.previous_log_index = prev_log_index; 764 | } else { 765 | return; 766 | } 767 | 768 | if commit_index > self.state.commit_index { 769 | self.state.commit_index = commit_index; 770 | } else { 771 | return; 772 | } 773 | 774 | let log_entry: LogEntry = LogEntry { 775 | leader_id: id, 776 | server_id: self.id, 777 | term: leader_term, 778 | command: LogCommand::Set, 779 | data: u32::from_be_bytes(data[24..28].try_into().unwrap()), 780 | }; 781 | 782 | // serialize log entry and append to log 783 | let data = match bincode::serialize(&log_entry) { 784 | Ok(data) => data, 785 | Err(e) => { 786 | error!(self.log, "Failed to serialize the log: {}", e); 787 | return; 788 | } 789 | }; 790 | 791 | let _ = self.persist_to_disk(id, &data).await; 792 | 793 | self.state.current_term += 1; // increment term on successful append for follower 794 | 795 | let response = [ 796 | self.id.to_be_bytes(), 797 | self.state.current_term.to_be_bytes(), 798 | 3u32.to_be_bytes(), 799 | 1u32.to_be_bytes(), 800 | ] 801 | .concat(); 802 | 803 | let leader_address = self.cluster_config.address(id); 804 | if leader_address.is_none() { 805 | // no dynamic membership changes 806 | info!(self.log, "Leader address not found"); 807 | return; 808 | } 809 | info!( 810 | self.log, 811 | "Sending append entries response to leader: {}", id 812 | ); 813 | if let Err(e) = self 814 | .network_manager 815 | .send(&leader_address.unwrap(), &response) 816 | .await 817 | { 818 | info!(self.log, "Failed to send append entries response: {}", e); 819 | } 820 | } 821 | 822 | async fn handle_append_entries_response(&mut self, data: &[u8]) { 823 | if self.state.state != RaftState::Leader { 824 | return; 825 | } 826 | 827 | let sender_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); 828 | let term = u32::from_be_bytes(data[4..8].try_into().unwrap()); 829 | let success = u32::from_be_bytes(data[12..16].try_into().unwrap()) == 1; 830 | 831 | info!( 832 | self.log, 833 | "Append entries response from peer: {} with term: {} and success: {}", 834 | sender_id, 835 | term, 836 | success 837 | ); 838 | 839 | if term > self.state.current_term { 840 | return; 841 | } 842 | 843 | if success { 844 | // check if you got a quorum 845 | let last_log_index = self.state.previous_log_index; 846 | self.state.match_index[sender_id as usize - 1] = last_log_index; 847 | self.state.next_index[sender_id as usize - 1] = last_log_index + 1; 848 | 849 | let mut match_indices = self.state.match_index.clone(); 850 | match_indices.sort(); 851 | let quorum_index = match_indices[self.peer_count() / 2]; 852 | 853 | info!( 854 | self.log, 855 | "Append entry response received from node {}: (match_index = {}, next_index = {}), current quorum_index: {}", 856 | sender_id, 857 | self.state.match_index[sender_id as usize - 1], 858 | self.state.next_index[sender_id as usize - 1], 859 | quorum_index 860 | ); 861 | 862 | if quorum_index >= self.state.commit_index { 863 | self.state.commit_index = quorum_index; 864 | // return client response 865 | let response_data = [ 866 | self.id.to_be_bytes(), 867 | self.state.current_term.to_be_bytes(), 868 | 7u32.to_be_bytes(), 869 | 1u32.to_be_bytes(), 870 | ] 871 | .concat(); 872 | if let Err(e) = self 873 | .network_manager 874 | .send(&self.config.address, &response_data) 875 | .await 876 | { 877 | error!(self.log, "Failed to send client response: {}", e); 878 | } 879 | info!( 880 | self.log, 881 | "Quorum decision reached to commit index: {}", self.state.commit_index 882 | ); 883 | } 884 | } else { 885 | self.state.next_index[sender_id as usize - 1] -= 1; 886 | } 887 | } 888 | 889 | async fn handle_heartbeat(&mut self, data: &[u8]) { 890 | if self.state.state != RaftState::Follower || self.state.state != RaftState::Candidate { 891 | return; 892 | } 893 | let term = u32::from_be_bytes(data[4..8].try_into().unwrap()); 894 | if term < self.state.current_term { 895 | return; 896 | } 897 | 898 | // if a leader gets a heartbeat from a leader with a higher term, it should step down 899 | if term > self.state.current_term { 900 | self.state.state = RaftState::Follower; 901 | } 902 | 903 | // if a leader gets a heartbeat from a leader same term, it should step down if it has a higher id 904 | if term == self.state.current_term { 905 | let leader_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); 906 | if self.config.default_leader.is_none() { 907 | if self.id < leader_id { 908 | self.state.state = RaftState::Follower; 909 | self.state.current_term = term; 910 | } 911 | } else if self.config.default_leader.is_some() 912 | && self.id != self.config.default_leader.unwrap() 913 | { 914 | self.state.state = RaftState::Follower; 915 | self.state.current_term = term; 916 | } else { 917 | self.state.state = RaftState::Leader; 918 | } 919 | } 920 | 921 | self.state.last_heartbeat = Instant::now(); 922 | } 923 | 924 | async fn handle_heartbeat_response(&mut self) { 925 | // Noop 926 | } 927 | 928 | async fn handle_repair_request(&mut self, data: &[u8]) { 929 | if self.state.state != RaftState::Follower || self.state.state != RaftState::Leader { 930 | return; 931 | } 932 | 933 | let peer_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); 934 | 935 | let log_byte = match self.storage.retrieve().await { 936 | Ok(data) => data, 937 | Err(e) => { 938 | error!(self.log, "Failed to read log from storage: {}", e); 939 | return; 940 | } 941 | }; 942 | let log_entry_size = std::mem::size_of::(); 943 | 944 | // Data integrity check failed 945 | if log_byte.len() % (log_entry_size + CHECKSUM_LEN) != 0 { 946 | error!(self.log, "Data integrity check failed"); 947 | return; 948 | } 949 | 950 | let mut cursor = Cursor::new(&log_byte); 951 | let mut repair_data = Vec::new(); 952 | let state_machine = Arc::clone(&self.state.state_machine); 953 | loop { 954 | let mut bytes_data = vec![0u8; log_entry_size + CHECKSUM_LEN]; 955 | if cursor.read_exact(&mut bytes_data).await.is_err() { 956 | break; 957 | } 958 | repair_data.extend_from_slice(&bytes_data[0..log_entry_size]); 959 | } 960 | info!( 961 | self.log, 962 | "Send repair data from {} to {}, log_entry: {:?}", 963 | self.id, 964 | peer_id, 965 | state_machine.lock().await.get_log_entry().await 966 | ); 967 | 968 | let mut response = [ 969 | self.id.to_be_bytes(), 970 | self.state.current_term.to_be_bytes(), 971 | 9u32.to_be_bytes(), 972 | 1u32.to_be_bytes(), 973 | ] 974 | .concat(); 975 | 976 | for entry in repair_data { 977 | response = [response.clone(), entry.to_be_bytes().to_vec()].concat(); 978 | } 979 | 980 | let peer_address = self.cluster_config.address(peer_id); 981 | if peer_address.is_none() { 982 | // no dynamic membership changes 983 | info!(self.log, "Peer address not found"); 984 | return; 985 | } 986 | if let Err(e) = self 987 | .network_manager 988 | .send(&peer_address.unwrap(), &response) 989 | .await 990 | { 991 | error!(self.log, "Failed to send repair response: {}", e); 992 | } 993 | } 994 | 995 | async fn handle_repair_response(&mut self, data: &[u8]) { 996 | if self.state.state != RaftState::Leader { 997 | return; 998 | } 999 | 1000 | if self.storage.turned_malicious().await.is_err() { 1001 | self.state.state = RaftState::Follower; 1002 | return; 1003 | } 1004 | 1005 | let term = u32::from_be_bytes(data[4..8].try_into().unwrap()); 1006 | if term < self.state.current_term { 1007 | return; 1008 | } 1009 | 1010 | let log_entries = data[16..].to_vec(); 1011 | if let Err(e) = self.storage.store(&log_entries).await { 1012 | error!(self.log, "Failed to store log entries to disk: {}", e); 1013 | } 1014 | } 1015 | 1016 | async fn handle_join_request(&mut self, data: &[u8]) { 1017 | if self.state.state != RaftState::Leader { 1018 | return; 1019 | } 1020 | 1021 | let node_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); 1022 | let term = u32::from_be_bytes(data[4..8].try_into().unwrap()); 1023 | let node_ip_address = String::from_utf8(data[12..].to_vec()).unwrap(); 1024 | 1025 | info!( 1026 | self.log, 1027 | "Current cluster nodes: {:?}, want join node: {}", 1028 | self.cluster_config 1029 | .peers() 1030 | .iter() 1031 | .map(|x| x.address.to_string()) 1032 | .collect::>(), 1033 | node_ip_address 1034 | ); 1035 | 1036 | if self.cluster_config.contains_server(node_id) { 1037 | error!( 1038 | self.log, 1039 | "Node already exists in the cluster, Ignoring join request." 1040 | ); 1041 | return; 1042 | } 1043 | 1044 | if term != 0 { 1045 | error!(self.log, "Invalid term for join request, term should be 0."); 1046 | return; 1047 | } 1048 | 1049 | // Add the new node's information to the cluster, ready to receive subsequent data 1050 | self.cluster_config 1051 | .add_server((node_id, node_ip_address.parse::().unwrap()).into()); 1052 | 1053 | let mut response = [ 1054 | self.id.to_be_bytes(), 1055 | self.state.current_term.to_be_bytes(), 1056 | 11u32.to_be_bytes(), 1057 | ] 1058 | .concat(); 1059 | response.extend_from_slice(&self.state.commit_index.to_be_bytes()); 1060 | response.extend_from_slice(&self.state.previous_log_index.to_be_bytes()); 1061 | response.extend_from_slice(&self.peer_count().to_be_bytes()); 1062 | 1063 | let Some(peer_address) = self.cluster_config.address(node_id) else { 1064 | // no dynamic membership changes 1065 | info!(self.log, "Peer address not found"); 1066 | return; 1067 | }; 1068 | if let Err(e) = self.network_manager.send(&peer_address, &response).await { 1069 | error!(self.log, "Failed to send join response: {}", e); 1070 | } 1071 | 1072 | // Here we will send the snapshot data to the new node 1073 | let state_machine = Arc::clone(&self.state.state_machine); 1074 | let log_entrys = if let Ok(data) = state_machine.lock().await.get_log_entry().await { 1075 | data 1076 | } else { 1077 | error!(self.log, "Failed to get log entrys from state machine."); 1078 | return; 1079 | }; 1080 | info!(self.log, "Sending log entrys to new node: {:?}", log_entrys); 1081 | 1082 | let log_entry_bytes = if let Ok(b) = bincode::serialize(&log_entrys) { 1083 | b 1084 | } else { 1085 | error!(self.log, "Failed to serialize log entrys."); 1086 | return; 1087 | }; 1088 | 1089 | let mut batch_append_entry_request: Vec = Vec::new(); 1090 | batch_append_entry_request.extend_from_slice(&self.id.to_be_bytes()); 1091 | batch_append_entry_request 1092 | .extend_from_slice(&state_machine.lock().await.get_term().await.to_be_bytes()); 1093 | batch_append_entry_request.extend_from_slice(&12u32.to_be_bytes()); 1094 | batch_append_entry_request 1095 | .extend_from_slice(&state_machine.lock().await.get_index().await.to_be_bytes()); 1096 | batch_append_entry_request.extend_from_slice(&log_entry_bytes); 1097 | if let Err(e) = self 1098 | .network_manager 1099 | .send(&peer_address, &batch_append_entry_request) 1100 | .await 1101 | { 1102 | error!( 1103 | self.log, 1104 | "Failed send batch append entry request to {}, err: {}", peer_address, e 1105 | ); 1106 | } 1107 | } 1108 | 1109 | async fn handle_join_response(&mut self, data: &[u8]) { 1110 | if self.state.state != RaftState::Follower { 1111 | return; 1112 | } 1113 | 1114 | let leader_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); 1115 | let current_term = u32::from_be_bytes(data[4..8].try_into().unwrap()); 1116 | let commit_index = u32::from_be_bytes(data[12..16].try_into().unwrap()); 1117 | let previous_log_index = u32::from_be_bytes(data[16..20].try_into().unwrap()); 1118 | // FIXME 1119 | let _peers_count = u32::from_be_bytes(data[20..24].try_into().unwrap()); 1120 | 1121 | self.state.current_term = current_term; 1122 | self.state.commit_index = commit_index; 1123 | self.state.previous_log_index = previous_log_index; 1124 | 1125 | let request_data = [ 1126 | self.id.to_be_bytes(), 1127 | self.state.current_term.to_be_bytes(), 1128 | 8u32.to_be_bytes(), 1129 | ] 1130 | .concat(); 1131 | let leader_address = self.cluster_config.address(leader_id); 1132 | if leader_address.is_none() { 1133 | // no dynamic membership changes 1134 | info!(self.log, "Leader address not found"); 1135 | return; 1136 | } 1137 | 1138 | if let Err(e) = self 1139 | .network_manager 1140 | .send(&leader_address.unwrap(), &request_data) 1141 | .await 1142 | { 1143 | error!(self.log, "Failed to send repair request: {}", e); 1144 | } 1145 | 1146 | info!( 1147 | self.log, 1148 | "Joined the cluster with leader: {}, own id: {}", leader_id, self.id 1149 | ); 1150 | } 1151 | 1152 | async fn handle_batch_append_entries(&mut self, data: &[u8]) { 1153 | let leader_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); 1154 | let last_included_term = u32::from_be_bytes(data[4..8].try_into().unwrap()); 1155 | let last_included_index = u32::from_be_bytes(data[12..16].try_into().unwrap()); 1156 | let log_entrys = if let Ok(data) = bincode::deserialize::>(&data[16..]) { 1157 | data 1158 | } else { 1159 | info!(self.log, "Failed to deserialize log entrys."); 1160 | return; 1161 | }; 1162 | 1163 | self.state.current_term = last_included_term; 1164 | self.state.commit_index = last_included_index; 1165 | self.state.previous_log_index = last_included_index; 1166 | let state_machine = Arc::clone(&self.state.state_machine); 1167 | state_machine 1168 | .lock() 1169 | .await 1170 | .apply_log_entrys(last_included_term, last_included_index, log_entrys) 1171 | .await; 1172 | 1173 | let response = [ 1174 | self.id.to_be_bytes(), 1175 | self.state.current_term.to_be_bytes(), 1176 | 13u32.to_be_bytes(), 1177 | ] 1178 | .concat(); 1179 | let peer_address = if let Some(addr) = self.cluster_config.address(leader_id) { 1180 | addr 1181 | } else { 1182 | info!(self.log, "Failed to get peer address."); 1183 | return; 1184 | }; 1185 | if let Err(e) = self.network_manager.send(&peer_address, &response).await { 1186 | error!(self.log, "Failed to send join response: {}", e); 1187 | } 1188 | } 1189 | 1190 | async fn handle_batch_append_entries_response(&mut self, data: &[u8]) { 1191 | if self.state.state != RaftState::Leader { 1192 | return; 1193 | } 1194 | 1195 | let peer_id = u32::from_be_bytes(data[0..4].try_into().unwrap()); 1196 | let last_included_term = u32::from_be_bytes(data[4..8].try_into().unwrap()); 1197 | 1198 | info!(self.log, "Received batch append entries response from peer: {}, current peer last_included_term: {}", peer_id, last_included_term); 1199 | } 1200 | 1201 | async fn persist_to_disk(&mut self, id: u32, data: &[u8]) { 1202 | info!( 1203 | self.log, 1204 | "Persisting logs to disk from peer: {} to server: {}", id, self.id 1205 | ); 1206 | 1207 | // Log Compaction 1208 | if let Err(e) = self.storage.compaction().await { 1209 | error!(self.log, "Failed to do compaction on disk: {}", e); 1210 | } 1211 | 1212 | let state_machine = Arc::clone(&self.state.state_machine); 1213 | // let mut state_machine_lock = state_machine.lock().await; 1214 | if self.state.state == RaftState::Follower { 1215 | // deserialize log entries and append to log 1216 | let res = self.deserialize_log_entries(data); 1217 | let Ok(log_entry) = res else { 1218 | error!( 1219 | self.log, 1220 | "Failed to deserialize log entry: {}", 1221 | res.err().unwrap() 1222 | ); 1223 | return; 1224 | }; 1225 | state_machine 1226 | .lock() 1227 | .await 1228 | .apply_log_entry(self.state.current_term, self.state.commit_index, log_entry) 1229 | .await; 1230 | } 1231 | if let Err(e) = self.storage.store(data).await { 1232 | error!(self.log, "Failed to store log entry to disk: {}", e); 1233 | } 1234 | 1235 | info!( 1236 | self.log, 1237 | "Log persistence complete, current log count: {}", 1238 | state_machine 1239 | .lock() 1240 | .await 1241 | .get_log_entry() 1242 | .await 1243 | .unwrap() 1244 | .len() 1245 | ); 1246 | } 1247 | 1248 | fn deserialize_log_entries(&self, data: &[u8]) -> Result { 1249 | bincode::deserialize(data).map_err(Error::BincodeError) 1250 | } 1251 | 1252 | fn is_quorum(&self, votes: u32) -> bool { 1253 | votes > (self.peer_count() / 2).try_into().unwrap_or_default() 1254 | } 1255 | 1256 | #[allow(dead_code)] 1257 | async fn stop(self) { 1258 | if let Err(e) = self.network_manager.close().await { 1259 | error!(self.log, "Failed to close network manager: {}", e); 1260 | } 1261 | } 1262 | 1263 | // Helper function to access cluster config 1264 | fn peers(&self) -> Vec<&NodeMeta> { 1265 | self.cluster_config.peers_for(self.id) 1266 | } 1267 | 1268 | fn peers_address(&self) -> Vec { 1269 | self.cluster_config.peer_address_for(self.id) 1270 | } 1271 | 1272 | fn peer_count(&self) -> usize { 1273 | self.peers().len() 1274 | } 1275 | } 1276 | -------------------------------------------------------------------------------- /src/state_mechine.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | use std::time::Duration; 3 | use std::{fmt::Debug, path::Path}; 4 | 5 | use async_trait::async_trait; 6 | use serde::{Deserialize, Serialize}; 7 | use tokio::fs; 8 | use tokio::io::AsyncReadExt; 9 | use tokio::{fs::OpenOptions, io::AsyncWriteExt, time::Instant}; 10 | 11 | use crate::error::StorageError::PathNotFound; 12 | use crate::{error::Error, error::Result, server::LogEntry}; 13 | 14 | #[async_trait] 15 | pub trait StateMachine: Debug + Send + Sync { 16 | // Retrieve the current term stored in the state machine 17 | async fn get_term(&self) -> u32; 18 | 19 | // Retrieve the current log index stored in the state machine 20 | async fn get_index(&self) -> u32; 21 | 22 | // Apply a single log entry to the state machine, updating term, index, and log entries 23 | async fn apply_log_entry( 24 | &mut self, 25 | last_included_term: u32, 26 | last_included_index: u32, 27 | log_entry: LogEntry, 28 | ); 29 | 30 | // Apply multiple log entries to the state machine in bulk 31 | async fn apply_log_entrys( 32 | &mut self, 33 | last_included_term: u32, 34 | last_included_index: u32, 35 | mut log_entrys: Vec, 36 | ); 37 | 38 | // Retrieve all log entries currently stored in the state machine 39 | async fn get_log_entry(&mut self) -> Result>; 40 | 41 | // Create a snapshot of the current state to the file system, and clear the log data after 42 | async fn create_snapshot(&mut self) -> Result<()>; 43 | 44 | // Check if a snapshot is required based on the interval since the last snapshot 45 | async fn need_create_snapshot(&mut self) -> bool; 46 | } 47 | 48 | #[derive(Debug, Clone, Serialize, Deserialize, Default)] 49 | pub struct FileStateMachine { 50 | last_included_term: u32, 51 | last_included_index: u32, 52 | data: Vec, 53 | 54 | #[serde(skip)] 55 | snapshot_path: Option>, 56 | 57 | /// The time interval between snapshots. 58 | #[serde(skip)] 59 | snapshot_interval: Duration, 60 | 61 | /// The time when snapshot generation started 62 | #[serde(skip)] 63 | snapshot_start_time: Option, 64 | 65 | /// Whether the state machine is currently generating a snapshot 66 | #[serde(skip)] 67 | is_snapshotting: bool, 68 | 69 | /// The time when the last snapshot was completed 70 | #[serde(skip)] 71 | last_snapshot_complete_time: Option, 72 | } 73 | 74 | impl FileStateMachine { 75 | /// Create a new FileStateMachine with initial values 76 | pub fn new(snapshot_path: &Path, snapshot_interval: Duration) -> Self { 77 | let current_time = Instant::now(); 78 | Self { 79 | snapshot_path: Some(PathBuf::from(snapshot_path).into_boxed_path()), 80 | last_included_term: 0, 81 | last_included_index: 0, 82 | data: Vec::new(), 83 | snapshot_interval, 84 | snapshot_start_time: Some(current_time), 85 | is_snapshotting: false, 86 | last_snapshot_complete_time: Some(current_time), 87 | } 88 | } 89 | } 90 | 91 | /// Implement the StateMachine trait for FileStateMachine 92 | /// Generate snapshots based on time intervals as I record start, end time 93 | /// The data in memory is cleared after generating a snapshot, to save memory 94 | #[async_trait] 95 | impl StateMachine for FileStateMachine { 96 | async fn get_term(&self) -> u32 { 97 | self.last_included_term 98 | } 99 | 100 | async fn get_index(&self) -> u32 { 101 | self.last_included_index 102 | } 103 | 104 | async fn apply_log_entry( 105 | &mut self, 106 | last_included_term: u32, 107 | last_included_index: u32, 108 | log_entry: LogEntry, 109 | ) { 110 | self.last_included_term = last_included_term; 111 | self.last_included_index = last_included_index; 112 | self.data.push(log_entry); 113 | } 114 | 115 | async fn apply_log_entrys( 116 | &mut self, 117 | last_included_term: u32, 118 | last_included_index: u32, 119 | mut log_entrys: Vec, 120 | ) { 121 | self.last_included_term = last_included_term; 122 | self.last_included_index = last_included_index; 123 | self.data.append(&mut log_entrys); 124 | } 125 | 126 | async fn create_snapshot(&mut self) -> Result<()> { 127 | let snapshot_path = if let Some(ref path) = self.snapshot_path { 128 | path 129 | } else { 130 | return Err(Error::Store(PathNotFound)); 131 | }; 132 | 133 | self.snapshot_start_time = Some(Instant::now()); 134 | self.is_snapshotting = true; 135 | 136 | // Step 1: Read the existing snapshot from file if it exists 137 | let mut existing_fsm = FileStateMachine::new(snapshot_path, Duration::from_secs(0)); 138 | if fs::metadata(snapshot_path).await.is_ok() { 139 | let mut file = OpenOptions::new().read(true).open(snapshot_path).await?; 140 | let mut buffer = Vec::new(); 141 | file.read_to_end(&mut buffer).await.map_err(Error::Io)?; 142 | 143 | if !buffer.is_empty() { 144 | existing_fsm = bincode::deserialize(&buffer).map_err(Error::BincodeError)?; 145 | } 146 | } 147 | 148 | // Step 2: Merge existing snapshot data 149 | self.data.splice(0..0, existing_fsm.data.drain(..)); 150 | 151 | // Step 3: Write the merged state back to the snapshot file 152 | let mut file = OpenOptions::new() 153 | .read(true) 154 | .write(true) 155 | .create(true) 156 | .truncate(true) 157 | .open(snapshot_path) 158 | .await?; 159 | let bytes = bincode::serialize(&self).map_err(Error::BincodeError)?; 160 | file.write_all(&bytes).await?; 161 | file.sync_all().await.map_err(Error::Io)?; 162 | 163 | // Step 4: Clear `data` after snapshot is created successfully 164 | self.data.clear(); 165 | 166 | self.last_snapshot_complete_time = Some(Instant::now()); 167 | self.is_snapshotting = false; 168 | 169 | Ok(()) 170 | } 171 | 172 | /// Check if a new snapshot is needed. 173 | async fn need_create_snapshot(&mut self) -> bool { 174 | if self.is_snapshotting { 175 | return false; // If we are currently snapshotting, we don't need another snapshot 176 | } 177 | 178 | if let Some(last_snapshot_time) = self.last_snapshot_complete_time { 179 | // Calculate the time since the last snapshot was completed 180 | let time_since_last_snapshot = Instant::now().duration_since(last_snapshot_time); 181 | 182 | // If the time since the last snapshot is greater than the snapshot interval, return true 183 | if time_since_last_snapshot >= self.snapshot_interval { 184 | // self.last_snapshot_complete_time = Some(Instant::now()); 185 | return true; 186 | } 187 | } else { 188 | // If we never completed a snapshot, we need to create one 189 | return true; 190 | } 191 | 192 | false 193 | } 194 | 195 | async fn get_log_entry(&mut self) -> Result> { 196 | let snapshot_path = if let Some(ref path) = self.snapshot_path { 197 | path 198 | } else { 199 | return Err(Error::Store(PathNotFound)); 200 | }; 201 | 202 | let mut existing_fsm = FileStateMachine::new(snapshot_path, Duration::from_secs(0)); 203 | if fs::metadata(snapshot_path).await.is_ok() { 204 | let mut file = OpenOptions::new().read(true).open(snapshot_path).await?; 205 | let mut buffer = Vec::new(); 206 | file.read_to_end(&mut buffer).await.map_err(Error::Io)?; 207 | 208 | if !buffer.is_empty() { 209 | existing_fsm = bincode::deserialize(&buffer).map_err(Error::BincodeError)?; 210 | } 211 | } 212 | 213 | self.data.splice(0..0, existing_fsm.data.drain(..)); 214 | 215 | Ok(self.data.clone()) 216 | } 217 | } 218 | 219 | #[cfg(test)] 220 | mod tests { 221 | use crate::server::LogCommand; 222 | 223 | use super::*; 224 | use tempfile::NamedTempFile; 225 | use tokio::time::sleep; 226 | 227 | #[tokio::test] 228 | async fn test_apply_log_entry() { 229 | let tmp_file = NamedTempFile::new().unwrap(); 230 | let snapshot_path = tmp_file.path().to_str().unwrap(); 231 | 232 | let mut fsm = FileStateMachine { 233 | snapshot_path: Option::Some(PathBuf::from(snapshot_path).into_boxed_path()), 234 | last_included_term: 0, 235 | last_included_index: 0, 236 | data: vec![], 237 | snapshot_interval: Duration::from_secs(300), 238 | snapshot_start_time: None, 239 | is_snapshotting: false, 240 | last_snapshot_complete_time: None, 241 | }; 242 | 243 | let log_entry = LogEntry { 244 | term: 1, 245 | command: LogCommand::Set, 246 | leader_id: 1, 247 | server_id: 1, 248 | data: 1, 249 | }; 250 | 251 | fsm.apply_log_entry(1, 1, log_entry.clone()).await; 252 | 253 | let log_entries = fsm.get_log_entry().await.unwrap(); 254 | assert_eq!(log_entries.len(), 1); 255 | assert_eq!(log_entries[0], log_entry); 256 | assert_eq!(fsm.last_included_term, 1); 257 | assert_eq!(fsm.last_included_index, 1); 258 | } 259 | 260 | #[tokio::test] 261 | async fn test_need_create_snapshot() { 262 | let mut fsm = FileStateMachine { 263 | snapshot_path: None, 264 | last_included_term: 0, 265 | last_included_index: 0, 266 | data: vec![], 267 | snapshot_interval: Duration::from_secs(1), 268 | snapshot_start_time: None, 269 | is_snapshotting: false, 270 | last_snapshot_complete_time: Some(Instant::now()), 271 | }; 272 | 273 | // Immediately after completing snapshot, no snapshot should be needed 274 | assert!(!fsm.need_create_snapshot().await); 275 | 276 | // Wait for more than the interval and check again 277 | sleep(Duration::from_secs(2)).await; 278 | assert!(fsm.need_create_snapshot().await); 279 | } 280 | 281 | #[tokio::test] 282 | async fn test_create_snapshot() { 283 | let tmp_file = NamedTempFile::new().unwrap(); 284 | let snapshot_path = tmp_file.path().to_str().unwrap(); 285 | 286 | // Create a FileStateMachine with some data 287 | let mut fsm = FileStateMachine { 288 | snapshot_path: Some(PathBuf::from(snapshot_path).into_boxed_path()), 289 | last_included_term: 1, 290 | last_included_index: 1, 291 | data: vec![ 292 | LogEntry { 293 | term: 1, 294 | command: LogCommand::Set, 295 | leader_id: 1, 296 | server_id: 1, 297 | data: 1, 298 | }, 299 | LogEntry { 300 | term: 2, 301 | command: LogCommand::Set, 302 | leader_id: 2, 303 | server_id: 2, 304 | data: 2, 305 | }, 306 | ], 307 | snapshot_interval: Duration::from_secs(300), 308 | snapshot_start_time: None, 309 | is_snapshotting: false, 310 | last_snapshot_complete_time: None, 311 | }; 312 | 313 | // Call create_snapshot and check result 314 | let result = fsm.create_snapshot().await; 315 | assert!(result.is_ok(), "Snapshot creation failed"); 316 | 317 | // Check that the snapshot file is created 318 | let metadata = std::fs::metadata(snapshot_path); 319 | assert!(metadata.is_ok(), "Snapshot file was not created"); 320 | 321 | // Read the file back and deserialize it to check contents 322 | let snapshot_data = std::fs::read(snapshot_path).unwrap(); 323 | let deserialized_fsm: FileStateMachine = bincode::deserialize(&snapshot_data).unwrap(); 324 | 325 | // Check that the deserialized data matches the original state machine 326 | assert_eq!(deserialized_fsm.data.len(), 2); 327 | 328 | // Check that the snapshot start and complete times were set 329 | assert!( 330 | fsm.snapshot_start_time.is_some(), 331 | "Snapshot start time was not set" 332 | ); 333 | assert!( 334 | fsm.last_snapshot_complete_time.is_some(), 335 | "Snapshot complete time was not set" 336 | ); 337 | 338 | // Check that is_snapshotting was set to false after completion 339 | assert!( 340 | !fsm.is_snapshotting, 341 | "is_snapshotting should be false after snapshot completion" 342 | ); 343 | } 344 | } 345 | -------------------------------------------------------------------------------- /src/storage.rs: -------------------------------------------------------------------------------- 1 | // Organization: SpacewalkHq 2 | // License: MIT License 3 | 4 | use std::io::{self, Cursor, SeekFrom}; 5 | use std::path::{Path, PathBuf}; 6 | use std::sync::Arc; 7 | 8 | use async_trait::async_trait; 9 | use hex; 10 | use sha2::{Digest, Sha256}; 11 | use tokio::fs::{self, File, OpenOptions}; 12 | use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}; 13 | use tokio::sync::Mutex; 14 | 15 | use crate::error::StorageError::CorruptFile; 16 | use crate::error::{Error, Result}; 17 | use crate::server::LogEntry; 18 | 19 | const MAX_FILE_SIZE: u64 = 1_000_000; 20 | pub const CHECKSUM_LEN: usize = 64; 21 | 22 | #[async_trait] 23 | pub trait Storage { 24 | async fn store(&self, data: &[u8]) -> Result<()>; 25 | async fn retrieve(&self) -> Result>; 26 | async fn compaction(&self) -> Result<()>; 27 | async fn delete(&self) -> Result<()>; 28 | async fn turned_malicious(&self) -> Result<()>; 29 | } 30 | 31 | #[derive(Clone)] 32 | pub struct LocalStorage { 33 | path: PathBuf, 34 | file: Arc>, 35 | } 36 | 37 | impl LocalStorage { 38 | pub async fn new(path: String) -> Self { 39 | // Ensure the parent directory exists 40 | if let Some(parent) = Path::new(&path).parent() { 41 | fs::create_dir_all(parent).await.unwrap(); 42 | } 43 | 44 | let file = OpenOptions::new() 45 | .read(true) 46 | .write(true) 47 | .create(true) 48 | .truncate(false) 49 | .open(path.clone()) 50 | .await 51 | .unwrap(); 52 | 53 | LocalStorage { 54 | path: path.into(), 55 | file: Arc::new(Mutex::new(file)), 56 | } 57 | } 58 | 59 | pub async fn new_from_path(path: &Path) -> Self { 60 | let file = OpenOptions::new() 61 | .read(true) 62 | .write(true) 63 | .create(true) 64 | .truncate(false) 65 | .open(path) 66 | .await 67 | .unwrap(); 68 | 69 | LocalStorage { 70 | path: path.into(), 71 | file: Arc::new(Mutex::new(file)), 72 | } 73 | } 74 | 75 | pub async fn check_storage(&self) -> io::Result<()> { 76 | if let Some(parent_path) = self.path.parent() { 77 | fs::create_dir_all(parent_path).await?; 78 | } 79 | 80 | if !self.path.exists() { 81 | fs::File::create(&self.path).await?; 82 | } 83 | 84 | Ok(()) 85 | } 86 | 87 | /// Asynchronously stores the provided data along with its checksum into a file. 88 | /// 89 | /// # Arguments 90 | /// * `data` - A slice of bytes representing the data to be stored. 91 | async fn store_async(&self, data: &[u8]) -> Result<()> { 92 | let checksum = calculate_checksum(data); 93 | let data_with_checksum = [data, checksum.as_slice()].concat(); 94 | 95 | let file = Arc::clone(&self.file); 96 | let mut locked_file = file.lock().await; 97 | 98 | locked_file.seek(SeekFrom::End(0)).await.unwrap(); 99 | 100 | locked_file 101 | .write_all(&data_with_checksum) 102 | .await 103 | .map_err(Error::Io)?; 104 | 105 | // Attempts to sync all OS-internal metadata to disk. 106 | locked_file.sync_all().await.map_err(Error::Io)?; 107 | 108 | Ok(()) 109 | } 110 | 111 | /// Asynchronously retrieves all data from the file. 112 | async fn retrieve_async(&self) -> Result> { 113 | let file = Arc::clone(&self.file); 114 | let mut locked_file = file.lock().await; 115 | locked_file.seek(SeekFrom::Start(0)).await.unwrap(); 116 | 117 | let mut buffer = Vec::new(); 118 | locked_file 119 | .read_to_end(&mut buffer) 120 | .await 121 | .map_err(Error::Io)?; 122 | 123 | Ok(buffer) 124 | } 125 | 126 | async fn delete_async(&self) -> Result<()> { 127 | fs::remove_file(&self.path).await.map_err(Error::Io)?; 128 | Ok(()) 129 | } 130 | 131 | async fn compaction_async(&self) -> Result<()> { 132 | // If file size is greater than 1MB, then compact it 133 | let file = Arc::clone(&self.file); 134 | let locked_file = file.lock().await; 135 | let metadata = locked_file.metadata().await.map_err(Error::Io)?; 136 | if metadata.len() > MAX_FILE_SIZE { 137 | self.delete_async().await?; 138 | } 139 | Ok(()) 140 | } 141 | 142 | async fn is_file_size_exceeded(&self) -> Result<()> { 143 | let file = Arc::clone(&self.file); 144 | let locked_file = file.lock().await; 145 | 146 | let md = locked_file.metadata().await.map_err(Error::Io)?; 147 | if md.len() > MAX_FILE_SIZE { 148 | return Err(Error::Store(CorruptFile)); 149 | } 150 | 151 | Ok(()) 152 | } 153 | } 154 | 155 | #[async_trait] 156 | impl Storage for LocalStorage { 157 | async fn store(&self, data: &[u8]) -> Result<()> { 158 | self.store_async(data).await 159 | } 160 | 161 | async fn retrieve(&self) -> Result> { 162 | self.retrieve_async().await 163 | } 164 | 165 | async fn compaction(&self) -> Result<()> { 166 | self.compaction_async().await 167 | } 168 | 169 | async fn delete(&self) -> Result<()> { 170 | self.delete_async().await 171 | } 172 | 173 | async fn turned_malicious(&self) -> Result<()> { 174 | self.is_file_size_exceeded().await.unwrap(); 175 | 176 | let disk_data = self.retrieve().await?; 177 | let log_entry_size = std::mem::size_of::(); 178 | 179 | if disk_data.len() % (log_entry_size + CHECKSUM_LEN) != 0 { 180 | return Err(Error::Store(CorruptFile)); 181 | } 182 | 183 | let mut cursor = Cursor::new(&disk_data); 184 | loop { 185 | let mut bytes_data = vec![0u8; log_entry_size]; 186 | if let Err(err) = cursor.read_exact(&mut bytes_data).await { 187 | if err.kind() == std::io::ErrorKind::UnexpectedEof { 188 | break; 189 | } else { 190 | return Err(Error::Io(err)); 191 | } 192 | } 193 | 194 | let byte_data_checksum = calculate_checksum(&bytes_data); 195 | 196 | let mut checksum = [0u8; CHECKSUM_LEN]; 197 | if let Err(err) = cursor.read_exact(&mut checksum).await { 198 | if err.kind() == std::io::ErrorKind::UnexpectedEof { 199 | break; 200 | } else { 201 | return Err(Error::Io(err)); 202 | } 203 | } 204 | 205 | if byte_data_checksum.ne(&checksum) { 206 | return Err(Error::Store(CorruptFile)); 207 | } 208 | } 209 | 210 | Ok(()) 211 | } 212 | } 213 | 214 | /// This function computes the SHA-256 hash of the given byte slice and returns 215 | /// a fixed-size array of bytes (`[u8; CHECKSUM_LEN]`). 216 | /// The resulting checksum is encoded in hexadecimal format. 217 | fn calculate_checksum(data: &[u8]) -> [u8; CHECKSUM_LEN] { 218 | let mut hasher = Sha256::new(); 219 | hasher.update(data); 220 | let result = hasher.finalize(); 221 | let mut checksum = [0u8; 64]; 222 | checksum.copy_from_slice(hex::encode(result).as_bytes()); 223 | checksum 224 | } 225 | 226 | #[cfg(test)] 227 | mod tests { 228 | 229 | use std::io::{Cursor, SeekFrom}; 230 | 231 | use tempfile::NamedTempFile; 232 | use tokio::{ 233 | fs::OpenOptions, 234 | io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt}, 235 | }; 236 | 237 | use crate::{ 238 | server::{LogCommand, LogEntry}, 239 | storage::{calculate_checksum, LocalStorage, Storage, CHECKSUM_LEN}, 240 | }; 241 | 242 | /// Helper function to extract the checksum from the end of a given byte slice. 243 | /// It assumes that the checksum is of a fixed length `CHECKSUM_LEN` and is located 244 | /// at the end of the provided data slice. 245 | /// 246 | /// This function will panic if the length of the provided data slice is less than `CHECKSUM_LEN`. 247 | fn retrieve_checksum(data: &[u8]) -> [u8; CHECKSUM_LEN] { 248 | assert!(data.len() >= CHECKSUM_LEN); 249 | let mut op = [0; 64]; 250 | op.copy_from_slice(&data[data.len() - CHECKSUM_LEN..]); 251 | op 252 | } 253 | 254 | #[test] 255 | fn test_retrieve_checksum() { 256 | let data_str = "Some data followed by a checksum".as_bytes(); 257 | let calculated_checksum = calculate_checksum(data_str); 258 | 259 | let data = [data_str, calculated_checksum.as_slice()].concat(); 260 | 261 | let retrieved_checksum = retrieve_checksum(&data); 262 | assert_eq!(calculated_checksum, retrieved_checksum); 263 | } 264 | 265 | #[tokio::test] 266 | async fn test_store_async() { 267 | let tmp_file = NamedTempFile::new().unwrap(); 268 | let storage: Box = 269 | Box::new(LocalStorage::new_from_path(tmp_file.path()).await); 270 | 271 | let payload_data = "Some data to test raft".as_bytes(); 272 | let store_result = storage.store(payload_data).await; 273 | assert!(store_result.is_ok()); 274 | 275 | let buffer = storage.retrieve().await.unwrap(); 276 | 277 | // Verify the length of the stored data (original data + checksum). 278 | assert_eq!( 279 | payload_data.len() + CHECKSUM_LEN, 280 | buffer.len(), 281 | "Stored data length mismatch" 282 | ); 283 | 284 | let stored_data = &buffer[..buffer.len() - CHECKSUM_LEN]; 285 | // Verify the original data matches the input data. 286 | assert_eq!(payload_data, stored_data, "Stored data mismatch"); 287 | } 288 | 289 | #[tokio::test] 290 | async fn test_delete() { 291 | let tmp_file = NamedTempFile::new().unwrap(); 292 | let storage: Box = 293 | Box::new(LocalStorage::new_from_path(tmp_file.path()).await); 294 | 295 | let delete_result = storage.delete().await; 296 | assert!(delete_result.is_ok()); 297 | assert!(!tmp_file.path().exists()); 298 | } 299 | 300 | #[tokio::test] 301 | async fn test_compaction_file_lt_max_file_size() { 302 | let tmp_file = NamedTempFile::new().unwrap(); 303 | let storage: Box = 304 | Box::new(LocalStorage::new_from_path(tmp_file.path()).await); 305 | let mock_data = vec![0u8; 1_000_000 /*1 MB*/ - 500]; 306 | 307 | let store_result = storage.store(&mock_data).await; 308 | assert!(store_result.is_ok()); 309 | 310 | let compaction_result = storage.compaction().await; 311 | assert!(compaction_result.is_ok()); 312 | 313 | assert!(tmp_file.path().exists()); 314 | } 315 | 316 | #[tokio::test] 317 | async fn test_compaction_file_gt_max_file_size() { 318 | let tmp_file = NamedTempFile::new().unwrap(); 319 | let storage: Box = 320 | Box::new(LocalStorage::new_from_path(tmp_file.path()).await); 321 | let mock_data = vec![0u8; 1_000_000 /*1 MB*/]; 322 | 323 | let store_result = storage.store(&mock_data).await; 324 | assert!(store_result.is_ok()); 325 | 326 | let compaction_result = storage.compaction().await; 327 | assert!(compaction_result.is_ok()); 328 | 329 | assert!(!tmp_file.path().exists()); 330 | } 331 | 332 | #[tokio::test] 333 | async fn test_retrieve() { 334 | let tmp_file = NamedTempFile::new().unwrap(); 335 | let storage: Box = 336 | Box::new(LocalStorage::new_from_path(tmp_file.path()).await); 337 | let log_entry_size = std::mem::size_of::(); 338 | 339 | // Insert the first data first 340 | let entry1 = LogEntry { 341 | leader_id: 1, 342 | server_id: 1, 343 | term: 1, 344 | command: LogCommand::Set, 345 | data: 1, 346 | }; 347 | let serialize_data = bincode::serialize(&entry1).unwrap(); 348 | storage.store(&serialize_data).await.unwrap(); 349 | let disk_data = storage.retrieve().await.unwrap(); 350 | let log_entry_bytes = &disk_data[0..log_entry_size]; 351 | let disk_entry: LogEntry = bincode::deserialize(log_entry_bytes).unwrap(); 352 | assert_eq!(entry1, disk_entry); 353 | 354 | // Then insert the second data 355 | let entry2 = LogEntry { 356 | leader_id: 2, 357 | server_id: 2, 358 | term: 2, 359 | command: LogCommand::Set, 360 | data: 2, 361 | }; 362 | let serialize_data = bincode::serialize(&entry2).unwrap(); 363 | storage.store(&serialize_data).await.unwrap(); 364 | let disk_data = storage.retrieve().await.unwrap(); 365 | 366 | // Try to read two pieces of data and sit down to compare 367 | let mut log_entrys = vec![]; 368 | let mut cursor = Cursor::new(&disk_data); 369 | loop { 370 | let mut bytes_data = vec![0u8; log_entry_size]; 371 | if cursor.read_exact(&mut bytes_data).await.is_err() { 372 | break; 373 | } 374 | let struct_data: LogEntry = bincode::deserialize(&bytes_data).unwrap(); 375 | 376 | let mut checksum = [0u8; CHECKSUM_LEN]; 377 | if cursor.read_exact(&mut checksum).await.is_err() { 378 | break; 379 | } 380 | 381 | log_entrys.push(struct_data); 382 | } 383 | 384 | assert_eq!(vec![entry1, entry2], log_entrys); 385 | } 386 | 387 | #[tokio::test] 388 | async fn test_turned_malicious_file_corrupted() { 389 | let tmp_file = NamedTempFile::new().unwrap(); 390 | let storage: Box = 391 | Box::new(LocalStorage::new_from_path(tmp_file.path()).await); 392 | 393 | // Try to write the data once 394 | let entry1 = LogEntry { 395 | leader_id: 1, 396 | server_id: 1, 397 | term: 1, 398 | command: LogCommand::Set, 399 | data: 1, 400 | }; 401 | let serialize_data = bincode::serialize(&entry1).unwrap(); 402 | let store_result = storage.store(&serialize_data).await; 403 | assert!(store_result.is_ok()); 404 | 405 | // We will go to simulate that the data is corrupted and does not conform to the original format 406 | // [(LogEntry, checksum), ...] 407 | let mut file = OpenOptions::new() 408 | .read(true) 409 | .write(true) 410 | .create(true) 411 | .open(tmp_file.path()) 412 | .await 413 | .unwrap(); 414 | file.seek(SeekFrom::Start(0)).await.unwrap(); 415 | 416 | file.write_all("Raft".as_bytes()).await.unwrap(); 417 | file.seek(SeekFrom::Start(0)).await.unwrap(); 418 | let mut buffer = vec![]; 419 | file.read_to_end(&mut buffer).await.unwrap(); 420 | file.sync_all().await.unwrap(); 421 | 422 | let storage: Box = 423 | Box::new(LocalStorage::new_from_path(tmp_file.path()).await); 424 | let result = storage.turned_malicious().await; 425 | assert!(result.is_err()); 426 | } 427 | 428 | #[tokio::test] 429 | async fn test_turned_malicious_happy_case() { 430 | let tmp_file = NamedTempFile::new().unwrap(); 431 | let storage: Box = 432 | Box::new(LocalStorage::new_from_path(tmp_file.path()).await); 433 | 434 | // Try to write the data once 435 | let entry1 = LogEntry { 436 | leader_id: 1, 437 | server_id: 1, 438 | term: 1, 439 | command: LogCommand::Set, 440 | data: 1, 441 | }; 442 | let serialize_data = bincode::serialize(&entry1).unwrap(); 443 | let store_result = storage.store(&serialize_data).await; 444 | assert!(store_result.is_ok()); 445 | 446 | let result = storage.turned_malicious().await; 447 | assert!(result.is_ok()); 448 | } 449 | } 450 | --------------------------------------------------------------------------------