├── .gitignore ├── .vscode └── launch.json ├── Cargo.lock ├── Cargo.toml ├── README.md ├── hosts ├── rustfmt.toml └── src ├── balancer ├── algorithms │ ├── mod.rs │ └── round_robin.rs ├── balancer.rs ├── balancing_algorithm.rs ├── client.rs ├── host_manager.rs ├── mod.rs └── poller.rs └── main.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "type": "lldb", 9 | "request": "launch", 10 | "name": "Debug executable 'load-balancer-rust'", 11 | "cargo": { 12 | "args": [ 13 | "build", 14 | "--bin=load-balancer-rust", 15 | "--package=load-balancer-rust" 16 | ], 17 | "filter": { 18 | "name": "load-balancer-rust", 19 | "kind": "bin" 20 | } 21 | }, 22 | "args": [ "7777" ], 23 | "cwd": "${workspaceFolder}" 24 | }, 25 | { 26 | "type": "lldb", 27 | "request": "launch", 28 | "name": "Debug unit tests in executable 'load-balancer-rust'", 29 | "cargo": { 30 | "args": [ 31 | "test", 32 | "--no-run", 33 | "--bin=load-balancer-rust", 34 | "--package=load-balancer-rust" 35 | ], 36 | "filter": { 37 | "name": "load-balancer-rust", 38 | "kind": "bin" 39 | } 40 | }, 41 | "args": [ "7777" ], 42 | "cwd": "${workspaceFolder}" 43 | } 44 | ] 45 | } -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "bitflags" 7 | version = "1.3.2" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" 10 | 11 | [[package]] 12 | name = "cfg-if" 13 | version = "1.0.0" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 16 | 17 | [[package]] 18 | name = "ctrlc" 19 | version = "3.2.2" 20 | source = "registry+https://github.com/rust-lang/crates.io-index" 21 | checksum = "b37feaa84e6861e00a1f5e5aa8da3ee56d605c9992d33e082786754828e20865" 22 | dependencies = [ 23 | "nix", 24 | "winapi", 25 | ] 26 | 27 | [[package]] 28 | name = "libc" 29 | version = "0.2.126" 30 | source = "registry+https://github.com/rust-lang/crates.io-index" 31 | checksum = "349d5a591cd28b49e1d1037471617a32ddcda5731b99419008085f72d5a53836" 32 | 33 | [[package]] 34 | name = "load-balancer-rust" 35 | version = "0.1.0" 36 | dependencies = [ 37 | "ctrlc", 38 | "mio", 39 | ] 40 | 41 | [[package]] 42 | name = "log" 43 | version = "0.4.17" 44 | source = "registry+https://github.com/rust-lang/crates.io-index" 45 | checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" 46 | dependencies = [ 47 | "cfg-if", 48 | ] 49 | 50 | [[package]] 51 | name = "mio" 52 | version = "0.8.3" 53 | source = "registry+https://github.com/rust-lang/crates.io-index" 54 | checksum = "713d550d9b44d89174e066b7a6217ae06234c10cb47819a88290d2b353c31799" 55 | dependencies = [ 56 | "libc", 57 | "log", 58 | "wasi", 59 | "windows-sys", 60 | ] 61 | 62 | [[package]] 63 | name = "nix" 64 | version = "0.24.1" 65 | source = "registry+https://github.com/rust-lang/crates.io-index" 66 | checksum = "8f17df307904acd05aa8e32e97bb20f2a0df1728bbc2d771ae8f9a90463441e9" 67 | dependencies = [ 68 | "bitflags", 69 | "cfg-if", 70 | "libc", 71 | ] 72 | 73 | [[package]] 74 | name = "wasi" 75 | version = "0.11.0+wasi-snapshot-preview1" 76 | source = "registry+https://github.com/rust-lang/crates.io-index" 77 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 78 | 79 | [[package]] 80 | name = "winapi" 81 | version = "0.3.9" 82 | source = "registry+https://github.com/rust-lang/crates.io-index" 83 | checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" 84 | dependencies = [ 85 | "winapi-i686-pc-windows-gnu", 86 | "winapi-x86_64-pc-windows-gnu", 87 | ] 88 | 89 | [[package]] 90 | name = "winapi-i686-pc-windows-gnu" 91 | version = "0.4.0" 92 | source = "registry+https://github.com/rust-lang/crates.io-index" 93 | checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 94 | 95 | [[package]] 96 | name = "winapi-x86_64-pc-windows-gnu" 97 | version = "0.4.0" 98 | source = "registry+https://github.com/rust-lang/crates.io-index" 99 | checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 100 | 101 | [[package]] 102 | name = "windows-sys" 103 | version = "0.36.1" 104 | source = "registry+https://github.com/rust-lang/crates.io-index" 105 | checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2" 106 | dependencies = [ 107 | "windows_aarch64_msvc", 108 | "windows_i686_gnu", 109 | "windows_i686_msvc", 110 | "windows_x86_64_gnu", 111 | "windows_x86_64_msvc", 112 | ] 113 | 114 | [[package]] 115 | name = "windows_aarch64_msvc" 116 | version = "0.36.1" 117 | source = "registry+https://github.com/rust-lang/crates.io-index" 118 | checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47" 119 | 120 | [[package]] 121 | name = "windows_i686_gnu" 122 | version = "0.36.1" 123 | source = "registry+https://github.com/rust-lang/crates.io-index" 124 | checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6" 125 | 126 | [[package]] 127 | name = "windows_i686_msvc" 128 | version = "0.36.1" 129 | source = "registry+https://github.com/rust-lang/crates.io-index" 130 | checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024" 131 | 132 | [[package]] 133 | name = "windows_x86_64_gnu" 134 | version = "0.36.1" 135 | source = "registry+https://github.com/rust-lang/crates.io-index" 136 | checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1" 137 | 138 | [[package]] 139 | name = "windows_x86_64_msvc" 140 | version = "0.36.1" 141 | source = "registry+https://github.com/rust-lang/crates.io-index" 142 | checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680" 143 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "load-balancer-rust" 3 | version = "0.1.0" 4 | authors = ["CryShana "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | ctrlc = "3.1.9" 9 | mio = "0.8.0" 10 | 11 | [features] 12 | default = ["mio/os-poll", "mio/net"] 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Load Balancer Rust 2 | Simple high-performance TCP-level load balancer / reverse proxy made in Rust 3 | 4 | ## Why? 5 | Because sometimes you just need a very simple tool with minimal configuration and that just works. 6 | 7 | Also because I just wanted to try using Rust and this was fun to do. 8 | 9 | ## Usage 10 | A `hosts` file is required in the same directory from where you're calling the program. Should contain all servers' `[HOSTNAME]:[PORT]` on every new line. 11 | 12 | Example `hosts` file content: 13 | ``` 14 | localhost:5000 15 | 127.0.0.1:5001 16 | domain.com:80 17 | ``` 18 | 19 | Running the program: (will listen on port 7777) 20 | ```sh 21 | ./load-balancer-rust 7777 22 | ``` 23 | 24 | ## Balancing algorithms 25 | As of right now, only *Round Robin* is implemented. Every time a connection to a server is lost due to an error, the server is marked as unavailable and is avoided for some time. To avoid losing time on constantly trying to connect clients to an offline server. 26 | 27 | ## Issues 28 | Not yet fully optimized for Windows. Some weird behavior causing slower response times than on Linux. 29 | 30 | ## Performance testing 31 | Some load testing was done using the [k6](https://k6.io/) tool to get an idea of relative performance. All testing was done on a system running Ubuntu 20.04. 32 | 33 | The performed test was defined as: 34 | ```js 35 | import http from 'k6/http'; 36 | import { sleep } from 'k6'; 37 | 38 | export let options = { 39 | stages: [ 40 | { duration: '20s', target: 1000 }, // slowly ramp-up traffic from 1 to 1000 users over 20 seconds 41 | { duration: '1m', target: 1000 }, // remain at 1000 users for 1 minute 42 | { duration: '20s', target: 0 }, // slowly ramp-down to 0 users 43 | ] 44 | }; 45 | 46 | const BASE_URL = 'http://localhost:7777'; 47 | 48 | export default () => { 49 | let res = http.get(`${BASE_URL}/test_endpoint`); 50 | sleep(1); 51 | }; 52 | ``` 53 | 54 | ### Reference 55 | The test was first ran directly against a local web server to get a reference point: 56 | ![](https://cryshana.me/f/T2bwGCVdYM04.png) 57 | 58 | Average response time was around **0.25ms**. 59 | 60 | ### nginx 1.18.0 61 | I then set up a reverse proxy on nginx like so: 62 | ```nginx 63 | server { 64 | listen 6666; 65 | listen [::]:6666; 66 | 67 | location / { 68 | proxy_pass http://localhost:5000 69 | } 70 | } 71 | ``` 72 | And ran the test against nginx and got the following results: 73 | ![](https://cryshana.me/f/uVmlKwSzzRJm.png) 74 | 75 | Average response time was around **0.55ms**. (an overhead of about 0.30ms) 76 | 77 | ### Load Balancer Rust 78 | And then I tried using my own tool - using just one host to test reverse proxying performance: 79 | 80 | ![](https://cryshana.me/f/CAXD08i5DyaH.png) 81 | 82 | Average response time was around **0.36ms**. (an overhead of about 0.11ms, about **52% faster** than nginx) 83 | -------------------------------------------------------------------------------- /hosts: -------------------------------------------------------------------------------- 1 | localhost:5000 2 | localhost:5001 -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width = 160 -------------------------------------------------------------------------------- /src/balancer/algorithms/mod.rs: -------------------------------------------------------------------------------- 1 | mod round_robin; 2 | 3 | pub use round_robin::RoundRobin; 4 | use super::BalancingAlgorithm; 5 | use super::HostManager; -------------------------------------------------------------------------------- /src/balancer/algorithms/round_robin.rs: -------------------------------------------------------------------------------- 1 | use std::net::SocketAddr; 2 | use std::time::Duration; 3 | use std::time::Instant; 4 | use std::usize; 5 | 6 | use super::BalancingAlgorithm; 7 | use super::HostManager; 8 | 9 | pub struct RoundRobin { 10 | current_host: usize, 11 | max_host: usize, 12 | host_manager: HostManager, 13 | cooldowns: Vec<(SocketAddr, Instant)>, 14 | } 15 | 16 | impl RoundRobin { 17 | // how long the host is avoided (on cooldown) when first error is reported 18 | const TARGET_DOWN_COOLDOWN: Duration = Duration::from_secs(30); 19 | 20 | pub fn new(host_manager: HostManager) -> Self { 21 | let max = host_manager.hosts.len(); 22 | RoundRobin { 23 | current_host: 0, 24 | host_manager: host_manager, 25 | max_host: max, 26 | cooldowns: vec![], 27 | } 28 | } 29 | 30 | fn get_host_cooldown_index(&self, addr: SocketAddr) -> i32 { 31 | let mut index: i32 = -1; 32 | for i in 0..self.cooldowns.len() { 33 | if self.cooldowns[i].0 == addr { 34 | index = i as i32; 35 | break; 36 | } 37 | } 38 | 39 | index 40 | } 41 | 42 | fn increment_host_counter(&mut self) { 43 | self.current_host = self.current_host + 1; 44 | if self.current_host >= self.max_host { 45 | self.current_host = 0 46 | } 47 | } 48 | } 49 | 50 | impl BalancingAlgorithm for RoundRobin { 51 | fn get_next_host(&mut self) -> SocketAddr { 52 | let mut val; 53 | let starting_host_index = self.current_host; 54 | 55 | loop { 56 | // select host 57 | val = self.host_manager.hosts[self.current_host]; 58 | 59 | // offset host selector to next one 60 | self.increment_host_counter(); 61 | 62 | // if host on cooldown, avoid it (but if we made a full cycle, just return the initial choice) 63 | let cooldown_index = self.get_host_cooldown_index(val); 64 | let cycle_reached = starting_host_index == self.current_host; 65 | if cooldown_index >= 0 && !cycle_reached { 66 | // check if cooldown has passed 67 | if Instant::now() > self.cooldowns[cooldown_index as usize].1 { 68 | // cooldown passed, remove it 69 | self.cooldowns.remove(cooldown_index as usize); 70 | break; 71 | } 72 | 73 | continue; 74 | } else if cycle_reached { 75 | // cycle reached, let's increment the counter to continue trying different hosts until one actually connects 76 | self.increment_host_counter(); 77 | } 78 | 79 | break; 80 | } 81 | 82 | val 83 | } 84 | 85 | fn report_error(&mut self, addr: SocketAddr) { 86 | let index: i32 = self.get_host_cooldown_index(addr); 87 | 88 | let new_limit = Instant::now() + RoundRobin::TARGET_DOWN_COOLDOWN; 89 | 90 | if index < 0 { 91 | // add it 92 | self.cooldowns.push((addr, new_limit)); 93 | } else { 94 | // update it 95 | self.cooldowns[index as usize].1 = new_limit; 96 | } 97 | } 98 | 99 | fn report_success(&mut self, addr: SocketAddr) { 100 | let index: i32 = self.get_host_cooldown_index(addr); 101 | if index < 0 { 102 | return; 103 | } 104 | 105 | self.cooldowns.remove(index as usize); 106 | } 107 | 108 | fn is_on_cooldown(&self, addr: SocketAddr) -> bool { 109 | let index: i32 = self.get_host_cooldown_index(addr); 110 | return index >= 0; 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/balancer/balancer.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::io::ErrorKind; 3 | use std::sync::Arc; 4 | use std::sync::RwLock; 5 | use std::usize; 6 | use std::vec; 7 | use std::{thread, time::Duration, u16}; 8 | 9 | use super::BalancingAlgorithm; 10 | use super::RoundRobin; 11 | use super::TcpClient; 12 | use mio::net::TcpStream; 13 | use mio::Events; 14 | use mio::Interest; 15 | use mio::Poll; 16 | use mio::Token; 17 | 18 | // this is used as the total timeout allowed to connect before client is disconnected 19 | const TOTAL_CONNECTION_TIMEOUT: Duration = Duration::from_millis(4000); 20 | 21 | // this is used as the timeout to connect to a target host 22 | const CONNECTION_TIMEOUT: Duration = Duration::from_millis(400); 23 | 24 | pub struct LoadBalancer { 25 | /** 26 | Holds client counts for all threads 27 | */ 28 | client_counts: Arc>>>>, 29 | /** 30 | Newly added clients are added here, threads will add them to polling when they can 31 | */ 32 | client_lists_pending: Arc>>>>>, 33 | threads: u16, 34 | stopped: Arc>, 35 | debug: Arc>, 36 | balancing_algorithm: Arc>, 37 | } 38 | 39 | impl LoadBalancer { 40 | pub fn new(balancing_algorithm: RoundRobin, threads: u16, debug: bool) -> Self { 41 | // prepare client lists for every thread 42 | let mut client_counts: Vec>> = vec![]; 43 | for _ in 0..threads { 44 | client_counts.push(Arc::new(RwLock::new(0))); 45 | } 46 | let client_counts = Arc::new(RwLock::new(client_counts)); 47 | 48 | // prepare pending client lists for every thread 49 | let mut client_lists_pending: Vec>>> = vec![]; 50 | for _ in 0..threads { 51 | let lists: Vec = vec![]; 52 | client_lists_pending.push(Arc::new(RwLock::new(lists))); 53 | } 54 | let client_lists_pending = Arc::new(RwLock::new(client_lists_pending)); 55 | 56 | let b = LoadBalancer { 57 | client_counts, 58 | client_lists_pending, 59 | threads, 60 | stopped: Arc::new(RwLock::new(false)), 61 | debug: Arc::new(RwLock::new(debug)), 62 | balancing_algorithm: Arc::new(RwLock::new(balancing_algorithm)), 63 | }; 64 | 65 | b 66 | } 67 | 68 | pub fn start(&mut self) { 69 | self.spawn_threads(); 70 | } 71 | 72 | pub fn add_client(&mut self, stream: TcpStream) { 73 | let client = TcpClient::new(stream); 74 | 75 | // pick client list with least clients and add it to pending list 76 | let client_counts = self.client_counts.read().unwrap(); 77 | let client_lists_pending = self.client_lists_pending.read().unwrap(); 78 | 79 | // find client list with least clients first 80 | let mut min_index = 0; 81 | let mut min_length = *client_counts[0].read().unwrap(); 82 | for i in 1..client_counts.len() { 83 | let len = *client_counts[i].read().unwrap(); 84 | if len < min_length { 85 | min_length = len; 86 | min_index = i; 87 | } 88 | } 89 | 90 | if *self.debug.read().unwrap() { 91 | println!("[Thread {}] Connected from {}", min_index, client.address); 92 | } 93 | 94 | // add client to pending list 95 | client_lists_pending[min_index].write().unwrap().push(client); 96 | } 97 | 98 | pub fn stop(&mut self) { 99 | *self.stopped.write().unwrap() = true; 100 | } 101 | 102 | fn spawn_threads(&mut self) { 103 | let th = self.threads as u32; 104 | 105 | // WORKERS 106 | for id in 0..th { 107 | let stopped = Arc::clone(&self.stopped); 108 | let d = Arc::clone(&self.debug); 109 | let b = Arc::clone(&self.balancing_algorithm); 110 | let client_counts = Arc::clone(&self.client_counts); 111 | let client_list_pending = Arc::clone(&self.client_lists_pending); 112 | 113 | thread::spawn(move || { 114 | let mut connected_sockets: HashMap = HashMap::new(); 115 | let mut next_token_id: usize = 0; 116 | 117 | let mut get_next_token = || { 118 | let token = Token(next_token_id); 119 | next_token_id += 1; 120 | if next_token_id >= usize::MAX { 121 | next_token_id = 1; 122 | } 123 | token 124 | }; 125 | 126 | let client_list_index = id as usize; 127 | 128 | let mut poll = Poll::new().unwrap(); 129 | let mut events = Events::with_capacity(1024); 130 | 131 | loop { 132 | // keep checking if balancer has been stopped 133 | if *stopped.read().unwrap() { 134 | break; 135 | } 136 | 137 | // ------------------------------- 138 | // EVENT POLLING 139 | // ------------------------------- 140 | match poll.poll(&mut events, Some(Duration::from_millis(10))) { 141 | Ok(_) => {} 142 | Err(ref e) if e.kind() == ErrorKind::Interrupted => { 143 | // this handler does not get called on Windows, so we use timeout and check it outside 144 | *stopped.write().unwrap() = true; 145 | } 146 | Err(e) => { 147 | println!("[Thread {}] Failed to poll for events! {}", id, e.to_string()); 148 | break; 149 | } 150 | }; 151 | 152 | // ------------------------------- 153 | // PROCESS PENDING CLIENTS 154 | // ------------------------------- 155 | { 156 | // check if any pending clients (try to read to avoid blocking) 157 | let r: i32 = match client_list_pending.read().unwrap()[client_list_index].try_read() { 158 | Ok(r) => r.len() as i32, 159 | Err(_) => -1, 160 | }; 161 | if r > 0 { 162 | let p_list = &*client_list_pending.read().unwrap()[client_list_index]; 163 | 164 | let pending = &mut *match p_list.try_write() { 165 | Ok(w) => w, 166 | Err(_) => continue, 167 | }; 168 | 169 | // move all pending clients over to our client_list and register them with poll 170 | let plen = pending.len(); 171 | for i in 0..plen { 172 | let index = (plen - 1) - i; 173 | let mut client = pending.remove(index); 174 | 175 | let token = get_next_token(); 176 | 177 | poll.registry().register(&mut client.stream, token, Interest::READABLE).unwrap(); 178 | 179 | // insert into hashmap for quick lookup 180 | connected_sockets.insert(token, client); 181 | } 182 | 183 | // update count 184 | *client_counts.read().unwrap()[client_list_index].write().unwrap() = connected_sockets.len(); 185 | } 186 | } 187 | 188 | // ------------------------------- 189 | // CLIENT CHECKING (timeout handling) 190 | // ------------------------------- 191 | { 192 | // check for connecting clients for time outs and their current state 193 | let mut tokens_to_remove: Vec> = vec![]; 194 | for (token, client) in &mut connected_sockets { 195 | // if client not connected, schedule for removal 196 | if !client.is_client_connected() { 197 | let t = Box::new(token.clone()); 198 | tokens_to_remove.push(t); 199 | continue; 200 | } 201 | 202 | // if client not in IN_CONNECTING state, we can't check for time outs 203 | if !client.is_connecting() { 204 | continue; 205 | } 206 | 207 | // HANDLE TIMEOUT TO SINGLE TARGET 208 | if client.started_connecting.elapsed() > CONNECTION_TIMEOUT { 209 | if *d.read().unwrap() { 210 | println!( 211 | "[Thread {}] Connection to target timed out ({} <-> {})", 212 | id, 213 | client.address, 214 | client.get_target_addr().unwrap() 215 | ); 216 | } 217 | 218 | // we timed out! Let's try another host 219 | client.close_connection_to_target(true); 220 | LoadBalancer::report_target_error(client, Arc::clone(&b)); 221 | LoadBalancer::start_connection(id, token.clone(), client, &poll, Arc::clone(&d), Arc::clone(&b)); 222 | } 223 | 224 | // HANDLE TOTAL TIMEOUT 225 | if client.last_connection_loss.elapsed() > TOTAL_CONNECTION_TIMEOUT { 226 | if *d.read().unwrap() { 227 | println!("[Thread {}] Timed out ({})", id, client.address); 228 | } 229 | 230 | // we timed out completely! 231 | client.close_connection(); 232 | } 233 | } 234 | 235 | // now remove the marked clients 236 | if tokens_to_remove.len() > 0 { 237 | for token in tokens_to_remove { 238 | let mut client = connected_sockets.remove(&token).unwrap(); 239 | poll.registry().deregister(&mut client.stream).unwrap(); 240 | 241 | if *d.read().unwrap() { 242 | println!( 243 | "[Thread {}] Connection ended ({}) [Remaining clients: {}]", 244 | id, 245 | client.address, 246 | connected_sockets.len() 247 | ); 248 | } 249 | } 250 | 251 | // update count 252 | *client_counts.read().unwrap()[client_list_index].write().unwrap() = connected_sockets.len(); 253 | } 254 | } 255 | 256 | // ------------------------------ 257 | // EVENT LOOP 258 | // ------------------------------ 259 | if events.is_empty() || *stopped.read().unwrap() { 260 | continue; 261 | } 262 | for event in events.iter() { 263 | match event.token() { 264 | token => { 265 | let client = match connected_sockets.get_mut(&token) { 266 | Some(c) => c, 267 | None => { 268 | // println!("ERROR - Tried getting client that was not present in hash map! -> token: {:?}", token); 269 | // TODO: maybe deregister from poll if this is ever even called 270 | continue; 271 | } 272 | }; 273 | 274 | if !client.is_client_connected() { 275 | // ignore, will be handled in later loop and cleaned 276 | continue; 277 | } 278 | 279 | // if client is in process of connecting, check if connection has been established 280 | if client.is_connecting() { 281 | LoadBalancer::try_confirm_connection(id, client, Arc::clone(&d), Arc::clone(&b)); 282 | } 283 | 284 | // if connected, process it normally, otherwise start a new connection to next host 285 | if client.is_connected() { 286 | LoadBalancer::process_client(client, Arc::clone(&b)); 287 | } else if !client.is_connecting() { 288 | LoadBalancer::start_connection(id, token, client, &poll, Arc::clone(&d), Arc::clone(&b)); 289 | } 290 | } 291 | } 292 | } 293 | } 294 | }); 295 | } 296 | } 297 | 298 | fn try_confirm_connection(id: u32, client: &mut TcpClient, d: Arc>, b: Arc>) { 299 | let server_connected = client.check_target_connected().unwrap_or_else(|e| { 300 | println!("Not connected unknown error -> {}", e.to_string()); 301 | // TODO: should probably disconnect - there was an error while connecting other than NotConnected 302 | false 303 | }); 304 | 305 | if server_connected { 306 | let addr = client.get_target_addr().unwrap(); 307 | 308 | if *d.read().unwrap() && !client.is_connecting() { 309 | println!("[Thread {}] Client connected to target ({} -> {})", id, client.address, addr); 310 | } 311 | 312 | // report success if connection succeeded 313 | if b.read().unwrap().is_on_cooldown(addr) { 314 | b.write().unwrap().report_success(addr); 315 | } 316 | } 317 | } 318 | 319 | fn process_client(client: &mut TcpClient, b: Arc>) { 320 | let success = client.process(); 321 | 322 | if success == false { 323 | // connection to either server or client has failed 324 | 325 | // removal from list is handled later 326 | 327 | LoadBalancer::report_target_error(client, Arc::clone(&b)); 328 | } 329 | } 330 | 331 | fn start_connection(id: u32, token: Token, client: &mut TcpClient, poll: &Poll, d: Arc>, b: Arc>) { 332 | // determine target host to connect to, using the balancing algorithm! 333 | let target_socket = match client.get_target_addr() { 334 | Some(s) => s, 335 | None => b.write().unwrap().get_next_host(), 336 | }; 337 | 338 | if *d.read().unwrap() && !client.is_connecting() { 339 | println!("[Thread {}] Connecting client ({} -> {})", id, client.address, target_socket); 340 | } 341 | 342 | // connect to target 343 | let success = match client.connect_to_target(target_socket) { 344 | Ok(s) => s, 345 | Err(e) => { 346 | println!( 347 | "[Thread {}] Unexpected error while trying to start a connection! {} ({} -> {})", 348 | id, 349 | e.to_string(), 350 | client.address, 351 | target_socket 352 | ); 353 | false 354 | } 355 | }; 356 | 357 | if success { 358 | // connection to target host started 359 | // add server to poll (with same token as client) 360 | client.register_target_with_poll(&poll, token); 361 | } else { 362 | // report host error to host manager 363 | LoadBalancer::report_target_error(client, Arc::clone(&b)); 364 | } 365 | } 366 | 367 | fn report_target_error(client: &mut TcpClient, b: Arc>) { 368 | // report host error to host manager 369 | let last_t = client.get_last_target_addr(); 370 | if client.last_target_errored() && last_t.is_some() { 371 | b.write().unwrap().report_error(last_t.unwrap()); 372 | } 373 | } 374 | } 375 | -------------------------------------------------------------------------------- /src/balancer/balancing_algorithm.rs: -------------------------------------------------------------------------------- 1 | use std::net::SocketAddr; 2 | pub trait BalancingAlgorithm: Sync + Send { 3 | /** 4 | Returns the next host for the client to try to connect to 5 | */ 6 | fn get_next_host(&mut self) -> SocketAddr; 7 | /** 8 | Reports error for the given host address. Host can then be placed on cooldown, this can affect the [get_next_host] call 9 | */ 10 | fn report_error(&mut self, addr: SocketAddr); 11 | /** 12 | Reports success for the given host address. Host can be removed from cooldown 13 | */ 14 | fn report_success(&mut self, addr: SocketAddr); 15 | /** 16 | Checks if host is currently on cooldown or in any way affected by the reported errors 17 | */ 18 | fn is_on_cooldown(&self, addr: SocketAddr) -> bool; 19 | } 20 | -------------------------------------------------------------------------------- /src/balancer/client.rs: -------------------------------------------------------------------------------- 1 | use std::io::prelude::*; 2 | use std::io::ErrorKind; 3 | use std::io::Result; 4 | use std::net::Shutdown; 5 | use std::net::SocketAddr; 6 | 7 | use std::time::Instant; 8 | 9 | use mio::net::TcpStream; 10 | use mio::Interest; 11 | use mio::Poll; 12 | use mio::Token; 13 | 14 | pub struct TcpClient { 15 | pub stream: TcpStream, 16 | buffer: [u8; 4096], 17 | 18 | pub address: SocketAddr, 19 | target: Option, 20 | target_stream: Option, 21 | is_connected: bool, 22 | is_connecting: bool, 23 | is_client_connected: bool, 24 | pub last_connection_loss: Instant, 25 | pub started_connecting: Instant, 26 | last_target: Option, 27 | last_target_error: bool, 28 | } 29 | 30 | impl TcpClient { 31 | pub fn new(stream: TcpStream) -> Self { 32 | let addr: SocketAddr = stream.peer_addr().unwrap(); 33 | 34 | TcpClient { 35 | stream: stream, 36 | buffer: [0; 4096], 37 | target: None, 38 | target_stream: None, 39 | address: addr, 40 | is_connected: false, 41 | is_connecting: false, 42 | is_client_connected: true, 43 | last_connection_loss: Instant::now(), 44 | started_connecting: Instant::now(), 45 | last_target: None, 46 | last_target_error: false, 47 | } 48 | } 49 | 50 | pub fn register_target_with_poll(&mut self, poll: &Poll, token: Token) -> Option<()> { 51 | let mut str = self.target_stream.take()?; 52 | 53 | poll.registry().register(&mut str, token, Interest::READABLE | Interest::WRITABLE).unwrap(); 54 | 55 | self.target_stream = Some(str); 56 | 57 | Some(()) 58 | } 59 | 60 | pub fn get_target_addr(&self) -> Option { 61 | self.target 62 | } 63 | 64 | pub fn get_last_target_addr(&self) -> Option { 65 | self.last_target 66 | } 67 | 68 | pub fn last_target_errored(&self) -> bool { 69 | self.last_target_error 70 | } 71 | 72 | pub fn is_connected(&self) -> bool { 73 | self.is_connected 74 | } 75 | 76 | pub fn is_connecting(&self) -> bool { 77 | self.is_connecting 78 | } 79 | 80 | pub fn is_client_connected(&self) -> bool { 81 | self.is_client_connected 82 | } 83 | 84 | pub fn connect_to_target(&mut self, target: SocketAddr) -> Result { 85 | if self.is_connecting { 86 | println!("[WARNING] Already connecting, this shouldn't happen"); 87 | return Ok(false); 88 | } 89 | 90 | self.close_connection_to_target(false); 91 | 92 | // start connecting 93 | let stream = match TcpStream::connect(target) { 94 | Ok(t) => t, 95 | Err(_) => { 96 | return Ok(false); 97 | } 98 | }; 99 | 100 | self.is_connecting = true; 101 | self.target = Some(target); 102 | self.target_stream = Some(stream); 103 | self.started_connecting = Instant::now(); 104 | 105 | Ok(true) 106 | } 107 | 108 | pub fn check_target_connected(&mut self) -> Result { 109 | let stream = self.target_stream.as_ref().unwrap(); 110 | 111 | let mut buf: [u8; 1] = [0; 1]; 112 | match stream.peek(&mut buf) { 113 | Ok(_) => true, 114 | Err(ref e) if e.kind() == ErrorKind::NotConnected => return Ok(false), 115 | Err(ref e) if e.kind() == ErrorKind::WouldBlock => true, 116 | Err(e) => { 117 | return Err(e); 118 | } 119 | }; 120 | 121 | self.set_connected(); 122 | Ok(true) 123 | } 124 | 125 | fn set_connected(&mut self) { 126 | self.is_connected = true; 127 | self.is_connecting = false; 128 | } 129 | 130 | /** 131 | Reads from client and forwards it to server. Boolean represents processing success, will be [false] when connection to either client or server fails. 132 | Equivalent of calling [forward_to_target] and [forward_from_target] methods 133 | */ 134 | pub fn process(&mut self) -> bool { 135 | if !self.forward_to_target() { 136 | return false; 137 | } 138 | 139 | if !self.forward_from_target() { 140 | return false; 141 | } 142 | 143 | return true; 144 | } 145 | 146 | /** 147 | Forwards client messages to connected target. (Reads from client stream and writes to target stream) 148 | */ 149 | pub fn forward_to_target(&mut self) -> bool { 150 | let mut str = self.target_stream.as_ref().unwrap(); 151 | 152 | // READ FROM CLIENT 153 | let read: i32 = match self.stream.read(&mut self.buffer) { 154 | Ok(r) => r as i32, 155 | Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => -1, 156 | Err(_) => { 157 | // error with connection to client 158 | self.close_connection(); 159 | return false; 160 | } 161 | }; 162 | 163 | // WRITE TO SERVER 164 | if read > 0 { 165 | match str.write(&self.buffer[..(read as usize)]) { 166 | Ok(_) => {} 167 | Err(_e) => { 168 | // error with connection to server 169 | self.close_connection_to_target(true); 170 | return false; 171 | } 172 | } 173 | } else if read == 0 { 174 | self.close_connection(); 175 | return false; 176 | } 177 | 178 | return true; 179 | } 180 | 181 | /** 182 | Forwards connected target messages to client. (Reads from target stream and writes to client stream) 183 | */ 184 | pub fn forward_from_target(&mut self) -> bool { 185 | let mut str = self.target_stream.as_ref().unwrap(); 186 | 187 | // READ FROM SERVER 188 | let reads: i32 = match str.read(&mut self.buffer) { 189 | Ok(r) => r as i32, 190 | Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => -1, 191 | Err(_e) => { 192 | // error with connection to server 193 | self.close_connection_to_target(true); 194 | return false; 195 | } 196 | }; 197 | 198 | // WRITE TO CLIENT 199 | if reads > 0 { 200 | match self.stream.write(&self.buffer[..(reads as usize)]) { 201 | Ok(_) => {} 202 | Err(_) => { 203 | // error with connection to client 204 | self.close_connection(); 205 | return false; 206 | } 207 | }; 208 | } else if reads == 0 { 209 | self.close_connection_to_target(false); 210 | return false; 211 | } 212 | 213 | return true; 214 | } 215 | 216 | pub fn close_connection_to_target(&mut self, target_errored: bool) { 217 | // if connected to target, disconnect - mark last connection loss 218 | if self.is_connected { 219 | let str = self.target_stream.as_ref().unwrap(); 220 | str.shutdown(Shutdown::Both).unwrap_or(()); 221 | drop(str); 222 | 223 | self.last_connection_loss = Instant::now(); 224 | } 225 | 226 | // mark error 227 | if target_errored { 228 | self.last_target = self.target; 229 | self.last_target_error = true; 230 | } else { 231 | self.last_target = None; 232 | self.last_target_error = false; 233 | } 234 | 235 | // reset 236 | self.target = None; 237 | self.target_stream = None; 238 | 239 | self.is_connected = false; 240 | self.is_connecting = false; 241 | } 242 | 243 | pub fn close_connection(&mut self) { 244 | if self.is_client_connected { 245 | let str = &self.stream; 246 | str.shutdown(Shutdown::Both).unwrap_or(()); 247 | drop(str); 248 | 249 | self.is_client_connected = false; 250 | 251 | // also close connection to target if connected - there is no reason to stay connected if client is not 252 | self.close_connection_to_target(false); 253 | } 254 | } 255 | } 256 | 257 | impl Drop for TcpClient { 258 | fn drop(&mut self) { 259 | self.close_connection(); 260 | } 261 | } 262 | -------------------------------------------------------------------------------- /src/balancer/host_manager.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::BufRead; 3 | use std::io::BufReader; 4 | use std::io::Result; 5 | use std::net::SocketAddr; 6 | use std::net::ToSocketAddrs; 7 | use std::path::Path; 8 | use std::str; 9 | 10 | pub struct HostManager { 11 | pub hosts: Vec, 12 | } 13 | 14 | impl HostManager { 15 | pub fn new(hostfile: &str) -> Self { 16 | if !Path::exists(Path::new(hostfile)) { 17 | println!("[Parser] Host file '{}' does not exist. Please create it and try again.", hostfile); 18 | 19 | return HostManager { hosts: vec![] }; 20 | } 21 | 22 | let hosts = match HostManager::parse_hosts(hostfile) { 23 | Ok(h) => h, 24 | Err(err) => { 25 | println!("[Parser] Failed to parse host file '{}' -> {}", hostfile, err.to_string()); 26 | vec![] 27 | } 28 | }; 29 | 30 | return HostManager { hosts: hosts }; 31 | } 32 | 33 | fn parse_hosts(hostfile: &str) -> Result> { 34 | let mut hosts: Vec = vec![]; 35 | 36 | let file = File::open(hostfile)?; 37 | let bufreader = BufReader::new(file); 38 | 39 | for line in bufreader.lines() { 40 | let l = line?; 41 | let l = l.trim(); 42 | if l.len() < 2 { 43 | continue; 44 | } 45 | 46 | // validate IP address and port - either IPv4 or IPv6 with valid port number 47 | // this also accepts domains and tries to resolve them, the first resolved IP is used 48 | let addr: Vec = match l.to_socket_addrs() { 49 | Ok(a) => a.collect(), 50 | Err(_) => { 51 | println!("[Parser] Invalid host: '{}'", l); 52 | continue; 53 | } 54 | }; 55 | 56 | let mut resolved_addr: SocketAddr = addr[0]; 57 | 58 | // if there are more than 1 IP resolved, prioritize the IPv4 59 | if addr.len() > 1 { 60 | for a in addr { 61 | if a.is_ipv4() { 62 | resolved_addr = a; 63 | break; 64 | } 65 | } 66 | } 67 | 68 | // push the resolved IP onto hosts list 69 | hosts.push(resolved_addr); 70 | } 71 | 72 | println!("[Parser] Registered {} valid hosts", hosts.len()); 73 | Ok(hosts) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/balancer/mod.rs: -------------------------------------------------------------------------------- 1 | mod client; 2 | mod balancer; 3 | mod host_manager; 4 | mod balancing_algorithm; 5 | mod algorithms; 6 | mod poller; 7 | 8 | pub use client::TcpClient; 9 | pub use balancer::LoadBalancer; 10 | pub use host_manager::HostManager; 11 | pub use balancing_algorithm::BalancingAlgorithm; 12 | pub use algorithms::RoundRobin; 13 | pub use poller::Poller; -------------------------------------------------------------------------------- /src/balancer/poller.rs: -------------------------------------------------------------------------------- 1 | use std::io::{ErrorKind, Result}; 2 | use std::sync::{Arc, RwLock}; 3 | use std::thread; 4 | use std::time::Duration; 5 | 6 | use mio::net::{TcpListener}; 7 | use mio::{Events, Interest, Poll, Token}; 8 | 9 | use super::LoadBalancer; 10 | 11 | pub struct Poller { 12 | balancer: LoadBalancer, 13 | should_cancel: Arc>, 14 | } 15 | 16 | impl Poller { 17 | pub fn new(mut balancer: LoadBalancer) -> Self { 18 | let should_cancel = Arc::new(RwLock::new(false)); 19 | balancer.start(); 20 | 21 | let mut p = Poller { 22 | balancer, 23 | should_cancel, 24 | }; 25 | 26 | p.initialize().unwrap(); 27 | 28 | p 29 | } 30 | 31 | fn initialize(&mut self) -> Result<()> { 32 | // prepare the ctrl+c handler for graceful stop 33 | let cancel = Arc::clone(&self.should_cancel); 34 | ctrlc::set_handler(move || { 35 | *cancel.write().unwrap() = true; 36 | }) 37 | .expect("Failed to set Ctrl+C handler!"); 38 | 39 | Ok(()) 40 | } 41 | 42 | pub fn start_listening(&mut self, listening_port: i32) -> Result<()> { 43 | let addr = format!("0.0.0.0:{}", listening_port).parse().unwrap(); 44 | let mut listener = TcpListener::bind(addr)?; 45 | 46 | let mut poll = Poll::new().unwrap(); 47 | let mut events = Events::with_capacity(512); 48 | poll.registry().register(&mut listener, Token(0), Interest::READABLE)?; 49 | 50 | // START LISTENING 51 | println!("[Listener] Started listening on port {}", listening_port); 52 | loop { 53 | if *self.should_cancel.read().unwrap() { 54 | self.balancer.stop(); 55 | println!("[Listener] Listening stopped"); 56 | 57 | // sleep a bit to allow all threads to exit gracefully 58 | thread::sleep(Duration::from_millis(10)); 59 | break; 60 | } 61 | 62 | // poll for events here (with timeout to check of [should_cancel]) 63 | match poll.poll(&mut events, Some(Duration::from_millis(5))) { 64 | Ok(_) => {} 65 | Err(ref e) if e.kind() == ErrorKind::Interrupted => { 66 | // this handler does not get called on Windows, so we use timeout and check it outside 67 | *self.should_cancel.write().unwrap() = true; 68 | } 69 | Err(e) => { 70 | println!("Failed to poll for events! {}", e.to_string()); 71 | break; 72 | } 73 | }; 74 | 75 | if events.is_empty() { 76 | continue; 77 | } 78 | 79 | for event in events.iter() { 80 | match event.token() { 81 | _ => { 82 | // accept a new client 83 | let connection = match listener.accept() { 84 | Ok(c) => c, 85 | Err(ref e) if e.kind() == ErrorKind::WouldBlock => { continue; }, 86 | Err(e) => { 87 | println!("Failed to accept socket! {}", e.to_string()); 88 | continue; 89 | } 90 | }; 91 | 92 | // we need to reregister to set the Interest again, othewise we won't get any more readiness events (only on Windows) 93 | poll.registry().reregister(&mut listener, Token(0), Interest::READABLE).unwrap(); 94 | self.balancer.add_client(connection.0); 95 | } 96 | } 97 | } 98 | } 99 | 100 | Ok(()) 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use std::io::Result; 2 | use std::process::exit; 3 | 4 | mod balancer; 5 | use balancer::Poller; 6 | use balancer::RoundRobin; 7 | use balancer::{HostManager, LoadBalancer}; 8 | fn main() -> Result<()> { 9 | // PARSE HOSTS 10 | let host_manager = HostManager::new("hosts"); 11 | if host_manager.hosts.len() == 0 { 12 | return Ok(()); 13 | } 14 | 15 | // INITIALIZE 16 | let debug_mode = true; 17 | let round_robin = RoundRobin::new(host_manager); 18 | let balancer = LoadBalancer::new(round_robin, 4, debug_mode); 19 | let mut poller = Poller::new(balancer); 20 | 21 | // PARSE PORT 22 | let port = get_port().unwrap_or_else(|| { 23 | println!("Invalid listening port provided!"); 24 | exit(1); 25 | }); 26 | 27 | // START 28 | poller.start_listening(port).unwrap_or_else(|e| { 29 | println!("{}", e.to_string()); 30 | exit(2); 31 | }); 32 | 33 | Ok(()) 34 | } 35 | 36 | fn get_port() -> Option { 37 | let listening_port = std::env::args().nth(1)?; 38 | let port: i32 = match listening_port.parse() { 39 | Ok(p) => p, 40 | Err(_) => return None, 41 | }; 42 | 43 | if port <= 0 || port > 65535 { 44 | return None; 45 | } 46 | 47 | Some(port) 48 | } 49 | --------------------------------------------------------------------------------