├── data ├── images │ └── cat.jpg └── fashion_mnist │ ├── md5sum.txt │ └── download.sh ├── docs ├── image_fit_output_relu_200.jpg ├── image_fit_output_relu-pe_200.jpg ├── image_fit_output_siren_200.jpg ├── image_fit_output_multi-hash_200.jpg ├── image_fit.gnuplot ├── fashion_mnist.gnuplot ├── fashion_mnist_network_linear.svg ├── fashion_mnist_network_single-layer.svg ├── array_api_values.svg └── array_api_grad.svg ├── .gitignore ├── src ├── device.rs ├── loss.rs ├── parameter.rs ├── device │ ├── descriptor_pool.rs │ ├── fence.rs │ ├── buffer_heap.rs │ ├── command_buffer.rs │ ├── context.rs │ ├── timestamp.rs │ ├── staging.rs │ └── heap.rs ├── optimizer.rs ├── kernel_matmul.glsl ├── kernel_common.glsl ├── op.rs ├── lib.rs └── module.rs ├── Cargo.toml ├── LICENSE ├── examples ├── array_api │ ├── main.rs │ └── README.md ├── image_fit │ ├── README.md │ └── main.rs ├── fashion_mnist │ ├── README.md │ └── main.rs └── sentiment │ └── main.rs ├── Makefile └── README.md /data/images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjb3d/descent/HEAD/data/images/cat.jpg -------------------------------------------------------------------------------- /docs/image_fit_output_relu_200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjb3d/descent/HEAD/docs/image_fit_output_relu_200.jpg -------------------------------------------------------------------------------- /docs/image_fit_output_relu-pe_200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjb3d/descent/HEAD/docs/image_fit_output_relu-pe_200.jpg -------------------------------------------------------------------------------- /docs/image_fit_output_siren_200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjb3d/descent/HEAD/docs/image_fit_output_siren_200.jpg -------------------------------------------------------------------------------- /docs/image_fit_output_multi-hash_200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sjb3d/descent/HEAD/docs/image_fit_output_multi-hash_200.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /.vscode 3 | /data/cifar-10-batches-bin/* 4 | /data/fashion_mnist/*.gz 5 | /data/mnist/*.gz 6 | /temp 7 | *.dot 8 | -------------------------------------------------------------------------------- /data/fashion_mnist/md5sum.txt: -------------------------------------------------------------------------------- 1 | 8d4fb7e6c68d591d4c3dfef9ec88bf0d train-images-idx3-ubyte.gz 2 | 25c81989df183df01b3e8a0aad5dffbe train-labels-idx1-ubyte.gz 3 | bef4ecab320f06d8554ea6380940ec79 t10k-images-idx3-ubyte.gz 4 | bb300cfdad3c16e7a12a480ee83cd310 t10k-labels-idx1-ubyte.gz 5 | -------------------------------------------------------------------------------- /docs/image_fit.gnuplot: -------------------------------------------------------------------------------- 1 | set xlabel "Epoch" 2 | set ylabel "Loss" 3 | set tics out 4 | set output ARG1 5 | set terminal svg enhanced background rgb 'white' 6 | set logscale y 7 | plot ARG2 with lines title "ReLU",\ 8 | ARG3 with lines title "ReLU with PE",\ 9 | ARG4 with lines title "SIREN",\ 10 | ARG5 with lines title "MULTI-HASH" 11 | -------------------------------------------------------------------------------- /src/device.rs: -------------------------------------------------------------------------------- 1 | pub mod buffer_heap; 2 | pub mod command_buffer; 3 | pub mod context; 4 | pub mod descriptor_pool; 5 | pub mod fence; 6 | mod heap; 7 | pub mod staging; 8 | pub mod timestamp; 9 | pub(crate) mod common { 10 | pub(crate) use super::{ 11 | buffer_heap::*, command_buffer::*, context::*, descriptor_pool::*, fence::*, staging::*, 12 | timestamp::*, 13 | }; 14 | } 15 | -------------------------------------------------------------------------------- /data/fashion_mnist/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | curl -O http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz 4 | curl -O http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz 5 | curl -O http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz 6 | curl -O http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz 7 | md5sum --check md5sum.txt 8 | -------------------------------------------------------------------------------- /docs/fashion_mnist.gnuplot: -------------------------------------------------------------------------------- 1 | set key above 2 | set xlabel "Epoch" 3 | set ylabel "Loss" 4 | set y2label "Accuracy" 5 | set ytics nomirror 6 | set y2tics 7 | set xtics nomirror 8 | set tics out 9 | set y2range [0.7 : 1.0] 10 | set output ARG1 11 | set terminal svg enhanced background rgb 'white' 12 | plot ARG2 using 1:2 with lines axes x1y1 title "training loss",\ 13 | ARG2 using 1:3 with lines axes x1y1 title "test loss",\ 14 | ARG2 using 1:4 with lines axes x1y2 title "training accuracy",\ 15 | ARG2 using 1:5 with lines axes x1y2 title "test accuracy" 16 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "descent" 3 | version = "0.1.0" 4 | authors = ["Simon Brown "] 5 | edition = "2018" 6 | publish = false 7 | 8 | [dependencies] 9 | arrayvec = "0.7" 10 | rand = "0.8" 11 | petgraph = "0.6" 12 | spark = { git = "https://github.com/sjb3d/spark/" } 13 | shaderc = "0.7" 14 | slotmap = "1.0" 15 | trait-set = "0.2" 16 | ordered-float = "2.7" 17 | bytemuck = "1.7" 18 | tinyvec = "1.3" 19 | 20 | [dev-dependencies] 21 | rand_chacha = "0.3" 22 | flate2 = "1.0" 23 | structopt = { version = "0.3", default-features = false } 24 | strum = { version = "0.21", features = ["derive"] } 25 | stb = "0.3.2" 26 | serde = { version = "1.0", features = ["derive"] } 27 | serde_json = "1.0" 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021 Simon Brown 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so, 8 | subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /src/loss.rs: -------------------------------------------------------------------------------- 1 | use crate::common::*; 2 | 3 | #[allow(clippy::many_single_char_names)] 4 | pub fn softmax_cross_entropy_loss<'s>(z: DualArray<'s>, y: impl IntoArray<'s>) -> DualArray<'s> { 5 | let (z, dz) = z.next_colour().into_inner(); 6 | let y = y.into_array(z.scope()); 7 | 8 | // softmax 9 | let t = (z - z.reduce_max(-1, true)).exp(); 10 | let p = t / t.reduce_sum(-1, true); 11 | 12 | // cross entropy loss 13 | let (loss, dloss) = y 14 | .select_eq(p.coord(-1), -p.log(), 0.0) 15 | .reduce_sum(-1, true) 16 | .with_empty_grad(); // TODO: pick element of p using value of y 17 | 18 | // backprop (softmax with cross entropy directly) 19 | let n = p.shape()[SignedIndex(-1)]; 20 | dz.accumulate((p - y.one_hot(n)) * dloss); 21 | 22 | (loss, dloss).into() 23 | } 24 | 25 | pub fn softmax_cross_entropy_accuracy<'s>(z: DualArray<'s>, y: impl IntoArray<'s>) -> Array<'s> { 26 | let z = z.value(); 27 | let y = y.into_array(z.scope()); 28 | 29 | // index of most likely choice 30 | let pred = z.argmax(-1, true); 31 | 32 | // set to 1 when correct, 0 when incorrect 33 | pred.select_eq(y, 1.0, 0.0) 34 | } 35 | -------------------------------------------------------------------------------- /examples/array_api/main.rs: -------------------------------------------------------------------------------- 1 | use descent::prelude::*; 2 | 3 | fn main() { 4 | let random_seed = 0x5EED5EED; 5 | 6 | let mut env = Environment::new(); 7 | 8 | let m_param = 9 | env.static_parameter_with_data([3, 3], "m", &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]); 10 | let x_param = env.static_parameter_with_data([3, 1], "x", &[4.0, 5.0, 6.0]); 11 | let y_param = env.static_parameter_with_data([3, 1], "y", &[1.0, 2.0, 3.0]); 12 | let z_param = env.static_parameter([3, 1], "z"); 13 | 14 | let graph = env.build_graph(|scope| { 15 | let m = scope.parameter_value(&m_param); 16 | let x = scope.parameter_value(&x_param); 17 | let y = scope.parameter_value(&y_param); 18 | let z = 2.0 * m.matmul(x) + y * y + 1.0; 19 | scope.write_parameter_value(&z_param, z); 20 | }); 21 | graph.write_dot_file(KernelDotOutput::Cluster, "array_api_values.dot"); 22 | 23 | env.run(&graph, random_seed); 24 | assert_eq!(&env.read_parameter_to_vec(&z_param), &[10.0, 15.0, 22.0]); 25 | 26 | let x_param = env.trainable_parameter([1], "x", Initializer::Zero); 27 | 28 | let graph = env.build_graph(|scope| { 29 | let x = scope.parameter(&x_param); 30 | let y = x.sin(); 31 | let _loss = (y.square() + y * 3.0).set_loss(); 32 | scope.write_parameter_value(&x_param, x.value() - 0.1 * x.loss_grad()); 33 | }); 34 | graph.write_dot_file(KernelDotOutput::Cluster, "array_api_grad.dot"); 35 | } 36 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | GNUPLOT=gnuplot 2 | 3 | TEMP_DIR=temp 4 | DOCS_DIR=docs 5 | 6 | FASHION_MNIST_APP=cargo run --release --example fashion_mnist -- 7 | FASHION_MNIST_PLOT=$(DOCS_DIR)/fashion_mnist.gnuplot 8 | FASHION_MNIST_STATS=\ 9 | $(TEMP_DIR)/fashion_mnist_stats_linear.csv \ 10 | $(TEMP_DIR)/fashion_mnist_stats_single-layer.csv \ 11 | $(TEMP_DIR)/fashion_mnist_stats_conv-net.csv \ 12 | $(TEMP_DIR)/fashion_mnist_stats_conv-blur-net.csv 13 | FASHION_MNIST_GRAPHS=$(FASHION_MNIST_STATS:$(TEMP_DIR)/%.csv=$(DOCS_DIR)/%.svg) 14 | 15 | IMAGE_FIT_APP=cargo run --release --example image_fit -- 16 | IMAGE_FIT_PLOT=$(DOCS_DIR)/image_fit.gnuplot 17 | IMAGE_FIT_STATS=\ 18 | $(TEMP_DIR)/image_fit_stats_relu.csv \ 19 | $(TEMP_DIR)/image_fit_stats_relu-pe.csv \ 20 | $(TEMP_DIR)/image_fit_stats_siren.csv \ 21 | $(TEMP_DIR)/image_fit_stats_multi-hash.csv 22 | IMAGE_FIT_GRAPHS=$(DOCS_DIR)/image_fit_stats.svg 23 | 24 | DIRS=$(TEMP_DIR) $(DOCS_DIR) 25 | 26 | $(info $(shell mkdir -p $(DIRS))) 27 | 28 | all: fashion_mnist image_fit 29 | .PHONY: all clean clean_fashion_mnist fashion_mnist clean_image_fit image_fit 30 | 31 | clean: clean_fashion_mnist clean_image_fit 32 | 33 | clean_fashion_mnist: 34 | $(RM) $(FASHION_MNIST_STATS) $(FASHION_MNIST_GRAPHS) 35 | 36 | fashion_mnist: $(FASHION_MNIST_GRAPHS) 37 | 38 | clean_image_fit: 39 | $(RM) $(IMAGE_FIT_STATS) $(IMAGE_FIT_GRAPHS) 40 | 41 | image_fit: $(IMAGE_FIT_GRAPHS) 42 | 43 | $(DOCS_DIR)/fashion_mnist_%.svg : $(TEMP_DIR)/fashion_mnist_%.csv $(FASHION_MNIST_PLOT) 44 | $(GNUPLOT) -c $(FASHION_MNIST_PLOT) "$@" "$<" 45 | 46 | $(TEMP_DIR)/fashion_mnist_stats_%.csv: 47 | $(FASHION_MNIST_APP) --quiet -t 4 --csv-file-name "$@" $* 48 | 49 | $(TEMP_DIR)/image_fit_stats_%.csv: 50 | $(IMAGE_FIT_APP) --quiet --csv-file-name "$@" --image-prefix "$(DOCS_DIR)/image_fit_output_$*" $* 51 | 52 | $(DOCS_DIR)/image_fit_stats.svg: $(IMAGE_FIT_STATS) $(IMAGE_FIT_PLOT) 53 | $(GNUPLOT) -c $(IMAGE_FIT_PLOT) "$@" $^ 54 | -------------------------------------------------------------------------------- /src/parameter.rs: -------------------------------------------------------------------------------- 1 | use crate::{common::*, device::common::*}; 2 | use slotmap::SlotMap; 3 | use std::{cell::RefCell, rc::Rc}; 4 | 5 | slotmap::new_key_type! { 6 | pub(crate) struct ParameterId; 7 | } 8 | 9 | #[derive(Clone, Copy, Debug)] 10 | pub enum Initializer { 11 | Zero, 12 | RandNormal(f32), 13 | RandUniform(f32), 14 | } 15 | 16 | impl Initializer { 17 | pub fn for_relu(fan_in: usize) -> Self { 18 | let scale = (2.0 / (fan_in as f32)).sqrt(); 19 | Self::RandNormal(scale) 20 | } 21 | 22 | pub fn for_siren(fan_in: usize, is_first_layer: bool) -> Self { 23 | let scale = (6.0 / (fan_in as f32)).sqrt() * if is_first_layer { 30.0 } else { 1.0 }; 24 | Self::RandUniform(scale) 25 | } 26 | } 27 | 28 | pub(crate) struct ParameterStorage { 29 | pub(crate) shape: Shape, 30 | pub(crate) name: String, 31 | pub(crate) buffer_id: Option, 32 | pub(crate) reset_to: Option, 33 | } 34 | 35 | pub(crate) type SharedParameters = Rc>>; 36 | 37 | #[derive(Clone)] 38 | pub struct Parameter { 39 | id: ParameterId, 40 | owner: SharedParameters, 41 | } 42 | 43 | impl Parameter { 44 | pub(crate) fn new(parameter_id: ParameterId, owner: &SharedParameters) -> Self { 45 | Self { 46 | id: parameter_id, 47 | owner: SharedParameters::clone(owner), 48 | } 49 | } 50 | 51 | pub(crate) fn checked_id(&self, owner: &SharedParameters) -> ParameterId { 52 | if !SharedParameters::ptr_eq(&self.owner, owner) { 53 | panic!("parameter does not come from the same environment"); 54 | } 55 | self.id 56 | } 57 | 58 | pub fn shape(&self) -> Shape { 59 | self.owner.borrow().get(self.id).unwrap().shape 60 | } 61 | 62 | pub fn name(&self) -> String { 63 | self.owner.borrow().get(self.id).unwrap().name.clone() 64 | } 65 | 66 | pub fn reset_to(&self) -> Option { 67 | self.owner.borrow().get(self.id).unwrap().reset_to 68 | } 69 | 70 | pub fn is_trainable(&self) -> bool { 71 | self.owner.borrow().get(self.id).unwrap().reset_to.is_some() 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /src/device/descriptor_pool.rs: -------------------------------------------------------------------------------- 1 | use super::common::*; 2 | use spark::{vk, Builder}; 3 | use std::collections::VecDeque; 4 | 5 | pub(crate) struct DescriptorPools { 6 | context: SharedContext, 7 | pools: VecDeque>, 8 | } 9 | 10 | impl DescriptorPools { 11 | const COUNT: usize = 2; 12 | 13 | const MAX_SETS: u32 = 2048; // TODO: pass in sizes (derive from graphs) 14 | const MAX_BUFFERS: u32 = 8 * Self::MAX_SETS; 15 | 16 | pub(crate) fn new(context: &SharedContext, fences: &FenceSet) -> Self { 17 | let device = &context.device; 18 | let descriptor_pool_sizes = [vk::DescriptorPoolSize { 19 | ty: vk::DescriptorType::STORAGE_BUFFER, 20 | descriptor_count: Self::MAX_BUFFERS, 21 | }]; 22 | let descriptor_pool_create_info = vk::DescriptorPoolCreateInfo::builder() 23 | .max_sets(Self::MAX_SETS) 24 | .p_pool_sizes(&descriptor_pool_sizes); 25 | 26 | let mut pools = VecDeque::new(); 27 | for _ in 0..Self::COUNT { 28 | let pool = unsafe { device.create_descriptor_pool(&descriptor_pool_create_info, None) } 29 | .unwrap(); 30 | pools.push_back(Fenced::new(pool, fences.old_id())); 31 | } 32 | Self { 33 | context: SharedContext::clone(context), 34 | pools, 35 | } 36 | } 37 | 38 | pub(crate) fn acquire(&mut self, fences: &FenceSet) -> ScopedDescriptorPool { 39 | let pool = self.pools.pop_front().unwrap().take_when_signaled(fences); 40 | unsafe { 41 | self.context 42 | .device 43 | .reset_descriptor_pool(pool, vk::DescriptorPoolResetFlags::empty()) 44 | .unwrap(); 45 | } 46 | ScopedDescriptorPool { pool, owner: self } 47 | } 48 | } 49 | 50 | impl Drop for DescriptorPools { 51 | fn drop(&mut self) { 52 | let device = &self.context.device; 53 | for pool in self.pools.iter() { 54 | unsafe { 55 | let pool = pool.get_unchecked(); 56 | device.destroy_descriptor_pool(Some(*pool), None); 57 | } 58 | } 59 | } 60 | } 61 | 62 | pub(crate) struct ScopedDescriptorPool<'a> { 63 | pool: vk::DescriptorPool, 64 | owner: &'a mut DescriptorPools, 65 | } 66 | 67 | impl<'a> ScopedDescriptorPool<'a> { 68 | pub(crate) fn get(&self) -> vk::DescriptorPool { 69 | self.pool 70 | } 71 | 72 | pub(crate) fn recycle(self, fence: FenceId) { 73 | self.owner.pools.push_back(Fenced::new(self.pool, fence)); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /examples/image_fit/README.md: -------------------------------------------------------------------------------- 1 | # image_fit example 2 | 3 | This example overfits a few different networks to the following test image: 4 | 5 | ![input image](../../data/images/cat.jpg) 6 | 7 | ## Overview 8 | 9 | The networks are: 10 | 11 | - **ReLU**: 2D coordinate into a 256-128-64-32 MLP with ReLU activation after each layer 12 | - **ReLU-PE**: 2D positional encoding (L=8 so 32 values) into the same MLP as above 13 | - **SIREN**: 2D coordinate into a 256-128-64-32 MLP with sine activation after each layer 14 | - This is implemented as described in [Implicit Neural Representations with Periodic Activation Functions](https://vsitzmann.github.io/siren/), using the initialisation scheme from the paper (including the extra scaling on the first layer). 15 | - **MULTI-HASH**: 2D coordinate into 10-level hash encoding (up to 4096 entries per level, 2 values per entry) with concatenated outputs fed into a 64-64 MLP with ReLU activation after each layer 16 | - This is implemented as described in [Instant Neural Graphics Primitives with a Multiresolution Hash Encoding](https://github.com/NVlabs/instant-ngp), using smaller hash tables and number of layers, in order to train a similar number of parameters to the other networks. 17 | 18 | For all networks there is then a final linear layer to an RGB triple. Here is a graph of how the training loss evolves over 200 epochs of training: 19 | 20 | ![](../../docs/image_fit_stats.svg) 21 | 22 | The loss function is squared difference between the RGB values of the result and the training image. 23 | Although the SIREN network evolves similarly to the ReLU-PE network, the SIREN network seems to do better at capturing sharp features, and does so with fewer parameters (due to the lack of initial encoding layer). 24 | 25 | The MULTI-HASH network trains quickly and has the lowest error by far amongst these networks. 26 | 27 | The number of trainable parameters for each network is as follows: 28 | 29 | Network Type | Trainable Parameters 30 | --- | --- 31 | ReLU | 44099 32 | ReLU with PE | 51779 33 | SIREN | 44099 34 | MULTI-HASH | 43977 35 | 36 | ## Running The Example 37 | 38 | The example can be run using: 39 | 40 | ``` 41 | cargo run --release --example image_fit 42 | ``` 43 | 44 | This will fit using a SIREN network by default. Other networks can be trained by passing different command-line arguments, run the following to show commandline help: 45 | 46 | ``` 47 | cargo run --release --example image_fit -- --help 48 | ``` 49 | 50 | ## Fitted Image Results 51 | 52 | Here are the results for each network type after 200 epochs of training. Each epoch trains using the same number of pixels as the input image, in mini-batches of 16K randomly sampled pixels. 53 | 54 | ### ReLU 55 | 56 | ![rulu output](../../docs/image_fit_output_relu_200.jpg) 57 | 58 | ### ReLU with Positional Encoding 59 | 60 | ![relu-pe output](../../docs/image_fit_output_relu-pe_200.jpg) 61 | 62 | ### SIREN 63 | 64 | ![siren output](../../docs/image_fit_output_siren_200.jpg) 65 | 66 | ### MULTI-HASH 67 | 68 | ![siren output](../../docs/image_fit_output_multi-hash_200.jpg) 69 | -------------------------------------------------------------------------------- /src/optimizer.rs: -------------------------------------------------------------------------------- 1 | use crate::common::*; 2 | 3 | pub fn add_weight_decay_to_grad(scope: &Scope, parameters: &[Parameter], weight_decay: f32) { 4 | if weight_decay == 0.0 { 5 | return; 6 | } 7 | 8 | scope.next_colour(); 9 | for param in parameters.iter() { 10 | let (w, g) = scope.parameter(param).into_inner(); 11 | g.accumulate(w * weight_decay); 12 | } 13 | } 14 | 15 | pub trait Optimizer { 16 | fn reset_state(&self, env: &mut Environment); 17 | } 18 | 19 | pub struct StochasticGradientDescent { 20 | state: Vec, 21 | } 22 | 23 | impl StochasticGradientDescent { 24 | pub fn new<'s>( 25 | env: &mut Environment, 26 | scope: &'s Scope, 27 | parameters: &[Parameter], 28 | learning_rate: impl IntoArray<'s>, 29 | momentum: f32, 30 | ) -> Self { 31 | scope.next_colour(); 32 | let mut state = Vec::new(); 33 | 34 | let learning_rate = learning_rate.into_array(scope); 35 | for param in parameters.iter() { 36 | let g = scope.parameter(param).loss_grad(); 37 | if momentum == 0.0 { 38 | scope.update_parameter_value(param, |theta| theta - learning_rate * g); 39 | } else { 40 | let shape = param.shape(); 41 | let v_param = env.static_parameter(shape, "v"); 42 | let v = scope.update_parameter_value(&v_param, |v| v * momentum + g); 43 | scope.update_parameter_value(param, |theta| theta - learning_rate * v); 44 | state.push(v_param); 45 | } 46 | } 47 | 48 | let tmp = Self { state }; 49 | tmp.reset_state(env); 50 | tmp 51 | } 52 | } 53 | 54 | impl Optimizer for StochasticGradientDescent { 55 | fn reset_state(&self, env: &mut Environment) { 56 | for param in self.state.iter() { 57 | env.writer(param).zero_fill() 58 | } 59 | } 60 | } 61 | 62 | pub struct Adam { 63 | state: Vec, 64 | } 65 | 66 | impl Adam { 67 | pub fn new<'s>( 68 | env: &mut Environment, 69 | scope: &'s Scope, 70 | parameters: &[Parameter], 71 | learning_rate: impl IntoArray<'s>, 72 | beta1: f32, 73 | beta2: f32, 74 | epsilon: f32, 75 | ) -> Self { 76 | scope.next_colour(); 77 | let mut state = Vec::new(); 78 | 79 | let t_param = env.static_parameter([1], "t"); 80 | let t = scope.update_parameter_value(&t_param, |t| t + 1.0); 81 | state.push(t_param); 82 | 83 | let alpha = learning_rate.into_array(scope) * (1.0 - (beta2.ln() * t).exp()).sqrt() 84 | / (1.0 - (beta1.ln() * t).exp()); 85 | 86 | for param in parameters.iter() { 87 | let shape = param.shape(); 88 | let m_param = env.static_parameter(shape, "m"); 89 | let v_param = env.static_parameter(shape, "v"); 90 | 91 | let g = scope.parameter(param).loss_grad(); 92 | let m = scope.update_parameter_value(&m_param, |m| m * beta1 + g * (1.0 - beta1)); 93 | let v = scope.update_parameter_value(&v_param, |v| v * beta2 + g * g * (1.0 - beta2)); 94 | state.push(m_param); 95 | state.push(v_param); 96 | 97 | scope.update_parameter_value(param, |theta| theta - alpha * m / (v.sqrt() + epsilon)); 98 | } 99 | 100 | let tmp = Self { state }; 101 | tmp.reset_state(env); 102 | tmp 103 | } 104 | } 105 | 106 | impl Optimizer for Adam { 107 | fn reset_state(&self, env: &mut Environment) { 108 | for param in self.state.iter() { 109 | env.writer(param).zero_fill() 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/device/fence.rs: -------------------------------------------------------------------------------- 1 | use super::common::*; 2 | use arrayvec::ArrayVec; 3 | use spark::vk; 4 | use std::slice; 5 | 6 | #[derive(Clone, Copy)] 7 | pub(crate) struct FenceId(usize); 8 | 9 | impl FenceId { 10 | fn index(&self) -> usize { 11 | self.0 % FenceSet::COUNT 12 | } 13 | } 14 | 15 | pub(crate) struct FenceSet { 16 | context: SharedContext, 17 | fences: [vk::Fence; Self::COUNT], 18 | counter: usize, 19 | } 20 | 21 | impl FenceSet { 22 | const COUNT: usize = 2; 23 | 24 | pub(crate) fn new(context: &SharedContext) -> Self { 25 | let mut fences = ArrayVec::new(); 26 | for _ in 0..Self::COUNT { 27 | let fence = { 28 | let fence_create_info = vk::FenceCreateInfo { 29 | flags: vk::FenceCreateFlags::SIGNALED, 30 | ..Default::default() 31 | }; 32 | unsafe { context.device.create_fence(&fence_create_info, None) }.unwrap() 33 | }; 34 | fences.push(fence); 35 | } 36 | Self { 37 | context: SharedContext::clone(context), 38 | fences: fences.into_inner().unwrap(), 39 | counter: 1, 40 | } 41 | } 42 | 43 | pub(crate) fn old_id(&self) -> FenceId { 44 | FenceId(self.counter.wrapping_sub(1)) 45 | } 46 | 47 | fn id_needs_wait(&self, id: FenceId) -> bool { 48 | id.0.wrapping_sub(self.counter) < Self::COUNT 49 | } 50 | 51 | pub(crate) fn next_unsignaled(&mut self) -> (FenceId, vk::Fence) { 52 | self.wait_for_signal(FenceId(self.counter)); 53 | 54 | let id = FenceId(self.counter.wrapping_add(Self::COUNT)); 55 | self.counter = self.counter.wrapping_add(1); 56 | 57 | let fence = self.fences[id.index()]; 58 | unsafe { 59 | self.context 60 | .device 61 | .reset_fences(slice::from_ref(&fence)) 62 | .unwrap(); 63 | } 64 | (id, fence) 65 | } 66 | 67 | pub(crate) fn wait_for_signal(&self, id: FenceId) { 68 | if !self.id_needs_wait(id) { 69 | return; 70 | } 71 | 72 | let fence = self.fences[id.index()]; 73 | let timeout_ns = 1000 * 1000 * 1000; 74 | loop { 75 | let res = unsafe { 76 | self.context 77 | .device 78 | .wait_for_fences(slice::from_ref(&fence), true, timeout_ns) 79 | }; 80 | match res { 81 | Ok(_) => break, 82 | Err(vk::Result::TIMEOUT) => {} 83 | Err(err_code) => panic!("failed to wait for fence {}", err_code), 84 | } 85 | } 86 | } 87 | } 88 | 89 | impl Drop for FenceSet { 90 | fn drop(&mut self) { 91 | for fence in self.fences.iter().copied() { 92 | unsafe { 93 | self.context.device.destroy_fence(Some(fence), None); 94 | } 95 | } 96 | } 97 | } 98 | 99 | pub(crate) struct Fenced { 100 | value: T, 101 | fence_id: FenceId, 102 | } 103 | 104 | impl Fenced { 105 | pub(crate) fn new(value: T, fence_id: FenceId) -> Self { 106 | Self { value, fence_id } 107 | } 108 | 109 | pub(crate) fn get_mut_when_signaled(&mut self, set: &FenceSet) -> &mut T { 110 | set.wait_for_signal(self.fence_id); 111 | &mut self.value 112 | } 113 | 114 | pub(crate) fn take_when_signaled(self, set: &FenceSet) -> T { 115 | set.wait_for_signal(self.fence_id); 116 | self.value 117 | } 118 | 119 | pub(crate) fn map(self, f: F) -> Fenced 120 | where 121 | F: FnOnce(T) -> U, 122 | { 123 | Fenced { 124 | value: f(self.value), 125 | fence_id: self.fence_id, 126 | } 127 | } 128 | 129 | pub(crate) unsafe fn get_unchecked(&self) -> &T { 130 | &self.value 131 | } 132 | } 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # descent 2 | 3 | Toy library for neural networks in Rust using Vulkan compute shaders. 4 | 5 | ## Features 6 | 7 | - Multi-dimensional arrays backed by Vulkan device memory 8 | - Use Rust syntax to build a computation graph, run as Vulkan compute shaders 9 | - Supports vector arithmetic and per-element sin/cos/exp/log/etc 10 | - 1D reduction, 2D matrix multiply, 2D convolutions and 2D max pool supported 11 | - Concatenation, gather loads and scatter adds 12 | - Softmax cross entropy loss 13 | - Ops are fused into larger compute shaders where possible (to reduce bandwidth cost) 14 | - Implements broadcasts/padding/windowing/reshapes as views (zero copy) where possible 15 | - Supports one level of automatic derivatives for back-propagation 16 | - Some example optimisers: 17 | - Stochastic gradient descent (with momentum) 18 | - Adam 19 | - Optional higher-level API of neural network building blocks 20 | - Can generate different code for train vs test (e.g. dropout only affects training) 21 | - Deterministic results (except for scatter add which currently uses float atomics...) 22 | 23 | ## Example Network 24 | 25 | The top-level API of neural network building blocks can be used to compactly describe multi-layer networks. Here is a small convolutional neural network with dropout and (leaky) ReLU activation using this API: 26 | 27 | ```rust 28 | struct ConvNet { 29 | conv1: Conv2D, 30 | conv2: Conv2D, 31 | fc1: Dense, 32 | fc2: Dense, 33 | } 34 | 35 | impl ConvNet { 36 | fn new(env: &mut Environment) -> Self { 37 | // create and store parameters for layers that require them 38 | let c1 = 16; 39 | let c2 = 32; 40 | let hidden = 128; 41 | Self { 42 | conv1: Conv2D::builder(1, c1, 3, 3).with_pad(1).build(env), 43 | conv2: Conv2D::builder(c1, c2, 3, 3) 44 | .with_pad(1) 45 | .with_groups(2) 46 | .build(env), 47 | fc1: Dense::builder(7 * 7 * c2, hidden).build(env), 48 | fc2: Dense::builder(hidden, 10).build(env), 49 | } 50 | } 51 | } 52 | 53 | impl Module for ConvNet { 54 | fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s> { 55 | // generates ops for the value (forwards) and gradient (backwards) through the layers 56 | input 57 | .apply(&self.conv1, ctx) 58 | .leaky_relu(0.01) 59 | .max_pool2d((2, 2), (2, 2)) 60 | .apply(&self.conv2, ctx) 61 | .leaky_relu(0.01) 62 | .max_pool2d((2, 2), (2, 2)) 63 | .flatten() 64 | .apply(&Dropout::new(0.5), ctx) 65 | .apply(&self.fc1, ctx) 66 | .leaky_relu(0.01) 67 | .apply(&self.fc2, ctx) 68 | } 69 | } 70 | ``` 71 | 72 | See the [fashion_mnist example](examples/fashion_mnist) for more networks using this API. 73 | 74 | ## Examples 75 | 76 | Please follow the link in the name of each example to show a more detailed description of each one. 77 | 78 | Name | Description 79 | --- | --- 80 | [array_api](examples/array_api) | Demonstrates the low-level `Array` API for building computation graphs. See the README for more details. 81 | [fashion_mnist](examples/fashion_mnist) | Trains a few different network types on the Fashion-MNIST dataset. Demonstrates the use of anti-aliasing during max pooling for improved accuracy. See the README for a comparison of network performance. 82 | [image_fit](examples/image_fit) | Overfits a few different network types to a single RGB image. Compares ReLU with positional encoding to a SIREN network. _Update: now also compares to a multi-level hash encoding._ 83 | 84 | ## Dependencies 85 | 86 | The following crates have been very useful to develop this project: 87 | 88 | - [petgraph](https://github.com/petgraph/petgraph): used for all graph data structures 89 | - [slotmap](https://github.com/orlp/slotmap): storage with stable keys 90 | - [shaderc](https://github.com/google/shaderc-rs): interface to GLSL compiler to generate SPIR-V for shaders 91 | 92 | ## Potential Future Work 93 | 94 | - [ ] Lookahead optimiser? 95 | - [ ] Recurrent network 96 | - [ ] SDF fitting 97 | - [x] Multi-level hash encoding 98 | - [ ] Make concat zero-copy (writeable views) 99 | -------------------------------------------------------------------------------- /src/device/buffer_heap.rs: -------------------------------------------------------------------------------- 1 | use super::{common::*, heap::*}; 2 | use spark::vk; 3 | 4 | slotmap::new_key_type! { 5 | pub(crate) struct BufferId; 6 | } 7 | 8 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 9 | struct ChunkIndex(usize); 10 | 11 | struct Chunk { 12 | device_memory: vk::DeviceMemory, 13 | buffer: vk::Buffer, 14 | } 15 | 16 | pub(crate) struct BufferHeap { 17 | context: SharedContext, 18 | chunks: Vec, 19 | heap: Heap, 20 | } 21 | 22 | #[derive(Debug, Clone, Copy)] 23 | pub(crate) struct BufferInfo { 24 | pub(crate) buffer: vk::Buffer, 25 | pub(crate) range: HeapRange, 26 | } 27 | 28 | impl BufferHeap { 29 | const CHUNK_SIZE: usize = 256 * 1024 * 1024; 30 | 31 | pub(crate) fn new(context: &SharedContext) -> Self { 32 | Self { 33 | context: SharedContext::clone(context), 34 | chunks: Vec::new(), 35 | heap: Heap::default(), 36 | } 37 | } 38 | 39 | fn extend_heap_by_at_least(&mut self, capacity: usize) { 40 | let chunk_size = Self::CHUNK_SIZE.max(capacity); 41 | let device = &self.context.device; 42 | let buffer = { 43 | let buffer_create_info = vk::BufferCreateInfo { 44 | size: chunk_size as vk::DeviceSize, 45 | usage: vk::BufferUsageFlags::STORAGE_BUFFER 46 | | vk::BufferUsageFlags::TRANSFER_SRC 47 | | vk::BufferUsageFlags::TRANSFER_DST, 48 | ..Default::default() 49 | }; 50 | unsafe { device.create_buffer(&buffer_create_info, None) }.unwrap() 51 | }; 52 | let mem_req = unsafe { device.get_buffer_memory_requirements(buffer) }; 53 | let device_memory = { 54 | let memory_type_index = self 55 | .context 56 | .get_memory_type_index( 57 | mem_req.memory_type_bits, 58 | vk::MemoryPropertyFlags::DEVICE_LOCAL, 59 | ) 60 | .unwrap(); 61 | let memory_allocate_info = vk::MemoryAllocateInfo { 62 | allocation_size: mem_req.size, 63 | memory_type_index, 64 | ..Default::default() 65 | }; 66 | unsafe { device.allocate_memory(&memory_allocate_info, None) }.unwrap() 67 | }; 68 | unsafe { device.bind_buffer_memory(buffer, device_memory, 0) }.unwrap(); 69 | 70 | let chunk_index = ChunkIndex(self.chunks.len()); 71 | self.chunks.push(Chunk { 72 | device_memory, 73 | buffer, 74 | }); 75 | 76 | self.heap.extend_with(chunk_index, chunk_size); 77 | } 78 | 79 | pub(crate) fn alloc(&mut self, size: usize) -> Option { 80 | let align = self 81 | .context 82 | .physical_device_properties 83 | .limits 84 | .non_coherent_atom_size 85 | .max( 86 | self.context 87 | .physical_device_properties 88 | .limits 89 | .min_storage_buffer_offset_alignment, 90 | ) as usize; 91 | match self.heap.alloc(size, align) { 92 | Some(alloc) => Some(alloc), 93 | None => { 94 | self.extend_heap_by_at_least(size); 95 | self.heap.alloc(size, align) 96 | } 97 | } 98 | } 99 | 100 | pub(crate) fn free(&mut self, id: BufferId) { 101 | self.heap.free(id); 102 | } 103 | 104 | pub(crate) fn info(&self, id: BufferId) -> BufferInfo { 105 | let info = self.heap.info(id); 106 | BufferInfo { 107 | buffer: self.chunks[info.tag.0].buffer, 108 | range: info.range, 109 | } 110 | } 111 | 112 | #[allow(dead_code)] 113 | pub(crate) fn heap_stats(&self) -> HeapStats { 114 | self.heap.stats() 115 | } 116 | } 117 | 118 | impl Drop for BufferHeap { 119 | fn drop(&mut self) { 120 | let device = &self.context.device; 121 | for chunk in self.chunks.drain(..) { 122 | unsafe { 123 | device.destroy_buffer(Some(chunk.buffer), None); 124 | device.free_memory(Some(chunk.device_memory), None); 125 | } 126 | } 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /src/kernel_matmul.glsl: -------------------------------------------------------------------------------- 1 | // float load_a(uint batch_index, uvec2 coord); 2 | // float load_b(uint batch_index, uvec2 coord); 3 | // void store_c(uint batch_index, uint k_chunk_index, uvec2 coord, float value); 4 | // const uint M 5 | // const uint N 6 | // const uint K 7 | // const uint TILE_M 8 | // const uint TILE_N 9 | // const uint TILE_K 10 | // const uint GROUP_SIZE 11 | // const uint K_CHUNK_SIZE_IN_TILES 12 | // const uint K_CHUNK_COUNT 13 | // const uint BATCH_COUNT 14 | // const bool LOAD_A_IN_COLUMNS 15 | // const bool LOAD_B_IN_COLUMNS 16 | 17 | layout(local_size_x = GROUP_SIZE) in; 18 | 19 | const uint A_TILE_W = TILE_K; 20 | const uint A_TILE_H = TILE_M; 21 | const uint A_TILE_SIZE = A_TILE_W * A_TILE_H; 22 | 23 | const uint B_TILE_W = TILE_N; 24 | const uint B_TILE_H = TILE_K; 25 | const uint B_TILE_SIZE = B_TILE_W * B_TILE_H; 26 | 27 | const uint SHARED_PAD = 1; 28 | const uint A_TILE_STRIDE = A_TILE_W + SHARED_PAD; 29 | const uint B_TILE_STRIDE = B_TILE_W + SHARED_PAD; 30 | 31 | const uint C_TILE_W = TILE_N; 32 | const uint C_TILE_H = TILE_M; 33 | const uint C_TILE_SIZE = C_TILE_W * C_TILE_H; 34 | const uint C_VALUES_PER_THREAD = C_TILE_SIZE / GROUP_SIZE; // must divide exactly! 35 | 36 | const uint M_TILE_COUNT = (M + (TILE_M - 1))/TILE_M; 37 | const uint N_TILE_COUNT = (N + (TILE_N - 1))/TILE_N; 38 | const uint K_TILE_COUNT = (K + (TILE_K - 1))/TILE_K; 39 | 40 | shared float s_a[A_TILE_H * A_TILE_STRIDE]; 41 | shared float s_b[B_TILE_H * B_TILE_STRIDE]; 42 | 43 | void main() { 44 | int c_output_coord[4]; 45 | compute_grid_coord(gl_WorkGroupID.x, c_output_coord, K_CHUNK_COUNT, BATCH_COUNT, M_TILE_COUNT, N_TILE_COUNT); 46 | uvec2 c_tile_coord = uvec2(c_output_coord[3], c_output_coord[2]); 47 | uint batch_index = c_output_coord[1]; 48 | uint k_chunk_index = c_output_coord[0]; 49 | 50 | uint thread_index = gl_LocalInvocationID.x; 51 | 52 | float result[C_VALUES_PER_THREAD]; 53 | for (uint i = 0; i < C_VALUES_PER_THREAD; ++i) { 54 | result[i] = 0.f; 55 | } 56 | 57 | uint k_tile_begin = k_chunk_index * K_CHUNK_SIZE_IN_TILES; 58 | uint k_tile_end = min(k_tile_begin + K_CHUNK_SIZE_IN_TILES, K_TILE_COUNT); 59 | for (uint k_tile_index = k_tile_begin; k_tile_index != k_tile_end; ++k_tile_index) { 60 | barrier(); 61 | for (uint load_index = thread_index; load_index < A_TILE_SIZE; load_index += GROUP_SIZE) { 62 | uvec2 a_coord_in_tile = LOAD_A_IN_COLUMNS 63 | ? uvec2(load_index/A_TILE_H, load_index % A_TILE_H) 64 | : uvec2(load_index % A_TILE_W, load_index/A_TILE_W); 65 | uvec2 a_tile_coord = uvec2(k_tile_index, c_tile_coord.y); 66 | float a = load_a(batch_index, a_tile_coord*uvec2(A_TILE_W, A_TILE_H) + a_coord_in_tile); 67 | s_a[a_coord_in_tile.y*A_TILE_STRIDE + a_coord_in_tile.x] = a; 68 | } 69 | for (uint load_index = thread_index; load_index < B_TILE_SIZE; load_index += GROUP_SIZE) { 70 | uvec2 b_coord_in_tile = LOAD_B_IN_COLUMNS 71 | ? uvec2(load_index/B_TILE_H, load_index % B_TILE_H) 72 | : uvec2(load_index % B_TILE_W, load_index/B_TILE_W); 73 | uvec2 b_tile_coord = uvec2(c_tile_coord.x, k_tile_index); 74 | float b = load_b(batch_index, b_tile_coord*uvec2(B_TILE_W, B_TILE_H) + b_coord_in_tile); 75 | s_b[b_coord_in_tile.y*B_TILE_STRIDE + b_coord_in_tile.x] = b; 76 | } 77 | barrier(); 78 | 79 | for (uint k_index = 0; k_index < TILE_K; ++k_index) { 80 | for (uint i = 0; i < C_VALUES_PER_THREAD; ++i) { 81 | uint c_index = i*GROUP_SIZE + thread_index; 82 | uvec2 c_coord_in_tile = uvec2(c_index % C_TILE_W, c_index / C_TILE_W); 83 | uvec2 a_coord_in_tile = uvec2(k_index, c_coord_in_tile.y); 84 | uvec2 b_coord_in_tile = uvec2(c_coord_in_tile.x, k_index); 85 | float a = s_a[a_coord_in_tile.y*A_TILE_STRIDE + a_coord_in_tile.x]; 86 | float b = s_b[b_coord_in_tile.y*B_TILE_STRIDE + b_coord_in_tile.x]; 87 | result[i] += a*b; 88 | } 89 | } 90 | } 91 | 92 | for (uint i = 0; i < C_VALUES_PER_THREAD; ++i) { 93 | uint c_index = i*GROUP_SIZE + thread_index; 94 | uvec2 c_coord_in_tile = uvec2(c_index % C_TILE_W, c_index / C_TILE_W); 95 | store_c(k_chunk_index, batch_index, c_tile_coord*uvec2(C_TILE_W, C_TILE_H) + c_coord_in_tile, result[i]); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /examples/array_api/README.md: -------------------------------------------------------------------------------- 1 | # array_api example 2 | 3 | This example creates some small graphs directly using the array API. 4 | 5 | ## Parameters 6 | 7 | A _parameter_ is a handle to a persistant array with some fixed shape. The memory for the array is provided by Vulkan device memory when necessary. 8 | 9 | Parameters are created via an _environment_, which wraps a Vulkan device. 10 | 11 | ```rust 12 | let mut env = Environment::new(); 13 | 14 | let z_param = env.static_parameter([3, 1], "z"); 15 | ``` 16 | 17 | ## Graphs 18 | 19 | A _graph_ is a set of expressions that read and write to _parameters_, that can be run on the Vulkan device. 20 | 21 | At the lowest level, graphs are built by using the `Array` type. This type does not directly do any operations on parameters, but can be used to build up a computation graph using standard Rust syntax. 22 | 23 | A _scope_ is a temporary data structure used to track expressions involving the `Array` type. Many parameter updates can be added to a single scope. Once the scope is complete, it can be compiled into a graph. 24 | 25 | Here is some code from this example that builds a graph using the `Array` type: 26 | 27 | ```rust 28 | let graph = env.build_graph(|scope| { 29 | let m = scope.parameter_value(&m_param); 30 | let x = scope.parameter_value(&x_param); 31 | let y = scope.parameter_value(&y_param); 32 | 33 | // build an expression involving m, x and y 34 | let z = 2.0 * m.matmul(x) + y * y + 1.0; 35 | 36 | scope.write_parameter_value(&z_param, z); 37 | }); 38 | ``` 39 | 40 | In order to actually perform the operations on the parameters, the graph can be run on the Vulkan device using the environment. 41 | 42 | ```rust 43 | env.run(&graph, random_seed); 44 | ``` 45 | 46 | The graph is run as a set compute shaders that run on the Vulkan device. To avoid needlessly wasting bandwidth, operations are fused into a single kernel where possible. A visualisation of the graph for the example above is as follows: 47 | 48 | ![array graph](../../docs/array_api_values.svg) 49 | 50 | The grey boxes above are individual compute shaders: one matrix multiply and one fused per-element shader. (In future this may become a single fused shader.) 51 | 52 | Temporary memory to pass data between shaders is allocated/freed automatically as the graph is run. 53 | 54 | ## Derivatives 55 | 56 | To simplify generating code for back-propagation, there is a higher-level API that manipulates `DualArray` values. 57 | 58 | A `DualArray` value is a pair of `Array` values: one for the (forward) value and one for the (backward) gradient. 59 | It is expected that the gradient that is an expression that computes the derivative of the loss function w.r.t. this variable, i.e. exactly what is required for gradient descent. 60 | 61 | API functions for `DualArray` values compute expressions for both the forward and backward passes. This lets us build the full graph for a gradient descent step in small (composable) chunks of forward-like code, with the final loss function connecting the value and grad ops into a single graph. 62 | 63 | Here is an example that directly constructs a toy loss function, then adds code for a step of gradient descent: 64 | 65 | ```rust 66 | let graph = env.build_graph(|scope| { 67 | let x = scope.parameter(&x_param); 68 | let y = x.sin(); 69 | let _loss = (y.square() + y * 3.0).set_loss(); 70 | scope.write_parameter_value(&x_param, x.value() - 0.1 * x.loss_grad()); 71 | }); 72 | ``` 73 | 74 | Since x is a `DualArray` we did not have to explicitly write code for back-propagation. The graph for this example is as follows: 75 | 76 | ![](../../docs/array_api_grad.svg) 77 | 78 | Inspecting this graph, x is updated proportional to `(2*sin(x) + 3)*cos(x)`, which matches what we expect for a loss function of `sin^2(x) + 3*sin(x)`. 79 | 80 | The `DualArray` API is usually implemented in terms of the `Array` API, and can easily be extended with new functions where they have known derivative. 81 | For example, here is the implemention of `sin()` with some additional comments: 82 | 83 | ```rust 84 | impl<'s> DualArray<'s> { 85 | pub fn sin(self) -> Self { 86 | // get Array for value, loss_grad ("dx" means dL/dx) 87 | let (a, da) = self.into_inner(); 88 | let (b, db) = a.sin().with_empty_grad(); 89 | 90 | // add back-propagation step to compute da from db 91 | // using dL/da = dL/db * db/da, db/da = cos(a) 92 | da.accumulate(db * a.cos()); 93 | 94 | // into DualArray 95 | (b, db).into() 96 | } 97 | } 98 | ``` -------------------------------------------------------------------------------- /docs/fashion_mnist_network_linear.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 46 | 51 | 52 | 54 | 59 | 66 | 73 | 1@28x28 89 | 96 | Dense 107 | 111 | 10 122 | 123 | 124 | -------------------------------------------------------------------------------- /src/device/command_buffer.rs: -------------------------------------------------------------------------------- 1 | use super::common::*; 2 | use spark::{vk, Builder}; 3 | use std::{collections::VecDeque, slice}; 4 | 5 | struct CommandBuffer { 6 | pool: vk::CommandPool, 7 | cmd: vk::CommandBuffer, 8 | } 9 | 10 | impl CommandBuffer { 11 | fn new(context: &SharedContext) -> Self { 12 | let pool = { 13 | let command_pool_create_info = vk::CommandPoolCreateInfo { 14 | queue_family_index: context.queue_family_index, 15 | ..Default::default() 16 | }; 17 | unsafe { 18 | context 19 | .device 20 | .create_command_pool(&command_pool_create_info, None) 21 | } 22 | .unwrap() 23 | }; 24 | 25 | let cmd = { 26 | let command_buffer_allocate_info = vk::CommandBufferAllocateInfo { 27 | command_pool: Some(pool), 28 | level: vk::CommandBufferLevel::PRIMARY, 29 | command_buffer_count: 1, 30 | ..Default::default() 31 | }; 32 | unsafe { 33 | context 34 | .device 35 | .allocate_command_buffers_single(&command_buffer_allocate_info) 36 | } 37 | .unwrap() 38 | }; 39 | 40 | Self { pool, cmd } 41 | } 42 | } 43 | 44 | pub(crate) struct ScopedCommandBuffer<'a> { 45 | buffer: CommandBuffer, 46 | owner: &'a mut CommandBuffers, 47 | } 48 | 49 | impl<'a> ScopedCommandBuffer<'a> { 50 | pub(crate) fn get(&self) -> vk::CommandBuffer { 51 | self.buffer.cmd 52 | } 53 | 54 | pub(crate) fn submit(self, fences: &mut FenceSet) -> FenceId { 55 | self.owner.submit(self.buffer, fences) 56 | } 57 | } 58 | 59 | pub(crate) struct CommandBuffers { 60 | context: SharedContext, 61 | buffers: VecDeque>, 62 | } 63 | 64 | impl CommandBuffers { 65 | const COUNT: usize = 2; 66 | 67 | pub(crate) fn new(context: &SharedContext, fences: &FenceSet) -> Self { 68 | let mut buffers = VecDeque::new(); 69 | for _ in 0..Self::COUNT { 70 | buffers.push_back(Fenced::new(CommandBuffer::new(context), fences.old_id())); 71 | } 72 | Self { 73 | context: SharedContext::clone(context), 74 | buffers, 75 | } 76 | } 77 | 78 | pub(crate) fn acquire(&mut self, fences: &FenceSet) -> ScopedCommandBuffer { 79 | let active = self.buffers.pop_front().unwrap().take_when_signaled(fences); 80 | 81 | unsafe { 82 | self.context 83 | .device 84 | .reset_command_pool(active.pool, vk::CommandPoolResetFlags::empty()) 85 | .unwrap(); 86 | } 87 | 88 | let command_buffer_begin_info = vk::CommandBufferBeginInfo { 89 | flags: vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT, 90 | ..Default::default() 91 | }; 92 | unsafe { 93 | self.context 94 | .device 95 | .begin_command_buffer(active.cmd, &command_buffer_begin_info) 96 | .unwrap(); 97 | } 98 | 99 | ScopedCommandBuffer { 100 | buffer: active, 101 | owner: self, 102 | } 103 | } 104 | 105 | fn submit(&mut self, active: CommandBuffer, fences: &mut FenceSet) -> FenceId { 106 | unsafe { self.context.device.end_command_buffer(active.cmd) }.unwrap(); 107 | 108 | let (fence_id, fence) = fences.next_unsignaled(); 109 | let submit_info = vk::SubmitInfo::builder().p_command_buffers(slice::from_ref(&active.cmd)); 110 | unsafe { 111 | self.context 112 | .device 113 | .queue_submit( 114 | self.context.queue, 115 | slice::from_ref(&submit_info), 116 | Some(fence), 117 | ) 118 | .unwrap(); 119 | } 120 | self.buffers.push_back(Fenced::new(active, fence_id)); 121 | 122 | fence_id 123 | } 124 | } 125 | 126 | impl Drop for CommandBuffers { 127 | fn drop(&mut self) { 128 | for buffer in self.buffers.iter() { 129 | unsafe { 130 | let buffer = buffer.get_unchecked(); 131 | self.context 132 | .device 133 | .free_command_buffers(buffer.pool, slice::from_ref(&buffer.cmd)); 134 | self.context 135 | .device 136 | .destroy_command_pool(Some(buffer.pool), None); 137 | } 138 | } 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /src/kernel_common.glsl: -------------------------------------------------------------------------------- 1 | void compute_grid_coord( 2 | uint remain, 3 | out int coord[1], 4 | uint /*shape0*/) 5 | { 6 | coord[0] = int(remain); 7 | } 8 | 9 | void compute_grid_coord( 10 | uint remain, 11 | out int coord[2], 12 | uint /*shape0*/, 13 | uint shape1) 14 | { 15 | uint tmp1 = remain; 16 | remain /= shape1; 17 | tmp1 -= remain*shape1; 18 | 19 | uint tmp0 = remain; 20 | 21 | coord[0] = int(tmp0); 22 | coord[1] = int(tmp1); 23 | } 24 | 25 | void compute_grid_coord( 26 | uint remain, 27 | out int coord[3], 28 | uint /*shape0*/, 29 | uint shape1, 30 | uint shape2) 31 | { 32 | uint tmp2 = remain; 33 | remain /= shape2; 34 | tmp2 -= remain*shape2; 35 | 36 | uint tmp1 = remain; 37 | remain /= shape1; 38 | tmp1 -= remain*shape1; 39 | 40 | uint tmp0 = remain; 41 | 42 | coord[0] = int(tmp0); 43 | coord[1] = int(tmp1); 44 | coord[2] = int(tmp2); 45 | } 46 | 47 | void compute_grid_coord( 48 | uint remain, 49 | out int coord[4], 50 | uint /*shape0*/, 51 | uint shape1, 52 | uint shape2, 53 | uint shape3) 54 | { 55 | uint tmp3 = remain; 56 | remain /= shape3; 57 | tmp3 -= remain*shape3; 58 | 59 | uint tmp2 = remain; 60 | remain /= shape2; 61 | tmp2 -= remain*shape2; 62 | 63 | uint tmp1 = remain; 64 | remain /= shape1; 65 | tmp1 -= remain*shape1; 66 | 67 | uint tmp0 = remain; 68 | 69 | coord[0] = int(tmp0); 70 | coord[1] = int(tmp1); 71 | coord[2] = int(tmp2); 72 | coord[3] = int(tmp3); 73 | } 74 | 75 | void compute_grid_coord( 76 | uint remain, 77 | out int coord[5], 78 | uint /*shape0*/, 79 | uint shape1, 80 | uint shape2, 81 | uint shape3, 82 | uint shape4) 83 | { 84 | uint tmp4 = remain; 85 | remain /= shape4; 86 | tmp4 -= remain*shape4; 87 | 88 | uint tmp3 = remain; 89 | remain /= shape3; 90 | tmp3 -= remain*shape3; 91 | 92 | uint tmp2 = remain; 93 | remain /= shape2; 94 | tmp2 -= remain*shape2; 95 | 96 | uint tmp1 = remain; 97 | remain /= shape1; 98 | tmp1 -= remain*shape1; 99 | 100 | uint tmp0 = remain; 101 | 102 | coord[0] = int(tmp0); 103 | coord[1] = int(tmp1); 104 | coord[2] = int(tmp2); 105 | coord[3] = int(tmp3); 106 | coord[4] = int(tmp4); 107 | } 108 | 109 | void compute_grid_coord( 110 | uint remain, 111 | out int coord[6], 112 | uint /*shape0*/, 113 | uint shape1, 114 | uint shape2, 115 | uint shape3, 116 | uint shape4, 117 | uint shape5) 118 | { 119 | uint tmp5 = remain; 120 | remain /= shape5; 121 | tmp5 -= remain*shape5; 122 | 123 | uint tmp4 = remain; 124 | remain /= shape4; 125 | tmp4 -= remain*shape4; 126 | 127 | uint tmp3 = remain; 128 | remain /= shape3; 129 | tmp3 -= remain*shape3; 130 | 131 | uint tmp2 = remain; 132 | remain /= shape2; 133 | tmp2 -= remain*shape2; 134 | 135 | uint tmp1 = remain; 136 | remain /= shape1; 137 | tmp1 -= remain*shape1; 138 | 139 | uint tmp0 = remain; 140 | 141 | coord[0] = int(tmp0); 142 | coord[1] = int(tmp1); 143 | coord[2] = int(tmp2); 144 | coord[3] = int(tmp3); 145 | coord[4] = int(tmp4); 146 | coord[5] = int(tmp5); 147 | } 148 | 149 | void compute_grid_coord( 150 | uint remain, 151 | out int coord[7], 152 | uint /*shape0*/, 153 | uint shape1, 154 | uint shape2, 155 | uint shape3, 156 | uint shape4, 157 | uint shape5, 158 | uint shape6) 159 | { 160 | uint tmp6 = remain; 161 | remain /= shape6; 162 | tmp6 -= remain*shape6; 163 | 164 | uint tmp5 = remain; 165 | remain /= shape5; 166 | tmp5 -= remain*shape5; 167 | 168 | uint tmp4 = remain; 169 | remain /= shape4; 170 | tmp4 -= remain*shape4; 171 | 172 | uint tmp3 = remain; 173 | remain /= shape3; 174 | tmp3 -= remain*shape3; 175 | 176 | uint tmp2 = remain; 177 | remain /= shape2; 178 | tmp2 -= remain*shape2; 179 | 180 | uint tmp1 = remain; 181 | remain /= shape1; 182 | tmp1 -= remain*shape1; 183 | 184 | uint tmp0 = remain; 185 | 186 | coord[0] = int(tmp0); 187 | coord[1] = int(tmp1); 188 | coord[2] = int(tmp2); 189 | coord[3] = int(tmp3); 190 | coord[4] = int(tmp4); 191 | coord[5] = int(tmp5); 192 | coord[6] = int(tmp6); 193 | } 194 | 195 | int dot(ivec3 a, ivec3 b) 196 | { 197 | return a.x*b.x + a.y*b.y + a.z*b.z; 198 | } 199 | 200 | layout(push_constant) uniform constants 201 | { 202 | uint rand_seed; 203 | }; 204 | 205 | uint pcg(uint v) 206 | { 207 | uint state = v*747796405u + 2891336453u; 208 | uint word = ((state >> ((state >> 28u) + 4u)) ^ state)*277803737u; 209 | return (word >> 22u) ^ word; 210 | } 211 | 212 | float rand_from_index(uint uid, int index) 213 | { 214 | uint hash = pcg(pcg(index) + rand_seed + uid); 215 | return float(hash)/float(0xffffffffu); 216 | } 217 | 218 | float U2F(uint x) { return uintBitsToFloat(x); } 219 | uint F2U(float x) { return floatBitsToUint(x); } 220 | int F2I(float x) { return floatBitsToInt(x); } 221 | -------------------------------------------------------------------------------- /examples/fashion_mnist/README.md: -------------------------------------------------------------------------------- 1 | # fashion_mnist example 2 | 3 | This example trains a few different network types on the [Fashion-MNIST dataset](https://github.com/zalandoresearch/fashion-mnist). 4 | The example also implements an anti-aliased variant of max pooling as described in [Making Convolutional Networks Shift-Invariant Again by Richard Zhang](https://richzhang.github.io/antialiased-cnns/) for improved accuracy. 5 | 6 | ## Overview 7 | 8 | The [Fashion-MNIST dataset](https://github.com/zalandoresearch/fashion-mnist) consists of 60000 training images and 10000 test images. 9 | Each image is 28 by 28 with a single greyscale channel, and has a corresponding label between 0 and 9 (signifying which of the 10 categories the image belongs to). 10 | It is considered to be a more challenging dataset than the original [MNIST data](http://yann.lecun.com/exdb/mnist/) of handwritten digits. 11 | 12 | The example trains a network with the full training set for multiple epochs, testing accuracy against the full test set after each one. The training set is shuffled into a random order for each epoch. 13 | 14 | ## Running The Example 15 | 16 | First run `download.sh` in `data/fashion_mnist` to download the dataset, or download the gz files manually. Then the example can be run using: 17 | 18 | ``` 19 | cargo run --release --example fashion_mnist 20 | ``` 21 | 22 | This will train a CNN by default. Other networks can be trained by passing different command-line arguments, run the following to show commandline help: 23 | 24 | ``` 25 | cargo run --release --example fashion_mnist -- --help 26 | ``` 27 | 28 | ## Networks 29 | 30 | The example evaluates 4 different networks using this dataset. Here is a short description of each network and how it performs. 31 | 32 | ### Linear Classifier 33 | 34 | This network can be trained by passing `linear` as the network type: 35 | 36 | ``` 37 | cargo run --release --example fashion_mnist -- linear 38 | ``` 39 | 40 | This network has no hidden layers or activation functions, just a single dense layer from 784 (28x28) pixels to 10 categories. Softmax is then used to convert the final values into a probability for each category. 41 | 42 | ![](../../docs/fashion_mnist_network_linear.svg) 43 | 44 | Training 4 different random starting conditions for 40 epochs produces the following statistics: 45 | 46 | ![](../../docs/fashion_mnist_stats_linear.svg) 47 | 48 | Training accuracy creeps up to 87%, but does not seem to generalize well since test accuracy quickly tops out just short of 85%. 49 | 50 | ### Single Hidden Layer 51 | 52 | This network can be trained by passing `single-layer` as the network type: 53 | 54 | ``` 55 | cargo run --release --example fashion_mnist -- single-layer 56 | ``` 57 | 58 | This network adds a single hidden layer with 300 units and a (leaky) ReLU activation function: 59 | 60 | ![](../../docs/fashion_mnist_network_single-layer.svg) 61 | 62 | Repeating the same experiment with 4 random starting conditions for 40 epochs produces the following: 63 | 64 | ![](../../docs/fashion_mnist_stats_single-layer.svg) 65 | 66 | Accuracy on the test set is up to 89%, but we very quickly start overfitting to the training set (the loss function on the test set starts to increase after only 20 epochs). 67 | 68 | ### Convolutional Neural Network 69 | 70 | This network can be trained by passing `conv-net` as the network type (the default if no network is provided): 71 | 72 | ``` 73 | cargo run --release --example fashion_mnist -- conv-net 74 | ``` 75 | 76 | This network performs two rounds of 3x3 convolutions, with a ReLU activation and 2x2 max pooling after each one. 77 | 78 | ![](../../docs/fashion_mnist_network_conv-net.svg) 79 | 80 | To make the network slightly smaller, the second convolution is split into 2 groups, with the first 16 output channels reading from the first 8 input channels, and the second 16 output channels reading from the second 8 input channels. 81 | 82 | Training this for 40 epochs produces the following results: 83 | 84 | ![](../../docs/fashion_mnist_stats_conv-net.svg) 85 | 86 | Test accuracy is now up to 92%, which seems to be fairly respectable for a small network on this dataset. However, the network is still not generalising very well to the test set since test performance levels out while the training performance is still rising. 87 | 88 | ### Convolutional Neural Network With Anti-Aliasing 89 | 90 | This network can be trained by passing `conv-blur-net` as the network type: 91 | 92 | ``` 93 | cargo run --release --example fashion_mnist -- conv-blur-net 94 | ``` 95 | 96 | This network is identical to the CNN above, except max pooling is replaced with an anti-aliased alternative, as described in [Making Convolutional Networks Shift-Invariant Again by Richard Zhang](https://richzhang.github.io/antialiased-cnns/). 97 | 98 | Specifically we replace both of the previous max pool steps (that use size=2, stride=2) with two steps: 99 | 100 | - Max pool (size=2, stride=1) 101 | - Convolution with fixed weight 3x3 blur (stride=2, pad=1) 102 | 103 | This has the effect of smoothing out the response of max pool as features shift from pixel to pixel, so should produce a result that is more robust. 104 | The convolution weights are not trainable, so this does not add any parameters to the model. 105 | Back-propagation through the extra convolution step does add to the training time however. 106 | 107 | Training this for 40 epochs produces the following: 108 | 109 | ![](../../docs/fashion_mnist_stats_conv-blur-net.svg) 110 | 111 | This seems to confirm the results of the paper: test accuracy has increased slightly to 92.5% and there is less of a gap between the training set and test set performance, indicating that the model has generalised more effectively. 112 | -------------------------------------------------------------------------------- /docs/fashion_mnist_network_single-layer.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 16 | 46 | 51 | 52 | 54 | 59 | 66 | 73 | 80 | 1@28x28 96 | 103 | Dense+ ReLU 119 | Dense 130 | 134 | 138 | 300 149 | 10 160 | 161 | 162 | -------------------------------------------------------------------------------- /src/op.rs: -------------------------------------------------------------------------------- 1 | use crate::common::*; 2 | use ordered_float::NotNan; 3 | use petgraph::prelude::*; 4 | use slotmap::Key; 5 | use std::fmt; 6 | 7 | pub(crate) trait Only: Iterator { 8 | fn only(&mut self) -> Option; 9 | } 10 | 11 | impl Only for I { 12 | fn only(&mut self) -> Option { 13 | let first = self.next(); 14 | first.filter(|_| self.next().is_none()) 15 | } 16 | } 17 | 18 | pub(crate) type OpGraph = StableDiGraph; 19 | pub(crate) type OpNodeId = NodeIndex; 20 | pub(crate) type OpEdgeId = EdgeIndex; 21 | 22 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 23 | pub(crate) enum Literal { 24 | F32(NotNan), 25 | U32(u32), 26 | } 27 | 28 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 29 | pub(crate) enum ReduceOp { 30 | Max, 31 | Sum, 32 | } 33 | 34 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 35 | pub(crate) enum BuiltInOp { 36 | Coord, 37 | Rand { uid: usize }, 38 | } 39 | 40 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 41 | pub(crate) enum CompareMode { 42 | Eq, 43 | Gt, 44 | } 45 | 46 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 47 | pub(crate) enum BinaryOp { 48 | Add, 49 | Sub, 50 | Mul, 51 | Div, 52 | Pow, 53 | UAdd, 54 | UMul, 55 | URem, 56 | UBitXor, 57 | } 58 | 59 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 60 | pub(crate) enum UnaryOp { 61 | Mov, 62 | Neg, 63 | Sqrt, 64 | Exp, 65 | Log, 66 | Sin, 67 | Cos, 68 | FloatToUint, 69 | UintToFloat, 70 | } 71 | 72 | pub(crate) const MAX_OP_ARGS: usize = 4; 73 | 74 | pub(crate) const MATMUL_MAX_K_SIZE: usize = 1024; 75 | 76 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 77 | pub(crate) enum MatMulOutputMode { 78 | Batches, 79 | Rows, 80 | } 81 | 82 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] 83 | pub(crate) enum Op { 84 | Input { parameter_id: ParameterId }, 85 | Output { parameter_id: ParameterId }, 86 | Literal(Literal), 87 | BuiltIn(BuiltInOp), 88 | Unary(UnaryOp), 89 | Binary(BinaryOp), 90 | CompareAndSelect(CompareMode), 91 | MatMul { output_mode: MatMulOutputMode }, 92 | Reduce { reduce_op: ReduceOp, axis: Axis }, // TODO: 2D version? 93 | Unpad { axis: Axis, pad: usize }, // TODO: 2D version? 94 | WindowsToImage { stride: (usize, usize) }, 95 | Gather { axis: Axis }, 96 | ScatterAdd { axis: Axis }, 97 | } 98 | 99 | impl Op { 100 | pub(crate) fn input_parameter_id(&self) -> Option { 101 | match self { 102 | Self::Input { parameter_id } => Some(*parameter_id), 103 | _ => None, 104 | } 105 | } 106 | 107 | pub(crate) fn output_parameter_id(&self) -> Option { 108 | match self { 109 | Self::Output { parameter_id } => Some(*parameter_id), 110 | _ => None, 111 | } 112 | } 113 | 114 | pub(crate) fn is_per_element(&self) -> bool { 115 | matches!( 116 | self, 117 | Self::Unary(_) | Self::Binary(_) | Self::CompareAndSelect(_) | Self::Gather { .. } 118 | ) 119 | } 120 | 121 | pub(crate) fn is_gather_arg(&self, arg: usize) -> bool { 122 | match self { 123 | Self::Gather { .. } => arg == 0, 124 | _ => false, 125 | } 126 | } 127 | 128 | pub(crate) fn can_reshape(&self) -> bool { 129 | !matches!(self, Self::BuiltIn(_) | Self::Gather { .. }) 130 | } 131 | 132 | pub(crate) fn can_merge(&self) -> bool { 133 | !matches!(self, Self::Input { .. } | Self::Output { .. }) 134 | } 135 | } 136 | 137 | impl fmt::Display for Op { 138 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 139 | match self { 140 | Self::Input { parameter_id } => write!(f, "Input({:?})", parameter_id.data()), 141 | Self::Output { parameter_id } => write!(f, "Output({:?})", parameter_id.data()), 142 | Self::Literal(value) => write!(f, "{:?}", value), 143 | Self::BuiltIn(built_in_op) => match built_in_op { 144 | BuiltInOp::Coord => write!(f, "Coord"), 145 | BuiltInOp::Rand { .. } => write!(f, "Rand"), 146 | }, 147 | Self::Unary(unary_op) => write!(f, "{:?}", unary_op), 148 | Self::Binary(binary_op) => write!(f, "{:?}", binary_op), 149 | Self::CompareAndSelect(compare_mode) => write!(f, "Select{:?}", compare_mode), 150 | Self::MatMul { .. } => write!(f, "MatMul"), 151 | Self::Reduce { reduce_op, axis } => { 152 | write!(f, "Reduce{:?}({})", reduce_op, axis.index()) 153 | } 154 | Self::Unpad { axis, pad } => write!(f, "Unpad{}({})", pad, axis.index()), 155 | Self::WindowsToImage { .. } => write!(f, "WindowsToImage"), 156 | Self::Gather { axis } => write!(f, "Gather({})", axis.index()), 157 | Self::ScatterAdd { axis } => write!(f, "ScatterAdd({})", axis.index()), 158 | } 159 | } 160 | } 161 | 162 | #[derive(Debug, Clone)] 163 | pub(crate) struct OpNode { 164 | pub(crate) colour: usize, 165 | pub(crate) shape: Shape, 166 | pub(crate) op: Op, 167 | pub(crate) cluster_id: Option, 168 | } 169 | 170 | #[derive(Debug, Clone)] 171 | pub(crate) struct OpEdge { 172 | pub(crate) arg: usize, 173 | pub(crate) view: View, 174 | } 175 | 176 | impl OpEdge { 177 | pub(crate) fn is_per_element(&self, op: &Op) -> bool { 178 | !op.is_gather_arg(self.arg) && self.view.is_contiguous() 179 | } 180 | } 181 | 182 | pub(crate) trait OpGraphExt { 183 | fn new_node( 184 | &mut self, 185 | colour: usize, 186 | shape: impl Into, 187 | op: Op, 188 | inputs: &[OpNodeId], 189 | ) -> OpNodeId; 190 | } 191 | 192 | impl OpGraphExt for OpGraph { 193 | fn new_node( 194 | &mut self, 195 | colour: usize, 196 | shape: impl Into, 197 | op: Op, 198 | inputs: &[OpNodeId], 199 | ) -> OpNodeId { 200 | let shape = shape.into(); 201 | let node_id = self.add_node(OpNode { 202 | colour, 203 | shape, 204 | op, 205 | cluster_id: None, 206 | }); 207 | for (index, input_id) in inputs.iter().copied().enumerate() { 208 | self.add_edge( 209 | input_id, 210 | node_id, 211 | OpEdge { 212 | arg: index, 213 | view: self[input_id].shape.identity_view(), 214 | }, 215 | ); 216 | } 217 | node_id 218 | } 219 | } 220 | -------------------------------------------------------------------------------- /examples/sentiment/main.rs: -------------------------------------------------------------------------------- 1 | use descent::{loss::*, module::*, optimizer::*, prelude::*}; 2 | use rand::{prelude::SliceRandom, RngCore, SeedableRng}; 3 | use serde::Deserialize; 4 | use std::{ 5 | collections::HashMap, 6 | fs::File, 7 | io::{prelude::*, BufReader}, 8 | iter, 9 | }; 10 | 11 | #[derive(Debug, Deserialize)] 12 | #[serde(rename_all = "lowercase")] 13 | enum JsonLabel { 14 | Positive, 15 | Negative, 16 | Neutral, 17 | Mixed, 18 | } 19 | 20 | #[derive(Debug, Deserialize)] 21 | struct JsonRecord { 22 | sentence: String, 23 | gold_label: Option, 24 | } 25 | 26 | struct Record { 27 | sentence: Vec, 28 | label: u8, 29 | } 30 | 31 | fn vocab_iter<'s>(s: &'s str) -> impl Iterator + 's { 32 | s.split_whitespace().filter_map(|word| { 33 | let entry: String = word 34 | .chars() 35 | .filter(|c| c.is_alphabetic()) 36 | .flat_map(|c| c.to_lowercase()) 37 | .collect(); 38 | if entry.is_empty() { 39 | None 40 | } else { 41 | Some(entry) 42 | } 43 | }) 44 | } 45 | 46 | fn main() { 47 | const MAX_WORD_COUNT: usize = 32; 48 | let f = 49 | File::open("data/dynasent/dynasent-v1.1/dynasent-v1.1-round01-yelp-train.jsonl").unwrap(); 50 | let records: Vec<_> = BufReader::new(f) 51 | .lines() 52 | .filter_map(|line| { 53 | let record: JsonRecord = serde_json::from_str(&line.unwrap()).unwrap(); 54 | if matches!(record.gold_label, None | Some(JsonLabel::Mixed)) 55 | || vocab_iter(&record.sentence).count() > MAX_WORD_COUNT 56 | { 57 | None 58 | } else { 59 | Some(record) 60 | } 61 | }) 62 | .collect(); 63 | 64 | let mut word_map = HashMap::new(); 65 | for record in records.iter() { 66 | for word in vocab_iter(&record.sentence) { 67 | *word_map.entry(word).or_insert(0) += 1; 68 | } 69 | } 70 | let retain_limit = records.len() / 5; 71 | let mut next_index = 1; 72 | word_map.retain(|_key, value| { 73 | if *value < retain_limit { 74 | *value = next_index; 75 | next_index += 1; 76 | true 77 | } else { 78 | false 79 | } 80 | }); 81 | let vocab_size = next_index; 82 | println!("vocab size: {}", vocab_size); 83 | 84 | let records: Vec<_> = records 85 | .iter() 86 | .map(|record| { 87 | let sentence: Vec<_> = vocab_iter(&record.sentence) 88 | .filter_map(|word| word_map.get(&word).map(|index| *index as u16)) 89 | .collect(); 90 | let label = match record.gold_label { 91 | Some(JsonLabel::Negative) => 0, 92 | Some(JsonLabel::Neutral) => 1, 93 | Some(JsonLabel::Positive) => 2, 94 | _ => unreachable!(), 95 | }; 96 | Record { sentence, label } 97 | }) 98 | .collect(); 99 | println!("records: {}", records.len()); 100 | println!( 101 | "sentence length max: {}, avg: {}", 102 | records 103 | .iter() 104 | .map(|record| record.sentence.len()) 105 | .max() 106 | .unwrap_or(0), 107 | records 108 | .iter() 109 | .map(|record| record.sentence.len()) 110 | .sum::() 111 | / records.len() 112 | ); 113 | 114 | let mut env = Environment::new(); 115 | 116 | let embedding_size = 128; 117 | let lstm_size = 64; 118 | let lstm = LSTMCell::new(&mut env, embedding_size, lstm_size); 119 | let fc = Dense::builder(lstm_size, 3).build(&mut env); 120 | 121 | let m = 256; 122 | let x_param = env.static_parameter([m, MAX_WORD_COUNT, 1], "x"); 123 | let y_param = env.static_parameter([m, 1], "y"); 124 | 125 | let embedding = env.trainable_parameter( 126 | [vocab_size, embedding_size], 127 | "em", 128 | Initializer::RandUniform(1.0), 129 | ); 130 | let loss_sum_param = env.static_parameter([1], "loss"); 131 | let accuracy_sum_param = env.static_parameter([1], "accuracy"); 132 | 133 | let (train_graph, parameters, _optimizer) = { 134 | let scope = env.scope(); 135 | 136 | let x: DualArray = scope 137 | .parameter(&x_param) 138 | .value() 139 | .one_hot(vocab_size) 140 | .with_empty_grad() 141 | .into(); 142 | let x = x 143 | .reshape([m * MAX_WORD_COUNT, vocab_size]) 144 | .matmul(&embedding) 145 | .reshape([m, MAX_WORD_COUNT, embedding_size]); 146 | let x = lstm.train(x); 147 | let x = fc.train(x); 148 | let loss = softmax_cross_entropy_loss(x, &y_param).set_loss(); 149 | let accuracy = softmax_cross_entropy_accuracy(x, &y_param); 150 | 151 | scope.update_parameter_value(&loss_sum_param, |loss_sum| { 152 | loss_sum + loss.reduce_sum(0, false) 153 | }); 154 | scope.update_parameter_value(&accuracy_sum_param, |accuracy_sum| { 155 | accuracy_sum + accuracy.reduce_sum(0, false) 156 | }); 157 | 158 | let parameters = scope.trainable_parameters(); 159 | let optimizer = Adam::new(&mut env, &scope, ¶meters, 0.002, 0.9, 0.999, 1.0E-8); 160 | 161 | (scope.build_graph(), parameters, optimizer) 162 | }; 163 | 164 | let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(0); 165 | for param in parameters.iter() { 166 | env.reset_parameter(param, &mut rng); 167 | } 168 | 169 | let mini_batch_per_epoch = records.len() / m; 170 | for epoch_index in 0..40 { 171 | env.writer(&loss_sum_param).zero_fill(); 172 | env.writer(&accuracy_sum_param).zero_fill(); 173 | 174 | for _ in 0..mini_batch_per_epoch { 175 | let mut xw = env.writer(&x_param); 176 | let mut labels = Vec::new(); 177 | for record in records.choose_multiple(&mut rng, m) { 178 | let sentence: Vec<_> = record 179 | .sentence 180 | .iter() 181 | .copied() 182 | .chain(iter::repeat(0)) 183 | .take(MAX_WORD_COUNT) 184 | .map(|w| w as f32) 185 | .collect(); 186 | xw.write(bytemuck::cast_slice(&sentence)).unwrap(); 187 | labels.push(record.label as f32); 188 | } 189 | xw.zero_fill(); 190 | let mut yw = env.writer(&y_param); 191 | yw.write(bytemuck::cast_slice(&labels)).unwrap(); 192 | yw.zero_fill(); 193 | 194 | env.run(&train_graph, rng.next_u32()); 195 | } 196 | 197 | if epoch_index < 2 { 198 | env.print_timings("training"); 199 | } 200 | 201 | let train_count = m * mini_batch_per_epoch; 202 | let train_loss = env.read_parameter_scalar(&loss_sum_param) / (train_count as f32); 203 | let train_accuracy = env.read_parameter_scalar(&accuracy_sum_param) / (train_count as f32); 204 | 205 | let done_counter = epoch_index + 1; 206 | println!( 207 | "epoch: {}, loss: {}, accuracy: {}", 208 | done_counter, train_loss, train_accuracy 209 | ); 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /src/device/context.rs: -------------------------------------------------------------------------------- 1 | use spark::{vk, Builder, Device, DeviceExtensions, Instance, InstanceExtensions, Loader}; 2 | use std::rc::Rc; 3 | use std::{ffi::CStr, slice}; 4 | 5 | trait PhysicalDeviceMemoryPropertiesExt { 6 | fn types(&self) -> &[vk::MemoryType]; 7 | fn heaps(&self) -> &[vk::MemoryHeap]; 8 | } 9 | 10 | impl PhysicalDeviceMemoryPropertiesExt for vk::PhysicalDeviceMemoryProperties { 11 | fn types(&self) -> &[vk::MemoryType] { 12 | &self.memory_types[..self.memory_type_count as usize] 13 | } 14 | fn heaps(&self) -> &[vk::MemoryHeap] { 15 | &self.memory_heaps[..self.memory_heap_count as usize] 16 | } 17 | } 18 | 19 | pub(crate) struct Context { 20 | pub(crate) instance: Instance, 21 | pub(crate) _physical_device: vk::PhysicalDevice, 22 | pub(crate) physical_device_properties: vk::PhysicalDeviceProperties, 23 | pub(crate) physical_device_memory_properties: vk::PhysicalDeviceMemoryProperties, 24 | pub(crate) queue_family_index: u32, 25 | pub(crate) queue_family_properties: vk::QueueFamilyProperties, 26 | pub(crate) queue: vk::Queue, 27 | pub(crate) device: Device, 28 | pub(crate) has_shader_atomic_float_add: bool, 29 | } 30 | 31 | pub(crate) type SharedContext = Rc; 32 | 33 | impl Context { 34 | pub(crate) fn new() -> SharedContext { 35 | let version = vk::Version::default(); 36 | let instance = { 37 | let loader = Loader::new().unwrap(); 38 | 39 | let available_extensions = { 40 | let extension_properties = 41 | unsafe { loader.enumerate_instance_extension_properties_to_vec(None) }.unwrap(); 42 | InstanceExtensions::from_properties(version, &extension_properties) 43 | }; 44 | 45 | let mut extensions = InstanceExtensions::new(version); 46 | if available_extensions.supports_ext_debug_utils() { 47 | extensions.enable_ext_debug_utils(); 48 | } 49 | if available_extensions.supports_ext_shader_atomic_float() { 50 | extensions.enable_ext_shader_atomic_float(); 51 | } 52 | let extension_names = extensions.to_name_vec(); 53 | 54 | let app_info = vk::ApplicationInfo::builder() 55 | .p_application_name(Some(CStr::from_bytes_with_nul(b"caldera\0").unwrap())) 56 | .api_version(version); 57 | 58 | let extension_name_ptrs: Vec<_> = extension_names.iter().map(|s| s.as_ptr()).collect(); 59 | let instance_create_info = vk::InstanceCreateInfo::builder() 60 | .p_application_info(Some(&app_info)) 61 | .pp_enabled_extension_names(&extension_name_ptrs); 62 | unsafe { loader.create_instance(&instance_create_info, None) }.unwrap() 63 | }; 64 | 65 | let physical_device = { 66 | let physical_devices = unsafe { instance.enumerate_physical_devices_to_vec() }.unwrap(); 67 | for (i, physical_device) in physical_devices.iter().enumerate() { 68 | let props = unsafe { instance.get_physical_device_properties(*physical_device) }; 69 | println!( 70 | "physical device {}: {:?} ({})", 71 | i, 72 | unsafe { CStr::from_ptr(props.device_name.as_ptr()) }, 73 | props.device_type 74 | ); 75 | } 76 | physical_devices[0] 77 | }; 78 | let physical_device_properties = 79 | unsafe { instance.get_physical_device_properties(physical_device) }; 80 | 81 | let physical_device_memory_properties = 82 | unsafe { instance.get_physical_device_memory_properties(physical_device) }; 83 | 84 | let mut available_atomic_float_features = 85 | vk::PhysicalDeviceShaderAtomicFloatFeaturesEXT::default(); 86 | if instance 87 | .extensions 88 | .supports_khr_get_physical_device_properties2() 89 | { 90 | let mut features = vk::PhysicalDeviceFeatures2KHR::builder() 91 | .insert_next(&mut available_atomic_float_features); 92 | unsafe { 93 | instance.get_physical_device_features2_khr(physical_device, features.get_mut()); 94 | } 95 | } 96 | 97 | let (queue_family_index, queue_family_properties) = { 98 | let queue_flags = vk::QueueFlags::COMPUTE; 99 | 100 | unsafe { instance.get_physical_device_queue_family_properties_to_vec(physical_device) } 101 | .iter() 102 | .enumerate() 103 | .filter_map(|(index, info)| { 104 | if info.queue_flags.contains(queue_flags) { 105 | Some((index as u32, *info)) 106 | } else { 107 | None 108 | } 109 | }) 110 | .next() 111 | .unwrap() 112 | }; 113 | 114 | let mut has_shader_atomic_float_add = false; 115 | let device = { 116 | let queue_priorities = [1.0]; 117 | let device_queue_create_info = vk::DeviceQueueCreateInfo::builder() 118 | .queue_family_index(queue_family_index) 119 | .p_queue_priorities(&queue_priorities); 120 | 121 | let available_extensions = { 122 | let extension_properties = unsafe { 123 | instance.enumerate_device_extension_properties_to_vec(physical_device, None) 124 | } 125 | .unwrap(); 126 | DeviceExtensions::from_properties(version, &extension_properties) 127 | }; 128 | 129 | let mut extensions = DeviceExtensions::new(version); 130 | let mut shader_atomic_float_features = 131 | vk::PhysicalDeviceShaderAtomicFloatFeaturesEXT::default(); 132 | if available_extensions.supports_ext_shader_atomic_float() 133 | && available_atomic_float_features.shader_buffer_float32_atomic_add == vk::TRUE 134 | { 135 | extensions.enable_ext_shader_atomic_float(); 136 | shader_atomic_float_features.shader_buffer_float32_atomic_add = vk::TRUE; 137 | has_shader_atomic_float_add = true; 138 | } 139 | let extension_names = extensions.to_name_vec(); 140 | 141 | let extension_name_ptrs: Vec<_> = extension_names.iter().map(|s| s.as_ptr()).collect(); 142 | let device_create_info = vk::DeviceCreateInfo::builder() 143 | .p_queue_create_infos(slice::from_ref(&device_queue_create_info)) 144 | .pp_enabled_extension_names(&extension_name_ptrs) 145 | .insert_next(&mut shader_atomic_float_features); 146 | 147 | unsafe { instance.create_device(physical_device, &device_create_info, None, version) } 148 | .unwrap() 149 | }; 150 | 151 | let queue = unsafe { device.get_device_queue(queue_family_index, 0) }; 152 | 153 | SharedContext::new(Self { 154 | instance, 155 | _physical_device: physical_device, 156 | physical_device_properties, 157 | physical_device_memory_properties, 158 | queue_family_index, 159 | queue_family_properties, 160 | queue, 161 | device, 162 | has_shader_atomic_float_add, 163 | }) 164 | } 165 | 166 | pub(crate) fn get_memory_type_index( 167 | &self, 168 | type_filter: u32, 169 | property_flags: vk::MemoryPropertyFlags, 170 | ) -> Option { 171 | for (i, mt) in self 172 | .physical_device_memory_properties 173 | .types() 174 | .iter() 175 | .enumerate() 176 | { 177 | let i = i as u32; 178 | if (type_filter & (1 << i)) != 0 && mt.property_flags.contains(property_flags) { 179 | return Some(i); 180 | } 181 | } 182 | None 183 | } 184 | } 185 | 186 | impl Drop for Context { 187 | fn drop(&mut self) { 188 | unsafe { 189 | self.device.destroy_device(None); 190 | self.instance.destroy_instance(None); 191 | } 192 | } 193 | } 194 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod array; 2 | mod device; 3 | pub mod environment; 4 | pub mod prelude { 5 | pub use crate::{array::*, environment::*, graph::*, parameter::*, shape::*}; 6 | } 7 | mod common { 8 | pub(crate) use crate::{kernel::*, op::*, prelude::*}; 9 | } 10 | pub mod graph; 11 | mod kernel; 12 | pub mod loss; 13 | pub mod module; 14 | mod op; 15 | pub mod optimizer; 16 | pub mod parameter; 17 | pub mod shape; 18 | 19 | #[cfg(test)] 20 | mod tests { 21 | use crate::prelude::*; 22 | use std::iter; 23 | 24 | const TEST_RAND_SEED: u32 = 0x5EED5EED; 25 | 26 | #[test] 27 | fn parameters() { 28 | let mut env = Environment::new(); 29 | 30 | let a_data = vec![0f32, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; 31 | let a_param = env.static_parameter_with_data([10], "a", &a_data); 32 | 33 | assert_eq!(env.read_parameter_to_vec(&a_param), a_data); 34 | } 35 | 36 | #[test] 37 | fn reduce() { 38 | let mut env = Environment::new(); 39 | 40 | let a_data: Vec = (0..100).map(|i| i as f32).collect(); 41 | let b_data: Vec = a_data.chunks(10).map(|v| v.iter().sum::()).collect(); 42 | 43 | let a_param = env.static_parameter_with_data([10, 10], "a", &a_data); 44 | let b_param = env.static_parameter([10, 1], "b"); 45 | 46 | let g = env.build_graph(|scope| { 47 | scope.write_parameter_value( 48 | &b_param, 49 | scope.parameter_value(&a_param).reduce_sum(-1, true), 50 | ); 51 | }); 52 | env.run(&g, TEST_RAND_SEED); 53 | 54 | assert_eq!(env.read_parameter_to_vec(&b_param), b_data); 55 | } 56 | 57 | #[test] 58 | fn pad_image() { 59 | let mut env = Environment::new(); 60 | 61 | let a_data: Vec = iter::repeat(1.0).take(64).collect(); 62 | let b_data: Vec = iter::repeat(1.0).take(100).collect(); 63 | 64 | let a_param = env.static_parameter_with_data([1, 8, 8, 1], "a", &a_data); 65 | let b_param = env.static_parameter([1, 10, 10, 1], "b"); 66 | 67 | let g = env.build_graph(|scope| { 68 | scope.write_parameter_value(&b_param, scope.parameter_value(&a_param).pad_image(1)); 69 | }); 70 | env.run(&g, TEST_RAND_SEED); 71 | 72 | assert_eq!(env.read_parameter_to_vec(&b_param), b_data); 73 | } 74 | 75 | #[test] 76 | fn unpad_image() { 77 | let mut env = Environment::new(); 78 | 79 | let a_data: Vec = iter::repeat(1.0).take(100).collect(); 80 | 81 | let unpad = |a| if a == 0 || a == 7 { 2.0 } else { 1.0 }; 82 | let b_data: Vec = (0..8) 83 | .flat_map(move |y| { 84 | let ny = unpad(y); 85 | (0..8).map(move |x| ny * unpad(x)) 86 | }) 87 | .collect(); 88 | 89 | let a_param = env.static_parameter_with_data([1, 10, 10, 1], "a", &a_data); 90 | let b_param = env.static_parameter([1, 8, 8, 1], "b"); 91 | 92 | let g = env.build_graph(|scope| { 93 | scope.write_parameter_value(&b_param, scope.parameter_value(&a_param).unpad_image(1)); 94 | }); 95 | env.run(&g, TEST_RAND_SEED); 96 | 97 | assert_eq!(env.read_parameter_to_vec(&b_param), b_data); 98 | } 99 | 100 | #[test] 101 | fn conv2d() { 102 | let mut env = Environment::new(); 103 | 104 | let a_data: Vec = iter::repeat(1.0).take(100).collect(); 105 | let b_data: Vec = iter::repeat(1.0).take(9).collect(); 106 | let c_data: Vec = iter::repeat(9.0).take(64).collect(); 107 | 108 | let a_param = env.static_parameter_with_data([1, 10, 10, 1], "a", &a_data); 109 | let b_param = env.static_parameter_with_data([1, 1, 3, 3, 1], "b", &b_data); 110 | let c_param = env.static_parameter([1, 8, 8, 1], "c"); 111 | 112 | let g = env.build_graph(|scope| { 113 | scope.write_parameter_value( 114 | &c_param, 115 | scope 116 | .parameter(&a_param) 117 | .conv2d(&b_param, 0, (1, 1)) 118 | .value(), 119 | ); 120 | }); 121 | env.run(&g, TEST_RAND_SEED); 122 | 123 | assert_eq!(env.read_parameter_to_vec(&c_param), c_data); 124 | } 125 | 126 | #[test] 127 | fn max_pool2d() { 128 | let mut env = Environment::new(); 129 | 130 | let a_data: Vec = (0..100).map(|i| i as f32).collect(); 131 | let b_data: Vec = (0..25) 132 | .map(|i| (11 + 2 * (i % 5) + 20 * (i / 5)) as f32) 133 | .collect(); 134 | 135 | let a_param = env.static_parameter_with_data([1, 10, 10, 1], "a", &a_data); 136 | let b_param = env.static_parameter([1, 5, 5, 1], "b"); 137 | 138 | let g = env.build_graph(|scope| { 139 | scope.write_parameter_value( 140 | &b_param, 141 | scope.parameter(&a_param).max_pool2d((2, 2), (2, 2)).value(), 142 | ); 143 | }); 144 | env.run(&g, TEST_RAND_SEED); 145 | 146 | assert_eq!(env.read_parameter_to_vec(&b_param), b_data); 147 | } 148 | 149 | #[test] 150 | fn gather() { 151 | let mut env = Environment::new(); 152 | 153 | let a_data: Vec = (0..200).map(|i| (i * i) as f32).collect(); 154 | let b_data: Vec = (0..100).map(|i| (99 - i) as f32).collect(); 155 | let c_data: Vec = (0..100).map(|i| ((99 - i) * (99 - i) + 1) as f32).collect(); 156 | 157 | let a_param = env.static_parameter_with_data([1, 200, 1], "a", &a_data); 158 | let b_param = env.static_parameter_with_data([100], "b", &b_data); 159 | let c_param = env.static_parameter([1, 100, 1], "c"); 160 | 161 | let g = env.build_graph(|scope| { 162 | scope.write_parameter_value( 163 | &c_param, 164 | scope 165 | .parameter_value(&a_param) 166 | .gather(1, scope.parameter_value(&b_param).into_u32()) 167 | + 1.0, 168 | ); 169 | }); 170 | env.run(&g, TEST_RAND_SEED); 171 | 172 | assert_eq!(env.read_parameter_to_vec(&c_param), c_data); 173 | } 174 | 175 | #[test] 176 | fn scatter_add() { 177 | let mut env = Environment::new(); 178 | 179 | let range = 10; 180 | 181 | let a_data: Vec = iter::repeat(1.0).take(100).collect(); 182 | let b_data: Vec = (0..range).map(|i| i as f32).cycle().take(100).collect(); 183 | let c_data: Vec = iter::repeat(10.0).take(10).collect(); 184 | 185 | let a_param = env.static_parameter_with_data([1, 100, 1], "a", &a_data); 186 | let b_param = env.static_parameter_with_data([100], "b", &b_data); 187 | let c_param = env.static_parameter([1, range, 1], "c"); 188 | 189 | let g = env.build_graph(|scope| { 190 | scope.write_parameter_value( 191 | &c_param, 192 | scope 193 | .literal(0.0) 194 | .value() 195 | .broadcast([1, range, 1]) 196 | .scatter_add(&a_param, -2, scope.parameter_value(&b_param).into_u32()), 197 | ); 198 | }); 199 | env.run(&g, TEST_RAND_SEED); 200 | 201 | assert_eq!(env.read_parameter_to_vec(&c_param), c_data); 202 | } 203 | 204 | #[test] 205 | fn concat() { 206 | let mut env = Environment::new(); 207 | 208 | let a_data: Vec = (0..200) 209 | .filter(|i| ((i / 10) & 1) == 0) 210 | .map(|i| i as f32) 211 | .collect(); 212 | let b_data: Vec = (0..200) 213 | .filter(|i| ((i / 10) & 1) == 1) 214 | .map(|i| i as f32) 215 | .collect(); 216 | let c_data: Vec = (0..200).map(|i| i as f32).collect(); 217 | 218 | let a_param = env.static_parameter_with_data([10, 10], "a", &a_data); 219 | let b_param = env.static_parameter_with_data([10, 10], "b", &b_data); 220 | let c_param = env.static_parameter([10, 20], "c"); 221 | 222 | let g = env.build_graph(|scope| { 223 | scope.write_parameter_value( 224 | &c_param, 225 | scope.parameter_value(&a_param).concat(&b_param, -1), 226 | ); 227 | }); 228 | env.run(&g, TEST_RAND_SEED); 229 | 230 | assert_eq!(env.read_parameter_to_vec(&c_param), c_data); 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /src/device/timestamp.rs: -------------------------------------------------------------------------------- 1 | use super::common::*; 2 | use ordered_float::NotNan; 3 | use spark::vk; 4 | use std::{ 5 | collections::{BinaryHeap, HashMap, VecDeque}, 6 | mem, 7 | }; 8 | 9 | #[derive(Debug, Clone, Copy, PartialOrd, Ord, PartialEq, Eq)] 10 | struct NameId { 11 | index: u32, 12 | } 13 | 14 | struct TimestampSet { 15 | context: SharedContext, 16 | query_pool: vk::QueryPool, 17 | timestamp_ids: Vec, 18 | } 19 | 20 | impl TimestampSet { 21 | const MAX_QUERY_COUNT: usize = 128; 22 | 23 | fn new(context: &SharedContext) -> Self { 24 | let query_pool = { 25 | let create_info = vk::QueryPoolCreateInfo { 26 | query_type: vk::QueryType::TIMESTAMP, 27 | query_count: Self::MAX_QUERY_COUNT as u32, 28 | ..Default::default() 29 | }; 30 | 31 | unsafe { context.device.create_query_pool(&create_info, None) }.unwrap() 32 | }; 33 | Self { 34 | context: SharedContext::clone(context), 35 | query_pool, 36 | timestamp_ids: Vec::new(), 37 | } 38 | } 39 | 40 | fn write_timestamp(&mut self, cmd: vk::CommandBuffer, id: NameId) { 41 | if self.timestamp_ids.len() >= TimestampSet::MAX_QUERY_COUNT { 42 | return; 43 | } 44 | unsafe { 45 | self.context.device.cmd_write_timestamp( 46 | cmd, 47 | vk::PipelineStageFlags::BOTTOM_OF_PIPE, 48 | self.query_pool, 49 | self.timestamp_ids.len() as u32, 50 | ) 51 | }; 52 | self.timestamp_ids.push(id); 53 | } 54 | } 55 | 56 | #[derive(Debug, Clone, Copy, PartialOrd, Ord, PartialEq, Eq)] 57 | struct TimestampEntry { 58 | total: NotNan, 59 | id: NameId, 60 | } 61 | 62 | impl TimestampEntry { 63 | fn new(id: NameId, time: f32) -> Self { 64 | Self { 65 | total: NotNan::new(time).unwrap(), 66 | id, 67 | } 68 | } 69 | } 70 | 71 | struct TimestampAccumulator { 72 | context: SharedContext, 73 | time_total: f32, 74 | time_per_id: Vec, 75 | counter: u32, 76 | timestamp_valid_mask: u64, 77 | timestamp_period: f32, 78 | } 79 | 80 | impl TimestampAccumulator { 81 | fn new(context: &SharedContext) -> Self { 82 | Self { 83 | context: SharedContext::clone(context), 84 | time_total: 0.0, 85 | time_per_id: Vec::new(), 86 | counter: 0, 87 | timestamp_valid_mask: 1u64 88 | .checked_shl(context.queue_family_properties.timestamp_valid_bits) 89 | .unwrap_or(0) 90 | .wrapping_sub(1), 91 | timestamp_period: context.physical_device_properties.limits.timestamp_period 92 | / 1_000_000_000.0, 93 | } 94 | } 95 | 96 | fn accumulate_timings(&mut self, set: &mut TimestampSet) { 97 | if !set.timestamp_ids.is_empty() { 98 | let mut query_results = vec![0u64; set.timestamp_ids.len()]; 99 | unsafe { 100 | self.context.device.get_query_pool_results( 101 | set.query_pool, 102 | 0, 103 | query_results.len() as u32, 104 | &mut query_results, 105 | mem::size_of::() as vk::DeviceSize, 106 | vk::QueryResultFlags::N64 | vk::QueryResultFlags::WAIT, 107 | ) 108 | } 109 | .unwrap(); 110 | 111 | let query_deltas: Vec = (0..(set.timestamp_ids.len() - 1)) 112 | .map(|i| { 113 | let a = query_results[i]; 114 | let b = query_results[i + 1]; 115 | b.wrapping_sub(a) & self.timestamp_valid_mask 116 | }) 117 | .collect(); 118 | let query_times: Vec = query_deltas 119 | .iter() 120 | .copied() 121 | .map(|delta| (delta as f32) * self.timestamp_period) 122 | .collect(); 123 | let total_time = 124 | (query_deltas.iter().copied().sum::() as f32) * self.timestamp_period; 125 | 126 | if self.time_per_id.len() == query_times.len() 127 | && self 128 | .time_per_id 129 | .iter() 130 | .zip(set.timestamp_ids.iter().copied()) 131 | .all(|(entry, id)| entry.id == id) 132 | { 133 | self.time_total += total_time; 134 | for (entry, time) in self.time_per_id.iter_mut().zip(query_times.iter()) { 135 | entry.total += time; 136 | } 137 | self.counter += 1; 138 | } else { 139 | self.time_total = total_time; 140 | self.time_per_id.clear(); 141 | self.time_per_id.extend( 142 | set.timestamp_ids 143 | .iter() 144 | .copied() 145 | .zip(query_times.iter().copied()) 146 | .map(|(id, time)| TimestampEntry::new(id, time)), 147 | ); 148 | self.counter = 1; 149 | } 150 | 151 | set.timestamp_ids.clear(); 152 | } 153 | } 154 | 155 | fn print_timings(&self, label: &str, names: &[String]) { 156 | if self.counter != 0 { 157 | let norm = 1.0 / (self.counter as f32); 158 | println!( 159 | "{} total: {:.2} ms (average of {} runs)", 160 | label, 161 | norm * self.time_total * 1000.0, 162 | self.counter, 163 | ); 164 | let mut heap: BinaryHeap = self.time_per_id.iter().copied().collect(); 165 | for i in 0..5 { 166 | if let Some(entry) = heap.pop() { 167 | let name = &names[entry.id.index as usize]; 168 | let total = entry.total.into_inner(); 169 | println!( 170 | "({}) {:>6.2} ms ({:>4.1}%): {}", 171 | i + 1, 172 | norm * total * 1000.0, 173 | 100.0 * total / self.time_total, 174 | name 175 | ); 176 | } 177 | } 178 | } 179 | } 180 | 181 | fn reset_timings(&mut self) { 182 | self.time_total = 0.0; 183 | self.time_per_id.clear(); 184 | self.counter = 0; 185 | } 186 | } 187 | 188 | pub(crate) struct TimestampSets { 189 | context: SharedContext, 190 | sets: VecDeque>, 191 | name_ids: HashMap, 192 | names: Vec, 193 | accumulator: TimestampAccumulator, 194 | } 195 | 196 | impl TimestampSets { 197 | const COUNT: usize = 2; 198 | 199 | pub(crate) fn new(context: &SharedContext, fences: &FenceSet) -> Self { 200 | let mut sets = VecDeque::new(); 201 | for _ in 0..Self::COUNT { 202 | sets.push_back(Fenced::new(TimestampSet::new(context), fences.old_id())); 203 | } 204 | Self { 205 | context: SharedContext::clone(context), 206 | sets, 207 | name_ids: HashMap::new(), 208 | names: Vec::new(), 209 | accumulator: TimestampAccumulator::new(context), 210 | } 211 | } 212 | 213 | fn name_id(&mut self, name: &str) -> NameId { 214 | if let Some(id) = self.name_ids.get(name) { 215 | *id 216 | } else { 217 | let id = NameId { 218 | index: self.names.len() as u32, 219 | }; 220 | self.names.push(name.to_owned()); 221 | self.name_ids.insert(name.to_owned(), id); 222 | id 223 | } 224 | } 225 | 226 | pub(crate) fn print_timings(&mut self, label: &str, fences: &FenceSet) { 227 | // ensure all timings have been processed 228 | for set in self.sets.iter_mut() { 229 | self.accumulator 230 | .accumulate_timings(set.get_mut_when_signaled(fences)); 231 | } 232 | self.accumulator.print_timings(label, &self.names); 233 | self.accumulator.reset_timings(); 234 | } 235 | 236 | pub(crate) fn acquire( 237 | &mut self, 238 | cmd: vk::CommandBuffer, 239 | fences: &FenceSet, 240 | ) -> ScopedTimestampSet { 241 | let mut set = self.sets.pop_front().unwrap().take_when_signaled(fences); 242 | self.accumulator.accumulate_timings(&mut set); 243 | 244 | unsafe { 245 | self.context.device.cmd_reset_query_pool( 246 | cmd, 247 | set.query_pool, 248 | 0, 249 | TimestampSet::MAX_QUERY_COUNT as u32, 250 | ) 251 | }; 252 | ScopedTimestampSet { set, owner: self } 253 | } 254 | } 255 | 256 | impl Drop for TimestampSets { 257 | fn drop(&mut self) { 258 | let device = &self.context.device; 259 | for set in self.sets.iter() { 260 | unsafe { 261 | let set = set.get_unchecked(); 262 | device.destroy_query_pool(Some(set.query_pool), None); 263 | } 264 | } 265 | } 266 | } 267 | 268 | pub(crate) struct ScopedTimestampSet<'a> { 269 | set: TimestampSet, 270 | owner: &'a mut TimestampSets, 271 | } 272 | 273 | impl<'a> ScopedTimestampSet<'a> { 274 | pub(crate) fn write_timestamp(&mut self, cmd: vk::CommandBuffer, name: &str) { 275 | self.set.write_timestamp(cmd, self.owner.name_id(name)) 276 | } 277 | 278 | pub(crate) fn end(&mut self, cmd: vk::CommandBuffer) { 279 | if let Some(id) = self.set.timestamp_ids.last().copied() { 280 | self.set.write_timestamp(cmd, id); 281 | } 282 | } 283 | 284 | pub(crate) fn recycle(self, fence: FenceId) { 285 | self.owner.sets.push_back(Fenced::new(self.set, fence)); 286 | } 287 | } 288 | -------------------------------------------------------------------------------- /src/module.rs: -------------------------------------------------------------------------------- 1 | use crate::common::*; 2 | use std::{io::Write, mem}; 3 | 4 | pub struct EvalContext { 5 | is_training: bool, 6 | } 7 | 8 | pub trait Module { 9 | fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s>; 10 | } 11 | 12 | pub trait ModuleExt: Module { 13 | fn train<'s>(&self, input: DualArray<'s>) -> DualArray<'s> { 14 | self.eval(input, &EvalContext { is_training: true }) 15 | } 16 | 17 | fn test<'s>(&self, input: DualArray<'s>) -> DualArray<'s> { 18 | self.eval(input, &EvalContext { is_training: false }) 19 | } 20 | } 21 | 22 | impl ModuleExt for T where T: Module + ?Sized {} 23 | 24 | pub trait ApplyModule { 25 | fn apply(self, module: &M, ctx: &EvalContext) -> Self; 26 | } 27 | 28 | impl<'s, M> ApplyModule for DualArray<'s> 29 | where 30 | M: Module + ?Sized, 31 | { 32 | fn apply(self, module: &M, ctx: &EvalContext) -> Self { 33 | self.map(|x| module.eval(x, ctx)) 34 | } 35 | } 36 | 37 | pub struct DenseBuilder { 38 | input: usize, 39 | output: usize, 40 | w_initializer: Initializer, 41 | b_initializer: Initializer, 42 | } 43 | 44 | impl DenseBuilder { 45 | pub fn with_w_initializer(mut self, w_initializer: Initializer) -> Self { 46 | self.w_initializer = w_initializer; 47 | self 48 | } 49 | 50 | pub fn with_b_initializer(mut self, b_initializer: Initializer) -> Self { 51 | self.b_initializer = b_initializer; 52 | self 53 | } 54 | 55 | pub fn build(self, env: &mut Environment) -> Dense { 56 | let DenseBuilder { 57 | input, 58 | output, 59 | w_initializer, 60 | b_initializer, 61 | } = self; 62 | 63 | let w = env.trainable_parameter([input, output], "w", w_initializer); 64 | let b = env.trainable_parameter([output], "b", b_initializer); 65 | 66 | Dense { w, b } 67 | } 68 | } 69 | 70 | pub struct Dense { 71 | w: Parameter, 72 | b: Parameter, // TODO: optional? 73 | } 74 | 75 | impl Dense { 76 | pub fn builder(input: usize, output: usize) -> DenseBuilder { 77 | DenseBuilder { 78 | input, 79 | output, 80 | w_initializer: Initializer::for_relu(input), 81 | b_initializer: Initializer::Zero, 82 | } 83 | } 84 | } 85 | 86 | impl Module for Dense { 87 | fn eval<'s>(&self, input: DualArray<'s>, _ctx: &EvalContext) -> DualArray<'s> { 88 | input.next_colour().matmul(&self.w) + &self.b 89 | } 90 | } 91 | 92 | pub struct Conv2DBuilder { 93 | input_channels: usize, 94 | output_channels: usize, 95 | filter: (usize, usize), 96 | pad: usize, 97 | stride: (usize, usize), 98 | groups: usize, 99 | is_blur: bool, 100 | } 101 | 102 | impl Conv2DBuilder { 103 | pub fn with_pad(mut self, pad: usize) -> Self { 104 | self.pad = pad; 105 | self 106 | } 107 | 108 | pub fn with_stride(mut self, stride_w: usize, stride_h: usize) -> Self { 109 | self.stride = (stride_w, stride_h); 110 | self 111 | } 112 | 113 | pub fn with_groups(mut self, groups: usize) -> Self { 114 | self.groups = groups; 115 | self 116 | } 117 | 118 | pub fn with_blur(mut self) -> Self { 119 | self.is_blur = true; 120 | self 121 | } 122 | 123 | pub fn build(self, env: &mut Environment) -> Conv2D { 124 | let Self { 125 | input_channels, 126 | output_channels, 127 | filter, 128 | pad, 129 | stride, 130 | groups, 131 | is_blur, 132 | } = self; 133 | let filter_ic = input_channels / groups; 134 | let filter_oc = output_channels / groups; 135 | assert_eq!(filter_ic * groups, input_channels); 136 | assert_eq!(filter_oc * groups, output_channels); 137 | let (filter_w, filter_h) = filter; 138 | 139 | let (f, b) = if is_blur { 140 | let f = env.static_parameter([groups, filter_oc, filter_h, filter_w, filter_ic], "f"); 141 | let b = env.static_parameter([output_channels], "b"); 142 | 143 | assert_eq!([filter_oc, filter_h, filter_w, filter_ic], [1, 3, 3, 1]); 144 | let f_data: [f32; 9] = [ 145 | 1.0 / 16.0, 146 | 2.0 / 16.0, 147 | 1.0 / 16.0, 148 | 2.0 / 16.0, 149 | 4.0 / 16.0, 150 | 2.0 / 16.0, 151 | 1.0 / 16.0, 152 | 2.0 / 16.0, 153 | 1.0 / 16.0, 154 | ]; 155 | 156 | let mut w = env.writer(&f); 157 | for _ in 0..groups { 158 | w.write_all(bytemuck::bytes_of(&f_data)).unwrap(); 159 | } 160 | mem::drop(w); 161 | env.writer(&b).zero_fill(); 162 | 163 | (f, b) 164 | } else { 165 | let f = env.trainable_parameter( 166 | [groups, filter_oc, filter_h, filter_w, filter_ic], 167 | "f", 168 | Initializer::for_relu(filter_h * filter_w * filter_ic), 169 | ); 170 | let b = env.trainable_parameter([output_channels], "b", Initializer::Zero); 171 | (f, b) 172 | }; 173 | 174 | Conv2D { f, b, pad, stride } 175 | } 176 | } 177 | 178 | pub struct Conv2D { 179 | f: Parameter, 180 | b: Parameter, // TODO: optional? 181 | pad: usize, 182 | stride: (usize, usize), 183 | } 184 | 185 | impl Conv2D { 186 | pub fn builder( 187 | input_channels: usize, 188 | output_channels: usize, 189 | filter_w: usize, 190 | filter_h: usize, 191 | ) -> Conv2DBuilder { 192 | Conv2DBuilder { 193 | input_channels, 194 | output_channels, 195 | filter: (filter_w, filter_h), 196 | pad: 0, 197 | stride: (1, 1), 198 | groups: 1, 199 | is_blur: false, 200 | } 201 | } 202 | } 203 | 204 | impl Module for Conv2D { 205 | fn eval<'s>(&self, input: DualArray<'s>, _ctx: &EvalContext) -> DualArray<'s> { 206 | let conv = input.next_colour().conv2d(&self.f, self.pad, self.stride); 207 | 208 | conv + &self.b 209 | } 210 | } 211 | 212 | #[derive(Default)] 213 | pub struct MaxPool2D {} 214 | 215 | impl Module for MaxPool2D { 216 | fn eval<'s>(&self, input: DualArray<'s>, _ctx: &EvalContext) -> DualArray<'s> { 217 | input.next_colour().max_pool2d((2, 2), (2, 2)) 218 | } 219 | } 220 | 221 | pub struct MaxBlurPool2D { 222 | blur: Conv2D, 223 | } 224 | 225 | impl MaxBlurPool2D { 226 | pub fn new(env: &mut Environment, channels: usize) -> Self { 227 | Self { 228 | blur: Conv2D::builder(channels, channels, 3, 3) 229 | .with_pad(1) 230 | .with_stride(2, 2) 231 | .with_groups(channels) 232 | .with_blur() 233 | .build(env), 234 | } 235 | } 236 | } 237 | 238 | impl Module for MaxBlurPool2D { 239 | fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s> { 240 | input 241 | .next_colour() 242 | .max_pool2d((2, 2), (1, 1)) 243 | .map(|x| self.blur.eval(x, ctx)) 244 | } 245 | } 246 | 247 | pub struct Dropout { 248 | amount: f32, 249 | } 250 | 251 | impl Dropout { 252 | pub fn new(amount: f32) -> Self { 253 | Self { amount } 254 | } 255 | } 256 | 257 | impl Module for Dropout { 258 | fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s> { 259 | if !ctx.is_training { 260 | return input; 261 | } 262 | 263 | let scope = input.scope(); 264 | let shape = input.shape(); 265 | 266 | scope.next_colour(); 267 | let rv = scope.rand(shape).value(); 268 | 269 | let (a, da) = input.into_inner(); 270 | 271 | let survivor_scale = 1.0 / (1.0 - self.amount); 272 | let (b, db) = rv 273 | .select_gt(self.amount, survivor_scale * a, 0.0) 274 | .with_empty_grad(); 275 | da.accumulate(rv.select_gt(self.amount, survivor_scale * db, 0.0)); 276 | 277 | (b, db).into() 278 | } 279 | } 280 | 281 | struct LSTMWeight { 282 | input: Parameter, 283 | hidden: Parameter, 284 | bias: Parameter, 285 | } 286 | 287 | impl LSTMWeight { 288 | fn new(env: &mut Environment, prefix: &str, input: usize, output: usize) -> Self { 289 | let input = env.trainable_parameter( 290 | [input, output], 291 | &format!("{}_wi", prefix), 292 | Initializer::RandNormal(0.01), 293 | ); 294 | let hidden = env.trainable_parameter( 295 | [output, output], 296 | &format!("{}_wh", prefix), 297 | Initializer::RandNormal(0.01), 298 | ); 299 | let bias = env.trainable_parameter([output], &format!("{}_b", prefix), Initializer::Zero); 300 | Self { 301 | input, 302 | hidden, 303 | bias, 304 | } 305 | } 306 | 307 | fn eval<'s>(&self, input: DualArray<'s>, hidden: Option>) -> DualArray<'s> { 308 | let mut x = input.matmul(&self.input); 309 | if let Some(hidden) = hidden { 310 | x += hidden.matmul(&self.hidden); 311 | } 312 | x + &self.bias 313 | } 314 | } 315 | 316 | pub struct LSTMCell { 317 | forget_gate: LSTMWeight, 318 | input_gate: LSTMWeight, 319 | output_gate: LSTMWeight, 320 | cell_input: LSTMWeight, 321 | } 322 | 323 | impl LSTMCell { 324 | pub fn new(env: &mut Environment, input: usize, output: usize) -> Self { 325 | Self { 326 | forget_gate: LSTMWeight::new(env, "forget", input, output), 327 | input_gate: LSTMWeight::new(env, "input", input, output), 328 | output_gate: LSTMWeight::new(env, "output", input, output), 329 | cell_input: LSTMWeight::new(env, "cell", input, output), 330 | } 331 | } 332 | } 333 | 334 | impl Module for LSTMCell { 335 | fn eval<'s>(&self, input: DualArray<'s>, _ctx: &EvalContext) -> DualArray<'s> { 336 | let time_axis = -2; 337 | let timestep_count = input.shape()[SignedIndex(time_axis)]; 338 | let mut prev_cell = None; 339 | let mut prev_hidden = None; 340 | for i in 0..timestep_count { 341 | let input = input.next_colour().lock_axis(time_axis, i, false); 342 | 343 | let input_gate = self.input_gate.eval(input, prev_hidden).sigmoid(); 344 | let output_gate = self.output_gate.eval(input, prev_hidden).sigmoid(); 345 | let cell_input = self.cell_input.eval(input, prev_hidden).tanh(); 346 | 347 | let mut cell = input_gate * cell_input; 348 | if let Some(prev_cell) = prev_cell { 349 | // TODO: fix dead code elimination for gradients 350 | // (disconnect accumulates that are not from the chosen loss) 351 | // then we can move the forget gate code out of this "if let" 352 | let forget_gate = self.forget_gate.eval(input, prev_hidden).sigmoid(); 353 | cell += forget_gate * prev_cell; 354 | } 355 | 356 | let hidden = output_gate * cell.tanh(); 357 | 358 | prev_cell = Some(cell); 359 | prev_hidden = Some(hidden); 360 | } 361 | prev_hidden.unwrap() 362 | } 363 | } 364 | -------------------------------------------------------------------------------- /src/device/staging.rs: -------------------------------------------------------------------------------- 1 | use super::common::*; 2 | use spark::vk; 3 | use std::{collections::VecDeque, slice}; 4 | 5 | #[derive(Debug, Clone, Copy)] 6 | struct StagingCursor { 7 | next: usize, 8 | end: usize, 9 | region: StagingBufferRegion, 10 | } 11 | 12 | impl StagingCursor { 13 | fn new(region: StagingBufferRegion, max_size: usize) -> Self { 14 | let begin = region.begin(); 15 | let size = StagingBuffer::REGION_SIZE.min(max_size); 16 | Self { 17 | next: begin, 18 | end: begin + size, 19 | region, 20 | } 21 | } 22 | 23 | fn is_empty(&self) -> bool { 24 | self.next == self.region.begin() 25 | } 26 | 27 | fn is_full(&self) -> bool { 28 | self.next == self.end 29 | } 30 | 31 | fn remaining(&self) -> usize { 32 | self.end - self.next 33 | } 34 | } 35 | 36 | #[derive(Debug, Clone, Copy)] 37 | struct BufferCursor { 38 | next: usize, 39 | info: BufferInfo, 40 | } 41 | 42 | impl BufferCursor { 43 | fn new(info: BufferInfo) -> Self { 44 | Self { 45 | next: info.range.begin, 46 | info, 47 | } 48 | } 49 | 50 | fn remaining(&self) -> usize { 51 | self.info.range.end - self.next 52 | } 53 | 54 | fn is_finished(&self) -> bool { 55 | self.next == self.info.range.end 56 | } 57 | } 58 | 59 | pub(crate) struct StagingWriter<'a> { 60 | owner: &'a mut StagingBuffer, 61 | command_buffers: &'a mut CommandBuffers, 62 | fences: &'a mut FenceSet, 63 | staging: Option, 64 | buffer: BufferCursor, 65 | } 66 | 67 | impl<'a> StagingWriter<'a> { 68 | pub(crate) fn new( 69 | owner: &'a mut StagingBuffer, 70 | command_buffers: &'a mut CommandBuffers, 71 | fences: &'a mut FenceSet, 72 | buffer_info: BufferInfo, 73 | ) -> Self { 74 | let mut writer = Self { 75 | owner, 76 | command_buffers, 77 | fences, 78 | staging: None, 79 | buffer: BufferCursor::new(buffer_info), 80 | }; 81 | writer.next_staging(); 82 | writer 83 | } 84 | 85 | pub(crate) fn write_slice(&mut self, mut buf: &[u8]) -> usize { 86 | let mut counter = 0; 87 | while let Some(staging) = self.staging.as_mut() { 88 | let copy_buf = self.owner.mapping(staging); 89 | let copy_size = copy_buf.len().min(buf.len()); 90 | copy_buf[..copy_size].copy_from_slice(&buf[..copy_size]); 91 | 92 | staging.next += copy_size; 93 | buf = &buf[copy_size..]; 94 | counter += copy_size; 95 | 96 | if staging.is_full() { 97 | self.flush_staging(); 98 | } 99 | if buf.is_empty() { 100 | break; 101 | } 102 | } 103 | counter 104 | } 105 | 106 | pub(crate) fn write_zeros(&mut self, mut count: usize) -> usize { 107 | let mut counter = 0; 108 | while let Some(staging) = self.staging.as_mut() { 109 | let copy_buf = self.owner.mapping(staging); 110 | let copy_size = copy_buf.len().min(count); 111 | for b in copy_buf.iter_mut().take(copy_size) { 112 | *b = 0; 113 | } 114 | 115 | staging.next += copy_size; 116 | count -= copy_size; 117 | counter += copy_size; 118 | 119 | if staging.is_full() { 120 | self.flush_staging(); 121 | } 122 | if count == 0 { 123 | break; 124 | } 125 | } 126 | counter 127 | } 128 | 129 | fn next_staging(&mut self) { 130 | assert!(self.staging.is_none()); 131 | let max_size = self.buffer.remaining(); 132 | if max_size > 0 { 133 | let range = self 134 | .owner 135 | .regions 136 | .pop_front() 137 | .unwrap() 138 | .take_when_signaled(self.fences); 139 | self.staging = Some(StagingCursor::new(range, max_size)); 140 | } 141 | } 142 | 143 | pub(crate) fn flush_staging(&mut self) { 144 | if let Some(staging) = self.staging.take() { 145 | if staging.is_empty() { 146 | self.staging = Some(staging); 147 | } else { 148 | let cmd = self.command_buffers.acquire(self.fences); 149 | 150 | let staging_begin = staging.region.begin(); 151 | let transfer_size = staging.next - staging_begin; 152 | { 153 | let region = vk::BufferCopy { 154 | src_offset: staging_begin as vk::DeviceSize, 155 | dst_offset: self.buffer.next as vk::DeviceSize, 156 | size: transfer_size as vk::DeviceSize, 157 | }; 158 | 159 | unsafe { 160 | self.owner.context.device.cmd_copy_buffer( 161 | cmd.get(), 162 | self.owner.buffer, 163 | self.buffer.info.buffer, 164 | slice::from_ref(®ion), 165 | ) 166 | }; 167 | } 168 | self.buffer.next += transfer_size; 169 | 170 | let fence_id = cmd.submit(self.fences); 171 | self.owner 172 | .regions 173 | .push_back(Fenced::new(staging.region, fence_id)); 174 | 175 | self.next_staging(); 176 | } 177 | } 178 | } 179 | } 180 | 181 | impl<'a> Drop for StagingWriter<'a> { 182 | fn drop(&mut self) { 183 | self.write_zeros(self.buffer.remaining()); 184 | self.flush_staging(); 185 | assert!(self.staging.is_none()); 186 | } 187 | } 188 | 189 | pub(crate) struct StagingReader<'a> { 190 | owner: &'a mut StagingBuffer, 191 | command_buffers: &'a mut CommandBuffers, 192 | fences: &'a mut FenceSet, 193 | buffer: BufferCursor, 194 | pending: VecDeque>, 195 | staging: Option, 196 | } 197 | 198 | impl<'a> StagingReader<'a> { 199 | pub(crate) fn new( 200 | owner: &'a mut StagingBuffer, 201 | command_buffers: &'a mut CommandBuffers, 202 | fences: &'a mut FenceSet, 203 | buffer_info: BufferInfo, 204 | ) -> Self { 205 | let mut reader = Self { 206 | owner, 207 | command_buffers, 208 | fences, 209 | buffer: BufferCursor::new(buffer_info), 210 | pending: VecDeque::new(), 211 | staging: None, 212 | }; 213 | while !reader.buffer.is_finished() { 214 | if let Some(region) = reader.owner.regions.pop_front() { 215 | let region = region.take_when_signaled(reader.fences); 216 | reader.add_pending(region); 217 | } else { 218 | break; 219 | } 220 | } 221 | reader.next_staging(); 222 | reader 223 | } 224 | 225 | pub(crate) fn peek(&mut self) -> Option<&[u8]> { 226 | let owner = &mut self.owner; 227 | self.staging 228 | .as_ref() 229 | .map(move |staging| owner.mapping(staging) as &_) 230 | } 231 | 232 | pub(crate) fn advance(&mut self, amt: usize) { 233 | let staging = self.staging.as_mut().unwrap(); 234 | assert!(amt <= staging.remaining()); 235 | staging.next += amt; 236 | if staging.is_full() { 237 | let region = staging.region; 238 | self.staging = None; 239 | self.add_pending(region); 240 | self.next_staging(); 241 | } 242 | } 243 | 244 | pub(crate) fn read_slice(&mut self, buf: &mut [u8]) -> usize { 245 | let limit = buf.len(); 246 | let mut offset = 0; 247 | while let Some(copy_buf) = self.peek() { 248 | let copy_size = copy_buf.len().min(limit - offset); 249 | let next = offset + copy_size; 250 | buf[offset..next].copy_from_slice(©_buf[..copy_size]); 251 | 252 | offset = next; 253 | self.advance(copy_size); 254 | 255 | if offset == limit { 256 | break; 257 | } 258 | } 259 | offset 260 | } 261 | 262 | fn add_pending(&mut self, region: StagingBufferRegion) { 263 | if self.buffer.is_finished() { 264 | self.owner 265 | .regions 266 | .push_back(Fenced::new(region, self.fences.old_id())); 267 | } else { 268 | let cmd = self.command_buffers.acquire(self.fences); 269 | 270 | let staging = StagingCursor::new(region, self.buffer.remaining()); 271 | let staging_begin = staging.region.begin(); 272 | let transfer_size = staging.end - staging_begin; 273 | { 274 | let region = vk::BufferCopy { 275 | src_offset: self.buffer.next as vk::DeviceSize, 276 | dst_offset: staging_begin as vk::DeviceSize, 277 | size: transfer_size as vk::DeviceSize, 278 | }; 279 | 280 | unsafe { 281 | self.owner.context.device.cmd_copy_buffer( 282 | cmd.get(), 283 | self.buffer.info.buffer, 284 | self.owner.buffer, 285 | slice::from_ref(®ion), 286 | ) 287 | }; 288 | } 289 | self.buffer.next += transfer_size; 290 | 291 | let fence_id = cmd.submit(self.fences); 292 | self.pending.push_back(Fenced::new(staging, fence_id)); 293 | } 294 | } 295 | 296 | fn next_staging(&mut self) { 297 | assert!(self.staging.is_none()); 298 | if let Some(pending) = self.pending.pop_front() { 299 | self.staging = Some(pending.take_when_signaled(self.fences)); 300 | } 301 | } 302 | } 303 | 304 | impl<'a> Drop for StagingReader<'a> { 305 | fn drop(&mut self) { 306 | if let Some(staging) = self.staging.take() { 307 | self.owner 308 | .regions 309 | .push_back(Fenced::new(staging.region, self.fences.old_id())); 310 | } 311 | while let Some(pending) = self.pending.pop_front() { 312 | self.owner 313 | .regions 314 | .push_back(pending.map(|pending| pending.region)); 315 | } 316 | } 317 | } 318 | 319 | #[derive(Debug, Clone, Copy)] 320 | struct StagingBufferRegion(u8); 321 | 322 | impl StagingBufferRegion { 323 | fn begin(&self) -> usize { 324 | (self.0 as usize) * StagingBuffer::REGION_SIZE 325 | } 326 | } 327 | 328 | pub(crate) struct StagingBuffer { 329 | context: SharedContext, 330 | device_memory: vk::DeviceMemory, 331 | buffer: vk::Buffer, 332 | mapping: *mut u8, 333 | regions: VecDeque>, 334 | } 335 | 336 | impl StagingBuffer { 337 | const REGION_SIZE: usize = 4 * 1024 * 1024; 338 | const COUNT: usize = 2; 339 | 340 | pub(crate) fn new(context: &SharedContext, fences: &FenceSet) -> Self { 341 | let device = &context.device; 342 | let buffer = { 343 | let buffer_create_info = vk::BufferCreateInfo { 344 | size: (Self::REGION_SIZE * Self::COUNT) as vk::DeviceSize, 345 | usage: vk::BufferUsageFlags::TRANSFER_DST | vk::BufferUsageFlags::TRANSFER_SRC, 346 | ..Default::default() 347 | }; 348 | unsafe { device.create_buffer(&buffer_create_info, None) }.unwrap() 349 | }; 350 | let mem_req = unsafe { device.get_buffer_memory_requirements(buffer) }; 351 | let device_memory = { 352 | let memory_type_index = context 353 | .get_memory_type_index( 354 | mem_req.memory_type_bits, 355 | vk::MemoryPropertyFlags::HOST_VISIBLE 356 | | vk::MemoryPropertyFlags::HOST_COHERENT 357 | | vk::MemoryPropertyFlags::HOST_CACHED, 358 | ) 359 | .unwrap(); 360 | let memory_allocate_info = vk::MemoryAllocateInfo { 361 | allocation_size: mem_req.size, 362 | memory_type_index, 363 | ..Default::default() 364 | }; 365 | unsafe { device.allocate_memory(&memory_allocate_info, None) }.unwrap() 366 | }; 367 | unsafe { device.bind_buffer_memory(buffer, device_memory, 0) }.unwrap(); 368 | let mapping = unsafe { 369 | context.device.map_memory( 370 | device_memory, 371 | 0, 372 | vk::WHOLE_SIZE, 373 | vk::MemoryMapFlags::empty(), 374 | ) 375 | } 376 | .unwrap(); 377 | 378 | let mut ranges = VecDeque::new(); 379 | for i in 0..Self::COUNT { 380 | ranges.push_back(Fenced::new(StagingBufferRegion(i as u8), fences.old_id())); 381 | } 382 | 383 | Self { 384 | context: SharedContext::clone(context), 385 | device_memory, 386 | buffer, 387 | mapping: mapping as *mut _, 388 | regions: ranges, 389 | } 390 | } 391 | 392 | fn mapping(&mut self, cursor: &StagingCursor) -> &mut [u8] { 393 | let full = 394 | unsafe { slice::from_raw_parts_mut(self.mapping, Self::REGION_SIZE * Self::COUNT) }; 395 | &mut full[cursor.next..cursor.end] 396 | } 397 | } 398 | 399 | impl Drop for StagingBuffer { 400 | fn drop(&mut self) { 401 | let device = &self.context.device; 402 | unsafe { 403 | device.destroy_buffer(Some(self.buffer), None); 404 | device.free_memory(Some(self.device_memory), None); 405 | } 406 | } 407 | } 408 | -------------------------------------------------------------------------------- /src/device/heap.rs: -------------------------------------------------------------------------------- 1 | use slotmap::{Key, SlotMap}; 2 | use std::fmt::Debug; 3 | use trait_set::trait_set; 4 | 5 | #[derive(Debug, Clone, Copy)] 6 | struct BlockListNode { 7 | prev_id: K, 8 | next_id: K, 9 | } 10 | 11 | impl BlockListNode { 12 | fn new(id: K) -> Self { 13 | Self { 14 | prev_id: id, 15 | next_id: id, 16 | } 17 | } 18 | } 19 | 20 | #[derive(Debug, Clone, Copy)] 21 | pub(crate) struct HeapRange { 22 | pub(crate) begin: usize, 23 | pub(crate) end: usize, 24 | } 25 | 26 | impl HeapRange { 27 | fn from_size(size: usize) -> Self { 28 | Self { 29 | begin: 0, 30 | end: size, 31 | } 32 | } 33 | 34 | pub(crate) fn size(&self) -> usize { 35 | self.end - self.begin 36 | } 37 | 38 | fn truncate(&mut self, new_size: usize) -> HeapRange { 39 | assert!(new_size > 0); 40 | let begin = self.begin + new_size; 41 | let end = self.end; 42 | assert!(begin < end); 43 | self.end = begin; 44 | HeapRange { begin, end } 45 | } 46 | 47 | fn append(&mut self, other: HeapRange) { 48 | assert_eq!(self.end, other.begin); 49 | self.end = other.end; 50 | } 51 | } 52 | 53 | trait_set! { 54 | pub(crate) trait Tag = Debug + Clone; 55 | } 56 | 57 | #[derive(Debug, Clone, Copy)] 58 | struct Block { 59 | tag: T, 60 | range: HeapRange, 61 | tag_node: BlockListNode, // linked list of blocks with this tag 62 | free_node: Option>, // linked list of similarly sized free blocks 63 | } 64 | 65 | impl Block { 66 | fn new(id: K, tag: T, range: HeapRange) -> Self { 67 | Self { 68 | tag, 69 | range, 70 | tag_node: BlockListNode::new(id), 71 | free_node: None, 72 | } 73 | } 74 | 75 | fn can_append(&self, other: &Block) -> bool { 76 | self.range.end == other.range.begin 77 | } 78 | } 79 | 80 | #[derive(Debug, Clone)] 81 | pub(crate) struct HeapAllocInfo { 82 | pub(crate) tag: T, 83 | pub(crate) range: HeapRange, 84 | } 85 | 86 | #[derive(Debug, Clone, Copy)] 87 | pub(crate) struct HeapStats { 88 | pub(crate) alloc_count: usize, 89 | pub(crate) total_alloc_size: usize, 90 | pub(crate) total_free_size: usize, 91 | pub(crate) largest_free_size: usize, 92 | } 93 | 94 | type BlockSlotMap = SlotMap>; 95 | 96 | #[derive(Debug)] 97 | pub(crate) struct Heap { 98 | blocks: BlockSlotMap, 99 | free_lists: Vec>, 100 | } 101 | 102 | impl Default for Heap { 103 | fn default() -> Self { 104 | Self { 105 | blocks: BlockSlotMap::with_key(), 106 | free_lists: Vec::new(), 107 | } 108 | } 109 | } 110 | 111 | impl Heap { 112 | fn free_list_index(size: usize) -> usize { 113 | (0usize.leading_zeros() - size.leading_zeros()) as usize 114 | } 115 | 116 | pub(crate) fn extend_with(&mut self, tag: T, size: usize) { 117 | let free_list_index = Self::free_list_index(size); 118 | 119 | while free_list_index >= self.free_lists.len() { 120 | self.free_lists.push(None); 121 | } 122 | 123 | let id = self 124 | .blocks 125 | .insert_with_key(|key| Block::new(key, tag, HeapRange::from_size(size))); 126 | Self::register_free_block(&mut self.blocks, self.free_lists.as_mut_slice(), id); 127 | } 128 | 129 | fn register_free_block( 130 | blocks: &mut BlockSlotMap, 131 | free_lists: &mut [Option], 132 | alloc_id: K, 133 | ) { 134 | let size = { 135 | let block = &blocks[alloc_id]; 136 | assert!(block.free_node.is_none()); 137 | block.range.size() 138 | }; 139 | let free_list_index = Self::free_list_index(size); 140 | if let Some(next_id) = free_lists[free_list_index] { 141 | let prev_id = blocks[next_id].free_node.unwrap().prev_id; 142 | if prev_id == next_id { 143 | let [other, alloc] = blocks.get_disjoint_mut([prev_id, alloc_id]).unwrap(); 144 | other.free_node = Some(BlockListNode::new(alloc_id)); 145 | alloc.free_node = Some(BlockListNode::new(prev_id)); 146 | } else { 147 | let [prev, alloc, next] = blocks 148 | .get_disjoint_mut([prev_id, alloc_id, next_id]) 149 | .unwrap(); 150 | prev.free_node.as_mut().unwrap().next_id = alloc_id; 151 | alloc.free_node = Some(BlockListNode { prev_id, next_id }); 152 | next.free_node.as_mut().unwrap().prev_id = alloc_id; 153 | } 154 | } else { 155 | blocks[alloc_id].free_node = Some(BlockListNode::new(alloc_id)); 156 | } 157 | free_lists[free_list_index] = Some(alloc_id); 158 | } 159 | 160 | fn unregister_free_block( 161 | blocks: &mut BlockSlotMap, 162 | free_lists: &mut [Option], 163 | free_id: K, 164 | ) { 165 | let (size, BlockListNode { prev_id, next_id }) = { 166 | let block = &blocks[free_id]; 167 | (block.range.size(), block.free_node.unwrap()) 168 | }; 169 | let free_list_index = Self::free_list_index(size); 170 | let head_id = if prev_id == free_id { 171 | assert_eq!(next_id, free_id); 172 | None 173 | } else if prev_id == next_id { 174 | blocks[prev_id].free_node = Some(BlockListNode::new(prev_id)); 175 | Some(prev_id) 176 | } else { 177 | let [prev, next] = blocks.get_disjoint_mut([prev_id, next_id]).unwrap(); 178 | prev.free_node.as_mut().unwrap().next_id = next_id; 179 | next.free_node.as_mut().unwrap().prev_id = prev_id; 180 | Some(next_id) 181 | }; 182 | free_lists[free_list_index] = head_id; 183 | blocks[free_id].free_node = None; 184 | } 185 | 186 | fn truncate_block(blocks: &mut BlockSlotMap, orig_id: K, new_size: usize) -> K { 187 | let (next_id, new_id) = { 188 | let orig_block = &mut blocks[orig_id]; 189 | let next_id = orig_block.tag_node.next_id; 190 | let tag = orig_block.tag.clone(); 191 | let range = orig_block.range.truncate(new_size); 192 | let new_id = blocks.insert_with_key(|key| Block::new(key, tag, range)); 193 | (next_id, new_id) 194 | }; 195 | 196 | if orig_id == next_id { 197 | let [orig, new] = blocks.get_disjoint_mut([orig_id, new_id]).unwrap(); 198 | orig.tag_node = BlockListNode::new(new_id); 199 | new.tag_node = BlockListNode::new(orig_id); 200 | } else { 201 | let prev_id = orig_id; 202 | let [prev, new, next] = blocks.get_disjoint_mut([prev_id, new_id, next_id]).unwrap(); 203 | prev.tag_node.next_id = new_id; 204 | new.tag_node = BlockListNode { prev_id, next_id }; 205 | next.tag_node.prev_id = new_id; 206 | } 207 | 208 | new_id 209 | } 210 | 211 | fn append_block(blocks: &mut BlockSlotMap, orig_id: K, append_id: K) { 212 | let [orig_block, append_block] = blocks.get_disjoint_mut([orig_id, append_id]).unwrap(); 213 | orig_block.range.append(append_block.range); 214 | 215 | let next_id = append_block.tag_node.next_id; 216 | if orig_id == next_id { 217 | orig_block.tag_node = BlockListNode::new(orig_id); 218 | } else { 219 | let [orig_block, next_block] = blocks.get_disjoint_mut([orig_id, next_id]).unwrap(); 220 | orig_block.tag_node.next_id = next_id; 221 | next_block.tag_node.prev_id = orig_id; 222 | } 223 | 224 | blocks.remove(append_id).unwrap(); 225 | } 226 | 227 | #[allow(dead_code)] 228 | fn print_state(&self) { 229 | println!("stats: {:?}", self.stats()); 230 | for (index, first_block_id) in self.free_lists.iter().copied().enumerate() { 231 | println!("free list {}:", index); 232 | if let Some(first_block_id) = first_block_id { 233 | let mut block_id = first_block_id; 234 | loop { 235 | let block = &self.blocks[block_id]; 236 | println!("{:?} = {:?}", block_id, block); 237 | block_id = block.free_node.unwrap().next_id; 238 | if block_id == first_block_id { 239 | break; 240 | } 241 | } 242 | } 243 | } 244 | println!("allocated list:"); 245 | for (block_id, block) in self.blocks.iter() { 246 | if block.free_node.is_none() { 247 | println!("{:?} = {:?}", block_id, block); 248 | } 249 | } 250 | } 251 | 252 | pub(crate) fn stats(&self) -> HeapStats { 253 | let mut stats = HeapStats { 254 | alloc_count: 0, 255 | total_alloc_size: 0, 256 | total_free_size: 0, 257 | largest_free_size: 0, 258 | }; 259 | for block in self.blocks.values() { 260 | let size = block.range.size(); 261 | if block.free_node.is_none() { 262 | stats.alloc_count += 1; 263 | stats.total_alloc_size += size; 264 | } else { 265 | stats.total_free_size += size; 266 | stats.largest_free_size = stats.largest_free_size.max(size); 267 | } 268 | } 269 | stats 270 | } 271 | 272 | pub(crate) fn alloc(&mut self, size: usize, align: usize) -> Option { 273 | let blocks = &mut self.blocks; 274 | let free_lists = self.free_lists.as_mut_slice(); 275 | 276 | let align_mask = align - 1; 277 | let start_free_list_index = Self::free_list_index(size); 278 | for first_block_id in free_lists 279 | .get(start_free_list_index..)? 280 | .iter() 281 | .copied() 282 | .flatten() 283 | { 284 | let mut block_id = first_block_id; 285 | loop { 286 | let block_range = blocks[block_id].range; 287 | let aligned_begin = (block_range.begin + align_mask) & !align_mask; 288 | let aligned_end = aligned_begin + size; 289 | if aligned_end <= block_range.end { 290 | Self::unregister_free_block(blocks, free_lists, block_id); 291 | if aligned_begin != block_range.begin { 292 | let aligned_id = Self::truncate_block( 293 | blocks, 294 | block_id, 295 | aligned_begin - block_range.begin, 296 | ); 297 | Self::register_free_block(blocks, free_lists, block_id); 298 | block_id = aligned_id; 299 | } 300 | if aligned_end != block_range.end { 301 | let unused_id = Self::truncate_block(blocks, block_id, size); 302 | Self::register_free_block(blocks, free_lists, unused_id); 303 | } 304 | return Some(block_id); 305 | } 306 | block_id = blocks[block_id].free_node.unwrap().next_id; 307 | if block_id == first_block_id { 308 | break; 309 | } 310 | } 311 | } 312 | None 313 | } 314 | 315 | pub(crate) fn info(&self, id: K) -> HeapAllocInfo { 316 | let block = &self.blocks[id]; 317 | HeapAllocInfo { 318 | tag: block.tag.clone(), 319 | range: block.range, 320 | } 321 | } 322 | 323 | pub(crate) fn free(&mut self, id: K) { 324 | let blocks = &mut self.blocks; 325 | let free_lists = self.free_lists.as_mut_slice(); 326 | 327 | let block = &blocks[id]; 328 | assert!(block.free_node.is_none()); 329 | let next_id = block.tag_node.next_id; 330 | let next = &blocks[next_id]; 331 | if next.free_node.is_some() && block.can_append(next) { 332 | Self::unregister_free_block(blocks, free_lists, next_id); 333 | Self::append_block(blocks, id, next_id); 334 | } 335 | 336 | let block = &blocks[id]; 337 | let prev_id = block.tag_node.prev_id; 338 | let prev = &blocks[prev_id]; 339 | if prev.free_node.is_some() && prev.can_append(block) { 340 | Self::unregister_free_block(blocks, free_lists, prev_id); 341 | Self::append_block(blocks, prev_id, id); 342 | Self::register_free_block(blocks, free_lists, prev_id); 343 | } else { 344 | Self::register_free_block(blocks, free_lists, id); 345 | } 346 | } 347 | } 348 | 349 | #[cfg(test)] 350 | mod tests { 351 | use super::*; 352 | 353 | slotmap::new_key_type! { 354 | struct Id; 355 | } 356 | 357 | #[test] 358 | fn heap() { 359 | let mut heap = Heap::default(); 360 | heap.extend_with(0usize, 1000); 361 | 362 | let ai: Id = heap.alloc(1000, 4).unwrap(); 363 | heap.free(ai); 364 | 365 | let ai = heap.alloc(500, 4).unwrap(); 366 | heap.print_state(); 367 | let bi = heap.alloc(500, 4).unwrap(); 368 | heap.print_state(); 369 | heap.free(ai); 370 | heap.print_state(); 371 | let ci = heap.alloc(250, 2).unwrap(); 372 | let di = heap.alloc(250, 2).unwrap(); 373 | heap.print_state(); 374 | heap.free(bi); 375 | heap.print_state(); 376 | heap.free(ci); 377 | heap.print_state(); 378 | heap.free(di); 379 | heap.print_state(); 380 | 381 | let ei = heap.alloc(1000, 4).unwrap(); 382 | heap.free(ei); 383 | } 384 | } 385 | -------------------------------------------------------------------------------- /examples/image_fit/main.rs: -------------------------------------------------------------------------------- 1 | use descent::{module::*, optimizer::*, prelude::*}; 2 | use rand::{Rng, RngCore, SeedableRng}; 3 | use stb::image; 4 | use std::{ 5 | f32::consts::PI, 6 | ffi::CString, 7 | fs::File, 8 | io::{BufWriter, Write}, 9 | mem, 10 | path::PathBuf, 11 | }; 12 | use structopt::StructOpt; 13 | use strum::{EnumString, EnumVariantNames, VariantNames}; 14 | 15 | #[derive(Debug, EnumString, EnumVariantNames)] 16 | #[strum(serialize_all = "kebab_case")] 17 | enum NetworkType { 18 | Relu, 19 | ReluPE, 20 | Siren, 21 | MultiHash, 22 | } 23 | 24 | #[derive(Debug, StructOpt)] 25 | #[structopt( 26 | no_version, 27 | name = "image_fit", 28 | about = "Example networks to fit a single image." 29 | )] 30 | struct AppParams { 31 | #[structopt(possible_values=&NetworkType::VARIANTS, default_value="siren")] 32 | network: NetworkType, 33 | 34 | #[structopt(long)] 35 | show_timings: bool, 36 | 37 | #[structopt(long)] 38 | quiet: bool, 39 | 40 | #[structopt(long)] 41 | csv_file_name: Option, 42 | 43 | #[structopt(long)] 44 | image_prefix: Option, 45 | 46 | #[structopt(long)] 47 | output_all_images: bool, 48 | } 49 | 50 | struct Relu { 51 | freq_count: usize, 52 | hidden_layers: Vec, 53 | final_layer: Dense, 54 | } 55 | 56 | impl Relu { 57 | fn new(env: &mut Environment, freq_count: usize, hidden_units: &[usize]) -> Self { 58 | let mut hidden_layers = Vec::new(); 59 | let mut prev_units = if freq_count == 0 { 2 } else { 4 * freq_count }; 60 | for hidden_units in hidden_units.iter().copied() { 61 | hidden_layers.push(Dense::builder(prev_units, hidden_units).build(env)); 62 | prev_units = hidden_units; 63 | } 64 | Self { 65 | freq_count, 66 | hidden_layers, 67 | final_layer: Dense::builder(prev_units, 3).build(env), 68 | } 69 | } 70 | } 71 | 72 | impl Module for Relu { 73 | fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s> { 74 | let mut x = input; 75 | if self.freq_count != 0 { 76 | x = positional_encoding(input, self.freq_count); 77 | } 78 | for layer in self.hidden_layers.iter() { 79 | x = x.apply(layer, ctx).leaky_relu(0.01); 80 | } 81 | x.apply(&self.final_layer, ctx) 82 | } 83 | } 84 | 85 | struct Siren { 86 | hidden_layers: Vec, 87 | final_layer: Dense, 88 | } 89 | 90 | impl Siren { 91 | fn new(env: &mut Environment, hidden_units: &[usize]) -> Self { 92 | let mut hidden_layers = Vec::new(); 93 | let mut prev_units = 2; 94 | for (index, hidden_units) in hidden_units.iter().copied().enumerate() { 95 | hidden_layers.push( 96 | Dense::builder(prev_units, hidden_units) 97 | .with_w_initializer(Initializer::for_siren(prev_units, index == 0)) 98 | .with_b_initializer(Initializer::RandUniform(1.0)) 99 | .build(env), 100 | ); 101 | prev_units = hidden_units; 102 | } 103 | Self { 104 | hidden_layers, 105 | final_layer: Dense::builder(prev_units, 3).build(env), 106 | } 107 | } 108 | } 109 | 110 | impl Module for Siren { 111 | fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s> { 112 | let mut x = input; 113 | for layer in self.hidden_layers.iter() { 114 | x = x.apply(layer, ctx).sin(); 115 | } 116 | x.apply(&self.final_layer, ctx) 117 | } 118 | } 119 | 120 | struct HashGrid { 121 | grid_size: usize, 122 | stride: usize, 123 | t: Parameter, 124 | } 125 | 126 | impl HashGrid { 127 | fn new( 128 | env: &mut Environment, 129 | grid_size: usize, 130 | entry_count: usize, 131 | values_per_entry: usize, 132 | ) -> Self { 133 | let grid_point_count = grid_size + 1; 134 | let max_entry_count = grid_point_count * grid_point_count; 135 | let entry_count = entry_count.min(max_entry_count); 136 | let stride = if entry_count == max_entry_count { 137 | grid_point_count 138 | } else { 139 | 1526263 // large prime 140 | }; 141 | let t = env.trainable_parameter( 142 | [entry_count, values_per_entry], 143 | "t", 144 | Initializer::RandUniform(1.0E-4), 145 | ); 146 | Self { 147 | grid_size, 148 | stride, 149 | t, 150 | } 151 | } 152 | } 153 | 154 | impl Module for HashGrid { 155 | fn eval<'s>(&self, input: DualArray<'s>, _ctx: &EvalContext) -> DualArray<'s> { 156 | let scope = input.scope(); 157 | let (x, _dx) = input.next_colour().into_inner(); 158 | 159 | let (t, dt) = scope.parameter(&self.t).into_inner(); 160 | let entry_count = t.shape()[0]; 161 | let stride = self.stride as u32; 162 | 163 | let cf = (x * 0.5 + 0.5) * (self.grid_size as f32); 164 | let c = cf.into_u32(); 165 | let f = cf - c.into_f32(); 166 | 167 | let c0 = c.lock_axis(-1, 0, false); 168 | let c1 = c.lock_axis(-1, 1, false); 169 | let f0 = f.lock_axis(-1, 0, true); 170 | let f1 = f.lock_axis(-1, 1, true); 171 | 172 | let ia = ((c0 + 0) ^ (c1 * stride + 0)) % (entry_count as u32); 173 | let ib = ((c0 + 1) ^ (c1 * stride + 0)) % (entry_count as u32); 174 | let ic = ((c0 + 0) ^ (c1 * stride + stride)) % (entry_count as u32); 175 | let id = ((c0 + 1) ^ (c1 * stride + stride)) % (entry_count as u32); 176 | 177 | let ta = t.gather(-2, ia); 178 | let tb = t.gather(-2, ib); 179 | let tc = t.gather(-2, ic); 180 | let td = t.gather(-2, id); 181 | let g0 = 1.0 - f0; 182 | let g1 = 1.0 - f1; 183 | let wa = g0 * g1; 184 | let wb = f0 * g1; 185 | let wc = g0 * f1; 186 | let wd = f0 * f1; 187 | 188 | let (y, dy) = (ta * wa + tb * wb + tc * wc + td * wd).with_empty_grad(); 189 | 190 | dt.accumulate( 191 | scope 192 | .literal(0.0) 193 | .value() 194 | .broadcast(dt.shape()) 195 | .scatter_add(dy * wa, -2, ia) 196 | .scatter_add(dy * wb, -2, ib) 197 | .scatter_add(dy * wc, -2, ic) 198 | .scatter_add(dy * wd, -2, id), 199 | ); 200 | 201 | (y, dy).into() 202 | } 203 | } 204 | 205 | struct MultiHashGrid { 206 | grids: Vec, 207 | hidden_layers: Vec, 208 | final_layer: Dense, 209 | } 210 | 211 | impl MultiHashGrid { 212 | fn new( 213 | env: &mut Environment, 214 | min_grid_size: usize, 215 | max_grid_size: usize, 216 | level_count: usize, 217 | entry_count: usize, 218 | hidden_units: &[usize], 219 | ) -> Self { 220 | let values_per_entry = 2; 221 | let mut grids = Vec::new(); 222 | let b = (((max_grid_size as f32).ln() - (min_grid_size as f32).ln()) 223 | / ((level_count - 1) as f32)) 224 | .exp(); 225 | println!("b = {}", b); 226 | for level_index in 0..level_count { 227 | let grid_size = ((min_grid_size as f32) * b.powi(level_index as i32)) as usize; 228 | grids.push(HashGrid::new(env, grid_size, entry_count, values_per_entry)); 229 | } 230 | let mut hidden_layers = Vec::new(); 231 | let mut prev_units = grids.len() * values_per_entry; 232 | for hidden_units in hidden_units.iter().copied() { 233 | hidden_layers.push(Dense::builder(prev_units, hidden_units).build(env)); 234 | prev_units = hidden_units; 235 | } 236 | Self { 237 | grids, 238 | hidden_layers, 239 | final_layer: Dense::builder(prev_units, 3).build(env), 240 | } 241 | } 242 | } 243 | 244 | impl Module for MultiHashGrid { 245 | fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s> { 246 | let mut x = self 247 | .grids 248 | .iter() 249 | .map(|grid| grid.eval(input, ctx)) 250 | .reduce(|a, b| a.concat(b, -1)) 251 | .unwrap(); 252 | for layer in self.hidden_layers.iter() { 253 | x = layer.eval(x, ctx).leaky_relu(0.01); 254 | } 255 | self.final_layer.eval(x, ctx) 256 | } 257 | } 258 | 259 | fn positional_encoding<'s>(x: DualArray<'s>, freq_count: usize) -> DualArray<'s> { 260 | let scope = x.scope(); 261 | 262 | let freq = scope.literal(2.0).pow(scope.coord(freq_count)) * PI; 263 | let phase = scope.coord(2).reshape([2, 1]) * 0.5 * PI; 264 | 265 | let shape = x.shape(); 266 | let calc_shape = shape + Shape::from([1, 1]); 267 | let output_shape = { 268 | let mut tmp = shape; 269 | tmp[SignedIndex(-1)] *= 2 * freq_count; 270 | tmp 271 | }; 272 | (x.reshape(calc_shape) * freq + phase) 273 | .sin() 274 | .reshape(output_shape) 275 | } 276 | 277 | fn main() { 278 | let (info, data) = image::stbi_load_from_reader( 279 | &mut File::open("data/images/cat.jpg").unwrap(), 280 | stb::image::Channels::Rgb, 281 | ) 282 | .unwrap(); 283 | let width = info.width as usize; 284 | let height = info.height as usize; 285 | 286 | let mut env = Environment::new(); 287 | 288 | let app_params = AppParams::from_args(); 289 | let pe_freq_count = 8; 290 | let module: Box = { 291 | let env = &mut env; 292 | let hidden_units = &[256, 128, 64, 32]; 293 | match app_params.network { 294 | NetworkType::Relu => Box::new(Relu::new(env, 0, hidden_units)), 295 | NetworkType::ReluPE => Box::new(Relu::new(env, pe_freq_count, hidden_units)), 296 | NetworkType::Siren => Box::new(Siren::new(env, hidden_units)), 297 | NetworkType::MultiHash => { 298 | Box::new(MultiHashGrid::new(env, 2, 512, 10, 4096, &[64, 64])) 299 | } 300 | } 301 | }; 302 | 303 | let m = 1 << 14; 304 | let x_param = env.static_parameter([m, 2], "x"); 305 | let y_param = env.static_parameter([m, 3], "y"); 306 | let learning_rate_scale_param = env.static_parameter([1], "lr_scale"); 307 | let loss_sum_param = env.static_parameter([1], "loss"); 308 | let (train_graph, parameters, _optimizer) = { 309 | let scope = env.scope(); 310 | 311 | let x = module.train(scope.parameter(&x_param)); 312 | let loss = (x - &y_param).square().reduce_sum(-1, true).set_loss(); 313 | scope.update_parameter_value(&loss_sum_param, |loss_sum| { 314 | loss_sum + loss.reduce_sum(0, false) 315 | }); 316 | 317 | let learning_rate_scale = scope.parameter_value(&learning_rate_scale_param); 318 | let parameters = scope.trainable_parameters(); 319 | let optimizer = Adam::new( 320 | &mut env, 321 | &scope, 322 | ¶meters, 323 | 0.02 * learning_rate_scale, 324 | 0.9, 325 | 0.99, 326 | 1.0E-8, 327 | ); 328 | 329 | (scope.build_graph(), parameters, optimizer) 330 | }; 331 | println!( 332 | "trainable parameters: {}", 333 | parameters 334 | .iter() 335 | .map(|param| param.shape().element_count()) 336 | .sum::() 337 | ); 338 | 339 | let pixel_count = height * width; 340 | let image_param = env.static_parameter([pixel_count, 3], "image"); 341 | let test_graph = env.build_graph(|scope| { 342 | let u = (scope.coord(width) + 0.5) * (2.0 / (width as f32)) - 1.0; 343 | let v = (scope.coord(height) + 0.5) * (2.0 / (height as f32)) - 1.0; 344 | let x = scope 345 | .coord(2) 346 | .select_eq(0.0, u.reshape([1, width, 1]), v.reshape([height, 1, 1])) 347 | .reshape([pixel_count, 2]); 348 | let x = module.test(x); 349 | scope.write_parameter_value(&image_param, x.value()); 350 | }); 351 | 352 | let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(0); 353 | for param in parameters.iter() { 354 | env.reset_parameter(param, &mut rng); 355 | } 356 | 357 | let mut stats_w = app_params 358 | .csv_file_name 359 | .map(|path| BufWriter::new(File::create(path).unwrap())); 360 | 361 | let epoch_count = 200; 362 | for epoch_index in 0..epoch_count { 363 | let epoch_t = (epoch_index as f32) + 0.5; 364 | let learning_rate_scale = (epoch_t / 10.0).min(1.0) * 0.5f32.powf(epoch_t / 40.0); 365 | env.writer(&learning_rate_scale_param) 366 | .write_all(bytemuck::bytes_of(&learning_rate_scale)) 367 | .unwrap(); 368 | 369 | // loop over batches to roughly cover the whole image 370 | env.writer(&loss_sum_param).zero_fill(); 371 | let mini_batch_count = (width * height) / m; 372 | for _ in 0..mini_batch_count { 373 | // generate batch from a random set of pixels 374 | let mut y_data: Vec = Vec::new(); 375 | let mut w = env.writer(&x_param); 376 | for _ in 0..m { 377 | let x0 = rng.gen_range(0..width); 378 | let x1 = rng.gen_range(0..height); 379 | let x_data: [f32; 2] = [ 380 | ((x0 as f32) + 0.5) * (2.0 / (width as f32)) - 1.0, 381 | ((x1 as f32) + 0.5) * (2.0 / (height as f32)) - 1.0, 382 | ]; 383 | w.write_all(bytemuck::cast_slice(&x_data)).unwrap(); 384 | let pixel_index = x1 * width + x0; 385 | for y in &data.as_slice()[3 * pixel_index..3 * (pixel_index + 1)] { 386 | y_data.push((*y as f32) / 255.0); 387 | } 388 | } 389 | mem::drop(w); 390 | env.writer(&y_param) 391 | .write_all(bytemuck::cast_slice(&y_data)) 392 | .unwrap(); 393 | 394 | // run training 395 | env.run(&train_graph, rng.next_u32()); 396 | } 397 | if app_params.show_timings && epoch_index < 2 { 398 | env.print_timings("training") 399 | } 400 | 401 | let done_counter = epoch_index + 1; 402 | let train_loss = env.read_parameter_scalar(&loss_sum_param) / (m as f32); 403 | if !app_params.quiet { 404 | println!( 405 | "epoch: {}, lr_scale: {}, loss: {}", 406 | done_counter, learning_rate_scale, train_loss 407 | ); 408 | } 409 | if let Some(w) = stats_w.as_mut() { 410 | if epoch_index == 0 { 411 | writeln!(w, "# epoch, loss").unwrap(); 412 | } 413 | writeln!(w, "{}, {}", done_counter, train_loss).unwrap(); 414 | if done_counter == epoch_count { 415 | writeln!(w).unwrap(); 416 | } 417 | } 418 | if let Some(image_prefix) = app_params.image_prefix.as_ref() { 419 | if app_params.output_all_images || done_counter == epoch_count { 420 | env.run(&test_graph, rng.next_u32()); 421 | let pixels: Vec = env 422 | .read_parameter_to_vec(&image_param) 423 | .iter() 424 | .map(|&x| (x * 255.0 + 0.5).clamp(0.0, 255.0) as u8) 425 | .collect(); 426 | let name = format!("{}_{}.jpg", image_prefix, done_counter); 427 | stb::image_write::stbi_write_jpg( 428 | CString::new(name).unwrap().as_c_str(), 429 | info.width, 430 | info.height, 431 | 3, 432 | &pixels, 433 | 90, 434 | ) 435 | .unwrap(); 436 | } 437 | } 438 | } 439 | } 440 | -------------------------------------------------------------------------------- /docs/array_api_values.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 19 | 21 | 45 | 49 | G 51 | 56 | 59 | cluster2 61 | 66 | 67 | 70 | cluster1 72 | 77 | 78 | 79 | 82 | n0 84 | 89 | Input(1v1) 96 | m[3, 3] 103 | 104 | 105 | 108 | n8 110 | 115 | MatMul 122 | [1, 1, 3, 1] 129 | 130 | 131 | 134 | n0->n8 136 | 141 | 146 | 147 | 148 | 151 | n2 153 | 158 | Input(2v1) 165 | x[3, 1] 172 | 173 | 174 | 177 | n2->n8 179 | 184 | 189 | 190 | 191 | 194 | n4 196 | 201 | Input(3v1) 208 | y[3, 1] 215 | 216 | 217 | 220 | n18 222 | 227 | Mul 234 | [3, 1] 241 | 242 | 243 | 246 | n4->n18 248 | 253 | 258 | 259 | 260 | 263 | n4->n18 265 | 270 | 275 | 276 | 277 | 280 | n22 282 | 2E0 289 | 290 | 291 | 294 | n15 296 | 301 | Mul 308 | [3, 1] 315 | 316 | 317 | 320 | n22->n15 322 | 327 | 332 | V 339 | 340 | 341 | 344 | n24 346 | 1E0 353 | 354 | 355 | 358 | n26 360 | 365 | Add 372 | [3, 1] 379 | 380 | 381 | 384 | n24->n26 386 | 391 | 396 | V 403 | 404 | 405 | 408 | n27 410 | 415 | Output(4v1) 422 | z[3, 1] 429 | 430 | 431 | 434 | n21 436 | 441 | Add 448 | [3, 1] 455 | 456 | 457 | 460 | n15->n21 462 | 467 | 472 | 473 | 474 | 477 | n18->n21 479 | 484 | 489 | 490 | 491 | 494 | n21->n26 496 | 501 | 506 | 507 | 508 | 511 | n26->n27 513 | 518 | 523 | 524 | 525 | 528 | n8->n15 530 | 535 | 540 | 541 | 542 | 543 | -------------------------------------------------------------------------------- /docs/array_api_grad.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 19 | 21 | 45 | 49 | G 51 | 56 | 59 | cluster1 61 | 66 | 67 | 68 | 71 | n0 73 | 78 | Input(5v1) 85 | x[1] 92 | 93 | 94 | 97 | n2 99 | 104 | Sin 111 | [1] 118 | 119 | 120 | 123 | n0->n2 125 | 130 | 135 | 136 | 137 | 140 | n4 142 | 147 | Cos 154 | [1] 161 | 162 | 163 | 166 | n0->n4 168 | 173 | 178 | 179 | 180 | 183 | n46 185 | 190 | Sub 197 | [1] 204 | 205 | 206 | 209 | n0->n46 211 | 216 | 221 | 222 | 223 | 226 | n17 228 | 1E-1 235 | 236 | 237 | 240 | n43 242 | 247 | Mul 254 | [1] 261 | 262 | 263 | 266 | n17->n43 268 | 273 | 278 | 279 | 280 | 283 | n36 285 | 3E0 292 | 293 | 294 | 297 | n28 299 | 304 | Add 311 | [1] 318 | 319 | 320 | 323 | n36->n28 325 | 330 | 335 | 336 | 337 | 340 | n47 342 | 347 | Output(5v1) 354 | x[1] 361 | 362 | 363 | 366 | n18 368 | 373 | Add 380 | [1] 387 | 388 | 389 | 392 | n2->n18 394 | 399 | 404 | 405 | 406 | 409 | n2->n18 411 | 416 | 421 | 422 | 423 | 426 | n7 428 | 433 | Mul 440 | [1] 447 | 448 | 449 | 452 | n4->n7 454 | 459 | 464 | 465 | 466 | 469 | n7->n43 471 | 476 | 481 | 482 | 483 | 486 | n18->n28 488 | 493 | 498 | 499 | 500 | 503 | n28->n7 505 | 510 | 515 | 516 | 517 | 520 | n43->n46 522 | 527 | 532 | 533 | 534 | 537 | n46->n47 539 | 544 | 549 | 550 | 551 | 552 | -------------------------------------------------------------------------------- /examples/fashion_mnist/main.rs: -------------------------------------------------------------------------------- 1 | use descent::{loss::*, module::*, optimizer::*, prelude::*}; 2 | use flate2::bufread::GzDecoder; 3 | use rand::{prelude::SliceRandom, RngCore, SeedableRng}; 4 | use std::{ 5 | convert::TryInto, 6 | fs::File, 7 | io::{self, prelude::*, BufReader, BufWriter}, 8 | path::{Path, PathBuf}, 9 | }; 10 | use structopt::StructOpt; 11 | use strum::{EnumString, EnumVariantNames, VariantNames}; 12 | 13 | fn load_gz_bytes(path: impl AsRef) -> io::Result> { 14 | let reader = BufReader::new(File::open(path).unwrap()); 15 | let mut decoder = GzDecoder::new(reader); 16 | let mut bytes = Vec::new(); 17 | decoder.read_to_end(&mut bytes)?; 18 | Ok(bytes) 19 | } 20 | 21 | fn read_be_u32(bytes: &[u8]) -> (u32, &[u8]) { 22 | let (prefix, suffix) = bytes.split_at(4); 23 | (u32::from_be_bytes(prefix.try_into().unwrap()), suffix) 24 | } 25 | 26 | fn read_images_info(bytes: &[u8]) -> ((usize, usize, usize), &[u8]) { 27 | let (magic, bytes) = read_be_u32(bytes); 28 | assert_eq!(magic, 2051); 29 | let (images, bytes) = read_be_u32(bytes); 30 | let (rows, bytes) = read_be_u32(bytes); 31 | let (cols, bytes) = read_be_u32(bytes); 32 | ((images as usize, rows as usize, cols as usize), bytes) 33 | } 34 | 35 | fn read_labels_info(bytes: &[u8]) -> (usize, &[u8]) { 36 | let (magic, bytes) = read_be_u32(bytes); 37 | assert_eq!(magic, 2049); 38 | let (items, bytes) = read_be_u32(bytes); 39 | ((items as usize), bytes) 40 | } 41 | 42 | fn unpack_images( 43 | env: &mut Environment, 44 | parameter: &Parameter, 45 | bytes: &[u8], 46 | indices: &[usize], 47 | ) -> io::Result<()> { 48 | let ((_, rows, cols), bytes) = read_images_info(bytes); 49 | let pixel_count = rows * cols; 50 | let mut w = env.writer(parameter); 51 | let mut image = Vec::::with_capacity(pixel_count); 52 | for index in indices.iter().copied() { 53 | let begin = index * pixel_count; 54 | let end = begin + pixel_count; 55 | image.clear(); 56 | image.extend(bytes[begin..end].iter().map(|&c| (c as f32) / 255.0)); 57 | w.write_all(bytemuck::cast_slice(&image))?; 58 | } 59 | Ok(()) 60 | } 61 | 62 | fn unpack_labels( 63 | env: &mut Environment, 64 | parameter: &Parameter, 65 | bytes: &[u8], 66 | indices: &[usize], 67 | ) -> io::Result<()> { 68 | let (_, bytes) = read_labels_info(bytes); 69 | let labels: Vec = indices.iter().map(|&index| bytes[index] as f32).collect(); 70 | let mut w = env.writer(parameter); 71 | w.write_all(bytemuck::cast_slice(&labels)) 72 | } 73 | 74 | #[derive(Debug, EnumString, EnumVariantNames)] 75 | #[strum(serialize_all = "kebab_case")] 76 | enum NetworkType { 77 | Linear, 78 | SingleLayer, 79 | ConvNet, 80 | ConvBlurNet, 81 | } 82 | 83 | #[derive(Debug, EnumString, EnumVariantNames)] 84 | #[strum(serialize_all = "kebab_case")] 85 | enum OptimizerType { 86 | Descent, 87 | Adam, 88 | } 89 | 90 | #[derive(Debug, StructOpt)] 91 | #[structopt( 92 | no_version, 93 | name = "fashion_mnist", 94 | about = "Example networks to train using the Fashion MNIST dataset." 95 | )] 96 | struct AppParams { 97 | #[structopt(possible_values=&NetworkType::VARIANTS, default_value="single-layer")] 98 | network: NetworkType, 99 | 100 | #[structopt(short, long, possible_values=&OptimizerType::VARIANTS, default_value="adam")] 101 | optimizer: OptimizerType, 102 | 103 | #[structopt(short, long, default_value = "1.0E-8")] 104 | weight_decay: f32, 105 | 106 | #[structopt(short, long, default_value = "1000")] 107 | mini_batch_size: usize, 108 | 109 | #[structopt(short, long, default_value = "40")] 110 | epoch_count: usize, 111 | 112 | #[structopt(short, long, default_value = "1")] 113 | trial_count: usize, 114 | 115 | #[structopt(long)] 116 | output_dot_files: bool, 117 | 118 | #[structopt(long)] 119 | show_timings: bool, 120 | 121 | #[structopt(long)] 122 | quiet: bool, 123 | 124 | #[structopt(long)] 125 | csv_file_name: Option, 126 | } 127 | 128 | struct Linear { 129 | fc: Dense, 130 | } 131 | 132 | impl Linear { 133 | fn new(env: &mut Environment) -> Self { 134 | Self { 135 | fc: Dense::builder(28 * 28, 10).build(env), 136 | } 137 | } 138 | } 139 | 140 | impl Module for Linear { 141 | fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s> { 142 | input.flatten().apply(&self.fc, ctx) 143 | } 144 | } 145 | 146 | struct SingleLayer { 147 | fc1: Dense, 148 | fc2: Dense, 149 | } 150 | 151 | impl SingleLayer { 152 | fn new(env: &mut Environment) -> Self { 153 | let hidden_units = 300; 154 | Self { 155 | fc1: Dense::builder(28 * 28, hidden_units).build(env), 156 | fc2: Dense::builder(hidden_units, 10).build(env), 157 | } 158 | } 159 | } 160 | 161 | impl Module for SingleLayer { 162 | fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s> { 163 | input 164 | .flatten() 165 | .apply(&self.fc1, ctx) 166 | .leaky_relu(0.01) 167 | .apply(&self.fc2, ctx) 168 | } 169 | } 170 | 171 | struct ConvNet { 172 | conv1: Conv2D, 173 | pool1: Box, 174 | conv2: Conv2D, 175 | pool2: Box, 176 | fc1: Dense, 177 | fc2: Dense, 178 | } 179 | 180 | impl ConvNet { 181 | fn new(env: &mut Environment, use_blur_pool: bool) -> Self { 182 | let c1 = 16; 183 | let c2 = 32; 184 | let hidden = 128; 185 | Self { 186 | conv1: Conv2D::builder(1, c1, 3, 3).with_pad(1).build(env), 187 | pool1: if use_blur_pool { 188 | Box::new(MaxBlurPool2D::new(env, c1)) 189 | } else { 190 | Box::new(MaxPool2D::default()) 191 | }, 192 | conv2: Conv2D::builder(c1, c2, 3, 3) 193 | .with_pad(1) 194 | .with_groups(2) 195 | .build(env), 196 | pool2: if use_blur_pool { 197 | Box::new(MaxBlurPool2D::new(env, c2)) 198 | } else { 199 | Box::new(MaxPool2D::default()) 200 | }, 201 | fc1: Dense::builder(7 * 7 * c2, hidden).build(env), 202 | fc2: Dense::builder(hidden, 10).build(env), 203 | } 204 | } 205 | } 206 | 207 | impl Module for ConvNet { 208 | fn eval<'s>(&self, input: DualArray<'s>, ctx: &EvalContext) -> DualArray<'s> { 209 | input 210 | .apply(&self.conv1, ctx) 211 | .leaky_relu(0.01) 212 | .apply(self.pool1.as_ref(), ctx) 213 | .apply(&self.conv2, ctx) 214 | .leaky_relu(0.01) 215 | .apply(self.pool2.as_ref(), ctx) 216 | .flatten() 217 | .apply(&Dropout::new(0.5), ctx) 218 | .apply(&self.fc1, ctx) 219 | .leaky_relu(0.01) 220 | .apply(&self.fc2, ctx) 221 | } 222 | } 223 | 224 | fn main() { 225 | let app_params = AppParams::from_args(); 226 | 227 | let mut env = Environment::new(); 228 | let module: Box = { 229 | let env = &mut env; 230 | match app_params.network { 231 | NetworkType::Linear => Box::new(Linear::new(env)), 232 | NetworkType::SingleLayer => Box::new(SingleLayer::new(env)), 233 | NetworkType::ConvNet => Box::new(ConvNet::new(env, false)), 234 | NetworkType::ConvBlurNet => Box::new(ConvNet::new(env, true)), 235 | } 236 | }; 237 | 238 | let m = app_params.mini_batch_size; 239 | let x_param = env.static_parameter([m, 28, 28, 1], "x"); 240 | let y_param = env.static_parameter([m, 1], "y"); 241 | 242 | let learning_rate_scale_param = env.static_parameter([1], "lr_scale"); 243 | let loss_sum_param = env.static_parameter([1], "loss"); 244 | let accuracy_sum_param = env.static_parameter([1], "accuracy"); 245 | 246 | // build a graph for training, collect the trainable parameters 247 | let (train_graph, parameters, optimizer) = { 248 | let scope = env.scope(); 249 | 250 | // emit the ops for the network 251 | let x = module.train(scope.parameter(&x_param)); 252 | let loss = softmax_cross_entropy_loss(x, &y_param).set_loss(); 253 | let accuracy = softmax_cross_entropy_accuracy(x, &y_param); 254 | 255 | // update sum of loss and accuracy 256 | scope.update_parameter_value(&loss_sum_param, |loss_sum| { 257 | loss_sum + loss.reduce_sum(0, false) 258 | }); 259 | scope.update_parameter_value(&accuracy_sum_param, |accuracy_sum| { 260 | accuracy_sum + accuracy.reduce_sum(0, false) 261 | }); 262 | 263 | // train using gradient of the loss (scaled for size of mini batch) 264 | let learning_rate_scale = scope.parameter_value(&learning_rate_scale_param); 265 | let parameters = scope.trainable_parameters(); 266 | add_weight_decay_to_grad(&scope, ¶meters, app_params.weight_decay); 267 | let optimizer: Box = match app_params.optimizer { 268 | OptimizerType::Descent => Box::new(StochasticGradientDescent::new( 269 | &mut env, 270 | &scope, 271 | ¶meters, 272 | 0.1 * learning_rate_scale, 273 | 0.9, 274 | )), 275 | OptimizerType::Adam => Box::new(Adam::new( 276 | &mut env, 277 | &scope, 278 | ¶meters, 279 | 0.005 * learning_rate_scale, 280 | 0.9, 281 | 0.999, 282 | 1.0E-8, 283 | )), 284 | }; 285 | 286 | (scope.build_graph(), parameters, optimizer) 287 | }; 288 | println!( 289 | "trainable parameters: {}", 290 | parameters 291 | .iter() 292 | .map(|param| param.shape().element_count()) 293 | .sum::() 294 | ); 295 | 296 | // build a graph to evaluate the test set (keeps parameters unchanged) 297 | let test_graph = env.build_graph(|scope| { 298 | // emit the ops for the network 299 | let x = module.test(scope.parameter(&x_param)); 300 | let loss = softmax_cross_entropy_loss(x, &y_param).set_loss(); 301 | let accuracy = softmax_cross_entropy_accuracy(x, &y_param); 302 | 303 | // update sum of loss and accuracy 304 | scope.update_parameter_value(&loss_sum_param, |loss_sum| { 305 | loss_sum + loss.reduce_sum(0, false) 306 | }); 307 | scope.update_parameter_value(&accuracy_sum_param, |accuracy_sum| { 308 | accuracy_sum + accuracy.reduce_sum(0, false) 309 | }); 310 | }); 311 | 312 | // build a graph to evaluate the L2 norm of training parameters (to check weight decay) 313 | let norm_param = env.static_parameter([1], "norm"); 314 | let norm_graph = env.build_graph(|scope| { 315 | let mut sum = scope.literal(0.0).value(); 316 | for param in parameters.iter() { 317 | let x = scope.parameter_value(¶m); 318 | let x = x.reshape([x.shape().element_count()]); 319 | let x = x * x * 0.5; 320 | sum = sum + x.reduce_sum(0, true); 321 | } 322 | scope.write_parameter_value(&norm_param, sum); 323 | }); 324 | 325 | // write graphs out to disk if necessary 326 | if app_params.output_dot_files { 327 | train_graph.write_dot_file(KernelDotOutput::Cluster, "train.dot"); 328 | train_graph.write_dot_file(KernelDotOutput::None, "train_s.dot"); 329 | train_graph.write_dot_file(KernelDotOutput::Color, "train_k.dot"); 330 | test_graph.write_dot_file(KernelDotOutput::Cluster, "test.dot"); 331 | } 332 | 333 | // load training data 334 | let train_images = load_gz_bytes("data/fashion_mnist/train-images-idx3-ubyte.gz").unwrap(); 335 | let train_labels = load_gz_bytes("data/fashion_mnist/train-labels-idx1-ubyte.gz").unwrap(); 336 | let ((train_image_count, train_image_rows, train_image_cols), _) = 337 | read_images_info(&train_images); 338 | let (train_label_count, _) = read_labels_info(&train_labels); 339 | assert_eq!(train_image_count, train_label_count); 340 | assert_eq!(train_image_count % m, 0); 341 | assert_eq!(train_image_rows, 28); 342 | assert_eq!(train_image_cols, 28); 343 | 344 | // load test data 345 | let test_images = load_gz_bytes("data/fashion_mnist/t10k-images-idx3-ubyte.gz").unwrap(); 346 | let test_labels = load_gz_bytes("data/fashion_mnist/t10k-labels-idx1-ubyte.gz").unwrap(); 347 | let ((test_image_count, test_image_rows, test_image_cols), _) = read_images_info(&test_images); 348 | let (test_label_count, _) = read_labels_info(&test_labels); 349 | assert_eq!(test_image_count, test_label_count); 350 | assert_eq!(test_image_count % m, 0); 351 | assert_eq!(test_image_rows, 28); 352 | assert_eq!(test_image_cols, 28); 353 | 354 | // maybe writing stats to file 355 | let mut stats_w = app_params 356 | .csv_file_name 357 | .map(|path| BufWriter::new(File::create(path).unwrap())); 358 | 359 | // attempt to train 5 times with different random seeds 360 | for trial_index in 0..app_params.trial_count { 361 | // reset all trainable parameters and optimizer state 362 | let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(trial_index as u64); 363 | for param in parameters.iter() { 364 | env.reset_parameter(param, &mut rng); 365 | } 366 | optimizer.reset_state(&mut env); 367 | 368 | // run epochs 369 | let mut indices = Vec::new(); 370 | for epoch_index in 0..app_params.epoch_count { 371 | // update learning for this epoch (halve every 40 epochs) 372 | let learning_rate_scale = 0.5f32.powf((epoch_index as f32) / 40.0); 373 | env.writer(&learning_rate_scale_param) 374 | .write_all(bytemuck::bytes_of(&learning_rate_scale)) 375 | .unwrap(); 376 | 377 | // loop over training mini-batches 378 | env.writer(&loss_sum_param).zero_fill(); 379 | env.writer(&accuracy_sum_param).zero_fill(); 380 | indices.clear(); 381 | indices.extend(0..train_image_count); 382 | indices.shuffle(&mut rng); 383 | for batch_indices in indices.chunks(m) { 384 | unpack_images(&mut env, &x_param, &train_images, batch_indices).unwrap(); 385 | unpack_labels(&mut env, &y_param, &train_labels, batch_indices).unwrap(); 386 | env.run(&train_graph, rng.next_u32()); 387 | } 388 | if app_params.show_timings && epoch_index < 2 { 389 | env.print_timings("training"); 390 | } 391 | let train_loss = 392 | env.read_parameter_scalar(&loss_sum_param) / (train_image_count as f32); 393 | let train_accuracy = 394 | env.read_parameter_scalar(&accuracy_sum_param) / (train_image_count as f32); 395 | 396 | // loop over test mini-batches to evaluate loss and accuracy 397 | env.writer(&loss_sum_param).zero_fill(); 398 | env.writer(&accuracy_sum_param).zero_fill(); 399 | indices.clear(); 400 | indices.extend(0..test_image_count); 401 | for batch_indices in indices.chunks(m) { 402 | unpack_images(&mut env, &x_param, &test_images, batch_indices).unwrap(); 403 | unpack_labels(&mut env, &y_param, &test_labels, batch_indices).unwrap(); 404 | env.run(&test_graph, rng.next_u32()); 405 | } 406 | if app_params.show_timings && epoch_index < 2 { 407 | env.print_timings("testing"); 408 | } 409 | let test_loss = env.read_parameter_scalar(&loss_sum_param) / (test_image_count as f32); 410 | let test_accuracy = 411 | env.read_parameter_scalar(&accuracy_sum_param) / (test_image_count as f32); 412 | 413 | // compute the norm of all the parameters 414 | env.run(&norm_graph, rng.next_u32()); 415 | let norm = env.read_parameter_scalar(&norm_param); 416 | 417 | let done_counter = epoch_index + 1; 418 | if !app_params.quiet { 419 | println!( 420 | "epoch: {}, loss: {}/{}, accuracy: {}/{}, w_norm: {}", 421 | done_counter, 422 | train_loss, 423 | test_loss, 424 | train_accuracy, 425 | test_accuracy, 426 | norm.sqrt() 427 | ); 428 | } 429 | if let Some(w) = stats_w.as_mut() { 430 | if epoch_index == 0 { 431 | writeln!( 432 | w, 433 | "# epoch, train_loss, test_loss, train_accuracy, test_accuracy" 434 | ) 435 | .unwrap(); 436 | } 437 | writeln!( 438 | w, 439 | "{}, {}, {}, {}, {}", 440 | done_counter, train_loss, test_loss, train_accuracy, test_accuracy 441 | ) 442 | .unwrap(); 443 | if done_counter == app_params.epoch_count { 444 | writeln!(w).unwrap(); 445 | } 446 | } 447 | } 448 | } 449 | } 450 | --------------------------------------------------------------------------------