├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── examples └── kvstore │ ├── .gitignore │ ├── Cargo.toml │ └── src │ └── main.rs ├── logo.png └── src ├── lib.rs └── test.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "redcon" 3 | version = "0.1.2" 4 | authors = ["Josh Baker "] 5 | edition = "2021" 6 | license = "MIT" 7 | readme = "README.md" 8 | repository = "https://github.com/tidwall/redcon.rs" 9 | documentation = "https://docs.rs/redcon/" 10 | description = "Redis compatible server framework for Rust" 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2022 Josh Baker 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

Redis compatible server framework for Rust

3 | 4 | ## Features 5 | 6 | - Create a fast custom Redis compatible server in Rust 7 | - Simple API. 8 | - Support for pipelining and telnet commands. 9 | - Works with Redis clients such as [redis-rs](https://github.com/redis-rs/redis-rs), [redigo](https://github.com/garyburd/redigo), [redis-py](https://github.com/andymccurdy/redis-py), [node_redis](https://github.com/NodeRedis/node_redis), and [jedis](https://github.com/xetorthio/jedis) 10 | - Multithreaded 11 | 12 | *This library is also avaliable for [Go](https://github.com/tidwall/redcon) and [C](https://github.com/tidwall/redcon.c).* 13 | 14 | ## Example 15 | 16 | Here's a full [example](examples/kvstore) of a Redis clone that accepts: 17 | 18 | - SET key value 19 | - GET key 20 | - DEL key 21 | - PING 22 | - QUIT 23 | 24 | ```rust 25 | use std::collections::HashMap; 26 | use std::sync::Mutex; 27 | 28 | fn main() { 29 | let db: Mutex, Vec>> = Mutex::new(HashMap::new()); 30 | 31 | let mut s = redcon::listen("127.0.0.1:6380", db).unwrap(); 32 | s.command = Some(|conn, db, args|{ 33 | let name = String::from_utf8_lossy(&args[0]).to_lowercase(); 34 | match name.as_str() { 35 | "ping" => conn.write_string("PONG"), 36 | "set" => { 37 | if args.len() < 3 { 38 | conn.write_error("ERR wrong number of arguments"); 39 | return; 40 | } 41 | let mut db = db.lock().unwrap(); 42 | db.insert(args[1].to_owned(), args[2].to_owned()); 43 | conn.write_string("OK"); 44 | } 45 | "get" => { 46 | if args.len() < 2 { 47 | conn.write_error("ERR wrong number of arguments"); 48 | return; 49 | } 50 | let db = db.lock().unwrap(); 51 | match db.get(&args[1]) { 52 | Some(val) => conn.write_bulk(val), 53 | None => conn.write_null(), 54 | } 55 | } 56 | _ => conn.write_error("ERR unknown command"), 57 | } 58 | }); 59 | println!("Serving at {}", s.local_addr()); 60 | s.serve().unwrap(); 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /examples/kvstore/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /examples/kvstore/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "kvstore" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | redcon = "0.1" -------------------------------------------------------------------------------- /examples/kvstore/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::sync::Mutex; 3 | 4 | fn main() { 5 | let db: Mutex, Vec>> = Mutex::new(HashMap::new()); 6 | 7 | let mut s = redcon::listen("127.0.0.1:6380", db).unwrap(); 8 | s.command = Some(|conn, db, args|{ 9 | let name = String::from_utf8_lossy(&args[0]).to_lowercase(); 10 | match name.as_str() { 11 | "ping" => conn.write_string("PONG"), 12 | "quit" => { 13 | conn.write_string("OK"); 14 | conn.close(); 15 | } 16 | "set" => { 17 | if args.len() != 3 { 18 | conn.write_error("ERR wrong number of arguments"); 19 | return; 20 | } 21 | let mut db = db.lock().unwrap(); 22 | db.insert(args[1].to_owned(), args[2].to_owned()); 23 | conn.write_string("OK"); 24 | } 25 | "get" => { 26 | if args.len() != 2 { 27 | conn.write_error("ERR wrong number of arguments"); 28 | return; 29 | } 30 | let db = db.lock().unwrap(); 31 | match db.get(&args[1]) { 32 | Some(val) => conn.write_bulk(val), 33 | None => conn.write_null(), 34 | } 35 | } 36 | "del" => { 37 | if args.len() != 2 { 38 | conn.write_error("ERR wrong number of arguments"); 39 | return; 40 | } 41 | let mut db = db.lock().unwrap(); 42 | match db.remove(&args[1]) { 43 | Some(_) => conn.write_integer(1), 44 | None => conn.write_integer(0), 45 | } 46 | } 47 | _ => conn.write_error("ERR unknown command"), 48 | } 49 | }); 50 | println!("Serving at {}", s.local_addr()); 51 | s.serve().unwrap(); 52 | } 53 | -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/atlas-2192/Redis_Server_Rust/41646aa0e776aed130e0237185a67e2a114d5587/logo.png -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | // Copyright 2022 Joshua J Baker. All rights reserved. 2 | // Use of this source code is governed by an MIT-style 3 | // license that can be found in the LICENSE file. 4 | 5 | //! Redis compatible server framework for Rust 6 | 7 | #[cfg(test)] 8 | mod test; 9 | 10 | use std::any::Any; 11 | use std::collections::HashMap; 12 | use std::io; 13 | use std::io::{BufRead, BufReader, Read, Write}; 14 | use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs}; 15 | use std::sync::atomic::{AtomicBool, AtomicI64, Ordering}; 16 | use std::sync::{Arc, Mutex}; 17 | use std::thread; 18 | use std::time::Duration; 19 | 20 | /// A error type that is returned by the [`listen`] and [`Server::serve`] 21 | /// functions and passed to the [`Server::closed`] handler. 22 | #[derive(Debug)] 23 | pub enum Error { 24 | // A Protocol error that was caused by malformed input by the client 25 | // connection. 26 | Protocol(String), 27 | // An I/O error that was caused by the network, such as a closed TCP 28 | // connection, or a failure or listen on at a socket address. 29 | IoError(io::Error), 30 | } 31 | 32 | impl From for Error { 33 | fn from(error: io::Error) -> Self { 34 | Error::IoError(error) 35 | } 36 | } 37 | 38 | impl Error { 39 | fn new(msg: &str) -> Error { 40 | Error::Protocol(msg.to_owned()) 41 | } 42 | } 43 | 44 | impl std::fmt::Display for Error { 45 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 46 | match &self { 47 | Error::Protocol(s) => write!(f, "{}", s), 48 | Error::IoError(e) => write!(f, "{}", e), 49 | } 50 | } 51 | } 52 | 53 | impl std::error::Error for Error {} 54 | 55 | /// A client connection. 56 | pub struct Conn { 57 | id: u64, 58 | addr: SocketAddr, 59 | reader: BufReader>, 60 | wbuf: Vec, 61 | writer: Box, 62 | closed: bool, 63 | shutdown: bool, 64 | cmds: Vec>>, 65 | conns: Arc>>>, 66 | /// A custom user-defined context. 67 | pub context: Option>, 68 | } 69 | 70 | impl Conn { 71 | /// A distinct identifier for the connection. 72 | pub fn id(&self) -> u64 { 73 | self.id 74 | } 75 | 76 | /// The connection socket address. 77 | pub fn addr(&self) -> &SocketAddr { 78 | &self.addr 79 | } 80 | 81 | /// Read the next command in the pipeline, if any. 82 | /// 83 | /// This method is not typically needed, but it can be used for reading 84 | /// additional incoming commands that may be present. Which, may come in 85 | /// handy for specialized stuff like batching operations or for optimizing 86 | /// a locking strategy. 87 | pub fn next_command(&mut self) -> Option>> { 88 | self.cmds.pop() 89 | } 90 | 91 | /// Write a RESP Simple String to client connection. 92 | /// 93 | /// 94 | pub fn write_string(&mut self, msg: &str) { 95 | if !self.closed { 96 | self.extend_lossy_line(b'+', msg); 97 | } 98 | } 99 | /// Write a RESP Null Bulk String to client connection. 100 | /// 101 | /// 102 | pub fn write_null(&mut self) { 103 | if !self.closed { 104 | self.wbuf.extend("$-1\r\n".as_bytes()); 105 | } 106 | } 107 | /// Write a RESP Error to client connection. 108 | /// 109 | /// 110 | pub fn write_error(&mut self, msg: &str) { 111 | if !self.closed { 112 | self.extend_lossy_line(b'-', msg); 113 | } 114 | } 115 | /// Write a RESP Integer to client connection. 116 | /// 117 | /// 118 | pub fn write_integer(&mut self, x: i64) { 119 | if !self.closed { 120 | self.wbuf.extend(format!(":{}\r\n", x).as_bytes()); 121 | } 122 | } 123 | /// Write a RESP Array to client connection. 124 | /// 125 | /// 126 | pub fn write_array(&mut self, count: usize) { 127 | if !self.closed { 128 | self.wbuf.extend(format!("*{}\r\n", count).as_bytes()); 129 | } 130 | } 131 | /// Write a RESP Bulk String to client connection. 132 | /// 133 | /// 134 | pub fn write_bulk(&mut self, msg: &[u8]) { 135 | if !self.closed { 136 | self.wbuf.extend(format!("${}\r\n", msg.len()).as_bytes()); 137 | self.wbuf.extend(msg); 138 | self.wbuf.push(b'\r'); 139 | self.wbuf.push(b'\n'); 140 | } 141 | } 142 | 143 | /// Write raw bytes to the client connection. 144 | pub fn write_raw(&mut self, raw: &[u8]) { 145 | if !self.closed { 146 | self.wbuf.extend(raw); 147 | } 148 | } 149 | 150 | /// Close the client connection. 151 | pub fn close(&mut self) { 152 | self.closed = true; 153 | } 154 | 155 | /// Shutdown the server that was started by [`Server::serve`]. 156 | /// 157 | /// This operation will gracefully shutdown the server by closing the all 158 | /// client connections, stopping the server listener, and waiting for the 159 | /// server resources to free. 160 | pub fn shutdown(&mut self) { 161 | self.closed = true; 162 | self.shutdown = true; 163 | } 164 | 165 | /// Close a client connection that is not this one. 166 | /// 167 | /// The identifier is for a client connection that was connection to the 168 | /// same server as `self`. This operation can safely be called on the same 169 | /// identifier multiple time. 170 | pub fn cross_close(&mut self, id: u64) { 171 | if let Some(xcloser) = self.conns.lock().unwrap().get(&id) { 172 | xcloser.store(true, Ordering::SeqCst); 173 | } 174 | } 175 | 176 | fn pl_read_array(&mut self, line: Vec) -> Result>>, Error> { 177 | let n = match String::from_utf8_lossy(&line[1..]).parse::() { 178 | Ok(n) => n, 179 | Err(_) => { 180 | return Err(Error::new("invalid multibulk length")); 181 | } 182 | }; 183 | let mut arr = Vec::new(); 184 | for _ in 0..n { 185 | let line = match self.pl_read_line()? { 186 | Some(line) => line, 187 | None => return Ok(None), 188 | }; 189 | if line.len() == 0 { 190 | return Err(Error::new("expected '$', got ' '")); 191 | } 192 | if line[0] != b'$' { 193 | return Err(Error::new(&format!( 194 | "expected '$', got '{}'", 195 | if line[0] < 20 || line[0] > b'~' { 196 | ' ' 197 | } else { 198 | line[0] as char 199 | }, 200 | ))); 201 | } 202 | let n = match String::from_utf8_lossy(&line[1..]).parse::() { 203 | Ok(n) => n, 204 | Err(_) => -1, 205 | }; 206 | if n < 0 || n > 536870912 { 207 | // Spec limits the number of bytes in a bulk. 208 | // https://redis.io/docs/reference/protocol-spec 209 | return Err(Error::new("invalid bulk length")); 210 | } 211 | let mut buf = vec![0u8; n as usize]; 212 | self.reader.read_exact(&mut buf)?; 213 | let mut crnl = [0u8; 2]; 214 | self.reader.read_exact(&mut crnl)?; 215 | // Actual redis ignores the last two characters even though 216 | // they should be looking for '\r\n'. 217 | arr.push(buf); 218 | } 219 | Ok(Some(arr)) 220 | } 221 | fn pl_read_line(&mut self) -> Result>, Error> { 222 | let mut line = Vec::new(); 223 | let size = self.reader.read_until(b'\n', &mut line)?; 224 | if size == 0 { 225 | return Ok(None); 226 | } 227 | if line.len() > 1 && line[line.len() - 2] == b'\r' { 228 | line.truncate(line.len() - 2); 229 | } else { 230 | line.truncate(line.len() - 1); 231 | } 232 | Ok(Some(line)) 233 | } 234 | fn pl_read_inline(&mut self, line: Vec) -> Result>>, Error> { 235 | const UNBALANCED: &str = "unbalanced quotes in request"; 236 | let mut arr = Vec::new(); 237 | let mut arg = Vec::new(); 238 | let mut i = 0; 239 | loop { 240 | if i >= line.len() || line[i] == b' ' || line[i] == b'\t' { 241 | if arg.len() > 0 { 242 | arr.push(arg); 243 | arg = Vec::new(); 244 | } 245 | if i >= line.len() { 246 | break; 247 | } 248 | } else if line[i] == b'\'' || line[i] == b'\"' { 249 | let quote = line[i]; 250 | i += 1; 251 | loop { 252 | if i == line.len() { 253 | return Err(Error::new(UNBALANCED)); 254 | } 255 | if line[i] == quote { 256 | i += 1; 257 | break; 258 | } 259 | if line[i] == b'\\' && quote == b'"' { 260 | if i == line.len() - 1 { 261 | return Err(Error::new(UNBALANCED)); 262 | } 263 | i += 1; 264 | match line[i] { 265 | b't' => arg.push(b'\t'), 266 | b'n' => arg.push(b'\n'), 267 | b'r' => arg.push(b'\r'), 268 | b'b' => arg.push(8), 269 | b'v' => arg.push(11), 270 | b'x' => { 271 | if line.len() < 3 { 272 | return Err(Error::new(UNBALANCED)); 273 | } 274 | let hline = &line[i + 1..i + 3]; 275 | let hex = String::from_utf8_lossy(hline); 276 | match u8::from_str_radix(&hex, 16) { 277 | Ok(b) => arg.push(b), 278 | Err(_) => arg.extend(&line[i..i + 3]), 279 | } 280 | i += 2; 281 | } 282 | _ => arg.push(line[i]), 283 | } 284 | } else { 285 | arg.push(line[i]); 286 | } 287 | i += 1; 288 | } 289 | if i < line.len() && line[i] != b' ' && line[i] != b'\t' { 290 | return Err(Error::new(UNBALANCED)); 291 | } 292 | } else { 293 | arg.push(line[i]); 294 | } 295 | i += 1; 296 | } 297 | Ok(Some(arr)) 298 | } 299 | 300 | // Read a pipeline of commands. 301 | // Each command will *always* have at least one argument. 302 | fn read_pipeline(&mut self) -> Result>>, Error> { 303 | let mut cmds = Vec::new(); 304 | loop { 305 | // read line 306 | let line = match self.pl_read_line()? { 307 | Some(line) => line, 308 | None => { 309 | self.closed = true; 310 | break; 311 | } 312 | }; 313 | if line.len() == 0 { 314 | // empty lines are ignored. 315 | continue; 316 | } 317 | let args = if line[0] == b'*' { 318 | // read RESP array 319 | self.pl_read_array(line)? 320 | } else { 321 | // read inline array 322 | self.pl_read_inline(line)? 323 | }; 324 | let args = match args { 325 | Some(args) => args, 326 | None => { 327 | self.closed = true; 328 | break; 329 | } 330 | }; 331 | if args.len() > 0 { 332 | cmds.push(args); 333 | } 334 | if cmds.len() > 0 && self.reader.buffer().len() == 0 { 335 | break; 336 | } 337 | } 338 | Ok(cmds) 339 | } 340 | 341 | fn extend_lossy_line(&mut self, prefix: u8, msg: &str) { 342 | self.wbuf.push(prefix); 343 | for b in msg.bytes() { 344 | self.wbuf.push(if b < b' ' { b' ' } else { b }) 345 | } 346 | self.wbuf.push(b'\r'); 347 | self.wbuf.push(b'\n'); 348 | } 349 | 350 | fn flush(&mut self) -> Result<(), Error> { 351 | if self.wbuf.len() > 0 { 352 | self.writer.write_all(&self.wbuf)?; 353 | if self.wbuf.len() > 1048576 { 354 | self.wbuf = Vec::new(); 355 | } else { 356 | self.wbuf.truncate(0); 357 | } 358 | } 359 | Ok(()) 360 | } 361 | } 362 | 363 | pub struct Server { 364 | listener: Option, 365 | data: Option, 366 | local_addr: SocketAddr, 367 | 368 | /// Handle incoming RESP commands. 369 | pub command: Option>)>, 370 | 371 | /// Handle incoming connections. 372 | pub opened: Option, 373 | 374 | /// Handle closed connections. 375 | /// 376 | /// If the connection was closed due to an error then that error is 377 | /// provided. 378 | pub closed: Option)>, 379 | 380 | /// Handle ticks at intervals as defined by the returned [`Duration`]. 381 | /// 382 | /// The next tick will happen following the elapsed returned `Duration`. 383 | /// 384 | /// Returning `None` will shutdown the server. 385 | pub tick: Option Option>, 386 | } 387 | 388 | /// Creates a new `Server` which will be listening for incoming connections on 389 | /// the specified address using the provided `data`. 390 | /// 391 | /// The returned server is ready for serving. 392 | /// 393 | /// # Examples 394 | /// 395 | /// Creates a Redcon server listening at `127.0.0.1:6379`: 396 | /// 397 | /// ```no_run 398 | /// let my_data = "hello"; 399 | /// let server = redcon::listen("127.0.0.1:6379", my_data).unwrap(); 400 | /// ``` 401 | pub fn listen(addr: A, data: T) -> Result, Error> { 402 | let listener = TcpListener::bind(addr)?; 403 | let local_addr = listener.local_addr()?; 404 | let svr = Server { 405 | data: Some(data), 406 | listener: Some(listener), 407 | local_addr: local_addr, 408 | command: None, 409 | opened: None, 410 | closed: None, 411 | tick: None, 412 | }; 413 | Ok(svr) 414 | } 415 | 416 | impl Server { 417 | pub fn serve(&mut self) -> Result<(), Error> { 418 | serve(self) 419 | } 420 | pub fn local_addr(&self) -> SocketAddr { 421 | self.local_addr 422 | } 423 | } 424 | 425 | fn serve(s: &mut Server) -> Result<(), Error> { 426 | // Take all of the server fields at once. 427 | let listener = match s.listener.take() { 428 | Some(listener) => listener, 429 | None => return Err(Error::IoError(io::Error::from(io::ErrorKind::Other))), 430 | }; 431 | let data = s.data.take().unwrap(); 432 | let command = s.command.take(); 433 | let opened = s.opened.take(); 434 | let closed = s.closed.take(); 435 | let tick = s.tick.take(); 436 | let laddr = s.local_addr; 437 | drop(s); 438 | 439 | let wg = Arc::new(AtomicI64::new(0)); 440 | let conns: HashMap> = HashMap::new(); 441 | let conns = Arc::new(Mutex::new(conns)); 442 | let data = Arc::new(data); 443 | let mut next_id: u64 = 1; 444 | let shutdown = Arc::new(AtomicBool::new(false)); 445 | let init_shutdown = |shutdown: Arc, laddr: &SocketAddr| { 446 | let aord = Ordering::SeqCst; 447 | if shutdown.compare_exchange(false, true, aord, aord).is_err() { 448 | // Shutdown has already been initiated. 449 | return; 450 | } 451 | // Connect to self to force initiate a shutdown. 452 | let _ = TcpStream::connect(&laddr); 453 | }; 454 | if let Some(tick) = tick { 455 | let data = data.clone(); 456 | let shutdown = shutdown.clone(); 457 | let wg = wg.clone(); 458 | wg.fetch_add(1, Ordering::SeqCst); 459 | thread::spawn(move || { 460 | while !shutdown.load(Ordering::SeqCst) { 461 | match (tick)(&data) { 462 | Some(delay) => thread::sleep(delay), 463 | None => { 464 | init_shutdown(shutdown.clone(), &laddr); 465 | break; 466 | } 467 | } 468 | } 469 | wg.fetch_add(-1, Ordering::SeqCst) 470 | }); 471 | } 472 | for stream in listener.incoming() { 473 | let shutdown = shutdown.clone(); 474 | if shutdown.load(Ordering::SeqCst) { 475 | break; 476 | } 477 | match stream { 478 | Ok(stream) => { 479 | if stream 480 | .set_read_timeout(Some(Duration::from_millis(100))) 481 | .is_err() 482 | { 483 | continue; 484 | } 485 | let addr = match stream.peer_addr() { 486 | Ok(addr) => addr, 487 | _ => continue, 488 | }; 489 | // create two streams (input, output) 490 | let streams = ( 491 | match stream.try_clone() { 492 | Ok(stream) => stream, 493 | _ => continue, 494 | }, 495 | stream, 496 | ); 497 | let data = data.clone(); 498 | let conn_id = next_id; 499 | next_id += 1; 500 | let xcloser = Arc::new(AtomicBool::new(false)); 501 | let conns = conns.clone(); 502 | conns.lock().unwrap().insert(conn_id, xcloser.clone()); 503 | let wg = wg.clone(); 504 | wg.fetch_add(1, Ordering::SeqCst); 505 | thread::spawn(move || { 506 | let mut conn = Conn { 507 | id: conn_id, 508 | cmds: Vec::new(), 509 | context: None, 510 | addr, 511 | reader: BufReader::new(Box::new(streams.0)), 512 | wbuf: Vec::new(), 513 | writer: Box::new(streams.1), 514 | closed: false, 515 | shutdown: false, 516 | conns: conns.clone(), 517 | }; 518 | let mut final_err: Option = None; 519 | if let Some(opened) = opened { 520 | (opened)(&mut conn, &data); 521 | } 522 | loop { 523 | if let Err(e) = conn.flush() { 524 | if final_err.is_none() { 525 | final_err = Some(From::from(e)); 526 | } 527 | conn.closed = true; 528 | } 529 | if conn.closed { 530 | break; 531 | } 532 | match conn.read_pipeline() { 533 | Ok(cmds) => { 534 | conn.cmds = cmds; 535 | conn.cmds.reverse(); 536 | while let Some(cmd) = conn.next_command() { 537 | if let Some(command) = command { 538 | (command)(&mut conn, &data, cmd); 539 | } 540 | if conn.closed { 541 | break; 542 | } 543 | } 544 | } 545 | Err(e) => { 546 | if let Error::Protocol(msg) = &e { 547 | // Write the protocol error to the 548 | // client before closing the connection. 549 | conn.write_error(&format!("ERR Protocol error: {}", msg)); 550 | } else if let Error::IoError(e) = &e { 551 | if let io::ErrorKind::WouldBlock = e.kind() { 552 | // Look to see if there is a pending 553 | // server shutdown or a cross close 554 | // request. 555 | if shutdown.load(Ordering::SeqCst) { 556 | conn.closed = true; 557 | } 558 | if xcloser.load(Ordering::SeqCst) { 559 | conn.closed = true; 560 | } 561 | continue; 562 | } 563 | } 564 | final_err = Some(e); 565 | conn.closed = true; 566 | } 567 | } 568 | } 569 | if conn.shutdown { 570 | init_shutdown(shutdown.clone(), &laddr); 571 | } 572 | if let Some(closed) = closed { 573 | (closed)(&mut conn, &data, final_err); 574 | } 575 | conns.lock().unwrap().remove(&conn.id); 576 | wg.fetch_add(-1, Ordering::SeqCst); 577 | }); 578 | } 579 | Err(_) => {} 580 | } 581 | } 582 | // Wait for all connections to complete and for their threads to terminate. 583 | while wg.load(Ordering::SeqCst) > 0 { 584 | thread::sleep(Duration::from_millis(10)); 585 | } 586 | Ok(()) 587 | } 588 | -------------------------------------------------------------------------------- /src/test.rs: -------------------------------------------------------------------------------- 1 | use super::*; 2 | use std::net::TcpStream; 3 | use std::sync::Mutex; 4 | use std::thread; 5 | use std::time::{Duration, Instant}; 6 | 7 | struct Data { 8 | value1: usize, 9 | value2: usize, 10 | } 11 | 12 | #[test] 13 | fn connections() { 14 | const N: usize = 11; 15 | const ADDR: &str = "127.0.0.1:11099"; 16 | let db = Arc::new(Mutex::new(Data { 17 | value1: 0, 18 | value2: 0, 19 | })); 20 | for i in 0..N { 21 | thread::spawn(move || test_conn(i, ADDR).unwrap()); 22 | } 23 | let mut s = listen(ADDR, db.clone()).unwrap(); 24 | s.command = Some(|conn, data, args| { 25 | let name = String::from_utf8_lossy(&args[0]).to_lowercase(); 26 | match name.as_str() { 27 | "shutdown" => { 28 | conn.write_string("OK"); 29 | conn.shutdown(); 30 | } 31 | "incr" => { 32 | data.lock().unwrap().value2 += 1; 33 | conn.write_string("OK"); 34 | } 35 | _ => { 36 | conn.write_error(&format!("ERR unknown command '{}'", name)); 37 | } 38 | } 39 | }); 40 | s.opened = Some(|conn, data| { 41 | data.lock().unwrap().value1 += 1; 42 | if conn.id() == 5 { 43 | conn.write_error("ERR unauthorized"); 44 | conn.close(); 45 | } 46 | // println!("opened: {}", conn.id()); 47 | }); 48 | s.closed = Some(|_conn, data, _| { 49 | data.lock().unwrap().value1 -= 1; 50 | // println!("closed: {}", conn.id()); 51 | }); 52 | s.tick = Some(|data| { 53 | if data.lock().unwrap().value2 == N - 1 { 54 | None 55 | } else { 56 | Some(Duration::from_millis(10)) 57 | } 58 | }); 59 | s.serve().unwrap(); 60 | assert_eq!(db.lock().unwrap().value1, 0); 61 | assert_eq!(s.serve().is_err(), true); 62 | } 63 | 64 | fn make_conn(addr: &str) -> Result> { 65 | let start = Instant::now(); 66 | loop { 67 | thread::sleep(Duration::from_millis(10)); 68 | if start.elapsed() > Duration::from_secs(5) { 69 | return Err(From::from("Connection timeout")); 70 | } 71 | if let Ok(stream) = TcpStream::connect(addr) { 72 | return Ok(stream); 73 | } 74 | } 75 | } 76 | 77 | fn test_conn(_i: usize, addr: &str) -> Result<(), Box> { 78 | let mut wr = make_conn(addr)?; 79 | let mut rd = BufReader::new(wr.try_clone()?); 80 | wr.write("NooP 1 2 3\r\n".as_bytes())?; 81 | let mut line = String::new(); 82 | rd.read_line(&mut line)?; 83 | if line == "-ERR unauthorized\r\n" { 84 | return Ok(()); 85 | } 86 | assert_eq!(line, "-ERR unknown command 'noop'\r\n"); 87 | wr.write("incr\r\n".as_bytes())?; 88 | let mut line = String::new(); 89 | rd.read_line(&mut line)?; 90 | assert_eq!(line, "+OK\r\n"); 91 | Ok(()) 92 | } 93 | --------------------------------------------------------------------------------