├── .DS_Store ├── .github └── workflows │ └── rust.yml ├── .gitignore ├── Cargo.toml ├── README.md ├── dataflow_nlp ├── .gitignore ├── Cargo.toml └── src │ ├── batching │ ├── mod.rs │ └── tests.rs │ ├── lib.rs │ ├── pipelines │ ├── line_loader.rs │ └── mod.rs │ ├── resources │ ├── bpe_merges.txt │ ├── bpe_vocab.json │ └── wordpiece_vocab.txt │ ├── tests.rs │ ├── tokenization │ ├── alphabet.rs │ ├── bpe.rs │ ├── mod.rs │ ├── sentence.rs │ ├── tests.rs │ ├── whitespace.rs │ └── wordpiece.rs │ └── vocab │ ├── basic.rs │ ├── bpe.rs │ ├── mod.rs │ ├── tests.rs │ └── wordpiece.rs └── src ├── .DS_Store ├── dataloader ├── dataloader.rs ├── mod.rs └── tests.rs ├── lib.rs └── pipeline ├── connectors.rs ├── loader ├── file.rs ├── keyed.rs ├── mod.rs └── vec.rs ├── mod.rs ├── node.rs ├── premade ├── batch.rs ├── map.rs ├── mapreduce.rs ├── mod.rs ├── selector.rs ├── shuffle.rs ├── sort.rs └── stateful.rs └── tests.rs /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jafioti/dataflow/85f98f5ee8736835b26d51609d722f48b4f65b88/.DS_Store -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | defaults: 15 | run: 16 | working-directory: ./ 17 | runs-on: ubuntu-latest 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Build 22 | run: cargo build --verbose 23 | - name: Run tests 24 | run: cargo test --verbose 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | 4 | dataflow_nlp/target 5 | dataflow_nlp/Cargo.lock -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "dataflow" 3 | version = "0.4.0" 4 | authors = ["Joe Fioti "] 5 | edition = "2021" 6 | description = "Dataflow is a data processing library, primarily for machine learning." 7 | license = "MIT OR Apache-2.0" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [workspace] 12 | 13 | members = [ 14 | "dataflow_nlp", 15 | ] 16 | 17 | [dependencies] 18 | rand = "0.8" 19 | thread-control = "0.1" 20 | itertools = "0.9" 21 | 22 | #rayon = "1.7" 23 | #multiqueue = "0.3" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dataflow 2 | 3 | ![image](https://www.sidekickai.co/static/images/other/dag.png) 4 | 5 | [![CI Status](https://img.shields.io/github/actions/workflow/status/Sidekick-AI/dataflow/rust.yml?style=for-the-badge&logo=github-actions&logoColor=white&branch=main)](https://github.com/Sidekick-AI/dataflow/actions) 6 | [![Current Crates.io Version](https://img.shields.io/crates/v/dataflow.svg?style=for-the-badge&logo=rust)](https://crates.io/crates/dataflow) 7 | [![Documentation](https://img.shields.io/badge/docs-online-5023dd.svg?style=for-the-badge&logoColor=white&logo=data:image/svg+xml;base64,PHN2ZyByb2xlPSJpbWciIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgdmlld0JveD0iMCAwIDUxMiA1MTIiPjxwYXRoIGZpbGw9IiNmNWY1ZjUiIGQ9Ik00ODguNiAyNTAuMkwzOTIgMjE0VjEwNS41YzAtMTUtOS4zLTI4LjQtMjMuNC0zMy43bC0xMDAtMzcuNWMtOC4xLTMuMS0xNy4xLTMuMS0yNS4zIDBsLTEwMCAzNy41Yy0xNC4xIDUuMy0yMy40IDE4LjctMjMuNCAzMy43VjIxNGwtOTYuNiAzNi4yQzkuMyAyNTUuNSAwIDI2OC45IDAgMjgzLjlWMzk0YzAgMTMuNiA3LjcgMjYuMSAxOS45IDMyLjJsMTAwIDUwYzEwLjEgNS4xIDIyLjEgNS4xIDMyLjIgMGwxMDMuOS01MiAxMDMuOSA1MmMxMC4xIDUuMSAyMi4xIDUuMSAzMi4yIDBsMTAwLTUwYzEyLjItNi4xIDE5LjktMTguNiAxOS45LTMyLjJWMjgzLjljMC0xNS05LjMtMjguNC0yMy40LTMzLjd6TTM1OCAyMTQuOGwtODUgMzEuOXYtNjguMmw4NS0zN3Y3My4zek0xNTQgMTA0LjFsMTAyLTM4LjIgMTAyIDM4LjJ2LjZsLTEwMiA0MS40LTEwMi00MS40di0uNnptODQgMjkxLjFsLTg1IDQyLjV2LTc5LjFsODUtMzguOHY3NS40em0wLTExMmwtMTAyIDQxLjQtMTAyLTQxLjR2LS42bDEwMi0zOC4yIDEwMiAzOC4ydi42em0yNDAgMTEybC04NSA0Mi41di03OS4xbDg1LTM4Ljh2NzUuNHptMC0xMTJsLTEwMiA0MS40LTEwMi00MS40di0uNmwxMDItMzguMiAxMDIgMzguMnYuNnoiPjwvcGF0aD48L3N2Zz4K)](https://docs.rs/dataflow/0.1.0/dataflow/) 8 | 9 | Dataflow is a data processing library that provides composable primatives to build flexible, fast and statically typed data pipelines. The pipeline is a directed acyclic dataflow graph, which a dataloader can run on a seperate thread to feed data-hungry applications. 10 | 11 | ## Usage 12 | To build a pipeline, first start with a loader Node: 13 | ```rust 14 | use dataflow::prelude::*; 15 | 16 | fn main() { 17 | let pipeline = FileLoader::from_directory("my_data_directory"); 18 | } 19 | ``` 20 | The FileLoader loads the files from the directory in a random order. Next add a transformation to it with the `map()` function: 21 | ```rust 22 | let pipeline = FileLoader::from_directory("my_data_directory") 23 | .map(|(_, text)| format!("Hello {}", text)) // Add hello to each file 24 | ``` 25 | `map()` takes in a Node that processes a single sample at a time. If we want to do batch processing, we can use `.chain()` which takes a Node that can process a batch at a time. 26 | 27 | Important note: **All functions and closures are also Nodes!** This means that whenever we want to add a stateless transformation, we could just use a function. In this case, the closure takes in a single datapoint and outputs a single datapoint. 28 | 29 | Now we've added "Hello " to every line, let's use a tokenizer from `dataflow_nlp` in our pipeline: 30 | ```rust 31 | // Our tokenizer 32 | let tokenizer = WordpieceTokenizer::load(); 33 | 34 | // Our pipeline 35 | let pipeline = FileLoader::from_directory("my_data_directory") 36 | .map(|(_, text)| format!("Hello {}", text)) // Add hello to each file 37 | .chain(tokenizer); // Tokenize the lines 38 | 39 | ``` 40 | Great! Now our data gets efficiently tokenized in batches. Right now, we will get single tokenized sentences out of the pipeline one at a time. But what if we wanted to get batches out? Let's use a Batch node: 41 | ```rust 42 | 43 | // Our tokenizer 44 | let tokenizer = dataflow_nlp::tokenization::WordpieceTokenizer::load(); 45 | 46 | // Our pipeline 47 | let pipeline = FileLoader::from_directory("my_data_directory") 48 | .map(|(_, text)| format!("Hello {}", text)) // Add hello to each file 49 | .chain(tokenizer) // Tokenize the files 50 | .chain(Batch::new(64)); // Create batches of 64 51 | ``` 52 | That's it! We'll now get batches of 64 tokenized sentences. 53 | 54 | ### Loader Nodes 55 | As discussed before, everything in the pipeline implements the `Node` trait. RandomLoader is also a node! So the question arises, since data originates from it, and since Nodes need an *input* and an *output*, what does it take as an input? Simple, it takes as input Vec<()>, which is what the pipeline will start with, and produces data (Vec) to send through the pipeline. This pattern is the same across all Nodes where data originates. 56 | 57 | ### Custom Nodes 58 | In fact, you can implement your own Nodes as well, by implementing the `Node` trait! 59 | ```rust 60 | pub trait Node { 61 | type Output; 62 | 63 | /// Process a batch of data 64 | fn process(&mut self, input: Input) -> Self::Output; 65 | /// Reset signal propogates through pipeline 66 | fn reset(&mut self) {} 67 | /// Get number of examples left 68 | fn data_remaining(&self, before: usize) -> usize { 69 | before // Defaults to same as previous remaining data 70 | } 71 | } 72 | ``` 73 | Your custom nodes can then be inserted directly into the pipeline! 74 | 75 | ### Dataloader 76 | Since we built this cool pipeline, what can we do with it? Well for starters, we could simply call process() and feed in some data: 77 | ```rust 78 | // The RandomLoader takes in a () for each sample, so we pass in a batch as Vec<()> 79 | let output: Vec>> = pipeline.process(vec![(); 128]) 80 | 81 | // Output should now contain 2 batches of 64 tokenized sentences from our files with "Hello" prepended. 82 | ``` 83 | 84 | Let's do something cooler. Let's put it in a Dataloader and use it in an ML training loop: 85 | ```rust 86 | // Make the dataloader 87 | let mut dataloader = Dataloader(pipeline); 88 | 89 | // Training loop 90 | for example in &mut dataloader { 91 | // Now example is a vector of tokenized strings! 92 | // Do with them what you please... 93 | } 94 | ``` 95 | 96 | To Do: 97 | - [ ] Make dataloader use a multiqueue instead of draining all examples into buffer on main thread 98 | - [ ] Make auto-parallel pipeline Node using rayon 99 | - [ ] Add async ability and remote sources. (blocked by stable async traits) -------------------------------------------------------------------------------- /dataflow_nlp/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /dataflow_nlp/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "dataflow_nlp" 3 | version = "0.1.1" 4 | authors = ["Joe Fioti "] 5 | edition = "2021" 6 | description = "Dataflow is a data processing library, primarily for machine learning." 7 | license = "MIT OR Apache-2.0" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | dataflow = "0.4" 13 | tokenizers = "0.11" 14 | serde_json = "1.0" 15 | serde = { version = "1.0", features = ["derive"] } 16 | regex = "1.5" 17 | rand = "0.8" 18 | lentrait = "0.2" 19 | -------------------------------------------------------------------------------- /dataflow_nlp/src/batching/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod tests; 3 | 4 | use lentrait::Len; 5 | 6 | /// Create a pad mask based on the values in the batch (batch shape: batch size, seq len) 7 | pub fn pad_mask(batch: &[Vec], pad_value: T) -> Vec> { 8 | let mut mask: Vec> = vec![Vec::with_capacity(batch[0].len()); batch.len()]; 9 | for (i, seq) in batch.iter().enumerate() { 10 | for token in seq { 11 | mask[i].push(*token == pad_value); 12 | } 13 | } 14 | mask 15 | } 16 | 17 | /// Pad all sequences to the length of the longest sequence 18 | pub fn pad_batch(batch: &mut [Vec], pad_value: T) { 19 | // Get longest example 20 | let mut longest = 0; 21 | for example in batch.iter() { 22 | if example.len() > longest { 23 | longest = example.len(); 24 | } 25 | } 26 | 27 | // Pad all sequences to be longest 28 | for example in batch.iter_mut() { 29 | while example.len() < longest { 30 | example.push(pad_value.clone()); 31 | } 32 | } 33 | } 34 | 35 | /// Filters lists by a max length and returns the ones under the max 36 | pub fn filter_by_length( 37 | lists: &mut [Vec], 38 | min_length: Option, 39 | max_length: Option, 40 | ) { 41 | // Loop through elements in all lists 42 | for i in (0..lists[0].len()).rev() { 43 | // Loop through each list 44 | for x in 0..lists.len() { 45 | // If element length is greater than max_length or less than min_length, remove element i from every list 46 | if lists[x][i].len() > max_length.unwrap_or(usize::MAX) 47 | || lists[x][i].len() < min_length.unwrap_or(0) 48 | { 49 | for list in lists.iter_mut() { 50 | list.remove(i); 51 | } 52 | break; 53 | } 54 | } 55 | } 56 | } 57 | 58 | /// Shuffles multiple lists of the same length in the same ways 59 | pub fn shuffle_lists(lists: &mut [Vec]) { 60 | use rand::seq::SliceRandom; 61 | use rand::thread_rng; 62 | 63 | // Zip lists 64 | let mut zipped: Vec> = vec![Vec::with_capacity(lists.len()); lists[0].len()]; 65 | for list in lists.iter() { 66 | for (i, item) in list.iter().enumerate() { 67 | zipped[i].push(item.clone()); 68 | } 69 | } 70 | // Shuffle 71 | zipped.shuffle(&mut thread_rng()); 72 | // Unzip lists 73 | for (x, list) in lists.iter_mut().enumerate() { 74 | for (i, item) in list.iter_mut().enumerate() { 75 | *item = zipped[i][x].clone(); 76 | } 77 | } 78 | } 79 | 80 | /// Sort lists by length. Uses the lengths of the elements in the first list passed in 81 | pub fn sort_lists_by_length( 82 | lists: &mut [Vec], 83 | longest_first: Option, 84 | ) { 85 | for i in 1..lists.len() { 86 | assert!(lists[i].len() == lists[0].len()) 87 | } // Ensure all lists are the same length 88 | 89 | // Zip lists 90 | let mut zipped: Vec> = vec![Vec::with_capacity(lists.len()); lists[0].len()]; 91 | for list in lists.iter() { 92 | for (i, item) in list.iter().enumerate() { 93 | zipped[i].push(item.clone()); 94 | } 95 | } 96 | // Sort lists 97 | zipped.sort_unstable_by(|a, b| { 98 | a[0].len() 99 | .partial_cmp(&b[0].len()) 100 | .expect("NaN found in lengths!") 101 | }); 102 | // Reverse if longest first 103 | if longest_first.unwrap_or(false) { 104 | zipped.reverse() 105 | } 106 | // Unzip lists 107 | for (x, list) in lists.iter_mut().enumerate() { 108 | for (i, item) in list.iter_mut().enumerate() { 109 | *item = zipped[i][x].clone(); 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /dataflow_nlp/src/batching/tests.rs: -------------------------------------------------------------------------------- 1 | // BATCHING TESTS 2 | use crate::batching; 3 | 4 | #[test] 5 | fn pad_mask_test() { 6 | let batch = vec![vec!["d", "hello", "how"], vec!["hi", "yo", "PAD"]]; 7 | let expected_mask = vec![vec![false, false, false], vec![false, false, true]]; 8 | let pad_mask = batching::pad_mask(&batch, "PAD"); 9 | assert_eq!(expected_mask, pad_mask); 10 | } 11 | 12 | #[test] 13 | fn pad_batch_test() { 14 | let mut seqs = vec![vec![1, 2, 3, 1], vec![1, 4, 6, 2, 3, 5, 67]]; 15 | let expected_padded_batch = vec![vec![1, 2, 3, 1, 0, 0, 0], vec![1, 4, 6, 2, 3, 5, 67]]; 16 | batching::pad_batch(&mut seqs, 0); 17 | assert_eq!(seqs, expected_padded_batch); 18 | } 19 | 20 | #[test] 21 | fn filter_by_length_test() { 22 | let mut seqs = vec![ 23 | vec![vec![1, 2, 3, 1], vec![1, 4, 6, 2, 3, 5, 67], vec![1, 2, 3]], 24 | vec![vec![1, 1], vec![1, 67], vec![1, 2, 3]], 25 | ]; 26 | let expected_seqs = vec![ 27 | vec![vec![1, 2, 3, 1], vec![1, 2, 3]], 28 | vec![vec![1, 1], vec![1, 2, 3]], 29 | ]; 30 | batching::filter_by_length(&mut seqs, None, Some(6)); 31 | assert_eq!(seqs, expected_seqs); 32 | } 33 | 34 | #[test] 35 | fn shuffle_lists_test() { 36 | let mut seqs = vec![ 37 | vec![vec![1, 2, 3, 1], vec![1, 4, 6, 2, 3, 5, 67], vec![1, 2, 3]], 38 | vec![vec![1, 1], vec![1, 67], vec![1, 2, 3]], 39 | ]; 40 | let orig_seqs = seqs.clone(); 41 | for _ in 0..10 { 42 | batching::shuffle_lists(&mut seqs); 43 | if seqs != orig_seqs { 44 | break; 45 | } 46 | } 47 | assert_ne!(seqs, orig_seqs); 48 | } 49 | 50 | #[test] 51 | fn sort_lists_by_length_test() { 52 | let mut seqs = vec![ 53 | vec![ 54 | "hello".to_string(), 55 | "how are you".to_string(), 56 | "yo".to_string(), 57 | ], 58 | vec!["hey".to_string(), "wow".to_string(), "who".to_string()], 59 | ]; 60 | let sorted_seqs = vec![ 61 | vec![ 62 | "yo".to_string(), 63 | "hello".to_string(), 64 | "how are you".to_string(), 65 | ], 66 | vec!["who".to_string(), "hey".to_string(), "wow".to_string()], 67 | ]; 68 | let reverse_sorted_seqs = vec![ 69 | vec![ 70 | "how are you".to_string(), 71 | "hello".to_string(), 72 | "yo".to_string(), 73 | ], 74 | vec!["wow".to_string(), "hey".to_string(), "who".to_string()], 75 | ]; 76 | batching::sort_lists_by_length(&mut seqs, Some(false)); 77 | assert_eq!(seqs, sorted_seqs); 78 | batching::sort_lists_by_length(&mut seqs, Some(true)); 79 | assert_eq!(seqs, reverse_sorted_seqs); 80 | } 81 | -------------------------------------------------------------------------------- /dataflow_nlp/src/lib.rs: -------------------------------------------------------------------------------- 1 | /// Uilities for dealing with batches, such as shuffling and sorting batches 2 | pub mod batching; 3 | /// Dataflow pipeline nodes 4 | pub mod pipelines; 5 | /// All tokenization and untokenization 6 | pub mod tokenization; 7 | /// Vocab object and the functions to load different vocabularies 8 | pub mod vocab; 9 | 10 | #[cfg(test)] 11 | mod tests; 12 | -------------------------------------------------------------------------------- /dataflow_nlp/src/pipelines/line_loader.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fs::File, 3 | io::{BufRead, BufReader}, 4 | path::Path, 5 | }; 6 | 7 | use dataflow::pipeline::*; 8 | 9 | /// Given files, randomly load segments seperated by a delimeter 10 | pub struct RandomLoader { 11 | files: Vec, // The files to load from 12 | delimeter: String, // The delimiter to split examples by 13 | total_examples: usize, 14 | currently_loaded_index: usize, // The last example we loaded as an index of the load_order vector (starts at 0) 15 | max_index: usize, // The max index to load 16 | min_index: usize, // The min index to load 17 | } 18 | 19 | impl RandomLoader { 20 | pub fn new(files: &[T]) -> Self { 21 | RandomLoader { 22 | files: files.iter().map(|s| s.to_string()).collect(), 23 | delimeter: "\n".to_string(), 24 | total_examples: 0, 25 | currently_loaded_index: 0, 26 | min_index: 0, 27 | max_index: usize::MAX, 28 | } 29 | } 30 | 31 | /// Create a new RandomLoader with all files in a directory 32 | pub fn from_directory>(path: T) -> Self { 33 | let files = std::fs::read_dir(path) 34 | .unwrap() 35 | .map(|r| r.unwrap().path().to_str().unwrap().to_string()) 36 | .collect(); 37 | RandomLoader { 38 | files, 39 | delimeter: "\n".to_string(), 40 | total_examples: 0, 41 | currently_loaded_index: 0, 42 | min_index: 0, 43 | max_index: usize::MAX, 44 | } 45 | } 46 | 47 | pub fn with_delimeter(self, delimeter: String) -> Self { 48 | RandomLoader { delimeter, ..self } 49 | } 50 | 51 | pub fn max_index(self, max_index: usize) -> Self { 52 | RandomLoader { max_index, ..self } 53 | } 54 | 55 | pub fn min_index(self, min_index: usize) -> Self { 56 | RandomLoader { min_index, ..self } 57 | } 58 | } 59 | 60 | /// Load segments of text seperated by a delimeter, from start to end segments 61 | fn load_text_segments( 62 | path: &str, 63 | indexes: &[usize], 64 | current_segment_index: &mut usize, 65 | delimiter: &str, 66 | ) -> Result, std::io::Error> { 67 | let file = File::open(path)?; 68 | let reader = BufReader::new(file); 69 | let mut segments = Vec::new(); 70 | let mut current_segment = String::new(); 71 | 72 | for line in reader.lines().flatten() { 73 | if line.contains(delimiter) { 74 | let mut line_segments = line.split(delimiter); 75 | 76 | // Handle beginning segment 77 | if let Some(segment) = line_segments.next() { 78 | if *current_segment_index == indexes[segments.len()] { 79 | segments.push(format!("{current_segment}{segment}")); 80 | current_segment.clear(); 81 | } 82 | *current_segment_index += 1; 83 | } 84 | // Handle middle segments 85 | for segment in line_segments { 86 | let Some(&ind) = indexes.get(segments.len()) else { 87 | return Ok(segments); 88 | }; 89 | if *current_segment_index == ind { 90 | segments.push(format!("{current_segment}{segment}")); 91 | } 92 | *current_segment_index += 1; 93 | } 94 | // We aren't supposed to finalize the last segment 95 | if let Some(last) = segments.pop() { 96 | current_segment = last; 97 | current_segment.push('\n'); 98 | *current_segment_index -= 1; 99 | } 100 | } else if *current_segment_index == indexes[segments.len()] { 101 | current_segment.push_str(&line); 102 | current_segment.push('\n'); 103 | } 104 | 105 | if segments.len() >= indexes.len() { 106 | break; 107 | } 108 | } 109 | 110 | Ok(segments) 111 | } 112 | 113 | impl Node> for RandomLoader { 114 | type Output = Vec; 115 | 116 | fn process(&mut self, input: Vec<()>) -> Self::Output { 117 | // Run through each example in each file 118 | let mut current_index = 0; 119 | let mut loaded = vec![]; 120 | for file in &self.files { 121 | loaded.append( 122 | &mut load_text_segments( 123 | file, 124 | &(self.currently_loaded_index 125 | ..(self.currently_loaded_index + input.len()).min(self.max_index)) 126 | .collect::>(), 127 | &mut current_index, 128 | &self.delimeter, 129 | ) 130 | .unwrap(), 131 | ); 132 | if loaded.len() >= input.len() { 133 | break; 134 | } 135 | } 136 | 137 | loaded.truncate(input.len()); 138 | self.currently_loaded_index += loaded.len(); 139 | loaded 140 | } 141 | 142 | fn reset(&mut self) { 143 | // Count the total number of examples 144 | self.total_examples = 0; 145 | for file in &self.files { 146 | let reader = BufReader::new(File::open(file).unwrap()); 147 | let mut delimeter_count = 0; 148 | if self.delimeter == "\n" { 149 | delimeter_count = reader.lines().count(); 150 | } else { 151 | delimeter_count += reader 152 | .lines() 153 | .flatten() 154 | .map(|line| line.matches(&self.delimeter).count()) 155 | .sum::(); 156 | delimeter_count += 1; // Since delimeters divide the examples, there should be 1 more example than delimeter 157 | } 158 | self.total_examples += delimeter_count; 159 | if self.total_examples >= self.max_index { 160 | break; 161 | } 162 | } 163 | self.total_examples = self.total_examples.min(self.max_index - self.min_index); 164 | self.currently_loaded_index = self.min_index; 165 | } 166 | 167 | fn data_remaining(&self, _before: usize) -> usize { 168 | self.total_examples - (self.currently_loaded_index - self.min_index) 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /dataflow_nlp/src/pipelines/mod.rs: -------------------------------------------------------------------------------- 1 | mod line_loader; 2 | pub use line_loader::*; 3 | -------------------------------------------------------------------------------- /dataflow_nlp/src/tests.rs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jafioti/dataflow/85f98f5ee8736835b26d51609d722f48b4f65b88/dataflow_nlp/src/tests.rs -------------------------------------------------------------------------------- /dataflow_nlp/src/tokenization/alphabet.rs: -------------------------------------------------------------------------------- 1 | use super::Tokenizer; 2 | 3 | use dataflow::prelude::Node; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | #[derive(Serialize, Deserialize, Debug, Clone)] 7 | pub struct AlphabetTokenizer {} 8 | 9 | impl Tokenizer for AlphabetTokenizer { 10 | fn load() -> Self { 11 | AlphabetTokenizer {} 12 | } 13 | 14 | fn tokenize(&self, string: &str) -> Vec { 15 | let tokens: Vec = string.split("").map(|f| f.to_string()).collect(); 16 | tokens[1..tokens.len() - 1].to_vec() // For some reason, the split adds empty strings to each end 17 | } 18 | 19 | fn batch_tokenize(&self, strings: Vec) -> Vec> { 20 | strings 21 | .iter() 22 | .map(|string| string.split("").map(|f| f.to_string()).collect()) 23 | .collect() 24 | } 25 | 26 | fn untokenize(&self, tokens: Vec) -> String { 27 | tokens.join("") 28 | } 29 | 30 | fn batch_untokenize(&self, tokens: Vec>) -> Vec { 31 | tokens.iter().map(|tokens| tokens.join("")).collect() 32 | } 33 | } 34 | 35 | impl Node for AlphabetTokenizer { 36 | type Output = Vec; 37 | fn process(&mut self, input: String) -> Self::Output { 38 | self.tokenize(&input) 39 | } 40 | } 41 | 42 | impl Node<&str> for AlphabetTokenizer { 43 | type Output = Vec; 44 | fn process(&mut self, input: &str) -> Self::Output { 45 | self.tokenize(input) 46 | } 47 | } 48 | 49 | impl Node> for AlphabetTokenizer { 50 | type Output = Vec>; 51 | 52 | fn process(&mut self, input: Vec) -> Self::Output { 53 | self.batch_tokenize(input) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /dataflow_nlp/src/tokenization/bpe.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use super::Tokenizer; 4 | 5 | use dataflow::prelude::Node; 6 | use serde::{Deserialize, Serialize}; 7 | use tokenizers::Tokenizer as HFTokenizer; 8 | 9 | #[derive(Serialize, Deserialize, Debug)] 10 | pub struct BPETokenizer { 11 | hf_tokenizer: HFTokenizer, 12 | } 13 | 14 | impl Tokenizer for BPETokenizer { 15 | fn load() -> Self { 16 | use serde_json::Value; 17 | use tokenizers::models::bpe::BPE; 18 | // Create token2index map 19 | // Open vocab file 20 | let json: HashMap = serde_json::from_str( 21 | &include_str!("../resources/bpe_vocab.json") 22 | .replace('/', "") 23 | .replace('Ġ', ""), 24 | ) 25 | .expect("Error parsing BPE vocab file!"); 26 | // Build sorted vector of tokens from hashmap 27 | let mut token_vec: Vec = vec![String::from(""); 50265]; // Happen to know the largest index in the json is 50264, this is a bad system 28 | for token in json.keys() { 29 | token_vec[json[token].as_u64().unwrap() as usize] = token.clone(); 30 | } 31 | // Build vocab 32 | let mut token2index = HashMap::with_capacity(token_vec.len()); 33 | for token in token_vec { 34 | if !token.is_empty() { 35 | token2index.insert(token.to_string(), token2index.len() as u32); 36 | } 37 | } 38 | // Create tokenizer 39 | let bpe_builder = BPE::builder(); 40 | let mut merges: Vec<(String, String)> = Vec::new(); 41 | let lines: Vec<&str> = include_str!("../resources/bpe_merges.txt") 42 | .split('\n') 43 | .collect(); 44 | for line in lines { 45 | let line = String::from(line) 46 | .replace(['Ġ', '\n'], "") 47 | .replace("##", ""); 48 | // Filter out junk 49 | if line.contains(' ') && !line.contains('#') { 50 | let line: Vec<&str> = line.split(' ').collect(); 51 | // Make sure vocab contains both tokens and combined token 52 | if token2index.contains_key(&line[0].to_string()) 53 | && token2index.contains_key(&line[1].to_string()) 54 | && token2index.contains_key(&format!("{}{}", line[0], line[1])) 55 | { 56 | merges.push((line[0].to_string(), line[1].to_string())); 57 | } 58 | } 59 | } 60 | 61 | let bpe_builder = bpe_builder.vocab_and_merges(token2index, merges); 62 | let bpe = bpe_builder 63 | .unk_token("[UNK]".into()) 64 | .build() 65 | .expect("BPE Tokenizer failed to build!"); 66 | 67 | BPETokenizer { 68 | hf_tokenizer: HFTokenizer::new(bpe), 69 | } 70 | } 71 | 72 | fn tokenize(&self, string: &str) -> Vec { 73 | tokenizers::utils::parallelism::set_parallelism(true); 74 | // Create tokenizer and tokenize 75 | let encoding = self 76 | .hf_tokenizer 77 | .encode(string, false) 78 | .expect("BPE tokenization failed!"); 79 | // Convert back to string 80 | encoding.get_tokens().to_vec() 81 | } 82 | 83 | fn batch_tokenize(&self, strings: Vec) -> Vec> { 84 | tokenizers::utils::parallelism::set_parallelism(true); 85 | // Create tokenizer and tokenize 86 | let encodings = self 87 | .hf_tokenizer 88 | .encode_batch(strings, false) 89 | .expect("BPE tokenization failed!"); 90 | // Convert back to strings 91 | let mut tokens: Vec> = Vec::with_capacity(encodings.len()); 92 | for encoding in encodings { 93 | tokens.push(encoding.get_tokens().to_vec()); 94 | } 95 | tokens 96 | } 97 | 98 | fn untokenize(&self, tokens: Vec) -> String { 99 | tokens.join("") 100 | } 101 | 102 | fn batch_untokenize(&self, tokens: Vec>) -> Vec { 103 | tokens.iter().map(|tokens| tokens.join("")).collect() 104 | } 105 | } 106 | 107 | impl Node for BPETokenizer { 108 | type Output = Vec; 109 | fn process(&mut self, input: String) -> Self::Output { 110 | self.tokenize(&input) 111 | } 112 | } 113 | 114 | impl Node<&str> for BPETokenizer { 115 | type Output = Vec; 116 | fn process(&mut self, input: &str) -> Self::Output { 117 | self.tokenize(input) 118 | } 119 | } 120 | 121 | impl Node> for BPETokenizer { 122 | type Output = Vec>; 123 | fn process(&mut self, input: Vec) -> Self::Output { 124 | self.batch_tokenize(input) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /dataflow_nlp/src/tokenization/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod tests; 3 | use std::fmt::Debug; 4 | 5 | // Tokenizers 6 | mod wordpiece; 7 | pub use wordpiece::WordpieceTokenizer; 8 | mod bpe; 9 | pub use bpe::BPETokenizer; 10 | mod whitespace; 11 | pub use whitespace::WhitespaceTokenizer; 12 | mod alphabet; 13 | pub use alphabet::AlphabetTokenizer; 14 | mod sentence; 15 | pub use sentence::SentenceTokenizer; 16 | 17 | /// A trait to implement for all tokenizers, contains basic tokenizing and untokenizing functions 18 | pub trait Tokenizer: Debug + Send + Sync { 19 | /// Load the tokenizer 20 | fn load() -> Self; 21 | /// Tokenize a single string 22 | fn tokenize(&self, string: &str) -> Vec; 23 | /// Tokenize a batch of strings 24 | fn batch_tokenize(&self, strings: Vec) -> Vec>; 25 | /// Untokenize a single string 26 | fn untokenize(&self, tokens: Vec) -> String; 27 | /// Untokenize a batch of strings 28 | fn batch_untokenize(&self, tokens: Vec>) -> Vec; 29 | } 30 | -------------------------------------------------------------------------------- /dataflow_nlp/src/tokenization/sentence.rs: -------------------------------------------------------------------------------- 1 | use super::Tokenizer; 2 | 3 | use dataflow::prelude::Node; 4 | use regex::Regex; 5 | use serde::{Deserialize, Serialize}; 6 | 7 | #[derive(Serialize, Deserialize, Debug, Clone)] 8 | pub struct SentenceTokenizer { 9 | keep_punctuation: bool, 10 | } 11 | 12 | impl SentenceTokenizer { 13 | pub fn new(keep_punctuation: bool) -> Self { 14 | SentenceTokenizer { keep_punctuation } 15 | } 16 | } 17 | 18 | impl Tokenizer for SentenceTokenizer { 19 | fn load() -> Self { 20 | SentenceTokenizer { 21 | keep_punctuation: true, 22 | } 23 | } 24 | 25 | fn tokenize(&self, string: &str) -> Vec { 26 | if self.keep_punctuation { 27 | split_keep(string) 28 | .into_iter() 29 | .map(|i| i.trim().to_string()) 30 | .collect() 31 | } else { 32 | let regex = Regex::new(r"!|.|\?|;").unwrap(); 33 | regex.split(string).map(|i| i.trim().to_string()).collect() 34 | } 35 | } 36 | 37 | fn batch_tokenize(&self, strings: Vec) -> Vec> { 38 | let regex = Regex::new(r"!|.|\?|;").unwrap(); 39 | if self.keep_punctuation { 40 | strings 41 | .iter() 42 | .map(|string| { 43 | split_keep(string) 44 | .into_iter() 45 | .map(|i| i.trim().to_string()) 46 | .collect() 47 | }) 48 | .collect() 49 | } else { 50 | strings 51 | .iter() 52 | .map(|string| regex.split(string).map(|i| i.trim().to_string()).collect()) 53 | .collect() 54 | } 55 | } 56 | 57 | fn untokenize(&self, tokens: Vec) -> String { 58 | tokens.join(" ") 59 | } 60 | 61 | fn batch_untokenize(&self, tokens: Vec>) -> Vec { 62 | tokens.iter().map(|tokens| tokens.join(" ")).collect() 63 | } 64 | } 65 | 66 | fn split_keep(text: &str) -> Vec<&str> { 67 | let mut result = Vec::new(); 68 | let chars = text.chars(); 69 | let mut last_match = 0; 70 | for (i, char) in chars.enumerate() { 71 | if char == '!' || char == '.' || char == '?' || char == ';' { 72 | // If we have more than one letter that came before, add it to the results 73 | if i - last_match > 0 { 74 | result.push(&text[last_match..i + 1]); 75 | last_match = i + 1; 76 | } 77 | } 78 | } 79 | if last_match < text.len() - 1 { 80 | result.push(&text[last_match..]); 81 | } 82 | result 83 | } 84 | 85 | impl Node for SentenceTokenizer { 86 | type Output = Vec; 87 | 88 | fn process(&mut self, input: String) -> Self::Output { 89 | self.tokenize(&input) 90 | } 91 | } 92 | 93 | impl Node<&str> for SentenceTokenizer { 94 | type Output = Vec; 95 | 96 | fn process(&mut self, input: &str) -> Self::Output { 97 | self.tokenize(input) 98 | } 99 | } 100 | 101 | impl Node> for SentenceTokenizer { 102 | type Output = Vec>; 103 | 104 | fn process(&mut self, input: Vec) -> Self::Output { 105 | self.batch_tokenize(input) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /dataflow_nlp/src/tokenization/tests.rs: -------------------------------------------------------------------------------- 1 | // TOKENIZATION TESTS 2 | use super::*; 3 | 4 | #[test] 5 | fn tokenize_alphabet() { 6 | let letters: Vec = vec!["h", "e", "l", "l", "o"] 7 | .iter() 8 | .map(|t| (*t).to_string()) 9 | .collect(); 10 | let tokenizer = AlphabetTokenizer::load(); 11 | assert_eq!(tokenizer.tokenize("hello"), letters); 12 | } 13 | 14 | #[test] 15 | fn tokenize_spaces() { 16 | let tokens: Vec = vec!["hello", "how", "are", "you"] 17 | .iter() 18 | .map(|t| (*t).to_string()) 19 | .collect(); 20 | let tokenizer = WhitespaceTokenizer::load(); 21 | assert_eq!(tokenizer.tokenize("hello how are you"), tokens); 22 | } 23 | 24 | #[test] 25 | fn tokenize_sentences() { 26 | let tokens: Vec = vec!["hello, how are you?", "good, how are you?"] 27 | .iter() 28 | .map(|t| (*t).to_string()) 29 | .collect(); 30 | let tokenizer = SentenceTokenizer::load(); 31 | assert_eq!( 32 | tokenizer.tokenize("hello, how are you? good, how are you?"), 33 | tokens 34 | ); 35 | } 36 | 37 | #[test] 38 | fn tokenize_bpe() { 39 | let tokens: Vec = vec!["hello", ",", " ", "how", " ", "are", " ", "you"] 40 | .iter() 41 | .map(|str| str.to_string()) 42 | .collect(); 43 | let tokenizer = BPETokenizer::load(); 44 | assert_eq!( 45 | tokenizer.batch_tokenize(vec!["hello, how are you".to_string()]), 46 | vec![tokens.clone()] 47 | ); 48 | assert_eq!(tokenizer.tokenize("hello, how are you"), tokens); 49 | } 50 | 51 | #[test] 52 | fn tokenize_wordpiece() { 53 | let tokens: Vec = vec!["hello", ",", "how", "are", "you"] 54 | .iter() 55 | .map(|str| str.to_string()) 56 | .collect(); 57 | let tokenizer = WordpieceTokenizer::load(); 58 | assert_eq!( 59 | tokenizer.batch_tokenize(vec!["hello, how are you".to_string()]), 60 | vec![tokens.clone()] 61 | ); 62 | assert_eq!(tokenizer.tokenize("hello, how are you"), tokens); 63 | } 64 | 65 | #[test] 66 | fn untokenize_bpe() { 67 | let sentence = "hello, how are you?"; 68 | let tokenizer = BPETokenizer::load(); 69 | let tokens = tokenizer.tokenize(sentence); 70 | assert_eq!(tokenizer.untokenize(tokens), sentence.to_string()); 71 | } 72 | 73 | #[test] 74 | fn untokenize_wordpiece() { 75 | let sentence = "hello, how are you?"; 76 | let tokenizer = WordpieceTokenizer::load(); 77 | let tokens = tokenizer.tokenize(sentence); 78 | assert_eq!(tokenizer.untokenize(tokens), sentence.to_string()); 79 | } 80 | 81 | #[test] 82 | fn untokenize_alphabet() { 83 | let sentence = "hello, how are you?"; 84 | let tokenizer = AlphabetTokenizer::load(); 85 | let tokens = tokenizer.tokenize(sentence); 86 | assert_eq!(tokenizer.untokenize(tokens), sentence.to_string()); 87 | } 88 | 89 | #[test] 90 | fn untokenize_spaces() { 91 | let sentence = "hello, how are you?"; 92 | let tokenizer = WhitespaceTokenizer::load(); 93 | let tokens = tokenizer.tokenize(sentence); 94 | assert_eq!(tokenizer.untokenize(tokens), sentence.to_string()); 95 | } 96 | 97 | #[test] 98 | fn untokenize_sentences() { 99 | let sentence = "hello how are you? good, how are you?"; 100 | let tokenizer = WhitespaceTokenizer::load(); 101 | let tokens = tokenizer.tokenize(sentence); 102 | assert_eq!(tokenizer.untokenize(tokens), sentence.to_string()); 103 | } 104 | -------------------------------------------------------------------------------- /dataflow_nlp/src/tokenization/whitespace.rs: -------------------------------------------------------------------------------- 1 | use super::Tokenizer; 2 | 3 | use dataflow::prelude::Node; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | #[derive(Serialize, Deserialize, Debug, Clone)] 7 | pub struct WhitespaceTokenizer {} 8 | 9 | impl Tokenizer for WhitespaceTokenizer { 10 | fn load() -> Self { 11 | WhitespaceTokenizer {} 12 | } 13 | 14 | fn tokenize(&self, string: &str) -> Vec { 15 | string.split(' ').map(|f| f.to_string()).collect() 16 | } 17 | 18 | fn batch_tokenize(&self, strings: Vec) -> Vec> { 19 | strings 20 | .iter() 21 | .map(|string| { 22 | let tokens: Vec = string.split("").map(|f| f.to_string()).collect(); 23 | tokens[1..tokens.len() - 1].to_vec() // For some reason, the split adds empty strings to each end 24 | }) 25 | .collect() 26 | } 27 | 28 | fn untokenize(&self, tokens: Vec) -> String { 29 | tokens.join(" ") 30 | } 31 | 32 | fn batch_untokenize(&self, tokens: Vec>) -> Vec { 33 | tokens.iter().map(|tokens| tokens.join(" ")).collect() 34 | } 35 | } 36 | 37 | impl Node for WhitespaceTokenizer { 38 | type Output = Vec; 39 | 40 | fn process(&mut self, input: String) -> Self::Output { 41 | self.tokenize(&input) 42 | } 43 | } 44 | 45 | impl Node<&str> for WhitespaceTokenizer { 46 | type Output = Vec; 47 | fn process(&mut self, input: &str) -> Self::Output { 48 | self.tokenize(input) 49 | } 50 | } 51 | 52 | impl Node> for WhitespaceTokenizer { 53 | type Output = Vec>; 54 | fn process(&mut self, input: Vec) -> Self::Output { 55 | self.batch_tokenize(input) 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /dataflow_nlp/src/tokenization/wordpiece.rs: -------------------------------------------------------------------------------- 1 | use super::Tokenizer; 2 | 3 | use dataflow::prelude::Node; 4 | use serde::{Deserialize, Serialize}; 5 | use tokenizers::Tokenizer as HFTokenizer; 6 | 7 | #[derive(Serialize, Deserialize, Debug)] 8 | pub struct WordpieceTokenizer { 9 | hf_tokenizer: HFTokenizer, 10 | } 11 | 12 | impl Tokenizer for WordpieceTokenizer { 13 | fn load() -> Self { 14 | tokenizers::utils::parallelism::set_parallelism(true); 15 | use std::collections::HashMap; 16 | use tokenizers::models::wordpiece::WordPiece; 17 | use tokenizers::pre_tokenizers::whitespace::Whitespace; 18 | // Build tokenizer 19 | let wordpiece_builder = WordPiece::builder(); 20 | let lines: Vec<&str> = include_str!("../resources/wordpiece_vocab.txt") 21 | .split('\n') 22 | .collect(); 23 | let mut hashmap: HashMap = HashMap::new(); 24 | for (i, line) in lines.iter().enumerate() { 25 | hashmap.insert(line.to_string(), i as u32); 26 | } 27 | let wordpiece_builder = wordpiece_builder.vocab(hashmap); 28 | let wordpiece = wordpiece_builder 29 | .build() 30 | .expect("WordPiece Tokenizer failed to build!"); 31 | 32 | let mut tokenizer = HFTokenizer::new(wordpiece); 33 | tokenizer.with_pre_tokenizer(Whitespace::default()); 34 | WordpieceTokenizer { 35 | hf_tokenizer: tokenizer, 36 | } 37 | } 38 | 39 | fn tokenize(&self, string: &str) -> Vec { 40 | tokenizers::utils::parallelism::set_parallelism(true); 41 | // Create tokenizer and tokenize 42 | let encoding = self 43 | .hf_tokenizer 44 | .encode(string, false) 45 | .expect("BPE tokenization failed!"); 46 | // Convert back to string 47 | encoding.get_tokens().to_vec() 48 | } 49 | 50 | fn batch_tokenize(&self, strings: Vec) -> Vec> { 51 | tokenizers::utils::parallelism::set_parallelism(true); 52 | // Create tokenizer and tokenize 53 | let encodings = self 54 | .hf_tokenizer 55 | .encode_batch(strings, false) 56 | .expect("WordPiece tokenization failed!"); 57 | // Convert back to strings 58 | let mut tokens: Vec> = Vec::with_capacity(encodings.len()); 59 | for encoding in encodings.iter() { 60 | tokens.push(encoding.get_tokens().to_vec()); 61 | } 62 | tokens 63 | } 64 | 65 | fn untokenize(&self, tokens: Vec) -> String { 66 | let mut untokenized_string = String::new(); 67 | for (i, token) in tokens.iter().enumerate() { 68 | if *token != "[PAD]" && *token != "[EOS]" { 69 | if token.contains("##") 70 | || [".", "?", "!", ",", "'", r#"""#] 71 | .iter() 72 | .any(|x| *x == token) 73 | || i == 0 74 | { 75 | untokenized_string = 76 | format!("{}{}", untokenized_string, token.replace("##", "")) 77 | } else { 78 | untokenized_string = format!("{} {}", untokenized_string, token) 79 | } 80 | } 81 | } 82 | untokenized_string 83 | } 84 | 85 | fn batch_untokenize(&self, tokens: Vec>) -> Vec { 86 | let mut untokenized_strings = vec![String::new(); tokens.len()]; 87 | for i in 0..tokens.len() { 88 | for x in 0..tokens[i].len() { 89 | if *tokens[i][x] != *"[PAD]" && *tokens[i][x] != *"[EOS]" { 90 | if tokens[i][x].contains("##") 91 | || [".", "?", "!", ",", "'", r#"""#] 92 | .iter() 93 | .any(|t| **t == tokens[i][x]) 94 | || x == 0 95 | { 96 | untokenized_strings[i] = format!( 97 | "{}{}", 98 | untokenized_strings[i], 99 | tokens[i][x].replace("##", "") 100 | ) 101 | } else { 102 | untokenized_strings[i] = 103 | format!("{} {}", untokenized_strings[i], tokens[i][x]) 104 | } 105 | } 106 | } 107 | } 108 | untokenized_strings 109 | } 110 | } 111 | 112 | impl Node for WordpieceTokenizer { 113 | type Output = Vec; 114 | 115 | fn process(&mut self, input: String) -> Self::Output { 116 | self.tokenize(&input) 117 | } 118 | } 119 | 120 | impl Node<&str> for WordpieceTokenizer { 121 | type Output = Vec; 122 | 123 | fn process(&mut self, input: &str) -> Self::Output { 124 | self.tokenize(input) 125 | } 126 | } 127 | 128 | impl Node> for WordpieceTokenizer { 129 | type Output = Vec>; 130 | 131 | fn process(&mut self, input: Vec) -> Self::Output { 132 | self.batch_tokenize(input) 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /dataflow_nlp/src/vocab/basic.rs: -------------------------------------------------------------------------------- 1 | use super::TokenNotFoundError; 2 | use dataflow::prelude::Node; 3 | use serde::{Deserialize, Serialize}; 4 | use std::collections::HashMap; 5 | 6 | /// The basic vocab type used internally in WordpieceVocab and BPEVocab 7 | #[derive(Serialize, Deserialize, Default, Clone)] 8 | pub(crate) struct BasicVocab { 9 | pub num_tokens: usize, 10 | pub token2index: HashMap, 11 | pub index2token: Vec, 12 | pub pad_token: usize, 13 | pub sos_token: usize, 14 | pub eos_token: usize, 15 | pub sep_token: usize, 16 | } 17 | 18 | impl BasicVocab { 19 | /// Make a new vocab 20 | pub fn new() -> Self { 21 | let mut voc = BasicVocab { 22 | num_tokens: 0, 23 | token2index: HashMap::new(), 24 | index2token: Vec::new(), 25 | pad_token: 0, 26 | sos_token: 1, 27 | eos_token: 2, 28 | sep_token: 3, 29 | }; 30 | voc.add_tokens(vec![ 31 | "[PAD]".to_string(), 32 | "[SOS]".to_string(), 33 | "[EOS]".to_string(), 34 | "[SEP]".to_string(), 35 | ]); 36 | voc 37 | } 38 | 39 | /// Returns num_tokens 40 | pub fn len(&self) -> usize { 41 | self.num_tokens 42 | } 43 | 44 | /// Add token to vocab 45 | pub fn add_token(&mut self, token: String) { 46 | self.token2index.insert(token.clone(), self.num_tokens); 47 | self.index2token.push(token); 48 | self.num_tokens += 1; 49 | } 50 | 51 | /// Add a vec of tokens to vocab 52 | pub fn add_tokens(&mut self, tokens: Vec) { 53 | self.index2token.extend(tokens.clone()); 54 | for (i, token) in tokens.iter().enumerate() { 55 | // Probably a more efficient way to do this and avoid the loop 56 | self.token2index.insert(token.clone(), self.num_tokens + i); 57 | } 58 | self.num_tokens += tokens.len(); 59 | } 60 | 61 | /// Remove a vec of tokens from vocab (NOT SURE IF THIS SHOULD BE KEPT) 62 | pub fn _remove_tokens(&mut self, tokens: Vec<&String>) { 63 | for token in tokens { 64 | if self.token2index.contains_key(token) { 65 | self._remove_token(token); 66 | } 67 | } 68 | } 69 | 70 | /// Remove token from vocab 71 | pub fn _remove_token(&mut self, token: &str) { 72 | // Loop through all higher token2index mappings and decrement (must be a more efficient way to do this) 73 | for i in (self.token2index[token]) + 1..self.index2token.len() { 74 | *self.token2index.get_mut(&self.index2token[i]).unwrap() -= 1; 75 | } 76 | self.index2token.remove(self.token2index[token]); 77 | self.token2index.remove(token); 78 | self.num_tokens -= 1; 79 | } 80 | 81 | /// Get vec of tokens from vec of indexes 82 | pub fn tokens_from_indexes( 83 | &self, 84 | indexes: &[usize], 85 | ) -> Result, TokenNotFoundError> { 86 | if *indexes.iter().max().unwrap() >= self.num_tokens { 87 | return Err(TokenNotFoundError); 88 | } // Make sure we aren't trying to get an index too big 89 | 90 | let mut tokens: Vec = Vec::with_capacity(indexes.len()); 91 | for index in indexes { 92 | tokens.push(self.index2token[*index].to_string()); 93 | } 94 | Ok(tokens) 95 | } 96 | 97 | /// Batched version of tokens_from_indexes 98 | pub fn batch_tokens_from_indexes( 99 | &self, 100 | indexes: &[Vec], 101 | ) -> Result>, TokenNotFoundError> { 102 | let mut tokens: Vec> = Vec::with_capacity(indexes.len()); 103 | for sent in indexes { 104 | tokens.push(self.tokens_from_indexes(sent)?); 105 | } 106 | Ok(tokens) 107 | } 108 | 109 | /// Get vec of indexes from vec of tokens 110 | pub fn indexes_from_tokens(&self, tokens: &[String]) -> Result, TokenNotFoundError> { 111 | let mut indexes: Vec = Vec::with_capacity(tokens.len()); 112 | for token in tokens { 113 | indexes.push(match self.token2index.get(token) { 114 | Some(index) => *index, 115 | None => { 116 | return Err(TokenNotFoundError); 117 | } 118 | }); 119 | } 120 | Ok(indexes) 121 | } 122 | 123 | /// Batched version of indexes_from_tokens 124 | pub fn batch_indexes_from_tokens( 125 | &self, 126 | tokens: &[Vec], 127 | ) -> Result>, TokenNotFoundError> { 128 | let mut indexes: Vec> = Vec::with_capacity(tokens.len()); 129 | for sent in tokens { 130 | indexes.push(self.indexes_from_tokens(sent)?); 131 | } 132 | Ok(indexes) 133 | } 134 | } 135 | 136 | impl Node>> for BasicVocab { 137 | type Output = Vec>; 138 | 139 | fn process(&mut self, input: Vec>) -> Self::Output { 140 | self.batch_indexes_from_tokens(&input).unwrap() 141 | } 142 | } 143 | 144 | impl Node> for BasicVocab { 145 | type Output = Vec; 146 | 147 | fn process(&mut self, input: Vec) -> Self::Output { 148 | self.indexes_from_tokens(&input).unwrap() 149 | } 150 | } 151 | 152 | impl Node>> for BasicVocab { 153 | type Output = Vec>; 154 | 155 | fn process(&mut self, input: Vec>) -> Self::Output { 156 | self.batch_tokens_from_indexes(&input).unwrap() 157 | } 158 | } 159 | 160 | impl Node> for BasicVocab { 161 | type Output = Vec; 162 | 163 | fn process(&mut self, input: Vec) -> Self::Output { 164 | self.tokens_from_indexes(&input).unwrap() 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /dataflow_nlp/src/vocab/bpe.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use super::{BasicVocab, TokenNotFoundError, Vocab}; 4 | use dataflow::prelude::Node; 5 | use serde::{Deserialize, Serialize}; 6 | 7 | #[derive(Serialize, Deserialize, Default, Clone)] 8 | pub struct BPEVocab { 9 | vocab: BasicVocab, 10 | } 11 | 12 | impl Vocab for BPEVocab { 13 | fn load() -> Self { 14 | use serde_json::Value; 15 | 16 | // Open vocab file 17 | let json: HashMap = serde_json::from_str( 18 | &include_str!("../resources/bpe_vocab.json") 19 | .replace('/', "") 20 | .replace('Ġ', ""), 21 | ) 22 | .expect("Error parsing BPE vocab file!"); 23 | // Build sorted vector of tokens from hashmap 24 | let mut token_vec: Vec = vec![String::from(""); 50265]; // Happen to know the largest index in the json is 50264, this is a bad system 25 | for token in json.keys() { 26 | token_vec[json[token].as_u64().unwrap() as usize] = token.clone(); 27 | } 28 | // Build vocab 29 | let mut vocab = BasicVocab::new(); 30 | let mut temp_vec: Vec = Vec::new(); 31 | for token in token_vec { 32 | if !token.is_empty() { 33 | vocab.add_token(token.clone()); 34 | temp_vec.push(token); 35 | } 36 | } 37 | BPEVocab { vocab } 38 | } 39 | 40 | fn len(&self) -> usize { 41 | self.vocab.len() 42 | } 43 | 44 | fn tokens_from_indexes(&self, indexes: &[usize]) -> Result, TokenNotFoundError> { 45 | self.vocab.tokens_from_indexes(indexes) 46 | } 47 | 48 | fn batch_tokens_from_indexes( 49 | &self, 50 | indexes: &[Vec], 51 | ) -> Result>, TokenNotFoundError> { 52 | self.vocab.batch_tokens_from_indexes(indexes) 53 | } 54 | 55 | fn indexes_from_tokens(&self, tokens: &[String]) -> Result, TokenNotFoundError> { 56 | self.vocab.indexes_from_tokens(tokens) 57 | } 58 | 59 | fn batch_indexes_from_tokens( 60 | &self, 61 | tokens: &[Vec], 62 | ) -> Result>, TokenNotFoundError> { 63 | self.vocab.batch_indexes_from_tokens(tokens) 64 | } 65 | } 66 | 67 | impl Node>> for BPEVocab { 68 | type Output = Vec>; 69 | 70 | fn process(&mut self, input: Vec>) -> Self::Output { 71 | self.batch_indexes_from_tokens(&input).unwrap() 72 | } 73 | } 74 | 75 | impl Node> for BPEVocab { 76 | type Output = Vec; 77 | 78 | fn process(&mut self, input: Vec) -> Self::Output { 79 | self.indexes_from_tokens(&input).unwrap() 80 | } 81 | } 82 | 83 | impl Node>> for BPEVocab { 84 | type Output = Vec>; 85 | 86 | fn process(&mut self, input: Vec>) -> Self::Output { 87 | self.batch_tokens_from_indexes(&input).unwrap() 88 | } 89 | } 90 | 91 | impl Node> for BPEVocab { 92 | type Output = Vec; 93 | 94 | fn process(&mut self, input: Vec) -> Self::Output { 95 | self.tokens_from_indexes(&input).unwrap() 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /dataflow_nlp/src/vocab/mod.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod tests; 3 | 4 | mod basic; 5 | pub(crate) use basic::BasicVocab; 6 | mod wordpiece; 7 | pub use wordpiece::WordPieceVocab; 8 | mod bpe; 9 | pub use bpe::BPEVocab; 10 | 11 | use serde::Serialize; 12 | 13 | /// A trait every vocab object must implement 14 | pub trait Vocab: Serialize + Default + Clone { 15 | fn load() -> Self; 16 | fn len(&self) -> usize; 17 | 18 | fn is_empty(&self) -> bool { 19 | self.len() == 0 20 | } 21 | 22 | fn tokens_from_indexes(&self, indexes: &[usize]) -> Result, TokenNotFoundError>; 23 | fn batch_tokens_from_indexes( 24 | &self, 25 | indexes: &[Vec], 26 | ) -> Result>, TokenNotFoundError>; 27 | fn indexes_from_tokens(&self, tokens: &[String]) -> Result, TokenNotFoundError>; 28 | fn batch_indexes_from_tokens( 29 | &self, 30 | tokens: &[Vec], 31 | ) -> Result>, TokenNotFoundError>; 32 | } 33 | 34 | /// Custom Error Types 35 | #[derive(Debug)] 36 | pub struct TokenNotFoundError; 37 | -------------------------------------------------------------------------------- /dataflow_nlp/src/vocab/tests.rs: -------------------------------------------------------------------------------- 1 | use crate::vocab::{BPEVocab, Vocab, WordPieceVocab}; 2 | #[test] 3 | fn creating_vocab() { 4 | let _wordpiece_vocab = WordPieceVocab::load(); 5 | let _bpe_vocab = BPEVocab::load(); 6 | } 7 | 8 | #[test] 9 | fn indexes_from_tokens_bpe() { 10 | let bpe_vocab = BPEVocab::load(); 11 | let tokens = ["Hello", ",", " ", "how", " ", "are", " ", "you", "?"]; 12 | let mut tokens_vec: Vec = Vec::new(); 13 | for token in tokens.iter() { 14 | tokens_vec.push(String::from(*token)); 15 | } 16 | let indexes = bpe_vocab.indexes_from_tokens(&tokens_vec); 17 | assert_eq!( 18 | indexes.unwrap(), 19 | vec![23858, 37861, 4, 4786, 4, 290, 4, 3258, 22092] 20 | ); 21 | } 22 | 23 | #[test] 24 | fn indexes_from_tokens_wordpiece() { 25 | let wordpiece_vocab = WordPieceVocab::load(); 26 | let tokens = ["hello", ",", "how", "are", "you", "?"]; 27 | let mut tokens_vec: Vec = Vec::new(); 28 | for token in tokens.iter() { 29 | tokens_vec.push(String::from(*token)); 30 | } 31 | let indexes = wordpiece_vocab.indexes_from_tokens(&tokens_vec); 32 | assert_eq!(indexes.unwrap(), vec![7596, 1014, 2133, 2028, 2021, 1033]); 33 | } 34 | 35 | #[test] 36 | fn tokens_from_indexes_bpe() { 37 | let bpe_vocab = BPEVocab::load(); 38 | let tokens = ["Hello", ",", " ", "how", " ", "are", " ", "you", "?"]; 39 | let mut tokens_vec: Vec = Vec::new(); 40 | for token in tokens.iter() { 41 | tokens_vec.push(String::from(*token)); 42 | } 43 | let tokens = bpe_vocab.tokens_from_indexes(&[23858, 37861, 4, 4786, 4, 290, 4, 3258, 22092]); 44 | assert_eq!(tokens.unwrap(), tokens_vec); 45 | } 46 | 47 | #[test] 48 | fn tokens_from_indexes_wordpiece() { 49 | let wordpiece_vocab = WordPieceVocab::load(); 50 | let tokens = ["hello", ",", "how", "are", "you", "?"]; 51 | let mut tokens_vec: Vec = Vec::new(); 52 | for token in tokens.iter() { 53 | tokens_vec.push(String::from(*token)); 54 | } 55 | let tokens = wordpiece_vocab.tokens_from_indexes(&[7596, 1014, 2133, 2028, 2021, 1033]); 56 | assert_eq!(tokens.unwrap(), tokens_vec); 57 | } 58 | 59 | #[test] 60 | fn batch_indexes_from_tokens() { 61 | let bpe_vocab = BPEVocab::load(); 62 | let tokens = ["Hello", ",", " ", "how", " ", "are", " ", "you", "?"]; 63 | let mut tokens_vec: Vec> = vec![Vec::new()]; 64 | for token in tokens.iter() { 65 | tokens_vec[0].push(String::from(*token)); 66 | } 67 | let indexes = bpe_vocab.batch_indexes_from_tokens(&tokens_vec); 68 | assert_eq!( 69 | indexes.unwrap()[0], 70 | vec![23858, 37861, 4, 4786, 4, 290, 4, 3258, 22092] 71 | ); 72 | } 73 | 74 | #[test] 75 | fn batch_tokens_from_indexes() { 76 | let bpe_vocab = BPEVocab::load(); 77 | let tokens = ["Hello", ",", " ", "how", " ", "are", " ", "you", "?"]; 78 | let mut tokens_vec: Vec = Vec::new(); 79 | for token in tokens.iter() { 80 | tokens_vec.push(String::from(*token)); 81 | } 82 | let tokens = 83 | bpe_vocab.batch_tokens_from_indexes(&[vec![23858, 37861, 4, 4786, 4, 290, 4, 3258, 22092]]); 84 | assert_eq!(tokens.unwrap()[0], tokens_vec); 85 | } 86 | -------------------------------------------------------------------------------- /dataflow_nlp/src/vocab/wordpiece.rs: -------------------------------------------------------------------------------- 1 | use super::{BasicVocab, TokenNotFoundError, Vocab}; 2 | use dataflow::prelude::Node; 3 | use serde::{Deserialize, Serialize}; 4 | 5 | #[derive(Serialize, Deserialize, Default, Clone)] 6 | pub struct WordPieceVocab { 7 | vocab: BasicVocab, 8 | } 9 | 10 | impl Vocab for WordPieceVocab { 11 | fn load() -> Self { 12 | // Open vocab file 13 | let tokens: Vec<&str> = include_str!("../resources/wordpiece_vocab.txt") 14 | .split('\n') 15 | .collect(); 16 | // Build vocab 17 | let mut vocab = BasicVocab::new(); 18 | for token in tokens { 19 | vocab.add_token(String::from(token)); 20 | } 21 | WordPieceVocab { vocab } 22 | } 23 | 24 | fn len(&self) -> usize { 25 | self.vocab.len() 26 | } 27 | 28 | fn tokens_from_indexes(&self, indexes: &[usize]) -> Result, TokenNotFoundError> { 29 | self.vocab.tokens_from_indexes(indexes) 30 | } 31 | 32 | fn batch_tokens_from_indexes( 33 | &self, 34 | indexes: &[Vec], 35 | ) -> Result>, TokenNotFoundError> { 36 | self.vocab.batch_tokens_from_indexes(indexes) 37 | } 38 | 39 | fn indexes_from_tokens(&self, tokens: &[String]) -> Result, TokenNotFoundError> { 40 | self.vocab.indexes_from_tokens(tokens) 41 | } 42 | 43 | fn batch_indexes_from_tokens( 44 | &self, 45 | tokens: &[Vec], 46 | ) -> Result>, TokenNotFoundError> { 47 | self.vocab.batch_indexes_from_tokens(tokens) 48 | } 49 | } 50 | 51 | impl Node>> for WordPieceVocab { 52 | type Output = Vec>; 53 | 54 | fn process(&mut self, input: Vec>) -> Self::Output { 55 | self.batch_indexes_from_tokens(&input).unwrap() 56 | } 57 | } 58 | 59 | impl Node> for WordPieceVocab { 60 | type Output = Vec; 61 | 62 | fn process(&mut self, input: Vec) -> Self::Output { 63 | self.indexes_from_tokens(&input).unwrap() 64 | } 65 | } 66 | 67 | impl Node>> for WordPieceVocab { 68 | type Output = Vec>; 69 | 70 | fn process(&mut self, input: Vec>) -> Self::Output { 71 | self.batch_tokens_from_indexes(&input).unwrap() 72 | } 73 | } 74 | 75 | impl Node> for WordPieceVocab { 76 | type Output = Vec; 77 | 78 | fn process(&mut self, input: Vec) -> Self::Output { 79 | self.tokens_from_indexes(&input).unwrap() 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jafioti/dataflow/85f98f5ee8736835b26d51609d722f48b4f65b88/src/.DS_Store -------------------------------------------------------------------------------- /src/dataloader/dataloader.rs: -------------------------------------------------------------------------------- 1 | use rand::{prelude::SliceRandom, thread_rng}; 2 | use std::{collections::VecDeque, thread}; 3 | 4 | use crate::pipeline::Node; 5 | 6 | pub struct Dataloader { 7 | pipeline: Option, Output = Vec> + Send>>, 8 | last_pipeline_length: usize, 9 | buffer: VecDeque, 10 | load_block_size: usize, 11 | buffer_size: usize, 12 | #[allow(clippy::type_complexity)] 13 | loading_process: 14 | Option, Output = Vec> + Send>, Vec)>>, 15 | } 16 | 17 | impl Dataloader { 18 | pub fn new(mut pipeline: impl Node, Output = Vec> + Send + 'static) -> Self { 19 | pipeline.reset(); 20 | Dataloader { 21 | pipeline: Some(Box::new(pipeline)), 22 | buffer: VecDeque::new(), 23 | last_pipeline_length: 0, 24 | load_block_size: 1000, 25 | buffer_size: 1000, 26 | loading_process: None, 27 | } 28 | } 29 | 30 | pub fn load_block_size(self, load_block_size: usize) -> Self { 31 | Dataloader { 32 | load_block_size, 33 | ..self 34 | } 35 | } 36 | 37 | pub fn buffer_size(self, buffer_size: usize) -> Self { 38 | Dataloader { 39 | buffer_size, 40 | ..self 41 | } 42 | } 43 | 44 | fn load_block(&mut self) { 45 | if self.loading_process.is_some() 46 | || self.pipeline.is_none() 47 | || self.pipeline.as_ref().unwrap().data_remaining(0) == 0 48 | { 49 | return; 50 | } 51 | 52 | // Launch loading thread 53 | let mut pipeline = self.pipeline.take().unwrap(); 54 | let load_block_size = self.load_block_size; 55 | self.loading_process = Some(thread::spawn(move || { 56 | let mut data = pipeline.process(vec![(); load_block_size]); 57 | data.shuffle(&mut thread_rng()); 58 | (pipeline, data) 59 | })); 60 | } 61 | 62 | pub fn len(&mut self) -> usize { 63 | let pipeline_data = if let Some(p) = &self.pipeline { 64 | let l = p.data_remaining(0); 65 | self.last_pipeline_length = l; 66 | l 67 | } else { 68 | self.last_pipeline_length 69 | }; 70 | pipeline_data + self.buffer.len() 71 | } 72 | 73 | pub fn is_empty(&mut self) -> bool { 74 | self.len() == 0 75 | } 76 | 77 | pub fn iter_len(&mut self) -> LenIterDataloader { 78 | LenIterDataloader { dataloader: self } 79 | } 80 | } 81 | 82 | impl Iterator for Dataloader { 83 | type Item = T; 84 | 85 | fn next(&mut self) -> Option { 86 | loop { 87 | // Check if the loading thread is finished 88 | if let Some(process) = &self.loading_process { 89 | if process.is_finished() || self.buffer.is_empty() { 90 | // Unload thread 91 | let process = self.loading_process.take().unwrap(); 92 | let (pipeline, data) = process.join().unwrap(); 93 | self.pipeline = Some(pipeline); 94 | self.buffer.extend(data); 95 | } 96 | } 97 | // Launch thread if not currently running and buffer running low 98 | if self.buffer.len() < self.buffer_size && self.pipeline.is_some() { 99 | if self.pipeline.as_ref().unwrap().data_remaining(0) == 0 && self.buffer.is_empty() 100 | { 101 | self.pipeline.as_mut().unwrap().reset(); 102 | return None; 103 | } 104 | self.load_block(); 105 | } 106 | 107 | // Get data from buffer 108 | if let Some(d) = self.buffer.pop_front() { 109 | return Some(d); 110 | } else if let Some(process) = self.loading_process.take() { 111 | let (pipeline, data) = process.join().unwrap(); 112 | self.pipeline = Some(pipeline); 113 | self.buffer.extend(data); 114 | if let Some(d) = self.buffer.pop_front() { 115 | return Some(d); 116 | } 117 | } 118 | } 119 | } 120 | } 121 | 122 | pub struct LenIterDataloader<'a, T> { 123 | dataloader: &'a mut Dataloader, 124 | } 125 | 126 | impl<'a, T: Send + 'static> Iterator for LenIterDataloader<'a, T> { 127 | type Item = (T, usize); 128 | 129 | fn next(&mut self) -> Option { 130 | let item = self.dataloader.next()?; 131 | Some((item, self.dataloader.len())) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /src/dataloader/mod.rs: -------------------------------------------------------------------------------- 1 | #[allow(clippy::module_inception)] 2 | mod dataloader; 3 | pub use dataloader::*; 4 | 5 | //mod threaded_dataloader; 6 | //pub use threaded_dataloader::*; 7 | 8 | #[cfg(test)] 9 | mod tests; 10 | -------------------------------------------------------------------------------- /src/dataloader/tests.rs: -------------------------------------------------------------------------------- 1 | use super::Dataloader; 2 | use crate::pipeline::*; 3 | use rand::{prelude::SliceRandom, thread_rng}; 4 | 5 | /// A "loader" to load a full range of numbers randomly 6 | struct CreateRange { 7 | nums_to_make: Vec, 8 | current_progress: usize, 9 | } 10 | 11 | impl CreateRange { 12 | pub fn new(max: usize) -> Self { 13 | CreateRange { 14 | nums_to_make: (0..max).collect(), 15 | current_progress: 0, 16 | } 17 | } 18 | } 19 | 20 | impl Node> for CreateRange { 21 | type Output = Vec; 22 | 23 | fn process(&mut self, input: Vec<()>) -> Self::Output { 24 | let data = 25 | self.nums_to_make[self.current_progress..self.current_progress + input.len()].to_vec(); 26 | self.current_progress += input.len(); 27 | data 28 | } 29 | 30 | fn reset(&mut self) { 31 | self.nums_to_make.shuffle(&mut thread_rng()); 32 | self.current_progress = 0; 33 | } 34 | 35 | fn data_remaining(&self, _before: usize) -> usize { 36 | self.nums_to_make.len() - self.current_progress 37 | } 38 | } 39 | 40 | #[test] 41 | fn test_dataloader() { 42 | // Write a dataloader test 43 | let pipeline = CreateRange::new(10_000) 44 | .map(|i: usize| i * 10) 45 | .chain(Batch::new(10)); 46 | let mut loader = Dataloader::new(pipeline); 47 | assert_eq!(loader.len(), 1000); 48 | 49 | // Run for 5_000 steps and collect results 50 | let mut data = Vec::with_capacity(10_000); 51 | for example in &mut loader { 52 | data.extend(example.into_iter()); 53 | if data.len() == 5_000 { 54 | break; 55 | } 56 | } 57 | 58 | // Check the examples, 5_000 should be retrieved 59 | assert_eq!(data.len(), 5_000); 60 | 61 | // Run for the rest of the data and store it 62 | for example in &mut loader { 63 | data.extend(example.into_iter()); 64 | } 65 | assert_eq!(loader.len(), 1000); // Make sure the loader reset 66 | 67 | // Sort read data 68 | data.sort_unstable(); 69 | 70 | // Compare data 71 | assert_eq!(data, (0..10_000).map(|i| i * 10).collect::>()) 72 | } 73 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | /// Dataloader module contains the main dataloader struct, as well as dataloader utilities 2 | pub mod dataloader; 3 | /// Pipeline module contains the dataflow pipeline struct, as well as all pipeline utilities 4 | pub mod pipeline; 5 | 6 | pub mod prelude { 7 | pub use crate::dataloader::*; 8 | pub use crate::pipeline::*; 9 | } 10 | -------------------------------------------------------------------------------- /src/pipeline/connectors.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use super::Node; 4 | 5 | /// A node that takes in T and outputs (T, T) 6 | pub struct Duplicator { 7 | _phantom: PhantomData, 8 | } 9 | 10 | impl Default for Duplicator { 11 | fn default() -> Self { 12 | Duplicator { 13 | _phantom: PhantomData::default(), 14 | } 15 | } 16 | } 17 | 18 | impl Node for Duplicator { 19 | type Output = (T, T); 20 | fn process(&mut self, input: T) -> Self::Output { 21 | (input.clone(), input) 22 | } 23 | } 24 | 25 | /// Pair contains two nodes that run in parallel (TODO: actually make parallel) 26 | pub struct Pair, N2: Node> { 27 | pub node1: N1, 28 | pub node2: N2, 29 | _phantom: PhantomData<(I0, I1)>, 30 | } 31 | 32 | impl, I2, O2, N2: Node> Pair { 33 | pub fn new(node1: N1, node2: N2) -> Self { 34 | Pair { 35 | node1, 36 | node2, 37 | _phantom: Default::default(), 38 | } 39 | } 40 | } 41 | 42 | impl, N2: Node> Node<(I1, I2)> for Pair { 43 | type Output = (N1::Output, N2::Output); 44 | 45 | fn process(&mut self, (a, b): (I1, I2)) -> Self::Output { 46 | (self.node1.process(a), self.node2.process(b)) 47 | } 48 | 49 | fn reset(&mut self) { 50 | self.node1.reset(); 51 | self.node2.reset(); 52 | } 53 | 54 | fn data_remaining(&self, before: usize) -> usize { 55 | usize::min( 56 | self.node1.data_remaining(before), 57 | self.node2.data_remaining(before), 58 | ) 59 | } 60 | } 61 | 62 | macro_rules! tuple_impls { 63 | ([$($name:ident),+] [$($idx:tt),+], $last:ident, [$($rev_tail:ident),+]) => { 64 | impl< 65 | Input, 66 | $last: 67 | $(Node::<$rev_tail ::Output>, $rev_tail: )+ 68 | Node 69 | > Node for ($($name,)+) { 70 | type Output = $last ::Output; 71 | 72 | fn process(&mut self, x: Input) -> Self::Output { 73 | $(let x = self.$idx.process(x);)+ 74 | x 75 | } 76 | 77 | fn reset(&mut self) { 78 | $(self.$idx.reset();)+ 79 | } 80 | 81 | fn data_remaining(&self, mut before: usize) -> usize { 82 | $( before = self.$idx.data_remaining(before); )+ 83 | before 84 | } 85 | } 86 | }; 87 | } 88 | 89 | tuple_impls!([M1, M2] [0, 1], M2, [M1]); 90 | tuple_impls!([M1, M2, M3] [0, 1, 2], M3, [M2, M1]); 91 | tuple_impls!([M1, M2, M3, M4] [0, 1, 2, 3], M4, [M3, M2, M1]); 92 | tuple_impls!([M1, M2, M3, M4, M5] [0, 1, 2, 3, 4], M5, [M4, M3, M2, M1]); 93 | tuple_impls!([M1, M2, M3, M4, M5, M6] [0, 1, 2, 3, 4, 5], M6, [M5, M4, M3, M2, M1]); 94 | -------------------------------------------------------------------------------- /src/pipeline/loader/file.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fs::File, 3 | io::Read, 4 | path::{Path, PathBuf}, 5 | }; 6 | 7 | use itertools::Itertools; 8 | use rand::{prelude::SliceRandom, thread_rng}; 9 | 10 | use crate::pipeline::*; 11 | 12 | #[derive(Clone)] 13 | pub struct FileLoader { 14 | files: Vec, 15 | currently_loaded_index: usize, // The last example we loaded as an index of the load_order vector (starts at 0) 16 | } 17 | 18 | impl FileLoader { 19 | pub fn new(mut files: Vec) -> Self { 20 | FileLoader { 21 | files: { 22 | files.shuffle(&mut thread_rng()); 23 | files 24 | }, 25 | currently_loaded_index: 0, 26 | } 27 | } 28 | 29 | pub fn from_directory>(path: P) -> Self { 30 | FileLoader { 31 | files: std::fs::read_dir(path) 32 | .unwrap() 33 | .flatten() 34 | .map(|f| f.path()) 35 | .collect_vec(), 36 | currently_loaded_index: 0, 37 | } 38 | } 39 | } 40 | 41 | impl Node> for FileLoader { 42 | type Output = Vec<(PathBuf, Vec)>; 43 | 44 | fn process(&mut self, input: Vec<()>) -> Self::Output { 45 | let mut read_data = vec![]; 46 | for file in self.files[self.currently_loaded_index 47 | ..(self.currently_loaded_index + input.len()).min(self.files.len() - 1)] 48 | .iter() 49 | { 50 | let mut data = Vec::new(); 51 | let mut f = File::open(file).expect("FileLoader failed to load file!"); 52 | f.read_to_end(&mut data).expect("Failed to read file!"); 53 | read_data.push((file.clone(), data)); 54 | } 55 | self.currently_loaded_index = 56 | (self.currently_loaded_index + input.len()).min(self.files.len()); 57 | read_data 58 | } 59 | 60 | fn reset(&mut self) { 61 | self.files.shuffle(&mut thread_rng()); 62 | self.currently_loaded_index = 0; 63 | } 64 | 65 | fn data_remaining(&self, _before: usize) -> usize { 66 | self.files.len() - self.currently_loaded_index 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /src/pipeline/loader/keyed.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fs::File, 3 | io::{BufRead, BufReader}, 4 | }; 5 | 6 | use itertools::Itertools; 7 | 8 | use crate::pipeline::*; 9 | 10 | /// A loader with a key generating function 11 | #[derive(Clone)] 12 | pub struct KeyedLoader { 13 | files: Vec, 14 | file_sizes: Vec, 15 | delimeter: String, 16 | } 17 | 18 | impl KeyedLoader { 19 | pub fn new(files: &[&str], delimeter: &str) -> Self { 20 | // Get file sizes 21 | let file_sizes: Vec = files 22 | .iter() 23 | .map(|f| { 24 | let file = File::open(f).unwrap(); 25 | let reader = BufReader::new(file); 26 | let mut delimeter_count = 0; 27 | if delimeter == "\n" { 28 | delimeter_count = reader.lines().count(); 29 | } else { 30 | for line in reader.lines().flatten() { 31 | delimeter_count += line.matches(delimeter).count(); 32 | } 33 | delimeter_count += 1; // Since delimeters divide the examples, there should be 1 more example than delimeter 34 | } 35 | delimeter_count 36 | }) 37 | .collect(); 38 | 39 | KeyedLoader { 40 | files: files.iter().map(|s| s.to_string()).collect(), 41 | file_sizes, 42 | delimeter: delimeter.to_string(), 43 | } 44 | } 45 | } 46 | 47 | impl Node> for KeyedLoader { 48 | type Output = Vec; 49 | 50 | fn process(&mut self, input: Vec) -> Self::Output { 51 | // Get bounds to load from 52 | let (min, max) = input.iter().minmax().into_option().unwrap().to_owned(); 53 | let (mut min, mut max) = (*min, *max); 54 | let (mut min_file, mut max_file) = (0, 0); 55 | let mut counter = 0; 56 | for (index, file_size) in self.file_sizes.iter().enumerate() { 57 | counter += file_size; 58 | if counter > min { 59 | min_file = index; 60 | min -= counter + file_size; 61 | } 62 | if counter + file_size > max { 63 | max_file = index; 64 | max -= counter + file_size; 65 | } 66 | } 67 | // Sort inputs and keep track of order (orig order, sorted indexes) 68 | let mut sorted_inputs: Vec<(usize, usize)> = input.into_iter().enumerate().collect(); 69 | sorted_inputs.sort_by(|a, b| a.1.cmp(&b.1)); 70 | 71 | // Load all segments from min to max 72 | let mut buffer = Vec::with_capacity(sorted_inputs.len()); 73 | for file_index in min_file..max_file + 1 { 74 | let file = File::open(&self.files[file_index]).unwrap(); 75 | let reader = BufReader::new(file); 76 | 77 | let mut index_counter = 0; 78 | let mut segment_counter = if file_index == min_file { min } else { 0 }; 79 | let segments_to_take = if file_index == max_file { 80 | max 81 | } else { 82 | self.file_sizes[file_index] 83 | }; 84 | if self.delimeter == "\n" { 85 | for line in reader.lines().flatten() { 86 | if segment_counter == sorted_inputs[index_counter].1 { 87 | buffer.push(line); 88 | index_counter += 1; 89 | if index_counter == sorted_inputs.len() { 90 | return buffer; 91 | } 92 | } 93 | segment_counter += 1; 94 | } 95 | } else { 96 | let mut intermediate_segment = "".to_string(); 97 | for line in reader.lines().flatten() { 98 | let line_segments: Vec<&str> = line.split(&self.delimeter).collect(); 99 | 100 | if segment_counter == sorted_inputs[index_counter].1 { 101 | buffer.push(format!("{}{}", intermediate_segment, line_segments[0])); 102 | index_counter += 1; 103 | if index_counter == sorted_inputs.len() { 104 | return buffer; 105 | } 106 | } 107 | for line_segment in line_segments 108 | .iter() 109 | .take((segments_to_take - counter).min(line_segments.len() - 1)) 110 | { 111 | if segment_counter == sorted_inputs[index_counter].1 { 112 | buffer.push(line_segment.to_string()); 113 | index_counter += 1; 114 | if index_counter == sorted_inputs.len() { 115 | return buffer; 116 | } 117 | } 118 | } 119 | intermediate_segment = line_segments.last().unwrap().to_string(); 120 | 121 | segment_counter += line_segments.len() - 1; 122 | if segment_counter >= segments_to_take { 123 | break; 124 | } 125 | } 126 | } 127 | } 128 | 129 | buffer 130 | } 131 | 132 | fn reset(&mut self) { 133 | // Recalculate file sizes 134 | let file_sizes = self 135 | .files 136 | .iter() 137 | .map(|f| { 138 | let file = File::open(f).unwrap(); 139 | let reader = BufReader::new(file); 140 | let mut delimeter_count = 0; 141 | if self.delimeter == "\n" { 142 | delimeter_count = reader.lines().count(); 143 | } else { 144 | for line in reader.lines().flatten() { 145 | delimeter_count += line.matches(&self.delimeter).count(); 146 | } 147 | delimeter_count += 1; // Since delimeters divide the examples, there should be 1 more example than delimeter 148 | } 149 | delimeter_count 150 | }) 151 | .collect(); 152 | self.file_sizes = file_sizes; 153 | } 154 | 155 | fn data_remaining(&self, before: usize) -> usize { 156 | before 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /src/pipeline/loader/mod.rs: -------------------------------------------------------------------------------- 1 | mod keyed; 2 | pub use keyed::*; 3 | mod file; 4 | pub use file::*; 5 | mod vec; 6 | pub use vec::*; 7 | -------------------------------------------------------------------------------- /src/pipeline/loader/vec.rs: -------------------------------------------------------------------------------- 1 | use rand::{seq::SliceRandom, thread_rng}; 2 | 3 | use crate::prelude::Node; 4 | 5 | pub struct VecLoader { 6 | elements: Vec, 7 | shuffle: bool, 8 | current_progress: usize, 9 | } 10 | 11 | impl VecLoader { 12 | pub fn new(elements: Vec) -> Self { 13 | Self { 14 | elements, 15 | shuffle: false, 16 | current_progress: 0, 17 | } 18 | } 19 | 20 | pub fn shuffle(mut self, shuffle: bool) -> Self { 21 | self.shuffle = shuffle; 22 | self 23 | } 24 | } 25 | 26 | impl Node> for VecLoader { 27 | type Output = Vec; 28 | 29 | fn reset(&mut self) { 30 | if self.shuffle { 31 | self.elements.shuffle(&mut thread_rng()); 32 | } 33 | self.current_progress = 0; 34 | } 35 | 36 | fn process(&mut self, input: Vec<()>) -> Self::Output { 37 | if self.current_progress >= self.elements.len() { 38 | return vec![]; 39 | } 40 | let elements = self.elements 41 | [self.current_progress..(self.current_progress + input.len()).min(self.elements.len())] 42 | .to_vec(); 43 | self.current_progress += input.len(); 44 | elements 45 | } 46 | 47 | fn data_remaining(&self, _: usize) -> usize { 48 | self.elements.len().saturating_sub(self.current_progress) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/pipeline/mod.rs: -------------------------------------------------------------------------------- 1 | mod node; 2 | pub use node::*; 3 | mod premade; 4 | pub use premade::*; 5 | mod loader; 6 | pub use loader::*; 7 | mod connectors; 8 | pub use connectors::*; 9 | 10 | #[cfg(test)] 11 | mod tests; 12 | 13 | #[derive(Clone, Copy)] 14 | pub struct Pipeline; 15 | 16 | impl Node for Pipeline { 17 | type Output = I; 18 | 19 | fn process(&mut self, input: I) -> I { 20 | input 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/pipeline/node.rs: -------------------------------------------------------------------------------- 1 | use super::{Duplicator, Pair}; 2 | 3 | pub trait Node { 4 | type Output; 5 | 6 | /// Process a batch of data 7 | fn process(&mut self, input: Input) -> Self::Output; 8 | /// Reset signal propogates through pipeline 9 | fn reset(&mut self) {} 10 | /// Get number of examples left 11 | fn data_remaining(&self, before: usize) -> usize { 12 | before 13 | } // Defaults to same as previous remaining data 14 | } 15 | 16 | impl O> Node for F { 17 | type Output = O; 18 | fn process(&mut self, input: I) -> Self::Output { 19 | (self)(input) 20 | } 21 | } 22 | 23 | pub trait ExtendNode> { 24 | fn chain>(self, node: N) -> (E, N); 25 | } 26 | 27 | impl> ExtendNode for E { 28 | fn chain>(self, node: N) -> (E, N) 29 | where 30 | Self: std::marker::Sized, 31 | { 32 | (self, node) 33 | } 34 | } 35 | 36 | pub trait ExtendNodeSplit> { 37 | #[allow(clippy::type_complexity)] 38 | fn split, E2: Node>( 39 | self, 40 | node1: E1, 41 | node2: E2, 42 | ) -> (E, Duplicator, Pair); 43 | } 44 | 45 | impl> ExtendNodeSplit 46 | for E 47 | { 48 | #[allow(clippy::type_complexity)] 49 | fn split, E2: Node>( 50 | self, 51 | node1: E1, 52 | node2: E2, 53 | ) -> (E, Duplicator, Pair) { 54 | (self, Duplicator::default(), Pair::new(node1, node2)) 55 | } 56 | } 57 | 58 | pub trait ExtendNodePair> { 59 | #[allow(clippy::type_complexity)] 60 | fn pair, N2: Node>( 61 | self, 62 | node1: N1, 63 | node2: N2, 64 | ) -> (E, Pair); 65 | } 66 | 67 | impl> ExtendNodePair 68 | for E 69 | { 70 | fn pair, N2: Node>( 71 | self, 72 | node1: N1, 73 | node2: N2, 74 | ) -> (E, Pair) { 75 | (self, Pair::new(node1, node2)) 76 | } 77 | } 78 | 79 | /// Feed a bunch of empty types until processing is done, returns result as vector 80 | pub trait RunNode { 81 | fn run(self, block_size: usize) -> Vec; 82 | } 83 | 84 | impl, Output = Vec>> RunNode for N { 85 | fn run(mut self, block_size: usize) -> Vec { 86 | let mut results = vec![]; 87 | while self.data_remaining(usize::MAX) > 0 { 88 | results.append(&mut self.process(vec![(); block_size])); 89 | } 90 | results 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/pipeline/premade/batch.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use itertools::Itertools; 4 | 5 | use crate::pipeline::Node; 6 | 7 | /// Create batches from examples 8 | #[derive(Clone, Copy)] 9 | pub struct Batch { 10 | _phantom: PhantomData, 11 | batch_size: usize, 12 | } 13 | 14 | impl Batch { 15 | pub fn new(batch_size: usize) -> Self { 16 | Batch { 17 | _phantom: PhantomData::default(), 18 | batch_size, 19 | } 20 | } 21 | } 22 | 23 | impl Node> for Batch { 24 | type Output = Vec>; 25 | 26 | fn process(&mut self, mut input: Vec) -> Self::Output { 27 | let mut batches = Vec::with_capacity(input.len() / self.batch_size); 28 | while !input.is_empty() { 29 | batches.push( 30 | input 31 | .drain(..usize::min(self.batch_size, input.len())) 32 | .collect(), 33 | ); 34 | } 35 | batches 36 | } 37 | 38 | fn data_remaining(&self, before: usize) -> usize { 39 | before / self.batch_size 40 | } 41 | } 42 | 43 | /// Create batches from examples 44 | pub struct ArrayBatch { 45 | _phantom: PhantomData, 46 | } 47 | 48 | impl Default for ArrayBatch { 49 | fn default() -> Self { 50 | Self { 51 | _phantom: Default::default(), 52 | } 53 | } 54 | } 55 | 56 | impl Node> for ArrayBatch { 57 | type Output = Vec<[T; B]>; 58 | fn process(&mut self, input: Vec) -> Self::Output { 59 | let mut batches = Vec::with_capacity(input.len() / B); 60 | let chunks = input.into_iter().chunks(B); 61 | let mut chunks_iter = chunks.into_iter(); 62 | while let Some(Ok(b)) = chunks_iter.next().map(|i| i.collect::>().try_into()) { 63 | batches.push(b); 64 | } 65 | batches 66 | } 67 | 68 | fn data_remaining(&self, before: usize) -> usize { 69 | before / B 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/pipeline/premade/map.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::type_complexity)] 2 | 3 | use std::marker::PhantomData; 4 | 5 | use crate::pipeline::Node; 6 | 7 | pub struct Map> { 8 | _phantom: PhantomData, 9 | node: N, 10 | } 11 | 12 | impl + Clone> Clone for Map { 13 | fn clone(&self) -> Self { 14 | Self { 15 | _phantom: self._phantom, 16 | node: self.node.clone(), 17 | } 18 | } 19 | } 20 | 21 | impl> Map { 22 | pub fn new(node: E) -> Self { 23 | Map { 24 | _phantom: PhantomData::default(), 25 | node, 26 | } 27 | } 28 | } 29 | 30 | impl> Node> for Map { 31 | type Output = Vec; 32 | 33 | fn process(&mut self, input: Vec) -> Self::Output { 34 | input.into_iter().map(|i| self.node.process(i)).collect() 35 | } 36 | } 37 | 38 | pub trait ExtendNodeMap>> { 39 | fn map>(self, node: N) -> (E, Map); 40 | fn filter_map>>( 41 | self, 42 | node: N, 43 | ) -> (E, FilterMap); 44 | fn filter bool>(self, function: F) -> (E, Filter); 45 | } 46 | 47 | impl>> ExtendNodeMap for E { 48 | fn map>(self, node: N) -> (E, Map) 49 | where 50 | Self: std::marker::Sized, 51 | { 52 | (self, Map::new(node)) 53 | } 54 | 55 | fn filter_map>>( 56 | self, 57 | node: N, 58 | ) -> (E, FilterMap) 59 | where 60 | Self: std::marker::Sized, 61 | { 62 | (self, FilterMap::new(node)) 63 | } 64 | 65 | fn filter bool>(self, function: F) -> (E, Filter) { 66 | (self, Filter::new(function)) 67 | } 68 | } 69 | 70 | pub trait ExtendNodeFlatten, Output = Vec>> { 71 | fn flatten(self) -> (N, Flatten); 72 | } 73 | 74 | impl, Output = Vec>> ExtendNodeFlatten for N { 75 | fn flatten(self) -> (Self, Flatten) { 76 | (self, Flatten) 77 | } 78 | } 79 | 80 | pub struct Flatten; 81 | 82 | impl Node> for Flatten { 83 | type Output = Vec; 84 | 85 | fn process(&mut self, input: Vec) -> Self::Output { 86 | input.into_iter().flatten().collect() 87 | } 88 | } 89 | 90 | pub struct FilterMap> { 91 | _phantom: PhantomData, 92 | node: N, 93 | } 94 | 95 | impl>> FilterMap { 96 | pub fn new(node: E) -> Self { 97 | FilterMap { 98 | _phantom: PhantomData::default(), 99 | node, 100 | } 101 | } 102 | } 103 | 104 | impl>> Node> for FilterMap { 105 | type Output = Vec; 106 | 107 | fn process(&mut self, input: Vec) -> Self::Output { 108 | input 109 | .into_iter() 110 | .filter_map(|i| self.node.process(i)) 111 | .collect() 112 | } 113 | } 114 | 115 | pub struct Filter bool> { 116 | _phantom: PhantomData, 117 | function: F, 118 | } 119 | 120 | impl bool> Filter { 121 | pub fn new(function: F) -> Self { 122 | Filter { 123 | _phantom: PhantomData::default(), 124 | function, 125 | } 126 | } 127 | } 128 | 129 | impl bool> Node> for Filter { 130 | type Output = Vec; 131 | 132 | fn process(&mut self, input: Vec) -> Self::Output { 133 | input.into_iter().filter(|i| (self.function)(i)).collect() 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /src/pipeline/premade/mapreduce.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::BTreeMap, marker::PhantomData}; 2 | 3 | use crate::pipeline::Node; 4 | 5 | /// Implements the MapReduce operation as seen here: https://research.google/pubs/pub62/ 6 | pub struct MapReduce Vec<(K, V)>, Reduce: Fn((K, Vec)) -> Vec> { 7 | map: Map, 8 | reduce: Reduce, 9 | _phantom: PhantomData<(I, K, V, O)>, 10 | } 11 | 12 | impl Vec<(K, V)> + Clone, Reduce: Fn((K, Vec)) -> Vec + Clone> Clone 13 | for MapReduce 14 | { 15 | fn clone(&self) -> Self { 16 | Self { 17 | map: self.map.clone(), 18 | reduce: self.reduce.clone(), 19 | _phantom: self._phantom, 20 | } 21 | } 22 | } 23 | 24 | impl Vec<(K, V)>, Reduce: Fn((K, Vec)) -> Vec> 25 | MapReduce 26 | { 27 | pub fn new(map: Map, reduce: Reduce) -> Self { 28 | Self { 29 | map, 30 | reduce, 31 | _phantom: PhantomData::default(), 32 | } 33 | } 34 | } 35 | 36 | impl Vec<(K, V)>, Reduce: Fn((K, Vec)) -> Vec> Node> 37 | for MapReduce 38 | { 39 | type Output = Vec; 40 | 41 | fn process(&mut self, input: Vec) -> Self::Output { 42 | group(input.into_iter().flat_map(&self.map)) 43 | .into_iter() 44 | .flat_map(&self.reduce) 45 | .collect() 46 | } 47 | } 48 | 49 | fn group(v: I) -> BTreeMap> 50 | where 51 | A: Ord, 52 | I: IntoIterator, 53 | { 54 | let mut result = BTreeMap::>::new(); 55 | for (a, b) in v { 56 | result.entry(a).or_default().push(b); 57 | } 58 | result 59 | } 60 | -------------------------------------------------------------------------------- /src/pipeline/premade/mod.rs: -------------------------------------------------------------------------------- 1 | mod stateful; 2 | pub use stateful::*; 3 | mod batch; 4 | pub use batch::*; 5 | mod sort; 6 | pub use sort::*; 7 | mod mapreduce; 8 | pub use mapreduce::*; 9 | mod map; 10 | pub use map::*; 11 | mod shuffle; 12 | pub use shuffle::*; 13 | mod selector; 14 | pub use selector::*; 15 | -------------------------------------------------------------------------------- /src/pipeline/premade/selector.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use itertools::Itertools; 4 | 5 | use crate::pipeline::Node; 6 | 7 | /// Equally selects from N nodes that all take in the same input and give the same output 8 | /// 9 | /// ### Example 10 | /// ``` 11 | /// use dataflow::prelude::*; 12 | /// 13 | /// // Combine two file loading pipelines together, each taking Vec<()> as input and outputting Vec<(PathBuf, Vec)> 14 | /// let combined_pipe = BalancedSelector::default() 15 | /// .add_node(FileLoader::new(vec!["file1".into()])) 16 | /// .add_node(FileLoader::new(vec!["file2".into()]).map(|(path, file)| { 17 | /// // Per-pipeline processing here 18 | /// (path, file) 19 | /// })); 20 | /// ``` 21 | #[derive(Default)] 22 | pub struct BalancedSelector { 23 | nodes: Vec, Output = Vec> + Send>>, 24 | _phantom: PhantomData<(I, O)>, 25 | } 26 | 27 | impl BalancedSelector { 28 | pub fn add_node, Output = Vec> + 'static + Send>(mut self, node: N) -> Self { 29 | self.nodes.push(Box::new(node)); 30 | self 31 | } 32 | } 33 | 34 | impl Node> for BalancedSelector { 35 | type Output = Vec; 36 | 37 | fn process(&mut self, mut input: Vec) -> Self::Output { 38 | // Distribute the inputs amoung the nodes in proportion to their remaining data 39 | let remaining_data = self 40 | .nodes 41 | .iter() 42 | .map(|i| i.data_remaining(usize::MAX)) 43 | .collect_vec(); 44 | let total_remaining_data = remaining_data.iter().sum::() as f64; 45 | remaining_data 46 | .into_iter() 47 | .enumerate() 48 | .flat_map(|(index, i)| { 49 | let proportion = i as f64 / total_remaining_data; 50 | if input.is_empty() { 51 | return vec![]; 52 | } 53 | self.nodes[index].process( 54 | input 55 | .drain(..((input.len() as f64 * proportion) as usize).min(input.len())) 56 | .collect(), 57 | ) 58 | }) 59 | .collect() 60 | } 61 | 62 | fn data_remaining(&self, before: usize) -> usize { 63 | self.nodes.iter().map(|n| n.data_remaining(before)).sum() 64 | } 65 | 66 | fn reset(&mut self) { 67 | for node in &mut self.nodes { 68 | node.reset(); 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /src/pipeline/premade/shuffle.rs: -------------------------------------------------------------------------------- 1 | use rand::{rngs::StdRng, seq::SliceRandom, SeedableRng}; 2 | 3 | use crate::pipeline::Node; 4 | use std::marker::PhantomData; 5 | 6 | pub struct Shuffle { 7 | rng: StdRng, 8 | _phantom: PhantomData, 9 | } 10 | 11 | impl Default for Shuffle { 12 | fn default() -> Self { 13 | Self { 14 | rng: StdRng::from_entropy(), 15 | _phantom: Default::default(), 16 | } 17 | } 18 | } 19 | 20 | impl Node> for Shuffle { 21 | type Output = Vec; 22 | 23 | fn process(&mut self, mut input: Vec) -> Self::Output { 24 | input.shuffle(&mut self.rng); 25 | input 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/pipeline/premade/sort.rs: -------------------------------------------------------------------------------- 1 | use crate::pipeline::Node; 2 | use std::{cmp::Ordering, marker::PhantomData}; 3 | 4 | pub struct Sort Ordering> { 5 | _phantom: PhantomData, 6 | sort_fn: F, 7 | } 8 | 9 | impl Ordering> Clone for Sort { 10 | fn clone(&self) -> Self { 11 | Self { 12 | _phantom: self._phantom, 13 | sort_fn: self.sort_fn.clone(), 14 | } 15 | } 16 | } 17 | 18 | impl Ordering> Sort { 19 | pub fn new(sort_fn: F) -> Self { 20 | Sort { 21 | _phantom: PhantomData::default(), 22 | sort_fn, 23 | } 24 | } 25 | } 26 | 27 | impl Ordering> Node> for Sort { 28 | type Output = Vec; 29 | 30 | fn process(&mut self, mut input: Vec) -> Self::Output { 31 | input.sort_by(&self.sort_fn); 32 | input 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/pipeline/premade/stateful.rs: -------------------------------------------------------------------------------- 1 | use crate::pipeline::Node; 2 | use std::marker::PhantomData; 3 | 4 | pub struct Stateful O, R: Fn(usize) -> usize> { 5 | _phantom: PhantomData<(I, O)>, 6 | function: F, 7 | state: S, 8 | remaining: R, 9 | } 10 | 11 | impl O + Clone, R: Fn(usize) -> usize + Clone> Clone 12 | for Stateful 13 | { 14 | fn clone(&self) -> Self { 15 | Self { 16 | _phantom: self._phantom, 17 | function: self.function.clone(), 18 | state: self.state.clone(), 19 | remaining: self.remaining.clone(), 20 | } 21 | } 22 | } 23 | 24 | fn identity_remaining(before: usize) -> usize { 25 | before 26 | } 27 | 28 | impl O> Stateful usize> { 29 | /// Initialize a new stateful node, with a state and a process function. 30 | pub fn new(state: S, function: F) -> Self { 31 | Stateful { 32 | _phantom: PhantomData::default(), 33 | function, 34 | state, 35 | remaining: identity_remaining, 36 | } 37 | } 38 | } 39 | 40 | impl O, R: Fn(usize) -> usize> Stateful { 41 | pub fn remaining usize>(self, remaining_fn: N) -> Stateful { 42 | Stateful { 43 | _phantom: PhantomData::default(), 44 | function: self.function, 45 | state: self.state, 46 | remaining: remaining_fn, 47 | } 48 | } 49 | } 50 | 51 | impl O, R: Fn(usize) -> usize> Node for Stateful { 52 | type Output = O; 53 | 54 | fn process(&mut self, input: I) -> Self::Output { 55 | (self.function)(input, &mut self.state) 56 | } 57 | 58 | fn data_remaining(&self, before: usize) -> usize { 59 | (self.remaining)(before) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/pipeline/tests.rs: -------------------------------------------------------------------------------- 1 | use std::{marker::PhantomData, thread}; 2 | 3 | use crate::pipeline::*; 4 | 5 | // Helper functions 6 | fn add_ten(nums: Vec) -> Vec { 7 | nums.into_iter().map(|n| n + 10).collect() 8 | } 9 | fn convert_to_int(inp: Vec) -> Vec { 10 | inp.into_iter().map(|i| i.parse::().unwrap()).collect() 11 | } 12 | fn greet(inp: Vec) -> Vec { 13 | inp.into_iter().map(|i| format!("Hello {}", i)).collect() 14 | } 15 | fn concat_strings(inp: Vec<(String, String)>) -> Vec { 16 | inp.into_iter() 17 | .map(|(a, b)| format!("{}{}", a, b)) 18 | .collect() 19 | } 20 | 21 | #[test] 22 | fn test_single_pipeline() { 23 | let mut pipeline = add_ten.map(|i: i32| i.to_string()).chain(greet); 24 | 25 | let inputs = vec![12, 3443, 123, 98543]; 26 | assert_eq!( 27 | Node::process(&mut pipeline, inputs), 28 | vec![ 29 | "Hello 22".to_string(), 30 | "Hello 3453".to_string(), 31 | "Hello 133".to_string(), 32 | "Hello 98553".to_string() 33 | ] 34 | ) 35 | } 36 | 37 | #[test] 38 | fn test_pair_pipeline() { 39 | let pipeline = add_ten 40 | .map(|i: i32| i.to_string()) 41 | .split( 42 | greet, 43 | convert_to_int.chain(add_ten).map(|i: i32| i.to_string()), 44 | ) 45 | .chain(|(a, b): (Vec, Vec)| { 46 | a.into_iter() 47 | .zip(b.into_iter()) 48 | .collect::>() 49 | }) 50 | .chain(concat_strings) 51 | .chain(greet); 52 | let inputs = vec![12, 3443, 123, 98543]; 53 | let mut holder = PipelineHolder { 54 | pipeline: Some(pipeline), 55 | _phantom: Default::default(), 56 | }; 57 | let outputs = run_pipeline(&mut holder, inputs); 58 | 59 | println!( 60 | "Examples left: {}", 61 | Node::data_remaining(&holder.pipeline.unwrap(), 0) 62 | ); 63 | assert_eq!( 64 | outputs, 65 | vec![ 66 | "Hello Hello 2232".to_string(), 67 | "Hello Hello 34533463".to_string(), 68 | "Hello Hello 133143".to_string(), 69 | "Hello Hello 9855398563".to_string() 70 | ] 71 | ); 72 | } 73 | 74 | #[test] 75 | fn test_map_reduce_pipeline() { 76 | let mut pipeline = MapReduce::new( 77 | // Count even and odd numbers 78 | |mut num: i32| { 79 | num += 10; 80 | vec![(num % 2 == 0, num)] 81 | }, 82 | |(is_even, nums)| { 83 | vec![format!( 84 | "{}: {nums:?}", 85 | if is_even { "Even" } else { "Odd" } 86 | )] 87 | }, 88 | ); 89 | 90 | let inputs = vec![12, 3443, 124, 98543]; 91 | assert_eq!( 92 | Node::process(&mut pipeline, inputs), 93 | vec!["Odd: [3453, 98553]", "Even: [22, 134]",] 94 | ) 95 | } 96 | 97 | struct PipelineHolder> { 98 | pub pipeline: Option, 99 | _phantom: PhantomData, 100 | } 101 | 102 | fn run_pipeline + Send + 'static>( 103 | pipeline_holder: &mut PipelineHolder, 104 | input: I, 105 | ) -> N::Output 106 | where 107 | I: Send + 'static, 108 | N::Output: Send + 'static, 109 | { 110 | let mut pipeline = pipeline_holder.pipeline.take().unwrap(); 111 | let handle = thread::spawn(move || (pipeline.process(input), pipeline)); 112 | let (output, pipeline) = handle.join().unwrap(); 113 | pipeline_holder.pipeline = Some(pipeline); 114 | output 115 | } 116 | --------------------------------------------------------------------------------