├── .gitignore ├── Cargo.toml ├── Readme.md ├── benches └── bench.rs ├── examples ├── load_graph.rs └── simulate.rs ├── flops.py └── src ├── data.rs ├── generate.rs ├── lib.rs ├── pool.rs └── timeline.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | /data 4 | profile.json 5 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "twitterperf" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | bytemuck = { version = "1.12.3", features = ["derive"] } 10 | criterion = "0.4.0" 11 | expect-test = "1.4.0" 12 | libc = "0.2.139" 13 | memmap2 = "0.5.8" 14 | nanorand = "0.7.0" 15 | rand = "0.8.5" 16 | rand-wyrand = "0.1.0" 17 | rand_distr = "0.4.3" 18 | ringbuffer = "0.10.0" 19 | signpost = "0.1.0" 20 | static_assertions = "1.1.0" 21 | 22 | [[bench]] 23 | name = "bench" 24 | harness = false 25 | 26 | [profile.release] 27 | debug = true 28 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Twitter performance prototype 2 | 3 | A Rust prototype of handling the full production load of Twitter's core timeline collation on a single core by only doing the very basics in-memory. This is purely a stunt systems design curiousity rather than anything practical. 4 | -------------------------------------------------------------------------------- /benches/bench.rs: -------------------------------------------------------------------------------- 1 | use criterion::{ 2 | black_box, criterion_group, criterion_main, Bencher, BenchmarkId, Criterion, Throughput, 3 | }; 4 | use twitterperf::data::START_TIME; 5 | // use twitterperf::data::Datastore; 6 | use twitterperf::generate::{LoadGraph, TweetGenerator, TweetGeneratorConfig}; 7 | use twitterperf::timeline::{Timeline, TimelineFetcher}; 8 | 9 | // fn bench_merge<'a>(b: &mut Bencher, input: &'a mut (&'a mut TweetGenerator, &'a mut Datastore<'a>)) { 10 | // let (gen, data) = input; 11 | // b.iter(|| { 12 | // let user_idx = black_box(gen.gen_view()); 13 | // Timeline::for_user(&data, user_idx, 200) 14 | // }); 15 | // } 16 | 17 | fn criterion_benchmark(c: &mut Criterion) { 18 | let loader = LoadGraph::new().unwrap(); 19 | let graph = loader.graph(); 20 | 21 | let n_tweets = 4_000_000; 22 | let mut config = TweetGeneratorConfig::default(); 23 | let (mut gen, mut data) = TweetGenerator::new(config, graph); 24 | 25 | gen.add_tweets(&mut data, n_tweets); 26 | 27 | // c.bench_with_input(BenchmarkId::new("timeline_merge", "default"), &mut (&mut gen, &mut data), bench_merge); 28 | let mut group = c.benchmark_group("timeline"); 29 | group.throughput(Throughput::Elements(69)); 30 | group.bench_function("merge", |b| { 31 | let mut fetcher = TimelineFetcher::default(); 32 | b.iter(|| { 33 | let user_idx = black_box(gen.gen_view()); 34 | fetcher.for_user(&data, user_idx, 200, START_TIME); 35 | }) 36 | }); 37 | group.finish() 38 | } 39 | 40 | criterion_group!(benches, criterion_benchmark); 41 | criterion_main!(benches); 42 | -------------------------------------------------------------------------------- /examples/load_graph.rs: -------------------------------------------------------------------------------- 1 | // process the graph from https://snap.stanford.edu/data/twitter-2010.html 2 | // time cat /Users/tristan/Downloads/twitter-2010.txt.gz | gunzip | cargo run --release --example load_graph 3 | 4 | use bytemuck::cast_slice; 5 | use std::{ 6 | fs::File, 7 | io::{self, BufRead, Write}, 8 | }; 9 | 10 | use twitterperf::data::*; 11 | 12 | const TEST: bool = true; 13 | 14 | fn main() { 15 | let num_users = 41652230; 16 | let mut graph: Vec> = (0..num_users).map(|_| vec![]).collect(); 17 | 18 | let mut buffer = String::new(); 19 | let stdin = io::stdin(); 20 | let mut handle = stdin.lock(); 21 | 22 | let mut total_follows = 0; 23 | while handle.read_line(&mut buffer).unwrap() != 0 { 24 | let mut split = buffer.split(" "); 25 | // An edge from i to j indicates that j is a follower of i 26 | let i: u32 = split.next().unwrap().trim().parse().unwrap(); 27 | let j: u32 = split.next().unwrap().trim().parse().unwrap(); 28 | graph[j as usize].push(i); 29 | buffer.clear(); 30 | total_follows += 1; 31 | } 32 | 33 | eprintln!("Done phase 1"); 34 | 35 | let mut users: Vec = Vec::with_capacity(num_users); 36 | let mut follows: Vec = Vec::with_capacity(total_follows); 37 | 38 | for ls in &graph { 39 | let user = User { 40 | follows_idx: follows.len(), 41 | num_follows: ls.len() as u32, 42 | num_followers: 0, 43 | }; 44 | users.push(user); 45 | for x in ls { 46 | follows.push(*x); 47 | } 48 | } 49 | 50 | eprintln!("Done phase 2"); 51 | 52 | for f in &follows { 53 | users[*f as usize].num_followers += 1; 54 | } 55 | 56 | eprintln!("Done phase 3"); 57 | 58 | if TEST { 59 | return; 60 | } 61 | 62 | let mut users_f = File::create("data/users.bin").unwrap(); 63 | let user_bytes = cast_slice(&users[..]); 64 | users_f.write_all(user_bytes).unwrap(); 65 | 66 | let mut follows_f = File::create("data/follows.bin").unwrap(); 67 | let follows_bytes = cast_slice(&follows[..]); 68 | follows_f.write_all(follows_bytes).unwrap(); 69 | } 70 | -------------------------------------------------------------------------------- /examples/simulate.rs: -------------------------------------------------------------------------------- 1 | use std::thread; 2 | use std::time::Instant; 3 | 4 | use twitterperf::data::START_TIME; 5 | use twitterperf::generate::{LoadGraph, TweetGenerator, TweetGeneratorConfig, ViewGenerator}; 6 | use twitterperf::timeline::TimelineFetcher; 7 | 8 | use signpost::{trace_function, AutoTrace}; 9 | 10 | fn main() { 11 | let loader = LoadGraph::new().unwrap(); 12 | let graph = loader.graph(); 13 | 14 | let n_test_add = 15_000_000; 15 | let n_tweets = 30_000_000 - n_test_add; 16 | let config = TweetGeneratorConfig::default(); 17 | let (mut gen, viewing_users, mut data) = TweetGenerator::new(config, graph); 18 | 19 | let add_start = Instant::now(); 20 | trace_function(1, &[0; 4], || gen.add_tweets(&mut data, n_tweets)); 21 | let add_dur = Instant::now() - add_start; 22 | let add_rate = n_tweets as f64 / add_dur.as_secs_f64(); 23 | eprintln!("Initially added {n_tweets} tweets in {add_dur:?}: {add_rate:.3} tweets/s."); 24 | 25 | let add_start = Instant::now(); 26 | trace_function(1, &[0; 4], || gen.add_tweets(&mut data, n_tweets)); 27 | let add_dur = Instant::now() - add_start; 28 | let add_rate = n_test_add as f64 / add_dur.as_secs_f64(); 29 | eprintln!("Benchmarked adding {n_test_add} tweets in {add_dur:?}: {add_rate:.3} tweets/s."); 30 | 31 | let _x = AutoTrace::new(2, &[0usize; 4]); 32 | let n_views = 100_000; 33 | // let mut total_likes = 0u32; 34 | let n_threads = 8; 35 | eprintln!("Starting fetches from {n_threads} threads"); 36 | let viewing_users = &viewing_users[..]; 37 | let data = &data; 38 | thread::scope(|s| { 39 | for _ in 0..n_threads { 40 | let seed: u64 = gen.fork_seed(); 41 | s.spawn(move || { 42 | let mut view_gen = ViewGenerator::new(seed, viewing_users); 43 | let mut total_viewed = 0usize; 44 | let start = Instant::now(); 45 | let mut fetcher = TimelineFetcher::default(); 46 | for _ in 0..n_views { 47 | let user_idx = view_gen.gen_view(); 48 | let timeline = fetcher.for_user(data, user_idx, 256, START_TIME); 49 | total_viewed += timeline.tweets.len(); 50 | // total_likes += timeline.tweets.iter().map(|t| t.likes).sum::(); 51 | } 52 | let dur = Instant::now() - start; 53 | let rate = total_viewed as f64 / dur.as_secs_f64(); 54 | let avg_timeline_size = total_viewed as f64 / n_views as f64; 55 | let expansion = (avg_timeline_size * view_gen.viewing_users.len() as f64) / n_tweets as f64; 56 | eprintln!("Done {total_viewed} in {dur:?} at {rate:.3} tweets/s. Avg timeline size {avg_timeline_size:.2} -> expansion {expansion:.2}"); 57 | }); 58 | } 59 | }); 60 | // eprintln!("{total_likes}"); 61 | } 62 | -------------------------------------------------------------------------------- /flops.py: -------------------------------------------------------------------------------- 1 | """Computes the flops needed for training/running transformer networks.""" 2 | 3 | import collections 4 | 5 | # We checked this code with TensorFlow"s FLOPs counting, although we had to 6 | # correct for this issue: https://github.com/tensorflow/tensorflow/issues/22071 7 | # Assumptions going into the FLOPs counting 8 | # - An "operation" is a mathematical operation, not a machine instruction. So 9 | # an "exp" takes one opp like and add, even though in practice an exp 10 | # might be slower. This is not too bad an assumption because 11 | # matrix-multiplies dominate the compute for most models, so minor details 12 | # about activation functions don"t matter too much. Similarly, we count 13 | # matrix-multiplies as 2*m*n flops instead of m*n, as one might if 14 | # if considering fused multiply-add ops. 15 | # - Backward pass takes the same number of FLOPs as forward pass. No exactly 16 | # right (e.g., for softmax cross entropy loss the backward pass is faster). 17 | # Importantly, it really is the same for matrix-multiplies, which is most of 18 | # the compute anyway. 19 | # - We assume "dense" embedding lookups (i.e., multiplication by a one-hot 20 | # vector). On some hardware accelerators, these dense operations are 21 | # actually faster than sparse lookups. 22 | # Please open a github issue if you spot a problem with this code! 23 | 24 | # I am not sure if the below constants are 100% right, but they are only applied 25 | # to O(hidden_size) activations, which is generally a lot less compute than the 26 | # matrix-multiplies, which are O(hidden_size^2), so they don't affect the total 27 | # number of FLOPs much. 28 | 29 | # random number, >=, multiply activations by dropout mask, multiply activations 30 | # by correction (1 / (1 - dropout_rate)) 31 | DROPOUT_FLOPS = 4 32 | 33 | # compute mean activation (sum), computate variance of activation 34 | # (square and sum), bias (add), scale (multiply) 35 | LAYER_NORM_FLOPS = 5 36 | 37 | # GELU: 0.5 * x * (1 + tanh(sqrt(2 / np.pi) * (x + 0.044715 * pow(x, 3)))) 38 | ACTIVATION_FLOPS = 8 39 | 40 | # max/substract (for stability), exp, sum, divide 41 | SOFTMAX_FLOPS = 5 42 | 43 | 44 | class TransformerHparams(object): 45 | """Computes the train/inference FLOPs for transformers.""" 46 | 47 | def __init__(self, h, l, s=512, v=30522, e=None, i=None, heads=None, 48 | head_size=None, output_frac=0.15625, sparse_embed_lookup=False, 49 | decoder=False): 50 | self.h = h # hidden size 51 | self.l = l # number of layers 52 | self.s = s # sequence length 53 | self.v = v # vocab size 54 | self.e = h if e is None else e # embedding size 55 | self.i = h * 4 if i is None else i # intermediate size 56 | self.kqv = h if head_size is None else head_size * heads # attn proj sizes 57 | self.heads = max(h // 64, 1) if heads is None else heads # attention heads 58 | self.output_frac = output_frac # percent of tokens using an output softmax 59 | self.sparse_embed_lookup = sparse_embed_lookup # sparse embedding lookups 60 | self.decoder = decoder # decoder has extra attn to encoder states 61 | 62 | def get_block_flops(self): 63 | """Get the forward-pass FLOPs for a single transformer block.""" 64 | attn_mul = 2 if self.decoder else 1 65 | block_flops = dict( 66 | kqv=3 * 2 * self.h * self.kqv * attn_mul, 67 | kqv_bias=3 * self.kqv * attn_mul, 68 | attention_scores=2 * self.kqv * self.s * attn_mul, 69 | attn_softmax=SOFTMAX_FLOPS * self.s * self.heads * attn_mul, 70 | attention_dropout=DROPOUT_FLOPS * self.s * self.heads * attn_mul, 71 | attention_scale=self.s * self.heads * attn_mul, 72 | attention_weighted_avg_values=2 * self.h * self.s * attn_mul, 73 | attn_output=2 * self.h * self.h * attn_mul, 74 | attn_output_bias=self.h * attn_mul, 75 | attn_output_dropout=DROPOUT_FLOPS * self.h * attn_mul, 76 | attn_output_residual=self.h * attn_mul, 77 | attn_output_layer_norm=LAYER_NORM_FLOPS * attn_mul, 78 | intermediate=2 * self.h * self.i, 79 | intermediate_act=ACTIVATION_FLOPS * self.i, 80 | intermediate_bias=self.i, 81 | output=2 * self.h * self.i, 82 | output_bias=self.h, 83 | output_dropout=DROPOUT_FLOPS * self.h, 84 | output_residual=self.h, 85 | output_layer_norm=LAYER_NORM_FLOPS * self.h, 86 | ) 87 | return sum(block_flops.values()) * self.s 88 | 89 | def get_embedding_flops(self, output=False): 90 | """Get the forward-pass FLOPs the transformer inputs or output softmax.""" 91 | embedding_flops = {} 92 | if output or (not self.sparse_embed_lookup): 93 | embedding_flops["main_multiply"] = 2 * self.e * self.v 94 | # input embedding post-processing 95 | if not output: 96 | embedding_flops.update(dict( 97 | tok_type_and_position=2 * self.e * (self.s + 2), 98 | add_tok_type_and_position=2 * self.e, 99 | emb_layer_norm=LAYER_NORM_FLOPS * self.e, 100 | emb_dropout=DROPOUT_FLOPS * self.e 101 | )) 102 | # projection layer if e != h 103 | if self.e != self.h or output: 104 | embedding_flops.update(dict( 105 | hidden_kernel=2 * self.h * self.e, 106 | hidden_bias=self.e if output else self.h 107 | )) 108 | # extra hidden layer and output softmax 109 | if output: 110 | embedding_flops.update(dict( 111 | hidden_activation=ACTIVATION_FLOPS * self.e, 112 | hidden_layernorm=LAYER_NORM_FLOPS * self.e, 113 | output_softmax=SOFTMAX_FLOPS * self.v, 114 | output_target_word=2 * self.v 115 | )) 116 | return self.output_frac * sum(embedding_flops.values()) * self.s 117 | return sum(embedding_flops.values()) * self.s 118 | 119 | def get_binary_classification_flops(self): 120 | classification_flops = dict( 121 | hidden=2 * self.h * self.h, 122 | hidden_bias=self.h, 123 | hidden_act=ACTIVATION_FLOPS * self.h, 124 | logits=2 * self.h 125 | ) 126 | return sum(classification_flops.values()) * self.s 127 | 128 | def get_train_flops(self, batch_size, train_steps, discriminator=False): 129 | """Get the FLOPs for pre-training the transformer.""" 130 | # 2* for forward/backward pass 131 | return 2 * batch_size * train_steps * ( 132 | (self.l * self.get_block_flops()) + 133 | self.get_embedding_flops(output=False) + 134 | (self.get_binary_classification_flops() if discriminator else 135 | self.get_embedding_flops(output=True)) 136 | ) 137 | 138 | def get_infer_flops(self): 139 | """Get the FLOPs for running inference with the transformer on a 140 | classification task.""" 141 | return ((self.l * self.get_block_flops()) + 142 | self.get_embedding_flops(output=False) + 143 | self.get_binary_classification_flops()) 144 | 145 | 146 | def get_electra_train_flops( 147 | h_d, l_d, h_g, l_g, batch_size, train_steps, tied_embeddings, 148 | e=None, s=512, output_frac=0.15625): 149 | """Get the FLOPs needed for pre-training ELECTRA.""" 150 | if e is None: 151 | e = h_d 152 | disc = TransformerHparams( 153 | h_d, l_d, s=s, e=e, 154 | output_frac=output_frac).get_train_flops(batch_size, train_steps, True) 155 | gen = TransformerHparams( 156 | h_g, l_g, s=s, e=e if tied_embeddings else None, 157 | output_frac=output_frac).get_train_flops(batch_size, train_steps) 158 | return disc + gen 159 | 160 | 161 | MODEL_FLOPS = collections.OrderedDict([ 162 | # These runtimes were computed with tensorflow FLOPs counting instead of the 163 | # script, as the neural architectures are quite different. 164 | # 768648884 words in LM1b benchmark, 10 epochs with batch size 20, 165 | # seq length 128, 568093262680 FLOPs per example. 166 | ("elmo", 2 * 10 * 768648884 * 568093262680 / (20.0 * 128)), 167 | # 15064773691518 is FLOPs for forward pass on 32 examples. 168 | # Therefore 2 * steps * batch_size * 15064773691518 / 32 is XLNet compute 169 | ("xlnet", 2 * 500000 * 8192 * 15064773691518 / 32.0), 170 | 171 | ("tinybert", TransformerHparams(312, 4, i=1200, s=128).get_infer_flops()), 172 | 173 | 174 | # Runtimes computed with the script 175 | ("gpt", TransformerHparams(768, 12, v=40000, output_frac=1.0).get_train_flops( 176 | 128, 960800)), 177 | ("bert_small", TransformerHparams(256, 12, e=128, s=128).get_infer_flops()), 178 | ("bert_base", TransformerHparams(768, 12).get_train_flops(256, 1e6)), 179 | ("bert_large", TransformerHparams(1024, 24).get_train_flops(256, 1e6)), 180 | ("electra_small", get_electra_train_flops(256, 12, 64, 12, 128, 1e6, True, s=128, e=128)), 181 | ("electra_base", get_electra_train_flops(768, 12, 256, 12, 256, 766000, True)), 182 | ("electra_400k", get_electra_train_flops(1024, 24, 256, 24, 2048, 400000, True)), 183 | ("electra_1.75M", get_electra_train_flops(1024, 24, 256, 24, 2048, 1750000, True)), 184 | 185 | # RoBERTa, ALBERT, and T5 have minor architectural differences from 186 | # BERT/ELECTRA, but I believe they don't significantly effect the runtime, 187 | # so we use this script for those models as well. 188 | ("roberta", TransformerHparams(1024, 24, v=50265).get_train_flops(8000, 500000)), 189 | ("albert", TransformerHparams(4096, 12, v=30000, e=128).get_train_flops( 190 | 4096, 1.5e6)), 191 | ("t5_11b", TransformerHparams( 192 | 1024, # hidden size 193 | 24, # layers 194 | v=32000, # vocab size 195 | i=65536, # ff intermediate hidden size 196 | heads=128, head_size=128, # heads/head size 197 | output_frac=0.0 # encoder has no output softmax 198 | ).get_train_flops(2048, 1e6) + # 1M steps with batch size 2048 199 | TransformerHparams( 200 | 1024, 201 | 24, 202 | v=32000, 203 | i=65536, 204 | heads=128, head_size=128, 205 | output_frac=1.0, # decoder has output softmax for all positions 206 | decoder=True 207 | ).get_train_flops(2048, 1e6)) 208 | ]) 209 | 210 | 211 | def main(): 212 | for k, v in MODEL_FLOPS.items(): 213 | print(k, v) 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | -------------------------------------------------------------------------------- /src/data.rs: -------------------------------------------------------------------------------- 1 | use std::{num::NonZeroU32, sync::atomic::AtomicU64}; 2 | 3 | use bytemuck::{NoUninit, Pod, Zeroable}; 4 | use static_assertions::assert_eq_size; 5 | 6 | use crate::pool::SharedPool; 7 | 8 | /// Leave room for a full 280 character plus some accents or emoji. 9 | /// A real implementation would have an escape hatch for longer tweets. 10 | pub const TWEET_BYTES: usize = 286; 11 | 12 | // non-zero so options including a timestamp don't take any more space 13 | // u32 since that's 100+ years of second-level precision and it lets us pack atomics 14 | pub type Timestamp = NonZeroU32; 15 | pub const START_TIME: Timestamp = unsafe { NonZeroU32::new_unchecked(1) }; 16 | 17 | #[derive(Clone)] 18 | pub struct Tweet { 19 | pub content: [u8; TWEET_BYTES], 20 | pub ts: Timestamp, 21 | 22 | pub likes: u32, 23 | pub quotes: u32, 24 | pub retweets: u32, 25 | } 26 | 27 | impl Tweet { 28 | pub fn dummy(ts: Timestamp) -> Self { 29 | Tweet { 30 | content: [0; TWEET_BYTES], 31 | ts, 32 | likes: 0, 33 | quotes: 0, 34 | retweets: 0, 35 | } 36 | } 37 | } 38 | 39 | // assert_eq_size!([u8; 304], Tweet); 40 | 41 | pub type TweetIdx = u32; 42 | 43 | /// linked list of tweets to make appending fast and avoid space overhead 44 | /// a linked list of chunks of tweets would probably be faster because of 45 | /// cache locality of fetches, but I haven't implemented that 46 | #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, NoUninit)] 47 | #[repr(C)] 48 | pub struct NextLink { 49 | pub ts: Timestamp, 50 | pub tweet_idx: TweetIdx, 51 | } 52 | assert_eq_size!(AtomicU64, NextLink); 53 | 54 | pub type FeedChain = Option; 55 | 56 | #[derive(Clone, Copy, Zeroable, Pod)] 57 | #[repr(C)] 58 | struct PodNextLink(u32, u32); 59 | assert_eq_size!(AtomicU64, PodNextLink); 60 | 61 | /// Top level feeds use an atomic link so we can mutate concurrently 62 | /// This effectively works by casting NextLink to a u64 63 | pub struct AtomicChain(AtomicU64); 64 | 65 | impl AtomicChain { 66 | pub fn none() -> Self { 67 | AtomicChain(AtomicU64::new(0)) 68 | } 69 | 70 | pub fn set(&self, next: NextLink) { 71 | let as_u64: u64 = bytemuck::cast(next); 72 | self.0.store(as_u64, std::sync::atomic::Ordering::SeqCst); 73 | } 74 | 75 | pub fn fetch(&self) -> Option { 76 | let as_u64 = self.0.load(std::sync::atomic::Ordering::SeqCst); 77 | // we hope LLVM optimizes this into a no-op 78 | let pod: PodNextLink = bytemuck::cast(as_u64); 79 | match Timestamp::new(pod.0) { 80 | Some(ts) => Some(NextLink { 81 | ts, 82 | tweet_idx: pod.1, 83 | }), 84 | None => None, 85 | } 86 | } 87 | } 88 | 89 | #[repr(align(64))] 90 | pub struct ChainedTweet { 91 | pub tweet: Tweet, 92 | pub prev_tweet: FeedChain, 93 | } 94 | assert_eq_size!([u8; 320], ChainedTweet); 95 | 96 | pub type UserIdx = u32; 97 | 98 | #[derive(Copy, Clone, Pod, Zeroable)] 99 | #[repr(C)] 100 | pub struct User { 101 | // This would be better as a Vec for mutation but for fast data loading we use one giant slice 102 | pub follows_idx: usize, 103 | pub num_follows: u32, 104 | pub num_followers: u32, 105 | } 106 | 107 | /// We store the Graph in a format we can mmap from a pre-baked file 108 | /// so that our tests can load a real graph faster 109 | pub struct Graph<'a> { 110 | pub users: &'a [User], 111 | pub follows: &'a [UserIdx], 112 | } 113 | 114 | impl<'a> Graph<'a> { 115 | #[inline] 116 | pub fn user_follows(&'a self, user: &User) -> &'a [UserIdx] { 117 | &self.follows[user.follows_idx..][..user.num_follows as usize] 118 | } 119 | } 120 | 121 | pub struct Datastore<'a> { 122 | pub graph: Graph<'a>, 123 | pub tweets: SharedPool, 124 | pub feeds: Vec, 125 | } 126 | 127 | impl<'a> Datastore<'a> { 128 | /// This will clobber writes (in a safe way) if called concurrently 129 | /// from multiple threads. Ideally we'd have a separate &mut handle for this 130 | pub fn add_tweet(&self, tweet: Tweet, user_id: UserIdx) { 131 | let prev_tweet = self.feeds[user_id as usize].fetch(); 132 | let ts = tweet.ts; 133 | let chained = ChainedTweet { tweet, prev_tweet }; 134 | let tweet_idx = self.tweets.push(chained) as TweetIdx; 135 | self.feeds[user_id as usize].set(NextLink { ts, tweet_idx }); 136 | } 137 | 138 | pub fn prefetch_tweet(&self, tweet_idx: TweetIdx) { 139 | let tweet_ptr = &self.tweets[tweet_idx as usize] as *const ChainedTweet; 140 | unsafe { 141 | for cache_line in 0..3 { 142 | let line_ptr = (tweet_ptr as *const i8).offset(64 * cache_line); 143 | core::arch::x86_64::_mm_prefetch(line_ptr, core::arch::x86_64::_MM_HINT_T0) 144 | } 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /src/generate.rs: -------------------------------------------------------------------------------- 1 | use crate::data::*; 2 | use crate::pool::SharedPool; 3 | 4 | use bytemuck::cast_slice; 5 | use memmap2::Mmap; 6 | use rand::seq::SliceRandom; 7 | use rand::{Rng, SeedableRng}; 8 | use rand_wyrand::WyRand; 9 | use std::fs::File; 10 | use std::ops::Deref; 11 | 12 | pub struct TweetGeneratorConfig { 13 | pub seed: u64, 14 | pub tweeter_follower_thresh: u32, 15 | pub viewer_follow_thresh: u32, 16 | } 17 | 18 | impl Default for TweetGeneratorConfig { 19 | fn default() -> TweetGeneratorConfig { 20 | TweetGeneratorConfig { 21 | seed: 123, 22 | tweeter_follower_thresh: 20, 23 | viewer_follow_thresh: 20, 24 | } 25 | } 26 | } 27 | 28 | pub struct TweetGenerator { 29 | // config: TweetGeneratorConfig, 30 | tweeting_users: Vec, 31 | rng: WyRand, 32 | ts: Timestamp, 33 | } 34 | 35 | pub type ViewingUsers = Vec; 36 | 37 | impl TweetGenerator { 38 | pub fn new<'a>( 39 | config: TweetGeneratorConfig, 40 | graph: Graph<'a>, 41 | ) -> (Self, ViewingUsers, Datastore<'a>) { 42 | let feeds: Vec = (0..graph.users.len()) 43 | .map(|_| AtomicChain::none()) 44 | .collect(); 45 | let tweets = SharedPool::new().unwrap(); 46 | 47 | let mut rng = WyRand::from_seed(config.seed.to_le_bytes()); 48 | let mut tweeting_users: Vec = graph 49 | .users 50 | .iter() 51 | .enumerate() 52 | .filter(|(_, u)| u.num_followers > config.tweeter_follower_thresh) 53 | .map(|(i, _)| i as u32) 54 | .collect(); 55 | tweeting_users.shuffle(&mut rng); 56 | let mut viewing_users: Vec = graph 57 | .users 58 | .iter() 59 | .enumerate() 60 | .filter(|(_, u)| u.num_follows > config.viewer_follow_thresh) 61 | .map(|(i, _)| i as u32) 62 | .collect(); 63 | viewing_users.shuffle(&mut rng); 64 | let this = Self { 65 | // config, 66 | tweeting_users, 67 | rng, 68 | ts: START_TIME, 69 | }; 70 | 71 | let data = Datastore { 72 | graph, 73 | tweets, 74 | feeds, 75 | }; 76 | 77 | (this, viewing_users, data) 78 | } 79 | 80 | pub fn gen_tweet(&mut self) -> (UserIdx, Tweet) { 81 | // TODO Zipf distribution or something 82 | let user_id: UserIdx = *self.tweeting_users.choose(&mut self.rng).unwrap(); 83 | let tweet = Tweet::dummy(self.ts); 84 | self.ts = self.ts.saturating_add(1); 85 | (user_id, tweet) 86 | } 87 | 88 | pub fn add_tweets(&mut self, data: &mut Datastore, n: usize) { 89 | for _ in 0..n { 90 | let (user_id, tweet) = self.gen_tweet(); 91 | data.add_tweet(tweet, user_id); 92 | } 93 | } 94 | 95 | pub fn fork_seed(&mut self) -> u64 { 96 | self.rng.gen() 97 | } 98 | } 99 | 100 | pub struct ViewGenerator<'a> { 101 | pub viewing_users: &'a [UserIdx], 102 | rng: WyRand, 103 | } 104 | 105 | impl<'a> ViewGenerator<'a> { 106 | pub fn new(seed: u64, viewing_users: &'a [UserIdx]) -> Self { 107 | Self { 108 | rng: WyRand::from_seed(seed.to_le_bytes()), 109 | viewing_users, 110 | } 111 | } 112 | 113 | pub fn gen_view(&mut self) -> UserIdx { 114 | // TODO Zipf distribution or something 115 | let user_id: UserIdx = *self.viewing_users.choose(&mut self.rng).unwrap(); 116 | user_id 117 | } 118 | } 119 | 120 | pub struct LoadGraph { 121 | users: Mmap, 122 | follows: Mmap, 123 | } 124 | 125 | impl LoadGraph { 126 | pub fn new() -> std::io::Result { 127 | Ok(Self { 128 | users: unsafe { Mmap::map(&File::open("data/users.bin")?)? }, 129 | follows: unsafe { Mmap::map(&File::open("data/follows.bin")?)? }, 130 | }) 131 | } 132 | 133 | pub fn graph<'a>(&'a self) -> Graph<'a> { 134 | Graph { 135 | users: cast_slice(self.users.deref()), 136 | follows: cast_slice(self.follows.deref()), 137 | } 138 | } 139 | } 140 | 141 | #[cfg(test)] 142 | mod tests { 143 | use crate::timeline::TimelineFetcher; 144 | 145 | use super::*; 146 | use expect_test::{expect, Expect}; 147 | 148 | pub fn n_eq(n: usize, ex: Expect) { 149 | ex.assert_eq(&n.to_string()); 150 | } 151 | 152 | pub fn f_eq(f: f64, ex: Expect) { 153 | ex.assert_eq(&format!("{f:.3}")); 154 | } 155 | 156 | #[test] 157 | fn loading() { 158 | let loader = LoadGraph::new().unwrap(); 159 | let graph = loader.graph(); 160 | 161 | n_eq(graph.users.len(), expect!["41652230"]); 162 | n_eq(graph.follows.len(), expect!["1468365182"]); 163 | 164 | let non_trivial = graph.users.iter().filter(|u| u.num_follows > 20).count(); 165 | n_eq(non_trivial, expect!["9031061"]); 166 | 167 | let max_follows = graph.users.iter().map(|u| u.num_follows).max().unwrap(); 168 | n_eq(max_follows as usize, expect!["770155"]); 169 | 170 | let max_followers = graph.users.iter().map(|u| u.num_followers).max().unwrap(); 171 | n_eq(max_followers as usize, expect!["2997469"]); 172 | } 173 | 174 | #[test] 175 | fn generating() { 176 | let loader = LoadGraph::new().unwrap(); 177 | let graph = loader.graph(); 178 | 179 | let n_tweets = 4_000_000; 180 | let config = TweetGeneratorConfig::default(); 181 | let (mut gen, mut data) = TweetGenerator::new(config, graph); 182 | 183 | n_eq(gen.viewing_users.len(), expect!["9031061"]); 184 | n_eq(gen.tweeting_users.len(), expect!["6746960"]); 185 | 186 | gen.add_tweets(&mut data, n_tweets); 187 | 188 | let n_views = 100_000; 189 | let mut total_viewed = 0usize; 190 | let mut fetcher = TimelineFetcher::default(); 191 | for _ in 0..n_views { 192 | let user_idx = gen.gen_view(); 193 | let timeline = fetcher.for_user(&data, user_idx, 200, START_TIME); 194 | total_viewed += timeline.tweets.len(); 195 | } 196 | let avg_timeline_size = total_viewed as f64 / n_views as f64; 197 | f_eq(avg_timeline_size, expect!["41.480"]); 198 | let expansion = (avg_timeline_size * gen.viewing_users.len() as f64) / n_tweets as f64; 199 | f_eq(expansion, expect!["93.652"]); 200 | } 201 | } 202 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod data; 2 | pub mod generate; 3 | pub mod pool; 4 | pub mod timeline; 5 | 6 | pub fn add(left: usize, right: usize) -> usize { 7 | left + right 8 | } 9 | 10 | #[cfg(test)] 11 | mod tests { 12 | use super::*; 13 | 14 | #[test] 15 | fn it_works() { 16 | let result = add(2, 2); 17 | assert_eq!(result, 4); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/pool.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | use std::ops::Index; 3 | use std::ptr::{self, NonNull}; 4 | use std::sync::atomic::{AtomicUsize, Ordering}; 5 | use std::sync::Mutex; 6 | 7 | pub struct SharedPool { 8 | len: AtomicUsize, 9 | lock: Mutex<()>, 10 | buf: NonNull, 11 | } 12 | 13 | unsafe impl Sync for SharedPool {} 14 | unsafe impl Send for SharedPool {} 15 | 16 | // https://vgel.me/posts/mmap-arena-alloc/ 17 | impl SharedPool { 18 | /// It doesn't matter that much how large this is but let's go for 34GB 19 | const MAP_SIZE: usize = 1 << 35; 20 | 21 | pub fn new() -> io::Result { 22 | // TODO use hugepages only on linux 23 | let map = unsafe { 24 | libc::mmap( 25 | ptr::null_mut(), 26 | Self::MAP_SIZE, 27 | libc::PROT_READ | libc::PROT_WRITE, 28 | // MAP_PRIVATE: this is not shared memory 29 | // MAP_ANONYMOUS: this is RAM, not a file-backed mmap 30 | // MAP_NORESERVE: don't reserve swap 31 | // MAP_HUGETLB: use huge pages for better performance 32 | // (make sure huge pages are enabled or this will SIGBUS: 33 | // # sysctl -w vm.nr_hugepages=2048) 34 | libc::MAP_PRIVATE | libc::MAP_ANONYMOUS | libc::MAP_NORESERVE, // | libc::MAP_HUGETLB, 35 | -1, 36 | 0, 37 | ) 38 | }; 39 | 40 | let buf = if map == libc::MAP_FAILED || map.is_null() { 41 | return Err(io::Error::last_os_error()); 42 | } else { 43 | NonNull::new(map as *mut T).unwrap() 44 | }; 45 | 46 | Ok(Self { 47 | buf, 48 | lock: Mutex::new(()), 49 | len: AtomicUsize::new(0), 50 | }) 51 | } 52 | 53 | #[inline] 54 | pub fn push(&self, value: T) -> usize { 55 | // TODO either be clever about queueing these up or 56 | // split this type into a reader and a writer to avoid the lock 57 | let _guard = self.lock.lock(); 58 | let i = self.len.load(Ordering::SeqCst); 59 | unsafe { 60 | let end = self.buf.as_ptr().add(i); 61 | ptr::write(end, value); 62 | } 63 | self.len.fetch_add(1, Ordering::SeqCst); 64 | i 65 | } 66 | } 67 | 68 | impl Index for SharedPool { 69 | type Output = T; 70 | 71 | #[inline] 72 | fn index(&self, i: usize) -> &T { 73 | let len = self.len.load(Ordering::SeqCst); 74 | if i >= len { 75 | panic!("index out of bounds {i} for length {len}") 76 | } 77 | unsafe { 78 | let item = self.buf.as_ptr().add(i); 79 | &*item 80 | } 81 | } 82 | } 83 | 84 | #[cfg(test)] 85 | mod tests { 86 | use super::*; 87 | 88 | #[test] 89 | fn basic_pool() { 90 | let pool = SharedPool::new().unwrap(); 91 | pool.push(5); 92 | pool.push(6); 93 | assert_eq!(pool[0], 5); 94 | assert_eq!(pool[1], 6); 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/timeline.rs: -------------------------------------------------------------------------------- 1 | #![allow(dead_code)] 2 | use ringbuffer::ConstGenericRingBuffer; 3 | use ringbuffer::RingBufferWrite; 4 | use static_assertions::assert_eq_size; 5 | use std::collections::BinaryHeap; 6 | 7 | use crate::data::*; 8 | 9 | pub struct Timeline<'a> { 10 | pub tweets: &'a [Tweet], 11 | } 12 | 13 | pub const CACHE_SIZE: usize = 123; 14 | 15 | #[derive(Clone)] 16 | #[repr(align(64))] 17 | pub struct CachedTimeline { 18 | tweets: ConstGenericRingBuffer, 19 | } 20 | assert_eq_size!([u8; 512], CachedTimeline); 21 | 22 | impl Default for CachedTimeline { 23 | fn default() -> Self { 24 | Self { 25 | tweets: ConstGenericRingBuffer::new(), 26 | } 27 | } 28 | } 29 | 30 | pub struct TimelineCache { 31 | pub timelines: Vec, 32 | } 33 | 34 | impl TimelineCache { 35 | fn new(graph: &Graph) -> Self { 36 | Self { 37 | timelines: vec![CachedTimeline::default(); graph.users.len()], 38 | } 39 | } 40 | 41 | fn publish_tweet(&mut self, graph: &Graph, user_idx: UserIdx, tweet_idx: TweetIdx) { 42 | let user = &graph.users[user_idx as usize]; 43 | for follow in graph.user_follows(user) { 44 | self.timelines[*follow as usize].tweets.push(tweet_idx); 45 | } 46 | } 47 | } 48 | 49 | #[derive(Default)] 50 | pub struct TimelineFetcher { 51 | tweets: Vec, 52 | heap: BinaryHeap, 53 | } 54 | 55 | impl TimelineFetcher { 56 | #[inline] 57 | fn push_after(&mut self, link: Option, after: Timestamp) { 58 | link.filter(|l| l.ts >= after).map(|l| self.heap.push(l)); 59 | } 60 | 61 | pub fn for_user<'a>( 62 | &'a mut self, 63 | data: &Datastore, 64 | user_idx: UserIdx, 65 | max_len: usize, 66 | after: Timestamp, 67 | ) -> Timeline<'a> { 68 | self.heap.clear(); 69 | self.tweets.clear(); 70 | let user = &data.graph.users[user_idx as usize]; 71 | 72 | // seed heap 73 | for follow in data.graph.user_follows(user) { 74 | self.push_after(data.feeds[*follow as usize].fetch(), after); 75 | } 76 | 77 | // compose timeline 78 | while let Some(NextLink { ts: _, tweet_idx }) = self.heap.pop() { 79 | let chain = &data.tweets[tweet_idx as usize]; 80 | // tweets.push(Tweet::dummy(NonZeroU64::new(1).unwrap())); 81 | self.tweets.push(chain.tweet.clone()); 82 | if self.tweets.len() >= max_len { 83 | break; 84 | } 85 | 86 | self.push_after(chain.prev_tweet, after); 87 | } 88 | 89 | Timeline { 90 | tweets: &self.tweets[..], 91 | } 92 | } 93 | } 94 | --------------------------------------------------------------------------------