├── LICENSE ├── bench ├── Cargo.toml ├── run-debug.sh ├── run-release.sh └── src │ ├── bench.rs │ ├── main.rs │ └── wait_group.rs ├── benchmark.md ├── client ├── Cargo.toml ├── examples │ ├── pub_subject.rs │ └── sub_subject.rs └── src │ ├── client.rs │ ├── error.rs │ ├── lib.rs │ └── parser.rs ├── readme.md └── server ├── Cargo.toml └── src ├── client.rs ├── error.rs ├── main.rs ├── parser.rs ├── server.rs ├── simple_sublist.rs └── sublist.rs /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 baizhenxuan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bench/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "bench" 3 | version = "0.1.0" 4 | authors = ["bai "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | log="0.4" 11 | env_logger="0.7" 12 | serde="1.0" 13 | serde_json="1.0" 14 | serde_derive = "1.0" 15 | tokio = { version = "0.2.0", path = "../../tokio", features = ["full"] } #{ version = "0.2", features = ["full"] } 16 | tokio-util ={ version = "0.2.0", path = "../../tokio-util", features = ["full"] } #{ version = "0.2", features = ["full"] } # 17 | rand="0.7" 18 | bitflags="1.0" 19 | lazy_static="1.0" 20 | get_if_addrs="0.5" 21 | futures = { version = "0.3.0", features = ["async-await"] } 22 | bytes="0.5" 23 | csv="1.1" 24 | structopt="0.3" 25 | client={version="0.1",path="../client"} 26 | async-wg="0.1" 27 | futures-util = "0.3.3" -------------------------------------------------------------------------------- /bench/run-debug.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # 一个pub,十个sub,默认消息数量100000 3 | cargo run -- --subject test --num-subs 10 4 | -------------------------------------------------------------------------------- /bench/run-release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | cargo run --release -- --urls 127.0.0.1:4222 --subject test --num-subs 10 --num-msgs 100000 --num-pubs 10 -------------------------------------------------------------------------------- /bench/src/bench.rs: -------------------------------------------------------------------------------- 1 | use std::cmp::max; 2 | use std::fmt::{Error, Formatter}; 3 | 4 | use bytes::buf::BufMutExt; 5 | use bytes::BytesMut; 6 | use std::cmp::min; 7 | use std::fmt::Write; 8 | use std::ops::Sub; 9 | use tokio::time::{Duration, Instant}; 10 | 11 | #[derive(Debug, Clone)] 12 | pub struct Sample { 13 | job_msg_cnt: usize, 14 | msg_count: u64, 15 | msg_bytes: u64, 16 | io_bytes: u64, 17 | start: Instant, 18 | end: Instant, 19 | } 20 | impl Default for Sample { 21 | fn default() -> Self { 22 | Self { 23 | job_msg_cnt: 0, 24 | msg_count: 0, 25 | msg_bytes: 0, 26 | io_bytes: 0, 27 | start: Instant::now() + Duration::from_secs(60 * 60 * 24), 28 | end: Instant::now() - Duration::from_secs(60 * 60 * 24), 29 | } 30 | } 31 | } 32 | impl Sample { 33 | pub fn new( 34 | job_count: usize, 35 | msg_size: usize, 36 | msg_count: u64, //发出和收到的消息总和 37 | io_bytes: u64, //发出和收到的字节数总和 38 | start: Instant, 39 | end: Instant, 40 | ) -> Self { 41 | Self { 42 | job_msg_cnt: job_count, 43 | msg_count, 44 | msg_bytes: (msg_size * job_count) as u64, 45 | io_bytes, 46 | start, 47 | end, 48 | } 49 | } 50 | // Throughput of bytes per second 51 | pub fn throughput(&self) -> f64 { 52 | self.msg_bytes as f64 / self.duration().as_secs_f64() 53 | } 54 | 55 | pub fn duration(&self) -> Duration { 56 | self.end.sub(self.start) 57 | } 58 | // Rate of meessages in the job per second 59 | pub fn rate(&self) -> i64 { 60 | (self.job_msg_cnt as f64 / self.duration().as_secs_f64()) as i64 61 | } 62 | pub fn add_statistics(&mut self, other: &Self) { 63 | self.msg_count += other.msg_count; 64 | self.job_msg_cnt += other.job_msg_cnt; 65 | self.io_bytes += other.io_bytes; 66 | self.msg_bytes += other.msg_bytes; 67 | if self.start > other.start { 68 | self.start = other.start; 69 | } 70 | if self.end < other.end { 71 | self.end = other.end 72 | } 73 | } 74 | } 75 | impl std::fmt::Display for Sample { 76 | fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> { 77 | let rate = self.rate(); 78 | let throughput = self.throughput(); 79 | write!(f, "{} msgs/sec ~ {}/sec", rate, throughput) 80 | } 81 | } 82 | #[derive(Debug, Default, Clone)] 83 | struct SampleGroup { 84 | group_sample: Sample, 85 | samples: Vec, 86 | } 87 | impl SampleGroup { 88 | // AddSample adds a Sample to the SampleGroup. After adding a Sample it shouldn't be modified. 89 | pub fn add_sample(&mut self, s: Sample) { 90 | if self.samples.is_empty() { 91 | self.group_sample.start = s.start; 92 | self.group_sample.end = s.end; 93 | } 94 | self.group_sample.add_statistics(&s); 95 | if s.start < self.group_sample.start { 96 | self.group_sample.start = s.start; 97 | } 98 | if s.end > self.group_sample.end { 99 | self.group_sample.end = s.end; 100 | } 101 | self.samples.push(s) 102 | } 103 | 104 | pub fn has_samples(&self) -> bool { 105 | !self.samples.is_empty() 106 | } 107 | pub fn statistics(&self) -> String { 108 | format!( 109 | "min {} | avg {} | max {} | stddev {} msgs", 110 | self.min_rate(), 111 | self.avg_rate(), 112 | self.max_rate(), 113 | self.std_dev(), 114 | ) 115 | } 116 | // MinRate returns the smallest message rate in the SampleGroup 117 | pub fn min_rate(&self) -> i64 { 118 | let mut m = std::i64::MAX; 119 | for s in self.samples.iter() { 120 | m = min(m, s.rate()); 121 | } 122 | m 123 | } 124 | pub fn max_rate(&self) -> i64 { 125 | let mut m = std::i64::MIN; 126 | for s in self.samples.iter() { 127 | m = max(m, s.rate()); 128 | } 129 | m 130 | } 131 | pub fn avg_rate(&self) -> i64 { 132 | if self.samples.is_empty() { 133 | return 0; 134 | } 135 | let mut sum = 0; 136 | for s in self.samples.iter() { 137 | sum += s.rate(); 138 | } 139 | sum / self.samples.len() as i64 140 | } 141 | // StdDev returns the standard deviation the message rates in the SampleGroup 142 | //求速率的标准差 143 | pub fn std_dev(&self) -> f64 { 144 | if self.samples.is_empty() { 145 | return 0.0; 146 | } 147 | let avg = self.avg_rate() as f64; 148 | let mut sum = 0 as f64; 149 | for s in self.samples.iter() { 150 | sum += (s.rate() as f64 - avg).powf(2.0); 151 | } 152 | let variance = sum / self.samples.len() as f64; 153 | variance.sqrt() 154 | } 155 | } 156 | 157 | #[derive(Debug, Default, Clone)] 158 | pub struct Benchmark { 159 | bench_sample: Sample, 160 | name: String, 161 | pubs: SampleGroup, 162 | subs: SampleGroup, 163 | } 164 | 165 | impl Benchmark { 166 | pub fn new>(name: S) -> Self { 167 | Self { 168 | bench_sample: Default::default(), 169 | name: name.into(), 170 | pubs: Default::default(), 171 | subs: Default::default(), 172 | } 173 | } 174 | pub fn add_pub_sample(&mut self, s: Sample) { 175 | self.bench_sample.add_statistics(&s); 176 | self.pubs.add_sample(s); 177 | } 178 | pub fn add_sub_sample(&mut self, s: Sample) { 179 | self.bench_sample.add_statistics(&s); 180 | self.subs.add_sample(s); 181 | } 182 | // Report returns a human readable report of the samples taken in the Benchmark 183 | pub fn report(&self) -> String { 184 | let mut buf = BytesMut::with_capacity(1024); 185 | let mut indent = "".to_string(); 186 | if !self.pubs.has_samples() && !self.subs.has_samples() { 187 | return "No publisher and subscribers".into(); 188 | } 189 | if self.pubs.has_samples() && self.subs.has_samples() { 190 | let _ = write!( 191 | buf, 192 | "{} pub/sub stats: {}\n", 193 | &self.name, &self.bench_sample 194 | ); 195 | indent += " "; 196 | } 197 | if self.pubs.has_samples() { 198 | let _ = write!(buf, "{}Pub stats: {}\n", indent, &self.pubs.group_sample); 199 | if self.pubs.samples.len() > 1 { 200 | for (i, stat) in self.pubs.samples.iter().enumerate() { 201 | let _ = write!( 202 | buf, 203 | "{} [{}] {} ({} msgs)\n", 204 | indent, 205 | i + 1, 206 | stat, 207 | stat.job_msg_cnt, 208 | ); 209 | } 210 | } 211 | } 212 | if self.subs.has_samples() { 213 | let _ = write!(buf, "{}Sub stats: {}\n", indent, self.subs.group_sample); 214 | if self.subs.samples.len() > 1 { 215 | for (i, stat) in self.subs.samples.iter().enumerate() { 216 | let _ = write!( 217 | buf, 218 | "{} [{}] {} ({} msgs)\n", 219 | indent, 220 | i + 1, 221 | stat, 222 | stat.job_msg_cnt, 223 | ); 224 | } 225 | } 226 | } 227 | let _ = write!(buf, "{} {}\n", indent, self.subs.statistics()); 228 | String::from_utf8(buf.to_vec()).unwrap() 229 | } 230 | pub fn csv(&self) -> String { 231 | let buf = BytesMut::with_capacity(1024); 232 | let mut wtr = csv::Writer::from_writer(buf.writer()); 233 | let headers = vec![ 234 | "#RunID", 235 | "ClientID", 236 | "MsgCount", 237 | "MsgBytes", 238 | "MsgsPerSec", 239 | "BytesPerSec", 240 | "DurationSecs", 241 | ]; 242 | wtr.write_record(&headers).unwrap(); 243 | let mut pre = "S"; 244 | let groups = vec![&self.subs, &self.pubs]; 245 | for (i, g) in groups.iter().enumerate() { 246 | if i == 1 { 247 | pre = "P"; 248 | } 249 | for (j, s) in g.samples.iter().enumerate() { 250 | let r = vec![ 251 | "runid".into(), 252 | format!("{}{}", pre, j), 253 | format!("{}", s.msg_count), 254 | format!("{}", s.msg_bytes), 255 | format!("{}", s.rate()), 256 | format!("{}", s.throughput()), 257 | format!("{}", s.duration().as_secs_f64()), 258 | ]; 259 | wtr.write_record(r).unwrap(); 260 | } 261 | } 262 | wtr.flush().unwrap(); 263 | let buf = wtr.into_inner().unwrap().into_inner(); 264 | String::from_utf8(buf.to_vec()).unwrap() 265 | } 266 | } 267 | // MsgsPerClient divides the number of messages by the number of clients and tries to distribute them as evenly as possible 268 | pub fn msgs_per_client(num_msgs: usize, num_clients: usize) -> Vec { 269 | let mut counts = vec![0; num_clients]; 270 | if num_clients == 0 || num_msgs == 0 { 271 | return counts; 272 | } 273 | let mc = num_msgs / num_clients; 274 | for i in 0..num_clients { 275 | counts[i] = mc; 276 | } 277 | let extra = num_msgs % num_clients; 278 | for i in 0..extra { 279 | counts[i] += 1; 280 | } 281 | counts 282 | } 283 | #[cfg(test)] 284 | mod tests { 285 | use super::*; 286 | use std::ops::Add; 287 | 288 | const MSG_SIZE: usize = 8; 289 | const MILLION: usize = 1000 * 1000; 290 | use lazy_static::lazy_static; 291 | lazy_static! { 292 | static ref BASE_TIME: Instant = { Instant::now() }; 293 | } 294 | fn million_messags_second_sample(seconds: i32) -> Sample { 295 | let messages = MILLION * seconds as usize; 296 | let start = BASE_TIME.clone(); 297 | let end = start.add(Duration::from_secs(seconds as u64)); 298 | let mut s = Sample::new(messages, MSG_SIZE, messages as u64, 0, start, end); 299 | s.msg_bytes = (messages * MSG_SIZE) as u64; 300 | s.io_bytes = s.msg_bytes; 301 | s 302 | } 303 | #[test] 304 | fn test() {} 305 | #[test] 306 | fn test_std_dev() { 307 | let mut sg = SampleGroup::default(); 308 | sg.add_sample(million_messags_second_sample(1)); 309 | sg.add_sample(million_messags_second_sample(1)); 310 | sg.add_sample(million_messags_second_sample(1)); 311 | assert_eq!(sg.std_dev(), 0.0); 312 | } 313 | #[test] 314 | fn test_bench_setup() { 315 | let mut bench = Benchmark::new("test"); 316 | bench.add_pub_sample(million_messags_second_sample(1)); 317 | bench.add_sub_sample(million_messags_second_sample(1)); 318 | assert_eq!(bench.pubs.samples.len(), 1); 319 | assert_eq!(bench.subs.samples.len(), 1); 320 | assert_eq!(bench.bench_sample.msg_count as usize, 2 * MILLION); 321 | assert_eq!(bench.bench_sample.io_bytes as usize, 2 * MILLION * MSG_SIZE); 322 | assert_eq!(bench.bench_sample.duration(), Duration::from_secs(1)); 323 | } 324 | fn make_bench(subs: usize, pubs: usize) -> Benchmark { 325 | let mut bench = Benchmark::default(); 326 | for _ in 0..subs { 327 | bench.add_sub_sample(million_messags_second_sample(1)); 328 | } 329 | for _ in 0..pubs { 330 | bench.add_pub_sample(million_messags_second_sample(1)); 331 | } 332 | bench 333 | } 334 | #[test] 335 | fn test_csv() { 336 | let bench = make_bench(2, 3); 337 | let csv = bench.csv(); 338 | println!("csv\n{}", csv); 339 | } 340 | #[test] 341 | fn test_report() { 342 | let bench = make_bench(2, 3); 343 | let r = bench.report(); 344 | println!("r=\n{}", r); 345 | } 346 | } 347 | -------------------------------------------------------------------------------- /bench/src/main.rs: -------------------------------------------------------------------------------- 1 | mod bench; 2 | mod wait_group; 3 | use crate::bench::{msgs_per_client, Benchmark, Sample}; 4 | use client::client::Client; 5 | use std::error::Error; 6 | use std::sync::Arc; 7 | use structopt::StructOpt; 8 | use tokio::sync::{oneshot, Mutex}; 9 | use tokio::time::Instant; 10 | use wait_group::WaitGroup; 11 | 12 | /// benchmark for simple nats 13 | #[derive(StructOpt, Debug, Clone)] 14 | #[structopt(name = "simple nats")] 15 | struct Opt { 16 | ///The nats server URLs (separated by comma) 17 | #[structopt(long, default_value = "127.0.0.1:4222")] 18 | urls: String, 19 | ///Save bench data to csv file 20 | #[structopt(long, default_value = "")] 21 | csv_file: String, 22 | ///Number of Concurrent Publishers 23 | #[structopt(long, default_value = "1")] 24 | num_pubs: usize, 25 | ///Number of Concurrent Subscribers 26 | #[structopt(long, default_value = "0")] 27 | num_subs: usize, 28 | ///Number of Messages to Publish 29 | #[structopt(long, default_value = "100000")] 30 | num_msgs: usize, 31 | ///Size of the message. 32 | #[structopt(long, default_value = "128")] 33 | msg_size: usize, 34 | ///publish subject 35 | #[structopt(long, default_value = "test_subject")] 36 | subject: String, 37 | } 38 | #[tokio::main] 39 | async fn main() -> Result<(), Box> { 40 | let opt: Opt = Opt::from_args(); 41 | println!("opt={:?}", opt); 42 | println!("Hello, world!"); 43 | let mut start_wg = WaitGroup::new("start_wg1".into(), opt.num_subs); 44 | let mut done_wg = WaitGroup::new("donw_wg".into(), opt.num_pubs + opt.num_subs); 45 | let bench = Arc::new(Mutex::new(Benchmark::new("Nats"))); 46 | 47 | for _ in 0..opt.num_subs { 48 | let mut c = Client::connect(opt.urls.as_str()).await.unwrap(); 49 | let start_wg = start_wg.clone(); 50 | let done_wg = done_wg.clone(); 51 | let bench = bench.clone(); 52 | let opt = opt.clone(); 53 | tokio::spawn(async move { 54 | run_subscriber(&mut c, start_wg, done_wg, opt, bench).await; 55 | c.close(); 56 | println!("run_subscriber finished"); 57 | }); 58 | } 59 | println!("startwg1 start wait"); 60 | start_wg.wait().await; 61 | println!("subs all started."); 62 | let mut start_wg = WaitGroup::new("start_wg2".into(), opt.num_pubs); 63 | 64 | let pub_counts = msgs_per_client(opt.num_msgs, opt.num_pubs); 65 | for i in 0..opt.num_pubs { 66 | let mut c = Client::connect(opt.urls.as_str()).await.unwrap(); 67 | let start_wg = start_wg.clone(); 68 | let done_wg = done_wg.clone(); 69 | let bench = bench.clone(); 70 | let opt = opt.clone(); 71 | let num_msgs = pub_counts[i]; 72 | tokio::spawn(async move { 73 | run_publiser(&mut c, start_wg, done_wg, num_msgs, opt, bench).await; 74 | c.close(); 75 | println!("run_publiser finished"); 76 | }); 77 | } 78 | start_wg.wait().await; 79 | println!("pubs all started."); 80 | done_wg.wait().await; 81 | println!("all task stopped."); 82 | println!("{}\n", bench.lock().await.report()); 83 | if opt.csv_file.len() > 0 { 84 | tokio::fs::write(opt.csv_file.as_str(), bench.lock().await.csv()) 85 | .await 86 | .unwrap(); 87 | println!("saved metric data in csv file {}", opt.csv_file); 88 | } 89 | Ok(()) 90 | } 91 | 92 | async fn run_publiser( 93 | c: &mut Client, 94 | mut start_wg: WaitGroup, 95 | mut done_wg: WaitGroup, 96 | num_msgs: usize, 97 | opt: Opt, 98 | bench: Arc>, 99 | ) { 100 | start_wg.done().await; 101 | let msg = vec![0x4a; opt.msg_size]; 102 | let start = Instant::now(); 103 | let t = 0..num_msgs; 104 | let mut i = 0; 105 | let step = 455; 106 | let mut msgs = Vec::with_capacity(step); 107 | let mut subjects = Vec::with_capacity(step); 108 | while i < num_msgs { 109 | let mut j = i; 110 | let expect = i + step; 111 | while j < num_msgs && j < expect { 112 | msgs.push(msg.as_slice()); 113 | subjects.push(opt.subject.as_str()); 114 | j += 1; 115 | i += 1; 116 | } 117 | // println!("pub step"); 118 | if msgs.len() > 0 { 119 | // println!("send message len={}", subjects.len()); 120 | if let Err(e) = c.pub_messages(subjects.as_slice(), msgs.as_slice()).await { 121 | println!("pub message error {}", e); 122 | return; 123 | }; 124 | } 125 | msgs.clear(); 126 | subjects.clear(); 127 | } 128 | let s = Sample::new( 129 | num_msgs, 130 | opt.msg_size, 131 | num_msgs as u64, 132 | (num_msgs * opt.msg_size) as u64, 133 | start, 134 | Instant::now(), 135 | ); 136 | bench.lock().await.add_pub_sample(s); 137 | done_wg.done().await; 138 | println!("one pub stoped."); 139 | } 140 | 141 | async fn run_subscriber( 142 | c: &mut Client, 143 | mut start_wg: WaitGroup, 144 | mut done_wg: WaitGroup, 145 | opt: Opt, 146 | bench: Arc>, 147 | ) { 148 | start_wg.done().await; 149 | let start = Instant::now(); 150 | let mut received_msgs = 0; 151 | let mut received_bytes = 0; 152 | let (tx, rx) = oneshot::channel(); 153 | let mut tx = Some(tx); 154 | let expected_msgs = opt.num_msgs; 155 | let _ = c 156 | .sub_message( 157 | opt.subject.clone(), 158 | None, 159 | Box::new(move |msg| { 160 | received_msgs += 1; 161 | received_bytes += msg.len(); 162 | if received_msgs >= expected_msgs { 163 | if let Some(tx) = tx.take() { 164 | let _ = tx.send((received_msgs, received_bytes)); 165 | println!("sub message end."); 166 | } 167 | } 168 | Ok(()) 169 | }), 170 | ) 171 | .await; 172 | let (received_msgs, received_bytes) = rx.await.unwrap(); 173 | let s = Sample::new( 174 | opt.num_msgs, 175 | opt.msg_size, 176 | received_msgs as u64, 177 | received_bytes as u64, 178 | start, 179 | Instant::now(), 180 | ); 181 | bench.lock().await.add_sub_sample(s); 182 | println!("subsriber done"); 183 | done_wg.done().await; 184 | } 185 | #[test] 186 | fn test() {} 187 | -------------------------------------------------------------------------------- /bench/src/wait_group.rs: -------------------------------------------------------------------------------- 1 | //! # async-wg 2 | //! 3 | //! Async version WaitGroup for RUST. 4 | //! 5 | //! ## Examples 6 | //! 7 | //! ```rust 8 | //! #[tokio::main] 9 | //! async fn main() { 10 | //! use async_wg::WaitGroup; 11 | //! 12 | //! // Create a new wait group. 13 | //! let wg = WaitGroup::new(); 14 | //! 15 | //! for _ in 0..10 { 16 | //! let mut wg = wg.clone(); 17 | //! // Add count n. 18 | //! wg.add(1).await; 19 | //! 20 | //! tokio::spawn(async move { 21 | //! // Do some work. 22 | //! 23 | //! // Done count 1. 24 | //! wg.done().await; 25 | //! }); 26 | //! } 27 | //! 28 | //! // Wait for done count is equal to add count. 29 | //! wg.await; 30 | //! } 31 | //! ``` 32 | //! 33 | //! ## Benchmarks 34 | //! 35 | //! Simple benchmark comparison run on github actions. 36 | //! 37 | //! Code: [benchs/main.rs](https://github.com/jmjoy/async-wg/blob/master/benches/main.rs) 38 | //! 39 | //! ```text 40 | //! test bench_join_handle ... bench: 34,485 ns/iter (+/- 18,969) 41 | //! test bench_wait_group ... bench: 36,916 ns/iter (+/- 7,555) 42 | //! ``` 43 | //! 44 | //! ## License 45 | //! 46 | //! The Unlicense. 47 | //! 48 | 49 | use futures_util::lock::Mutex; 50 | use std::future::Future; 51 | use std::pin::Pin; 52 | use std::sync::atomic::AtomicBool; 53 | use std::sync::Arc; 54 | use std::task::{Context, Poll, Waker}; 55 | use tokio::sync::mpsc; 56 | 57 | #[derive(Clone)] 58 | /// Enables multiple tasks to synchronize the beginning or end of some computation. 59 | pub struct WaitGroup { 60 | name: String, //for test 61 | count: usize, 62 | tx: mpsc::Sender<()>, 63 | rx: Arc>>, 64 | } 65 | 66 | impl WaitGroup { 67 | /// Creates a new wait group and returns the single reference to it. 68 | /// 69 | /// # Examples 70 | /// ```rust 71 | /// use async_wg::WaitGroup; 72 | /// let wg = WaitGroup::new(); 73 | /// ``` 74 | pub fn new(name: String, count: usize) -> WaitGroup { 75 | let mut count2 = count; 76 | if count2 == 0 { 77 | count2 = 1; 78 | } 79 | let (tx, rx) = mpsc::channel(count2); 80 | WaitGroup { 81 | name, 82 | tx, 83 | rx: Arc::new(Mutex::new(rx)), 84 | count, 85 | } 86 | } 87 | 88 | /// Done count 1. 89 | pub async fn done(&mut self) { 90 | if let Err(e) = self.tx.send(()).await { 91 | panic!("{} send error", self.name); 92 | } 93 | } 94 | pub async fn wait(&mut self) { 95 | let mut rx = self.rx.lock().await; 96 | for i in 0..self.count { 97 | rx.recv().await; 98 | } 99 | } 100 | } 101 | 102 | #[cfg(test)] 103 | mod tests { 104 | use super::*; 105 | 106 | #[tokio::main] 107 | #[test] 108 | async fn can_quit() { 109 | let mut wg = WaitGroup::new("test".into(), 4); 110 | assert_eq!(wg.count, 4); 111 | let mut wg2 = wg.clone(); 112 | tokio::spawn(async move { 113 | tokio::time::delay_for(tokio::time::Duration::from_millis(10)).await; 114 | wg2.done().await; 115 | wg2.done().await; 116 | wg2.done().await; 117 | wg2.done().await; 118 | }); 119 | wg.wait().await 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /benchmark.md: -------------------------------------------------------------------------------- 1 | before any optimaztion: debug server+client 2 | ``` 3 | 4 | Nats pub/sub stats: 28507 msgs/sec ~ 3649009.168600655/sec 5 | Pub stats: 2770 msgs/sec ~ 354636.4610452923/sec 6 | Sub stats: 25916 msgs/sec ~ 3317281.062364232/sec 7 | [1] 2591 msgs/sec ~ 331732.67399158835/sec (100000 msgs) 8 | [2] 2591 msgs/sec ~ 331737.8092221816/sec (100000 msgs) 9 | [3] 2591 msgs/sec ~ 331740.81480018055/sec (100000 msgs) 10 | [4] 2591 msgs/sec ~ 331738.8905870689/sec (100000 msgs) 11 | [5] 2591 msgs/sec ~ 331742.6321808266/sec (100000 msgs) 12 | [6] 2591 msgs/sec ~ 331731.7214811043/sec (100000 msgs) 13 | [7] 2591 msgs/sec ~ 331743.00491758913/sec (100000 msgs) 14 | [8] 2591 msgs/sec ~ 331732.89617403643/sec (100000 msgs) 15 | [9] 2591 msgs/sec ~ 331739.76043689373/sec (100000 msgs) 16 | [10] 2591 msgs/sec ~ 331744.08291154966/sec (100000 msgs) 17 | min 2591 | avg 2591 | max 2591 | stddev 0 msgs 18 | 19 | ``` 20 | 21 | 22 | before any optimaztion: release server+client 23 | ``` 24 | Nats pub/sub stats: 61389 msgs/sec ~ 7857896.78813678/sec 25 | Pub stats: 5882 msgs/sec ~ 752915.7634376923/sec 26 | Sub stats: 55808 msgs/sec ~ 7143542.5346698/sec 27 | [1] 5580 msgs/sec ~ 714362.6022647332/sec (100000 msgs) 28 | [2] 5581 msgs/sec ~ 714379.9072697352/sec (100000 msgs) 29 | [3] 5581 msgs/sec ~ 714383.2669884142/sec (100000 msgs) 30 | [4] 5581 msgs/sec ~ 714392.3734657292/sec (100000 msgs) 31 | [5] 5581 msgs/sec ~ 714388.2732535391/sec (100000 msgs) 32 | [6] 5581 msgs/sec ~ 714394.9112222577/sec (100000 msgs) 33 | [7] 5581 msgs/sec ~ 714398.9512635452/sec (100000 msgs) 34 | [8] 5580 msgs/sec ~ 714365.3115635442/sec (100000 msgs) 35 | [9] 5581 msgs/sec ~ 714368.6850990714/sec (100000 msgs) 36 | [10] 5581 msgs/sec ~ 714399.5103139627/sec (100000 msgs) 37 | min 5580 | avg 5580 | max 5581 | stddev 0.8944271909999159 msgs 38 | ``` 39 | 40 | 41 | ## 使用BytesMut优化缓存,减少await调用,就有六倍的提升. 42 | ``` 43 | Nats pub/sub stats: 377984 msgs/sec ~ 48382031.15402761/sec 44 | Pub stats: 36107 msgs/sec ~ 4621775.446487484/sec 45 | Sub stats: 343622 msgs/sec ~ 43983664.68547965/sec 46 | [1] 34366 msgs/sec ~ 4398952.634534228/sec (100000 msgs) 47 | [2] 34375 msgs/sec ~ 4400015.705993563/sec (100000 msgs) 48 | [3] 34378 msgs/sec ~ 4400465.716413369/sec (100000 msgs) 49 | [4] 34363 msgs/sec ~ 4398490.865962944/sec (100000 msgs) 50 | [5] 34376 msgs/sec ~ 4400181.333672851/sec (100000 msgs) 51 | [6] 34370 msgs/sec ~ 4399367.908305072/sec (100000 msgs) 52 | [7] 34368 msgs/sec ~ 4399154.384783533/sec (100000 msgs) 53 | [8] 34372 msgs/sec ~ 4399698.322447861/sec (100000 msgs) 54 | [9] 34363 msgs/sec ~ 4398570.699846082/sec (100000 msgs) 55 | [10] 34370 msgs/sec ~ 4399408.850882781/sec (100000 msgs) 56 | min 34363 | avg 34370 | max 34378 | stddev 4.969909455915671 msgs 57 | ``` 58 | 59 | ## 使用多个pub,测试,代码和刚刚的一样, 60 | 61 | 性能有10%以上的提升,但是观察到nats-server cpu占用一致比较低,没有超过client 62 | cargo run --release -- --subject test --num-subs 10 --num-msgs 1000000 --num-pubs 10 63 | ``` 64 | Nats pub/sub stats: 433364 msgs/sec ~ 55470702.813337795/sec 65 | Pub stats: 41424 msgs/sec ~ 5302354.359580678/sec 66 | [1] 4149 msgs/sec ~ 531126.7721818777/sec (100000 msgs) 67 | [2] 4149 msgs/sec ~ 531104.6321667176/sec (100000 msgs) 68 | [3] 4148 msgs/sec ~ 530950.1965164108/sec (100000 msgs) 69 | [4] 4148 msgs/sec ~ 530952.9439894602/sec (100000 msgs) 70 | [5] 4149 msgs/sec ~ 531135.9464654801/sec (100000 msgs) 71 | [6] 4147 msgs/sec ~ 530886.8992623787/sec (100000 msgs) 72 | [7] 4148 msgs/sec ~ 530979.3590014324/sec (100000 msgs) 73 | [8] 4148 msgs/sec ~ 530954.9749239235/sec (100000 msgs) 74 | [9] 4146 msgs/sec ~ 530720.1580503538/sec (100000 msgs) 75 | [10] 4146 msgs/sec ~ 530730.5288312284/sec (100000 msgs) 76 | Sub stats: 393968 msgs/sec ~ 50427911.6484889/sec 77 | [1] 39397 msgs/sec ~ 5042943.996870369/sec (1000000 msgs) 78 | [2] 39398 msgs/sec ~ 5042977.890046073/sec (1000000 msgs) 79 | [3] 39399 msgs/sec ~ 5043108.388063024/sec (1000000 msgs) 80 | [4] 39398 msgs/sec ~ 5043044.403963505/sec (1000000 msgs) 81 | [5] 39398 msgs/sec ~ 5043000.823317792/sec (1000000 msgs) 82 | [6] 39397 msgs/sec ~ 5042888.875959185/sec (1000000 msgs) 83 | [7] 39397 msgs/sec ~ 5042860.738008805/sec (1000000 msgs) 84 | [8] 39399 msgs/sec ~ 5043073.080630799/sec (1000000 msgs) 85 | [9] 39396 msgs/sec ~ 5042811.453494481/sec (1000000 msgs) 86 | [10] 39396 msgs/sec ~ 5042791.16484889/sec (1000000 msgs) 87 | min 39396 | avg 39397 | max 39399 | stddev 1.140175425099138 msgs 88 | ``` 89 | 90 | ## 优化,msg_buf不再每次都分配 91 | 但是这里有一个问题,这个buf什么时候释放呢?一直占着?并且不会缩小. 92 | 只要连接不断开,就会一直占用. 93 | 经测试影响不大,因为根本就没有用到msg_buf,目前的msg_size是128,不会用到. 94 | ``` 95 | Nats pub/sub stats: 377535 msgs/sec ~ 48324568.42183404/sec 96 | Pub stats: 36157 msgs/sec ~ 4628158.023356056/sec 97 | Sub stats: 343214 msgs/sec ~ 43931425.83803094/sec 98 | [1] 34326 msgs/sec ~ 4393757.401522481/sec (100000 msgs) 99 | [2] 34324 msgs/sec ~ 4393530.965829699/sec (100000 msgs) 100 | [3] 34327 msgs/sec ~ 4393877.165810065/sec (100000 msgs) 101 | [4] 34330 msgs/sec ~ 4394268.074840311/sec (100000 msgs) 102 | [5] 34328 msgs/sec ~ 4394076.038781426/sec (100000 msgs) 103 | [6] 34333 msgs/sec ~ 4394714.7405391075/sec (100000 msgs) 104 | [7] 34335 msgs/sec ~ 4394895.1331806155/sec (100000 msgs) 105 | [8] 34332 msgs/sec ~ 4394508.732300149/sec (100000 msgs) 106 | [9] 34321 msgs/sec ~ 4393189.280530978/sec (100000 msgs) 107 | [10] 34330 msgs/sec ~ 4394268.596802779/sec (100000 msgs) 108 | min 34321 | avg 34328 | max 34335 | stddev 4.09878030638384 msgs 109 | 110 | Nats pub/sub stats: 385345 msgs/sec ~ 49324272.01149149/sec 111 | Pub stats: 36887 msgs/sec ~ 4721574.6730107395/sec 112 | Sub stats: 350314 msgs/sec ~ 44840247.28317408/sec 113 | [1] 35047 msgs/sec ~ 4486101.932145457/sec (100000 msgs) 114 | [2] 35043 msgs/sec ~ 4485538.845636878/sec (100000 msgs) 115 | [3] 35049 msgs/sec ~ 4486307.3576117465/sec (100000 msgs) 116 | [4] 35038 msgs/sec ~ 4484933.827490584/sec (100000 msgs) 117 | [5] 35044 msgs/sec ~ 4485645.19838767/sec (100000 msgs) 118 | [6] 35050 msgs/sec ~ 4486476.823528619/sec (100000 msgs) 119 | [7] 35051 msgs/sec ~ 4486560.542245931/sec (100000 msgs) 120 | [8] 35053 msgs/sec ~ 4486888.8000169415/sec (100000 msgs) 121 | [9] 35034 msgs/sec ~ 4484385.156475487/sec (100000 msgs) 122 | [10] 35031 msgs/sec ~ 4484024.728317408/sec (100000 msgs) 123 | min 35031 | avg 35044 | max 35053 | stddev 7.113367697511496 msgs 124 | ``` 125 | 126 | ## 并发发送 127 | send_message由串行改为并发, 128 | `--num-pubs 1`基本没变化,但是改为`--num-pubs 10`有一倍提升. 129 | 这种情况下,无论是client还是server内存占用都非常低。 考虑启用缓存,来空间换时间. 130 | 131 | ### --num-pubs 1 132 | ``` 133 | Nats pub/sub stats: 376232 msgs/sec ~ 48157765.32076849/sec 134 | Pub stats: 34386 msgs/sec ~ 4401420.72194075/sec 135 | Sub stats: 342029 msgs/sec ~ 43779786.65524408/sec 136 | [1] 34205 msgs/sec ~ 4378252.77514042/sec (1000000 msgs) 137 | [2] 34204 msgs/sec ~ 4378216.303690606/sec (1000000 msgs) 138 | [3] 34203 msgs/sec ~ 4378102.962688762/sec (1000000 msgs) 139 | [4] 34204 msgs/sec ~ 4378131.4951717425/sec (1000000 msgs) 140 | [5] 34205 msgs/sec ~ 4378278.877766064/sec (1000000 msgs) 141 | [6] 34204 msgs/sec ~ 4378176.879110729/sec (1000000 msgs) 142 | [7] 34203 msgs/sec ~ 4377995.095039316/sec (1000000 msgs) 143 | [8] 34203 msgs/sec ~ 4378071.547512658/sec (1000000 msgs) 144 | [9] 34203 msgs/sec ~ 4378032.975034622/sec (1000000 msgs) 145 | [10] 34204 msgs/sec ~ 4378138.376806764/sec (1000000 msgs) 146 | min 34203 | avg 34203 | max 34205 | stddev 1.0954451150103321 msgs 147 | ``` 148 | ### --num-pubs 10 149 | ``` 150 | Nats pub/sub stats: 790452 msgs/sec ~ 101177900.76085752/sec 151 | Pub stats: 75425 msgs/sec ~ 9654401.307397576/sec 152 | [1] 7579 msgs/sec ~ 970234.0958495921/sec (100000 msgs) 153 | [2] 7574 msgs/sec ~ 969559.4115457184/sec (100000 msgs) 154 | [3] 7572 msgs/sec ~ 969224.644279701/sec (100000 msgs) 155 | [4] 7572 msgs/sec ~ 969298.5754256473/sec (100000 msgs) 156 | [5] 7572 msgs/sec ~ 969326.885318488/sec (100000 msgs) 157 | [6] 7573 msgs/sec ~ 969461.2974012174/sec (100000 msgs) 158 | [7] 7570 msgs/sec ~ 968976.6409039636/sec (100000 msgs) 159 | [8] 7571 msgs/sec ~ 969158.8835072055/sec (100000 msgs) 160 | [9] 7573 msgs/sec ~ 969358.9971748831/sec (100000 msgs) 161 | [10] 7569 msgs/sec ~ 968873.5004742752/sec (100000 msgs) 162 | Sub stats: 718593 msgs/sec ~ 91979909.78259775/sec 163 | [1] 71862 msgs/sec ~ 9198385.281672334/sec (1000000 msgs) 164 | [2] 71866 msgs/sec ~ 9198897.934254501/sec (1000000 msgs) 165 | [3] 71863 msgs/sec ~ 9198491.500543945/sec (1000000 msgs) 166 | [4] 71859 msgs/sec ~ 9198034.709645862/sec (1000000 msgs) 167 | [5] 71864 msgs/sec ~ 9198643.851382144/sec (1000000 msgs) 168 | [6] 71865 msgs/sec ~ 9198771.76335013/sec (1000000 msgs) 169 | [7] 71862 msgs/sec ~ 9198411.781956475/sec (1000000 msgs) 170 | [8] 71861 msgs/sec ~ 9198245.238568164/sec (1000000 msgs) 171 | [9] 71860 msgs/sec ~ 9198140.883406043/sec (1000000 msgs) 172 | [10] 71863 msgs/sec ~ 9198546.939085985/sec (1000000 msgs) 173 | min 71859 | avg 71862 | max 71866 | stddev 2.1213203435596424 msgs 174 | ``` 175 | 176 | 177 | ## 加入本地cache 178 | 只是为了测试,没有更新机制 179 | 不完善,可以看到有一定提高,但是不明显,几个百分点的样子 180 | ### -pubs 1 181 | ``` 182 | Nats pub/sub stats: 390384 msgs/sec ~ 49969156.719011426/sec 183 | Pub stats: 35743 msgs/sec ~ 4575136.18287007/sec 184 | Sub stats: 354894 msgs/sec ~ 45426506.108192205/sec 185 | [1] 35490 msgs/sec ~ 4542791.626176854/sec (1000000 msgs) 186 | [2] 35489 msgs/sec ~ 4542685.696734685/sec (1000000 msgs) 187 | [3] 35490 msgs/sec ~ 4542804.556233796/sec (1000000 msgs) 188 | [4] 35490 msgs/sec ~ 4542841.2181290435/sec (1000000 msgs) 189 | [5] 35491 msgs/sec ~ 4542859.491013554/sec (1000000 msgs) 190 | [6] 35489 msgs/sec ~ 4542700.385740192/sec (1000000 msgs) 191 | [7] 35489 msgs/sec ~ 4542717.684684231/sec (1000000 msgs) 192 | [8] 35490 msgs/sec ~ 4542745.48128393/sec (1000000 msgs) 193 | [9] 35490 msgs/sec ~ 4542815.080036936/sec (1000000 msgs) 194 | [10] 35490 msgs/sec ~ 4542788.906937107/sec (1000000 msgs) 195 | min 35489 | avg 35489 | max 35491 | stddev 1 msgs 196 | ``` 197 | ### -pubs 10 198 | ``` 199 | Nats pub/sub stats: 813669 msgs/sec ~ 104149682.70495887/sec 200 | Pub stats: 77845 msgs/sec ~ 9964231.962461589/sec 201 | [1] 7821 msgs/sec ~ 1001095.0332549638/sec (100000 msgs) 202 | [2] 7812 msgs/sec ~ 1000051.0897194011/sec (100000 msgs) 203 | [3] 7811 msgs/sec ~ 999847.6086172981/sec (100000 msgs) 204 | [4] 7810 msgs/sec ~ 999774.9932175713/sec (100000 msgs) 205 | [5] 7815 msgs/sec ~ 1000337.8641917821/sec (100000 msgs) 206 | [6] 7807 msgs/sec ~ 999348.5271154706/sec (100000 msgs) 207 | [7] 7805 msgs/sec ~ 999100.0203804227/sec (100000 msgs) 208 | [8] 7809 msgs/sec ~ 999660.2638192733/sec (100000 msgs) 209 | [9] 7805 msgs/sec ~ 999132.7749196282/sec (100000 msgs) 210 | [10] 7806 msgs/sec ~ 999202.2221919566/sec (100000 msgs) 211 | Sub stats: 739699 msgs/sec ~ 94681529.7317808/sec 212 | [1] 73976 msgs/sec ~ 9469015.62481706/sec (1000000 msgs) 213 | [2] 73975 msgs/sec ~ 9468881.113728898/sec (1000000 msgs) 214 | [3] 73974 msgs/sec ~ 9468771.64819417/sec (1000000 msgs) 215 | [4] 73979 msgs/sec ~ 9469414.929399932/sec (1000000 msgs) 216 | [5] 73978 msgs/sec ~ 9469243.13565144/sec (1000000 msgs) 217 | [6] 73977 msgs/sec ~ 9469088.617461707/sec (1000000 msgs) 218 | [7] 73970 msgs/sec ~ 9468177.11740739/sec (1000000 msgs) 219 | [8] 73973 msgs/sec ~ 9468598.682160331/sec (1000000 msgs) 220 | [9] 73971 msgs/sec ~ 9468336.636719728/sec (1000000 msgs) 221 | [10] 73972 msgs/sec ~ 9468442.361155936/sec (1000000 msgs) 222 | min 73970 | avg 73974 | max 73979 | stddev 2.9154759474226504 msgs 223 | ``` 224 | ## 客户端服务端批量消息处理 225 | 在测试过程中发现,单客户端批量消息发送,效果并不理想,两者结合起来,则有大幅度提升 226 | 可以听到到2500000 msgs/sec 227 | 228 | ### -pub 1 229 | ``` 230 | Nats pub/sub stats: 2820266 msgs/sec ~ 360994066.07799476/sec 231 | Pub stats: 272533 msgs/sec ~ 34884233.19733797/sec 232 | Sub stats: 2563878 msgs/sec ~ 328176423.70726794/sec 233 | [1] 259650 msgs/sec ~ 33235314.85993536/sec (100000 msgs) 234 | [2] 259853 msgs/sec ~ 33261248.67171009/sec (100000 msgs) 235 | [3] 258512 msgs/sec ~ 33089597.940031897/sec (100000 msgs) 236 | [4] 258564 msgs/sec ~ 33096266.07228048/sec (100000 msgs) 237 | [5] 258607 msgs/sec ~ 33101786.509107266/sec (100000 msgs) 238 | [6] 258163 msgs/sec ~ 33044888.545773942/sec (100000 msgs) 239 | [7] 258153 msgs/sec ~ 33043639.743700314/sec (100000 msgs) 240 | [8] 256826 msgs/sec ~ 32873816.19622194/sec (100000 msgs) 241 | [9] 257215 msgs/sec ~ 32923624.123998497/sec (100000 msgs) 242 | [10] 257401 msgs/sec ~ 32947340.651538294/sec (100000 msgs) 243 | min 256826 | avg 258294 | max 259853 | stddev 929.5894792864214 msgs 244 | ``` 245 | 246 | ### --num-pubs 10 247 | 没有明显变化,和`--num-pubs 1`是差不多的 248 | ``` 249 | Nats pub/sub stats: 2836183 msgs/sec ~ 363031456.8459067/sec 250 | Pub stats: 455987 msgs/sec ~ 58366353.48983736/sec 251 | [1] 46964 msgs/sec ~ 6011486.664181387/sec (10000 msgs) 252 | [2] 47609 msgs/sec ~ 6093961.236245924/sec (10000 msgs) 253 | [3] 46671 msgs/sec ~ 5973945.188782198/sec (10000 msgs) 254 | [4] 46761 msgs/sec ~ 5985441.694337137/sec (10000 msgs) 255 | [5] 46871 msgs/sec ~ 5999539.835294633/sec (10000 msgs) 256 | [6] 46971 msgs/sec ~ 6012344.414030855/sec (10000 msgs) 257 | [7] 46890 msgs/sec ~ 6002020.064256032/sec (10000 msgs) 258 | [8] 47038 msgs/sec ~ 6020924.7738744095/sec (10000 msgs) 259 | [9] 47144 msgs/sec ~ 6034558.284076369/sec (10000 msgs) 260 | [10] 45624 msgs/sec ~ 5839918.179461352/sec (10000 msgs) 261 | Sub stats: 2578348 msgs/sec ~ 330028597.13264245/sec 262 | [1] 267107 msgs/sec ~ 34189728.241582826/sec (100000 msgs) 263 | [2] 265067 msgs/sec ~ 33928665.84947354/sec (100000 msgs) 264 | [3] 264266 msgs/sec ~ 33826095.06122654/sec (100000 msgs) 265 | [4] 264055 msgs/sec ~ 33799086.41650338/sec (100000 msgs) 266 | [5] 262109 msgs/sec ~ 33549999.75597617/sec (100000 msgs) 267 | [6] 262331 msgs/sec ~ 33578437.388171345/sec (100000 msgs) 268 | [7] 260781 msgs/sec ~ 33380085.707386095/sec (100000 msgs) 269 | [8] 259417 msgs/sec ~ 33205397.433585964/sec (100000 msgs) 270 | [9] 259557 msgs/sec ~ 33223381.159022104/sec (100000 msgs) 271 | [10] 258487 msgs/sec ~ 33086353.612291187/sec (100000 msgs) 272 | min 258487 | avg 262317 | max 267107 | stddev 2653.772804894948 msgs 273 | 274 | ``` -------------------------------------------------------------------------------- /client/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "client" 3 | version = "0.1.0" 4 | authors = ["bai "] 5 | edition = "2018" 6 | license = "Apache-2.0" 7 | 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | log="0.4" 13 | env_logger="0.7" 14 | serde="1.0" 15 | serde_json="1.0" 16 | serde_derive = "1.0" 17 | tokio = { version = "0.2.0", path = "../../tokio", features = ["full"] } #{ version = "0.2", features = ["full"] } 18 | tokio-util ={ version = "0.2.0", path = "../../tokio-util", features = ["full"] } #{ version = "0.2", features = ["full"] } # 19 | rand="0.7" 20 | bitflags="1.0" 21 | lazy_static="1.0" 22 | get_if_addrs="0.5" 23 | futures = { version = "0.3.0", features = ["async-await"] } 24 | bytes="0.5" 25 | 26 | [[example]] 27 | name = "pub_subject" 28 | path="examples/pub_subject.rs" 29 | [[example]] 30 | name = "sub_subject" 31 | path="examples/sub_subject.rs" -------------------------------------------------------------------------------- /client/examples/pub_subject.rs: -------------------------------------------------------------------------------- 1 | use client::client::Client; 2 | use std::error::Error; 3 | 4 | #[tokio::main] 5 | async fn main() -> Result<(), Box> { 6 | let addr = "127.0.0.1:4222"; 7 | let mut c = Client::connect(addr).await?; 8 | for i in 0..10 { 9 | println!("pub {}", i); 10 | c.pub_message("test", format!("hello{}", i).as_bytes()) 11 | .await?; 12 | } 13 | println!("close connection"); 14 | c.close(); 15 | Ok(()) 16 | } 17 | -------------------------------------------------------------------------------- /client/examples/sub_subject.rs: -------------------------------------------------------------------------------- 1 | use client::client::Client; 2 | use std::error::Error; 3 | 4 | #[tokio::main] 5 | async fn main() -> Result<(), Box> { 6 | let addr = "127.0.0.1:4222"; 7 | let mut c = Client::connect(addr).await?; 8 | let mut rx = c.sub_message("test", None).await?; 9 | for i in 0..10 { 10 | let r = rx.recv().await; 11 | if r.is_none() { 12 | break; 13 | } 14 | let r = r.unwrap(); 15 | println!("{} receive on test {}", i, unsafe { 16 | std::str::from_utf8_unchecked(r.as_slice()) 17 | }); 18 | } 19 | println!("close connection"); 20 | c.close(); 21 | Ok(()) 22 | } 23 | -------------------------------------------------------------------------------- /client/src/client.rs: -------------------------------------------------------------------------------- 1 | use crate::parser::Parser; 2 | use crate::parser::*; 3 | use bytes::buf::BufMutExt; 4 | use bytes::{Buf, BytesMut}; 5 | use std::collections::HashMap; 6 | use std::sync::Arc; 7 | use tokio::io::*; 8 | use tokio::net::TcpStream; 9 | use tokio::sync::{oneshot, Mutex}; 10 | 11 | type MessageHandler = Box std::result::Result<(), ()> + Sync + Send>; 12 | //#[derive(Debug)] 13 | pub struct Client { 14 | addr: String, 15 | writer: Arc>>, 16 | msg_buf: Option, 17 | pub stop: Option>, 18 | sid: u64, 19 | handler: Arc>>, 20 | } 21 | 22 | impl Client { 23 | pub async fn connect(addr: &str) -> std::io::Result { 24 | let conn = TcpStream::connect(addr).await?; 25 | let (reader, writer) = tokio::io::split(conn); 26 | let (tx, rx) = tokio::sync::oneshot::channel(); 27 | let msg_sender = Arc::new(Mutex::new(HashMap::new())); 28 | let writer = Arc::new(Mutex::new(writer)); 29 | tokio::spawn(Self::receive_task( 30 | reader, 31 | rx, 32 | msg_sender.clone(), 33 | writer.clone(), 34 | )); 35 | return Ok(Client { 36 | addr: addr.to_string(), 37 | writer, 38 | stop: Some(tx), 39 | sid: 0, 40 | handler: msg_sender, 41 | msg_buf: Some(BytesMut::with_capacity(512)), 42 | }); 43 | } 44 | async fn receive_task( 45 | mut reader: ReadHalf, 46 | stop: oneshot::Receiver<()>, 47 | handler: Arc>>, 48 | writer: Arc>>, 49 | ) { 50 | use futures::*; 51 | let mut buf = [0 as u8; 512]; 52 | let mut parser = Parser::new(); 53 | let mut stop = stop.fuse(); 54 | // let mut _r: Result; 55 | loop { 56 | select! { 57 | _=stop=>{ 58 | println!("client stoped."); 59 | let r=writer.lock().await.shutdown().await; 60 | if r.is_err() { 61 | println!("receive_task err {:?}", r.unwrap_err()); 62 | return; 63 | } 64 | return; 65 | }, 66 | r = reader.read(&mut buf[..]).fuse()=>{ 67 | if r.is_err() { 68 | println!("receive_task err {:?}", r.unwrap_err()); 69 | return; 70 | } 71 | let r = r.unwrap(); 72 | if r == 0 { 73 | println!("connection closed"); 74 | return; 75 | } 76 | let mut buf = &buf[0..r]; 77 | // println!("read buf len={},buf={}", r, unsafe { 78 | // std::str::from_utf8_unchecked(buf) 79 | // }); 80 | loop { 81 | let r = parser.parse(buf); 82 | if r.is_err() { 83 | println!("msg error:{}", r.unwrap_err()); 84 | let r=writer.lock().await.shutdown().await; 85 | if r.is_err() { 86 | println!("shutdown err {:?}",r); 87 | } 88 | return; 89 | } 90 | let (r, n) = r.unwrap(); 91 | if let ParseResult::MsgArg(ref msg) = r { 92 | if let Some(handler) = handler.lock().await.get_mut(msg.subject) { 93 | let r = handler(msg.msg); 94 | if r.is_err() { 95 | println!("handler error {:?}", r.unwrap_err()); 96 | return; 97 | } 98 | } else { 99 | println!("receive msg on subject {}, not found receiver", msg.subject); 100 | } 101 | parser.clear_msg_buf(); 102 | } else if ParseResult::NoMsg == r { 103 | break; //NoMsg 104 | } 105 | // println!("n={},buf len={}", n, buf.len()); 106 | if n == buf.len() { 107 | break; 108 | } 109 | buf = &buf[n..]; 110 | } 111 | } 112 | } 113 | } 114 | } 115 | //pub消息格式为PUB subject size\r\n{message} 116 | pub async fn pub_message(&mut self, subject: &str, msg: &[u8]) -> std::io::Result<()> { 117 | use std::io::Write; 118 | let msg_buf = self.msg_buf.take().expect("must have"); 119 | let mut writer = msg_buf.writer(); 120 | writer.write("PUB ".as_bytes())?; 121 | writer.write(subject.as_bytes())?; 122 | // write!(writer, subject)?; 123 | write!(writer, " {}\r\n", msg.len())?; 124 | writer.write(msg)?; //todo 这个需要copy么?最好别copy 125 | writer.write("\r\n".as_bytes())?; 126 | let mut msg_buf = writer.into_inner(); 127 | let mut writer = self.writer.lock().await; 128 | writer.write_all(msg_buf.bytes()).await?; 129 | msg_buf.clear(); 130 | self.msg_buf = Some(msg_buf); 131 | Ok(()) 132 | } 133 | //批量pub, 134 | pub async fn pub_messages(&mut self, subjects: &[&str], msgs: &[&[u8]]) -> std::io::Result<()> { 135 | use std::io::Write; 136 | let msg_buf = self.msg_buf.take().expect("must have"); 137 | let mut writer = msg_buf.writer(); 138 | for i in 0..subjects.len() { 139 | writer.write("PUB ".as_bytes())?; 140 | writer.write(subjects[i].as_bytes())?; 141 | // write!(writer, subject)?; 142 | write!(writer, " {}\r\n", msgs[i].len())?; 143 | writer.write(msgs[i])?; //todo 这个需要copy么?最好别copy 144 | writer.write("\r\n".as_bytes())?; 145 | } 146 | let mut msg_buf = writer.into_inner(); 147 | let mut writer = self.writer.lock().await; 148 | 149 | writer.write_all(msg_buf.bytes()).await?; 150 | msg_buf.clear(); 151 | self.msg_buf = Some(msg_buf); 152 | Ok(()) 153 | } 154 | // type MessageHandler = Box Result<()> + Sync + Send >; 155 | //sub消息格式为SUB subject {queue} {sid}\r\n 156 | //可能由于rustc的bug,导致如果subject是&str,则会报错E0700,暂时使用String来替代 157 | pub async fn sub_message( 158 | &mut self, 159 | subject: String, 160 | queue: Option, 161 | handler: MessageHandler, 162 | ) -> std::io::Result<()> { 163 | self.sid += 1; 164 | let mut writer = self.writer.lock().await; 165 | if let Some(q) = queue { 166 | writer 167 | .write_all(format!("SUB {} {} {}\r\n", subject, q, self.sid).as_bytes()) 168 | .await?; 169 | } else { 170 | writer 171 | .write_all(format!("SUB {} {}\r\n", subject, self.sid).as_bytes()) 172 | .await?; 173 | } 174 | self.handler 175 | .lock() 176 | .await 177 | .insert(subject.to_string(), handler); 178 | Ok(()) 179 | } 180 | pub fn close(&mut self) { 181 | if let Some(stop) = self.stop.take() { 182 | let _ = stop.send(()); 183 | } 184 | } 185 | } 186 | 187 | #[cfg(test)] 188 | mod tests { 189 | struct A { 190 | a: String, 191 | } 192 | impl A { 193 | async fn test( 194 | &mut self, 195 | arg: &str, 196 | handler: Box std::result::Result<(), ()> + Send + Sync + '_>, 197 | ) { 198 | } 199 | } 200 | type MessageHandler = Box std::result::Result<(), ()> + Sync + Send>; 201 | async fn test2(handler: MessageHandler) { 202 | let args = "hello".to_string(); 203 | handler(args.as_bytes()); 204 | } 205 | fn print_hello(args: &[u8]) -> std::result::Result<(), ()> { 206 | println!("{:?}", args); 207 | Ok(()) 208 | } 209 | #[test] 210 | fn test() {} 211 | #[tokio::main] 212 | #[test] 213 | async fn test_2() { 214 | test2(Box::new(print_hello)).await 215 | } 216 | 217 | use std::cell::Cell; 218 | 219 | trait Trait<'a> {} 220 | 221 | impl<'a, 'b> Trait<'b> for Cell<&'a u32> {} 222 | 223 | fn foo<'x, 'y>(x: Cell<&'x u32>) -> impl Trait<'y> + 'x 224 | where 225 | 'x: 'y, 226 | { 227 | x 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /client/src/error.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt::{Display, Formatter}; 3 | pub type Result = std::result::Result; 4 | pub const ERROR_PARSE: i32 = 1; 5 | pub const ERROR_MESSAGE_SIZE_TOO_LARGE: i32 = 2; 6 | pub const ERROR_INVALID_SUBJECT: i32 = 3; 7 | pub const ERROR_SUBSCRIBTION_NOT_FOUND: i32 = 4; 8 | pub const ERROR_UNKOWN_ERROR: i32 = 1000; 9 | #[derive(Debug)] 10 | pub struct NError { 11 | pub err_code: i32, 12 | } 13 | impl NError { 14 | pub fn new(err_code: i32) -> Self { 15 | Self { err_code } 16 | } 17 | pub fn error_description(&self) -> &'static str { 18 | match self.err_code { 19 | ERROR_PARSE => return "parse error", 20 | _ => return "unkown error", 21 | } 22 | } 23 | } 24 | impl Error for NError {} 25 | impl Display for NError { 26 | fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { 27 | write!(f, "NError[{},{}]", self.err_code, self.error_description()) 28 | } 29 | } 30 | #[cfg(test)] 31 | mod tests { 32 | use super::*; 33 | #[test] 34 | fn test() { 35 | println!("{}", NError::new(ERROR_PARSE)); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /client/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![recursion_limit = "512"] 2 | pub mod client; 3 | pub mod error; 4 | mod parser; 5 | -------------------------------------------------------------------------------- /client/src/parser.rs: -------------------------------------------------------------------------------- 1 | /** 2 | 3 | ``` 4 | ## MSG 5 | ``` 6 | MSG \r\n 7 | \r\n 8 | ``` 9 | */ 10 | use crate::error::*; 11 | #[macro_export] 12 | macro_rules! parse_error { 13 | ( ) => {{ 14 | return Err(NError::new(ERROR_PARSE)); 15 | }}; 16 | } 17 | 18 | #[derive(Debug, Clone)] 19 | enum ParseState { 20 | OpStart, 21 | OpM, 22 | OpMs, 23 | OpMsg, 24 | OpMsgSpc, 25 | OpMsgArg, 26 | OpMsgBody, //pub message 27 | OpMsgFull, 28 | } 29 | 30 | #[derive(Debug, PartialEq)] 31 | pub struct MsgArg<'a> { 32 | pub subject: &'a str, 33 | pub size: usize, //1024 整数形式 34 | pub sid: &'a str, 35 | pub msg: &'a [u8], 36 | } 37 | #[derive(Debug, PartialEq)] 38 | pub enum ParseResult<'a> { 39 | NoMsg, //buf="sub top.stevenbai.blog" sub消息不完整,我肯定不能处理 40 | MsgArg(MsgArg<'a>), 41 | } 42 | /* 43 | 这个长度很有关系,必须能够将一个完整的主题以及参数放进去, 44 | 所以要限制subject的长度 45 | */ 46 | const BUF_LEN: usize = 512; 47 | pub struct Parser { 48 | state: ParseState, 49 | buf: [u8; BUF_LEN], //消息解析缓冲区,如果消息不超过512,直接用这个,超过了就必须另分配 50 | arg_len: usize, 51 | msg_buf: Option>, 52 | //解析过程中受到新消息,那么 新消息的总长度是msg_total_len,已收到部分应该是msg_len 53 | msg_total_len: usize, 54 | msg_len: usize, 55 | debug: bool, 56 | } 57 | 58 | impl Parser { 59 | pub fn new() -> Self { 60 | Self { 61 | state: ParseState::OpStart, 62 | buf: [0; BUF_LEN], 63 | arg_len: 0, 64 | msg_buf: None, 65 | msg_total_len: 0, 66 | msg_len: 0, 67 | debug: false, 68 | } 69 | } 70 | /** 71 | 对收到的字节序列进行解析,解析完毕后得到pub或者sub消息, 72 | 同时有可能没有消息或者缓冲区里面还有其他消息 73 | */ 74 | pub fn parse<'a, 'b>(&'b mut self, buf: &'a [u8]) -> Result<(ParseResult<'a>, usize)> 75 | where 76 | 'b: 'a, 77 | { 78 | let mut b; 79 | let mut i = 0; 80 | if self.debug { 81 | println!( 82 | "parse string:{},state={:?}", 83 | unsafe { std::str::from_utf8_unchecked(buf) }, 84 | self.state 85 | ); 86 | } 87 | while i < buf.len() { 88 | use ParseState::*; 89 | b = buf[i] as char; 90 | // println!("state={:?},b={}", self.state, b); 91 | match self.state { 92 | OpStart => match b { 93 | 'M' => self.state = OpM, 94 | _ => parse_error!(), 95 | }, 96 | OpM => match b { 97 | 'S' => self.state = OpMs, 98 | _ => parse_error!(), 99 | }, 100 | OpMs => match b { 101 | 'G' => self.state = OpMsg, 102 | _ => parse_error!(), 103 | }, 104 | OpMsg => match b { 105 | ' ' | '\t' => self.state = OpMsgSpc, 106 | _ => parse_error!(), 107 | }, 108 | OpMsgSpc => match b { 109 | ' ' | '\t' => {} 110 | _ => { 111 | self.state = OpMsgArg; 112 | self.arg_len = 0; 113 | continue; 114 | } 115 | }, 116 | OpMsgArg => match b { 117 | '\r' => {} 118 | '\n' => { 119 | self.state = OpMsgBody; 120 | let size = self.get_message_size()?; 121 | if size == 0 || size > 1 * 1024 * 1024 { 122 | //消息体长度不应该超过1M,防止Dos攻击 123 | return Err(NError::new(ERROR_MESSAGE_SIZE_TOO_LARGE)); 124 | } 125 | if size + self.arg_len > BUF_LEN { 126 | self.msg_buf = Some(Vec::with_capacity(size)); 127 | } 128 | self.msg_total_len = size; 129 | } 130 | _ => { 131 | self.add_arg(b as u8)?; 132 | } 133 | }, 134 | OpMsgBody => { 135 | //涉及消息长度 136 | if self.msg_len < self.msg_total_len { 137 | self.add_msg(b as u8); 138 | } else { 139 | self.state = OpMsgFull; 140 | } 141 | } 142 | OpMsgFull => match b { 143 | '\r' => {} 144 | '\n' => { 145 | self.state = OpStart; 146 | let r = self.process_msg()?; 147 | return Ok((r, i + 1)); 148 | } 149 | _ => { 150 | parse_error!(); 151 | } 152 | }, 153 | // _ => panic!("unkown state {:?}", self.state), 154 | } 155 | i += 1; 156 | } 157 | Ok((ParseResult::NoMsg, 0)) 158 | } 159 | //一种是消息体比较短,可以直接放在buf中,无需另外分配内存 160 | //另一种是消息体很长,无法放在buf中,额外分配了msg_buf空间 161 | fn add_msg(&mut self, b: u8) { 162 | if let Some(buf) = self.msg_buf.as_mut() { 163 | buf.push(b); 164 | } else { 165 | //消息体比较短的情况 166 | if self.arg_len + self.msg_total_len > BUF_LEN { 167 | panic!("message should allocate space"); 168 | } 169 | self.buf[self.arg_len + self.msg_len] = b; 170 | } 171 | self.msg_len += 1; 172 | } 173 | fn add_arg(&mut self, b: u8) -> Result<()> { 174 | //太长的subject 175 | if self.arg_len >= self.buf.len() { 176 | parse_error!(); 177 | } 178 | self.buf[self.arg_len] = b; 179 | self.arg_len += 1; 180 | Ok(()) 181 | } 182 | 183 | //解析缓冲区中以及msg_buf中的形如stevenbai.top 5hello 184 | fn process_msg(&self) -> Result { 185 | let msg = if self.msg_buf.is_some() { 186 | self.msg_buf.as_ref().unwrap().as_slice() 187 | } else { 188 | &self.buf[self.arg_len..self.arg_len + self.msg_total_len] 189 | }; 190 | let mut arg_buf = [""; 3]; 191 | let mut arg_len = 0; 192 | let ss = unsafe { std::str::from_utf8_unchecked(&self.buf[0..self.arg_len]) }; 193 | for s in ss.split(' ') { 194 | if s.len() == 0 { 195 | continue; 196 | } 197 | if arg_len >= 3 { 198 | parse_error!() 199 | } 200 | arg_buf[arg_len] = s; 201 | arg_len += 1; 202 | } 203 | let msg_arg = MsgArg { 204 | subject: arg_buf[0], 205 | size: self.msg_total_len, 206 | sid: arg_buf[1], 207 | msg, 208 | }; 209 | Ok(ParseResult::MsgArg(msg_arg)) 210 | } 211 | pub fn clear_msg_buf(&mut self) { 212 | self.msg_buf = None; 213 | self.msg_len = 0; 214 | self.msg_total_len = 0; 215 | } 216 | //从接收到的pub消息中提前解析出来消息的长度 217 | fn get_message_size(&self) -> Result { 218 | //缓冲区中形如top.stevenbai.top 5 219 | let arg_buf = &self.buf[0..self.arg_len]; 220 | let pos = arg_buf 221 | .iter() 222 | .rev() 223 | .position(|b| *b == ' ' as u8 || *b == '\t' as u8); 224 | if pos.is_none() { 225 | parse_error!(); 226 | } 227 | let pos = pos.unwrap(); 228 | let size_buf = &arg_buf[arg_buf.len() - pos..]; 229 | let szb = unsafe { std::str::from_utf8_unchecked(size_buf) }; 230 | szb.parse::().map_err(|_| NError::new(ERROR_PARSE)) 231 | } 232 | } 233 | 234 | #[cfg(test)] 235 | mod tests { 236 | use super::*; 237 | 238 | #[test] 239 | fn test_msg() { 240 | let mut p = Parser::new(); 241 | assert!(p.parse("aa".as_bytes()).is_err()); 242 | let mut buf = "MSG subject 1 5\r\nhello\r\nMSG subject 1 5\r\nxxxxx\r\n".as_bytes(); 243 | let r = p.parse(buf); 244 | println!("r={:?}", r); 245 | assert!(r.is_ok()); 246 | let (r, n) = r.unwrap(); 247 | // assert_eq!(r.1, buf.len()); 248 | match r { 249 | ParseResult::MsgArg(p) => { 250 | assert_eq!(p.subject, "subject"); 251 | assert_eq!(p.size, 5); 252 | assert_eq!(p.msg, "hello".as_bytes()); 253 | } 254 | _ => assert!(false, "must be valid pub arg "), 255 | } 256 | p.clear_msg_buf(); 257 | if n < buf.len() { 258 | buf = &buf[n..]; 259 | let r = p.parse(buf); 260 | println!("r={:?}", r); 261 | assert!(r.is_ok()); 262 | let r = r.unwrap(); 263 | // assert_eq!(r.1, buf.len()); 264 | match r.0 { 265 | ParseResult::MsgArg(p) => { 266 | assert_eq!(p.subject, "subject"); 267 | assert_eq!(p.size, 5); 268 | assert_eq!(p.msg, "xxxxx".as_bytes()); 269 | } 270 | _ => assert!(false, "must be valid pub arg "), 271 | } 272 | } 273 | } 274 | } 275 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # 从零实现消息中间件 2 | 3 | 消息中间件在现代系统中非常关键,包括阿里云,腾讯云都有直接的消息中间件服务,也就是你不用自己搭建服务器,直接使用它提供的服务就可以了.那么我们今天就从零开始一步一步搭建一个极简消息中间件. 当然我们不可能做到像阿里云的RocketMQ那么复杂,但是最核心功能还是要保证的. 4 | 5 | 天实现的消息中间件系统不是基于MQTT,而是基于[nats](https://nats.io/),当然也是为了教学的方便,我们只会实现最核心的消息订阅发布,而围绕其的权限,cluster之类的我们都先屏蔽.对完整nats感兴趣的可以上[nats官网](https://nats.io/)查看完整的功能. 6 | 7 | ## 相关博客文章 8 | 9 | - [1.从零实现一个极简消息中间件](从零实现一个极简消息中间件) 10 | - [2.parser](从零实现消息中间件-parser) 11 | - [3.sublist](从零实现消息中间件-sublist) 12 | - [4.server](从零实现消息中间件-server) 13 | - [5.server-client](从零实现消息中间件-server.client) 14 | - [6.client library](从零实现消息中间件-client) 15 | 16 | 17 | ## B站相关视频 18 | 19 | - [1.parser](https://www.bilibili.com/video/av85936685) 20 | - [2.sublist](https://www.bilibili.com/video/av86899713/) 21 | - [3.server](https://www.bilibili.com/video/av90457400/) 22 | - [4.server-client](https://www.bilibili.com/video/av90457552/) 23 | - [5.client library](https://www.bilibili.com/video/av90458399/) 24 | 25 | ## 协议设计 26 | nats是一个文本格式的通信协议,本来就非常简单,加上我们这次教学的需要,只保留了最核心的订阅发布系统.那就更简单了. 消息总共只有三种(订阅,发布,消息推送). 27 | 为了简化实现,就不支持取消订阅功能,如果想取消订阅,只能断开连接了. 28 | ### 订阅主题 29 | 所谓订阅,首先是要订阅什么. nats中的主题是类似于域名格式,形如top.stevenbai.blog. 比如我订阅了top.stevenbai.blog,那么当有人在这个主题下发布消息的时候我就收的到. 30 | 当然为了使用的方便,我们还支持主题的模糊匹配,具体来说就是*和>. 31 | #### *匹配 32 | *只匹配.分割的一个字段. 33 | 比如top.\*.blog 则可以匹配top.stevenbai.blog,top.steven.blog等等 34 | 而top.\*,则可以匹配top.stevenbai,top.steven,但是不能匹配top.stevenbai.blog. 35 | #### >匹配 36 | `>`可以匹配所有的字段. 37 | 比如top.> 则可以匹配包括top.stevenbai,top.stevenbai.blog,top.steven.blog等 38 | 一般来说调试的时候我们可以订阅这么一个主题`>`,他会匹配所有的主题,也就是说所有人发布的消息都可以收到. 39 | 当然只是调试的时候,因为真实的生产环境中这么使用,很快这个client就会被淹没. 40 | 41 | ### 发布消息(PUB) 42 | ``` 43 | PUB \r\n 44 | \r\n 45 | ``` 46 | 发布消息格式很简单,就是我想在某个subject下发布一个长度为多少的消息,这个消息可以使纯文本,也可以是二进制. 47 | 48 | ### 订阅消息(SUB) 49 | ``` 50 | SUB \r\n 51 | ``` 52 | 具体来说就是表达对某个subject感兴趣,如果有人在这个subject下发布了消息,那么请推送给我.推送的格式见消息推送. 53 | 其中sid是对订阅的编号,是一个十进制整数. 因为同一个tcp连接是可以有任意多个订阅. 54 | 55 | #### 负载均衡 56 | 同一subject的消息发布方可能有很多个,比如一个物联网系统中,同一类型的设备都会在某个主题下发布消息. 而这个消息可能每秒钟有上百万条,这时候一个接收方肯定就忙不过来了. 这时候就可以多个接收方. 57 | 因此从设计角度来说nats的消息订阅发布系统是多对多的. 也就是说一个主题下可以有多个发送发,多个接收方. 58 | 带负载均衡的订阅: 59 | ``` 60 | SUB 61 | ``` 62 | 比如两个client,clientA和B分别订阅了`sub top.stevenbai.blog workers 3`和`sub top.stevenbai.blog workers 4`.这里的3和4分别是两个连接各自的订阅id,他们没有任何关系,可以相同也可以不同,是他们自己的安排. 63 | 如果这时有一个client C发布了`pub top.stevenbai.blog 5\r\nfirst`和`pub top.stevenbai.blog 6\r\nsecond`两条消息,则A和B将分别收到`first`和`second`. 64 | ### 消息推送 65 | 订阅发布消息都是客户端向服务器发出,而消息推送则是服务器向客户端发出. 格式如下: 66 | ``` 67 | MSG \r\n 68 | \r\n 69 | ``` 70 | 这个格式看起来和pub消息的非常像,只不过关键字是MSG,而且多了一个表示这个连接上的订阅编号. 71 | 举例来说,上面的例子client C发布了`pub top.stevenbai.blog 5\r\nfirst`,那么ClientA收到的消息格式就是 72 | ``` 73 | MSG top.stevenbai.blog 3 5\r\n 74 | first\r\n 75 | ``` 76 | 77 | ## 系统设计 78 | 根据上面的协议设计. 79 | ### 客户端的一般工作流程. 80 | #### 消息订阅方的工作流程 81 | 1. 建立一个tcp连接 82 | 2. sub一个或者多个主题 83 | 3. 等等相关消息 84 | #### 消息发布方的工作流程 85 | 1. 建立一个tcp连接 86 | 2. 重复的在一个或者多个主题下pub消息 87 | 客户端的工作看了起来非常直观. 88 | 89 | ### 服务端的工作流程 90 | #### 消息格式解析 91 | 目前就两种消息pub和sub. 92 | #### 主题的树状组织 93 | trie树,是一种字典树 a.b.c 94 | 95 | 按照前面的描述当客户端在一个主题下pub消息的时候,服务器要能找到所有对这个主题感兴趣的客户端,因为要支持*和>的模糊匹配,使用trie树来组织比较合理. 96 | 97 | 明显这里的trie树是系统的核心数据,每一次client的pub都要来这里查找所有相关的sub,如果这里设计的不好肯定会造成系统的瓶颈. 98 | 1. 这颗trie树是全局的,每一次新的订阅和连接的断开都需要更新 99 | 2. 每一次pub都需要在树中查找. 100 | 所以树的访问必须带锁;为了避免重复查找,要进行cache. 101 | 102 | #### client的管理 103 | 这是所有server都要做的,这也是这个系统的核心部分. 104 | 1. 计划使用tokio 0.2 105 | 2. trie树的管理 106 | 3. client的管理,新建连接,连接断开等. 107 | 108 | 109 | 110 | https://github.com/nkbai/learnrustbynats 111 | 112 | https://github.com/nkbai/tokio/tree/readcode 113 | -------------------------------------------------------------------------------- /server/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "nats-server" 3 | version = "0.1.0" 4 | authors = ["bai "] 5 | edition = "2018" 6 | 7 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 8 | 9 | [dependencies] 10 | log="0.4" 11 | env_logger="0.7" 12 | serde="1.0" 13 | serde_json="1.0" 14 | serde_derive = "1.0" 15 | tokio = { version = "0.2.0", path = "../../tokio",features = ["full"] } #{ version = "0.2", features = ["full"] } 16 | tokio-util ={ version = "0.2.0",path = "../../tokio-util", features = ["full"] } #{ version = "0.2", features = ["full"] } # 17 | rand="0.7" 18 | bitflags="1.0" 19 | lazy_static="1.0" 20 | get_if_addrs="0.5" 21 | futures = { version = "0.3.0", features = ["async-await"] } 22 | lru="0.4.3" 23 | bytes="0.5" 24 | jemallocator = "*" 25 | lru-cache="0.1.2" 26 | 27 | [dev-dependencies] 28 | tokio-test = { version = "0.2.0", path="../../tokio-test" } 29 | futures = { version = "0.3.0", features = ["async-await"] } -------------------------------------------------------------------------------- /server/src/client.rs: -------------------------------------------------------------------------------- 1 | use crate::error::*; 2 | use crate::parser::{ParseResult, Parser, PubArg, SubArg}; 3 | use crate::server::*; 4 | use crate::simple_sublist::{ArcSubResult, ArcSubscription, SubListTrait, Subscription}; 5 | use rand::{RngCore, SeedableRng}; 6 | use std::collections::{BTreeSet, HashMap}; 7 | use std::error::Error; 8 | use std::sync::Arc; 9 | use tokio::io::*; 10 | use tokio::net::TcpStream; 11 | use tokio::sync::Mutex; 12 | 13 | #[derive(Debug)] 14 | pub struct Client { 15 | pub srv: Arc>>, 16 | pub cid: u64, 17 | pub msg_sender: Arc>, 18 | } 19 | 20 | #[derive(Debug)] 21 | pub struct ClientMessageSender { 22 | writer: Option>, 23 | msg_buf: Option>, 24 | } 25 | impl ClientMessageSender { 26 | pub fn new(writer: WriteHalf) -> Self { 27 | Self { 28 | writer: Some(writer), 29 | msg_buf: Some(Vec::with_capacity(512)), 30 | } 31 | } 32 | async fn send_all(&mut self) -> std::io::Result<()> { 33 | if let Some(ref mut writer) = self.writer { 34 | let r = writer 35 | .write_all(self.msg_buf.as_ref().unwrap().as_slice()) 36 | .await; 37 | self.msg_buf.as_mut().unwrap().clear(); 38 | r 39 | } else { 40 | Ok(()) 41 | } 42 | } 43 | } 44 | #[derive(Debug, Clone)] 45 | pub struct ClientMessageSenderWrapper(Arc>, usize); 46 | impl std::cmp::PartialEq for ClientMessageSenderWrapper { 47 | fn eq(&self, other: &Self) -> bool { 48 | self.cmp(other) == Ordering::Equal 49 | } 50 | } 51 | /* 52 | 为了能够将ArcSubscription,必须实现下面这些Trait 53 | */ 54 | impl std::cmp::Eq for ClientMessageSenderWrapper {} 55 | impl std::cmp::PartialOrd for ClientMessageSenderWrapper { 56 | fn partial_cmp(&self, other: &Self) -> Option { 57 | Some(self.cmp(other)) 58 | } 59 | } 60 | impl std::cmp::Ord for ClientMessageSenderWrapper { 61 | fn cmp(&self, other: &Self) -> Ordering { 62 | self.1.cmp(&other.1) 63 | } 64 | } 65 | impl Client { 66 | pub fn process_connection( 67 | cid: u64, 68 | srv: Arc>>, 69 | conn: TcpStream, 70 | ) -> Arc> { 71 | let (reader, writer) = tokio::io::split(conn); 72 | let msg_sender = Arc::new(Mutex::new(ClientMessageSender::new(writer))); 73 | let c = Client { 74 | srv: srv, 75 | cid, 76 | msg_sender: msg_sender.clone(), 77 | }; 78 | tokio::spawn(async move { 79 | Client::client_task(c, reader).await; 80 | println!("client {} client_task quit", cid); 81 | }); 82 | msg_sender 83 | } 84 | async fn client_task(self, mut reader: ReadHalf) { 85 | let mut parser = Parser::new(); 86 | let mut count: i32 = 0; 87 | let mut subs = HashMap::new(); 88 | let mut buf = [0; 1024 * 64]; 89 | let mut rng = rand::rngs::StdRng::from_entropy(); 90 | let mut cache = HashMap::new(); 91 | let mut pendings = BTreeSet::new(); 92 | loop { 93 | // let mut buf: Vec = Vec::new(); 94 | // let r = tokio::io::copy(&mut reader, &mut buf).await; 95 | // match r { 96 | // Ok(r) => println!("recevied {} bytes", r), 97 | // Err(e) => println!("copy err: {}", e), 98 | // } 99 | // return; 100 | count += 1; 101 | let r = reader.read(&mut buf[..]).await; 102 | if r.is_err() { 103 | let e = r.unwrap_err(); 104 | self.process_error(e, subs).await; 105 | return; 106 | } 107 | let r = r.unwrap(); 108 | let n = r; 109 | if n == 0 { 110 | self.process_error(NError::new(ERROR_CONNECTION_CLOSED), subs) 111 | .await; 112 | return; 113 | } 114 | // pendings.clear(); 115 | let mut buf2 = &buf[0..n]; 116 | loop { 117 | let r = parser.parse(&buf2[..]); 118 | if r.is_err() { 119 | { 120 | let s = unsafe { std::str::from_utf8_unchecked(&buf2[..]) }; 121 | println!("parse err buf={}", s); 122 | } 123 | self.process_error(r.unwrap_err(), subs).await; 124 | return; 125 | } 126 | let (result, left) = r.unwrap(); 127 | 128 | match result { 129 | ParseResult::NoMsg => { 130 | break; 131 | } 132 | ParseResult::Sub(ref sub) => { 133 | if let Err(e) = self.process_sub(sub, &mut subs).await { 134 | self.process_error(e, subs).await; 135 | return; 136 | } 137 | } 138 | ParseResult::Pub(ref pub_arg) => { 139 | if let Err(e) = self 140 | .process_pub(pub_arg, &mut cache, &mut rng, &mut pendings) 141 | .await 142 | { 143 | self.process_error(e, subs).await; 144 | return; 145 | } 146 | parser.clear_msg_buf(); 147 | } 148 | } 149 | if left == buf2.len() { 150 | break; 151 | } 152 | buf2 = &buf2[left..]; 153 | } 154 | //批量处理发送 155 | for c in pendings.iter() { 156 | let c = c.clone(); 157 | tokio::spawn(async move { 158 | let mut sender = c.0.lock().await; 159 | if let Err(e) = sender.send_all().await { 160 | println!("send_all error {}", e); 161 | } 162 | }); 163 | } 164 | pendings.clear(); 165 | } 166 | } 167 | async fn process_error(&self, err: E, subs: HashMap) { 168 | println!("client {} process err {:?}", self.cid, err); 169 | { 170 | let sublist = &mut self.srv.lock().await.sublist; 171 | for (_, sub) in subs { 172 | if let Err(e) = sublist.remove(sub) { 173 | println!("client {} remove err {} ", self.cid, e); 174 | } 175 | } 176 | } 177 | let mut sender = self.msg_sender.lock().await; 178 | if let Some(mut writer) = sender.writer.take() { 179 | sender.msg_buf.take(); 180 | if let Err(e) = writer.shutdown().await { 181 | println!("shutdown err {:?}", e); 182 | } 183 | } 184 | } 185 | async fn process_sub( 186 | &self, 187 | sub: &SubArg<'_>, 188 | subs: &mut HashMap, 189 | ) -> crate::error::Result<()> { 190 | let sub = Subscription { 191 | subject: sub.subject.to_string(), 192 | queue: sub.queue.map(|q| q.to_string()), 193 | sid: sub.sid.to_string(), 194 | msg_sender: self.msg_sender.clone(), 195 | }; 196 | let sub = Arc::new(sub); 197 | subs.insert(sub.subject.clone(), sub.clone()); 198 | let sublist = &mut self.srv.lock().await.sublist; 199 | sublist.insert(sub)?; 200 | Ok(()) 201 | } 202 | async fn process_pub( 203 | &self, 204 | pub_arg: &PubArg<'_>, 205 | cache: &mut HashMap, 206 | rng: &mut rand::rngs::StdRng, 207 | pendings: &mut BTreeSet, 208 | ) -> crate::error::Result<()> { 209 | let sub_result = { 210 | if let Some(r) = cache.get(pub_arg.subject) { 211 | Arc::clone(r) 212 | } else { 213 | let sub_list = &mut self.srv.lock().await.sublist; 214 | let r = sub_list.match_subject(pub_arg.subject); 215 | cache.insert(pub_arg.subject.to_string(), Arc::clone(&r)); 216 | r 217 | } 218 | }; 219 | if sub_result.psubs.len() > 0 { 220 | for sub in sub_result.psubs.iter() { 221 | self.send_message(sub.as_ref(), pub_arg, pendings) 222 | .await 223 | .map_err(|e| { 224 | println!("send message error {}", e); 225 | NError::new(ERROR_CONNECTION_CLOSED) 226 | })?; 227 | } 228 | } 229 | if sub_result.qsubs.len() > 0 { 230 | //qsubs 要考虑负载均衡问题 231 | for qsubs in sub_result.qsubs.iter() { 232 | let n = rng.next_u32(); 233 | let n = n as usize % qsubs.len(); 234 | let sub = qsubs.get(n).unwrap(); 235 | self.send_message(sub.as_ref(), pub_arg, pendings) 236 | .await 237 | .map_err(|_| NError::new(ERROR_CONNECTION_CLOSED))?; 238 | } 239 | } 240 | Ok(()) 241 | } 242 | ///消息格式 243 | ///``` 244 | /// MSG \r\n 245 | /// \r\n 246 | /// ``` 247 | async fn send_message( 248 | &self, 249 | sub: &Subscription, 250 | pub_arg: &PubArg<'_>, 251 | pendings: &mut BTreeSet, 252 | ) -> std::io::Result<()> { 253 | let mut msg_sender = sub.msg_sender.lock().await; 254 | if let Some(ref mut msg_buf) = msg_sender.msg_buf { 255 | let id = msg_sender.deref() as *const ClientMessageSender as usize; 256 | let msg_buf = msg_sender.msg_buf.as_mut().unwrap(); 257 | 258 | msg_buf.extend_from_slice("MSG ".as_bytes()); 259 | msg_buf.extend_from_slice(sub.subject.as_bytes()); 260 | msg_buf.extend_from_slice(" ".as_bytes()); 261 | msg_buf.extend_from_slice(sub.sid.as_bytes()); 262 | msg_buf.extend_from_slice(" ".as_bytes()); 263 | msg_buf.extend_from_slice(pub_arg.size_buf.as_bytes()); 264 | msg_buf.extend_from_slice("\r\n".as_bytes()); 265 | msg_buf.extend_from_slice(pub_arg.msg); //经测试,如果这里不使用缓存,而是多个await,性能会大幅下降. 266 | msg_buf.extend_from_slice("\r\n".as_bytes()); 267 | pendings.insert(ClientMessageSenderWrapper(sub.msg_sender.clone(), id)); 268 | } 269 | Ok(()) 270 | } 271 | /* async fn send_message2(sub: Arc, msg: Arc>) -> std::io::Result<()> { 272 | let mut msg_sender = sub.msg_sender.lock().await; 273 | let msg_buf = msg_sender.msg_buf.take().expect("must have"); 274 | let writer = &mut msg_sender.writer; 275 | let mut buf = msg_buf.writer(); 276 | buf.write("MSG ".as_bytes())?; 277 | buf.write(sub.subject.as_bytes())?; 278 | buf.write(" ".as_bytes())?; 279 | buf.write(sub.sid.as_bytes())?; 280 | buf.write(" ".as_bytes())?; 281 | write!(buf, "{}", msg.len())?; 282 | buf.write("\r\n".as_bytes())?; 283 | buf.write(msg.as_slice())?; //经测试,如果这里不使用缓存,而是多个await,性能会大幅下降. 284 | buf.write("\r\n".as_bytes())?; 285 | let mut msg_buf = buf.into_inner(); 286 | writer.write_all(msg_buf.bytes()).await?; 287 | // writer.flush().await?; 暂不需要flush,因为没有使用BufWriter 288 | msg_buf.clear(); 289 | msg_sender.msg_buf = Some(msg_buf); 290 | Ok(()) 291 | }*/ 292 | } 293 | #[cfg(test)] 294 | pub mod test_helper { 295 | use super::*; 296 | use lazy_static::lazy_static; 297 | lazy_static! { 298 | static ref SENDER: Arc> = { 299 | let l = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); 300 | // let port = l.local_addr().unwrap().port(); 301 | let conn = std::net::TcpStream::connect(l.local_addr().unwrap()).unwrap(); 302 | let (tx, rx) = std::sync::mpsc::channel(); 303 | 304 | let rt: tokio::runtime::Runtime = tokio::runtime::Builder::new() 305 | .threaded_scheduler() 306 | .enable_all() 307 | .build() 308 | .unwrap(); 309 | 310 | rt.spawn(async move { 311 | println!("tokio spawned"); 312 | let conn = tokio::net::TcpStream::from_std(conn).unwrap(); 313 | let (_, writer) = tokio::io::split(conn); 314 | println!("send start"); 315 | let _=tx.send(writer); 316 | println!("send complete") 317 | }); 318 | let writer = rx.recv().unwrap(); 319 | Arc::new(Mutex::new(ClientMessageSender::new(writer))) 320 | }; 321 | } 322 | #[cfg(test)] 323 | pub fn new_test_tcp_writer() -> Arc> { 324 | SENDER.clone() 325 | } 326 | } 327 | use std::cmp::Ordering; 328 | use std::ops::Deref; 329 | #[cfg(test)] 330 | pub use test_helper::new_test_tcp_writer; 331 | 332 | #[cfg(test)] 333 | mod tests { 334 | use super::*; 335 | extern crate test; 336 | use std::io::Write; 337 | use test::Bencher; 338 | 339 | #[test] 340 | fn test() {} 341 | #[test] 342 | fn test_rng() { 343 | for _ in 0..10 { 344 | let mut r = rand::rngs::StdRng::from_entropy(); 345 | println!("next={}", r.next_u32()); 346 | } 347 | } 348 | #[test] 349 | fn test_bytes() { 350 | let mut buf = Vec::with_capacity(100); 351 | buf.extend_from_slice("hello".as_bytes()); 352 | assert_eq!(buf.len(), 5); 353 | assert_eq!(buf.capacity(), 100); 354 | buf.clear(); 355 | assert_eq!(buf.capacity(), 100); 356 | assert_eq!(buf.len(), 0); 357 | } 358 | #[bench] 359 | fn bench_gen_rng(b: &mut Bencher) { 360 | b.iter(|| { 361 | let r = rand::rngs::StdRng::from_entropy(); 362 | drop(r); 363 | }); 364 | } 365 | } 366 | -------------------------------------------------------------------------------- /server/src/error.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt::{Display, Formatter}; 3 | pub type Result = std::result::Result; 4 | pub const ERROR_PARSE: i32 = 1; 5 | pub const ERROR_MESSAGE_SIZE_TOO_LARGE: i32 = 2; 6 | pub const ERROR_INVALID_SUBJECT: i32 = 3; 7 | pub const ERROR_SUBSCRIBTION_NOT_FOUND: i32 = 4; 8 | pub const ERROR_CONNECTION_CLOSED: i32 = 5; 9 | //pub const ERROR_UNKOWN_ERROR: i32 = 1000; 10 | #[derive(Debug)] 11 | pub struct NError { 12 | pub err_code: i32, 13 | } 14 | impl NError { 15 | pub fn new(err_code: i32) -> Self { 16 | Self { err_code } 17 | } 18 | pub fn error_description(&self) -> &'static str { 19 | match self.err_code { 20 | ERROR_PARSE => return "parse error", 21 | _ => return "unkown error", 22 | } 23 | } 24 | } 25 | impl Error for NError {} 26 | impl Display for NError { 27 | fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { 28 | write!(f, "NError[{},{}]", self.err_code, self.error_description()) 29 | } 30 | } 31 | #[cfg(test)] 32 | mod tests { 33 | use super::*; 34 | #[test] 35 | fn test() { 36 | println!("{}", NError::new(ERROR_PARSE)); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /server/src/main.rs: -------------------------------------------------------------------------------- 1 | #![feature(test)] 2 | #![feature(hash_raw_entry)] 3 | 4 | use crate::server::Server; 5 | use crate::sublist::TrieSubList; 6 | use jemallocator::Jemalloc; 7 | use std::error::Error; 8 | 9 | #[global_allocator] 10 | static GLOBAL: Jemalloc = Jemalloc; 11 | 12 | mod client; 13 | mod error; 14 | mod parser; 15 | mod server; 16 | mod simple_sublist; 17 | mod sublist; 18 | #[tokio::main] 19 | async fn main() -> Result<(), Box> { 20 | println!("server start.."); 21 | let s: Server = Server::default(); 22 | s.start().await 23 | } 24 | -------------------------------------------------------------------------------- /server/src/parser.rs: -------------------------------------------------------------------------------- 1 | /** 2 | ## pub 3 | ``` 4 | PUB \r\n 5 | \r\n 6 | ``` 7 | ## sub 8 | ``` 9 | SUB \r\n 10 | SUB \r\n 11 | ``` 12 | ## MSG 13 | ``` 14 | MSG \r\n 15 | \r\n 16 | ``` 17 | */ 18 | use crate::error::*; 19 | #[macro_export] 20 | macro_rules! parse_error { 21 | ( ) => {{ 22 | // panic!("parse error"); 23 | return Err(NError::new(ERROR_PARSE)); 24 | }}; 25 | } 26 | 27 | #[derive(Debug, Clone)] 28 | enum ParseState { 29 | OpStart, 30 | OpS, 31 | OpSu, 32 | OpSub, 33 | OPSubSpace, 34 | OpSubArg, 35 | OpP, 36 | OpPu, 37 | OpPub, //pub argument 38 | OpPubSpace, 39 | OpPubArg, 40 | OpMsg, //pub message 41 | OpMsgFull, 42 | } 43 | #[derive(Debug, PartialEq)] 44 | pub struct SubArg<'a> { 45 | pub subject: &'a str, //为什么是str而不是String,就是为了避免内存分配, 46 | pub sid: &'a str, 47 | pub queue: Option<&'a str>, 48 | } 49 | #[derive(Debug, PartialEq)] 50 | pub struct PubArg<'a> { 51 | pub subject: &'a str, 52 | pub size_buf: &'a str, // 1024 字符串形式,避免后续再次转换 53 | pub size: usize, //1024 整数形式 54 | pub msg: &'a [u8], 55 | } 56 | #[derive(Debug, PartialEq)] 57 | pub enum ParseResult<'a> { 58 | NoMsg, //buf="sub top.stevenbai.blog" sub消息不完整,我肯定不能处理 59 | Sub(SubArg<'a>), 60 | Pub(PubArg<'a>), 61 | } 62 | /* 63 | 这个长度很有关系,必须能够将一个完整的主题以及参数放进去, 64 | 所以要限制subject的长度 65 | */ 66 | const BUF_LEN: usize = 512; 67 | pub struct Parser { 68 | state: ParseState, 69 | buf: [u8; BUF_LEN], //消息解析缓冲区,如果消息不超过512,直接用这个,超过了就必须另分配 70 | arg_len: usize, 71 | msg_buf: Option>, 72 | //解析过程中受到新消息,那么 新消息的总长度是msg_total_len,已收到部分应该是msg_len 73 | msg_total_len: usize, 74 | msg_len: usize, 75 | debug: bool, 76 | } 77 | 78 | impl Parser { 79 | pub fn new() -> Self { 80 | Self { 81 | state: ParseState::OpStart, 82 | buf: [0; BUF_LEN], 83 | arg_len: 0, 84 | msg_buf: None, 85 | msg_total_len: 0, 86 | msg_len: 0, 87 | debug: false, 88 | } 89 | } 90 | /** 91 | 对收到的字节序列进行解析,解析完毕后得到pub或者sub消息, 92 | 同时有可能没有消息或者缓冲区里面还有其他消息 93 | */ 94 | pub fn parse(&mut self, buf: &[u8]) -> Result<(ParseResult, usize)> { 95 | let mut b; 96 | let mut i = 0; 97 | if self.debug { 98 | println!( 99 | "parse string:{},state={:?}", 100 | unsafe { std::str::from_utf8_unchecked(buf) }, 101 | self.state 102 | ); 103 | } 104 | while i < buf.len() { 105 | use ParseState::*; 106 | b = buf[i] as char; 107 | // println!("state={:?},b={}", self.state, b); 108 | match self.state { 109 | OpStart => match b { 110 | 'S' => self.state = OpS, 111 | 'P' => self.state = OpP, 112 | _ => parse_error!(), 113 | }, 114 | OpS => match b { 115 | 'U' => self.state = OpSu, 116 | _ => parse_error!(), 117 | }, 118 | OpSu => match b { 119 | 'B' => self.state = OpSub, 120 | _ => parse_error!(), 121 | }, 122 | OpSub => match b { 123 | //sub stevenbai.top 3 是ok的,但是substevenbai.top 3就不允许 124 | ' ' | '\t' => self.state = OPSubSpace, 125 | _ => parse_error!(), 126 | }, 127 | OPSubSpace => match b { 128 | ' ' | '\t' => {} 129 | _ => { 130 | self.state = OpSubArg; 131 | self.arg_len = 0; 132 | continue; 133 | } 134 | }, 135 | OpSubArg => match b { 136 | '\r' => {} 137 | '\n' => { 138 | self.state = OpStart; 139 | let r = self.process_sub()?; 140 | return Ok((r, i + 1)); 141 | } 142 | _ => { 143 | self.add_arg(b as u8)?; 144 | } 145 | }, 146 | OpP => match b { 147 | 'U' => self.state = OpPu, 148 | _ => parse_error!(), 149 | }, 150 | OpPu => match b { 151 | 'B' => self.state = OpPub, 152 | _ => parse_error!(), 153 | }, 154 | OpPub => match b { 155 | ' ' | '\t' => self.state = OpPubSpace, 156 | _ => parse_error!(), 157 | }, 158 | OpPubSpace => match b { 159 | ' ' | '\t' => {} 160 | _ => { 161 | self.state = OpPubArg; 162 | self.arg_len = 0; 163 | continue; 164 | } 165 | }, 166 | OpPubArg => match b { 167 | '\r' => {} 168 | '\n' => { 169 | //PUB top.stevenbai 5\r\n 170 | self.state = OpMsg; 171 | let size = self.get_message_size()?; 172 | if size == 0 || size > 1 * 1024 * 1024 { 173 | //消息体长度不应该超过1M,防止Dos攻击 174 | return Err(NError::new(ERROR_MESSAGE_SIZE_TOO_LARGE)); 175 | } 176 | if size + self.arg_len > BUF_LEN { 177 | if self.msg_buf.is_none() { 178 | self.msg_buf = Some(Vec::with_capacity(size)); 179 | } 180 | } 181 | self.msg_total_len = size; 182 | } 183 | _ => { 184 | self.add_arg(b as u8)?; 185 | } 186 | }, 187 | OpMsg => { 188 | //涉及消息长度 189 | if self.msg_len < self.msg_total_len { 190 | self.add_msg(b as u8); 191 | } else { 192 | self.state = OpMsgFull; 193 | } 194 | } 195 | OpMsgFull => match b { 196 | '\r' => {} 197 | '\n' => { 198 | self.state = OpStart; 199 | let r = self.process_msg()?; 200 | return Ok((r, i + 1)); 201 | } 202 | _ => { 203 | parse_error!(); 204 | } 205 | }, 206 | // _ => panic!("unkown state {:?}", self.state), 207 | } 208 | i += 1; 209 | } 210 | Ok((ParseResult::NoMsg, buf.len())) 211 | } 212 | //一种是消息体比较短,可以直接放在buf中,无需另外分配内存 213 | //另一种是消息体很长,无法放在buf中,额外分配了msg_buf空间 214 | fn add_msg(&mut self, b: u8) { 215 | if let Some(buf) = self.msg_buf.as_mut() { 216 | buf.push(b); 217 | } else { 218 | //消息体比较短的情况 219 | if self.arg_len + self.msg_total_len > BUF_LEN { 220 | panic!("message should allocate space"); 221 | } 222 | self.buf[self.arg_len + self.msg_len] = b; 223 | } 224 | self.msg_len += 1; 225 | } 226 | fn add_arg(&mut self, b: u8) -> Result<()> { 227 | //太长的subject 228 | if self.arg_len >= self.buf.len() { 229 | parse_error!(); 230 | } 231 | self.buf[self.arg_len] = b; 232 | self.arg_len += 1; 233 | Ok(()) 234 | } 235 | //解析缓冲区中的形如stevenbai.top queue 3 236 | fn process_sub(&self) -> Result { 237 | let buf = &self.buf[0..self.arg_len]; 238 | //有可能客户端恶意发送一些无效的utf8字符,这会导致错误. 239 | let ss = unsafe { std::str::from_utf8_unchecked(buf) }; 240 | let mut arg_buf = [""; 3]; //如果没有queue,长度就是2,否则长度是3 241 | let mut arg_len = 0; 242 | for s in ss.split(' ') { 243 | if s.len() == 0 { 244 | continue; 245 | } 246 | if arg_len >= 3 { 247 | parse_error!(); 248 | } 249 | arg_buf[arg_len] = s; 250 | arg_len += 1; 251 | } 252 | let mut sub_arg = SubArg { 253 | subject: "", 254 | sid: "", 255 | queue: None, 256 | }; 257 | sub_arg.subject = arg_buf[0]; 258 | //长度为2时不包含queue,为3包含queue,其他都说明格式错误 259 | match arg_len { 260 | 2 => { 261 | sub_arg.sid = arg_buf[1]; 262 | } 263 | 3 => { 264 | sub_arg.sid = arg_buf[2]; 265 | sub_arg.queue = Some(arg_buf[1]); 266 | } 267 | _ => parse_error!(), 268 | } 269 | Ok(ParseResult::Sub(sub_arg)) 270 | } 271 | //解析缓冲区中以及msg_buf中的形如stevenbai.top 5hello 272 | fn process_msg(&self) -> Result { 273 | let msg = if self.msg_buf.is_some() { 274 | self.msg_buf.as_ref().unwrap().as_slice() 275 | } else { 276 | &self.buf[self.arg_len..self.arg_len + self.msg_total_len] 277 | }; 278 | let mut arg_buf = [""; 2]; 279 | let mut arg_len = 0; 280 | let ss = unsafe { std::str::from_utf8_unchecked(&self.buf[0..self.arg_len]) }; 281 | for s in ss.split(' ') { 282 | if s.len() == 0 { 283 | continue; 284 | } 285 | if arg_len >= 2 { 286 | parse_error!() 287 | } 288 | arg_buf[arg_len] = s; 289 | arg_len += 1; 290 | } 291 | let pub_arg = PubArg { 292 | subject: arg_buf[0], 293 | size_buf: arg_buf[1], 294 | size: self.msg_total_len, 295 | msg, 296 | }; 297 | Ok(ParseResult::Pub(pub_arg)) 298 | } 299 | pub fn clear_msg_buf(&mut self) { 300 | //self.msg_buf = None; 301 | if let Some(ref mut v) = self.msg_buf { 302 | v.clear(); 303 | } 304 | self.msg_len = 0; 305 | self.msg_total_len = 0; 306 | } 307 | //从接收到的pub消息中提前解析出来消息的长度 308 | fn get_message_size(&self) -> Result { 309 | //缓冲区中形如top.stevenbai.top 5 310 | let arg_buf = &self.buf[0..self.arg_len]; 311 | let pos = arg_buf 312 | .iter() 313 | .rev() 314 | .position(|b| *b == ' ' as u8 || *b == '\t' as u8); 315 | if pos.is_none() { 316 | parse_error!(); 317 | } 318 | let pos = pos.unwrap(); 319 | let size_buf = &arg_buf[arg_buf.len() - pos..]; 320 | let szb = unsafe { std::str::from_utf8_unchecked(size_buf) }; 321 | szb.parse::().map_err(|_| NError::new(ERROR_PARSE)) 322 | } 323 | pub fn iter<'a>(&'a mut self, buf: &'a [u8]) -> ParseIter<'a> { 324 | ParseIter { parser: self, buf } 325 | } 326 | } 327 | pub struct ParseIter<'a> { 328 | parser: *mut Parser, 329 | buf: &'a [u8], 330 | } 331 | 332 | impl<'a> Iterator for ParseIter<'a> { 333 | type Item = Result>; 334 | fn next(&mut self) -> Option { 335 | if self.buf.len() == 0 { 336 | return None; 337 | } 338 | /* 339 | 对于外部使用这类来说,这里使用unsafe是安全的. 340 | 首先,ParseIter<'a>的生命周期一定是小于self.parser,也就是说parser这个指针一定是有效的. 341 | 其次,ParseIter的构造只能通过Parser.iter来构造,所以parser一定是mutable的 342 | 所以不存在内存安全问题. 343 | */ 344 | let parser = unsafe { &mut *self.parser }; 345 | let r: Result<(ParseResult<'a>, usize)> = parser.parse(self.buf); 346 | 347 | return Some(r.map(|r| { 348 | self.buf = &self.buf[r.1..]; 349 | r.0 350 | })); 351 | } 352 | } 353 | #[cfg(test)] 354 | mod tests { 355 | use super::*; 356 | #[test] 357 | fn test() {} 358 | #[test] 359 | fn test_get_message_size() { 360 | let mut p = Parser::new(); 361 | let buf = "subject 5".as_bytes(); 362 | p.buf[0..buf.len()].copy_from_slice(buf); 363 | p.arg_len = buf.len(); 364 | let r = p.get_message_size(); 365 | assert!(r.is_ok()); 366 | let r = r.unwrap(); 367 | assert!(r == 5); 368 | } 369 | #[test] 370 | fn test_process_sub() { 371 | let mut p = Parser::new(); 372 | let buf = "subject 5".as_bytes(); 373 | p.buf[0..buf.len()].copy_from_slice(buf); 374 | p.arg_len = buf.len(); 375 | let r = p.process_sub(); 376 | assert!(r.is_ok()); 377 | let r = r.unwrap(); 378 | if let ParseResult::Sub(sub) = r { 379 | assert_eq!(sub.subject, "subject"); 380 | assert_eq!(sub.sid, "5"); 381 | assert!(sub.queue.is_none()); 382 | } else { 383 | assert!(false, "unkown error"); 384 | } 385 | //包含queue的情形 386 | let buf = "subject queue 5".as_bytes(); 387 | p.buf[0..buf.len()].copy_from_slice(buf); 388 | p.arg_len = buf.len(); 389 | let r = p.process_sub(); 390 | assert!(r.is_ok()); 391 | let r = r.unwrap(); 392 | if let ParseResult::Sub(sub) = r { 393 | assert_eq!(sub.subject, "subject"); 394 | assert_eq!(sub.sid, "5"); 395 | assert_eq!(sub.queue.as_ref().unwrap(), &"queue"); 396 | } else { 397 | assert!(false, "unkown error"); 398 | } 399 | } 400 | #[test] 401 | fn test_process_pub() { 402 | let mut p = Parser::new(); 403 | 404 | let buf = "subject 5hello".as_bytes(); 405 | p.buf[0..buf.len()].copy_from_slice(buf); 406 | p.arg_len = buf.len() - 5; 407 | p.msg_total_len = 5; 408 | p.msg_len = 5; 409 | let r = p.process_msg(); 410 | assert!(r.is_ok()); 411 | let r = r.unwrap(); 412 | if let ParseResult::Pub(pub_arg) = r { 413 | assert_eq!(pub_arg.subject, "subject"); 414 | assert_eq!(pub_arg.size_buf, "5"); 415 | assert_eq!(pub_arg.size, 5); 416 | assert_eq!(pub_arg.msg, "hello".as_bytes()); 417 | } else { 418 | assert!(false, "unkown error"); 419 | } 420 | let buf = "subject 5".as_bytes(); 421 | p.buf[0..buf.len()].copy_from_slice(buf); 422 | p.arg_len = buf.len(); 423 | p.msg_buf = Some(Vec::from("hello".as_bytes())); 424 | p.msg_total_len = 5; 425 | p.msg_len = 5; 426 | let r = p.process_msg(); 427 | assert!(r.is_ok()); 428 | let r = r.unwrap(); 429 | if let ParseResult::Pub(pub_arg) = r { 430 | assert_eq!(pub_arg.subject, "subject"); 431 | assert_eq!(pub_arg.size_buf, "5"); 432 | assert_eq!(pub_arg.size, 5); 433 | assert_eq!(pub_arg.msg, "hello".as_bytes()); 434 | } else { 435 | assert!(false, "unkown error"); 436 | } 437 | } 438 | #[test] 439 | fn test_pub() { 440 | let mut p = Parser::new(); 441 | assert!(p.parse("aa".as_bytes()).is_err()); 442 | let buf = "PUB subject 5\r\nhello\r\n".as_bytes(); 443 | let r = p.parse(buf); 444 | println!("r={:?}", r); 445 | assert!(r.is_ok()); 446 | let r = r.unwrap(); 447 | assert_eq!(r.1, buf.len()); 448 | match r.0 { 449 | ParseResult::Pub(p) => { 450 | assert_eq!(p.subject, "subject"); 451 | assert_eq!(p.size, 5); 452 | assert_eq!(p.size_buf, "5"); 453 | assert_eq!(p.msg, "hello".as_bytes()); 454 | } 455 | _ => assert!(false, "must be valid pub arg "), 456 | } 457 | } 458 | #[test] 459 | fn test_pub2() { 460 | let mut p = Parser::new(); 461 | let mut buf = "PUB subject 5\r\nhello\r\nPUB subject 5\r\nhe".as_bytes(); 462 | loop { 463 | unsafe { 464 | let s = std::str::from_utf8_unchecked(buf); 465 | println!("buf={}", s); 466 | } 467 | let r = p.parse(buf); 468 | println!("r={:?}", r); 469 | assert!(!r.is_err()); 470 | let r = r.unwrap(); 471 | buf = &buf[r.1..]; 472 | println!("r.0={:?}", r.0); 473 | match r.0 { 474 | ParseResult::Pub(pub_arg) => { 475 | println!("pub_arg.subject={}", pub_arg.subject); 476 | p.clear_msg_buf(); 477 | } 478 | ParseResult::NoMsg => {} 479 | _ => panic!(), 480 | } 481 | if buf.len() == 0 { 482 | break; 483 | } 484 | } 485 | } 486 | #[test] 487 | fn test_sub() { 488 | let mut p = Parser::new(); 489 | let buf = "SUB subject 1\r\n".as_bytes(); 490 | let r = p.parse(buf); 491 | assert!(r.is_ok()); 492 | println!("r={:?}", r); 493 | let r = r.unwrap(); 494 | assert_eq!(r.1, buf.len()); 495 | if let ParseResult::Sub(sub) = r.0 { 496 | assert_eq!(sub.subject, "subject"); 497 | assert_eq!(sub.sid, "1"); 498 | assert_eq!(sub.queue, None); 499 | } else { 500 | assert!(false, "unkown error"); 501 | } 502 | 503 | let buf = "SUB subject queue 1\r\n".as_bytes(); 504 | let r = p.parse(buf); 505 | println!("r={:?}", r); 506 | assert!(r.is_ok()); 507 | let r = r.unwrap(); 508 | assert_eq!(r.1, buf.len()); 509 | if let ParseResult::Sub(sub) = r.0 { 510 | assert_eq!(sub.subject, "subject"); 511 | assert_eq!(sub.sid, "1"); 512 | assert_eq!(sub.queue, Some("queue")); 513 | } else { 514 | assert!(false, "unkown error"); 515 | } 516 | } 517 | #[test] 518 | fn test_sub2() { 519 | let mut p = Parser::new(); 520 | let mut buf = "SUB subject 1\r\nSUB subject2 2\r\n".as_bytes(); 521 | loop { 522 | let r = p.parse(buf); 523 | assert!(!r.is_err()); 524 | let r = r.unwrap(); 525 | buf = &buf[r.1..]; 526 | match r.0 { 527 | ParseResult::Sub(sub) => { 528 | println!("sub.subect={}", sub.subject); 529 | } 530 | _ => panic!(), 531 | } 532 | if buf.len() == 0 { 533 | break; 534 | } 535 | } 536 | } 537 | #[test] 538 | fn test_sub3() { 539 | let mut p = Parser::new(); 540 | let buf = "SUB subject 1\r\nSUB subject2 2\r\n".as_bytes(); 541 | for r in p.iter(buf) { 542 | assert!(!r.is_err()); 543 | let r = r.unwrap(); 544 | match r { 545 | ParseResult::Sub(sub) => { 546 | println!("sub.subect={}", sub.subject); 547 | } 548 | ParseResult::NoMsg => { 549 | break; 550 | } 551 | _ => panic!(), 552 | } 553 | } 554 | } 555 | #[test] 556 | fn test_sub4() { 557 | /* 558 | 这种iter的实现方法有问题,是因为buf放到了Parser里面, 559 | 这会造成下一次的结果覆盖上一次的 560 | */ 561 | let mut p = Parser::new(); 562 | let buf = "SUB subject 1\r\nSUB xxxx2 2\r\n".as_bytes(); 563 | let mut it = p.iter(buf); 564 | let v1 = it.next(); 565 | let v2 = it.next(); 566 | println!("v1={:?}", v1); 567 | println!("v2={:?}", v2); 568 | } 569 | // #[test] 570 | // fn test_sub5() { 571 | // struct P { 572 | // buf: Option>, 573 | // } 574 | // let mut p = P { 575 | // buf: Some(Vec::from("hello,world")), 576 | // }; 577 | // let ss = p.buf.as_ref().unwrap().as_slice(); 578 | // println!("ss={:?}", ss); 579 | // p.buf.take(); 580 | // println!("ss={:?}", ss); 581 | // } 582 | #[test] 583 | fn test_no_msg() { 584 | let mut p = Parser::new(); 585 | let buf = "SUB subject".as_bytes(); 586 | let r = p.parse(buf); 587 | assert!(r.is_ok()); 588 | println!("r={:?}", r); 589 | let r = r.unwrap(); 590 | assert_eq!(r.0, ParseResult::NoMsg); 591 | } 592 | } 593 | -------------------------------------------------------------------------------- /server/src/server.rs: -------------------------------------------------------------------------------- 1 | use crate::client::*; 2 | use crate::simple_sublist::SubListTrait; 3 | use std::collections::HashMap; 4 | use std::error::Error; 5 | use std::sync::Arc; 6 | use tokio::net::{TcpListener, TcpStream}; 7 | use tokio::sync::Mutex; 8 | 9 | #[derive(Debug, Default)] 10 | pub struct Server { 11 | state: Arc>>, 12 | } 13 | #[derive(Debug, Default)] 14 | pub struct ServerState { 15 | clients: HashMap>>, 16 | pub sublist: T, 17 | pub gen_cid: u64, 18 | } 19 | 20 | impl Server { 21 | pub async fn start(self) -> Result<(), Box> { 22 | let addr = "0.0.0.0:4222"; 23 | let mut listener = TcpListener::bind(addr).await?; 24 | //go func(){} 25 | loop { 26 | let (conn, _) = listener.accept().await?; 27 | self.new_client(conn).await; 28 | } 29 | } 30 | async fn new_client(&self, conn: TcpStream) { 31 | let state = self.state.clone(); 32 | let cid = { 33 | let mut state = state.lock().await; 34 | state.gen_cid += 1; 35 | state.gen_cid 36 | }; 37 | let _c = Client::process_connection(cid, state, conn); 38 | // self.state.lock().await.clients.insert(cid, c); 39 | } 40 | } 41 | 42 | #[cfg(test)] 43 | mod tests { 44 | #[test] 45 | fn test() {} 46 | } 47 | -------------------------------------------------------------------------------- /server/src/simple_sublist.rs: -------------------------------------------------------------------------------- 1 | use crate::client::ClientMessageSender; 2 | use crate::error::{NError, Result, ERROR_SUBSCRIBTION_NOT_FOUND}; 3 | use bitflags::_core::cmp::Ordering; 4 | use std::collections::{BTreeSet, HashMap}; 5 | use std::sync::Arc; 6 | use tokio::sync::Mutex; 7 | 8 | /** 9 | 为了讲解方便,考虑到Trie的实现以及Cache的实现都是很琐碎, 10 | 我这里专门实现一个简单的订阅关系查找,不支持*和>这两种模糊匹配. 11 | 这样就是简单的字符串查找了. 使用map即可. 12 | 但是为了后续的扩展性呢,我会定义SubListTrait,这样方便后续实现Trie树 13 | */ 14 | #[derive(Debug)] 15 | pub struct Subscription { 16 | pub msg_sender: Arc>, 17 | pub subject: String, 18 | pub queue: Option, 19 | pub sid: String, 20 | } 21 | impl Subscription { 22 | pub fn new( 23 | subject: &str, 24 | queue: Option<&str>, 25 | sid: &str, 26 | msg_sender: Arc>, 27 | ) -> Self { 28 | Self { 29 | subject: subject.to_string(), 30 | queue: queue.map(|s| s.to_string()), 31 | sid: sid.to_string(), 32 | msg_sender, 33 | } 34 | } 35 | } 36 | #[derive(Debug, Default)] 37 | pub struct SubResult { 38 | pub psubs: Vec, 39 | pub qsubs: Vec>, 40 | } 41 | impl SubResult { 42 | pub(crate) fn new() -> Self { 43 | Self { 44 | qsubs: Vec::new(), 45 | psubs: Vec::new(), 46 | } 47 | } 48 | } 49 | impl SubResult { 50 | fn is_empty(&self) -> bool { 51 | self.psubs.len() == 0 && self.qsubs.len() == 0 52 | } 53 | } 54 | pub type ArcSubscription = Arc; 55 | /* 56 | 因为孤儿原则,所以必须单独定义ArcSubscription 57 | */ 58 | #[derive(Debug, Clone)] 59 | pub(crate) struct ArcSubscriptionWrapper(pub ArcSubscription); 60 | impl std::cmp::PartialEq for ArcSubscriptionWrapper { 61 | fn eq(&self, other: &Self) -> bool { 62 | self.cmp(other) == Ordering::Equal 63 | } 64 | } 65 | /* 66 | 为了能够将ArcSubscription,必须实现下面这些Trait 67 | 68 | */ 69 | impl std::cmp::Eq for ArcSubscriptionWrapper {} 70 | impl std::cmp::PartialOrd for ArcSubscriptionWrapper { 71 | fn partial_cmp(&self, other: &Self) -> Option { 72 | Some(self.cmp(other)) 73 | } 74 | } 75 | impl std::cmp::Ord for ArcSubscriptionWrapper { 76 | fn cmp(&self, other: &Self) -> Ordering { 77 | let a = self.0.as_ref() as *const Subscription as usize; 78 | let b = other.0.as_ref() as *const Subscription as usize; 79 | a.cmp(&b) 80 | } 81 | } 82 | pub type ArcSubResult = Arc; 83 | pub trait SubListTrait { 84 | fn insert(&mut self, sub: ArcSubscription) -> Result<()>; 85 | fn remove(&mut self, sub: ArcSubscription) -> Result<()>; 86 | fn match_subject(&mut self, subject: &str) -> ArcSubResult; 87 | } 88 | #[derive(Debug, Default)] 89 | pub struct SimpleSubList { 90 | subs: HashMap>, 91 | qsubs: HashMap>>, 92 | } 93 | 94 | impl SubListTrait for SimpleSubList { 95 | fn insert(&mut self, sub: Arc) -> Result<()> { 96 | if let Some(ref q) = sub.queue { 97 | let entry = self 98 | .qsubs 99 | .entry(sub.subject.clone()) 100 | .or_insert(Default::default()); 101 | let queue = entry.entry(q.clone()).or_insert(Default::default()); 102 | queue.insert(ArcSubscriptionWrapper(sub)); 103 | } else { 104 | let subs = self 105 | .subs 106 | .entry(sub.subject.clone()) 107 | .or_insert(Default::default()); 108 | subs.insert(ArcSubscriptionWrapper(sub)); 109 | } 110 | Ok(()) 111 | } 112 | 113 | fn remove(&mut self, sub: Arc) -> Result<()> { 114 | if let Some(ref q) = sub.queue { 115 | if let Some(subs) = self.qsubs.get_mut(&sub.subject) { 116 | if let Some(qsubs) = subs.get_mut(q) { 117 | qsubs.remove(&ArcSubscriptionWrapper(sub.clone())); 118 | if qsubs.is_empty() { 119 | subs.remove(q); 120 | } 121 | } else { 122 | return Err(NError::new(ERROR_SUBSCRIBTION_NOT_FOUND)); 123 | } 124 | if subs.is_empty() { 125 | self.qsubs.remove(&sub.subject); 126 | } 127 | } else { 128 | return Err(NError::new(ERROR_SUBSCRIBTION_NOT_FOUND)); 129 | } 130 | } else { 131 | if let Some(subs) = self.subs.get_mut(&sub.subject) { 132 | subs.remove(&ArcSubscriptionWrapper(sub.clone())); 133 | if subs.is_empty() { 134 | self.subs.remove(&sub.subject); 135 | } 136 | } 137 | } 138 | Ok(()) 139 | } 140 | 141 | fn match_subject(&mut self, subject: &str) -> ArcSubResult { 142 | let mut r = SubResult::default(); 143 | if let Some(subs) = self.subs.get(subject) { 144 | for s in subs { 145 | r.psubs.push(s.0.clone()); 146 | } 147 | } 148 | if let Some(qsubs) = self.qsubs.get(subject) { 149 | for (_, qsub) in qsubs { 150 | let mut v = Vec::with_capacity(qsub.len()); 151 | for s in qsub { 152 | v.push(s.0.clone()); 153 | } 154 | r.qsubs.push(v); 155 | } 156 | } 157 | Arc::new(r) 158 | } 159 | } 160 | 161 | #[cfg(test)] 162 | mod tests { 163 | use super::*; 164 | use crate::client::new_test_tcp_writer; 165 | 166 | #[test] 167 | fn test_match() { 168 | let mut sl = SimpleSubList::default(); 169 | let mut subs = Vec::new(); 170 | let r = sl.match_subject("test"); 171 | assert_eq!(r.psubs.len(), 0); 172 | assert_eq!(r.qsubs.len(), 0); 173 | let sub = Arc::new(Subscription::new("test", None, "1", new_test_tcp_writer())); 174 | subs.push(sub.clone()); 175 | let r = sl.insert(sub); 176 | assert!(!r.is_err()); 177 | let r = sl.match_subject("test"); 178 | assert_eq!(r.psubs.len(), 1); 179 | assert_eq!(r.qsubs.len(), 0); 180 | let sub = Arc::new(Subscription::new("test", None, "1", new_test_tcp_writer())); 181 | subs.push(sub.clone()); 182 | let r = sl.insert(sub); 183 | assert!(!r.is_err()); 184 | let r = sl.match_subject("test"); 185 | assert_eq!(r.psubs.len(), 2); 186 | assert_eq!(r.qsubs.len(), 0); 187 | let sub = Arc::new(Subscription::new( 188 | "test", 189 | Some("q"), 190 | "1", 191 | new_test_tcp_writer(), 192 | )); 193 | subs.push(sub.clone()); 194 | let r = sl.insert(sub); 195 | assert!(!r.is_err()); 196 | let r = sl.match_subject("test"); 197 | assert_eq!(r.psubs.len(), 2); 198 | assert_eq!(r.qsubs.len(), 1); 199 | let sub = Arc::new(Subscription::new( 200 | "test", 201 | Some("q"), 202 | "1", 203 | new_test_tcp_writer(), 204 | )); 205 | subs.push(sub.clone()); 206 | let r = sl.insert(sub); 207 | assert!(!r.is_err()); 208 | let r = sl.match_subject("test"); 209 | assert_eq!(r.psubs.len(), 2); 210 | assert_eq!(r.qsubs.len(), 1); 211 | 212 | let sub = Arc::new(Subscription::new( 213 | "test", 214 | Some("q2"), 215 | "1", 216 | new_test_tcp_writer(), 217 | )); 218 | subs.push(sub.clone()); 219 | let r = sl.insert(sub); 220 | assert!(!r.is_err()); 221 | let r = sl.match_subject("test"); 222 | assert_eq!(r.psubs.len(), 2); 223 | assert_eq!(r.qsubs.len(), 2); 224 | 225 | let s = subs.pop().unwrap(); 226 | let r = sl.remove(s); 227 | assert!(!r.is_err()); 228 | let r = sl.match_subject("test"); 229 | assert_eq!(r.psubs.len(), 2); 230 | assert_eq!(r.qsubs.len(), 1); 231 | 232 | let s = subs.pop().unwrap(); 233 | let r = sl.remove(s); 234 | assert!(!r.is_err()); 235 | let r = sl.match_subject("test"); 236 | assert_eq!(r.psubs.len(), 2); 237 | assert_eq!(r.qsubs.len(), 1); 238 | 239 | let s = subs.pop().unwrap(); 240 | let r = sl.remove(s); 241 | assert!(!r.is_err()); 242 | let r = sl.match_subject("test"); 243 | assert_eq!(r.psubs.len(), 2); 244 | assert_eq!(r.qsubs.len(), 0); 245 | 246 | let s = subs.pop().unwrap(); 247 | let r = sl.remove(s); 248 | assert!(!r.is_err()); 249 | let r = sl.match_subject("test"); 250 | assert_eq!(r.psubs.len(), 1); 251 | assert_eq!(r.qsubs.len(), 0); 252 | 253 | let s = subs.pop().unwrap(); 254 | let r = sl.remove(s); 255 | assert!(!r.is_err()); 256 | let r = sl.match_subject("test"); 257 | assert_eq!(r.psubs.len(), 0); 258 | assert_eq!(r.qsubs.len(), 0); 259 | } 260 | } 261 | -------------------------------------------------------------------------------- /server/src/sublist.rs: -------------------------------------------------------------------------------- 1 | /** 2 | ### 核心的trie树 3 | 这个算是整个系统稍微复杂一点的部分 4 | 核心就是一个Trie树 5 | 6 | node是一个trie树 7 | 每个节点都是以.分割的字符串 8 | foo.bar.aa 9 | foo.cc.aa 10 | foo.bb.dd 11 | foo.dd 12 | ``` 13 | foo 14 | / / | \ \ \ 15 | * > bar cc bb dd 16 | | | | | 17 | aa aa aa aa 18 | ``` 19 | 当一个订阅foo.> 插入这个树上的时候, 这个订阅会放到>中去 ,称之为sub1 20 | 当一个foo.* 插入的时候,订阅会放到* sub2 21 | 当一个订阅foo.bar.aa 订阅来的时候会放到foo.bar.aa中去 sub3 22 | 当有人再foo.ff 发布一个消息的时候会匹配到sub1,sub2 23 | 当有人再foo.bar.aa发布一个消息的时候会匹配到sub2,sub3 24 | 25 | ### cache系统 26 | 每次查找虽然是LogN,但是代价也挺大的,因此搞了缓存 27 | 28 | 一个trie树遍历的缓存,当一个publisher发表一个消息的时候,很可能会针对这个主题再次发布消息, 29 | 那么查找到的相关的所有的subscriber,可以缓存起来 30 | 负面: 当新增或者删除subscriber的时候也要来cache里面遍历,修改. 31 | */ 32 | use crate::error::*; 33 | use crate::simple_sublist::*; 34 | use lru_cache::LruCache; 35 | use std::collections::{BTreeSet, HashMap}; 36 | use std::sync::Arc; 37 | 38 | const PWC: u8 = '*' as u8; 39 | const FWC: u8 = '>' as u8; 40 | const TSEP: &str = "."; 41 | const BTSEP: u8 = '.' as u8; 42 | // cacheMax is used to bound limit the frontend cache 43 | const SL_CACHE_MAX: usize = 1024; 44 | #[derive(Debug, Default)] 45 | pub struct Level { 46 | pwc: Option>, //* 47 | fwc: Option>, //> 48 | nodes: HashMap>, //others 49 | } 50 | impl Level { 51 | pub fn is_empty(&self) -> bool { 52 | (self.pwc.is_none() || self.pwc.as_ref().unwrap().is_empty()) 53 | && (self.fwc.is_none() || self.fwc.as_ref().unwrap().is_empty()) 54 | && self.nodes.is_empty() 55 | } 56 | } 57 | 58 | #[derive(Debug, Default)] 59 | pub struct TrieNode { 60 | next: Option>, 61 | subs: BTreeSet, 62 | qsubs: HashMap>, 63 | } 64 | impl TrieNode { 65 | fn new() -> Self { 66 | Self { 67 | next: None, 68 | subs: Default::default(), 69 | qsubs: Default::default(), 70 | } 71 | } 72 | fn is_empty(&self) -> bool { 73 | self.subs.is_empty() 74 | && self.qsubs.is_empty() 75 | && (self.next.is_none() || self.next.as_ref().unwrap().is_empty()) 76 | } 77 | } 78 | #[derive(Debug)] 79 | struct SubResultCache { 80 | cache: LruCache, 81 | } 82 | impl SubResultCache { 83 | fn new(cache_size: usize) -> SubResultCache { 84 | Self { 85 | cache: LruCache::new(cache_size), 86 | } 87 | } 88 | /* 89 | 插入的时候要考虑重建cache. 90 | 比如插入一个a.> 91 | 那么a.b.c a.d a.d.c 等对应的项中都要插入 92 | 另外,cache是共享的,所以相应的项必须重建. 93 | */ 94 | fn insert(&mut self, sub: ArcSubscription) { 95 | let mut v = Vec::new(); 96 | for (subject, result) in self.cache.iter_mut() { 97 | if match_literal(subject, sub.subject.as_str()) { 98 | let mut r = SubResult::default(); 99 | r.psubs = result.psubs.clone(); 100 | r.qsubs = result.qsubs.clone(); 101 | if let Some(ref q) = sub.queue { 102 | let mut found = false; 103 | for (pos, subs) in result.qsubs.iter().enumerate() { 104 | if subs.get(0).unwrap().queue.as_ref().unwrap().as_str() == q.as_str() { 105 | r.qsubs[pos].push(sub.clone()); 106 | found = true; 107 | } 108 | } 109 | if !found { 110 | r.qsubs.push(vec![sub.clone()]); 111 | } 112 | } else { 113 | r.psubs.push(sub.clone()); 114 | } 115 | v.push((r, subject.clone())); 116 | } 117 | } 118 | for r in v { 119 | self.cache.insert(r.1, Arc::new(r.0)); 120 | } 121 | } 122 | fn remove(&mut self, sub: &ArcSubscription) { 123 | let mut v = Vec::new(); 124 | for (subject, result) in self.cache.iter_mut() { 125 | if match_literal(subject, sub.subject.as_str()) { 126 | let mut r = SubResult::default(); 127 | r.psubs = result.psubs.clone(); 128 | r.qsubs = result.qsubs.clone(); 129 | if let Some(ref q) = sub.queue { 130 | for (pos, subs) in result.qsubs.iter().enumerate() { 131 | if subs.get(0).unwrap().queue.as_ref().unwrap().as_str() == q.as_str() { 132 | for t in subs.iter().enumerate() { 133 | if std::ptr::eq(t.1.as_ref(), sub.as_ref()) { 134 | r.qsubs[pos].swap_remove(t.0); 135 | break; 136 | } 137 | } 138 | if r.qsubs[pos].len() == 0 { 139 | r.qsubs.swap_remove(pos); 140 | } 141 | break; 142 | } 143 | } 144 | } else { 145 | let pos = r 146 | .psubs 147 | .iter() 148 | .position(|it| std::ptr::eq(it.as_ref(), sub.as_ref())); 149 | if let Some(pos) = pos { 150 | r.psubs.swap_remove(pos); 151 | } else { 152 | println!(" not found {:?}", sub); 153 | } 154 | } 155 | v.push((r, subject.clone())); 156 | } 157 | } 158 | for r in v { 159 | self.cache.insert(r.1, Arc::new(r.0)); 160 | } 161 | } 162 | fn get(&mut self, subject: &str) -> Option { 163 | // return Some(ArcSubResult::default()); 164 | //todo 由于lru cache 自身问题,等修复后就不需要copy了 165 | self.cache.get_mut(subject).map(|r| Arc::clone(r)) 166 | // self.cache.get(&subject.to_string()) 167 | } 168 | fn insert_result(&mut self, subject: &str, result: ArcSubResult) { 169 | self.cache.insert(subject.to_string(), result); 170 | } 171 | } 172 | #[test] 173 | fn test_lru() { 174 | use lru::LruCache; 175 | let mut cache = LruCache::new(2); 176 | cache.put("apple", 3); 177 | cache.put("banana", 2); 178 | assert_eq!(*cache.get(&"apple").unwrap(), 3); 179 | assert_eq!(*cache.get(&"banana").unwrap(), 2); 180 | assert!(cache.get(&"pear").is_none()); 181 | 182 | assert_eq!(cache.put("banana", 4), Some(2)); 183 | assert_eq!(cache.put("pear", 5), None); 184 | 185 | assert_eq!(*cache.get(&"pear").unwrap(), 5); 186 | assert_eq!(*cache.get(&"banana").unwrap(), 4); 187 | assert!(cache.get(&"apple").is_none()); 188 | 189 | { 190 | let v = cache.get_mut(&"banana").unwrap(); 191 | *v = 6; 192 | } 193 | 194 | assert_eq!(*cache.get(&"banana").unwrap(), 6); 195 | } 196 | impl Default for SubResultCache { 197 | fn default() -> Self { 198 | Self::new(1024) 199 | } 200 | } 201 | #[derive(Debug, Default)] 202 | pub struct TrieSubList { 203 | cache: SubResultCache, 204 | root: Level, 205 | d: ArcSubResult, 206 | default_node: Box, //只是因为Insert的时候必须有一个初始化的值 207 | } 208 | impl TrieSubList { 209 | pub fn new() -> Self { 210 | Self { 211 | cache: Default::default(), 212 | root: Default::default(), 213 | d: ArcSubResult::default(), 214 | default_node: Default::default(), 215 | } 216 | } 217 | } 218 | impl SubListTrait for TrieSubList { 219 | /* 220 | 将合法的subject插入树中, 221 | 形如a.b.c a.*.c a.* a.>等 222 | 插入的时候要考虑重建cache. 223 | 比如插入一个a.> 224 | 那么a.b.c a.d a.d.c 等对应的项中都要插入 225 | 另外,cache是共享的,所以相应的项必须重建. 226 | */ 227 | fn insert(&mut self, sub: Arc) -> Result<()> { 228 | if !is_valid_subject(sub.subject.as_str()) { 229 | return Err(NError::new(ERROR_INVALID_SUBJECT)); 230 | } 231 | // println!("insert {}", sub.subject); 232 | let mut l = &mut self.root; 233 | let mut n = &mut self.default_node; 234 | let mut tokens = split_subject(&sub.subject).peekable(); 235 | while tokens.peek().is_some() { 236 | let token = tokens.next().unwrap(); 237 | // println!("token:{}", token); 238 | let t = token.as_bytes()[0]; 239 | match t { 240 | PWC => { 241 | if l.pwc.is_none() { 242 | l.pwc = Some(Box::new(TrieNode::new())); 243 | } 244 | n = l.pwc.as_mut().unwrap(); 245 | } 246 | FWC => { 247 | if l.fwc.is_none() { 248 | l.fwc = Some(Box::new(TrieNode::new())); 249 | } 250 | n = l.fwc.as_mut().unwrap(); 251 | } 252 | _ => { 253 | n = l 254 | .nodes 255 | .entry(token.to_string()) 256 | .or_insert(Box::new(TrieNode::new())); 257 | } 258 | } 259 | let is_last = tokens.peek().is_none(); 260 | if is_last { 261 | break; 262 | } 263 | if n.next.is_none() { 264 | n.next = Some(Box::new(Level::default())); 265 | } 266 | l = n.next.as_mut().unwrap(); 267 | } 268 | 269 | if let Some(ref q) = sub.queue { 270 | let qsubs = n.qsubs.entry(q.clone()).or_insert(Default::default()); 271 | qsubs.insert(ArcSubscriptionWrapper(sub.clone())); 272 | } else { 273 | n.subs.insert(ArcSubscriptionWrapper(sub.clone())); 274 | } 275 | self.cache.insert(sub); 276 | Ok(()) 277 | } 278 | /* 279 | 将合法的subject从树中移除, 280 | 形如a.b.c a.*.c a.* a.>等 281 | 插入的时候要考虑重建cache. 282 | 比如插入一个a.> 283 | 那么a.b.c a.d a.d.c 等对应的项中都要删除 284 | 另外,cache是共享的,所以相应的项必须重建. 285 | */ 286 | fn remove(&mut self, sub: Arc) -> Result<()> { 287 | if !is_valid_subject(sub.subject.as_str()) { 288 | return Err(NError::new(ERROR_INVALID_SUBJECT)); 289 | } 290 | let tokens = sub.subject.split(".").peekable(); 291 | if Self::remove_internal(&mut self.root, tokens, &sub) { 292 | self.cache.remove(&sub); 293 | } else { 294 | return Err(NError::new(ERROR_SUBSCRIBTION_NOT_FOUND)); 295 | } 296 | Ok(()) 297 | } 298 | //pub 时用 299 | //pub a.b.c 300 | //需要查找到订阅了a.> a.*.c a.b.* a.b.c和> 这些可能匹配的节点 301 | //并且他们不应该在同一个queue中,就是订阅了a.*.c 和 a.b.c就算是他们有相同的queue,也不能做负载均衡. 302 | fn match_subject(&mut self, subject: &str) -> ArcSubResult { 303 | // return Arc::new(SubResult::new()); 304 | if let Some(r) = self.cache.get(subject) { 305 | return r; 306 | } 307 | if !is_valid_literal_subject(subject) { 308 | unreachable!("invalid subject {}", subject); 309 | } 310 | let mut r = Default::default(); 311 | let tokens = split_subject(subject).peekable(); 312 | let l = &mut self.root; 313 | Self::match_internal(l, tokens, &mut r); 314 | let r = Arc::new(r); 315 | self.cache.insert_result(subject, r.clone()); 316 | r 317 | } 318 | } 319 | impl TrieSubList { 320 | fn cache_count(&self) -> usize { 321 | self.cache.cache.len() 322 | } 323 | fn add_node_to_result(n: &TrieNode, r: &mut SubResult) { 324 | for sub in n.subs.iter() { 325 | r.psubs.push(sub.0.clone()); 326 | } 327 | for subs in n.qsubs.values() { 328 | let mut v = vec![]; 329 | for sub in subs { 330 | v.push(sub.0.clone()); 331 | } 332 | r.qsubs.push(v); 333 | } 334 | } 335 | fn match_internal(l: &mut Level, mut tokens: std::iter::Peekable, r: &mut SubResult) { 336 | let token = tokens.next(); 337 | if token.is_none() { 338 | return; 339 | } 340 | let token = token.unwrap(); 341 | let is_last = tokens.peek().is_none(); 342 | //match > 343 | if let Some(ref fwc) = l.fwc { 344 | Self::add_node_to_result(fwc.as_ref(), r); 345 | } 346 | //match * 347 | if let Some(ref mut pwc) = l.pwc { 348 | if is_last { 349 | Self::add_node_to_result(pwc.as_ref(), r); 350 | } else if let Some(l) = pwc.next.as_mut() { 351 | Self::match_internal(l.as_mut(), tokens.clone(), r); 352 | } 353 | } 354 | //match exactly 355 | if let Some(n) = l.nodes.get_mut(token) { 356 | if is_last { 357 | Self::add_node_to_result(n.as_ref(), r); 358 | } else if let Some(ref mut l) = n.next { 359 | Self::match_internal(l.as_mut(), tokens, r); 360 | } 361 | } 362 | } 363 | //返回true表示删除了sub,否则表示没有删除 364 | fn remove_internal( 365 | l: &mut Level, 366 | mut tokens: std::iter::Peekable>, 367 | sub: &Arc, 368 | ) -> bool { 369 | let token = tokens.next(); 370 | if token.is_none() { 371 | return false; 372 | } 373 | let token = token.unwrap(); 374 | let is_last = tokens.peek().is_none(); 375 | let n; 376 | match token { 377 | "*" => { 378 | n = l.pwc.as_mut(); 379 | } 380 | ">" => { 381 | n = l.fwc.as_mut(); 382 | } 383 | _ => n = l.nodes.get_mut(token), 384 | } 385 | if n.is_none() { 386 | return false; 387 | } 388 | let n = n.unwrap(); 389 | if is_last { 390 | if !Self::remove_sub(n.as_mut(), sub.clone()) { 391 | return false; 392 | } 393 | } else { 394 | if n.next.is_none() { 395 | return false; 396 | } 397 | let l = n.next.as_mut().unwrap(); 398 | //如果成功移除了,就要考虑我这一层是否是空的了. 399 | if !Self::remove_internal(l.as_mut(), tokens, sub) { 400 | return false; 401 | } 402 | } 403 | if n.is_empty() { 404 | match token { 405 | "*" => { 406 | l.pwc = None; 407 | } 408 | ">" => { 409 | l.fwc = None; 410 | } 411 | _ => { 412 | l.nodes.remove(token); 413 | } 414 | } 415 | } 416 | true 417 | } 418 | //从Node中移除一个sub 419 | fn remove_sub(n: &mut TrieNode, sub: ArcSubscription) -> bool { 420 | if let Some(ref q) = sub.queue { 421 | let qsubs = n.qsubs.get_mut(q); 422 | if let Some(qsubs) = qsubs { 423 | return qsubs.remove(&ArcSubscriptionWrapper(sub)); 424 | } 425 | } else { 426 | return n.subs.remove(&ArcSubscriptionWrapper(sub)); 427 | } 428 | false 429 | } 430 | fn match_test(&mut self, _subject: &str) -> ArcSubResult { 431 | // Arc::new(SubResult::new()) 432 | self.d.clone() 433 | } 434 | fn match_test2(&mut self, subject: &str) -> Box { 435 | Box::new(SubResult::new()) 436 | } 437 | } 438 | ///is_valid_subject returns true if a subject is valid, false otherwise 439 | /// 当收到sub消息时使用,无效的包括: 440 | /// 1. 空的subject 441 | /// 2. 连续的.. 442 | /// 3. 包含>,但是不以>结尾 443 | pub fn is_valid_subject(subject: &str) -> bool { 444 | if subject.len() == 0 { 445 | return false; 446 | } 447 | let mut sfwc = false; 448 | !split_subject(subject).any(|s: &str| { 449 | //连续的..或者>后面有. 450 | if s.len() == 0 || sfwc { 451 | return true; 452 | } 453 | if s.len() >= 1 { 454 | if s.as_bytes()[0] == FWC { 455 | //只允许单独的> 456 | if s.len() > 1 { 457 | return true; 458 | } 459 | sfwc = true; 460 | } 461 | } 462 | false 463 | }) 464 | } 465 | 466 | /// pub时用的 467 | /// 无效的主题 包括: 468 | /// 1. is_valid_subject 认为无效的肯定无效 469 | /// 2. 包含* 470 | /// 3. 包含> 471 | pub fn is_valid_literal_subject(subject: &str) -> bool { 472 | if subject.len() == 0 { 473 | return false; 474 | } 475 | // split_subject(subject).any() 476 | !split_subject(subject).any(|s: &str| { 477 | if s.len() == 0 { 478 | return true; 479 | } 480 | if s.len() > 1 { 481 | return false; 482 | } 483 | let b = s.as_bytes()[0]; 484 | if b == FWC || b == PWC { 485 | return true; 486 | } 487 | false 488 | }) 489 | } 490 | #[derive(Clone, Debug)] 491 | struct Split<'a> { 492 | pos: usize, 493 | buf: &'a [u8], 494 | } 495 | fn split_subject<'a>(subject: &'a str) -> Split<'a> { 496 | Split { 497 | pos: 0, 498 | buf: subject.as_bytes(), 499 | } 500 | } 501 | impl<'a> std::iter::Iterator for Split<'a> { 502 | type Item = &'a str; 503 | fn next(&mut self) -> Option { 504 | if self.pos > self.buf.len() { 505 | return None; 506 | } 507 | let start = self.pos; 508 | while self.pos < self.buf.len() { 509 | if self.buf[self.pos] == BTSEP { 510 | break; 511 | } 512 | self.pos += 1; 513 | } 514 | let str = unsafe { std::str::from_utf8_unchecked(&self.buf[start..self.pos]) }; 515 | self.pos += 1; //无论哪种情况,都应该跳过. 516 | return Some(str); 517 | } 518 | } 519 | #[test] 520 | fn test_split() { 521 | let v = vec!["a", "b", "c"]; 522 | let r: Vec<_> = split_subject("a.b.c").collect(); 523 | assert_eq!(v, r); 524 | let v = vec!["a", "b", "", "c"]; 525 | let r: Vec<_> = split_subject("a.b..c").collect(); 526 | assert_eq!(v, r); 527 | let v = vec!["", "", "a", "b", "c"]; 528 | let r: Vec<_> = split_subject("..a.b.c").collect(); 529 | assert_eq!(v, r); 530 | let v = vec!["a", "b", "c", ""]; 531 | let r: Vec<_> = split_subject("a.b.c.").collect(); 532 | assert_eq!(v, r); 533 | } 534 | /// matchLiteral is used to test literal subjects, those that do not have any 535 | /// wildcards, with a target subject. This is used in the cache layer. 536 | /// 判断a.b.c和a.*.c时否匹配 537 | fn match_literal(literal: &str, subject: &str) -> bool { 538 | let mut literal_iter = split_subject(literal).peekable(); 539 | let mut subject_iter = split_subject(subject).peekable(); 540 | 541 | while literal_iter.peek().is_some() { 542 | //a.b.c走完了,即使是a.b.c.>也不能匹配 543 | if subject_iter.peek().is_none() { 544 | return false; 545 | } 546 | let literal = literal_iter.next().unwrap(); 547 | let subject = subject_iter.next().unwrap(); 548 | if literal == subject { 549 | continue; 550 | } 551 | if subject == "*" { 552 | continue; 553 | } 554 | if subject == ">" { 555 | return true; 556 | } 557 | return false; 558 | } 559 | subject_iter.peek().is_none() 560 | } 561 | #[cfg(test)] 562 | fn test_new_sub(subject: &str) -> Subscription { 563 | use crate::client::new_test_tcp_writer; 564 | let writer = new_test_tcp_writer(); 565 | Subscription::new(subject, None, "1", writer.clone()) 566 | } 567 | #[cfg(test)] 568 | fn test_new_sub_arc(subject: &str) -> ArcSubscription { 569 | Arc::new(test_new_sub(subject)) 570 | } 571 | #[cfg(test)] 572 | mod tests { 573 | use super::*; 574 | fn verify_count(_sub: &TrieSubList, _count: usize) { 575 | // assert_eq!( 576 | // sub.count(), 577 | // count, 578 | // "expect count={},got={}", 579 | // count, 580 | // sub.count() 581 | // ); 582 | } 583 | fn verify_len(r: &[T], l: usize) { 584 | assert_eq!(r.len(), l, "results len expect={},got={}", l, r.len()); 585 | } 586 | fn verify_qlen(r: &Vec>>, l: usize) { 587 | assert_eq!(r.len(), l, "queue results len expect={},got={}", l, r.len()); 588 | } 589 | fn verify_num_levels(_s: &TrieSubList, l: usize) { 590 | // assert_eq!( 591 | // s.num_levels(), 592 | // l, 593 | // "numlevels expect={},got={}", 594 | // l, 595 | // s.num_levels() 596 | // ); 597 | } 598 | 599 | fn verify_member(r: &[Arc], val: &Subscription) { 600 | for s in r { 601 | if std::ptr::eq(s.as_ref(), val) { 602 | return; 603 | } 604 | } 605 | assert!(false, "sub:{:?} not found in results", val); 606 | } 607 | fn new_qsub(subject: &str, queue: &str) -> Subscription { 608 | let mut s = test_new_sub(subject); 609 | s.queue = Some(queue.into()); 610 | s 611 | } 612 | #[test] 613 | fn test_match_literal() { 614 | assert_eq!(true, match_literal("a.b.c", "a.*.c")); 615 | assert_eq!(true, match_literal("a.b.c", "a.>")); 616 | assert_eq!(false, match_literal("a.b.c.d", "a.*.c")); 617 | } 618 | 619 | #[test] 620 | fn test_sublist_insert_count() { 621 | let mut s = TrieSubList::new(); 622 | assert!(s.insert(Arc::new(test_new_sub("foo"))).is_ok()); 623 | assert!(s.insert(Arc::new(test_new_sub("bar"))).is_ok()); 624 | assert!(s.insert(Arc::new(test_new_sub("foo.bar"))).is_ok()); 625 | verify_count(&s, 3); 626 | } 627 | #[test] 628 | fn test_sublist_simple() { 629 | let mut s = TrieSubList::new(); 630 | let subject = "foo"; 631 | let sub = Arc::new(test_new_sub(subject)); 632 | assert!(s.insert(sub.clone()).is_ok()); 633 | let r = s.match_subject(subject); 634 | verify_len(r.psubs.as_slice(), 1); 635 | verify_member(r.psubs.as_slice(), sub.as_ref()); 636 | } 637 | #[test] 638 | fn test_sublist_simple_multi_tokens() { 639 | let mut s = TrieSubList::new(); 640 | let subject = "foo.bar.baz"; 641 | let sub = Arc::new(test_new_sub(subject)); 642 | assert!(s.insert(sub.clone()).is_ok()); 643 | let r = s.match_subject(subject); 644 | verify_len(r.psubs.as_slice(), 1); 645 | verify_member(r.psubs.as_slice(), sub.as_ref()); 646 | } 647 | #[test] 648 | fn test_sublist_partial_wildcard() { 649 | let mut s = TrieSubList::new(); 650 | let lsub = Arc::new(test_new_sub("a.b.c")); 651 | let psub = Arc::new(test_new_sub("a.*.c")); 652 | 653 | assert!(s.insert(lsub.clone()).is_ok()); 654 | assert!(s.insert(psub.clone()).is_ok()); 655 | let r = s.match_subject(&lsub.subject); 656 | verify_len(r.psubs.as_slice(), 2); 657 | verify_member(r.psubs.as_slice(), lsub.as_ref()); 658 | verify_member(r.psubs.as_slice(), psub.as_ref()); 659 | } 660 | #[test] 661 | fn test_sublist_partial_wildcard_at_end() { 662 | let mut s = TrieSubList::new(); 663 | let lsub = Arc::new(test_new_sub("a.b.c")); 664 | let psub = Arc::new(test_new_sub("a.b.*")); 665 | 666 | assert!(s.insert(lsub.clone()).is_ok()); 667 | assert!(s.insert(psub.clone()).is_ok()); 668 | let r = s.match_subject(&lsub.subject); 669 | verify_len(r.psubs.as_slice(), 2); 670 | verify_member(r.psubs.as_slice(), lsub.as_ref()); 671 | verify_member(r.psubs.as_slice(), psub.as_ref()); 672 | } 673 | #[test] 674 | fn test_sublist_partial_full_wildcard() { 675 | let mut s = TrieSubList::new(); 676 | let lsub = Arc::new(test_new_sub("a.b.c")); 677 | let psub = Arc::new(test_new_sub("a.>")); 678 | 679 | assert!(s.insert(lsub.clone()).is_ok()); 680 | assert!(s.insert(psub.clone()).is_ok()); 681 | let r = s.match_subject(&lsub.subject); 682 | verify_len(r.psubs.as_slice(), 2); 683 | verify_member(r.psubs.as_slice(), lsub.as_ref()); 684 | verify_member(r.psubs.as_slice(), psub.as_ref()); 685 | } 686 | #[test] 687 | fn test_sublist_remove() { 688 | let mut s = TrieSubList::new(); 689 | let sub = Arc::new(test_new_sub("a.b.c.d")); 690 | 691 | assert!(s.insert(sub.clone()).is_ok()); 692 | let r = s.match_subject(&sub.subject); 693 | verify_len(r.psubs.as_slice(), 1); 694 | verify_count(&s, 1); 695 | assert!(s.remove(Arc::new(test_new_sub("a.b.c"))).is_err()); 696 | verify_count(&s, 1); 697 | assert!(s.remove(sub.clone()).is_ok()); 698 | verify_count(&s, 0); 699 | let r = s.match_subject(&sub.subject); 700 | verify_len(r.psubs.as_slice(), 0); 701 | } 702 | 703 | #[test] 704 | fn test_sublist_remove_wildcard() { 705 | let mut s = TrieSubList::new(); 706 | let sub = Arc::new(test_new_sub("a.b.c.d")); 707 | let psub = Arc::new(test_new_sub("a.b.*.d")); 708 | let fsub = Arc::new(test_new_sub("a.b.>")); 709 | let _ = s.insert(sub.clone()); 710 | let _ = s.insert(psub.clone()); 711 | let _ = s.insert(fsub.clone()); 712 | verify_count(&s, 3); 713 | 714 | let r = s.match_subject(&sub.subject); 715 | verify_len(r.psubs.as_slice(), 3); 716 | assert!(s.remove(sub.clone()).is_ok()); 717 | verify_count(&s, 2); 718 | assert!(s.remove(fsub.clone()).is_ok()); 719 | verify_count(&s, 1); 720 | assert!(s.remove(psub.clone()).is_ok()); 721 | verify_count(&s, 0); 722 | 723 | verify_count(&s, 1); 724 | assert!(s.remove(Arc::new(test_new_sub("a.b.c"))).is_err()); 725 | verify_count(&s, 0); 726 | 727 | let r = s.match_subject(&sub.subject); 728 | verify_len(r.psubs.as_slice(), 0); 729 | } 730 | #[test] 731 | fn test_sublist_remove_cleanup() { 732 | let mut s = TrieSubList::new(); 733 | let literal = "a.b.c.d.e.f"; 734 | let depth = literal.split(TSEP).count(); 735 | let sub = Arc::new(test_new_sub(literal)); 736 | verify_num_levels(&s, 0); 737 | let _ = s.insert(sub.clone()); 738 | verify_num_levels(&s, depth); 739 | let _ = s.remove(sub.clone()); 740 | verify_num_levels(&s, 0); 741 | } 742 | #[test] 743 | fn test_sublist_remove_cleanup_wildcards() { 744 | let mut s = TrieSubList::new(); 745 | let literal = "a.b.*.d.e.>"; 746 | let depth = literal.split(TSEP).count(); 747 | let sub = Arc::new(test_new_sub(literal)); 748 | verify_num_levels(&s, 0); 749 | let _ = s.insert(sub.clone()); 750 | verify_num_levels(&s, depth); 751 | let _ = s.remove(sub); 752 | verify_num_levels(&s, 0); 753 | } 754 | #[test] 755 | fn test_sublist_invalid_subjects_insert() { 756 | let mut s = TrieSubList::new(); 757 | assert!(s.insert(Arc::new(test_new_sub(".foo"))).is_err()); 758 | assert!(s.insert(Arc::new(test_new_sub("foo."))).is_err()); 759 | assert!(s.insert(Arc::new(test_new_sub("foo..bar"))).is_err()); 760 | assert!(s.insert(Arc::new(test_new_sub("foo.bar..baz"))).is_err()); 761 | assert!(s.insert(Arc::new(test_new_sub("foo.>.bar"))).is_err()); 762 | } 763 | #[test] 764 | fn test_sublist_cache() { 765 | let mut s = TrieSubList::new(); 766 | let subject = "a.b.c.d"; 767 | let sub = Arc::new(test_new_sub(subject)); 768 | let psub = Arc::new(test_new_sub("a.b.*.d")); 769 | let fsub = Arc::new(test_new_sub("a.b.>")); 770 | let _ = s.insert(sub.clone()); 771 | let _ = s.insert(psub.clone()); 772 | let _ = s.insert(fsub.clone()); 773 | verify_count(&s, 3); 774 | let r = s.match_subject(subject); 775 | verify_len(r.psubs.as_slice(), 3); 776 | let _ = s.remove(sub); 777 | verify_count(&s, 2); 778 | let _ = s.remove(fsub.clone()); 779 | verify_count(&s, 1); 780 | let _ = s.remove(psub.clone()); 781 | verify_count(&s, 0); 782 | assert_eq!(s.cache_count(), 0); 783 | let r = s.match_subject(subject); 784 | verify_len(r.psubs.as_slice(), 0); 785 | for i in 0..2 * SL_CACHE_MAX { 786 | s.match_subject(format!("foo-#{}", i).as_str()); 787 | } 788 | // assert!(s.cache_count() <= SL_CACHE_MAX); 789 | } 790 | #[test] 791 | fn test_sublist_basic_queue_results() { 792 | let mut s = TrieSubList::new(); 793 | let subject = "foo"; 794 | let sub1 = Arc::new(new_qsub(subject, "bar")); 795 | let sub2 = Arc::new(new_qsub(subject, "baz")); 796 | 797 | let _ = s.insert(sub1.clone()); 798 | let r = s.match_subject(subject); 799 | verify_len(r.psubs.as_slice(), 0); 800 | verify_qlen(&r.qsubs, 1); 801 | verify_len(r.qsubs[0].as_slice(), 1); 802 | verify_member(r.qsubs[0].as_slice(), sub1.as_ref()); 803 | 804 | let _ = s.insert(sub2.clone()); 805 | let r = s.match_subject(subject); 806 | verify_len(r.psubs.as_slice(), 0); 807 | println!("qsubs={:?}", r.qsubs); 808 | verify_qlen(&r.qsubs, 2); 809 | verify_len(r.qsubs[0].as_slice(), 1); 810 | verify_len(r.qsubs[1].as_slice(), 1); 811 | verify_member(r.qsubs[0].as_slice(), sub1.as_ref()); 812 | verify_member(r.qsubs[1].as_slice(), sub2.as_ref()); 813 | } 814 | #[test] 815 | fn test_valid_literal_subject() { 816 | assert_eq!(is_valid_literal_subject("foo"), true); 817 | assert_eq!(is_valid_literal_subject(".foo"), false); 818 | assert_eq!(is_valid_literal_subject("foo."), false); 819 | assert_eq!(is_valid_literal_subject("foo..bar"), false); 820 | assert_eq!(is_valid_literal_subject("foo.bar.*"), false); 821 | assert_eq!(is_valid_literal_subject("foo.bar.>"), false); 822 | assert_eq!(is_valid_literal_subject("*"), false); 823 | assert_eq!(is_valid_literal_subject(">"), false); 824 | } 825 | #[test] 826 | fn test_match_literal2() { 827 | println!("a={}", 3); 828 | assert_eq!(match_literal("foo", "foo"), true); 829 | assert_eq!(match_literal("foo", "bar"), false); 830 | assert_eq!(match_literal("foo", "*"), true); 831 | assert_eq!(match_literal("foo", ">"), true); 832 | assert_eq!(match_literal("foo.bar", ">"), true); 833 | assert_eq!(match_literal("foo.bar", "foo.>"), true); 834 | assert_eq!(match_literal("foo.bar", "bar.>"), false); 835 | assert_eq!(match_literal("stats.test.22", "stats.>"), true); 836 | assert_eq!(match_literal("stats.test.22", "stats.*.*"), true); 837 | assert_eq!(match_literal("foo.bar", "foo"), false); 838 | assert_eq!(match_literal("stats.test.foos", "stats.test.foos"), true); 839 | assert_eq!(match_literal("stats.test.foos", "stats.test.foo"), false); 840 | } 841 | #[test] 842 | fn test_sublist_two_token_pub_match_single_token_sub() { 843 | let mut s = TrieSubList::new(); 844 | let sub = Arc::new(test_new_sub("foo")); 845 | let _ = s.insert(sub.clone()); 846 | let r = s.match_subject("foo"); 847 | verify_len(r.psubs.as_slice(), 1); 848 | let r = s.match_subject("foo.bar"); 849 | verify_len(r.psubs.as_slice(), 0); 850 | } 851 | } 852 | 853 | #[cfg(test)] 854 | mod benchmark { 855 | extern crate test; 856 | use super::*; 857 | use test::Bencher; 858 | 859 | fn create_subs(pre: &str, subs: &mut Vec>) { 860 | let tokens = vec![ 861 | "apcera", 862 | "continuum", 863 | "component", 864 | "router", 865 | "api", 866 | "imgr", 867 | "jmgr", 868 | "auth", 869 | ]; 870 | for t in tokens { 871 | let sub; 872 | if pre.len() > 0 { 873 | sub = format!("{}.{}", pre, t); 874 | } else { 875 | sub = t.into(); 876 | } 877 | subs.push(Arc::new(test_new_sub(sub.as_str()))); 878 | if sub.split(TSEP).count() < 5 { 879 | create_subs(sub.as_str(), subs); 880 | } 881 | } 882 | } 883 | fn add_wild_cards(s: &mut TrieSubList) { 884 | let _ = s.insert(Arc::new(test_new_sub("cloud.>"))); 885 | let _ = s.insert(Arc::new(test_new_sub("cloud.continuum.component.>"))); 886 | let _ = s.insert(Arc::new(test_new_sub("cloud.*.*.router.*"))); 887 | } 888 | fn create_test_subs() -> Vec> { 889 | let mut subs = Vec::new(); 890 | create_subs("", &mut subs); 891 | subs 892 | } 893 | fn get_test_sublist() -> TrieSubList { 894 | let mut s = TrieSubList::new(); 895 | let subs = create_test_subs(); 896 | for sub in subs { 897 | let _ = s.insert(sub.clone()); 898 | } 899 | add_wild_cards(&mut s); 900 | s 901 | } 902 | #[test] 903 | fn test_subs() { 904 | let subs = create_test_subs(); 905 | let l = subs.len(); 906 | println!("subs len={}", l); 907 | subs.iter().take(1000).for_each(|sub| { 908 | println!("subs={:?}", sub.subject); 909 | }); 910 | } 911 | use lazy_static::lazy_static; 912 | lazy_static! { 913 | static ref TEST_SUBS_COLLECT: Vec = { create_test_subs() }; 914 | } 915 | #[bench] 916 | fn benchmark1_sublist_insert(b: &mut Bencher) { 917 | //为什么go版本的insert只需要300ns,我这个需要3000ns,慢了一个数量级 918 | let mut s = TrieSubList::new(); 919 | let subs = &TEST_SUBS_COLLECT; 920 | let l = subs.len(); 921 | let mut i = 0; 922 | // println!("subs len={}", l); 923 | // println!("subs={:?}", subs); 924 | b.iter(|| { 925 | let _ = s.insert(subs[i % l].clone()); 926 | // println!("i={}", i); 927 | i += 1; 928 | }); 929 | println!("count={}", i); 930 | } 931 | 932 | #[bench] 933 | fn benchmark1_match_single_token(b: &mut Bencher) { 934 | let mut s = get_test_sublist(); 935 | b.iter(|| { 936 | let _ = s.match_subject("apcera"); 937 | }) 938 | } 939 | #[bench] 940 | fn benchmark1_match_twotokens(b: &mut Bencher) { 941 | let mut s = get_test_sublist(); 942 | b.iter(|| { 943 | let _ = s.match_subject("apcera.continuum"); 944 | }) 945 | } 946 | // #[test] 947 | // fn test_match() { 948 | // let mut s = get_test_sublist(); 949 | // for i in 0..10 { 950 | // let _ = s.match_subject("apcera.continuum.component"); 951 | // } 952 | // } 953 | #[bench] 954 | fn benchmark1_match_threetokens(b: &mut Bencher) { 955 | let mut s = get_test_sublist(); 956 | b.iter(|| { 957 | let _ = s.match_subject("apcera.continuum.component"); 958 | }) 959 | } 960 | #[bench] 961 | fn benchmark1_match_fourtokens(b: &mut Bencher) { 962 | let mut s = get_test_sublist(); 963 | let _ = s.match_subject("apcera.continuum.component.router"); 964 | let summary = b.bench(|b| { 965 | b.iter(|| { 966 | let _ = s.match_subject("apcera.continuum.component.router"); 967 | }); 968 | }); 969 | println!("summary={:?}", summary) 970 | } 971 | #[bench] 972 | fn benchmark1_match_fivetokens2(b: &mut Bencher) { 973 | let mut s = get_test_sublist(); 974 | b.iter(|| { 975 | let _ = s.match_test("apcera.continuum.component.router.ZZZZ"); 976 | }) 977 | } 978 | #[bench] 979 | fn benchmark1_match_fivetokens3(b: &mut Bencher) { 980 | let mut s = get_test_sublist(); 981 | b.iter(|| { 982 | let _ = s.match_test2("apcera.continuum.component.router.ZZZZ"); 983 | }) 984 | } 985 | #[bench] 986 | fn benchmark1_match_fivetokens(b: &mut Bencher) { 987 | let mut s = get_test_sublist(); 988 | b.iter(|| { 989 | let _ = s.match_subject("apcera.continuum.component.router.ZZZZ"); 990 | }) 991 | } 992 | fn get_test_array() -> Vec { 993 | let mut v = vec![32; 10000]; 994 | v[9000] = 999; 995 | return v; 996 | } 997 | fn search_order(v: &[u32]) -> i32 { 998 | for i in 0..v.len() { 999 | if v[i] == 999 { 1000 | return i as i32; 1001 | } 1002 | } 1003 | return -1; 1004 | } 1005 | //因为不能越界,确保最后一个不是999 1006 | fn search_order2(v: &mut [u32]) -> i32 { 1007 | let l = v.len() - 1; 1008 | let hold = v[l]; 1009 | v[l] = 999; 1010 | let mut i = 0; 1011 | loop { 1012 | if v[i] == 999 { 1013 | break; 1014 | } 1015 | i += 1; 1016 | } 1017 | v[l] = hold; 1018 | if i == l { 1019 | return -1; 1020 | } else { 1021 | return i as i32; 1022 | } 1023 | } 1024 | fn search_order3(v: &[u32]) -> i32 { 1025 | let l = v.len() - 1; 1026 | // let hold = v[l]; 1027 | // v[l] = 999; 1028 | let mut i = 0; 1029 | //访问越界 1030 | while i <= l { 1031 | if v[i] == 999 { 1032 | break; 1033 | } 1034 | if v[i + 1] == 999 { 1035 | i += 1; 1036 | break; 1037 | } 1038 | if v[i + 2] == 999 { 1039 | i += 2; 1040 | break; 1041 | } 1042 | if v[i + 3] == 999 { 1043 | i += 3; 1044 | break; 1045 | } 1046 | if v[i + 4] == 999 { 1047 | i += 4; 1048 | break; 1049 | } 1050 | if v[i + 5] == 999 { 1051 | i += 5; 1052 | break; 1053 | } 1054 | if v[i + 6] == 999 { 1055 | i += 6; 1056 | break; 1057 | } 1058 | if v[i + 7] == 999 { 1059 | i += 7; 1060 | break; 1061 | } 1062 | if v[i + 8] == 999 { 1063 | i += 8; 1064 | break; 1065 | } 1066 | i += 8; 1067 | } 1068 | // v[l] = holxd; 1069 | if i > l { 1070 | return -1; 1071 | } else { 1072 | return i as i32; 1073 | } 1074 | } 1075 | #[bench] 1076 | fn test_search_order(b: &mut Bencher) { 1077 | let v = get_test_array(); 1078 | b.iter(|| { 1079 | assert_eq!(9000, search_order(v.as_slice())); 1080 | }) 1081 | } 1082 | #[bench] 1083 | fn test_search_order2(b: &mut Bencher) { 1084 | let mut v = get_test_array(); 1085 | b.iter(|| { 1086 | assert_eq!(9000, search_order2(v.as_mut_slice())); 1087 | }) 1088 | } 1089 | #[bench] 1090 | fn test_search_order3(b: &mut Bencher) { 1091 | let v = get_test_array(); 1092 | b.iter(|| { 1093 | assert_eq!(9000, search_order3(v.as_slice())); 1094 | }) 1095 | } 1096 | } 1097 | --------------------------------------------------------------------------------