├── .gitignore ├── Cargo.toml ├── README.md ├── examples ├── config.toml ├── read_saved.rs └── write_data.rs ├── py └── train.py └── src ├── bench.rs ├── btree.rs ├── forwarding_model.rs ├── lib.rs ├── model.rs ├── neural.rs ├── synthetic.rs └── train.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "learned_index_structures" 3 | version = "0.1.0" 4 | authors = ["Michael Benfield "] 5 | 6 | [dependencies] 7 | rand = "0.5.5" 8 | tempfile = "3.0.4" 9 | toml = "0.4.7" 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learned Index Structures 2 | 3 | This is an implementation of Learned Index Structures as described in 4 | [this](https://arxiv.org/abs/1712.01208) 2017 paper. 5 | 6 | Normally in a database the mapping from key to index is done using a traditional 7 | data structure, often a B Tree. In this paper, the authors proposed instead to 8 | use function approximation to implement that mapping. Namely, they used neural 9 | networks. To elaborate, they used a hierarchy of neural networks, with each 10 | stage of the hierarchy doing model selection for the next stage. The hierarchy 11 | ends either in a B Tree, which maps key to index for a much smaller selection of 12 | records, or in a neural network, which provides an approximation to the correct 13 | index. Since this is only an approximation, a linear or binary search must be 14 | performed among nearby records to find the correct record. 15 | 16 | My model is a simplified version of that in the paper: I use only one top level 17 | neural net, which feeds into one of many B Trees. 18 | 19 | You can see a few slides I made about this project 20 | [here](https://docs.google.com/presentation/d/1lTMOZBnLd5YrKf26UCM2tFr3qWYijWIiY7W2es9g_tA/edit). 21 | 22 | ## Requirements 23 | 24 | You'll need a recent Python 3 with the `toml` library (do `pip install toml`) 25 | and Keras and Tensorflow. 26 | 27 | You'll need Rust and related tools; the easiest way to install them is to use 28 | [rustup](https://rustup.rs). 29 | 30 | ## How to use 31 | 32 | Training and inference is done using a combination of Rust programs in 33 | `examples/` and the Python script in `py/`. 34 | 35 | To generate 100,0000 data points (sampled from a log normal distribution), 36 | train a model, and benchmark both a B Tree and the learned model on 10,000 37 | index lookups, do the following from the main directory of this repository: 38 | ``` 39 | $ cargo run --example data_filename 100000 40 | $ python py/train.py --config examples/config.toml --index data_filename --save out.toml 41 | $ cargo run --release --example read_saved out.toml data_filename 42 | ``` 43 | 44 | If you want to modify the type of model used, you can modify the 45 | `examples/config.toml` file. The format is simple: the lines of the form "0 = 46 | 32" indicate the width of each layer. `btree_count` indicates how many btrees 47 | are used. 48 | 49 | ## Implementation notes 50 | 51 | The authors of the paper above implemented inference for their models in native 52 | code on the CPU. They gave two reasons for this choice: 53 | 54 | 1. Latency. A roundtrip to the GPU takes on the order of two microseconds. 55 | Python adds additional overhead, as does Tensorflow, which is optimized for 56 | larger models than used here. 57 | 58 | 2. Comparing apples to apples. Given that the new model is being compared 59 | against a sequential data structure implemented on the CPU, the authors appear 60 | to feel that it would not be justified to use different hardware for the new 61 | model. 62 | 63 | The first concern will hopefully be alleviated in the future as GPUs become more 64 | closely integrated with CPUs and main memory. 65 | 66 | In any case, I have followed the authors in this respect. Training in this 67 | project happens in a Python script using Tensorflow. Inference happens in native 68 | code, written in Rust. 69 | 70 | For the reader unfamiliar with Rust, it's a newer systems programming language 71 | suitable for the same domains as C or C++, but with many modern features, 72 | notably a lifetime system that provides memory safety without garbage 73 | collection. 74 | 75 | ## Benchmarks 76 | 77 | Currently, on my system, executing the commands under `How to use` above gives 78 | these performance results: 79 | 80 | | Model | Runtime (sec) | 81 | | ------------- | ------------- | 82 | | B Tree | 0.001 | 83 | | Learned Model | 0.052 | 84 | 85 | That is, the learned model is slow. 86 | 87 | This is entirely predictable: the current implementation simply does matrix 88 | multiplication scalar by scalar. I looked at the assembly output by the compiler 89 | and the code is not auto-vectorized at all. 90 | 91 | I'm working on (and will hopefully finish very shortly) a more optimized version 92 | using vectorized AVX instructions. This should be dramatically faster than the 93 | current code. I'm interested to see whether it beats the B Tree. 94 | 95 | Note that the authors of the paper used custom code generation techniques to 96 | achieve substantially greater performance than an optimized B Tree. 97 | Unfortunately their code is not publicly available. 98 | 99 | ## Discussion 100 | 101 | I think the idea of treating the mapping from key to index as a function 102 | approximation problem is really interesting. 103 | 104 | Having worked with and thought about this idea for some time though, I am 105 | skeptical that neural networks are the right tool. Note that in the case of the 106 | randomly generated log normal data, we are essentially just learning the 107 | cumulative distribution function of the log normal distribution. This is a 108 | relatively simple function, and a hierarchy of thousands of neural nets is an 109 | extraordinarily heavyweight tool for this purpose. 110 | 111 | There have been several interesting discussions online of this paper, notably 112 | [this blog 113 | post](http://databasearchitects.blogspot.com/2017/12/the-case-for-b-tree-index-structures.html). 114 | The blogger mentions the possibility of using splines for function 115 | approximation, and points out that B Trees in some sense already perform the 116 | same partition/interpolate process as splines. See also comments by Tim Kraska, 117 | one of the authors of of the original paper, where he clarifies that the paper 118 | was intended to introduce the idea of machine learning for this application, not 119 | necessarily to suggest that neural networks are the best tool for the purpose. 120 | 121 | Finally, I'll mention one more point of view to illustrate why I am skeptical of 122 | neural nets for this application. The authors talk about their model in terms of 123 | precision gain. A given neural network in their hierarchy may reduce the 124 | potential error in predicted index from 100M to 10k, so this is a precision gain 125 | of 10,000. But assuming for simplicity that the B Tree we're replacing is 126 | binary, even with that large precision gain we're only doing the work of about 127 | lg(10,000) = 13 levels of the B Tree. 128 | 129 | Traversing each layer of a B Tree requires very little computation. In contrast, 130 | even a small neural network as used in the paper requires many thousands of 131 | operations. Due to the existence of modern CPUs with wide SIMD instructions, 132 | with substantial engineering effort the authors are able to make much of that 133 | massive amount of computation happen in parallel and beat the performance of a B 134 | Tree, which cannot be parallelized. Nevertheless, I believe in practice a more 135 | work-efficient approach would be preferable, especially considering total 136 | throughput. 137 | 138 | In particular, if I had an infinite amount of time, I would investigate 139 | the following approaches: 140 | 141 | 1. GPU-based B Trees. Although a single B Tree search cannot be parallelized, it 142 | would be possible to batch indexes and perform many thousands in parallel on 143 | the GPU. This would not solve the latency issue, but for applications where 144 | throughput rather than latency is the concern, I'd be interested to see the 145 | results. 146 | 147 | 2. Other function approximation techniques. There are so many methods it's hard 148 | to know where to start. I mentioned splines earlier. There are also Chebyshev 149 | polynomials, Remez's algorithm, and many more. 150 | -------------------------------------------------------------------------------- /examples/config.toml: -------------------------------------------------------------------------------- 1 | [model] 2 | 0 = 32 3 | 1 = 32 4 | 2 = 32 5 | 3 = 32 6 | btree_count = 1000 -------------------------------------------------------------------------------- /examples/read_saved.rs: -------------------------------------------------------------------------------- 1 | extern crate learned_index_structures; 2 | 3 | use std::env; 4 | use std::time::Duration; 5 | 6 | use learned_index_structures::bench; 7 | use learned_index_structures::btree::BTree; 8 | use learned_index_structures::forwarding_model::{self, ForwardingModel}; 9 | 10 | fn duration_to_secs(dur: Duration) -> f64 { 11 | let mut secs = dur.as_secs() as f64; 12 | secs += dur.subsec_nanos() as f64 / 1000000000.0; 13 | secs 14 | } 15 | 16 | fn main() { 17 | let args: Vec = env::args().collect(); 18 | let data = forwarding_model::read_data(&args[2]); 19 | let model = ForwardingModel::read_toml(&args[1], &data); 20 | println!( 21 | "Time for neural net model: {:.4}", 22 | duration_to_secs(bench::bench(&model, &data, 10000)) 23 | ); 24 | let mut btree = BTree::new(); 25 | for i in 0..data.len() { 26 | btree.insert(data[i], i as u32); 27 | } 28 | println!( 29 | "Time for B Tree: {:.4}", 30 | duration_to_secs(bench::bench(&btree, &data, 10000)) 31 | ); 32 | } 33 | -------------------------------------------------------------------------------- /examples/write_data.rs: -------------------------------------------------------------------------------- 1 | extern crate learned_index_structures; 2 | 3 | use learned_index_structures::synthetic; 4 | 5 | use std::env; 6 | use std::fs::File; 7 | use std::io::Write; 8 | 9 | fn main() { 10 | let args: Vec = env::args().collect(); 11 | let count: usize = args[2] 12 | .parse() 13 | .expect("Couldn't parse command line arguments"); 14 | let data = synthetic::gen_lognormal(count); 15 | 16 | let mut file = File::create(&args[1]).expect("Unable to open file"); 17 | for &datum in data.iter() { 18 | writeln!(file, "{}", datum).expect("Unable to write to file"); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /py/train.py: -------------------------------------------------------------------------------- 1 | from keras import backend as K 2 | from keras.models import Sequential 3 | from keras.layers import Dense, LeakyReLU 4 | import numpy as np 5 | import argparse 6 | 7 | 8 | def individual_model(keys, labels, config): 9 | model = Sequential() 10 | model.add(Dense(config['0'], input_dim=1)) 11 | model.add(LeakyReLU()) 12 | for i in range(1000): 13 | if str(i) not in config: 14 | break 15 | model.add(Dense(int(config[str(i)]))) 16 | model.add(LeakyReLU()) 17 | model.add(Dense(1)) 18 | model.compile(optimizer='adam', loss='mse', metrics=[ 19 | max_absolute_error, 'mse', 'mae']) 20 | model.fit(keys, labels, epochs=64, batch_size=32, verbose=1) 21 | # model.compile(optimizer='adam', loss=mean_fourth_error, 22 | # metrics=[max_absolute_error, 'mse', 'mae']) 23 | # model.fit(keys, labels, epochs=1, batch_size=32, verbose=1) 24 | # model.compile(optimizer='adam', loss=max_absolute_error, 25 | # metrics=[max_absolute_error, 'mse', 'mae']) 26 | # model.fit(keys, labels, epochs=1, batch_size=32, verbose=1) 27 | # model.compile(optimizer='adam', loss=max_absolute_error, 28 | # metrics=[max_absolute_error, 'mse', 'mae']) 29 | # model.fit(keys, labels, epochs=1, batch_size=128, verbose=1) 30 | return model 31 | 32 | 33 | def max_absolute_error_model(keys, labels, model): 34 | pred_labels = model.predict(keys) 35 | # for some inane reason, Python hangs if I use Keras functions here??? 36 | return np.max(np.abs(labels - pred_labels[:, 0])) 37 | 38 | 39 | def max_absolute_error(y_true, y_pred): 40 | return K.max(K.abs(y_true - y_pred)) 41 | 42 | 43 | def mean_fourth_error(y_true, y_pred): 44 | return K.mean(K.square(K.square(y_pred - y_true)), axis=-1) 45 | 46 | 47 | def mean_sixth_error(y_true, y_pred): 48 | return K.mean(K.square(K.square(K.square(y_pred - y_true))), axis=-1) 49 | 50 | 51 | def plot(keys, labels, model): 52 | pred_labels = model.predict(keys) 53 | import matplotlib.pyplot as plt 54 | plt.plot(keys, labels) 55 | plt.plot(keys, pred_labels) 56 | plt.show() 57 | 58 | 59 | def select_next_model(pred_label, max_pred, model_count): 60 | index = (K.flatten(pred_label) / max_pred) * model_count 61 | x = K.clip(index, 0, model_count - 1) 62 | return K.clip(index, 0, model_count - 1) 63 | 64 | 65 | def max_overestimate(pred_label, label): 66 | return max(K.get_value(K.max(K.flatten(pred_label) - K.flatten(label))), 0) 67 | 68 | 69 | def max_underestimate(pred_label, label): 70 | return max(K.get_value(K.max(K.flatten(label) - K.flatten(pred_label))), 0) 71 | 72 | 73 | def train(toml_file, data_file): 74 | import toml 75 | with open(toml_file) as f: 76 | text = f.read() 77 | t = toml.loads(text) 78 | 79 | keys = np.loadtxt(data_file, dtype=np.float32) 80 | keys = keys[:, np.newaxis] 81 | labels = np.arange(len(keys), dtype=np.float32) 82 | 83 | model = individual_model(keys, labels, t['model']) 84 | 85 | first_layer = model.layers[0] 86 | 87 | btree_indices = [[] for _ in range(t['model']['btree_count'])] 88 | 89 | pred_labels = model.predict(keys) 90 | 91 | selected_models = select_next_model( 92 | pred_labels, labels[-1], t['model']['btree_count']) 93 | 94 | selected_models = K.eval(selected_models) 95 | 96 | for i in range(len(pred_labels)): 97 | selected_model = int(selected_models[i]) 98 | btree_indices[selected_model].append(i) 99 | 100 | return model, btree_indices 101 | 102 | 103 | def save(filename, model, btree_indices): 104 | with open(filename, 'w') as f: 105 | j = 0 106 | for layer in model.layers: 107 | weight_list = layer.get_weights() 108 | if not weight_list: 109 | continue 110 | f.write("layer{} = [".format(j)) 111 | for array in weight_list: 112 | f.write('[') 113 | flattened = K.eval(K.flatten(array)) 114 | for item in flattened: 115 | f.write("{0:f}, ".format(item)) 116 | f.write('], ') 117 | f.write("]\n") 118 | j += 1 119 | f.write("btree_indices = {}\n".format(btree_indices)) 120 | 121 | 122 | def main(): 123 | parser = argparse.ArgumentParser() 124 | 125 | parser.add_argument( 126 | '--index', required=True, 127 | help="File containing newline separated numbers representing keys") 128 | 129 | parser.add_argument( 130 | '--config', required=True, 131 | help="TOML file describing model architecture", 132 | ) 133 | 134 | parser.add_argument( 135 | '--save', required=True, 136 | help="File in which to save trained model", 137 | ) 138 | 139 | args = parser.parse_args() 140 | 141 | model, btree_indices = train(args.config, args.index) 142 | 143 | save(args.save, model, btree_indices) 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /src/bench.rs: -------------------------------------------------------------------------------- 1 | //! Benchmarking models 2 | 3 | use std::time::{Duration, Instant}; 4 | 5 | use rand::distributions::Uniform; 6 | use rand::{FromEntropy, Rng, XorShiftRng}; 7 | 8 | use model::Model; 9 | 10 | pub fn duration_to_secs(dur: Duration) -> f64 { 11 | let secs = dur.as_secs() as f64; 12 | let frac = dur.subsec_millis() as f64 / 1000.0; 13 | secs + frac 14 | } 15 | 16 | /// Randomly sample `count` keys from `data`, call `eval_many` on `model`, and 17 | /// return how long `eval_many` took. 18 | pub fn bench(model: &M, data: &[f32], count: usize) -> Duration 19 | where 20 | M: Model, 21 | { 22 | let mut rng = XorShiftRng::from_entropy(); 23 | 24 | let keys: Vec = { 25 | let dist = Uniform::new(0, count); 26 | let mut vec = Vec::with_capacity(count); 27 | for _ in 0..count { 28 | vec.push(data[rng.sample(dist)]); 29 | } 30 | vec 31 | }; 32 | 33 | let mut indices = vec![None; count]; 34 | 35 | let t1 = Instant::now(); 36 | model.eval_many(&keys, &mut indices); 37 | let t2 = Instant::now(); 38 | 39 | t2.duration_since(t1) 40 | } 41 | -------------------------------------------------------------------------------- /src/btree.rs: -------------------------------------------------------------------------------- 1 | //! Simple implementation of B Trees. Refer to Cormen et al (2009). 2 | //! 3 | //! Like Cormen's description, but also maintains an index associated with each 4 | //! key. Currently delete is not implemented. Also, this is just an in-memory B 5 | //! Tree, so I don't worry about disk reads and writes. 6 | 7 | use model::Model; 8 | 9 | const T: usize = 8; 10 | 11 | #[derive(Copy, Clone, Default, Debug, Eq, PartialEq, Hash)] 12 | struct BTreeNode { 13 | keys: [K; 2 * T - 1], 14 | indices: [I; 2 * T - 1], 15 | key_count: u32, 16 | 17 | // Either 0xFFFFFFFF if a leaf, or else an index into the `children` Vec of 18 | // BTree 19 | children: u32, 20 | } 21 | 22 | /// A BTree, parametrized by key and index type. 23 | /// 24 | /// The minimum degree is fixed at 8. I've made this a fixed constant rather 25 | /// than some sort of parameter because I want it to be optimized away rather 26 | /// than a runtime variable, and because Rust currently has issues with 27 | /// associated consts as array lengths (see rustc github issue #29646). It's 28 | /// possible to work around this but I'm not taking the time, at least for now. 29 | #[derive(Clone, Debug, Eq, PartialEq, Hash)] 30 | pub struct BTree { 31 | nodes: Vec>, 32 | 33 | // each item is an index into `nodes` 34 | children: Vec<[u32; 2 * T]>, 35 | 36 | // index into `nodes` giving the root node 37 | root: u32, 38 | } 39 | 40 | impl Default for BTree 41 | where 42 | K: Copy + Default, 43 | I: Copy + Default, 44 | { 45 | fn default() -> Self { 46 | let mut root = BTreeNode::default(); 47 | root.children = 0xFFFFFFFF; 48 | BTree { 49 | nodes: vec![root], 50 | children: Vec::new(), 51 | root: 0, 52 | } 53 | } 54 | } 55 | 56 | // Some notes on the private interface of this BTree: 57 | // 58 | // Nodes are identified by their `u32` index. Rather than calling a method to 59 | // retrieve a BTreeNode, methods are provided to retrieve features of a node 60 | // given the index. For instance, `fn keys(&self, node: u32) -> &[K]` retrieves 61 | // a slice of the keys associated with a given node. 62 | impl BTree 63 | where 64 | K: Copy + Default + PartialEq + PartialOrd, 65 | I: Copy + Default, 66 | { 67 | /// Create a new BTree. 68 | pub fn new() -> Self { 69 | Default::default() 70 | } 71 | 72 | /// Find the index with this key that was inserted before any other index 73 | /// with this key, or `None` if the key is not in the tree. 74 | pub fn search(&self, key: K) -> Option { 75 | self.rsearch(self.root, key) 76 | } 77 | 78 | /// Insert `key` into the tree, mapping to `index`. 79 | /// 80 | /// As may be clear from the interface, no attempt is made to choose a 81 | /// reasonable, ordered index - the caller is responsible. 82 | pub fn insert(&mut self, key: K, index: I) { 83 | let r = self.root; 84 | if *self.key_count(r) == (2 * T - 1) as u32 { 85 | let s = self.nodes.len() as u32; 86 | self.root = s; 87 | self.nodes.push(Default::default()); 88 | self.nodes[s as usize].children = self.children.len() as u32; 89 | self.children.push(Default::default()); 90 | self.children_mut(s).unwrap()[0] = r; 91 | self.split_child(s, 0); 92 | self.insert_nonfull(s, key, index); 93 | } else { 94 | self.insert_nonfull(r, key, index); 95 | } 96 | } 97 | 98 | fn split_child(&mut self, x: u32, i: usize) { 99 | let z = self.nodes.len() as u32; 100 | self.nodes.push(Default::default()); 101 | let y = self.children(x).expect("No children")[i]; 102 | *self.key_count_mut(z) = (T - 1) as u32; 103 | for j in 0..T - 1 { 104 | self.keys_mut(z)[j] = self.keys(y)[j + T]; 105 | self.indices_mut(z)[j] = self.indices(y)[j + T]; 106 | } 107 | if let None = self.children(y) { 108 | self.nodes[z as usize].children = 0xFFFFFFFF; 109 | } else { 110 | self.nodes[z as usize].children = self.children.len() as u32; 111 | self.children.push(Default::default()); 112 | for j in 0..T { 113 | self.children_mut(z).unwrap()[j] = self.children(y).unwrap()[j + T]; 114 | } 115 | } 116 | *self.key_count_mut(y) = (T - 1) as u32; 117 | for j in i + 1..*self.key_count(x) as usize + 1 { 118 | let j0 = *self.key_count(x) as usize + 1 - j; 119 | let array = self.children_mut(x).unwrap(); 120 | array[j0 + 1] = array[j0]; 121 | } 122 | self.children_mut(x).unwrap()[i + 1] = z; 123 | for j in i..*self.key_count(x) as usize { 124 | self.keys_mut(x)[j + 1] = self.keys(x)[j]; 125 | self.indices_mut(x)[j + 1] = self.indices(x)[j]; 126 | } 127 | self.keys_mut(x)[i] = self.keys(y)[T - 1]; 128 | self.indices_mut(x)[i] = self.indices(y)[T - 1]; 129 | *self.key_count_mut(x) += 1; 130 | } 131 | 132 | fn keys(&self, node: u32) -> &[K] { 133 | &self.nodes[node as usize].keys 134 | } 135 | 136 | fn keys_mut(&mut self, node: u32) -> &mut [K] { 137 | &mut self.nodes[node as usize].keys 138 | } 139 | 140 | fn indices(&self, node: u32) -> &[I] { 141 | &self.nodes[node as usize].indices 142 | } 143 | 144 | fn indices_mut(&mut self, node: u32) -> &mut [I] { 145 | &mut self.nodes[node as usize].indices 146 | } 147 | 148 | fn children(&self, node: u32) -> Option<&[u32; 2 * T]> { 149 | let node = &self.nodes[node as usize]; 150 | if node.children == 0xFFFFFFFF { 151 | None 152 | } else { 153 | Some(&self.children[node.children as usize]) 154 | } 155 | } 156 | 157 | fn children_mut(&mut self, node: u32) -> Option<&mut [u32; 2 * T]> { 158 | let node = &self.nodes[node as usize]; 159 | if node.children == 0xFFFFFFFF { 160 | None 161 | } else { 162 | Some(&mut self.children[node.children as usize]) 163 | } 164 | } 165 | 166 | fn key_count(&self, node: u32) -> &u32 { 167 | &self.nodes[node as usize].key_count 168 | } 169 | 170 | fn key_count_mut(&mut self, node: u32) -> &mut u32 { 171 | &mut self.nodes[node as usize].key_count 172 | } 173 | 174 | fn rsearch(&self, node: u32, key: K) -> Option { 175 | for (i, &nodekey) in self.keys(node)[..*self.key_count(node) as usize] 176 | .iter() 177 | .enumerate() 178 | { 179 | if key == nodekey { 180 | return Some(self.indices(node)[i]); 181 | } else if key < nodekey { 182 | match self.children(node) { 183 | None => return None, 184 | Some(c) => return self.rsearch(c[i as usize], key), 185 | } 186 | } 187 | } 188 | match self.children(node) { 189 | None => None, 190 | Some(c) => self.rsearch(c[*self.key_count(node) as usize], key), 191 | } 192 | } 193 | 194 | fn insert_nonfull(&mut self, x: u32, key: K, index: I) { 195 | let mut i = *self.key_count(x) as isize - 1; 196 | if let None = self.children(x) { 197 | // x is a leaf 198 | while i >= 0 && key < self.keys(x)[i as usize] { 199 | self.keys_mut(x)[(i + 1) as usize] = self.keys(x)[i as usize]; 200 | self.indices_mut(x)[(i + 1) as usize] = self.indices(x)[i as usize]; 201 | i -= 1; 202 | } 203 | self.keys_mut(x)[(i + 1) as usize] = key; 204 | self.indices_mut(x)[(i + 1) as usize] = index; 205 | *self.key_count_mut(x) += 1; 206 | } else { 207 | // x is internal 208 | while i >= 0 && key < self.keys(x)[i as usize] { 209 | i -= 1; 210 | } 211 | i += 1; 212 | if *self.key_count(self.children(x).unwrap()[i as usize]) == 2 * T as u32 - 1 { 213 | self.split_child(x, i as usize); 214 | if key > self.keys(x)[i as usize] { 215 | i += 1; 216 | } 217 | } 218 | let c = self.children(x).unwrap()[i as usize]; 219 | self.insert_nonfull(c, key, index); 220 | } 221 | } 222 | } 223 | 224 | impl Model for BTree 225 | where 226 | K: Copy + Default + PartialEq + PartialOrd, 227 | I: Copy + Default, 228 | { 229 | fn eval(&self, key: K) -> Option { 230 | self.search(key) 231 | } 232 | } 233 | 234 | #[cfg(test)] 235 | mod tests { 236 | use super::*; 237 | 238 | #[test] 239 | fn t() { 240 | let mut b: BTree = Default::default(); 241 | for i in 0..500 { 242 | b.insert(i as f32, i as u32); 243 | } 244 | for i in 0..500 { 245 | assert_eq!(b.search(i as f32).unwrap(), i as u32); 246 | } 247 | } 248 | } 249 | -------------------------------------------------------------------------------- /src/forwarding_model.rs: -------------------------------------------------------------------------------- 1 | //! A model consisting of a top level neural net that selects one of several B 2 | //! Trees 3 | 4 | use std::fs::File; 5 | use std::io::{BufRead, BufReader, Read}; 6 | use std::path::Path; 7 | use std::str::FromStr; 8 | 9 | use toml::{self, Value}; 10 | 11 | use btree::BTree; 12 | use model::Model; 13 | use neural::Network; 14 | 15 | use self::Value::*; 16 | 17 | pub struct ForwardingModel { 18 | net: Network, 19 | btrees: Box<[BTree]>, 20 | max_prediction: u32, 21 | } 22 | 23 | impl Model for ForwardingModel { 24 | fn eval(&self, key: f32) -> Option { 25 | let buf_size = self.net.buf_size(); 26 | let mut buf1 = vec![0.0f32; buf_size]; 27 | let mut buf2 = vec![0.0f32; buf_size]; 28 | let predicted_label = self.net.apply_buffer(key, &mut buf1, &mut buf2); 29 | let model = 30 | ((predicted_label / self.max_prediction as f32) * self.btrees.len() as f32) as usize; 31 | self.btrees[model].eval(key) 32 | } 33 | 34 | fn eval_many(&self, keys: &[f32], indices: &mut [Option]) { 35 | let buf_size = self.net.buf_size(); 36 | let mut buf1 = vec![0.0f32; buf_size]; 37 | let mut buf2 = vec![0.0f32; buf_size]; 38 | for (i, &key) in keys.iter().enumerate() { 39 | let predicted_label = self.net.apply_buffer(key, &mut buf1, &mut buf2); 40 | let model = ((predicted_label / self.max_prediction as f32) * self.btrees.len() as f32) 41 | as usize; 42 | indices[i] = self.btrees[model].eval(key) 43 | } 44 | } 45 | } 46 | 47 | fn value_array_arrays(v: &Value) -> Box<[Box<[u32]>]> { 48 | if let Array(a) = v { 49 | let mut arrays: Vec> = Vec::new(); 50 | for value in a.iter() { 51 | if let Array(immediate_array) = value { 52 | let mut array: Vec = Vec::new(); 53 | for integer in immediate_array.iter() { 54 | if let Integer(i) = integer { 55 | array.push(*i as u32); 56 | } else { 57 | panic!("Invalid TOML format"); 58 | } 59 | } 60 | arrays.push(array.into_boxed_slice()); 61 | } else { 62 | panic!("Invalid TOML format"); 63 | } 64 | } 65 | return arrays.into_boxed_slice(); 66 | } else { 67 | panic!("Invalid TOML format"); 68 | } 69 | } 70 | 71 | pub fn read_data

(data_path: &P) -> Box<[f32]> 72 | where 73 | P: AsRef, 74 | { 75 | read_data0(data_path.as_ref()) 76 | } 77 | 78 | fn read_data0(data_path: &Path) -> Box<[f32]> { 79 | use std::string::String; 80 | let mut result = Vec::new(); 81 | let mut buf = String::new(); 82 | let mut file = BufReader::new(File::open(data_path).expect("Unable to open data file")); 83 | 84 | loop { 85 | if let Ok(_) = file.read_line(&mut buf) { 86 | if buf.len() == 0 { 87 | break; 88 | } else if buf.len() == 1 { 89 | continue; 90 | } 91 | buf.pop(); // drop the newline 92 | let value = f32::from_str(&buf).expect("Invalid data format"); 93 | result.push(value); 94 | buf.clear(); 95 | } else { 96 | panic!("file read error"); 97 | } 98 | } 99 | result.into_boxed_slice() 100 | } 101 | 102 | impl ForwardingModel { 103 | pub fn read_toml

(toml_path: &P, data: &Box<[f32]>) -> Self 104 | where 105 | P: AsRef, 106 | { 107 | Self::read_toml0(toml_path.as_ref(), data) 108 | } 109 | 110 | fn read_toml0(toml_path: &Path, data: &Box<[f32]>) -> Self { 111 | use std::cmp::max; 112 | 113 | let s = { 114 | use std::string::String; 115 | let mut buf = String::new(); 116 | let mut file = File::open(toml_path).expect("Unable to open TOML file"); 117 | file.read_to_string(&mut buf) 118 | .expect("Unable to read TOML file"); 119 | buf 120 | }; 121 | 122 | let value: Value = toml::from_str(&s).expect("Unable to parse TOML file"); 123 | 124 | let table = if let &Table(ref table) = &value { 125 | table 126 | } else { 127 | panic!("Bad TOML format"); 128 | }; 129 | 130 | let indices = if let Some(indices) = table.get("btree_indices") { 131 | indices 132 | } else { 133 | panic!("Invalid TOML format"); 134 | }; 135 | 136 | let arrays = value_array_arrays(&indices); 137 | 138 | let mut max_prediction: u32 = 0; 139 | 140 | let btrees: Vec> = arrays 141 | .iter() 142 | .map(|array| { 143 | let mut btree = BTree::new(); 144 | for &index in array.iter() { 145 | max_prediction = max(index, max_prediction); 146 | btree.insert(data[index as usize], index); 147 | } 148 | btree 149 | }) 150 | .collect(); 151 | 152 | let network = Network::from_toml(&value); 153 | 154 | return ForwardingModel { 155 | net: network, 156 | btrees: btrees.into_boxed_slice(), 157 | max_prediction, 158 | }; 159 | 160 | } 161 | } 162 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | extern crate rand; 2 | extern crate tempfile; 3 | extern crate toml; 4 | 5 | pub mod bench; 6 | pub mod btree; 7 | pub mod forwarding_model; 8 | pub mod model; 9 | pub mod neural; 10 | pub mod synthetic; 11 | pub mod train; 12 | 13 | #[cfg(test)] 14 | mod tests { 15 | use super::*; 16 | 17 | #[test] 18 | fn f() { 19 | let data = synthetic::gen_lognormal(10000); 20 | let mut b: btree::BTree = Default::default(); 21 | 22 | for (i, &v) in data.iter().enumerate() { 23 | b.insert(v, i as u32); 24 | } 25 | 26 | for &v in data.iter() { 27 | // we may not have the same index, in the case of duplicate values, 28 | // but the value at that index will be the same 29 | assert_eq!(data[b.search(v).unwrap() as usize], v); 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/model.rs: -------------------------------------------------------------------------------- 1 | //! The generic Model type. 2 | 3 | pub trait Model 4 | where 5 | K: Copy 6 | { 7 | fn eval(&self, key: K) -> Option; 8 | 9 | fn eval_many(&self, keys: &[K], indices: &mut [Option]) { 10 | for (i, &key) in keys.iter().enumerate() { 11 | indices[i] = self.eval(key); 12 | } 13 | } 14 | } 15 | 16 | -------------------------------------------------------------------------------- /src/neural.rs: -------------------------------------------------------------------------------- 1 | //! A fully connected neural network with Leaky ReLU activations. 2 | //! 3 | //! Not optimized at the moment, but at least there are fewer superfluous 4 | //! allocations now. 5 | 6 | use std::mem; 7 | use std::ops::{Index, IndexMut}; 8 | use std::slice; 9 | 10 | use toml::Value; 11 | 12 | // since we will eventually need 32-byte aligned memory for AVX instructions, we 13 | // have to jump through some hoops to allocate and deallocate 14 | 15 | #[repr(align(32))] 16 | struct Aligned(u8); 17 | 18 | const LEAKY_SLOPE: f32 = 0.3; 19 | 20 | fn ceil_div(dividend: usize, divisor: usize) -> usize { 21 | (dividend + divisor - 1) / divisor 22 | } 23 | 24 | fn allocate_aligned_f32(len: usize) -> *mut f32 { 25 | let aligned_len = ceil_div(4 * len, mem::align_of::()); 26 | let mut v: Vec = Vec::with_capacity(aligned_len); 27 | let ptr = v.as_mut_ptr(); 28 | mem::forget(v); 29 | unsafe { mem::transmute(ptr) } 30 | } 31 | 32 | fn deallocate_aligned_f32(ptr: *mut f32, len: usize) { 33 | let aligned_len = ceil_div(len, mem::size_of::()); 34 | unsafe { 35 | let _: Vec = Vec::from_raw_parts(mem::transmute(ptr), 0, aligned_len); 36 | } 37 | } 38 | 39 | fn value_array_arrays_float(v: &Value) -> Box<[Box<[f32]>]> { 40 | use self::Value::*; 41 | 42 | if let Array(a) = v { 43 | let mut arrays: Vec> = Vec::new(); 44 | for value in a.iter() { 45 | if let Array(immediate_array) = value { 46 | let mut array: Vec = Vec::new(); 47 | for integer in immediate_array.iter() { 48 | if let Float(i) = integer { 49 | array.push(*i as f32); 50 | } else { 51 | panic!("Invalid TOML format"); 52 | } 53 | } 54 | arrays.push(array.into_boxed_slice()); 55 | } else { 56 | panic!("Invalid TOML format"); 57 | } 58 | } 59 | return arrays.into_boxed_slice(); 60 | } else { 61 | panic!("Invalid TOML format"); 62 | } 63 | } 64 | 65 | #[repr(C)] 66 | pub struct FirstLayer { 67 | data: *mut f32, 68 | bias: *mut f32, 69 | size: usize, 70 | } 71 | 72 | impl Index for FirstLayer { 73 | type Output = f32; 74 | 75 | fn index(&self, i: usize) -> &f32 { 76 | if self.size <= i { 77 | panic!("FirstLayer: index out of bounds"); 78 | } else { 79 | unsafe { &*self.data.offset(i as isize) } 80 | } 81 | } 82 | } 83 | 84 | impl IndexMut for FirstLayer { 85 | fn index_mut(&mut self, i: usize) -> &mut f32 { 86 | if self.size <= i { 87 | panic!("FirstLayer: index out of bounds"); 88 | } else { 89 | unsafe { &mut *self.data.offset(i as isize) } 90 | } 91 | } 92 | } 93 | 94 | impl FirstLayer { 95 | pub fn new(size: usize) -> Self { 96 | FirstLayer { 97 | data: allocate_aligned_f32(size), 98 | bias: allocate_aligned_f32(size), 99 | size, 100 | } 101 | } 102 | 103 | fn bias(&self) -> &[f32] { 104 | unsafe { slice::from_raw_parts(self.bias, self.size) } 105 | } 106 | 107 | fn bias_mut(&mut self) -> &mut [f32] { 108 | unsafe { slice::from_raw_parts_mut(self.bias, self.size) } 109 | } 110 | } 111 | 112 | impl Drop for FirstLayer { 113 | fn drop(&mut self) { 114 | deallocate_aligned_f32(self.data, self.size); 115 | deallocate_aligned_f32(self.bias, self.size); 116 | } 117 | } 118 | 119 | #[repr(C)] 120 | pub struct LastLayer { 121 | data: *mut f32, 122 | size: usize, 123 | bias: f32, 124 | } 125 | 126 | impl Index for LastLayer { 127 | type Output = f32; 128 | 129 | fn index(&self, i: usize) -> &f32 { 130 | if self.size <= i { 131 | panic!("FirstLayer: index out of bounds"); 132 | } else { 133 | unsafe { &*self.data.offset(i as isize) } 134 | } 135 | } 136 | } 137 | 138 | impl IndexMut for LastLayer { 139 | fn index_mut(&mut self, i: usize) -> &mut f32 { 140 | if self.size <= i { 141 | panic!("FirstLayer: index too small"); 142 | } else { 143 | unsafe { &mut *self.data.offset(i as isize) } 144 | } 145 | } 146 | } 147 | 148 | impl LastLayer { 149 | pub fn new(size: usize) -> Self { 150 | LastLayer { 151 | data: allocate_aligned_f32(size), 152 | bias: 0.0, 153 | size, 154 | } 155 | } 156 | 157 | fn bias(&self) -> &f32 { 158 | &self.bias 159 | } 160 | 161 | fn bias_mut(&mut self) -> &mut f32 { 162 | &mut self.bias 163 | } 164 | } 165 | 166 | impl Drop for LastLayer { 167 | fn drop(&mut self) { 168 | deallocate_aligned_f32(self.data, self.size); 169 | } 170 | } 171 | 172 | #[repr(C)] 173 | pub struct InteriorLayer { 174 | data: *mut f32, 175 | bias: *mut f32, 176 | rows: usize, 177 | columns: usize, 178 | } 179 | 180 | impl InteriorLayer { 181 | fn new(rows: usize, columns: usize) -> Self { 182 | InteriorLayer { 183 | data: allocate_aligned_f32(rows * columns), 184 | bias: allocate_aligned_f32(rows), 185 | rows, 186 | columns, 187 | } 188 | } 189 | 190 | fn bias(&self) -> &[f32] { 191 | unsafe { slice::from_raw_parts(self.bias, self.rows) } 192 | } 193 | 194 | fn bias_mut(&mut self) -> &mut [f32] { 195 | unsafe { slice::from_raw_parts_mut(self.bias, self.rows) } 196 | } 197 | } 198 | 199 | impl Drop for InteriorLayer { 200 | fn drop(&mut self) { 201 | deallocate_aligned_f32(self.data, self.rows * self.columns); 202 | deallocate_aligned_f32(self.data, self.rows); 203 | } 204 | } 205 | 206 | impl Index<(usize, usize)> for InteriorLayer { 207 | type Output = f32; 208 | fn index(&self, i: (usize, usize)) -> &f32 { 209 | if i.0 >= self.rows || i.1 >= self.columns { 210 | panic!("InteriorLayer: index out of bounds") 211 | } else { 212 | unsafe { &*self.data.offset((self.columns * i.0 + i.1) as isize) } 213 | } 214 | } 215 | } 216 | 217 | impl IndexMut<(usize, usize)> for InteriorLayer { 218 | fn index_mut(&mut self, i: (usize, usize)) -> &mut f32 { 219 | if i.0 >= self.rows || i.1 >= self.columns { 220 | panic!("InteriorLayer: index out of bounds") 221 | } else { 222 | unsafe { &mut *self.data.offset((self.columns * i.0 + i.1) as isize) } 223 | } 224 | } 225 | } 226 | 227 | pub struct Network { 228 | first_layer: FirstLayer, 229 | last_layer: LastLayer, 230 | interior_layers: Box<[InteriorLayer]>, 231 | } 232 | 233 | impl Network { 234 | pub fn apply_buffer(&self, x: f32, buf1: &mut [f32], buf2: &mut [f32]) -> f32 { 235 | // first layer 236 | debug_assert!(buf1.len() >= self.first_layer.size); 237 | for i in 0..self.first_layer.size { 238 | buf1[i] = x * self.first_layer[i] + self.first_layer.bias()[i]; 239 | if buf1[i] < 0.0 { 240 | buf1[i] *= LEAKY_SLOPE; 241 | } 242 | } 243 | 244 | // interior layers 245 | fn write_layer( 246 | layers: &[InteriorLayer], 247 | layer_index: usize, 248 | read: &mut [f32], 249 | write: &mut [f32], 250 | ) { 251 | if layer_index >= layers.len() { 252 | return; 253 | } 254 | let layer = &layers[layer_index]; 255 | debug_assert!(read.len() >= layer.columns); 256 | debug_assert!(write.len() >= layer.rows); 257 | for row in 0..layer.rows { 258 | write[row] = 0.0; 259 | for col in 0..layer.columns { 260 | write[row] += layer[(row, col)] * read[col]; 261 | } 262 | write[row] += layer.bias()[row]; 263 | if write[row] < 0.0 { 264 | write[row] *= LEAKY_SLOPE; 265 | } 266 | } 267 | write_layer(layers, layer_index + 1, write, read); 268 | } 269 | 270 | write_layer(&self.interior_layers, 0, buf1, buf2); 271 | 272 | let mut result = 0.0f32; 273 | 274 | // last layer 275 | let read = if self.interior_layers.len() % 2 == 0 { 276 | buf1 277 | } else { 278 | buf2 279 | }; 280 | 281 | debug_assert!(read.len() >= self.last_layer.size); 282 | 283 | for row in 0..self.last_layer.size { 284 | result += self.last_layer[row] * read[row]; 285 | } 286 | 287 | result += *self.last_layer.bias(); 288 | 289 | if result < 0.0 { 290 | result *= LEAKY_SLOPE; 291 | } 292 | 293 | result 294 | } 295 | 296 | /// What size of buffer is necessary to pass to `apply_buffer`? 297 | pub fn buf_size(&self) -> usize { 298 | use std::cmp::max; 299 | 300 | let mut bufsize = 0usize; 301 | bufsize = max(bufsize, self.first_layer.size); 302 | for layer in self.interior_layers.iter() { 303 | bufsize = max(bufsize, layer.rows); 304 | } 305 | 306 | bufsize 307 | } 308 | 309 | /// Create a Network from a TOML value in my custom format. 310 | /// 311 | /// This is the only way to create a Network outside this module at the 312 | /// moment. 313 | pub fn from_toml(v: &Value) -> Self { 314 | use self::Value::*; 315 | 316 | let table = if let Table(table) = v { 317 | table 318 | } else { 319 | panic!("Bad TOML format"); 320 | }; 321 | 322 | let mut last_layer_index = 0usize; 323 | for i in 0.. { 324 | let layer_var = format!("layer{}", i); 325 | if let Some(_layer) = table.get(&layer_var) { 326 | last_layer_index = i; 327 | } else { 328 | break; 329 | } 330 | } 331 | 332 | if last_layer_index < 2 { 333 | panic!("Need at least two layers") 334 | } 335 | 336 | // first layer 337 | 338 | let first_layer_toml = if let Some(layer) = table.get("layer0") { 339 | layer 340 | } else { 341 | panic!("Bad TOML format"); 342 | }; 343 | 344 | let arrays = value_array_arrays_float(first_layer_toml); 345 | 346 | let mut first_layer = FirstLayer::new(arrays[0].len()); 347 | 348 | unsafe { 349 | slice::from_raw_parts_mut(first_layer.data, first_layer.size) 350 | .copy_from_slice(&arrays[0]); 351 | } 352 | first_layer.bias_mut().copy_from_slice(&arrays[1]); 353 | 354 | // interior layers 355 | 356 | let mut interior_layers = Vec::new(); 357 | 358 | let mut previous_layer_rows = first_layer.size; 359 | 360 | for layer_index in 1..last_layer_index { 361 | let layer_toml = if let Some(layer) = table.get(&format!("layer{}", layer_index)) { 362 | layer 363 | } else { 364 | unreachable!(); 365 | }; 366 | 367 | let arrays = value_array_arrays_float(layer_toml); 368 | 369 | if arrays[0].len() % previous_layer_rows != 0 { 370 | panic!("Invalid layer sizes: layer {}", layer_index); 371 | } 372 | 373 | let columns = previous_layer_rows; 374 | let rows = arrays[0].len() / previous_layer_rows; 375 | 376 | let mut layer = InteriorLayer::new(rows, columns); 377 | 378 | unsafe { 379 | slice::from_raw_parts_mut(layer.data, rows * columns).copy_from_slice(&arrays[0]); 380 | } 381 | layer.bias_mut().copy_from_slice(&arrays[1]); 382 | 383 | interior_layers.push(layer); 384 | 385 | previous_layer_rows = rows; 386 | } 387 | 388 | // last layer 389 | 390 | let last_layer_toml = if let Some(layer) = table.get(&format!("layer{}", last_layer_index)) 391 | { 392 | layer 393 | } else { 394 | panic!("Bad TOML format"); 395 | }; 396 | 397 | let arrays = value_array_arrays_float(last_layer_toml); 398 | 399 | let mut last_layer = LastLayer::new(arrays[0].len()); 400 | unsafe { 401 | slice::from_raw_parts_mut(last_layer.data, last_layer.size).copy_from_slice(&arrays[0]); 402 | } 403 | *last_layer.bias_mut() = arrays[1][0]; 404 | 405 | Network { 406 | first_layer, 407 | last_layer, 408 | interior_layers: interior_layers.into_boxed_slice(), 409 | } 410 | } 411 | } 412 | 413 | #[cfg(test)] 414 | mod tests { 415 | use super::*; 416 | 417 | #[test] 418 | fn f() { 419 | let mut first = FirstLayer::new(2); 420 | first[0] = 1.0; 421 | first[1] = 2.0; 422 | first.bias_mut()[0] = -3.0; 423 | first.bias_mut()[1] = 4.0; 424 | 425 | let mut interior = InteriorLayer::new(2, 2); 426 | interior[(0, 0)] = 1.0; 427 | interior[(0, 1)] = 2.0; 428 | interior[(1, 0)] = 3.0; 429 | interior[(1, 1)] = 4.0; 430 | interior.bias_mut()[0] = 5.0; 431 | interior.bias_mut()[1] = 5.0; 432 | 433 | let mut last = LastLayer::new(2); 434 | last[0] = -1.0; 435 | last[1] = 1.0; 436 | *last.bias_mut() = 2.0; 437 | 438 | let network = Network { 439 | first_layer: first, 440 | interior_layers: vec![interior].into_boxed_slice(), 441 | last_layer: last, 442 | }; 443 | 444 | let mut buf1 = vec![0.0, 0.0]; 445 | let mut buf2 = vec![0.0, 0.0]; 446 | 447 | let result = network.apply_buffer(1.0, &mut buf1, &mut buf2); 448 | 449 | const GOLDEN: f32 = 12.8; 450 | 451 | assert!((result - GOLDEN).abs() < 0.0001); 452 | } 453 | } 454 | -------------------------------------------------------------------------------- /src/synthetic.rs: -------------------------------------------------------------------------------- 1 | //! Generating synthetic data 2 | 3 | use rand::distributions::{Distribution, LogNormal}; 4 | use rand::{FromEntropy, XorShiftRng}; 5 | 6 | pub fn gen_numbers(mut f: F, count: usize) -> Box<[f32]> 7 | where 8 | F: FnMut() -> f32, 9 | { 10 | let mut result = Vec::with_capacity(count); 11 | for _ in 0..count { 12 | result.push(f()); 13 | } 14 | result.sort_by(|a, b| a.partial_cmp(b).unwrap()); 15 | result.into_boxed_slice() 16 | } 17 | 18 | /// generate `count` samples drawn from a Log-Normal distribution with mean 2.0 19 | /// and std deviation 3.0, sorted. 20 | pub fn gen_lognormal(count: usize) -> Box<[f32]> { 21 | let mut rng = XorShiftRng::from_entropy(); 22 | let lognormal = LogNormal::new(0.0, 0.25); 23 | gen_numbers(|| lognormal.sample(&mut rng) as f32, count) 24 | } 25 | -------------------------------------------------------------------------------- /src/train.rs: -------------------------------------------------------------------------------- 1 | //! Run the Python script `train.py` to train a hierarchy of models. 2 | 3 | use std::ffi::OsStr; 4 | use std::fs::File; 5 | use std::io::Write; 6 | use std::path::Path; 7 | use std::process::{Command, Stdio}; 8 | 9 | use tempfile::NamedTempFile; 10 | 11 | pub fn train

(data: &[f32], layers: usize, width: usize, threshold: usize, py_path: &P) 12 | where 13 | P: AsRef, 14 | { 15 | train0(data, layers, width, threshold, py_path.as_ref()); 16 | } 17 | 18 | fn train0(data: &[f32], layers: usize, width: usize, threshold: usize, py_path: &Path) { 19 | let os: &OsStr = py_path.as_ref(); 20 | let file_name = NamedTempFile::new().expect("Unable to create temp file"); 21 | { 22 | let mut file = File::create(&file_name).expect("Unable to open temp file"); 23 | for &datum in data.iter() { 24 | writeln!(file, "{}", datum).expect("Unable to write to temp file"); 25 | } 26 | } 27 | Command::new("python3.6") 28 | .arg(os) 29 | .args(&[ 30 | "--layers", 31 | &format!("{}", layers), 32 | "--width", 33 | &format!("{}", width), 34 | "--threshold", 35 | &format!("{}", threshold), 36 | ]) 37 | .arg("--index") 38 | .arg(file_name.path()) 39 | .stdout(Stdio::inherit()) 40 | .stderr(Stdio::inherit()) 41 | .output() 42 | .expect("Failed to execute Python script"); 43 | } 44 | --------------------------------------------------------------------------------