├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── download_mnist.sh ├── examples ├── char_rnn_unrolled.rs └── mnist.rs └── src ├── graph.rs ├── init.rs ├── layers.rs ├── lib.rs ├── op.rs ├── train.rs ├── util.rs └── var_store.rs /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | target 3 | Cargo.lock 4 | *.swp 5 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "deeplearn" 3 | version = "0.1.0" 4 | authors = ["Theodore DeRego "] 5 | 6 | [dependencies.gpuarray] 7 | git = "https://github.com/tedsta/gpuarray-rs" 8 | 9 | [dependencies] 10 | rand = "*" 11 | num = "*" 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016 Theodore DeRego 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deeplearn-rs 2 | 3 | Deep learning in Rust! This is my first shot at this. It's mostly just a proof of concept right now. The API will change. 4 | 5 | ### Status 6 | 7 | We have these models implemented (check out the examples folder): 8 | - MNIST handwritten digit recognition 9 | - char-rnn using LSTM 10 | 11 | So far, we have the following layers implemented: 12 | 13 | - Matrix multiply (fully connected) 14 | - Add (for bias, for example) 15 | - LSTM 16 | - Softmax 17 | - MSE loss 18 | - Cross entropy loss 19 | 20 | We have the following optimizers: 21 | - SGD 22 | - RMSProp 23 | 24 | ### Road map 25 | 26 | - More layer types (in the order that I'll probably get to them) 27 | - Conv2d 28 | - Pooling 29 | - Dropout 30 | - Allow datatypes other than `f32` and implement casting between arrays of primitive numeric types. 31 | - Provide utilities for working with data 32 | - images 33 | - tsv and csv 34 | - raw text data and word embeddings 35 | 36 | ### Goals 37 | 38 | We have a looong way to go :) 39 | 40 | - Fast 41 | - Easy to use 42 | - Portable 43 | - More control when you need it 44 | - Easy to define custom layers 45 | - Readable internal codebase 46 | 47 | ### License 48 | 49 | MIT 50 | -------------------------------------------------------------------------------- /download_mnist.sh: -------------------------------------------------------------------------------- 1 | mkdir -p data/mnist 2 | cd data/mnist 3 | wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz 4 | wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz 5 | wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz 6 | wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz 7 | gunzip *.gz 8 | -------------------------------------------------------------------------------- /examples/char_rnn_unrolled.rs: -------------------------------------------------------------------------------- 1 | extern crate deeplearn; 2 | extern crate gpuarray as ga; 3 | 4 | use std::collections::{HashMap, HashSet}; 5 | use std::fs::File; 6 | use std::io::{ 7 | self, 8 | BufReader, 9 | Write, 10 | }; 11 | use std::path::Path; 12 | use std::rc::Rc; 13 | 14 | use deeplearn::{init, layers, util, train, Graph}; 15 | use deeplearn::op::Softmax; 16 | use deeplearn::train::Optimizer; 17 | use ga::Array; 18 | 19 | fn main() { 20 | let batch_size = 1; 21 | 22 | let (char_map, rev_char_map, lines) = load_char_rnn_data("data/bible_no_verse.txt").unwrap(); 23 | let char_classes = char_map.len(); 24 | println!("Loaded char rnn data"); 25 | println!("Character types: {}", char_classes); 26 | 27 | //////////////////////////////////////////////////////////////////////////////////////////////// 28 | // Build the graph 29 | 30 | let batch_size = 1; 31 | let l1_size = 200; 32 | let l2_size = char_classes; 33 | 34 | let ctx = Rc::new(ga::Context::new()); 35 | let ref mut graph = Graph::new(ctx.clone()); 36 | 37 | let l1_w = graph.add_variable(vec![1+char_classes+l1_size, 4*l1_size], true, init::Normal(0.001, 0.003)); 38 | let l2_w = graph.add_variable(vec![1+l1_size+l2_size, 4*l2_size], true, init::Normal(0.001, 0.003)); 39 | //let l3_w = graph.add_variable(vec![l2_size, char_classes], true, init::Normal(0.001, 0.005)); 40 | //let l3_b = graph.add_variable(vec![1, char_classes], true, init::Normal(0.001, 0.005)); 41 | let l1_c0 = graph.add_variable(vec![batch_size, l1_size], false, 0.0); 42 | let l1_h0 = graph.add_variable(vec![batch_size, l1_size], false, 0.0); 43 | let l2_c0 = graph.add_variable(vec![batch_size, l2_size], false, 0.0); 44 | let l2_h0 = graph.add_variable(vec![batch_size, l2_size], false, 0.0); 45 | 46 | let net_step = 47 | |graph: &mut Graph, (l1_prev_h, l1_prev_c, l2_prev_h, l2_prev_c)| { 48 | let input = graph.add_variable(vec![batch_size, char_classes], false, 0.0); 49 | let (l1_out, l1_c) = layers::lstm_unrolled(graph, input, l1_w, l1_prev_h, l1_prev_c); 50 | let (l2_out, l2_c) = layers::lstm_unrolled(graph, l1_out, l2_w, l2_prev_h, l2_prev_c); 51 | //let l3_fcb = layers::dense_biased_manual(graph, l2_out, l3_w, l3_b); 52 | //let l3_out = layers::activation(graph, Softmax(l3_fcb)); 53 | // Loss 54 | let (loss_out, train_out) = layers::cross_entropy(graph, l2_out); 55 | let loss_d = graph.add_gradient(loss_out); // Create a gradient to apply to the loss function 56 | // We apply a gradient of -0.1 to the loss function 57 | loss_d.write(graph, &Array::new(loss_d.get(graph).shape().to_owned(), -1.0)); 58 | 59 | ((l1_out, l1_c, l2_out, l2_c), (input, l2_out, train_out)) 60 | }; 61 | 62 | let (last_recur_in, steps) = util::unrolled_net(graph, 25, (l1_h0, l1_c0, l2_h0, l2_c0), net_step); 63 | 64 | //////////////////////////////////////////////////////////////////////////////////////////////// 65 | // Train and validate the network 66 | 67 | let mut l2_out_cpu = Array::new(vec![batch_size, l2_size], 0.0); 68 | let mut l3_out_cpu = Array::new(vec![batch_size, char_classes], 0.0); 69 | let mut l3_out_d_cpu = Array::new(vec![batch_size, char_classes], 0.0); 70 | let mut l1_w_cpu = Array::new(vec![1+char_classes+l1_size, 4*l1_size], 0.0); 71 | let mut argmax_out = Array::new(vec![batch_size], 0usize); 72 | 73 | let rmsprop = train::RmsProp::new(graph, 0.001, 0.9); 74 | 75 | let samples = 10000; 76 | let mut i = 0; 77 | for line in &lines { 78 | for (t, &(input, _, train_out)) in (1..line.len()).zip(steps.iter()) { 79 | input.write(graph, &char_map[&line[t-1]]); 80 | train_out.write(graph, &char_map[&line[t]]); 81 | } 82 | 83 | for (t, &(_, _, _)) in (1..line.len()).zip(steps.iter()) { 84 | print!("{}", line[t] as char); 85 | } 86 | println!(""); 87 | 88 | graph.forward(); 89 | graph.backward(); 90 | 91 | for (_, &(_, l2_out, _)) in (1..line.len()).zip(steps.iter()) { 92 | l2_out.read(graph, &mut l2_out_cpu); 93 | //l1_w.read(graph, &mut l1_w_cpu); 94 | //graph.get_gradient(l3_out).read(graph, &mut l3_out_d_cpu); 95 | //graph.get_gradient(l1_w).read(graph, &mut l1_w_cpu); 96 | //graph.get_gradient(last_recur_in.2).read(graph, &mut l2_out_cpu); 97 | //println!("{:?}", l3_out_cpu); 98 | //println!("{:?}", l3_out_d_cpu); 99 | //println!("{:?}", l1_w_cpu); 100 | //println!("{:?}", l2_out_cpu); 101 | util::argmax_rows(&l2_out_cpu, &mut argmax_out); 102 | let next_char = rev_char_map[&argmax_out[&[0]]]; 103 | print!("{}", next_char as char); 104 | } 105 | println!(""); 106 | 107 | //train::apply_gradients(graph); 108 | rmsprop.update(graph); 109 | 110 | println!("{}", i); 111 | 112 | i += 1; 113 | if i > samples { 114 | break; 115 | } 116 | } 117 | 118 | /*let mut l3_out_cpu = Array::new(vec![batch_size, char_classes], 0.0); 119 | let mut argmax_out = Array::new(vec![batch_size], 0usize); 120 | 121 | loop { 122 | // Print prompt 123 | print!(">"); 124 | io::stdout().flush().unwrap(); 125 | 126 | // Get seed string 127 | let mut seed = String::new(); 128 | io::stdin().read_line(&mut seed).unwrap(); 129 | 130 | // Input the seed string 131 | let mut last_char = 0u8; 132 | for c in seed.trim_right().as_bytes() { 133 | print!("{}", *c as char); 134 | input.write(graph, &char_map[c]); 135 | graph.forward(); 136 | l3_out.read(graph, &mut l3_out_cpu); 137 | util::argmax_rows(&l3_out_cpu, &mut argmax_out); 138 | last_char = rev_char_map[&argmax_out[&[0]]]; 139 | } 140 | 141 | // Generate the rest of the output 142 | for _ in 0..150 { 143 | print!("{}", last_char as char); 144 | io::stdout().flush().unwrap(); 145 | 146 | input.write(graph, &char_map[&last_char]); 147 | graph.forward(); 148 | l3_out.read(graph, &mut l3_out_cpu); 149 | //println!("{:?}", l3_out_cpu); 150 | util::argmax_rows(&l3_out_cpu, &mut argmax_out); 151 | last_char = rev_char_map[&argmax_out[&[0]]]; 152 | 153 | if last_char == b'\n' { 154 | break; 155 | } 156 | } 157 | println!(""); 158 | }*/ 159 | } 160 | 161 | pub fn load_char_rnn_data>(path: P) 162 | -> io::Result<(HashMap>, HashMap, Vec>)> 163 | { 164 | use std::io::BufRead; 165 | 166 | let ref mut file = BufReader::new(File::open(path).unwrap()); 167 | 168 | let mut unique_chars = HashSet::new(); 169 | let mut lines = vec![]; 170 | 171 | for line in file.lines() { 172 | let line: String = try!(line) + "\n"; 173 | for c in line.as_bytes() { 174 | unique_chars.insert(*c); 175 | } 176 | lines.push(line.as_bytes().to_owned()); 177 | } 178 | 179 | let char_classes = unique_chars.len(); 180 | let mut char_map = HashMap::new(); 181 | let mut rev_char_map = HashMap::new(); 182 | for (i, c) in unique_chars.into_iter().enumerate() { 183 | char_map.insert(c, util::one_hot_row(i, char_classes)); 184 | rev_char_map.insert(i, c); 185 | } 186 | 187 | Ok((char_map, rev_char_map, lines)) 188 | } 189 | -------------------------------------------------------------------------------- /examples/mnist.rs: -------------------------------------------------------------------------------- 1 | extern crate deeplearn; 2 | extern crate gpuarray as ga; 3 | 4 | use std::fs::File; 5 | use std::io::{ 6 | self, 7 | BufReader, 8 | Read, 9 | }; 10 | use std::path::Path; 11 | use std::rc::Rc; 12 | 13 | use deeplearn::{init, layers, train, util, Graph}; 14 | use deeplearn::op::Relu; 15 | use ga::Array; 16 | 17 | fn main() { 18 | let batch_size = 5; 19 | 20 | // Training data 21 | println!("Reading training labels..."); 22 | let train_labels = read_mnist_labels("data/mnist/train-labels-idx1-ubyte", None).unwrap(); 23 | println!("Label count: {}", train_labels.len()); 24 | 25 | // Build label batches 26 | let train_labels_logits: Vec> = 27 | (0usize..train_labels.len()/batch_size) 28 | .map(|i| i*batch_size) 29 | .map(|i| util::one_hot_rows_batch(&train_labels[i..i+batch_size], 10)) 30 | .collect(); 31 | 32 | println!("Reading training images..."); 33 | let (rows, columns, mut train_images) = 34 | read_mnist_images("data/mnist/train-images-idx3-ubyte", batch_size, None).unwrap(); 35 | 36 | // Flatten the validation images from [batch_size, rows, columns] to [batch_size, rows*columns] 37 | for image in &mut train_images { 38 | image.reshape(vec![batch_size, rows*columns]); 39 | } 40 | 41 | // Validation data 42 | println!("Reading validation labels..."); 43 | let val_labels = read_mnist_labels("data/mnist/t10k-labels-idx1-ubyte", Some(1000)).unwrap(); 44 | println!("Label count: {}", val_labels.len()); 45 | 46 | println!("Reading validation images..."); 47 | let (_, _, mut val_images) = 48 | read_mnist_images("data/mnist/t10k-images-idx3-ubyte", batch_size, Some(1000)).unwrap(); 49 | 50 | // Flatten the training images from [batch_size, rows, columns] to [batch_size, rows*columns] 51 | for image in &mut val_images { 52 | image.reshape(vec![batch_size, rows*columns]); 53 | } 54 | 55 | //////////////////////////////////////////////////////////////////////////////////////////////// 56 | // Build the graph 57 | 58 | let ctx = Rc::new(ga::Context::new()); 59 | let ref mut graph = Graph::new(ctx.clone()); 60 | 61 | ////////////////////////// 62 | // Layer 1 63 | 64 | // Input. 1 batch of rows*columns inputs 65 | let input = graph.add_variable(vec![batch_size, rows*columns], false, 0.0); 66 | 67 | // Biased fully connected layer with 300 neurons 68 | let (l1_fcb, _, _) = layers::dense_biased(graph, input, 300, 69 | init::Normal(0.001, 0.005), // Weights initializer 70 | init::Normal(0.001, 0.005)); // Bias initializer 71 | let l1_out = layers::activation(graph, Relu(l1_fcb)); 72 | 73 | ////////////////////////// 74 | // Layer 2 75 | 76 | // Biased fully connected layer with 10 neurons 77 | let (l2_fcb, _, _) = layers::dense_biased(graph, l1_out, 10, 78 | init::Normal(0.001, 0.005), // Weights initializer 79 | init::Normal(0.001, 0.005)); // Bias initializer 80 | let l2_out = layers::activation(graph, Relu(l2_fcb)); 81 | let l2_out_d = graph.get_gradient(l2_out); 82 | 83 | ////////////////////////// 84 | // Loss 85 | 86 | let (loss_out, train_out) = layers::mse(graph, l2_out); 87 | let loss_d = graph.add_gradient(loss_out); // Create a gradient to apply to the loss function 88 | 89 | //////////////////////////////////////////////////////////////////////////////////////////////// 90 | // Train and validate the network 91 | 92 | // We apply a gradient of -0.001 to the loss function 93 | let loss_d_cpu = Array::new(vec![batch_size, 10], -1.0); 94 | loss_d.write(graph, &loss_d_cpu); 95 | 96 | let mut loss_out_cpu = Array::new(vec![batch_size, 10], 0.0); 97 | let mut l2_out_cpu = Array::new(vec![batch_size, 10], 0.0); 98 | let mut l2_out_d_cpu = Array::new(vec![batch_size, 10], 0.0); 99 | 100 | let mut predictions = Array::new(vec![batch_size], 0usize); 101 | let mut num_correct = 0; 102 | 103 | { 104 | // Put this in it's own scope so that our train_update closure doesn't hold onto all of our 105 | // stuff until the end of main() 106 | let train_update = |graph: &mut Graph, epoch: usize| { 107 | // Get the output 108 | l2_out.read(graph, &mut l2_out_cpu); 109 | 110 | for b in 0..batch_size { 111 | // Get the most likely digit (the index of the neuron with the highest output) 112 | util::argmax_rows(&l2_out_cpu, &mut predictions); 113 | let prediction = predictions[&[b]]; 114 | 115 | // Check if the model was correct 116 | if prediction == train_labels[epoch*batch_size + b] as usize { 117 | num_correct += 1; 118 | } 119 | } 120 | 121 | if epoch % 1000 == 999 { 122 | l2_out_d.read(graph, &mut l2_out_d_cpu); 123 | loss_out.read(graph, &mut loss_out_cpu); 124 | println!("==================="); 125 | println!("Epoch: {}", epoch); 126 | println!("out = {:?}", l2_out_cpu); 127 | println!("out_d = {:?}", l2_out_d_cpu); 128 | println!("loss = {:?}", loss_out_cpu); 129 | println!("Accuracy: {}%", (num_correct as f32)/((batch_size*1000) as f32) * 100.0); 130 | num_correct = 0; 131 | } 132 | }; 133 | 134 | let trainer = train::Trainer; 135 | let rms_prop = train::RmsProp::new(graph, 0.0001, 0.9); 136 | trainer.train(graph, &rms_prop, train_images.len(), train_update, 137 | &[(input, &train_images), (train_out, &train_labels_logits)]); 138 | } 139 | 140 | ///////////////////////// 141 | // Validate the network 142 | println!("#######################################"); 143 | println!("Validating"); 144 | num_correct = 0; 145 | for epoch in 0..val_images.len() { 146 | // Upload training data 147 | input.write(graph, &val_images[epoch]); 148 | 149 | // Run the graph 150 | graph.forward(); 151 | 152 | // Get the output 153 | l2_out.read(graph, &mut l2_out_cpu); 154 | 155 | // Get the most likely digit (the index of the neuron with the highest output) 156 | for b in 0..batch_size { 157 | util::argmax_rows(&l2_out_cpu, &mut predictions); 158 | let prediction = predictions[&[b]]; 159 | 160 | // Check if the model was correct 161 | if prediction == val_labels[epoch*batch_size + b] as usize { 162 | num_correct += 1; 163 | } 164 | } 165 | } 166 | println!("Validation Accuracy: {}%", (num_correct as f32)/((batch_size*val_images.len()) as f32) * 100.0); 167 | } 168 | 169 | fn read_mnist_labels>(path: P, num_samples: Option) -> io::Result> { 170 | use std::cmp; 171 | use std::io::{Error, ErrorKind}; 172 | 173 | let ref mut file = BufReader::new(File::open(path).unwrap()); 174 | 175 | let magic = u32::from_be(try!(read_u32(file))); 176 | if magic != 2049 { 177 | return Err(Error::new(ErrorKind::Other, 178 | format!("Invalid magic number. Got expect 2049, got {}", 179 | magic).as_ref())) 180 | } 181 | 182 | let label_count = u32::from_be(try!(read_u32(file))) as usize; 183 | let label_count = cmp::min(label_count, num_samples.unwrap_or(label_count)); 184 | 185 | let mut labels = Vec::with_capacity(label_count); 186 | for _ in 0..label_count { 187 | labels.push(try!(read_u8(file))); 188 | } 189 | 190 | Ok(labels) 191 | } 192 | 193 | fn read_mnist_images>(path: P, batch_size: usize, num_samples: Option) 194 | -> io::Result<(usize, usize, Vec>)> { 195 | use std::cmp; 196 | use std::io::{Error, ErrorKind}; 197 | 198 | let ref mut file = BufReader::new(File::open(path).unwrap()); 199 | 200 | let magic = u32::from_be(try!(read_u32(file))); 201 | if magic != 2051 { 202 | return Err(Error::new(ErrorKind::Other, 203 | format!("Invalid magic number. Got expect 2051, got {}", 204 | magic).as_ref())) 205 | } 206 | 207 | let image_count = u32::from_be(try!(read_u32(file))) as usize; 208 | let rows = u32::from_be(try!(read_u32(file))) as usize; 209 | let columns = u32::from_be(try!(read_u32(file))) as usize; 210 | 211 | let image_count = cmp::min(image_count, num_samples.unwrap_or(image_count)); 212 | 213 | let mut images = Vec::with_capacity(image_count); 214 | for _ in 0..image_count/batch_size { 215 | let mut pixel_buf = vec![0u8; batch_size*rows*columns]; 216 | try!(file.read_exact(pixel_buf.as_mut())); 217 | let array = Array::from_vec(vec![batch_size, rows, columns], 218 | pixel_buf.into_iter().map(|x| (x as f32)/255.0).collect()); 219 | images.push(array); 220 | } 221 | 222 | Ok((rows, columns, images)) 223 | } 224 | 225 | fn read_u8(reader: &mut T) -> io::Result { 226 | use std::mem; 227 | 228 | let mut buf: [u8; 1] = [0]; 229 | reader.read_exact(&mut buf).map(|_| { 230 | let data: u8 = unsafe { mem::transmute(buf) }; 231 | data 232 | }) 233 | } 234 | 235 | fn read_u32(reader: &mut T) -> io::Result { 236 | use std::mem; 237 | 238 | let mut buf: [u8; 4] = [0, 0, 0, 0]; 239 | reader.read_exact(&mut buf).map(|_| { 240 | let data: u32 = unsafe { mem::transmute(buf) }; 241 | data 242 | }) 243 | } 244 | -------------------------------------------------------------------------------- /src/graph.rs: -------------------------------------------------------------------------------- 1 | use std::cell::Ref; 2 | use std::collections::HashMap; 3 | use std::rc::Rc; 4 | 5 | use ga::{self, Array, Tensor, TensorMode}; 6 | use rand; 7 | 8 | use super::init::Initializer; 9 | use super::op::{OpBuilder, OpDescriptor, Operation}; 10 | use super::var_store::{VarIndex, VarStore}; 11 | 12 | #[derive(Copy, Clone)] 13 | pub enum NodeInput { 14 | Var(VarIndex), // Regular input variable 15 | Recurrent(usize), // Recurrent connection 16 | } 17 | 18 | pub struct Node { 19 | pub inputs: Vec, 20 | pub outputs: Vec, 21 | pub in_grad: Vec, // gradients on inputs 22 | pub out_grad: Vec, // gradients on outputs 23 | pub back_dep: Vec, 24 | pub back_dep_cache: Vec>>, 25 | } 26 | 27 | pub struct Graph { 28 | ctx: Rc, 29 | nodes: Vec, 30 | node_ops: Vec>, 31 | pub var_store: VarStore, 32 | out_var_map: HashMap, // Maps output variable to its node and index within node 33 | // Gradients on variables that are inputs to the graph - they have no corresponding node 34 | in_var_map: HashMap, 35 | learnables: Vec<(VarIndex, GradIndex)>, // Learnable variables 36 | rnn_learnable_accum: Vec, 37 | in_grad: Vec, // Gradients on variables that are inputs to the graph 38 | 39 | rng: rand::ThreadRng, 40 | } 41 | 42 | impl Graph { 43 | pub fn new(ctx: Rc) -> Self { 44 | Graph { 45 | ctx: ctx, 46 | nodes: vec![], 47 | node_ops: vec![], 48 | var_store: VarStore::new(), 49 | out_var_map: HashMap::new(), 50 | in_var_map: HashMap::new(), 51 | learnables: vec![], 52 | rnn_learnable_accum: vec![], 53 | in_grad: vec![], 54 | rng: rand::thread_rng(), 55 | } 56 | } 57 | 58 | pub fn add_node(&mut self, op: T) -> NodeIndex { 59 | let node_index = NodeIndex(self.nodes.len()); 60 | 61 | let OpDescriptor { op, inputs: node_inputs, out_shapes, back_dep } = 62 | op.build(&self.ctx, &mut self.var_store).unwrap(); 63 | 64 | // Create output variables 65 | let mut outputs = vec![]; 66 | for (i, shape) in out_shapes.into_iter().enumerate() { 67 | let var_index = self.var_store.add(Tensor::new(self.ctx.as_ref(), shape, TensorMode::Mut)); 68 | outputs.push(var_index); 69 | self.out_var_map.insert(var_index, (node_index, i)); 70 | } 71 | let mut out_grad = vec![OutGrad::new(); outputs.len()]; 72 | // Set up inputs and gradients on inputs 73 | let mut inputs = vec![]; 74 | let mut in_grad = vec![]; 75 | for input in node_inputs { 76 | let (v, gradient) = 77 | match input { 78 | NodeInput::Var(v) => (v, self.add_gradient(v)), 79 | NodeInput::Recurrent(out) => { 80 | let v = outputs[out]; 81 | let gradient = Self::create_gradient(&self.ctx, &mut self.var_store, v, 82 | &mut out_grad[out]); 83 | (v, gradient) 84 | } 85 | }; 86 | inputs.push(v); 87 | in_grad.push(gradient); 88 | } 89 | // Create the node 90 | self.nodes.push(Node { inputs: inputs, 91 | outputs: outputs, 92 | in_grad: in_grad, 93 | out_grad: out_grad, 94 | back_dep: back_dep, 95 | back_dep_cache: vec![] }); 96 | // Add the corresponding node op 97 | self.node_ops.push(Box::new(op)); 98 | node_index 99 | } 100 | 101 | pub fn add_variable(&mut self, 102 | shape: Vec, 103 | learnable: bool, 104 | init: I) -> VarIndex { 105 | let a = init.init(&mut self.rng, shape.clone()); 106 | let v = self.var_store.add(Tensor::from_array(&self.ctx, &a, TensorMode::Mut)); 107 | self.in_var_map.insert(v, self.in_grad.len()); 108 | if learnable { 109 | self.learnables.push((v, GradIndex::InVar(self.in_grad.len()))); 110 | self.rnn_learnable_accum.push(self.var_store.add(Tensor::new(&self.ctx, shape, TensorMode::Mut))); 111 | } 112 | self.in_grad.push(OutGrad::new()); 113 | v 114 | } 115 | 116 | pub fn get_gradient(&self, v: VarIndex) -> GradIndex { 117 | match self.out_var_map.get(&v).map(|x| *x) { 118 | Some((node, out_index)) => { 119 | // v is the output of a node 120 | GradIndex::OutVar(node, out_index) 121 | }, 122 | None => { 123 | // v is an input to the graph - it has no corresponding node 124 | let in_grad_index = *self.in_var_map.get(&v) 125 | .expect("Variable is neither input nor output. Nonsense!"); 126 | GradIndex::InVar(in_grad_index) 127 | }, 128 | } 129 | } 130 | 131 | pub fn add_gradient(&mut self, v: VarIndex) -> VarIndex { 132 | match self.out_var_map.get(&v).map(|x| *x) { 133 | Some((node, out_index)) => { 134 | // v is the output of a node 135 | /*self.nodes[node.0].out_grad[out_index] 136 | .fork(&self.ctx, &mut self.var_store, gradient);*/ 137 | Self::create_gradient(&self.ctx, &mut self.var_store, v, 138 | &mut self.nodes[node.0].out_grad[out_index]) 139 | }, 140 | None => { 141 | // v is an input to the graph - it has no corresponding node 142 | let in_grad_index = *self.in_var_map.get(&v) 143 | .expect("Variable is neither input nor output. Nonsense!"); 144 | /*self.in_grad[in_grad_index] 145 | .fork(&self.ctx, &mut self.var_store, gradient);*/ 146 | Self::create_gradient(&self.ctx, &mut self.var_store, v, 147 | &mut self.in_grad[in_grad_index]) 148 | }, 149 | } 150 | } 151 | 152 | fn create_gradient(ctx: &ga::Context, 153 | var_store: &mut VarStore, 154 | v: VarIndex, 155 | out_grad: &mut OutGrad) 156 | -> VarIndex { 157 | let shape = var_store.get(v).shape().to_owned(); 158 | let gradient = var_store.add(Tensor::new(ctx, shape, TensorMode::Mut)); 159 | out_grad.fork(ctx, var_store, gradient); 160 | gradient 161 | } 162 | 163 | pub fn forward(&mut self) { 164 | // Forward pass 165 | // 166 | // NOTE: We just execute the nodes in order. We can do this because of the way the graph is 167 | // built. When a user wants to add a node, he/she must also supply the inputs. This means 168 | // any dependencies must already be added before the node can be added. Therefore, we can 169 | // assert that all dependents come after their dependencies in the `self.nodes` array. 170 | for (node, op) in self.nodes.iter().zip(&mut self.node_ops) { 171 | op.forward(&self.ctx, &self.var_store, node); 172 | } 173 | } 174 | 175 | pub fn backward(&mut self) { 176 | // Backward pass 177 | // We run through the nodes in reverse order. See note in Graph::forward 178 | for (node, op) in self.nodes.iter_mut().rev().zip(self.node_ops.iter_mut().rev()) { 179 | // Sum the gradients on each output if there are multiple gradients 180 | for out_grad in &node.out_grad { 181 | out_grad.maybe_sum(self.ctx.as_ref(), &mut self.var_store); 182 | } 183 | op.backward(&self.ctx, &mut self.var_store, node); 184 | } 185 | for grad in &self.in_grad { 186 | grad.maybe_sum(self.ctx.as_ref(), &mut self.var_store); 187 | } 188 | } 189 | 190 | pub fn forward_rnn(&mut self, t: usize) { 191 | for (node, op) in self.nodes.iter_mut().zip(&mut self.node_ops) { 192 | op.forward(&self.ctx, &self.var_store, node); 193 | let mut back_dep_step = vec![]; 194 | for back_dep in &node.back_dep { 195 | back_dep_step.push(self.var_store.get(*back_dep).get(&self.ctx)); 196 | } 197 | node.back_dep_cache.push(back_dep_step); 198 | } 199 | } 200 | 201 | pub fn backward_rnn(&mut self, t: usize) { 202 | for (node, op) in self.nodes.iter_mut().rev().zip(self.node_ops.iter_mut().rev()) { 203 | // Sum the gradients on each output if there are multiple gradients 204 | for out_grad in &node.out_grad { 205 | out_grad.maybe_sum(self.ctx.as_ref(), &self.var_store); 206 | } 207 | for (back_dep, cached) in node.back_dep.iter().zip(node.back_dep_cache[t].iter()) { 208 | self.var_store.get(*back_dep).set(&self.ctx, cached); 209 | } 210 | op.backward(&self.ctx, &self.var_store, node); 211 | } 212 | for (&(_, learn_grad), learn_accum) in self.learnables.iter().zip(self.rnn_learnable_accum.iter()) { 213 | if let GradIndex::InVar(in_grad_index) = learn_grad { 214 | ga::add(&self.ctx, &self.var_store.get(self.in_grad[in_grad_index].get()), -1, 215 | &self.var_store.get(*learn_accum), &self.var_store.get(*learn_accum)); 216 | } else { 217 | unreachable!(); 218 | } 219 | } 220 | } 221 | 222 | pub fn reset_rnn(&mut self) { 223 | for (&(_, learn_grad), learn_accum) in self.learnables.iter().zip(self.rnn_learnable_accum.iter()) { 224 | if let GradIndex::InVar(in_grad_index) = learn_grad { 225 | ga::copy_to(&self.ctx, &self.var_store.get(*learn_accum), 226 | &self.var_store.get(self.in_grad[in_grad_index].get())); 227 | } else { 228 | unreachable!(); 229 | } 230 | } 231 | for (node, op) in self.nodes.iter_mut().zip(&mut self.node_ops) { 232 | node.back_dep_cache.clear(); 233 | op.reset_rnn(&self.ctx, &mut self.var_store, node); 234 | } 235 | } 236 | 237 | pub fn context(&self) -> &ga::Context { 238 | &self.ctx 239 | } 240 | 241 | pub fn learnables(&self) -> &[(VarIndex, GradIndex)] { 242 | &self.learnables 243 | } 244 | } 245 | 246 | //////////////////////////////////////////////////////////////////////////////////////////////////// 247 | 248 | #[derive(Clone)] 249 | pub struct OutGrad { 250 | gradient: Option, // The gradient or sum of gradients 251 | gradients: Vec, 252 | } 253 | 254 | impl OutGrad { 255 | pub fn new() -> Self { 256 | OutGrad { 257 | gradient: None, 258 | gradients: vec![], 259 | } 260 | } 261 | 262 | pub fn get(&self) -> VarIndex { 263 | self.gradient.unwrap() 264 | } 265 | 266 | pub fn try_get(&self) -> Option { 267 | self.gradient 268 | } 269 | 270 | fn maybe_sum(&self, ctx: &ga::Context, var_store: &VarStore) { 271 | if self.gradients.len() > 0 { 272 | if let Some(sum) = self.gradient { 273 | ga::copy_to(ctx, &var_store.get(self.gradients[0]), &var_store.get(sum)); 274 | for grad in &self.gradients[1..] { 275 | ga::add(ctx, &var_store.get(sum), -1, &var_store.get(*grad), &var_store.get(sum)); 276 | } 277 | } 278 | } 279 | } 280 | 281 | fn fork(&mut self, ctx: &ga::Context, var_store: &mut VarStore, v: VarIndex) { 282 | if self.gradients.len() > 0 { 283 | // There are multiple gradients already, just add the new one to the list 284 | self.gradients.push(v); 285 | } else if let Some(gradient) = self.gradient { 286 | // There is still only one gradient, switch it to a fork 287 | let shape = { 288 | let grad = var_store.get(gradient); 289 | grad.shape().to_vec() 290 | }; 291 | // Create variable for gradient sum 292 | self.gradient = Some(var_store.add(Tensor::new(ctx, shape, TensorMode::Mut))); 293 | self.gradients.push(gradient); 294 | self.gradients.push(v); 295 | } else { 296 | // This is the only gradient so far, so we don't need to sum anything 297 | self.gradient = Some(v); 298 | } 299 | } 300 | } 301 | 302 | #[derive(Copy, Clone)] 303 | pub enum GradIndex { 304 | InVar(usize), 305 | OutVar(NodeIndex, usize), 306 | } 307 | 308 | impl GradIndex { 309 | pub fn get<'a>(&self, graph: &'a Graph) -> Ref<'a, Tensor> { 310 | match *self { 311 | GradIndex::InVar(in_grad_index) => { 312 | graph.in_grad[in_grad_index].get().get(graph) 313 | }, 314 | GradIndex::OutVar(node, out_index) => { 315 | node.get(graph).out_grad[out_index].get().get(graph) 316 | }, 317 | } 318 | } 319 | 320 | pub fn read(&self, g: &Graph, a: &mut Array) { 321 | self.get(g).read(g.context(), a); 322 | } 323 | } 324 | 325 | //////////////////////////////////////////////////////////////////////////////////////////////////// 326 | 327 | #[derive(Copy, Clone)] 328 | pub struct NodeIndex(usize); 329 | 330 | impl NodeIndex { 331 | pub fn get<'a>(&self, g: &'a Graph) -> &'a Node { 332 | &g.nodes[self.0] 333 | } 334 | } 335 | 336 | #[test] 337 | fn it_works() { 338 | use super::op::MatMul; 339 | 340 | let ctx = Rc::new(ga::Context::new()); 341 | 342 | // Setup the graph 343 | let mut graph = Graph::new(ctx.clone()); 344 | let a = graph.add_variable(vec![1, 2], true, vec![1.4, 0.3]); 345 | let wa = graph.add_variable(vec![2, 3], true, vec![0.5, 0.3, 0.2, 346 | 0.6, 0.7, 0.7]); 347 | let node = graph.add_node(MatMul(a, wa)); 348 | let node_out = node.get(&graph).outputs[0]; 349 | let node_g = graph.add_gradient(node_out); 350 | 351 | // Send some input data 352 | let node_g_cpu = ga::Array::from_vec(vec![1, 3], vec![1.0, -1.0, 0.5]); 353 | node_g.get(&graph).set(&ctx, &node_g_cpu); 354 | 355 | // Run the network 356 | graph.forward(); 357 | graph.backward(); 358 | let out = node.get(&graph).outputs[0].get(&graph).get(&ctx); 359 | let wa_d = graph.get_gradient(wa).get(&graph).get(&ctx); 360 | println!("out = {:?}", out); 361 | println!("wa_d = {:?}", wa_d); 362 | assert!(out.buffer() == &[0.88, 0.63, 0.49]); 363 | assert!(wa_d.buffer() == &[1.4, -1.4, 0.7, 364 | 0.3, -0.3, 0.15]); 365 | } 366 | -------------------------------------------------------------------------------- /src/init.rs: -------------------------------------------------------------------------------- 1 | use ga::Array; 2 | use rand; 3 | 4 | pub trait Initializer { 5 | fn init(self, rng: &mut rand::ThreadRng, shape: Vec) -> Array; 6 | } 7 | 8 | //////////////////////////////////////////////////////////////////////////////////////////////////// 9 | 10 | impl Initializer for f32 { 11 | fn init(self, _: &mut rand::ThreadRng, shape: Vec) -> Array { 12 | Array::new(shape, self) 13 | } 14 | } 15 | 16 | //////////////////////////////////////////////////////////////////////////////////////////////////// 17 | 18 | impl Initializer for Vec { 19 | fn init(self, _: &mut rand::ThreadRng, shape: Vec) -> Array { 20 | Array::from_vec(shape, self) 21 | } 22 | } 23 | 24 | //////////////////////////////////////////////////////////////////////////////////////////////////// 25 | 26 | /// Uniform(min, max) 27 | pub struct Uniform(pub f32, pub f32); 28 | 29 | impl Initializer for Uniform { 30 | fn init(self, rng: &mut rand::ThreadRng, shape: Vec) -> Array { 31 | use rand::Rng; 32 | 33 | let Uniform(min, max) = self; 34 | let vec = (0..shape[0]*shape[1]).map(|_| rng.next_f32()*(max-min) + min).collect(); 35 | Array::from_vec(shape, vec) 36 | } 37 | } 38 | 39 | //////////////////////////////////////////////////////////////////////////////////////////////////// 40 | 41 | /// Normal(mean, standard deviation) 42 | pub struct Normal(pub f32, pub f32); 43 | 44 | impl Initializer for Normal { 45 | fn init(self, rng: &mut rand::ThreadRng, shape: Vec) -> Array { 46 | use rand::distributions::Sample; 47 | 48 | let Normal(mean, std_dev) = self; 49 | let mut dist = rand::distributions::Normal::new(mean as f64, std_dev as f64); 50 | let vec = (0..shape[0]*shape[1]).map(|_| dist.sample(rng) as f32).collect(); 51 | Array::from_vec(shape, vec) 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/layers.rs: -------------------------------------------------------------------------------- 1 | use graph::Graph; 2 | use init::Initializer; 3 | use op::{Add, CrossEntropy, Lstm, LstmUnrolled, MatMul, Mse, OpBuilder}; 4 | use var_store::VarIndex; 5 | 6 | pub fn dense(graph: &mut Graph, 7 | input: VarIndex, 8 | layer_size: usize, 9 | w_init: WI) 10 | -> (VarIndex, VarIndex) { 11 | // Input shape is [batch_size x input_size] 12 | let input_size = input.get(graph).shape()[1]; 13 | 14 | // Weights for layer 1: [input_size x layer_size] 15 | let weights = graph.add_variable(vec![input_size, layer_size], true, w_init); 16 | 17 | // Use matrix multiplication to do a fully connected layer 18 | let mat_mul = graph.add_node(MatMul(input, weights)); 19 | let mat_mul_out = mat_mul.get(&graph).outputs[0]; 20 | 21 | (mat_mul_out, weights) 22 | } 23 | 24 | pub fn dense_biased(graph: &mut Graph, 25 | input: VarIndex, 26 | layer_size: usize, 27 | w_init: WI, 28 | b_init: BI) 29 | -> (VarIndex, VarIndex, VarIndex) { 30 | // Input shape is [batch_size x input_size] 31 | let input_size = input.get(graph).shape()[1]; 32 | 33 | // Weights for layer 1: [input_size x layer_size] 34 | let weights = graph.add_variable(vec![input_size, layer_size], true, w_init); 35 | 36 | // Use matrix multiplication to do a fully connected layer 37 | let mat_mul = graph.add_node(MatMul(input, weights)); 38 | let mat_mul_out = mat_mul.get(&graph).outputs[0]; 39 | 40 | // Biases, one for each neuron in layer 41 | let bias = graph.add_variable(vec![1, layer_size], true, b_init); 42 | // Add the biases to the matrix multiplication output 43 | let biased = graph.add_node(Add(mat_mul_out, bias, 0)); 44 | // Grab VarIndex for biased's output 45 | let biased_out = biased.get(&graph).outputs[0]; 46 | 47 | (biased_out, weights, bias) 48 | } 49 | 50 | pub fn dense_biased_manual(graph: &mut Graph, 51 | input: VarIndex, 52 | weights: VarIndex, 53 | bias: VarIndex) 54 | -> VarIndex { 55 | // Use matrix multiplication to do a fully connected layer 56 | let mat_mul = graph.add_node(MatMul(input, weights)); 57 | let mat_mul_out = mat_mul.get(&graph).outputs[0]; 58 | 59 | // Add the biases to the matrix multiplication output 60 | let biased = graph.add_node(Add(mat_mul_out, bias, 0)); 61 | // Grab VarIndex for biased's output 62 | let biased_out = biased.get(&graph).outputs[0]; 63 | 64 | biased_out 65 | } 66 | 67 | pub fn activation(graph: &mut Graph, op: A) -> VarIndex { 68 | // Run the biased input*weight sums through an ReLU activation 69 | let activation = graph.add_node(op); 70 | // Grab VarIndex for l2_relu's output 71 | let activation_out = activation.get(&graph).outputs[0]; 72 | 73 | activation_out 74 | } 75 | 76 | pub fn lstm(graph: &mut Graph, 77 | input: VarIndex, 78 | layer_size: usize, 79 | w_init: WI) 80 | -> (VarIndex, VarIndex) { 81 | // Input shape is [batch_size, input_size] 82 | let input_size = input.get(graph).shape()[1]; 83 | 84 | // Weights for layer 1: [1+input_size+layer_size, 4*layer_size] 85 | let weights = graph.add_variable(vec![1+input_size+layer_size, 4*layer_size], true, w_init); 86 | 87 | // Use matrix multiplication to do a fully connected layer 88 | let lstm = graph.add_node(Lstm(input, weights, layer_size)); 89 | let lstm_out = lstm.get(&graph).outputs[0]; 90 | 91 | (lstm_out, weights) 92 | } 93 | 94 | pub fn lstm_unrolled(graph: &mut Graph, 95 | input: VarIndex, 96 | weights: VarIndex, 97 | prev_h: VarIndex, 98 | prev_c: VarIndex) 99 | -> (VarIndex, VarIndex) { 100 | // Use matrix multiplication to do a fully connected layer 101 | let lstm = graph.add_node(LstmUnrolled(input, weights, prev_h, prev_c)); 102 | let lstm_out = lstm.get(&graph).outputs[0]; 103 | let c = lstm.get(&graph).outputs[1]; 104 | 105 | (lstm_out, c) 106 | } 107 | 108 | pub fn mse(graph: &mut Graph, out: VarIndex) -> (VarIndex, VarIndex) { 109 | let out_shape = out.get(graph).shape().to_owned(); 110 | 111 | // Expected output 112 | let train_out = graph.add_variable(out_shape, false, 0.0); 113 | 114 | let loss = graph.add_node(Mse(out, train_out)); 115 | let loss_out = loss.get(&graph).outputs[0]; 116 | 117 | (loss_out, train_out) 118 | } 119 | 120 | pub fn cross_entropy(graph: &mut Graph, out: VarIndex) -> (VarIndex, VarIndex) { 121 | let out_shape = out.get(graph).shape().to_owned(); 122 | 123 | // Expected output 124 | let train_out = graph.add_variable(out_shape, false, 0.0); 125 | 126 | let loss = graph.add_node(CrossEntropy(out, train_out)); 127 | let loss_out = loss.get(&graph).outputs[0]; 128 | 129 | (loss_out, train_out) 130 | } 131 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] extern crate gpuarray as ga; 2 | extern crate num; 3 | extern crate rand; 4 | 5 | pub use graph::{Graph, NodeIndex}; 6 | pub use var_store::VarIndex; 7 | pub use op::Operation; 8 | pub use train::Trainer; 9 | 10 | pub mod graph; 11 | pub mod init; 12 | pub mod layers; 13 | pub mod op; 14 | pub mod train; 15 | pub mod util; 16 | pub mod var_store; 17 | -------------------------------------------------------------------------------- /src/op.rs: -------------------------------------------------------------------------------- 1 | use ga::{self, Tensor}; 2 | use ga::tensor::TensorMode; 3 | 4 | use super::graph::{Node, NodeInput}; 5 | use super::var_store::{VarIndex, VarStore}; 6 | 7 | pub trait Operation : 'static { 8 | fn forward(&mut self, &ga::Context, &VarStore, &Node); 9 | fn backward(&mut self, &ga::Context, &VarStore, &Node); 10 | fn reset_rnn(&mut self, _: &ga::Context, _: &VarStore, _: &Node) { } 11 | } 12 | 13 | pub trait OpBuilder { 14 | type Op: Operation; 15 | 16 | fn build(&self, ctx: &ga::Context, v: &mut VarStore) 17 | -> Result, String>; 18 | } 19 | 20 | pub struct OpDescriptor { 21 | pub op: T, 22 | pub inputs: Vec, 23 | pub out_shapes: Vec>, 24 | pub back_dep: Vec, // Dependencies for backward pass 25 | } 26 | 27 | //////////////////////////////////////////////////////////////////////////////////////////////////// 28 | 29 | pub struct MatMul(pub VarIndex, pub VarIndex); 30 | 31 | impl OpBuilder for MatMul { 32 | type Op = MatMulImpl; 33 | 34 | fn build(&self, ctx: &ga::Context, v: &mut VarStore) 35 | -> Result, String> { 36 | let a_shape = v.get(self.0).shape().to_vec(); 37 | let b_shape = v.get(self.1).shape().to_vec(); 38 | if a_shape[1] != b_shape[0] { 39 | return Err(format!("DIM ERROR: Shapes must be of form [I, J] and [J, K] 40 | (got {:?} and {:?}) for MatMul", 41 | a_shape, b_shape)); 42 | } 43 | let out_shape = vec![a_shape[0], b_shape[1]]; 44 | Ok(OpDescriptor { 45 | op: MatMulImpl::new(ctx, v, a_shape, b_shape), 46 | inputs: vec![NodeInput::Var(self.0), NodeInput::Var(self.1)], 47 | out_shapes: vec![out_shape], 48 | back_dep: vec![self.0, self.1], 49 | }) 50 | } 51 | } 52 | 53 | pub struct MatMulImpl { 54 | a_t: VarIndex, 55 | b_t: VarIndex, 56 | } 57 | 58 | impl MatMulImpl { 59 | pub fn new(ctx: &ga::Context, v: &mut VarStore, a_shape: Vec, b_shape: Vec) -> Self { 60 | MatMulImpl { 61 | a_t: v.add(Tensor::new(ctx, vec![a_shape[1], a_shape[0]], TensorMode::Mut)), 62 | b_t: v.add(Tensor::new(ctx, vec![b_shape[1], b_shape[0]], TensorMode::Mut)), 63 | } 64 | } 65 | } 66 | 67 | impl Operation for MatMulImpl { 68 | fn forward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 69 | let a = &v.get(n.inputs[0]); 70 | let b = &v.get(n.inputs[1]); 71 | let c = &v.get(n.outputs[0]); 72 | ga::matmul(ctx, a, b, c); // c = a*b 73 | } 74 | 75 | fn backward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 76 | let a = &v.get(n.inputs[0]); 77 | let b = &v.get(n.inputs[1]); 78 | let a_d = &v.get(n.in_grad[0]); 79 | let b_d = &v.get(n.in_grad[1]); 80 | let g = &v.get(n.out_grad[0].get()); 81 | 82 | // Derivative with respect to first input 83 | // a_d = g*b_t 84 | ga::transpose(ctx, b, &v.get(self.b_t)); 85 | ga::matmul(ctx, g, &v.get(self.b_t), a_d); 86 | 87 | // Derivative with respect to second input 88 | // b_d = a_t*g 89 | ga::transpose(ctx, a, &v.get(self.a_t)); 90 | ga::matmul(ctx, &v.get(self.a_t), g, b_d); 91 | } 92 | } 93 | 94 | //////////////////////////////////////////////////////////////////////////////////////////////////// 95 | 96 | pub struct Add(pub VarIndex, pub VarIndex, pub i32); 97 | 98 | impl OpBuilder for Add { 99 | type Op = AddImpl; 100 | 101 | fn build(&self, _: &ga::Context, v: &mut VarStore) 102 | -> Result, String> { 103 | let a = &v.get(self.0); 104 | let b = &v.get(self.1); 105 | let add_axis = self.2; 106 | match add_axis { 107 | -1 => { 108 | if a.shape() != b.shape() { 109 | return Err("DIM ERROR: Shapes must be equal for Add".to_string()); 110 | } 111 | }, 112 | 0 => { 113 | if b.shape()[0] != 1 || a.shape()[1] != b.shape()[1] { 114 | return Err(format!("DIM ERROR: Shapes must be [M, N] and [1, N] 115 | (got {:?} and {:?}) for Add with broadcast axis of 0", 116 | a.shape(), b.shape())); 117 | } 118 | }, 119 | 1 => { 120 | if b.shape()[1] != 1 || a.shape()[0] != b.shape()[0] { 121 | return Err(format!("DIM ERROR: Shapes must be [M, N] and [M, 1] 122 | (got {:?} and {:?}) for Add with broadcast axis of 1", 123 | a.shape(), b.shape())); 124 | } 125 | }, 126 | _ => { 127 | return Err(format!("BROADCAST AXIS ERROR: Invalid broadcast axis {}", add_axis)); 128 | } 129 | } 130 | Ok(OpDescriptor { 131 | op: AddImpl::new(add_axis), 132 | inputs: vec![NodeInput::Var(self.0), NodeInput::Var(self.1)], 133 | out_shapes: vec![a.shape().to_vec()], 134 | back_dep: vec![], 135 | }) 136 | } 137 | } 138 | 139 | pub struct AddImpl { 140 | axis: i32, 141 | } 142 | 143 | impl AddImpl { 144 | pub fn new(axis: i32) -> Self { 145 | AddImpl { 146 | axis: axis, 147 | } 148 | } 149 | } 150 | 151 | impl Operation for AddImpl { 152 | fn forward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 153 | let a = &v.get(n.inputs[0]); 154 | let b = &v.get(n.inputs[1]); 155 | let c = &v.get(n.outputs[0]); 156 | ga::add(ctx, a, self.axis, b, c); // c = a+b 157 | } 158 | 159 | fn backward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 160 | let a_d = &v.get(n.in_grad[0]); 161 | let b_d = &v.get(n.in_grad[1]); 162 | let g = &v.get(n.out_grad[0].get()); 163 | ga::copy_to(ctx, g, a_d); 164 | ga::sum(ctx, g, self.axis as usize, b_d); 165 | } 166 | } 167 | 168 | //////////////////////////////////////////////////////////////////////////////////////////////////// 169 | 170 | // Softmax(input) 171 | pub struct Softmax(pub VarIndex); 172 | 173 | impl OpBuilder for Softmax { 174 | type Op = SoftmaxImpl; 175 | 176 | fn build(&self, ctx: &ga::Context, v: &mut VarStore) 177 | -> Result, String> { 178 | let batches = v.get(self.0).shape()[0]; 179 | let classes = v.get(self.0).shape()[1]; 180 | Ok(OpDescriptor { 181 | op: SoftmaxImpl::new(ctx, v, batches), 182 | inputs: vec![NodeInput::Var(self.0)], 183 | out_shapes: vec![vec![batches, classes]], 184 | back_dep: vec![], 185 | }) 186 | } 187 | } 188 | 189 | pub struct SoftmaxImpl { 190 | exp_sum: VarIndex, 191 | } 192 | 193 | impl SoftmaxImpl { 194 | pub fn new(ctx: &ga::Context, v: &mut VarStore, batches: usize) -> Self { 195 | SoftmaxImpl { 196 | exp_sum: v.add(Tensor::new(ctx, vec![batches, 1], TensorMode::Mut)), 197 | } 198 | } 199 | } 200 | 201 | impl Operation for SoftmaxImpl { 202 | fn forward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 203 | let input = &v.get(n.inputs[0]); 204 | let prob = &v.get(n.outputs[0]); 205 | 206 | let exp_sum = &v.get(self.exp_sum); 207 | 208 | ga::exp(ctx, input, prob); 209 | ga::sum(ctx, prob, 1, exp_sum); 210 | //println!("{:?}", input.get(ctx)); 211 | //println!("{:?}", exp_sum.get(ctx)[&[0, 0]]); 212 | if exp_sum.get(ctx)[&[0, 0]].is_nan() { 213 | panic!("NaN :("); 214 | } 215 | if exp_sum.get(ctx)[&[0, 0]] == 0.0 { 216 | panic!("Zero :("); 217 | } 218 | ga::divide(ctx, prob, 1, exp_sum, prob); 219 | } 220 | 221 | fn backward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 222 | let input_d = &v.get(n.in_grad[0]); 223 | let prob_d = &v.get(n.out_grad[0].get()); 224 | 225 | ga::copy_to(ctx, prob_d, input_d); 226 | } 227 | } 228 | 229 | //////////////////////////////////////////////////////////////////////////////////////////////////// 230 | 231 | // CrossEntropy(prob, true_prob) 232 | pub struct CrossEntropy(pub VarIndex, pub VarIndex); 233 | 234 | impl OpBuilder for CrossEntropy { 235 | type Op = CrossEntropyImpl; 236 | 237 | fn build(&self, _: &ga::Context, v: &mut VarStore) 238 | -> Result, String> { 239 | let prob = &v.get(self.0); 240 | let true_prob = &v.get(self.1); 241 | if prob.shape() != true_prob.shape() { 242 | return Err("DIM ERROR: Shapes must be equal for CrossEntropy".to_string()); 243 | } 244 | Ok(OpDescriptor { 245 | op: CrossEntropyImpl::new(), 246 | inputs: vec![NodeInput::Var(self.0), NodeInput::Var(self.1)], 247 | out_shapes: vec![prob.shape().to_vec()], 248 | back_dep: vec![], 249 | }) 250 | } 251 | } 252 | 253 | pub struct CrossEntropyImpl; 254 | 255 | impl CrossEntropyImpl { 256 | pub fn new() -> Self { 257 | CrossEntropyImpl 258 | } 259 | } 260 | 261 | impl Operation for CrossEntropyImpl { 262 | fn forward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 263 | let prob = &v.get(n.inputs[0]); 264 | let true_prob = &v.get(n.inputs[1]); 265 | let loss = &v.get(n.outputs[0]); 266 | 267 | ga::log(ctx, prob, loss); 268 | ga::negate(ctx, loss, loss); 269 | ga::multiply(ctx, loss, -1, true_prob, loss); 270 | } 271 | 272 | fn backward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 273 | let prob = &v.get(n.inputs[0]); 274 | let true_prob = &v.get(n.inputs[1]); 275 | 276 | let prob_d = &v.get(n.in_grad[0]); 277 | let loss_d = &v.get(n.out_grad[0].get()); 278 | 279 | ga::sub(ctx, prob, true_prob, prob_d); 280 | ga::multiply(ctx, prob_d, -1, loss_d, prob_d); 281 | } 282 | } 283 | 284 | //////////////////////////////////////////////////////////////////////////////////////////////////// 285 | 286 | pub struct Relu(pub VarIndex); 287 | 288 | impl OpBuilder for Relu { 289 | type Op = ReluImpl; 290 | 291 | fn build(&self, _: &ga::Context, v: &mut VarStore) 292 | -> Result, String> { 293 | let a = &v.get(self.0); 294 | Ok(OpDescriptor { 295 | op: ReluImpl, 296 | inputs: vec![NodeInput::Var(self.0)], 297 | out_shapes: vec![a.shape().to_vec()], 298 | back_dep: vec![self.0], 299 | }) 300 | } 301 | } 302 | 303 | pub struct ReluImpl; 304 | 305 | impl Operation for ReluImpl { 306 | fn forward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 307 | let a = &v.get(n.inputs[0]); 308 | let b = &v.get(n.outputs[0]); 309 | ga::max(ctx, a, 0.0, b); // b = max(0, a) 310 | } 311 | 312 | fn backward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 313 | let a = &v.get(n.inputs[0]); 314 | let a_d = &v.get(n.in_grad[0]); 315 | let g = &v.get(n.out_grad[0].get()); 316 | ga::dmax(ctx, a, 0.0, a_d); 317 | ga::multiply(ctx, g, -1, a_d, a_d); 318 | } 319 | } 320 | 321 | //////////////////////////////////////////////////////////////////////////////////////////////////// 322 | 323 | pub struct Mse(pub VarIndex, pub VarIndex); 324 | 325 | impl OpBuilder for Mse { 326 | type Op = MseImpl; 327 | 328 | fn build(&self, _: &ga::Context, v: &mut VarStore) 329 | -> Result, String> { 330 | let a = &v.get(self.0); 331 | let b = &v.get(self.1); 332 | if a.shape() != b.shape() { 333 | return Err("DIM ERROR: Shapes must be equal for MSE".to_string()); 334 | } 335 | Ok(OpDescriptor { 336 | op: MseImpl, 337 | inputs: vec![NodeInput::Var(self.0), NodeInput::Var(self.1)], 338 | out_shapes: vec![a.shape().to_vec()], 339 | back_dep: vec![self.0, self.1], 340 | }) 341 | } 342 | } 343 | 344 | pub struct MseImpl; 345 | 346 | impl Operation for MseImpl { 347 | fn forward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 348 | let h = &v.get(n.inputs[0]); // predictions 349 | let y = &v.get(n.inputs[1]); // training output 350 | let out = &v.get(n.outputs[0]); 351 | ga::mse(ctx, h, y, out); // out = mse(h, y) 352 | } 353 | 354 | fn backward(&mut self, ctx: &ga::Context, v: &VarStore, n: &Node) { 355 | let h = &v.get(n.inputs[0]); // predictions 356 | let h_d = &v.get(n.in_grad[0]); 357 | let y = &v.get(n.inputs[1]); // training output 358 | let g = &v.get(n.out_grad[0].get()); 359 | ga::dmse(ctx, h, y, h_d); // h_d = dmse(h, y) 360 | ga::multiply(ctx, h_d, 0, g, h_d); // h_d = g*h_d 361 | } 362 | } 363 | 364 | //////////////////////////////////////////////////////////////////////////////////////////////////// 365 | 366 | // Lstm(input, state) 367 | pub struct Lstm(pub VarIndex, pub VarIndex, pub usize); 368 | 369 | impl OpBuilder for Lstm { 370 | type Op = LstmImpl; 371 | 372 | fn build(&self, ctx: &ga::Context, v: &mut VarStore) 373 | -> Result, String> { 374 | let batch_size = v.get(self.0).shape()[0]; 375 | let input_size = v.get(self.0).shape()[1]; 376 | let hidden_size = self.2; 377 | if v.get(self.1).shape()[0] != 1+input_size+hidden_size || v.get(self.1).shape()[1] != 4*hidden_size { 378 | return Err(format!("DIM ERROR: LSTM expects weight matrix shape of 379 | [1+input_size+hidden_size, 4*hidden_size], got {:?}", 380 | v.get(self.1).shape())); 381 | } 382 | let lstm_impl = LstmImpl::new(ctx, v, batch_size, input_size, hidden_size); 383 | let h_in = lstm_impl.h_in; 384 | let ifog_f = lstm_impl.ifog_f; 385 | Ok(OpDescriptor { 386 | op: lstm_impl, 387 | inputs: vec![NodeInput::Var(self.0), NodeInput::Var(self.1), 388 | NodeInput::Recurrent(0), NodeInput::Recurrent(1)], 389 | out_shapes: vec![vec![batch_size, hidden_size], vec![batch_size, hidden_size]], 390 | back_dep: vec![h_in, ifog_f], 391 | }) 392 | } 393 | } 394 | 395 | // LstmUnrolled(input, state, prev_h, prev_c) 396 | pub struct LstmUnrolled(pub VarIndex, pub VarIndex, pub VarIndex, pub VarIndex); 397 | 398 | impl OpBuilder for LstmUnrolled { 399 | type Op = LstmImpl; 400 | 401 | fn build(&self, ctx: &ga::Context, v: &mut VarStore) 402 | -> Result, String> { 403 | let batch_size = v.get(self.0).shape()[0]; 404 | let input_size = v.get(self.0).shape()[1]; 405 | let hidden_size = v.get(self.2).shape()[1]; 406 | if v.get(self.1).shape()[0] != 1+input_size+hidden_size || v.get(self.1).shape()[1] != 4*hidden_size { 407 | return Err(format!("DIM ERROR: LSTM expects weight matrix shape of 408 | [1+input_size+hidden_size, 4*hidden_size], got {:?}", 409 | v.get(self.1).shape())); 410 | } 411 | let lstm_impl = LstmImpl::new(ctx, v, batch_size, input_size, hidden_size); 412 | let h_in = lstm_impl.h_in; 413 | let ifog_f = lstm_impl.ifog_f; 414 | Ok(OpDescriptor { 415 | op: lstm_impl, 416 | inputs: vec![NodeInput::Var(self.0), NodeInput::Var(self.1), 417 | NodeInput::Var(self.2), NodeInput::Var(self.3)], 418 | out_shapes: vec![vec![batch_size, hidden_size], vec![batch_size, hidden_size]], 419 | back_dep: vec![h_in, ifog_f], 420 | }) 421 | } 422 | } 423 | 424 | pub struct LstmImpl { 425 | h_in: VarIndex, // Concatonated input and previous output 426 | d_h_in: VarIndex, 427 | ifog: VarIndex, // input, forget, output, gate (IFOG): input sums 428 | d_ifog: VarIndex, 429 | ifog_f: VarIndex, // input, forget, output, gate: activations 430 | d_ifog_f: VarIndex, 431 | d_c_inner: VarIndex, 432 | c_f: VarIndex, // tanh of C 433 | d_c_f: VarIndex, 434 | 435 | h_in_t: VarIndex, 436 | wlstm_t: VarIndex, 437 | 438 | input_size: usize, 439 | hidden_size: usize, 440 | } 441 | 442 | impl LstmImpl { 443 | fn new(ctx: &ga::Context, 444 | v: &mut VarStore, 445 | batch_size: usize, 446 | input_size: usize, 447 | hidden_size: usize) -> Self { 448 | let b = batch_size; 449 | let d = hidden_size; 450 | 451 | let h_in = Tensor::new(ctx, vec![b, 1 + input_size + d], TensorMode::Mut); 452 | 453 | // Fill first column of h_in with 1's to be multiplied by the biases in the weights matrix 454 | ga::fill_slice(ctx, &h_in.slice(s![.., 0]), 1.0); 455 | 456 | LstmImpl { 457 | h_in: v.add(h_in), 458 | d_h_in: v.add(Tensor::new(ctx, vec![b, 1 + input_size + d], TensorMode::Mut)), 459 | ifog: v.add(Tensor::new(ctx, vec![b, d*4], TensorMode::Mut)), 460 | d_ifog: v.add(Tensor::new(ctx, vec![b, d*4], TensorMode::Mut)), 461 | ifog_f: v.add(Tensor::new(ctx, vec![b, d*4], TensorMode::Mut)), 462 | d_ifog_f: v.add(Tensor::new(ctx, vec![b, d*4], TensorMode::Mut)), 463 | d_c_inner: v.add(Tensor::new(ctx, vec![b, d], TensorMode::Mut)), 464 | c_f: v.add(Tensor::new(ctx, vec![b, d], TensorMode::Mut)), 465 | d_c_f: v.add(Tensor::new(ctx, vec![b, d], TensorMode::Mut)), 466 | 467 | h_in_t: v.add(Tensor::new(ctx, vec![1+input_size+d, b], TensorMode::Mut)), 468 | wlstm_t: v.add(Tensor::new(ctx, vec![4*d, 1+input_size+d], TensorMode::Mut)), 469 | 470 | input_size: input_size, 471 | hidden_size: hidden_size, 472 | } 473 | } 474 | } 475 | 476 | impl Operation for LstmImpl { 477 | fn forward(&mut self, ctx: &ga::Context, v: &VarStore, node: &Node) { 478 | let d = self.hidden_size; 479 | let input_size = self.input_size; 480 | 481 | let x = &v.get(node.inputs[0]); // input 482 | let wlstm = &v.get(node.inputs[1]); // all of the weights for all the cells 483 | let prev_h = &v.get(node.inputs[2]); // Output from last timestep 484 | let prev_c = &v.get(node.inputs[3]); // C from last timestep 485 | 486 | let ref h_in = v.get(self.h_in); 487 | 488 | let ref ifog = v.get(self.ifog); 489 | let ref ifog_f = v.get(self.ifog_f); 490 | let ref c_f = v.get(self.c_f); 491 | 492 | let h = &v.get(node.outputs[0]); // output 493 | let c = &v.get(node.outputs[1]); // cell 494 | 495 | // NOTE: unless the layer is unrolled, c and prev_c are actually the same underlying buffer. 496 | // We use different aliases for clarity. 497 | // NOTE: unless the layer is unrolled, h and prev_h are actually the same underlying buffer. 498 | // We use different aliases for clarity. 499 | 500 | // Input 501 | ga::copy_to_slice(ctx, &x.slice(s![..]), &h_in.slice(s![.., 1..input_size+1])); 502 | ga::copy_to_slice(ctx, &prev_h.slice(s![..]), &h_in.slice(s![.., input_size+1..])); 503 | // Multiply inputs and weights, and add biases - all in one dot product! 504 | ga::matmul(ctx, &h_in, &wlstm, &ifog); 505 | // Compute internal activations 506 | ga::sigmoid_slice(ctx, &ifog.slice(s![.., ..3*d]), &ifog_f.slice(s![.., ..3*d])); // sigmoids 507 | ga::tanh_slice(ctx, &ifog.slice(s![.., 3*d..]), &ifog_f.slice(s![.., 3*d..])); // tanh 508 | // compute the LSTM cell activation 509 | // NOTE: we're using c_f as a temporary buffer here - we overwrite it later anyway 510 | // c[t] = ifog_f[t, .., ..d]*ifog_f[t, .., 3*d..] + ifog_f[t, .., d..2*d]*c[t-1] 511 | ga::multiply_slice(ctx, &ifog_f.slice(s![.., d..2*d]), &prev_c.slice(s![..]), &c_f.slice(s![..])); 512 | ga::multiply_slice(ctx, &ifog_f.slice(s![.., ..d]), &ifog_f.slice(s![.., 3*d..]), &c.slice(s![..])); 513 | ga::add(ctx, c, -1, c_f, c); 514 | // c_f[t] = tanh(c[t]) 515 | ga::tanh(ctx, &c, &c_f); 516 | ga::multiply_slice(ctx, &ifog_f.slice(s![.., 2*d..3*d]), &c_f.slice(s![..]), &h.slice(s![..])); 517 | } 518 | 519 | fn backward(&mut self, ctx: &ga::Context, v: &VarStore, node: &Node) { 520 | let d = self.hidden_size; 521 | let input_size = self.input_size; 522 | 523 | let wlstm = &v.get(node.inputs[1]); // all of the weights for all the cells 524 | let prev_c = &v.get(node.inputs[3]); // C from last timestep 525 | 526 | let d_x = &v.get(node.in_grad[0]); 527 | let d_wlstm = &v.get(node.in_grad[1]); 528 | let d_prev_h = &v.get(node.in_grad[2]); 529 | let d_prev_c = &v.get(node.in_grad[3]); 530 | 531 | let ref h_in = v.get(self.h_in); 532 | let ref d_h_in = v.get(self.d_h_in); 533 | 534 | let ref h_in_t = v.get(self.h_in_t); 535 | let ref wlstm_t = v.get(self.wlstm_t); 536 | 537 | let ref d_ifog = v.get(self.d_ifog); 538 | let ref ifog_f = v.get(self.ifog_f); 539 | let ref d_ifog_f = v.get(self.d_ifog_f); 540 | let ref d_c_inner = v.get(self.d_c_inner); 541 | let ref c_f = v.get(self.c_f); 542 | let ref d_c_f = v.get(self.d_c_f); 543 | 544 | let c = &v.get(node.outputs[1]); // cell 545 | 546 | let d_h = &v.get(node.out_grad[0].get()); 547 | 548 | // NOTE: unless the layer is unrolled, d_c and d_prev_c are actually the same underlying 549 | // buffer. We use different aliases for clarity. 550 | // NOTE: unless the layer is unrolled, d_h and d_prev_h are actually the same underlying 551 | // buffer. We use different aliases for clarity. 552 | 553 | //tanhCt = Ct[t] 554 | //dIFOGf[t,:,2*d:3*d] = tanhCt * dHout[t] 555 | // XXX 556 | ga::multiply_slice(ctx, &c_f.slice(s![..]), &d_h.slice(s![..]), &d_ifog_f.slice(s![.., 2*d..3*d])); 557 | // backprop tanh non-linearity first then continue backprop 558 | //dC[t] += (1-tanhCt**2) * (IFOGf[t,:,2*d:3*d] * dHout[t]) 559 | ga::dtanh(ctx, c, d_c_inner); 560 | // XXX 561 | ga::multiply_slice(ctx, &ifog_f.slice(s![.., 2*d..3*d]), &d_h.slice(s![..]), &d_c_f.slice(s![..])); 562 | ga::multiply(ctx, d_c_f, -1, d_c_inner, d_c_inner); 563 | if let Some(d_c) = node.out_grad[1].try_get() { 564 | ga::add(ctx, &v.get(d_c), -1, d_c_inner, d_c_inner); 565 | } 566 | 567 | //dIFOGf[t,:,d:2*d] = C[t-1] * dC[t] 568 | ga::multiply_slice(ctx, &prev_c.slice(s![..]), &d_c_inner.slice(s![..]), &d_ifog_f.slice(s![.., d..2*d])); 569 | //dC[t-1] += IFOGf[t,:,d:2*d] * dC[t] 570 | ga::multiply_slice(ctx, &ifog_f.slice(s![.., d..2*d]), &d_c_inner.slice(s![..]), &d_prev_c.slice(s![..])); 571 | 572 | //dIFOGf[t,:,:d] = IFOGf[t,:,3*d:] * dC[t] 573 | ga::multiply_slice(ctx, &ifog_f.slice(s![.., 3*d..]), &d_c_inner.slice(s![..]), &d_ifog_f.slice(s![.., ..d])); 574 | //dIFOGf[t,:,3*d:] = IFOGf[t,:,:d] * dC[t] 575 | ga::multiply_slice(ctx, &ifog_f.slice(s![.., ..d]), &d_c_inner.slice(s![..]), &d_ifog_f.slice(s![.., 3*d..])); 576 | 577 | // backprop activation functions 578 | //dIFOG[t,:,3*d:] = (1 - IFOGf[t,:,3*d:] ** 2) * dIFOGf[t,:,3*d:] 579 | ga::dtanh_slice(ctx, &ifog_f.slice(s![.., 3*d..]), &d_ifog.slice(s![.., 3*d..])); 580 | ga::multiply_slice(ctx, &d_ifog.slice(s![.., 3*d..]), &d_ifog_f.slice(s![.., 3*d..]), &d_ifog.slice(s![.., 3*d..])); 581 | //y = IFOGf[t,:,:3*d] 582 | //dIFOG[t,:,:3*d] = (y*(1.0-y)) * dIFOGf[t,:,:3*d] 583 | ga::dsigmoid_slice(ctx, &ifog_f.slice(s![.., ..3*d]), &d_ifog.slice(s![.., ..3*d])); 584 | ga::multiply_slice(ctx, &d_ifog.slice(s![.., ..3*d]), &d_ifog_f.slice(s![.., ..3*d]), &d_ifog.slice(s![.., ..3*d])); 585 | 586 | // backprop matrix multiply 587 | //dWLSTM += np.dot(Hin[t].transpose(), dIFOG[t]) 588 | 589 | ga::transpose(ctx, h_in, h_in_t); 590 | ga::matmul(ctx, h_in_t, d_ifog, d_wlstm); 591 | 592 | //dHin[t] = dIFOG[t].dot(WLSTM.transpose()) 593 | 594 | ga::transpose(ctx, wlstm, wlstm_t); 595 | ga::matmul(ctx, d_ifog, wlstm_t, d_h_in); 596 | //println!("{:?}", d_wlstm.get(ctx)); 597 | //println!("{:?}", d_h_in.get(ctx)); 598 | 599 | // backprop the identity transforms into Hin 600 | //dX[t] = dHin[t,:,1:input_size+1] 601 | ga::copy_to_slice(ctx, &d_h_in.slice(s![.., 1..input_size+1]), &d_x.slice(s![..])); 602 | //dHout[t-1,:] += dHin[t,:,input_size+1:] 603 | // XXX 604 | ga::copy_to_slice(ctx, &d_h_in.slice(s![.., input_size+1..]), &d_prev_h.slice(s![..])); 605 | } 606 | 607 | fn reset_rnn(&mut self, ctx: &ga::Context, v: &VarStore, node: &Node) { 608 | let prev_h = &v.get(node.inputs[2]); // H from last timestep 609 | let prev_c = &v.get(node.inputs[3]); // C from last timestep 610 | ga::fill(ctx, prev_h, 0.0); 611 | ga::fill(ctx, prev_c, 0.0); 612 | } 613 | } 614 | -------------------------------------------------------------------------------- /src/train.rs: -------------------------------------------------------------------------------- 1 | use graph::Graph; 2 | use var_store::VarIndex; 3 | 4 | use ga::{self, Array, Tensor, TensorMode}; 5 | 6 | pub struct Trainer; 7 | 8 | impl Trainer { 9 | pub fn new() -> Trainer { 10 | Trainer 11 | } 12 | 13 | pub fn train(&self, graph: &mut Graph, optimizer: &O, epochs: usize, 14 | mut update_fn: F, training_data: &[(VarIndex, &[Array])]) 15 | where F: FnMut(&mut Graph, usize) 16 | { 17 | for epoch in 0..epochs { 18 | // Upload training data 19 | for &(var, data) in training_data { 20 | var.write(graph, &data[epoch]); 21 | } 22 | 23 | // Run the graph 24 | graph.forward(); 25 | graph.backward(); 26 | 27 | optimizer.update(graph); 28 | 29 | update_fn(graph, epoch); 30 | } 31 | } 32 | } 33 | 34 | pub trait Optimizer { 35 | fn update(&self, graph: &Graph); 36 | } 37 | 38 | /// Stochastic gradient descent 39 | pub struct Sgd { 40 | learn_rate: f32, 41 | } 42 | 43 | impl Sgd { 44 | pub fn new(learn_rate: f32) -> Sgd { 45 | Sgd { 46 | learn_rate: learn_rate, 47 | } 48 | } 49 | } 50 | 51 | impl Optimizer for Sgd { 52 | fn update(&self, graph: &Graph) { 53 | for &(learn, learn_d) in graph.learnables().iter() { 54 | ga::sgd(graph.context(), &learn.get(graph), &learn_d.get(graph), self.learn_rate); 55 | } 56 | } 57 | } 58 | 59 | pub struct RmsProp { 60 | cache: Vec>, // gradient cache 61 | learn_rate: f32, 62 | decay_rate: f32, 63 | } 64 | 65 | impl RmsProp { 66 | pub fn new(graph: &Graph, learn_rate: f32, decay_rate: f32) -> RmsProp { 67 | let cache = graph.learnables().iter().map( 68 | |&(_, learn_d)| { 69 | Tensor::new(graph.context(), learn_d.get(graph).shape().to_owned(), TensorMode::Mut) 70 | }).collect(); 71 | RmsProp { 72 | cache: cache, 73 | learn_rate: learn_rate, 74 | decay_rate: decay_rate, 75 | } 76 | } 77 | } 78 | 79 | impl Optimizer for RmsProp { 80 | fn update(&self, graph: &Graph) { 81 | for (&(learn, learn_d), cache) in graph.learnables().iter().zip(self.cache.iter()) { 82 | ga::rmsprop(graph.context(), &learn.get(graph), &learn_d.get(graph), cache, self.learn_rate, self.decay_rate, 0.00001); 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | use num::{Zero, One}; 2 | 3 | use ga::Array; 4 | 5 | use graph::Graph; 6 | 7 | pub fn one_hot_row(label: N, classes: N) -> Array 8 | where usize: From, 9 | M: Clone+Zero+One, 10 | { 11 | let classes: usize = From::from(classes); // Cast class count to usize 12 | let label: usize = From::from(label); // Cast label to usize 13 | let mut buf: Vec = vec![Zero::zero(); classes]; // Create array of zeroes 14 | buf[label] = One::one(); // Set the one-hot component 15 | Array::from_vec(vec![1, classes], buf) // Construct the array 16 | } 17 | 18 | pub fn one_hot_rows_batch(labels: &[N], classes: N) -> Array 19 | where usize: From, 20 | N: Copy, 21 | M: Clone+Zero+One, 22 | { 23 | let classes: usize = From::from(classes); // Cast class count to usize 24 | let batch_size = labels.len(); 25 | let mut one_hot_batch = Array::new(vec![batch_size, classes], Zero::zero()); // Construct the array 26 | // Set the one hot components 27 | for b in 0..batch_size { 28 | one_hot_batch[&[b, From::from(labels[b])]] = One::one(); 29 | } 30 | one_hot_batch 31 | } 32 | 33 | pub fn argmax_rows(a: &Array, out: &mut Array) { 34 | let (rows, columns) = (a.shape()[0], a.shape()[1]); 35 | for row in 0..rows { 36 | // TODO: I would do this: 37 | // let max_col = (0..a.columns()).max_by_key(|col| a[&[row, *col]]); 38 | // But f32 does not implement Ord :'( 39 | let (mut max_col, mut max_val) = (0, a[&[row, 0]]); 40 | for col in 1..columns { 41 | let val = a[&[row, col]]; 42 | if val > max_val { 43 | max_col = col; 44 | max_val = val; 45 | } 46 | } 47 | out[&[row]] = max_col; 48 | } 49 | } 50 | 51 | pub fn unrolled_net(graph: &mut Graph, unroll_steps: usize, first_inputs: I, mut net_step: F) -> (I, Vec) 52 | where F: FnMut(&mut Graph, I) -> (I, O), 53 | { 54 | 55 | let mut next_inputs = first_inputs; 56 | let mut steps = vec![]; 57 | for _ in 0..unroll_steps { 58 | let (_next_inputs, step_out) = net_step(graph, next_inputs); 59 | next_inputs = _next_inputs; 60 | steps.push(step_out); 61 | } 62 | 63 | (next_inputs, steps) 64 | } 65 | -------------------------------------------------------------------------------- /src/var_store.rs: -------------------------------------------------------------------------------- 1 | use std::cell::{Ref, RefCell, RefMut}; 2 | 3 | use ga::{Array, Tensor}; 4 | 5 | use super::graph::Graph; 6 | 7 | pub struct VarStore { 8 | vars: Vec>>, 9 | } 10 | 11 | impl VarStore { 12 | pub fn new() -> Self { 13 | VarStore { 14 | vars: vec![], 15 | } 16 | } 17 | 18 | pub fn add(&mut self, v: Tensor) -> VarIndex { 19 | self.vars.push(RefCell::new(v)); 20 | VarIndex(self.vars.len()-1) 21 | } 22 | 23 | pub fn get<'a>(&'a self, v: VarIndex) -> Ref<'a, Tensor> { 24 | self.vars[v.0].borrow() 25 | } 26 | 27 | pub fn get_mut<'a>(&'a self, v: VarIndex) -> RefMut<'a, Tensor> { 28 | self.vars[v.0].borrow_mut() 29 | } 30 | } 31 | 32 | //////////////////////////////////////////////////////////////////////////////////////////////////// 33 | 34 | #[derive(Copy, Clone, Eq, PartialEq, Hash)] 35 | pub struct VarIndex(usize); 36 | 37 | impl VarIndex { 38 | pub fn get<'a>(self, g: &'a Graph) -> Ref<'a, Tensor> { 39 | g.var_store.get(self) 40 | } 41 | 42 | pub fn read(self, g: &Graph, a: &mut Array) { 43 | g.var_store.get(self).read(g.context(), a); 44 | } 45 | 46 | pub fn write(self, g: &Graph, a: &Array) { 47 | g.var_store.get(self).set(g.context(), a); 48 | } 49 | } 50 | --------------------------------------------------------------------------------