├── .gitignore ├── readme.md ├── Cargo.toml ├── license.txt ├── src ├── nodes │ ├── arithmetic.rs │ ├── activation.rs │ ├── embedding.rs │ ├── matmul.rs │ ├── global_pool.rs │ ├── mod.rs │ └── conv.rs ├── optimizers.rs ├── lib.rs └── graph.rs └── examples ├── saving └── main.rs ├── mnist ├── input.rs └── main.rs └── ptb ├── text_dataset.rs ├── beam_search.rs ├── rnn.rs ├── ops.rs └── main.rs /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | todo.md 5 | .DS_Store 6 | examples/data/* 7 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | This is a package for writing differential programs (i.e. neural networks) in 2 | rust. 3 | 4 | * Look at the documentation `docs.rs/drug` for a better description 5 | * Check out an example `cargo run --example mnist --release`! 6 | * Please give me feedback 7 | 8 | # Versions 9 | 10 | ## 0.0.2 11 | * Saving functionality 12 | * New optimizers: momentum, Adam, RMSProp 13 | * Nodes are now all part of one enum, rather than boxed traits 14 | * Changed some function type signatures as per clippy suggestions 15 | 16 | ## 0.0.1 17 | * Initial release 18 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "drug" 3 | version = "0.0.2" 4 | authors = ["Casper Neo "] 5 | license = "MIT" 6 | repository = "https://github.com/CasperN/drug" 7 | description = "A differentiable computation graph for neural networks." 8 | readme = "readme.md" 9 | keywords = ["differentiable", "neural", "network", "compuation", "graph"] 10 | categories = ["science"] 11 | exclude = ["examples/data/*", "doc/**/*.html"] 12 | # publish = false 13 | 14 | [dependencies] 15 | ndarray = { version="0.11.2", features=["serde-1"] } 16 | rand = "0.5.4" 17 | debug_stub_derive = "0.3.0" 18 | serde = "1.0" 19 | erased-serde = "0.3" 20 | serde_derive = "1.0" 21 | 22 | [dev-dependencies] 23 | byteorder = "1.2.4" 24 | itertools = "0.7.8" 25 | ron = "0.4.0" 26 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | Copyright 2018 Casper Neo 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so, 8 | subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /src/nodes/arithmetic.rs: -------------------------------------------------------------------------------- 1 | use ndarray::prelude::*; 2 | use nodes::Operation; 3 | 4 | #[derive(Debug, Serialize, Deserialize)] 5 | /// Elementwise addition operation 6 | pub struct Add(); 7 | 8 | #[derive(Debug, Serialize, Deserialize)] 9 | /// Elementwise multiplication operation 10 | pub struct Mult(); 11 | 12 | impl Operation for Add { 13 | fn eval(&self, inputs: &[ArrayViewD]) -> ArrayD { 14 | let mut res = inputs[0].to_owned(); 15 | for i in inputs { 16 | res += i; 17 | } 18 | res 19 | } 20 | fn grad(&self, inputs: &[ArrayViewD], _loss: ArrayViewD) -> Vec> { 21 | inputs.into_iter().map(|i| i.to_owned()).collect() 22 | } 23 | } 24 | 25 | impl Operation for Mult { 26 | fn eval(&self, inputs: &[ArrayViewD]) -> ArrayD { 27 | let mut res = inputs[0].to_owned(); 28 | for i in inputs { 29 | res *= i; 30 | } 31 | res 32 | } 33 | fn grad(&self, inputs: &[ArrayViewD], _loss: ArrayViewD) -> Vec> { 34 | assert_eq!(inputs.len(), 2); 35 | inputs.iter().rev().map(|v| v.to_owned()).collect() 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /examples/saving/main.rs: -------------------------------------------------------------------------------- 1 | extern crate drug; 2 | extern crate erased_serde; 3 | extern crate ndarray; 4 | extern crate ron; 5 | 6 | use ndarray::prelude::*; 7 | use std::fs::File; 8 | use std::io::Write; 9 | use std::path::Path; 10 | 11 | fn save() { 12 | let mut g = drug::Graph::default(); 13 | let a = g.constant(arr1(&[1.0, 2.0]).into_dyn()); 14 | let b = g.constant(arr1(&[3.0, 4.0]).into_dyn()); 15 | let c = g.mult(&[a, b]); 16 | g.forward(); 17 | assert_eq!(*g.get_value(c), arr1(&[3.0, 8.0]).into_dyn()); 18 | g.named_idxs.insert("c".to_string(), c); 19 | 20 | let path = Path::new("/tmp/drug.json"); 21 | let mut file = File::create(&path).expect("File creation error"); 22 | 23 | println!("Writing graph:\n{}", g); 24 | let g_str = ron::ser::to_string(&g).expect("Error serealizing graph"); 25 | file.write_all(g_str.as_bytes()).expect("Could not write"); 26 | } 27 | 28 | fn load() { 29 | let path = Path::new("/tmp/drug.json"); 30 | let f = File::open(&path).expect("File open error"); 31 | 32 | let g: drug::Graph = ron::de::from_reader(&f).unwrap(); 33 | let c = &g.named_idxs["c"]; 34 | 35 | assert_eq!(*g.get_value(*c), arr1(&[3.0, 8.0]).into_dyn()); 36 | println!("Read graph:\n{}", g); 37 | } 38 | 39 | fn main() { 40 | save(); 41 | load(); 42 | } 43 | -------------------------------------------------------------------------------- /src/nodes/activation.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{ArrayD, ArrayViewD}; 2 | use nodes::Operation; 3 | 4 | /// Elementwise Activation, which could be leaky relu, sigmoid or tanh 5 | #[derive(Debug, Serialize, Deserialize)] 6 | pub enum Activation { 7 | Relu { leak: f32 }, 8 | Sigmoid, 9 | Tanh, 10 | } 11 | 12 | impl Operation for Activation { 13 | fn eval(&self, inputs: &[ArrayViewD]) -> ArrayD { 14 | assert_eq!(inputs.len(), 1, "Activation accepts one input"); 15 | match self { 16 | Activation::Relu { leak } => inputs[0].mapv(|x| if x > 0.0 { x } else { x * leak }), 17 | Activation::Sigmoid => inputs[0].mapv(sig), 18 | Activation::Tanh => inputs[0].mapv(f32::tanh), 19 | } 20 | } 21 | fn grad(&self, inputs: &[ArrayViewD], loss: ArrayViewD) -> Vec> { 22 | assert_eq!(inputs.len(), 1, "Activation accepts one input"); 23 | 24 | let mut res = loss.to_owned(); 25 | match self { 26 | Activation::Relu { leak } => { 27 | res.zip_mut_with(&inputs[0], |l, i| { 28 | if *i < 0.0 { 29 | *l *= leak 30 | } 31 | }); 32 | } 33 | Activation::Sigmoid => { 34 | res.zip_mut_with(&inputs[0], |l, i| { 35 | let s = sig(*i); 36 | *l *= s * (1.0 - s); 37 | }); 38 | } 39 | Activation::Tanh => { 40 | res.zip_mut_with(&inputs[0], |l, i| { 41 | *l *= 1.0 - i.tanh().powi(2); 42 | }); 43 | } 44 | } 45 | vec![res] 46 | } 47 | } 48 | fn sig(x: f32) -> f32 { 49 | 1.0 / (1.0 + (-x).exp()) 50 | } 51 | -------------------------------------------------------------------------------- /src/nodes/embedding.rs: -------------------------------------------------------------------------------- 1 | use ndarray::prelude::*; 2 | use nodes::Operation; 3 | 4 | /// Trainable embedding operation, given an index and a 2d-array of embedding vectors, 5 | /// index into the embedding vectors. `FIXME` drug hardcodes `ArrayD` inside the graph so 6 | /// the index should be a `batch_size` length `arrayD` where the values are integers. 7 | #[derive(Debug, Serialize, Deserialize)] 8 | pub struct Embedding(); 9 | 10 | impl Operation for Embedding { 11 | fn eval(&self, inputs: &[ArrayViewD]) -> ArrayD { 12 | assert_eq!(inputs.len(), 2, "Embedding operation takes two inputs"); 13 | let embedding = inputs[0].view().into_dimensionality::().unwrap(); 14 | let code = inputs[1].view().into_dimensionality::().unwrap(); 15 | let batch_size = code.shape()[0]; 16 | let embedding_dim = embedding.shape()[1]; 17 | 18 | Array::from_shape_fn([batch_size, embedding_dim], |(b, d)| { 19 | let x = code[(b)] as usize; 20 | embedding[(x, d)] 21 | }) 22 | .into_dyn() 23 | } 24 | 25 | fn grad(&self, inputs: &[ArrayViewD], loss: ArrayViewD) -> Vec> { 26 | assert_eq!(inputs.len(), 2, "Embedding operation takes two inputs"); 27 | let loss = loss.into_dimensionality::().unwrap(); 28 | let embedding = inputs[0].view().into_dimensionality::().unwrap(); 29 | let code = inputs[1].view().into_dimensionality::().unwrap(); 30 | let batch_size = code.shape()[0]; 31 | let num_embeddings = embedding.shape()[0]; 32 | let embedding_dim = embedding.shape()[1]; 33 | 34 | let mut grad = Array::zeros([num_embeddings, embedding_dim]); 35 | for b in 0..batch_size { 36 | let code = code[(b)] as usize; 37 | for d in 0..embedding_dim { 38 | grad[(code, d)] += loss[(b, d)] 39 | } 40 | } 41 | vec![grad.into_dyn(), Array::zeros([]).into_dyn()] 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/nodes/matmul.rs: -------------------------------------------------------------------------------- 1 | use ndarray::prelude::*; 2 | use nodes::Operation; 3 | 4 | /// implements matrix multiply [Operation](trait.Operation.html). 5 | /// See [Node](enum.Node.html) constructor for full description. 6 | #[derive(Debug, Serialize, Deserialize)] 7 | pub struct MatMul(); 8 | 9 | impl Operation for MatMul { 10 | fn eval(&self, inputs: &[ArrayViewD]) -> ArrayD { 11 | assert_eq!(inputs.len(), 2); 12 | let weights = inputs[0].view().into_dimensionality::().unwrap(); 13 | let neurons = inputs[1].view().into_dimensionality::().unwrap(); 14 | 15 | neurons.dot(&weights).into_dyn() 16 | } 17 | fn grad(&self, inputs: &[ArrayViewD], loss: ArrayViewD) -> Vec> { 18 | assert_eq!(inputs.len(), 2); 19 | let weights = inputs[0].view().into_dimensionality::().unwrap(); 20 | let neurons = inputs[1].view().into_dimensionality::().unwrap(); 21 | let loss = loss.into_dimensionality::().unwrap(); 22 | 23 | let grad_weights = neurons.t().dot(&loss).into_dyn(); 24 | let grad_neurons = loss.dot(&weights.t()).into_dyn(); 25 | vec![grad_weights, grad_neurons] 26 | } 27 | } 28 | 29 | #[cfg(test)] 30 | mod tests { 31 | use super::*; 32 | use test::Bencher; 33 | use xavier_initialize; 34 | 35 | #[test] 36 | fn sample_eval() { 37 | let weights = arr2(&[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]).into_dyn(); 38 | let vecs = arr2(&[[1.0, 2.0], [3.0, 4.0]]).into_dyn(); 39 | let m = MatMul(); 40 | 41 | let o = m.eval(&[weights.view(), vecs.view()]); 42 | assert_eq!( 43 | o, 44 | arr2(&[[11.0, 14.0, 17.0, 20.0], [23.0, 30.0, 37.0, 44.0]]).into_dyn() 45 | ) 46 | } 47 | #[bench] 48 | fn bench_matmul_eval(b: &mut Bencher) { 49 | let weights = xavier_initialize(&[100, 150]); 50 | let vecs = xavier_initialize(&[8, 100]); 51 | let m = MatMul(); 52 | b.iter(|| m.eval(&[weights.view(), vecs.view()])); 53 | } 54 | #[bench] 55 | fn bench_matmul_grad(b: &mut Bencher) { 56 | let weights = xavier_initialize(&[100, 150]); 57 | let vecs = xavier_initialize(&[8, 100]); 58 | let m = MatMul(); 59 | let o = m.eval(&[weights.view(), vecs.view()]); 60 | b.iter(|| m.grad(&[weights.view(), vecs.view()], o.view())); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /examples/mnist/input.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::Read; 3 | use std::path::Path; 4 | 5 | use byteorder::BigEndian; 6 | use byteorder::ReadBytesExt; 7 | 8 | static IMG_MAGIC_NUMBER: u32 = 0x0000_0803; 9 | static LBL_MAGIC_NUMBER: u32 = 0x0000_0801; 10 | static ROWS: usize = 28; 11 | static COLS: usize = 28; 12 | 13 | pub fn images(path: &Path, expected_length: u32) -> Vec { 14 | // Read whole file in memory 15 | let mut content: Vec = Vec::new(); 16 | let mut file = { 17 | let mut fh = File::open(path) 18 | .unwrap_or_else(|_| panic!("Unable to find path to images at {:?}.", path)); 19 | let _ = fh 20 | .read_to_end(&mut content) 21 | .unwrap_or_else(|_| panic!("Unable to read whole file in memory ({})", path.display())); 22 | // The read_u32() method, coming from the byteorder crate's ReadBytesExt trait, cannot be 23 | // used with a `Vec` directly, it requires a slice. 24 | &content[..] 25 | }; 26 | 27 | let magic_number = file 28 | .read_u32::() 29 | .unwrap_or_else(|_| panic!("Unable to read magic number from {:?}.", path)); 30 | assert!(IMG_MAGIC_NUMBER == magic_number, "Incorrect Magic Number"); 31 | 32 | let length = file.read_u32::().unwrap() as u32; 33 | assert!(expected_length == length, "Unexpected Length"); 34 | 35 | let rows = file.read_u32::().unwrap() as usize; 36 | assert!(ROWS == rows, "Unexpected rows"); 37 | 38 | let cols = file.read_u32::().unwrap() as usize; 39 | assert!(COLS == cols, "Unexpected columns"); 40 | 41 | file.to_vec() 42 | .into_iter() 43 | .map(|x| f32::from(x) / 255.0) 44 | .collect() 45 | } 46 | 47 | pub fn labels(path: &Path, expected_length: u32) -> Vec { 48 | let mut file = 49 | File::open(path).unwrap_or_else(|_| panic!("Unable to find path to labels at {:?}.", path)); 50 | 51 | let magic_number = file 52 | .read_u32::() 53 | .unwrap_or_else(|_| panic!("Unable to read magic number from {:?}.", path)); 54 | assert!(LBL_MAGIC_NUMBER == magic_number, "Incorrect magic number"); 55 | 56 | let length = file 57 | .read_u32::() 58 | .unwrap_or_else(|_| panic!("Unable to length from {:?}.", path)); 59 | 60 | assert!(expected_length == length, "Unexpected length"); 61 | file.bytes().map(|b| b.unwrap() as usize).collect() 62 | } 63 | -------------------------------------------------------------------------------- /examples/ptb/text_dataset.rs: -------------------------------------------------------------------------------- 1 | use ndarray::prelude::*; 2 | use rand::{thread_rng, Rng}; 3 | use std::collections::HashMap; 4 | use std::fs::File; 5 | use std::io::Read; 6 | use std::path::Path; 7 | 8 | static DATA_DIR: &'static str = "examples/data/"; 9 | static TRAIN: &'static str = "ptb.train.txt"; 10 | 11 | #[allow(dead_code)] 12 | pub struct TextDataSet { 13 | pub char2idx: HashMap, 14 | pub idx2char: Vec, 15 | pub corpus: Vec>>, 16 | } 17 | impl TextDataSet { 18 | pub fn decode(&self, codes: &[usize]) -> String { 19 | codes.iter().map(|c| self.idx2char[*c]).collect() 20 | } 21 | 22 | pub fn new(batch_size: usize, seq_len: usize) -> Self { 23 | let mut contents = String::new(); 24 | let mut f = File::open(Path::new(DATA_DIR).join(TRAIN)).expect("train data not found"); 25 | f.read_to_string(&mut contents) 26 | .expect("something went wrong reading the file"); 27 | 28 | let mut coded_lines = Vec::new(); 29 | let mut char2idx = HashMap::new(); 30 | let mut idx2char = Vec::new(); 31 | // Tokenize characters 32 | for str_line in contents.lines() { 33 | let mut line = Vec::new(); 34 | 35 | for c in str_line.chars() { 36 | // Insert token `idx` and register new character if unseen. 37 | let token = char2idx.entry(c).or_insert_with(|| { 38 | idx2char.push(c); 39 | idx2char.len() - 1 40 | }); 41 | line.push(*token); 42 | } 43 | coded_lines.push(line); 44 | } 45 | // Cut up long lines to seq_len length 46 | let mut truncated: Vec> = coded_lines 47 | .into_iter() 48 | .flat_map(|l| { 49 | let v: Vec> = l.chunks_exact(seq_len).map(|x| x.to_vec()).collect(); 50 | v.into_iter() 51 | }) 52 | .collect(); 53 | thread_rng().shuffle(truncated.as_mut_slice()); 54 | 55 | // Batchify 56 | let corpus: Vec>> = truncated 57 | .chunks_exact(batch_size) 58 | .map(|chunk| { 59 | let mut batch = vec![]; 60 | for s in 0..seq_len { 61 | let x = Array::from_shape_fn([batch_size], |b| chunk[b][s] as f32).into_dyn(); 62 | batch.push(x); 63 | } 64 | batch 65 | }) 66 | .collect(); 67 | 68 | TextDataSet { 69 | corpus, 70 | char2idx, 71 | idx2char, 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/nodes/global_pool.rs: -------------------------------------------------------------------------------- 1 | use ndarray::prelude::*; 2 | use nodes::Operation; 3 | 4 | #[derive(Debug, Serialize, Deserialize)] 5 | /// Type of pooling operation (currently there is only average pooling). 6 | /// TODO enum max pool, avg pool, sum pool, min pool 7 | /// Implements [Operation](trait.Operation.html). 8 | /// See [Node](enum.Node.html) constructor for full description. 9 | pub enum GlobalPool { 10 | /// Reduces by taking the arithmetic mean 11 | Average, 12 | } 13 | 14 | #[allow(unused_variables)] 15 | impl Operation for GlobalPool { 16 | fn eval(&self, inputs: &[ArrayViewD]) -> ArrayD { 17 | assert_eq!(inputs.len(), 1, "GlobalPool takes one 4d-Array"); 18 | let input = &inputs[0]; 19 | 20 | match self { 21 | // Mean over axis 1 and 2 (but ndarray only supports mean over 1 axis at once) 22 | // In second mean_axis, axis 1 is original axis 2. 23 | GlobalPool::Average => input.mean_axis(Axis(1)).mean_axis(Axis(1)).into_dyn(), 24 | } 25 | } 26 | 27 | fn grad(&self, inputs: &[ArrayViewD], loss: ArrayViewD) -> Vec> { 28 | let loss = loss.into_dimensionality::().unwrap(); 29 | if let [n_b, n_i, n_j, n_c] = inputs[0].shape() { 30 | let res = match self { 31 | GlobalPool::Average => { 32 | let scale = 1.0 / *n_i as f32 / *n_j as f32; 33 | Array::from_shape_fn([*n_b, *n_i, *n_j, *n_c], |(b, _, _, c)| { 34 | loss[(b, c)] * scale 35 | }) 36 | .into_dyn() 37 | } 38 | }; 39 | vec![res] 40 | } else { 41 | unreachable!("Global pool grad should take in 2d-array or shape [batch, channels]") 42 | } 43 | } 44 | } 45 | 46 | #[cfg(test)] 47 | mod tests { 48 | use super::*; 49 | 50 | #[test] 51 | fn average_eval() { 52 | let x = Array::from_shape_vec([2, 3, 4, 5], (0..120).map(|x| x as f32).collect()).unwrap(); 53 | let g = GlobalPool::Average; 54 | let avg = g 55 | .eval(&[x.view().into_dyn()]) 56 | .into_dimensionality::() 57 | .unwrap(); 58 | assert_eq!( 59 | avg, 60 | aview2(&[ 61 | [27.5, 28.5, 29.5, 30.5, 31.5], 62 | [87.5, 88.5, 89.5, 90.5, 91.5] 63 | ]), 64 | "\nFailed comparision with `np.array(range(120)).reshape([2,3,4,5]).mean(axis=(1,2))`" 65 | ) 66 | } 67 | #[test] 68 | fn average_grad() { 69 | let inputs = Array::zeros([2, 3, 4, 5]).into_dyn(); 70 | let losses = Array::ones([2, 5]).into_dyn(); 71 | let g = GlobalPool::Average; 72 | let grad = g 73 | .grad(&[inputs.view().into_dyn()], losses.view()) 74 | .pop() 75 | .unwrap(); 76 | // .into_dimensionality::() 77 | // .unwrap(); 78 | assert_eq!( 79 | grad.into_dimensionality::().unwrap(), 80 | Array::ones([2, 3, 4, 5]) / 12.0, 81 | "\nFailed comparision with `np.array(range(120)).reshape([2,3,4,5]).mean(axis=(1,2))`" 82 | ) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /examples/ptb/beam_search.rs: -------------------------------------------------------------------------------- 1 | use drug::softmax; 2 | use ndarray::prelude::*; 3 | use rand::distributions::{Distribution, Uniform}; 4 | use rand::thread_rng; 5 | use std::cmp::Ordering; 6 | use std::collections::{HashMap, HashSet}; 7 | 8 | #[derive(Debug)] 9 | struct Beam { 10 | sequence: Vec, 11 | log_prob: f32, 12 | } 13 | #[derive(Debug)] 14 | pub struct BeamSearch { 15 | beams: Vec, 16 | width: usize, 17 | } 18 | // TODO support N layers of GRU 19 | impl BeamSearch { 20 | pub fn new(width: usize) -> Self { 21 | let mut beams = vec![]; 22 | for _ in 0..width { 23 | beams.push(Beam { 24 | sequence: vec![], 25 | log_prob: 0.0, 26 | }) 27 | } 28 | BeamSearch { beams, width } 29 | } 30 | pub fn into_codes(self) -> Vec> { 31 | self.beams.into_iter().map(|b| b.sequence).collect() 32 | } 33 | /// Find likelihood of all "next elements" of every sequence. 34 | /// Returns next hidden state and next words of the RNN sequence 35 | pub fn search( 36 | &mut self, 37 | hidden: &[ArrayD], 38 | logits: ArrayViewD, 39 | temperature: f32, 40 | ) -> (Vec>, ArrayD) { 41 | let mut b = 0; 42 | let mut top = HashMap::new(); 43 | let probs = softmax(&logits.into_dyn().mapv(|x| x / temperature)); 44 | 45 | while top.len() < self.width { 46 | let codes = weighted_sample(probs.slice(s!(b, ..)), 1); 47 | for code in codes.iter() { 48 | let new_log_prob = self.beams[b].log_prob + probs[(b, *code)].ln(); 49 | let mut new_seq = self.beams[b].sequence.to_vec(); 50 | new_seq.push(*code); 51 | top.insert(new_seq, (new_log_prob, b)); 52 | } 53 | b += 1; 54 | b %= self.width; 55 | } 56 | 57 | let mut top: Vec<(Beam, usize)> = top 58 | .into_iter() 59 | .map(|(sequence, (log_prob, b))| (Beam { sequence, log_prob }, b)) 60 | .collect(); 61 | 62 | top.sort_by(|a, b| { 63 | if let Some(Ordering::Less) = (a.0.log_prob).partial_cmp(&b.0.log_prob) { 64 | Ordering::Greater 65 | } else { 66 | Ordering::Less 67 | } 68 | }); 69 | top.truncate(self.width); 70 | 71 | // Next set of words, hidden states and update beams 72 | let new_words = Array::from_iter( 73 | top.iter() 74 | .map(|(beam, _)| *beam.sequence.last().expect("Empty beam?") as f32), 75 | ) 76 | .into_dyn(); 77 | 78 | let mut new_hidden = vec![]; 79 | for hid in hidden.iter() { 80 | let hdim = hid.shape()[1]; 81 | let new_hid = Array::from_shape_fn([top.len(), hdim], |(b, d)| { 82 | let orig = top[b].1; 83 | hid[Dim([orig, d])] 84 | }) 85 | .into_dyn(); 86 | 87 | new_hidden.push(new_hid); 88 | } 89 | self.beams = top.into_iter().map(|(beam, _b)| beam).collect(); 90 | 91 | (new_hidden, new_words) 92 | } 93 | } 94 | 95 | /// Returns width samples from each column from weights. 96 | fn weighted_sample(weights: ArrayView1, width: usize) -> Vec { 97 | let len = weights.shape()[0]; 98 | let unif = Uniform::new(0.0, 1.0); 99 | let mut rng = thread_rng(); 100 | let mut res = HashSet::new(); 101 | 102 | while res.len() < width.min(weights.len()) { 103 | let mut x = unif.sample(&mut rng); 104 | let mut code = 0; 105 | for w in weights.iter().take(len) { 106 | x -= w; 107 | if x > 0.0 { 108 | code += 1; 109 | } else { 110 | break; 111 | } 112 | } 113 | res.insert(code); 114 | } 115 | res.into_iter().collect() 116 | } 117 | -------------------------------------------------------------------------------- /examples/ptb/rnn.rs: -------------------------------------------------------------------------------- 1 | use drug::*; 2 | use ops::*; 3 | use serde::Serialize; 4 | 5 | pub trait RecurrentCell: Serialize { 6 | /// Constructor. 7 | fn new(g: &mut Graph, seq_in_dim: usize, hidden_dim: usize) -> Self; 8 | /// Adds an instance of itself, every instance shares the same parameters 9 | fn add_cell(&self, g: &mut Graph, hidden_in: Idx, seq_in: Idx) -> Idx; 10 | /// The index of hidden 0 11 | fn get_hidden0_idx(&self) -> Idx; 12 | } 13 | 14 | /// Holds stacked RecurrentCells in a graph 15 | #[derive(Serialize, Deserialize)] 16 | pub struct RecurrentLayers { 17 | layers: Vec, 18 | } 19 | impl RecurrentLayers { 20 | pub fn new(g: &mut Graph, dimensions: &[usize]) -> RecurrentLayers { 21 | assert!( 22 | dimensions.len() > 1, 23 | "Need to specify at least 1 input and output layer" 24 | ); 25 | let mut layers = vec![]; 26 | for i in 0..dimensions.len() - 1 { 27 | layers.push(T::new(g, dimensions[i], dimensions[i + 1])); 28 | } 29 | RecurrentLayers { layers } 30 | } 31 | pub fn get_hidden0_idxs(&self) -> Vec { 32 | self.layers.iter().map(|l| l.get_hidden0_idx()).collect() 33 | } 34 | pub fn add_cells(&self, g: &mut Graph, hiddens: &[Idx], seq_in: Idx) -> Vec { 35 | assert_eq!(self.layers.len(), hiddens.len()); 36 | let mut h = seq_in; 37 | let mut new_hiddens = vec![]; 38 | for (l, hid) in self.layers.iter().zip(hiddens.iter()) { 39 | h = l.add_cell(g, *hid, h); 40 | new_hiddens.push(h) 41 | } 42 | new_hiddens 43 | } 44 | } 45 | 46 | /// Basic vanilla RNN 47 | #[derive(Serialize, Deserialize)] 48 | pub struct RNNCell { 49 | hidden0: Idx, 50 | weights: Idx, 51 | } 52 | 53 | impl RecurrentCell for RNNCell { 54 | fn new(g: &mut Graph, seq_in_dim: usize, hidden_dim: usize) -> Self { 55 | RNNCell { 56 | // TODO hidden0 should be Ix2 but we add batch_size dim because im lazy 57 | // ideally there should be an op that stacks hidden0 batch_size times 58 | hidden0: g.param(&[1, hidden_dim]), 59 | weights: g.param(&[hidden_dim + seq_in_dim, hidden_dim]), 60 | } 61 | } 62 | fn add_cell(&self, g: &mut Graph, hidden_in: Idx, seq_in: Idx) -> Idx { 63 | let app = g.op(Append(), &[hidden_in, seq_in]); 64 | let update = g.matmul(self.weights, app); 65 | g.tanh(update) 66 | } 67 | fn get_hidden0_idx(&self) -> Idx { 68 | self.hidden0 69 | } 70 | } 71 | 72 | /// Gated recurrent unit. Computes a feature vector and reset vector at each step given previous 73 | /// state and input. The new state is a convex combination of the previous state and feature vector. 74 | /// This is mediated by the reset vector. 75 | #[derive(Serialize, Deserialize)] 76 | pub struct GatedRecurrentUnit { 77 | hidden0: Idx, 78 | feature: Idx, 79 | resets: Idx, 80 | } 81 | 82 | impl RecurrentCell for GatedRecurrentUnit { 83 | /// Register the params for one gated recurrent unit 84 | fn new(g: &mut Graph, seq_in_dim: usize, hidden_dim: usize) -> Self { 85 | GatedRecurrentUnit { 86 | // TODO hidden0 should be Ix2 but we add batch_size dim because im lazy 87 | // ideally there should be an op that stacks hidden0 batch_size times 88 | hidden0: g.param(&[1, hidden_dim]), 89 | feature: g.param(&[hidden_dim + seq_in_dim, hidden_dim]), 90 | resets: g.param(&[hidden_dim + seq_in_dim, hidden_dim]), 91 | } 92 | } 93 | /// Add an instance of the gated recurrent unit 94 | fn add_cell(&self, g: &mut Graph, hidden_in: Idx, seq_in: Idx) -> Idx { 95 | let app1 = g.op(Append(), &[hidden_in, seq_in]); 96 | 97 | // Extract features Gate 98 | let f_matmul = g.matmul(self.feature, app1); 99 | let feature = g.sigmoid(f_matmul); 100 | 101 | // Reset Gate 102 | let r_matmul = g.matmul(self.resets, app1); 103 | let reset = g.sigmoid(r_matmul); 104 | 105 | // Combine them and get predictions 106 | g.op(ConvexCombine(), &[hidden_in, feature, reset]) 107 | } 108 | fn get_hidden0_idx(&self) -> Idx { 109 | self.hidden0 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /examples/ptb/ops.rs: -------------------------------------------------------------------------------- 1 | use drug::*; 2 | use ndarray::prelude::*; 3 | 4 | #[derive(Debug, Serialize, Deserialize)] 5 | /// Operation that does [x, y, a] -> a * x + (1 - a) * y. This is used in gated recurrent units 6 | /// forget gate. 7 | pub struct ConvexCombine(); 8 | 9 | #[derive(Debug, Serialize, Deserialize)] 10 | /// Operation that takes two batches of vectos xs, ys and appends ys below xs. Supports 11 | /// broadcasting if the batch dimension of xs or ys is 1. 12 | pub struct Append(); 13 | 14 | impl nodes::Operation for ConvexCombine { 15 | fn eval(&self, inputs: &[ArrayViewD]) -> ArrayD { 16 | assert_eq!(inputs.len(), 3, "Convex combine takes 3 arguments x, y, a"); 17 | 18 | let y = inputs[1].to_owned(); 19 | let a = inputs[2].to_owned(); 20 | let mut x = inputs[0] 21 | .broadcast(y.shape()) 22 | .expect("ConvexCombine: Broadcast Failed") 23 | .to_owned(); 24 | 25 | azip!(mut x, a, y in { *x = a * *x + (1.0 - a) * y}); 26 | x.into_dyn() 27 | } 28 | fn grad(&self, inputs: &[ArrayViewD], loss: ArrayViewD) -> Vec> { 29 | assert_eq!(inputs.len(), 3, "Convex combine takes 3 arguments x, y, a"); 30 | let x = inputs[0].view().into_dimensionality::().unwrap(); 31 | let y = inputs[1].view().into_dimensionality::().unwrap(); 32 | let a = inputs[2].view().into_dimensionality::().unwrap(); 33 | let loss = loss.into_dimensionality::().unwrap(); 34 | 35 | if x.shape() == y.shape() && x.shape() == a.shape() {} 36 | let x_bs = x.shape()[0]; 37 | let y_bs = y.shape()[0]; 38 | let a_bs = a.shape()[0]; 39 | let num_channels = a.shape()[1]; 40 | 41 | let mut a_grad = Array::zeros([a_bs, num_channels]); 42 | let mut x_grad = Array::zeros([x_bs, num_channels]); 43 | let mut y_grad = Array::zeros([y_bs, num_channels]); 44 | 45 | for b in 0..a_bs.max(x_bs).max(y_bs) { 46 | for c in 0..num_channels { 47 | // TODO make this prettier 48 | let ab = if a_bs == 1 { 0 } else { b }; 49 | let xb = if x_bs == 1 { 0 } else { b }; 50 | let yb = if y_bs == 1 { 0 } else { b }; 51 | let ai = a[(ab, c)]; 52 | let xi = x[(xb, c)]; 53 | let yi = y[(yb, c)]; 54 | let li = loss[(b, c)]; 55 | a_grad[(b, c)] += li * (xi - yi); 56 | x_grad[(xb, c)] += li * ai; 57 | y_grad[(yb, c)] += li * (1.0 - ai); 58 | } 59 | } 60 | vec![x_grad.into_dyn(), y_grad.into_dyn(), a_grad.into_dyn()] 61 | } 62 | } 63 | 64 | impl nodes::Operation for Append { 65 | fn eval(&self, inputs: &[ArrayViewD]) -> ArrayD { 66 | let x = inputs[0] 67 | .view() 68 | .into_dimensionality::() 69 | .expect("Append x dim error"); 70 | let y = inputs[1] 71 | .view() 72 | .into_dimensionality::() 73 | .expect("Append y dim error"); 74 | let x_bn = x.shape()[0]; 75 | let y_bn = y.shape()[0]; 76 | assert!( 77 | x_bn == y_bn || y_bn == 1 || x_bn == 1, 78 | "`Append::eval`: `x` and `y` batch sizes do not align and neither is 1." 79 | ); 80 | 81 | let x_len = x.shape()[1]; 82 | let y_len = y.shape()[1]; 83 | 84 | Array::from_shape_fn([x_bn.max(y_bn), x_len + y_len], |(b, i)| { 85 | if i < x_len && x_bn == 1 { 86 | x[(0, i)] 87 | } else if i < x_len { 88 | x[(b, i)] 89 | } else if y_bn == 1 { 90 | y[(0, i - x_len)] 91 | } else { 92 | y[(b, i - x_len)] 93 | } 94 | }) 95 | .into_dyn() 96 | } 97 | fn grad(&self, inputs: &[ArrayViewD], loss: ArrayViewD) -> Vec> { 98 | let x = inputs[0].view().into_dimensionality::().unwrap(); 99 | let y = inputs[1].view().into_dimensionality::().unwrap(); 100 | let loss = loss.into_dimensionality::().unwrap(); 101 | let x_bn = x.shape()[0]; 102 | let y_bn = y.shape()[0]; 103 | assert!( 104 | x_bn == y_bn || y_bn == 1 || x_bn == 1, 105 | "`Append::grad`: `x` and `y` batch sizes do not align and neither is 1." 106 | ); 107 | let (x_len, y_len) = (x.shape()[1], y.shape()[1]); 108 | 109 | let x_grad = if x_bn == 1 { 110 | loss.sum_axis(Axis(0)) 111 | .slice_move(s![..x_len]) 112 | .insert_axis(Axis(0)) 113 | } else { 114 | Array::from_shape_fn([x_bn, x_len], |(b, xi)| loss[(b, xi)]) 115 | }; 116 | 117 | let y_grad = if y_bn == 1 { 118 | loss.sum_axis(Axis(0)) 119 | .slice_move(s![x_len..]) 120 | .insert_axis(Axis(0)) 121 | } else { 122 | Array::from_shape_fn([y_bn, y_len], |(b, yi)| loss[(b, yi + x_len)]) 123 | }; 124 | 125 | vec![x_grad.into_dyn(), y_grad.into_dyn()] 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /src/optimizers.rs: -------------------------------------------------------------------------------- 1 | //! This module holds the various optimizers used to update parameters in a computation graph. 2 | //! Currently only one is implemented. 3 | use ndarray::{ArrayD, ArrayViewMutD}; 4 | use std::collections::HashMap; 5 | use std::{f32, fmt}; 6 | use Idx; 7 | 8 | #[derive(Debug, Serialize, Deserialize)] 9 | struct OptimizerInstance { 10 | /// Accumulates average gradient. Used in Momentum and Adam 11 | momentum: Option>, 12 | /// Accumulates gradient squared, used in RMSProp and Adam 13 | magnitude: Option>, 14 | // param_magnitude for Adadelta 15 | } 16 | 17 | /// Here is a good [blog that explains various optimizers](http://ruder.io/optimizing-gradient-descent/index.html). 18 | /// Currently only SGD, RMSProp, Adam, and SGD-with-momentum are implemented. 19 | /// The `Optimizer`struct builds and holds `OptimizerInstance`s which 20 | /// hold runtime information about every parameter that's being optimized. 21 | /// If `beta_momentum` or `beta_magnitude` are set to zero, then the optimizer does not keep 22 | /// momentum and magnitude correction information information about parameters. 23 | /// `epsilon` is added to denominators to avoid divide by zero errors. 24 | /// 25 | /// | | no `beta_momentum` |`beta_momentum`| 26 | /// |---|---|---| 27 | /// |**no `beta_magnitude`** |vanilla SGD | SGD with momentum 28 | /// |**`beta_magnitude`** | RMSProp | Adam 29 | #[derive(Debug, Serialize, Deserialize)] 30 | pub struct Optimizer { 31 | pub learning_rate: f32, 32 | pub beta_momentum: f32, 33 | pub beta_magnitude: f32, 34 | pub epsilon: f32, 35 | // QUESTION why keep this instance info inside the optimizer intead of the parameter node? 36 | // * Need to make parameter node its own type for easier accessing 37 | // * Tiny memory impact in forward only "production" graph. 38 | data: HashMap, 39 | } 40 | 41 | impl Default for Optimizer { 42 | fn default() -> Self { 43 | Self::sgd_default() 44 | } 45 | } 46 | impl fmt::Display for Optimizer { 47 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 48 | // Customize so only `x` and `y` are denoted. 49 | write!(f, 50 | "Optimizer {{ learning_rate {:?}, beta_momentum: {:?}, beta_magnitude: {:?}, epsilon: {:?}}}", 51 | self.learning_rate, 52 | self.beta_momentum, 53 | self.beta_magnitude, 54 | self.epsilon, 55 | )?; 56 | Ok(()) 57 | } 58 | } 59 | 60 | impl Optimizer { 61 | pub fn new(learning_rate: f32, beta_momentum: f32, beta_magnitude: f32, epsilon: f32) -> Self { 62 | let data = HashMap::new(); 63 | Optimizer { 64 | learning_rate, 65 | beta_momentum, 66 | beta_magnitude, 67 | epsilon, 68 | data, 69 | } 70 | } 71 | /// Vanilla stochastic gradient descent with no added fluff. 72 | pub fn sgd_default() -> Self { 73 | Self::new(1e-3, 0.0, 0.0, 1e-8) 74 | } 75 | /// SGD with a momentum component. Add the geometric average of past gradients to the parameter 76 | /// instead of the gradient itself. This averaging dampens the stochasticity of the stochastic 77 | /// gradient descent. 78 | pub fn momentum_default() -> Self { 79 | Self::new(1e-3, 0.9, 0.0, 1e-8) 80 | } 81 | /// SGD with a magnitude component. Rescale gradients by dividing by the geometric mean of 82 | /// previous gradients squared. Parameters with frequent large gradients will see those 83 | /// gradients shrink while parameters with sparse gradients will have their gradients grow. 84 | pub fn rmsprop_default() -> Self { 85 | Self::new(1e-2, 0.0, 0.9, 1e-8) 86 | } 87 | /// Adam (Adaptive Moment Estimation) Combines the momentum component from `momentum` and the 88 | /// magnitude component from `rmsprop`. 89 | pub fn adam_default() -> Self { 90 | Self::new(1e-2, 0.9, 0.999, 1e-8) 91 | } 92 | pub fn register(&mut self, i: Idx, shape: &[usize]) { 93 | let momentum = if self.beta_momentum > f32::EPSILON { 94 | Some(ArrayD::zeros(shape)) 95 | } else { 96 | None 97 | }; 98 | let magnitude = if self.beta_magnitude > f32::EPSILON { 99 | Some(ArrayD::zeros(shape)) 100 | } else { 101 | None 102 | }; 103 | let instance = OptimizerInstance { 104 | momentum, 105 | magnitude, 106 | }; 107 | self.data.insert(i, instance); 108 | } 109 | /// Apply gradient 110 | pub fn apply_gradient(&mut self, i: Idx, mut param: ArrayViewMutD, grad: &ArrayD) { 111 | let optimizer_instance = self 112 | .data 113 | .get_mut(&i) 114 | .expect("Attempted to apply gradient to unregistered parameter"); 115 | 116 | let mut delta = if let Some(ref mut mom) = optimizer_instance.momentum { 117 | let beta1 = self.beta_momentum; 118 | mom.zip_mut_with(&grad, |m, g| *m = (1.0 - beta1) * *g + beta1 * *m); 119 | mom.to_owned() / (1.0 - self.beta_momentum) 120 | } else { 121 | grad.to_owned() 122 | }; 123 | 124 | if let Some(ref mut mag) = optimizer_instance.magnitude { 125 | let beta2 = self.beta_magnitude; 126 | mag.zip_mut_with(&grad, |m, g| *m = (1.0 - beta2) * g.powi(2) + beta2 * *m); 127 | let e = self.epsilon; 128 | delta.zip_mut_with(mag, |d, m| *d /= (m / (1.0 - beta2)).sqrt() + e); 129 | } 130 | 131 | let lr = self.learning_rate; 132 | param.zip_mut_with(&delta, |p, d| *p += d * lr); 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # ∂rug - Differentiable Rust Graph 2 | //! 3 | //! This crate is a collection of utilities to build build neural networks (differentiable 4 | //! programs). See [examples](https://github.com/CasperN/drug/tree/master/examples) 5 | //! for implementations of canonical neural networks. You may need to download those datasets 6 | //! yourself to use them. Examples include: 7 | //! * Mnist with dense networks 8 | //! * Mnist with convolutional neural networks (though embarassingly slowly) 9 | //! * Penn TreeBank character prediction with RNN and GRU 10 | //! 11 | //! ### Planned Future Features 12 | //! * Higher level API 13 | //! * Building complexes of nodes (conv + bias + relu) / RNN cells, with parameter reuse 14 | //! * Subgraphs / updating subsets of graphs (e.g. for GAN) with separate optimizers 15 | //! * Parallel backprop multiple arguments of 1 node 16 | //! * ndarray-parallel or OpenMPI for graph replication and parallelization 17 | //! * Link to some optimized OpenCL maths backend for GPU utilization 18 | //! 19 | //! Reinforcement learning applications may also challenge the archiecture but I don't understand 20 | //! the process well enough yet to consider adding it to the library. 21 | //! 22 | //! ### Wish list 23 | //! * Operator overloading API + Taking advantage of the type system and const generics 24 | //! * May require total overhaul.. or may be possible with a "Graph Cursor" trait and more 25 | //! sophisticaed handles beyond current Idxs 26 | //! * Automatic differentiation of operations defined only from loops (proc macros?) 27 | //! * Taking advantage of just in time compilation and fusion of operations / kernels 28 | //! * Other kinds of derivatives e.g. jacobian 29 | // TODO Consider greping for unwrap, expect, panic, assert, and replacing signatures with results. 30 | 31 | #![feature(test)] 32 | #[macro_use] 33 | pub extern crate ndarray; 34 | extern crate rand; 35 | extern crate test; 36 | #[macro_use] 37 | extern crate debug_stub_derive; 38 | #[macro_use] 39 | extern crate serde_derive; 40 | #[macro_use] 41 | extern crate erased_serde; 42 | extern crate serde; 43 | 44 | #[cfg(test)] 45 | #[macro_use(iproduct)] 46 | extern crate itertools; 47 | 48 | use ndarray::prelude::*; 49 | use rand::distributions::{Distribution, Normal}; 50 | use rand::thread_rng; 51 | 52 | mod graph; 53 | pub mod nodes; 54 | mod optimizers; 55 | 56 | pub use graph::*; 57 | pub use nodes::{GlobalPool, Operation, Padding}; 58 | pub use optimizers::Optimizer; 59 | 60 | // TODO initializers file, maybe initializers enum so they're serializable 61 | /// The default (and only provided) initializer. Only works with convolution kernels and matrices. 62 | pub fn xavier_initialize(shape: &[usize]) -> ArrayD { 63 | // let len: usize = shape.iter().product(); 64 | let (n_in, n_out) = match shape.len() { 65 | 4 => (shape[2], shape[3]), // Convolution kernel 66 | 2 => (shape[0], shape[1]), // Matrix 67 | 1 => (shape[0], shape[0]), // Vector 68 | x => unimplemented!("Initialize with {:?}", x), 69 | }; 70 | let var = 2.0 / (n_in as f64 + n_out as f64); 71 | let normal = Normal::new(0.0, var.sqrt()); 72 | let mut rng = thread_rng(); 73 | ArrayD::from_shape_fn(shape, |_| normal.sample(&mut rng) as f32) 74 | } 75 | 76 | /// Take the softmax of an array of shape `batch_size * num_classes` 77 | pub fn softmax(logits: &ArrayD) -> Array2 { 78 | let mut softmax = logits.to_owned().into_dimensionality::().unwrap(); 79 | // Calculate softmax 80 | let max = softmax.fold_axis(Axis(1), 0.0, |x, y| if *x > *y { *x } else { *y }); 81 | for ((b, _), x) in softmax.indexed_iter_mut() { 82 | *x = (*x - max[b]).exp(); 83 | } 84 | let sum = softmax.sum_axis(Axis(1)); 85 | for ((b, _), x) in softmax.indexed_iter_mut() { 86 | *x /= sum[b]; 87 | } 88 | softmax 89 | } 90 | 91 | /// A loss function used for classification. 92 | /// 93 | /// `logits` are a `batch_size * num_classes` array of values which will be compressed into the 94 | /// `[0,1]` range by a softmax operation. Given the correct categories `labels`, this function will 95 | /// calculate the negative log-probability of the logits and its gradient with respect to the logits. 96 | pub fn softmax_cross_entropy_loss(logits: &ArrayD, labels: &[usize]) -> (f32, ArrayD) { 97 | let mut softmax = softmax(logits); 98 | let mut log_loss = 0.0; 99 | // Turn softmax into gradient and add up log_loss 100 | for (b, lbl) in labels.iter().enumerate() { 101 | let correct = *lbl; 102 | log_loss -= softmax[(b, correct)].ln(); 103 | softmax[(b, correct)] -= 1.0; 104 | } 105 | log_loss /= labels.len() as f32; 106 | 107 | (log_loss, softmax.into_dyn()) 108 | } 109 | 110 | #[cfg(test)] 111 | mod libc { 112 | use super::*; 113 | use graph::Graph; 114 | use std::f32; 115 | #[test] 116 | fn param_initialize() { 117 | let mut g = Graph::default(); 118 | let x = g.param(&[3, 3, 1, 8]); 119 | assert_eq!(g.get_value(x).shape(), [3, 3, 1, 8]); 120 | } 121 | #[test] 122 | fn softmax_vs_correct() { 123 | let logits = arr2(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); 124 | let correct = arr2(&[ 125 | [ 126 | 9.003057317038046e-2, 127 | 0.24472847105479767, 128 | 0.6652409557748219, 129 | ], 130 | [ 131 | 9.003057317038045e-2, 132 | 0.24472847105479764, 133 | 0.6652409557748219, 134 | ], 135 | ]); 136 | let softmax = softmax(&logits.into_dyn()); 137 | for i in 0..2 { 138 | for j in 0..3 { 139 | assert!((softmax[(i, j)] - correct[(i, j)]).abs() < f32::EPSILON); 140 | } 141 | } 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /examples/mnist/main.rs: -------------------------------------------------------------------------------- 1 | //! Example dense net classifier with MNIST. 2 | extern crate byteorder; 3 | extern crate drug; 4 | extern crate ndarray; 5 | extern crate ron; 6 | mod input; 7 | 8 | use std::f32; 9 | use std::fs::{create_dir_all, File}; 10 | use std::io::Write; 11 | use std::path::Path; 12 | 13 | use drug::*; 14 | use input::{images, labels}; 15 | use ndarray::prelude::*; 16 | 17 | static MODEL_DIR: &'static str = "/tmp/drug/mnist/"; 18 | static DATA: &'static str = "examples/data/"; 19 | static TR_IMG: &'static str = "train-images-idx3-ubyte"; 20 | static TR_LBL: &'static str = "train-labels-idx1-ubyte"; 21 | static TS_IMG: &'static str = "t10k-images-idx3-ubyte"; 22 | static TS_LBL: &'static str = "t10k-labels-idx1-ubyte"; 23 | static TR_LEN: u32 = 60_000; 24 | static TS_LEN: u32 = 10_000; 25 | static ROWS: usize = 28; 26 | static COLS: usize = 28; 27 | 28 | // TODO: Replace with chunker, shuffle and batchify function for multiple epochs. 29 | fn reshape_and_iter( 30 | data: &[f32], // The mnist data read from file 31 | batch_size: usize, // how many mnist examples to train with in one forward / backward pass 32 | as_vectors: bool, // output vectors for a dense network instead of images for convolutional 33 | ) -> Box>> { 34 | let shape = if as_vectors { 35 | vec![batch_size, ROWS * COLS] 36 | } else { 37 | vec![batch_size, ROWS, COLS] 38 | }; 39 | let batched: Vec> = data 40 | .chunks_exact(batch_size * ROWS * COLS) 41 | .map(move |x| Array::from_shape_vec(shape.as_slice(), x.to_vec()).unwrap()) 42 | .collect(); 43 | 44 | Box::new(batched.into_iter()) 45 | } 46 | 47 | /// Simple 3 layer neural network 48 | fn dense_network(g: &mut Graph, imgs: Idx) -> Idx { 49 | let weights_1 = g.param(&[ROWS * COLS, 110]); 50 | let weights_2 = g.param(&[110, 10]); 51 | let mat_mul_1 = g.matmul(weights_1, imgs); 52 | let sigmoid = g.sigmoid(mat_mul_1); 53 | g.matmul(weights_2, sigmoid) 54 | } 55 | 56 | /// 3 layer Convolutional neural network 57 | fn conv_network(g: &mut Graph, imgs: Idx) -> Idx { 58 | let conv_block = |g: &mut Graph, in_idx, in_channels, out_channels| { 59 | // Repeating block of our cnn 60 | let kernel = g.param(&[3, 3, in_channels, out_channels]); 61 | let conv = g.conv(kernel, in_idx, Padding::Same, 1); 62 | g.relu(conv) 63 | }; 64 | 65 | let b1 = conv_block(g, imgs, 1, 8); 66 | let b2 = conv_block(g, b1, 8, 16); 67 | let b3 = conv_block(g, b2, 16, 32); 68 | 69 | let kernel_1x1 = g.param(&[1, 1, 32, 10]); 70 | let conv_1x1 = g.conv(kernel_1x1, b3, Padding::Same, 1); 71 | 72 | g.global_pool(conv_1x1, GlobalPool::Average) 73 | } 74 | 75 | /// this is main 76 | fn main() { 77 | let learning_rate = 0.05; 78 | let batch_size = 8; 79 | let train_steps = TR_LEN as usize / batch_size; 80 | let use_dense = true; 81 | let summary_every = 500; 82 | 83 | println!("Reading data...",); 84 | let train_images = images(&Path::new(DATA).join(TR_IMG), TR_LEN); 85 | let train_labels = labels(&Path::new(DATA).join(TR_LBL), TR_LEN); 86 | let test_images = images(&Path::new(DATA).join(TS_IMG), TS_LEN); 87 | let test_labels = labels(&Path::new(DATA).join(TS_LBL), TS_LEN); 88 | 89 | let train_images = reshape_and_iter(&train_images, batch_size, use_dense); 90 | 91 | let (mut g, imgs, out) = load_model().unwrap_or_else(|e| { 92 | println!("Couldn't load graph because `{:?}`", e); 93 | println!("Building new graph..."); 94 | 95 | let mut g = Graph::default(); 96 | let imgs = g.input(None); // Set the iterator later 97 | 98 | let out = if use_dense { 99 | dense_network(&mut g, imgs) 100 | } else { 101 | conv_network(&mut g, imgs) 102 | }; 103 | 104 | // Save input and output idxs for the model 105 | g.named_idxs.insert("imgs".to_string(), imgs); 106 | g.named_idxs.insert("out".to_string(), out); 107 | (g, imgs, out) 108 | }); 109 | 110 | g.optimizer.learning_rate = learning_rate; 111 | 112 | println!("{}", g); 113 | 114 | g.replace_input_iterator(imgs, train_images).unwrap(); 115 | println!("Training..."); 116 | for step in 0..train_steps { 117 | g.forward(); 118 | 119 | let labels = &train_labels[step * batch_size..(step + 1) * batch_size]; 120 | let (loss, grad) = softmax_cross_entropy_loss(g.get_value(out), labels); 121 | 122 | g.set_loss(out, -grad); 123 | g.backward(); 124 | 125 | if step % summary_every == 0 { 126 | println!(" Step: {:?}\t log loss: {:?}", step, loss); 127 | } 128 | } 129 | // old input node exhausted, refresh with test images 130 | let test_images = reshape_and_iter(&test_images, batch_size, use_dense); 131 | g.replace_input_iterator(imgs, test_images).unwrap(); 132 | 133 | let test_steps = TS_LEN as usize / batch_size; 134 | let mut num_correct = 0; 135 | 136 | println!("Testing..."); 137 | for step in 0..test_steps { 138 | g.forward(); 139 | let labels = &test_labels[step * batch_size..(step + 1) * batch_size]; 140 | num_correct += count_correct(&g.get_value(out), labels); 141 | } 142 | println!( 143 | " Test accuracy: {:?}%", 144 | 100.0 * num_correct as f32 / TS_LEN as f32 145 | ); 146 | 147 | save_model(&g).expect("Error saving"); 148 | } 149 | 150 | fn save_model(g: &Graph) -> Result<(), Box> { 151 | create_dir_all(MODEL_DIR)?; 152 | let model_path = Path::new(MODEL_DIR).join("model.bin"); 153 | let mut f = File::create(&model_path)?; 154 | let gs = ron::ser::to_string(&g)?; 155 | f.write_all(gs.as_bytes())?; 156 | Ok(()) 157 | } 158 | 159 | fn load_model() -> Result<(Graph, Idx, Idx), Box> { 160 | let model_path = Path::new(MODEL_DIR).join("model.bin"); 161 | let f = File::open(&model_path)?; 162 | let g: Graph = ron::de::from_reader(&f)?; 163 | let imgs = *g 164 | .named_idxs 165 | .get("imgs") 166 | .expect("Expected named index `imgs`."); 167 | let out = *g 168 | .named_idxs 169 | .get("out") 170 | .expect("Expected named index `out`."); 171 | println!("Loaded saved model from {:?}", model_path); 172 | Ok((g, imgs, out)) 173 | } 174 | 175 | fn count_correct(logits: &ArrayD, labels: &[usize]) -> u32 { 176 | let logits = logits.to_owned().into_dimensionality::().unwrap(); 177 | let batch_size = labels.len(); 178 | let mut num_correct = 0; 179 | for b in 0..batch_size { 180 | let mut max = f32::MIN; 181 | let mut max_idx = 0; 182 | for i in 0..10 { 183 | if logits[(b, i)] > max { 184 | max = logits[(b, i)]; 185 | max_idx = i; 186 | } 187 | } 188 | if max_idx == labels[b] as usize { 189 | num_correct += 1; 190 | } 191 | } 192 | num_correct 193 | } 194 | // 195 | -------------------------------------------------------------------------------- /src/nodes/mod.rs: -------------------------------------------------------------------------------- 1 | //! This module holds the different types nodes that exist in a computation graph. Nodes that 2 | //! represent a differentiable computation are implemented by a struct with the "Operation" trait. 3 | //! Use [Graph](../struct.Graph.html) methods to create and register nodes inside a graph. 4 | //! See [Node](enum.Node.html) for the types of node available. 5 | //! This module may eventually be made private... 6 | 7 | pub use self::activation::*; 8 | pub use self::arithmetic::{Add, Mult}; 9 | pub use self::conv::Conv; 10 | pub use self::conv::Padding; 11 | pub use self::embedding::Embedding; 12 | pub use self::global_pool::GlobalPool; 13 | pub use self::matmul::MatMul; 14 | 15 | use graph::Idx; 16 | use ndarray::prelude::*; 17 | use std::fmt::Debug; 18 | mod activation; 19 | mod arithmetic; 20 | mod conv; 21 | mod embedding; 22 | mod global_pool; 23 | mod matmul; 24 | 25 | /// Represents a differentiable function in a computation graph. 26 | /// Operations hold their own hyperparameters but not their parameters, values or losses. 27 | /// Unfortunately boxed traits cannot be saved with serde. When reloaded they will be replaced 28 | /// by `Box` nodes. When reloading a model with custom Operations, you need to 29 | /// replace them manually. 30 | pub trait Operation: Debug { 31 | /// Mutates Outputs based on inputs. 32 | /// TODO consider modifying output ArrayD in place 33 | fn eval(&self, inputs: &[ArrayViewD]) -> ArrayD; 34 | // fn eval(&self, inputs: Box<[ArrayViewD]>) -> ArrayD; 35 | 36 | /// Returns gradients of inputs wrt outputs. 37 | /// Note the inputs and output vectors should be the same length. 38 | /// TODO consider modifying output ArrayDs in place 39 | fn grad(&self, inputs: &[ArrayViewD], loss: ArrayViewD) -> Vec>; 40 | } 41 | serialize_trait_object!(Operation); 42 | 43 | #[derive(DebugStub, Serialize, Deserialize)] 44 | /// Nodes are the building blocks of the [computation graph](../struct.Graph.html). 45 | /// The variants of a node differ in how the value is produced and how loss is propagated back. 46 | /// Users typically interact with Nodes with their index `:Idx` which is returned by the graph 47 | /// when registered / created. 48 | pub enum Node { 49 | /// Produce Value from beyond the graph. 50 | /// * In a forward pass, its value is updates by the iterator or panics if its None 51 | /// * In a backward pass, its losses are currently calculated but unused. 52 | /// * When serializing, the internal iterator is ignored. It deserializes to None. 53 | Input { 54 | #[serde(skip)] 55 | #[debug_stub = "Option>>>"] 56 | it: Option>>>, 57 | }, 58 | 59 | /// Parameter nodes only hold a shape. Its values are initialized when inserted into the graph 60 | /// using the graph's initializer. 61 | /// * In a foward pass, parameters are ignored. 62 | /// * In a backward pass, their losses are applied by the graph's optimizer. 63 | Parameter(Box<[usize]>), 64 | /// See [Conv](struct.Conv.html) for more. 65 | Conv { kernel: Idx, img: Idx, conv: Conv }, 66 | /// See [Add](struct.Add.html) for more. 67 | Add { xs: Vec }, 68 | /// See [Mult](struct.Mult.html) for more. 69 | Mult { xs: Vec }, 70 | /// See [Matmul](struct.Matmul.html) for more. 71 | MatMul { mat: Idx, v: Idx }, 72 | /// See [Activation](enum.Activation.html) for more. 73 | Activation { x: Idx, a: Activation }, 74 | /// See [Embedding](struct.Embedding.html) for more. 75 | Embedding { emb: Idx, code: Idx }, 76 | /// See [GlobalPool](struct.GlobalPool.html) for more. 77 | GlobalPool { pool: GlobalPool, x: Idx }, 78 | /// An Operation node holds an [Operation trait object](trait.Operation.html) and the indices 79 | /// referring to its input values. 80 | /// * In a forward pass, its value is updated by the `operation` and the values indexed by 81 | /// `inputs`. 82 | /// * In a backward pass, gradients are calculated and losses are propagated backwards and added 83 | /// to the losses indexed by `inputs`. 84 | Operation { 85 | inputs: Box<[Idx]>, 86 | #[serde(skip_deserializing)] 87 | operation: Box, 88 | }, 89 | 90 | /// Ignored by the graph, you have to set the values yourself 91 | Constant, 92 | } 93 | 94 | impl Node { 95 | pub fn inputs(&self) -> Vec { 96 | match self { 97 | Node::Conv { kernel, img, .. } => vec![*kernel, *img], 98 | Node::Add { xs } => xs.to_vec(), 99 | Node::Mult { xs } => xs.to_vec(), 100 | Node::MatMul { mat, v } => vec![*mat, *v], 101 | Node::Activation { x, .. } => vec![*x], 102 | Node::Embedding { emb, code } => vec![*emb, *code], 103 | Node::GlobalPool { x, .. } => vec![*x], 104 | Node::Operation { inputs, .. } => inputs.to_vec(), 105 | Node::Input { .. } | Node::Parameter(..) | Node::Constant => vec![], 106 | } 107 | } 108 | pub fn forward(&mut self, inputs: &[ArrayViewD]) -> Option> { 109 | match self { 110 | Node::Conv { conv, .. } => Some(conv.eval(inputs)), 111 | Node::Add { .. } => Some(Add().eval(inputs)), 112 | Node::Mult { .. } => Some(Mult().eval(inputs)), 113 | Node::MatMul { .. } => Some(MatMul().eval(inputs)), 114 | Node::Activation { a, .. } => Some(a.eval(inputs)), 115 | Node::Embedding { .. } => Some(Embedding().eval(inputs)), 116 | Node::GlobalPool { pool, .. } => Some(pool.eval(inputs)), 117 | Node::Operation { operation, .. } => Some(operation.eval(inputs)), 118 | Node::Input { ref mut it } => it.as_mut().expect("Input node uninitialized.").next(), 119 | Node::Parameter(..) | Node::Constant => None, 120 | } 121 | } 122 | pub fn backward(&self, inputs: &[ArrayViewD], loss: &ArrayD) -> Vec> { 123 | match self { 124 | Node::Conv { conv, .. } => conv.grad(inputs, loss.view()), 125 | Node::Add { .. } => Add().grad(inputs, loss.view()), 126 | Node::Mult { .. } => Mult().grad(inputs, loss.view()), 127 | Node::MatMul { .. } => MatMul().grad(inputs, loss.view()), 128 | Node::Activation { a, .. } => a.grad(inputs, loss.view()), 129 | Node::Embedding { .. } => Embedding().grad(inputs, loss.view()), 130 | Node::GlobalPool { pool, .. } => pool.grad(inputs, loss.view()), 131 | Node::Operation { operation, .. } => operation.grad(inputs, loss.view()), 132 | Node::Input { .. } | Node::Constant | Node::Parameter(..) => vec![], 133 | } 134 | } 135 | } 136 | 137 | // TODO figure out serialization and deserialization of Boxed traits. This may not be possible :/ 138 | impl Default for Box { 139 | fn default() -> Self { 140 | Box::new(arithmetic::Add()) 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /examples/ptb/main.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::deref_addrof)] 2 | #[macro_use] 3 | extern crate ndarray; 4 | extern crate drug; 5 | extern crate rand; 6 | #[macro_use] 7 | extern crate serde_derive; 8 | extern crate ron; 9 | extern crate serde; 10 | 11 | use drug::*; 12 | use ndarray::prelude::*; 13 | use std::fs::{create_dir_all, File}; 14 | use std::io::Write; 15 | use std::path::Path; 16 | mod beam_search; 17 | mod ops; 18 | mod rnn; 19 | mod text_dataset; 20 | 21 | use beam_search::BeamSearch; 22 | use rnn::*; 23 | use text_dataset::TextDataSet; 24 | 25 | static MODEL_DIR: &'static str = "/tmp/drug/ptb/"; 26 | 27 | /// Adds batches of words to the graph by registering constants and passing the coded words through 28 | /// an embedding vector. Despite being categorical, word_batch is a vector of positive integers 29 | /// because the graph only holds ArrayD 30 | #[derive(Serialize, Deserialize)] 31 | struct Embedding(Idx); 32 | impl Embedding { 33 | fn new(g: &mut Graph, embedding_len: usize, embedding_dim: usize) -> Self { 34 | Embedding(g.param(&[embedding_len, embedding_dim])) 35 | } 36 | // Add batch to graph and return Idx of its embedding 37 | fn add_word(&self, g: &mut Graph, word_batch: &ArrayD) -> Idx { 38 | let word = g.constant(word_batch.to_owned()); 39 | g.embedding(self.0, word) 40 | } 41 | } 42 | 43 | #[derive(Serialize, Deserialize)] 44 | struct Predict(Idx); 45 | impl Predict { 46 | fn new(g: &mut Graph, hidden_dim: usize, pred_len: usize) -> Self { 47 | Predict(g.param(&[hidden_dim, pred_len])) 48 | } 49 | fn predict(&self, g: &mut Graph, hidden: Idx) -> Idx { 50 | g.matmul(self.0, hidden) 51 | } 52 | } 53 | 54 | fn save_model( 55 | g: &Graph, 56 | r: &RecurrentLayers, 57 | ) -> Result<(), Box> { 58 | create_dir_all(MODEL_DIR)?; 59 | let model_path = Path::new(MODEL_DIR).join("computation_graph.bin"); 60 | let mut f = File::create(&model_path)?; 61 | let gs = ron::ser::to_string(&g)?; 62 | f.write_all(gs.as_bytes())?; 63 | 64 | let rl_path = Path::new(MODEL_DIR).join("recurrent_layers.bin"); 65 | let mut f = File::create(&rl_path)?; 66 | let rs = ron::ser::to_string(&r)?; 67 | f.write_all(rs.as_bytes())?; 68 | 69 | Ok(()) 70 | } 71 | 72 | type StuffToRestore = (Graph, Embedding, Predict, RecurrentLayers); 73 | fn load_model() -> Result, Box> 74 | where 75 | T: RecurrentCell + serde::de::DeserializeOwned, 76 | { 77 | let model_path = Path::new(MODEL_DIR).join("computation_graph.bin"); 78 | let f = File::open(&model_path)?; 79 | let g: Graph = ron::de::from_reader(f)?; 80 | 81 | let rl_path = Path::new(MODEL_DIR).join("recurrent_layers.bin"); 82 | let f = File::open(&rl_path)?; 83 | let rl: RecurrentLayers = ron::de::from_reader(&f)?; 84 | 85 | let emb_idx = *g 86 | .named_idxs 87 | .get("embedding") 88 | .expect("Expected named index `embedding`"); 89 | let embedding = Embedding(emb_idx); 90 | 91 | let prd_idx = *g 92 | .named_idxs 93 | .get("predict") 94 | .expect("Expected named index `predict`"); 95 | let predict = Predict(prd_idx); 96 | 97 | println!("Loaded saved model"); 98 | Ok((g, embedding, predict, rl)) 99 | } 100 | 101 | /// Architecture Epoch 5 Train Perplexity 102 | /// -------------------- ------------------------ 103 | /// GRU [30, 30, 30] 5.35 104 | /// GRU [30, 30, 30, 30] 5.09 105 | /// GRU [50, 50, 50] 4.69 106 | /// GRU [50, 100, 100] 4.16 107 | /// GRU [50, 250, 250] 3.86 - 3.74 (10 epochs) 108 | fn main() { 109 | // dimensions[0] is embedding dimension, the rest are size of hidden dim in each layer 110 | let dimensions = vec![50, 50, 50]; 111 | let batch_size = 32; 112 | let sequence_len = 50; 113 | // Note the effective learning_rate is this * batch_size * sequence_len 114 | let learning_rate = 0.01 as f32; 115 | let summary_every = 250; 116 | let num_epochs = 1; 117 | 118 | println!("Reading dataset...",); 119 | let train = TextDataSet::new(batch_size, sequence_len); 120 | let num_symbols = train.idx2char.len(); 121 | println!(" Batch size {:?}", batch_size); 122 | println!(" Sequence len {:?}", sequence_len); 123 | println!(" Number of symbols: {:?}", num_symbols); 124 | println!(" Number of sequences: {:?}\n", train.corpus.len()); 125 | 126 | let (mut g, embedding, predict, rnn) = load_model().unwrap_or_else(|_| { 127 | println!("Defining new model"); 128 | let mut g = Graph::default(); 129 | g.optimizer.learning_rate = learning_rate; 130 | 131 | // These structs hold Idx pointing to their parameters and have methods adding operations 132 | // to the graph. 133 | let embedding = Embedding::new(&mut g, num_symbols, dimensions[0]); 134 | let predict = Predict::new(&mut g, *dimensions.last().unwrap(), num_symbols); 135 | let rnn = RecurrentLayers::::new(&mut g, &dimensions); 136 | 137 | g.named_idxs.insert("embedding".to_string(), embedding.0); 138 | g.named_idxs.insert("predict".to_string(), predict.0); 139 | (g, embedding, predict, rnn) 140 | }); 141 | 142 | println!("Training..."); 143 | let mut total_loss = 0.0; 144 | let mut seen = 0; 145 | for epoch in 0..num_epochs { 146 | for (step, sequence) in train.corpus.iter().enumerate() { 147 | let mut hiddens = rnn.get_hidden0_idxs(); 148 | let mut output = vec![]; 149 | 150 | // Build RNN sequence dynamically based on the length of the sequence. 151 | for (i, word_batch) in sequence.iter().enumerate() { 152 | // Skip predicting first word because batch size incompatible. (FIXME) 153 | let pred = if i > 0 { 154 | Some(predict.predict(&mut g, *hiddens.last().unwrap())) 155 | } else { 156 | None 157 | }; 158 | let emb = embedding.add_word(&mut g, &word_batch); 159 | hiddens = rnn.add_cells(&mut g, &hiddens, emb); 160 | 161 | output.push((pred, word_batch)); 162 | } 163 | g.forward(); 164 | // Check 1 step predictions and compute loss 165 | for (pred, correct) in output.into_iter() { 166 | let correct: Vec = correct.iter().map(|x| *x as usize).collect(); 167 | 168 | if let Some(pred) = pred { 169 | let (loss, grad) = 170 | softmax_cross_entropy_loss(g.get_value(pred), correct.as_slice()); 171 | total_loss += loss; 172 | g.set_loss(pred, -grad) 173 | } 174 | } 175 | g.backward(); 176 | g.clear_non_parameters(); 177 | seen += sequence.len(); 178 | 179 | if step % summary_every == 0 { 180 | total_loss /= seen as f32 * batch_size as f32; 181 | println!( 182 | "Epoch: {:?} of {:?}\t Step: {:5} of {:?}\t Perplexity: {:2.2}", 183 | epoch, 184 | num_epochs, 185 | step, 186 | train.corpus.len(), 187 | total_loss.exp() 188 | ); 189 | total_loss = 0.0; 190 | seen = 0; 191 | } 192 | } 193 | } 194 | 195 | save_model(&g, &rnn).expect("Saving Error"); 196 | 197 | // BUG forward pass will fail if beam width > num characters 198 | let beam_width = 30; 199 | let gen_len = 80; 200 | 201 | for temp in [1.0, 0.9, 0.8, 0.7].into_iter() { 202 | println!("\nGenerating with temp {:?}...", temp); 203 | 204 | let mut beam_search = BeamSearch::new(beam_width); 205 | let mut hiddens = vec![]; 206 | 207 | for h in rnn.get_hidden0_idxs().iter() { 208 | let mean_h0 = g.get_value(*h).mean_axis(Axis(0)); 209 | let h_dim = mean_h0.shape()[0]; 210 | let hidden = 211 | Array::from_shape_fn([beam_width, h_dim], |(_b, h)| mean_h0[(h)]).into_dyn(); 212 | hiddens.push(hidden); 213 | } 214 | 215 | for _ in 0..gen_len { 216 | // predict next characters based on hidden state 217 | let h = g.constant(hiddens.last().unwrap().to_owned()); 218 | let pred_idx = predict.predict(&mut g, h); 219 | g.forward1(pred_idx); 220 | let next_word_logits = g.get_value(pred_idx).to_owned(); 221 | 222 | // Consider next hidden state and words based on probability of sequence 223 | let (hiddens, words) = beam_search.search(&hiddens, next_word_logits.view(), *temp); 224 | let mut hidden_idxs = vec![]; 225 | for h in hiddens.into_iter() { 226 | hidden_idxs.push(g.constant(h)); 227 | } 228 | // Update hidden state 229 | let emb = embedding.add_word(&mut g, &words); 230 | let hidden_idxs = rnn.add_cells(&mut g, &hidden_idxs, emb); 231 | g.forward(); 232 | 233 | // Take it out of the graph 234 | let mut hiddens = vec![]; 235 | for hi in hidden_idxs.into_iter() { 236 | hiddens.push(g.get_value(hi).to_owned()); 237 | } 238 | g.clear_non_parameters(); 239 | } 240 | 241 | let res = beam_search.into_codes(); 242 | for s in res.iter() { 243 | println!("{:?}", train.decode(s)); 244 | } 245 | } 246 | } 247 | -------------------------------------------------------------------------------- /src/graph.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{Array, ArrayD, ArrayViewD}; 2 | use nodes::*; 3 | use std::collections::BTreeMap; 4 | use std::fmt; 5 | 6 | use optimizers::Optimizer; 7 | use xavier_initialize; 8 | 9 | /// A placeholder to help index into a graph. These should not be interchanged between graphs. 10 | #[derive(Debug, Clone, Copy, Serialize, Deserialize, Hash, PartialEq, Eq)] 11 | pub struct Idx { 12 | idx: usize, 13 | } 14 | 15 | /// A differentiable computation graph. Use this struct to hold your differentiable program 16 | /// which is a directed acyclic graph of Nodes, their associated values 17 | /// and losses (gradients). The graph computes values moving forward in insertion order (see 18 | /// `forward` method) and propagates losses backwards in reverse insertion order (see `backward` 19 | /// method). The default graph comes with an xavier initializer and a vanilla stochastic gradient 20 | /// descent optimizer. 21 | // QUESTION why not Option, Option?, even Option 22 | // TODO losses should be renamed gradients to be more descriptive 23 | #[derive(DebugStub, Serialize, Deserialize)] 24 | pub struct Graph { 25 | nodes: BTreeMap, 26 | values: BTreeMap>, 27 | losses: BTreeMap>, 28 | num_inserted: usize, 29 | #[debug_stub = "Initializer function"] 30 | #[serde(skip)] 31 | initializer: Initializer, 32 | pub optimizer: Optimizer, 33 | pub named_idxs: BTreeMap, 34 | } 35 | 36 | struct Initializer(Box<(Fn(&[usize]) -> ArrayD)>); 37 | 38 | impl Default for Initializer { 39 | fn default() -> Self { 40 | Initializer(Box::new(xavier_initialize)) 41 | } 42 | } 43 | 44 | impl fmt::Display for Graph { 45 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 46 | // Customize so only `x` and `y` are denoted. 47 | writeln!(f, "Computation Graph with Optimizer:\n\t{}", self.optimizer)?; 48 | for (i, node) in self.nodes.iter() { 49 | writeln!( 50 | f, 51 | "\n{}\t{:?}\n\tvalue shape: {:?}\tloss shape: {:?}", 52 | i, 53 | node, 54 | self.values[&i].shape(), 55 | self.losses[&i].shape(), 56 | )? 57 | } 58 | Ok(()) 59 | } 60 | } 61 | 62 | // Shape information? 63 | impl Default for Graph { 64 | /// xavier initializer and normal gradient descent 65 | fn default() -> Self { 66 | Graph::new(Box::new(xavier_initialize), Optimizer::default()) 67 | } 68 | } 69 | 70 | impl Graph { 71 | /// Consider using `Graph::default()` if you don't want to choose your own optimizer and 72 | /// initializer. 73 | pub fn new(initializer: Box<(Fn(&[usize]) -> ArrayD)>, optimizer: Optimizer) -> Self { 74 | Graph { 75 | nodes: BTreeMap::new(), 76 | values: BTreeMap::new(), 77 | losses: BTreeMap::new(), 78 | named_idxs: BTreeMap::new(), 79 | num_inserted: 0, 80 | initializer: Initializer(initializer), 81 | optimizer, 82 | } 83 | } 84 | /// Inserts the node into the graph and returns the index 85 | pub fn register(&mut self, node: Node) -> Idx { 86 | let idx = self.num_inserted; 87 | if let Node::Parameter(ref shape) = node { 88 | self.optimizer.register(Idx { idx }, shape) 89 | } 90 | self.nodes.insert(idx, node); 91 | self.values.insert(idx, Array::zeros(()).into_dyn()); 92 | self.losses.insert(idx, Array::zeros(()).into_dyn()); 93 | self.num_inserted += 1; 94 | Idx { idx } 95 | } 96 | /// Registers a parameter of the given shape and initializes the value using the graph's 97 | /// initializer. 98 | pub fn param(&mut self, shape: &[usize]) -> Idx { 99 | let idx = self.register(Node::Parameter(shape.to_vec().into_boxed_slice())); 100 | self.values.insert(idx.idx, (self.initializer.0)(shape)); 101 | self.losses.insert(idx.idx, Array::zeros(shape)); 102 | self.num_inserted += 1; 103 | idx 104 | } 105 | /// Registers an input node which advances the iterator `it` each forward pass. 106 | pub fn input(&mut self, it: Option>>>) -> Idx { 107 | self.register(Node::Input { it }) 108 | } 109 | /// Registers an operation and its inputs 110 | pub fn op(&mut self, op: impl Operation + 'static, inputs: &[Idx]) -> Idx { 111 | // TODO Verify Operation inputs -- make sure they're in the graph 112 | let o = Node::Operation { 113 | operation: Box::new(op), 114 | inputs: inputs.to_vec().into_boxed_slice(), 115 | }; 116 | self.register(o) 117 | } 118 | /// Registers a constant, sets its value to `c`, then returns the idx 119 | pub fn constant(&mut self, c: ArrayD) -> Idx { 120 | let idx = self.register(Node::Constant); 121 | self.set_value(idx, c); 122 | idx 123 | } 124 | fn _forward1(&mut self, i: usize) { 125 | if let Some(n) = self.nodes.get_mut(&i) { 126 | let inps = n.inputs(); 127 | if let Some(v) = n.forward(&view_at_idxs(&inps, &self.values)) { 128 | self.values.insert(i, v); 129 | } 130 | } 131 | // reset losses 132 | self.losses.insert(i, Array::zeros(self.values[&i].shape())); 133 | } 134 | fn _backward1(&mut self, i: usize) { 135 | if let Some(n) = self.nodes.get_mut(&i) { 136 | if let Node::Parameter(..) = n { 137 | self.optimizer.apply_gradient( 138 | Idx { idx: i }, 139 | self.values.get_mut(&i).unwrap().view_mut(), 140 | &self.losses[&i], 141 | ); 142 | } else { 143 | let inps = n.inputs(); 144 | let gradients = n.backward(&view_at_idxs(&inps, &self.values), &self.losses[&i]); 145 | for (grad, j) in gradients.iter().zip(inps.iter()) { 146 | if let Some(x) = self.losses.get_mut(&j.idx) { 147 | *x += grad; 148 | } 149 | } 150 | } 151 | } 152 | } 153 | /// Computes values for each node in insertion order. 154 | /// Parameters are unaffected. 155 | /// Inputs will set their value to the next output of their iterator, 156 | /// Operations will compute a new value based on the values of its inputs. 157 | pub fn forward(&mut self) { 158 | let keys: Vec = self.nodes.keys().cloned().collect(); 159 | for i in keys.into_iter() { 160 | self._forward1(i); 161 | } 162 | } 163 | /// Propagates gradients in reverse insertion order. 164 | /// Parameters will apply gradients with the graph's optimizer. 165 | /// Inputs are unaffected 166 | /// Operations will compute gradient given values from their inputs and gradients from its outputs 167 | pub fn backward(&mut self) { 168 | let keys: Vec = self.nodes.keys().rev().cloned().collect(); 169 | for i in keys.into_iter() { 170 | self._backward1(i); 171 | } 172 | } 173 | /// Updates value and resets losses for node with Idx `i`. 174 | pub fn forward1(&mut self, i: Idx) { 175 | self._forward1(i.idx); 176 | } 177 | /// Back propagates losses for node with Idx `i`. 178 | pub fn backward1(&mut self, i: Idx) { 179 | self._backward1(i.idx); 180 | } 181 | /// Remove the node at `idx` as well as its associated value and loss. 182 | pub fn remove(&mut self, idx: Idx) { 183 | self.nodes.remove(&idx.idx); 184 | self.values.remove(&idx.idx); 185 | self.losses.remove(&idx.idx); 186 | } 187 | /// This op removes every node from the graph that is not a parameter. This is useful for 188 | /// dynamic graphs and recurrent neural networks when you want to rebuild everything each 189 | /// forward and backward pass of the network. 190 | pub fn clear_non_parameters(&mut self) { 191 | let mut keys = Vec::new(); 192 | for (i, n) in self.nodes.iter() { 193 | if let Node::Parameter(_) = n { 194 | //pass 195 | } else { 196 | keys.push(*i); 197 | } 198 | } 199 | for k in keys.into_iter() { 200 | self.nodes.remove(&k); 201 | self.values.remove(&k); 202 | self.losses.remove(&k); 203 | } 204 | } 205 | pub fn set_value(&mut self, idx: Idx, val: ArrayD) { 206 | if self.values.insert(idx.idx, val).is_none() { 207 | panic!("Tried to set value at a removed index") 208 | } 209 | } 210 | pub fn get_value(&self, idx: Idx) -> &ArrayD { 211 | &self.values[&idx.idx] 212 | } 213 | pub fn set_loss(&mut self, idx: Idx, loss: ArrayD) { 214 | if self.losses.insert(idx.idx, loss).is_none() { 215 | panic!("Tried to set loss at a removed index") 216 | } 217 | } 218 | pub fn get_loss(&self, idx: Idx) -> &ArrayD { 219 | &self.losses[&idx.idx] 220 | } 221 | /// Replace an Input node's iterator or converts Constant nodes into Input with this iterator. 222 | /// Note that Input node iterators are not saved when serialized with serde. 223 | pub fn replace_input_iterator( 224 | &mut self, 225 | idx: Idx, 226 | new: Box>>, 227 | ) -> Result<(), String> { 228 | if let Some(n) = self.nodes.get_mut(&idx.idx) { 229 | match n { 230 | Node::Input { it } => *it = Some(new), 231 | Node::Constant => *n = Node::Input { it: Some(new) }, 232 | _ => { 233 | return Err("Tried to replace input iter at non Input/Constant node.".to_string()) 234 | } 235 | } 236 | Ok(()) 237 | } else { 238 | Err("Tried to replace input iterator at invalid index.".to_string()) 239 | } 240 | } 241 | pub fn add(&mut self, inputs: &[Idx]) -> Idx { 242 | self.register(Node::Add { 243 | xs: inputs.to_vec(), 244 | }) 245 | } 246 | pub fn mult(&mut self, inputs: &[Idx]) -> Idx { 247 | self.register(Node::Mult { 248 | xs: inputs.to_vec(), 249 | }) 250 | } 251 | /// Registers a convolution operation node and returns the index 252 | pub fn conv(&mut self, kernel: Idx, img: Idx, padding: Padding, stride: usize) -> Idx { 253 | self.register(Node::Conv { 254 | kernel, 255 | img, 256 | conv: Conv::new(padding, stride), 257 | }) 258 | } 259 | /// Registers a pooling operation takes a `Batch * Height * Width * Channels` image and reduces 260 | /// it to a `Batch * Channels` vector. 261 | pub fn global_pool(&mut self, x: Idx, pool: GlobalPool) -> Idx { 262 | self.register(Node::GlobalPool { x, pool }) 263 | } 264 | /// Registers a Relu operation which takes the elementwise maximum of the input array and 0. 265 | pub fn relu(&mut self, x: Idx) -> Idx { 266 | self.register(Node::Activation { 267 | x, 268 | a: Activation::Relu { leak: 0.0 }, 269 | }) 270 | } 271 | /// Registers a new sigmoid activation operation, an 272 | /// elementwise application of $\frac{ 1 }{1 - e^{-x}}$. 273 | pub fn sigmoid(&mut self, x: Idx) -> Idx { 274 | self.register(Node::Activation { 275 | x, 276 | a: Activation::Sigmoid, 277 | }) 278 | } 279 | /// Registers a Tanh operation. 280 | pub fn tanh(&mut self, x: Idx) -> Idx { 281 | self.register(Node::Activation { 282 | x, 283 | a: Activation::Tanh, 284 | }) 285 | } 286 | /// Registers a matrix multiplication of vectors `v` by matrix `mat`. 287 | pub fn matmul(&mut self, mat: Idx, v: Idx) -> Idx { 288 | self.register(Node::MatMul { mat, v }) 289 | } 290 | /// Registers an embedding later that converts A0 to vector representation 291 | pub fn embedding(&mut self, emb: Idx, code: Idx) -> Idx { 292 | self.register(Node::Embedding { emb, code }) 293 | } 294 | } 295 | 296 | fn view_at_idxs<'a>( 297 | indices: &[Idx], 298 | nodes: &'a BTreeMap>, 299 | ) -> Box<[ArrayViewD<'a, f32>]> { 300 | let mut vals = Vec::new(); 301 | for i in indices.iter() { 302 | vals.push(nodes[&i.idx].view()); 303 | } 304 | vals.into_boxed_slice() 305 | } 306 | -------------------------------------------------------------------------------- /src/nodes/conv.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{Array4, ArrayD, ArrayViewD, Ix4}; 2 | use nodes::Operation; 3 | 4 | /// Implements convolution [Operation](trait.Operation.html) that supports striding and padding. 5 | /// It takes two arguments, `kernel`, and `input`. The output shape depends on striding, padding, 6 | /// and eventually dialation (not yet implemented). 7 | /// * `input ~ (Batch * Height * Width * Channels_in)`. 8 | /// * `kernel ~ (Kernel_height * Kernel_width * Channels_in * Channels_out)`. 9 | #[derive(Debug, Serialize, Deserialize)] 10 | pub struct Conv { 11 | _dialation: usize, 12 | stride: (usize, usize), 13 | padding: Padding, 14 | } 15 | 16 | /// Type of padding to use in a [Conv](nodes/struct.Conv.html) node . `No` padding means a non-strided 17 | /// convolution will shrink by the dimensions of the kernel as pixels at the edge will not be the 18 | /// center of a convolution. `Same` padding allows for convolution of edge pixels by assuming 19 | /// the values beyond the images are equal to the edge. Other not implemented padding strategies 20 | /// are "Zero" padding or "Reflection" padding. 21 | #[derive(Debug, Clone, Copy, Serialize, Deserialize)] 22 | pub enum Padding { 23 | Same, 24 | No, 25 | } 26 | 27 | impl Default for Conv { 28 | fn default() -> Self { 29 | Conv { 30 | _dialation: 1, 31 | stride: (1, 1), 32 | padding: Padding::Same, 33 | } 34 | } 35 | } 36 | 37 | struct ConvDims { 38 | i: usize, 39 | j: usize, 40 | di: usize, 41 | dj: usize, 42 | } 43 | 44 | impl Conv { 45 | pub fn new(padding: Padding, stride: usize) -> Self { 46 | Conv { 47 | _dialation: 1, 48 | stride: (stride, stride), 49 | padding, 50 | } 51 | } 52 | #[inline(always)] 53 | fn conv_point(&self, size: &ConvDims, idx: &ConvDims) -> Option<(usize, usize)> { 54 | // Returns the index of the point of the input image image multiplied by the Kernel 55 | // in the convolution. 56 | let kernel_offset_i = size.di >> 1; 57 | let kernel_offset_j = size.dj >> 1; 58 | 59 | match self.padding { 60 | Padding::Same => { 61 | // subtract kernel size / 2 to center kernel 62 | let ci = (idx.i * self.stride.0 + idx.di) 63 | .saturating_sub(kernel_offset_i) 64 | .min(size.i - 1); 65 | let cj = (idx.j * self.stride.1 + idx.dj) 66 | .saturating_sub(kernel_offset_j) 67 | .min(size.j - 1); 68 | Some((ci, cj)) 69 | } 70 | Padding::No => { 71 | // No padding so next image is di (dj) rows (cols) smaller 72 | if kernel_offset_i <= idx.i && idx.i + size.di < size.i { 73 | let ci = idx.i * self.stride.0 + idx.di - kernel_offset_i; 74 | if kernel_offset_j <= idx.j && idx.j + size.dj < size.j { 75 | let cj = idx.j * self.stride.1 + idx.dj - kernel_offset_j; 76 | return Some((ci, cj)); 77 | } 78 | } 79 | None 80 | } 81 | } 82 | } 83 | } 84 | 85 | impl Operation for Conv { 86 | #[allow(clippy::deref_addrof)] 87 | fn eval(&self, inputs: &[ArrayViewD]) -> ArrayD { 88 | assert!( 89 | inputs.len() == 2, 90 | "Convolution operation takes two arguments" 91 | ); 92 | let kernel = inputs[0].view().into_dimensionality::().unwrap(); 93 | let image = inputs[1].view().into_dimensionality::().unwrap(); 94 | 95 | // Attaching mut to n_i, n_j makes them variables that dont need to be mutable 96 | // but without it, they are references ==><== 97 | if let ([n_di, n_dj, n_c0, n_c1], [n_b, n_i, n_j, n_c0_]) = (kernel.shape(), image.shape()) 98 | { 99 | assert_eq!( 100 | n_c0_, n_c0, 101 | "number of channels in image do not match kernel's" 102 | ); 103 | // Striding shrinks image 104 | let (out_i, out_j) = match self.padding { 105 | Padding::Same => (n_i / self.stride.0, n_j / self.stride.1), 106 | Padding::No => ((n_i - n_di) / self.stride.0, (n_j - n_dj) / self.stride.1), 107 | }; 108 | let size = ConvDims { 109 | i: *n_i, 110 | j: *n_j, 111 | di: *n_di, 112 | dj: *n_dj, 113 | }; 114 | let mut output = Array4::zeros([*n_b, out_i, out_j, *n_c1]); 115 | 116 | for b in 0..*n_b { 117 | for i in 0..out_i { 118 | for j in 0..out_j { 119 | for di in 0..*n_di { 120 | for dj in 0..*n_dj { 121 | let idx = ConvDims { i, j, di, dj }; 122 | if let Some((ci, cj)) = self.conv_point(&size, &idx) { 123 | let ker = kernel.slice(s!(di, dj, .., ..)); 124 | let img = image.slice(s!(b, ci, cj, ..)); 125 | let mut out = output.slice_mut(s!(b, i, j, ..)); 126 | out += &img.dot(&ker); 127 | } 128 | } 129 | } 130 | } 131 | } 132 | } 133 | output.into_dyn() 134 | } else { 135 | unreachable!() 136 | } 137 | } 138 | fn grad(&self, inputs: &[ArrayViewD], loss: ArrayViewD) -> Vec> { 139 | assert!( 140 | inputs.len() == 2, 141 | "Convolution operation takes two arguments" 142 | ); 143 | let kernel = inputs[0].view().into_dimensionality::().unwrap(); 144 | let image = inputs[1].view().into_dimensionality::().unwrap(); 145 | let loss = loss.into_dimensionality::().unwrap(); 146 | 147 | if let ([n_di, n_dj, n_c0, n_c1], [n_b, n_i, n_j, n_c0_]) = (kernel.shape(), image.shape()) 148 | { 149 | assert_eq!( 150 | n_c0_, n_c0, 151 | "number of channels in image do not match kernel's" 152 | ); 153 | let out_i = n_i / self.stride.0; 154 | let out_j = n_j / self.stride.1; 155 | 156 | let mut grad_kernel = Array4::zeros([*n_di, *n_dj, *n_c0, *n_c1]); 157 | let mut grad_image = Array4::zeros([*n_b, *n_i, *n_j, *n_c0]); 158 | 159 | let size = ConvDims { 160 | i: *n_i, 161 | j: *n_j, 162 | di: *n_di, 163 | dj: *n_dj, 164 | }; 165 | // Benchmarks suggests that iproduct is in fact not zero cost (slower than this). 166 | // manually nrolling the loop or implementing blocking may increase performance... 167 | for b in 0..*n_b { 168 | for i in 0..out_i { 169 | for j in 0..out_j { 170 | for di in 0..*n_di { 171 | for dj in 0..*n_dj { 172 | let idx = ConvDims { i, j, di, dj }; 173 | if let Some((ci, cj)) = self.conv_point(&size, &idx) { 174 | // // OPTIMIZE Batch version is worse, I'm guessing due to cache 175 | // // inefficency because the stride for `b` is so large 176 | // let img = image.slice(s!(.., ci, cj, ..)); 177 | // let los = loss.slice(s!(.., i, j, ..)); 178 | // let ker = kernel.slice(s!(di, dj, .., ..)); 179 | // let mut gker = grad_kernel.slice_mut(s!(di, dj, .., ..)); 180 | // let mut gimg = grad_image.slice_mut(s!(.., ci, cj, ..)); 181 | // gker += &img.t().dot(&los); 182 | // gimg += &los.dot(&ker.t()); 183 | 184 | for c0 in 0..*n_c0 { 185 | let img = image[(b, ci, cj, c0)]; 186 | let gi = &mut grad_image[(b, ci, cj, c0)]; 187 | for c1 in 0..*n_c1 { 188 | let l = loss[(b, i, j, c1)]; 189 | let k = kernel[(di, dj, c0, c1)]; 190 | grad_kernel[(di, dj, c0, c1)] += l * img; 191 | *gi += l * k; 192 | } 193 | } 194 | } 195 | } 196 | } 197 | } 198 | } 199 | } 200 | vec![grad_kernel.into_dyn(), grad_image.into_dyn()] 201 | } else { 202 | unreachable!() 203 | } 204 | } 205 | } 206 | 207 | #[cfg(test)] 208 | mod tests { 209 | use super::*; 210 | use rand::distributions::{Distribution, Uniform}; 211 | use rand::thread_rng; 212 | use std::f32; 213 | use test::Bencher; 214 | use xavier_initialize; 215 | 216 | #[test] 217 | fn conv_point_same_padding() { 218 | let ker = Array4::zeros([3, 3, 1, 1]).into_dyn(); 219 | let img = Array4::zeros([4, 4, 4, 1]).into_dyn(); 220 | let c = Conv::new(Padding::Same, 1); 221 | c.eval(&[ker.view(), img.view()]); 222 | 223 | let size = ConvDims { 224 | i: 4, 225 | j: 4, 226 | di: 3, 227 | dj: 3, 228 | }; 229 | 230 | assert_eq!( 231 | c.conv_point( 232 | &size, 233 | &ConvDims { 234 | i: 0, 235 | j: 0, 236 | di: 0, 237 | dj: 0 238 | } 239 | ), 240 | Some((0, 0)), 241 | "Top left going up and left" 242 | ); 243 | assert_eq!( 244 | c.conv_point( 245 | &size, 246 | &ConvDims { 247 | i: 0, 248 | j: 3, 249 | di: 2, 250 | dj: 2 251 | } 252 | ), 253 | Some((1, 3)), 254 | "Top right going down and right" 255 | ); 256 | assert_eq!( 257 | c.conv_point( 258 | &size, 259 | &ConvDims { 260 | i: 2, 261 | j: 2, 262 | di: 1, 263 | dj: 1 264 | } 265 | ), 266 | Some((2, 2)), 267 | "Center going center" 268 | ); 269 | assert_eq!( 270 | c.conv_point( 271 | &size, 272 | &ConvDims { 273 | i: 3, 274 | j: 3, 275 | di: 0, 276 | dj: 0 277 | } 278 | ), 279 | Some((2, 2)), 280 | "Bottom right going up and left" 281 | ); 282 | assert_eq!( 283 | c.conv_point( 284 | &size, 285 | &ConvDims { 286 | i: 3, 287 | j: 3, 288 | di: 0, 289 | dj: 2 290 | } 291 | ), 292 | Some((2, 3)), 293 | "Bottom right going down and left" 294 | ); 295 | } 296 | 297 | // #[test] TODO test no_padding 298 | // fn conv_point_no_padding() { 299 | // unimplemented!() 300 | // } 301 | 302 | fn stripe_detector_kernel(horizontal: bool) -> ArrayD { 303 | Array4::from_shape_fn([3, 3, 1, 1], move |(row, col, _, _)| { 304 | if (horizontal && row == 1) || (!horizontal && col == 1) { 305 | 1.0 / 3.0 306 | } else { 307 | -1.0 / 6.0 308 | } 309 | }) 310 | .into_dyn() 311 | } 312 | 313 | fn stripes(horizontal: bool) -> ArrayD { 314 | Array4::from_shape_fn( 315 | [1, 10, 10, 1], 316 | move |(_, row, col, _)| if horizontal { row % 2 } else { col % 2 } as f32, 317 | ) 318 | .into_dyn() 319 | } 320 | 321 | #[test] 322 | fn stripe_detectors() { 323 | for (padding, det, st) in iproduct!( 324 | [Padding::Same, Padding::No].into_iter(), 325 | [true, false].into_iter(), 326 | [true, false].into_iter() 327 | ) { 328 | println!("{:?}", (*padding, *det, *st)); 329 | let kernel = stripe_detector_kernel(*det); 330 | let stripes = stripes(*st); 331 | let conv = Conv::new(*padding, 1); 332 | let detections = conv.eval(&[kernel.view(), stripes.view()]); 333 | let detections = detections.slice(s!(0, .., .., 0)); 334 | if *det != *st { 335 | assert!( 336 | detections.iter().all(|x| x.abs() < f32::EPSILON), 337 | "padding: {:?}; h_detector: {:?}; h_stripes: {:?}; detected orthogonal lines\n{:?}", 338 | padding, 339 | *det, 340 | *st, 341 | detections 342 | ); 343 | } else { 344 | assert!( 345 | detections.iter().any(|x| x.abs() != 0.0), 346 | "padding: {:?}; h_detector: {:?}; h_stripes: {:?}; detected nothing\n{:?}", 347 | padding, 348 | *det, 349 | *st, 350 | detections 351 | ); 352 | } 353 | } 354 | } 355 | 356 | #[test] 357 | fn identity_kernel_eval() { 358 | let identity_kernel = Array4::from_shape_fn([3, 3, 1, 1], |(di, dj, c0, c1)| { 359 | if di == 1 && dj == 1 && c0 == c1 { 360 | 1.0 361 | } else { 362 | 0.0 363 | } 364 | }) 365 | .into_dyn(); 366 | 367 | let img = stripes(true); 368 | let conv = Conv::new(Padding::Same, 1); 369 | let res = conv.eval(&[identity_kernel.view(), img.view()]); 370 | let conv = res.slice(s!(0, .., .., 0)); 371 | let orig = img.slice(s!(0, .., .., 0)); 372 | 373 | assert_eq!(orig, conv, "Identity Kernel failed\n"); 374 | } 375 | 376 | #[test] 377 | fn identity_kernel_grad() { 378 | let identity_kernel = Array4::from_shape_fn([3, 3, 1, 1], |(di, dj, c0, c1)| { 379 | if di == 1 && dj == 1 && c0 == c1 { 380 | 1.0 381 | } else { 382 | 0.0 383 | } 384 | }) 385 | .into_dyn(); 386 | 387 | let orig = stripes(true); 388 | let conv = Conv::new(Padding::Same, 1); 389 | let eval = conv.eval(&[identity_kernel.view(), orig.view()]); 390 | let grad = conv.grad(&[identity_kernel.view(), orig.view()], eval.view()); 391 | assert_eq!(grad.len(), 2); 392 | let g_img = grad[1].view(); 393 | assert_eq!(g_img, orig.view(), "backwards identity"); 394 | } 395 | 396 | #[test] 397 | fn minimize_from_positive_image() { 398 | let mut rng = thread_rng(); 399 | let unif = Uniform::new(1.0, 2.0); 400 | let conv = Conv::new(Padding::Same, 1); 401 | let mut kernel = xavier_initialize(&[3, 3, 2, 2]); 402 | 403 | for _ in 0..5 { 404 | for _ in 0..3 { 405 | let img = Array4::from_shape_fn([4, 5, 5, 2], |_| unif.sample(&mut rng)).into_dyn(); 406 | conv.eval(&[kernel.view(), img.view()]); 407 | let grad = conv.grad(&[kernel.view(), img.view()], img.view()); 408 | let g_ker = grad[0].view(); 409 | kernel = kernel - g_ker 410 | } 411 | assert!( 412 | kernel.iter().all(|x| *x < 0.0), 413 | "Kernel failed to learn to be all negative\n{:?}", 414 | kernel.view() 415 | ) 416 | } 417 | } 418 | 419 | #[bench] 420 | fn eval_3x3x8_kernel_64x64x3_img(b: &mut Bencher) { 421 | let kernel = xavier_initialize(&[3, 3, 3, 8]); 422 | let conv = Conv::new(Padding::Same, 1); 423 | let img = xavier_initialize(&[1, 64, 64, 3]); 424 | 425 | b.iter(|| conv.eval(&[kernel.view(), img.view()])); 426 | } 427 | #[bench] 428 | fn grad_3x3x8_kernel_64x64x3_img(b: &mut Bencher) { 429 | let kernel = xavier_initialize(&[3, 3, 3, 8]); 430 | let conv = Conv::new(Padding::Same, 1); 431 | let img = xavier_initialize(&[1, 64, 64, 3]); 432 | let out = conv.eval(&[kernel.view(), img.view()]); 433 | 434 | b.iter(|| conv.grad(&[kernel.view(), img.view()], out.view())); 435 | } 436 | 437 | } 438 | --------------------------------------------------------------------------------